diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 77047636bb95..a655a650cb32 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -368,7 +368,7 @@ def parse_client_command(cmd: str) -> dict[str, Any]: # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", # we want to turn it into "8xGPUTYPE" df["GPU"] = df["GPU"].apply( - lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}" + lambda x: f"{len(x.splitlines())}x{x.splitlines()[0]}" ) # get markdown tables diff --git a/.buildkite/pyproject.toml b/.buildkite/pyproject.toml deleted file mode 100644 index d5cad1c73c6f..000000000000 --- a/.buildkite/pyproject.toml +++ /dev/null @@ -1,46 +0,0 @@ -# This local pyproject file is part of the migration from yapf to ruff format. -# It uses the same core rules as the main pyproject.toml file, but with the -# following differences: -# - ruff line length is overridden to 88 -# - deprecated typing ignores (UP006, UP035) have been removed - -[tool.ruff] -line-length = 88 - -[tool.ruff.lint.per-file-ignores] -"vllm/third_party/**" = ["ALL"] -"vllm/version.py" = ["F401"] -"vllm/_version.py" = ["ALL"] - -[tool.ruff.lint] -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # flake8-simplify - "SIM", - # isort - "I", - # flake8-logging-format - "G", -] -ignore = [ - # star imports - "F405", "F403", - # lambda expression assignment - "E731", - # Loop control variable not used within loop body - "B007", - # f-string format - "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", -] - -[tool.ruff.format] -docstring-code-format = true diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9c200a577167..d11a43377548 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -477,6 +477,7 @@ steps: source_file_dependencies: - csrc/mamba/ - tests/kernels/mamba + - vllm/model_executor/layers/mamba/ops commands: - pytest -v -s kernels/mamba @@ -834,11 +835,11 @@ steps: - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py -- label: GPT-OSS Eval (Blackwell) +- label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 working_dir: "/vllm-workspace/" gpu: b200 - optional: true # disable while debugging + optional: true # run on nightlies source_file_dependencies: - tests/evals/gpt_oss - vllm/model_executor/models/gpt_oss.py @@ -865,6 +866,16 @@ steps: commands: - pytest -s -v tests/quantization/test_blackwell_moe.py +- label: Blackwell LM Eval Small Models + timeout_in_minutes: 75 + gpu: b200 + optional: true # run on nightlies + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1 + ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 0b9c054b968a..dbcad3aa308f 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -23,6 +23,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson # Any change to the VllmConfig changes can have a large user-facing impact, # so spam a lot of people /vllm/config @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg +/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345 # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 82844810a633..dca3089f496c 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: actions: write runs-on: ubuntu-latest steps: - - uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0 + - uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0 with: # Increasing this value ensures that changes to this workflow # propagate to all issues and PRs in days rather than months diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8ca414ee4269..95a3866e6bb8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,28 +6,16 @@ default_stages: - manual # Run in CI exclude: 'vllm/third_party/.*' repos: -- repo: https://github.com/google/yapf - rev: v0.43.0 - hooks: - - id: yapf - args: [--in-place, --verbose] - # Keep the same list from yapfignore here to avoid yapf failing without any inputs - exclude: '(.buildkite|benchmarks|build|examples)/.*' - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.7 + rev: v0.13.3 hooks: - - id: ruff + - id: ruff-check args: [--output-format, github, --fix] - id: ruff-format - files: ^(.buildkite|benchmarks|examples)/.* - repo: https://github.com/crate-ci/typos rev: v1.35.5 hooks: - id: typos -- repo: https://github.com/PyCQA/isort - rev: 6.0.1 - hooks: - - id: isort - repo: https://github.com/pre-commit/mirrors-clang-format rev: v20.1.3 hooks: diff --git a/benchmarks/benchmark_block_pool.py b/benchmarks/benchmark_block_pool.py index eae8d9927ea3..5434f8b6a4e4 100644 --- a/benchmarks/benchmark_block_pool.py +++ b/benchmarks/benchmark_block_pool.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +from benchmark_utils import TimeCollector from tabulate import tabulate -from benchmark_utils import TimeCollector from vllm.utils import FlexibleArgumentParser from vllm.v1.core.block_pool import BlockPool diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index d4b83edbd940..626b150ee4ce 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -5,9 +5,9 @@ from unittest import mock import numpy as np +from benchmark_utils import TimeCollector from tabulate import tabulate -from benchmark_utils import TimeCollector from vllm.config import ( CacheConfig, DeviceConfig, @@ -164,7 +164,7 @@ def invoke_main() -> None: ) parser.add_argument( "--batched", action="store_true", help="consider time to prepare batch" - ) # noqa: E501 + ) parser.add_argument( "--num-iteration", type=int, diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index a0350625491f..58b9767d0939 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -37,14 +37,13 @@ import datasets import numpy as np import pandas as pd -from tqdm.asyncio import tqdm -from transformers import PreTrainedTokenizerBase - from backend_request_func import ( ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput, ) +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase try: from vllm.transformers_utils.tokenizer import get_tokenizer @@ -910,13 +909,13 @@ def create_argument_parser(): parser.add_argument( "--tokenizer", type=str, - help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", ) parser.add_argument( "--tokenizer-mode", type=str, default="auto", - help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", ) parser.add_argument( "--num-prompts", diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index 2010b8038563..ba31bc563829 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# fmt: off # ruff: noqa: E501 import time @@ -20,19 +19,21 @@ ) -def benchmark_shape(m: int, - n: int, - k: int, - warmup: int = 100, - repeat: int = 10000, - verbose: bool = False) -> dict: +def benchmark_shape( + m: int, + n: int, + k: int, + warmup: int = 100, + repeat: int = 10000, + verbose: bool = False, +) -> dict: """Benchmark all implementations for a specific (m, n, k) shape.""" if verbose: print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") # Create test tensors - A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) # Reference result in BF16 torch.cuda.synchronize() @@ -49,34 +50,39 @@ def benchmark_shape(m: int, # Pre-quantize A for all implementations A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1]) A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - A, block_size[1], column_major_scales=True) + A, block_size[1], column_major_scales=True + ) # === DeepGEMM Implementation === def deepgemm_gemm(): - fp8_gemm_nt((A_deepgemm, A_scale_deepgemm), - (B_deepgemm, B_scale_deepgemm), - C_deepgemm) + fp8_gemm_nt( + (A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm + ) return C_deepgemm # === vLLM Triton Implementation === def vllm_triton_gemm(): - return w8a8_triton_block_scaled_mm(A_vllm, - B_vllm, - A_scale_vllm, - B_scale_vllm, - block_size, - output_dtype=torch.bfloat16) + return w8a8_triton_block_scaled_mm( + A_vllm, + B_vllm, + A_scale_vllm, + B_scale_vllm, + block_size, + output_dtype=torch.bfloat16, + ) # === vLLM CUTLASS Implementation === def vllm_cutlass_gemm(): - return ops.cutlass_scaled_mm(A_vllm_cutlass, - B_vllm.T, - scale_a=A_scale_vllm_cutlass, - scale_b=B_scale_vllm.T, - out_dtype=torch.bfloat16) + return ops.cutlass_scaled_mm( + A_vllm_cutlass, + B_vllm.T, + scale_a=A_scale_vllm_cutlass, + scale_b=B_scale_vllm.T, + out_dtype=torch.bfloat16, + ) # Run correctness check first if verbose: @@ -93,26 +99,23 @@ def vllm_cutlass_gemm(): print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") - print("vLLM Triton vs DeepGEMM difference: " - f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") - print("vLLM CUTLASS vs DeepGEMM difference: " - f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") + print( + "vLLM Triton vs DeepGEMM difference: " + f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}" + ) + print( + "vLLM CUTLASS vs DeepGEMM difference: " + f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}" + ) # Benchmark implementations implementations = { "DeepGEMM": deepgemm_gemm, "vLLM Triton": vllm_triton_gemm, - "vLLM CUTLASS": vllm_cutlass_gemm + "vLLM CUTLASS": vllm_cutlass_gemm, } - benchmark_results = { - "shape": { - "m": m, - "n": n, - "k": k - }, - "implementations": {} - } + benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}} for name, func in implementations.items(): # Warmup @@ -140,38 +143,36 @@ def vllm_cutlass_gemm(): "tflops": tflops, "gb_s": gb_s, "diff": { - "DeepGEMM": - 0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm), - "Reference": - deepgemm_diff if name == "DeepGEMM" else - (vllm_triton_diff - if name == "vLLM Triton" else vllm_cutlass_diff) - } + "DeepGEMM": 0.0 + if name == "DeepGEMM" + else calc_diff(func(), C_deepgemm), + "Reference": deepgemm_diff + if name == "DeepGEMM" + else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff), + }, } if verbose: - print( - f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s" - ) + print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s") # Calculate speedups baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] for name, data in benchmark_results["implementations"].items(): if name != "DeepGEMM": speedup = baseline / data["time_ms"] - benchmark_results["implementations"][name][ - "speedup_vs_deepgemm"] = speedup + benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup if verbose: - print(f"DeepGEMM is {1/speedup:.2f}x " - f"{'faster' if 1/speedup > 1 else 'slower'} than {name}") + print( + f"DeepGEMM is {1 / speedup:.2f}x " + f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}" + ) - vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][ - "time_ms"] - vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][ - "time_ms"] + vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"] + vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"] cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time - benchmark_results["implementations"]["vLLM CUTLASS"][ - "speedup_vs_triton"] = cutlass_vs_triton + benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = ( + cutlass_vs_triton + ) if verbose: print( f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " @@ -183,8 +184,7 @@ def vllm_cutlass_gemm(): def format_table_row(values, widths): """Format a row with specified column widths.""" - return "| " + " | ".join(f"{val:{w}}" - for val, w in zip(values, widths)) + " |" + return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |" def print_table(headers, rows, title=None): @@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False): for result in all_results: shape = result["shape"] impl_data = result["implementations"]["DeepGEMM"] - deepgemm_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}" - ]) + deepgemm_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + ] + ) - print_table(deepgemm_headers, - deepgemm_rows, - title="DeepGEMM Implementation:") + print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:") # Print vLLM Triton table - triton_headers = [ - "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM" - ] + triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"] triton_rows = [] for result in all_results: shape = result["shape"] impl_data = result["implementations"]["vLLM Triton"] speedup = impl_data.get("speedup_vs_deepgemm", 1.0) - triton_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", - format_speedup(speedup) - ]) + triton_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + format_speedup(speedup), + ] + ) - print_table(triton_headers, - triton_rows, - title="vLLM Triton Implementation:") + print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:") # Print vLLM CUTLASS table cutlass_headers = [ - "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM", - "vs Triton" + "m", + "n", + "k", + "Time (μs)", + "TFLOPS", + "GB/s", + "vs DeepGEMM", + "vs Triton", ] cutlass_rows = [] for result in all_results: @@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False): impl_data = result["implementations"]["vLLM CUTLASS"] vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) vs_triton = impl_data.get("speedup_vs_triton", 1.0) - cutlass_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", - format_speedup(vs_deepgemm), - format_speedup(vs_triton) - ]) + cutlass_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + format_speedup(vs_deepgemm), + format_speedup(vs_triton), + ] + ) - print_table(cutlass_headers, - cutlass_rows, - title="vLLM CUTLASS Implementation:") + print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:") # Calculate and print averages print("\n===== AVERAGE PERFORMANCE =====") implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] avg_metrics = { - impl: { - "tflops": 0, - "gb_s": 0, - "time_ms": 0 - } - for impl in implementations + impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations } for result in all_results: @@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False): avg_tflops = avg_metrics[impl]["tflops"] / num_shapes avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes avg_time = avg_metrics[impl]["time_ms"] / num_shapes - avg_rows.append([ - impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}" - ]) + avg_rows.append( + [impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"] + ) print_table(avg_headers, avg_rows) @@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False): avg_speedups = { "DeepGEMM vs vLLM Triton": 0, "DeepGEMM vs vLLM CUTLASS": 0, - "vLLM CUTLASS vs vLLM Triton": 0 + "vLLM CUTLASS vs vLLM Triton": 0, } for result in all_results: deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] - vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][ - "time_ms"] + vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"] - avg_speedups[ - "DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time - avg_speedups[ - "DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time - avg_speedups[ - "vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time + avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time + avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time + avg_speedups["vLLM CUTLASS vs vLLM Triton"] += ( + vllm_triton_time / vllm_cutlass_time + ) print("\n===== AVERAGE SPEEDUPS =====") speedup_headers = ["Comparison", "Speedup"] @@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False): for result in all_results: for impl in implementations: - avg_diff[impl] += result["implementations"][impl]["diff"][ - "Reference"] + avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"] diff_headers = ["Implementation", "Avg Diff vs Reference"] diff_rows = [] diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml deleted file mode 100644 index 65b1e09a247e..000000000000 --- a/benchmarks/pyproject.toml +++ /dev/null @@ -1,49 +0,0 @@ -# This local pyproject file is part of the migration from yapf to ruff format. -# It uses the same core rules as the main pyproject.toml file, but with the -# following differences: -# - ruff line length is overridden to 88 -# - deprecated typing ignores (UP006, UP035) have been removed - -[tool.ruff] -line-length = 88 - -[tool.ruff.lint.per-file-ignores] -"vllm/third_party/**" = ["ALL"] -"vllm/version.py" = ["F401"] -"vllm/_version.py" = ["ALL"] - -[tool.ruff.lint] -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # flake8-simplify - "SIM", - # isort - "I", - # flake8-logging-format - "G", -] -ignore = [ - # star imports - "F405", "F403", - # lambda expression assignment - "E731", - # Loop control variable not used within loop body - "B007", - # f-string format - "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", -] - -[tool.ruff.lint.isort] -known-first-party = ["vllm"] - -[tool.ruff.format] -docstring-code-format = true \ No newline at end of file diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index e6d0012c1a4b..c962564c8da0 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -213,6 +213,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON endif() set(ONEDNN_AARCH64_USE_ACL "ON") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") + add_compile_definitions(VLLM_USE_ACL) endif() set(ONEDNN_LIBRARY_TYPE "STATIC") @@ -226,7 +227,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") - set(ONEDNN_VERBOSE "OFF") + set(ONEDNN_VERBOSE "ON") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) FetchContent_MakeAvailable(oneDNN) diff --git a/cmake/hipify.py b/cmake/hipify.py index 55d378f5b111..8504f9defee9 100755 --- a/cmake/hipify.py +++ b/cmake/hipify.py @@ -16,7 +16,7 @@ from torch.utils.hipify.hipify_python import hipify -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() # Project directory where all the source + include files live. @@ -34,15 +34,14 @@ ) # Source files to convert. - parser.add_argument("sources", - help="Source files to hipify.", - nargs="*", - default=[]) + parser.add_argument( + "sources", help="Source files to hipify.", nargs="*", default=[] + ) args = parser.parse_args() # Limit include scope to project_dir only - includes = [os.path.join(args.project_dir, '*')] + includes = [os.path.join(args.project_dir, "*")] # Get absolute path for all source files. extra_files = [os.path.abspath(s) for s in args.sources] @@ -51,25 +50,31 @@ # The directory might already exist to hold object files so we ignore that. shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True) - hipify_result = hipify(project_directory=args.project_dir, - output_directory=args.output_dir, - header_include_dirs=[], - includes=includes, - extra_files=extra_files, - show_detailed=True, - is_pytorch_extension=True, - hipify_extra_files_only=True) + hipify_result = hipify( + project_directory=args.project_dir, + output_directory=args.output_dir, + header_include_dirs=[], + includes=includes, + extra_files=extra_files, + show_detailed=True, + is_pytorch_extension=True, + hipify_extra_files_only=True, + ) hipified_sources = [] for source in args.sources: s_abs = os.path.abspath(source) - hipified_s_abs = (hipify_result[s_abs].hipified_path if - (s_abs in hipify_result - and hipify_result[s_abs].hipified_path is not None) - else s_abs) + hipified_s_abs = ( + hipify_result[s_abs].hipified_path + if ( + s_abs in hipify_result + and hipify_result[s_abs].hipified_path is not None + ) + else s_abs + ) hipified_sources.append(hipified_s_abs) - assert (len(hipified_sources) == len(args.sources)) + assert len(hipified_sources) == len(args.sources) # Print hipified source files. print("\n".join(hipified_sources)) diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp index 6def0e061fa9..0f0cc34602b3 100644 --- a/csrc/cpu/dnnl_helper.cpp +++ b/csrc/cpu/dnnl_helper.cpp @@ -137,9 +137,8 @@ DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler( } void DNNLMatMulPrimitiveHandler::prepack_weight( - void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) { - dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, - {b_k_stride_, b_n_stride_}); + void* original_b_ptr, dnnl::memory::desc original_b_md, + dnnl::memory::desc b_target_mem_desc) { dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr); dnnl::memory packed_weight(b_target_mem_desc, default_engine()); { @@ -250,7 +249,9 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) if (a_qs_ == QuantizationStrategy::PER_TOKEN) { assert(!use_azp_); }; - prepack_weight(args.b_ptr, + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + prepack_weight(args.b_ptr, original_b_md, create_primitive_desc( MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, .use_bias = false, @@ -412,12 +413,25 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args) assert(ab_type_ == dnnl::memory::data_type::f32 || ab_type_ == dnnl::memory::data_type::bf16 || ab_type_ == dnnl::memory::data_type::f16); - prepack_weight(args.b_ptr, + + dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, + {b_k_stride_, b_n_stride_}); + + prepack_weight(args.b_ptr, original_b_md, create_primitive_desc( - MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, - .a_m_stride = DNNL_RUNTIME_DIM_VAL, - .use_bias = false, - .bias_type = dnnl::memory::data_type::undef}, + MSizeCacheKey{ +#ifdef VLLM_USE_ACL + // Arm Compute Library (ACL) backend for oneDNN does + // not support runtime + // dimensions, so we set M to a default value + .a_m_size = 128, + .a_m_stride = b_k_size_, +#else + .a_m_size = DNNL_RUNTIME_DIM_VAL, + .a_m_stride = DNNL_RUNTIME_DIM_VAL, +#endif + .use_bias = false, + .bias_type = dnnl::memory::data_type::undef}, true) .weights_desc()); init_runtime_memory_cache(args); @@ -443,13 +457,31 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) { c_storage->set_data_handle((void*)args.c_ptr); c_mem_desc->dims[0] = args.a_m_size; +#ifndef VLLM_USE_ACL + // We do not support in ACL backend of oneDNN, we handle bias by: + // 1. copying it into the result tensor + // 2. attaching a fused-sum post-op to the matmul primitive if (args.use_bias) { auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2); bias_storage->set_data_handle((void*)args.bias_ptr); } - +#endif dnnl::matmul matmul = get_matmul_cache(args); +// With ACL backend of oneDNN, the required memory format might change when the +// source tensor dims change. This does not really happen in practice, so isn't +// a performance hit, but we need to support it because the API allows for it. +#ifdef VLLM_USE_ACL + auto new_expected_wei_desc = + dnnl::matmul::primitive_desc( + const_cast(matmul.get_primitive_desc())) + .weights_desc(); + if (new_expected_wei_desc != b_target_mem_desc_) { + prepack_weight(memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle(), + b_target_mem_desc_, new_expected_wei_desc); + } +#endif + auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3); scratchpad_storage->set_data_handle( DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data()); @@ -484,7 +516,13 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc( } else { a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_, {key.a_m_stride, 1}); +#ifdef VLLM_USE_ACL + // ACL's backend of oneDNN always expects the weight format to be "any" + b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_, + dnnl::memory::format_tag::any); +#else b_md = b_target_mem_desc_; +#endif } dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, dnnl::memory::format_tag::ab); @@ -494,8 +532,18 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc( if (key.use_bias) { dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); +// Since ACL's matmuls don't support passing a bias_md, we apply the bias +// through a fused-sum post-op +#ifdef VLLM_USE_ACL + dnnl::post_ops post_ops; + post_ops.append_sum(); + attr.set_post_ops(post_ops); + return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, + attr); +#else return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, c_md, attr); +#endif } else { return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, attr); @@ -511,13 +559,23 @@ void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { default_engine(), nullptr); set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); +// ACL matmuls don't support bias_md, so we don't need these +#ifndef VLLM_USE_ACL memory_cache_[DNNL_ARG_BIAS] = dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get()); - +#endif memory_cache_[DNNL_ARG_SCRATCHPAD] = dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get()); } + +bool is_onednn_acl_supported() { +#ifdef VLLM_USE_ACL + return true; +#else + return false; +#endif +} diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h index ad6773d2b9fd..f0cb197d81a3 100644 --- a/csrc/cpu/dnnl_helper.h +++ b/csrc/cpu/dnnl_helper.h @@ -101,7 +101,7 @@ class DNNLMatMulPrimitiveHandler { protected: DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type); - void prepack_weight(void* original_b_ptr, + void prepack_weight(void* original_b_ptr, dnnl::memory::desc original_b_md, dnnl::memory::desc b_target_mem_desc); void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr); diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp index 1c42a75bc2d6..6d062c71e767 100644 --- a/csrc/cpu/dnnl_kernels.cpp +++ b/csrc/cpu/dnnl_kernels.cpp @@ -527,21 +527,42 @@ void onednn_mm(torch::Tensor& c, // [M, OC], row-major MatMulPrimitiveHandler* ptr = reinterpret_cast(handler); +// ACL matmuls expect contiguous source tensors +#ifdef VLLM_USE_ACL + torch::Tensor a_contig = a.contiguous(); +#endif + MatMulPrimitiveHandler::ExecArgs exec_args; + +#ifdef VLLM_USE_ACL + exec_args.a_m_size = a_contig.size(0); + exec_args.a_m_stride = a_contig.stride(0); +#else exec_args.a_m_size = a.size(0); exec_args.a_m_stride = a.stride(0); - +#endif VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] { if (bias.has_value()) { exec_args.use_bias = true; exec_args.bias_type = get_dnnl_type(); +#ifdef VLLM_USE_ACL + // ACL matmuls in oneDNN do not support a bias. + // We handle a matmul with bias by doing: c = bias; c += matmul(a, b) + c.copy_(bias.value()); +#else exec_args.bias_ptr = bias->data_ptr(); +#endif } else { exec_args.use_bias = false; exec_args.bias_type = get_dnnl_type(); exec_args.bias_ptr = nullptr; } +#ifdef VLLM_USE_ACL + exec_args.a_ptr = a_contig.data_ptr(); +#else exec_args.a_ptr = a.data_ptr(); + +#endif exec_args.c_ptr = c.data_ptr(); ptr->execute(exec_args); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index d279c03e0b59..9df19d1ac392 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -27,6 +27,8 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b, void onednn_mm(torch::Tensor& c, const torch::Tensor& a, const std::optional& bias, int64_t handler); +bool is_onednn_acl_supported(); + void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); @@ -181,6 +183,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int handler) -> ()"); ops.impl("onednn_mm", torch::kCPU, &onednn_mm); + // Check if oneDNN was built with ACL backend + ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported); + // Create oneDNN W8A8 handler ops.def( "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index 1dd7101acc27..5e742d0b0293 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -27,7 +27,7 @@ class MixedInputKernelScheduleType(enum.Enum): **{ VLLMDataType.u4b8: "u4b8", VLLMDataType.u8b128: "u8b128", - } + }, } VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = { @@ -35,7 +35,7 @@ class MixedInputKernelScheduleType(enum.Enum): **{ VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", - } + }, } VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = { @@ -43,7 +43,7 @@ class MixedInputKernelScheduleType(enum.Enum): **{ VLLMDataType.u4b8: 4, VLLMDataType.u8b128: 8, - } + }, } VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = { @@ -67,15 +67,13 @@ class MixedInputKernelScheduleType(enum.Enum): DataType.f32: "at::ScalarType::Float", } -VLLMKernelScheduleTag: dict[Union[ - MixedInputKernelScheduleType, KernelScheduleType], str] = { - **KernelScheduleTag, # type: ignore - **{ - MixedInputKernelScheduleType.TmaWarpSpecialized: - "cutlass::gemm::KernelTmaWarpSpecialized", - MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: - "cutlass::gemm::KernelTmaWarpSpecializedPingpong", - MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: - "cutlass::gemm::KernelTmaWarpSpecializedCooperative", - } - } +VLLMKernelScheduleTag: dict[ + Union[MixedInputKernelScheduleType, KernelScheduleType], str +] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", # noqa: E501 + MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong", # noqa: E501 + MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative", # noqa: E501 + }, +} diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 698deb107cc0..be5b68cc53e6 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -17,25 +17,30 @@ namespace MARLIN_NAMESPACE_NAME { """.strip() -TEMPLATE = ("template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " - "{{s_type_id}}, " - "{{threads}}, " - "{{thread_m_blocks}}, " - "{{thread_n_blocks}}, " - "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " - "{{stages}}, " - "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" - "( MARLIN_KERNEL_PARAMS );") +TEMPLATE = ( + "template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{s_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );" +) # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. SCALAR_TYPES = [ - "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", - "vllm::kFE2M1f" + "vllm::kU4", + "vllm::kU4B8", + "vllm::kU8B128", + "vllm::kFE4M3fn", + "vllm::kFE2M1f", ] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] @@ -58,11 +63,12 @@ def generate_new_kernels(): all_template_str_list = [] for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): - + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS + ): # act order case only support gptq-int4 and gptq-int8 if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", "vllm::kU8B128" + "vllm::kU4B8", + "vllm::kU8B128", ]: continue if thread_configs[2] == 256: diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py index 7576e0548abe..42d3b456096e 100644 --- a/csrc/quantization/gptq_marlin/generate_kernels.py +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -17,28 +17,32 @@ namespace MARLIN_NAMESPACE_NAME { """.strip() -TEMPLATE = ("template __global__ void Marlin<" - "{{scalar_t}}, " - "{{w_type_id}}, " - "{{s_type_id}}, " - "{{threads}}, " - "{{thread_m_blocks}}, " - "{{thread_n_blocks}}, " - "{{thread_k_blocks}}, " - "{{'true' if m_block_size_8 else 'false'}}, " - "{{stages}}, " - "{{group_blocks}}, " - "{{'true' if is_zp_float else 'false'}}>" - "( MARLIN_KERNEL_PARAMS );") +TEMPLATE = ( + "template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{s_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );" +) # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. SCALAR_TYPES = [ - "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", - "vllm::kFE2M1f" + "vllm::kU4", + "vllm::kU4B8", + "vllm::kU8B128", + "vllm::kFE4M3fn", + "vllm::kFE2M1f", ] -THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), - (128, 64, 128)] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] # group_blocks: @@ -59,11 +63,12 @@ def generate_new_kernels(): all_template_str_list = [] for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): - + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS + ): # act order case only support gptq-int4 and gptq-int8 if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", "vllm::kU8B128" + "vllm::kU4B8", + "vllm::kU8B128", ]: continue if thread_configs[2] == 256: @@ -93,8 +98,7 @@ def generate_new_kernels(): c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" is_zp_float_list = [False] - if dtype == "fp16" and scalar_type == "vllm::kU4" and \ - group_blocks == 4: + if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4: # HQQ (is_zp_float = true) only supports # 4bit quantization and fp16 is_zp_float_list.append(True) diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 8fd536ef46e3..d29a199c5d32 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -12,20 +12,21 @@ from typing import Optional, Union import jinja2 -# yapf conflicts with isort for this block -# yapf: disable -from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag, - EpilogueScheduleType, - MixedInputKernelScheduleType, - TileSchedulerTag, - TileSchedulerType, VLLMDataType, - VLLMDataTypeNames, - VLLMDataTypeSize, VLLMDataTypeTag, - VLLMDataTypeTorchDataTypeTag, - VLLMDataTypeVLLMScalarTypeTag, - VLLMKernelScheduleTag) - -# yapf: enable +from vllm_cutlass_library_extension import ( + DataType, + EpilogueScheduleTag, + EpilogueScheduleType, + MixedInputKernelScheduleType, + TileSchedulerTag, + TileSchedulerType, + VLLMDataType, + VLLMDataTypeNames, + VLLMDataTypeSize, + VLLMDataTypeTag, + VLLMDataTypeTorchDataTypeTag, + VLLMDataTypeVLLMScalarTypeTag, + VLLMKernelScheduleTag, +) # # Generator templating @@ -286,18 +287,23 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str: tile_shape = ( f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" ) - cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" + - f"x{schedule_config.cluster_shape_mnk[1]}" + - f"x{schedule_config.cluster_shape_mnk[2]}") - kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\ - .split("::")[-1] - epilogue_schedule = EpilogueScheduleTag[ - schedule_config.epilogue_schedule].split("::")[-1] - tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\ - .split("::")[-1] - - return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + - f"_{epilogue_schedule}_{tile_scheduler}") + cluster_shape = ( + f"{schedule_config.cluster_shape_mnk[0]}" + + f"x{schedule_config.cluster_shape_mnk[1]}" + + f"x{schedule_config.cluster_shape_mnk[2]}" + ) + kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split( + "::" + )[-1] + epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split( + "::" + )[-1] + tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1] + + return ( + f"{tile_shape}_{cluster_shape}_{kernel_schedule}" + + f"_{epilogue_schedule}_{tile_scheduler}" + ) # mostly unique shorter sch_sig @@ -316,18 +322,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: # unique type_name def generate_type_signature(kernel_types: TypeConfig): - return str("".join([ - VLLMDataTypeNames[getattr(kernel_types, field.name)] - for field in fields(TypeConfig) - ])) + return str( + "".join( + [ + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ] + ) + ) def generate_type_option_name(kernel_types: TypeConfig): - return ", ".join([ - f"{field.name.replace('b_', 'with_')+'_type'}=" + - VLLMDataTypeNames[getattr(kernel_types, field.name)] - for field in fields(TypeConfig) - ]) + return ", ".join( + [ + f"{field.name.replace('b_', 'with_') + '_type'}=" + + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ] + ) def is_power_of_two(n): @@ -335,7 +347,6 @@ def is_power_of_two(n): def to_cute_constant(value: list[int]): - def _to_cute_constant(value: int): if is_power_of_two(value): return f"_{value}" @@ -350,11 +361,11 @@ def _to_cute_constant(value: int): def unique_schedules(impl_configs: list[ImplConfig]): # Use dict over set for deterministic ordering - return list({ - sch: None - for impl_config in impl_configs - for sch in impl_config.schedules - }.keys()) + return list( + { + sch: None for impl_config in impl_configs for sch in impl_config.schedules + }.keys() + ) def unsigned_type_with_bitwidth(num_bits): @@ -380,7 +391,7 @@ def unsigned_type_with_bitwidth(num_bits): "gen_type_sig": generate_type_signature, "unique_schedules": unique_schedules, "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, - "gen_type_option_name": generate_type_option_name + "gen_type_option_name": generate_type_option_name, } @@ -398,23 +409,28 @@ def create_template(template_str): def create_sources(impl_configs: list[ImplConfig], num_impl_files=8): sources = [] - sources.append(( - "machete_mm_dispatch", - mm_dispatch_template.render(impl_configs=impl_configs), - )) + sources.append( + ( + "machete_mm_dispatch", + mm_dispatch_template.render(impl_configs=impl_configs), + ) + ) prepack_types = [] for impl_config in impl_configs: - convert_type = impl_config.types.a \ - if impl_config.types.b_group_scale == DataType.void \ - else impl_config.types.b_group_scale + convert_type = ( + impl_config.types.a + if impl_config.types.b_group_scale == DataType.void + else impl_config.types.b_group_scale + ) prepack_types.append( PrepackTypeConfig( a=impl_config.types.a, b_num_bits=VLLMDataTypeSize[impl_config.types.b], convert=convert_type, accumulator=impl_config.types.accumulator, - )) + ) + ) def prepacked_type_key(prepack_type: PrepackTypeConfig): # For now, we can just use the first accumulator type seen since @@ -430,10 +446,14 @@ def prepacked_type_key(prepack_type: PrepackTypeConfig): unique_prepack_types.append(prepack_type) prepack_types_seen.add(key) - sources.append(( - "machete_prepack", - prepack_dispatch_template.render(types=unique_prepack_types, ), - )) + sources.append( + ( + "machete_prepack", + prepack_dispatch_template.render( + types=unique_prepack_types, + ), + ) + ) # Split up impls across files num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) @@ -466,10 +486,12 @@ def prepacked_type_key(prepack_type: PrepackTypeConfig): curr_impl_in_file += len(files_impls[-1][-1].schedules) for part, file_impls in enumerate(files_impls): - sources.append(( - f"machete_mm_impl_part{part+1}", - mm_impl_template.render(impl_configs=file_impls), - )) + sources.append( + ( + f"machete_mm_impl_part{part + 1}", + mm_impl_template.render(impl_configs=file_impls), + ) + ) return sources @@ -514,8 +536,7 @@ def generate(): # For now we use the same heuristic for all types # Heuristic is currently tuned for H100s default_heuristic = [ - (cond, ScheduleConfig(*tile_config, - **sch_common_params)) # type: ignore + (cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore for cond, tile_config in default_tile_heuristic_config.items() ] @@ -541,14 +562,18 @@ def get_unique_schedules(heuristic: dict[str, ScheduleConfig]): a_token_scale=DataType.void, out=a, accumulator=DataType.f32, - ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) - for a in (DataType.f16, DataType.bf16)) + ) + for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for a in (DataType.f16, DataType.bf16) + ) impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(GPTQ_kernel_type_configs, - itertools.repeat(get_unique_schedules(default_heuristic)), - itertools.repeat(default_heuristic)) + for x in zip( + GPTQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic), + ) ] AWQ_kernel_type_configs = list( @@ -561,14 +586,18 @@ def get_unique_schedules(heuristic: dict[str, ScheduleConfig]): a_token_scale=DataType.void, out=a, accumulator=DataType.f32, - ) for b in (DataType.u4, DataType.u8) - for a in (DataType.f16, DataType.bf16)) + ) + for b in (DataType.u4, DataType.u8) + for a in (DataType.f16, DataType.bf16) + ) impl_configs += [ ImplConfig(x[0], x[1], x[2]) - for x in zip(AWQ_kernel_type_configs, - itertools.repeat(get_unique_schedules(default_heuristic)), - itertools.repeat(default_heuristic)) + for x in zip( + AWQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic), + ) ] # TODO: Support W4A8 when ready diff --git a/docs/deployment/frameworks/hf_inference_endpoints.md b/docs/deployment/frameworks/hf_inference_endpoints.md index 50c981f42c03..75a234bdf142 100644 --- a/docs/deployment/frameworks/hf_inference_endpoints.md +++ b/docs/deployment/frameworks/hf_inference_endpoints.md @@ -61,7 +61,7 @@ This is the easiest way to get started with vLLM on Hugging Face Inference Endpo ### Method 2: Guided Deployment (Transformers Models) -This method applies to models with the `transformers` library tag in their metadata. It allows you to deploy a model directly from the Hub UI without manual configuration. +This method applies to models with the [`transformers` library tag](https://huggingface.co/models?library=transformers) in their metadata. It allows you to deploy a model directly from the Hub UI without manual configuration. 1. Navigate to a model on [Hugging Face Hub](https://huggingface.co/models). For this example we will use the [`ibm-granite/granite-docling-258M`](https://huggingface.co/ibm-granite/granite-docling-258M) model. You can verify that the model is compatible by checking the front matter in the [README](https://huggingface.co/ibm-granite/granite-docling-258M/blob/main/README.md), where the library is tagged as `library: transformers`. @@ -128,7 +128,7 @@ Some models require manual deployment because they: These models cannot be deployed using the **Deploy** button on the model card. -In this guide, we demonstrate manual deployment using the [rednote-hilab/dots.ocr](https://huggingface.co/rednote-hilab/dots.ocr) model, an OCR model integrated with vLLM (see vLLM [PR](https://github.com/vllm-project/vllm/pull/24645)). +In this guide, we demonstrate manual deployment using the [`rednote-hilab/dots.ocr`](https://huggingface.co/rednote-hilab/dots.ocr) model, an OCR model integrated with vLLM (see vLLM [PR](https://github.com/vllm-project/vllm/pull/24645)). 1. Start a new deployment. Go to [Inference Endpoints](https://endpoints.huggingface.co/) and click `New`. diff --git a/docs/deployment/integrations/kaito.md b/docs/deployment/integrations/kaito.md new file mode 100644 index 000000000000..ff050d3eeaf4 --- /dev/null +++ b/docs/deployment/integrations/kaito.md @@ -0,0 +1,5 @@ +# KAITO + +[KAITO](https://kaito-project.github.io/kaito/docs/) is a Kubernetes operator that supports deploying and serving LLMs with vLLM. It offers managing large models via container images with built-in OpenAI-compatible inference, auto-provisioning GPU nodes and curated model presets. + +Please refer to [quick start](https://kaito-project.github.io/kaito/docs/quick-start) for more details. diff --git a/docs/deployment/integrations/production-stack.md b/docs/deployment/integrations/production-stack.md index fae392589c06..2f1894ccf002 100644 --- a/docs/deployment/integrations/production-stack.md +++ b/docs/deployment/integrations/production-stack.md @@ -55,7 +55,7 @@ sudo kubectl port-forward svc/vllm-router-service 30080:80 And then you can send out a query to the OpenAI-compatible API to check the available models: ```bash -curl -o- http://localhost:30080/models +curl -o- http://localhost:30080/v1/models ``` ??? console "Output" @@ -78,7 +78,7 @@ curl -o- http://localhost:30080/models To send an actual chatting request, you can issue a curl request to the OpenAI `/completion` endpoint: ```bash -curl -X POST http://localhost:30080/completions \ +curl -X POST http://localhost:30080/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "facebook/opt-125m", diff --git a/docs/deployment/k8s.md b/docs/deployment/k8s.md index ca23e0b9fd8a..d3fda7eb6fb6 100644 --- a/docs/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -12,6 +12,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following: - [Helm](frameworks/helm.md) - [InftyAI/llmaz](integrations/llmaz.md) +- [KAITO](integrations/kaito.md) - [KServe](integrations/kserve.md) - [KubeRay](integrations/kuberay.md) - [kubernetes-sigs/lws](frameworks/lws.md) diff --git a/docs/design/plugin_system.md b/docs/design/plugin_system.md index 37193809776a..a384c6289f4f 100644 --- a/docs/design/plugin_system.md +++ b/docs/design/plugin_system.md @@ -49,7 +49,7 @@ Every plugin has three parts: - **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported. -- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for poling models. The plugin function returns the IOProcessor's class fully qualified name. +- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for pooling models. The plugin function returns the IOProcessor's class fully qualified name. ## Guidelines for Writing Plugins diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 241438ae5578..6a0bcfac66d0 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -191,10 +191,14 @@ VLLM also provides a pythonic and JSON-based chat template for Llama 4, but pyth For Llama 4 model, use `--tool-call-parser llama4_pythonic --chat-template examples/tool_chat_template_llama4_pythonic.jinja`. -#### IBM Granite +### IBM Granite Supported models: +* `ibm-granite/granite-4.0-h-small` and other Granite 4.0 models + + Recommended flags: `--tool-call-parser hermes` + * `ibm-granite/granite-3.0-8b-instruct` Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index d026235dd9d5..ecd71ee1f3f6 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -33,8 +33,11 @@ def auto_mock(module, attr, max_mocks=50): try: # First treat attr as an attr, then as a submodule with patch("importlib.metadata.version", return_value="0.0.0"): - return getattr(importlib.import_module(module), attr, - importlib.import_module(f"{module}.{attr}")) + return getattr( + importlib.import_module(module), + attr, + importlib.import_module(f"{module}.{attr}"), + ) except importlib.metadata.PackageNotFoundError as e: raise e except ModuleNotFoundError as e: @@ -42,7 +45,8 @@ def auto_mock(module, attr, max_mocks=50): sys.modules[e.name] = PydanticMagicMock() raise ImportError( - f"Failed to import {module}.{attr} after mocking {max_mocks} imports") + f"Failed to import {module}.{attr} after mocking {max_mocks} imports" + ) latency = auto_mock("vllm.benchmarks", "latency") @@ -61,9 +65,7 @@ class MarkdownFormatter(HelpFormatter): """Custom formatter that generates markdown for argument groups.""" def __init__(self, prog, starting_heading_level=3): - super().__init__(prog, - max_help_position=float('inf'), - width=float('inf')) + super().__init__(prog, max_help_position=float("inf"), width=float("inf")) self._section_heading_prefix = "#" * starting_heading_level self._argument_heading_prefix = "#" * (starting_heading_level + 1) self._markdown_output = [] @@ -85,23 +87,19 @@ def add_usage(self, usage, actions, groups, prefix=None): def add_arguments(self, actions): for action in actions: - if (len(action.option_strings) == 0 - or "--help" in action.option_strings): + if len(action.option_strings) == 0 or "--help" in action.option_strings: continue - option_strings = f'`{"`, `".join(action.option_strings)}`' + option_strings = f"`{'`, `'.join(action.option_strings)}`" heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n" self._markdown_output.append(heading_md) if choices := action.choices: - choices = f'`{"`, `".join(str(c) for c in choices)}`' - self._markdown_output.append( - f"Possible choices: {choices}\n\n") - elif ((metavar := action.metavar) - and isinstance(metavar, (list, tuple))): - metavar = f'`{"`, `".join(str(m) for m in metavar)}`' - self._markdown_output.append( - f"Possible choices: {metavar}\n\n") + choices = f"`{'`, `'.join(str(c) for c in choices)}`" + self._markdown_output.append(f"Possible choices: {choices}\n\n") + elif (metavar := action.metavar) and isinstance(metavar, (list, tuple)): + metavar = f"`{'`, `'.join(str(m) for m in metavar)}`" + self._markdown_output.append(f"Possible choices: {metavar}\n\n") if action.help: self._markdown_output.append(f"{action.help}\n\n") @@ -116,7 +114,7 @@ def format_help(self): def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser: """Create a parser for the given class with markdown formatting. - + Args: cls: The class to create a parser for **kwargs: Additional keyword arguments to pass to `cls.add_cli_args`. @@ -143,24 +141,17 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): # Create parsers to document parsers = { - "engine_args": - create_parser(EngineArgs.add_cli_args), - "async_engine_args": - create_parser(AsyncEngineArgs.add_cli_args, async_args_only=True), - "serve": - create_parser(cli_args.make_arg_parser), - "chat": - create_parser(ChatCommand.add_cli_args), - "complete": - create_parser(CompleteCommand.add_cli_args), - "bench_latency": - create_parser(latency.add_cli_args), - "bench_throughput": - create_parser(throughput.add_cli_args), - "bench_serve": - create_parser(serve.add_cli_args), - "run-batch": - create_parser(run_batch.make_arg_parser), + "engine_args": create_parser(EngineArgs.add_cli_args), + "async_engine_args": create_parser( + AsyncEngineArgs.add_cli_args, async_args_only=True + ), + "serve": create_parser(cli_args.make_arg_parser), + "chat": create_parser(ChatCommand.add_cli_args), + "complete": create_parser(CompleteCommand.add_cli_args), + "bench_latency": create_parser(latency.add_cli_args), + "bench_throughput": create_parser(throughput.add_cli_args), + "bench_serve": create_parser(serve.add_cli_args), + "run-batch": create_parser(run_batch.make_arg_parser), } # Generate documentation for each parser diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 0cbaebb598a3..ed8277f628d4 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -11,7 +11,7 @@ logger = logging.getLogger("mkdocs") ROOT_DIR = Path(__file__).parent.parent.parent.parent -ROOT_DIR_RELATIVE = '../../../../..' +ROOT_DIR_RELATIVE = "../../../../.." EXAMPLE_DIR = ROOT_DIR / "examples" EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples" @@ -36,7 +36,7 @@ def fix_case(text: str) -> str: r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16 } for pattern, repl in subs.items(): - text = re.sub(rf'\b{pattern}\b', repl, text, flags=re.IGNORECASE) + text = re.sub(rf"\b{pattern}\b", repl, text, flags=re.IGNORECASE) return text @@ -58,7 +58,8 @@ class Example: determine_other_files() -> list[Path]: Determines other files in the directory excluding the main file. determine_title() -> str: Determines the title of the document. generate() -> str: Generates the documentation content. - """ # noqa: E501 + """ # noqa: E501 + path: Path category: str = None main_file: Path = field(init=False) @@ -84,9 +85,8 @@ def determine_main_file(self) -> Path: Markdown file found in the directory. Raises: IndexError: If no Markdown files are found in the directory. - """ # noqa: E501 - return self.path if self.path.is_file() else list( - self.path.glob("*.md")).pop() + """ # noqa: E501 + return self.path if self.path.is_file() else list(self.path.glob("*.md")).pop() def determine_other_files(self) -> list[Path]: """ @@ -98,7 +98,7 @@ def determine_other_files(self) -> list[Path]: Returns: list[Path]: A list of Path objects representing the other files in the directory. - """ # noqa: E501 + """ # noqa: E501 if self.path.is_file(): return [] is_other_file = lambda file: file.is_file() and file != self.main_file @@ -109,25 +109,25 @@ def determine_title(self) -> str: # Specify encoding for building on Windows with open(self.main_file, encoding="utf-8") as f: first_line = f.readline().strip() - match = re.match(r'^#\s+(?P.+)$', first_line) + match = re.match(r"^#\s+(?P<title>.+)$", first_line) if match: - return match.group('title') + return match.group("title") return fix_case(self.path.stem.replace("_", " ").title()) def fix_relative_links(self, content: str) -> str: """ Fix relative links in markdown content by converting them to gh-file format. - + Args: content (str): The markdown content to process - + Returns: str: Content with relative links converted to gh-file format """ # Regex to match markdown links [text](relative_path) # This matches links that don't start with http, https, ftp, or # - link_pattern = r'\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)' + link_pattern = r"\[([^\]]*)\]\((?!(?:https?|ftp)://|#)([^)]+)\)" def replace_link(match): link_text = match.group(1) @@ -137,7 +137,7 @@ def replace_link(match): gh_file = (self.main_file.parent / relative_path).resolve() gh_file = gh_file.relative_to(ROOT_DIR) - return f'[{link_text}](gh-file:{gh_file})' + return f"[{link_text}](gh-file:{gh_file})" return re.sub(link_pattern, replace_link, content) @@ -150,9 +150,11 @@ def generate(self) -> str: code_fence = "``````" if self.is_code: - content += (f"{code_fence}{self.main_file.suffix[1:]}\n" - f'--8<-- "{self.main_file}"\n' - f"{code_fence}\n") + content += ( + f"{code_fence}{self.main_file.suffix[1:]}\n" + f'--8<-- "{self.main_file}"\n' + f"{code_fence}\n" + ) else: with open(self.main_file) as f: # Skip the title from md snippets as it's been included above diff --git a/docs/mkdocs/hooks/remove_announcement.py b/docs/mkdocs/hooks/remove_announcement.py index 1a84039abc14..12db2265b9f8 100644 --- a/docs/mkdocs/hooks/remove_announcement.py +++ b/docs/mkdocs/hooks/remove_announcement.py @@ -7,7 +7,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): # see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa - if os.getenv('READTHEDOCS_VERSION_TYPE') == "tag": + if os.getenv("READTHEDOCS_VERSION_TYPE") == "tag": # remove the warning banner if the version is a tagged release mkdocs_dir = Path(__file__).parent.parent announcement_path = mkdocs_dir / "overrides/main.html" diff --git a/docs/mkdocs/hooks/url_schemes.py b/docs/mkdocs/hooks/url_schemes.py index 6fce6bd8130e..53b1fbca26b9 100644 --- a/docs/mkdocs/hooks/url_schemes.py +++ b/docs/mkdocs/hooks/url_schemes.py @@ -25,8 +25,9 @@ from mkdocs.structure.pages import Page -def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, - files: Files) -> str: +def on_page_markdown( + markdown: str, *, page: Page, config: MkDocsConfig, files: Files +) -> str: """ Custom MkDocs plugin hook to rewrite special GitHub reference links in Markdown. @@ -35,7 +36,7 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, GitHub shorthand links, such as: - `[Link text](gh-issue:123)` - `<gh-pr:456>` - + And rewrites them into fully-qualified GitHub URLs with GitHub icons: - `[:octicons-mark-github-16: Link text](https://github.com/vllm-project/vllm/issues/123)` - `[:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456)` @@ -88,21 +89,21 @@ def replace_inline_link(match: re.Match) -> str: """ Replaces a matched inline-style GitHub shorthand link with a full Markdown link. - + Example: [My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123) """ - url = f'{urls[match.group("type")]}/{match.group("path")}' + url = f"{urls[match.group('type')]}/{match.group('path')}" if fragment := match.group("fragment"): url += f"#{fragment}" - return f'[{gh_icon} {match.group("title")}]({url})' + return f"[{gh_icon} {match.group('title')}]({url})" def replace_auto_link(match: re.Match) -> str: """ Replaces a matched autolink-style GitHub shorthand with a full Markdown link. - + Example: <gh-pr:456> → [:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456) """ diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index c705a70b93f5..60fe5b887952 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -32,8 +32,9 @@ If the Transformers model implementation follows all the steps in [writing a cus - All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature) - Any combination of the following vLLM parallelisation schemes: - Data parallel - - Pipeline parallel - Tensor parallel + - Expert parallel + - Pipeline parallel Checking if the modeling backend is Transformers is as simple as: @@ -828,6 +829,7 @@ The following table lists those that are tested in vLLM. | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) | |--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| +| `CLIPModel` | CLIP | T / I | `openai/clip-vit-base-patch32`, `openai/clip-vit-large-patch14`, etc. | | | ✅︎ | | `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | ✅︎ | ✅︎ | | `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | | ✅︎ | ✅︎ | | `*ForConditionalGeneration`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | \* | N/A | \* | \* | \* | diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index e0d95758a822..3e711e205313 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -371,13 +371,14 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: ) -def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "Kwai-Keye/Keye-VL-8B-Preview" engine_args = EngineArgs( model=model_name, - max_model_len=131072, - tensor_parallel_size=8, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=5, limit_mm_per_prompt={"image": len(image_urls)}, ) @@ -389,29 +390,32 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: *placeholders, {"type": "text", "text": question}, ], - } + }, ] - processor = AutoProcessor.from_pretrained(model_name) + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) + image_data = [fetch_image(url) for url in image_urls] + return ModelRequestData( engine_args=engine_args, prompt=prompt, - image_data=[fetch_image(url) for url in image_urls], + image_data=image_data, ) -def load_llava(question: str, image_urls: list[str]) -> ModelRequestData: - # NOTE: CAUTION! Original Llava models wasn't really trained on multi-image inputs, - # it will generate poor response for multi-image inputs! - model_name = "llava-hf/llava-1.5-7b-hf" +def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "Kwai-Keye/Keye-VL-1_5-8B" + engine_args = EngineArgs( model=model_name, - max_num_seqs=16, + trust_remote_code=True, + max_model_len=32768, + max_num_seqs=5, limit_mm_per_prompt={"image": len(image_urls)}, ) @@ -423,28 +427,32 @@ def load_llava(question: str, image_urls: list[str]) -> ModelRequestData: *placeholders, {"type": "text", "text": question}, ], - } + }, ] - processor = AutoProcessor.from_pretrained(model_name) + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) + image_data = [fetch_image(url) for url in image_urls] + return ModelRequestData( engine_args=engine_args, prompt=prompt, - image_data=[fetch_image(url) for url in image_urls], + image_data=image_data, ) -def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "llava-hf/llava-v1.6-mistral-7b-hf" +def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "moonshotai/Kimi-VL-A3B-Instruct" + engine_args = EngineArgs( model=model_name, - max_model_len=8192, - max_num_seqs=16, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=4, limit_mm_per_prompt={"image": len(image_urls)}, ) @@ -459,7 +467,7 @@ def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData: } ] - processor = AutoProcessor.from_pretrained(model_name) + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True @@ -472,12 +480,13 @@ def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData: ) -def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "llava-hf/llava-onevision-qwen2-7b-ov-hf" +def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + engine_args = EngineArgs( model=model_name, - max_model_len=16384, - max_num_seqs=16, + max_model_len=131072, + tensor_parallel_size=8, limit_mm_per_prompt={"image": len(image_urls)}, ) @@ -505,14 +514,13 @@ def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestDa ) -def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "Kwai-Keye/Keye-VL-8B-Preview" - +def load_llava(question: str, image_urls: list[str]) -> ModelRequestData: + # NOTE: CAUTION! Original Llava models wasn't really trained on multi-image inputs, + # it will generate poor response for multi-image inputs! + model_name = "llava-hf/llava-1.5-7b-hf" engine_args = EngineArgs( model=model_name, - trust_remote_code=True, - max_model_len=8192, - max_num_seqs=5, + max_num_seqs=16, limit_mm_per_prompt={"image": len(image_urls)}, ) @@ -524,32 +532,28 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData: *placeholders, {"type": "text", "text": question}, ], - }, + } ] - processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model_name) prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - image_data = [fetch_image(url) for url in image_urls] - return ModelRequestData( engine_args=engine_args, prompt=prompt, - image_data=image_data, + image_data=[fetch_image(url) for url in image_urls], ) -def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "Kwai-Keye/Keye-VL-1_5-8B" - +def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "llava-hf/llava-v1.6-mistral-7b-hf" engine_args = EngineArgs( model=model_name, - trust_remote_code=True, max_model_len=8192, - max_num_seqs=5, + max_num_seqs=16, limit_mm_per_prompt={"image": len(image_urls)}, ) @@ -561,32 +565,28 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData: *placeholders, {"type": "text", "text": question}, ], - }, + } ] - processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model_name) prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - image_data = [fetch_image(url) for url in image_urls] - return ModelRequestData( engine_args=engine_args, prompt=prompt, - image_data=image_data, + image_data=[fetch_image(url) for url in image_urls], ) -def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: - model_name = "moonshotai/Kimi-VL-A3B-Instruct" - +def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "llava-hf/llava-onevision-qwen2-7b-ov-hf" engine_args = EngineArgs( model=model_name, - trust_remote_code=True, - max_model_len=4096, - max_num_seqs=4, + max_model_len=16384, + max_num_seqs=16, limit_mm_per_prompt={"image": len(image_urls)}, ) @@ -601,7 +601,7 @@ def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: } ] - processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model_name) prompt = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True diff --git a/examples/offline_inference/vision_language_pooling.py b/examples/offline_inference/vision_language_pooling.py index 3d1daf4d19ff..33ffb59014d8 100644 --- a/examples/offline_inference/vision_language_pooling.py +++ b/examples/offline_inference/vision_language_pooling.py @@ -58,6 +58,30 @@ class ModelRequestData(NamedTuple): documents: Optional[ScoreMultiModalParam] = None +def run_clip(query: Query) -> ModelRequestData: + if query["modality"] == "text": + prompt = query["text"] + image = None + elif query["modality"] == "image": + prompt = "" # For image input, make sure that the prompt text is empty + image = query["image"] + else: + modality = query["modality"] + raise ValueError(f"Unsupported query modality: '{modality}'") + + engine_args = EngineArgs( + model="openai/clip-vit-base-patch32", + runner="pooling", + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image=image, + ) + + def run_e5_v(query: Query) -> ModelRequestData: llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501 @@ -89,7 +113,7 @@ def run_e5_v(query: Query) -> ModelRequestData: def _get_vlm2vec_prompt_image(query: Query, image_token: str): if query["modality"] == "text": text = query["text"] - prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501 + prompt = f"Find me an everyday image that matches the given caption: {text}" image = None elif query["modality"] == "image": prompt = f"{image_token} Find a day-to-day image that looks similar to the provided image." # noqa: E501 @@ -146,7 +170,8 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData: processor = AutoProcessor.from_pretrained( model_id, - # `min_pixels` and `max_pixels` are deprecated + # `min_pixels` and `max_pixels` are deprecated for + # transformers `preprocessor_config.json` size={"shortest_edge": 3136, "longest_edge": 12845056}, ) processor.chat_template = load_chat_template( @@ -172,8 +197,10 @@ def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData: model=merged_path, runner="pooling", max_model_len=4096, - trust_remote_code=True, - mm_processor_kwargs={"num_crops": 4}, + mm_processor_kwargs={ + "min_pixels": 3136, + "max_pixels": 12845056, + }, limit_mm_per_prompt={"image": 1}, ) @@ -299,6 +326,7 @@ def run_score(model: str, modality: QueryModality, seed: Optional[int]): model_example_map = { + "clip": run_clip, "e5_v": run_e5_v, "vlm2vec_phi3v": run_vlm2vec_phi3v, "vlm2vec_qwen2vl": run_vlm2vec_qwen2vl, diff --git a/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py index d39edb0b9d15..1df11d9d8495 100644 --- a/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py +++ b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py @@ -203,9 +203,9 @@ async def forward_request(self, url, data, use_chunked=True): async with session.post( url=url, json=data, headers=headers ) as response: - if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501 + if 200 <= response.status < 300 or 400 <= response.status < 500: if use_chunked: - async for chunk_bytes in response.content.iter_chunked( # noqa: E501 + async for chunk_bytes in response.content.iter_chunked( 1024 ): yield chunk_bytes diff --git a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py index 6e31c3836806..16ac4378c686 100644 --- a/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py +++ b/examples/online_serving/pooling/openai_chat_embedding_client_for_multimodal.py @@ -1,14 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # ruff: noqa: E501 -"""Example Python client for multimodal embedding API using vLLM API server -NOTE: - start a supported multimodal embeddings model server with `vllm serve`, e.g. - vllm serve TIGER-Lab/VLM2Vec-Full \ - --runner pooling \ - --trust-remote-code \ - --max-model-len 4096 \ - --chat-template examples/template_vlm2vec_phi3v.jinja +"""Example Python client for multimodal embedding API using vLLM API server. + +Refer to each `run_*` function for the command to run the server for that model. """ import argparse @@ -47,7 +42,58 @@ def create_chat_embeddings( ) +def run_clip(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve openai/clip-vit-base-patch32 \ + --runner pooling + """ + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Image embedding output:", response.data[0].embedding) + + response = create_chat_embeddings( + client, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "a photo of a cat"}, + ], + } + ], + model=model, + encoding_format="float", + ) + + print("Text embedding output:", response.data[0].embedding) + + def run_vlm2vec(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve TIGER-Lab/VLM2Vec-Full \ + --runner pooling \ + --trust-remote-code \ + --max-model-len 4096 \ + --chat-template examples/template_vlm2vec_phi3v.jinja + """ + response = create_chat_embeddings( client, messages=[ @@ -103,6 +149,15 @@ def run_vlm2vec(client: OpenAI, model: str): def run_dse_qwen2_vl(client: OpenAI, model: str): + """ + Start the server using: + + vllm serve MrLight/dse-qwen2-2b-mrl-v1 \ + --runner pooling \ + --trust-remote-code \ + --max-model-len 8192 \ + --chat-template examples/template_dse_qwen2_vl.jinja + """ response = create_chat_embeddings( client, messages=[ @@ -156,6 +211,7 @@ def run_dse_qwen2_vl(client: OpenAI, model: str): model_example_map = { + "clip": run_clip, "vlm2vec": run_vlm2vec, "dse_qwen2_vl": run_dse_qwen2_vl, } diff --git a/examples/others/tensorize_vllm_model.py b/examples/others/tensorize_vllm_model.py index 2b7f0beab227..acbfd8cda489 100644 --- a/examples/others/tensorize_vllm_model.py +++ b/examples/others/tensorize_vllm_model.py @@ -21,8 +21,6 @@ logger = logging.getLogger() -# yapf conflicts with isort for this docstring -# yapf: disable """ tensorize_vllm_model.py is a script that can be used to serialize and deserialize vLLM models. These models can be loaded using tensorizer @@ -132,7 +130,8 @@ def get_parser(): "can be loaded using tensorizer directly to the GPU " "extremely quickly. Tensor encryption and decryption is " "also supported, although libsodium must be installed to " - "use it.") + "use it." + ) parser = EngineArgs.add_cli_args(parser) parser.add_argument( @@ -144,13 +143,14 @@ def get_parser(): "along with the model by instantiating a TensorizerConfig object, " "creating a dict from it with TensorizerConfig.to_serializable(), " "and passing it to LoRARequest's initializer with the kwarg " - "tensorizer_config_dict." + "tensorizer_config_dict.", ) - subparsers = parser.add_subparsers(dest='command', required=True) + subparsers = parser.add_subparsers(dest="command", required=True) serialize_parser = subparsers.add_parser( - 'serialize', help="Serialize a model to `--serialized-directory`") + "serialize", help="Serialize a model to `--serialized-directory`" + ) serialize_parser.add_argument( "--suffix", @@ -163,7 +163,9 @@ def get_parser(): "`--suffix` is `v1`, the serialized model tensors will be " "saved to " "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " - "If none is provided, a random UUID will be used.")) + "If none is provided, a random UUID will be used." + ), + ) serialize_parser.add_argument( "--serialized-directory", type=str, @@ -175,108 +177,127 @@ def get_parser(): "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will " "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, " "where `suffix` is given by `--suffix` or a random UUID if not " - "provided.") + "provided.", + ) serialize_parser.add_argument( "--serialization-kwargs", type=tensorizer_kwargs_arg, required=False, - help=("A JSON string containing additional keyword arguments to " - "pass to Tensorizer's TensorSerializer during " - "serialization.")) + help=( + "A JSON string containing additional keyword arguments to " + "pass to Tensorizer's TensorSerializer during " + "serialization." + ), + ) serialize_parser.add_argument( "--keyfile", type=str, required=False, - help=("Encrypt the model weights with a randomly-generated binary key," - " and save the key at this path")) + help=( + "Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path" + ), + ) deserialize_parser = subparsers.add_parser( - 'deserialize', - help=("Deserialize a model from `--path-to-tensors`" - " to verify it can be loaded and used.")) + "deserialize", + help=( + "Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used." + ), + ) deserialize_parser.add_argument( "--path-to-tensors", type=str, required=False, - help="The local path or S3 URI to the model tensors to deserialize. ") + help="The local path or S3 URI to the model tensors to deserialize. ", + ) deserialize_parser.add_argument( "--serialized-directory", type=str, required=False, help="Directory with model artifacts for loading. Assumes a " - "model.tensors file exists therein. Can supersede " - "--path-to-tensors.") + "model.tensors file exists therein. Can supersede " + "--path-to-tensors.", + ) deserialize_parser.add_argument( "--keyfile", type=str, required=False, - help=("Path to a binary key to use to decrypt the model weights," - " if the model was serialized with encryption")) + help=( + "Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption" + ), + ) deserialize_parser.add_argument( "--deserialization-kwargs", type=tensorizer_kwargs_arg, required=False, - help=("A JSON string containing additional keyword arguments to " - "pass to Tensorizer's `TensorDeserializer` during " - "deserialization.")) + help=( + "A JSON string containing additional keyword arguments to " + "pass to Tensorizer's `TensorDeserializer` during " + "deserialization." + ), + ) TensorizerArgs.add_cli_args(deserialize_parser) return parser -def merge_extra_config_with_tensorizer_config(extra_cfg: dict, - cfg: TensorizerConfig): + +def merge_extra_config_with_tensorizer_config(extra_cfg: dict, cfg: TensorizerConfig): for k, v in extra_cfg.items(): if hasattr(cfg, k): setattr(cfg, k, v) logger.info( "Updating TensorizerConfig with %s from " - "--model-loader-extra-config provided", k + "--model-loader-extra-config provided", + k, ) + def deserialize(args, tensorizer_config): if args.lora_path: tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir - llm = LLM(model=args.model, - load_format="tensorizer", - tensor_parallel_size=args.tensor_parallel_size, - model_loader_extra_config=tensorizer_config, - enable_lora=True, + llm = LLM( + model=args.model, + load_format="tensorizer", + tensor_parallel_size=args.tensor_parallel_size, + model_loader_extra_config=tensorizer_config, + enable_lora=True, ) sampling_params = SamplingParams( - temperature=0, - max_tokens=256, - stop=["[/assistant]"] + temperature=0, max_tokens=256, stop=["[/assistant]"] ) # Truncating this as the extra text isn't necessary - prompts = [ - "[user] Write a SQL query to answer the question based on ..." - ] + prompts = ["[user] Write a SQL query to answer the question based on ..."] # Test LoRA load print( llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest("sql-lora", - 1, - args.lora_path, - tensorizer_config_dict = tensorizer_config - .to_serializable()) + prompts, + sampling_params, + lora_request=LoRARequest( + "sql-lora", + 1, + args.lora_path, + tensorizer_config_dict=tensorizer_config.to_serializable(), + ), ) ) else: - llm = LLM(model=args.model, - load_format="tensorizer", - tensor_parallel_size=args.tensor_parallel_size, - model_loader_extra_config=tensorizer_config + llm = LLM( + model=args.model, + load_format="tensorizer", + tensor_parallel_size=args.tensor_parallel_size, + model_loader_extra_config=tensorizer_config, ) return llm @@ -285,17 +306,20 @@ def main(): parser = get_parser() args = parser.parse_args() - s3_access_key_id = (getattr(args, 's3_access_key_id', None) - or os.environ.get("S3_ACCESS_KEY_ID", None)) - s3_secret_access_key = (getattr(args, 's3_secret_access_key', None) - or os.environ.get("S3_SECRET_ACCESS_KEY", None)) - s3_endpoint = (getattr(args, 's3_endpoint', None) - or os.environ.get("S3_ENDPOINT_URL", None)) + s3_access_key_id = getattr(args, "s3_access_key_id", None) or os.environ.get( + "S3_ACCESS_KEY_ID", None + ) + s3_secret_access_key = getattr( + args, "s3_secret_access_key", None + ) or os.environ.get("S3_SECRET_ACCESS_KEY", None) + s3_endpoint = getattr(args, "s3_endpoint", None) or os.environ.get( + "S3_ENDPOINT_URL", None + ) credentials = { "s3_access_key_id": s3_access_key_id, "s3_secret_access_key": s3_secret_access_key, - "s3_endpoint": s3_endpoint + "s3_endpoint": s3_endpoint, } model_ref = args.model @@ -309,25 +333,25 @@ def main(): if args.model_loader_extra_config: extra_config = json.loads(args.model_loader_extra_config) - - tensorizer_dir = (args.serialized_directory or - extra_config.get("tensorizer_dir")) - tensorizer_uri = (getattr(args, "path_to_tensors", None) - or extra_config.get("tensorizer_uri")) + tensorizer_dir = args.serialized_directory or extra_config.get("tensorizer_dir") + tensorizer_uri = getattr(args, "path_to_tensors", None) or extra_config.get( + "tensorizer_uri" + ) if tensorizer_dir and tensorizer_uri: - parser.error("--serialized-directory and --path-to-tensors " - "cannot both be provided") + parser.error( + "--serialized-directory and --path-to-tensors cannot both be provided" + ) if not tensorizer_dir and not tensorizer_uri: - parser.error("Either --serialized-directory or --path-to-tensors " - "must be provided") - + parser.error( + "Either --serialized-directory or --path-to-tensors must be provided" + ) if args.command == "serialize": engine_args = EngineArgs.from_cli_args(args) - input_dir = tensorizer_dir.rstrip('/') + input_dir = tensorizer_dir.rstrip("/") suffix = args.suffix if args.suffix else uuid.uuid4().hex base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" if engine_args.tensor_parallel_size > 1: @@ -339,15 +363,14 @@ def main(): tensorizer_uri=model_path, encryption_keyfile=keyfile, serialization_kwargs=args.serialization_kwargs or {}, - **credentials + **credentials, ) if args.lora_path: tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir tensorize_lora_adapter(args.lora_path, tensorizer_config) - merge_extra_config_with_tensorizer_config(extra_config, - tensorizer_config) + merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config) tensorize_vllm_model(engine_args, tensorizer_config) elif args.command == "deserialize": @@ -356,11 +379,10 @@ def main(): tensorizer_dir=args.serialized_directory, encryption_keyfile=keyfile, deserialization_kwargs=args.deserialization_kwargs or {}, - **credentials + **credentials, ) - merge_extra_config_with_tensorizer_config(extra_config, - tensorizer_config) + merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config) deserialize(args, tensorizer_config) else: raise ValueError("Either serialize or deserialize must be specified.") diff --git a/examples/pyproject.toml b/examples/pyproject.toml deleted file mode 100644 index f825cb203269..000000000000 --- a/examples/pyproject.toml +++ /dev/null @@ -1,54 +0,0 @@ -# This local pyproject file is part of the migration from yapf to ruff format. -# It uses the same core rules as the main pyproject.toml file, but with the -# following differences: -# - ruff line length is overridden to 88 -# - deprecated typing ignores (UP006, UP035) have been removed - -[tool.ruff] -line-length = 88 -exclude = [ - # External file, leaving license intact - "examples/other/fp8/quantizer/quantize.py", - "vllm/vllm_flash_attn/flash_attn_interface.pyi" -] - -[tool.ruff.lint.per-file-ignores] -"vllm/third_party/**" = ["ALL"] -"vllm/version.py" = ["F401"] -"vllm/_version.py" = ["ALL"] - -[tool.ruff.lint] -select = [ - # pycodestyle - "E", - # Pyflakes - "F", - # pyupgrade - "UP", - # flake8-bugbear - "B", - # flake8-simplify - "SIM", - # isort - "I", - # flake8-logging-format - "G", -] -ignore = [ - # star imports - "F405", "F403", - # lambda expression assignment - "E731", - # Loop control variable not used within loop body - "B007", - # f-string format - "UP032", - # Can remove once 3.10+ is the minimum Python version - "UP007", -] - -[tool.ruff.lint.isort] -known-first-party = ["vllm"] - -[tool.ruff.format] -docstring-code-format = true \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 034a21f1c12b..704f28fa6536 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,27 +52,10 @@ lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:regi where = ["."] include = ["vllm*"] -[tool.yapfignore] -ignore_patterns = [ - ".buildkite/**", - "benchmarks/**", - "build/**", - "examples/**", -] - -[tool.ruff] -# Allow lines to be as long as 80. -line-length = 80 - [tool.ruff.lint.per-file-ignores] "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] -# Python 3.8 typing - skip V0 code -"vllm/attention/**/*.py" = ["UP006", "UP035"] -"vllm/engine/**/*.py" = ["UP006", "UP035"] -"vllm/executor/**/*.py" = ["UP006", "UP035"] -"vllm/worker/**/*.py" = ["UP006", "UP035"] [tool.ruff.lint] select = [ @@ -87,7 +70,7 @@ select = [ # flake8-simplify "SIM", # isort - # "I", + "I", # flake8-logging-format "G", ] @@ -104,21 +87,15 @@ ignore = [ "UP007", ] +[tool.ruff.format] +docstring-code-format = true + [tool.mypy] plugins = ['pydantic.mypy'] ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" -[tool.isort] -skip_glob = [ - ".buildkite/*", - "benchmarks/*", - "examples/*", -] -use_parentheses = true -skip_gitignore = true - [tool.pytest.ini_options] markers = [ "slow_test", diff --git a/requirements/common.txt b/requirements/common.txt index a52745f69870..1530e5a09e75 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -49,3 +49,4 @@ pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss +gpt-oss >= 0.0.7 diff --git a/setup.py b/setup.py index 5491046991ca..53c460d2c5b8 100644 --- a/setup.py +++ b/setup.py @@ -34,32 +34,36 @@ def load_module_from_path(module_name, path): # cannot import envs directly because it depends on vllm, # which is not installed yet -envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py')) +envs = load_module_from_path("envs", os.path.join(ROOT_DIR, "vllm", "envs.py")) VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE if sys.platform.startswith("darwin") and VLLM_TARGET_DEVICE != "cpu": - logger.warning( - "VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS") + logger.warning("VLLM_TARGET_DEVICE automatically set to `cpu` due to macOS") VLLM_TARGET_DEVICE = "cpu" -elif not (sys.platform.startswith("linux") - or sys.platform.startswith("darwin")): +elif not (sys.platform.startswith("linux") or sys.platform.startswith("darwin")): logger.warning( "vLLM only supports Linux platform (including WSL) and MacOS." "Building on %s, " - "so vLLM may not be able to run correctly", sys.platform) + "so vLLM may not be able to run correctly", + sys.platform, + ) VLLM_TARGET_DEVICE = "empty" -elif (sys.platform.startswith("linux") and torch.version.cuda is None - and os.getenv("VLLM_TARGET_DEVICE") is None - and torch.version.hip is None): +elif ( + sys.platform.startswith("linux") + and torch.version.cuda is None + and os.getenv("VLLM_TARGET_DEVICE") is None + and torch.version.hip is None +): # if cuda or hip is not available and VLLM_TARGET_DEVICE is not set, # fallback to cpu VLLM_TARGET_DEVICE = "cpu" def is_sccache_available() -> bool: - return which("sccache") is not None and \ - not bool(int(os.getenv("VLLM_DISABLE_SCCACHE", "0"))) + return which("sccache") is not None and not bool( + int(os.getenv("VLLM_DISABLE_SCCACHE", "0")) + ) def is_ccache_available() -> bool: @@ -83,8 +87,7 @@ def is_url_available(url: str) -> bool: class CMakeExtension(Extension): - - def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: + def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None: super().__init__(name, sources=[], py_limited_api=True, **kwa) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) @@ -121,8 +124,8 @@ def compute_num_jobs(self): if nvcc_threads is not None: nvcc_threads = int(nvcc_threads) logger.info( - "Using NVCC_THREADS=%d as the number of nvcc threads.", - nvcc_threads) + "Using NVCC_THREADS=%d as the number of nvcc threads.", nvcc_threads + ) else: nvcc_threads = 1 num_jobs = max(1, num_jobs // nvcc_threads) @@ -146,36 +149,36 @@ def configure(self, ext: CMakeExtension) -> None: cfg = envs.CMAKE_BUILD_TYPE or default_cfg cmake_args = [ - '-DCMAKE_BUILD_TYPE={}'.format(cfg), - '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), + "-DCMAKE_BUILD_TYPE={}".format(cfg), + "-DVLLM_TARGET_DEVICE={}".format(VLLM_TARGET_DEVICE), ] verbose = envs.VERBOSE if verbose: - cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON'] + cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"] if is_sccache_available(): cmake_args += [ - '-DCMAKE_C_COMPILER_LAUNCHER=sccache', - '-DCMAKE_CXX_COMPILER_LAUNCHER=sccache', - '-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache', - '-DCMAKE_HIP_COMPILER_LAUNCHER=sccache', + "-DCMAKE_C_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=sccache", + "-DCMAKE_CUDA_COMPILER_LAUNCHER=sccache", + "-DCMAKE_HIP_COMPILER_LAUNCHER=sccache", ] elif is_ccache_available(): cmake_args += [ - '-DCMAKE_C_COMPILER_LAUNCHER=ccache', - '-DCMAKE_CXX_COMPILER_LAUNCHER=ccache', - '-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache', - '-DCMAKE_HIP_COMPILER_LAUNCHER=ccache', + "-DCMAKE_C_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CUDA_COMPILER_LAUNCHER=ccache", + "-DCMAKE_HIP_COMPILER_LAUNCHER=ccache", ] # Pass the python executable to cmake so it can find an exact # match. - cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)] + cmake_args += ["-DVLLM_PYTHON_EXECUTABLE={}".format(sys.executable)] # Pass the python path to cmake so it can reuse the build dependencies # on subsequent calls to python. - cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))] + cmake_args += ["-DVLLM_PYTHON_PATH={}".format(":".join(sys.path))] # Override the base directory for FetchContent downloads to $ROOT/.deps # This allows sharing dependencies between profiles, @@ -183,7 +186,7 @@ def configure(self, ext: CMakeExtension) -> None: # To override this, set the FETCHCONTENT_BASE_DIR environment variable. fc_base_dir = os.path.join(ROOT_DIR, ".deps") fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir) - cmake_args += ['-DFETCHCONTENT_BASE_DIR={}'.format(fc_base_dir)] + cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)] # # Setup parallelism and build tool @@ -191,30 +194,36 @@ def configure(self, ext: CMakeExtension) -> None: num_jobs, nvcc_threads = self.compute_num_jobs() if nvcc_threads: - cmake_args += ['-DNVCC_THREADS={}'.format(nvcc_threads)] + cmake_args += ["-DNVCC_THREADS={}".format(nvcc_threads)] if is_ninja_available(): - build_tool = ['-G', 'Ninja'] + build_tool = ["-G", "Ninja"] cmake_args += [ - '-DCMAKE_JOB_POOL_COMPILE:STRING=compile', - '-DCMAKE_JOB_POOLS:STRING=compile={}'.format(num_jobs), + "-DCMAKE_JOB_POOL_COMPILE:STRING=compile", + "-DCMAKE_JOB_POOLS:STRING=compile={}".format(num_jobs), ] else: # Default build tool to whatever cmake picks. build_tool = [] # Make sure we use the nvcc from CUDA_HOME if _is_cuda(): - cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc'] + cmake_args += [f"-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc"] + + other_cmake_args = os.environ.get("CMAKE_ARGS") + if other_cmake_args: + cmake_args += other_cmake_args.split() + subprocess.check_call( - ['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args], - cwd=self.build_temp) + ["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args], + cwd=self.build_temp, + ) def build_extensions(self) -> None: # Ensure that CMake is present and working try: - subprocess.check_output(['cmake', '--version']) + subprocess.check_output(["cmake", "--version"]) except OSError as e: - raise RuntimeError('Cannot find CMake executable') from e + raise RuntimeError("Cannot find CMake executable") from e # Create build directory if it does not exist. if not os.path.exists(self.build_temp): @@ -253,13 +262,18 @@ def target_name(s: str) -> str: # CMake appends the extension prefix to the install path, # and outdir already contains that prefix, so we need to remove it. prefix = outdir - for _ in range(ext.name.count('.')): + for _ in range(ext.name.count(".")): prefix = prefix.parent # prefix here should actually be the same for all components install_args = [ - "cmake", "--install", ".", "--prefix", prefix, "--component", - target_name(ext.name) + "cmake", + "--install", + ".", + "--prefix", + prefix, + "--component", + target_name(ext.name), ] subprocess.check_call(install_args, cwd=self.build_temp) @@ -270,12 +284,15 @@ def run(self): # copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current # directory so that they can be included in the editable build import glob - files = glob.glob(os.path.join(self.build_lib, "vllm", - "vllm_flash_attn", "**", "*.py"), - recursive=True) + + files = glob.glob( + os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "**", "*.py"), + recursive=True, + ) for file in files: - dst_file = os.path.join("vllm/vllm_flash_attn", - file.split("vllm/vllm_flash_attn/")[-1]) + dst_file = os.path.join( + "vllm/vllm_flash_attn", file.split("vllm/vllm_flash_attn/")[-1] + ) print(f"Copying {file} to {dst_file}") os.makedirs(os.path.dirname(dst_file), exist_ok=True) self.copy_file(file, dst_file) @@ -285,8 +302,7 @@ class precompiled_build_ext(build_ext): """Disables extension building when using precompiled binaries.""" def run(self) -> None: - assert _is_cuda( - ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" + assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" def build_extensions(self) -> None: print("Skipping build_ext: using precompiled extensions.") @@ -307,9 +323,9 @@ def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict: wheel_filename = wheel_url_or_path.split("/")[-1] temp_dir = tempfile.mkdtemp(prefix="vllm-wheels") wheel_path = os.path.join(temp_dir, wheel_filename) - print(f"Downloading wheel from {wheel_url_or_path} " - f"to {wheel_path}") + print(f"Downloading wheel from {wheel_url_or_path} to {wheel_path}") from urllib.request import urlretrieve + urlretrieve(wheel_url_or_path, filename=wheel_path) else: wheel_path = wheel_url_or_path @@ -330,25 +346,29 @@ def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict: ] compiled_regex = re.compile( - r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") + r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" + ) file_members = list( - filter(lambda x: x.filename in files_to_copy, - wheel.filelist)) + filter(lambda x: x.filename in files_to_copy, wheel.filelist) + ) file_members += list( - filter(lambda x: compiled_regex.match(x.filename), - wheel.filelist)) + filter(lambda x: compiled_regex.match(x.filename), wheel.filelist) + ) for file in file_members: print(f"[extract] {file.filename}") target_path = os.path.join(".", file.filename) os.makedirs(os.path.dirname(target_path), exist_ok=True) - with wheel.open(file.filename) as src, open( - target_path, "wb") as dst: + with ( + wheel.open(file.filename) as src, + open(target_path, "wb") as dst, + ): shutil.copyfileobj(src, dst) pkg = os.path.dirname(file.filename).replace("/", ".") package_data_patch.setdefault(pkg, []).append( - os.path.basename(file.filename)) + os.path.basename(file.filename) + ) return package_data_patch finally: @@ -364,10 +384,13 @@ def get_base_commit_in_main_branch() -> str: try: # Get the latest commit hash of the upstream main branch. - resp_json = subprocess.check_output([ - "curl", "-s", - "https://api.github.com/repos/vllm-project/vllm/commits/main" - ]).decode("utf-8") + resp_json = subprocess.check_output( + [ + "curl", + "-s", + "https://api.github.com/repos/vllm-project/vllm/commits/main", + ] + ).decode("utf-8") upstream_main_commit = json.loads(resp_json)["sha"] # In Docker build context, .git may be immutable or missing. @@ -377,25 +400,32 @@ def get_base_commit_in_main_branch() -> str: # Check if the upstream_main_commit exists in the local repo try: subprocess.check_output( - ["git", "cat-file", "-e", f"{upstream_main_commit}"]) + ["git", "cat-file", "-e", f"{upstream_main_commit}"] + ) except subprocess.CalledProcessError: # If not present, fetch it from the remote repository. # Note that this does not update any local branches, # but ensures that this commit ref and its history are # available in our local repo. - subprocess.check_call([ - "git", "fetch", "https://github.com/vllm-project/vllm", - "main" - ]) + subprocess.check_call( + ["git", "fetch", "https://github.com/vllm-project/vllm", "main"] + ) # Then get the commit hash of the current branch that is the same as # the upstream main commit. - current_branch = subprocess.check_output( - ["git", "branch", "--show-current"]).decode("utf-8").strip() + current_branch = ( + subprocess.check_output(["git", "branch", "--show-current"]) + .decode("utf-8") + .strip() + ) - base_commit = subprocess.check_output([ - "git", "merge-base", f"{upstream_main_commit}", current_branch - ]).decode("utf-8").strip() + base_commit = ( + subprocess.check_output( + ["git", "merge-base", f"{upstream_main_commit}", current_branch] + ) + .decode("utf-8") + .strip() + ) return base_commit except ValueError as err: raise ValueError(err) from None @@ -403,7 +433,9 @@ def get_base_commit_in_main_branch() -> str: logger.warning( "Failed to get the base commit in the main branch. " "Using the nightly wheel. The libraries in this " - "wheel may not be compatible with your dev branch: %s", err) + "wheel may not be compatible with your dev branch: %s", + err, + ) return "nightly" @@ -413,12 +445,13 @@ def _no_device() -> bool: def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None - return (VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu()) + return VLLM_TARGET_DEVICE == "cuda" and has_cuda and not _is_tpu() def _is_hip() -> bool: - return (VLLM_TARGET_DEVICE == "cuda" - or VLLM_TARGET_DEVICE == "rocm") and torch.version.hip is not None + return ( + VLLM_TARGET_DEVICE == "cuda" or VLLM_TARGET_DEVICE == "rocm" + ) and torch.version.hip is not None def _is_tpu() -> bool: @@ -457,8 +490,12 @@ def get_rocm_version(): minor = ctypes.c_uint32() patch = ctypes.c_uint32() - if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor), - ctypes.byref(patch)) == 0): + if ( + get_rocm_core_version( + ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch) + ) + == 0 + ): return f"{major.value}.{minor.value}.{patch.value}" return None except Exception: @@ -471,8 +508,9 @@ def get_nvcc_cuda_version() -> Version: Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py """ assert CUDA_HOME is not None, "CUDA_HOME is not set" - nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], - universal_newlines=True) + nvcc_output = subprocess.check_output( + [CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True + ) output = nvcc_output.split() release_idx = output.index("release") + 1 nvcc_cuda_version = parse(output[release_idx].split(",")[0]) @@ -484,14 +522,20 @@ def get_gaudi_sw_version(): Returns the driver version. """ # Enable console printing for `hl-smi` check - output = subprocess.run("hl-smi", - shell=True, - text=True, - capture_output=True, - env={"ENABLE_CONSOLE": "true"}) + output = subprocess.run( + "hl-smi", + shell=True, + text=True, + capture_output=True, + env={"ENABLE_CONSOLE": "true"}, + ) if output.returncode == 0 and output.stdout: - return output.stdout.split("\n")[2].replace( - " ", "").split(":")[1][:-1].split("-")[0] + return ( + output.stdout.split("\n")[2] + .replace(" ", "") + .split(":")[1][:-1] + .split("-")[0] + ) return "0.0.0" # when hl-smi is not available @@ -541,8 +585,11 @@ def _read_requirements(filename: str) -> list[str]: for line in requirements: if line.startswith("-r "): resolved_requirements += _read_requirements(line.split()[1]) - elif not line.startswith("--") and not line.startswith( - "#") and line.strip() != "": + elif ( + not line.startswith("--") + and not line.startswith("#") + and line.strip() != "" + ): resolved_requirements.append(line) return resolved_requirements @@ -553,7 +600,7 @@ def _read_requirements(filename: str) -> list[str]: cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: - if ("vllm-flash-attn" in req and cuda_major != "12"): + if "vllm-flash-attn" in req and cuda_major != "12": # vllm-flash-attn is built only for CUDA 12.x. # Skip for other versions. continue @@ -568,8 +615,7 @@ def _read_requirements(filename: str) -> list[str]: elif _is_xpu(): requirements = _read_requirements("xpu.txt") else: - raise ValueError( - "Unsupported platform, please use CUDA, ROCm, or CPU.") + raise ValueError("Unsupported platform, please use CUDA, ROCm, or CPU.") return requirements @@ -585,14 +631,13 @@ def _read_requirements(filename: str) -> list[str]: ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"): # FA3 requires CUDA 12.3 or later - ext_modules.append( - CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) + ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) # Optional since this doesn't get built (produce an .so file) when # not targeting a hopper system + ext_modules.append(CMakeExtension(name="vllm._flashmla_C", optional=True)) ext_modules.append( - CMakeExtension(name="vllm._flashmla_C", optional=True)) - ext_modules.append( - CMakeExtension(name="vllm._flashmla_extension_C", optional=True)) + CMakeExtension(name="vllm._flashmla_extension_C", optional=True) + ) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): @@ -614,6 +659,7 @@ def _read_requirements(filename: str) -> list[str]: wheel_url = wheel_location else: import platform + arch = platform.machine() if arch == "x86_64": wheel_tag = "manylinux1_x86_64" @@ -623,8 +669,11 @@ def _read_requirements(filename: str) -> list[str]: raise ValueError(f"Unsupported architecture: {arch}") base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch() wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" - nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" + nightly_wheel_url = ( + f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl" + ) from urllib.request import urlopen + try: with urlopen(wheel_url) as resp: if resp.status != 200: @@ -633,8 +682,7 @@ def _read_requirements(filename: str) -> list[str]: print(f"[warn] Falling back to nightly wheel: {e}") wheel_url = nightly_wheel_url - patch = precompiled_wheel_utils.extract_precompiled_and_patch_package( - wheel_url) + patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(wheel_url) for pkg, files in patch.items(): package_data.setdefault(pkg, []).extend(files) @@ -645,8 +693,9 @@ def _read_requirements(filename: str) -> list[str]: cmdclass = {} else: cmdclass = { - "build_ext": - precompiled_build_ext if envs.VLLM_USE_PRECOMPILED else cmake_build_ext + "build_ext": precompiled_build_ext + if envs.VLLM_USE_PRECOMPILED + else cmake_build_ext } setup( @@ -659,8 +708,11 @@ def _read_requirements(filename: str) -> list[str]: "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], "runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"], - "audio": ["librosa", "soundfile", - "mistral_common[audio]"], # Required for audio processing + "audio": [ + "librosa", + "soundfile", + "mistral_common[audio]", + ], # Required for audio processing "video": [], # Kept for backwards compatibility # FlashInfer should be updated together with the Dockerfile "flashinfer": ["flashinfer-python==0.3.1"], diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 411f3e01bc2c..d63c82102b6b 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,6 +4,7 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ + import os import weakref from unittest.mock import Mock @@ -37,16 +38,21 @@ def test_vllm_gc_ed(): def _fix_prompt_embed_outputs( - vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner, - example_prompts: list[str]) -> list[tuple[list[int], str]]: + vllm_outputs: list[tuple[list[int], str]], + hf_model: HfRunner, + example_prompts: list[str], +) -> list[tuple[list[int], str]]: fixed_vllm_outputs = [] for vllm_output, hf_input, prompt in zip( - vllm_outputs, hf_model.get_inputs(example_prompts), - example_prompts): + vllm_outputs, hf_model.get_inputs(example_prompts), example_prompts + ): hf_input_ids = hf_input["input_ids"].tolist()[0] fixed_vllm_outputs.append( - (hf_input_ids + vllm_output[0][len(hf_input_ids):], - prompt + vllm_output[1])) + ( + hf_input_ids + vllm_output[0][len(hf_input_ids) :], + prompt + vllm_output[1], + ) + ) return fixed_vllm_outputs @@ -69,8 +75,7 @@ def test_models( enable_prompt_embeds: bool, ) -> None: if backend == "XFORMERS" and model == "google/gemma-2-2b-it": - pytest.skip( - f"{backend} does not support gemma2 with full context length.") + pytest.skip(f"{backend} does not support gemma2 with full context length.") with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", backend) @@ -78,34 +83,35 @@ def test_models( # 5042 tokens for gemma2 # gemma2 has alternating sliding window size of 4096 # we need a prompt with more than 4096 tokens to test the sliding window - prompt = "The following numbers of the sequence " + ", ".join( - str(i) for i in range(1024)) + " are:" + prompt = ( + "The following numbers of the sequence " + + ", ".join(str(i) for i in range(1024)) + + " are:" + ) example_prompts = [prompt] with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) if enable_prompt_embeds: with torch.no_grad(): - prompt_embeds = hf_model.get_prompt_embeddings( - example_prompts) + prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) with VllmRunner( - model, - max_model_len=8192, - enforce_eager=enforce_eager, - enable_prompt_embeds=enable_prompt_embeds, - gpu_memory_utilization=0.7, - async_scheduling=async_scheduling, - distributed_executor_backend=model_executor, + model, + max_model_len=8192, + enforce_eager=enforce_eager, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, + async_scheduling=async_scheduling, + distributed_executor_backend=model_executor, ) as vllm_model: if enable_prompt_embeds: - vllm_outputs = vllm_model.generate_greedy( - prompt_embeds, max_tokens) + vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens) vllm_outputs = _fix_prompt_embed_outputs( - vllm_outputs, hf_model, example_prompts) + vllm_outputs, hf_model, example_prompts + ) else: - vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -117,21 +123,18 @@ def test_models( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( - "model, distributed_executor_backend, attention_backend, " - "test_suite, extra_env", [ + "model, distributed_executor_backend, attention_backend, test_suite, extra_env", + [ ("distilbert/distilgpt2", "ray", "", "L4", {}), ("distilbert/distilgpt2", "mp", "", "L4", {}), - ("distilbert/distilgpt2", "ray", "", "L4", { - "VLLM_SLEEP_WHEN_IDLE": "1" - }), - ("distilbert/distilgpt2", "mp", "", "L4", { - "VLLM_SLEEP_WHEN_IDLE": "1" - }), + ("distilbert/distilgpt2", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), + ("distilbert/distilgpt2", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("distilbert/distilgpt2", "ray", "", "A100", {}), ("distilbert/distilgpt2", "mp", "", "A100", {}), - ]) + ], +) @pytest.mark.parametrize("enable_prompt_embeds", [True, False]) def test_models_distributed( monkeypatch: pytest.MonkeyPatch, @@ -149,11 +152,14 @@ def test_models_distributed( pytest.skip(f"Skip test for {test_suite}") with monkeypatch.context() as monkeypatch_context: - if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa + if ( + model == "meta-llama/Llama-3.2-1B-Instruct" + and distributed_executor_backend == "ray" + and attention_backend == "" + and test_suite == "L4" + ): # noqa if enable_prompt_embeds: - pytest.skip( - "enable_prompt_embeds does not work with ray compiled dag." - ) + pytest.skip("enable_prompt_embeds does not work with ray compiled dag.") monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1") @@ -175,30 +181,26 @@ def test_models_distributed( # will hurt multiprocessing backend with fork method # (the default method). with vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - distributed_executor_backend=distributed_executor_backend, - enable_prompt_embeds=enable_prompt_embeds, - gpu_memory_utilization=0.7, + model, + dtype=dtype, + tensor_parallel_size=2, + distributed_executor_backend=distributed_executor_backend, + enable_prompt_embeds=enable_prompt_embeds, + gpu_memory_utilization=0.7, ) as vllm_model: if enable_prompt_embeds: with hf_runner(model, dtype=dtype) as hf_model: with torch.no_grad(): - prompt_embeds = hf_model.get_prompt_embeddings( - example_prompts) - vllm_outputs = vllm_model.generate_greedy( - prompt_embeds, max_tokens) + prompt_embeds = hf_model.get_prompt_embeddings(example_prompts) + vllm_outputs = vllm_model.generate_greedy(prompt_embeds, max_tokens) vllm_outputs = _fix_prompt_embed_outputs( - vllm_outputs, hf_model, example_prompts) - hf_outputs = hf_model.generate_greedy( - example_prompts, max_tokens) + vllm_outputs, hf_model, example_prompts + ) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) else: - vllm_outputs = vllm_model.generate_greedy( - example_prompts, max_tokens) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy( - example_prompts, max_tokens) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -209,27 +211,23 @@ def test_models_distributed( def test_failed_model_execution(vllm_runner, monkeypatch) -> None: - from vllm.envs import VLLM_USE_V1 if not VLLM_USE_V1: pytest.skip("Skipping V0 test, dump input not supported") # Needed to mock an error in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: + with vllm_runner("facebook/opt-125m", enforce_eager=True) as vllm_model: if isinstance(vllm_model.llm.llm_engine, LLMEngineV1): v1_test_failed_model_execution(vllm_model) def v1_test_failed_model_execution(vllm_model): - engine = vllm_model.llm.llm_engine - mocked_execute_model = Mock( - side_effect=RuntimeError("Mocked Critical Error")) - engine.engine_core.engine_core.model_executor.execute_model =\ - mocked_execute_model + mocked_execute_model = Mock(side_effect=RuntimeError("Mocked Critical Error")) + engine.engine_core.engine_core.model_executor.execute_model = mocked_execute_model with pytest.raises(RuntimeError) as exc_info: prompts = [ diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index 28bfe9e7c802..3c1e01d072b9 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -5,5 +5,6 @@ def test_cpu_offload(): - compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", [], - ["--cpu-offload-gb", "1"]) + compare_two_settings( + "meta-llama/Llama-3.2-1B-Instruct", [], ["--cpu-offload-gb", "1"] + ) diff --git a/tests/basic_correctness/test_cumem.py b/tests/basic_correctness/test_cumem.py index 508740ab2938..b7cd98e27403 100644 --- a/tests/basic_correctness/test_cumem.py +++ b/tests/basic_correctness/test_cumem.py @@ -23,13 +23,13 @@ def test_python_error(): tensors = [] with allocator.use_memory_pool(): # allocate 70% of the total memory - x = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') + x = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda") tensors.append(x) # release the memory allocator.sleep() # allocate more memory than the total memory - y = torch.empty(alloc_bytes, dtype=torch.uint8, device='cuda') + y = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda") tensors.append(y) with pytest.raises(RuntimeError): # when the allocator is woken up, it should raise an error @@ -41,17 +41,17 @@ def test_python_error(): def test_basic_cumem(): # some tensors from default memory pool shape = (1024, 1024) - x = torch.empty(shape, device='cuda') + x = torch.empty(shape, device="cuda") x.zero_() # some tensors from custom memory pool allocator = CuMemAllocator.get_instance() with allocator.use_memory_pool(): # custom memory pool - y = torch.empty(shape, device='cuda') + y = torch.empty(shape, device="cuda") y.zero_() y += 1 - z = torch.empty(shape, device='cuda') + z = torch.empty(shape, device="cuda") z.zero_() z += 2 @@ -74,16 +74,16 @@ def test_basic_cumem(): def test_cumem_with_cudagraph(): allocator = CuMemAllocator.get_instance() with allocator.use_memory_pool(): - weight = torch.eye(1024, device='cuda') + weight = torch.eye(1024, device="cuda") with allocator.use_memory_pool(tag="discard"): - cache = torch.empty(1024, 1024, device='cuda') + cache = torch.empty(1024, 1024, device="cuda") def model(x): out = x @ weight - cache[:out.size(0)].copy_(out) + cache[: out.size(0)].copy_(out) return out + 1 - x = torch.empty(128, 1024, device='cuda') + x = torch.empty(128, 1024, device="cuda") # warmup model(x) @@ -109,7 +109,7 @@ def model(x): model_graph.replay() # cache content is as expected - assert torch.allclose(x, cache[:x.size(0)]) + assert torch.allclose(x, cache[: x.size(0)]) # output content is as expected assert torch.allclose(y, x + 1) @@ -123,7 +123,8 @@ def model(x): ("meta-llama/Llama-3.2-1B", True), # sleep mode with pytorch checkpoint ("facebook/opt-125m", True), - ]) + ], +) def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): with monkeypatch.context() as m: assert use_v1 diff --git a/tests/benchmarks/test_latency_cli.py b/tests/benchmarks/test_latency_cli.py index 2279c846e01c..54075a3a15e6 100644 --- a/tests/benchmarks/test_latency_cli.py +++ b/tests/benchmarks/test_latency_cli.py @@ -10,8 +10,18 @@ @pytest.mark.benchmark def test_bench_latency(): command = [ - "vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32", - "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + "vllm", + "bench", + "latency", + "--model", + MODEL_NAME, + "--input-len", + "32", + "--output-len", + "1", + "--enforce-eager", + "--load-format", + "dummy", ] result = subprocess.run(command, capture_output=True, text=True) print(result.stdout) diff --git a/tests/benchmarks/test_random_dataset.py b/tests/benchmarks/test_random_dataset.py index 26cae369cdd5..90527dbeae28 100644 --- a/tests/benchmarks/test_random_dataset.py +++ b/tests/benchmarks/test_random_dataset.py @@ -7,8 +7,11 @@ import pytest from transformers import AutoTokenizer, PreTrainedTokenizerBase -from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset, - SampleRequest) +from vllm.benchmarks.datasets import ( + RandomDataset, + RandomMultiModalDataset, + SampleRequest, +) @pytest.fixture(scope="session") @@ -27,11 +30,9 @@ class Params(NamedTuple): @pytest.fixture(scope="session") def random_dataset_params() -> Params: - return Params(num_requests=16, - prefix_len=7, - range_ratio=0.3, - input_len=50, - output_len=20) + return Params( + num_requests=16, prefix_len=7, range_ratio=0.3, input_len=50, output_len=20 + ) def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: @@ -39,13 +40,15 @@ def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: return (req.prompt, req.prompt_len, req.expected_output_len) -def _collect_samples(dataset: RandomDataset, - tokenizer: PreTrainedTokenizerBase, - num_requests: int = 16, - prefix_len: int = 7, - range_ratio: float = 0.3, - input_len: int = 50, - output_len: int = 20) -> list[tuple[str, int, int]]: +def _collect_samples( + dataset: RandomDataset, + tokenizer: PreTrainedTokenizerBase, + num_requests: int = 16, + prefix_len: int = 7, + range_ratio: float = 0.3, + input_len: int = 50, + output_len: int = 20, +) -> list[tuple[str, int, int]]: samples = dataset.sample( tokenizer=tokenizer, num_requests=num_requests, @@ -59,8 +62,8 @@ def _collect_samples(dataset: RandomDataset, @pytest.mark.benchmark def test_random_dataset_same_seed( - hf_tokenizer: PreTrainedTokenizerBase, - random_dataset_params: Params) -> None: + hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params +) -> None: """Same seed should yield identical outputs, even if global RNGs change. This guards against accidental reliance on Python's random or np.random @@ -70,13 +73,15 @@ def test_random_dataset_same_seed( common_seed = 123 dataset_a = RandomDataset(random_seed=common_seed) dataset_b = RandomDataset(random_seed=common_seed) - a = _collect_samples(dataset_a, - hf_tokenizer, - num_requests=p.num_requests, - prefix_len=p.prefix_len, - range_ratio=p.range_ratio, - input_len=p.input_len, - output_len=p.output_len) + a = _collect_samples( + dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len, + ) # Perturb global RNG state to ensure isolation random.seed(999) @@ -84,43 +89,50 @@ def test_random_dataset_same_seed( np.random.seed(888) _ = [np.random.random() for _ in range(100)] - b = _collect_samples(dataset_b, - hf_tokenizer, - num_requests=p.num_requests, - prefix_len=p.prefix_len, - range_ratio=p.range_ratio, - input_len=p.input_len, - output_len=p.output_len) + b = _collect_samples( + dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len, + ) assert a == b + @pytest.mark.benchmark def test_random_dataset_different_seeds( - hf_tokenizer: PreTrainedTokenizerBase, - random_dataset_params: Params) -> None: + hf_tokenizer: PreTrainedTokenizerBase, random_dataset_params: Params +) -> None: """Different seeds should change outputs with overwhelming likelihood.""" p = random_dataset_params seed_a = 0 dataset_a = RandomDataset(random_seed=seed_a) - a = _collect_samples(dataset_a, - hf_tokenizer, - num_requests=p.num_requests, - prefix_len=p.prefix_len, - range_ratio=p.range_ratio, - input_len=p.input_len, - output_len=p.output_len) + a = _collect_samples( + dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len, + ) seed_b = 999 dataset_b = RandomDataset(random_seed=seed_b) # Perturb global RNG with same seed as dataset_a to ensure isolation random.seed(seed_a) np.random.seed(seed_a) - b = _collect_samples(dataset_b, - hf_tokenizer, - num_requests=p.num_requests, - prefix_len=p.prefix_len, - range_ratio=p.range_ratio, - input_len=p.input_len, - output_len=p.output_len) + b = _collect_samples( + dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len, + ) assert a != b @@ -128,6 +140,7 @@ def test_random_dataset_different_seeds( # RandomMultiModalDataset tests # ----------------------------- + def _mm_fingerprint_sample( req: SampleRequest, ) -> tuple[str, int, int, int, list[str]]: @@ -152,8 +165,13 @@ def _mm_fingerprint_sample( item_prefixes.append(f"video:{url[:22]}") else: item_prefixes.append("unknown:") - return (req.prompt, req.prompt_len, req.expected_output_len, len(items), - item_prefixes) + return ( + req.prompt, + req.prompt_len, + req.expected_output_len, + len(items), + item_prefixes, + ) def _collect_mm_samples( @@ -214,6 +232,7 @@ def test_random_mm_different_seeds( fb = [_mm_fingerprint_sample(s) for s in b] assert fa != fb + @pytest.mark.benchmark def test_random_mm_respects_limits( hf_tokenizer: PreTrainedTokenizerBase, @@ -271,9 +290,9 @@ def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None: for s in samples: assert s.multi_modal_data == [] + @pytest.mark.benchmark -def test_random_mm_num_items_per_prompt( - hf_tokenizer: PreTrainedTokenizerBase) -> None: +def test_random_mm_num_items_per_prompt(hf_tokenizer: PreTrainedTokenizerBase) -> None: ds = RandomMultiModalDataset(random_seed=0) # Fixed number of images per prompt # set num_mm_items_range_ratio to 0.0 @@ -300,7 +319,6 @@ def test_random_mm_num_items_per_prompt( def test_random_mm_bucket_config_not_mutated( hf_tokenizer: PreTrainedTokenizerBase, ) -> None: - ds = RandomMultiModalDataset(random_seed=0) # This bucket config is not normalized to sum to 1 # and has more buckets than requested images @@ -321,7 +339,6 @@ def test_random_mm_bucket_config_not_mutated( # Ensure the original dict content is unchanged assert original == snapshot - # Vary number of mm items per prompt # set num_mm_items_range_ratio to 0.5 samples_varying_items = _collect_mm_samples( diff --git a/tests/benchmarks/test_serve_cli.py b/tests/benchmarks/test_serve_cli.py index fafbef5f3718..90d685c966d3 100644 --- a/tests/benchmarks/test_serve_cli.py +++ b/tests/benchmarks/test_serve_cli.py @@ -11,9 +11,7 @@ @pytest.fixture(scope="module") def server(): - args = [ - "--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy" - ] + args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -46,6 +44,7 @@ def test_bench_serve(server): assert result.returncode == 0, f"Benchmark failed: {result.stderr}" + @pytest.mark.benchmark def test_bench_serve_chat(server): command = [ diff --git a/tests/benchmarks/test_throughput_cli.py b/tests/benchmarks/test_throughput_cli.py index b61e51db4fbe..a579b59e8af4 100644 --- a/tests/benchmarks/test_throughput_cli.py +++ b/tests/benchmarks/test_throughput_cli.py @@ -10,8 +10,18 @@ @pytest.mark.benchmark def test_bench_throughput(): command = [ - "vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len", - "32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy" + "vllm", + "bench", + "throughput", + "--model", + MODEL_NAME, + "--input-len", + "32", + "--output-len", + "1", + "--enforce-eager", + "--load-format", + "dummy", ] result = subprocess.run(command, capture_output=True, text=True) print(result.stdout) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index f25c367433f4..36bc832a1329 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -23,8 +23,7 @@ class LazyInitPass(InductorPass): and then immediately invoke it. """ - def __init__(self, pass_cls: type[VllmInductorPass], - vllm_config: VllmConfig): + def __init__(self, pass_cls: type[VllmInductorPass], vllm_config: VllmConfig): self.pass_cls = pass_cls self.vllm_config = weakref.proxy(vllm_config) # avoid cycle @@ -45,20 +44,18 @@ class TestBackend: Inductor config is default-initialized from VllmConfig.CompilationConfig. """ - def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], - None]]): + def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], None]]): self.custom_passes = list(passes) compile_config = get_current_vllm_config().compilation_config self.inductor_config = compile_config.inductor_compile_config - self.inductor_config['force_disable_caches'] = True - self.inductor_config['post_grad_custom_post_pass'] = self.post_pass + self.inductor_config["force_disable_caches"] = True + self.inductor_config["post_grad_custom_post_pass"] = self.post_pass def __call__(self, graph: fx.GraphModule, example_inputs): self.graph_pre_compile = deepcopy(graph) from torch._inductor.compile_fx import compile_fx - return compile_fx(graph, - example_inputs, - config_patches=self.inductor_config) + + return compile_fx(graph, example_inputs, config_patches=self.inductor_config) @with_pattern_match_debug def post_pass(self, graph: fx.Graph): @@ -82,8 +79,7 @@ def check_before_ops(self, ops: Sequence[OpOverload], fully_replaced=True): assert num_pre > 0, f"Op {op.name()} not found in pre-pass graph" assert num_pre > num_post, f"All nodes remain for op {op.name()}" if fully_replaced: - assert num_post == 0, \ - f"Unexpected op {op.name()} in post-pass graph" + assert num_post == 0, f"Unexpected op {op.name()} in post-pass graph" def check_after_ops(self, ops: Sequence[OpOverload]): for op in ops: diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 9906e49bb110..927c838ae74e 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -38,8 +38,8 @@ def temporary_environ(env_vars): MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"] for mla_backend in MLA_backends: test_params_full_cudagraph.append( - pytest.param( - ("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend]))) + pytest.param(("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])) + ) # Qwen/Qwen2-1.5B-Instruct with other backends other_backend_configs = [ @@ -47,7 +47,8 @@ def temporary_environ(env_vars): ] for backend_config in other_backend_configs: test_params_full_cudagraph.append( - pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config))) + pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)) + ) @pytest.fixture(scope="class") @@ -55,8 +56,10 @@ def llm_pair(request): model, backend_config = request.param # Dynamically skip test if GPU capability is not met - if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\ - != current_platform.get_device_capability(): + if ( + backend_config.specific_gpu_arch + and backend_config.specific_gpu_arch != current_platform.get_device_capability() + ): if backend_config.specific_gpu_arch == (9, 0): pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") elif backend_config.specific_gpu_arch == (10, 0): @@ -76,8 +79,7 @@ def llm_pair(request): trust_remote_code=True, max_model_len=1024, max_num_seqs=128, - compilation_config=\ - CompilationConfig(**backend_config.comp_config), + compilation_config=CompilationConfig(**backend_config.comp_config), generation_config="vllm", seed=42, ) @@ -113,20 +115,22 @@ class TestFullCUDAGraph: meaning there would be multiple LLM instances hogging memory simultaneously. """ - @pytest.mark.parametrize(("batch_size", "max_tokens"), [ - (1, 10), - (7, 10), - (16, 10), - (25, 10), - (32, 10), - (45, 10), - (64, 10), - (123, 10), - (8, 5), - (8, 30), - ]) - def test_full_cudagraph(self, batch_size, max_tokens, - llm_pair: tuple[LLM, LLM]): + @pytest.mark.parametrize( + ("batch_size", "max_tokens"), + [ + (1, 10), + (7, 10), + (16, 10), + (25, 10), + (32, 10), + (45, 10), + (64, 10), + (123, 10), + (8, 5), + (8, 30), + ], + ) + def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]): """ Test various batch sizes and max_tokens to ensure that the full cudagraph compilation works for padded cases too. @@ -137,26 +141,34 @@ def test_full_cudagraph(self, batch_size, max_tokens, prompts = ["the quick brown fox"] * batch_size # Use purely greedy decoding to avoid top-p truncation sensitivity # that can amplify tiny numeric differences across runtimes. - sampling_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens, - top_p=1.0) + sampling_params = SamplingParams( + temperature=0.0, max_tokens=max_tokens, top_p=1.0 + ) piecewise_responses = piecewise_llm.generate(prompts, sampling_params) full_responses = full_cudagraph_llm.generate(prompts, sampling_params) # Check that all responses are the same - for piecewise_res, full_res in zip(piecewise_responses, - full_responses): - assert piecewise_res.outputs[0].text.lower() == \ - full_res.outputs[0].text.lower() + for piecewise_res, full_res in zip(piecewise_responses, full_responses): + assert ( + piecewise_res.outputs[0].text.lower() + == full_res.outputs[0].text.lower() + ) @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") def test_full_cudagraph_with_invalid_backend(): - with temporary_environ({ - "VLLM_USE_V1": "1", - "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION" - # Flex_Attention is not supported with full cuda graph - }), pytest.raises(RuntimeError): - LLM(model="Qwen/Qwen2-1.5B-Instruct", - compilation_config=CompilationConfig(cudagraph_mode="FULL")) + with ( + temporary_environ( + { + "VLLM_USE_V1": "1", + "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION", + # Flex_Attention is not supported with full cuda graph + } + ), + pytest.raises(RuntimeError), + ): + LLM( + model="Qwen/Qwen2-1.5B-Instruct", + compilation_config=CompilationConfig(cudagraph_mode="FULL"), + ) diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 5cfebfce9ea2..7372dc99bc79 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -10,10 +10,14 @@ from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter -from vllm.compilation.decorators import (ignore_torch_compile, - support_torch_compile) -from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, - VllmConfig, set_current_vllm_config) +from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile +from vllm.config import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import BatchDescriptor, set_forward_context # This import automatically registers `torch.ops.silly.attention` @@ -27,12 +31,7 @@ @support_torch_compile class ParentModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -40,7 +39,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Attention(nn.Module): - def __init__(self, mlp_size: int, hidden_size: int) -> None: super().__init__() self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False) @@ -51,17 +49,21 @@ def __init__(self, mlp_size: int, hidden_size: int) -> None: nn.init.xavier_normal_( self.pre_attn.weight.data, generator=torch.Generator().manual_seed(RANDOM_SEED), - gain=0.001) + gain=0.001, + ) nn.init.xavier_normal_( self.post_attn.weight.data, generator=torch.Generator().manual_seed(RANDOM_SEED), - gain=0.001) + gain=0.001, + ) def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor: x_f32 = x.float() - return (x_f32 * torch.rsqrt( - torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) * - self.rms_norm_weight).to(x.dtype) + return ( + x_f32 + * torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) + * self.rms_norm_weight + ).to(x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pre_attn(x) @@ -76,14 +78,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @support_torch_compile class CompiledAttention(nn.Module): - - def __init__(self, - *, - mlp_size: int, - hidden_size: int, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__( + self, + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ) -> None: super().__init__() self.attn = Attention(mlp_size, hidden_size) @@ -93,21 +96,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @support_torch_compile class CompiledAttentionTwo(CompiledAttention): - def forward(self, x: torch.Tensor) -> torch.Tensor: return self.attn(x) + x @ignore_torch_compile class SimpleModelWithTwoGraphs(ParentModel): - - def __init__(self, - *, - mlp_size: int, - hidden_size: int, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__( + self, + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) # Test will fail without set_model_tag here with error: # "ValueError: too many values to unpack (expected 3)" @@ -142,32 +145,45 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch.inference_mode -def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor, - cudagraph_runtime_mode: CUDAGraphMode): +def run_model( + vllm_config: VllmConfig, + model: nn.Module, + inputs: torch.Tensor, + cudagraph_runtime_mode: CUDAGraphMode, +): with set_forward_context({}, vllm_config=vllm_config): # warmup for the model with cudagraph_mode NONE model(inputs) # simulate cudagraphs capturing - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): model(inputs[:2]) - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=1, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, + ), + ): model(inputs[:1]) # simulate cudagraphs replay - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): output = model(inputs[:2]) output = output.cpu() @@ -178,82 +194,104 @@ def test_multi_graph_piecewise_compile_outputs_equal(): outputs = [] # piecewise compile - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + ) + ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): - model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, - hidden_size=HIDDEN_SIZE, - vllm_config=vllm_config, - prefix='').eval().cuda() + model = ( + SimpleModelWithTwoGraphs( + mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix="", + ) + .eval() + .cuda() + ) # Pre-allocate memory for CUDAGraph which expects # static tensor addresses inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() with compilation_counter.expect( - num_graphs_seen=2, # two graphs for the model - num_piecewise_graphs_seen=6, - # attn_one, attn_two each has 3 piecewise graphs - # (pre attn, post attn, silly_attention) each - num_piecewise_capturable_graphs_seen=4, - # attn_one, attn_two has pre attn and post attn each, total=4 - num_backend_compilations=4, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=2, # two graphs for the model + num_piecewise_graphs_seen=6, + # attn_one, attn_two each has 3 piecewise graphs + # (pre attn, post attn, silly_attention) each + num_piecewise_capturable_graphs_seen=4, + # attn_one, attn_two has pre attn and post attn each, total=4 + num_backend_compilations=4, # num_piecewise_capturable_graphs_seen + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): - outputs.append( - run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) + outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # no compile or cudagraph - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.NO_COMPILATION, )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.NO_COMPILATION, + ) + ) cudagraph_runtime_mode = CUDAGraphMode.NONE with set_current_vllm_config(vllm_config): - model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, - hidden_size=HIDDEN_SIZE, - vllm_config=vllm_config, - prefix='').eval().cuda() + model = ( + SimpleModelWithTwoGraphs( + mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix="", + ) + .eval() + .cuda() + ) with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, ): - outputs.append( - run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) + outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # piecewise compile without CUDA graph - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=False, - splitting_ops=["silly.attention"], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=False, + splitting_ops=["silly.attention"], + ) + ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): - model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, - hidden_size=HIDDEN_SIZE, - vllm_config=vllm_config, - prefix='').eval().cuda() + model = ( + SimpleModelWithTwoGraphs( + mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix="", + ) + .eval() + .cuda() + ) with compilation_counter.expect( - num_graphs_seen=2, - num_piecewise_graphs_seen=6, - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, - num_cudagraph_captured=0, # no cudagraph captured + num_graphs_seen=2, + num_piecewise_graphs_seen=6, + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=0, # no cudagraph captured ): - outputs.append( - run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) + outputs.append(run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) # Generally don't expect outputs with and without inductor # to be bitwise equivalent diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 41055f431569..920cd5a06c26 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -11,8 +11,13 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, - VllmConfig, set_current_vllm_config) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + VllmConfig, + set_current_vllm_config, +) from vllm.envs import VLLM_USE_V1 from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils import is_torch_equal_or_newer @@ -23,12 +28,7 @@ @support_torch_compile class SillyModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -60,53 +60,65 @@ def _run_simple_model( expected_num_backend_compilations, expected_num_cudagraph_captured, ): - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - use_inductor=use_inductor, - splitting_ops=splitting_ops, - use_inductor_graph_partition=use_inductor_graph_partition, - cudagraph_copy_inputs=True, - cudagraph_capture_sizes=[1, 2], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + use_inductor=use_inductor, + splitting_ops=splitting_ops, + use_inductor_graph_partition=use_inductor_graph_partition, + cudagraph_copy_inputs=True, + cudagraph_capture_sizes=[1, 2], + ) + ) with set_current_vllm_config(vllm_config): - model = SillyModel(vllm_config=vllm_config, prefix='') + model = SillyModel(vllm_config=vllm_config, prefix="") inputs = torch.randn(100).cuda() - with compilation_counter.expect( + with ( + compilation_counter.expect( num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen, - num_piecewise_capturable_graphs_seen= - expected_num_piecewise_capturable_graphs_seen, + num_piecewise_capturable_graphs_seen=expected_num_piecewise_capturable_graphs_seen, num_backend_compilations=expected_num_backend_compilations, num_cudagraph_captured=expected_num_cudagraph_captured, - ), set_forward_context(None, - vllm_config=vllm_config): # background context + ), + set_forward_context(None, vllm_config=vllm_config), + ): # background context # warm up with background context model(inputs) # capturing/replaying should under context of cudagraph dispatching with set_forward_context( - None, - vllm_config=vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, - batch_descriptor=BatchDescriptor(num_tokens=2, )): + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): model(torch.randn(2).cuda()) with set_forward_context( - None, - vllm_config=vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, - batch_descriptor=BatchDescriptor(num_tokens=1, )): + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor( + num_tokens=1, + ), + ): model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() reset_global_counter() with set_forward_context( - None, - vllm_config=vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, - batch_descriptor=BatchDescriptor(num_tokens=2, )): + None, + vllm_config=vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): output = model(input) assert get_global_counter() == 2 assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0])) @@ -120,12 +132,14 @@ def test_simple_piecewise_compile(use_inductor): splitting_ops=["silly.attention"], use_inductor_graph_partition=False, use_inductor=use_inductor, - expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1 - expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers - expected_num_backend_compilations= - 3, # num_piecewise_capturable_graphs_seen - expected_num_cudagraph_captured= - 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + # 2 * num_layers + 1 + expected_num_piecewise_graphs_seen=5, + # 1 + num_layers + expected_num_piecewise_capturable_graphs_seen=3, + # num_piecewise_capturable_graphs_seen + expected_num_backend_compilations=3, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + expected_num_cudagraph_captured=6, ) @@ -134,22 +148,19 @@ def test_simple_piecewise_compile(use_inductor): def test_simple_inductor_graph_partition(splitting_ops): assert VLLM_USE_V1 if not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("inductor graph partition is only available " - "in PyTorch 2.9+") + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") _run_simple_model( - # inductor graph partition automatically resets splitting_ops - # to be an empty list + # Inductor graph partition automatically resets splitting_ops to an empty list splitting_ops=splitting_ops, use_inductor_graph_partition=True, use_inductor=True, - expected_num_piecewise_graphs_seen= - 1, # since not splitting at fx graph level - expected_num_piecewise_capturable_graphs_seen= - 1, # since not splitting at fx graph level - expected_num_backend_compilations= - 1, # since not splitting at fx graph level - expected_num_cudagraph_captured= - 6, # inductor graph partition still captures 6 - # graph, same as fx graph partition. + # Since not splitting at fx graph level + expected_num_piecewise_graphs_seen=1, + # Since not splitting at fx graph level + expected_num_piecewise_capturable_graphs_seen=1, + # Since not splitting at fx graph level + expected_num_backend_compilations=1, + # Inductor graph partition still captures 6 graph, same as fx graph partition + expected_num_cudagraph_captured=6, ) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index cba7517647e5..e053367fb3d7 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -8,6 +8,7 @@ if the config `tractable_init` is set to True. Otherwise, the weights are initialized randomly with a fixed seed. """ + from dataclasses import dataclass from typing import Any, Optional @@ -17,8 +18,13 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, - VllmConfig, set_current_vllm_config) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import BatchDescriptor, set_forward_context # This import automatically registers `torch.ops.silly.attention` @@ -43,15 +49,14 @@ def compute_hash(self) -> str: factors.append((k, v)) factors.sort() import hashlib - return hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + + return hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() def __post_init__(self): assert self.mlp_size >= self.hidden_size class LlamaMLP(nn.Module): - def __init__(self, config: LlamaConfig) -> None: super().__init__() self.gate_up_projection = nn.Linear( @@ -66,31 +71,31 @@ def __init__(self, config: LlamaConfig) -> None: ) if config.tractable_init: - nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size]) - nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:]) + nn.init.eye_(self.gate_up_projection.weight.data[: config.mlp_size]) + nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size :]) nn.init.eye_(self.down_projection.weight.data) else: - nn.init.xavier_normal_(self.gate_up_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) - nn.init.xavier_normal_(self.down_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) + nn.init.xavier_normal_( + self.gate_up_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) + nn.init.xavier_normal_( + self.down_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) def forward(self, x): # for tractable_init and positive input, this is # essentially an elementwise-square x = self.gate_up_projection(x) - x = x[:, :x.size(1) // 2] * torch.nn.functional.relu( - x[:, x.size(1) // 2:]) + x = x[:, : x.size(1) // 2] * torch.nn.functional.relu(x[:, x.size(1) // 2 :]) x = self.down_projection(x) return x class LlamaAttention(nn.Module): - def __init__(self, config: LlamaConfig) -> None: super().__init__() self.qkv_projection = nn.Linear( @@ -106,21 +111,25 @@ def __init__(self, config: LlamaConfig) -> None: ) if config.tractable_init: - nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size]) - nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 * - config.hidden_size]) - nn.init.eye_(self.qkv_projection.weight.data[2 * - config.hidden_size:]) + nn.init.eye_(self.qkv_projection.weight.data[: config.hidden_size]) + nn.init.eye_( + self.qkv_projection.weight.data[ + config.hidden_size : 2 * config.hidden_size + ] + ) + nn.init.eye_(self.qkv_projection.weight.data[2 * config.hidden_size :]) nn.init.eye_(self.output_projection.weight.data) else: - nn.init.xavier_normal_(self.qkv_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) - nn.init.xavier_normal_(self.output_projection.weight.data, - generator=torch.Generator().manual_seed( - config.random_seed), - gain=0.001) + nn.init.xavier_normal_( + self.qkv_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) + nn.init.xavier_normal_( + self.output_projection.weight.data, + generator=torch.Generator().manual_seed(config.random_seed), + gain=0.001, + ) def forward( self, @@ -144,7 +153,6 @@ def forward( class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig) -> None: super().__init__() self.self_attention = LlamaAttention(config) @@ -164,7 +172,7 @@ def forward( - if residual is not None, the outputs are: - residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3 - hidden_states = (residual + 1) ** 2 - """ # noqa + """ # noqa if residual is None: residual = hidden_states hidden_states = hidden_states + 1 @@ -173,8 +181,9 @@ def forward( residual = hidden_states hidden_states = hidden_states + 1 - hidden_states = self.self_attention(positions=positions, - hidden_states=hidden_states) + hidden_states = self.self_attention( + positions=positions, hidden_states=hidden_states + ) hidden_states = hidden_states + residual residual = hidden_states @@ -186,20 +195,22 @@ def forward( @support_torch_compile class LlamaModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - config: LlamaConfig, - prefix: str = '', - **kwargs) -> None: + def __init__( + self, + *, + vllm_config: VllmConfig, + config: LlamaConfig, + prefix: str = "", + **kwargs, + ) -> None: super().__init__() self.embedding_tokens = nn.Embedding( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, ) self.layers = nn.ModuleList( - [LlamaDecoderLayer(config) for _ in range(config.num_layers)]) + [LlamaDecoderLayer(config) for _ in range(config.num_layers)] + ) # this is the initial value of the hidden states self.embedding_tokens.weight.data.fill_(config.init_value) @@ -216,34 +227,39 @@ def forward( return hidden_states -def tractable_computation(input_ids: torch.Tensor, - positions: torch.Tensor, - config: LlamaConfig, - init_value: float = 1.0) -> torch.Tensor: - hidden_states = torch.ones(input_ids.size(0), - config.hidden_size, - device=input_ids.device, - dtype=input_ids.dtype) * init_value +def tractable_computation( + input_ids: torch.Tensor, + positions: torch.Tensor, + config: LlamaConfig, + init_value: float = 1.0, +) -> torch.Tensor: + hidden_states = ( + torch.ones( + input_ids.size(0), + config.hidden_size, + device=input_ids.device, + dtype=input_ids.dtype, + ) + * init_value + ) # first layer residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 - hidden_states = (residual + 1)**2 + hidden_states = (residual + 1) ** 2 # following layers for _ in range(config.num_layers - 1): hidden_states = hidden_states + residual residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3 - hidden_states = (residual + 1)**2 + hidden_states = (residual + 1) ** 2 return hidden_states @torch.inference_mode -def run_model(llama_config, - use_compile: bool, - use_inductor: bool, - split_attn: bool = False) -> torch.Tensor: - +def run_model( + llama_config, use_compile: bool, use_inductor: bool, split_attn: bool = False +) -> torch.Tensor: if use_compile: compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, @@ -256,54 +272,66 @@ def run_model(llama_config, cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: compilation_config = CompilationConfig( - level=CompilationLevel.NO_COMPILATION, ) + level=CompilationLevel.NO_COMPILATION, + ) cudagraph_runtime_mode = CUDAGraphMode.NONE - vllm_config = VllmConfig(compilation_config=compilation_config, - additional_config=llama_config) + vllm_config = VllmConfig( + compilation_config=compilation_config, additional_config=llama_config + ) with set_current_vllm_config(vllm_config): - model = LlamaModel(config=llama_config, - vllm_config=vllm_config, - prefix="").eval().cuda() + model = ( + LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="") + .eval() + .cuda() + ) - with set_forward_context({}, - vllm_config=vllm_config): # background context + with set_forward_context({}, vllm_config=vllm_config): # background context B = 16 # max batch size - input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() + input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda() positions = torch.arange(B).cuda() # warmup for the model with cudagraph_mode NONE model(input_ids, positions) # simulate cudagraphs capturing - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): model(input_ids[:2], positions[:2]) - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=1, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, + ), + ): model(input_ids[:1], positions[:1]) input_ids[:2].zero_() # simulate cudagraphs replay - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): output = model(input_ids[:2], positions[:2]) output = output.cpu() if llama_config.tractable_init: - expected_output = tractable_computation(input_ids[:2], - positions[:2], - llama_config).cpu() + expected_output = tractable_computation( + input_ids[:2], positions[:2], llama_config + ).cpu() assert torch.allclose(output, expected_output) else: @@ -314,27 +342,23 @@ def run_model(llama_config, def test_toy_llama(use_inductor: bool): # compare output with and without piecewise compilation - llama_config = LlamaConfig(hidden_size=128, - mlp_size=256, - vocab_size=128, - num_layers=12) + llama_config = LlamaConfig( + hidden_size=128, mlp_size=256, vocab_size=128, num_layers=12 + ) - tractable_config = LlamaConfig(hidden_size=128, - mlp_size=256, - vocab_size=128, - num_layers=2, - tractable_init=True) + tractable_config = LlamaConfig( + hidden_size=128, mlp_size=256, vocab_size=128, num_layers=2, tractable_init=True + ) outputs = [] with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, ): - outputs.append( - run_model(llama_config, use_inductor=False, use_compile=False)) + outputs.append(run_model(llama_config, use_inductor=False, use_compile=False)) run_model(tractable_config, use_inductor=False, use_compile=False) if use_inductor: @@ -343,41 +367,44 @@ def test_toy_llama(use_inductor: bool): kwargs = {"num_eager_compiles": 1, "num_inductor_compiles": 0} with compilation_counter.expect( - num_graphs_seen=1, # one graph for the model - num_piecewise_graphs_seen=1, - num_piecewise_capturable_graphs_seen=1, - num_backend_compilations=1, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured= - 2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen - **kwargs, + # One graph for the model + num_graphs_seen=1, + num_piecewise_graphs_seen=1, + num_piecewise_capturable_graphs_seen=1, + # num_piecewise_capturable_graphs_seen + num_backend_compilations=1, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_cudagraph_captured=2, + **kwargs, ): outputs.append( - run_model(llama_config, - use_inductor=use_inductor, - use_compile=True)) + run_model(llama_config, use_inductor=use_inductor, use_compile=True) + ) run_model(tractable_config, use_inductor=use_inductor, use_compile=True) with compilation_counter.expect( - num_graphs_seen=1, # one graph for the model - num_piecewise_graphs_seen=2 * llama_config.num_layers + - 1, # 2 * num_layers + 1 - num_piecewise_capturable_graphs_seen=1 + - llama_config.num_layers, # 1 + num_layers - num_backend_compilations=1 + - llama_config.num_layers, # num_piecewise_capturable_graphs_seen - num_cudagraph_captured=2 * - (1 + llama_config.num_layers - ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, # one graph for the model + num_piecewise_graphs_seen=2 * llama_config.num_layers + 1, # 2 * num_layers + 1 + num_piecewise_capturable_graphs_seen=1 + + llama_config.num_layers, # 1 + num_layers + num_backend_compilations=1 + + llama_config.num_layers, # num_piecewise_capturable_graphs_seen + num_cudagraph_captured=2 + * ( + 1 + llama_config.num_layers + ), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): outputs.append( - run_model(llama_config, - use_inductor=use_inductor, - use_compile=True, - split_attn=True)) - run_model(tractable_config, - use_inductor=use_inductor, - use_compile=True, - split_attn=True) + run_model( + llama_config, + use_inductor=use_inductor, + use_compile=True, + split_attn=True, + ) + ) + run_model( + tractable_config, use_inductor=use_inductor, use_compile=True, split_attn=True + ) for i in range(1, len(outputs)): assert torch.allclose(outputs[0], outputs[i]) @@ -388,17 +415,15 @@ def benchmark(): from triton.testing import do_bench # similar to llama 3.1-8B - llama_config = LlamaConfig(hidden_size=4096, - mlp_size=14336, - vocab_size=128 * 1024, - num_layers=32) + llama_config = LlamaConfig( + hidden_size=4096, mlp_size=14336, vocab_size=128 * 1024, num_layers=32 + ) # a tiny model to measure the overhead # of piecewise cudagraph - llama_config = LlamaConfig(hidden_size=40, - mlp_size=80, - vocab_size=128, - num_layers=2) + llama_config = LlamaConfig( + hidden_size=40, mlp_size=80, vocab_size=128, num_layers=2 + ) cudagraph_sizes = [1, 2, 4] + [i * 8 for i in range(1, 33)] @@ -424,12 +449,15 @@ def benchmark(): vllm_config = VllmConfig(compilation_config=compilation_config) with set_current_vllm_config(vllm_config): - model = LlamaModel(config=llama_config, - vllm_config=vllm_config, - prefix="").eval().cuda().to(torch.bfloat16) + model = ( + LlamaModel(config=llama_config, vllm_config=vllm_config, prefix="") + .eval() + .cuda() + .to(torch.bfloat16) + ) B = 256 # max batch size - input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() + input_ids = torch.randint(0, llama_config.vocab_size, (B,)).cuda() positions = torch.arange(B).cuda().to(torch.bfloat16) graphs = {} @@ -451,21 +479,26 @@ def benchmark(): # and use it later, because it will look up the name `b` in the # enclosing scope, and the value of `b` will always be 256. # it is fine here, because we only use the lambda function once. - runtime = do_bench(lambda: graphs[b][0] # noqa - (input_ids[:b], positions[:b])) # noqa + runtime = do_bench( + lambda: graphs[b][0]( # noqa + input_ids[:b], # noqa + positions[:b], # noqa + ) + ) piecewise_cudagraph_time[b] = runtime else: runtime = do_bench(lambda: graphs[b][0].replay()) # noqa - eager_runtime = do_bench( - lambda: model(input_ids[:b], positions[:b])) # noqa + eager_runtime = do_bench(lambda: model(input_ids[:b], positions[:b])) # noqa full_cudagraph_time[b] = runtime eager_time[b] = eager_runtime # print in tabular format print("batch size\teager mode\tfull cudagraph\tpiecewise cudagraph") for b in cudagraph_sizes: - print(f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}" - f"\t{piecewise_cudagraph_time[b]:.3f}") + print( + f"{b}\t{eager_time[b]:.3f}\t{full_cudagraph_time[b]:.3f}" + f"\t{piecewise_cudagraph_time[b]:.3f}" + ) if __name__ == "__main__": diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py index baedafbae99f..c0d3f908149f 100644 --- a/tests/compile/silly_attention.py +++ b/tests/compile/silly_attention.py @@ -31,8 +31,9 @@ def reset_global_counter(): _global_counter = 0 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: """ Unified attention implementation that depends on all inputs and affects the output. @@ -47,8 +48,9 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out.copy_(q + k + v) -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor +) -> None: """Fake implementation for testing""" return @@ -60,5 +62,5 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mutates_args=["out"], fake_impl=silly_attention_fake, target_lib=silly_lib, - tags=(torch._C.Tag.cudagraph_unsafe, ), + tags=(torch._C.Tag.cudagraph_unsafe,), ) diff --git a/tests/compile/test_async_tp.py b/tests/compile/test_async_tp.py index 1dc21365d557..03cd510eb5d0 100644 --- a/tests/compile/test_async_tp.py +++ b/tests/compile/test_async_tp.py @@ -8,18 +8,30 @@ import vllm.envs as envs from vllm.compilation.collective_fusion import AsyncTPPass -from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, - PassConfig, VllmConfig) -from vllm.distributed import (tensor_model_parallel_all_gather, - tensor_model_parallel_reduce_scatter) -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.config import ( + CompilationConfig, + DeviceConfig, + ModelConfig, + PassConfig, + VllmConfig, +) +from vllm.distributed import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_reduce_scatter, +) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.platforms import current_platform from vllm.utils import update_environment_variables from ..models.registry import HF_EXAMPLE_MODELS -from ..utils import (compare_two_settings, create_new_process_for_each_test, - multi_gpu_test) +from ..utils import ( + compare_two_settings, + create_new_process_for_each_test, + multi_gpu_test, +) from .backend import TestBackend FP8_DTYPE = current_platform.fp8_dtype() @@ -33,21 +45,20 @@ class TestMMRSModel(torch.nn.Module): - def __init__(self, hidden_size=16, dtype=torch.float16): super().__init__() self.hidden_size = hidden_size self.dtype = dtype - self.gate_proj = torch.nn.Parameter(torch.empty( - (self.hidden_size * 2, hidden_size)), - requires_grad=False) + self.gate_proj = torch.nn.Parameter( + torch.empty((self.hidden_size * 2, hidden_size)), requires_grad=False + ) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) def forward(self, hidden_states): """ Forward pass implementing the mm + reduce scatter in the FX graph - + """ # Reshape input view = hidden_states.reshape(-1, self.hidden_size) @@ -66,14 +77,13 @@ def ops_in_model_after(self): class TestAGMMModel(torch.nn.Module): - def __init__(self, hidden_size=16, dtype=torch.float16): super().__init__() self.hidden_size = hidden_size self.dtype = dtype - self.weight = torch.nn.Parameter(torch.empty( - (hidden_size, hidden_size)), - requires_grad=False) + self.weight = torch.nn.Parameter( + torch.empty((hidden_size, hidden_size)), requires_grad=False + ) # Initialize weights torch.nn.init.normal_(self.weight, std=0.02) @@ -96,32 +106,35 @@ def ops_in_model_after(self): class _BaseScaledMMModel(torch.nn.Module): - def __init__(self, hidden_size=16, dtype=torch.float16): super().__init__() self.hidden_size = hidden_size self.dtype = dtype - self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\ - .contiguous().transpose(0, 1) + self.weight = ( + torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) # Initialize scale_b for _scaled_mm. self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32) class TestScaledMMRSModel(_BaseScaledMMModel): - def forward(self, input: torch.Tensor): """ Forward pass implementing the scaled_mm + reduce scatter in the FX graph - + """ fp8_input = input.to(FP8_DTYPE) scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32) - scaled_mm = torch._scaled_mm(fp8_input, - self.weight, - scale_a=scale_a, - scale_b=self.scale_b, - out_dtype=self.dtype) + scaled_mm = torch._scaled_mm( + fp8_input, + self.weight, + scale_a=scale_a, + scale_b=self.scale_b, + out_dtype=self.dtype, + ) reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0) return reduce_scatter @@ -133,7 +146,6 @@ def ops_in_model_after(self): class TestAGScaledMMModel(_BaseScaledMMModel): - def forward(self, input: torch.Tensor): """ Forward pass implementing the all gather + scaled_mm in the FX graph @@ -143,11 +155,13 @@ def forward(self, input: torch.Tensor): all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0) scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32) - scaled_mm = torch._scaled_mm(all_gather, - self.weight, - scale_a=scale_a, - scale_b=self.scale_b, - out_dtype=self.dtype) + scaled_mm = torch._scaled_mm( + all_gather, + self.weight, + scale_a=scale_a, + scale_b=self.scale_b, + out_dtype=self.dtype, + ) return scaled_mm def ops_in_model_before(self): @@ -158,20 +172,22 @@ def ops_in_model_after(self): class TestCutlassScaledMMRSModel(_BaseScaledMMModel): - def forward(self, input: torch.Tensor): """ Forward pass implementing the cutlass_scaled_mm + reduce scatter in the FX graph - + """ fp8_input = input.to(FP8_DTYPE) scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32) - mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]), - dtype=self.dtype, - device=input.device) - torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a, - self.scale_b, None) + mm_out = torch.empty( + (fp8_input.shape[0], self.weight.shape[1]), + dtype=self.dtype, + device=input.device, + ) + torch.ops._C.cutlass_scaled_mm( + mm_out, fp8_input, self.weight, scale_a, self.scale_b, None + ) reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0) return reduce_scatter @@ -183,10 +199,9 @@ def ops_in_model_after(self): class TestAGCutlassScaledMMModel(_BaseScaledMMModel): - def forward(self, input: torch.Tensor): """ - Forward pass implementing the all gather + cutlass_scaled_mm + Forward pass implementing the all gather + cutlass_scaled_mm in the FX graph """ # Reshape input @@ -195,11 +210,14 @@ def forward(self, input: torch.Tensor): scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32) - mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]), - dtype=self.dtype, - device=all_gather.device) - torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight, - scale_a, self.scale_b, None) + mm_out = torch.empty( + (all_gather.shape[0], self.weight.shape[1]), + dtype=self.dtype, + device=all_gather.device, + ) + torch.ops._C.cutlass_scaled_mm( + mm_out, all_gather, self.weight, scale_a, self.scale_b, None + ) return mm_out def ops_in_model_before(self): @@ -210,23 +228,37 @@ def ops_in_model_after(self): @multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("test_model", [ - TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel, - TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel -]) +@pytest.mark.parametrize( + "test_model", + [ + TestMMRSModel, + TestAGMMModel, + TestScaledMMRSModel, + TestAGScaledMMModel, + TestCutlassScaledMMRSModel, + TestAGCutlassScaledMMModel, + ], +) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [16]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") -def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): - if test_model in (TestScaledMMRSModel, TestAGScaledMMModel, - TestCutlassScaledMMRSModel, - TestAGCutlassScaledMMModel) and dtype == torch.float16: +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_async_tp_pass_replace( + test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype +): + if ( + test_model + in ( + TestScaledMMRSModel, + TestAGScaledMMModel, + TestCutlassScaledMMRSModel, + TestAGCutlassScaledMMModel, + ) + and dtype == torch.float16 + ): pytest.skip( - "Only bf16 high precision output types are supported for " \ + "Only bf16 high precision output types are supported for " "per-token (row-wise) scaling" ) @@ -235,19 +267,24 @@ def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int, def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda - torch.multiprocessing.spawn(fn, - args=(num_processes, test_model, - batch_size, seq_len, hidden_size, - dtype), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + nprocs=nprocs, + ) run_torch_spawn(async_tp_pass_on_test_model, num_processes) -def async_tp_pass_on_test_model(local_rank: int, world_size: int, - test_model_cls: torch.nn.Module, - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): +def async_tp_pass_on_test_model( + local_rank: int, + world_size: int, + test_model_cls: torch.nn.Module, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, +): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -255,13 +292,15 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) # initialize distributed init_distributed_environment() @@ -269,27 +308,28 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( - enable_async_tp=True, ), ) + vllm_config.compilation_config = CompilationConfig( + pass_config=PassConfig( + enable_async_tp=True, + ), + ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model_name, - trust_remote_code=True, - dtype=dtype, - seed=42) + vllm_config.model_config = ModelConfig( + model=model_name, trust_remote_code=True, dtype=dtype, seed=42 + ) async_tp_pass = AsyncTPPass(vllm_config) backend = TestBackend(async_tp_pass) - model = test_model_cls(hidden_size, - dtype) # Pass dtype to model constructor + model = test_model_cls(hidden_size, dtype) # Pass dtype to model constructor - hidden_states = torch.randn((batch_size * seq_len, hidden_size), - dtype=dtype, - requires_grad=False) + hidden_states = torch.randn( + (batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False + ) compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states) @@ -306,10 +346,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, @create_new_process_for_each_test() -@pytest.mark.parametrize("model_id", [ - "meta-llama/Llama-3.2-1B-Instruct", - "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8" -]) +@pytest.mark.parametrize( + "model_id", + ["meta-llama/Llama-3.2-1B-Instruct", "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"], +) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("async_tp_enabled", [True]) @pytest.mark.parametrize("distributed_backend", ["mp"]) @@ -342,12 +382,10 @@ def test_async_tp_pass_correctness( common_args.append("--enforce-eager") compilation_config = { - 'level': 3, - 'compile_sizes': [2, 4, 8], - 'splitting_ops': [], - 'pass_config': { - 'enable_async_tp': async_tp_enabled - }, + "level": 3, + "compile_sizes": [2, 4, 8], + "splitting_ops": [], + "pass_config": {"enable_async_tp": async_tp_enabled}, } async_tp_env = tp_env = { @@ -372,9 +410,6 @@ def test_async_tp_pass_correctness( "mp", ] - compare_two_settings(model_id, - async_tp_args, - tp_args, - async_tp_env, - tp_env, - method="generate") + compare_two_settings( + model_id, async_tp_args, tp_args, async_tp_env, tp_env, method="generate" + ) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index a1e5127ebeeb..4bcefb30b2e6 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -103,23 +103,28 @@ def test_compile_correctness( attn_backend = test_setting.attn_backend method = test_setting.method if cuda_device_count_stateless() < pp_size * tp_size: - pytest.skip(f"Need at least {pp_size}*{tp_size} CUDA gpus but got " - f"{cuda_device_count_stateless()}") + pytest.skip( + f"Need at least {pp_size}*{tp_size} CUDA gpus but got " + f"{cuda_device_count_stateless()}" + ) with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) final_args = [ - "--enforce-eager", *model_args, "-pp", - str(pp_size), "-tp", - str(tp_size) + "--enforce-eager", + *model_args, + "-pp", + str(pp_size), + "-tp", + str(tp_size), ] all_args: list[list[str]] = [] all_envs: list[dict[str, str] | None] = [] for level in [ - CompilationLevel.NO_COMPILATION, - CompilationLevel.PIECEWISE, + CompilationLevel.NO_COMPILATION, + CompilationLevel.PIECEWISE, ]: all_args.append(final_args + [f"-O{level}"]) all_envs.append({}) @@ -130,14 +135,15 @@ def test_compile_correctness( model, all_args, all_envs, - method=method if method != "generate" else "generate_close") + method=method if method != "generate" else "generate_close", + ) all_envs.clear() all_args.clear() for level in [ - CompilationLevel.NO_COMPILATION, - CompilationLevel.DYNAMO_AS_IS, - CompilationLevel.DYNAMO_ONCE, + CompilationLevel.NO_COMPILATION, + CompilationLevel.DYNAMO_AS_IS, + CompilationLevel.DYNAMO_ONCE, ]: all_args.append(final_args + [f"-O{level}"]) all_envs.append({}) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 17d3f0b37768..d055a41af4c4 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -9,11 +9,11 @@ def test_version(): - assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev') - assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev') - assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev') - assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev') - assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev') + assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev") + assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev") + assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev") + assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev") + assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev") def test_use_cudagraphs_dynamic(monkeypatch): @@ -21,7 +21,7 @@ def test_use_cudagraphs_dynamic(monkeypatch): vllm_config = VllmConfig() assert vllm_config.compilation_config.use_cudagraph - monkeypatch.setenv('VLLM_USE_V1', '0') + monkeypatch.setenv("VLLM_USE_V1", "0") vllm_config = VllmConfig() assert not vllm_config.compilation_config.use_cudagraph @@ -44,19 +44,23 @@ def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val): assert vllm.envs.VLLM_USE_V1 # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') - monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val) + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val) compilation_config = { "use_cudagraph": False, # speed things up a bit } with ( - compilation_counter.expect(num_cache_entries_updated=0, - num_compiled_artifacts_saved=0), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - compilation_config=compilation_config, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect( + num_cache_entries_updated=0, num_compiled_artifacts_saved=0 + ), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", + compilation_config=compilation_config, + gpu_memory_utilization=0.4, + ) as _, + ): pass @@ -67,22 +71,25 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): assert vllm.envs.VLLM_USE_V1 # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") compilation_config = { "cudagraph_capture_sizes": [100], "use_cudagraph": enabled, } with ( - compilation_counter.expect( - num_graphs_seen=1, - num_gpu_runner_capture_triggers=1 if enabled else 0, - num_cudagraph_captured=13 if enabled else 0, - ), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - compilation_config=compilation_config, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect( + num_graphs_seen=1, + num_gpu_runner_capture_triggers=1 if enabled else 0, + num_cudagraph_captured=13 if enabled else 0, + ), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", + compilation_config=compilation_config, + gpu_memory_utilization=0.4, + ) as _, + ): pass @@ -90,14 +97,17 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): @pytest.mark.forked def test_dynamo_as_is(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(dynamo_as_is_count=1), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - compilation_config={"level": 1}, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect(dynamo_as_is_count=1), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", + compilation_config={"level": 1}, + gpu_memory_utilization=0.4, + ) as _, + ): pass @@ -105,14 +115,16 @@ def test_dynamo_as_is(vllm_runner, monkeypatch): @pytest.mark.forked def test_no_compilation(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(num_graphs_seen=0, - dynamo_as_is_count=0), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - compilation_config={"level": 0}, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", + compilation_config={"level": 0}, + gpu_memory_utilization=0.4, + ) as _, + ): pass @@ -120,77 +132,73 @@ def test_no_compilation(vllm_runner, monkeypatch): @pytest.mark.forked def test_enforce_eager(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process - monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with ( - compilation_counter.expect(num_graphs_seen=0, - dynamo_as_is_count=0), - # loading the model causes compilation (if enabled) to happen - vllm_runner('facebook/opt-125m', - enforce_eager=True, - gpu_memory_utilization=0.4) as _): + compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0), + # loading the model causes compilation (if enabled) to happen + vllm_runner( + "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4 + ) as _, + ): pass def test_splitting_ops_dynamic(): # Default config config = VllmConfig() - assert config.compilation_config.cudagraph_mode == \ - CUDAGraphMode.FULL_AND_PIECEWISE + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE assert config.compilation_config.splitting_ops_contain_attention() # When use_inductor_graph_partition=True - if _is_torch_equal_or_newer('2.9.0.dev'): + if _is_torch_equal_or_newer("2.9.0.dev"): # inductor graph partition is only available in PyTorch 2.9+. # this is a fast config check so we are not using pytest.skip. - config = VllmConfig(compilation_config=CompilationConfig( - use_inductor_graph_partition=True, - splitting_ops=["silly_attention"])) + config = VllmConfig( + compilation_config=CompilationConfig( + use_inductor_graph_partition=True, splitting_ops=["silly_attention"] + ) + ) # should ignore splitting_ops assert config.compilation_config.splitting_ops == [] # When attn_fusion pass enabled. - config = VllmConfig(compilation_config=CompilationConfig( - pass_config={ - "enable_attn_fusion": True, - "enable_noop": True - }, - custom_ops=["+quant_fp8"], - cudagraph_mode=CUDAGraphMode.PIECEWISE, - )) + config = VllmConfig( + compilation_config=CompilationConfig( + pass_config={"enable_attn_fusion": True, "enable_noop": True}, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + ) + ) assert config.compilation_config.splitting_ops == [] # cudagraph mode also fall back to FULL - assert config.compilation_config.cudagraph_mode == \ - CUDAGraphMode.FULL + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL # splitting_ops can not contain attention ops when attn_fusion # pass enabled. with pytest.raises(AssertionError): - config = VllmConfig(compilation_config=CompilationConfig( - pass_config={ - "enable_attn_fusion": True, - "enable_noop": True - }, - custom_ops=["+quant_fp8"], - cudagraph_mode=CUDAGraphMode.PIECEWISE, - # work around for accessing all attntion ops - splitting_ops=CompilationConfig()._attention_ops, - )) + config = VllmConfig( + compilation_config=CompilationConfig( + pass_config={"enable_attn_fusion": True, "enable_noop": True}, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + # work around for accessing all attntion ops + splitting_ops=CompilationConfig()._attention_ops, + ) + ) # When both use_inductor_graph_partition and attn_fusion pass enabled. - if _is_torch_equal_or_newer('2.9.0.dev'): - config = VllmConfig(compilation_config=CompilationConfig( - use_inductor_graph_partition=True, - pass_config={ - "enable_attn_fusion": True, - "enable_noop": True - }, - custom_ops=["+quant_fp8"], - cudagraph_mode=CUDAGraphMode.PIECEWISE, - )) + if _is_torch_equal_or_newer("2.9.0.dev"): + config = VllmConfig( + compilation_config=CompilationConfig( + use_inductor_graph_partition=True, + pass_config={"enable_attn_fusion": True, "enable_noop": True}, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + ) + ) assert config.compilation_config.splitting_ops == [] # enable_attn_fusion is directly support under # use_inductor_graph_partition=True, and cudagraph_mode # is unchanged. - assert config.compilation_config.cudagraph_mode == \ - CUDAGraphMode.PIECEWISE + assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index d73586d53ff3..d7048821bb60 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -4,10 +4,15 @@ from torch import nn from vllm.compilation.counter import compilation_counter -from vllm.compilation.decorators import (ignore_torch_compile, - support_torch_compile) -from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, - CUDAGraphMode, VllmConfig, set_current_vllm_config) +from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile +from vllm.config import ( + CacheConfig, + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import BatchDescriptor, set_forward_context # This import automatically registers `torch.ops.silly.attention` @@ -18,32 +23,42 @@ @torch.inference_mode -def run_model(vllm_config: VllmConfig, model: nn.Module, - cudagraph_runtime_mode: CUDAGraphMode): +def run_model( + vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode +): with set_forward_context({}, vllm_config=vllm_config): # warmup for the model with cudagraph_mode NONE model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) # simulate cudagraphs capturing - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): model(torch.randn(2, MLP_SIZE).cuda()) - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=1, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=1, + ), + ): model(torch.randn(1, MLP_SIZE).cuda()) # simulate cudagraphs replay - with set_forward_context({}, - vllm_config=vllm_config, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=BatchDescriptor( - num_tokens=2, )): + with set_forward_context( + {}, + vllm_config=vllm_config, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=BatchDescriptor( + num_tokens=2, + ), + ): output = model(torch.randn(2, MLP_SIZE).cuda()) output = output.cpu() @@ -52,22 +67,21 @@ def run_model(vllm_config: VllmConfig, model: nn.Module, def test_ignore_torch_compile_decorator(): # piecewise - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + ) + ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE @support_torch_compile class A(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs + ) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -79,66 +93,60 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x @ignore_torch_compile - class B(A): - ... + class B(A): ... @support_torch_compile - class C(B): - ... + class C(B): ... with set_current_vllm_config(vllm_config): - mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() # A has support_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) with set_current_vllm_config(vllm_config): - mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() + mod_B = B(vllm_config=vllm_config, prefix="").eval().cuda() # B's ignore_torch_compile should override A's support_torch_compile with compilation_counter.expect( - num_graphs_seen=0, - num_piecewise_graphs_seen=0, - num_piecewise_capturable_graphs_seen=0, - num_backend_compilations=0, - num_cudagraph_captured=0, + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, ): run_model(vllm_config, mod_B, cudagraph_runtime_mode) with set_current_vllm_config(vllm_config): - mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() + mod_C = C(vllm_config=vllm_config, prefix="").eval().cuda() # C's support_torch_compile should override B's ignore_torch_compile with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=3, - num_piecewise_capturable_graphs_seen=2, - num_backend_compilations=2, - num_cudagraph_captured=4, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, mod_C, cudagraph_runtime_mode) # Only enable torch.compile if # vllm_config.cache_config.kv_sharing_fast_prefill=True -@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. - kv_sharing_fast_prefill) +@support_torch_compile( + enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill +) class B(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -152,15 +160,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Only enable torch.compile if # vllm_config.cache_config.kv_sharing_fast_prefill=False -@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. - cache_config.kv_sharing_fast_prefill) +@support_torch_compile( + enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill +) class A(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = '', - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs) @@ -175,54 +179,60 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_conditional_compile_enable_if(): - vllm_config = VllmConfig(cache_config=CacheConfig( - kv_sharing_fast_prefill=True, ), - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) + vllm_config = VllmConfig( + cache_config=CacheConfig( + kv_sharing_fast_prefill=True, + ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + ), + ) cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE with set_current_vllm_config(vllm_config): - mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() # A has support_torch_compile but enable_if fn returns False # enalbe_if will be True for B, so we expect mod1 and mod2 # to be compiled with compilation_counter.expect( - num_graphs_seen=2, - num_piecewise_graphs_seen=6, - # 3 piecewise graphs per instance of B() - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, - num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=2, + num_piecewise_graphs_seen=6, + # 3 piecewise graphs per instance of B() + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) # Set kv_sharing_fast_prefill=False # which will cause A to be compiled and B to not be compiled - vllm_config = VllmConfig(cache_config=CacheConfig( - kv_sharing_fast_prefill=False, ), - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - splitting_ops=["silly.attention"], - cudagraph_capture_sizes=[1, 2], - )) + vllm_config = VllmConfig( + cache_config=CacheConfig( + kv_sharing_fast_prefill=False, + ), + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + ), + ) with set_current_vllm_config(vllm_config): - mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + mod_A = A(vllm_config=vllm_config, prefix="").eval().cuda() with compilation_counter.expect( - num_graphs_seen=1, - num_piecewise_graphs_seen=7, - # 3 attn ops and 4 non-attn ops - num_piecewise_capturable_graphs_seen=4, - num_backend_compilations=4, - num_cudagraph_captured=8, - # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + num_graphs_seen=1, + num_piecewise_graphs_seen=7, + # 3 attn ops and 4 non-attn ops + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): run_model(vllm_config, mod_A, cudagraph_runtime_mode) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 3ecda1a8ec33..8ccae4cfb9df 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -5,7 +5,7 @@ import logging import tempfile -from typing import Any, Optional, Union +from typing import Any, Union import pytest import torch @@ -14,54 +14,64 @@ from vllm import LLM, SamplingParams from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, - PassConfig) +from vllm.config import CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer from ..utils import create_new_process_for_each_test -def models_list(*, all: bool = True, keywords: Optional[list[str]] = None): +def models_list(*, all: bool = True, keywords: list[str] | None = None): TEST_MODELS: list[tuple[str, dict[str, Any]]] = [ ("facebook/opt-125m", {}), - ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { - "dtype": torch.float16, - }), - ("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", { - "dtype": torch.float16, - }), + ( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + { + "dtype": torch.float16, + }, + ), + ( + "neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", + { + "dtype": torch.float16, + }, + ), ("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}), ("meta-llama/Llama-3.2-1B-Instruct", {}), ] if all: - # TODO: figure out why this fails. if False and is_quant_method_supported("gguf"): # noqa: SIM223 - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", { - "quantization": "gguf" - })) + TEST_MODELS.append( + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {"quantization": "gguf"}) + ) if is_quant_method_supported("gptq"): - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", { - "quantization": "gptq" - })) + TEST_MODELS.append( + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {"quantization": "gptq"}) + ) if is_quant_method_supported("gptq_marlin"): - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", { - "quantization": "gptq_marlin" - })) + TEST_MODELS.append( + ( + "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", + {"quantization": "gptq_marlin"}, + ) + ) if is_quant_method_supported("gptq_marlin_24"): - TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", { - "quantization": "gptq_marlin_24" - })) + TEST_MODELS.append( + ( + "alexm-nm/tinyllama-24-marlin24-4bit-g128", + {"quantization": "gptq_marlin_24"}, + ) + ) if not current_platform.is_rocm() and is_quant_method_supported("awq"): - TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { - "quantization": "AWQ" - })) + TEST_MODELS.append( + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {"quantization": "AWQ"}) + ) if keywords is None: return TEST_MODELS @@ -95,22 +105,34 @@ def test_full_graph( "compilation_config, model_info", [ # additional compile sizes, only some of the models - (CompilationConfig(level=CompilationLevel.PIECEWISE, - compile_sizes=[1, 2]), model) + ( + CompilationConfig(level=CompilationLevel.PIECEWISE, compile_sizes=[1, 2]), + model, + ) for model in models_list(all=False) - ] + [ + ] + + [ # RMSNorm + quant fusion, only 8-bit quant models - (CompilationConfig(level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm"], - pass_config=PassConfig(enable_fusion=True, - enable_noop=True)), model) + ( + CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ), + model, + ) for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) - ] + [ + ] + + [ # Test depyf integration works - (CompilationConfig(level=CompilationLevel.PIECEWISE, - debug_dump_path=tempfile.gettempdir()), - ("facebook/opt-125m", {})), - ] + [ + ( + CompilationConfig( + level=CompilationLevel.PIECEWISE, debug_dump_path=tempfile.gettempdir() + ), + ("facebook/opt-125m", {}), + ), + ] + + [ # graph inductor partition ( CompilationConfig( @@ -119,20 +141,24 @@ def test_full_graph( # torch._C.Tag.cudagraph_unsafe to specify splitting ops use_inductor_graph_partition=True, cudagraph_mode=CUDAGraphMode.PIECEWISE, - compile_sizes=[1, 2]), - model) for model in models_list(all=False) + compile_sizes=[1, 2], + ), + model, + ) + for model in models_list(all=False) if is_torch_equal_or_newer("2.9.0.dev") - ]) + ], +) # only test some of the models @create_new_process_for_each_test() def test_custom_compile_config( compilation_config: CompilationConfig, model_info: tuple[str, dict[str, Any]], ): - if (compilation_config.use_inductor_graph_partition - and not is_torch_equal_or_newer("2.9.0.dev")): - pytest.skip("inductor graph partition is only available " - "in PyTorch 2.9+") + if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer( + "2.9.0.dev" + ): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") model, model_kwargs = model_info print(f"MODEL={model}") @@ -156,8 +182,7 @@ def test_fp8_kv_scale_compile(optimization_level: int): def test_inductor_graph_partition_attn_fusion(caplog_vllm): if not is_torch_equal_or_newer("2.9.0.dev"): - pytest.skip("inductor graph partition is only available " - "in PyTorch 2.9+") + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" compilation_config = CompilationConfig( @@ -171,14 +196,16 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm): "kv_cache_dtype": "fp8", "max_model_len": 1024, } - with caplog_vllm.at_level( - logging.DEBUG), global_force_attn_backend_context_manager( - _Backend.FLASHINFER): + with ( + caplog_vllm.at_level(logging.DEBUG), + global_force_attn_backend_context_manager(_Backend.FLASHINFER), + ): run_model(compilation_config, model, model_kwargs) try: - assert ("Fused quantization onto 48 attention nodes" - in caplog_vllm.text), caplog_vllm.text + assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, ( + caplog_vllm.text + ) except AssertionError: # Note: this message is only triggered when the compilation goes # through the custom pass. Due to multiple layers of cache on @@ -189,8 +216,11 @@ def test_inductor_graph_partition_attn_fusion(caplog_vllm): assert "Fused quantization" not in caplog_vllm.text -def run_model(compile_config: Union[int, CompilationConfig], model: str, - model_kwargs: dict[str, Any]): +def run_model( + compile_config: Union[int, CompilationConfig], + model: str, + model_kwargs: dict[str, Any], +): prompts = [ "Hello, my name is", "The president of the United States is", diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 0c8d610bc9c5..ae17bc67b1fb 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -14,10 +14,8 @@ from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform @@ -28,7 +26,6 @@ class TestSiluMul(torch.nn.Module): - def __init__(self, hidden_size: int = 128): super().__init__() self.silu_and_mul = SiluAndMul() @@ -36,8 +33,7 @@ def __init__(self, hidden_size: int = 128): self.scale = torch.rand(1, dtype=torch.float32) if TEST_FP8: - self.w = torch.rand(hidden_size, - hidden_size).to(dtype=FP8_DTYPE).t() + self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.fp8_linear = Fp8LinearOp( act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR, @@ -46,17 +42,14 @@ def __init__(self, hidden_size: int = 128): def forward(self, x): y = self.silu_and_mul(x) if TEST_FP8: - x2 = self.fp8_linear.apply(y, - self.w, - self.wscale, - input_scale=self.wscale) + x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) return x2 else: return y def example_inputs(self, num_tokens=32, hidden_size=128): dtype = torch.float16 if TEST_FP8 else torch.float32 - return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), ) + return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),) def ops_in_model(self, do_fusion): if TEST_FP8 and do_fusion: @@ -69,7 +62,6 @@ def ops_not_in_model(self): class TestFusedAddRMSNorm(torch.nn.Module): - def __init__(self, hidden_size=16, intermediate_size=32): super().__init__() self.hidden_size = hidden_size @@ -78,10 +70,12 @@ def __init__(self, hidden_size=16, intermediate_size=32): dtype = torch.float16 if TEST_FP8 else torch.float32 self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size), dtype=dtype)) + torch.empty((intermediate_size, hidden_size), dtype=dtype) + ) self.norm = RMSNorm(intermediate_size, 1e-05) self.norm.weight = torch.nn.Parameter( - torch.ones(intermediate_size, dtype=dtype)) + torch.ones(intermediate_size, dtype=dtype) + ) torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -89,8 +83,7 @@ def __init__(self, hidden_size=16, intermediate_size=32): self.fp8_linear = Fp8LinearOp(act_quant_static=True) self.scale = torch.rand(1, dtype=torch.float32) - self.w = torch.rand(hidden_size, - intermediate_size).to(dtype=FP8_DTYPE).t() + self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() self.wscale = torch.rand(1, dtype=torch.float32) def forward(self, hidden_states, residual): @@ -120,10 +113,8 @@ def forward(self, hidden_states, residual): def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16): dtype = torch.float16 if TEST_FP8 else torch.float32 - hidden_states = torch.randn((batch_size * seq_len, hidden_size), - dtype=dtype) - residual = torch.randn((batch_size * seq_len, hidden_size), - dtype=dtype) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) + residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) return (hidden_states, residual) def ops_in_model(self, do_fusion): @@ -137,12 +128,7 @@ def ops_not_in_model(self): class TestRotaryEmbedding(torch.nn.Module): - - def __init__(self, - head_dim=64, - rotary_dim=None, - max_position=2048, - base=10000): + def __init__(self, head_dim=64, rotary_dim=None, max_position=2048, base=10000): super().__init__() self.head_dim = head_dim self.rotary_dim = rotary_dim or head_dim @@ -173,21 +159,15 @@ def ops_not_in_model(self): class TestRotaryEmbeddingSliceScatter(torch.nn.Module): - - def __init__(self, - head_dim=64, - num_heads=4, - max_position=2048, - base=10000): + def __init__(self, head_dim=64, num_heads=4, max_position=2048, base=10000): super().__init__() self.head_dim = head_dim self.num_heads = num_heads self.hidden_size = head_dim * num_heads - self.qkv_proj = torch.nn.Linear(self.hidden_size, - self.hidden_size * 3, - bias=False, - dtype=torch.float16) + self.qkv_proj = torch.nn.Linear( + self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16 + ) self.rotary_emb = get_rope( self.head_dim, @@ -233,21 +213,24 @@ def ops_not_in_model(self): @pytest.mark.parametrize("model_class", MODELS) @pytest.mark.parametrize("do_fusion", [True, False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", - reason="Only test on CUDA") +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): torch.set_default_device("cuda") vllm_config = VllmConfig() vllm_config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True) + ) noop_pass = NoOpEliminationPass(vllm_config) fusion_pass = RMSNormQuantFusionPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] - if do_fusion else [noop_pass, cleanup_pass]) + passes = ( + [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] + if do_fusion + else [noop_pass, cleanup_pass] + ) func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(*passes, func_pass) @@ -260,8 +243,7 @@ def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): # check if the functionalization pass is applied for op in model.ops_in_model(do_fusion): find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) - is None) # noqa: E501 + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # make sure the ops were all de-functionalized found = dict() diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 3d8897d3f18b..7c2233643229 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -5,17 +5,26 @@ import torch import vllm.plugins -from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - RMSNormQuantFusionPass) +from vllm.compilation.fusion import ( + FUSED_OPS, + QUANT_OPS, + FusedRMSQuantKey, + RMSNormQuantFusionPass, +) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, - VllmConfig) +from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, QuantKey, ScaleDesc) + GroupShape, + QuantKey, + ScaleDesc, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_fp8_supported, maybe_create_device_identity) + Fp8LinearOp, + cutlass_fp8_supported, + maybe_create_device_identity, +) from vllm.platforms import current_platform from ..utils import override_cutlass_fp8_supported @@ -25,9 +34,15 @@ class TestModel(torch.nn.Module): - - def __init__(self, hidden_size: int, eps: float, static: bool, - cuda_force_torch: bool, *args, **kwargs): + def __init__( + self, + hidden_size: int, + eps: float, + static: bool, + cuda_force_torch: bool, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.cuda_force_torch = cuda_force_torch self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] @@ -54,17 +69,15 @@ def forward(self, x): resid = torch.sqrt(x) y = self.norm[0](x) - x2 = self.fp8_linear.apply(y, - self.w[0], - self.wscale[0], - input_scale=self.scale[0]) + x2 = self.fp8_linear.apply( + y, self.w[0], self.wscale[0], input_scale=self.scale[0] + ) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = self.fp8_linear.apply(y2, - self.w[1], - self.wscale[1], - input_scale=self.scale[1]) + x3 = self.fp8_linear.apply( + y2, self.w[1], self.wscale[1], input_scale=self.scale[1] + ) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -74,7 +87,7 @@ def ops_in_model_before(self): def ops_in_model_after(self): return [ FUSED_OPS[FusedRMSQuantKey(self.key, False)], - FUSED_OPS[FusedRMSQuantKey(self.key, True)] + FUSED_OPS[FusedRMSQuantKey(self.key, True)], ] @@ -85,22 +98,27 @@ def ops_in_model_after(self): @pytest.mark.parametrize("static", [True, False]) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. -@pytest.mark.parametrize("cuda_force_torch", - [True, False] if cutlass_fp8_supported() else [True]) -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test on CUDA and ROCm") -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, - cuda_force_torch): +@pytest.mark.parametrize( + "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] +) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" +) +def test_fusion_rmsnorm_quant( + dtype, hidden_size, num_tokens, eps, static, cuda_force_torch +): torch.set_default_device("cuda") torch.set_default_dtype(dtype) torch.manual_seed(1) maybe_create_device_identity() # needed for certain non-cutlass fp8 paths - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm", "+quant_fp8"], - pass_config=PassConfig(enable_fusion=True, enable_noop=True), - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ) + ) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py index 60f32c863208..7e5c460db174 100644 --- a/tests/compile/test_fusion_all_reduce.py +++ b/tests/compile/test_fusion_all_reduce.py @@ -10,14 +10,24 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, - ModelConfig, PassConfig, VllmConfig) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + DeviceConfig, + ModelConfig, + PassConfig, + VllmConfig, +) from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - GroupShape, QuantFP8) + GroupShape, + QuantFP8, +) from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -26,7 +36,6 @@ class TestAllReduceRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size @@ -47,7 +56,6 @@ def ops_in_model_after(self): class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size @@ -68,25 +76,22 @@ def ops_in_model_after(self): class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = RMSNorm(hidden_size, eps) - self.quant_fp8 = QuantFP8(static=True, - group_shape=GroupShape.PER_TENSOR) + self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), - dtype=torch.float32) + self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) norm_output, residual_output = self.norm(all_reduce, residual) - torch.ops._C.static_scaled_fp8_quant(self.output, - norm_output.contiguous(), - self.scale) + torch.ops._C.static_scaled_fp8_quant( + self.output, norm_output.contiguous(), self.scale + ) return self.output, residual_output def ops_in_model_after(self): @@ -95,35 +100,33 @@ def ops_in_model_after(self): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.static_scaled_fp8_quant.default + torch.ops._C.static_scaled_fp8_quant.default, ] class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): - def __init__(self, hidden_size=16, token_num=16, eps=1e-6): super().__init__() self.hidden_size = hidden_size self.eps = eps self.norm = RMSNorm(hidden_size, eps) self.scale = torch.rand(1, dtype=torch.float32) - self.output = torch.empty((token_num, hidden_size), - dtype=torch.float32) + self.output = torch.empty((token_num, hidden_size), dtype=torch.float32) round_up = lambda x, y: (x + y - 1) // y * y rounded_m = round_up(token_num, 128) scale_n = hidden_size // 16 rounded_n = round_up(scale_n, 4) - self.output_scale = torch.empty((rounded_m, rounded_n // 4), - dtype=torch.int32) + self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32) def forward(self, hidden_states, residual): view = hidden_states.reshape(-1, self.hidden_size) all_reduce = tensor_model_parallel_all_reduce(view) norm_output, residual_output = self.norm(all_reduce, residual) norm_output = norm_output.reshape(-1, norm_output.shape[-1]) - torch.ops._C.scaled_fp4_quant(self.output, norm_output, - self.output_scale, self.scale) + torch.ops._C.scaled_fp4_quant( + self.output, norm_output, self.output_scale, self.scale + ) return self.output, residual_output, self.output_scale def ops_in_model_after(self): @@ -132,7 +135,7 @@ def ops_in_model_after(self): def ops_in_model_before(self): return [ torch.ops.vllm.all_reduce.default, - torch.ops._C.scaled_fp4_quant.default + torch.ops._C.scaled_fp4_quant.default, ] @@ -145,41 +148,55 @@ def ops_in_model_before(self): TestAllReduceFusedAddRMSNormStaticQuantFP8Model, # TODO: Enable with torch==2.8.0 # TestAllReduceFusedAddRMSNormStaticQuantFP4Model, - ]) + ], +) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("seq_len", [8]) @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( not find_spec("flashinfer") or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"), reason="flashinfer is not found or flashinfer " - "is not compiled with trtllm_allreduce_fusion") -def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): + "is not compiled with trtllm_allreduce_fusion", +) +def test_all_reduce_fusion_pass_replace( + test_model: torch.nn.Module, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, +): num_processes = 2 - if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model - and not current_platform.has_device_capability(100)): - pytest.skip("Skip as nvfp4 is only supported on " - "devices with compute capability 10.0 (Blackwell)") + if ( + test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model + and not current_platform.has_device_capability(100) + ): + pytest.skip( + "Skip as nvfp4 is only supported on " + "devices with compute capability 10.0 (Blackwell)" + ) def run_torch_spawn(fn, nprocs): - torch.multiprocessing.spawn(fn, - args=(num_processes, test_model, - batch_size, seq_len, hidden_size, - dtype), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype), + nprocs=nprocs, + ) run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes) -def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, - test_model_cls: torch.nn.Module, - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype): +def all_reduce_fusion_pass_on_test_model( + local_rank: int, + world_size: int, + test_model_cls: torch.nn.Module, + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, +): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -187,39 +204,42 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - custom_ops=["+rms_norm", "+quant_fp8"])) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm", "+quant_fp8"] + ) + ) vllm_config.compilation_config.pass_config = PassConfig( - enable_fi_allreduce_fusion=True, enable_noop=True) + enable_fi_allreduce_fusion=True, enable_noop=True + ) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model_name, - trust_remote_code=True, - dtype=dtype, - seed=42) + vllm_config.model_config = ModelConfig( + model=model_name, trust_remote_code=True, dtype=dtype, seed=42 + ) all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) - backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, - cleanup_pass) + backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass) token_num = batch_size * seq_len model = test_model_cls(hidden_size, token_num) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 077cf11d048a..1fd5c267650b 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -19,14 +19,23 @@ from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, - ModelConfig, PassConfig, SchedulerConfig, VllmConfig, - set_current_vllm_config) +from vllm.config import ( + CacheConfig, + CompilationConfig, + CompilationLevel, + ModelConfig, + PassConfig, + SchedulerConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym, kNvfp4Quant) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp) + QuantKey, + kFp8StaticTensorSym, + kNvfp4Quant, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer from vllm.v1.kv_cache_interface import AttentionSpec @@ -40,14 +49,16 @@ @pytest.mark.parametrize( - "model, quant_key", - [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) + "model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)] +) @pytest.mark.parametrize("use_triton_fa", [True, False]) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="V0 attn quant fusion only on ROCm") -def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, - quant_key: QuantKey, use_triton_fa: bool): +@pytest.mark.skipif( + not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm" +) +def test_attention_fusion_v0( + example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool +): # Clean Dynamo cache to avoid reusing other test cases # (for some reason the reset at the end is not enough) torch._dynamo.reset() @@ -69,22 +80,24 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, backend="tests.compile.test_fusion_attn.backend_unfused", custom_ops=["+quant_fp8"], ) - vllm_config = VllmConfig(compilation_config=compile_config, - model_config=ModelConfig( - model=model, - dtype=torch.bfloat16, - )) + vllm_config = VllmConfig( + compilation_config=compile_config, + model_config=ModelConfig( + model=model, + dtype=torch.bfloat16, + ), + ) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) - llm = LLM(model, - enforce_eager=True, - compilation_config=compile_config, - gpu_memory_utilization=0.5, - max_model_len=2048) + llm = LLM( + model, + enforce_eager=True, + compilation_config=compile_config, + gpu_memory_utilization=0.5, + max_model_len=2048, + ) - sampling_params = SamplingParams(temperature=0.0, - max_tokens=10, - top_p=0.95) + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95) unfused_output = llm.generate(prompts, sampling_params) backend_unfused = None # Reset backend to make sure llm gets released @@ -97,21 +110,25 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, backend="tests.compile.test_fusion_attn.backend", custom_ops=["+quant_fp8"], ) - vllm_config = VllmConfig(compilation_config=compile_config, - model_config=ModelConfig( - model=model, - dtype=torch.bfloat16, - )) + vllm_config = VllmConfig( + compilation_config=compile_config, + model_config=ModelConfig( + model=model, + dtype=torch.bfloat16, + ), + ) # AttnFusionPass needs attention layers to be registered in config upon init # so we initialize it during compilation. attn_pass = LazyInitPass(AttnFusionPass, vllm_config) backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) - llm2 = LLM(model, - enforce_eager=True, - compilation_config=compile_config, - gpu_memory_utilization=0.5, - max_model_len=2048) + llm2 = LLM( + model, + enforce_eager=True, + compilation_config=compile_config, + gpu_memory_utilization=0.5, + max_model_len=2048, + ) # check support attn_fusion_supported = [ @@ -132,9 +149,9 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, for i in range(len(attn_nodes_pre)): assert attn_nodes_pre[i].kwargs["output_scale"] is None fused = attn_nodes_post[i].kwargs["output_scale"] is not None - assert fused == attn_fusion_supported[i], \ - f"Node {i} {'' if fused else 'not '} expected " \ - f"to have fused output quant" + assert fused == attn_fusion_supported[i], ( + f"Node {i} {'' if fused else 'not '} expected to have fused output quant" + ) # check outputs fused_output = llm2.generate(prompts, sampling_params) @@ -160,9 +177,16 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, class AttentionQuantPatternModel(torch.nn.Module): """Base model for AttentionQuantPattern fusion.""" - def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, - kv_cache_dtype: torch.dtype, device: torch.device, - vllm_config: VllmConfig, **kwargs): + def __init__( + self, + num_qo_heads: int, + num_kv_heads: int, + head_size: int, + kv_cache_dtype: torch.dtype, + device: torch.device, + vllm_config: VllmConfig, + **kwargs, + ): super().__init__() self.num_qo_heads = num_qo_heads self.num_kv_heads = num_kv_heads @@ -197,33 +221,30 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, device=self.device, ) - def build_attn_metadata(self, batch_size: int, use_hnd: bool) \ - -> AttentionMetadata: + def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata: """Initialize attention metadata.""" # Create common attn metadata - batch_spec = BatchSpec(seq_lens=[1] * batch_size, - query_lens=[1] * batch_size) + batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size) common_attn_metadata = create_common_attn_metadata( - batch_spec, - self.block_size, - self.device, - arange_block_indices=True) + batch_spec, self.block_size, self.device, arange_block_indices=True + ) - max_blocks = (max(batch_spec.seq_lens) + self.block_size - - 1) // self.block_size + max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size num_blocks = batch_size * max_blocks # Create dummy KV cache for FlashInfer TRTLLM # - NHD: [num_blocks, block_size, num_kv_heads, head_size] # - HND: [num_blocks, num_kv_heads, block_size, head_size] - kv_cache = torch.zeros(num_blocks, - 2, - self.num_kv_heads, - self.block_size, - self.head_size, - dtype=self.kv_cache_dtype, - device=self.device) + kv_cache = torch.zeros( + num_blocks, + 2, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ) if current_platform.is_rocm(): # k/v as 1st dimention if use_hnd: @@ -239,7 +260,8 @@ def build_attn_metadata(self, batch_size: int, use_hnd: bool) \ # Build attn metadata self.attn_metadata = self.builder.build( - common_prefix_len=0, common_attn_metadata=common_attn_metadata) + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) return self.attn_metadata @@ -254,27 +276,30 @@ def __init__(self, *args, **kwargs): self.fp8_linear = Fp8LinearOp( act_quant_static=self.quant_key.scale.static, - act_quant_group_shape=self.quant_key.scale.group_shape) + act_quant_group_shape=self.quant_key.scale.group_shape, + ) hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( - "w", { - "weight": - torch.randn(hidden_size, hidden_size).to( - dtype=FP8_DTYPE, device=self.device).t(), - "wscale": - torch.tensor([1.0], dtype=torch.float32, device=self.device), - "scale": - torch.tensor([1.0], dtype=torch.float32, device=self.device), - }) + "w", + { + "weight": torch.randn(hidden_size, hidden_size) + .to(dtype=FP8_DTYPE, device=self.device) + .t(), + "wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device), + "scale": torch.tensor([1.0], dtype=torch.float32, device=self.device), + }, + ) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) - return self.fp8_linear.apply(input=attn_output, - weight=self.w["weight"], - weight_scale=self.w["wscale"], - input_scale=self.w["scale"]) + return self.fp8_linear.apply( + input=attn_output, + weight=self.w["weight"], + weight_scale=self.w["wscale"], + input_scale=self.w["scale"], + ) class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): @@ -287,42 +312,54 @@ def __init__(self, *args, **kwargs): hidden_size = self.num_qo_heads * self.head_size self.w = kwargs.get( - "w", { - "weight": - torch.randint(256, (hidden_size, hidden_size // 2), - dtype=FP4_DTYPE, - device=self.device), - "wscale_swizzled": - torch.randn(hidden_size, hidden_size // 16).to( - dtype=FP8_DTYPE, device=self.device), - "wscale": - torch.tensor([500], dtype=torch.float32, device=self.device), - "scale": - torch.tensor([0.002], dtype=torch.float32, device=self.device), - }) + "w", + { + "weight": torch.randint( + 256, + (hidden_size, hidden_size // 2), + dtype=FP4_DTYPE, + device=self.device, + ), + "wscale_swizzled": torch.randn(hidden_size, hidden_size // 16).to( + dtype=FP8_DTYPE, device=self.device + ), + "wscale": torch.tensor([500], dtype=torch.float32, device=self.device), + "scale": torch.tensor([0.002], dtype=torch.float32, device=self.device), + }, + ) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) quant_output, output_block_scale = scaled_fp4_quant( - attn_output, 1 / self.w["scale"]) - return cutlass_scaled_fp4_mm(a=quant_output, - b=self.w["weight"], - block_scale_a=output_block_scale, - block_scale_b=self.w["wscale_swizzled"], - alpha=self.w["scale"] * self.w["wscale"], - out_dtype=attn_output.dtype) + attn_output, 1 / self.w["scale"] + ) + return cutlass_scaled_fp4_mm( + a=quant_output, + b=self.w["weight"], + block_scale_a=output_block_scale, + block_scale_b=self.w["wscale_swizzled"], + alpha=self.w["scale"] * self.w["wscale"], + out_dtype=attn_output.dtype, + ) if current_platform.is_cuda(): - MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", - TestAttentionFp8StaticQuantPatternModel), - ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", - TestAttentionNvfp4QuantPatternModel)] + MODELS = [ + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + TestAttentionFp8StaticQuantPatternModel, + ), + ( + "nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel, + ), + ] HEADS = [(64, 8), (40, 8)] elif current_platform.is_rocm(): - MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV", - TestAttentionFp8StaticQuantPatternModel)] + MODELS = [ + ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) + ] HEADS = [(32, 8), (40, 8)] else: MODELS = [] @@ -331,41 +368,53 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) @pytest.mark.parametrize("head_size", [128]) -@pytest.mark.parametrize("batch_size", - [7, 256, 533] if current_platform.is_cuda() else [8]) +@pytest.mark.parametrize( + "batch_size", [7, 256, 533] if current_platform.is_cuda() else [8] +) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("model_name, model_class", MODELS) -@pytest.mark.parametrize("backend", - [_Backend.FLASHINFER] if current_platform.is_cuda() - else [_Backend.TRITON_ATTN]) @pytest.mark.parametrize( - "split_attention", - [False, True] if current_platform.is_rocm() else [False]) + "backend", + [_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN], +) +@pytest.mark.parametrize( + "split_attention", [False, True] if current_platform.is_rocm() else [False] +) # TODO(boyuan): test inductor graph partition on rocm @pytest.mark.parametrize( "use_inductor_graph_partition", - [False] if current_platform.is_rocm() else [False, True]) -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test ROCm or CUDA") + [False] if current_platform.is_rocm() else [False, True], +) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" +) @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif(current_platform.is_cuda() - and not current_platform.is_device_capability((10, 0)), - reason="On CUDA only test on SM100(Blackwell)") -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Only test ROCm or CUDA") -def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, - head_size: int, batch_size: int, - dtype: torch.dtype, model_name: str, - model_class: type[AttentionQuantPatternModel], - backend: _Backend, split_attention: bool, - use_inductor_graph_partition: bool, - monkeypatch, dist_init, caplog_vllm): +@pytest.mark.skipif( + current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)), + reason="On CUDA only test on SM100(Blackwell)", +) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA" +) +def test_attention_quant_pattern( + num_qo_heads: int, + num_kv_heads: int, + head_size: int, + batch_size: int, + dtype: torch.dtype, + model_name: str, + model_class: type[AttentionQuantPatternModel], + backend: _Backend, + split_attention: bool, + use_inductor_graph_partition: bool, + monkeypatch, + dist_init, + caplog_vllm, +): """Test AttentionStaticQuantPattern fusion pass""" - if use_inductor_graph_partition and not is_torch_equal_or_newer( - "2.9.0.dev"): - pytest.skip("inductor graph partition is only available " - "in PyTorch 2.9+") + if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): + pytest.skip("inductor graph partition is only available in PyTorch 2.9+") monkeypatch.setenv("VLLM_USE_V1", "1") if split_attention: @@ -386,21 +435,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, custom_ops=["+quant_fp8"], use_inductor_graph_partition=use_inductor_graph_partition, ), - cache_config=CacheConfig(cache_dtype="fp8")) + cache_config=CacheConfig(cache_dtype="fp8"), + ) # Create test inputs - q = torch.randn(batch_size, - num_qo_heads * head_size, - dtype=dtype, - device=device) - k = torch.randn(batch_size, - num_kv_heads * head_size, - dtype=dtype, - device=device) - v = torch.randn(batch_size, - num_kv_heads * head_size, - dtype=dtype, - device=device) + q = torch.randn(batch_size, num_qo_heads * head_size, dtype=dtype, device=device) + k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device) + v = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, device=device) # Mark first dimension as dynamic for realistic testing torch._dynamo.mark_dynamic(q, 0) @@ -409,42 +450,53 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, # Run model directly without compilation and fusion vllm_config_unfused = copy.deepcopy(vllm_config) - with set_current_vllm_config(vllm_config_unfused), set_forward_context( - attn_metadata=None, vllm_config=vllm_config_unfused - ), global_force_attn_backend_context_manager(backend): - model_unfused = model_class(num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - kv_cache_dtype=FP8_DTYPE, - device=device, - vllm_config=vllm_config_unfused) + with ( + set_current_vllm_config(vllm_config_unfused), + set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused), + global_force_attn_backend_context_manager(backend), + ): + model_unfused = model_class( + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config_unfused, + ) model_unfused = model_unfused.to(device) forward_ctx = get_forward_context() forward_ctx.attn_metadata = model_unfused.build_attn_metadata( - batch_size, use_hnd=split_attention) + batch_size, use_hnd=split_attention + ) # Run model directly without compilation and fusion result_unfused = model_unfused(q, k, v) # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( - enable_attn_fusion=True, enable_noop=True) - with set_current_vllm_config(vllm_config), set_forward_context( - attn_metadata=None, vllm_config=vllm_config - ), global_force_attn_backend_context_manager(backend): - model_fused = model_class(num_qo_heads=num_qo_heads, - num_kv_heads=num_kv_heads, - head_size=head_size, - kv_cache_dtype=FP8_DTYPE, - device=device, - vllm_config=vllm_config, - w=model_unfused.w) + enable_attn_fusion=True, enable_noop=True + ) + with ( + set_current_vllm_config(vllm_config), + set_forward_context(attn_metadata=None, vllm_config=vllm_config), + global_force_attn_backend_context_manager(backend), + ): + model_fused = model_class( + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config, + w=model_unfused.w, + ) model_fused = model_fused.to(device) forward_ctx = get_forward_context() forward_ctx.attn_metadata = model_fused.build_attn_metadata( - batch_size, use_hnd=split_attention) + batch_size, use_hnd=split_attention + ) # Create test backend with fusion passes enabled noop_pass = NoOpEliminationPass(vllm_config) @@ -454,9 +506,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) # Compile model with fusion enabled - model_compiled = torch.compile(model_fused, - backend=test_backend, - fullgraph=True) + model_compiled = torch.compile( + model_fused, backend=test_backend, fullgraph=True + ) assert model_compiled.attn._o_scale_float is None result_fused_1 = model_compiled(q, k, v) @@ -471,49 +523,49 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, assert model_compiled.attn._o_scale_float is not None - torch.testing.assert_close(result_unfused, - result_fused_2, - atol=1e-2, - rtol=1e-2) + torch.testing.assert_close( + result_unfused, result_fused_2, atol=1e-2, rtol=1e-2 + ) # Check attn fusion support quant_key = model_class.quant_key attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key) for key, layer in - vllm_config.compilation_config.static_forward_context.items() + layer.impl.fused_output_quant_supported(quant_key) + for key, layer in vllm_config.compilation_config.static_forward_context.items() ] if any(attn_fusion_supported): # Check quantization ops in the graph before and after fusion - test_backend.check_before_ops([QUANT_OPS[quant_key]], - fully_replaced=True) + test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True) # access the underlying `AttnFusionPass` on the `LazyInitPass` assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) # Check attention ops in the graph before and after fusion attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) - attn_nodes_post = list(find_op_nodes(ATTN_OP, - test_backend.graph_post_pass)) + attn_nodes_post = list(find_op_nodes(ATTN_OP, test_backend.graph_post_pass)) assert len(attn_nodes_pre) > 0, "Should have attention nodes before fusion" - assert len(attn_nodes_pre) == len(attn_nodes_post), \ + assert len(attn_nodes_pre) == len(attn_nodes_post), ( "Should have same number of attention nodes before and after fusion" - assert attn_nodes_pre[0].kwargs.get("output_scale") is None, \ + ) + assert attn_nodes_pre[0].kwargs.get("output_scale") is None, ( "Attention should not have output_scale before fusion" - assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ + ) + assert attn_nodes_post[0].kwargs.get("output_scale") is not None, ( "Attention should have output_scale after fusion" + ) - assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \ + assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, ( "Attention should not have output_block_scale before fusion" + ) if quant_key.dtype == FP8_DTYPE: - assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \ + assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, ( "Attention should not have output_block_scale after FP8 fusion" + ) elif quant_key.dtype == FP4_DTYPE: - assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ - "Attention should have output_block_scale after FP4 fusion" # noqa: E501 + assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, ( + "Attention should have output_block_scale after FP4 fusion" + ) # Check that results are close - torch.testing.assert_close(result_unfused, - result_fused_1, - atol=1e-2, - rtol=1e-2) + torch.testing.assert_close(result_unfused, result_fused_1, atol=1e-2, rtol=1e-2) diff --git a/tests/compile/test_noop_elimination.py b/tests/compile/test_noop_elimination.py index 242d53131267..fda7f4e3bafa 100644 --- a/tests/compile/test_noop_elimination.py +++ b/tests/compile/test_noop_elimination.py @@ -6,14 +6,12 @@ import vllm from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, - VllmConfig) +from vllm.config import CompilationConfig, CompilationLevel, PassConfig, VllmConfig from .backend import TestBackend -@pytest.mark.parametrize("dtype", - [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @pytest.mark.parametrize("num_tokens", [256, 1024]) @pytest.mark.parametrize("hidden_size", [64, 4096]) def test_noop_elimination(dtype, num_tokens, hidden_size): @@ -22,7 +20,6 @@ def test_noop_elimination(dtype, num_tokens, hidden_size): torch.manual_seed(1) class Model(torch.nn.Module): - def forward(self, x): # Chain of reshapes y = x.reshape(-1, 128, 32) @@ -32,7 +29,7 @@ def forward(self, x): # Final reshape that should remain b = a.reshape(-1, 128, 32) # No-op slice - c = b[0:b.shape[0]] + c = b[0 : b.shape[0]] # The pass should replace the result of this op with `c`. d = torch.slice_scatter( torch.ones_like(c), # Dummy tensor to be scattered into @@ -43,10 +40,12 @@ def forward(self, x): ) return d - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - pass_config=PassConfig(enable_noop=True), - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig(enable_noop=True), + ) + ) with vllm.config.set_current_vllm_config(vllm_config): noop_pass = NoOpEliminationPass(vllm_config) @@ -82,17 +81,18 @@ def test_non_noop_slice_preserved(): x = torch.randn(16, 16) class SliceModel(torch.nn.Module): - def forward(self, x): base = x.clone() src = torch.ones(15, 16) y = torch.slice_scatter(base, src, dim=0, start=0, end=-1) return x[0:-1, :], y - vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - pass_config=PassConfig(enable_noop=True), - )) + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + pass_config=PassConfig(enable_noop=True), + ) + ) with vllm.config.set_current_vllm_config(vllm_config): noop_pass = NoOpEliminationPass(vllm_config) backend = TestBackend(noop_pass) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 251cc46e9e98..ac561d2e8f84 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -28,7 +28,6 @@ def test_bad_callable(): # Pass that inherits from InductorPass class ProperPass(InductorPass): - def __call__(self, graph: torch.fx.graph.Graph) -> None: pass @@ -39,8 +38,7 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: ProperPass(), # Can also wrap callables in CallableInductorPass for compliance CallableInductorPass(simple_callable), - CallableInductorPass(simple_callable, - InductorPass.hash_source(__file__)) + CallableInductorPass(simple_callable, InductorPass.hash_source(__file__)), ], ) def test_pass_manager_uuid(callable): @@ -65,8 +63,9 @@ def test_pass_manager_uuid(callable): # UUID should be different due to config change config2 = copy.deepcopy(config) - config2.compilation_config.pass_config.enable_fusion = not \ - config2.compilation_config.pass_config.enable_fusion + config2.compilation_config.pass_config.enable_fusion = ( + not config2.compilation_config.pass_config.enable_fusion + ) pass_manager3 = PostGradPassManager() pass_manager3.configure(config2) pass_manager3.add(callable) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index b2734e915bbb..afb31cb95be0 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -12,14 +12,20 @@ from vllm.compilation.post_cleanup import PostCleanupPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.compilation.vllm_inductor_pass import VllmInductorPass -from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, - PassConfig, VllmConfig) +from vllm.config import ( + CompilationConfig, + DeviceConfig, + ModelConfig, + PassConfig, + VllmConfig, +) from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -36,16 +42,15 @@ class TestModel(torch.nn.Module): - - def __init__(self, - hidden_size=16, - intermediate_size=32, - vllm_config: VllmConfig = None): + def __init__( + self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None + ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.gate_proj = torch.nn.Parameter( - torch.empty((intermediate_size, hidden_size))) + torch.empty((intermediate_size, hidden_size)) + ) self.norm = RMSNorm(intermediate_size, 1e-05) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -53,18 +58,18 @@ def __init__(self, def forward(self, hidden_states, residual): """ Forward pass implementing the operations in the FX graph - + Args: hidden_states: Input tensor residual: Residual tensor from previous layer - + Returns: Tuple containing the output tensor """ # Reshape input view = hidden_states.reshape(-1, self.hidden_size) - #matrix multiplication + # matrix multiplication permute = self.gate_proj.permute(1, 0) mm = torch.mm(view, permute) @@ -82,7 +87,7 @@ def ops_in_model_before(self): def ops_in_model_after(self): return [ torch.ops.vllm.reduce_scatter.default, - torch.ops.vllm.all_gather.default + torch.ops.vllm.all_gather.default, ] def ops_in_model(self): @@ -90,18 +95,16 @@ def ops_in_model(self): class TestQuantModel(torch.nn.Module): - - def __init__(self, - hidden_size=16, - intermediate_size=32, - vllm_config: VllmConfig = None): + def __init__( + self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None + ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.vllm_config = vllm_config - self.gate_proj = torch.nn.Parameter(torch.empty( - (intermediate_size, hidden_size)), - requires_grad=False) + self.gate_proj = torch.nn.Parameter( + torch.empty((intermediate_size, hidden_size)), requires_grad=False + ) self.norm = RMSNorm(intermediate_size, 1e-05) # Initialize weights torch.nn.init.normal_(self.gate_proj, std=0.02) @@ -111,25 +114,24 @@ def __init__(self, self.scale = torch.rand(1, dtype=torch.float32) # Create a weight that is compatible with torch._scaled_mm, # which expects a column-major layout. - self.w = torch.rand(hidden_size, - intermediate_size).to(dtype=FP8_DTYPE).t() + self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() self.wscale = torch.rand(1, dtype=torch.float32) def forward(self, hidden_states, residual): """ Forward pass implementing the operations in the FX graph - + Args: hidden_states: Input tensor residual: Residual tensor from previous layer - + Returns: Tuple containing the output tensor """ # Reshape input view = hidden_states.reshape(-1, self.hidden_size) - #matrix multiplication + # matrix multiplication permute = self.gate_proj.permute(1, 0) mm = torch.mm(view, permute) @@ -140,45 +142,51 @@ def forward(self, hidden_states, residual): norm_output, residual_output = self.norm(all_reduce, residual) # scaled_mm with static input quantization - fp8_linear_result = self.fp8_linear.apply(norm_output, - self.w, - self.wscale, - input_scale=self.scale.to( - norm_output.device)) + fp8_linear_result = self.fp8_linear.apply( + norm_output, + self.w, + self.wscale, + input_scale=self.scale.to(norm_output.device), + ) return fp8_linear_result, residual_output def ops_in_model_before(self): - ops_to_remove = [torch.ops.vllm.all_reduce.default - ] # Always removed by SP + ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP # The following are only removed if fusion happens - if self.vllm_config and self.vllm_config.compilation_config \ - .pass_config.enable_fusion: - ops_to_remove.extend([ - torch.ops._C.fused_add_rms_norm.default, - torch.ops._C.static_scaled_fp8_quant.default, - ]) + if ( + self.vllm_config + and self.vllm_config.compilation_config.pass_config.enable_fusion + ): + ops_to_remove.extend( + [ + torch.ops._C.fused_add_rms_norm.default, + torch.ops._C.static_scaled_fp8_quant.default, + ] + ) return ops_to_remove def ops_in_model_after(self): ops_to_add = [ torch.ops.vllm.reduce_scatter.default, - torch.ops.vllm.all_gather.default + torch.ops.vllm.all_gather.default, ] # The following is only added if fusion happens - if self.vllm_config and self.vllm_config.compilation_config \ - .pass_config.enable_fusion: - ops_to_add.append( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) + if ( + self.vllm_config + and self.vllm_config.compilation_config.pass_config.enable_fusion + ): + ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default) return ops_to_add def ops_in_model(self): - if self.vllm_config and self.vllm_config.compilation_config \ - .pass_config.enable_fusion: + if ( + self.vllm_config + and self.vllm_config.compilation_config.pass_config.enable_fusion + ): # If fusion happens, the fused op is the one # we check for (de)functionalization - return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - ] # noqa: E501 + return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default] else: # If no fusion, the original ops are checked return [ @@ -195,30 +203,47 @@ def ops_in_model(self): @pytest.mark.parametrize("hidden_size", [16]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("enable_fusion", [True, False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") -def test_sequence_parallelism_pass(test_model_cls: type[torch.nn.Module], - batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype, - enable_fusion: bool): +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_sequence_parallelism_pass( + test_model_cls: type[torch.nn.Module], + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + enable_fusion: bool, +): num_processes = 2 def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda - torch.multiprocessing.spawn(fn, - args=(num_processes, test_model_cls, - batch_size, seq_len, hidden_size, - dtype, enable_fusion), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=( + num_processes, + test_model_cls, + batch_size, + seq_len, + hidden_size, + dtype, + enable_fusion, + ), + nprocs=nprocs, + ) run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes) def sequence_parallelism_pass_on_test_model( - local_rank: int, world_size: int, - test_model_cls: type[torch.nn.Module], batch_size: int, seq_len: int, - hidden_size: int, dtype: torch.dtype, enable_fusion: bool): + local_rank: int, + world_size: int, + test_model_cls: type[torch.nn.Module], + batch_size: int, + seq_len: int, + hidden_size: int, + dtype: torch.dtype, + enable_fusion: bool, +): current_platform.seed_everything(0) device = torch.device(f"cuda:{local_rank}") @@ -226,13 +251,15 @@ def sequence_parallelism_pass_on_test_model( torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) # initialize distributed init_distributed_environment() @@ -240,27 +267,28 @@ def sequence_parallelism_pass_on_test_model( # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( - enable_sequence_parallelism=True, - enable_fusion=enable_fusion, - enable_noop=True)) # NoOp needed for fusion + vllm_config.compilation_config = CompilationConfig( + pass_config=PassConfig( + enable_sequence_parallelism=True, + enable_fusion=enable_fusion, + enable_noop=True, + ) + ) # NoOp needed for fusion vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" - vllm_config.model_config = ModelConfig(model=model_name, - trust_remote_code=True, - dtype=dtype, - seed=42) + vllm_config.model_config = ModelConfig( + model=model_name, trust_remote_code=True, dtype=dtype, seed=42 + ) noop_pass = NoOpEliminationPass(vllm_config) sequence_parallelism_pass = SequenceParallelismPass(vllm_config) func_pass = FixFunctionalizationPass(vllm_config) cleanup_pass = PostCleanupPass(vllm_config) - passes_for_backend: list[VllmInductorPass] = \ - [noop_pass, sequence_parallelism_pass] + passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass] if enable_fusion: fusion_pass = RMSNormQuantFusionPass(vllm_config) @@ -271,12 +299,9 @@ def sequence_parallelism_pass_on_test_model( backend_no_func = TestBackend(*passes_for_backend) backend_func = TestBackend(*passes_for_backend, func_pass) - model = test_model_cls(hidden_size, - hidden_size * 2, - vllm_config=vllm_config) + model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config) - hidden_states = torch.randn((batch_size * seq_len, hidden_size), - dtype=dtype) + hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype) compiled_model_no_func = torch.compile(model, backend=backend_no_func) @@ -297,8 +322,7 @@ def sequence_parallelism_pass_on_test_model( # check if the functionalization pass is applied for op in model.ops_in_model(): find_auto_fn(backend_no_func.graph_post_pass.nodes, op) - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, - op) is None # noqa: E501 + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None # make sure the ops were all de-functionalized found = dict() diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index c445f4dde2cc..16a4271655ef 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -8,20 +8,25 @@ import vllm.envs as envs from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -# yapf conflicts with isort for this block -# yapf: disable from vllm.compilation.activation_quant_fusion import ( - FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass) -# yapf: enable + FUSED_OPS, + SILU_MUL_OP, + ActivationQuantFusionPass, +) from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, kFp8StaticTensorSym, kNvfp4Quant) + GroupShape, + kFp8StaticTensorSym, + kNvfp4Quant, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_fp8_supported) + Fp8LinearOp, + cutlass_fp8_supported, +) from vllm.platforms import current_platform from ..utils import override_cutlass_fp8_supported @@ -36,7 +41,6 @@ def is_nvfp4_supported(): class TestSiluMulFp8QuantModel(torch.nn.Module): - def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): super().__init__() self.silu_and_mul = SiluAndMul() @@ -53,10 +57,7 @@ def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): def forward(self, x): y = self.silu_and_mul(x) - x2 = self.fp8_linear.apply(y, - self.w, - self.wscale, - input_scale=self.wscale) + x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) return x2 def ops_in_model_before(self): @@ -67,11 +68,12 @@ def ops_in_model_after(self): class TestSiluMulNvfp4QuantModel(torch.nn.Module): - def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): super().__init__() from vllm.compilation.activation_quant_fusion import ( - silu_and_mul_nvfp4_quant_supported) + silu_and_mul_nvfp4_quant_supported, + ) + assert silu_and_mul_nvfp4_quant_supported self.silu_and_mul = SiluAndMul() @@ -88,12 +90,14 @@ def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): def forward(self, x): y = self.silu_and_mul(x) y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale) - out = cutlass_scaled_fp4_mm(a=y_quant, - b=self.w, - block_scale_a=y_block_scale, - block_scale_b=self.w_block_scale, - alpha=self.alpha, - out_dtype=y.dtype) + out = cutlass_scaled_fp4_mm( + a=y_quant, + b=self.w, + block_scale_a=y_block_scale, + block_scale_b=self.w_block_scale, + alpha=self.alpha, + out_dtype=y.dtype, + ) return out def ops_in_model_before(self): @@ -108,16 +112,24 @@ def ops_in_model_after(self): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize( "model_class", - cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] - if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])) + cast( + list[type], + [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] + if is_nvfp4_supported() + else [TestSiluMulFp8QuantModel], + ), +) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. -@pytest.mark.parametrize("cuda_force_torch", - [True, False] if cutlass_fp8_supported() else [True]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], - reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class, - cuda_force_torch): +@pytest.mark.parametrize( + "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] +) +@pytest.mark.skipif( + envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm" +) +def test_fusion_silu_and_mul_quant( + num_tokens, hidden_size, dtype, model_class, cuda_force_torch +): if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: pytest.skip("Duplicate tests for NVFP4") @@ -129,17 +141,13 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class, # Reshape pass is needed for the fusion pass to work config = VllmConfig() config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=True, enable_noop=True)) + pass_config=PassConfig(enable_fusion=True, enable_noop=True) + ) fusion_pass = ActivationQuantFusionPass(config) - passes = [ - NoOpEliminationPass(config), fusion_pass, - PostCleanupPass(config) - ] + passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] backend = TestBackend(*passes) - model = model_class(hidden_size=hidden_size, - cuda_force_torch=cuda_force_torch, - x=x) + model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x) # First dimension dynamic torch._dynamo.mark_dynamic(x, 0) @@ -155,10 +163,9 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class, elif model_class == TestSiluMulNvfp4QuantModel: atol, rtol = 1e-1, 1e-1 - torch.testing.assert_close(result[0].to(dtype=dtype), - result2[0].to(dtype=dtype), - atol=atol, - rtol=rtol) + torch.testing.assert_close( + result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol + ) assert fusion_pass.matched_count == 1 diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 5e39f6821d16..34db5a999cbd 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -10,7 +10,6 @@ class MyMod(torch.nn.Module): - def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): if cache is not None: return x + cache @@ -18,12 +17,12 @@ def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): class MyWrapper(TorchCompileWrapperWithCustomDispatcher): - def __init__(self, model): self.model = model compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__(compiled_callable, - compilation_level=CompilationLevel.DYNAMO_ONCE) + super().__init__( + compiled_callable, compilation_level=CompilationLevel.DYNAMO_ONCE + ) def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): # this is the function to be compiled @@ -54,10 +53,8 @@ def test_torch_compile_wrapper(): # for new input, dispatch to the compiled code directly new_x = torch.tensor([3]) - assert wrapper(new_x, - None).item() == 6 # dispatch to the first compiled code - assert wrapper( - new_x, cache).item() == 5 # dispatch to the second compiled code + assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code + assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code for wrapper in wrappers: # make sure they have independent compiled codes diff --git a/tests/config/test_config_generation.py b/tests/config/test_config_generation.py index e37b6b95941e..61c3df0a2348 100644 --- a/tests/config/test_config_generation.py +++ b/tests/config/test_config_generation.py @@ -14,8 +14,9 @@ def test_cuda_empty_vs_unset_configs(monkeypatch: pytest.MonkeyPatch): """ def create_config(): - engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite", - trust_remote_code=True) + engine_args = EngineArgs( + model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True + ) return engine_args.create_engine_config() # Create config with CUDA_VISIBLE_DEVICES set normally @@ -34,16 +35,18 @@ def create_config(): empty_config_dict.pop("instance_id", None) assert deep_compare(normal_config_dict, empty_config_dict), ( - "Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\"" - " should be equivalent") + 'Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=""' + " should be equivalent" + ) def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch): # In testing, this method needs to be nested inside as ray does not # see the test module. def create_config(): - engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite", - trust_remote_code=True) + engine_args = EngineArgs( + model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True + ) return engine_args.create_engine_config() config = create_config() @@ -51,6 +54,7 @@ def create_config(): assert parallel_config.ray_runtime_env is None import ray + ray.init() runtime_env = { @@ -59,13 +63,13 @@ def create_config(): }, } - config_ref = ray.remote(create_config).options( - runtime_env=runtime_env).remote() + config_ref = ray.remote(create_config).options(runtime_env=runtime_env).remote() config = ray.get(config_ref) parallel_config = config.parallel_config assert parallel_config.ray_runtime_env is not None - assert parallel_config.ray_runtime_env.env_vars().get( - "TEST_ENV_VAR") == "test_value" + assert ( + parallel_config.ray_runtime_env.env_vars().get("TEST_ENV_VAR") == "test_value" + ) ray.shutdown() diff --git a/tests/config/test_mp_reducer.py b/tests/config/test_mp_reducer.py index d4d4be293280..9c03f26c504e 100644 --- a/tests/config/test_mp_reducer.py +++ b/tests/config/test_mp_reducer.py @@ -16,13 +16,13 @@ def test_mp_reducer(monkeypatch): """ # Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value - monkeypatch.setenv('VLLM_USE_V1', '1') + monkeypatch.setenv("VLLM_USE_V1", "1") # Ensure transformers_modules is not in sys.modules - if 'transformers_modules' in sys.modules: - del sys.modules['transformers_modules'] + if "transformers_modules" in sys.modules: + del sys.modules["transformers_modules"] - with patch('multiprocessing.reducer.register') as mock_register: + with patch("multiprocessing.reducer.register") as mock_register: engine_args = AsyncEngineArgs( model="facebook/opt-125m", max_model_len=32, @@ -36,7 +36,8 @@ def test_mp_reducer(monkeypatch): ) assert mock_register.called, ( - "multiprocessing.reducer.register should have been called") + "multiprocessing.reducer.register should have been called" + ) vllm_config_registered = False for call_args in mock_register.call_args_list: @@ -45,8 +46,7 @@ def test_mp_reducer(monkeypatch): vllm_config_registered = True reducer_func = call_args[0][1] - assert callable( - reducer_func), "Reducer function should be callable" + assert callable(reducer_func), "Reducer function should be callable" break assert vllm_config_registered, ( diff --git a/tests/conftest.py b/tests/conftest.py index c61a8f8dd539..c03fd84ade1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,23 +30,27 @@ import torch.nn.functional as F from huggingface_hub import snapshot_download from PIL import Image -from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - BatchEncoding, BatchFeature) +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + BatchEncoding, + BatchFeature, +) from transformers.models.auto.auto_factory import _BaseAutoModelClass -from tests.models.utils import (TokensTextLogprobs, - TokensTextLogprobsPromptLogprobs) +from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs from vllm import LLM, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config.model import (ConvertOption, RunnerOption, - _get_and_verify_dtype) +from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype from vllm.connections import global_http_connection -from vllm.distributed import (cleanup_dist_env_and_memory, - init_distributed_environment, - initialize_model_parallel) -from vllm.inputs import TextPrompt +from vllm.distributed import ( + cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel, +) from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.multimodal.utils import fetch_image @@ -83,12 +87,13 @@ class ImageAssetPrompts(TypedDict): class ImageTestAssets(list[ImageAsset]): - def __init__(self) -> None: - super().__init__([ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), - ]) + super().__init__( + [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), + ] + ) def prompts(self, prompts: ImageAssetPrompts) -> list[str]: """ @@ -105,11 +110,12 @@ class VideoAssetPrompts(TypedDict): class VideoTestAssets(list[VideoAsset]): - def __init__(self) -> None: - super().__init__([ - VideoAsset("baby_reading"), - ]) + super().__init__( + [ + VideoAsset("baby_reading"), + ] + ) def prompts(self, prompts: VideoAssetPrompts) -> list[str]: return [prompts["baby_reading"]] @@ -121,12 +127,13 @@ class AudioAssetPrompts(TypedDict): class AudioTestAssets(list[AudioAsset]): - def __init__(self) -> None: - super().__init__([ - AudioAsset("mary_had_lamb"), - AudioAsset("winning_call"), - ]) + super().__init__( + [ + AudioAsset("mary_had_lamb"), + AudioAsset("winning_call"), + ] + ) def prompts(self, prompts: AudioAssetPrompts) -> list[str]: return [prompts["mary_had_lamb"], prompts["winning_call"]] @@ -221,6 +228,7 @@ def example_system_message() -> str: class DecoderPromptType(Enum): """For encoder/decoder models only.""" + CUSTOM = 1 NONE = 2 EMPTY_STR = 3 @@ -254,15 +262,13 @@ def audio_assets() -> AudioTestAssets: class HfRunner: - def get_default_device(self): from vllm.platforms import current_platform - return ("cpu" - if current_platform.is_cpu() else current_platform.device_type) + return "cpu" if current_platform.is_cpu() else current_platform.device_type def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: - if x is None or isinstance(x, (bool, )): + if x is None or isinstance(x, (bool,)): return x if device is None: @@ -290,8 +296,11 @@ def __init__( # Set this to avoid hanging issue default_torch_num_threads: Optional[int] = None, ) -> None: - init_ctx = (nullcontext() if default_torch_num_threads is None else - set_default_torch_num_threads(default_torch_num_threads)) + init_ctx = ( + nullcontext() + if default_torch_num_threads is None + else set_default_torch_num_threads(default_torch_num_threads) + ) with init_ctx: self._init( @@ -363,14 +372,15 @@ def _init( ) # in case some unquantized custom models are not in same dtype - if (getattr(model, "quantization_method", None) is None - and any(p.dtype != self.dtype - for p in model.parameters())): + if getattr(model, "quantization_method", None) is None and any( + p.dtype != self.dtype for p in model.parameters() + ): model = model.to(dtype=self.dtype) - if (getattr(model, "quantization_method", None) != "bitsandbytes" - and len({p.device - for p in model.parameters()}) < 2): + if ( + getattr(model, "quantization_method", None) != "bitsandbytes" + and len({p.device for p in model.parameters()}) < 2 + ): model = model.to(device=self.device) self.model = model @@ -385,6 +395,7 @@ def _init( # don't put this import at the top level # it will call torch.cuda.device_count() from transformers import AutoProcessor # noqa: F401 + self.processor = AutoProcessor.from_pretrained( model_name, torch_dtype=torch_dtype, @@ -472,10 +483,9 @@ def generate( audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + all_inputs = self.get_inputs( + prompts, images=images, videos=videos, audios=audios + ) outputs: list[tuple[list[list[int]], list[str]]] = [] for inputs in all_inputs: @@ -502,16 +512,17 @@ def generate_greedy( audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[tuple[list[int], str]]: - outputs = self.generate(prompts, - do_sample=False, - max_new_tokens=max_tokens, - images=images, - videos=videos, - audios=audios, - **kwargs) + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + images=images, + videos=videos, + audios=audios, + **kwargs, + ) - return [(output_ids[0], output_str[0]) - for output_ids, output_str in outputs] + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] def generate_beam_search( self, @@ -522,21 +533,22 @@ def generate_beam_search( videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> list[tuple[list[list[int]], list[str]]]: - outputs = self.generate(prompts, - do_sample=False, - max_new_tokens=max_tokens, - num_beams=beam_width, - num_return_sequences=beam_width, - images=images, - videos=videos, - audios=audios) + outputs = self.generate( + prompts, + do_sample=False, + max_new_tokens=max_tokens, + num_beams=beam_width, + num_return_sequences=beam_width, + images=images, + videos=videos, + audios=audios, + ) for i in range(len(outputs)): output_ids, output_str = outputs[i] for j in range(len(output_ids)): output_ids[j] = [ - x for x in output_ids[j] - if x != self.tokenizer.pad_token_id + x for x in output_ids[j] if x != self.tokenizer.pad_token_id ] outputs[i] = (output_ids, output_str) return outputs @@ -550,10 +562,9 @@ def generate_greedy_logprobs( audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[list[torch.Tensor]]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + all_inputs = self.get_inputs( + prompts, images=images, videos=videos, audios=audios + ) all_logprobs: list[list[torch.Tensor]] = [] for inputs in all_inputs: @@ -566,8 +577,7 @@ def generate_greedy_logprobs( return_dict_in_generate=True, **kwargs, ) - seq_logprobs = self._hidden_states_to_seq_logprobs( - output.hidden_states) + seq_logprobs = self._hidden_states_to_seq_logprobs(output.hidden_states) all_logprobs.append(seq_logprobs) return all_logprobs @@ -631,10 +641,9 @@ def generate_greedy_logprobs_limit( videos: Optional[PromptVideoInput] = None, **kwargs: Any, ) -> list[TokensTextLogprobs]: - all_inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + all_inputs = self.get_inputs( + prompts, images=images, videos=videos, audios=audios + ) all_logprobs: list[list[dict[int, float]]] = [] all_output_ids: list[list[int]] = [] @@ -654,8 +663,7 @@ def generate_greedy_logprobs_limit( ( seq_logprobs_lst, output_len, - ) = self._hidden_states_to_logprobs(output.hidden_states, - num_logprobs) + ) = self._hidden_states_to_logprobs(output.hidden_states, num_logprobs) all_logprobs.append(seq_logprobs_lst) seq_ids = output.sequences[0] @@ -665,19 +673,16 @@ def generate_greedy_logprobs_limit( all_output_strs.append(self.tokenizer.decode(output_ids)) outputs = zip(all_output_ids, all_output_strs, all_logprobs) - return [(output_ids, output_str, output_logprobs) - for output_ids, output_str, output_logprobs in outputs] + return [ + (output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs + ] - def encode(self, prompts: list[str], *args, - **kwargs) -> list[list[torch.Tensor]]: + def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]: return self.model.encode(prompts, *args, **kwargs) - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - return self.model.predict(prompts, - *args, - convert_to_tensor=True, - **kwargs) + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: + return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs) def __enter__(self): return self @@ -728,10 +733,17 @@ def __init__( default_torch_num_threads: Optional[int] = None, **kwargs, ) -> None: - init_ctx = (nullcontext() if default_torch_num_threads is None else - set_default_torch_num_threads(default_torch_num_threads)) + init_ctx = ( + nullcontext() + if default_torch_num_threads is None + else set_default_torch_num_threads(default_torch_num_threads) + ) if not kwargs.get("compilation_config", None): + # Note(@tdoublep): This is set to 4 because some tests (e.g., hybrid + # model tests) may set max_num_seqs=4. If min cudagraph_capture_size is + # set to larger than max_num_seqs, then it will lead to *no* graphs + # being captured which can trigger edge cases that we don't handle yet. kwargs["compilation_config"] = {"cudagraph_capture_sizes": [4]} with init_ctx: @@ -760,17 +772,25 @@ def get_inputs( images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> list[TextPrompt]: - - if any(x is not None and len(x) != len(prompts) - for x in [images, videos, audios]): + ) -> list[dict[str, Any]]: + if any( + x is not None and len(x) != len(prompts) for x in [images, videos, audios] + ): raise ValueError( - "All non-None multimodal inputs must have the same length as " - "prompts") + "All non-None multimodal inputs must have the same length as prompts" + ) - inputs = [] + inputs = list[dict[str, Any]]() for i, prompt in enumerate(prompts): - multi_modal_data = {} + prompt_dict = dict[str, Any]() + if isinstance(prompt, str): + prompt_dict["prompt"] = prompt + elif isinstance(prompt, list): + prompt_dict["prompt_token_ids"] = prompt + else: + prompt_dict["prompt_embeds"] = prompt + + multi_modal_data = dict[str, Any]() if images is not None and (image := images[i]) is not None: multi_modal_data["image"] = image if videos is not None and (video := videos[i]) is not None: @@ -778,17 +798,10 @@ def get_inputs( if audios is not None and (audio := audios[i]) is not None: multi_modal_data["audio"] = audio - text_prompt_kwargs: dict[str, Any] = { - "multi_modal_data": multi_modal_data or None - } - if isinstance(prompt, str): - text_prompt_kwargs["prompt"] = prompt - elif isinstance(prompt, list): - text_prompt_kwargs["prompt_token_ids"] = prompt - else: - text_prompt_kwargs["prompt_embeds"] = prompt + if multi_modal_data: + prompt_dict["multi_modal_data"] = multi_modal_data - inputs.append(TextPrompt(**text_prompt_kwargs)) + inputs.append(prompt_dict) return inputs @@ -801,14 +814,11 @@ def generate( audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - req_outputs = self.llm.generate(inputs, - sampling_params=sampling_params, - **kwargs) + req_outputs = self.llm.generate( + inputs, sampling_params=sampling_params, **kwargs + ) outputs: list[tuple[list[list[int]], list[str]]] = [] for req_output in req_outputs: @@ -835,8 +845,9 @@ def _final_steps_generate_w_logprobs( output_str = sample.text output_ids = list(sample.token_ids) output_logprobs = sample.logprobs - outputs.append((output_ids, output_str, output_logprobs, - req_output.prompt_logprobs)) + outputs.append( + (output_ids, output_str, output_logprobs, req_output.prompt_logprobs) + ) return outputs def generate_w_logprobs( @@ -847,23 +858,22 @@ def generate_w_logprobs( audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, **kwargs: Any, - ) -> Union[list[TokensTextLogprobs], - list[TokensTextLogprobsPromptLogprobs]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) - - req_outputs = self.llm.generate(inputs, - sampling_params=sampling_params, - **kwargs) - - toks_str_logsprobs_prompt_logprobs = ( - self._final_steps_generate_w_logprobs(req_outputs)) + ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) + + req_outputs = self.llm.generate( + inputs, sampling_params=sampling_params, **kwargs + ) + + toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs( + req_outputs + ) # Omit prompt logprobs if not required by sampling params - return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] - if sampling_params.prompt_logprobs is None else - toks_str_logsprobs_prompt_logprobs) + return ( + [x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None + else toks_str_logsprobs_prompt_logprobs + ) def generate_greedy( self, @@ -875,14 +885,15 @@ def generate_greedy( **kwargs: Any, ) -> list[tuple[list[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, - greedy_params, - images=images, - videos=videos, - audios=audios, - **kwargs) - return [(output_ids[0], output_str[0]) - for output_ids, output_str in outputs] + outputs = self.generate( + prompts, + greedy_params, + images=images, + videos=videos, + audios=audios, + **kwargs, + ) + return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] def generate_greedy_logprobs( self, @@ -896,22 +907,24 @@ def generate_greedy_logprobs( stop_token_ids: Optional[list[int]] = None, stop: Optional[list[str]] = None, **kwargs: Any, - ) -> Union[list[TokensTextLogprobs], - list[TokensTextLogprobsPromptLogprobs]]: + ) -> Union[list[TokensTextLogprobs], list[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs, prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids, - stop=stop) + stop=stop, + ) - return self.generate_w_logprobs(prompts, - greedy_logprobs_params, - images=images, - audios=audios, - videos=videos, - **kwargs) + return self.generate_w_logprobs( + prompts, + greedy_logprobs_params, + images=images, + audios=audios, + videos=videos, + **kwargs, + ) def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: """ @@ -920,10 +933,9 @@ def generate_prompt_perplexity(self, prompts: list[str]) -> list[float]: :param prompts: list of prompts to score :return: perplexity score of each prompt """ - outputs = self.generate_greedy_logprobs(prompts, - max_tokens=1, - num_logprobs=None, - num_prompt_logprobs=0) + outputs = self.generate_greedy_logprobs( + prompts, max_tokens=1, num_logprobs=None, num_prompt_logprobs=0 + ) perplexities = [] for output in outputs: @@ -952,15 +964,13 @@ def generate_beam_search( audios: Optional[PromptAudioInput] = None, concurrency_limit: Optional[int] = None, ) -> list[tuple[list[list[int]], list[str]]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) - - outputs = self.llm.beam_search(inputs, - BeamSearchParams(beam_width=beam_width, - max_tokens=max_tokens), - concurrency_limit=concurrency_limit) + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) + + outputs = self.llm.beam_search( + inputs, + BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens), + concurrency_limit=concurrency_limit, + ) returned_outputs = [] for output in outputs: token_ids = [x.tokens for x in output.sequences] @@ -972,17 +982,16 @@ def classify(self, prompts: list[str]) -> list[list[float]]: req_outputs = self.llm.classify(prompts) return [req_output.outputs.probs for req_output in req_outputs] - def embed(self, - prompts: list[str], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - *args, - **kwargs) -> list[list[float]]: - inputs = self.get_inputs(prompts, - images=images, - videos=videos, - audios=audios) + def embed( + self, + prompts: list[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + *args, + **kwargs, + ) -> list[list[float]]: + inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.llm.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] @@ -1027,6 +1036,7 @@ def vllm_runner(): @pytest.fixture() def temporary_enable_log_propagate(): import logging + logger = logging.getLogger("vllm") logger.propagate = True yield @@ -1046,6 +1056,7 @@ def num_gpus_available(): in current process.""" from vllm.platforms import current_platform + return current_platform.device_count() @@ -1059,12 +1070,11 @@ def num_gpus_available(): def dummy_opt_path(): json_path = os.path.join(_dummy_opt_path, "config.json") if not os.path.exists(_dummy_opt_path): - snapshot_download(repo_id="facebook/opt-125m", - local_dir=_dummy_opt_path, - ignore_patterns=[ - "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack" - ]) + snapshot_download( + repo_id="facebook/opt-125m", + local_dir=_dummy_opt_path, + ignore_patterns=["*.bin", "*.bin.index.json", "*.pt", "*.h5", "*.msgpack"], + ) assert os.path.exists(json_path) with open(json_path) as f: config = json.load(f) @@ -1078,12 +1088,18 @@ def dummy_opt_path(): def dummy_llava_path(): json_path = os.path.join(_dummy_llava_path, "config.json") if not os.path.exists(_dummy_llava_path): - snapshot_download(repo_id="llava-hf/llava-1.5-7b-hf", - local_dir=_dummy_llava_path, - ignore_patterns=[ - "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack", "*.safetensors" - ]) + snapshot_download( + repo_id="llava-hf/llava-1.5-7b-hf", + local_dir=_dummy_llava_path, + ignore_patterns=[ + "*.bin", + "*.bin.index.json", + "*.pt", + "*.h5", + "*.msgpack", + "*.safetensors", + ], + ) assert os.path.exists(json_path) with open(json_path) as f: config = json.load(f) @@ -1097,12 +1113,18 @@ def dummy_llava_path(): def dummy_gemma2_embedding_path(): json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json") if not os.path.exists(_dummy_gemma2_embedding_path): - snapshot_download(repo_id="BAAI/bge-multilingual-gemma2", - local_dir=_dummy_gemma2_embedding_path, - ignore_patterns=[ - "*.bin", "*.bin.index.json", "*.pt", "*.h5", - "*.msgpack", "*.safetensors" - ]) + snapshot_download( + repo_id="BAAI/bge-multilingual-gemma2", + local_dir=_dummy_gemma2_embedding_path, + ignore_patterns=[ + "*.bin", + "*.bin.index.json", + "*.pt", + "*.h5", + "*.msgpack", + "*.safetensors", + ], + ) assert os.path.exists(json_path) with open(json_path) as f: config = json.load(f) @@ -1115,10 +1137,9 @@ def dummy_gemma2_embedding_path(): # Add the flag `--optional` to allow run tests # that are marked with @pytest.mark.optional def pytest_addoption(parser): - parser.addoption("--optional", - action="store_true", - default=False, - help="run optional test") + parser.addoption( + "--optional", action="store_true", default=False, help="run optional test" + ) def pytest_collection_modifyitems(config, items): @@ -1186,7 +1207,6 @@ def _find_free_port() -> int: class LocalAssetServer: - address: str port: int server: Optional[http.server.ThreadingHTTPServer] @@ -1201,9 +1221,9 @@ def __init__(self, address: str = "127.0.0.1") -> None: def __enter__(self): self.port = _find_free_port() self.server = http.server.ThreadingHTTPServer( - (self.address, self.port), AssetHandler) - self.thread = threading.Thread(target=self.server.serve_forever, - daemon=True) + (self.address, self.port), AssetHandler + ) + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.thread.start() return self @@ -1237,7 +1257,7 @@ def get_image_asset(self, name: str) -> Image.Image: @pytest.fixture(scope="session") def local_asset_server() -> Generator[LocalAssetServer, None, None]: """ - Starts a thread based HTTP server bound to 127.0.0.1 on a random free port. + Starts a thread based HTTP server bound to 127.0.0.1 on a random free port. The server currently servers images at: http://127.0.0.1:<port>/<name>.<ext> """ diff --git a/tests/cuda/test_cuda_context.py b/tests/cuda/test_cuda_context.py index f973b284b87e..6336f2112c66 100644 --- a/tests/cuda/test_cuda_context.py +++ b/tests/cuda/test_cuda_context.py @@ -13,7 +13,7 @@ def check_cuda_context(): """Check CUDA driver context status""" try: - cuda = ctypes.CDLL('libcuda.so') + cuda = ctypes.CDLL("libcuda.so") device = ctypes.c_int() result = cuda.cuCtxGetDevice(ctypes.byref(device)) return (True, device.value) if result == 0 else (False, None) @@ -27,9 +27,11 @@ def run_cuda_test_in_thread(device_input, expected_device_id): # New thread should have no CUDA context initially valid_before, device_before = check_cuda_context() if valid_before: - return False, \ - "CUDA context should not exist in new thread, " \ - f"got device {device_before}" + return ( + False, + "CUDA context should not exist in new thread, " + f"got device {device_before}", + ) # Test setting CUDA context current_platform.set_device(device_input) @@ -39,8 +41,7 @@ def run_cuda_test_in_thread(device_input, expected_device_id): if not valid_after: return False, "CUDA context should be valid after set_cuda_context" if device_id != expected_device_id: - return False, \ - f"Expected device {expected_device_id}, got {device_id}" + return False, f"Expected device {expected_device_id}, got {device_id}" return True, "Success" except Exception as e: @@ -50,30 +51,30 @@ def run_cuda_test_in_thread(device_input, expected_device_id): class TestSetCudaContext: """Test suite for the set_cuda_context function.""" - @pytest.mark.skipif(not current_platform.is_cuda(), - reason="CUDA not available") - @pytest.mark.parametrize(argnames="device_input,expected_device_id", - argvalues=[ - (0, 0), - (torch.device('cuda:0'), 0), - ('cuda:0', 0), - ], - ids=["int", "torch_device", "string"]) - def test_set_cuda_context_parametrized(self, device_input, - expected_device_id): + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") + @pytest.mark.parametrize( + argnames="device_input,expected_device_id", + argvalues=[ + (0, 0), + (torch.device("cuda:0"), 0), + ("cuda:0", 0), + ], + ids=["int", "torch_device", "string"], + ) + def test_set_cuda_context_parametrized(self, device_input, expected_device_id): """Test setting CUDA context in isolated threads.""" with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(run_cuda_test_in_thread, device_input, - expected_device_id) + future = executor.submit( + run_cuda_test_in_thread, device_input, expected_device_id + ) success, message = future.result(timeout=30) assert success, message - @pytest.mark.skipif(not current_platform.is_cuda(), - reason="CUDA not available") + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") def test_set_cuda_context_invalid_device_type(self): """Test error handling for invalid device type.""" with pytest.raises(ValueError, match="Expected a cuda device"): - current_platform.set_device(torch.device('cpu')) + current_platform.set_device(torch.device("cpu")) if __name__ == "__main__": diff --git a/tests/detokenizer/test_disable_detokenization.py b/tests/detokenizer/test_disable_detokenization.py index ae06a985c7ec..a77626df5dc7 100644 --- a/tests/detokenizer/test_disable_detokenization.py +++ b/tests/detokenizer/test_disable_detokenization.py @@ -17,20 +17,16 @@ def test_computed_prefix_blocks(model: str): prompt = ( "You are a helpful assistant. How do I build a car from cardboard and " "paper clips? Is there an easy to follow video tutorial available " - "online for free?") + "online for free?" + ) llm = LLM(model=model) - sampling_params = SamplingParams(max_tokens=10, - temperature=0.0, - detokenize=False) + sampling_params = SamplingParams(max_tokens=10, temperature=0.0, detokenize=False) - outputs_no_detokenization = llm.generate(prompt, - sampling_params)[0].outputs[0] + outputs_no_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0] sampling_params.detokenize = True - outputs_with_detokenization = llm.generate(prompt, - sampling_params)[0].outputs[0] + outputs_with_detokenization = llm.generate(prompt, sampling_params)[0].outputs[0] - assert outputs_no_detokenization.text == '' - assert outputs_with_detokenization.text != '' - assert outputs_no_detokenization.token_ids == \ - outputs_with_detokenization.token_ids + assert outputs_no_detokenization.text == "" + assert outputs_with_detokenization.text != "" + assert outputs_no_detokenization.token_ids == outputs_with_detokenization.token_ids diff --git a/tests/detokenizer/test_min_tokens.py b/tests/detokenizer/test_min_tokens.py index 26003373c569..1f8e944695bd 100644 --- a/tests/detokenizer/test_min_tokens.py +++ b/tests/detokenizer/test_min_tokens.py @@ -8,15 +8,17 @@ from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer -PROMPT = "Hello, my name is Lee, and I'm a student in the " + \ - "college of engineering" +PROMPT = "Hello, my name is Lee, and I'm a student in the " + "college of engineering" -@pytest.mark.parametrize("min_tokens,stop,truth", [ - (0, None, " is Lee, and I'm a student in the college of engineering"), - (0, "e", " is L"), - (5, "e", " is Lee, and I'm a stud"), -]) +@pytest.mark.parametrize( + "min_tokens,stop,truth", + [ + (0, None, " is Lee, and I'm a student in the college of engineering"), + (0, "e", " is L"), + (5, "e", " is Lee, and I'm a stud"), + ], +) def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str): """Test for a specific min_tokens and stop. @@ -31,16 +33,18 @@ def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str): stop=stop, min_tokens=min_tokens, ) - request = EngineCoreRequest(request_id="", - prompt_token_ids=prompt_token_ids, - mm_features=None, - sampling_params=params, - pooling_params=None, - eos_token_id=None, - arrival_time=0.0, - lora_request=None, - cache_salt=None, - data_parallel_rank=None) + request = EngineCoreRequest( + request_id="", + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + ) detokenizer = FastIncrementalDetokenizer(tokenizer, request) diff --git a/tests/detokenizer/test_stop_reason.py b/tests/detokenizer/test_stop_reason.py index 1ff679789c95..6565949cc50f 100644 --- a/tests/detokenizer/test_stop_reason.py +++ b/tests/detokenizer/test_stop_reason.py @@ -31,34 +31,39 @@ def test_stop_reason(vllm_model, example_prompts): llm = vllm_model.llm # test stop token - outputs = llm.generate(example_prompts, - sampling_params=SamplingParams( - ignore_eos=True, - seed=SEED, - max_tokens=MAX_TOKENS, - stop_token_ids=[stop_token_id])) + outputs = llm.generate( + example_prompts, + sampling_params=SamplingParams( + ignore_eos=True, + seed=SEED, + max_tokens=MAX_TOKENS, + stop_token_ids=[stop_token_id], + ), + ) for output in outputs: output = output.outputs[0] assert output.finish_reason == "stop" assert output.stop_reason == stop_token_id # test stop string - outputs = llm.generate(example_prompts, - sampling_params=SamplingParams( - ignore_eos=True, - seed=SEED, - max_tokens=MAX_TOKENS, - stop=".")) + outputs = llm.generate( + example_prompts, + sampling_params=SamplingParams( + ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop="." + ), + ) for output in outputs: output = output.outputs[0] assert output.finish_reason == "stop" assert output.stop_reason == STOP_STR # test EOS token - outputs = llm.generate(example_prompts, - sampling_params=SamplingParams( - seed=SEED, max_tokens=MAX_TOKENS)) + outputs = llm.generate( + example_prompts, + sampling_params=SamplingParams(seed=SEED, max_tokens=MAX_TOKENS), + ) for output in outputs: output = output.outputs[0] assert output.finish_reason == "length" or ( - output.finish_reason == "stop" and output.stop_reason is None) + output.finish_reason == "stop" and output.stop_reason is None + ) diff --git a/tests/detokenizer/test_stop_string_while_stop_model_terminates.py b/tests/detokenizer/test_stop_string_while_stop_model_terminates.py index 9b32a2927f2d..5624332ef71d 100644 --- a/tests/detokenizer/test_stop_string_while_stop_model_terminates.py +++ b/tests/detokenizer/test_stop_string_while_stop_model_terminates.py @@ -14,7 +14,6 @@ def include_stop_str_in_output(request): class _DummyDetokenizer(BaseIncrementalDetokenizer): - def __init__(self, request: EngineCoreRequest): super().__init__(request) @@ -27,7 +26,8 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0): params = SamplingParams( stop=stop, include_stop_str_in_output=include_stop_str_in_output, - min_tokens=min_tokens) + min_tokens=min_tokens, + ) # Keep other fields minimal for unit test purposes. req = EngineCoreRequest( request_id="test", @@ -44,26 +44,25 @@ def _make_request(stop, include_stop_str_in_output: bool, min_tokens: int = 0): return req -def test_stop_string_while_stop_token_terminates( - include_stop_str_in_output: bool): +def test_stop_string_while_stop_token_terminates(include_stop_str_in_output: bool): """ This test verifies that the detokenizer correctly handles the case where the generated token sequence contains both: - a stop token - an <eos> token - + The detokenizer should respect the stop string and truncate the output accordingly. - + Imagine the following sequence: - "abcdeZ" is generated, where "Z" is the <eos> token. - "cd" is the stop string. - + If include_stop_str_in_output=False, the detokenizer should truncate the output to "ab" because the stop string "cd" is excluded. If include_stop_str_in_output=True, the detokenizer should include the stop string "cd" in the output, resulting in "abcd". - + This verifies the behavioral change introduced in BaseIncrementalDetokenizer where stop-string evaluation occurs before the early-return on @@ -78,8 +77,9 @@ def test_stop_string_while_stop_token_terminates( token_ids = [ord(c) for c in generated_text] # Create a request with the stop string and initialize the detokenizer. - req = _make_request(stop=[stop_string], - include_stop_str_in_output=include_stop_str_in_output) + req = _make_request( + stop=[stop_string], include_stop_str_in_output=include_stop_str_in_output + ) detok = _DummyDetokenizer(req) # Simulate that the last token ('Z') is a stop token (stop_terminated=True). @@ -99,5 +99,4 @@ def test_stop_string_while_stop_token_terminates( # get_next_output_text should return the full text when finished=True. # (Buffering only applies during streaming when finished=False.) - assert detok.get_next_output_text(finished=True, - delta=False) == expected_text + assert detok.get_next_output_text(finished=True, delta=False) == expected_text diff --git a/tests/detokenizer/test_stop_strings.py b/tests/detokenizer/test_stop_strings.py index 46f7d58c438c..70cc7e31b8ad 100644 --- a/tests/detokenizer/test_stop_strings.py +++ b/tests/detokenizer/test_stop_strings.py @@ -11,12 +11,14 @@ MAX_TOKENS = 200 -def _test_stopping(llm: LLM, - expected_output: str, - expected_reason: Any, - stop: Optional[list[str]] = None, - stop_token_ids: Optional[list[int]] = None, - include_in_output: bool = False) -> None: +def _test_stopping( + llm: LLM, + expected_output: str, + expected_reason: Any, + stop: Optional[list[str]] = None, + stop_token_ids: Optional[list[int]] = None, + include_in_output: bool = False, +) -> None: output = llm.generate( "A story about vLLM:\n", SamplingParams( @@ -25,7 +27,8 @@ def _test_stopping(llm: LLM, stop=stop, stop_token_ids=stop_token_ids, include_stop_str_in_output=include_in_output, - ))[0].outputs[0] + ), + )[0].outputs[0] assert output is not None assert output.text == expected_output @@ -33,17 +36,21 @@ def _test_stopping(llm: LLM, def _stop_basic(llm): - _test_stopping(llm, - stop=["."], - include_in_output=False, - expected_output="VLLM is a 100% volunteer organization", - expected_reason=".") + _test_stopping( + llm, + stop=["."], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=".", + ) - _test_stopping(llm, - stop=["."], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organization.", - expected_reason=".") + _test_stopping( + llm, + stop=["."], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization.", + expected_reason=".", + ) def _stop_multi_tokens(llm): @@ -52,45 +59,54 @@ def _stop_multi_tokens(llm): stop=["group of peo", "short"], include_in_output=False, expected_output="VLLM is a 100% volunteer organization. We are a ", - expected_reason="group of peo") + expected_reason="group of peo", + ) _test_stopping( llm, stop=["group of peo", "short"], include_in_output=True, - expected_output= - "VLLM is a 100% volunteer organization. We are a group of peo", - expected_reason="group of peo") + expected_output="VLLM is a 100% volunteer organization. We are a group of peo", + expected_reason="group of peo", + ) def _stop_partial_token(llm): - _test_stopping(llm, - stop=["gani"], - include_in_output=False, - expected_output="VLLM is a 100% volunteer or", - expected_reason="gani") + _test_stopping( + llm, + stop=["gani"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer or", + expected_reason="gani", + ) - _test_stopping(llm, - stop=["gani"], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organi", - expected_reason="gani") + _test_stopping( + llm, + stop=["gani"], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organi", + expected_reason="gani", + ) def _stop_token_id(llm): # token id 13013 => " organization" - _test_stopping(llm, - stop_token_ids=[13013], - include_in_output=False, - expected_output="VLLM is a 100% volunteer", - expected_reason=13013) - - _test_stopping(llm, - stop_token_ids=[13013], - include_in_output=True, - expected_output="VLLM is a 100% volunteer organization", - expected_reason=13013) + _test_stopping( + llm, + stop_token_ids=[13013], + include_in_output=False, + expected_output="VLLM is a 100% volunteer", + expected_reason=13013, + ) + + _test_stopping( + llm, + stop_token_ids=[13013], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=13013, + ) @pytest.mark.skip_global_cleanup diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 7dc4a0cc3d58..47ceb45057c9 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -111,8 +111,7 @@ def __init__( self.last_seq = -1 self.decoder = msgspec.msgpack.Decoder(type=decode_type) - def receive_one(self, - timeout=1000) -> Union[tuple[int, SampleBatch], None]: + def receive_one(self, timeout=1000) -> Union[tuple[int, SampleBatch], None]: """Receive a single message with timeout""" if not self.sub.poll(timeout): return None @@ -135,8 +134,7 @@ def request_replay(self, start_seq: int, socket_idx: int = 0) -> None: self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big")) - def receive_replay(self, - socket_idx: int = 0) -> list[tuple[int, SampleBatch]]: + def receive_replay(self, socket_idx: int = 0) -> list[tuple[int, SampleBatch]]: """Receive replayed messages from a specific replay socket""" if not self.replay_sockets: raise ValueError("Replay sockets not initialized") diff --git a/tests/distributed/test_ca_buffer_sharing.py b/tests/distributed/test_ca_buffer_sharing.py index e2de462612b4..1ddce64f8e61 100644 --- a/tests/distributed/test_ca_buffer_sharing.py +++ b/tests/distributed/test_ca_buffer_sharing.py @@ -12,7 +12,8 @@ from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa - CustomAllreduce) + CustomAllreduce, +) # create a cpu process group for communicating metadata (ipc handle) dist.init_process_group(backend="gloo") @@ -52,7 +53,8 @@ assert ord(host_data[i]) == byte_value, ( f"Rank {rank} failed" f" to verify buffer {p}. Expected {byte_value}, " - f"got {ord(host_data[i])}") + f"got {ord(host_data[i])}" + ) print(f"Rank {rank} verified all buffers") diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 8d84cc2d0ffe..c61c4584d837 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -13,13 +13,19 @@ import ray import torch -from vllm.distributed import (broadcast_tensor_dict, get_pp_group, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, - tensor_model_parallel_reduce_scatter) +from vllm.distributed import ( + broadcast_tensor_dict, + get_pp_group, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter, +) -from ..utils import (init_test_distributed_environment, multi_gpu_test, - multi_process_parallel) +from ..utils import ( + init_test_distributed_environment, + multi_gpu_test, + multi_process_parallel, +) @ray.remote(num_gpus=1, max_calls=1) @@ -37,12 +43,11 @@ def all_reduce_test_worker( device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ - torch.arange(num_elements, dtype=torch.float32, device="cuda") * - (r + 1) for r in range(tp_size) + torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1) + for r in range(tp_size) ] expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) t = all_tensors[rank % tp_size] @@ -51,28 +56,31 @@ def all_reduce_test_worker( @ray.remote(num_gpus=1, max_calls=1) -def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int, - pp_size: int, rank: int, - distributed_init_port: str): +def reduce_scatter_test_worker( + monkeypatch: pytest.MonkeyPatch, + tp_size: int, + pp_size: int, + rank: int, + distributed_init_port: str, +): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs # they will be able to set the device to the correct GPU monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ - torch.arange(num_elements, dtype=torch.float32, device="cuda") * - (r + 1) for r in range(tp_size) + torch.arange(num_elements, dtype=torch.float32, device="cuda") * (r + 1) + for r in range(tp_size) ] index = rank % tp_size partition_size = num_elements // tp_size all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) - expected = all_reduce[index * partition_size:(index + 1) * partition_size] + expected = all_reduce[index * partition_size : (index + 1) * partition_size] t = all_tensors[index] t = tensor_model_parallel_reduce_scatter(t, 0) torch.testing.assert_close(t, expected) @@ -92,8 +100,7 @@ def all_gather_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_dimensions = 3 tensor_size = list(range(2, num_dimensions + 2)) total_size = 1 @@ -101,8 +108,10 @@ def all_gather_test_worker( total_size *= s for all_gather_dimension in range(num_dimensions): all_tensors = [ - torch.arange(total_size, dtype=torch.float32, - device="cuda").reshape(tensor_size) * (r + 1) + torch.arange(total_size, dtype=torch.float32, device="cuda").reshape( + tensor_size + ) + * (r + 1) for r in range(tp_size) ] expected = torch.cat(all_tensors, dim=all_gather_dimension) @@ -125,8 +134,7 @@ def broadcast_tensor_dict_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { # device tensor "a": torch.arange(8, dtype=torch.float32, device="cuda"), @@ -134,10 +142,7 @@ def broadcast_tensor_dict_test_worker( "b": torch.arange(16, dtype=torch.int8, device="cpu"), "c": "test", "d": [1, 2, 3], - "e": { - "a": 1, - "b": 2 - }, + "e": {"a": 1, "b": 2}, # empty tensor "f": torch.tensor([], dtype=torch.float32, device="cuda"), } @@ -166,8 +171,7 @@ def send_recv_tensor_dict_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { # device tensor @@ -176,10 +180,7 @@ def send_recv_tensor_dict_test_worker( "b": torch.arange(16, dtype=torch.int8, device="cpu"), "c": "test", "d": [1, 2, 3], - "e": { - "a": 1, - "b": 2 - }, + "e": {"a": 1, "b": 2}, # empty tensor "f": torch.tensor([], dtype=torch.float32, device="cuda"), } @@ -211,8 +212,7 @@ def send_recv_test_worker( monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) size = 64 test_tensor = torch.arange(64, dtype=torch.float32, device="cuda") @@ -229,10 +229,10 @@ def send_recv_test_worker( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("test_target", [ - all_reduce_test_worker, all_gather_test_worker, - broadcast_tensor_dict_test_worker -]) +@pytest.mark.parametrize( + "test_target", + [all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker], +) def test_multi_process_tensor_parallel( monkeypatch: pytest.MonkeyPatch, tp_size: int, @@ -244,7 +244,8 @@ def test_multi_process_tensor_parallel( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("pp_size", [2]) @pytest.mark.parametrize( - "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker]) + "test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker] +) def test_multi_process_pipeline_parallel( monkeypatch: pytest.MonkeyPatch, pp_size: int, @@ -256,11 +257,16 @@ def test_multi_process_pipeline_parallel( @multi_gpu_test(num_gpus=4) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pp_size", [2]) -@pytest.mark.parametrize("test_target", [ - send_recv_test_worker, send_recv_tensor_dict_test_worker, - all_reduce_test_worker, all_gather_test_worker, - broadcast_tensor_dict_test_worker -]) +@pytest.mark.parametrize( + "test_target", + [ + send_recv_test_worker, + send_recv_tensor_dict_test_worker, + all_reduce_test_worker, + all_gather_test_worker, + broadcast_tensor_dict_test_worker, + ], +) def test_multi_process_tensor_parallel_pipeline_parallel( tp_size: int, pp_size: int, diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 11685bc90c41..c8b6dc9781df 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -7,6 +7,7 @@ all workers in a node other than the head node, which can cause the test to fail. """ + import json import os from dataclasses import dataclass @@ -56,7 +57,8 @@ def __post_init__(self): raise ValueError( f"Length mismatch: distributed_backends " f"({len(self.distributed_backends)}) != " - f"vllm_major_versions ({len(self.vllm_major_versions)})") + f"vllm_major_versions ({len(self.vllm_major_versions)})" + ) @staticmethod def detailed( @@ -74,29 +76,39 @@ def detailed( for dcp_multiplier in [0.5, 1]: for chunked_prefill_val in [True]: parallel_setups.append( - ParallelSetup(tp_size=tp_base, - pp_size=pp_multiplier * pp_base, - dcp_size=int(dcp_multiplier * - tp_base), - eager_mode=eager_mode_val, - chunked_prefill=chunked_prefill_val)) + ParallelSetup( + tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + dcp_size=int(dcp_multiplier * tp_base), + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val, + ) + ) return CPTestSettings( parallel_setups=parallel_setups, distributed_backends=["mp"], vllm_major_versions=["1"], runner=runner, - test_options=CPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=CPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) def iter_params(self, model_id: str): opts = self.test_options for parallel_setup in self.parallel_setups: - for backend, vllm_major_version in zip(self.distributed_backends, - self.vllm_major_versions): - yield (model_id, parallel_setup, backend, vllm_major_version, - self.runner, opts) + for backend, vllm_major_version in zip( + self.distributed_backends, self.vllm_major_versions + ): + yield ( + model_id, + parallel_setup, + backend, + vllm_major_version, + self.runner, + opts, + ) def _compare_cp_with_tp( @@ -148,8 +160,10 @@ def _compare_cp_with_tp( if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": - pytest.skip("Skipping multi-node pipeline parallel test for " - "multiprocessing distributed backend") + pytest.skip( + "Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend" + ) if multi_node_only and not VLLM_MULTI_NODE: pytest.skip("Not in multi-node setting") @@ -178,8 +192,7 @@ def _compare_cp_with_tp( common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) cp_env = tp_env = { - "VLLM_USE_V1": - vllm_major_version, # Note(hc): DCP only support V1 engine only + "VLLM_USE_V1": vllm_major_version, # Note(hc): DCP only support V1 engine only } cp_args = [ @@ -205,13 +218,15 @@ def _compare_cp_with_tp( ] try: - compare_two_settings(model_id, - cp_args, - tp_args, - cp_env, - tp_env, - method=method, - max_wait_seconds=720) + compare_two_settings( + model_id, + cp_args, + tp_args, + cp_env, + tp_env, + method=method, + max_wait_seconds=720, + ) except Exception: testing_ray_compiled_graph = cp_env is not None if testing_ray_compiled_graph and vllm_major_version == "0": @@ -224,9 +239,10 @@ def _compare_cp_with_tp( CP_TEXT_GENERATION_MODELS = { # [MLA attention only] - "deepseek-ai/DeepSeek-V2-Lite-Chat": - [CPTestSettings.detailed(), - CPTestSettings.detailed(tp_base=2)], + "deepseek-ai/DeepSeek-V2-Lite-Chat": [ + CPTestSettings.detailed(), + CPTestSettings.detailed(tp_base=2), + ], } CP_TEST_MODELS = [ @@ -237,11 +253,19 @@ def _compare_cp_with_tp( @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "runner", "test_options"), + ( + "model_id", + "parallel_setup", + "distributed_backend", + "vllm_major_version", + "runner", + "test_options", + ), [ - params for model_id, settings in CP_TEXT_GENERATION_MODELS.items() - for setting in settings for params in setting.iter_params(model_id) + params + for model_id, settings in CP_TEXT_GENERATION_MODELS.items() + for setting in settings + for params in setting.iter_params(model_id) if model_id in CP_TEST_MODELS ], ) @@ -255,12 +279,14 @@ def test_cp_generation( test_options: CPTestOptions, num_gpus_available, ): - _compare_cp_with_tp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - runner, - test_options, - num_gpus_available, - method="generate", - is_multimodal=False) + _compare_cp_with_tp( + model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False, + ) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 9212c04deec9..f6e274be9384 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -8,12 +8,14 @@ import torch import torch.distributed as dist -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.parallel_state import get_tp_group, graph_capture -from ..utils import (ensure_model_parallel_initialized, - init_test_distributed_environment, multi_process_parallel) +from ..utils import ( + ensure_model_parallel_initialized, + init_test_distributed_environment, + multi_process_parallel, +) random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] @@ -33,8 +35,7 @@ def graph_allreduce( m.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) ensure_model_parallel_initialized(tp_size, pp_size) group = get_tp_group().device_group @@ -60,18 +61,15 @@ def graph_allreduce( for dtype in [torch.float32, torch.float16, torch.bfloat16]: with graph_capture(device=device) as graph_capture_context: # use integers so result matches NCCL exactly - inp1 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - inp2 = torch.randint(1, - 16, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) + inp1 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + inp2 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, - stream=graph_capture_context.stream): + with torch.cuda.graph(graph, stream=graph_capture_context.stream): for i in range(num_communication): out1 = tensor_model_parallel_all_reduce(inp1) # the input buffer is immediately modified to test @@ -96,8 +94,7 @@ def eager_allreduce( m.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) # we use the first group to communicate once # and the second group to communicate twice @@ -132,5 +129,4 @@ def test_custom_allreduce( world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") - multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, - test_target) + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_distributed_oot.py b/tests/distributed/test_distributed_oot.py index b93696e4be0e..ea7a88abda24 100644 --- a/tests/distributed/test_distributed_oot.py +++ b/tests/distributed/test_distributed_oot.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from ..entrypoints.openai.test_oot_registration import ( - run_and_test_dummy_opt_api_server) +from ..entrypoints.openai.test_oot_registration import run_and_test_dummy_opt_api_server def test_distributed_oot(dummy_opt_path: str): diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index e47ccba99c81..79805a7cce53 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -10,10 +10,12 @@ def test_basic_rebalance(): """Test basic rebalancing functionality""" # Example from https://github.com/deepseek-ai/eplb - weight = torch.tensor([ - [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], - [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], - ]) + weight = torch.tensor( + [ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ] + ) num_layers = weight.shape[0] num_replicas = 16 @@ -21,45 +23,49 @@ def test_basic_rebalance(): num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify output shapes assert phy2log.shape == ( 2, 16, ), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}" - assert (log2phy.shape[0] == 2 - ), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" - assert ( - log2phy.shape[1] == 12 - ), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" + assert log2phy.shape[0] == 2, ( + f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" + ) + assert log2phy.shape[1] == 12, ( + f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" + ) assert logcnt.shape == ( 2, 12, ), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}" # Verify physical to logical expert mapping range is correct - assert torch.all(phy2log >= 0) and torch.all( - phy2log < 12), "Physical to logical mapping should be in range [0, 12)" + assert torch.all(phy2log >= 0) and torch.all(phy2log < 12), ( + "Physical to logical mapping should be in range [0, 12)" + ) # Verify expert count reasonableness - assert torch.all( - logcnt >= 1), "Each logical expert should have at least 1 replica" - assert ( - torch.sum(logcnt, dim=1).sum() == num_replicas * - num_layers), f"Total replicas should be {num_replicas * num_layers}" + assert torch.all(logcnt >= 1), "Each logical expert should have at least 1 replica" + assert torch.sum(logcnt, dim=1).sum() == num_replicas * num_layers, ( + f"Total replicas should be {num_replicas * num_layers}" + ) # Verify expected output - expected_phy2log = torch.tensor([ - [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], - [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], - ]) + expected_phy2log = torch.tensor( + [ + [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], + [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], + ] + ) assert torch.all(phy2log == expected_phy2log) - expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], - [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]) + expected_logcnt = torch.tensor( + [[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]] + ) assert torch.all(logcnt == expected_logcnt) @@ -71,9 +77,9 @@ def test_single_gpu_case(): num_nodes = 1 num_gpus = 1 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (1, 4) @@ -93,19 +99,19 @@ def test_equal_weights(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (1, 8) assert logcnt.shape == (1, 8) # With equal weights, each expert should have exactly one replica - assert torch.all( - logcnt == 1 - ), "With equal weights and no replication, " \ - "each expert should have exactly 1 replica" + assert torch.all(logcnt == 1), ( + "With equal weights and no replication, " + "each expert should have exactly 1 replica" + ) def test_extreme_weight_imbalance(): @@ -116,35 +122,37 @@ def test_extreme_weight_imbalance(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (1, 12) assert logcnt.shape == (1, 8) # Expert with highest weight (index 0) should have more replicas - assert ( - logcnt[0, 0] - > logcnt[0, 1]), "Expert with highest weight should have more replicas" + assert logcnt[0, 0] > logcnt[0, 1], ( + "Expert with highest weight should have more replicas" + ) def test_multiple_layers(): """Test multiple layers case""" - weight = torch.tensor([ - [10, 20, 30, 40, 50, 60], # First layer - [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) - [25, 25, 25, 25, 25, 25], # Third layer (equal weights) - ]) + weight = torch.tensor( + [ + [10, 20, 30, 40, 50, 60], # First layer + [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) + [25, 25, 25, 25, 25, 25], # Third layer (equal weights) + ] + ) num_replicas = 8 num_groups = 2 num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify shapes assert phy2log.shape == (3, 8) @@ -152,12 +160,12 @@ def test_multiple_layers(): # Verify expert allocation is reasonable for each layer for layer in range(3): - assert torch.all(phy2log[layer] >= 0) and torch.all( - phy2log[layer] < 6 - ), f"Layer {layer} physical to logical mapping" \ - "should be in range [0, 6)" - assert (torch.sum(logcnt[layer]) == num_replicas - ), f"Layer {layer} total replicas should be {num_replicas}" + assert torch.all(phy2log[layer] >= 0) and torch.all(phy2log[layer] < 6), ( + f"Layer {layer} physical to logical mappingshould be in range [0, 6)" + ) + assert torch.sum(logcnt[layer]) == num_replicas, ( + f"Layer {layer} total replicas should be {num_replicas}" + ) def test_parameter_validation(): @@ -179,17 +187,19 @@ def test_parameter_validation(): def test_small_scale_hierarchical(): """Test small-scale hierarchical load balancing""" - weight = torch.tensor([ - [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts - ]) + weight = torch.tensor( + [ + [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts + ] + ) num_replicas = 12 num_groups = 4 # 4 groups, 2 experts each num_nodes = 2 # 2 nodes num_gpus = 4 # 4 GPUs - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Verify basic constraints assert phy2log.shape == (1, 12) @@ -199,8 +209,9 @@ def test_small_scale_hierarchical(): # Expert with highest weight should have more replicas max_weight_expert = torch.argmax(weight[0]) - assert (logcnt[0, max_weight_expert] - >= 2), "Highest weight expert should have multiple replicas" + assert logcnt[0, max_weight_expert] >= 2, ( + "Highest weight expert should have multiple replicas" + ) def test_global_load_balance_fallback(): @@ -213,9 +224,9 @@ def test_global_load_balance_fallback(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Should work normally, just using global load balancing strategy assert phy2log.shape == (1, 8) @@ -235,9 +246,9 @@ def test_device_compatibility(device): num_nodes = 1 num_gpus = 2 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) # Function will convert to CPU internally, but should handle different # device inputs normally @@ -250,7 +261,8 @@ def test_additional_cases(): # Test case 1: Large-scale distributed setup weight1 = torch.tensor( - [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]) + [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]] + ) phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) assert phy2log1.shape == (1, 24) @@ -258,10 +270,12 @@ def test_additional_cases(): assert torch.sum(logcnt1) == 24 # Test case 2: Different weight distributions - weight2 = torch.tensor([ - [200, 150, 100, 50, 25, 12], # Decreasing weights - [12, 25, 50, 100, 150, 200], # Increasing weights - ]) + weight2 = torch.tensor( + [ + [200, 150, 100, 50, 25, 12], # Decreasing weights + [12, 25, 50, 100, 150, 200], # Increasing weights + ] + ) phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) assert phy2log2.shape == (2, 10) @@ -274,19 +288,21 @@ def test_additional_cases(): if __name__ == "__main__": - weight = torch.tensor([ - [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], - [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], - ]) + weight = torch.tensor( + [ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ] + ) num_replicas = 16 num_groups = 4 num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, - num_groups, num_nodes, - num_gpus) + phy2log, log2phy, logcnt = rebalance_experts( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) print(phy2log) test_basic_rebalance() diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index de9ed1eabbac..7ca3d3d27b56 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -9,11 +9,12 @@ import torch import torch.distributed -from vllm.distributed.eplb.rebalance_execute import ( - rearrange_expert_weights_inplace) -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - get_tp_group, - init_distributed_environment) +from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + get_tp_group, + init_distributed_environment, +) from vllm.utils import update_environment_variables @@ -22,13 +23,13 @@ def distributed_run(fn, world_size): processes: list[multiprocessing.Process] = [] for i in range(number_of_processes): env: dict[str, str] = {} - env['RANK'] = str(i) - env['LOCAL_RANK'] = str(i) - env['WORLD_SIZE'] = str(number_of_processes) - env['LOCAL_WORLD_SIZE'] = str(number_of_processes) - env['MASTER_ADDR'] = 'localhost' - env['MASTER_PORT'] = '12345' - p = multiprocessing.Process(target=fn, args=(env, )) + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12345" + p = multiprocessing.Process(target=fn, args=(env,)) processes.append(p) p.start() @@ -45,7 +46,7 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) - local_rank = os.environ['LOCAL_RANK'] + local_rank = os.environ["LOCAL_RANK"] device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) init_distributed_environment() @@ -60,20 +61,20 @@ def wrapped_fn(env): def create_expert_indices_with_redundancy( - num_layers: int, - num_logical_experts: int, - total_physical_experts: int, - redundancy_config: list[int], # redundancy for each logical expert + num_layers: int, + num_logical_experts: int, + total_physical_experts: int, + redundancy_config: list[int], # redundancy for each logical expert ) -> torch.Tensor: """ Create expert indices with redundancy. - + Args: num_layers: number of layers num_logical_experts: number of logical experts total_physical_experts: total number of physical experts redundancy_config: redundancy for each logical expert - + Returns: indices: Shape (num_layers, total_physical_experts) """ @@ -106,11 +107,11 @@ def create_expert_weights( ) -> list[list[torch.Tensor]]: """ Create fake expert weights tensor for testing. - + Use `arange` to generate predictable weights values, based on logical expert ID. All replicas of the same logical expert should have the same weights. - + Args: physical_to_logical_mapping: Shape (num_layers, num_local_experts) mapping[layer, physical_pos] = logical_expert_id @@ -120,27 +121,27 @@ def create_expert_weights( for layer in range(num_layers): layer_weights = [] for weight_idx, hidden_size in enumerate(hidden_sizes): - weight_tensor = torch.zeros(num_local_experts, - hidden_size, - device=device, - dtype=torch.float32) + weight_tensor = torch.zeros( + num_local_experts, hidden_size, device=device, dtype=torch.float32 + ) for local_expert in range(num_local_experts): # Get the logical expert ID for this physical expert global_pos = rank * num_local_experts + local_expert logical_expert_id = physical_to_logical_mapping[ - layer, global_pos].item() + layer, global_pos + ].item() # Generate weights based on logical expert ID # (so that all replicas of the same logical expert have the # same weights) - base_value = (logical_expert_id * 1000 + layer * 100 + - weight_idx * 10) - weight_tensor[local_expert] = torch.arange(base_value, - base_value + - hidden_size, - device=device, - dtype=torch.float32) + base_value = logical_expert_id * 1000 + layer * 100 + weight_idx * 10 + weight_tensor[local_expert] = torch.arange( + base_value, + base_value + hidden_size, + device=device, + dtype=torch.float32, + ) layer_weights.append(weight_tensor) expert_weights.append(layer_weights) @@ -182,12 +183,15 @@ def verify_expert_weights_after_shuffle( # Check if the weights are correct actual_weights = weight_tensor[local_expert] - expected_base = (expected_logical_expert * 1000 + layer * 100 + - weight_idx * 10) - expected_weights = torch.arange(expected_base, - expected_base + hidden_size, - device=actual_weights.device, - dtype=actual_weights.dtype) + expected_base = ( + expected_logical_expert * 1000 + layer * 100 + weight_idx * 10 + ) + expected_weights = torch.arange( + expected_base, + expected_base + hidden_size, + device=actual_weights.device, + dtype=actual_weights.dtype, + ) torch.testing.assert_close( actual_weights, @@ -195,7 +199,8 @@ def verify_expert_weights_after_shuffle( msg=f"Layer {layer}, weight {weight_idx}," f"local expert {local_expert}: " f"weights do not match. " - f"Expected logical expert {expected_logical_expert}") + f"Expected logical expert {expected_logical_expert}", + ) def verify_redundant_experts_have_same_weights( @@ -222,23 +227,23 @@ def verify_redundant_experts_have_same_weights( total_physical_experts, hidden_size, device=expert_weights[layer][weight_idx].device, - dtype=expert_weights[layer][weight_idx].dtype) + dtype=expert_weights[layer][weight_idx].dtype, + ) # Use all_gather to collect expert weights from current node # expert_weights[layer][weight_idx] shape: # [num_local_experts, hidden_size] local_weights = expert_weights[layer][ - weight_idx] # [num_local_experts, hidden_size] + weight_idx + ] # [num_local_experts, hidden_size] # Split tensor along dim 0 into a list for all_gather - gathered_weights_list = torch.chunk(gathered_weights, - world_size, - dim=0) + gathered_weights_list = torch.chunk(gathered_weights, world_size, dim=0) torch.distributed.all_gather( # Output list: each element corresponds to one rank's weights list(gathered_weights_list), - local_weights # Input: current rank's local weights + local_weights, # Input: current rank's local weights ) all_weights.append(gathered_weights) @@ -266,7 +271,8 @@ def verify_redundant_experts_have_same_weights( msg=f"Layer {layer}, weight {weight_idx}," f"logical expert {logical_expert_id}: " f"Physical expert {physical_pos} has different weights" - f"than expected") + f"than expected", + ) @pytest.mark.parametrize( @@ -290,10 +296,11 @@ def verify_redundant_experts_have_same_weights( # 4 GPU, 8 experts per GPU # 16 logical experts, 32 physical experts, 16 redundant experts (4, 8, 8, 16), - ]) -def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, - num_local_experts, - num_logical_experts): + ], +) +def test_rearrange_expert_weights_with_redundancy( + world_size, num_layers, num_local_experts, num_logical_experts +): """Test the functionality of rearranging expert weights with redundancy.""" if torch.cuda.device_count() < world_size: @@ -304,8 +311,8 @@ def worker_fn(): # Initialize model parallel (using tensor parallel as an entrypoint # to expert parallel) ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -316,8 +323,9 @@ def worker_fn(): hidden_sizes = [32, 64] # Two different weight matrices # Create old expert indices (with redundancy) - redundancy_config = create_redundancy_config(num_logical_experts, - total_physical_experts) + redundancy_config = create_redundancy_config( + num_logical_experts, total_physical_experts + ) old_indices = create_expert_indices_with_redundancy( num_layers, @@ -328,7 +336,8 @@ def worker_fn(): # Create new expert indices (with redundancy) new_redundancy_config = create_redundancy_config( - num_logical_experts, total_physical_experts) + num_logical_experts, total_physical_experts + ) new_indices = create_expert_indices_with_redundancy( num_layers, num_logical_experts, @@ -337,9 +346,9 @@ def worker_fn(): ) # Create expert weights - expert_weights = create_expert_weights(num_layers, num_local_experts, - hidden_sizes, ep_rank, device, - old_indices) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices + ) # Execute weight rearrangement rearrange_expert_weights_inplace( @@ -383,8 +392,8 @@ def test_rearrange_expert_weights_no_change(world_size): @worker_fn_wrapper def worker_fn(): ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -401,12 +410,12 @@ def worker_fn(): # Same indices - no change indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, - redundancy_config) + num_layers, num_logical_experts, total_physical_experts, redundancy_config + ) - expert_weights = create_expert_weights(num_layers, num_local_experts, - hidden_sizes, ep_rank, device, - indices) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, indices + ) # Save original weights original_weights = [] @@ -422,7 +431,8 @@ def worker_fn(): indices, # Same indices expert_weights, ep_group, - is_profile=False) + is_profile=False, + ) # Verify that the weights have not changed for layer in range(num_layers): @@ -430,8 +440,8 @@ def worker_fn(): torch.testing.assert_close( expert_weights[layer][weight_idx], original_weights[layer][weight_idx], - msg=f"Layer {layer}, weight {weight_idx} should remain " - f"unchanged") + msg=f"Layer {layer}, weight {weight_idx} should remain unchanged", + ) distributed_run(worker_fn, world_size) @@ -446,8 +456,8 @@ def test_rearrange_expert_weights_profile_mode(world_size): @worker_fn_wrapper def worker_fn(): ensure_model_parallel_initialized( - tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + tensor_model_parallel_size=world_size, pipeline_model_parallel_size=1 + ) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -460,21 +470,23 @@ def worker_fn(): hidden_sizes = [32] # Create different index distributions - old_redundancy = create_redundancy_config(num_logical_experts, - total_physical_experts) - new_redundancy = create_redundancy_config(num_logical_experts, - total_physical_experts) + old_redundancy = create_redundancy_config( + num_logical_experts, total_physical_experts + ) + new_redundancy = create_redundancy_config( + num_logical_experts, total_physical_experts + ) old_indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, - old_redundancy) + num_layers, num_logical_experts, total_physical_experts, old_redundancy + ) new_indices = create_expert_indices_with_redundancy( - num_layers, num_logical_experts, total_physical_experts, - new_redundancy) + num_layers, num_logical_experts, total_physical_experts, new_redundancy + ) - expert_weights = create_expert_weights(num_layers, num_local_experts, - hidden_sizes, ep_rank, device, - old_indices) + expert_weights = create_expert_weights( + num_layers, num_local_experts, hidden_sizes, ep_rank, device, old_indices + ) # Save original weights original_weights = [] @@ -490,7 +502,7 @@ def worker_fn(): new_indices, expert_weights, ep_group, - is_profile=True # Profile mode + is_profile=True, # Profile mode ) # In profile mode, the weights should remain unchanged @@ -499,6 +511,7 @@ def worker_fn(): torch.testing.assert_close( expert_weights[layer][weight_idx], original_weights[layer][weight_idx], - msg="In profile mode, the weights should remain unchanged") + msg="In profile mode, the weights should remain unchanged", + ) distributed_run(worker_fn, world_size) diff --git a/tests/distributed/test_events.py b/tests/distributed/test_events.py index 8be9ee0a1889..f06f6771a4a0 100644 --- a/tests/distributed/test_events.py +++ b/tests/distributed/test_events.py @@ -6,24 +6,29 @@ import msgspec import pytest -from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory, - NullEventPublisher) +from vllm.distributed.kv_events import ( + EventBatch, + EventPublisherFactory, + NullEventPublisher, +) DP_RANK = 0 class EventSample( - msgspec.Struct, - tag=True, # type: ignore - array_like=True # type: ignore + msgspec.Struct, + tag=True, # type: ignore + array_like=True, # type: ignore ): """Test event for publisher testing""" + id: int value: str class SampleBatch(EventBatch): """Test event batch for publisher testing""" + events: list[EventSample] @@ -44,10 +49,8 @@ def test_basic_publishing(publisher, subscriber): seq, received = result assert seq == 0, "Sequence number mismatch" - assert received.ts == pytest.approx(test_batch.ts, - abs=0.1), ("Timestamp mismatch") - assert len(received.events) == len( - test_batch.events), ("Number of events mismatch") + assert received.ts == pytest.approx(test_batch.ts, abs=0.1), "Timestamp mismatch" + assert len(received.events) == len(test_batch.events), "Number of events mismatch" for i, event in enumerate(received.events): assert event.id == i, "Event id mismatch" @@ -88,9 +91,9 @@ def test_replay_mechanism(publisher, subscriber): assert len(replayed) > 0, "No replayed messages received" seqs = [seq for seq, _ in replayed] assert all(seq >= 10 for seq in seqs), "Replayed messages not in order" - assert seqs == list(range(min(seqs), - max(seqs) + - 1)), ("Replayed messages not consecutive") + assert seqs == list(range(min(seqs), max(seqs) + 1)), ( + "Replayed messages not consecutive" + ) def test_buffer_limit(publisher, subscriber, publisher_config): @@ -126,6 +129,7 @@ def test_topic_filtering(publisher_config): pub = EventPublisherFactory.create(publisher_config, DP_RANK) from .conftest import MockSubscriber + sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo") sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar") @@ -137,11 +141,13 @@ def test_topic_filtering(publisher_config): foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)] assert all(msg is not None for msg in foo_received), ( - "Subscriber with matching topic should receive messages") + "Subscriber with matching topic should receive messages" + ) bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)] assert all(msg is None for msg in bar_received), ( - "Subscriber with non-matching topic should receive no messages") + "Subscriber with non-matching topic should receive no messages" + ) finally: pub.shutdown() sub_foo.close() @@ -178,8 +184,7 @@ def publish_events(): publisher_thread.join() - assert len(received) >= num_batches * 0.9, ( - "We should have received most messages") + assert len(received) >= num_batches * 0.9, "We should have received most messages" seqs = [seq for seq, _ in received] assert sorted(seqs) == seqs, "Sequence numbers should be in order" @@ -209,13 +214,15 @@ def test_data_parallel_rank_tagging(publisher_config): # For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558 expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port expected_endpoint_1 = base_endpoint.replace( - ":5557", ":5558") # rank 1 gets port + 1 + ":5557", ":5558" + ) # rank 1 gets port + 1 else: # For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1 expected_endpoint_0 = base_endpoint # rank 0 gets base expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1 from .conftest import MockSubscriber + sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic) sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic) @@ -241,15 +248,15 @@ def test_data_parallel_rank_tagging(publisher_config): # Verify DP rank tagging assert received_0.data_parallel_rank == 0, ( - f"Expected DP rank 0, got {received_0.data_parallel_rank}") + f"Expected DP rank 0, got {received_0.data_parallel_rank}" + ) assert received_1.data_parallel_rank == 1, ( - f"Expected DP rank 1, got {received_1.data_parallel_rank}") + f"Expected DP rank 1, got {received_1.data_parallel_rank}" + ) # Verify event content is correct - assert len( - received_0.events) == 2, "Wrong number of events from rank 0" - assert len( - received_1.events) == 3, "Wrong number of events from rank 1" + assert len(received_0.events) == 2, "Wrong number of events from rank 0" + assert len(received_1.events) == 3, "Wrong number of events from rank 1" finally: pub_0.shutdown() diff --git a/tests/distributed/test_expert_parallel.py b/tests/distributed/test_expert_parallel.py index f273f302e72e..94f0ece4971b 100644 --- a/tests/distributed/test_expert_parallel.py +++ b/tests/distributed/test_expert_parallel.py @@ -46,28 +46,24 @@ def detailed( ): return EPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=tp_base, - eager_mode=True, - chunked_prefill=False), - ParallelSetup(tp_size=2 * tp_base, - eager_mode=False, - chunked_prefill=True), - ParallelSetup(tp_size=2 * tp_base, - eager_mode=True, - chunked_prefill=False), + ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=False), + ParallelSetup(tp_size=tp_base, eager_mode=False, chunked_prefill=True), + ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False), + ParallelSetup( + tp_size=2 * tp_base, eager_mode=False, chunked_prefill=True + ), + ParallelSetup( + tp_size=2 * tp_base, eager_mode=True, chunked_prefill=False + ), ], distributed_backends=["mp", "ray"], runner=runner, - test_options=EPTestOptions(trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, - load_format=load_format, - hf_overrides=hf_overrides), + test_options=EPTestOptions( + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + load_format=load_format, + hf_overrides=hf_overrides, + ), ) @staticmethod @@ -82,16 +78,16 @@ def fast( ): return EPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - eager_mode=True, - chunked_prefill=False), + ParallelSetup(tp_size=tp_base, eager_mode=True, chunked_prefill=False), ], distributed_backends=["mp"], runner=runner, - test_options=EPTestOptions(trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, - load_format=load_format, - hf_overrides=hf_overrides), + test_options=EPTestOptions( + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + load_format=load_format, + hf_overrides=hf_overrides, + ), ) def iter_params(self, model_name: str): @@ -99,17 +95,20 @@ def iter_params(self, model_name: str): for parallel_setup in self.parallel_setups: for distributed_backend in self.distributed_backends: - yield (model_name, parallel_setup, distributed_backend, - self.runner, opts) + yield ( + model_name, + parallel_setup, + distributed_backend, + self.runner, + opts, + ) # NOTE: You can adjust tp_base locally to fit the model in GPU # The values displayed here are only a rough indicator of the size of the model -# yapf: disable TEST_MODELS = { - "deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast( - trust_remote_code=True), + "deepseek-ai/DeepSeek-V2-Lite-Chat": EPTestSettings.fast(trust_remote_code=True), "mistralai/Mixtral-8x7B-Instruct-v0.1": EPTestSettings.fast(tp_base=4), } @@ -191,22 +190,24 @@ def _compare_tp( ] try: - compare_two_settings(model_name, - ep_args, - tp_args, - ep_env, - tp_env, - method=method, - max_wait_seconds=360) + compare_two_settings( + model_name, + ep_args, + tp_args, + ep_env, + tp_env, + method=method, + max_wait_seconds=360, + ) except Exception: raise @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", "runner", - "test_options"), + ("model_name", "parallel_setup", "distributed_backend", "runner", "test_options"), [ - params for model_name, settings in TEST_MODELS.items() + params + for model_name, settings in TEST_MODELS.items() for params in settings.iter_params(model_name) ], ) @@ -219,10 +220,12 @@ def test_ep( test_options: EPTestOptions, num_gpus_available, ): - _compare_tp(model_name, - parallel_setup, - distributed_backend, - runner, - test_options, - num_gpus_available, - method="generate") + _compare_tp( + model_name, + parallel_setup, + distributed_backend, + runner, + test_options, + num_gpus_available, + method="generate", + ) diff --git a/tests/distributed/test_expert_placement.py b/tests/distributed/test_expert_placement.py index a3b1b3193deb..cb9c8f507404 100644 --- a/tests/distributed/test_expert_placement.py +++ b/tests/distributed/test_expert_placement.py @@ -6,17 +6,13 @@ from vllm.model_executor.layers.fused_moe.layer import determine_expert_map -def verify_round_robin_pattern(expert_map, ep_rank, ep_size, - global_num_experts): +def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts): """Verify that the expert map follows the round_robin pattern.""" # Calculate expected local experts (supporting non-divisible cases) base_experts = global_num_experts // ep_size remainder = global_num_experts % ep_size - if ep_rank < remainder: - local_num_experts = base_experts + 1 - else: - local_num_experts = base_experts + local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts # Expected expert IDs for this rank in round_robin pattern # For non-divisible cases, ranks with extra experts start earlier @@ -30,24 +26,21 @@ def verify_round_robin_pattern(expert_map, ep_rank, ep_size, if global_expert_id in expected_expert_ids: local_expert_id = expert_map[global_expert_id] expected_local_id = expected_expert_ids.index(global_expert_id) - assert ( - local_expert_id == expected_local_id - ), f"Global expert {global_expert_id} should map to local expert " \ + assert local_expert_id == expected_local_id, ( + f"Global expert {global_expert_id} should map to local expert " f"{expected_local_id}, got {local_expert_id}" + ) else: - assert ( - expert_map[global_expert_id] == -1 - ), f"Global expert {global_expert_id} should not be mapped to " \ - f"this rank" + assert expert_map[global_expert_id] == -1, ( + f"Global expert {global_expert_id} should not be mapped to this rank" + ) # Verify that all local expert IDs are consecutive starting from 0 - local_expert_ids = [ - expert_map[global_id] for global_id in expected_expert_ids - ] + local_expert_ids = [expert_map[global_id] for global_id in expected_expert_ids] expected_local_ids = list(range(local_num_experts)) - assert ( - local_expert_ids == expected_local_ids - ), f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}" + assert local_expert_ids == expected_local_ids, ( + f"Expected local expert IDs {expected_local_ids}, got {local_expert_ids}" + ) @pytest.mark.parametrize("expert_placement_strategy", ["round_robin"]) @@ -78,8 +71,9 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size): for test_global_experts, test_ep_size in test_cases: # Ensure ep_size matches world_size - assert (test_ep_size == world_size - ), f"ep_size {test_ep_size} must equal world_size {world_size}" + assert test_ep_size == world_size, ( + f"ep_size {test_ep_size} must equal world_size {world_size}" + ) # Test each rank for ep_rank in range(world_size): @@ -98,21 +92,22 @@ def test_expert_placement_various_sizes(expert_placement_strategy, world_size): expert_placement_strategy=expert_placement_strategy, ) - assert ( - test_local_experts == expected_test_local - ), f"For {test_global_experts} experts on {test_ep_size} ranks, " \ - f"rank {ep_rank}: expected {expected_test_local} local" \ + assert test_local_experts == expected_test_local, ( + f"For {test_global_experts} experts on {test_ep_size} ranks, " + f"rank {ep_rank}: expected {expected_test_local} local" f"experts, got {test_local_experts}" + ) if test_expert_map is not None: - assert test_expert_map.shape == ( - test_global_experts, - ), f"Expected expert map shape ({test_global_experts},), " \ + assert test_expert_map.shape == (test_global_experts,), ( + f"Expected expert map shape ({test_global_experts},), " f"got {test_expert_map.shape}" + ) # Verify round_robin pattern for this test case - verify_round_robin_pattern(test_expert_map, ep_rank, - test_ep_size, test_global_experts) + verify_round_robin_pattern( + test_expert_map, ep_rank, test_ep_size, test_global_experts + ) @pytest.mark.parametrize("expert_placement_strategy", ["round_robin"]) @@ -147,28 +142,81 @@ def test_determine_expert_map_comprehensive(): # expert_placement_strategy, expected_local, expected_map_pattern) test_cases = [ # Round robin placement tests - (2, 0, 8, "round_robin", 4, [0, -1, 1, -1, 2, -1, 3, - -1]), # rank 0 gets even experts - (2, 1, 8, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1, - 3]), # rank 1 gets odd experts - (2, 0, 9, "round_robin", 5, [0, -1, 1, -1, 2, -1, 3, -1, 4 - ]), # rank 0 gets 5 experts (even + last) - (2, 1, 9, "round_robin", 4, [-1, 0, -1, 1, -1, 2, -1, 3, - -1]), # rank 1 gets 4 experts (odd) - + ( + 2, + 0, + 8, + "round_robin", + 4, + [0, -1, 1, -1, 2, -1, 3, -1], + ), # rank 0 gets even experts + ( + 2, + 1, + 8, + "round_robin", + 4, + [-1, 0, -1, 1, -1, 2, -1, 3], + ), # rank 1 gets odd experts + ( + 2, + 0, + 9, + "round_robin", + 5, + [0, -1, 1, -1, 2, -1, 3, -1, 4], + ), # rank 0 gets 5 experts (even + last) + ( + 2, + 1, + 9, + "round_robin", + 4, + [-1, 0, -1, 1, -1, 2, -1, 3, -1], + ), # rank 1 gets 4 experts (odd) # 4-rank tests - (4, 0, 8, "round_robin", 2, [0, -1, -1, -1, 1, -1, -1, - -1]), # rank 0 gets experts 0, 4 - (4, 1, 8, "round_robin", 2, [-1, 0, -1, -1, -1, 1, -1, - -1]), # rank 1 gets experts 1, 5 - (4, 2, 8, "round_robin", 2, [-1, -1, 0, -1, -1, -1, 1, - -1]), # rank 2 gets experts 2, 6 - (4, 3, 8, "round_robin", 2, [-1, -1, -1, 0, -1, -1, -1, - 1]), # rank 3 gets experts 3, 7 + ( + 4, + 0, + 8, + "round_robin", + 2, + [0, -1, -1, -1, 1, -1, -1, -1], + ), # rank 0 gets experts 0, 4 + ( + 4, + 1, + 8, + "round_robin", + 2, + [-1, 0, -1, -1, -1, 1, -1, -1], + ), # rank 1 gets experts 1, 5 + ( + 4, + 2, + 8, + "round_robin", + 2, + [-1, -1, 0, -1, -1, -1, 1, -1], + ), # rank 2 gets experts 2, 6 + ( + 4, + 3, + 8, + "round_robin", + 2, + [-1, -1, -1, 0, -1, -1, -1, 1], + ), # rank 3 gets experts 3, 7 ] - for ep_size, ep_rank, global_num_experts, expert_placement_strategy, \ - expected_local, expected_map_pattern in test_cases: + for ( + ep_size, + ep_rank, + global_num_experts, + expert_placement_strategy, + expected_local, + expected_map_pattern, + ) in test_cases: local_num_experts, expert_map = determine_expert_map( ep_size=ep_size, ep_rank=ep_rank, @@ -176,19 +224,21 @@ def test_determine_expert_map_comprehensive(): expert_placement_strategy=expert_placement_strategy, ) - assert local_num_experts == expected_local, \ - f"ep_size={ep_size}, ep_rank={ep_rank}, " \ - f"global_num_experts={global_num_experts}, " \ - f"expert_placement_strategy={expert_placement_strategy}: " \ + assert local_num_experts == expected_local, ( + f"ep_size={ep_size}, ep_rank={ep_rank}, " + f"global_num_experts={global_num_experts}, " + f"expert_placement_strategy={expert_placement_strategy}: " f"expected {expected_local} local experts, got {local_num_experts}" + ) if expected_map_pattern is None: assert expert_map is None, "Expected expert_map to be None" else: assert expert_map is not None, "Expected expert_map to not be None" actual_map = expert_map.tolist() - assert actual_map == expected_map_pattern, \ - f"ep_size={ep_size}, ep_rank={ep_rank}, " \ - f"global_num_experts={global_num_experts}, " \ - f"expert_placement_strategy={expert_placement_strategy}: " \ + assert actual_map == expected_map_pattern, ( + f"ep_size={ep_size}, ep_rank={ep_rank}, " + f"global_num_experts={global_num_experts}, " + f"expert_placement_strategy={expert_placement_strategy}: " f"expected map {expected_map_pattern}, got {actual_map}" + ) diff --git a/tests/distributed/test_kvlayout.py b/tests/distributed/test_kvlayout.py index d447876f6cc7..b190b2820451 100644 --- a/tests/distributed/test_kvlayout.py +++ b/tests/distributed/test_kvlayout.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig, - VllmConfig, set_current_vllm_config) +from vllm.config import ( + DeviceConfig, + KVTransferConfig, + ModelConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_kv_connector_cache_layout) + get_kv_connector_cache_layout, +) from vllm.logger import init_logger logger = init_logger("test_expert_parallel") @@ -23,8 +29,9 @@ def test_get_kv_connector_cache_layout_with_lmcache_connector(): kv_connector="LMCacheConnectorV1", kv_role="kv_both", ) - vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), - kv_transfer_config=kv_transfer_config) + vllm_config = VllmConfig( + device_config=DeviceConfig("cpu"), kv_transfer_config=kv_transfer_config + ) with set_current_vllm_config(vllm_config): # Test with default settings layout = get_kv_connector_cache_layout() @@ -37,9 +44,11 @@ def test_get_kv_connector_cache_layout_with_nixl_connector(): kv_role="kv_both", ) model_config = ModelConfig() - vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), - model_config=model_config, - kv_transfer_config=kv_transfer_config) + vllm_config = VllmConfig( + device_config=DeviceConfig("cpu"), + model_config=model_config, + kv_transfer_config=kv_transfer_config, + ) with set_current_vllm_config(vllm_config): # Test with default settings layout = get_kv_connector_cache_layout() @@ -47,25 +56,22 @@ def test_get_kv_connector_cache_layout_with_nixl_connector(): def test_get_kv_connector_cache_layout_with_multi_connector(): - kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector", - kv_role="kv_both", - kv_connector_extra_config={ - "connectors": [{ - "kv_connector": - "SharedStorageConnector", - "kv_role": - "kv_both" - }, { - "kv_connector": - "NixlConnector", - "kv_role": - "kv_both" - }] - }) + kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [ + {"kv_connector": "SharedStorageConnector", "kv_role": "kv_both"}, + {"kv_connector": "NixlConnector", "kv_role": "kv_both"}, + ] + }, + ) model_config = ModelConfig() - vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), - model_config=model_config, - kv_transfer_config=kv_transfer_config) + vllm_config = VllmConfig( + device_config=DeviceConfig("cpu"), + model_config=model_config, + kv_transfer_config=kv_transfer_config, + ) with set_current_vllm_config(vllm_config): # Test with default settings layout = get_kv_connector_cache_layout() diff --git a/tests/distributed/test_multi_node_assignment.py b/tests/distributed/test_multi_node_assignment.py index ef17a51fff0e..8d818edbb3bd 100644 --- a/tests/distributed/test_multi_node_assignment.py +++ b/tests/distributed/test_multi_node_assignment.py @@ -24,14 +24,13 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.skipif(not VLLM_MULTI_NODE, - reason="Need at least 2 nodes to run the test.") +@pytest.mark.skipif( + not VLLM_MULTI_NODE, reason="Need at least 2 nodes to run the test." +) def test_multi_node_assignment() -> None: - # NOTE: important to keep this class definition here # to let ray use cloudpickle to serialize it. class Actor: - def get_ip(self): return get_ip() @@ -41,8 +40,7 @@ def get_ip(self): current_ip = get_ip() workers = [] - for bundle_id, bundle in enumerate( - config.placement_group.bundle_specs): + for bundle_id, bundle in enumerate(config.placement_group.bundle_specs): if not bundle.get("GPU", 0): continue scheduling_strategy = PlacementGroupSchedulingStrategy( diff --git a/tests/distributed/test_nccl_symm_mem_allreduce.py b/tests/distributed/test_nccl_symm_mem_allreduce.py index ffc913742620..40dcf7567c92 100644 --- a/tests/distributed/test_nccl_symm_mem_allreduce.py +++ b/tests/distributed/test_nccl_symm_mem_allreduce.py @@ -11,15 +11,17 @@ import vllm.envs as envs from vllm.distributed import cleanup_dist_env_and_memory -from vllm.distributed.device_communicators.cuda_communicator import ( - CudaCommunicator) -from vllm.distributed.device_communicators.pynccl import ( - register_nccl_symmetric_ops) +from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator +from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops from vllm.distributed.device_communicators.pynccl_allocator import ( - get_nccl_mem_pool, is_symmetric_memory_enabled) -from vllm.distributed.parallel_state import (get_tp_group, - init_distributed_environment, - initialize_model_parallel) + get_nccl_mem_pool, + is_symmetric_memory_enabled, +) +from vllm.distributed.parallel_state import ( + get_tp_group, + init_distributed_environment, + initialize_model_parallel, +) from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -38,31 +40,32 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int): torch.cuda.set_device(device) torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - "RANK": str(local_rank), - "LOCAL_RANK": str(local_rank), - "WORLD_SIZE": str(world_size), - "MASTER_ADDR": "localhost", - "MASTER_PORT": "12345", - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) - cuda_communicator = typing.cast(CudaCommunicator, - get_tp_group().device_communicator) + cuda_communicator = typing.cast( + CudaCommunicator, get_tp_group().device_communicator + ) pynccl_comm = cuda_communicator.pynccl_comm if get_nccl_mem_pool() is None: - pytest.skip("NCCL allocator compilation failed " - "(probably missing NCCL headers).") + pytest.skip( + "NCCL allocator compilation failed (probably missing NCCL headers)." + ) if not is_symmetric_memory_enabled(): pytest.skip("NCCL symmetric memory allreduce is disabled.") register_nccl_symmetric_ops(pynccl_comm) - input = torch.randint(1, - 23, (test_size_elements, ), - dtype=dtype, - device=device) + input = torch.randint(1, 23, (test_size_elements,), dtype=dtype, device=device) input_clone = input.clone() output = torch.ops.vllm.all_reduce_symmetric_with_copy(input) assert output is not None @@ -77,8 +80,7 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int): reason="NCCLSymmMemAllreduce is only available for CUDA platforms.", ) @pytest.mark.parametrize("world_size", [2]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size): if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") @@ -88,7 +90,5 @@ def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size): monkeypatch.setenv("NCCL_NVLS_ENABLE", "1") monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1") - mp.spawn(nccl_symm_mem_allreduce_worker, - args=(world_size, ), - nprocs=world_size) + mp.spawn(nccl_symm_mem_allreduce_worker, args=(world_size,), nprocs=world_size) cleanup_dist_env_and_memory() diff --git a/tests/distributed/test_node_count.py b/tests/distributed/test_node_count.py index e3c36ef5ef37..b48c025aa1a2 100644 --- a/tests/distributed/test_node_count.py +++ b/tests/distributed/test_node_count.py @@ -32,12 +32,15 @@ # Expected node count based on environment variable) expected = int(os.environ.get("NUM_NODES", "1")) - assert test_result == expected, \ - f"Expected {expected} nodes, got {test_result}" + assert test_result == expected, f"Expected {expected} nodes, got {test_result}" if pg == dist.group.WORLD: - print(f"Node count test passed! Got {test_result} nodes " - f"when using torch distributed!") + print( + f"Node count test passed! Got {test_result} nodes " + f"when using torch distributed!" + ) else: - print(f"Node count test passed! Got {test_result} nodes " - f"when using StatelessProcessGroup!") + print( + f"Node count test passed! Got {test_result} nodes " + f"when using StatelessProcessGroup!" + ) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index aa28ed9ce25e..7d55c40754b4 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -7,6 +7,7 @@ all workers in a node other than the head node, which can cause the test to fail. """ + import json import os from dataclasses import dataclass @@ -55,26 +56,17 @@ def detailed( ): return PPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - eager_mode=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - eager_mode=False), - ParallelSetup(tp_size=tp_base, - pp_size=2 * pp_base, - eager_mode=True), - ParallelSetup(tp_size=2 * tp_base, - pp_size=pp_base, - eager_mode=False), - ParallelSetup(tp_size=2 * tp_base, - pp_size=pp_base, - eager_mode=True), + ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=False), + ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=False), + ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, eager_mode=True), + ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=False), + ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, eager_mode=True), ], distributed_backends=["mp", "ray"], runner=runner, - test_options=PPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=PPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) @staticmethod @@ -86,17 +78,15 @@ def fast( multi_node_only: bool = False, load_format: Optional[str] = None, ): - return PPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - eager_mode=True), + ParallelSetup(tp_size=tp_base, pp_size=pp_base, eager_mode=True), ], distributed_backends=["mp"], runner=runner, - test_options=PPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=PPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) def iter_params(self, model_id: str): @@ -110,12 +100,11 @@ def iter_params(self, model_id: str): # NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU # The values displayed here are only a rough indicator of the size of the model -# yapf: disable TEXT_GENERATION_MODELS = { # [Decoder-only] # Uses Llama # "BAAI/AquilaChat-7B": PPTestSettings.fast(), - "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"), # noqa: E501 + "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(load_format="dummy"), "baichuan-inc/Baichuan-7B": PPTestSettings.fast(), "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(), "bigscience/bloomz-1b1": PPTestSettings.fast(), @@ -149,7 +138,7 @@ def iter_params(self, model_id: str): # Uses Llama # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(), "state-spaces/mamba-130m-hf": PPTestSettings.fast(), - "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"), # noqa: E501 + "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(load_format="dummy"), "mosaicml/mpt-7b": PPTestSettings.fast(), "nvidia/Minitron-8B-Base": PPTestSettings.fast(), "allenai/OLMo-1B-hf": PPTestSettings.fast(), @@ -160,13 +149,15 @@ def iter_params(self, model_id: str): "adept/persimmon-8b-chat": PPTestSettings.fast(), "microsoft/phi-2": PPTestSettings.fast(), "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(), - "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed(multi_node_only=True, load_format="dummy"), # noqa: E501 + "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.detailed( + multi_node_only=True, load_format="dummy" + ), "Qwen/Qwen-7B-Chat": PPTestSettings.fast(), "Qwen/Qwen2.5-0.5B-Instruct": PPTestSettings.fast(), "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(), "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), "bigcode/starcoder2-3b": PPTestSettings.fast(), - "upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"), # noqa: E501 + "upstage/solar-pro-preview-instruct": PPTestSettings.fast(load_format="dummy"), # FIXME: Cannot load tokenizer in latest transformers version. # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf` # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(), @@ -206,7 +197,6 @@ def iter_params(self, model_id: str): "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), "fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(), } -# yapf: enable # NOTE: You can update this on your local machine to run specific tests TEST_MODELS = [ @@ -281,8 +271,10 @@ def _compare_tp( if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": - pytest.skip("Skipping multi-node pipeline parallel test for " - "multiprocessing distributed backend") + pytest.skip( + "Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend" + ) if multi_node_only and not VLLM_MULTI_NODE: pytest.skip("Not in multi-node setting") @@ -357,20 +349,16 @@ def _compare_tp( "mp", ] - compare_two_settings(model_id, - pp_args, - tp_args, - pp_env, - tp_env, - method=method) + compare_two_settings(model_id, pp_args, tp_args, pp_env, tp_env, method=method) @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "runner", - "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"), [ - params for model_id, settings in TEXT_GENERATION_MODELS.items() - for params in settings.iter_params(model_id) if model_id in TEST_MODELS + params + for model_id, settings in TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in TEST_MODELS ], ) @create_new_process_for_each_test() @@ -382,22 +370,25 @@ def test_tp_language_generation( test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_id, - parallel_setup, - distributed_backend, - runner, - test_options, - num_gpus_available, - method="generate", - is_multimodal=False) + _compare_tp( + model_id, + parallel_setup, + distributed_backend, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False, + ) @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "runner", - "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"), [ - params for model_id, settings in EMBEDDING_MODELS.items() - for params in settings.iter_params(model_id) if model_id in TEST_MODELS + params + for model_id, settings in EMBEDDING_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in TEST_MODELS ], ) @create_new_process_for_each_test() @@ -409,22 +400,25 @@ def test_tp_language_embedding( test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_id, - parallel_setup, - distributed_backend, - runner, - test_options, - num_gpus_available, - method="encode", - is_multimodal=False) + _compare_tp( + model_id, + parallel_setup, + distributed_backend, + runner, + test_options, + num_gpus_available, + method="encode", + is_multimodal=False, + ) @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "runner", - "test_options"), + ("model_id", "parallel_setup", "distributed_backend", "runner", "test_options"), [ - params for model_id, settings in MULTIMODAL_MODELS.items() - for params in settings.iter_params(model_id) if model_id in TEST_MODELS + params + for model_id, settings in MULTIMODAL_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in TEST_MODELS ], ) @create_new_process_for_each_test() @@ -436,11 +430,13 @@ def test_tp_multimodal_generation( test_options: PPTestOptions, num_gpus_available, ): - _compare_tp(model_id, - parallel_setup, - distributed_backend, - runner, - test_options, - num_gpus_available, - method="generate", - is_multimodal=True) + _compare_tp( + model_id, + parallel_setup, + distributed_backend, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=True, + ) diff --git a/tests/distributed/test_pipeline_partition.py b/tests/distributed/test_pipeline_partition.py index 69ceedd345a8..4df6f43970d7 100644 --- a/tests/distributed/test_pipeline_partition.py +++ b/tests/distributed/test_pipeline_partition.py @@ -9,7 +9,6 @@ def test_custom_layer_partition(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: def _verify(partition_str, num_layers, pp_size, goldens): @@ -57,7 +56,8 @@ def _verify(partition_str, num_layers, pp_size, goldens): (5, 3, 0, (0, 2)), (5, 3, 1, (2, 4)), (5, 3, 2, (4, 5)), - ]) + ], +) def test_uneven_auto_partition( num_hidden_layers: int, pp_size: int, diff --git a/tests/distributed/test_pp_cudagraph.py b/tests/distributed/test_pp_cudagraph.py index 5ca65a0e8d2c..2c9f47464008 100644 --- a/tests/distributed/test_pp_cudagraph.py +++ b/tests/distributed/test_pp_cudagraph.py @@ -12,12 +12,18 @@ from typing_extensions import LiteralString -@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [ - (2, "JackFram/llama-160m"), -]) -@pytest.mark.parametrize("ATTN_BACKEND", [ - "FLASH_ATTN", -]) +@pytest.mark.parametrize( + "PP_SIZE, MODEL_NAME", + [ + (2, "JackFram/llama-160m"), + ], +) +@pytest.mark.parametrize( + "ATTN_BACKEND", + [ + "FLASH_ATTN", + ], +) @create_new_process_for_each_test() def test_pp_cudagraph( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index abfad9ebfe7d..4bab709fb589 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -9,13 +9,15 @@ import torch import torch.distributed -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - get_world_group, graph_capture, - init_distributed_environment) +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + get_world_group, + graph_capture, + init_distributed_environment, +) from vllm.utils import update_environment_variables @@ -24,13 +26,13 @@ def distributed_run(fn, world_size): processes: list[multiprocessing.Process] = [] for i in range(number_of_processes): env: dict[str, str] = {} - env['RANK'] = str(i) - env['LOCAL_RANK'] = str(i) - env['WORLD_SIZE'] = str(number_of_processes) - env['LOCAL_WORLD_SIZE'] = str(number_of_processes) - env['MASTER_ADDR'] = 'localhost' - env['MASTER_PORT'] = '12345' - p = multiprocessing.Process(target=fn, args=(env, )) + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12345" + p = multiprocessing.Process(target=fn, args=(env,)) processes.append(p) p.start() @@ -47,7 +49,7 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) - local_rank = os.environ['LOCAL_RANK'] + local_rank = os.environ["LOCAL_RANK"] device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) init_distributed_environment() @@ -58,17 +60,18 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) - tensor = torch.ones(16, 1024, 1024, - dtype=torch.float32).cuda(pynccl_comm.rank) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize() assert torch.all(tensor == pynccl_comm.world_size).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl(): distributed_run(worker_fn, 2) @@ -78,7 +81,7 @@ def multiple_allreduce_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") groups = [ torch.distributed.new_group(ranks=[0, 1], backend="gloo"), - torch.distributed.new_group(ranks=[2, 3], backend="gloo") + torch.distributed.new_group(ranks=[2, 3], backend="gloo"), ] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) @@ -95,8 +98,9 @@ def multiple_allreduce_worker_fn(): assert torch.all(tensor == 2).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_multiple_allreduce(): # this tests pynccl for multiple tp groups, in a standalone way # i.e. call `pynccl_comm.all_reduce` directly @@ -121,8 +125,9 @@ def multiple_allreduce_with_vllm_worker_fn(): assert torch.all(tensor == 2).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_multiple_allreduce_with_vllm(): # this tests pynccl for multiple tp groups, together with vllm # i.e. call `tensor_model_parallel_all_reduce` @@ -133,10 +138,11 @@ def test_pynccl_multiple_allreduce_with_vllm(): def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) # run something in the default stream to initialize torch engine - a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') + a = torch.ones((4, 4), device=f"cuda:{pynccl_comm.rank}") torch.cuda.synchronize() with torch.cuda.graph(graph): a_out = pynccl_comm.all_reduce(a) @@ -148,84 +154,90 @@ def worker_fn_with_cudagraph(): @worker_fn_wrapper def all_gather_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" num_elems = 1000 - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * num_elems - result = torch.zeros(num_elems * world_size, - dtype=torch.float32, - device=device) - - expected = torch.cat([ - torch.arange(num_elems, dtype=torch.float32) + r * num_elems - for r in range(world_size) - ]).to(device) + tensor = ( + torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems + ) + result = torch.zeros(num_elems * world_size, dtype=torch.float32, device=device) + + expected = torch.cat( + [ + torch.arange(num_elems, dtype=torch.float32) + r * num_elems + for r in range(world_size) + ] + ).to(device) pynccl_comm.all_gather(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_all_gather(): distributed_run(all_gather_worker_fn, 2) @worker_fn_wrapper def all_gatherv_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" assert world_size <= 8 sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] num_elems = sizes[rank] - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * 100 + tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100 result = torch.zeros(sum(sizes), dtype=torch.float32, device=device) - expected = torch.cat([ - torch.arange(sizes[r], dtype=torch.float32) + r * 100 - for r in range(world_size) - ]).to(device) + expected = torch.cat( + [ + torch.arange(sizes[r], dtype=torch.float32) + r * 100 + for r in range(world_size) + ] + ).to(device) pynccl_comm.all_gatherv(result, tensor, sizes=sizes) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_all_gatherv(): distributed_run(all_gatherv_worker_fn, 2) @worker_fn_wrapper def reduce_scatter_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" num_elems = 1000 - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * num_elems - assert (num_elems % world_size == 0) - result = torch.zeros(num_elems // world_size, - dtype=torch.float32, - device=device) + tensor = ( + torch.arange(num_elems, dtype=torch.float32, device=device) + rank * num_elems + ) + assert num_elems % world_size == 0 + result = torch.zeros(num_elems // world_size, dtype=torch.float32, device=device) # Calculate expected result for this rank's chunk scattered_size = num_elems // world_size @@ -233,34 +245,37 @@ def reduce_scatter_worker_fn(): torch.arange(num_elems, dtype=torch.float32) + r * num_elems for r in range(world_size) ] - expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] - for tensor in all_tensors).to(device) + expected = sum( + tensor[rank * scattered_size : (rank + 1) * scattered_size] + for tensor in all_tensors + ).to(device) pynccl_comm.reduce_scatter(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_reduce_scatter(): distributed_run(reduce_scatter_worker_fn, 2) @worker_fn_wrapper def reduce_scatterv_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) rank = pynccl_comm.rank world_size = pynccl_comm.world_size - device = f'cuda:{pynccl_comm.rank}' + device = f"cuda:{pynccl_comm.rank}" assert world_size <= 8 sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] num_elems = sum(sizes) - tensor = torch.arange(num_elems, dtype=torch.float32, - device=device) + rank * 100 + tensor = torch.arange(num_elems, dtype=torch.float32, device=device) + rank * 100 result = torch.zeros(sizes[rank], dtype=torch.float32, device=device) # Calculate expected result for this rank's chunk @@ -278,41 +293,41 @@ def reduce_scatterv_worker_fn(): torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_reduce_scatterv(): distributed_run(reduce_scatterv_worker_fn, 2) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_with_cudagraph(): distributed_run(worker_fn_with_cudagraph, 2) @worker_fn_wrapper def send_recv_worker_fn(): - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) if pynccl_comm.rank == 0: - tensor = torch.ones(16, 1024, 1024, - dtype=torch.float32).cuda(pynccl_comm.rank) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) else: - tensor = torch.empty(16, 1024, 1024, - dtype=torch.float32).cuda(pynccl_comm.rank) + tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) if pynccl_comm.rank == 0: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) + pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) torch.cuda.synchronize() assert torch.all(tensor == 1).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test." +) def test_pynccl_send_recv(): distributed_run(send_recv_worker_fn, 2) @@ -322,27 +337,20 @@ def multiple_send_recv_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") groups = [ torch.distributed.new_group(ranks=[0, 2], backend="gloo"), - torch.distributed.new_group(ranks=[1, 3], backend="gloo") + torch.distributed.new_group(ranks=[1, 3], backend="gloo"), ] group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) if torch.distributed.get_rank() == 0: tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) elif torch.distributed.get_rank() == 1: - tensor = 2 * torch.ones( - 16, 1024, 1024, dtype=torch.float32, device=device) + tensor = 2 * torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) else: - tensor = torch.empty(16, - 1024, - 1024, - dtype=torch.float32, - device=device) + tensor = torch.empty(16, 1024, 1024, dtype=torch.float32, device=device) if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) + pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) torch.cuda.synchronize() if torch.distributed.get_rank() in [0, 2]: assert torch.all(tensor == 1).cpu().item() @@ -350,14 +358,16 @@ def multiple_send_recv_worker_fn(): assert torch.all(tensor == 2).cpu().item() -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_multiple_send_recv(): distributed_run(multiple_send_recv_worker_fn, 4) -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") +@pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test." +) def test_pynccl_broadcast(): distributed_run(broadcast_worker_fn, 4) @@ -366,19 +376,17 @@ def test_pynccl_broadcast(): def broadcast_worker_fn(): # Test broadcast for every root rank. # Essentially this is an all-gather operation. - pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, - device=get_world_group().device) + pynccl_comm = PyNcclCommunicator( + get_world_group().cpu_group, device=get_world_group().device + ) recv_tensors = [ - torch.empty(16, - 1024, - 1024, - dtype=torch.float32, - device=pynccl_comm.device) + torch.empty(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device) for i in range(pynccl_comm.world_size) ] - recv_tensors[pynccl_comm.rank] = torch.ones( - 16, 1024, 1024, dtype=torch.float32, - device=pynccl_comm.device) * pynccl_comm.rank + recv_tensors[pynccl_comm.rank] = ( + torch.ones(16, 1024, 1024, dtype=torch.float32, device=pynccl_comm.device) + * pynccl_comm.rank + ) for i in range(pynccl_comm.world_size): pynccl_comm.broadcast(recv_tensors[i], src=i) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py index 6245ccbeca87..2df88377345d 100644 --- a/tests/distributed/test_quick_all_reduce.py +++ b/tests/distributed/test_quick_all_reduce.py @@ -8,20 +8,20 @@ import torch import torch.distributed as dist -from vllm.distributed.communication_op import ( # noqa - tensor_model_parallel_all_reduce) +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.parallel_state import get_tp_group, graph_capture from vllm.platforms import current_platform -from ..utils import (ensure_model_parallel_initialized, - init_test_distributed_environment, multi_process_parallel) +from ..utils import ( + ensure_model_parallel_initialized, + init_test_distributed_environment, + multi_process_parallel, +) torch.manual_seed(42) random.seed(44) # Size over 8MB is sufficient for custom quick allreduce. -test_sizes = [ - random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8) -] +test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)] for i, v in enumerate(test_sizes): test_sizes[i] -= v % 8 @@ -38,8 +38,7 @@ def graph_quickreduce( m.delenv("CUDA_VISIBLE_DEVICES", raising=False) device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) ensure_model_parallel_initialized(tp_size, pp_size) group = get_tp_group().device_group @@ -64,18 +63,15 @@ def graph_quickreduce( for sz in test_sizes: for dtype in [torch.float16, torch.bfloat16]: with graph_capture(device=device) as graph_capture_context: - inp1 = torch.randint(1, - 23, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) - inp2 = torch.randint(-23, - 1, (sz, ), - dtype=dtype, - device=torch.cuda.current_device()) + inp1 = torch.randint( + 1, 23, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + inp2 = torch.randint( + -23, 1, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, - stream=graph_capture_context.stream): + with torch.cuda.graph(graph, stream=graph_capture_context.stream): for _ in range(num_communication): out1 = tensor_model_parallel_all_reduce(inp1) dist.all_reduce(inp1, group=group) @@ -99,39 +95,42 @@ def eager_quickreduce( device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(tp_size, pp_size, rank, - distributed_init_port) + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) # Size over 8MB is sufficient for custom quick allreduce. sz = 16 * 1024 * 1024 fa = get_tp_group().device_communicator.qr_comm - inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], - dtype=torch.float16, - device=device) + inp = torch.tensor( + [1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device + ) out = fa.quick_all_reduce(inp) torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) - inp = torch.tensor([1.0 * ((i) % 23) for i in range(sz)], - dtype=torch.bfloat16, - device=device) + inp = torch.tensor( + [1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device + ) out = fa.quick_all_reduce(inp) torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test quick allreduce for rocm") +@pytest.mark.skipif( + not current_platform.is_rocm(), reason="only test quick allreduce for rocm" +) @pytest.mark.parametrize("quant_mode", ["FP", "INT8", "INT6", "INT4"]) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) @pytest.mark.parametrize("test_target", [graph_quickreduce, eager_quickreduce]) -def test_custom_quick_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, - pipeline_parallel_size, test_target, - quant_mode): +def test_custom_quick_allreduce( + monkeypatch: pytest.MonkeyPatch, + tp_size, + pipeline_parallel_size, + test_target, + quant_mode, +): world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode) - multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, - test_target) + multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index 94ad8f4f1213..baf75fd48c63 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -22,15 +22,13 @@ dist.broadcast_object_list(recv, src=0) ip, port = recv - stateless_pg = StatelessProcessGroup.create(ip, port, rank, - dist.get_world_size()) + stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) for pg in [dist.group.WORLD, stateless_pg]: test_result = all(in_the_same_node_as(pg, source_rank=0)) expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" - assert test_result == expected, \ - f"Expected {expected}, got {test_result}" + assert test_result == expected, f"Expected {expected}, got {test_result}" if pg == dist.group.WORLD: print("Same node test passed! when using torch distributed!") else: diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index ded3d834faf0..82eaed66717c 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -7,6 +7,7 @@ all workers in a node other than the head node, which can cause the test to fail. """ + import json import os from dataclasses import dataclass @@ -56,7 +57,8 @@ def __post_init__(self): raise ValueError( f"Length mismatch: distributed_backends " f"({len(self.distributed_backends)}) != " - f"vllm_major_versions ({len(self.vllm_major_versions)})") + f"vllm_major_versions ({len(self.vllm_major_versions)})" + ) @staticmethod def detailed( @@ -72,18 +74,22 @@ def detailed( for pp_multiplier in [1, 2]: for chunked_prefill_val in [False, True]: parallel_setups.append( - ParallelSetup(tp_size=tp_base, - pp_size=pp_multiplier * pp_base, - enable_fusion=False, - eager_mode=eager_mode_val, - chunked_prefill=chunked_prefill_val)) + ParallelSetup( + tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val, + ) + ) return SPTestSettings( parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], runner=runner, - test_options=SPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=SPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) @staticmethod @@ -100,18 +106,22 @@ def fast( for pp_multiplier in [1, 2]: for chunked_prefill_val in [False, True]: parallel_setups.append( - ParallelSetup(tp_size=tp_base, - pp_size=pp_multiplier * pp_base, - enable_fusion=False, - eager_mode=eager_mode_val, - chunked_prefill=chunked_prefill_val)) + ParallelSetup( + tp_size=tp_base, + pp_size=pp_multiplier * pp_base, + enable_fusion=False, + eager_mode=eager_mode_val, + chunked_prefill=chunked_prefill_val, + ) + ) return SPTestSettings( parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], runner=runner, - test_options=SPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=SPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) @staticmethod @@ -126,28 +136,39 @@ def fp8_quant( parallel_setups = [] for fusion_val in [False, True]: parallel_setups.append( - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - enable_fusion=fusion_val, - eager_mode=True, - chunked_prefill=False)) + ParallelSetup( + tp_size=tp_base, + pp_size=pp_base, + enable_fusion=fusion_val, + eager_mode=True, + chunked_prefill=False, + ) + ) return SPTestSettings( parallel_setups=parallel_setups, distributed_backends=["mp", "ray"], vllm_major_versions=["1", "1"], runner=runner, - test_options=SPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), + test_options=SPTestOptions( + multi_node_only=multi_node_only, load_format=load_format + ), ) def iter_params(self, model_id: str): opts = self.test_options for parallel_setup in self.parallel_setups: - for backend, vllm_major_version in zip(self.distributed_backends, - self.vllm_major_versions): - yield (model_id, parallel_setup, backend, vllm_major_version, - self.runner, opts) + for backend, vllm_major_version in zip( + self.distributed_backends, self.vllm_major_versions + ): + yield ( + model_id, + parallel_setup, + backend, + vllm_major_version, + self.runner, + opts, + ) def _compare_sp( @@ -200,8 +221,10 @@ def _compare_sp( if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": - pytest.skip("Skipping multi-node pipeline parallel test for " - "multiprocessing distributed backend") + pytest.skip( + "Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend" + ) if multi_node_only and not VLLM_MULTI_NODE: pytest.skip("Not in multi-node setting") @@ -232,13 +255,13 @@ def _compare_sp( common_args.append("--skip-tokenizer-init") compilation_config = { - 'level': 3, - 'custom_ops': ["+rms_norm"], - 'compile_sizes': [4, 8], - 'pass_config': { - 'enable_sequence_parallelism': True, - 'enable_fusion': enable_fusion, - 'enable_noop': True, + "level": 3, + "custom_ops": ["+rms_norm"], + "compile_sizes": [4, 8], + "pass_config": { + "enable_sequence_parallelism": True, + "enable_fusion": enable_fusion, + "enable_noop": True, }, } @@ -270,12 +293,9 @@ def _compare_sp( ] try: - compare_two_settings(model_id, - tp_sp_args, - tp_args, - tp_sp_env, - tp_env, - method=method) + compare_two_settings( + model_id, tp_sp_args, tp_args, tp_sp_env, tp_env, method=method + ) except Exception: testing_ray_compiled_graph = tp_sp_env is not None if testing_ray_compiled_graph and vllm_major_version == "0": @@ -301,10 +321,17 @@ def _compare_sp( @pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "runner", "test_options"), + ( + "model_id", + "parallel_setup", + "distributed_backend", + "vllm_major_version", + "runner", + "test_options", + ), [ - params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() + params + for model_id, settings in SP_TEXT_GENERATION_MODELS.items() for params in settings.iter_params(model_id) if model_id in SP_TEST_MODELS ], @@ -319,12 +346,14 @@ def test_tp_sp_generation( test_options: SPTestOptions, num_gpus_available, ): - _compare_sp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - runner, - test_options, - num_gpus_available, - method="generate", - is_multimodal=False) + _compare_sp( + model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + runner, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False, + ) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index e1357b4a34e9..cdea1bfe8f28 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -26,13 +26,13 @@ def distributed_run(fn, world_size): processes = [] for i in range(number_of_processes): env = {} - env['RANK'] = str(i) - env['LOCAL_RANK'] = str(i) - env['WORLD_SIZE'] = str(number_of_processes) - env['LOCAL_WORLD_SIZE'] = str(number_of_processes) - env['MASTER_ADDR'] = 'localhost' - env['MASTER_PORT'] = '12345' - p = multiprocessing.Process(target=fn, args=(env, )) + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12345" + p = multiprocessing.Process(target=fn, args=(env,)) processes.append(p) p.start() @@ -57,25 +57,23 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): - rank = dist.get_rank() if rank == 0: port = get_open_port() - ip = '127.0.0.1' + ip = "127.0.0.1" dist.broadcast_object_list([ip, port], src=0) else: recv = [None, None] dist.broadcast_object_list(recv, src=0) ip, port = recv # type: ignore - stateless_pg = StatelessProcessGroup.create(ip, port, rank, - dist.get_world_size()) + stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) for pg in [dist.group.WORLD, stateless_pg]: - writer_rank = 2 broadcaster = MessageQueue.create_from_process_group( - pg, 40 * 1024, 2, writer_rank) + pg, 40 * 1024, 2, writer_rank + ) if rank == writer_rank: seed = random.randint(0, 1000) dist.broadcast_object_list([seed], writer_rank) diff --git a/tests/distributed/test_shm_buffer.py b/tests/distributed/test_shm_buffer.py index f70028b87960..c6ceab181ff5 100644 --- a/tests/distributed/test_shm_buffer.py +++ b/tests/distributed/test_shm_buffer.py @@ -5,7 +5,8 @@ import unittest from vllm.distributed.device_communicators.shm_object_storage import ( - SingleWriterShmRingBuffer) + SingleWriterShmRingBuffer, +) class TestSingleWriterShmRingBuffer(unittest.TestCase): @@ -25,18 +26,21 @@ def test_buffer_opening(self): """Test opening an existing buffer""" # First create a buffer self.ring_buffer = SingleWriterShmRingBuffer( - data_buffer_size=self.buffer_size, create=True) + data_buffer_size=self.buffer_size, create=True + ) # Then open it with another instance reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle()) self.assertFalse(reader_buffer.is_writer) - self.assertEqual(reader_buffer.shared_memory.name, - self.ring_buffer.shared_memory.name) + self.assertEqual( + reader_buffer.shared_memory.name, self.ring_buffer.shared_memory.name + ) def test_buffer_access(self): """Test accessing allocated buffers""" self.ring_buffer = SingleWriterShmRingBuffer( - data_buffer_size=self.buffer_size, create=True) + data_buffer_size=self.buffer_size, create=True + ) size = 100 address, monotonic_id = self.ring_buffer.allocate_buf(size) @@ -44,11 +48,11 @@ def test_buffer_access(self): # Write some test data test_data = b"Hello, World!" * 7 # 91 bytes with self.ring_buffer.access_buf(address) as (data_buf, metadata): - data_buf[0:len(test_data)] = test_data + data_buf[0 : len(test_data)] = test_data # Read it back with self.ring_buffer.access_buf(address) as (data_buf2, metadata2): - read_data = bytes(data_buf2[0:len(test_data)]) + read_data = bytes(data_buf2[0 : len(test_data)]) read_id = metadata2[0] self.assertEqual(read_data, test_data) @@ -58,7 +62,8 @@ def test_memory_error_on_full_buffer(self): """Test that MemoryError is raised when buffer is full""" small_buffer_size = 200 self.ring_buffer = SingleWriterShmRingBuffer( - data_buffer_size=small_buffer_size, create=True) + data_buffer_size=small_buffer_size, create=True + ) # Fill up the buffer self.ring_buffer.allocate_buf(100) @@ -72,7 +77,8 @@ def test_allocation_and_free(self): """Test allocation and freeing of buffers""" small_buffer_size = 200 self.ring_buffer = SingleWriterShmRingBuffer( - data_buffer_size=small_buffer_size, create=True) + data_buffer_size=small_buffer_size, create=True + ) size = 80 # Write some data @@ -81,7 +87,7 @@ def test_allocation_and_free(self): address, monotonic_id = self.ring_buffer.allocate_buf(size) with self.ring_buffer.access_buf(address) as (data_buf, metadata): data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use - data_buf[4:len(test_data) + 4] = test_data + data_buf[4 : len(test_data) + 4] = test_data print(self.ring_buffer.metadata) freed_ids = self.ring_buffer.free_buf(lambda *args: True) print(f" Freed IDs: {freed_ids}") @@ -90,7 +96,8 @@ def test_allocation_and_free(self): def test_clear_buffer(self): """Test clearing the buffer""" self.ring_buffer = SingleWriterShmRingBuffer( - data_buffer_size=self.buffer_size, create=True) + data_buffer_size=self.buffer_size, create=True + ) # Allocate some buffers for _ in range(3): @@ -121,8 +128,7 @@ def main(): # Manual demonstration try: print("Creating ring buffer...") - writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, - create=True) + writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, create=True) reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle()) print(f"Buffer created with name: {writer_buffer.shared_memory.name}") @@ -140,7 +146,7 @@ def main(): # Write some test data with writer_buffer.access_buf(address) as (data_buf, metadata): test_message = f"Test message {i}".encode() - data_buf[0:len(test_message)] = test_message + data_buf[0 : len(test_message)] = test_message except MemoryError as e: print(f" Failed to allocate {size} bytes: {e}") diff --git a/tests/distributed/test_shm_storage.py b/tests/distributed/test_shm_storage.py index 03495222bc1b..b9a5c22447fd 100644 --- a/tests/distributed/test_shm_storage.py +++ b/tests/distributed/test_shm_storage.py @@ -12,28 +12,33 @@ # Assuming these are imported from your module from vllm.distributed.device_communicators.shm_object_storage import ( - MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer) -from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, - MultiModalSharedField) + MsgpackSerde, + SingleWriterShmObjectStorage, + SingleWriterShmRingBuffer, +) +from vllm.multimodal.inputs import ( + MultiModalFieldElem, + MultiModalKwargsItem, + MultiModalSharedField, +) def _dummy_elem(modality: str, key: str, size: int): return MultiModalFieldElem( modality=modality, key=key, - data=torch.empty((size, ), dtype=torch.int8), + data=torch.empty((size,), dtype=torch.int8), field=MultiModalSharedField(1), ) def _dummy_item(modality: str, size_by_key: dict[str, int]): - return MultiModalKwargsItem.from_elems([ - _dummy_elem(modality, key, size) for key, size in size_by_key.items() - ]) + return MultiModalKwargsItem.from_elems( + [_dummy_elem(modality, key, size) for key, size in size_by_key.items()] + ) class TestSingleWriterShmObjectStorage(unittest.TestCase): - def setUp(self): """Set up test fixtures before each test method.""" ring_buffer = SingleWriterShmRingBuffer( @@ -208,8 +213,7 @@ def test_invalid_get_operations(self): with self.assertRaises(ValueError) as context: self.storage.get(address, monotonic_id + 100) - self.assertIn("has been modified or is invalid", \ - str(context.exception)) + self.assertIn("has been modified or is invalid", str(context.exception)) def test_clear_storage(self): """Test clearing the storage.""" @@ -234,8 +238,7 @@ def test_clear_storage(self): # Reader process function def reader_process(process_id, storage_handle, items_to_read): """Reader process that connects to existing shared memory and reads data.""" - reader_storage = SingleWriterShmObjectStorage.create_from_handle( - storage_handle) + reader_storage = SingleWriterShmObjectStorage.create_from_handle(storage_handle) print(f"Reader {process_id} started") @@ -276,11 +279,7 @@ def run_multiprocess_example(): # Test basic data types test_data = [ - ("user_data", { - "name": "Alice", - "age": 30, - "scores": [95, 87, 92] - }), + ("user_data", {"name": "Alice", "age": 30, "scores": [95, 87, 92]}), ("simple_string", "Hello, World!"), ("number", 42), ("list_data", [1, 2, 3, "four", 5.0]), @@ -301,8 +300,9 @@ def run_multiprocess_example(): # initialize lock for reader processes handle.reader_lock = Lock() for i in range(storage.n_readers): - p = multiprocessing.Process(target=reader_process, - args=(i, handle, stored_items)) + p = multiprocessing.Process( + target=reader_process, args=(i, handle, stored_items) + ) processes.append(p) p.start() diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py index 83e1fe47aeec..e669b81b04f0 100644 --- a/tests/distributed/test_symm_mem_allreduce.py +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -14,11 +14,12 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed.communication_op import tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators.cuda_communicator import ( - CudaCommunicator) -from vllm.distributed.parallel_state import (get_tp_group, - init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator +from vllm.distributed.parallel_state import ( + get_tp_group, + init_distributed_environment, + initialize_model_parallel, +) from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.platforms import current_platform @@ -32,8 +33,7 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue): monkeypatch = pytest.MonkeyPatch() - config = VllmConfig(parallel_config=ParallelConfig( - tensor_parallel_size=world_size)) + config = VllmConfig(parallel_config=ParallelConfig(tensor_parallel_size=world_size)) with monkeypatch.context() as m, set_current_vllm_config(config): m.delenv("CUDA_VISIBLE_DEVICES", raising=False) @@ -42,34 +42,34 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue): torch.cuda.set_device(device) torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) - cuda_communicator = typing.cast(CudaCommunicator, - get_tp_group().device_communicator) + cuda_communicator = typing.cast( + CudaCommunicator, get_tp_group().device_communicator + ) symm_mem_comm = cuda_communicator.symm_mem_comm if symm_mem_comm is None or symm_mem_comm.disabled: # can't use skip under multiprocessing q.put("SymmMemCommunicator is not available or disabled.") return - inp_direct_symm_mem = torch.randint(1, - 23, (test_size_elements, ), - dtype=dtype, - device=device) + inp_direct_symm_mem = torch.randint( + 1, 23, (test_size_elements,), dtype=dtype, device=device + ) if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): # can't use skip under multiprocessing - q.put( - "SymmMemCommunicator isn't used for this world and input size." - ) + q.put("SymmMemCommunicator isn't used for this world and input size.") return original_inp_direct_symm_mem = inp_direct_symm_mem.clone() @@ -78,42 +78,37 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue): group = get_tp_group().device_group dist.all_reduce(original_inp_direct_symm_mem, group=group) - torch.testing.assert_close(out_direct_symm_mem, - original_inp_direct_symm_mem, - atol=2.5, - rtol=0.1) + torch.testing.assert_close( + out_direct_symm_mem, original_inp_direct_symm_mem, atol=2.5, rtol=0.1 + ) # Test tensor_model_parallel_all_reduce which should use symm_mem - inp_tensor_parallel = torch.randint(-23, - 1, (test_size_elements, ), - dtype=dtype, - device=device) + inp_tensor_parallel = torch.randint( + -23, 1, (test_size_elements,), dtype=dtype, device=device + ) original_inp_tensor_parallel = inp_tensor_parallel.clone() - out_tensor_parallel = tensor_model_parallel_all_reduce( - inp_tensor_parallel) + out_tensor_parallel = tensor_model_parallel_all_reduce(inp_tensor_parallel) dist.all_reduce(original_inp_tensor_parallel, group=group) - torch.testing.assert_close(out_tensor_parallel, - original_inp_tensor_parallel, - atol=2.5, - rtol=0.1) + torch.testing.assert_close( + out_tensor_parallel, original_inp_tensor_parallel, atol=2.5, rtol=0.1 + ) @pytest.mark.skipif( not current_platform.is_cuda(), - reason="SymmMemAllreduce is only available for CUDA platforms.") + reason="SymmMemAllreduce is only available for CUDA platforms.", +) @pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("pipeline_parallel_size", [1]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") -def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, - pipeline_parallel_size): +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +def test_symm_mem_allreduce( + monkeypatch: pytest.MonkeyPatch, tp_size, pipeline_parallel_size +): world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") - q = mp.get_context('spawn').Queue() - mp.spawn(symm_mem_allreduce_worker, - args=(world_size, q), - nprocs=world_size) + q = mp.get_context("spawn").Queue() + mp.spawn(symm_mem_allreduce_worker, args=(world_size, q), nprocs=world_size) try: val = q.get(timeout=1) except queue.Empty: @@ -126,18 +121,20 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, @pytest.mark.skipif( not current_platform.is_cuda(), - reason="SymmMemAllreduce is only available for CUDA platforms.") -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], - reason="Only test on CUDA") + reason="SymmMemAllreduce is only available for CUDA platforms.", +) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch): world_size = 4 if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") # Verify that the DataParallel runs without error - engine_args = EngineArgs(model="distilbert/distilgpt2", - enforce_eager=True, - enable_prefix_caching=True, - data_parallel_size=2, - tensor_parallel_size=2, - data_parallel_backend="mp") + engine_args = EngineArgs( + model="distilbert/distilgpt2", + enforce_eager=True, + enable_prefix_caching=True, + data_parallel_size=2, + tensor_parallel_size=2, + data_parallel_backend="mp", + ) LLMEngine.from_engine_args(engine_args) diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index 9f2c3eaec359..f415409d7b37 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -24,13 +24,15 @@ # set different `gpu_memory_utilization` and `swap_space` for different ranks, # to test if all ranks agree on the same kv cache configuration. -llm = LLM(model="facebook/opt-125m", - tensor_parallel_size=2, - pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), - distributed_executor_backend="external_launcher", - gpu_memory_utilization=random.uniform(0.7, 0.9), - swap_space=random.randint(1, 4), - seed=0) +llm = LLM( + model="facebook/opt-125m", + tensor_parallel_size=2, + pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), + distributed_executor_backend="external_launcher", + gpu_memory_utilization=random.uniform(0.7, 0.9), + swap_space=random.randint(1, 4), + seed=0, +) outputs = llm.generate(prompts, sampling_params) @@ -48,15 +50,14 @@ def test_consistent_across_ranks(obj): assert container[0] == obj -test_consistent_across_ranks( - llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) -test_consistent_across_ranks( - llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) +test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) +test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) # make sure we can access the model parameters from the calling process # of the `LLM` instance. -params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner. - model.parameters()) +params = list( + llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters() +) test_consistent_across_ranks(len(params)) # all ranks should have the same outputs @@ -65,5 +66,4 @@ def test_consistent_across_ranks(obj): generated_text = output.outputs[0].text test_consistent_across_ranks(prompt) test_consistent_across_ranks(generated_text) - print(f"Rank {torch_rank}, Prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") + print(f"Rank {torch_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/distributed/test_torchrun_example_moe.py b/tests/distributed/test_torchrun_example_moe.py index 2d6b930fcc07..1aa7f1793570 100644 --- a/tests/distributed/test_torchrun_example_moe.py +++ b/tests/distributed/test_torchrun_example_moe.py @@ -24,23 +24,22 @@ if dp_size > 1: # distribute the prompts across the data parallel ranks - prompts = [ - prompt for idx, prompt in enumerate(prompts) - if idx % dp_size == dp_rank - ] + prompts = [prompt for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # set different `gpu_memory_utilization` and `swap_space` for different ranks, # to test if all ranks agree on the same kv cache configuration. -llm = LLM(model="microsoft/Phi-mini-MoE-instruct", - tensor_parallel_size=int(os.getenv("TP_SIZE", "1")), - pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")), - enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1, - distributed_executor_backend="external_launcher", - gpu_memory_utilization=random.uniform(0.7, 0.9), - swap_space=random.randint(1, 4), - seed=0) +llm = LLM( + model="microsoft/Phi-mini-MoE-instruct", + tensor_parallel_size=int(os.getenv("TP_SIZE", "1")), + pipeline_parallel_size=int(os.getenv("PP_SIZE", "1")), + enable_expert_parallel=int(os.getenv("ENABLE_EP", "0")) == 1, + distributed_executor_backend="external_launcher", + gpu_memory_utilization=random.uniform(0.7, 0.9), + swap_space=random.randint(1, 4), + seed=0, +) outputs = llm.generate(prompts, sampling_params) @@ -54,21 +53,18 @@ def test_consistent_across_ranks(obj): dist.broadcast_object_list([obj], src=group.ranks[0], group=cpu_group) else: container = [None] - dist.broadcast_object_list(container, - src=group.ranks[0], - group=cpu_group) + dist.broadcast_object_list(container, src=group.ranks[0], group=cpu_group) assert container[0] == obj -test_consistent_across_ranks( - llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) -test_consistent_across_ranks( - llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) +test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_cpu_blocks) +test_consistent_across_ranks(llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) # make sure we can access the model parameters from the calling process # of the `LLM` instance. -params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner. - model.parameters()) +params = list( + llm.llm_engine.model_executor.driver_worker.worker.model_runner.model.parameters() +) test_consistent_across_ranks(len(params)) # all ranks should have the same outputs @@ -77,5 +73,4 @@ def test_consistent_across_ranks(obj): generated_text = output.outputs[0].text test_consistent_across_ranks(prompt) test_consistent_across_ranks(generated_text) - print(f"Rank {group_rank}, Prompt: {prompt!r}, " - f"Generated text: {generated_text!r}") + print(f"Rank {group_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 0287ad94e388..2a6936fcd4c2 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -10,21 +10,22 @@ import vllm.envs as envs from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import (cuda_device_count_stateless, get_open_port, - update_environment_variables) +from vllm.utils import ( + cuda_device_count_stateless, + get_open_port, + update_environment_variables, +) from ..utils import multi_gpu_test @ray.remote class _CUDADeviceCountStatelessTestActor: - def get_count(self): return cuda_device_count_stateless() def set_cuda_visible_devices(self, cuda_visible_devices: str): - update_environment_variables( - {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) def get_cuda_visible_devices(self): return envs.CUDA_VISIBLE_DEVICES @@ -34,10 +35,9 @@ def test_cuda_device_count_stateless(): """Test that cuda_device_count_stateless changes return value if CUDA_VISIBLE_DEVICES is changed.""" actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore - num_gpus=2).remote() - assert len( - sorted(ray.get( - actor.get_cuda_visible_devices.remote()).split(","))) == 2 + num_gpus=2 + ).remote() + assert len(sorted(ray.get(actor.get_cuda_visible_devices.remote()).split(","))) == 2 assert ray.get(actor.get_count.remote()) == 2 ray.get(actor.set_cuda_visible_devices.remote("0")) assert ray.get(actor.get_count.remote()) == 1 @@ -46,15 +46,13 @@ def test_cuda_device_count_stateless(): def cpu_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) if rank <= 2: - pg2 = StatelessProcessGroup.create(host="127.0.0.1", - port=port2, - rank=rank, - world_size=3) + pg2 = StatelessProcessGroup.create( + host="127.0.0.1", port=port2, rank=rank, world_size=3 + ) data = torch.tensor([rank]) data = pg1.broadcast_obj(data, src=2) assert data.item() == 2 @@ -68,16 +66,14 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2): def gpu_worker(rank, WORLD_SIZE, port1, port2): torch.cuda.set_device(rank) - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) pynccl1 = PyNcclCommunicator(pg1, device=rank) if rank <= 2: - pg2 = StatelessProcessGroup.create(host="127.0.0.1", - port=port2, - rank=rank, - world_size=3) + pg2 = StatelessProcessGroup.create( + host="127.0.0.1", port=port2, rank=rank, world_size=3 + ) pynccl2 = PyNcclCommunicator(pg2, device=rank) data = torch.tensor([rank]).cuda() pynccl1.all_reduce(data) @@ -96,10 +92,9 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2): def broadcast_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) if rank == 2: pg1.broadcast_obj("secret", src=2) else: @@ -109,10 +104,9 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2): def allgather_worker(rank, WORLD_SIZE, port1, port2): - pg1 = StatelessProcessGroup.create(host="127.0.0.1", - port=port1, - rank=rank, - world_size=WORLD_SIZE) + pg1 = StatelessProcessGroup.create( + host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE + ) data = pg1.all_gather_obj(rank) assert data == list(range(WORLD_SIZE)) pg1.barrier() @@ -121,7 +115,8 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2): @pytest.mark.skip(reason="This test is flaky and prone to hang.") @multi_gpu_test(num_gpus=4) @pytest.mark.parametrize( - "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]) + "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker] +) def test_stateless_process_group(worker): port1 = get_open_port() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -129,12 +124,14 @@ def test_stateless_process_group(worker): port2 = get_open_port() WORLD_SIZE = 4 from multiprocessing import get_context + ctx = get_context("fork") processes = [] for i in range(WORLD_SIZE): rank = i processes.append( - ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2))) + ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)) + ) for p in processes: p.start() for p in processes: diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 33888f008f04..9d367349fc2e 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -10,22 +10,30 @@ import pytest from vllm.config import CompilationConfig, config -from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, - get_type, get_type_hints, is_not_builtin, - is_type, literal_to_kwargs, optional_type, - parse_type) +from vllm.engine.arg_utils import ( + EngineArgs, + contains_type, + get_kwargs, + get_type, + get_type_hints, + is_not_builtin, + is_type, + literal_to_kwargs, + optional_type, + parse_type, +) from vllm.utils import FlexibleArgumentParser -@pytest.mark.parametrize(("type", "value", "expected"), [ - (int, "42", 42), - (float, "3.14", 3.14), - (str, "Hello World!", "Hello World!"), - (json.loads, '{"foo":1,"bar":2}', { - "foo": 1, - "bar": 2 - }), -]) +@pytest.mark.parametrize( + ("type", "value", "expected"), + [ + (int, "42", 42), + (float, "3.14", 3.14), + (str, "Hello World!", "Hello World!"), + (json.loads, '{"foo":1,"bar":2}', {"foo": 1, "bar": 2}), + ], +) def test_parse_type(type, value, expected): parse_type_func = parse_type(type) assert parse_type_func(value) == expected @@ -37,50 +45,56 @@ def test_optional_type(): assert optional_type_func("42") == 42 -@pytest.mark.parametrize(("type_hint", "type", "expected"), [ - (int, int, True), - (int, float, False), - (list[int], list, True), - (list[int], tuple, False), - (Literal[0, 1], Literal, True), -]) +@pytest.mark.parametrize( + ("type_hint", "type", "expected"), + [ + (int, int, True), + (int, float, False), + (list[int], list, True), + (list[int], tuple, False), + (Literal[0, 1], Literal, True), + ], +) def test_is_type(type_hint, type, expected): assert is_type(type_hint, type) == expected -@pytest.mark.parametrize(("type_hints", "type", "expected"), [ - ({float, int}, int, True), - ({int, tuple}, int, True), - ({int, tuple[int]}, int, True), - ({int, tuple[int, ...]}, int, True), - ({int, tuple[int]}, float, False), - ({int, tuple[int, ...]}, float, False), - ({str, Literal["x", "y"]}, Literal, True), -]) +@pytest.mark.parametrize( + ("type_hints", "type", "expected"), + [ + ({float, int}, int, True), + ({int, tuple}, int, True), + ({int, tuple[int]}, int, True), + ({int, tuple[int, ...]}, int, True), + ({int, tuple[int]}, float, False), + ({int, tuple[int, ...]}, float, False), + ({str, Literal["x", "y"]}, Literal, True), + ], +) def test_contains_type(type_hints, type, expected): assert contains_type(type_hints, type) == expected -@pytest.mark.parametrize(("type_hints", "type", "expected"), [ - ({int, float}, int, int), - ({int, float}, str, None), - ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), -]) +@pytest.mark.parametrize( + ("type_hints", "type", "expected"), + [ + ({int, float}, int, int), + ({int, float}, str, None), + ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), + ], +) def test_get_type(type_hints, type, expected): assert get_type(type_hints, type) == expected -@pytest.mark.parametrize(("type_hints", "expected"), [ - ({Literal[1, 2]}, { - "type": int, - "choices": [1, 2] - }), - ({str, Literal["x", "y"]}, { - "type": str, - "metavar": ["x", "y"] - }), - ({Literal[1, "a"]}, Exception), -]) +@pytest.mark.parametrize( + ("type_hints", "expected"), + [ + ({Literal[1, 2]}, {"type": int, "choices": [1, 2]}), + ({str, Literal["x", "y"]}, {"type": str, "metavar": ["x", "y"]}), + ({Literal[1, "a"]}, Exception), + ], +) def test_literal_to_kwargs(type_hints, expected): context = nullcontext() if expected is Exception: @@ -123,22 +137,27 @@ class DummyConfig: """Nested config""" -@pytest.mark.parametrize(("type_hint", "expected"), [ - (int, False), - (DummyConfig, True), -]) +@pytest.mark.parametrize( + ("type_hint", "expected"), + [ + (int, False), + (DummyConfig, True), + ], +) def test_is_not_builtin(type_hint, expected): assert is_not_builtin(type_hint) == expected @pytest.mark.parametrize( - ("type_hint", "expected"), [ + ("type_hint", "expected"), + [ (Annotated[int, "annotation"], {int}), (Optional[int], {int, type(None)}), (Annotated[Optional[int], "annotation"], {int, type(None)}), (Optional[Annotated[int, "annotation"]], {int, type(None)}), ], - ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"]) + ids=["Annotated", "Optional", "Annotated_Optional", "Optional_Annotated"], +) def test_get_type_hints(type_hint, expected): assert get_type_hints(type_hint) == expected @@ -178,24 +197,16 @@ def test_get_kwargs(): ("arg", "expected"), [ (None, dict()), - ('{"video": {"num_frames": 123} }', { - "video": { - "num_frames": 123 - } - }), + ('{"video": {"num_frames": 123} }', {"video": {"num_frames": 123}}), ( '{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa { - "video": { - "num_frames": 123, - "fps": 1.0, - "foo": "bar" - }, - "image": { - "foo": "bar" - } - }), - ]) + "video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, + "image": {"foo": "bar"}, + }, + ), + ], +) def test_media_io_kwargs_parser(arg, expected): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: @@ -230,24 +241,32 @@ def test_compilation_config(): assert args.compilation_config.level == 3 # set to string form of a dict - args = parser.parse_args([ - "-O", - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' - '"use_inductor": false}', - ]) - assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] - and not args.compilation_config.use_inductor) + args = parser.parse_args( + [ + "-O", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": false}', + ] + ) + assert ( + args.compilation_config.level == 3 + and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and not args.compilation_config.use_inductor + ) # set to string form of a dict - args = parser.parse_args([ - "--compilation-config=" - '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' - '"use_inductor": true}', - ]) - assert (args.compilation_config.level == 3 and - args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] - and args.compilation_config.use_inductor) + args = parser.parse_args( + [ + "--compilation-config=" + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], ' + '"use_inductor": true}', + ] + ) + assert ( + args.compilation_config.level == 3 + and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8] + and args.compilation_config.use_inductor + ) def test_prefix_cache_default(): @@ -255,8 +274,7 @@ def test_prefix_cache_default(): args = parser.parse_args([]) engine_args = EngineArgs.from_cli_args(args=args) - assert (not engine_args.enable_prefix_caching - ), "prefix caching defaults to off." + assert not engine_args.enable_prefix_caching, "prefix caching defaults to off." # with flag to turn it on. args = parser.parse_args(["--enable-prefix-caching"]) @@ -269,29 +287,15 @@ def test_prefix_cache_default(): assert not engine_args.enable_prefix_caching -# yapf: disable -@pytest.mark.parametrize(("arg", "expected", "option"), [ - (None, None, "mm-processor-kwargs"), - ("{}", {}, "mm-processor-kwargs"), - ( - '{"num_crops": 4}', - { - "num_crops": 4 - }, - "mm-processor-kwargs" - ), - ( - '{"foo": {"bar": "baz"}}', - { - "foo": - { - "bar": "baz" - } - }, - "mm-processor-kwargs" - ), -]) -# yapf: enable +@pytest.mark.parametrize( + ("arg", "expected", "option"), + [ + (None, None, "mm-processor-kwargs"), + ("{}", {}, "mm-processor-kwargs"), + ('{"num_crops": 4}', {"num_crops": 4}, "mm-processor-kwargs"), + ('{"foo": {"bar": "baz"}}', {"foo": {"bar": "baz"}}, "mm-processor-kwargs"), + ], +) def test_composite_arg_parser(arg, expected, option): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: @@ -303,8 +307,7 @@ def test_composite_arg_parser(arg, expected, option): def test_human_readable_model_len(): # `exit_on_error` disabled to test invalid values below - parser = EngineArgs.add_cli_args( - FlexibleArgumentParser(exit_on_error=False)) + parser = EngineArgs.add_cli_args(FlexibleArgumentParser(exit_on_error=False)) args = parser.parse_args([]) assert args.max_model_len is None diff --git a/tests/engine/test_short_mm_context.py b/tests/engine/test_short_mm_context.py index 9eb3dfc09224..54a88586d8ed 100644 --- a/tests/engine/test_short_mm_context.py +++ b/tests/engine/test_short_mm_context.py @@ -5,12 +5,12 @@ from ..conftest import IMAGE_ASSETS -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "USER: <image>\nWhat's the content of the image?\nASSISTANT:", - "cherry_blossom": - "USER: <image>\nWhat is the season?\nASSISTANT:", -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "USER: <image>\nWhat's the content of the image?\nASSISTANT:", + "cherry_blossom": "USER: <image>\nWhat is the season?\nASSISTANT:", + } +) models = ["llava-hf/llava-1.5-7b-hf"] @@ -19,8 +19,7 @@ def test_context_length_too_short(vllm_runner, image_assets, model): images = [asset.pil_image for asset in image_assets] - with pytest.raises(ValueError, - match="longer than the maximum model length"): + with pytest.raises(ValueError, match="longer than the maximum model length"): vllm_model = vllm_runner( model, max_model_len=128, # LLaVA has a feature size of 576 @@ -29,6 +28,6 @@ def test_context_length_too_short(vllm_runner, image_assets, model): ) with vllm_model: - vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]], - max_tokens=1, - images=[images[0]]) + vllm_model.generate_greedy( + [HF_IMAGE_PROMPTS[0]], max_tokens=1, images=[images[0]] + ) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index 7daf62595b1b..a52e1cb7df33 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -26,8 +26,10 @@ def sample_token_ids(): @pytest.fixture def sample_regex(): - return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" - r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + return ( + r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + ) @pytest.fixture @@ -35,40 +37,27 @@ def sample_json_schema(): return { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, "skills": { "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 + "items": {"type": "string", "maxLength": 10}, + "minItems": 3, }, "work_history": { "type": "array", "items": { "type": "object", "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "number" - }, - "position": { - "type": "string" - } + "company": {"type": "string"}, + "duration": {"type": "number"}, + "position": {"type": "string"}, }, - "required": ["company", "position"] - } - } + "required": ["company", "position"], + }, + }, }, - "required": ["name", "age", "skills", "work_history"] + "required": ["name", "age", "skills", "work_history"], } @@ -80,65 +69,54 @@ def sample_complex_json_schema(): "score": { "type": "integer", "minimum": 0, - "maximum": 100 # Numeric range + "maximum": 100, # Numeric range }, "grade": { "type": "string", - "pattern": "^[A-D]$" # Regex pattern + "pattern": "^[A-D]$", # Regex pattern }, "email": { "type": "string", - "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$", }, "tags": { "type": "array", "items": { "type": "string", - "pattern": - "^[a-z]{1,10}$" # Combining length and pattern restrictions - } - } + # Combining length and pattern restrictions + "pattern": "^[a-z]{1,10}$", + }, + }, }, - "required": ["score", "grade", "email", "tags"] + "required": ["score", "grade", "email", "tags"], } @pytest.fixture def sample_definition_json_schema(): return { - '$defs': { - 'Step': { - 'properties': { - 'explanation': { - 'title': 'Explanation', - 'type': 'string' - }, - 'output': { - 'title': 'Output', - 'type': 'string' - } + "$defs": { + "Step": { + "properties": { + "explanation": {"title": "Explanation", "type": "string"}, + "output": {"title": "Output", "type": "string"}, }, - 'required': ['explanation', 'output'], - 'title': 'Step', - 'type': 'object' + "required": ["explanation", "output"], + "title": "Step", + "type": "object", } }, - 'properties': { - 'steps': { - 'items': { - '$ref': '#/$defs/Step' - }, - 'title': 'Steps', - 'type': 'array' + "properties": { + "steps": { + "items": {"$ref": "#/$defs/Step"}, + "title": "Steps", + "type": "array", }, - 'final_answer': { - 'title': 'Final Answer', - 'type': 'string' - } + "final_answer": {"title": "Final Answer", "type": "string"}, }, - 'required': ['steps', 'final_answer'], - 'title': 'MathReasoning', - 'type': 'object' + "required": ["steps", "final_answer"], + "title": "MathReasoning", + "type": "object", } @@ -149,64 +127,71 @@ def sample_enum_json_schema(): "properties": { "status": { "type": "string", - "enum": ["active", "inactive", - "pending"] # Literal values using enum + "enum": ["active", "inactive", "pending"], # Literal values using enum }, "priority": { "type": "string", - "enum": ["low", "medium", "high", "critical"] + "enum": ["low", "medium", "high", "critical"], }, "category": { "type": "object", "properties": { "type": { "type": "string", - "enum": ["bug", "feature", "improvement"] + "enum": ["bug", "feature", "improvement"], }, "severity": { "type": "integer", - "enum": [1, 2, 3, 4, - 5] # Enum can also contain numbers - } + "enum": [1, 2, 3, 4, 5], # Enum can also contain numbers + }, }, - "required": ["type", "severity"] + "required": ["type", "severity"], }, "flags": { "type": "array", "items": { "type": "string", - "enum": ["urgent", "blocked", "needs_review", "approved"] - } - } + "enum": ["urgent", "blocked", "needs_review", "approved"], + }, + }, }, - "required": ["status", "priority", "category", "flags"] + "required": ["status", "priority", "category", "flags"], } @pytest.fixture def sample_structured_outputs_choices(): return [ - "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", - "Ruby", "Swift", "Kotlin" + "Python", + "Java", + "JavaScript", + "C++", + "C#", + "PHP", + "TypeScript", + "Ruby", + "Swift", + "Kotlin", ] @pytest.fixture def sample_sql_statements(): - return (""" + return """ start: select_statement select_statement: "SELECT" column "from" table "where" condition column: "col_1" | "col_2" table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" -""") +""" @pytest.fixture(scope="session") def zephyr_lora_files(): """Download zephyr LoRA files once per test session.""" from huggingface_hub import snapshot_download + return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora") @@ -214,5 +199,5 @@ def zephyr_lora_files(): def opt125_lora_files() -> str: """Download opt-125m LoRA files once per test session.""" from huggingface_hub import snapshot_download - return snapshot_download( - repo_id="peft-internal-testing/opt-125m-dummy-lora") + + return snapshot_download(repo_id="peft-internal-testing/opt-125m-dummy-lora") diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py index 5d605e906e81..e2d107b60586 100644 --- a/tests/entrypoints/llm/test_accuracy.py +++ b/tests/entrypoints/llm/test_accuracy.py @@ -48,20 +48,23 @@ def run_test(model_name, more_args=None): measured_value = results["results"][TASK][FILTER] assert model_name in EXPECTED_VALUES, ( - f"Cannot find the expected value for the model {model_name=}") + f"Cannot find the expected value for the model {model_name=}" + ) expected_value = EXPECTED_VALUES[model_name] - assert (measured_value - RTOL < expected_value - and measured_value + RTOL > expected_value - ), f"Expected: {expected_value} | Measured: {measured_value}" + assert ( + measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" # TODO: [AlexM] Fix it with new CI/CD tests -TPU_TP_TEST_STR = "" #"tensor_parallel_size=4" +TPU_TP_TEST_STR = "" # "tensor_parallel_size=4" -@pytest.mark.skipif(not current_platform.is_cuda() - and not current_platform.is_tpu(), - reason="V1 is currently only supported on CUDA and TPU") +@pytest.mark.skipif( + not current_platform.is_cuda() and not current_platform.is_tpu(), + reason="V1 is currently only supported on CUDA and TPU", +) @pytest.mark.parametrize("model", MODEL_NAMES) def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): """Run with the V1 Engine.""" @@ -82,12 +85,14 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): run_test(model, more_args) -@pytest.mark.skipif(not current_platform.is_cuda() - and not current_platform.is_tpu(), - reason="V1 is currently only supported on CUDA and TPU") +@pytest.mark.skipif( + not current_platform.is_cuda() and not current_platform.is_tpu(), + reason="V1 is currently only supported on CUDA and TPU", +) @pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES) def test_lm_eval_accuracy_v1_engine_fp8_kv_cache( - model, monkeypatch: pytest.MonkeyPatch): + model, monkeypatch: pytest.MonkeyPatch +): """Run with the V1 Engine.""" with monkeypatch.context() as m: diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index bf460d0fb25d..b2a958a992a6 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -14,9 +14,7 @@ def text_llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - seed=0) + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, seed=0) yield weakref.proxy(llm) @@ -28,14 +26,8 @@ def text_llm(): def test_chat(text_llm): prompt1 = "Explain the concept of entropy." messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt1}, ] outputs = text_llm.chat(messages) assert len(outputs) == 1 @@ -46,25 +38,13 @@ def test_multi_chat(text_llm): prompt2 = "Explain what among us is." conversation1 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt1}, ] conversation2 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt2 - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": prompt2}, ] messages = [conversation1, conversation2] @@ -94,26 +74,22 @@ def vision_llm(): cleanup_dist_env_and_memory() -@pytest.mark.parametrize("image_urls", - [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], - indirect=True) +@pytest.mark.parametrize( + "image_urls", [[TEST_IMAGE_ASSETS[0], TEST_IMAGE_ASSETS[1]]], indirect=True +) def test_chat_multi_image(vision_llm, image_urls: list[str]): - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "image_url", - "image_url": { - "url": image_url - } - } for image_url in image_urls), - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "image_url", "image_url": {"url": image_url}} + for image_url in image_urls + ), + {"type": "text", "text": "What's in this image?"}, + ], + } + ] outputs = vision_llm.chat(messages) assert len(outputs) >= 0 @@ -124,14 +100,8 @@ def test_llm_chat_tokenization_no_double_bos(text_llm): Check we get a single BOS token for llama chat. """ messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "Hello!" - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello!"}, ] outputs = text_llm.chat(messages) assert len(outputs) == 1 @@ -167,14 +137,8 @@ def thinking_llm(): @pytest.mark.parametrize("enable_thinking", [True, False]) def test_chat_extra_kwargs(thinking_llm, enable_thinking): messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": "What is 1+1?" - }, + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "What is 1+1?"}, ] outputs = thinking_llm.chat( diff --git a/tests/entrypoints/llm/test_collective_rpc.py b/tests/entrypoints/llm/test_collective_rpc.py index 3a13f8c979f2..937aa5c13246 100644 --- a/tests/entrypoints/llm/test_collective_rpc.py +++ b/tests/entrypoints/llm/test_collective_rpc.py @@ -23,9 +23,11 @@ def echo_rank(self): return self.rank monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - load_format="dummy", - tensor_parallel_size=tp_size, - distributed_executor_backend=backend) + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=tp_size, + distributed_executor_backend=backend, + ) assert llm.collective_rpc(echo_rank) == list(range(tp_size)) diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 5af4327b65d0..af9cc0afd26b 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -29,11 +29,13 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True, + ) yield weakref.proxy(llm) @@ -81,11 +83,13 @@ def test_max_model_len(): outputs = llm.generate(PROMPTS, sampling_params) for output in outputs: num_total_tokens = len(output.prompt_token_ids) + len( - output.outputs[0].token_ids) - # Total tokens must not exceed max_model_len. + output.outputs[0].token_ids + ) + # Total tokens must not exceed max_model_len + 1 (the last token can be + # generated with the context length equal to the max model length) # It can be less if generation finishes due to other reasons (e.g., EOS) # before reaching the absolute model length limit. - assert num_total_tokens <= max_model_len + assert num_total_tokens <= max_model_len + 1 def test_log_stats(): diff --git a/tests/entrypoints/llm/test_gpu_utilization.py b/tests/entrypoints/llm/test_gpu_utilization.py index 533da9e6d6ea..896091533ad2 100644 --- a/tests/entrypoints/llm/test_gpu_utilization.py +++ b/tests/entrypoints/llm/test_gpu_utilization.py @@ -16,9 +16,8 @@ def test_gpu_memory_utilization(): # makes sure gpu_memory_utilization is per-instance limit, # not a global limit llms = [ - LLM(model="facebook/opt-125m", - gpu_memory_utilization=0.3, - enforce_eager=True) for i in range(3) + LLM(model="facebook/opt-125m", gpu_memory_utilization=0.3, enforce_eager=True) + for i in range(3) ] for llm in llms: outputs = llm.generate(prompts, sampling_params) diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index b219b33d1760..81126a4f16f9 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -8,12 +8,12 @@ def test_empty_prompt(): llm = LLM(model="openai-community/gpt2", enforce_eager=True) - with pytest.raises(ValueError, match='decoder prompt cannot be empty'): + with pytest.raises(ValueError, match="decoder prompt cannot be empty"): llm.generate([""]) @pytest.mark.skip_v1 def test_out_of_vocab_token(): llm = LLM(model="openai-community/gpt2", enforce_eager=True) - with pytest.raises(ValueError, match='out of vocabulary'): + with pytest.raises(ValueError, match="out of vocabulary"): llm.generate({"prompt_token_ids": [999999]}) diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index f8ed5dda260f..25e663f3af0e 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for HF_HUB_OFFLINE mode""" + import dataclasses import importlib import sys @@ -91,12 +92,11 @@ def disable_connect(*args, **kwargs): def _re_import_modules(): - hf_hub_module_names = [ - k for k in sys.modules if k.startswith("huggingface_hub") - ] + hf_hub_module_names = [k for k in sys.modules if k.startswith("huggingface_hub")] transformers_module_names = [ - k for k in sys.modules if k.startswith("transformers") - and not k.startswith("transformers_modules") + k + for k in sys.modules + if k.startswith("transformers") and not k.startswith("transformers_modules") ] reload_exception = None diff --git a/tests/entrypoints/openai/conftest.py b/tests/entrypoints/openai/conftest.py index 0ecdd4245df4..b40079d8dc3d 100644 --- a/tests/entrypoints/openai/conftest.py +++ b/tests/entrypoints/openai/conftest.py @@ -7,14 +7,14 @@ @pytest.fixture def mary_had_lamb(): - path = AudioAsset('mary_had_lamb').get_local_path() + path = AudioAsset("mary_had_lamb").get_local_path() with open(str(path), "rb") as f: yield f @pytest.fixture def winning_call(): - path = AudioAsset('winning_call').get_local_path() + path = AudioAsset("winning_call").get_local_path() with open(str(path), "rb") as f: yield f @@ -22,6 +22,6 @@ def winning_call(): @pytest.fixture def foscolo(): # Test translation it->en - path = AudioAsset('azacinto_foscolo').get_local_path() + path = AudioAsset("azacinto_foscolo").get_local_path() with open(str(path), "rb") as f: yield f diff --git a/tests/entrypoints/openai/correctness/test_lmeval.py b/tests/entrypoints/openai/correctness/test_lmeval.py index 624acd5ffde7..919b7793628e 100644 --- a/tests/entrypoints/openai/correctness/test_lmeval.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -44,14 +44,15 @@ def run_test(more_args): print(f"Running with: {args}") with RemoteOpenAIServer( - MODEL_NAME, args, - max_wait_seconds=MAX_WAIT_SECONDS) as remote_server: + MODEL_NAME, args, max_wait_seconds=MAX_WAIT_SECONDS + ) as remote_server: url = f"{remote_server.url_for('v1')}/completions" model_args = ( f"model={MODEL_NAME}," f"base_url={url}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False" + ) results = lm_eval.simple_evaluate( model="local-completions", @@ -60,15 +61,18 @@ def run_test(more_args): ) measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" - - -@pytest.mark.skipif(not current_platform.is_cuda() - and not current_platform.is_tpu() - and not current_platform.is_xpu(), - reason="V1 currently only supported on CUDA, XPU and TPU") + assert ( + measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + +@pytest.mark.skipif( + not current_platform.is_cuda() + and not current_platform.is_tpu() + and not current_platform.is_xpu(), + reason="V1 currently only supported on CUDA, XPU and TPU", +) def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): """Run with the V1 Engine.""" diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 9122b7003bf9..7821ade63ac3 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -7,6 +7,7 @@ This simulates real work usage of the API and makes sure that the frontend and AsyncLLMEngine are working correctly. """ + import asyncio import io import time @@ -45,7 +46,8 @@ async def transcribe_audio(client, tokenizer, y, sr): # NOTE there's no streaming in transcriptions, can't measure ttft latency = end_time - start_time num_output_tokens = len( - tokenizer(transcription.text, add_special_tokens=False).input_ids) + tokenizer(transcription.text, add_special_tokens=False).input_ids + ) return latency, num_output_tokens, transcription.text @@ -73,8 +75,8 @@ async def process_dataset(model, client, data, concurrent_request): for sample in data: audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] task = asyncio.create_task( - bound_transcribe(sem, client, tokenizer, (audio, sr), - sample["text"])) + bound_transcribe(sem, client, tokenizer, (audio, sr), sample["text"]) + ) tasks.append(task) return await asyncio.gather(*tasks) @@ -98,34 +100,35 @@ def print_performance_metrics(results, total_time): def add_duration(sample): - y, sr = sample['audio']["array"], sample['audio']["sampling_rate"] - sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000 + y, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"] + sample["duration_ms"] = librosa.get_duration(y=y, sr=sr) * 1000 return sample -def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs): +def load_hf_dataset(dataset_repo: str, split="validation", **hf_kwargs): ## Load and filter the dataset dataset = load_dataset(dataset_repo, split=split, **hf_kwargs) - if 'duration_ms' not in dataset[0]: + if "duration_ms" not in dataset[0]: # compute duration to filter dataset = dataset.map(add_duration) # Whisper max supported duration - dataset = dataset.filter(lambda example: example['duration_ms'] < 30000) + dataset = dataset.filter(lambda example: example["duration_ms"] < 30000) return dataset -def run_evaluation(model: str, - client, - dataset, - max_concurrent_reqs: int, - n_examples: int = -1, - print_metrics: bool = True): +def run_evaluation( + model: str, + client, + dataset, + max_concurrent_reqs: int, + n_examples: int = -1, + print_metrics: bool = True, +): if n_examples > 0: dataset = dataset.select(range(n_examples)) start = time.perf_counter() - results = asyncio.run( - process_dataset(model, client, dataset, max_concurrent_reqs)) + results = asyncio.run(process_dataset(model, client, dataset, max_concurrent_reqs)) end = time.perf_counter() total_time = end - start print(f"Total Test Time: {total_time:.4f} seconds") @@ -135,8 +138,7 @@ def run_evaluation(model: str, predictions = [res[2] for res in results] references = [res[3] for res in results] wer = load("wer") - wer_score = 100 * wer.compute(references=references, - predictions=predictions) + wer_score = 100 * wer.compute(references=references, predictions=predictions) print("WER:", wer_score) return wer_score @@ -145,26 +147,25 @@ def run_evaluation(model: str, @pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) # Original dataset is 20GB+ in size, hence we use a pre-filtered slice. @pytest.mark.parametrize( - "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]) + "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"] +) # NOTE: Expected WER measured with equivalent hf.transformers args: # whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. @pytest.mark.parametrize("expected_wer", [12.744980]) -def test_wer_correctness(model_name, - dataset_repo, - expected_wer, - n_examples=-1, - max_concurrent_request=None): +def test_wer_correctness( + model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None +): # TODO refactor to use `ASRDataset` - with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server: + with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server: dataset = load_hf_dataset(dataset_repo) if not max_concurrent_request: # No max concurrency - max_concurrent_request = n_examples if n_examples > 0\ - else len(dataset) + max_concurrent_request = n_examples if n_examples > 0 else len(dataset) client = remote_server.get_async_client() - wer = run_evaluation(model_name, client, dataset, - max_concurrent_request, n_examples) + wer = run_evaluation( + model_name, client, dataset, max_concurrent_request, n_examples + ) if expected_wer: torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) diff --git a/tests/entrypoints/openai/test_async_tokenization.py b/tests/entrypoints/openai/test_async_tokenization.py index 80261597b11a..5df859df42da 100644 --- a/tests/entrypoints/openai/test_async_tokenization.py +++ b/tests/entrypoints/openai/test_async_tokenization.py @@ -44,15 +44,11 @@ async def client(server): ids=["completion", "chat"], argnames=["create_func_gen", "content_body"], argvalues=[ - (lambda x: x.completions.create, { - "prompt": " ".join(['A'] * 10_000) - }), - (lambda x: x.chat.completions.create, { - "messages": [{ - "role": "user", - "content": " ".join(['A'] * 10_000) - }] - }), + (lambda x: x.completions.create, {"prompt": " ".join(["A"] * 10_000)}), + ( + lambda x: x.chat.completions.create, + {"messages": [{"role": "user", "content": " ".join(["A"] * 10_000)}]}, + ), ], ) async def test_with_and_without_truncate( @@ -65,15 +61,15 @@ async def test_with_and_without_truncate( body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} num_requests = 10 - truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] * - (num_requests - num_requests // 2)) + truncate_prompt_tokens = [1000] * (num_requests // 2) + [None] * ( + num_requests - num_requests // 2 + ) random.shuffle(truncate_prompt_tokens) - bodies = [{ - **body, "extra_body": { - 'truncate_prompt_tokens': t - } - } for t in truncate_prompt_tokens] + bodies = [ + {**body, "extra_body": {"truncate_prompt_tokens": t}} + for t in truncate_prompt_tokens + ] async def get_status_code(**kwargs): try: diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 2d33d3c3a6b5..a96f0134c2ff 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -56,24 +56,18 @@ def base64_encoded_audio() -> dict[str, str]: @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) -async def test_single_chat_session_audio(client: openai.AsyncOpenAI, - model_name: str, audio_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_single_chat_session_audio( + client: openai.AsyncOpenAI, model_name: str, audio_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -82,13 +76,15 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=202, total_tokens=212 + ) message = choice.message message = chat_completion.choices[0].message @@ -110,56 +106,52 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) -async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI, - model_name: str, - audio_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": audio_url - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_error_on_invalid_audio_url_type( + client: openai.AsyncOpenAI, model_name: str, audio_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": audio_url}, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # audio_url should be a dict {"url": "some url"}, not directly a string with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0) + _ = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) async def test_single_chat_session_audio_base64encoded( - client: openai.AsyncOpenAI, model_name: str, audio_url: str, - base64_encoded_audio: dict[str, str]): - - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": { - "url": - f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + audio_url: str, + base64_encoded_audio: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": { + "url": f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}" # noqa: E501 + }, + }, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -168,13 +160,15 @@ async def test_single_chat_session_audio_base64encoded( max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=202, total_tokens=212 + ) message = choice.message message = chat_completion.choices[0].message @@ -198,25 +192,26 @@ async def test_single_chat_session_audio_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]]) async def test_single_chat_session_input_audio( - client: openai.AsyncOpenAI, model_name: str, audio_url: str, - base64_encoded_audio: dict[str, str]): - messages = [{ - "role": - "user", - "content": [ - { - "type": "input_audio", - "input_audio": { - "data": base64_encoded_audio[audio_url], - "format": "wav" - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + audio_url: str, + base64_encoded_audio: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": base64_encoded_audio[audio_url], + "format": "wav", + }, + }, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -224,13 +219,15 @@ async def test_single_chat_session_input_audio( messages=messages, max_completion_tokens=10, logprobs=True, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=202, total_tokens=212) + completion_tokens=10, prompt_tokens=202, total_tokens=212 + ) message = choice.message message = chat_completion.choices[0].message @@ -252,24 +249,18 @@ async def test_single_chat_session_input_audio( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) -async def test_chat_streaming_audio(client: openai.AsyncOpenAI, - model_name: str, audio_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_chat_streaming_audio( + client: openai.AsyncOpenAI, model_name: str, audio_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": audio_url}}, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -309,27 +300,27 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS) -async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, - model_name: str, audio_url: str, - base64_encoded_audio: dict[str, - str]): - messages = [{ - "role": - "user", - "content": [ - { - "type": "input_audio", - "input_audio": { - "data": base64_encoded_audio[audio_url], - "format": "wav" - } - }, - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] +async def test_chat_streaming_input_audio( + client: openai.AsyncOpenAI, + model_name: str, + audio_url: str, + base64_encoded_audio: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": base64_encoded_audio[audio_url], + "format": "wav", + }, + }, + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -369,26 +360,23 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( - "audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]]) -async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str, - audio_urls: list[str]): - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "audio_url", - "audio_url": { - "url": audio_url - } - } for audio_url in audio_urls), - { - "type": "text", - "text": "What's happening in this audio?" - }, - ], - }] + "audio_urls", [TEST_AUDIO_URLS, TEST_AUDIO_URLS + [TEST_AUDIO_URLS[0]]] +) +async def test_multi_audio_input( + client: openai.AsyncOpenAI, model_name: str, audio_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "audio_url", "audio_url": {"url": audio_url}} + for audio_url in audio_urls + ), + {"type": "text", "text": "What's happening in this audio?"}, + ], + } + ] if len(audio_urls) > MAXIMUM_AUDIOS: with pytest.raises(openai.BadRequestError): # test multi-audio input diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index a55941976cd8..50ec87b4464f 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -16,9 +16,9 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def server_args(request: pytest.FixtureRequest) -> list[str]: - """ Provide extra arguments to the server via indirect parametrization + """Provide extra arguments to the server via indirect parametrization Usage: @@ -80,8 +80,10 @@ async def client(server): "server_args", [ pytest.param([], id="default-frontend-multiprocessing"), - pytest.param(["--disable-frontend-multiprocessing"], - id="disable-frontend-multiprocessing") + pytest.param( + ["--disable-frontend-multiprocessing"], + id="disable-frontend-multiprocessing", + ), ], indirect=True, ) @@ -97,8 +99,10 @@ async def test_show_version(server: RemoteOpenAIServer): "server_args", [ pytest.param([], id="default-frontend-multiprocessing"), - pytest.param(["--disable-frontend-multiprocessing"], - id="disable-frontend-multiprocessing") + pytest.param( + ["--disable-frontend-multiprocessing"], + id="disable-frontend-multiprocessing", + ), ], indirect=True, ) @@ -112,11 +116,13 @@ async def test_check_health(server: RemoteOpenAIServer): @pytest.mark.parametrize( "server_args", [ - pytest.param(["--max-model-len", "10100"], - id="default-frontend-multiprocessing"), + pytest.param( + ["--max-model-len", "10100"], id="default-frontend-multiprocessing" + ), pytest.param( ["--disable-frontend-multiprocessing", "--max-model-len", "10100"], - id="disable-frontend-multiprocessing") + id="disable-frontend-multiprocessing", + ), ], indirect=True, ) @@ -131,14 +137,16 @@ async def test_request_cancellation(server: RemoteOpenAIServer): # Request about 2 million tokens for _ in range(200): task = asyncio.create_task( - client.chat.completions.create(messages=chat_input, - model=MODEL_NAME, - max_tokens=10000, - extra_body={"min_tokens": 10000})) + client.chat.completions.create( + messages=chat_input, + model=MODEL_NAME, + max_tokens=10000, + extra_body={"min_tokens": 10000}, + ) + ) tasks.append(task) - done, pending = await asyncio.wait(tasks, - return_when=asyncio.ALL_COMPLETED) + done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) # Make sure all requests were sent to the server and timed out # (We don't want to hide other errors like 400s that would invalidate this @@ -151,16 +159,15 @@ async def test_request_cancellation(server: RemoteOpenAIServer): # If the server had not cancelled all the other requests, then it would not # be able to respond to this one within the timeout client = server.get_async_client(timeout=5) - response = await client.chat.completions.create(messages=chat_input, - model=MODEL_NAME, - max_tokens=10) + response = await client.chat.completions.create( + messages=chat_input, model=MODEL_NAME, max_tokens=10 + ) assert len(response.choices) == 1 @pytest.mark.asyncio async def test_request_wrong_content_type(server: RemoteOpenAIServer): - chat_input = [{"role": "user", "content": "Write a long story"}] client = server.get_async_client() @@ -169,17 +176,13 @@ async def test_request_wrong_content_type(server: RemoteOpenAIServer): messages=chat_input, model=MODEL_NAME, max_tokens=10000, - extra_headers={ - "Content-Type": "application/x-www-form-urlencoded" - }) + extra_headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) @pytest.mark.parametrize( "server_args", - [ - pytest.param(["--enable-server-load-tracking"], - id="enable-server-load-tracking") - ], + [pytest.param(["--enable-server-load-tracking"], id="enable-server-load-tracking")], indirect=True, ) @pytest.mark.asyncio @@ -202,7 +205,8 @@ def make_long_completion_request(): # Start the completion request in a background thread. completion_future = asyncio.create_task( - asyncio.to_thread(make_long_completion_request)) + asyncio.to_thread(make_long_completion_request) + ) # Give a short delay to ensure the request has started. await asyncio.sleep(0.1) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 3bdfef7b4adb..ed0b284bda62 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -23,14 +23,15 @@ @pytest.fixture(scope="module") def monkeypatch_module(): from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() yield mpatch mpatch.undo() @pytest.fixture(scope="module") -def server(monkeypatch_module, zephyr_lora_files): #noqa: F811 - monkeypatch_module.setenv('VLLM_USE_V1', '1') +def server(monkeypatch_module, zephyr_lora_files): # noqa: F811 + monkeypatch_module.setenv("VLLM_USE_V1", "1") args = [ # use half precision for speed and memory savings in CI environment @@ -68,20 +69,18 @@ async def client(server): [MODEL_NAME, "zephyr-lora"], ) async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] chat_completion = await client.chat.completions.create( model=model_name, messages=messages, max_completion_tokens=5, temperature=0.0, - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] assert choice.logprobs is None @@ -94,13 +93,10 @@ async def test_no_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): [MODEL_NAME, "zephyr-lora"], ) async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -108,7 +104,8 @@ async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): max_completion_tokens=5, temperature=0.0, logprobs=True, - top_logprobs=0) + top_logprobs=0, + ) choice = chat_completion.choices[0] assert choice.logprobs is not None @@ -122,13 +119,10 @@ async def test_zero_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): [MODEL_NAME, "zephyr-lora"], ) async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -136,7 +130,8 @@ async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): max_completion_tokens=5, temperature=0.0, logprobs=True, - top_logprobs=5) + top_logprobs=5, + ) choice = chat_completion.choices[0] assert choice.logprobs is not None @@ -149,41 +144,39 @@ async def test_some_logprobs_chat(client: openai.AsyncOpenAI, model_name: str): "model_name", [MODEL_NAME, "zephyr-lora"], ) -async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] +async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, model_name: str): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] # Default max_logprobs is 20, so this should raise an error with pytest.raises((openai.BadRequestError, openai.APIError)): - stream = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - logprobs=True, - top_logprobs=21, - stream=True) + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + logprobs=True, + top_logprobs=21, + stream=True, + ) async for chunk in stream: ... with pytest.raises(openai.BadRequestError): - await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - logprobs=True, - top_logprobs=30, - stream=False) + await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + logprobs=True, + top_logprobs=30, + stream=False, + ) # the server should still work afterwards chat_completion = await client.chat.completions.create( - model=model_name, - messages=messages, - max_completion_tokens=10, - stream=False) + model=model_name, messages=messages, max_completion_tokens=10, stream=False + ) message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 0 @@ -193,27 +186,20 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, "model_name, prompt_logprobs", [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], ) -async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str, - prompt_logprobs: Optional[int]): +async def test_prompt_logprobs_chat( + client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int] +): params: dict = { - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - "model": - model_name + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + }, + {"role": "user", "content": "Where was it played?"}, + ], + "model": model_name, } if prompt_logprobs is not None: @@ -236,29 +222,21 @@ async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str): +async def test_more_than_one_prompt_logprobs_chat( + client: openai.AsyncOpenAI, model_name: str +): params: dict = { - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - "model": - model_name, - "extra_body": { - "prompt_logprobs": 1 - } + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020?"}, + { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020.", + }, + {"role": "user", "content": "Where was it played?"}, + ], + "model": model_name, + "extra_body": {"prompt_logprobs": 1}, } completion_1 = await client.chat.completions.create(**params) @@ -275,15 +253,11 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME, "zephyr-lora"], ) -async def test_single_chat_session(client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] +async def test_single_chat_session(client: openai.AsyncOpenAI, model_name: str): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] # test single completion chat_completion = await client.chat.completions.create( @@ -291,14 +265,16 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, messages=messages, max_completion_tokens=10, logprobs=True, - top_logprobs=5) + top_logprobs=5, + ) assert chat_completion.id is not None assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=37, total_tokens=47) + completion_tokens=10, prompt_tokens=37, total_tokens=47 + ) message = choice.message assert message.content is not None and len(message.content) >= 10 @@ -323,13 +299,10 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, [MODEL_NAME, "zephyr-lora"], ) async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] # test single completion chat_completion = await client.chat.completions.create( @@ -371,15 +344,13 @@ async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str): "model_name", ["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"], ) -async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is the capital of France?" - }] +async def test_chat_completion_stream_options( + client: openai.AsyncOpenAI, model_name: str +): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] # Test stream=True, stream_options={"include_usage": False} stream = await client.chat.completions.create( @@ -388,23 +359,21 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, max_completion_tokens=10, temperature=0.0, stream=True, - stream_options={"include_usage": False}) + stream_options={"include_usage": False}, + ) async for chunk in stream: assert chunk.usage is None # Test stream=True, stream_options={"include_usage": True, # "continuous_usage_stats": False}} - stream = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": - True, - "continuous_usage_stats": - False - }) + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + stream=True, + stream_options={"include_usage": True, "continuous_usage_stats": False}, + ) async for chunk in stream: if chunk.choices[0].finish_reason is None: @@ -416,8 +385,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) + final_chunk.usage.prompt_tokens + final_chunk.usage.completion_tokens + ) assert final_chunk.choices == [] # Test stream=False, stream_options={"include_usage": None} @@ -428,7 +397,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, max_completion_tokens=10, temperature=0.0, stream=False, - stream_options={"include_usage": None}) + stream_options={"include_usage": None}, + ) # Test stream=False, stream_options={"include_usage": True} with pytest.raises(BadRequestError): @@ -438,7 +408,8 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, max_completion_tokens=10, temperature=0.0, stream=False, - stream_options={"include_usage": True}) + stream_options={"include_usage": True}, + ) # Test stream=True, stream_options={"include_usage": True, # "continuous_usage_stats": True} @@ -457,14 +428,17 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, last_completion_tokens = 0 async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 - assert last_completion_tokens == 0 or \ - chunk.usage.completion_tokens > last_completion_tokens or \ - ( - not chunk.choices and - chunk.usage.completion_tokens == last_completion_tokens - ) - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert ( + last_completion_tokens == 0 + or chunk.usage.completion_tokens > last_completion_tokens + or ( + not chunk.choices + and chunk.usage.completion_tokens == last_completion_tokens + ) + ) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) last_completion_tokens = chunk.usage.completion_tokens assert last_completion_tokens == 10 @@ -475,37 +449,36 @@ async def test_structured_outputs_choice_chat( client: openai.AsyncOpenAI, sample_structured_outputs_choices, ): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - "The best language for type-safe systems programming is " - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": "The best language for type-safe systems programming is ", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=10, temperature=0.7, extra_body=dict( - structured_outputs={"choice": sample_structured_outputs_choices})) + structured_outputs={"choice": sample_structured_outputs_choices} + ), + ) choice1 = chat_completion.choices[0].message.content assert choice1 in sample_structured_outputs_choices messages.append({"role": "assistant", "content": choice1}) - messages.append({ - "role": "user", - "content": "I disagree, pick another one" - }) + messages.append({"role": "user", "content": "I disagree, pick another one"}) chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=10, temperature=0.7, extra_body=dict( - structured_outputs={"choice": sample_structured_outputs_choices})) + structured_outputs={"choice": sample_structured_outputs_choices} + ), + ) choice2 = chat_completion.choices[0].message.content assert choice2 in sample_structured_outputs_choices assert choice1 != choice2 @@ -516,38 +489,35 @@ async def test_structured_outputs_json_chat( client: openai.AsyncOpenAI, sample_json_schema, ): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {sample_json_schema}" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example JSON for an employee profile that " + f"fits this schema: {sample_json_schema}", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - extra_body=dict(structured_outputs={"json": sample_json_schema})) + extra_body=dict(structured_outputs={"json": sample_json_schema}), + ) message = chat_completion.choices[0].message assert message.content is not None json1 = json.loads(message.content) jsonschema.validate(instance=json1, schema=sample_json_schema) messages.append({"role": "assistant", "content": message.content}) - messages.append({ - "role": - "user", - "content": - "Give me another one with a different name and age" - }) + messages.append( + {"role": "user", "content": "Give me another one with a different name and age"} + ) chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - extra_body=dict(structured_outputs={"json": sample_json_schema})) + extra_body=dict(structured_outputs={"json": sample_json_schema}), + ) message = chat_completion.choices[0].message assert message.content is not None json2 = json.loads(message.content) @@ -561,21 +531,19 @@ async def test_structured_outputs_regex_chat( client: openai.AsyncOpenAI, sample_regex, ): - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example IP address with this regex: {sample_regex}" - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example IP address with this regex: {sample_regex}", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=20, - extra_body=dict(structured_outputs={"regex": sample_regex})) + extra_body=dict(structured_outputs={"regex": sample_regex}), + ) ip1 = chat_completion.choices[0].message.content assert ip1 is not None assert re.fullmatch(sample_regex, ip1) is not None @@ -586,7 +554,8 @@ async def test_structured_outputs_regex_chat( model=MODEL_NAME, messages=messages, max_completion_tokens=20, - extra_body=dict(structured_outputs={"regex": sample_regex})) + extra_body=dict(structured_outputs={"regex": sample_regex}), + ) ip2 = chat_completion.choices[0].message.content assert ip2 is not None assert re.fullmatch(sample_regex, ip2) is not None @@ -595,40 +564,33 @@ async def test_structured_outputs_regex_chat( @pytest.mark.asyncio async def test_structured_outputs_type_error(client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - "The best language for type-safe systems programming is " - }] + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": "The best language for type-safe systems programming is ", + }, + ] with pytest.raises(openai.BadRequestError): _ = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - extra_body=dict( - structured_outputs={"regex": { - 1: "Python", - 2: "C++" - }})) + extra_body=dict(structured_outputs={"regex": {1: "Python", 2: "C++"}}), + ) @pytest.mark.asyncio async def test_structured_outputs_choice_chat_logprobs( - client: openai.AsyncOpenAI, sample_structured_outputs_choices): - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - "The best language for type-safe systems programming is " - }] + client: openai.AsyncOpenAI, sample_structured_outputs_choices +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": "The best language for type-safe systems programming is ", + }, + ] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, @@ -636,7 +598,9 @@ async def test_structured_outputs_choice_chat_logprobs( logprobs=True, top_logprobs=5, extra_body=dict( - structured_outputs={"choice": sample_structured_outputs_choices})) + structured_outputs={"choice": sample_structured_outputs_choices} + ), + ) assert chat_completion.choices[0].logprobs is not None assert chat_completion.choices[0].logprobs.content is not None @@ -652,29 +616,26 @@ async def test_named_tool_use( client: openai.AsyncOpenAI, sample_json_schema, ): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": ("Give an example JSON for an employee " - "profile using the specified tool.") - }] - tools = [{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema - } - }] - tool_choice = { - "type": "function", - "function": { - "name": "dummy_function_name" + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": ( + "Give an example JSON for an employee profile using the specified tool." + ), + }, + ] + tools = [ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, } - } + ] + tool_choice = {"type": "function", "function": {"name": "dummy_function_name"}} # non-streaming @@ -692,21 +653,20 @@ async def test_named_tool_use( jsonschema.validate(instance=json1, schema=sample_json_schema) messages.append({"role": "assistant", "content": json_string}) - messages.append({ - "role": - "user", - "content": - "Give me another one with a different name and age" - }) + messages.append( + {"role": "user", "content": "Give me another one with a different name and age"} + ) # streaming - stream = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_completion_tokens=1000, - tools=tools, - tool_choice=tool_choice, - stream=True) + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tools=tools, + tool_choice=tool_choice, + stream=True, + ) output = [] finish_reason_count = 0 @@ -728,64 +688,66 @@ async def test_named_tool_use( @pytest.mark.asyncio -async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI, - sample_json_schema): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {sample_json_schema}" - }] +async def test_inconsistent_tool_choice_and_tools( + client: openai.AsyncOpenAI, sample_json_schema +): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + { + "role": "user", + "content": f"Give an example JSON for an employee profile that " + f"fits this schema: {sample_json_schema}", + }, + ] with pytest.raises(openai.BadRequestError): - await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_completion_tokens=1000, - tool_choice={ - "type": "function", - "function": { - "name": - "dummy_function_name" - } - }) + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tool_choice={ + "type": "function", + "function": {"name": "dummy_function_name"}, + }, + ) with pytest.raises(openai.BadRequestError): await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema + tools=[ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, } - }], + ], tool_choice={ "type": "function", - "function": { - "name": "nondefined_function_name" - } - }) + "function": {"name": "nondefined_function_name"}, + }, + ) with pytest.raises(openai.BadRequestError): await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_completion_tokens=1000, - tools=[{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema + tools=[ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, } - }], - tool_choice={}) + ], + tool_choice={}, + ) @pytest.mark.asyncio @@ -793,13 +755,17 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": - "user", - "content": ('what is 1+1? please respond with a JSON object, ' - 'the format is {"result": 2}') - }], - response_format={"type": "json_object"}) + messages=[ + { + "role": "user", + "content": ( + "what is 1+1? please respond with a JSON object, " + 'the format is {"result": 2}' + ), + } + ], + response_format={"type": "json_object"}, + ) content = resp.choices[0].message.content assert content is not None @@ -815,10 +781,7 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": prompt - }], + messages=[{"role": "user", "content": prompt}], ) content = resp.choices[0].message.content assert content is not None @@ -829,10 +792,7 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": prompt - }], + messages=[{"role": "user", "content": prompt}], response_format={ "type": "json_schema", "json_schema": { @@ -840,13 +800,12 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): "schema": { "type": "object", "properties": { - "result": { - "type": "integer" - }, + "result": {"type": "integer"}, }, }, - } - }) + }, + }, + ) content = resp.choices[0].message.content assert content is not None @@ -859,13 +818,16 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): async def test_extra_fields_allowed(client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?", - "extra_field": "0", - }], # type: ignore + messages=[ + { + "role": "user", + "content": "what is 1+1?", + "extra_field": "0", + } + ], # type: ignore temperature=0, - seed=0) + seed=0, + ) content = resp.choices[0].message.content assert content is not None @@ -873,20 +835,23 @@ async def test_extra_fields_allowed(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_complex_message_content(client: openai.AsyncOpenAI): + content = [ + { + "type": "text", + "text": "what is 1+1? please provide the result without any other text.", + } + ] resp = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": - "user", - "content": [{ - "type": - "text", - "text": - "what is 1+1? please provide the result without any other text." - }] - }], + messages=[ + { + "role": "user", + "content": content, + } + ], temperature=0, - seed=0) + seed=0, + ) content = resp.choices[0].message.content assert content == "2" @@ -898,24 +863,27 @@ async def test_custom_role(client: openai.AsyncOpenAI): resp1 = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "my-custom-role", - "content": "what is 1+1?", - }], # type: ignore + messages=[ + { + "role": "my-custom-role", + "content": "what is 1+1?", + } + ], # type: ignore temperature=0, - seed=0) + seed=0, + ) resp2 = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "my-custom-role", - "content": [{ - "type": "text", - "text": "what is 1+1?" - }] - }], # type: ignore + messages=[ + { + "role": "my-custom-role", + "content": [{"type": "text", "text": "what is 1+1?"}], + } + ], # type: ignore temperature=0, - seed=0) + seed=0, + ) content1 = resp1.choices[0].message.content content2 = resp2.choices[0].message.content @@ -924,34 +892,32 @@ async def test_custom_role(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_long_seed(client: openai.AsyncOpenAI): - for seed in [ - torch.iinfo(torch.long).min - 1, - torch.iinfo(torch.long).max + 1 - ]: + for seed in [torch.iinfo(torch.long).min - 1, torch.iinfo(torch.long).max + 1]: with pytest.raises(BadRequestError) as exc_info: await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "system", - "content": "You are a helpful assistant.", - }], + messages=[ + { + "role": "system", + "content": "You are a helpful assistant.", + } + ], temperature=0, - seed=seed) + seed=seed, + ) - assert ("greater_than_equal" in exc_info.value.message - or "less_than_equal" in exc_info.value.message) + assert ( + "greater_than_equal" in exc_info.value.message + or "less_than_equal" in exc_info.value.message + ) @pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "what is 1+1?" - }] +async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): + messages = [ + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "what is 1+1?"}, + ] request_args = { "model": MODEL_NAME, @@ -963,8 +929,9 @@ async def test_invocations(server: RemoteOpenAIServer, chat_completion = await client.chat.completions.create(**request_args) - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() chat_output = chat_completion.model_dump() diff --git a/tests/entrypoints/openai/test_chat_echo.py b/tests/entrypoints/openai/test_chat_echo.py index ce965eb82924..a9c9c8e3dfe8 100644 --- a/tests/entrypoints/openai/test_chat_echo.py +++ b/tests/entrypoints/openai/test_chat_echo.py @@ -23,7 +23,7 @@ def server(): "--max-model-len", "4080", "--max-logprobs", # test prompt_logprobs equal to -1 - "151936" + "151936", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -46,27 +46,26 @@ class TestCase(NamedTuple): "test_case", [ TestCase(model_name=MODEL_NAME, echo=True), - TestCase(model_name=MODEL_NAME, echo=False) + TestCase(model_name=MODEL_NAME, echo=False), ], ) async def test_chat_session_with_echo_and_continue_final_message( - client: openai.AsyncOpenAI, test_case: TestCase): + client: openai.AsyncOpenAI, test_case: TestCase +): saying: str = "Here is a common saying about apple. An apple a day, keeps" # test echo with continue_final_message parameter chat_completion = await client.chat.completions.create( model=test_case.model_name, - messages=[{ - "role": "user", - "content": "tell me a common saying" - }, { - "role": "assistant", - "content": saying - }], + messages=[ + {"role": "user", "content": "tell me a common saying"}, + {"role": "assistant", "content": saying}, + ], extra_body={ "echo": test_case.echo, "continue_final_message": True, - "add_generation_prompt": False - }) + "add_generation_prompt": False, + }, + ) assert chat_completion.id is not None assert len(chat_completion.choices) == 1 @@ -83,13 +82,10 @@ async def test_chat_session_with_echo_and_continue_final_message( @pytest.mark.asyncio async def test_prompt_logprobs(client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Beijing is the capital of which country?" - }] + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Beijing is the capital of which country?"}, + ] completion = await client.chat.completions.create( model=MODEL_NAME, @@ -103,13 +99,10 @@ async def test_prompt_logprobs(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_top_logprobs(client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Beijing is the capital of which country?" - }] + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Beijing is the capital of which country?"}, + ] completion = await client.chat.completions.create( model=MODEL_NAME, diff --git a/tests/entrypoints/openai/test_chat_logit_bias_validation.py b/tests/entrypoints/openai/test_chat_logit_bias_validation.py index 9fa7ab83555a..6539613ed17b 100644 --- a/tests/entrypoints/openai/test_chat_logit_bias_validation.py +++ b/tests/entrypoints/openai/test_chat_logit_bias_validation.py @@ -49,10 +49,7 @@ async def test_chat_logit_bias_valid(client): completion = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "Testing valid logit bias" - }], + messages=[{"role": "user", "content": "Testing valid logit bias"}], max_tokens=5, logit_bias={str(valid_token_id): 1.0}, ) @@ -69,10 +66,7 @@ async def test_chat_logit_bias_invalid(client): with pytest.raises(openai.BadRequestError) as excinfo: await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "Testing invalid logit bias" - }], + messages=[{"role": "user", "content": "Testing invalid logit bias"}], max_tokens=5, logit_bias={str(invalid_token_id): 1.0}, ) diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index ce90a67c0151..d1202a59752b 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -4,8 +4,7 @@ import pytest from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (apply_hf_chat_template, - load_chat_template) +from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer @@ -17,48 +16,54 @@ # Define models, templates, and their corresponding expected outputs MODEL_TEMPLATE_GENERATION_OUTPUT = [ - ("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user + ( + "facebook/opt-125m", + chatml_jinja_path, + True, + False, + """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user What is the capital of<|im_end|> <|im_start|>assistant -"""), - ("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user +""", + ), + ( + "facebook/opt-125m", + chatml_jinja_path, + False, + False, + """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user -What is the capital of"""), - ("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user +What is the capital of""", + ), + ( + "facebook/opt-125m", + chatml_jinja_path, + False, + True, + """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user What is the capital of<|im_end|> <|im_start|>assistant -The capital of"""), +The capital of""", + ), ] TEST_MESSAGES = [ - { - 'role': 'user', - 'content': 'Hello' - }, - { - 'role': 'assistant', - 'content': 'Hi there!' - }, - { - 'role': 'user', - 'content': 'What is the capital of' - }, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "What is the capital of"}, ] -ASSISTANT_MESSAGE_TO_CONTINUE = { - 'role': 'assistant', - 'content': 'The capital of' -} +ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"} def test_load_chat_template(): @@ -68,8 +73,11 @@ def test_load_chat_template(): # Test assertions assert template_content is not None # Hard coded value for template_chatml.jinja - assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} + assert ( + template_content + == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %} {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 + ) def test_no_load_chat_template_filelike(): @@ -91,9 +99,11 @@ def test_no_load_chat_template_literallike(): @pytest.mark.parametrize( "model,template,add_generation_prompt,continue_final_message,expected_output", - MODEL_TEMPLATE_GENERATION_OUTPUT) -def test_get_gen_prompt(model, template, add_generation_prompt, - continue_final_message, expected_output): + MODEL_TEMPLATE_GENERATION_OUTPUT, +) +def test_get_gen_prompt( + model, template, add_generation_prompt, continue_final_message, expected_output +): model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -106,7 +116,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt, hf_overrides=model_info.hf_overrides, skip_tokenizer_init=model_info.skip_tokenizer_init, enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype) + dtype=model_info.dtype, + ) # Initialize the tokenizer tokenizer = get_tokenizer( @@ -119,7 +130,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt, mock_request = ChatCompletionRequest( model=model, messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE] - if continue_final_message else TEST_MESSAGES, + if continue_final_message + else TEST_MESSAGES, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, ) @@ -138,4 +150,5 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Test assertion assert result == expected_output, ( f"The generated prompt does not match the expected output for " - f"model {model} and template {template}") + f"model {model} and template {template}" + ) diff --git a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py index 03730b67283c..e452b578ba22 100644 --- a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py +++ b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py @@ -14,9 +14,14 @@ @pytest.fixture(scope="module") def server(): # noqa: F811 args = [ - "--max-model-len", "8192", "--enforce-eager", "--reasoning-parser", - "deepseek_r1", "--enable-auto-tool-choice", "--tool-call-parser", - "hermes" + "--max-model-len", + "8192", + "--enforce-eager", + "--reasoning-parser", + "deepseek_r1", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -29,50 +34,46 @@ async def client(server): yield async_client -TOOLS = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, e.g. 'San Francisco'" +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. " + "'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that " + "the city is in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state that the city is" - " in, e.g. 'CA' which would mean 'California'" - }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"] - } + }, } -}] - -MESSAGES = [{ - "role": "user", - "content": "Hi! How are you doing today?" -}, { - "role": "assistant", - "content": "I'm doing well! How can I help you?" -}, { - "role": - "user", - "content": - "Can you tell me what the temperate will be in Dallas, in fahrenheit?" -}] +] + +MESSAGES = [ + {"role": "user", "content": "Hi! How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, + { + "role": "user", + "content": "Can you tell me what the temperate will be in Dallas, " + "in fahrenheit?", + }, +] FUNC_NAME = "get_current_weather" FUNC_ARGS = """{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}""" @@ -105,9 +106,7 @@ def extract_reasoning_and_calls(chunks: list): # test streaming @pytest.mark.asyncio -async def test_chat_streaming_of_tool_and_reasoning( - client: openai.AsyncOpenAI): - +async def test_chat_streaming_of_tool_and_reasoning(client: openai.AsyncOpenAI): stream = await client.chat.completions.create( model=MODEL_NAME, messages=MESSAGES, @@ -120,8 +119,7 @@ async def test_chat_streaming_of_tool_and_reasoning( async for chunk in stream: chunks.append(chunk) - reasoning_content, arguments, function_names = extract_reasoning_and_calls( - chunks) + reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks) assert len(reasoning_content) > 0 assert len(function_names) > 0 and function_names[0] == FUNC_NAME assert len(arguments) > 0 and arguments[0] == FUNC_ARGS @@ -130,7 +128,6 @@ async def test_chat_streaming_of_tool_and_reasoning( # test full generate @pytest.mark.asyncio async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI): - tool_calls = await client.chat.completions.create( model=MODEL_NAME, messages=MESSAGES, @@ -140,7 +137,5 @@ async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI): ) assert len(tool_calls.choices[0].message.reasoning_content) > 0 - assert tool_calls.choices[0].message.tool_calls[0].function.name \ - == FUNC_NAME - assert tool_calls.choices[0].message.tool_calls[0].function.arguments \ - == FUNC_ARGS + assert tool_calls.choices[0].message.tool_calls[0].function.name == FUNC_NAME + assert tool_calls.choices[0].message.tool_calls[0].function.arguments == FUNC_ARGS diff --git a/tests/entrypoints/openai/test_chunked_prompt.py b/tests/entrypoints/openai/test_chunked_prompt.py index c8160c5f2d0e..608e509e59e8 100644 --- a/tests/entrypoints/openai/test_chunked_prompt.py +++ b/tests/entrypoints/openai/test_chunked_prompt.py @@ -40,7 +40,8 @@ async def client(server): @pytest.mark.asyncio async def test_completion_stream_options_and_logprobs_with_long_prompts( - client: openai.AsyncOpenAI): + client: openai.AsyncOpenAI, +): # Test stream with long prompt prompt = "What is the capital of France?" * 400 @@ -62,8 +63,9 @@ async def test_completion_stream_options_and_logprobs_with_long_prompts( async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 assert chunk.usage.completion_tokens >= 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) if not finished: tokens_received += 1 assert chunk.choices[0].text @@ -77,15 +79,13 @@ async def test_completion_stream_options_and_logprobs_with_long_prompts( @pytest.mark.asyncio async def test_chat_completion_stream_options_and_logprobs_with_long_prompts( - client: openai.AsyncOpenAI): + client: openai.AsyncOpenAI, +): # Test stream with long prompt - messages = [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "What is the capital of France?" * 400 - }] + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?" * 400}, + ] stream = await client.chat.completions.create( model=MODEL_NAME, messages=messages, @@ -106,8 +106,9 @@ async def test_chat_completion_stream_options_and_logprobs_with_long_prompts( async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 assert chunk.usage.completion_tokens >= 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) if not finished: if chunk.choices[0].delta.content == "": diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index 9a1c0ea13b54..0b9d171aa481 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -5,8 +5,7 @@ import pytest -from vllm.entrypoints.openai.cli_args import (make_arg_parser, - validate_parsed_serve_args) +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.utils import FlexibleArgumentParser @@ -15,7 +14,7 @@ LORA_MODULE = { "name": "module2", "path": "/path/to/module2", - "base_model_name": "llama" + "base_model_name": "llama", } CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja" assert CHATML_JINJA_PATH.exists() @@ -31,45 +30,51 @@ def serve_parser(): def test_config_arg_parsing(serve_parser, cli_config_file): args = serve_parser.parse_args([]) assert args.port == 8000 - args = serve_parser.parse_args(['--config', cli_config_file]) + args = serve_parser.parse_args(["--config", cli_config_file]) assert args.port == 12312 - args = serve_parser.parse_args([ - '--config', - cli_config_file, - '--port', - '9000', - ]) + args = serve_parser.parse_args( + [ + "--config", + cli_config_file, + "--port", + "9000", + ] + ) assert args.port == 9000 - args = serve_parser.parse_args([ - '--port', - '9000', - '--config', - cli_config_file, - ]) + args = serve_parser.parse_args( + [ + "--port", + "9000", + "--config", + cli_config_file, + ] + ) assert args.port == 9000 ### Tests for LoRA module parsing def test_valid_key_value_format(serve_parser): # Test old format: name=path - args = serve_parser.parse_args([ - '--lora-modules', - 'module1=/path/to/module1', - ]) - expected = [LoRAModulePath(name='module1', path='/path/to/module1')] + args = serve_parser.parse_args( + [ + "--lora-modules", + "module1=/path/to/module1", + ] + ) + expected = [LoRAModulePath(name="module1", path="/path/to/module1")] assert args.lora_modules == expected def test_valid_json_format(serve_parser): # Test valid JSON format input - args = serve_parser.parse_args([ - '--lora-modules', - json.dumps(LORA_MODULE), - ]) + args = serve_parser.parse_args( + [ + "--lora-modules", + json.dumps(LORA_MODULE), + ] + ) expected = [ - LoRAModulePath(name='module2', - path='/path/to/module2', - base_model_name='llama') + LoRAModulePath(name="module2", path="/path/to/module2", base_model_name="llama") ] assert args.lora_modules == expected @@ -77,47 +82,53 @@ def test_valid_json_format(serve_parser): def test_invalid_json_format(serve_parser): # Test invalid JSON format input, missing closing brace with pytest.raises(SystemExit): - serve_parser.parse_args([ - '--lora-modules', '{"name": "module3", "path": "/path/to/module3"' - ]) + serve_parser.parse_args( + ["--lora-modules", '{"name": "module3", "path": "/path/to/module3"'] + ) def test_invalid_type_error(serve_parser): # Test type error when values are not JSON or key=value with pytest.raises(SystemExit): - serve_parser.parse_args([ - '--lora-modules', - 'invalid_format' # This is not JSON or key=value format - ]) + serve_parser.parse_args( + [ + "--lora-modules", + "invalid_format", # This is not JSON or key=value format + ] + ) def test_invalid_json_field(serve_parser): # Test valid JSON format but missing required fields with pytest.raises(SystemExit): - serve_parser.parse_args([ - '--lora-modules', - '{"name": "module4"}' # Missing required 'path' field - ]) + serve_parser.parse_args( + [ + "--lora-modules", + '{"name": "module4"}', # Missing required 'path' field + ] + ) def test_empty_values(serve_parser): # Test when no LoRA modules are provided - args = serve_parser.parse_args(['--lora-modules', '']) + args = serve_parser.parse_args(["--lora-modules", ""]) assert args.lora_modules == [] def test_multiple_valid_inputs(serve_parser): # Test multiple valid inputs (both old and JSON format) - args = serve_parser.parse_args([ - '--lora-modules', - 'module1=/path/to/module1', - json.dumps(LORA_MODULE), - ]) + args = serve_parser.parse_args( + [ + "--lora-modules", + "module1=/path/to/module1", + json.dumps(LORA_MODULE), + ] + ) expected = [ - LoRAModulePath(name='module1', path='/path/to/module1'), - LoRAModulePath(name='module2', - path='/path/to/module2', - base_model_name='llama') + LoRAModulePath(name="module1", path="/path/to/module1"), + LoRAModulePath( + name="module2", path="/path/to/module2", base_model_name="llama" + ), ] assert args.lora_modules == expected @@ -133,40 +144,46 @@ def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser): def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser): """Ensure validation passes with tool choice enabled with a call parser""" - args = serve_parser.parse_args(args=[ - "--enable-auto-tool-choice", - "--tool-call-parser", - "mistral", - ]) + args = serve_parser.parse_args( + args=[ + "--enable-auto-tool-choice", + "--tool-call-parser", + "mistral", + ] + ) validate_parsed_serve_args(args) def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser): """Ensure validation fails if reasoning is enabled with auto tool choice""" - args = serve_parser.parse_args(args=[ - "--enable-auto-tool-choice", - "--reasoning-parser", - "deepseek_r1", - ]) + args = serve_parser.parse_args( + args=[ + "--enable-auto-tool-choice", + "--reasoning-parser", + "deepseek_r1", + ] + ) with pytest.raises(TypeError): validate_parsed_serve_args(args) def test_passes_with_reasoning_parser(serve_parser): - """Ensure validation passes if reasoning is enabled + """Ensure validation passes if reasoning is enabled with a reasoning parser""" - args = serve_parser.parse_args(args=[ - "--reasoning-parser", - "deepseek_r1", - ]) + args = serve_parser.parse_args( + args=[ + "--reasoning-parser", + "deepseek_r1", + ] + ) validate_parsed_serve_args(args) def test_chat_template_validation_for_happy_paths(serve_parser): """Ensure validation passes if the chat template exists""" args = serve_parser.parse_args( - args=["--chat-template", - CHATML_JINJA_PATH.absolute().as_posix()]) + args=["--chat-template", CHATML_JINJA_PATH.absolute().as_posix()] + ) validate_parsed_serve_args(args) @@ -179,8 +196,14 @@ def test_chat_template_validation_for_sad_paths(serve_parser): @pytest.mark.parametrize( "cli_args, expected_middleware", - [(["--middleware", "middleware1", "--middleware", "middleware2" - ], ["middleware1", "middleware2"]), ([], [])]) + [ + ( + ["--middleware", "middleware1", "--middleware", "middleware2"], + ["middleware1", "middleware2"], + ), + ([], []), + ], +) def test_middleware(serve_parser, cli_args, expected_middleware): """Ensure multiple middleware args are parsed properly""" args = serve_parser.parse_args(args=cli_args) diff --git a/tests/entrypoints/openai/test_collective_rpc.py b/tests/entrypoints/openai/test_collective_rpc.py index 37c0b7a900ac..cbd6b02f05dc 100644 --- a/tests/entrypoints/openai/test_collective_rpc.py +++ b/tests/entrypoints/openai/test_collective_rpc.py @@ -12,7 +12,6 @@ class TestWorkerExtension: - def get_model_name(self) -> str: """Test non-pydantic return type.""" return MODEL_NAME @@ -41,20 +40,18 @@ def server(): "tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension", ] with RemoteOpenAIServer( - MODEL_NAME, - args, - env_dict={ - "VLLM_SERVER_DEV_MODE": "1", - "CUDA_VISIBLE_DEVICES": "0" - }, + MODEL_NAME, + args, + env_dict={"VLLM_SERVER_DEV_MODE": "1", "CUDA_VISIBLE_DEVICES": "0"}, ) as remote_server: yield remote_server def test_get_model_name(server): """Test basic response""" - response = requests.post(server.url_for("collective_rpc"), - json={"method": "get_model_name"}) + response = requests.post( + server.url_for("collective_rpc"), json={"method": "get_model_name"} + ) assert response.status_code == 200 results = response.json() assert "results" in results @@ -63,8 +60,9 @@ def test_get_model_name(server): def test_return_none(server): """Test return none""" - response = requests.post(server.url_for("collective_rpc"), - json={"method": "return_none"}) + response = requests.post( + server.url_for("collective_rpc"), json={"method": "return_none"} + ) assert response.status_code == 200 results = response.json() assert results["results"] == [None] @@ -74,12 +72,10 @@ def test_echo_args_kwargs(server): """Test args, kwargs, and dict response""" args = ["arg1", "arg2"] kwargs = {"key1": "value1", "key2": "value2"} - response = requests.post(server.url_for("collective_rpc"), - json={ - "method": "echo_args_kwargs", - "args": args, - "kwargs": kwargs - }) + response = requests.post( + server.url_for("collective_rpc"), + json={"method": "echo_args_kwargs", "args": args, "kwargs": kwargs}, + ) assert response.status_code == 200 results = response.json() result = results["results"][0] diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index 4355603fcd70..e64f68cad7c8 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -25,15 +25,14 @@ "properties": { "city": { "type": "string", - "description": - "The city to find the weather for, e.g. 'Vienna'", + "description": "The city to find the weather for, e.g. " + "'Vienna'", "default": "Vienna", }, "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", + "type": "string", + "description": "The country that the city is in, e.g. " + "'Austria'", }, "unit": { "type": "string", @@ -62,8 +61,7 @@ "include_forecast": { "type": "boolean", "default": False, - "description": - "Whether to include a 24-hour forecast", + "description": "Whether to include a 24-hour forecast", "title": "Include Forecast", }, "language": { @@ -89,21 +87,18 @@ "properties": { "city": { "type": "string", - "description": - "The city to get the forecast for, e.g. 'Vienna'", + "description": "The city to get the forecast for, e.g. " + "'Vienna'", "default": "Vienna", }, "country": { - "type": - "string", - "description": - "The country that the city is in, e.g. 'Austria'", + "type": "string", + "description": "The country that the city is in, e.g. " + "'Austria'", }, "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", + "type": "integer", + "description": "Number of days to get the forecast for (1-7)", }, "unit": { "type": "string", @@ -118,19 +113,11 @@ ] messages = [ + {"role": "user", "content": "Hi! How are you doing today?"}, + {"role": "assistant", "content": "I'm doing well! How can I help you?"}, { "role": "user", - "content": "Hi! How are you doing today?" - }, - { - "role": "assistant", - "content": "I'm doing well! How can I help you?" - }, - { - "role": - "user", - "content": - "Can you tell me what the current weather is in Berlin and the "\ + "content": "Can you tell me what the current weather is in Berlin and the " "forecast for the next 5 days, in fahrenheit?", }, ] @@ -150,7 +137,7 @@ def server(): # noqa: F811 "--reasoning-parser", "qwen3", "--gpu-memory-utilization", - "0.4" + "0.4", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -166,18 +153,22 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("stream", [True, False]) -@pytest.mark.parametrize("tool_choice", [ - "auto", "required", { - "type": "function", - "function": { - "name": "get_current_weather" - } - } -]) +@pytest.mark.parametrize( + "tool_choice", + [ + "auto", + "required", + {"type": "function", "function": {"name": "get_current_weather"}}, + ], +) @pytest.mark.parametrize("enable_thinking", [True, False]) -async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, - stream: bool, tool_choice: Union[str, dict], - enable_thinking: bool): +async def test_function_tool_use( + client: openai.AsyncOpenAI, + model_name: str, + stream: bool, + tool_choice: Union[str, dict], + enable_thinking: bool, +): if not stream: # Non-streaming test chat_completion = await client.chat.completions.create( @@ -185,16 +176,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, model=model_name, tools=tools, tool_choice=tool_choice, - extra_body={ - "chat_template_kwargs": { - "enable_thinking": enable_thinking - } - }) + extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}}, + ) if enable_thinking: - assert chat_completion.choices[0].message.\ - reasoning_content is not None - assert chat_completion.choices[0].message.\ - reasoning_content != "" + assert chat_completion.choices[0].message.reasoning_content is not None + assert chat_completion.choices[0].message.reasoning_content != "" assert chat_completion.choices[0].message.tool_calls is not None assert len(chat_completion.choices[0].message.tool_calls) > 0 else: @@ -205,11 +191,8 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, tools=tools, tool_choice=tool_choice, stream=True, - extra_body={ - "chat_template_kwargs": { - "enable_thinking": enable_thinking - } - }) + extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}}, + ) output = [] async for chunk in output_stream: @@ -237,12 +220,11 @@ def k2_server(): # noqa: F811 ] # hack to test kimi_k2 tool use tool_id format. # avoid error in is_deepseek_mla check by setting kv_lora_rank=null - with RemoteOpenAIServer(MODEL_NAME, - args, - override_hf_configs={ - "model_type": 'kimi_k2', - 'kv_lora_rank': None - }) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, + args, + override_hf_configs={"model_type": "kimi_k2", "kv_lora_rank": None}, + ) as remote_server: yield remote_server @@ -256,20 +238,20 @@ async def k2_client(k2_server): @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("tool_choice", ["required"]) -async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str, - stream: bool, tool_choice: str): - +async def test_tool_id_kimi_k2( + k2_client: openai.AsyncOpenAI, model_name: str, stream: bool, tool_choice: str +): if not stream: # Non-streaming test chat_completion = await k2_client.chat.completions.create( - messages=messages, - model=model_name, - tools=tools, - tool_choice=tool_choice) + messages=messages, model=model_name, tools=tools, tool_choice=tool_choice + ) assert chat_completion.choices[0].message.tool_calls is not None assert len(chat_completion.choices[0].message.tool_calls) > 0 - assert chat_completion.choices[0].message.tool_calls[ - 0].id == 'functions.get_current_weather:0' + assert ( + chat_completion.choices[0].message.tool_calls[0].id + == "functions.get_current_weather:0" + ) else: # Streaming test output_stream = await k2_client.chat.completions.create( @@ -277,42 +259,45 @@ async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str, model=model_name, tools=tools, tool_choice=tool_choice, - stream=True) + stream=True, + ) output = [] async for chunk in output_stream: if chunk.choices and chunk.choices[0].delta.tool_calls: output.extend(chunk.choices[0].delta.tool_calls) for o in output: - assert o.id is None or o.id == 'functions.get_current_weather:0' + assert o.id is None or o.id == "functions.get_current_weather:0" @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("arguments", ["{}", '']) -async def test_no_args_tool_call(client: openai.AsyncOpenAI, model_name: str, - arguments: str): +@pytest.mark.parametrize("arguments", ["{}", ""]) +async def test_no_args_tool_call( + client: openai.AsyncOpenAI, model_name: str, arguments: str +): # Step 1: Define a tool that requires no parameters - tools = [{ - "type": "function", - "function": { - "name": "get_current_time", - "description": - "Get the current date and time. No parameters needed.", - "parameters": { - "type": "object", - "properties": {}, # No parameters - "required": [] # No required fields - } + tools = [ + { + "type": "function", + "function": { + "name": "get_current_time", + "description": "Get the current date and time. No parameters needed.", + "parameters": { + "type": "object", + "properties": {}, # No parameters + "required": [], # No required fields + }, + }, } - }] + ] messages = [{"role": "user", "content": "What time is it now?"}] # Step 2: Send user message and let model decide whether to call the tool response = await client.chat.completions.create( model=model_name, messages=messages, tools=tools, - tool_choice="auto" # Let model choose automatically + tool_choice="auto", # Let model choose automatically ) # Step 3: Check if model wants to call a tool @@ -328,11 +313,13 @@ async def test_no_args_tool_call(client: openai.AsyncOpenAI, model_name: str, messages.append(message) current_time = datetime.datetime.now() result = current_time.isoformat() - messages.append({ - "role": "tool", - "tool_call_id": tool_call.id, - "content": result, - }) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + } + ) # Step 5: Send tool result back to model to continue conversation final_response = await client.chat.completions.create( model=model_name, diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index cad914282306..3ed98ffe0e39 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -9,6 +9,7 @@ import pytest import pytest_asyncio import torch + # downloading lora to test lora requests from openai import BadRequestError from transformers import AutoConfig @@ -23,8 +24,9 @@ @pytest.fixture(scope="module", params=["use-lora"]) -def default_server_args(request: pytest.FixtureRequest, - opt125_lora_files: str) -> list[str]: +def default_server_args( + request: pytest.FixtureRequest, opt125_lora_files: str +) -> list[str]: args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -42,18 +44,20 @@ def default_server_args(request: pytest.FixtureRequest, lora_module_1 = { "name": LORA_SERVING_MODEL_NAME, "path": opt125_lora_files, - "base_model_name": MODEL_NAME + "base_model_name": MODEL_NAME, } - args.extend([ - "--enable-lora", - "--lora-module", - json.dumps(lora_module_1), - "--max-lora-rank", - "64", - "--max-cpu-loras", - "2", - ]) + args.extend( + [ + "--enable-lora", + "--lora-module", + json.dumps(lora_module_1), + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + ] + ) return args @@ -67,7 +71,7 @@ def default_server_args(request: pytest.FixtureRequest, def _encode_embeds(embeds: torch.Tensor): buffer = io.BytesIO() torch.save(embeds, buffer) - return base64.b64encode(buffer.getvalue()).decode('utf-8') + return base64.b64encode(buffer.getvalue()).decode("utf-8") @pytest.fixture(scope="module") @@ -79,8 +83,7 @@ def example_prompt_embeds(hf_runner): return [_encode_embeds(item) for item in example_embeddings] -@pytest.fixture(scope="module", - params=["", "--disable-frontend-multiprocessing"]) +@pytest.fixture(scope="module", params=["", "--disable-frontend-multiprocessing"]) def server_with_prompt_embeds(default_server_args, request): if request.param: default_server_args.append(request.param) @@ -110,7 +113,8 @@ async def test_completions_with_prompt_embeds( prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) assert len(completion.choices[0].text) >= 1 assert completion.choices[0].prompt_logprobs is None @@ -120,7 +124,8 @@ async def test_completions_with_prompt_embeds( prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, + ) assert len(completion.choices) == 2 assert len(completion.choices[0].text) >= 1 assert len(completion.choices[1].text) >= 1 @@ -131,7 +136,8 @@ async def test_completions_with_prompt_embeds( prompt="", # Add empty prompt as required parameter max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) single_output = single_completion.choices[0].text stream = await client_with_prompt_embeds.completions.create( @@ -140,7 +146,8 @@ async def test_completions_with_prompt_embeds( max_tokens=5, temperature=0.0, stream=True, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) chunks = [] finish_reason_count = 0 async for chunk in stream: @@ -159,12 +166,12 @@ async def test_completions_with_prompt_embeds( max_tokens=5, temperature=0.0, stream=True, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, + ) chunks_stream_embeds: list[list[str]] = [[], []] finish_reason_count = 0 async for chunk in stream: - chunks_stream_embeds[chunk.choices[0].index].append( - chunk.choices[0].text) + chunks_stream_embeds[chunk.choices[0].index].append(chunk.choices[0].text) if chunk.choices[0].finish_reason is not None: finish_reason_count += 1 assert finish_reason_count == 2 @@ -179,7 +186,8 @@ async def test_completions_with_prompt_embeds( prompt="This is a prompt", max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) assert len(completion.choices) == 2 completion_text_only = await client_with_prompt_embeds.completions.create( model=model_name, @@ -192,18 +200,18 @@ async def test_completions_with_prompt_embeds( prompt="", max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) # Embeddings responses should be handled first - assert completion_mixed.choices[0].text == completion_embeds_only.choices[ - 0].text - assert completion_mixed.choices[1].text == completion_text_only.choices[ - 0].text + assert completion_mixed.choices[0].text == completion_embeds_only.choices[0].text + assert completion_mixed.choices[1].text == completion_text_only.choices[0].text @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME]) async def test_completions_errors_with_prompt_embeds( - client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str +): # Test error case: invalid prompt_embeds with pytest.raises(BadRequestError): await client_with_prompt_embeds.completions.create( @@ -211,7 +219,8 @@ async def test_completions_errors_with_prompt_embeds( model=model_name, max_tokens=5, temperature=0.0, - extra_body={"prompt_embeds": "invalid_base64"}) + extra_body={"prompt_embeds": "invalid_base64"}, + ) @pytest.mark.asyncio @@ -233,7 +242,8 @@ async def test_completions_with_logprobs_and_prompt_embeds( temperature=0.0, echo=False, logprobs=logprobs_arg, - extra_body={"prompt_embeds": encoded_embeds}) + extra_body={"prompt_embeds": encoded_embeds}, + ) logprobs = completion.choices[0].logprobs assert logprobs is not None @@ -252,7 +262,8 @@ async def test_completions_with_logprobs_and_prompt_embeds( temperature=0.0, echo=False, logprobs=logprobs_arg, - extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}, + ) assert len(completion.choices) == 2 for choice in completion.choices: @@ -262,8 +273,7 @@ async def test_completions_with_logprobs_and_prompt_embeds( assert len(logprobs.token_logprobs) == 5 assert len(logprobs.top_logprobs) == 5 for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, - 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) == 5 @@ -280,8 +290,5 @@ async def test_prompt_logprobs_raises_error( prompt="", max_tokens=5, temperature=0.0, - extra_body={ - "prompt_embeds": encoded_embeds, - "prompt_logprobs": True - }, + extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True}, ) diff --git a/tests/entrypoints/openai/test_default_mm_loras.py b/tests/entrypoints/openai/test_default_mm_loras.py index b9c466a6fbeb..336bda81a9ef 100644 --- a/tests/entrypoints/openai/test_default_mm_loras.py +++ b/tests/entrypoints/openai/test_default_mm_loras.py @@ -16,8 +16,7 @@ # need a multimodal model for these tests. # Contains a modality specific lora alongside the base model -MULTIMODAL_MODEL_NAME = snapshot_download( - "microsoft/Phi-4-multimodal-instruct") +MULTIMODAL_MODEL_NAME = snapshot_download("microsoft/Phi-4-multimodal-instruct") AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora") ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501 @@ -25,7 +24,6 @@ @pytest.fixture(scope="module") def multimodal_server(): # noqa: F811 - args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -45,11 +43,12 @@ def multimodal_server(): # noqa: F811 "--gpu-memory-utilization", "0.8", "--default-mm-loras", - f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}", + f'{{"audio": "{AUDIO_LORA_PATH}"}}', ] - with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args, - max_wait_seconds=480) as remote_server: + with RemoteOpenAIServer( + MULTIMODAL_MODEL_NAME, args, max_wait_seconds=480 + ) as remote_server: yield remote_server @@ -70,25 +69,25 @@ async def test_default_mm_lora_chat_completions( multi_modal_client: openai.AsyncOpenAI, audio_assets: AudioTestAssets, ): - messages = [{ - "role": - "user", - "content": [{ - "type": "text", - "text": "Can you transcribe this audio?", - }, { - "type": "audio_url", - "audio_url": { - "url": audio_assets[0].url - }, - }] - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Can you transcribe this audio?", + }, + { + "type": "audio_url", + "audio_url": {"url": audio_assets[0].url}, + }, + ], + } + ] chat_completion = await multi_modal_client.chat.completions.create( - model=model_name, - messages=messages, - max_completion_tokens=128, - temperature=0.0) + model=model_name, messages=messages, max_completion_tokens=128, temperature=0.0 + ) assert len(chat_completion.choices) > 0 diff --git a/tests/entrypoints/openai/test_lora_adapters.py b/tests/entrypoints/openai/test_lora_adapters.py index 6f2addd3649d..379e7d36d9e1 100644 --- a/tests/entrypoints/openai/test_lora_adapters.py +++ b/tests/entrypoints/openai/test_lora_adapters.py @@ -20,26 +20,18 @@ BADREQUEST_CASES = [ ( "test_rank", - { - "r": 1024 - }, + {"r": 1024}, "is greater than max_lora_rank", ), ( "test_bias", - { - "bias": "all" - }, + {"bias": "all"}, "Adapter bias cannot be used without bias_enabled", ), - ("test_dora", { - "use_dora": True - }, "does not yet support DoRA"), + ("test_dora", {"use_dora": True}, "does not yet support DoRA"), ( "test_modules_to_save", - { - "modules_to_save": ["lm_head"] - }, + {"modules_to_save": ["lm_head"]}, "only supports modules_to_save being None", ), ] @@ -48,24 +40,23 @@ @pytest.fixture(scope="module") def monkeypatch_module(): from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() yield mpatch mpatch.undo() @pytest.fixture(scope="module", params=[True]) -def server_with_lora_modules_json(request, monkeypatch_module, - zephyr_lora_files): - +def server_with_lora_modules_json(request, monkeypatch_module, zephyr_lora_files): use_v1 = request.param assert use_v1 - monkeypatch_module.setenv('VLLM_USE_V1', '1') + monkeypatch_module.setenv("VLLM_USE_V1", "1") # Define the json format LoRA module configurations lora_module_1 = { "name": "zephyr-lora", "path": zephyr_lora_files, - "base_model_name": MODEL_NAME + "base_model_name": MODEL_NAME, } args = [ @@ -96,14 +87,12 @@ def server_with_lora_modules_json(request, monkeypatch_module, @pytest_asyncio.fixture async def client(server_with_lora_modules_json): - async with server_with_lora_modules_json.get_async_client( - ) as async_client: + async with server_with_lora_modules_json.get_async_client() as async_client: yield async_client @pytest.mark.asyncio -async def test_static_lora_lineage(client: openai.AsyncOpenAI, - zephyr_lora_files): +async def test_static_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files): models = await client.models.list() models = models.data served_model = models[0] @@ -111,22 +100,18 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI, assert served_model.id == MODEL_NAME assert served_model.root == MODEL_NAME assert served_model.parent is None - assert all(lora_model.root == zephyr_lora_files - for lora_model in lora_models) + assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models) assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" @pytest.mark.asyncio -async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, - zephyr_lora_files): - - response = await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "zephyr-lora-3", - "lora_path": zephyr_lora_files - }) +async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files): + response = await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "zephyr-lora-3", "lora_path": zephyr_lora_files}, + ) # Ensure adapter loads before querying /models assert "success" in response @@ -141,37 +126,37 @@ async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, @pytest.mark.asyncio async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI): with pytest.raises(openai.NotFoundError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "notfound", - "lora_path": "/not/an/adapter" - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "notfound", "lora_path": "/not/an/adapter"}, + ) @pytest.mark.asyncio -async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, - tmp_path): +async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, tmp_path): invalid_files = tmp_path / "invalid_files" invalid_files.mkdir() (invalid_files / "adapter_config.json").write_text("this is not json") with pytest.raises(openai.BadRequestError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "invalid-json", - "lora_path": str(invalid_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "invalid-json", "lora_path": str(invalid_files)}, + ) @pytest.mark.asyncio -@pytest.mark.parametrize("test_name,config_change,expected_error", - BADREQUEST_CASES) -async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path, - zephyr_lora_files, test_name: str, - config_change: dict, - expected_error: str): +@pytest.mark.parametrize("test_name,config_change,expected_error", BADREQUEST_CASES) +async def test_dynamic_lora_badrequests( + client: openai.AsyncOpenAI, + tmp_path, + zephyr_lora_files, + test_name: str, + config_change: dict, + expected_error: str, +): # Create test directory test_dir = tmp_path / test_name @@ -191,29 +176,28 @@ async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path, # Test loading the adapter with pytest.raises(openai.BadRequestError, match=expected_error): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": test_name, - "lora_path": str(test_dir) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": test_name, "lora_path": str(test_dir)}, + ) @pytest.mark.asyncio -async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path, - zephyr_lora_files): +async def test_multiple_lora_adapters( + client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files +): """Validate that many loras can be dynamically registered and inferenced with concurrently""" # This test file configures the server with --max-cpu-loras=2 and this test # will concurrently load 10 adapters, so it should flex the LRU cache async def load_and_run_adapter(adapter_name: str): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": adapter_name, - "lora_path": str(zephyr_lora_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)}, + ) for _ in range(3): await client.completions.create( model=adapter_name, @@ -223,8 +207,7 @@ async def load_and_run_adapter(adapter_name: str): lora_tasks = [] for i in range(10): - lora_tasks.append( - asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) + lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) results, _ = await asyncio.wait(lora_tasks) @@ -234,8 +217,8 @@ async def load_and_run_adapter(adapter_name: str): @pytest.mark.asyncio async def test_loading_invalid_adapters_does_not_break_others( - client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files): - + client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files +): invalid_files = tmp_path / "invalid_files" invalid_files.mkdir() (invalid_files / "adapter_config.json").write_text("this is not json") @@ -266,20 +249,18 @@ async def run_good_requests(client): # Run a bunch of bad adapter loads for _ in range(25): with suppress(openai.NotFoundError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "notfound", - "lora_path": "/not/an/adapter" - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "notfound", "lora_path": "/not/an/adapter"}, + ) for _ in range(25): with suppress(openai.BadRequestError): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "invalid", - "lora_path": str(invalid_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "invalid", "lora_path": str(invalid_files)}, + ) # Ensure all the running requests with lora adapters succeeded stop_good_requests_event.set() @@ -288,12 +269,11 @@ async def run_good_requests(client): assert not isinstance(r, Exception), f"Got exception {r}" # Ensure we can load another adapter and run it - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": "valid", - "lora_path": zephyr_lora_files - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": "valid", "lora_path": zephyr_lora_files}, + ) await client.completions.create( model="valid", prompt=["Hello there", "Foo bar bazz buzz"], @@ -310,12 +290,11 @@ async def test_beam_search_with_lora_adapters( """Validate that async beam search can be used with lora.""" async def load_and_run_adapter(adapter_name: str): - await client.post("load_lora_adapter", - cast_to=str, - body={ - "lora_name": adapter_name, - "lora_path": str(zephyr_lora_files) - }) + await client.post( + "load_lora_adapter", + cast_to=str, + body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)}, + ) for _ in range(3): await client.completions.create( model=adapter_name, @@ -326,8 +305,7 @@ async def load_and_run_adapter(adapter_name: str): lora_tasks = [] for i in range(3): - lora_tasks.append( - asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) + lora_tasks.append(asyncio.create_task(load_and_run_adapter(f"adapter_{i}"))) results, _ = await asyncio.wait(lora_tasks) diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index 45aa2070d0a2..2a15848ba447 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -12,8 +12,7 @@ from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.transformers_utils.tokenizer import get_tokenizer @@ -33,14 +32,14 @@ class MockHFConfig: @dataclass class MockModelConfig: """Minimal mock ModelConfig for testing.""" + model: str = MODEL_NAME tokenizer: str = MODEL_NAME trust_remote_code: bool = False tokenizer_mode: str = "auto" max_model_len: int = 100 tokenizer_revision: Optional[str] = None - multimodal_config: MultiModalConfig = field( - default_factory=MultiModalConfig) + multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig) hf_config: MockHFConfig = field(default_factory=MockHFConfig) logits_processor_pattern: Optional[str] = None diff_sampling_param: Optional[dict] = None @@ -55,17 +54,21 @@ def get_diff_sampling_param(self): class MockLoRAResolver(LoRAResolver): - - async def resolve_lora(self, base_model_name: str, - lora_name: str) -> Optional[LoRARequest]: + async def resolve_lora( + self, base_model_name: str, lora_name: str + ) -> Optional[LoRARequest]: if lora_name == "test-lora": - return LoRARequest(lora_name="test-lora", - lora_int_id=1, - lora_local_path="/fake/path/test-lora") + return LoRARequest( + lora_name="test-lora", + lora_int_id=1, + lora_local_path="/fake/path/test-lora", + ) elif lora_name == "invalid-lora": - return LoRARequest(lora_name="invalid-lora", - lora_int_id=2, - lora_local_path="/fake/path/invalid-lora") + return LoRARequest( + lora_name="invalid-lora", + lora_int_id=2, + lora_local_path="/fake/path/invalid-lora", + ) return None @@ -96,8 +99,7 @@ async def mock_add_lora_side_effect(lora_request: LoRARequest): return True if lora_request.lora_name == "invalid-lora": # Simulate failure during addition (e.g. invalid format) - raise ValueError(f"Simulated failure adding LoRA: " - f"{lora_request.lora_name}") + raise ValueError(f"Simulated failure adding LoRA: {lora_request.lora_name}") return True mock_engine.add_lora = AsyncMock(side_effect=mock_add_lora_side_effect) @@ -106,31 +108,31 @@ async def mock_generate(*args, **kwargs): for _ in []: yield _ - mock_engine.generate = MagicMock(spec=AsyncLLM.generate, - side_effect=mock_generate) + mock_engine.generate = MagicMock(spec=AsyncLLM.generate, side_effect=mock_generate) mock_engine.generate.reset_mock() mock_engine.add_lora.reset_mock() mock_model_config = MockModelConfig() - models = OpenAIServingModels(engine_client=mock_engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config) + models = OpenAIServingModels( + engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + ) - serving_completion = OpenAIServingCompletion(mock_engine, - mock_model_config, - models, - request_logger=None) + serving_completion = OpenAIServingCompletion( + mock_engine, mock_model_config, models, request_logger=None + ) - serving_completion._process_inputs = AsyncMock(return_value=(MagicMock( - name="engine_request"), {})) + serving_completion._process_inputs = AsyncMock( + return_value=(MagicMock(name="engine_request"), {}) + ) return mock_engine, serving_completion @pytest.mark.asyncio -async def test_serving_completion_with_lora_resolver(mock_serving_setup, - monkeypatch): +async def test_serving_completion_with_lora_resolver(mock_serving_setup, monkeypatch): monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") mock_engine, serving_completion = mock_serving_setup @@ -152,14 +154,13 @@ async def test_serving_completion_with_lora_resolver(mock_serving_setup, assert called_lora_request.lora_name == lora_model_name mock_engine.generate.assert_called_once() - called_lora_request = mock_engine.generate.call_args[1]['lora_request'] + called_lora_request = mock_engine.generate.call_args[1]["lora_request"] assert isinstance(called_lora_request, LoRARequest) assert called_lora_request.lora_name == lora_model_name @pytest.mark.asyncio -async def test_serving_completion_resolver_not_found(mock_serving_setup, - monkeypatch): +async def test_serving_completion_resolver_not_found(mock_serving_setup, monkeypatch): monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") mock_engine, serving_completion = mock_serving_setup @@ -182,7 +183,8 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup, @pytest.mark.asyncio async def test_serving_completion_resolver_add_lora_fails( - mock_serving_setup, monkeypatch): + mock_serving_setup, monkeypatch +): monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true") mock_engine, serving_completion = mock_serving_setup diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index f0b61902eb56..711505c74bca 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -54,19 +54,22 @@ def default_server_args(): ] -@pytest.fixture(scope="module", - params=[ - "", - "--enable-chunked-prefill", - "--disable-frontend-multiprocessing", - f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", - ]) +@pytest.fixture( + scope="module", + params=[ + "", + "--enable-chunked-prefill", + "--disable-frontend-multiprocessing", + f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", + ], +) def server(use_v1, default_server_args, request): if request.param: default_server_args.append(request.param) - env_dict = dict(VLLM_USE_V1='1' if use_v1 else '0') - with RemoteOpenAIServer(MODEL_NAME, default_server_args, - env_dict=env_dict) as remote_server: + env_dict = dict(VLLM_USE_V1="1" if use_v1 else "0") + with RemoteOpenAIServer( + MODEL_NAME, default_server_args, env_dict=env_dict + ) as remote_server: yield remote_server @@ -87,30 +90,36 @@ async def client(server): # {metric_family: [(suffix, expected_value)]} EXPECTED_VALUES = { "vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)], - "vllm:time_per_output_token_seconds": - [("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))], + "vllm:time_per_output_token_seconds": [ + ("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1)) + ], "vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], "vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)], "vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)], "vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)], "vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)], - "vllm:request_prompt_tokens": - [("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS)], - "vllm:request_generation_tokens": - [("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS)], + "vllm:request_prompt_tokens": [ + ("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), + ("_count", _NUM_REQUESTS), + ], + "vllm:request_generation_tokens": [ + ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), + ("_count", _NUM_REQUESTS), + ], "vllm:request_params_n": [("_count", _NUM_REQUESTS)], "vllm:request_params_max_tokens": [ ("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), - ("_count", _NUM_REQUESTS) + ("_count", _NUM_REQUESTS), + ], + "vllm:iteration_tokens_total": [ + ( + "_sum", + _NUM_REQUESTS + * (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST), + ), + ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), ], - "vllm:iteration_tokens_total": - [("_sum", _NUM_REQUESTS * - (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST)), - ("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST)], - "vllm:prompt_tokens": [("_total", - _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], + "vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], "vllm:generation_tokens": [ ("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST) ], @@ -119,14 +128,16 @@ async def client(server): @pytest.mark.asyncio -async def test_metrics_counts(server: RemoteOpenAIServer, - client: openai.AsyncClient, use_v1: bool): +async def test_metrics_counts( + server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool +): for _ in range(_NUM_REQUESTS): # sending a request triggers the metrics to be logged. await client.completions.create( model=MODEL_NAME, prompt=_TOKENIZED_PROMPT, - max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST) + max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST, + ) response = requests.get(server.url_for("metrics")) print(response.text) @@ -134,9 +145,10 @@ async def test_metrics_counts(server: RemoteOpenAIServer, # Loop over all expected metric_families for metric_family, suffix_values_list in EXPECTED_VALUES.items(): - if ((use_v1 and metric_family not in EXPECTED_METRICS_V1) - or (not server.show_hidden_metrics - and metric_family in HIDDEN_DEPRECATED_METRICS)): + if (use_v1 and metric_family not in EXPECTED_METRICS_V1) or ( + not server.show_hidden_metrics + and metric_family in HIDDEN_DEPRECATED_METRICS + ): continue found_metric = False @@ -160,14 +172,15 @@ async def test_metrics_counts(server: RemoteOpenAIServer, assert sample.value == expected_value, ( f"{metric_name_w_suffix} expected value of " f"{expected_value} did not match found value " - f"{sample.value}") + f"{sample.value}" + ) break assert found_suffix, ( f"Did not find {metric_name_w_suffix} in prom endpoint" ) break - assert found_metric, (f"Did not find {metric_family} in prom endpoint") + assert found_metric, f"Did not find {metric_family} in prom endpoint" EXPECTED_METRICS = [ @@ -290,30 +303,30 @@ async def test_metrics_counts(server: RemoteOpenAIServer, @pytest.mark.asyncio -async def test_metrics_exist(server: RemoteOpenAIServer, - client: openai.AsyncClient, use_v1: bool): +async def test_metrics_exist( + server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool +): # sending a request triggers the metrics to be logged. - await client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) + await client.completions.create( + model=MODEL_NAME, prompt="Hello, my name is", max_tokens=5, temperature=0.0 + ) response = requests.get(server.url_for("metrics")) assert response.status_code == HTTPStatus.OK - for metric in (EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS): - if (metric in HIDDEN_DEPRECATED_METRICS - and not server.show_hidden_metrics): + for metric in EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS: + if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics: continue assert metric in response.text @pytest.mark.asyncio -async def test_abort_metrics_reset(server: RemoteOpenAIServer, - client: openai.AsyncClient, use_v1: bool): - - running_requests, waiting_requests, kv_cache_usage = ( - _get_running_metrics_from_api(server, use_v1)) +async def test_abort_metrics_reset( + server: RemoteOpenAIServer, client: openai.AsyncClient, use_v1: bool +): + running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( + server, use_v1 + ) # Expect no running requests or kvcache usage assert running_requests == 0 @@ -328,15 +341,18 @@ async def test_abort_metrics_reset(server: RemoteOpenAIServer, model=MODEL_NAME, prompt=_TOKENIZED_PROMPT, max_tokens=100, # Long generation to give time to abort - temperature=0.0)) + temperature=0.0, + ) + ) tasks.append(task) # Wait a bit for requests to start processing await asyncio.sleep(0.5) # Check that we have running requests - running_requests, waiting_requests, kv_cache_usage = ( - _get_running_metrics_from_api(server, use_v1)) + running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( + server, use_v1 + ) # Expect running requests and kvcache usage assert running_requests > 0 @@ -355,17 +371,18 @@ async def test_abort_metrics_reset(server: RemoteOpenAIServer, # Verify running and waiting requests counts and KV cache usage are zero running_requests_after, waiting_requests_after, kv_cache_usage_after = ( - _get_running_metrics_from_api(server, use_v1)) + _get_running_metrics_from_api(server, use_v1) + ) - assert running_requests_after == 0,\ - (f"Expected 0 running requests after abort, got " - f"{running_requests_after}") - assert waiting_requests_after == 0,\ - (f"Expected 0 waiting requests after abort, got " - f"{waiting_requests_after}") - assert kv_cache_usage_after == 0,\ - (f"Expected 0% KV cache usage after abort, got " - f"{kv_cache_usage_after}") + assert running_requests_after == 0, ( + f"Expected 0 running requests after abort, got {running_requests_after}" + ) + assert waiting_requests_after == 0, ( + f"Expected 0 waiting requests after abort, got {waiting_requests_after}" + ) + assert kv_cache_usage_after == 0, ( + f"Expected 0% KV cache usage after abort, got {kv_cache_usage_after}" + ) def _get_running_metrics_from_api(server: RemoteOpenAIServer, use_v1: bool): @@ -377,8 +394,9 @@ def _get_running_metrics_from_api(server: RemoteOpenAIServer, use_v1: bool): # Verify running and waiting requests counts and KV cache usage are zero running_requests, waiting_requests, kv_cache_usage = None, None, None - kv_cache_usage_metric = ("vllm:kv_cache_usage_perc" - if use_v1 else "vllm:gpu_cache_usage_perc") + kv_cache_usage_metric = ( + "vllm:kv_cache_usage_perc" if use_v1 else "vllm:gpu_cache_usage_perc" + ) for family in text_string_to_metric_families(response.text): if family.name == "vllm:num_requests_running": @@ -411,28 +429,31 @@ def test_metrics_exist_run_batch(use_v1: bool): port = "8001" server_url = f"http://{base_url}:{port}" - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(input_batch) input_file.flush() - proc = subprocess.Popen([ - sys.executable, - "-m", - "vllm.entrypoints.openai.run_batch", - "-i", - input_file.name, - "-o", - output_file.name, - "--model", - "intfloat/multilingual-e5-small", - "--enable-metrics", - "--url", - base_url, - "--port", - port, - ], - env={"VLLM_USE_V1": "1"}) + proc = subprocess.Popen( + [ + sys.executable, + "-m", + "vllm.entrypoints.openai.run_batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "intfloat/multilingual-e5-small", + "--enable-metrics", + "--url", + base_url, + "--port", + port, + ], + env={"VLLM_USE_V1": "1"}, + ) def is_server_up(url): try: diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index 4ee34b19dea3..7d2968d96506 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -52,6 +52,5 @@ async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files): lora_models = models[1:] assert served_model.id == MODEL_NAME assert served_model.root == MODEL_NAME - assert all(lora_model.root == zephyr_lora_files - for lora_model in lora_models) + assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index f0ce50debe49..ba463be1d5cd 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -25,13 +25,10 @@ def run_and_test_dummy_opt_api_server(model, tp=1): client = server.get_client() completion = client.chat.completions.create( model=model, - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Hello!" - }], + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], temperature=0, ) generated_text = completion.choices[0].message.content diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 73f79ac28d11..64fdaf08893a 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -75,10 +75,11 @@ def no_invalid_types(case: schemathesis.models.Case): http://localhost:8000/v1/chat/completions """ # noqa: E501 if hasattr(case, "body") and isinstance(case.body, dict): - if ("messages" in case.body - and isinstance(case.body["messages"], list) - and len(case.body["messages"]) > 0): - + if ( + "messages" in case.body + and isinstance(case.body["messages"], list) + and len(case.body["messages"]) > 0 + ): for message in case.body["messages"]: if not isinstance(message, dict): continue @@ -86,10 +87,11 @@ def no_invalid_types(case: schemathesis.models.Case): # Check for invalid file type in tokenize endpoint if op.method.lower() == "post" and op.path == "/tokenize": content = message.get("content", []) - if (isinstance(content, list) and len(content) > 0 - and any( - item.get("type") == "file" - for item in content)): + if ( + isinstance(content, list) + and len(content) > 0 + and any(item.get("type") == "file" for item in content) + ): return False # Check for invalid tool_calls with non-function types @@ -106,10 +108,13 @@ def no_invalid_types(case: schemathesis.models.Case): # Causing a server error in EBNF grammar parsing # https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421 structured_outputs = case.body.get("structured_outputs", {}) - grammar = structured_outputs.get("grammar") if isinstance( - structured_outputs, dict) else None + grammar = ( + structured_outputs.get("grammar") + if isinstance(structured_outputs, dict) + else None + ) - if grammar == '': + if grammar == "": # Allow None (will be handled as no grammar) # But skip empty strings return False @@ -133,9 +138,8 @@ def test_openapi_stateless(case: schemathesis.Case): timeout = { # requires a longer timeout - ("POST", "/v1/chat/completions"): - LONG_TIMEOUT_SECONDS, + ("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS, }.get(key, DEFAULT_TIMEOUT_SECONDS) - #No need to verify SSL certificate for localhost + # No need to verify SSL certificate for localhost case.call_and_validate(verify=False, timeout=timeout) diff --git a/tests/entrypoints/openai/test_optional_middleware.py b/tests/entrypoints/openai/test_optional_middleware.py index eb387998c2cc..b67d6147937d 100644 --- a/tests/entrypoints/openai/test_optional_middleware.py +++ b/tests/entrypoints/openai/test_optional_middleware.py @@ -37,7 +37,7 @@ def server(request: pytest.FixtureRequest): "--enforce-eager", "--max-num-seqs", "2", - *passed_params + *passed_params, ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -73,8 +73,9 @@ async def test_missing_api_token(server: RemoteOpenAIServer): ) @pytest.mark.asyncio async def test_passed_api_token(server: RemoteOpenAIServer): - response = requests.get(server.url_for("v1/models"), - headers={"Authorization": "Bearer test"}) + response = requests.get( + server.url_for("v1/models"), headers={"Authorization": "Bearer test"} + ) assert response.status_code == HTTPStatus.OK @@ -110,7 +111,8 @@ async def test_enable_request_id_header(server: RemoteOpenAIServer): ) @pytest.mark.asyncio async def test_custom_request_id_header(server: RemoteOpenAIServer): - response = requests.get(server.url_for("health"), - headers={"X-Request-Id": "Custom"}) + response = requests.get( + server.url_for("health"), headers={"X-Request-Id": "Custom"} + ) assert "X-Request-Id" in response.headers assert response.headers.get("X-Request-Id") == "Custom" diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index bb4c633e5e50..81e2b52dfa71 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -17,7 +17,7 @@ @pytest.fixture(scope="function", autouse=True) def use_v1_only(monkeypatch): - monkeypatch.setenv('VLLM_USE_V1', '1') + monkeypatch.setenv("VLLM_USE_V1", "1") @pytest.mark.asyncio @@ -28,15 +28,16 @@ async def test_empty_prompt(): client = remote_server.get_async_client() with pytest.raises( - openai.BadRequestError, - match= - "Either prompt or prompt_embeds must be provided and non-empty." + openai.BadRequestError, + match="Either prompt or prompt_embeds must be provided and non-empty.", ): - await client.completions.create(model=model_name, - prompt="", - max_tokens=5, - temperature=0.0, - extra_body={"prompt_embeds": []}) + await client.completions.create( + model=model_name, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": []}, + ) @pytest.mark.asyncio @@ -46,23 +47,23 @@ async def test_out_of_vocab_token_ids(): with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() - with pytest.raises(openai.BadRequestError, - match=re.compile('.*out of vocabulary.*').pattern): - await client.completions.create(model=model_name, - prompt=[999999], - max_tokens=5, - temperature=0.0) + with pytest.raises( + openai.BadRequestError, match=re.compile(".*out of vocabulary.*").pattern + ): + await client.completions.create( + model=model_name, prompt=[999999], max_tokens=5, temperature=0.0 + ) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize( - "layout", - [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr]) + "layout", [torch.strided, torch.sparse_coo, torch.sparse_csc, torch.sparse_csr] +) @pytest.mark.parametrize("seq_len", [2, 10]) @pytest.mark.parametrize("hidden_size", [2, 10]) -def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout, - seq_len: int, hidden_size: int): +def test_load_prompt_embeds( + dtype: torch.dtype, layout: torch.layout, seq_len: int, hidden_size: int +): # construct arbitrary tensors of various dtypes, layouts, and sizes. # We need to check against different layouts to make sure that if a user # uses sparse tensors to reduce the transmission size of prompt embeddings, @@ -92,6 +93,6 @@ def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout, loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"] assert loaded_tensor.device.type == "cpu" assert loaded_tensor.layout == torch.strided - torch.testing.assert_close(loaded_tensor, - tensor.to("cpu").to_dense(), - equal_nan=True) + torch.testing.assert_close( + loaded_tensor, tensor.to("cpu").to_dense(), equal_nan=True + ) diff --git a/tests/entrypoints/openai/test_response_api_mcp_tools.py b/tests/entrypoints/openai/test_response_api_mcp_tools.py index b0eb84712c19..653d44f20b44 100644 --- a/tests/entrypoints/openai/test_response_api_mcp_tools.py +++ b/tests/entrypoints/openai/test_response_api_mcp_tools.py @@ -13,6 +13,7 @@ @pytest.fixture(scope="module") def monkeypatch_module(): from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() yield mpatch mpatch.undo() @@ -36,8 +37,7 @@ def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch): with monkeypatch_module.context() as m: m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") - m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", - "code_interpreter,container") + m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container") with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -57,23 +57,26 @@ async def mcp_enabled_client(mcp_enabled_server): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") -async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, - model_name: str): +async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: str): response = await mcp_enabled_client.responses.create( model=model_name, # TODO: Ideally should be able to set max tool calls # to prevent multi-turn, but it is not currently supported # would speed up the test - input=("What's the first 4 digits after the decimal point of " - "cube root of `19910212 * 20250910`? " - "Show only the digits. The python interpreter is not stateful " - "and you must print to see the output."), - tools=[{ - "type": "mcp", - "server_label": "code_interpreter", - # URL unused for DemoToolServer - "server_url": "http://localhost:8888" - }], + input=( + "What's the first 4 digits after the decimal point of " + "cube root of `19910212 * 20250910`? " + "Show only the digits. The python interpreter is not stateful " + "and you must print to see the output." + ), + tools=[ + { + "type": "mcp", + "server_label": "code_interpreter", + # URL unused for DemoToolServer + "server_url": "http://localhost:8888", + } + ], ) assert response is not None assert response.status == "completed" @@ -83,23 +86,26 @@ async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") -async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, - model_name: str): +async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str): response = await mcp_disabled_client.responses.create( model=model_name, # TODO: Ideally should be able to set max tool calls # to prevent multi-turn, but it is not currently supported # would speed up the test - input=("What's the first 4 digits after the decimal point of " - "cube root of `19910212 * 20250910`? " - "Show only the digits. The python interpreter is not stateful " - "and you must print to see the output."), - tools=[{ - "type": "mcp", - "server_label": "code_interpreter", - # URL unused for DemoToolServer - "server_url": "http://localhost:8888" - }], + input=( + "What's the first 4 digits after the decimal point of " + "cube root of `19910212 * 20250910`? " + "Show only the digits. The python interpreter is not stateful " + "and you must print to see the output." + ), + tools=[ + { + "type": "mcp", + "server_label": "code_interpreter", + # URL unused for DemoToolServer + "server_url": "http://localhost:8888", + } + ], ) assert response is not None assert response.status == "completed" diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 400779064ef5..fb0035de67c2 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -15,21 +15,15 @@ @pytest.fixture(scope="module") -def monkeypatch_module(): - from _pytest.monkeypatch import MonkeyPatch - mpatch = MonkeyPatch() - yield mpatch - mpatch.undo() - - -@pytest.fixture(scope="module") -def server(monkeypatch_module: pytest.MonkeyPatch): +def server(): args = ["--enforce-eager", "--tool-server", "demo"] + env_dict = dict( + VLLM_ENABLE_RESPONSES_API_STORE="1", + PYTHON_EXECUTION_BACKEND="dangerously_use_uv", + ) - with monkeypatch_module.context() as m: - m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: + yield remote_server @pytest_asyncio.fixture @@ -94,22 +88,10 @@ async def test_chat(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, input=[ - { - "role": "system", - "content": "Respond in Korean." - }, - { - "role": "user", - "content": "Hello!" - }, - { - "role": "assistant", - "content": "Hello! How can I help you today?" - }, - { - "role": "user", - "content": "What is 13 * 24? Explain your answer." - }, + {"role": "system", "content": "Respond in Korean."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hello! How can I help you today?"}, + {"role": "user", "content": "What is 13 * 24? Explain your answer."}, ], ) assert response is not None @@ -124,10 +106,7 @@ async def test_chat_with_input_type(client: OpenAI, model_name: str): input=[ { "role": "user", - "content": [{ - "type": "input_text", - "text": "What is 13*24?" - }], + "content": [{"type": "input_text", "text": "What is 13*24?"}], }, ], ) @@ -141,14 +120,10 @@ async def test_structured_output(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, input=[ - { - "role": "system", - "content": "Extract the event information." - }, + {"role": "system", "content": "Extract the event information."}, { "role": "user", - "content": - "Alice and Bob are going to a science fair on Friday.", + "content": "Alice and Bob are going to a science fair on Friday.", }, ], text={ @@ -158,18 +133,9 @@ async def test_structured_output(client: OpenAI, model_name: str): "schema": { "type": "object", "properties": { - "name": { - "type": "string" - }, - "date": { - "type": "string" - }, - "participants": { - "type": "array", - "items": { - "type": "string" - } - }, + "name": {"type": "string"}, + "date": {"type": "string"}, + "participants": {"type": "array", "items": {"type": "string"}}, }, "required": ["name", "date", "participants"], "additionalProperties": False, @@ -319,11 +285,10 @@ async def test_streaming_types(client: OpenAI, model_name: str): stack_of_event_types = [] async for event in response: - if event.type == 'response.created': + if event.type == "response.created": stack_of_event_types.append(event.type) - elif event.type == 'response.completed': - assert stack_of_event_types[-1] == pairs_of_event_types[ - event.type] + elif event.type == "response.completed": + assert stack_of_event_types[-1] == pairs_of_event_types[event.type] stack_of_event_types.pop() if event.type.endswith("added"): stack_of_event_types.append(event.type) @@ -332,8 +297,7 @@ async def test_streaming_types(client: OpenAI, model_name: str): continue stack_of_event_types.append(event.type) elif event.type.endswith("done"): - assert stack_of_event_types[-1] == pairs_of_event_types[ - event.type] + assert stack_of_event_types[-1] == pairs_of_event_types[event.type] stack_of_event_types.pop() assert len(stack_of_event_types) == 0 @@ -345,7 +309,7 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): # TODO: Add back when web search and code interpreter are available in CI prompts = [ "tell me a story about a cat in 20 words", - # "What is 13 * 24? Use python to calculate the result.", + "What is 13 * 24? Use python to calculate the result.", # "When did Jensen found NVIDIA? Search it and answer the year only.", ] @@ -358,12 +322,7 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): # { # "type": "web_search_preview" # }, - # { - # "type": "code_interpreter", - # "container": { - # "type": "auto" - # } - # }, + {"type": "code_interpreter", "container": {"type": "auto"}}, ], stream=True, background=background, @@ -381,11 +340,12 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): # test vllm custom types are in the response if event.type in [ - "response.completed", "response.in_progress", - "response.created" + "response.completed", + "response.in_progress", + "response.created", ]: - assert 'input_messages' in event.response.model_extra - assert 'output_messages' in event.response.model_extra + assert "input_messages" in event.response.model_extra + assert "output_messages" in event.response.model_extra if current_event_mode != event.type: current_event_mode = event.type @@ -396,21 +356,21 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): assert event.item.id != current_item_id current_item_id = event.item.id elif event.type in [ - "response.output_text.delta", - "response.reasoning_text.delta" + "response.output_text.delta", + "response.reasoning_text.delta", ]: assert event.item_id == current_item_id # verify content_index_id is correct if event.type in [ - "response.content_part.added", - "response.reasoning_part.added" + "response.content_part.added", + "response.reasoning_part.added", ]: assert event.content_index != current_content_index current_content_index = event.content_index elif event.type in [ - "response.output_text.delta", - "response.reasoning_text.delta" + "response.output_text.delta", + "response.reasoning_text.delta", ]: assert event.content_index == current_content_index @@ -420,8 +380,10 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): print(f"{event.delta}", end="", flush=True) elif "response.code_interpreter_call_code.done" in event.type: print(f"Code: {event.code}", end="", flush=True) - elif ("response.output_item.added" in event.type - and event.item.type == "web_search_call"): + elif ( + "response.output_item.added" in event.type + and event.item.type == "web_search_call" + ): print(f"Web search: {event.item.action}", end="", flush=True) events.append(event) @@ -432,13 +394,13 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): if background: starting_after = 5 async with await client.responses.retrieve( - response_id=resp_id, - stream=True, - starting_after=starting_after) as stream: + response_id=resp_id, stream=True, starting_after=starting_after + ) as stream: counter = starting_after async for event in stream: counter += 1 assert event == events[counter] + assert counter == len(events) - 1 @pytest.mark.asyncio @@ -448,9 +410,7 @@ async def test_web_search(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, input="Who is the president of South Korea as of now?", - tools=[{ - "type": "web_search_preview" - }], + tools=[{"type": "web_search_preview"}], ) assert response is not None assert response.status == "completed" @@ -458,27 +418,29 @@ async def test_web_search(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") async def test_code_interpreter(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, # TODO: Ideally should be able to set max tool calls # to prevent multi-turn, but it is not currently supported # would speed up the test - input=("What's the first 4 digits after the decimal point of " - "cube root of `19910212 * 20250910`? " - "Show only the digits. The python interpreter is not stateful " - "and you must print to see the output."), - tools=[{ - "type": "code_interpreter", - "container": { - "type": "auto" - } - }], + input=( + "What's the first 4 digits after the decimal point of " + "cube root of `19910212 * 20250910`? " + "Show only the digits. The python interpreter is not stateful " + "and you must print to see the output." + ), + tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], + temperature=0.0, # More deterministic output in response ) assert response is not None assert response.status == "completed" assert response.usage.output_tokens_details.tool_output_tokens > 0 + for item in response.output: + if item.type == "message": + output_string = item.content[0].text + print("output_string: ", output_string, flush=True) + assert "5846" in output_string def get_weather(latitude, longitude): @@ -505,26 +467,23 @@ def call_function(name, args): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_function_calling(client: OpenAI, model_name: str): - tools = [{ - "type": "function", - "name": "get_weather", - "description": - "Get current temperature for provided coordinates in celsius.", # noqa - "parameters": { - "type": "object", - "properties": { - "latitude": { - "type": "number" - }, - "longitude": { - "type": "number" + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, }, + "required": ["latitude", "longitude"], + "additionalProperties": False, }, - "required": ["latitude", "longitude"], - "additionalProperties": False, - }, - "strict": True, - }] + "strict": True, + } + ] response = await client.responses.create( model=model_name, @@ -547,11 +506,13 @@ async def test_function_calling(client: OpenAI, model_name: str): response_2 = await client.responses.create( model=model_name, - input=[{ - "type": "function_call_output", - "call_id": tool_call.call_id, - "output": str(result), - }], + input=[ + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ], tools=tools, previous_response_id=response.id, ) @@ -591,17 +552,12 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): { "type": "function", "name": "get_weather", - "description": - "Get current temperature for provided coordinates in celsius.", # noqa + "description": "Get current temperature for provided coordinates in celsius.", # noqa "parameters": { "type": "object", "properties": { - "latitude": { - "type": "number" - }, - "longitude": { - "type": "number" - }, + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, }, "required": ["latitude", "longitude"], "additionalProperties": False, @@ -612,8 +568,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, - input= - "Help me plan a trip to a random place. And tell me the weather there.", + input="Help me plan a trip to a random place. And tell me the weather there.", tools=tools, ) assert response is not None @@ -630,11 +585,13 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): response_2 = await client.responses.create( model=model_name, - input=[{ - "type": "function_call_output", - "call_id": tool_call.call_id, - "output": str(result), - }], + input=[ + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ], tools=tools, previous_response_id=response.id, ) @@ -652,11 +609,13 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): response_3 = await client.responses.create( model=model_name, - input=[{ - "type": "function_call_output", - "call_id": tool_call.call_id, - "output": str(result), - }], + input=[ + { + "type": "function_call_output", + "call_id": tool_call.call_id, + "output": str(result), + } + ], tools=tools, previous_response_id=response_2.id, ) @@ -668,26 +627,23 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_function_calling_required(client: OpenAI, model_name: str): - tools = [{ - "type": "function", - "name": "get_weather", - "description": - "Get current temperature for provided coordinates in celsius.", # noqa - "parameters": { - "type": "object", - "properties": { - "latitude": { - "type": "number" - }, - "longitude": { - "type": "number" + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, }, + "required": ["latitude", "longitude"], + "additionalProperties": False, }, - "required": ["latitude", "longitude"], - "additionalProperties": False, - }, - "strict": True, - }] + "strict": True, + } + ] with pytest.raises(BadRequestError): await client.responses.create( @@ -717,31 +673,27 @@ async def test_system_message_with_tools(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_function_calling_full_history(client: OpenAI, model_name: str): - tools = [{ - "type": "function", - "name": "get_weather", - "description": - "Get current temperature for provided coordinates in celsius.", # noqa - "parameters": { - "type": "object", - "properties": { - "latitude": { - "type": "number" - }, - "longitude": { - "type": "number" + tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current temperature for provided coordinates in celsius.", # noqa + "parameters": { + "type": "object", + "properties": { + "latitude": {"type": "number"}, + "longitude": {"type": "number"}, }, + "required": ["latitude", "longitude"], + "additionalProperties": False, }, - "required": ["latitude", "longitude"], - "additionalProperties": False, - }, - "strict": True, - }] + "strict": True, + } + ] - input_messages = [{ - "role": "user", - "content": "What's the weather like in Paris today?" - }] + input_messages = [ + {"role": "user", "content": "What's the weather like in Paris today?"} + ] response = await client.responses.create( model=model_name, @@ -758,8 +710,7 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str): result = call_function(name, args) - input_messages.extend( - response.output) # append model's function call message + input_messages.extend(response.output) # append model's function call message input_messages.append( { # append result message "type": "function_call_output", @@ -780,12 +731,12 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_output_messages_enabled(client: OpenAI, model_name: str, - server): +async def test_output_messages_enabled(client: OpenAI, model_name: str, server): response = await client.responses.create( model=model_name, input="What is the capital of South Korea?", - extra_body={"enable_response_messages": True}) + extra_body={"enable_response_messages": True}, + ) assert response is not None assert response.status == "completed" diff --git a/tests/entrypoints/openai/test_return_token_ids.py b/tests/entrypoints/openai/test_return_token_ids.py index ff8f193fec55..60a80210fb76 100644 --- a/tests/entrypoints/openai/test_return_token_ids.py +++ b/tests/entrypoints/openai/test_return_token_ids.py @@ -50,13 +50,16 @@ async def test_basic_completion_with_emoji(server): # Check against the expected prompt token IDs tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) encoded_tokens = tokenizer.encode( - "Complete this sentence with emojis: I love coding 🚀") + "Complete this sentence with emojis: I love coding 🚀" + ) # Check that encoded_tokens is a subsequence of prompt_token_ids - assert any(completion.choices[0].prompt_token_ids[i:i + - len(encoded_tokens)] - == encoded_tokens for i in range( - len(completion.choices[0].prompt_token_ids) - - len(encoded_tokens) + 1)) + assert any( + completion.choices[0].prompt_token_ids[i : i + len(encoded_tokens)] + == encoded_tokens + for i in range( + len(completion.choices[0].prompt_token_ids) - len(encoded_tokens) + 1 + ) + ) # Verify token_ids field is present in the choice assert completion.choices[0].token_ids is not None @@ -86,44 +89,38 @@ async def test_basic_completion_with_emoji(server): @pytest.mark.asyncio async def test_chat_completion_with_tool_use(server): """Test chat completion with tool use (get_weather function).""" - tools = [{ - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": - "string", - "description": - "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The unit of temperature", + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature", + }, }, + "required": ["location"], }, - "required": ["location"], }, - }, - }] + } + ] async with server.get_async_client() as client: # Test with return_token_ids enabled response = await client.chat.completions.create( model=MODEL_NAME, messages=[ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "What's the weather like in Paris?" - }, + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather like in Paris?"}, ], tools=tools, tool_choice="auto", @@ -145,10 +142,11 @@ async def test_chat_completion_with_tool_use(server): tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) prompt_text = tokenizer.decode(response.prompt_token_ids) assert prompt_text.startswith( - "<|im_start|>system\nYou are a helpful assistant.") + "<|im_start|>system\nYou are a helpful assistant." + ) assert prompt_text.endswith( - "What's the weather like in Paris?<|im_end|>\n" - "<|im_start|>assistant\n") + "What's the weather like in Paris?<|im_end|>\n<|im_start|>assistant\n" + ) response_text = tokenizer.decode(response.choices[0].token_ids) assert response_text.startswith('<tool_call>\n{"name": "get_weather"') @@ -164,14 +162,8 @@ async def test_chat_completion_with_tool_use(server): response_without = await client.chat.completions.create( model=MODEL_NAME, messages=[ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": "What's the weather like in Paris?" - }, + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What's the weather like in Paris?"}, ], tools=tools, tool_choice="auto", @@ -203,7 +195,7 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server): extra_body={ "return_token_ids": True, "return_tokens_as_token_ids": True, - "prompt_logprobs": 1 + "prompt_logprobs": 1, }, ) @@ -228,16 +220,17 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server): # The prompt_token_ids should match the prompt portion assert len(completion.choices[0].token_ids) < len(logprobs_token_ids) response_token_ids_length = len(completion.choices[0].token_ids) - assert logprobs_token_ids[-response_token_ids_length:] == \ - completion.choices[0].token_ids + assert ( + logprobs_token_ids[-response_token_ids_length:] + == completion.choices[0].token_ids + ) # Verify tokenizer consistency tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) # Decode prompt tokens if completion.choices[0].prompt_token_ids: - prompt_text = tokenizer.decode( - completion.choices[0].prompt_token_ids) + prompt_text = tokenizer.decode(completion.choices[0].prompt_token_ids) # The decoded prompt should match or close to original prompt assert "Hello, world" in prompt_text @@ -255,10 +248,7 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server): stream=True, echo=False, logprobs=1, - extra_body={ - "return_token_ids": True, - "return_tokens_as_token_ids": True - }, + extra_body={"return_token_ids": True, "return_tokens_as_token_ids": True}, ) # Collect streamed tokens @@ -287,14 +277,8 @@ async def test_comparison_with_prompt_logprobs_and_logprobs(server): async def test_chat_completion_with_emoji_and_token_ids(server): """Test chat completion with emojis to verify token_ids handling.""" chat_messages = [ - { - "role": "system", - "content": "You like to use emojis in your responses." - }, - { - "role": "user", - "content": "Repeat after me: I love cats 🐱" - }, + {"role": "system", "content": "You like to use emojis in your responses."}, + {"role": "user", "content": "Repeat after me: I love cats 🐱"}, ] async with server.get_async_client() as client: response = await client.chat.completions.create( @@ -319,15 +303,16 @@ async def test_chat_completion_with_emoji_and_token_ids(server): decoded_prompt = tokenizer.decode(response.prompt_token_ids) assert decoded_prompt.startswith( - "<|im_start|>system\nYou like to use emojis in your responses.") + "<|im_start|>system\nYou like to use emojis in your responses." + ) assert decoded_prompt.endswith( - "I love cats 🐱<|im_end|>\n<|im_start|>assistant\n") + "I love cats 🐱<|im_end|>\n<|im_start|>assistant\n" + ) decoded_response = tokenizer.decode(response.choices[0].token_ids) # The content should match the response text # except the ending <|im_end|> - assert decoded_response == response.choices[ - 0].message.content + "<|im_end|>" + assert decoded_response == response.choices[0].message.content + "<|im_end|>" # Test with streaming stream = await client.chat.completions.create( @@ -348,14 +333,14 @@ async def test_chat_completion_with_emoji_and_token_ids(server): assert chunk.prompt_token_ids is not None assert isinstance(chunk.prompt_token_ids, list) # Check the prompt_token_ids match the initial prompt - decoded_prompt_stream = tokenizer.decode( - chunk.prompt_token_ids) + decoded_prompt_stream = tokenizer.decode(chunk.prompt_token_ids) assert decoded_prompt_stream == decoded_prompt first_chunk = False else: chunk_dump = chunk.model_dump() - assert "prompt_token_ids" not in chunk_dump, \ + assert "prompt_token_ids" not in chunk_dump, ( "Subsequent chunks should not have prompt_token_ids" + ) if chunk.choices: if chunk.choices[0].delta.content: diff --git a/tests/entrypoints/openai/test_return_tokens_as_ids.py b/tests/entrypoints/openai/test_return_tokens_as_ids.py index ef9d5234f231..adbcc1f2430c 100644 --- a/tests/entrypoints/openai/test_return_tokens_as_ids.py +++ b/tests/entrypoints/openai/test_return_tokens_as_ids.py @@ -44,22 +44,19 @@ def server_fixture(request, default_server_args): # noqa: F811 with RemoteOpenAIServer(MODEL_NAME, args_with_flag) as remote_server: yield (remote_server, True) else: - with RemoteOpenAIServer(MODEL_NAME, - default_server_args) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: yield (remote_server, False) @pytest.mark.asyncio @pytest.mark.parametrize("server_fixture", [True, False], indirect=True) -async def test_completion_return_tokens_as_token_ids_completion( - server_fixture): +async def test_completion_return_tokens_as_token_ids_completion(server_fixture): server, use_server_flag = server_fixture request_args = {} if not use_server_flag: request_args["return_tokens_as_token_ids"] = True async with server.get_async_client() as client: - completion = await client.completions.create( model=MODEL_NAME, # Include Unicode characters to test for dividing a single @@ -70,7 +67,8 @@ async def test_completion_return_tokens_as_token_ids_completion( temperature=0, max_tokens=10, logprobs=1, - extra_body=request_args) + extra_body=request_args, + ) text = completion.choices[0].text token_strs = completion.choices[0].logprobs.tokens @@ -104,22 +102,22 @@ async def test_chat_return_tokens_as_token_ids_completion(server_fixture): # Include Unicode characters to test for dividing a single # character across multiple tokens: 🎉 is [28705, 31862] for the # Zephyr tokenizer - messages=[{ - "role": "system", - "content": "You like to respond in only emojis, like 🎉" - }, { - "role": "user", - "content": "Please write some emojis: 🐱🐶🎉" - }], + messages=[ + { + "role": "system", + "content": "You like to respond in only emojis, like 🎉", + }, + {"role": "user", "content": "Please write some emojis: 🐱🐶🎉"}, + ], temperature=0, max_tokens=8, logprobs=True, - extra_body=request_args) + extra_body=request_args, + ) text = response.choices[0].message.content tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) token_ids = [] for logprob_content in response.choices[0].logprobs.content: - token_ids.append( - int(logprob_content.token.removeprefix("token_id:"))) + token_ids.append(int(logprob_content.token.removeprefix("token_id:"))) assert tokenizer.decode(token_ids, skip_special_tokens=True) == text diff --git a/tests/entrypoints/openai/test_root_path.py b/tests/entrypoints/openai/test_root_path.py index 7b4966848b9d..6bcb80878f07 100644 --- a/tests/entrypoints/openai/test_root_path.py +++ b/tests/entrypoints/openai/test_root_path.py @@ -51,26 +51,31 @@ class TestCase(NamedTuple): model_name=MODEL_NAME, base_url=["v1"], # http://localhost:8000/v1 api_key=ERROR_API_KEY, - expected_error=openai.AuthenticationError), + expected_error=openai.AuthenticationError, + ), TestCase( model_name=MODEL_NAME, base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 api_key=ERROR_API_KEY, - expected_error=openai.AuthenticationError), + expected_error=openai.AuthenticationError, + ), TestCase( model_name=MODEL_NAME, base_url=["v1"], # http://localhost:8000/v1 api_key=API_KEY, - expected_error=None), + expected_error=None, + ), TestCase( model_name=MODEL_NAME, base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 api_key=API_KEY, - expected_error=None), + expected_error=None, + ), ], ) -async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer, - test_case: TestCase): +async def test_chat_session_root_path_with_api_key( + server: RemoteOpenAIServer, test_case: TestCase +): saying: str = "Here is a common saying about apple. An apple a day, keeps" ctx = contextlib.nullcontext() if test_case.expected_error is not None: @@ -79,20 +84,16 @@ async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer, client = openai.AsyncOpenAI( api_key=test_case.api_key, base_url=server.url_for(*test_case.base_url), - max_retries=0) + max_retries=0, + ) chat_completion = await client.chat.completions.create( model=test_case.model_name, - messages=[{ - "role": "user", - "content": "tell me a common saying" - }, { - "role": "assistant", - "content": saying - }], - extra_body={ - "continue_final_message": True, - "add_generation_prompt": False - }) + messages=[ + {"role": "user", "content": "tell me a common saying"}, + {"role": "assistant", "content": saying}, + ], + extra_body={"continue_final_message": True, "add_generation_prompt": False}, + ) assert chat_completion.id is not None assert len(chat_completion.choices) == 1 diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index e23f41e983b0..d31dadf90679 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -35,15 +35,24 @@ def test_empty_file(): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write("") input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "intfloat/multilingual-e5-small" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "intfloat/multilingual-e5-small", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" @@ -53,15 +62,24 @@ def test_empty_file(): def test_completions(): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(INPUT_BATCH) input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "NousResearch/Meta-Llama-3-8B-Instruct" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "NousResearch/Meta-Llama-3-8B-Instruct", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" @@ -77,30 +95,48 @@ def test_completions_invalid_input(): """ Ensure that we fail when the input doesn't conform to the openai api. """ - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(INVALID_INPUT_BATCH) input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "NousResearch/Meta-Llama-3-8B-Instruct" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "NousResearch/Meta-Llama-3-8B-Instruct", + ], + ) proc.communicate() proc.wait() assert proc.returncode != 0, f"{proc=}" def test_embeddings(): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(INPUT_EMBEDDING_BATCH) input_file.flush() - proc = subprocess.Popen([ - "vllm", "run-batch", "-i", input_file.name, "-o", output_file.name, - "--model", "intfloat/multilingual-e5-small" - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "intfloat/multilingual-e5-small", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" @@ -112,24 +148,26 @@ def test_embeddings(): BatchRequestOutput.model_validate_json(line) -@pytest.mark.parametrize("input_batch", - [INPUT_SCORE_BATCH, INPUT_RERANK_BATCH]) +@pytest.mark.parametrize("input_batch", [INPUT_SCORE_BATCH, INPUT_RERANK_BATCH]) def test_score(input_batch): - with tempfile.NamedTemporaryFile( - "w") as input_file, tempfile.NamedTemporaryFile( - "r") as output_file: + with ( + tempfile.NamedTemporaryFile("w") as input_file, + tempfile.NamedTemporaryFile("r") as output_file, + ): input_file.write(input_batch) input_file.flush() - proc = subprocess.Popen([ - "vllm", - "run-batch", - "-i", - input_file.name, - "-o", - output_file.name, - "--model", - "BAAI/bge-reranker-v2-m3", - ], ) + proc = subprocess.Popen( + [ + "vllm", + "run-batch", + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + "BAAI/bge-reranker-v2-m3", + ], + ) proc.communicate() proc.wait() assert proc.returncode == 0, f"{proc=}" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 81683854e177..abe5a5f4ffc1 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -6,7 +6,7 @@ import asyncio from contextlib import suppress from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from unittest.mock import AsyncMock, MagicMock import pytest @@ -15,8 +15,7 @@ from vllm.config.multimodal import MultiModalConfig from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.v1.engine.async_llm import AsyncLLM @@ -31,14 +30,17 @@ @pytest.fixture(scope="module") def monkeypatch_module(): from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() yield mpatch mpatch.undo() -@pytest.fixture(scope="module", - params=[True, False], - ids=["with_tool_parser", "without_tool_parser"]) +@pytest.fixture( + scope="module", + params=[True, False], + ids=["with_tool_parser", "without_tool_parser"], +) def with_tool_parser(request) -> bool: return request.param @@ -56,21 +58,25 @@ def default_server_args(with_tool_parser: bool): "0.8", ] if with_tool_parser: - args.extend([ - "--tool-call-parser", - "openai", - "--enable-auto-tool-choice", - ]) + args.extend( + [ + "--tool-call-parser", + "openai", + "--enable-auto-tool-choice", + ] + ) return args @pytest.fixture(scope="module") -def gptoss_server(monkeypatch_module: pytest.MonkeyPatch, - default_server_args: list[str]): +def gptoss_server( + monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str] +): with monkeypatch_module.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") - with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, - default_server_args) as remote_server: + with RemoteOpenAIServer( + GPT_OSS_MODEL_NAME, default_server_args + ) as remote_server: yield remote_server @@ -81,44 +87,41 @@ async def gptoss_client(gptoss_server): @pytest.mark.asyncio -async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI, - with_tool_parser: bool): - tools = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string" - }, - "state": { - "type": "string" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], +async def test_gpt_oss_chat_tool_call_streaming( + gptoss_client: OpenAI, with_tool_parser: bool +): + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "state": {"type": "string"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, }, + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"], }, - }, - }] + } + ] messages = [ - { - "role": "user", - "content": "What is the weather in Dallas, TX?" - }, + {"role": "user", "content": "What is the weather in Dallas, TX?"}, ] stream = await gptoss_client.chat.completions.create( model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools if with_tool_parser else None, - stream=True) + stream=True, + ) name = None args_buf = "" @@ -143,43 +146,34 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI, @pytest.mark.asyncio -async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, - with_tool_parser: bool): +async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool): if not with_tool_parser: pytest.skip("skip non-tool for multi-turn tests") - tools = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string" - }, - "state": { - "type": "string" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "state": {"type": "string"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, }, + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"], }, - }, - }] + } + ] messages = [ - { - "role": "system", - "content": "you are a helpful assistant" - }, - { - "role": "user", - "content": "What is the weather in Dallas, TX with celsius?" - }, + {"role": "system", "content": "you are a helpful assistant"}, + {"role": "user", "content": "What is the weather in Dallas, TX with celsius?"}, ] first = await gptoss_client.chat.completions.create( @@ -197,10 +191,9 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, assert not first_msg.content messages.append({"role": "assistant", "content": args1}) - messages.append({ - "role": "user", - "content": "Now convert to celsius and return JSON only" - }) + messages.append( + {"role": "user", "content": "Now convert to celsius and return JSON only"} + ) second = await gptoss_client.chat.completions.create( model=GPT_OSS_MODEL_NAME, @@ -209,8 +202,9 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, temperature=0.0, ) second_msg = second.choices[0].message - assert (second_msg.content is not None and len(second_msg.content) > 0) or \ - (second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) + assert (second_msg.content is not None and len(second_msg.content) > 0) or ( + second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0 + ) MODEL_NAME = "openai-community/gpt2" @@ -218,7 +212,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, CHAT_TEMPLATE = "Dummy chat template for testing {}" BASE_MODEL_PATHS = [ BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME), - BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT) + BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT), ] @@ -239,9 +233,9 @@ class MockModelConfig: multimodal_config = MultiModalConfig() hf_config = MockHFConfig() logits_processor_pattern = None - diff_sampling_param: Optional[dict] = None + diff_sampling_param: dict | None = None allowed_local_media_path: str = "" - allowed_media_domains: Optional[list[str]] = None + allowed_media_domains: list[str] | None = None encoder_config = None generation_config: str = "auto" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) @@ -251,21 +245,33 @@ def get_diff_sampling_param(self): return self.diff_sampling_param or {} -def _build_serving_chat(engine: AsyncLLM, - model_config: MockModelConfig) -> OpenAIServingChat: - models = OpenAIServingModels(engine_client=engine, - base_model_paths=BASE_MODEL_PATHS, - model_config=model_config) - serving_chat = OpenAIServingChat(engine, - model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) - - async def _fake_process_inputs(request_id, engine_prompt, sampling_params, - *, lora_request, trace_headers, priority): +def _build_serving_chat( + engine: AsyncLLM, model_config: MockModelConfig +) -> OpenAIServingChat: + models = OpenAIServingModels( + engine_client=engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=model_config, + ) + serving_chat = OpenAIServingChat( + engine, + model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + ) + + async def _fake_process_inputs( + request_id, + engine_prompt, + sampling_params, + *, + lora_request, + trace_headers, + priority, + ): return dict(engine_prompt), {} serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs) @@ -274,7 +280,6 @@ async def _fake_process_inputs(request_id, engine_prompt, sampling_params, @dataclass class MockEngine: - async def get_model_config(self): return MockModelConfig() @@ -284,13 +289,15 @@ async def _async_serving_chat_init(): model_config = await engine.get_model_config() models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS) - serving_completion = OpenAIServingChat(engine, - model_config, - models, - response_role="assistant", - chat_template=CHAT_TEMPLATE, - chat_template_content_format="auto", - request_logger=None) + serving_completion = OpenAIServingChat( + engine, + model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + ) return serving_completion @@ -336,10 +343,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], ) with suppress(Exception): @@ -371,10 +375,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): # Test Case 1: No max_tokens specified in request req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], ) with suppress(Exception): @@ -416,10 +417,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): # Test case 1: No max_tokens specified, defaults to context_window req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], ) with suppress(Exception): @@ -446,11 +444,10 @@ async def test_serving_chat_should_set_correct_max_tokens(): @pytest.mark.asyncio async def test_serving_chat_could_load_correct_generation_config(): - mock_model_config = MockModelConfig() mock_model_config.diff_sampling_param = { "temperature": 0.5, - "repetition_penalty": 1.05 + "repetition_penalty": 1.05, } mock_engine = MagicMock(spec=AsyncLLM) @@ -462,10 +459,7 @@ async def test_serving_chat_could_load_correct_generation_config(): req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], ) with suppress(Exception): @@ -508,10 +502,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): # Test cache_salt req = ChatCompletionRequest( model=MODEL_NAME, - messages=[{ - "role": "user", - "content": "what is 1+1?" - }], + messages=[{"role": "user", "content": "what is 1+1?"}], ) # By default, cache_salt in the engine prompt is not set diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py index ba6f10891159..0c52270c13af 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -34,7 +34,8 @@ def serving() -> OpenAIServing: @pytest.mark.asyncio async def test_async_mistral_tokenizer_does_not_block_event_loop( - serving: OpenAIServing): + serving: OpenAIServing, +): expected_tokens = [1, 2, 3] # Mock the blocking version to sleep @@ -45,10 +46,9 @@ def mocked_apply_chat_template(*_args, **_kwargs): mock_tokenizer = Mock(spec=MistralTokenizer) mock_tokenizer.apply_chat_template.side_effect = mocked_apply_chat_template - task = serving._apply_mistral_chat_template_async(tokenizer=mock_tokenizer, - messages=[], - chat_template=None, - tools=[]) + task = serving._apply_mistral_chat_template_async( + tokenizer=mock_tokenizer, messages=[], chat_template=None, tools=[] + ) # Ensure the event loop is not blocked blocked_count = 0 @@ -66,4 +66,4 @@ def mocked_apply_chat_template(*_args, **_kwargs): # Ensure task completes tokens = await task assert tokens == expected_tokens, "Mocked blocking tokenizer was not called" - assert blocked_count == 0, ("Event loop blocked during tokenization") + assert blocked_count == 0, "Event loop blocked during tokenization" diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index bc6a0341f59f..ed9dedcc6f08 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -8,19 +8,20 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.openai.protocol import (ErrorResponse, - LoadLoRAAdapterRequest, - UnloadLoRAAdapterRequest) -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + LoadLoRAAdapterRequest, + UnloadLoRAAdapterRequest, +) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.lora.request import LoRARequest MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] -LORA_LOADING_SUCCESS_MESSAGE = ( - "Success: LoRA adapter '{lora_name}' added successfully.") +LORA_LOADING_SUCCESS_MESSAGE = "Success: LoRA adapter '{lora_name}' added successfully." LORA_UNLOADING_SUCCESS_MESSAGE = ( - "Success: LoRA adapter '{lora_name}' removed successfully.") + "Success: LoRA adapter '{lora_name}' removed successfully." +) async def _async_serving_models_init() -> OpenAIServingModels: @@ -29,10 +30,12 @@ async def _async_serving_models_init() -> OpenAIServingModels: # Set the max_model_len attribute to avoid missing attribute mock_model_config.max_model_len = 2048 - serving_models = OpenAIServingModels(engine_client=mock_engine_client, - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config, - lora_modules=None) + serving_models = OpenAIServingModels( + engine_client=mock_engine_client, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + lora_modules=None, + ) await serving_models.init_static_loras() return serving_models @@ -42,19 +45,18 @@ async def _async_serving_models_init() -> OpenAIServingModels: async def test_serving_model_name(): serving_models = await _async_serving_models_init() assert serving_models.model_name(None) == MODEL_NAME - request = LoRARequest(lora_name="adapter", - lora_path="/path/to/adapter2", - lora_int_id=1) + request = LoRARequest( + lora_name="adapter", lora_path="/path/to/adapter2", lora_int_id=1 + ) assert serving_models.model_name(request) == request.lora_name @pytest.mark.asyncio async def test_load_lora_adapter_success(): serving_models = await _async_serving_models_init() - request = LoadLoRAAdapterRequest(lora_name="adapter", - lora_path="/path/to/adapter2") + request = LoadLoRAAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2") response = await serving_models.load_lora_adapter(request) - assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') + assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter") assert len(serving_models.lora_requests) == 1 assert "adapter" in serving_models.lora_requests assert serving_models.lora_requests["adapter"].lora_name == "adapter" @@ -73,15 +75,16 @@ async def test_load_lora_adapter_missing_fields(): @pytest.mark.asyncio async def test_load_lora_adapter_duplicate(): serving_models = await _async_serving_models_init() - request = LoadLoRAAdapterRequest(lora_name="adapter1", - lora_path="/path/to/adapter1") + request = LoadLoRAAdapterRequest( + lora_name="adapter1", lora_path="/path/to/adapter1" + ) response = await serving_models.load_lora_adapter(request) - assert response == LORA_LOADING_SUCCESS_MESSAGE.format( - lora_name='adapter1') + assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name="adapter1") assert len(serving_models.lora_requests) == 1 - request = LoadLoRAAdapterRequest(lora_name="adapter1", - lora_path="/path/to/adapter1") + request = LoadLoRAAdapterRequest( + lora_name="adapter1", lora_path="/path/to/adapter1" + ) response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.error.type == "InvalidUserInput" @@ -92,15 +95,15 @@ async def test_load_lora_adapter_duplicate(): @pytest.mark.asyncio async def test_unload_lora_adapter_success(): serving_models = await _async_serving_models_init() - request = LoadLoRAAdapterRequest(lora_name="adapter1", - lora_path="/path/to/adapter1") + request = LoadLoRAAdapterRequest( + lora_name="adapter1", lora_path="/path/to/adapter1" + ) response = await serving_models.load_lora_adapter(request) assert len(serving_models.lora_requests) == 1 request = UnloadLoRAAdapterRequest(lora_name="adapter1") response = await serving_models.unload_lora_adapter(request) - assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( - lora_name='adapter1') + assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(lora_name="adapter1") assert len(serving_models.lora_requests) == 0 diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index 58d92f72dfae..cd7bb06ad320 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -34,11 +34,9 @@ def need_builtin_tool_call(self) -> bool: def render_for_completion(self): return [] - async def init_tool_sessions(self, tool_server, exit_stack, request_id, - mcp_tools): + async def init_tool_sessions(self, tool_server, exit_stack, request_id, mcp_tools): self.init_tool_sessions_called = True - self.init_tool_sessions_args = (tool_server, exit_stack, request_id, - mcp_tools) + self.init_tool_sessions_args = (tool_server, exit_stack, request_id, mcp_tools) async def cleanup_session(self) -> None: pass @@ -96,35 +94,31 @@ async def serving_responses_instance(self): return instance @pytest.mark.asyncio - async def test_initialize_tool_sessions(self, serving_responses_instance, - mock_context, mock_exit_stack): + async def test_initialize_tool_sessions( + self, serving_responses_instance, mock_context, mock_exit_stack + ): """Test that method works correctly with only MCP tools""" request = ResponsesRequest(input="test input", tools=[]) # Call the method await serving_responses_instance._initialize_tool_sessions( - request, mock_context, mock_exit_stack) + request, mock_context, mock_exit_stack + ) assert mock_context.init_tool_sessions_called is False # Create only MCP tools tools = [ - { - "type": "web_search_preview" - }, - { - "type": "code_interpreter", - "container": { - "type": "auto" - } - }, + {"type": "web_search_preview"}, + {"type": "code_interpreter", "container": {"type": "auto"}}, ] request = ResponsesRequest(input="test input", tools=tools) # Call the method await serving_responses_instance._initialize_tool_sessions( - request, mock_context, mock_exit_stack) + request, mock_context, mock_exit_stack + ) # Verify that init_tool_sessions was called assert mock_context.init_tool_sessions_called @@ -165,25 +159,20 @@ def test_validate_generator_input(self, serving_responses_instance): """Test _validate_generator_input with valid prompt length""" # Create an engine prompt with valid length (less than max_model_len) valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len - engine_prompt = EngineTokensPrompt( - prompt_token_ids=valid_prompt_token_ids) + engine_prompt = EngineTokensPrompt(prompt_token_ids=valid_prompt_token_ids) # Call the method - result = serving_responses_instance._validate_generator_input( - engine_prompt) + result = serving_responses_instance._validate_generator_input(engine_prompt) # Should return None for valid input assert result is None # create an invalid engine prompt - invalid_prompt_token_ids = list( - range(200)) # 100 tokens >= 100 max_model_len - engine_prompt = EngineTokensPrompt( - prompt_token_ids=invalid_prompt_token_ids) + invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len + engine_prompt = EngineTokensPrompt(prompt_token_ids=invalid_prompt_token_ids) # Call the method - result = serving_responses_instance._validate_generator_input( - engine_prompt) + result = serving_responses_instance._validate_generator_input(engine_prompt) # Should return an ErrorResponse assert result is not None diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 29a94c852bba..ff46df81d0ff 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -24,16 +24,13 @@ async def test_shutdown_on_engine_failure(): with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: async with remote_server.get_async_client() as client: - - with pytest.raises( - (openai.APIConnectionError, openai.InternalServerError)): + with pytest.raises((openai.APIConnectionError, openai.InternalServerError)): # Asking for lots of prompt logprobs will currently crash the # engine. This may change in the future when that bug is fixed prompt = "Hello " * 4000 await client.completions.create( - model=MODEL_NAME, - prompt=prompt, - extra_body={"prompt_logprobs": 10}) + model=MODEL_NAME, prompt=prompt, extra_body={"prompt_logprobs": 10} + ) # Now the server should shut down return_code = remote_server.proc.wait(timeout=8) diff --git a/tests/entrypoints/openai/test_skip_tokenizer.py b/tests/entrypoints/openai/test_skip_tokenizer.py index b469fc76fc7a..6998566c03d0 100644 --- a/tests/entrypoints/openai/test_skip_tokenizer.py +++ b/tests/entrypoints/openai/test_skip_tokenizer.py @@ -29,7 +29,7 @@ def server(): "--max-num-seqs", "32", "--model-impl", - "terratorch" + "terratorch", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -39,7 +39,6 @@ def server(): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_single_request(server: RemoteOpenAIServer, model_name: str): - pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) @@ -47,40 +46,39 @@ async def test_single_request(server: RemoteOpenAIServer, model_name: str): torch.save(pixel_values, buffer_tiff) buffer_tiff.seek(0) binary_data = buffer_tiff.read() - base64_tensor_embedding = base64.b64encode(binary_data).decode('utf-8') + base64_tensor_embedding = base64.b64encode(binary_data).decode("utf-8") buffer_coord = io.BytesIO() torch.save(location_coords, buffer_coord) buffer_coord.seek(0) binary_data = buffer_coord.read() - base64_coord_embedding = base64.b64encode(binary_data).decode('utf-8') + base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8") prompt = { - "model": - model_name, - "additional_data": { - "prompt_token_ids": [1] - }, - "encoding_format": - "base64", - "messages": [{ - "role": - "user", - "content": [{ - "type": "image_embeds", - "image_embeds": { - "pixel_values": base64_tensor_embedding, - "location_coords": base64_coord_embedding, - }, - }], - }] + "model": model_name, + "additional_data": {"prompt_token_ids": [1]}, + "encoding_format": "base64", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": { + "pixel_values": base64_tensor_embedding, + "location_coords": base64_coord_embedding, + }, + } + ], + } + ], } # test single pooling response = requests.post(server.url_for("pooling"), json=prompt) response.raise_for_status() - output = response.json()["data"][0]['data'] + output = response.json()["data"][0]["data"] np_response = np.frombuffer(base64.b64decode(output), dtype=np.float32) diff --git a/tests/entrypoints/openai/test_sleep.py b/tests/entrypoints/openai/test_sleep.py index 0dd6af17ef22..e07436f89d2d 100644 --- a/tests/entrypoints/openai/test_sleep.py +++ b/tests/entrypoints/openai/test_sleep.py @@ -20,14 +20,12 @@ def test_sleep_mode(): "--enable-sleep-mode", ] - with RemoteOpenAIServer(MODEL_NAME, - args, - env_dict={ - "VLLM_SERVER_DEV_MODE": "1", - "CUDA_VISIBLE_DEVICES": "0" - }) as remote_server: - response = requests.post(remote_server.url_for("sleep"), - params={"level": "1"}) + with RemoteOpenAIServer( + MODEL_NAME, + args, + env_dict={"VLLM_SERVER_DEV_MODE": "1", "CUDA_VISIBLE_DEVICES": "0"}, + ) as remote_server: + response = requests.post(remote_server.url_for("sleep"), params={"level": "1"}) assert response.status_code == 200 response = requests.get(remote_server.url_for("is_sleeping")) assert response.status_code == 200 @@ -40,12 +38,12 @@ def test_sleep_mode(): assert response.json().get("is_sleeping") is False # test wake up with tags - response = requests.post(remote_server.url_for("sleep"), - params={"level": "1"}) + response = requests.post(remote_server.url_for("sleep"), params={"level": "1"}) assert response.status_code == 200 - response = requests.post(remote_server.url_for("wake_up"), - params={"tags": ["weights"]}) + response = requests.post( + remote_server.url_for("wake_up"), params={"tags": ["weights"]} + ) assert response.status_code == 200 # is sleeping should be false after waking up any part of the engine @@ -53,8 +51,9 @@ def test_sleep_mode(): assert response.status_code == 200 assert response.json().get("is_sleeping") is True - response = requests.post(remote_server.url_for("wake_up"), - params={"tags": ["kv_cache"]}) + response = requests.post( + remote_server.url_for("wake_up"), params={"tags": ["kv_cache"]} + ) assert response.status_code == 200 response = requests.get(remote_server.url_for("is_sleeping")) diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py index 058e96f203c3..80b7cd9f4cbc 100644 --- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py +++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -11,7 +11,10 @@ from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, tensorize_lora_adapter, tensorize_vllm_model) + TensorizerConfig, + tensorize_lora_adapter, + tensorize_vllm_model, +) from ...utils import RemoteOpenAIServer @@ -29,21 +32,20 @@ def cleanup(): _cleanup() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def tmp_dir(): with tempfile.TemporaryDirectory() as path: yield path -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def model_uri(tmp_dir): yield f"{tmp_dir}/model.tensors" @pytest.fixture(scope="module") def tensorize_model_and_lora(tmp_dir, model_uri): - tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri, - lora_dir=tmp_dir) + tensorizer_config = TensorizerConfig(tensorizer_uri=model_uri, lora_dir=tmp_dir) args = EngineArgs(model=MODEL_NAME) tensorize_lora_adapter(LORA_PATH, tensorizer_config) @@ -66,8 +68,11 @@ def server(model_uri, tensorize_model_and_lora): ## Start OpenAI API server args = [ - "--load-format", "tensorizer", "--served-model-name", MODEL_NAME, - "--enable-lora" + "--load-format", + "tensorizer", + "--served-model-name", + MODEL_NAME, + "--enable-lora", ] model_dir = os.path.dirname(model_uri) @@ -85,10 +90,9 @@ async def client(server): @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): _cleanup() - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) + completion = await client.completions.create( + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -97,4 +101,5 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): assert len(completion.choices[0].text) >= 5 assert completion.choices[0].finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) + completion_tokens=5, prompt_tokens=6, total_tokens=11 + ) diff --git a/tests/entrypoints/openai/test_token_in_token_out.py b/tests/entrypoints/openai/test_token_in_token_out.py index ed003939c44b..25eb5882be89 100644 --- a/tests/entrypoints/openai/test_token_in_token_out.py +++ b/tests/entrypoints/openai/test_token_in_token_out.py @@ -6,8 +6,7 @@ import pytest -from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf) +from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf from vllm.transformers_utils.tokenizer import get_tokenizer from ...utils import RemoteOpenAIServer @@ -23,7 +22,8 @@ def server(): MODEL_NAME, allow_patterns=["*"], cache_dir=MODEL_PATH, - ignore_patterns=["tokenizer*", "vocab*", "*.safetensors"]) + ignore_patterns=["tokenizer*", "vocab*", "*.safetensors"], + ) args = [ "--max-model-len", "2048", @@ -61,13 +61,14 @@ async def test_token_in_token_out_and_logprobs(server): ) # Verify all fields are present - assert (completion.choices[0].token_ids is not None - and 0 < len(completion.choices[0].token_ids) <= 20) + assert ( + completion.choices[0].token_ids is not None + and 0 < len(completion.choices[0].token_ids) <= 20 + ) assert completion.choices[0].prompt_token_ids is not None # Decode prompt tokens if completion.choices[0].prompt_token_ids: - prompt_text = tokenizer.decode( - completion.choices[0].prompt_token_ids) + prompt_text = tokenizer.decode(completion.choices[0].prompt_token_ids) # The decoded prompt should match or close to original prompt assert prompt_text == text diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index ecb7f50fa740..7fd32e1c7be1 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -53,19 +53,20 @@ async def test_tokenize_completions( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") for add_special in [False, True]: prompt = "vllm1 This is a test prompt." tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post(server.url_for("tokenize"), - json={ - "add_special_tokens": add_special, - "model": model_name, - "prompt": prompt - }) + response = requests.post( + server.url_for("tokenize"), + json={ + "add_special_tokens": add_special, + "model": model_name, + "prompt": prompt, + }, + ) response.raise_for_status() result = response.json() @@ -86,48 +87,39 @@ async def test_tokenize_chat( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") for add_generation in [False, True]: for add_special in [False, True]: - conversation = [{ - "role": "user", - "content": "Hi there!" - }, { - "role": "assistant", - "content": "Nice to meet you!" - }, { - "role": "user", - "content": "Can I ask a question? vllm1" - }] + conversation = [ + {"role": "user", "content": "Hi there!"}, + {"role": "assistant", "content": "Nice to meet you!"}, + {"role": "user", "content": "Can I ask a question? vllm1"}, + ] for continue_final in [False, True]: if add_generation and continue_final: continue if continue_final: - conversation.append({ - "role": "assistant", - "content": "Sure," - }) + conversation.append({"role": "assistant", "content": "Sure,"}) prompt = tokenizer.apply_chat_template( add_generation_prompt=add_generation, continue_final_message=continue_final, conversation=conversation, - tokenize=False) - tokens = tokenizer.encode(prompt, - add_special_tokens=add_special) - - response = requests.post(server.url_for("tokenize"), - json={ - "add_generation_prompt": - add_generation, - "continue_final_message": - continue_final, - "add_special_tokens": add_special, - "messages": conversation, - "model": model_name - }) + tokenize=False, + ) + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) + + response = requests.post( + server.url_for("tokenize"), + json={ + "add_generation_prompt": add_generation, + "continue_final_message": continue_final, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name, + }, + ) response.raise_for_status() result = response.json() @@ -148,41 +140,35 @@ async def test_tokenize_chat_with_tools( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") for add_generation in [False, True]: for add_special in [False, True]: - conversation = [{ - "role": - "user", - "content": - "What's the weather like in Paris today?", - }] - - tools = [{ - "type": "function", - "function": { - "name": "get_weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string" - } + conversation = [ + { + "role": "user", + "content": "What's the weather like in Paris today?", + } + ] + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, }, }, - }, - }] + } + ] for continue_final in [False, True]: if add_generation and continue_final: continue if continue_final: - conversation.append({ - "role": "assistant", - "content": "Sure," - }) + conversation.append({"role": "assistant", "content": "Sure,"}) prompt = tokenizer.apply_chat_template( add_generation_prompt=add_generation, @@ -191,8 +177,7 @@ async def test_tokenize_chat_with_tools( tools=tools, tokenize=False, ) - tokens = tokenizer.encode(prompt, - add_special_tokens=add_special) + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) response = requests.post( server.url_for("tokenize"), @@ -225,17 +210,12 @@ async def test_tokenize_with_return_token_strs( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") prompt = "This is a token_strs test prompt! vllm1" response = requests.post( server.url_for("tokenize"), - json={ - "prompt": prompt, - "model": model_name, - "return_token_strs": True - }, + json={"prompt": prompt, "model": model_name, "return_token_strs": True}, ) response.raise_for_status() @@ -260,17 +240,14 @@ async def test_detokenize( model_name: str, tokenizer_name: str, ): - tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, tokenizer_mode="fast") prompt = "This is a test prompt. vllm1" tokens = tokenizer.encode(prompt, add_special_tokens=False) - response = requests.post(server.url_for("detokenize"), - json={ - "model": model_name, - "tokens": tokens - }) + response = requests.post( + server.url_for("detokenize"), json={"model": model_name, "tokens": tokens} + ) response.raise_for_status() assert response.json() == {"prompt": prompt} @@ -319,14 +296,15 @@ async def test_tokenizer_info_schema(server: RemoteOpenAIServer): } for field, expected_type in field_types.items(): if field in result and result[field] is not None: - assert isinstance( - result[field], - expected_type), (f"{field} should be {expected_type.__name__}") + assert isinstance(result[field], expected_type), ( + f"{field} should be {expected_type.__name__}" + ) @pytest.mark.asyncio async def test_tokenizer_info_added_tokens_structure( - server: RemoteOpenAIServer, ): + server: RemoteOpenAIServer, +): """Test added_tokens_decoder structure if present.""" response = requests.get(server.url_for("tokenizer_info")) response.raise_for_status() @@ -337,25 +315,23 @@ async def test_tokenizer_info_added_tokens_structure( assert isinstance(token_id, str), "Token IDs should be strings" assert isinstance(token_info, dict), "Token info should be a dict" assert "content" in token_info, "Token info should have content" - assert "special" in token_info, ( - "Token info should have special flag") - assert isinstance(token_info["special"], - bool), ("Special flag should be boolean") + assert "special" in token_info, "Token info should have special flag" + assert isinstance(token_info["special"], bool), ( + "Special flag should be boolean" + ) @pytest.mark.asyncio async def test_tokenizer_info_consistency_with_tokenize( - server: RemoteOpenAIServer, ): + server: RemoteOpenAIServer, +): """Test that tokenizer info is consistent with tokenization endpoint.""" info_response = requests.get(server.url_for("tokenizer_info")) info_response.raise_for_status() info = info_response.json() tokenize_response = requests.post( server.url_for("tokenize"), - json={ - "model": MODEL_NAME, - "prompt": "Hello world!" - }, + json={"model": MODEL_NAME, "prompt": "Hello world!"}, ) tokenize_response.raise_for_status() tokenize_result = tokenize_response.json() @@ -363,7 +339,8 @@ async def test_tokenizer_info_consistency_with_tokenize( tokenize_max_len = tokenize_result.get("max_model_len") if info_max_len and tokenize_max_len: assert info_max_len >= tokenize_max_len, ( - "Info max length should be >= tokenize max length") + "Info max length should be >= tokenize max length" + ) @pytest.mark.asyncio @@ -374,6 +351,5 @@ async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer): result = response.json() chat_template = result.get("chat_template") if chat_template: - assert isinstance(chat_template, - str), ("Chat template should be a string") + assert isinstance(chat_template, str), "Chat template should be a string" assert chat_template.strip(), "Chat template should not be empty" diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 23c99da97ad3..6ef932392d09 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -17,8 +17,12 @@ MODEL_NAME = "openai/whisper-large-v3-turbo" SERVER_ARGS = ["--enforce-eager"] MISTRAL_FORMAT_ARGS = [ - "--tokenizer_mode", "mistral", "--config_format", "mistral", - "--load_format", "mistral" + "--tokenizer_mode", + "mistral", + "--config_format", + "mistral", + "--load_format", + "mistral", ] @@ -36,8 +40,8 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize( - "model_name", - ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"]) + "model_name", ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"] +) async def test_basic_audio(mary_had_lamb, model_name): server_args = ["--enforce-eager"] @@ -52,10 +56,11 @@ async def test_basic_audio(mary_had_lamb, model_name): file=mary_had_lamb, language="en", response_format="text", - temperature=0.0) + temperature=0.0, + ) out = json.loads(transcription) - out_text = out['text'] - out_usage = out['usage'] + out_text = out["text"] + out_usage = out["usage"] assert "Mary had a little lamb," in out_text assert out_usage["seconds"] == 16, out_usage["seconds"] @@ -74,8 +79,9 @@ async def test_basic_audio_gemma(foscolo): file=foscolo, language="it", response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] + temperature=0.0, + ) + out = json.loads(transcription)["text"] assert "da cui vergine nacque Venere" in out @@ -85,24 +91,21 @@ async def test_non_asr_model(winning_call): model_name = "JackFram/llama-68m" with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server: client = remote_server.get_async_client() - res = await client.audio.transcriptions.create(model=model_name, - file=winning_call, - language="en", - temperature=0.0) + res = await client.audio.transcriptions.create( + model=model_name, file=winning_call, language="en", temperature=0.0 + ) err = res.error assert err["code"] == 400 and not res.text - assert err[ - "message"] == "The model does not support Transcriptions API" + assert err["message"] == "The model does not support Transcriptions API" @pytest.mark.asyncio async def test_bad_requests(mary_had_lamb, client): # invalid language with pytest.raises(openai.BadRequestError): - await client.audio.transcriptions.create(model=MODEL_NAME, - file=mary_had_lamb, - language="hh", - temperature=0.0) + await client.audio.transcriptions.create( + model=MODEL_NAME, file=mary_had_lamb, language="hh", temperature=0.0 + ) @pytest.mark.asyncio @@ -114,17 +117,18 @@ async def test_long_audio_request(mary_had_lamb, client): repeated_audio = np.tile(audio, 10) # Repeated audio to buffer buffer = io.BytesIO() - sf.write(buffer, repeated_audio, sr, format='WAV') + sf.write(buffer, repeated_audio, sr, format="WAV") buffer.seek(0) transcription = await client.audio.transcriptions.create( model=MODEL_NAME, file=buffer, language="en", response_format="text", - temperature=0.0) + temperature=0.0, + ) out = json.loads(transcription) - out_text = out['text'] - out_usage = out['usage'] + out_text = out["text"] + out_usage = out["usage"] counts = out_text.count("Mary had a little lamb") assert counts == 10, counts assert out_usage["seconds"] == 161, out_usage["seconds"] @@ -135,10 +139,8 @@ async def test_completion_endpoints(client): # text to text model res = await client.chat.completions.create( model=MODEL_NAME, - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }]) + messages=[{"role": "system", "content": "You are a helpful assistant."}], + ) err = res.error assert err["code"] == 400 assert err["message"] == "The model does not support Chat Completions API" @@ -157,16 +159,19 @@ async def test_streaming_response(winning_call, client): file=winning_call, response_format="json", language="en", - temperature=0.0) - res = await client.audio.transcriptions.create(model=MODEL_NAME, - file=winning_call, - language="en", - temperature=0.0, - stream=True, - timeout=30) + temperature=0.0, + ) + res = await client.audio.transcriptions.create( + model=MODEL_NAME, + file=winning_call, + language="en", + temperature=0.0, + stream=True, + timeout=30, + ) # Reconstruct from chunks and validate async for chunk in res: - text = chunk.choices[0]['delta']['content'] + text = chunk.choices[0]["delta"]["content"] transcription += text assert transcription == res_no_stream.text @@ -180,9 +185,9 @@ async def test_stream_options(winning_call, client): language="en", temperature=0.0, stream=True, - extra_body=dict(stream_include_usage=True, - stream_continuous_usage_stats=True), - timeout=30) + extra_body=dict(stream_include_usage=True, stream_continuous_usage_stats=True), + timeout=30, + ) final = False continuous = True async for chunk in res: @@ -190,7 +195,7 @@ async def test_stream_options(winning_call, client): # final usage sent final = True else: - continuous = continuous and hasattr(chunk, 'usage') + continuous = continuous and hasattr(chunk, "usage") assert final and continuous @@ -198,27 +203,31 @@ async def test_stream_options(winning_call, client): async def test_sampling_params(mary_had_lamb, client): """ Compare sampling with params and greedy sampling to assert results - are different when extreme sampling parameters values are picked. + are different when extreme sampling parameters values are picked. """ transcription = await client.audio.transcriptions.create( model=MODEL_NAME, file=mary_had_lamb, language="en", temperature=0.8, - extra_body=dict(seed=42, - repetition_penalty=1.9, - top_k=12, - top_p=0.4, - min_p=0.5, - frequency_penalty=1.8, - presence_penalty=2.0)) + extra_body=dict( + seed=42, + repetition_penalty=1.9, + top_k=12, + top_p=0.4, + min_p=0.5, + frequency_penalty=1.8, + presence_penalty=2.0, + ), + ) greedy_transcription = await client.audio.transcriptions.create( model=MODEL_NAME, file=mary_had_lamb, language="en", temperature=0.0, - extra_body=dict(seed=42)) + extra_body=dict(seed=42), + ) assert greedy_transcription.text != transcription.text @@ -226,15 +235,16 @@ async def test_sampling_params(mary_had_lamb, client): @pytest.mark.asyncio async def test_audio_prompt(mary_had_lamb, client): prompt = "This is a speech, recorded in a phonograph." - #Prompts should not omit the part of original prompt while transcribing. + # Prompts should not omit the part of original prompt while transcribing. prefix = "The first words I spoke in the original phonograph" transcription = await client.audio.transcriptions.create( model=MODEL_NAME, file=mary_had_lamb, language="en", response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] + temperature=0.0, + ) + out = json.loads(transcription)["text"] assert prefix in out transcription_wprompt = await client.audio.transcriptions.create( model=MODEL_NAME, @@ -242,6 +252,7 @@ async def test_audio_prompt(mary_had_lamb, client): language="en", response_format="text", prompt=prompt, - temperature=0.0) - out_prompt = json.loads(transcription_wprompt)['text'] + temperature=0.0, + ) + out_prompt = json.loads(transcription_wprompt)["text"] assert prefix in out_prompt diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index eb7879927b9b..f35742e166fe 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import io + # imports for structured outputs tests import json @@ -17,8 +18,9 @@ SERVER_ARGS = ["--enforce-eager"] -@pytest.fixture(scope="module", - params=["openai/whisper-small", "google/gemma-3n-E2B-it"]) +@pytest.fixture( + scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"] +) def server(request): # Parametrize over model name with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server: @@ -38,9 +40,9 @@ async def test_non_asr_model(foscolo): model_name = "JackFram/llama-68m" with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server: client = remote_server.get_async_client() - res = await client.audio.translations.create(model=model_name, - file=foscolo, - temperature=0.0) + res = await client.audio.translations.create( + model=model_name, file=foscolo, temperature=0.0 + ) err = res.error assert err["code"] == 400 and not res.text assert err["message"] == "The model does not support Translations API" @@ -56,8 +58,9 @@ async def test_basic_audio(foscolo, client_and_model): response_format="text", # TODO remove `language="it"` once language detection is implemented extra_body=dict(language="it", to_language="en"), - temperature=0.0) - out = json.loads(translation)['text'].strip().lower() + temperature=0.0, + ) + out = json.loads(translation)["text"].strip().lower() assert "greek sea" in out @@ -72,8 +75,9 @@ async def test_audio_prompt(foscolo, client_and_model): prompt=prompt, extra_body=dict(language="it", to_language="en"), response_format="text", - temperature=0.0) - out = json.loads(transcription)['text'] + temperature=0.0, + ) + out = json.loads(transcription)["text"] assert "Nor will I ever touch the sacred" not in out assert prompt not in out @@ -87,7 +91,8 @@ async def test_streaming_response(foscolo, client_and_model, server): file=foscolo, response_format="json", extra_body=dict(language="it", to_language="en", seed=42), - temperature=0.0) + temperature=0.0, + ) # Stream via HTTPX since OpenAI translation client doesn't expose streaming server, model_name = server @@ -104,16 +109,14 @@ async def test_streaming_response(foscolo, client_and_model, server): foscolo.seek(0) async with httpx.AsyncClient() as http_client: files = {"file": foscolo} - async with http_client.stream("POST", - url, - headers=headers, - data=data, - files=files) as response: + async with http_client.stream( + "POST", url, headers=headers, data=data, files=files + ) as response: async for line in response.aiter_lines(): if not line: continue if line.startswith("data: "): - line = line[len("data: "):] + line = line[len("data: ") :] if line.strip() == "[DONE]": break chunk = json.loads(line) @@ -124,9 +127,10 @@ async def test_streaming_response(foscolo, client_and_model, server): # NOTE There's a small non-deterministic issue here, likely in the attn # computation, which will cause a few tokens to be different, while still # being very close semantically. - assert sum([ - x == y for x, y in zip(res_stream, res_no_stream.text.split()) - ]) >= len(res_stream) * 0.9 + assert ( + sum([x == y for x, y in zip(res_stream, res_no_stream.text.split())]) + >= len(res_stream) * 0.9 + ) @pytest.mark.asyncio @@ -148,16 +152,14 @@ async def test_stream_options(foscolo, server): continuous = True async with httpx.AsyncClient() as http_client: files = {"file": foscolo} - async with http_client.stream("POST", - url, - headers=headers, - data=data, - files=files) as response: + async with http_client.stream( + "POST", url, headers=headers, data=data, files=files + ) as response: async for line in response.aiter_lines(): if not line: continue if line.startswith("data: "): - line = line[len("data: "):] + line = line[len("data: ") :] if line.strip() == "[DONE]": break chunk = json.loads(line) @@ -180,13 +182,14 @@ async def test_long_audio_request(foscolo, client_and_model): repeated_audio = np.tile(audio, 2) # Repeated audio to buffer buffer = io.BytesIO() - sf.write(buffer, repeated_audio, sr, format='WAV') + sf.write(buffer, repeated_audio, sr, format="WAV") buffer.seek(0) translation = await client.audio.translations.create( model=model_name, file=buffer, extra_body=dict(language="it", to_language="en"), response_format="text", - temperature=0.0) - out = json.loads(translation)['text'].strip().lower() + temperature=0.0, + ) + out = json.loads(translation)["text"].strip().lower() assert out.count("greek sea") == 2 diff --git a/tests/entrypoints/openai/test_video.py b/tests/entrypoints/openai/test_video.py index ad4dff00daaa..4c7d1c14ca17 100644 --- a/tests/entrypoints/openai/test_video.py +++ b/tests/entrypoints/openai/test_video.py @@ -58,24 +58,18 @@ def base64_encoded_video() -> dict[str, str]: @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_single_chat_session_video(client: openai.AsyncOpenAI, - model_name: str, video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": video_url - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_single_chat_session_video( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -84,13 +78,15 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=6287, total_tokens=6297) + completion_tokens=10, prompt_tokens=6287, total_tokens=6297 + ) message = choice.message message = chat_completion.choices[0].message @@ -112,54 +108,44 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_error_on_invalid_video_url_type(client: openai.AsyncOpenAI, - model_name: str, - video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": video_url - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_error_on_invalid_video_url_type( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": video_url}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # video_url should be a dict {"url": "some url"}, not directly a string with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0) + _ = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI, - model_name: str, - video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": video_url - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_single_chat_session_video_beamsearch( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -168,36 +154,38 @@ async def test_single_chat_session_video_beamsearch(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, top_logprobs=5, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[ - 0].message.content != chat_completion.choices[1].message.content + assert ( + chat_completion.choices[0].message.content + != chat_completion.choices[1].message.content + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) async def test_single_chat_session_video_base64encoded( - client: openai.AsyncOpenAI, model_name: str, video_url: str, - base64_encoded_video: dict[str, str]): - - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": - f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + video_url: str, + base64_encoded_video: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" # noqa: E501 + }, + }, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -206,13 +194,15 @@ async def test_single_chat_session_video_base64encoded( max_completion_tokens=10, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=6287, total_tokens=6297) + completion_tokens=10, prompt_tokens=6287, total_tokens=6297 + ) message = choice.message message = chat_completion.choices[0].message @@ -236,58 +226,54 @@ async def test_single_chat_session_video_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) async def test_single_chat_session_video_base64encoded_beamsearch( - client: openai.AsyncOpenAI, model_name: str, video_url: str, - base64_encoded_video: dict[str, str]): - - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": - f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] + client: openai.AsyncOpenAI, + model_name: str, + video_url: str, + base64_encoded_video: dict[str, str], +): + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": { + "url": f"data:video/jpeg;base64,{base64_encoded_video[video_url]}" # noqa: E501 + }, + }, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, messages=messages, n=2, max_completion_tokens=10, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[ - 0].message.content != chat_completion.choices[1].message.content + assert ( + chat_completion.choices[0].message.content + != chat_completion.choices[1].message.content + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) -async def test_chat_streaming_video(client: openai.AsyncOpenAI, - model_name: str, video_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "video_url", - "video_url": { - "url": video_url - } - }, - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] +async def test_chat_streaming_video( + client: openai.AsyncOpenAI, model_name: str, video_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": video_url}}, + {"type": "text", "text": "What's in this video?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -327,27 +313,23 @@ async def test_chat_streaming_video(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize( - "video_urls", - [TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))]) -async def test_multi_video_input(client: openai.AsyncOpenAI, model_name: str, - video_urls: list[str]): - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "video_url", - "video_url": { - "url": video_url - } - } for video_url in video_urls), - { - "type": "text", - "text": "What's in this video?" - }, - ], - }] + "video_urls", [TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))] +) +async def test_multi_video_input( + client: openai.AsyncOpenAI, model_name: str, video_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "video_url", "video_url": {"url": video_url}} + for video_url in video_urls + ), + {"type": "text", "text": "What's in this video?"}, + ], + } + ] if len(video_urls) > MAXIMUM_VIDEOS: with pytest.raises(openai.BadRequestError): # test multi-video input diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index a324e8666605..5a15a352f45c 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -71,26 +71,30 @@ async def client(server): @pytest.fixture(scope="session") def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_asset: - encode_image_base64(local_asset_server.get_image_asset(image_asset)) + image_asset: encode_image_base64( + local_asset_server.get_image_asset(image_asset) + ) for image_asset in TEST_IMAGE_ASSETS } def get_hf_prompt_tokens(model_name, content, image_url): - processor = AutoProcessor.from_pretrained(model_name, - trust_remote_code=True, - num_crops=4) + processor = AutoProcessor.from_pretrained( + model_name, trust_remote_code=True, num_crops=4 + ) placeholder = "<|image_1|>\n" - messages = [{ - "role": "user", - "content": f"{placeholder}{content}", - }] + messages = [ + { + "role": "user", + "content": f"{placeholder}{content}", + } + ] images = [fetch_image(image_url)] prompt = processor.tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True) + messages, tokenize=False, add_generation_prompt=True + ) inputs = processor(prompt, images, return_tensors="pt") return inputs.input_ids.shape[1] @@ -99,25 +103,19 @@ def get_hf_prompt_tokens(model_name, content, image_url): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) -async def test_single_chat_session_image(client: openai.AsyncOpenAI, - model_name: str, image_url: str): +async def test_single_chat_session_image( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": content_text}, + ], + } + ] max_completion_tokens = 10 # test single completion @@ -127,17 +125,18 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, max_completion_tokens=max_completion_tokens, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" - hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, - image_url) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) assert chat_completion.usage == openai.types.CompletionUsage( completion_tokens=max_completion_tokens, prompt_tokens=hf_prompt_tokens, - total_tokens=hf_prompt_tokens + max_completion_tokens) + total_tokens=hf_prompt_tokens + max_completion_tokens, + ) message = choice.message message = chat_completion.choices[0].message @@ -159,55 +158,45 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) -async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI, - model_name: str, - image_url: str): +async def test_error_on_invalid_image_url_type( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": image_url - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": image_url}, + {"type": "text", "text": content_text}, + ], + } + ] # image_url should be a dict {"url": "some url"}, not directly a string with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create(model=model_name, - messages=messages, - max_completion_tokens=10, - temperature=0.0) + _ = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.0, + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) -async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, - model_name: str, - image_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] +async def test_single_chat_session_image_beamsearch( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, @@ -216,10 +205,13 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, max_completion_tokens=10, logprobs=True, top_logprobs=5, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 - assert chat_completion.choices[ - 0].message.content != chat_completion.choices[1].message.content + assert ( + chat_completion.choices[0].message.content + != chat_completion.choices[1].message.content + ) @pytest.mark.asyncio @@ -227,27 +219,27 @@ async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI, @pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) @pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) async def test_single_chat_session_image_base64encoded( - client: openai.AsyncOpenAI, model_name: str, raw_image_url: str, - image_url: str, base64_encoded_image: dict[str, str]): - + client: openai.AsyncOpenAI, + model_name: str, + raw_image_url: str, + image_url: str, + base64_encoded_image: dict[str, str], +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": - f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" - } - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" # noqa: E501 + }, + }, + {"type": "text", "text": content_text}, + ], + } + ] max_completion_tokens = 10 # test single completion @@ -257,17 +249,18 @@ async def test_single_chat_session_image_base64encoded( max_completion_tokens=max_completion_tokens, logprobs=True, temperature=0.0, - top_logprobs=5) + top_logprobs=5, + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] assert choice.finish_reason == "length" - hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, - image_url) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) assert chat_completion.usage == openai.types.CompletionUsage( completion_tokens=max_completion_tokens, prompt_tokens=hf_prompt_tokens, - total_tokens=hf_prompt_tokens + max_completion_tokens) + total_tokens=hf_prompt_tokens + max_completion_tokens, + ) message = choice.message message = chat_completion.choices[0].message @@ -291,36 +284,37 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_ASSETS)))) async def test_single_chat_session_image_base64encoded_beamsearch( - client: openai.AsyncOpenAI, model_name: str, image_idx: int, - base64_encoded_image: dict[str, str]): + client: openai.AsyncOpenAI, + model_name: str, + image_idx: int, + base64_encoded_image: dict[str, str], +): # NOTE: This test also validates that we pass MM data through beam search raw_image_url = TEST_IMAGE_ASSETS[image_idx] expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx] - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": - f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}" # noqa: E501 + }, + }, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] chat_completion = await client.chat.completions.create( model=model_name, messages=messages, n=2, max_completion_tokens=10, temperature=0.0, - extra_body=dict(use_beam_search=True)) + extra_body=dict(use_beam_search=True), + ) assert len(chat_completion.choices) == 2 for actual, expected_str in zip(chat_completion.choices, expected_res): assert actual.message.content == expected_str @@ -329,24 +323,18 @@ async def test_single_chat_session_image_base64encoded_beamsearch( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) -async def test_chat_streaming_image(client: openai.AsyncOpenAI, - model_name: str, image_url: str): - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] +async def test_chat_streaming_image( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] # test single completion chat_completion = await client.chat.completions.create( @@ -388,26 +376,23 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI, @pytest.mark.parametrize( "image_urls", [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], - indirect=True) -async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, - image_urls: list[str]): - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "image_url", - "image_url": { - "url": image_url - } - } for image_url in image_urls), - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] + indirect=True, +) +async def test_multi_image_input( + client: openai.AsyncOpenAI, model_name: str, image_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + {"type": "image_url", "image_url": {"url": image_url}} + for image_url in image_urls + ), + {"type": "text", "text": "What's in this image?"}, + ], + } + ] if len(image_urls) > MAXIMUM_IMAGES: with pytest.raises(openai.BadRequestError): # test multi-image input @@ -443,7 +428,8 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, @pytest.mark.parametrize( "image_urls", [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], - indirect=True) + indirect=True, +) async def test_completions_with_image( client: openai.AsyncOpenAI, model_name: str, @@ -452,13 +438,9 @@ async def test_completions_with_image( for image_url in image_urls: chat_completion = await client.chat.completions.create( messages=[ + {"role": "system", "content": "You are a helpful assistant."}, { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": - "user", + "role": "user", "content": [ { "type": "text", @@ -468,7 +450,7 @@ async def test_completions_with_image( "type": "image_url", "image_url": { "url": image_url, - } + }, }, ], }, @@ -485,7 +467,8 @@ async def test_completions_with_image( @pytest.mark.parametrize( "image_urls", [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], - indirect=True) + indirect=True, +) async def test_completions_with_image_with_uuid( client: openai.AsyncOpenAI, model_name: str, @@ -494,13 +477,9 @@ async def test_completions_with_image_with_uuid( for image_url in image_urls: chat_completion = await client.chat.completions.create( messages=[ + {"role": "system", "content": "You are a helpful assistant."}, { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": - "user", + "role": "user", "content": [ { "type": "text", @@ -511,7 +490,7 @@ async def test_completions_with_image_with_uuid( "image_url": { "url": image_url, }, - "uuid": image_url + "uuid": image_url, }, ], }, @@ -525,34 +504,25 @@ async def test_completions_with_image_with_uuid( # Second request, with empty image but the same uuid. chat_completion_with_empty_image = await client.chat.completions.create( messages=[ + {"role": "system", "content": "You are a helpful assistant."}, { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": - "user", + "role": "user", "content": [ { "type": "text", "text": "Describe this image.", }, - { - "type": "image_url", - "image_url": {}, - "uuid": image_url - }, + {"type": "image_url", "image_url": {}, "uuid": image_url}, ], }, ], model=model_name, ) - assert chat_completion_with_empty_image.choices[ - 0].message.content is not None + assert chat_completion_with_empty_image.choices[0].message.content is not None assert isinstance( - chat_completion_with_empty_image.choices[0].message.content, str) - assert len( - chat_completion_with_empty_image.choices[0].message.content) > 0 + chat_completion_with_empty_image.choices[0].message.content, str + ) + assert len(chat_completion_with_empty_image.choices[0].message.content) > 0 @pytest.mark.asyncio @@ -564,13 +534,9 @@ async def test_completions_with_empty_image_with_uuid_without_cache_hit( with pytest.raises(openai.BadRequestError): _ = await client.chat.completions.create( messages=[ + {"role": "system", "content": "You are a helpful assistant."}, { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": - "user", + "role": "user", "content": [ { "type": "text", @@ -579,7 +545,7 @@ async def test_completions_with_empty_image_with_uuid_without_cache_hit( { "type": "image_url", "image_url": {}, - "uuid": "uuid_not_previously_seen" + "uuid": "uuid_not_previously_seen", }, ], }, @@ -593,7 +559,8 @@ async def test_completions_with_empty_image_with_uuid_without_cache_hit( @pytest.mark.parametrize( "image_urls", [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], - indirect=True) + indirect=True, +) async def test_completions_with_image_with_incorrect_uuid_format( client: openai.AsyncOpenAI, model_name: str, @@ -602,13 +569,9 @@ async def test_completions_with_image_with_incorrect_uuid_format( for image_url in image_urls: chat_completion = await client.chat.completions.create( messages=[ + {"role": "system", "content": "You are a helpful assistant."}, { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": - "user", + "role": "user", "content": [ { "type": "text", diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py index 1da06be2eba9..38008dafe32b 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -6,8 +6,7 @@ import pytest from vllm.entrypoints.openai.protocol import ChatCompletionRequest -from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import ( - Hermes2ProToolParser) +from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.transformers_utils.tokenizer import AnyTokenizer from ....utils import RemoteOpenAIServer @@ -27,61 +26,64 @@ f"{LORA_MODEL}", ] -TOOLS = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": - "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, }, + "required": ["location"], }, - "required": ["location"], }, - }, -}] - -PRODUCT_TOOLS = [{ - "type": "function", - "function": { - "name": "get_product_info", - "description": "Get detailed information of a product based on its " - "product ID.", - "parameters": { - "type": "object", - "properties": { - "inserted": { - "type": "boolean", - "description": "inserted.", - }, - "product_id": { - "type": "integer", - "description": "The product ID of the product.", + } +] + +PRODUCT_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_product_info", + "description": "Get detailed information of a product based on its " + "product ID.", + "parameters": { + "type": "object", + "properties": { + "inserted": { + "type": "boolean", + "description": "inserted.", + }, + "product_id": { + "type": "integer", + "description": "The product ID of the product.", + }, }, + "required": ["product_id", "inserted"], }, - "required": ["product_id", "inserted"], }, - }, -}] + } +] MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}] -PRODUCT_MESSAGES = [{ - "role": - "user", - "content": - "Hi! Do you have any detailed information about the product id " - "7355608 and inserted true?", -}] +PRODUCT_MESSAGES = [ + { + "role": "user", + "content": "Hi! Do you have any detailed information about the product id " + "7355608 and inserted true?", + } +] @pytest.mark.asyncio @@ -150,7 +152,8 @@ async def test_streaming_tool_call(): tool_call_chunks[index]["name"] += tool_chunk.function.name if tool_chunk.function.arguments: tool_call_chunks[index]["arguments"] += ( - tool_chunk.function.arguments) + tool_chunk.function.arguments + ) assert len(tool_call_chunks) == 1 reconstructed_tool_call = tool_call_chunks[0] @@ -240,7 +243,8 @@ async def test_streaming_product_tool_call(): tool_call_chunks[index]["name"] += tool_chunk.function.name if tool_chunk.function.arguments: tool_call_chunks[index]["arguments"] += ( - tool_chunk.function.arguments) + tool_chunk.function.arguments + ) assert len(tool_call_chunks) == 1 reconstructed_tool_call = tool_call_chunks[0] @@ -291,9 +295,7 @@ def test_hermes_parser_streaming_just_forward_text( hermes_parser: Hermes2ProToolParser, any_chat_request: ChatCompletionRequest, ) -> None: - text = ( - """This is some prior text that has nothing to do with tool calling.""" - ) + text = """This is some prior text that has nothing to do with tool calling.""" tokens = qwen_tokenizer.encode(text) previous_text = "" delta_messages = [] @@ -348,8 +350,9 @@ def test_hermes_parser_streaming_failure_case_bug_19056( delta_messages.append(delta) assert delta_messages[0].tool_calls[0].function.name == "final_answer" - tool_call_args = "".join(delta.tool_calls[0].function.arguments or "" - for delta in delta_messages) + tool_call_args = "".join( + delta.tool_calls[0].function.arguments or "" for delta in delta_messages + ) assert tool_call_args == '{"trigger": true}' @@ -383,13 +386,13 @@ def test_hermes_parser_streaming( if delta is not None: delta_messages.append(delta) print(delta_messages) - assert (delta_messages[0].tool_calls[0].function.name == - "get_current_temperature") - tool_call_args = "".join(delta.tool_calls[0].function.arguments or "" - for delta in delta_messages) + assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature" + tool_call_args = "".join( + delta.tool_calls[0].function.arguments or "" for delta in delta_messages + ) assert tool_call_args == ( - '{"location":"San Francisco, California, United States", ' - '"unit": "celsius"}') + '{"location":"San Francisco, California, United States", "unit": "celsius"}' + ) def test_hermes_parser_non_streaming_no_tool_call( diff --git a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py index bd8e06513e13..bdd5344652c4 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py @@ -8,15 +8,18 @@ import pytest from tests.entrypoints.openai.tool_parsers.utils import ( - run_tool_extraction, run_tool_extraction_streaming) + run_tool_extraction, + run_tool_extraction_streaming, +) from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager def make_tool_call(name, arguments): - return ToolCall(type="function", - function=FunctionCall(name=name, - arguments=json.dumps(arguments))) + return ToolCall( + type="function", + function=FunctionCall(name=name, arguments=json.dumps(arguments)), + ) # TODO: add reason prefix and suffix. @@ -29,70 +32,68 @@ def make_tool_call(name, arguments): ("How can I help you today?", [], "How can I help you today?"), # Single tool call, no content ( - "<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501 + '<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]</tool_calls>', # noqa: E501 [ - make_tool_call("get_weather", { - "city": "San Francisco", - "metric": "celsius" - }) + make_tool_call( + "get_weather", {"city": "San Francisco", "metric": "celsius"} + ) ], - None), + None, + ), # Multiple tool calls ( - "<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501 + '<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]</tool_calls>', # noqa: E501 [ - make_tool_call("get_weather", { - "city": "San Francisco", - "metric": "celsius" - }), make_tool_call( - "register_user", { + "get_weather", {"city": "San Francisco", "metric": "celsius"} + ), + make_tool_call( + "register_user", + { "name": "John Doe", "age": 37, - "address": { - "city": "San Francisco", - "state": "CA" - }, + "address": {"city": "San Francisco", "state": "CA"}, "role": None, "passed_test": True, - "aliases": ["John", "Johnny"] - }) + "aliases": ["John", "Johnny"], + }, + ), ], - None), + None, + ), # Content before tool call ( - "I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501 + 'I will call the tool now. <tool_calls>[{"name": "get_weather", "arguments": {"city": "Boston"}}]</tool_calls>', # noqa: E501 [make_tool_call("get_weather", {"city": "Boston"})], - "I will call the tool now. "), + "I will call the tool now. ", + ), # Content after tool call (should be stripped) ( - "<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501 + '<tool_calls>[{"name": "get_weather", "arguments": {"city": "Seattle"}}]</tool_calls>\nThank you!', # noqa: E501 [make_tool_call("get_weather", {"city": "Seattle"})], - None), + None, + ), ( - "<tool_calls>[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]</tool_calls>", + '<tool_calls>[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]</tool_calls>', [ make_tool_call( - "complex_tool", - {"level1": { - "level2": { - "level3": { - "value": 123 - } - } - }}) + "complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}} + ) ], None, ), - ]) -def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls, - expected_content): + ], +) +def test_hunyuan_a13b_tool_parser_extract( + model_output, expected_tool_calls, expected_content +): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "hunyuan_a13b")(mock_tokenizer) - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=False) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")( + mock_tokenizer + ) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=False + ) # align the random id. for idx in range(len(tool_calls)): @@ -102,49 +103,74 @@ def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls, # Streaming test: simulate incremental output -@pytest.mark.parametrize("model_deltas,expected_tool_calls", [ - ([ - "<tool_calls>[{\"name\": \"get_weather\", ", - "\"arguments\": {\"city\": \"San Francisco\", ", - "\"metric\": \"celsius\"}}]", "</tool_calls>" - ], [ - make_tool_call("get_weather", { - "city": "San Francisco", - "metric": "celsius" - }) - ]), - ([ - "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":", - " {\"city\": \"Boston\"}", "}]", "</tool_calls>" - ], [make_tool_call("get_weather", {"city": "Boston"})]), - ([ - "", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":", - " {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>" - ], [make_tool_call("get_weather", {"city": "Boston"})]), - pytest.param([ - "<tool_calls>[{\"name\": \"complex_tool\",", " \"arguments\": ", - " {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}", - "]</tool_calls>" - ], [ - make_tool_call("complex_tool", - {"level1": { - "level2": { - "level3": { - "value": 123 - } - } - }}) +@pytest.mark.parametrize( + "model_deltas,expected_tool_calls", + [ + ( + [ + '<tool_calls>[{"name": "get_weather", ', + '"arguments": {"city": "San Francisco", ', + '"metric": "celsius"}}]', + "</tool_calls>", + ], + [ + make_tool_call( + "get_weather", {"city": "San Francisco", "metric": "celsius"} + ) + ], + ), + ( + [ + '<tool_calls>[{"name":', + ' "get_weather",', + ' "arguments":', + ' {"city": "Boston"}', + "}]", + "</tool_calls>", + ], + [make_tool_call("get_weather", {"city": "Boston"})], + ), + ( + [ + "", + '<tool_calls>[{"name":', + ' "get_weather",', + ' "arguments":', + ' {"city": "Boston"}', + "}]", + "</tool_calls>", + "\n</answer>", + ], + [make_tool_call("get_weather", {"city": "Boston"})], + ), + pytest.param( + [ + '<tool_calls>[{"name": "complex_tool",', + ' "arguments": ', + ' {"level1": {"level2": ', + '{"level3": {"value": 123}}}}}', + "]</tool_calls>", + ], + [ + make_tool_call( + "complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}} + ) + ], + marks=pytest.mark.xfail( + reason="stream parsing not support nested json yet." + ), + ), ], - marks=pytest.mark.xfail( - reason="stream parsing not support nested json yet.")), -]) +) def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "hunyuan_a13b")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")( + mock_tokenizer + ) reconstructor = run_tool_extraction_streaming( - tool_parser, model_deltas, assert_one_tool_per_delta=False) + tool_parser, model_deltas, assert_one_tool_per_delta=False + ) # align the random id. for idx in range(len(reconstructor.tool_calls)): diff --git a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py index 09726c7e3e5b..c7a8ef83cf71 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py @@ -5,8 +5,7 @@ from transformers import AutoTokenizer from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation -from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import ( - Llama3JsonToolParser) +from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser @pytest.fixture @@ -18,8 +17,10 @@ def parser(): def test_extract_tool_calls_simple(parser): # Test with a simple tool call - model_output = ('Here is the result: {"name": "getOpenIncidentsTool", ' - '"parameters": {}} Would you like to know more?') + model_output = ( + 'Here is the result: {"name": "getOpenIncidentsTool", ' + '"parameters": {}} Would you like to know more?' + ) result = parser.extract_tool_calls(model_output, None) assert isinstance(result, ExtractedToolCallInformation) @@ -34,8 +35,8 @@ def test_extract_tool_calls_simple(parser): def test_extract_tool_calls_with_arguments(parser): # Test with a tool call that has arguments model_output = ( - '{"name": "searchTool", "parameters": {"query": "test query", ' - '"limit": 10}}') + '{"name": "searchTool", "parameters": {"query": "test query", "limit": 10}}' + ) result = parser.extract_tool_calls(model_output, None) assert result.tools_called is True @@ -81,7 +82,8 @@ def test_extract_tool_calls_multiple_json(parser): model_output = ( '{"name": "searchTool", "parameters": {"query": "test1"}}; ' '{"name": "getOpenIncidentsTool", "parameters": {}}; ' - '{"name": "searchTool", "parameters": {"query": "test2"}}') + '{"name": "searchTool", "parameters": {"query": "test2"}}' + ) result = parser.extract_tool_calls(model_output, None) assert result.tools_called is True @@ -105,7 +107,8 @@ def test_extract_tool_calls_multiple_json_with_whitespace(parser): model_output = ( '{"name": "searchTool", "parameters": {"query": "test1"}} ; ' '{"name": "getOpenIncidentsTool", "parameters": {}} ; ' - '{"name": "searchTool", "parameters": {"query": "test2"}}') + '{"name": "searchTool", "parameters": {"query": "test2"}}' + ) result = parser.extract_tool_calls(model_output, None) assert result.tools_called is True @@ -118,11 +121,12 @@ def test_extract_tool_calls_multiple_json_with_whitespace(parser): def test_extract_tool_calls_multiple_json_with_surrounding_text(parser): # Test with multiple JSONs and surrounding text model_output = ( - 'Here are the results: ' + "Here are the results: " '{"name": "searchTool", "parameters": {"query": "test1"}}; ' '{"name": "getOpenIncidentsTool", "parameters": {}}; ' '{"name": "searchTool", "parameters": {"query": "test2"}} ' - 'Would you like to know more?') + "Would you like to know more?" + ) result = parser.extract_tool_calls(model_output, None) assert result.tools_called is True diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py index 8c86b4889e15..94277980f229 100644 --- a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -6,7 +6,9 @@ import pytest from tests.entrypoints.openai.tool_parsers.utils import ( - run_tool_extraction, run_tool_extraction_streaming) + run_tool_extraction, + run_tool_extraction_streaming, +) from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager @@ -16,12 +18,14 @@ name="get_weather", arguments='{"city": "LA", "metric": "C"}', ) -MORE_TYPES_FUNCTION_OUTPUT = ("[register_user(name='Doe', " - "age=9, " - "address={'city': 'LA', 'state': 'CA'}, " - "role=None, " - "passed_test=True, " - "aliases=['John', 'Johnny'])]") +MORE_TYPES_FUNCTION_OUTPUT = ( + "[register_user(name='Doe', " + "age=9, " + "address={'city': 'LA', 'state': 'CA'}, " + "role=None, " + "passed_test=True, " + "aliases=['John', 'Johnny'])]" +) MORE_TYPES_FUNCTION_CALL = FunctionCall( name="register_user", arguments='{"name": "Doe", ' @@ -34,7 +38,7 @@ PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]" PARAMETERLESS_FUNCTION_CALL = FunctionCall( name="get_weather", - arguments='{}', + arguments="{}", ) EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]" EMPTY_DICT_FUNCTION_CALL = FunctionCall( @@ -47,25 +51,28 @@ arguments='{"steps": []}', ) ESCAPED_STRING_FUNCTION_OUTPUT = ( - r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]") + r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]" +) ESCAPED_STRING_FUNCTION_CALL = FunctionCall( name="get_weather", arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', ) PYTHON_TAG_FUNCTION_OUTPUT = ( - "<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>") + "<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>" +) @pytest.mark.parametrize("streaming", [True, False]) def test_no_tool_call(streaming: bool): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) model_output = "How can I help you today?" - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert content == model_output assert len(tool_calls) == 0 @@ -75,98 +82,139 @@ def test_no_tool_call(streaming: bool): test_str += "[get_weather(city='LA', metric='C')," test_str += "register_user(name='Doe', age=9)]" TEST_CASES = [ - pytest.param(True, - ESCAPED_STRING_FUNCTION_OUTPUT, - [ESCAPED_STRING_FUNCTION_CALL], - id="simple_streaming"), - pytest.param(False, - SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], - id="simple_nonstreaming"), - pytest.param(True, - MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], - id="more_types_streaming"), - pytest.param(False, - MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], - id="more_types_nonstreaming"), - pytest.param(True, - PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_streaming"), - pytest.param(False, - PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_nonstreaming"), - pytest.param(True, - EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_streaming"), - pytest.param(False, - EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_nonstreaming"), - pytest.param(True, - EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_streaming"), - pytest.param(False, - EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_nonstreaming"), - pytest.param(True, - ESCAPED_STRING_FUNCTION_OUTPUT, - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_streaming"), - pytest.param(False, - ESCAPED_STRING_FUNCTION_OUTPUT, - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_nonstreaming"), + pytest.param( + True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="simple_streaming", + ), + pytest.param( + False, SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], id="simple_nonstreaming" + ), + pytest.param( + True, + MORE_TYPES_FUNCTION_OUTPUT, + [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming", + ), + pytest.param( + False, + MORE_TYPES_FUNCTION_OUTPUT, + [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming", + ), + pytest.param( + True, + PARAMETERLESS_FUNCTION_OUTPUT, + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming", + ), + pytest.param( + False, + PARAMETERLESS_FUNCTION_OUTPUT, + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming", + ), + pytest.param( + True, + EMPTY_DICT_FUNCTION_OUTPUT, + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming", + ), + pytest.param( + False, + EMPTY_DICT_FUNCTION_OUTPUT, + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming", + ), + pytest.param( + True, + EMPTY_LIST_FUNCTION_OUTPUT, + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming", + ), + pytest.param( + False, + EMPTY_LIST_FUNCTION_OUTPUT, + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming", + ), + pytest.param( + True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming", + ), + pytest.param( + False, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming", + ), pytest.param( True, "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", [ SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), ], - id="parallel_calls_streaming"), + id="parallel_calls_streaming", + ), pytest.param( False, "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", [ SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), + ], + id="parallel_calls_nonstreaming", + ), + pytest.param( + True, + PYTHON_TAG_FUNCTION_OUTPUT, + [SIMPLE_FUNCTION_CALL], + id="python_tag_streaming", + ), + pytest.param( + False, + PYTHON_TAG_FUNCTION_OUTPUT, + [SIMPLE_FUNCTION_CALL], + id="python_tag_nonstreaming", + ), + pytest.param( + True, + test_str, + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), + ], + id="parallel_calls_streaming", + ), + pytest.param( + False, + "<|python_start|>[get_weather(city='LA', metric='C'), " + + "register_user(name='Doe', age=9)]", + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'), ], - id="parallel_calls_nonstreaming"), - pytest.param(True, - PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], - id="python_tag_streaming"), - pytest.param(False, - PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], - id="python_tag_nonstreaming"), - pytest.param(True, - test_str, [ - SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') - ], - id="parallel_calls_streaming"), - pytest.param(False, - "<|python_start|>[get_weather(city='LA', metric='C'), " + - "register_user(name='Doe', age=9)]", [ - SIMPLE_FUNCTION_CALL, - FunctionCall(name="register_user", - arguments='{"name": "Doe", "age": 9}') - ], - id="parallel_calls_nonstreaming"), + id="parallel_calls_nonstreaming", + ), ] -@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", - TEST_CASES) -def test_tool_call(streaming: bool, model_output: str, - expected_tool_calls: list[FunctionCall]): +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES) +def test_tool_call( + streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall] +): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert len(tool_calls) == len(expected_tool_calls) for actual, expected in zip(tool_calls, expected_tool_calls): @@ -176,8 +224,9 @@ def test_tool_call(streaming: bool, model_output: str, def test_streaming_tool_call_with_large_steps(): mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) model_output_deltas = [ "<|python_start|>[get_weather(city='LA', metric='C'), " "get_weather(), " @@ -185,7 +234,8 @@ def test_streaming_tool_call_with_large_steps(): ] reconstructor = run_tool_extraction_streaming( - tool_parser, model_output_deltas, assert_one_tool_per_delta=False) + tool_parser, model_output_deltas, assert_one_tool_per_delta=False + ) assert reconstructor.other_content == "" assert len(reconstructor.tool_calls) == 3 @@ -198,8 +248,9 @@ def test_streaming_tool_call_with_large_steps(): def test_regex_timeout_handling(streaming: bool): """test regex timeout is handled gracefully""" mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 @@ -207,10 +258,10 @@ def test_regex_timeout_handling(streaming: bool): mock_regex = MagicMock() mock_regex.match.side_effect = TimeoutError("Regex timeout") - with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex): - content, tool_calls = run_tool_extraction(tool_parser, - fake_problematic_input, - streaming=streaming) + with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex): + content, tool_calls = run_tool_extraction( + tool_parser, fake_problematic_input, streaming=streaming + ) # should treat as regular text when regex times out assert content == fake_problematic_input diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py index d83137472598..ccd6abbac4c9 100644 --- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py @@ -6,7 +6,9 @@ import pytest from tests.entrypoints.openai.tool_parsers.utils import ( - run_tool_extraction, run_tool_extraction_streaming) + run_tool_extraction, + run_tool_extraction_streaming, +) from vllm.entrypoints.openai.protocol import FunctionCall from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager @@ -22,7 +24,8 @@ "address={'city': 'San Francisco', 'state': 'CA'}, " "role=None, " "passed_test=True, " - "aliases=['John', 'Johnny'])") + "aliases=['John', 'Johnny'])" +) MORE_TYPES_FUNCTION_CALL = FunctionCall( name="register_user", arguments='{"name": "John Doe", ' @@ -35,7 +38,7 @@ PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()" PARAMETERLESS_FUNCTION_CALL = FunctionCall( name="get_weather", - arguments='{}', + arguments="{}", ) EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})" EMPTY_DICT_FUNCTION_CALL = FunctionCall( @@ -48,7 +51,8 @@ arguments='{"steps": []}', ) ESCAPED_STRING_FUNCTION_OUTPUT = ( - r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')") + r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')" +) ESCAPED_STRING_FUNCTION_CALL = FunctionCall( name="get_weather", arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', @@ -59,80 +63,118 @@ def test_no_tool_call(streaming: bool): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + mock_tokenizer + ) model_output = "How can I help you today?" - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert content == model_output assert len(tool_calls) == 0 TEST_CASES = [ - pytest.param(True, - f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL], - id="simple_streaming"), - pytest.param(False, - f"[{SIMPLE_FUNCTION_OUTPUT}]", [SIMPLE_FUNCTION_CALL], - id="simple_nonstreaming"), - pytest.param(True, - f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL], - id="more_types_streaming"), - pytest.param(False, - f"[{MORE_TYPES_FUNCTION_OUTPUT}]", [MORE_TYPES_FUNCTION_CALL], - id="more_types_nonstreaming"), - pytest.param(True, - f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", - [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_streaming"), - pytest.param(False, - f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", - [PARAMETERLESS_FUNCTION_CALL], - id="parameterless_nonstreaming"), - pytest.param(True, - f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_streaming"), - pytest.param(False, - f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", [EMPTY_DICT_FUNCTION_CALL], - id="empty_dict_nonstreaming"), - pytest.param(True, - f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_streaming"), - pytest.param(False, - f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", [EMPTY_LIST_FUNCTION_CALL], - id="empty_list_nonstreaming"), - pytest.param(True, - f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_streaming"), - pytest.param(False, - f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", - [ESCAPED_STRING_FUNCTION_CALL], - id="escaped_string_nonstreaming"), - pytest.param(True, - f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", - [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], - id="parallel_calls_streaming"), - pytest.param(False, - f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", - [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], - id="parallel_calls_nonstreaming"), + pytest.param( + True, + f"[{SIMPLE_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL], + id="simple_streaming", + ), + pytest.param( + False, + f"[{SIMPLE_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL], + id="simple_nonstreaming", + ), + pytest.param( + True, + f"[{MORE_TYPES_FUNCTION_OUTPUT}]", + [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming", + ), + pytest.param( + False, + f"[{MORE_TYPES_FUNCTION_OUTPUT}]", + [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming", + ), + pytest.param( + True, + f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming", + ), + pytest.param( + False, + f"[{PARAMETERLESS_FUNCTION_OUTPUT}]", + [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming", + ), + pytest.param( + True, + f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming", + ), + pytest.param( + False, + f"[{EMPTY_DICT_FUNCTION_OUTPUT}]", + [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming", + ), + pytest.param( + True, + f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming", + ), + pytest.param( + False, + f"[{EMPTY_LIST_FUNCTION_OUTPUT}]", + [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming", + ), + pytest.param( + True, + f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming", + ), + pytest.param( + False, + f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]", + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming", + ), + pytest.param( + True, + f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + id="parallel_calls_streaming", + ), + pytest.param( + False, + f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", + [SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], + id="parallel_calls_nonstreaming", + ), ] -@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", - TEST_CASES) -def test_tool_call(streaming: bool, model_output: str, - expected_tool_calls: list[FunctionCall]): +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES) +def test_tool_call( + streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall] +): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + mock_tokenizer + ) - content, tool_calls = run_tool_extraction(tool_parser, - model_output, - streaming=streaming) + content, tool_calls = run_tool_extraction( + tool_parser, model_output, streaming=streaming + ) assert content is None assert len(tool_calls) == len(expected_tool_calls) @@ -144,7 +186,8 @@ def test_tool_call(streaming: bool, model_output: str, def test_streaming_tool_call_with_large_steps(): mock_tokenizer = MagicMock() tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")( - mock_tokenizer) + mock_tokenizer + ) model_output_deltas = [ "[get_weather(city='San", " Francisco', metric='celsius'), " @@ -153,7 +196,8 @@ def test_streaming_tool_call_with_large_steps(): ] reconstructor = run_tool_extraction_streaming( - tool_parser, model_output_deltas, assert_one_tool_per_delta=False) + tool_parser, model_output_deltas, assert_one_tool_per_delta=False + ) assert reconstructor.other_content == "" assert len(reconstructor.tool_calls) == 3 @@ -166,8 +210,9 @@ def test_streaming_tool_call_with_large_steps(): def test_regex_timeout_handling(streaming: bool): """test regex timeout is handled gracefully""" mock_tokenizer = MagicMock() - tool_parser: ToolParser = ToolParserManager.get_tool_parser( - "llama4_pythonic")(mock_tokenizer) + tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")( + mock_tokenizer + ) fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2 @@ -175,10 +220,10 @@ def test_regex_timeout_handling(streaming: bool): mock_regex = MagicMock() mock_regex.match.side_effect = TimeoutError("Regex timeout") - with patch.object(tool_parser, 'TOOL_CALL_REGEX', mock_regex): - content, tool_calls = run_tool_extraction(tool_parser, - fake_problematic_input, - streaming=streaming) + with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex): + content, tool_calls = run_tool_extraction( + tool_parser, fake_problematic_input, streaming=streaming + ) # should treat as regular text when regex times out assert content == fake_problematic_input diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index e1b41f45f554..cfa4d3584e70 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -4,15 +4,17 @@ from collections.abc import Iterable from typing import Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import ToolParser class StreamingToolReconstructor: - def __init__(self, assert_one_tool_per_delta: bool = True): self.tool_calls: list[ToolCall] = [] self.other_content: str = "" @@ -23,49 +25,60 @@ def append_delta(self, delta: DeltaMessage): self.other_content += delta.content else: assert delta.tool_calls, ( - "Streaming results should have either content or tool calls " - "(or both)") + "Streaming results should have either content or tool calls (or both)" + ) if self._assert_one_tool_per_delta: # Note: This isn't strictly required by the API and may not be # possible to adhere to depending on the token space and number of # tokens per streamed response from the model, but it is required # by tool_use tests, so we enforce it here by default also. assert len(delta.tool_calls) < 2, ( - "Streaming should include only one tool call per update.") + "Streaming should include only one tool call per update." + ) for call_delta in delta.tool_calls: assert call_delta.type is None or call_delta.type == "function", ( "Streaming tool calls should only emit function calls. Got " - f"{call_delta.type}") - current_tool_call = self.tool_calls[ - call_delta.index] if call_delta.index < len( - self.tool_calls) else None + f"{call_delta.type}" + ) + current_tool_call = ( + self.tool_calls[call_delta.index] + if call_delta.index < len(self.tool_calls) + else None + ) if current_tool_call: - assert (not call_delta.function.name), ( + assert not call_delta.function.name, ( "Streaming tool calls should emit the full function name " - f"exactly once. Got {call_delta.function.name}") - assert (not call_delta.id), ( + f"exactly once. Got {call_delta.function.name}" + ) + assert not call_delta.id, ( "Streaming tool calls must emit function id only once. Got " - f"{call_delta.id}") - assert (call_delta.index == len(self.tool_calls) - 1), ( + f"{call_delta.id}" + ) + assert call_delta.index == len(self.tool_calls) - 1, ( f"Incorrect index for tool delta. Got {call_delta.index}, " - f"expected {len(self.tool_calls) - 1}") - current_tool_call.function.arguments += ( - call_delta.function.arguments) + f"expected {len(self.tool_calls) - 1}" + ) + current_tool_call.function.arguments += call_delta.function.arguments else: assert call_delta.id is not None, ( - "Streaming tool calls must have an id on first appearance") + "Streaming tool calls must have an id on first appearance" + ) assert call_delta.function.name is not None, ( - "Streaming tool calls must have a function name on first " - "appearance") + "Streaming tool calls must have a function name on first appearance" + ) assert call_delta.index == len(self.tool_calls), ( f"Incorrect index for tool delta. Got {call_delta.index}, " - f"expected {len(self.tool_calls)}") + f"expected {len(self.tool_calls)}" + ) self.tool_calls.append( - ToolCall(id=call_delta.id, - function=FunctionCall( - name=call_delta.function.name, - arguments=call_delta.function.arguments - or ""))) + ToolCall( + id=call_delta.id, + function=FunctionCall( + name=call_delta.function.name, + arguments=call_delta.function.arguments or "", + ), + ) + ) def run_tool_extraction( @@ -80,11 +93,11 @@ def run_tool_extraction( tool_parser, model_output, request, - assert_one_tool_per_delta=assert_one_tool_per_delta) + assert_one_tool_per_delta=assert_one_tool_per_delta, + ) return reconstructor.other_content or None, reconstructor.tool_calls else: - extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, - request) + extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, request) assert extracted.tools_called == bool(extracted.tool_calls) return extracted.content, extracted.tool_calls @@ -92,7 +105,7 @@ def run_tool_extraction( def run_tool_extraction_nonstreaming( tool_parser: ToolParser, model_output: str, - request: Union[ChatCompletionRequest, None] = None + request: Union[ChatCompletionRequest, None] = None, ) -> ExtractedToolCallInformation: request = request or ChatCompletionRequest(messages=[], model="test-model") return tool_parser.extract_tool_calls(model_output, request) @@ -106,7 +119,8 @@ def run_tool_extraction_streaming( ) -> StreamingToolReconstructor: request = request or ChatCompletionRequest(messages=[], model="test-model") reconstructor = StreamingToolReconstructor( - assert_one_tool_per_delta=assert_one_tool_per_delta) + assert_one_tool_per_delta=assert_one_tool_per_delta + ) previous_text = "" previous_tokens: list[int] = [] for delta in model_deltas: @@ -118,8 +132,14 @@ def run_tool_extraction_streaming( current_text = previous_text + delta current_tokens = previous_tokens + token_delta delta_message = tool_parser.extract_tool_calls_streaming( - previous_text, current_text, delta, previous_tokens, - current_tokens, token_delta, request) + previous_text, + current_text, + delta, + previous_tokens, + current_tokens, + token_delta, + request, + ) if delta_message is not None: reconstructor.append_delta(delta_message) previous_text = current_text diff --git a/tests/entrypoints/pooling/correctness/test_mteb_embed.py b/tests/entrypoints/pooling/correctness/test_mteb_embed.py index 12a4875bdacf..7f16638e51e2 100644 --- a/tests/entrypoints/pooling/correctness/test_mteb_embed.py +++ b/tests/entrypoints/pooling/correctness/test_mteb_embed.py @@ -5,8 +5,11 @@ import pytest from tests.models.language.pooling_mteb_test.mteb_utils import ( - MTEB_EMBED_TASKS, MTEB_EMBED_TOL, OpenAIClientMtebEncoder, - run_mteb_embed_task) + MTEB_EMBED_TASKS, + MTEB_EMBED_TOL, + OpenAIClientMtebEncoder, + run_mteb_embed_task, +) from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" @@ -17,10 +20,7 @@ @pytest.fixture(scope="module") def server(): - args = [ - "--runner", "pooling", "--enforce-eager", - "--disable-uvicorn-access-log" - ] + args = ["--runner", "pooling", "--enforce-eager", "--disable-uvicorn-access-log"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/entrypoints/pooling/correctness/test_mteb_score.py b/tests/entrypoints/pooling/correctness/test_mteb_score.py index 7c059d16b386..1afe68b189db 100644 --- a/tests/entrypoints/pooling/correctness/test_mteb_score.py +++ b/tests/entrypoints/pooling/correctness/test_mteb_score.py @@ -5,8 +5,13 @@ import pytest from tests.models.language.pooling_mteb_test.mteb_utils import ( - MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL, - RerankClientMtebEncoder, ScoreClientMtebEncoder, run_mteb_rerank) + MTEB_RERANK_LANGS, + MTEB_RERANK_TASKS, + MTEB_RERANK_TOL, + RerankClientMtebEncoder, + ScoreClientMtebEncoder, + run_mteb_rerank, +) from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" @@ -17,10 +22,7 @@ @pytest.fixture(scope="module") def server(): - args = [ - "--runner", "pooling", "--enforce-eager", - "--disable-uvicorn-access-log" - ] + args = ["--runner", "pooling", "--enforce-eager", "--disable-uvicorn-access-log"] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -29,8 +31,7 @@ def server(): def test_mteb_score(server): url = server.url_for("score") encoder = ScoreClientMtebEncoder(MODEL_NAME, url) - vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, - MTEB_RERANK_LANGS) + vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, MTEB_RERANK_LANGS) print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) @@ -44,8 +45,7 @@ def test_mteb_score(server): def test_mteb_rerank(server): url = server.url_for("rerank") encoder = RerankClientMtebEncoder(MODEL_NAME, url) - vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, - MTEB_RERANK_LANGS) + vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, MTEB_RERANK_LANGS) print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) diff --git a/tests/entrypoints/pooling/llm/test_classify.py b/tests/entrypoints/pooling/llm/test_classify.py index ff5cea11a918..ae216c464a5b 100644 --- a/tests/entrypoints/pooling/llm/test_classify.py +++ b/tests/entrypoints/pooling/llm/test_classify.py @@ -19,12 +19,14 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - seed=0) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + ) yield weakref.proxy(llm) @@ -35,26 +37,25 @@ def llm(): @pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): - def get_outputs(activation): outputs = llm.classify( - prompts, - pooling_params=PoolingParams(activation=activation), - use_tqdm=False) + prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False + ) return torch.tensor([x.outputs.probs for x in outputs]) default = get_outputs(activation=None) w_activation = get_outputs(activation=True) wo_activation = get_outputs(activation=False) - assert torch.allclose(default, w_activation, - atol=1e-2), "Default should use activation." - assert not torch.allclose( - w_activation, wo_activation, - atol=1e-2), "wo_activation should not use activation." - assert torch.allclose( - softmax(wo_activation), w_activation, atol=1e-2 - ), "w_activation should be close to activation(wo_activation)." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." + ) + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." + ) def test_encode_api(llm: LLM): diff --git a/tests/entrypoints/pooling/llm/test_embedding.py b/tests/entrypoints/pooling/llm/test_embedding.py index 485f04ed6d84..aa24a70fd18b 100644 --- a/tests/entrypoints/pooling/llm/test_embedding.py +++ b/tests/entrypoints/pooling/llm/test_embedding.py @@ -19,12 +19,14 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - seed=0) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + ) yield weakref.proxy(llm) @@ -35,21 +37,20 @@ def llm(): @pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): - def get_outputs(normalize): - outputs = llm.embed(prompts, - pooling_params=PoolingParams(normalize=normalize), - use_tqdm=False) + outputs = llm.embed( + prompts, pooling_params=PoolingParams(normalize=normalize), use_tqdm=False + ) return torch.tensor([x.outputs.embedding for x in outputs]) default = get_outputs(normalize=None) w_normal = get_outputs(normalize=True) wo_normal = get_outputs(normalize=False) - assert torch.allclose(default, w_normal, - atol=1e-2), "Default should use normal." - assert not torch.allclose(w_normal, wo_normal, - atol=1e-2), "wo_normal should not use normal." - assert torch.allclose( - w_normal, F.normalize(wo_normal, p=2, dim=-1), - atol=1e-2), "w_normal should be close to normal(wo_normal)." + assert torch.allclose(default, w_normal, atol=1e-2), "Default should use normal." + assert not torch.allclose(w_normal, wo_normal, atol=1e-2), ( + "wo_normal should not use normal." + ) + assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), ( + "w_normal should be close to normal(wo_normal)." + ) diff --git a/tests/entrypoints/pooling/llm/test_encode.py b/tests/entrypoints/pooling/llm/test_encode.py index eae3e234378f..d6aae99944f8 100644 --- a/tests/entrypoints/pooling/llm/test_encode.py +++ b/tests/entrypoints/pooling/llm/test_encode.py @@ -31,12 +31,14 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - seed=0) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + ) yield weakref.proxy(llm) diff --git a/tests/entrypoints/pooling/llm/test_reward.py b/tests/entrypoints/pooling/llm/test_reward.py index 11d164c978a9..8312ff180b36 100644 --- a/tests/entrypoints/pooling/llm/test_reward.py +++ b/tests/entrypoints/pooling/llm/test_reward.py @@ -19,13 +19,15 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - trust_remote_code=True, - seed=0) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + trust_remote_code=True, + seed=0, + ) yield weakref.proxy(llm) @@ -36,21 +38,20 @@ def llm(): @pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): - def get_outputs(softmax): - outputs = llm.reward(prompts, - pooling_params=PoolingParams(softmax=softmax), - use_tqdm=False) + outputs = llm.reward( + prompts, pooling_params=PoolingParams(softmax=softmax), use_tqdm=False + ) return torch.cat([x.outputs.data for x in outputs]) default = get_outputs(softmax=None) w_softmax = get_outputs(softmax=True) wo_softmax = get_outputs(softmax=False) - assert torch.allclose(default, w_softmax, - atol=1e-2), "Default should use softmax." - assert not torch.allclose(w_softmax, wo_softmax, - atol=1e-2), "wo_softmax should not use softmax." - assert torch.allclose( - softmax(wo_softmax), w_softmax, - atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." + assert torch.allclose(default, w_softmax, atol=1e-2), "Default should use softmax." + assert not torch.allclose(w_softmax, wo_softmax, atol=1e-2), ( + "wo_softmax should not use softmax." + ) + assert torch.allclose(softmax(wo_softmax), w_softmax, atol=1e-2), ( + "w_softmax should be close to softmax(wo_softmax)." + ) diff --git a/tests/entrypoints/pooling/llm/test_score.py b/tests/entrypoints/pooling/llm/test_score.py index 447378f989d0..9bf74fce906b 100644 --- a/tests/entrypoints/pooling/llm/test_score.py +++ b/tests/entrypoints/pooling/llm/test_score.py @@ -17,12 +17,14 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True, - seed=0) + llm = LLM( + model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + ) yield weakref.proxy(llm) @@ -33,7 +35,6 @@ def llm(): @pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): - def get_outputs(activation): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." @@ -42,18 +43,20 @@ def get_outputs(activation): text_1, text_2, pooling_params=PoolingParams(activation=activation), - use_tqdm=False) + use_tqdm=False, + ) return torch.tensor([x.outputs.score for x in outputs]) default = get_outputs(activation=None) w_activation = get_outputs(activation=True) wo_activation = get_outputs(activation=False) - assert torch.allclose(default, w_activation, - atol=1e-2), "Default should use activation." - assert not torch.allclose( - w_activation, wo_activation, - atol=1e-2), "wo_activation should not use activation." - assert torch.allclose( - softmax(wo_activation), w_activation, atol=1e-2 - ), "w_activation should be close to activation(wo_activation)." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." + ) + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." + ) diff --git a/tests/entrypoints/pooling/openai/test_classification.py b/tests/entrypoints/pooling/openai/test_classification.py index 26c2c8e6af17..92d40efad21c 100644 --- a/tests/entrypoints/pooling/openai/test_classification.py +++ b/tests/entrypoints/pooling/openai/test_classification.py @@ -28,21 +28,16 @@ def server(): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_single_input_classification(server: RemoteOpenAIServer, - model_name: str): +def test_single_input_classification(server: RemoteOpenAIServer, model_name: str): input_text = "This product was excellent and exceeded my expectations" classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": input_text - }, + json={"model": model_name, "input": input_text}, ) classification_response.raise_for_status() - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert output.object == "list" assert output.model == MODEL_NAME @@ -52,8 +47,7 @@ def test_single_input_classification(server: RemoteOpenAIServer, @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_multiple_inputs_classification(server: RemoteOpenAIServer, - model_name: str): +def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str): input_texts = [ "The product arrived on time and works perfectly", "I'm very satisfied with my purchase, would buy again", @@ -65,13 +59,9 @@ def test_multiple_inputs_classification(server: RemoteOpenAIServer, classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": input_texts - }, + json={"model": model_name, "input": input_texts}, ) - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert len(output.data) == len(input_texts) for i, item in enumerate(output.data): @@ -88,16 +78,11 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": long_text, - "truncate_prompt_tokens": 5 - }, + json={"model": model_name, "input": long_text, "truncate_prompt_tokens": 5}, ) classification_response.raise_for_status() - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert len(output.data) == 1 assert output.data[0].index == 0 @@ -107,15 +92,12 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, - model_name: str): +def test_invalid_truncate_prompt_tokens_error( + server: RemoteOpenAIServer, model_name: str +): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": "test", - "truncate_prompt_tokens": 513 - }, + json={"model": model_name, "input": "test", "truncate_prompt_tokens": 513}, ) error = classification_response.json() @@ -127,10 +109,7 @@ def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": "" - }, + json={"model": model_name, "input": ""}, ) error = classification_response.json() @@ -139,18 +118,13 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): @pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_batch_classification_empty_list(server: RemoteOpenAIServer, - model_name: str): +def test_batch_classification_empty_list(server: RemoteOpenAIServer, model_name: str): classification_response = requests.post( server.url_for("classify"), - json={ - "model": model_name, - "input": [] - }, + json={"model": model_name, "input": []}, ) classification_response.raise_for_status() - output = ClassificationResponse.model_validate( - classification_response.json()) + output = ClassificationResponse.model_validate(classification_response.json()) assert output.object == "list" assert isinstance(output.data, list) @@ -161,15 +135,17 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer, async def test_invocations(server: RemoteOpenAIServer): request_args = { "model": MODEL_NAME, - "input": "This product was excellent and exceeded my expectations" + "input": "This product was excellent and exceeded my expectations", } - classification_response = requests.post(server.url_for("classify"), - json=request_args) + classification_response = requests.post( + server.url_for("classify"), json=request_args + ) classification_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() classification_output = classification_response.json() @@ -177,10 +153,12 @@ async def test_invocations(server: RemoteOpenAIServer): assert classification_output.keys() == invocation_output.keys() for classification_data, invocation_data in zip( - classification_output["data"], invocation_output["data"]): + classification_output["data"], invocation_output["data"] + ): assert classification_data.keys() == invocation_data.keys() assert classification_data["probs"] == pytest.approx( - invocation_data["probs"], rel=0.01) + invocation_data["probs"], rel=0.01 + ) @pytest.mark.asyncio @@ -189,27 +167,26 @@ async def test_activation(server: RemoteOpenAIServer, model_name: str): input_text = ["This product was excellent and exceeded my expectations"] async def get_outputs(activation): - response = requests.post(server.url_for("classify"), - json={ - "model": model_name, - "input": input_text, - "activation": activation - }) + response = requests.post( + server.url_for("classify"), + json={"model": model_name, "input": input_text, "activation": activation}, + ) outputs = response.json() - return torch.tensor([x['probs'] for x in outputs["data"]]) + return torch.tensor([x["probs"] for x in outputs["data"]]) default = await get_outputs(activation=None) w_activation = await get_outputs(activation=True) wo_activation = await get_outputs(activation=False) - assert torch.allclose(default, w_activation, - atol=1e-2), "Default should use activation." - assert not torch.allclose( - w_activation, wo_activation, - atol=1e-2), "wo_activation should not use activation." - assert torch.allclose( - F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2 - ), "w_activation should be close to activation(wo_activation)." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." + ) + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(F.softmax(wo_activation, dim=-1), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." + ) @pytest.mark.asyncio @@ -218,11 +195,7 @@ def test_pooling(server: RemoteOpenAIServer, model_name: str): # pooling api uses ALL pooling, which does not support chunked prefill. response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": "test", - "encoding_format": "float" - }, + json={"model": model_name, "input": "test", "encoding_format": "float"}, ) assert response.json()["error"]["type"] == "BadRequestError" diff --git a/tests/entrypoints/pooling/openai/test_embedding.py b/tests/entrypoints/pooling/openai/test_embedding.py index 37a10e79d4fc..6f6559a961a1 100644 --- a/tests/entrypoints/pooling/openai/test_embedding.py +++ b/tests/entrypoints/pooling/openai/test_embedding.py @@ -11,8 +11,7 @@ import torch import torch.nn.functional as F -from tests.models.language.pooling.embed_utils import ( - run_embedding_correctness_test) +from tests.models.language.pooling.embed_utils import run_embedding_correctness_test from tests.models.utils import check_embeddings_close from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse @@ -50,15 +49,13 @@ async def client(server): @pytest.fixture(scope="module") def hf_model(hf_runner): - with hf_runner(MODEL_NAME, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: + with hf_runner(MODEL_NAME, dtype=DTYPE, is_sentence_transformer=True) as hf_model: yield hf_model @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, - model_name: str): +async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str): input_texts = [ "The chef prepared a delicious meal.", ] @@ -70,7 +67,8 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -90,7 +88,8 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -102,12 +101,12 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, - model_name: str): +async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str): # test list[str] input_texts = [ - "The cat sat on the mat.", "A feline was resting on a rug.", - "Stars twinkle brightly in the night sky." + "The cat sat on the mat.", + "A feline was resting on a rug.", + "Stars twinkle brightly in the night sky.", ] embedding_response = await client.embeddings.create( model=model_name, @@ -115,7 +114,8 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 3 @@ -128,15 +128,20 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, run_embedding_correctness_test(hf_model, input_texts, vllm_outputs) # test list[list[int]] - input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], - [25, 32, 64, 77]] + input_tokens = [ + [4, 5, 7, 9, 20], + [15, 29, 499], + [24, 24, 24, 24, 24], + [25, 32, 64, 77], + ] embedding_response = await client.embeddings.create( model=model_name, input=input_tokens, encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 4 @@ -148,19 +153,23 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_conversation_embedding(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI, - model_name: str): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] +async def test_conversation_embedding( + server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str +): + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] chat_response = requests.post( server.url_for("v1/embeddings"), @@ -189,64 +198,66 @@ async def test_conversation_embedding(server: RemoteOpenAIServer, extra_body={"add_special_tokens": False}, ) completion_embeddings = EmbeddingResponse.model_validate( - completion_response.model_dump(mode="json")) + completion_response.model_dump(mode="json") + ) assert chat_embeddings.id is not None assert completion_embeddings.id is not None assert chat_embeddings.created <= completion_embeddings.created - assert chat_embeddings.model_dump( - exclude={"id", "created"}) == (completion_embeddings.model_dump( - exclude={"id", "created"})) + assert chat_embeddings.model_dump(exclude={"id", "created"}) == ( + completion_embeddings.model_dump(exclude={"id", "created"}) + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI, - model_name: str): +async def test_batch_base64_embedding( + hf_model, client: openai.AsyncOpenAI, model_name: str +): input_texts = [ "Hello my name is", - "The best thing about vLLM is that it supports many different models" + "The best thing about vLLM is that it supports many different models", ] - responses_float = await client.embeddings.create(input=input_texts, - model=model_name, - encoding_format="float") + responses_float = await client.embeddings.create( + input=input_texts, model=model_name, encoding_format="float" + ) float_data = [d.embedding for d in responses_float.data] run_embedding_correctness_test(hf_model, input_texts, float_data) - responses_base64 = await client.embeddings.create(input=input_texts, - model=model_name, - encoding_format="base64") + responses_base64 = await client.embeddings.create( + input=input_texts, model=model_name, encoding_format="base64" + ) base64_data = [] for data in responses_base64.data: base64_data.append( - np.frombuffer(base64.b64decode(data.embedding), - dtype="float32").tolist()) + np.frombuffer(base64.b64decode(data.embedding), dtype="float32").tolist() + ) run_embedding_correctness_test(hf_model, input_texts, base64_data) # Default response is float32 decoded from base64 by OpenAI Client - responses_default = await client.embeddings.create(input=input_texts, - model=model_name) + responses_default = await client.embeddings.create( + input=input_texts, model=model_name + ) default_data = [d.embedding for d in responses_default.data] run_embedding_correctness_test(hf_model, input_texts, default_data) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding_truncation(client: openai.AsyncOpenAI, - model_name: str): +async def test_single_embedding_truncation(client: openai.AsyncOpenAI, model_name: str): input_texts = [ "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", ] # test single embedding embedding_response = await client.embeddings.create( - model=model_name, - input=input_texts, - extra_body={"truncate_prompt_tokens": 10}) + model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 10} + ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -256,15 +267,34 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 10 input_tokens = [ - 1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728, - 9901, 340, 2229, 385, 340, 315, 28741, 28804, 2 + 1, + 24428, + 289, + 18341, + 26165, + 285, + 19323, + 283, + 289, + 26789, + 3871, + 28728, + 9901, + 340, + 2229, + 385, + 340, + 315, + 28741, + 28804, + 2, ] embedding_response = await client.embeddings.create( - model=model_name, - input=input_tokens, - extra_body={"truncate_prompt_tokens": 10}) + model=model_name, input=input_tokens, extra_body={"truncate_prompt_tokens": 10} + ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -276,8 +306,9 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI, - model_name: str): +async def test_single_embedding_truncation_invalid( + client: openai.AsyncOpenAI, model_name: str +): input_texts = [ "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", ] @@ -286,15 +317,17 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI, response = await client.embeddings.create( model=model_name, input=input_texts, - extra_body={"truncate_prompt_tokens": 8193}) + extra_body={"truncate_prompt_tokens": 8193}, + ) assert "error" in response.object - assert "truncate_prompt_tokens value is greater than max_model_len. "\ - "Please, select a smaller truncation size." in response.message + assert ( + "truncate_prompt_tokens value is greater than max_model_len. " + "Please, select a smaller truncation size." in response.message + ) @pytest.mark.asyncio -async def test_invocations(server: RemoteOpenAIServer, - client: openai.AsyncOpenAI): +async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI): input_texts = [ "The chef prepared a delicious meal.", ] @@ -307,35 +340,43 @@ async def test_invocations(server: RemoteOpenAIServer, completion_response = await client.embeddings.create(**request_args) - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() completion_output = completion_response.model_dump() invocation_output = invocation_response.json() assert completion_output.keys() == invocation_output.keys() - for completion_data, invocation_data in zip(completion_output["data"], - invocation_output["data"]): + for completion_data, invocation_data in zip( + completion_output["data"], invocation_output["data"] + ): assert completion_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=[completion_data["embedding"]], - embeddings_1_lst=[invocation_data["embedding"]], - name_0="completion", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=[completion_data["embedding"]], + embeddings_1_lst=[invocation_data["embedding"]], + name_0="completion", + name_1="invocation", + ) @pytest.mark.asyncio async def test_invocations_conversation(server: RemoteOpenAIServer): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] request_args = { "model": MODEL_NAME, @@ -343,25 +384,28 @@ async def test_invocations_conversation(server: RemoteOpenAIServer): "encoding_format": "float", } - chat_response = requests.post(server.url_for("v1/embeddings"), - json=request_args) + chat_response = requests.post(server.url_for("v1/embeddings"), json=request_args) chat_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() chat_output = chat_response.json() invocation_output = invocation_response.json() assert chat_output.keys() == invocation_output.keys() - for chat_data, invocation_data in zip(chat_output["data"], - invocation_output["data"]): + for chat_data, invocation_data in zip( + chat_output["data"], invocation_output["data"] + ): assert chat_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=[chat_data["embedding"]], - embeddings_1_lst=[invocation_data["embedding"]], - name_0="chat", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=[chat_data["embedding"]], + embeddings_1_lst=[invocation_data["embedding"]], + name_0="chat", + name_1="invocation", + ) @pytest.mark.asyncio @@ -374,23 +418,22 @@ async def get_outputs(normalize): "model": MODEL_NAME, "input": input_text, "encoding_format": "float", - "normalize": normalize + "normalize": normalize, } - response = requests.post(server.url_for("v1/embeddings"), - json=request_args) + response = requests.post(server.url_for("v1/embeddings"), json=request_args) outputs = response.json() - return torch.tensor([x['embedding'] for x in outputs["data"]]) + return torch.tensor([x["embedding"] for x in outputs["data"]]) default = await get_outputs(normalize=None) w_normal = await get_outputs(normalize=True) wo_normal = await get_outputs(normalize=False) - assert torch.allclose(default, w_normal, - atol=1e-2), "Default should use normal." - assert not torch.allclose(w_normal, wo_normal, - atol=1e-2), "wo_normal should not use normal." - assert torch.allclose( - w_normal, F.normalize(wo_normal, p=2, dim=-1), - atol=1e-2), "w_normal should be close to normal(wo_normal)." + assert torch.allclose(default, w_normal, atol=1e-2), "Default should use normal." + assert not torch.allclose(w_normal, wo_normal, atol=1e-2), ( + "wo_normal should not use normal." + ) + assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), ( + "w_normal should be close to normal(wo_normal)." + ) diff --git a/tests/entrypoints/pooling/openai/test_embedding_dimensions.py b/tests/entrypoints/pooling/openai/test_embedding_dimensions.py index 3c7e88daa8ff..92df43d7dbdc 100644 --- a/tests/entrypoints/pooling/openai/test_embedding_dimensions.py +++ b/tests/entrypoints/pooling/openai/test_embedding_dimensions.py @@ -10,17 +10,18 @@ import pytest from tests.conftest import HfRunner -from tests.models.language.pooling.embed_utils import ( - run_embedding_correctness_test) +from tests.models.language.pooling.embed_utils import run_embedding_correctness_test from tests.models.utils import EmbedModelInfo from tests.utils import RemoteOpenAIServer from vllm.entrypoints.openai.protocol import EmbeddingResponse MODELS = [ EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - matryoshka_dimensions=[256]), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + matryoshka_dimensions=[256], + ), ] input_texts = [ @@ -48,15 +49,14 @@ def server(model_info, dtype: str): dtype, "--enforce-eager", "--max-model-len", - "512" + "512", ] if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5": # Manually enable Matryoshka Embeddings - args.extend([ - "--trust_remote_code", "--hf_overrides", - '{"matryoshka_dimensions":[256]}' - ]) + args.extend( + ["--trust_remote_code", "--hf_overrides", '{"matryoshka_dimensions":[256]}'] + ) with RemoteOpenAIServer(model_info.name, args) as remote_server: yield remote_server @@ -64,14 +64,16 @@ def server(model_info, dtype: str): @pytest.fixture(scope="module") def hf_model(hf_runner, model_info, dtype: str): - with hf_runner(model_info.name, dtype=dtype, - is_sentence_transformer=True) as hf_model: + with hf_runner( + model_info.name, dtype=dtype, is_sentence_transformer=True + ) as hf_model: yield hf_model @pytest.mark.asyncio -async def test_matryoshka(model_info: EmbedModelInfo, - server: RemoteOpenAIServer, hf_model: HfRunner): +async def test_matryoshka( + model_info: EmbedModelInfo, server: RemoteOpenAIServer, hf_model: HfRunner +): client = server.get_async_client() async def make_request_and_correctness_test(dimensions): @@ -84,7 +86,8 @@ async def make_request_and_correctness_test(dimensions): encoding_format="float", ) embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 3 @@ -97,8 +100,7 @@ async def make_request_and_correctness_test(dimensions): assert len(embeddings.data[0].embedding) == dimensions vllm_outputs = [d.embedding for d in embeddings.data] - run_embedding_correctness_test(hf_model, prompts, vllm_outputs, - dimensions) + run_embedding_correctness_test(hf_model, prompts, vllm_outputs, dimensions) if model_info.is_matryoshka: valid_dimensions: list[Optional[int]] = [None] diff --git a/tests/entrypoints/pooling/openai/test_embedding_long_text.py b/tests/entrypoints/pooling/openai/test_embedding_long_text.py index ab5f765c28ed..f977c81a9084 100644 --- a/tests/entrypoints/pooling/openai/test_embedding_long_text.py +++ b/tests/entrypoints/pooling/openai/test_embedding_long_text.py @@ -31,7 +31,6 @@ def _generate_random_text(word_count: int) -> str: "that", "these", "those", - # Action verbs "create", "build", @@ -80,7 +79,6 @@ def _generate_random_text(word_count: int) -> str: "finish", "deliver", "provide", - # Technology and science nouns "system", "application", @@ -132,7 +130,6 @@ def _generate_random_text(word_count: int) -> str: "optimization", "performance", "efficiency", - # General nouns "project", "team", @@ -175,7 +172,7 @@ def _generate_random_text(word_count: int) -> str: "session", "meeting", "discussion", - "decision" + "decision", ] words = [] @@ -189,7 +186,7 @@ def _generate_random_text(word_count: int) -> str: result = [] for i, word in enumerate(words_list): result.append(word) - if ((i + 1) % random.randint(10, 20) == 0 and i < len(words_list) - 1): + if (i + 1) % random.randint(10, 20) == 0 and i < len(words_list) - 1: result[-1] += "." return " ".join(result) @@ -216,9 +213,11 @@ def server_with_chunked_processing(): "--enforce-eager", "--max-model-len", "512", # Set smaller max_model_len to trigger chunking mechanism - '--pooler-config', - ('{"pooling_type": "MEAN", "normalize": true, ' - '"enable_chunked_processing": true, "max_embed_len": 10000}'), + "--pooler-config", + ( + '{"pooling_type": "MEAN", "normalize": true, ' + '"enable_chunked_processing": true, "max_embed_len": 10000}' + ), "--gpu-memory-utilization", "0.8", ] @@ -230,23 +229,22 @@ def server_with_chunked_processing(): @pytest_asyncio.fixture async def client_with_chunked_processing(server_with_chunked_processing): """Create async client with chunking processing support.""" - async with server_with_chunked_processing.get_async_client( - ) as async_client: + async with server_with_chunked_processing.get_async_client() as async_client: yield async_client @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_long_text_embedding_1500_chars( - client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): - """Test embedding processing for ~1500 character long text + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str +): + """Test embedding processing for ~1500 character long text (~1028 tokens, exceeding 512 token limit).""" # Verify text length # Verify text has sufficient word count (approximately 1500 words) word_count = len(LONG_TEXT_1500_WORDS.split()) - assert word_count >= 1400, ( - f"Test text word count insufficient: {word_count} words") + assert word_count >= 1400, f"Test text word count insufficient: {word_count} words" # Send embedding request embedding_response = await client_with_chunked_processing.embeddings.create( @@ -257,12 +255,14 @@ async def test_long_text_embedding_1500_chars( # Verify response structure embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 - assert len(embeddings.data[0].embedding - ) == 384 # multilingual-e5-small embedding dimension + assert ( + len(embeddings.data[0].embedding) == 384 + ) # multilingual-e5-small embedding dimension assert embeddings.usage.completion_tokens == 0 # Due to chunked processing, token count should # reflect actual processed tokens @@ -274,26 +274,26 @@ async def test_long_text_embedding_1500_chars( # Verify embedding vector validity embedding_vector = embeddings.data[0].embedding - assert all( - isinstance(x, float) - for x in embedding_vector), "Embedding vector should contain floats" - assert not all( - x == 0 - for x in embedding_vector), "Embedding vector should not be all zeros" + assert all(isinstance(x, float) for x in embedding_vector), ( + "Embedding vector should contain floats" + ) + assert not all(x == 0 for x in embedding_vector), ( + "Embedding vector should not be all zeros" + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_long_text_embedding_2500_chars( - client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str +): """Test embedding processing for ~2500 character long text (~2048 tokens, requiring multiple chunks).""" # Verify text length # Verify text has sufficient word count (approximately 2500 words) word_count = len(LONG_TEXT_2500_WORDS.split()) - assert word_count >= 2300, ( - f"Test text word count insufficient: {word_count} words") + assert word_count >= 2300, f"Test text word count insufficient: {word_count} words" # Send embedding request embedding_response = await client_with_chunked_processing.embeddings.create( @@ -304,12 +304,14 @@ async def test_long_text_embedding_2500_chars( # Verify response structure embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 - assert len(embeddings.data[0].embedding - ) == 384 # multilingual-e5-small embedding dimension + assert ( + len(embeddings.data[0].embedding) == 384 + ) # multilingual-e5-small embedding dimension assert embeddings.usage.completion_tokens == 0 # Due to chunked processing, token count should # reflect actual processed tokens @@ -321,18 +323,19 @@ async def test_long_text_embedding_2500_chars( # Verify embedding vector validity embedding_vector = embeddings.data[0].embedding - assert all( - isinstance(x, float) - for x in embedding_vector), "Embedding vector should contain floats" - assert not all( - x == 0 - for x in embedding_vector), "Embedding vector should not be all zeros" + assert all(isinstance(x, float) for x in embedding_vector), ( + "Embedding vector should contain floats" + ) + assert not all(x == 0 for x in embedding_vector), ( + "Embedding vector should not be all zeros" + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_batch_long_text_embedding( - client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str +): """Test batch long text embedding processing.""" input_texts = [ @@ -350,7 +353,8 @@ async def test_batch_long_text_embedding( # Verify response structure embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 3 # Three input texts @@ -375,13 +379,16 @@ async def test_batch_long_text_embedding( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_chunked_vs_normal_consistency( - client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str +): """Test consistency between chunked and normal processing (using short text).""" # Use a short text within the 512 token limit - short_text = ("Artificial intelligence technology is changing our world, " - "bringing unprecedented opportunities and challenges.") + short_text = ( + "Artificial intelligence technology is changing our world, " + "bringing unprecedented opportunities and challenges." + ) # Send embedding request embedding_response = await client_with_chunked_processing.embeddings.create( @@ -392,7 +399,8 @@ async def test_chunked_vs_normal_consistency( # Verify response structure embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -411,7 +419,8 @@ async def test_chunked_vs_normal_consistency( @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_chunked_processing_response_format( - client_with_chunked_processing: openai.AsyncOpenAI, model_name: str): + client_with_chunked_processing: openai.AsyncOpenAI, model_name: str +): """Test response format and structure during chunked processing.""" # Test with long text to trigger chunking @@ -423,7 +432,8 @@ async def test_chunked_processing_response_format( # Verify response structure embeddings = EmbeddingResponse.model_validate( - embedding_response.model_dump(mode="json")) + embedding_response.model_dump(mode="json") + ) assert embeddings.id is not None assert len(embeddings.data) == 1 @@ -433,8 +443,10 @@ async def test_chunked_processing_response_format( # Verify embedding vector properties embedding_vector = embeddings.data[0].embedding import math + vector_norm = math.sqrt(sum(x * x for x in embedding_vector)) # Check that the vector is normalized # (default behavior for most embedding models) assert 0.8 < vector_norm < 1.2, ( - f"Vector norm should be reasonable, actual: {vector_norm}") + f"Vector norm should be reasonable, actual: {vector_norm}" + ) diff --git a/tests/entrypoints/pooling/openai/test_pooling.py b/tests/entrypoints/pooling/openai/test_pooling.py index 9f58955cfb40..3439c556ccc4 100644 --- a/tests/entrypoints/pooling/openai/test_pooling.py +++ b/tests/entrypoints/pooling/openai/test_pooling.py @@ -46,11 +46,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): # test single pooling response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_texts, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_texts, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -66,11 +62,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): input_tokens = [1, 1, 1, 1, 1] response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_tokens, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_tokens, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -88,16 +80,13 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): # test list[str] input_texts = [ - "The cat sat on the mat.", "A feline was resting on a rug.", - "Stars twinkle brightly in the night sky." + "The cat sat on the mat.", + "A feline was resting on a rug.", + "Stars twinkle brightly in the night sky.", ] response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_texts, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_texts, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -110,15 +99,15 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.usage.total_tokens == 29 # test list[list[int]] - input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], - [25, 32, 64, 77]] + input_tokens = [ + [4, 5, 7, 9, 20], + [15, 29, 499], + [24, 24, 24, 24, 24], + [25, 32, 64, 77], + ] response = requests.post( server.url_for("pooling"), - json={ - "model": model_name, - "input": input_tokens, - "encoding_format": "float" - }, + json={"model": model_name, "input": input_tokens, "encoding_format": "float"}, ) response.raise_for_status() poolings = PoolingResponse.model_validate(response.json()) @@ -133,18 +122,21 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_conversation_pooling(server: RemoteOpenAIServer, - model_name: str): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] +async def test_conversation_pooling(server: RemoteOpenAIServer, model_name: str): + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] chat_response = requests.post( server.url_for("pooling"), @@ -180,24 +172,22 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, }, ) completions_response.raise_for_status() - completion_poolings = PoolingResponse.model_validate( - completions_response.json()) + completion_poolings = PoolingResponse.model_validate(completions_response.json()) assert chat_poolings.id is not None assert completion_poolings.id is not None assert chat_poolings.created <= completion_poolings.created - assert chat_poolings.model_dump( - exclude={"id", "created"}) == (completion_poolings.model_dump( - exclude={"id", "created"})) + assert chat_poolings.model_dump(exclude={"id", "created"}) == ( + completion_poolings.model_dump(exclude={"id", "created"}) + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_batch_base64_pooling(server: RemoteOpenAIServer, - model_name: str): +async def test_batch_base64_pooling(server: RemoteOpenAIServer, model_name: str): input_texts = [ "Hello my name is", - "The best thing about vLLM is that it supports many different models" + "The best thing about vLLM is that it supports many different models", ] float_response = requests.post( @@ -210,9 +200,7 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ) float_response.raise_for_status() responses_float = PoolingResponse.model_validate(float_response.json()) - float_data = [ - np.array(d.data).squeeze(-1).tolist() for d in responses_float.data - ] + float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data] base64_response = requests.post( server.url_for("pooling"), @@ -228,13 +216,15 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, decoded_responses_base64_data = [] for data in responses_base64.data: decoded_responses_base64_data.append( - np.frombuffer(base64.b64decode(data.data), - dtype="float32").tolist()) - - check_embeddings_close(embeddings_0_lst=float_data, - embeddings_1_lst=decoded_responses_base64_data, - name_0="float32", - name_1="base64") + np.frombuffer(base64.b64decode(data.data), dtype="float32").tolist() + ) + + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=decoded_responses_base64_data, + name_0="float32", + name_1="base64", + ) # Default response is float32 decoded from base64 by OpenAI Client default_response = requests.post( @@ -250,10 +240,12 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, np.array(d.data).squeeze(-1).tolist() for d in responses_default.data ] - check_embeddings_close(embeddings_0_lst=float_data, - embeddings_1_lst=default_data, - name_0="float32", - name_1="default") + check_embeddings_close( + embeddings_0_lst=float_data, + embeddings_1_lst=default_data, + name_0="float32", + name_1="default", + ) @pytest.mark.asyncio @@ -268,39 +260,46 @@ async def test_invocations(server: RemoteOpenAIServer): "encoding_format": "float", } - completion_response = requests.post(server.url_for("pooling"), - json=request_args) + completion_response = requests.post(server.url_for("pooling"), json=request_args) completion_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() completion_output = completion_response.json() invocation_output = invocation_response.json() assert completion_output.keys() == invocation_output.keys() - for completion_data, invocation_data in zip(completion_output["data"], - invocation_output["data"]): + for completion_data, invocation_data in zip( + completion_output["data"], invocation_output["data"] + ): assert completion_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=completion_data["data"], - embeddings_1_lst=invocation_data["data"], - name_0="completion", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=completion_data["data"], + embeddings_1_lst=invocation_data["data"], + name_0="completion", + name_1="invocation", + ) @pytest.mark.asyncio async def test_invocations_conversation(server: RemoteOpenAIServer): - messages = [{ - "role": "user", - "content": "The cat sat on the mat.", - }, { - "role": "assistant", - "content": "A feline was resting on a rug.", - }, { - "role": "user", - "content": "Stars twinkle brightly in the night sky.", - }] + messages = [ + { + "role": "user", + "content": "The cat sat on the mat.", + }, + { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, + { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }, + ] request_args = { "model": MODEL_NAME, @@ -311,18 +310,22 @@ async def test_invocations_conversation(server: RemoteOpenAIServer): chat_response = requests.post(server.url_for("pooling"), json=request_args) chat_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() chat_output = chat_response.json() invocation_output = invocation_response.json() assert chat_output.keys() == invocation_output.keys() - for chat_data, invocation_data in zip(chat_output["data"], - invocation_output["data"]): + for chat_data, invocation_data in zip( + chat_output["data"], invocation_output["data"] + ): assert chat_data.keys() == invocation_data.keys() - check_embeddings_close(embeddings_0_lst=chat_data["data"], - embeddings_1_lst=invocation_data["data"], - name_0="chat", - name_1="invocation") + check_embeddings_close( + embeddings_0_lst=chat_data["data"], + embeddings_1_lst=invocation_data["data"], + name_0="chat", + name_1="invocation", + ) diff --git a/tests/entrypoints/pooling/openai/test_rerank.py b/tests/entrypoints/pooling/openai/test_rerank.py index 992cb5147ef0..9980fcff16c1 100644 --- a/tests/entrypoints/pooling/openai/test_rerank.py +++ b/tests/entrypoints/pooling/openai/test_rerank.py @@ -25,15 +25,18 @@ def server(): def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): query = "What is the capital of France?" documents = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", ] - rerank_response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents, - }) + rerank_response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + }, + ) rerank_response.raise_for_status() rerank = RerankResponse.model_validate(rerank_response.json()) @@ -49,16 +52,14 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): query = "What is the capital of France?" documents = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris.", "Cross-encoder models are neat" + "The capital of France is Paris.", + "Cross-encoder models are neat", ] - rerank_response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents, - "top_n": 2 - }) + rerank_response = requests.post( + server.url_for("rerank"), + json={"model": model_name, "query": query, "documents": documents, "top_n": 2}, + ) rerank_response.raise_for_status() rerank = RerankResponse.model_validate(rerank_response.json()) @@ -71,28 +72,26 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str): @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): - query = "What is the capital of France?" * 100 documents = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", ] - rerank_response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents - }) + rerank_response = requests.post( + server.url_for("rerank"), + json={"model": model_name, "query": query, "documents": documents}, + ) assert rerank_response.status_code == 400 # Assert just a small fragments of the response - assert "Please reduce the length of the input." in \ - rerank_response.text + assert "Please reduce the length of the input." in rerank_response.text def test_invocations(server: RemoteOpenAIServer): query = "What is the capital of France?" documents = [ - "The capital of Brazil is Brasilia.", "The capital of France is Paris." + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", ] request_args = { @@ -101,23 +100,25 @@ def test_invocations(server: RemoteOpenAIServer): "documents": documents, } - rerank_response = requests.post(server.url_for("rerank"), - json=request_args) + rerank_response = requests.post(server.url_for("rerank"), json=request_args) rerank_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() rerank_output = rerank_response.json() invocation_output = invocation_response.json() assert rerank_output.keys() == invocation_output.keys() - for rerank_result, invocations_result in zip(rerank_output["results"], - invocation_output["results"]): + for rerank_result, invocations_result in zip( + rerank_output["results"], invocation_output["results"] + ): assert rerank_result.keys() == invocations_result.keys() assert rerank_result["relevance_score"] == pytest.approx( - invocations_result["relevance_score"], rel=0.05) + invocations_result["relevance_score"], rel=0.05 + ) # TODO: reset this tolerance to 0.01 once we find # an alternative to flash_attn with bfloat16 @@ -125,34 +126,36 @@ def test_invocations(server: RemoteOpenAIServer): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_activation(server: RemoteOpenAIServer, model_name: str): - async def get_outputs(activation): query = "What is the capital of France?" documents = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - response = requests.post(server.url_for("rerank"), - json={ - "model": model_name, - "query": query, - "documents": documents, - "activation": activation - }) + response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + "activation": activation, + }, + ) outputs = response.json() - return torch.tensor([x['relevance_score'] for x in outputs["results"]]) + return torch.tensor([x["relevance_score"] for x in outputs["results"]]) default = await get_outputs(activation=None) w_activation = await get_outputs(activation=True) wo_activation = await get_outputs(activation=False) - assert torch.allclose(default, w_activation, - atol=1e-2), "Default should use activation." - assert not torch.allclose( - w_activation, wo_activation, - atol=1e-2), "wo_activation should not use activation." - assert torch.allclose( - F.sigmoid(wo_activation), w_activation, atol=1e-2 - ), "w_activation should be close to activation(wo_activation)." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." + ) + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." + ) diff --git a/tests/entrypoints/pooling/openai/test_score.py b/tests/entrypoints/pooling/openai/test_score.py index d676ecccbc87..ef213ab0ea18 100644 --- a/tests/entrypoints/pooling/openai/test_score.py +++ b/tests/entrypoints/pooling/openai/test_score.py @@ -12,14 +12,8 @@ from vllm.entrypoints.openai.protocol import ScoreResponse MODELS = [ - { - "name": "BAAI/bge-reranker-v2-m3", - "is_cross_encoder": True - }, - { - "name": "BAAI/bge-base-en-v1.5", - "is_cross_encoder": False - }, + {"name": "BAAI/bge-reranker-v2-m3", "is_cross_encoder": True}, + {"name": "BAAI/bge-base-en-v1.5", "is_cross_encoder": False}, ] DTYPE = "half" @@ -28,9 +22,7 @@ def run_transformers(hf_model, model, text_pairs): if model["is_cross_encoder"]: return hf_model.predict(text_pairs).tolist() else: - hf_embeddings = [ - hf_model.encode(text_pair) for text_pair in text_pairs - ] + hf_embeddings = [hf_model.encode(text_pair) for text_pair in text_pairs] return [ F.cosine_similarity(tensor(pair[0]), tensor(pair[1]), dim=0) for pair in hf_embeddings @@ -54,8 +46,9 @@ def server(model: dict[str, Any]): def runner(model: dict[str, Any], hf_runner): kwargs = { "dtype": DTYPE, - "is_cross_encoder" if model["is_cross_encoder"]\ - else "is_sentence_transformer": True + "is_cross_encoder" + if model["is_cross_encoder"] + else "is_sentence_transformer": True, } with hf_runner(model["name"], **kwargs) as hf_model: @@ -63,21 +56,23 @@ def runner(model: dict[str, Any], hf_runner): class TestModel: - - def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer, - model: dict[str, Any], runner): + def test_text_1_str_text_2_list( + self, server: RemoteOpenAIServer, model: dict[str, Any], runner + ): text_1 = "What is the capital of France?" text_2 = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) score_response.raise_for_status() score = ScoreResponse.model_validate(score_response.json()) @@ -93,23 +88,26 @@ def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer, for i in range(len(vllm_outputs)): assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) - def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer, - model: dict[str, Any], runner): + def test_text_1_list_text_2_list( + self, server: RemoteOpenAIServer, model: dict[str, Any], runner + ): text_1 = [ "What is the capital of the United States?", - "What is the capital of France?" + "What is the capital of France?", ] text_2 = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) score_response.raise_for_status() score = ScoreResponse.model_validate(score_response.json()) @@ -125,17 +123,20 @@ def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer, for i in range(len(vllm_outputs)): assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) - def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer, - model: dict[str, Any], runner): + def test_text_1_str_text_2_str( + self, server: RemoteOpenAIServer, model: dict[str, Any], runner + ): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) score_response.raise_for_status() score = ScoreResponse.model_validate(score_response.json()) @@ -151,40 +152,41 @@ def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer, for i in range(len(vllm_outputs)): assert hf_outputs[i] == pytest.approx(vllm_outputs[i], rel=0.01) - def test_score_max_model_len(self, server: RemoteOpenAIServer, - model: dict[str, Any]): - + def test_score_max_model_len( + self, server: RemoteOpenAIServer, model: dict[str, Any] + ): text_1 = "What is the capital of France?" * 20 text_2 = [ "The capital of Brazil is Brasilia.", - "The capital of France is Paris." + "The capital of France is Paris.", ] - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + }, + ) assert score_response.status_code == 400 # Assert just a small fragments of the response - assert "Please reduce the length of the input." in \ - score_response.text + assert "Please reduce the length of the input." in score_response.text # Test truncation - score_response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - "truncate_prompt_tokens": 101 - }) + score_response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + "truncate_prompt_tokens": 101, + }, + ) assert score_response.status_code == 400 - assert "Please, select a smaller truncation size." in \ - score_response.text + assert "Please, select a smaller truncation size." in score_response.text - def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, - Any]): + def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, Any]): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." @@ -194,59 +196,61 @@ def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, "text_2": text_2, } - score_response = requests.post(server.url_for("score"), - json=request_args) + score_response = requests.post(server.url_for("score"), json=request_args) score_response.raise_for_status() - invocation_response = requests.post(server.url_for("invocations"), - json=request_args) + invocation_response = requests.post( + server.url_for("invocations"), json=request_args + ) invocation_response.raise_for_status() score_output = score_response.json() invocation_output = invocation_response.json() assert score_output.keys() == invocation_output.keys() - for score_data, invocation_data in zip(score_output["data"], - invocation_output["data"]): + for score_data, invocation_data in zip( + score_output["data"], invocation_output["data"] + ): assert score_data.keys() == invocation_data.keys() assert score_data["score"] == pytest.approx( - invocation_data["score"], rel=0.05) + invocation_data["score"], rel=0.05 + ) # TODO: reset this tolerance to 0.01 once we find # an alternative to flash_attn with bfloat16 - def test_activation(self, server: RemoteOpenAIServer, model: dict[str, - Any]): - + def test_activation(self, server: RemoteOpenAIServer, model: dict[str, Any]): def get_outputs(activation): text_1 = "What is the capital of France?" text_2 = "The capital of France is Paris." - response = requests.post(server.url_for("score"), - json={ - "model": model["name"], - "text_1": text_1, - "text_2": text_2, - "activation": activation - }) + response = requests.post( + server.url_for("score"), + json={ + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + "activation": activation, + }, + ) if response.status_code != 200: return response outputs = response.json() - return torch.tensor([x['score'] for x in outputs["data"]]) + return torch.tensor([x["score"] for x in outputs["data"]]) if model["is_cross_encoder"]: - default = get_outputs(activation=None) w_activation = get_outputs(activation=True) wo_activation = get_outputs(activation=False) - assert torch.allclose(default, w_activation, - atol=1e-2), "Default should use activation." - assert not torch.allclose( - w_activation, wo_activation, - atol=1e-2), "wo_activation should not use activation." - assert torch.allclose( - F.sigmoid(wo_activation), w_activation, atol=1e-2 - ), "w_activation should be close to activation(wo_activation)." + assert torch.allclose(default, w_activation, atol=1e-2), ( + "Default should use activation." + ) + assert not torch.allclose(w_activation, wo_activation, atol=1e-2), ( + "wo_activation should not use activation." + ) + assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), ( + "w_activation should be close to activation(wo_activation)." + ) else: get_outputs(activation=None) diff --git a/tests/entrypoints/pooling/openai/test_truncation.py b/tests/entrypoints/pooling/openai/test_truncation.py index 6bdf5ce7c4a6..6889628dc914 100644 --- a/tests/entrypoints/pooling/openai/test_truncation.py +++ b/tests/entrypoints/pooling/openai/test_truncation.py @@ -54,12 +54,10 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI): kwargs: dict[str, Any] = { "model": MODEL_NAME, "input": input, - "truncate_prompt_tokens": truncation_size + "truncate_prompt_tokens": truncation_size, } - response = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) + response = await client.post(path="embeddings", cast_to=object, body={**kwargs}) assert response["usage"]["prompt_tokens"] == truncation_size @@ -70,12 +68,10 @@ async def test_zero_truncation_size(client: openai.AsyncOpenAI): kwargs: dict[str, Any] = { "model": MODEL_NAME, "input": input, - "truncate_prompt_tokens": truncation_size + "truncate_prompt_tokens": truncation_size, } - response = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) + response = await client.post(path="embeddings", cast_to=object, body={**kwargs}) assert response["usage"]["prompt_tokens"] == truncation_size @@ -86,7 +82,7 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI): kwargs: dict[str, Any] = { "model": MODEL_NAME, "input": input, - "truncate_prompt_tokens": truncation_size + "truncate_prompt_tokens": truncation_size, } with pytest.raises(openai.BadRequestError) as err: @@ -95,9 +91,11 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI): assert err.value.status_code == 400 error_details = err.value.response.json()["error"] assert error_details["type"] == "BadRequestError" - expected_message = ("truncate_prompt_tokens value is " - "greater than max_model_len." - " Please, select a smaller truncation size.") + expected_message = ( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size." + ) assert error_details["message"] == expected_message @@ -107,11 +105,9 @@ async def test_max_truncation_size(client: openai.AsyncOpenAI): kwargs: dict[str, Any] = { "model": MODEL_NAME, "input": input, - "truncate_prompt_tokens": truncation_size + "truncate_prompt_tokens": truncation_size, } - response = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) + response = await client.post(path="embeddings", cast_to=object, body={**kwargs}) assert response["usage"]["prompt_tokens"] == max_model_len diff --git a/tests/entrypoints/pooling/openai/test_vision_embedding.py b/tests/entrypoints/pooling/openai/test_vision_embedding.py index a30413bc3298..944392d66fa5 100644 --- a/tests/entrypoints/pooling/openai/test_vision_embedding.py +++ b/tests/entrypoints/pooling/openai/test_vision_embedding.py @@ -50,16 +50,15 @@ def server(): @pytest.fixture(scope="session") def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: - encode_image_base64(local_asset_server.get_image_asset(image_url)) + image_url: encode_image_base64(local_asset_server.get_image_asset(image_url)) for image_url in TEST_IMAGE_ASSETS } def get_hf_prompt_tokens(model_name, content, image_url): - processor = AutoProcessor.from_pretrained(model_name, - trust_remote_code=True, - num_crops=4) + processor = AutoProcessor.from_pretrained( + model_name, trust_remote_code=True, num_crops=4 + ) placeholder = "<|image_1|> " prompt = f"{placeholder}{content}" @@ -71,39 +70,28 @@ def get_hf_prompt_tokens(model_name, content, image_url): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) -async def test_image_embedding(server: RemoteOpenAIServer, model_name: str, - image_url: str): +async def test_image_embedding( + server: RemoteOpenAIServer, model_name: str, image_url: str +): content_text = "Represent the given image." - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": content_text}, + ], + } + ] response = requests.post( server.url_for("v1/embeddings"), - json={ - "model": model_name, - "messages": messages, - "encoding_format": "float" - }, + json={"model": model_name, "messages": messages, "encoding_format": "float"}, ) response.raise_for_status() embeddings = EmbeddingResponse.model_validate(response.json()) - hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, - image_url) + hf_prompt_tokens = get_hf_prompt_tokens(model_name, content_text, image_url) assert embeddings.id is not None assert len(embeddings.data) == 1 diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py index 34b05ad17b02..e548f52e1e94 100644 --- a/tests/entrypoints/test_api_server_process_manager.py +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -10,8 +10,7 @@ import pytest -from vllm.v1.utils import (APIServerProcessManager, - wait_for_completion_or_failure) +from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure # Global variables to control worker behavior WORKER_RUNTIME_SECONDS = 0.5 @@ -30,26 +29,22 @@ def api_server_args(): """Fixture to provide arguments for APIServerProcessManager.""" sock = socket.socket() return { - "target_server_fn": - mock_run_api_server_worker, - "listen_address": - "localhost:8000", - "sock": - sock, - "args": - "test_args", # Simple string to avoid pickling issues - "num_servers": - 3, + "target_server_fn": mock_run_api_server_worker, + "listen_address": "localhost:8000", + "sock": sock, + "args": "test_args", # Simple string to avoid pickling issues + "num_servers": 3, "input_addresses": [ - "tcp://127.0.0.1:5001", "tcp://127.0.0.1:5002", - "tcp://127.0.0.1:5003" + "tcp://127.0.0.1:5001", + "tcp://127.0.0.1:5002", + "tcp://127.0.0.1:5003", ], "output_addresses": [ - "tcp://127.0.0.1:6001", "tcp://127.0.0.1:6002", - "tcp://127.0.0.1:6003" + "tcp://127.0.0.1:6001", + "tcp://127.0.0.1:6002", + "tcp://127.0.0.1:6003", ], - "stats_update_address": - "tcp://127.0.0.1:7000", + "stats_update_address": "tcp://127.0.0.1:7000", } @@ -95,8 +90,9 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update): assert not proc.is_alive() -@patch("vllm.entrypoints.cli.serve.run_api_server_worker_proc", - mock_run_api_server_worker) +@patch( + "vllm.entrypoints.cli.serve.run_api_server_worker_proc", mock_run_api_server_worker +) def test_wait_for_completion_or_failure(api_server_args): """Test that wait_for_completion_or_failure works with failures.""" global WORKER_RUNTIME_SECONDS @@ -118,8 +114,7 @@ def run_with_exception_capture(): result["exception"] = e # Start a thread to run wait_for_completion_or_failure - wait_thread = threading.Thread(target=run_with_exception_capture, - daemon=True) + wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) wait_thread.start() # Let all processes run for a short time @@ -174,8 +169,7 @@ def test_normal_completion(api_server_args): # Verify all processes have terminated for i, proc in enumerate(manager.processes): - assert not proc.is_alive( - ), f"Process {i} still alive after terminate()" + assert not proc.is_alive(), f"Process {i} still alive after terminate()" # Now call wait_for_completion_or_failure # since all processes have already @@ -198,13 +192,13 @@ def test_external_process_monitoring(api_server_args): # Create and start the external process # (simulates local_engine_manager or coordinator) spawn_context = multiprocessing.get_context("spawn") - external_proc = spawn_context.Process(target=mock_run_api_server_worker, - name="MockExternalProcess") + external_proc = spawn_context.Process( + target=mock_run_api_server_worker, name="MockExternalProcess" + ) external_proc.start() # Create the class to simulate a coordinator class MockCoordinator: - def __init__(self, proc): self.proc = proc @@ -228,14 +222,14 @@ def close(self): def run_with_exception_capture(): try: - wait_for_completion_or_failure(api_server_manager=manager, - coordinator=mock_coordinator) + wait_for_completion_or_failure( + api_server_manager=manager, coordinator=mock_coordinator + ) except Exception as e: result["exception"] = e # Start a thread to run wait_for_completion_or_failure - wait_thread = threading.Thread(target=run_with_exception_capture, - daemon=True) + wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) wait_thread.start() # Terminate the external process to trigger a failure @@ -246,21 +240,23 @@ def run_with_exception_capture(): wait_thread.join(timeout=1.0) # The wait thread should have completed - assert not wait_thread.is_alive( - ), "wait_for_completion_or_failure thread still running" + assert not wait_thread.is_alive(), ( + "wait_for_completion_or_failure thread still running" + ) # Verify that an exception was raised with appropriate error message assert result["exception"] is not None, "No exception was raised" error_message = str(result["exception"]) - assert "died with exit code" in error_message, \ + assert "died with exit code" in error_message, ( f"Unexpected error message: {error_message}" - assert "MockExternalProcess" in error_message, \ + ) + assert "MockExternalProcess" in error_message, ( f"Error doesn't mention external process: {error_message}" + ) # Verify that all API server processes were terminated as a result for i, proc in enumerate(manager.processes): - assert not proc.is_alive( - ), f"API server process {i} was not terminated" + assert not proc.is_alive(), f"API server process {i} was not terminated" finally: # Clean up diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 85b305c2fa02..6e92419c4f67 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -6,24 +6,29 @@ from typing import Literal, Optional import pytest -from mistral_common.tokens.tokenizers.base import (SpecialTokenPolicy, - SpecialTokens) -from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo, - Tekkenizer) +from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens +from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.config import ModelConfig -from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, - parse_chat_messages, - parse_chat_messages_futures, - resolve_chat_template_content_format, - resolve_chat_template_kwargs, - resolve_hf_chat_template) +from vllm.entrypoints.chat_utils import ( + _try_extract_ast, + apply_mistral_chat_template, + load_chat_template, + parse_chat_messages, + parse_chat_messages_futures, + resolve_chat_template_content_format, + resolve_chat_template_kwargs, + resolve_hf_chat_template, +) from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict -from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, - encode_video_base64) +from vllm.multimodal.utils import ( + encode_audio_base64, + encode_image_base64, + encode_video_base64, +) from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer @@ -177,8 +182,7 @@ def _assert_mm_uuids( image_uuids = mm_uuids.get(modality) assert image_uuids is not None - assert isinstance(image_uuids, - list) and len(image_uuids) == media_count + assert isinstance(image_uuids, list) and len(image_uuids) == media_count assert image_uuids == expected_uuids else: @@ -190,10 +194,9 @@ def _assert_mm_uuids( def _assert_mm_data_inputs( - mm_data: Optional[MultiModalDataDict], - data_count: MultiModalDataCounts, - skipped_media_indices: Optional[dict[ - str, list]] = None, # modality -> list[int] + mm_data: Optional[MultiModalDataDict], + data_count: MultiModalDataCounts, + skipped_media_indices: Optional[dict[str, list]] = None, # modality -> list[int] ) -> None: assert mm_data is not None assert set(data_count.keys()) == (set(mm_data.keys())) @@ -204,8 +207,7 @@ def _assert_mm_data_inputs( assert isinstance(modality_data, list) and len(modality_data) == n if skipped_media_indices is not None: - skipped_media_indices_for_modality = skipped_media_indices.get( - modality) + skipped_media_indices_for_modality = skipped_media_indices.get(modality) assert skipped_media_indices_for_modality is not None for i in skipped_media_indices_for_modality: assert modality_data[i] is None @@ -217,31 +219,23 @@ def test_parse_chat_messages_single_image( image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in the image?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] _assert_mm_data_is_image_input(mm_data, 1) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) @@ -253,32 +247,29 @@ def test_parse_chat_messages_single_image_with_uuid( ): image_uuid = str(hash(image_url)) conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url, + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid, }, - "uuid": image_uuid, - }, - { - "type": "text", - "text": "What's in the image?" - }, - ], - }], + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] _assert_mm_data_is_image_input(mm_data, 1) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) @@ -290,30 +281,27 @@ def test_parse_chat_messages_single_empty_image_with_uuid( ): image_uuid = str(hash(image_url)) conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": None, - "uuid": image_uuid, - }, - { - "type": "text", - "text": "What's in the image?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] _assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0]) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) @@ -325,33 +313,30 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format( ): image_uuid = str(hash(image_url)) conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url, - "uuid": image_uuid, + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + "uuid": image_uuid, + }, + "bad_uuid_key": image_uuid, }, - "bad_uuid_key": image_uuid, - }, - { - "type": "text", - "text": "What's in the image?" - }, - ], - }], + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] _assert_mm_data_is_image_input(mm_data, 1) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) @@ -365,41 +350,39 @@ def test_parse_chat_messages_multiple_images_with_uuids( image_uuid2 = "my_uuid_2" conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url, + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid1, }, - "uuid": image_uuid1, - }, - { - "type": "image_url", - "image_url": { - "url": image_url, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid2, }, - "uuid": image_uuid2, - }, - { - "type": "text", - "text": "What's in the image?" - }, - ], - }], + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in the image?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in the image?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) @@ -413,37 +396,35 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids( image_uuid2 = "my_uuid_2" conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": None, - "uuid": image_uuid1, - }, - { - "type": "image_url", - "image_url": None, - "uuid": image_uuid2, - }, - { - "type": "text", - "text": "What's in the image?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid1, + }, + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in the image?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in the image?", + } + ] _assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[0, 1]) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) @@ -457,39 +438,37 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids( image_uuid2 = "my_uuid_2" conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url, + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + "uuid": image_uuid1, }, - "uuid": image_uuid1, - }, - { - "type": "image_url", - "image_url": None, - "uuid": image_uuid2, - }, - { - "type": "text", - "text": "What's in the image?" - }, - ], - }], + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in the image?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in the image?", + } + ] _assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[1]) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) @@ -502,32 +481,27 @@ async def test_parse_chat_messages_single_image_with_uuid_async( ): image_uuid = str(hash(image_url)) conversation, mm_future, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid, }, - "uuid": image_uuid, - }, - { - "type": "text", - "text": "What's in the image?" - }, - ], - }], + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] _assert_mm_data_is_image_input(await mm_future, 1) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) @@ -540,33 +514,28 @@ async def test_parse_chat_messages_empty_image_with_uuid_async( ): image_uuid = str(hash(image_url)) conversation, mm_future, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": None, - "uuid": image_uuid, - }, - { - "type": "text", - "text": "What's in the image?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid, + }, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] - _assert_mm_data_is_image_input(await mm_future, - 1, - skipped_image_indices=[0]) + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] + _assert_mm_data_is_image_input(await mm_future, 1, skipped_image_indices=[0]) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid]) @@ -580,39 +549,35 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async( image_uuid2 = "my_uuid_2" conversation, mm_future, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid1, }, - "uuid": image_uuid1, - }, - { - "type": "image_pil", - "image_pil": ImageAsset("cherry_blossom").pil_image, - "uuid": image_uuid2, - }, - { - "type": "text", - "text": "What's in these images?" - }, - ], - }], + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] _assert_mm_data_is_image_input(await mm_future, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) @@ -627,40 +592,36 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( image_uuid2 = "my_uuid_2" conversation, mm_future, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": None, - "uuid": image_uuid1, - }, - { - "type": "image_pil", - "image_pil": None, - "uuid": image_uuid2, - }, - { - "type": "text", - "text": "What's in these images?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": None, + "uuid": image_uuid1, + }, + { + "type": "image_pil", + "image_pil": None, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?", - }] - _assert_mm_data_is_image_input(await mm_future, - 2, - skipped_image_indices=[0, 1]) + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] + _assert_mm_data_is_image_input(await mm_future, 2, skipped_image_indices=[0, 1]) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2]) @@ -673,38 +634,34 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( image_uuid2 = "my_uuid_2" conversation, mm_future, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, }, - }, - { - "type": "image_pil", - "image_pil": ImageAsset("cherry_blossom").pil_image, - "uuid": image_uuid2, - }, - { - "type": "text", - "text": "What's in these images?" - }, - ], - }], + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + "uuid": image_uuid2, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] _assert_mm_data_is_image_input(await mm_future, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, image_uuid2]) @@ -716,16 +673,10 @@ def test_parse_chat_messages_empty_system( # Test string format conversation, _, _ = parse_chat_messages( [ - { - "role": "system", - "content": "" - }, + {"role": "system", "content": ""}, { "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }], + "content": [{"type": "text", "text": "Who are you?"}], }, ], mistral_model_config, @@ -733,29 +684,17 @@ def test_parse_chat_messages_empty_system( content_format="string", ) assert conversation == [ - { - "role": "system", - "content": "" - }, - { - "role": "user", - "content": "Who are you?" - }, + {"role": "system", "content": ""}, + {"role": "user", "content": "Who are you?"}, ] # Test openai format conversation, _, _ = parse_chat_messages( [ - { - "role": "system", - "content": "" - }, + {"role": "system", "content": ""}, { "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }], + "content": [{"type": "text", "text": "Who are you?"}], }, ], mistral_model_config, @@ -763,20 +702,8 @@ def test_parse_chat_messages_empty_system( content_format="openai", ) assert conversation == [ - { - "role": "system", - "content": [{ - "type": "text", - "text": "" - }] - }, - { - "role": "user", - "content": [{ - "type": "text", - "text": "Who are you?" - }] - }, + {"role": "system", "content": [{"type": "text", "text": ""}]}, + {"role": "user", "content": [{"type": "text", "text": "Who are you?"}]}, ] @@ -787,31 +714,23 @@ async def test_parse_chat_messages_single_image_async( image_url, ): conversation, mm_future, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in the image?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in the image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in the image?" - }] + assert conversation == [ + {"role": "user", "content": "<|image_1|>\nWhat's in the image?"} + ] _assert_mm_data_is_image_input(await mm_future, 1) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[None]) @@ -822,37 +741,30 @@ def test_parse_chat_messages_multiple_images( image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "image_pil", - "image_pil": ImageAsset("cherry_blossom").pil_image, - }, - { - "type": "text", - "text": "What's in these images?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @@ -863,30 +775,26 @@ def test_parse_chat_messages_empty_pil_image_with_uuid( ): uuid = "abcd" conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_pil", - "image_pil": None, - "uuid": uuid - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + {"type": "image_pil", "image_pil": None, "uuid": uuid}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in this image?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + } + ] _assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0]) _assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid]) @@ -897,30 +805,26 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( ): uuid = "abcd" conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_embeds", - "image_embeds": None, - "uuid": uuid - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + {"type": "image_embeds", "image_embeds": None, "uuid": uuid}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in this image?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + } + ] assert mm_data is not None assert "image" in mm_data assert mm_data["image"] is None @@ -934,30 +838,26 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( ): uuid = "abcd" conversation, mm_future, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "image_embeds", - "image_embeds": None, - "uuid": uuid - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + {"type": "image_embeds", "image_embeds": None, "uuid": uuid}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": "user", - "content": "<|image_1|>\nWhat's in this image?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?", + } + ] mm_data = await mm_future assert mm_data is not None assert "image" in mm_data @@ -972,37 +872,30 @@ async def test_parse_chat_messages_multiple_images_async( image_url, ): conversation, mm_future, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "image_pil", - "image_pil": ImageAsset("cherry_blossom").pil_image, - }, - { - "type": "text", - "text": "What's in these images?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "image_pil", + "image_pil": ImageAsset("cherry_blossom").pil_image, + }, + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] _assert_mm_data_is_image_input(await mm_future, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @@ -1013,40 +906,29 @@ def test_parse_chat_messages_placeholder_already_in_prompt( image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to <|image_2|>?", # noqa: E501 - }, - ], - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": "What's in <|image_1|> and how does it compare to <|image_2|>?", # noqa: E501 + }, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "What's in <|image_1|> and how does it compare to <|image_2|>?", - }] + assert conversation == [ + { + "role": "user", + "content": "What's in <|image_1|> and how does it compare to <|image_2|>?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @@ -1057,42 +939,32 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": - "text", - "text": - "What's in <|image_1|> and how does it compare to the other one?", # noqa: E501 - }, - ], - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": "What's in <|image_1|> and how does it compare to " + "the other one?", + }, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " - "other one?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_2|>\nWhat's in <|image_1|> and how does it compare to " + "the other one?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @@ -1105,39 +977,18 @@ def test_parse_chat_messages_multiple_images_across_messages( conversation, mm_data, mm_uuids = parse_chat_messages( [ { - "role": - "user", + "role": "user", "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in this image?" - }, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, ], }, + {"role": "assistant", "content": "Some stuff."}, { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": - "user", + "role": "user", "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What about this one?" - }, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What about this one?"}, ], }, ], @@ -1147,18 +998,9 @@ def test_parse_chat_messages_multiple_images_across_messages( ) assert conversation == [ - { - "role": "user", - "content": "<|image_1|>\nWhat's in this image?" - }, - { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": "user", - "content": "<|image_2|>\nWhat about this one?" - }, + {"role": "user", "content": "<|image_1|>\nWhat's in this image?"}, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "<|image_2|>\nWhat about this one?"}, ] _assert_mm_data_is_image_input(mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @@ -1173,41 +1015,26 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages( conversation, mm_data, mm_uuids = parse_chat_messages( [ { - "role": - "user", + "role": "user", "content": [ { "type": "image_url", - "image_url": { - "url": image_url - }, + "image_url": {"url": image_url}, "uuid": image_uuid, }, - { - "type": "text", - "text": "What's in this image?" - }, + {"type": "text", "text": "What's in this image?"}, ], }, + {"role": "assistant", "content": "Some stuff."}, { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": - "user", + "role": "user", "content": [ { "type": "image_url", - "image_url": { - "url": image_url - }, + "image_url": {"url": image_url}, "uuid": image_uuid, }, - { - "type": "text", - "text": "What about this one?" - }, + {"type": "text", "text": "What about this one?"}, ], }, ], @@ -1217,18 +1044,9 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages( ) assert conversation == [ - { - "role": "user", - "content": "<|image_1|>\nWhat's in this image?" - }, - { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": "user", - "content": "<|image_2|>\nWhat about this one?" - }, + {"role": "user", "content": "<|image_1|>\nWhat's in this image?"}, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "<|image_2|>\nWhat about this one?"}, ] _assert_mm_data_is_image_input(mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) @@ -1242,19 +1060,10 @@ def test_parse_chat_messages_context_text_format( [ { "role": "user", - "content": [{ - "type": "text", - "text": "What's in this text?" - }], - }, - { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": "user", - "content": "What about this one?" + "content": [{"type": "text", "text": "What's in this text?"}], }, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "What about this one?"}, ], phi3v_model_config, phi3v_tokenizer, @@ -1264,24 +1073,15 @@ def test_parse_chat_messages_context_text_format( assert conversation == [ { "role": "user", - "content": [{ - "type": "text", - "text": "What's in this text?" - }], + "content": [{"type": "text", "text": "What's in this text?"}], }, { "role": "assistant", - "content": [{ - "type": "text", - "text": "Some stuff." - }], + "content": [{"type": "text", "text": "Some stuff."}], }, { "role": "user", - "content": [{ - "type": "text", - "text": "What about this one?" - }], + "content": [{"type": "text", "text": "What about this one?"}], }, ] assert mm_data is None @@ -1300,34 +1100,26 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( ) with pytest.raises(ValueError, match="At most"): parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_url}, }, - }, - { - "type": "image_url", - "image_url": { - "url": image_url + { + "type": "image_url", + "image_url": {"url": image_url}, }, - }, - { - "type": "image_url", - "image_url": { - "url": image_url + { + "type": "image_url", + "image_url": {"url": image_url}, }, - }, - { - "type": "text", - "text": "What's in these images?" - }, - ], - }], + {"type": "text", "text": "What's in these images?"}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", @@ -1348,45 +1140,28 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( parse_chat_messages( [ { - "role": - "user", + "role": "user", "content": [ { "type": "image_url", - "image_url": { - "url": image_url - }, - }, - { - "type": "text", - "text": "What's in this image?" + "image_url": {"url": image_url}, }, + {"type": "text", "text": "What's in this image?"}, ], }, + {"role": "assistant", "content": "Some stuff."}, { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": - "user", + "role": "user", "content": [ { "type": "image_url", - "image_url": { - "url": image_url - }, + "image_url": {"url": image_url}, }, { "type": "image_url", - "image_url": { - "url": image_url - }, - }, - { - "type": "text", - "text": "What about these two?" + "image_url": {"url": image_url}, }, + {"type": "text", "text": "What about these two?"}, ], }, ], @@ -1402,30 +1177,27 @@ def test_parse_chat_messages_multiple_images_uncommon_input( image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - "What's in these images?", - { - "image_url": image_url - }, - { - "image_url": image_url - }, - ], - }], + [ + { + "role": "user", + "content": [ + "What's in these images?", + {"image_url": image_url}, + {"image_url": image_url}, + ], + } + ], phi3v_model_config, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "<|image_1|>\n<|image_2|>\nWhat's in these images?", - }] + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\n<|image_2|>\nWhat's in these images?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @@ -1436,48 +1208,33 @@ def test_parse_chat_messages_multiple_images_interleave( image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "I need you to compare this image", - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "and this one" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Do they have differences?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "and this one"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Do they have differences?"}, + ], + } + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?", - }] + assert conversation == [ + { + "role": "user", + "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + } + ] _assert_mm_data_is_image_input(mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @@ -1489,48 +1246,33 @@ async def test_parse_chat_messages_multiple_images_interleave_async( image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "I need you to compare this image", - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "and this one" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Do they have differences?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "I need you to compare this image", + }, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "and this one"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Do they have differences?"}, + ], + } + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?", - }] + assert conversation == [ + { + "role": "user", + "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + } + ] _assert_mm_data_is_image_input(await mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) @@ -1543,50 +1285,41 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( ): image_uuid = str(hash(image_url)) conversation, mm_data, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "I need you to compare this image", - }, - { - "type": "image_url", - "image_url": { - "url": image_url + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "I need you to compare this image", }, - "uuid": image_uuid, - }, - { - "type": "text", - "text": "and this one" - }, - { - "type": "image_url", - "image_url": { - "url": image_url + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid, }, - "uuid": image_uuid, - }, - { - "type": "text", - "text": "Do they have differences?" - }, - ], - }], + {"type": "text", "text": "and this one"}, + { + "type": "image_url", + "image_url": {"url": image_url}, + "uuid": image_uuid, + }, + {"type": "text", "text": "Do they have differences?"}, + ], + } + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?", - }] + assert conversation == [ + { + "role": "user", + "content": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + } + ] _assert_mm_data_is_image_input(await mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) @@ -1599,43 +1332,19 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( conversation, mm_data, mm_uuids = parse_chat_messages( [ { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Be accurate." - }, + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Be accurate."}, ], }, + {"role": "assistant", "content": "Some stuff."}, { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, ], }, ], @@ -1649,20 +1358,14 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( "role": "user", "content": "What's on this image?\n<|image_1|>\nBe accurate.", }, - { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": "user", - "content": "What's on this image?\n<|image_2|>" - }, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "What's on this image?\n<|image_2|>"}, ] _assert_mm_data_is_image_input(mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None]) -def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave( # noqa: E501 +def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave( phi3v_model_config_mm_interleaved, phi3v_tokenizer, image_url, @@ -1671,43 +1374,25 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl conversation, mm_data, mm_uuids = parse_chat_messages( [ { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, + {"type": "text", "text": "What's on this image?"}, { "type": "image_url", - "image_url": { - "url": image_url - }, + "image_url": {"url": image_url}, "uuid": image_uuid, }, - { - "type": "text", - "text": "Be accurate." - }, + {"type": "text", "text": "Be accurate."}, ], }, + {"role": "assistant", "content": "Some stuff."}, { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, + {"type": "text", "text": "What's on this image?"}, { "type": "image_url", - "image_url": { - "url": image_url - }, + "image_url": {"url": image_url}, "uuid": image_uuid, }, ], @@ -1723,14 +1408,8 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl "role": "user", "content": "What's on this image?\n<|image_1|>\nBe accurate.", }, - { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": "user", - "content": "What's on this image?\n<|image_2|>" - }, + {"role": "assistant", "content": "Some stuff."}, + {"role": "user", "content": "What's on this image?\n<|image_2|>"}, ] _assert_mm_data_is_image_input(mm_data, 2) _assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid]) @@ -1746,59 +1425,22 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( conversation, mm_data, mm_uuids = parse_chat_messages( [ { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "Now listen to this audio" - }, - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "Now listen to this audio"}, + {"type": "audio_url", "audio_url": {"url": audio_url}}, ], }, + {"role": "assistant", "content": "Some stuff."}, { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "And what's in the video?" - }, - { - "type": "video_url", - "video_url": { - "url": video_url - } - }, + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "And what's in the video?"}, + {"type": "video_url", "video_url": {"url": video_url}}, ], }, ], @@ -1809,35 +1451,25 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( assert conversation == [ { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 - }, - { - "role": "assistant", - "content": "Some stuff." + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nNow listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", }, + {"role": "assistant", "content": "Some stuff."}, { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nAnd what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", }, ] _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) - _assert_mm_uuids(mm_uuids, - 2, - modality="image", - expected_uuids=[None, None]) + _assert_mm_uuids(mm_uuids, 2, modality="image", expected_uuids=[None, None]) _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=[None]) _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) -def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave( # noqa: E501 +def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave( qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, image_url, @@ -1847,61 +1479,36 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl conversation, mm_data, mm_uuids = parse_chat_messages( [ { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, + {"type": "text", "text": "What's on this image?"}, { "type": "image_url", - "image_url": { - "url": image_url - }, + "image_url": {"url": image_url}, "uuid": "image_123", }, - { - "type": "text", - "text": "Now listen to this audio" - }, + {"type": "text", "text": "Now listen to this audio"}, { "type": "audio_url", - "audio_url": { - "url": audio_url - }, + "audio_url": {"url": audio_url}, "uuid": "audio_123", }, ], }, + {"role": "assistant", "content": "Some stuff."}, { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, + {"type": "text", "text": "What's on this image?"}, { "type": "image_url", - "image_url": { - "url": image_url - }, + "image_url": {"url": image_url}, "uuid": "image_123", }, - { - "type": "text", - "text": "And what's in the video?" - }, + {"type": "text", "text": "And what's in the video?"}, { "type": "video_url", - "video_url": { - "url": video_url - }, + "video_url": {"url": video_url}, "uuid": "video_123", }, ], @@ -1914,38 +1521,24 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl assert conversation == [ { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 - }, - { - "role": "assistant", - "content": "Some stuff." + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nNow listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", }, + {"role": "assistant", "content": "Some stuff."}, { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nAnd what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", }, ] _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) - _assert_mm_uuids(mm_uuids, - 2, - modality="image", - expected_uuids=["image_123", "image_123"]) - _assert_mm_uuids(mm_uuids, - 1, - modality="video", - expected_uuids=["video_123"]) - _assert_mm_uuids(mm_uuids, - 1, - modality="audio", - expected_uuids=["audio_123"]) + _assert_mm_uuids( + mm_uuids, 2, modality="image", expected_uuids=["image_123", "image_123"] + ) + _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=["audio_123"]) def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501 @@ -1958,22 +1551,15 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes conversation, mm_data, mm_uuids = parse_chat_messages( [ { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, + {"type": "text", "text": "What's on this image?"}, { "type": "image_url", "image_url": None, "uuid": "image_123", }, - { - "type": "text", - "text": "Now listen to this audio" - }, + {"type": "text", "text": "Now listen to this audio"}, { "type": "audio_url", "audio_url": None, @@ -1981,27 +1567,17 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes }, ], }, + {"role": "assistant", "content": "Some stuff."}, { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, + {"type": "text", "text": "What's on this image?"}, { "type": "image_url", "image_url": None, "uuid": "image_123", }, - { - "type": "text", - "text": "And what's in the video?" - }, + {"type": "text", "text": "And what's in the video?"}, { "type": "video_url", "video_url": None, @@ -2017,47 +1593,28 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes assert conversation == [ { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 - }, - { - "role": "assistant", - "content": "Some stuff." + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nNow listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", }, + {"role": "assistant", "content": "Some stuff."}, { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nAnd what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", }, ] - _assert_mm_data_inputs(mm_data, { - "image": 2, - "video": 1, - "audio": 1 - }, - skipped_media_indices={ - "image": [0, 1], - "video": [0], - "audio": [0] - }) - _assert_mm_uuids(mm_uuids, - 2, - modality="image", - expected_uuids=["image_123", "image_123"]) - _assert_mm_uuids(mm_uuids, - 1, - modality="video", - expected_uuids=["video_123"]) - _assert_mm_uuids(mm_uuids, - 1, - modality="audio", - expected_uuids=["audio_123"]) + _assert_mm_data_inputs( + mm_data, + {"image": 2, "video": 1, "audio": 1}, + skipped_media_indices={"image": [0, 1], "video": [0], "audio": [0]}, + ) + _assert_mm_uuids( + mm_uuids, 2, modality="image", expected_uuids=["image_123", "image_123"] + ) + _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=["audio_123"]) def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501 @@ -2070,59 +1627,28 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message conversation, mm_data, mm_uuids = parse_chat_messages( [ { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, + {"type": "text", "text": "What's on this image?"}, { "type": "image_url", - "image_url": { - "url": image_url - }, + "image_url": {"url": image_url}, "uuid": "image_123", }, - { - "type": "text", - "text": "Now listen to this audio" - }, - { - "type": "audio_url", - "audio_url": { - "url": audio_url - } - }, + {"type": "text", "text": "Now listen to this audio"}, + {"type": "audio_url", "audio_url": {"url": audio_url}}, ], }, + {"role": "assistant", "content": "Some stuff."}, { - "role": "assistant", - "content": "Some stuff." - }, - { - "role": - "user", + "role": "user", "content": [ - { - "type": "text", - "text": "What's on this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "And what's in the video?" - }, + {"type": "text", "text": "What's on this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "And what's in the video?"}, { "type": "video_url", - "video_url": { - "url": video_url - }, + "video_url": {"url": video_url}, "uuid": "video_123", }, ], @@ -2135,34 +1661,21 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message assert conversation == [ { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", # noqa: E501 - }, - { - "role": "assistant", - "content": "Some stuff." + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nNow listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>", }, + {"role": "assistant", "content": "Some stuff."}, { - "role": - "user", - "content": - "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" - "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", + "role": "user", + "content": "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>" + "\nAnd what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>", }, ] _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) - _assert_mm_uuids(mm_uuids, - 2, - modality="image", - expected_uuids=["image_123", None]) - _assert_mm_uuids(mm_uuids, - 1, - modality="video", - expected_uuids=["video_123"]) + _assert_mm_uuids(mm_uuids, 2, modality="image", expected_uuids=["image_123", None]) + _assert_mm_uuids(mm_uuids, 1, modality="video", expected_uuids=["video_123"]) _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) @@ -2172,36 +1685,25 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( image_url, ): with pytest.raises( - ValueError, - match=r"Found more '<|image_1|>' placeholders in input prompt " - "than actual multimodal data items.", + ValueError, + match=r"Found more '<|image_1|>' placeholders in input prompt " + "than actual multimodal data items.", ): parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": - "text", - "text": - "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 - "Do they have differences?", - }, - ], - }], + [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?", + }, + ], + } + ], phi3v_model_config_mm_interleaved, phi3v_tokenizer, content_format="string", @@ -2230,7 +1732,8 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): hf_overrides=model_info.hf_overrides, skip_tokenizer_init=model_info.skip_tokenizer_init, enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype) + dtype=model_info.dtype, + ) # Build the tokenizer tokenizer = get_tokenizer( @@ -2238,14 +1741,20 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): trust_remote_code=model_config.trust_remote_code, ) - tools = ([{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema, - }, - }] if use_tools else None) + tools = ( + [ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ] + if use_tools + else None + ) # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( @@ -2263,33 +1772,38 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ( QWEN2VL_MODEL_ID, { - "add_vision_id", "add_generation_prompt", - "continue_final_message", "tools" + "add_vision_id", + "add_generation_prompt", + "continue_final_message", + "tools", }, ), ( QWEN3_MODEL_ID, { - "enable_thinking", "add_generation_prompt", - "continue_final_message", "tools" + "enable_thinking", + "add_generation_prompt", + "continue_final_message", + "tools", }, ), ], ) -def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, - expected_kwargs): +def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwargs): """checks that chat_template is a dict type for HF models.""" model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") - tools = ([{ - "type": "function", - "function": { - "name": "dummy_function_name", - "description": "This is a dummy function", - "parameters": sample_json_schema, - }, - }]) + tools = [ + { + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema, + }, + } + ] chat_template_kwargs = { # both unused @@ -2317,7 +1831,8 @@ def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, hf_overrides=model_info.hf_overrides, skip_tokenizer_init=model_info.skip_tokenizer_init, enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype) + dtype=model_info.dtype, + ) # Build the tokenizer tokenizer = get_tokenizer( @@ -2342,17 +1857,17 @@ def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, # NOTE: Qwen2-Audio default chat template is specially defined inside # processor class instead of using `tokenizer_config.json` -# yapf: disable @pytest.mark.parametrize( ("model", "expected_format"), - [(PHI3V_MODEL_ID, "string"), - (QWEN2VL_MODEL_ID, "openai"), - (QWEN25VL_MODEL_ID, "openai"), - (ULTRAVOX_MODEL_ID, "string"), - (QWEN2AUDIO_MODEL_ID, "openai"), - (LLAMA_GUARD_MODEL_ID, "openai")], + [ + (PHI3V_MODEL_ID, "string"), + (QWEN2VL_MODEL_ID, "openai"), + (QWEN25VL_MODEL_ID, "openai"), + (ULTRAVOX_MODEL_ID, "string"), + (QWEN2AUDIO_MODEL_ID, "openai"), + (LLAMA_GUARD_MODEL_ID, "openai"), + ], ) -# yapf: enable def test_resolve_content_format_hf_defined(model, expected_format): model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -2366,7 +1881,8 @@ def test_resolve_content_format_hf_defined(model, expected_format): hf_overrides=model_info.hf_overrides, skip_tokenizer_init=model_info.skip_tokenizer_init, enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype) + dtype=model_info.dtype, + ) tokenizer = get_tokenizer( model, @@ -2398,18 +1914,18 @@ def test_resolve_content_format_hf_defined(model, expected_format): assert resolved_format == expected_format -# yapf: disable @pytest.mark.parametrize( ("model", "expected_format"), - [("Salesforce/blip2-opt-2.7b", "string"), - ("facebook/chameleon-7b", "string"), - ("deepseek-ai/deepseek-vl2-tiny", "string"), - ("adept/fuyu-8b", "string"), - ("google/paligemma-3b-mix-224", "string"), - ("Qwen/Qwen-VL", "string"), - ("Qwen/Qwen-VL-Chat", "string")], + [ + ("Salesforce/blip2-opt-2.7b", "string"), + ("facebook/chameleon-7b", "string"), + ("deepseek-ai/deepseek-vl2-tiny", "string"), + ("adept/fuyu-8b", "string"), + ("google/paligemma-3b-mix-224", "string"), + ("Qwen/Qwen-VL", "string"), + ("Qwen/Qwen-VL-Chat", "string"), + ], ) -# yapf: enable def test_resolve_content_format_fallbacks(model, expected_format): model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -2423,7 +1939,8 @@ def test_resolve_content_format_fallbacks(model, expected_format): hf_overrides=model_info.hf_overrides, skip_tokenizer_init=model_info.skip_tokenizer_init, enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype) + dtype=model_info.dtype, + ) tokenizer = get_tokenizer( model_config.tokenizer, @@ -2455,30 +1972,30 @@ def test_resolve_content_format_fallbacks(model, expected_format): assert resolved_format == expected_format -# yapf: disable @pytest.mark.parametrize( ("template_path", "expected_format"), - [("template_alpaca.jinja", "string"), - ("template_baichuan.jinja", "string"), - ("template_chatglm.jinja", "string"), - ("template_chatglm2.jinja", "string"), - ("template_chatml.jinja", "string"), - ("template_dse_qwen2_vl.jinja", "openai"), - ("template_falcon_180b.jinja", "string"), - ("template_falcon.jinja", "string"), - ("template_inkbot.jinja", "string"), - ("template_teleflm.jinja", "string"), - ("template_vlm2vec_phi3v.jinja", "openai"), - ("template_vlm2vec_qwen2vl.jinja", "openai"), - ("tool_chat_template_granite_20b_fc.jinja", "string"), - ("tool_chat_template_hermes.jinja", "string"), - ("tool_chat_template_internlm2_tool.jinja", "string"), - ("tool_chat_template_llama3.1_json.jinja", "openai"), - ("tool_chat_template_llama3.2_json.jinja", "openai"), - ("tool_chat_template_mistral_parallel.jinja", "string"), - ("tool_chat_template_mistral.jinja", "string")], + [ + ("template_alpaca.jinja", "string"), + ("template_baichuan.jinja", "string"), + ("template_chatglm.jinja", "string"), + ("template_chatglm2.jinja", "string"), + ("template_chatml.jinja", "string"), + ("template_dse_qwen2_vl.jinja", "openai"), + ("template_falcon_180b.jinja", "string"), + ("template_falcon.jinja", "string"), + ("template_inkbot.jinja", "string"), + ("template_teleflm.jinja", "string"), + ("template_vlm2vec_phi3v.jinja", "openai"), + ("template_vlm2vec_qwen2vl.jinja", "openai"), + ("tool_chat_template_granite_20b_fc.jinja", "string"), + ("tool_chat_template_hermes.jinja", "string"), + ("tool_chat_template_internlm2_tool.jinja", "string"), + ("tool_chat_template_llama3.1_json.jinja", "openai"), + ("tool_chat_template_llama3.2_json.jinja", "openai"), + ("tool_chat_template_mistral_parallel.jinja", "string"), + ("tool_chat_template_mistral.jinja", "string"), + ], ) -# yapf: enable def test_resolve_content_format_examples(template_path, expected_format): model_config = ModelConfig( PHI3V_MODEL_ID, # Dummy @@ -2511,40 +2028,34 @@ def test_resolve_content_format_examples(template_path, expected_format): assert resolved_format == expected_format -def test_parse_chat_messages_include_thinking_chunk(mistral_model_config, - mistral_tokenizer): - messages = [{ - "role": - "system", - "content": [{ - "type": "text", - "text": "You are a helpful assistant." - }, { - "type": - "thinking", - "closed": - True, - "thinking": - "Only return the answer when you are confident." - }] - }, { - "role": "user", - "content": "What is 2+2?" - }, { - "role": - "assistant", - "content": [{ - "type": "text", - "text": "Let me think about it." - }, { - "type": "thinking", - "closed": True, - "thinking": "2+2 = 4" - }, { - "type": "text", - "text": "The answer is 4.", - }], - }] +def test_parse_chat_messages_include_thinking_chunk( + mistral_model_config, mistral_tokenizer +): + messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + { + "type": "thinking", + "closed": True, + "thinking": "Only return the answer when you are confident.", + }, + ], + }, + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me think about it."}, + {"type": "thinking", "closed": True, "thinking": "2+2 = 4"}, + { + "type": "text", + "text": "The answer is 4.", + }, + ], + }, + ] conversation_with_thinking, _, _ = parse_chat_messages( messages, @@ -2553,122 +2064,105 @@ def test_parse_chat_messages_include_thinking_chunk(mistral_model_config, content_format="openai", ) - expected_conversation = [{ - "role": - "system", - "content": [{ - "type": "text", - "text": "You are a helpful assistant." - }, { - "type": "text", - "text": "Only return the answer when you are confident." - }], - }, { - "role": - "user", - "content": [{ - "type": "text", - "text": "What is 2+2?" - }], - }, { - "role": - "assistant", - "content": [ - { - "type": "text", - "text": "Let me think about it." - }, - { - "type": "text", - "text": "2+2 = 4" - }, - { - "type": "text", - "text": "The answer is 4." - }, - ] - }] + expected_conversation = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + { + "type": "text", + "text": "Only return the answer when you are confident.", + }, + ], + }, + { + "role": "user", + "content": [{"type": "text", "text": "What is 2+2?"}], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me think about it."}, + {"type": "text", "text": "2+2 = 4"}, + {"type": "text", "text": "The answer is 4."}, + ], + }, + ] assert conversation_with_thinking == expected_conversation def test_apply_mistral_chat_template_thinking_chunk(): - # Moved import here to avoid yapf and isort conflicts - from vllm.entrypoints.chat_utils import apply_mistral_chat_template - messages = [{ - "role": - "system", - "content": [{ - "type": "text", - "text": "You are a helpful assistant." - }, { - "type": - "thinking", - "closed": - True, - "thinking": - "Only return the answer when you are confident." - }] - }, { - "role": "user", - "content": "What is 2+2?" - }, { - "role": - "assistant", - "content": [{ - "type": "text", - "text": "Let me think about it." - }, { - "type": "thinking", - "closed": True, - "thinking": "2+2 = 4" - }, { - "type": "text", - "text": "The answer is 4.", - }], - }, { - "role": "user", - "content": "Thanks, what is 3+3?" - }] + messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + { + "type": "thinking", + "closed": True, + "thinking": "Only return the answer when you are confident.", + }, + ], + }, + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me think about it."}, + {"type": "thinking", "closed": True, "thinking": "2+2 = 4"}, + { + "type": "text", + "text": "The answer is 4.", + }, + ], + }, + {"role": "user", "content": "Thanks, what is 3+3?"}, + ] # TODO(Julien): upon model release change to a tokenizer already configured. # ================================================================= mistral_tokenizer = MistralTokenizer.from_pretrained( - "mistralai/Devstral-Small-2507") + "mistralai/Devstral-Small-2507" + ) assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) # Add think special tokens to the tokenizer mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( - rank=35, is_control=True, token_str=SpecialTokens.begin_think.value) + rank=35, is_control=True, token_str=SpecialTokens.begin_think.value + ) mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( - rank=36, is_control=True, token_str=SpecialTokens.end_think.value) + rank=36, is_control=True, token_str=SpecialTokens.end_think.value + ) mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { k: v - for k, v in - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() + for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() if v not in {35, 36} } mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.begin_think.value] = 35 + SpecialTokens.begin_think.value + ] = 35 mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.end_think.value] = 36 + SpecialTokens.end_think.value + ] = 36 mistral_tokenizer.instruct.BEGIN_THINK = 35 mistral_tokenizer.instruct.END_THINK = 36 # ================================================================= - tokens_ids = apply_mistral_chat_template(mistral_tokenizer, - messages, - chat_template=None, - tools=None) + tokens_ids = apply_mistral_chat_template( + mistral_tokenizer, messages, chat_template=None, tools=None + ) string_tokens = mistral_tokenizer.mistral.decode( - tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP) + tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP + ) expected_tokens = ( r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the" r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]" r"[INST]What is 2+2?[/INST]" r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>" - r"[INST]Thanks, what is 3+3?[/INST]") + r"[INST]Thanks, what is 3+3?[/INST]" + ) assert string_tokens == expected_tokens @@ -2679,37 +2173,33 @@ def test_parse_chat_messages_single_empty_audio_with_uuid( ): audio_uuid = "abcd" conversation, mm_data, mm_uuids = parse_chat_messages( - [{ - "role": - "user", - "content": [ - { - "type": "input_audio", - "input_audio": {}, - "uuid": audio_uuid, - }, - { - "type": "text", - "text": "What does the audio say?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": {}, + "uuid": audio_uuid, + }, + {"type": "text", "text": "What does the audio say?"}, + ], + } + ], qwen2_audio_model_config, qwen2_audio_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the audio say?" - }] + assert conversation == [ + { + "role": "user", + "content": "Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the " + "audio say?", + } + ] _assert_mm_data_inputs(mm_data, {"audio": 1}) - _assert_mm_uuids(mm_uuids, - 1, - modality="audio", - expected_uuids=[audio_uuid]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[audio_uuid]) @pytest.mark.asyncio @@ -2719,34 +2209,30 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async( ): audio_uuid = "abcd" conversation, mm_future, mm_uuids = parse_chat_messages_futures( - [{ - "role": - "user", - "content": [ - { - "type": "input_audio", - "input_audio": {}, - "uuid": audio_uuid, - }, - { - "type": "text", - "text": "What does the audio say?" - }, - ], - }], + [ + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": {}, + "uuid": audio_uuid, + }, + {"type": "text", "text": "What does the audio say?"}, + ], + } + ], qwen2_audio_model_config, qwen2_audio_tokenizer, content_format="string", ) - assert conversation == [{ - "role": - "user", - "content": - "Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the audio say?" - }] + assert conversation == [ + { + "role": "user", + "content": "Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat does the " + "audio say?", + } + ] _assert_mm_data_inputs(await mm_future, {"audio": 1}) - _assert_mm_uuids(mm_uuids, - 1, - modality="audio", - expected_uuids=[audio_uuid]) + _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[audio_uuid]) diff --git a/tests/entrypoints/test_context.py b/tests/entrypoints/test_context.py index 2afe9758c2ad..6ad18fa08bc4 100644 --- a/tests/entrypoints/test_context.py +++ b/tests/entrypoints/test_context.py @@ -48,10 +48,9 @@ def create_mock_request_output( ) -async def generate_mock_outputs(num_turns, - prompt_token_counts, - output_token_counts, - cached_token_counts=None): +async def generate_mock_outputs( + num_turns, prompt_token_counts, output_token_counts, cached_token_counts=None +): """Generate a sequence of mock RequestOutput objects to simulate multiple turns.""" if cached_token_counts is None: @@ -73,8 +72,9 @@ async def generate_mock_outputs(num_turns, @pytest.fixture def mock_parser(): """Set up a mock parser for tests.""" - with patch("vllm.entrypoints.context.get_streamable_parser_for_assistant" - ) as mock_parser_factory: + with patch( + "vllm.entrypoints.context.get_streamable_parser_for_assistant" + ) as mock_parser_factory: # Create a mock parser object parser = MagicMock() parser.messages = [] @@ -124,9 +124,9 @@ async def test_multi_turn_token_counting(): prompt_token_counts = [5, 15, 20] output_token_counts = [3, 4, 5] cached_token_counts = [0, 5, 15] - mock_generator = generate_mock_outputs(3, prompt_token_counts, - output_token_counts, - cached_token_counts) + mock_generator = generate_mock_outputs( + 3, prompt_token_counts, output_token_counts, cached_token_counts + ) # First turn - initial prompt and response mock_output1 = await async_next(mock_generator) @@ -251,7 +251,7 @@ async def test_single_turn_no_tool_output(): """Test that first turn never generates tool output tokens.""" context = HarmonyContext( messages=[], - available_tools=["browser"] # Tools available + available_tools=["browser"], # Tools available ) # Even with large prompt in first turn, no tool tokens should be counted @@ -333,21 +333,24 @@ async def test_streaming_multi_turn_token_counting(mock_parser): output_token_ids=[101], # Single token num_cached_tokens=0, finished=False, # Not end of message yet - )) + ) + ) # Second token of first turn context.append_output( create_mock_request_output( output_token_ids=[102], finished=False, - )) + ) + ) # Last token of first turn (finished=True signals end of message) context.append_output( create_mock_request_output( output_token_ids=[103], finished=True, # End of message - )) + ) + ) # Check token counts after first turn assert context.num_prompt_tokens == 3 # Initial prompt tokens @@ -362,25 +365,36 @@ async def test_streaming_multi_turn_token_counting(mock_parser): # First token of second turn context.append_output( create_mock_request_output( - prompt_token_ids=[1, 2, 3, 101, 102, 103, 4, - 5], # 8 tokens (includes previous) + prompt_token_ids=[ + 1, + 2, + 3, + 101, + 102, + 103, + 4, + 5, + ], # 8 tokens (includes previous) output_token_ids=[201], num_cached_tokens=3, # Some tokens cached finished=False, - )) + ) + ) # More tokens in reasoning channel context.append_output( create_mock_request_output( output_token_ids=[202], finished=False, - )) + ) + ) context.append_output( create_mock_request_output( output_token_ids=[203], finished=True, # End of reasoning message - )) + ) + ) # Check counts after second turn (reasoning message) assert context.num_prompt_tokens == 3 + 8 # Initial + second prompt @@ -399,18 +413,32 @@ async def test_streaming_multi_turn_token_counting(mock_parser): context.append_output( create_mock_request_output( prompt_token_ids=[ - 1, 2, 3, 101, 102, 103, 4, 5, 201, 202, 203, 6, 7 + 1, + 2, + 3, + 101, + 102, + 103, + 4, + 5, + 201, + 202, + 203, + 6, + 7, ], # 13 tokens output_token_ids=[301], num_cached_tokens=8, # More cached tokens finished=False, - )) + ) + ) context.append_output( create_mock_request_output( output_token_ids=[302], finished=True, - )) + ) + ) # Final token counts check assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts @@ -421,8 +449,9 @@ async def test_streaming_multi_turn_token_counting(mock_parser): # Additional tool tokens from third turn # Formula: this turn prompt - last turn prompt - last turn output additional_tool_tokens = 13 - 8 - 3 # = 2 - assert context.num_tool_output_tokens == expected_tool_tokens \ - + additional_tool_tokens + assert ( + context.num_tool_output_tokens == expected_tool_tokens + additional_tool_tokens + ) @pytest.mark.asyncio @@ -442,8 +471,7 @@ async def test_streaming_message_synchronization(mock_parser): recipient=Role.ASSISTANT, ) ] - context = StreamingHarmonyContext(messages=initial_messages, - available_tools=[]) + context = StreamingHarmonyContext(messages=initial_messages, available_tools=[]) # Verify initial state assert len(context._messages) == 1 @@ -461,9 +489,10 @@ async def test_streaming_message_synchronization(mock_parser): # This should trigger the message synchronization logic context.append_output( - create_mock_request_output(prompt_token_ids=[1, 2, 3], - output_token_ids=[101], - finished=False)) + create_mock_request_output( + prompt_token_ids=[1, 2, 3], output_token_ids=[101], finished=False + ) + ) # Verify that messages were synchronized assert len(context._messages) == 2 @@ -485,12 +514,13 @@ async def test_streaming_message_synchronization(mock_parser): author=Author(role=Role.ASSISTANT, name="assistant"), content=[TextContent(text="Response 4")], recipient=Role.USER, - )) + ) + ) # Create another output to trigger synchronization again - mock_output2 = create_mock_request_output(prompt_token_ids=[1, 2, 3], - output_token_ids=[102], - finished=True) + mock_output2 = create_mock_request_output( + prompt_token_ids=[1, 2, 3], output_token_ids=[102], finished=True + ) context.append_output(mock_output2) diff --git a/tests/entrypoints/test_renderer.py b/tests/entrypoints/test_renderer.py index 1f55b1fba613..f93978c3e6e7 100644 --- a/tests/entrypoints/test_renderer.py +++ b/tests/entrypoints/test_renderer.py @@ -21,7 +21,6 @@ class MockModelConfig: class MockTokenizerResult: - def __init__(self, input_ids): self.input_ids = input_ids @@ -45,9 +44,11 @@ def mock_async_tokenizer(): @pytest.fixture def renderer(mock_model_config, mock_tokenizer): - return CompletionRenderer(model_config=mock_model_config, - tokenizer=mock_tokenizer, - async_tokenizer_pool={}) + return CompletionRenderer( + model_config=mock_model_config, + tokenizer=mock_tokenizer, + async_tokenizer_pool={}, + ) class TestRenderPrompt: @@ -57,7 +58,8 @@ class TestRenderPrompt: async def test_token_input(self, renderer): tokens = [101, 7592, 2088] results = await renderer.render_prompt( - prompt_or_prompts=tokens, config=RenderConfig(max_length=100)) + prompt_or_prompts=tokens, config=RenderConfig(max_length=100) + ) assert len(results) == 1 assert results[0]["prompt_token_ids"] == tokens @@ -66,7 +68,8 @@ async def test_token_input(self, renderer): async def test_token_list_input(self, renderer): token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] results = await renderer.render_prompt( - prompt_or_prompts=token_lists, config=RenderConfig(max_length=100)) + prompt_or_prompts=token_lists, config=RenderConfig(max_length=100) + ) assert len(results) == 3 assert results[0]["prompt_token_ids"] == [101, 7592, 2088] @@ -75,14 +78,12 @@ async def test_token_list_input(self, renderer): @pytest.mark.asyncio async def test_text_input(self, renderer, mock_async_tokenizer): - mock_async_tokenizer.return_value = MockTokenizerResult( - [101, 7592, 2088]) - renderer.async_tokenizer_pool[ - renderer.tokenizer] = mock_async_tokenizer + mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer results = await renderer.render_prompt( - prompt_or_prompts="Hello world", - config=RenderConfig(max_length=100)) + prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) + ) assert len(results) == 1 assert results[0]["prompt_token_ids"] == [101, 7592, 2088] @@ -90,15 +91,13 @@ async def test_text_input(self, renderer, mock_async_tokenizer): @pytest.mark.asyncio async def test_text_list_input(self, renderer, mock_async_tokenizer): - mock_async_tokenizer.return_value = MockTokenizerResult( - [101, 7592, 2088]) - renderer.async_tokenizer_pool[ - renderer.tokenizer] = mock_async_tokenizer + mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer text_list_input = ["Hello world", "How are you?", "Good morning"] results = await renderer.render_prompt( - prompt_or_prompts=text_list_input, - config=RenderConfig(max_length=100)) + prompt_or_prompts=text_list_input, config=RenderConfig(max_length=100) + ) assert len(results) == 3 for result in results: @@ -107,31 +106,31 @@ async def test_text_list_input(self, renderer, mock_async_tokenizer): @pytest.mark.asyncio async def test_no_truncation(self, renderer, mock_async_tokenizer): - mock_async_tokenizer.return_value = MockTokenizerResult( - [101, 7592, 2088]) - renderer.async_tokenizer_pool[ - renderer.tokenizer] = mock_async_tokenizer + mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088]) + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer results = await renderer.render_prompt( - prompt_or_prompts="Hello world", - config=RenderConfig(max_length=100)) + prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) + ) assert len(results) == 1 call_args = mock_async_tokenizer.call_args - assert "truncation" not in call_args.kwargs or call_args.kwargs[ - "truncation"] is False + assert ( + "truncation" not in call_args.kwargs + or call_args.kwargs["truncation"] is False + ) @pytest.mark.asyncio async def test_truncation_positive(self, renderer, mock_async_tokenizer): mock_async_tokenizer.return_value = MockTokenizerResult( - [101, 7592, 2088]) # Truncated - renderer.async_tokenizer_pool[ - renderer.tokenizer] = mock_async_tokenizer + [101, 7592, 2088] + ) # Truncated + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer - results = await renderer.render_prompt(prompt_or_prompts="Hello world", - config=RenderConfig( - max_length=100, - truncate_prompt_tokens=50)) + results = await renderer.render_prompt( + prompt_or_prompts="Hello world", + config=RenderConfig(max_length=100, truncate_prompt_tokens=50), + ) assert len(results) == 1 call_args = mock_async_tokenizer.call_args @@ -142,14 +141,14 @@ async def test_truncation_positive(self, renderer, mock_async_tokenizer): async def test_truncation_negative(self, renderer, mock_async_tokenizer): # Test that negative truncation uses model's max_model_len mock_async_tokenizer.return_value = MockTokenizerResult( - [101, 7592, 2088]) # Truncated to max_model_len - renderer.async_tokenizer_pool[ - renderer.tokenizer] = mock_async_tokenizer + [101, 7592, 2088] + ) # Truncated to max_model_len + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer - results = await renderer.render_prompt(prompt_or_prompts="Hello world", - config=RenderConfig( - max_length=200, - truncate_prompt_tokens=-1)) + results = await renderer.render_prompt( + prompt_or_prompts="Hello world", + config=RenderConfig(max_length=200, truncate_prompt_tokens=-1), + ) assert len(results) == 1 call_args = mock_async_tokenizer.call_args @@ -159,12 +158,11 @@ async def test_truncation_negative(self, renderer, mock_async_tokenizer): @pytest.mark.asyncio async def test_token_truncation_last_elements(self, renderer): # Test that token truncation keeps the last N elements - long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, - 109] # 10 tokens - results = await renderer.render_prompt(prompt_or_prompts=long_tokens, - config=RenderConfig( - max_length=100, - truncate_prompt_tokens=5)) + long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens + results = await renderer.render_prompt( + prompt_or_prompts=long_tokens, + config=RenderConfig(max_length=100, truncate_prompt_tokens=5), + ) assert len(results) == 1 # Should keep the last 5 tokens: [105, 106, 107, 108, 109] @@ -175,30 +173,30 @@ async def test_max_length_exceeded(self, renderer): long_tokens = list(range(150)) # Exceeds max_model_len=100 with pytest.raises(ValueError, match="maximum context length"): - await renderer.render_prompt(prompt_or_prompts=long_tokens, - config=RenderConfig(max_length=100)) + await renderer.render_prompt( + prompt_or_prompts=long_tokens, config=RenderConfig(max_length=100) + ) @pytest.mark.asyncio async def test_no_tokenizer_for_text(self, mock_model_config): renderer_no_tokenizer = CompletionRenderer( - model_config=mock_model_config, - tokenizer=None, - async_tokenizer_pool={}) + model_config=mock_model_config, tokenizer=None, async_tokenizer_pool={} + ) with pytest.raises(ValueError, match="No tokenizer available"): await renderer_no_tokenizer.render_prompt( - prompt_or_prompts="Hello world", - config=RenderConfig(max_length=100)) + prompt_or_prompts="Hello world", config=RenderConfig(max_length=100) + ) @pytest.mark.asyncio async def test_token_input_with_needs_detokenization( - self, renderer, mock_async_tokenizer): + self, renderer, mock_async_tokenizer + ): # When needs_detokenization=True for token inputs, renderer should # use the async tokenizer to decode and include the original text # in the returned prompt object. mock_async_tokenizer.decode = AsyncMock(return_value="decoded text") - renderer.async_tokenizer_pool[ - renderer.tokenizer] = mock_async_tokenizer + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer tokens = [1, 2, 3, 4] results = await renderer.render_prompt( @@ -213,7 +211,6 @@ async def test_token_input_with_needs_detokenization( class TestRenderEmbedPrompt: - def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes: """Helper to create base64-encoded tensor bytes""" buffer = io.BytesIO() @@ -244,9 +241,7 @@ async def test_multiple_prompt_embeds(self, renderer): torch.randn(8, 512, dtype=torch.float32), torch.randn(12, 512, dtype=torch.float32), ] - embed_bytes_list = [ - self._create_test_embed_bytes(t) for t in test_tensors - ] + embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors] results = await renderer.render_prompt_and_embeds( prompt_embeds=embed_bytes_list, @@ -307,13 +302,10 @@ async def test_prompt_embed_squeeze_batch_dim(self, renderer): assert results[0]["prompt_embeds"].shape == (10, 768) @pytest.mark.asyncio - async def test_both_prompts_and_embeds(self, renderer, - mock_async_tokenizer): + async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer): # Set up text tokenization - mock_async_tokenizer.return_value = MockTokenizerResult( - [101, 102, 103]) - renderer.async_tokenizer_pool[ - renderer.tokenizer] = mock_async_tokenizer + mock_async_tokenizer.return_value = MockTokenizerResult([101, 102, 103]) + renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer # Create embed test_tensor = torch.randn(5, 256, dtype=torch.float32) diff --git a/tests/entrypoints/test_ssl_cert_refresher.py b/tests/entrypoints/test_ssl_cert_refresher.py index 33ad2cfd3a33..b56fbd9fee7e 100644 --- a/tests/entrypoints/test_ssl_cert_refresher.py +++ b/tests/entrypoints/test_ssl_cert_refresher.py @@ -11,7 +11,6 @@ class MockSSLContext(SSLContext): - def __init__(self): self.load_cert_chain_count = 0 self.load_ca_count = 0 @@ -34,7 +33,7 @@ def load_verify_locations( def create_file() -> str: - with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f: + with tempfile.NamedTemporaryFile(dir="/tmp", delete=False) as f: return f.name diff --git a/tests/evals/gpt_oss/__init__.py b/tests/evals/gpt_oss/__init__.py index 0fec1fe5bcdf..208f01a7cb5e 100644 --- a/tests/evals/gpt_oss/__init__.py +++ b/tests/evals/gpt_oss/__init__.py @@ -1,2 +1,2 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project \ No newline at end of file +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/evals/gpt_oss/conftest.py b/tests/evals/gpt_oss/conftest.py index 35528c0a6a36..2f140ae2c8e9 100644 --- a/tests/evals/gpt_oss/conftest.py +++ b/tests/evals/gpt_oss/conftest.py @@ -8,11 +8,9 @@ def pytest_addoption(parser): """Add command line options for pytest.""" parser.addoption("--model", action="store", help="Model name to evaluate") - parser.addoption("--metric", - action="store", - type=float, - help="Expected metric threshold") - parser.addoption("--server-args", - action="store", - default="", - help="Additional server arguments") + parser.addoption( + "--metric", action="store", type=float, help="Expected metric threshold" + ) + parser.addoption( + "--server-args", action="store", default="", help="Additional server arguments" + ) diff --git a/tests/evals/gpt_oss/test_gpqa_correctness.py b/tests/evals/gpt_oss/test_gpqa_correctness.py index 07c04f00cd0d..151deaa059f0 100644 --- a/tests/evals/gpt_oss/test_gpqa_correctness.py +++ b/tests/evals/gpt_oss/test_gpqa_correctness.py @@ -25,9 +25,19 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float: # Build the command to run the evaluation cmd = [ - sys.executable, "-m", "gpt_oss.evals", "--eval", "gpqa", "--model", - model_name, "--reasoning-effort", "low", "--base-url", base_url, - "--n-threads", "200" + sys.executable, + "-m", + "gpt_oss.evals", + "--eval", + "gpqa", + "--model", + model_name, + "--reasoning-effort", + "low", + "--base-url", + base_url, + "--n-threads", + "200", ] try: @@ -37,7 +47,8 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float: text=True, capture_output=True, timeout=1800, # 30 minute timeout - env={"OPENAI_API_KEY": "dummy"}) + env={"OPENAI_API_KEY": "dummy"}, + ) print("Evaluation process output:\n", result.stdout) @@ -48,14 +59,16 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float: # If we still can't find it, raise an error raise ValueError( - f"Could not parse score from evaluation output:\n{result.stdout}") + f"Could not parse score from evaluation output:\n{result.stdout}" + ) except subprocess.TimeoutExpired as e: raise RuntimeError("Evaluation timed out") from e except subprocess.CalledProcessError as e: raise RuntimeError( f"Evaluation failed with exit code {e.returncode}:\n" - f"stdout: {e.stdout}\nstderr: {e.stderr}") from e + f"stdout: {e.stdout}\nstderr: {e.stderr}" + ) from e def test_gpqa_correctness(request): @@ -72,17 +85,20 @@ def test_gpqa_correctness(request): server_args = server_args_str.split() # Add standard server arguments - server_args.extend([ - "--trust-remote-code", - ]) + server_args.extend( + [ + "--trust-remote-code", + ] + ) print(f"Starting GPQA evaluation for model: {model_name}") print(f"Expected metric threshold: {expected_metric}") print(f"Server args: {' '.join(server_args)}") # Launch server and run evaluation - with RemoteOpenAIServer(model_name, server_args, - max_wait_seconds=1800) as remote_server: + with RemoteOpenAIServer( + model_name, server_args, max_wait_seconds=1800 + ) as remote_server: base_url = remote_server.url_for("v1") print(f"Server started at: {base_url}") @@ -96,6 +112,7 @@ def test_gpqa_correctness(request): # Verify metric is within tolerance assert measured_metric >= expected_metric - TOL, ( f"GPQA metric too low: {measured_metric:.4f} < " - f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}") + f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}" + ) print(f"✅ GPQA test passed for {model_name}") diff --git a/tests/evals/gsm8k/__init__.py b/tests/evals/gsm8k/__init__.py index 0fec1fe5bcdf..208f01a7cb5e 100644 --- a/tests/evals/gsm8k/__init__.py +++ b/tests/evals/gsm8k/__init__.py @@ -1,2 +1,2 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project \ No newline at end of file +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/evals/gsm8k/configs/models-blackwell.txt b/tests/evals/gsm8k/configs/models-blackwell.txt new file mode 100644 index 000000000000..e577645d60d6 --- /dev/null +++ b/tests/evals/gsm8k/configs/models-blackwell.txt @@ -0,0 +1,4 @@ +Qwen3-0.6B-FP8.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml +Qwen1.5-MoE-W4A16-CT.yaml +DeepSeek-V2-Lite-Instruct-FP8.yaml diff --git a/tests/evals/gsm8k/conftest.py b/tests/evals/gsm8k/conftest.py index d96b0a66ede2..1932a13cdfc6 100644 --- a/tests/evals/gsm8k/conftest.py +++ b/tests/evals/gsm8k/conftest.py @@ -6,13 +6,12 @@ def pytest_addoption(parser): """Add custom command line options.""" - parser.addoption("--config-list-file", - default="configs/models-small.txt", - help="File containing list of config files to test") - parser.addoption("--tp-size", - default=1, - type=int, - help="Tensor parallel size") + parser.addoption( + "--config-list-file", + default="configs/models-small.txt", + help="File containing list of config files to test", + ) + parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size") def pytest_generate_tests(metafunc): @@ -55,12 +54,10 @@ def pytest_generate_tests(metafunc): # Generate test parameters if config_files: - metafunc.parametrize(["config_filename", "tp_size"], - [(config_file, int(tp_size)) - for config_file in config_files], - ids=[ - f"{config_file.stem}-tp{tp_size}" - for config_file in config_files - ]) + metafunc.parametrize( + ["config_filename", "tp_size"], + [(config_file, int(tp_size)) for config_file in config_files], + ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files], + ) else: print("No config files found, test will be skipped") diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py index 7d0ce25f75dd..9edec7a78ca2 100644 --- a/tests/evals/gsm8k/gsm8k_eval.py +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -76,13 +76,15 @@ def get_answer_value(answer_str: str) -> int: return INVALID -async def call_vllm_api(session: aiohttp.ClientSession, - prompt: str, - temperature: float, - max_tokens: int, - stop: Optional[list[str]] = None, - url: Optional[str] = None, - seed: Optional[int] = None) -> str: +async def call_vllm_api( + session: aiohttp.ClientSession, + prompt: str, + temperature: float, + max_tokens: int, + stop: Optional[list[str]] = None, + url: Optional[str] = None, + seed: Optional[int] = None, +) -> str: """Call vLLM's OpenAI-compatible completions endpoint.""" data = { "prompt": prompt, @@ -94,8 +96,7 @@ async def call_vllm_api(session: aiohttp.ClientSession, data["seed"] = seed try: - async with session.post(f"{url}/v1/completions", - json=data) as response: + async with session.post(f"{url}/v1/completions", json=data) as response: response.raise_for_status() result = await response.json() return result["choices"][0]["text"] @@ -104,16 +105,18 @@ async def call_vllm_api(session: aiohttp.ClientSession, return "" -def evaluate_gsm8k(num_questions: int = 1319, - num_shots: int = 5, - max_tokens: int = 256, - host: str = "http://127.0.0.1", - port: int = 8000, - temperature: float = 0.0, - seed: Optional[int] = 42) -> dict[str, Union[float, int]]: +def evaluate_gsm8k( + num_questions: int = 1319, + num_shots: int = 5, + max_tokens: int = 256, + host: str = "http://127.0.0.1", + port: int = 8000, + temperature: float = 0.0, + seed: Optional[int] = 42, +) -> dict[str, Union[float, int]]: """ Evaluate GSM8K accuracy using vLLM serve endpoint. - + Returns dict with accuracy, invalid_rate, latency, etc. """ base_url = f"{host}:{port}" @@ -127,8 +130,10 @@ def evaluate_gsm8k(num_questions: int = 1319, # Build few-shot examples from train split (like lm-eval does) few_shot_examples = "" for i in range(num_shots): - few_shot_examples += (f"Question: {train_data[i]['question']}\n" - f"Answer: {train_data[i]['answer']}\n\n") + few_shot_examples += ( + f"Question: {train_data[i]['question']}\n" + f"Answer: {train_data[i]['answer']}\n\n" + ) # Prepare test questions and labels from test split questions = [] @@ -157,15 +162,15 @@ async def get_answer(session: aiohttp.ClientSession, i: int) -> str: states[i] = answer return answer - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( - total=600)) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=600) + ) as session: tasks = [get_answer(session, i) for i in range(num_questions)] await tqdm.gather(*tasks, desc="Evaluating") return states - print(f"Running GSM8K evaluation: {num_questions} questions, " - f"{num_shots}-shot") + print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot") tic = time.perf_counter() states = asyncio.run(run_async_evaluation()) @@ -191,36 +196,28 @@ async def get_answer(session: aiohttp.ClientSession, i: int) -> str: def main() -> None: - parser = argparse.ArgumentParser( - description="GSM8K evaluation for vLLM serve") - parser.add_argument("--num-shots", - type=int, - default=5, - help="Number of few-shot examples") - parser.add_argument("--num-questions", - type=int, - default=1319, - help="Number of questions to evaluate") - parser.add_argument("--max-tokens", - type=int, - default=256, - help="Max tokens for generation") - parser.add_argument("--host", - type=str, - default="http://127.0.0.1", - help="Host URL") + parser = argparse.ArgumentParser(description="GSM8K evaluation for vLLM serve") + parser.add_argument( + "--num-shots", type=int, default=5, help="Number of few-shot examples" + ) + parser.add_argument( + "--num-questions", + type=int, + default=1319, + help="Number of questions to evaluate", + ) + parser.add_argument( + "--max-tokens", type=int, default=256, help="Max tokens for generation" + ) + parser.add_argument("--host", type=str, default="http://127.0.0.1", help="Host URL") parser.add_argument("--port", type=int, default=8000, help="Port number") - parser.add_argument("--temperature", - type=float, - default=0.0, - help="Temperature for generation") - parser.add_argument("--seed", - type=int, - default=42, - help="Random seed for reproducibility") - parser.add_argument("--save-results", - type=str, - help="Save results to JSON file") + parser.add_argument( + "--temperature", type=float, default=0.0, help="Temperature for generation" + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for reproducibility" + ) + parser.add_argument("--save-results", type=str, help="Save results to JSON file") args = parser.parse_args() diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py index a12dd49dbea6..ce3ab8096b45 100644 --- a/tests/evals/gsm8k/test_gsm8k_correctness.py +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -63,9 +63,9 @@ def test_gsm8k_correctness_param(config_filename, tp_size): ] # Launch server and run evaluation - with RemoteOpenAIServer(eval_config["model_name"], - server_args, - max_wait_seconds=480) as remote_server: + with RemoteOpenAIServer( + eval_config["model_name"], server_args, max_wait_seconds=480 + ) as remote_server: server_url = remote_server.url_for("v1") results = launch_gsm8k_eval(eval_config, server_url, tp_size) @@ -85,6 +85,7 @@ def test_gsm8k_correctness_param(config_filename, tp_size): # Verify accuracy is within tolerance assert measured_accuracy >= expected_accuracy - RTOL, ( f"Accuracy too low: {measured_accuracy:.3f} < " - f"{expected_accuracy:.3f} - {RTOL:.3f}") + f"{expected_accuracy:.3f} - {RTOL:.3f}" + ) print(f"✅ GSM8K test passed for {eval_config['model_name']}") diff --git a/tests/kernels/allclose_default.py b/tests/kernels/allclose_default.py index 9d65159bf64f..6561e9556fa7 100644 --- a/tests/kernels/allclose_default.py +++ b/tests/kernels/allclose_default.py @@ -6,11 +6,7 @@ # Reference default values of atol and rtol are from # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67 default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5} -default_rtol = { - torch.float16: 1e-3, - torch.bfloat16: 1.6e-2, - torch.float: 1.3e-6 -} +default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6} def get_default_atol(output) -> float: diff --git a/tests/kernels/attention/conftest.py b/tests/kernels/attention/conftest.py index 88a2fb62b254..b080a71bd54e 100644 --- a/tests/kernels/attention/conftest.py +++ b/tests/kernels/attention/conftest.py @@ -3,8 +3,7 @@ import pytest -from vllm.utils import (create_kv_caches_with_random, - create_kv_caches_with_random_flash) +from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash @pytest.fixture() diff --git a/tests/kernels/attention/test_aiter_flash_attn.py b/tests/kernels/attention/test_aiter_flash_attn.py index 2d882bdf4066..88b21a9b84d6 100644 --- a/tests/kernels/attention/test_aiter_flash_attn.py +++ b/tests/kernels/attention/test_aiter_flash_attn.py @@ -39,7 +39,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -57,10 +57,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -74,11 +77,10 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="Only ROCm is supported") -@pytest.mark.parametrize("seq_lens", - [[(10, 1328), (5, 18), - (129, 463)], [(8, 523), (24, 37), (3, 2011)]]) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only ROCm is supported") +@pytest.mark.parametrize( + "seq_lens", [[(10, 1328), (5, 18), (129, 463)], [(8, 523), (24, 37), (3, 2011)]] +) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -109,34 +111,27 @@ def test_varlen_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) - cu_seq_lens = torch.tensor([0] + kv_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + cu_seq_lens = torch.tensor([0] + kv_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) output = torch.empty_like(query) @@ -187,5 +182,7 @@ def test_varlen_with_paged_kv( atol, rtol = 2e-2, 2e-2 if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index c7abf652f111..16e544eb3cf9 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -42,9 +42,7 @@ USE_ALIBI = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] def ref_masked_attention( @@ -110,8 +108,7 @@ def ref_single_query_cached_kv_attention( # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(seq_len).int() alibi_bias = (position_ids - seq_len + 1).float() - alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( - 1, 1, -1) + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) out = ref_masked_attention(q, keys, values, scale, alibi_bias) out = out.view(num_query_heads, head_size) @@ -119,8 +116,8 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( - "version", - ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) + "version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"] +) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -143,13 +140,18 @@ def test_paged_attention( seed: int, device: str, ) -> None: - if ((kv_cache_dtype == "fp8" and head_size % 16) - or (version == "rocm" and head_size not in (64, 128))): + if (kv_cache_dtype == "fp8" and head_size % 16) or ( + version == "rocm" and head_size not in (64, 128) + ): pytest.skip() - if (version == "rocm" and current_platform.is_navi() - and (kv_cache_dtype == "fp8" or head_size != 128 - or block_size != 16 or use_alibi)): + if ( + version == "rocm" + and current_platform.is_navi() + and ( + kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi + ) + ): pytest.skip() global PARTITION_SIZE @@ -177,18 +179,24 @@ def test_paged_attention( block_tables_lst: list[list[int]] = [] for _ in range(num_seqs): block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables_lst.append(block_table) block_tables = torch.tensor(block_tables_lst, dtype=torch.int) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) + key_caches, value_caches = kv_cache_factory( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale @@ -214,18 +222,37 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v1, - (output, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v1, + ( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + 0, + 0, + 0, + 64, + 0, + ), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), + ) elif version in ("v2", "rocm"): if current_platform.is_rocm() and version == "rocm": PARTITION_SIZE = PARTITION_SIZE_ROCM - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( @@ -258,13 +285,34 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + ( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + 0, + 0, + 0, + 64, + 0, + ), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), + ) else: ops.paged_attention_rocm( @@ -288,13 +336,30 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._rocm_C.paged_attention, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, None, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._rocm_C.paged_attention, + ( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + None, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]), + ) else: raise AssertionError(f"Unknown version: {version}") @@ -303,18 +368,17 @@ def test_paged_attention( if kv_cache_dtype == "fp8": # Convert cache data back to dtype. x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) + dequantized_key_cache = torch.empty( + size=key_cache_shape, dtype=dtype, device=device + ) ops.convert_fp8(dequantized_key_cache, key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) + dequantized_value_cache = torch.empty( + size=value_cache_shape, dtype=dtype, device=device + ) ops.convert_fp8(dequantized_value_cache, value_cache) value_cache = dequantized_value_cache @@ -367,8 +431,9 @@ def ref_multi_query_kv_attention( if alibi_bias: attn_mask = alibi_bias[i] else: - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) + attn_mask = torch.triu( + torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1 + ) attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask.to(dtype=dtype) @@ -390,8 +455,9 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, @@ -413,13 +479,11 @@ def test_multi_query_kv_attention( scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads - qkv = torch.empty(num_tokens, - num_query_heads + 2 * num_kv_heads, - head_size, - dtype=dtype) + qkv = torch.empty( + num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype + ) qkv.uniform_(-scale, scale) - query, key, value = qkv.split( - [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1) num_queries_per_kv = num_query_heads // num_kv_heads if num_queries_per_kv > 1: @@ -429,8 +493,7 @@ def test_multi_query_kv_attention( alibi_bias = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, - seq_lens) + attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) output = torch.empty_like(query) start = 0 # Dynamic sequence length not supported with custom attn_bias. @@ -442,7 +505,8 @@ def test_multi_query_kv_attention( value[None, start:end], attn_bias=attn_bias[i], p=0.0, - scale=scale) + scale=scale, + ) output[start:end].copy_(out.view_as(query[start:end])) start += seq_len # xformers.AttentionBias to Tensor for use in reference impl. @@ -485,8 +549,9 @@ def test_multi_query_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) @torch.inference_mode() def test_multi_query_kv_attention_with_alibi( num_seqs: int, diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 0ff2517f7ba2..6b99ba7af50e 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -15,16 +15,18 @@ @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() # Define MLA and non-MLA backends separately DEVICE_MLA_BACKENDS = { "cuda": [ - "TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA", - "CUTLASS_MLA" + "TRITON_MLA", + "FLASHMLA", + "FLASHINFER_MLA", + "FLASH_ATTN_MLA", + "CUTLASS_MLA", ], "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], "cpu": [], @@ -40,7 +42,7 @@ def clear_cache(): "cuda": [16, 64], # CUDA supports both standard and extended block sizes "hip": [16, 1], # HIP requires special handling for block_size=1 # "cpu": [16] # CPU uses fixed block size from test cases - "cpu": [] # FIXME(woosuk): Temporarily disable CPU tests + "cpu": [], # FIXME(woosuk): Temporarily disable CPU tests } @@ -48,12 +50,13 @@ def generate_params(): params = [] for use_mla in [True, False]: for device in ["cuda", "hip", "cpu"]: - backends = DEVICE_MLA_BACKENDS[ - device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device] + backends = ( + DEVICE_MLA_BACKENDS[device] + if use_mla + else DEVICE_REGULAR_ATTN_BACKENDS[device] + ) for name in backends: - block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [ - 16 - ] + block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16] for block_size in block_sizes: params.append( pytest.param( @@ -61,14 +64,13 @@ def generate_params(): name, use_mla, block_size, - id= - f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}" - )) + id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}", + ) + ) return params -@pytest.mark.parametrize("device, name, use_mla, block_size", - generate_params()) +@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) def test_env( device: str, name: str, @@ -83,14 +85,12 @@ def test_env( m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") if device == "cpu": - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): + with patch("vllm.attention.selector.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float16, None, block_size) assert backend.get_name() == "TORCH_SDPA" elif device == "hip": - with patch("vllm.attention.selector.current_platform", - RocmPlatform()): + with patch("vllm.attention.selector.current_platform", RocmPlatform()): if use_mla: # ROCm MLA backend logic: # - TRITON_MLA: supported when block_size != 1 @@ -101,44 +101,33 @@ def test_env( if name == "TRITON_MLA" and block_size == 1: # TRITON_MLA doesn't support block_size == 1 with pytest.raises(ValueError) as exc_info: - get_attn_backend(16, - torch.float16, - None, - block_size, - use_mla=use_mla) - assert f"The selected backend, {name}" in str( - exc_info.value) + get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + assert f"The selected backend, {name}" in str(exc_info.value) elif name == "ROCM_AITER_MLA" and block_size != 1: # ROCM_AITER_MLA only supports block_size == 1 with pytest.raises(ValueError) as exc_info: - get_attn_backend(16, - torch.float16, - None, - block_size, - use_mla=use_mla) - assert f"The selected backend, {name}" in str( - exc_info.value) + get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) + assert f"The selected backend, {name}" in str(exc_info.value) else: # Valid backend-block_size combination - backend = get_attn_backend(16, - torch.float16, - None, - block_size, - use_mla=use_mla) + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) expected = name assert backend.get_name() == expected else: - backend = get_attn_backend(16, - torch.float16, - None, - block_size, - use_mla=use_mla) + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) expected = "TRITON_ATTN" assert backend.get_name() == expected elif device == "cuda": - with patch("vllm.attention.selector.current_platform", - CudaPlatform()): + with patch("vllm.attention.selector.current_platform", CudaPlatform()): if use_mla: # CUDA MLA backend logic: # - CUTLASS_MLA: only supported with block_size == 128 @@ -152,28 +141,23 @@ def test_env( if name == "CUTLASS_MLA": if block_size != 128: # CUTLASS_MLA only supports block_size == 128 - pytest.skip( - "CUTLASS_MLA only supports block_size 128") + pytest.skip("CUTLASS_MLA only supports block_size 128") else: - backend = get_attn_backend(16, - torch.float16, - None, - block_size, - use_mla=use_mla) + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) expected = "CUTLASS_MLA" assert backend.get_name() == expected elif name == "FLASHINFER_MLA": if block_size not in [32, 64]: # FlashInfer MLA only supports block_size 32 or 64 pytest.skip( - "FlashInfer MLA only supports block_size 32 " - "or 64") + "FlashInfer MLA only supports block_size 32 or 64" + ) else: - backend = get_attn_backend(16, - torch.float16, - None, - block_size, - use_mla=use_mla) + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) expected = "FLASHINFER_MLA" assert backend.get_name() == expected elif name == "FLASHMLA": @@ -181,59 +165,48 @@ def test_env( # FlashMLA only supports block_size == 64 pytest.skip("FlashMLA only supports block_size 64") else: - from vllm.v1.attention.backends.mla.flashmla import ( # noqa: E501 - is_flashmla_supported) + from vllm.v1.attention.backends.mla.flashmla import ( + is_flashmla_supported, + ) + is_supported, _ = is_flashmla_supported() if not is_supported: - pytest.skip( - "FlashMLA not supported on this platform") + pytest.skip("FlashMLA not supported on this platform") else: - backend = get_attn_backend(16, - torch.float16, - None, - block_size, - use_mla=use_mla) + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) expected = name assert backend.get_name() == expected elif name == "FLASH_ATTN_MLA": - backend = get_attn_backend(16, - torch.float16, - None, - block_size, - use_mla=use_mla) + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected else: # TRITON_MLA or other fallback - backend = get_attn_backend(16, - torch.float16, - None, - block_size, - use_mla=use_mla) + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) expected = "TRITON_MLA" assert backend.get_name() == expected elif name == "FLASHINFER": - backend = get_attn_backend(16, - torch.float16, - None, - block_size, - use_mla=use_mla) + backend = get_attn_backend( + 16, torch.float16, None, block_size, use_mla=use_mla + ) expected = "FLASHINFER" assert backend.get_name() == expected elif name == "XFORMERS": - backend = get_attn_backend(32, - torch.float16, - None, - block_size, - use_mla=use_mla) + backend = get_attn_backend( + 32, torch.float16, None, block_size, use_mla=use_mla + ) expected = "XFORMERS" assert backend.get_name() == expected elif name == "FLASH_ATTN": - backend = get_attn_backend(32, - torch.float16, - None, - block_size, - use_mla=use_mla) + backend = get_attn_backend( + 32, torch.float16, None, block_size, use_mla=use_mla + ) expected = "FLASH_ATTN" assert backend.get_name() == expected @@ -248,14 +221,12 @@ def test_fp32_fallback( m.setenv("VLLM_USE_V1", "1") if device == "cpu": - with patch("vllm.attention.selector.current_platform", - CpuPlatform()): + with patch("vllm.attention.selector.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float32, None, 16) assert backend.get_name() == "TORCH_SDPA" elif device == "cuda": - with patch("vllm.attention.selector.current_platform", - CudaPlatform()): + with patch("vllm.attention.selector.current_platform", CudaPlatform()): backend = get_attn_backend(16, torch.float32, None, 16) assert backend.get_name() == "FLEX_ATTENTION" @@ -265,16 +236,16 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): # TODO: When testing for v1, pipe in `use_v1` as an argument to # get_attn_backend - pytest.skip("Skipping as current backend selector does not " \ - "handle fallbacks when a backend is set via env var.") + pytest.skip( + "Skipping as current backend selector does not " + "handle fallbacks when a backend is set via env var." + ) with monkeypatch.context() as m: m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch - monkeypatch.setattr(torch.cuda, - "get_device_capability", - lambda _=None: (7, 5)) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) backend = get_attn_backend(16, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL @@ -295,17 +266,17 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): # flash-attn is not installed import sys - original_module = sys.modules.get('vllm_flash_attn') - monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None) + + original_module = sys.modules.get("vllm_flash_attn") + monkeypatch.setitem(sys.modules, "vllm_flash_attn", None) backend = get_attn_backend(16, torch.float16, None, 16) assert backend.get_name() != STR_FLASH_ATTN_VAL # Restore the original module if it existed if original_module is not None: - monkeypatch.setitem(sys.modules, 'vllm_flash_attn', - original_module) + monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module) else: - monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False) + monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False) # Unsupported head size backend = get_attn_backend(17, torch.float16, None, 16) @@ -314,8 +285,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): def test_invalid_env(monkeypatch: pytest.MonkeyPatch): """Test that invalid attention backend names raise ValueError.""" - with monkeypatch.context() as m, patch( - "vllm.attention.selector.current_platform", CudaPlatform()): + with ( + monkeypatch.context() as m, + patch("vllm.attention.selector.current_platform", CudaPlatform()), + ): m.setenv("VLLM_USE_V1", "1") m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 6e096a4c3999..f33a27d1fd85 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -10,7 +10,7 @@ from vllm import _custom_ops as ops from vllm.platforms import current_platform -COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] +COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")] DTYPES = [torch.bfloat16, torch.float] NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing @@ -32,9 +32,7 @@ NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] # We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] @@ -85,24 +83,33 @@ def test_copy_blocks( block_mapping.append((src, dst2)) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, block_size, - num_layers, num_heads, - head_size, kv_cache_dtype, - dtype, seed, device) + key_caches, value_caches = kv_cache_factory( + num_blocks, + block_size, + num_layers, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) # Clone the KV caches. cloned_key_caches = [key_cache.clone() for key_cache in key_caches] cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device=device).view(-1, 2) - - opcheck(torch.ops._C_cache_ops.copy_blocks, - (key_caches, value_caches, block_mapping_tensor), - test_utils=DEFAULT_OPCHECK_TEST_UTILS, - cond=(head_size == HEAD_SIZES[0])) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device=device + ).view(-1, 2) + + opcheck( + torch.ops._C_cache_ops.copy_blocks, + (key_caches, value_caches, block_mapping_tensor), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + cond=(head_size == HEAD_SIZES[0]), + ) ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) # Run the reference implementation. @@ -115,8 +122,7 @@ def test_copy_blocks( # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): torch.testing.assert_close(key_cache, cloned_key_cache) - for value_cache, cloned_value_cache in zip(value_caches, - cloned_value_caches): + for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): torch.testing.assert_close(value_cache, cloned_value_cache) @@ -155,10 +161,17 @@ def test_reshape_and_cache( _, key, value = qkv.unbind(dim=1) # Create the KV caches. - key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, - num_heads, head_size, - kv_cache_dtype, dtype, seed, - device) + key_caches, value_caches = kv_cache_factory( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + device, + ) key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale @@ -176,12 +189,30 @@ def test_reshape_and_cache( cloned_value_cache = value_cache.clone() # Call the reshape_and_cache kernel. - opcheck(torch.ops._C_cache_ops.reshape_and_cache, - (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, - k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) - ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, k_scale, v_scale) + opcheck( + torch.ops._C_cache_ops.reshape_and_cache, + ( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0]), + ) + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) @@ -202,14 +233,12 @@ def test_reshape_and_cache( cloned_value_cache[block_idx, :, :, block_offset] = value[i] if kv_cache_dtype == "fp8": - torch.testing.assert_close(result_key_cache, - cloned_key_cache, - atol=0.001, - rtol=0.1) - torch.testing.assert_close(result_value_cache, - cloned_value_cache, - atol=0.001, - rtol=0.1) + torch.testing.assert_close( + result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1 + ) + torch.testing.assert_close( + result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1 + ) else: torch.testing.assert_close(key_cache, cloned_key_cache) torch.testing.assert_close(value_cache, cloned_value_cache) @@ -254,15 +283,8 @@ def test_reshape_and_cache_flash( # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping_lst = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) - qkv = torch.randn(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device=device) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device) _, key, value = qkv.unbind(dim=1) # Create the KV caches. @@ -293,48 +315,73 @@ def permute_and_compact(x): # Clone the KV caches. if kv_cache_dtype == "fp8": - cloned_key_cache = torch.empty_like(key_cache_compact, - dtype=torch.float16) - ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(), - kv_cache_dtype) - cloned_value_cache = torch.empty_like(value_cache_compact, - dtype=torch.float16) - ops.convert_fp8(cloned_value_cache, value_cache_compact, - v_scale.item(), kv_cache_dtype) + cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16) + ops.convert_fp8( + cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype + ) + cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16) + ops.convert_fp8( + cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype + ) else: cloned_key_cache = key_cache_compact.clone() cloned_value_cache = value_cache_compact.clone() # Call the reshape_and_cache kernel. if implementation == "cuda": - opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, - (key, value, key_cache, value_cache, slot_mapping, - kv_cache_dtype, k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) - ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, k_scale, - v_scale) + opcheck( + torch.ops._C_cache_ops.reshape_and_cache_flash, + ( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ), + cond=(head_size == HEAD_SIZES[0]), + ) + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) elif implementation == "triton": from vllm.attention.ops.triton_reshape_and_cache_flash import ( - triton_reshape_and_cache_flash) - triton_reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, k_scale, - v_scale) + triton_reshape_and_cache_flash, + ) + + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) key_cache_compact = permute_and_compact(key_cache) value_cache_compact = permute_and_compact(value_cache) if kv_cache_dtype == "fp8": - result_key_cache = torch.empty_like(key_cache_compact, - dtype=torch.float16) - ops.convert_fp8(result_key_cache, - key_cache_compact, - k_scale.item(), - kv_dtype=kv_cache_dtype) - result_value_cache = torch.empty_like(value_cache_compact, - dtype=torch.float16) - ops.convert_fp8(result_value_cache, - value_cache_compact, - v_scale.item(), - kv_dtype=kv_cache_dtype) + result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16) + ops.convert_fp8( + result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype + ) + result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16) + ops.convert_fp8( + result_value_cache, + value_cache_compact, + v_scale.item(), + kv_dtype=kv_cache_dtype, + ) # Run the reference implementation. block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor") @@ -352,14 +399,12 @@ def permute_and_compact(x): cloned_value_cache[block_idx, :, block_offset, :] = value[i] if kv_cache_dtype == "fp8": - torch.testing.assert_close(result_key_cache, - cloned_key_cache, - atol=0.001, - rtol=0.1) - torch.testing.assert_close(result_value_cache, - cloned_value_cache, - atol=0.001, - rtol=0.1) + torch.testing.assert_close( + result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1 + ) + torch.testing.assert_close( + result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1 + ) else: torch.testing.assert_close(key_cache_compact, cloned_key_cache) torch.testing.assert_close(value_cache_compact, cloned_value_cache) @@ -396,8 +441,8 @@ def test_swap_blocks( current_platform.seed_everything(seed) - src_device = device if direction[0] == "cuda" else 'cpu' - dst_device = device if direction[1] == "cuda" else 'cpu' + src_device = device if direction[0] == "cuda" else "cpu" + dst_device = device if direction[1] == "cuda" else "cpu" src_blocks = random.sample(range(num_blocks), num_mappings) # For the same device, mapping must not overlap @@ -408,42 +453,62 @@ def test_swap_blocks( dst_blocks = random.sample(range(num_blocks), num_mappings) block_mapping = list(zip(src_blocks, dst_blocks)) - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device="cpu").view(-1, 2) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device="cpu" + ).view(-1, 2) # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( - num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, - seed, src_device) + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + src_device, + ) # Create the KV caches on the second device. dist_key_caches, dist_value_caches = kv_cache_factory( - num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, - seed, dst_device) + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + seed, + dst_device, + ) src_key_caches_clone = src_key_caches[0].clone() src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - do_opcheck = (head_size == HEAD_SIZES[0]) - opcheck(torch.ops._C_cache_ops.swap_blocks, - (src_key_caches[0], dist_key_caches[0], block_mapping_tensor), - cond=do_opcheck) - opcheck(torch.ops._C_cache_ops.swap_blocks, - (src_value_caches[0], dist_value_caches[0], block_mapping_tensor), - cond=do_opcheck) - - ops.swap_blocks(src_key_caches[0], dist_key_caches[0], - block_mapping_tensor) - ops.swap_blocks(src_value_caches[0], dist_value_caches[0], - block_mapping_tensor) + do_opcheck = head_size == HEAD_SIZES[0] + opcheck( + torch.ops._C_cache_ops.swap_blocks, + (src_key_caches[0], dist_key_caches[0], block_mapping_tensor), + cond=do_opcheck, + ) + opcheck( + torch.ops._C_cache_ops.swap_blocks, + (src_value_caches[0], dist_value_caches[0], block_mapping_tensor), + cond=do_opcheck, + ) + + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor) for src, dst in block_mapping: - torch.testing.assert_close(src_key_caches_clone[src].cpu(), - dist_key_caches[0][dst].cpu()) - torch.testing.assert_close(src_value_caches_clone[src].cpu(), - dist_value_caches[0][dst].cpu()) + torch.testing.assert_close( + src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu() + ) + torch.testing.assert_close( + src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu() + ) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -489,11 +554,9 @@ def _create_mla_cache( device: str, ) -> torch.Tensor: cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype - return torch.zeros(num_blocks, - block_size, - entry_size, - dtype=cache_dtype, - device=device) + return torch.zeros( + num_blocks, block_size, entry_size, dtype=cache_dtype, device=device + ) def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str): @@ -533,20 +596,16 @@ def test_concat_and_cache_mla( total_slots = num_blocks * block_size slot_mapping_lst = random.sample(range(total_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) - k_pe = torch.randn(num_tokens, - qk_rope_head_dim, - dtype=dtype, - device=device) + k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device) entry_size = kv_lora_rank + qk_rope_head_dim scale = torch.tensor(0.1, dtype=torch.float32, device=device) - kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + kv_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) for i in range(num_tokens): @@ -558,10 +617,7 @@ def test_concat_and_cache_mla( if kv_cache_dtype == "fp8": ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) - ops.convert_fp8(ref_kv_cache, - ref_temp, - scale.item(), - kv_dtype=kv_cache_dtype) + ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype) else: ref_kv_cache = ref_temp @@ -571,24 +627,18 @@ def test_concat_and_cache_mla( test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, - kv_cache_dtype, scale) + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale) if kv_cache_dtype == "fp8": result_temp = torch.empty_like(kv_cache, dtype=torch.float16) - ops.convert_fp8(result_temp, - kv_cache.contiguous(), - scale.item(), - kv_dtype=kv_cache_dtype) + ops.convert_fp8( + result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype + ) expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16) - ops.convert_fp8(expected_temp, - ref_kv_cache, - scale.item(), - kv_dtype=kv_cache_dtype) - torch.testing.assert_close(result_temp, - expected_temp, - atol=0.001, - rtol=0.1) + ops.convert_fp8( + expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype + ) + torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1) else: torch.testing.assert_close(kv_cache, ref_kv_cache) @@ -620,24 +670,21 @@ def test_concat_and_cache_ds_mla( total_slots = num_blocks * block_size slot_mapping_lst = random.sample(range(total_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) - k_pe = torch.randn(num_tokens, - qk_rope_head_dim, - dtype=dtype, - device=device) + k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device) entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim) scale = torch.tensor(1.0, dtype=torch.float32, device=device) - kv_cache = _create_mla_cache(num_blocks, - block_size, - entry_size, - dtype=torch.uint8, - kv_cache_dtype=kv_cache_dtype, - device=device) + kv_cache = _create_mla_cache( + num_blocks, + block_size, + entry_size, + dtype=torch.uint8, + kv_cache_dtype=kv_cache_dtype, + device=device, + ) ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype) tile_data = torch.zeros(128, dtype=dtype, device=device) @@ -664,14 +711,16 @@ def test_concat_and_cache_ds_mla( manual_max = abs(tile_data_float[0]) for j in range(1, 128): manual_max = max(manual_max, abs(tile_data_float[j])) - tile_scale = manual_max / 448. + tile_scale = manual_max / 448.0 ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale - ops.convert_fp8(ref_cache_slice[tile_start:tile_end], - tile_data, - tile_scale.item(), - kv_dtype="fp8") + ops.convert_fp8( + ref_cache_slice[tile_start:tile_end], + tile_data, + tile_scale.item(), + kv_dtype="fp8", + ) for j in range(qk_rope_head_dim): ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j] @@ -682,8 +731,7 @@ def test_concat_and_cache_ds_mla( test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, - kv_cache_dtype, scale) + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale) for i in range(num_tokens): slot = slot_mapping[i].item() @@ -694,12 +742,14 @@ def test_concat_and_cache_ds_mla( kv_nope = kv_cache_slice[:kv_lora_rank] ref_nope = ref_cache_slice[:kv_lora_rank] - kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank // - 4:kv_lora_rank // 4 + 4] - ref_scales = ref_cache_slice.view( - torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4] - kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:] - ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:] + kv_scales = kv_cache_slice.view(torch.float32)[ + kv_lora_rank // 4 : kv_lora_rank // 4 + 4 + ] + ref_scales = ref_cache_slice.view(torch.float32)[ + kv_lora_rank // 4 : kv_lora_rank // 4 + 4 + ] + kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :] + ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :] torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1) torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1) @@ -734,8 +784,9 @@ def test_copy_blocks_mla( kv_caches = [] for _ in range(num_layers): - kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + kv_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) _fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype) kv_caches.append(kv_cache) @@ -752,9 +803,9 @@ def test_copy_blocks_mla( dst2 = dst_blocks[2 * i + 1] block_mapping.append((src, dst1)) block_mapping.append((src, dst2)) - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device=device).view(-1, 2) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device=device + ).view(-1, 2) for src, dst in block_mapping: for ref_cache in ref_caches: @@ -795,10 +846,12 @@ def test_swap_blocks_mla( entry_size = kv_lora_rank + qk_rope_head_dim - src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) - dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + src_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) + dst_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) _fill_mla_cache(src_cache, kv_cache_dtype) _fill_mla_cache(dst_cache, kv_cache_dtype) @@ -810,9 +863,9 @@ def test_swap_blocks_mla( remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remaining_blocks, num_mappings) block_mapping = list(zip(src_blocks, dst_blocks)) - block_mapping_tensor = torch.tensor(block_mapping, - dtype=torch.int64, - device="cpu").view(-1, 2) + block_mapping_tensor = torch.tensor( + block_mapping, dtype=torch.int64, device="cpu" + ).view(-1, 2) opcheck( torch.ops._C_cache_ops.swap_blocks, @@ -827,7 +880,8 @@ def test_swap_blocks_mla( src_cache_clone[src].cpu(), dst_cache[dst].cpu(), msg=f"Block {src} from src should have been swapped to block " - f"{dst} in dst_cache.") + f"{dst} in dst_cache.", + ) @pytest.mark.parametrize("kv_lora_rank", [512]) @@ -840,32 +894,36 @@ def test_swap_blocks_mla( @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, - block_size, num_blocks, - max_seq_len, batch_size, dtype, - kv_cache_dtype, device): +def test_gather_and_maybe_dequant_cache_mla( + kv_lora_rank, + qk_rope_head_dim, + block_size, + num_blocks, + max_seq_len, + batch_size, + dtype, + kv_cache_dtype, + device, +): entry_size = kv_lora_rank + qk_rope_head_dim scale = torch.tensor(0.1, dtype=torch.float32, device=device) - src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + src_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) - seq_len_tensor = torch.randint(0, - max_seq_len + 1, (batch_size, ), - device=device) + seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device) total_tokens = seq_len_tensor.sum() - cu_seq_lens = torch.empty((batch_size + 1), - dtype=torch.int32, - device=device) + cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device) cu_seq_lens[0] = 0 cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) print("seq_len_tensor", seq_len_tensor) tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size - block_table = torch.empty((batch_size, num_blocks), - dtype=torch.int32, - device=device) + block_table = torch.empty( + (batch_size, num_blocks), dtype=torch.int32, device=device + ) for b in range(batch_size): perm = torch.randperm(num_blocks, device=device) @@ -893,10 +951,8 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, remaining = s - (tot - 1) * block_size last_block_data = src_cache[blocks[-1], :remaining, :] if kv_cache_dtype == "fp8": - dequantized_last_block = torch.empty_like(last_block_data, - dtype=dtype) - ops.convert_fp8(dequantized_last_block, last_block_data, - scale.item()) + dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype) + ops.convert_fp8(dequantized_last_block, last_block_data, scale.item()) gathered_rows.append(dequantized_last_block) else: gathered_rows.append(last_block_data) @@ -907,14 +963,29 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, opcheck( torch.ops._C_cache_ops.gather_and_maybe_dequant_cache, - (src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, - scale, None), + ( + src_cache, + dst, + block_table, + cu_seq_lens, + batch_size, + kv_cache_dtype, + scale, + None, + ), test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table, - cu_seq_lens, batch_size, kv_cache_dtype, - scale, None) + ops.gather_and_maybe_dequant_cache( + src_cache, + dst, + block_table, + cu_seq_lens, + batch_size, + kv_cache_dtype, + scale, + None, + ) torch.testing.assert_close(dst, expected) @@ -925,42 +996,46 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim, @pytest.mark.parametrize("max_seq_len", [512]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("kv_cache_dtype", - ["auto"]) # You can also test "fp8" if needed. +@pytest.mark.parametrize( + "kv_cache_dtype", ["auto"] +) # You can also test "fp8" if needed. @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size, - num_blocks, max_seq_len, batch_size, dtype, - kv_cache_dtype, device): +def test_cp_gather_cache_mla( + kv_lora_rank, + qk_rope_head_dim, + block_size, + num_blocks, + max_seq_len, + batch_size, + dtype, + kv_cache_dtype, + device, +): entry_size = kv_lora_rank + qk_rope_head_dim - src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + src_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) _fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype) - seq_len_tensor = torch.randint(0, - max_seq_len + 1, (batch_size, ), - device=device) + seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device) total_tokens = seq_len_tensor.sum() - cu_seq_lens = torch.empty((batch_size + 1), - dtype=torch.int32, - device=device) + cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device) cu_seq_lens[0] = 0 cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32) print("seq_len_tensor", seq_len_tensor) tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size - block_table = torch.empty((batch_size, num_blocks), - dtype=torch.int32, - device=device) + block_table = torch.empty( + (batch_size, num_blocks), dtype=torch.int32, device=device + ) for b in range(batch_size): perm = torch.randperm(num_blocks, device=device) block_table[b, :] = perm - dst = torch.zeros((total_tokens, entry_size), - dtype=src_cache.dtype, - device=device) + dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device) expected_batches = [] for b in range(batch_size): @@ -1016,20 +1091,16 @@ def test_concat_and_cache_mla_cpu( total_slots = num_blocks * block_size slot_mapping_lst = random.sample(range(total_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping_lst, - dtype=torch.long, - device=device) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) - k_pe = torch.randn(num_tokens, - qk_rope_head_dim, - dtype=dtype, - device=device) + k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device) entry_size = kv_lora_rank + qk_rope_head_dim scale = torch.tensor(0.1, dtype=torch.float32, device=device) - kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype, - kv_cache_dtype, device) + kv_cache = _create_mla_cache( + num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device + ) ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) for i in range(num_tokens): @@ -1041,10 +1112,7 @@ def test_concat_and_cache_mla_cpu( if kv_cache_dtype == "fp8": ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) - ops.convert_fp8(ref_kv_cache, - ref_temp, - scale.item(), - kv_dtype=kv_cache_dtype) + ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype) else: ref_kv_cache = ref_temp @@ -1054,6 +1122,5 @@ def test_concat_and_cache_mla_cpu( test_utils=DEFAULT_OPCHECK_TEST_UTILS, ) - ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, - kv_cache_dtype, scale) + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale) torch.testing.assert_close(kv_cache, ref_kv_cache) diff --git a/tests/kernels/attention/test_cascade_flash_attn.py b/tests/kernels/attention/test_cascade_flash_attn.py index 1e7e7e0a7f84..58e8bd592ba4 100755 --- a/tests/kernels/attention/test_cascade_flash_attn.py +++ b/tests/kernels/attention/test_cascade_flash_attn.py @@ -7,11 +7,12 @@ import torch from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import (cascade_attention, - merge_attn_states) -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - is_fa_version_supported) +from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states +from vllm.vllm_flash_attn import ( + fa_version_unsupported_reason, + flash_attn_varlen_func, + is_fa_version_supported, +) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 192, 256] @@ -37,21 +38,14 @@ def test_merge_kernel( assert num_query_heads % num_kv_heads == 0 # Prepare inputs. - prefix_output = torch.randn(num_tokens, - num_query_heads, - head_size, - dtype=dtype) - suffix_output = torch.randn(num_tokens, - num_query_heads, - head_size, - dtype=dtype) + prefix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype) + suffix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype) prefix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32) suffix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32) # Run the kernel. output = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype) - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) # Reference implementation. max_lse = torch.maximum(prefix_lse, suffix_lse) @@ -97,8 +91,10 @@ def test_cascade( ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " - f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + pytest.skip( + f"Flash attention version {fa_version} not supported due " + f'to: "{fa_version_unsupported_reason(fa_version)}"' + ) current_platform.seed_everything(0) @@ -107,11 +103,9 @@ def test_cascade( num_query_heads = num_heads[0] num_kv_heads = num_heads[1] assert num_query_heads % num_kv_heads == 0 - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) seq_lens, common_prefix_len = seq_lens_and_common_prefix @@ -122,26 +116,21 @@ def test_cascade( max_kv_len = max(kv_lens) total_num_query_tokens = sum(query_lens) - query = torch.randn(total_num_query_tokens, - num_query_heads, - head_size, - dtype=dtype) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + query = torch.randn(total_num_query_tokens, num_query_heads, head_size, dtype=dtype) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) assert common_prefix_len > 0 assert common_prefix_len % block_size == 0 num_common_kv_blocks = common_prefix_len // block_size # Make sure the first `num_common_kv_blocks` blocks are the same. - block_tables[:, :num_common_kv_blocks] = \ - block_tables[0, :num_common_kv_blocks] + block_tables[:, :num_common_kv_blocks] = block_tables[0, :num_common_kv_blocks] # Run the regular attention. ref_output = flash_attn_varlen_func( @@ -161,8 +150,7 @@ def test_cascade( # Run cascade attention. assert all(common_prefix_len < kv_len for kv_len in kv_lens) - cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], - dtype=torch.int32) + cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], dtype=torch.int32) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32) suffix_kv_lens = kv_lens_tensor - common_prefix_len output = torch.empty_like(query) diff --git a/tests/kernels/attention/test_cutlass_mla_decode.py b/tests/kernels/attention/test_cutlass_mla_decode.py index 5078bd730a1a..dad1510ce532 100644 --- a/tests/kernels/attention/test_cutlass_mla_decode.py +++ b/tests/kernels/attention/test_cutlass_mla_decode.py @@ -12,33 +12,37 @@ from vllm.triton_utils import triton -def cal_diff(x: torch.Tensor, - y: torch.Tensor, - name: str, - use_fp8: bool = False, - diff_threshold: Optional[float] = None) -> None: +def cal_diff( + x: torch.Tensor, + y: torch.Tensor, + name: str, + use_fp8: bool = False, + diff_threshold: Optional[float] = None, +) -> None: x, y = x.double(), y.double() - cos_diff = 1 - 2 * (x * y).sum().item() / max( - (x * x + y * y).sum().item(), 1e-12) + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) if diff_threshold is not None: # directly compare the cos_diff with the threshold assert cos_diff < diff_threshold else: # use the default threshold - if (use_fp8): + if use_fp8: assert cos_diff < 1e-4 else: assert cos_diff < 1e-5 -CUTLASS_MLA_UNSUPPORTED_REASON = \ - "Cutlass MLA Requires compute capability of 10 or above." \ - if not current_platform.is_device_capability(100) \ +CUTLASS_MLA_UNSUPPORTED_REASON = ( + "Cutlass MLA Requires compute capability of 10 or above." + if not current_platform.is_device_capability(100) else "Cutlass MLA is supported" +) -@pytest.mark.skipif(not current_platform.has_device_capability(100), - reason=CUTLASS_MLA_UNSUPPORTED_REASON) +@pytest.mark.skipif( + not current_platform.has_device_capability(100), + reason=CUTLASS_MLA_UNSUPPORTED_REASON, +) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1]) @pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) @@ -54,40 +58,40 @@ def cal_diff(x: torch.Tensor, [ torch.bfloat16, # fp8 can have occasional precision-related failures. - pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2)) - ]) + pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2)), + ], +) @torch.inference_mode() -def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, - causal, varlen, torch_dtype): +def test_cutlass_mla_decode( + b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype +): device = torch.device("cuda:0") - if torch_dtype == torch.float8_e4m3fn: - init_dtype = torch.bfloat16 - else: - init_dtype = torch_dtype + init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(42) random.seed(42) - print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " - f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") + print( + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}" + ) use_fp8 = torch_dtype == torch.float8_e4m3fn - scale = math.sqrt(d)**(-1) - cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) + scale = math.sqrt(d) ** (-1) + cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: for i in range(b): - cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), - s_q) + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) total_seqlens = cache_seqlens.sum().item() max_seqlen = cache_seqlens.max().item() max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 q = torch.randn(b, s_q, h_q, d) - block_table = torch.arange(b * max_seqlen_pad // block_size, - dtype=torch.int32).view( - b, max_seqlen_pad // block_size) + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32 + ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) blocked_v = blocked_k[..., :dv] @@ -121,22 +125,29 @@ def cutlass_mla(): q_pe = q_pe_padded kv_cache_flat = blocked_k.squeeze(2) - device_properties = torch.cuda.get_device_properties( - torch.device("cuda:0")) + device_properties = torch.cuda.get_device_properties(torch.device("cuda:0")) sm_count = device_properties.multi_processor_count workspace_size = ops.sm100_cutlass_mla_get_workspace_size( - max_seqlen * block_size, b, sm_count, num_kv_splits=1) - workspace = torch.empty(workspace_size, - device="cuda", - dtype=torch.uint8) + max_seqlen * block_size, b, sm_count, num_kv_splits=1 + ) + workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8) out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype) - output_lse = torch.empty((b, MAX_HEADS), - dtype=torch.float32, - device=q_nope.device) - ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe, - kv_cache_flat, cache_seqlens, block_table, - workspace, scale, 1) + output_lse = torch.empty( + (b, MAX_HEADS), dtype=torch.float32, device=q_nope.device + ) + ops.sm100_cutlass_mla_decode( + out_ans, + output_lse, + q_nope, + q_pe, + kv_cache_flat, + cache_seqlens, + block_table, + workspace, + scale, + 1, + ) return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous() def scaled_dot_product_attention(query, key, value, is_causal=False): @@ -150,8 +161,7 @@ def scaled_dot_product_attention(query, key, value, is_causal=False): s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) - temp_mask = torch.ones(s_q, s_k, - dtype=torch.bool).tril(diagonal=s_k - s_q) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias @@ -161,10 +171,16 @@ def scaled_dot_product_attention(query, key, value, is_causal=False): def ref_mla(): q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q - blocked_k_ = (blocked_k.to(torch.float) * - descale_k).to(init_dtype) if use_fp8 else blocked_k - blocked_v_ = (blocked_v.to(torch.float) * - descale_k).to(init_dtype) if use_fp8 else blocked_v + blocked_k_ = ( + (blocked_k.to(torch.float) * descale_k).to(init_dtype) + if use_fp8 + else blocked_k + ) + blocked_v_ = ( + (blocked_v.to(torch.float) * descale_k).to(init_dtype) + if use_fp8 + else blocked_v + ) out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): @@ -191,8 +207,9 @@ def ref_mla(): t = triton.testing.do_bench(cutlass_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + - b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( - b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", - f"{bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * ( + torch.finfo(torch_dtype).bits // 8 + ) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print( + f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS,", f"{bytes / 10**6 / t:.0f} GB/s" + ) diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 2d901e408b27..4873afa649c9 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -7,9 +7,14 @@ from vllm.platforms import current_platform from vllm.utils import cdiv, has_deep_gemm -from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits, - fp8_paged_mqa_logits, get_num_sms, - get_paged_mqa_logits_metadata) +from vllm.utils.deep_gemm import ( + _ceil_to_ue8m0, + calc_diff, + fp8_mqa_logits, + fp8_paged_mqa_logits, + get_num_sms, + get_paged_mqa_logits_metadata, +) def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: @@ -24,17 +29,18 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: device=x.device, dtype=torch.uint8, ) - x_fp8[:, :block_size * head_dim] = x_scaled.view( - num_blocks, block_size * head_dim).view(dtype=torch.uint8) - x_fp8[:, - block_size * head_dim:] = sf.view(num_blocks, - block_size).view(dtype=torch.uint8) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(dtype=torch.uint8) + x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view( + dtype=torch.uint8 + ) return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) def per_custom_dims_cast_to_fp8( - x: torch.Tensor, dims: tuple, - use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, dims: tuple, use_ue8m0: bool +) -> tuple[torch.Tensor, torch.Tensor]: excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) sf = x_amax / 448.0 @@ -69,10 +75,12 @@ def _ref_fp8_mqa_logits( q = q.float() k = k.float() - mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :] - >= cu_seqlen_ks[:, None]) - mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :] - < cu_seqlen_ke[:, None]) + mask_lo = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + ) mask = mask_lo & mask_hi score = torch.einsum("mhd,and->hmn", q, k) @@ -84,14 +92,15 @@ def _ref_fp8_mqa_logits( @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="SM90 and SM100 only") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) def test_deepgemm_fp8_mqa_logits(): torch.manual_seed(0) random.seed(0) num_heads, head_dim = 32, 128 - for seq_len in (512, ): - for seq_len_kv in (1024, ): + for seq_len in (512,): + for seq_len_kv in (1024,): for disable_cp in (False, True): q = torch.randn( seq_len, @@ -100,24 +109,23 @@ def test_deepgemm_fp8_mqa_logits(): device="cuda", dtype=torch.bfloat16, ) - kv = torch.randn(seq_len_kv, - head_dim, - device="cuda", - dtype=torch.bfloat16) - weights = torch.randn(seq_len, - num_heads, - device="cuda", - dtype=torch.float32) + kv = torch.randn( + seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16 + ) + weights = torch.randn( + seq_len, num_heads, device="cuda", dtype=torch.float32 + ) if disable_cp: ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") - ke = torch.arange(seq_len, dtype=torch.int, - device="cuda") + (seq_len_kv - seq_len) + ke = torch.arange(seq_len, dtype=torch.int, device="cuda") + ( + seq_len_kv - seq_len + ) else: ks, ke = _generate_cp_test_data(seq_len, seq_len_kv) q_fp8 = q.to(torch.float8_e4m3fn) - kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False) + kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) ref_logits = _ref_fp8_mqa_logits( @@ -157,11 +165,10 @@ def _ref_fp8_paged_mqa_logits( context_lens_list = context_lens.tolist() for i in range(batch_size): context_len = context_lens_list[i] - q_offsets = torch.arange(context_len - next_n, - context_len, - device="cuda") - weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose( - 0, 1).contiguous()) + q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") + weight_slice = ( + weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() + ) for block_rk in range(cdiv(context_len, block_size)): block_idx = block_tables[i][block_rk] qx, kx = q[i], kv_cache[block_idx] @@ -170,28 +177,30 @@ def _ref_fp8_paged_mqa_logits( (block_rk + 1) * block_size, device="cuda", ) - mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] - <= q_offsets[:, None]) + mask = (k_offsets[None, :] < context_len) & ( + k_offsets[None, :] <= q_offsets[:, None] + ) s = torch.where( mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( - logits.dtype), + logits.dtype + ), float("-inf"), ) s = torch.relu(s) * weight_slice[..., None] s = s.sum(dim=0) logits[ - i * next_n:(i + 1) * next_n, - block_rk * block_size:(block_rk + 1) * block_size, - ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, - float("-inf")) + i * next_n : (i + 1) * next_n, + block_rk * block_size : (block_rk + 1) * block_size, + ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) return logits @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="SM90 and SM100 only") +@pytest.mark.skipif( + not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +) def test_deepgemm_fp8_paged_mqa_logits(): torch.manual_seed(0) random.seed(0) @@ -199,7 +208,7 @@ def test_deepgemm_fp8_paged_mqa_logits(): max_model_len = 4096 for batch_size, next_n in [(4, 1), (2, 2)]: for heads, index_dim in [(32, 128)]: - for avg_kv in (2048, ): + for avg_kv in (2048,): num_blocks, blocksize = max_model_len * 2, 64 q = torch.randn( @@ -218,12 +227,14 @@ def test_deepgemm_fp8_paged_mqa_logits(): dtype=torch.float32, ) - context_lens = (torch.randint(int(0.8 * avg_kv), - int(1.2 * avg_kv), - (batch_size, )).cuda().to( - torch.int32)) - max_block_len = ((context_lens.max().item() + blocksize - 1) // - blocksize * blocksize) + context_lens = ( + torch.randint(int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,)) + .cuda() + .to(torch.int32) + ) + max_block_len = ( + (context_lens.max().item() + blocksize - 1) // blocksize * blocksize + ) block_tables = torch.zeros( (batch_size, max_block_len), device="cuda", @@ -243,7 +254,8 @@ def test_deepgemm_fp8_paged_mqa_logits(): kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) schedule_metadata = get_paged_mqa_logits_metadata( - context_lens, blocksize, get_num_sms()) + context_lens, blocksize, get_num_sms() + ) logits = fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, @@ -263,15 +275,18 @@ def test_deepgemm_fp8_paged_mqa_logits(): max_model_len, ) - positions = (torch.arange(max_model_len, - device="cuda").unsqueeze(0).expand( - batch_size * next_n, -1)) - row_indices = ( - torch.arange(batch_size * next_n, device="cuda") // next_n) + positions = ( + torch.arange(max_model_len, device="cuda") + .unsqueeze(0) + .expand(batch_size * next_n, -1) + ) + row_indices = torch.arange(batch_size * next_n, device="cuda") // next_n next_n_offset = ( - torch.arange(batch_size * next_n, device="cuda") % next_n) - mask = positions <= (context_lens[row_indices] - next_n + - next_n_offset).unsqueeze(1) + torch.arange(batch_size * next_n, device="cuda") % next_n + ) + mask = positions <= ( + context_lens[row_indices] - next_n + next_n_offset + ).unsqueeze(1) logits = logits.masked_fill(~mask, 0) ref_logits = ref_logits.masked_fill(~mask, 0) diff --git a/tests/kernels/attention/test_flash_attn.py b/tests/kernels/attention/test_flash_attn.py index 2544703f8bf9..d39f0a593ed4 100644 --- a/tests/kernels/attention/test_flash_attn.py +++ b/tests/kernels/attention/test_flash_attn.py @@ -7,10 +7,12 @@ import torch from vllm.platforms import current_platform -from vllm.vllm_flash_attn import (fa_version_unsupported_reason, - flash_attn_varlen_func, - flash_attn_with_kvcache, - is_fa_version_supported) +from vllm.vllm_flash_attn import ( + fa_version_unsupported_reason, + flash_attn_varlen_func, + flash_attn_with_kvcache, + is_fa_version_supported, +) NUM_HEADS = [(4, 4), (8, 2)] HEAD_SIZES = [128, 256] @@ -44,7 +46,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -62,10 +64,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -106,11 +111,15 @@ def test_flash_attn_with_paged_kv( ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " - f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + pytest.skip( + f"Flash attention version {fa_version} not supported due " + f'to: "{fa_version_unsupported_reason(fa_version)}"' + ) if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): - pytest.skip("Flash attention with quantized inputs is only " - "supported on version 3 with bfloat16 base type") + pytest.skip( + "Flash attention with quantized inputs is only " + "supported on version 3 with bfloat16 base type" + ) current_platform.seed_everything(0) num_seqs = len(kv_lens) @@ -119,23 +128,19 @@ def test_flash_attn_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_kv_len = max(kv_lens) scale = head_size**-0.5 - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) q = query.unsqueeze(1) out = torch.empty_like(q) if use_out else None @@ -180,23 +185,27 @@ def test_flash_attn_with_paged_kv( if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window) - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + sliding_window=sliding_window, + ) + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("use_out", [True, False]) -@pytest.mark.parametrize("seq_lens", - [[(1, 1328), (5, 18), - (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) +@pytest.mark.parametrize( + "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] +) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -222,11 +231,15 @@ def test_varlen_with_paged_kv( ) -> None: torch.set_default_device("cuda") if not is_fa_version_supported(fa_version): - pytest.skip(f"Flash attention version {fa_version} not supported due " - f"to: \"{fa_version_unsupported_reason(fa_version)}\"") + pytest.skip( + f"Flash attention version {fa_version} not supported due " + f'to: "{fa_version_unsupported_reason(fa_version)}"' + ) if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2): - pytest.skip("Flash attention with quantized inputs is only " - "supported on version 3 with bfloat16 base type") + pytest.skip( + "Flash attention with quantized inputs is only " + "supported on version 3 with bfloat16 base type" + ) current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] @@ -236,30 +249,23 @@ def test_varlen_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) out = torch.empty_like(query) if use_out else None @@ -315,5 +321,7 @@ def test_varlen_with_paged_kv( atol, rtol = 1.5e-2, 1e-2 if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index a821a74aba93..52cd10fdc5be 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -38,7 +38,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -56,10 +56,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -101,20 +104,16 @@ def test_flashinfer_decode_with_paged_kv( query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_value_cache = torch.randn(NUM_BLOCKS, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] @@ -135,9 +134,9 @@ def test_flashinfer_decode_with_paged_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=True) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD", use_tensor_cores=True + ) wrapper.plan( kv_indptr, kv_indices, @@ -155,17 +154,21 @@ def test_flashinfer_decode_with_paged_kv( output = wrapper.run(query, key_value_cache) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window) - torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + sliding_window=sliding_window, + ) + ( + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) @@ -196,16 +199,10 @@ def test_flashinfer_prefill_with_paged_kv( max_kv_len = max(kv_lens) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_value_cache = torch.randn(NUM_BLOCKS, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) @@ -215,10 +212,9 @@ def test_flashinfer_prefill_with_paged_kv( value_cache /= head_size**0.5 max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) qo_indptr = [0] kv_indptr = [0] @@ -242,8 +238,7 @@ def test_flashinfer_prefill_with_paged_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD") + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( qo_indptr, kv_indptr, @@ -264,17 +259,21 @@ def test_flashinfer_prefill_with_paged_kv( key_value_cache, ) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - sliding_window=sliding_window) - torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + sliding_window=sliding_window, + ) + ( + torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]]) @@ -284,9 +283,13 @@ def test_flashinfer_prefill_with_paged_kv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", SOFT_CAPS) def test_flashinfer_prefill_with_paged_fp8_kv( - seq_lens: list[tuple[int, int]], num_heads: tuple[int, int], - head_size: int, dtype: torch.dtype, block_size: int, - soft_cap: Optional[float]) -> None: + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], +) -> None: pytest.skip("TODO: fix the accuracy issue") torch.set_default_device("cuda") current_platform.seed_everything(0) @@ -301,17 +304,11 @@ def test_flashinfer_prefill_with_paged_fp8_kv( kv_cache_dtype = torch.float8_e4m3fn - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) NUM_BLOCKS_FP8 = 2048 - key_value_cache = torch.randn(NUM_BLOCKS_FP8, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) key_cache /= head_size**0.5 value_cache /= head_size**0.5 @@ -319,15 +316,15 @@ def test_flashinfer_prefill_with_paged_fp8_kv( k_scale = key_cache.amax().item() / 448.0 v_scale = value_cache.amax().item() / 448.0 - kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], - dim=1).to(kv_cache_dtype) + kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], dim=1).to( + kv_cache_dtype + ) - assert (kv_cache_fp8.shape == key_value_cache.shape) + assert kv_cache_fp8.shape == key_value_cache.shape max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS_FP8, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) qo_indptr = [0] kv_indptr = [0] @@ -351,8 +348,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD") + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( qo_indptr, kv_indptr, @@ -369,19 +365,23 @@ def test_flashinfer_prefill_with_paged_fp8_kv( output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache.squeeze(1), - value_cache=value_cache.squeeze(1), - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap) + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache.squeeze(1), + value_cache=value_cache.squeeze(1), + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + ) del query del block_tables # verify prefill fp8 - torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @@ -414,12 +414,9 @@ def test_flashinfer_decode_with_paged_fp8_kv( query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) NUM_BLOCKS_FP8 = 2048 - key_value_cache = torch.randn(NUM_BLOCKS_FP8, - 2, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) key_cache /= head_size**0.5 value_cache /= head_size**0.5 @@ -429,14 +426,13 @@ def test_flashinfer_decode_with_paged_fp8_kv( key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype) value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype) - assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1) + assert key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1 kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS_FP8, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] @@ -457,32 +453,38 @@ def test_flashinfer_decode_with_paged_fp8_kv( kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.\ - BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", - use_tensor_cores=use_tensor_cores) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - q_data_type=dtype, - kv_data_type=kv_cache_dtype, - logits_soft_cap=soft_cap) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + q_data_type=dtype, + kv_data_type=kv_cache_dtype, + logits_soft_cap=soft_cap, + ) output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) - ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap) + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + ) # Temporary fix: Increasing the tolerance. Seems like a flashinfer issue - torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/attention/test_flashinfer_mla_decode.py b/tests/kernels/attention/test_flashinfer_mla_decode.py index 02225432f77f..0350136677c6 100644 --- a/tests/kernels/attention/test_flashinfer_mla_decode.py +++ b/tests/kernels/attention/test_flashinfer_mla_decode.py @@ -13,34 +13,29 @@ if not current_platform.has_device_capability(100): pytest.skip( reason="FlashInfer MLA Requires compute capability of 10 or above.", - allow_module_level=True) + allow_module_level=True, + ) def ref_mla( - out: Tensor, # (bs, num_heads, v_head_dim) - query: Tensor, # (bs, num_heads, head_dim) - kv_cache: Tensor, # (num_blocks, block_size, head_dim) - scale: float, - block_tables: Tensor, # (bs, max_num_blocks) - seq_lens: Tensor, # (bs,) + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) ): bs, num_heads, v_head_dim = out.shape head_dim = query.shape[2] for i in range(bs): # gather and flatten KV-cache - kv = kv_cache[ - block_tables[i]] # (max_num_blocks, block_size, head_dim) - kv = kv.view(1, -1, - head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim) v = kv[:, :, :v_head_dim] q = query[i].view(num_heads, 1, head_dim) - o = F.scaled_dot_product_attention(q, - kv, - v, - scale=scale, - enable_gqa=True) + o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True) out[i] = o.view(num_heads, v_head_dim) return out @@ -50,7 +45,7 @@ def ref_mla( @pytest.mark.parametrize("bs", [1, 2, 4, 16]) @pytest.mark.parametrize("block_size", [32, 64]) def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int): - torch.set_default_device('cuda') + torch.set_default_device("cuda") torch.manual_seed(42) # Deepseek R1 config @@ -59,11 +54,11 @@ def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int): qk_nope_head_dim = 128 qk_rope_head_dim = 64 qk_head_dim = kv_lora_rank + qk_rope_head_dim - scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5 + scale = (qk_nope_head_dim + qk_rope_head_dim) ** -0.5 MAX_SEQ_LEN = 1024 - seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)] + seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1,)).item() for _ in range(bs)] seq_lens[-1] = MAX_SEQ_LEN max_seq_len = max(seq_lens) seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) @@ -86,12 +81,12 @@ def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int): block_id = 0 for i in range(bs): num_blocks_needed = blocks_per_seq[i] - block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id + - num_blocks_needed] + block_tables[i, :num_blocks_needed] = all_block_ids[ + block_id : block_id + num_blocks_needed + ] block_id += num_blocks_needed - kv_cache = torch.randn(block_tables.numel(), block_size, - qk_head_dim).to(dtype) + kv_cache = torch.randn(block_tables.numel(), block_size, qk_head_dim).to(dtype) q = torch.randn(bs, num_heads, qk_head_dim).to(dtype) out_ref = q.new_zeros(bs, num_heads, kv_lora_rank) diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index bd3ba554b32e..62d94f0bb751 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -6,15 +6,18 @@ import pytest import torch -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) from vllm.platforms import current_platform from vllm.utils import round_up if not current_platform.is_device_capability(100): - pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", - allow_module_level=True) + pytest.skip( + "This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True + ) FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FP8_DTYPE = current_platform.fp8_dtype() @@ -64,8 +67,9 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @torch.inference_mode def test_flashinfer_trtllm_decode_with_baseline( dtype: torch.dtype, - quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], - Optional[torch.dtype]], + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], batch_size: int, max_seq_lens: tuple[int, int], num_heads: tuple[int, int], @@ -106,7 +110,7 @@ def test_flashinfer_trtllm_decode_with_baseline( q_scale = 1.0 ref_query = query - kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_kv_len seq_lens = kv_lens @@ -122,10 +126,9 @@ def test_flashinfer_trtllm_decode_with_baseline( k_scale = v_scale = kv_scale max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (batch_size, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] @@ -147,20 +150,23 @@ def test_flashinfer_trtllm_decode_with_baseline( # Baseline Decode wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( - workspace_buffer, kv_layout, use_tensor_cores=True) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - sm_scale=sm_scale, - q_data_type=dtype, - kv_data_type=dtype, - window_left=window_left, - logits_soft_cap=soft_cap) + workspace_buffer, kv_layout, use_tensor_cores=True + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + window_left=window_left, + logits_soft_cap=soft_cap, + ) output = torch.empty(ref_query.shape, dtype=dtype) wrapper.run(ref_query, ref_kv_cache, out=output) @@ -169,17 +175,21 @@ def test_flashinfer_trtllm_decode_with_baseline( if o_quant_dtype == FP8_DTYPE: _, o_scale = to_float8(output) elif o_quant_dtype == FP4_DTYPE: - o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(output.flatten(), dim=-1)).to(torch.float32) + o_sf_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1) + ).to(torch.float32) # TRTLLM Decode if o_quant_dtype == FP4_DTYPE: output_trtllm = flashinfer.utils.FP4Tensor( - torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), - dtype=torch.uint8), - torch.empty((round_up(query.shape[0], 128), - round_up(query.shape[1] * query.shape[2] // 16, 4)), - dtype=torch.float8_e4m3fn), + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), ) else: output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) @@ -201,13 +211,12 @@ def test_flashinfer_trtllm_decode_with_baseline( output_trtllm = output_trtllm.to(dtype) * o_scale elif o_quant_dtype == FP4_DTYPE: output_trtllm.data = output_trtllm.data.reshape( - -1, query.shape[1] * query.shape[2] // 2) - output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, - output_trtllm.scale, - o_sf_scale, dtype, - query.device) - output_trtllm = output_trtllm.reshape(-1, query.shape[1], - query.shape[2]) + -1, query.shape[1] * query.shape[2] // 2 + ) + output_trtllm = dequantize_nvfp4_to_dtype( + output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device + ) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: rtol, atol = 3e-1, 1e0 @@ -216,8 +225,10 @@ def test_flashinfer_trtllm_decode_with_baseline( else: rtol, atol = 1e-2, 2e-2 - torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - output_trtllm))}" + ( + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - output_trtllm))}", + ) @pytest.mark.parametrize("dtype", DTYPE) @@ -233,8 +244,9 @@ def test_flashinfer_trtllm_decode_with_baseline( @torch.inference_mode def test_flashinfer_trtllm_prefill_with_baseline( dtype: torch.dtype, - quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype], - Optional[torch.dtype]], + quant_dtypes: tuple[ + Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] + ], batch_size: int, max_seq_lens: tuple[int, int], num_heads: tuple[int, int], @@ -270,17 +282,16 @@ def test_flashinfer_trtllm_prefill_with_baseline( else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32) + q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32) q_lens[-1] = max_q_len - q_indptr = torch.cat([ - torch.tensor([0], dtype=torch.int32), - torch.cumsum(q_lens, dim=0, dtype=torch.int32), - ]) - - query = torch.randn(torch.sum(q_lens).item(), - num_qo_heads, - head_size, - dtype=dtype) + q_indptr = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(q_lens, dim=0, dtype=torch.int32), + ] + ) + + query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) if q_quant_dtype == FP8_DTYPE: query, q_scale = to_float8(query) ref_query = query.to(dtype) * q_scale @@ -288,7 +299,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( q_scale = 1.0 ref_query = query - kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32) + kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_kv_len seq_lens = kv_lens + q_lens @@ -304,10 +315,9 @@ def test_flashinfer_trtllm_prefill_with_baseline( k_scale = v_scale = kv_scale max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables = torch.randint(0, - NUM_BLOCKS, - (batch_size, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 + ) kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] @@ -329,21 +339,24 @@ def test_flashinfer_trtllm_prefill_with_baseline( # Baseline Prefill wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout) - wrapper.plan(q_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - num_qo_heads, - num_kv_heads, - head_size, - block_size, - causal=True, - sm_scale=sm_scale, - q_data_type=dtype, - kv_data_type=dtype, - window_left=window_left, - logits_soft_cap=soft_cap) + workspace_buffer, kv_layout + ) + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_size, + block_size, + causal=True, + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + window_left=window_left, + logits_soft_cap=soft_cap, + ) output = torch.empty(ref_query.shape, dtype=dtype) wrapper.run(ref_query, ref_kv_cache, out=output) @@ -352,17 +365,21 @@ def test_flashinfer_trtllm_prefill_with_baseline( if o_quant_dtype == FP8_DTYPE: _, o_scale = to_float8(output) elif o_quant_dtype == FP4_DTYPE: - o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(output.flatten(), dim=-1)).to(torch.float32) + o_sf_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1) + ).to(torch.float32) # TRTLLM Prefill if o_quant_dtype == FP4_DTYPE: output_trtllm = flashinfer.utils.FP4Tensor( - torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), - dtype=torch.uint8), - torch.empty((round_up(query.shape[0], 128), - round_up(query.shape[1] * query.shape[2] // 16, 4)), - dtype=torch.float8_e4m3fn), + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), ) else: output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) @@ -388,13 +405,12 @@ def test_flashinfer_trtllm_prefill_with_baseline( output_trtllm = output_trtllm.to(dtype) * o_scale elif o_quant_dtype == FP4_DTYPE: output_trtllm.data = output_trtllm.data.reshape( - -1, query.shape[1] * query.shape[2] // 2) - output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, - output_trtllm.scale, - o_sf_scale, dtype, - query.device) - output_trtllm = output_trtllm.reshape(-1, query.shape[1], - query.shape[2]) + -1, query.shape[1] * query.shape[2] // 2 + ) + output_trtllm = dequantize_nvfp4_to_dtype( + output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device + ) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2]) if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: rtol, atol = 4e-1, 1e0 @@ -405,5 +421,7 @@ def test_flashinfer_trtllm_prefill_with_baseline( else: rtol, atol = 1e-2, 1e-2 - torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - output_trtllm))}" + ( + torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - output_trtllm))}", + ) diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index bddd7e5c50ed..2b6fd38e4f58 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -7,30 +7,33 @@ import pytest import torch -from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) +from vllm.attention.ops.flashmla import ( + flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported, +) from vllm.triton_utils import triton -def cal_diff(x: torch.Tensor, - y: torch.Tensor, - name: str, - use_fp8: bool = False) -> None: +def cal_diff( + x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False +) -> None: x, y = x.double(), y.double() - cos_diff = 1 - 2 * (x * y).sum().item() / max( - (x * x + y * y).sum().item(), 1e-12) - if (use_fp8): + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + if use_fp8: assert cos_diff < 1e-4 else: assert cos_diff < 1e-5 -FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ - if not is_flashmla_supported()[0] else "FlashMLA is supported" + +FLASH_MLA_UNSUPPORTED_REASON = ( + is_flashmla_supported()[1] + if not is_flashmla_supported()[0] + else "FlashMLA is supported" +) -@pytest.mark.skipif(not is_flashmla_supported()[0], - reason=FLASH_MLA_UNSUPPORTED_REASON) +@pytest.mark.skipif(not is_flashmla_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) @pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) @@ -41,47 +44,49 @@ def cal_diff(x: torch.Tensor, @pytest.mark.parametrize("block_size", [64]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("torch_dtype", - [torch.bfloat16, torch.float16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "torch_dtype", [torch.bfloat16, torch.float16, torch.float8_e4m3fn] +) @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, - varlen, torch_dtype): +def test_flash_mla( + b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype +): device = torch.device("cuda:0") - if torch_dtype == torch.float8_e4m3fn: - init_dtype = torch.bfloat16 - else: - init_dtype = torch_dtype + init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) random.seed(0) - print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " - f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") + print( + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}" + ) use_fp8 = torch_dtype == torch.float8_e4m3fn - cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) + cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: for i in range(b): - cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), - s_q) + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) total_seqlens = cache_seqlens.sum().item() max_seqlen = cache_seqlens.max().item() max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 q = torch.randn(b, s_q, h_q, d) - block_table = torch.arange(b * max_seqlen_pad // block_size, - dtype=torch.int32).view( - b, max_seqlen_pad // block_size) + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32 + ).view(b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, - d)[i, cache_seqlens[i].item():] = float("nan") + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = ( + float("nan") + ) blocked_v = blocked_k[..., :dv] tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens, s_q * h_q // h_kv, h_kv) + cache_seqlens, s_q * h_q // h_kv, h_kv + ) init_dtype = q.dtype if use_fp8: @@ -97,16 +102,18 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, descale_k = None def flash_mla(): - return flash_mla_with_kvcache(q, - blocked_k, - block_table, - cache_seqlens, - dv, - tile_scheduler_metadata, - num_splits, - causal=causal, - descale_q=descale_q, - descale_k=descale_k) + return flash_mla_with_kvcache( + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, + descale_q=descale_q, + descale_k=descale_k, + ) def scaled_dot_product_attention(query, key, value, is_causal=False): query = query.float() @@ -119,8 +126,7 @@ def scaled_dot_product_attention(query, key, value, is_causal=False): s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) - temp_mask = torch.ones(s_q, s_k, - dtype=torch.bool).tril(diagonal=s_k - s_q) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias @@ -130,10 +136,16 @@ def scaled_dot_product_attention(query, key, value, is_causal=False): def ref_mla(): q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q - blocked_k_ = (blocked_k.to(torch.float) * - descale_k).to(init_dtype) if use_fp8 else blocked_k - blocked_v_ = (blocked_v.to(torch.float) * - descale_k).to(init_dtype) if use_fp8 else blocked_v + blocked_k_ = ( + (blocked_k.to(torch.float) * descale_k).to(init_dtype) + if use_fp8 + else blocked_k + ) + blocked_v_ = ( + (blocked_v.to(torch.float) * descale_k).to(init_dtype) + if use_fp8 + else blocked_v + ) out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): @@ -156,8 +168,9 @@ def ref_mla(): t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + - b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( - b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", - f"{bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * ( + torch.finfo(torch_dtype).bits // 8 + ) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print( + f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS,", f"{bytes / 10**6 / t:.0f} GB/s" + ) diff --git a/tests/kernels/attention/test_flashmla_sparse.py b/tests/kernels/attention/test_flashmla_sparse.py index 9036e4e7800b..562ae3009e41 100644 --- a/tests/kernels/attention/test_flashmla_sparse.py +++ b/tests/kernels/attention/test_flashmla_sparse.py @@ -13,6 +13,7 @@ def _cuda_sm90_available() -> bool: def test_sparse_flashmla_metadata_smoke(): import vllm.attention.ops.flashmla as fm + ok, reason = fm.is_flashmla_supported() if not ok or not _cuda_sm90_available(): pytest.skip(reason or "SM90 not available") @@ -27,18 +28,21 @@ def test_sparse_flashmla_metadata_smoke(): cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) - tile_md, num_splits = fm.get_mla_metadata(cache_seqlens, - q_seq_per_hk, - num_heads_k, - num_heads_q=num_heads_q, - topk=topk, - is_fp8_kvcache=True) + tile_md, num_splits = fm.get_mla_metadata( + cache_seqlens, + q_seq_per_hk, + num_heads_k, + num_heads_q=num_heads_q, + topk=topk, + is_fp8_kvcache=True, + ) assert tile_md.dtype == torch.int32 assert num_splits.dtype == torch.int32 def test_sparse_flashmla_decode_smoke(): import vllm.attention.ops.flashmla as fm + ok, reason = fm.is_flashmla_supported() if not ok or not _cuda_sm90_available(): pytest.skip(reason or "SM90 not available") @@ -58,36 +62,42 @@ def test_sparse_flashmla_decode_smoke(): q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k # q_heads_per_hk = num_heads_q // num_heads_k cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) - tile_md, num_splits = fm.get_mla_metadata(cache_seqlens, - q_seq_per_hk, - num_heads_k, - num_heads_q=num_heads_q, - topk=topk, - is_fp8_kvcache=True) + tile_md, num_splits = fm.get_mla_metadata( + cache_seqlens, + q_seq_per_hk, + num_heads_k, + num_heads_q=num_heads_q, + topk=topk, + is_fp8_kvcache=True, + ) # Inputs - q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k), - dtype=torch.bfloat16, - device=device) - k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token), - dtype=torch.uint8, - device=device) - indices = torch.zeros((batch_size, seqlen_q, topk), - dtype=torch.int32, - device=device) - - block_table = torch.zeros((batch_size, 128), - dtype=torch.int32, - device=device) - out, lse = fm.flash_mla_with_kvcache(q, - k_cache, - block_table, - cache_seqlens, - head_dim_v, - tile_md, - num_splits, - indices=indices, - is_fp8_kvcache=True) + q = torch.zeros( + (batch_size, seqlen_q, num_heads_q, head_dim_k), + dtype=torch.bfloat16, + device=device, + ) + k_cache = torch.zeros( + (1, page_block_size, num_heads_k, bytes_per_token), + dtype=torch.uint8, + device=device, + ) + indices = torch.zeros( + (batch_size, seqlen_q, topk), dtype=torch.int32, device=device + ) + + block_table = torch.zeros((batch_size, 128), dtype=torch.int32, device=device) + out, lse = fm.flash_mla_with_kvcache( + q, + k_cache, + block_table, + cache_seqlens, + head_dim_v, + tile_md, + num_splits, + indices=indices, + is_fp8_kvcache=True, + ) assert out.shape[0] == batch_size assert out.shape[-1] == head_dim_v assert lse.shape[0] == batch_size @@ -95,6 +105,7 @@ def test_sparse_flashmla_decode_smoke(): def test_sparse_flashmla_prefill_smoke(): import vllm.attention.ops.flashmla as fm + ok, reason = fm.is_flashmla_supported() if not ok or not _cuda_sm90_available(): pytest.skip(reason or "SM90 not available") @@ -112,8 +123,7 @@ def test_sparse_flashmla_prefill_smoke(): kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device) indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device) - out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, - d_v) + out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, d_v) assert out.shape == (s_q, h_q, d_v) assert max_logits.shape == (s_q, h_q) assert lse.shape == (s_q, h_q) diff --git a/tests/kernels/attention/test_lightning_attn.py b/tests/kernels/attention/test_lightning_attn.py index de45ee1ed5cc..ec938caff2c6 100644 --- a/tests/kernels/attention/test_lightning_attn.py +++ b/tests/kernels/attention/test_lightning_attn.py @@ -4,8 +4,7 @@ import pytest import torch -from vllm.model_executor.layers.lightning_attn import ( - linear_decode_forward_triton) +from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton from vllm.platforms import current_platform NUM_HEADS = [4, 8] @@ -17,8 +16,8 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history): """Reference implementation of lightning attention core algorithm - - The difference from the main implementation is that this processes + + The difference from the main implementation is that this processes each step sequentially, instead of using parallelized triton kernels """ B, H, S, D = q.shape @@ -34,10 +33,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history): # More efficient implementation # Convert decay factors to matrix form - if ed.dim() == 1: - decay = torch.exp(-ed).view(1, -1, 1, 1) - else: - decay = torch.exp(-ed) + decay = torch.exp(-ed).view(1, -1, 1, 1) if ed.dim() == 1 else torch.exp(-ed) for b in range(B): for step in range(S): @@ -62,8 +58,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history): # The actual implementation returns a tensor of shape [B, H, 2, D, E] # where dimension 2 contains both KV and KV history kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E] - final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], - dim=2) # [B, H, 2, D, E] + final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], dim=2) # [B, H, 2, D, E] return output, final_kv_cache @@ -109,7 +104,7 @@ def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx): out_h = torch.matmul(q_bh, kv_new) # Update output and cache - output[b, h * D:(h + 1) * D] = out_h + output[b, h * D : (h + 1) * D] = out_h kv_caches[b, h] = kv_new return output @@ -135,12 +130,9 @@ def test_linear_decode_forward_triton( k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) - kv_caches = base * torch.randn(batch_size, - num_heads, - head_size, - head_size, - dtype=dtype, - device="cuda") + kv_caches = base * torch.randn( + batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + ) kv_caches_copy = kv_caches.clone() @@ -150,15 +142,14 @@ def test_linear_decode_forward_triton( slot_idx = torch.arange(batch_size, device="cuda") - triton_output = linear_decode_forward_triton(q, k, v, kv_caches, - slope_rate, slot_idx) + triton_output = linear_decode_forward_triton( + q, k, v, kv_caches, slope_rate, slot_idx + ) - reference_output = reference_linear_decode(q, k, v, kv_caches_copy, - slope_rate, slot_idx) - torch.testing.assert_close(triton_output, - reference_output, - rtol=1e-1, - atol=1e-1) + reference_output = reference_linear_decode( + q, k, v, kv_caches_copy, slope_rate, slot_idx + ) + torch.testing.assert_close(triton_output, reference_output, rtol=1e-1, atol=1e-1) torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1) assert triton_output.shape == (batch_size, num_heads * head_size) @@ -184,12 +175,9 @@ def test_linear_decode_forward_triton_with_padding( k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype) - kv_caches = base * torch.randn(batch_size, - num_heads, - head_size, - head_size, - dtype=dtype, - device="cuda") + kv_caches = base * torch.randn( + batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + ) kv_caches_copy = kv_caches.clone() @@ -199,14 +187,15 @@ def test_linear_decode_forward_triton_with_padding( slot_idx = torch.tensor([0, 1, -1, 2], device="cuda") - triton_output = linear_decode_forward_triton(q, k, v, kv_caches, - slope_rate, slot_idx) + triton_output = linear_decode_forward_triton( + q, k, v, kv_caches, slope_rate, slot_idx + ) - reference_output = reference_linear_decode(q, k, v, kv_caches_copy, - slope_rate, slot_idx) + reference_output = reference_linear_decode( + q, k, v, kv_caches_copy, slope_rate, slot_idx + ) - padding_mask = (slot_idx - != -1).unsqueeze(1).expand(-1, num_heads * head_size) + padding_mask = (slot_idx != -1).unsqueeze(1).expand(-1, num_heads * head_size) triton_masked = triton_output[padding_mask] reference_masked = reference_output[padding_mask] @@ -217,15 +206,11 @@ def test_linear_decode_forward_triton_with_padding( for i in range(batch_size): if valid_indices[i] > 0: - torch.testing.assert_close(kv_caches[i], - kv_caches_copy[i], - rtol=rtol, - atol=atol) + torch.testing.assert_close( + kv_caches[i], kv_caches_copy[i], rtol=rtol, atol=atol + ) - torch.testing.assert_close(triton_masked, - reference_masked, - rtol=rtol, - atol=atol) + torch.testing.assert_close(triton_masked, reference_masked, rtol=rtol, atol=atol) assert triton_output.shape == (batch_size, num_heads * head_size) @@ -249,39 +234,33 @@ def test_lightning_attention_reference( current_platform.seed_everything(42) base = 0.01 - q = base * torch.randn( - batch_size, num_heads, seq_len, head_size, dtype=dtype) - k = base * torch.randn( - batch_size, num_heads, seq_len, head_size, dtype=dtype) - v = base * torch.randn( - batch_size, num_heads, seq_len, head_size, dtype=dtype) + q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) + k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) + v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype) ed = torch.zeros(num_heads, device="cuda") for h in range(num_heads): ed[h] = 0.1 * (h + 1) - kv_history = base * torch.randn(batch_size, - num_heads, - head_size, - head_size, - dtype=dtype, - device="cuda") + kv_history = base * torch.randn( + batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda" + ) kv_history_clone = kv_history.clone() ref_output, ref_kv_cache = reference_lightning_attention( - q, k, v, ed, 256, kv_history) + q, k, v, ed, 256, kv_history + ) from vllm.model_executor.layers.lightning_attn import lightning_attention + actual_output, actual_kv_cache = lightning_attention( - q, k, v, ed, 256, kv_history_clone) + q, k, v, ed, 256, kv_history_clone + ) atol, rtol = 1.5e-1, 1.5e-1 torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol) - torch.testing.assert_close(ref_kv_cache, - actual_kv_cache, - rtol=rtol, - atol=atol) + torch.testing.assert_close(ref_kv_cache, actual_kv_cache, rtol=rtol, atol=atol) assert ref_output.shape == (batch_size, num_heads, seq_len, head_size) assert ref_kv_cache.shape == actual_kv_cache.shape diff --git a/tests/kernels/attention/test_merge_attn_states.py b/tests/kernels/attention/test_merge_attn_states.py index 9d1a301ebe30..eb9204dfaf15 100644 --- a/tests/kernels/attention/test_merge_attn_states.py +++ b/tests/kernels/attention/test_merge_attn_states.py @@ -7,19 +7,20 @@ from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda from vllm.attention.ops.triton_merge_attn_states import ( - merge_attn_states as merge_attn_states_triton) + merge_attn_states as merge_attn_states_triton, +) from vllm.platforms import current_platform # Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 # can be used to combine partial attention results (in the split-KV case) def merge_attn_states_torch( - output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] - suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] - output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS] + output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] + output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS] ): p_lse = prefix_lse s_lse = suffix_lse @@ -32,15 +33,13 @@ def merge_attn_states_torch( s_lse = s_lse - max_lse p_lse_exp = torch.exp(p_lse) s_lse_exp = torch.exp(s_lse) - out_se = (p_lse_exp + s_lse_exp) + out_se = p_lse_exp + s_lse_exp if output_lse is not None: output_lse = torch.log(out_se) + max_lse p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] - p_scale = torch.transpose(p_scale, 0, - 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] - s_scale = torch.transpose(s_scale, 0, - 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] output = prefix_output * p_scale + suffix_output * s_scale return output, output_lse @@ -55,8 +54,10 @@ def merge_attn_states_torch( def generate_markdown_table(): global all_case_info - table_header = ("| tokens | heads | headsize | dtype " - "| device | torch | triton | cuda | speedup |") + table_header = ( + "| tokens | heads | headsize | dtype " + "| device | torch | triton | cuda | speedup |" + ) table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |" def shortly_dtype(dtype: torch.dtype) -> str: @@ -68,16 +69,26 @@ def shortly_device(device: str) -> str: print(table_header) print(table_separator) for info in all_case_info: - (num_tokens, num_heads, head_size, dtype, device, - avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel, - performance_improved) = info + ( + num_tokens, + num_heads, + head_size, + dtype, + device, + avg_time_torch_kernel, + avg_time_triton_kernel, + avg_time_cuda_kernel, + performance_improved, + ) = info dtype = shortly_dtype(dtype) device = shortly_device(device) - print(f"| {num_tokens} | {num_heads} | {head_size} " - f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms " - f"| {avg_time_triton_kernel:.5f}ms " - f"| {avg_time_cuda_kernel:.5f}ms " - f"| {performance_improved:.4f}x |") + print( + f"| {num_tokens} | {num_heads} | {head_size} " + f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms " + f"| {avg_time_triton_kernel:.5f}ms " + f"| {avg_time_cuda_kernel:.5f}ms " + f"| {performance_improved:.4f}x |" + ) @pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS) @@ -85,29 +96,28 @@ def shortly_device(device: str) -> str: @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("output_dtype", DTYPES) @torch.inference_mode() -def test_merge_attn_states(num_tokens: int, num_query_heads: int, - head_size: int, output_dtype: torch.dtype): +def test_merge_attn_states( + num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype +): if not current_platform.is_cuda(): - pytest.skip('Currently only support compare triton merge_attn_states ' - 'with custom cuda merge_attn_states kernel') + pytest.skip( + "Currently only support compare triton merge_attn_states " + "with custom cuda merge_attn_states kernel" + ) NUM_TOKENS = num_tokens NUM_HEADS = num_query_heads HEAD_SIZE = head_size - print(f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " - f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " - f"Device: {current_platform.get_device_name()}") + print( + f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " + f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " + f"Device: {current_platform.get_device_name()}" + ) # prefix_lse and suffix_lse contain inf and normal values - prefix_lse = torch.randn(NUM_HEADS, - NUM_TOKENS, - dtype=torch.float32, - device="cuda") - suffix_lse = torch.randn(NUM_HEADS, - NUM_TOKENS, - dtype=torch.float32, - device="cuda") + prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda") + suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda") # Generate boolean masks mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1 @@ -117,23 +127,23 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, mask_prefix = torch.logical_and(mask_prefix, ~combined_mask) mask_suffix = torch.logical_and(mask_suffix, ~combined_mask) - prefix_lse[mask_prefix] = float('inf') - suffix_lse[mask_suffix] = float('inf') + prefix_lse[mask_prefix] = float("inf") + suffix_lse[mask_suffix] = float("inf") # Other input tensors (need to be initialized but # no actual calculation needed) - output = torch.zeros((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), - dtype=output_dtype, - device="cuda") - output_lse = torch.zeros((NUM_HEADS, NUM_TOKENS), - dtype=torch.float32, - device="cuda") - prefix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), - dtype=output_dtype, - device="cuda") - suffix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE), - dtype=output_dtype, - device="cuda") + output = torch.zeros( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + output_lse = torch.zeros( + (NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda" + ) + prefix_output = torch.randn( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + suffix_output = torch.randn( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) warmup_times = 2 repeat_times = 20 @@ -149,15 +159,25 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, suffix_lse_torch = suffix_lse.clone() for _ in range(warmup_times): output_torch, output_lse_torch = merge_attn_states_torch( - output_torch, prefix_output, prefix_lse_torch, suffix_output, - suffix_lse_torch, output_lse_torch) + output_torch, + prefix_output, + prefix_lse_torch, + suffix_output, + suffix_lse_torch, + output_lse_torch, + ) torch.cuda.synchronize() for _ in range(repeat_times): start.record() output_torch, output_lse_torch = merge_attn_states_torch( - output_torch, prefix_output, prefix_lse_torch, suffix_output, - suffix_lse_torch, output_lse_torch) + output_torch, + prefix_output, + prefix_lse_torch, + suffix_output, + suffix_lse_torch, + output_lse_torch, + ) end.record() torch.cuda.synchronize() total_time_torch_kernel += start.elapsed_time(end) @@ -173,16 +193,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, end = torch.cuda.Event(enable_timing=True) for _ in range(warmup_times): - merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse, - suffix_output, suffix_lse, - output_lse_ref_triton) + merge_attn_states_triton( + output_ref_triton, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_ref_triton, + ) torch.cuda.synchronize() for _ in range(repeat_times): start.record() - merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse, - suffix_output, suffix_lse, - output_lse_ref_triton) + merge_attn_states_triton( + output_ref_triton, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_ref_triton, + ) end.record() torch.cuda.synchronize() total_time_triton_kernel += start.elapsed_time(end) @@ -195,14 +225,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, output_lse_cuda = output_lse.clone() for _ in range(warmup_times): - merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse_cuda) + merge_attn_states_cuda( + output_cuda, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_cuda, + ) torch.cuda.synchronize() for _ in range(repeat_times): start.record() - merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse_cuda) + merge_attn_states_cuda( + output_cuda, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + output_lse_cuda, + ) end.record() torch.cuda.synchronize() total_time_cuda_kernel += start.elapsed_time(end) @@ -213,8 +255,10 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int, performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel print(f" Torch time: {avg_time_torch_kernel:.6f}ms") print(f"Triton time: {avg_time_triton_kernel:.6f}ms") - print(f" CUDA time: {avg_time_cuda_kernel:.6f}ms, " - f"Performance: {performance_improved:.5f}x") + print( + f" CUDA time: {avg_time_cuda_kernel:.6f}ms, " + f"Performance: {performance_improved:.5f}x" + ) print("-" * 100) # 4. Correctness compare @@ -232,35 +276,45 @@ def diff(a: torch.Tensor, b: torch.Tensor): # states operation. output_ref = output_ref_triton output_lse_ref = output_lse_ref_triton - torch.testing.assert_close(output_cuda.float(), - output_ref.float(), - atol=1e-3, - rtol=rtol) + torch.testing.assert_close( + output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol + ) print("Output all match, max abs diff:") print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}") print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}") print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}") print("-" * 100) - torch.testing.assert_close(output_lse_cuda.float(), - output_lse_ref.float(), - atol=1e-3, - rtol=rtol) + torch.testing.assert_close( + output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol + ) print("Output LSE all match, max abs diff:") print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}") print(f" (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}") print(f" (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}") print("-" * 100) - print("All output values test passed! All inf values " - "are correctly replaced with -inf.") + print( + "All output values test passed! All inf values " + "are correctly replaced with -inf." + ) print("-" * 100) device = current_platform.get_device_name() all_case_info.append( - (NUM_TOKENS, NUM_HEADS, HEAD_SIZE, output_dtype, device, - avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel, - performance_improved)) - if len(all_case_info) == (len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * - len(NUM_QUERY_HEADS) * len(DTYPES)): + ( + NUM_TOKENS, + NUM_HEADS, + HEAD_SIZE, + output_dtype, + device, + avg_time_torch_kernel, + avg_time_triton_kernel, + avg_time_cuda_kernel, + performance_improved, + ) + ) + if len(all_case_info) == ( + len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES) + ): generate_markdown_table() diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index cea08e19f52d..14d1618bca3c 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -5,6 +5,7 @@ * Tests for MultiHeadAttention layer """ + from unittest.mock import patch import pytest @@ -21,11 +22,11 @@ @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() # Clear xformers availability cache import vllm.attention.layer as layer_module + layer_module.USE_XFORMERS_OPS = None @@ -37,49 +38,63 @@ def test_mha_attn_platform(device: str): torch.set_default_dtype(torch.float16) if device == "cpu": - with patch("vllm.attention.layer.current_platform", CpuPlatform()), \ - patch("vllm.model_executor.models.vision.current_platform", - CpuPlatform()): + with ( + patch("vllm.attention.layer.current_platform", CpuPlatform()), + patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()), + ): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA elif device == "hip": - with patch("vllm.attention.layer.current_platform", RocmPlatform()), \ - patch("vllm.model_executor.models.vision.current_platform", - RocmPlatform()): + with ( + patch("vllm.attention.layer.current_platform", RocmPlatform()), + patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), + ): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA else: # Test CUDA with head_size=64 (divisible by 32) # - should use vLLM's FlashAttention - with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ - patch("vllm.model_executor.models.vision.current_platform", - CudaPlatform()): + with ( + patch("vllm.attention.layer.current_platform", CudaPlatform()), + patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), + ): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.FLASH_ATTN # Test CUDA with head_size=72 (not divisible by 32) # - with upstream FA not available # - should use xformers - with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ - patch("vllm.model_executor.models.vision.current_platform", - CudaPlatform()), \ - patch("vllm.attention.layer.check_upstream_fa_availability", - return_value=False): + with ( + patch("vllm.attention.layer.current_platform", CudaPlatform()), + patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), + patch( + "vllm.attention.layer.check_upstream_fa_availability", + return_value=False, + ), + ): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.XFORMERS # Test CUDA with head_size=72 (not divisible by 32) # - with upstream FA available # - should use upstream FA - with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ - patch("vllm.model_executor.models.vision.current_platform", - CudaPlatform()), \ - patch("vllm.attention.layer.check_upstream_fa_availability", - return_value=True), \ - patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (), - { - 'flash_attn_varlen_func': lambda *args, **kwargs: None - })()}): + with ( + patch("vllm.attention.layer.current_platform", CudaPlatform()), + patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), + patch( + "vllm.attention.layer.check_upstream_fa_availability", return_value=True + ), + patch.dict( + "sys.modules", + { + "flash_attn": type( + "MockFlashAttn", + (), + {"flash_attn_varlen_func": lambda *args, **kwargs: None}, + )() + }, + ), + ): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.FLASH_ATTN @@ -108,9 +123,11 @@ def ref_attention( NUM_KV_HEADS = [1] HEAD_SIZES = [64, 80] # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} -DTYPES = [ - torch.half, torch.bfloat16, torch.float -] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] +DTYPES = ( + [torch.half, torch.bfloat16, torch.float] + if not current_platform.is_rocm() + else [torch.half, torch.bfloat16] +) CUDA_DEVICES = ["cuda"] @@ -138,10 +155,9 @@ def test_mha_attn_forward( k = torch.randn(batch_size, seq_len, num_kv_heads * head_size) v = torch.randn(batch_size, seq_len, num_kv_heads * head_size) scale = 1.0 / head_size**0.5 - attn = MultiHeadAttention(num_heads, - head_size, - scale=scale, - num_kv_heads=num_kv_heads) + attn = MultiHeadAttention( + num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads + ) output = attn(q, k, v) assert num_heads % num_kv_heads == 0 diff --git a/tests/kernels/attention/test_mla_decode_cpu.py b/tests/kernels/attention/test_mla_decode_cpu.py index f8b307c595de..44f3e42e8714 100644 --- a/tests/kernels/attention/test_mla_decode_cpu.py +++ b/tests/kernels/attention/test_mla_decode_cpu.py @@ -11,30 +11,24 @@ def ref_mla( - out: Tensor, # (bs, num_heads, v_head_dim) - query: Tensor, # (bs, num_heads, head_dim) - kv_cache: Tensor, # (num_blocks, block_size, head_dim) - scale: float, - block_tables: Tensor, # (bs, max_num_blocks) - seq_lens: Tensor, # (bs,) + out: Tensor, # (bs, num_heads, v_head_dim) + query: Tensor, # (bs, num_heads, head_dim) + kv_cache: Tensor, # (num_blocks, block_size, head_dim) + scale: float, + block_tables: Tensor, # (bs, max_num_blocks) + seq_lens: Tensor, # (bs,) ): bs, num_heads, v_head_dim = out.shape head_dim = query.shape[2] for i in range(bs): # gather and flatten KV-cache - kv = kv_cache[ - block_tables[i]] # (max_num_blocks, block_size, head_dim) - kv = kv.view(1, -1, - head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) + kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim) + kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim) v = kv[:, :, :v_head_dim] q = query[i].view(num_heads, 1, head_dim) - o = F.scaled_dot_product_attention(q, - kv, - v, - scale=scale, - enable_gqa=True) + o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True) out[i] = o.view(num_heads, v_head_dim) return out @@ -63,18 +57,17 @@ def test_mla_decode_cpu( torch.set_default_dtype(dtype) torch.manual_seed(0) - scale = d**(-0.5) + scale = d ** (-0.5) if varlen: seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) seq_lens = seq_lens.clip(2).to(torch.int32) else: - seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) + seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32) max_seq_len = seq_lens.max().item() seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary? q = torch.randn(bs, h_q, d) - block_table = torch.arange(bs * seqlen_pad // block_size, - dtype=torch.int32) + block_table = torch.arange(bs * seqlen_pad // block_size, dtype=torch.int32) block_table = block_table.view(bs, seqlen_pad // block_size) kv_cache = torch.randn(block_table.numel(), block_size, d) @@ -82,8 +75,7 @@ def test_mla_decode_cpu( kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan") out_mla = q.new_zeros(bs, h_q, dv) - ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, - seq_lens) + ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, seq_lens) out_ref = q.new_zeros(bs, h_q, dv) ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) diff --git a/tests/kernels/attention/test_pack_unpack_triton.py b/tests/kernels/attention/test_pack_unpack_triton.py index 20c0b262b479..d2aa14738d9d 100644 --- a/tests/kernels/attention/test_pack_unpack_triton.py +++ b/tests/kernels/attention/test_pack_unpack_triton.py @@ -39,7 +39,7 @@ def test_pack_seq_basic_fp8(): start_idx = sum(lengths_list[:b]) seq_len = lengths_list[b] - expected_data = x[start_idx:start_idx + seq_len].to(torch.float32) + expected_data = x[start_idx : start_idx + seq_len].to(torch.float32) actual_data = packed[b, :seq_len].to(torch.float32) assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) @@ -62,7 +62,7 @@ def test_pack_seq_custom_padding_fp8(): # Check valid data for b in range(B): start_idx = b * 10 - expected_data = x[start_idx:start_idx + 10].to(torch.float32) + expected_data = x[start_idx : start_idx + 10].to(torch.float32) actual_data = result[b, :10].to(torch.float32) assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) @@ -73,9 +73,7 @@ def test_pack_seq_custom_padding_fp8(): elif pad_value > 0: assert torch.all(padded_data > 50) # Large positive values else: - assert torch.allclose(padded_data, - torch.zeros_like(padded_data), - atol=1e-2) + assert torch.allclose(padded_data, torch.zeros_like(padded_data), atol=1e-2) def test_pack_seq_default_negative_inf_padding_fp8(): @@ -93,7 +91,8 @@ def test_pack_seq_default_negative_inf_padding_fp8(): # Check that padding is large negative values (fp8 representation of -inf) padded_data = result[:, 10:].to(torch.float32) assert torch.all( - padded_data < -100) # fp8 -inf is represented as large negative number + padded_data < -100 + ) # fp8 -inf is represented as large negative number def test_pack_seq_edge_cases_fp8(): @@ -142,7 +141,7 @@ def test_pack_seq_different_block_sizes_fp8(): # Check that valid data is preserved (within fp8 precision) for b in range(B): start_idx = b * 25 - expected_data = x[start_idx:start_idx + 25].to(torch.float32) + expected_data = x[start_idx : start_idx + 25].to(torch.float32) actual_data = result[b, :25].to(torch.float32) assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) @@ -198,10 +197,7 @@ def test_pack_unpack_roundtrip_fp8(): # Unpack without explicit start locations (computed in kernel) unpacked_with_loc = unpack_seq_triton(packed, lengths) - assert_close(x_f32, - unpacked_with_loc.to(torch.float32), - rtol=1e-3, - atol=1e-2) + assert_close(x_f32, unpacked_with_loc.to(torch.float32), rtol=1e-3, atol=1e-2) def test_unpack_seq_triton_edge_cases_fp8(): @@ -216,10 +212,7 @@ def test_unpack_seq_triton_edge_cases_fp8(): packed = pack_seq_triton(x, lengths) unpacked = unpack_seq_triton(packed, lengths) assert unpacked.shape == x.shape - assert_close(x.to(torch.float32), - unpacked.to(torch.float32), - rtol=1e-1, - atol=1e-2) + assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) # Test with very short sequences x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1 @@ -228,10 +221,9 @@ def test_unpack_seq_triton_edge_cases_fp8(): packed = pack_seq_triton(x, lengths) unpacked = unpack_seq_triton(packed, lengths) # Only compare the first 3 elements that were actually packed - assert_close(x[:3].to(torch.float32), - unpacked.to(torch.float32), - rtol=1e-1, - atol=1e-2) + assert_close( + x[:3].to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2 + ) x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1 x = x.to(dtype=dtype) @@ -239,7 +231,4 @@ def test_unpack_seq_triton_edge_cases_fp8(): packed = pack_seq_triton(x, lengths) unpacked = unpack_seq_triton(packed, lengths) assert unpacked.shape == x.shape - assert_close(x.to(torch.float32), - unpacked.to(torch.float32), - rtol=1e-1, - atol=1e-2) + assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 0695f84aea1a..5ff2624cd7a4 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -12,8 +12,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask from tests.kernels.utils import make_alibi_bias -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) +from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -22,9 +21,7 @@ NUM_QUERIES_PER_KV = [1, 64] HEAD_SIZES = [24, 128] DTYPES = [torch.float16] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] SLIDING_WINDOW = [0, 16, 2048] KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] @@ -50,12 +47,10 @@ def test_contexted_kv_attention( device: str, op: Callable, ) -> None: - - if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( - 89): + if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): pytest.skip( - 'Triton limitation: fp8e4nv data type is not supported on CUDA' - ' arch < 89') + "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" + ) current_platform.seed_everything(0) torch.set_default_device(device) @@ -93,38 +88,29 @@ def test_contexted_kv_attention( cache_dtype = dtype else: cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) + k_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) + v_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) + block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.long), - dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + ) for i in range(BS): for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) cur_ctx = 0 block_id = 0 while cur_ctx < b_ctx_len[i]: @@ -135,61 +121,71 @@ def test_contexted_kv_attention( end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) + k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc] + ) + v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc] + ) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, - 8).permute(0, 2, 3, 1, 4).contiguous() + k_cache = ( + k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8) + .permute(0, 2, 3, 1, 4) + .contiguous() + ) # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_kv_heads, - head_size).permute(0, 2, 3, 1).contiguous() + v_cache = ( + v_cache.view(-1, block_size, num_kv_heads, head_size) + .permute(0, 2, 3, 1) + .contiguous() + ) k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - sliding_window=sliding_window) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + sliding_window=sliding_window, + ) torch.cuda.synchronize() start_time = time.time() - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - sliding_window=sliding_window) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + sliding_window=sliding_window, + ) torch.cuda.synchronize() end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) @@ -201,22 +197,24 @@ def test_contexted_kv_attention( # heads. # # see also: vllm/model_executor/layers/attention.py - query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, - query.shape[-1]) - key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, - num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], num_kv_heads, - num_queries_per_kv, value.shape[-1]) + query = query.view( + query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1] + ) + key = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + value = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] + ) query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - query_lens, seq_lens) + query_lens, seq_lens + ) if sliding_window > 0: - attn_bias = attn_bias.make_local_attention_from_bottomright( - sliding_window) + attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window) output_ref = xops.memory_efficient_attention_forward( query, key, @@ -239,7 +237,7 @@ def test_contexted_kv_attention( ) torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") output_ref = output_ref.reshape(output.shape) atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -262,12 +260,10 @@ def test_contexted_kv_attention_alibi( device: str, op: Callable, ) -> None: - - if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( - 89): + if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89): pytest.skip( - 'Triton limitation: fp8e4nv data type is not supported on CUDA' - ' arch < 89') + "Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89" + ) current_platform.seed_everything(0) torch.set_default_device(device) @@ -280,9 +276,9 @@ def test_contexted_kv_attention_alibi( def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) @@ -290,17 +286,16 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes alibi_slopes = _get_alibi_slopes(num_heads).to(device) @@ -328,38 +323,29 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: cache_dtype = dtype else: cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] - k_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) - v_cache = torch.zeros(cache_size, - block_size, - num_kv_heads, - head_size, - dtype=cache_dtype) + k_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) + v_cache = torch.zeros( + cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype + ) k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) + block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.long), - dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long), - dim=0) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0 + ) for i in range(BS): for j in range(query_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) cur_ctx = 0 block_id = 0 while cur_ctx < b_ctx_len[i]: @@ -370,82 +356,90 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_kv_heads, - head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) + k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc] + ) + v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc] + ) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, - 8).permute(0, 2, 3, 1, 4).contiguous() + k_cache = ( + k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8) + .permute(0, 2, 3, 1, 4) + .contiguous() + ) # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_kv_heads, - head_size).permute(0, 2, 3, 1).contiguous() + v_cache = ( + v_cache.view(-1, block_size, num_kv_heads, head_size) + .permute(0, 2, 3, 1) + .contiguous() + ) k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - alibi_slopes=alibi_slopes) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + alibi_slopes=alibi_slopes, + ) torch.cuda.synchronize() start_time = time.time() - op(query, - k, - v, - output, - kv_cache_dtype, - k_cache, - v_cache, - block_table, - b_start_loc, - b_seq_len, - MAX_CTX_LEN, - max_input_len, - k_scale, - v_scale, - alibi_slopes=alibi_slopes) + op( + query, + k, + v, + output, + kv_cache_dtype, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + MAX_CTX_LEN, + max_input_len, + k_scale, + v_scale, + alibi_slopes=alibi_slopes, + ) torch.cuda.synchronize() end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms") scale = float(1.0 / (head_size**0.5)) # NOTE(DefTruth): In order to reuse _make_alibi_bias function, # we have to pad query tensor before MQA/GQA expanding. if query.shape[0] != key.shape[0]: - query_pad = torch.empty(sum(seq_lens), - num_heads, - head_size, - dtype=dtype) + query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype) query_pad.uniform_(-1e-3, 1e-3) seq_start = 0 query_start = 0 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - query_pad[seq_start:seq_end, ...] = torch.cat([ - torch.zeros( - seq_len - query_len, num_heads, head_size, dtype=dtype), - query[query_start:query_end, ...] - ], - dim=0) + query_pad[seq_start:seq_end, ...] = torch.cat( + [ + torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...], + ], + dim=0, + ) seq_start += seq_len query_start += query_len query = query_pad @@ -456,11 +450,12 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # heads. # # see also: vllm/model_executor/layers/attention.py - key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, - num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], num_kv_heads, - num_queries_per_kv, value.shape[-1]) + key = key[:, :, None, :].expand( + key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1] + ) + value = value[:, :, None, :].expand( + value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1] + ) # [seq, num_kv_heads, num_queries_per_kv, dk]=> # [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the # codebase. We save some time reshaping alibi matrix at runtime. @@ -483,24 +478,23 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len - out = xops.memory_efficient_attention_forward(query[:, - seq_start:seq_end], - key[:, - seq_start:seq_end], - value[:, - seq_start:seq_end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale) + out = xops.memory_efficient_attention_forward( + query[:, seq_start:seq_end], + key[:, seq_start:seq_end], + value[:, seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale, + ) out = out.view_as(query[:, seq_start:seq_end]).view( - seq_len, num_heads, head_size) - output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, - ...]) + seq_len, num_heads, head_size + ) + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...]) seq_start += seq_len query_start += query_len torch.cuda.synchronize() end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms") atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) @@ -532,9 +526,16 @@ def test_contexted_kv_attention_f32( device: str, op: Callable, ) -> None: - test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size, - sliding_window, dtype, kv_cache_dtype, device, - op) + test_contexted_kv_attention( + num_heads, + num_queries_per_kv, + head_size, + sliding_window, + dtype, + kv_cache_dtype, + device, + op, + ) @pytest.mark.optional @@ -555,5 +556,6 @@ def test_contexted_kv_attention_alibi_f32( device: str, op: Callable, ) -> None: - test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size, - dtype, kv_cache_dtype, device, op) + test_contexted_kv_attention_alibi( + num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op + ) diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index a5b4bddaf475..a59230528770 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -11,8 +11,7 @@ @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() @@ -22,46 +21,29 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") # Set the current platform to ROCm using monkeypatch - monkeypatch.setattr("vllm.attention.selector.current_platform", - RocmPlatform()) + monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform()) # Test standard ROCm attention backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) - assert (backend.get_name() == "ROCM_FLASH" - or backend.get_name() == "TRITON_ATTN") + assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN" # MLA test for deepseek related # change the attention backend to triton MLA m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") - backend = get_attn_backend(576, - torch.bfloat16, - "auto", - 16, - False, - use_mla=True) + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) assert backend.get_name() == "TRITON_MLA" # If attention backend is None # If use_mla is true # The selected backend is triton MLA m.setenv(STR_BACKEND_ENV_VAR, None) - backend = get_attn_backend(576, - torch.bfloat16, - "auto", - 16, - False, - use_mla=True) + backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True) assert backend.get_name() == "TRITON_MLA" # change the attention backend to AITER MLA m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") - backend = get_attn_backend(576, - torch.bfloat16, - "auto", - 1, - False, - use_mla=True) + backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) assert backend.get_name() == "ROCM_AITER_MLA" # If attention backend is None @@ -70,10 +52,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): # The selected backend is ROCM_AITER_MLA m.setenv(STR_BACKEND_ENV_VAR, None) m.setenv("VLLM_ROCM_USE_AITER", "1") - backend = get_attn_backend(576, - torch.bfloat16, - "auto", - 1, - False, - use_mla=True) + backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True) assert backend.get_name() == "ROCM_AITER_MLA" diff --git a/tests/kernels/attention/test_triton_decode_attention.py b/tests/kernels/attention/test_triton_decode_attention.py index 48aacac8376b..01ba0951b825 100644 --- a/tests/kernels/attention/test_triton_decode_attention.py +++ b/tests/kernels/attention/test_triton_decode_attention.py @@ -24,14 +24,12 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): num_kv_splits = 8 num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) - req_to_page = torch.randint(0, - CACHE_SIZE // PAGE_SIZE, - (B, num_pages_per_batch, 1), - device="cuda") + req_to_page = torch.randint( + 0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda" + ) req_to_token = req_to_page * PAGE_SIZE req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) - req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view( - 1, 1, -1) + req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1) req_to_token = req_to_token.view(B, -1) req_to_token = req_to_token[:, :seq_len].contiguous() @@ -48,7 +46,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda") - b_seq_len = torch.full((B, ), seq_len, device="cuda") + b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( (B, H_Q, num_kv_splits, D_V + 1), diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py index 5cff29b15aa3..fba82cfdadbd 100644 --- a/tests/kernels/attention/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -14,9 +14,11 @@ BLOCK_SIZES = [16] DTYPES = [torch.bfloat16] -QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [ - None, torch.float8_e4m3fnuz -] +QDTYPES = ( + [None, torch.float8_e4m3fn] + if not current_platform.is_rocm() + else [None, torch.float8_e4m3fnuz] +) # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check NUM_BLOCKS = [32768, 2048] @@ -42,7 +44,7 @@ def ref_paged_attn( for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] + q = query[start_idx : start_idx + query_len] q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size @@ -60,10 +62,13 @@ def ref_paged_attn( empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) mask |= sliding_window_mask if soft_cap is not None and soft_cap > 0: attn = soft_cap * torch.tanh(attn / soft_cap) @@ -77,9 +82,9 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) -@pytest.mark.parametrize("seq_lens", - [[(1, 1328), (5, 18), - (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) +@pytest.mark.parametrize( + "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] +) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -111,30 +116,23 @@ def test_triton_unified_attn( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) - window_size = ((sliding_window - 1, 0) if sliding_window is not None else - (-1, -1)) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) scale = head_size**-0.5 - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) + query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) value_cache = torch.randn_like(key_cache) - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) + cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum( + dim=0, dtype=torch.int32 + ) kv_lens = torch.tensor(kv_lens, dtype=torch.int32) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) + block_tables = torch.randint( + 0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) output = torch.empty_like(query) @@ -188,5 +186,7 @@ def test_triton_unified_attn( atol, rtol = 1.5e-2, 1e-2 if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ - f"{torch.max(torch.abs(output - ref_output))}" + ( + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), + f"{torch.max(torch.abs(output - ref_output))}", + ) diff --git a/tests/kernels/core/test_activation.py b/tests/kernels/core/test_activation.py index ec5c60fd7b0e..e8777ec4f59e 100644 --- a/tests/kernels/core/test_activation.py +++ b/tests/kernels/core/test_activation.py @@ -8,19 +8,23 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol from tests.kernels.utils import opcheck -from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, - GeluAndMul, MulAndSilu, - NewGELU, QuickGELU, - SiluAndMul, SwigluOAIAndMul) +from vllm.model_executor.layers.activation import ( + FastGELU, + FatreluAndMul, + GeluAndMul, + MulAndSilu, + NewGELU, + QuickGELU, + SiluAndMul, + SwigluOAIAndMul, +) from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 13824] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize( @@ -73,24 +77,19 @@ def test_act_and_mul( out = layer(x) ref_out = layer.forward_native(x) if activation == "swigluoai_and_mul": - rtol = { - #For fp16, change the relative tolerance from 1e-3 to 2e-3 - torch.float16: - 2e-3, - torch.bfloat16: - 2e-2, - torch.float: - 1.3e-6 + # For fp16, change the relative tolerance from 1e-3 to 2e-3 + torch.float16: 2e-3, + torch.bfloat16: 2e-2, + torch.float: 1.3e-6, } def _get_rtol(output) -> float: return rtol[output.dtype] - torch.testing.assert_close(out, - ref_out, - atol=get_default_atol(out), - rtol=_get_rtol(out)) + torch.testing.assert_close( + out, ref_out, atol=get_default_atol(out), rtol=_get_rtol(out) + ) else: # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are # equivalent to the native PyTorch implementations, so we can do exact @@ -98,7 +97,7 @@ def _get_rtol(output) -> float: torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) if activation == "fatrelu": opcheck(fn, (out, x, threshold)) @@ -108,9 +107,14 @@ def _get_rtol(output) -> float: opcheck(fn, (out, x)) -@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast), - (NewGELU, torch.ops._C.gelu_new), - (QuickGELU, torch.ops._C.gelu_quick)]) +@pytest.mark.parametrize( + "activation", + [ + (FastGELU, torch.ops._C.gelu_fast), + (NewGELU, torch.ops._C.gelu_new), + (QuickGELU, torch.ops._C.gelu_quick), + ], +) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -132,10 +136,9 @@ def test_activation( fn = activation[1] out = layer(x) ref_out = layer.forward_native(x) - torch.testing.assert_close(out, - ref_out, - atol=get_default_atol(out), - rtol=get_default_rtol(out)) + torch.testing.assert_close( + out, ref_out, atol=get_default_atol(out), rtol=get_default_rtol(out) + ) out = torch.empty_like(x) opcheck(fn, (out, x)) diff --git a/tests/kernels/core/test_fused_quant_layernorm.py b/tests/kernels/core/test_fused_quant_layernorm.py index 19703b8a2f97..52133ec53d1d 100644 --- a/tests/kernels/core/test_fused_quant_layernorm.py +++ b/tests/kernels/core/test_fused_quant_layernorm.py @@ -24,9 +24,7 @@ ADD_RESIDUAL = [False, True] SCALE_UBS = [True, False] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] EPS = 1e-6 @@ -34,13 +32,12 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: - return torch.as_tensor(x, dtype=torch.float32, device='cuda') + return torch.as_tensor(x, dtype=torch.float32, device="cuda") -def ref_rms_norm(rms_norm_layer: RMSNorm, - x: torch.Tensor, - residual: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, Optional[torch.Tensor]]: +def ref_rms_norm( + rms_norm_layer: RMSNorm, x: torch.Tensor, residual: Optional[torch.Tensor] +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() out, residual = rms_norm_layer.forward_native(x, residual) @@ -50,12 +47,13 @@ def ref_rms_norm(rms_norm_layer: RMSNorm, return out, residual -def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +def ref_dynamic_per_token_quant( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if scale_ub is not None: assert quant_dtype == torch.float8_e4m3fn @@ -64,9 +62,9 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, # Quant if quant_dtype == torch.float8_e4m3fn: - torch_out, scales = ops.scaled_fp8_quant(torch_out, - scale_ub=scale_ub, - use_per_token_if_dynamic=True) + torch_out, scales = ops.scaled_fp8_quant( + torch_out, scale_ub=scale_ub, use_per_token_if_dynamic=True + ) else: assert quant_dtype == torch.int8 torch_out, scales = ops.scaled_int8_quant(torch_out) @@ -74,38 +72,41 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, return torch_out, scales, residual -def ref_impl(rms_norm_layer: RMSNorm, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype, - residual, scale_ub) +def ref_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return ref_dynamic_per_token_quant( + rms_norm_layer, x, quant_dtype, residual, scale_ub + ) -def ops_dynamic_per_token_quant(weight: torch.Tensor, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +def ops_dynamic_per_token_quant( + weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() - out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS, - quant_dtype, scale_ub, - residual) + out, scales = ops.rms_norm_dynamic_per_token_quant( + x, weight, EPS, quant_dtype, scale_ub, residual + ) return out, scales, residual -def ops_impl(weight: torch.Tensor, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, - scale_ub) +def ops_impl( + weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub) @pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) @@ -146,12 +147,14 @@ def test_rms_norm( residual = torch.randn_like(x) * scale if add_residual else None if scale_ub is not None: rms_x, _ = ref_rms_norm(layer, x, residual) - scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda') + scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device="cuda") - ref_out, ref_scales, ref_residual = \ - ref_impl(layer, x, quant_dtype, residual, scale_ub) - ops_out, ops_scales, ops_residual = \ - ops_impl(layer.weight, x, quant_dtype, residual, scale_ub) + ref_out, ref_scales, ref_residual = ref_impl( + layer, x, quant_dtype, residual, scale_ub + ) + ops_out, ops_scales, ops_residual = ops_impl( + layer.weight, x, quant_dtype, residual, scale_ub + ) assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype @@ -160,15 +163,18 @@ def test_rms_norm( # big atol to account for round-off errors. assert torch.allclose(ref_out, ops_out, atol=1) else: - assert torch.allclose(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + assert torch.allclose( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) if add_residual: assert torch.allclose(ref_residual, ops_residual) output = torch.empty_like(x, dtype=quant_dtype) - scales = torch.empty((x.numel() // x.shape[-1], 1), - device=x.device, - dtype=torch.float32) - - opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant, - (output, x, layer.weight, scales, 1e-5, scale_ub, residual)) + scales = torch.empty( + (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 + ) + + opcheck( + torch.ops._C.rms_norm_dynamic_per_token_quant, + (output, x, layer.weight, scales, 1e-5, scale_ub, residual), + ) diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 53e6d793cf2f..7553d45e0057 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -11,13 +11,22 @@ DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, - 8199] # Arbitrary values for testing +HIDDEN_SIZES = [ + 8, + 768, + 769, + 770, + 771, + 5120, + 5124, + 5125, + 5126, + 8192, + 8199, +] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -63,11 +72,14 @@ def test_rms_norm( torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) if residual is not None: - opcheck(torch.ops._C.fused_add_rms_norm, - (x, residual, layer.weight.data, layer.variance_epsilon)) + opcheck( + torch.ops._C.fused_add_rms_norm, + (x, residual, layer.weight.data, layer.variance_epsilon), + ) else: - opcheck(torch.ops._C.rms_norm, - (out, x, layer.weight.data, layer.variance_epsilon)) + opcheck( + torch.ops._C.rms_norm, (out, x, layer.weight.data, layer.variance_epsilon) + ) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -98,7 +110,8 @@ def test_poly_norm( opcheck( torch.ops._C.poly_norm, - (out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon)) + (out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon), + ) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -144,7 +157,8 @@ def test_fused_rms_norm_quant( if add_residual: torch.ops._C.fused_add_rms_norm_static_fp8_quant( - out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6) + out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6 + ) # Unfused kernel is in-place so it goes second # Also use a separate clone of x to avoid modifying the input @@ -152,29 +166,32 @@ def test_fused_rms_norm_quant( x_unfused = x_unfused_base[..., :hidden_size] assert x_unfused.is_contiguous() != strided_input torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6) - torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(), - quant_scale_t) + torch.ops._C.static_scaled_fp8_quant( + out_quant, x_unfused.contiguous(), quant_scale_t + ) torch.cuda.synchronize() - torch.testing.assert_close(residual_fused, - residual, - atol=1e-2, - rtol=1e-2) + torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2) opcheck( torch.ops._C.fused_add_rms_norm_static_fp8_quant, - (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)) + (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6), + ) else: - torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight, - quant_scale_t, 1e-6) + torch.ops._C.rms_norm_static_fp8_quant( + out_quant_fused, x, weight, quant_scale_t, 1e-6 + ) torch.ops._C.rms_norm(out_norm, x, weight, 1e-6) - torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, - quant_scale_t) + torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, quant_scale_t) - opcheck(torch.ops._C.rms_norm_static_fp8_quant, - (out_quant_fused, x, weight, quant_scale_t, 1e-6)) - - torch.testing.assert_close(out_quant.to(dtype=torch.float32), - out_quant_fused.to(dtype=torch.float32), - atol=1e-3, - rtol=1e-3) + opcheck( + torch.ops._C.rms_norm_static_fp8_quant, + (out_quant_fused, x, weight, quant_scale_t, 1e-6), + ) + + torch.testing.assert_close( + out_quant.to(dtype=torch.float32), + out_quant_fused.to(dtype=torch.float32), + atol=1e-3, + rtol=1e-3, + ) diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py index 5a903438f5e9..02b795721f46 100644 --- a/tests/kernels/core/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -14,25 +14,25 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int, - head_size: int, max_position_embeddings: int, - dtype: torch.dtype, device: torch.device): +def generate_test_data( + num_tokens: int, + num_q_heads: int, + num_kv_heads: int, + head_size: int, + max_position_embeddings: int, + dtype: torch.dtype, + device: torch.device, +): """Generate test data for given configuration.""" current_platform.seed_everything(42) # Create 2D positions (3, num_tokens) for multimodal case - positions = torch.randint(0, - max_position_embeddings // 4, (3, num_tokens), - device=device) + positions = torch.randint( + 0, max_position_embeddings // 4, (3, num_tokens), device=device + ) # Create query and key tensors - query = torch.randn(num_tokens, - num_q_heads * head_size, - dtype=dtype, - device=device) - key = torch.randn(num_tokens, - num_kv_heads * head_size, - dtype=dtype, - device=device) + query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device) + key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device) return positions, query, key @@ -59,7 +59,8 @@ class MRoPETestInfo(NamedTuple): Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), reason="Qwen3-VL only available after Transformers v4.57", ) - ]), + ], + ), MRoPETestInfo( model_name="Qwen/Qwen3-VL-30B-A3B-Instruct", marks=[ @@ -67,24 +68,33 @@ class MRoPETestInfo(NamedTuple): Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), reason="Qwen3-VL only available after Transformers v4.57", ) - ]), + ], + ), ] num_tokens_list = [11, 8192] -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Skipping CUDA/ROCm only tests.") -@pytest.mark.parametrize("model_info, model_name", [ - pytest.param(test_config, test_config.model_name, marks=test_config.marks) - for test_config in MODELS_TO_TEST -]) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests." +) +@pytest.mark.parametrize( + "model_info, model_name", + [ + pytest.param(test_config, test_config.model_name, marks=test_config.marks) + for test_config in MODELS_TO_TEST + ], +) @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("num_tokens", num_tokens_list) -def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int, - dtype: torch.dtype, num_tokens: int): - +def test_mrope( + model_name: str, + model_info: MRoPETestInfo, + tp_size: int, + dtype: torch.dtype, + num_tokens: int, +): atol = model_info.atol rtol = model_info.rtol @@ -96,8 +106,11 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int, total_num_heads = config.num_attention_heads num_heads = total_num_heads // tp_size num_kv_heads = max(1, total_num_kv_heads // tp_size) - head_dim = (config.head_dim if hasattr(config, "head_dim") else - config.hidden_size // total_num_heads) + head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // total_num_heads + ) is_neox_style = True rope_theta = config.rope_theta @@ -117,9 +130,9 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int, # create q k v input tensors # create rotary pos emb input tensors - positions, query, key = generate_test_data(num_tokens, num_heads, - num_kv_heads, head_dim, - max_position, dtype, device) + positions, query, key = generate_test_data( + num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device + ) query_native, key_native = mrope_helper_class.forward_native( positions, @@ -137,19 +150,26 @@ def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int, torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol) -@pytest.mark.skipif(not current_platform.is_cuda_alike(), - reason="Skipping CUDA/ROCm only tests.") -@pytest.mark.parametrize("model_info, model_name", [ - pytest.param(test_config, test_config.model_name, marks=test_config.marks) - for test_config in MODELS_TO_TEST -]) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests." +) +@pytest.mark.parametrize( + "model_info, model_name", + [ + pytest.param(test_config, test_config.model_name, marks=test_config.marks) + for test_config in MODELS_TO_TEST + ], +) @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("num_tokens", num_tokens_list) -def test_mrope_torch_compile_tracing(model_name: str, - model_info: MRoPETestInfo, tp_size: int, - dtype: torch.dtype, num_tokens: int): - +def test_mrope_torch_compile_tracing( + model_name: str, + model_info: MRoPETestInfo, + tp_size: int, + dtype: torch.dtype, + num_tokens: int, +): atol = model_info.atol rtol = model_info.rtol @@ -161,8 +181,11 @@ def test_mrope_torch_compile_tracing(model_name: str, total_num_heads = config.num_attention_heads num_heads = total_num_heads // tp_size num_kv_heads = max(1, total_num_kv_heads // tp_size) - head_dim = (config.head_dim if hasattr(config, "head_dim") else - config.hidden_size // total_num_heads) + head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // total_num_heads + ) is_neox_style = True rope_theta = config.rope_theta max_position = config.max_position_embeddings @@ -180,16 +203,16 @@ def test_mrope_torch_compile_tracing(model_name: str, ).to(device=device) # Generate test data - positions, query, key = generate_test_data(num_tokens, num_heads, - num_kv_heads, head_dim, - max_position, dtype, device) + positions, query, key = generate_test_data( + num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device + ) # Create a wrapper that makes the in-place function appear functional def functional_forward_cuda(pos, q, k): """Wrapper that converts in-place operation to functional style CUDA Graph does not support in-place operations. - This wrapper creates working copies of the + This wrapper creates working copies of the input tensors and modifies them. """ q_work = q.clone() # Create working copies @@ -206,11 +229,13 @@ def functional_forward_cuda(pos, q, k): ) try: - compiled_forward_cuda = torch.compile(functional_forward_cuda, - fullgraph=True, - backend="inductor", - mode="reduce-overhead", - dynamic=False) + compiled_forward_cuda = torch.compile( + functional_forward_cuda, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False, + ) # Run compiled version query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda( @@ -225,25 +250,16 @@ def functional_forward_cuda(pos, q, k): mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda) # Verify results - torch.testing.assert_close(query_compiled_cuda, - query_cuda, - atol=atol, - rtol=rtol) - torch.testing.assert_close(key_compiled_cuda, - key_cuda, - atol=atol, - rtol=rtol) - torch.testing.assert_close(query_compiled_cuda, - query_native, - atol=atol, - rtol=rtol) - torch.testing.assert_close(key_compiled_cuda, - key_native, - atol=atol, - rtol=rtol) + torch.testing.assert_close( + query_compiled_cuda, query_cuda, atol=atol, rtol=rtol + ) + torch.testing.assert_close(key_compiled_cuda, key_cuda, atol=atol, rtol=rtol) + torch.testing.assert_close( + query_compiled_cuda, query_native, atol=atol, rtol=rtol + ) + torch.testing.assert_close(key_compiled_cuda, key_native, atol=atol, rtol=rtol) print("✓ forward_cuda successfully traced with torch.compile inductor") except Exception as e: - pytest.fail( - f"forward_cuda failed to trace with torch.compile inductor: {e}") + pytest.fail(f"forward_cuda failed to trace with torch.compile inductor: {e}") diff --git a/tests/kernels/core/test_permute_cols.py b/tests/kernels/core/test_permute_cols.py index e18f6230dbce..1e264735cb3c 100644 --- a/tests/kernels/core/test_permute_cols.py +++ b/tests/kernels/core/test_permute_cols.py @@ -8,11 +8,11 @@ from vllm._custom_ops import permute_cols -@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)]) -@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("shape", [(1, 512), (544, 4096), (67, 8192)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_permute_cols(shape, dtype): x = torch.randn(shape, dtype=dtype).cuda() perm = torch.randperm(x.shape[1]).to(torch.int).cuda() opcheck(torch.ops._C.permute_cols, (x, perm)) y = permute_cols(x, perm) - torch.testing.assert_close(y, x[:, perm]) \ No newline at end of file + torch.testing.assert_close(y, x[:, perm]) diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index 1235e3222a78..799e0a3f2a2b 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -19,30 +19,33 @@ BATCH_SIZES = [5] # Arbitrary values for testing SEQ_LENS = [11, 8192] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] USE_KEY = [True, False] -def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, - head_size: int) -> tuple[int, ...]: +def _get_flat_tensor_shape( + batch_size: int, seq_len: int, num_heads: int, head_size: int +) -> tuple[int, ...]: return (batch_size, seq_len, num_heads * head_size) # For testing sliced tensors -def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int, - head_size: int) -> tuple[int, ...]: +def _get_padded_tensor_shape( + batch_size: int, seq_len: int, num_heads: int, head_size: int +) -> tuple[int, ...]: return (batch_size, seq_len, num_heads, head_size + 64) -def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, - head_size: int) -> tuple[int, ...]: +def _get_batch_tensor_shape( + batch_size: int, seq_len: int, num_heads: int, head_size: int +) -> tuple[int, ...]: return (batch_size, seq_len, num_heads, head_size) TENSORS_SHAPES_FN = [ - _get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape + _get_batch_tensor_shape, + _get_flat_tensor_shape, + _get_padded_tensor_shape, ] @@ -97,41 +100,63 @@ def test_rotary_embedding( ref_query, ref_key = rope.forward_native(positions, query, key) out_query, out_key = rope.forward(positions, query, key) # Compare the results. - torch.testing.assert_close(out_query, - ref_query, - atol=get_default_atol(out_query), - rtol=get_default_rtol(out_query)) + torch.testing.assert_close( + out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query), + ) if use_key: - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + torch.testing.assert_close( + out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key), + ) else: - assert ref_key is None and out_key is None, \ - "expected returned key to be None" + assert ref_key is None and out_key is None, "expected returned key to be None" @torch.inference_mode() def test_rope_module_cache(): MAX_POSITIONS = [123, 1234] BASES = [10000, 1000000] - ROPE_SCALINGS = (None, { - "rope_type": "linear", - "factor": (1, ) - }, { - "rope_type": "dynamic", - "factor": 1 - }) - settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, - ROPE_SCALINGS, DTYPES) + ROPE_SCALINGS = ( + None, + {"rope_type": "linear", "factor": (1,)}, + {"rope_type": "dynamic", "factor": 1}, + ) + settings = ( + HEAD_SIZES, + ROTARY_DIMS, + MAX_POSITIONS, + BASES, + IS_NEOX_STYLE, + ROPE_SCALINGS, + DTYPES, + ) rope_setting_id_map: dict[str, int] = {} for setting in product(*settings): - head_size, rotary_dim, max_position, base, \ - is_neox_stype, rope_scaling, dtype = setting + ( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) = setting if rotary_dim is None: rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_stype, rope_scaling, dtype) + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) # different settings cannot share the same rope module assert id(rope) not in rope_setting_id_map.values() assert all(x.dtype == dtype for x in rope.buffers()) @@ -139,11 +164,25 @@ def test_rope_module_cache(): rope_setting_id_map[str(setting)] = id(rope) for setting in product(*settings): - head_size, rotary_dim, max_position, base, \ - is_neox_stype, rope_scaling, dtype = setting + ( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) = setting if rotary_dim is None: rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_stype, rope_scaling, dtype) + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_stype, + rope_scaling, + dtype, + ) # check if cache take effect assert id(rope) == rope_setting_id_map[str(setting)] diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index 5857dd5ba3fa..0a292a3e2ae7 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -13,17 +13,20 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -def rotary_embedding_opcheck(rot, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None): +def rotary_embedding_opcheck( + rot, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, +): cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) # ops.rotary_embedding() is a in-place operation # that updates the query and key tensors. - opcheck(torch.ops._C.rotary_embedding, - (positions, query, key, rot.head_size, cos_sin_cache, - rot.is_neox_style)) + opcheck( + torch.ops._C.rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, rot.is_neox_style), + ) @pytest.mark.parametrize("device", ["cuda"]) @@ -34,26 +37,30 @@ def rotary_embedding_opcheck(rot, @pytest.mark.parametrize("seq_len", [11, 1024]) @pytest.mark.parametrize("use_key", [True, False]) @pytest.mark.parametrize("head_stride_is_contiguous", [True, False]) -def test_rotary_embedding_opcheck(dist_init, device, max_position, - is_neox_style, rotary_dim, head_size, - seq_len, use_key, head_stride_is_contiguous): +def test_rotary_embedding_opcheck( + dist_init, + device, + max_position, + is_neox_style, + rotary_dim, + head_size, + seq_len, + use_key, + head_stride_is_contiguous, +): batch_size = 1 base = 10000 num_heads = 7 - rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, torch.float32) + rot = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, torch.float32 + ) - positions = torch.randint(0, - max_position, (batch_size, seq_len), - device=device) + positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) head_stride = head_size + (64 if head_stride_is_contiguous else 0) - query = torch.randn(batch_size, - seq_len, - num_heads, - head_stride, - dtype=torch.float32, - device=device) + query = torch.randn( + batch_size, seq_len, num_heads, head_stride, dtype=torch.float32, device=device + ) key = torch.randn_like(query) if use_key else None query = query[..., :head_size] key = key[..., :head_size] if use_key else None @@ -64,5 +71,8 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, # [..., num_heads * head_dim] shape/layout if head_stride_is_contiguous: rotary_embedding_opcheck( - rot, positions, query.flatten(start_dim=-2), - key.flatten(start_dim=-2) if use_key else None) + rot, + positions, + query.flatten(start_dim=-2), + key.flatten(start_dim=-2) if use_key else None, + ) diff --git a/tests/kernels/core/test_uva.py b/tests/kernels/core/test_uva.py index c71215e4c646..73738175e5c7 100644 --- a/tests/kernels/core/test_uva.py +++ b/tests/kernels/core/test_uva.py @@ -5,20 +5,14 @@ from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.skipif(not is_uva_available(), reason="UVA is not available.") @pytest.mark.parametrize("device", CUDA_DEVICES) def test_cpu_write(device): torch.set_default_device(device) - cpu_tensor = torch.zeros(10, - 10, - device="cpu", - pin_memory=True, - dtype=torch.int32) + cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32) cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor) assert cuda_view.device.type == "cuda" @@ -40,11 +34,7 @@ def test_cpu_write(device): @pytest.mark.parametrize("device", CUDA_DEVICES) def test_gpu_write(device): torch.set_default_device(device) - cpu_tensor = torch.zeros(10, - 10, - device="cpu", - pin_memory=True, - dtype=torch.int32) + cpu_tensor = torch.zeros(10, 10, device="cpu", pin_memory=True, dtype=torch.int32) cuda_view = get_cuda_view_from_cpu_tensor(cpu_tensor) assert cuda_view.device.type == "cuda" @@ -59,4 +49,4 @@ def test_gpu_write(device): assert cpu_tensor[0, 0] == 2 assert cpu_tensor[2, 3] == 4 - assert cpu_tensor[4, 5] == -2 \ No newline at end of file + assert cpu_tensor[4, 5] == -2 diff --git a/tests/kernels/mamba/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py index 411bd9e904b0..fea6b94481b6 100644 --- a/tests/kernels/mamba/test_causal_conv1d.py +++ b/tests/kernels/mamba/test_causal_conv1d.py @@ -10,7 +10,9 @@ from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) + causal_conv1d_fn, + causal_conv1d_update, +) from vllm.platforms import current_platform @@ -39,18 +41,15 @@ def causal_conv1d_ref( seqlen = x.shape[-1] dim, width = weight.shape if initial_states is None: - out = F.conv1d(x, - weight.unsqueeze(1), - bias, - padding=width - 1, - groups=dim) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) else: x = torch.cat([initial_states, x], dim=-1) out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) out = out[..., :seqlen] if return_final_states: final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in) # (batch, dim, width - 1) + dtype_in + ) # (batch, dim, width - 1) if final_states_out is not None: final_states_out.copy_(final_states) else: @@ -59,12 +58,9 @@ def causal_conv1d_ref( return (out, None) if not return_final_states else (out, final_states_out) -def causal_conv1d_update_ref(x, - conv_state, - weight, - bias=None, - activation=None, - cache_seqlens=None): +def causal_conv1d_update_ref( + x, conv_state, weight, bias=None, activation=None, cache_seqlens=None +): """ x: (batch, dim) or (batch, dim, seqlen) conv_state: (batch, dim, state_len), where state_len >= width - 1 @@ -91,24 +87,25 @@ def causal_conv1d_update_ref(x, assert weight.shape == (dim, width) if cache_seqlens is None: x_new = torch.cat([conv_state, x], dim=-1).to( - weight.dtype) # (batch, dim, state_len + seqlen) + weight.dtype + ) # (batch, dim, state_len + seqlen) conv_state.copy_(x_new[:, :, -state_len:]) else: width_idx = torch.arange( - -(width - 1), 0, dtype=torch.long, - device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand( - -1, dim, -1) - x_new = torch.cat([conv_state.gather(2, width_idx), x], - dim=-1).to(weight.dtype) - copy_idx = torch.arange( - seqlen, dtype=torch.long, - device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, - state_len).unsqueeze(1).expand(-1, dim, -1) + -(width - 1), 0, dtype=torch.long, device=x.device + ).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = ( + torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + ) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze( + 0 + ) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, - groups=dim)[:, :, -seqlen:] + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[ + :, :, -seqlen: + ] if unsqueeze: out = out.squeeze(-1) return (out if activation is None else F.silu(out)).to(dtype=dtype_in) @@ -117,15 +114,17 @@ def causal_conv1d_update_ref(x, @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) -def causal_conv1d_opcheck_fn(x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - cu_seq_len: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID): +def causal_conv1d_opcheck_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + cu_seq_len: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +): """ x: (batch, dim, seqlen) weight: (dim, width) @@ -150,8 +149,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, @pytest.mark.parametrize("seqlen", [1]) @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, - itype): +def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -167,23 +165,26 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state.detach().clone() activation = None if not silu_activation else "silu" - out = causal_conv1d_update(x, - conv_state, - weight, - bias, - activation=activation) - out_ref = causal_conv1d_update_ref(x_ref, - conv_state_ref, - weight, - bias, - activation=activation) + + conv_state_indices = torch.arange(batch, dtype=torch.int32, device=device) + + out = causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=conv_state_indices, + ) + out_ref = causal_conv1d_update_ref( + x_ref, conv_state_ref, weight, bias, activation=activation + ) assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("has_bias", [False, True]) @pytest.mark.parametrize("seqlen", [1, 3]) @@ -192,9 +193,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) @pytest.mark.parametrize("batch_size", [3]) -def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, - width, seqlen, has_bias, - silu_activation, itype): +def test_causal_conv1d_update_with_batch_gather( + batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype +): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -209,31 +210,30 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, total_entries = 10 * batch_size # x will be (batch, dim, seqlen) with contiguous along dim-axis - x = torch.randn(padded_batch_size, seqlen, dim, device=device, - dtype=itype).transpose(1, 2) + x = torch.randn( + padded_batch_size, seqlen, dim, device=device, dtype=itype + ).transpose(1, 2) x_ref = x.clone() conv_state_indices = torch.randperm(total_entries)[:batch_size].to( - dtype=torch.int32, device=device) - unused_states_bool = torch.ones(total_entries, - dtype=torch.bool, - device=device) + dtype=torch.int32, device=device + ) + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool[conv_state_indices] = False - padded_state_indices = torch.concat([ - conv_state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) - ], - dim=0) + padded_state_indices = torch.concat( + [ + conv_state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=0, + ) # conv_state will be (cache_lines, dim, state_len) # with contiguous along dim-axis - conv_state = torch.randn(total_entries, - width - 1, - dim, - device=device, - dtype=itype).transpose(1, 2) + conv_state = torch.randn( + total_entries, width - 1, dim, device=device, dtype=itype + ).transpose(1, 2) conv_state_for_padding_test = conv_state.clone() @@ -242,22 +242,23 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, conv_state_ref = conv_state[conv_state_indices, :].detach().clone() activation = None if not silu_activation else "silu" - out = causal_conv1d_update(x, - conv_state, - weight, - bias, - activation=activation, - conv_state_indices=padded_state_indices, - pad_slot_id=PAD_SLOT_ID) - out_ref = causal_conv1d_update_ref(x_ref[:batch_size], - conv_state_ref, - weight, - bias, - activation=activation) + out = causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + ) + out_ref = causal_conv1d_update_ref( + x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation + ) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) - assert torch.equal(conv_state[unused_states_bool], - conv_state_for_padding_test[unused_states_bool]) + assert torch.equal( + conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool] + ) assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) @@ -265,12 +266,13 @@ def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096]) -@pytest.mark.parametrize('dim', [64, 4096]) -@pytest.mark.parametrize('with_padding', [True, False]) -@pytest.mark.parametrize('batch', [4, 10]) -def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, - has_bias, silu_activation, itype): +@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096]) +@pytest.mark.parametrize("dim", [64, 4096]) +@pytest.mark.parametrize("with_padding", [True, False]) +@pytest.mark.parametrize("batch", [4, 10]) +def test_causal_conv1d_varlen( + batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype +): device = "cuda" torch.cuda.empty_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) @@ -288,19 +290,19 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, seqlens.append( torch.diff( - torch.cat( - [torch.tensor([-1]), eos_pos, - torch.tensor([seqlen - 1])])).tolist()) + torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])]) + ).tolist() + ) assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) - cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], - dim=0) + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0) x = rearrange( torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype), - "b s d -> b d s")[:, 4096:4096 + dim, :] + "b s d -> b d s", + )[:, 4096 : 4096 + dim, :] weight = torch.randn(dim, width, device=device, dtype=itype) @@ -309,34 +311,34 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None activation = None if not silu_activation else "silu" - final_states = torch.randn(total_entries, - width - 1, - dim, - device=x.device, - dtype=x.dtype).transpose(1, 2) + final_states = torch.randn( + total_entries, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) final_states_ref = final_states.clone() - has_initial_states = torch.randint(0, - 2, (cumsum.shape[0] - 1, ), - dtype=torch.bool, - device=x.device) - state_indices = torch.randperm(total_entries, - dtype=torch.int32, - device=x.device)[:batch_size] - padded_state_indices = torch.concat([ - state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), - ], - dim=-1) - out = causal_conv1d_fn(x.squeeze(0), - weight, - bias=bias, - conv_states=final_states, - query_start_loc=cumsum.cuda(), - cache_indices=padded_state_indices, - has_initial_state=has_initial_states, - activation=activation, - pad_slot_id=PAD_SLOT_ID) + has_initial_states = torch.randint( + 0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device + ) + state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[ + :batch_size + ] + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1, + ) + out = causal_conv1d_fn( + x.squeeze(0), + weight, + bias=bias, + conv_states=final_states, + query_start_loc=cumsum.cuda(), + cache_indices=padded_state_indices, + has_initial_state=has_initial_states, + activation=activation, + pad_slot_id=PAD_SLOT_ID, + ) out_ref = [] out_ref_b = [] @@ -353,16 +355,20 @@ def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, bias_ref, activation=activation, return_final_states=True, - final_states_out=final_states_ref[ - padded_state_indices[i]].unsqueeze(0), - initial_states=final_states_ref[padded_state_indices[i]]. - unsqueeze(0) if has_initial_states[i] else None)) + final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0), + initial_states=final_states_ref[padded_state_indices[i]].unsqueeze(0) + if has_initial_states[i] + else None, + ) + ) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) out_ref_tensor = torch.cat(out_ref, dim=0) - assert torch.allclose(final_states[state_indices], - final_states_ref[state_indices], - rtol=rtol, - atol=atol) - unpadded_out = out[:, :out_ref_tensor.shape[-1]] + assert torch.allclose( + final_states[state_indices], + final_states_ref[state_indices], + rtol=rtol, + atol=atol, + ) + unpadded_out = out[:, : out_ref_tensor.shape[-1]] assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index 16c310726ad1..d23daefa7b43 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -7,8 +7,10 @@ import torch from tests.utils import multi_gpu_test -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated from vllm.platforms import current_platform from vllm.utils import update_environment_variables @@ -24,14 +26,15 @@ (64, 2), (64, 4), # hidden_size be divisible by num_gpus (100, 5), # and n_groups must divide hidden_size - ]) + ], +) @pytest.mark.parametrize("dtype", [torch.float16]) def test_mixer2_gated_norm_multi_gpu( batch_size: int, seq_len: int, hidden_size_n_groups: tuple[int, int], dtype: torch.dtype, - device: str = 'cuda', + device: str = "cuda", ): hidden_size, n_groups = hidden_size_n_groups num_processes = 2 @@ -39,17 +42,19 @@ def test_mixer2_gated_norm_multi_gpu( def run_torch_spawn(fn, nprocs): # need to use torch.mp.spawn otherwise will have problems with # torch.distributed and cuda - torch.multiprocessing.spawn(fn, - args=( - num_processes, - batch_size, - seq_len, - hidden_size, - n_groups, - dtype, - device, - ), - nprocs=nprocs) + torch.multiprocessing.spawn( + fn, + args=( + num_processes, + batch_size, + seq_len, + hidden_size, + n_groups, + dtype, + device, + ), + nprocs=nprocs, + ) run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2) @@ -71,20 +76,22 @@ def mixer2_gated_norm_tensor_parallel( torch.set_default_device(device) torch.set_default_dtype(dtype) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) # initialize distributed init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) # create random weights an inputs - weight = torch.rand((hidden_size, ), dtype=dtype, device=device) + weight = torch.rand((hidden_size,), dtype=dtype, device=device) hidden_states = torch.randn(batch_size, seq_len, hidden_size) gate_states = torch.randn(batch_size, seq_len, hidden_size) @@ -97,14 +104,18 @@ def mixer2_gated_norm_tensor_parallel( # create gated-norm without TP to compute reference # - utilize mock patching to disable TP when - with (unittest.mock.patch( + with ( + unittest.mock.patch( "vllm.model_executor.layers.mamba.mamba_mixer2." "get_tensor_model_parallel_world_size", - return_value=1), - unittest.mock.patch( - "vllm.model_executor.layers.mamba.mamba_mixer2." - "get_tensor_model_parallel_rank", - return_value=0)): + return_value=1, + ), + unittest.mock.patch( + "vllm.model_executor.layers.mamba.mamba_mixer2." + "get_tensor_model_parallel_rank", + return_value=0, + ), + ): mixer_single_gpu = Mixer2RMSNormGated( full_hidden_size=hidden_size, full_n_groups=n_groups, @@ -115,12 +126,13 @@ def mixer2_gated_norm_tensor_parallel( # generate and compare N = hidden_size // world_size output = mixer( - hidden_states[..., local_rank * N:(local_rank + 1) * N], - gate_states[..., local_rank * N:(local_rank + 1) * N], + hidden_states[..., local_rank * N : (local_rank + 1) * N], + gate_states[..., local_rank * N : (local_rank + 1) * N], ) ref_output = mixer_single_gpu(hidden_states, gate_states) - torch.testing.assert_close(output, - ref_output[..., - local_rank * N:(local_rank + 1) * N], - atol=5e-3, - rtol=1e-3) + torch.testing.assert_close( + output, + ref_output[..., local_rank * N : (local_rank + 1) * N], + atol=5e-3, + rtol=1e-3, + ) diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index 4c32ae81b34c..9a6137239ebf 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -10,20 +10,15 @@ from vllm import _custom_ops as ops # noqa: F401 from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) + selective_scan_fn, + selective_state_update, +) from vllm.platforms import current_platform -def selective_state_update_ref(state, - x, - dt, - A, - B, - C, - D=None, - z=None, - dt_bias=None, - dt_softplus=False): +def selective_state_update_ref( + state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False +): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -73,16 +68,17 @@ def selective_state_update_ref(state, assert dt_bias.shape == (nheads, dim) dt = dt + dt_bias dt = F.softplus(dt) if dt_softplus else dt - dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * - A) # (batch, nheads, dim, dstate) - B = repeat(B, "b g n -> b (g h) n", - h=nheads // ngroups) # (batch, nheads, dstate) - C = repeat(C, "b g n -> b (g h) n", - h=nheads // ngroups) # (batch, nheads, dstate) + dA = torch.exp( + rearrange(dt, "b h d -> b h d 1") * A + ) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) dB = rearrange(dt, "b h d -> b h d 1") * rearrange( - B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) - state.copy_(state * dA + - dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + B, "b h n -> b h 1 n" + ) # (batch, nheads, dim, dstate) + state.copy_( + state * dA + dB * rearrange(x, "b h d -> b h d 1") + ) # (batch, dim, dstate out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) if D is not None: out += (x * D).to(out.dtype) @@ -92,18 +88,20 @@ def selective_state_update_ref(state, return out -def selective_scan_ref(u, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - return_last_state=False, - prev_state=None, - final_state_out=None): +def selective_scan_ref( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + prev_state=None, + final_state_out=None, +): """ u: r(B D L) delta: r(B D L) @@ -132,26 +130,26 @@ def selective_scan_ref(u, C = C.float() x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state ys = [] - deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A)) if not is_variable_B: - deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u) else: if B.dim() == 3: - deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: - y = torch.einsum('bdn,dn->bd', x, C) + y = torch.einsum("bdn,dn->bd", x, C) else: if C.dim() == 3: - y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + y = torch.einsum("bdn,bn->bd", x, C[:, :, i]) else: - y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i]) if i == u.shape[2] - 1: if final_state_out is None: final_state_out = x @@ -166,20 +164,22 @@ def selective_scan_ref(u, return out if not return_last_state else (out, final_state_out) -def selective_scan_opcheck_fn(u, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - cu_seq_len=None, - cache_indices=None, - has_initial_state=None, - ssm_states=None, - pad_slot_id=PAD_SLOT_ID): +def selective_scan_opcheck_fn( + u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + cu_seq_len=None, + cache_indices=None, + has_initial_state=None, + ssm_states=None, + pad_slot_id=PAD_SLOT_ID, +): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). """ @@ -206,30 +206,55 @@ def selective_scan_opcheck_fn(u, # Disable test_autograd_registration for now as it seems to trigger # a bogus error. - opcheck(torch.ops._C.selective_scan_fwd, - (u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len, - cache_indices, has_initial_state, ssm_states, pad_slot_id), - test_utils=["test_schema", "test_faketensor"]) - - -@pytest.mark.parametrize('wtype', [torch.float32]) -@pytest.mark.parametrize('itype', - [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) -@pytest.mark.parametrize('has_delta_bias', [True]) -@pytest.mark.parametrize('delta_softplus', [True]) -@pytest.mark.parametrize('has_z', [True]) -@pytest.mark.parametrize('has_D', [True]) + opcheck( + torch.ops._C.selective_scan_fwd, + ( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cu_seq_len, + cache_indices, + has_initial_state, + ssm_states, + pad_slot_id, + ), + test_utils=["test_schema", "test_faketensor"], + ) + + +@pytest.mark.parametrize("wtype", [torch.float32]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("has_delta_bias", [True]) +@pytest.mark.parametrize("delta_softplus", [True]) +@pytest.mark.parametrize("has_z", [True]) +@pytest.mark.parametrize("has_D", [True]) @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) @pytest.mark.parametrize("scan_chunks", [1, 2, 3]) -def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, - has_z, has_delta_bias, delta_softplus, seqlen, itype, - wtype, scan_chunks): +def test_selective_scan( + is_variable_B, + is_variable_C, + varBC_groups, + has_D, + has_z, + has_delta_bias, + delta_softplus, + seqlen, + itype, + wtype, + scan_chunks, +): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable - device = 'cuda' + device = "cuda" rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 @@ -242,7 +267,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, batch_size = 1 dim = 4 dstate = 8 - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype) A_ref = A.clone() if not is_variable_B: B_shape = [dim, dstate] @@ -250,9 +275,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, B_shape = [batch_size, dstate, seqlen] else: B_shape = [batch_size, varBC_groups, dstate, seqlen] - B = torch.randn(B_shape, - device=device, - dtype=wtype if not is_variable_B else itype) + B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype) B_ref = B.clone() if not is_variable_C: C_shape = [dim, dstate] @@ -260,27 +283,27 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, C_shape = [batch_size, dstate, seqlen] else: C_shape = [batch_size, varBC_groups, dstate, seqlen] - C = torch.randn(C_shape, - device=device, - dtype=wtype if not is_variable_C else itype) + C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) C_ref = C.clone() D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None D_ref = D.clone() - z = torch.randn(batch_size, dim, seqlen, device=device, - dtype=itype) if has_z else None + z = ( + torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + if has_z + else None + ) z_ref = z.clone() if has_z else None - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) - ) if has_delta_bias else None + delta_bias = ( + (0.5 * torch.rand(dim, device=device, dtype=torch.float32)) + if has_delta_bias + else None + ) u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) u_ref = u.clone() - delta = (0.5 * - torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) + delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype) delta_ref = delta.clone() state_shape = (batch_size, u.shape[1], int(A.shape[1])) - state = torch.randn(state_shape, - device=u.device, - dtype=itype, - requires_grad=False) + state = torch.randn(state_shape, device=u.device, dtype=itype, requires_grad=False) state_ref = state.clone() out = None out_ref = None @@ -312,9 +335,10 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, z=_z, delta_bias=delta_bias, delta_softplus=delta_softplus, - has_initial_state=torch.ones(batch_size, - device=u.device, - dtype=torch.bool) if c > 0 else None) + has_initial_state=torch.ones(batch_size, device=u.device, dtype=torch.bool) + if c > 0 + else None, + ) outs.append(out) if len(outs) > 1: out = torch.cat(outs, dim=-1) @@ -329,27 +353,29 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, z=z_ref, delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=True) + return_last_state=True, + ) assert out is not None and out_ref is not None assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) assert state is not None and state_ref is not None assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol) - selective_scan_opcheck_fn(u, - delta, - A, - B, - C, - D, - z, - delta_bias=delta_bias, - delta_softplus=delta_softplus, - ssm_states=state) + selective_scan_opcheck_fn( + u, + delta, + A, + B, + C, + D, + z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + ssm_states=state, + ) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) @@ -374,52 +400,47 @@ def test_selective_state_update(dim, dstate, has_z, itype): D = torch.randn(dim, device=device) z = torch.randn_like(x) if has_z else None state_ref = state.detach().clone() - selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - out=out) - out_ref = selective_state_update_ref(state_ref, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True) + selective_state_update( + state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out + ) + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize('wtype', [torch.float32]) -@pytest.mark.parametrize('itype', [torch.float32]) -@pytest.mark.parametrize('seqlen', [1, 128, 129, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("wtype", [torch.float32]) +@pytest.mark.parametrize("itype", [torch.float32]) +@pytest.mark.parametrize("seqlen", [1, 128, 129, 256, 512, 1024, 2048, 4096]) @pytest.mark.parametrize("return_last_state", [True]) -@pytest.mark.parametrize('has_delta_bias', [True]) -@pytest.mark.parametrize('delta_softplus', [True]) -@pytest.mark.parametrize('has_z', [True]) -@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("has_delta_bias", [True]) +@pytest.mark.parametrize("delta_softplus", [True]) +@pytest.mark.parametrize("has_z", [True]) +@pytest.mark.parametrize("has_D", [True]) @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [False, True]) -def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, - varBC_groups, has_D, has_z, has_delta_bias, - delta_softplus, return_last_state, seqlen, - itype, wtype): +def test_selective_scan_varlen( + with_padding, + is_variable_B, + is_variable_C, + varBC_groups, + has_D, + has_z, + has_delta_bias, + delta_softplus, + return_last_state, + seqlen, + itype, + wtype, +): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable - device = 'cuda' + device = "cuda" rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 3e-2, 5e-2 @@ -443,72 +464,79 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( - torch.cat( - [torch.tensor([-1]), eos_pos, - torch.tensor([seqlen - 1])])).tolist()) + torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])]) + ).tolist() + ) assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) - cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], - dim=0).cuda() + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda() dim = 4 dstate = 8 - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype) A_ref = A.clone() B_shape = [varBC_groups, dstate, seqlen] - B = torch.randn(B_shape, - device=device, - dtype=wtype if not is_variable_B else itype) + B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype) B_ref = B.clone() C_shape = [varBC_groups, dstate, seqlen] - C = torch.randn(C_shape, - device=device, - dtype=wtype if not is_variable_C else itype) + C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) C_ref = C.clone() D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None D_ref = D.clone() z = torch.randn(dim, seqlen, device=device, dtype=itype) z_ref = z.clone() - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) - ) if has_delta_bias else None + delta_bias = ( + (0.5 * torch.rand(dim, device=device, dtype=torch.float32)) + if has_delta_bias + else None + ) u = torch.randn(dim, seqlen, device=device, dtype=itype) u_ref = u.clone() - delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype)) + delta = 0.5 * torch.rand(dim, seqlen, device=device, dtype=itype) delta_ref = delta.clone() out = None out_ref = None prev_state_shape = (total_entries, u.shape[0], int(A.shape[1])) - prev_state = torch.randn(prev_state_shape, - device=u.device, - dtype=itype, - requires_grad=False) + prev_state = torch.randn( + prev_state_shape, device=u.device, dtype=itype, requires_grad=False + ) prev_state_ref = prev_state.clone() - state_indices = torch.randperm(total_entries, - dtype=torch.int32, - device=u.device)[:batch_size] - unused_states_bool = torch.ones(total_entries, - dtype=torch.bool, - device=device) + state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[ + :batch_size + ] + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool[state_indices] = False - padded_state_indices = torch.concat([ - state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), - ], - dim=-1) - - has_initial_state = torch.randint(0, - 2, (cumsum.shape[0] - 1, ), - dtype=torch.bool, - device=u.device) - out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, padded_state_indices, - has_initial_state) + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1, + ) + + has_initial_state = torch.randint( + 0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=u.device + ) + out = selective_scan_fn( + u, + prev_state, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cumsum, + padded_state_indices, + has_initial_state, + ) outs_ref = [] splits = [ torch.split(var, seqlens[0], dim=-1) @@ -530,33 +558,46 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, delta_softplus=delta_softplus, return_last_state=return_last_state, prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0) - if has_initial_state[i] else None, - final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze( - 0)) + if has_initial_state[i] + else None, + final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(0), + ) outs_ref.append(out_ref_s) out_ref = torch.cat(outs_ref, dim=-1)[0] - unpadded_out = out[:, :out_ref[0].shape[-1]] + unpadded_out = out[:, : out_ref[0].shape[-1]] print("Output diff max", (unpadded_out - out_ref).max()) print("Output diff mean", (unpadded_out - out_ref).mean()) print("Output state diff max", (prev_state - prev_state_ref).max()) print("Output state diff mean", (prev_state - prev_state_ref).mean()) assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol) assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol) - selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, padded_state_indices, - has_initial_state, prev_state) - - -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) + selective_scan_opcheck_fn( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + cumsum, + padded_state_indices, + has_initial_state, + prev_state, + ) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) -def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, - has_z, itype): +def test_selective_state_update_with_batch_indices( + with_padding, dim, dstate, has_z, itype +): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: @@ -571,17 +612,17 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state_indices = torch.randperm(total_entries)[:batch_size].to( - dtype=torch.int32, device=device) - unused_states_bool = torch.ones(total_entries, - dtype=torch.bool, - device=device) + dtype=torch.int32, device=device + ) + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) unused_states_bool[state_indices] = False - padded_state_indices = torch.concat([ - state_indices, - torch.as_tensor( - [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) - ], - dim=0) + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=0, + ) x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) out = torch.empty_like(x) dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) @@ -593,61 +634,60 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, z = torch.randn_like(x) if has_z else None state_ref = state[state_indices, :].clone() state_before = state.clone() - selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - state_batch_indices=padded_state_indices, - pad_slot_id=PAD_SLOT_ID, - out=out) - out_ref = selective_state_update_ref(state_ref, - x[:batch_size], - dt[:batch_size], - A, - B[:batch_size], - C[:batch_size], - D=D, - z=z[:batch_size], - dt_bias=dt_bias, - dt_softplus=True) + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + out=out, + ) + out_ref = selective_state_update_ref( + state_ref, + x[:batch_size], + dt[:batch_size], + A, + B[:batch_size], + C[:batch_size], + D=D, + z=z[:batch_size], + dt_bias=dt_bias, + dt_softplus=True, + ) print("Output diff max", (out[:batch_size] - out_ref).max()) print("Output diff mean", (out[:batch_size] - out_ref).mean()) print("Output state diff max", (state[state_indices, :] - state_ref).max()) - print("Output state diff mean", - (state[state_indices, :] - state_ref).mean()) + print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) # test padded entries stay the same if with_padding: - assert torch.equal(state_before[unused_states_bool], - state[unused_states_bool]) - assert torch.equal(x[batch_size + 1:], x[batch_size + 1:]) - assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:]) - assert torch.equal(B[batch_size + 1:], B[batch_size + 1:]) - assert torch.equal(C[batch_size + 1:], C[batch_size + 1:]) + assert torch.equal(state_before[unused_states_bool], state[unused_states_bool]) + assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :]) + assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :]) + assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :]) + assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :]) # test "real" entries - assert torch.allclose(state[state_indices, :], - state_ref, - rtol=rtol, - atol=atol) + assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("tie_hdim", [False, True]) @pytest.mark.parametrize("ngroups", [1, 2, 4]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 4096]) def test_selective_state_update_with_heads_with_batch_indices( - dim, dstate, ngroups, has_z, tie_hdim, itype): + dim, dstate, ngroups, has_z, tie_hdim, itype +): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) if itype == torch.bfloat16: @@ -659,71 +699,55 @@ def test_selective_state_update_with_heads_with_batch_indices( nheads = dim // headdim total_entries = 10 * batch_size - state = torch.randn(total_entries, - nheads, - headdim, - dstate, - dtype=itype, - device=device) + state = torch.randn( + total_entries, nheads, headdim, dstate, dtype=itype, device=device + ) state_indices = torch.randperm(total_entries)[:batch_size].to( - dtype=torch.int32, device=device) + dtype=torch.int32, device=device + ) x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) out = torch.empty_like(x) if not tie_hdim: - dt = torch.randn(batch_size, - nheads, - headdim, - device=device, - dtype=itype) + dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 D = torch.randn(nheads, headdim, device=device) else: - dt = repeat(torch.randn(batch_size, nheads, device=device, - dtype=itype), - "b h -> b h p", - p=headdim) - dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, - "h -> h p", - p=headdim) - A = repeat(-torch.rand(nheads, device=device) - 1.0, - "h -> h p n", - p=headdim, - n=dstate) + dt = repeat( + torch.randn(batch_size, nheads, device=device, dtype=itype), + "b h -> b h p", + p=headdim, + ) + dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim) + A = repeat( + -torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate + ) D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) B = torch.randn(batch_size, ngroups, dstate, device=device) C = torch.randn(batch_size, ngroups, dstate, device=device) z = torch.randn_like(x) if has_z else None state_ref = state[state_indices, :].detach().clone() - selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - state_batch_indices=state_indices, - pad_slot_id=PAD_SLOT_ID, - out=out) - out_ref = selective_state_update_ref(state_ref, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True) + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices, + pad_slot_id=PAD_SLOT_ID, + out=out, + ) + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - assert torch.allclose(state[state_indices, :], - state_ref, - rtol=rtol, - atol=atol) + assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 9798b27cae76..57dcb789e97b 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -7,10 +7,10 @@ from einops import rearrange, repeat from vllm.model_executor.layers.mamba.ops.ssd_combined import ( - mamba_chunk_scan_combined_varlen) + mamba_chunk_scan_combined_varlen, +) from vllm.platforms import current_platform -from vllm.v1.attention.backends.mamba2_attn import ( - compute_varlen_chunk_metadata) +from vllm.v1.attention.backends.mamba2_attn import compute_varlen_chunk_metadata # Added by the IBM Team, 2024 @@ -22,12 +22,10 @@ def segsum(x): """Calculates segment sum.""" T = x.size(-1) x = repeat(x, "... d -> ... d e", e=T) - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), - diagonal=-1) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) x = x.masked_fill(~mask, 0) x_segsum = torch.cumsum(x, dim=-2) - mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), - diagonal=0) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) x_segsum = x_segsum.masked_fill(~mask, -torch.inf) return x_segsum @@ -46,8 +44,9 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): assert X.shape[1] % block_len == 0 # Rearrange into blocks/chunks - X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len) - for x in (X, A, B, C)) + X, A, B, C = ( + rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C) + ) A = rearrange(A, "b c l h -> b h c l") A_cumsum = torch.cumsum(A, dim=-1) @@ -74,7 +73,7 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): # 4. Compute state -> output conversion per chunk # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum) - Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) # Add output of intra-chunk and inter-chunk terms # (diagonal and off-diagonal blocks) @@ -82,42 +81,31 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): return Y, final_state -def generate_random_inputs(batch_size, - seqlen, - n_heads, - d_head, - itype, - device='cuda'): - +def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"): current_platform.seed_everything(0) - A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) + A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device)) dt = F.softplus( - torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - - 4) - X = torch.randn((batch_size, seqlen, n_heads, d_head), - dtype=itype, - device=device) - B = torch.randn((batch_size, seqlen, n_heads, d_head), - dtype=itype, - device=device) - C = torch.randn((batch_size, seqlen, n_heads, d_head), - dtype=itype, - device=device) + torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4 + ) + X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) + B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) + C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) return A, dt, X, B, C -def generate_continuous_batched_examples(example_lens_by_batch, - num_examples, - full_length, - last_taken, - exhausted, - n_heads, - d_head, - itype, - device='cuda', - return_naive_ref=True): - +def generate_continuous_batched_examples( + example_lens_by_batch, + num_examples, + full_length, + last_taken, + exhausted, + n_heads, + d_head, + itype, + device="cuda", + return_naive_ref=True, +): # this function generates a random examples of certain length # and then cut according to "example_lens_by_batch" and feed # them in continuous batches to the kernels. @@ -126,23 +114,20 @@ def generate_continuous_batched_examples(example_lens_by_batch, # reference output. # generate the full-length example - A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, - d_head, itype) + A, dt, X, B, C = generate_random_inputs( + num_examples, full_length, n_heads, d_head, itype + ) if return_naive_ref: - Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), - A * dt, - B, - C, - block_len=full_length // - 4) + Y_min, final_state_min = ssd_minimal_discrete( + X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4 + ) # internal function that outputs a cont batch of examples # given a tuple of lengths for each example in the batch # e.g., example_lens=(8, 4) means take 8 samples from first eg, # 4 examples from second eg, etc def get_continuous_batch(example_lens: tuple[int, ...]): - indices = [] for i, x in enumerate(example_lens): c = last_taken.get(i, 0) @@ -150,8 +135,10 @@ def get_continuous_batch(example_lens: tuple[int, ...]): last_taken[i] = (c + x) % full_length exhausted[i] = last_taken[i] == 0 - return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices) - ]).unsqueeze(0) for x in (dt, X, B, C)) + return ( + torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0) + for x in (dt, X, B, C) + ) # internal function that maps "n" to the appropriate right boundary # value when forming continuous batches from examples of length given @@ -163,19 +150,20 @@ def end_boundary(n: int): IND_E = None for spec in example_lens_by_batch: - # get the (maybe partial) example seen in this cont batch dt2, X2, B2, C2 = get_continuous_batch(spec) # get the metadata - cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) - seq_idx = torch.zeros(cu_seqlens[-1], - dtype=torch.int32, - device=cu_seqlens.device) - for i, (srt, end) in enumerate(zip( + cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0) + seq_idx = torch.zeros( + cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device + ) + for i, (srt, end) in enumerate( + zip( cu_seqlens, cu_seqlens[1:], - )): + ) + ): seq_idx[srt:end] = i # for cont batch @@ -190,19 +178,21 @@ def end_boundary(n: int): X2 = X2.squeeze(0) B2 = B2.squeeze(0) C2 = C2.squeeze(0) - yield ([Y_min[s, IND_S[s]:IND_E[s]] - for s in range(num_examples)] if return_naive_ref else None, - cu_seqlens, seq_idx, (A, dt2, X2, B2, C2)) + yield ( + [Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)] + if return_naive_ref + else None, + cu_seqlens, + seq_idx, + (A, dt2, X2, B2, C2), + ) -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) @pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) @pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)]) -def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, - itype): - +def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype): # this tests the kernels on a single example (bs=1) # TODO: the bfloat16 case requires higher thresholds. To be investigated @@ -219,15 +209,16 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, # it is not an operational limitation. seqlen, chunk_size = seq_len_chunk_size - A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, - d_head, itype) + A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype) - Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, - B, C, chunk_size) + Y_min, final_state_min = ssd_minimal_discrete( + X * dt.unsqueeze(-1), A * dt, B, C, chunk_size + ) cu_seqlens = torch.tensor((0, seqlen), device="cuda").cumsum(dim=0) cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( - compute_varlen_chunk_metadata(cu_seqlens, chunk_size)) + compute_varlen_chunk_metadata(cu_seqlens, chunk_size) + ) # varlen has implicit batch=1 X = X.squeeze(0) dt = dt.squeeze(0) @@ -255,10 +246,12 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, # just test the last head # NOTE, in the kernel we always cast states to fp32 - torch.testing.assert_close(final_state[:, -1].to(torch.float32), - final_state_min[:, -1].to(torch.float32), - atol=atol, - rtol=rtol) + torch.testing.assert_close( + final_state[:, -1].to(torch.float32), + final_state_min[:, -1].to(torch.float32), + atol=atol, + rtol=rtol, + ) @pytest.mark.parametrize("itype", [torch.float32, torch.float16]) @@ -267,32 +260,40 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, @pytest.mark.parametrize( "seq_len_chunk_size_cases", [ - # small-ish chunk_size (8) (64, 8, 2, [(64, 32), (64, 32)]), (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary - (64, 8, 2, [(4, 4), (4, 4), (4, 4), - (4, 4)]), # chunk_size larger than cont batches - (64, 8, 5, [ - (64, 32, 16, 8, 8), - (8, 16, 32, 16, 8), - (8, 8, 16, 32, 16), - ]), # mode examples with varied lengths - + ( + 64, + 8, + 2, + [(4, 4), (4, 4), (4, 4), (4, 4)], + ), # chunk_size larger than cont batches + ( + 64, + 8, + 5, + [ + (64, 32, 16, 8, 8), + (8, 16, 32, 16, 8), + (8, 8, 16, 32, 16), + ], + ), # mode examples with varied lengths # large-ish chunk_size (256) - (64, 256, 1, [(5, ), (1, ), (1, ), - (1, )]), # irregular sizes with small sequences - (64, 256, 2, [(5, 30), (1, 2), (1, 2), - (1, 2)]), # irregular sizes with small sequences - + (64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences + ( + 64, + 256, + 2, + [(5, 30), (1, 2), (1, 2), (1, 2)], + ), # irregular sizes with small sequences # we also need to test some large seqlen # to catch errors with init states decay (768, 128, 2, [(138, 225), (138, 225)]), - ]) -def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, - itype): - + ], +) +def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype): # this test with multiple examples in a continuous batch # (i.e. chunked prefill) @@ -311,12 +312,17 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, states = None for Y_min, cu_seqlens, _token_seq_idx, ( - A, dt, X, B, C) in generate_continuous_batched_examples( - cases, num_examples, seqlen, last_taken, exhausted, n_heads, - d_head, itype): - + A, + dt, + X, + B, + C, + ) in generate_continuous_batched_examples( + cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype + ): cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( - compute_varlen_chunk_metadata(cu_seqlens, chunk_size)) + compute_varlen_chunk_metadata(cu_seqlens, chunk_size) + ) Y = torch.empty_like(X) new_states = mamba_chunk_scan_combined_varlen( @@ -337,9 +343,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # just test the last in sequence for i in range(num_examples): - # just test one dim and dstate - Y_eg = Y[cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] + Y_eg = Y[cu_seqlens[i] : cu_seqlens[i + 1], 0, 0] Y_min_eg = Y_min[i][:, 0, 0] torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol) @@ -347,18 +352,20 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, states = new_states for i, clear in exhausted.items(): if clear: - states[i].fill_(0.) + states[i].fill_(0.0) exhausted[i] = False @pytest.mark.parametrize("chunk_size", [8, 256]) -@pytest.mark.parametrize("seqlens", [ - (16, 2, 8, 13), - (270, 88, 212, 203), - (16, 20), -]) +@pytest.mark.parametrize( + "seqlens", + [ + (16, 2, 8, 13), + (270, 88, 212, 203), + (16, 20), + ], +) def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): - # This test verifies the correctness of the chunked prefill implementation # in the mamba2 ssd kernels, by comparing concatenation (in the sequence # dimension) of chunked results with the full sequence result. @@ -387,21 +394,25 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): last_taken: dict = {} # map: eg -> pointer to last taken sample exhausted: dict = {} # map: eg -> boolean indicating example is exhausted _, cu_seqlens, seq_idx, (A, dt, X, B, C) = next( - generate_continuous_batched_examples([seqlens], - num_sequences, - max_seqlen, - last_taken, - exhausted, - n_heads, - d_head, - itype, - return_naive_ref=False)) + generate_continuous_batched_examples( + [seqlens], + num_sequences, + max_seqlen, + last_taken, + exhausted, + n_heads, + d_head, + itype, + return_naive_ref=False, + ) + ) seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device) device = X.device ## full seqlen computation cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( - compute_varlen_chunk_metadata(cu_seqlens, chunk_size)) + compute_varlen_chunk_metadata(cu_seqlens, chunk_size) + ) Y_ref = torch.empty_like(X) state_ref = mamba_chunk_scan_combined_varlen( X, @@ -422,28 +433,35 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): ## chunked seqlen computation # first chunk chunked_seqlens = seqlens // 2 - chunked_cu_seqlens = torch.cat([ - torch.tensor([0], device=device), - torch.cumsum(chunked_seqlens, dim=0) - ], - dim=0) + chunked_cu_seqlens = torch.cat( + [torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0)], dim=0 + ) chunked_input_seq_len = chunked_cu_seqlens[-1] X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...] dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...] B_chunked = torch.zeros_like(B)[:chunked_input_seq_len, ...] C_chunked = torch.zeros_like(C)[:chunked_input_seq_len, ...] for i in range(num_sequences): - # fmt: off - chunk_f = lambda x, i: x[cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501 + chunk_f = lambda x, i: x[ + cu_seqlens[i] : cu_seqlens[i] + chunked_seqlens[i], ... + ] - X_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501 - dt_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501 - B_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501 - C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501 - # fmt: on + X_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + X, i + ) + dt_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + dt, i + ) + B_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + B, i + ) + C_chunked[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...] = chunk_f( + C, i + ) cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( - compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size)) + compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size) + ) Y_partial = torch.empty_like(X_chunked) partial_state = mamba_chunk_scan_combined_varlen( X_chunked, @@ -463,33 +481,50 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): # remaining chunk remaining_chunked_seqlens = seqlens - chunked_seqlens - remaining_chunked_cu_seqlens = torch.cat([ - torch.tensor([0], device=device), - torch.cumsum(remaining_chunked_seqlens, dim=0) - ], - dim=0) + remaining_chunked_cu_seqlens = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(remaining_chunked_seqlens, dim=0), + ], + dim=0, + ) remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] - # fmt: off - remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501 - remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] # noqa: E501 - remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] # noqa: E501 - remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] + remaining_dt_chunked = torch.zeros_like(dt)[:remaining_chunked_input_seq_len, ...] + remaining_B_chunked = torch.zeros_like(B)[:remaining_chunked_input_seq_len, ...] + remaining_C_chunked = torch.zeros_like(C)[:remaining_chunked_input_seq_len, ...] for i in range(num_sequences): - remaining_chunk_f = lambda x, i: x[cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501 - - remaining_X_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501 - remaining_dt_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501 - remaining_B_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501 - remaining_C_chunked[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501 + remaining_chunk_f = lambda x, i: x[ + cu_seqlens[i] + chunked_seqlens[i] : cu_seqlens[i + 1], ... + ] + + remaining_X_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(X, i) + remaining_dt_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(dt, i) + remaining_B_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(B, i) + remaining_C_chunked[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], ... + ] = remaining_chunk_f(C, i) # assert input chunking is correct - concat_chunk_f = lambda pt1, pt2, i: torch.cat([ - pt1[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], - pt2[remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], + concat_chunk_f = lambda pt1, pt2, i: torch.cat( + [ + pt1[chunked_cu_seqlens[i] : chunked_cu_seqlens[i + 1], ...], + pt2[ + remaining_chunked_cu_seqlens[i] : remaining_chunked_cu_seqlens[i + 1], + ..., + ], ], - dim=0) - concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0) # noqa: E501 - # fmt: on + dim=0, + ) + concat_batch_f = lambda pt1, pt2: torch.cat( + [concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=0 + ) assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X) assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt) @@ -497,8 +532,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C) cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = ( - compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens, - chunk_size)) + compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens, chunk_size) + ) Y_chunked = torch.empty_like(remaining_X_chunked) state_chunked = mamba_chunk_scan_combined_varlen( @@ -520,20 +555,22 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): # kernel chunked is same as kernel overall for i in range(num_sequences): - Y_seq = Y[cu_seqlens[i]:cu_seqlens[i + 1], ...] - Y_ref_seq = Y_ref[cu_seqlens[i]:cu_seqlens[i + 1], ...] + Y_seq = Y[cu_seqlens[i] : cu_seqlens[i + 1], ...] + Y_ref_seq = Y_ref[cu_seqlens[i] : cu_seqlens[i + 1], ...] torch.testing.assert_close( - Y_seq[:chunked_seqlens[i], ...], - Y_ref_seq[:chunked_seqlens[i], ...], + Y_seq[: chunked_seqlens[i], ...], + Y_ref_seq[: chunked_seqlens[i], ...], atol=atol, rtol=rtol, - msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023 + msg=lambda x, i=i: f"seq{i} output part1 " + x, + ) torch.testing.assert_close( - Y_seq[chunked_seqlens[i]:, ...], - Y_ref_seq[chunked_seqlens[i]:, ...], + Y_seq[chunked_seqlens[i] :, ...], + Y_ref_seq[chunked_seqlens[i] :, ...], atol=atol, rtol=rtol, - msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023 + msg=lambda x, i=i: f"seq{i} output part2 " + x, + ) state_seq = state_chunked[i] state_seq_ref = state_ref[i] @@ -542,4 +579,5 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): state_seq_ref, atol=atol, rtol=rtol, - msg=lambda x: f"seq{i} state " + x) # noqa: B023 + msg=lambda x, i=i: f"seq{i} state " + x, + ) diff --git a/tests/kernels/moe/modular_kernel_tools/cli_args.py b/tests/kernels/moe/modular_kernel_tools/cli_args.py index b95d87cd04f5..d46847fbf6a3 100644 --- a/tests/kernels/moe/modular_kernel_tools/cli_args.py +++ b/tests/kernels/moe/modular_kernel_tools/cli_args.py @@ -9,18 +9,19 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from .common import Config -from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES, - MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) +from .mk_objects import ( + MK_ALL_PREPARE_FINALIZE_TYPES, + MK_FUSED_EXPERT_TYPES, + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, +) def make_config_arg_parser(description: str): - def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize: for pf in MK_ALL_PREPARE_FINALIZE_TYPES: if pf.__name__ == s: return pf - raise ValueError( - f"Cannot find a PrepareFinalize type that matches {s}") + raise ValueError(f"Cannot find a PrepareFinalize type that matches {s}") def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute: for fe in MK_FUSED_EXPERT_TYPES: @@ -45,15 +46,18 @@ def to_quant_torch_dtype(s: str) -> torch.dtype: "--pf-type", type=to_pf_class_type, required=True, - help=("Choose a PrepareFinalize Type : " - f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"), + help=( + "Choose a PrepareFinalize Type : " + f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}" + ), ) parser.add_argument( "--experts-type", type=to_experts_class_type, required=True, - help=(f"Choose a FusedExpert type : " - f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"), + help=( + f"Choose a FusedExpert type : {[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}" + ), ) parser.add_argument( "-m", @@ -74,66 +78,65 @@ def to_quant_torch_dtype(s: str) -> torch.dtype: default=1024, help="N dimension of the first fused-moe matmul", ) - parser.add_argument("--num-experts", - type=int, - default=32, - help="Global num experts") - parser.add_argument("--topk", - nargs="+", - type=int, - default=[4, 1], - help="num topk") + parser.add_argument( + "--num-experts", type=int, default=32, help="Global num experts" + ) + parser.add_argument("--topk", nargs="+", type=int, default=[4, 1], help="num topk") parser.add_argument( "--fused-moe-chunk-size", type=int, - help="Fused moe chunk size used for the non-batched fused experts impl." + help="Fused moe chunk size used for the non-batched fused experts impl.", ) # Quant args - parser.add_argument("--quant-dtype", - type=to_quant_torch_dtype, - help="Quant datatype") - parser.add_argument("--per-token-quantized-activations", - action='store_true', - help=("The input activations must be per-token " - "quantized")) - parser.add_argument("--per-channel-quantized-weights", - action="store_true", - help="The weights must be per-channel quantized.") - parser.add_argument("--block-shape", - nargs="+", - type=int, - help="Quantization block shape") + parser.add_argument( + "--quant-dtype", type=to_quant_torch_dtype, help="Quant datatype" + ) + parser.add_argument( + "--per-token-quantized-activations", + action="store_true", + help=("The input activations must be per-token quantized"), + ) + parser.add_argument( + "--per-channel-quantized-weights", + action="store_true", + help="The weights must be per-channel quantized.", + ) + parser.add_argument( + "--block-shape", nargs="+", type=int, help="Quantization block shape" + ) # Torch trace profile generation args - parser.add_argument("--torch-trace-dir-path", - type=str, - default=None, - help="Get torch trace for single execution") + parser.add_argument( + "--torch-trace-dir-path", + type=str, + default=None, + help="Get torch trace for single execution", + ) return parser def _validate_args(args: argparse.Namespace): - if args.quant_dtype is not None: assert args.quant_dtype == torch.float8_e4m3fn if args.block_shape is not None: assert len(args.block_shape) == 2, ( - f"block shape must have 2 elements. got {args.block_shape}") + f"block shape must have 2 elements. got {args.block_shape}" + ) if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: - assert args.world_size == 1, ( - "Single GPU objects need world size set to 1") + assert args.world_size == 1, "Single GPU objects need world size set to 1" if args.torch_trace_dir_path is not None: from pathlib import Path + assert Path(args.torch_trace_dir_path).is_dir(), ( - f"Please create {args.torch_trace_dir_path}") + f"Please create {args.torch_trace_dir_path}" + ) def make_config(args: argparse.Namespace) -> Config: - _validate_args(args) quant_config = None @@ -142,7 +145,8 @@ def make_config(args: argparse.Namespace) -> Config: quant_dtype=args.quant_dtype, per_act_token_quant=args.per_token_quantized_activations, per_out_ch_quant=args.per_channel_quantized_weights, - block_shape=args.block_shape) + block_shape=args.block_shape, + ) return Config( Ms=args.m, @@ -156,4 +160,5 @@ def make_config(args: argparse.Namespace) -> Config: fused_experts_type=args.experts_type, fused_moe_chunk_size=args.fused_moe_chunk_size, world_size=args.world_size, - torch_trace_dir_path=args.torch_trace_dir_path) + torch_trace_dir_path=args.torch_trace_dir_path, + ) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index b5fcc4cd70bf..091fa4fafe21 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -8,20 +8,30 @@ import vllm._custom_ops as ops import vllm.model_executor.layers.fused_moe.modular_kernel as mk from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8 -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) from tests.kernels.utils import torch_experts from vllm.config import VllmConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig) + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx -from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts, - make_prepare_finalize, prepare_finalize_info) +from .mk_objects import ( + TestMoEQuantConfig, + expert_info, + make_fused_experts, + make_prepare_finalize, + prepare_finalize_info, +) from .parallel_utils import ProcessGroupInfo @@ -94,8 +104,7 @@ def is_per_act_token_quant(self) -> bool: @property def is_per_tensor_act_quant(self) -> bool: - return (not self.is_per_act_token_quant - and self.quant_block_shape is None) + return not self.is_per_act_token_quant and self.quant_block_shape is None @property def is_per_out_ch_quant(self) -> bool: @@ -134,23 +143,24 @@ def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]: if self.fused_moe_chunk_size is not None: env_dict.update( - {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}) + {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)} + ) return vllm_config, env_dict def is_fp8_block_quantized(self): - return (self.quant_dtype == torch.float8_e4m3fn - and self.quant_block_shape is not None) + return ( + self.quant_dtype == torch.float8_e4m3fn + and self.quant_block_shape is not None + ) def is_batched_prepare_finalize(self): info = prepare_finalize_info(self.prepare_finalize_type) - return (mk.FusedMoEActivationFormat.BatchedExperts == - info.activation_format) + return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format def is_batched_fused_experts(self): info = expert_info(self.fused_experts_type) - return (mk.FusedMoEActivationFormat.BatchedExperts == - info.activation_format) + return mk.FusedMoEActivationFormat.BatchedExperts == info.activation_format def is_standard_fused_experts(self): info = expert_info(self.fused_experts_type) @@ -190,8 +200,10 @@ def needs_pplx(self): def needs_deep_ep(self): info = prepare_finalize_info(self.prepare_finalize_type) - return (info.backend == "deepep_high_throughput" - or info.backend == "deepep_low_latency") + return ( + info.backend == "deepep_high_throughput" + or info.backend == "deepep_low_latency" + ) def all2all_backend(self): info = prepare_finalize_info(self.prepare_finalize_type) @@ -211,20 +223,26 @@ def is_valid(self): return False # Check quantization sanity - if (int(self.is_per_act_token_quant) + - int(self.is_per_tensor_act_quant) + - int(self.quant_block_shape is not None)) > 1: + if ( + int(self.is_per_act_token_quant) + + int(self.is_per_tensor_act_quant) + + int(self.quant_block_shape is not None) + ) > 1: # invalid quant config return False # check type support if self.quant_dtype is None: - if (self.dtype not in self.pf_supported_types() - or self.dtype not in self.fe_supported_types()): + if ( + self.dtype not in self.pf_supported_types() + or self.dtype not in self.fe_supported_types() + ): return False else: - if (self.quant_dtype not in self.pf_supported_types() - or self.quant_dtype not in self.fe_supported_types()): + if ( + self.quant_dtype not in self.pf_supported_types() + or self.quant_dtype not in self.fe_supported_types() + ): return False # Check block quanization support @@ -261,18 +279,21 @@ class WeightTensors: def describe(self): s = "" s += "== Weight Tensors: \n" - s += f' - {_describe_tensor(self.w1, "w1")} \n' - s += f' - {_describe_tensor(self.w2, "w2")} \n' - s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n' - s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n' - s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n' - s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n' + s += f" - {_describe_tensor(self.w1, 'w1')} \n" + s += f" - {_describe_tensor(self.w2, 'w2')} \n" + s += f" - {_describe_tensor(self.w1_scale, 'w1_scale')} \n" + s += f" - {_describe_tensor(self.w2_scale, 'w2_scale')} \n" + s += f" - {_describe_tensor(self.w1_gs, 'w1_gs')} \n" + s += f" - {_describe_tensor(self.w2_gs, 'w2_gs')} \n" return s def is_quantized(self) -> bool: # or w1_scale is not None? - return (self.w1.dtype == torch.float8_e4m3fn - or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8) + return ( + self.w1.dtype == torch.float8_e4m3fn + or self.w1.dtype == torch.uint8 + or self.w1.dtype == torch.int8 + ) def to_current_device(self): device = torch.cuda.current_device() @@ -289,16 +310,13 @@ def to_current_device(self): if self.w2_gs is not None: self.w2_gs = self.w2_gs.to(device=device) - def slice_weights(self, rank: int, - num_local_experts: int) -> "WeightTensors": + def slice_weights(self, rank: int, num_local_experts: int) -> "WeightTensors": s = rank * num_local_experts e = s + num_local_experts w1 = self.w1[s:e, :, :] w2 = self.w2[s:e, :, :] - w1_scale = self.w1_scale[ - s:e, :, :] if self.w1_scale is not None else None - w2_scale = self.w2_scale[ - s:e, :, :] if self.w2_scale is not None else None + w1_scale = self.w1_scale[s:e, :, :] if self.w1_scale is not None else None + w2_scale = self.w2_scale[s:e, :, :] if self.w2_scale is not None else None w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None @@ -313,15 +331,12 @@ def make(config: Config) -> "WeightTensors": in_dtype=config.dtype, quant_dtype=config.quant_dtype, block_shape=config.quant_block_shape, - per_out_ch_quant=config. - is_per_act_token_quant, # or config.is_per_out_ch_quant + # or config.is_per_out_ch_quant + per_out_ch_quant=config.is_per_act_token_quant, + ) + return WeightTensors( + w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale, w1_gs=w1_gs, w2_gs=w2_gs ) - return WeightTensors(w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_gs=w1_gs, - w2_gs=w2_gs) @dataclass @@ -336,22 +351,22 @@ class RankTensors: def describe(self): s = "" s += "== Rank Tensors: \n" - s += f' - {_describe_tensor(self.hidden_states, "HS")} \n' - s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n' - s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n' - s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n' - s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n' + s += f" - {_describe_tensor(self.hidden_states, 'HS')} \n" + s += f" - {_describe_tensor(self.hidden_states_scale, 'HS_scale')} \n" + s += f" - {_describe_tensor(self.topk_weights, 'topk_weights')} \n" + s += f" - {_describe_tensor(self.topk_ids, 'topk_ids')} \n" + s += f" - {_describe_tensor(self.expert_map, 'expert_map')} \n" return s @staticmethod def make_hidden_states( - config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + config: Config, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Return hidden_states """ m, k, dtype = (config.M, config.K, config.dtype) - a = (torch.randn( - (m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0) + a = torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0 if config.quant_dtype is None: return a, None @@ -362,36 +377,29 @@ def make_hidden_states( # first - so further quantize and dequantize will yield the same # values. if config.is_per_tensor_act_quant: - a_q, a_scales = ops.scaled_fp8_quant( - a, use_per_token_if_dynamic=False) + a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=False) return a_q.float().mul(a_scales).to(dtype), a_scales if config.is_per_act_token_quant: - a_q, a_scales = ops.scaled_fp8_quant(a, - use_per_token_if_dynamic=True) + a_q, a_scales = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True) return a_q.float().mul(a_scales).to(dtype), None assert config.quant_block_shape is not None block_k = config.quant_block_shape[1] a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k) - return a_q.float().view( - (-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None + return a_q.float().view((-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to( + dtype + ), None @staticmethod def make(config: Config, pgi: ProcessGroupInfo): - dtype = config.dtype topk, m, _ = (config.topk, config.M, config.K) - hidden_states, hidden_states_scale = RankTensors.make_hidden_states( - config) + hidden_states, hidden_states_scale = RankTensors.make_hidden_states(config) - num_local_experts, global_num_experts = (config.num_local_experts, - config.E) - score = torch.randn((m, global_num_experts), - device="cuda", - dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, - False) + num_local_experts, global_num_experts = (config.num_local_experts, config.E) + score = torch.randn((m, global_num_experts), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, False) # distribute topk_ids evenly for mi in range(m): @@ -400,14 +408,15 @@ def make(config: Config, pgi: ProcessGroupInfo): expert_map = None if config.world_size > 1 and config.supports_expert_map(): - expert_map = torch.full((global_num_experts, ), - fill_value=-1, - dtype=torch.int32) + expert_map = torch.full( + (global_num_experts,), fill_value=-1, dtype=torch.int32 + ) s = pgi.rank * num_local_experts e = s + num_local_experts expert_map[s:e] = torch.tensor(list(range(num_local_experts))) - expert_map = expert_map.to(device=torch.cuda.current_device(), - dtype=torch.int32) + expert_map = expert_map.to( + device=torch.cuda.current_device(), dtype=torch.int32 + ) return RankTensors( hidden_states=hidden_states, @@ -418,9 +427,9 @@ def make(config: Config, pgi: ProcessGroupInfo): ) -def reference_moe_impl(config: Config, weights: WeightTensors, - rank_tensors: RankTensors) -> torch.Tensor: - +def reference_moe_impl( + config: Config, weights: WeightTensors, rank_tensors: RankTensors +) -> torch.Tensor: if config.quant_dtype == "nvfp4": quant_blocksize = 16 dtype = config.dtype @@ -433,8 +442,10 @@ def reference_moe_impl(config: Config, weights: WeightTensors, w2_blockscale = weights.w2_scale w2_gs = weights.w2_gs - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax( - rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) + / torch.amax(rank_tensors.hidden_states.flatten(), dim=-1) + ).to(torch.float32) assert w1_gs is not None assert w2_gs is not None @@ -447,14 +458,17 @@ def reference_moe_impl(config: Config, weights: WeightTensors, assert w2_blockscale.shape[2] % 4 == 0 a_fp4, a_scale_interleaved = ops.scaled_fp4_quant( - rank_tensors.hidden_states, a_global_scale) + rank_tensors.hidden_states, a_global_scale + ) - a = dequantize_nvfp4_to_dtype(a_fp4, - a_scale_interleaved, - a_global_scale, - dtype=dtype, - device=a_fp4.device, - block_size=quant_blocksize) + a = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=dtype, + device=a_fp4.device, + block_size=quant_blocksize, + ) e = w1_q.shape[0] n = w1_q.shape[1] // 2 @@ -464,18 +478,22 @@ def reference_moe_impl(config: Config, weights: WeightTensors, w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype) for idx in range(0, e): - w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], - w1_blockscale[idx], - w1_gs[idx], - dtype=dtype, - device=w1_q.device, - block_size=quant_blocksize) - w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], - w2_blockscale[idx], - w2_gs[idx], - dtype=dtype, - device=w2_q.device, - block_size=quant_blocksize) + w1[idx] = dequantize_nvfp4_to_dtype( + w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize, + ) + w2[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize, + ) a_scale = None w1_scale = None w2_scale = None @@ -493,27 +511,29 @@ def reference_moe_impl(config: Config, weights: WeightTensors, per_act_token_quant = config.is_per_act_token_quant block_shape = config.quant_block_shape - return torch_experts(a=a, - w1=w1, - w2=w2, - topk_weight=rank_tensors.topk_weights, - topk_ids=rank_tensors.topk_ids, - global_num_experts=config.E, - expert_map=None, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale, - quant_dtype=quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - apply_router_weights_on_input=config.topk == 1 - and config.supports_apply_weight_on_input()) + return torch_experts( + a=a, + w1=w1, + w2=w2, + topk_weight=rank_tensors.topk_weights, + topk_ids=rank_tensors.topk_ids, + global_num_experts=config.E, + expert_map=None, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + quant_dtype=quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + apply_router_weights_on_input=config.topk == 1 + and config.supports_apply_weight_on_input(), + ) def _make_gscale(num_experts: int) -> torch.Tensor: - return torch.ones((num_experts, ), - device=torch.cuda.current_device(), - dtype=torch.float32) + return torch.ones( + (num_experts,), device=torch.cuda.current_device(), dtype=torch.float32 + ) def make_modular_kernel( @@ -521,12 +541,12 @@ def make_modular_kernel( vllm_config: VllmConfig, quant_config: FusedMoEQuantConfig, ) -> mk.FusedMoEModularKernel: - def next_power_of_2(x): import math + if x == 0: return 1 - return 2**math.ceil(math.log2(x)) + return 2 ** math.ceil(math.log2(x)) # make moe config moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( @@ -546,9 +566,9 @@ def next_power_of_2(x): ) # make modular kernel - prepare_finalize = make_prepare_finalize(config.prepare_finalize_type, - config.all2all_backend(), moe, - quant_config) + prepare_finalize = make_prepare_finalize( + config.prepare_finalize_type, config.all2all_backend(), moe, quant_config + ) fused_experts = make_fused_experts( config.fused_experts_type, @@ -559,7 +579,8 @@ def next_power_of_2(x): ) modular_kernel = mk.FusedMoEModularKernel( - prepare_finalize=prepare_finalize, fused_experts=fused_experts) + prepare_finalize=prepare_finalize, fused_experts=fused_experts + ) return modular_kernel @@ -587,10 +608,8 @@ def run_modular_kernel( w1_scale=rank_weights.w1_scale, w2_scale=rank_weights.w2_scale, a1_scale=rank_tensors.hidden_states_scale, - g1_alphas=(1 / rank_weights.w1_gs) - if rank_weights.w1_gs is not None else None, - g2_alphas=(1 / rank_weights.w2_gs) - if rank_weights.w2_gs is not None else None, + g1_alphas=(1 / rank_weights.w1_gs) if rank_weights.w1_gs is not None else None, + g2_alphas=(1 / rank_weights.w2_gs) if rank_weights.w2_gs is not None else None, a1_gscale=gscale, a2_gscale=gscale, block_shape=config.quant_block_shape, @@ -603,38 +622,30 @@ def run_modular_kernel( # impls might update the tensor in place hidden_states = rank_tensors.hidden_states.clone() - topk_ids = rank_tensors.topk_ids.to( - mk.prepare_finalize.topk_indices_dtype()) + topk_ids = rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()) mk_kwargs = { - "hidden_states": - hidden_states, - "w1": - rank_weights.w1, - "w2": - rank_weights.w2, - "topk_weights": - rank_tensors.topk_weights, - "topk_ids": - topk_ids, - "expert_map": - rank_tensors.expert_map, - "global_num_experts": - config.E, - "apply_router_weight_on_input": - config.topk == 1 and config.supports_apply_weight_on_input(), + "hidden_states": hidden_states, + "w1": rank_weights.w1, + "w2": rank_weights.w2, + "topk_weights": rank_tensors.topk_weights, + "topk_ids": topk_ids, + "expert_map": rank_tensors.expert_map, + "global_num_experts": config.E, + "apply_router_weight_on_input": config.topk == 1 + and config.supports_apply_weight_on_input(), } num_tokens = rank_tensors.hidden_states.shape[0] - num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size, - device="cuda", - dtype=torch.int) + num_tokens_across_dp = torch.tensor( + [num_tokens] * config.world_size, device="cuda", dtype=torch.int + ) with set_forward_context( - None, - vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, + None, + vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, ): out = mk.forward(**mk_kwargs) diff --git a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py index c1037b60bf38..0ef306051c8a 100644 --- a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py +++ b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py @@ -10,14 +10,21 @@ from tqdm import tqdm from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import ( - FUSED_MOE_UNQUANTIZED_CONFIG) +from vllm.model_executor.layers.fused_moe.config import FUSED_MOE_UNQUANTIZED_CONFIG from vllm.platforms import current_platform -from .common import (Config, RankTensors, WeightTensors, reference_moe_impl, - run_modular_kernel) -from .mk_objects import (MK_FUSED_EXPERT_TYPES, - MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS) +from .common import ( + Config, + RankTensors, + WeightTensors, + reference_moe_impl, + run_modular_kernel, +) +from .mk_objects import ( + MK_FUSED_EXPERT_TYPES, + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, + MK_QUANT_CONFIGS, +) from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config @@ -38,8 +45,9 @@ def rank_worker( # sanity check from vllm import envs + if config.fused_moe_chunk_size is not None: - assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE # get weights to this device weights.to_current_device() @@ -60,8 +68,7 @@ def rank_worker( rank_tensors = RankTensors.make(cfgx, pgi) # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, - rank_tensors) + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors) with set_current_vllm_config(vllm_config): ref_out = reference_moe_impl(cfgx, weights, rank_tensors) @@ -70,28 +77,27 @@ def rank_worker( def make_feature_matrix(csv_file_path: str): - from dataclasses import asdict import pandas as pd - def add_to_results(config: Config, - success: Result, - results_df: Optional[pd.DataFrame] = None): + def add_to_results( + config: Config, success: Result, results_df: Optional[pd.DataFrame] = None + ): config_dict = asdict(config) - config_dict['prepare_finalize_type'] = config_dict[ - 'prepare_finalize_type'].__name__ - config_dict['fused_experts_type'] = config_dict[ - 'fused_experts_type'].__name__ - config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant - quant_config_dict = config_dict['quant_config'] - del config_dict['quant_config'] + config_dict["prepare_finalize_type"] = config_dict[ + "prepare_finalize_type" + ].__name__ + config_dict["fused_experts_type"] = config_dict["fused_experts_type"].__name__ + config_dict["per_tensor_act_quant"] = config.is_per_tensor_act_quant + quant_config_dict = config_dict["quant_config"] + del config_dict["quant_config"] if quant_config_dict is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config_dict = asdict(quant_config) config_dict |= quant_config_dict - result_dict = config_dict | {'success': success.name} + result_dict = config_dict | {"success": success.name} result_df = pd.DataFrame([result_dict]) if results_df is None: @@ -112,22 +118,26 @@ def add_to_results(config: Config, Q_TYPES = MK_QUANT_CONFIGS combinations = list( - product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)) + product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES) + ) results_df: Optional[pd.DataFrame] = None for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm( - combinations): #noqa: E501 - config = Config(Ms=[m], - K=k, - N=n, - E=e, - topks=topks, - dtype=dtype, - prepare_finalize_type=pf_type, - fused_experts_type=experts_type, - quant_config=quant_config, - world_size=2, - fused_moe_chunk_size=None) + combinations + ): + config = Config( + Ms=[m], + K=k, + N=n, + E=e, + topks=topks, + dtype=dtype, + prepare_finalize_type=pf_type, + fused_experts_type=experts_type, + quant_config=quant_config, + world_size=2, + fused_moe_chunk_size=None, + ) success = None if config.is_valid(): @@ -135,9 +145,14 @@ def add_to_results(config: Config, try: weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() - parallel_launch_with_config(config.world_size, rank_worker, - vllm_config, env_dict, config, - weights) + parallel_launch_with_config( + config.world_size, + rank_worker, + vllm_config, + env_dict, + config, + weights, + ) success = Result.PASS except Exception as _: success = Result.FAIL @@ -150,25 +165,33 @@ def add_to_results(config: Config, results_df.to_csv(f"{csv_file_path}") -if __name__ == '__main__': +if __name__ == "__main__": import argparse from pathlib import Path - parser = argparse.ArgumentParser(description=( - "Make ModularKernel feature matrix \n" - "Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501 - "-f ./feature_matrices/feature_matrix.csv")) - - parser.add_argument("-f", - "--feature-matrix-csv-file-path", - type=str, - required=True, - help="File name to Generate a .csv file") + + parser = argparse.ArgumentParser( + description=( + "Make ModularKernel feature matrix \n" + "Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " # noqa: E501 + "-f ./feature_matrices/feature_matrix.csv" + ) + ) + + parser.add_argument( + "-f", + "--feature-matrix-csv-file-path", + type=str, + required=True, + help="File name to Generate a .csv file", + ) args = parser.parse_args() csv_path = args.feature_matrix_csv_file_path - assert csv_path.endswith( - 'csv'), f"Need a file path ending with .csv, got {csv_path}" - assert Path(csv_path).parent.is_dir( - ), f"Cannot find parent directory for {Path(csv_path).parent}" + assert csv_path.endswith("csv"), ( + f"Need a file path ending with .csv, got {csv_path}" + ) + assert Path(csv_path).parent.is_dir(), ( + f"Cannot find parent directory for {Path(csv_path).parent}" + ) make_feature_matrix(args.feature_matrix_csv_file_path) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 57a1da7b4b1a..566fb1e09d3b 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -8,24 +8,33 @@ # Fused experts and PrepareFinalize imports import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, - FusedMoEQuantConfig) + BatchedDeepGemmExperts, +) +from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( + BatchedTritonOrDeepGemmExperts, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts, NaiveBatchedExperts) -from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, - TritonExperts) + BatchedTritonExperts, + NaiveBatchedExperts, +) +from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase, TritonExperts from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) + TritonOrDeepGemmExperts, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported) + cutlass_fp4_supported, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_fp8_supported) + cutlass_fp8_supported, +) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.deep_gemm import is_deep_gemm_supported @@ -60,8 +69,7 @@ class ExpertInfo: needs_deep_gemm: bool = False -PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, - PrepareFinalizeInfo] = {} +PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {} EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {} MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = [] @@ -71,7 +79,10 @@ class ExpertInfo: standard_format = mk.FusedMoEActivationFormat.Standard batched_format = mk.FusedMoEActivationFormat.BatchedExperts common_float_types: list[Union[torch.dtype, str]] = [ - torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32 + torch.float8_e4m3fn, + torch.bfloat16, + torch.float16, + torch.float32, ] common_float_and_int_types = common_float_types + [torch.int8] nvfp4_types = ["nvfp4"] @@ -185,10 +196,12 @@ def expert_info(kind) -> ExpertInfo: # Disable on blackwell for now if has_deep_ep() and not current_platform.has_device_capability(100): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( + DeepEPHTPrepareAndFinalize, + ) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( + DeepEPLLPrepareAndFinalize, + ) register_prepare_and_finalize( DeepEPHTPrepareAndFinalize, @@ -208,7 +221,9 @@ def expert_info(kind) -> ExpertInfo: if has_pplx(): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + PplxPrepareAndFinalize, + ) + register_prepare_and_finalize( PplxPrepareAndFinalize, batched_format, @@ -217,13 +232,14 @@ def expert_info(kind) -> ExpertInfo: backend="pplx", ) -if (has_flashinfer_cutlass_fused_moe() - and current_platform.has_device_capability(100)): - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - FlashInferExperts) +if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability(100): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 FlashInferCutlassMoEPrepareAndFinalize, - create_flashinfer_prepare_finalize) + create_flashinfer_prepare_finalize, + ) register_prepare_and_finalize( FlashInferCutlassMoEPrepareAndFinalize, @@ -258,16 +274,18 @@ def expert_info(kind) -> ExpertInfo: needs_matching_quant=False, needs_deep_gemm=True, ) - register_experts( - DeepGemmExperts, - standard_format, - fp8_types, - blocked_quantization_support=True, - supports_chunking=True, - supports_expert_map=True, - needs_matching_quant=False, - needs_deep_gemm=True, - ), + ( + register_experts( + DeepGemmExperts, + standard_format, + fp8_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=False, + needs_deep_gemm=True, + ), + ) register_experts( BatchedTritonOrDeepGemmExperts, batched_format, @@ -290,8 +308,11 @@ def expert_info(kind) -> ExpertInfo: ) if cutlass_fp8_supported(): - from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8, - CutlassExpertsFp8) + from vllm.model_executor.layers.fused_moe import ( + CutlassBatchedExpertsFp8, + CutlassExpertsFp8, + ) + register_experts( CutlassExpertsFp8, standard_format, @@ -310,8 +331,8 @@ def expert_info(kind) -> ExpertInfo: ) if cutlass_fp4_supported(): - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassExpertsFp4) + from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp4 + register_experts( CutlassExpertsFp4, standard_format, @@ -324,30 +345,40 @@ def expert_info(kind) -> ExpertInfo: MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [ None, # per-channel / per-column weights and per-tensor activations - TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=True, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=False, + block_shape=None, + ), # per-channel / per-column weights and per-token activations - TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=True, - per_act_token_quant=True, - block_shape=None), + TestMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=True, + block_shape=None, + ), # per-tensor weights and per-tensor activations - TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None, + ), # per-tensor weights and per-token activations - TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=True, - block_shape=None), + TestMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=True, + block_shape=None, + ), # block-quantized weights and 128 block per-token activations - TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=[128, 128]), + TestMoEQuantConfig( + quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=[128, 128], + ), # TODO (varun) : Should we test the following combinations ? # block-quantized weights and per-token activations # block-quantized weights and per-tensor activations @@ -355,10 +386,12 @@ def expert_info(kind) -> ExpertInfo: if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe(): MK_QUANT_CONFIGS += [ - TestMoEQuantConfig(quant_dtype="nvfp4", - per_out_ch_quant=False, - per_act_token_quant=False, - block_shape=None), + TestMoEQuantConfig( + quant_dtype="nvfp4", + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None, + ), ] @@ -370,12 +403,14 @@ def make_prepare_finalize( ) -> mk.FusedMoEPrepareAndFinalize: if backend != "naive" and backend is not None: prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize( - moe, quant_config) + moe, quant_config + ) assert prepare_finalize is not None return prepare_finalize elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: return create_flashinfer_prepare_finalize( - use_dp=moe.moe_parallel_config.dp_size > 1) + use_dp=moe.moe_parallel_config.dp_size > 1 + ) else: return MoEPrepareAndFinalizeNoEP() @@ -391,10 +426,10 @@ def make_cutlass_strides( n: int, k: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) return ab_strides1, ab_strides2, c_strides1, c_strides2 @@ -405,7 +440,6 @@ def make_fused_experts( num_dispatchers: int, N: int, ) -> mk.FusedMoEPermuteExpertsUnpermute: - batch_kwargs = { "max_num_tokens": moe.max_num_tokens, "num_dispatchers": num_dispatchers, diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py index 459b785e6504..7802129d3d48 100644 --- a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py +++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py @@ -6,13 +6,11 @@ from typing import Any, Callable, Optional import torch -from torch.multiprocessing import ( - spawn) # pyright: ignore[reportPrivateImportUsage] +from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec from vllm.config import VllmConfig, set_current_vllm_config -from vllm.distributed import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed import init_distributed_environment, initialize_model_parallel from vllm.utils import get_open_port ## Parallel Processes Utils @@ -30,10 +28,11 @@ class ProcessGroupInfo: device: torch.device -def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int, - local_rank: int): - +def _set_vllm_config( + vllm_config: VllmConfig, world_size: int, rank: int, local_rank: int +): import tempfile + temp_file = tempfile.mkstemp()[1] with set_current_vllm_config(vllm_config): @@ -46,13 +45,10 @@ def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int, ) initialize_model_parallel( - tensor_model_parallel_size=vllm_config.parallel_config. - tensor_parallel_size, - pipeline_model_parallel_size=vllm_config.parallel_config. - pipeline_parallel_size, + tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size, + pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size, ) - cpu_group = torch.distributed.new_group(list(range(world_size)), - backend="gloo") + cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo") return cpu_group @@ -62,8 +58,7 @@ def _worker_parallel_launch( world_local_size: int, node_rank: int, init_method: str, - worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, - P], None], + worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None], vllm_config: Optional[VllmConfig], env_dict: Optional[dict], *args: P.args, @@ -131,7 +126,8 @@ def parallel_launch_with_config( worker, vllm_config, env_dict, - ) + args, + ) + + args, nprocs=world_size, join=True, ) diff --git a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py index 0da6ee354352..48e5c4659b49 100644 --- a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py +++ b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py @@ -14,28 +14,31 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config -def do_profile(fn: Callable, - fn_kwargs: dict[Any, Any], - pgi: ProcessGroupInfo, - config: Config, - num_warmups: int = 5): +def do_profile( + fn: Callable, + fn_kwargs: dict[Any, Any], + pgi: ProcessGroupInfo, + config: Config, + num_warmups: int = 5, +): for _ in range(num_warmups): fn(**fn_kwargs) with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=True, - record_shapes=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + record_shapes=True, ) as tprof: fn(**fn_kwargs) torch.cuda.synchronize(torch.cuda.current_device()) # TODO (varun): Add a descriptive trace file name tprof.export_chrome_trace( - f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json") + f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json" + ) def profile_modular_kernel( @@ -82,6 +85,7 @@ def rank_worker( # sanity check from vllm import envs + if config.fused_moe_chunk_size is not None: assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE @@ -108,20 +112,25 @@ def rank_worker( def run(config: Config): weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() - parallel_launch_with_config(config.world_size, rank_worker, vllm_config, - env_dict, config, weights) + parallel_launch_with_config( + config.world_size, rank_worker, vllm_config, env_dict, config, weights + ) -if __name__ == '__main__': +if __name__ == "__main__": from .cli_args import make_config, make_config_arg_parser - parser = make_config_arg_parser(description=( - "Run single prepare-finalize & fused-experts combination test" - "Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501 - "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" - )) + + parser = make_config_arg_parser( + description=( + "Run single prepare-finalize & fused-experts combination test" + "Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " # noqa: E501 + "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" + ) + ) args = parser.parse_args() assert args.torch_trace_dir_path is not None, ( - "Please pass in a directory to store torch traces") + "Please pass in a directory to store torch traces" + ) config = make_config(args) run(config) diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py index 1ad361ae0733..fb9e5df281f1 100644 --- a/tests/kernels/moe/parallel_utils.py +++ b/tests/kernels/moe/parallel_utils.py @@ -3,6 +3,7 @@ """ DeepEP test utilities """ + import dataclasses import os import traceback @@ -10,17 +11,18 @@ import torch from torch.distributed import ProcessGroup -from torch.multiprocessing import ( - spawn) # pyright: ignore[reportPrivateImportUsage] +from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec from vllm.utils import get_open_port, has_deep_ep if has_deep_ep(): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( + DeepEPHTPrepareAndFinalize, + ) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( + DeepEPLLPrepareAndFinalize, + ) ## Parallel Processes Utils @@ -96,7 +98,8 @@ def parallel_launch( 0, f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}", worker, - ) + args, + ) + + args, nprocs=world_size, join=True, ) @@ -118,48 +121,57 @@ class DeepEPLLArgs: use_fp8_dispatch: bool -def make_deepep_ht_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - dp_size: int, - ht_args: DeepEPHTArgs, - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): - +def make_deepep_ht_a2a( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + ht_args: DeepEPHTArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, +): import deep_ep # high throughput a2a num_nvl_bytes = 1024 * 1024 * 1024 # 1GB num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1 - buffer = deep_ep.Buffer(group=pg, - num_nvl_bytes=num_nvl_bytes, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=low_latency_mode, - num_qps_per_rank=num_qps_per_rank) - return DeepEPHTPrepareAndFinalize(buffer=buffer, - num_dispatchers=pgi.world_size, - dp_size=dp_size, - rank_expert_offset=pgi.rank * - ht_args.num_local_experts) - - -def make_deepep_ll_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - deepep_ll_args: DeepEPLLArgs, - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): + buffer = deep_ep.Buffer( + group=pg, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=low_latency_mode, + num_qps_per_rank=num_qps_per_rank, + ) + return DeepEPHTPrepareAndFinalize( + buffer=buffer, + num_dispatchers=pgi.world_size, + dp_size=dp_size, + rank_expert_offset=pgi.rank * ht_args.num_local_experts, + ) + +def make_deepep_ll_a2a( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + deepep_ll_args: DeepEPLLArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, +): import deep_ep # low-latency a2a num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size, - pgi.world_size, deepep_ll_args.num_experts) + deepep_ll_args.max_tokens_per_rank, + deepep_ll_args.hidden_size, + pgi.world_size, + deepep_ll_args.num_experts, + ) - buffer = deep_ep.Buffer(group=pg, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=True, - num_qps_per_rank=deepep_ll_args.num_experts // - pgi.world_size) + buffer = deep_ep.Buffer( + group=pg, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size, + ) return DeepEPLLPrepareAndFinalize( buffer=buffer, @@ -169,17 +181,20 @@ def make_deepep_ll_a2a(pg: ProcessGroup, ) -def make_deepep_a2a(pg: ProcessGroup, - pgi: ProcessGroupInfo, - dp_size: int, - deepep_ht_args: Optional[DeepEPHTArgs], - deepep_ll_args: Optional[DeepEPLLArgs], - q_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): +def make_deepep_a2a( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + deepep_ht_args: Optional[DeepEPHTArgs], + deepep_ll_args: Optional[DeepEPLLArgs], + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, +): if deepep_ht_args is not None: assert deepep_ll_args is None - return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype, - block_shape) + return make_deepep_ht_a2a( + pg, pgi, dp_size, deepep_ht_args, q_dtype, block_shape + ) assert deepep_ll_args is not None return make_deepep_ll_a2a(pg, pgi, deepep_ll_args, q_dtype, block_shape) diff --git a/tests/kernels/moe/test_batched_deepgemm.py b/tests/kernels/moe/test_batched_deepgemm.py index afec97e8cffd..59cecd60d3d6 100644 --- a/tests/kernels/moe/test_batched_deepgemm.py +++ b/tests/kernels/moe/test_batched_deepgemm.py @@ -5,13 +5,14 @@ import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) -from vllm.model_executor.layers.fused_moe.config import ( - fp8_w8a8_moe_quant_config) + BatchedDeepGemmExperts, +) +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize, BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) + BatchedPrepareAndFinalize, + BatchedTritonExperts, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported from .test_deepgemm import make_block_quant_fp8_weights @@ -19,15 +20,15 @@ BLOCK_SIZE = [128, 128] -@pytest.mark.skipif(not is_deep_gemm_supported(), - reason="Requires deep_gemm kernels") +@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") @pytest.mark.parametrize("E", [16, 32]) # number of experts @pytest.mark.parametrize("T", [256, 512]) # tokens per expert @pytest.mark.parametrize("K", [128, 256]) # hidden dim @pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert @pytest.mark.parametrize("topk", [2, 4]) -def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, - monkeypatch): +def test_batched_deepgemm_vs_triton( + E: int, T: int, K: int, N: int, topk: int, monkeypatch +): """Compare BatchedDeepGemmExperts to BatchedTritonExperts.""" monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 7e79828937c7..09cede3fbcc7 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -7,14 +7,18 @@ import pytest import torch -from tests.kernels.moe.utils import (batched_moe, - make_quantized_test_activations, - make_test_weights, naive_batched_moe) +from tests.kernels.moe.utils import ( + batched_moe, + make_quantized_test_activations, + make_test_weights, + naive_batched_moe, +) from tests.kernels.quant_utils import native_batched_masked_quant_matmul from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - invoke_moe_batched_triton_kernel) + invoke_moe_batched_triton_kernel, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform from vllm.triton_utils import tl @@ -68,23 +72,32 @@ class BatchedMMTensors: @staticmethod def make_tensors(config: BatchedMMConfig): - A = torch.randn( - (config.num_experts, config.max_tokens_per_expert, config.K), + A = ( + torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.K), + device="cuda", + dtype=config.in_dtype, + ) + / 10 + ) + B = torch.randn( + (config.num_experts, config.N, config.K), device="cuda", - dtype=config.in_dtype) / 10 - B = torch.randn((config.num_experts, config.N, config.K), - device="cuda", - dtype=config.in_dtype) + dtype=config.in_dtype, + ) C = torch.zeros( (config.num_experts, config.max_tokens_per_expert, config.N), device="cuda", - dtype=config.out_dtype) + dtype=config.out_dtype, + ) - num_expert_tokens = torch.randint(low=0, - high=config.max_tokens_per_expert, - size=(config.num_experts, ), - device="cuda", - dtype=torch.int32) + num_expert_tokens = torch.randint( + low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts,), + device="cuda", + dtype=torch.int32, + ) return BatchedMMTensors(A, B, C, num_expert_tokens) @@ -96,10 +109,15 @@ def make_tensors(config: BatchedMMConfig): @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) @pytest.mark.parametrize("block_shape", [None, [128, 128]]) @pytest.mark.parametrize("per_act_token_quant", [False, True]) -def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, - N: int, dtype: torch.dtype, - block_shape: Optional[list[int]], - per_act_token_quant: bool): +def test_batched_mm( + num_experts: int, + max_tokens_per_expert: int, + K: int, + N: int, + dtype: torch.dtype, + block_shape: Optional[list[int]], + per_act_token_quant: bool, +): current_platform.seed_everything(7) use_fp8_w8a8 = dtype == torch.float8_e4m3fn @@ -117,11 +135,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, act_dtype = dtype quant_dtype = None - num_expert_tokens = torch.randint(low=0, - high=max_tokens_per_expert, - size=(num_experts, ), - device="cuda", - dtype=torch.int32) + num_expert_tokens = torch.randint( + low=0, + high=max_tokens_per_expert, + size=(num_experts,), + device="cuda", + dtype=torch.int32, + ) A, A_q, A_scale = make_quantized_test_activations( num_experts, @@ -151,7 +171,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, compute_tl_dtype = { torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, - torch.float32: tl.float32 + torch.float32: tl.float32, }[test_output.dtype] assert A_q.dtype == B_q.dtype @@ -173,7 +193,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, config={ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32 + "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32, }, per_act_token_quant=per_act_token_quant, block_shape=block_shape, @@ -186,11 +206,16 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, num_expert_tokens, ) - q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output, - num_expert_tokens, - A_scale, B_scale, - block_shape, - per_act_token_quant) + q_ref_output = native_batched_masked_quant_matmul( + A_q, + B_q, + q_ref_output, + num_expert_tokens, + A_scale, + B_scale, + block_shape, + per_act_token_quant, + ) rtol, atol = { torch.float16: (6e-2, 6e-2), @@ -308,12 +333,6 @@ def test_fused_moe_batched_experts( block_shape=block_shape, ) - torch.testing.assert_close(batched_output, - baseline_output, - atol=3e-2, - rtol=2e-2) + torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2) - torch.testing.assert_close(triton_output, - batched_output, - atol=2e-2, - rtol=2e-2) + torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index da383e18c372..b8cd3cb9200c 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -5,15 +5,21 @@ import torch from tests.kernels.moe.utils import make_test_quant_config, make_test_weights -from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, - native_w8a8_block_matmul) +from tests.kernels.quant_utils import ( + native_per_token_group_quant_fp8, + native_w8a8_block_matmul, +) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm_shape, deep_gemm_moe_fp8) + _valid_deep_gemm_shape, + deep_gemm_moe_fp8, +) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, modular_triton_fused_moe) + fused_topk, + modular_triton_fused_moe, +) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used @@ -24,8 +30,7 @@ from deep_gemm import get_m_alignment_for_contiguous_layout if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -97,8 +102,7 @@ SEEDS = [0] -def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, - block_shape): +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape): """Fused moe with block-wise quantization using native torch.""" B, D = a.shape topk = topk_ids.size(1) @@ -114,23 +118,17 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - inter_out = native_w8a8_block_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) + inter_out = native_w8a8_block_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype + ) act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_fp8( - act_out, block_k) - out[mask] = native_w8a8_block_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k) + out[mask] = native_w8a8_block_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) # Skip all tests if CUDA is not available @@ -149,8 +147,9 @@ def setup_cuda(): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, - monkeypatch): +def test_w8a8_block_fp8_fused_moe( + M, N, K, E, topk, block_size, dtype, seed, monkeypatch +): if topk > E: pytest.skip(f"Skipping test; topk={topk} > E={E}") @@ -188,12 +187,9 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, block_size, ) - out = fused_experts(a, - w1, - w2, - topk_weights, - topk_ids, - quant_config=quant_config) + out = fused_experts( + a, w1, w2, topk_weights, topk_ids, quant_config=quant_config + ) m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids) @@ -210,8 +206,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() -def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, - monkeypatch): +def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): if topk > E: pytest.skip(f"Skipping test: topk={topk} > E={E}") @@ -245,36 +240,38 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, # setup code in case we are able to revisit this later. use_compile = False - use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024 - and current_platform.is_cuda_alike()) + use_cudagraph = ( + chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike() + ) topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids, block_size) + ref_out = torch_w8a8_block_fp8_moe( + a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size + ) if use_compile: - deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8, - backend="inductor", - fullgraph=True) + deep_gemm_moe_fp8_fn = torch.compile( + deep_gemm_moe_fp8, backend="inductor", fullgraph=True + ) torch._dynamo.mark_dynamic(a, 0) torch._dynamo.mark_dynamic(topk_weights, 0) torch._dynamo.mark_dynamic(topk_ids, 0) else: deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8 - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) if use_cudagraph: out.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, - topk_ids) + out = deep_gemm_moe_fp8_fn( + a, w1, w2, w1_s, w2_s, topk_weights, topk_ids + ) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() diff --git a/tests/kernels/moe/test_block_int8.py b/tests/kernels/moe/test_block_int8.py index 041a13ca5585..74cc943714dd 100644 --- a/tests/kernels/moe/test_block_int8.py +++ b/tests/kernels/moe/test_block_int8.py @@ -5,16 +5,17 @@ import torch from tests.kernels.moe.utils import make_test_quant_config -from tests.kernels.quant_utils import (native_per_token_group_quant_int8, - native_w8a8_block_matmul) +from tests.kernels.quant_utils import ( + native_per_token_group_quant_int8, + native_w8a8_block_matmul, +) from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): - pytest.skip("INT8 Triton requires CUDA 7.0 or higher", - allow_module_level=True) + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -77,24 +78,18 @@ def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - inter_out = native_w8a8_block_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - block_shape, - output_dtype=a.dtype) + inter_out = native_w8a8_block_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype + ) act_out = SiluAndMul().forward_native(inter_out) - act_out_q, act_out_s = native_per_token_group_quant_int8( - act_out, block_k) + act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k) act_out = act_out.to(torch.float32) - out[mask] = native_w8a8_block_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - block_shape, - output_dtype=a.dtype) - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + out[mask] = native_w8a8_block_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) @pytest.fixture(autouse=True, scope="module") @@ -131,15 +126,19 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): - out = fused_experts(a, - w1, - w2, - topk_weights, - topk_ids, - quant_config=quant_config) - ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale, - quant_config.w2_scale, score, topk, - block_size) + out = fused_experts( + a, w1, w2, topk_weights, topk_ids, quant_config=quant_config + ) + ref_out = torch_w8a8_block_int8_moe( + a, + w1, + w2, + quant_config.w1_scale, + quant_config.w2_scale, + score, + topk, + block_size, + ) # Check results torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065) diff --git a/tests/kernels/moe/test_count_expert_num_tokens.py b/tests/kernels/moe/test_count_expert_num_tokens.py index 1768baaf1ca7..996a4538d105 100644 --- a/tests/kernels/moe/test_count_expert_num_tokens.py +++ b/tests/kernels/moe/test_count_expert_num_tokens.py @@ -15,7 +15,6 @@ @dataclasses.dataclass class TestTensors: - topk_ids: torch.Tensor expert_map: Optional[torch.Tensor] = None @@ -25,32 +24,31 @@ def to_device(self, device: str): self.expert_map = self.expert_map.to(device=device) @staticmethod - def make(num_tokens: int, num_topk: int, num_experts: int, device: str, - topk_ids_dtype: torch.dtype) -> "TestTensors": - + def make( + num_tokens: int, + num_topk: int, + num_experts: int, + device: str, + topk_ids_dtype: torch.dtype, + ) -> "TestTensors": # make topk ids - topk_ids = torch.empty((num_tokens, num_topk), - device=device, - dtype=torch.int64) + topk_ids = torch.empty((num_tokens, num_topk), device=device, dtype=torch.int64) for x in range(num_tokens): topk_ids[x] = torch.randperm(num_experts)[:num_topk] topk_ids = topk_ids.to(dtype=torch.int64) return TestTensors(topk_ids=topk_ids) - def with_ep_rank(self, ep_rank: int, num_global_experts: int, - num_local_experts: int, device: str): + def with_ep_rank( + self, ep_rank: int, num_global_experts: int, num_local_experts: int, device: str + ): # make an expert map - expert_map = torch.empty((num_global_experts), - device=device, - dtype=torch.int32) + expert_map = torch.empty((num_global_experts), device=device, dtype=torch.int32) expert_map.fill_(-1) s = ep_rank * num_local_experts e = s + num_local_experts - expert_map[s:e] = torch.tensor(list(range(num_local_experts)), - device=device) + expert_map[s:e] = torch.tensor(list(range(num_local_experts)), device=device) - return TestTensors(topk_ids=self.topk_ids.clone(), - expert_map=expert_map) + return TestTensors(topk_ids=self.topk_ids.clone(), expert_map=expert_map) def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor): @@ -68,49 +66,49 @@ def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor): expert_num_tokens[eid] += count -def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int, - num_experts: int, ep_size: int, - topk_ids_dtype: torch.dtype): - +def do_test_compute_expert_num_tokens( + num_tokens: int, + num_topk: int, + num_experts: int, + ep_size: int, + topk_ids_dtype: torch.dtype, +): assert num_topk <= num_experts - tt = TestTensors.make(num_tokens, - num_topk, - num_experts, - topk_ids_dtype=topk_ids_dtype, - device="cpu") + tt = TestTensors.make( + num_tokens, num_topk, num_experts, topk_ids_dtype=topk_ids_dtype, device="cpu" + ) num_global_experts = num_experts assert num_global_experts % ep_size == 0 num_local_experts = num_global_experts // ep_size for ep_rank in range(ep_size): - tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, - num_local_experts, "cpu") + tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, num_local_experts, "cpu") - ref_expert_num_tokens = torch.zeros((num_local_experts), - device="cpu", - dtype=torch.int32) + ref_expert_num_tokens = torch.zeros( + (num_local_experts), device="cpu", dtype=torch.int32 + ) ref_impl(tt_rank, ref_expert_num_tokens) ref_expert_num_tokens = ref_expert_num_tokens.to("cuda") tt_rank.to_device("cuda") # Test with expert_map triton_expert_num_tokens_w_emap = count_expert_num_tokens( - tt_rank.topk_ids, num_local_experts, tt_rank.expert_map) + tt_rank.topk_ids, num_local_experts, tt_rank.expert_map + ) # Test without expert map topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype) triton_expert_num_tokens_wo_emap = count_expert_num_tokens( - topk_ids, num_local_experts, expert_map=None) + topk_ids, num_local_experts, expert_map=None + ) - torch.testing.assert_close(ref_expert_num_tokens, - triton_expert_num_tokens_w_emap, - atol=0, - rtol=0) - torch.testing.assert_close(ref_expert_num_tokens, - triton_expert_num_tokens_wo_emap, - atol=0, - rtol=0) + torch.testing.assert_close( + ref_expert_num_tokens, triton_expert_num_tokens_w_emap, atol=0, rtol=0 + ) + torch.testing.assert_close( + ref_expert_num_tokens, triton_expert_num_tokens_wo_emap, atol=0, rtol=0 + ) @pytest.mark.parametrize("num_tokens", [1, 4, 8, 11, 127, 128, 3333, 7317]) @@ -118,22 +116,29 @@ def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int, @pytest.mark.parametrize("num_experts", [64]) @pytest.mark.parametrize("ep_size", [1, 2, 4]) @pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) -def test_compute_expert_num_tokens(num_tokens: int, num_topk: int, - num_experts: int, ep_size: int, - topk_ids_dtype: torch.dtype): - do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts, - ep_size, topk_ids_dtype) +def test_compute_expert_num_tokens( + num_tokens: int, + num_topk: int, + num_experts: int, + ep_size: int, + topk_ids_dtype: torch.dtype, +): + do_test_compute_expert_num_tokens( + num_tokens, num_topk, num_experts, ep_size, topk_ids_dtype + ) @pytest.mark.parametrize("numel", list(range(1, 8192, 111))) @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("ep_size", [2]) @pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) -def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int, - ep_size: int, - topk_ids_dtype: torch.dtype): - do_test_compute_expert_num_tokens(num_tokens=numel, - num_topk=1, - num_experts=num_experts, - ep_size=ep_size, - topk_ids_dtype=topk_ids_dtype) +def test_compute_expert_num_tokens_from_numel( + numel: int, num_experts: int, ep_size: int, topk_ids_dtype: torch.dtype +): + do_test_compute_expert_num_tokens( + num_tokens=numel, + num_topk=1, + num_experts=num_experts, + ep_size=ep_size, + topk_ids_dtype=topk_ids_dtype, + ) diff --git a/tests/kernels/moe/test_cutlass_grouped_gemm.py b/tests/kernels/moe/test_cutlass_grouped_gemm.py index 3b1618dacac7..4c60241bdb01 100644 --- a/tests/kernels/moe/test_cutlass_grouped_gemm.py +++ b/tests/kernels/moe/test_cutlass_grouped_gemm.py @@ -17,19 +17,24 @@ from vllm.utils.deep_gemm import per_block_cast_to_fp8 -@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ - (4, 8192, 7168, 4096), - (4, 8192, 2048, 7168), - (8, 4096, 7168, 4096), - (8, 4096, 2048, 7168), - (32, 1024, 7168, 4096), - (32, 1024, 2048, 7168), -]) +@pytest.mark.parametrize( + "num_groups, expected_m_per_group, k, n", + [ + (4, 8192, 7168, 4096), + (4, 8192, 2048, 7168), + (8, 4096, 7168, 4096), + (8, 4096, 2048, 7168), + (32, 1024, 7168, 4096), + (32, 1024, 2048, 7168), + ], +) @pytest.mark.parametrize("out_dtype", [torch.float16]) @pytest.mark.skipif( (lambda x: x is None or x.to_int() != 100)( - current_platform.get_device_capability()), - reason="Block Scaled Grouped GEMM is only supported on SM100.") + current_platform.get_device_capability() + ), + reason="Block Scaled Grouped GEMM is only supported on SM100.", +) def test_cutlass_grouped_gemm( num_groups: int, expected_m_per_group: int, @@ -40,8 +45,7 @@ def test_cutlass_grouped_gemm( device = "cuda" alignment = 128 group_ms = [ - int(expected_m_per_group * random.uniform(0.7, 1.3)) - for _ in range(num_groups) + int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups) ] m = sum([cdiv(m, alignment) * alignment for m in group_ms]) @@ -58,20 +62,22 @@ def test_cutlass_grouped_gemm( expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32) x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), - torch.empty((num_groups, cdiv(n, 128), k // 128), - device=device, - dtype=torch.float)) + y_fp8 = ( + torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty( + (num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float + ), + ) for i in range(num_groups): y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128]) for i in range(num_groups): - a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] - a_scale = x_fp8[1][ep_offset[i]:ep_offset[i + 1]] + a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]] + a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]] b = y_fp8[0][i].t() b_scale = y_fp8[1][i].t() baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype) - ref_out[ep_offset[i]:ep_offset[i + 1]] = baseline + ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline ops.cutlass_blockwise_scaled_grouped_mm( out, diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index ca6be767dab3..b82cea61bd4e 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -11,13 +11,15 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import ( - FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config) + FUSED_MOE_UNQUANTIZED_CONFIG, + fp8_w8a8_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8, run_cutlass_moe_fp8) -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, - fused_topk) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + cutlass_moe_fp8, + run_cutlass_moe_fp8, +) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.platforms import current_platform NUM_EXPERTS = [40, 64] @@ -39,12 +41,11 @@ (224, 3072, 1536), (32768, 1024, 1024), # These sizes trigger wrong answers. - #(7232, 2048, 5120), - #(40000, 2048, 5120), + # (7232, 2048, 5120), + # (40000, 2048, 5120), ] -vllm_config = VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1)) +vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 @@ -60,22 +61,25 @@ class MOETensors: c_strides2: torch.Tensor @staticmethod - def make_moe_tensors(m: int, k: int, n: int, e: int, - dtype: torch.dtype) -> "MOETensors": + def make_moe_tensors( + m: int, k: int, n: int, e: int, dtype: torch.dtype + ) -> "MOETensors": a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - return MOETensors(a=a, - w1=w1, - w2=w2, - ab_strides1=ab_strides1, - c_strides1=c_strides1, - ab_strides2=ab_strides2, - c_strides2=c_strides2) + ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) + return MOETensors( + a=a, + w1=w1, + w2=w2, + ab_strides1=ab_strides1, + c_strides1=c_strides1, + ab_strides2=ab_strides2, + c_strides2=c_strides2, + ) @dataclasses.dataclass @@ -93,9 +97,9 @@ class MOETensors8Bit(MOETensors): w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d @staticmethod - def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, - per_act_token: bool, - per_out_channel: bool) -> "MOETensors8Bit": + def make_moe_tensors_8bit( + m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool + ) -> "MOETensors8Bit": dtype = torch.half q_dtype = torch.float8_e4m3fn @@ -106,24 +110,21 @@ def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, k_b_scales = k if per_out_channel else 1 # Get the right scale for tests. a_q, a_scale = ops.scaled_fp8_quant( - moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token) + moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token + ) w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - moe_tensors_fp16.w1[expert], - use_per_token_if_dynamic=per_out_channel) + moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel + ) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - moe_tensors_fp16.w2[expert], - use_per_token_if_dynamic=per_out_channel) + moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel + ) # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d a_d = a_q.float().mul(a_scale).to(dtype) @@ -133,31 +134,37 @@ def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half() w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() - return MOETensors8Bit(a=moe_tensors_fp16.a, - w1=moe_tensors_fp16.w1, - w2=moe_tensors_fp16.w2, - ab_strides1=moe_tensors_fp16.ab_strides1, - c_strides1=moe_tensors_fp16.c_strides1, - ab_strides2=moe_tensors_fp16.ab_strides2, - c_strides2=moe_tensors_fp16.c_strides2, - a_q=a_q, - w1_q=w1_q, - w2_q=w2_q, - a_scale=a_scale, - w1_scale=w1_scale, - w2_scale=w2_scale, - a_d=a_d, - w1_d=w1_d, - w2_d=w2_d) - - -def run_with_expert_maps(num_experts: int, num_local_experts: int, - **cutlass_moe_kwargs): - + return MOETensors8Bit( + a=moe_tensors_fp16.a, + w1=moe_tensors_fp16.w1, + w2=moe_tensors_fp16.w2, + ab_strides1=moe_tensors_fp16.ab_strides1, + c_strides1=moe_tensors_fp16.c_strides1, + ab_strides2=moe_tensors_fp16.ab_strides2, + c_strides2=moe_tensors_fp16.c_strides2, + a_q=a_q, + w1_q=w1_q, + w2_q=w2_q, + a_scale=a_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + a_d=a_d, + w1_d=w1_d, + w2_d=w2_d, + ) + + +def run_with_expert_maps( + num_experts: int, num_local_experts: int, **cutlass_moe_kwargs +): def slice_experts(): slice_params = [ - "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1", - "c_strides2" + "w1_q", + "w2_q", + "ab_strides1", + "ab_strides2", + "c_strides1", + "c_strides2", ] full_tensors = { k: v @@ -173,9 +180,7 @@ def slice_experts(): # make expert map expert_map = [-1] * num_experts expert_map[s:e] = list(range(num_local_experts)) - expert_map = torch.tensor(expert_map, - dtype=torch.int32, - device="cuda") + expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") # update cutlass moe arg with expert_map cutlass_moe_kwargs["expert_map"] = expert_map @@ -198,18 +203,26 @@ def slice_experts(): return out_tensor -def run_8_bit(moe_tensors: MOETensors8Bit, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - per_act_token: bool, - per_out_ch: bool, - num_local_experts: Optional[int] = None) -> torch.Tensor: - assert not any([ - t is None for t in [ - moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale, - moe_tensors.w2_scale, moe_tensors.a_scale +def run_8_bit( + moe_tensors: MOETensors8Bit, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + per_act_token: bool, + per_out_ch: bool, + num_local_experts: Optional[int] = None, +) -> torch.Tensor: + assert not any( + [ + t is None + for t in [ + moe_tensors.w1_q, + moe_tensors.w2_q, + moe_tensors.w1_scale, + moe_tensors.w2_scale, + moe_tensors.a_scale, + ] ] - ]) + ) quant_config = fp8_w8a8_moe_quant_config( w1_scale=moe_tensors.w1_scale, @@ -222,16 +235,16 @@ def run_8_bit(moe_tensors: MOETensors8Bit, ) kwargs = { - 'a': moe_tensors.a, - 'w1_q': moe_tensors.w1_q, # type: ignore[union-attr] - 'w2_q': moe_tensors.w2_q, # type: ignore[union-attr] - 'topk_weights': topk_weights, - 'topk_ids': topk_ids, - 'ab_strides1': moe_tensors.ab_strides1, - 'ab_strides2': moe_tensors.ab_strides2, - 'c_strides1': moe_tensors.c_strides1, - 'c_strides2': moe_tensors.c_strides2, - 'quant_config': quant_config, + "a": moe_tensors.a, + "w1_q": moe_tensors.w1_q, # type: ignore[union-attr] + "w2_q": moe_tensors.w2_q, # type: ignore[union-attr] + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "ab_strides1": moe_tensors.ab_strides1, + "ab_strides2": moe_tensors.ab_strides2, + "c_strides1": moe_tensors.c_strides1, + "c_strides2": moe_tensors.c_strides2, + "quant_config": quant_config, } num_experts = moe_tensors.w1.size(0) @@ -243,7 +256,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit, return run_with_expert_maps( num_experts, num_local_experts, # type: ignore[arg-type] - **kwargs) + **kwargs, + ) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @@ -253,8 +267,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_no_graph( m: int, n: int, @@ -269,25 +285,18 @@ def test_cutlass_moe_8_bit_no_graph( current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): - mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, - per_out_ch) + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids, _ = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - triton_output = fused_experts(mt.a_d, - mt.w1_d, - mt.w2_d, - topk_weights, - topk_ids, - quant_config=quant_config) + triton_output = fused_experts( + mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config + ) if ep_size is not None: assert e % ep_size == 0, "Cannot distribute experts evenly" @@ -295,15 +304,15 @@ def test_cutlass_moe_8_bit_no_graph( else: number_local_experts = None - cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token, - per_out_ch, number_local_experts) + cutlass_output = run_8_bit( + mt, topk_weights, topk_ids, per_act_token, per_out_ch, number_local_experts + ) # Note 5.5 only needed for larger problem sizes, 5 works ok for # the rest. - torch.testing.assert_close(triton_output, - cutlass_output, - atol=5.5e-2, - rtol=1e-2) + torch.testing.assert_close( + triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2 + ) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @@ -313,8 +322,10 @@ def test_cutlass_moe_8_bit_no_graph( @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_cuda_graph( m: int, n: int, @@ -330,39 +341,30 @@ def test_cutlass_moe_8_bit_cuda_graph( with set_current_vllm_config(vllm_config): dtype = torch.half - mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, - per_out_ch) + mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - triton_output = fused_experts(mt.a_d, - mt.w1_d, - mt.w2_d, - topk_weights, - topk_ids, - quant_config=quant_config) + triton_output = fused_experts( + mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config + ) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - cutlass_output = run_8_bit(mt, topk_weights, topk_ids, - per_act_token, per_out_ch) + cutlass_output = run_8_bit( + mt, topk_weights, topk_ids, per_act_token, per_out_ch + ) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() - torch.testing.assert_close(triton_output, - cutlass_output, - atol=9e-2, - rtol=1e-2) + torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2) @pytest.mark.parametrize("m", [64]) @@ -375,8 +377,10 @@ def test_cutlass_moe_8_bit_cuda_graph( @pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_EP( m: int, n: int, @@ -388,8 +392,9 @@ def test_cutlass_moe_8_bit_EP( ep_size: int, monkeypatch, ): - test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token, - per_out_channel, monkeypatch, ep_size) + test_cutlass_moe_8_bit_no_graph( + m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size + ) LARGE_MNK_FACTORS = [ @@ -406,8 +411,10 @@ def test_cutlass_moe_8_bit_EP( @pytest.mark.parametrize("ep_size", [8]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_cutlass_moe_8_bit_EP_large( m: int, n: int, @@ -419,8 +426,9 @@ def test_cutlass_moe_8_bit_EP_large( ep_size: int, monkeypatch, ): - test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token, - per_out_channel, monkeypatch, ep_size) + test_cutlass_moe_8_bit_no_graph( + m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size + ) @pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)]) @@ -430,8 +438,10 @@ def test_cutlass_moe_8_bit_EP_large( @pytest.mark.parametrize("ep_size", [8]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) def test_run_cutlass_moe_fp8( m: int, n: int, @@ -444,14 +454,12 @@ def test_run_cutlass_moe_fp8( ): current_platform.seed_everything(7) with set_current_vllm_config(vllm_config): - mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, - per_out_channel) + mt = MOETensors8Bit.make_moe_tensors_8bit( + m, k, n, e, per_act_token, per_out_channel + ) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids, _ = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False) # we want to make sure there is at least one token that's generated in # this expert shard and at least one token that's NOT generated in this # expert shard @@ -462,12 +470,12 @@ def test_run_cutlass_moe_fp8( workspace2_shape = (m * topk, max(n, k)) output_shape = (m, k) - workspace13 = torch.empty(prod(workspace13_shape), - device="cuda", - dtype=mt.a.dtype) - workspace2 = torch.empty(prod(workspace2_shape), - device="cuda", - dtype=mt.a.dtype) + workspace13 = torch.empty( + prod(workspace13_shape), device="cuda", dtype=mt.a.dtype + ) + workspace2 = torch.empty( + prod(workspace2_shape), device="cuda", dtype=mt.a.dtype + ) num_local_experts = e // ep_size start, end = 0, num_local_experts @@ -475,36 +483,55 @@ def test_run_cutlass_moe_fp8( expert_map[start:end] = list(range(num_local_experts)) expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") - ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64) + ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64) activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) - a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, - torch.float8_e4m3fn, - per_act_token) + a1q, a1q_scale = moe_kernel_quantize_input( + mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token + ) global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0) func = lambda output: run_cutlass_moe_fp8( - output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, - global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, - a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2, - workspace13, workspace2, None, mt.a.dtype, per_act_token, - per_out_channel, False, topk_weights) + output, + a1q, + mt.w1_q, + mt.w2_q, + topk_ids, + activation, + global_num_experts, + expert_map, + mt.w1_scale, + mt.w2_scale, + a1q_scale, + None, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, + workspace13, + workspace2, + None, + mt.a.dtype, + per_act_token, + per_out_channel, + False, + topk_weights, + ) workspace13.random_() - output_random_workspace = torch.empty(output_shape, - device="cuda", - dtype=mt.a.dtype) + output_random_workspace = torch.empty( + output_shape, device="cuda", dtype=mt.a.dtype + ) func(output_random_workspace) workspace13.fill_(0) - output_zero_workspace = torch.zeros(output_shape, - device="cuda", - dtype=mt.a.dtype) + output_zero_workspace = torch.zeros( + output_shape, device="cuda", dtype=mt.a.dtype + ) func(output_zero_workspace) - torch.testing.assert_close(output_random_workspace, - output_zero_workspace, - atol=5e-3, - rtol=1e-3) + torch.testing.assert_close( + output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3 + ) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index ced5457d4f53..e68c5bfa5946 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -16,10 +16,11 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported @@ -29,19 +30,20 @@ from .utils import make_test_weights if has_deep_ep(): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( + DeepEPHTPrepareAndFinalize, + ) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( + DeepEPLLPrepareAndFinalize, + ) from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a if has_deep_gemm(): - from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) - from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts) + BatchedDeepGemmExperts, + ) + from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts requires_deep_ep = pytest.mark.skipif( not has_deep_ep(), @@ -58,9 +60,10 @@ def next_power_of_2(x): import math + if x == 0: return 1 - return 2**math.ceil(math.log2(x)) + return 2 ** math.ceil(math.log2(x)) def make_block_quant_fp8_weights( @@ -72,13 +75,9 @@ def make_block_quant_fp8_weights( """ Return weights w1q, w2q, w1_scale, w2_scale """ - (_, w1q, w1_scale, _), (_, w2q, w2_scale, - _) = make_test_weights(e, - n, - k, - torch.bfloat16, - torch.float8_e4m3fn, - block_shape=block_size) + (_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights( + e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size + ) return w1q, w2q, w1_scale, w2_scale @@ -106,15 +105,15 @@ class TestTensors: @staticmethod def make(config: TestConfig, rank) -> "TestTensors": - dtype = torch.bfloat16 topk, m, k = (config.topk, config.m, config.k) fp8_info = torch.finfo(torch.float8_e4m3fn) fp8_max, fp8_min = fp8_info.max, fp8_info.min - rank_tokens = torch.randn( - (m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0 + rank_tokens = ( + torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0 + ) rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max) rank_token_scales = None @@ -122,25 +121,32 @@ def make(config: TestConfig, rank) -> "TestTensors": low=0, high=config.num_experts, size=(m, topk), - device=torch.cuda.current_device()).to(dtype=torch.int64) + device=torch.cuda.current_device(), + ).to(dtype=torch.int64) - topk_weights = torch.randn(topk_ids.shape, - dtype=torch.float32, - device=torch.cuda.current_device()) + topk_weights = torch.randn( + topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device() + ) - return TestTensors(rank_tokens=rank_tokens, - rank_token_scales=rank_token_scales, - topk=topk_ids, - topk_weights=topk_weights, - config=config) + return TestTensors( + rank_tokens=rank_tokens, + rank_token_scales=rank_token_scales, + topk=topk_ids, + topk_weights=topk_weights, + config=config, + ) def make_ll_modular_kernel( - pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int, - dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype], - test_config: TestConfig, - quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel: - + pg: ProcessGroup, + pgi: ProcessGroupInfo, + max_tokens_per_rank: int, + dp_size: int, + hidden_size: int, + q_dtype: Optional[torch.dtype], + test_config: TestConfig, + quant_config: FusedMoEQuantConfig, +) -> FusedMoEModularKernel: assert test_config.low_latency assert test_config.use_fp8_dispatch is not None @@ -153,26 +159,30 @@ def make_ll_modular_kernel( max_tokens_per_rank=max_tokens_per_rank, hidden_size=hidden_size, num_experts=test_config.num_experts, - use_fp8_dispatch=test_config.use_fp8_dispatch), + use_fp8_dispatch=test_config.use_fp8_dispatch, + ), q_dtype=q_dtype, - block_shape=test_config.block_size) + block_shape=test_config.block_size, + ) fused_experts = BatchedDeepGemmExperts( max_num_tokens=max_tokens_per_rank, num_dispatchers=pgi.world_size // dp_size, quant_config=quant_config, ) - mk = FusedMoEModularKernel(prepare_finalize=a2a, - fused_experts=fused_experts) + mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk def make_ht_modular_kernel( - pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, - num_local_experts: int, q_dtype: Optional[torch.dtype], - test_config: TestConfig, - quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel: - + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + num_local_experts: int, + q_dtype: Optional[torch.dtype], + test_config: TestConfig, + quant_config: FusedMoEQuantConfig, +) -> FusedMoEModularKernel: assert not test_config.low_latency assert test_config.use_fp8_dispatch is None @@ -183,76 +193,82 @@ def make_ht_modular_kernel( deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts), deepep_ll_args=None, q_dtype=q_dtype, - block_shape=test_config.block_size) + block_shape=test_config.block_size, + ) fused_experts = DeepGemmExperts(quant_config) - mk = FusedMoEModularKernel(prepare_finalize=a2a, - fused_experts=fused_experts) + mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk def make_modular_kernel( - pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, - num_local_experts: int, test_tensors: TestTensors, - quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel: - + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + num_local_experts: int, + test_tensors: TestTensors, + quant_config: FusedMoEQuantConfig, +) -> FusedMoEModularKernel: q_dtype = torch.float8_e4m3fn test_config = test_tensors.config mk: FusedMoEModularKernel # Make modular kernel if test_config.low_latency: - max_tokens_per_rank = max( - 64, next_power_of_2(test_tensors.rank_tokens.size(0))) + max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0))) hidden_size = test_tensors.rank_tokens.size(-1) - mk = make_ll_modular_kernel(pg=pg, - pgi=pgi, - max_tokens_per_rank=max_tokens_per_rank, - dp_size=dp_size, - hidden_size=hidden_size, - q_dtype=q_dtype, - test_config=test_config, - quant_config=quant_config) + mk = make_ll_modular_kernel( + pg=pg, + pgi=pgi, + max_tokens_per_rank=max_tokens_per_rank, + dp_size=dp_size, + hidden_size=hidden_size, + q_dtype=q_dtype, + test_config=test_config, + quant_config=quant_config, + ) else: - mk = make_ht_modular_kernel(pg, - pgi, - dp_size, - num_local_experts, - q_dtype, - test_config, - quant_config=quant_config) + mk = make_ht_modular_kernel( + pg, + pgi, + dp_size, + num_local_experts, + q_dtype, + test_config, + quant_config=quant_config, + ) return mk -def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, - dp_size: int, test_tensors: TestTensors, - w1: torch.Tensor, w2: torch.Tensor, - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor]) -> torch.Tensor: - +def deepep_deepgemm_moe_impl( + pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + test_tensors: TestTensors, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], +) -> torch.Tensor: test_config = test_tensors.config num_experts = test_config.num_experts num_local_experts = w1.size(0) def build_expert_map(): num_local_experts = w1.size(0) - expert_map = torch.full((num_experts, ), - fill_value=-1, - dtype=torch.int32) + expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32) s = pgi.rank * num_local_experts e = s + num_local_experts expert_map[s:e] = torch.tensor(list(range(num_local_experts))) - return expert_map.to(device=torch.cuda.current_device(), - dtype=torch.int32) + return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) quant_config = fp8_w8a8_moe_quant_config( w1_scale=w1_scale, w2_scale=w2_scale, # Low-Latency kernels can't dispatch scales. - a1_scale=(None if test_config.low_latency else - test_tensors.rank_token_scales), + a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales), block_shape=test_config.block_size, ) @@ -263,26 +279,35 @@ def build_expert_map(): dp_size=dp_size, num_local_experts=num_local_experts, test_tensors=test_tensors, - quant_config=quant_config) - - out = mk.forward(hidden_states=test_tensors.rank_tokens, - w1=w1, - w2=w2, - topk_weights=test_tensors.topk_weights, - topk_ids=test_tensors.topk, - inplace=False, - activation="silu", - global_num_experts=num_experts, - expert_map=build_expert_map(), - apply_router_weight_on_input=False) - return out + quant_config=quant_config, + ) + out = mk.forward( + hidden_states=test_tensors.rank_tokens, + w1=w1, + w2=w2, + topk_weights=test_tensors.topk_weights, + topk_ids=test_tensors.topk, + inplace=False, + activation="silu", + global_num_experts=num_experts, + expert_map=build_expert_map(), + apply_router_weight_on_input=False, + ) + return out -def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, - topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - a1_scale: torch.Tensor, block_shape: list[int]): +def triton_impl( + a: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + block_shape: list[int], +): quant_config = fp8_w8a8_moe_quant_config( w1_scale=w1_scale, w2_scale=w2_scale, @@ -300,7 +325,8 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, quant_config=quant_config, # Make sure this is set to False so we # don't end up comparing the same implementation. - allow_deep_gemm=False) + allow_deep_gemm=False, + ) def _test_deepep_deepgemm_moe( @@ -321,22 +347,21 @@ def _test_deepep_deepgemm_moe( pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, pgi.rank) - block_shape = [ - w1.size(1) // w1_scale.size(1), - w1.size(2) // w1_scale.size(2) - ] + block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)] with set_current_vllm_config(VllmConfig()): # Reference - triton_moe = triton_impl(a=test_tensors.rank_tokens, - topk_ids=test_tensors.topk, - topk_weights=test_tensors.topk_weights, - w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=test_tensors.rank_token_scales, - block_shape=block_shape) + triton_moe = triton_impl( + a=test_tensors.rank_tokens, + topk_ids=test_tensors.topk, + topk_weights=test_tensors.topk_weights, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=test_tensors.rank_token_scales, + block_shape=block_shape, + ) # Slice experts for this rank. num_local_experts = config.num_experts // pgi.world_size @@ -390,10 +415,15 @@ def _test_deepep_deepgemm_moe( @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_deep_gemm_e8m0_used(), - reason="Skipping test for Blackwell DeepGEMM") -def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, - topk: int, world_dp_size: tuple[int, int]): +@pytest.mark.skipif( + is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM" +) +def test_ht_deepep_deepgemm_moe( + mnk: tuple[int, int, int], + num_experts: int, + topk: int, + world_dp_size: tuple[int, int], +): """ Tests for High-Throughput DeepEP + DeepGemm integration. """ @@ -409,21 +439,32 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, block_size = [block_m, block_m] world_size, dp_size = world_dp_size - config = TestConfig(topk=topk, - m=m, - k=k, - n=n, - num_experts=num_experts, - per_act_token_quant=False, - block_size=block_size, - low_latency=False, - use_fp8_dispatch=None) + config = TestConfig( + topk=topk, + m=m, + k=k, + n=n, + num_experts=num_experts, + per_act_token_quant=False, + block_size=block_size, + low_latency=False, + use_fp8_dispatch=None, + ) w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( - num_experts, n, k, block_size) + num_experts, n, k, block_size + ) - parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1, - w2, w1_scale, w2_scale) + parallel_launch( + world_size, + _test_deepep_deepgemm_moe, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + ) MNKs = [ @@ -448,8 +489,9 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_deep_gemm_e8m0_used(), - reason="Skipping test for Blackwell DeepGEMM") +@pytest.mark.skipif( + is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM" +) def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], num_experts: int, @@ -482,7 +524,16 @@ def test_ll_deepep_deepgemm_moe( ) w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( - num_experts, n, k, block_size) + num_experts, n, k, block_size + ) - parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1, - w2, w1_scale, w2_scale) + parallel_launch( + world_size, + _test_deepep_deepgemm_moe, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + ) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 54d3a62b03fc..a1dabea1f0c7 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -16,12 +16,11 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.platforms import current_platform from vllm.utils import has_deep_ep @@ -29,10 +28,12 @@ from .parallel_utils import ProcessGroupInfo, parallel_launch if has_deep_ep(): - from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 - DeepEPHTPrepareAndFinalize) - from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 - DeepEPLLPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( + DeepEPHTPrepareAndFinalize, + ) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( + DeepEPLLPrepareAndFinalize, + ) from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a @@ -45,7 +46,7 @@ def make_weights( - e, n, k, dtype + e, n, k, dtype ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Return weights w1, w2, w1_scale, w2_scale @@ -64,17 +65,15 @@ def make_weights( k_b_scales = k w1_q = torch.empty_like(w1, dtype=dtype) w2_q = torch.empty_like(w2, dtype=dtype) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=True) + w1[expert], use_per_token_if_dynamic=True + ) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=True) + w2[expert], use_per_token_if_dynamic=True + ) return w1_q, w2_q, w1_scale, w2_scale @@ -100,24 +99,25 @@ class TestTensors: def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors": # TODO (varun) - check that float16 works ? assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn] - token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn - else config.dtype) - rank_tokens = torch.randn( - (config.m, config.k), device="cuda", dtype=token_dtype) / 10 + token_dtype = ( + torch.bfloat16 if config.dtype == torch.float8_e4m3fn else config.dtype + ) + rank_tokens = ( + torch.randn((config.m, config.k), device="cuda", dtype=token_dtype) / 10 + ) rank_token_scales = None - topk = torch.randint(low=0, - high=config.num_experts, - size=(config.m, config.topk), - device="cuda").to(dtype=torch.int64) - topk_weights = torch.randn(topk.shape, - dtype=torch.float32, - device="cuda") - return TestTensors(rank_tokens=rank_tokens, - rank_token_scales=rank_token_scales, - topk=topk, - topk_weights=topk_weights, - config=config) + topk = torch.randint( + low=0, high=config.num_experts, size=(config.m, config.topk), device="cuda" + ).to(dtype=torch.int64) + topk_weights = torch.randn(topk.shape, dtype=torch.float32, device="cuda") + return TestTensors( + rank_tokens=rank_tokens, + rank_token_scales=rank_token_scales, + topk=topk, + topk_weights=topk_weights, + config=config, + ) def make_modular_kernel( @@ -132,28 +132,33 @@ def make_modular_kernel( use_fp8_dispatch: bool, quant_config: FusedMoEQuantConfig, ) -> FusedMoEModularKernel: - ht_args: Optional[DeepEPHTArgs] = None ll_args: Optional[DeepEPLLArgs] = None if low_latency_mode: - ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK, - hidden_size=hidden_size, - num_experts=num_experts, - use_fp8_dispatch=use_fp8_dispatch) + ll_args = DeepEPLLArgs( + max_tokens_per_rank=MAX_TOKENS_PER_RANK, + hidden_size=hidden_size, + num_experts=num_experts, + use_fp8_dispatch=use_fp8_dispatch, + ) else: assert not use_fp8_dispatch, ( - "FP8 Dispatch is valid only for low-latency kernels") + "FP8 Dispatch is valid only for low-latency kernels" + ) ht_args = DeepEPHTArgs(num_local_experts=num_local_experts) - a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \ - make_deepep_a2a(pg = pg, - pgi = pgi, - dp_size = dp_size, - q_dtype = q_dtype, - block_shape = None, - deepep_ht_args = ht_args, - deepep_ll_args = ll_args) + a2a: Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = ( + make_deepep_a2a( + pg=pg, + pgi=pgi, + dp_size=dp_size, + q_dtype=q_dtype, + block_shape=None, + deepep_ht_args=ht_args, + deepep_ll_args=ll_args, + ) + ) num_dispatchers = pgi.world_size // dp_size @@ -167,8 +172,7 @@ def make_modular_kernel( else: fused_experts = TritonExperts(quant_config=quant_config) - mk = FusedMoEModularKernel(prepare_finalize=a2a, - fused_experts=fused_experts) + mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts) return mk @@ -186,19 +190,15 @@ def deep_ep_moe_impl( use_fp8_dispatch: bool, per_act_token_quant: bool, ) -> torch.Tensor: - num_local_experts = w1.size(0) def build_expert_map(): num_local_experts = w1.size(0) - expert_map = torch.full((num_experts, ), - fill_value=-1, - dtype=torch.int32) + expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32) s = pgi.rank * num_local_experts e = s + num_local_experts expert_map[s:e] = torch.tensor(list(range(num_local_experts))) - return expert_map.to(device=torch.cuda.current_device(), - dtype=torch.int32) + return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) hidden_size = test_tensors.rank_tokens.size(1) is_quantized = w1.dtype == torch.float8_e4m3fn @@ -214,11 +214,12 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end] topk_chunk = test_tensors.topk[chunk_start:chunk_end] rank_token_scales_chunk = test_tensors.rank_token_scales - if rank_token_scales_chunk is not None and rank_token_scales_chunk.size( - 0) == total_num_tokens: + if ( + rank_token_scales_chunk is not None + and rank_token_scales_chunk.size(0) == total_num_tokens + ): # per act token - rank_token_scales_chunk = rank_token_scales_chunk[ - chunk_start:chunk_end] + rank_token_scales_chunk = rank_token_scales_chunk[chunk_start:chunk_end] quant_config = FusedMoEQuantConfig.make( q_dtype, @@ -230,26 +231,37 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): # Make modular kernel mk: FusedMoEModularKernel = make_modular_kernel( - pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts, - num_local_experts, q_dtype, use_fp8_dispatch, quant_config) - - out = mk.forward(hidden_states=rank_tokens_chunk, - w1=w1, - w2=w2, - topk_weights=topk_weights_chunk, - topk_ids=topk_chunk, - inplace=False, - activation="silu", - global_num_experts=num_experts, - expert_map=build_expert_map(), - apply_router_weight_on_input=False) + pg, + pgi, + low_latency_mode, + hidden_size, + dp_size, + num_experts, + num_local_experts, + q_dtype, + use_fp8_dispatch, + quant_config, + ) + + out = mk.forward( + hidden_states=rank_tokens_chunk, + w1=w1, + w2=w2, + topk_weights=topk_weights_chunk, + topk_ids=topk_chunk, + inplace=False, + activation="silu", + global_num_experts=num_experts, + expert_map=build_expert_map(), + apply_router_weight_on_input=False, + ) if not skip_result_store: - out_hidden_states[chunk_start:chunk_end, :].copy_( - out, non_blocking=True) + out_hidden_states[chunk_start:chunk_end, :].copy_(out, non_blocking=True) - max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK - if low_latency_mode else total_num_tokens) + max_num_tokens_per_dp = ( + MAX_TOKENS_PER_RANK if low_latency_mode else total_num_tokens + ) for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp): chunk_start = chunk_start_ @@ -258,9 +270,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_start = min(chunk_start, total_num_tokens - 1) chunk_end = min(chunk_end, total_num_tokens) - process_chunk(chunk_start, - chunk_end, - skip_result_store=chunk_start_ >= total_num_tokens) + process_chunk( + chunk_start, chunk_end, skip_result_store=chunk_start_ >= total_num_tokens + ) return out_hidden_states @@ -274,9 +286,11 @@ def torch_moe_impl( using_fp8_dispatch: bool, per_act_token_quant: bool, ): - - a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk, - test_tensors.topk_weights) + a, topk_ids, topk_weights = ( + test_tensors.rank_tokens, + test_tensors.topk, + test_tensors.topk_weights, + ) if using_fp8_dispatch: # The DeepEP implementation is requested to dispatch using FP8. # For numerical stability for testing, emulate the fp8 dispatch by @@ -284,8 +298,11 @@ def torch_moe_impl( assert not per_act_token_quant a = test_tensors.rank_tokens aq, aq_scale = per_token_group_quant_fp8(a, 128) - a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view( - a.shape).to(a.dtype) + a = ( + (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)) + .view(a.shape) + .to(a.dtype) + ) is_quantized = w1.dtype == torch.float8_e4m3fn a_dtype = a.dtype @@ -306,8 +323,9 @@ def torch_moe_impl( e_w = topk_weights[i][j] w1_e = w1[e] w2_e = w2[e] - o_i += (SiluAndMul() - (a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w + o_i += ( + SiluAndMul()(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1) + ) * e_w if is_quantized: out = out.to(dtype=a_dtype) @@ -327,28 +345,36 @@ def _deep_ep_moe( use_fp8_dispatch: bool, per_act_token_quant: bool, ): - if not low_latency_mode: assert not use_fp8_dispatch, ( - "FP8 dispatch interface is available only in low-latency mode") + "FP8 dispatch interface is available only in low-latency mode" + ) is_quantized = w1.dtype == torch.float8_e4m3fn w1 = w1.to(device=torch.cuda.current_device()) w2 = w2.to(device=torch.cuda.current_device()) if is_quantized: w1_scale = w1_scale.to( # type: ignore - device=torch.cuda.current_device()) + device=torch.cuda.current_device() + ) w2_scale = w2_scale.to( # type: ignore - device=torch.cuda.current_device()) + device=torch.cuda.current_device() + ) pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, low_latency_mode) with set_current_vllm_config(VllmConfig()): # Reference - torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale, - w2_scale, use_fp8_dispatch, - per_act_token_quant) + torch_combined = torch_moe_impl( + test_tensors, + w1, + w2, + w1_scale, + w2_scale, + use_fp8_dispatch, + per_act_token_quant, + ) # Splice experts for this rank. num_local_experts = config.num_experts // pgi.world_size @@ -420,18 +446,23 @@ def test_deep_ep_moe( current_platform.seed_everything(7) world_size, dp_size = world_dp_size - config = TestConfig(dtype=dtype, - topk=topk, - m=m, - k=k, - n=n, - num_experts=num_experts) + config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) - parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, - config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch, - per_act_token_quant) + parallel_launch( + world_size, + _deep_ep_moe, + low_latency_mode, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + use_fp8_dispatch, + per_act_token_quant, + ) MNKs = [ @@ -467,8 +498,7 @@ def test_low_latency_deep_ep_moe( ): low_latency_mode = True - if (low_latency_mode - and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES): + if low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES: pytest.skip( f"Skipping test as hidden size {k} is not in list of supported " f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}" @@ -476,15 +506,20 @@ def test_low_latency_deep_ep_moe( current_platform.seed_everything(7) world_size, dp_size = world_dp_size - config = TestConfig(dtype=dtype, - topk=topk, - m=m, - k=k, - n=n, - num_experts=num_experts) + config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts) w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) - parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, - config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch, - False) + parallel_launch( + world_size, + _deep_ep_moe, + low_latency_mode, + dp_size, + config, + w1, + w2, + w1_scale, + w2_scale, + use_fp8_dispatch, + False, + ) diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index d575b6d4ca62..cad0085d5ba6 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -11,14 +11,18 @@ import pytest import torch -from vllm.model_executor.layers.fused_moe.config import ( - fp8_w8a8_moe_quant_config) +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config + # vLLM fused-expert reference (Triton fallback + DeepGEMM option) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) -from vllm.utils.deep_gemm import (calc_diff, is_deep_gemm_supported, - per_block_cast_to_fp8) + per_token_group_quant_fp8, +) +from vllm.utils.deep_gemm import ( + calc_diff, + is_deep_gemm_supported, + per_block_cast_to_fp8, +) BLOCK_SIZE = [128, 128] @@ -37,8 +41,10 @@ def make_block_quant_fp8_weights( w2 shape: (E, K, N) """ dtype = torch.bfloat16 - fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo( - torch.float8_e4m3fn).min + fp8_max, fp8_min = ( + torch.finfo(torch.float8_e4m3fn).max, + torch.finfo(torch.float8_e4m3fn).min, + ) # bf16 reference weights w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10 @@ -54,24 +60,16 @@ def make_block_quant_fp8_weights( w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) - w1_s = torch.empty(e, - n_tiles_w1, - k_tiles_w1, - device="cuda", - dtype=torch.float32) - w2_s = torch.empty(e, - n_tiles_w2, - k_tiles_w2, - device="cuda", - dtype=torch.float32) + w1_s = torch.empty(e, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32) + w2_s = torch.empty(e, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32) for i in range(e): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i], - block_size=block_size, - use_ue8m0=True) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i], - block_size=block_size, - use_ue8m0=True) + w1[i], w1_s[i] = per_block_cast_to_fp8( + w1_bf16[i], block_size=block_size, use_ue8m0=True + ) + w2[i], w2_s[i] = per_block_cast_to_fp8( + w2_bf16[i], block_size=block_size, use_ue8m0=True + ) return w1, w2, w1_s, w2_s @@ -81,18 +79,17 @@ def run_single_case(m, n, k, topk, num_experts, block_size): Run one (M,N,K) configuration on a single GPU and assert DeepGEMM == Triton baseline within tolerance. """ - tokens_bf16 = torch.randn( - m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1) + tokens_bf16 = ( + torch.randn(m, k, device="cuda", dtype=torch.bfloat16) + .clamp_min_(-1) + .clamp_max_(1) + ) _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1]) # expert weight tensors - w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, - block_size) + w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, block_size) - router_logits = torch.randn(m, - num_experts, - device="cuda", - dtype=torch.float32) + router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32) topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) @@ -147,15 +144,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size): @pytest.mark.parametrize(("m", "n", "k"), MNKs) @pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("num_experts", NUM_EXPERTS) -@pytest.mark.skipif(not is_deep_gemm_supported(), - reason="Requires deep_gemm kernels") +@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels") def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch): - with monkeypatch.context() as mp: mp.setenv("VLLM_USE_DEEP_GEMM", "1") _fused_moe_mod = importlib.import_module( - "vllm.model_executor.layers.fused_moe.fused_moe") + "vllm.model_executor.layers.fused_moe.fused_moe" + ) call_counter = {"cnt": 0} @@ -165,8 +161,7 @@ def _spy_deep_gemm_moe_fp8(*args, **kwargs): call_counter["cnt"] += 1 return orig_fn(*args, **kwargs) - monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", - _spy_deep_gemm_moe_fp8) + monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8", _spy_deep_gemm_moe_fp8) if topk > num_experts: pytest.skip(f"topk={topk} > num_experts={num_experts}") @@ -181,6 +176,7 @@ def _spy_deep_gemm_moe_fp8(*args, **kwargs): ) # ensure that the DeepGEMM path was indeed taken. - assert call_counter["cnt"] == 1, \ - f"DeepGEMM path was not executed during the test. " \ + assert call_counter["cnt"] == 1, ( + f"DeepGEMM path was not executed during the test. " f"Call counter: {call_counter['cnt']}" + ) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index c3be7f28fb24..0780232a8264 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -6,24 +6,28 @@ import torch from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import ( - fp8_w8a8_moe_quant_config) +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8, - register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, - swap_w13_to_w31) -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - input_to_float8) + apply_flashinfer_per_tensor_scale_fp8, + flashinfer_cutlass_moe_fp8, + register_moe_scaling_factors, + rotate_flashinfer_fp8_moe_weights, + swap_w13_to_w31, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8 from vllm.model_executor.models.llama4 import Llama4MoE from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -if not has_flashinfer_cutlass_fused_moe( -) or not current_platform.has_device_capability(100): - pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", - allow_module_level=True) +if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability( + 100 +): + pytest.skip( + "Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True, + ) NUM_EXPERTS = [16] TOP_KS = [1] @@ -39,8 +43,7 @@ (1, 4096, 5120), ] -vllm_config = VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1)) +vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 @@ -74,18 +77,17 @@ class TestData: layer: torch.nn.Module @staticmethod - def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, - reorder: bool) -> "TestData": - hidden_states = torch.randn( - (m, k), device="cuda", dtype=torch.bfloat16) / 10 + def make_moe_tensors_8bit( + m: int, k: int, n: int, e: int, reorder: bool + ) -> "TestData": + hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) # Scale to fp8 _, a1_scale = input_to_float8(hidden_states) a1_scale = 1.0 / a1_scale - a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to( - dtype=torch.float32) + a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32) w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13) w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) @@ -102,8 +104,7 @@ def make_moe_tensors_8bit(m: int, k: int, n: int, e: int, # flashinfer expects swapped rows for w13 layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if reorder: - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, - layer.w2_weight) + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) layer.custom_routing_function = Llama4MoE.custom_routing_function layer.intermediate_size_per_partition = n layer.ep_rank = 0 @@ -145,7 +146,8 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( top_k=topk, renormalize=False, custom_routing_function=Llama4MoE.custom_routing_function, - scoring_func="softmax") + scoring_func="softmax", + ) quant_config = fp8_w8a8_moe_quant_config( w1_scale=td.w13_weight_scale, @@ -178,12 +180,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( top_k=topk, num_expert_group=None, topk_group=None, - apply_router_weight_on_input=True) + apply_router_weight_on_input=True, + ) - torch.testing.assert_close(output, - flashinfer_output, - atol=5.5e-2, - rtol=1e-2) + torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2) @pytest.mark.skip( @@ -213,7 +213,8 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( top_k=topk, renormalize=False, custom_routing_function=Llama4MoE.custom_routing_function, - scoring_func="softmax") + scoring_func="softmax", + ) quant_config = fp8_w8a8_moe_quant_config( w1_scale=td.w13_weight_scale, @@ -250,7 +251,6 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( apply_router_weight_on_input=True, ) - torch.testing.assert_close(output, - flashinfer_cutlass_output, - atol=5.5e-2, - rtol=1e-2) + torch.testing.assert_close( + output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2 + ) diff --git a/tests/kernels/moe/test_flashinfer_moe.py b/tests/kernels/moe/test_flashinfer_moe.py index 8bf096b798cb..18cfd4f79092 100644 --- a/tests/kernels/moe/test_flashinfer_moe.py +++ b/tests/kernels/moe/test_flashinfer_moe.py @@ -4,26 +4,33 @@ import torch from tests.kernels.moe.utils import make_test_quant_config -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe) + FlashInferExperts, + is_valid_flashinfer_cutlass_fused_moe, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -if not has_flashinfer_cutlass_fused_moe( -) or not current_platform.has_device_capability(100): - pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support", - allow_module_level=True) +if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability( + 100 +): + pytest.skip( + "Requires flashinfer_cutlass_fused_moe and nvfp4 support", + allow_module_level=True, + ) MNK_FACTORS = [ (2, 1024, 1024), @@ -44,13 +51,13 @@ @pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @torch.inference_mode() -def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, - dtype: torch.dtype): +def test_flashinfer_fp4_moe_no_graph( + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype +): current_platform.seed_everything(7) with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 quant_blocksize = 16 @@ -66,10 +73,7 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q) @@ -87,16 +91,19 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ) # Reference check: - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) + ).to(torch.float32) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) _, m_k = a_fp4.shape - a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, - a_scale_interleaved, - a_global_scale, - dtype=a.dtype, - device=a.device, - block_size=quant_blocksize) + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize, + ) w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) @@ -104,23 +111,26 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, for idx in range(0, e): w1_d[idx] = dequantize_nvfp4_to_dtype( w1_q[idx], - quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]), + quant_config.w1_scale[idx], + (1 / quant_config.g1_alphas[idx]), dtype=dtype, device=w1_q.device, - block_size=quant_blocksize) + block_size=quant_blocksize, + ) w2_d[idx] = dequantize_nvfp4_to_dtype( w2_q[idx], - quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]), + quant_config.w2_scale[idx], + (1 / quant_config.g2_alphas[idx]), dtype=dtype, device=w2_q.device, - block_size=quant_blocksize) + block_size=quant_blocksize, + ) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) - torch.testing.assert_close(torch_output, - flashinfer_output, - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close( + torch_output, flashinfer_output, atol=1e-1, rtol=1e-1 + ) if __name__ == "__main__": diff --git a/tests/kernels/moe/test_gpt_oss_triton_kernels.py b/tests/kernels/moe/test_gpt_oss_triton_kernels.py index 024993c7677d..f78596d220bf 100644 --- a/tests/kernels/moe/test_gpt_oss_triton_kernels.py +++ b/tests/kernels/moe/test_gpt_oss_triton_kernels.py @@ -17,20 +17,21 @@ import triton_kernels.swiglu from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from triton_kernels.numerics import InFlexData -from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp, - upcast_from_mxfp) +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor_details import layout from triton_kernels.testing import assert_close from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize) + BatchedPrepareAndFinalize, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - BatchedOAITritonExperts, triton_kernel_moe_forward) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) + BatchedOAITritonExperts, + triton_kernel_moe_forward, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.utils import shuffle_weight from vllm.utils import round_up @@ -46,13 +47,11 @@ def deshuffle(w: torch.Tensor): def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): randbits = [torch.randperm(E) for _ in range(M)] x_list = [ - (-1)**i * - ((16384 + - ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16)) + (-1) ** i + * ((16384 + ((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16)) for i, bits in enumerate(randbits) ] - exp_data = torch.stack(x_list).to( - device="cuda") # simulating gate_output (M, E) + exp_data = torch.stack(x_list).to(device="cuda") # simulating gate_output (M, E) # create input tensor x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda") @@ -120,20 +119,21 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): value=0, ) - w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0), - mode="constant", - value=0) - w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0), - mode="constant", - value=0) + w1_bias_tri = F.pad( + w1_bias_tri, (0, w1_right_pad, 0, 0), mode="constant", value=0 + ) + w2_bias_tri = F.pad( + w2_bias_tri, (0, w2_right_pad, 0, 0), mode="constant", value=0 + ) x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0) - w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout( - mx_axis=1) + w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) w_scale_layout, w_scale_layout_opts = ( layout.make_default_matmul_mxfp4_w_scale_layout( - mx_axis=1, num_warps=num_warps)) + mx_axis=1, num_warps=num_warps + ) + ) w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1) w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1) @@ -141,29 +141,33 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1) w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1) - w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout, - **w_layout_opts) + w1_tri = convert_layout( + wrap_torch_tensor(w1_tri, FP4), w_layout, **w_layout_opts + ) w1_scale_tri = convert_layout( wrap_torch_tensor(w1_scale_tri), w_scale_layout, **w_scale_layout_opts, ) - w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout, - **w_layout_opts) + w2_tri = convert_layout( + wrap_torch_tensor(w2_tri, FP4), w_layout, **w_layout_opts + ) w2_scale_tri = convert_layout( wrap_torch_tensor(w2_scale_tri), w_scale_layout, **w_scale_layout_opts, ) - pc1 = PrecisionConfig(weight_scale=w1_scale_tri, - flex_ctx=FlexCtx(rhs_data=InFlexData())) - pc2 = PrecisionConfig(weight_scale=w2_scale_tri, - flex_ctx=FlexCtx(rhs_data=InFlexData())) + pc1 = PrecisionConfig( + weight_scale=w1_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) + ) + pc2 = PrecisionConfig( + weight_scale=w2_scale_tri, flex_ctx=FlexCtx(rhs_data=InFlexData()) + ) # tucuate so the rest can run properly - w1 = w1[..., :K, :2 * N] + w1 = w1[..., :K, : 2 * N] w2 = w2[..., :N, :K] w1 = deshuffle(w1) @@ -261,7 +265,8 @@ class Case: @pytest.mark.parametrize( ", ".join(f.name for f in fields(Case)), [ - tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + tuple(getattr(case, f.name) for f in fields(Case)) + for case in [ # Case(a_dtype="bf16", w_dtype="bf16"), # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), Case(a_dtype="bf16", w_dtype="mx4") @@ -321,10 +326,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp): gating_output=exp_data, topk=topk, ) - assert_close(ref=out_ref, - tri=out_triton_monolithic, - maxtol=0.025, - rmstol=0.005) + assert_close(ref=out_ref, tri=out_triton_monolithic, maxtol=0.025, rmstol=0.005) def batched_moe( @@ -376,7 +378,8 @@ def batched_moe( @pytest.mark.parametrize( ", ".join(f.name for f in fields(Case)), [ - tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + tuple(getattr(case, f.name) for f in fields(Case)) + for case in [ # Case(a_dtype="bf16", w_dtype="bf16"), # Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"), Case(a_dtype="bf16", w_dtype="mx4") diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py index 646e763194fd..3f4f142be767 100644 --- a/tests/kernels/moe/test_grouped_topk.py +++ b/tests/kernels/moe/test_grouped_topk.py @@ -4,16 +4,20 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`. """ + import pytest import torch -from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk, - grouped_topk) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_grouped_topk, + grouped_topk, +) from vllm.platforms import current_platform -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test is skipped on non-CUDA platform.") +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) @pytest.mark.parametrize("n_token", [1, 33, 64]) @pytest.mark.parametrize("n_hidden", [1024, 2048]) @pytest.mark.parametrize("n_expert", [16]) @@ -23,23 +27,26 @@ @pytest.mark.parametrize("topk_group", [2]) @pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) @pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) -@pytest.mark.parametrize("dtype", - [torch.float16, torch.bfloat16, torch.float32]) -def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int, - n_hidden: int, n_expert: int, topk: int, - renormalize: bool, num_expert_group: int, - topk_group: int, scoring_func: str, - routed_scaling_factor: float, dtype: torch.dtype): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_grouped_topk( + monkeypatch: pytest.MonkeyPatch, + n_token: int, + n_hidden: int, + n_expert: int, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str, + routed_scaling_factor: float, + dtype: torch.dtype, +): current_platform.seed_everything(0) - hidden_states = torch.randn((n_token, n_hidden), - dtype=dtype, - device="cuda") - gating_output = torch.randn((n_token, n_expert), - dtype=dtype, - device="cuda") - e_score_correction_bias = torch.randn((n_expert, ), - dtype=torch.float32, - device="cuda") + hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda") + gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda") + e_score_correction_bias = torch.randn( + (n_expert,), dtype=torch.float32, device="cuda" + ) with monkeypatch.context() as m: m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") @@ -52,7 +59,8 @@ def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int, topk_group=topk_group, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + ) test_topk_weights, test_topk_ids = fused_grouped_topk( hidden_states=hidden_states, @@ -63,14 +71,11 @@ def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int, topk_group=topk_group, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + ) if renormalize: - torch.testing.assert_close(baseline_topk_weights, - test_topk_weights, - atol=2e-2, - rtol=0) - torch.testing.assert_close(baseline_topk_ids, - test_topk_ids, - atol=0, - rtol=0) + torch.testing.assert_close( + baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0 + ) + torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0) diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 1c7e62d7aa4c..9c4114523590 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -17,18 +17,29 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from ...utils import multi_gpu_test -from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, - reference_moe_impl, - run_modular_kernel) +from .modular_kernel_tools.common import ( + Config, + RankTensors, + WeightTensors, + reference_moe_impl, + run_modular_kernel, +) from .modular_kernel_tools.mk_objects import ( - MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, - MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig, - expert_info) -from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, - parallel_launch_with_config) + MK_FUSED_EXPERT_TYPES, + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, + MK_QUANT_CONFIGS, + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, + TestMoEQuantConfig, + expert_info, +) +from .modular_kernel_tools.parallel_utils import ( + ProcessGroupInfo, + parallel_launch_with_config, +) -has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx() - or has_flashinfer_cutlass_fused_moe()) +has_any_multi_gpu_package = ( + has_deep_ep() or has_deep_gemm() or has_pplx() or has_flashinfer_cutlass_fused_moe() +) meets_multi_gpu_requirements = pytest.mark.skipif( not has_any_multi_gpu_package, @@ -64,9 +75,9 @@ def rank_worker( # sanity check from vllm import envs + if base_config.fused_moe_chunk_size is not None: - assert ( - base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + assert base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE # get weights to this device weights.to_current_device() @@ -93,8 +104,7 @@ def rank_worker( rank_tensors = RankTensors.make(config, pgi) # modular kernel out - mk_out = run_modular_kernel(pgi, vllm_config, config, weights, - rank_tensors) + mk_out = run_modular_kernel(pgi, vllm_config, config, weights, rank_tensors) with set_current_vllm_config(vllm_config): ref_out = reference_moe_impl(config, weights, rank_tensors) @@ -115,10 +125,10 @@ def rank_worker( if len(exceptions) > 0: raise RuntimeError( f"{len(exceptions)} of {count} tests failed in child process, " - f"rank={pgi.rank}.") + f"rank={pgi.rank}." + ) else: - print(f"{count} of {count} tests passed in child process, " - f"rank={pgi.rank}.") + print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.") def run(config: Config, verbose: bool): @@ -127,8 +137,9 @@ def run(config: Config, verbose: bool): weights: WeightTensors = WeightTensors.make(config) vllm_config, env_dict = config.make_env_data() - parallel_launch_with_config(config.world_size, rank_worker, vllm_config, - env_dict, config, weights, verbose) + parallel_launch_with_config( + config.world_size, rank_worker, vllm_config, env_dict, config, weights, verbose + ) Ms = [32, 64] @@ -149,8 +160,9 @@ def is_nyi_config(config: Config) -> bool: if info.needs_matching_quant: # The triton kernels expect both per-act-token-quant and # per-out-ch-quant or neither. - unsupported_quant_config = ((config.is_per_act_token_quant + - config.is_per_out_ch_quant) == 1) + unsupported_quant_config = ( + config.is_per_act_token_quant + config.is_per_out_ch_quant + ) == 1 return unsupported_quant_config return not info.supports_expert_map @@ -162,19 +174,25 @@ def is_nyi_config(config: Config) -> bool: @pytest.mark.parametrize("dtype", DTYPEs) @pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) @pytest.mark.parametrize( - "combination", - product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) + "combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES) +) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [2]) @multi_gpu_test(num_gpus=2) @meets_multi_gpu_requirements def test_modular_kernel_combinations_multigpu( - k: int, n: int, e: int, dtype: torch.dtype, - quant_config: Optional[TestMoEQuantConfig], - combination: tuple[mk.FusedMoEPrepareAndFinalize, - mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): - + k: int, + n: int, + e: int, + dtype: torch.dtype, + quant_config: Optional[TestMoEQuantConfig], + combination: tuple[ + mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute + ], + fused_moe_chunk_size: Optional[int], + world_size: int, + pytestconfig, +): config = Config( Ms=Ms, K=k, @@ -195,7 +213,7 @@ def test_modular_kernel_combinations_multigpu( if is_nyi_config(config): pytest.skip(f"Tests config {config} is nyi. Skipping ...") - verbosity = pytestconfig.getoption('verbose') + verbosity = pytestconfig.getoption("verbose") run(config, verbosity > 0) @@ -205,16 +223,23 @@ def test_modular_kernel_combinations_multigpu( @pytest.mark.parametrize("dtype", DTYPEs) @pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) @pytest.mark.parametrize( - "combination", - product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) + "combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES) +) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("world_size", [1]) def test_modular_kernel_combinations_singlegpu( - k: int, n: int, e: int, dtype: torch.dtype, - quant_config: Optional[TestMoEQuantConfig], - combination: tuple[mk.FusedMoEPrepareAndFinalize, - mk.FusedMoEPermuteExpertsUnpermute], - fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig): + k: int, + n: int, + e: int, + dtype: torch.dtype, + quant_config: Optional[TestMoEQuantConfig], + combination: tuple[ + mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute + ], + fused_moe_chunk_size: Optional[int], + world_size: int, + pytestconfig, +): config = Config( Ms=Ms, K=k, @@ -235,19 +260,21 @@ def test_modular_kernel_combinations_singlegpu( if is_nyi_config(config): pytest.skip(f"Tests config {config} is nyi. Skipping ...") - verbosity = pytestconfig.getoption('verbose') + verbosity = pytestconfig.getoption("verbose") run(config, verbosity > 0) -if __name__ == '__main__': +if __name__ == "__main__": # Ability to test individual PrepareAndFinalize and FusedExperts combination - from .modular_kernel_tools.cli_args import (make_config, - make_config_arg_parser) - parser = make_config_arg_parser(description=( - "Run single prepare-finalize & fused-experts combination test" - "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501 - "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" - )) + from .modular_kernel_tools.cli_args import make_config, make_config_arg_parser + + parser = make_config_arg_parser( + description=( + "Run single prepare-finalize & fused-experts combination test" + "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " + "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" + ) + ) args = parser.parse_args() config = make_config(args) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 00835bec9a15..9354e819877a 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -4,6 +4,7 @@ Run `pytest tests/kernels/test_moe.py`. """ + import functools from typing import Callable, Optional, Union @@ -21,22 +22,32 @@ from vllm.distributed.parallel_state import init_distributed_environment from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.config import ( - FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config, - int8_w8a16_moe_quant_config) + FUSED_MOE_UNQUANTIZED_CONFIG, + int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, modular_triton_fused_moe) + fused_topk, + modular_triton_fused_moe, +) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as iterative_moe) + fused_moe as iterative_moe, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_permute_bias) + marlin_permute_bias, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like) + rand_marlin_weight_mxfp4_like, + rand_marlin_weight_nvfp4_like, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - marlin_quant_fp8_torch) + marlin_quant_fp8_torch, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - awq_marlin_quantize, marlin_quantize) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - quantize_weights) + awq_marlin_quantize, + marlin_quantize, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -87,13 +98,15 @@ def run_moe_test( if isinstance(baseline, torch.Tensor): baseline_output = baseline else: - baseline_output = baseline(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) + baseline_output = baseline( + a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) # Pad the weight if moe padding is enabled if padding: @@ -105,34 +118,35 @@ def run_moe_test( torch._dynamo.mark_dynamic(a, 0) torch._dynamo.mark_dynamic(score, 0) - test_output = moe_fn(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) + test_output = moe_fn( + a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) if use_cudagraph: test_output.fill_(0) stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - test_output = moe_fn(a, - w1, - w2, - score, - topk, - global_num_experts=global_num_experts, - expert_map=expert_map) + test_output = moe_fn( + a, + w1, + w2, + score, + topk, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) torch.cuda.synchronize() graph.replay() torch.cuda.synchronize() - torch.testing.assert_close(test_output, - baseline_output, - atol=atol, - rtol=rtol) + torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol) return baseline_output @@ -176,11 +190,8 @@ def test_fused_moe( if ep_size > 1: local_e = e // ep_size - e_ids = torch.randint(0, - e, (local_e, ), - device="cuda", - dtype=torch.int32) - e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32) + e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1 = w1[e_ids] w2 = w2[e_ids] @@ -204,13 +215,15 @@ def m_fused_moe( expert_map: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - return m_fused_moe_fn(a, - w1, - w2, - topk_weights, - topk_ids, - global_num_experts=global_num_experts, - expert_map=expert_map) + return m_fused_moe_fn( + a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) fused_moe_fn = functools.partial(fused_moe, renormalize=False) @@ -234,19 +247,22 @@ def m_fused_moe( # setup code in case we are able to revisit this later. use_compile = False - use_cudagraph = (n >= 1024 and k >= 1024 - and current_platform.is_cuda_alike()) + use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike() with set_current_vllm_config(vllm_config): baseline_output = runner(torch_moe, iterative_moe) - runner(baseline_output, - fused_moe_fn, - use_compile=use_compile, - use_cudagraph=use_cudagraph) - runner(baseline_output, - m_fused_moe, - use_compile=use_compile, - use_cudagraph=use_cudagraph) + runner( + baseline_output, + fused_moe_fn, + use_compile=use_compile, + use_cudagraph=use_cudagraph, + ) + runner( + baseline_output, + m_fused_moe, + use_compile=use_compile, + use_cudagraph=use_cudagraph, + ) @pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS) @@ -257,9 +273,18 @@ def m_fused_moe( @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.parametrize("has_zp", [True, False]) @pytest.mark.parametrize("weight_bits", [4, 8]) -def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, - ep_size: int, dtype: torch.dtype, group_size: int, - has_zp: bool, weight_bits: int): +def test_fused_moe_wn16( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, + group_size: int, + has_zp: bool, + weight_bits: int, +): a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -274,35 +299,40 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, w1_ref = w1.clone() w2_ref = w2.clone() - w1_qweight = torch.empty((e, 2 * n, k // pack_factor), - device="cuda", - dtype=torch.uint8) - w2_qweight = torch.empty((e, k, n // pack_factor), - device="cuda", - dtype=torch.uint8) - w1_scales = torch.empty((e, 2 * n, k // group_size), - device="cuda", - dtype=dtype) - w2_scales = torch.empty((e, k, n // group_size), - device="cuda", - dtype=dtype) - w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size), - device="cuda", - dtype=torch.uint8) - w2_qzeros = torch.empty((e, k // pack_factor, n // group_size), - device="cuda", - dtype=torch.uint8) + w1_qweight = torch.empty( + (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8 + ) + w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8) + w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype) + w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype) + w1_qzeros = torch.empty( + (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8 + ) + w2_qzeros = torch.empty( + (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8 + ) for i in range(e * 2): expert_id = i % e if i // e == 0: - w, w_ref, w_qweight, w_scales, w_qzeros = \ - w1, w1_ref, w1_qweight, w1_scales, w1_qzeros + w, w_ref, w_qweight, w_scales, w_qzeros = ( + w1, + w1_ref, + w1_qweight, + w1_scales, + w1_qzeros, + ) else: - w, w_ref, w_qweight, w_scales, w_qzeros = \ - w2, w2_ref, w2_qweight, w2_scales, w2_qzeros + w, w_ref, w_qweight, w_scales, w_qzeros = ( + w2, + w2_ref, + w2_qweight, + w2_scales, + w2_qzeros, + ) weight, qweight, scales, qzeros = quantize_weights( - w[expert_id].T, quant_type, group_size, has_zp, False) + w[expert_id].T, quant_type, group_size, has_zp, False + ) weight = weight.T qweight = qweight.T.contiguous().to(torch.uint8) scales = scales.T @@ -321,11 +351,8 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, if ep_size > 1: local_e = e // ep_size - e_ids = torch.randint(0, - e, (local_e, ), - device="cuda", - dtype=torch.int32) - e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32) + e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1_ref = w1_ref[e_ids] w2_ref = w2_ref[e_ids] @@ -344,28 +371,27 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, assert weight_bits == 8 quant_config_builder = int8_w8a16_moe_quant_config - quant_config = quant_config_builder(w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) + quant_config = quant_config_builder( + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size], + ) with set_current_vllm_config(vllm_config): - triton_output = fused_moe(a, - w1_qweight, - w2_qweight, - score, - topk, - renormalize=False, - global_num_experts=e, - expert_map=e_map, - quant_config=quant_config) - torch_output = torch_moe(a, - w1_ref, - w2_ref, - score, - topk, - expert_map=e_map) + triton_output = fused_moe( + a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + global_num_experts=e, + expert_map=e_map, + quant_config=quant_config, + ) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -373,16 +399,20 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("padding", [True, False]) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) @torch.inference_mode() -def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool, - use_rocm_aiter: bool, monkeypatch): +def test_mixtral_moe( + dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch +): """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" # clear the cache before every test from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled, + ) + is_rocm_aiter_moe_enabled.cache_clear() if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -390,17 +420,16 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool, if dtype == torch.float32: pytest.skip("AITER ROCm test skip for float32") - monkeypatch.setenv('RANK', "0") - monkeypatch.setenv('LOCAL_RANK', "0") - monkeypatch.setenv('WORLD_SIZE', "1") - monkeypatch.setenv('MASTER_ADDR', 'localhost') - monkeypatch.setenv('MASTER_PORT', '12345') + monkeypatch.setenv("RANK", "0") + monkeypatch.setenv("LOCAL_RANK", "0") + monkeypatch.setenv("WORLD_SIZE", "1") + monkeypatch.setenv("MASTER_ADDR", "localhost") + monkeypatch.setenv("MASTER_PORT", "12345") init_distributed_environment() # Instantiate our and huggingface's MoE blocks vllm_config.compilation_config.static_forward_context = dict() - with (set_current_vllm_config(vllm_config), - set_forward_context(None, vllm_config)): + with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config): config = MixtralConfig() hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") vllm_moe = MixtralMoE( @@ -416,27 +445,30 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool, # Load the weights vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): - weights = (hf_moe.experts[i].w1.weight.data, - hf_moe.experts[i].w3.weight.data) + weights = ( + hf_moe.experts[i].w1.weight.data, + hf_moe.experts[i].w3.weight.data, + ) vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] - hf_inputs = torch.randn( - (1, 64, config.hidden_size)).to(dtype).to("cuda") + hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") # vLLM uses 1D query [num_tokens, hidden_dim] vllm_inputs = hf_inputs.flatten(0, 1) # Pad the weight if moe padding is enabled if padding: - vllm_moe.experts.w13_weight = Parameter(F.pad( - vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., - 0:-128], - requires_grad=False) - vllm_moe.experts.w2_weight = Parameter(F.pad( - vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., - 0:-128], - requires_grad=False) + vllm_moe.experts.w13_weight = Parameter( + F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[ + ..., 0:-128 + ], + requires_grad=False, + ) + vllm_moe.experts.w2_weight = Parameter( + F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128], + requires_grad=False, + ) torch.cuda.synchronize() torch.cuda.empty_cache() @@ -451,21 +483,23 @@ def test_mixtral_moe(dist_init, dtype: torch.dtype, padding: bool, } if use_rocm_aiter: - # The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501 - # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501 - torch.testing.assert_close(hf_states.flatten(0, 1), - vllm_states, - rtol=0.01, - atol=100) + # The values of rtol and atol are set based on the tests in ROCM AITER package. + # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 + torch.testing.assert_close( + hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100 + ) else: - torch.testing.assert_close(hf_states.flatten(0, 1), - vllm_states, - rtol=mixtral_moe_tol[dtype], - atol=mixtral_moe_tol[dtype]) + torch.testing.assert_close( + hf_states.flatten(0, 1), + vllm_states, + rtol=mixtral_moe_tol[dtype], + atol=mixtral_moe_tol[dtype], + ) def marlin_moe_generate_valid_test_cases(): import itertools + m_list = [1, 123, 666] n_list = [128, 1024] k_list = [256, 2048] @@ -484,16 +518,24 @@ def marlin_moe_generate_valid_test_cases(): ] is_k_full_list = [True, False] - all_combinations = itertools.product(m_list, n_list, k_list, e_list, - topk_list, ep_size_list, dtype_list, - group_size_list, act_order_list, - quant_type_list, is_k_full_list) - - def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order, - quant_type, is_k_full): + all_combinations = itertools.product( + m_list, + n_list, + k_list, + e_list, + topk_list, + ep_size_list, + dtype_list, + group_size_list, + act_order_list, + quant_type_list, + is_k_full_list, + ) - if quant_type == scalar_types.float8_e4m3fn and \ - group_size not in [-1, 128]: + def is_invalid( + m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full + ): + if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]: return False if quant_type == scalar_types.float4_e2m1f: if group_size not in [16, 32]: @@ -522,9 +564,10 @@ def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order, @pytest.mark.flaky(reruns=2) -@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size," - "act_order, quant_type, is_k_full"), - marlin_moe_generate_valid_test_cases()) +@pytest.mark.parametrize( + ("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"), + marlin_moe_generate_valid_test_cases(), +) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, @@ -549,7 +592,7 @@ def test_fused_marlin_moe( if ep_size > 1: local_e = e // ep_size e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e] - e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) + e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32) e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) w1 = w1[e_ids] w2 = w2[e_ids] @@ -567,11 +610,13 @@ def test_fused_marlin_moe( for i in range(w1.shape[0]): if quant_type == scalar_types.float4_e2m1f: if group_size == 16: - w_ref1, qweight1, scales1, global_scale1 = \ + w_ref1, qweight1, scales1, global_scale1 = ( rand_marlin_weight_nvfp4_like(w1[i], group_size) + ) else: - w_ref1, qweight1, scales1 = \ - rand_marlin_weight_mxfp4_like(w1[i], group_size) + w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like( + w1[i], group_size + ) global_scale1 = None w_ref1_l.append(w_ref1.T) @@ -580,14 +625,14 @@ def test_fused_marlin_moe( if global_scale1 is not None: global_scale1_l.append(global_scale1) elif quant_type == scalar_types.float8_e4m3fn: - w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( - w1[i], group_size) + w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) scales1_l.append(scales1) elif has_zp: w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size) + w1[i].transpose(1, 0), quant_type, group_size + ) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) @@ -595,9 +640,9 @@ def test_fused_marlin_moe( zeros1_l.append(zeros1) else: test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ - marlin_quantize(w1[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) @@ -624,11 +669,13 @@ def test_fused_marlin_moe( for i in range(w2.shape[0]): if quant_type == scalar_types.float4_e2m1f: if group_size == 16: - w_ref2, qweight2, scales2, global_scale2 = \ + w_ref2, qweight2, scales2, global_scale2 = ( rand_marlin_weight_nvfp4_like(w2[i], group_size) + ) else: - w_ref2, qweight2, scales2 = \ - rand_marlin_weight_mxfp4_like(w2[i], group_size) + w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like( + w2[i], group_size + ) global_scale2 = None w_ref2_l.append(w_ref2.T) @@ -637,14 +684,14 @@ def test_fused_marlin_moe( if global_scale2 is not None: global_scale2_l.append(global_scale2) elif quant_type == scalar_types.float8_e4m3fn: - w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( - w2[i], group_size) + w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) scales2_l.append(scales2) elif has_zp: w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size) + w2[i].transpose(1, 0), quant_type, group_size + ) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) @@ -652,9 +699,9 @@ def test_fused_marlin_moe( zeros2_l.append(zeros2) else: test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ - marlin_quantize(w2[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) @@ -675,12 +722,7 @@ def test_fused_marlin_moe( topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, - w_ref1, - w_ref2, - score, - topk, - expert_map=e_map) + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, @@ -704,7 +746,8 @@ def test_fused_marlin_moe( w1_zeros=zeros1, w2_zeros=zeros2, quant_type_id=quant_type.id, - is_k_full=is_k_full) + is_k_full=is_k_full, + ) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) @@ -738,9 +781,9 @@ def test_fused_marlin_moe_with_bias(m): for i in range(w1.shape[0]): test_perm = torch.randperm(k) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ - marlin_quantize(w1[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) @@ -767,9 +810,9 @@ def test_fused_marlin_moe_with_bias(m): for i in range(w2.shape[0]): test_perm = torch.randperm(n) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ - marlin_quantize(w2[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm + ) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) @@ -792,8 +835,7 @@ def test_fused_marlin_moe_with_bias(m): topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) with set_current_vllm_config(vllm_config): - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, - b_bias2) + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2) marlin_output = torch.ops.vllm.fused_marlin_moe( a, @@ -817,7 +859,8 @@ def test_fused_marlin_moe_with_bias(m): w1_zeros=zeros1, w2_zeros=zeros2, quant_type_id=quant_type.id, - is_k_full=is_k_full) + is_k_full=is_k_full, + ) torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) @@ -825,34 +868,36 @@ def test_fused_marlin_moe_with_bias(m): def test_moe_align_block_size_opcheck(): num_experts = 4 block_size = 4 - topk_ids = torch.randint(0, - num_experts, (3, 4), - dtype=torch.int32, - device='cuda') + topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda") max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids = torch.empty((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) - - opcheck(torch.ops._moe_C.moe_align_block_size, - (topk_ids, num_experts, block_size, sorted_ids, expert_ids, - num_tokens_post_pad)) + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + + opcheck( + torch.ops._moe_C.moe_align_block_size, + ( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ), + ) @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): input = torch.randn((m, topk, k), device="cuda", dtype=dtype) diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py index 5dfc8d9fab32..f92526e74955 100644 --- a/tests/kernels/moe/test_moe_align_block_size.py +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -11,7 +11,8 @@ import torch from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) + moe_align_block_size, +) from vllm.platforms import current_platform from vllm.utils import round_up @@ -60,30 +61,33 @@ def _verify_expert_level_sorting( in topk_ids in the final sorted_ids however this does not impact quality. """ # Group tokens by expert from the golden implementation - golden_expert_tokens = _group_tokens_by_expert(golden_sorted_ids, - expert_ids, block_size, - valid_length, total_tokens) + golden_expert_tokens = _group_tokens_by_expert( + golden_sorted_ids, expert_ids, block_size, valid_length, total_tokens + ) - actual_expert_tokens = _group_tokens_by_expert(actual_sorted_ids, - expert_ids, block_size, - valid_length, total_tokens) + actual_expert_tokens = _group_tokens_by_expert( + actual_sorted_ids, expert_ids, block_size, valid_length, total_tokens + ) - assert set(golden_expert_tokens.keys()) == set( - actual_expert_tokens.keys()), ( - f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, " - f"actual={set(actual_expert_tokens.keys())}") + assert set(golden_expert_tokens.keys()) == set(actual_expert_tokens.keys()), ( + f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, " + f"actual={set(actual_expert_tokens.keys())}" + ) for expert_id in golden_expert_tokens: - golden_tokens = torch.tensor(golden_expert_tokens[expert_id], - device=actual_sorted_ids.device) - actual_tokens = torch.tensor(actual_expert_tokens[expert_id], - device=actual_sorted_ids.device) + golden_tokens = torch.tensor( + golden_expert_tokens[expert_id], device=actual_sorted_ids.device + ) + actual_tokens = torch.tensor( + actual_expert_tokens[expert_id], device=actual_sorted_ids.device + ) assert torch.equal( - torch.sort(golden_tokens)[0], - torch.sort(actual_tokens)[0]), ( - f"Expert {expert_id} token mismatch: " - f"golden={golden_expert_tokens[expert_id]}, " - f"actual={actual_expert_tokens[expert_id]}") + torch.sort(golden_tokens)[0], torch.sort(actual_tokens)[0] + ), ( + f"Expert {expert_id} token mismatch: " + f"golden={golden_expert_tokens[expert_id]}, " + f"actual={actual_expert_tokens[expert_id]}" + ) def torch_moe_align_block_size( @@ -104,40 +108,38 @@ def torch_moe_align_block_size( if pad_sorted_ids: max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) - flattened_token_indices = torch.arange(topk_ids.numel(), - device=topk_ids.device, - dtype=torch.int32) + flattened_token_indices = torch.arange( + topk_ids.numel(), device=topk_ids.device, dtype=torch.int32 + ) flattened_expert_ids = topk_ids.flatten() - sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, - stable=True) + sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, stable=True) sorted_token_indices = flattened_token_indices[sort_indices] - expert_token_counts = torch.zeros(num_experts, - dtype=torch.int64, - device=topk_ids.device) + expert_token_counts = torch.zeros( + num_experts, dtype=torch.int64, device=topk_ids.device + ) for expert_id in range(num_experts): mask = sorted_expert_ids == expert_id expert_token_counts[expert_id] = mask.sum() - expert_padded_counts = torch.zeros(num_experts, - dtype=torch.int64, - device=topk_ids.device) + expert_padded_counts = torch.zeros( + num_experts, dtype=torch.int64, device=topk_ids.device + ) for expert_id in range(num_experts): original_count = expert_token_counts[expert_id] if original_count > 0: expert_padded_counts[expert_id] = ( - (original_count + block_size - 1) // block_size) * block_size + (original_count + block_size - 1) // block_size + ) * block_size sorted_token_ids = torch.full( - (max_num_tokens_padded, ), + (max_num_tokens_padded,), topk_ids.numel(), dtype=torch.int32, device=topk_ids.device, ) max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size - expert_ids = torch.zeros(max_num_blocks, - dtype=torch.int32, - device=topk_ids.device) + expert_ids = torch.zeros(max_num_blocks, dtype=torch.int32, device=topk_ids.device) current_pos = 0 current_block = 0 @@ -147,20 +149,20 @@ def torch_moe_align_block_size( num_expert_tokens = expert_tokens.shape[0] if num_expert_tokens > 0: - sorted_token_ids[current_pos:current_pos + - num_expert_tokens] = (expert_tokens) + sorted_token_ids[current_pos : current_pos + num_expert_tokens] = ( + expert_tokens + ) expert_blocks_needed = expert_padded_counts[expert_id] // block_size - expert_ids[current_block:current_block + - expert_blocks_needed] = (expert_id) + expert_ids[current_block : current_block + expert_blocks_needed] = expert_id current_pos += expert_padded_counts[expert_id] current_block += expert_blocks_needed total_padded_tokens = expert_padded_counts.sum() - num_tokens_post_pad = torch.tensor([total_padded_tokens], - dtype=torch.int32, - device=topk_ids.device) + num_tokens_post_pad = torch.tensor( + [total_padded_tokens], dtype=torch.int32, device=topk_ids.device + ) if expert_map is not None: expert_ids = expert_map[expert_ids] @@ -173,37 +175,32 @@ def torch_moe_align_block_size( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("pad_sorted_ids", [False, True]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") -def test_moe_align_block_size(m: int, topk: int, num_experts: int, - block_size: int, pad_sorted_ids: bool): +def test_moe_align_block_size( + m: int, topk: int, num_experts: int, block_size: int, pad_sorted_ids: bool +): """Test moe_align_block_size without expert mapping""" topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32) for i in range(m): experts = torch.randperm(num_experts, device="cuda")[:topk] topk_ids[i] = experts - actual_sorted_ids, actual_expert_ids, actual_num_tokens = ( - moe_align_block_size( - topk_ids=topk_ids, - block_size=block_size, - num_experts=num_experts, - pad_sorted_ids=pad_sorted_ids, - )) + actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + pad_sorted_ids=pad_sorted_ids, + ) golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( torch_moe_align_block_size( topk_ids=topk_ids, block_size=block_size, num_experts=num_experts, pad_sorted_ids=pad_sorted_ids, - )) + ) + ) - torch.testing.assert_close(actual_num_tokens, - golden_num_tokens, - atol=0, - rtol=0) - torch.testing.assert_close(actual_expert_ids, - golden_expert_ids, - atol=0, - rtol=0) + torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0) + torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0) # For sorted_token_ids, verify block-level correctness rather than exact # order Tokens within each expert's blocks can be in any order, but expert @@ -219,16 +216,18 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int, total_tokens = m * topk assert actual_num_tokens.item() % block_size == 0, ( - "num_tokens_post_pad should be divisible by block_size") + "num_tokens_post_pad should be divisible by block_size" + ) assert actual_num_tokens.item() >= total_tokens, ( - "num_tokens_post_pad should be at least total_tokens") + "num_tokens_post_pad should be at least total_tokens" + ) valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens] assert len(valid_tokens) == total_tokens, ( - f"Should have exactly {total_tokens} valid tokens, " - f"got {len(valid_tokens)}") - assert (actual_expert_ids >= 0).all() and ( - actual_expert_ids - < num_experts).all(), "expert_ids should contain valid expert indices" + f"Should have exactly {total_tokens} valid tokens, got {len(valid_tokens)}" + ) + assert (actual_expert_ids >= 0).all() and (actual_expert_ids < num_experts).all(), ( + "expert_ids should contain valid expert indices" + ) @pytest.mark.parametrize("m", [16, 32]) @@ -236,46 +235,37 @@ def test_moe_align_block_size(m: int, topk: int, num_experts: int, @pytest.mark.parametrize("num_experts", [8]) @pytest.mark.parametrize("block_size", [64]) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") -def test_moe_align_block_size_with_expert_map(m: int, topk: int, - num_experts: int, - block_size: int): +def test_moe_align_block_size_with_expert_map( + m: int, topk: int, num_experts: int, block_size: int +): """Test moe_align_block_size with expert mapping (EP scenario)""" topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32) for i in range(m): experts = torch.randperm(num_experts, device="cuda")[:topk] topk_ids[i] = experts - expert_map = torch.full((num_experts, ), - -1, - device="cuda", - dtype=torch.int32) + expert_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32) local_experts = list(range(0, num_experts, 2)) for i, expert_id in enumerate(local_experts): expert_map[expert_id] = i - actual_sorted_ids, actual_expert_ids, actual_num_tokens = ( - moe_align_block_size( - topk_ids=topk_ids, - block_size=block_size, - num_experts=num_experts, - expert_map=expert_map, - )) + actual_sorted_ids, actual_expert_ids, actual_num_tokens = moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + expert_map=expert_map, + ) golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( torch_moe_align_block_size( topk_ids=topk_ids, block_size=block_size, num_experts=num_experts, expert_map=expert_map, - )) - - torch.testing.assert_close(actual_num_tokens, - golden_num_tokens, - atol=0, - rtol=0) - torch.testing.assert_close(actual_expert_ids, - golden_expert_ids, - atol=0, - rtol=0) + ) + ) + + torch.testing.assert_close(actual_num_tokens, golden_num_tokens, atol=0, rtol=0) + torch.testing.assert_close(actual_expert_ids, golden_expert_ids, atol=0, rtol=0) _verify_expert_level_sorting( actual_sorted_ids, golden_sorted_ids, @@ -290,26 +280,25 @@ def test_moe_align_block_size_deterministic(): m, topk, num_experts, block_size = 128, 2, 32, 64 torch.manual_seed(42) - topk_ids = torch.randint(0, - num_experts, (m, topk), - device="cuda", - dtype=torch.int32) + topk_ids = torch.randint( + 0, num_experts, (m, topk), device="cuda", dtype=torch.int32 + ) # expect the results to be reproducible results = [] for _ in range(5): sorted_ids, expert_ids, num_tokens = moe_align_block_size( - topk_ids=topk_ids, block_size=block_size, num_experts=num_experts) - results.append( - (sorted_ids.clone(), expert_ids.clone(), num_tokens.clone())) + topk_ids=topk_ids, block_size=block_size, num_experts=num_experts + ) + results.append((sorted_ids.clone(), expert_ids.clone(), num_tokens.clone())) for i in range(1, len(results)): - assert torch.equal( - results[0][0], - results[i][0]), ("sorted_ids should be deterministic") - assert torch.equal( - results[0][1], - results[i][1]), ("expert_ids should be deterministic") - assert torch.equal( - results[0][2], - results[i][2]), ("num_tokens should be deterministic") + assert torch.equal(results[0][0], results[i][0]), ( + "sorted_ids should be deterministic" + ) + assert torch.equal(results[0][1], results[i][1]), ( + "expert_ids should be deterministic" + ) + assert torch.equal(results[0][2], results[i][2]), ( + "num_tokens should be deterministic" + ) diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index d71664d94b9c..a6214437d404 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -14,7 +14,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.layer import determine_expert_map from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - moe_permute, moe_permute_unpermute_supported, moe_unpermute) + moe_permute, + moe_permute_unpermute_supported, + moe_unpermute, +) from vllm.platforms import current_platform NUM_EXPERTS = [16, 64, 256] @@ -24,35 +27,34 @@ def torch_permute( - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - # token_expert_indices: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, - start_expert: int, - expert_map: Optional[torch.Tensor] = None, - align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1) -> list[torch.Tensor]: + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + # token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + start_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1, +) -> list[torch.Tensor]: n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] if expert_map is not None: - is_local_expert = (expert_map[topk_ids] != -1) - not_local_expert = (expert_map[topk_ids] == -1) - topk_ids = is_local_expert * ( - topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) - token_expert_indices = torch.arange(0, - n_token * topk, - dtype=torch.int32, - device=hidden_states.device).reshape( - (n_token, topk)) + is_local_expert = expert_map[topk_ids] != -1 + not_local_expert = expert_map[topk_ids] == -1 + topk_ids = is_local_expert * (topk_ids - start_expert) + not_local_expert * ( + topk_ids + n_expert + ) + token_expert_indices = torch.arange( + 0, n_token * topk, dtype=torch.int32, device=hidden_states.device + ).reshape((n_token, topk)) - sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), - stable=True) + sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True) dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices] - expert_first_token_offset = torch.zeros(n_local_expert + 1, - dtype=torch.int64, - device="cuda") + expert_first_token_offset = torch.zeros( + n_local_expert + 1, dtype=torch.int64, device="cuda" + ) idx = 0 for i in range(0, n_local_expert): cnt = 0 @@ -64,116 +66,133 @@ def torch_permute( _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map) valid_row_idx = [] if align_block_size is None: - - permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // - topk, ...] + permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...] permuted_row_size = permuted_hidden_states.shape[0] - m_indices = torch.empty(permuted_row_size, - device="cuda", - dtype=torch.int32).fill_(fill_invalid_expert) + m_indices = torch.empty( + permuted_row_size, device="cuda", dtype=torch.int32 + ).fill_(fill_invalid_expert) for i in range(1, n_local_expert + 1): first_token_offset = expert_first_token_offset[i - 1] last_token_offset = expert_first_token_offset[i] m_indices[first_token_offset:last_token_offset] = i - 1 src_row_id2dst_row_id_map = torch.arange( - 0, n_token * topk, device="cuda", - dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) + 0, n_token * topk, device="cuda", dtype=torch.int32 + )[src2dst_idx].reshape((n_token, topk)) valid_row_idx += [i for i in range(expert_first_token_offset[-1])] - dst_row_id2src_row_id_map[ - expert_first_token_offset[-1]:] = n_token * topk + dst_row_id2src_row_id_map[expert_first_token_offset[-1] :] = n_token * topk return [ - permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices, - valid_row_idx + permuted_hidden_states, + expert_first_token_offset, + src_row_id2dst_row_id_map, + dst_row_id2src_row_id_map, + m_indices, + valid_row_idx, ] else: - permuted_row_size = (topk * n_token + n_expert * - (align_block_size - 1) + align_block_size - - 1) // align_block_size * align_block_size - permuted_idx = torch.full((permuted_row_size, ), - n_token * topk, - dtype=torch.int32, - device=hidden_states.device) - permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), - device="cuda", - dtype=hidden_states.dtype) - align_src_row_id2dst_row_id = torch.empty(n_token * topk, - device="cuda", - dtype=torch.int32) - align_expert_first_token_offset = torch.zeros_like( - expert_first_token_offset) - m_indices = torch.empty(permuted_row_size, - device="cuda", - dtype=torch.int32).fill_(fill_invalid_expert) + permuted_row_size = ( + (topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1) + // align_block_size + * align_block_size + ) + permuted_idx = torch.full( + (permuted_row_size,), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device, + ) + permuted_hidden_states = torch.empty( + (permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype + ) + align_src_row_id2dst_row_id = torch.empty( + n_token * topk, device="cuda", dtype=torch.int32 + ) + align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset) + m_indices = torch.empty( + permuted_row_size, device="cuda", dtype=torch.int32 + ).fill_(fill_invalid_expert) # get align_permuted_hidden_states, # valid row_idx and align_expert_first_token_offset for i in range(1, n_local_expert + 1): first_token_offset = expert_first_token_offset[i - 1] last_token_offset = expert_first_token_offset[i] n_token_in_expert = last_token_offset - first_token_offset - align_expert_first_token_offset[ - i] = align_expert_first_token_offset[ - i - 1] + (n_token_in_expert + align_block_size - - 1) // align_block_size * align_block_size + align_expert_first_token_offset[i] = ( + align_expert_first_token_offset[i - 1] + + (n_token_in_expert + align_block_size - 1) + // align_block_size + * align_block_size + ) align_first_token_offset = align_expert_first_token_offset[i - 1] align_last_token_offset = align_expert_first_token_offset[i] dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ - first_token_offset:first_token_offset + n_token_in_expert] + first_token_offset : first_token_offset + n_token_in_expert + ] # store token in current expert with align_first_token_offset - permuted_hidden_states[align_first_token_offset:\ - align_first_token_offset+n_token_in_expert,\ - ...] = hidden_states[\ - dst_row_id2src_row_id_in_expert // topk,\ - ...] - permuted_idx[align_first_token_offset:\ - align_first_token_offset+\ - n_token_in_expert] = dst_row_id2src_row_id_in_expert + permuted_hidden_states[ + align_first_token_offset : align_first_token_offset + n_token_in_expert, + ..., + ] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...] + permuted_idx[ + align_first_token_offset : align_first_token_offset + n_token_in_expert + ] = dst_row_id2src_row_id_in_expert # set current expert m_indices m_indices[align_first_token_offset:align_last_token_offset] = i - 1 valid_row_idx += [ - i for i in range(align_first_token_offset, - align_first_token_offset + n_token_in_expert) + i + for i in range( + align_first_token_offset, + align_first_token_offset + n_token_in_expert, + ) ] # get align_src_row_id2dst_row_id for i in range(n_token * topk): eid = sorted_topk_ids[i] - if (eid >= n_local_expert): + if eid >= n_local_expert: # check token not in local expert - align_src_row_id2dst_row_id[ - i] = align_expert_first_token_offset[-1] + align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1] continue first_token_offset = expert_first_token_offset[eid] align_first_token_offset = align_expert_first_token_offset[eid] token_offset = i - first_token_offset - align_src_row_id2dst_row_id[ - i] = align_first_token_offset + token_offset - align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\ - src2dst_idx].reshape((n_token, topk)) + align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset + align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape( + (n_token, topk) + ) return [ - permuted_hidden_states, align_expert_first_token_offset, - align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx + permuted_hidden_states, + align_expert_first_token_offset, + align_src_row_id2dst_row_id, + permuted_idx, + m_indices, + valid_row_idx, ] -def torch_unpermute(permuted_hidden_states: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - src_row_id2dst_row_id_map: torch.Tensor, - valid_row_idx: torch.Tensor, topk: int, - n_expert: int) -> torch.Tensor: +def torch_unpermute( + permuted_hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + src_row_id2dst_row_id_map: torch.Tensor, + valid_row_idx: torch.Tensor, + topk: int, + n_expert: int, +) -> torch.Tensor: # ignore invalid row n_hidden = permuted_hidden_states.shape[1] - mask = torch.zeros(permuted_hidden_states.shape[0], - dtype=bool, - device="cuda") + mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda") mask[valid_row_idx] = True permuted_hidden_states[~mask] = 0 permuted_hidden_states = permuted_hidden_states[ - src_row_id2dst_row_id_map.flatten(), ...] + src_row_id2dst_row_id_map.flatten(), ... + ] permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden) - output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to( - permuted_hidden_states.dtype) + output = ( + (permuted_hidden_states * topk_weights.unsqueeze(2)) + .sum(1) + .to(permuted_hidden_states.dtype) + ) return output @@ -184,59 +203,76 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("align_block_size", [None, 128]) -def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, - n_expert: int, ep_size: int, dtype: torch.dtype, - align_block_size: Optional[int]): +def test_moe_permute_unpermute( + n_token: int, + n_hidden: int, + topk: int, + n_expert: int, + ep_size: int, + dtype: torch.dtype, + align_block_size: Optional[int], +): if not moe_permute_unpermute_supported(): pytest.skip("moe_permute_unpermute is not supported on this platform.") fill_invalid_expert = 0 ep_rank = np.random.randint(0, ep_size) expert_map = None n_local_expert = n_expert - if (ep_size != 1): - n_local_expert, expert_map = determine_expert_map( - ep_size, ep_rank, n_expert) + if ep_size != 1: + n_local_expert, expert_map = determine_expert_map(ep_size, ep_rank, n_expert) expert_map = expert_map.cuda() start_expert = n_local_expert * ep_rank current_platform.seed_everything(0) hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype) gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states, gating_output, topk, False) - (gold_permuted_hidden_states, gold_expert_first_token_offset, - gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices, - valid_row_idx) = torch_permute( - hidden_states, - topk_ids, - # token_expert_indices, - topk, - n_expert, - n_local_expert, - start_expert, - expert_map=expert_map, - align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert) + hidden_states, gating_output, topk, False + ) + ( + gold_permuted_hidden_states, + gold_expert_first_token_offset, + gold_inv_permuted_idx, + gold_permuted_idx, + gold_m_indices, + valid_row_idx, + ) = torch_permute( + hidden_states, + topk_ids, + # token_expert_indices, + topk, + n_expert, + n_local_expert, + start_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert, + ) - (permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx, - m_indices) = moe_permute(hidden_states=hidden_states, - a1q_scale=None, - topk_ids=topk_ids, - n_expert=n_expert, - n_local_expert=n_local_expert, - expert_map=expert_map, - align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert) + ( + permuted_hidden_states, + _, + expert_first_token_offset, + inv_permuted_idx, + m_indices, + ) = moe_permute( + hidden_states=hidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=n_expert, + n_local_expert=n_local_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert, + ) # check expert_first_token_offset - torch.testing.assert_close(gold_expert_first_token_offset, - expert_first_token_offset, - atol=0, - rtol=0) + torch.testing.assert_close( + gold_expert_first_token_offset, expert_first_token_offset, atol=0, rtol=0 + ) # check src_row_id2dst_row_id_map - torch.testing.assert_close(gold_inv_permuted_idx.flatten(), - inv_permuted_idx, - atol=0, - rtol=0) + torch.testing.assert_close( + gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0 + ) # check mindice # current kernel usage assumes deepgemm requires align_block_size # when it's not provided then we don't compute m_indices (for cutlass) @@ -244,19 +280,28 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) # check permuted_hidden_states, only valid token - torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], - permuted_hidden_states[valid_row_idx], - atol=0, - rtol=0) + torch.testing.assert_close( + gold_permuted_hidden_states[valid_row_idx], + permuted_hidden_states[valid_row_idx], + atol=0, + rtol=0, + ) # add a random tensor to simulate group gemm - result0 = 0.5 * permuted_hidden_states + torch.randn_like( - permuted_hidden_states) + result0 = 0.5 * permuted_hidden_states + torch.randn_like(permuted_hidden_states) result4 = torch.empty_like(hidden_states) - moe_unpermute(result4, result0, topk_weights, inv_permuted_idx, - expert_first_token_offset) + moe_unpermute( + result4, result0, topk_weights, inv_permuted_idx, expert_first_token_offset + ) - gold4 = torch_unpermute(result0, topk_weights, topk_ids, - token_expert_indices, inv_permuted_idx, - valid_row_idx, topk, n_local_expert) + gold4 = torch_unpermute( + result0, + topk_weights, + topk_ids, + token_expert_indices, + inv_permuted_idx, + valid_row_idx, + topk, + n_local_expert, + ) # check unpermuted hidden torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py index 61d3311cc162..83241c000850 100644 --- a/tests/kernels/moe/test_mxfp4_moe.py +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -10,28 +10,40 @@ import torch from packaging import version -from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 - QuarkLinearMethod, QuarkW4A4MXFP4) -from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 - QuarkW4A4MXFp4MoEMethod) +from vllm.model_executor.layers.quantization.quark.quark import ( + QuarkLinearMethod, + QuarkW4A4MXFP4, +) +from vllm.model_executor.layers.quantization.quark.quark_moe import ( + QuarkW4A4MXFp4MoEMethod, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( - importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') + importlib.metadata.version("amd-quark") +) >= version.parse("0.8.99") -TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda( -) and current_platform.is_device_capability(100) +TRTLLM_GEN_MXFP4_AVAILABLE = ( + current_platform.is_cuda() and current_platform.is_device_capability(100) +) -HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda() - and current_platform.is_device_capability(90) - and has_flashinfer()) +HOPPER_MXFP4_BF16_AVAILABLE = ( + current_platform.is_cuda() + and current_platform.is_device_capability(90) + and has_flashinfer() +) if TRTLLM_GEN_MXFP4_AVAILABLE: - from flashinfer import (fp4_quantize, mxfp8_quantize, - next_positive_power_of_2, - reorder_rows_for_gated_act_gemm, shuffle_matrix_a, - shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) + from flashinfer import ( + fp4_quantize, + mxfp8_quantize, + next_positive_power_of_2, + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, + trtllm_fp4_block_scale_moe, + ) from flashinfer.fp4_quantization import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices @@ -48,21 +60,25 @@ def enable_pickle(monkeypatch): monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") -@pytest.mark.parametrize('model_case', [ - ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1), - ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), - ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1) -]) -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, - reason="amd-quark>=0.9 is not available") +@pytest.mark.parametrize( + "model_case", + [ + ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1), + ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), + ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1), + ], +) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): if torch.cuda.device_count() < model_case.tp: - pytest.skip(f"This test requires >={model_case.tp} gpus, got only " - f"{torch.cuda.device_count()}") + pytest.skip( + f"This test requires >={model_case.tp} gpus, got only " + f"{torch.cuda.device_count()}" + ) - with vllm_runner(model_case.model_id, - tensor_parallel_size=model_case.tp, - load_format="dummy") as llm: + with vllm_runner( + model_case.model_id, tensor_parallel_size=model_case.tp, load_format="dummy" + ) as llm: def check_model(model): layer = model.model.layers[0] @@ -72,21 +88,16 @@ def check_model(model): assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4) - assert isinstance(layer.mlp.experts.quant_method, - QuarkW4A4MXFp4MoEMethod) + assert isinstance(layer.mlp.experts.quant_method, QuarkW4A4MXFp4MoEMethod) if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4": llm.apply_model(check_model) - output = llm.generate_greedy("Today I am in the French Alps and", - max_tokens=20) + output = llm.generate_greedy("Today I am in the French Alps and", max_tokens=20) assert output -def swiglu(x, - alpha: float = 1.702, - beta: float = 1.0, - limit: Optional[float] = None): +def swiglu(x, alpha: float = 1.702, beta: float = 1.0, limit: Optional[float] = None): # Note we add an extra bias of 1 to the linear layer x_glu, x_linear = torch.chunk(x, 2, dim=-1) if limit is not None: @@ -96,24 +107,19 @@ def swiglu(x, return out_glu * (x_linear + beta) -fp4_lookup_table = [ - 0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6 -] +fp4_lookup_table = [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6] def mxfp4_dequantize(x, scale): assert x.dtype == torch.uint8 x = x.view(torch.uint8).to(torch.int32) - x_unpacked = torch.zeros(*x.shape[:-1], - x.shape[-1] * 2, - dtype=torch.int32, - device=x.device) + x_unpacked = torch.zeros( + *x.shape[:-1], x.shape[-1] * 2, dtype=torch.int32, device=x.device + ) x_unpacked[..., 0::2].copy_(x & 0xF) x_unpacked[..., 1::2].copy_((x >> 4) & 0xF) - x_float = torch.zeros(x_unpacked.shape, - dtype=torch.float32, - device=x.device) + x_float = torch.zeros(x_unpacked.shape, dtype=torch.float32, device=x.device) for i, val in enumerate(fp4_lookup_table): x_float[x_unpacked == i] = val @@ -162,9 +168,10 @@ def reference_moe( t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias t = swiglu(t, alpha=alpha, beta=beta, limit=limit) - if act_type == 'mxfp8': - t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16), - is_sf_swizzled_layout=False) + if act_type == "mxfp8": + t_quantized, t_scale = mxfp8_quantize( + t.to(torch.bfloat16), is_sf_swizzled_layout=False + ) t = mxfp8_dequantize(t_quantized, t_scale) # MLP #2 mlp2_weight = w2[expert_indices, ...] @@ -221,37 +228,53 @@ def tg_mxfp4_moe( transpose_optimized: bool = False, ) -> torch.Tensor: sf_block_size = 32 - assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts - and w13_weight.shape[1] == intermediate_size * 2 - and w13_weight.shape[2] == hidden_size // 2) - assert (w13_weight_scale.dim() == 3 - and w13_weight_scale.shape[0] == num_experts - and w13_weight_scale.shape[1] == intermediate_size * 2 - and w13_weight_scale.shape[2] == hidden_size // sf_block_size) - assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts - and w2_weight.shape[1] == hidden_size - and w2_weight.shape[2] == intermediate_size // 2) - assert (w2_weight_scale.dim() == 3 - and w2_weight_scale.shape[1] == hidden_size - and w2_weight_scale.shape[2] == intermediate_size // sf_block_size) - assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts - and w13_bias.shape[1] == intermediate_size * 2) - assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts - and w2_bias.shape[1] == hidden_size) + assert ( + w13_weight.dim() == 3 + and w13_weight.shape[0] == num_experts + and w13_weight.shape[1] == intermediate_size * 2 + and w13_weight.shape[2] == hidden_size // 2 + ) + assert ( + w13_weight_scale.dim() == 3 + and w13_weight_scale.shape[0] == num_experts + and w13_weight_scale.shape[1] == intermediate_size * 2 + and w13_weight_scale.shape[2] == hidden_size // sf_block_size + ) + assert ( + w2_weight.dim() == 3 + and w2_weight.shape[0] == num_experts + and w2_weight.shape[1] == hidden_size + and w2_weight.shape[2] == intermediate_size // 2 + ) + assert ( + w2_weight_scale.dim() == 3 + and w2_weight_scale.shape[1] == hidden_size + and w2_weight_scale.shape[2] == intermediate_size // sf_block_size + ) + assert ( + w13_bias.dim() == 2 + and w13_bias.shape[0] == num_experts + and w13_bias.shape[1] == intermediate_size * 2 + ) + assert ( + w2_bias.dim() == 2 + and w2_bias.shape[0] == num_experts + and w2_bias.shape[1] == hidden_size + ) # Swap w1 and w3 as the definition of # swiglu is different in the trtllm-gen w13_weight_scale_ = w13_weight_scale.clone() w13_weight_ = w13_weight.clone() w13_bias_ = w13_bias.clone() - w13_weight[:, :intermediate_size, :].copy_( - w13_weight_[:, intermediate_size:, :]) - w13_weight[:, intermediate_size:, :].copy_( - w13_weight_[:, :intermediate_size, :]) + w13_weight[:, :intermediate_size, :].copy_(w13_weight_[:, intermediate_size:, :]) + w13_weight[:, intermediate_size:, :].copy_(w13_weight_[:, :intermediate_size, :]) w13_weight_scale[:, :intermediate_size, :].copy_( - w13_weight_scale_[:, intermediate_size:, :]) + w13_weight_scale_[:, intermediate_size:, :] + ) w13_weight_scale[:, intermediate_size:, :].copy_( - w13_weight_scale_[:, :intermediate_size, :]) + w13_weight_scale_[:, :intermediate_size, :] + ) w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:]) w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size]) @@ -261,18 +284,23 @@ def tg_mxfp4_moe( w13_bias_interleaved = [] for i in range(num_experts): w13_weight_interleaved.append( - reorder_rows_for_gated_act_gemm(w13_weight[i].clone())) + reorder_rows_for_gated_act_gemm(w13_weight[i].clone()) + ) w13_weight_scale_interleaved.append( - reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone())) + reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()) + ) w13_bias_interleaved.append( - reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, - 1))) + reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1, 1)) + ) w13_weight = torch.stack(w13_weight_interleaved).reshape( - num_experts, 2 * intermediate_size, hidden_size // 2) + num_experts, 2 * intermediate_size, hidden_size // 2 + ) w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape( - num_experts, 2 * intermediate_size, hidden_size // 32) + num_experts, 2 * intermediate_size, hidden_size // 32 + ) w13_bias = torch.stack(w13_bias_interleaved).reshape( - num_experts, 2 * intermediate_size) + num_experts, 2 * intermediate_size + ) # Shuffle weights and scaling factors for transposed mma output gemm1_weights_shuffled = [] @@ -291,9 +319,11 @@ def tg_mxfp4_moe( w13_weight[i].view(torch.uint8), epilogue_tile_m, ) - gemm1_weights_shuffled.append(w13_weight[i].view( - torch.uint8)[permute_indices.to( - w13_weight.device)].contiguous()) + gemm1_weights_shuffled.append( + w13_weight[i] + .view(torch.uint8)[permute_indices.to(w13_weight.device)] + .contiguous() + ) # w13 scale shuffling permute_sf_indices = _maybe_get_cached_w2_permute_indices( _cache_permute_indices, @@ -302,26 +332,35 @@ def tg_mxfp4_moe( num_elts_per_sf=16, ) gemm1_scales_shuffled.append( - nvfp4_block_scale_interleave(w13_weight_scale[i].view( - torch.uint8)[permute_sf_indices.to( - w13_weight_scale.device)].contiguous())) + nvfp4_block_scale_interleave( + w13_weight_scale[i] + .view(torch.uint8)[permute_sf_indices.to(w13_weight_scale.device)] + .contiguous() + ) + ) # w13 bias shuffling permute_bias_indices = _maybe_get_cached_w2_permute_indices( _cache_permute_indices, w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m, ) - gemm1_bias_shuffled.append(w13_bias[i].clone().reshape( - -1, 1)[permute_bias_indices.to(w13_bias.device)].contiguous()) + gemm1_bias_shuffled.append( + w13_bias[i] + .clone() + .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)] + .contiguous() + ) # w2 weight shuffling permute_indices = _maybe_get_cached_w2_permute_indices( _cache_permute_indices, w2_weight[i].view(torch.uint8), epilogue_tile_m, ) - gemm2_weights_shuffled.append(w2_weight[i].view( - torch.uint8)[permute_indices.to( - w2_weight.device)].contiguous()) + gemm2_weights_shuffled.append( + w2_weight[i] + .view(torch.uint8)[permute_indices.to(w2_weight.device)] + .contiguous() + ) # w2 scale shuffling permute_sf_indices = _maybe_get_cached_w2_permute_indices( _cache_permute_indices, @@ -330,48 +369,65 @@ def tg_mxfp4_moe( num_elts_per_sf=16, ) gemm2_scales_shuffled.append( - nvfp4_block_scale_interleave(w2_weight_scale[i].view( - torch.uint8)[permute_sf_indices.to( - w2_weight_scale.device)].contiguous())) + nvfp4_block_scale_interleave( + w2_weight_scale[i] + .view(torch.uint8)[permute_sf_indices.to(w2_weight_scale.device)] + .contiguous() + ) + ) # w2 bias shuffling permute_indices = _maybe_get_cached_w2_permute_indices( _cache_permute_indices, w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m, ) - gemm2_bias_shuffled.append(w2_bias[i].clone().reshape( - -1, 1)[permute_indices.to(w2_bias.device)].contiguous()) + gemm2_bias_shuffled.append( + w2_bias[i] + .clone() + .reshape(-1, 1)[permute_indices.to(w2_bias.device)] + .contiguous() + ) else: for i in range(num_experts): gemm1_weights_shuffled.append( - shuffle_matrix_a(w13_weight[i].view(torch.uint8), - epilogue_tile_m)) + shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m) + ) gemm1_scales_shuffled.append( - shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) + shuffle_matrix_sf_a( + w13_weight_scale[i].view(torch.uint8), epilogue_tile_m + ) + ) gemm2_weights_shuffled.append( - shuffle_matrix_a(w2_weight[i].view(torch.uint8), - epilogue_tile_m)) + shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m) + ) gemm2_scales_shuffled.append( - shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), - epilogue_tile_m)) + shuffle_matrix_sf_a( + w2_weight_scale[i].view(torch.uint8), epilogue_tile_m + ) + ) gemm1_bias_shuffled.append( - shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m)) + shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m) + ) gemm2_bias_shuffled.append( - shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m)) + shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m) + ) w13_weight = torch.stack(gemm1_weights_shuffled) - w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape( - num_experts, 2 * intermediate_size, - hidden_size // sf_block_size).view(torch.float8_e4m3fn) + w13_weight_scale = ( + torch.stack(gemm1_scales_shuffled) + .reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size) + .view(torch.float8_e4m3fn) + ) w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1) w2_weight = torch.stack(gemm2_weights_shuffled) - w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape( - num_experts, hidden_size, - intermediate_size // sf_block_size).view(torch.float8_e4m3fn) + w2_weight_scale = ( + torch.stack(gemm2_scales_shuffled) + .reshape(num_experts, hidden_size, intermediate_size // sf_block_size) + .view(torch.float8_e4m3fn) + ) w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1) tg_result = trtllm_fp4_block_scale_moe( @@ -401,7 +457,8 @@ def tg_mxfp4_moe( routed_scaling_factor=None, tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts), routing_method_type=1, # renormalize - do_finalize=True)[0] + do_finalize=True, + )[0] return tg_result @@ -424,20 +481,21 @@ def check_accuracy(a, b, atol, rtol, percent): if mismatch_percent > 1 - percent: raise Exception( f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} " - f"(threshold: {1-percent:.4f})") + f"(threshold: {1 - percent:.4f})" + ) @pytest.mark.parametrize("topk", [1, 4]) @pytest.mark.parametrize("num_experts", [32, 128]) @pytest.mark.parametrize("num_tokens", [1, 128, 1024]) @pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) -@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), - (1.702, 1.0, 7.0)]) -@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16']) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) +@pytest.mark.parametrize("act_type", ["mxfp8", "bf16"]) @pytest.mark.parametrize("transpose_optimized", [False, True]) @pytest.mark.skipif( not TRTLLM_GEN_MXFP4_AVAILABLE, - reason="nvidia gpu and compute capability sm100 is required for this test") + reason="nvidia gpu and compute capability sm100 is required for this test", +) def test_trtllm_gen_mxfp4_fused_moe( topk: int, num_experts: int, @@ -452,45 +510,52 @@ def test_trtllm_gen_mxfp4_fused_moe( ): seed = 42 torch.manual_seed(seed) - hidden_states = torch.randn(num_tokens, - hidden_size, - device="cuda:0", - dtype=torch.bfloat16) - w13 = (torch.randn(num_experts, - intermediate_size * 2, - hidden_size, - device="cuda:0", - dtype=torch.bfloat16)) - w2 = (torch.randn(num_experts, - hidden_size, - intermediate_size, - device="cuda:0", - dtype=torch.bfloat16)) - bias13 = torch.randn(num_experts, intermediate_size * 2, - device="cuda:0") * 10 + hidden_states = torch.randn( + num_tokens, hidden_size, device="cuda:0", dtype=torch.bfloat16 + ) + w13 = torch.randn( + num_experts, + intermediate_size * 2, + hidden_size, + device="cuda:0", + dtype=torch.bfloat16, + ) + w2 = torch.randn( + num_experts, + hidden_size, + intermediate_size, + device="cuda:0", + dtype=torch.bfloat16, + ) + bias13 = torch.randn(num_experts, intermediate_size * 2, device="cuda:0") * 10 bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10 - router_logits = torch.rand(num_tokens, num_experts, - dtype=torch.float32).cuda() - - w13, w13_scale = fp4_quantize(w13, - torch.tensor(1.0, device="cuda:0"), - 32, - sf_use_ue8m0=True, - is_sf_swizzled_layout=False) + router_logits = torch.rand(num_tokens, num_experts, dtype=torch.float32).cuda() + + w13, w13_scale = fp4_quantize( + w13, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False, + ) w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape( - num_experts, intermediate_size * 2, hidden_size // 32) - w2, w2_scale = fp4_quantize(w2, - torch.tensor(1.0, device="cuda:0"), - 32, - sf_use_ue8m0=True, - is_sf_swizzled_layout=False) + num_experts, intermediate_size * 2, hidden_size // 32 + ) + w2, w2_scale = fp4_quantize( + w2, + torch.tensor(1.0, device="cuda:0"), + 32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=False, + ) w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape( - num_experts, hidden_size, intermediate_size // 32) - if act_type == 'mxfp8': + num_experts, hidden_size, intermediate_size // 32 + ) + if act_type == "mxfp8": hidden_states, hidden_states_scale = mxfp8_quantize( - hidden_states, is_sf_swizzled_layout=False) - hidden_states_scale = hidden_states_scale.view( - torch.float8_e4m3fn).reshape(-1) + hidden_states, is_sf_swizzled_layout=False + ) + hidden_states_scale = hidden_states_scale.view(torch.float8_e4m3fn).reshape(-1) else: hidden_states_scale = None @@ -500,9 +565,10 @@ def test_trtllm_gen_mxfp4_fused_moe( w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone()) bias13_ref = bias13 bias2_ref = bias2 - if act_type == 'mxfp8': - hidden_states_ref = mxfp8_dequantize( - hidden_states, hidden_states_scale).to(torch.float32) + if act_type == "mxfp8": + hidden_states_ref = mxfp8_dequantize(hidden_states, hidden_states_scale).to( + torch.float32 + ) else: hidden_states_ref = hidden_states.to(torch.float32) # Process tokens in chunks of 32 to reduce memory usage @@ -529,29 +595,31 @@ def test_trtllm_gen_mxfp4_fused_moe( # trtllm-gen result if alpha is not None: - alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) + alpha = torch.full((num_experts,), alpha, device=hidden_states.device) if limit is not None: - limit = torch.full((num_experts, ), limit, device=hidden_states.device) + limit = torch.full((num_experts,), limit, device=hidden_states.device) if beta is not None: - beta = torch.full((num_experts, ), beta, device=hidden_states.device) - tg_result = tg_mxfp4_moe(router_logits, - topk, - num_experts, - intermediate_size, - hidden_size, - hidden_states, - hidden_states_scale, - w13, - w13_scale, - bias13, - w2, - w2_scale, - bias2, - act_type, - alpha=alpha, - beta=beta, - limit=limit, - transpose_optimized=transpose_optimized) + beta = torch.full((num_experts,), beta, device=hidden_states.device) + tg_result = tg_mxfp4_moe( + router_logits, + topk, + num_experts, + intermediate_size, + hidden_size, + hidden_states, + hidden_states_scale, + w13, + w13_scale, + bias13, + w2, + w2_scale, + bias2, + act_type, + alpha=alpha, + beta=beta, + limit=limit, + transpose_optimized=transpose_optimized, + ) # relatively loose check since the mxfp4 quantization is less accurate check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8) @@ -573,8 +641,7 @@ def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("num_tokens", [1, 128]) @pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) -@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), - (1.702, 1.0, 7.0)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) @pytest.mark.skipif( not HOPPER_MXFP4_BF16_AVAILABLE, reason="nvidia gpu sm90 and flashinfer are required for this test", @@ -593,52 +660,73 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( device = "cuda:0" # Inputs - hidden_states = torch.randn(num_tokens, - hidden_size, - device=device, - dtype=torch.bfloat16) + hidden_states = torch.randn( + num_tokens, hidden_size, device=device, dtype=torch.bfloat16 + ) # Random MXFP4 weights and scales (uint8), contiguous [w1; w3] w13_q = torch.randint( 0, - 256, (num_experts, 2 * intermediate_size, hidden_size // 2), + 256, + (num_experts, 2 * intermediate_size, hidden_size // 2), device=device, - dtype=torch.uint8) + dtype=torch.uint8, + ) w13_scale = torch.randint( 118, - 123, (num_experts, 2 * intermediate_size, hidden_size // 32), + 123, + (num_experts, 2 * intermediate_size, hidden_size // 32), device=device, - dtype=torch.uint8) + dtype=torch.uint8, + ) - w2_q = torch.randint(0, - 256, - (num_experts, hidden_size, intermediate_size // 2), - device=device, - dtype=torch.uint8) + w2_q = torch.randint( + 0, + 256, + (num_experts, hidden_size, intermediate_size // 2), + device=device, + dtype=torch.uint8, + ) w2_scale = torch.randint( 118, - 123, (num_experts, hidden_size, intermediate_size // 32), + 123, + (num_experts, hidden_size, intermediate_size // 32), device=device, - dtype=torch.uint8) + dtype=torch.uint8, + ) # Bias contiguous [b1; b3] - bias13 = (torch.randn(num_experts, - 2 * intermediate_size, - device=device, - dtype=torch.bfloat16) * 10) - bias2 = (torch.randn( - num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) - router_logits = torch.rand(num_tokens, - num_experts, - dtype=torch.float32, - device=device) + bias13 = ( + torch.randn( + num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16 + ) + * 10 + ) + bias2 = ( + torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10 + ) + router_logits = torch.rand( + num_tokens, num_experts, dtype=torch.float32, device=device + ) w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape( - num_experts, 2 * intermediate_size, hidden_size) + num_experts, 2 * intermediate_size, hidden_size + ) w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape( - num_experts, hidden_size, intermediate_size) - ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, - hidden_states.to(torch.float32), w13_ref, - bias13.to(torch.float32), w2_ref, - bias2.to(torch.float32), alpha, beta, limit, 'bf16') + num_experts, hidden_size, intermediate_size + ) + ref = reference_moe( + router_logits.to(torch.float32), + topk, + num_experts, + hidden_states.to(torch.float32), + w13_ref, + bias13.to(torch.float32), + w2_ref, + bias2.to(torch.float32), + alpha, + beta, + limit, + "bf16", + ) from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe @@ -654,23 +742,24 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( w13_s_inter = _interleave_scales_lastdim_by4(w13_s) w2_s_inter = _interleave_scales_lastdim_by4(w2_scale) - routing_weights = torch.nn.functional.softmax(router_logits, - dim=1, - dtype=torch.float32) - token_final_scales, token_selected_experts = torch.topk(routing_weights, - topk, - dim=-1) - token_final_scales = (token_final_scales / - token_final_scales.sum(dim=-1, keepdim=True)) + routing_weights = torch.nn.functional.softmax( + router_logits, dim=1, dtype=torch.float32 + ) + token_final_scales, token_selected_experts = torch.topk( + routing_weights, topk, dim=-1 + ) + token_final_scales = token_final_scales / token_final_scales.sum( + dim=-1, keepdim=True + ) token_selected_experts = token_selected_experts.to(torch.int).contiguous() out = torch.empty_like(hidden_states, dtype=torch.bfloat16) if alpha is not None: - alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) + alpha = torch.full((num_experts,), alpha, device=hidden_states.device) if beta is not None: - beta = torch.full((num_experts, ), beta, device=hidden_states.device) + beta = torch.full((num_experts,), beta, device=hidden_states.device) if limit is not None: - limit = torch.full((num_experts, ), limit, device=hidden_states.device) + limit = torch.full((num_experts,), limit, device=hidden_states.device) _ = flashinfer_cutlass_fused_moe( input=hidden_states, @@ -680,8 +769,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( fc2_expert_weights=w2_q, output_dtype=torch.bfloat16, output=out, - quant_scales=[w13_s_inter.to(torch.uint8), - w2_s_inter.to(torch.uint8)], + quant_scales=[w13_s_inter.to(torch.uint8), w2_s_inter.to(torch.uint8)], fc1_expert_biases=w13_b, fc2_expert_biases=bias2.to(torch.bfloat16), swiglu_alpha=alpha, @@ -702,11 +790,13 @@ def test_flashinfer_cutlass_mxfp4_fused_moe( @pytest.mark.parametrize("num_experts", [32]) @pytest.mark.parametrize("num_tokens", [1, 128]) @pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) -@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), - (1.702, 1.0, 7.0)]) +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), (1.702, 1.0, 7.0)]) @pytest.mark.skipif( - not (current_platform.is_cuda() - and current_platform.is_device_capability(100) and has_flashinfer()), + not ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and has_flashinfer() + ), reason="NVIDIA GPU sm100 and flashinfer are required for this test", ) def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( @@ -723,32 +813,43 @@ def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( device = "cuda:0" # Inputs - hidden_states = torch.randn(num_tokens, - hidden_size, - device=device, - dtype=torch.bfloat16) + hidden_states = torch.randn( + num_tokens, hidden_size, device=device, dtype=torch.bfloat16 + ) # Float weights in w13 format [w1; w3] - w13 = (torch.randn(num_experts, - 2 * intermediate_size, - hidden_size, - device=device, - dtype=torch.bfloat16) / 10) - w2 = (torch.randn(num_experts, - hidden_size, - intermediate_size, - device=device, - dtype=torch.bfloat16) / 10) + w13 = ( + torch.randn( + num_experts, + 2 * intermediate_size, + hidden_size, + device=device, + dtype=torch.bfloat16, + ) + / 10 + ) + w2 = ( + torch.randn( + num_experts, + hidden_size, + intermediate_size, + device=device, + dtype=torch.bfloat16, + ) + / 10 + ) # Bias contiguous [b1; b3] - bias13 = (torch.randn(num_experts, - 2 * intermediate_size, - device=device, - dtype=torch.bfloat16) * 10) - bias2 = (torch.randn( - num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) - router_logits = torch.rand(num_tokens, - num_experts, - dtype=torch.float32, - device=device) + bias13 = ( + torch.randn( + num_experts, 2 * intermediate_size, device=device, dtype=torch.bfloat16 + ) + * 10 + ) + bias2 = ( + torch.randn(num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10 + ) + router_logits = torch.rand( + num_tokens, num_experts, dtype=torch.float32, device=device + ) # Quantize weights to MXFP4 per expert (SM100 path) from flashinfer import mxfp4_quantize @@ -761,36 +862,56 @@ def quant_mxfp4_batches(a: torch.Tensor, e: int): sfs.append(sf) return torch.stack(qs), torch.stack(sfs) - def dequant_mxfp4_batches(mat_fp4: torch.Tensor, - scale_tensor: torch.Tensor): + def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor): num_batches = mat_fp4.size(0) scale_tensor = scale_tensor.view(num_batches, -1) from flashinfer import mxfp4_dequantize - return torch.stack([ - mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :]) - for b in range(num_batches) - ]) + + return torch.stack( + [ + mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :]) + for b in range(num_batches) + ] + ) w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts) w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts) # Reference result using dequantized tensors and reference_moe - w13_ref = dequant_mxfp4_batches( - w13_q.view(torch.uint8), - w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( - num_experts, 2 * intermediate_size, hidden_size).to(device) - w2_ref = dequant_mxfp4_batches( - w2_q.view(torch.uint8), - w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( - num_experts, hidden_size, intermediate_size).to(device) + w13_ref = ( + dequant_mxfp4_batches( + w13_q.view(torch.uint8), w13_scale.view(torch.uint8).reshape(-1) + ) + .to(torch.float32) + .reshape(num_experts, 2 * intermediate_size, hidden_size) + .to(device) + ) + w2_ref = ( + dequant_mxfp4_batches( + w2_q.view(torch.uint8), w2_scale.view(torch.uint8).reshape(-1) + ) + .to(torch.float32) + .reshape(num_experts, hidden_size, intermediate_size) + .to(device) + ) # Quantize activations for SM100 path and dequantize for reference hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32) # Reference uses BF16 input but quantizes intermediate activation to MXFP8 - ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, - hidden_states.to(torch.float32), w13_ref, - bias13.to(torch.float32), w2_ref, - bias2.to(torch.float32), alpha, beta, limit, 'mxfp8') + ref = reference_moe( + router_logits.to(torch.float32), + topk, + num_experts, + hidden_states.to(torch.float32), + w13_ref, + bias13.to(torch.float32), + w2_ref, + bias2.to(torch.float32), + alpha, + beta, + limit, + "mxfp8", + ) # Prepare inputs for FlashInfer CUTLASS fused MoE from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe @@ -807,31 +928,28 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor, w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) # Build routing for kernel - routing_weights = torch.nn.functional.softmax(router_logits, - dim=1, - dtype=torch.float32) - token_final_scales, token_selected_experts = torch.topk(routing_weights, - topk, - dim=-1) - token_final_scales = (token_final_scales / - token_final_scales.sum(dim=-1, keepdim=True)) + routing_weights = torch.nn.functional.softmax( + router_logits, dim=1, dtype=torch.float32 + ) + token_final_scales, token_selected_experts = torch.topk( + routing_weights, topk, dim=-1 + ) + token_final_scales = token_final_scales / token_final_scales.sum( + dim=-1, keepdim=True + ) token_selected_experts = token_selected_experts.to(torch.int).contiguous() out = torch.empty_like(hidden_states, dtype=torch.bfloat16) if alpha is not None: - alpha_t = torch.full((num_experts, ), - alpha, - device=hidden_states.device) + alpha_t = torch.full((num_experts,), alpha, device=hidden_states.device) else: alpha_t = None if beta is not None: - beta_t = torch.full((num_experts, ), beta, device=hidden_states.device) + beta_t = torch.full((num_experts,), beta, device=hidden_states.device) else: beta_t = None if limit is not None: - limit_t = torch.full((num_experts, ), - limit, - device=hidden_states.device) + limit_t = torch.full((num_experts,), limit, device=hidden_states.device) else: limit_t = None diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index a48bfeb10b2e..dae19c0b2b31 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -4,9 +4,11 @@ import torch from tests.kernels.moe.utils import make_test_weights -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) from tests.kernels.utils import torch_moe from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config @@ -16,8 +18,9 @@ from vllm.platforms import current_platform if not current_platform.has_device_capability(100): - pytest.skip("Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) + pytest.skip( + "Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True + ) MNK_FACTORS = [ (2, 1024, 1024), @@ -38,36 +41,34 @@ @pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @torch.inference_mode() -def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, - dtype: torch.dtype): +def test_cutlass_fp4_moe_no_graph( + m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype +): current_platform.seed_everything(7) with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): quant_blocksize = 16 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - (_, w1_q, w1_blockscale, - w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights( - e, - n, - k, - in_dtype=dtype, - quant_dtype="nvfp4", - block_shape=None, # use quant_blocksize? - per_out_ch_quant=False, - ) + (_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = ( + make_test_weights( + e, + n, + k, + in_dtype=dtype, + quant_dtype="nvfp4", + block_shape=None, # use quant_blocksize? + per_out_ch_quant=False, + ) + ) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) - a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) - a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) assert w1_gs is not None assert w2_gs is not None @@ -97,40 +98,44 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, ) # Reference check: - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1) + ).to(torch.float32) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) - a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, - a_scale_interleaved, - a_global_scale, - dtype=a.dtype, - device=a.device, - block_size=quant_blocksize) + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize, + ) w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) for idx in range(0, e): - w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], - w1_blockscale[idx], - w1_gs[idx], - dtype=dtype, - device=w1_q.device, - block_size=quant_blocksize) - w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], - w2_blockscale[idx], - w2_gs[idx], - dtype=dtype, - device=w2_q.device, - block_size=quant_blocksize) + w1_d[idx] = dequantize_nvfp4_to_dtype( + w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=dtype, + device=w1_q.device, + block_size=quant_blocksize, + ) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=dtype, + device=w2_q.device, + block_size=quant_blocksize, + ) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) - torch.testing.assert_close(torch_output, - cutlass_output, - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1) if __name__ == "__main__": diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 59126cef6adb..4c7c6c6a4f52 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -9,13 +9,10 @@ from tests.kernels.utils import torch_experts from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import ( - fp8_w8a8_moe_quant_config) -from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassBatchedExpertsFp8) +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassBatchedExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.platforms import current_platform from vllm.utils import cdiv @@ -24,9 +21,13 @@ try: from pplx_kernels import AllToAll - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_finalize, nvshmem_get_unique_id, - nvshmem_init) + from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_finalize, + nvshmem_get_unique_id, + nvshmem_init, + ) + has_pplx = True except ImportError: has_pplx = False @@ -50,12 +51,12 @@ def chunk_by_rank(t, r, w): chunk = rank_chunk(num, r, w) rem = num % w if rem == 0 or r < rem: - return t[(r * chunk):(r + 1) * chunk].contiguous() + return t[(r * chunk) : (r + 1) * chunk].contiguous() else: long_chunks = (num // w + 1) * rem short_chunks = (r - rem) * chunk start = long_chunks + short_chunks - return t[start:start + chunk].contiguous() + return t[start : start + chunk].contiguous() def pplx_cutlass_moe( @@ -75,7 +76,9 @@ def pplx_cutlass_moe( group_name: Optional[str], ): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize) + PplxPrepareAndFinalize, + ) + assert torch.cuda.current_device() == pgi.local_rank num_tokens, hidden_dim = a.shape @@ -126,35 +129,40 @@ def pplx_cutlass_moe( ata, max_num_tokens=max_num_tokens, num_local_experts=num_local_experts, - num_dispatchers=num_dispatchers) - - ab_strides1 = torch.full((num_local_experts, ), - hidden_dim, - device="cuda", - dtype=torch.int64) - ab_strides2 = torch.full((num_local_experts, ), - intermediate_dim, - device="cuda", - dtype=torch.int64) - c_strides1 = torch.full((num_local_experts, ), - 2 * intermediate_dim, - device="cuda", - dtype=torch.int64) - c_strides2 = torch.full((num_local_experts, ), - hidden_dim, - device="cuda", - dtype=torch.int64) + num_dispatchers=num_dispatchers, + ) + + ab_strides1 = torch.full( + (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64 + ) + ab_strides2 = torch.full( + (num_local_experts,), intermediate_dim, device="cuda", dtype=torch.int64 + ) + c_strides1 = torch.full( + (num_local_experts,), 2 * intermediate_dim, device="cuda", dtype=torch.int64 + ) + c_strides2 = torch.full( + (num_local_experts,), hidden_dim, device="cuda", dtype=torch.int64 + ) experts = CutlassBatchedExpertsFp8( - num_local_experts, num_dispatchers, out_dtype, ab_strides1, - ab_strides2, c_strides1, c_strides2, + num_local_experts, + num_dispatchers, + out_dtype, + ab_strides1, + ab_strides2, + c_strides1, + c_strides2, fp8_w8a8_moe_quant_config( per_act_token_quant=per_act_token, per_out_ch_quant=per_out_ch, w1_scale=chunk_by_rank(w1_scale, rank, world_size), w2_scale=chunk_by_rank(w2_scale, rank, world_size), a1_scale=chunk_by_rank(a1_scale, rank, world_size) - if per_act_token else a1_scale[rank])) + if per_act_token + else a1_scale[rank], + ), + ) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, @@ -162,10 +170,10 @@ def pplx_cutlass_moe( ) a_chunk = chunk_by_rank(a, rank, world_size).to(device) - chunk_topk_weight = chunk_by_rank(topk_weights, rank, - world_size).to(device) - chunk_topk_ids = chunk_by_rank(topk_ids, rank, - world_size).to(torch.uint32).to(device) + chunk_topk_weight = chunk_by_rank(topk_weights, rank, world_size).to(device) + chunk_topk_ids = ( + chunk_by_rank(topk_ids, rank, world_size).to(torch.uint32).to(device) + ) out = fused_cutlass_experts( a_chunk, @@ -174,7 +182,7 @@ def pplx_cutlass_moe( chunk_topk_weight, chunk_topk_ids, global_num_experts=num_experts, - expert_map=None, #TODO + expert_map=None, # TODO ) torch.cuda.synchronize() @@ -210,35 +218,48 @@ def _pplx_moe( ): try: if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = ( + nvshmem_get_unique_id() + if pgi.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) else: group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, - backend="gloo") + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name with set_current_vllm_config(vllm_config): - torch_output = torch_experts(a_full, w1_full, w2_full, - topk_weights, topk_ids) - pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, - w2_scale, topk_weights, topk_ids, - a1_scale, out_dtype, per_act_token, - per_out_ch, group_name) - - torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pplx_output.device) + torch_output = torch_experts( + a_full, w1_full, w2_full, topk_weights, topk_ids + ) + pplx_output = pplx_cutlass_moe( + pgi, + dp_size, + a, + w1, + w2, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + a1_scale, + out_dtype, + per_act_token, + per_out_ch, + group_name, + ) + + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to( + pplx_output.device + ) # Uncomment if more debugging is needed # print("PPLX OUT:", pplx_output) # print("TORCH OUT:", torch_output) - torch.testing.assert_close(pplx_output, - torch_output, - atol=0.05, - rtol=0) + torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0) finally: if use_internode: nvshmem_finalize() @@ -251,13 +272,15 @@ def _pplx_moe( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) -@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) # , [4, 2]]) @pytest.mark.parametrize("use_internode", [False]) @multi_gpu_test(num_gpus=2) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) @requires_pplx def test_cutlass_moe_pplx( m: int, @@ -273,7 +296,6 @@ def test_cutlass_moe_pplx( current_platform.seed_everything(7) with set_current_vllm_config(vllm_config): - dtype = torch.half a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0 @@ -283,22 +305,18 @@ def test_cutlass_moe_pplx( n_b_scales = 2 * n if per_out_ch else 1 k_b_scales = k if per_out_ch else 1 - w1_q = torch.empty((e, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn) w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((e, n_b_scales, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((e, k_b_scales, 1), - device="cuda", - dtype=torch.float32) + w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32) for expert in range(e): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( - w1[expert], use_per_token_if_dynamic=per_out_ch) + w1[expert], use_per_token_if_dynamic=per_out_ch + ) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( - w2[expert], use_per_token_if_dynamic=per_out_ch) + w2[expert], use_per_token_if_dynamic=per_out_ch + ) w1_d = torch.empty_like(w1) w2_d = torch.empty_like(w2) @@ -307,19 +325,35 @@ def test_cutlass_moe_pplx( w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = fused_topk(a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) world_size, dp_size = world_dp_size - a_scale1 = torch.randn( - (m if per_act_token else 1, 1), device="cuda", - dtype=torch.float32) / 10.0 + a_scale1 = ( + torch.randn( + (m if per_act_token else 1, 1), device="cuda", dtype=torch.float32 + ) + / 10.0 + ) if not per_act_token: a_scale1 = a_scale1.repeat(world_size, 1) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q, - w1_scale, w2_scale, topk_weights, topk_ids, a_scale1, - dtype, a, w1_d, w2_d, per_act_token, per_out_ch, - use_internode) + parallel_launch( + world_size, + _pplx_moe, + dp_size, + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + a_scale1, + dtype, + a, + w1_d, + w2_d, + per_act_token, + per_out_ch, + use_internode, + ) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 4ca4a1e79c57..223f095c0b55 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -4,6 +4,7 @@ Run `pytest tests/kernels/test_pplx_moe.py`. """ + import copy import itertools import textwrap @@ -15,29 +16,34 @@ try: from pplx_kernels import AllToAll - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_finalize, nvshmem_get_unique_id, - nvshmem_init) + from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_finalize, + nvshmem_get_unique_id, + nvshmem_init, + ) + has_pplx = True except ImportError: has_pplx = False -from tests.kernels.moe.modular_kernel_tools.parallel_utils import ( - _set_vllm_config) -from tests.kernels.moe.utils import (make_shared_experts, make_test_weights, - naive_batched_moe) +from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config +from tests.kernels.moe.utils import ( + make_shared_experts, + make_test_weights, + naive_batched_moe, +) from tests.kernels.quant_utils import dequant from tests.kernels.utils import torch_experts from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_topk, override_config from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceDelegate, +) from vllm.platforms import current_platform from vllm.utils import round_up @@ -59,7 +65,7 @@ PPLX_COMBOS = [ # TODO(bnell): figure out why this fails, seems to be test problem - #(1, 128, 128), + # (1, 128, 128), (2, 128, 512), (3, 1024, 2048), (4, 128, 128), @@ -91,17 +97,16 @@ def torch_prepare( num_tokens, hidden_dim = a.shape topk = topk_ids.shape[1] - tokens_per_expert = torch.bincount(topk_ids.view(-1), - minlength=num_experts) + tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts) assert tokens_per_expert.numel() == num_experts if max_num_tokens is None: max_num_tokens = int(tokens_per_expert.max().item()) - b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim), - dtype=a.dtype, - device=a.device) + b_a = torch.zeros( + (num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device + ) token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) @@ -109,28 +114,29 @@ def torch_prepare( for j in range(topk): expert_id = topk_ids[token, j] idx = token_counts[expert_id] - b_a[expert_id, idx:idx + 1, :] = a[token, :] + b_a[expert_id, idx : idx + 1, :] = a[token, :] token_counts[expert_id] = token_counts[expert_id] + 1 return b_a, tokens_per_expert -def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor, - topk_ids: torch.Tensor) -> torch.Tensor: +def torch_finalize( + b_out: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor +) -> torch.Tensor: num_tokens = topk_ids.shape[0] num_experts = b_out.shape[0] K = b_out.shape[-1] out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) - expert_counts = torch.zeros(num_experts, - dtype=torch.int, - device=b_out.device) + expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device) for token in range(num_tokens): expert_ids = topk_ids[token] for i in range(expert_ids.numel()): expert_id = expert_ids[i] idx = expert_counts[expert_id] - out[token, :] = out[token, :] + b_out[expert_id, idx:idx + - 1, :] * topk_weight[token, i] + out[token, :] = ( + out[token, :] + + b_out[expert_id, idx : idx + 1, :] * topk_weight[token, i] + ) expert_counts[expert_id] = expert_counts[expert_id] + 1 return out @@ -149,17 +155,18 @@ def torch_batched_moe( num_tokens, topk = topk_ids.shape _, max_num_tokens, K = b_a.shape assert num_experts == b_a.shape[0] and w2.shape[1] == K - out = torch.zeros((num_experts, max_num_tokens, K), - dtype=b_a.dtype, - device=b_a.device) - tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), - dtype=b_a.dtype, - device=b_a.device) + out = torch.zeros( + (num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device + ) + tmp = torch.empty( + (max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device + ) for expert in range(num_experts): num = tokens_per_expert[expert] if num > 0: torch.ops._C.silu_and_mul( - tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)) + tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1) + ) out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) return torch_finalize(out, topk_weight, topk_ids) @@ -186,20 +193,16 @@ def test_fused_moe_batched_experts( with set_current_vllm_config(vllm_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - baseline_output = torch_experts(a, w1, w2, topk_weight, - topk_ids) # only for baseline + baseline_output = torch_experts( + a, w1, w2, topk_weight, topk_ids + ) # only for baseline torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) batched_output = naive_batched_moe( - a, w1, w2, topk_weight, topk_ids) # pick torch_experts or this + a, w1, w2, topk_weight, topk_ids + ) # pick torch_experts or this - torch.testing.assert_close(baseline_output, - torch_output, - atol=2e-2, - rtol=0) - torch.testing.assert_close(baseline_output, - batched_output, - atol=2e-2, - rtol=0) + torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0) def create_pplx_prepare_finalize( @@ -217,7 +220,9 @@ def create_pplx_prepare_finalize( group_name: Optional[str], ): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( - PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes) + PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes, + ) max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1) num_local_experts = rank_chunk(num_experts, 0, world_size) @@ -266,28 +271,31 @@ def rank_chunk(num: int, r: int, w: int) -> int: def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: chunk = rank_chunk(t.shape[0], r, w) - return t[(r * chunk):(r + 1) * chunk] + return t[(r * chunk) : (r + 1) * chunk] -def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int, - w: int) -> Optional[torch.Tensor]: +def maybe_chunk_by_rank( + t: Optional[torch.Tensor], r: int, w: int +) -> Optional[torch.Tensor]: if t is not None: return chunk_by_rank(t, r, w) else: return t -def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int, - w: int) -> Optional[torch.Tensor]: +def chunk_scales_by_rank( + t: Optional[torch.Tensor], r: int, w: int +) -> Optional[torch.Tensor]: if t is not None and t.numel() > 1: chunk = rank_chunk(t.shape[0], r, w) - return t[(r * chunk):(r + 1) * chunk] + return t[(r * chunk) : (r + 1) * chunk] else: return t -def chunk_scales(t: Optional[torch.Tensor], start: int, - end: int) -> Optional[torch.Tensor]: +def chunk_scales( + t: Optional[torch.Tensor], start: int, end: int +) -> Optional[torch.Tensor]: if t is not None and t.numel() > 1: return t[start:end] else: @@ -350,8 +358,7 @@ def pplx_prepare_finalize( device=device, ) - if (quant_dtype is not None and not per_act_token_quant - and block_shape is None): + if quant_dtype is not None and not per_act_token_quant and block_shape is None: a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) else: @@ -375,8 +382,7 @@ def pplx_prepare_finalize( ), ) - b_a = dummy_work( - dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) + b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype)) prepare_finalize.finalize( out, @@ -410,15 +416,17 @@ def _pplx_prepare_finalize( ): try: if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = ( + nvshmem_get_unique_id() + if pgi.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) group_name = None else: group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, - backend="gloo") + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) @@ -426,22 +434,28 @@ def _pplx_prepare_finalize( a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0) - torch_output = (a_rep.view(m, topk, k) * - topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum( - dim=1) - - pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, - topk_ids, num_experts, quant_dtype, - block_shape, per_act_token_quant, - group_name) + torch_output = ( + a_rep.view(m, topk, k) * topk_weight.view(m, topk, 1).to(a_rep.dtype) + ).sum(dim=1) + + pplx_output = pplx_prepare_finalize( + pgi, + dp_size, + a, + topk_weight, + topk_ids, + num_experts, + quant_dtype, + block_shape, + per_act_token_quant, + group_name, + ) - torch_output = chunk_by_rank(torch_output, pgi.rank, - pgi.world_size).to(pgi.device) + torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to( + pgi.device + ) - torch.testing.assert_close(pplx_output, - torch_output, - atol=3e-2, - rtol=3e-2) + torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2) finally: if use_internode: nvshmem_finalize() @@ -491,9 +505,19 @@ def test_pplx_prepare_finalize_slow( a = torch.randn((m, k), device=device, dtype=act_dtype) / 10 score = torch.randn((m, e), device=device, dtype=act_dtype) - parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, - topk, e, quant_dtype, block_shape, per_act_token_quant, - use_internode) + parallel_launch( + world_size, + _pplx_prepare_finalize, + dp_size, + a, + score, + topk, + e, + quant_dtype, + block_shape, + per_act_token_quant, + use_internode, + ) def pplx_moe( @@ -517,7 +541,6 @@ def pplx_moe( use_cudagraphs: bool = True, shared_experts: Optional[torch.nn.Module] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - num_tokens, hidden_dim = a.shape num_experts = w1.shape[0] topk = topk_ids.shape[1] @@ -579,21 +602,23 @@ def pplx_moe( # large enough to trigger chunking. I'm leaving the flag and # setup code in case we are able to revisit this later. if use_compile: - _fused_experts = torch.compile(fused_experts, - backend='inductor', - fullgraph=True) + _fused_experts = torch.compile( + fused_experts, backend="inductor", fullgraph=True + ) torch._dynamo.mark_dynamic(a_chunk, 0) torch._dynamo.mark_dynamic(chunk_topk_weight, 0) torch._dynamo.mark_dynamic(chunk_topk_ids, 0) else: _fused_experts = fused_experts - out = _fused_experts(a_chunk, - w1_chunk, - w2_chunk, - chunk_topk_weight, - chunk_topk_ids, - global_num_experts=num_experts) + out = _fused_experts( + a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts, + ) if use_cudagraphs: if isinstance(out, tuple): @@ -604,12 +629,14 @@ def pplx_moe( stream = torch.cuda.Stream() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): - out = _fused_experts(a_chunk, - w1_chunk, - w2_chunk, - chunk_topk_weight, - chunk_topk_ids, - global_num_experts=num_experts) + out = _fused_experts( + a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts, + ) torch.cuda.synchronize() graph.replay() @@ -640,15 +667,17 @@ def _pplx_moe( ): try: if use_internode: - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + uid = ( + nvshmem_get_unique_id() + if pgi.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) group_name = None else: group_ranks = list(range(pgi.world_size)) - cpu_group = torch.distributed.new_group(group_ranks, - backend="gloo") + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") group_name = cpu_group.group_name m, k = a.shape @@ -666,8 +695,7 @@ def _pplx_moe( w1_s = w1_s.to(device) if w1_s is not None else None w2_s = w2_s.to(device) if w2_s is not None else None - if (quant_dtype is not None and not per_act_token_quant - and block_shape is None): + if quant_dtype is not None and not per_act_token_quant and block_shape is None: a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) else: @@ -677,10 +705,7 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) - if shared_experts is not None: - shared_output = shared_experts(a) - else: - shared_output = None + shared_output = shared_experts(a) if shared_experts is not None else None torch_output = torch_experts( a, @@ -742,31 +767,27 @@ def _pplx_moe( if shared_output is not None: assert pplx_shared_output is not None chunked_shared_output = chunk_by_rank( - shared_output, pgi.rank, - pgi.world_size).to(pplx_shared_output.device) + shared_output, pgi.rank, pgi.world_size + ).to(pplx_shared_output.device) else: chunked_shared_output = None chunked_batch_output = chunk_by_rank( - batched_output, pgi.rank, pgi.world_size).to(pplx_output.device) + batched_output, pgi.rank, pgi.world_size + ).to(pplx_output.device) - torch.testing.assert_close(batched_output, - torch_output, - atol=3e-2, - rtol=3e-2) + torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2) - torch.testing.assert_close(pplx_output, - chunked_batch_output, - atol=3e-2, - rtol=3e-2) + torch.testing.assert_close( + pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2 + ) if shared_experts is not None: assert chunked_shared_output is not None assert pplx_shared_output is not None - torch.testing.assert_close(pplx_shared_output, - chunked_shared_output, - atol=3e-2, - rtol=3e-2) + torch.testing.assert_close( + pplx_shared_output, chunked_shared_output, atol=3e-2, rtol=3e-2 + ) finally: if use_internode: @@ -823,15 +844,33 @@ def test_pplx_moe_slow( per_out_ch_quant=per_act_token_quant, ) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e, - w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape, - use_internode) - + parallel_launch( + world_size, + _pplx_moe, + dp_size, + a, + w1, + w2, + score, + topk, + e, + w1_s, + w2_s, + quant_dtype, + per_act_token_quant, + block_shape, + use_internode, + ) -def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool, - use_shared_experts: bool, make_weights: bool, - test_fn: Callable): +def _pplx_test_loop( + pgi: ProcessGroupInfo, + dp_size: int, + use_internode: bool, + use_shared_experts: bool, + make_weights: bool, + test_fn: Callable, +): def format_result(msg, ex=None): if ex is not None: x = str(ex) @@ -850,12 +889,12 @@ def format_result(msg, ex=None): new_vllm_config = copy.deepcopy(vllm_config) new_vllm_config.parallel_config.data_parallel_size = pgi.world_size new_vllm_config.parallel_config.enable_expert_parallel = True - _set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, - pgi.local_rank) + _set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, pgi.local_rank) current_platform.seed_everything(7) - combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, - [False, True], [None, [128, 128]]) + combos = itertools.product( + PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]] + ) exceptions = [] count = 0 for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos: @@ -873,13 +912,11 @@ def format_result(msg, ex=None): f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, " f"dtype={dtype}, per_act_token={per_act_token_quant}, " f"block_shape={block_shape}, use_internode={use_internode}, " - f"use_shared_experts={use_shared_experts}") + f"use_shared_experts={use_shared_experts}" + ) - if not use_fp8_w8a8 and (per_act_token_quant - or block_shape is not None): - print( - f"{test_desc} - Skip quantization test for non-quantized type." - ) + if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None): + print(f"{test_desc} - Skip quantization test for non-quantized type.") continue if per_act_token_quant and block_shape is not None: @@ -934,10 +971,10 @@ def format_result(msg, ex=None): if len(exceptions) > 0: raise RuntimeError( f"{len(exceptions)} of {count} tests failed in child process, " - f"rank={pgi.rank}.") + f"rank={pgi.rank}." + ) else: - print(f"{count} of {count} tests passed in child process, " - f"rank={pgi.rank}.") + print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.") @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @@ -950,8 +987,15 @@ def test_pplx_prepare_finalize( ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size, - use_internode, False, False, _pplx_prepare_finalize) + parallel_launch( + world_size * dp_size, + _pplx_test_loop, + dp_size, + use_internode, + False, + False, + _pplx_prepare_finalize, + ) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) @@ -966,5 +1010,12 @@ def test_pplx_moe( ): current_platform.seed_everything(7) world_size, dp_size = world_dp_size - parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, - use_shared_experts, True, _pplx_moe) + parallel_launch( + world_size, + _pplx_test_loop, + dp_size, + use_internode, + use_shared_experts, + True, + _pplx_moe, + ) diff --git a/tests/kernels/moe/test_rocm_aiter_topk.py b/tests/kernels/moe/test_rocm_aiter_topk.py index 1c51c530c193..d4724d749fc9 100644 --- a/tests/kernels/moe/test_rocm_aiter_topk.py +++ b/tests/kernels/moe/test_rocm_aiter_topk.py @@ -24,13 +24,14 @@ pytestmark = pytest.mark.skipif( not (current_platform.is_rocm() and aiter_available), - reason="AITER ops are only available on ROCm with aiter package installed") + reason="AITER ops are only available on ROCm with aiter package installed", +) def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): """Test that the custom op is correctly registered.""" # Check if the op exists in torch.ops.vllm - assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk') + assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk") # Check if the op is callable assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk) @@ -39,7 +40,7 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): def test_rocm_aiter_grouped_topk_custom_op_registration(): """Test that the custom op is correctly registered.""" # Check if the op exists in torch.ops.vllm - assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk') + assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk") # Check if the op is callable assert callable(torch.ops.vllm.rocm_aiter_grouped_topk) @@ -56,25 +57,29 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility(): renormalize = True scale_factor = 1.0 - gating_output = torch.randn((token, expert), - dtype=torch.bfloat16, - device="cuda") - e_score_correction_bias = torch.randn((expert, ), - dtype=torch.bfloat16, - device="cuda") + gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda") + e_score_correction_bias = torch.randn( + (expert,), dtype=torch.bfloat16, device="cuda" + ) device = gating_output.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights = torch.empty((token, topk), - dtype=torch.float32, - device=device) + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) # Define a function that uses the op - def biased_grouped_topk_fn(gating_output, e_score_correction_bias, - topk_weights, topk_ids): + def biased_grouped_topk_fn( + gating_output, e_score_correction_bias, topk_weights, topk_ids + ): return torch.ops.vllm.rocm_aiter_biased_grouped_topk( - gating_output, e_score_correction_bias, topk_weights, topk_ids, - num_expert_group, topk_group, renormalize, scale_factor) + gating_output, + e_score_correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + scale_factor, + ) # Verify the op's fake implementation torch.library.opcheck( @@ -84,51 +89,49 @@ def biased_grouped_topk_fn(gating_output, e_score_correction_bias, "num_expert_group": num_expert_group, "topk_group": topk_group, "need_renorm": renormalize, - "routed_scaling_factor": scale_factor + "routed_scaling_factor": scale_factor, }, - test_utils=("test_faketensor")) + test_utils=("test_faketensor"), + ) # Compile the function with appropriate settings - compiled_fn = torch.compile(biased_grouped_topk_fn, - fullgraph=True, - backend="inductor", - mode="reduce-overhead", - dynamic=False) - - topk_weights_original = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_original = torch.empty((token, topk), - dtype=torch.int32, - device=device) - - topk_weights_compiled = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_compiled = torch.empty((token, topk), - dtype=torch.int32, - device=device) + compiled_fn = torch.compile( + biased_grouped_topk_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False, + ) + + topk_weights_original = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device) + + topk_weights_compiled = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device) # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) - biased_grouped_topk_fn(gating_output, e_score_correction_bias, - topk_weights_original, topk_ids_original) - compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled, - topk_ids_compiled) + biased_grouped_topk_fn( + gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original + ) + compiled_fn( + gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled + ) # Sort the results for comparison since the order might not be deterministic topk_ids_original, indices_original = torch.sort(topk_ids_original) - topk_weights_original = torch.gather(topk_weights_original, 1, - indices_original) + topk_weights_original = torch.gather(topk_weights_original, 1, indices_original) topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) - topk_weights_compiled = torch.gather(topk_weights_compiled, 1, - indices_compiled) + topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled) # Verify results match - assert torch.allclose(topk_weights_original, - topk_weights_compiled, - rtol=1e-2, - atol=1e-2) + assert torch.allclose( + topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2 + ) assert torch.allclose(topk_ids_original, topk_ids_compiled) @@ -144,73 +147,73 @@ def test_rocm_aiter_grouped_topk_torch_compile_compatibility(): scoring_func = "softmax" scale_factor = 1.0 - gating_output = torch.randn((token, expert), - dtype=torch.bfloat16, - device="cuda") + gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda") device = gating_output.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights = torch.empty((token, topk), - dtype=torch.float32, - device=device) + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) # Define a function that uses the op def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func): return torch.ops.vllm.rocm_aiter_grouped_topk( - gating_output, topk_weights, topk_ids, num_expert_group, - topk_group, renormalize, scoring_func, scale_factor) + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + scoring_func, + scale_factor, + ) # Verify the op's fake implementation - torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk, - (gating_output, topk_weights, topk_ids), - kwargs={ - "num_expert_group": num_expert_group, - "topk_group": topk_group, - "need_renorm": renormalize, - "scoring_func": scoring_func, - "routed_scaling_factor": scale_factor - }, - test_utils=("test_faketensor")) + torch.library.opcheck( + torch.ops.vllm.rocm_aiter_grouped_topk, + (gating_output, topk_weights, topk_ids), + kwargs={ + "num_expert_group": num_expert_group, + "topk_group": topk_group, + "need_renorm": renormalize, + "scoring_func": scoring_func, + "routed_scaling_factor": scale_factor, + }, + test_utils=("test_faketensor"), + ) # Compile the function with appropriate settings - compiled_fn = torch.compile(grouped_topk_fn, - fullgraph=True, - backend="inductor", - mode="reduce-overhead", - dynamic=False) - - topk_weights_original = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_original = torch.empty((token, topk), - dtype=torch.int32, - device=device) - - topk_weights_compiled = torch.empty((token, topk), - dtype=torch.float32, - device=device) - topk_ids_compiled = torch.empty((token, topk), - dtype=torch.int32, - device=device) + compiled_fn = torch.compile( + grouped_topk_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False, + ) + + topk_weights_original = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device) + + topk_weights_compiled = torch.empty( + (token, topk), dtype=torch.float32, device=device + ) + topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device) # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) - grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original, - scoring_func) - compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, - scoring_func) + grouped_topk_fn( + gating_output, topk_weights_original, topk_ids_original, scoring_func + ) + compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func) # Sort the results for comparison since the order might not be deterministic topk_ids_original, indices_original = torch.sort(topk_ids_original) - topk_weights_original = torch.gather(topk_weights_original, 1, - indices_original) + topk_weights_original = torch.gather(topk_weights_original, 1, indices_original) topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) - topk_weights_compiled = torch.gather(topk_weights_compiled, 1, - indices_compiled) + topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled) # Verify results match - assert torch.allclose(topk_weights_original, - topk_weights_compiled, - rtol=1e-2, - atol=1e-2) + assert torch.allclose( + topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2 + ) assert torch.allclose(topk_ids_original, topk_ids_compiled) diff --git a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py index 383b5ebfba9b..b6ca80e97e91 100644 --- a/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py +++ b/tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py @@ -5,7 +5,8 @@ import torch from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - silu_mul_fp8_quant_deep_gemm_cuda) + silu_mul_fp8_quant_deep_gemm_cuda, +) from vllm.platforms import current_platform from vllm.utils import cdiv @@ -34,7 +35,6 @@ (256, 16, 7168, fp8_dtype), (256, 32, 7168, fp8_dtype), (256, 64, 7168, fp8_dtype), - # Only add a few fnuz tests to help with long CI times. (8, 512, 7168, torch.float8_e4m3fnuz), (8, 1024, 7168, torch.float8_e4m3fnuz), @@ -52,15 +52,15 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): tokens_per_expert = torch.randint( low=T // 2, high=T, - size=(E, ), + size=(E,), dtype=torch.int32, device="cuda", ) # Run the Triton kernel - y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda(y, - tokens_per_expert, - group_size=group_size) + y_q, y_s = silu_mul_fp8_quant_deep_gemm_cuda( + y, tokens_per_expert, group_size=group_size + ) torch.cuda.synchronize() fp8_info = torch.finfo(fp8_dtype) @@ -75,9 +75,9 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): for e in range(E): nt = tokens_per_expert[e].item() - ref_s = torch.empty((T, cdiv(H, group_size)), - dtype=torch.float32, - device="cuda") + ref_s = torch.empty( + (T, cdiv(H, group_size)), dtype=torch.float32, device="cuda" + ) ref_q = torch.empty((T, H), dtype=fp8_dtype, device="cuda") for t in range(nt): @@ -87,14 +87,17 @@ def test_silu_mul_fp8_quant_deep_gemm(E, T, H, fp8_type): # process full groups n_full_groups = H // group_size if n_full_groups > 0: - data_grp = data[:n_full_groups * group_size].view( - n_full_groups, group_size) + data_grp = data[: n_full_groups * group_size].view( + n_full_groups, group_size + ) amax = data_grp.abs().amax(dim=1).clamp(min=eps) scale = amax / fp8_max - scaled = data[:n_full_groups * - group_size] / scale.repeat_interleave(group_size) - ref_q_row[:n_full_groups * group_size] = scaled.clamp( - fp8_min, fp8_max).to(fp8_dtype) + scaled = data[: n_full_groups * group_size] / scale.repeat_interleave( + group_size + ) + ref_q_row[: n_full_groups * group_size] = scaled.clamp( + fp8_min, fp8_max + ).to(fp8_dtype) ref_s[t, :n_full_groups] = scale # process remainder group diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index 1c31464b30e7..933cd9dbdeaa 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -11,13 +11,11 @@ from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe.config import ( - fp8_w8a8_moe_quant_config) +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config from vllm.platforms import current_platform if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -31,14 +29,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): B = B.to(torch.float32) assert A.shape[-1] == B.shape[-1], "Dimension mismatch" - assert B.ndim == 2 and B.is_contiguous( - ), "B must be a 2D contiguous tensor" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" # Reshape input M = A.numel() // A.shape[-1] B = B.t() # Transpose weight matrix N, K = B.shape - origin_C_shape = A.shape[:-1] + (K, ) + origin_C_shape = A.shape[:-1] + (K,) A = A.reshape(M, N) # As is per-token [M, 1], Bs is per-column [1, K] @@ -88,17 +85,17 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): act_out = SiluAndMul().forward_native(inter_out) # Quantize activation output with per-token act_out_q, act_out_s = ops.scaled_fp8_quant( - act_out, use_per_token_if_dynamic=True) + act_out, use_per_token_if_dynamic=True + ) # Second MLP layer - out[mask] = native_w8a8_per_token_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - output_dtype=a.dtype) + out[mask] = native_w8a8_per_token_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype + ) # Apply routing weights and sum - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) @pytest.fixture(autouse=True, scope="module") @@ -116,8 +113,10 @@ def setup_cuda(): SEEDS = [0] -@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed", - itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS)) +@pytest.mark.parametrize( + "M, N, K, E, topk, dtype, seed", + itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): torch.manual_seed(seed) @@ -133,12 +132,10 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): # Generate int8 weights w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 - w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, - max=fp8_max).to(torch.float8_e4m3fn) + w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 - w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, - max=fp8_max).to(torch.float8_e4m3fn) + w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) # Generate scale for each column (per-column quantization) w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale @@ -163,7 +160,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): ) # Check results - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.05 diff --git a/tests/kernels/moe/utils.py b/tests/kernels/moe/utils.py index 7a0feb6a2079..9466dacb0c11 100644 --- a/tests/kernels/moe/utils.py +++ b/tests/kernels/moe/utils.py @@ -6,17 +6,17 @@ import vllm._custom_ops as ops from tests.kernels.quant_utils import per_block_cast_to_int8 -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX) +from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + BatchedPrepareAndFinalize, + BatchedTritonExperts, + NaiveBatchedExperts, +) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils import round_up from vllm.utils.deep_gemm import per_block_cast_to_fp8 @@ -45,12 +45,7 @@ def triton_moe( a2_scale=a2_scale, ) - return fused_experts(a, - w1, - w2, - topk_weight, - topk_ids, - quant_config=quant_config) + return fused_experts(a, w1, w2, topk_weight, topk_ids, quant_config=quant_config) def batched_moe( @@ -80,10 +75,9 @@ def batched_moe( ) fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(max_num_tokens, - num_dispatchers=1, - num_local_experts=w1.shape[0], - rank=0), + BatchedPrepareAndFinalize( + max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 + ), BatchedTritonExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, @@ -121,10 +115,9 @@ def naive_batched_moe( ) fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(max_num_tokens, - num_dispatchers=1, - num_local_experts=w1.shape[0], - rank=0), + BatchedPrepareAndFinalize( + max_num_tokens, num_dispatchers=1, num_local_experts=w1.shape[0], rank=0 + ), NaiveBatchedExperts( max_num_tokens=max_num_tokens, num_dispatchers=1, @@ -135,8 +128,9 @@ def naive_batched_moe( return fused_experts(a, w1, w2, topk_weight, topk_ids) -def chunk_scales(scales: Optional[torch.Tensor], start: int, - end: int) -> Optional[torch.Tensor]: +def chunk_scales( + scales: Optional[torch.Tensor], start: int, end: int +) -> Optional[torch.Tensor]: if scales is not None: if scales.numel() == 1: return scales @@ -159,13 +153,15 @@ def make_quantized_test_activations( a_scale = None if quant_dtype is not None: - assert (quant_dtype == torch.float8_e4m3fn - or quant_dtype == torch.int8), "only fp8/int8 supported" + assert quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8, ( + "only fp8/int8 supported" + ) a_q = torch.zeros_like(a, dtype=quant_dtype) a_scale_l = [None] * E for e in range(E): a_q[e], a_scale_l[e] = moe_kernel_quantize_input( - a[e], None, quant_dtype, per_act_token_quant, block_shape) + a[e], None, quant_dtype, per_act_token_quant, block_shape + ) a_scale = torch.stack(a_scale_l) if not per_act_token_quant and block_shape is None: @@ -181,8 +177,11 @@ def moe_quantize_weights( per_token_quant: bool, block_shape: Optional[list[int]], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8 - or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported" + assert ( + quant_dtype == torch.float8_e4m3fn + or quant_dtype == torch.int8 + or quant_dtype == "nvfp4" + ), "only fp8/int8/nvfp4 supported" w_gs = None @@ -199,10 +198,12 @@ def moe_quantize_weights( else: if quant_dtype == torch.int8: w, w_s = ops.scaled_int8_quant( - w, w_s, use_per_token_if_dynamic=per_token_quant) + w, w_s, use_per_token_if_dynamic=per_token_quant + ) elif quant_dtype == torch.float8_e4m3fn: w, w_s = ops.scaled_fp8_quant( - w, w_s, use_per_token_if_dynamic=per_token_quant) + w, w_s, use_per_token_if_dynamic=per_token_quant + ) elif quant_dtype == "nvfp4": assert not per_token_quant w_amax = torch.abs(w).max().to(torch.float32) @@ -222,8 +223,7 @@ def make_test_weight( quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, per_out_ch_quant: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 w_gs = None @@ -233,7 +233,8 @@ def make_test_weight( w_gs_l = [None] * e for idx in range(e): w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights( - w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape) + w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape + ) w = torch.stack(w_l) w_s = torch.stack(w_s_l) @@ -264,26 +265,25 @@ def make_test_weights( quant_dtype: Union[torch.dtype, str, None] = None, block_shape: Optional[list[int]] = None, per_out_ch_quant: bool = False, -) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]], - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]]: +) -> tuple[ + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]], +]: return ( - make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, - per_out_ch_quant), - make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, - per_out_ch_quant), + make_test_weight( + e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant + ), + make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant), ) def per_token_cast_to_fp8( - x: torch.Tensor, - block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, block_size: int = 128 +) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape pad_size = (block_size - (n % block_size)) % block_size - x = torch.nn.functional.pad(x, - (0, pad_size), value=0) if pad_size > 0 else x + x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x x_view = x.view(m, -1, block_size) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) @@ -313,27 +313,31 @@ def make_test_quant_config( a1_gscale: Optional[torch.Tensor] = None a2_gscale: Optional[torch.Tensor] = None if quant_dtype == "nvfp4": - a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32) - a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32) + a1_gscale = torch.ones((e,), device="cuda", dtype=torch.float32) + a2_gscale = torch.ones((e,), device="cuda", dtype=torch.float32) a1_scale = a1_gscale a2_scale = a2_gscale else: a1_scale = None a2_scale = None - return w1, w2, FusedMoEQuantConfig.make( - quant_dtype, - per_act_token_quant=per_act_token_quant, - block_shape=block_shape, - w1_scale=w1_s, - w2_scale=w2_s, - a1_gscale=a1_gscale, - a2_gscale=a2_gscale, - a1_scale=a1_scale, - a2_scale=a2_scale, - # TODO: make sure this is handled properly - g1_alphas=(1 / w1_gs) if w1_gs is not None else None, - g2_alphas=(1 / w2_gs) if w2_gs is not None else None, + return ( + w1, + w2, + FusedMoEQuantConfig.make( + quant_dtype, + per_act_token_quant=per_act_token_quant, + block_shape=block_shape, + w1_scale=w1_s, + w2_scale=w2_s, + a1_gscale=a1_gscale, + a2_gscale=a2_gscale, + a1_scale=a1_scale, + a2_scale=a2_scale, + # TODO: make sure this is handled properly + g1_alphas=(1 / w1_gs) if w1_gs is not None else None, + g2_alphas=(1 / w2_gs) if w2_gs is not None else None, + ), ) @@ -348,21 +352,23 @@ def fused_moe( global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, ) -> torch.Tensor: - topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk, - renormalize) - return fused_experts(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - global_num_experts=global_num_experts, - expert_map=expert_map, - quant_config=quant_config) + topk_weights, topk_ids, _ = fused_topk( + hidden_states, score.float(), topk, renormalize + ) + return fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=quant_config, + ) # CustomOp? class BaselineMM(torch.nn.Module): - def __init__( self, b: torch.Tensor, @@ -372,15 +378,11 @@ def __init__( self.b = b.to(dtype=torch.float32) self.out_dtype = out_dtype - def forward( - self, - a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - return torch.mm(a.to(dtype=torch.float32), - self.b).to(self.out_dtype), None + def forward(self, a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return torch.mm(a.to(dtype=torch.float32), self.b).to(self.out_dtype), None class TestMLP(torch.nn.Module): - def __init__( self, w1: torch.Tensor, @@ -410,7 +412,6 @@ def make_naive_shared_experts( class RealMLP(torch.nn.Module): - def __init__( self, hidden_size: int, @@ -425,37 +426,48 @@ def __init__( w2_s: Optional[torch.Tensor] = None, ) -> None: from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, RowParallelLinear) + MergedColumnParallelLinear, + RowParallelLinear, + ) super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") + prefix=f"{prefix}.gate_up_proj", + ) self.gate_up_proj.register_parameter( - "weight", torch.nn.Parameter(w1, requires_grad=False)) + "weight", torch.nn.Parameter(w1, requires_grad=False) + ) self.gate_up_proj.register_parameter( - "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False)) + "weight_scale", torch.nn.Parameter(w1_s, requires_grad=False) + ) self.gate_up_proj.register_parameter( - "input_scale", - None) #torch.nn.Parameter(None, requires_grad=False)) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + "input_scale", None + ) # torch.nn.Parameter(None, requires_grad=False)) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) self.down_proj.register_parameter( - "weight", torch.nn.Parameter(w2, requires_grad=False)) + "weight", torch.nn.Parameter(w2, requires_grad=False) + ) self.down_proj.register_parameter( - "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False)) + "weight_scale", torch.nn.Parameter(w2_s, requires_grad=False) + ) self.down_proj.register_parameter( - "input_scale", - None) #torch.nn.Parameter(None, requires_grad=False)) + "input_scale", None + ) # torch.nn.Parameter(None, requires_grad=False)) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -496,13 +508,6 @@ def make_shared_experts( w2_s = None quant_config = None - return RealMLP(K, - N, - w1, - w2, - "silu", - quant_config, - w1_s=w1_s, - w2_s=w2_s) + return RealMLP(K, N, w1, w2, "silu", quant_config, w1_s=w1_s, w2_s=w2_s) finally: torch.set_default_dtype(old_dtype) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 01a1ad2e7a0a..d892f2a5acc0 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -5,8 +5,7 @@ import torch -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - group_broadcast) +from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast from vllm.platforms import current_platform from vllm.utils import round_up @@ -17,25 +16,31 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: - return torch.as_tensor(x, dtype=torch.float32, device='cuda') + return torch.as_tensor(x, dtype=torch.float32, device="cuda") -def ref_dynamic_per_token_quant(x: torch.tensor, - quant_dtype: torch.dtype, - scale_ub: Optional[torch.tensor] = None) \ - -> tuple[torch.tensor, torch.tensor]: +def ref_dynamic_per_token_quant( + x: torch.tensor, quant_dtype: torch.dtype, scale_ub: Optional[torch.tensor] = None +) -> tuple[torch.tensor, torch.tensor]: assert quant_dtype in [torch.int8, FP8_DTYPE] if scale_ub is not None: assert quant_dtype == FP8_DTYPE - qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ - else torch.finfo(quant_dtype) - qtype_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else qtype_traits.max - qtype_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else qtype_traits.min + qtype_traits = ( + torch.iinfo(quant_dtype) + if quant_dtype == torch.int8 + else torch.finfo(quant_dtype) + ) + qtype_traits_max = ( + ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else qtype_traits.max + ) + qtype_traits_min = ( + -ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else qtype_traits.min + ) qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) s_512 = as_float32_tensor(512.0) @@ -56,15 +61,13 @@ def ref_dynamic_per_token_quant(x: torch.tensor, iscales = as_float32_tensor(s_1 / scales) torch_out = as_float32_tensor(x) * iscales torch_out = torch_out.round() - torch_out = torch_out.clamp(qtype_traits_min, - qtype_traits_max).to(quant_dtype) + torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype) else: assert quant_dtype == FP8_DTYPE min_scaling_factor = s_1 / (qtype_max * s_512) scales = scales.clamp(min=min_scaling_factor) torch_out = as_float32_tensor(x) / scales - torch_out = torch_out.clamp(qtype_traits_min, - qtype_traits_max).to(quant_dtype) + torch_out = torch_out.clamp(qtype_traits_min, qtype_traits_max).to(quant_dtype) return torch_out, scales @@ -72,16 +75,20 @@ def ref_dynamic_per_token_quant(x: torch.tensor, # The int8 version is very similar. Incorporate the int8 version, like in # ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant # kernel -def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ - -> tuple[torch.tensor, torch.tensor]: - +def ref_dynamic_per_tensor_fp8_quant( + x: torch.tensor, +) -> tuple[torch.tensor, torch.tensor]: fp8_traits = torch.finfo(FP8_DTYPE) - fp8_traits_max = ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else fp8_traits.max - fp8_traits_min = -ROCM_FP8FNUZ_MAX if current_platform.is_rocm() \ - and current_platform.is_fp8_fnuz() \ - else fp8_traits.min + fp8_traits_max = ( + ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else fp8_traits.max + ) + fp8_traits_min = ( + -ROCM_FP8FNUZ_MAX + if current_platform.is_rocm() and current_platform.is_fp8_fnuz() + else fp8_traits.min + ) fp8_max = as_float32_tensor(fp8_traits_max) one = as_float32_tensor(1.0) @@ -92,9 +99,12 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ x_max = as_float32_tensor(x.abs().max()) ref_scale = x_max / fp8_max ref_iscale = one / ref_scale - ref_out = (as_float32_tensor(x) * ref_iscale).clamp( - fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) - return ref_out, ref_scale.view((1, )) + ref_out = ( + (as_float32_tensor(x) * ref_iscale) + .clamp(fp8_traits_min, fp8_traits_max) + .to(FP8_DTYPE) + ) + return ref_out, ref_scale.view((1,)) def native_w8a8_block_matmul( @@ -126,7 +136,7 @@ def native_w8a8_block_matmul( M = A.numel() // A.shape[-1] N, K = B.shape - origin_C_shape = A.shape[:-1] + (N, ) + origin_C_shape = A.shape[:-1] + (N,) A = A.reshape(M, A.shape[-1]) As = As.reshape(M, As.shape[-1]) n_tiles = (N + block_n - 1) // block_n @@ -137,19 +147,19 @@ def native_w8a8_block_matmul( C_shape = (M, N) C = torch.zeros(C_shape, dtype=compute_type, device=A.device) - A_tiles = [ - A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) - ] - B_tiles = [[ - B[ - j * block_n:min((j + 1) * block_n, N), - i * block_k:min((i + 1) * block_k, K), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - C_tiles = [ - C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) + A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [ + [ + B[ + j * block_n : min((j + 1) * block_n, N), + i * block_k : min((i + 1) * block_k, K), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) ] - As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] + C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] for i in range(k_tiles): for j in range(n_tiles): @@ -163,14 +173,14 @@ def native_w8a8_block_matmul( return C -def native_per_token_group_quant_fp8(x, - group_size, - eps=1e-10, - dtype=torch.float8_e4m3fn): +def native_per_token_group_quant_fp8( + x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn +): """Function to perform per-token-group quantization on an input tensor `x` using native torch.""" - assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must " - "be divisible by `group_size`") + assert x.shape[-1] % group_size == 0, ( + "the last dimension of `x` must be divisible by `group_size`" + ) assert x.is_contiguous(), "`x` is not contiguous" finfo = torch.finfo(dtype) @@ -178,28 +188,25 @@ def native_per_token_group_quant_fp8(x, fp8_max = finfo.max x_ = x.reshape(x.numel() // group_size, group_size) - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) x_s = amax / fp8_max x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) return x_q, x_s -def native_per_token_group_quant_int8(x, - group_size, - eps=1e-10, - dtype=torch.int8): +def native_per_token_group_quant_int8(x, group_size, eps=1e-10, dtype=torch.int8): """Function to perform per-token-group quantization on an input tensor `x` using native torch. It converts the tensor values into int8 values and returns the quantized tensor along with the scaling factor used for quantization. """ - assert (x.shape[-1] % group_size == 0 - ), "the last dimension of `x` must be divisible by `group_size`" + assert x.shape[-1] % group_size == 0, ( + "the last dimension of `x` must be divisible by `group_size`" + ) assert x.is_contiguous(), "`x` is not contiguous" iinfo = torch.iinfo(dtype) @@ -208,13 +215,13 @@ def native_per_token_group_quant_int8(x, x_ = x.reshape(x.numel() // group_size, group_size) # Use float32 for scale calculation for stability - amax = x_.abs().max(dim=-1, - keepdim=True)[0].clamp(min=eps).to(torch.float32) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) x_s = amax / int8_max - x_q = (x_.to(torch.float32) / x_s).round().clamp( - min=int8_min, max=int8_max).to(dtype) # Round before clamping + x_q = ( + (x_.to(torch.float32) / x_s).round().clamp(min=int8_min, max=int8_max).to(dtype) + ) # Round before clamping x_q = x_q.reshape(x.shape) - x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) return x_q, x_s @@ -229,9 +236,9 @@ def per_block_cast_to_int8( block_m, block_n = block_shape assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)), - dtype=x.dtype, - device=x.device) + x_padded = torch.zeros( + (round_up(m, block_m), round_up(n, block_n)), dtype=x.dtype, device=x.device + ) x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) @@ -269,8 +276,9 @@ def batched_dequant( assert t.shape[0] == scale.shape[0] out = torch.empty_like(t, dtype=out_dtype) for e in range(t.shape[0]): - out[e] = dequant(t[e], scale[e], block_shape, per_act_token_quant, - out_dtype) + out[e] = dequant( + t[e], scale[e], block_shape, per_act_token_quant, out_dtype + ) return out return t.to(out_dtype) @@ -294,15 +302,17 @@ def native_batched_masked_quant_matmul( num_tokens = num_expert_tokens_cpu[e] if A.dtype.itemsize == 1 and block_shape is not None: assert A_scale is not None and B_scale is not None - tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e], - block_shape, C.dtype) + tmp = native_w8a8_block_matmul( + A[e], B[e], A_scale[e], B_scale[e], block_shape, C.dtype + ) C[e, :num_tokens, :] = tmp[:num_tokens, :] elif A.dtype.itemsize == 1 and block_shape is None: assert A_scale is not None and B_scale is not None A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant) B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant) - C[e, :num_tokens, :] = ( - A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype) + C[e, :num_tokens, :] = (A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to( + C.dtype + ) else: assert A_scale is None assert B_scale is None diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py index fc4e12555018..50be6841560b 100644 --- a/tests/kernels/quantization/nvfp4_utils.py +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -8,8 +8,9 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max -kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], - dtype=torch.float32) +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): @@ -22,12 +23,9 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): return out[0:m, 0:k] -def dequantize_nvfp4_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): +def dequantize_nvfp4_to_dtype( + tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 +): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. assert tensor_fp4.dtype == torch.uint8 @@ -69,7 +67,8 @@ def break_fp4_bytes(a, dtype): def quant_nvfp4_tensor(a: torch.Tensor): - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.abs(a).max().to(torch.float32)) + a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to( + torch.float32 + ) a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale) return a_quant, a_block_scale, a_global_scale diff --git a/tests/kernels/quantization/test_allspark_gemm.py b/tests/kernels/quantization/test_allspark_gemm.py index 3de9cb364468..e5f056f04f8c 100644 --- a/tests/kernels/quantization/test_allspark_gemm.py +++ b/tests/kernels/quantization/test_allspark_gemm.py @@ -6,24 +6,25 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.allspark_utils import ( - ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, - ALLSPARK_AMPERE_N_ALIGN) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - quantize_weights) + ALLSPARK_AMPERE_K_ALIGN, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + ALLSPARK_AMPERE_N_ALIGN, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -def is_gptq_allspark_supported(min_capability: int, - max_capability: int) -> bool: +def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool: if not current_platform.is_cuda(): return False capability = current_platform.get_device_capability() assert capability is not None - return capability.to_int() >= min_capability \ - and capability.to_int() <= max_capability + return ( + capability.to_int() >= min_capability and capability.to_int() <= max_capability + ) MNK_FACTORS = [ @@ -43,7 +44,8 @@ def is_gptq_allspark_supported(min_capability: int, def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) def rand_data(shape, dtype=torch.float16): @@ -52,7 +54,8 @@ def rand_data(shape, dtype=torch.float16): @pytest.mark.skipif( not is_gptq_allspark_supported(80, 89), - reason="AllSpark Ampere kernel is not supported on this GPU type.") + reason="AllSpark Ampere kernel is not supported on this GPU type.", +) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("group_size", [-1]) @pytest.mark.parametrize("has_zp", HAS_ZP_OPTS) @@ -67,8 +70,9 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): weight = rand_data((k, n), dtype=dtype) # Quantize (and apply act_order if provided) - w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128, - group_size, has_zp) + w_ref, qw, s, zp = quantize_weights( + weight, scalar_types.uint8b128, group_size, has_zp + ) qw = qw.to(torch.uint8) if has_zp: @@ -79,20 +83,42 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype): n_32align = (n + 32 - 1) // 32 * 32 - qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( - qw, s, zp, has_zp) - opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order, - (qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, - n_32align)) - - opcheck(torch.ops._C.allspark_w8a16_gemm, - (input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count, - sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) - output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder, - n, group_size, sm_count, sm_version, - ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, - has_zp, True) + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(qw, s, zp, has_zp) + opcheck( + torch.ops._C.rearrange_kn_weight_as_n32k16_order, + (qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, n_32align), + ) + + opcheck( + torch.ops._C.allspark_w8a16_gemm, + ( + input, + qw_reorder, + s_reorder, + zp_reorder, + n, + group_size, + sm_count, + sm_version, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp, + True, + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + output = ops.allspark_w8a16_gemm( + input, + qw_reorder, + s_reorder, + zp_reorder, + n, + group_size, + sm_count, + sm_version, + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + has_zp, + True, + ) output_ref = torch.matmul(input, w_ref) torch.cuda.synchronize() diff --git a/tests/kernels/quantization/test_awq.py b/tests/kernels/quantization/test_awq.py index bc0868123d82..efb62ca3799a 100644 --- a/tests/kernels/quantization/test_awq.py +++ b/tests/kernels/quantization/test_awq.py @@ -8,40 +8,42 @@ from vllm import _custom_ops as ops # noqa: F401 -@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"), - reason="AWQ is not supported on this GPU type.") +@pytest.mark.skipif( + not hasattr(torch.ops._C, "awq_dequantize"), + reason="AWQ is not supported on this GPU type.", +) def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_TRITON_AWQ", "0") - qweight = torch.randint(-2000000000, - 2000000000, (8192, 256), - device='cuda', - dtype=torch.int32) - scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16) - zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32) + qweight = torch.randint( + -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32 + ) + scales = torch.rand((64, 2048), device="cuda", dtype=torch.float16) + zeros = torch.empty((64, 256), device="cuda", dtype=torch.int32) split_k_iters = 0 thx = 0 thy = 0 - opcheck(torch.ops._C.awq_dequantize, - (qweight, scales, zeros, split_k_iters, thx, thy)) + opcheck( + torch.ops._C.awq_dequantize, + (qweight, scales, zeros, split_k_iters, thx, thy), + ) @pytest.mark.skip(reason="Not working; needs investigation.") -@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"), - reason="AWQ is not supported on this GPU type.") +@pytest.mark.skipif( + not hasattr(torch.ops._C, "awq_gemm"), + reason="AWQ is not supported on this GPU type.", +) def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_TRITON_AWQ", "0") - input = torch.rand((2, 8192), device='cuda', dtype=torch.float16) - qweight = torch.randint(-2000000000, - 2000000000, (8192, 256), - device='cuda', - dtype=torch.int32) - scales = torch.randint(-2000000000, - 2000000000, (64, 256), - device='cuda', - dtype=torch.int32) - qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16) + input = torch.rand((2, 8192), device="cuda", dtype=torch.float16) + qweight = torch.randint( + -2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32 + ) + scales = torch.randint( + -2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32 + ) + qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16) split_k_iters = 8 - opcheck(torch.ops._C.awq_gemm, - (input, qweight, qzeros, scales, split_k_iters)) + opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters)) diff --git a/tests/kernels/quantization/test_awq_triton.py b/tests/kernels/quantization/test_awq_triton.py index 9354495642b2..069bd7435534 100644 --- a/tests/kernels/quantization/test_awq_triton.py +++ b/tests/kernels/quantization/test_awq_triton.py @@ -4,11 +4,15 @@ Run `pytest tests/kernels/quantization/test_awq_triton.py`. """ + import pytest import torch from vllm.model_executor.layers.quantization.awq_triton import ( - AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) + AWQ_TRITON_SUPPORTED_GROUP_SIZES, + awq_dequantize_triton, + awq_gemm_triton, +) from vllm.platforms import current_platform device = "cuda" @@ -33,23 +37,24 @@ def reverse_awq_order(t: torch.Tensor): # qweights - [R , C // 8], int32 # scales - [R // G, C ], float16 # zeros - [R // G, C // 8], int32 -def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, - qzeros: torch.Tensor, - group_size: int) -> torch.Tensor: - +def awq_dequantize_torch( + qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int +) -> torch.Tensor: if group_size == -1: group_size = qweight.shape[0] bits = 4 shifts = torch.arange(0, 32, bits, device=qzeros.device) - iweights = torch.bitwise_right_shift(qweight[:, :, None], - shifts[None, None, :]).to(torch.int8) + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) iweights = iweights.view(iweights.shape[0], -1) - zeros = torch.bitwise_right_shift(qzeros[:, :, None], - shifts[None, None, :]).to(torch.int8) + zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) zeros = zeros.view(qzeros.shape[0], -1) zeros = reverse_awq_order(zeros) @@ -70,7 +75,6 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, @pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128]) @pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) def test_dequantize(qweight_rows, qweight_cols, group_size): - if group_size == -1: group_size = qweight_rows @@ -84,25 +88,27 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): current_platform.seed_everything(0) - qweight = torch.randint(0, - torch.iinfo(torch.int32).max, - (qweight_rows, qweight_cols), - dtype=qweight_dtype, - device=device) - scales = torch.rand(scales_rows, - scales_cols, - dtype=scales_dtype, - device=device) - zeros = torch.randint(0, - torch.iinfo(torch.int32).max, - (zeros_rows, zeros_cols), - dtype=zeros_dtype, - device=device) + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + dtype=qweight_dtype, + device=device, + ) + scales = torch.rand(scales_rows, scales_cols, dtype=scales_dtype, device=device) + zeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (zeros_rows, zeros_cols), + dtype=zeros_dtype, + device=device, + ) iweights_triton = awq_dequantize_triton(qweight, scales, zeros) - assert (not torch.any(torch.isinf(iweights_triton)) - and not torch.any(torch.isnan(iweights_triton))) + assert not torch.any(torch.isinf(iweights_triton)) and not torch.any( + torch.isnan(iweights_triton) + ) iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size) @@ -119,7 +125,6 @@ def test_dequantize(qweight_rows, qweight_cols, group_size): @pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("splitK", [1, 8]) def test_gemm(N, K, M, splitK, group_size): - if group_size == -1: group_size = K @@ -138,35 +143,29 @@ def test_gemm(N, K, M, splitK, group_size): current_platform.seed_everything(0) - input = torch.rand((input_rows, input_cols), - dtype=input_dtype, - device=device) - qweight = torch.randint(0, - torch.iinfo(torch.int32).max, - (qweight_rows, qweight_cols), - device=device) - qzeros = torch.randint(0, - torch.iinfo(torch.int32).max, - (qzeros_rows, qzeros_cols), - device=device) - scales = torch.rand((scales_rows, scales_cols), - dtype=scales_dtype, - device=device) - - output_triton = awq_gemm_triton(input, qweight, scales, qzeros, - split_k_iters) - - assert (not torch.any(torch.isinf(output_triton)) - and not torch.any(torch.isnan(output_triton))) + input = torch.rand((input_rows, input_cols), dtype=input_dtype, device=device) + qweight = torch.randint( + 0, torch.iinfo(torch.int32).max, (qweight_rows, qweight_cols), device=device + ) + qzeros = torch.randint( + 0, torch.iinfo(torch.int32).max, (qzeros_rows, qzeros_cols), device=device + ) + scales = torch.rand((scales_rows, scales_cols), dtype=scales_dtype, device=device) + + output_triton = awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) + + assert not torch.any(torch.isinf(output_triton)) and not torch.any( + torch.isnan(output_triton) + ) dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros) output_torch = torch.matmul(input, dequantized_weights) - assert (not torch.any(torch.isinf(output_torch)) - and not torch.any(torch.isnan(output_torch))) + assert not torch.any(torch.isinf(output_torch)) and not torch.any( + torch.isnan(output_torch) + ) - torch.testing.assert_close(output_triton.cpu(), - output_torch.cpu(), - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close( + output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1 + ) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index e02df540ce9d..a6dfb5428c52 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -7,20 +7,26 @@ import pytest import torch -from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, - native_w8a8_block_matmul) +from tests.kernels.quant_utils import ( + native_per_token_group_quant_fp8, + native_w8a8_block_matmul, +) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm) + cutlass_scaled_mm, + per_token_group_quant_fp8, + w8a8_triton_block_scaled_mm, +) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import (fp8_gemm_nt, - get_col_major_tma_aligned_tensor, - per_block_cast_to_fp8) +from vllm.utils.deep_gemm import ( + fp8_gemm_nt, + get_col_major_tma_aligned_tensor, + per_block_cast_to_fp8, +) if current_platform.get_device_capability() < (9, 0): - pytest.skip("FP8 Triton requires CUDA 9.0 or higher", - allow_module_level=True) + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -51,7 +57,8 @@ def setup_cuda(): @pytest.mark.parametrize( "num_tokens,d,dtype,group_size,seed", - itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS)) + itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS), +) @torch.inference_mode() def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): torch.manual_seed(seed) @@ -60,15 +67,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) out, scale = per_token_group_quant_fp8(x, group_size) - assert torch.allclose(out.to(torch.float32), - ref_out.to(torch.float32), - rtol=0.15) + assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15) assert torch.allclose(scale, ref_scale) @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): torch.manual_seed(seed) @@ -89,14 +95,12 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) - out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 @@ -127,32 +131,32 @@ def test_w8a8_block_fp8_cutlass_matmul(): Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale # Hopper requires row-major format for scales - Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability( - 90) else Bs + Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(90) else Bs - A_fp8, As = per_token_group_quant_fp8(A_fp32, - block_size[1], - column_major_scales=False) + A_fp8, As = per_token_group_quant_fp8( + A_fp32, block_size[1], column_major_scales=False + ) # CUTLASS uses column-major format for scales A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8( - A_fp32, block_size[1], column_major_scales=True) + A_fp32, block_size[1], column_major_scales=True + ) - ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) - out = cutlass_scaled_mm(A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, - block_size, out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + out = cutlass_scaled_mm( + A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, block_size, out_dtype + ) - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) -@pytest.mark.skipif(not has_deep_gemm(), - reason="DeepGemm kernels not available.") + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS), +) +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -172,20 +176,20 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): As = As_fp8.to(torch.float32) Bs = Bs_fp8.to(torch.float32) - ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) # Transpose earlier so that the testing will not trigger transposing kernels As_fp8 = get_col_major_tma_aligned_tensor(As_fp8) - out = torch.zeros((M, N), device='cuda', dtype=out_dtype) + out = torch.zeros((M, N), device="cuda", dtype=out_dtype) - assert As_fp8.shape == (M, (K + 127) // - 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" + assert As_fp8.shape == (M, (K + 127) // 128), ( + f"{As_fp8.shape} != {(M, (K + 127) // 128)}" + ) fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index fac82cf9c8b5..dabc10a122f7 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -10,12 +10,12 @@ from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.int8_utils import ( - w8a8_block_int8_matmul) + w8a8_block_int8_matmul, +) from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): - pytest.skip("INT8 Triton requires CUDA 7.0 or higher", - allow_module_level=True) + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) vllm_config = VllmConfig() vllm_config.scheduler_config.max_num_seqs = 128 @@ -36,8 +36,10 @@ def setup_cuda(): torch.set_default_device("cuda") -@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", - itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS)) +@pytest.mark.parametrize( + "M,N,K,block_size,out_dtype,seed", + itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): torch.manual_seed(seed) @@ -58,11 +60,10 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale - ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, - out_dtype) + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.001 diff --git a/tests/kernels/quantization/test_cutlass_2of4_sparse.py b/tests/kernels/quantization/test_cutlass_2of4_sparse.py index ae61b3b3a28a..cfdb3658028a 100644 --- a/tests/kernels/quantization/test_cutlass_2of4_sparse.py +++ b/tests/kernels/quantization/test_cutlass_2of4_sparse.py @@ -11,12 +11,11 @@ from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8 from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - sparse_cutlass_supported) + sparse_cutlass_supported, +) from vllm.platforms import current_platform -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] @@ -40,9 +39,7 @@ def prune_to_2_4(tensor): # Create binary mask mask = torch.zeros_like(reshaped) - mask.scatter_(dim=1, - index=indices, - src=torch.ones_like(indices, dtype=mask.dtype)) + mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype)) # Apply mask and reshape back pruned = reshaped * mask @@ -55,32 +52,31 @@ def prune_to_2_4(tensor): # This function checks that applying an identity matrix multiplication # to the compressed weights yields the original uncompressed weights. -def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor, - b_compressed: torch.Tensor, - b_metadata: torch.Tensor): - +def check_compress_decompress_invariance( + dtype: torch.dtype, + b: torch.Tensor, + b_compressed: torch.Tensor, + b_metadata: torch.Tensor, +): # For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the # same dtype as its inputs. This line addresses that constraint while # arbitrarily using bfloat16 for the int8/fp8 cases. out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16 - eye = torch.eye(b.shape[0], device='cuda', dtype=dtype) - eye_scale = torch.ones(1, device='cuda', dtype=torch.float32) - b_decomp = ops.cutlass_scaled_sparse_mm(eye, - b_compressed, - b_metadata, - eye_scale, - eye_scale, - out_dtype=out_dtype) + eye = torch.eye(b.shape[0], device="cuda", dtype=dtype) + eye_scale = torch.ones(1, device="cuda", dtype=torch.float32) + b_decomp = ops.cutlass_scaled_sparse_mm( + eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype + ) torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp) def make_rand_sparse_tensors( - dtype: torch.dtype, m: int, n: int, k: int + dtype: torch.dtype, m: int, n: int, k: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') - b = torch.randn((n, k), device='cuda').t() + a = torch.randn((m, k), device="cuda") + b = torch.randn((n, k), device="cuda").t() if dtype == torch.int8: # ensure A and B aren't all zeros after rounding @@ -107,32 +103,25 @@ def make_rand_sparse_tensors( return b_compressed, e, a, b -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) # Test working with a subset of A and B for sparse matmul def test_cutlass_sparse_subset(): - big_m = 1024 m, n, k = 512, 512, 512 # Create tensors - b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, - big_m, n, k) + b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k) a = whole_a[0:m, 0:k] scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=torch.bfloat16) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16 + ) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) @@ -161,105 +150,87 @@ def test_cutlass_sparse_subset(): # Test working with a subset of A and B for sparse matmul -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) @pytest.mark.parametrize("m, n, k", MNK_FACTORS) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype], - use_bias: bool): - +def test_cutlass_sparse_gemm( + m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool +): # Create tensors b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32) scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32) - bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None + bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=dtype, - bias=bias) + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias + ) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=dtype, - bias=bias) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1) -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) @pytest.mark.parametrize("m, k, n", MNK_FACTORS) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("use_bias", [True, False]) def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool): - # Create tensors b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) - scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) - scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) out_dtype = torch.bfloat16 - bias = torch.rand( - (n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None + bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) + baseline = baseline_scaled_mm( + a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1) -@pytest.mark.skipif(not sparse_cutlass_supported(), - reason="Sparse CUTLASS is not supported on this GPU type.") +@pytest.mark.skipif( + not sparse_cutlass_supported(), + reason="Sparse CUTLASS is not supported on this GPU type.", +) @pytest.mark.parametrize("m,k,n", MNK_FACTORS) @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool, - per_out_ch: bool, use_bias: bool): - +def test_cutlass_sparse_int8_gemm( + m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool +): # Create tensors b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) - scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) - scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32)) + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) out_dtype = torch.bfloat16 - bias = torch.rand( - (n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None - - out = ops.cutlass_scaled_sparse_mm(a, - b_comp, - e, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) - - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=bias) + bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None + + out = ops.cutlass_scaled_sparse_mm( + a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) + + baseline = baseline_scaled_mm( + a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias + ) torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0) diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 65320509e173..835c067e2f72 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -4,6 +4,7 @@ Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`. """ + import random import pytest @@ -36,9 +37,7 @@ (512, 24576, 128), ] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] # -1 means full extent in that dimension TENSORWISE_GROUP_SHAPE = (-1, -1) @@ -60,18 +59,19 @@ def group_scale_helper(shape, group_shape): def scale_shape(shape, group_shape): assert len(shape) == len(group_shape) group_shape = group_scale_helper(shape, group_shape) - return tuple( - cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) - - -def cutlass_fp8_gemm_helper(m: int, - n: int, - k: int, - a_scale_group_shape: tuple, - b_scale_group_shape: tuple, - use_bias: bool, - out_dtype: type[torch.dtype] = torch.bfloat16, - device: str = "cuda"): + return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) + + +def cutlass_fp8_gemm_helper( + m: int, + n: int, + k: int, + a_scale_group_shape: tuple, + b_scale_group_shape: tuple, + use_bias: bool, + out_dtype: type[torch.dtype] = torch.bfloat16, + device: str = "cuda", +): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. a = to_fp8(torch.randn((m, k), device=device)) @@ -80,36 +80,34 @@ def cutlass_fp8_gemm_helper(m: int, a_scales_shape = scale_shape(a.shape, a_scale_group_shape) b_scales_shape = scale_shape(b.shape, b_scale_group_shape) - scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) - scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32) + scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32) # make scales M-major for blockwise quant, doesn't affect 1D scales scale_a = scale_a.t().contiguous().t() # make scales K-major for blockwise quant, doesn't affect 1D scales scale_b = scale_b.t().contiguous().t() - if use_bias: - bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 - else: - bias = None + bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1) - opcheck(torch.ops._C.cutlass_scaled_mm, - (out, a, b, scale_a, scale_b, bias)) + opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) -def cutlass_int8_gemm_helper(m: int, - n: int, - k: int, - a_scale_group_shape: tuple, - b_scale_group_shape: tuple, - use_bias: bool, - out_dtype: type[torch.dtype] = torch.bfloat16, - device: str = "cuda"): +def cutlass_int8_gemm_helper( + m: int, + n: int, + k: int, + a_scale_group_shape: tuple, + b_scale_group_shape: tuple, + use_bias: bool, + out_dtype: type[torch.dtype] = torch.bfloat16, + device: str = "cuda", +): # Test for a cutlass kernel with per-token activation quantization # and per-output channel weight quantization. a = to_int8(torch.randn((m, k), device=device) * 5) @@ -118,158 +116,202 @@ def cutlass_int8_gemm_helper(m: int, a_scales_shape = scale_shape(a.shape, a_scale_group_shape) b_scales_shape = scale_shape(b.shape, b_scale_group_shape) - scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) - scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32) + scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32) - if use_bias: - bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 - else: - bias = None + bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) - opcheck(torch.ops._C.cutlass_scaled_mm, - (out, a, b, scale_a, scale_b, bias)) + opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape, - b_scale_group_shape, use_bias: bool): - cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, - use_bias) +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm( + m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool +): + cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) -@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape", - [((1, 128), (128, 128))]) +@pytest.mark.parametrize( + "a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))] +) @pytest.mark.parametrize("use_bias", [False]) -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="FP8 blockwise is not supported on this GPU type.") -def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int, - a_scale_group_shape, - b_scale_group_shape, use_bias: bool): +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="FP8 blockwise is not supported on this GPU type.", +) +def test_cutlass_fp8_blockwise_scale_gemm( + m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool +): if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0: return if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0: return if m % 4 != 0 and current_platform.has_device_capability(100): return - cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, - use_bias) + cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape, - b_scale_group_shape, use_bias: bool): - cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, - use_bias) - - -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +def test_cutlass_int8_gemm( + m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool +): + cutlass_int8_gemm_helper( + m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias + ) + + +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape, - b_scale_group_shape, - out_dtype: type[torch.dtype], - use_bias: bool): - cutlass_int8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=out_dtype) - - -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +def test_cutlass_int8_gemm_output_dtype( + a_scale_group_shape, + b_scale_group_shape, + out_dtype: type[torch.dtype], + use_bias: bool, +): + cutlass_int8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=out_dtype, + ) + + +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape, - b_scale_group_shape, - out_dtype: type[torch.dtype], - use_bias: bool): - cutlass_fp8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=out_dtype) - - -@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape", - [((1, 128), (128, 128))]) +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm_output_dtype( + a_scale_group_shape, + b_scale_group_shape, + out_dtype: type[torch.dtype], + use_bias: bool, +): + cutlass_fp8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=out_dtype, + ) + + +@pytest.mark.parametrize( + "a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))] +) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [False]) -@pytest.mark.skipif(not current_platform.has_device_capability(90), - reason="FP8 blockwise is not supported on this GPU type.") -def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape, - b_scale_group_shape, - out_dtype: type[torch.dtype], - use_bias: bool): - cutlass_fp8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=out_dtype) - - -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="FP8 blockwise is not supported on this GPU type.", +) +def test_cutlass_fp8_blockwise_scale_gemm_dtype( + a_scale_group_shape, + b_scale_group_shape, + out_dtype: type[torch.dtype], + use_bias: bool, +): + cutlass_fp8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=out_dtype, + ) + + +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape, - use_bias: bool, device: str): - cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape, - b_scale_group_shape, use_bias, torch.bfloat16, - device) - - -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm_devices( + a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str +): + cutlass_fp8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + torch.bfloat16, + device, + ) + + +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape, - use_bias: bool, device: str): - cutlass_int8_gemm_helper(512, - 512, - 512, - a_scale_group_shape, - b_scale_group_shape, - use_bias, - out_dtype=torch.bfloat16, - device=device) +def test_cutlass_int8_gemm_devices( + a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str +): + cutlass_int8_gemm_helper( + 512, + 512, + 512, + a_scale_group_shape, + b_scale_group_shape, + use_bias, + out_dtype=torch.bfloat16, + device=device, + ) # For the following two tests: @@ -277,32 +319,42 @@ def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape, # of a large power of two. In any case, the kernel will have a naive fallback # when N and K are not divisible by 16. But M is the number of tokens and the # kernel must handle any M thrown at it. -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.skipif(not current_platform.has_device_capability(89), - reason="FP8 is not supported on this GPU type.") -def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape, - use_bias: bool): +@pytest.mark.skipif( + not current_platform.has_device_capability(89), + reason="FP8 is not supported on this GPU type.", +) +def test_cutlass_fp8_gemm_m_sweep( + a_scale_group_shape, b_scale_group_shape, use_bias: bool +): for nk in range(32, 128, 32): for m in range(1, 128): - cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape, - b_scale_group_shape, use_bias) + cutlass_fp8_gemm_helper( + m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias + ) -@pytest.mark.parametrize("a_scale_group_shape", - [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) -@pytest.mark.parametrize("b_scale_group_shape", - [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]) +@pytest.mark.parametrize( + "a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) +@pytest.mark.parametrize( + "b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE] +) @pytest.mark.parametrize("use_bias", [True, False]) -def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape, - use_bias: bool): +def test_cutlass_int8_gemm_m_sweep( + a_scale_group_shape, b_scale_group_shape, use_bias: bool +): for nk in range(32, 128, 32): for m in range(1, 128): - cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape, - b_scale_group_shape, use_bias) + cutlass_int8_gemm_helper( + m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias + ) @pytest.mark.parametrize("m", [32, 64, 128]) @@ -310,8 +362,7 @@ def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape, @pytest.mark.parametrize("k", [64, 128, 256]) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.skip -def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, - out_dtype: torch.dtype): +def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, out_dtype: torch.dtype): # Currently, the test is failing because folding azp into # 16-bit bias loses too much precision scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 @@ -328,7 +379,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, b_dq = scale_b * bq_f32 - azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5 + azp_a = torch.rand((1,), device="cuda", dtype=torch.float32) * 10 + 1.5 azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8) azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding @@ -340,18 +391,17 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, J = torch.ones((1, k), device="cuda", dtype=torch.float32) azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype) assert azp_bias.shape == (1, n) - assert azp_bias[0, :].shape == (n, ) - - baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * ( - (aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to( - dtype=out_dtype, device='cuda') - - out = ops.cutlass_scaled_mm(aq_i8, - bq_i8, - scale_a, - scale_b, - out_dtype=out_dtype, - bias=azp_bias[0, :]) + assert azp_bias[0, :].shape == (n,) + + baseline_q = ( + scale_a.to(device="cpu") + * scale_b.to(device="cpu") + * ((aq_i32 + azp_aq_i8).to(device="cpu") @ bq_i32.to(device="cpu")) + ).to(dtype=out_dtype, device="cuda") + + out = ops.cutlass_scaled_mm( + aq_i8, bq_i8, scale_a, scale_b, out_dtype=out_dtype, bias=azp_bias[0, :] + ) torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0) torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0) @@ -362,8 +412,9 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("use_bias", [True, False]) @pytest.mark.parametrize("azp_per_token", [True, False]) -def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, - use_bias: bool, azp_per_token: bool): +def test_cutlass_int8_azp( + m: int, n: int, k: int, out_dtype: torch.dtype, use_bias: bool, azp_per_token: bool +): m_azp = m if azp_per_token else 1 scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10 @@ -377,16 +428,12 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, bq_f32 = bq_i8.to(dtype=torch.float32) b_dq = scale_b * bq_f32 - azp_a = torch.rand( - (m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5 + azp_a = torch.rand((m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5 azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8) azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32) - torch.testing.assert_close(a_dq, - scale_a * aq_f32 - azp_a, - rtol=1e-4, - atol=1e-3) + torch.testing.assert_close(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3) if use_bias: bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5 @@ -396,8 +443,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype) # int32 mm not supported on CUDA - a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu') - cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda') + a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device="cpu") + cq = (a_noazp_i32_cpu @ bq_i32.to(device="cpu")).to(device="cuda") baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype) # Hadamard is just the sum of the cols @@ -406,14 +453,14 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, func_bias = bias if use_bias else None if azp_per_token: - out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b, - out_dtype, azp_adj_i32, azp_i32, - func_bias) + out = ops.cutlass_scaled_mm_azp( + aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_adj_i32, azp_i32, func_bias + ) else: azp_with_adj_i32 = azp_i32 * azp_adj_i32 - out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b, - out_dtype, azp_with_adj_i32, None, - func_bias) + out = ops.cutlass_scaled_mm_azp( + aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_with_adj_i32, None, func_bias + ) # bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4% # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05% @@ -423,13 +470,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol) if azp_per_token: - opcheck(torch.ops._C.cutlass_scaled_mm_azp, - (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, - func_bias)) + opcheck( + torch.ops._C.cutlass_scaled_mm_azp, + (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, func_bias), + ) else: - opcheck(torch.ops._C.cutlass_scaled_mm_azp, - (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, - func_bias)) + opcheck( + torch.ops._C.cutlass_scaled_mm_azp, + (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, func_bias), + ) # Test working with a subset of A and B @@ -445,23 +494,14 @@ def test_cutlass_subset(): scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 - out = ops.cutlass_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) - baseline = baseline_scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16) + out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16) + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) # Test to make sure cuda graphs work class CutlassLayer(torch.nn.Module): - def __init__(self, b, scale_a, scale_b, out_dtype): super().__init__() self.b = b @@ -470,8 +510,9 @@ def __init__(self, b, scale_a, scale_b, out_dtype): self.out_dtype = out_dtype def forward(self, a): - return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b, - self.out_dtype) + return ops.cutlass_scaled_mm( + a, self.b, self.scale_a, self.scale_b, self.out_dtype + ) @pytest.mark.parametrize("per_act_token", [True, False]) @@ -485,10 +526,8 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): m_a_scales = m if per_act_token else 1 n_b_scales = n if per_out_ch else 1 - scale_a = (torch.randn( - (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10) - scale_b = (torch.randn( - (1, n_b_scales), device="cuda", dtype=torch.float32) / 10) + scale_a = torch.randn((m_a_scales, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, n_b_scales), device="cuda", dtype=torch.float32) / 10 # Construct a trivial model with a single layer that calls a CUTLASS kernel model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16) @@ -502,13 +541,14 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): out.zero_() g.replay() - baseline = torch.mm(scale_a * a.to(dtype=torch.float32), - scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) + baseline = torch.mm( + scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32) + ).to(torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) def test_cutlass_support_opcheck(): - opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) + opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability,)) @pytest.mark.parametrize("num_experts", [8, 64]) @@ -517,11 +557,13 @@ def test_cutlass_support_opcheck(): @pytest.mark.parametrize("use_bias", [False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( - current_platform.get_device_capability()), - reason="Grouped gemm is not supported on this GPU type.") -def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, - per_out_ch: bool, use_bias: bool): - + current_platform.get_device_capability() + ), + reason="Grouped gemm is not supported on this GPU type.", +) +def test_cutlass_fp8_group_gemm( + num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool +): # Device and dtype setup device = "cuda" out_dtype = torch.half @@ -533,13 +575,9 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, b_scales_tensors = [] baseline_tensors = [] - expert_offsets = torch.zeros((num_experts + 1), - device=device, - dtype=torch.int64) + expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int64) - problem_sizes = torch.zeros((num_experts, 3), - device=device, - dtype=torch.int32) + problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32) if not per_act_token: one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32) @@ -566,75 +604,76 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, b_tensors.append(b_g) # Set up A/B scales - scale_b = torch.randn((1, n_b_scales), - device=device, - dtype=torch.float32) + scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32) b_scales_tensors.append(scale_b) if per_act_token: - scale_a = torch.randn((m_a_scales, 1), - device=device, - dtype=torch.float32) + scale_a = torch.randn((m_a_scales, 1), device=device, dtype=torch.float32) a_scales_tensors.append(scale_a) else: scale_a = one_scale_a # Compute baseline result for this group - baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, - None) + baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None) baseline_tensors.append(baseline_g) - a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g), - device=device, - dtype=torch.float8_e4m3fn) - b_tensors_stacked = torch.empty((num_experts, n_g, k_g), - device=device, - dtype=torch.float8_e4m3fn) + a_tensors_stacked = torch.empty( + (expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn + ) + b_tensors_stacked = torch.empty( + (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn + ) for g in range(num_experts): - a_tensors_stacked[expert_offsets[g]:expert_offsets[g + - 1]] = a_tensors[g] + a_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g] b_tensors_stacked[g] = b_tensors[g].t() b_tensors_stacked = b_tensors_stacked.transpose(1, 2) if per_act_token: a_scales_tensors_stacked = torch.empty( - (expert_offsets[num_experts], 1), - device=device, - dtype=torch.float32) + (expert_offsets[num_experts], 1), device=device, dtype=torch.float32 + ) for g in range(num_experts): - a_scales_tensors_stacked[ - expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g] + a_scales_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = ( + a_scales_tensors[g] + ) else: a_scales_tensors_stacked = one_scale_a - b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales), - device=device, - dtype=torch.float32) + b_scales_tensors_stacked = torch.empty( + (num_experts, n_b_scales), device=device, dtype=torch.float32 + ) for g in range(num_experts): b_scales_tensors_stacked[g] = b_scales_tensors[g] - out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g), - device=device, - dtype=out_dtype) - - ab_strides = torch.full((num_experts, ), - a_tensors_stacked.stride(0), - device="cuda", - dtype=torch.int64) - c_strides = torch.full((num_experts, ), - out_tensors_stacked.stride(0), - device="cuda", - dtype=torch.int64) - - ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked, - b_tensors_stacked, a_scales_tensors_stacked, - b_scales_tensors_stacked, expert_offsets[:-1], - problem_sizes, ab_strides, ab_strides, c_strides, - per_act_token, per_out_ch) + out_tensors_stacked = torch.zeros( + (expert_offsets[num_experts], n_g), device=device, dtype=out_dtype + ) + + ab_strides = torch.full( + (num_experts,), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64 + ) + c_strides = torch.full( + (num_experts,), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64 + ) + + ops.cutlass_moe_mm( + out_tensors_stacked, + a_tensors_stacked, + b_tensors_stacked, + a_scales_tensors_stacked, + b_scales_tensors_stacked, + expert_offsets[:-1], + problem_sizes, + ab_strides, + ab_strides, + c_strides, + per_act_token, + per_out_ch, + ) # Validate each group's result against the baseline for g in range(num_experts): baseline = baseline_tensors[g] - c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]] + c = out_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4) diff --git a/tests/kernels/quantization/test_cutlass_w4a8.py b/tests/kernels/quantization/test_cutlass_w4a8.py index f659408efe8c..a3d524fe90ed 100644 --- a/tests/kernels/quantization/test_cutlass_w4a8.py +++ b/tests/kernels/quantization/test_cutlass_w4a8.py @@ -13,7 +13,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) + pack_rows, + quantize_weights, +) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types @@ -24,16 +26,33 @@ # have kernels and some kernels support multiple quantization methods. IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 -MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672), - (13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096), - (64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096), - (1024, 4096, 8192), (1024, 8192, 4096)] +MNK_SHAPES = [ + (1, 128, 128), + (1, 512, 1024), + (1, 4096, 4096), + (1, 8192, 28672), + (13, 8192, 4096), + (26, 4096, 8192), + (64, 4096, 4096), + (64, 8192, 28672), + (257, 128, 4096), + (257, 4096, 4096), + (1024, 4096, 8192), + (1024, 8192, 4096), +] # TODO(czhu): get supported schedules from fn SCHEDULES = [ - '128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1', - '128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1', - '128x256_1x1x1', '128x256_2x1x1' + "128x16_1x1x1", + "256x16_1x1x1", + "128x32_1x1x1", + "256x32_1x1x1", + "128x64_1x1x1", + "256x64_1x1x1", + "128x128_1x1x1", + "256x128_1x1x1", + "128x256_1x1x1", + "128x256_2x1x1", ] @@ -60,19 +79,23 @@ class Tensors: # (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, # Ch Scales Type, Tok Scales Type) -TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype], - Optional[torch.dtype], bool] +TestTypeTuple = tuple[ + list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool +] TEST_TYPES = [ *( - TypeConfig(act_type=torch.float8_e4m3fn, - weight_type=w_type, - output_type=o_type, - group_scale_type=torch.float8_e4m3fn, - channel_scale_type=torch.float32, - token_scale_type=torch.float32) + TypeConfig( + act_type=torch.float8_e4m3fn, + weight_type=w_type, + output_type=o_type, + group_scale_type=torch.float8_e4m3fn, + channel_scale_type=torch.float32, + token_scale_type=torch.float32, + ) for w_type in [scalar_types.int4] # TODO(czhu): fp16 out type - for o_type in [torch.bfloat16]), + for o_type in [torch.bfloat16] + ), ] # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel @@ -86,26 +109,28 @@ class Tensors: # For testing quantized linear kernels def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) - return tensor.clamp(min=finfo.min, - max=finfo.max).to(dtype=torch.float8_e4m3fn) + return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn) -def cutlass_quantize_and_pack(atype: torch.dtype, - w: torch.Tensor, - wtype: ScalarType, - stype: Optional[torch.dtype], - group_size: Optional[int], - zero_points: bool = False): +def cutlass_quantize_and_pack( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False, +): assert wtype.is_integer(), "TODO: support floating point weights" - w_ref, w_q, w_s, w_zp = quantize_weights(w, - wtype, - group_size=group_size, - zero_points=zero_points) + w_ref, w_q, w_s, w_zp = quantize_weights( + w, wtype, group_size=group_size, zero_points=zero_points + ) # since scales are cast to fp8, we need to compute w_ref this way - w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to( - torch.float32).repeat_interleave(group_size, dim=0)).to(atype) + w_ref = ( + (w_q).to(torch.float32) + * w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0) + ).to(atype) # bit mask prevents sign extending int4 when packing w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape) @@ -117,12 +142,14 @@ def cutlass_quantize_and_pack(atype: torch.dtype, return w_ref, w_q_packed, w_s_packed, w_zp -def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig, - group_size: Optional[int]) -> Tensors: +def create_test_tensors( + shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int] +) -> Tensors: m, n, k = shape - print("create_test_tensors, shape:", shape, "types:", types, "group_size:", - group_size) + print( + "create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size + ) a = to_fp8(torch.randn((m, k), device="cuda")) w = to_fp8(torch.randn((k, n), device="cuda")) @@ -133,30 +160,34 @@ def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig, w = w.to(torch.float16) w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( - a.dtype, w, types.weight_type, types.group_scale_type, group_size, - False) + a.dtype, w, types.weight_type, types.group_scale_type, group_size, False + ) a_ref = a.to(torch.float32) w_ref = w_ref.to(torch.float32) # for the practical use case we need per-tok scales for fp8 activations - w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type) + w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type) # weights are already per-group quantized, use placeholder here - w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type) - - return Tensors(w_ref=w_ref, - a_ref=a_ref, - a=a, - w_q=w_q_packed, - w_g_s=w_s, - w_ch_s=w_ch_s, - w_tok_s=w_tok_s) + w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type) + + return Tensors( + w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s, + ) -def mm_test_helper(types: TypeConfig, - tensors: Tensors, - group_size: Optional[int] = None, - schedule: Optional[str] = None): +def mm_test_helper( + types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None, +): # CUTLASS upstream uses fp8 with fastaccum as reference # https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406 output_ref = torch._scaled_mm( @@ -165,7 +196,8 @@ def mm_test_helper(types: TypeConfig, tensors.w_tok_s.unsqueeze(1), tensors.w_ch_s.unsqueeze(0), out_dtype=types.output_type, - use_fast_accum=True) + use_fast_accum=True, + ) output = ops.cutlass_w4a8_mm( a=tensors.a, @@ -179,17 +211,15 @@ def mm_test_helper(types: TypeConfig, print(output) print(output_ref) - torch.testing.assert_close(output, - output_ref.to(output.dtype), - rtol=1e-3, - atol=1e-3) + torch.testing.assert_close( + output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3 + ) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="CUTLASS W4A8 is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type." +) +@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x)) @pytest.mark.parametrize("types", TEST_TYPES) @pytest.mark.parametrize("schedule", SCHEDULES) def test_cutlass_w4a8(shape, types: TypeConfig, schedule): @@ -201,7 +231,6 @@ def test_cutlass_w4a8(shape, types: TypeConfig, schedule): # Test to make sure cuda graphs work class W4A8Layer(torch.nn.Module): - def __init__(self, **kwargs): super().__init__() self.kwargs = kwargs @@ -210,8 +239,9 @@ def forward(self, a): return ops.cutlass_w4a8_mm(a=a, **self.kwargs) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="CUTLASS W4A8 is not supported on this GPU type.") +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type." +) def test_w4a8_cuda_graph(): m, n, k = 512, 4096, 4096 @@ -224,10 +254,11 @@ def test_w4a8_cuda_graph(): zero_points = False w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( - a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points) + a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points + ) - w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32) - w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32) + w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32) + w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32) # Construct a trivial model with a single layer that calls the kernel model = W4A8Layer( @@ -244,7 +275,8 @@ def test_w4a8_cuda_graph(): w_tok_s.unsqueeze(1), w_ch_s.unsqueeze(0), out_dtype=torch.bfloat16, - use_fast_accum=True) + use_fast_accum=True, + ) # Run the model with a cuda graph stream = torch.cuda.Stream() diff --git a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py index 131086a5f703..1e5c7dafb0f5 100644 --- a/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -2,8 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch -from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, - convert_swizzled_to_linear, dequantize_nvfp4_to_dtype) +from nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + convert_swizzled_to_linear, + dequantize_nvfp4_to_dtype, +) from vllm import _custom_ops as ops from vllm.platforms import current_platform @@ -41,18 +45,12 @@ def get_ref_results( _, m_k = a_fp4.shape _, n_k = b_fp4.shape assert m_k == n_k - a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, - a_sf, - a_global_scale, - dtype=dtype, - device=device, - block_size=block_size) - b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, - b_sf, - b_global_scale, - dtype=dtype, - device=device, - block_size=block_size) + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size + ) + b_in_dtype = dequantize_nvfp4_to_dtype( + b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size + ) return torch.matmul(a_in_dtype, b_in_dtype.t()) @@ -72,8 +70,7 @@ def test_flashinfer_nvfp4_gemm( autotune: bool, ) -> None: if backend == "trtllm" and dtype == torch.float16: - pytest.skip( - "Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") + pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations") current_platform.seed_everything(seed) m, n, packed_k = shape @@ -82,10 +79,12 @@ def test_flashinfer_nvfp4_gemm( a_dtype = torch.randn((m, k), dtype=dtype, device=device) b_dtype = torch.randn((n, k), dtype=dtype, device=device) - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) - b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) + ).to(torch.float32) + b_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) + ).to(torch.float32) alpha = 1.0 / (a_global_scale * b_global_scale) # ops.scaled_fp4_quant returns swizzled scales, while weights # from checkpoints are in linear scales. @@ -113,14 +112,18 @@ def test_flashinfer_nvfp4_gemm( if backend == "trtllm": epilogue_tile_m = 128 - b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), - epilogue_tile_m) + b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m) b_scale_interleaved = convert_swizzled_to_linear( - b_scale_interleaved, n, k, block_size) - b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a( - b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape( - b_scale_interleaved.shape).view(torch.float8_e4m3fn)) + b_scale_interleaved, n, k, block_size + ) + b_scale_interleaved = ( + flashinfer.shuffle_matrix_sf_a( + b_scale_interleaved.view(torch.uint8), epilogue_tile_m + ) + .reshape(b_scale_interleaved.shape) + .view(torch.float8_e4m3fn) + ) with flashinfer.autotune(autotune): out = flashinfer_scaled_fp4_mm( @@ -133,7 +136,4 @@ def test_flashinfer_nvfp4_gemm( backend=backend, ) - torch.testing.assert_close(out, - expected_out.to(dtype=dtype), - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) diff --git a/tests/kernels/quantization/test_flashinfer_scaled_mm.py b/tests/kernels/quantization/test_flashinfer_scaled_mm.py index 9f669c6df8bd..b30821b6895b 100644 --- a/tests/kernels/quantization/test_flashinfer_scaled_mm.py +++ b/tests/kernels/quantization/test_flashinfer_scaled_mm.py @@ -9,8 +9,7 @@ if not current_platform.has_device_capability(100): pytest.skip( - reason= - "Flashinfer FP8 gemms requires compute capability of 10.0 or above.", + reason="Flashinfer FP8 gemms requires compute capability of 10.0 or above.", allow_module_level=True, ) @@ -53,7 +52,7 @@ def test_flashinfer_fp8_gemm( ).to(dtype=dtype) if use_bias: - bias = torch.randn((n, ), dtype=dtype, device=device) + bias = torch.randn((n,), dtype=dtype, device=device) expected_out = expected_out + bias else: bias = None diff --git a/tests/kernels/quantization/test_fp8_quant.py b/tests/kernels/quantization/test_fp8_quant.py index c2e70ffb8d34..19aa21b96a57 100644 --- a/tests/kernels/quantization/test_fp8_quant.py +++ b/tests/kernels/quantization/test_fp8_quant.py @@ -5,9 +5,11 @@ import torch import vllm._custom_ops as ops -from tests.kernels.quant_utils import (FP8_DTYPE, - ref_dynamic_per_tensor_fp8_quant, - ref_dynamic_per_token_quant) +from tests.kernels.quant_utils import ( + FP8_DTYPE, + ref_dynamic_per_tensor_fp8_quant, + ref_dynamic_per_token_quant, +) from tests.kernels.utils import opcheck from vllm.platforms import current_platform @@ -18,23 +20,25 @@ SEEDS = [0] -def opcheck_fp8_quant(output, - input, - scale=None, - scale_ub=None, - use_per_token_if_dynamic=False): +def opcheck_fp8_quant( + output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False +): if scale is not None: opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale)) elif use_per_token_if_dynamic: - scale = torch.empty((input.shape[0], 1), - device=input.device, - dtype=torch.float32) - opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant, - (output, input, scale, scale_ub)) + scale = torch.empty( + (input.shape[0], 1), device=input.device, dtype=torch.float32 + ) + opcheck( + torch.ops._C.dynamic_per_token_scaled_fp8_quant, + (output, input, scale, scale_ub), + ) else: - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) + scale = torch.empty( + (input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32, + ) opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale)) @@ -44,30 +48,29 @@ def opcheck_fp8_quant(output, @pytest.mark.parametrize("scale_ub", SCALE_UBS) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, scale_ub: bool, - seed: int) -> None: +def test_dynamic_per_token_fp8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int +) -> None: current_platform.seed_everything(seed) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, - device="cuda") + 1e-6 # avoid nans + x = ( + torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6 + ) # avoid nans - scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \ - if scale_ub else None + scale_ub = ( + torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None + ) ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub) - ops_out, ops_scales = ops.scaled_fp8_quant(x, - scale_ub=scale_ub, - use_per_token_if_dynamic=True) + ops_out, ops_scales = ops.scaled_fp8_quant( + x, scale_ub=scale_ub, use_per_token_if_dynamic=True + ) torch.testing.assert_close(ref_scales, ops_scales) - torch.testing.assert_close(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + torch.testing.assert_close( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) - opcheck_fp8_quant(ops_out, - x, - None, - scale_ub, - use_per_token_if_dynamic=True) + opcheck_fp8_quant(ops_out, x, None, scale_ub, use_per_token_if_dynamic=True) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -75,8 +78,9 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: +def test_dynamic_per_tensor_fp8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int +) -> None: current_platform.seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") @@ -85,8 +89,9 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, ops_out, ops_scale = ops.scaled_fp8_quant(x) torch.testing.assert_close(ref_scale, ops_scale) - torch.testing.assert_close(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + torch.testing.assert_close( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) opcheck_fp8_quant(ops_out, x) diff --git a/tests/kernels/quantization/test_fp8_quant_group.py b/tests/kernels/quantization/test_fp8_quant_group.py index 8f2bc6e3cee5..6628ac650fd5 100644 --- a/tests/kernels/quantization/test_fp8_quant_group.py +++ b/tests/kernels/quantization/test_fp8_quant_group.py @@ -6,8 +6,7 @@ import torch from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform @@ -18,13 +17,14 @@ (64, 1024, 64), # Medium (128, 2048, 128), # Large (8, 513, 64), # Non-divisible (native only) - ]) + ], +) @pytest.mark.parametrize("seed", [42]) @pytest.mark.parametrize("use_ue8m0", [True, False]) @torch.inference_mode() -def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, - group_size: int, seed: int, - use_ue8m0: bool) -> None: +def test_quantfp8_group_functionality( + batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool +) -> None: """Test QuantFP8 group quantization with various configurations. Tests both CUDA and native implementations, column-major scales, @@ -32,16 +32,17 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, """ current_platform.seed_everything(seed) - x = torch.randn( - (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + x = torch.randn((batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 expected_num_groups = (hidden_dim + group_size - 1) // group_size is_divisible = hidden_dim % group_size == 0 group_shape = GroupShape(1, group_size) - quant_op = QuantFP8(static=False, - group_shape=group_shape, - column_major_scales=False, - use_ue8m0=use_ue8m0) + quant_op = QuantFP8( + static=False, + group_shape=group_shape, + column_major_scales=False, + use_ue8m0=use_ue8m0, + ) # 1. Test native implementation (always available) x_quant_native, scales_native = quant_op.forward_native(x.clone()) @@ -49,10 +50,12 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, assert scales_native.shape == (batch_size, expected_num_groups) # 2. Test column-major scales configuration - quant_op_col = QuantFP8(static=False, - group_shape=group_shape, - column_major_scales=True, - use_ue8m0=use_ue8m0) + quant_op_col = QuantFP8( + static=False, + group_shape=group_shape, + column_major_scales=True, + use_ue8m0=use_ue8m0, + ) _, scales_col = quant_op_col.forward_native(x.clone()) assert scales_col.shape == (batch_size, expected_num_groups) assert scales_col.stride(0) == 1 @@ -86,41 +89,48 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None: # Test with 3D input batch1, batch2, hidden_dim = 4, 8, 1024 - x_3d = torch.randn( - (batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 + x_3d = ( + torch.randn((batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") + * 8 + ) group_shape = GroupShape(1, group_size) - quant_op = QuantFP8(static=False, - group_shape=group_shape, - column_major_scales=False, - use_ue8m0=use_ue8m0) + quant_op = QuantFP8( + static=False, + group_shape=group_shape, + column_major_scales=False, + use_ue8m0=use_ue8m0, + ) x_quant, scales = quant_op.forward_native(x_3d.clone()) assert x_quant.shape == x_3d.shape assert scales.shape == (batch1, batch2, hidden_dim // group_size) # Test column_major_scales with multi-dim - quant_op_col = QuantFP8(static=False, - group_shape=group_shape, - column_major_scales=True, - use_ue8m0=use_ue8m0) + quant_op_col = QuantFP8( + static=False, + group_shape=group_shape, + column_major_scales=True, + use_ue8m0=use_ue8m0, + ) _, scales_col = quant_op_col.forward_native(x_3d.clone()) assert scales_col.shape == (batch1, batch2, hidden_dim // group_size) # Test with 4D input batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256 - x_4d = torch.randn((batch1, batch2, batch3, hidden_dim), - dtype=torch.bfloat16, - device="cuda") * 8 + x_4d = ( + torch.randn( + (batch1, batch2, batch3, hidden_dim), dtype=torch.bfloat16, device="cuda" + ) + * 8 + ) x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone()) assert x_quant_4d.shape == x_4d.shape - assert scales_4d.shape == (batch1, batch2, batch3, - hidden_dim // group_size) + assert scales_4d.shape == (batch1, batch2, batch3, hidden_dim // group_size) _, scales_4d_col = quant_op_col.forward_native(x_4d.clone()) - assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, - batch3) + assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, batch3) @pytest.mark.parametrize("seed", [42]) @@ -132,30 +142,24 @@ def test_quantfp8_group_edge_cases(seed: int) -> None: group_size = 64 # Test with single group (group_size >= hidden_dim) - x_small = torch.randn( - (batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8 + x_small = torch.randn((batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8 group_shape = GroupShape(1, group_size) - quant_op = QuantFP8(static=False, - group_shape=group_shape, - column_major_scales=False) + quant_op = QuantFP8( + static=False, group_shape=group_shape, column_major_scales=False + ) x_quant_small, scales_small = quant_op.forward_native(x_small.clone()) assert x_quant_small.shape == x_small.shape assert scales_small.shape == (batch_size, 1) # Test with zero inputs - x_zero = torch.zeros((batch_size, 256), - dtype=torch.bfloat16, - device="cuda") + x_zero = torch.zeros((batch_size, 256), dtype=torch.bfloat16, device="cuda") x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone()) assert x_quant_zero.shape == x_zero.shape assert (scales_zero > 0).all(), "Scales should be clamped to minimum" # Test very large values - x_large = torch.full((batch_size, 256), - 1000.0, - dtype=torch.bfloat16, - device="cuda") + x_large = torch.full((batch_size, 256), 1000.0, dtype=torch.bfloat16, device="cuda") x_quant_large, scales_large = quant_op.forward_native(x_large.clone()) assert x_quant_large.shape == x_large.shape # FP8 max is typically 448 or 224, so scales should be > 1 diff --git a/tests/kernels/quantization/test_ggml.py b/tests/kernels/quantization/test_ggml.py index 07651fef39bf..0dc24187f2b3 100644 --- a/tests/kernels/quantization/test_ggml.py +++ b/tests/kernels/quantization/test_ggml.py @@ -13,33 +13,42 @@ def test_ggml_opcheck(quant_type): block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type] shape = [256, 1152] - qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) + qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8) m = qweight.shape[0] n = qweight.shape[1] // type_size * block_size - opcheck(torch.ops._C.ggml_dequantize, - (qweight, quant_type, m, n, torch.float16)) + opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n, torch.float16)) - x = torch.rand((m, 512), device='cuda', dtype=torch.float16) - opcheck(torch.ops._C.ggml_mul_mat_a8, - (qweight, x, quant_type, qweight.shape[0])) - opcheck(torch.ops._C.ggml_mul_mat_vec_a8, - (qweight, x, quant_type, qweight.shape[0])) + x = torch.rand((m, 512), device="cuda", dtype=torch.float16) + opcheck(torch.ops._C.ggml_mul_mat_a8, (qweight, x, quant_type, qweight.shape[0])) + opcheck( + torch.ops._C.ggml_mul_mat_vec_a8, (qweight, x, quant_type, qweight.shape[0]) + ) shape = [256, 1024, 336] - qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) - x = torch.rand((1, 1024), device='cuda', dtype=torch.float16) - sorted_token_ids = torch.arange(776, device='cuda') - expert_ids = torch.randint(0, 256, (194, ), device='cuda') - num_tokens_post_padded = torch.tensor([1], - dtype=torch.int64, - device='cuda') - - opcheck(torch.ops._C.ggml_moe_a8, - (x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded, - quant_type, qweight.shape[0], 1, x.shape[0])) + qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8) + x = torch.rand((1, 1024), device="cuda", dtype=torch.float16) + sorted_token_ids = torch.arange(776, device="cuda") + expert_ids = torch.randint(0, 256, (194,), device="cuda") + num_tokens_post_padded = torch.tensor([1], dtype=torch.int64, device="cuda") - topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32) + opcheck( + torch.ops._C.ggml_moe_a8, + ( + x, + qweight, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + quant_type, + qweight.shape[0], + 1, + x.shape[0], + ), + ) + + topk_ids = torch.zeros((1, 1), device="cuda", dtype=torch.int32) opcheck( torch.ops._C.ggml_moe_a8_vec, - (x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0])) + (x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]), + ) diff --git a/tests/kernels/quantization/test_gguf.py b/tests/kernels/quantization/test_gguf.py index 436d5cb64021..0988ba01759f 100644 --- a/tests/kernels/quantization/test_gguf.py +++ b/tests/kernels/quantization/test_gguf.py @@ -18,8 +18,8 @@ def get_gguf_sample_tensors( - hidden_size: int, - quant_type: GGMLQuantizationType) -> list[ReaderTensor]: + hidden_size: int, quant_type: GGMLQuantizationType +) -> list[ReaderTensor]: sample_dir = GGUF_SAMPLE filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" sample_file = Path(sample_dir) / filename @@ -27,8 +27,8 @@ def get_gguf_sample_tensors( def get_gguf_MoE_tensors( - hidden_size: int, - quant_type: GGMLQuantizationType) -> list[ReaderTensor]: + hidden_size: int, quant_type: GGMLQuantizationType +) -> list[ReaderTensor]: sample_dir = GGUF_SAMPLE_MOE filename = f"Quant_{quant_type.name}_{hidden_size}.gguf" sample_file = Path(sample_dir) / filename @@ -68,17 +68,20 @@ def get_gguf_MoE_tensors( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() -def test_dequantize(hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType): +def test_dequantize( + hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType +): tensors = get_gguf_sample_tensors(hidden_size, quant_type) for tensor in tensors: shape_str = tensor.name.split("_")[-1] shape = map(int, shape_str.split("x")) - ref_output = torch.tensor(dequantize(tensor.data, quant_type), - device="cuda").to(dtype) - output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"), - quant_type, *list(shape), dtype) + ref_output = torch.tensor( + dequantize(tensor.data, quant_type), device="cuda" + ).to(dtype) + output = ops.ggml_dequantize( + torch.tensor(tensor.data, device="cuda"), quant_type, *list(shape), dtype + ) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2) @@ -87,20 +90,21 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() -def test_mmvq(hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType): +def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType): current_platform.seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((1, hidden_size), dtype=dtype, device="cuda") for tensor in tensors: - weight = torch.tensor(dequantize(tensor.data, quant_type), - device="cuda").to(dtype) + weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to( + dtype + ) ref_output = x @ weight.T qweight = torch.tensor(tensor.data, device="cuda") - output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, - qweight.shape[0]).to(dtype) + output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, qweight.shape[0]).to( + dtype + ) torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) @@ -121,17 +125,23 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype, GGMLQuantizationType.Q4_0, GGMLQuantizationType.Q5_0, GGMLQuantizationType.Q8_0, - ]) + ], +) @torch.inference_mode() -def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType): +def test_mmq( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + quant_type: GGMLQuantizationType, +): current_platform.seed_everything(0) tensors = get_gguf_sample_tensors(hidden_size, quant_type) x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda") for tensor in tensors: - weight = torch.tensor(dequantize(tensor.data, quant_type), - device="cuda").to(dtype) + weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to( + dtype + ) ref_output = x @ weight.T qweight = torch.tensor(tensor.data, device="cuda") @@ -141,10 +151,9 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, # bfloat16 tends to accumulate and can greatly inflate rtol # since outputs are also very close to 0 rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1} - torch.testing.assert_close(output, - ref_output, - atol=atols[dtype], - rtol=rtols[dtype]) + torch.testing.assert_close( + output, ref_output, atol=atols[dtype], rtol=rtols[dtype] + ) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -153,35 +162,46 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() -def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, - quant_type: GGMLQuantizationType, top_k: int): +def test_moe( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + quant_type: GGMLQuantizationType, + top_k: int, +): current_platform.seed_everything(0) H, E = 1024, 256 x = torch.rand((num_tokens, H), dtype=dtype, device="cuda") topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype) - topk_ids = torch.randint(0, - E, (num_tokens, top_k), - device="cuda", - dtype=torch.int32) + topk_ids = torch.randint( + 0, E, (num_tokens, top_k), device="cuda", dtype=torch.int32 + ) tensors = get_gguf_MoE_tensors(hidden_size, quant_type) w13 = tensors[0] w2 = tensors[1] - w13_dequant = torch.tensor(dequantize(w13.data, quant_type), - device="cuda").to(dtype) - - w2_dequant = torch.tensor(dequantize(w2.data, quant_type), - device="cuda").to(dtype) - - output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"), - torch.tensor(w2.data, - device="cuda"), topk_weights, - topk_ids, quant_type, quant_type, "silu") - - ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights, - topk_ids).reshape(output.shape) + w13_dequant = torch.tensor(dequantize(w13.data, quant_type), device="cuda").to( + dtype + ) + + w2_dequant = torch.tensor(dequantize(w2.data, quant_type), device="cuda").to(dtype) + + output = _fused_moe_gguf( + x, + torch.tensor(w13.data, device="cuda"), + torch.tensor(w2.data, device="cuda"), + topk_weights, + topk_ids, + quant_type, + quant_type, + "silu", + ) + + ref_output = fused_experts( + x, w13_dequant, w2_dequant, topk_weights, topk_ids + ).reshape(output.shape) torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1) diff --git a/tests/kernels/quantization/test_gptq.py b/tests/kernels/quantization/test_gptq.py index 7fb57a1576bd..72e4194c1327 100644 --- a/tests/kernels/quantization/test_gptq.py +++ b/tests/kernels/quantization/test_gptq.py @@ -8,25 +8,22 @@ def test_gptq_shuffle_opcheck(): - weight = torch.randint(-2000000, - 2000000, (1792, 4096), - device='cuda', - dtype=torch.int32) - perm = torch.empty((0, ), device='cuda', dtype=torch.int32) + weight = torch.randint( + -2000000, 2000000, (1792, 4096), device="cuda", dtype=torch.int32 + ) + perm = torch.empty((0,), device="cuda", dtype=torch.int32) bit = 4 opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit)) def test_gptq_gemm_opcheck(): - a = torch.rand((240, 4096), device='cuda', dtype=torch.float16) - weight = torch.randint(-2000000, - 2000000, (512, 6144), - device='cuda', - dtype=torch.int32) - zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32) - scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16) - idx = torch.empty((0, ), device='cuda', dtype=torch.int32) + a = torch.rand((240, 4096), device="cuda", dtype=torch.float16) + weight = torch.randint( + -2000000, 2000000, (512, 6144), device="cuda", dtype=torch.int32 + ) + zeros = torch.zeros((32, 768), device="cuda", dtype=torch.int32) + scales = torch.rand((32, 6144), device="cuda", dtype=torch.float16) + idx = torch.empty((0,), device="cuda", dtype=torch.int32) use_exllama = True bit = 4 - opcheck(torch.ops._C.gptq_gemm, - (a, weight, zeros, scales, idx, use_exllama, bit)) + opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit)) diff --git a/tests/kernels/quantization/test_hadacore.py b/tests/kernels/quantization/test_hadacore.py index 127d68072e3f..3ccee9db048c 100644 --- a/tests/kernels/quantization/test_hadacore.py +++ b/tests/kernels/quantization/test_hadacore.py @@ -15,7 +15,8 @@ def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"): x = torch.eye(hidden_dim, dtype=dtype, device=device) hadamard = deterministic_hadamard_matrix( - hidden_dim, dtype=torch.float64, device="cuda") / math.sqrt(hidden_dim) + hidden_dim, dtype=torch.float64, device="cuda" + ) / math.sqrt(hidden_dim) y = ops.hadacore_transform(x.clone()) y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype) diff --git a/tests/kernels/quantization/test_int8_kernel.py b/tests/kernels/quantization/test_int8_kernel.py index f2271e6be542..0e31e9aabea8 100644 --- a/tests/kernels/quantization/test_int8_kernel.py +++ b/tests/kernels/quantization/test_int8_kernel.py @@ -11,12 +11,12 @@ from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.quantization.utils.int8_utils import ( - per_token_quant_int8) + per_token_quant_int8, +) from vllm.platforms import current_platform if current_platform.get_device_capability() < (7, 0): - pytest.skip("INT8 Triton requires CUDA 7.0 or higher", - allow_module_level=True) + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): @@ -26,14 +26,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): B = B.to(torch.float32) assert A.shape[-1] == B.shape[-1], "Dimension mismatch" - assert B.ndim == 2 and B.is_contiguous( - ), "B must be a 2D contiguous tensor" + assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor" # Reshape input M = A.numel() // A.shape[-1] B = B.t() # Transpose weight matrix N, K = B.shape - origin_C_shape = A.shape[:-1] + (K, ) + origin_C_shape = A.shape[:-1] + (K,) A = A.reshape(M, N) # As is per-token [M, 1], Bs is per-column [1, K] @@ -43,8 +42,7 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): return C.reshape(origin_C_shape).to(output_dtype) -def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, - topk_ids): +def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, topk_ids): """This function performs fused moe with per-column int8 quantization using native torch.""" @@ -66,25 +64,22 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, mask = topk_ids == i if mask.sum(): # First MLP layer: note that a_s is now per-token - inter_out = native_w8a8_per_token_matmul(a_q[mask], - w1[i], - a_s[mask], - w1_s[i], - output_dtype=a.dtype) + inter_out = native_w8a8_per_token_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype + ) # Activation function act_out = SiluAndMul().forward_native(inter_out) # Quantize activation output with per-token act_out_q, act_out_s = per_token_quant_int8(act_out) # Second MLP layer - out[mask] = native_w8a8_per_token_matmul(act_out_q, - w2[i], - act_out_s, - w2_s[i], - output_dtype=a.dtype) + out[mask] = native_w8a8_per_token_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype + ) # Apply routing weights and sum - return (out.view(B, -1, w2.shape[1]) * - topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) @pytest.fixture(autouse=True, scope="module") @@ -102,8 +97,10 @@ def setup_cuda(): SEEDS = [0] -@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed", - itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS)) +@pytest.mark.parametrize( + "M, N, K, E, topk, dtype, seed", + itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS), +) @torch.inference_mode() def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): torch.manual_seed(seed) @@ -130,8 +127,9 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weights, topk_ids = torch.topk(score, topk) - ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, - topk_weights, topk_ids) + ref_out = torch_w8a8_per_column_moe( + a, w1, w2, w1_s, w2_s, topk, topk_weights, topk_ids + ) quant_config = FusedMoEQuantConfig.make( torch.int8, @@ -151,7 +149,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): ) # Check results - rel_diff = (torch.mean( - torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / - torch.mean(torch.abs(ref_out.to(torch.float32)))) + rel_diff = torch.mean( + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)) + ) / torch.mean(torch.abs(ref_out.to(torch.float32))) assert rel_diff < 0.05 diff --git a/tests/kernels/quantization/test_int8_quant.py b/tests/kernels/quantization/test_int8_quant.py index c1c9bf191d5b..48e947db5fa7 100644 --- a/tests/kernels/quantization/test_int8_quant.py +++ b/tests/kernels/quantization/test_int8_quant.py @@ -18,26 +18,24 @@ def opcheck_int8_quant_static(output, input, scale, azp=None): if azp is None: - opcheck(torch.ops._C.static_scaled_int8_quant, - (output, input, scale, None)) + opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, None)) else: - opcheck(torch.ops._C.static_scaled_int8_quant, - (output, input, scale, azp)) + opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, azp)) def opcheck_int8_quant_dynamic(output, input, symmetric=True): - scale = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) + scale = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) if symmetric: - opcheck(torch.ops._C.dynamic_scaled_int8_quant, - (output, input, scale, None)) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, None)) else: - azp = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.int32) - opcheck(torch.ops._C.dynamic_scaled_int8_quant, - (output, input, scale, azp)) + azp = torch.empty( + (input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.int32, + ) + opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, azp)) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -45,8 +43,9 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: +def test_dynamic_scaled_int8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int +) -> None: current_platform.seed_everything(seed) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 @@ -68,30 +67,31 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() -def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int) -> None: +def test_dynamic_scaled_int8_azp_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int +) -> None: current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, - device="cuda") * 1000 - 300 + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300 x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True) x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True) # calculate scale and azp, and adjust the range scales = (x_token_max - x_token_min) / torch.tensor(255.0) - azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to( - torch.int32) + azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(torch.int32) - torch_out = ((x / scales).round() + azps).clamp( - int8_traits.min, int8_traits.max).to(torch.int8) - assert torch_out.min() >= int8_traits.min and torch_out.max( - ) <= int8_traits.max + torch_out = ( + ((x / scales).round() + azps) + .clamp(int8_traits.min, int8_traits.max) + .to(torch.int8) + ) + assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False) - if (not torch.allclose(scales_out, scales)): + if not torch.allclose(scales_out, scales): print(torch.argmax(torch.abs(scales_out - scales))) torch.testing.assert_close(scales_out, scales) # big atol to account for rounding errors @@ -108,17 +108,18 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("scale", SCALE) @torch.inference_mode() -def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int, - scale: float) -> None: +def test_static_scaled_int8_quant( + num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float +) -> None: current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") - out1 = (x / scale_arg).round().clamp(int8_traits.min, - int8_traits.max).to(torch.int8) + out1 = ( + (x / scale_arg).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) + ) out2, scale2, _ = scaled_int8_quant(x, scale_arg) assert scale2 is scale_arg @@ -135,24 +136,28 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("scale", SCALE) @pytest.mark.parametrize("azp", [-255, 54]) @torch.inference_mode() -def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int, - dtype: torch.dtype, seed: int, - scale: float, azp: int) -> None: +def test_static_scaled_int8_azp_quant( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + scale: float, + azp: int, +) -> None: current_platform.seed_everything(seed) int8_traits = torch.iinfo(torch.int8) - x = torch.rand(num_tokens, hidden_size, dtype=dtype, - device="cuda") * 1000 - 300 + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300 - out1 = ((x / scale).round() + azp).clamp(int8_traits.min, - int8_traits.max).to(torch.int8) + out1 = ( + ((x / scale).round() + azp) + .clamp(int8_traits.min, int8_traits.max) + .to(torch.int8) + ) scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda") azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda") - out2, scale2, azp2 = scaled_int8_quant(x, - scale_arg, - azp_arg, - symmetric=False) + out2, scale2, azp2 = scaled_int8_quant(x, scale_arg, azp_arg, symmetric=False) assert scale2 is scale_arg assert azp2 is azp_arg @@ -172,10 +177,7 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None: int32_traits = torch.iinfo(torch.int32) val = float(int32_traits.max if is_max else int32_traits.min) - x_vals = [[ - nextafter(val, inf), val + 1, val, val - 1, - nextafter(val, -inf) - ]] + x_vals = [[nextafter(val, inf), val + 1, val, val - 1, nextafter(val, -inf)]] x = torch.tensor(x_vals, dtype=torch.float32, device="cuda") # The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp) diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index 50584f3f82d4..b32523bb85d9 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -15,15 +15,16 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.machete_utils import ( - query_machete_supported_group_sizes) + query_machete_supported_group_sizes, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) + pack_rows, + quantize_weights, +) from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] # TODO: in future PR refactor this and `is_quant_method_supported` in the kernel # unit tests to a common utility function. Currently the use of @@ -72,29 +73,38 @@ class Tensors: # Ch Scales Type, Tok Scales Type) # NOTE: None "Scale Type" means the act type is floating point # None "Output Type" means the output type is the same as the act type -TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype], - Optional[torch.dtype], bool] +TestTypeTuple = tuple[ + list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool +] TEST_TYPES = [ # GPTQ style - *(TypeConfig(act_type=a_type, - weight_type=w_type, - output_type=None, - group_scale_type=a_type, - group_zero_type=None, - channel_scale_type=None, - token_scale_type=None) - for w_type in [scalar_types.uint4b8, scalar_types.uint8b128] - for a_type in [torch.float16, torch.bfloat16]), + *( + TypeConfig( + act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None, + ) + for w_type in [scalar_types.uint4b8, scalar_types.uint8b128] + for a_type in [torch.float16, torch.bfloat16] + ), # AWQ style - *(TypeConfig(act_type=a_type, - weight_type=w_type, - output_type=None, - group_scale_type=a_type, - group_zero_type=a_type, - channel_scale_type=None, - token_scale_type=None) - for w_type in [scalar_types.uint4, scalar_types.uint8] - for a_type in [torch.float16, torch.bfloat16]), + *( + TypeConfig( + act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=a_type, + channel_scale_type=None, + token_scale_type=None, + ) + for w_type in [scalar_types.uint4, scalar_types.uint8] + for a_type in [torch.float16, torch.bfloat16] + ), # # QQQ style # *(TypeConfig(act_type=torch.int8, # weight_type=scalar_types.uint4b8, @@ -133,17 +143,18 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): return zps if zps is None else -1 * s * (zps.to(s.dtype)) -def group_size_valid(shape: tuple[int, int, int], - group_size: Optional[int]) -> bool: +def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool: return group_size is None or group_size == -1 or shape[2] % group_size == 0 -def machete_quantize_and_pack(atype: torch.dtype, - w: torch.Tensor, - wtype: ScalarType, - stype: Optional[torch.dtype], - group_size: Optional[int], - zero_points: bool = False): +def machete_quantize_and_pack( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False, +): assert wtype.is_integer(), "TODO: support floating point weights" w_ref, w_q, w_s, w_zp = quantize_weights( @@ -152,7 +163,8 @@ def machete_quantize_and_pack(atype: torch.dtype, group_size=group_size, zero_points=zero_points, # to match how the kernel applies zps - ref_zero_points_after_scales=True) + ref_zero_points_after_scales=True, + ) w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) w_q = w_q.t().contiguous().t() # convert to col major @@ -163,15 +175,18 @@ def machete_quantize_and_pack(atype: torch.dtype, return w_ref, w_q_machete, w_s, w_zp -def create_test_tensors(shape: tuple[int, int, int], - types: TypeConfig, - group_size: Optional[int], - subset_stride_factor: Optional[int] = None) -> Tensors: +def create_test_tensors( + shape: tuple[int, int, int], + types: TypeConfig, + group_size: Optional[int], + subset_stride_factor: Optional[int] = None, +) -> Tensors: m, n, k = shape factor = subset_stride_factor or 1 - print("create_test_tensors, shape:", shape, "types:", types, "group_size:", - group_size) + print( + "create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size + ) a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2) w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1) @@ -186,8 +201,13 @@ def create_test_tensors(shape: tuple[int, int, int], w = w.to(torch.float16) w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - a.dtype, w, types.weight_type, types.group_scale_type, group_size, - types.group_zero_type is not None) + a.dtype, + w, + types.weight_type, + types.group_scale_type, + group_size, + types.group_zero_type is not None, + ) if not a.dtype.is_floating_point: aiinfo = torch.iinfo(a.dtype) @@ -196,35 +216,47 @@ def create_test_tensors(shape: tuple[int, int, int], a_ref = a.to(torch.float32) w_ref = w_ref.to(torch.float32) - w_ch_s = None if types.channel_scale_type is None else\ - rand_data((n,), types.channel_scale_type) - w_tok_s = None if types.token_scale_type is None else\ - rand_data((m,), types.token_scale_type) + w_ch_s = ( + None + if types.channel_scale_type is None + else rand_data((n,), types.channel_scale_type) + ) + w_tok_s = ( + None + if types.token_scale_type is None + else rand_data((m,), types.token_scale_type) + ) - return Tensors(w_ref=w_ref, - a_ref=a_ref, - a=a, - w_q=w_q_packed, - w_g_s=w_s, - w_g_zp=maybe_convert_zeropoints(w_zp, w_s), - w_ch_s=w_ch_s, - w_tok_s=w_tok_s) + return Tensors( + w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_g_zp=maybe_convert_zeropoints(w_zp, w_s), + w_ch_s=w_ch_s, + w_tok_s=w_tok_s, + ) # None stype means scales use the same dtype as a -def machete_mm_test_helper(types: TypeConfig, - tensors: Tensors, - group_size: Optional[int] = None, - schedule: Optional[str] = None): +def machete_mm_test_helper( + types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None, +): output_ref = torch.matmul(tensors.a_ref, tensors.w_ref) output_ref_type = output_ref.dtype if tensors.w_ch_s is not None: - output_ref = (output_ref.to(tensors.w_ch_s.dtype) * - tensors.w_ch_s.unsqueeze(0)).to(output_ref_type) + output_ref = ( + output_ref.to(tensors.w_ch_s.dtype) * tensors.w_ch_s.unsqueeze(0) + ).to(output_ref_type) if tensors.w_tok_s is not None: - output_ref = (output_ref.to(tensors.w_tok_s.dtype) * - tensors.w_tok_s.unsqueeze(1)).to(output_ref_type) + output_ref = ( + output_ref.to(tensors.w_tok_s.dtype) * tensors.w_tok_s.unsqueeze(1) + ).to(output_ref_type) output = ops.machete_mm( a=tensors.a, @@ -245,23 +277,23 @@ def machete_mm_test_helper(types: TypeConfig, # Relax atol as our reduction dim becomes larger (more rounding error) # Relax atol when we have zeropoints since the way machete applies # zeropoints (after scales) causes noise around 0 - atol = 1 if tensors.w_g_zp is not None\ + atol = ( + 1 + if tensors.w_g_zp is not None else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1) + ) rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1 - torch.testing.assert_close(output, - output_ref.to(output.dtype), - rtol=rtol, - atol=atol) + torch.testing.assert_close( + output, output_ref.to(output.dtype), rtol=rtol, atol=atol + ) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) +@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x)) @pytest.mark.parametrize("types", TEST_TYPES) def test_machete_all_schedules(shape, types: TypeConfig): - group_sizes: list[Optional[int]] = [] if types.group_scale_type is None: group_sizes = [None] @@ -275,20 +307,20 @@ def test_machete_all_schedules(shape, types: TypeConfig): tensors = create_test_tensors(shape, types, group_size) print(f"MNK = {shape}") for schedule in ops.machete_supported_schedules( - types.act_type, - types.weight_type, - group_scales_type=types.group_scale_type, - group_zeros_type=types.group_scale_type, - out_type=types.output_type): + types.act_type, + types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_scale_type, + out_type=types.output_type, + ): print(f"Testing schedule {schedule}") machete_mm_test_helper(types, tensors, group_size, schedule) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) +@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x)) @pytest.mark.parametrize("types", TEST_TYPES) def test_machete_heuristic(shape, types: TypeConfig): group_sizes: list[Optional[int]] = [] @@ -306,19 +338,22 @@ def test_machete_heuristic(shape, types: TypeConfig): # Test working on other devices -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_machete_devices(device: str): group_size = 128 - type_config = TypeConfig(act_type=torch.float16, - weight_type=scalar_types.uint4b8, - output_type=None, - group_scale_type=torch.float16, - group_zero_type=None, - channel_scale_type=None, - token_scale_type=None) + type_config = TypeConfig( + act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None, + ) tensors = create_test_tensors((512, 4096, 4096), type_config, group_size) @@ -331,29 +366,30 @@ def test_machete_devices(device: str): # Test working with a subset of A and B -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) def test_machete_subset(): group_size = 128 - type_config = TypeConfig(act_type=torch.float16, - weight_type=scalar_types.uint4b8, - output_type=None, - group_scale_type=torch.float16, - group_zero_type=None, - channel_scale_type=None, - token_scale_type=None) - - tensors = create_test_tensors((512, 4096, 4096), - type_config, - group_size, - subset_stride_factor=2) + type_config = TypeConfig( + act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None, + ) + + tensors = create_test_tensors( + (512, 4096, 4096), type_config, group_size, subset_stride_factor=2 + ) machete_mm_test_helper(type_config, tensors, group_size) # Test to make sure cuda graphs work class MacheteLayer(torch.nn.Module): - def __init__(self, **kwargs): super().__init__() self.kwargs = kwargs @@ -362,8 +398,9 @@ def forward(self, a): return ops.machete_mm(a=a, **self.kwargs) -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") +@pytest.mark.skipif( + not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type." +) def test_machete_cuda_graph(): m, n, k = 512, 4096, 4096 @@ -375,7 +412,8 @@ def test_machete_cuda_graph(): zero_points = False w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - a.dtype, b, wtype, stype, group_size, zero_points) + a.dtype, b, wtype, stype, group_size, zero_points + ) # Construct a trivial model with a single layer that calls a machete kernel model = MacheteLayer( diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 0be020085bfa..0833115fcf30 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -4,6 +4,7 @@ Run `pytest tests/kernels/quantization/test_marlin_gemm.py`. """ + import pytest import torch @@ -11,24 +12,44 @@ from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) + GPTQ_MARLIN_24_MAX_PARALLEL, + GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, + GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, - query_marlin_supported_quant_types) + MARLIN_SUPPORTED_GROUP_SIZES, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + query_marlin_supported_quant_types, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like, - rand_marlin_weight_nvfp4_like) + FP4_MARLIN_SUPPORTED_GROUP_SIZES, + rand_marlin_weight_mxfp4_like, + rand_marlin_weight_nvfp4_like, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - marlin_quant_fp8_torch) + marlin_quant_fp8_torch, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize, - marlin_weights) + MarlinWorkspace, + awq_marlin_quantize, + get_weight_perm, + marlin_quantize, + marlin_weights, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( - marlin_24_quantize) + marlin_24_quantize, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) + awq_pack, + gptq_pack, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) from vllm.scalar_type import scalar_types ACT_ORDER_OPTS = [False, True] @@ -56,24 +77,27 @@ def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) def rand_data(shape, dtype=torch.float16): return torch.randn(shape, dtype=dtype, device="cuda") -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False, False)) +@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, - act_order, mnk_factors): +def test_gptq_marlin_repack( + k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors +): m_factor, n_factor, k_factor = mnk_factors size_k = k_chunk * k_factor @@ -96,7 +120,8 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - b_weight, quant_type, group_size, act_order) + b_weight, quant_type, group_size, act_order + ) # Pack to GPTQ format q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) @@ -109,11 +134,14 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, # Pack to Marlin format weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, - weight_perm) + marlin_q_w_1 = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, weight_perm + ) - opcheck(torch.ops._C.gptq_marlin_repack, - (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits)) + opcheck( + torch.ops._C.gptq_marlin_repack, + (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits), + ) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -128,16 +156,16 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(True)) +@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, - mnk_factors): +def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_k = k_chunk * k_factor @@ -152,21 +180,22 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, b_weight = rand_data((size_k, size_n)) # Quantize - w_ref, q_w, s, zp = quantize_weights(b_weight, - quant_type, - group_size, - zero_points=True) + w_ref, q_w, s, zp = quantize_weights( + b_weight, quant_type, group_size, zero_points=True + ) # Pack to AWQ format q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) # Pack to Marlin format weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, - weight_perm) + marlin_q_w_1 = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, weight_perm + ) - opcheck(torch.ops._C.awq_marlin_repack, - (q_w_awq, size_k, size_n, quant_type.size_bits)) + opcheck( + torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits) + ) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.awq_marlin_repack( @@ -180,23 +209,34 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types()) @pytest.mark.parametrize( - "group_size", - set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)) + "group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES) +) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS) @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) @pytest.mark.parametrize("dtype", DTYPES) -def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, - mnk_factors, act_order, is_k_full, use_atomic_add, - use_fp32_reduce, dtype): +def test_gptq_marlin_gemm( + k_chunk, + n_chunk, + quant_type, + group_size, + mnk_factors, + act_order, + is_k_full, + use_atomic_add, + use_fp32_reduce, + dtype, +): m_factor, n_factor, k_factor = mnk_factors has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] @@ -225,11 +265,13 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, return if group_size == 16: - w_ref, marlin_q_w, marlin_s, marlin_s2 = \ - rand_marlin_weight_nvfp4_like(b_weight.T, group_size) + w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like( + b_weight.T, group_size + ) else: - w_ref, marlin_q_w, marlin_s = \ - rand_marlin_weight_mxfp4_like(b_weight.T, group_size) + w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like( + b_weight.T, group_size + ) marlin_s2 = None g_idx = None @@ -240,8 +282,7 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, return if act_order: return - w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch( - b_weight.T, group_size) + w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size) g_idx = None sort_indices = None marlin_zp = None @@ -250,7 +291,8 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, if group_size == 16: return w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, quant_type, group_size) + b_weight, quant_type, group_size + ) g_idx = None sort_indices = None marlin_s2 = None @@ -258,18 +300,37 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, if group_size == 16: return w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, act_order) + b_weight, quant_type, group_size, act_order + ) marlin_zp = None marlin_s2 = None workspace = marlin_make_workspace_new(w_ref.device) - opcheck(torch.ops._C.gptq_marlin_gemm, - (a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp, - g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0], - b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, - use_fp32_reduce, False), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck( + torch.ops._C.gptq_marlin_gemm, + ( + a_input, + None, + marlin_q_w, + None, + marlin_s, + marlin_s2, + marlin_zp, + g_idx, + sort_indices, + workspace, + quant_type.id, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full, + use_atomic_add, + use_fp32_reduce, + False, + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) output = ops.gptq_marlin_gemm( a_input, @@ -302,23 +363,40 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size, # TODO: find better way to test this? @torch.compile(fullgraph=True) -def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s, scratch, quant_type, size_m, size_n, - size_k): - return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s, scratch, quant_type, size_m, - size_n, size_k) +def marlin_24_gemm_tester( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + scratch, + quant_type, + size_m, + size_n, + size_k, +): + return ops.gptq_marlin_24_gemm( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + scratch, + quant_type, + size_m, + size_n, + size_k, + ) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) @pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, - mnk_factors): +def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors): m_factor, n_factor, k_factor = mnk_factors size_m = m_factor @@ -328,19 +406,31 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size) + (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize( + b_weight, quant_type, group_size + ) - workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_MAX_PARALLEL) + workspace_24 = MarlinWorkspace( + size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL + ) output_ref = torch.matmul(a_input, w_24_ref) - opcheck(torch.ops._C.gptq_marlin_24_gemm, - (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, - workspace_24.scratch, quant_type.id, a_input.shape[0], - b_weight.shape[1], a_input.shape[1]), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck( + torch.ops._C.gptq_marlin_24_gemm, + ( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + workspace_24.scratch, + quant_type.id, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) output = marlin_24_gemm_tester( a_input, @@ -361,8 +451,10 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="Marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES) @@ -386,22 +478,22 @@ def test_hqq_marlin_gemm( a_input = rand_data((size_m, size_k)) dev = a_input.device - b_weight = torch.randint(0, - 10, (size_n, size_k), - dtype=torch.uint8, - device=dev) + b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev) scale = rand_data((size_n, size_k // group_size)) zero = rand_data((size_n, size_k // group_size)) gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n) sort_indices = torch.empty(0, dtype=torch.int, device=dev) - marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, - 4).to(dev) - marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n, - group_size).to(dev) - marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n, - group_size).to(dev) + marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to( + dev + ) + marlin_s = marlin_permute_scales( + scale.transpose(1, 0), size_k, size_n, group_size + ).to(dev) + marlin_zp = marlin_permute_scales( + zero.transpose(1, 0), size_k, size_n, group_size + ).to(dev) g_idx = marlin_make_empty_g_idx(dev) g_idx_sort_indices = marlin_make_empty_g_idx(dev) @@ -433,8 +525,7 @@ def test_hqq_marlin_gemm( s_flat = scale.reshape(-1, 1) dequant = (b_flat - zp_flat) * s_flat - output_ref = torch.matmul(a_input, - dequant.reshape(b_weight.shape).transpose(1, 0)) + output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0)) torch.cuda.synchronize() @@ -451,11 +542,12 @@ def test_marlin_gemm_subset_input(): big_m = size_m * 2 big_k = size_k * 2 - a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8] + a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8] b_weight = rand_data((size_k, size_n)) w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, False) + b_weight, quant_type, group_size, False + ) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) workspace = marlin_make_workspace_new(a_input.device) @@ -497,12 +589,13 @@ def test_marlin_gemm_with_bias(size_m): size_k, size_n = 1024, 2048 a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - b_bias = rand_data((size_n, )) * 10 + b_bias = rand_data((size_n,)) * 10 marlin_bias = marlin_permute_bias(b_bias) w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, False) + b_weight, quant_type, group_size, False + ) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) workspace = marlin_make_workspace_new(a_input.device) diff --git a/tests/kernels/quantization/test_nvfp4_quant.py b/tests/kernels/quantization/test_nvfp4_quant.py index 3a8f4c17598c..e9b091d06697 100644 --- a/tests/kernels/quantization/test_nvfp4_quant.py +++ b/tests/kernels/quantization/test_nvfp4_quant.py @@ -8,15 +8,27 @@ from vllm.scalar_type import scalar_types if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) DTYPES = [torch.float16, torch.bfloat16] SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] -PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48), - (90, 128), (150, 128), (150, 48), (90, 80)] +PAD_SHAPES = [ + (90, 64), + (150, 64), + (128, 48), + (128, 80), + (150, 80), + (90, 48), + (90, 128), + (150, 128), + (150, 48), + (90, 80), +] SEEDS = [42] -CUDA_DEVICES = ['cuda:0'] +CUDA_DEVICES = ["cuda:0"] FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max @@ -31,7 +43,22 @@ # 0001 -> 0.5 # 0000 -> 0 E2M1_TO_FLOAT32 = [ - 0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6. + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + 0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ] BLOCK_SIZE = 16 @@ -74,8 +101,7 @@ def ref_nvfp4_quant(x, global_scale): assert x.ndim == 2 m, n = x.shape x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE)) - vec_max = torch.max(torch.abs(x), dim=-1, - keepdim=True)[0].to(torch.float32) + vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) scale = scale.to(torch.float8_e4m3fn).to(torch.float32) output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) @@ -131,7 +157,7 @@ def test_quantize_to_fp4( def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: dtype = torch.float16 current_platform.seed_everything(42) - torch.set_default_device('cuda:0') + torch.set_default_device("cuda:0") m, n = pad_shape diff --git a/tests/kernels/quantization/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py index 67e041f2b71c..434564737c88 100644 --- a/tests/kernels/quantization/test_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_nvfp4_scaled_mm.py @@ -2,15 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch -from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype from vllm import _custom_ops as ops from vllm.platforms import current_platform if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) DTYPES = [torch.float16, torch.bfloat16] # m, n, k @@ -19,26 +20,31 @@ SHAPES.extend(PAD_SHAPES) SEEDS = [42] -CUDA_DEVICES = ['cuda:0'] +CUDA_DEVICES = ["cuda:0"] -def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale, - m, n, dtype, block_size, device): +def get_ref_results( + a_fp4, + b_fp4, + a_sf, + b_sf, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, +): _, m_k = a_fp4.shape _, n_k = b_fp4.shape - assert (m_k == n_k) - a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, - a_sf, - a_global_scale, - dtype=dtype, - device=device, - block_size=block_size) - b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, - b_sf, - b_global_scale, - dtype=dtype, - device=device, - block_size=block_size) + assert m_k == n_k + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size + ) + b_in_dtype = dequantize_nvfp4_to_dtype( + b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size + ) return torch.matmul(a_in_dtype, b_in_dtype.t()) @@ -60,25 +66,34 @@ def test_nvfp4_gemm( a_dtype = torch.randn((m, k), dtype=dtype, device=device) b_dtype = torch.randn((n, k), dtype=dtype, device=device) - a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) - b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) - alpha = 1. / (a_global_scale * b_global_scale) + a_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) + ).to(torch.float32) + b_global_scale = ( + (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1) + ).to(torch.float32) + alpha = 1.0 / (a_global_scale * b_global_scale) # ops.scaled_fp4_quant returns swizzled scales, while weights # from checkpoints are in linear scales. a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) # get_ref_results unswizzles the scales internally. - expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved, - b_scale_interleaved, a_global_scale, - b_global_scale, m, n, dtype, block_size, - device) - out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved, - b_scale_interleaved, alpha, dtype) + expected_out = get_ref_results( + a_fp4, + b_fp4, + a_scale_interleaved, + b_scale_interleaved, + a_global_scale, + b_global_scale, + m, + n, + dtype, + block_size, + device, + ) + out = ops.cutlass_scaled_fp4_mm( + a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype + ) - torch.testing.assert_close(out, - expected_out.to(dtype=dtype), - atol=1e-1, - rtol=1e-1) + torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) diff --git a/tests/kernels/quantization/test_per_token_group_quant.py b/tests/kernels/quantization/test_per_token_group_quant.py index 07f17d1efe64..7a6500454530 100644 --- a/tests/kernels/quantization/test_per_token_group_quant.py +++ b/tests/kernels/quantization/test_per_token_group_quant.py @@ -13,15 +13,15 @@ @pytest.mark.parametrize("scale_ue8m0", [False, True]) @pytest.mark.parametrize("group_size", [64, 128]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_per_token_group_quant_fp8(shape, column_major: bool, - scale_ue8m0: bool, group_size: int): +def test_per_token_group_quant_fp8( + shape, column_major: bool, scale_ue8m0: bool, group_size: int +): device = "cuda" torch.manual_seed(42) num_tokens, hidden_dim = shape - x = (torch.randn( - (num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8) + x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8 # cuda path out_q, scale = fp8_utils.per_token_group_quant_fp8( @@ -53,8 +53,7 @@ def test_per_token_group_quant_int8(shape, group_size: int): torch.manual_seed(42) num_tokens, hidden_dim = shape - x = (torch.randn( - (num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8) + x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8 # cuda path out_q, scale = int8_utils.per_token_group_quant_int8( diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 6de5fc9c5601..dc6557b93f05 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -63,12 +63,11 @@ @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") @torch.inference_mode() def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): torch.manual_seed(seed) - #TODO: Zero-centering the inputs causes errors for LLMM1! + # TODO: Zero-centering the inputs causes errors for LLMM1! # Without that the numbers quickly saturate, and may # be giving false matches. A = torch.rand(n, k, dtype=dtype, device="cuda") @@ -83,14 +82,13 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) cu_count = current_platform.get_cu_count() - A = torch.rand(n, k, dtype=dtype, device="cuda") - .5 - B = torch.rand(m, k, dtype=dtype, device="cuda") - .5 + A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5 + B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5 ref_out = torch.nn.functional.linear(A, B) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count) @@ -101,16 +99,15 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) cu_count = current_platform.get_cu_count() xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas - A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier - B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier - BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5 + A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier + BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 ref_out = torch.nn.functional.linear(A, B, BIAS) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) @@ -121,16 +118,15 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) cu_count = current_platform.get_cu_count() xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas - A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier - B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier - BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5 + A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier + BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5 ref_out = torch.nn.functional.linear(A, B, BIAS) out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) @@ -143,7 +139,8 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif( not (current_platform.is_rocm() and current_platform.supports_fp8()), - reason="only test for rocm fp8") + reason="only test for rocm fp8", +) def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) @@ -153,13 +150,10 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) - ref_out = torch._scaled_mm(A, - B.t(), - out_dtype=dtype, - scale_a=scale_a, - scale_b=scale_b) - out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, - current_platform.get_cu_count()) + ref_out = torch._scaled_mm( + A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b + ) + out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count()) assert torch.allclose(out, ref_out, rtol=0.01) @@ -169,25 +163,24 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif( not (current_platform.is_rocm() and current_platform.supports_fp8()), - reason="only test for rocm fp8") + reason="only test for rocm fp8", +) def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas - A = (torch.rand(n, k, device="cuda") - .5) * xavier - B = (torch.rand(m, k, device="cuda") - .5) * xavier - BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5 + A = (torch.rand(n, k, device="cuda") - 0.5) * xavier + B = (torch.rand(m, k, device="cuda") - 0.5) * xavier + BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) - ref_out = torch._scaled_mm(A, - B.t(), - out_dtype=dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=BIAS) - out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, - current_platform.get_cu_count(), BIAS) + ref_out = torch._scaled_mm( + A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS + ) + out = ops.wvSplitKQ( + B, A, dtype, scale_a, scale_b, current_platform.get_cu_count(), BIAS + ) assert torch.allclose(out, ref_out, rtol=0.01) diff --git a/tests/kernels/quantization/test_silu_mul_nvfp4_quant.py b/tests/kernels/quantization/test_silu_mul_nvfp4_quant.py index a40d0c4ef122..4617464a3978 100644 --- a/tests/kernels/quantization/test_silu_mul_nvfp4_quant.py +++ b/tests/kernels/quantization/test_silu_mul_nvfp4_quant.py @@ -3,16 +3,20 @@ import pytest import torch -from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, - FLOAT8_E4M3_MAX, - dequantize_nvfp4_to_dtype) +from tests.kernels.quantization.nvfp4_utils import ( + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype, +) from vllm._custom_ops import scaled_fp4_quant from vllm.model_executor.layers.activation import SiluAndMul from vllm.platforms import current_platform if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) FP4_DTYPE = torch.uint8 FP8_DTYPE = current_platform.fp8_dtype() @@ -30,24 +34,24 @@ def test_silu_mul_nvfp4_quant( shape: tuple[int, int], ) -> None: current_platform.seed_everything(42) - device = 'cuda:0' + device = "cuda:0" torch.set_default_device(device) x = torch.randn(shape, dtype=dtype) # ref op ref_output = SiluAndMul().forward_native(x) - ref_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / - torch.abs(ref_output).max().to(torch.float32)) - ref_output_quant, ref_block_scale = scaled_fp4_quant( - ref_output, ref_global_scale) + ref_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs( + ref_output + ).max().to(torch.float32) + ref_output_quant, ref_block_scale = scaled_fp4_quant(ref_output, ref_global_scale) # fused op fused_output_quant = torch.empty_like(ref_output_quant) fused_block_scale = torch.empty_like(ref_block_scale) - torch.ops._C.silu_and_mul_nvfp4_quant(fused_output_quant, - fused_block_scale, x, - ref_global_scale) + torch.ops._C.silu_and_mul_nvfp4_quant( + fused_output_quant, fused_block_scale, x, ref_global_scale + ) # check dtype assert ref_output_quant.dtype == FP4_DTYPE @@ -59,17 +63,14 @@ def test_silu_mul_nvfp4_quant( assert ref_block_scale.shape == fused_block_scale.shape # check dequantized output - ref_output_dequant = dequantize_nvfp4_to_dtype(ref_output_quant, - ref_block_scale, - ref_global_scale, dtype, - device) - fused_output_dequant = dequantize_nvfp4_to_dtype(fused_output_quant, - fused_block_scale, - ref_global_scale, dtype, - device) + ref_output_dequant = dequantize_nvfp4_to_dtype( + ref_output_quant, ref_block_scale, ref_global_scale, dtype, device + ) + fused_output_dequant = dequantize_nvfp4_to_dtype( + fused_output_quant, fused_block_scale, ref_global_scale, dtype, device + ) atol, rtol = 3e-1, 3e-1 - torch.testing.assert_close(ref_output_dequant, - fused_output_dequant, - atol=atol, - rtol=rtol) + torch.testing.assert_close( + ref_output_dequant, fused_output_dequant, atol=atol, rtol=rtol + ) diff --git a/tests/kernels/quantization/test_triton_scaled_mm.py b/tests/kernels/quantization/test_triton_scaled_mm.py index d8cfb5710dba..1026332d99f8 100644 --- a/tests/kernels/quantization/test_triton_scaled_mm.py +++ b/tests/kernels/quantization/test_triton_scaled_mm.py @@ -4,6 +4,7 @@ Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`. """ + import importlib from typing import Optional @@ -15,17 +16,19 @@ device = "cuda" triton_scaled_mm_module = importlib.import_module( - "vllm.model_executor.layers.quantization.compressed_tensors." - "triton_scaled_mm") + "vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm" +) triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm -def torch_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def torch_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: out = torch.mm(a.to(torch.float32), b.to(torch.float32)) out = scale_a * out out = scale_b.T * out @@ -44,20 +47,22 @@ def get_8bit_types(): # This test is to check regressions for int8 support on ROCm. -@pytest.mark.parametrize("model_path", [ - "neuralmagic/Llama-3.2-1B-quantized.w8a8", -]) +@pytest.mark.parametrize( + "model_path", + [ + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + ], +) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [10]) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="Should only run on ROCm") -def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, - max_tokens, num_logprobs): +@pytest.mark.skipif(not current_platform.is_rocm(), reason="Should only run on ROCm") +def test_rocm_compressed_tensors_w8a8( + vllm_runner, example_prompts, model_path, max_tokens, num_logprobs +): dtype = "bfloat16" with vllm_runner(model_path, dtype=dtype) as vllm_model: - vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, - num_logprobs) + vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs) MNK_FACTORS = [ @@ -76,10 +81,10 @@ def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path, @pytest.mark.parametrize("use_scalar_scale_a", [True, False]) @pytest.mark.parametrize("use_scalar_scale_b", [True, False]) @pytest.mark.parametrize("use_bias", [True, False]) -def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, - use_scalar_scale_b, use_bias): - is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t - ).is_floating_point() +def test_scaled_mm( + M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias +): + is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point() current_platform.seed_everything(0) @@ -93,10 +98,8 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, # # So, the values here are kept small enough to avoid this situation. if is_floating_point_type(in_dtype): - a = (0.25 * torch.rand( - (M, K), dtype=torch.float32, device=device)).to(in_dtype) - b = (0.25 * torch.rand( - (K, N), dtype=torch.float32, device=device)).to(in_dtype) + a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype) + b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype) else: a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device) b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device) @@ -113,7 +116,7 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, bias = None if use_bias: - bias = torch.rand((N, ), device=device, dtype=out_dtype) + bias = torch.rand((N,), device=device, dtype=out_dtype) c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) diff --git a/tests/kernels/test_apply_repetition_penalties.py b/tests/kernels/test_apply_repetition_penalties.py index 90380b872d6c..a4619f5846b1 100644 --- a/tests/kernels/test_apply_repetition_penalties.py +++ b/tests/kernels/test_apply_repetition_penalties.py @@ -4,8 +4,10 @@ import torch from tests.kernels.utils import opcheck -from vllm._custom_ops import (apply_repetition_penalties_cuda, - apply_repetition_penalties_torch) +from vllm._custom_ops import ( + apply_repetition_penalties_cuda, + apply_repetition_penalties_torch, +) from vllm.platforms import current_platform NUM_SEQS = [1, 2, 3, 4, 8, 13, 17, 32, 37, 256, 1023, 1024, 1025] @@ -21,8 +23,9 @@ @pytest.mark.parametrize("repetition_penalty", REPETITION_PENALTY_VALUES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test for checking CUDA kernel") +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test for checking CUDA kernel" +) @torch.inference_mode() def test_apply_repetition_penalties( num_seqs: int, @@ -32,7 +35,7 @@ def test_apply_repetition_penalties( seed: int, ) -> None: """ - Test the apply_repetition_penalties custom op + Test the apply_repetition_penalties custom op against a reference implementation. """ current_platform.seed_everything(seed) @@ -46,39 +49,40 @@ def test_apply_repetition_penalties( output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool) # Mark some tokens as repeated in prompt and output - prompt_indices = torch.randint(0, vocab_size, - (num_seqs, max(1, vocab_size // 200))) - output_indices = torch.randint(0, vocab_size, - (num_seqs, max(1, vocab_size // 200))) + prompt_indices = torch.randint(0, vocab_size, (num_seqs, max(1, vocab_size // 200))) + output_indices = torch.randint(0, vocab_size, (num_seqs, max(1, vocab_size // 200))) for i in range(num_seqs): prompt_mask[i, prompt_indices[i]] = True output_mask[i, output_indices[i]] = True # Create repetition penalties tensor - repetition_penalties = torch.full((num_seqs, ), - repetition_penalty, - dtype=dtype) + repetition_penalties = torch.full((num_seqs,), repetition_penalty, dtype=dtype) # Run all three implementations logits_torch = logits.clone() logits_cuda = logits.clone() - apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask, - repetition_penalties) - apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask, - repetition_penalties) + apply_repetition_penalties_torch( + logits_torch, prompt_mask, output_mask, repetition_penalties + ) + apply_repetition_penalties_cuda( + logits_cuda, prompt_mask, output_mask, repetition_penalties + ) # Compare all outputs to reference torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3) # Test the operator by applying the opcheck utility - opcheck(torch.ops._C.apply_repetition_penalties_, - (logits.clone(), prompt_mask, output_mask, repetition_penalties)) + opcheck( + torch.ops._C.apply_repetition_penalties_, + (logits.clone(), prompt_mask, output_mask, repetition_penalties), + ) -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test for checking CUDA kernel") +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test for checking CUDA kernel" +) @torch.inference_mode() def test_apply_repetition_penalties_zero_seqs() -> None: """ @@ -104,22 +108,24 @@ def test_apply_repetition_penalties_zero_seqs() -> None: # No tokens to mark as repeated since num_seqs=0 # Create repetition penalties tensor - repetition_penalties = torch.full((num_seqs, ), - repetition_penalty, - dtype=dtype) + repetition_penalties = torch.full((num_seqs,), repetition_penalty, dtype=dtype) # Run all three implementations logits_torch = logits.clone() logits_cuda = logits.clone() - apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask, - repetition_penalties) - apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask, - repetition_penalties) + apply_repetition_penalties_torch( + logits_torch, prompt_mask, output_mask, repetition_penalties + ) + apply_repetition_penalties_cuda( + logits_cuda, prompt_mask, output_mask, repetition_penalties + ) # Compare all outputs to reference torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3) # Test the operator by applying the opcheck utility - opcheck(torch.ops._C.apply_repetition_penalties_, - (logits.clone(), prompt_mask, output_mask, repetition_penalties)) + opcheck( + torch.ops._C.apply_repetition_penalties_, + (logits.clone(), prompt_mask, output_mask, repetition_penalties), + ) diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index 39753c0cc15b..87002c72f6e1 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -9,11 +9,13 @@ import torch from packaging import version -from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, - create_standard_kv_cache_spec, - create_vllm_config) -from vllm.v1.attention.backends.flex_attention import ( - FlexAttentionMetadataBuilder) +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, +) +from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder from ..models.utils import check_embeddings_close, check_logprobs_close @@ -57,26 +59,32 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") set_seed(seed) - with vllm_runner(model_name, - runner="generate", - tensor_parallel_size=1, - num_gpu_blocks_override=128, - enforce_eager=True) as llm_flex: + with vllm_runner( + model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True, + ) as llm_flex: output_flex = llm_flex.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs) + prompts, max_tokens, num_logprobs + ) # Run with default backend with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") set_seed(seed) - with vllm_runner(model_name, - runner="generate", - tensor_parallel_size=1, - num_gpu_blocks_override=128, - enforce_eager=True, - gpu_memory_utilization=0.85) as llm_default: + with vllm_runner( + model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True, + gpu_memory_utilization=0.85, + ) as llm_default: output_default = llm_default.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs) + prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=output_flex, @@ -107,23 +115,27 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") - with vllm_runner(model_name, - runner="pooling", - dtype=torch.bfloat16, - tensor_parallel_size=1, - max_model_len=100, - enforce_eager=True) as llm_flex: + with vllm_runner( + model_name, + runner="pooling", + dtype=torch.bfloat16, + tensor_parallel_size=1, + max_model_len=100, + enforce_eager=True, + ) as llm_flex: flex_outputs = llm_flex.embed(prompts) # Run with default backend with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - with vllm_runner(model_name, - runner="pooling", - dtype=torch.bfloat16, - tensor_parallel_size=1, - max_model_len=100, - enforce_eager=True) as llm_default: + with vllm_runner( + model_name, + runner="pooling", + dtype=torch.bfloat16, + tensor_parallel_size=1, + max_model_len=100, + enforce_eager=True, + ) as llm_default: default_outputs = llm_default.embed(prompts) check_embeddings_close( @@ -147,27 +159,29 @@ def test_block_mask_direct_vs_slow_path(): """ device = torch.device("cuda") - vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B", - block_size=16, - max_model_len=1024) + vllm_config = create_vllm_config( + model_name="meta-llama/Meta-Llama-3-8B", block_size=16, max_model_len=1024 + ) kv_cache_spec = create_standard_kv_cache_spec(vllm_config) # Use a mixed batch that will create groups spanning multiple sequences - batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256], - query_lens=[33, 5, 32, 64], - name="test_mixed_batch") + batch_spec = BatchSpec( + seq_lens=[35, 64, 128, 256], query_lens=[33, 5, 32, 64], name="test_mixed_batch" + ) common_attn_metadata = create_common_attn_metadata( - batch_spec, vllm_config.cache_config.block_size, device) + batch_spec, vllm_config.cache_config.block_size, device + ) - builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, - device) + builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device) - metadata_direct = builder.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + metadata_direct = builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) builder.direct_build = False - metadata_slow = builder.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + metadata_slow = builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) assert metadata_direct.block_mask is not None assert metadata_slow.block_mask is not None @@ -184,20 +198,20 @@ def test_block_mask_direct_vs_slow_path(): missing_details = [] for group_idx in range(num_groups): - direct_blocks = set( - direct_indices[group_idx, :direct_num[group_idx]].tolist()) - slow_blocks = set( - slow_indices[group_idx, :slow_num[group_idx]].tolist()) + direct_blocks = set(direct_indices[group_idx, : direct_num[group_idx]].tolist()) + slow_blocks = set(slow_indices[group_idx, : slow_num[group_idx]].tolist()) missing_blocks = slow_blocks - direct_blocks if missing_blocks: all_contained = False missing_details.append( - f"Group {group_idx}: missing {sorted(missing_blocks)}") + f"Group {group_idx}: missing {sorted(missing_blocks)}" + ) assert all_contained, ( - "Direct path is missing blocks required by slow path:\n" + - "\n".join(missing_details)) + "Direct path is missing blocks required by slow path:\n" + + "\n".join(missing_details) + ) if __name__ == "__main__": diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py index 803453a20d81..c79e6105e69f 100644 --- a/tests/kernels/test_fused_quant_activation.py +++ b/tests/kernels/test_fused_quant_activation.py @@ -13,13 +13,12 @@ NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] -def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, - scale: torch.Tensor) -> torch.Tensor: +def ref_impl( + silu_and_mul: SiluAndMul, x: torch.Tensor, scale: torch.Tensor +) -> torch.Tensor: silu_and_mul_out = silu_and_mul.forward_native(x) out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale) return out @@ -27,9 +26,7 @@ def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: out_shape = (x.shape[0], x.shape[1] // 2) - out = torch.empty(out_shape, - dtype=current_platform.fp8_dtype(), - device=x.device) + out = torch.empty(out_shape, dtype=current_platform.fp8_dtype(), device=x.device) torch.ops._C.silu_and_mul_quant(out, x, scale) return out @@ -57,7 +54,7 @@ def test_silu_and_mul( layer = SiluAndMul() # Make inputs - scale = (torch.randn((1), device=device, dtype=torch.float32)) + scale = torch.randn((1), device=device, dtype=torch.float32) x = torch.randn(num_tokens, hidden_size, dtype=dtype) ref_out = ref_impl(layer, x, scale) @@ -66,6 +63,7 @@ def test_silu_and_mul( assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype assert ref_out.shape == ops_out.shape - assert torch.allclose(ref_out.to(dtype=torch.float32), - ops_out.to(dtype=torch.float32)) + assert torch.allclose( + ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32) + ) opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale)) diff --git a/tests/kernels/test_onednn.py b/tests/kernels/test_onednn.py index 198a8fdf0c33..9f78c177a81f 100644 --- a/tests/kernels/test_onednn.py +++ b/tests/kernels/test_onednn.py @@ -44,24 +44,27 @@ def ref_int8_scaled_mm( ): if azp is not None: a = a.to(dtype=torch.float32) - azp.to(dtype=torch.float32) - output = torch.mm((scale_a * a.to(dtype=torch.float32)), - (scale_b * b.to(dtype=torch.float32))) + output = torch.mm( + (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32)) + ) if bias is not None: output += bias.float() return output.to(dtype=output_type) -def onednn_int8_gemm_test_helper(primitive_cache_size: int, - m: int, - n: int, - k: int, - per_tensor_a_quant: bool, - per_tensor_b_quant: bool, - use_azp: bool, - use_bias: bool, - out_dtype: torch.dtype = torch.bfloat16, - device: str = "cpu"): +def onednn_int8_gemm_test_helper( + primitive_cache_size: int, + m: int, + n: int, + k: int, + per_tensor_a_quant: bool, + per_tensor_b_quant: bool, + use_azp: bool, + use_bias: bool, + out_dtype: torch.dtype = torch.bfloat16, + device: str = "cpu", +): # Test for a oneDNN kernel with per-tensor / per-token activation # quantization and per-tensor / per-output channel weight quantization. a = to_int8(torch.randn((m, k), device=device) * 5) @@ -70,8 +73,8 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int, a_scales_shape = (1, 1) if per_tensor_a_quant else (m, 1) b_scales_shape = (1, 1) if per_tensor_b_quant else (1, n) - scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32)) - scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32)) + scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32) + scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32) if use_azp: azp = torch.rand(a_scales_shape, dtype=torch.float32) * 10 + 1.5 @@ -81,10 +84,7 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int, azp = None azp_adj = None - if use_bias: - bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10 - else: - bias = None + bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None handler = ops.create_onednn_scaled_mm( b, @@ -105,20 +105,21 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int, # To test runtime bias setting out = torch.zeros((m, n), dtype=out_dtype) ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, None) - baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None, - out_dtype) + baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None, out_dtype) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) -def onednn_gemm_test_helper(primitive_cache_size: int, - m: int, - n: int, - k: int, - use_bias: bool, - use_stride: bool, - dtype: torch.dtype = torch.bfloat16, - device: str = "cpu"): +def onednn_gemm_test_helper( + primitive_cache_size: int, + m: int, + n: int, + k: int, + use_bias: bool, + use_stride: bool, + dtype: torch.dtype = torch.bfloat16, + device: str = "cpu", +): if use_stride: a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5 a = a[:, :k] @@ -128,7 +129,7 @@ def onednn_gemm_test_helper(primitive_cache_size: int, b = torch.rand((n, k), dtype=dtype, device=device) * 1.5 if use_bias: - bias = torch.rand((n, ), device=device, dtype=dtype) * 5 + bias = torch.rand((n,), device=device, dtype=dtype) * 5 bias_f32 = bias.float() else: bias = None @@ -140,16 +141,18 @@ def onednn_gemm_test_helper(primitive_cache_size: int, ) out = ops.onednn_mm(handler, a, bias) - baseline = torch.nn.functional.linear(a.float(), b.float(), - bias_f32).to(dtype=a.dtype) + baseline = torch.nn.functional.linear(a.float(), b.float(), bias_f32).to( + dtype=a.dtype + ) torch.testing.assert_close(out, baseline) if use_bias: # To test runtime bias setting out = ops.onednn_mm(handler, a, None) - baseline = torch.nn.functional.linear(a.float(), b.float(), - None).to(dtype=a.dtype) + baseline = torch.nn.functional.linear(a.float(), b.float(), None).to( + dtype=a.dtype + ) torch.testing.assert_close(out, baseline) diff --git a/tests/kernels/test_shuffle_rows.py b/tests/kernels/test_shuffle_rows.py index 7d02e1764e7d..c7de64066e87 100644 --- a/tests/kernels/test_shuffle_rows.py +++ b/tests/kernels/test_shuffle_rows.py @@ -14,20 +14,15 @@ @pytest.mark.parametrize("num_tokens", [1, 16, 64, 128, 256, 512, 1024]) @pytest.mark.parametrize("hidden_size", [128, 256, 512, 1024, 2048, 4096]) -@pytest.mark.parametrize("dtype", - [torch.float16, torch.bfloat16, torch.float32]) -def test_shuffle_rows_basic(num_tokens: int, hidden_size: int, - dtype: torch.dtype): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_shuffle_rows_basic(num_tokens: int, hidden_size: int, dtype: torch.dtype): """Test basic functionality of shuffle_rows with various tensor sizes and dtypes.""" if not current_platform.is_cuda(): pytest.skip("shuffle_rows requires CUDA") # Create input tensor - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) # Create a simple permutation map (identity mapping) dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) @@ -47,24 +42,18 @@ def test_shuffle_rows_basic(num_tokens: int, hidden_size: int, @pytest.mark.parametrize("num_tokens", [16, 64, 128]) @pytest.mark.parametrize("hidden_size", [128, 512, 1024]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_shuffle_rows_permutation(num_tokens: int, hidden_size: int, - dtype: torch.dtype): +def test_shuffle_rows_permutation( + num_tokens: int, hidden_size: int, dtype: torch.dtype +): """Test shuffle_rows with actual permutation.""" if not current_platform.is_cuda(): pytest.skip("shuffle_rows requires CUDA") # Create input tensor - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) # Create a reverse permutation map - dst2src_map = torch.arange(num_tokens - 1, - -1, - -1, - device="cuda", - dtype=torch.int32) + dst2src_map = torch.arange(num_tokens - 1, -1, -1, device="cuda", dtype=torch.int32) # Test shuffle_rows output = shuffle_rows(input_tensor, dst2src_map) @@ -90,17 +79,13 @@ def test_shuffle_rows_expansion(num_tokens: int, hidden_size: int): dtype = torch.float16 # Create input tensor - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) # Create a mapping that duplicates some tokens (expansion) expanded_size = num_tokens * 2 - dst2src_map = torch.randint(0, - num_tokens, (expanded_size, ), - device="cuda", - dtype=torch.int32) + dst2src_map = torch.randint( + 0, num_tokens, (expanded_size,), device="cuda", dtype=torch.int32 + ) # Test shuffle_rows output = shuffle_rows(input_tensor, dst2src_map) @@ -113,10 +98,9 @@ def test_shuffle_rows_expansion(num_tokens: int, hidden_size: int): # Verify that each output row matches the corresponding input row for i in range(expanded_size): src_idx = dst2src_map[i].item() - torch.testing.assert_close(output[i], - input_tensor[src_idx], - atol=1e-6, - rtol=1e-5) + torch.testing.assert_close( + output[i], input_tensor[src_idx], atol=1e-6, rtol=1e-5 + ) @pytest.mark.parametrize("num_tokens", [16, 64]) @@ -132,10 +116,7 @@ def test_shuffle_rows_random_permutation(num_tokens: int, hidden_size: int): torch.manual_seed(42) # Create input tensor - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) # Create a random permutation map dst2src_map = torch.randperm(num_tokens, device="cuda", dtype=torch.int32) @@ -151,10 +132,9 @@ def test_shuffle_rows_random_permutation(num_tokens: int, hidden_size: int): # Verify that each output row matches the corresponding input row for i in range(num_tokens): src_idx = dst2src_map[i].item() - torch.testing.assert_close(output[i], - input_tensor[src_idx], - atol=1e-6, - rtol=1e-5) + torch.testing.assert_close( + output[i], input_tensor[src_idx], atol=1e-6, rtol=1e-5 + ) def test_shuffle_rows_edge_cases(): @@ -188,10 +168,7 @@ def test_shuffle_rows_moe_like_scenario(): topk = 2 # Simulate input tokens - input_tensor = torch.randn(batch_size, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) # Simulate expert assignment (each token goes to topk experts) # This creates a mapping where tokens are duplicated for multiple experts @@ -215,14 +192,12 @@ def test_shuffle_rows_moe_like_scenario(): for i in range(batch_size): for k in range(topk): output_idx = i * topk + k - torch.testing.assert_close(output[output_idx], - input_tensor[i], - atol=1e-6, - rtol=1e-5) + torch.testing.assert_close( + output[output_idx], input_tensor[i], atol=1e-6, rtol=1e-5 + ) -@pytest.mark.parametrize("dtype", - [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) def test_shuffle_rows_dtype_consistency(dtype: torch.dtype): """Test that shuffle_rows preserves dtype correctly.""" if not current_platform.is_cuda(): @@ -232,10 +207,7 @@ def test_shuffle_rows_dtype_consistency(dtype: torch.dtype): hidden_size = 512 # Create input tensor with specific dtype - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) # Test shuffle_rows @@ -257,10 +229,7 @@ def test_shuffle_rows_device_consistency(): dtype = torch.float16 # Create input tensor on CUDA - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) # Test shuffle_rows @@ -281,10 +250,7 @@ def test_shuffle_rows_contiguous_output(): dtype = torch.float16 # Create input tensor - input_tensor = torch.randn(num_tokens, - hidden_size, - device="cuda", - dtype=dtype) + input_tensor = torch.randn(num_tokens, hidden_size, device="cuda", dtype=dtype) dst2src_map = torch.arange(num_tokens, device="cuda", dtype=torch.int32) # Test shuffle_rows diff --git a/tests/kernels/test_triton_flash_attention.py b/tests/kernels/test_triton_flash_attention.py index 1c31cfb25e5a..4b0bbb992d2e 100644 --- a/tests/kernels/test_triton_flash_attention.py +++ b/tests/kernels/test_triton_flash_attention.py @@ -4,21 +4,24 @@ Run `pytest tests/kernels/test_triton_flash_attention.py`. """ + import pytest import torch -from vllm.attention.ops.triton_flash_attention import (SUPPORTED_LAYOUTS, - MetaData, - compute_alibi_tensor, - scale_fp8, - triton_attention_rocm) +from vllm.attention.ops.triton_flash_attention import ( + SUPPORTED_LAYOUTS, + MetaData, + compute_alibi_tensor, + scale_fp8, + triton_attention_rocm, +) from vllm.platforms import current_platform class ReferenceAttention: - - def __init__(self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, - input_metadata): + def __init__( + self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata + ): self.Z = Z self.HQ = HQ self.HK = HK @@ -30,21 +33,23 @@ def __init__(self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, self.input_metadata = input_metadata def fwd(self, q, k, v): - scores = torch.einsum('bhqd,bhkd->bhqk', q, - k).float() * self.input_metadata.sm_scale + scores = ( + torch.einsum("bhqd,bhkd->bhqk", q, k).float() * self.input_metadata.sm_scale + ) if self.input_metadata.causal: - mask = torch.tril(torch.ones(self.N_CTX_Q, - self.N_CTX_K, - device="cuda"), - diagonal=self.N_CTX_K - self.N_CTX_Q) + mask = torch.tril( + torch.ones(self.N_CTX_Q, self.N_CTX_K, device="cuda"), + diagonal=self.N_CTX_K - self.N_CTX_Q, + ) scores[:, :, mask == 0] = float("-inf") if self.input_metadata.bias is not None: scores += self.input_metadata.bias if self.use_alibi: - scores += compute_alibi_tensor(self.input_metadata.alibi_slopes, - self.N_CTX_Q, self.N_CTX_K) + scores += compute_alibi_tensor( + self.input_metadata.alibi_slopes, self.N_CTX_Q, self.N_CTX_K + ) p = torch.softmax(scores, dim=-1) if self.input_metadata.causal: @@ -54,31 +59,38 @@ def fwd(self, q, k, v): # should be out of the softmax. nan_mask = torch.isnan(p) p[nan_mask == 1] = 0 - ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(self.dtype), v) + ref_out = torch.einsum("bhqk,bhkd->bhqd", p.to(self.dtype), v) # compare - if self.input_metadata.layout == 'bshd': + if self.input_metadata.layout == "bshd": ref_out = ref_out.transpose(1, 2).clone() return ref_out def fwd_fp8(self, q_quantized, k_quantized, v_quantized): q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to( - self.dtype) + self.dtype + ) k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to( - self.dtype) + self.dtype + ) v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to( - self.dtype) + self.dtype + ) result = self.fwd(q, k, v) if self.input_metadata.o_scale is not None: result, _ = scale_fp8(result, self.input_metadata.o_scale) return result def fwd_fp8_kv(self, q, k_quantized, v_quantized): - k_descale, v_descale = (self.input_metadata.k_descale, - self.input_metadata.v_descale) - k_dequantized = (k_quantized.to(torch.float32) * - k_descale.to(torch.float32)).to(self.dtype) - v_dequantized = (v_quantized.to(torch.float32) * - v_descale.to(torch.float32)).to(self.dtype) + k_descale, v_descale = ( + self.input_metadata.k_descale, + self.input_metadata.v_descale, + ) + k_dequantized = ( + k_quantized.to(torch.float32) * k_descale.to(torch.float32) + ).to(self.dtype) + v_dequantized = ( + v_quantized.to(torch.float32) * v_descale.to(torch.float32) + ).to(self.dtype) return self.fwd(q, k_dequantized, v_dequantized) def varlen_fwd(self, q, k, v, is_mqa=False): @@ -86,29 +98,33 @@ def varlen_fwd(self, q, k, v, is_mqa=False): if is_mqa: # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so # the size aligns with Q. - k_ref = k.view(k.shape[0], k.shape[1], 1, - k.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) - v_ref = v.view(v.shape[0], v.shape[1], 1, - v.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) + k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand( + -1, -1, self.HQ // self.HK, -1 + ) + v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand( + -1, -1, self.HQ // self.HK, -1 + ) else: k_ref = k v_ref = v for i in range(0, self.input_metadata.num_contexts): - start_q, start_k = self.input_metadata.cu_seqlens_q[ - i], self.input_metadata.cu_seqlens_k[i] - end_q, end_k = self.input_metadata.cu_seqlens_q[ - i + 1], self.input_metadata.cu_seqlens_k[i + 1] + start_q, start_k = ( + self.input_metadata.cu_seqlens_q[i], + self.input_metadata.cu_seqlens_k[i], + ) + end_q, end_k = ( + self.input_metadata.cu_seqlens_q[i + 1], + self.input_metadata.cu_seqlens_k[i + 1], + ) k_curr = k_ref[start_k:end_k] v_curr = v_ref[start_k:end_k] if is_mqa: k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) - scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], - k_curr).float() - p = torch.softmax(scores * self.input_metadata.sm_scale, - dim=-1).half() - ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + scores = torch.einsum("qhd,khd->qhk", q[start_q:end_q], k_curr).float() + p = torch.softmax(scores * self.input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum("qhk,khd->qhd", p, v_curr) return ref_out @@ -123,8 +139,7 @@ def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False): # model. p_scale = None - o_scale = torch.rand(1, device="cuda", - requires_grad=False) if use_o_scale else None + o_scale = torch.rand(1, device="cuda", requires_grad=False) if use_o_scale else None return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale @@ -150,10 +165,10 @@ def input_helper( current_platform.seed_everything(0) # Initialize q, k, v - if layout == 'bhsd': + if layout == "bhsd": q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) - elif layout == 'bshd': + elif layout == "bshd": q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) @@ -161,69 +176,54 @@ def input_helper( # for n heads the set of slopes is the geometric sequence that starts # 2^(-8/n) alibi_slopes = torch.tensor( - [2**(-8 / HQ * i) for i in range(1, HQ + 1)], + [2 ** (-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, - device="cuda").repeat(Z, 1) + device="cuda", + ).repeat(Z, 1) else: alibi_slopes = None if use_bias: - bias = torch.randn((1, HQ, N_CTX_Q, N_CTX_K), - dtype=dtype, - device="cuda", - requires_grad=False) + bias = torch.randn( + (1, HQ, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda", requires_grad=False + ) else: bias = None - q = torch.randn(q_tensor_shape, - dtype=dtype, - device="cuda", - requires_grad=False) - k = torch.randn(k_tensor_shape, - dtype=dtype, - device="cuda", - requires_grad=False) - v = torch.randn(k_tensor_shape, - dtype=dtype, - device="cuda", - requires_grad=False) + q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) + k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) + v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=False) if is_fp8: - (q, k, v, q_descale, k_descale, v_descale, p_scale, - o_scale) = quantize_input(q, - k, - v, - use_o_scale=use_o_scale, - fp8_kv=fp8_kv) + (q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale) = quantize_input( + q, k, v, use_o_scale=use_o_scale, fp8_kv=fp8_kv + ) else: q_descale = k_descale = v_descale = p_scale = o_scale = None - input_metadata = MetaData(sm_scale=D_HEAD**-0.5, - max_seqlens_q=N_CTX_Q, - max_seqlens_k=N_CTX_K, - layout=layout, - alibi_slopes=alibi_slopes, - alibi_batch=Z, - alibi_nheads=HQ, - q_descale=q_descale, - k_descale=k_descale, - v_descale=v_descale, - p_scale=p_scale, - o_scale=o_scale, - bias=bias, - seqlen_q=N_CTX_Q, - seqlen_k=N_CTX_K) + input_metadata = MetaData( + sm_scale=D_HEAD**-0.5, + max_seqlens_q=N_CTX_Q, + max_seqlens_k=N_CTX_K, + layout=layout, + alibi_slopes=alibi_slopes, + alibi_batch=Z, + alibi_nheads=HQ, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + p_scale=p_scale, + o_scale=o_scale, + bias=bias, + seqlen_q=N_CTX_Q, + seqlen_k=N_CTX_K, + ) return q, k, v, input_metadata -def varlen_input_helper(Z, - HQ, - HK, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - equal_seqlens=False): +def varlen_input_helper( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False +): current_platform.seed_everything(0) # Random sequence lengths. Using N_CTX as kind of max of sum of individual @@ -231,66 +231,72 @@ def varlen_input_helper(Z, if not equal_seqlens: max_seqlens_q = N_CTX_Q // Z max_seqlens_k = N_CTX_K // Z - seqlens_q = torch.randint(1, - max_seqlens_q + 1, (Z, ), - dtype=torch.int32) - seqlens_k = torch.randint(1, - max_seqlens_k + 1, (Z, ), - dtype=torch.int32) + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z,), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z,), dtype=torch.int32) else: - seqlens_q = torch.full((Z, ), N_CTX_Q // Z) - seqlens_k = torch.full((Z, ), N_CTX_K // Z) + seqlens_q = torch.full((Z,), N_CTX_Q // Z) + seqlens_k = torch.full((Z,), N_CTX_K // Z) # Calculate cumulative sequence lengths - cu_seqlens_q = torch.cat([ - torch.tensor([0], dtype=torch.int32), - seqlens_q.cumsum(dim=0, dtype=torch.int32) - ]) - cu_seqlens_k = torch.cat([ - torch.tensor([0], dtype=torch.int32), - seqlens_k.cumsum(dim=0, dtype=torch.int32) - ]) + cu_seqlens_q = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + seqlens_q.cumsum(dim=0, dtype=torch.int32), + ] + ) + cu_seqlens_k = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + seqlens_k.cumsum(dim=0, dtype=torch.int32), + ] + ) cu_seqlens_q = cu_seqlens_q.to(device="cuda") cu_seqlens_k = cu_seqlens_k.to(device="cuda") # Initialize q, k, v with variable lengths total_q = cu_seqlens_q[-1].item() total_k = cu_seqlens_k[-1].item() - q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_() - k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_() - v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, - device="cuda").normal_(mean=0., std=0.5).requires_grad_() + q = ( + torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) sm_scale = D_HEAD**-0.5 input_metadata = MetaData(sm_scale=sm_scale) input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) return q, k, v, input_metadata -@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ - (1, 48, 12, 1, 1, 64), - (4, 4, 4, 128, 128, 65), - (16, 48, 48, 1, 1, 128), - (64, 48, 24, 3, 3, 128), - (4, 4, 4, 113, 123, 1), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_alibi', [True, False]) -@pytest.mark.parametrize('layout', ['bshd']) -def test_op_fwd(Z, - HQ, - HK, - N_CTX_Q, - N_CTX_K, - D_HEAD, - causal, - use_alibi, - layout, - dtype=torch.float16): +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (1, 48, 12, 1, 1, 64), + (4, 4, 4, 128, 128, 65), + (16, 48, 48, 1, 1, 128), + (64, 48, 24, 3, 3, 128), + (4, 4, 4, 113, 123, 1), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("use_alibi", [True, False]) +@pytest.mark.parametrize("layout", ["bshd"]) +def test_op_fwd( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16 +): current_platform.seed_everything(0) - q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, - dtype, layout, use_alibi, causal) + q, k, v, input_metadata = input_helper( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, use_alibi, causal + ) o = torch.empty_like(q) @@ -299,48 +305,50 @@ def test_op_fwd(Z, # Transpose here if layout is bshd so we have same reference code for all # layouts - if layout == 'bshd': + if layout == "bshd": q = q.transpose(1, 2).clone() k = k.transpose(1, 2).clone() v = v.transpose(1, 2).clone() # Replicate K and V if using MQA/GQA if HQ != HK: - k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], - k.shape[3]).expand(-1, -1, HQ // HK, -1, - -1).reshape(k.shape[0], -1, k.shape[2], - k.shape[3]) - v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], - v.shape[3]).expand(-1, -1, HQ // HK, -1, - -1).reshape(v.shape[0], -1, v.shape[2], - v.shape[3]) - - ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, - use_alibi, dtype, input_metadata) + k = ( + k.view(k.shape[0], k.shape[1], -1, k.shape[2], k.shape[3]) + .expand(-1, -1, HQ // HK, -1, -1) + .reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + ) + v = ( + v.view(v.shape[0], v.shape[1], -1, v.shape[2], v.shape[3]) + .expand(-1, -1, HQ // HK, -1, -1) + .reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + ) + + ref_impl = ReferenceAttention( + Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, input_metadata + ) ref_out = ref_impl.fwd(q, k, v) torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('layout', ['bhsd']) -@pytest.mark.parametrize('use_o_scale', [True, False]) -@pytest.mark.skipif(torch.cuda.get_device_capability() < (9, 0), - reason="Triton FP8 requires CUDA 9.0 or higher") -def test_op_fwd_fp8(Z, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - causal, - layout, - use_o_scale, - dtype=torch.float32): +@pytest.mark.parametrize( + "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("layout", ["bhsd"]) +@pytest.mark.parametrize("use_o_scale", [True, False]) +@pytest.mark.skipif( + torch.cuda.get_device_capability() < (9, 0), + reason="Triton FP8 requires CUDA 9.0 or higher", +) +def test_op_fwd_fp8( + Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, use_o_scale, dtype=torch.float32 +): current_platform.seed_everything(0) # Disable grad to save memory it won't run into OOM on CI machine. @@ -358,95 +366,103 @@ def test_op_fwd_fp8(Z, causal=causal, layout=layout, is_fp8=True, - use_o_scale=use_o_scale) + use_o_scale=use_o_scale, + ) o = torch.empty_like(q_quantized) if use_o_scale else None - tri_out, _ = triton_attention_rocm(q_quantized, k_quantized, v_quantized, - o, input_metadata) + tri_out, _ = triton_attention_rocm( + q_quantized, k_quantized, v_quantized, o, input_metadata + ) - ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized) # compare - torch.testing.assert_close(ref_out.to(torch.float32), - tri_out.to(torch.float32), - atol=7e-2, - rtol=2e-1) - - -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), - (4, 4, 113, 123, 1), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('layout', ['bhsd']) -def test_op_fwd_fp8_kv(Z, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - causal, - layout, - dtype=torch.float32): + torch.testing.assert_close( + ref_out.to(torch.float32), tri_out.to(torch.float32), atol=7e-2, rtol=2e-1 + ) + + +@pytest.mark.parametrize( + "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("layout", ["bhsd"]) +def test_op_fwd_fp8_kv( + Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, dtype=torch.float32 +): current_platform.seed_everything(0) - q, k_quantized, v_quantized, input_metadata = input_helper(Z, - H, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - causal=causal, - layout=layout, - is_fp8=True, - fp8_kv=True) + q, k_quantized, v_quantized, input_metadata = input_helper( + Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + causal=causal, + layout=layout, + is_fp8=True, + fp8_kv=True, + ) o = torch.empty_like(q) - tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, - input_metadata) + tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, input_metadata) - ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized) torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1) -@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (4, 48, 1, 1, 64), - (4, 48, 1, 1, 128), - (4, 48, 3, 3, 128), - (4, 4, 128, 128, 65), -]) -@pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_bias', [True]) -@pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize( + "Z, H, N_CTX_Q, N_CTX_K, D_HEAD", + [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("use_bias", [True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): current_platform.seed_everything(0) - q, k, v, input_metadata = input_helper(Z, - H, - H, - N_CTX_Q, - N_CTX_K, - D_HEAD, - dtype, - layout='bhsd', - causal=causal, - use_bias=use_bias) + q, k, v, input_metadata = input_helper( + Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + layout="bhsd", + causal=causal, + use_bias=use_bias, + ) o = torch.empty_like(q) # triton implementation tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) - ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.fwd(q, k, v) # compare @@ -454,47 +470,47 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): # NOTE: Uses thd layout, so also tests thd. -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(1, 48, 256, 64), - (4, 48, 512, 64), - (16, 48, 512, 64), - (64, 48, 128, 128)]) -@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize( + "Z, H, N_CTX, D_HEAD", + [(1, 48, 256, 64), (4, 48, 512, 64), (16, 48, 512, 64), (64, 48, 128, 128)], +) +@pytest.mark.parametrize("causal", [True, False]) def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): - - q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, - D_HEAD, dtype) + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) tri_out = torch.empty_like(q) triton_attention_rocm(q, k, v, tri_out, input_metadata) - ref_impl = ReferenceAttention(Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, - input_metadata) + ref_impl = ReferenceAttention( + Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False) torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) # NOTE: Uses thd layout, so also tests thd. -@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), - (4, 48, 12, 256, 64), - (4, 48, 4, 512, 64), - (4, 64, 16, 128, 128)]) -@pytest.mark.parametrize('causal', [False]) -def test_op_varlen_mqa_fwd(Z, - HQ, - HK, - N_CTX, - D_HEAD, - causal, - dtype=torch.float16): - q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, - D_HEAD, dtype) +@pytest.mark.parametrize( + "Z, HQ, HK, N_CTX, D_HEAD", + [ + (2, 48, 24, 128, 64), + (4, 48, 12, 256, 64), + (4, 48, 4, 512, 64), + (4, 64, 16, 128, 128), + ], +) +@pytest.mark.parametrize("causal", [False]) +def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper( + Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype + ) tri_out = torch.empty_like(q) triton_attention_rocm(q, k, v, tri_out, input_metadata) - ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, - dtype, input_metadata) + ref_impl = ReferenceAttention( + Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, dtype, input_metadata + ) ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True) torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index db6f29c28c95..015424d9ee0f 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -17,10 +17,13 @@ from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.attention.backends.registry import _Backend from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, - STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils import ( + STR_BACKEND_ENV_VAR, + STR_FLASH_ATTN_VAL, + STR_XFORMERS_ATTN_VAL, + make_tensor_with_pad, +) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -39,7 +42,7 @@ class QKVInputs(NamedTuple): - ''' + """ Data structure for representing unpacked attention inputs, query/key/values and their sequence lengths. @@ -49,7 +52,7 @@ class QKVInputs(NamedTuple): num_heads x head_size) attention inputs * q_seq_lens: query sequence lengths list * kv_seq_lens: shared key/value sequence lengths list - ''' + """ query: torch.Tensor key: torch.Tensor @@ -59,7 +62,7 @@ class QKVInputs(NamedTuple): class QKVO(NamedTuple): - ''' + """ Data structure for representing unpacked attention inputs, alongside unpacked known-correct attention output @@ -69,14 +72,14 @@ class QKVO(NamedTuple): num_heads x head_size) attention inputs * ideal_output: unpacked (batch_size x padded_seq_len x num_heads x head_size) known-correct attention output - ''' + """ qkv: QKVInputs ideal_output: torch.Tensor class PackedQKVInputs(NamedTuple): - ''' + """ Data structure for representing packed attention inputs Attributes: @@ -88,7 +91,7 @@ class PackedQKVInputs(NamedTuple): packed tensor * q_seq_lens: query sequence lengths list * kv_seq_lens: shared key/value sequence lengths list - ''' + """ query: torch.Tensor key: torch.Tensor @@ -100,7 +103,7 @@ class PackedQKVInputs(NamedTuple): class PackedQKVO(NamedTuple): - ''' + """ Data structure for representing packed attention inputs, alongside packed known-correct attention output @@ -110,28 +113,28 @@ class PackedQKVO(NamedTuple): x head_size) attention inputs * ideal_output: packed (number_of_tokens x num_heads x head_size) known-correct attention output - ''' + """ packed_qkv: Optional[PackedQKVInputs] ideal_output: torch.Tensor class KVMemoryMap(NamedTuple): - ''' + """ Data structure for encapsulating KV cache memory mapping. Attributes: * block_tables: KV cache block tables * slot_mapping: mapping of sequence offset to physical address - ''' + """ block_tables: torch.Tensor slot_mapping: torch.Tensor class PhaseTestParameters(NamedTuple): - ''' + """ Data structure for encapsulating the test parameters for a given test "phase" (prefill or decode phase) and attention scenario (encoder, decoder-self, encoder/decoder-cross) @@ -143,7 +146,7 @@ class PhaseTestParameters(NamedTuple): output * kv_mmap: KV cache memory mapping, specific to this test phase & attention scenario - ''' + """ packed_qkvo: PackedQKVO kv_mmap: Optional[KVMemoryMap] @@ -153,41 +156,43 @@ def maybe_make_int_tensor( _list: Optional[list[int]], device: Union[torch.device, str], ) -> torch.Tensor: - ''' + """ Convert Python int list to a 1D int torch.Tensor on `device` Returns: * If _list is not None: 1D int torch.Tensor on `device` * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.int, device=device) + """ + return ( + None if _list is None else torch.tensor(_list, dtype=torch.int, device=device) + ) def maybe_make_long_tensor( _list: Optional[list[int]], device: Union[torch.device, str], ) -> torch.Tensor: - ''' + """ Convert Python int list to a 1D long torch.Tensor on `device` Returns: * If _list is not None: 1D long torch.Tensor on `device` * None otherwise - ''' - return None if _list is None else torch.tensor( - _list, dtype=torch.long, device=device) + """ + return ( + None if _list is None else torch.tensor(_list, dtype=torch.long, device=device) + ) def maybe_max(_list: Optional[list]) -> Optional[Number]: - ''' + """ Returns: * If _list is not None: max(_list) * None otherwise - ''' + """ return None if _list is None else max(_list) @@ -195,7 +200,7 @@ def make_causal_mask( q_max_seq_len: int, kv_max_seq_len: int, ) -> torch.Tensor: - ''' + """ Create a q_max_seq_len x kv_max_seq_len causal mask Arguments: @@ -206,19 +211,19 @@ def make_causal_mask( Returns: * 2D tensor, q_max_seq_len x kv_max_seq_len - ''' + """ # Create a matrix where entry (i, j) is True if i >= j mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1) # Replace True with float('-inf') and False with 0 - mask = mask.masked_fill(mask == 1, - float('-inf')).masked_fill(mask == 0, 0.0) + mask = mask.masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0) return mask -def override_backend_env_variable(mpatch: pytest.MonkeyPatch, - backend_name: str) -> None: - ''' +def override_backend_env_variable( + mpatch: pytest.MonkeyPatch, backend_name: str +) -> None: + """ Override the environment variable indicating the vLLM backend temporarily, using pytest monkeypatch to ensure that the env vars get reset once the test context exits. @@ -227,18 +232,20 @@ def override_backend_env_variable(mpatch: pytest.MonkeyPatch, * mpatch: pytest monkeypatch instance * backend_name: attention backend name to force - ''' + """ mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name) -def ref_masked_attention(query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - custom_mask: Optional[torch.Tensor] = None, - q_seq_lens: Optional[list] = None, - kv_seq_lens: Optional[list] = None) -> torch.Tensor: - ''' +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + custom_mask: Optional[torch.Tensor] = None, + q_seq_lens: Optional[list] = None, + kv_seq_lens: Optional[list] = None, +) -> torch.Tensor: + """ "Golden" masked attention reference. Supports two types of masking: * Basic attention mask, utilizing {q,kv}_seq_lens args to mask out @@ -260,14 +267,14 @@ def ref_masked_attention(query: torch.Tensor, Returns: * Attention result, batch_size x q_padded_seq_len x num_heads x head_size - ''' + """ assert q_seq_lens is not None assert kv_seq_lens is not None batch_size = query.shape[0] - assert (len(q_seq_lens) == batch_size) - assert (len(kv_seq_lens) == batch_size) + assert len(q_seq_lens) == batch_size + assert len(kv_seq_lens) == batch_size attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float() @@ -303,7 +310,7 @@ def make_qkv( attn_type: AttentionType = AttentionType.ENCODER_DECODER, force_max_len: bool = False, ) -> tuple[QKVInputs, QKVInputs, QKVInputs]: - ''' + """ Construct QKV test tensors for self- and cross-attention. Generates three query/key/value triplets: @@ -340,14 +347,12 @@ def make_qkv( * Overall QKVInputs structure (containing full unpacked Q/K/V tensors) * Prefill QKVInputs structure (containing all but the last sequence offset) * Decode QKVInputs structure (containing all only the last sequence offset) - ''' + """ if force_max_len: q_seq_lens = [max_q_seq_len for _ in range(batch_size)] else: - q_seq_lens = [ - random.randint(2, max_q_seq_len) for _ in range(batch_size) - ] + q_seq_lens = [random.randint(2, max_q_seq_len) for _ in range(batch_size)] kv_seq_lens = None if force_kv_seq_lens is not None: kv_seq_lens = force_kv_seq_lens @@ -360,50 +365,44 @@ def make_qkv( if force_max_len: kv_seq_lens = [max_kv_seq_len] * batch_size else: - kv_seq_lens = [ - random.randint(2, max_kv_seq_len) for _ in range(batch_size) - ] - - query = torch.rand( - (batch_size, max_q_seq_len, num_heads, head_size)).to(device) - key = torch.rand( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - value = torch.rand( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - - prefill_query = torch.zeros( - (batch_size, max_q_seq_len, num_heads, head_size)).to(device) - prefill_key = torch.zeros( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - prefill_value = torch.zeros( - (batch_size, max_kv_seq_len, num_heads, head_size)).to(device) - - decode_query = torch.zeros( - (batch_size, 1, num_heads, head_size)).to(device) + kv_seq_lens = [random.randint(2, max_kv_seq_len) for _ in range(batch_size)] + + query = torch.rand((batch_size, max_q_seq_len, num_heads, head_size)).to(device) + key = torch.rand((batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + value = torch.rand((batch_size, max_kv_seq_len, num_heads, head_size)).to(device) + + prefill_query = torch.zeros((batch_size, max_q_seq_len, num_heads, head_size)).to( + device + ) + prefill_key = torch.zeros((batch_size, max_kv_seq_len, num_heads, head_size)).to( + device + ) + prefill_value = torch.zeros((batch_size, max_kv_seq_len, num_heads, head_size)).to( + device + ) + + decode_query = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) - decode_value = torch.zeros( - (batch_size, 1, num_heads, head_size)).to(device) + decode_value = torch.zeros((batch_size, 1, num_heads, head_size)).to(device) - for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, - kv_seq_lens)): + for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens, kv_seq_lens)): query[bdx, q_seq_len:, :, :] = 0 key[bdx, kv_seq_len:, :, :] = 0 value[bdx, kv_seq_len:, :, :] = 0 - prefill_query[bdx, - 0:(q_seq_len - 1), :, :] = query[bdx, - 0:(q_seq_len - 1), :, :] - prefill_key[bdx, - 0:(kv_seq_len - 1), :, :] = key[bdx, - 0:(kv_seq_len - 1), :, :] - prefill_value[bdx, 0:(kv_seq_len - - 1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :] - - decode_query[bdx, :, :, :] = query[bdx, - (q_seq_len - 1):q_seq_len, :, :] - decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :] - decode_value[bdx, :, :, :] = value[bdx, - (kv_seq_len - 1):kv_seq_len, :, :] + prefill_query[bdx, 0 : (q_seq_len - 1), :, :] = query[ + bdx, 0 : (q_seq_len - 1), :, : + ] + prefill_key[bdx, 0 : (kv_seq_len - 1), :, :] = key[ + bdx, 0 : (kv_seq_len - 1), :, : + ] + prefill_value[bdx, 0 : (kv_seq_len - 1), :, :] = value[ + bdx, 0 : (kv_seq_len - 1), :, : + ] + + decode_query[bdx, :, :, :] = query[bdx, (q_seq_len - 1) : q_seq_len, :, :] + decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1) : kv_seq_len, :, :] + decode_value[bdx, :, :, :] = value[bdx, (kv_seq_len - 1) : kv_seq_len, :, :] prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens] prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens] @@ -417,25 +416,29 @@ def make_qkv( key, value, q_seq_lens, - kv_seq_lens), + kv_seq_lens, + ), QKVInputs( prefill_query, # Prefill subset of QKV sequences prefill_key, prefill_value, prefill_q_seq_lens, - prefill_kv_seq_lens), + prefill_kv_seq_lens, + ), QKVInputs( decode_query, # Decode subset of KV sequences decode_key, decode_value, decode_q_seq_lens, - decode_kv_seq_lens)) + decode_kv_seq_lens, + ), + ) def pack_tensor( - unpacked_tensor: torch.Tensor, seq_lens: list[int], - device: Union[torch.device, str]) -> tuple[torch.Tensor, list[int]]: - ''' + unpacked_tensor: torch.Tensor, seq_lens: list[int], device: Union[torch.device, str] +) -> tuple[torch.Tensor, list[int]]: + """ Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an unpadded number_of_tokens x num_heads x head_size tensor, where number_of_tokens = sum(seq_lens) @@ -451,7 +454,7 @@ def pack_tensor( * packed_tensor: number_of_tokens x num_heads x head_size * start_loc_list: start idx of each batch elt in packed_tensor; [0] + list(itertools.accumulate(seq_lens)) - ''' + """ num_tok = sum(seq_lens) num_heads = unpacked_tensor.shape[-2] @@ -460,16 +463,15 @@ def pack_tensor( packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device) for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)): - - packed_tensor[start_loc:( - start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :] + packed_tensor[start_loc : (start_loc + seq_len), :, :] = unpacked_tensor[ + bdx, :seq_len, :, : + ] return packed_tensor, start_loc_list -def pack_qkv(qkv: QKVInputs, device: Union[torch.device, - str]) -> PackedQKVInputs: - ''' +def pack_qkv(qkv: QKVInputs, device: Union[torch.device, str]) -> PackedQKVInputs: + """ Individually pack each of Q, K and V, each with dimensions batch_size x padded_seq_len x num_heads x head_size, into respective number_of_tokens x num_heads x head_size tensors. @@ -488,28 +490,30 @@ def pack_qkv(qkv: QKVInputs, device: Union[torch.device, * Packed (number_of_tokens x num_heads x head_size) QKV inputs derived from unpacked inputs - ''' + """ if qkv.query is None: packed_query = None q_start_loc_list = None else: - packed_query, q_start_loc_list = pack_tensor(qkv.query, - qkv.q_seq_lens, - device=device) - packed_key, kv_start_loc_list = pack_tensor(qkv.key, - qkv.kv_seq_lens, - device=device) + packed_query, q_start_loc_list = pack_tensor( + qkv.query, qkv.q_seq_lens, device=device + ) + packed_key, kv_start_loc_list = pack_tensor(qkv.key, qkv.kv_seq_lens, device=device) packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device) return PackedQKVInputs( - packed_query, packed_key, packed_value, q_start_loc_list, + packed_query, + packed_key, + packed_value, + q_start_loc_list, kv_start_loc_list, (None if q_start_loc_list is None else qkv.q_seq_lens), - qkv.kv_seq_lens) + qkv.kv_seq_lens, + ) def make_backend(backend_name: str) -> AttentionBackend: - ''' + """ Construct the backend instance determined by the backend_name string argument. @@ -523,31 +527,33 @@ def make_backend(backend_name: str) -> AttentionBackend: Returns: * Backend instance - ''' + """ if backend_name == STR_XFORMERS_ATTN_VAL: - from vllm.v1.attention.backends.xformers import ( - XFormersAttentionBackend) + from vllm.v1.attention.backends.xformers import XFormersAttentionBackend + return XFormersAttentionBackend() if backend_name == STR_FLASH_ATTN_VAL: from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + return FlashAttentionBackend() if backend_name == "TRITON_ATTN": - from vllm.v1.attention.backends.triton_attn import ( - TritonAttentionBackend) + from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend + return TritonAttentionBackend() if backend_name == "FLEX_ATTENTION": - from vllm.v1.attention.backends.flex_attention import ( - FlexAttentionBackend) + from vllm.v1.attention.backends.flex_attention import FlexAttentionBackend + return FlexAttentionBackend() if backend_name == "TORCH_SDPA": from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend + return TorchSDPABackend() if backend_name == "FLASHINFER": from vllm.v1.attention.backends.flashinfer import FlashInferBackend + return FlashInferBackend() - raise AssertionError( - f"Unrecognized backend_name {backend_name} for unit test") + raise AssertionError(f"Unrecognized backend_name {backend_name} for unit test") def make_alibi_bias( @@ -565,7 +571,8 @@ def make_alibi_bias( attn_biases: list[Any] = [] num_heads = alibi_slopes.shape[0] assert num_heads >= num_kv_heads, ( - "ALiBi slopes expect at least as many heads as KV heads") + "ALiBi slopes expect at least as many heads as KV heads" + ) for seq_len in seq_lens: bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) @@ -591,9 +598,17 @@ def _make_metadata_tensors( context_lens: Optional[list[int]], encoder_seq_lens: Optional[list[int]], device: Union[torch.device, str], -) -> tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor], - torch.Tensor, torch.Tensor, Optional[int]]: - ''' +) -> tuple[ + torch.Tensor, + torch.Tensor, + Any, + Any, + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + Optional[int], +]: + """ Build scalar & tensor values required to build attention metadata structure. Arguments: @@ -613,48 +628,61 @@ def _make_metadata_tensors( * encoder_seq_lens_tensor: encoder seq_lens list, as tensor * encoder_seq_start_loc: start idx of each encoder sequence * max_encoder_seq_len: encoder seq_lens list, as tensor - ''' + """ seq_lens_tensor = maybe_make_int_tensor(seq_lens, device) context_lens_tensor = maybe_make_int_tensor(context_lens, device) max_context_len = maybe_max(context_lens) max_seq_len = maybe_max(seq_lens) encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device) - max_encoder_seq_len = (None if encoder_seq_lens is None else - max(encoder_seq_lens)) + max_encoder_seq_len = None if encoder_seq_lens is None else max(encoder_seq_lens) seq_start_loc = None if seq_lens_tensor is not None: - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=seq_lens_tensor.device) - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - - encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=encoder_seq_lens_tensor.device) - torch.cumsum(encoder_seq_lens_tensor, - dim=0, - dtype=encoder_seq_start_loc.dtype, - out=encoder_seq_start_loc[1:]) - - return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len, - seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc, - max_encoder_seq_len) - - -def make_kv_cache(num_blocks: int, - num_heads: int, - head_size: int, - block_size: int, - device: Union[torch.device, str], - backend: str, - default_val: float = 0.0) -> torch.Tensor: - ''' + seq_start_loc = torch.zeros( + seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=seq_lens_tensor.device, + ) + torch.cumsum( + seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:] + ) + + encoder_seq_start_loc = torch.zeros( + encoder_seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=encoder_seq_lens_tensor.device, + ) + torch.cumsum( + encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:], + ) + + return ( + seq_lens_tensor, + context_lens_tensor, + max_context_len, + max_seq_len, + seq_start_loc, + encoder_seq_lens_tensor, + encoder_seq_start_loc, + max_encoder_seq_len, + ) + + +def make_kv_cache( + num_blocks: int, + num_heads: int, + head_size: int, + block_size: int, + device: Union[torch.device, str], + backend: str, + default_val: float = 0.0, +) -> torch.Tensor: + """ Create a fake KV cache. Arguments: @@ -672,27 +700,29 @@ def make_kv_cache(num_blocks: int, * for backend 'XFORMERS' * kv_cache: 2 x num_blocks x block_size x num_heads x head_size * for backend 'FLASH_ATTN' - ''' - if backend == 'XFORMERS': - kv_cache = torch.rand( - (2, num_blocks, block_size * num_heads * head_size)).to(device) - elif backend == 'FLASH_ATTN': - kv_cache = torch.rand( - (2, num_blocks, block_size, num_heads, head_size)).to(device) + """ + if backend == "XFORMERS": + kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to( + device + ) + elif backend == "FLASH_ATTN": + kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to( + device + ) else: raise ValueError( - f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " - f"'FLASH_ATTN'.") + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'." + ) if default_val is not None: kv_cache[:, :, :] = default_val return kv_cache def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int: - ''' + """ Compute the minimum number of blocks required to hold num_tokens tokens, given block_size - ''' + """ return (num_tokens + block_size) // block_size @@ -704,9 +734,12 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]): return torch.tensor([], device=device) -def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int], - device: Union[torch.device, str]): - ''' +def split_slot_mapping( + slot_mapping_list: torch.Tensor, + seq_lens: list[int], + device: Union[torch.device, str], +): + """ Split a slot mapping into valid prefill- and decode-phase slot mappings. Context: @@ -744,28 +777,32 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: list[int], reflecting all N prefill prompts * decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting all N decoded tokens - ''' + """ prefill_slot_mapping = [] decode_slot_mapping = [] base_idx = 0 for seq_len in seq_lens: - prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx + - seq_len - 1)]) + prefill_slot_mapping.extend( + slot_mapping_list[base_idx : (base_idx + seq_len - 1)] + ) decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1]) base_idx += seq_len - return (maybe_make_long_tensor(prefill_slot_mapping, device), - maybe_make_long_tensor(decode_slot_mapping, device)) + return ( + maybe_make_long_tensor(prefill_slot_mapping, device), + maybe_make_long_tensor(decode_slot_mapping, device), + ) def make_block_tables_slot_mapping( - block_size: int, - seq_lens: list[int], - device: Union[torch.device, str], - block_base_addr: int = 0) -> tuple[torch.Tensor, list[int], int]: - ''' + block_size: int, + seq_lens: list[int], + device: Union[torch.device, str], + block_base_addr: int = 0, +) -> tuple[torch.Tensor, list[int], int]: + """ Construct fake block tables & slot mappings. For a sequence with num_tokens tokens the minimum number @@ -802,12 +839,11 @@ def make_block_tables_slot_mapping( * block_tables_tensor: block table for sequence * slot_mapping_list: slot mapping for sequence * max_block_idx: the highest block address within this block table - ''' + """ # Provision minimum number of KV cache blocks num_blocks_list = [ - _num_tokens_to_min_blocks(num_tokens, block_size) - for num_tokens in seq_lens + _num_tokens_to_min_blocks(num_tokens, block_size) for num_tokens in seq_lens ] max_block_table_len = max(num_blocks_list) block_table_pad_tokens = 10 @@ -820,11 +856,11 @@ def make_block_tables_slot_mapping( max_block_idx = block_base_idx for sdx, num_tokens in enumerate(seq_lens): num_blocks = num_blocks_list[sdx] - block_table = list( - range(block_base_idx, block_base_idx - num_blocks, -1)) + block_table = list(range(block_base_idx, block_base_idx - num_blocks, -1)) for idx in range(num_tokens): - mapping_value = ( - idx % block_size) + block_table[idx // block_size] * block_size + mapping_value = (idx % block_size) + block_table[ + idx // block_size + ] * block_size slot_mapping_list.append(mapping_value) block_base_idx -= num_blocks @@ -848,9 +884,9 @@ def make_test_metadata( decoder_test_params: Optional[PhaseTestParameters], device: Union[torch.device, str], encoder_test_params: Optional[PhaseTestParameters] = None, - cross_test_params: Optional[PhaseTestParameters] = None + cross_test_params: Optional[PhaseTestParameters] = None, ) -> AttentionMetadata: - ''' + """ Construct fake attention metadata for a given test phase (prefill-phase or decode-phase). @@ -887,13 +923,12 @@ def make_test_metadata( Return: * AttentionMetadata structure - ''' + """ # Decoder self-attention memory mapping # decoder_test_params is None signals encoder-only # scenario, so kv_mmap is None - kv_mmap = (None - if decoder_test_params is None else decoder_test_params.kv_mmap) + kv_mmap = None if decoder_test_params is None else decoder_test_params.kv_mmap # This function constructs metadata assuming no chunked prefill, # i.e. 100% prefill tokens or 100% decode tokens @@ -906,10 +941,11 @@ def make_test_metadata( # seq_lens is None signals encoder-only # scenario, in which case num_prefills_or_decodes and # num_prefill_or_decode_tokens are unused - num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens)) + num_prefills_or_decodes = None if seq_lens is None else len(seq_lens) - num_prefill_or_decode_tokens = (None if seq_lens is None else ( - sum(seq_lens) if is_prompt else len(seq_lens))) + num_prefill_or_decode_tokens = ( + None if seq_lens is None else (sum(seq_lens) if is_prompt else len(seq_lens)) + ) # Seems for non-prefix-caching scenarios context_lens # is never needed @@ -923,16 +959,13 @@ def make_test_metadata( # * Extract encoder input sequence lengths assert encoder_test_params.packed_qkvo.packed_qkv is not None encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens - num_encoder_tokens = (None if encoder_seq_lens is None else - (sum(encoder_seq_lens))) + num_encoder_tokens = ( + None if encoder_seq_lens is None else (sum(encoder_seq_lens)) + ) - if cross_test_params is None: - cross_kv_mmap = None - else: - # Encoder/decoder or encoder-only models only: - # * Extract *cross-attention* slot_mapping and block table - # (kv_mmap) - cross_kv_mmap = cross_test_params.kv_mmap + # For encoder/decoder or encoder-only models only, extract *cross-attention* + # slot_mapping and block table (kv_mmap) + cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap attn_backend_obj = make_backend(attn_backend.name) @@ -952,10 +985,9 @@ def make_test_metadata( encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len, - ) = _make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) + ) = _make_metadata_tensors( + seq_lens, context_lens, encoder_seq_lens, device=device + ) return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), @@ -975,10 +1007,13 @@ def make_test_metadata( encoder_seq_lens_tensor=encoder_seq_lens_tensor, encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=(None if cross_kv_mmap is None else - cross_kv_mmap.slot_mapping), - cross_block_tables=(None if cross_kv_mmap is None else - cross_kv_mmap.block_tables)) + cross_slot_mapping=( + None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping + ), + cross_block_tables=( + None if cross_kv_mmap is None else cross_kv_mmap.block_tables + ), + ) else: # not is_prompt # Decode-phase scenario @@ -1000,10 +1035,9 @@ def make_test_metadata( encoder_seq_lens_tensor, encoder_seq_start_loc, max_encoder_seq_len, - ) = _make_metadata_tensors(seq_lens, - context_lens, - encoder_seq_lens, - device=device) + ) = _make_metadata_tensors( + seq_lens, context_lens, encoder_seq_lens, device=device + ) return attn_backend_obj.make_metadata( num_prefills=num_prefills, @@ -1025,16 +1059,19 @@ def make_test_metadata( encoder_seq_lens_tensor=encoder_seq_lens_tensor, encoder_seq_start_loc=encoder_seq_start_loc, max_encoder_seq_len=max_encoder_seq_len, - cross_slot_mapping=(None if cross_kv_mmap is None else - cross_kv_mmap.slot_mapping), - cross_block_tables=(None if cross_kv_mmap is None else - cross_kv_mmap.block_tables)) - - -def assert_actual_matches_ideal(test_params: PhaseTestParameters, - output_under_test: torch.Tensor, - backend: str) -> None: - ''' + cross_slot_mapping=( + None if cross_kv_mmap is None else cross_kv_mmap.slot_mapping + ), + cross_block_tables=( + None if cross_kv_mmap is None else cross_kv_mmap.block_tables + ), + ) + + +def assert_actual_matches_ideal( + test_params: PhaseTestParameters, output_under_test: torch.Tensor, backend: str +) -> None: + """ Assert that observed output matches the ideal output contained in the test parameters data structure. @@ -1042,24 +1079,24 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, * test_params: Test parameters including packed ideal output * output_under_test: actually observed output value - ''' + """ ideal_output = test_params.packed_qkvo.ideal_output - if backend == 'XFORMERS': - torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output)) + if backend == "XFORMERS": + torch.testing.assert_close( + ideal_output, output_under_test.view_as(ideal_output) + ) - elif backend == 'FLASH_ATTN': + elif backend == "FLASH_ATTN": # For FlashAttention override the accuracy thresholds to non default # values since we notice a higher difference between the ideal and # actual output. - torch.testing.assert_close(ideal_output, - output_under_test.view_as(ideal_output), - atol=0.01, - rtol=0.016) + torch.testing.assert_close( + ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016 + ) else: raise ValueError( - f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or " - f"'FLASH_ATTN'.") + f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'." + ) # Copied/modified from torch._refs.__init__.py @@ -1073,19 +1110,15 @@ def fp8_allclose( """ Reference implementation of torch.allclose """ - torch._refs._check_close_args(name="torch.allclose", - a=a, - b=b, - rtol=rtol, - atol=atol) + torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) return bool( torch.all( - torch.isclose(a.double(), - b.double(), - rtol=rtol, - atol=atol, - equal_nan=equal_nan)).item()) + torch.isclose( + a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan + ) + ).item() + ) # Marlin MoE test utils @@ -1098,7 +1131,8 @@ def stack_and_dev(tensors: list[torch.Tensor]): def compute_max_diff(output, output_ref): return torch.mean(torch.abs(output - output_ref)) / torch.mean( - torch.abs(output_ref)) + torch.abs(output_ref) + ) def torch_experts( @@ -1120,10 +1154,11 @@ def torch_experts( block_shape: Optional[list[int]] = None, apply_router_weights_on_input: bool = False, ) -> torch.Tensor: - assert (global_num_experts == -1 - or (global_num_experts == w1.shape[0] and expert_map is None) - or (expert_map is not None - and global_num_experts == expert_map.shape[0])) + assert ( + global_num_experts == -1 + or (global_num_experts == w1.shape[0] and expert_map is None) + or (expert_map is not None and global_num_experts == expert_map.shape[0]) + ) M, K = a.shape topk = topk_ids.shape[1] @@ -1138,8 +1173,9 @@ def torch_experts( if a1_scale: assert not per_act_token_quant and block_shape is None - a, a_scale = moe_kernel_quantize_input(a, a1_scale, quant_dtype, - per_act_token_quant, block_shape) + a, a_scale = moe_kernel_quantize_input( + a, a1_scale, quant_dtype, per_act_token_quant, block_shape + ) num_experts = w1.shape[0] @@ -1159,31 +1195,35 @@ def torch_experts( tmp2 = SiluAndMul()(tmp1) out[mask] = tmp2 @ w2[i].transpose(0, 1) if b_bias2 is not None: - out[mask] = out[mask] + b_bias2[i].view(1, -1).to( - tmp1.dtype) + out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype) elif block_shape is not None: # block quantized - assert (a_scale is not None and w1_scale is not None - and w2_scale is not None) - tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], - w1_scale[i], block_shape, - out.dtype) + assert ( + a_scale is not None + and w1_scale is not None + and w2_scale is not None + ) + tmp1 = native_w8a8_block_matmul( + a[mask], w1[i], a_scale[mask], w1_scale[i], block_shape, out.dtype + ) if b_bias1 is not None: tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype) tmp2 = SiluAndMul()(tmp1) tmp2, b_scale = moe_kernel_quantize_input( - tmp2, a2_scale, quant_dtype, per_act_token_quant, - block_shape) + tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape + ) - out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale, - w2_scale[i], block_shape, - out.dtype) + out[mask] = native_w8a8_block_matmul( + tmp2, w2[i], b_scale, w2_scale[i], block_shape, out.dtype + ) if b_bias2 is not None: - out[mask] = out[mask] + b_bias2[i].view(1, -1).to( - tmp1.dtype) + out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype) else: - assert (a_scale is not None and w1_scale is not None - and w2_scale is not None) + assert ( + a_scale is not None + and w1_scale is not None + and w2_scale is not None + ) scales = a_scale if a_scale.numel() == 1 else a_scale[mask] tmp1 = a[mask].to(f32) * scales @@ -1195,37 +1235,50 @@ def torch_experts( tmp2 = SiluAndMul()(tmp1).to(out.dtype) tmp2, b_scale = moe_kernel_quantize_input( - tmp2, a2_scale, quant_dtype, per_act_token_quant, - block_shape) + tmp2, a2_scale, quant_dtype, per_act_token_quant, block_shape + ) assert b_scale is not None tmp2 = tmp2.to(f32) * b_scale w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) out[mask] = (tmp2 @ w2_dq).to(out.dtype) if b_bias2 is not None: - out[mask] = out[mask] + b_bias2[i].view(1, -1).to( - out.dtype) + out[mask] = out[mask] + b_bias2[i].view(1, -1).to(out.dtype) if apply_router_weights_on_input: return out else: - return (out.view(M, -1, w2.shape[1]).to(f32) * - topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype) - - -def torch_moe(a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - score: torch.Tensor, - topk: int, - b_bias1: Optional[torch.Tensor] = None, - b_bias2: Optional[torch.Tensor] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None) -> torch.Tensor: + return ( + (out.view(M, -1, w2.shape[1]).to(f32) * topk_weight.view(M, -1, 1)) + .sum(dim=1) + .to(out.dtype) + ) + + +def torch_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, + b_bias1: Optional[torch.Tensor] = None, + b_bias2: Optional[torch.Tensor] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, +) -> torch.Tensor: score = torch.softmax(score, dim=-1, dtype=torch.float32) topk_weight, topk_ids = torch.topk(score, topk) - return torch_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts, - b_bias1, b_bias2, expert_map) + return torch_experts( + a, + w1, + w2, + topk_weight, + topk_ids, + global_num_experts, + b_bias1, + b_bias2, + expert_map, + ) def torch_moe_single(a, w, score, topk): @@ -1244,41 +1297,49 @@ def torch_moe_single(a, w, score, topk): # A special version of op check that has a restricted default set of test_utils # and a patched version of allclose that supports fp8 types. -def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, - torch._library.custom_ops.CustomOpDef], - args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, - *, - test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, - raise_exception: bool = True, - cond: bool = True) -> dict[str, str]: - with unittest.mock.patch('torch.allclose', new=fp8_allclose): - return torch.library.opcheck( - op, - args, - kwargs, - test_utils=test_utils, - raise_exception=raise_exception) if cond else {} +def opcheck( + op: Union[ + torch._ops.OpOverload, + torch._ops.OpOverloadPacket, + torch._library.custom_ops.CustomOpDef, + ], + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, + raise_exception: bool = True, + cond: bool = True, +) -> dict[str, str]: + with unittest.mock.patch("torch.allclose", new=fp8_allclose): + return ( + torch.library.opcheck( + op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception + ) + if cond + else {} + ) # For testing quantized linear kernels def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) def to_int8(tensor: torch.Tensor): return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) -def baseline_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - +def baseline_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: # We treat N-dimensional group scaling as extended numpy-style broadcasting # in numpy simply stretches dimensions with an extent of 1 to match # the target shape by repeating the data along that dimension (broadcasting) @@ -1297,16 +1358,19 @@ def group_broadcast(t, shape): for i, s in enumerate(shape): if t.shape[i] != s and t.shape[i] != 1: assert s % t.shape[i] == 0 - t = t.unsqueeze(i + 1)\ - .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\ - .flatten(i, i + 1) + t = ( + t.unsqueeze(i + 1) + .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :]) + .flatten(i, i + 1) + ) return t scale_a = group_broadcast(scale_a, a.shape) scale_b = group_broadcast(scale_b, b.shape) - output = torch.mm((scale_a * a.to(dtype=torch.float32)), - (scale_b * b.to(dtype=torch.float32))).to(out_dtype) + output = torch.mm( + (scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32)) + ).to(out_dtype) if bias is not None: output = output + bias diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py index ca2f04dabfc9..a61ccef70062 100644 --- a/tests/kv_transfer/test_lookup_buffer.py +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -8,8 +8,7 @@ from tqdm import tqdm from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( - SimpleBuffer) +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import SimpleBuffer from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe # TODO: the test depends on a lot of fields in the current implementation. @@ -17,7 +16,6 @@ def test_run(my_rank, buffer, device): - # buffer should be empty in the beginning if my_rank == 0: assert buffer.buffer_size == 0 @@ -27,7 +25,7 @@ def test_run(my_rank, buffer, device): # insert tokens = torch.tensor([1, 2, 3]).to(device) - roi = (tokens > 0) + roi = tokens > 0 if my_rank == 0: key = 2.0 * torch.ones([5, 6]).to(device) value = 3.0 * torch.ones([5, 6]).to(device) @@ -55,7 +53,6 @@ def test_run(my_rank, buffer, device): def stress_test(my_rank, buf, device): - torch.distributed.barrier() torch.manual_seed(100) @@ -66,7 +63,8 @@ def stress_test(my_rank, buf, device): torch.rand(100).to(device), # key torch.rand(100).to(device), # value torch.rand(100).to(device), # hidden - ) for i in tqdm(range(200)) + ) + for i in tqdm(range(200)) ] random.seed(my_rank) @@ -115,12 +113,11 @@ def stress_test(my_rank, buf, device): if __name__ == "__main__": - - my_rank = int(os.environ['RANK']) + my_rank = int(os.environ["RANK"]) torch.distributed.init_process_group( - backend='gloo', - init_method='tcp://localhost:12398', + backend="gloo", + init_method="tcp://localhost:12398", world_size=2, rank=my_rank, ) @@ -128,8 +125,8 @@ def stress_test(my_rank, buf, device): print(f"initialized! My rank is {my_rank}") config = KVTransferConfig( - kv_connector='P2pNcclConnector', - kv_buffer_device='cuda', + kv_connector="P2pNcclConnector", + kv_buffer_device="cuda", kv_buffer_size=1e9, kv_rank=my_rank, kv_role="kv_both", # this arg doesn't matter in this test @@ -160,4 +157,4 @@ def stress_test(my_rank, buf, device): buffer.close() data_pipe.close() cpu_pipe.close() - print('Done') + print("Done") diff --git a/tests/kv_transfer/test_module.py b/tests/kv_transfer/test_module.py index 7a04174870da..b9a28e4bceb7 100644 --- a/tests/kv_transfer/test_module.py +++ b/tests/kv_transfer/test_module.py @@ -9,21 +9,19 @@ def run_python_script(script_name, timeout): - script_name = f'kv_transfer/{script_name}' + script_name = f"kv_transfer/{script_name}" try: # Start both processes asynchronously using Popen process0 = subprocess.Popen( [sys.executable, script_name], - env={"RANK": - "0"}, # Set the RANK environment variable for process 0 + env={"RANK": "0"}, # Set the RANK environment variable for process 0 stdout=sys.stdout, # Pipe stdout to current stdout stderr=sys.stderr, # Pipe stderr to current stderr ) process1 = subprocess.Popen( [sys.executable, script_name], - env={"RANK": - "1"}, # Set the RANK environment variable for process 1 + env={"RANK": "1"}, # Set the RANK environment variable for process 1 stdout=sys.stdout, # Pipe stdout to current stdout stderr=sys.stderr, # Pipe stderr to current stderr ) @@ -34,11 +32,9 @@ def run_python_script(script_name, timeout): # Check the return status of both processes if process0.returncode != 0: - pytest.fail( - f"Test {script_name} failed for RANK=0, {process0.returncode}") + pytest.fail(f"Test {script_name} failed for RANK=0, {process0.returncode}") if process1.returncode != 0: - pytest.fail( - f"Test {script_name} failed for RANK=1, {process1.returncode}") + pytest.fail(f"Test {script_name} failed for RANK=1, {process1.returncode}") except subprocess.TimeoutExpired: # If either process times out, terminate both and fail the test @@ -53,15 +49,14 @@ def run_python_script(script_name, timeout): @pytest.mark.parametrize( "script_name,timeout", [ - ("test_lookup_buffer.py", - 60), # Second test case with a 60-second timeout - ("test_send_recv.py", 120) # First test case with a 120-second timeout - ]) + ("test_lookup_buffer.py", 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120), # First test case with a 120-second timeout + ], +) def test_run_python_script(script_name, timeout): # Check the number of GPUs if torch.cuda.device_count() < 2: - pytest.skip( - f"Skipping test {script_name} because <2 GPUs are available") + pytest.skip(f"Skipping test {script_name} because <2 GPUs are available") # Run the test if there are at least 2 GPUs run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py index 99ad2b43aeac..5762224eff76 100644 --- a/tests/kv_transfer/test_send_recv.py +++ b/tests/kv_transfer/test_send_recv.py @@ -15,7 +15,7 @@ def test_run(my_rank, pipe): print(f"rank {my_rank} test_run starts....") # test run x = torch.tensor([1]).to(pipe.device) - y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device) + y = torch.tensor([[2.0, 3.0, 4.0, 8.0]]).to(pipe.device) if my_rank == 0: pipe.send_tensor(x) print(f"rank {my_rank} sent tensor x") @@ -53,9 +53,8 @@ def stress_test(my_rank, pipe): for i in tqdm(range(500)): mean = torch.rand(1).item() * 100 std = torch.rand(1).item() * 100 - size = torch.randint(900, 1000, (2, )) - x = torch.normal(mean * 1.0, std * 1.0, - size=size.tolist()).to(pipe.device) + size = torch.randint(900, 1000, (2,)) + x = torch.normal(mean * 1.0, std * 1.0, size=size.tolist()).to(pipe.device) # 5% probability of sending a None if torch.rand(1).item() < 0.05: @@ -96,20 +95,16 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.distributed.barrier() for i in tqdm(range(500)): - tensors = [] if my_rank == 0: # create tensor - tensors = [ - torch.rand(nelement).to(pipe.device) for _ in range(ntensor) - ] + tensors = [torch.rand(nelement).to(pipe.device) for _ in range(ntensor)] torch.distributed.barrier() if my_rank == 0: - t = torch.tensor([time.time()], - dtype=torch.float64).to(pipe.device) + t = torch.tensor([time.time()], dtype=torch.float64).to(pipe.device) for tensor in tensors: pipe.send_tensor(tensor) pipe.send_tensor(t) @@ -121,24 +116,23 @@ def latency_test(my_rank, pipe, nelement, ntensor): torch.distributed.barrier() - print('Latency test passed.') - print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') + print("Latency test passed.") + print("Latency:", torch.tensor(latencies).mean().item() * 1000, "ms") if __name__ == "__main__": - - my_rank = int(os.environ['RANK']) + my_rank = int(os.environ["RANK"]) torch.distributed.init_process_group( - backend='gloo', - init_method='tcp://localhost:12398', + backend="gloo", + init_method="tcp://localhost:12398", world_size=2, rank=my_rank, ) config = KVTransferConfig( - kv_connector='P2pNcclConnector', - kv_buffer_device='cuda', + kv_connector="P2pNcclConnector", + kv_buffer_device="cuda", kv_buffer_size=1e9, kv_rank=my_rank, kv_role="kv_both", # this arg doesn't matter in this test diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index b539a7bf5d76..f805a74a4dba 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -10,12 +10,16 @@ import torch.nn as nn from huggingface_hub import snapshot_download -from vllm.distributed import (cleanup_dist_env_and_memory, - init_distributed_environment, - initialize_model_parallel) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.distributed import ( + cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.interfaces import SupportsLoRA @@ -47,11 +51,13 @@ def dist_init(): if current_platform.is_cpu() or current_platform.is_tpu(): backend = "gloo" - init_distributed_environment(world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend=backend) + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend=backend, + ) initialize_model_parallel(1, 1) yield cleanup_dist_env_and_memory(shutdown_ray=True) @@ -66,10 +72,9 @@ def dist_init_torch_only(): backend = "gloo" temp_file = tempfile.mkstemp()[1] - torch.distributed.init_process_group(world_size=1, - rank=0, - init_method=f"file://{temp_file}", - backend=backend) + torch.distributed.init_process_group( + world_size=1, rank=0, init_method=f"file://{temp_file}", backend=backend + ) class DummyLoRAModel(nn.Sequential, SupportsLoRA): @@ -79,24 +84,30 @@ class DummyLoRAModel(nn.Sequential, SupportsLoRA): @pytest.fixture def dummy_model() -> nn.Module: model = DummyLoRAModel( - OrderedDict([ - ("dense1", ColumnParallelLinear(764, 100)), - ("dense2", RowParallelLinear(100, 50)), - ( - "layer1", - nn.Sequential( - OrderedDict([ - ("dense1", ColumnParallelLinear(100, 10)), - ("dense2", RowParallelLinear(10, 50)), - ])), - ), - ("act2", nn.ReLU()), - ("output", ColumnParallelLinear(50, 10)), - ("outact", nn.Sigmoid()), - # Special handling for lm_head & sampler - ("lm_head", ParallelLMHead(512, 10)), - ("logits_processor", LogitsProcessor(512)), - ])) + OrderedDict( + [ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict( + [ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ] + ) + ), + ), + ("act2", nn.ReLU()), + ("output", ColumnParallelLinear(50, 10)), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("logits_processor", LogitsProcessor(512)), + ] + ) + ) model.config = MagicMock() model.embedding_modules = {"lm_head": "lm_head"} model.unpadded_vocab_size = 32000 @@ -106,24 +117,30 @@ def dummy_model() -> nn.Module: @pytest.fixture def dummy_model_gate_up() -> nn.Module: model = DummyLoRAModel( - OrderedDict([ - ("dense1", ColumnParallelLinear(764, 100)), - ("dense2", RowParallelLinear(100, 50)), - ( - "layer1", - nn.Sequential( - OrderedDict([ - ("dense1", ColumnParallelLinear(100, 10)), - ("dense2", RowParallelLinear(10, 50)), - ])), - ), - ("act2", nn.ReLU()), - ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), - ("outact", nn.Sigmoid()), - # Special handling for lm_head & sampler - ("lm_head", ParallelLMHead(512, 10)), - ("logits_processor", LogitsProcessor(512)), - ])) + OrderedDict( + [ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict( + [ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ] + ) + ), + ), + ("act2", nn.ReLU()), + ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("logits_processor", LogitsProcessor(512)), + ] + ) + ) model.config = MagicMock() model.packed_modules_mapping = { "gate_up_proj": [ diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py index 35d024575915..2f28253bce53 100644 --- a/tests/lora/test_add_lora.py +++ b/tests/lora/test_add_lora.py @@ -7,7 +7,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, +) from vllm.inputs import TextPrompt from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams @@ -26,14 +27,10 @@ def get_lora_requests(lora_path) -> list[LoRARequest]: return lora_requests -async def requests_processing_time(llm, - lora_requests: list[LoRARequest]) -> float: - - sampling_params = SamplingParams(n=1, - temperature=0.0, - top_p=1.0, - ignore_eos=True, - max_tokens=1) +async def requests_processing_time(llm, lora_requests: list[LoRARequest]) -> float: + sampling_params = SamplingParams( + n=1, temperature=0.0, top_p=1.0, ignore_eos=True, max_tokens=1 + ) generators = [] start = time.perf_counter() @@ -41,11 +38,11 @@ async def requests_processing_time(llm, for lora_request in lora_requests: lora_int_id = lora_request.lora_int_id generator = llm.generate( - prompt=TextPrompt(prompt=f"hello {lora_int_id}", - multi_modal_data=None), # type: ignore + prompt=TextPrompt(prompt=f"hello {lora_int_id}", multi_modal_data=None), # type: ignore sampling_params=sampling_params, lora_request=lora_request, - request_id=f"test{lora_int_id}") + request_id=f"test{lora_int_id}", + ) generators.append(generator) all_gens = merge_async_iterators(*generators) @@ -58,13 +55,13 @@ async def requests_processing_time(llm, @pytest.mark.asyncio async def test_add_lora(chatglm3_lora_files): - """ + """ The add_lora function is used to preload some LoRA adapters into the engine in anticipation of future requests using these adapters. To test this functionality, we use the async engine to process some requests - We do it twice, once with add_lora() preloading and once without. - We measure the request processing time in both cases and expect the time + We measure the request processing time in both cases and expect the time to be lesser in the case with add_lora() calls. """ lora_requests: list[LoRARequest] = get_lora_requests(chatglm3_lora_files) @@ -78,18 +75,18 @@ async def test_add_lora(chatglm3_lora_files): max_loras=max_loras, max_lora_rank=LORA_RANK, max_model_len=128, - gpu_memory_utilization=0.8, #avoid OOM + gpu_memory_utilization=0.8, # avoid OOM trust_remote_code=True, - enforce_eager=True) + enforce_eager=True, + ) # split lora_requests into 3 parts part_size = len(lora_requests) // 3 dummy_run_requests = lora_requests[:part_size] - warmup_run_requests = lora_requests[part_size:part_size * 2] - cold_run_requests = lora_requests[part_size * 2:] + warmup_run_requests = lora_requests[part_size : part_size * 2] + cold_run_requests = lora_requests[part_size * 2 :] async with build_async_engine_client_from_engine_args(engine_args) as llm: - # Dummy run - So any 1-time functionality like triton kernel compilation # is complete here. await requests_processing_time(llm, dummy_run_requests) @@ -101,18 +98,16 @@ async def test_add_lora(chatglm3_lora_files): # Test that all all_lora calls are successful. assert all(add_lora_results) - time_with_add_lora = await requests_processing_time( - llm, warmup_run_requests) + time_with_add_lora = await requests_processing_time(llm, warmup_run_requests) # Run without any warmup - time_cold_start = await requests_processing_time( - llm, cold_run_requests) + time_cold_start = await requests_processing_time(llm, cold_run_requests) - print(f"time hot-start {time_with_add_lora} vs " - f"time cold-start {time_cold_start} ") + print(f"time hot-start {time_with_add_lora} vs time cold-start {time_cold_start} ") assert time_with_add_lora < time_cold_start, ( f"time_with_add_lora={time_with_add_lora}, " f"time_cold_start={time_cold_start}" "The engine request processing time with LoRA pre-loading " - "must be less than the version that does on-demand LoRA loading.") + "must be less than the version that does on-demand LoRA loading." + ) diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index 5cffb8cfcc26..d8058c5f87a8 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -12,7 +12,7 @@ EXPECTED_LORA_OUTPUT = [ "SELECT count(*) FROM singer", - "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 + "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", "SELECT name , country , age FROM singer ORDER BY age", ] @@ -21,20 +21,24 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + query=( + "What is the average, minimum, and maximum " + "age of all singers from France?" + ) ), PROMPT_TEMPLATE.format( - query= - "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 + query=( + "Show name, country, age for all singers ordered " + "by age from the oldest to the youngest." + ) ), ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -47,13 +51,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: @create_new_process_for_each_test() def test_chatglm3_lora(chatglm3_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + trust_remote_code=True, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -66,15 +72,17 @@ def test_chatglm3_lora(chatglm3_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_chatglm3_lora_tp4(chatglm3_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=False, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=False, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -90,16 +98,18 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): # https://github.com/NVIDIA/nccl/issues/1790, set a lower value for # gpu_memory_utilization here because NCCL >= 2.26.3 seems to use # more GPU memory causing vLLM to OOM - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=True, - enable_chunked_prefill=True, - gpu_memory_utilization=0.85) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True, + enable_chunked_prefill=True, + gpu_memory_utilization=0.85, + ) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): assert output1[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_default_mm_loras.py b/tests/lora/test_default_mm_loras.py index f615ceda76b5..1a5b9ba3641d 100644 --- a/tests/lora/test_default_mm_loras.py +++ b/tests/lora/test_default_mm_loras.py @@ -32,15 +32,12 @@ "max_lora_rank": 320, "max_model_len": 12800, "gpu_memory_utilization": 0.8, - "limit_mm_per_prompt": { - "audio": 1 - }, + "limit_mm_per_prompt": {"audio": 1}, "enforce_eager": True, } -def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, - **kwargs): +def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, **kwargs): inputs = [([AUDIO_PROMPT], [audio_assets[0].audio_and_sample_rate[0]])] # Apply any additional kwargs as overrides to the base kwargs @@ -53,11 +50,11 @@ def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, max_tokens=128, audios=audios, lora_request=lora_request, - ) for prompts, audios in inputs + ) + for prompts, audios in inputs ] - assert vllm_outputs_with_default_lora[-1][-1][-1].endswith( - expected_suffix) + assert vllm_outputs_with_default_lora[-1][-1][-1].endswith(expected_suffix) def test_active_default_mm_lora( diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index ced0afc50cb9..695e06e7c1d6 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -12,32 +12,38 @@ import torch.nn.functional as F from vllm.config.lora import LoRAConfig -# yapf conflicts with isort for this block -# yapf: disable -from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - ColumnParallelLinearWithShardedLoRA, - LogitsProcessorWithLoRA, LoRAMapping, - MergedColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithLoRA, - MergedQKVParallelLinearWithShardedLoRA, - QKVParallelLinearWithLoRA, - QKVParallelLinearWithShardedLoRA, - ReplicatedLinearWithLoRA, - RowParallelLinearWithLoRA, - RowParallelLinearWithShardedLoRA, - VocabParallelEmbeddingWithLoRA) -# yapf: enable +from vllm.lora.layers import ( + BaseLayerWithLoRA, + ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, + LogitsProcessorWithLoRA, + LoRAMapping, + MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, + QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, + VocabParallelEmbeddingWithLoRA, +) from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) + ParallelLMHead, + VocabParallelEmbedding, + get_masked_input_and_mask, +) from vllm.model_executor.utils import set_random_seed from vllm.platforms import current_platform @@ -51,11 +57,14 @@ pytestmark = pytest.mark.skipif( not (current_platform.is_cuda_alike() or current_platform.is_cpu()), - reason="Backend not supported") + reason="Backend not supported", +) -DEVICES = ([ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] if current_platform.is_cuda_alike() else ["cpu"]) +DEVICES = ( + [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] + if current_platform.is_cuda_alike() + else ["cpu"] +) # prefill stage(True) or decode stage(False) STAGES = [True, False] @@ -68,8 +77,8 @@ @pytest.fixture(autouse=True) def clean_cache_reset_device(reset_default_device): # Release any memory we might be holding on to. CI runs OOMs otherwise. - from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, - _LORA_B_PTR_DICT) + from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT + _LORA_B_PTR_DICT.clear() _LORA_A_PTR_DICT.clear() @@ -79,13 +88,14 @@ def clean_cache_reset_device(reset_default_device): @pytest.fixture(autouse=True) def skip_cuda_with_stage_false(request): """ - On cuda-like platforms, we use the same kernels for prefill and decode + On cuda-like platforms, we use the same kernels for prefill and decode stage, and 'stage' is generally ignored, so we only need to test once. """ if current_platform.is_cuda_alike(): try: if hasattr(request.node, "callspec") and hasattr( - request.node.callspec, "params"): + request.node.callspec, "params" + ): params = request.node.callspec.params if "stage" in params and params["stage"] is False: pytest.skip("Skip test when stage=False") @@ -94,9 +104,9 @@ def skip_cuda_with_stage_false(request): yield -def get_random_id_to_index(num_loras: int, - num_slots: int, - log: bool = True) -> list[Optional[int]]: +def get_random_id_to_index( + num_loras: int, num_slots: int, log: bool = True +) -> list[Optional[int]]: """Creates a random lora_id_to_index mapping. Args: @@ -109,7 +119,8 @@ def get_random_id_to_index(num_loras: int, if num_loras > num_slots: raise ValueError( f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " - "num_loras must be less than or equal to num_slots.") + "num_loras must be less than or equal to num_slots." + ) slots: list[Optional[int]] = [None] * num_slots random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() @@ -158,19 +169,18 @@ def populate_loras( subloras: list[LoRALayerWeights] = [] sublora_len = layer_weights.shape[0] // repeats for i in range(repeats): - sublora = DummyLoRAManager( - layer_weights.device).init_random_lora( - module_name=f"fake_{i}", - weight=layer_weights, - generate_embeddings_tensor=generate_embeddings_tensor, - ) - sublora.lora_b = sublora.lora_b[(sublora_len * - i):(sublora_len * (i + 1)), :] + sublora = DummyLoRAManager(layer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[ + (sublora_len * i) : (sublora_len * (i + 1)), : + ] sublora.optimize() subloras.append(sublora) - lora = PackedLoRALayerWeights.pack( - subloras) if repeats > 1 else subloras[0] + lora = PackedLoRALayerWeights.pack(subloras) if repeats > 1 else subloras[0] layer.set_lora( slot_idx, @@ -191,7 +201,7 @@ def create_random_inputs( input_size: tuple[int, ...], input_range: tuple[float, float], input_type: torch.dtype = torch.int, - device: torch.device = "cuda" + device: torch.device = "cuda", ) -> tuple[list[torch.Tensor], list[int], list[int]]: """Creates random inputs. @@ -213,14 +223,15 @@ def create_random_inputs( for _ in range(num_inputs): if input_type == torch.int: inputs.append( - torch.randint(low=int(low), - high=int(high), - size=input_size, - device=device)) + torch.randint( + low=int(low), high=int(high), size=input_size, device=device + ) + ) else: inputs.append( - torch.rand(size=input_size, dtype=input_type, device=device) * - high + low) + torch.rand(size=input_size, dtype=input_type, device=device) * high + + low + ) lora_id = random.choice(active_lora_ids) index_mapping += [lora_id] * input_size[0] @@ -258,9 +269,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) def create_random_embedding_layer(): embedding = VocabParallelEmbedding(vocab_size, 256) @@ -286,15 +297,18 @@ def create_random_embedding_layer(): inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_embedding(torch.cat(inputs)) @@ -306,15 +320,12 @@ def create_random_embedding_layer(): input_, lora.lora_a.T, ) - result += (after_a @ lora.lora_b.T) + result += after_a @ lora.lora_b.T expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -324,24 +335,24 @@ def create_random_embedding_layer(): inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_embedding(torch.cat(inputs)) expected_result = embedding(torch.cat(inputs)) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @@ -351,9 +362,9 @@ def create_random_embedding_layer(): @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) -def test_embeddings_with_new_embeddings(dist_init, num_loras, device, - vocab_size, stage) -> None: - +def test_embeddings_with_new_embeddings( + dist_init, num_loras, device, vocab_size, stage +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -361,9 +372,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) def create_random_embedding_layer(): embedding = VocabParallelEmbedding(vocab_size, 256) @@ -373,12 +384,12 @@ def create_random_embedding_layer(): expanded_embedding = VocabParallelEmbedding( vocab_size + lora_config.lora_extra_vocab_size * max_loras, 256, - org_num_embeddings=vocab_size) + org_num_embeddings=vocab_size, + ) expanded_embedding.weight.data[:vocab_size, :] = embedding_data # We need to deepcopy the embedding as it will be modified # in place - lora_embedding = VocabParallelEmbeddingWithLoRA( - deepcopy(expanded_embedding)) + lora_embedding = VocabParallelEmbeddingWithLoRA(deepcopy(expanded_embedding)) lora_embedding.create_lora_weights(max_loras, lora_config) return expanded_embedding, lora_embedding @@ -392,7 +403,8 @@ def create_random_embedding_layer(): id_to_index, layer=lora_embedding, layer_weights=torch.zeros( - (256, vocab_size + lora_config.lora_extra_vocab_size)), + (256, vocab_size + lora_config.lora_extra_vocab_size) + ), generate_embeddings_tensor=256, ) @@ -410,52 +422,53 @@ def create_random_embedding_layer(): inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) original_inputs = deepcopy(inputs) # Force some of the inputs to be in the extended embeddings range # to guarantee that their behavior is tested. - for input_, original_input_, lora_id in zip(inputs, original_inputs, - prompt_mapping): + for input_, original_input_, lora_id in zip( + inputs, original_inputs, prompt_mapping + ): embedding_id = lora_id - 1 input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) original_input_[-1] = vocab_size - input_[-2] = vocab_size + ( - (embedding_id + 1) * embeddings_tensor_len - 1) + input_[-2] = vocab_size + ((embedding_id + 1) * embeddings_tensor_len - 1) original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - expanded_embedding.weight[vocab_size:vocab_size + - (embeddings_tensor_len * - max_loras)] = torch.cat(embeddings_tensors) + expanded_embedding.weight[ + vocab_size : vocab_size + (embeddings_tensor_len * max_loras) + ] = torch.cat(embeddings_tensors) lora_result = lora_embedding(torch.cat(original_inputs)) expected_results: list[torch.Tensor] = [] - for input_, original_input_, lora_id in zip(inputs, original_inputs, - prompt_mapping): + for input_, original_input_, lora_id in zip( + inputs, original_inputs, prompt_mapping + ): lora = lora_dict[lora_id] result = expanded_embedding(input_) after_a = F.embedding( original_input_, lora.lora_a.T, ) - result += (after_a @ lora.lora_b.T) + result += after_a @ lora.lora_b.T expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -465,24 +478,24 @@ def create_random_embedding_layer(): inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=num_loras * 3, - input_size=(200, ), + input_size=(200,), input_range=(1, vocab_size), - device=device) + device=device, + ) original_inputs = deepcopy(inputs) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) lora_result = lora_embedding(torch.cat(original_inputs)) expected_result = expanded_embedding(torch.cat(inputs)) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @@ -490,9 +503,9 @@ def create_random_embedding_layer(): @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) -def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, - stage) -> None: - +def test_lm_head_logits_processor( + dist_init, num_loras, device, vocab_size, stage +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -500,22 +513,25 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) + lora_config = LoRAConfig( + max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16 + ) def _pretest(): - linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, - 1024, - vocab_size, - params_dtype=torch.float16) + linear = ParallelLMHead( + vocab_size + lora_config.lora_extra_vocab_size, + 1024, + vocab_size, + params_dtype=torch.float16, + ) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( - vocab_size + lora_config.lora_extra_vocab_size, vocab_size) + vocab_size + lora_config.lora_extra_vocab_size, vocab_size + ) lora_logits_processor = LogitsProcessorWithLoRA( - logits_processor, 1024, linear.weight.dtype, linear.weight.device, - None) + logits_processor, 1024, linear.weight.dtype, linear.weight.device, None + ) lora_logits_processor.create_lora_weights(max_loras, lora_config) return linear, logits_processor, lora_logits_processor @@ -542,10 +558,9 @@ def _pretest(): input_size=(1, 1024), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -556,25 +571,24 @@ def _pretest(): input_ = torch.rand(20, 1024) lora_result = lora_logits_processor._get_logits( - hidden_states=torch.cat(inputs), - lm_head=linear, - embedding_bias=None) + hidden_states=torch.cat(inputs), lm_head=linear, embedding_bias=None + ) original_lm_head = deepcopy(linear) - linear.weight[logits_processor. - org_vocab_size:logits_processor.org_vocab_size + - embeddings_tensor_len] = embeddings_tensor + linear.weight[ + logits_processor.org_vocab_size : logits_processor.org_vocab_size + + embeddings_tensor_len + ] = embeddings_tensor - logits_processor.org_vocab_size = (vocab_size + - lora_config.lora_extra_vocab_size) + logits_processor.org_vocab_size = vocab_size + lora_config.lora_extra_vocab_size expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): lora = lora_dict[lora_id] - result = logits_processor._get_logits(hidden_states=input_, - lm_head=linear, - embedding_bias=None) - result[:, vocab_size + embeddings_tensor_len:] = float("-inf") + result = logits_processor._get_logits( + hidden_states=input_, lm_head=linear, embedding_bias=None + ) + result[:, vocab_size + embeddings_tensor_len :] = float("-inf") result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) @@ -591,10 +605,9 @@ def _pretest(): input_size=(1, 1024), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -606,17 +619,16 @@ def _pretest(): lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), lm_head=original_lm_head, - embedding_bias=None)[:, :vocab_size] + embedding_bias=None, + )[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), lm_head=original_lm_head, - embedding_bias=None) + embedding_bias=None, + ) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @@ -629,7 +641,6 @@ def test_linear_replicated( device, stage, ) -> None: - if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -644,17 +655,17 @@ def test_linear_replicated( ) def create_random_linear_replicated_layer(): - - linear = ReplicatedLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) + linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = ReplicatedLinearWithLoRA(linear) lora_linear.create_lora_weights(max_loras, lora_config) - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == 1) + assert ( + lora_linear.n_slices + == len(lora_linear.lora_a_stacked) + == len(lora_linear.lora_b_stacked) + == 1 + ) return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -676,10 +687,9 @@ def create_random_linear_replicated_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -692,7 +702,6 @@ def create_random_linear_replicated_layer(): expected_results: list[torch.Tensor] = [] for input_, lora_id in zip(inputs, prompt_mapping): - lora = lora_dict[lora_id] result = linear(input_)[0] result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling @@ -700,10 +709,7 @@ def create_random_linear_replicated_layer(): expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -716,22 +722,19 @@ def create_random_linear_replicated_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + punica_wrapper.update_metadata( + lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size + ) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @@ -740,9 +743,9 @@ def create_random_linear_replicated_layer(): @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage) -> None: - +def test_linear_parallel( + dist_init, num_loras, orientation, fully_shard, device, stage +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -759,25 +762,32 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, def create_random_linear_parallel_layer(): if orientation == "row": - linear = RowParallelLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) + linear = RowParallelLinear( + 4096, 4096, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard - else RowParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + RowParallelLinearWithLoRA(linear) + if not fully_shard + else RowParallelLinearWithShardedLoRA(linear) + ) else: - linear = ColumnParallelLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) + linear = ColumnParallelLinear( + 4096, 4096, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (ColumnParallelLinearWithLoRA(linear) - if not fully_shard else - ColumnParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + ColumnParallelLinearWithLoRA(linear) + if not fully_shard + else ColumnParallelLinearWithShardedLoRA(linear) + ) lora_linear.create_lora_weights(max_loras, lora_config) - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == 1) + assert ( + lora_linear.n_slices + == len(lora_linear.lora_a_stacked) + == len(lora_linear.lora_b_stacked) + == 1 + ) return linear, lora_linear @@ -800,10 +810,9 @@ def create_random_linear_parallel_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, id_to_index, @@ -823,10 +832,7 @@ def create_random_linear_parallel_layer(): expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) # Check that resetting the lora weights succeeds @@ -839,22 +845,19 @@ def create_random_linear_parallel_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + punica_wrapper.update_metadata( + lora_mapping, id_to_index, max_loras, 512, lora_config.lora_extra_vocab_size + ) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @torch.inference_mode() @@ -863,9 +866,9 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage) -> None: - +def test_column_parallel_packed( + dist_init, num_loras, repeats, fully_shard, device, stage +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -882,33 +885,35 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, def create_column_parallel_packed_layer(): if repeats == 2: - linear = MergedColumnParallelLinear(4096, [4096] * repeats, - bias=False, - params_dtype=torch.float16) + linear = MergedColumnParallelLinear( + 4096, [4096] * repeats, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedColumnParallelLinearWithLoRA(linear) - if not fully_shard else - MergedColumnParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + MergedColumnParallelLinearWithLoRA(linear) + if not fully_shard + else MergedColumnParallelLinearWithShardedLoRA(linear) + ) elif repeats == 3: - linear = QKVParallelLinear(4096, - 64, - 32, - bias=False, - params_dtype=torch.float16) + linear = QKVParallelLinear( + 4096, 64, 32, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedQKVParallelLinearWithLoRA(linear) - if not fully_shard else - MergedQKVParallelLinearWithShardedLoRA(linear)) + lora_linear = ( + MergedQKVParallelLinearWithLoRA(linear) + if not fully_shard + else MergedQKVParallelLinearWithShardedLoRA(linear) + ) else: - linear = QKVParallelLinear(4096, - 64, - 32, - bias=False, - params_dtype=torch.float16) + linear = QKVParallelLinear( + 4096, 64, 32, bias=False, params_dtype=torch.float16 + ) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = QKVParallelLinearWithLoRA( - linear - ) if not fully_shard else QKVParallelLinearWithShardedLoRA(linear) + lora_linear = ( + QKVParallelLinearWithLoRA(linear) + if not fully_shard + else QKVParallelLinearWithShardedLoRA(linear) + ) @dataclass class FakeConfig: @@ -917,11 +922,15 @@ class FakeConfig: num_attention_heads = 32 n_slices = repeats - lora_linear.create_lora_weights(max_loras, - lora_config, - model_config=FakeConfig()) - assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( - lora_linear.lora_b_stacked) == n_slices) + lora_linear.create_lora_weights( + max_loras, lora_config, model_config=FakeConfig() + ) + assert ( + lora_linear.n_slices + == len(lora_linear.lora_a_stacked) + == len(lora_linear.lora_b_stacked) + == n_slices + ) return linear, lora_linear @@ -946,10 +955,9 @@ class FakeConfig: input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, @@ -966,18 +974,14 @@ class FakeConfig: result = linear(input_)[0] subloras = sublora_dict[lora_id] for i, sublora in enumerate(subloras): - result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] * - (i + 1)] += ( - input_ @ sublora.lora_a.T @ sublora.lora_b.T * - sublora.scaling) + result[ + :, sublora.lora_b.shape[0] * i : sublora.lora_b.shape[0] * (i + 1) + ] += input_ @ sublora.lora_a.T @ sublora.lora_b.T * sublora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) for slot_idx in range(max_loras): lora_linear.reset_lora(slot_idx) @@ -988,10 +992,9 @@ class FakeConfig: input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) punica_wrapper.update_metadata( lora_mapping, @@ -1005,15 +1008,13 @@ class FakeConfig: expected_result = linear(torch.cat(inputs))[0] rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) + torch.testing.assert_close(lora_result, expected_result, rtol=rtol, atol=atol) @pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) @pytest.mark.parametrize( - "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))) + "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)) +) def test_vocab_parallel_embedding_indices(tp_size, seed): random.seed(seed) vocab_size = random.randint(4000, 64000) @@ -1031,20 +1032,24 @@ def test_vocab_parallel_embedding_indices(tp_size, seed): token_ids: list[int] = [] for tp_rank in range(tp_size): - with patch( + with ( + patch( "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", - return_value=tp_rank - ), patch( + return_value=tp_rank, + ), + patch( "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", - return_value=tp_size): + return_value=tp_size, + ), + ): vocab_embedding = VocabParallelEmbedding( - vocab_size, 1, org_num_embeddings=org_vocab_size) + vocab_size, 1, org_num_embeddings=org_vocab_size + ) vocab_size_padded = vocab_embedding.num_embeddings_padded shard_indices = vocab_embedding.shard_indices # Assert that the ranges are contiguous assert shard_indices.org_vocab_start_index == last_org_vocab_end_index - assert (shard_indices.added_vocab_start_index == - last_added_vocab_end_index) + assert shard_indices.added_vocab_start_index == last_added_vocab_end_index # Ensure that we are not exceeding the vocab size computed_vocab_size += shard_indices.num_elements_padded @@ -1053,22 +1058,39 @@ def test_vocab_parallel_embedding_indices(tp_size, seed): # Ensure that the ranges are not overlapping all_org_tokens.extend( - range(shard_indices.org_vocab_start_index, - shard_indices.org_vocab_end_index)) + range( + shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index + ) + ) all_added_tokens.extend( - range(shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index)) + range( + shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index, + ) + ) token_ids.extend( - range(shard_indices.org_vocab_start_index, - shard_indices.org_vocab_end_index)) - token_ids.extend([-1] * (shard_indices.num_org_elements_padded - - shard_indices.num_org_elements)) + range( + shard_indices.org_vocab_start_index, shard_indices.org_vocab_end_index + ) + ) + token_ids.extend( + [-1] + * (shard_indices.num_org_elements_padded - shard_indices.num_org_elements) + ) + token_ids.extend( + range( + shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index, + ) + ) token_ids.extend( - range(shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index)) - token_ids.extend([-1] * (shard_indices.num_added_elements_padded - - shard_indices.num_added_elements)) + [-1] + * ( + shard_indices.num_added_elements_padded + - shard_indices.num_added_elements + ) + ) last_org_vocab_end_index = shard_indices.org_vocab_end_index last_added_vocab_end_index = shard_indices.added_vocab_end_index @@ -1096,130 +1118,165 @@ def test_get_masked_input_and_mask(): x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) # base tp 1 case, no padding - modified_x, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=0) + modified_x, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=0, + ) assert torch.equal(x, modified_x) # tp 2 case, no padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=0) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=0, + ) modified_x_rank_1, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=8, added_vocab_start_index=10, added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) + num_org_vocab_padding=0, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]) + ) # tp 4 case, no padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=0) - modified_x_rank_1, _ = get_masked_input_and_mask(x, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=0) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=0, + ) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=0, + ) modified_x_rank_2, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=6, added_vocab_start_index=10, added_vocab_end_index=11, - num_org_vocab_padding=0) + num_org_vocab_padding=0, + ) modified_x_rank_3, _ = get_masked_input_and_mask( x, org_vocab_start_index=6, org_vocab_end_index=8, added_vocab_start_index=11, added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) - assert torch.equal(modified_x_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) - assert torch.equal(modified_x_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) + num_org_vocab_padding=0, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]) + ) + assert torch.equal( + modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]) + ) + assert torch.equal( + modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]) + ) # base tp 1 case, with padding - modified_x, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x, - torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) + modified_x, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=2, + ) + assert torch.equal( + modified_x, torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]) + ) # tp 2 case, with padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=2) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=2, + ) modified_x_rank_1, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=8, added_vocab_start_index=10, added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) + num_org_vocab_padding=2, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]) + ) # tp 4 case, with padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=2) - modified_x_rank_1, _ = get_masked_input_and_mask(x, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=2) + modified_x_rank_0, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=2, + ) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=2, + ) modified_x_rank_2, _ = get_masked_input_and_mask( x, org_vocab_start_index=4, org_vocab_end_index=6, added_vocab_start_index=10, added_vocab_end_index=11, - num_org_vocab_padding=2) + num_org_vocab_padding=2, + ) modified_x_rank_3, _ = get_masked_input_and_mask( x, org_vocab_start_index=6, org_vocab_end_index=8, added_vocab_start_index=11, added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) - assert torch.equal(modified_x_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) - assert torch.equal(modified_x_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) + num_org_vocab_padding=2, + ) + assert torch.equal( + modified_x_rank_0, torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]) + ) + assert torch.equal( + modified_x_rank_1, torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]) + ) + assert torch.equal( + modified_x_rank_2, torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]) + ) + assert torch.equal( + modified_x_rank_3, torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]) + ) diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index a6770e6d32af..0d9431bd7aae 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -15,31 +15,32 @@ EXPECTED_LORA_OUTPUT = [ " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 - " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 - " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 - " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' ", # noqa: E501 ] -def do_sample(llm: vllm.LLM, - lora_path: str, - lora_id: int, - tensorizer_config_dict: Union[dict, None] = None) -> list[str]: +def do_sample( + llm: vllm.LLM, + lora_path: str, + lora_id: int, + tensorizer_config_dict: Union[dict, None] = None, +) -> list[str]: prompts = [ "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]", # noqa: E501 ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=256, - skip_special_tokens=False, - stop=["[/assistant]"]) + sampling_params = vllm.SamplingParams( + temperature=0, max_tokens=256, skip_special_tokens=False, stop=["[/assistant]"] + ) if tensorizer_config_dict is not None: outputs = llm.generate( @@ -49,14 +50,19 @@ def do_sample(llm: vllm.LLM, str(lora_id), lora_id, lora_path, - tensorizer_config_dict=tensorizer_config_dict) - if lora_id else None) + tensorizer_config_dict=tensorizer_config_dict, + ) + if lora_id + else None, + ) else: outputs = llm.generate( prompts, sampling_params, lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + if lora_id + else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -67,42 +73,51 @@ def do_sample(llm: vllm.LLM, return generated_texts -def generate_and_test(llm, - sql_lora_files, - tensorizer_config_dict: Union[dict, None] = None): +def generate_and_test( + llm, sql_lora_files, tensorizer_config_dict: Union[dict, None] = None +): print("lora adapter created") print("lora 1") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=1) == EXPECTED_LORA_OUTPUT + assert ( + do_sample( + llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=1, + ) + == EXPECTED_LORA_OUTPUT + ) print("lora 2") - assert do_sample(llm, - sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, - lora_id=2) == EXPECTED_LORA_OUTPUT + assert ( + do_sample( + llm, + sql_lora_files, + tensorizer_config_dict=tensorizer_config_dict, + lora_id=2, + ) + == EXPECTED_LORA_OUTPUT + ) print("removing lora") @create_new_process_for_each_test() def test_llama_lora(sql_lora_files): - llm = vllm.LLM( MODEL_PATH, tokenizer=sql_lora_files, enable_lora=True, # also test odd max_num_seqs max_num_seqs=13, - max_loras=4) + max_loras=4, + ) generate_and_test(llm, sql_lora_files) @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_llama_lora_tp4(sql_lora_files): - llm = vllm.LLM( MODEL_PATH, tokenizer=sql_lora_files, @@ -117,7 +132,6 @@ def test_llama_lora_tp4(sql_lora_files): @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): - llm = vllm.LLM( MODEL_PATH, tokenizer=sql_lora_files, @@ -132,9 +146,9 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): @multi_gpu_test(num_gpus=2) @create_new_process_for_each_test() -def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, - sql_lora_huggingface_id): - +def test_tp2_serialize_and_deserialize_lora( + tmp_path, sql_lora_files, sql_lora_huggingface_id +): # Run the tensorizing of the LoRA adapter and the model in a subprocess # to guarantee cleanup @@ -145,17 +159,28 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, lora_path = sql_lora_huggingface_id suffix = "test" try: - result = subprocess.run([ - sys.executable, - f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", - MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size", - str(tp_size), "serialize", "--serialized-directory", - str(tmp_path), "--suffix", suffix, "--serialization-kwargs", - '{"limit_cpu_concurrency": 4}' - ], - check=True, - capture_output=True, - text=True) + result = subprocess.run( + [ + sys.executable, + f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", + "--model", + MODEL_PATH, + "--lora-path", + lora_path, + "--tensor-parallel-size", + str(tp_size), + "serialize", + "--serialized-directory", + str(tmp_path), + "--suffix", + suffix, + "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}', + ], + check=True, + capture_output=True, + text=True, + ) except subprocess.CalledProcessError as e: print("Tensorizing failed.") print("STDOUT:\n", e.stdout) @@ -167,21 +192,25 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, model_uri = tmp_path / "vllm" / model_ref / suffix / model_name tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri)) - loaded_llm = LLM(model=model_ref, - tokenizer=sql_lora_files, - load_format="tensorizer", - enable_lora=True, - enforce_eager=True, - model_loader_extra_config=tensorizer_config, - max_num_seqs=13, - tensor_parallel_size=2, - max_loras=2) + loaded_llm = LLM( + model=model_ref, + tokenizer=sql_lora_files, + load_format="tensorizer", + enable_lora=True, + enforce_eager=True, + model_loader_extra_config=tensorizer_config, + max_num_seqs=13, + tensor_parallel_size=2, + max_loras=2, + ) tc_as_dict = tensorizer_config.to_serializable() print("lora adapter created") print("lora 1") - assert do_sample(loaded_llm, - sql_lora_files, - tensorizer_config_dict=tc_as_dict, - lora_id=1) == EXPECTED_LORA_OUTPUT + assert ( + do_sample( + loaded_llm, sql_lora_files, tensorizer_config_dict=tc_as_dict, lora_id=1 + ) + == EXPECTED_LORA_OUTPUT + ) diff --git a/tests/lora/test_llm_with_multi_loras.py b/tests/lora/test_llm_with_multi_loras.py index 3d8dd512a201..269a1ade7734 100644 --- a/tests/lora/test_llm_with_multi_loras.py +++ b/tests/lora/test_llm_with_multi_loras.py @@ -5,6 +5,7 @@ 1. test multi loras service with tp >= 2 2. test multi loras request """ + import pytest from tests.utils import multi_gpu_test @@ -25,20 +26,14 @@ LORA_TEST_PROMPTS = ["What is GitHub?", "Hi, tell me about you"] LORA_TEST_EXPECTED = [ "GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.", # noqa: E501 - "I am Alice, an AI assistant developed by GitHub/Charent.", # noqa: E501 + "I am Alice, an AI assistant developed by GitHub/Charent.", ] def format_chatml_messages(prompt: str): return [ - { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": "user", - "content": prompt - }, + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, ] @@ -57,7 +52,6 @@ def make_add_lora_request(name: str, path: str): @multi_gpu_test(num_gpus=2) def test_multi_loras_with_tp_sync(): - llm = LLM( model=MODEL_PATH, enable_lora=True, @@ -116,15 +110,17 @@ def call_llm_get_outputs(prompt: str, lora_name: str): def reload_lora(name: str): """ - reload a lora to simulate the case: - setting `VLLM_ALLOW_RUNTIME_LORA_UPDATING=true` + reload a lora to simulate the case: + setting `VLLM_ALLOW_RUNTIME_LORA_UPDATING=true` for dynamic lora loading and unloading """ remove_lora_response = llm.llm_engine.remove_lora( - lora_id=LORA_NAME_ID_MAP[name]) + lora_id=LORA_NAME_ID_MAP[name] + ) add_lora_response = llm.llm_engine.add_lora( - make_add_lora_request(name, LORA_NAME_PATH_MAP[name])) + make_add_lora_request(name, LORA_NAME_PATH_MAP[name]) + ) print(f"{remove_lora_response=}, {add_lora_response=}") @@ -134,7 +130,6 @@ def check_outputs(outputs: str, expected: str): assert outputs == expected for prompt, expected_output in zip(LORA_TEST_PROMPTS, LORA_TEST_EXPECTED): - output_text = call_llm_get_outputs(prompt, "Alice") check_outputs(output_text, expected_output) @@ -175,8 +170,7 @@ def test_multiple_lora_requests(): PROMPTS = ["Hello, my name is"] * 2 LORA_NAME = "Alice" lora_request = [ - LoRARequest(LORA_NAME + str(idx), idx + 1, - LORA_NAME_PATH_MAP[LORA_NAME]) + LoRARequest(LORA_NAME + str(idx), idx + 1, LORA_NAME_PATH_MAP[LORA_NAME]) for idx in range(len(PROMPTS)) ] # Multiple SamplingParams should be matched with each prompt diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index ebc0f26378d2..2219d470e91a 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -8,9 +8,7 @@ from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.utils import WeightsMapper -lora_lst = [ - "baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b" -] +lora_lst = ["baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"] BAICHUAN_LORA_MODULES = [ "W_pack", "o_proj", @@ -37,8 +35,9 @@ def test_load_checkpoints( else: expected_lora_modules.append(module) if lora_name == "baichuan7B": - peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_lora_files, max_position_embeddings=4096 + ) # For the baichuan7B model, load it's LoRA, # and the test should pass. LoRAModel.from_local_checkpoint( @@ -48,13 +47,15 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) elif lora_name == "baichuan7B-zero": # Test that the target_modules contain prefix # such as "model.layers.0.self_atten.W_pack", and # the test should pass. - peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_zero_lora_files, max_position_embeddings=4096 + ) LoRAModel.from_local_checkpoint( baichuan_zero_lora_files, expected_lora_modules, @@ -62,12 +63,14 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) elif lora_name == "baichuan7B-zero-regex": # Test that the `target_modules` in the form of regular expressions, # such as `model\\..*(W_pack|o_proj)`, and the test should pass. - peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_regex_lora_files, max_position_embeddings=4096 + ) LoRAModel.from_local_checkpoint( baichuan_regex_lora_files, expected_lora_modules, @@ -75,13 +78,15 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) else: # For the baichuan7B model, load chatglm3-6b's LoRA, # and the test should raise the following error. expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501 - peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + chatglm3_lora_files, max_position_embeddings=4096 + ) with pytest.raises(ValueError, match=expected_error): LoRAModel.from_local_checkpoint( chatglm3_lora_files, @@ -90,11 +95,11 @@ def test_load_checkpoints( lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) def test_lora_weights_mapping(baichuan_lora_files): - packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules @@ -113,8 +118,9 @@ def test_lora_weights_mapping(baichuan_lora_files): ".layers.": ".baichuan_layers.", }, ) - peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + baichuan_lora_files, max_position_embeddings=4096 + ) lora_model = LoRAModel.from_local_checkpoint( baichuan_lora_files, expected_lora_modules, diff --git a/tests/lora/test_lora_functions.py b/tests/lora/test_lora_functions.py index 221d5237823c..e914393fee8a 100644 --- a/tests/lora/test_lora_functions.py +++ b/tests/lora/test_lora_functions.py @@ -3,11 +3,13 @@ """ Script to test add_lora, remove_lora, pin_lora, list_loras functions. """ + import pytest from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, +) from vllm.lora.request import LoRARequest from vllm.v1.engine.llm_engine import LLMEngine @@ -17,23 +19,24 @@ def make_lora_request(lora_id: int): - return LoRARequest(lora_name=f"{lora_id}", - lora_int_id=lora_id, - lora_path=LORA_MODULE_PATH) + return LoRARequest( + lora_name=f"{lora_id}", lora_int_id=lora_id, lora_path=LORA_MODULE_PATH + ) def test_lora_functions_sync(): - max_loras = 4 # Create engine in eager-mode. Due to high max_loras, the CI can # OOM during cuda-graph capture. - engine_args = EngineArgs(model=MODEL_PATH, - enable_lora=True, - max_loras=max_loras, - max_lora_rank=LORA_RANK, - max_model_len=128, - gpu_memory_utilization=0.8, - enforce_eager=True) + engine_args = EngineArgs( + model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, + enforce_eager=True, + ) llm = LLMEngine.from_engine_args(engine_args) @@ -70,15 +73,16 @@ def run_check(fn, args, expected: list): @pytest.mark.asyncio async def test_lora_functions_async(): - max_loras = 4 - engine_args = AsyncEngineArgs(model=MODEL_PATH, - enable_lora=True, - max_loras=max_loras, - max_lora_rank=LORA_RANK, - max_model_len=128, - gpu_memory_utilization=0.8, - enforce_eager=True) + engine_args = AsyncEngineArgs( + model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, + enforce_eager=True, + ) async def run_check(fn, args, expected: list): await fn(args) diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index b46d81f1651a..7d20faef541a 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -11,8 +11,12 @@ # Provide absolute path and huggingface lora ids lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"] LLAMA_LORA_MODULES = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", - "lm_head" + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", ] @@ -40,7 +44,8 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request): lora_model_id=1, device="cpu", embedding_modules=embedding_modules, - embedding_padding_modules=embed_padding_modules) + embedding_padding_modules=embed_padding_modules, + ) # Assertions to ensure the model is loaded correctly assert lora_model is not None, "LoRAModel is not loaded correctly" diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 6f0a85231408..e7816031142e 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -10,16 +10,21 @@ from vllm.config import ModelConfig, VllmConfig from vllm.config.lora import LoRAConfig -from vllm.lora.layers import (ColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithLoRA, - RowParallelLinearWithLoRA) +from vllm.lora.layers import ( + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + RowParallelLinearWithLoRA, +) from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager, - LRUCacheLoRAModelManager) +from vllm.lora.models import ( + LoRAMapping, + LoRAModel, + LoRAModelManager, + LRUCacheLoRAModelManager, +) from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, - WorkerLoRAManager) +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager from vllm.platforms import current_platform from .utils import create_peft_lora @@ -31,22 +36,25 @@ EMBEDDING_PADDING_MODULES = ["lm_head"] -DEVICES = ([ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] if current_platform.is_cuda_alike() else ["cpu"]) +DEVICES = ( + [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] + if current_platform.is_cuda_alike() + else ["cpu"] +) DEFAULT_DTYPE = torch.get_default_dtype() @pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): - tensors = load_file( - os.path.join(sql_lora_files, "adapter_model.safetensors")) + tensors = load_file(os.path.join(sql_lora_files, "adapter_model.safetensors")) new_embeddings = load_file( - os.path.join(sql_lora_files, "new_embeddings.safetensors")) + os.path.join(sql_lora_files, "new_embeddings.safetensors") + ) - peft_helper = PEFTHelper.from_local_dir(sql_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + sql_lora_files, max_position_embeddings=4096 + ) lora_model = LoRAModel.from_lora_tensors( 1, tensors, @@ -54,7 +62,8 @@ def test_from_lora_tensors(sql_lora_files, device): device=device, embeddings=new_embeddings, embedding_modules=EMBEDDING_MODULES, - embedding_padding_modules=EMBEDDING_PADDING_MODULES) + embedding_padding_modules=EMBEDDING_PADDING_MODULES, + ) for module_name, lora in lora_model.loras.items(): assert lora.module_name == module_name assert lora.rank == 8 @@ -63,22 +72,27 @@ def test_from_lora_tensors(sql_lora_files, device): assert lora.lora_b is not None assert lora.lora_a.device == torch.device(device) assert lora.lora_b.device == torch.device(device) - assert (lora.lora_a.shape[0] == lora.lora_b.shape[1] - ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" + assert lora.lora_a.shape[0] == lora.lora_b.shape[1], ( + f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" + ) assert lora.lora_a.shape[0] == 8 embeddings_module = next( - (k for k in EMBEDDING_MODULES if k in module_name), None) + (k for k in EMBEDDING_MODULES if k in module_name), None + ) if embeddings_module: assert torch.equal( lora.embeddings_tensor, new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( - device=lora.embeddings_tensor.device)) + device=lora.embeddings_tensor.device + ), + ) else: assert lora.embeddings_tensor is None -def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str], - device: torch.device) -> LoRAModel: +def create_lora( + lora_id: int, model: nn.Module, sub_modules: list[str], device: torch.device +) -> LoRAModel: loras: dict[str, LoRALayerWeights] = {} for name in sub_modules: w = model.get_submodule(name).weight @@ -110,8 +124,7 @@ def create_packed_lora( 8, 16, torch.rand([8, w.shape[1]], device=device), - torch.rand([w.shape[0] // len(replaced_module_names), 8], - device=device), + torch.rand([w.shape[0] // len(replaced_module_names), 8], device=device), ) return LoRAModel(lora_id, 8, loras) @@ -119,42 +132,42 @@ def create_packed_lora( def test_replace_submodules(dist_init, dummy_model): model = dummy_model manager = LoRAModelManager( - model, 1, 1, 1, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=8, - max_loras=8, - lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0])) + model, + 1, + 1, + 1, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE + ), + torch.device(DEVICES[0]), + ) model = manager.model - assert isinstance(model.get_submodule("dense1"), - ColumnParallelLinearWithLoRA) - assert isinstance(model.get_submodule("layer1.dense1"), - ColumnParallelLinearWithLoRA) + assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA) + assert isinstance( + model.get_submodule("layer1.dense1"), ColumnParallelLinearWithLoRA + ) assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA) - assert isinstance(model.get_submodule("layer1.dense2"), - RowParallelLinearWithLoRA) + assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA) @pytest.mark.parametrize("device", DEVICES) def test_lora_model_manager(dist_init, dummy_model, device): model = dummy_model - model_lora1 = create_lora(1, - model, ["layer1.dense1", "dense2", "lm_head"], - device=device) - model_lora2 = create_lora(2, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora3 = create_lora(3, - model, ["dense1", "dense2", "lm_head"], - device=device) - manager = LoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=3, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + model_lora1 = create_lora( + 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device + ) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) assert manager.activate_adapter(1) @@ -204,24 +217,21 @@ def test_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): model = dummy_model - model_lora1 = create_lora(1, - model, ["layer1.dense1", "dense2", "lm_head"], - device=device) - model_lora2 = create_lora(2, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora3 = create_lora(3, - model, ["dense1", "dense2", "lm_head"], - device=device) - manager = LRUCacheLoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=3, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + model_lora1 = create_lora( + 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device + ) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) assert manager.activate_adapter(1) @@ -297,27 +307,22 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager model = dummy_model - model_lora1 = create_lora(1, - model, ["layer1.dense1", "dense2", "lm_head"], - device=device) - model_lora2 = create_lora(2, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora3 = create_lora(3, - model, ["dense1", "dense2", "lm_head"], - device=device) - model_lora4 = create_lora(4, - model, ["dense1", "dense2", "lm_head"], - device=device) - manager = LRUCacheLoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=2, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + model_lora1 = create_lora( + 1, model, ["layer1.dense1", "dense2", "lm_head"], device=device + ) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device) + model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"], device=device) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity @@ -421,12 +426,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): @pytest.mark.parametrize("device", DEVICES) -def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, - tmp_path): - lora_config = LoRAConfig(max_lora_rank=8, - max_cpu_loras=4, - max_loras=4, - lora_dtype=DEFAULT_DTYPE) +def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_path): + lora_config = LoRAConfig( + max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE + ) dummy_lora_files = f"{tmp_path}/lora_adapter" os.makedirs(dummy_lora_files, exist_ok=True) @@ -438,13 +441,13 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, ) model_config = ModelConfig(max_model_len=16) - vllm_config = VllmConfig(model_config=model_config, - lora_config=lora_config) + vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config) vllm_config.scheduler_config.max_num_seqs = 4 vllm_config.scheduler_config.max_num_batched_tokens = 2 worker_adapter_manager = LRUCacheWorkerLoRAManager( - vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) + vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES + ) worker_adapter_manager.max_num_seqs = 4 worker_adapter_manager.max_num_batched_tokens = 2 @@ -452,52 +455,64 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, worker_adapter_manager.create_lora_manager(dummy_model) mapping = LoRAMapping([], []) - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("2", 2, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("3", 3, dummy_lora_files), - LoRARequest("4", 4, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3 assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("2", 2, dummy_lora_files), - LoRARequest("5", 5, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("1", 1, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, dummy_lora_files), - LoRARequest("7", 7, dummy_lora_files), - LoRARequest("8", 8, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7 @@ -506,41 +521,40 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, # Over capacity with pytest.raises(RuntimeError): - worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, dummy_lora_files), - LoRARequest("11", 11, dummy_lora_files), - LoRARequest("12", 12, dummy_lora_files), - LoRARequest("13", 13, dummy_lora_files), - LoRARequest("14", 14, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.device == device - assert (worker_adapter_manager._adapter_manager.punica_wrapper.device == - device) + assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device @pytest.mark.parametrize("device", DEVICES) -def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, - tmp_path): +def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path): # Should remove every LoRA not specified in the request. - lora_config = LoRAConfig(max_lora_rank=8, - max_cpu_loras=4, - max_loras=4, - lora_dtype=DEFAULT_DTYPE) + lora_config = LoRAConfig( + max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE + ) model_config = ModelConfig(max_model_len=16) - vllm_config = VllmConfig(model_config=model_config, - lora_config=lora_config) + vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config) vllm_config.scheduler_config.max_num_seqs = 4 vllm_config.scheduler_config.max_num_batched_tokens = 2 - worker_adapter_manager = WorkerLoRAManager(vllm_config, device, - EMBEDDING_MODULES, - EMBEDDING_PADDING_MODULES) + worker_adapter_manager = WorkerLoRAManager( + vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES + ) worker_adapter_manager.vocab_size = ( - dummy_model_gate_up.unpadded_vocab_size - - lora_config.lora_extra_vocab_size) + dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size + ) worker_adapter_manager.create_lora_manager(dummy_model_gate_up) dummy_lora_files = f"{tmp_path}/lora_adapter" @@ -553,49 +567,61 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, ) mapping = LoRAMapping([], []) - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("2", 2, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("3", 3, dummy_lora_files), - LoRARequest("4", 4, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("3", 3, dummy_lora_files), + LoRARequest("4", 4, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 3, 4} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("2", 2, dummy_lora_files), - LoRARequest("5", 5, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("2", 2, dummy_lora_files), + LoRARequest("5", 5, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1, 2, 5} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 - worker_adapter_manager.set_active_adapters([ - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("1", 1, dummy_lora_files), - LoRARequest("1", 1, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + LoRARequest("1", 1, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {1} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None - worker_adapter_manager.set_active_adapters([ - LoRARequest("6", 6, dummy_lora_files), - LoRARequest("7", 7, dummy_lora_files), - LoRARequest("8", 8, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("6", 6, dummy_lora_files), + LoRARequest("7", 7, dummy_lora_files), + LoRARequest("8", 8, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.list_adapters() == {6, 7, 8} assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6 @@ -603,17 +629,19 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, # Over capacity with pytest.raises(RuntimeError): - worker_adapter_manager.set_active_adapters([ - LoRARequest("10", 10, dummy_lora_files), - LoRARequest("11", 11, dummy_lora_files), - LoRARequest("12", 12, dummy_lora_files), - LoRARequest("13", 13, dummy_lora_files), - LoRARequest("14", 14, dummy_lora_files) - ], mapping) + worker_adapter_manager.set_active_adapters( + [ + LoRARequest("10", 10, dummy_lora_files), + LoRARequest("11", 11, dummy_lora_files), + LoRARequest("12", 12, dummy_lora_files), + LoRARequest("13", 13, dummy_lora_files), + LoRARequest("14", 14, dummy_lora_files), + ], + mapping, + ) assert worker_adapter_manager.device == device - assert (worker_adapter_manager._adapter_manager.punica_wrapper.device == - device) + assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device @pytest.mark.parametrize("device", DEVICES) @@ -624,7 +652,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): model, module_name="gate_up_proj", replaced_module_names=["gate_proj", "up_proj"], - device=device) + device=device, + ) model_lora1 = create_packed_lora( 2, model, @@ -634,19 +663,21 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): empty_replaced_module_name="gate_proj", ) - manager = LoRAModelManager(model, - 2, - 2, - 2, - LoRAConfig(max_lora_rank=8, - max_cpu_loras=2, - max_loras=2, - lora_dtype=DEFAULT_DTYPE), - device=device) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE + ), + device=device, + ) model = manager.model - assert isinstance(model.get_submodule("gate_up_proj"), - MergedColumnParallelLinearWithLoRA) + assert isinstance( + model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA + ) # Verify packed lora is correct model_lora_clone = model_lora.clone(1) model_lora_clone1 = model_lora1.clone(1) @@ -659,21 +690,27 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device): packed_lora = model_lora.get_lora("gate_up_proj") assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) - torch.testing.assert_close(packed_lora.lora_a[0], - model_lora_clone.get_lora("gate_proj").lora_a) - torch.testing.assert_close(packed_lora.lora_b[0], - model_lora_clone.get_lora("gate_proj").lora_b) - torch.testing.assert_close(packed_lora.lora_a[1], - model_lora_clone.get_lora("up_proj").lora_a) - torch.testing.assert_close(packed_lora.lora_b[1], - model_lora_clone.get_lora("up_proj").lora_b) + torch.testing.assert_close( + packed_lora.lora_a[0], model_lora_clone.get_lora("gate_proj").lora_a + ) + torch.testing.assert_close( + packed_lora.lora_b[0], model_lora_clone.get_lora("gate_proj").lora_b + ) + torch.testing.assert_close( + packed_lora.lora_a[1], model_lora_clone.get_lora("up_proj").lora_a + ) + torch.testing.assert_close( + packed_lora.lora_b[1], model_lora_clone.get_lora("up_proj").lora_b + ) packed_lora1 = model_lora1.get_lora("gate_up_proj") assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights) assert packed_lora1.lora_a[0] is None assert packed_lora1.lora_b[0] is None - torch.testing.assert_close(packed_lora1.lora_a[1], - model_lora_clone1.get_lora("up_proj").lora_a) - torch.testing.assert_close(packed_lora1.lora_b[1], - model_lora_clone1.get_lora("up_proj").lora_b) + torch.testing.assert_close( + packed_lora1.lora_a[1], model_lora_clone1.get_lora("up_proj").lora_a + ) + torch.testing.assert_close( + packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b + ) diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index 99fe951bbf07..ce98fe2f8613 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -15,7 +15,8 @@ PROMPT_TEMPLATE = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" "(<image>./</image>)\nWhat is in the image?<|eot_id|>" - "<|start_header_id|>assistant<|end_header_id|>\n\n") + "<|start_header_id|>assistant<|end_header_id|>\n\n" +) IMAGE_ASSETS = [ ImageAsset("stop_sign"), @@ -34,18 +35,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: stop_token_ids=[128001, 128009], # eos_id, eot_id ) - inputs = [{ - "prompt": PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in IMAGE_ASSETS] + inputs = [ + { + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in IMAGE_ASSETS + ] outputs = llm.generate( inputs, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, ) # Print the outputs. generated_texts: list[str] = [] @@ -58,7 +59,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: @pytest.mark.xfail( current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") + reason="MiniCPM-V dependency xformers incompatible with ROCm", +) def test_minicpmv_lora(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, @@ -68,10 +70,7 @@ def test_minicpmv_lora(minicpmv_lora_files): max_lora_rank=8, enforce_eager=True, max_model_len=2048, - limit_mm_per_prompt={ - "image": 2, - "video": 0 - }, + limit_mm_per_prompt={"image": 2, "video": 0}, trust_remote_code=True, ) output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) @@ -82,11 +81,13 @@ def test_minicpmv_lora(minicpmv_lora_files): assert EXPECTED_OUTPUT[i].startswith(output2[i]) -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @pytest.mark.xfail( current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") + reason="MiniCPM-V dependency xformers incompatible with ROCm", +) @create_new_process_for_each_test() def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( @@ -96,10 +97,7 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): max_loras=4, max_lora_rank=64, tensor_parallel_size=4, - limit_mm_per_prompt={ - "image": 2, - "video": 0 - }, + limit_mm_per_prompt={"image": 2, "video": 0}, trust_remote_code=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) @@ -107,11 +105,13 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @pytest.mark.xfail( current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") + reason="MiniCPM-V dependency xformers incompatible with ROCm", +) @create_new_process_for_each_test() def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( @@ -122,10 +122,7 @@ def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): max_lora_rank=8, tensor_parallel_size=4, trust_remote_code=True, - limit_mm_per_prompt={ - "image": 1, - "video": 0 - }, + limit_mm_per_prompt={"image": 1, "video": 0}, fully_sharded_loras=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 03e5d8d5d672..868ca51b3331 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -11,15 +11,15 @@ MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1" -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, - prompts: list[str]) -> list[str]: - +def do_sample( + llm: vllm.LLM, lora_path: str, lora_id: int, prompts: list[str] +) -> list[str]: sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -33,8 +33,11 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, @pytest.mark.parametrize("tp_size", [4]) def test_mixtral_lora(mixtral_lora_files, tp_size): """Original test, the LoRA model has the common target modules, not all""" - if torch.cuda.device_count( - ) < tp_size and tp_size > 1 and current_platform.is_cuda_alike(): + if ( + torch.cuda.device_count() < tp_size + and tp_size > 1 + and current_platform.is_cuda_alike() + ): pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") prompts = [ @@ -57,7 +60,11 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): "give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])", # noqa: E501 "inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501 ] - assert do_sample(llm, mixtral_lora_files, lora_id=1, - prompts=prompts) == expected_lora_output - assert do_sample(llm, mixtral_lora_files, lora_id=2, - prompts=prompts) == expected_lora_output + assert ( + do_sample(llm, mixtral_lora_files, lora_id=1, prompts=prompts) + == expected_lora_output + ) + assert ( + do_sample(llm, mixtral_lora_files, lora_id=2, prompts=prompts) + == expected_lora_output + ) diff --git a/tests/lora/test_peft_helper.py b/tests/lora/test_peft_helper.py index ffffb5d8eab9..2cc8bfe63495 100644 --- a/tests/lora/test_peft_helper.py +++ b/tests/lora/test_peft_helper.py @@ -13,34 +13,27 @@ ERROR_CASES = [ ( "test_rank", - { - "r": 1024 - }, + {"r": 1024}, "is greater than max_lora_rank", ), ( "test_bias", - { - "bias": "all" - }, + {"bias": "all"}, "Adapter bias cannot be used without bias_enabled", ), - ("test_dora", { - "use_dora": True - }, "does not yet support DoRA"), + ("test_dora", {"use_dora": True}, "does not yet support DoRA"), ( "test_modules_to_save", - { - "modules_to_save": ["lm_head"] - }, + {"modules_to_save": ["lm_head"]}, "only supports modules_to_save being None", ), ] def test_peft_helper_pass(sql_lora_files, tmp_path): - peft_helper = PEFTHelper.from_local_dir(sql_lora_files, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir( + sql_lora_files, max_position_embeddings=4096 + ) lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2) peft_helper.validate_legal(lora_config) assert peft_helper.r == 8 @@ -74,8 +67,7 @@ def test_peft_helper_pass(sql_lora_files, tmp_path): with open(config_path, "w") as f: json.dump(adapter_config, f) - peft_helper = PEFTHelper.from_local_dir(test_dir, - max_position_embeddings=4096) + peft_helper = PEFTHelper.from_local_dir(test_dir, max_position_embeddings=4096) peft_helper.validate_legal(lora_config) scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r) assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3 @@ -106,4 +98,5 @@ def test_peft_helper_error( # Test loading the adapter with pytest.raises(ValueError, match=expected_error): PEFTHelper.from_local_dir( - test_dir, max_position_embeddings=4096).validate_legal(lora_config) + test_dir, max_position_embeddings=4096 + ).validate_legal(lora_config) diff --git a/tests/lora/test_punica_ops.py b/tests/lora/test_punica_ops.py index 14fa79ae5b44..e4df9751077d 100644 --- a/tests/lora/test_punica_ops.py +++ b/tests/lora/test_punica_ops.py @@ -21,11 +21,18 @@ def reset_device(reset_default_device): # Utility shrink and expand operations used as reference implementations. def sgmv_shrink_for_nslices( - nslices: int, inputs_tensor: torch.Tensor, - lora_weights_lst: list[torch.Tensor], out_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, - prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int, - num_tokens: int, scaling: float): + nslices: int, + inputs_tensor: torch.Tensor, + lora_weights_lst: list[torch.Tensor], + out_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + prompt_lora_mapping: torch.Tensor, + batches: int, + max_seq_length: int, + num_tokens: int, + scaling: float, +): """ Wrapper around torch_ops.sgmv_shrink that handles any nslices. """ @@ -44,15 +51,20 @@ def sgmv_shrink_for_nslices( ) -def sgmv_expand_for_nslices(nslices: int, hidden_size: int, - inputs_tensor: torch.Tensor, - lora_weights_lst: list[torch.Tensor], - out_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - prompt_lora_mapping: torch.Tensor, batches: int, - max_seq_length: int, num_tokens: int, - add_inputs: bool) -> None: +def sgmv_expand_for_nslices( + nslices: int, + hidden_size: int, + inputs_tensor: torch.Tensor, + lora_weights_lst: list[torch.Tensor], + out_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + prompt_lora_mapping: torch.Tensor, + batches: int, + max_seq_length: int, + num_tokens: int, + add_inputs: bool, +) -> None: """ Wrapper around torch_ops.sgmv_expand that handles any nslices. """ @@ -94,10 +106,17 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int, _dict_lock = Lock() -def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, - dtype: torch.dtype, device: str, seq_length: int, - scaling: float): +def check_lora_shrink_kernel( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seq_length: int, + scaling: float, +): """ Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink kernels. @@ -116,14 +135,19 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int, max_seq_length, token_nums = data.meta() # Setup metadata information for SGMV and reference kernels - sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor, - data.prompt_lora_mapping, batches, max_seq_length, - token_nums) + sgmv_meta_args = ( + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + ) # Setup metadata information for the LoRA kernel. - lora_meta = LoRAKernelMeta.make(max_loras=num_loras, - max_num_tokens=token_nums, - device='cuda') + lora_meta = LoRAKernelMeta.make( + max_loras=num_loras, max_num_tokens=token_nums, device="cuda" + ) lora_meta.prepare_tensors(data.token_lora_mapping) ref_out_tensor = data.ref_out_tensor @@ -154,10 +178,17 @@ def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int, assert_close(out_tensor, ref_out_tensor) -def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, - dtype: torch.dtype, device: str, seq_length: int, - add_inputs: bool): +def check_lora_expand_kernel( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seq_length: int, + add_inputs: bool, +): """ Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand kernels. @@ -177,14 +208,19 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, max_seq_length, token_nums = data.meta() # Setup metadata information for SGMV and reference kernels - sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor, - data.prompt_lora_mapping, batches, max_seq_length, - token_nums) + sgmv_meta_args = ( + data.b_seq_start_loc, + data.seq_len_tensor, + data.prompt_lora_mapping, + batches, + max_seq_length, + token_nums, + ) # Setup metadata information for the LoRA kernel. - lora_meta = LoRAKernelMeta.make(max_loras=num_loras, - max_num_tokens=token_nums, - device='cuda') + lora_meta = LoRAKernelMeta.make( + max_loras=num_loras, max_num_tokens=token_nums, device="cuda" + ) lora_meta.prepare_tensors(data.token_lora_mapping) # Setup output tensors @@ -194,21 +230,25 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, with _dict_lock: # lora_expand kernel _LORA_B_PTR_DICT.clear() - triton_ops.lora_expand(data.inputs_tensor, - data.lora_weights, - out_tensor, - *lora_meta.meta_args(token_nums=token_nums), - offset_start=0, - add_inputs=add_inputs) + triton_ops.lora_expand( + data.inputs_tensor, + data.lora_weights, + out_tensor, + *lora_meta.meta_args(token_nums=token_nums), + offset_start=0, + add_inputs=add_inputs, + ) # Reference - sgmv_expand_for_nslices(nslices, - hidden_size, - data.inputs_tensor, - data.lora_weights, - ref_out_tensor, - *sgmv_meta_args, - add_inputs=add_inputs) + sgmv_expand_for_nslices( + nslices, + hidden_size, + data.inputs_tensor, + data.lora_weights, + ref_out_tensor, + *sgmv_meta_args, + add_inputs=add_inputs, + ) assert_close(out_tensor, ref_out_tensor) @@ -299,7 +339,7 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, 128000, 128256, ] -#The size of TP +# The size of TP divisibility = [1, 2, 8, 16, 64] all_hidden_size = [] @@ -331,10 +371,10 @@ def check_lora_expand_kernel(batches: int, num_loras: int, rank: int, SEED = [0] -@pytest.mark.parametrize("batches", test_params['batches']) -@pytest.mark.parametrize("num_loras", test_params['num_loras']) -@pytest.mark.parametrize("rank", test_params['max_ranks']) -@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes']) +@pytest.mark.parametrize("batches", test_params["batches"]) +@pytest.mark.parametrize("num_loras", test_params["num_loras"]) +@pytest.mark.parametrize("rank", test_params["max_ranks"]) +@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"]) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", DEVICES) @@ -358,31 +398,35 @@ def test_kernels( current_platform.seed_everything(seed) if op_type == "shrink": - check_lora_shrink_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - scaling=0.5) + check_lora_shrink_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5, + ) else: - check_lora_expand_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - add_inputs=True) - - -@pytest.mark.parametrize("batches", hs_test_params['batches']) -@pytest.mark.parametrize("num_loras", hs_test_params['num_loras']) -@pytest.mark.parametrize("rank", hs_test_params['max_ranks']) -@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes']) + check_lora_expand_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True, + ) + + +@pytest.mark.parametrize("batches", hs_test_params["batches"]) +@pytest.mark.parametrize("num_loras", hs_test_params["num_loras"]) +@pytest.mark.parametrize("rank", hs_test_params["max_ranks"]) +@pytest.mark.parametrize("hidden_size", hs_test_params["hidden_sizes"]) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", DEVICES) @@ -406,22 +450,26 @@ def test_kernels_hidden_size( current_platform.seed_everything(seed) if op_type == "shrink": - check_lora_shrink_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - scaling=0.5) + check_lora_shrink_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5, + ) else: - check_lora_expand_kernel(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - add_inputs=True) + check_lora_expand_kernel( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True, + ) diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 2b54b2edd6a9..06e1b22ab56e 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -20,28 +20,27 @@ class ModelWithQuantization: MODELS: list[ModelWithQuantization] -#AWQ quantization is currently not supported in ROCm. +# AWQ quantization is currently not supported in ROCm. if current_platform.is_rocm(): MODELS = [ ModelWithQuantization( - model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - quantization="gptq"), + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", quantization="gptq" + ), ] else: MODELS = [ ModelWithQuantization( - model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", - quantization="awq"), + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", quantization="awq" + ), ModelWithQuantization( - model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - quantization="gptq"), + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", quantization="gptq" + ), ] -def do_sample(llm: vllm.LLM, - lora_path: str, - lora_id: int, - max_tokens: int = 256) -> list[str]: +def do_sample( + llm: vllm.LLM, lora_path: str, lora_id: int, max_tokens: int = 256 +) -> list[str]: raw_prompts = [ "Give me an orange-ish brown color", "Give me a neon pink color", @@ -52,14 +51,14 @@ def format_prompt_tuples(prompt): prompts = [format_prompt_tuples(p) for p in raw_prompts] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=max_tokens, - stop=["<|im_end|>"]) + sampling_params = vllm.SamplingParams( + temperature=0, max_tokens=max_tokens, stop=["<|im_end|>"] + ) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -72,18 +71,18 @@ def format_prompt_tuples(prompt): @pytest.mark.parametrize("model", MODELS) def test_quant_model_lora(tinyllama_lora_files, model): - llm = vllm.LLM( model=model.model_path, enable_lora=True, max_num_seqs=16, max_loras=4, max_model_len=400, - gpu_memory_utilization=0.2, #avoid OOM + gpu_memory_utilization=0.2, # avoid OOM quantization=model.quantization, trust_remote_code=True, enable_chunked_prefill=True, - tokenizer=tinyllama_lora_files) + tokenizer=tinyllama_lora_files, + ) if model.quantization is None: expected_lora_output = [ @@ -104,11 +103,11 @@ def test_quant_model_lora(tinyllama_lora_files, model): def expect_match(output, expected_output): # HACK: GPTQ lora outputs are just incredibly unstable. # Assert that the outputs changed. - if (model.quantization == "gptq" - and expected_output is expected_lora_output): + if model.quantization == "gptq" and expected_output is expected_lora_output: for i, o in enumerate(output): - assert o.startswith( - '#'), f"Expected example {i} to start with # but got {o}" + assert o.startswith("#"), ( + f"Expected example {i} to start with # but got {o}" + ) return assert output == expected_output @@ -116,17 +115,11 @@ def expect_match(output, expected_output): print("lora adapter created") print("lora 1") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=1, - max_tokens=max_tokens) + output = do_sample(llm, tinyllama_lora_files, lora_id=1, max_tokens=max_tokens) expect_match(output, expected_lora_output) print("lora 2") - output = do_sample(llm, - tinyllama_lora_files, - lora_id=2, - max_tokens=max_tokens) + output = do_sample(llm, tinyllama_lora_files, lora_id=2, max_tokens=max_tokens) expect_match(output, expected_lora_output) print("removing lora") @@ -136,8 +129,7 @@ def expect_match(output, expected_output): @pytest.mark.parametrize("model", MODELS) -def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, - model): +def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, model): if num_gpus_available < 2: pytest.skip(f"Not enough GPUs for tensor parallelism {2}") if model.quantization == "gptq": @@ -147,10 +139,11 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, enable_lora=True, max_num_seqs=16, max_loras=4, - gpu_memory_utilization=0.2, #avoid OOM + gpu_memory_utilization=0.2, # avoid OOM quantization=model.quantization, trust_remote_code=True, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + ) output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) del llm_tp1 @@ -162,9 +155,10 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, max_num_seqs=16, max_loras=4, tensor_parallel_size=2, - gpu_memory_utilization=0.2, #avoid OOM + gpu_memory_utilization=0.2, # avoid OOM quantization=model.quantization, - enable_chunked_prefill=True) + enable_chunked_prefill=True, + ) output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) del llm_tp2 diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index 76f3bc0ebf89..894263bd0ba3 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -37,7 +37,8 @@ class Qwen2VLTester: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" "What is in the image?<|im_end|>\n" - "<|im_start|>assistant\n") + "<|im_start|>assistant\n" + ) def __init__(self, config: TestConfig): self.config = config @@ -56,68 +57,68 @@ def _initialize_llm(self) -> vllm.LLM: max_model_len=self.config.max_model_len, ) - def run_test(self, - images: list[ImageAsset], - expected_outputs: list[str], - lora_id: Optional[int] = None, - temperature: float = 0, - max_tokens: int = 5): - + def run_test( + self, + images: list[ImageAsset], + expected_outputs: list[str], + lora_id: Optional[int] = None, + temperature: float = 0, + max_tokens: int = 5, + ): sampling_params = vllm.SamplingParams( temperature=temperature, max_tokens=max_tokens, ) - inputs = [{ - "prompt": self.PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in images] - - lora_request = LoRARequest(str(lora_id), lora_id, - self.config.lora_path) - outputs = self.llm.generate(inputs, - sampling_params, - lora_request=lora_request) - generated_texts = [ - output.outputs[0].text.strip() for output in outputs + inputs = [ + { + "prompt": self.PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in images ] + lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path) + outputs = self.llm.generate(inputs, sampling_params, lora_request=lora_request) + generated_texts = [output.outputs[0].text.strip() for output in outputs] + # Validate outputs for generated, expected in zip(generated_texts, expected_outputs): - assert expected.startswith( - generated), f"Generated text {generated} doesn't " + assert expected.startswith(generated), ( + f"Generated text {generated} doesn't " + ) f"match expected pattern {expected}" - def run_beam_search_test(self, - images: list[ImageAsset], - expected_outputs: list[list[str]], - lora_id: Optional[int] = None, - temperature: float = 0, - beam_width: int = 2, - max_tokens: int = 5): - - beam_search_params = BeamSearchParams(beam_width=beam_width, - max_tokens=max_tokens, - temperature=temperature) - - inputs = [{ - "prompt": self.PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in images] - - lora_request = LoRARequest(str(lora_id), lora_id, - self.config.lora_path) - outputs = self.llm.beam_search(inputs, - beam_search_params, - lora_request=lora_request) + def run_beam_search_test( + self, + images: list[ImageAsset], + expected_outputs: list[list[str]], + lora_id: Optional[int] = None, + temperature: float = 0, + beam_width: int = 2, + max_tokens: int = 5, + ): + beam_search_params = BeamSearchParams( + beam_width=beam_width, max_tokens=max_tokens, temperature=temperature + ) + + inputs = [ + { + "prompt": self.PROMPT_TEMPLATE, + "multi_modal_data": {"image": asset.pil_image}, + } + for asset in images + ] + + lora_request = LoRARequest(str(lora_id), lora_id, self.config.lora_path) + outputs = self.llm.beam_search( + inputs, beam_search_params, lora_request=lora_request + ) for output_obj, expected_outs in zip(outputs, expected_outputs): output_texts = [seq.text for seq in output_obj.sequences] - assert output_texts == expected_outs, \ - f"Generated texts {output_texts} do not match expected {expected_outs}" # noqa: E501 + assert output_texts == expected_outs, ( + f"Generated texts {output_texts} do not match expected {expected_outs}" + ) # noqa: E501 TEST_IMAGES = [ @@ -144,27 +145,25 @@ def run_beam_search_test(self, @pytest.mark.xfail( current_platform.is_rocm(), - reason="Qwen2-VL dependency xformers incompatible with ROCm") + reason="Qwen2-VL dependency xformers incompatible with ROCm", +) def test_qwen2vl_lora(qwen2vl_lora_files): """Test Qwen 2.0 VL model with LoRA""" - config = TestConfig(model_path=QWEN2VL_MODEL_PATH, - lora_path=qwen2vl_lora_files) + config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files) tester = Qwen2VLTester(config) # Test with different LoRA IDs for lora_id in [1, 2]: - tester.run_test(TEST_IMAGES, - expected_outputs=EXPECTED_OUTPUTS, - lora_id=lora_id) + tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id) @pytest.mark.xfail( current_platform.is_rocm(), - reason="Qwen2-VL dependency xformers incompatible with ROCm") + reason="Qwen2-VL dependency xformers incompatible with ROCm", +) def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): """Test Qwen 2.0 VL model with LoRA through beam search.""" - config = TestConfig(model_path=QWEN2VL_MODEL_PATH, - lora_path=qwen2vl_lora_files) + config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files) tester = Qwen2VLTester(config) # Test with different LoRA IDs @@ -176,7 +175,8 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): tester.run_beam_search_test( [ImageAsset("cherry_blossom")], expected_outputs=EXPECTED_BEAM_SEARCH_OUTPUTS, - lora_id=lora_id) + lora_id=lora_id, + ) @pytest.mark.xfail( @@ -185,12 +185,9 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): ) def test_qwen25vl_lora(qwen25vl_lora_files): """Test Qwen 2.5 VL model with LoRA""" - config = TestConfig(model_path=QWEN25VL_MODEL_PATH, - lora_path=qwen25vl_lora_files) + config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files) tester = Qwen2VLTester(config) # Test with different LoRA IDs for lora_id in [1, 2]: - tester.run_test(TEST_IMAGES, - expected_outputs=EXPECTED_OUTPUTS, - lora_id=lora_id) + tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id) diff --git a/tests/lora/test_resolver.py b/tests/lora/test_resolver.py index 6c93e577611f..c70e58a375c7 100644 --- a/tests/lora/test_resolver.py +++ b/tests/lora/test_resolver.py @@ -12,13 +12,15 @@ class DummyLoRAResolver(LoRAResolver): """A dummy LoRA resolver for testing.""" - async def resolve_lora(self, base_model_name: str, - lora_name: str) -> Optional[LoRARequest]: + async def resolve_lora( + self, base_model_name: str, lora_name: str + ) -> Optional[LoRARequest]: if lora_name == "test_lora": return LoRARequest( lora_name=lora_name, lora_path=f"/dummy/path/{base_model_name}/{lora_name}", - lora_int_id=abs(hash(lora_name))) + lora_int_id=abs(hash(lora_name)), + ) return None @@ -70,6 +72,5 @@ async def test_dummy_resolver_resolve(): assert result.lora_path == f"/dummy/path/{base_model_name}/{lora_name}" # Test failed resolution - result = await dummy_resolver.resolve_lora(base_model_name, - "nonexistent_lora") + result = await dummy_resolver.resolve_lora(base_model_name, "nonexistent_lora") assert result is None diff --git a/tests/lora/test_transformers_model.py b/tests/lora/test_transformers_model.py index 723f7a54778f..ea1f5f9c32c3 100644 --- a/tests/lora/test_transformers_model.py +++ b/tests/lora/test_transformers_model.py @@ -24,20 +24,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: prompts = [ PROMPT_TEMPLATE.format(query="How many singers do we have?"), PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + query="What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 ), PROMPT_TEMPLATE.format( - query= - "What are all distinct countries where singers above age 20 are from?" # noqa: E501 + query="What are all distinct countries where singers above age 20 are from?" # noqa: E501 ), ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) outputs = llm.generate( prompts, sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, + ) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -49,13 +47,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: def test_ilama_lora(ilama_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=16, - trust_remote_code=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + trust_remote_code=True, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, ilama_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -65,20 +65,23 @@ def test_ilama_lora(ilama_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_ilama_lora_tp4(ilama_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=16, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=False, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=False, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, ilama_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): @@ -88,20 +91,23 @@ def test_ilama_lora_tp4(ilama_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] -@pytest.mark.skipif(current_platform.is_cuda_alike(), - reason="Skipping to avoid redundant model tests") +@pytest.mark.skipif( + current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" +) @multi_gpu_test(num_gpus=4) @create_new_process_for_each_test() def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=16, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=True, - enable_chunked_prefill=True) + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True, + enable_chunked_prefill=True, + ) output1 = do_sample(llm, ilama_lora_files, lora_id=1) for i in range(len(EXPECTED_LORA_OUTPUT)): assert output1[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index b343bef0a920..aed91d98ddbd 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -9,8 +9,11 @@ from huggingface_hub.utils import HfHubHTTPError from torch import nn -from vllm.lora.utils import (get_adapter_absolute_path, - parse_fine_tuned_lora_name, replace_submodule) +from vllm.lora.utils import ( + get_adapter_absolute_path, + parse_fine_tuned_lora_name, + replace_submodule, +) from vllm.model_executor.models.utils import WeightsMapper @@ -24,10 +27,12 @@ class LoRANameParserTestConfig(NamedTuple): def test_parse_fine_tuned_lora_name_valid(): fixture = [ - LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight", - "lm_head", True, False), - LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight", - "lm_head", False, False), + LoRANameParserTestConfig( + "base_model.model.lm_head.lora_A.weight", "lm_head", True, False + ), + LoRANameParserTestConfig( + "base_model.model.lm_head.lora_B.weight", "lm_head", False, False + ), LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_A", "model.embed_tokens", @@ -71,7 +76,8 @@ def test_parse_fine_tuned_lora_name_valid(): True, False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", @@ -79,7 +85,8 @@ def test_parse_fine_tuned_lora_name_valid(): False, False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), LoRANameParserTestConfig( "model.layers.9.mlp.down_proj.lora_A.weight", @@ -87,7 +94,8 @@ def test_parse_fine_tuned_lora_name_valid(): True, False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), LoRANameParserTestConfig( "model.layers.9.mlp.down_proj.lora_B.weight", @@ -95,12 +103,14 @@ def test_parse_fine_tuned_lora_name_valid(): False, False, weights_mapper=WeightsMapper( - orig_to_new_prefix={"model.": "language_model.model."}), + orig_to_new_prefix={"model.": "language_model.model."} + ), ), ] for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: - assert (module_name, is_lora_a, - is_bias) == parse_fine_tuned_lora_name(name, weights_mapper) + assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name( + name, weights_mapper + ) def test_parse_fine_tuned_lora_name_invalid(): @@ -115,22 +125,28 @@ def test_parse_fine_tuned_lora_name_invalid(): def test_replace_submodule(): model = nn.Sequential( - OrderedDict([ - ("dense1", nn.Linear(764, 100)), - ("act1", nn.ReLU()), - ("dense2", nn.Linear(100, 50)), - ( - "seq1", - nn.Sequential( - OrderedDict([ - ("dense1", nn.Linear(100, 10)), - ("dense2", nn.Linear(10, 50)), - ])), - ), - ("act2", nn.ReLU()), - ("output", nn.Linear(50, 10)), - ("outact", nn.Sigmoid()), - ])) + OrderedDict( + [ + ("dense1", nn.Linear(764, 100)), + ("act1", nn.ReLU()), + ("dense2", nn.Linear(100, 50)), + ( + "seq1", + nn.Sequential( + OrderedDict( + [ + ("dense1", nn.Linear(100, 10)), + ("dense2", nn.Linear(10, 50)), + ] + ) + ), + ), + ("act2", nn.ReLU()), + ("output", nn.Linear(50, 10)), + ("outact", nn.Sigmoid()), + ] + ) + ) sigmoid = nn.Sigmoid() @@ -143,52 +159,51 @@ def test_replace_submodule(): # Unit tests for get_adapter_absolute_path -@patch('os.path.isabs') +@patch("os.path.isabs") def test_get_adapter_absolute_path_absolute(mock_isabs): - path = '/absolute/path/to/lora' + path = "/absolute/path/to/lora" mock_isabs.return_value = True assert get_adapter_absolute_path(path) == path -@patch('os.path.expanduser') +@patch("os.path.expanduser") def test_get_adapter_absolute_path_expanduser(mock_expanduser): # Path with ~ that needs to be expanded - path = '~/relative/path/to/lora' - absolute_path = '/home/user/relative/path/to/lora' + path = "~/relative/path/to/lora" + absolute_path = "/home/user/relative/path/to/lora" mock_expanduser.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path -@patch('os.path.exists') -@patch('os.path.abspath') +@patch("os.path.exists") +@patch("os.path.abspath") def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist): # Relative path that exists locally - path = 'relative/path/to/lora' - absolute_path = '/absolute/path/to/lora' + path = "relative/path/to/lora" + absolute_path = "/absolute/path/to/lora" mock_exist.return_value = True mock_abspath.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path -@patch('huggingface_hub.snapshot_download') -@patch('os.path.exists') -def test_get_adapter_absolute_path_huggingface(mock_exist, - mock_snapshot_download): +@patch("huggingface_hub.snapshot_download") +@patch("os.path.exists") +def test_get_adapter_absolute_path_huggingface(mock_exist, mock_snapshot_download): # Hugging Face model identifier - path = 'org/repo' - absolute_path = '/mock/snapshot/path' + path = "org/repo" + absolute_path = "/mock/snapshot/path" mock_exist.return_value = False mock_snapshot_download.return_value = absolute_path assert get_adapter_absolute_path(path) == absolute_path -@patch('huggingface_hub.snapshot_download') -@patch('os.path.exists') -def test_get_adapter_absolute_path_huggingface_error(mock_exist, - mock_snapshot_download): +@patch("huggingface_hub.snapshot_download") +@patch("os.path.exists") +def test_get_adapter_absolute_path_huggingface_error( + mock_exist, mock_snapshot_download +): # Hugging Face model identifier with download error - path = 'org/repo' + path = "org/repo" mock_exist.return_value = False - mock_snapshot_download.side_effect = HfHubHTTPError( - "failed to query model info") + mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info") assert get_adapter_absolute_path(path) == path diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 9c47abf8f4dc..c97f8debd1b9 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -6,8 +6,14 @@ import tempfile from unittest.mock import patch -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.config import ( + CacheConfig, + DeviceConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) from vllm.config.load import LoadConfig from vllm.config.lora import LoRAConfig from vllm.lora.models import LoRAMapping @@ -19,12 +25,12 @@ @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): - def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): lora_mapping = LoRAMapping([], []) worker.model_runner.lora_manager.set_active_adapters( - lora_requests, lora_mapping) + lora_requests, lora_mapping + ) vllm_config = VllmConfig( model_config=ModelConfig( @@ -49,9 +55,9 @@ def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): swap_space=0, cache_dtype="auto", ), - lora_config=LoRAConfig(max_lora_rank=8, - max_cpu_loras=NUM_LORAS, - max_loras=NUM_LORAS), + lora_config=LoRAConfig( + max_lora_rank=8, max_cpu_loras=NUM_LORAS, max_loras=NUM_LORAS + ), ) worker = Worker( vllm_config=vllm_config, @@ -67,23 +73,22 @@ def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): assert worker.list_loras() == set() lora_requests = [ - LoRARequest(str(i + 1), i + 1, sql_lora_files) - for i in range(NUM_LORAS) + LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(NUM_LORAS) ] set_active_loras(worker, lora_requests) assert worker.list_loras() == { - lora_request.lora_int_id - for lora_request in lora_requests + lora_request.lora_int_id for lora_request in lora_requests } for i in range(NUM_LORAS): random.seed(i) - iter_lora_requests = random.choices(lora_requests, - k=random.randint(1, NUM_LORAS)) + iter_lora_requests = random.choices( + lora_requests, k=random.randint(1, NUM_LORAS) + ) random.shuffle(iter_lora_requests) - iter_lora_requests = iter_lora_requests[:-random.randint(0, NUM_LORAS)] + iter_lora_requests = iter_lora_requests[: -random.randint(0, NUM_LORAS)] set_active_loras(worker, lora_requests) assert worker.list_loras().issuperset( - {lora_request.lora_int_id - for lora_request in iter_lora_requests}) + {lora_request.lora_int_id for lora_request in iter_lora_requests} + ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 0432a1a9bba0..b522aa6b0874 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -13,7 +13,6 @@ class DummyLoRAManager: - def __init__(self, device: torch.device = "cuda:0"): super().__init__() self._loras: dict[str, LoRALayerWeights] = {} @@ -36,12 +35,12 @@ def init_random_lora( module_name, rank=rank, lora_alpha=1, - lora_a=torch.rand([rank, weight.shape[1]], - dtype=weight.dtype, - device=self._device), - lora_b=torch.rand([weight.shape[0], rank], - dtype=weight.dtype, - device=self._device), + lora_a=torch.rand( + [rank, weight.shape[1]], dtype=weight.dtype, device=self._device + ), + lora_b=torch.rand( + [weight.shape[0], rank], dtype=weight.dtype, device=self._device + ), ) if generate_embeddings_tensor: lora.embeddings_tensor = torch.rand( @@ -146,27 +145,26 @@ def generate_data( op_type, device, ) -> PunicaTensors: - seq_len_tensor = torch.randint(seq_length, seq_length + 1, - (batches, )).to(device) + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, ).to(device) total_tokens = seq_len_tensor.sum() if op_type == "shrink": - inputs_tensor = torch.rand((total_tokens, hidden_size), - dtype=dtype).to(device) + inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device) lora_weights = torch.rand( (lora_nums, max_rank, hidden_size), # col-major dtype=dtype, ).to(device) # shrink op need atomic_add, so output is initinized by 0 - ref_out_tensor = torch.zeros((total_tokens, max_rank), - dtype=dtype, - device=inputs_tensor.device) + ref_out_tensor = torch.zeros( + (total_tokens, max_rank), dtype=dtype, device=inputs_tensor.device + ) # NOTE shrink kernel using torch.float32 as output type - our_out_tensor = torch.zeros((total_tokens, max_rank), - dtype=torch.float32).to(device) + our_out_tensor = torch.zeros((total_tokens, max_rank), dtype=torch.float32).to( + device + ) else: inputs_tensor = torch.rand( (total_tokens, max_rank), @@ -184,15 +182,16 @@ def generate_data( ).to(device) # Ensure the same input. our_out_tensor = ref_out_tensor.clone() - lora_indices_tensor = torch.randint(0, - lora_nums - 1 if lora_nums > 1 else 1, - (batches, )).to(device) + lora_indices_tensor = torch.randint( + 0, lora_nums - 1 if lora_nums > 1 else 1, (batches,) + ).to(device) indices = torch.zeros((total_tokens), dtype=torch.long).to(device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] - indices[current_offset:current_offset + - seq_len_tensor[b_id]].copy_(lora_index) + indices[current_offset : current_offset + seq_len_tensor[b_id]].copy_( + lora_index + ) current_offset += seq_len_tensor[b_id].item() return PunicaTensors( @@ -217,8 +216,7 @@ def generate_data_for_expand_nslices( nslices, device, ) -> PunicaTensors: - seq_len_tensor = torch.randint(seq_length, seq_length + 1, - (batches, )).to(device) + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, @@ -234,22 +232,25 @@ def generate_data_for_expand_nslices( torch.rand( (lora_nums, hidden_size, max_rank), # col-major dtype=dtype, - ).to(device)) + ).to(device) + ) # expand op needs to complete y+=a@lora_b, so output is # initinized randomly - ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), - dtype=dtype).to(device) + ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices), dtype=dtype).to( + device + ) # Ensure the same input. our_out_tensor = ref_out_tensor.clone() - lora_indices_tensor = torch.randint(0, - lora_nums - 1 if lora_nums > 1 else 1, - (batches, )) + lora_indices_tensor = torch.randint( + 0, lora_nums - 1 if lora_nums > 1 else 1, (batches,) + ) indices = torch.zeros((total_tokens), dtype=torch.long).to(device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] - indices[current_offset:current_offset + - seq_len_tensor[b_id]] = (lora_index.item()) + indices[current_offset : current_offset + seq_len_tensor[b_id]] = ( + lora_index.item() + ) current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) @@ -276,8 +277,7 @@ def generate_data_for_nslices( op_type, device, ) -> PunicaTensors: - seq_len_tensor = torch.randint(seq_length, seq_length + 1, - (batches, )).to(device) + seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches,)).to(device) b_seq_start_loc = torch.cumsum( torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), dim=0, @@ -286,9 +286,7 @@ def generate_data_for_nslices( lora_weights_lst = [] if op_type == "shrink": - - inputs_tensor = torch.rand((total_tokens, hidden_size), - dtype=dtype).to(device) + inputs_tensor = torch.rand((total_tokens, hidden_size), dtype=dtype).to(device) for _ in range(nslices): if op_type == "shrink": @@ -296,7 +294,8 @@ def generate_data_for_nslices( torch.rand( (lora_nums, max_rank, hidden_size), # col-major dtype=dtype, - ).to(device)) + ).to(device) + ) # NOTE shrink kernel using torch.float32 as output type # shrink op need atomic_add, so output is initinized by 0 our_out_tensor = torch.zeros( @@ -313,23 +312,26 @@ def generate_data_for_nslices( torch.rand( (lora_nums, hidden_size, max_rank), # col-major dtype=dtype, - ).to(device)) + ).to(device) + ) # expand op needs to complete y+=a@lora_b, so output is # initinized randomly - our_out_tensor = torch.rand((total_tokens, hidden_size * nslices), - dtype=dtype).to(device) + our_out_tensor = torch.rand( + (total_tokens, hidden_size * nslices), dtype=dtype + ).to(device) # Ensure the same input. ref_out_tensor = our_out_tensor.clone() - lora_indices_tensor = torch.randint(0, - lora_nums - 1 if lora_nums > 1 else 1, - (batches, )) + lora_indices_tensor = torch.randint( + 0, lora_nums - 1 if lora_nums > 1 else 1, (batches,) + ) indices = torch.zeros((total_tokens), dtype=torch.long).to(device) current_offset = 0 for b_id in range(batches): lora_index = lora_indices_tensor[b_id] - indices[current_offset:current_offset + - seq_len_tensor[b_id]] = (lora_index.item()) + indices[current_offset : current_offset + seq_len_tensor[b_id]] = ( + lora_index.item() + ) current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) @@ -379,24 +381,20 @@ def create_peft_lora( } for module_name in target_modules: - module = model for attr in module_name.split("."): module = getattr(module, attr) if hasattr(module, "input_size") and hasattr(module, "output_size"): - in_features = module.input_size out_features = module.output_size - elif hasattr(module, "embedding_dim") and hasattr( - module, "num_embeddings"): + elif hasattr(module, "embedding_dim") and hasattr(module, "num_embeddings"): # ParallelLMHead in_features = module.embedding_dim out_features = module.num_embeddings else: - raise ValueError( - f"Unable to determine dimensions for module {module_name}") + raise ValueError(f"Unable to determine dimensions for module {module_name}") lora_A = torch.randn(rank, in_features, dtype=lora_dtype) diff --git a/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py index 78d23acfec7c..cc899b77b5e9 100644 --- a/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py +++ b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py @@ -8,24 +8,25 @@ import torch from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, fastsafetensors_weights_iterator, - safetensors_weights_iterator) + download_weights_from_hf, + fastsafetensors_weights_iterator, + safetensors_weights_iterator, +) def test_fastsafetensors_model_loader(): with tempfile.TemporaryDirectory() as tmpdir: huggingface_hub.constants.HF_HUB_OFFLINE = False - download_weights_from_hf("openai-community/gpt2", - allow_patterns=["*.safetensors"], - cache_dir=tmpdir) + download_weights_from_hf( + "openai-community/gpt2", allow_patterns=["*.safetensors"], cache_dir=tmpdir + ) safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) assert len(safetensors) > 0 fastsafetensors_tensors = {} hf_safetensors_tensors = {} - for name, tensor in fastsafetensors_weights_iterator( - safetensors, True): + for name, tensor in fastsafetensors_weights_iterator(safetensors, True): fastsafetensors_tensors[name] = tensor for name, tensor in safetensors_weights_iterator(safetensors, True): @@ -34,13 +35,10 @@ def test_fastsafetensors_model_loader(): assert len(fastsafetensors_tensors) == len(hf_safetensors_tensors) for name, fastsafetensors_tensor in fastsafetensors_tensors.items(): - fastsafetensors_tensor = fastsafetensors_tensor.to('cpu') - assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[ - name].dtype - assert fastsafetensors_tensor.shape == hf_safetensors_tensors[ - name].shape - assert torch.all( - fastsafetensors_tensor.eq(hf_safetensors_tensors[name])) + fastsafetensors_tensor = fastsafetensors_tensor.to("cpu") + assert fastsafetensors_tensor.dtype == hf_safetensors_tensors[name].dtype + assert fastsafetensors_tensor.shape == hf_safetensors_tensors[name].shape + assert torch.all(fastsafetensors_tensor.eq(hf_safetensors_tensors[name])) if __name__ == "__main__": diff --git a/tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py b/tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py index e11e4c7289bc..3ad7308eeba2 100644 --- a/tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py +++ b/tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py @@ -8,11 +8,12 @@ import huggingface_hub.constants -from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf) -from vllm.transformers_utils.runai_utils import (ObjectStorageModel, - is_runai_obj_uri, - list_safetensors) +from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf +from vllm.transformers_utils.runai_utils import ( + ObjectStorageModel, + is_runai_obj_uri, + list_safetensors, +) def test_is_runai_obj_uri(): @@ -24,14 +25,14 @@ def test_is_runai_obj_uri(): def test_runai_list_safetensors_local(): with tempfile.TemporaryDirectory() as tmpdir: huggingface_hub.constants.HF_HUB_OFFLINE = False - download_weights_from_hf("openai-community/gpt2", - allow_patterns=["*.safetensors", "*.json"], - cache_dir=tmpdir) + download_weights_from_hf( + "openai-community/gpt2", + allow_patterns=["*.safetensors", "*.json"], + cache_dir=tmpdir, + ) safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) assert len(safetensors) > 0 - parentdir = [ - os.path.dirname(safetensor) for safetensor in safetensors - ][0] + parentdir = [os.path.dirname(safetensor) for safetensor in safetensors][0] files = list_safetensors(parentdir) assert len(safetensors) == len(files) @@ -50,9 +51,9 @@ def test_runai_pull_files_gcs(monkeypatch): # | cut -d":" -f2 | base64 -d | xxd -p expected_checksum = "f60dea775da1392434275b311b31a431" hasher = hashlib.new("md5") - with open(os.path.join(model.dir, filename), 'rb') as f: + with open(os.path.join(model.dir, filename), "rb") as f: # Read the file in chunks to handle large files efficiently - for chunk in iter(lambda: f.read(4096), b''): + for chunk in iter(lambda: f.read(4096), b""): hasher.update(chunk) actual_checksum = hasher.hexdigest() assert actual_checksum == expected_checksum diff --git a/tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py b/tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py index ee448c2ccb21..03691b4a472f 100644 --- a/tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py +++ b/tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py @@ -8,24 +8,25 @@ import torch from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, runai_safetensors_weights_iterator, - safetensors_weights_iterator) + download_weights_from_hf, + runai_safetensors_weights_iterator, + safetensors_weights_iterator, +) def test_runai_model_loader(): with tempfile.TemporaryDirectory() as tmpdir: huggingface_hub.constants.HF_HUB_OFFLINE = False - download_weights_from_hf("openai-community/gpt2", - allow_patterns=["*.safetensors"], - cache_dir=tmpdir) + download_weights_from_hf( + "openai-community/gpt2", allow_patterns=["*.safetensors"], cache_dir=tmpdir + ) safetensors = glob.glob(f"{tmpdir}/**/*.safetensors", recursive=True) assert len(safetensors) > 0 runai_model_streamer_tensors = {} hf_safetensors_tensors = {} - for name, tensor in runai_safetensors_weights_iterator( - safetensors, True): + for name, tensor in runai_safetensors_weights_iterator(safetensors, True): runai_model_streamer_tensors[name] = tensor for name, tensor in safetensors_weights_iterator(safetensors, True): diff --git a/tests/model_executor/model_loader/tensorizer_loader/conftest.py b/tests/model_executor/model_loader/tensorizer_loader/conftest.py index cc02d7ecf20b..add6d3742ff5 100644 --- a/tests/model_executor/model_loader/tensorizer_loader/conftest.py +++ b/tests/model_executor/model_loader/tensorizer_loader/conftest.py @@ -32,7 +32,6 @@ def cleanup(): @pytest.fixture() def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path): - def noop(*args, **kwargs): return None @@ -56,8 +55,7 @@ def model_path(model_ref, tmp_path): yield tmp_path / model_ref / "model.tensors" -def assert_from_collective_rpc(engine: LLM, closure: Callable, - closure_kwargs: dict): +def assert_from_collective_rpc(engine: LLM, closure: Callable, closure_kwargs: dict): res = engine.collective_rpc(method=closure, kwargs=closure_kwargs) return all(res) @@ -67,18 +65,13 @@ def assert_from_collective_rpc(engine: LLM, closure: Callable, # method. It's purely used as a dummy utility to run methods that test # Tensorizer functionality class DummyExecutor(UniProcExecutor): - def _init_executor(self) -> None: - """Initialize the worker and load the model. - """ - self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, - rpc_rank=0) - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) + """Initialize the worker and load the model.""" + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) + distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) local_rank = 0 # set local rank as the device index if specified - device_info = self.vllm_config.device_config.device.__str__().split( - ":") + device_info = self.vllm_config.device_config.device.__str__().split(":") if len(device_info) > 1: local_rank = int(device_info[1]) rank = 0 @@ -91,7 +84,7 @@ def _init_executor(self) -> None: is_driver_worker=is_driver_worker, ) self.mm_receiver_cache = None - self.collective_rpc("init_worker", args=([kwargs], )) + self.collective_rpc("init_worker", args=([kwargs],)) self.collective_rpc("init_device") @property @@ -99,5 +92,5 @@ def max_concurrent_batches(self) -> int: return 2 def shutdown(self): - if hasattr(self, 'thread_pool'): + if hasattr(self, "thread_pool"): self.thread_pool.shutdown(wait=False) diff --git a/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py b/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py index f50f04696738..57db1f98baed 100644 --- a/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py +++ b/tests/model_executor/model_loader/tensorizer_loader/test_tensorizer.py @@ -17,15 +17,16 @@ from tests.utils import VLLM_PATH, RemoteOpenAIServer from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs -# yapf: disable -from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, - TensorSerializer, - is_vllm_tensorized, - open_stream, - tensorize_vllm_model) +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, + TensorSerializer, + is_vllm_tensorized, + open_stream, + tensorize_vllm_model, +) from vllm.model_executor.model_loader.tensorizer_loader import ( - BLACKLISTED_TENSORIZER_ARGS) -# yapf: enable + BLACKLISTED_TENSORIZER_ARGS, +) from vllm.utils import PlaceholderModule from .conftest import DummyExecutor, assert_from_collective_rpc @@ -44,7 +45,7 @@ class TensorizerCaughtError(Exception): EXAMPLES_PATH = VLLM_PATH / "examples" -pytest_plugins = "pytest_asyncio", +pytest_plugins = ("pytest_asyncio",) prompts = [ "Hello, my name is", @@ -56,8 +57,7 @@ class TensorizerCaughtError(Exception): sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) -def patch_init_and_catch_error(self, obj, method_name, - expected_error: type[Exception]): +def patch_init_and_catch_error(self, obj, method_name, expected_error: type[Exception]): original = getattr(obj, method_name, None) if original is None: raise ValueError("Method '{}' not found.".format(method_name)) @@ -80,17 +80,19 @@ def assert_specific_tensorizer_error_is_raised( expected_error: type[Exception], ): with pytest.raises(TensorizerCaughtError): - executor.collective_rpc(patch_init_and_catch_error, - args=( - obj, - method_name, - expected_error, - )) + executor.collective_rpc( + patch_init_and_catch_error, + args=( + obj, + method_name, + expected_error, + ), + ) def is_curl_installed(): try: - subprocess.check_call(['curl', '--version']) + subprocess.check_call(["curl", "--version"]) return True except (subprocess.CalledProcessError, FileNotFoundError): return False @@ -99,13 +101,14 @@ def is_curl_installed(): def write_keyfile(keyfile_path: str): encryption_params = EncryptionParams.random() pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True) - with open(keyfile_path, 'wb') as f: + with open(keyfile_path, "wb") as f: f.write(encryption_params.key) @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_deserialized_encrypted_vllm_model_has_same_outputs( - model_ref, vllm_runner, tmp_path, model_path): + model_ref, vllm_runner, tmp_path, model_path +): args = EngineArgs(model=model_ref) with vllm_runner(model_ref) as vllm_model: key_path = tmp_path / model_ref / "model.key" @@ -113,29 +116,30 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( outputs = vllm_model.generate(prompts, sampling_params) - config_for_serializing = TensorizerConfig(tensorizer_uri=str(model_path), - encryption_keyfile=str(key_path)) + config_for_serializing = TensorizerConfig( + tensorizer_uri=str(model_path), encryption_keyfile=str(key_path) + ) tensorize_vllm_model(args, config_for_serializing) config_for_deserializing = TensorizerConfig( - tensorizer_uri=str(model_path), encryption_keyfile=str(key_path)) - - with vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=config_for_deserializing - ) as loaded_vllm_model: # noqa: E501 + tensorizer_uri=str(model_path), encryption_keyfile=str(key_path) + ) - deserialized_outputs = loaded_vllm_model.generate( - prompts, sampling_params) + with vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=config_for_deserializing, + ) as loaded_vllm_model: # noqa: E501 + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 assert outputs == deserialized_outputs -def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, - tmp_path, model_ref, - model_path): +def test_deserialized_hf_model_has_same_outputs( + hf_runner, vllm_runner, tmp_path, model_ref, model_path +): with hf_runner(model_ref) as hf_model: max_tokens = 50 outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) @@ -143,14 +147,17 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, serializer = TensorSerializer(stream) serializer.write_module(hf_model.model) - with vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=TensorizerConfig( - tensorizer_uri=str(model_path), - num_readers=1, - )) as loaded_hf_model: + with vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=str(model_path), + num_readers=1, + ), + ) as loaded_hf_model: deserialized_outputs = loaded_hf_model.generate_greedy( - prompts, max_tokens=max_tokens) + prompts, max_tokens=max_tokens + ) assert outputs == deserialized_outputs @@ -159,35 +166,37 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref): model = None try: model = vllm_runner( - model_ref, - model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) + model_ref, model_loader_extra_config=TensorizerConfig(tensorizer_uri="test") + ) pytest.fail("Expected RuntimeError for extra config keys") except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert ("ValueError: Unexpected extra config keys for load " - "format auto") in combined_output + assert ( + "ValueError: Unexpected extra config keys for load format auto" + ) in combined_output finally: del model gc.collect() torch.cuda.empty_cache() -def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, - model_ref): +def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, model_ref): model = None try: model = vllm_runner( model_ref, load_format="safetensors", - model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) + model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"), + ) pytest.fail("Expected RuntimeError for extra config keys") except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert ("ValueError: Unexpected extra config keys " - "for load format safetensors") in combined_output + assert ( + "ValueError: Unexpected extra config keys for load format safetensors" + ) in combined_output finally: del model gc.collect() @@ -214,21 +223,24 @@ def test_tensorizer_with_tp_path_without_template(vllm_runner, capfd): except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert ("ValueError: For a sharded model, tensorizer_uri " - "should include a string format template like '%04d' " - "to be formatted with the rank " - "of the shard") in combined_output + assert ( + "ValueError: For a sharded model, tensorizer_uri " + "should include a string format template like '%04d' " + "to be formatted with the rank " + "of the shard" + ) in combined_output @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2 GPUs") def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( - vllm_runner, tmp_path): + vllm_runner, tmp_path +): model_ref = "EleutherAI/pythia-1.4b" # record outputs from un-sharded un-tensorized model with vllm_runner( - model_ref, - disable_custom_all_reduce=True, - enforce_eager=True, + model_ref, + disable_custom_all_reduce=True, + enforce_eager=True, ) as base_model: outputs = base_model.generate(prompts, sampling_params) @@ -254,21 +266,22 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( assert os.path.isfile(model_path % 1), "Serialization subprocess failed" with vllm_runner( - model_ref, - tensor_parallel_size=2, - load_format="tensorizer", - disable_custom_all_reduce=True, - enforce_eager=True, - model_loader_extra_config=tensorizer_config) as loaded_vllm_model: - deserialized_outputs = loaded_vllm_model.generate( - prompts, sampling_params) + model_ref, + tensor_parallel_size=2, + load_format="tensorizer", + disable_custom_all_reduce=True, + enforce_eager=True, + model_loader_extra_config=tensorizer_config, + ) as loaded_vllm_model: + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) assert outputs == deserialized_outputs @pytest.mark.flaky(reruns=3) -def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner, - tmp_path, model_path): +def test_vllm_tensorized_model_has_same_outputs( + model_ref, vllm_runner, tmp_path, model_path +): gc.collect() torch.cuda.empty_cache() config = TensorizerConfig(tensorizer_uri=str(model_path)) @@ -280,11 +293,10 @@ def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner, tensorize_vllm_model(args, config) assert is_vllm_tensorized(config) - with vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=config) as loaded_vllm_model: - deserialized_outputs = loaded_vllm_model.generate( - prompts, sampling_params) + with vllm_runner( + model_ref, load_format="tensorizer", model_loader_extra_config=config + ) as loaded_vllm_model: + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 assert outputs == deserialized_outputs @@ -314,15 +326,17 @@ def test_load_with_just_model_tensors(just_serialize_model_tensors, model_ref): def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path): - serialization_params = { "limit_cpu_concurrency": 2, } model_ref = "facebook/opt-125m" model_path = tmp_path / (model_ref + ".tensors") - config = TensorizerConfig(tensorizer_uri=str(model_path), - serialization_kwargs=serialization_params) - llm = LLM(model=model_ref, ) + config = TensorizerConfig( + tensorizer_uri=str(model_path), serialization_kwargs=serialization_params + ) + llm = LLM( + model=model_ref, + ) def serialization_test(self, *args, **kwargs): # This is performed in the ephemeral worker process, so monkey-patching @@ -340,10 +354,13 @@ def tensorizer_serializer_wrapper(self, *args, **kwargs): return original(self, *args, **kwargs) tensorizer.serialization.TensorSerializer.__init__ = ( - tensorizer_serializer_wrapper) + tensorizer_serializer_wrapper + ) tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"]) - self.save_tensorized_model(tensorizer_config=tensorizer_config, ) + self.save_tensorized_model( + tensorizer_config=tensorizer_config, + ) return to_compare | original_dict == to_compare kwargs = {"tensorizer_config": config.to_serializable()} @@ -351,9 +368,7 @@ def tensorizer_serializer_wrapper(self, *args, **kwargs): assert assert_from_collective_rpc(llm, serialization_test, kwargs) -def test_assert_deserialization_kwargs_passed_to_tensor_deserializer( - tmp_path, capfd): - +def test_assert_deserialization_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): deserialization_kwargs = { "num_readers": "bar", # illegal value } @@ -364,8 +379,9 @@ def test_assert_deserialization_kwargs_passed_to_tensor_deserializer( model_ref = "facebook/opt-125m" model_path = tmp_path / (model_ref + ".tensors") - config = TensorizerConfig(tensorizer_uri=str(model_path), - serialization_kwargs=serialization_params) + config = TensorizerConfig( + tensorizer_uri=str(model_path), serialization_kwargs=serialization_params + ) args = EngineArgs(model=model_ref) tensorize_vllm_model(args, config) @@ -393,7 +409,6 @@ def test_assert_deserialization_kwargs_passed_to_tensor_deserializer( def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): - deserialization_kwargs = { "num_readers": 1, } @@ -404,8 +419,9 @@ def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): model_ref = "facebook/opt-125m" model_path = tmp_path / (model_ref + ".tensors") - config = TensorizerConfig(tensorizer_uri=str(model_path), - serialization_kwargs=serialization_params) + config = TensorizerConfig( + tensorizer_uri=str(model_path), serialization_kwargs=serialization_params + ) args = EngineArgs(model=model_ref) tensorize_vllm_model(args, config) @@ -441,16 +457,24 @@ async def test_serialize_and_serve_entrypoints(tmp_path): suffix = "test" try: - result = subprocess.run([ - sys.executable, - f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", - model_ref, "serialize", "--serialized-directory", - str(tmp_path), "--suffix", suffix, "--serialization-kwargs", - '{"limit_cpu_concurrency": 4}' - ], - check=True, - capture_output=True, - text=True) + result = subprocess.run( + [ + sys.executable, + f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", + "--model", + model_ref, + "serialize", + "--serialized-directory", + str(tmp_path), + "--suffix", + suffix, + "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}', + ], + check=True, + capture_output=True, + text=True, + ) except subprocess.CalledProcessError as e: print("Tensorizing failed.") print("STDOUT:\n", e.stdout) @@ -470,14 +494,20 @@ async def test_serialize_and_serve_entrypoints(tmp_path): "deserialization_kwargs": { "verify_hash": True, "num_readers": 8, - } + }, } cmd = [ - "-m", "vllm.entrypoints.cli.main", "serve", "--host", "localhost", - "--load-format", "tensorizer", model_ref, + "-m", + "vllm.entrypoints.cli.main", + "serve", + "--host", + "localhost", + "--load-format", + "tensorizer", + model_ref, "--model-loader-extra-config", - json.dumps(model_loader_extra_config, indent=2) + json.dumps(model_loader_extra_config, indent=2), ] proc = await asyncio.create_subprocess_exec( @@ -500,17 +530,16 @@ async def test_serialize_and_serve_entrypoints(tmp_path): @pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS) -def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, - illegal_value): - +def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, illegal_value): serialization_params = { "limit_cpu_concurrency": 2, } model_ref = "facebook/opt-125m" model_path = tmp_path / (model_ref + ".tensors") - config = TensorizerConfig(tensorizer_uri=str(model_path), - serialization_kwargs=serialization_params) + config = TensorizerConfig( + tensorizer_uri=str(model_path), serialization_kwargs=serialization_params + ) args = EngineArgs(model=model_ref) tensorize_vllm_model(args, config) @@ -526,5 +555,6 @@ def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, except RuntimeError: out, err = capfd.readouterr() combined_output = out + err - assert (f"ValueError: {illegal_value} is not an allowed " - f"Tensorizer argument.") in combined_output + assert ( + f"ValueError: {illegal_value} is not an allowed Tensorizer argument." + ) in combined_output diff --git a/tests/model_executor/model_loader/test_registry.py b/tests/model_executor/model_loader/test_registry.py index 639ee6db9270..020988ccac13 100644 --- a/tests/model_executor/model_loader/test_registry.py +++ b/tests/model_executor/model_loader/test_registry.py @@ -6,22 +6,19 @@ from vllm.config import ModelConfig from vllm.config.load import LoadConfig -from vllm.model_executor.model_loader import (get_model_loader, - register_model_loader) +from vllm.model_executor.model_loader import get_model_loader, register_model_loader from vllm.model_executor.model_loader.base_loader import BaseModelLoader @register_model_loader("custom_load_format") class CustomModelLoader(BaseModelLoader): - def __init__(self, load_config: LoadConfig) -> None: super().__init__(load_config) def download_model(self, model_config: ModelConfig) -> None: pass - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: pass diff --git a/tests/model_executor/model_loader/test_sharded_state_loader.py b/tests/model_executor/model_loader/test_sharded_state_loader.py index 785169f5d22e..5bb841bf2fa0 100644 --- a/tests/model_executor/model_loader/test_sharded_state_loader.py +++ b/tests/model_executor/model_loader/test_sharded_state_loader.py @@ -35,11 +35,13 @@ def test_filter_subtensors(): "b": torch.empty((2, 4)), "c": torch.empty((2, 4, 8)), } - state_dict.update({ - "x": state_dict["b"], - "y": state_dict["c"][1, 2, :], - "z": state_dict["c"][1, :, 4], - }) + state_dict.update( + { + "x": state_dict["b"], + "y": state_dict["c"][1, 2, :], + "z": state_dict["c"][1, :, 4], + } + ) filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") for key, tensor in filtered_state_dict.items(): @@ -49,8 +51,9 @@ def test_filter_subtensors(): @pytest.fixture(scope="module") def llama_3p2_1b_files(): - input_dir = snapshot_download("meta-llama/Llama-3.2-1B-Instruct", - ignore_patterns=["*.bin*", "original/*"]) + input_dir = snapshot_download( + "meta-llama/Llama-3.2-1B-Instruct", ignore_patterns=["*.bin*", "original/*"] + ) yield input_dir @@ -63,8 +66,7 @@ def _run_writer(input_dir, output_dir, weights_patterns, **kwargs): if is_v1_engine: # For V1 engine, we need to use engine_core.save_sharded_state print("Using V1 engine save path") - llm_sharded_writer.llm_engine.engine_core.save_sharded_state( - path=output_dir) + llm_sharded_writer.llm_engine.engine_core.save_sharded_state(path=output_dir) else: # For V0 engine print("Using V0 engine save path") @@ -74,8 +76,9 @@ def _run_writer(input_dir, output_dir, weights_patterns, **kwargs): # Copy metadata files to output directory for file in os.listdir(input_dir): if os.path.isdir(os.path.join(input_dir, file)): - shutil.copytree(os.path.join(input_dir, file), - os.path.join(output_dir, file)) + shutil.copytree( + os.path.join(input_dir, file), os.path.join(output_dir, file) + ) elif not any(fnmatch.fnmatch(file, ext) for ext in weights_patterns): shutil.copy(os.path.join(input_dir, file), output_dir) @@ -90,37 +93,42 @@ def _run_generate(input_dir, queue: mp.Queue, **kwargs): @pytest.mark.parametrize("enable_lora", [False, True]) @pytest.mark.parametrize("tp_size", [1, 2]) -def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, - llama_3p2_1b_files): +def test_sharded_state_loader( + enable_lora, tp_size, num_gpus_available, llama_3p2_1b_files +): if num_gpus_available < tp_size: pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") - weights_patterns = ("*.safetensors", ) + weights_patterns = ("*.safetensors",) gpu_memory_utilization = 0.8 input_dir = llama_3p2_1b_files ctx = mp.get_context("spawn") # Run in separate processes for memory & CUDA isolation with TemporaryDirectory() as output_dir: - p = ctx.Process(target=_run_writer, - args=(input_dir, output_dir, weights_patterns), - kwargs=dict( - tensor_parallel_size=tp_size, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=True, - )) + p = ctx.Process( + target=_run_writer, + args=(input_dir, output_dir, weights_patterns), + kwargs=dict( + tensor_parallel_size=tp_size, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=True, + ), + ) p.start() p.join() queue = ctx.Queue() - p = ctx.Process(target=_run_generate, - args=(input_dir, queue), - kwargs=dict( - enable_lora=enable_lora, - gpu_memory_utilization=gpu_memory_utilization, - tensor_parallel_size=tp_size, - )) + p = ctx.Process( + target=_run_generate, + args=(input_dir, queue), + kwargs=dict( + enable_lora=enable_lora, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size, + ), + ) p.start() # Call queue.get() before p.join() to prevent deadlock: # If p.join() is called before queue.get() and the queue is full, @@ -134,14 +142,16 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, queue = ctx.Queue() - p = ctx.Process(target=_run_generate, - args=(output_dir, queue), - kwargs=dict( - enable_lora=enable_lora, - gpu_memory_utilization=gpu_memory_utilization, - tensor_parallel_size=tp_size, - load_format="sharded_state", - )) + p = ctx.Process( + target=_run_generate, + args=(output_dir, queue), + kwargs=dict( + enable_lora=enable_lora, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size, + load_format="sharded_state", + ), + ) p.start() # Call queue.get() before p.join() to prevent deadlock: # If p.join() is called before queue.get() and the queue is full, diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 200b6ecd5852..12aad4cb8da0 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -7,16 +7,24 @@ from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.activation import (GeluAndMul, - ReLUSquaredActivation, - SiluAndMul) -from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func, - vllm_topk_softmax) +from vllm.model_executor.layers.activation import ( + GeluAndMul, + ReLUSquaredActivation, + SiluAndMul, +) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + dispatch_topk_func, + vllm_topk_softmax, +) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) -from vllm.model_executor.layers.layernorm import (RMSNorm, - dispatch_rocm_rmsnorm_func, - fused_add_rms_norm, rms_norm) + is_rocm_aiter_moe_enabled, +) +from vllm.model_executor.layers.layernorm import ( + RMSNorm, + dispatch_rocm_rmsnorm_func, + fused_add_rms_norm, + rms_norm, +) from vllm.platforms import current_platform RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] @@ -65,14 +73,21 @@ class Relu3(ReLUSquaredActivation): ("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False), # All but RMSNorm ("all,-rms_norm", 4, True, [0, 1, 1, 1], True), - ]) -def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool, - ops_enabled: list[int], default_on: bool): - custom_ops = env.split(',') if env else [] + ], +) +def test_enabled_ops( + env: Optional[str], + torch_level: int, + use_inductor: bool, + ops_enabled: list[int], + default_on: bool, +): + custom_ops = env.split(",") if env else [] vllm_config = VllmConfig( - compilation_config=CompilationConfig(use_inductor=bool(use_inductor), - level=torch_level, - custom_ops=custom_ops)) + compilation_config=CompilationConfig( + use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops + ) + ) with set_current_vllm_config(vllm_config): assert CustomOp.default_on() == default_on @@ -100,11 +115,13 @@ class SiluAndMul2(SiluAndMul): @pytest.mark.parametrize( - "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) + "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"] +) def test_enabled_ops_invalid(env: str): with pytest.raises(Exception): # noqa - vllm_config = VllmConfig(compilation_config=CompilationConfig( - custom_ops=env.split(","))) + vllm_config = VllmConfig( + compilation_config=CompilationConfig(custom_ops=env.split(",")) + ) with set_current_vllm_config(vllm_config): RMSNorm(1024).enabled() @@ -116,28 +133,38 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): is_rocm_aiter_moe_enabled.cache_clear() if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_topk_softmax) + rocm_aiter_topk_softmax, + ) + assert topk_func == rocm_aiter_topk_softmax else: assert topk_func == vllm_topk_softmax @pytest.mark.parametrize("add_residual", [True, False]) -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="AITER is a feature exclusive for ROCm") -def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype, - use_rocm_aiter: str, use_rocm_aiter_norm: str, - monkeypatch): +@pytest.mark.skipif( + not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm" +) +def test_rms_norm_dispatch( + add_residual: bool, + dtype: torch.dtype, + use_rocm_aiter: str, + use_rocm_aiter_norm: str, + monkeypatch, +): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype) - should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \ - and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES + should_use_rocm_aiter = ( + current_platform.is_rocm() + and int(use_rocm_aiter) + and int(use_rocm_aiter_norm) + and dtype in RMS_NORM_SUPPORTED_DTYPES + ) if add_residual and should_use_rocm_aiter: assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index c7b15c6ae118..489ac1e6475b 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -5,8 +5,12 @@ import pytest -from vllm.model_executor.layers.pooler import (CLSPool, DispatchPooler, - MeanPool, PoolingType) +from vllm.model_executor.layers.pooler import ( + CLSPool, + DispatchPooler, + MeanPool, + PoolingType, +) from vllm.model_executor.models.bert import BertEmbeddingModel from vllm.model_executor.models.roberta import RobertaEmbeddingModel from vllm.platforms import current_platform @@ -15,25 +19,28 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") REVISION = os.environ.get("REVISION", "main") -MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME", - "intfloat/multilingual-e5-base") +MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME", "intfloat/multilingual-e5-base") REVISION_ROBERTA = os.environ.get("REVISION", "main") -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_model_loading_with_params(vllm_runner, monkeypatch): """ Test parameter weight loading with tp>1. """ # to use apply_model monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - with vllm_runner(model_name=MODEL_NAME, - revision=REVISION, - dtype="float16", - max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.embed("Write a short story about a robot that" - " dreams for the first time.\n") + with vllm_runner( + model_name=MODEL_NAME, + revision=REVISION, + dtype="float16", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + output = vllm_model.embed( + "Write a short story about a robot that dreams for the first time.\n" + ) model_config = vllm_model.llm.llm_engine.model_config model_tokenizer = vllm_model.llm.llm_engine.tokenizer @@ -60,20 +67,24 @@ def check_model(model): assert output -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_roberta_model_loading_with_params(vllm_runner, monkeypatch): """ Test parameter weight loading with tp>1. """ # to use apply_model monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - with vllm_runner(model_name=MODEL_NAME_ROBERTA, - revision=REVISION_ROBERTA, - dtype="float16", - max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.embed("Write a short story about a robot that" - " dreams for the first time.\n") + with vllm_runner( + model_name=MODEL_NAME_ROBERTA, + revision=REVISION_ROBERTA, + dtype="float16", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + output = vllm_model.embed( + "Write a short story about a robot that dreams for the first time.\n" + ) model_config = vllm_model.llm.llm_engine.model_config model_tokenizer = vllm_model.llm.llm_engine.tokenizer @@ -93,16 +104,16 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch): def check_model(model): assert isinstance(model, RobertaEmbeddingModel) assert isinstance(pooler := model.pooler, DispatchPooler) - assert isinstance(pooler.poolers_by_task["embed"].pooling, - MeanPool) + assert isinstance(pooler.poolers_by_task["embed"].pooling, MeanPool) vllm_model.apply_model(check_model) assert output -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch): """ Test loading roberta-base model with no lm_head. @@ -110,11 +121,12 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner, monkeypatch): # to use apply_model monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") model_name = "FacebookAI/roberta-base" - with vllm_runner(model_name=model_name, - dtype="float16", - max_model_len=MAX_MODEL_LEN) as vllm_model: - output = vllm_model.embed("Write a short story about a robot that" - " dreams for the first time.\n") + with vllm_runner( + model_name=model_name, dtype="float16", max_model_len=MAX_MODEL_LEN + ) as vllm_model: + output = vllm_model.embed( + "Write a short story about a robot that dreams for the first time.\n" + ) assert vllm_model.llm.llm_engine.model_config.tokenizer == model_name diff --git a/tests/model_executor/test_weight_utils.py b/tests/model_executor/test_weight_utils.py index df625b8d6004..6dc120ddbac9 100644 --- a/tests/model_executor/test_weight_utils.py +++ b/tests/model_executor/test_weight_utils.py @@ -9,23 +9,24 @@ from huggingface_hub.utils import LocalEntryNotFoundError from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, enable_hf_transfer) + download_weights_from_hf, + enable_hf_transfer, +) def test_hf_transfer_auto_activation(): if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ: # in case it is already set, we can't test the auto activation - pytest.skip( - "HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation") + pytest.skip("HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation") enable_hf_transfer() try: # enable hf hub transfer if available import hf_transfer # type: ignore # noqa + HF_TRANSFER_ACTIVE = True except ImportError: HF_TRANSFER_ACTIVE = False - assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == - HF_TRANSFER_ACTIVE) + assert huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == HF_TRANSFER_ACTIVE def test_download_weights_from_hf(): @@ -34,22 +35,30 @@ def test_download_weights_from_hf(): # if offline is set and model is not cached huggingface_hub.constants.HF_HUB_OFFLINE = True with pytest.raises(LocalEntryNotFoundError): - download_weights_from_hf("facebook/opt-125m", - allow_patterns=["*.safetensors", "*.bin"], - cache_dir=tmpdir) + download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir, + ) # download the model huggingface_hub.constants.HF_HUB_OFFLINE = False - download_weights_from_hf("facebook/opt-125m", - allow_patterns=["*.safetensors", "*.bin"], - cache_dir=tmpdir) + download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir, + ) # now it should work offline huggingface_hub.constants.HF_HUB_OFFLINE = True - assert download_weights_from_hf( - "facebook/opt-125m", - allow_patterns=["*.safetensors", "*.bin"], - cache_dir=tmpdir) is not None + assert ( + download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir, + ) + is not None + ) if __name__ == "__main__": diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 39c4dd735b72..3fc265194e2a 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -51,8 +51,9 @@ pytest.param( "google/gemma-1.1-2b-it", # gemma marks=[ - pytest.mark.core_model, pytest.mark.cpu_model, - pytest.mark.slow_test + pytest.mark.core_model, + pytest.mark.cpu_model, + pytest.mark.slow_test, ], ), pytest.param( @@ -65,8 +66,7 @@ pytest.param( "openbmb/MiniCPM3-4B", # fused_moe not supported on CPU - marks=[pytest.mark.core_model, - large_gpu_mark(min_gb=32)], + marks=[pytest.mark.core_model, large_gpu_mark(min_gb=32)], ), pytest.param( "facebook/opt-125m", # opt @@ -82,8 +82,9 @@ pytest.param( "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 marks=[ - pytest.mark.core_model, pytest.mark.cpu_model, - pytest.mark.slow_test + pytest.mark.core_model, + pytest.mark.cpu_model, + pytest.mark.slow_test, ], ), pytest.param( @@ -100,16 +101,25 @@ marks=[pytest.mark.cpu_model], ), pytest.param("swiss-ai/Apertus-8B-2509"), # apertus - ]) + ], +) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) @pytest.mark.parametrize("use_prompt_embeds", [True, False]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - max_tokens: int, num_logprobs: int, use_rocm_aiter: bool, - use_prompt_embeds: bool, monkeypatch) -> None: - +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + max_tokens: int, + num_logprobs: int, + use_rocm_aiter: bool, + use_prompt_embeds: bool, + monkeypatch, +) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") @@ -125,34 +135,37 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds - else None) + prompt_embeds: Optional[list[torch.Tensor]] = [] if use_prompt_embeds else None prompt_token_ids = [] for prompt in example_prompts: - token_ids = hf_model.tokenizer(prompt, - return_tensors="pt").input_ids.to( - hf_model.model.device) + token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids.to( + hf_model.model.device + ) prompt_token_ids.append(token_ids) if prompt_embeds is not None: - prompt_embeds.append(hf_model.model.get_input_embeddings()( - token_ids).squeeze(0)) + prompt_embeds.append( + hf_model.model.get_input_embeddings()(token_ids).squeeze(0) + ) with vllm_runner( - model, - tokenizer_name=model_info.tokenizer or model, - tokenizer_mode=model_info.tokenizer_mode, - trust_remote_code=model_info.trust_remote_code, - max_num_seqs=2, - enable_prompt_embeds=use_prompt_embeds, + model, + tokenizer_name=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + max_num_seqs=2, + enable_prompt_embeds=use_prompt_embeds, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) if prompt_embeds is not None: vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( - prompt_embeds, max_tokens, num_logprobs) + prompt_embeds, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py index 60a4bc14be88..246b893be315 100644 --- a/tests/models/language/generation/test_gemma.py +++ b/tests/models/language/generation/test_gemma.py @@ -11,17 +11,17 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None: with monkeypatch.context() as m: m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner( - model, - load_format="dummy", + model, + load_format="dummy", ) as llm: if model == "google/gemma-3-4b-it": normalizers = llm.llm.collective_rpc( - lambda self: self.model_runner.model.language_model.model. - normalizer.cpu().item()) + lambda self: self.model_runner.model.language_model.model.normalizer.cpu().item() # noqa: E501 + ) config = llm.llm.llm_engine.model_config.hf_config.text_config else: normalizers = llm.llm.collective_rpc( - lambda self: self.model_runner.model.model.normalizer.cpu( - ).item()) + lambda self: self.model_runner.model.model.normalizer.cpu().item() + ) config = llm.llm.llm_engine.model_config.hf_config assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3) diff --git a/tests/models/language/generation/test_granite.py b/tests/models/language/generation/test_granite.py index 2a39f78a708e..e569e75ff3a8 100644 --- a/tests/models/language/generation/test_granite.py +++ b/tests/models/language/generation/test_granite.py @@ -26,11 +26,13 @@ def test_models( ) -> None: with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 9d67b46f2e3e..abedd15b0d7e 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable + import pytest from tests.models.registry import HF_EXAMPLE_MODELS @@ -8,7 +10,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams -from ...utils import check_logprobs_close +from ...utils import check_logprobs_close, check_outputs_equal # Mark all tests as hybrid pytestmark = pytest.mark.hybrid_model @@ -22,7 +24,7 @@ "tiiuae/falcon-mamba-tiny-dev", # mamba2-codestral in transformers is broken pending: # https://github.com/huggingface/transformers/pull/40861 - #"yujiepan/mamba2-codestral-v0.1-tiny-random", + # "yujiepan/mamba2-codestral-v0.1-tiny-random", ] HYBRID_MODELS = [ @@ -63,7 +65,6 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -73,11 +74,13 @@ def test_models( with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -107,13 +110,14 @@ def test_batching( for_loop_outputs = [] with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: for prompt in example_prompts: - single_output, = vllm_model.generate_greedy_logprobs([prompt], - max_tokens, - num_logprobs) + (single_output,) = vllm_model.generate_greedy_logprobs( + [prompt], max_tokens, num_logprobs + ) for_loop_outputs.append(single_output) batched_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=for_loop_outputs, @@ -132,8 +136,8 @@ def test_chunked_prefill_with_parallel_sampling( max_tokens: int, ) -> None: """ - Tests chunked prefill in conjunction with n > 1. - + Tests chunked prefill in conjunction with n > 1. + In this case, prefill is populated with decoding tokens and we test that it doesn't fail. @@ -141,16 +145,13 @@ def test_chunked_prefill_with_parallel_sampling( decoding steps inside a chunked prefill forward pass (where we have both prefill and decode together) """ - sampling_params = SamplingParams(n=3, - temperature=1, - seed=0, - max_tokens=max_tokens) + sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens) with vllm_runner( - model, - enable_chunked_prefill=True, - # forces prefill chunks with decoding - max_num_batched_tokens=MAX_NUM_SEQS * 3, - max_num_seqs=MAX_NUM_SEQS, + model, + enable_chunked_prefill=True, + # forces prefill chunks with decoding + max_num_batched_tokens=MAX_NUM_SEQS * 3, + max_num_seqs=MAX_NUM_SEQS, ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) @@ -168,10 +169,8 @@ def test_mamba_cache_cg_padding( batch size. If it's not, a torch RuntimeError will be raised because tensor dimensions aren't compatible. """ - vllm_config = EngineArgs(model=model, - trust_remote_code=True).create_engine_config() - while len(example_prompts) == vllm_config.pad_for_cudagraph( - len(example_prompts)): + vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config() + while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)): example_prompts.append(example_prompts[0]) try: @@ -181,7 +180,8 @@ def test_mamba_cache_cg_padding( pytest.fail( "Couldn't run batch size which is not equal to a Cuda Graph " "captured batch size. " - "Could be related to mamba cache not padded correctly") + "Could be related to mamba cache not padded correctly" + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -203,8 +203,10 @@ def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks( with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_model.generate_greedy([example_prompts[0]] * 100, 10) except ValueError: - pytest.fail("Hybrid inner state wasn't cleaned up properly between" - "steps finished requests registered unnecessarily ") + pytest.fail( + "Hybrid inner state wasn't cleaned up properly between" + "steps finished requests registered unnecessarily " + ) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @@ -213,10 +215,10 @@ def test_state_cleanup( example_prompts, model: str, ) -> None: - """ + """ This test is for verifying that the Hybrid state is cleaned up between steps. - + If it's not cleaned, an error would be expected. """ try: @@ -224,8 +226,10 @@ def test_state_cleanup( for _ in range(10): vllm_model.generate_greedy([example_prompts[0]] * 100, 1) except ValueError: - pytest.fail("Hybrid inner state wasn't cleaned up between states, " - "could be related to finished_requests_ids") + pytest.fail( + "Hybrid inner state wasn't cleaned up between states, " + "could be related to finished_requests_ids" + ) @multi_gpu_test(num_gpus=2) @@ -239,15 +243,19 @@ def test_distributed_correctness( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model, tensor_parallel_size=1, - max_num_seqs=MAX_NUM_SEQS) as vllm_model: + with vllm_runner( + model, tensor_parallel_size=1, max_num_seqs=MAX_NUM_SEQS + ) as vllm_model: vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model, tensor_parallel_size=2, - max_num_seqs=MAX_NUM_SEQS) as vllm_model: + with vllm_runner( + model, tensor_parallel_size=2, max_num_seqs=MAX_NUM_SEQS + ) as vllm_model: vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=vllm_outputs_tp_1, @@ -269,7 +277,6 @@ def test_full_cuda_graph( max_tokens: int, num_logprobs: int, ) -> None: - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -279,11 +286,13 @@ def test_full_cuda_graph( with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -296,8 +305,9 @@ def test_full_cuda_graph( @pytest.mark.parametrize("model", FP32_STATE_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("cache_dtype_param", - ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]) +@pytest.mark.parametrize( + "cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"] +) def test_fp32_cache_state( hf_runner, vllm_runner, @@ -308,7 +318,6 @@ def test_fp32_cache_state( num_logprobs: int, cache_dtype_param: str, ) -> None: - try: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") @@ -318,13 +327,15 @@ def test_fp32_cache_state( with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model, - max_num_seqs=MAX_NUM_SEQS, - **{cache_dtype_param: "float32"}) as vllm_model: + with vllm_runner( + model, max_num_seqs=MAX_NUM_SEQS, **{cache_dtype_param: "float32"} + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -332,3 +343,417 @@ def test_fp32_cache_state( name_0="hf", name_1="vllm", ) + + +# Helper functions for the APC tests +def _get_vllm_runner_params(model, max_model_len, tensor_parallel_size=1): + return { + "model_name": model, + "enable_prefix_caching": False, + "max_model_len": max_model_len, + "tensor_parallel_size": tensor_parallel_size, + "gpu_memory_utilization": 0.4, + } + + +def _get_vLLM_output( + vllm_runner, + kwargs, + prompts, + max_tokens, + num_logprobs, + num_repetitions=1, + vllm_model=None, +): + outs = [] + if vllm_model is None: + vllm_model = vllm_runner(**kwargs) + for _ in range(num_repetitions): + if num_logprobs < 0: + vllm_output = vllm_model.generate_greedy(prompts, max_tokens) + else: + vllm_output = vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs + ) + outs.append(vllm_output) + + return outs, vllm_model + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_single_prompt( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = ( + check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore + ) + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * example_prompts[0]] + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size + ) + vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs + ) + + vllm_runner_kwargs["enable_prefix_caching"] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + ) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_single_prompt_block_align_alignment( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = ( + check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore + ) + + MULTIPLE = 300 + + # Sample prompts. This custom prompt is used, as it causes the most issues + generated_prompts = ["The president of the United States is " * MULTIPLE] + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size + ) + vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs + ) + + vllm_runner_kwargs["enable_prefix_caching"] = True + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + # Retrieve the default mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size + + # In case the hybrid model does not have the + # "mamba_block_size" assume a fixed constant + if mamba_block_size is None: + mamba_block_size = 512 + + mamba_block_size_multiplier = 10 + for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]: + vllm_runner_kwargs["max_num_batched_tokens"] = ( + mamba_block_size_multiplier * mamba_block_size - offsets + ) + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + ) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_multiple_prompts_all_cached_outputs( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = ( + check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore + ) + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size + ) + vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs + ) + + vllm_runner_kwargs["enable_prefix_caching"] = True + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + ) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_multiple_prompts_block_align_alignment( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = ( + check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore + ) + + MULTIPLE = 300 + + # Sample prompts. This custom prompt is used, as it causes the most issues + prompt_text = "The president of the United States is " + prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31] + generated_prompts = [prompt_text[offset:] * MULTIPLE for offset in prompt_offsets] + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size + ) + vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs + ) + + vllm_runner_kwargs["enable_prefix_caching"] = True + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + # Retrieve the default mamba state block size + mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size + + # In case the hybrid model does not have the + # "mamba_block_size" assume a fixed constant + if mamba_block_size is None: + mamba_block_size = 512 + + mamba_block_size_multiplier = 10 + for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]: + vllm_runner_kwargs["max_num_batched_tokens"] = ( + mamba_block_size_multiplier * mamba_block_size - offsets + ) + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + ) + + # Check alignment of the output logits when using APC + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) + + +@pytest.mark.parametrize("model", [HYBRID_MODELS[3]]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("n_repetitions", [2]) +# If num_logprobs is set to -1, then the stringent version +# of the test is executed using `check_outputs_equal` +# instead of `check_logprobs_close` +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_apc_multiple_prompts_partial_cached_outputs( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + n_repetitions: int, + num_logprobs: int, + tensor_parallel_size: int, +) -> None: + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + compare_operator: Callable = ( + check_logprobs_close if num_logprobs > 0 else check_outputs_equal # type: ignore + ) + + MULTIPLE = 300 + + # Sample prompts. + generated_prompts = [MULTIPLE * prompt for prompt in example_prompts] + + max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts) + vllm_runner_kwargs = _get_vllm_runner_params( + model, max_model_len, tensor_parallel_size=tensor_parallel_size + ) + vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32" + + vllm_outputs_no_cache, _ = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs + ) + + # Cache only part of all the prompts + vllm_runner_kwargs["enable_prefix_caching"] = True + vllm_outputs_partial_cache, vllm_model = _get_vLLM_output( + vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs + ) + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0][:3], + outputs_1_lst=vllm_outputs_partial_cache[0], + name_0="vllm_no_cache", + name_1="vllm_partial_cache", + ) + + vllm_outputs_cache_rep, _ = _get_vLLM_output( + vllm_runner, + vllm_runner_kwargs, + generated_prompts, + max_tokens, + num_logprobs, + n_repetitions, + vllm_model=vllm_model, + ) + + for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep): + # In the first repetition, the caches are filled + # In the second repetition, these caches are reused + + compare_operator( + outputs_0_lst=vllm_outputs_no_cache[0], + outputs_1_lst=vllm_outputs_cache_itn, + name_0="vllm_no_cache", + name_1=f"vllm_cache_it_{r_idx + 1}", + ) diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py index 845afbfa8a45..0ae83ec16020 100644 --- a/tests/models/language/generation/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -6,7 +6,9 @@ import pytest from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( - MistralToolCall, MistralToolParser) + MistralToolCall, + MistralToolParser, +) from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import MistralTokenizer @@ -33,136 +35,118 @@ ] # for function calling -TOOLS = [{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": - "string", - "description": - "The city to find the weather for, e.g. 'San Francisco'" - }, - "state": { - "type": - "string", - "description": - "the two-letter abbreviation for the state that the city is" - " in, e.g. 'CA' which would mean 'California'" +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. " + "'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that " + "the city is in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, }, - "unit": { - "type": "string", - "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } + "required": ["city", "state", "unit"], }, - "required": ["city", "state", "unit"] - } + }, }, -}, { - "type": "function", - "function": { - "name": "rewrite", - "description": "Rewrites text", - "parameters": { - "type": "object", - "required": [], - "properties": { - "text": { - "type": "string", - "description": "The input text to rewrite." - } - } - } - } -}] -MSGS = [ { - "role": "system", - "content": "You are an assistant." + "type": "function", + "function": { + "name": "rewrite", + "description": "Rewrites text", + "parameters": { + "type": "object", + "required": [], + "properties": { + "text": { + "type": "string", + "description": "The input text to rewrite.", + } + }, + }, + }, }, +] +MSGS = [ + {"role": "system", "content": "You are an assistant."}, { - "role": - "user", - "content": - "Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa + "role": "user", + "content": "Could you please rewrite the below article? \n\n My English needs " + "improvving, maybe I make errors.", }, { - "role": - "assistant", - "content": - "", - "tool_calls": [{ - "id": "bbc5b7ede", - "type": "function", - "function": { - "name": - "rewrite", - "arguments": - '{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "bbc5b7ede", + "type": "function", + "function": { + "name": "rewrite", + "arguments": '{"text":"My English needs improvving, maybe ' + 'I make errors."}', + }, } - }] + ], }, { "role": "tool", - "content": - "{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa + "content": '{"action":"rewrite","outcome":"My English needs improving, maybe ' + 'I make errors."}', "tool_call_id": "bbc5b7ede", - "name": "rewrite" + "name": "rewrite", }, { "role": "assistant", - "content": "---\n\nMy English needs improving, maybe I make errors" + "content": "---\n\nMy English needs improving, maybe I make errors", }, { - "role": - "user", - "content": ("Can you tell me what the temperate" - " will be in Dallas, in fahrenheit?") - } + "role": "user", + "content": ( + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" + ), + }, ] SAMPLE_JSON_SCHEMA = { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, "skills": { "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 + "items": {"type": "string", "maxLength": 10}, + "minItems": 3, }, "work_history": { "type": "array", "items": { "type": "object", "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "number" - }, - "position": { - "type": "string" - } + "company": {"type": "string"}, + "duration": {"type": "number"}, + "position": {"type": "string"}, }, - "required": ["company", "position"] - } - } + "required": ["company", "position"], + }, + }, }, - "required": ["name", "age", "skills", "work_history"] + "required": ["name", "age", "skills", "work_history"], } @@ -170,17 +154,25 @@ @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, max_tokens: int, num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: # TODO(sang): Sliding window should be tested separately. with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model, dtype=dtype, - tokenizer_mode="mistral") as vllm_model: + with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral") as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -194,27 +186,35 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, - max_tokens: int, num_logprobs: int) -> None: +def test_mistral_format( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: with vllm_runner( - model, - dtype=dtype, - tokenizer_mode="mistral", - load_format="mistral", - config_format="mistral", + model, + dtype=dtype, + tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", ) as mistral_format_model: mistral_format_outputs = mistral_format_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner( - model, - dtype=dtype, - tokenizer_mode="auto", - load_format="safetensors", - config_format="hf", + model, + dtype=dtype, + tokenizer_mode="auto", + load_format="safetensors", + config_format="hf", ) as hf_format_model: hf_format_outputs = hf_format_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_format_outputs, @@ -226,34 +226,35 @@ def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str, @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_mistral_symbolic_languages(vllm_runner, model: str, - dtype: str) -> None: - with vllm_runner(model, - dtype=dtype, - max_model_len=8192, - tokenizer_mode="mistral", - config_format="mistral", - load_format="mistral") as vllm_model: +def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str) -> None: + with vllm_runner( + model, + dtype=dtype, + max_model_len=8192, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral", + ) as vllm_model: for prompt in SYMBOLIC_LANG_PROMPTS: msg = {"role": "user", "content": prompt} - outputs = vllm_model.llm.chat([msg], - sampling_params=SAMPLING_PARAMS) + outputs = vllm_model.llm.chat([msg], sampling_params=SAMPLING_PARAMS) assert "�" not in outputs[0].outputs[0].text.strip() @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: - with vllm_runner(model, - dtype=dtype, - tokenizer_mode="mistral", - config_format="mistral", - load_format="mistral") as vllm_model: - + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral", + ) as vllm_model: msgs = copy.deepcopy(MSGS) - outputs = vllm_model.llm.chat(msgs, - tools=TOOLS, - sampling_params=SAMPLING_PARAMS) + outputs = vllm_model.llm.chat( + msgs, tools=TOOLS, sampling_params=SAMPLING_PARAMS + ) tokenizer = vllm_model.llm.get_tokenizer() tool_parser = MistralToolParser(tokenizer) @@ -265,10 +266,11 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: assert parsed_message.tools_called assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id) - assert parsed_message.tool_calls[ - 0].function.name == "get_current_weather" - assert parsed_message.tool_calls[ - 0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa + assert parsed_message.tool_calls[0].function.name == "get_current_weather" + assert ( + parsed_message.tool_calls[0].function.arguments + == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' + ) # noqa assert parsed_message.content is None @@ -297,17 +299,10 @@ def get_vocab(): "city": "Dallas", "state": "TX", "unit": "fahrenheit", - "sub_dict": { - "foo": "bar", - "inner": { - "x": 1, - "y": 2 - } - }, + "sub_dict": {"foo": "bar", "inner": {"x": 1, "y": 2}}, } - model_output = ( - f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}") + model_output = f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}" parsed = parser.extract_tool_calls(model_output, None) diff --git a/tests/models/language/generation/test_phimoe.py b/tests/models/language/generation/test_phimoe.py index 6c9cc2821c30..e640655784cc 100644 --- a/tests/models/language/generation/test_phimoe.py +++ b/tests/models/language/generation/test_phimoe.py @@ -15,62 +15,56 @@ def test_phimoe_routing_function(): from vllm.model_executor.models.phimoe import phimoe_routing_function + test_case = { 0: { - "hidden_states": - torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.float32, - requires_grad=False).view(4, 2), - "gating_output": - torch.tensor([0.1, 0.2, 0.3, 0.4], - dtype=torch.float32, - requires_grad=False), - "topk": - 2, - "renormalize": - False, + "hidden_states": torch.tensor( + [1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float32, requires_grad=False + ).view(4, 2), + "gating_output": torch.tensor( + [0.1, 0.2, 0.3, 0.4], dtype=torch.float32, requires_grad=False + ), + "topk": 2, + "renormalize": False, }, 1: { - "hidden_states": - torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.float32, - requires_grad=False).view(4, 2), - "gating_output": - torch.tensor([0.4, 0.2, 0.3, 0.4], - dtype=torch.float32, - requires_grad=False), - "topk": - 2, - "renormalize": - False, - } + "hidden_states": torch.tensor( + [1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float32, requires_grad=False + ).view(4, 2), + "gating_output": torch.tensor( + [0.4, 0.2, 0.3, 0.4], dtype=torch.float32, requires_grad=False + ), + "topk": 2, + "renormalize": False, + }, } ground_truth = { 0: { - "topk_weights": - torch.tensor([1., 1.], dtype=torch.float32, requires_grad=False), - "topk_ids": - torch.tensor([3, 2], dtype=torch.long, requires_grad=False), + "topk_weights": torch.tensor( + [1.0, 1.0], dtype=torch.float32, requires_grad=False + ), + "topk_ids": torch.tensor([3, 2], dtype=torch.long, requires_grad=False), }, 1: { - "topk_weights": - torch.tensor([0.5, 1.], dtype=torch.float32, requires_grad=False), - "topk_ids": - torch.tensor([0, 3], dtype=torch.long, requires_grad=False), - } + "topk_weights": torch.tensor( + [0.5, 1.0], dtype=torch.float32, requires_grad=False + ), + "topk_ids": torch.tensor([0, 3], dtype=torch.long, requires_grad=False), + }, } for test_id in test_case: topk_weights, topk_ids = phimoe_routing_function(**test_case[test_id]) - assert torch.allclose(topk_weights, - ground_truth[test_id]["topk_weights"]) + assert torch.allclose(topk_weights, ground_truth[test_id]["topk_weights"]) assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"]) -@pytest.mark.skipif(condition=current_platform.is_cpu(), - reason="This test takes a lot time to run on CPU, " - "and vllm CI's disk space is not enough for this model.") +@pytest.mark.skipif( + condition=current_platform.is_cpu(), + reason="This test takes a lot time to run on CPU, " + "and vllm CI's disk space is not enough for this model.", +) @large_gpu_test(min_gb=80) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -87,11 +81,13 @@ def test_models( ) -> None: with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, diff --git a/tests/models/language/generation_ppl_test/ppl_utils.py b/tests/models/language/generation_ppl_test/ppl_utils.py index 6225bbe3377b..dcef365e99e7 100644 --- a/tests/models/language/generation_ppl_test/ppl_utils.py +++ b/tests/models/language/generation_ppl_test/ppl_utils.py @@ -8,8 +8,7 @@ from datasets import load_dataset import tests.ci_envs as ci_envs -from tests.models.utils import (GenerateModelInfo, - TokensTextLogprobsPromptLogprobs) +from tests.models.utils import GenerateModelInfo, TokensTextLogprobsPromptLogprobs from vllm.logprobs import Logprob # See #24485 @@ -18,13 +17,14 @@ @torch.inference_mode -def wikitext_ppl_test(hf_runner, - vllm_runner, - model_info: GenerateModelInfo, - max_length=MAX_LENGTH, - vllm_extra_kwargs=None, - atol=PPL_TOL): - +def wikitext_ppl_test( + hf_runner, + vllm_runner, + model_info: GenerateModelInfo, + max_length=MAX_LENGTH, + vllm_extra_kwargs=None, + atol=PPL_TOL, +): # A model family has many models with the same architecture, # and we don't need to test each one. if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: @@ -44,15 +44,16 @@ def wikitext_ppl_test(hf_runner, if ci_envs.VLLM_CI_HEAD_DTYPE is not None: if "hf_overrides" not in vllm_extra_kwargs: vllm_extra_kwargs["hf_overrides"] = {} - vllm_extra_kwargs["hf_overrides"][ - "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE - - with vllm_runner(model_info.name, - gpu_memory_utilization=0.7, - max_model_len=max_length, - max_num_seqs=1, - enforce_eager=True, - **vllm_extra_kwargs) as vllm_model: + vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + + with vllm_runner( + model_info.name, + gpu_memory_utilization=0.7, + max_model_len=max_length, + max_num_seqs=1, + enforce_eager=True, + **vllm_extra_kwargs, + ) as vllm_model: # Use max_num_seqs=1 to avoid OOM, # and avoid batch different requests together. @@ -60,7 +61,7 @@ def wikitext_ppl_test(hf_runner, # Confirm whether vllm is using the correct architecture if model_info.architecture: - assert (model_info.architecture in model_config.architectures) + assert model_info.architecture in model_config.architectures max_length = min(model_config.max_model_len - 1, max_length) stride = max_length @@ -74,12 +75,14 @@ def wikitext_ppl_test(hf_runner, end_loc = min(begin_loc + max_length, n_tokens) chunks.append(tokens[begin_loc:end_loc]) - outputs = vllm_model.generate_greedy_logprobs(prompts=chunks, - max_tokens=1, - num_logprobs=None, - num_prompt_logprobs=0, - use_tqdm=False) - nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu") + outputs = vllm_model.generate_greedy_logprobs( + prompts=chunks, + max_tokens=1, + num_logprobs=None, + num_prompt_logprobs=0, + use_tqdm=False, + ) + nll_sum = torch.tensor(0.0, dtype=torch.float32, device="cpu") n_tokens = 0 for output in outputs: output = cast(TokensTextLogprobsPromptLogprobs, output) @@ -94,7 +97,8 @@ def wikitext_ppl_test(hf_runner, token_log_probs.append(token_log_prob) neg_log_likelihood = -torch.tensor( - token_log_probs, dtype=torch.float32, device="cpu").sum() + token_log_probs, dtype=torch.float32, device="cpu" + ).sum() nll_sum += neg_log_likelihood n_tokens += len(token_log_probs) vllm_ppl = float(torch.exp(nll_sum / n_tokens)) @@ -104,14 +108,13 @@ def wikitext_ppl_test(hf_runner, # Accelerate ppl test by setting Transformers ppl score to a constant if model_info.hf_ppl is None: with hf_runner( - model_info.name, - dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, + model_info.name, + dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, ) as hf_model: - nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu") + nll_sum = torch.tensor(0.0, dtype=torch.float32, device="cpu") n_tokens = 0 for chunk in chunks: - inputs = hf_model.wrap_device( - {"input_ids": torch.tensor([chunk])}) + inputs = hf_model.wrap_device({"input_ids": torch.tensor([chunk])}) input_ids = inputs["input_ids"] outputs = hf_model.model(input_ids, labels=input_ids) neg_log_likelihood = outputs.loss diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index 86751e0a4d5f..261ab80ae86b 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -6,8 +6,7 @@ import pytest from tests.conftest import HfRunner -from tests.models.utils import (EmbedModelInfo, check_embeddings_close, - matryoshka_fy) +from tests.models.utils import EmbedModelInfo, check_embeddings_close, matryoshka_fy def run_embedding_correctness_test( @@ -29,12 +28,14 @@ def run_embedding_correctness_test( ) -def correctness_test_embed_models(hf_runner, - vllm_runner, - model_info: EmbedModelInfo, - example_prompts, - vllm_extra_kwargs=None, - hf_model_callback=None): +def correctness_test_embed_models( + hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + example_prompts, + vllm_extra_kwargs=None, + hf_model_callback=None, +): pytest.skip("Debug only, ci prefers to use mteb test.") # The example_prompts has ending "\n", for example: @@ -51,18 +52,16 @@ def correctness_test_embed_models(hf_runner, if model_info.hf_overrides is not None: vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - **vllm_extra_kwargs) as vllm_model: + with vllm_runner( + model_info.name, runner="pooling", max_model_len=None, **vllm_extra_kwargs + ) as vllm_model: vllm_outputs = vllm_model.embed(example_prompts) with hf_runner( - model_info.name, - dtype=model_info.hf_dtype, - is_sentence_transformer=True, + model_info.name, + dtype=model_info.hf_dtype, + is_sentence_transformer=True, ) as hf_model: - if hf_model_callback is not None: hf_model_callback(hf_model) diff --git a/tests/models/language/pooling/test_auto_prefix_cache_support.py b/tests/models/language/pooling/test_auto_prefix_cache_support.py index 15e24c59d1dd..e95119df95c7 100644 --- a/tests/models/language/pooling/test_auto_prefix_cache_support.py +++ b/tests/models/language/pooling/test_auto_prefix_cache_support.py @@ -4,8 +4,7 @@ import torch from transformers import AutoModelForSequenceClassification -from tests.models.language.pooling.embed_utils import ( - run_embedding_correctness_test) +from tests.models.language.pooling.embed_utils import run_embedding_correctness_test @pytest.mark.parametrize( @@ -20,28 +19,27 @@ def test_classify_models( model: str, dtype: str, ) -> None: - example_prompts = example_prompts * 2 - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - enable_prefix_caching=True) as vllm_model: + with vllm_runner( + model, max_model_len=512, dtype=dtype, enable_prefix_caching=True + ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert cache_config.enable_prefix_caching vllm_outputs = vllm_model.classify(example_prompts) - with hf_runner(model, - dtype=dtype, - auto_cls=AutoModelForSequenceClassification) as hf_model: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForSequenceClassification + ) as hf_model: hf_outputs = hf_model.classify(example_prompts) for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): hf_output = torch.tensor(hf_output) vllm_output = torch.tensor(vllm_output) - assert torch.allclose(hf_output, vllm_output, - 1e-3 if dtype == "float" else 1e-2) + assert torch.allclose( + hf_output, vllm_output, 1e-3 if dtype == "float" else 1e-2 + ) @pytest.mark.parametrize( @@ -59,18 +57,18 @@ def test_embed_models( example_prompts = [str(s).strip() for s in example_prompts] * 2 with vllm_runner( - model, - runner="pooling", - max_model_len=None, - enable_prefix_caching=True, + model, + runner="pooling", + max_model_len=None, + enable_prefix_caching=True, ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert cache_config.enable_prefix_caching vllm_outputs = vllm_model.embed(example_prompts) with hf_runner( - model, - is_sentence_transformer=True, + model, + is_sentence_transformer=True, ) as hf_model: run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) @@ -81,13 +79,14 @@ def test_embed_models( "intfloat/e5-small", "Alibaba-NLP/gte-Qwen2-1.5B-instruct", # is_causal == False "papluca/xlm-roberta-base-language-detection", - ]) + ], +) @pytest.mark.parametrize("dtype", ["half"]) -def test_non_causal_models(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str) -> None: - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - enable_prefix_caching=True) as vllm_model: +def test_non_causal_models( + hf_runner, vllm_runner, example_prompts, model: str, dtype: str +) -> None: + with vllm_runner( + model, max_model_len=512, dtype=dtype, enable_prefix_caching=True + ) as vllm_model: cache_config = vllm_model.llm.llm_engine.cache_config assert not cache_config.enable_prefix_caching diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 8e398830d39d..471826f214d0 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -10,15 +10,17 @@ @pytest.mark.parametrize( "model", [ - pytest.param("jason9693/Qwen2.5-1.5B-apeach", - marks=[ - pytest.mark.core_model, pytest.mark.cpu_model, - pytest.mark.slow_test - ]), + pytest.param( + "jason9693/Qwen2.5-1.5B-apeach", + marks=[ + pytest.mark.core_model, + pytest.mark.cpu_model, + pytest.mark.slow_test, + ], + ), ], ) -@pytest.mark.parametrize("dtype", - ["half"] if current_platform.is_rocm() else ["float"]) +@pytest.mark.parametrize("dtype", ["half"] if current_platform.is_rocm() else ["float"]) def test_models( hf_runner, vllm_runner, @@ -35,9 +37,9 @@ def test_models( with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) - with hf_runner(model, - dtype=dtype, - auto_cls=AutoModelForSequenceClassification) as hf_model: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForSequenceClassification + ) as hf_model: hf_outputs = hf_model.classify(example_prompts) # check logits difference @@ -48,5 +50,6 @@ def test_models( # the tolerance value of 1e-2 is selected based on the # half datatype tests in # tests/models/language/pooling/test_embedding.py - assert torch.allclose(hf_output, vllm_output, - 1e-3 if dtype == "float" else 1e-2) + assert torch.allclose( + hf_output, vllm_output, 1e-3 if dtype == "float" else 1e-2 + ) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 17513d1bb20d..c9574dca498e 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -18,20 +18,25 @@ # case won't pass because gte-Qwen2-1.5B-instruct will cache custom # model code with bidirectional attention. # [Decoder-only] - pytest.param("BAAI/bge-multilingual-gemma2", - marks=[pytest.mark.core_model, pytest.mark.slow_test]), + pytest.param( + "BAAI/bge-multilingual-gemma2", + marks=[pytest.mark.core_model, pytest.mark.slow_test], + ), pytest.param( "intfloat/e5-mistral-7b-instruct", # CPU v1 doesn't support sliding window - marks=[pytest.mark.core_model]), - pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", - marks=[pytest.mark.cpu_model]), + marks=[pytest.mark.core_model], + ), + pytest.param( + "ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.cpu_model] + ), # [Encoder-only] pytest.param( "BAAI/bge-base-en-v1.5", marks=[ - pytest.mark.core_model, pytest.mark.cpu_model, - pytest.mark.slow_test + pytest.mark.core_model, + pytest.mark.cpu_model, + pytest.mark.slow_test, ], ), pytest.param("sentence-transformers/all-MiniLM-L12-v2"), @@ -50,7 +55,6 @@ def test_models( model, monkeypatch, ) -> None: - if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm(): # ROCm Triton FA does not currently support sliding window attention # switch to use ROCm CK FA backend @@ -58,13 +62,14 @@ def test_models( vllm_extra_kwargs = {} if model == "ssmits/Qwen2-7B-Instruct-embed-base": - vllm_extra_kwargs["pooler_config"] = \ - PoolerConfig(pooling_type="MEAN", normalize=False) + vllm_extra_kwargs["pooler_config"] = PoolerConfig( + pooling_type="MEAN", normalize=False + ) max_model_len: Optional[int] = 512 if model in [ - "sentence-transformers/all-MiniLM-L12-v2", - "sentence-transformers/stsb-roberta-base-v2" + "sentence-transformers/all-MiniLM-L12-v2", + "sentence-transformers/stsb-roberta-base-v2", ]: max_model_len = None @@ -79,10 +84,9 @@ def test_models( with hf_runner(model, is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, - runner="pooling", - max_model_len=max_model_len, - **vllm_extra_kwargs) as vllm_model: + with vllm_runner( + model, runner="pooling", max_model_len=max_model_len, **vllm_extra_kwargs + ) as vllm_model: vllm_outputs = vllm_model.embed(example_prompts) check_embeddings_close( diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index 17a55d916b1f..14308ac06c03 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -70,8 +70,9 @@ async def run_client_embeddings( def gritlm_instruction(instruction): - return ("<|user|>\n" + instruction + - "\n<|embed|>\n" if instruction else "<|embed|>\n") + return ( + "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n" + ) def get_test_data(): @@ -80,7 +81,8 @@ def get_test_data(): README.md in https://github.com/ContextualAI/gritlm """ q_instruction = gritlm_instruction( - "Given a scientific paper title, retrieve the paper's abstract", ) + "Given a scientific paper title, retrieve the paper's abstract", + ) queries = [ "Bitcoin: A Peer-to-Peer Electronic Cash System", "Generative Representational Instruction Tuning", @@ -114,9 +116,9 @@ def test_gritlm_offline_embedding(vllm_runner): queries, q_instruction, documents, d_instruction = get_test_data() with vllm_runner( - MODEL_NAME, - runner="pooling", - max_model_len=MAX_MODEL_LEN, + MODEL_NAME, + runner="pooling", + max_model_len=MAX_MODEL_LEN, ) as vllm_model: llm = vllm_model.llm @@ -161,9 +163,9 @@ def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner): input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" with vllm_runner( - MODEL_NAME, - runner="generate", - max_model_len=MAX_MODEL_LEN, + MODEL_NAME, + runner="generate", + max_model_len=MAX_MODEL_LEN, ) as vllm_model: llm = vllm_model.llm diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py index 9814cad48a80..91be6cd09d33 100644 --- a/tests/models/language/pooling/test_mm_classifier_conversion.py +++ b/tests/models/language/pooling/test_mm_classifier_conversion.py @@ -21,16 +21,18 @@ def test_idefics_multimodal( "The future of AI is", ] - with vllm_runner(model_name="HuggingFaceM4/Idefics3-8B-Llama3", - runner="pooling", - task="classify", - convert="classify", - load_format="dummy", - max_model_len=512, - enforce_eager=True, - tensor_parallel_size=1, - disable_log_stats=True, - dtype="bfloat16") as vllm_model: + with vllm_runner( + model_name="HuggingFaceM4/Idefics3-8B-Llama3", + runner="pooling", + task="classify", + convert="classify", + load_format="dummy", + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=1, + disable_log_stats=True, + dtype="bfloat16", + ) as vllm_model: llm = vllm_model.get_llm() outputs = llm.classify(prompts) for output in outputs: @@ -38,19 +40,20 @@ def test_idefics_multimodal( def update_config(config): - config.text_config.update({ - "architectures": ["Gemma3ForSequenceClassification"], - "classifier_from_token": ["A", "B", "C", "D", "E"], - "method": - "no_post_processing", - "id2label": { - "A": "Chair", - "B": "Couch", - "C": "Table", - "D": "Bed", - "E": "Cupboard" - }, - }) + config.text_config.update( + { + "architectures": ["Gemma3ForSequenceClassification"], + "classifier_from_token": ["A", "B", "C", "D", "E"], + "method": "no_post_processing", + "id2label": { + "A": "Chair", + "B": "Couch", + "C": "Table", + "D": "Bed", + "E": "Cupboard", + }, + } + ) return config @@ -63,11 +66,10 @@ def test_gemma_multimodal( # switch to use ROCm CK FA backend monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - messages = [{ - "role": - "system", - "content": - """ + messages = [ + { + "role": "system", + "content": """ You are a helpful assistant. You will be given a product description which may also include an image. Classify the following product into one of the categories: @@ -78,38 +80,39 @@ def test_gemma_multimodal( D = bed E = cupboard - You'll answer with exactly one letter (A, B, C, D, or E).""" - }, { - "role": - "user", - "content": [{ - "type": "image_url", - "image_url": { - "url": - "https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg" - } - }, { - "type": "text", - "text": "A fine 19th century piece of furniture." - }] - }] - - with vllm_runner(model_name="google/gemma-3-4b-it", - runner="pooling", - task="classify", - convert="classify", - load_format="auto", - hf_overrides=update_config, - pooler_config=PoolerConfig(pooling_type="LAST"), - max_model_len=512, - enforce_eager=True, - tensor_parallel_size=1, - disable_log_stats=True, - dtype="bfloat16") as vllm_model: + You'll answer with exactly one letter (A, B, C, D, or E).""", + }, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg" + }, + }, + {"type": "text", "text": "A fine 19th century piece of furniture."}, + ], + }, + ] + with vllm_runner( + model_name="google/gemma-3-4b-it", + runner="pooling", + task="classify", + convert="classify", + load_format="auto", + hf_overrides=update_config, + pooler_config=PoolerConfig(pooling_type="LAST"), + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=1, + disable_log_stats=True, + dtype="bfloat16", + ) as vllm_model: llm = vllm_model.get_llm() prompts = llm.preprocess_chat(messages) result = llm.classify(prompts) assert result[0].outputs.probs[0] > 0.95 - assert all(c < 0.05 for c in result[0].outputs.probs[1:]) \ No newline at end of file + assert all(c < 0.05 for c in result[0].outputs.probs[1:]) diff --git a/tests/models/language/pooling/test_multilabel_classification_support.py b/tests/models/language/pooling/test_multilabel_classification_support.py index 45366f209414..472fee71711a 100644 --- a/tests/models/language/pooling/test_multilabel_classification_support.py +++ b/tests/models/language/pooling/test_multilabel_classification_support.py @@ -20,14 +20,15 @@ def test_classify_models( with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) - with hf_runner(model, - dtype=dtype, - auto_cls=AutoModelForSequenceClassification) as hf_model: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForSequenceClassification + ) as hf_model: hf_outputs = hf_model.classify(example_prompts) for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): hf_output = torch.tensor(hf_output) vllm_output = torch.tensor(vllm_output) - assert torch.allclose(hf_output, vllm_output, - 1e-3 if dtype == "float" else 1e-2) + assert torch.allclose( + hf_output, vllm_output, 1e-3 if dtype == "float" else 1e-2 + ) diff --git a/tests/models/language/pooling/test_nomic_max_model_len.py b/tests/models/language/pooling/test_nomic_max_model_len.py index c34c36fd9815..88f088c60327 100644 --- a/tests/models/language/pooling/test_nomic_max_model_len.py +++ b/tests/models/language/pooling/test_nomic_max_model_len.py @@ -7,10 +7,10 @@ MODELS = [ EmbedModelInfo("nomic-ai/nomic-embed-text-v1"), - #EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5"), - #EmbedModelInfo("nomic-ai/CodeRankEmbed"), + # EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5"), + # EmbedModelInfo("nomic-ai/CodeRankEmbed"), EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe"), - #EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long"), + # EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long"), ] rope_theta = 1000 @@ -21,23 +21,24 @@ @pytest.mark.parametrize("model_info", MODELS) def test_default(model_info, vllm_runner): - with vllm_runner(model_info.name, runner="pooling", - max_model_len=None) as vllm_model: + with vllm_runner( + model_info.name, runner="pooling", max_model_len=None + ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config if model_info.name == "nomic-ai/nomic-embed-text-v2-moe": # For nomic-embed-text-v2-moe the length is set to 512 # by sentence_bert_config.json. assert model_config.max_model_len == 512 else: - assert ( - model_config.max_model_len == original_max_position_embeddings) + assert model_config.max_model_len == original_max_position_embeddings @pytest.mark.parametrize("model_info", MODELS) def test_set_max_model_len_legal(model_info, vllm_runner): # set max_model_len <= 512 - with vllm_runner(model_info.name, runner="pooling", - max_model_len=256) as vllm_model: + with vllm_runner( + model_info.name, runner="pooling", max_model_len=256 + ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config assert model_config.max_model_len == 256 @@ -46,13 +47,12 @@ def test_set_max_model_len_legal(model_info, vllm_runner): # For nomic-embed-text-v2-moe the length is set to 512 # by sentence_bert_config.json. with pytest.raises(ValueError): - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=1024): + with vllm_runner(model_info.name, runner="pooling", max_model_len=1024): pass else: - with vllm_runner(model_info.name, runner="pooling", - max_model_len=1024) as vllm_model: + with vllm_runner( + model_info.name, runner="pooling", max_model_len=1024 + ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config assert model_config.max_model_len == 1024 @@ -61,17 +61,18 @@ def test_set_max_model_len_legal(model_info, vllm_runner): def test_set_max_model_len_illegal(model_info, vllm_runner): # set max_model_len > 2048 with pytest.raises(ValueError): - with vllm_runner(model_info.name, runner="pooling", - max_model_len=4096): + with vllm_runner(model_info.name, runner="pooling", max_model_len=4096): pass # set max_model_len > 2048 by hf_overrides hf_overrides = {"max_model_len": 4096} with pytest.raises(ValueError): - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=None, + hf_overrides=hf_overrides, + ): pass @@ -82,16 +83,14 @@ def test_use_rope_scaling_legal(model_info, vllm_runner): "rope_scaling": { "rope_type": "yarn", "factor": factor, - "original_max_position_embeddings": - original_max_position_embeddings + "original_max_position_embeddings": original_max_position_embeddings, }, - "max_model_len": max_model_len + "max_model_len": max_model_len, } - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, runner="pooling", max_model_len=None, hf_overrides=hf_overrides + ): pass @@ -102,16 +101,17 @@ def test_use_rope_scaling_illegal(model_info, vllm_runner): "rope_scaling": { "rope_type": "yarn", "factor": factor, - "original_max_position_embeddings": - original_max_position_embeddings - } + "original_max_position_embeddings": original_max_position_embeddings, + }, } # illegal max_model_len with pytest.raises(ValueError): - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=max_model_len + 1, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=max_model_len + 1, + hf_overrides=hf_overrides, + ): pass hf_overrides = { @@ -119,15 +119,16 @@ def test_use_rope_scaling_illegal(model_info, vllm_runner): "rope_scaling": { "rope_type": "yarn", "factor": factor, - "original_max_position_embeddings": - original_max_position_embeddings + "original_max_position_embeddings": original_max_position_embeddings, }, - "max_model_len": max_model_len + 1 + "max_model_len": max_model_len + 1, } # illegal max_model_len by hf_overrides with pytest.raises(ValueError): - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - hf_overrides=hf_overrides): + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=None, + hf_overrides=hf_overrides, + ): pass diff --git a/tests/models/language/pooling/test_pooler_config_init_behaviour.py b/tests/models/language/pooling/test_pooler_config_init_behaviour.py index 9b3fbd6a6cd0..674bf02b7b98 100644 --- a/tests/models/language/pooling/test_pooler_config_init_behaviour.py +++ b/tests/models/language/pooling/test_pooler_config_init_behaviour.py @@ -10,10 +10,7 @@ @pytest.mark.parametrize( "model", - [ - "jason9693/Qwen2.5-1.5B-apeach", - "papluca/xlm-roberta-base-language-detection" - ], + ["jason9693/Qwen2.5-1.5B-apeach", "papluca/xlm-roberta-base-language-detection"], ) @pytest.mark.parametrize("dtype", ["half"]) def test_classify_models_using_activation( @@ -23,30 +20,32 @@ def test_classify_models_using_activation( model: str, dtype: str, ) -> None: - with vllm_runner( - model, - max_model_len=512, - dtype=dtype, - pooler_config=PoolerConfig(activation=False)) as vllm_model: + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(activation=False), + ) as vllm_model: wo_activation_out = vllm_model.classify(example_prompts) with vllm_runner( - model, - max_model_len=512, - dtype=dtype, - pooler_config=PoolerConfig(activation=True)) as vllm_model: + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(activation=True), + ) as vllm_model: w_activation_out = vllm_model.classify(example_prompts) - for wo_activation, w_activation in zip(wo_activation_out, - w_activation_out): + for wo_activation, w_activation in zip(wo_activation_out, w_activation_out): wo_activation = torch.tensor(wo_activation) w_activation = torch.tensor(w_activation) - assert not torch.allclose(wo_activation, w_activation, - atol=1e-2), "pooler_config is not working" - assert torch.allclose(softmax(wo_activation), w_activation, - 1e-3 if dtype == "float" else 1e-2) + assert not torch.allclose(wo_activation, w_activation, atol=1e-2), ( + "pooler_config is not working" + ) + assert torch.allclose( + softmax(wo_activation), w_activation, 1e-3 if dtype == "float" else 1e-2 + ) @pytest.mark.parametrize( @@ -63,26 +62,28 @@ def test_embed_models_using_normalize( model: str, dtype: str, ) -> None: - with vllm_runner( - model, - max_model_len=512, - dtype=dtype, - pooler_config=PoolerConfig(normalize=False)) as vllm_model: + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(normalize=False), + ) as vllm_model: wo_normalize = torch.tensor(vllm_model.embed(example_prompts)) - with vllm_runner(model, - max_model_len=512, - dtype=dtype, - pooler_config=PoolerConfig(normalize=True)) as vllm_model: + with vllm_runner( + model, + max_model_len=512, + dtype=dtype, + pooler_config=PoolerConfig(normalize=True), + ) as vllm_model: w_normalize = torch.tensor(vllm_model.embed(example_prompts)) - assert not torch.allclose( - wo_normalize, w_normalize, - atol=1e-2), "pooler_config normalize is not working" + assert not torch.allclose(wo_normalize, w_normalize, atol=1e-2), ( + "pooler_config normalize is not working" + ) assert torch.allclose( - F.normalize(wo_normalize, p=2, dim=-1), w_normalize, - atol=1e-2), "w_normal should be close to normal(wo_normal)." + F.normalize(wo_normalize, p=2, dim=-1), w_normalize, atol=1e-2 + ), "w_normal should be close to normal(wo_normal)." @pytest.mark.parametrize( @@ -99,25 +100,26 @@ def test_reward_models_using_softmax( model: str, dtype: str, ) -> None: - - with vllm_runner(model, - max_model_len=1024, - dtype=dtype, - pooler_config=PoolerConfig(softmax=False)) as vllm_model: + with vllm_runner( + model, + max_model_len=1024, + dtype=dtype, + pooler_config=PoolerConfig(softmax=False), + ) as vllm_model: wo_softmax = vllm_model.encode(example_prompts) - with vllm_runner(model, - max_model_len=1024, - dtype=dtype, - pooler_config=PoolerConfig(softmax=True)) as vllm_model: + with vllm_runner( + model, max_model_len=1024, dtype=dtype, pooler_config=PoolerConfig(softmax=True) + ) as vllm_model: w_softmax = vllm_model.encode(example_prompts) for wo, w in zip(wo_softmax, w_softmax): wo = torch.tensor(wo) w = torch.tensor(w) - assert not torch.allclose( - wo, w, atol=1e-2), "pooler_config softmax is not working" - assert torch.allclose( - softmax(wo), w, - atol=1e-2), "w_softmax should be close to softmax(wo_softmax)." + assert not torch.allclose(wo, w, atol=1e-2), ( + "pooler_config softmax is not working" + ) + assert torch.allclose(softmax(wo), w, atol=1e-2), ( + "w_softmax should be close to softmax(wo_softmax)." + ) diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index 4ac91b5aed50..46504d025c26 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -16,10 +16,8 @@ def math_step_prompts(): # ruff: noqa: E501 data = { - "system": - "Please reason step by step, and put your final answer within \\boxed{}. ", - "query": - "Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?", + "system": "Please reason step by step, and put your final answer within \\boxed{}. ", + "query": "Sue lives in a fun neighborhood. One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard. Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?", "response": [ "To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.", "On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.", @@ -27,16 +25,16 @@ def math_step_prompts(): "To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30}).", ], } - answer = "<extra_0>".join(data['response']) + "<extra_0>" + answer = "<extra_0>".join(data["response"]) + "<extra_0>" prompt = f"<im_start>system\n{data['system']}<im_end>\n<im_start>user\n{data['query']}<im_end>\n<im_start>assistant\n{answer}<im_end><|endoftext|>" return [prompt] def step_reward_patch_hf_model(hf_model: HfRunner): - # Patch the hf_runner to use the step reward function - def make_step_rewards(logits: torch.Tensor, - token_masks: torch.Tensor) -> list[list[float]]: + def make_step_rewards( + logits: torch.Tensor, token_masks: torch.Tensor + ) -> list[list[float]]: probabilities = F.softmax(logits, dim=-1) probabilities = probabilities * token_masks.unsqueeze(-1) @@ -54,7 +52,7 @@ def reward(prompts: list[str]) -> list[list[float]]: outputs = hf_model.model(input_ids=input_ids) step_sep_id = hf_model.tokenizer.encode("<extra_0>")[0] - token_masks = (input_ids == step_sep_id) + token_masks = input_ids == step_sep_id return make_step_rewards(outputs[0], token_masks) hf_model.reward = reward # type: ignore[attr-defined] @@ -65,8 +63,10 @@ def reward(prompts: list[str]) -> list[list[float]]: @pytest.mark.parametrize( "model", [ - pytest.param("Qwen/Qwen2.5-Math-PRM-7B", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param( + "Qwen/Qwen2.5-Math-PRM-7B", + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), ], ) @pytest.mark.parametrize("dtype", ["half"]) @@ -78,8 +78,9 @@ def test_prm_models( dtype: str, monkeypatch, ) -> None: - check_transformers_version("Qwen/Qwen2.5-Math-PRM-7B", - max_transformers_version="4.53.2") + check_transformers_version( + "Qwen/Qwen2.5-Math-PRM-7B", max_transformers_version="4.53.2" + ) if current_platform.is_cpu(): pytest.skip("CPU only supports V1") diff --git a/tests/models/language/pooling/test_scoring.py b/tests/models/language/pooling/test_scoring.py index ef9d5530cde1..416a43070f0e 100644 --- a/tests/models/language/pooling/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -37,10 +37,9 @@ def test_cross_encoder_1_to_1(vllm_runner, hf_runner, model_name): with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict([text_pair]).tolist() - with vllm_runner(model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) assert len(vllm_outputs) == 1 @@ -58,10 +57,9 @@ def test_cross_encoder_1_to_N(vllm_runner, hf_runner, model_name): with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict(text_pairs).tolist() - with vllm_runner(model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) assert len(vllm_outputs) == 2 @@ -80,10 +78,9 @@ def test_cross_encoder_N_to_N(vllm_runner, hf_runner, model_name): with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict(text_pairs).tolist() - with vllm_runner(model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) assert len(vllm_outputs) == 2 @@ -101,17 +98,15 @@ def emb_model_name(request): def test_embedding_1_to_1(vllm_runner, hf_runner, emb_model_name): text_pair = [TEXTS_1[0], TEXTS_2[0]] - with hf_runner(emb_model_name, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: + with hf_runner( + emb_model_name, dtype=DTYPE, is_sentence_transformer=True + ) as hf_model: hf_embeddings = hf_model.encode(text_pair) - hf_outputs = [ - F.cosine_similarity(*map(torch.tensor, hf_embeddings), dim=0) - ] + hf_outputs = [F.cosine_similarity(*map(torch.tensor, hf_embeddings), dim=0)] - with vllm_runner(emb_model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + emb_model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) assert len(vllm_outputs) == 1 @@ -126,20 +121,18 @@ def test_embedding_1_to_N(vllm_runner, hf_runner, emb_model_name): [TEXTS_1[0], TEXTS_2[1]], ] - with hf_runner(emb_model_name, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: - hf_embeddings = [ - hf_model.encode(text_pair) for text_pair in text_pairs - ] + with hf_runner( + emb_model_name, dtype=DTYPE, is_sentence_transformer=True + ) as hf_model: + hf_embeddings = [hf_model.encode(text_pair) for text_pair in text_pairs] hf_outputs = [ F.cosine_similarity(*map(torch.tensor, pair), dim=0) for pair in hf_embeddings ] - with vllm_runner(emb_model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + emb_model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) assert len(vllm_outputs) == 2 @@ -155,20 +148,18 @@ def test_embedding_N_to_N(vllm_runner, hf_runner, emb_model_name): [TEXTS_1[1], TEXTS_2[1]], ] - with hf_runner(emb_model_name, dtype=DTYPE, - is_sentence_transformer=True) as hf_model: - hf_embeddings = [ - hf_model.encode(text_pair) for text_pair in text_pairs - ] + with hf_runner( + emb_model_name, dtype=DTYPE, is_sentence_transformer=True + ) as hf_model: + hf_embeddings = [hf_model.encode(text_pair) for text_pair in text_pairs] hf_outputs = [ F.cosine_similarity(*map(torch.tensor, pair), dim=0) for pair in hf_embeddings ] - with vllm_runner(emb_model_name, - runner="pooling", - dtype=DTYPE, - max_model_len=None) as vllm_model: + with vllm_runner( + emb_model_name, runner="pooling", dtype=DTYPE, max_model_len=None + ) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) assert len(vllm_outputs) == 2 diff --git a/tests/models/language/pooling/test_token_classification.py b/tests/models/language/pooling/test_token_classification.py index fd5e48a8b144..4849f1ec4d36 100644 --- a/tests/models/language/pooling/test_token_classification.py +++ b/tests/models/language/pooling/test_token_classification.py @@ -21,9 +21,9 @@ def test_models( with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) - with hf_runner(model, - dtype=dtype, - auto_cls=AutoModelForTokenClassification) as hf_model: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForTokenClassification + ) as hf_model: tokenizer = hf_model.tokenizer hf_outputs = [] for prompt in example_prompts: diff --git a/tests/models/language/pooling/test_truncation_control.py b/tests/models/language/pooling/test_truncation_control.py index c6ef899958a0..f1870ddbee51 100644 --- a/tests/models/language/pooling/test_truncation_control.py +++ b/tests/models/language/pooling/test_truncation_control.py @@ -20,51 +20,57 @@ field.""" -def test_smaller_truncation_size(vllm_runner, - model_name=MODEL_NAME, - input_str=input_str): - +def test_smaller_truncation_size( + vllm_runner, model_name=MODEL_NAME, input_str=input_str +): truncate_prompt_tokens = 10 - with vllm_runner(model_name, runner="pooling", - max_model_len=max_model_len) as vllm_model: + with vllm_runner( + model_name, runner="pooling", max_model_len=max_model_len + ) as vllm_model: vllm_output = vllm_model.llm.embed( - input_str, truncate_prompt_tokens=truncate_prompt_tokens) + input_str, truncate_prompt_tokens=truncate_prompt_tokens + ) prompt_tokens = vllm_output[0].prompt_token_ids assert len(prompt_tokens) == truncate_prompt_tokens -def test_max_truncation_size(vllm_runner, - model_name=MODEL_NAME, - input_str=input_str): +def test_max_truncation_size(vllm_runner, model_name=MODEL_NAME, input_str=input_str): truncate_prompt_tokens = -1 - with vllm_runner(model_name, runner="pooling", - max_model_len=max_model_len) as vllm_model: + with vllm_runner( + model_name, runner="pooling", max_model_len=max_model_len + ) as vllm_model: vllm_output = vllm_model.llm.embed( - input_str, truncate_prompt_tokens=truncate_prompt_tokens) + input_str, truncate_prompt_tokens=truncate_prompt_tokens + ) prompt_tokens = vllm_output[0].prompt_token_ids assert len(prompt_tokens) == max_model_len -def test_bigger_truncation_size(vllm_runner, - model_name=MODEL_NAME, - input_str=input_str): - +def test_bigger_truncation_size( + vllm_runner, model_name=MODEL_NAME, input_str=input_str +): truncate_prompt_tokens = max_model_len + 1 - with pytest.raises(ValueError), vllm_runner( - model_name, runner="pooling", - max_model_len=max_model_len) as vllm_model: - + with ( + pytest.raises(ValueError), + vllm_runner( + model_name, runner="pooling", max_model_len=max_model_len + ) as vllm_model, + ): llm_output = vllm_model.llm.embed( - input_str, truncate_prompt_tokens=truncate_prompt_tokens) + input_str, truncate_prompt_tokens=truncate_prompt_tokens + ) - assert llm_output == f"""truncate_prompt_tokens value + assert ( + llm_output + == f"""truncate_prompt_tokens value ({truncate_prompt_tokens}) is greater than max_model_len ({max_model_len}). Please, select a smaller truncation size.""" + ) diff --git a/tests/models/language/pooling_mteb_test/mteb_utils.py b/tests/models/language/pooling_mteb_test/mteb_utils.py index 7b3c02fbbd9f..a4a7f1b48d3d 100644 --- a/tests/models/language/pooling_mteb_test/mteb_utils.py +++ b/tests/models/language/pooling_mteb_test/mteb_utils.py @@ -12,8 +12,7 @@ import torch import tests.ci_envs as ci_envs -from tests.models.utils import (EmbedModelInfo, RerankModelInfo, - check_embeddings_close) +from tests.models.utils import EmbedModelInfo, RerankModelInfo, check_embeddings_close # Most embedding models on the STS12 task (See #17175): # - Model implementation and minor changes in tensor dtype @@ -30,7 +29,6 @@ class VllmMtebEncoder(mteb.Encoder): - def __init__(self, vllm_model): super().__init__() self.llm = vllm_model @@ -53,8 +51,7 @@ def encode( def predict( self, - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt *args, **kwargs, ) -> np.ndarray: @@ -64,17 +61,15 @@ def predict( queries = [s[0] for s in sentences] corpus = [s[1] for s in sentences] - outputs = self.llm.score(queries, - corpus, - truncate_prompt_tokens=-1, - use_tqdm=False) + outputs = self.llm.score( + queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False + ) scores = np.array(outputs) scores = scores[np.argsort(r)] return scores class OpenAIClientMtebEncoder(mteb.Encoder): - def __init__(self, model_name: str, client): super().__init__() self.model_name = model_name @@ -87,8 +82,9 @@ def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray: r = self.rng.permutation(len(sentences)) sentences = [sentences[i] for i in r] - embeddings = self.client.embeddings.create(model=self.model_name, - input=sentences) + embeddings = self.client.embeddings.create( + model=self.model_name, input=sentences + ) outputs = [d.embedding for d in embeddings.data] embeds = np.array(outputs) embeds = embeds[np.argsort(r)] @@ -96,7 +92,6 @@ def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray: class ScoreClientMtebEncoder(mteb.Encoder): - def __init__(self, model_name: str, url): super().__init__() self.model_name = model_name @@ -105,8 +100,7 @@ def __init__(self, model_name: str, url): def predict( self, - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt *args, **kwargs, ) -> np.ndarray: @@ -122,27 +116,30 @@ def predict( return scores def get_score(self, query, corpus): - response = requests.post(self.url, - json={ - "model": self.model_name, - "text_1": query, - "text_2": corpus, - "truncate_prompt_tokens": -1, - }).json() - return response['data'][0]["score"] + response = requests.post( + self.url, + json={ + "model": self.model_name, + "text_1": query, + "text_2": corpus, + "truncate_prompt_tokens": -1, + }, + ).json() + return response["data"][0]["score"] class RerankClientMtebEncoder(ScoreClientMtebEncoder): - def get_score(self, query, corpus): - response = requests.post(self.url, - json={ - "model": self.model_name, - "query": query, - "documents": [corpus], - "truncate_prompt_tokens": -1, - }).json() - return response['results'][0]["relevance_score"] + response = requests.post( + self.url, + json={ + "model": self.model_name, + "query": query, + "documents": [corpus], + "truncate_prompt_tokens": -1, + }, + ).json() + return response["results"][0]["relevance_score"] def run_mteb_embed_task(encoder, tasks): @@ -161,12 +158,14 @@ def run_mteb_embed_task(encoder, tasks): return main_score -def mteb_test_embed_models(hf_runner, - vllm_runner, - model_info: EmbedModelInfo, - vllm_extra_kwargs=None, - hf_model_callback=None, - atol=MTEB_EMBED_TOL): +def mteb_test_embed_models( + hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + vllm_extra_kwargs=None, + hf_model_callback=None, + atol=MTEB_EMBED_TOL, +): # A model family has many models with the same architecture, # and we don't need to test each one. if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: @@ -187,15 +186,15 @@ def mteb_test_embed_models(hf_runner, if ci_envs.VLLM_CI_HEAD_DTYPE is not None: if "hf_overrides" not in vllm_extra_kwargs: vllm_extra_kwargs["hf_overrides"] = {} - vllm_extra_kwargs["hf_overrides"][ - "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE - - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - enforce_eager=True, - **vllm_extra_kwargs) as vllm_model: - + vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=None, + enforce_eager=True, + **vllm_extra_kwargs, + ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config # Confirm whether vllm is using the correct architecture @@ -204,28 +203,29 @@ def mteb_test_embed_models(hf_runner, # Confirm whether vllm uses the correct default_pooling_type, which # relates to whether chunked prefill and prefix caching are enabled - assert (model_config._model_info.default_pooling_type == - model_info.default_pooling_type) + assert ( + model_config._model_info.default_pooling_type + == model_info.default_pooling_type + ) - vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), - MTEB_EMBED_TASKS) + vllm_main_score = run_mteb_embed_task( + VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS + ) vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype head_dtype = model_config.head_dtype # Test embed_dims, isnan and whether to use normalize - vllm_outputs = vllm_model.embed(example_prompts, - truncate_prompt_tokens=-1) + vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1) assert not torch.any(torch.isnan(torch.tensor(vllm_outputs))) # Accelerate mteb test by setting # SentenceTransformers mteb score to a constant if model_info.mteb_score is None: with hf_runner( - model_info.name, - is_sentence_transformer=True, - dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, + model_info.name, + is_sentence_transformer=True, + dtype=ci_envs.VLLM_CI_HF_DTYPE or model_info.hf_dtype, ) as hf_model: - # e.g. setting default parameters for the encode method of hf_runner if hf_model_callback is not None: hf_model_callback(hf_model) @@ -247,8 +247,7 @@ def mteb_test_embed_models(hf_runner, st_dtype = "Constant" print("Model:", model_info.name) - print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", - vllm_main_score) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) @@ -282,26 +281,21 @@ def run_mteb_rerank(cross_encoder, tasks, languages): top_k=10, save_predictions=True, output_folder=f"{results_folder}/stage2", - previous_results= - f"{results_folder}/stage1/NFCorpus_{subset}_predictions.json", + previous_results=f"{results_folder}/stage1/NFCorpus_{subset}_predictions.json", encode_kwargs={"show_progress_bar": False}, ) main_score = results[0].scores["test"][0]["main_score"] return main_score -def mteb_test_rerank_models_hf(hf_runner, - model_name, - hf_dtype="float32", - hf_model_callback=None): - with hf_runner(model_name, is_cross_encoder=True, - dtype=hf_dtype) as hf_model: - +def mteb_test_rerank_models_hf( + hf_runner, model_name, hf_dtype="float32", hf_model_callback=None +): + with hf_runner(model_name, is_cross_encoder=True, dtype=hf_dtype) as hf_model: original_predict = hf_model.predict def _predict( - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt *args, **kwargs, ): @@ -315,20 +309,22 @@ def _predict( if hf_model_callback is not None: hf_model_callback(hf_model) - st_main_score = run_mteb_rerank(hf_model, - tasks=MTEB_RERANK_TASKS, - languages=MTEB_RERANK_LANGS) + st_main_score = run_mteb_rerank( + hf_model, tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS + ) st_dtype = next(hf_model.model.model.parameters()).dtype return st_main_score, st_dtype -def mteb_test_rerank_models(hf_runner, - vllm_runner, - model_info: RerankModelInfo, - vllm_extra_kwargs=None, - hf_model_callback=None, - vllm_mteb_encoder=VllmMtebEncoder, - atol=MTEB_RERANK_TOL): +def mteb_test_rerank_models( + hf_runner, + vllm_runner, + model_info: RerankModelInfo, + vllm_extra_kwargs=None, + hf_model_callback=None, + vllm_mteb_encoder=VllmMtebEncoder, + atol=MTEB_RERANK_TOL, +): # A model family has many models with the same architecture, # and we don't need to test each one. if not ci_envs.VLLM_CI_NO_SKIP and not model_info.enable_test: @@ -346,33 +342,37 @@ def mteb_test_rerank_models(hf_runner, if ci_envs.VLLM_CI_HEAD_DTYPE is not None: if "hf_overrides" not in vllm_extra_kwargs: vllm_extra_kwargs["hf_overrides"] = {} - vllm_extra_kwargs["hf_overrides"][ - "head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE - - with vllm_runner(model_info.name, - runner="pooling", - max_model_len=None, - max_num_seqs=8, - enforce_eager=True, - **vllm_extra_kwargs) as vllm_model: - + vllm_extra_kwargs["hf_overrides"]["head_dtype"] = ci_envs.VLLM_CI_HEAD_DTYPE + + with vllm_runner( + model_info.name, + runner="pooling", + max_model_len=None, + max_num_seqs=8, + enforce_eager=True, + **vllm_extra_kwargs, + ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config # Confirm whether vllm is using the correct architecture if model_info.architecture: - assert (model_info.architecture in model_config.architectures) + assert model_info.architecture in model_config.architectures # Score API is only enabled for num_labels == 1 assert model_config.hf_config.num_labels == 1 # Confirm whether vllm uses the correct default_pooling_type, which # relates to whether chunked prefill and prefix caching are enabled - assert (model_config._model_info.default_pooling_type == - model_info.default_pooling_type) + assert ( + model_config._model_info.default_pooling_type + == model_info.default_pooling_type + ) - vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model), - tasks=MTEB_RERANK_TASKS, - languages=MTEB_RERANK_LANGS) + vllm_main_score = run_mteb_rerank( + vllm_mteb_encoder(vllm_model), + tasks=MTEB_RERANK_TASKS, + languages=MTEB_RERANK_LANGS, + ) vllm_dtype = model_config.dtype head_dtype = model_config.head_dtype @@ -380,14 +380,14 @@ def mteb_test_rerank_models(hf_runner, # SentenceTransformers mteb score to a constant if model_info.mteb_score is None: st_main_score, st_dtype = mteb_test_rerank_models_hf( - hf_runner, model_info.name, model_info.hf_dtype, hf_model_callback) + hf_runner, model_info.name, model_info.hf_dtype, hf_model_callback + ) else: st_main_score = model_info.mteb_score st_dtype = "Constant" print("Model:", model_info.name) - print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", - vllm_main_score) + print("VLLM:", f"dtype:{vllm_dtype}", f"head_dtype:{head_dtype}", vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) diff --git a/tests/models/language/pooling_mteb_test/test_baai.py b/tests/models/language/pooling_mteb_test/test_baai.py index e131c9b1038d..bad13e245714 100644 --- a/tests/models/language/pooling_mteb_test/test_baai.py +++ b/tests/models/language/pooling_mteb_test/test_baai.py @@ -2,67 +2,76 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from tests.models.language.pooling.embed_utils import ( - correctness_test_embed_models) -from tests.models.utils import (CLSPoolingEmbedModelInfo, - CLSPoolingRerankModelInfo, EmbedModelInfo, - LASTPoolingEmbedModelInfo, RerankModelInfo) +from tests.models.language.pooling.embed_utils import correctness_test_embed_models +from tests.models.utils import ( + CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, + EmbedModelInfo, + LASTPoolingEmbedModelInfo, + RerankModelInfo, +) from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ ########## BertModel - CLSPoolingEmbedModelInfo("BAAI/bge-base-en", - architecture="BertModel", - mteb_score=0.779336792, - enable_test=True), - CLSPoolingEmbedModelInfo("BAAI/bge-base-zh", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-small-en", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-small-zh", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-large-en", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-large-zh", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-noinstruct", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-base-en-v1.5", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-base-zh-v1.5", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-small-en-v1.5", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-small-zh-v1.5", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-large-en-v1.5", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("BAAI/bge-large-zh-v1.5", - architecture="BertModel", - enable_test=False), + CLSPoolingEmbedModelInfo( + "BAAI/bge-base-en", + architecture="BertModel", + mteb_score=0.779336792, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-base-zh", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-small-en", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-small-zh", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-large-en", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-large-zh", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-large-zh-noinstruct", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-base-en-v1.5", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-base-zh-v1.5", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-small-en-v1.5", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-small-zh-v1.5", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-large-en-v1.5", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "BAAI/bge-large-zh-v1.5", architecture="BertModel", enable_test=False + ), ########## XLMRobertaModel - CLSPoolingEmbedModelInfo("BAAI/bge-m3", - architecture="XLMRobertaModel", - mteb_score=0.787343078, - enable_test=True), + CLSPoolingEmbedModelInfo( + "BAAI/bge-m3", + architecture="XLMRobertaModel", + mteb_score=0.787343078, + enable_test=True, + ), ########## Qwen2Model - LASTPoolingEmbedModelInfo("BAAI/bge-code-v1", - architecture="Qwen2Model", - mteb_score=0.75724465, - dtype="float32", - enable_test=True), + LASTPoolingEmbedModelInfo( + "BAAI/bge-code-v1", + architecture="Qwen2Model", + mteb_score=0.75724465, + dtype="float32", + enable_test=True, + ), ] RERANK_MODELS = [ @@ -71,33 +80,35 @@ "BAAI/bge-reranker-base", architecture="XLMRobertaForSequenceClassification", mteb_score=0.32398, - enable_test=True), + enable_test=True, + ), CLSPoolingRerankModelInfo( "BAAI/bge-reranker-large", architecture="XLMRobertaForSequenceClassification", - enable_test=False), + enable_test=False, + ), CLSPoolingRerankModelInfo( "BAAI/bge-reranker-v2-m3", architecture="XLMRobertaForSequenceClassification", - enable_test=False) + enable_test=False, + ), ] @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py index 1eca2a2c0abd..9e95dd74c397 100644 --- a/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py +++ b/tests/models/language/pooling_mteb_test/test_bge_reranker_v2_gemma.py @@ -8,53 +8,50 @@ from tests.conftest import HfRunner from tests.models.language.pooling_mteb_test.mteb_utils import ( - VllmMtebEncoder, mteb_test_rerank_models) + VllmMtebEncoder, + mteb_test_rerank_models, +) from tests.models.utils import LASTPoolingRerankModelInfo, RerankModelInfo RERANK_MODELS = [ - LASTPoolingRerankModelInfo("BAAI/bge-reranker-v2-gemma", - architecture="GemmaForSequenceClassification", - mteb_score=0.33757, - hf_overrides={ - "architectures": - ["GemmaForSequenceClassification"], - "classifier_from_token": ["Yes"], - "method": - "no_post_processing", - }), + LASTPoolingRerankModelInfo( + "BAAI/bge-reranker-v2-gemma", + architecture="GemmaForSequenceClassification", + mteb_score=0.33757, + hf_overrides={ + "architectures": ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": "no_post_processing", + }, + ), ] PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 class GemmaRerankerHfRunner(HfRunner): - - def __init__(self, - model_name: str, - dtype: str = "auto", - *args: Any, - **kwargs: Any) -> None: + def __init__( + self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any + ) -> None: from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, - padding_side='left') + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes") @torch.no_grad() - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: def get_inputs(pairs, tokenizer, prompt=None): if prompt is None: prompt = PROMPT sep = "\n" - prompt_inputs = tokenizer(prompt, - return_tensors=None, - add_special_tokens=False)["input_ids"] - sep_inputs = tokenizer(sep, - return_tensors=None, - add_special_tokens=False)["input_ids"] + prompt_inputs = tokenizer( + prompt, return_tensors=None, add_special_tokens=False + )["input_ids"] + sep_inputs = tokenizer(sep, return_tensors=None, add_special_tokens=False)[ + "input_ids" + ] inputs = [] for query, passage in pairs: query_inputs = tokenizer( @@ -78,8 +75,7 @@ def get_inputs(pairs, tokenizer, prompt=None): return_token_type_ids=False, add_special_tokens=False, ) - item["input_ids"] = item[ - "input_ids"] + sep_inputs + prompt_inputs + item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs item["attention_mask"] = [1] * len(item["input_ids"]) inputs.append(item) return tokenizer.pad( @@ -95,14 +91,19 @@ def get_inputs(pairs, tokenizer, prompt=None): inputs = inputs.to(self.model.device) _n_tokens = inputs["input_ids"].shape[1] logits = self.model(**inputs, return_dict=True).logits - _scores = (logits[:, -1, - self.yes_loc].view(-1, ).float().sigmoid()) + _scores = ( + logits[:, -1, self.yes_loc] + .view( + -1, + ) + .float() + .sigmoid() + ) scores.append(_scores[0].item()) return torch.Tensor(scores) class GemmaMtebEncoder(VllmMtebEncoder): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.query_template = "A: {query}\n" @@ -110,12 +111,10 @@ def __init__(self, *args, **kwargs): def predict( self, - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt + sentences: list[tuple[str, str, Optional[str]]], # query, corpus, prompt *args, **kwargs, ) -> np.ndarray: - _sentences = [] for query, corpus, prompt in sentences: query = self.query_template.format(query=query) @@ -127,8 +126,9 @@ def predict( @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - - mteb_test_rerank_models(GemmaRerankerHfRunner, - vllm_runner, - model_info, - vllm_mteb_encoder=GemmaMtebEncoder) + mteb_test_rerank_models( + GemmaRerankerHfRunner, + vllm_runner, + model_info, + vllm_mteb_encoder=GemmaMtebEncoder, + ) diff --git a/tests/models/language/pooling_mteb_test/test_cross_encoder.py b/tests/models/language/pooling_mteb_test/test_cross_encoder.py index ad320fae0c85..638ffc7a62b0 100644 --- a/tests/models/language/pooling_mteb_test/test_cross_encoder.py +++ b/tests/models/language/pooling_mteb_test/test_cross_encoder.py @@ -2,22 +2,30 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from tests.models.utils import (CLSPoolingRerankModelInfo, - LASTPoolingRerankModelInfo, RerankModelInfo) +from tests.models.utils import ( + CLSPoolingRerankModelInfo, + LASTPoolingRerankModelInfo, + RerankModelInfo, +) from .mteb_utils import mteb_test_rerank_models RERANK_MODELS = [ - CLSPoolingRerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2", - mteb_score=0.32898, - architecture="BertForSequenceClassification"), - LASTPoolingRerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", - mteb_score=0.25736, - architecture="Qwen3ForSequenceClassification") + CLSPoolingRerankModelInfo( + "cross-encoder/ms-marco-TinyBERT-L-2-v2", + mteb_score=0.32898, + architecture="BertForSequenceClassification", + ), + LASTPoolingRerankModelInfo( + "tomaarsen/Qwen3-Reranker-0.6B-seq-cls", + mteb_score=0.25736, + architecture="Qwen3ForSequenceClassification", + ), ] @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling_mteb_test/test_gte.py b/tests/models/language/pooling_mteb_test/test_gte.py index 9ae43fd05bf7..a22821fd65b5 100644 --- a/tests/models/language/pooling_mteb_test/test_gte.py +++ b/tests/models/language/pooling_mteb_test/test_gte.py @@ -3,74 +3,93 @@ import pytest -from tests.models.language.pooling.embed_utils import ( - correctness_test_embed_models) -from tests.models.utils import (CLSPoolingEmbedModelInfo, - CLSPoolingRerankModelInfo, EmbedModelInfo, - LASTPoolingEmbedModelInfo, RerankModelInfo) +from tests.models.language.pooling.embed_utils import correctness_test_embed_models +from tests.models.utils import ( + CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, + EmbedModelInfo, + LASTPoolingEmbedModelInfo, + RerankModelInfo, +) from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models MODELS = [ ########## BertModel - CLSPoolingEmbedModelInfo("thenlper/gte-large", - mteb_score=0.76807651, - architecture="BertModel", - enable_test=True), - CLSPoolingEmbedModelInfo("thenlper/gte-base", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("thenlper/gte-small", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("thenlper/gte-large-zh", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("thenlper/gte-base-zh", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("thenlper/gte-small-zh", - architecture="BertModel", - enable_test=False), + CLSPoolingEmbedModelInfo( + "thenlper/gte-large", + mteb_score=0.76807651, + architecture="BertModel", + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "thenlper/gte-base", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "thenlper/gte-small", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "thenlper/gte-large-zh", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "thenlper/gte-base-zh", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "thenlper/gte-small-zh", architecture="BertModel", enable_test=False + ), ########### NewModel # These three architectures are almost the same, but not exactly the same. # For example, # - whether to use token_type_embeddings # - whether to use context expansion # So only test one (the most widely used) model - CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-multilingual-base", - architecture="GteNewModel", - mteb_score=0.775074696, - hf_overrides={"architectures": ["GteNewModel"]}, - enable_test=True), - CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", - architecture="GteNewModel", - hf_overrides={"architectures": ["GteNewModel"]}, - enable_test=False), - CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", - architecture="GteNewModel", - hf_overrides={"architectures": ["GteNewModel"]}, - enable_test=False), + CLSPoolingEmbedModelInfo( + "Alibaba-NLP/gte-multilingual-base", + architecture="GteNewModel", + mteb_score=0.775074696, + hf_overrides={"architectures": ["GteNewModel"]}, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "Alibaba-NLP/gte-base-en-v1.5", + architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "Alibaba-NLP/gte-large-en-v1.5", + architecture="GteNewModel", + hf_overrides={"architectures": ["GteNewModel"]}, + enable_test=False, + ), ########### Qwen2ForCausalLM - LASTPoolingEmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", - mteb_score=0.758473459018872, - architecture="Qwen2ForCausalLM", - enable_test=True), + LASTPoolingEmbedModelInfo( + "Alibaba-NLP/gte-Qwen2-1.5B-instruct", + mteb_score=0.758473459018872, + architecture="Qwen2ForCausalLM", + enable_test=True, + ), ########## ModernBertModel - CLSPoolingEmbedModelInfo("Alibaba-NLP/gte-modernbert-base", - mteb_score=0.748193353, - architecture="ModernBertModel", - enable_test=True), + CLSPoolingEmbedModelInfo( + "Alibaba-NLP/gte-modernbert-base", + mteb_score=0.748193353, + architecture="ModernBertModel", + enable_test=True, + ), ########## Qwen3ForCausalLM - LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", - mteb_score=0.771163695, - architecture="Qwen3ForCausalLM", - dtype="float32", - enable_test=True), - LASTPoolingEmbedModelInfo("Qwen/Qwen3-Embedding-4B", - architecture="Qwen3ForCausalLM", - dtype="float32", - enable_test=False), + LASTPoolingEmbedModelInfo( + "Qwen/Qwen3-Embedding-0.6B", + mteb_score=0.771163695, + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=True, + ), + LASTPoolingEmbedModelInfo( + "Qwen/Qwen3-Embedding-4B", + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=False, + ), ] RERANK_MODELS = [ @@ -79,31 +98,32 @@ "Alibaba-NLP/gte-reranker-modernbert-base", mteb_score=0.33386, architecture="ModernBertForSequenceClassification", - enable_test=True), + enable_test=True, + ), CLSPoolingRerankModelInfo( "Alibaba-NLP/gte-multilingual-reranker-base", mteb_score=0.33062, architecture="GteNewForSequenceClassification", hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, - enable_test=True), + enable_test=True, + ), ] @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: mteb_test_rerank_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/language/pooling_mteb_test/test_intfloat.py b/tests/models/language/pooling_mteb_test/test_intfloat.py index 0d6026898ad4..1d078db69236 100644 --- a/tests/models/language/pooling_mteb_test/test_intfloat.py +++ b/tests/models/language/pooling_mteb_test/test_intfloat.py @@ -2,50 +2,55 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from tests.models.language.pooling.embed_utils import ( - correctness_test_embed_models) +from tests.models.language.pooling.embed_utils import correctness_test_embed_models from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo from .mteb_utils import mteb_test_embed_models MODELS = [ ########## BertModel - CLSPoolingEmbedModelInfo("intfloat/e5-small", - architecture="BertModel", - mteb_score=0.742285423, - enable_test=True), - CLSPoolingEmbedModelInfo("intfloat/e5-base", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("intfloat/e5-large", - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-small", - architecture="BertModel", - enable_test=False), + CLSPoolingEmbedModelInfo( + "intfloat/e5-small", + architecture="BertModel", + mteb_score=0.742285423, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "intfloat/e5-base", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "intfloat/e5-large", architecture="BertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "intfloat/multilingual-e5-small", architecture="BertModel", enable_test=False + ), ########## XLMRobertaModel - CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-base", - architecture="XLMRobertaModel", - mteb_score=0.779325955, - enable_test=True), - CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large", - architecture="XLMRobertaModel", - enable_test=False), - CLSPoolingEmbedModelInfo("intfloat/multilingual-e5-large-instruct", - architecture="XLMRobertaModel", - enable_test=False), + CLSPoolingEmbedModelInfo( + "intfloat/multilingual-e5-base", + architecture="XLMRobertaModel", + mteb_score=0.779325955, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "intfloat/multilingual-e5-large", + architecture="XLMRobertaModel", + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "intfloat/multilingual-e5-large-instruct", + architecture="XLMRobertaModel", + enable_test=False, + ), ] @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling_mteb_test/test_jina.py b/tests/models/language/pooling_mteb_test/test_jina.py index 0a77a78bb31b..0a712b2542f3 100644 --- a/tests/models/language/pooling_mteb_test/test_jina.py +++ b/tests/models/language/pooling_mteb_test/test_jina.py @@ -5,60 +5,68 @@ import pytest from tests.models.language.pooling.embed_utils import ( - check_embeddings_close, correctness_test_embed_models, matryoshka_fy) -from tests.models.utils import (CLSPoolingEmbedModelInfo, - CLSPoolingRerankModelInfo, EmbedModelInfo, - RerankModelInfo) + check_embeddings_close, + correctness_test_embed_models, + matryoshka_fy, +) +from tests.models.utils import ( + CLSPoolingEmbedModelInfo, + CLSPoolingRerankModelInfo, + EmbedModelInfo, + RerankModelInfo, +) from vllm import PoolingParams from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models EMBEDDING_MODELS = [ - CLSPoolingEmbedModelInfo("jinaai/jina-embeddings-v3", - mteb_score=0.824413164, - architecture="XLMRobertaModel", - is_matryoshka=True) + CLSPoolingEmbedModelInfo( + "jinaai/jina-embeddings-v3", + mteb_score=0.824413164, + architecture="XLMRobertaModel", + is_matryoshka=True, + ) ] RERANK_MODELS = [ CLSPoolingRerankModelInfo( "jinaai/jina-reranker-v2-base-multilingual", mteb_score=0.33643, - architecture="XLMRobertaForSequenceClassification") + architecture="XLMRobertaForSequenceClassification", + ) ] @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: def hf_model_callback(model): model.encode = partial(model.encode, task="text-matching") - mteb_test_embed_models(hf_runner, - vllm_runner, - model_info, - hf_model_callback=hf_model_callback) + mteb_test_embed_models( + hf_runner, vllm_runner, model_info, hf_model_callback=hf_model_callback + ) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: def hf_model_callback(model): model.encode = partial(model.encode, task="text-matching") - correctness_test_embed_models(hf_runner, - vllm_runner, - model_info, - example_prompts, - hf_model_callback=hf_model_callback) + correctness_test_embed_models( + hf_runner, + vllm_runner, + model_info, + example_prompts, + hf_model_callback=hf_model_callback, + ) @pytest.mark.parametrize("model_info", RERANK_MODELS) -def test_rerank_models_mteb(hf_runner, vllm_runner, - model_info: RerankModelInfo) -> None: +def test_rerank_models_mteb( + hf_runner, vllm_runner, model_info: RerankModelInfo +) -> None: mteb_test_rerank_models(hf_runner, vllm_runner, model_info) @@ -81,32 +89,32 @@ def test_matryoshka( example_prompts = [str(s).strip() for s in example_prompts] with hf_runner( - model_info.name, - dtype=dtype, - is_sentence_transformer=True, + model_info.name, + dtype=dtype, + is_sentence_transformer=True, ) as hf_model: hf_outputs = hf_model.encode(example_prompts, task="text-matching") hf_outputs = matryoshka_fy(hf_outputs, dimensions) - with vllm_runner(model_info.name, - runner="pooling", - dtype=dtype, - max_model_len=None) as vllm_model: + with vllm_runner( + model_info.name, runner="pooling", dtype=dtype, max_model_len=None + ) as vllm_model: assert vllm_model.llm.llm_engine.model_config.is_matryoshka matryoshka_dimensions = ( - vllm_model.llm.llm_engine.model_config.matryoshka_dimensions) + vllm_model.llm.llm_engine.model_config.matryoshka_dimensions + ) assert matryoshka_dimensions is not None if dimensions not in matryoshka_dimensions: with pytest.raises(ValueError): vllm_model.embed( - example_prompts, - pooling_params=PoolingParams(dimensions=dimensions)) + example_prompts, pooling_params=PoolingParams(dimensions=dimensions) + ) else: vllm_outputs = vllm_model.embed( - example_prompts, - pooling_params=PoolingParams(dimensions=dimensions)) + example_prompts, pooling_params=PoolingParams(dimensions=dimensions) + ) check_embeddings_close( embeddings_0_lst=hf_outputs, diff --git a/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py index 05ebb4ec4d3f..fd04dc199023 100644 --- a/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py +++ b/tests/models/language/pooling_mteb_test/test_mxbai_rerank.py @@ -17,46 +17,45 @@ } RERANK_MODELS = [ - LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", - architecture="Qwen2ForSequenceClassification", - hf_overrides=mxbai_rerank_hf_overrides, - mteb_score=0.273, - enable_test=True), - LASTPoolingRerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", - architecture="Qwen2ForSequenceClassification", - hf_overrides=mxbai_rerank_hf_overrides, - enable_test=False) + LASTPoolingRerankModelInfo( + "mixedbread-ai/mxbai-rerank-base-v2", + architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, + mteb_score=0.273, + enable_test=True, + ), + LASTPoolingRerankModelInfo( + "mixedbread-ai/mxbai-rerank-large-v2", + architecture="Qwen2ForSequenceClassification", + hf_overrides=mxbai_rerank_hf_overrides, + enable_test=False, + ), ] class MxbaiRerankerHfRunner(HfRunner): - - def __init__(self, - model_name: str, - dtype: str = "auto", - *args: Any, - **kwargs: Any) -> None: + def __init__( + self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any + ) -> None: from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, - padding_side='left') + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.yes_loc = self.tokenizer.convert_tokens_to_ids("1") self.no_loc = self.tokenizer.convert_tokens_to_ids("0") - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: def process_inputs(pairs): - inputs = self.tokenizer(pairs, - padding=False, - truncation='longest_first', - return_attention_mask=False) - for i, ele in enumerate(inputs['input_ids']): - inputs['input_ids'][i] = ele - inputs = self.tokenizer.pad(inputs, - padding=True, - return_tensors="pt") + inputs = self.tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = ele + inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt") for key in inputs: inputs[key] = inputs[key].to(self.model.device) return inputs diff --git a/tests/models/language/pooling_mteb_test/test_nomic.py b/tests/models/language/pooling_mteb_test/test_nomic.py index 61512fd0dff1..c54a43052483 100644 --- a/tests/models/language/pooling_mteb_test/test_nomic.py +++ b/tests/models/language/pooling_mteb_test/test_nomic.py @@ -3,39 +3,42 @@ import pytest -from tests.models.language.pooling.embed_utils import ( - correctness_test_embed_models) +from tests.models.language.pooling.embed_utils import correctness_test_embed_models from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo from .mteb_utils import mteb_test_embed_models MODELS = [ - CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1", - architecture="NomicBertModel", - mteb_score=0.737568559, - enable_test=True), - CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", - architecture="NomicBertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("nomic-ai/CodeRankEmbed", - architecture="NomicBertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", - architecture="NomicBertModel", - mteb_score=0.715488912, - enable_test=True) + CLSPoolingEmbedModelInfo( + "nomic-ai/nomic-embed-text-v1", + architecture="NomicBertModel", + mteb_score=0.737568559, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "nomic-ai/nomic-embed-text-v1.5", + architecture="NomicBertModel", + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "nomic-ai/CodeRankEmbed", architecture="NomicBertModel", enable_test=False + ), + CLSPoolingEmbedModelInfo( + "nomic-ai/nomic-embed-text-v2-moe", + architecture="NomicBertModel", + mteb_score=0.715488912, + enable_test=True, + ), ] @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py index 65403081dc0f..00e99f44cfdb 100644 --- a/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py +++ b/tests/models/language/pooling_mteb_test/test_qwen3_reranker.py @@ -18,46 +18,45 @@ } RERANK_MODELS = [ - LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-0.6B", - architecture="Qwen3ForSequenceClassification", - mteb_score=0.25736, - hf_overrides=qwen3_reranker_hf_overrides, - enable_test=True), - LASTPoolingRerankModelInfo("Qwen/Qwen3-Reranker-4B", - architecture="Qwen3ForSequenceClassification", - hf_overrides=qwen3_reranker_hf_overrides, - enable_test=False) + LASTPoolingRerankModelInfo( + "Qwen/Qwen3-Reranker-0.6B", + architecture="Qwen3ForSequenceClassification", + mteb_score=0.25736, + hf_overrides=qwen3_reranker_hf_overrides, + enable_test=True, + ), + LASTPoolingRerankModelInfo( + "Qwen/Qwen3-Reranker-4B", + architecture="Qwen3ForSequenceClassification", + hf_overrides=qwen3_reranker_hf_overrides, + enable_test=False, + ), ] class Qwen3RerankerHfRunner(HfRunner): - - def __init__(self, - model_name: str, - dtype: str = "auto", - *args: Any, - **kwargs: Any) -> None: + def __init__( + self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any + ) -> None: from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, - padding_side='left') + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") - def predict(self, prompts: list[list[str]], *args, - **kwargs) -> torch.Tensor: - + def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: def process_inputs(pairs): - inputs = self.tokenizer(pairs, - padding=False, - truncation='longest_first', - return_attention_mask=False) - for i, ele in enumerate(inputs['input_ids']): - inputs['input_ids'][i] = ele - inputs = self.tokenizer.pad(inputs, - padding=True, - return_tensors="pt") + inputs = self.tokenizer( + pairs, + padding=False, + truncation="longest_first", + return_attention_mask=False, + ) + for i, ele in enumerate(inputs["input_ids"]): + inputs["input_ids"][i] = ele + inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt") for key in inputs: inputs[key] = inputs[key].to(self.model.device) return inputs @@ -82,20 +81,18 @@ def compute_logits(inputs): @pytest.mark.parametrize("model_info", RERANK_MODELS) def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", RERANK_MODELS) @multi_gpu_test(num_gpus=2) -def test_rerank_models_mteb_tp(vllm_runner, - model_info: RerankModelInfo) -> None: - +def test_rerank_models_mteb_tp(vllm_runner, model_info: RerankModelInfo) -> None: assert model_info.architecture == "Qwen3ForSequenceClassification" vllm_extra_kwargs: dict[str, Any] = { "tensor_parallel_size": 2, } - mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, - vllm_extra_kwargs) + mteb_test_rerank_models( + Qwen3RerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs + ) diff --git a/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py b/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py index 91bad2c4e42f..3c30628aeaa4 100644 --- a/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling_mteb_test/test_snowflake_arctic_embed.py @@ -3,62 +3,75 @@ import pytest -from tests.models.language.pooling.embed_utils import ( - correctness_test_embed_models) +from tests.models.language.pooling.embed_utils import correctness_test_embed_models from tests.models.utils import CLSPoolingEmbedModelInfo, EmbedModelInfo from .mteb_utils import mteb_test_embed_models MODELS = [ - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", - is_matryoshka=False, - architecture="BertModel", - mteb_score=0.714927797, - enable_test=True), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-s", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long", - is_matryoshka=False, - architecture="NomicBertModel", - mteb_score=0.681146831, - enable_test=True), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l", - is_matryoshka=False, - architecture="BertModel", - enable_test=False), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - architecture="BertModel", - mteb_score=0.649088363, - enable_test=True), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-l-v2.0", - is_matryoshka=True, - architecture="XLMRobertaModel", - mteb_score=0.712258299, - enable_test=True), - CLSPoolingEmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v2.0", - is_matryoshka=True, - architecture="GteModel", - mteb_score=0.706622444, - enable_test=True), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-xs", + is_matryoshka=False, + architecture="BertModel", + mteb_score=0.714927797, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-s", + is_matryoshka=False, + architecture="BertModel", + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m", + is_matryoshka=False, + architecture="BertModel", + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-long", + is_matryoshka=False, + architecture="NomicBertModel", + mteb_score=0.681146831, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-l", + is_matryoshka=False, + architecture="BertModel", + enable_test=False, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + architecture="BertModel", + mteb_score=0.649088363, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-l-v2.0", + is_matryoshka=True, + architecture="XLMRobertaModel", + mteb_score=0.712258299, + enable_test=True, + ), + CLSPoolingEmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v2.0", + is_matryoshka=True, + architecture="GteModel", + mteb_score=0.706622444, + enable_test=True, + ), ] @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_embed_models_correctness(hf_runner, vllm_runner, - model_info: EmbedModelInfo, - example_prompts) -> None: - correctness_test_embed_models(hf_runner, vllm_runner, model_info, - example_prompts) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, example_prompts) diff --git a/tests/models/language/pooling_mteb_test/test_st_projector.py b/tests/models/language/pooling_mteb_test/test_st_projector.py index bd493e7e2ba0..91b1ef828d0d 100644 --- a/tests/models/language/pooling_mteb_test/test_st_projector.py +++ b/tests/models/language/pooling_mteb_test/test_st_projector.py @@ -2,8 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from tests.models.utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo, - LASTPoolingEmbedModelInfo) +from tests.models.utils import ( + CLSPoolingEmbedModelInfo, + EmbedModelInfo, + LASTPoolingEmbedModelInfo, +) from .mteb_utils import mteb_test_embed_models @@ -15,15 +18,15 @@ mteb_score=0.688611955, enable_test=True, ), - LASTPoolingEmbedModelInfo("google/embeddinggemma-300m", - architecture="Gemma3TextModel", - mteb_score=0.7473819294684156, - enable_test=True) + LASTPoolingEmbedModelInfo( + "google/embeddinggemma-300m", + architecture="Gemma3TextModel", + mteb_score=0.7473819294684156, + enable_test=True, + ), ] @pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) -def test_embed_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index c378ef670f91..c57ccd62fe6c 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -3,27 +3,40 @@ """Common tests for testing .generate() functionality for single / multiple image, embedding, and video support for different VLMs in vLLM. """ + import math import os from collections import defaultdict from pathlib import PosixPath import pytest -from transformers import (AutoModel, AutoModelForImageTextToText, - AutoModelForTextToWaveform) +from transformers import ( + AutoModel, + AutoModelForImageTextToText, + AutoModelForTextToWaveform, +) from vllm.platforms import current_platform from vllm.utils import identity -from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner, - ImageTestAssets, VideoTestAssets, VllmRunner) -from ....utils import (create_new_process_for_each_test, large_gpu_mark, - multi_gpu_marks) +from ....conftest import ( + IMAGE_ASSETS, + AudioTestAssets, + HfRunner, + ImageTestAssets, + VideoTestAssets, + VllmRunner, +) +from ....utils import create_new_process_for_each_test, large_gpu_mark, multi_gpu_marks from ...utils import check_outputs_equal from .vlm_utils import custom_inputs, model_utils, runners from .vlm_utils.case_filtering import get_parametrized_options -from .vlm_utils.types import (CustomTestOptions, ExpandableVLMTestArgs, - VLMTestInfo, VLMTestType) +from .vlm_utils.types import ( + CustomTestOptions, + ExpandableVLMTestArgs, + VLMTestInfo, + VLMTestType, +) # This hack is needed for phi3v & paligemma models # ROCm Triton FA can run into shared memory issues with these models, @@ -32,18 +45,17 @@ if current_platform.is_rocm(): os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" -# yapf: disable COMMON_BROADCAST_SETTINGS = { "test_type": VLMTestType.IMAGE, "dtype": "half", "max_tokens": 5, "tensor_parallel_size": 2, "hf_model_kwargs": {"device_map": "auto"}, - "image_size_factors": [(.25, 0.5, 1.0)], + "image_size_factors": [(0.25, 0.5, 1.0)], "distributed_executor_backend": ( "ray", "mp", - ) + ), } ### Test configuration for specific models @@ -83,22 +95,20 @@ #### Core tests to always run in the CI "llava": VLMTestInfo( models=["llava-hf/llava-1.5-7b-hf"], - test_type=( - VLMTestType.EMBEDDING, - VLMTestType.IMAGE, - VLMTestType.CUSTOM_INPUTS - ), + test_type=(VLMTestType.EMBEDDING, VLMTestType.IMAGE, VLMTestType.CUSTOM_INPUTS), prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", convert_assets_to_embeddings=model_utils.get_llava_embeddings, max_model_len=4096, auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( - formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:" - ), - limit_mm_per_prompt={"image": 4}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( + formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:" + ), + limit_mm_per_prompt={"image": 4}, + ) + ], # TODO: Revert to "auto" when CPU backend can use torch > 2.6 dtype="bfloat16" if current_platform.is_cpu() else "auto", marks=[pytest.mark.core_model, pytest.mark.cpu_model], @@ -107,27 +117,27 @@ models=["google/paligemma-3b-mix-224"], test_type=VLMTestType.IMAGE, prompt_formatter=identity, - img_idx_to_prompt = lambda idx: "", + img_idx_to_prompt=lambda idx: "", # Paligemma uses its own sample prompts because the default one fails - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "caption es", - "cherry_blossom": "What is in the picture?", - }), + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "caption es", + "cherry_blossom": "What is in the picture?", + } + ), auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, dtype="bfloat16", - marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501 + marks=[ + pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask") + ], ), "qwen2_5_vl": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], - test_type=( - VLMTestType.IMAGE, - VLMTestType.MULTI_IMAGE, - VLMTestType.VIDEO - ), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 - video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501 + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", + video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -137,17 +147,13 @@ ), "qwen2_5_omni": VLMTestInfo( models=["Qwen/Qwen2.5-Omni-3B"], - test_type=( - VLMTestType.IMAGE, - VLMTestType.MULTI_IMAGE, - VLMTestType.VIDEO - ), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|vision_bos|><|IMAGE|><|vision_eos|>", # noqa: E501 - video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501 + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_bos|><|IMAGE|><|vision_eos|>", + video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", max_model_len=4096, max_num_seqs=2, - num_logprobs= 6 if current_platform.is_cpu() else 5, + num_logprobs=6 if current_platform.is_cpu() else 5, auto_cls=AutoModelForTextToWaveform, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, patch_hf_runner=model_utils.qwen2_5_omni_patch_hf_runner, @@ -155,9 +161,9 @@ marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "ultravox": VLMTestInfo( - models = ["fixie-ai/ultravox-v0_5-llama-3_2-1b"], + models=["fixie-ai/ultravox-v0_5-llama-3_2-1b"], test_type=VLMTestType.AUDIO, - prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 audio_idx_to_prompt=lambda idx: "<|audio|>", max_model_len=4096, max_num_seqs=2, @@ -171,9 +177,11 @@ "llava-onevision-transformers": VLMTestInfo( models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], test_type=VLMTestType.IMAGE, - prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 max_model_len=16384, - hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" + ), auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, image_size_factors=[(0.25, 0.5, 1.0)], @@ -188,7 +196,7 @@ "idefics3-transformers": VLMTestInfo( models=["HuggingFaceTB/SmolVLM-256M-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 img_idx_to_prompt=lambda idx: "<image>", max_model_len=8192, max_num_seqs=2, @@ -204,8 +212,8 @@ "qwen2_5_vl-transformers": VLMTestInfo( models=["Qwen/Qwen2.5-VL-3B-Instruct"], test_type=VLMTestType.IMAGE, - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -214,24 +222,24 @@ vllm_runner_kwargs={ "model_impl": "transformers", }, - # FIXME: Investigate mrope issue - marks=[large_gpu_mark(min_gb=32), - pytest.mark.skip(reason="Mrope issue")], + marks=[large_gpu_mark(min_gb=32)], ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<vlm_image>Please describe the image shortly.", - "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501 - }), - multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<vlm_image>Please describe the image shortly.", + "cherry_blossom": "<vlm_image>Please infer the season with reason.", + } + ), + multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", stop_str=["<|im_end|>"], image_size_factors=[(0.10, 0.15)], max_tokens=64, @@ -240,12 +248,14 @@ "aya_vision": VLMTestInfo( models=["CohereForAI/aya-vision-8b"], test_type=(VLMTestType.IMAGE), - prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>What's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>What is the season?", # noqa: E501 - }), - multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>What's the content in the center of the image?", + "cherry_blossom": "<image>What is the season?", + } + ), + multi_image_prompt="<image><image>Describe the two images in detail.", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -254,12 +264,14 @@ "aya_vision-multi_image": VLMTestInfo( models=["CohereForAI/aya-vision-8b"], test_type=(VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>What's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>What is the season?", # noqa: E501 - }), - multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{img_prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>What's the content in the center of the image?", + "cherry_blossom": "<image>What is the season?", + } + ), + multi_image_prompt="<image><image>Describe the two images in detail.", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -284,27 +296,29 @@ max_num_seqs=2, auto_cls=AutoModelForImageTextToText, # For chameleon, we only compare the sequences - vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], - hf_output_post_proc = lambda hf_output, model: hf_output[:2], + vllm_output_post_proc=lambda vllm_output, model: vllm_output[:2], + hf_output_post_proc=lambda hf_output, model: hf_output[:2], comparator=check_outputs_equal, max_tokens=8, dtype="bfloat16", ), "deepseek_vl_v2": VLMTestInfo( - models=["Isotr0py/deepseek-vl2-tiny"], # model repo using dynamic module + models=["Isotr0py/deepseek-vl2-tiny"], # model repo using dynamic module test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501 max_model_len=4096, max_num_seqs=2, - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>\nPlease infer the season with reason in details.", # noqa: E501 - }), - multi_image_prompt="image_1:<image>\nimage_2:<image>\nWhich image can we see the car and the tower?", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>\nWhat's the content in the center of the image?", + "cherry_blossom": "<image>\nPlease infer the season with reason in details.", # noqa: E501 + } + ), + multi_image_prompt="image_1:<image>\nimage_2:<image>\nWhich image can we see the car and the tower?", # noqa: E501 patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner, hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output, - stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], # noqa: E501 - image_size_factors=[(), (1.0, ), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)], + stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], + image_size_factors=[(), (1.0,), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)], ), "fuyu": VLMTestInfo( models=["adept/fuyu-8b"], @@ -323,11 +337,13 @@ "gemma3": VLMTestInfo( models=["google/gemma-3-4b-it"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<start_of_image>What's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<start_of_image>What is the season?", # noqa: E501 - }), + prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<start_of_image>What's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "<start_of_image>What is the season?", + } + ), multi_image_prompt="<start_of_image><start_of_image>Describe the two images in detail.", # noqa: E501 max_model_len=4096, max_num_seqs=2, @@ -339,11 +355,13 @@ "glm4v": VLMTestInfo( models=["zai-org/glm-4v-9b"], test_type=VLMTestType.IMAGE, - prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<|begin_of_image|><|endoftext|><|end_of_image|>What's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<|begin_of_image|><|endoftext|><|end_of_image|>What is the season?", # noqa: E501 - }), + prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<|begin_of_image|><|endoftext|><|end_of_image|>What's the content in the center of the image?", # noqa: E501 + "cherry_blossom": "<|begin_of_image|><|endoftext|><|end_of_image|>What is the season?", # noqa: E501 + } + ), max_model_len=2048, max_num_seqs=2, get_stop_token_ids=lambda tok: [151329, 151336, 151338], @@ -358,9 +376,9 @@ "glm4_1v": VLMTestInfo( models=["zai-org/GLM-4.1V-9B-Thinking"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", # noqa: E501 - video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", + img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", + video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", max_model_len=2048, max_num_seqs=2, get_stop_token_ids=lambda tok: [151329, 151336, 151338], @@ -377,23 +395,27 @@ max_num_seqs=2, auto_cls=AutoModelForImageTextToText, patch_hf_runner=model_utils.glm4_1v_patch_hf_runner, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.video_with_metadata_glm4_1v(), - limit_mm_per_prompt={"video": 1}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.video_with_metadata_glm4_1v(), + limit_mm_per_prompt={"video": 1}, + ) + ], marks=[large_gpu_mark(min_gb=32)], ), "h2ovl": VLMTestInfo( - models = [ + models=[ "h2oai/h2ovl-mississippi-800m", "h2oai/h2ovl-mississippi-2b", ], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>\nWhat is the season?", - }), + prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>\nWhat's the content in the center of the image?", + "cherry_blossom": "<image>\nWhat is the season?", + } + ), multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501 max_model_len=8192, use_tokenizer_eos=True, @@ -403,7 +425,7 @@ "idefics3": VLMTestInfo( models=["HuggingFaceTB/SmolVLM-256M-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 img_idx_to_prompt=lambda idx: "<image>", max_model_len=8192, max_num_seqs=2, @@ -418,11 +440,13 @@ # "OpenGVLab/Mono-InternVL-2B", ], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>\nWhat is the season?", - }), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>\nWhat's the content in the center of the image?", + "cherry_blossom": "<image>\nWhat is the season?", + } + ), multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501 max_model_len=4096, use_tokenizer_eos=True, @@ -433,7 +457,7 @@ "OpenGVLab/InternVL3-1B", ], test_type=VLMTestType.VIDEO, - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 video_idx_to_prompt=lambda idx: "<video>", max_model_len=8192, use_tokenizer_eos=True, @@ -446,7 +470,7 @@ VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO, ), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>", video_idx_to_prompt=lambda idx: "<video>", max_model_len=8192, @@ -456,7 +480,7 @@ "kimi_vl": VLMTestInfo( models=["moonshotai/Kimi-VL-A3B-Instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_user|>user<|im_middle|>{img_prompt}<|im_end|><|im_assistant|>assistant<|im_middle|>", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_user|>user<|im_middle|>{img_prompt}<|im_end|><|im_assistant|>assistant<|im_middle|>", # noqa: E501 img_idx_to_prompt=lambda _: "<|media_start|>image<|media_content|><|media_pad|><|media_end|>", # noqa: E501 max_model_len=8192, max_num_seqs=2, @@ -467,11 +491,11 @@ ), "llama4": VLMTestInfo( models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"], - prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501 img_idx_to_prompt=lambda _: "<|image|>", test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), distributed_executor_backend="mp", - image_size_factors=[(.25, 0.5, 1.0)], + image_size_factors=[(0.25, 0.5, 1.0)], hf_model_kwargs={"device_map": "auto"}, max_model_len=8192, max_num_seqs=4, @@ -487,28 +511,34 @@ max_model_len=10240, auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( - formatter=lambda img_prompt: f"[INST] {img_prompt} [/INST]" - ), - limit_mm_per_prompt={"image": 4}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( + formatter=lambda img_prompt: f"[INST] {img_prompt} [/INST]" + ), + limit_mm_per_prompt={"image": 4}, + ) + ], ), "llava_onevision": VLMTestInfo( models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], test_type=VLMTestType.CUSTOM_INPUTS, - prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 num_video_frames=16, max_model_len=16384, - hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" + ), auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs( - formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - ), - limit_mm_per_prompt={"video": 4}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs( + formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + ), + limit_mm_per_prompt={"video": 4}, + ) + ], ), "llava_next_video": VLMTestInfo( models=["llava-hf/LLaVA-NeXT-Video-7B-hf"], @@ -550,7 +580,9 @@ img_idx_to_prompt=lambda idx: "(<image>./</image>)\n", max_model_len=4096, max_num_seqs=2, - get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 + get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids( + ["<|im_end|>", "<|endoftext|>"] + ), hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, # FIXME: https://huggingface.co/openbmb/MiniCPM-o-2_6/discussions/49 @@ -563,13 +595,15 @@ img_idx_to_prompt=lambda idx: "(<image>./</image>)\n", max_model_len=4096, max_num_seqs=2, - get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 + get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids( + ["<|im_end|>", "<|endoftext|>"] + ), hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, ), "minimax_vl_01": VLMTestInfo( models=["MiniMaxAI/MiniMax-VL-01"], - prompt_formatter=lambda img_prompt: f"<beginning_of_sentence>user: {img_prompt} assistant:<end_of_sentence>", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<beginning_of_sentence>user: {img_prompt} assistant:<end_of_sentence>", # noqa: E501 img_idx_to_prompt=lambda _: "<image>", test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), max_model_len=8192, @@ -591,8 +625,8 @@ "ovis1_6-gemma2": VLMTestInfo( models=["AIDC-AI/Ovis1.6-Gemma2-9B"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", max_model_len=4096, max_num_seqs=2, dtype="half", @@ -604,8 +638,8 @@ "ovis2": VLMTestInfo( models=["AIDC-AI/Ovis2-1B"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", max_model_len=4096, max_num_seqs=2, dtype="half", @@ -615,13 +649,9 @@ ), "ovis2_5": VLMTestInfo( models=["AIDC-AI/Ovis2.5-2B"], - test_type=( - VLMTestType.IMAGE, - VLMTestType.MULTI_IMAGE, - VLMTestType.VIDEO - ), - prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501 + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO), + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<image>\n", video_idx_to_prompt=lambda idx: "<video>\n", max_model_len=4096, max_num_seqs=2, @@ -633,7 +663,7 @@ "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|end|>\n<|assistant|>\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|end|>\n<|assistant|>\n", # noqa: E501 img_idx_to_prompt=lambda idx: f"<|image_{idx}|>\n", max_model_len=4096, max_num_seqs=2, @@ -668,15 +698,11 @@ ), "qwen2_vl": VLMTestInfo( models=["Qwen/Qwen2-VL-2B-Instruct"], - test_type=( - VLMTestType.IMAGE, - VLMTestType.MULTI_IMAGE, - VLMTestType.VIDEO - ), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 - video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501 - multi_image_prompt="Picture 1: <vlm_image>\nPicture 2: <vlm_image>\nDescribe these two images with one paragraph respectively.", # noqa: E501 + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", + video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", + multi_image_prompt="Picture 1: <vlm_image>\nPicture 2: <vlm_image>\nDescribe these two images with one paragraph respectively.", # noqa: E501 max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -687,12 +713,14 @@ "skywork_r1v": VLMTestInfo( models=["Skywork/Skywork-R1V-38B"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|><think>\n", # noqa: E501 - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501 - "cherry_blossom": "<image>\nWhat is the season?", - }), - multi_image_prompt="<image>\n<image>\nDescribe the two images in short.", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|><think>\n", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "<image>\nWhat's the content in the center of the image?", + "cherry_blossom": "<image>\nWhat is the season?", + } + ), + multi_image_prompt="<image>\n<image>\nDescribe the two images in short.", max_model_len=4096, use_tokenizer_eos=True, patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner, @@ -724,9 +752,9 @@ VLMTestType.MULTI_IMAGE, VLMTestType.VIDEO, ), - prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 - video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", + video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -739,11 +767,11 @@ prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForImageTextToText, - vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], - hf_output_post_proc = lambda hf_output, model: hf_output[:2], + vllm_output_post_proc=lambda vllm_output, model: vllm_output[:2], + hf_output_post_proc=lambda hf_output, model: hf_output[:2], comparator=check_outputs_equal, marks=multi_gpu_marks(num_gpus=2), - **COMMON_BROADCAST_SETTINGS # type: ignore + **COMMON_BROADCAST_SETTINGS, # type: ignore ), "llava-broadcast": VLMTestInfo( models=["llava-hf/llava-1.5-7b-hf"], @@ -752,7 +780,7 @@ auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, marks=multi_gpu_marks(num_gpus=2), - **COMMON_BROADCAST_SETTINGS # type: ignore + **COMMON_BROADCAST_SETTINGS, # type: ignore ), "llava_next-broadcast": VLMTestInfo( models=["llava-hf/llava-v1.6-mistral-7b-hf"], @@ -761,12 +789,12 @@ auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, marks=multi_gpu_marks(num_gpus=2), - **COMMON_BROADCAST_SETTINGS # type: ignore + **COMMON_BROADCAST_SETTINGS, # type: ignore ), ### Custom input edge-cases for specific models "intern_vl-diff-patches": VLMTestInfo( models=["OpenGVLab/InternVL2-2B"], - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=4096, use_tokenizer_eos=True, @@ -775,7 +803,8 @@ CustomTestOptions( inputs=inp, limit_mm_per_prompt={"image": 2}, - ) for inp in custom_inputs.different_patch_input_cases_internvl() + ) + for inp in custom_inputs.different_patch_input_cases_internvl() ], ), "llava_onevision-multiple-images": VLMTestInfo( @@ -784,14 +813,18 @@ max_model_len=16384, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" + ), vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( - formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 - ), - limit_mm_per_prompt={"image": 4}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( + formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + ), + limit_mm_per_prompt={"image": 4}, + ) + ], ), # regression test for https://github.com/vllm-project/vllm/issues/15122 "qwen2_5_vl-windows-attention": VLMTestInfo( @@ -801,13 +834,14 @@ max_num_seqs=2, auto_cls=AutoModelForImageTextToText, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.windows_attention_image_qwen2_5_vl(), - limit_mm_per_prompt={"image": 1}, - )], + custom_test_opts=[ + CustomTestOptions( + inputs=custom_inputs.windows_attention_image_qwen2_5_vl(), + limit_mm_per_prompt={"image": 1}, + ) + ], ), } -# yapf: enable def _mark_splits( @@ -828,7 +862,7 @@ def _mark_splits( new_test_settings = dict[str, VLMTestInfo]() for i in range(num_groups): - models_in_group = models[i * split_size:(i + 1) * split_size] + models_in_group = models[i * split_size : (i + 1) * split_size] for model in models_in_group: for info in test_infos_by_model[model]: @@ -859,7 +893,8 @@ def _mark_splits( VLM_TEST_SETTINGS, test_type=VLMTestType.IMAGE, create_new_process_for_each_test=False, - )) + ), +) def test_single_image_models( tmp_path: PosixPath, model_type: str, @@ -885,7 +920,8 @@ def test_single_image_models( VLM_TEST_SETTINGS, test_type=VLMTestType.MULTI_IMAGE, create_new_process_for_each_test=False, - )) + ), +) def test_multi_image_models( tmp_path: PosixPath, model_type: str, @@ -911,7 +947,8 @@ def test_multi_image_models( VLM_TEST_SETTINGS, test_type=VLMTestType.EMBEDDING, create_new_process_for_each_test=False, - )) + ), +) def test_image_embedding_models( model_type: str, test_case: ExpandableVLMTestArgs, @@ -935,7 +972,8 @@ def test_image_embedding_models( VLM_TEST_SETTINGS, test_type=VLMTestType.VIDEO, create_new_process_for_each_test=False, - )) + ), +) def test_video_models( model_type: str, test_case: ExpandableVLMTestArgs, @@ -959,7 +997,8 @@ def test_video_models( VLM_TEST_SETTINGS, test_type=VLMTestType.AUDIO, create_new_process_for_each_test=False, - )) + ), +) def test_audio_models( model_type: str, test_case: ExpandableVLMTestArgs, @@ -983,7 +1022,8 @@ def test_audio_models( VLM_TEST_SETTINGS, test_type=VLMTestType.CUSTOM_INPUTS, create_new_process_for_each_test=False, - )) + ), +) def test_custom_inputs_models( model_type: str, test_case: ExpandableVLMTestArgs, @@ -1006,7 +1046,8 @@ def test_custom_inputs_models( VLM_TEST_SETTINGS, test_type=VLMTestType.IMAGE, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() def test_single_image_models_heavy( tmp_path: PosixPath, @@ -1033,7 +1074,8 @@ def test_single_image_models_heavy( VLM_TEST_SETTINGS, test_type=VLMTestType.MULTI_IMAGE, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() def test_multi_image_models_heavy( tmp_path: PosixPath, @@ -1060,7 +1102,8 @@ def test_multi_image_models_heavy( VLM_TEST_SETTINGS, test_type=VLMTestType.EMBEDDING, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() def test_image_embedding_models_heavy( model_type: str, @@ -1085,7 +1128,8 @@ def test_image_embedding_models_heavy( VLM_TEST_SETTINGS, test_type=VLMTestType.VIDEO, create_new_process_for_each_test=True, - )) + ), +) def test_video_models_heavy( model_type: str, test_case: ExpandableVLMTestArgs, @@ -1109,7 +1153,8 @@ def test_video_models_heavy( VLM_TEST_SETTINGS, test_type=VLMTestType.AUDIO, create_new_process_for_each_test=True, - )) + ), +) def test_audio_models_heavy( model_type: str, test_case: ExpandableVLMTestArgs, @@ -1133,7 +1178,8 @@ def test_audio_models_heavy( VLM_TEST_SETTINGS, test_type=VLMTestType.CUSTOM_INPUTS, create_new_process_for_each_test=True, - )) + ), +) @create_new_process_for_each_test() def test_custom_inputs_models_heavy( model_type: str, diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index c1305e0ae31c..ef08b1916aa5 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -10,8 +10,7 @@ from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest -from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput, - VllmRunner) +from ....conftest import AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close @@ -64,50 +63,49 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size with vllm_runner( - model, - runner="generate", - max_model_len=max_model_len, - max_num_seqs=1, - dtype=dtype, - limit_mm_per_prompt={"audio": 1}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_lora=True, - max_lora_rank=64, - enforce_eager=True, + model, + runner="generate", + max_model_len=max_model_len, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"audio": 1}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=64, + enforce_eager=True, ) as vllm_model: lora_request = LoRARequest("audio", 1, audio_lora_path) vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - audios=audios, - lora_request=lora_request) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=audios, + lora_request=lora_request, + ) for prompts, audios in inputs ] - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: - + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSpeechSeq2Seq) as hf_model: hf_processor = hf_model.processor eos_token_id = hf_processor.tokenizer.eos_token_id hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - audios=[audios], - eos_token_id=eos_token_id) + hf_model.generate_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs=num_logprobs, + audios=[audios], + eos_token_id=eos_token_id, + ) for prompts, audios in inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): check_logprobs_close( outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(output) for output in vllm_outputs - ], + outputs_1_lst=[vllm_to_hf_output(output) for output in vllm_outputs], name_0="hf", name_1="vllm", ) @@ -118,9 +116,16 @@ def run_test( @pytest.mark.parametrize("max_model_len", [2048]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, model: str, - audio_assets: AudioTestAssets, dtype: str, max_model_len: int, - max_tokens: int, num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + model: str, + audio_assets: AudioTestAssets, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") diff --git a/tests/models/multimodal/generation/test_interleaved.py b/tests/models/multimodal/generation/test_interleaved.py index 1ef56af33a09..a773db19825e 100644 --- a/tests/models/multimodal/generation/test_interleaved.py +++ b/tests/models/multimodal/generation/test_interleaved.py @@ -28,8 +28,7 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None: give the same result. """ - image_cherry = convert_image_mode( - ImageAsset("cherry_blossom").pil_image, "RGB") + image_cherry = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") image_stop = convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB") images = [image_cherry, image_stop] video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays @@ -47,29 +46,30 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None: ), ] - with vllm_runner(model, - runner="generate", - dtype=dtype, - limit_mm_per_prompt={"image": 2}, - max_model_len=32768, - max_num_seqs=2, - tensor_parallel_size=1, - enforce_eager=True) as vllm_model: + with vllm_runner( + model, + runner="generate", + dtype=dtype, + limit_mm_per_prompt={"image": 2}, + max_model_len=32768, + max_num_seqs=2, + tensor_parallel_size=1, + enforce_eager=True, + ) as vllm_model: vllm_outputs_per_case = [ - vllm_model.generate_greedy(prompts, - max_tokens, - images=images, - videos=videos) + vllm_model.generate_greedy( + prompts, max_tokens, images=images, videos=videos + ) for prompts, images, videos in inputs ] all_results = [output[0][1] for output in vllm_outputs_per_case] - outputs = [(total_str, total_str.find("assistant\n") + len("assistant\n")) - for total_str in all_results] - prompt_lengths = [prompt_len for _, prompt_len in outputs] - generated_strs = [ - total_str[prompt_len:] for total_str, prompt_len in outputs + outputs = [ + (total_str, total_str.find("assistant\n") + len("assistant\n")) + for total_str in all_results ] + prompt_lengths = [prompt_len for _, prompt_len in outputs] + generated_strs = [total_str[prompt_len:] for total_str, prompt_len in outputs] interleaved_prompt_len, noninterleaved_prompt_len = prompt_lengths interleaved_output_str, noninterleaved_output_str = generated_strs diff --git a/tests/models/multimodal/generation/test_maverick.py b/tests/models/multimodal/generation/test_maverick.py index bacc9ef94f49..2f9b09f4026c 100644 --- a/tests/models/multimodal/generation/test_maverick.py +++ b/tests/models/multimodal/generation/test_maverick.py @@ -18,13 +18,11 @@ import pytest import torch from safetensors.torch import save_file -from transformers import (AutoConfig, AutoProcessor, AutoTokenizer, - GenerationConfig) +from transformers import AutoConfig, AutoProcessor, AutoTokenizer, GenerationConfig from vllm import LLM, SamplingParams from vllm.v1.executor.abstract import Executor -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec) +from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, FullAttentionSpec from ....utils import multi_gpu_test @@ -93,8 +91,7 @@ def get_rope_layers_config(model_path: str) -> list[int]: def create_reduced_maverick_model( - original_model_name: - str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + original_model_name: str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", output_dir: str = "/tmp/reduced_maverick", text_layers: int = 4, num_experts: int = 4, @@ -118,7 +115,8 @@ def create_reduced_maverick_model( print( f"Creating reduced Maverick model with {text_layers} text layers and " - f"{vision_layers} vision layers...") + f"{vision_layers} vision layers..." + ) # Create output directory output_path = Path(output_dir) @@ -126,19 +124,23 @@ def create_reduced_maverick_model( if force_recreate: shutil.rmtree(output_path) else: - print(f"Output directory {output_dir} already exists. " - "Use --force-recreate to overwrite.") + print( + f"Output directory {output_dir} already exists. " + "Use --force-recreate to overwrite." + ) return str(output_path) output_path.mkdir(parents=True, exist_ok=True) try: print("Loading original model configuration...") - original_config = AutoConfig.from_pretrained(original_model_name, - trust_remote_code=True) + original_config = AutoConfig.from_pretrained( + original_model_name, trust_remote_code=True + ) print("Creating reduced configuration...") - reduced_config = create_reduced_config(original_config, text_layers, - num_experts, vision_layers) + reduced_config = create_reduced_config( + original_config, text_layers, num_experts, vision_layers + ) config_path = output_path / "config.json" with open(config_path, "w") as f: @@ -149,8 +151,7 @@ def create_reduced_maverick_model( copy_tokenizer_files(original_model_name, output_path) print("Creating reduced safetensors files...") - create_reduced_safetensors(original_config, reduced_config, - output_path) + create_reduced_safetensors(original_config, reduced_config, output_path) print("Creating preprocessor config...") create_preprocessor_config(original_config, output_path) @@ -173,9 +174,9 @@ def create_reduced_maverick_model( raise -def create_reduced_config(original_config: Any, text_layers: int, - num_experts: int, - vision_layers: int) -> dict[str, Any]: +def create_reduced_config( + original_config: Any, text_layers: int, num_experts: int, vision_layers: int +) -> dict[str, Any]: """Create a reduced configuration based on the original.""" # Convert config to dictionary @@ -185,23 +186,18 @@ def create_reduced_config(original_config: Any, text_layers: int, if "text_config" in config_dict: original_text_layers = config_dict["text_config"]["num_hidden_layers"] config_dict["text_config"]["num_hidden_layers"] = text_layers - print( - f"Reduced text layers from {original_text_layers} to {text_layers}" - ) + print(f"Reduced text layers from {original_text_layers} to {text_layers}") original_num_experts = config_dict["text_config"]["num_local_experts"] config_dict["text_config"]["num_local_experts"] = num_experts - print( - f"Reduced num experts from {original_num_experts} to {num_experts}" - ) + print(f"Reduced num experts from {original_num_experts} to {num_experts}") hidden_dim_divisor = 4 original_hidden_size = config_dict["text_config"]["hidden_size"] new_hidden_size = original_hidden_size // hidden_dim_divisor config_dict["text_config"]["hidden_size"] = new_hidden_size - print(f"Reduced hidden size from {original_hidden_size} to " - f"{new_hidden_size}") + print(f"Reduced hidden size from {original_hidden_size} to {new_hidden_size}") original_head_dim = config_dict["text_config"]["head_dim"] new_head_dim = original_head_dim // hidden_dim_divisor @@ -210,15 +206,12 @@ def create_reduced_config(original_config: Any, text_layers: int, # Reduce vision layers if "vision_config" in config_dict: - original_vision_layers = config_dict["vision_config"][ - "num_hidden_layers"] + original_vision_layers = config_dict["vision_config"]["num_hidden_layers"] config_dict["vision_config"]["num_hidden_layers"] = vision_layers - print(f"Reduced vision layers from {original_vision_layers} " - f"to {vision_layers}") + print(f"Reduced vision layers from {original_vision_layers} to {vision_layers}") # Update model name to indicate it's a reduced version - config_dict["_name_or_path"] = ( - f"reduced_maverick_{text_layers}t_{vision_layers}v") + config_dict["_name_or_path"] = f"reduced_maverick_{text_layers}t_{vision_layers}v" return config_dict @@ -227,16 +220,16 @@ def copy_tokenizer_files(original_model_name: str, output_path: Path) -> None: """Copy tokenizer files from the original model.""" try: - tokenizer = AutoTokenizer.from_pretrained(original_model_name, - trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + original_model_name, trust_remote_code=True + ) tokenizer.save_pretrained(output_path) print("Tokenizer files copied successfully") except Exception as e: print(f"Warning: Could not copy tokenizer files: {e}") -def create_preprocessor_config(original_config: Any, - output_path: Path) -> None: +def create_preprocessor_config(original_config: Any, output_path: Path) -> None: """Create preprocessor_config.json for multimodal model.""" # Try to load the original preprocessor config @@ -254,9 +247,9 @@ def create_preprocessor_config(original_config: Any, raise -def create_reduced_safetensors(original_config: Any, reduced_config: dict[str, - Any], - output_path: Path) -> None: +def create_reduced_safetensors( + original_config: Any, reduced_config: dict[str, Any], output_path: Path +) -> None: """Create safetensors files with weights for the reduced model.""" print("Generating synthetic weights for reduced model...") @@ -279,8 +272,7 @@ def create_reduced_safetensors(original_config: Any, reduced_config: dict[str, save_weights_to_safetensors(weights, output_path) -def create_text_model_weights( - text_config: dict[str, Any]) -> dict[str, torch.Tensor]: +def create_text_model_weights(text_config: dict[str, Any]) -> dict[str, torch.Tensor]: """Create synthetic weights for the text model with MoE structure.""" weights = {} @@ -291,19 +283,18 @@ def create_text_model_weights( intermediate_size_mlp = text_config["intermediate_size_mlp"] num_layers = text_config["num_hidden_layers"] num_attention_heads = text_config["num_attention_heads"] - num_key_value_heads = text_config.get("num_key_value_heads", - num_attention_heads) + num_key_value_heads = text_config.get("num_key_value_heads", num_attention_heads) # MoE specific parameters num_experts = text_config.get("num_local_experts") - assert (num_experts - is not None), "num_local_experts must be specified for MoE" + assert num_experts is not None, "num_local_experts must be specified for MoE" head_dim = hidden_size // num_attention_heads # Embedding layers weights["language_model.model.embed_tokens.weight"] = torch.randn( - vocab_size, hidden_size, dtype=torch.float16) + vocab_size, hidden_size, dtype=torch.float16 + ) # Transformer layers for layer_idx in range(num_layers): @@ -312,95 +303,105 @@ def create_text_model_weights( # Self-attention weights (separate q, k, v projections) weights[f"{layer_prefix}.self_attn.q_proj.weight"] = torch.randn( - hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16) + hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.k_proj.weight"] = torch.randn( - hidden_size, num_key_value_heads * head_dim, dtype=torch.bfloat16) + hidden_size, num_key_value_heads * head_dim, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.v_proj.weight"] = torch.randn( - num_key_value_heads * head_dim, hidden_size, dtype=torch.bfloat16) + num_key_value_heads * head_dim, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.o_proj.weight"] = torch.randn( - hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16) + hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16 + ) print("Self-attention weights created.") # Feed-forward weights - MoE pattern based on interleave_moe_layer_step # For interleave_moe_layer_step=2: layers 1,3,5,... are MoE, layers # 0,2,4,... are dense interleave_step = text_config.get("interleave_moe_layer_step", 1) - is_moe_layer = (interleave_step > 0 - and (layer_idx + 1) % interleave_step == 0) + is_moe_layer = interleave_step > 0 and (layer_idx + 1) % interleave_step == 0 if is_moe_layer: # MoE layer structure # 1. Router weights - weights[ - f"{layer_prefix}.feed_forward.router.weight"] = torch.randn( - num_experts, hidden_size, dtype=torch.float16) + weights[f"{layer_prefix}.feed_forward.router.weight"] = torch.randn( + num_experts, hidden_size, dtype=torch.float16 + ) # 2. Individual expert weights (not fused) for expert_idx in range(num_experts): - expert_prefix = ( - f"{layer_prefix}.feed_forward.experts.{expert_idx}") + expert_prefix = f"{layer_prefix}.feed_forward.experts.{expert_idx}" weights[f"{expert_prefix}.gate_proj.weight"] = torch.randn( - intermediate_size, hidden_size, dtype=torch.bfloat16) + intermediate_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{expert_prefix}.up_proj.weight"] = torch.randn( - intermediate_size, hidden_size, dtype=torch.bfloat16) + intermediate_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{expert_prefix}.down_proj.weight"] = torch.randn( - hidden_size, intermediate_size, dtype=torch.bfloat16) + hidden_size, intermediate_size, dtype=torch.bfloat16 + ) # Expert weight scales (FP8 quantization) - weights[ - f"{expert_prefix}.gate_proj.weight_scale"] = torch.ones( - intermediate_size, 1, dtype=torch.bfloat16) + weights[f"{expert_prefix}.gate_proj.weight_scale"] = torch.ones( + intermediate_size, 1, dtype=torch.bfloat16 + ) weights[f"{expert_prefix}.up_proj.weight_scale"] = torch.ones( - intermediate_size, 1, dtype=torch.bfloat16) - weights[ - f"{expert_prefix}.down_proj.weight_scale"] = torch.ones( - hidden_size, 1, dtype=torch.bfloat16) + intermediate_size, 1, dtype=torch.bfloat16 + ) + weights[f"{expert_prefix}.down_proj.weight_scale"] = torch.ones( + hidden_size, 1, dtype=torch.bfloat16 + ) # 3. Shared expert weights shared_expert_prefix = f"{layer_prefix}.feed_forward.shared_expert" weights[f"{shared_expert_prefix}.gate_proj.weight"] = torch.randn( - intermediate_size, hidden_size, dtype=torch.bfloat16) + intermediate_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{shared_expert_prefix}.up_proj.weight"] = torch.randn( - intermediate_size, hidden_size, dtype=torch.bfloat16) + intermediate_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{shared_expert_prefix}.down_proj.weight"] = torch.randn( - hidden_size, intermediate_size, dtype=torch.bfloat16) + hidden_size, intermediate_size, dtype=torch.bfloat16 + ) print(f"MoE feed-forward weights created for layer {layer_idx}.") else: # Dense layer structure - weights[f"{layer_prefix}.feed_forward.gate_proj.weight"] = ( - torch.randn(intermediate_size_mlp, - hidden_size, - dtype=torch.bfloat16)) - weights[f"{layer_prefix}.feed_forward.up_proj.weight"] = ( - torch.randn(intermediate_size_mlp, - hidden_size, - dtype=torch.bfloat16)) - weights[f"{layer_prefix}.feed_forward.down_proj.weight"] = ( - torch.randn(hidden_size, - intermediate_size_mlp, - dtype=torch.bfloat16)) + weights[f"{layer_prefix}.feed_forward.gate_proj.weight"] = torch.randn( + intermediate_size_mlp, hidden_size, dtype=torch.bfloat16 + ) + weights[f"{layer_prefix}.feed_forward.up_proj.weight"] = torch.randn( + intermediate_size_mlp, hidden_size, dtype=torch.bfloat16 + ) + weights[f"{layer_prefix}.feed_forward.down_proj.weight"] = torch.randn( + hidden_size, intermediate_size_mlp, dtype=torch.bfloat16 + ) print(f"Dense feed-forward weights created for layer {layer_idx}.") # Layer norms weights[f"{layer_prefix}.input_layernorm.weight"] = torch.ones( - hidden_size, dtype=torch.bfloat16) - weights[ - f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) + weights[f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( + hidden_size, dtype=torch.bfloat16 + ) print("Layer norms created.") # Final layer norm and output projection weights["language_model.model.norm.weight"] = torch.ones( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights["language_model.lm_head.weight"] = torch.randn( - vocab_size, hidden_size, dtype=torch.bfloat16) + vocab_size, hidden_size, dtype=torch.bfloat16 + ) return weights def create_vision_model_weights( - vision_config: dict[str, Any]) -> dict[str, torch.Tensor]: + vision_config: dict[str, Any], +) -> dict[str, torch.Tensor]: """Create synthetic weights for the vision model.""" weights = {} @@ -414,47 +415,62 @@ def create_vision_model_weights( layer_prefix = f"vision_model.model.layers.{layer_idx}" weights[f"{layer_prefix}.self_attn.q_proj.weight"] = torch.randn( - hidden_size, hidden_size, dtype=torch.bfloat16) + hidden_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.q_proj.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.k_proj.weight"] = torch.randn( - hidden_size, hidden_size, dtype=torch.bfloat16) + hidden_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.k_proj.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.v_proj.weight"] = torch.randn( - hidden_size, hidden_size, dtype=torch.bfloat16) + hidden_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.v_proj.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.o_proj.weight"] = torch.randn( - hidden_size, hidden_size, dtype=torch.bfloat16) + hidden_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.self_attn.o_proj.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.mlp.fc1.weight"] = torch.randn( - intermediate_size, hidden_size, dtype=torch.bfloat16) + intermediate_size, hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.mlp.fc1.bias"] = torch.zeros( - intermediate_size, dtype=torch.bfloat16) + intermediate_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.mlp.fc2.weight"] = torch.randn( - hidden_size, intermediate_size, dtype=torch.bfloat16) + hidden_size, intermediate_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.mlp.fc2.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.input_layernorm.weight"] = torch.ones( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.input_layernorm.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) - weights[ - f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) + weights[f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( + hidden_size, dtype=torch.bfloat16 + ) weights[f"{layer_prefix}.post_attention_layernorm.bias"] = torch.zeros( - hidden_size, dtype=torch.bfloat16) + hidden_size, dtype=torch.bfloat16 + ) return weights def create_shared_weights( - text_config: dict[str, Any], - vision_config: dict[str, Any]) -> dict[str, torch.Tensor]: + text_config: dict[str, Any], vision_config: dict[str, Any] +) -> dict[str, torch.Tensor]: """Create weights for shared components (vision-language connector)""" weights = {} @@ -464,13 +480,15 @@ def create_shared_weights( # Vision-language connector (projects vision features to text space) weights["multi_modal_projector.linear_1.weight"] = torch.randn( - text_hidden_size, projector_input_dim, dtype=torch.bfloat16) + text_hidden_size, projector_input_dim, dtype=torch.bfloat16 + ) return weights -def save_weights_to_safetensors(weights: dict[str, torch.Tensor], - output_path: Path) -> None: +def save_weights_to_safetensors( + weights: dict[str, torch.Tensor], output_path: Path +) -> None: """Save weights to safetensors files and create index.""" # Determine how to shard the weights @@ -507,18 +525,18 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor], else: # Multiple shards for i, shard in enumerate(shards): - filename = f"model-{i+1:05d}-of-{len(shards):05d}.safetensors" + filename = f"model-{i + 1:05d}-of-{len(shards):05d}.safetensors" save_file(shard, output_path / filename) for name in shard: weight_map[name] = filename - print(f"Saved shard {i+1}/{len(shards)}: {filename}") + print(f"Saved shard {i + 1}/{len(shards)}: {filename}") # Create index file index_data = { "metadata": { - "total_size": - sum(tensor.numel() * tensor.element_size() - for tensor in weights.values()) + "total_size": sum( + tensor.numel() * tensor.element_size() for tensor in weights.values() + ) }, "weight_map": weight_map, } @@ -528,8 +546,9 @@ def save_weights_to_safetensors(weights: dict[str, torch.Tensor], json.dump(index_data, f, indent=2) print(f"Created index file: {index_path}") - print(f"Total model size: " - f"{index_data['metadata']['total_size'] / (1024**3):.2f} GB") + print( + f"Total model size: {index_data['metadata']['total_size'] / (1024**3):.2f} GB" + ) def check_attention_spec_interleaved_rope( @@ -540,8 +559,7 @@ def check_attention_spec_interleaved_rope( ): """Check that the attention spec is correct.""" assert isinstance(llm.llm_engine.model_executor, Executor) - kv_cache_specs_per_rank = llm.llm_engine.model_executor.get_kv_cache_specs( - ) + kv_cache_specs_per_rank = llm.llm_engine.model_executor.get_kv_cache_specs() for rank in range(num_ranks): kv_cache_specs = kv_cache_specs_per_rank[rank] assert len(kv_cache_specs.keys()) == num_attention_layers @@ -551,16 +569,14 @@ def check_attention_spec_interleaved_rope( else: expected_spec = ChunkedLocalAttentionSpec assert isinstance( - kv_cache_specs[ - f"language_model.model.layers.{i}.self_attn.attn"], - expected_spec) + kv_cache_specs[f"language_model.model.layers.{i}.self_attn.attn"], + expected_spec, + ) def run_reduced_model(llm: LLM, should_profile: bool = False) -> None: """Test the created reduced model with vLLM.""" - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=50) + sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=50) if should_profile: llm.start_profile() @@ -571,15 +587,15 @@ def run_reduced_model(llm: LLM, should_profile: bool = False) -> None: print("Test generation successful!") for output in outputs: print(f"Prompt: {output.prompt}") - print(f"Output: " - f"{output.outputs[0].text}") + print(f"Output: {output.outputs[0].text}") print("-" * 40) @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( "original_model_name,text_layers,num_experts,vision_layers,", - [("meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", 4, 4, 2)]) + [("meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", 4, 4, 2)], +) @pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("tp,ep", [(2, True)]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -640,7 +656,8 @@ def main(): import argparse parser = argparse.ArgumentParser( - description="Create a reduced-layer Maverick model") + description="Create a reduced-layer Maverick model" + ) parser.add_argument( "--output-dir", default="/tmp/reduced_maverick", @@ -652,10 +669,7 @@ def main(): default=4, help="Number of text transformer layers", ) - parser.add_argument("--num-experts", - type=int, - default=4, - help="Number of experts") + parser.add_argument("--num-experts", type=int, default=4, help="Number of experts") parser.add_argument( "--vision-layers", type=int, @@ -667,12 +681,12 @@ def main(): action="store_true", help="Force recreation if output directory exists", ) - parser.add_argument("--test", - action="store_true", - help="Test the created model with vLLM") - parser.add_argument("--profile", - action="store_true", - help="Profile the created model with vLLM") + parser.add_argument( + "--test", action="store_true", help="Test the created model with vLLM" + ) + parser.add_argument( + "--profile", action="store_true", help="Profile the created model with vLLM" + ) parser.add_argument( "--test-original", action="store_true", @@ -687,16 +701,18 @@ def main(): args = parser.parse_args() if args.test: - test_dummy_maverick(original_model_name=args.original_model, - output_dir=args.output_dir, - text_layers=args.text_layers, - num_experts=args.num_experts, - vision_layers=args.vision_layers, - force_recreate=args.force_recreate, - tp=2, - ep=True, - enforce_eager=True, - profile=args.profile) + test_dummy_maverick( + original_model_name=args.original_model, + output_dir=args.output_dir, + text_layers=args.text_layers, + num_experts=args.num_experts, + vision_layers=args.vision_layers, + force_recreate=args.force_recreate, + tp=2, + ep=True, + enforce_eager=True, + profile=args.profile, + ) if args.test_original: run_maverick_serving(args.original_model) diff --git a/tests/models/multimodal/generation/test_phi4_multimodal.py b/tests/models/multimodal/generation/test_phi4_multimodal.py index db8984d8656f..132c69285c5c 100644 --- a/tests/models/multimodal/generation/test_phi4_multimodal.py +++ b/tests/models/multimodal/generation/test_phi4_multimodal.py @@ -14,26 +14,35 @@ from vllm.multimodal.image import rescale_image_size from vllm.platforms import current_platform -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, - PromptImageInput, VllmRunner) +from ....conftest import ( + IMAGE_ASSETS, + HfRunner, + PromptAudioInput, + PromptImageInput, + VllmRunner, +) from ....utils import large_gpu_test from ...utils import check_logprobs_close -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 - "cherry_blossom": - "<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 -}) -HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 + "cherry_blossom": "<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 + } +) +HF_MULTIIMAGE_IMAGE_PROMPT = ( + "<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +) -model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct", - revision="refs/pr/70") +model_path = snapshot_download( + "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70" +) # Since the vision-lora and speech-lora co-exist with the base model, # we have to manually specify the path of the lora weights. vision_lora_path = os.path.join(model_path, "vision-lora") -speech_question = os.path.join(model_path, "examples", - "what_is_shown_in_this_image.wav") +speech_question = os.path.join( + model_path, "examples", "what_is_shown_in_this_image.wav" +) models = [model_path] target_dtype = "half" @@ -48,8 +57,7 @@ def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: Sequence[tuple[list[str], PromptImageInput, - Optional[PromptAudioInput]]], + inputs: Sequence[tuple[list[str], PromptImageInput, Optional[PromptAudioInput]]], model: str, *, max_model_len: int, @@ -75,28 +83,30 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size with vllm_runner( - model, - task="generate", - max_model_len=max_model_len, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_lora=True, - max_lora_rank=320, - gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI - enforce_eager=True, - trust_remote_code=False, + model, + task="generate", + max_model_len=max_model_len, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=320, + gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI + enforce_eager=True, + trust_remote_code=False, ) as vllm_model: lora_request = LoRARequest("vision", 1, vision_lora_path) vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - lora_request=lora_request) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + lora_request=lora_request, + ) for prompts, images, audios in inputs ] @@ -108,17 +118,18 @@ def run_test( hf_processor = hf_model.processor eos_token_id = hf_processor.tokenizer.eos_token_id hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - eos_token_id=eos_token_id) + hf_model.generate_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + eos_token_id=eos_token_id, + ) for prompts, images, audios in inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, @@ -145,16 +156,27 @@ def run_test( @pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - None, - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + inputs_per_image = [ + ( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + None, + ) + for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] run_test( hf_runner, @@ -189,16 +211,26 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, @pytest.mark.parametrize("max_model_len", [25600]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, - size_factors, dtype: str, max_model_len: int, - max_tokens: int, num_logprobs: int) -> None: +def test_multi_images_models( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] inputs_per_case = [ ( [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors], + [ + [rescale_image_size(image, factor) for image in images] + for factor in size_factors + ], None, ), ] @@ -222,10 +254,15 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, - max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: - +def test_vision_speech_models( + hf_runner, + vllm_runner, + model, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: # use the example speech question so that the model outputs are reasonable audio = librosa.load(speech_question, sr=16000) image = ImageAsset("cherry_blossom").pil_image.convert("RGB") diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index 77e2b90dd5e9..e69d44c6a131 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -17,31 +17,39 @@ from vllm.multimodal.image import convert_image_mode, rescale_image_size from vllm.platforms import current_platform -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, - PromptImageInput, VllmRunner) +from ....conftest import ( + IMAGE_ASSETS, + HfRunner, + PromptAudioInput, + PromptImageInput, + VllmRunner, +) from ....utils import large_gpu_test from ...utils import check_logprobs_close -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 - "cherry_blossom": - "<|user|>\n<|image_1|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 -}) -HF_MULTIIMAGE_IMAGE_PROMPT = "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "<|user|>\n<|image_1|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 + "cherry_blossom": "<|user|>\n<|image_1|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 + } +) +HF_MULTIIMAGE_IMAGE_PROMPT = ( + "<|user|>\n<|image_1|>\n<|image_2|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 +) model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct") # Since the vision-lora and speech-lora co-exist with the base model, # we have to manually specify the path of the lora weights. vision_lora_path = os.path.join(model_path, "vision-lora") -speech_question = os.path.join(model_path, "examples", - "what_is_shown_in_this_image.wav") +speech_question = os.path.join( + model_path, "examples", "what_is_shown_in_this_image.wav" +) models = [model_path] -def vllm_to_hf_output(vllm_output: tuple[list[int], str, - Optional[SampleLogprobs]], - model: str): +def vllm_to_hf_output( + vllm_output: tuple[list[int], str, Optional[SampleLogprobs]], model: str +): """Sanitize vllm output to be comparable with hf output.""" _, output_str, out_logprobs = vllm_output @@ -71,8 +79,7 @@ def vllm_to_hf_output(vllm_output: tuple[list[int], str, def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: Sequence[tuple[list[str], PromptImageInput, - Optional[PromptAudioInput]]], + inputs: Sequence[tuple[list[str], PromptImageInput, Optional[PromptAudioInput]]], model: str, *, max_model_len: int, @@ -98,27 +105,29 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size with vllm_runner( - model, - runner="generate", - max_model_len=max_model_len, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_lora=True, - max_lora_rank=320, - gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI - enforce_eager=True, + model, + runner="generate", + max_model_len=max_model_len, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enable_lora=True, + max_lora_rank=320, + gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI + enforce_eager=True, ) as vllm_model: lora_request = LoRARequest("vision", 1, vision_lora_path) vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - lora_request=lora_request) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + lora_request=lora_request, + ) for prompts, images, audios in inputs ] @@ -127,42 +136,36 @@ def run_test( pytest.skip("HF impl is not compatible with current transformers") hf_model_kwargs = {"_attn_implementation": "sdpa"} - with hf_runner(model, dtype=dtype, - model_kwargs=hf_model_kwargs) as hf_model: - + with hf_runner(model, dtype=dtype, model_kwargs=hf_model_kwargs) as hf_model: hf_processor = hf_model.processor eos_token_id = hf_processor.tokenizer.eos_token_id - def patch_hf_processor(*args, - text="", - images=None, - audio=None, - sampling_rate=None, - **kwargs): + def patch_hf_processor( + *args, text="", images=None, audio=None, sampling_rate=None, **kwargs + ): audios = None if audio is not None and sampling_rate is not None: audios = [(audio, sampling_rate)] - return hf_processor(*args, - text=text, - images=images, - audios=audios, - **kwargs) + return hf_processor( + *args, text=text, images=images, audios=audios, **kwargs + ) hf_model.processor = patch_hf_processor hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - eos_token_id=eos_token_id, - num_logits_to_keep=0) + hf_model.generate_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + audios=audios, + eos_token_id=eos_token_id, + num_logits_to_keep=0, + ) for prompts, images, audios in inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, @@ -189,16 +192,27 @@ def patch_hf_processor(*args, @pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, - dtype: str, max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: +def test_models( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - None, - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + inputs_per_image = [ + ( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + None, + ) + for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] run_test( hf_runner, @@ -233,16 +247,26 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, @pytest.mark.parametrize("max_model_len", [25600]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, - size_factors, dtype: str, max_model_len: int, - max_tokens: int, num_logprobs: int) -> None: +def test_multi_images_models( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] inputs_per_case = [ ( [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors], + [ + [rescale_image_size(image, factor) for image in images] + for factor in size_factors + ], None, ), ] @@ -266,10 +290,15 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, @pytest.mark.parametrize("max_model_len", [12800]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str, - max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: - +def test_vision_speech_models( + hf_runner, + vllm_runner, + model, + dtype: str, + max_model_len: int, + max_tokens: int, + num_logprobs: int, +) -> None: # use the example speech question so that the model outputs are reasonable audio = librosa.load(speech_question, sr=None) image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB") diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index 715b08ef90e5..db0effdaf666 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -37,33 +37,33 @@ def _create_msg_format(urls: list[str]) -> list[dict[str, Any]]: - return [{ - "role": - "user", - "content": [{ - "type": "text", - "text": PROMPT, - }] + [{ - "type": "image_url", - "image_url": { - "url": url - } - } for url in urls], - }] + return [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": PROMPT, + } + ] + + [{"type": "image_url", "image_url": {"url": url}} for url in urls], + } + ] def _create_msg_format_hf(urls: list[str]) -> list[dict[str, Any]]: - return [{ - "role": - "user", - "content": [{ - "type": "text", - "content": PROMPT, - }, *({ - "type": "image", - "image": download_image(url) - } for url in urls)], - }] + return [ + { + "role": "user", + "content": [ + { + "type": "text", + "content": PROMPT, + }, + *({"type": "image", "image": download_image(url)} for url in urls), + ], + } + ] def _create_engine_inputs(urls: list[str]) -> TokensPrompt: @@ -125,11 +125,17 @@ def _dump_outputs_w_logprobs( outputs: OutputsLogprobs, filename: "StrPath", ) -> None: - json_data = [(tokens, text, [{ - k: asdict(v) - for k, v in token_logprobs.items() - } for token_logprobs in (logprobs or [])]) - for tokens, text, logprobs in outputs] + json_data = [ + ( + tokens, + text, + [ + {k: asdict(v) for k, v in token_logprobs.items()} + for token_logprobs in (logprobs or []) + ], + ) + for tokens, text, logprobs in outputs + ] with open(filename, "w") as f: json.dump(json_data, f) @@ -139,28 +145,35 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: with open(filename, "rb") as f: json_data = json.load(f) - return [(tokens, text, [{ - int(k): Logprob(**v) - for k, v in token_logprobs.items() - } for token_logprobs in logprobs]) for tokens, text, logprobs in json_data] + return [ + ( + tokens, + text, + [ + {int(k): Logprob(**v) for k, v in token_logprobs.items()} + for token_logprobs in logprobs + ], + ) + for tokens, text, logprobs in json_data + ] @large_gpu_test(min_gb=80) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN) @pytest.mark.parametrize("dtype", ["bfloat16"]) -def test_chat(vllm_runner, max_model_len: int, model: str, dtype: str, - local_asset_server) -> None: - EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs( - FIXTURE_LOGPROBS_CHAT[model]) +def test_chat( + vllm_runner, max_model_len: int, model: str, dtype: str, local_asset_server +) -> None: + EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model]) with vllm_runner( - model, - dtype=dtype, - tokenizer_mode="mistral", - load_format="mistral", - config_format="mistral", - max_model_len=max_model_len, - limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, + model, + dtype=dtype, + tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", + max_model_len=max_model_len, + limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: outputs = [] @@ -180,7 +193,9 @@ def test_chat(vllm_runner, max_model_len: int, model: str, dtype: str, for i in range(len(logprobs)): assert logprobs[i][-1] is None logprobs[i] = logprobs[i][:-1] - check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS, - outputs_1_lst=logprobs, - name_0="h100_ref", - name_1="output") + check_logprobs_close( + outputs_0_lst=EXPECTED_CHAT_LOGPROBS, + outputs_1_lst=logprobs, + name_0="h100_ref", + name_1="output", + ) diff --git a/tests/models/multimodal/generation/test_qwen2_5_vl.py b/tests/models/multimodal/generation/test_qwen2_5_vl.py index 1dc3188d60bd..1a7d854352ae 100644 --- a/tests/models/multimodal/generation/test_qwen2_5_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_5_vl.py @@ -17,14 +17,15 @@ def qwen2_5_vl_chat_template(*query): return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501 -VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ - "baby_reading": - qwen2_5_vl_chat_template( - VIDEO_PLACEHOLDER, - "Describe this video with a short sentence ", - "(no more than 20 words)", - ), -}) +VIDEO_PROMPTS = VIDEO_ASSETS.prompts( + { + "baby_reading": qwen2_5_vl_chat_template( + VIDEO_PLACEHOLDER, + "Describe this video with a short sentence ", + "(no more than 20 words)", + ), + } +) @pytest.mark.core_model @@ -33,10 +34,15 @@ def qwen2_5_vl_chat_template(*query): @pytest.mark.parametrize("num_frames", [16]) @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) -def test_qwen2_5_vl_evs_functionality(vllm_runner, video_assets, model, - video_pruning_rate: float, - num_frames: int, dtype: str, - max_tokens: int) -> None: +def test_qwen2_5_vl_evs_functionality( + vllm_runner, + video_assets, + model, + video_pruning_rate: float, + num_frames: int, + dtype: str, + max_tokens: int, +) -> None: """Test EVS (Efficient Video Sampling) functionality with different pruning rates. """ @@ -51,19 +57,18 @@ def test_qwen2_5_vl_evs_functionality(vllm_runner, video_assets, model, videos = [sampled_vids[0]] # Initialize model with EVS configuration - with vllm_runner(model, - runner="generate", - max_model_len=4000, - max_num_seqs=1, - dtype=dtype, - limit_mm_per_prompt={"video": 1}, - tensor_parallel_size=1, - video_pruning_rate=video_pruning_rate) as vllm_model: - + with vllm_runner( + model, + runner="generate", + max_model_len=4000, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"video": 1}, + tensor_parallel_size=1, + video_pruning_rate=video_pruning_rate, + ) as vllm_model: # Generate output - this should not crash - outputs = vllm_model.generate_greedy(prompts, - max_tokens, - videos=videos) + outputs = vllm_model.generate_greedy(prompts, max_tokens, videos=videos) # Basic validation that we got a response assert len(outputs) == 1 @@ -83,10 +88,15 @@ def test_qwen2_5_vl_evs_functionality(vllm_runner, video_assets, model, @pytest.mark.parametrize("num_frames", [16]) @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) -def test_qwen2_5_vl_evs_batched_videos(vllm_runner, video_assets, model, - video_pruning_rate: float, - num_frames: int, dtype: str, - max_tokens: int) -> None: +def test_qwen2_5_vl_evs_batched_videos( + vllm_runner, + video_assets, + model, + video_pruning_rate: float, + num_frames: int, + dtype: str, + max_tokens: int, +) -> None: """Test EVS functionality with batched videos. This test validates that: @@ -102,23 +112,21 @@ def test_qwen2_5_vl_evs_batched_videos(vllm_runner, video_assets, model, # Test batched videos prompts = [VIDEO_PROMPTS[0], VIDEO_PROMPTS[0]] - videos = [sampled_vids[0], - sampled_vids[0]] # Use same video twice for testing + videos = [sampled_vids[0], sampled_vids[0]] # Use same video twice for testing # Initialize model with EVS configuration - with vllm_runner(model, - runner="generate", - max_model_len=4000, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"video": 2}, - tensor_parallel_size=1, - video_pruning_rate=video_pruning_rate) as vllm_model: - + with vllm_runner( + model, + runner="generate", + max_model_len=4000, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"video": 2}, + tensor_parallel_size=1, + video_pruning_rate=video_pruning_rate, + ) as vllm_model: # Generate output - this should not crash - outputs = vllm_model.generate_greedy(prompts, - max_tokens, - videos=videos) + outputs = vllm_model.generate_greedy(prompts, max_tokens, videos=videos) # Basic validation that we got responses for both videos assert len(outputs) == 2 diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index c8a3513ac7ad..a8f0ba870185 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -11,8 +11,13 @@ from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import rescale_video_size, sample_frames_from_video -from ....conftest import (IMAGE_ASSETS, VIDEO_ASSETS, PromptImageInput, - PromptVideoInput, VllmRunner) +from ....conftest import ( + IMAGE_ASSETS, + VIDEO_ASSETS, + PromptImageInput, + PromptVideoInput, + VllmRunner, +) from ...utils import check_logprobs_close @@ -34,28 +39,29 @@ def qwen2_vl_chat_template(*query): return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501 -IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - qwen2_vl_chat_template( - IMAGE_PLACEHOLDER, - "What is the biggest text's content in this image?", - ), - "cherry_blossom": - qwen2_vl_chat_template( - IMAGE_PLACEHOLDER, - "What is the season shown in this image? ", - "Reply with a short sentence (no more than 20 words)", - ), -}) - -VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ - "baby_reading": - qwen2_vl_chat_template( - VIDEO_PLACEHOLDER, - "Describe this video with a short sentence ", - "(no more than 20 words)", - ), -}) +IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": qwen2_vl_chat_template( + IMAGE_PLACEHOLDER, + "What is the biggest text's content in this image?", + ), + "cherry_blossom": qwen2_vl_chat_template( + IMAGE_PLACEHOLDER, + "What is the season shown in this image? ", + "Reply with a short sentence (no more than 20 words)", + ), + } +) + +VIDEO_PROMPTS = VIDEO_ASSETS.prompts( + { + "baby_reading": qwen2_vl_chat_template( + VIDEO_PLACEHOLDER, + "Describe this video with a short sentence ", + "(no more than 20 words)", + ), + } +) MULTIIMAGE_PROMPT = qwen2_vl_chat_template( IMAGE_PLACEHOLDER, @@ -77,17 +83,19 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict): def batch_make_image_embeddings( - image_batches: list[Union[Image.Image, list[Image.Image]]], processor, - llm: VllmRunner) -> list[Qwen2VLPromptImageEmbeddingInput]: + image_batches: list[Union[Image.Image, list[Image.Image]]], + processor, + llm: VllmRunner, +) -> list[Qwen2VLPromptImageEmbeddingInput]: """batched image embeddings for Qwen2-VL - This will infer all images' embeddings in a single batch, + This will infer all images' embeddings in a single batch, and split the result according to input batches. image_batches: - Single-image batches: `list[Image.Image]` - Multiple-image batches: `list[list[Image.Image]]]` - + returns: `list[Qwen2VLPromptImageEmbeddingInput]` """ @@ -108,9 +116,9 @@ def batch_make_image_embeddings( # image to pixel values image_processor = processor.image_processor - preprocess_result = image_processor \ - .preprocess(images=images, return_tensors="pt") \ - .data + preprocess_result = image_processor.preprocess( + images=images, return_tensors="pt" + ).data pixel_values = preprocess_result["pixel_values"] image_grid_thw = preprocess_result["image_grid_thw"] @@ -119,12 +127,13 @@ def get_image_embeds(model): with torch.no_grad(): visual = model.visual - pixel_values_on_device = pixel_values.to(visual.device, - dtype=visual.dtype) - image_grid_thw_on_device = image_grid_thw.to(visual.device, - dtype=torch.int64) - return visual(pixel_values_on_device, - grid_thw=image_grid_thw_on_device).cpu() + pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype) + image_grid_thw_on_device = image_grid_thw.to( + visual.device, dtype=torch.int64 + ) + return visual( + pixel_values_on_device, grid_thw=image_grid_thw_on_device + ).cpu() image_embeds = torch.concat(llm.apply_model(get_image_embeds)) @@ -137,16 +146,21 @@ def get_image_embeds(model): merge_size = image_processor.merge_size cur_batch_embed_len = sum( grid_thw.prod(-1) // merge_size // merge_size - for grid_thw in image_grid_thw[image_counter:image_counter + - cur_batch_image_count]) + for grid_thw in image_grid_thw[ + image_counter : image_counter + cur_batch_image_count + ] + ) - result.append({ - "image_embeds": - image_embeds[embed_counter:embed_counter + cur_batch_embed_len], - "image_grid_thw": - image_grid_thw[image_counter:image_counter + - cur_batch_image_count], - }) + result.append( + { + "image_embeds": image_embeds[ + embed_counter : embed_counter + cur_batch_embed_len + ], + "image_grid_thw": image_grid_thw[ + image_counter : image_counter + cur_batch_image_count + ], + } + ) embed_counter += cur_batch_embed_len image_counter += cur_batch_image_count @@ -160,13 +174,13 @@ def get_image_embeds(model): def batch_make_video_embeddings( - video_batches: PromptVideoInput, processor, - llm: VllmRunner) -> list[Qwen2VLPromptVideoEmbeddingInput]: + video_batches: PromptVideoInput, processor, llm: VllmRunner +) -> list[Qwen2VLPromptVideoEmbeddingInput]: """batched video embeddings for Qwen2-VL A NDArray represents a single video's all frames. - This will infer all videos' embeddings in a single batch, + This will infer all videos' embeddings in a single batch, and split the result according to input batches. video_batches: @@ -191,9 +205,9 @@ def batch_make_video_embeddings( # video to pixel values image_processor = processor.image_processor - preprocess_result = image_processor \ - .preprocess(images=None, videos=videos, return_tensors="pt") \ - .data + preprocess_result = image_processor.preprocess( + images=None, videos=videos, return_tensors="pt" + ).data pixel_values = preprocess_result["pixel_values_videos"] video_grid_thw = preprocess_result["video_grid_thw"] @@ -202,12 +216,13 @@ def get_image_embeds(model): with torch.no_grad(): visual = model.visual - pixel_values_on_device = pixel_values.to(visual.device, - dtype=visual.dtype) - video_grid_thw_on_device = video_grid_thw.to(visual.device, - dtype=torch.int64) - return visual(pixel_values_on_device, - grid_thw=video_grid_thw_on_device).cpu() + pixel_values_on_device = pixel_values.to(visual.device, dtype=visual.dtype) + video_grid_thw_on_device = video_grid_thw.to( + visual.device, dtype=torch.int64 + ) + return visual( + pixel_values_on_device, grid_thw=video_grid_thw_on_device + ).cpu() video_embeds = torch.concat(llm.apply_model(get_image_embeds)) @@ -220,16 +235,21 @@ def get_image_embeds(model): merge_size = image_processor.merge_size cur_batch_embed_len = sum( grid_thw.prod(-1) // merge_size // merge_size - for grid_thw in video_grid_thw[video_counter:video_counter + - cur_batch_video_count]) + for grid_thw in video_grid_thw[ + video_counter : video_counter + cur_batch_video_count + ] + ) - result.append({ - "video_embeds": - video_embeds[embed_counter:embed_counter + cur_batch_embed_len], - "video_grid_thw": - video_grid_thw[video_counter:video_counter + - cur_batch_video_count], - }) + result.append( + { + "video_embeds": video_embeds[ + embed_counter : embed_counter + cur_batch_embed_len + ], + "video_grid_thw": video_grid_thw[ + video_counter : video_counter + cur_batch_video_count + ], + } + ) embed_counter += cur_batch_embed_len video_counter += cur_batch_video_count @@ -263,25 +283,24 @@ def run_embedding_input_test( # max_model_len should be greater than image_feature_size with vllm_runner( - model, - runner="generate", - max_model_len=4000, - max_num_seqs=3, - dtype=dtype, - limit_mm_per_prompt={ - "image": mm_limit, - "video": mm_limit - }, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - default_torch_num_threads=1, + model, + runner="generate", + max_model_len=4000, + max_num_seqs=3, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit, "video": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + default_torch_num_threads=1, ) as vllm_model: outputs_per_case_for_original_input = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images or None, - videos=videos or None) + vllm_model.generate_greedy_logprobs( + prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images or None, + videos=videos or None, + ) for prompts, images, videos in inputs ] @@ -290,17 +309,19 @@ def run_embedding_input_test( prompts, max_tokens, num_logprobs=num_logprobs, - images=batch_make_image_embeddings( - images, processor, vllm_model) if images else None, - videos=batch_make_video_embeddings( - videos, processor, vllm_model) if videos else None) + images=batch_make_image_embeddings(images, processor, vllm_model) + if images + else None, + videos=batch_make_video_embeddings(videos, processor, vllm_model) + if videos + else None, + ) for prompts, images, videos in inputs ] - for outputs_for_original_input, \ - outputs_for_embeddings_input \ - in zip(outputs_per_case_for_original_input, - outputs_per_case_for_embeddings_input): + for outputs_for_original_input, outputs_for_embeddings_input in zip( + outputs_per_case_for_original_input, outputs_per_case_for_embeddings_input + ): check_logprobs_close( outputs_0_lst=outputs_for_original_input, outputs_1_lst=outputs_for_embeddings_input, @@ -325,17 +346,26 @@ def run_embedding_input_test( @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, - size_factors, dtype, max_tokens, - num_logprobs, monkeypatch) -> None: +def test_qwen2_vl_image_embeddings_input( + vllm_runner, + image_assets, + model, + size_factors, + dtype, + max_tokens, + num_logprobs, + monkeypatch, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_case: list[tuple[ - list[str], PromptImageInput, PromptVideoInput]] = [( + inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [ + ( [prompt for _ in size_factors], [rescale_image_size(image, factor) for factor in size_factors], [], - ) for image, prompt in zip(images, IMAGE_PROMPTS)] + ) + for image, prompt in zip(images, IMAGE_PROMPTS) + ] run_embedding_input_test( vllm_runner, @@ -366,21 +396,27 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model, @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, - model, size_factors, - dtype: str, max_tokens: int, - num_logprobs: int) -> None: +def test_qwen2_vl_multiple_image_embeddings_input( + vllm_runner, + image_assets, + model, + size_factors, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: images = [asset.pil_image for asset in image_assets] - inputs_per_case: list[tuple[list[str], PromptImageInput, - PromptVideoInput]] = [( - [MULTIIMAGE_PROMPT for _ in size_factors], - [[ - rescale_image_size(image, factor) - for image in images - ] for factor in size_factors], - [], - )] + inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [ + ( + [MULTIIMAGE_PROMPT for _ in size_factors], + [ + [rescale_image_size(image, factor) for image in images] + for factor in size_factors + ], + [], + ) + ] run_embedding_input_test( vllm_runner, @@ -410,22 +446,29 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets, @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model, - size_factors, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: +def test_qwen2_vl_video_embeddings_input( + vllm_runner, + video_assets, + model, + size_factors, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: num_frames = 4 sampled_vids = [ sample_frames_from_video(asset.np_ndarrays, num_frames) for asset in video_assets ] - inputs_per_case: list[tuple[ - list[str], PromptImageInput, PromptVideoInput]] = [( + inputs_per_case: list[tuple[list[str], PromptImageInput, PromptVideoInput]] = [ + ( [prompt for _ in size_factors], [], [rescale_video_size(video, factor) for factor in size_factors], - ) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)] + ) + for video, prompt in zip(sampled_vids, VIDEO_PROMPTS) + ] run_embedding_input_test( vllm_runner, diff --git a/tests/models/multimodal/generation/test_ultravox.py b/tests/models/multimodal/generation/test_ultravox.py index e7e7bd3154a1..6bfec6c2c8d3 100644 --- a/tests/models/multimodal/generation/test_ultravox.py +++ b/tests/models/multimodal/generation/test_ultravox.py @@ -15,12 +15,12 @@ MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" -AUDIO_PROMPTS = AUDIO_ASSETS.prompts({ - "mary_had_lamb": - "Transcribe this into English.", - "winning_call": - "What is happening in this audio clip?", -}) +AUDIO_PROMPTS = AUDIO_ASSETS.prompts( + { + "mary_had_lamb": "Transcribe this into English.", + "winning_call": "What is happening in this audio clip?", + } +) MULTI_AUDIO_PROMPT = "Describe each of the audios above." @@ -33,7 +33,7 @@ "enable_chunked_prefill": True, "max_num_seqs": 2, # Use a very small limit to exercise chunked prefill. - "max_num_batched_tokens": 16 + "max_num_batched_tokens": 16, } @@ -43,27 +43,33 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]: for key, value in params_kwargs.items(): if isinstance(value, bool): if value: - args.append(f"--{key.replace('_','-')}") + args.append(f"--{key.replace('_', '-')}") else: - args.append(f"--{key.replace('_','-')}={value}") + args.append(f"--{key.replace('_', '-')}={value}") return args -@pytest.fixture(params=[ - pytest.param({}, marks=pytest.mark.cpu_model), - pytest.param(CHUNKED_PREFILL_KWARGS), -]) +@pytest.fixture( + params=[ + pytest.param({}, marks=pytest.mark.cpu_model), + pytest.param(CHUNKED_PREFILL_KWARGS), + ] +) def server(request, audio_assets: AudioTestAssets): args = [ - "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", + "--dtype", + "bfloat16", + "--max-model-len", + "4096", + "--enforce-eager", "--limit-mm-per-prompt", - json.dumps({"audio": len(audio_assets)}), "--trust-remote-code" + json.dumps({"audio": len(audio_assets)}), + "--trust-remote-code", ] + params_kwargs_to_cli_args(request.param) - with RemoteOpenAIServer(MODEL_NAME, - args, - env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": - "30"}) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"} + ) as remote_server: yield remote_server @@ -77,12 +83,11 @@ def _get_prompt(audio_count, question, placeholder): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) placeholder = f"{placeholder}\n" * audio_count - return tokenizer.apply_chat_template([{ - 'role': 'user', - 'content': f"{placeholder}{question}" - }], - tokenize=False, - add_generation_prompt=True) + return tokenizer.apply_chat_template( + [{"role": "user", "content": f"{placeholder}{question}"}], + tokenize=False, + add_generation_prompt=True, + ) def run_multi_audio_test( @@ -99,19 +104,21 @@ def run_multi_audio_test( model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - with vllm_runner(model, - dtype=dtype, - enforce_eager=True, - limit_mm_per_prompt={ - "audio": - max((len(audio) for _, audio in prompts_and_audios)) - }, - **kwargs) as vllm_model: + with vllm_runner( + model, + dtype=dtype, + enforce_eager=True, + limit_mm_per_prompt={ + "audio": max((len(audio) for _, audio in prompts_and_audios)) + }, + **kwargs, + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( [prompt for prompt, _ in prompts_and_audios], max_tokens, num_logprobs=num_logprobs, - audios=[audios for _, audios in prompts_and_audios]) + audios=[audios for _, audios in prompts_and_audios], + ) # The HuggingFace model doesn't support multiple audios yet, so # just assert that some tokens were generated. @@ -122,21 +129,25 @@ def run_multi_audio_test( @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("vllm_kwargs", [ - pytest.param({}, marks=pytest.mark.cpu_model), - pytest.param(CHUNKED_PREFILL_KWARGS), -]) -def test_models_with_multiple_audios(vllm_runner, - audio_assets: AudioTestAssets, dtype: str, - max_tokens: int, num_logprobs: int, - vllm_kwargs: dict) -> None: - - vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT, - VLLM_PLACEHOLDER) +@pytest.mark.parametrize( + "vllm_kwargs", + [ + pytest.param({}, marks=pytest.mark.cpu_model), + pytest.param(CHUNKED_PREFILL_KWARGS), + ], +) +def test_models_with_multiple_audios( + vllm_runner, + audio_assets: AudioTestAssets, + dtype: str, + max_tokens: int, + num_logprobs: int, + vllm_kwargs: dict, +) -> None: + vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT, VLLM_PLACEHOLDER) run_multi_audio_test( vllm_runner, - [(vllm_prompt, [audio.audio_and_sample_rate - for audio in audio_assets])], + [(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])], MODEL_NAME, dtype=dtype, max_tokens=max_tokens, @@ -149,28 +160,25 @@ def test_models_with_multiple_audios(vllm_runner, async def test_online_serving(client, audio_assets: AudioTestAssets): """Exercises online serving with/without chunked prefill enabled.""" - messages = [{ - "role": - "user", - "content": [ - *[{ - "type": "audio_url", - "audio_url": { - "url": audio.url - } - } for audio in audio_assets], - { - "type": - "text", - "text": - f"What's happening in these {len(audio_assets)} audio clips?" - }, - ], - }] - - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10) + messages = [ + { + "role": "user", + "content": [ + *[ + {"type": "audio_url", "audio_url": {"url": audio.url}} + for audio in audio_assets + ], + { + "type": "text", + "text": f"What's happening in these {len(audio_assets)} audio clips?", # noqa: E501 + }, + ], + } + ] + + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, messages=messages, max_tokens=10 + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] diff --git a/tests/models/multimodal/generation/test_voxtral.py b/tests/models/multimodal/generation/test_voxtral.py index b4439dfe020c..d27b3ab5ff47 100644 --- a/tests/models/multimodal/generation/test_voxtral.py +++ b/tests/models/multimodal/generation/test_voxtral.py @@ -6,8 +6,12 @@ import pytest import pytest_asyncio from mistral_common.audio import Audio -from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio, - TextChunk, UserMessage) +from mistral_common.protocol.instruct.messages import ( + AudioChunk, + RawAudio, + TextChunk, + UserMessage, +) from vllm.transformers_utils.tokenizer import MistralTokenizer @@ -17,8 +21,12 @@ MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507" MISTRAL_FORMAT_ARGS = [ - "--tokenizer_mode", "mistral", "--config_format", "mistral", - "--load_format", "mistral" + "--tokenizer_mode", + "mistral", + "--config_format", + "mistral", + "--load_format", + "mistral", ] @@ -30,10 +38,9 @@ def server(request, audio_assets: AudioTestAssets): json.dumps({"audio": len(audio_assets)}), ] + MISTRAL_FORMAT_ARGS - with RemoteOpenAIServer(MODEL_NAME, - args, - env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": - "30"}) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, args, env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": "30"} + ) as remote_server: yield remote_server @@ -64,15 +71,17 @@ def _get_prompt(audio_assets, question): @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models_with_multiple_audios(vllm_runner, - audio_assets: AudioTestAssets, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: +def test_models_with_multiple_audios( + vllm_runner, + audio_assets: AudioTestAssets, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT) run_multi_audio_test( vllm_runner, - [(vllm_prompt, [audio.audio_and_sample_rate - for audio in audio_assets])], + [(vllm_prompt, [audio.audio_and_sample_rate for audio in audio_assets])], MODEL_NAME, dtype=dtype, max_tokens=max_tokens, @@ -92,23 +101,17 @@ def asset_to_chunk(asset): return audio_dict audio_chunks = [asset_to_chunk(asset) for asset in audio_assets] - messages = [{ - "role": - "user", - "content": [ - *audio_chunks, - { - "type": - "text", - "text": - f"What's happening in these {len(audio_assets)} audio clips?" - }, - ], - }] - - chat_completion = await client.chat.completions.create(model=MODEL_NAME, - messages=messages, - max_tokens=10) + text = f"What's happening in these {len(audio_assets)} audio clips?" + messages = [ + { + "role": "user", + "content": [*audio_chunks, {"type": "text", "text": text}], + } + ] + + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, messages=messages, max_tokens=10 + ) assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index e0e9980b8833..766f09b0d320 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -12,8 +12,7 @@ PROMPTS = [ { - "prompt": - "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + "prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", "multi_modal_data": { "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate, }, @@ -25,9 +24,8 @@ "audio": AudioAsset("winning_call").audio_and_sample_rate, }, }, - "decoder_prompt": - "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", - } + "decoder_prompt": "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + }, ] EXPECTED = { @@ -41,7 +39,7 @@ " is June and the third base. They're going to wave him in. The throw" " to the plate will be late. The Mariners are going to play for the" " American League Championship. I don't believe it. It just continues" - " by all five." + " by all five.", ], "openai/whisper-small": [ " The first words I spoke in the original pornograph. A little piece" @@ -51,7 +49,7 @@ " comes joy. Here is Junior to third base. They're gonna wave him" " in. The throw to the plate will be late. The Mariners are going to" " play for the American League Championship. I don't believe it. It" - " just continues. My, oh my." + " just continues. My, oh my.", ], "openai/whisper-medium": [ " The first words I spoke in the original phonograph, a little piece" @@ -62,7 +60,7 @@ " Jorgen at third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh" - " my." + " my.", ], "openai/whisper-large-v3": [ " The first words I spoke in the original phonograph, a little piece" @@ -73,7 +71,7 @@ " Junior to third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh," - " my." + " my.", ], "openai/whisper-large-v3-turbo": [ " The first words I spoke in the original phonograph, a little piece" @@ -84,8 +82,8 @@ " Junior to third base. They're going to wave him in. The throw to the" " plate will be late. The Mariners are going to play for the American" " League Championship. I don't believe it. It just continues. My, oh," - " my." - ] + " my.", + ], } @@ -100,11 +98,11 @@ def run_test( expected_list = EXPECTED[model] * 10 with vllm_runner( - model, - dtype="half", - max_model_len=448, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, + model, + dtype="half", + max_model_len=448, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, ) as vllm_model: llm = vllm_model.llm diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index 133d5d6ee2ef..096931cca09f 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Helpers for building inputs that can be leveraged for different test types. -""" +"""Helpers for building inputs that can be leveraged for different test types.""" + from collections.abc import Iterable from pathlib import PosixPath from typing import Callable, Optional, Union @@ -10,20 +10,30 @@ from vllm.multimodal.audio import AudioResampler from vllm.multimodal.image import rescale_image_size -from vllm.multimodal.video import (rescale_video_size, resize_video, - sample_frames_from_video) +from vllm.multimodal.video import ( + rescale_video_size, + resize_video, + sample_frames_from_video, +) from .....conftest import AudioTestAssets, ImageTestAssets, VideoTestAssets -from .types import (SINGLE_AUDIO_BASE_PROMPT, SINGLE_IMAGE_BASE_PROMPTS, - TEST_AUDIO_PLACEHOLDER, TEST_IMG_PLACEHOLDER, - TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT, - ImageSizeWrapper, PromptWithMultiModalInput, SizeType, - VLMTestInfo) - - -def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int], - str], - test_placeholder: str) -> str: +from .types import ( + SINGLE_AUDIO_BASE_PROMPT, + SINGLE_IMAGE_BASE_PROMPTS, + TEST_AUDIO_PLACEHOLDER, + TEST_IMG_PLACEHOLDER, + TEST_VIDEO_PLACEHOLDER, + VIDEO_BASE_PROMPT, + ImageSizeWrapper, + PromptWithMultiModalInput, + SizeType, + VLMTestInfo, +) + + +def replace_test_placeholder( + prompt: str, mm_idx_to_prompt: Callable[[int], str], test_placeholder: str +) -> str: """Given a prompt, replaces each test placeholder with the model-specific tag. """ @@ -35,11 +45,13 @@ def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int], return img_prompt -def get_model_prompts(base_prompts: Iterable[str], - img_idx_to_prompt: Optional[Callable[[int], str]], - video_idx_to_prompt: Optional[Callable[[int], str]], - audio_idx_to_prompt: Optional[Callable[[int], str]], - prompt_formatter: Callable[[str], str]) -> list[str]: +def get_model_prompts( + base_prompts: Iterable[str], + img_idx_to_prompt: Optional[Callable[[int], str]], + video_idx_to_prompt: Optional[Callable[[int], str]], + audio_idx_to_prompt: Optional[Callable[[int], str]], + prompt_formatter: Callable[[str], str], +) -> list[str]: """Given a model-agnostic base prompt and test configuration for a model(s) to be tested, update the media placeholders and apply the prompt formatting to get the test prompt string for this model. @@ -56,19 +68,19 @@ def get_model_prompts(base_prompts: Iterable[str], # Replace the multimodal placeholders in the base prompt with # the correct ones for the model that we are testing if img_idx_to_prompt: - base_prompt = replace_test_placeholder(base_prompt, - img_idx_to_prompt, - TEST_IMG_PLACEHOLDER) + base_prompt = replace_test_placeholder( + base_prompt, img_idx_to_prompt, TEST_IMG_PLACEHOLDER + ) if video_idx_to_prompt: - base_prompt = replace_test_placeholder(base_prompt, - video_idx_to_prompt, - TEST_VIDEO_PLACEHOLDER) + base_prompt = replace_test_placeholder( + base_prompt, video_idx_to_prompt, TEST_VIDEO_PLACEHOLDER + ) if audio_idx_to_prompt: - base_prompt = replace_test_placeholder(base_prompt, - audio_idx_to_prompt, - TEST_AUDIO_PLACEHOLDER) + base_prompt = replace_test_placeholder( + base_prompt, audio_idx_to_prompt, TEST_AUDIO_PLACEHOLDER + ) # Apply the prompt formatter to wrap the base prompt with # the correct media placeholders to get the model test prompt @@ -84,14 +96,15 @@ def build_single_image_inputs_from_test_info( tmp_path: Optional[PosixPath] = None, ) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: - raise ValueError( - "Prompt formatter must be set to build single image inputs") + raise ValueError("Prompt formatter must be set to build single image inputs") - model_prompts = get_model_prompts(test_info.single_image_prompts, - test_info.img_idx_to_prompt, - test_info.video_idx_to_prompt, - test_info.audio_idx_to_prompt, - test_info.prompt_formatter) + model_prompts = get_model_prompts( + test_info.single_image_prompts, + test_info.img_idx_to_prompt, + test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, + test_info.prompt_formatter, + ) # For models that require a local path / URL encoded in the image; export # assets and encode into tmp_path for this test. This should be avoided @@ -110,8 +123,8 @@ def build_single_image_inputs_from_test_info( def build_single_image_inputs( - images, model_prompts, - size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: + images, model_prompts, size_wrapper: ImageSizeWrapper +) -> list[PromptWithMultiModalInput]: # For every image / prompt pair, get a pair containing two lists of # length size_factors, where the first contains duplicates of the model # prompt [str], and the second contains copies of the image after being @@ -125,7 +138,8 @@ def build_single_image_inputs( apply_image_size_scaling(image, size, size_wrapper.type) for size in size_wrapper.data ], - ) for image, prompt in zip(images, model_prompts) + ) + for image, prompt in zip(images, model_prompts) ] @@ -136,14 +150,15 @@ def build_multi_image_inputs_from_test_info( tmp_path: Optional[PosixPath] = None, ) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: - raise ValueError( - "Prompt formatter must be set to build multi image inputs") + raise ValueError("Prompt formatter must be set to build multi image inputs") - model_prompts = get_model_prompts([test_info.multi_image_prompt], - test_info.img_idx_to_prompt, - test_info.video_idx_to_prompt, - test_info.audio_idx_to_prompt, - test_info.prompt_formatter) + model_prompts = get_model_prompts( + [test_info.multi_image_prompt], + test_info.img_idx_to_prompt, + test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, + test_info.prompt_formatter, + ) if test_info.prompt_path_encoder is not None: if tmp_path is None: @@ -164,16 +179,20 @@ def build_multi_image_inputs_from_test_info( def build_multi_image_inputs( - image_lists, model_prompts, - size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: + image_lists, model_prompts, size_wrapper: ImageSizeWrapper +) -> list[PromptWithMultiModalInput]: return [ PromptWithMultiModalInput( prompts=[prompt for _ in size_wrapper.data], - image_data=[[ - apply_image_size_scaling(image, size, size_wrapper.type) - for image in images - ] for size in size_wrapper.data], - ) for images, prompt in zip(image_lists, model_prompts) + image_data=[ + [ + apply_image_size_scaling(image, size, size_wrapper.type) + for image in images + ] + for size in size_wrapper.data + ], + ) + for images, prompt in zip(image_lists, model_prompts) ] @@ -185,10 +204,10 @@ def build_embedding_inputs_from_test_info( # These conditions will always be true if invoked through filtering, # but we still check them in case this is ever called directly if test_info.prompt_formatter is None: - raise ValueError( - "Prompt formatter must be set to build image embedding inputs") - if size_wrapper.type != SizeType.SIZE_FACTOR or not \ - all(factor == 1.0 for factor in size_wrapper.data): + raise ValueError("Prompt formatter must be set to build image embedding inputs") + if size_wrapper.type != SizeType.SIZE_FACTOR or not all( + factor == 1.0 for factor in size_wrapper.data + ): raise ValueError("Embedding tests require constant (1.0) size factors") if test_info.convert_assets_to_embeddings is None: raise ValueError("No conversion func for getting embeddings found") @@ -209,8 +228,7 @@ def build_embedding_inputs_from_test_info( assert len(images) == len(model_prompts) inputs = build_single_image_inputs(images, model_prompts, size_wrapper) - vllm_embeddings = build_single_image_inputs(embeds, model_prompts, - size_wrapper) + vllm_embeddings = build_single_image_inputs(embeds, model_prompts, size_wrapper) return inputs, vllm_embeddings @@ -235,21 +253,22 @@ def build_video_inputs_from_test_info( for asset in video_assets ] - video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE - else rescale_video_size) + video_scaler = ( + resize_video if size_wrapper.type == SizeType.FIXED_SIZE else rescale_video_size + ) return [ PromptWithMultiModalInput( prompts=[prompt for _ in size_wrapper.data], - video_data=[ - video_scaler(video, size) for size in size_wrapper.data - ], - ) for video, prompt in zip(sampled_vids, model_prompts) + video_data=[video_scaler(video, size) for size in size_wrapper.data], + ) + for video, prompt in zip(sampled_vids, model_prompts) ] -def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], - size_type: SizeType): +def apply_image_size_scaling( + image, size: Union[float, tuple[int, int]], size_type: SizeType +): """Applies a size scaler to one image; this can be an image size factor, which scales the image while maintaining the aspect ratio""" # Special case for embeddings; if it's a tensor, it's only valid if we @@ -285,13 +304,16 @@ def build_audio_inputs_from_test_info( method="librosa", ) audios = [asset.audio_and_sample_rate for asset in audio_assets] - resampled_audios = [( - resampler.resample( - audio, - orig_sr=sr, - ), - int(resampler.target_sr), - ) for audio, sr in audios] + resampled_audios = [ + ( + resampler.resample( + audio, + orig_sr=sr, + ), + int(resampler.target_sr), + ) + for audio, sr in audios + ] return [ PromptWithMultiModalInput( diff --git a/tests/models/multimodal/generation/vlm_utils/case_filtering.py b/tests/models/multimodal/generation/vlm_utils/case_filtering.py index 1edb51213534..77e478e53c1f 100644 --- a/tests/models/multimodal/generation/vlm_utils/case_filtering.py +++ b/tests/models/multimodal/generation/vlm_utils/case_filtering.py @@ -4,19 +4,28 @@ modality, getting all combinations (similar to pytest's parametrization), handling multimodal placeholder substitution, and so on. """ + import itertools from collections import OrderedDict from collections.abc import Iterable import pytest -from .types import (EMBEDDING_SIZE_FACTORS, ExpandableVLMTestArgs, - ImageSizeWrapper, SizeType, VLMTestInfo, VLMTestType) +from .types import ( + EMBEDDING_SIZE_FACTORS, + ExpandableVLMTestArgs, + ImageSizeWrapper, + SizeType, + VLMTestInfo, + VLMTestType, +) def get_filtered_test_settings( - test_settings: dict[str, VLMTestInfo], test_type: VLMTestType, - new_proc_per_test: bool) -> dict[str, VLMTestInfo]: + test_settings: dict[str, VLMTestInfo], + test_type: VLMTestType, + new_proc_per_test: bool, +) -> dict[str, VLMTestInfo]: """Given the dict of potential test settings to run, return a subdict of tests who have the current test type enabled with the matching val for fork_per_test. @@ -25,7 +34,8 @@ def get_filtered_test_settings( def matches_test_type(test_info: VLMTestInfo, test_type: VLMTestType): return test_info.test_type == test_type or ( isinstance(test_info.test_type, Iterable) - and test_type in test_info.test_type) + and test_type in test_info.test_type + ) matching_tests = {} for test_name, test_info in test_settings.items(): @@ -36,68 +46,74 @@ def matches_test_type(test_info: VLMTestInfo, test_type: VLMTestType): assert test_info.convert_assets_to_embeddings is not None # Custom test inputs need to explicitly define the mm limit/inputs if matches_test_type(test_info, VLMTestType.CUSTOM_INPUTS): - assert (test_info.custom_test_opts is not None - and isinstance(test_info.custom_test_opts, Iterable)) + assert test_info.custom_test_opts is not None and isinstance( + test_info.custom_test_opts, Iterable + ) # For all types besides custom inputs, we need a prompt formatter else: assert test_info.prompt_formatter is not None # Everything looks okay; keep if this is correct proc handling - if (test_info.distributed_executor_backend - is not None) == new_proc_per_test: + if ( + test_info.distributed_executor_backend is not None + ) == new_proc_per_test: matching_tests[test_name] = test_info return matching_tests -def get_parametrized_options(test_settings: dict[str, VLMTestInfo], - test_type: VLMTestType, - create_new_process_for_each_test: bool): +def get_parametrized_options( + test_settings: dict[str, VLMTestInfo], + test_type: VLMTestType, + create_new_process_for_each_test: bool, +): """Converts all of our VLMTestInfo into an expanded list of parameters. This is similar to nesting pytest parametrize calls, but done directly through an itertools product so that each test can set things like size factors etc, while still running in isolated test cases. """ matching_tests = get_filtered_test_settings( - test_settings, test_type, create_new_process_for_each_test) + test_settings, test_type, create_new_process_for_each_test + ) # Ensure that something is wrapped as an iterable it's not already - ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e, ) + ensure_wrapped = lambda e: e if isinstance(e, (list, tuple)) else (e,) def get_model_type_cases(model_type: str, test_info: VLMTestInfo): # This is essentially the same as nesting a bunch of mark.parametrize # decorators, but we do it programmatically to allow overrides for on # a per-model basis, while still being able to execute each of these # as individual test cases in pytest. - iter_kwargs = OrderedDict([ - ("model", ensure_wrapped(test_info.models)), - ("max_tokens", ensure_wrapped(test_info.max_tokens)), - ("num_logprobs", ensure_wrapped(test_info.num_logprobs)), - ("dtype", ensure_wrapped(test_info.dtype)), - ("distributed_executor_backend", - ensure_wrapped(test_info.distributed_executor_backend)), - ]) + iter_kwargs = OrderedDict( + [ + ("model", ensure_wrapped(test_info.models)), + ("max_tokens", ensure_wrapped(test_info.max_tokens)), + ("num_logprobs", ensure_wrapped(test_info.num_logprobs)), + ("dtype", ensure_wrapped(test_info.dtype)), + ( + "distributed_executor_backend", + ensure_wrapped(test_info.distributed_executor_backend), + ), + ] + ) # num_frames is video only if test_type == VLMTestType.VIDEO: - iter_kwargs["num_video_frames"] = ensure_wrapped( - test_info.num_video_frames) + iter_kwargs["num_video_frames"] = ensure_wrapped(test_info.num_video_frames) # No sizes passed for custom inputs, since inputs are directly provided if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO): wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) if wrapped_sizes is None: - raise ValueError( - f"Sizes must be set for test type {test_type}") + raise ValueError(f"Sizes must be set for test type {test_type}") iter_kwargs["size_wrapper"] = wrapped_sizes - #Otherwise expand the custom test options instead + # Otherwise expand the custom test options instead elif test_type == VLMTestType.CUSTOM_INPUTS: if test_info.custom_test_opts is None: raise ValueError("Test has type CUSTOM_INPUTS, but none given") iter_kwargs["custom_test_opts"] = test_info.custom_test_opts - # yapf: disable # Wrap all model cases in a pytest parameter & pass marks through return [ pytest.param( @@ -105,10 +121,10 @@ def get_model_type_cases(model_type: str, test_info: VLMTestInfo): ExpandableVLMTestArgs( **{k: v for k, v in zip(iter_kwargs.keys(), case)} ), - marks=test_info.marks if test_info.marks is not None else [] - ) for case in list(itertools.product(*iter_kwargs.values())) + marks=test_info.marks if test_info.marks is not None else [], + ) + for case in list(itertools.product(*iter_kwargs.values())) ] - # yapf: enable # Get a list per model type, where each entry contains a tuple of all of # that model type's cases, then flatten them into the top level so that @@ -121,8 +137,8 @@ def get_model_type_cases(model_type: str, test_info: VLMTestInfo): def get_wrapped_test_sizes( - test_info: VLMTestInfo, - test_type: VLMTestType) -> tuple[ImageSizeWrapper, ...]: + test_info: VLMTestInfo, test_type: VLMTestType +) -> tuple[ImageSizeWrapper, ...]: """Given a test info which may have size factors or fixed sizes, wrap them and combine them into an iterable, each of which will be used in parameter expansion. @@ -133,18 +149,18 @@ def get_wrapped_test_sizes( """ # If it is an embedding test, we always use the EMBEDDING_SIZE_FACTORS if test_type == VLMTestType.EMBEDDING: - return tuple([ - ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) - for factor in EMBEDDING_SIZE_FACTORS - ]) + return tuple( + [ + ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) + for factor in EMBEDDING_SIZE_FACTORS + ] + ) # Audio and Custom inputs have preprocessed inputs elif test_type in (VLMTestType.AUDIO, VLMTestType.CUSTOM_INPUTS): return tuple() - size_factors = test_info.image_size_factors \ - if test_info.image_size_factors else [] - fixed_sizes = test_info.image_sizes \ - if test_info.image_sizes else [] + size_factors = test_info.image_size_factors if test_info.image_size_factors else [] + fixed_sizes = test_info.image_sizes if test_info.image_sizes else [] wrapped_factors = [ ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) @@ -152,8 +168,7 @@ def get_wrapped_test_sizes( ] wrapped_sizes = [ - ImageSizeWrapper(type=SizeType.FIXED_SIZE, data=size) - for size in fixed_sizes + ImageSizeWrapper(type=SizeType.FIXED_SIZE, data=size) for size in fixed_sizes ] return tuple(wrapped_factors + wrapped_sizes) diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index 11d44120b875..0c11f5f9b082 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Core test implementation to be shared across modalities.""" + from typing import Any, Callable, Optional import torch @@ -70,22 +71,23 @@ def run_test( if model_info.hf_overrides: vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides if model_info.skip_tokenizer_init: - vllm_runner_kwargs_[ - "skip_tokenizer_init"] = model_info.skip_tokenizer_init + vllm_runner_kwargs_["skip_tokenizer_init"] = model_info.skip_tokenizer_init if vllm_runner_kwargs: vllm_runner_kwargs_.update(vllm_runner_kwargs) - with vllm_runner(model, - max_model_len=max_model_len, - max_num_seqs=max_num_seqs, - dtype=dtype, - limit_mm_per_prompt=limit_mm_per_prompt, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=enforce_eager, - runner=runner, - **vllm_runner_kwargs_) as vllm_model: + with vllm_runner( + model, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, + dtype=dtype, + limit_mm_per_prompt=limit_mm_per_prompt, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=enforce_eager, + runner=runner, + **vllm_runner_kwargs_, + ) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() vllm_kwargs: dict[str, Any] = {} @@ -95,21 +97,19 @@ def run_test( vllm_kwargs["stop"] = stop_str for prompts, image_data, video_data, audio_data in vllm_inputs: - mm_data = dict(images=image_data, - videos=video_data, - audios=audio_data) + mm_data = dict(images=image_data, videos=video_data, audios=audio_data) vllm_kwargs_with_mm_data = vllm_kwargs | mm_data vllm_output = vllm_model.generate_greedy_logprobs( prompts, max_tokens, num_logprobs=num_logprobs, - **vllm_kwargs_with_mm_data) + **vllm_kwargs_with_mm_data, + ) vllm_outputs_per_mm.append(vllm_output) - hf_model = hf_runner(model, - dtype=dtype, - auto_cls=auto_cls, - model_kwargs=hf_model_kwargs) + hf_model = hf_runner( + model, dtype=dtype, auto_cls=auto_cls, model_kwargs=hf_model_kwargs + ) # Some models need to patch things like the model processor, e.g., internvl if patch_hf_runner is not None: @@ -129,16 +129,15 @@ def run_test( hf_kwargs["stop_strings"] = stop_str for prompts, image_data, video_data, audio_data in inputs: - mm_data = dict(images=image_data, - videos=video_data, - audios=audio_data) + mm_data = dict(images=image_data, videos=video_data, audios=audio_data) hf_kwargs_with_mm_data = hf_kwargs | mm_data hf_output = hf_model.generate_greedy_logprobs_limit( prompts, max_tokens, num_logprobs=num_logprobs, tokenizer=tokenizer, - **hf_kwargs_with_mm_data) + **hf_kwargs_with_mm_data, + ) hf_outputs_per_mm.append(hf_output) # Apply output processing / sanitation to the vLLM and HF runner results @@ -150,8 +149,7 @@ def run_test( second_runner_processor=vllm_output_post_proc, ) - for hf_outputs, vllm_outputs in zip(hf_outputs_per_mm, - vllm_outputs_per_mm): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_mm, vllm_outputs_per_mm): # This is usually check_logprobs_close, but it's passed through to # allow things like check_outputs_equal where needed comparator( @@ -171,15 +169,19 @@ def process_runner_outputs( ): """Applies the runner processor(s) to the runner outputs, if any.""" if first_runner_processor is not None: - first_runner_outputs = process_outputs(first_runner_processor, model, - first_runner_outputs) + first_runner_outputs = process_outputs( + first_runner_processor, model, first_runner_outputs + ) if second_runner_processor is not None: - second_runner_outputs = process_outputs(second_runner_processor, model, - second_runner_outputs) + second_runner_outputs = process_outputs( + second_runner_processor, model, second_runner_outputs + ) return first_runner_outputs, second_runner_outputs def process_outputs(output_processor, model, outputs_per_image): """Applies a model specific post-processor function to a runner's output""" - return [[output_processor(res, model) for res in outputs] - for outputs in outputs_per_image] + return [ + [output_processor(res, model) for res in outputs] + for outputs in outputs_per_image + ] diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index e369416fc49c..8f2f8bba39ca 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -1,12 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom input builders for edge-cases in different models.""" + from typing import Callable from vllm.assets.image import ImageAsset from vllm.multimodal.image import rescale_image_size -from vllm.multimodal.video import (rescale_video_size, resize_video, - sample_frames_from_video) +from vllm.multimodal.video import ( + rescale_video_size, + resize_video, + sample_frames_from_video, +) from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS from .builders import build_multi_image_inputs, build_single_image_inputs @@ -15,7 +19,7 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): """Builds inputs for multi-image (varied sizes/aspect ratio) testing. - + Args: formatter: model-specific prompt formatter. """ @@ -41,7 +45,7 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): stop_sign, rescale_image_size(stop_sign, 0.25), cherry_blossom.resize((183, 488)), - cherry_blossom.resize((488, 183)) + cherry_blossom.resize((488, 183)), ], cherry_blossom, ] @@ -54,10 +58,11 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): ] -def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], - num_frames: int = 16): +def multi_video_multi_aspect_ratio_inputs( + formatter: Callable[[str], str], num_frames: int = 16 +): """Builds inputs for multi-video (varied sizes/aspect ratio) testing. - + Args: formatter: model-specific prompt formatter. """ @@ -81,7 +86,7 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], video, rescale_video_size(video, 0.25), resize_video(video, (183, 488)), - resize_video(video, (488, 183)) + resize_video(video, (488, 183)), ], video, ] @@ -96,7 +101,9 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], def different_patch_input_cases_internvl(): images = [asset.pil_image.resize((896, 896)) for asset in IMAGE_ASSETS] - formatter = lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501 + formatter = ( + lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501 + ) single_img_prompts = [ "<image>\nWhat's the content in the center of the image?", "<image>\nWhat is the season?", @@ -115,14 +122,14 @@ def different_patch_input_cases_internvl(): def windows_attention_image_qwen2_5_vl(): - # image from regression issue: https://github.com/vllm-project/vllm/issues/15122 # noqa: E501 image = ImageAsset("hato").pil_image question = "Describe the image." img_prompt = "<|vision_start|><|image_pad|><|vision_end|>" - prompt = (f"<|im_start|>User\n{img_prompt}{question}<|im_end|>\n" - "<|im_start|>assistant\n") + prompt = ( + f"<|im_start|>User\n{img_prompt}{question}<|im_end|>\n<|im_start|>assistant\n" + ) wrapped_sf = ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=[0.5]) return build_single_image_inputs([image], [prompt], wrapped_sf) @@ -136,8 +143,9 @@ def video_with_metadata_glm4_1v(): formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n" scales = [0.1, 0.2, 0.25] - video_input = [[(rescale_video_size(video_array, scale), metadata)] - for scale in scales] + video_input = [ + [(rescale_video_size(video_array, scale), metadata)] for scale in scales + ] prompts = [formatted_prompt] * len(video_input) return [ diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index ba55450ec8a9..f924bea9f495 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -4,6 +4,7 @@ for manipulating the input / output of HF & vLLM test runners, which are typically specific to a small subset of models. """ + import types from pathlib import PosixPath from typing import Optional, Union @@ -15,8 +16,13 @@ import regex as re import torch from PIL.Image import Image -from transformers import (AutoConfig, AutoTokenizer, BatchFeature, - GenerationConfig, GenerationMixin) +from transformers import ( + AutoConfig, + AutoTokenizer, + BatchFeature, + GenerationConfig, + GenerationMixin, +) from transformers.video_utils import VideoMetadata from vllm.logprobs import SampleLogprobs @@ -27,8 +33,7 @@ ####### vLLM output processors functions -def blip2_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def blip2_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [blip2 models] to be comparable with hf output.""" _, output_str, out_logprobs = vllm_output @@ -42,8 +47,7 @@ def blip2_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs -def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [fuyu models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -53,8 +57,8 @@ def fuyu_vllm_to_hf_output(vllm_output: RunnerOutput, def qwen_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [qwen models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -64,8 +68,8 @@ def qwen_vllm_to_hf_output( def qwen2_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [qwen2 models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -75,8 +79,8 @@ def qwen2_vllm_to_hf_output( def kimiv_vl_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: """Sanitize vllm output [kimi_vl models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -85,23 +89,25 @@ def kimiv_vl_vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs -def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def llava_image_vllm_to_hf_output( + vllm_output: RunnerOutput, model: str +) -> RunnerOutput: config = AutoConfig.from_pretrained(model) mm_token_id = config.image_token_index return _llava_vllm_to_hf_output(vllm_output, model, mm_token_id) def llava_video_vllm_to_hf_output( - vllm_output: RunnerOutput, - model: str) -> tuple[list[int], str, Optional[SampleLogprobs]]: + vllm_output: RunnerOutput, model: str +) -> tuple[list[int], str, Optional[SampleLogprobs]]: config = AutoConfig.from_pretrained(model) mm_token_id = config.video_token_index return _llava_vllm_to_hf_output(vllm_output, model, mm_token_id) -def _llava_vllm_to_hf_output(vllm_output: RunnerOutput, model: str, - mm_token_id: int) -> RunnerOutput: +def _llava_vllm_to_hf_output( + vllm_output: RunnerOutput, model: str, mm_token_id: int +) -> RunnerOutput: """Sanitize vllm output [Llava models] to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -109,7 +115,8 @@ def _llava_vllm_to_hf_output(vllm_output: RunnerOutput, model: str, eos_token_id = tokenizer.eos_token_id hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) + token_id + for idx, token_id in enumerate(output_ids) if token_id != mm_token_id or output_ids[idx - 1] != mm_token_id ] @@ -128,8 +135,9 @@ def llava_onevision_hf_model_kwargs(model: str) -> dict: return config.to_dict() -def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def llava_onevision_vllm_to_hf_output( + vllm_output: RunnerOutput, model: str +) -> RunnerOutput: """Sanitize vllm output [llava-onevision] to compare with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -140,7 +148,8 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, eos_token_id = tokenizer.eos_token_id hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) + token_id + for idx, token_id in enumerate(output_ids) if token_id != video_token_id or output_ids[idx - 1] != video_token_id ] @@ -151,8 +160,7 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs -def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [mantis] to compare with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -161,8 +169,7 @@ def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, return output_ids, hf_output_str, out_logprobs -def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [phi3v] to be comparable with hf output.""" _, output_str, out_logprobs = vllm_output @@ -180,8 +187,7 @@ def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs -def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, - model: str) -> RunnerOutput: +def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output to be comparable with hf output.""" output_ids, output_str, out_logprobs = vllm_output @@ -192,7 +198,8 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, eos_token_id = tokenizer.eos_token_id hf_output_ids = [ - token_id for idx, token_id in enumerate(output_ids) + token_id + for idx, token_id in enumerate(output_ids) if token_id != image_token_id or output_ids[idx - 1] != image_token_id ] @@ -205,46 +212,40 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, ####### Post-processors for HF outputs -def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def deepseekvl2_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<|end▁of▁sentence|>"): output_str = output_str.split("<|end▁of▁sentence|>")[0] return output_ids, output_str, out_logprobs -def idefics3_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def idefics3_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<end_of_utterance>"): output_str = output_str.split("<end_of_utterance>")[0] return output_ids, output_str, out_logprobs -def smolvlm_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def smolvlm_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: # Based on Idefics3 return idefics3_trunc_hf_output(hf_output, model) -def minicpmv_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def minicpmv_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<|eot_id|>"): output_str = output_str.split("<|eot_id|>")[0] return output_ids, output_str, out_logprobs -def minimax_vl_01_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def minimax_vl_01_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<end_of_sentence>"): output_str = output_str.split("<end_of_sentence>")[0] return output_ids, output_str, out_logprobs -def ultravox_trunc_hf_output(hf_output: RunnerOutput, - model: str) -> RunnerOutput: +def ultravox_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output tokenizer = AutoTokenizer.from_pretrained(model) @@ -262,8 +263,8 @@ def get_llava_embeddings(image_assets: ImageTestAssets): ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( - tmp_path: PosixPath, prompt: str, - assets: Union[list[ImageAsset], ImageTestAssets]) -> str: + tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset], ImageTestAssets] +) -> str: """Given a temporary dir path, export one or more image assets into the tempdir & replace its contents with the local path to the string so that the HF version of Qwen-VL can resolve the path and load the image in its @@ -313,8 +314,9 @@ def processor(*args, text="", images=None, **kwargs): return BatchFeature(data=inputs, tensor_type="pt") hf_model.processor = processor - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language.model.embed_tokens + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language.model.embed_tokens + ) return hf_model @@ -357,11 +359,10 @@ def processor(*args, text="", images=None, **kwargs): assert len(contents) == len(images) return hf_processor.apply_chat_template( - [{ - "role": "user", - "image": image, - "content": content - } for image, content in zip(images, contents)], + [ + {"role": "user", "image": image, "content": content} + for image, content in zip(images, contents) + ], add_generation_prompt=True, tokenize=True, return_dict=True, @@ -369,8 +370,9 @@ def processor(*args, text="", images=None, **kwargs): ) hf_model.processor = processor - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.transformer.output_layer + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.transformer.output_layer + ) return hf_model @@ -387,10 +389,9 @@ def processor(*args, videos=None, **kwargs): else: video_metadata = None - return hf_processor(*args, - videos=videos, - video_metadata=video_metadata, - **kwargs) + return hf_processor( + *args, videos=videos, video_metadata=video_metadata, **kwargs + ) hf_model.processor = processor return hf_model @@ -406,8 +407,9 @@ def __init__(self, hf_runner: HfRunner): self.num_image_token = hf_runner.model.num_image_token self.tokenizer = hf_runner.tokenizer - self.config = AutoConfig.from_pretrained(hf_runner.model_name, - trust_remote_code=True) + self.config = AutoConfig.from_pretrained( + hf_runner.model_name, trust_remote_code=True + ) self.vision_config = self.config.vision_config self.use_thumbnail = self.config.use_thumbnail self.use_msac = self.config.use_msac @@ -415,13 +417,14 @@ def __init__(self, hf_runner: HfRunner): self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, list[Image]], - **kwargs): - # yapf: disable + def __call__(self, text: str, images: Union[Image, list[Image]], **kwargs): from vllm.model_executor.models.h2ovl import ( - IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values_h2ovl) + IMG_CONTEXT, + IMG_END, + IMG_START, + image_to_pixel_values_h2ovl, + ) - # yapf: enable images = [images] if isinstance(images, Image) else images pixel_values = [ image_to_pixel_values_h2ovl( @@ -431,29 +434,26 @@ def __call__(self, text: str, images: Union[Image, list[Image]], max_num=self.max_num, use_thumbnail=self.use_thumbnail, use_msac=self.use_msac, - ) for image in images - ] - num_patches_list = [ - pixel_value.shape[0] for pixel_value in pixel_values + ) + for image in images ] + num_patches_list = [pixel_value.shape[0] for pixel_value in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token \ - * num_patches + context_tokens = IMG_CONTEXT * self.num_image_token * num_patches image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('<image>', image_tokens, 1) + text = text.replace("<image>", image_tokens, 1) prompt = self.tokenizer(text, return_tensors="pt") prompt.update({"pixel_values": pixel_values}) return prompt - img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( - "<IMG_CONTEXT>") + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") hf_model.model.img_context_token_id = img_context_token_id hf_model.processor = H2OVLProcessor(hf_model) - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() - hf_model.model.generate = types.MethodType(_internvl_generate, - hf_model.model) + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language_model.get_output_embeddings() + ) + hf_model.model.generate = types.MethodType(_internvl_generate, hf_model.model) return hf_model @@ -467,19 +467,23 @@ def __init__(self, hf_runner: HfRunner): self.num_image_token = hf_runner.model.num_image_token self.tokenizer = hf_runner.tokenizer - self.config = AutoConfig.from_pretrained(hf_runner.model_name, - trust_remote_code=True) + self.config = AutoConfig.from_pretrained( + hf_runner.model_name, trust_remote_code=True + ) self.vision_config = self.config.vision_config self.use_thumbnail = self.config.use_thumbnail self.min_num = self.config.min_dynamic_patch self.max_num = self.config.max_dynamic_patch self.image_size = self.vision_config.image_size - def __call__(self, text: str, images: Union[Image, list[Image]], - **kwargs): + def __call__(self, text: str, images: Union[Image, list[Image]], **kwargs): from vllm.model_executor.models.skyworkr1v import ( - IMG_CONTEXT, IMG_END, IMG_START, - image_to_pixel_values_skyworkr1v) + IMG_CONTEXT, + IMG_END, + IMG_START, + image_to_pixel_values_skyworkr1v, + ) + images = [images] if isinstance(images, Image) else images pixel_values = [ image_to_pixel_values_skyworkr1v( @@ -488,29 +492,26 @@ def __call__(self, text: str, images: Union[Image, list[Image]], min_num=self.min_num, max_num=self.max_num, use_thumbnail=self.use_thumbnail, - ) for image in images - ] - num_patches_list = [ - pixel_value.shape[0] for pixel_value in pixel_values + ) + for image in images ] + num_patches_list = [pixel_value.shape[0] for pixel_value in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token \ - * num_patches + context_tokens = IMG_CONTEXT * self.num_image_token * num_patches image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('<image>', image_tokens, 1) + text = text.replace("<image>", image_tokens, 1) prompt = self.tokenizer(text, return_tensors="pt") prompt.update({"pixel_values": pixel_values}) return prompt - img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( - "<IMG_CONTEXT>") + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") hf_model.model.img_context_token_id = img_context_token_id hf_model.processor = SkyworkR1VProcessor(hf_model) - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() - hf_model.model.generate = types.MethodType(_internvl_generate, - hf_model.model) + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language_model.get_output_embeddings() + ) + hf_model.model.generate = types.MethodType(_internvl_generate, hf_model.model) return hf_model @@ -524,8 +525,9 @@ def __init__(self, hf_runner: HfRunner): self.num_image_token = hf_runner.model.num_image_token self.tokenizer = hf_runner.tokenizer - self.config = AutoConfig.from_pretrained(hf_runner.model_name, - trust_remote_code=True) + self.config = AutoConfig.from_pretrained( + hf_runner.model_name, trust_remote_code=True + ) self.vision_config = self.config.vision_config self.use_thumbnail = self.config.use_thumbnail self.min_num = self.config.min_dynamic_patch @@ -540,8 +542,13 @@ def __call__( **kwargs, ): from vllm.model_executor.models.internvl import ( - IMG_CONTEXT, IMG_END, IMG_START, - image_to_pixel_values_internvl, video_to_pixel_values_internvl) + IMG_CONTEXT, + IMG_END, + IMG_START, + image_to_pixel_values_internvl, + video_to_pixel_values_internvl, + ) + images = [images] if isinstance(images, Image) else images videos = [videos] if isinstance(videos, np.ndarray) else videos if images is not None: @@ -552,7 +559,8 @@ def __call__( min_num=self.min_num, max_num=self.max_num, use_thumbnail=self.use_thumbnail, - ) for image in images + ) + for image in images ] num_patches_images = [ pixel_value.shape[0] for pixel_value in pixel_values_images @@ -568,7 +576,8 @@ def __call__( min_num=1, max_num=1, use_thumbnail=False, - ) for video in videos + ) + for video in videos ] num_patches_videos = [ pixel_value.shape[0] for pixel_value in pixel_values_videos @@ -580,38 +589,37 @@ def __call__( while ("<image>" in text) or ("<video>" in text): image_index = text.find("<image>") video_index = text.find("<video>") - if image_index == -1 or (video_index > -1 - and video_index < image_index): + if image_index == -1 or ( + video_index > -1 and video_index < image_index + ): num_patches = num_patches_videos.pop(0) pixel_values.append(pixel_values_videos.pop(0)) - context_tokens = IMG_START + \ - IMG_CONTEXT * self.num_image_token + IMG_END - video_tokens = ''.join([ - f'Frame{i+1}: {context_tokens}' - for i in range(num_patches) - ]) - text = text.replace('<video>', video_tokens, 1) + context_tokens = ( + IMG_START + IMG_CONTEXT * self.num_image_token + IMG_END + ) + video_tokens = "".join( + [f"Frame{i + 1}: {context_tokens}" for i in range(num_patches)] + ) + text = text.replace("<video>", video_tokens, 1) else: num_patches = num_patches_images.pop(0) pixel_values.append(pixel_values_images.pop(0)) - context_tokens = IMG_CONTEXT * self.num_image_token \ - * num_patches + context_tokens = IMG_CONTEXT * self.num_image_token * num_patches image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('<image>', image_tokens, 1) + text = text.replace("<image>", image_tokens, 1) pixel_values = torch.cat(pixel_values, dim=0) prompt = self.tokenizer(text, return_tensors="pt") prompt.update({"pixel_values": pixel_values}) return prompt - img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( - "<IMG_CONTEXT>") + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") hf_model.model.img_context_token_id = img_context_token_id hf_model.processor = InternVLProcessor(hf_model) - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() - hf_model.model.generate = types.MethodType(_internvl_generate, - hf_model.model) + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.language_model.get_output_embeddings() + ) + hf_model.model.generate = types.MethodType(_internvl_generate, hf_model.model) return hf_model @@ -631,7 +639,7 @@ def _internvl_generate( input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) - selected = (input_ids == self.img_context_token_id) + selected = input_ids == self.img_context_token_id assert selected.sum() != 0 input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) @@ -778,8 +786,9 @@ def _generate(self, max_new_tokens=None, do_sample=None, **kwargs): def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Ovis2.""" - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.llm.get_output_embeddings() + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.llm.get_output_embeddings() + ) def processor(*args, text="", images=None, **kwargs): text_tokenizer = hf_model.model.get_text_tokenizer() @@ -787,8 +796,7 @@ def processor(*args, text="", images=None, **kwargs): prompt_start_and_end = { "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"), - "llama": - ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), + "llama": ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), "gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"), } for start, end in prompt_start_and_end.values(): @@ -797,7 +805,8 @@ def processor(*args, text="", images=None, **kwargs): break prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs( - text_or_conversations=text, images=images) + text_or_conversations=text, images=images + ) attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id) inputs = { @@ -813,8 +822,9 @@ def processor(*args, text="", images=None, **kwargs): def ovis2_5_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for Ovis2.""" - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.llm.get_output_embeddings() + hf_model.model.get_output_embeddings = ( + lambda: hf_model.model.llm.get_output_embeddings() + ) def processor(*args, text="", images=None, videos=None, **kwargs): if images is None: @@ -825,13 +835,11 @@ def processor(*args, text="", images=None, videos=None, **kwargs): videos = [] else: videos = [videos] if isinstance(videos, np.ndarray) else videos - videos = [[PIL.Image.fromarray(frame) for frame in vid] - for vid in videos] + videos = [[PIL.Image.fromarray(frame) for frame in vid] for vid in videos] prompt_start_and_end = { "qwen2": ("<|im_start|>user\n", "<|im_end|>\n"), - "llama": - ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), + "llama": ("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"), "gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"), } for start, end in prompt_start_and_end.values(): @@ -842,21 +850,20 @@ def processor(*args, text="", images=None, videos=None, **kwargs): images_message = [{"type": "image", "image": img} for img in images] videos_message = [{"type": "video", "video": vid} for vid in videos] - messages = [{ - "role": - "user", - "content": [ - *images_message, - *videos_message, - { - "type": "text", - "text": text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + *images_message, + *videos_message, + {"type": "text", "text": text}, + ], + } + ] input_ids, pixel_values, grid_thws = hf_model.model.preprocess_inputs( - messages=messages, enable_thinking=True) + messages=messages, enable_thinking=True + ) inputs = { "inputs": input_ids, "pixel_values": pixel_values, diff --git a/tests/models/multimodal/generation/vlm_utils/runners.py b/tests/models/multimodal/generation/vlm_utils/runners.py index 562f89df1347..c91ae117b558 100644 --- a/tests/models/multimodal/generation/vlm_utils/runners.py +++ b/tests/models/multimodal/generation/vlm_utils/runners.py @@ -3,23 +3,34 @@ """Entrypoints for wrapping the core run_test implementation for specific test types / modalities. """ + from pathlib import PosixPath -from .....conftest import (AudioTestAssets, HfRunner, ImageTestAssets, - VideoTestAssets, VllmRunner) +from .....conftest import ( + AudioTestAssets, + HfRunner, + ImageTestAssets, + VideoTestAssets, + VllmRunner, +) from . import builders, core from .types import ExpandableVLMTestArgs, VLMTestInfo ####### Entrypoints for running different test types -def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets): +def run_single_image_test( + *, + tmp_path: PosixPath, + model_test_info: VLMTestInfo, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): assert test_case.size_wrapper is not None inputs = builders.build_single_image_inputs_from_test_info( - model_test_info, image_assets, test_case.size_wrapper, tmp_path) + model_test_info, image_assets, test_case.size_wrapper, tmp_path + ) core.run_test( hf_runner=hf_runner, @@ -31,17 +42,23 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"image": 1}, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) -def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets): +def run_multi_image_test( + *, + tmp_path: PosixPath, + model_test_info: VLMTestInfo, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): assert test_case.size_wrapper is not None inputs = builders.build_multi_image_inputs_from_test_info( - model_test_info, image_assets, test_case.size_wrapper, tmp_path) + model_test_info, image_assets, test_case.size_wrapper, tmp_path + ) core.run_test( hf_runner=hf_runner, @@ -53,17 +70,22 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"image": len(image_assets)}, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) -def run_embedding_test(*, model_test_info: VLMTestInfo, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - image_assets: ImageTestAssets): +def run_embedding_test( + *, + model_test_info: VLMTestInfo, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + image_assets: ImageTestAssets, +): assert test_case.size_wrapper is not None inputs, vllm_embeddings = builders.build_embedding_inputs_from_test_info( - model_test_info, image_assets, test_case.size_wrapper) + model_test_info, image_assets, test_case.size_wrapper + ) core.run_test( hf_runner=hf_runner, @@ -76,7 +98,8 @@ def run_embedding_test(*, model_test_info: VLMTestInfo, limit_mm_per_prompt={"image": 1}, vllm_embeddings=vllm_embeddings, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) def run_video_test( @@ -90,8 +113,11 @@ def run_video_test( assert test_case.size_wrapper is not None assert test_case.num_video_frames is not None inputs = builders.build_video_inputs_from_test_info( - model_test_info, video_assets, test_case.size_wrapper, - test_case.num_video_frames) + model_test_info, + video_assets, + test_case.size_wrapper, + test_case.num_video_frames, + ) core.run_test( hf_runner=hf_runner, @@ -103,7 +129,8 @@ def run_video_test( num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"video": len(video_assets)}, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) def run_audio_test( @@ -114,8 +141,7 @@ def run_audio_test( vllm_runner: type[VllmRunner], audio_assets: AudioTestAssets, ): - inputs = builders.build_audio_inputs_from_test_info( - model_test_info, audio_assets) + inputs = builders.build_audio_inputs_from_test_info(model_test_info, audio_assets) core.run_test( hf_runner=hf_runner, @@ -127,13 +153,17 @@ def run_audio_test( num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"audio": 1}, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) -def run_custom_inputs_test(*, model_test_info: VLMTestInfo, - test_case: ExpandableVLMTestArgs, - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner]): +def run_custom_inputs_test( + *, + model_test_info: VLMTestInfo, + test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], +): # Custom test cases can provide inputs directly, but they need to # explicitly provided a CustomTestConfig, which wraps the inputs and # the limit_mm_per_prompt @@ -155,4 +185,5 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt=limit_mm_per_prompt, distributed_executor_backend=test_case.distributed_executor_backend, - **model_test_info.get_non_parametrized_runner_kwargs()) + **model_test_info.get_non_parametrized_runner_kwargs(), + ) diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index 6a82bdfc4cf2..bb34d1cc6dad 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Types for writing multimodal model tests.""" + from collections.abc import Iterable from enum import Enum from pathlib import PosixPath @@ -15,9 +16,16 @@ from vllm.logprobs import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer -from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset, - ImageTestAssets, PromptAudioInput, PromptImageInput, - PromptVideoInput) +from .....conftest import ( + AUDIO_ASSETS, + IMAGE_ASSETS, + HfRunner, + ImageAsset, + ImageTestAssets, + PromptAudioInput, + PromptImageInput, + PromptVideoInput, +) from ....utils import check_logprobs_close # meta image tag; will be replaced by the appropriate tag for the model @@ -25,28 +33,31 @@ TEST_VIDEO_PLACEHOLDER = "<vlm_video>" TEST_AUDIO_PLACEHOLDER = "<lmm_audio>" -# yapf: disable -SINGLE_IMAGE_BASE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": f"{TEST_IMG_PLACEHOLDER}What's the content of the image?", - "cherry_blossom": f"{TEST_IMG_PLACEHOLDER}What is the season?", -}) -SINGLE_AUDIO_BASE_PROMPT = AUDIO_ASSETS.prompts({ - "mary_had_lamb": f"{TEST_AUDIO_PLACEHOLDER}Transcribe this audio into English.", # noqa: E501 - "winning_call": f"{TEST_AUDIO_PLACEHOLDER}What is happening in this audio clip?", # noqa: E501 -}) +SINGLE_IMAGE_BASE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": f"{TEST_IMG_PLACEHOLDER}What's the content of the image?", + "cherry_blossom": f"{TEST_IMG_PLACEHOLDER}What is the season?", + } +) +SINGLE_AUDIO_BASE_PROMPT = AUDIO_ASSETS.prompts( + { + "mary_had_lamb": f"{TEST_AUDIO_PLACEHOLDER}Transcribe this audio into English.", # noqa: E501 + "winning_call": f"{TEST_AUDIO_PLACEHOLDER}What is happening in this audio clip?", # noqa: E501 + } +) MULTI_IMAGE_BASE_PROMPT = f"Image-1: {TEST_IMG_PLACEHOLDER}Image-2: {TEST_IMG_PLACEHOLDER}Describe the two images in detail.\n" # noqa: E501 VIDEO_BASE_PROMPT = f"{TEST_VIDEO_PLACEHOLDER}Why is this video funny?" -IMAGE_SIZE_FACTORS = [(), (1.0, ), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)] -EMBEDDING_SIZE_FACTORS = [(), (1.0, ), (1.0, 1.0, 1.0)] +IMAGE_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0), (0.25, 0.5, 1.0)] +EMBEDDING_SIZE_FACTORS = [(), (1.0,), (1.0, 1.0, 1.0)] RunnerOutput = tuple[list[int], str, Optional[SampleLogprobs]] -# yapf: enable class PromptWithMultiModalInput(NamedTuple): """Holds the multimodal input for a single test case.""" + prompts: list[str] image_data: Optional[PromptImageInput] = None video_data: Optional[PromptVideoInput] = None @@ -100,8 +111,9 @@ class VLMTestInfo(NamedTuple): # Function for converting ImageAssets to image embeddings; # We need to define this explicitly for embedding tests - convert_assets_to_embeddings: Optional[Callable[[ImageTestAssets], - list[torch.Tensor]]] = None + convert_assets_to_embeddings: Optional[ + Callable[[ImageTestAssets], list[torch.Tensor]] + ] = None # Exposed options for vLLM runner; we change these in a several tests, # but the defaults are derived from VllmRunner & the engine defaults @@ -156,8 +168,8 @@ class VLMTestInfo(NamedTuple): # for Qwen-VL, which requires encoding the image path / url into the prompt # for HF runner prompt_path_encoder: Optional[ - Callable[[PosixPath, str, Union[list[ImageAsset], ImageTestAssets]], - str]] = None # noqa: E501 + Callable[[PosixPath, str, Union[list[ImageAsset], ImageTestAssets]], str] + ] = None # noqa: E501 # Allows configuring a test to run with custom inputs custom_test_opts: Optional[list[CustomTestOptions]] = None @@ -190,6 +202,7 @@ def get_non_parametrized_runner_kwargs(self): class ExpandableVLMTestArgs(NamedTuple): """The expanded kwargs which correspond to a single test case.""" + model: str max_tokens: int num_logprobs: int diff --git a/tests/models/multimodal/pooling/test_clip.py b/tests/models/multimodal/pooling/test_clip.py new file mode 100644 index 000000000000..b8c6c4abace9 --- /dev/null +++ b/tests/models/multimodal/pooling/test_clip.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import CLIPModel + +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ...utils import check_embeddings_close + +HF_TEXT_PROMPTS = [ + "a photo of a stop sign", + "a photo of a cherry blossom", +] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "", + "cherry_blossom": "", + } +) + +MODELS = ["openai/clip-vit-base-patch32"] + + +def _run_test( + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + input_texts: list[str], + input_images: PromptImageInput, + model: str, + *, + dtype: str, +) -> None: + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner( + model, runner="pooling", dtype=dtype, enforce_eager=True, max_model_len=77 + ) as vllm_model: + vllm_outputs = vllm_model.embed(input_texts, images=input_images) + + with hf_runner(model, dtype=dtype, auto_cls=CLIPModel) as hf_model: + all_inputs = hf_model.get_inputs(input_texts, images=input_images) + + all_outputs = [] + for inputs in all_inputs: + if "pixel_values" in inputs: + inputs.pop("input_ids") + pooled_output = hf_model.model.get_image_features( + **hf_model.wrap_device(inputs) + ).squeeze(0) + else: + pooled_output = hf_model.model.get_text_features( + **hf_model.wrap_device(inputs) + ).squeeze(0) + + all_outputs.append(pooled_output.tolist()) + + hf_outputs = all_outputs + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_text( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, # type: ignore + model, + dtype=dtype, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [ + (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, + model, + dtype=dtype, + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_models_text_image_no_crash( + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + texts = [HF_TEXT_PROMPTS[0]] + images = [image_assets[0].pil_image] + + with vllm_runner( + model, runner="pooling", dtype=dtype, enforce_eager=True, max_model_len=77 + ) as vllm_model: + with pytest.raises(ValueError, match="not both"): + vllm_model.embed(texts, images=images) + + # Should still be able to run subsequent requests + vllm_model.embed(texts) + vllm_model.embed([""], images=images) diff --git a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py index f152ded3fb23..7f30b1f299ba 100644 --- a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py +++ b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py @@ -17,18 +17,21 @@ # T -> X ( "Query: Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501, - Image.new("RGB", (56, 56))), + Image.new("RGB", (56, 56)), + ), # T -> X - ("Query: Retrieve an image of this caption: cherry blossom", - Image.new("RGB", (56, 56))), + ( + "Query: Retrieve an image of this caption: cherry blossom", + Image.new("RGB", (56, 56)), + ), ] -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "What is shown in this image?", - "cherry_blossom": - "What is shown in this image?" -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "What is shown in this image?", + "cherry_blossom": "What is shown in this image?", + } +) MODELS = ["MrLight/dse-qwen2-2b-mrl-v1"] @@ -36,34 +39,30 @@ def get_messages(image: Image.Image, text: str, embed_text: bool): # assert False, 'remember to use outer [] as required' if embed_text: - messages = [{ - "role": - "user", - "content": [ - { - "type": "image", - "image": Image.new("RGB", (56, 56)), - "resized_height": 1, - "resized_width": 1 - }, # need a dummy image here for an easier process. - { - "type": "text", - "text": text - }, - ] - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": Image.new("RGB", (56, 56)), + "resized_height": 1, + "resized_width": 1, + }, # need a dummy image here for an easier process. + {"type": "text", "text": text}, + ], + } + ] else: - messages = [{ - "role": - "user", - "content": [{ - "type": "image", - "image": image - }, { - "type": "text", - "text": text - }] - }] + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": text}, + ], + } + ] return messages @@ -71,8 +70,10 @@ def apply_chat_template_and_add_eos( messages: list[dict], apply_chat_template_fn: Callable, ): - prompt = apply_chat_template_fn( - messages, tokenize=False, add_generation_prompt=True) + "<|endoftext|>" + prompt = ( + apply_chat_template_fn(messages, tokenize=False, add_generation_prompt=True) + + "<|endoftext|>" + ) return prompt @@ -86,16 +87,14 @@ def _run_test( *, dtype: str, ) -> None: - '''SET PYTHONPATH''' + """SET PYTHONPATH""" # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - runner="pooling", - dtype=dtype, - enforce_eager=True, - max_model_len=8192) as vllm_model: + with vllm_runner( + model, runner="pooling", dtype=dtype, enforce_eager=True, max_model_len=8192 + ) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() texts = [ # this is necessary because vllm_model.embed will not apply any @@ -105,25 +104,25 @@ def _run_test( apply_chat_template_and_add_eos( get_messages(image, text, False), apply_chat_template_fn=tokenizer.apply_chat_template, - ) for text, image in zip(input_texts, input_images) + ) + for text, image in zip(input_texts, input_images) # vllm will replace the pad token with the actual image, # which may be a placeholder image, later. ] vllm_outputs = vllm_model.embed(texts, images=input_images) hf_outputs = [] - with hf_runner(model, - dtype=dtype, - auto_cls=Qwen2VLForConditionalGeneration) as hf_model: - + with hf_runner( + model, dtype=dtype, auto_cls=Qwen2VLForConditionalGeneration + ) as hf_model: prompts = [] - for text, image, embed_text in zip(input_texts, input_images, - embed_texts): + for text, image, embed_text in zip(input_texts, input_images, embed_texts): # dse requires non-standard input processing # because it needs an image_pad token messages = get_messages(image, text, embed_text) prompt = apply_chat_template_and_add_eos( - messages, hf_model.processor.apply_chat_template) + messages, hf_model.processor.apply_chat_template + ) prompts.append(prompt) @@ -145,9 +144,9 @@ def _run_test( return_dict=True, output_hidden_states=True, ) - pooled_output = F.normalize(outputs.hidden_states[-1][0, -1], - p=2, - dim=-1) + pooled_output = F.normalize( + outputs.hidden_states[-1][0, -1], p=2, dim=-1 + ) all_outputs.append(pooled_output.tolist()) @@ -170,8 +169,9 @@ def test_models_text( model: str, dtype: str, ) -> None: - input_texts_images = [(text, image_placeholder) - for text, image_placeholder in HF_TEXT_PROMPTS] + input_texts_images = [ + (text, image_placeholder) for text, image_placeholder in HF_TEXT_PROMPTS + ] input_texts = [text for text, _ in input_texts_images] input_images = [image for _, image in input_texts_images] embed_texts = [True] * len(input_texts) @@ -198,8 +198,7 @@ def test_models_image( dtype: str, ) -> None: input_texts_images = [ - (text, asset.pil_image) - for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) ] input_texts = [text for text, _ in input_texts_images] input_images = [image for _, image in input_texts_images] diff --git a/tests/models/multimodal/pooling/test_intern_vit.py b/tests/models/multimodal/pooling/test_intern_vit.py index 3e2be34a50ad..b474e851319a 100644 --- a/tests/models/multimodal/pooling/test_intern_vit.py +++ b/tests/models/multimodal/pooling/test_intern_vit.py @@ -29,7 +29,7 @@ def run_intern_vit_test( img_processor = CLIPImageProcessor.from_pretrained(model) images = [asset.pil_image for asset in image_assets] pixel_values = [ - img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype) + img_processor(images, return_tensors="pt").pixel_values.to(torch_dtype) for images in images ] @@ -37,15 +37,16 @@ def run_intern_vit_test( if not getattr(config, "norm_type", None): config.norm_type = "rms_norm" - hf_model = AutoModel.from_pretrained(model, - torch_dtype=torch_dtype, - trust_remote_code=True).to("cuda") + hf_model = AutoModel.from_pretrained( + model, torch_dtype=torch_dtype, trust_remote_code=True + ).to("cuda") hf_outputs_per_image = [ hf_model(pixel_value.to("cuda")).last_hidden_state for pixel_value in pixel_values ] from vllm.model_executor.models.intern_vit import InternVisionModel + vllm_model = InternVisionModel(config) vllm_model.load_weights(hf_model.state_dict().items()) @@ -54,22 +55,23 @@ def run_intern_vit_test( vllm_model = vllm_model.to("cuda", torch_dtype) vllm_outputs_per_image = [ - vllm_model(pixel_values=pixel_value.to("cuda")) - for pixel_value in pixel_values + vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values ] del vllm_model cleanup_dist_env_and_memory() cos_similar = nn.CosineSimilarity(dim=-1) - for vllm_output, hf_output in zip(vllm_outputs_per_image, - hf_outputs_per_image): + for vllm_output, hf_output in zip(vllm_outputs_per_image, hf_outputs_per_image): assert cos_similar(vllm_output, hf_output).mean() > 0.99 -@pytest.mark.parametrize("model_id", [ - "OpenGVLab/InternViT-300M-448px", - "OpenGVLab/InternViT-6B-448px-V1-5", -]) +@pytest.mark.parametrize( + "model_id", + [ + "OpenGVLab/InternViT-300M-448px", + "OpenGVLab/InternViT-6B-448px-V1-5", + ], +) @pytest.mark.parametrize("dtype", ["half"]) def test_models(dist_init, image_assets, model_id, dtype: str) -> None: run_intern_vit_test( diff --git a/tests/models/multimodal/pooling/test_jinavl_reranker.py b/tests/models/multimodal/pooling/test_jinavl_reranker.py index 7ad7a8d284cb..853f56618290 100644 --- a/tests/models/multimodal/pooling/test_jinavl_reranker.py +++ b/tests/models/multimodal/pooling/test_jinavl_reranker.py @@ -29,7 +29,6 @@ def vllm_reranker( query_type: str = "text", doc_type: str = "text", ): - def create_image_param(url: str) -> ChatCompletionContentPartImageParam: return {"type": "image_url", "image_url": {"url": f"{url}"}} @@ -38,23 +37,25 @@ def create_image_param(url: str) -> ChatCompletionContentPartImageParam: query = query_strs elif query_type == "image": query = ScoreMultiModalParam( - content=[create_image_param(url) for url in query_strs]) + content=[create_image_param(url) for url in query_strs] + ) documents: Union[list[str], ScoreMultiModalParam] if doc_type == "text": documents = document_strs elif doc_type == "image": documents = ScoreMultiModalParam( - content=[create_image_param(url) for url in document_strs]) + content=[create_image_param(url) for url in document_strs] + ) with vllm_runner( - model_name, - runner="pooling", - dtype=dtype, - max_num_seqs=2, - max_model_len=2048, - mm_processor_kwargs=mm_processor_kwargs, - limit_mm_per_prompt=limit_mm_per_prompt, + model_name, + runner="pooling", + dtype=dtype, + max_num_seqs=2, + max_model_len=2048, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt=limit_mm_per_prompt, ) as vllm_model: outputs = vllm_model.llm.score(query, documents) @@ -78,16 +79,15 @@ def hf_reranker( data_pairs = [[query_strs[0], d] for d in document_strs] with hf_runner( - model_name, - dtype=dtype, - trust_remote_code=True, - auto_cls=AutoModel, - model_kwargs={"key_mapping": checkpoint_to_hf_mapper}, + model_name, + dtype=dtype, + trust_remote_code=True, + auto_cls=AutoModel, + model_kwargs={"key_mapping": checkpoint_to_hf_mapper}, ) as hf_model: - return hf_model.model.compute_score(data_pairs, - max_length=2048, - query_type=query_type, - doc_type=doc_type) + return hf_model.model.compute_score( + data_pairs, max_length=2048, query_type=query_type, doc_type=doc_type + ) # Visual Documents Reranking @@ -100,10 +100,12 @@ def test_model_text_image(hf_runner, vllm_runner, model_name, dtype): "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png", ] - hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, - "text", "image") - vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, - documents, "text", "image") + hf_outputs = hf_reranker( + hf_runner, model_name, dtype, query, documents, "text", "image" + ) + vllm_outputs = vllm_reranker( + vllm_runner, model_name, dtype, query, documents, "text", "image" + ) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) @@ -127,10 +129,12 @@ def test_model_text_text(hf_runner, vllm_runner, model_name, dtype): lower computational requirements.""", # noqa: E501 "数据提取么?为什么不用正则啊,你用正则不就全解决了么?", ] - hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, - "text", "text") - vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, - documents, "text", "text") + hf_outputs = hf_reranker( + hf_runner, model_name, dtype, query, documents, "text", "text" + ) + vllm_outputs = vllm_reranker( + vllm_runner, model_name, dtype, query, documents, "text", "text" + ) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) @@ -157,10 +161,12 @@ def test_model_image_text(hf_runner, vllm_runner, model_name, dtype): "数据提取么?为什么不用正则啊,你用正则不就全解决了么?", ] - hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, - "image", "text") - vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, - documents, "image", "text") + hf_outputs = hf_reranker( + hf_runner, model_name, dtype, query, documents, "image", "text" + ) + vllm_outputs = vllm_reranker( + vllm_runner, model_name, dtype, query, documents, "image", "text" + ) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) @@ -178,10 +184,12 @@ def test_model_image_image(hf_runner, vllm_runner, model_name, dtype): "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png", ] - hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, - "image", "image") - vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, - documents, "image", "image") + hf_outputs = hf_reranker( + hf_runner, model_name, dtype, query, documents, "image", "image" + ) + vllm_outputs = vllm_reranker( + vllm_runner, model_name, dtype, query, documents, "image", "image" + ) assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) diff --git a/tests/models/multimodal/pooling/test_llava_next.py b/tests/models/multimodal/pooling/test_llava_next.py index 50826677581d..2053ce399483 100644 --- a/tests/models/multimodal/pooling/test_llava_next.py +++ b/tests/models/multimodal/pooling/test_llava_next.py @@ -24,9 +24,10 @@ # built with LAPACK support. pytestmark = pytest.mark.skipif( not current_platform.is_cuda(), - reason="Llava Next model uses op that is only supported in CUDA") + reason="Llava Next model uses op that is only supported in CUDA", +) -llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 +llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501 HF_TEXT_PROMPTS = [ # T -> X @@ -34,18 +35,21 @@ "The label of the object is stop sign\nSummary above sentence in one word: " # noqa: E501 ), # T -> X - llama3_template.format( - "cherry blossom\nSummary above sentence in one word: "), + llama3_template.format("cherry blossom\nSummary above sentence in one word: "), ] -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - # I -> X - "stop_sign": - llama3_template.format("<image>\nSummary above image in one word: "), - # I -> X - "cherry_blossom": - llama3_template.format("<image>\nSummary above image in one word: "), -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + # I -> X + "stop_sign": llama3_template.format( + "<image>\nSummary above image in one word: " + ), + # I -> X + "cherry_blossom": llama3_template.format( + "<image>\nSummary above image in one word: " + ), + } +) MODELS = ["royokong/e5-v"] @@ -63,23 +67,22 @@ def _run_test( # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, - runner="pooling", - dtype=dtype, - max_model_len=4096, - enforce_eager=True) as vllm_model: + with vllm_runner( + model, runner="pooling", dtype=dtype, max_model_len=4096, enforce_eager=True + ) as vllm_model: vllm_outputs = vllm_model.embed(input_texts, images=input_images) - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForImageTextToText) as hf_model: + with hf_runner( + model, dtype=dtype, auto_cls=AutoModelForImageTextToText + ) as hf_model: # Patch the issue where generation_config.json is missing - hf_model.processor.patch_size = \ - hf_model.model.config.vision_config.patch_size + hf_model.processor.patch_size = hf_model.model.config.vision_config.patch_size # Patch the issue where image_token_id # exceeds the maximum allowed vocab size hf_model.model.resize_token_embeddings( - hf_model.model.language_model.vocab_size + 1) + hf_model.model.language_model.vocab_size + 1 + ) all_inputs = hf_model.get_inputs(input_texts, images=input_images) @@ -91,8 +94,7 @@ def _run_test( return_dict=True, output_hidden_states=True, ) - pooled_output = F.normalize(outputs.hidden_states[-1][0, -1, :], - dim=-1) + pooled_output = F.normalize(outputs.hidden_states[-1][0, -1, :], dim=-1) all_outputs.append(pooled_output.tolist()) @@ -142,8 +144,7 @@ def test_models_image( dtype: str, ) -> None: input_texts_images = [ - (text, asset.pil_image) - for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) ] input_texts = [text for text, _ in input_texts_images] input_images = [image for _, image in input_texts_images] diff --git a/tests/models/multimodal/pooling/test_phi3v.py b/tests/models/multimodal/pooling/test_phi3v.py index f918a0bd781e..c799a5bd3e1e 100644 --- a/tests/models/multimodal/pooling/test_phi3v.py +++ b/tests/models/multimodal/pooling/test_phi3v.py @@ -19,14 +19,14 @@ "Retrieve an image of this caption: cherry blossom", ] -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - # T + I -> X - "stop_sign": - "<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501 - # I -> X - "cherry_blossom": - "<|image_1|> Represent the given image for classification", # noqa: E501 -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + # T + I -> X + "stop_sign": "<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501 + # I -> X + "cherry_blossom": "<|image_1|> Represent the given image for classification", # noqa: E501 + } +) MODELS = ["TIGER-Lab/VLM2Vec-Full"] @@ -44,14 +44,14 @@ def _run_test( # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - with vllm_runner(model, runner="pooling", dtype=dtype, - enforce_eager=True) as vllm_model: + with vllm_runner( + model, runner="pooling", dtype=dtype, enforce_eager=True + ) as vllm_model: vllm_outputs = vllm_model.embed(input_texts, images=input_images) # use eager mode for hf runner, since phi3_v didn't work with flash_attn hf_model_kwargs = {"_attn_implementation": "eager"} - with hf_runner(model, dtype=dtype, - model_kwargs=hf_model_kwargs) as hf_model: + with hf_runner(model, dtype=dtype, model_kwargs=hf_model_kwargs) as hf_model: all_inputs = hf_model.get_inputs(input_texts, images=input_images) all_outputs = [] @@ -114,18 +114,21 @@ def test_models_image( dtype: str, ) -> None: input_texts_images = [ - (text, asset.pil_image) - for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + (text, asset.pil_image) for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) ] # add cases for special_tokens - input_texts_images.append(( - "\n<s><|user|>\n <|image_1|>\n\t <s>" - "Represent the given image for classification<|end|>" - "\n<|assistant|>\n", - Image.open( - get_vllm_public_assets(filename="cherry_blossom.jpg", - s3_prefix=VLM_IMAGES_DIR)), - )) + input_texts_images.append( + ( + "\n<s><|user|>\n <|image_1|>\n\t <s>" + "Represent the given image for classification<|end|>" + "\n<|assistant|>\n", + Image.open( + get_vllm_public_assets( + filename="cherry_blossom.jpg", s3_prefix=VLM_IMAGES_DIR + ) + ), + ) + ) input_texts = [text for text, _ in input_texts_images] input_images = [image for _, image in input_texts_images] diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index 7309660ea526..abf4150a9132 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -19,25 +19,25 @@ def _run_test( vllm_runner: type[VllmRunner], model: str, ) -> None: - prompt = [ { # This model deals with no text input "prompt_token_ids": [1], "multi_modal_data": generate_test_mm_data(), - } for _ in range(10) + } + for _ in range(10) ] with vllm_runner( - model, - runner="pooling", - dtype="half", - enforce_eager=True, - skip_tokenizer_init=True, - # Limit the maximum number of sequences to avoid the - # test going OOM during the warmup run - max_num_seqs=32, - default_torch_num_threads=1, + model, + runner="pooling", + dtype="half", + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + default_torch_num_threads=1, ) as vllm_model: vllm_model.encode(prompt) diff --git a/tests/models/multimodal/pooling/test_radio.py b/tests/models/multimodal/pooling/test_radio.py index 27b9fe369e80..80f594021ca8 100644 --- a/tests/models/multimodal/pooling/test_radio.py +++ b/tests/models/multimodal/pooling/test_radio.py @@ -34,9 +34,9 @@ def run_radio_test( # Using `self.get_nearest_supported_resolution`, for assets 432x642 the # nearest supported resolution is 432x640. pixel_values = [ - img_processor( - image, - return_tensors='pt').pixel_values.to(torch_dtype)[:, :, :, :640] + img_processor(image, return_tensors="pt").pixel_values.to(torch_dtype)[ + :, :, :, :640 + ] for image in images ] @@ -51,32 +51,33 @@ def run_radio_test( hf_model.eval() hf_outputs_per_image = [ - hf_model(pixel_value.to("cuda")).features - for pixel_value in pixel_values + hf_model(pixel_value.to("cuda")).features for pixel_value in pixel_values ] - radio_config = RadioConfig(model_name=config.args["model"], - reg_tokens=config.args["register_multiple"]) + radio_config = RadioConfig( + model_name=config.args["model"], reg_tokens=config.args["register_multiple"] + ) vllm_model = RadioModel(radio_config) vllm_model.load_weights(hf_model.state_dict()) vllm_model = vllm_model.to("cuda", torch_dtype) vllm_outputs_per_image = [ - vllm_model(pixel_values=pixel_value.to("cuda")) - for pixel_value in pixel_values + vllm_model(pixel_values=pixel_value.to("cuda")) for pixel_value in pixel_values ] del vllm_model, hf_model cleanup_dist_env_and_memory() cos_similar = nn.CosineSimilarity(dim=-1) - for vllm_output, hf_output in zip(vllm_outputs_per_image, - hf_outputs_per_image): + for vllm_output, hf_output in zip(vllm_outputs_per_image, hf_outputs_per_image): assert cos_similar(vllm_output, hf_output).mean() > 0.99 -@pytest.mark.parametrize("model_id", [ - "nvidia/C-RADIOv2-H", -]) +@pytest.mark.parametrize( + "model_id", + [ + "nvidia/C-RADIOv2-H", + ], +) @pytest.mark.parametrize("dtype", ["half"]) def test_radio(dist_init, image_assets, model_id, dtype: str) -> None: run_radio_test( diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index e8c28afee7e3..d9d85f7e0c00 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -6,22 +6,27 @@ import numpy as np import pytest -from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, - UserMessage) +from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image from vllm.config import ModelConfig -from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions, - ImageDummyOptions, VideoDummyOptions) +from vllm.config.multimodal import ( + AudioDummyOptions, + BaseDummyOptions, + ImageDummyOptions, + VideoDummyOptions, +) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs -from vllm.multimodal.processing import (BaseMultiModalProcessor, - InputProcessingContext) -from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, - cached_tokenizer_from_config, - encode_tokens) +from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext +from vllm.transformers_utils.tokenizer import ( + AnyTokenizer, + MistralTokenizer, + cached_tokenizer_from_config, + encode_tokens, +) from ....multimodal.utils import random_audio, random_image, random_video from ...registry import HF_EXAMPLE_MODELS @@ -36,14 +41,17 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: # GLM4.1V doesn't support multiple videos video = mm_data["video"] num_frames = len(video) - mm_data["video"] = (video, { - "total_num_frames": num_frames, - "fps": num_frames, - "duration": 1, - "frames_indices": [i for i in range(num_frames)], - "video_backend": "opencv", - "do_sample_frames": True, - }) + mm_data["video"] = ( + video, + { + "total_num_frames": num_frames, + "fps": num_frames, + "duration": 1, + "frames_indices": [i for i in range(num_frames)], + "video_backend": "opencv", + "do_sample_frames": True, + }, + ) return mm_data @@ -102,7 +110,8 @@ def _test_processing_correctness( mm_processor_cache_gb=2048, skip_tokenizer_init=model_info.skip_tokenizer_init, enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype) + dtype=model_info.dtype, + ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] @@ -145,27 +154,22 @@ def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions: input_to_hit = { "image": Image.new("RGB", size=(128, 128)), "video": np.zeros((4, 128, 128, 3), dtype=np.uint8), - "audio": (np.zeros((512, )), 16000), + "audio": (np.zeros((512,)), 16000), } input_factory = { - "image": - partial(random_image, rng, min_wh=128, max_wh=256), - "video": - partial(random_video, - rng, - min_frames=2, - max_frames=16, - min_wh=128, - max_wh=256), - "audio": - partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), + "image": partial(random_image, rng, min_wh=128, max_wh=256), + "video": partial( + random_video, rng, min_frames=2, max_frames=16, min_wh=128, max_wh=256 + ), + "audio": partial(random_audio, rng, min_len=512, max_len=1024, sr=16000), } for batch_idx in range(num_batches): mm_data = { - k: - [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) - for _ in range(rng.randint(limit + 1))] + k: [ + (input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) + for _ in range(rng.randint(limit + 1)) + ] for k, limit in limit_mm_per_prompt_ints.items() } @@ -174,12 +178,16 @@ def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions: # Mistral chat outputs tokens directly, rather than text prompts if isinstance(tokenizer, MistralTokenizer): images = mm_data.get("image", []) - request = ChatCompletionRequest(messages=[ - UserMessage(content=[ - TextChunk(text=""), - *(ImageChunk(image=image) for image in images), - ]), - ]) + request = ChatCompletionRequest( + messages=[ + UserMessage( + content=[ + TextChunk(text=""), + *(ImageChunk(image=image) for image in images), + ] + ), + ] + ) res = tokenizer.mistral.encode_chat_completion(request) prompt = res.tokens else: @@ -303,93 +311,92 @@ def _test_processing_correctness_one( baseline_text_result, baseline_tokenized_result, ignore_mm_keys=ignore_mm_keys, - msg=f"Failed ({batch_idx=}, {text_prompt=}, " - f"{token_prompt=}, {mm_data=})", + msg=f"Failed ({batch_idx=}, {text_prompt=}, {token_prompt=}, {mm_data=})", ) _assert_inputs_equal( cached_text_result, cached_tokenized_result, ignore_mm_keys=ignore_mm_keys, - msg=f"Failed ({batch_idx=}, {text_prompt=}, " - f"{token_prompt=}, {mm_data=})", + msg=f"Failed ({batch_idx=}, {text_prompt=}, {token_prompt=}, {mm_data=})", ) -# yapf: disable -@pytest.mark.parametrize("model_id", [ - "rhymes-ai/Aria", - "CohereForAI/aya-vision-8b", - "Salesforce/blip2-opt-2.7b", - "facebook/chameleon-7b", - "CohereLabs/command-a-vision-07-2025", - "deepseek-ai/deepseek-vl2-tiny", - "baidu/ERNIE-4.5-VL-28B-A3B-PT", - "adept/fuyu-8b", - "google/gemma-3-4b-it", - "google/gemma-3n-E2B-it", - "zai-org/glm-4v-9b", - "zai-org/GLM-4.1V-9B-Thinking", - "zai-org/GLM-4.5V", - "ibm-granite/granite-speech-3.3-2b", - "h2oai/h2ovl-mississippi-800m", - "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", - "HuggingFaceM4/Idefics3-8B-Llama3", - "internlm/Intern-S1", - "OpenGVLab/InternVL2-1B", - "OpenGVLab/InternVL3-1B", - "OpenGVLab/InternVL3_5-1B", - "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview", - "OpenGVLab/InternVL3_5-30B-A3B", - "Kwai-Keye/Keye-VL-8B-Preview", - "Kwai-Keye/Keye-VL-1_5-8B", - "moonshotai/Kimi-VL-A3B-Instruct", - "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "llava-hf/llava-1.5-7b-hf", - "llava-hf/llava-v1.6-mistral-7b-hf", - "llava-hf/LLaVA-NeXT-Video-7B-hf", - "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", - "TIGER-Lab/Mantis-8B-siglip-llama3", - "mispeech/midashenglm-7b", - "openbmb/MiniCPM-Llama3-V-2_5", - "openbmb/MiniCPM-o-2_6", - "openbmb/MiniCPM-V-2_6", - "MiniMaxAI/MiniMax-VL-01", - "allenai/Molmo-7B-D-0924", - "allenai/Molmo-7B-O-0924", - "nvidia/NVLM-D-72B", - "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", - "AIDC-AI/Ovis1.6-Gemma2-9B", - "AIDC-AI/Ovis1.6-Llama3.2-3B", - "AIDC-AI/Ovis2-1B", - "AIDC-AI/Ovis2.5-2B", - "google/paligemma-3b-mix-224", - "google/paligemma2-3b-ft-docci-448", - "microsoft/Phi-3.5-vision-instruct", - "microsoft/Phi-4-multimodal-instruct", - "mistralai/Pixtral-12B-2409", - "mistral-community/pixtral-12b", - "Qwen/Qwen-VL-Chat", - "Qwen/Qwen2-VL-2B-Instruct", - "Qwen/Qwen2.5-VL-3B-Instruct", - "Qwen/Qwen2-Audio-7B-Instruct", - "Qwen/Qwen2.5-Omni-3B", - "Qwen/Qwen3-VL-4B-Instruct", - "Qwen/Qwen3-VL-30B-A3B-Instruct", - "YannQi/R-4B", - "Skywork/Skywork-R1V-38B", - "HuggingFaceTB/SmolVLM2-2.2B-Instruct", - "stepfun-ai/step3", - "fixie-ai/ultravox-v0_5-llama-3_2-1b", - "openai/whisper-large-v3", - "omni-research/Tarsier-7b", - "omni-research/Tarsier2-Recap-7b", - "mistralai/Voxtral-Mini-3B-2507", -]) +@pytest.mark.parametrize( + "model_id", + [ + "rhymes-ai/Aria", + "CohereForAI/aya-vision-8b", + "Salesforce/blip2-opt-2.7b", + "facebook/chameleon-7b", + "CohereLabs/command-a-vision-07-2025", + "deepseek-ai/deepseek-vl2-tiny", + "baidu/ERNIE-4.5-VL-28B-A3B-PT", + "adept/fuyu-8b", + "google/gemma-3-4b-it", + "google/gemma-3n-E2B-it", + "zai-org/glm-4v-9b", + "zai-org/GLM-4.1V-9B-Thinking", + "zai-org/GLM-4.5V", + "ibm-granite/granite-speech-3.3-2b", + "h2oai/h2ovl-mississippi-800m", + "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", + "HuggingFaceM4/Idefics3-8B-Llama3", + "internlm/Intern-S1", + "OpenGVLab/InternVL2-1B", + "OpenGVLab/InternVL3-1B", + "OpenGVLab/InternVL3_5-1B", + "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview", + "OpenGVLab/InternVL3_5-30B-A3B", + "Kwai-Keye/Keye-VL-8B-Preview", + "Kwai-Keye/Keye-VL-1_5-8B", + "moonshotai/Kimi-VL-A3B-Instruct", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "llava-hf/llava-1.5-7b-hf", + "llava-hf/llava-v1.6-mistral-7b-hf", + "llava-hf/LLaVA-NeXT-Video-7B-hf", + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + "TIGER-Lab/Mantis-8B-siglip-llama3", + "mispeech/midashenglm-7b", + "openbmb/MiniCPM-Llama3-V-2_5", + "openbmb/MiniCPM-o-2_6", + "openbmb/MiniCPM-V-2_6", + "MiniMaxAI/MiniMax-VL-01", + "allenai/Molmo-7B-D-0924", + "allenai/Molmo-7B-O-0924", + "nvidia/NVLM-D-72B", + "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", + "AIDC-AI/Ovis1.6-Gemma2-9B", + "AIDC-AI/Ovis1.6-Llama3.2-3B", + "AIDC-AI/Ovis2-1B", + "AIDC-AI/Ovis2.5-2B", + "google/paligemma-3b-mix-224", + "google/paligemma2-3b-ft-docci-448", + "microsoft/Phi-3.5-vision-instruct", + "microsoft/Phi-4-multimodal-instruct", + "mistralai/Pixtral-12B-2409", + "mistral-community/pixtral-12b", + "Qwen/Qwen-VL-Chat", + "Qwen/Qwen2-VL-2B-Instruct", + "Qwen/Qwen2.5-VL-3B-Instruct", + "Qwen/Qwen2-Audio-7B-Instruct", + "Qwen/Qwen2.5-Omni-3B", + "Qwen/Qwen3-VL-4B-Instruct", + "Qwen/Qwen3-VL-30B-A3B-Instruct", + "YannQi/R-4B", + "Skywork/Skywork-R1V-38B", + "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + "stepfun-ai/step3", + "fixie-ai/ultravox-v0_5-llama-3_2-1b", + "openai/whisper-large-v3", + "omni-research/Tarsier-7b", + "omni-research/Tarsier2-Recap-7b", + "mistralai/Voxtral-Mini-3B-2507", + ], +) @pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) @pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("simplify_rate", [1.0]) -# yapf: enable def test_processing_correctness( model_id: str, hit_rate: float, diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index 070ddcd89ee9..553a5f719bd3 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -24,7 +24,8 @@ # post-sampled frames (expected behavior) (-1, 1, 5), (-1, 2, 10), - ]) + ], +) def test_processor_override( model_id: str, expected_toks_per_frame: int, @@ -55,10 +56,8 @@ def test_processor_override( # Ensure we have the right number of placeholders per num_crops size hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) video_token_id = tokenizer.convert_tokens_to_ids(hf_processor.video_token) - video_tok_count = processed_inputs["prompt_token_ids"].count( - video_token_id) - grid_t, _, _ = processed_inputs["mm_kwargs"].get_data( - )["video_grid_thw"][0] + video_tok_count = processed_inputs["prompt_token_ids"].count(video_token_id) + grid_t, _, _ = processed_inputs["mm_kwargs"].get_data()["video_grid_thw"][0] assert grid_t == expected_grid_t assert video_tok_count == expected_toks_per_frame * grid_t @@ -71,7 +70,7 @@ def test_video_loader_consistency( fps: int, ): """ - Ensure dynamic video loader (pre-sampled by loader) and normal video + Ensure dynamic video loader (pre-sampled by loader) and normal video loader (post-sampled by processor) produce same video processing outputs. """ ctx = build_model_context( @@ -91,7 +90,8 @@ def test_video_loader_consistency( static_video, static_metadata = OpenCVVideoBackend.load_bytes(video_bytes) dynamic_video, dynamic_metadata = OpenCVDynamicVideoBackend.load_bytes( - video_bytes, fps=fps) + video_bytes, fps=fps + ) # pre-sampled loader shouldn't read all frames assert len(dynamic_video) < len(static_video) @@ -99,12 +99,11 @@ def test_video_loader_consistency( static_mm_data = {"video": [(static_video, static_metadata)]} dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]} - static_outputs = processor.apply(prompt, static_mm_data, - hf_processor_mm_kwargs) - dynamic_outputs = processor.apply(prompt, dynamic_mm_data, - hf_processor_mm_kwargs) + static_outputs = processor.apply(prompt, static_mm_data, hf_processor_mm_kwargs) + dynamic_outputs = processor.apply(prompt, dynamic_mm_data, hf_processor_mm_kwargs) - assert static_outputs["prompt_token_ids"] == dynamic_outputs[ - "prompt_token_ids"] - assert static_outputs["mm_kwargs"].get_data( - ) == dynamic_outputs["mm_kwargs"].get_data() + assert static_outputs["prompt_token_ids"] == dynamic_outputs["prompt_token_ids"] + assert ( + static_outputs["mm_kwargs"].get_data() + == dynamic_outputs["mm_kwargs"].get_data() + ) diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 1adfe21352c4..bd21d4008fa7 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for H2OVL's multimodal preprocessing kwargs.""" + from collections.abc import Mapping from typing import Optional @@ -23,8 +24,10 @@ def _get_expected_num_patches( min_num: int, max_num: int, ): - from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets, - get_h2ovl_target_ratios) + from vllm.model_executor.models.h2ovl import ( + calculate_h2ovl_targets, + get_h2ovl_target_ratios, + ) width, height = image.size @@ -101,24 +104,27 @@ def _run_check( total_expected_num_patches = sum( _get_expected_num_patches(config, image, len(images), min_num, max_num) - for image in images) + for image in images + ) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"].get_data( - )["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data()["pixel_values_flat"].shape assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches -@pytest.mark.parametrize("model_id", [ - "h2oai/h2ovl-mississippi-800m", - "h2oai/h2ovl-mississippi-2b", -]) +@pytest.mark.parametrize( + "model_id", + [ + "h2oai/h2ovl-mississippi-800m", + "h2oai/h2ovl-mississippi-2b", + ], +) @pytest.mark.parametrize( "size_factors", [ @@ -165,10 +171,7 @@ def test_processor_override( _run_check( processor, - [ - rescale_image_size(image_assets[0].pil_image, f) - for f in size_factors - ], + [rescale_image_size(image_assets[0].pil_image, f) for f in size_factors], min_num, max_num, hf_processor_mm_kwargs, diff --git a/tests/models/multimodal/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py index d3a55993e558..351b9d018eec 100644 --- a/tests/models/multimodal/processing/test_idefics3.py +++ b/tests/models/multimodal/processing/test_idefics3.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for Idefics3's multimodal preprocessing kwargs.""" + import pytest from transformers import Idefics3Config @@ -11,14 +12,13 @@ @pytest.mark.parametrize("model_id", ["HuggingFaceM4/Idefics3-8B-Llama3"]) -# yapf: disable @pytest.mark.parametrize( ("mm_processor_kwargs", "expected_toks_per_img"), [ ({"size": {"longest_edge": 364}}, 169), ({"size": {"longest_edge": 728}}, 169 * (2**2 + 1)), - ]) -# yapf: enable + ], +) @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( @@ -42,8 +42,11 @@ def test_processor_override( hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs # Build the image str / prompt based on the number of images we pass - placeholders = "<image>" if num_imgs == 1 else "\n".join( - f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1)) + placeholders = ( + "<image>" + if num_imgs == 1 + else "\n".join(f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1)) + ) prompt = f"<|begin_of_text|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501 # Build mm_data @@ -57,8 +60,7 @@ def test_processor_override( # Ensure the placeholders format are correct hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"]) - assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[ - "input_ids"][0] + assert processed_inputs["prompt_token_ids"] == hf_processed_inputs["input_ids"][0] # Ensure we have the right number of placeholders per num_crops size image_token_id = ctx.get_hf_config().image_token_id diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index e4f25f5ac712..6f6529cb9401 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for InternVL's multimodal preprocessing kwargs.""" + from collections.abc import Mapping from typing import Optional @@ -24,7 +25,9 @@ def _get_expected_num_patches( max_num: int, ): from vllm.model_executor.models.internvl import ( - calculate_internvl_targets, get_internvl_target_ratios) + calculate_internvl_targets, + get_internvl_target_ratios, + ) width, height = image.size @@ -61,15 +64,15 @@ def _run_check( total_expected_num_patches = sum( _get_expected_num_patches(config, image, len(images), min_num, max_num) - for image in images) + for image in images + ) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"].get_data( - )["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data()["pixel_values_flat"].shape assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches @@ -122,10 +125,7 @@ def test_processor_override( _run_check( processor, - [ - rescale_image_size(image_assets[0].pil_image, f) - for f in size_factors - ], + [rescale_image_size(image_assets[0].pil_image, f) for f in size_factors], min_num, max_num, hf_processor_mm_kwargs, diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index bea4f43567ee..4c0791ea3cec 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -11,8 +11,7 @@ from ...utils import build_model_context -@pytest.mark.parametrize("model_id", - ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) +@pytest.mark.parametrize("model_id", ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) @pytest.mark.parametrize("mm_processor_kwargs", [{}]) @pytest.mark.parametrize("num_imgs", [1, 5]) @pytest.mark.parametrize("mm_processor_cache_gb", [0, 4]) @@ -38,13 +37,14 @@ def test_processor_override( hf_processor = processor.info.get_hf_processor() vocab = tokenizer.get_vocab() - prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \ - + "<|image|>" * num_imgs \ + prompt = ( + "<|begin_of_text|><|header_start|>user<|header_end|>" + + "<|image|>" * num_imgs + "<|eot|><|header_start|>assistant<|header_end|>" + ) mm_data = { "image": [ - image_assets[(i % len(image_assets))].pil_image - for i in range(num_imgs) + image_assets[(i % len(image_assets))].pil_image for i in range(num_imgs) ] } if tokenized_prompt: @@ -64,22 +64,23 @@ def test_processor_override( if tiles_x * tiles_y > 1: num_x_separators += (tiles_x - 1) * tiles_y num_y_separators += tiles_y - assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \ - == num_x_separators - assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \ - == num_y_separators + assert prompt_token_ids.count(vocab[hf_processor.tile_token]) == num_x_separators + assert ( + prompt_token_ids.count(vocab[hf_processor.tile_global_token]) + == num_y_separators + ) # image token offsets img_locs = processed_inputs["mm_placeholders"].get("image", []) assert len(img_locs) == num_imgs - assert [img_loc.offset for img_loc in img_locs] == \ - [i for i, v in enumerate(prompt_token_ids) \ - if v == config.boi_token_index] + assert [img_loc.offset for img_loc in img_locs] == [ + i for i, v in enumerate(prompt_token_ids) if v == config.boi_token_index + ] # patch sizes and masks - num_patches_per_chunk = processor.info.get_patch_per_chunk( - config.vision_config) - assert prompt_token_ids.count(config.image_token_index) \ + num_patches_per_chunk = processor.info.get_patch_per_chunk(config.vision_config) + assert ( + prompt_token_ids.count(config.image_token_index) == sum(mm_data["patches_per_image"]) * num_patches_per_chunk - assert len(mm_data["pixel_values"]) \ - == sum(mm_data["patches_per_image"]) + ) + assert len(mm_data["pixel_values"]) == sum(mm_data["patches_per_image"]) diff --git a/tests/models/multimodal/processing/test_llava_next.py b/tests/models/multimodal/processing/test_llava_next.py index ca34d1d758a4..ffe7ca17b5d6 100644 --- a/tests/models/multimodal/processing/test_llava_next.py +++ b/tests/models/multimodal/processing/test_llava_next.py @@ -22,8 +22,9 @@ def _validate_image_max_tokens_one( image_size: ImageSize, ) -> None: info = processor.info - feature_size = info.get_num_image_tokens(image_width=image_size.width, - image_height=image_size.height) + feature_size = info.get_num_image_tokens( + image_width=image_size.width, image_height=image_size.height + ) try: assert feature_size <= max_tokens, f"{feature_size} <= {max_tokens}" @@ -31,8 +32,9 @@ def _validate_image_max_tokens_one( failed_size_excs.append((image_size, exc)) -@pytest.mark.skip("This test takes around 5 minutes to run. " - "Comment this out to run it manually.") +@pytest.mark.skip( + "This test takes around 5 minutes to run. Comment this out to run it manually." +) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) def test_processor_max_tokens(model_id): ctx = build_model_context( @@ -66,9 +68,9 @@ def test_processor_max_tokens(model_id): pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes") if failed_size_excs: - msg = "Found failing image sizes:" \ - + "\n========\n".join(f"[{size}]\n{exc}" - for size, exc in failed_size_excs) + msg = "Found failing image sizes:" + "\n========\n".join( + f"[{size}]\n{exc}" for size, exc in failed_size_excs + ) raise AssertionError(msg) @@ -94,8 +96,10 @@ def _validate_image_prompt_replacements_one( # NOTE: There is a BOS token assert first_placeholder.offset == 1 - assert first_placeholder.length == ( - len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs + assert ( + first_placeholder.length + == (len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs + ) except Exception as exc: failed_size_excs.append((image_size, exc)) @@ -122,9 +126,9 @@ def _test_image_prompt_replacements( pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes") if failed_size_excs: - msg = "Found failing image sizes:" \ - + "\n========\n".join(f"[{size}]\n{exc}" - for size, exc in failed_size_excs) + msg = "Found failing image sizes:" + "\n========\n".join( + f"[{size}]\n{exc}" for size, exc in failed_size_excs + ) raise AssertionError(msg) @@ -138,11 +142,17 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) - image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), - (488, 183), (2560, 1669)] + image_ratios = [ + (171, 152), + (184, 161), + (198, 176), + (333, 296), + (369, 328), + (488, 183), + (2560, 1669), + ] image_sizes = [ - size for w, h in image_ratios - for size in [ImageSize(w, h), ImageSize(h, w)] + size for w, h in image_ratios for size in [ImageSize(w, h), ImageSize(h, w)] ] _test_image_prompt_replacements( @@ -152,8 +162,9 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) -@pytest.mark.skip("This test takes around 2 hours to run. " - "Comment this out to run it manually.") +@pytest.mark.skip( + "This test takes around 2 hours to run. Comment this out to run it manually." +) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("num_imgs", [1]) def test_processor_prompt_replacements_all(model_id, num_imgs): diff --git a/tests/models/multimodal/processing/test_llava_onevision.py b/tests/models/multimodal/processing/test_llava_onevision.py index e6344c4e7e6f..f5c552fe6476 100644 --- a/tests/models/multimodal/processing/test_llava_onevision.py +++ b/tests/models/multimodal/processing/test_llava_onevision.py @@ -22,8 +22,9 @@ def _validate_image_max_tokens_one( image_size: ImageSize, ) -> None: info = processor.info - feature_size = info.get_num_image_tokens(image_width=image_size.width, - image_height=image_size.height) + feature_size = info.get_num_image_tokens( + image_width=image_size.width, image_height=image_size.height + ) try: assert feature_size <= max_tokens, f"{feature_size} <= {max_tokens}" @@ -31,10 +32,10 @@ def _validate_image_max_tokens_one( failed_size_excs.append((image_size, exc)) -@pytest.mark.skip("This test takes around 5 minutes to run. " - "Comment this out to run it manually.") -@pytest.mark.parametrize("model_id", - ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +@pytest.mark.skip( + "This test takes around 5 minutes to run. Comment this out to run it manually." +) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) def test_processor_max_tokens(model_id): ctx = build_model_context( model_id, @@ -67,9 +68,9 @@ def test_processor_max_tokens(model_id): pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes") if failed_size_excs: - msg = "Found failing image sizes:" \ - + "\n========\n".join(f"[{size}]\n{exc}" - for size, exc in failed_size_excs) + msg = "Found failing image sizes:" + "\n========\n".join( + f"[{size}]\n{exc}" for size, exc in failed_size_excs + ) raise AssertionError(msg) @@ -94,8 +95,10 @@ def _validate_image_prompt_replacements_one( first_placeholder = image_placeholders[0] assert first_placeholder.offset == 0 - assert first_placeholder.length == len( - processed_inputs["prompt_token_ids"]) // num_imgs + assert ( + first_placeholder.length + == len(processed_inputs["prompt_token_ids"]) // num_imgs + ) except Exception as exc: failed_size_excs.append((image_size, exc)) @@ -121,14 +124,13 @@ def _test_image_prompt_replacements( pqdm(image_sizes, validate_one, n_jobs=8, desc="Validating image sizes") if failed_size_excs: - msg = "Found failing image sizes:" \ - + "\n========\n".join(f"[{size}]\n{exc}" - for size, exc in failed_size_excs) + msg = "Found failing image sizes:" + "\n========\n".join( + f"[{size}]\n{exc}" for size, exc in failed_size_excs + ) raise AssertionError(msg) -@pytest.mark.parametrize("model_id", - ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) @pytest.mark.parametrize("num_imgs", [1, 2]) def test_processor_prompt_replacements_regression(model_id, num_imgs): ctx = build_model_context( @@ -138,11 +140,17 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) - image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), - (488, 183), (2560, 1669)] + image_ratios = [ + (171, 152), + (184, 161), + (198, 176), + (333, 296), + (369, 328), + (488, 183), + (2560, 1669), + ] image_sizes = [ - size for w, h in image_ratios - for size in [ImageSize(w, h), ImageSize(h, w)] + size for w, h in image_ratios for size in [ImageSize(w, h), ImageSize(h, w)] ] _test_image_prompt_replacements( @@ -152,10 +160,10 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) -@pytest.mark.skip("This test takes around 2 hours to run. " - "Comment this out to run it manually.") -@pytest.mark.parametrize("model_id", - ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +@pytest.mark.skip( + "This test takes around 2 hours to run. Comment this out to run it manually." +) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) @pytest.mark.parametrize("num_imgs", [1]) def test_processor_prompt_replacements_all(model_id, num_imgs): ctx = build_model_context( diff --git a/tests/models/multimodal/processing/test_minimax_vl_01.py b/tests/models/multimodal/processing/test_minimax_vl_01.py index 9387212e3f10..11e000123511 100644 --- a/tests/models/multimodal/processing/test_minimax_vl_01.py +++ b/tests/models/multimodal/processing/test_minimax_vl_01.py @@ -61,17 +61,17 @@ def _test_image_prompt_replacements( num_imgs: int, image_sizes: list[ImageSize], ) -> None: - failed_size_excs = list[tuple[ImageSize, Exception]]() for size in image_sizes: - _validate_image_prompt_replacements_one(processor, num_imgs, - failed_size_excs, size) + _validate_image_prompt_replacements_one( + processor, num_imgs, failed_size_excs, size + ) if failed_size_excs: - msg = "Found failing image sizes:" \ - + "\n========\n".join(f"[{size}]\n{exc}" - for size, exc in failed_size_excs) + msg = "Found failing image sizes:" + "\n========\n".join( + f"[{size}]\n{exc}" for size, exc in failed_size_excs + ) raise AssertionError(msg) @@ -85,11 +85,17 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs): ) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) - image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), - (488, 183), (2560, 1669)] + image_ratios = [ + (171, 152), + (184, 161), + (198, 176), + (333, 296), + (369, 328), + (488, 183), + (2560, 1669), + ] image_sizes = [ - size for w, h in image_ratios - for size in [ImageSize(w, h), ImageSize(h, w)] + size for w, h in image_ratios for size in [ImageSize(w, h), ImageSize(h, w)] ] _test_image_prompt_replacements( diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py index a155ada35e92..e5ff2d1391b6 100644 --- a/tests/models/multimodal/processing/test_mllama4.py +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for mllama's multimodal preprocessing and profiling.""" + import pytest from torch import prod from transformers import Llama4Config @@ -47,14 +48,17 @@ def test_profiling(model_id: str, max_model_len: int): image_size = hf_config.vision_config.image_size patch_size = hf_config.vision_config.patch_size downsample_ratio = int( - round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2))) - tokens_per_patch = ((image_size // patch_size)**2) // downsample_ratio + round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)) + ) + tokens_per_patch = ((image_size // patch_size) ** 2) // downsample_ratio chunks_per_image = prod(mm_data["patches_per_image"]) total_num_patches = chunks_per_image * tokens_per_patch - num_tiles = mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][ - 1] # x-y separator tokens - total_tokens = total_num_patches.item() + num_tiles.item( - ) + 3 # image start, image, image end + num_tiles = ( + mm_data["aspect_ratios"][0][0] * mm_data["aspect_ratios"][0][1] + ) # x-y separator tokens + total_tokens = ( + total_num_patches.item() + num_tiles.item() + 3 + ) # image start, image, image end profiled_tokens = profiler.get_mm_max_contiguous_tokens( max_model_len, @@ -63,5 +67,6 @@ def test_profiling(model_id: str, max_model_len: int): assert total_tokens == profiled_tokens["image"] assert total_tokens == sum( - placeholder.length for placeholder in - decoder_dummy_data.multi_modal_placeholders["image"]) + placeholder.length + for placeholder in decoder_dummy_data.multi_modal_placeholders["image"] + ) diff --git a/tests/models/multimodal/processing/test_nemotron_vl.py b/tests/models/multimodal/processing/test_nemotron_vl.py index d9f1965a053d..6ff6f396fa33 100644 --- a/tests/models/multimodal/processing/test_nemotron_vl.py +++ b/tests/models/multimodal/processing/test_nemotron_vl.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for Nemotron-Nano-VL's multimodal preprocessing kwargs.""" + from collections.abc import Mapping from typing import Optional @@ -24,7 +25,9 @@ def _get_expected_num_patches( max_num: int, ): from vllm.model_executor.models.nemotron_vl import ( - calculate_nemotron_vl_targets, get_nemotron_vl_target_ratios) + calculate_nemotron_vl_targets, + get_nemotron_vl_target_ratios, + ) width, height = image.size @@ -63,22 +66,21 @@ def _run_check( total_expected_num_patches = sum( _get_expected_num_patches(config, image, len(images), min_num, max_num) - for image in images) + for image in images + ) print(total_expected_num_patches) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("<image>") img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"].get_data( - )["pixel_values_flat"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data()["pixel_values_flat"].shape print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape) assert img_tok_count == 256 * total_expected_num_patches assert pixel_shape[0] == total_expected_num_patches -@pytest.mark.parametrize("model_id", - ["nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"]) +@pytest.mark.parametrize("model_id", ["nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"]) @pytest.mark.parametrize( "size_factors", [ @@ -125,10 +127,7 @@ def test_processor_override( _run_check( processor, - [ - rescale_image_size(image_assets[0].pil_image, f) - for f in size_factors - ], + [rescale_image_size(image_assets[0].pil_image, f) for f in size_factors], min_num, max_num, hf_processor_mm_kwargs, diff --git a/tests/models/multimodal/processing/test_phi3v.py b/tests/models/multimodal/processing/test_phi3v.py index 1f3646f79486..8faff2611e6f 100644 --- a/tests/models/multimodal/processing/test_phi3v.py +++ b/tests/models/multimodal/processing/test_phi3v.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for phi3v's multimodal preprocessing kwargs.""" + import pytest from vllm.multimodal import MULTIMODAL_REGISTRY @@ -10,7 +11,6 @@ @pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"]) -# yapf: disable @pytest.mark.parametrize( ("mm_processor_kwargs", "expected_toks_per_img"), [ @@ -18,8 +18,8 @@ ({"num_crops": 16}, 1921), # the default num_crops of phi-3.5-vision is 4 ({}, 757), - ]) -# yapf: enable + ], +) @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( diff --git a/tests/models/multimodal/processing/test_phi4mm.py b/tests/models/multimodal/processing/test_phi4mm.py index f16d261c2c6a..5391555c2667 100644 --- a/tests/models/multimodal/processing/test_phi4mm.py +++ b/tests/models/multimodal/processing/test_phi4mm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for phi4mm's multimodal preprocessing kwargs.""" + import pytest from vllm.multimodal import MULTIMODAL_REGISTRY @@ -10,7 +11,6 @@ @pytest.mark.parametrize("model_id", ["microsoft/Phi-4-multimodal-instruct"]) -# yapf: disable @pytest.mark.parametrize( ("mm_processor_kwargs", "expected_toks_per_img"), [ @@ -18,8 +18,8 @@ ({"dynamic_hd": 16}, 4433), # the default num_crops of phi-4-multimodal is 36 ({}, 9585), - ]) -# yapf: enable + ], +) @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( @@ -46,8 +46,7 @@ def test_processor_override( img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" - image_size = ctx.get_hf_config( - ).embd_layer["image_embd_layer"]["crop_size"] + image_size = ctx.get_hf_config().embd_layer["image_embd_layer"]["crop_size"] dummy_image_size = (image_size * 7, image_size * 7) dummy_image = image_assets[0].pil_image.resize(dummy_image_size) mm_data = {"image": [dummy_image] * num_imgs} @@ -56,5 +55,6 @@ def test_processor_override( # Ensure we have the right number of placeholders per num_crops size img_tok_count = processed_inputs["prompt_token_ids"].count( - _IMAGE_PLACEHOLDER_TOKEN_ID) + _IMAGE_PLACEHOLDER_TOKEN_ID + ) assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py index 985f4188fdb6..9f4cdb6789b2 100644 --- a/tests/models/multimodal/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -10,13 +10,13 @@ @pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) -# yapf: disable @pytest.mark.parametrize( - ("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), [ + ("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), + [ ({}, 1426, (5704, 1176)), ({"min_pixels": 64**2, "max_pixels": 512**2}, 330, (1320, 1176)), - ]) -# yapf: enable + ], +) @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( @@ -48,8 +48,7 @@ def test_processor_override( hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) - pixel_shape = processed_inputs["mm_kwargs"].get_data( - )["pixel_values"].shape + pixel_shape = processed_inputs["mm_kwargs"].get_data()["pixel_values"].shape assert img_tok_count == expected_toks_per_img * num_imgs assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs diff --git a/tests/models/multimodal/processing/test_smolvlm.py b/tests/models/multimodal/processing/test_smolvlm.py index af8f983388c6..6f77d5516d14 100644 --- a/tests/models/multimodal/processing/test_smolvlm.py +++ b/tests/models/multimodal/processing/test_smolvlm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for smolvlm's multimodal preprocessing kwargs.""" + import pytest from transformers import SmolVLMConfig @@ -11,14 +12,13 @@ @pytest.mark.parametrize("model_id", ["HuggingFaceTB/SmolVLM2-2.2B-Instruct"]) -# yapf: disable @pytest.mark.parametrize( ("mm_processor_kwargs", "expected_toks_per_img"), [ ({"max_image_size": {"longest_edge": 384}}, 1377), ({"max_image_size": {"longest_edge": 768}}, 405), - ]) -# yapf: enable + ], +) @pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("kwargs_on_init", [True, False]) def test_processor_override( @@ -42,8 +42,11 @@ def test_processor_override( hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs # Build the image str / prompt based on the number of images we pass - placeholders = "<image>" if num_imgs == 1 else "\n".join( - f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1)) + placeholders = ( + "<image>" + if num_imgs == 1 + else "\n".join(f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1)) + ) prompt = f"<|im_start|>User:{placeholders}\n<end_of_utterance>\nAssistant:" # noqa: E501 # Build mm_data @@ -57,8 +60,7 @@ def test_processor_override( # Ensure the placeholders format are correct hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"]) - assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[ - "input_ids"][0] + assert processed_inputs["prompt_token_ids"] == hf_processed_inputs["input_ids"][0] # Ensure we have the right number of placeholders per num_crops size image_token_id = ctx.get_hf_config().image_token_id diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 6061e4538c95..2c4d109c3687 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -9,23 +9,29 @@ import numpy as np import pytest import torch.nn as nn -from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, - UserMessage) +from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config -from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions, - ImageDummyOptions, VideoDummyOptions) -from vllm.distributed import (cleanup_dist_env_and_memory, - init_distributed_environment, - initialize_model_parallel) +from vllm.config.multimodal import ( + AudioDummyOptions, + BaseDummyOptions, + ImageDummyOptions, + VideoDummyOptions, +) +from vllm.distributed import ( + cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.models.interfaces import (SupportsMultiModal, - supports_multimodal) +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + supports_multimodal, +) from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs -from vllm.multimodal.processing import (BaseMultiModalProcessor, - InputProcessingContext) +from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils import is_list_of @@ -48,13 +54,15 @@ } ImageInput = list[Image.Image] -VideoInput = Union[list[Image.Image], list[np.ndarray], - list[tuple[np.ndarray, dict[str, Any]]]] +VideoInput = Union[ + list[Image.Image], list[np.ndarray], list[tuple[np.ndarray, dict[str, Any]]] +] AudioInput = list[tuple[np.ndarray, int]] -def _resize_data(_data: Union[Image.Image, np.ndarray], - size_factor: float) -> Union[Image.Image, np.ndarray]: +def _resize_data( + _data: Union[Image.Image, np.ndarray], size_factor: float +) -> Union[Image.Image, np.ndarray]: assert size_factor <= 1, "Size factor must be less than 1" # Image input if isinstance(_data, Image.Image): @@ -74,20 +82,18 @@ def _resize_data(_data: Union[Image.Image, np.ndarray], return _data[..., :T, :H, :W, :C] # Audio input elif isinstance(_data, np.ndarray) and _data.ndim == 1: - return _data[:int(len(_data) * size_factor)] + return _data[: int(len(_data) * size_factor)] raise AssertionError("This line should be unreachable.") def resize_mm_data( - data: Union[ImageInput, VideoInput, AudioInput], - size_factors: tuple[float, - ...]) -> Union[ImageInput, VideoInput, AudioInput]: - size_factors = size_factors[:len(data)] + data: Union[ImageInput, VideoInput, AudioInput], size_factors: tuple[float, ...] +) -> Union[ImageInput, VideoInput, AudioInput]: + size_factors = size_factors[: len(data)] if is_list_of(data, (Image.Image, np.ndarray, list)): return [_resize_data(d, s) for d, s in zip(data, size_factors)] elif is_list_of(data, tuple): - return [(_resize_data(d, s), meta) - for (d, meta), s in zip(data, size_factors)] + return [(_resize_data(d, s), meta) for (d, meta), s in zip(data, size_factors)] raise ValueError("Unsupported multimodal data type.") @@ -116,12 +122,16 @@ def create_batched_mm_kwargs( # Mistral chat outputs tokens directly, rather than text prompts if model_config.tokenizer_mode == "mistral": images = resized_mm_data.get("image", []) - request = ChatCompletionRequest(messages=[ - UserMessage(content=[ - TextChunk(text=""), - *(ImageChunk(image=image) for image in images), - ]), - ]) + request = ChatCompletionRequest( + messages=[ + UserMessage( + content=[ + TextChunk(text=""), + *(ImageChunk(image=image) for image in images), + ] + ), + ] + ) tokenizer = processing_info.get_tokenizer() res = tokenizer.mistral.encode_chat_completion(request) prompt = res.tokens @@ -133,10 +143,7 @@ def create_batched_mm_kwargs( hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, tokenization_kwargs=processor_inputs.tokenization_kwargs, )["mm_kwargs"].require_data() - items = [ - item for modality in supported_mm_limits - for item in mm_kwargs[modality] - ] + items = [item for modality in supported_mm_limits for item in mm_kwargs[modality]] return group_mm_kwargs_by_modality( items, merge_by_field_config=model_cls.merge_by_field_config, @@ -167,15 +174,17 @@ def initialize_dummy_model( cleanup_dist_env_and_memory() -def get_model_id_to_test( - model_arch_list: Iterable[str]) -> list[tuple[str, str]]: +def get_model_id_to_test(model_arch_list: Iterable[str]) -> list[tuple[str, str]]: filtered_results = [] for model_arch in model_arch_list: model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) if model_info.extras and model_arch in ARCH_NEEDS_EXTRAS: available_repos = list( - map(lambda model_id: (model_arch, model_id), - [model_info.default, *model_info.extras.values()])) + map( + lambda model_id: (model_arch, model_id), + [model_info.default, *model_info.extras.values()], + ) + ) filtered_results.extend(available_repos) else: filtered_results.append((model_arch, model_info.default)) @@ -183,8 +192,8 @@ def get_model_id_to_test( @pytest.mark.parametrize( - "model_arch, model_id", - get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys())) + "model_arch, model_id", get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()) +) def test_model_tensor_schema(model_arch: str, model_id: str): if model_arch in ARCH_TO_SKIP: pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}") @@ -193,12 +202,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str): model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) model_info.check_available_online(on_fail="skip") - model_info.check_transformers_version(on_fail="skip", - check_max_version=False) + model_info.check_transformers_version(on_fail="skip", check_max_version=False) - hf_overrides_fn = partial(dummy_hf_overrides, - model_arch=model_arch, - exist_overrides=model_info.hf_overrides) + hf_overrides_fn = partial( + dummy_hf_overrides, + model_arch=model_arch, + exist_overrides=model_info.hf_overrides, + ) model_config = ModelConfig( model_id, @@ -256,8 +266,11 @@ def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions: with initialize_dummy_model(model_cls, model_config) as model: for modality, _, mm_kwargs in create_batched_mm_kwargs( - model_cls, model_config, processor): + model_cls, model_config, processor + ): for method_name in inputs_parse_methods: - print(f"Testing `{method_name}` with modality={modality} " - f"and mm_kwargs{list(mm_kwargs.keys())}") + print( + f"Testing `{method_name}` with modality={modality} " + f"and mm_kwargs{list(mm_kwargs.keys())}" + ) getattr(model, method_name)(modality=modality, **mm_kwargs) diff --git a/tests/models/multimodal/processing/test_transformers.py b/tests/models/multimodal/processing/test_transformers.py index c0e043ade736..e2a2186f470b 100644 --- a/tests/models/multimodal/processing/test_transformers.py +++ b/tests/models/multimodal/processing/test_transformers.py @@ -7,9 +7,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY -# yapf: disable -@pytest.mark.parametrize("model_id", - ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) def test_multimodal_processor(model_id): model_config = ModelConfig( model=model_id, @@ -18,9 +16,9 @@ def test_multimodal_processor(model_id): mm_processor = MULTIMODAL_REGISTRY.create_processor(model_config) - image_pil = ImageAsset('cherry_blossom').pil_image + image_pil = ImageAsset("cherry_blossom").pil_image mm_data = {"image": image_pil} - str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501 + str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501 str_processed_inputs = mm_processor.apply( prompt=str_prompt, mm_data=mm_data, @@ -28,8 +26,23 @@ def test_multimodal_processor(model_id): ) ids_prompt = [ - 151644, 872, 220, 151646, 198, 3838, 374, 279, 2213, 315, 419, 2168, - 30, 151645, 151644, 77091, 198 + 151644, + 872, + 220, + 151646, + 198, + 3838, + 374, + 279, + 2213, + 315, + 419, + 2168, + 30, + 151645, + 151644, + 77091, + 198, ] ids_processed_inputs = mm_processor.apply( prompt=ids_prompt, @@ -37,5 +50,7 @@ def test_multimodal_processor(model_id): hf_processor_mm_kwargs={}, ) - assert (str_processed_inputs["prompt_token_ids"] - == ids_processed_inputs["prompt_token_ids"]) + assert ( + str_processed_inputs["prompt_token_ids"] + == ids_processed_inputs["prompt_token_ids"] + ) diff --git a/tests/models/multimodal/test_mapping.py b/tests/models/multimodal/test_mapping.py index caf1966ab513..2179cf33a573 100644 --- a/tests/models/multimodal/test_mapping.py +++ b/tests/models/multimodal/test_mapping.py @@ -19,7 +19,7 @@ def create_repo_dummy_weights(repo: str) -> Iterable[tuple[str, torch.Tensor]]: """Create weights from safetensors checkpoint metadata""" metadata = try_get_safetensors_metadata(repo) weight_names = list(metadata.weight_map.keys()) - with torch.device('meta'): + with torch.device("meta"): return ((name, torch.empty(0)) for name in weight_names) @@ -61,7 +61,8 @@ def test_hf_model_weights_mapper(model_arch: str): hf_overrides=model_info.hf_overrides, skip_tokenizer_init=model_info.skip_tokenizer_init, enforce_eager=model_info.enforce_eager, - dtype=model_info.dtype) + dtype=model_info.dtype, + ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) original_weights = create_repo_dummy_weights(model_id) @@ -83,6 +84,7 @@ def test_hf_model_weights_mapper(model_arch: str): weights_missing = ref_weight_names - weight_names weights_unmapped = weight_names - ref_weight_names - assert (not weights_missing and not weights_unmapped), ( + assert not weights_missing and not weights_unmapped, ( f"Following weights are not mapped correctly: {weights_unmapped}, " - f"Missing expected weights: {weights_missing}.") + f"Missing expected weights: {weights_missing}." + ) diff --git a/tests/models/quantization/test_awq.py b/tests/models/quantization/test_awq.py index e741e4ad90a0..c4c10832ede3 100644 --- a/tests/models/quantization/test_awq.py +++ b/tests/models/quantization/test_awq.py @@ -11,12 +11,12 @@ from ...conftest import IMAGE_ASSETS, ImageTestAssets, VllmRunner from ..utils import check_logprobs_close -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ - "stop_sign": - "<|im_start|>User\n<image>\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 - "cherry_blossom": - "<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 -}) +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( + { + "stop_sign": "<|im_start|>User\n<image>\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + "cherry_blossom": "<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + } +) def run_awq_test( @@ -34,10 +34,13 @@ def run_awq_test( ): images = [asset.pil_image for asset in image_assets] - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + inputs_per_image = [ + ( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) + for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. @@ -46,42 +49,41 @@ def run_awq_test( # max_model_len should be greater than image_feature_size with vllm_runner( - source_model, - max_model_len=4096, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True, - default_torch_num_threads=1, + source_model, + max_model_len=4096, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + default_torch_num_threads=1, ) as vllm_model: source_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs=num_logprobs, images=images + ) for prompts, images in inputs_per_image ] with vllm_runner( - quant_model, - quantization="awq", - max_model_len=4096, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True, - default_torch_num_threads=1, + quant_model, + quantization="awq", + max_model_len=4096, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + default_torch_num_threads=1, ) as vllm_model: quant_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs=num_logprobs, images=images + ) for prompts, images in inputs_per_image ] - for source_outputs, quant_outputs in zip(source_outputs_per_image, - quant_outputs_per_image): + for source_outputs, quant_outputs in zip( + source_outputs_per_image, quant_outputs_per_image + ): # TODO: Check whether using original CLIPVisionModel can improve # consistency against HF check_logprobs_close( @@ -113,9 +115,16 @@ def run_awq_test( @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @torch.inference_mode() -def test_awq_models(vllm_runner, image_assets, source_model, quant_model, - size_factors, dtype, max_tokens, num_logprobs) -> None: - +def test_awq_models( + vllm_runner, + image_assets, + source_model, + quant_model, + size_factors, + dtype, + max_tokens, + num_logprobs, +) -> None: run_awq_test( vllm_runner, image_assets, diff --git a/tests/models/quantization/test_bitblas.py b/tests/models/quantization/test_bitblas.py index 754ac9a29a13..f516cc2724a6 100644 --- a/tests/models/quantization/test_bitblas.py +++ b/tests/models/quantization/test_bitblas.py @@ -7,9 +7,10 @@ bitblas/GPTQ models are in the top 3 selections of each other. Note: bitblas internally uses locks to synchronize the threads. This can -result in very slight nondeterminism for bitblas. As a result, we re-run the +result in very slight nondeterminism for bitblas. As a result, we re-run the test up to 3 times to see if we pass. """ + from dataclasses import dataclass import pytest @@ -24,8 +25,10 @@ class ModelPair: model_pairs = [ - ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas", - model_gptq="hxbgsyxh/opt-125m-4bit-128g"), + ModelPair( + model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas", + model_gptq="hxbgsyxh/opt-125m-4bit-128g", + ), ] @@ -43,16 +46,19 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model_pair.model_bitblas, - dtype=dtype, - quantization="bitblas") as bitblas_model: + with vllm_runner( + model_pair.model_bitblas, dtype=dtype, quantization="bitblas" + ) as bitblas_model: bitblas_outputs = bitblas_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model_pair.model_gptq, dtype=dtype, - quantization="gptq") as gptq_model: + with vllm_runner( + model_pair.model_gptq, dtype=dtype, quantization="gptq" + ) as gptq_model: gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/quantization/test_bitsandbytes.py b/tests/models/quantization/test_bitsandbytes.py index 25fc44fee90d..5e0421af1c17 100644 --- a/tests/models/quantization/test_bitsandbytes.py +++ b/tests/models/quantization/test_bitsandbytes.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -'''Tests whether bitsandbytes computation is enabled correctly. +"""Tests whether bitsandbytes computation is enabled correctly. Run `pytest tests/quantization/test_bitsandbytes.py`. -''' +""" import pytest from transformers import BitsAndBytesConfig @@ -15,8 +15,10 @@ models_4bit_to_test = [ ("facebook/opt-125m", "quantize opt model inflight"), - ("mistralai/Mistral-7B-Instruct-v0.3", - "quantize inflight model with both HF and Mistral format weights") + ( + "mistralai/Mistral-7B-Instruct-v0.3", + "quantize inflight model with both HF and Mistral format weights", + ), ] models_4bit_to_embedding_test = [ @@ -28,72 +30,84 @@ ] models_pre_qaunt_4bit_to_test = [ - ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', - 'read pre-quantized 4-bit FP4 model'), - ('poedator/opt-125m-bnb-4bit', 'read pre-quantized 4-bit NF4 opt model'), + ( + "PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed", + "read pre-quantized 4-bit FP4 model", + ), + ("poedator/opt-125m-bnb-4bit", "read pre-quantized 4-bit NF4 opt model"), ] models_pre_quant_8bit_to_test = [ - ('meta-llama/Llama-Guard-3-8B-INT8', - 'read pre-quantized llama 8-bit model'), + ("meta-llama/Llama-Guard-3-8B-INT8", "read pre-quantized llama 8-bit model"), ("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"), ] -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name, description", models_4bit_to_test) -def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, - model_name, description) -> None: - - hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( - load_in_4bit=True)) - validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], - model_name, False, hf_model_kwargs) - - -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') -@pytest.mark.parametrize("model_name, description", - models_pre_qaunt_4bit_to_test) -def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, - model_name, description) -> None: +def test_load_4bit_bnb_model( + hf_runner, vllm_runner, example_prompts, model_name, description +) -> None: + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(load_in_4bit=True)) + validate_generated_texts( + hf_runner, vllm_runner, example_prompts[:1], model_name, False, hf_model_kwargs + ) - validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], - model_name, True) +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) +@pytest.mark.parametrize("model_name, description", models_pre_qaunt_4bit_to_test) +def test_load_pre_quant_4bit_bnb_model( + hf_runner, vllm_runner, example_prompts, model_name, description +) -> None: + validate_generated_texts( + hf_runner, vllm_runner, example_prompts[:1], model_name, True + ) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') -@pytest.mark.parametrize("model_name, description", - models_pre_quant_8bit_to_test) -def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, - model_name, description) -> None: - validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], - model_name, True) +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) +@pytest.mark.parametrize("model_name, description", models_pre_quant_8bit_to_test) +def test_load_8bit_bnb_model( + hf_runner, vllm_runner, example_prompts, model_name, description +) -> None: + validate_generated_texts( + hf_runner, vllm_runner, example_prompts[:1], model_name, True + ) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name, description", models_4bit_to_test) @multi_gpu_test(num_gpus=2) -def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, - model_name, description) -> None: - - hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( - load_in_4bit=True)) - validate_generated_texts(hf_runner, - vllm_runner, - example_prompts[:1], - model_name, - False, - hf_model_kwargs, - vllm_tp_size=2) - - -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') +def test_load_tp_4bit_bnb_model( + hf_runner, vllm_runner, example_prompts, model_name, description +) -> None: + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(load_in_4bit=True)) + validate_generated_texts( + hf_runner, + vllm_runner, + example_prompts[:1], + model_name, + False, + hf_model_kwargs, + vllm_tp_size=2, + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name, description", models_4bit_to_test) @multi_gpu_test(num_gpus=2) def test_load_pp_4bit_bnb_model(model_name, description) -> None: @@ -115,30 +129,37 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None: compare_two_settings(model_name, common_args, pp_args) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name, description", models_4bit_to_moe_test) -def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts, - model_name, description) -> None: - - hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=True, - )) - with vllm_runner(model_name, - quantization='bitsandbytes', - enforce_eager=False, - default_torch_num_threads=1) as llm: - vllm_outputs = llm.generate_greedy_logprobs(example_prompts, - max_tokens=32, - num_logprobs=5) - - with hf_runner(model_name, - model_kwargs=hf_model_kwargs, - default_torch_num_threads=1) as llm: +def test_4bit_bnb_moe_model( + hf_runner, vllm_runner, example_prompts, model_name, description +) -> None: + hf_model_kwargs = dict( + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + ) + ) + with vllm_runner( + model_name, + quantization="bitsandbytes", + enforce_eager=False, + default_torch_num_threads=1, + ) as llm: + vllm_outputs = llm.generate_greedy_logprobs( + example_prompts, max_tokens=32, num_logprobs=5 + ) + + with hf_runner( + model_name, model_kwargs=hf_model_kwargs, default_torch_num_threads=1 + ) as llm: transformers_outputs = llm.generate_greedy_logprobs_limit( - example_prompts, max_tokens=32, num_logprobs=5) + example_prompts, max_tokens=32, num_logprobs=5 + ) check_logprobs_close( outputs_0_lst=transformers_outputs, outputs_1_lst=vllm_outputs, @@ -147,10 +168,11 @@ def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts, ) -@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), - reason='bitsandbytes is not supported on this GPU type.') -@pytest.mark.parametrize("model_name, description", - models_4bit_to_embedding_test) +@pytest.mark.skipif( + not is_quant_method_supported("bitsandbytes"), + reason="bitsandbytes is not supported on this GPU type.", +) +@pytest.mark.parametrize("model_name, description", models_4bit_to_embedding_test) @pytest.mark.parametrize("dtype", ["half"]) def test_4bit_bnb_embedding_model( model_name, @@ -160,7 +182,6 @@ def test_4bit_bnb_embedding_model( example_prompts, dtype: str, ) -> None: - # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" # sentence_transformers will strip the input texts, see: @@ -170,22 +191,23 @@ def test_4bit_bnb_embedding_model( example_prompts = [str(s).strip() for s in example_prompts] # Inflight 4bit quantization - with vllm_runner(model_name, - runner="pooling", - dtype=dtype, - gpu_memory_utilization=0.5, - quantization="bitsandbytes", - default_torch_num_threads=1) as vllm_model: + with vllm_runner( + model_name, + runner="pooling", + dtype=dtype, + gpu_memory_utilization=0.5, + quantization="bitsandbytes", + default_torch_num_threads=1, + ) as vllm_model: vllm_outputs = vllm_model.embed(example_prompts) - hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( - load_in_4bit=True)) + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(load_in_4bit=True)) with hf_runner( - model_name, - dtype=dtype, - model_kwargs=hf_model_kwargs, - is_sentence_transformer=True, - default_torch_num_threads=1, + model_name, + dtype=dtype, + model_kwargs=hf_model_kwargs, + is_sentence_transformer=True, + default_torch_num_threads=1, ) as hf_model: hf_outputs = hf_model.encode(example_prompts) @@ -210,23 +232,25 @@ def log_generated_texts(prompts, outputs, runner_name): return logged_texts -def validate_generated_texts(hf_runner, - vllm_runner, - prompts, - model_name, - pre_quant=False, - hf_model_kwargs=None, - vllm_tp_size=1, - max_tokens=8): - +def validate_generated_texts( + hf_runner, + vllm_runner, + prompts, + model_name, + pre_quant=False, + hf_model_kwargs=None, + vllm_tp_size=1, + max_tokens=8, +): # NOTE: run vLLM first, as it requires a clean process # when using distributed inference - with vllm_runner(model_name, - quantization=None if pre_quant else 'bitsandbytes', - tensor_parallel_size=vllm_tp_size, - enforce_eager=False, - default_torch_num_threads=1) as llm: - + with vllm_runner( + model_name, + quantization=None if pre_quant else "bitsandbytes", + tensor_parallel_size=vllm_tp_size, + enforce_eager=False, + default_torch_num_threads=1, + ) as llm: vllm_outputs = llm.generate_greedy(prompts, max_tokens) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") @@ -234,9 +258,9 @@ def validate_generated_texts(hf_runner, hf_model_kwargs = {} # Run with HF runner - with hf_runner(model_name, - model_kwargs=hf_model_kwargs, - default_torch_num_threads=1) as llm: + with hf_runner( + model_name, model_kwargs=hf_model_kwargs, default_torch_num_threads=1 + ) as llm: hf_outputs = llm.generate_greedy(prompts, max_tokens) hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") @@ -245,8 +269,10 @@ def validate_generated_texts(hf_runner, hf_str = hf_log["generated_text"] vllm_str = vllm_log["generated_text"] prompt = hf_log["prompt"] - assert hf_str == vllm_str, (f"Model: {model_name}" - f"Mismatch between HF and vLLM outputs:\n" - f"Prompt: {prompt}\n" - f"HF Output: '{hf_str}'\n" - f"vLLM Output: '{vllm_str}'") + assert hf_str == vllm_str, ( + f"Model: {model_name}" + f"Mismatch between HF and vLLM outputs:\n" + f"Prompt: {prompt}\n" + f"HF Output: '{hf_str}'\n" + f"vLLM Output: '{vllm_str}'" + ) diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index bb8ae741b614..55b149ae5da7 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -5,6 +5,7 @@ """Tests fp8 models against ground truth generation Note: these tests will only pass on L4 GPU. """ + import pytest from tests.quantization.utils import is_quant_method_supported @@ -14,21 +15,33 @@ from ..utils import check_logprobs_close -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="fp8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.", +) @pytest.mark.parametrize( "kv_cache_dtype,base_model,test_model", [ # Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors. - ("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct", - "nm-testing/Llama-3.2-1B-Instruct-FP8-KV"), + ( + "fp8_e4m3", + "meta-llama/Llama-3.2-1B-Instruct", + "nm-testing/Llama-3.2-1B-Instruct-FP8-KV", + ), # Test BF16 checkpoint w. fp8_e5m2 kv-cache. - ("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-1B-Instruct"), + ( + "fp8_e5m2", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + ), # Test BF16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json. - ("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-1B-Instruct") - ]) + ( + "fp8_e4m3", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + ), + ], +) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("enforce_eager", [True]) @@ -54,38 +67,39 @@ def test_models( """ if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): - pytest.skip( - f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") + pytest.skip(f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") if not current_platform.is_kv_cache_dtype_supported(kv_cache_dtype, None): pytest.skip(f"{kv_cache_dtype} is not supported on this platform.") with monkeypatch.context() as m: - m.setenv("TOKENIZERS_PARALLELISM", 'true') + m.setenv("TOKENIZERS_PARALLELISM", "true") m.setenv(STR_BACKEND_ENV_VAR, backend) MAX_MODEL_LEN = 1024 NUM_LOG_PROBS = 8 with vllm_runner( - base_model, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - kv_cache_dtype="auto", + base_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype="auto", ) as vllm_model: baseline_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) + example_prompts, max_tokens, NUM_LOG_PROBS + ) with vllm_runner( - test_model, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, + test_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, ) as vllm_model: test_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) + example_prompts, max_tokens, NUM_LOG_PROBS + ) check_logprobs_close( outputs_0_lst=baseline_outputs, @@ -96,15 +110,18 @@ def test_models( @pytest.mark.cpu_model -@pytest.mark.skipif(not current_platform.is_cpu(), - reason="test for the CPU backend.") +@pytest.mark.skipif(not current_platform.is_cpu(), reason="test for the CPU backend.") @pytest.mark.parametrize( "kv_cache_dtype,base_model,test_model", [ # Test BF16 checkpoint w. fp8_e5m2 kv-cache. - ("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-1B-Instruct"), - ]) + ( + "fp8_e5m2", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + ), + ], +) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) def test_cpu_models( @@ -121,28 +138,30 @@ def test_cpu_models( numerical sensitive kernels. """ with monkeypatch.context() as m: - m.setenv("TOKENIZERS_PARALLELISM", 'true') + m.setenv("TOKENIZERS_PARALLELISM", "true") MAX_MODEL_LEN = 1024 NUM_LOG_PROBS = 8 with vllm_runner( - base_model, - max_model_len=MAX_MODEL_LEN, - dtype="bfloat16", - kv_cache_dtype="auto", + base_model, + max_model_len=MAX_MODEL_LEN, + dtype="bfloat16", + kv_cache_dtype="auto", ) as vllm_model: baseline_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) + example_prompts, max_tokens, NUM_LOG_PROBS + ) with vllm_runner( - test_model, - max_model_len=MAX_MODEL_LEN, - dtype="bfloat16", - kv_cache_dtype=kv_cache_dtype, + test_model, + max_model_len=MAX_MODEL_LEN, + dtype="bfloat16", + kv_cache_dtype=kv_cache_dtype, ) as vllm_model: test_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, NUM_LOG_PROBS) + example_prompts, max_tokens, NUM_LOG_PROBS + ) check_logprobs_close( outputs_0_lst=baseline_outputs, diff --git a/tests/models/quantization/test_gguf.py b/tests/models/quantization/test_gguf.py index 3e77d3e71039..5e2438857aee 100644 --- a/tests/models/quantization/test_gguf.py +++ b/tests/models/quantization/test_gguf.py @@ -100,35 +100,37 @@ def check_model_outputs( ): tokenizer = AutoTokenizer.from_pretrained(model.original_model) if tokenizer.chat_template is not None: - messages = [[{ - 'role': 'user', - 'content': prompt - }] for prompt in prompts] - prompts = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + messages = [[{"role": "user", "content": prompt}] for prompt in prompts] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) # Run gguf model. - with vllm_runner(model_name=model.gguf_model, - enforce_eager=True, - tokenizer_name=model.original_model, - dtype=dtype, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=tp_size) as gguf_model: + with vllm_runner( + model_name=model.gguf_model, + enforce_eager=True, + tokenizer_name=model.original_model, + dtype=dtype, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tp_size, + ) as gguf_model: gguf_outputs = gguf_model.generate_greedy_logprobs( - prompts[:-1], max_tokens, num_logprobs) + prompts[:-1], max_tokens, num_logprobs + ) # Run unquantized model. # Should run with tp=1, otherwise the test will stuck at # nccl initialization. with vllm_runner( - model_name=model.original_model, - enforce_eager=True, # faster tests - dtype=dtype, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1) as original_model: + model_name=model.original_model, + enforce_eager=True, # faster tests + dtype=dtype, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1, + ) as original_model: original_outputs = original_model.generate_greedy_logprobs( - prompts[:-1], max_tokens, num_logprobs) + prompts[:-1], max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=original_outputs, @@ -138,12 +140,14 @@ def check_model_outputs( ) -@pytest.mark.skipif(not is_quant_method_supported("gguf"), - reason="gguf is not supported on this GPU type.") -@pytest.mark.parametrize("model", [ - pytest.param(test_config, marks=test_config.marks) - for test_config in MODELS -]) +@pytest.mark.skipif( + not is_quant_method_supported("gguf"), + reason="gguf is not supported on this GPU type.", +) +@pytest.mark.parametrize( + "model", + [pytest.param(test_config, marks=test_config.marks) for test_config in MODELS], +) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @@ -157,12 +161,15 @@ def test_models( num_logprobs: int, tp_size: int, ) -> None: - check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens, - num_logprobs, tp_size) + check_model_outputs( + vllm_runner, example_prompts, model, dtype, max_tokens, num_logprobs, tp_size + ) -@pytest.mark.skipif(not is_quant_method_supported("gguf"), - reason="gguf is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gguf"), + reason="gguf is not supported on this GPU type.", +) @pytest.mark.parametrize("model", [LLAMA_CONFIG]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [8]) @@ -178,5 +185,6 @@ def test_distributed( num_logprobs: int, tp_size: int, ) -> None: - check_model_outputs(vllm_runner, example_prompts, model, dtype, max_tokens, - num_logprobs, tp_size) + check_model_outputs( + vllm_runner, example_prompts, model, dtype, max_tokens, num_logprobs, tp_size + ) diff --git a/tests/models/quantization/test_gptq_bitblas.py b/tests/models/quantization/test_gptq_bitblas.py index c3aed77525de..b29c5e769ce8 100644 --- a/tests/models/quantization/test_gptq_bitblas.py +++ b/tests/models/quantization/test_gptq_bitblas.py @@ -7,9 +7,10 @@ bitblas/GPTQ models are in the top 3 selections of each other. Note: bitblas internally uses locks to synchronize the threads. This can -result in very slight nondeterminism for bitblas. As a result, we re-run the +result in very slight nondeterminism for bitblas. As a result, we re-run the test up to 3 times to see if we pass. """ + from dataclasses import dataclass import pytest @@ -41,16 +42,19 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model_pair.model_gptq, - dtype=dtype, - quantization="bitblas") as bitblas_model: + with vllm_runner( + model_pair.model_gptq, dtype=dtype, quantization="bitblas" + ) as bitblas_model: bitblas_outputs = bitblas_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model_pair.model_gptq, dtype=dtype, - quantization="gptq") as gptq_model: + with vllm_runner( + model_pair.model_gptq, dtype=dtype, quantization="gptq" + ) as gptq_model: gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/quantization/test_gptq_marlin.py b/tests/models/quantization/test_gptq_marlin.py index db70a3bd2c04..cf52ae39214d 100644 --- a/tests/models/quantization/test_gptq_marlin.py +++ b/tests/models/quantization/test_gptq_marlin.py @@ -9,6 +9,7 @@ result in very slight nondeterminism for Marlin. As a result, we re-run the test up to 3 times to see if we pass. """ + import os import pytest @@ -26,20 +27,20 @@ MODELS = [ # act_order==True, group_size=128 ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"), - # 8-bit, act_order==True, group_size=channelwise ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"), - # 4-bit, act_order==True, group_size=128 - ("TechxGenus/gemma-1.1-2b-it-GPTQ", "main") + ("TechxGenus/gemma-1.1-2b-it-GPTQ", "main"), ] @pytest.mark.flaky(reruns=3) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin") - or current_platform.is_rocm() - or not current_platform.is_cuda(), - reason="gptq_marlin is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin") + or current_platform.is_rocm() + or not current_platform.is_cuda(), + reason="gptq_marlin is not supported on this GPU type.", +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [32]) @@ -55,29 +56,34 @@ def test_models( model_name, revision = model # Run marlin. - with vllm_runner(model_name=model_name, - revision=revision, - dtype=dtype, - quantization="marlin", - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1) as gptq_marlin_model: - + with vllm_runner( + model_name=model_name, + revision=revision, + dtype=dtype, + quantization="marlin", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1, + ) as gptq_marlin_model: gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( - example_prompts[:-1], max_tokens, num_logprobs) + example_prompts[:-1], max_tokens, num_logprobs + ) _ROPE_DICT.clear() # clear rope cache to avoid rope dtype error # Run gptq. # The naive gptq kernel doesn't support bf16 yet. # Here we always compare fp16/bf16 gpt marlin kernel # to fp16 gptq kernel. - with vllm_runner(model_name=model_name, - revision=revision, - dtype="half", - quantization="gptq", - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1) as gptq_model: + with vllm_runner( + model_name=model_name, + revision=revision, + dtype="half", + quantization="gptq", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1, + ) as gptq_model: gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts[:-1], max_tokens, num_logprobs) + example_prompts[:-1], max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/quantization/test_gptq_marlin_24.py b/tests/models/quantization/test_gptq_marlin_24.py index 9b86ae95ba5c..85426ee5b089 100644 --- a/tests/models/quantization/test_gptq_marlin_24.py +++ b/tests/models/quantization/test_gptq_marlin_24.py @@ -6,6 +6,7 @@ As a result, in this test, we just confirm that the top selected tokens of the Marlin/GPTQ models are in the top 3 selections of each other. """ + from dataclasses import dataclass import pytest @@ -24,15 +25,18 @@ class ModelPair: model_pairs = [ # 4-bit, group_size == 128 - ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128", - model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128"), + ModelPair( + model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128", + model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128", + ), # # 4-bit, group_size == channelwise # ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-channelwise", # model_gptq="alexm-nm/tinyllama-24-gptq-4bit-channelwise"), - # 8-bit, group_size == 128 - ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128", - model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128"), + ModelPair( + model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128", + model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128", + ), # # 8-bit, group_size == channelwise # ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-channelwise", # model_gptq="alexm-nm/tinyllama-24-gptq-8bit-channelwise"), @@ -40,10 +44,12 @@ class ModelPair: @pytest.mark.flaky(reruns=2) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24") - or current_platform.is_rocm() - or not current_platform.is_cuda(), - reason="Marlin24 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin_24") + or current_platform.is_rocm() + or not current_platform.is_cuda(), + reason="Marlin24 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [8]) @@ -56,16 +62,19 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - with vllm_runner(model_pair.model_marlin, - dtype=dtype, - quantization="gptq_marlin_24") as marlin_24_model: + with vllm_runner( + model_pair.model_marlin, dtype=dtype, quantization="gptq_marlin_24" + ) as marlin_24_model: marlin_24_outputs = marlin_24_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) - with vllm_runner(model_pair.model_gptq, dtype=dtype, - quantization="gptq") as gptq_model: + with vllm_runner( + model_pair.model_gptq, dtype=dtype, quantization="gptq" + ) as gptq_model: gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/quantization/test_modelopt.py b/tests/models/quantization/test_modelopt.py index e23d4d9d211d..db3af972bb77 100644 --- a/tests/models/quantization/test_modelopt.py +++ b/tests/models/quantization/test_modelopt.py @@ -5,6 +5,7 @@ """Tests Model Optimizer fp8 models against ground truth generation Note: these tests will only pass on H100 """ + import os import pytest @@ -22,13 +23,13 @@ EXPECTED_STRS_MAP = { "nvidia/Llama-3.1-8B-Instruct-FP8": [ "You're referring to VLLM, a high-performance Large Language Model (LLM) inference and", - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'The comparison between artificial intelligence (AI) and human intelligence in terms of processing information is a complex and', + "Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ", + "The comparison between artificial intelligence (AI) and human intelligence in terms of processing information is a complex and", 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - '**The Spark of Imagination**\n\nZeta-5, a sleek and efficient robot, whir', - 'The COVID-19 pandemic has had a profound impact on global economic structures and business models, leading to', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** 「早起きは早く獲物をとる' + "**The Spark of Imagination**\n\nZeta-5, a sleek and efficient robot, whir", + "The COVID-19 pandemic has had a profound impact on global economic structures and business models, leading to", + "The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of", + "Here are the translations:\n\n**Japanese:** 「早起きは早く獲物をとる", ] } @@ -39,10 +40,12 @@ # the hardware being run on. # Disabled to prevent it from breaking the build @pytest.mark.skip( - reason= - "Prevent unstable test based on golden strings from breaking the build.") -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="fp8 is not supported on this GPU type.") + reason="Prevent unstable test based on golden strings from breaking the build." +) +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name", MODELS) def test_models(example_prompts, model_name) -> None: llm = LLM( @@ -55,12 +58,11 @@ def test_models(example_prompts, model_name) -> None: tokenizer = AutoTokenizer.from_pretrained(model_name) formatted_prompts = [ - tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - tokenize=False, - add_generation_prompt=True) + tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + ) for prompt in example_prompts ] params = SamplingParams(max_tokens=20, temperature=0) @@ -78,4 +80,5 @@ def test_models(example_prompts, model_name) -> None: generated_str = generations[i] expected_str = expected_strs[i] assert expected_str == generated_str, ( - f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}" + ) diff --git a/tests/models/quantization/test_mxfp4.py b/tests/models/quantization/test_mxfp4.py index 7b8a334bbc36..d598e405be81 100644 --- a/tests/models/quantization/test_mxfp4.py +++ b/tests/models/quantization/test_mxfp4.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # flake8: noqa -"""Tests Quark mxfp4 models against ground truth generation -""" +"""Tests Quark mxfp4 models against ground truth generation""" + import pytest from vllm import LLM, SamplingParams @@ -11,13 +11,13 @@ EXPECTED_STRS_MAP = { "amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8": [ - '\n### Key Features\n\n* **High-throughput Inference**: vLL', - '\nArtificial intelligence (AI) has evolved significantly since its inception in the 1', - 'Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been', - 'A neural network is a machine learning model inspired by the structure of the human brain. It consists of', - '\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol', - '\nThe COVID-19 pandemic has had a profound impact on global economic structures and business', - 'The Mona Lisa painting, created by Leonardo da Vinci in the early 16th', + "\n### Key Features\n\n* **High-throughput Inference**: vLL", + "\nArtificial intelligence (AI) has evolved significantly since its inception in the 1", + "Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been", + "A neural network is a machine learning model inspired by the structure of the human brain. It consists of", + "\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol", + "\nThe COVID-19 pandemic has had a profound impact on global economic structures and business", + "The Mona Lisa painting, created by Leonardo da Vinci in the early 16th", " everybody knows this proverbial saying, but did you know that it's not entirely accurate?", ] } @@ -38,4 +38,5 @@ def test_models(example_prompts, model_name) -> None: output_str = output.outputs[0].text expected_str = EXPECTED_STRS_MAP[model_name][i] assert expected_str == output_str, ( - f"Expected: {expected_str!r}\nvLLM: {output_str!r}") + f"Expected: {expected_str!r}\nvLLM: {output_str!r}" + ) diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index b3c217e729e4..9f45f142d68b 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -4,6 +4,7 @@ """Tests Model Optimizer nvfp4 models against ground truth generation Note: these tests will only pass on B200 """ + import os from typing import List @@ -21,14 +22,14 @@ EXPECTED_STRS_MAP = { "nvidia/Llama-3.3-70B-Instruct-FP4": [ - 'vLLM (Vectorized Large Language Model) is indeed a high-throughput and memory-efficient inference', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) are two distinct forms of intelligence that process', - 'A neural network is a type of machine learning model inspired by the structure and function of the human brain', - 'In the heart of a cutting-edge robotics lab, a team of engineers had been working tirelessly to push', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models, leading', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n* Japanese: (Sasuga no tori ga miwa o ts' + "vLLM (Vectorized Large Language Model) is indeed a high-throughput and memory-efficient inference", + "Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ", + "Artificial intelligence (AI) and human intelligence (HI) are two distinct forms of intelligence that process", + "A neural network is a type of machine learning model inspired by the structure and function of the human brain", + "In the heart of a cutting-edge robotics lab, a team of engineers had been working tirelessly to push", + "The COVID-19 pandemic has had a profound impact on global economic structures and future business models, leading", + "The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of", + "Here are the translations:\n\n* Japanese: (Sasuga no tori ga miwa o ts", ] } @@ -39,11 +40,13 @@ # the hardware being run on. # Disabled to prevent it from breaking the build @pytest.mark.skip( - reason= - "Prevent unstable test based on golden strings from breaking the build " - " and test input model being too large and hanging the system.") -@pytest.mark.skipif(not is_quant_method_supported("modelopt_fp4"), - reason="modelopt_fp4 is not supported on this GPU type.") + reason="Prevent unstable test based on golden strings from breaking the build " + " and test input model being too large and hanging the system." +) +@pytest.mark.skipif( + not is_quant_method_supported("modelopt_fp4"), + reason="modelopt_fp4 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_name", MODELS) def test_models(example_prompts, model_name) -> None: llm = LLM( @@ -56,12 +59,11 @@ def test_models(example_prompts, model_name) -> None: tokenizer = AutoTokenizer.from_pretrained(model_name) formatted_prompts = [ - tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - tokenize=False, - add_generation_prompt=True) + tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + ) for prompt in example_prompts ] params = SamplingParams(max_tokens=20, temperature=0) @@ -79,4 +81,5 @@ def test_models(example_prompts, model_name) -> None: generated_str = generations[i] expected_str = expected_strs[i] assert expected_str == generated_str, ( - f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}" + ) diff --git a/tests/models/registry.py b/tests/models/registry.py index 86a835975227..e1d9f1d1dd74 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -114,8 +114,10 @@ def check_transformers_version( If the installed transformers version does not meet the requirements, perform the given action. """ - if (self.min_transformers_version is None - and self.max_transformers_version is None): + if ( + self.min_transformers_version is None + and self.max_transformers_version is None + ): return None current_version = TRANSFORMERS_VERSION @@ -125,11 +127,17 @@ def check_transformers_version( msg = f"`transformers=={current_version}` installed, but `transformers" # Only check the base version for the min/max version, otherwise preview # models cannot be run because `x.yy.0.dev0`<`x.yy.0` - if (check_min_version and min_version - and Version(cur_base_version) < Version(min_version)): + if ( + check_min_version + and min_version + and Version(cur_base_version) < Version(min_version) + ): msg += f">={min_version}` is required to run this model." - elif (check_max_version and max_version - and Version(cur_base_version) > Version(max_version)): + elif ( + check_max_version + and max_version + and Version(cur_base_version) > Version(max_version) + ): msg += f"<={max_version}` is required to run this model." else: return None @@ -161,429 +169,625 @@ def check_available_online( pytest.skip(msg) -# yapf: disable _TEXT_GENERATION_EXAMPLE_MODELS = { # [Decoder-only] - "ApertusForCausalLM": _HfExamplesInfo("swiss-ai/Apertus-8B-2509", - min_transformers_version="4.56.0", - trust_remote_code=True), - "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", - trust_remote_code=True), - "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", - trust_remote_code=True), + "ApertusForCausalLM": _HfExamplesInfo( + "swiss-ai/Apertus-8B-2509", + min_transformers_version="4.56.0", + trust_remote_code=True, + ), + "AquilaModel": _HfExamplesInfo("BAAI/AquilaChat-7B", trust_remote_code=True), + "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True), "ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base"), - "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", - trust_remote_code=True), - "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", - trust_remote_code=True), - "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", - trust_remote_code=True), - "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", - trust_remote_code=True), - "BailingMoeV2ForCausalLM": _HfExamplesInfo("inclusionAI/Ling-mini-2.0", - trust_remote_code=True), - "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B-v1", - min_transformers_version="4.55.3", - extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 - "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", - {"1b": "bigscience/bloomz-1b1"}), - "ChatGLMModel": _HfExamplesInfo("zai-org/chatglm3-6b", - trust_remote_code=True, - max_transformers_version="4.48"), - "ChatGLMForConditionalGeneration": _HfExamplesInfo("thu-coai/ShieldLM-6B-chatglm3", # noqa: E501 - trust_remote_code=True), - "CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01", - trust_remote_code=True), - "Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501 - trust_remote_code=True), - "CwmForCausalLM": _HfExamplesInfo("facebook/cwm", # noqa: E501 - trust_remote_code=True, - is_available_online=False), + "ArcticForCausalLM": _HfExamplesInfo( + "Snowflake/snowflake-arctic-instruct", trust_remote_code=True + ), + "BaiChuanForCausalLM": _HfExamplesInfo( + "baichuan-inc/Baichuan-7B", trust_remote_code=True + ), + "BaichuanForCausalLM": _HfExamplesInfo( + "baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True + ), + "BailingMoeForCausalLM": _HfExamplesInfo( + "inclusionAI/Ling-lite-1.5", trust_remote_code=True + ), + "BailingMoeV2ForCausalLM": _HfExamplesInfo( + "inclusionAI/Ling-mini-2.0", trust_remote_code=True + ), + "BambaForCausalLM": _HfExamplesInfo( + "ibm-ai-platform/Bamba-9B-v1", + min_transformers_version="4.55.3", + extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}, + ), + "BloomForCausalLM": _HfExamplesInfo( + "bigscience/bloom-560m", {"1b": "bigscience/bloomz-1b1"} + ), + "ChatGLMModel": _HfExamplesInfo( + "zai-org/chatglm3-6b", trust_remote_code=True, max_transformers_version="4.48" + ), + "ChatGLMForConditionalGeneration": _HfExamplesInfo( + "thu-coai/ShieldLM-6B-chatglm3", + trust_remote_code=True, + ), + "CohereForCausalLM": _HfExamplesInfo( + "CohereForAI/c4ai-command-r-v01", trust_remote_code=True + ), + "Cohere2ForCausalLM": _HfExamplesInfo( + "CohereForAI/c4ai-command-r7b-12-2024", + trust_remote_code=True, + ), + "CwmForCausalLM": _HfExamplesInfo( + "facebook/cwm", + trust_remote_code=True, + is_available_online=False, + ), "DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"), - "DeciLMForCausalLM": _HfExamplesInfo("nvidia/Llama-3_3-Nemotron-Super-49B-v1", # noqa: E501 - trust_remote_code=True), + "DeciLMForCausalLM": _HfExamplesInfo( + "nvidia/Llama-3_3-Nemotron-Super-49B-v1", + trust_remote_code=True, + ), "DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"), - "DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501 - trust_remote_code=True), - "DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501 - trust_remote_code=True), + "DeepseekV2ForCausalLM": _HfExamplesInfo( + "deepseek-ai/DeepSeek-V2-Lite-Chat", + trust_remote_code=True, + ), + "DeepseekV3ForCausalLM": _HfExamplesInfo( + "deepseek-ai/DeepSeek-V3", + trust_remote_code=True, + ), "DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"), - "Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT", - min_transformers_version="4.54"), - "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", - min_transformers_version="4.54"), - "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", - trust_remote_code=True), - "Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B", - min_transformers_version="4.54"), - "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 + "Ernie4_5ForCausalLM": _HfExamplesInfo( + "baidu/ERNIE-4.5-0.3B-PT", min_transformers_version="4.54" + ), + "Ernie4_5_MoeForCausalLM": _HfExamplesInfo( + "baidu/ERNIE-4.5-21B-A3B-PT", min_transformers_version="4.54" + ), + "ExaoneForCausalLM": _HfExamplesInfo( + "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", trust_remote_code=True + ), + "Exaone4ForCausalLM": _HfExamplesInfo( + "LGAI-EXAONE/EXAONE-4.0-32B", min_transformers_version="4.54" + ), + "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), - "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), + "FalconH1ForCausalLM": _HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), - "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it", - min_transformers_version="4.53"), + "Gemma3nForCausalLM": _HfExamplesInfo( + "google/gemma-3n-E2B-it", min_transformers_version="4.53" + ), "GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"), "Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"), - "Glm4MoeForCausalLM": _HfExamplesInfo("zai-org/GLM-4.5", - min_transformers_version="4.54"), # noqa: E501 - "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", - {"alias": "gpt2"}), - "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder", - extras={"tiny": "bigcode/tiny_starcoder_py"}, # noqa: E501 - min_transformers_version="4.55.1", - transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501 - "GPTJForCausalLM": _HfExamplesInfo("Milos/slovak-gpt-j-405M", - {"6b": "EleutherAI/gpt-j-6b"}), - "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-70m", - {"1b": "EleutherAI/pythia-1.4b"}), + "Glm4MoeForCausalLM": _HfExamplesInfo( + "zai-org/GLM-4.5", min_transformers_version="4.54" + ), + "GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}), + "GPTBigCodeForCausalLM": _HfExamplesInfo( + "bigcode/starcoder", + extras={"tiny": "bigcode/tiny_starcoder_py"}, + min_transformers_version="4.55.1", + transformers_version_reason="HF model broken in 4.55.0", + ), + "GPTJForCausalLM": _HfExamplesInfo( + "Milos/slovak-gpt-j-405M", {"6b": "EleutherAI/gpt-j-6b"} + ), + "GPTNeoXForCausalLM": _HfExamplesInfo( + "EleutherAI/pythia-70m", {"1b": "EleutherAI/pythia-1.4b"} + ), "GptOssForCausalLM": _HfExamplesInfo("lmsys/gpt-oss-20b-bf16"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), - "GraniteMoeHybridForCausalLM": _HfExamplesInfo("ibm-granite/granite-4.0-tiny-preview", # noqa: E501 - min_transformers_version="4.55.3"), - "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts"), # noqa: E501 - "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", - trust_remote_code=True), - "HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct", - trust_remote_code=True), + "GraniteMoeHybridForCausalLM": _HfExamplesInfo( + "ibm-granite/granite-4.0-tiny-preview", + min_transformers_version="4.55.3", + ), + "GraniteMoeSharedForCausalLM": _HfExamplesInfo( + "ibm-research/moe-7b-1b-active-shared-experts" + ), + "Grok1ModelForCausalLM": _HfExamplesInfo( + "hpcai-tech/grok-1", trust_remote_code=True + ), + "HunYuanMoEV1ForCausalLM": _HfExamplesInfo( + "tencent/Hunyuan-A13B-Instruct", trust_remote_code=True + ), # TODO: Remove is_available_online once their config.json is fixed - "HunYuanDenseV1ForCausalLM":_HfExamplesInfo("tencent/Hunyuan-7B-Instruct-0124", - trust_remote_code=True, - is_available_online=False), - "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", - trust_remote_code=True), - "InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b", - trust_remote_code=True), - "InternLM2VEForCausalLM": _HfExamplesInfo("OpenGVLab/Mono-InternVL-2B", - trust_remote_code=True), - "InternLM3ForCausalLM": _HfExamplesInfo("internlm/internlm3-8b-instruct", - trust_remote_code=True), + "HunYuanDenseV1ForCausalLM": _HfExamplesInfo( + "tencent/Hunyuan-7B-Instruct-0124", + trust_remote_code=True, + is_available_online=False, + ), + "InternLMForCausalLM": _HfExamplesInfo( + "internlm/internlm-chat-7b", trust_remote_code=True + ), + "InternLM2ForCausalLM": _HfExamplesInfo( + "internlm/internlm2-chat-7b", trust_remote_code=True + ), + "InternLM2VEForCausalLM": _HfExamplesInfo( + "OpenGVLab/Mono-InternVL-2B", trust_remote_code=True + ), + "InternLM3ForCausalLM": _HfExamplesInfo( + "internlm/internlm3-8b-instruct", trust_remote_code=True + ), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), - "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", - min_transformers_version="4.55.3", - extras={ - "tiny": "ai21labs/Jamba-tiny-dev", - "random": "ai21labs/Jamba-tiny-random", # noqa: E501 - }), - "Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B", - min_transformers_version="4.54"), - "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", - extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 - "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 - "fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"}), # noqa: E501 - "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", - is_available_online=False), - "Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 - is_available_online=False), - "LongcatFlashForCausalLM": _HfExamplesInfo - ("meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True), + "JambaForCausalLM": _HfExamplesInfo( + "ai21labs/AI21-Jamba-1.5-Mini", + min_transformers_version="4.55.3", + extras={ + "tiny": "ai21labs/Jamba-tiny-dev", + "random": "ai21labs/Jamba-tiny-random", + }, + ), + "Lfm2ForCausalLM": _HfExamplesInfo( + "LiquidAI/LFM2-1.2B", min_transformers_version="4.54" + ), + "LlamaForCausalLM": _HfExamplesInfo( + "meta-llama/Llama-3.2-1B-Instruct", + extras={ + "guard": "meta-llama/Llama-Guard-3-1B", + "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", + "fp8": "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", + }, + ), + "LLaMAForCausalLM": _HfExamplesInfo( + "decapoda-research/llama-7b-hf", is_available_online=False + ), + "Llama4ForCausalLM": _HfExamplesInfo( + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + is_available_online=False, + ), + "LongcatFlashForCausalLM": _HfExamplesInfo( + "meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True + ), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), - "Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1", - min_transformers_version="4.55.3", - extras={ - "random": "yujiepan/mamba2-codestral-v0.1-tiny-random", # noqa: E501 - }), - "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501 - "MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16", - trust_remote_code=True), - "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", - trust_remote_code=True), + "Mamba2ForCausalLM": _HfExamplesInfo( + "mistralai/Mamba-Codestral-7B-v0.1", + min_transformers_version="4.55.3", + extras={ + "random": "yujiepan/mamba2-codestral-v0.1-tiny-random", + }, + ), + "FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), + "MiniCPMForCausalLM": _HfExamplesInfo( + "openbmb/MiniCPM-2B-sft-bf16", trust_remote_code=True + ), + "MiniCPM3ForCausalLM": _HfExamplesInfo( + "openbmb/MiniCPM3-4B", trust_remote_code=True + ), "MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf"), - "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", - trust_remote_code=True, - revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501 - "MiniMaxM1ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-M1-40k", - trust_remote_code=True), + "MiniMaxText01ForCausalLM": _HfExamplesInfo( + "MiniMaxAI/MiniMax-Text-01", + trust_remote_code=True, + revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3", + ), + "MiniMaxM1ForCausalLM": _HfExamplesInfo( + "MiniMaxAI/MiniMax-M1-40k", trust_remote_code=True + ), "MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"), - "MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501 - {"tiny": "TitanML/tiny-mixtral"}), # noqa: E501 + "MixtralForCausalLM": _HfExamplesInfo( + "mistralai/Mixtral-8x7B-Instruct-v0.1", + {"tiny": "TitanML/tiny-mixtral"}, + ), "MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False), "MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"), "NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"), - "NemotronHForCausalLM": _HfExamplesInfo("nvidia/Nemotron-H-8B-Base-8K", - trust_remote_code=True), + "NemotronHForCausalLM": _HfExamplesInfo( + "nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True + ), "OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"), "Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"), "Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"), "OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"), - "OPTForCausalLM": _HfExamplesInfo("facebook/opt-125m", - {"1b": "facebook/opt-iml-max-1.3b"}), - "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", - trust_remote_code=True), + "OPTForCausalLM": _HfExamplesInfo( + "facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"} + ), + "OrionForCausalLM": _HfExamplesInfo( + "OrionStarAI/Orion-14B-Chat", trust_remote_code=True + ), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), - "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", - trust_remote_code=True), - "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", - max_transformers_version="4.55.4", - transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 - trust_remote_code=True), - "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", - max_transformers_version="4.53", - transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 - trust_remote_code=True), - "Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-0.5B-Instruct", - extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501 + "PhiMoEForCausalLM": _HfExamplesInfo( + "microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True + ), + "Plamo2ForCausalLM": _HfExamplesInfo( + "pfnet/plamo-2-1b", + max_transformers_version="4.55.4", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 + trust_remote_code=True, + ), + "QWenLMHeadModel": _HfExamplesInfo( + "Qwen/Qwen-7B-Chat", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 + trust_remote_code=True, + ), + "Qwen2ForCausalLM": _HfExamplesInfo( + "Qwen/Qwen2-0.5B-Instruct", extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"} + ), "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), - "Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", - extras={"tiny-random": "tiny-random/qwen3-next-moe"}, # noqa: E501 - min_transformers_version="4.56.3"), + "Qwen3NextForCausalLM": _HfExamplesInfo( + "Qwen/Qwen3-Next-80B-A3B-Instruct", + extras={"tiny-random": "tiny-random/qwen3-next-moe"}, + min_transformers_version="4.56.3", + ), "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), - "SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501 - trust_remote_code=True, - is_available_online=False), + "SeedOssForCausalLM": _HfExamplesInfo( + "ByteDance-Seed/Seed-OSS-36B-Instruct", + trust_remote_code=True, + is_available_online=False, + ), "SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"), - "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 + "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), - "Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3", - trust_remote_code=True), - "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct", - trust_remote_code=True), - "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", - trust_remote_code=True), - "TeleFLMForCausalLM": _HfExamplesInfo("CofeAI/FLM-2-52B-Instruct-2407", - trust_remote_code=True), - "XverseForCausalLM": _HfExamplesInfo("xverse/XVERSE-7B-Chat", - tokenizer="meta-llama/Llama-2-7b", - trust_remote_code=True), + "Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3", trust_remote_code=True), + "SolarForCausalLM": _HfExamplesInfo( + "upstage/solar-pro-preview-instruct", trust_remote_code=True + ), + "TeleChat2ForCausalLM": _HfExamplesInfo( + "Tele-AI/TeleChat2-3B", trust_remote_code=True + ), + "TeleFLMForCausalLM": _HfExamplesInfo( + "CofeAI/FLM-2-52B-Instruct-2407", trust_remote_code=True + ), + "XverseForCausalLM": _HfExamplesInfo( + "xverse/XVERSE-7B-Chat", + tokenizer="meta-llama/Llama-2-7b", + trust_remote_code=True, + ), "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), - "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", - trust_remote_code=True), + "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True), "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"), } _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), - "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), # noqa: E501 + "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), - "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", - trust_remote_code=True), - "GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5", - trust_remote_code=True, - hf_overrides={"architectures": ["GteNewModel"]}), # noqa: E501 - "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", - trust_remote_code=True), - "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 + "GteModel": _HfExamplesInfo( + "Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True + ), + "GteNewModel": _HfExamplesInfo( + "Alibaba-NLP/gte-base-en-v1.5", + trust_remote_code=True, + hf_overrides={"architectures": ["GteNewModel"]}, + ), + "InternLM2ForRewardModel": _HfExamplesInfo( + "internlm/internlm2-1_8b-reward", trust_remote_code=True + ), + "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), - "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", - trust_remote_code=True), - "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", - trust_remote_code=True), # noqa: E501 + "ModernBertModel": _HfExamplesInfo( + "Alibaba-NLP/gte-modernbert-base", trust_remote_code=True + ), + "NomicBertModel": _HfExamplesInfo( + "nomic-ai/nomic-embed-text-v2-moe", trust_remote_code=True + ), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), - "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B", - max_transformers_version="4.53", - transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 - "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B", - max_transformers_version="4.53", - transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 - "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 - "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 - "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # noqa: E501 + "Qwen2ForRewardModel": _HfExamplesInfo( + "Qwen/Qwen2.5-Math-RM-72B", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 + ), + "Qwen2ForProcessRewardModel": _HfExamplesInfo( + "Qwen/Qwen2.5-Math-PRM-7B", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501 + ), + "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), + "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), + "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), # [Multimodal] + "CLIPModel": _HfExamplesInfo("openai/clip-vit-base-patch32"), "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), - "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", - trust_remote_code=True), - "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 - "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 - dtype=torch.float16, - enforce_eager=True, - skip_tokenizer_init=True, - # This is to avoid the model - # going OOM in CI - max_num_seqs=32, - ), - "Terratorch": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 - dtype=torch.float16, - enforce_eager=True, - skip_tokenizer_init=True, - # This is to avoid the model going OOM in CI - max_num_seqs=32, - ), + "Phi3VForCausalLM": _HfExamplesInfo( + "TIGER-Lab/VLM2Vec-Full", trust_remote_code=True + ), + "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), + "PrithviGeoSpatialMAE": _HfExamplesInfo( + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # This is to avoid the model + # going OOM in CI + max_num_seqs=32, + ), + "Terratorch": _HfExamplesInfo( + "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # This is to avoid the model going OOM in CI + max_num_seqs=32, + ), } _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { # [Decoder-only] - "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 - + "GPT2ForSequenceClassification": _HfExamplesInfo( + "nie3e/sentiment-polish-gpt2-small" + ), # [Cross-encoder] - "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 + "BertForSequenceClassification": _HfExamplesInfo( + "cross-encoder/ms-marco-MiniLM-L-6-v2" + ), "BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"), - "GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501 - trust_remote_code=True, - hf_overrides={ - "architectures": ["GteNewForSequenceClassification"]}),# noqa: E501 - "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501 - "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 - "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 + "GteNewForSequenceClassification": _HfExamplesInfo( + "Alibaba-NLP/gte-multilingual-reranker-base", + trust_remote_code=True, + hf_overrides={"architectures": ["GteNewForSequenceClassification"]}, + ), + "ModernBertForSequenceClassification": _HfExamplesInfo( + "Alibaba-NLP/gte-reranker-modernbert-base" + ), + "RobertaForSequenceClassification": _HfExamplesInfo( + "cross-encoder/quora-roberta-base" + ), + "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), } _AUTOMATIC_CONVERTED_MODELS = { # Use as_seq_cls_model for automatic conversion - "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 - hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 - "classifier_from_token": ["Yes"], # noqa: E501 - "method": "no_post_processing"}), # noqa: E501 - "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 - "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 - "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 + "GemmaForSequenceClassification": _HfExamplesInfo( + "BAAI/bge-reranker-v2-gemma", + hf_overrides={ + "architectures": ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": "no_post_processing", + }, + ), + "LlamaForSequenceClassification": _HfExamplesInfo( + "Skywork/Skywork-Reward-V2-Llama-3.2-1B" + ), + "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), + "Qwen3ForSequenceClassification": _HfExamplesInfo( + "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" + ), } _MULTIMODAL_EXAMPLE_MODELS = { # [Decoder-only] "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), - "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501 - "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501 - extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501 - "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 - "Cohere2VisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/command-a-vision-07-2025"), # noqa: E501 - "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 - extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 - max_transformers_version="4.48", # noqa: E501 - transformers_version_reason="HF model is not compatible.", # noqa: E501 - hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 - "DotsOCRForCausalLM": _HfExamplesInfo("rednote-hilab/dots.ocr", - trust_remote_code=True), + "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), + "Blip2ForConditionalGeneration": _HfExamplesInfo( + "Salesforce/blip2-opt-2.7b", + extras={"6b": "Salesforce/blip2-opt-6.7b"}, + ), + "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), + "Cohere2VisionForConditionalGeneration": _HfExamplesInfo( + "CohereLabs/command-a-vision-07-2025" + ), + "DeepseekVLV2ForCausalLM": _HfExamplesInfo( + "deepseek-ai/deepseek-vl2-tiny", + extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, + max_transformers_version="4.48", + transformers_version_reason="HF model is not compatible.", + hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}, + ), + "DotsOCRForCausalLM": _HfExamplesInfo( + "rednote-hilab/dots.ocr", trust_remote_code=True + ), "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), - "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501 - trust_remote_code=True), + "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo( + "baidu/ERNIE-4.5-VL-28B-A3B-PT", + trust_remote_code=True, + ), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), - "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501 - min_transformers_version="4.53"), - "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501 - "GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b", - trust_remote_code=True, - hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 - "Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501 - "Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V", - min_transformers_version="4.56"), # noqa: E501 - "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", - trust_remote_code=True, - extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 - max_transformers_version="4.48", # noqa: E501 - transformers_version_reason="HF model is not compatible."), # noqa: E501 - "HCXVisionForCausalLM": _HfExamplesInfo("naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", # noqa: E501 - trust_remote_code=True), - "Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501 - {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, # noqa: E501 - min_transformers_version="4.56", - transformers_version_reason="HF model broken in 4.55"), # noqa: E501 - "InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1", - trust_remote_code=True), # noqa: E501 - "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", - extras={"2B": "OpenGVLab/InternVL2-2B", - "3.0": "OpenGVLab/InternVL3-1B", # noqa: E501 - "3.5-qwen3": "OpenGVLab/InternVL3_5-1B", # noqa: E501 - "3.5-qwen3moe": "OpenGVLab/InternVL3_5-30B-A3B", # noqa: E501 - "3.5-gptoss": "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview"}, # noqa: E501 - trust_remote_code=True), - "InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), # noqa: E501 - "KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501 - trust_remote_code=True), - "KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-1_5-8B", # noqa: E501 - trust_remote_code=True), - "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 - extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 - trust_remote_code=True), - "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 - max_model_len=10240, - extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501 - ), - "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", - extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501 - "mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501 - "LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501 - "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501 - "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 - "MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501 - max_transformers_version="4.48", # noqa: E501 - transformers_version_reason="HF model is not compatible.", # noqa: E501 - hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501 - "MiDashengLMModel": _HfExamplesInfo("mispeech/midashenglm-7b", - trust_remote_code=True), - "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", - trust_remote_code=True), - "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", - extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4", "4.5": "openbmb/MiniCPM-V-4_5"}, # noqa: E501 - trust_remote_code=True), - "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501 - trust_remote_code=True, - v0_only=True), - "Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501 - extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 - "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", - max_transformers_version="4.48", - transformers_version_reason="Incorrectly-detected `tensorflow` import.", # noqa: E501 - extras={"olmo": "allenai/Molmo-7B-O-0924"}, # noqa: E501 - trust_remote_code=True), - "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", - trust_remote_code=True), - "Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501 - trust_remote_code=True), - "NemotronH_Nano_VL_V2": _HfExamplesInfo("nano_vl_dummy", - is_available_online=False, - trust_remote_code=True), - "Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True, - max_transformers_version="4.53", - transformers_version_reason="HF model is not compatible", # noqa: E501 - extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", - "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501 - "Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", - trust_remote_code=True), - "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501 - extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 - "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", - trust_remote_code=True, - max_transformers_version="4.48", - transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 - extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501 - "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", - trust_remote_code=True), - "Phi4MultimodalForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", # noqa: E501 - revision="refs/pr/70"), - "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 - tokenizer_mode="mistral"), - "QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL", - extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501 - trust_remote_code=True, - hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501 - "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 - "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 - "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 - max_model_len=4096), + "Gemma3nForConditionalGeneration": _HfExamplesInfo( + "google/gemma-3n-E2B-it", + min_transformers_version="4.53", + ), + "GraniteSpeechForConditionalGeneration": _HfExamplesInfo( + "ibm-granite/granite-speech-3.3-2b" + ), + "GLM4VForCausalLM": _HfExamplesInfo( + "zai-org/glm-4v-9b", + trust_remote_code=True, + hf_overrides={"architectures": ["GLM4VForCausalLM"]}, + ), + "Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), + "Glm4vMoeForConditionalGeneration": _HfExamplesInfo( + "zai-org/GLM-4.5V", min_transformers_version="4.56" + ), + "H2OVLChatModel": _HfExamplesInfo( + "h2oai/h2ovl-mississippi-800m", + trust_remote_code=True, + extras={"2b": "h2oai/h2ovl-mississippi-2b"}, + max_transformers_version="4.48", + transformers_version_reason="HF model is not compatible.", + ), + "HCXVisionForCausalLM": _HfExamplesInfo( + "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", + trust_remote_code=True, + ), + "Idefics3ForConditionalGeneration": _HfExamplesInfo( + "HuggingFaceM4/Idefics3-8B-Llama3", + {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, + min_transformers_version="4.56", + transformers_version_reason="HF model broken in 4.55", + ), + "InternS1ForConditionalGeneration": _HfExamplesInfo( + "internlm/Intern-S1", trust_remote_code=True + ), + "InternVLChatModel": _HfExamplesInfo( + "OpenGVLab/InternVL2-1B", + extras={ + "2B": "OpenGVLab/InternVL2-2B", + "3.0": "OpenGVLab/InternVL3-1B", + "3.5-qwen3": "OpenGVLab/InternVL3_5-1B", + "3.5-qwen3moe": "OpenGVLab/InternVL3_5-30B-A3B", + "3.5-gptoss": "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview", + }, + trust_remote_code=True, + ), + "InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), + "KeyeForConditionalGeneration": _HfExamplesInfo( + "Kwai-Keye/Keye-VL-8B-Preview", + trust_remote_code=True, + ), + "KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo( + "Kwai-Keye/Keye-VL-1_5-8B", + trust_remote_code=True, + ), + "KimiVLForConditionalGeneration": _HfExamplesInfo( + "moonshotai/Kimi-VL-A3B-Instruct", + extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, + trust_remote_code=True, + ), + "Llama4ForConditionalGeneration": _HfExamplesInfo( + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + max_model_len=10240, + extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, + ), + "LlavaForConditionalGeneration": _HfExamplesInfo( + "llava-hf/llava-1.5-7b-hf", + extras={ + "mistral": "mistral-community/pixtral-12b", + "mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic", + }, + ), + "LlavaNextForConditionalGeneration": _HfExamplesInfo( + "llava-hf/llava-v1.6-mistral-7b-hf" + ), + "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo( + "llava-hf/LLaVA-NeXT-Video-7B-hf" + ), + "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo( + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" + ), + "MantisForConditionalGeneration": _HfExamplesInfo( + "TIGER-Lab/Mantis-8B-siglip-llama3", + max_transformers_version="4.48", + transformers_version_reason="HF model is not compatible.", + hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, + ), + "MiDashengLMModel": _HfExamplesInfo( + "mispeech/midashenglm-7b", trust_remote_code=True + ), + "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True), + "MiniCPMV": _HfExamplesInfo( + "openbmb/MiniCPM-Llama3-V-2_5", + extras={ + "2.6": "openbmb/MiniCPM-V-2_6", + "4.0": "openbmb/MiniCPM-V-4", + "4.5": "openbmb/MiniCPM-V-4_5", + }, + trust_remote_code=True, + ), + "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo( + "MiniMaxAI/MiniMax-VL-01", + trust_remote_code=True, + v0_only=True, + ), + "Mistral3ForConditionalGeneration": _HfExamplesInfo( + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}, + ), + "MolmoForCausalLM": _HfExamplesInfo( + "allenai/Molmo-7B-D-0924", + max_transformers_version="4.48", + transformers_version_reason="Incorrectly-detected `tensorflow` import.", + extras={"olmo": "allenai/Molmo-7B-O-0924"}, + trust_remote_code=True, + ), + "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True), + "Llama_Nemotron_Nano_VL": _HfExamplesInfo( + "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", + trust_remote_code=True, + ), + "NemotronH_Nano_VL_V2": _HfExamplesInfo( + "nano_vl_dummy", is_available_online=False, trust_remote_code=True + ), + "Ovis": _HfExamplesInfo( + "AIDC-AI/Ovis2-1B", + trust_remote_code=True, + max_transformers_version="4.53", + transformers_version_reason="HF model is not compatible", + extras={ + "1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B", + "1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B", + }, + ), + "Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", trust_remote_code=True), + "PaliGemmaForConditionalGeneration": _HfExamplesInfo( + "google/paligemma-3b-mix-224", + extras={"v2": "google/paligemma2-3b-ft-docci-448"}, + ), + "Phi3VForCausalLM": _HfExamplesInfo( + "microsoft/Phi-3-vision-128k-instruct", + trust_remote_code=True, + max_transformers_version="4.48", + transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501 + extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}, + ), + "Phi4MMForCausalLM": _HfExamplesInfo( + "microsoft/Phi-4-multimodal-instruct", trust_remote_code=True + ), + "Phi4MultimodalForCausalLM": _HfExamplesInfo( + "microsoft/Phi-4-multimodal-instruct", + revision="refs/pr/70", + ), + "PixtralForConditionalGeneration": _HfExamplesInfo( + "mistralai/Pixtral-12B-2409", + tokenizer_mode="mistral", + ), + "QwenVLForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen-VL", + extras={"chat": "Qwen/Qwen-VL-Chat"}, + trust_remote_code=True, + hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, + ), + "Qwen2AudioForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen2-Audio-7B-Instruct" + ), + "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), + "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen2.5-VL-3B-Instruct", + max_model_len=4096, + ), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), - "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 - "Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501 - max_model_len=4096, - min_transformers_version="4.57", - is_available_online=False), - "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501 - max_model_len=4096, - min_transformers_version="4.57", - is_available_online=False), - "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", - trust_remote_code=True), - "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", - trust_remote_code=True), - "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501 - min_transformers_version="4.56", - transformers_version_reason="HF model broken in 4.55"), # noqa: E501 - "Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3", - trust_remote_code=True), - "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 - trust_remote_code=True), - "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), # noqa: E501 - "Tarsier2ForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier2-Recap-7b", # noqa: E501 - hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}), # noqa: E501 + "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), + "Qwen3VLForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen3-VL-4B-Instruct", + max_model_len=4096, + min_transformers_version="4.57", + is_available_online=False, + ), + "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo( + "Qwen/Qwen3-VL-30B-A3B-Instruct", + max_model_len=4096, + min_transformers_version="4.57", + is_available_online=False, + ), + "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", trust_remote_code=True), + "SkyworkR1VChatModel": _HfExamplesInfo( + "Skywork/Skywork-R1V-38B", trust_remote_code=True + ), + "SmolVLMForConditionalGeneration": _HfExamplesInfo( + "HuggingFaceTB/SmolVLM2-2.2B-Instruct", + min_transformers_version="4.56", + transformers_version_reason="HF model broken in 4.55", + ), + "Step3VLForConditionalGeneration": _HfExamplesInfo( + "stepfun-ai/step3", trust_remote_code=True + ), + "UltravoxModel": _HfExamplesInfo( + "fixie-ai/ultravox-v0_5-llama-3_2-1b", + trust_remote_code=True, + ), + "TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), + "Tarsier2ForConditionalGeneration": _HfExamplesInfo( + "omni-research/Tarsier2-Recap-7b", + hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}, + ), "VoxtralForConditionalGeneration": _HfExamplesInfo( "mistralai/Voxtral-Mini-3B-2507", min_transformers_version="4.54", @@ -591,80 +795,120 @@ def check_available_online( is_available_online=False, ), # [Encoder-decoder] - "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 + "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # [Cross-encoder] - "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 + "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), } _SPECULATIVE_DECODING_EXAMPLE_MODELS = { - "MedusaModel": _HfExamplesInfo("JackFram/llama-68m", - speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 + "MedusaModel": _HfExamplesInfo( + "JackFram/llama-68m", speculative_model="abhigoyal/vllm-medusa-llama-68m-random" + ), # Temporarily disabled. # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. - # "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", - # speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 - "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", - speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 - trust_remote_code=True), - "EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random", - speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501 - trust_remote_code=True), - "EagleLlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B-Instruct", # noqa: E501 - trust_remote_code=True, - speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", - tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 - "Eagle3LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.1-8B-Instruct", # noqa: E501 - trust_remote_code=True, - speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501 - tokenizer="meta-llama/Llama-3.1-8B-Instruct", - use_original_num_layers=True, - max_model_len=10240), - "LlamaForCausalLMEagle3": _HfExamplesInfo("Qwen/Qwen3-8B", # noqa: E501 - trust_remote_code=True, - speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501 - tokenizer="Qwen/Qwen3-8B", - use_original_num_layers=True), + # "MLPSpeculatorPreTrainedModel": _HfExamplesInfo( + # "JackFram/llama-160m", + # speculative_model="ibm-ai-platform/llama-160m-accelerator" + # ), + "DeepSeekMTPModel": _HfExamplesInfo( + "luccafong/deepseek_mtp_main_random", + speculative_model="luccafong/deepseek_mtp_draft_random", + trust_remote_code=True, + ), + "EagleDeepSeekMTPModel": _HfExamplesInfo( + "eagle618/deepseek-v3-random", + speculative_model="eagle618/eagle-deepseek-v3-random", + trust_remote_code=True, + ), + "EagleLlamaForCausalLM": _HfExamplesInfo( + "meta-llama/Meta-Llama-3-8B-Instruct", + trust_remote_code=True, + speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct", + ), + "Eagle3LlamaForCausalLM": _HfExamplesInfo( + "meta-llama/Llama-3.1-8B-Instruct", + trust_remote_code=True, + speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + tokenizer="meta-llama/Llama-3.1-8B-Instruct", + use_original_num_layers=True, + max_model_len=10240, + ), + "LlamaForCausalLMEagle3": _HfExamplesInfo( + "Qwen/Qwen3-8B", + trust_remote_code=True, + speculative_model="AngelSlim/Qwen3-8B_eagle3", + tokenizer="Qwen/Qwen3-8B", + use_original_num_layers=True, + ), "EagleLlama4ForCausalLM": _HfExamplesInfo( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", trust_remote_code=True, speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", - tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 - "EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16", - trust_remote_code=True, - is_available_online=False, - speculative_model="openbmb/MiniCPM-2B-sft-bf16", - tokenizer="openbmb/MiniCPM-2B-sft-bf16"), - "ErnieMTPModel": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", - trust_remote_code=True, - speculative_model="baidu/ERNIE-4.5-21B-A3B-PT"), - "Glm4MoeMTPModel": _HfExamplesInfo("zai-org/GLM-4.5", - speculative_model="zai-org/GLM-4.5", - min_transformers_version="4.56", - is_available_online=False), + tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct", + ), + "EagleMiniCPMForCausalLM": _HfExamplesInfo( + "openbmb/MiniCPM-1B-sft-bf16", + trust_remote_code=True, + is_available_online=False, + speculative_model="openbmb/MiniCPM-2B-sft-bf16", + tokenizer="openbmb/MiniCPM-2B-sft-bf16", + ), + "ErnieMTPModel": _HfExamplesInfo( + "baidu/ERNIE-4.5-21B-A3B-PT", + trust_remote_code=True, + speculative_model="baidu/ERNIE-4.5-21B-A3B-PT", + ), + "Glm4MoeMTPModel": _HfExamplesInfo( + "zai-org/GLM-4.5", + speculative_model="zai-org/GLM-4.5", + min_transformers_version="4.56", + is_available_online=False, + ), "LongCatFlashMTPModel": _HfExamplesInfo( "meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True, - speculative_model="meituan-longcat/LongCat-Flash-Chat"), - "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", - trust_remote_code=True, - speculative_model="XiaomiMiMo/MiMo-7B-RL"), + speculative_model="meituan-longcat/LongCat-Flash-Chat", + ), + "MiMoMTPModel": _HfExamplesInfo( + "XiaomiMiMo/MiMo-7B-RL", + trust_remote_code=True, + speculative_model="XiaomiMiMo/MiMo-7B-RL", + ), "Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo( "Qwen/Qwen2.5-VL-7B-Instruct", - speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"), - "Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct", - min_transformers_version="4.56.3"), + speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl", + ), + "Qwen3NextMTP": _HfExamplesInfo( + "Qwen/Qwen3-Next-80B-A3B-Instruct", min_transformers_version="4.56.3" + ), } _TRANSFORMERS_BACKEND_MODELS = { - "TransformersEmbeddingModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0"), # noqa: E501 - "TransformersForSequenceClassification": _HfExamplesInfo("papluca/xlm-roberta-base-language-detection", min_transformers_version="4.57.0.dev0"), # noqa: E501 - "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 + "TransformersEmbeddingModel": _HfExamplesInfo( + "BAAI/bge-base-en-v1.5", min_transformers_version="4.57.0.dev0" + ), + "TransformersForSequenceClassification": _HfExamplesInfo( + "papluca/xlm-roberta-base-language-detection", + min_transformers_version="4.57.0.dev0", + ), + "TransformersForCausalLM": _HfExamplesInfo( + "hmellor/Ilama-3.2-1B", trust_remote_code=True + ), "TransformersForMultimodalLM": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), - "TransformersMoEForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0"), # noqa: E501 - "TransformersMoEForMultimodalLM": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0"), # noqa: E501 - "TransformersMoEEmbeddingModel": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501 - "TransformersMoEForSequenceClassification": _HfExamplesInfo("Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0"), # noqa: E501 + "TransformersMoEForCausalLM": _HfExamplesInfo( + "allenai/OLMoE-1B-7B-0924", min_transformers_version="4.57.0.dev0" + ), + "TransformersMoEForMultimodalLM": _HfExamplesInfo( + "Qwen/Qwen3-VL-30B-A3B-Instruct", min_transformers_version="4.57.0.dev0" + ), + "TransformersMoEEmbeddingModel": _HfExamplesInfo( + "Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0" + ), + "TransformersMoEForSequenceClassification": _HfExamplesInfo( + "Qwen/Qwen3-30B-A3B", min_transformers_version="4.57.0.dev0" + ), } _EXAMPLE_MODELS = { @@ -687,7 +931,12 @@ def get_supported_archs(self) -> Set[str]: return self.hf_models.keys() def get_hf_info(self, model_arch: str) -> _HfExamplesInfo: - return self.hf_models[model_arch] + try: + return self.hf_models[model_arch] + except KeyError: + raise ValueError( + f"No example model defined for {model_arch}; please update this file." + ) from None def find_hf_info(self, model_id: str) -> _HfExamplesInfo: for info in self.hf_models.values(): @@ -699,7 +948,9 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: if any(extra == model_id for extra in info.extras.values()): return info - raise ValueError(f"No example model defined for {model_id}") + raise ValueError( + f"No example model defined for {model_id}; please update this file." + ) HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 1db0dc3da922..f501798ffa36 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -8,13 +8,19 @@ from vllm import LLM from vllm.utils import GiB_bytes -from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config, - get_kv_cache_configs) +from vllm.v1.core.kv_cache_utils import ( + generate_scheduler_kv_cache_config, + get_kv_cache_configs, +) from vllm.v1.engine.core import EngineCore as V1EngineCore from ..utils import create_new_process_for_each_test -from .registry import (_TRANSFORMERS_BACKEND_MODELS, AUTO_EXAMPLE_MODELS, - HF_EXAMPLE_MODELS, HfExampleModels) +from .registry import ( + _TRANSFORMERS_BACKEND_MODELS, + AUTO_EXAMPLE_MODELS, + HF_EXAMPLE_MODELS, + HfExampleModels, +) from .utils import dummy_hf_overrides # This minimal list of model architectures is smaller than the total list of @@ -24,23 +30,32 @@ # generation, sequence classification, causal LM, ranking, chat, reward model, # multimodal, geospatial, voice, embedding, MTP) MINIMAL_MODEL_ARCH_LIST = [ - "LlavaForConditionalGeneration", "Llama4ForConditionalGeneration", - "BertForSequenceClassification", "Gemma3nForCausalLM", "JinaVLForRanking", - "InternVLChatModel", "InternLM2ForRewardModel", - "TransformersForMultimodalLM", "PrithviGeoSpatialMAE", "UltravoxModel", - "DeepSeekMTPModel", "XLMRobertaModel" + "LlavaForConditionalGeneration", + "Llama4ForConditionalGeneration", + "BertForSequenceClassification", + "Gemma3nForCausalLM", + "JinaVLForRanking", + "InternVLChatModel", + "InternLM2ForRewardModel", + "TransformersForMultimodalLM", + "PrithviGeoSpatialMAE", + "UltravoxModel", + "DeepSeekMTPModel", + "XLMRobertaModel", ] # This list is the complement of the minimal list above. The intention is that # this list of models is only tested in a "special case" i.e. most PRs should # not test these models -OTHER_MODEL_ARCH_LIST = (set(HF_EXAMPLE_MODELS.get_supported_archs()) - - set(MINIMAL_MODEL_ARCH_LIST)) +OTHER_MODEL_ARCH_LIST = set(HF_EXAMPLE_MODELS.get_supported_archs()) - set( + MINIMAL_MODEL_ARCH_LIST +) @create_new_process_for_each_test() -def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, - EXAMPLE_MODELS: HfExampleModels): +def can_initialize( + model_arch: str, monkeypatch: pytest.MonkeyPatch, EXAMPLE_MODELS: HfExampleModels +): """The reason for using create_new_process_for_each_test is to avoid the WARNING: "We must use the 'spawn' multiprocessing start method. Overriding @@ -53,12 +68,12 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - hf_overrides_fn = partial(dummy_hf_overrides, - model_arch=model_arch, - exist_overrides=model_info.hf_overrides, - use_original_num_layers=getattr( - model_info, 'use_original_num_layers', - False)) + hf_overrides_fn = partial( + dummy_hf_overrides, + model_arch=model_arch, + exist_overrides=model_info.hf_overrides, + use_original_num_layers=getattr(model_info, "use_original_num_layers", False), + ) # Avoid calling model.forward() def _initialize_kv_caches_v1(self, vllm_config): @@ -68,14 +83,15 @@ def _initialize_kv_caches_v1(self, vllm_config): kv_cache_specs, [10 * GiB_bytes], ) - scheduler_kv_cache_config = generate_scheduler_kv_cache_config( - kv_cache_configs) + scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config return 1, 0, scheduler_kv_cache_config - with (patch.object(V1EngineCore, "_initialize_kv_caches", - _initialize_kv_caches_v1), monkeypatch.context() as m): + with ( + patch.object(V1EngineCore, "_initialize_kv_caches", _initialize_kv_caches_v1), + monkeypatch.context() as m, + ): if model_info.v0_only: # NOTE(woosuk): skip the test for V0-only models return @@ -97,21 +113,24 @@ def _initialize_kv_caches_v1(self, vllm_config): speculative_config={ "model": model_info.speculative_model, "num_speculative_tokens": 1, - } if model_info.speculative_model else None, + } + if model_info.speculative_model + else None, trust_remote_code=model_info.trust_remote_code, max_model_len=model_info.max_model_len, # these tests seem to produce leftover memory gpu_memory_utilization=0.80, load_format="dummy", model_impl="transformers" - if model_arch in _TRANSFORMERS_BACKEND_MODELS else "vllm", + if model_arch in _TRANSFORMERS_BACKEND_MODELS + else "vllm", hf_overrides=hf_overrides_fn, - max_num_seqs=model_info.max_num_seqs) + max_num_seqs=model_info.max_num_seqs, + ) @pytest.mark.parametrize("model_arch", MINIMAL_MODEL_ARCH_LIST) -def test_can_initialize_small_subset(model_arch: str, - monkeypatch: pytest.MonkeyPatch): +def test_can_initialize_small_subset(model_arch: str, monkeypatch: pytest.MonkeyPatch): """Test initializing small subset of supported models""" if model_arch == "Lfm2ForCausalLM": pytest.skip("Skipping until test supports V1-only models") @@ -119,10 +138,9 @@ def test_can_initialize_small_subset(model_arch: str, @pytest.mark.parametrize("model_arch", OTHER_MODEL_ARCH_LIST) -def test_can_initialize_large_subset(model_arch: str, - monkeypatch: pytest.MonkeyPatch): +def test_can_initialize_large_subset(model_arch: str, monkeypatch: pytest.MonkeyPatch): """Test initializing large subset of supported models - + This test covers the complement of the tests covered in the "small subset" test. """ @@ -131,8 +149,6 @@ def test_can_initialize_large_subset(model_arch: str, can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) -@pytest.mark.parametrize("model_arch", - AUTO_EXAMPLE_MODELS.get_supported_archs()) -def test_implicit_converted_models(model_arch: str, - monkeypatch: pytest.MonkeyPatch): +@pytest.mark.parametrize("model_arch", AUTO_EXAMPLE_MODELS.get_supported_archs()) +def test_implicit_converted_models(model_arch: str, monkeypatch: pytest.MonkeyPatch): can_initialize(model_arch, monkeypatch, AUTO_EXAMPLE_MODELS) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 4aa7bb729789..15e94eef4aa0 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -50,9 +50,9 @@ def test_oot_registration_embedding( with monkeypatch.context() as m: m.setenv("VLLM_PLUGINS", "register_dummy_model") prompts = ["Hello, my name is", "The text does not matter"] - llm = LLM(model=dummy_gemma2_embedding_path, - load_format="dummy", - max_model_len=2048) + llm = LLM( + model=dummy_gemma2_embedding_path, load_format="dummy", max_model_len=2048 + ) outputs = llm.embed(prompts) for output in outputs: @@ -69,27 +69,28 @@ def test_oot_registration_multimodal( ): with monkeypatch.context() as m: m.setenv("VLLM_PLUGINS", "register_dummy_model") - prompts = [{ - "prompt": "What's in the image?<image>", - "multi_modal_data": { - "image": image + prompts = [ + { + "prompt": "What's in the image?<image>", + "multi_modal_data": {"image": image}, }, - }, { - "prompt": "Describe the image<image>", - "multi_modal_data": { - "image": image + { + "prompt": "Describe the image<image>", + "multi_modal_data": {"image": image}, }, - }] + ] sampling_params = SamplingParams(temperature=0) - llm = LLM(model=dummy_llava_path, - load_format="dummy", - max_num_seqs=1, - trust_remote_code=True, - gpu_memory_utilization=0.98, - max_model_len=4096, - enforce_eager=True, - limit_mm_per_prompt={"image": 1}) + llm = LLM( + model=dummy_llava_path, + load_format="dummy", + max_num_seqs=1, + trust_remote_code=True, + gpu_memory_utilization=0.98, + max_model_len=4096, + enforce_eager=True, + limit_mm_per_prompt={"image": 1}, + ) first_token = llm.get_tokenizer().decode(0) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index f67d4017eeee..9017a0fd9140 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -6,16 +6,22 @@ import pytest import torch.cuda -from vllm.model_executor.models import (is_pooling_model, - is_text_generation_model, - supports_multimodal) -from vllm.model_executor.models.adapters import (as_embedding_model, - as_reward_model, - as_seq_cls_model) -from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS, - _SPECULATIVE_DECODING_MODELS, - _TEXT_GENERATION_MODELS, - ModelRegistry) +from vllm.model_executor.models import ( + is_pooling_model, + is_text_generation_model, + supports_multimodal, +) +from vllm.model_executor.models.adapters import ( + as_embedding_model, + as_reward_model, + as_seq_cls_model, +) +from vllm.model_executor.models.registry import ( + _MULTIMODAL_MODELS, + _SPECULATIVE_DECODING_MODELS, + _TEXT_GENERATION_MODELS, + ModelRegistry, +) from vllm.platforms import current_platform from ..utils import create_new_process_for_each_test @@ -34,8 +40,7 @@ def test_registry_imports(model_arch): if model_arch in _SPECULATIVE_DECODING_MODELS: return # Ignore these models which do not have a unified format - if (model_arch in _TEXT_GENERATION_MODELS - or model_arch in _MULTIMODAL_MODELS): + if model_arch in _TEXT_GENERATION_MODELS or model_arch in _MULTIMODAL_MODELS: assert is_text_generation_model(model_cls) # All vLLM models should be convertible to a pooling model @@ -48,13 +53,16 @@ def test_registry_imports(model_arch): @create_new_process_for_each_test() -@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [ - ("LlamaForCausalLM", False, False, False), - ("LlavaForConditionalGeneration", True, True, False), - ("BertForSequenceClassification", False, False, True), - ("RobertaForSequenceClassification", False, False, True), - ("XLMRobertaForSequenceClassification", False, False, True), -]) +@pytest.mark.parametrize( + "model_arch,is_mm,init_cuda,is_ce", + [ + ("LlamaForCausalLM", False, False, False), + ("LlavaForConditionalGeneration", True, True, False), + ("BertForSequenceClassification", False, False, True), + ("RobertaForSequenceClassification", False, False, True), + ("XLMRobertaForSequenceClassification", False, False, True), + ], +) def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): model_info = ModelRegistry._try_inspect_model_cls(model_arch) assert model_info is not None @@ -70,7 +78,8 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): warnings.warn( "This model no longer initializes CUDA on import. " "Please test using a different one.", - stacklevel=2) + stacklevel=2, + ) @create_new_process_for_each_test() @@ -82,7 +91,8 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): # ("MLPSpeculatorPreTrainedModel", False, False), ("DeepseekV2ForCausalLM", True, False), ("Qwen2VLForConditionalGeneration", True, True), - ]) + ], +) def test_registry_is_pp(model_arch, is_pp, init_cuda): model_info = ModelRegistry._try_inspect_model_cls(model_arch) assert model_info is not None @@ -97,13 +107,16 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda): warnings.warn( "This model no longer initializes CUDA on import. " "Please test using a different one.", - stacklevel=2) + stacklevel=2, + ) def test_hf_registry_coverage(): - untested_archs = (ModelRegistry.get_supported_archs() - - HF_EXAMPLE_MODELS.get_supported_archs()) + untested_archs = ( + ModelRegistry.get_supported_archs() - HF_EXAMPLE_MODELS.get_supported_archs() + ) assert not untested_archs, ( "Please add the following architectures to " - f"`tests/models/registry.py`: {untested_archs}") + f"`tests/models/registry.py`: {untested_archs}" + ) diff --git a/tests/models/test_terratorch.py b/tests/models/test_terratorch.py index 842e37ea26f6..cadce5d2b2bb 100644 --- a/tests/models/test_terratorch.py +++ b/tests/models/test_terratorch.py @@ -11,32 +11,33 @@ "model", [ "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", - "mgazz/Prithvi_v2_eo_300_tl_unet_agb" + "mgazz/Prithvi_v2_eo_300_tl_unet_agb", ], ) def test_inference( vllm_runner: type[VllmRunner], model: str, ) -> None: - pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) - prompt = dict(prompt_token_ids=[1], - multi_modal_data=dict(pixel_values=pixel_values, - location_coords=location_coords)) + prompt = dict( + prompt_token_ids=[1], + multi_modal_data=dict( + pixel_values=pixel_values, location_coords=location_coords + ), + ) with vllm_runner( - model, - runner="pooling", - dtype="half", - enforce_eager=True, - skip_tokenizer_init=True, - # Limit the maximum number of sequences to avoid the - # test going OOM during the warmup run - max_num_seqs=32, - default_torch_num_threads=1, + model, + runner="pooling", + dtype="half", + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + default_torch_num_threads=1, ) as vllm_model: - vllm_output = vllm_model.llm.encode(prompt) assert torch.equal( - torch.isnan(vllm_output[0].outputs.data).any(), - torch.tensor(False)) + torch.isnan(vllm_output[0].outputs.data).any(), torch.tensor(False) + ) diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index bd443575127f..b434c0955be7 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Test the functionality of the Transformers backend.""" + from typing import Any, Optional, Union import pytest @@ -60,14 +61,16 @@ def check_implementation( @pytest.mark.skipif( current_platform.is_rocm(), - reason="Llama-3.2-1B-Instruct, Ilama-3.2-1B produce memory access fault.") + reason="Llama-3.2-1B-Instruct, Ilama-3.2-1B produce memory access fault.", +) @pytest.mark.parametrize( "model,model_impl", [ ("meta-llama/Llama-3.2-1B-Instruct", "transformers"), ("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE ("allenai/OLMoE-1B-7B-0924", "transformers"), # MoE - ]) # trust_remote_code=True by default + ], +) # trust_remote_code=True by default def test_models( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], @@ -77,29 +80,32 @@ def test_models( ) -> None: import transformers from packaging.version import Version + installed = Version(transformers.__version__) required = Version("4.57.0.dev0") if model == "allenai/OLMoE-1B-7B-0924" and installed < required: - pytest.skip("MoE models with the Transformers backend require " - f"transformers>={required}, but got {installed}") + pytest.skip( + "MoE models with the Transformers backend require " + f"transformers>={required}, but got {installed}" + ) - check_implementation(hf_runner, - vllm_runner, - example_prompts, - model, - model_impl=model_impl) + check_implementation( + hf_runner, vllm_runner, example_prompts, model, model_impl=model_impl + ) def test_hybrid_attention(vllm_runner: type[VllmRunner]) -> None: prompts, _, _ = prep_prompts(4, (800, 801)) kwargs_ref = {"max_model_len": 8192, "enforce_eager": True} kwargs_test = {"model_impl": "transformers", **kwargs_ref} - check_implementation(vllm_runner, - vllm_runner, - prompts, - model="hmellor/tiny-random-Gemma2ForCausalLM", - kwargs_ref=kwargs_ref, - kwargs_test=kwargs_test) + check_implementation( + vllm_runner, + vllm_runner, + prompts, + model="hmellor/tiny-random-Gemma2ForCausalLM", + kwargs_ref=kwargs_ref, + kwargs_test=kwargs_test, + ) @multi_gpu_test(num_gpus=2) @@ -109,23 +115,28 @@ def test_distributed( example_prompts, ): kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2} - check_implementation(hf_runner, - vllm_runner, - example_prompts, - "meta-llama/Llama-3.2-1B-Instruct", - kwargs_test=kwargs) + check_implementation( + hf_runner, + vllm_runner, + example_prompts, + "meta-llama/Llama-3.2-1B-Instruct", + kwargs_test=kwargs, + ) -@pytest.mark.parametrize("model, quantization_kwargs", [ - ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}), - ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}), - ( - "meta-llama/Llama-3.2-1B-Instruct", - { - "quantization": "bitsandbytes", - }, - ), -]) +@pytest.mark.parametrize( + "model, quantization_kwargs", + [ + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}), + ("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}), + ( + "meta-llama/Llama-3.2-1B-Instruct", + { + "quantization": "bitsandbytes", + }, + ), + ], +) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) def test_quantization( @@ -136,27 +147,34 @@ def test_quantization( max_tokens: int, num_logprobs: int, ) -> None: - if (current_platform.is_rocm() - and quantization_kwargs.get("quantization", "") == "bitsandbytes"): - pytest.skip( - "bitsandbytes quantization is currently not supported in rocm.") + if ( + current_platform.is_rocm() + and quantization_kwargs.get("quantization", "") == "bitsandbytes" + ): + pytest.skip("bitsandbytes quantization is currently not supported in rocm.") with vllm_runner( - model, model_impl="auto", enforce_eager=True, - **quantization_kwargs) as vllm_model: # type: ignore[arg-type] + model, + model_impl="auto", + enforce_eager=True, + **quantization_kwargs, # type: ignore[arg-type] + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs) + example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs + ) with vllm_runner( - model, - model_impl="transformers", - enforce_eager=True, - **quantization_kwargs) as vllm_model: # type: ignore[arg-type] + model, + model_impl="transformers", + enforce_eager=True, + **quantization_kwargs, # type: ignore[arg-type] + ) as vllm_model: model_config = vllm_model.llm.llm_engine.model_config assert model_config.using_transformers_backend() transformers_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs) + example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs + ) check_logprobs_close( outputs_0_lst=transformers_outputs, @@ -172,22 +190,24 @@ def test_quantization( # Layers live in `layers` "Qwen/Qwen3-Embedding-0.6B", # Layers live in `model.layers` - "meta-llama/Llama-3.2-1B-Instruct" + "meta-llama/Llama-3.2-1B-Instruct", ], ) def test_embed_loading(vllm_runner, model): - with vllm_runner(model, - max_model_len=1024, - enforce_eager=True, - runner="pooling", - model_impl="transformers") as model_test: + with vllm_runner( + model, + max_model_len=1024, + enforce_eager=True, + runner="pooling", + model_impl="transformers", + ) as model_test: model_config = model_test.llm.llm_engine.model_config assert model_config.using_transformers_backend() @pytest.mark.parametrize( - "arch", - ["TransformersEmbeddingModel", "TransformersForSequenceClassification"]) + "arch", ["TransformersEmbeddingModel", "TransformersForSequenceClassification"] +) def test_pooling(hf_runner, vllm_runner, example_prompts, arch): model = get_model(arch) @@ -202,6 +222,7 @@ def test_pooling(hf_runner, vllm_runner, example_prompts, arch): hf_kwargs["is_sentence_transformer"] = True elif arch == "TransformersForSequenceClassification": from transformers import AutoModelForSequenceClassification + hf_kwargs["auto_cls"] = AutoModelForSequenceClassification # The example_prompts has ending "\n", for example: @@ -212,8 +233,10 @@ def test_pooling(hf_runner, vllm_runner, example_prompts, arch): # So we need to strip the input texts to avoid test failing. example_prompts = [str(s).strip() for s in example_prompts] - with (vllm_runner(model, **vllm_kwargs) as - vllm_model, hf_runner(model, **hf_kwargs) as hf_model): + with ( + vllm_runner(model, **vllm_kwargs) as vllm_model, + hf_runner(model, **hf_kwargs) as hf_model, + ): model_config = vllm_model.llm.llm_engine.model_config assert model_config.using_transformers_backend() diff --git a/tests/models/test_utils.py b/tests/models/test_utils.py index 9b87b1a9d46c..7cc4ee3c1856 100644 --- a/tests/models/test_utils.py +++ b/tests/models/test_utils.py @@ -10,7 +10,6 @@ class ModuleWithBatchNorm(torch.nn.Module): - def __init__(self): super().__init__() self.bn = torch.nn.BatchNorm1d(2) @@ -20,7 +19,6 @@ def forward(self, x): class ModuleWithNestedBatchNorm(torch.nn.Module): - def __init__(self): super().__init__() self.nested_mod = ModuleWithBatchNorm() @@ -67,9 +65,11 @@ def weight_generator(): new_mod = ModuleWithNestedBatchNorm() assert not torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) assert not torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var + ) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 loader = AutoWeightsLoader(new_mod) @@ -77,9 +77,9 @@ def weight_generator(): # Ensure the stats are updated assert torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) - assert torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) + assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 @@ -101,9 +101,11 @@ def weight_generator(): new_mod = ModuleWithNestedBatchNorm() assert not torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) assert not torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var + ) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."]) @@ -111,9 +113,9 @@ def weight_generator(): # Ensure the stats are updated assert torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) - assert torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) + assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 @@ -137,9 +139,11 @@ def weight_generator(): new_mod = ModuleWithNestedBatchNorm() assert not torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) assert not torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var + ) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."]) @@ -147,7 +151,7 @@ def weight_generator(): # Ensure the stats are updated assert torch.all( - new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) - assert torch.all( - new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) + new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean + ) + assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1 diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 60ea2447e984..b323bca79f4e 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -8,11 +8,16 @@ from tests.utils import multi_gpu_test from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.models.vision import ( - get_load_balance_assignment, resolve_visual_encoder_outputs, - run_dp_sharded_mrope_vision_model, run_dp_sharded_vision_model) + get_load_balance_assignment, + resolve_visual_encoder_outputs, + run_dp_sharded_mrope_vision_model, + run_dp_sharded_vision_model, +) from vllm.platforms import current_platform from vllm.utils import get_open_port, update_environment_variables @@ -20,8 +25,7 @@ @pytest.mark.parametrize( - ("select_layers", "num_layers_loaded", "max_possible_layers", - "expected_features"), + ("select_layers", "num_layers_loaded", "max_possible_layers", "expected_features"), [ # All layers loaded ([1, 10], 10, 10, [1, 10]), @@ -29,16 +33,15 @@ # Some layers not loaded ([1, 10], 10, 20, [1, 10]), ([-20, -11], 10, 20, [1, 10]), - ]) -def test_resolve_visual_encoder_outputs(select_layers, num_layers_loaded, - max_possible_layers, - expected_features): + ], +) +def test_resolve_visual_encoder_outputs( + select_layers, num_layers_loaded, max_possible_layers, expected_features +): """ Test that offsets are correctly handled for vision feature layers. """ - encoder_outputs = [ - torch.tensor([idx]) for idx in range(num_layers_loaded + 1) - ] + encoder_outputs = [torch.tensor([idx]) for idx in range(num_layers_loaded + 1)] output_tensor = resolve_visual_encoder_outputs( encoder_outputs=encoder_outputs, post_layer_norm=None, @@ -85,10 +88,11 @@ def test_run_dp_sharded_vision_model(batch_size: int): ) -def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, - batch_size: int, master_port: int): +def run_dp_sharded_vision_model_vs_direct( + local_rank: int, world_size: int, batch_size: int, master_port: int +): """ - Test that run_dp_sharded_vision_model produces the same results as + Test that run_dp_sharded_vision_model produces the same results as calling the model directly. """ @@ -99,13 +103,15 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, current_platform.set_device(device) torch.set_default_device(device) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + } + ) # initialize distributed init_distributed_environment() @@ -141,28 +147,45 @@ def run_dp_sharded_vision_model_vs_direct(local_rank: int, world_size: int, [ # Empty input ([], 2, [], [0, 0], [0, 0], "empty input"), - # Fewer samples than GPUs - ([100, 200], 4, [1, 0], [1, 1, 0, 0], [200, 100, 0, 0 - ], "fewer samples than GPUs"), - + ( + [100, 200], + 4, + [1, 0], + [1, 1, 0, 0], + [200, 100, 0, 0], + "fewer samples than GPUs", + ), # Single GPU ([100, 200, 300], 1, [2, 1, 0], [3], [600], "single GPU"), - # Balanced assignment - ([100, 100, 100, 100 - ], 2, [0, 2, 1, 3], [2, 2], [200, 200], "balanced assignment"), - + ( + [100, 100, 100, 100], + 2, + [0, 2, 1, 3], + [2, 2], + [200, 200], + "balanced assignment", + ), # Unbalanced sizes - this one is trickier since the algorithm is greedy - ([1000, 100, 200, 50], 2, [0, 2, 1, 3 - ], [1, 3], [1000, 350], "unbalanced sizes"), + ( + [1000, 100, 200, 50], + 2, + [0, 2, 1, 3], + [1, 3], + [1000, 350], + "unbalanced sizes", + ), ], ) -def test_get_load_balance_assignment_cases(sizes, num_gpus, - expected_shuffle_indices, - expected_gpu_sample_counts, - expected_grouped_sizes_per_gpu, - test_description): +def test_get_load_balance_assignment_cases( + sizes, + num_gpus, + expected_shuffle_indices, + expected_gpu_sample_counts, + expected_grouped_sizes_per_gpu, + test_description, +): """Test get_load_balance_assignment with various input cases.""" result = get_load_balance_assignment(sizes, num_gpus=num_gpus) (shuffle_indices, gpu_sample_counts, grouped_sizes_per_gpu) = result @@ -188,8 +211,7 @@ def __init__(self, spatial_merge_size: int = 2, out_hidden_size: int = 64): self.out_hidden_size = out_hidden_size self.linear = torch.nn.Linear(768, out_hidden_size) - def forward(self, pixel_values: torch.Tensor, - grid_thw_list: list[list[int]]): + def forward(self, pixel_values: torch.Tensor, grid_thw_list: list[list[int]]): """Simple forward pass that simulates spatial merging.""" # Apply linear transformation embeddings = self.linear(pixel_values) @@ -212,8 +234,9 @@ def forward(self, pixel_values: torch.Tensor, merged_patches = num_patches // merge_factor if merged_patches > 0: # Reshape and average to simulate merging - reshaped = image_patches[:merged_patches * merge_factor].view( - merged_patches, merge_factor, -1) + reshaped = image_patches[: merged_patches * merge_factor].view( + merged_patches, merge_factor, -1 + ) merged = reshaped.mean(dim=1) merged_embeddings.append(merged) @@ -222,9 +245,11 @@ def forward(self, pixel_values: torch.Tensor, if merged_embeddings: return torch.cat(merged_embeddings, dim=0) else: - return torch.empty((0, self.out_hidden_size), - device=pixel_values.device, - dtype=pixel_values.dtype) + return torch.empty( + (0, self.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) @multi_gpu_test(num_gpus=2) @@ -250,12 +275,11 @@ def test_run_dp_sharded_mrope_vision_model(batch_size: int): ) -def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, - world_size: int, - batch_size: int, - master_port: int): +def run_dp_sharded_mrope_vision_model_vs_direct( + local_rank: int, world_size: int, batch_size: int, master_port: int +): """ - Test that run_dp_sharded_mrope_vision_model produces the same results as + Test that run_dp_sharded_mrope_vision_model produces the same results as calling the model directly. """ # Set random seed for reproducibility @@ -264,13 +288,15 @@ def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, current_platform.set_device(device) torch.set_default_device(device) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + } + ) # initialize distributed init_distributed_environment() @@ -303,10 +329,9 @@ def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, # Run the model through the sharded function with torch.inference_mode(): - sharded_output = run_dp_sharded_mrope_vision_model(vision_model, - pixel_values, - grid_thw_list, - rope_type="rope_3d") + sharded_output = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list, rope_type="rope_3d" + ) sharded_output = torch.cat(sharded_output, dim=0) # Check that the world size is set up correctly @@ -317,10 +342,7 @@ def run_dp_sharded_mrope_vision_model_vs_direct(local_rank: int, # Check that the outputs have the same shape assert direct_output.shape == sharded_output.shape # Check that the outputs are close (they should be identical) - assert torch.allclose(direct_output, - sharded_output, - rtol=1e-5, - atol=1e-5) + assert torch.allclose(direct_output, sharded_output, rtol=1e-5, atol=1e-5) @multi_gpu_test(num_gpus=2) @@ -334,20 +356,23 @@ def test_run_dp_sharded_mrope_vision_model_empty_input(): def run_dp_sharded_mrope_vision_model_empty_input_worker( - local_rank: int, world_size: int, master_port: int): + local_rank: int, world_size: int, master_port: int +): """Test run_dp_sharded_mrope_vision_model with empty input.""" # Set up distributed environment device = f"{current_platform.device_name}:{local_rank}" current_platform.set_device(device) torch.set_default_device(device) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + } + ) init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) @@ -360,10 +385,9 @@ def run_dp_sharded_mrope_vision_model_empty_input_worker( # Should handle empty input gracefully with torch.inference_mode(): - output = run_dp_sharded_mrope_vision_model(vision_model, - pixel_values, - grid_thw_list, - rope_type="rope_3d") + output = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list, rope_type="rope_3d" + ) assert len(output) == 0 @@ -379,7 +403,8 @@ def test_run_dp_sharded_mrope_vision_model_uneven_load(): def run_dp_sharded_mrope_vision_model_uneven_load_worker( - local_rank: int, world_size: int, master_port: int): + local_rank: int, world_size: int, master_port: int +): """Test run_dp_sharded_mrope_vision_model with uneven load distribution.""" # Set up distributed environment current_platform.seed_everything(123) @@ -387,13 +412,15 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker( current_platform.set_device(device) torch.set_default_device(device) - update_environment_variables({ - 'RANK': str(local_rank), - 'LOCAL_RANK': str(local_rank), - 'WORLD_SIZE': str(world_size), - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': str(master_port), - }) + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + } + ) init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=world_size) @@ -401,7 +428,7 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker( # Create images with very different sizes grid_thw_list = [ [1, 2, 2], # Small: 4 patches - [1, 8, 8], # Large: 64 patches + [1, 8, 8], # Large: 64 patches [1, 3, 3], # Medium: 9 patches ] @@ -416,15 +443,15 @@ def run_dp_sharded_mrope_vision_model_uneven_load_worker( # Should handle uneven distribution without errors with torch.inference_mode(): - output_tuple = run_dp_sharded_mrope_vision_model(vision_model, - pixel_values, - grid_thw_list, - rope_type="rope_3d") + output_tuple = run_dp_sharded_mrope_vision_model( + vision_model, pixel_values, grid_thw_list, rope_type="rope_3d" + ) # Verify output shape is reasonable merge_factor = vision_model.spatial_merge_size**2 expected_output_patches = list( - math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list) + math.prod(grid_thw) // merge_factor for grid_thw in grid_thw_list + ) for i, output in enumerate(output_tuple): assert output.shape[0] == expected_output_patches[i] @@ -445,8 +472,9 @@ def test_simple_mrope_vision_model_spatial_merge(spatial_merge_size: int): pixel_values_list.append(image_pixels) pixel_values = torch.cat(pixel_values_list, dim=0) - vision_model = SimpleMRopeVisionModel( - spatial_merge_size=spatial_merge_size).to(device) + vision_model = SimpleMRopeVisionModel(spatial_merge_size=spatial_merge_size).to( + device + ) with torch.inference_mode(): output = vision_model(pixel_values, grid_thw_list) diff --git a/tests/models/utils.py b/tests/models/utils.py index 50936114865a..c20e50ff1bff 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -33,16 +33,18 @@ def check_outputs_equal( """ assert len(outputs_0_lst) == len(outputs_1_lst) - for prompt_idx, (outputs_0, - outputs_1) in enumerate(zip(outputs_0_lst, - outputs_1_lst)): + for prompt_idx, (outputs_0, outputs_1) in enumerate( + zip(outputs_0_lst, outputs_1_lst) + ): output_ids_0, output_str_0 = outputs_0 output_ids_1, output_str_1 = outputs_1 # The text and token outputs should exactly match - fail_msg = (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + fail_msg = ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}" + ) assert output_str_0 == output_str_1, fail_msg assert output_ids_0 == output_ids_1, fail_msg @@ -54,9 +56,9 @@ def check_outputs_equal( # * List of top sample logprobs for each sampled token # # Assumes prompt logprobs were not requested. -TokensTextLogprobs = tuple[list[int], str, Optional[Union[list[dict[int, - float]], - SampleLogprobs]]] +TokensTextLogprobs = tuple[ + list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]] +] # Allow for tokens to be represented as str's rather than IDs; # tuple of @@ -65,9 +67,9 @@ def check_outputs_equal( # * Optional list of top sample logprobs for each sampled token # # Assumes prompt logprobs were not requested. -TextTextLogprobs = tuple[list[str], str, Optional[Union[list[dict[str, float]], - list[dict[str, - Logprob]]]]] +TextTextLogprobs = tuple[ + list[str], str, Optional[Union[list[dict[str, float]], list[dict[str, Logprob]]]] +] # Representation of generated sequence as a tuple of # * Token ID list @@ -77,18 +79,21 @@ def check_outputs_equal( # # Allows prompt logprobs to be requested. TokensTextLogprobsPromptLogprobs = tuple[ - list[int], str, Optional[Union[list[dict[int, float]], SampleLogprobs]], - Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]]] + list[int], + str, + Optional[Union[list[dict[int, float]], SampleLogprobs]], + Optional[Union[list[Optional[dict[int, float]]], PromptLogprobs]], +] def check_logprobs_close( *, - outputs_0_lst: Sequence[Union[TokensTextLogprobs, - TokensTextLogprobsPromptLogprobs, - TextTextLogprobs]], - outputs_1_lst: Sequence[Union[TokensTextLogprobs, - TokensTextLogprobsPromptLogprobs, - TextTextLogprobs]], + outputs_0_lst: Sequence[ + Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs] + ], + outputs_1_lst: Sequence[ + Union[TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, TextTextLogprobs] + ], name_0: str, name_1: str, num_outputs_0_skip_tokens: int = 0, @@ -128,9 +133,9 @@ def check_logprobs_close( assert len(outputs_0_lst) == len(outputs_1_lst) # Loop through responses to each prompt. - for prompt_idx, (outputs_0, - outputs_1) in enumerate(zip(outputs_0_lst, - outputs_1_lst)): + for prompt_idx, (outputs_0, outputs_1) in enumerate( + zip(outputs_0_lst, outputs_1_lst) + ): assert len(outputs_0) == len(outputs_1) if len(outputs_0) == 3: assert len(outputs_1) == 3 @@ -155,17 +160,18 @@ def check_logprobs_close( ) = outputs_1 # Test prompt logprobs closeness - if (prompt_logprobs_0 is not None - and prompt_logprobs_1 is not None): + if prompt_logprobs_0 is not None and prompt_logprobs_1 is not None: # Both sequences' prompt logprobs lists are not `None`` # (although individual list elements may be `None`); # for each token's logprobs: for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate( - zip(prompt_logprobs_0, prompt_logprobs_1)): + zip(prompt_logprobs_0, prompt_logprobs_1) + ): fail_msg = ( f"Prompt logprobs test:" f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}" - f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}") + f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}" + ) if logprobs_elem_0 is None: # If the seq 0 token's logprobs are `None`, @@ -176,20 +182,24 @@ def check_logprobs_close( # the seq 1 token's logprobs must not be `None` assert logprobs_elem_1 is not None, fail_msg # Logprobs check: top-k token choices must be the same - assert (set(logprobs_elem_0.keys()) == set( - logprobs_elem_1.keys())), fail_msg + assert set(logprobs_elem_0.keys()) == set( + logprobs_elem_1.keys() + ), fail_msg else: # Both sequence logprobs lists must be `None` - fail_msg = (f"Prompt logprobs test:" - f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}" - f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}") + fail_msg = ( + f"Prompt logprobs test:" + f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}" + f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}" + ) - assert (prompt_logprobs_0 is None - and prompt_logprobs_1 is None), fail_msg + assert prompt_logprobs_0 is None and prompt_logprobs_1 is None, fail_msg else: - raise ValueError(f"Outputs tuple must have 3 or 4 elements but " - f"{len(outputs_0)} elements were provided: " - f"{outputs_0}") + raise ValueError( + f"Outputs tuple must have 3 or 4 elements but " + f"{len(outputs_0)} elements were provided: " + f"{outputs_0}" + ) if logprobs_0 is None: logprobs_0 = [None] * len(output_ids_0) @@ -206,9 +216,9 @@ def check_logprobs_close( logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:] # Loop through generated tokens. - for idx, (output_id_0, - output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): - + for idx, (output_id_0, output_id_1) in enumerate( + zip(output_ids_0, output_ids_1) + ): is_tok_mismatch = output_id_0 != output_id_1 # If generated tokens don't match @@ -223,7 +233,8 @@ def check_logprobs_close( f"Test{prompt_idx}:" f"\nMatched tokens:\t{output_ids_0[:idx]}" f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}" - f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}") + f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}" + ) assert logprobs_elem_0 is not None, fail_msg assert logprobs_elem_1 is not None, fail_msg @@ -244,9 +255,11 @@ def check_logprobs_close( if output_str_0 != output_str_1 and warn_on_mismatch: # The token outputs exactly match, # so the text outputs should exactly match as well - fail_msg = (f"Test{prompt_idx}:" - f"\n{name_0}:\t{output_str_0!r}" - f"\n{name_1}:\t{output_str_1!r}") + fail_msg = ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}" + ) with warnings.catch_warnings(): # This ensures that repeated warnings are shown @@ -317,18 +330,22 @@ def check_embeddings_close( assert len(embeddings_0_lst) == len(embeddings_1_lst) for prompt_idx, (embeddings_0, embeddings_1) in enumerate( - zip(embeddings_0_lst, embeddings_1_lst)): + zip(embeddings_0_lst, embeddings_1_lst) + ): assert len(embeddings_0) == len(embeddings_1), ( - f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}") + f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}" + ) - sim = F.cosine_similarity(torch.tensor(embeddings_0), - torch.tensor(embeddings_1), - dim=0) + sim = F.cosine_similarity( + torch.tensor(embeddings_0), torch.tensor(embeddings_1), dim=0 + ) - fail_msg = (f"Test{prompt_idx}:" - f"\nCosine similarity: \t{sim:.4f}" - f"\n{name_0}:\t{embeddings_0[:16]!r}" - f"\n{name_1}:\t{embeddings_1[:16]!r}") + fail_msg = ( + f"Test{prompt_idx}:" + f"\nCosine similarity: \t{sim:.4f}" + f"\n{name_0}:\t{embeddings_0[:16]!r}" + f"\n{name_1}:\t{embeddings_1[:16]!r}" + ) assert sim >= 1 - tol, fail_msg @@ -413,20 +430,19 @@ def dummy_hf_overrides( # Ensure at least 2 expert per group # Since `grouped_topk` assumes top-2 - n_group = getattr(text_config, 'n_group', None) + n_group = getattr(text_config, "n_group", None) num_experts = n_group * 2 if n_group is not None else 2 # we use three layers for Gemma-3n to check # both normal layer and kv_shared_layer if use_original_num_layers: # Use the original number of layers from the config - num_layers = getattr(text_config, 'num_layers', 1) - num_hidden_layers = getattr(text_config, 'num_hidden_layers', 1) + num_layers = getattr(text_config, "num_layers", 1) + num_hidden_layers = getattr(text_config, "num_hidden_layers", 1) else: # Use minimal layers for testing num_layers = 1 - num_hidden_layers = (3 if model_arch - == "Gemma3nForConditionalGeneration" else 1) + num_hidden_layers = 3 if model_arch == "Gemma3nForConditionalGeneration" else 1 update_dict = { "num_layers": num_layers, @@ -440,53 +456,63 @@ class DummyConfig: # Only set MoE related config when the model has MoE layers. # Otherwise all models detected as MoE by _get_transformers_backend_cls. if ModelConfig.get_num_experts(DummyConfig) > 0: - update_dict.update({ - "num_experts": num_experts, - "num_experts_per_tok": 2, - "num_local_experts": num_experts, - # Otherwise there will not be any expert layers - "first_k_dense_replace": 0, - # To avoid OOM on DeepSeek-V3 - "n_routed_experts": num_experts, - }) + update_dict.update( + { + "num_experts": num_experts, + "num_experts_per_tok": 2, + "num_local_experts": num_experts, + # Otherwise there will not be any expert layers + "first_k_dense_replace": 0, + # To avoid OOM on DeepSeek-V3 + "n_routed_experts": num_experts, + } + ) # Update num_hidden_layers for non-Longcat architectures - if model_arch != "LongcatFlashForCausalLM" \ - and model_arch != "LongCatFlashMTPModel": + if model_arch != "LongcatFlashForCausalLM" and model_arch != "LongCatFlashMTPModel": update_dict["num_hidden_layers"] = num_hidden_layers text_config.update(update_dict) if hasattr(hf_config, "vision_config"): - hf_config.vision_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) + hf_config.vision_config.update( + { + "num_layers": 1, + "num_hidden_layers": 1, + } + ) # e.g.: ibm-granite/granite-speech-3.3-2b if hasattr(hf_config, "encoder_config"): - hf_config.encoder_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - }) + hf_config.encoder_config.update( + { + "num_layers": 1, + "num_hidden_layers": 1, + } + ) # e.g.: Qwen/Qwen2-Audio-7B-Instruct if hasattr(hf_config, "audio_config"): - hf_config.audio_config.update({ - "num_layers": 1, - "num_hidden_layers": 1, - "encoder_layers": 1, - }) + hf_config.audio_config.update( + { + "num_layers": 1, + "num_hidden_layers": 1, + "encoder_layers": 1, + } + ) return hf_config -def check_transformers_version(model: str, - min_transformers_version: Optional[str] = None, - max_transformers_version: Optional[str] = None): +def check_transformers_version( + model: str, + min_transformers_version: Optional[str] = None, + max_transformers_version: Optional[str] = None, +): from .registry import _HfExamplesInfo - return _HfExamplesInfo(model, - min_transformers_version=min_transformers_version, - max_transformers_version=max_transformers_version - ).check_transformers_version(on_fail="skip") + return _HfExamplesInfo( + model, + min_transformers_version=min_transformers_version, + max_transformers_version=max_transformers_version, + ).check_transformers_version(on_fail="skip") diff --git a/tests/multimodal/test_audio.py b/tests/multimodal/test_audio.py index ba39af845041..189b319e5fcd 100644 --- a/tests/multimodal/test_audio.py +++ b/tests/multimodal/test_audio.py @@ -8,9 +8,12 @@ import numpy as np import pytest -from vllm.multimodal.audio import (AudioMediaIO, AudioResampler, - resample_audio_librosa, - resample_audio_scipy) +from vllm.multimodal.audio import ( + AudioMediaIO, + AudioResampler, + resample_audio_librosa, + resample_audio_scipy, +) @pytest.fixture @@ -21,12 +24,10 @@ def dummy_audio(): def test_resample_audio_librosa(dummy_audio): with patch("vllm.multimodal.audio.librosa.resample") as mock_resample: mock_resample.return_value = dummy_audio * 2 - out = resample_audio_librosa(dummy_audio, - orig_sr=44100, - target_sr=22050) - mock_resample.assert_called_once_with(dummy_audio, - orig_sr=44100, - target_sr=22050) + out = resample_audio_librosa(dummy_audio, orig_sr=44100, target_sr=22050) + mock_resample.assert_called_once_with( + dummy_audio, orig_sr=44100, target_sr=22050 + ) assert np.all(out == dummy_audio * 2) @@ -40,8 +41,7 @@ def test_resample_audio_scipy(dummy_audio): assert np.all(out_same == dummy_audio) -@pytest.mark.xfail( - reason="resample_audio_scipy is buggy for non-integer ratios") +@pytest.mark.xfail(reason="resample_audio_scipy is buggy for non-integer ratios") def test_resample_audio_scipy_non_integer_ratio(dummy_audio): out = resample_audio_scipy(dummy_audio, orig_sr=5, target_sr=3) @@ -54,13 +54,12 @@ def test_resample_audio_scipy_non_integer_ratio(dummy_audio): def test_audio_resampler_librosa_calls_resample(dummy_audio): resampler = AudioResampler(target_sr=22050, method="librosa") - with patch( - "vllm.multimodal.audio.resample_audio_librosa") as mock_resample: + with patch("vllm.multimodal.audio.resample_audio_librosa") as mock_resample: mock_resample.return_value = dummy_audio out = resampler.resample(dummy_audio, orig_sr=44100) - mock_resample.assert_called_once_with(dummy_audio, - orig_sr=44100, - target_sr=22050) + mock_resample.assert_called_once_with( + dummy_audio, orig_sr=44100, target_sr=22050 + ) assert np.all(out == dummy_audio) @@ -69,9 +68,9 @@ def test_audio_resampler_scipy_calls_resample(dummy_audio): with patch("vllm.multimodal.audio.resample_audio_scipy") as mock_resample: mock_resample.return_value = dummy_audio out = resampler.resample(dummy_audio, orig_sr=44100) - mock_resample.assert_called_once_with(dummy_audio, - orig_sr=44100, - target_sr=22050) + mock_resample.assert_called_once_with( + dummy_audio, orig_sr=44100, target_sr=22050 + ) assert np.all(out == dummy_audio) diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 48e88e7c0175..fe983990b90c 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -8,15 +8,20 @@ from vllm.config import ModelConfig, ParallelConfig, VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import (MultiModalCache, - MultiModalProcessorCacheItem, - MultiModalProcessorCacheItemMetadata, - engine_receiver_cache_from_config, - processor_cache_from_config) +from vllm.multimodal.cache import ( + MultiModalCache, + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, + engine_receiver_cache_from_config, + processor_cache_from_config, +) from vllm.multimodal.hasher import MultiModalHasher -from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, - MultiModalKwargsItems, - MultiModalSharedField) +from vllm.multimodal.inputs import ( + MultiModalFieldElem, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalSharedField, +) from vllm.multimodal.processing import PromptInsertion pytestmark = pytest.mark.cpu_test @@ -30,9 +35,9 @@ def _dummy_elem( rng: Optional[np.random.RandomState] = None, ): if rng is None: - data = torch.empty((size, ), dtype=torch.int8) + data = torch.empty((size,), dtype=torch.int8) else: - data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8)) + data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8)) return MultiModalFieldElem( modality=modality, @@ -48,10 +53,9 @@ def _dummy_item( *, rng: Optional[np.random.RandomState] = None, ): - return MultiModalKwargsItem.from_elems([ - _dummy_elem(modality, key, size, rng=rng) - for key, size in size_by_key.items() - ]) + return MultiModalKwargsItem.from_elems( + [_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()] + ) def _dummy_items( @@ -59,31 +63,35 @@ def _dummy_items( *, rng: Optional[np.random.RandomState] = None, ): - return MultiModalKwargsItems.from_seq([ - _dummy_item(modality, size_by_key, rng=rng) - for modality, size_by_key in size_by_key_modality.items() - ]) + return MultiModalKwargsItems.from_seq( + [ + _dummy_item(modality, size_by_key, rng=rng) + for modality, size_by_key in size_by_key_modality.items() + ] + ) -# yapf: disable @pytest.mark.parametrize( ("item", "expected_size"), [ (_dummy_item("a", {"a1": 100}), 100), (_dummy_item("a", {"a1": 100, "a2": 110}), 210), (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501 - (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}).get_data(), 460), # noqa: E501 + ( + _dummy_items( + {"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}} + ).get_data(), + 460, + ), # noqa: E501 ], ) -# yapf: enable def test_cache_item_size(item, expected_size): cache = MultiModalCache.get_lru_cache(2048, type(item)) cache[""] = item assert cache.currsize == expected_size - prompt_update = PromptInsertion("dummy", "target", "insertion") \ - .resolve(0) + prompt_update = PromptInsertion("dummy", "target", "insertion").resolve(0) cache[""] = MultiModalProcessorCacheItem(item, [prompt_update]) assert cache.currsize == expected_size @@ -100,9 +108,9 @@ def _create_vllm_config( return VllmConfig( model_config=ModelConfig( model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", - mm_processor_cache_gb=mm_processor_cache_gb), - parallel_config=ParallelConfig( - data_parallel_size=1 if enable_ipc else 2), + mm_processor_cache_gb=mm_processor_cache_gb, + ), + parallel_config=ParallelConfig(data_parallel_size=1 if enable_ipc else 2), ) @@ -118,11 +126,9 @@ def _compare_caches( seed: int = 0, ): cache_0_p0 = processor_cache_from_config(config_0, MULTIMODAL_REGISTRY) - cache_0_p1 = engine_receiver_cache_from_config(config_0, - MULTIMODAL_REGISTRY) + cache_0_p1 = engine_receiver_cache_from_config(config_0, MULTIMODAL_REGISTRY) cache_1_p0 = processor_cache_from_config(config_1, MULTIMODAL_REGISTRY) - cache_1_p1 = engine_receiver_cache_from_config(config_1, - MULTIMODAL_REGISTRY) + cache_1_p1 = engine_receiver_cache_from_config(config_1, MULTIMODAL_REGISTRY) cache_size_gb = max( config_0.model_config.multimodal_config.mm_processor_cache_gb, @@ -136,8 +142,7 @@ def _compare_caches( for _ in range(int(item_capacity / hit_rate)) ] all_hashes = [ - MultiModalHasher.hash_kwargs(item=item.get_data()) - for item in all_items + MultiModalHasher.hash_kwargs(item=item.get_data()) for item in all_items ] # Should not be used since there is nothing to convert to text @@ -156,7 +161,8 @@ def _compare_caches( for _ in range(is_cached_calls_per_iter): cache_0_p0.is_cached(selected_hashes) cache_0_p0_out = [ - item for item, _ in cache_0_p0.get_and_update( + item + for item, _ in cache_0_p0.get_and_update( [(item, prompt_update.content) for item in selected_items], selected_hashes, ) @@ -168,7 +174,8 @@ def _compare_caches( for _ in range(is_cached_calls_per_iter): cache_1_p0.is_cached(selected_hashes) cache_1_p0_out = [ - item for item, _ in cache_1_p0.get_and_update( + item + for item, _ in cache_1_p0.get_and_update( [(item, prompt_update.content) for item in selected_items], selected_hashes, ) @@ -177,14 +184,12 @@ def _compare_caches( if cache_0_p1 is None: cache_0_p1_out = cache_0_p0_out else: - cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out, - selected_hashes) + cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out, selected_hashes) if cache_1_p1 is None: cache_1_p1_out = cache_1_p0_out else: - cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out, - selected_hashes) + cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out, selected_hashes) assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}" diff --git a/tests/multimodal/test_hasher.py b/tests/multimodal/test_hasher.py index 46aba1b75f77..29064f273783 100644 --- a/tests/multimodal/test_hasher.py +++ b/tests/multimodal/test_hasher.py @@ -90,8 +90,6 @@ def test_hash_image_exif_id(): hasher = MultiModalHasher # first image has UUID in ImageID, so it should hash to that UUID - assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs( - image=id.bytes) + assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs(image=id.bytes) # second image has non-UUID in ImageID, so it should hash to the image data - assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs( - image=image2a) + assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs(image=image2a) diff --git a/tests/multimodal/test_image.py b/tests/multimodal/test_image.py index 2f21ad969e74..329a5b0494cb 100644 --- a/tests/multimodal/test_image.py +++ b/tests/multimodal/test_image.py @@ -43,8 +43,7 @@ def test_rgba_to_rgb(): def test_rgba_to_rgb_custom_background(tmp_path): """Test RGBA to RGB conversion with custom background colors.""" # Create a simple RGBA image with transparent and opaque pixels - rgba_image = Image.new("RGBA", (10, 10), - (255, 0, 0, 255)) # Red with full opacity + rgba_image = Image.new("RGBA", (10, 10), (255, 0, 0, 255)) # Red with full opacity # Make top-left quadrant transparent for i in range(5): @@ -94,7 +93,7 @@ def test_rgba_to_rgb_custom_background(tmp_path): assert blue_numpy[0][0][2] == 255 # B # Test 4: Test with load_bytes method - with open(test_image_path, 'rb') as f: + with open(test_image_path, "rb") as f: image_data = f.read() image_io_green = ImageMediaIO(rgba_background_color=(0, 255, 0)) @@ -111,39 +110,47 @@ def test_rgba_background_color_validation(): """Test that invalid rgba_background_color values are properly rejected.""" # Test invalid types - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color="255,255,255") - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=255) # Test wrong number of elements - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(255, 255)) - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(255, 255, 255, 255)) # Test non-integer values - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(255.0, 255.0, 255.0)) - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(255, "255", 255)) # Test out of range values - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(256, 255, 255)) - with pytest.raises(ValueError, - match="rgba_background_color must be a list or tuple"): + with pytest.raises( + ValueError, match="rgba_background_color must be a list or tuple" + ): ImageMediaIO(rgba_background_color=(255, -1, 255)) # Test that valid values work diff --git a/tests/multimodal/test_inputs.py b/tests/multimodal/test_inputs.py index f35935d14ff2..88e92bee3a29 100644 --- a/tests/multimodal/test_inputs.py +++ b/tests/multimodal/test_inputs.py @@ -9,8 +9,7 @@ pytestmark = pytest.mark.cpu_test -def assert_nested_tensors_equal(expected: NestedTensors, - actual: NestedTensors): +def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors): assert type(expected) == type(actual) # noqa: E721 if isinstance(expected, torch.Tensor): assert torch.equal(expected, actual) @@ -19,8 +18,9 @@ def assert_nested_tensors_equal(expected: NestedTensors, assert_nested_tensors_equal(expected_item, actual_item) -def assert_multimodal_inputs_equal(expected: MultiModalKwargs, - actual: MultiModalKwargs): +def assert_multimodal_inputs_equal( + expected: MultiModalKwargs, actual: MultiModalKwargs +): assert set(expected.keys()) == set(actual.keys()) for key in expected: assert_nested_tensors_equal(expected[key], actual[key]) @@ -52,19 +52,10 @@ def test_multimodal_input_batch_nested_tensors(): a = torch.rand([2, 3]) b = torch.rand([2, 3]) c = torch.rand([2, 3]) - result = MultiModalKwargs.batch([{ - "image": [a] - }, { - "image": [b] - }, { - "image": [c] - }]) - assert_multimodal_inputs_equal(result, { - "image": - torch.stack([a.unsqueeze(0), - b.unsqueeze(0), - c.unsqueeze(0)]) - }) + result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b]}, {"image": [c]}]) + assert_multimodal_inputs_equal( + result, {"image": torch.stack([a.unsqueeze(0), b.unsqueeze(0), c.unsqueeze(0)])} + ) def test_multimodal_input_batch_heterogeneous_lists(): @@ -73,8 +64,8 @@ def test_multimodal_input_batch_heterogeneous_lists(): c = torch.rand([1, 2, 3]) result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}]) assert_multimodal_inputs_equal( - result, - {"image": [torch.stack([a, b]), c.unsqueeze(0)]}) + result, {"image": [torch.stack([a, b]), c.unsqueeze(0)]} + ) def test_multimodal_input_batch_multiple_batchable_lists(): @@ -84,9 +75,8 @@ def test_multimodal_input_batch_multiple_batchable_lists(): d = torch.rand([1, 2, 3]) result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}]) assert_multimodal_inputs_equal( - result, - {"image": torch.stack([torch.stack([a, b]), - torch.stack([c, d])])}) + result, {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])} + ) def test_multimodal_input_batch_mixed_stacking_depths(): diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 7aa51acff350..a542b068a42b 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -9,17 +9,18 @@ from vllm.config import ModelConfig from vllm.multimodal import MULTIMODAL_REGISTRY -# yapf conflicts with isort for this block -# yapf: disable -from vllm.multimodal.processing import (InputProcessingContext, - PlaceholderFeaturesInfo, - PromptIndexTargets, PromptInsertion, - PromptReplacement, apply_text_matches, - apply_token_matches, - find_mm_placeholders, - iter_token_matches, - replace_token_matches) -# yapf: enable +from vllm.multimodal.processing import ( + InputProcessingContext, + PlaceholderFeaturesInfo, + PromptIndexTargets, + PromptInsertion, + PromptReplacement, + apply_text_matches, + apply_token_matches, + find_mm_placeholders, + iter_token_matches, + replace_token_matches, +) from vllm.multimodal.profiling import MultiModalProfiler from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -28,7 +29,6 @@ pytestmark = pytest.mark.cpu_test -# yapf: disable @pytest.mark.parametrize( ("token_ids", "match_ids", "expected"), [ @@ -38,34 +38,34 @@ [32000, 32000, 32000], [32000], [ - { "start_idx": 0, "end_idx": 1 }, - { "start_idx": 1, "end_idx": 2 }, - { "start_idx": 2, "end_idx": 3 }, + {"start_idx": 0, "end_idx": 1}, + {"start_idx": 1, "end_idx": 2}, + {"start_idx": 2, "end_idx": 3}, ], ), ( [32000, 32000, 32000], [32000, 32000], - [{ "start_idx": 0, "end_idx": 2 }], + [{"start_idx": 0, "end_idx": 2}], ), ( [32000, 32000, 32000], [32000, 32000, 32000], - [{ "start_idx": 0, "end_idx": 3 }], + [{"start_idx": 0, "end_idx": 3}], ), ( [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], [28747, 32000], [ - { "start_idx": 1, "end_idx": 3 }, - { "start_idx": 6, "end_idx": 8 }, + {"start_idx": 1, "end_idx": 3}, + {"start_idx": 6, "end_idx": 8}, ], ), ( [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], [28747, 32000, 32000, 32000], [ - { "start_idx": 1, "end_idx": 5 }, + {"start_idx": 1, "end_idx": 5}, ], ), ( @@ -76,14 +76,13 @@ ], ) @pytest.mark.parametrize("start_idx", [0, 4, 8]) -# yapf: enable def test_iter_token_matches(token_ids, match_ids, expected, start_idx): - result = list(iter_token_matches(token_ids, match_ids, - start_idx=start_idx)) + result = list(iter_token_matches(token_ids, match_ids, start_idx=start_idx)) # Manually constructed results - assert [item._asdict() for item in result - ] == [item for item in expected if item["start_idx"] >= start_idx] + assert [item._asdict() for item in result] == [ + item for item in expected if item["start_idx"] >= start_idx + ] # Invariants match_lens = [end - start for start, end in result] @@ -91,7 +90,6 @@ def test_iter_token_matches(token_ids, match_ids, expected, start_idx): assert all(match_len == len(match_ids) for match_len in match_lens) -# yapf: disable @pytest.mark.parametrize( ("token_ids", "match_ids", "new_ids", "expected"), [ @@ -135,7 +133,6 @@ def test_iter_token_matches(token_ids, match_ids, expected, start_idx): ), ], ) -# yapf: enable def test_replace_token_matches(token_ids, match_ids, new_ids, expected): result = replace_token_matches(token_ids, match_ids, new_ids) @@ -143,7 +140,6 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected): assert result == expected -# yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "expected_by_key"), [ @@ -160,11 +156,11 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected): "pattern_1": [], "pattern_2": [], "pattern_3": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_4": [], "pattern_5": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], }, ), @@ -180,26 +176,26 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected): }, { "pattern_1": [ - { "start_idx": 0, "end_idx": 1 }, - { "start_idx": 1, "end_idx": 2 }, - { "start_idx": 2, "end_idx": 3 }, - { "start_idx": 3, "end_idx": 4 }, + {"start_idx": 0, "end_idx": 1}, + {"start_idx": 1, "end_idx": 2}, + {"start_idx": 2, "end_idx": 3}, + {"start_idx": 3, "end_idx": 4}, ], "pattern_2": [ - { "start_idx": 0, "end_idx": 2 }, - { "start_idx": 2, "end_idx": 4 }, + {"start_idx": 0, "end_idx": 2}, + {"start_idx": 2, "end_idx": 4}, ], "pattern_3": [ - { "start_idx": 0, "end_idx": 3 }, + {"start_idx": 0, "end_idx": 3}, ], "pattern_4": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_5": [ - { "start_idx": 1, "end_idx": 1 }, + {"start_idx": 1, "end_idx": 1}, ], "pattern_6": [ - { "start_idx": 4, "end_idx": 4 }, + {"start_idx": 4, "end_idx": 4}, ], }, ), @@ -215,26 +211,25 @@ def test_replace_token_matches(token_ids, match_ids, new_ids, expected): }, { "pattern_1": [ - { "start_idx": 1, "end_idx": 3 }, - { "start_idx": 6, "end_idx": 8 }, + {"start_idx": 1, "end_idx": 3}, + {"start_idx": 6, "end_idx": 8}, ], "pattern_2": [ - { "start_idx": 1, "end_idx": 5 }, + {"start_idx": 1, "end_idx": 5}, ], "pattern_3": [], "pattern_4": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_5": [], "pattern_6": [ - { "start_idx": 10, "end_idx": 10 }, + {"start_idx": 10, "end_idx": 10}, ], }, ), ], ) @pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) -# yapf: enable def test_find_token_matches( prompt, target_by_key, @@ -266,7 +261,6 @@ def test_find_token_matches( } == expected_by_key -# yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "expected_by_key"), [ @@ -282,16 +276,16 @@ def test_find_token_matches( "pattern_5": PromptIndexTargets.end(), }, { - "pattern_1": [{ "start_idx": 0, "end_idx": 0 }], + "pattern_1": [{"start_idx": 0, "end_idx": 0}], "pattern_2": [], "pattern_3": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_4": [], "pattern_5": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], - } + }, ), ( "<image><image><image><image>", @@ -305,26 +299,26 @@ def test_find_token_matches( }, { "pattern_1": [ - { "start_idx": 0, "end_idx": 7 }, - { "start_idx": 7, "end_idx": 14 }, - { "start_idx": 14, "end_idx": 21 }, - { "start_idx": 21, "end_idx": 28 }, + {"start_idx": 0, "end_idx": 7}, + {"start_idx": 7, "end_idx": 14}, + {"start_idx": 14, "end_idx": 21}, + {"start_idx": 21, "end_idx": 28}, ], "pattern_2": [ - { "start_idx": 0, "end_idx": 14 }, - { "start_idx": 14, "end_idx": 28 }, + {"start_idx": 0, "end_idx": 14}, + {"start_idx": 14, "end_idx": 28}, ], "pattern_3": [ - { "start_idx": 0, "end_idx": 21 }, + {"start_idx": 0, "end_idx": 21}, ], "pattern_4": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_5": [ - { "start_idx": 7, "end_idx": 7 }, + {"start_idx": 7, "end_idx": 7}, ], "pattern_6": [ - { "start_idx": 28, "end_idx": 28 }, + {"start_idx": 28, "end_idx": 28}, ], }, ), @@ -340,21 +334,21 @@ def test_find_token_matches( }, { "pattern_1": [ - { "start_idx": 0, "end_idx": 13 }, - { "start_idx": 27, "end_idx": 40 }, + {"start_idx": 0, "end_idx": 13}, + {"start_idx": 27, "end_idx": 40}, ], "pattern_2": [ - { "start_idx": 0, "end_idx": 27 }, + {"start_idx": 0, "end_idx": 27}, ], "pattern_3": [], "pattern_4": [ - { "start_idx": 0, "end_idx": 0 }, + {"start_idx": 0, "end_idx": 0}, ], "pattern_5": [ - { "start_idx": 13, "end_idx": 13 }, + {"start_idx": 13, "end_idx": 13}, ], "pattern_6": [ - { "start_idx": 48, "end_idx": 48 }, + {"start_idx": 48, "end_idx": 48}, ], }, ), @@ -368,22 +362,21 @@ def test_find_token_matches( }, { "pattern_1": [ - { "start_idx": 0, "end_idx": 9 }, - { "start_idx": 16, "end_idx": 25 }, + {"start_idx": 0, "end_idx": 9}, + {"start_idx": 16, "end_idx": 25}, ], "pattern_2": [ - { "start_idx": 0, "end_idx": 16 }, - { "start_idx": 16, "end_idx": 32 }, + {"start_idx": 0, "end_idx": 16}, + {"start_idx": 16, "end_idx": 32}, ], "pattern_3": [ - { "start_idx": 0, "end_idx": 25 }, + {"start_idx": 0, "end_idx": 25}, ], }, ), ], ) @pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) -# yapf: enable def test_find_text_matches( prompt, target_by_key, @@ -415,7 +408,6 @@ def test_find_text_matches( } == expected_by_key -# yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ @@ -543,9 +535,8 @@ def test_find_text_matches( }, }, ), - ] + ], ) -# yapf: enable def test_find_update_text( prompt, target_by_key, @@ -556,13 +547,15 @@ def test_find_update_text( mock_tokenizer = cast(AnyTokenizer, object()) for ( - update_type, - expected_by_mm_count, + update_type, + expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): for mm_count, expected in expected_by_mm_count.items(): mm_prompt_updates = { - key: [[update_type(key, target, repl_by_key[key]).resolve(i)] - for i in range(mm_count)] + key: [ + [update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count) + ] for key, target in target_by_key.items() } @@ -583,7 +576,6 @@ def test_find_update_text( assert new_prompt == expected -# yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ @@ -609,8 +601,43 @@ def test_find_update_text( { PromptInsertion: { 0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], - 1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501 - 2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501 + 1: [ + 1, + 9833, + 28747, + 32000, + 32000, + 32000, + 9833, + 28747, + 32000, + 32000, + 918, + 1550, + 918, + 1550, + ], # noqa: E501 + 2: [ + 1, + 9833, + 28747, + 32000, + 32000, + 32000, + 32000, + 32000, + 9833, + 28747, + 32000, + 32000, + 918, + 1550, + 918, + 1550, + 1550, + 918, + 1550, + ], # noqa: E501 }, PromptReplacement: { 0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], @@ -713,9 +740,8 @@ def test_find_update_text( }, }, ), - ] + ], ) -# yapf: enable def test_find_update_tokens( prompt, target_by_key, @@ -726,13 +752,15 @@ def test_find_update_tokens( mock_tokenizer = cast(AnyTokenizer, object()) for ( - update_type, - expected_by_mm_count, + update_type, + expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): for mm_count, expected in expected_by_mm_count.items(): mm_prompt_updates = { - key: [[update_type(key, target, repl_by_key[key]).resolve(i)] - for i in range(mm_count)] + key: [ + [update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count) + ] for key, target in target_by_key.items() } @@ -753,7 +781,6 @@ def test_find_update_tokens( assert new_prompt == expected -# yapf: disable @pytest.mark.parametrize( "repl_by_key", [ @@ -790,8 +817,7 @@ def test_find_update_tokens( is_embed=None, ), ], - } - + }, ), ( [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], @@ -822,7 +848,7 @@ def test_find_update_tokens( ), ], # No match for pattern_4 as it has lower priority than pattern_1 - } + }, ), ( [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], @@ -861,12 +887,11 @@ def test_find_update_tokens( is_embed=None, ), ], - } + }, ), - ] + ], ) @pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement]) -# yapf: enable def test_find_mm_placeholders( repl_by_key, prompt, @@ -893,8 +918,15 @@ def test_find_mm_placeholders( @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("limit", "num_supported", "is_valid"), - [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), - (2, 1, False), (2, 2, True)], + [ + (0, 0, True), + (0, 1, True), + (1, 0, False), + (1, 1, True), + (1, 2, True), + (2, 1, False), + (2, 2, True), + ], ) def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): limit_mm_per_prompt = {"image": limit} @@ -909,10 +941,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): profiler = MultiModalProfiler(processor) - if is_valid: - exc_ctx = nullcontext() - else: - exc_ctx = pytest.raises(ValueError, match="At most") + exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") with exc_ctx: profiler.get_decoder_dummy_data( @@ -924,8 +953,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("num_images", "limit", "is_valid"), - [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True), - (2, 1, False), (2, 2, True)], + [ + (0, 0, True), + (0, 1, True), + (1, 0, False), + (1, 1, True), + (1, 2, True), + (2, 1, False), + (2, 2, True), + ], ) def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): limit_mm_per_prompt = {"image": limit} @@ -946,10 +982,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): else: mm_data = {"image": [image] * num_images} - if is_valid: - exc_ctx = nullcontext() - else: - exc_ctx = pytest.raises(ValueError, match="At most") + exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") with exc_ctx: processor.apply( @@ -960,7 +993,6 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): class DummyProcessor: - def __init__(self, a: int = 0, b: int = 0) -> None: super().__init__() @@ -976,7 +1008,6 @@ def __call__( return dict(a=a, c=c) -# yapf: disable @pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy @pytest.mark.parametrize( ("config_kwargs", "inference_kwargs", "expected_kwargs"), @@ -990,7 +1021,6 @@ def __call__( ({"b": 1, "c": 1}, {}, {"a": 0, "b": 1}), ], ) -# yapf: enable def test_hf_processor_init_kwargs( model_id, config_kwargs, @@ -1014,7 +1044,6 @@ def test_hf_processor_init_kwargs( assert getattr(processor, k) == v -# yapf: disable @pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy @pytest.mark.parametrize( ("config_kwargs", "inference_kwargs", "expected_kwargs"), @@ -1028,7 +1057,6 @@ def test_hf_processor_init_kwargs( ({"b": 1, "c": 1}, {}, {"a": 0, "c": 1}), ], ) -# yapf: enable def test_hf_processor_call_kwargs( model_id, config_kwargs, diff --git a/tests/multimodal/test_registry.py b/tests/multimodal/test_registry.py index 01fbe9a52b77..3b01bda7f54c 100644 --- a/tests/multimodal/test_registry.py +++ b/tests/multimodal/test_registry.py @@ -19,22 +19,16 @@ [ ("Qwen/Qwen2-0.5B-Instruct", {}, False), ("Qwen/Qwen2.5-VL-3B-Instruct", {}, True), - ("Qwen/Qwen2.5-VL-3B-Instruct", { - "image": 0, - "video": 0 - }, False), - ("Qwen/Qwen2.5-VL-3B-Instruct", { - "image": 0 - }, True), + ("Qwen/Qwen2.5-VL-3B-Instruct", {"image": 0, "video": 0}, False), + ("Qwen/Qwen2.5-VL-3B-Instruct", {"image": 0}, True), ], ) @pytest.mark.core_model def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected): - """Test supports_multimodal_inputs returns correct boolean for various + """Test supports_multimodal_inputs returns correct boolean for various configs.""" ctx = build_model_context( model_id, limit_mm_per_prompt=limit_mm_per_prompt, ) - assert MULTIMODAL_REGISTRY.supports_multimodal_inputs( - ctx.model_config) is expected \ No newline at end of file + assert MULTIMODAL_REGISTRY.supports_multimodal_inputs(ctx.model_config) is expected diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index d1a7882a4c37..ea795fcbbde5 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -30,7 +30,6 @@ @pytest.fixture(scope="module") def url_images(local_asset_server) -> dict[str, Image.Image]: - return { image_url: local_asset_server.get_image_asset(image_url) for image_url in TEST_IMAGE_ASSETS @@ -39,10 +38,10 @@ def url_images(local_asset_server) -> dict[str, Image.Image]: def get_supported_suffixes() -> tuple[str, ...]: # We should at least test the file types mentioned in GPT-4 with Vision - OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif') + OPENAI_SUPPORTED_SUFFIXES = (".png", ".jpeg", ".jpg", ".webp", ".gif") # Additional file types that are supported by us - EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff') + EXTRA_SUPPORTED_SUFFIXES = (".bmp", ".tiff") return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES @@ -64,14 +63,16 @@ async def test_fetch_image_http(image_url: str): @pytest.mark.asyncio @pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS) @pytest.mark.parametrize("suffix", get_supported_suffixes()) -async def test_fetch_image_base64(url_images: dict[str, Image.Image], - raw_image_url: str, suffix: str): +async def test_fetch_image_base64( + url_images: dict[str, Image.Image], raw_image_url: str, suffix: str +): connector = MediaConnector( # Domain restriction should not apply to data URLs. allowed_media_domains=[ "www.bogotobogo.com", "github.com", - ]) + ] + ) url_image = url_images[raw_image_url] try: @@ -80,14 +81,14 @@ async def test_fetch_image_base64(url_images: dict[str, Image.Image], try: mime_type = mimetypes.types_map[suffix] except KeyError: - pytest.skip('No MIME type') + pytest.skip("No MIME type") with NamedTemporaryFile(suffix=suffix) as f: try: url_image.save(f.name) except Exception as e: - if e.args[0] == 'cannot write mode RGBA as JPEG': - pytest.skip('Conversion not supported') + if e.args[0] == "cannot write mode RGBA as JPEG": + pytest.skip("Conversion not supported") raise @@ -113,30 +114,36 @@ async def test_fetch_image_local_files(image_url: str): local_connector = MediaConnector(allowed_local_media_path=temp_dir) origin_image = connector.fetch_image(image_url) - origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)), - quality=100, - icc_profile=origin_image.info.get('icc_profile')) + origin_image.save( + os.path.join(temp_dir, os.path.basename(image_url)), + quality=100, + icc_profile=origin_image.info.get("icc_profile"), + ) image_async = await local_connector.fetch_image_async( - f"file://{temp_dir}/{os.path.basename(image_url)}") + f"file://{temp_dir}/{os.path.basename(image_url)}" + ) image_sync = local_connector.fetch_image( - f"file://{temp_dir}/{os.path.basename(image_url)}") + f"file://{temp_dir}/{os.path.basename(image_url)}" + ) # Check that the images are equal assert not ImageChops.difference(image_sync, image_async).getbbox() with pytest.raises(ValueError, match="must be a subpath"): await local_connector.fetch_image_async( - f"file://{temp_dir}/../{os.path.basename(image_url)}") + f"file://{temp_dir}/../{os.path.basename(image_url)}" + ) with pytest.raises(RuntimeError, match="Cannot load local files"): await connector.fetch_image_async( - f"file://{temp_dir}/../{os.path.basename(image_url)}") + f"file://{temp_dir}/../{os.path.basename(image_url)}" + ) with pytest.raises(ValueError, match="must be a subpath"): local_connector.fetch_image( - f"file://{temp_dir}/../{os.path.basename(image_url)}") + f"file://{temp_dir}/../{os.path.basename(image_url)}" + ) with pytest.raises(RuntimeError, match="Cannot load local files"): - connector.fetch_image( - f"file://{temp_dir}/../{os.path.basename(image_url)}") + connector.fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}") @pytest.mark.asyncio @@ -149,18 +156,19 @@ async def test_fetch_image_local_files_with_space_in_name(image_url: str): origin_image = connector.fetch_image(image_url) filename = "file name with space.jpg" - origin_image.save(os.path.join(temp_dir, filename), - quality=100, - icc_profile=origin_image.info.get('icc_profile')) + origin_image.save( + os.path.join(temp_dir, filename), + quality=100, + icc_profile=origin_image.info.get("icc_profile"), + ) try: image_async = await local_connector.fetch_image_async( - f"file://{temp_dir}/{filename}") - image_sync = local_connector.fetch_image( - f"file://{temp_dir}/{filename}") + f"file://{temp_dir}/{filename}" + ) + image_sync = local_connector.fetch_image(f"file://{temp_dir}/{filename}") except FileNotFoundError as e: - pytest.fail( - "Failed to fetch image with space in name: {}".format(e)) + pytest.fail("Failed to fetch image with space in name: {}".format(e)) # Check that the images are equal assert not ImageChops.difference(image_sync, image_async).getbbox() @@ -183,9 +191,12 @@ async def test_fetch_image_error_conversion(): @pytest.mark.parametrize("num_frames", [-1, 32, 1800]) async def test_fetch_video_http(video_url: str, num_frames: int): connector = MediaConnector( - media_io_kwargs={"video": { - "num_frames": num_frames, - }}) + media_io_kwargs={ + "video": { + "num_frames": num_frames, + } + } + ) video_sync, metadata_sync = connector.fetch_video(video_url) video_async, metadata_async = await connector.fetch_video_async(video_url) @@ -198,8 +209,11 @@ async def test_fetch_video_http(video_url: str, num_frames: int): @pytest.mark.parametrize("max_duration", [1, 60, 1800]) @pytest.mark.parametrize("requested_fps", [2, 24]) async def test_fetch_video_http_with_dynamic_loader( - video_url: str, max_duration: int, requested_fps: int, - monkeypatch: pytest.MonkeyPatch): + video_url: str, + max_duration: int, + requested_fps: int, + monkeypatch: pytest.MonkeyPatch, +): with monkeypatch.context() as m: m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv_dynamic") connector = MediaConnector( @@ -208,18 +222,17 @@ async def test_fetch_video_http_with_dynamic_loader( "max_duration": max_duration, "requested_fps": requested_fps, } - }) + } + ) video_sync, metadata_sync = connector.fetch_video(video_url) - video_async, metadata_async = await connector.fetch_video_async( - video_url) + video_async, metadata_async = await connector.fetch_video_async(video_url) assert np.array_equal(video_sync, video_async) assert metadata_sync == metadata_async assert metadata_sync["video_backend"] == "opencv_dynamic" -# yapf: disable @pytest.mark.parametrize( "case", [ @@ -250,7 +263,6 @@ async def test_fetch_video_http_with_dynamic_loader( ("image", 0), ], ), - # Two modalities ## Internally sorted dict( @@ -262,7 +274,7 @@ async def test_fetch_video_http_with_dynamic_loader( "audio": [ PlaceholderRange(offset=0, length=2), PlaceholderRange(offset=2, length=3), - ] + ], }, expected_modality_idxs=[ ("audio", 0), @@ -281,7 +293,7 @@ async def test_fetch_video_http_with_dynamic_loader( "audio": [ PlaceholderRange(offset=5, length=2), PlaceholderRange(offset=11, length=4), - ] + ], }, expected_modality_idxs=[ ("image", 0), @@ -300,7 +312,7 @@ async def test_fetch_video_http_with_dynamic_loader( "audio": [ PlaceholderRange(offset=11, length=4), PlaceholderRange(offset=5, length=2), - ] + ], }, expected_modality_idxs=[ ("image", 1), @@ -309,7 +321,6 @@ async def test_fetch_video_http_with_dynamic_loader( ("audio", 0), ], ), - # Three modalities ## Internally sorted dict( @@ -325,7 +336,7 @@ async def test_fetch_video_http_with_dynamic_loader( PlaceholderRange(offset=3, length=4), PlaceholderRange(offset=7, length=5), PlaceholderRange(offset=12, length=6), - ] + ], }, expected_modality_idxs=[ ("audio", 0), @@ -349,7 +360,7 @@ async def test_fetch_video_http_with_dynamic_loader( ], "video": [ PlaceholderRange(offset=8, length=5), - ] + ], }, expected_modality_idxs=[ ("image", 0), @@ -372,7 +383,7 @@ async def test_fetch_video_http_with_dynamic_loader( ], "video": [ PlaceholderRange(offset=8, length=5), - ] + ], }, expected_modality_idxs=[ ("image", 0), @@ -384,7 +395,6 @@ async def test_fetch_video_http_with_dynamic_loader( ), ], ) -# yapf: enable def test_argsort_mm_positions(case): mm_positions = case["mm_positions"] expected_modality_idxs = case["expected_modality_idxs"] @@ -399,13 +409,16 @@ def test_argsort_mm_positions(case): @pytest.mark.parametrize("num_frames", [-1, 32, 1800]) async def test_allowed_media_domains(video_url: str, num_frames: int): connector = MediaConnector( - media_io_kwargs={"video": { - "num_frames": num_frames, - }}, + media_io_kwargs={ + "video": { + "num_frames": num_frames, + } + }, allowed_media_domains=[ "www.bogotobogo.com", "github.com", - ]) + ], + ) video_sync, metadata_sync = connector.fetch_video(video_url) video_async, metadata_async = await connector.fetch_video_async(video_url) diff --git a/tests/multimodal/test_video.py b/tests/multimodal/test_video.py index 1bdbb5a10a6d..6572616769a9 100644 --- a/tests/multimodal/test_video.py +++ b/tests/multimodal/test_video.py @@ -12,8 +12,7 @@ from vllm.assets.base import get_vllm_public_assets from vllm.assets.video import video_to_ndarrays, video_to_pil_images_list from vllm.multimodal.image import ImageMediaIO -from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader, - VideoMediaIO) +from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader, VideoMediaIO from .utils import cosine_similarity, create_video_from_image, normalize_image @@ -26,7 +25,6 @@ @VIDEO_LOADER_REGISTRY.register("test_video_loader_1") class TestVideoLoader1(VideoLoader): - @classmethod def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: return FAKE_OUTPUT_1 @@ -34,7 +32,6 @@ def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: @VIDEO_LOADER_REGISTRY.register("test_video_loader_2") class TestVideoLoader2(VideoLoader): - @classmethod def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: return FAKE_OUTPUT_2 @@ -57,13 +54,10 @@ def test_video_loader_type_doesnt_exist(): @VIDEO_LOADER_REGISTRY.register("assert_10_frames_1_fps") class Assert10Frames1FPSVideoLoader(VideoLoader): - @classmethod - def load_bytes(cls, - data: bytes, - num_frames: int = -1, - fps: float = -1.0, - **kwargs) -> npt.NDArray: + def load_bytes( + cls, data: bytes, num_frames: int = -1, fps: float = -1.0, **kwargs + ) -> npt.NDArray: assert num_frames == 10, "bad num_frames" assert fps == 1.0, "bad fps" return FAKE_OUTPUT_2 @@ -79,11 +73,8 @@ def test_video_media_io_kwargs(monkeypatch: pytest.MonkeyPatch): _ = videoio.load_bytes(b"test") videoio = VideoMediaIO( - imageio, **{ - "num_frames": 10, - "fps": 1.0, - "not_used": "not_used" - }) + imageio, **{"num_frames": 10, "fps": 1.0, "not_used": "not_used"} + ) _ = videoio.load_bytes(b"test") with pytest.raises(AssertionError, match="bad num_frames"): @@ -106,8 +97,9 @@ def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str): Test all functions that use OpenCV for video I/O return RGB format. Both RGB and grayscale videos are tested. """ - image_path = get_vllm_public_assets(filename="stop_sign.jpg", - s3_prefix="vision_model_images") + image_path = get_vllm_public_assets( + filename="stop_sign.jpg", s3_prefix="vision_model_images" + ) image = Image.open(image_path) with tempfile.TemporaryDirectory() as tmpdir: if not is_color: @@ -127,21 +119,24 @@ def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str): frames = video_to_ndarrays(video_path) for frame in frames: - sim = cosine_similarity(normalize_image(np.array(frame)), - normalize_image(np.array(image))) + sim = cosine_similarity( + normalize_image(np.array(frame)), normalize_image(np.array(image)) + ) assert np.sum(np.isnan(sim)) / sim.size < 0.001 assert np.nanmean(sim) > 0.99 pil_frames = video_to_pil_images_list(video_path) for frame in pil_frames: - sim = cosine_similarity(normalize_image(np.array(frame)), - normalize_image(np.array(image))) + sim = cosine_similarity( + normalize_image(np.array(frame)), normalize_image(np.array(image)) + ) assert np.sum(np.isnan(sim)) / sim.size < 0.001 assert np.nanmean(sim) > 0.99 io_frames, _ = VideoMediaIO(ImageMediaIO()).load_file(Path(video_path)) for frame in io_frames: - sim = cosine_similarity(normalize_image(np.array(frame)), - normalize_image(np.array(image))) + sim = cosine_similarity( + normalize_image(np.array(frame)), normalize_image(np.array(image)) + ) assert np.sum(np.isnan(sim)) / sim.size < 0.001 assert np.nanmean(sim) > 0.99 diff --git a/tests/multimodal/utils.py b/tests/multimodal/utils.py index 9a58292f9f4a..485bde939f69 100644 --- a/tests/multimodal/utils.py +++ b/tests/multimodal/utils.py @@ -8,7 +8,7 @@ def random_image(rng: np.random.RandomState, min_wh: int, max_wh: int): - w, h = rng.randint(min_wh, max_wh, size=(2, )) + w, h = rng.randint(min_wh, max_wh, size=(2,)) arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8) return Image.fromarray(arr) @@ -21,7 +21,7 @@ def random_video( max_wh: int, ): num_frames = rng.randint(min_frames, max_frames) - w, h = rng.randint(min_wh, max_wh, size=(2, )) + w, h = rng.randint(min_wh, max_wh, size=(2,)) return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8) @@ -66,14 +66,13 @@ def create_video_from_image( return video_path -def cosine_similarity(A: npt.NDArray, - B: npt.NDArray, - axis: int = -1) -> npt.NDArray: +def cosine_similarity(A: npt.NDArray, B: npt.NDArray, axis: int = -1) -> npt.NDArray: """Compute cosine similarity between two vectors.""" - return (np.sum(A * B, axis=axis) / - (np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis))) + return np.sum(A * B, axis=axis) / ( + np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis) + ) def normalize_image(image: npt.NDArray) -> npt.NDArray: """Normalize image to [0, 1] range.""" - return image.astype(np.float32) / 255.0 \ No newline at end of file + return image.astype(np.float32) / 255.0 diff --git a/tests/plugins/lora_resolvers/test_filesystem_resolver.py b/tests/plugins/lora_resolvers/test_filesystem_resolver.py index 3e2c2577da66..cd98efdd1390 100644 --- a/tests/plugins/lora_resolvers/test_filesystem_resolver.py +++ b/tests/plugins/lora_resolvers/test_filesystem_resolver.py @@ -13,11 +13,10 @@ PA_NAME = "swapnilbp/llama_tweet_ptune" -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def adapter_cache(request, tmpdir_factory): # Create dir that mimics the structure of the adapter cache - adapter_cache = tmpdir_factory.mktemp( - request.module.__name__) / "adapter_cache" + adapter_cache = tmpdir_factory.mktemp(request.module.__name__) / "adapter_cache" return adapter_cache diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index 42874f0398f0..a2a8d0ec9aba 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -9,7 +9,7 @@ import tempfile import urllib.request from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Union import albumentations import numpy as np @@ -20,14 +20,15 @@ from terratorch.datamodules import Sen1Floods11NonGeoDataModule from vllm.config import VllmConfig -from vllm.entrypoints.openai.protocol import (IOProcessorRequest, - IOProcessorResponse) +from vllm.entrypoints.openai.protocol import IOProcessorRequest, IOProcessorResponse from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput -from vllm.plugins.io_processors.interface import (IOProcessor, - IOProcessorInput, - IOProcessorOutput) +from vllm.plugins.io_processors.interface import ( + IOProcessor, + IOProcessorInput, + IOProcessorOutput, +) from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput @@ -42,35 +43,25 @@ datamodule_config: DataModuleConfig = { "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], - "batch_size": - 16, - "constant_scale": - 0.0001, - "data_root": - "/dccstor/geofm-finetuning/datasets/sen1floods11", - "drop_last": - True, - "no_data_replace": - 0.0, - "no_label_replace": - -1, - "num_workers": - 8, + "batch_size": 16, + "constant_scale": 0.0001, + "data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11", + "drop_last": True, + "no_data_replace": 0.0, + "no_label_replace": -1, + "num_workers": 8, "test_transform": [ - albumentations.Resize(always_apply=False, - height=448, - interpolation=1, - p=1, - width=448), - albumentations.pytorch.ToTensorV2(transpose_mask=False, - always_apply=True, - p=1.0), + albumentations.Resize( + always_apply=False, height=448, interpolation=1, p=1, width=448 + ), + albumentations.pytorch.ToTensorV2( + transpose_mask=False, always_apply=True, p=1.0 + ), ], } -def save_geotiff(image: torch.Tensor, meta: dict, - out_format: str) -> str | bytes: +def save_geotiff(image: torch.Tensor, meta: dict, out_format: str) -> str | bytes: """Save multi-band image in Geotiff file. Args: @@ -107,9 +98,9 @@ def _convert_np_uint8(float_image: torch.Tensor): def read_geotiff( - file_path: Optional[str] = None, - path_type: Optional[str] = None, - file_data: Optional[bytes] = None, + file_path: str | None = None, + path_type: str | None = None, + file_data: bytes | None = None, ) -> tuple[torch.Tensor, dict, tuple[float, float] | None]: """Read all bands from *file_path* and return image + meta info. @@ -123,8 +114,8 @@ def read_geotiff( if all([x is None for x in [file_path, path_type, file_data]]): raise Exception("All input fields to read_geotiff are None") - write_to_file: Optional[bytes] = None - path: Optional[str] = None + write_to_file: bytes | None = None + path: str | None = None if file_data is not None: # with tempfile.NamedTemporaryFile() as tmpfile: # tmpfile.write(file_data) @@ -171,9 +162,9 @@ def read_geotiff( def load_image( data: Union[list[str]], path_type: str, - mean: Optional[list[float]] = None, - std: Optional[list[float]] = None, - indices: Optional[Union[list[int], None]] = None, + mean: list[float] | None = None, + std: list[float] | None = None, + indices: Union[list[int], None] | None = None, ): """Build an input example by loading images in *file_paths*. @@ -219,8 +210,11 @@ def load_image( if len(julian_day) == 3: julian_day = int(julian_day) else: - julian_day = (datetime.datetime.strptime( - julian_day, "%m%d").timetuple().tm_yday) + julian_day = ( + datetime.datetime.strptime(julian_day, "%m%d") + .timetuple() + .tm_yday + ) temporal_coords.append([year, julian_day]) except Exception: logger.exception("Could not extract timestamp for %s", file) @@ -233,11 +227,9 @@ def load_image( class PrithviMultimodalDataProcessor(IOProcessor): - indices = [0, 1, 2, 3, 4, 5] def __init__(self, vllm_config: VllmConfig): - super().__init__(vllm_config) self.datamodule = Sen1Floods11NonGeoDataModule( @@ -264,8 +256,7 @@ def parse_request(self, request: Any) -> IOProcessorInput: return image_prompt if isinstance(request, IOProcessorRequest): if not hasattr(request, "data"): - raise ValueError( - "missing 'data' field in OpenAIBaseModel Request") + raise ValueError("missing 'data' field in OpenAIBaseModel Request") request_data = request.data @@ -277,7 +268,8 @@ def parse_request(self, request: Any) -> IOProcessorInput: raise ValueError("Unable to parse request") def output_to_response( - self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: + self, plugin_output: IOProcessorOutput + ) -> IOProcessorResponse: return IOProcessorResponse( request_id=plugin_output.request_id, data=plugin_output, @@ -286,10 +278,9 @@ def output_to_response( def pre_process( self, prompt: IOProcessorInput, - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, ) -> Union[PromptType, Sequence[PromptType]]: - image_data = dict(prompt) if request_id: @@ -309,10 +300,8 @@ def pre_process( input_data = input_data / 10000 # Convert to range 0-1 self.original_h, self.original_w = input_data.shape[-2:] - pad_h = (self.img_size - - (self.original_h % self.img_size)) % self.img_size - pad_w = (self.img_size - - (self.original_w % self.img_size)) % self.img_size + pad_h = (self.img_size - (self.original_h % self.img_size)) % self.img_size + pad_w = (self.img_size - (self.original_w % self.img_size)) % self.img_size input_data = np.pad( input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), @@ -320,9 +309,9 @@ def pre_process( ) batch = torch.tensor(input_data) - windows = batch.unfold(3, self.img_size, - self.img_size).unfold(4, self.img_size, - self.img_size) + windows = batch.unfold(3, self.img_size, self.img_size).unfold( + 4, self.img_size, self.img_size + ) self.h1, self.w1 = windows.shape[3:5] windows = rearrange( windows, @@ -332,8 +321,11 @@ def pre_process( ) # Split into batches if number of windows > batch_size - num_batches = (windows.shape[0] // self.batch_size - if windows.shape[0] > self.batch_size else 1) + num_batches = ( + windows.shape[0] // self.batch_size + if windows.shape[0] > self.batch_size + else 1 + ) windows = torch.tensor_split(windows, num_batches, dim=0) if temporal_coords: @@ -349,25 +341,27 @@ def pre_process( for window in windows: # Apply standardization window = self.datamodule.test_transform( - image=window.squeeze().numpy().transpose(1, 2, 0)) + image=window.squeeze().numpy().transpose(1, 2, 0) + ) window = self.datamodule.aug(window)["image"] - prompts.append({ - "prompt_token_ids": [1], - "multi_modal_data": { - "pixel_values": window.to(torch.float16)[0], - "location_coords": location_coords.to(torch.float16), - }, - }) + prompts.append( + { + "prompt_token_ids": [1], + "multi_modal_data": { + "pixel_values": window.to(torch.float16)[0], + "location_coords": location_coords.to(torch.float16), + }, + } + ) return prompts def post_process( self, model_output: Sequence[PoolingRequestOutput], - request_id: Optional[str] = None, + request_id: str | None = None, **kwargs, ) -> IOProcessorOutput: - pred_imgs_list = [] if request_id and (request_id in self.requests_cache): @@ -399,7 +393,7 @@ def post_process( ) # Cut padded area back to original size - pred_imgs = pred_imgs[..., :self.original_h, :self.original_w] + pred_imgs = pred_imgs[..., : self.original_h, : self.original_w] # Squeeze (batch size 1) pred_imgs = pred_imgs[0] @@ -407,10 +401,10 @@ def post_process( if not self.meta_data: raise ValueError("No metadata available for the current task") self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) - out_data = save_geotiff(_convert_np_uint8(pred_imgs), self.meta_data, - out_format) + out_data = save_geotiff( + _convert_np_uint8(pred_imgs), self.meta_data, out_format + ) - return ImageRequestOutput(type=out_format, - format="tiff", - data=out_data, - request_id=request_id) + return ImageRequestOutput( + type=out_format, format="tiff", data=out_data, request_id=request_id + ) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py index d4c6628211fb..21a5c3754c36 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/types.py @@ -16,12 +16,10 @@ class DataModuleConfig(TypedDict): no_data_replace: float no_label_replace: int num_workers: int - test_transform: list[ - albumentations.core.transforms_interface.BasicTransform] + test_transform: list[albumentations.core.transforms_interface.BasicTransform] class ImagePrompt(BaseModel): - data_format: Literal["b64_json", "bytes", "url", "path"] """ This is the data type for the input image @@ -45,7 +43,7 @@ class ImagePrompt(BaseModel): class ImageRequestOutput(BaseModel): """ - The output data of an image request to vLLM. + The output data of an image request to vLLM. Args: type (str): The data content type [path, object] diff --git a/tests/plugins/vllm_add_dummy_model/setup.py b/tests/plugins/vllm_add_dummy_model/setup.py index 6307bb63897a..eeffac5d3edd 100644 --- a/tests/plugins/vllm_add_dummy_model/setup.py +++ b/tests/plugins/vllm_add_dummy_model/setup.py @@ -3,10 +3,11 @@ from setuptools import setup -setup(name='vllm_add_dummy_model', - version='0.1', - packages=['vllm_add_dummy_model'], - entry_points={ - 'vllm.general_plugins': - ["register_dummy_model = vllm_add_dummy_model:register"] - }) +setup( + name="vllm_add_dummy_model", + version="0.1", + packages=["vllm_add_dummy_model"], + entry_points={ + "vllm.general_plugins": ["register_dummy_model = vllm_add_dummy_model:register"] + }, +) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py index b2085b01c45c..457187e4b492 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py @@ -19,5 +19,4 @@ def register(): ) if "MyLlava" not in ModelRegistry.get_supported_archs(): - ModelRegistry.register_model("MyLlava", - "vllm_add_dummy_model.my_llava:MyLlava") + ModelRegistry.register_model("MyLlava", "vllm_add_dummy_model.my_llava:MyLlava") diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index fc654f20fff2..a22a10eab47d 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -15,7 +15,6 @@ class MyGemma2Embedding(nn.Module): - is_pooling_model = True hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) @@ -23,19 +22,23 @@ class MyGemma2Embedding(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - self.model = Gemma2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Gemma2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": Pooler.for_encode(pooler_config), - "embed": Pooler.for_embed(pooler_config), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def forward( self, @@ -58,8 +61,8 @@ def forward( return torch.zeros_like(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - weights = self.hf_to_vllm_mapper.apply(weights) - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) + weights = ( + (name, data) for name, data in weights if not name.startswith("lm_head.") + ) return self.model.load_weights(weights) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index b431ad1ed092..9e6f5c3a77e3 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -5,20 +5,22 @@ import torch -from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder, - LlavaForConditionalGeneration, - LlavaMultiModalProcessor, - LlavaProcessingInfo) +from vllm.model_executor.models.llava import ( + LlavaDummyInputsBuilder, + LlavaForConditionalGeneration, + LlavaMultiModalProcessor, + LlavaProcessingInfo, +) from vllm.multimodal import MULTIMODAL_REGISTRY -@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor, - info=LlavaProcessingInfo, - dummy_inputs=LlavaDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + LlavaMultiModalProcessor, + info=LlavaProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder, +) class MyLlava(LlavaForConditionalGeneration): - - def compute_logits(self, - hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # this dummy model always predicts the first token logits = super().compute_logits(hidden_states) if logits is not None: diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py index a6fafff98e9c..c02299f5d44f 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py @@ -9,9 +9,7 @@ class MyOPTForCausalLM(OPTForCausalLM): - - def compute_logits(self, - hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # this dummy model always predicts the first token logits = super().compute_logits(hidden_states) if logits is not None: diff --git a/tests/plugins/vllm_add_dummy_platform/setup.py b/tests/plugins/vllm_add_dummy_platform/setup.py index a531826628cd..b976dddb7fb5 100644 --- a/tests/plugins/vllm_add_dummy_platform/setup.py +++ b/tests/plugins/vllm_add_dummy_platform/setup.py @@ -4,13 +4,15 @@ from setuptools import setup setup( - name='vllm_add_dummy_platform', - version='0.1', - packages=['vllm_add_dummy_platform'], + name="vllm_add_dummy_platform", + version="0.1", + packages=["vllm_add_dummy_platform"], entry_points={ - 'vllm.platform_plugins': [ + "vllm.platform_plugins": [ "dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa ], - "vllm.general_plugins": - ["dummy_custom_ops = vllm_add_dummy_platform:register_ops"], - }) + "vllm.general_plugins": [ + "dummy_custom_ops = vllm_add_dummy_platform:register_ops" + ], + }, +) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py index e38fb2fbf934..f2d516f52b8b 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.placeholder_attn import ( - PlaceholderAttentionBackend) +from vllm.attention.backends.placeholder_attn import PlaceholderAttentionBackend class DummyAttentionBackend(PlaceholderAttentionBackend): - @staticmethod def get_name() -> str: return "Dummy_Backend" diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py index 1fcc3fc66617..b73028574526 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py @@ -15,6 +15,5 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.addition_config = True - def forward_oot(self, *args, - **kwargs) -> tuple[torch.Tensor, torch.Tensor]: + def forward_oot(self, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor]: return super().forward_oot(*args, **kwargs) diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index 30d721304b5c..90cb461a6caf 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -24,7 +24,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # Activate custom ops for v1. compilation_config.custom_ops = ["all"] - def get_attn_backend_cls(self, backend_name, head_size, dtype, - kv_cache_dtype, block_size, use_v1, use_mla, - has_sink, use_sparse): + def get_attn_backend_cls( + self, + backend_name, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink, + use_sparse, + ): return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 3567a701a3af..912b32755e80 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -50,7 +50,6 @@ async def test_prithvi_mae_plugin_online( server: RemoteOpenAIServer, model_name: str, ): - request_payload_url = { "data": { "data": image_url, @@ -60,7 +59,7 @@ async def test_prithvi_mae_plugin_online( }, "priority": 0, "model": model_name, - "softmax": False + "softmax": False, } ret = requests.post( @@ -77,8 +76,8 @@ async def test_prithvi_mae_plugin_online( plugin_data = parsed_response.data assert all( - plugin_data.get(attr) - for attr in ["type", "format", "data", "request_id"]) + plugin_data.get(attr) for attr in ["type", "format", "data", "request_id"] + ) # We just check that the output is a valid base64 string. # Raises an exception and fails the test if the string is corrupted. @@ -87,7 +86,6 @@ async def test_prithvi_mae_plugin_online( @pytest.mark.parametrize("model_name", [MODEL_NAME]) def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): - img_prompt = dict( data=image_url, data_format="url", @@ -98,16 +96,16 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): pooling_params = PoolingParams(task="encode", softmax=False) with vllm_runner( - model_name, - runner="pooling", - skip_tokenizer_init=True, - trust_remote_code=True, - enforce_eager=True, - # Limit the maximum number of parallel requests - # to avoid the model going OOM in CI. - max_num_seqs=1, - model_impl="terratorch", - io_processor_plugin="prithvi_to_tiff", + model_name, + runner="pooling", + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM in CI. + max_num_seqs=1, + model_impl="terratorch", + io_processor_plugin="prithvi_to_tiff", ) as llm_runner: pooler_output = llm_runner.get_llm().encode( img_prompt, @@ -117,8 +115,8 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): # verify the output is formatted as expected for this plugin assert all( - hasattr(output, attr) - for attr in ["type", "format", "data", "request_id"]) + hasattr(output, attr) for attr in ["type", "format", "data", "request_id"] + ) # We just check that the output is a valid base64 string. # Raises an exception and fails the test if the string is corrupted. diff --git a/tests/plugins_tests/test_platform_plugins.py b/tests/plugins_tests/test_platform_plugins.py index 1d7e4475011d..4dace171a8d3 100644 --- a/tests/plugins_tests/test_platform_plugins.py +++ b/tests/plugins_tests/test_platform_plugins.py @@ -10,29 +10,38 @@ def test_platform_plugins(): # simulate workload by running an example import runpy + current_file = __file__ import os + example_file = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(current_file))), - "examples", "offline_inference/basic/basic.py") + "examples", + "offline_inference/basic/basic.py", + ) runpy.run_path(example_file) # check if the plugin is loaded correctly from vllm.platforms import _init_trace, current_platform + assert current_platform.device_name == "DummyDevice", ( f"Expected DummyDevice, got {current_platform.device_name}, " "possibly because current_platform is imported before the plugin" - f" is loaded. The first import:\n{_init_trace}") + f" is loaded. The first import:\n{_init_trace}" + ) def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch): # simulate workload by running an example load_general_plugins() from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + layer = RotaryEmbedding(16, 16, 16, 16, True, torch.float16) assert layer.__class__.__name__ == "DummyRotaryEmbedding", ( f"Expected DummyRotaryEmbedding, got {layer.__class__.__name__}, " - "possibly because the custom op is not registered correctly.") + "possibly because the custom op is not registered correctly." + ) assert hasattr(layer, "addition_config"), ( "Expected DummyRotaryEmbedding to have an 'addition_config' attribute, " - "which is set by the custom op.") + "which is set by the custom op." + ) diff --git a/tests/plugins_tests/test_scheduler_plugins.py b/tests/plugins_tests/test_scheduler_plugins.py index 099869a82ad2..1c37d6a39261 100644 --- a/tests/plugins_tests/test_scheduler_plugins.py +++ b/tests/plugins_tests/test_scheduler_plugins.py @@ -10,7 +10,6 @@ class DummyV1Scheduler(Scheduler): - def schedule(self): raise Exception("Exception raised by DummyV1Scheduler") @@ -23,7 +22,6 @@ def test_scheduler_plugins_v1(monkeypatch: pytest.MonkeyPatch): m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") with pytest.raises(Exception) as exception_info: - engine_args = EngineArgs( model="facebook/opt-125m", enforce_eager=True, # reduce test time @@ -36,5 +34,4 @@ def test_scheduler_plugins_v1(monkeypatch: pytest.MonkeyPatch): engine.add_request("0", "foo", sampling_params) engine.step() - assert str( - exception_info.value) == "Exception raised by DummyV1Scheduler" + assert str(exception_info.value) == "Exception raised by DummyV1Scheduler" diff --git a/tests/quantization/reference_mxfp4.py b/tests/quantization/reference_mxfp4.py index 2ef251933f68..d84659ed035e 100644 --- a/tests/quantization/reference_mxfp4.py +++ b/tests/quantization/reference_mxfp4.py @@ -14,14 +14,15 @@ FLOAT4_EXP_BIAS = 1 FLOAT4_MANTISSA_BITS = 1 -FLOAT16_VAL_TO_ADD = (1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)) -FLOAT16_SIGN_EXPONENT_MASK = (( - (1 << (FLOAT16_EXP_BITS + 1)) - 1) << FLOAT16_MANTISSA_BITS) +FLOAT16_VAL_TO_ADD = 1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1) +FLOAT16_SIGN_EXPONENT_MASK = ( + (1 << (FLOAT16_EXP_BITS + 1)) - 1 +) << FLOAT16_MANTISSA_BITS -BFLOAT16_VAL_TO_ADD = (1 << - (BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)) -BFLOAT16_SIGN_EXPONENT_MASK = (( - (1 << (BFLOAT16_EXP_BITS + 1)) - 1) << BFLOAT16_MANTISSA_BITS) +BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1) +BFLOAT16_SIGN_EXPONENT_MASK = ( + (1 << (BFLOAT16_EXP_BITS + 1)) - 1 +) << BFLOAT16_MANTISSA_BITS def e8m0_to_half(scale, half_dtype: torch.dtype): @@ -30,19 +31,19 @@ def e8m0_to_half(scale, half_dtype: torch.dtype): scale_exp = scale.to(torch.int16) - 127 # This can be implemented with bitwise operations in a proper kernel. - scale_half = 2.0**(scale_exp.to(torch.float)) + scale_half = 2.0 ** (scale_exp.to(torch.float)) return scale_half.to(half_dtype) -def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype, - half_exp_bias: int, half_mantissa_bits: int): +def upcast_fp4_to_fp16_or_bf16( + val, float_dtype: torch.dtype, half_exp_bias: int, half_mantissa_bits: int +): assert val.dtype == torch.uint8 - unpacked = torch.zeros(*val.shape[:-1], - val.shape[-1] * 2, - dtype=torch.uint8, - device=val.device) + unpacked = torch.zeros( + *val.shape[:-1], val.shape[-1] * 2, dtype=torch.uint8, device=val.device + ) unpacked[..., 1::2] = (val >> 4) & 0x0F # Extract high 4 bits. unpacked[..., ::2] = val & 0x0F # Extract low 4 bits. @@ -72,8 +73,11 @@ def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype, new_exp = new_exp.to(torch.int32) sign = sign.to(torch.int32) - qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + ( - new_mantissa << (half_mantissa_bits - 1)) + qdq_val = ( + (sign << 15) + + (new_exp << half_mantissa_bits) + + (new_mantissa << (half_mantissa_bits - 1)) + ) assert qdq_val.max() <= 65535 assert qdq_val.min() >= 0 @@ -84,8 +88,9 @@ def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype, return result -def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor, - float_dtype: torch.dtype) -> torch.Tensor: +def dq_mxfp4_torch( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype +) -> torch.Tensor: assert x.dtype == torch.uint8 assert scale.dtype == torch.uint8 @@ -98,10 +103,12 @@ def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor, scale_half = e8m0_to_half(scale, half_dtype=float_dtype) - x_half = upcast_fp4_to_fp16_or_bf16(x, - float_dtype=float_dtype, - half_exp_bias=half_exp_bias, - half_mantissa_bits=half_mantissa_bits) + x_half = upcast_fp4_to_fp16_or_bf16( + x, + float_dtype=float_dtype, + half_exp_bias=half_exp_bias, + half_mantissa_bits=half_mantissa_bits, + ) x_half = x_half.reshape(*x_half.shape[:-1], -1, 32) x_half = x_half * scale_half[..., None] @@ -110,8 +117,9 @@ def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor, return x_half -def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, - half_exp_bias: int): +def fp16_to_fp4_simulate( + val, half_mantissa_bits: int, half_exp_bits: int, half_exp_bias: int +): # Casts an fp16/bf16 input to the restricted values of float4_e2m1, # that is to say [0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, # -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0]. @@ -119,7 +127,7 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, float_type = val.dtype # "rshift_cuda" not implemented for 'UInt16' - val_view = val.view(torch.int16) #.to(torch.int32) + val_view = val.view(torch.int16) # .to(torch.int32) exp = val_view >> half_mantissa_bits exp = exp & ((1 << half_exp_bits) - 1) @@ -147,23 +155,15 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, tail = mantissa_plus_one & ((1 << tail_bits) - 1) - round_close = (tail < half) # round towards 0 - round_away = (tail > half) # round away from 0 + round_close = tail < half # round towards 0 + round_away = tail > half # round away from 0 tie = tail == half - new_mantissa_close = torch.zeros(val.shape, - device=val.device, - dtype=torch.bool) - new_exp_close = torch.zeros(val.shape, - device=val.device, - dtype=torch.uint16) + new_mantissa_close = torch.zeros(val.shape, device=val.device, dtype=torch.bool) + new_exp_close = torch.zeros(val.shape, device=val.device, dtype=torch.uint16) - new_mantissa_away = torch.zeros(val.shape, - device=val.device, - dtype=torch.bool) - new_exp_away = torch.zeros(val.shape, - device=val.device, - dtype=torch.uint16) + new_mantissa_away = torch.zeros(val.shape, device=val.device, dtype=torch.bool) + new_exp_away = torch.zeros(val.shape, device=val.device, dtype=torch.uint16) new_exp_tie = torch.zeros(val.shape, device=val.device, dtype=torch.uint16) @@ -202,27 +202,29 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, new_exp_tie = (exp > (half_exp_bias - 2)) * (exp + (mantissa_last == 1)) # Gather round up, round down and tie. - new_exp = round_away * new_exp_away \ - + round_close * new_exp_close \ - + tie * new_exp_tie + new_exp = ( + round_away * new_exp_away + round_close * new_exp_close + tie * new_exp_tie + ) - new_mantissa = round_away * new_mantissa_away \ - + round_close * new_mantissa_close + new_mantissa = round_away * new_mantissa_away + round_close * new_mantissa_close # if new_exp > 3: # new_mantissa = 1 - new_mantissa = new_mantissa + (new_exp > - (2 + half_exp_bias)) * (new_mantissa == 0) + new_mantissa = new_mantissa + (new_exp > (2 + half_exp_bias)) * (new_mantissa == 0) # Clamp the exponent to acceptable values. new_exp = (new_exp >= (half_exp_bias - 2)) * torch.clamp( - new_exp, half_exp_bias - 2, half_exp_bias + 2) + new_exp, half_exp_bias - 2, half_exp_bias + 2 + ) sign = sign.to(torch.int32) new_mantissa = new_mantissa.to(torch.int32) - qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + ( - new_mantissa << (half_mantissa_bits - 1)) + qdq_val = ( + (sign << 15) + + (new_exp << half_mantissa_bits) + + (new_mantissa << (half_mantissa_bits - 1)) + ) assert qdq_val.max() <= 65535 assert qdq_val.min() >= 0 @@ -233,8 +235,9 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, return result -def qdq_mxfp4_torch(x: torch.Tensor, - scale_calculation_mode: str = "even") -> torch.Tensor: +def qdq_mxfp4_torch( + x: torch.Tensor, scale_calculation_mode: str = "even" +) -> torch.Tensor: half_dtype = x.dtype if half_dtype == torch.float16: @@ -258,8 +261,7 @@ def qdq_mxfp4_torch(x: torch.Tensor, block_max = block_max.view(torch.uint16).to(torch.int32) - block_max_uint = torch.bitwise_and(block_max + val_to_add, - sign_exponent_mask) + block_max_uint = torch.bitwise_and(block_max + val_to_add, sign_exponent_mask) assert block_max_uint.max() <= 65535 assert block_max_uint.min() >= 0 @@ -268,20 +270,23 @@ def qdq_mxfp4_torch(x: torch.Tensor, block_max = block_max_uint.view(half_dtype) - scale_exp = FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to( - torch.int32) - 2 + scale_exp = ( + FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to(torch.int32) - 2 + ) scale_exp = torch.clamp(scale_exp, 0, 2 * FLOAT8_E8M0_MAX_EXP) - scale = 2.0**(scale_exp - FLOAT8_E8M0_MAX_EXP) + scale = 2.0 ** (scale_exp - FLOAT8_E8M0_MAX_EXP) scale = scale.to(half_dtype) x = x / scale[..., None] - x_fp4 = fp16_to_fp4_simulate(x, - half_exp_bits=half_exp_bits, - half_mantissa_bits=half_mantissa_bits, - half_exp_bias=half_exp_bias) + x_fp4 = fp16_to_fp4_simulate( + x, + half_exp_bits=half_exp_bits, + half_mantissa_bits=half_mantissa_bits, + half_exp_bias=half_exp_bias, + ) x_fp4 = x_fp4 * scale[..., None] return x_fp4.reshape(*x_fp4.shape[:-2], -1) diff --git a/tests/quantization/test_auto_round.py b/tests/quantization/test_auto_round.py index 1c41d904b816..69632ae6cac7 100644 --- a/tests/quantization/test_auto_round.py +++ b/tests/quantization/test_auto_round.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Test model set-up and inference for quantized HF models supported - on the AutoRound. +on the AutoRound. - Validating the configuration and printing results for manual checking. +Validating the configuration and printing results for manual checking. - Run `pytest tests/quantization/test_auto_round.py`. +Run `pytest tests/quantization/test_auto_round.py`. """ import pytest @@ -14,18 +14,19 @@ MODELS = [ "OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", ##auto_round:auto_gptq - "Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound" ##auto_round:auto_awq + "Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound", ##auto_round:auto_awq ] -@pytest.mark.skipif(not current_platform.is_cpu() - and not current_platform.is_xpu() - and not current_platform.is_cuda(), - reason="only supports CPU/XPU/CUDA backend.") +@pytest.mark.skipif( + not current_platform.is_cpu() + and not current_platform.is_xpu() + and not current_platform.is_cuda(), + reason="only supports CPU/XPU/CUDA backend.", +) @pytest.mark.parametrize("model", MODELS) def test_auto_round(vllm_runner, model): with vllm_runner(model) as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=8) + output = llm.generate_greedy(["The capital of France is"], max_tokens=8) assert output print(f"{output[0][1]}") diff --git a/tests/quantization/test_blackwell_moe.py b/tests/quantization/test_blackwell_moe.py index 906693a1f401..218763bc627d 100644 --- a/tests/quantization/test_blackwell_moe.py +++ b/tests/quantization/test_blackwell_moe.py @@ -11,8 +11,9 @@ from vllm.platforms import current_platform if not current_platform.is_device_capability(100): - pytest.skip("This test only runs on Blackwell GPUs (SM100).", - allow_module_level=True) + pytest.skip( + "This test only runs on Blackwell GPUs (SM100).", allow_module_level=True + ) os.environ["FLASHINFER_NVCC_THREADS"] = "16" @@ -22,7 +23,6 @@ def can_initialize(model: str, extra_args: Optional[list[str]] = None): - # Server arguments extra_args = extra_args if extra_args is not None else [] server_args = [ @@ -40,10 +40,11 @@ def can_initialize(model: str, extra_args: Optional[list[str]] = None): # Launch server and make a simple request with RemoteOpenAIServer( - model, - server_args, - max_wait_seconds=1000, # Due to FlashInfer compile - override_hf_configs=dummy_hf_overrides) as server: + model, + server_args, + max_wait_seconds=1000, # Due to FlashInfer compile + override_hf_configs=dummy_hf_overrides, + ) as server: client = server.get_client() # Make a simple request to verify the server works completion = client.completions.create( @@ -59,20 +60,21 @@ def can_initialize(model: str, extra_args: Optional[list[str]] = None): ## Llama4 ## -@pytest.mark.skip(reason=( - "RuntimeError: run_moe() Expected a value of type " - "'Optional[List[Tensor]]' for argument '_9' but instead found type " - "'list'.")) -def test_llama4_fp8_tensor_moe_flashinfer_cutlass( - monkeypatch: pytest.MonkeyPatch): +@pytest.mark.skip( + reason=( + "RuntimeError: run_moe() Expected a value of type " + "'Optional[List[Tensor]]' for argument '_9' but instead found type " + "'list'." + ) +) +def test_llama4_fp8_tensor_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8") @pytest.mark.skip(reason="Works, but takes too long to run") -def test_llama4_fp8_tensor_moe_flashinfer_trtllm( - monkeypatch: pytest.MonkeyPatch): +def test_llama4_fp8_tensor_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8") @@ -100,24 +102,25 @@ def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch): can_initialize("deepseek-ai/DeepSeek-V3.1") -@pytest.mark.skip(reason=("Known issue: lack of kernel support. " - "Expected failure: assert self.block_quant is None")) -def test_deepseek_fp8_block_moe_flashinfer_cutlass( - monkeypatch: pytest.MonkeyPatch): +@pytest.mark.skip( + reason=( + "Known issue: lack of kernel support. " + "Expected failure: assert self.block_quant is None" + ) +) +def test_deepseek_fp8_block_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") can_initialize("deepseek-ai/DeepSeek-V3.1") -def test_deepseek_fp8_block_moe_flashinfer_trtllm( - monkeypatch: pytest.MonkeyPatch): +def test_deepseek_fp8_block_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency") can_initialize("deepseek-ai/DeepSeek-V3.1") -def test_deepseek_nvfp4_moe_flashinfer_cutlass( - monkeypatch: pytest.MonkeyPatch): +def test_deepseek_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1") monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput") can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2") @@ -138,13 +141,11 @@ def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch): can_initialize("openai/gpt-oss-20b") -def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass( - monkeypatch: pytest.MonkeyPatch): +def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1") can_initialize("openai/gpt-oss-20b") -def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm( - monkeypatch: pytest.MonkeyPatch): +def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1") can_initialize("openai/gpt-oss-20b") diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index af8c7ec3b482..824d927724e0 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -13,18 +13,25 @@ from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensors24, CompressedTensorsLinearMethod, - CompressedTensorsW4A4Fp4, CompressedTensorsW4A8Fp8, - CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) + CompressedTensors24, + CompressedTensorsLinearMethod, + CompressedTensorsW4A4Fp4, + CompressedTensorsW4A8Fp8, + CompressedTensorsW4A16Fp4, + CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16, +) from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp) +from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported) + cutlass_fp4_supported, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - sparse_cutlass_supported) + sparse_cutlass_supported, +) from vllm.platforms import current_platform # AITER only supports per-channel-per-channel INT8 gemm @@ -32,7 +39,7 @@ # It does not support mix precision MM and mix quantization scheme. ROCM_AITER_SUPPORTED_INT8_MODEL = [ "neuralmagic/Llama-3.2-1B-quantized.w8a8", - "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2" + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", ] # TritonScaledMMLinearKernel only supports symmetric quantization. @@ -80,8 +87,10 @@ def enable_pickle(monkeypatch): def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): model_path, strategy, quant_type, shape_0, is_symmetric = model_args - if current_platform.is_rocm( - ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + if ( + current_platform.is_rocm() + and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL + ): pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") with vllm_runner(model_path, enforce_eager=True) as llm: @@ -106,14 +115,10 @@ def zp_valid(zp: Optional[torch.Tensor]): assert zp_valid(gate_up_proj.input_zero_point) assert zp_valid(down_proj.input_zero_point) - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) - assert isinstance(o_proj.quant_method, - CompressedTensorsLinearMethod) - assert isinstance(gate_up_proj.quant_method, - CompressedTensorsLinearMethod) - assert isinstance(down_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8) assert qkv_proj.scheme.strategy == strategy @@ -151,7 +156,8 @@ def zp_valid(zp: Optional[torch.Tensor]): @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [10]) @pytest.mark.parametrize( - "use_aiter", [True, False] if current_platform.is_rocm() else [False]) + "use_aiter", [True, False] if current_platform.is_rocm() else [False] +) def test_compressed_tensors_w8a8_logprobs( hf_runner, vllm_runner, @@ -162,15 +168,15 @@ def test_compressed_tensors_w8a8_logprobs( use_aiter, monkeypatch, ): - - if current_platform.is_rocm( - ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + if ( + current_platform.is_rocm() + and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL + ): pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") if use_aiter: if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL: - pytest.skip( - f"Skip model {model_path} as it is not support by aiter.") + pytest.skip(f"Skip model {model_path} as it is not support by aiter.") # this will enable VLLM_ROCM_USE_AITER_LINEAR monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -178,18 +184,20 @@ def test_compressed_tensors_w8a8_logprobs( # skip language translation prompt for the static per tensor models if model_path in ( - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", - "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym", + "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym", ): example_prompts = example_prompts[0:-1] with hf_runner(model_path, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) with vllm_runner(model_path, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts, max_tokens, num_logprobs + ) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -225,7 +233,8 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner): ], ) @pytest.mark.parametrize( - "use_aiter", [True, False] if current_platform.is_rocm() else [False]) + "use_aiter", [True, False] if current_platform.is_rocm() else [False] +) def test_compressed_tensors_w8a8_dynamic_per_token( vllm_runner, model_args, @@ -234,14 +243,15 @@ def test_compressed_tensors_w8a8_dynamic_per_token( ): model_path, strategy = model_args - if current_platform.is_rocm( - ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + if ( + current_platform.is_rocm() + and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL + ): pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") if use_aiter: if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL: - pytest.skip( - f"Skip model {model_path} as it is not support by aiter.") + pytest.skip(f"Skip model {model_path} as it is not support by aiter.") # this will enable VLLM_ROCM_USE_AITER_LINEAR monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -252,8 +262,7 @@ def check_model(model): qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8) assert not qkv_proj.scheme.is_static_input_scheme assert qkv_proj.scheme.strategy == strategy @@ -267,21 +276,60 @@ def check_model(model): @pytest.mark.parametrize( "wNa16_args", - [("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8, - True, False), - ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8, True, - False), - ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4, - True, False), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", "group", 128, - 8, False, False), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel", - "channel", None, 8, False, False), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder", - "group", 128, 8, False, True)], + [ + ( + "nm-testing/tinyllama-oneshot-w4a16-channel-v2", + "channel", + None, + 8, + True, + False, + ), + ( + "nm-testing/tinyllama-oneshot-w4a16-group128-v2", + "group", + 128, + 8, + True, + False, + ), + ( + "nm-testing/tinyllama-oneshot-w8a16-per-channel", + "channel", + None, + 4, + True, + False, + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-awq-group128-asym256", + "group", + 128, + 8, + False, + False, + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-Channel", + "channel", + None, + 8, + False, + False, + ), + ( + "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder", + "group", + 128, + 8, + False, + True, + ), + ], +) +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="The tests are skipped on non-CUDA platform." ) -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="The tests are skipped on non-CUDA platform.") def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args with vllm_runner(model) as llm: @@ -290,13 +338,11 @@ def check_model(model): layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16) assert qkv_proj.scheme.strategy == strategy - assert qkv_proj.scheme.group_size == (-1 - if group is None else group) + assert qkv_proj.scheme.group_size == (-1 if group is None else group) assert qkv_proj.scheme.pack_factor == pack_factor assert qkv_proj.scheme.symmetric == symmetric @@ -308,8 +354,9 @@ def check_model(model): assert output -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test is skipped on non-CUDA platform.") +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) def test_compressed_tensors_w4a16_marlin24(vllm_runner): model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" with vllm_runner(model_path) as llm: @@ -319,8 +366,7 @@ def check_model(model): qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24) assert qkv_proj.weight_packed.dtype is torch.int32 @@ -339,8 +385,7 @@ def check_model(model): qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance( qkv_proj.scheme, (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8), @@ -362,9 +407,11 @@ def check_model(model): @pytest.mark.skipif( not current_platform.is_kv_cache_dtype_supported("fp8", None), - reason="FP8 KV cache is not supported on this device.") -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test is skipped on non-CUDA platform.") + reason="FP8 KV cache is not supported on this device.", +) +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) def test_compressed_tensors_kv_cache(vllm_runner): model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: @@ -376,10 +423,7 @@ def test_compressed_tensors_kv_cache(vllm_runner): not sparse_cutlass_supported(), reason="Sparse FP8 is not yet supported on this GPU type.", ) -def _test_2of4_quant_models(qkv_proj, - weight_strategy, - input_strategy, - format="dense"): +def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy, format="dense"): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) @@ -393,8 +437,7 @@ def _test_2of4_quant_models(qkv_proj, @pytest.mark.skipif( - not current_platform.is_cuda() - or not current_platform.has_device_capability(90), + not current_platform.is_cuda() or not current_platform.has_device_capability(90), reason="Sparse FP8 is not yet supported on this GPU type.", ) @pytest.mark.parametrize( @@ -441,8 +484,7 @@ def check_model(model): @pytest.mark.skipif( - not current_platform.is_cuda() - or not current_platform.has_device_capability(90), + not current_platform.is_cuda() or not current_platform.has_device_capability(90), reason="Sparse FP8 is not yet supported on this GPU type.", ) @pytest.mark.parametrize( @@ -603,17 +645,14 @@ def check_model(model): layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) assert qkv_proj.scheme.weight_quant is None assert qkv_proj.scheme.input_quant is None assert not qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map - sparsity_map = ( - qkv_proj.quant_method.quantization_config.sparsity_scheme_map - ) # noqa: E501 + sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 assert sparsity_map.get("Linear").format == "dense" assert sparsity_map.get("Linear").sparsity_structure == "2:4" @@ -629,7 +668,8 @@ def check_model(model): reason="Cutlass is not yet supported on this GPU type.", ) @pytest.mark.parametrize( - "args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]) + "args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")] +) def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4): model = args_2of4 with vllm_runner(model) as llm: @@ -638,17 +678,14 @@ def check_model(model): layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensors24) assert qkv_proj.scheme.weight_quant is None assert qkv_proj.scheme.input_quant is None assert not qkv_proj.scheme.quantized assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map - sparsity_map = ( - qkv_proj.quant_method.quantization_config.sparsity_scheme_map - ) # noqa: E501 + sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501 assert sparsity_map.get("Linear").format == "sparse-24-bitmask" assert sparsity_map.get("Linear").sparsity_structure == "2:4" @@ -661,9 +698,11 @@ def check_model(model): @pytest.mark.parametrize( "args", - [("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", - CompressedTensorsW4A16Fp4), - ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4)]) + [ + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4), + ], +) def test_compressed_tensors_nvfp4(vllm_runner, args): model, scheme = args with vllm_runner(model, enforce_eager=True) as llm: @@ -672,11 +711,12 @@ def check_model(model): layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) - if isinstance(qkv_proj.scheme, scheme) or isinstance( - qkv_proj.scheme, - CompressedTensorsW4A16Fp4) and not cutlass_fp4_supported(): + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + if ( + isinstance(qkv_proj.scheme, scheme) + or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4) + and not cutlass_fp4_supported() + ): assert True else: raise AssertionError("FP4 Scheme Mismatch") @@ -690,13 +730,13 @@ def check_model(model): @pytest.mark.skipif( - not current_platform.is_cuda() - or not current_platform.has_device_capability(90), + not current_platform.is_cuda() or not current_platform.has_device_capability(90), reason="W4A8 FP8 is not yet supported on this GPU type.", ) -@pytest.mark.parametrize("args", [ - ("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8) -]) +@pytest.mark.parametrize( + "args", + [("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8)], +) def test_compressed_tensors_w4a8_fp8(vllm_runner, args): model, scheme = args with vllm_runner(model, enforce_eager=True) as llm: @@ -710,8 +750,7 @@ def check_model(model): down_proj = layer.mlp.down_proj for proj in (qkv_proj, o_proj, gate_up_proj, down_proj): - assert isinstance(proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(proj.scheme, scheme) assert proj.weight_packed.dtype is torch.int32 @@ -725,22 +764,27 @@ def check_model(model): assert output -@pytest.mark.skipif(not current_platform.is_cuda(), - reason="This test is skipped on non-CUDA platform.") -@pytest.mark.parametrize("model,prompt,exp_perplexity", [ - ( - "nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", - "Flat is better than nested.\nSparse is better than dense.", - 150.0, - ), - ( - "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", - "Flat is better than nested.\nSparse is better than dense.", - 150.0, - ), -]) -def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, - exp_perplexity): +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +@pytest.mark.parametrize( + "model,prompt,exp_perplexity", + [ + ( + "nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", + "Flat is better than nested.\nSparse is better than dense.", + 150.0, + ), + ( + "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", + "Flat is better than nested.\nSparse is better than dense.", + 150.0, + ), + ], +) +def test_compressed_tensors_transforms_perplexity( + vllm_runner, model, prompt, exp_perplexity +): with vllm_runner(model, enforce_eager=True) as llm: perplexity = llm.generate_prompt_perplexity([prompt])[0] print(perplexity) @@ -750,26 +794,24 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt, def test_compressed_tensors_fp8_block_enabled(vllm_runner): model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK" with vllm_runner(model_path) as llm: - fp8_dtype = current_platform.fp8_dtype() def check_model(model): layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj - assert isinstance(qkv_proj.quant_method, - CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8) - assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear, - W8A8BlockFp8LinearOp) + assert isinstance( + qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp + ) assert qkv_proj.weight.dtype is fp8_dtype assert qkv_proj.weight_scale.dtype is torch.float32 assert len(qkv_proj.weight.shape) == 2 assert len(qkv_proj.weight_scale.shape) == 2 - input_quant_op = \ - qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op + input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op assert isinstance(input_quant_op, QuantFP8) assert input_quant_op._forward_method == input_quant_op.forward_cuda diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py index 1843bffd2115..797b565b91af 100644 --- a/tests/quantization/test_configs.py +++ b/tests/quantization/test_configs.py @@ -33,7 +33,6 @@ class ModelPair: ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"), ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"), ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"), - # AUTOAWQ ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq_marlin"), ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"), @@ -55,4 +54,5 @@ def test_auto_gptq(model_arg_exptype: tuple[str, None, str]) -> None: assert found_quantization_type == expected_type, ( f"Expected quant_type == {expected_type} for {model_path}, " f"but found {found_quantization_type} " - f"for no --quantization {quantization_arg} case") + f"for no --quantization {quantization_arg} case" + ) diff --git a/tests/quantization/test_cpu_offload.py b/tests/quantization/test_cpu_offload.py index 08d9573ecf0b..25d1dc59f617 100644 --- a/tests/quantization/test_cpu_offload.py +++ b/tests/quantization/test_cpu_offload.py @@ -1,77 +1,108 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Expanded quantized model tests for CPU offloading -# Base tests: tests/basic_correctness/test_cpu_offload.py - -import pytest - -from tests.quantization.utils import is_quant_method_supported - -from ..utils import compare_two_settings - - -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="fp8 is not supported on this GPU type.") -def test_cpu_offload_fp8(): - # Test quantization of an unquantized checkpoint - compare_two_settings("meta-llama/Llama-3.2-1B-Instruct", - ["--quantization", "fp8"], - ["--quantization", "fp8", "--cpu-offload-gb", "1"], - max_wait_seconds=480) - # Test loading a quantized checkpoint - compare_two_settings("neuralmagic/Qwen2-1.5B-Instruct-FP8", [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480) - - -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="gptq_marlin is not supported on this GPU type.") -def test_cpu_offload_gptq(monkeypatch): - # This quant method is sensitive to dummy weights, so we force real weights - monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto') - # Test GPTQ Marlin - compare_two_settings("Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480) - # Test GPTQ - compare_two_settings("Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", - ["--quantization", "gptq"], - ["--quantization", "gptq", "--cpu-offload-gb", "1"], - max_wait_seconds=480) - - -@pytest.mark.skipif(not is_quant_method_supported("awq_marlin"), - reason="awq_marlin is not supported on this GPU type.") -def test_cpu_offload_awq(monkeypatch): - # This quant method is sensitive to dummy weights, so we force real weights - monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto') - # Test AWQ Marlin - compare_two_settings("Qwen/Qwen2-1.5B-Instruct-AWQ", [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480) - # Test AWQ - compare_two_settings("Qwen/Qwen2-1.5B-Instruct-AWQ", - ["--quantization", "awq"], - ["--quantization", "awq", "--cpu-offload-gb", "1"], - max_wait_seconds=480) - - -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="gptq_marlin is not supported on this GPU type.") -def test_cpu_offload_compressed_tensors(monkeypatch): - # This quant method is sensitive to dummy weights, so we force real weights - monkeypatch.setenv('VLLM_TEST_FORCE_LOAD_FORMAT', 'auto') - # Test wNa16 - compare_two_settings("nm-testing/tinyllama-oneshot-w4a16-channel-v2", [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480) - # Test w4a16_marlin24 - compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", - [], ["--cpu-offload-gb", "1"], - max_wait_seconds=480) - # Test w8a8 - compare_two_settings( - "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", [], - ["--cpu-offload-gb", "1"], - max_wait_seconds=480) +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Expanded quantized model tests for CPU offloading +# Base tests: tests/basic_correctness/test_cpu_offload.py + +import pytest + +from tests.quantization.utils import is_quant_method_supported + +from ..utils import compare_two_settings + + +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.", +) +def test_cpu_offload_fp8(): + # Test quantization of an unquantized checkpoint + compare_two_settings( + "meta-llama/Llama-3.2-1B-Instruct", + ["--quantization", "fp8"], + ["--quantization", "fp8", "--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + # Test loading a quantized checkpoint + compare_two_settings( + "neuralmagic/Qwen2-1.5B-Instruct-FP8", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="gptq_marlin is not supported on this GPU type.", +) +def test_cpu_offload_gptq(monkeypatch): + # This quant method is sensitive to dummy weights, so we force real weights + monkeypatch.setenv("VLLM_TEST_FORCE_LOAD_FORMAT", "auto") + # Test GPTQ Marlin + compare_two_settings( + "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + # Test GPTQ + compare_two_settings( + "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4", + ["--quantization", "gptq"], + ["--quantization", "gptq", "--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("awq_marlin"), + reason="awq_marlin is not supported on this GPU type.", +) +def test_cpu_offload_awq(monkeypatch): + # This quant method is sensitive to dummy weights, so we force real weights + monkeypatch.setenv("VLLM_TEST_FORCE_LOAD_FORMAT", "auto") + # Test AWQ Marlin + compare_two_settings( + "Qwen/Qwen2-1.5B-Instruct-AWQ", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + # Test AWQ + compare_two_settings( + "Qwen/Qwen2-1.5B-Instruct-AWQ", + ["--quantization", "awq"], + ["--quantization", "awq", "--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("gptq_marlin"), + reason="gptq_marlin is not supported on this GPU type.", +) +def test_cpu_offload_compressed_tensors(monkeypatch): + # This quant method is sensitive to dummy weights, so we force real weights + monkeypatch.setenv("VLLM_TEST_FORCE_LOAD_FORMAT", "auto") + # Test wNa16 + compare_two_settings( + "nm-testing/tinyllama-oneshot-w4a16-channel-v2", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + # Test w4a16_marlin24 + compare_two_settings( + "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) + # Test w8a8 + compare_two_settings( + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + [], + ["--cpu-offload-gb", "1"], + max_wait_seconds=480, + ) diff --git a/tests/quantization/test_experts_int8.py b/tests/quantization/test_experts_int8.py index 1e3e69e008bd..2a72f734e431 100644 --- a/tests/quantization/test_experts_int8.py +++ b/tests/quantization/test_experts_int8.py @@ -2,9 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # flake8: noqa -"""Tests experts_int8 quantization startup and generation, +"""Tests experts_int8 quantization startup and generation, doesn't test correctness """ + import pytest from tests.quantization.utils import is_quant_method_supported @@ -14,8 +15,10 @@ MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"] -@pytest.mark.skipif(not is_quant_method_supported("experts_int8"), - reason="ExpertsInt8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("experts_int8"), + reason="ExpertsInt8 is not supported on this GPU type.", +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [10]) @@ -30,6 +33,5 @@ def test_model_experts_int8_startup( model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_transformers_version(on_fail="skip") - with vllm_runner(model, dtype=dtype, - quantization="experts_int8") as vllm_model: + with vllm_runner(model, dtype=dtype, quantization="experts_int8") as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index db53061cf2d1..6b9a33059815 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -4,13 +4,16 @@ Run `pytest tests/quantization/test_fp8.py --forked`. """ + import pytest import torch from tests.quantization.utils import is_quant_method_supported from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod, - Fp8LinearMethod) +from vllm.model_executor.layers.quantization.fp8 import ( + Fp8KVCacheMethod, + Fp8LinearMethod, +) from vllm.platforms import current_platform MODELS = [ @@ -20,15 +23,18 @@ ] -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_id", MODELS) @pytest.mark.parametrize("force_marlin", [False, True]) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, - use_rocm_aiter: bool, monkeypatch) -> None: - + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_model_load_and_run( + vllm_runner, model_id: str, force_marlin: bool, use_rocm_aiter: bool, monkeypatch +) -> None: if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -50,13 +56,17 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, ] -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("model_id", KV_CACHE_MODELS) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, - use_rocm_aiter: bool, monkeypatch): + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_kv_cache_model_load_and_run( + vllm_runner, model_id: str, use_rocm_aiter: bool, monkeypatch +): if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -93,14 +103,22 @@ def check_model(model): print(outputs[0][1]) -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) @pytest.mark.parametrize("force_marlin", [False, True]) @pytest.mark.parametrize( - "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) -def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, - use_rocm_aiter: bool, monkeypatch) -> None: + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_load_fp16_model( + vllm_runner, + kv_cache_dtype: str, + force_marlin: bool, + use_rocm_aiter: bool, + monkeypatch, +) -> None: if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -110,9 +128,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, if force_marlin: monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1") - with vllm_runner("facebook/opt-125m", - quantization="fp8", - kv_cache_dtype=kv_cache_dtype) as llm: + with vllm_runner( + "facebook/opt-125m", quantization="fp8", kv_cache_dtype=kv_cache_dtype + ) as llm: def check_model(model): fc1 = model.model.decoder.layers[0].fc1 @@ -139,26 +157,29 @@ def check_model(model): pytest.skip( "Skip `test_load_fp16_model`. " "It only runs on ROCm platform with FP8 compute." - " e.g. MI300X and above.") + " e.g. MI300X and above." + ) else: # unsupported platform - pytest.skip("Skip `test_load_fp16_model`. " - "It only runs on CUDA and ROCm platform.") + pytest.skip( + "Skip `test_load_fp16_model`. " + "It only runs on CUDA and ROCm platform." + ) llm.apply_model(check_model) -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("fp8"), + reason="FP8 is not supported on this GPU type.", +) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_scaled_fp8_quant(dtype) -> None: - def quantize_ref(tensor, inv_scale): # The reference implementation that fully aligns to # the kernel being tested. finfo = torch.finfo(torch.float8_e4m3fn) scale = inv_scale.reciprocal() - qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, - max=finfo.max) + qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) qweight = qweight.to(torch.float8_e4m3fn) return qweight @@ -177,26 +198,23 @@ def per_tensor_dequantize(tensor, inv_scale, dtype): # Reference dynamic quantizaton y = quantize_ref(x, inv_scale) - torch.testing.assert_close(ref_y, - per_tensor_dequantize(y, inv_scale, dtype)) + torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) # Static quantization y, _ = ops.scaled_fp8_quant(x, inv_scale) - torch.testing.assert_close(ref_y, - per_tensor_dequantize(y, inv_scale, dtype)) + torch.testing.assert_close(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) # Padding y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17) assert y.shape[0] == 17 torch.testing.assert_close( ref_y, - per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, - dtype)) + per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, dtype), + ) # non-contiguous input with padding m, n, padded_stride = 975, 512, 576 - padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * - 13).to(dtype) + padded_tensor = (torch.randn(size=(m, padded_stride), device="cuda") * 13).to(dtype) x_nc = padded_tensor[:, :n] # shape (m, n) with stride (padded_stride, 1) assert not x_nc.is_contiguous() @@ -209,19 +227,21 @@ def per_tensor_dequantize(tensor, inv_scale, dtype): # reference dynamic quantization y_nc = quantize_ref(x_nc, inv_scale_nc) torch.testing.assert_close( - ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)) + ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype) + ) # static quantization y_nc, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc) torch.testing.assert_close( - ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype)) + ref_y_nc, per_tensor_dequantize(y_nc, inv_scale_nc, dtype) + ) # padding after non-contiguous input quantization - y_nc_pad, _ = ops.scaled_fp8_quant(x_nc, - inv_scale_nc, - num_token_padding=m + 10) + y_nc_pad, _ = ops.scaled_fp8_quant(x_nc, inv_scale_nc, num_token_padding=m + 10) assert y_nc_pad.shape[0] == m + 10 torch.testing.assert_close( ref_y_nc, - per_tensor_dequantize(torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), - inv_scale_nc, dtype)) + per_tensor_dequantize( + torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype + ), + ) diff --git a/tests/quantization/test_gptq_dynamic.py b/tests/quantization/test_gptq_dynamic.py index 00a5946ed015..c71f4b815611 100644 --- a/tests/quantization/test_gptq_dynamic.py +++ b/tests/quantization/test_gptq_dynamic.py @@ -10,10 +10,10 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinLinearMethod from vllm.model_executor.layers.quantization.utils.gptq_utils import ( - get_dynamic_override) + get_dynamic_override, +) PROMPT = "On the surface of Mars, we found" @@ -21,56 +21,59 @@ # The second layer is quantized using bits=8, group_size=32 # All other layers (layer index >= 2) are not quantized MODEL_QUANT = [ - ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", - True), - ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", - False), + ("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue", True), + ( + "ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse", + False, + ), ] @pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT) -def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool, - monkeypatch): +def test_gptq_with_dynamic( + vllm_runner, model_id: str, use_marlin_kernel: bool, monkeypatch +): # `LLM.apply_model` requires pickling a function. monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else ( - GPTQLinearMethod) + linear_method_cls = ( + GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod) + ) with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as llm: def check_model(model): for name, submodule in model.named_modules(): if name == "lm_head": - assert isinstance(submodule.quant_method, - linear_method_cls) - elif name == 'model.layers.0.self_attn.qkv_proj': + assert isinstance(submodule.quant_method, linear_method_cls) + elif name == "model.layers.0.self_attn.qkv_proj": # The first layer is quantized using bits=4, group_size=128 # desc_act=True - assert isinstance(submodule.quant_method, - linear_method_cls) + assert isinstance(submodule.quant_method, linear_method_cls) config = submodule.quant_method.quant_config assert config.weight_bits == 4 assert config.group_size == 128 assert config.desc_act - elif name == 'model.layers.1.self_attn.qkv_proj': + elif name == "model.layers.1.self_attn.qkv_proj": # The second layer is quantized using bits=8, group_size=32 # desc_act=False - assert isinstance(submodule.quant_method, - linear_method_cls) + assert isinstance(submodule.quant_method, linear_method_cls) config = submodule.quant_method.quant_config - assert get_dynamic_override(config, - layer_name=name, - key="bits") == 8 - assert get_dynamic_override(config, - layer_name=name, - key="group_size") == 32 + assert ( + get_dynamic_override(config, layer_name=name, key="bits") == 8 + ) + assert ( + get_dynamic_override(config, layer_name=name, key="group_size") + == 32 + ) assert not get_dynamic_override( - config, layer_name=name, key="desc_act") - elif (name == 'model.layers.2.self_attn.qkv_proj' - or name == 'model.layers.2.mlp.gate_up_proj'): + config, layer_name=name, key="desc_act" + ) + elif ( + name == "model.layers.2.self_attn.qkv_proj" + or name == "model.layers.2.mlp.gate_up_proj" + ): # All other layers (layer index >= 2) are not quantized - assert isinstance(submodule.quant_method, - UnquantizedLinearMethod) + assert isinstance(submodule.quant_method, UnquantizedLinearMethod) llm.apply_model(check_model) diff --git a/tests/quantization/test_ipex_quant.py b/tests/quantization/test_ipex_quant.py index 34b1b6c2e5b6..ae9b1df3377d 100644 --- a/tests/quantization/test_ipex_quant.py +++ b/tests/quantization/test_ipex_quant.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Test model set-up and inference for quantized HF models supported - on the CPU/GPU backend using IPEX (including AWQ/GPTQ). - - Validating the configuration and printing results for manual checking. +on the CPU/GPU backend using IPEX (including AWQ/GPTQ). - Run `pytest tests/quantization/test_ipex_quant.py`. +Validating the configuration and printing results for manual checking. + +Run `pytest tests/quantization/test_ipex_quant.py`. """ import pytest @@ -19,14 +19,14 @@ DTYPE = ["bfloat16"] -@pytest.mark.skipif(not current_platform.is_cpu() - and not current_platform.is_xpu(), - reason="only supports Intel CPU/XPU backend.") +@pytest.mark.skipif( + not current_platform.is_cpu() and not current_platform.is_xpu(), + reason="only supports Intel CPU/XPU backend.", +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", DTYPE) def test_ipex_quant(vllm_runner, model, dtype): with vllm_runner(model, dtype=dtype) as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output print(output) diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index e69d4ad349c3..bae8b7f7d535 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -9,10 +9,10 @@ import torch from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinLinearMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( - UnquantizedEmbeddingMethod) + UnquantizedEmbeddingMethod, +) PROMPT = "On the surface of Mars, we found" @@ -31,20 +31,20 @@ def test_lm_head( ) -> None: # `LLM.apply_model` requires pickling a function. monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - with vllm_runner(model_id, dtype=torch.float16, - max_model_len=2048) as vllm_model: + with vllm_runner(model_id, dtype=torch.float16, max_model_len=2048) as vllm_model: def check_model(model): lm_head_layer = model.lm_head if lm_head_quantized: - assert isinstance(lm_head_layer.quant_method, - (GPTQLinearMethod, GPTQMarlinLinearMethod)) + assert isinstance( + lm_head_layer.quant_method, + (GPTQLinearMethod, GPTQMarlinLinearMethod), + ) else: - assert isinstance(lm_head_layer.quant_method, - UnquantizedEmbeddingMethod) + assert isinstance( + lm_head_layer.quant_method, UnquantizedEmbeddingMethod + ) vllm_model.apply_model(check_model) - print( - vllm_model.generate_greedy(["Hello my name is"], - max_tokens=10)[0][1]) + print(vllm_model.generate_greedy(["Hello my name is"], max_tokens=10)[0][1]) diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py index e7174be73626..8abf65d29784 100644 --- a/tests/quantization/test_modelopt.py +++ b/tests/quantization/test_modelopt.py @@ -19,21 +19,26 @@ def enable_pickle(monkeypatch): monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") -@pytest.mark.skipif(not is_quant_method_supported("modelopt"), - reason="ModelOpt FP8 is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("modelopt"), + reason="ModelOpt FP8 is not supported on this GPU type.", +) def test_modelopt_fp8_checkpoint_setup(vllm_runner): """Test ModelOpt FP8 checkpoint loading and structure validation.""" # TODO: provide a small publicly available test checkpoint - model_path = ("/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/" - "TinyLlama-1.1B-Chat-v1.0-fp8-0710") + model_path = ( + "/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/" + "TinyLlama-1.1B-Chat-v1.0-fp8-0710" + ) # Skip test if checkpoint doesn't exist if not os.path.exists(model_path): - pytest.skip(f"Test checkpoint not found at {model_path}. " - "This test requires a local ModelOpt FP8 checkpoint.") + pytest.skip( + f"Test checkpoint not found at {model_path}. " + "This test requires a local ModelOpt FP8 checkpoint." + ) - with vllm_runner(model_path, quantization="modelopt", - enforce_eager=True) as llm: + with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm: def check_model(model): layer = model.model.layers[0] @@ -45,11 +50,12 @@ def check_model(model): # Check that ModelOpt quantization method is properly applied from vllm.model_executor.layers.quantization.modelopt import ( - ModelOptFp8LinearMethod) + ModelOptFp8LinearMethod, + ) + assert isinstance(qkv_proj.quant_method, ModelOptFp8LinearMethod) assert isinstance(o_proj.quant_method, ModelOptFp8LinearMethod) - assert isinstance(gate_up_proj.quant_method, - ModelOptFp8LinearMethod) + assert isinstance(gate_up_proj.quant_method, ModelOptFp8LinearMethod) assert isinstance(down_proj.quant_method, ModelOptFp8LinearMethod) # Check weight dtype is FP8 @@ -59,23 +65,23 @@ def check_model(model): assert down_proj.weight.dtype == torch.float8_e4m3fn # Check scales are present and have correct dtype - assert hasattr(qkv_proj, 'weight_scale') - assert hasattr(qkv_proj, 'input_scale') + assert hasattr(qkv_proj, "weight_scale") + assert hasattr(qkv_proj, "input_scale") assert qkv_proj.weight_scale.dtype == torch.float32 assert qkv_proj.input_scale.dtype == torch.float32 - assert hasattr(o_proj, 'weight_scale') - assert hasattr(o_proj, 'input_scale') + assert hasattr(o_proj, "weight_scale") + assert hasattr(o_proj, "input_scale") assert o_proj.weight_scale.dtype == torch.float32 assert o_proj.input_scale.dtype == torch.float32 - assert hasattr(gate_up_proj, 'weight_scale') - assert hasattr(gate_up_proj, 'input_scale') + assert hasattr(gate_up_proj, "weight_scale") + assert hasattr(gate_up_proj, "input_scale") assert gate_up_proj.weight_scale.dtype == torch.float32 assert gate_up_proj.input_scale.dtype == torch.float32 - assert hasattr(down_proj, 'weight_scale') - assert hasattr(down_proj, 'input_scale') + assert hasattr(down_proj, "weight_scale") + assert hasattr(down_proj, "input_scale") assert down_proj.weight_scale.dtype == torch.float32 assert down_proj.input_scale.dtype == torch.float32 diff --git a/tests/quantization/test_ptpc_fp8.py b/tests/quantization/test_ptpc_fp8.py index 088b68510cff..e8ea4148585b 100644 --- a/tests/quantization/test_ptpc_fp8.py +++ b/tests/quantization/test_ptpc_fp8.py @@ -4,18 +4,19 @@ Run `pytest tests/quantization/test_ptpc_fp8.py --forked`. """ + import pytest import torch from tests.quantization.utils import is_quant_method_supported from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod -from vllm.model_executor.layers.quantization.ptpc_fp8 import ( - PTPCFp8LinearMethod) +from vllm.model_executor.layers.quantization.ptpc_fp8 import PTPCFp8LinearMethod from vllm.platforms import current_platform UNSUPPORTED_STR = ( "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only " - "support output dtype of bfloat16. torch.float16 is specified.") + "support output dtype of bfloat16. torch.float16 is specified." +) @pytest.fixture(scope="function", autouse=True) @@ -24,18 +25,21 @@ def enable_pickle(monkeypatch): monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") -@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"), - reason="PTPC FP8 is not supported on this GPU type.") -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="This test is for ROCm GPU.") +@pytest.mark.skipif( + not is_quant_method_supported("ptpc_fp8"), + reason="PTPC FP8 is not supported on this GPU type.", +) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="This test is for ROCm GPU.") @pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"]) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None: try: - llm = vllm_runner("facebook/opt-125m", - dtype=dtype, - quantization="ptpc_fp8", - kv_cache_dtype=kv_cache_dtype) + llm = vllm_runner( + "facebook/opt-125m", + dtype=dtype, + quantization="ptpc_fp8", + kv_cache_dtype=kv_cache_dtype, + ) except AssertionError as e: if str(e) == UNSUPPORTED_STR: # If the error message matches, the test passes diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 930f4acb328f..6c047259c177 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -19,23 +19,27 @@ from packaging import version from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 - QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8) + QuarkLinearMethod, + QuarkW8A8Fp8, + QuarkW8A8Int8, +) from vllm.platforms import current_platform from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse( - importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') + importlib.metadata.version("amd-quark") +) >= version.parse("0.8.99") if QUARK_MXFP4_AVAILABLE: - from quark.torch.export.nn.modules.realquantizer import ( - StaticScaledRealQuantizer) + from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer from quark.torch.kernel import mx as mx_kernel from quark.torch.quantization.config.config import FP4PerGroupSpec try: huggingface_hub.list_repo_refs( - "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ") + "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ" + ) HF_HUB_AMD_ORG_ACCESS = True except huggingface_hub.errors.RepositoryNotFoundError: HF_HUB_AMD_ORG_ACCESS = False @@ -47,13 +51,13 @@ def enable_pickle(monkeypatch): monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") -@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8']) -@pytest.mark.parametrize('tp', [1]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +@pytest.mark.parametrize("tp", [1]) def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp): model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test" - with vllm_runner(model_path, - kv_cache_dtype=kv_cache_dtype, - tensor_parallel_size=tp) as llm: + with vllm_runner( + model_path, kv_cache_dtype=kv_cache_dtype, tensor_parallel_size=tp + ) as llm: def check_model(model): layer = model.model.layers[0] @@ -74,7 +78,7 @@ def check_model(model): assert output -@pytest.mark.parametrize('tp', [1]) +@pytest.mark.parametrize("tp", [1]) def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp): model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts" with vllm_runner(model_path, tensor_parallel_size=tp) as llm: @@ -89,8 +93,7 @@ def check_model(model): if isinstance(qkv_proj.scheme, QuarkW8A8Fp8): assert qkv_proj.weight.dtype is current_platform.fp8_dtype() - assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[ - 1] + assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[1] assert qkv_proj.weight_scale.shape[1] == 1 llm.apply_model(check_model) @@ -99,7 +102,7 @@ def check_model(model): assert output -@pytest.mark.parametrize('tp', [1]) +@pytest.mark.parametrize("tp", [1]) def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp): model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test" with vllm_runner(model_path, tensor_parallel_size=tp) as llm: @@ -125,16 +128,18 @@ def test_quark_fp8_parity(vllm_runner): llm_kwargs = { "tensor_parallel_size": 1, "enforce_eager": True, - "gpu_memory_utilization": 0.1 + "gpu_memory_utilization": 0.1, } - with (vllm_runner(quark_model_id, **llm_kwargs) as - quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle): + with ( + vllm_runner(quark_model_id, **llm_kwargs) as quark_handle, + vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle, + ): def get_state_dict(model): return {k: v.cpu() for k, v in model.state_dict().items()} - quark_state_dict, = quark_handle.apply_model(get_state_dict) - fp8_state_dict, = fp8_handle.apply_model(get_state_dict) + (quark_state_dict,) = quark_handle.apply_model(get_state_dict) + (fp8_state_dict,) = fp8_handle.apply_model(get_state_dict) assert fp8_state_dict.keys() == quark_state_dict.keys() @@ -164,16 +169,17 @@ def get_model_args(self) -> str: # Private model. GSM8KAccuracyTestConfig( model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant", - excepted_value=0.96), + excepted_value=0.96, + ), ] @pytest.mark.parametrize("config", ACCURACY_CONFIGS) -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, - reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.skipif( not HF_HUB_AMD_ORG_ACCESS, - reason="Read access to huggingface.co/amd is required for this test.") + reason="Read access to huggingface.co/amd is required for this test.", +) def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig): if torch.cuda.device_count() < 8: pytest.skip( @@ -195,28 +201,26 @@ def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig): EXPECTED_VALUE = config.excepted_value measured_value = results["results"][task]["exact_match,strict-match"] - assert (measured_value - rtol < EXPECTED_VALUE - and measured_value + rtol > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert ( + measured_value - rtol < EXPECTED_VALUE + and measured_value + rtol > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" del os.environ["VLLM_USE_TRITON_FLASH_ATTN"] -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, - reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("scalings", - [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) -def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, - scalings: list[int]): +@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) +def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]): torch.manual_seed(0) hidden_size = 64 * 32 - inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - - 0.5) * 2 + inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - 0.5) * 2 for i in range(hidden_size // 32): - inp[:, i * 32:(i + 1) * - 32] = inp[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)] + inp[:, i * 32 : (i + 1) * 32] = ( + inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)] + ) inp_kernel = inp.clone() inp_kernel_clone = inp_kernel.clone() @@ -225,20 +229,20 @@ def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, res_torch = qdq_mxfp4_torch(inp_kernel, "even") for i in range(hidden_size // 32): - assert torch.all(torch.isfinite(res_hip[:, i * 32:(i + 1) * 32])) - assert torch.all(torch.isfinite(res_torch[:, i * 32:(i + 1) * 32])) + assert torch.all(torch.isfinite(res_hip[:, i * 32 : (i + 1) * 32])) + assert torch.all(torch.isfinite(res_torch[:, i * 32 : (i + 1) * 32])) - torch.testing.assert_close(res_hip[:, i * 32:(i + 1) * 32], - res_torch[:, i * 32:(i + 1) * 32]) + torch.testing.assert_close( + res_hip[:, i * 32 : (i + 1) * 32], res_torch[:, i * 32 : (i + 1) * 32] + ) -@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, - reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("scalings", - [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) -def test_mxfp4_dequant_kernel_match_quark(float_dtype: torch.dtype, - scalings: list[int]): +@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) +def test_mxfp4_dequant_kernel_match_quark( + float_dtype: torch.dtype, scalings: list[int] +): qspec = FP4PerGroupSpec( ch_axis=-1, group_size=32, @@ -265,8 +269,9 @@ def test_mxfp4_dequant_kernel_match_quark(float_dtype: torch.dtype, # Make it so that different groups have different scales. for i in range(hidden_size // 32): - w[:, i * 32:(i + 1) * - 32] = w[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)] + w[:, i * 32 : (i + 1) * 32] = ( + w[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)] + ) observer(w) scale, _ = observer._calculate_qparams() diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index 03fe59d7e3bf..b70c2ee7fe2e 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -6,18 +6,25 @@ Run `pytest tests/quantization/test_register_quantization_config.py`. """ + from typing import Any, Optional import pytest import torch import torch.nn.functional as F -from vllm.model_executor.layers.linear import LinearBase # noqa: E501 -from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.linear import ( + LinearBase, # noqa: E501 + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import ( - QuantizationMethods, get_quantization_config, register_quantization_config) + QuantizationMethods, + get_quantization_config, + register_quantization_config, +) from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 - QuantizationConfig) + QuantizationConfig, +) class FakeQuantLinearMethod(UnquantizedLinearMethod): @@ -28,10 +35,12 @@ def __init__(self, num_bits: int = 8) -> None: super().__init__() self.num_bits = num_bits - def apply(self, - layer: "torch.nn.Module", - x: "torch.Tensor", - bias: Optional["torch.Tensor"] = None) -> "torch.Tensor": + def apply( + self, + layer: "torch.nn.Module", + x: "torch.Tensor", + bias: Optional["torch.Tensor"] = None, + ) -> "torch.Tensor": """Perform fake quantization before the linear layer.""" # Calculate the scales dynamically @@ -40,8 +49,11 @@ def apply(self, scales = (max_val - min_val) / (2**self.num_bits - 1) # Fake quantize the input - quant_x = torch.clamp(torch.round(x / scales), -2**(self.num_bits - 1), - 2**(self.num_bits - 1) - 1) + quant_x = torch.clamp( + torch.round(x / scales), + -(2 ** (self.num_bits - 1)), + 2 ** (self.num_bits - 1) - 1, + ) dequant_x = quant_x * scales return F.linear(dequant_x, layer.weight, bias) @@ -79,8 +91,9 @@ def from_config(cls, config: dict[str, Any]) -> "CustomQuantConfig": """Create a config class from the model's quantization config.""" return CustomQuantConfig(num_bits=config.get("num_bits", 8)) - def get_quant_method(self, layer: "torch.nn.Module", - prefix: str) -> Optional["FakeQuantLinearMethod"]: + def get_quant_method( + self, layer: "torch.nn.Module", prefix: str + ) -> Optional["FakeQuantLinearMethod"]: """Get the quantize method to use for the quantized layer.""" if isinstance(layer, LinearBase): return FakeQuantLinearMethod(num_bits=self.num_bits) @@ -99,18 +112,20 @@ def test_register_quantization_config(): register_quantization_config("custom_quant")(CustomQuantConfig) -@pytest.mark.parametrize(argnames="model", - argvalues=[ - "meta-llama/Llama-3.2-1B-Instruct", - ]) +@pytest.mark.parametrize( + argnames="model", + argvalues=[ + "meta-llama/Llama-3.2-1B-Instruct", + ], +) def test_custom_quant(vllm_runner, model, monkeypatch): """Test infer with the custom quantization method.""" # `LLM.apply_model` requires pickling a function. monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - with vllm_runner(model_name=model, - quantization="custom_quant", - enforce_eager=True) as llm: + with vllm_runner( + model_name=model, quantization="custom_quant", enforce_eager=True + ) as llm: def check_model(model): layer = model.model.layers[0] diff --git a/tests/quantization/test_rtn.py b/tests/quantization/test_rtn.py index bc2b468f97d8..370625ed3479 100644 --- a/tests/quantization/test_rtn.py +++ b/tests/quantization/test_rtn.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright © 2025, Oracle and/or its affiliates. -"""Tests RTN quantization startup and generation, +"""Tests RTN quantization startup and generation, doesn't test correctness """ + import pytest from tests.quantization.utils import is_quant_method_supported @@ -14,8 +15,10 @@ ] -@pytest.mark.skipif(not is_quant_method_supported("rtn"), - reason="RTN is not supported on this GPU type.") +@pytest.mark.skipif( + not is_quant_method_supported("rtn"), + reason="RTN is not supported on this GPU type.", +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [10]) @@ -27,6 +30,5 @@ def test_model_rtn_startup( dtype: str, max_tokens: int, ) -> None: - with vllm_runner(model, dtype=dtype, quantization="rtn") as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py index 37cf7ef8417b..45ee94119bbb 100644 --- a/tests/quantization/test_torchao.py +++ b/tests/quantization/test_torchao.py @@ -13,12 +13,13 @@ @pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available") def test_pre_quantized_model(vllm_runner): - with vllm_runner("drisspg/fp8-opt-125m", - quantization="torchao", - dtype="bfloat16", - enforce_eager=True) as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + with vllm_runner( + "drisspg/fp8-opt-125m", + quantization="torchao", + dtype="bfloat16", + enforce_eager=True, + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output @@ -28,17 +29,18 @@ def test_pre_quantized_model(vllm_runner): [ "cuda:0", # {"": "cuda"}, - ]) -def test_opt_125m_int8wo_model_loading_with_params(vllm_runner, - pt_load_map_location): + ], +) +def test_opt_125m_int8wo_model_loading_with_params(vllm_runner, pt_load_map_location): torch._dynamo.reset() model_name = "jerryzh168/opt-125m-int8wo-partial-quant" - with vllm_runner(model_name=model_name, - quantization="torchao", - dtype="bfloat16", - pt_load_map_location=pt_load_map_location) as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location=pt_load_map_location, + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output @@ -47,12 +49,13 @@ def test_opt_125m_int8wo_model_loading_with_params(vllm_runner, def test_opt_125m_int4wo_model_per_module_quant(vllm_runner): torch._dynamo.reset() model_name = "jerryzh168/opt-125m-int4wo-per-module" - with vllm_runner(model_name=model_name, - quantization="torchao", - dtype="bfloat16", - pt_load_map_location="cuda:0") as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0", + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output @@ -61,12 +64,13 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner): def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): torch._dynamo.reset() model_name = "mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao" - with vllm_runner(model_name=model_name, - quantization="torchao", - dtype="bfloat16", - pt_load_map_location="cuda:0") as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0", + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output @@ -75,17 +79,18 @@ def test_qwenvl_int8wo_model_loading_with_params(vllm_runner): @pytest.mark.skip( reason="since torchao nightly is only compatible with torch nightly" "currently https://github.com/pytorch/ao/issues/2919, we'll have to skip " - "torchao tests that requires newer versions (0.14.0.dev+) for now") + "torchao tests that requires newer versions (0.14.0.dev+) for now" +) def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner): torch._dynamo.reset() - model_name = ("torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2" - "-0.14.0.dev") - with vllm_runner(model_name=model_name, - quantization="torchao", - dtype="bfloat16", - pt_load_map_location="cuda:0") as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + model_name = "torchao-testing/opt-125m-AWQConfig-Int4WeightOnlyConfig-v2-0.14.0.dev" + with vllm_runner( + model_name=model_name, + quantization="torchao", + dtype="bfloat16", + pt_load_map_location="cuda:0", + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output @@ -101,22 +106,24 @@ def test_on_the_fly_quant_config_dict_json(vllm_runner): import json from torchao.core.config import config_to_dict - from torchao.quantization import ( - Float8DynamicActivationFloat8WeightConfig, PerRow) + from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow torchao_quant_config = Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow()) + granularity=PerRow() + ) hf_overrides = { - "quantization_config_dict_json": - json.dumps(config_to_dict(torchao_quant_config)) + "quantization_config_dict_json": json.dumps( + config_to_dict(torchao_quant_config) + ) } - with vllm_runner(model_name=model_name, - dtype="bfloat16", - pt_load_map_location="cuda:0", - quantization="torchao", - hf_overrides=hf_overrides) as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + with vllm_runner( + model_name=model_name, + dtype="bfloat16", + pt_load_map_location="cuda:0", + quantization="torchao", + hf_overrides=hf_overrides, + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output @@ -132,8 +139,7 @@ def test_on_the_fly_quant_config_file(vllm_runner): from tempfile import NamedTemporaryFile from torchao.core.config import config_to_dict - from torchao.quantization import ( - Float8DynamicActivationFloat8WeightConfig, PerRow) + from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) @@ -144,13 +150,14 @@ def test_on_the_fly_quant_config_file(vllm_runner): config_file_name = str(f.name) hf_overrides = {"quantization_config_file": config_file_name} - with vllm_runner(model_name=model_name, - dtype="bfloat16", - pt_load_map_location="cuda:0", - quantization="torchao", - hf_overrides=hf_overrides) as llm: - output = llm.generate_greedy(["The capital of France is"], - max_tokens=32) + with vllm_runner( + model_name=model_name, + dtype="bfloat16", + pt_load_map_location="cuda:0", + quantization="torchao", + hf_overrides=hf_overrides, + ) as llm: + output = llm.generate_greedy(["The capital of France is"], max_tokens=32) assert output @@ -160,17 +167,18 @@ def test_reload_weights(): import json from torchao.core.config import config_to_dict - from torchao.quantization import ( - Float8DynamicActivationFloat8WeightConfig, PerRow) + from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow from vllm import LLM, SamplingParams torchao_quant_config = Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow()) + granularity=PerRow() + ) hf_overrides = { - "quantization_config_dict_json": - json.dumps(config_to_dict(torchao_quant_config)) + "quantization_config_dict_json": json.dumps( + config_to_dict(torchao_quant_config) + ) } llm = LLM( @@ -182,12 +190,9 @@ def test_reload_weights(): hf_overrides=hf_overrides, ) # Update load format from `dummy` to `auto` - llm.collective_rpc("update_config", - args=({ - "load_config": { - "load_format": "auto" - } - }, )) + llm.collective_rpc( + "update_config", args=({"load_config": {"load_format": "auto"}},) + ) # Now reload real weights inplace llm.collective_rpc("reload_weights") prompts = [ diff --git a/tests/reasoning/test_base_thinking_reasoning_parser.py b/tests/reasoning/test_base_thinking_reasoning_parser.py index 6a939dcfc2c9..ddda50fe770a 100644 --- a/tests/reasoning/test_base_thinking_reasoning_parser.py +++ b/tests/reasoning/test_base_thinking_reasoning_parser.py @@ -44,9 +44,7 @@ def test_tokenizer(): # Add custom test tokens test_tokens = ["<test:think>", "</test:think>", "<alt:start>", "<alt:end>"] existing_tokens = set(tokenizer.get_vocab().keys()) - new_tokens = [ - token for token in test_tokens if token not in existing_tokens - ] + new_tokens = [token for token in test_tokens if token not in existing_tokens] if new_tokens: tokenizer.add_tokens(new_tokens) return tokenizer @@ -54,8 +52,8 @@ def test_tokenizer(): class TestBaseThinkingReasoningParserInit: """ - Test initialization and basic properties of - BaseThinkingReasoningParser. + Test initialization and basic properties of + BaseThinkingReasoningParser. """ def test_successful_initialization(self, test_tokenizer): @@ -76,7 +74,6 @@ def test_initialization_with_missing_tokens(self, test_tokenizer): # Create a parser with tokens not in vocabulary class MissingTokenParser(BaseThinkingReasoningParser): - @property def start_token(self) -> str: return "<missing:start>" @@ -85,15 +82,15 @@ def start_token(self) -> str: def end_token(self) -> str: return "<missing:end>" - with pytest.raises(RuntimeError, - match="could not locate think start/end tokens"): + with pytest.raises( + RuntimeError, match="could not locate think start/end tokens" + ): MissingTokenParser(test_tokenizer) def test_initialization_with_empty_tokens(self, test_tokenizer): """Test that initialization fails with empty token strings.""" class EmptyTokenParser(BaseThinkingReasoningParser): - @property def start_token(self) -> str: return "" @@ -102,8 +99,9 @@ def start_token(self) -> str: def end_token(self) -> str: return "" - with pytest.raises(ValueError, - match="start_token and end_token must be defined"): + with pytest.raises( + ValueError, match="start_token and end_token must be defined" + ): EmptyTokenParser(test_tokenizer) @@ -158,10 +156,8 @@ def test_extract_reasoning_content_with_both_tokens(self, test_tokenizer): parser = TestThinkingReasoningParser(test_tokenizer) request = ChatCompletionRequest(messages=[], model="test-model") - model_output = ("<test:think>This is reasoning" - "</test:think>This is content") - reasoning, content = parser.extract_reasoning_content( - model_output, request) + model_output = "<test:think>This is reasoning</test:think>This is content" + reasoning, content = parser.extract_reasoning_content(model_output, request) assert reasoning == "This is reasoning" assert content == "This is content" @@ -171,9 +167,8 @@ def test_extract_reasoning_content_only_end_token(self, test_tokenizer): parser = TestThinkingReasoningParser(test_tokenizer) request = ChatCompletionRequest(messages=[], model="test-model") - model_output = ("This is reasoning</test:think>This is content") - reasoning, content = parser.extract_reasoning_content( - model_output, request) + model_output = "This is reasoning</test:think>This is content" + reasoning, content = parser.extract_reasoning_content(model_output, request) assert reasoning == "This is reasoning" assert content == "This is content" @@ -184,8 +179,7 @@ def test_extract_reasoning_content_no_end_token(self, test_tokenizer): request = ChatCompletionRequest(messages=[], model="test-model") model_output = "This is just content" - reasoning, content = parser.extract_reasoning_content( - model_output, request) + reasoning, content = parser.extract_reasoning_content(model_output, request) assert reasoning == "This is just content" assert content is None @@ -196,8 +190,7 @@ def test_extract_reasoning_content_empty_output(self, test_tokenizer): request = ChatCompletionRequest(messages=[], model="test-model") model_output = "" - reasoning, content = parser.extract_reasoning_content( - model_output, request) + reasoning, content = parser.extract_reasoning_content(model_output, request) assert reasoning == "" assert content is None @@ -207,9 +200,8 @@ def test_extract_reasoning_content_only_tokens(self, test_tokenizer): parser = TestThinkingReasoningParser(test_tokenizer) request = ChatCompletionRequest(messages=[], model="test-model") - model_output = ("<test:think></test:think>") - reasoning, content = parser.extract_reasoning_content( - model_output, request) + model_output = "<test:think></test:think>" + reasoning, content = parser.extract_reasoning_content(model_output, request) assert reasoning == "" assert content is None @@ -221,19 +213,24 @@ class TestBaseThinkingReasoningParserStreaming: @pytest.mark.parametrize("streaming", [True, False]) def test_simple_reasoning_extraction(self, test_tokenizer, streaming): """ - Test basic reasoning extraction in both - streaming and non-streaming modes. + Test basic reasoning extraction in both + streaming and non-streaming modes. """ parser = TestThinkingReasoningParser(test_tokenizer) model_output = [ - "<test:think>", "Some ", "reasoning ", "content", "</test:think>", - "Final ", "answer" + "<test:think>", + "Some ", + "reasoning ", + "content", + "</test:think>", + "Final ", + "answer", ] - reasoning, content = run_reasoning_extraction(parser, - model_output, - streaming=streaming) + reasoning, content = run_reasoning_extraction( + parser, model_output, streaming=streaming + ) assert reasoning == "Some reasoning content" assert content == "Final answer" @@ -252,9 +249,7 @@ def test_streaming_with_incremental_deltas(self, test_tokenizer): "answer", ] - reasoning, content = run_reasoning_extraction(parser, - deltas, - streaming=True) + reasoning, content = run_reasoning_extraction(parser, deltas, streaming=True) assert reasoning == "Some reasoning content" assert content == "Final answer" @@ -271,9 +266,7 @@ def test_streaming_with_start_token(self, test_tokenizer): "Answer", ] - reasoning, content = run_reasoning_extraction(parser, - deltas, - streaming=True) + reasoning, content = run_reasoning_extraction(parser, deltas, streaming=True) assert reasoning == "Some reasoning" assert content == "Answer" @@ -290,9 +283,7 @@ def test_streaming_no_end_token(self, test_tokenizer): "end", ] - reasoning, content = run_reasoning_extraction(parser, - deltas, - streaming=True) + reasoning, content = run_reasoning_extraction(parser, deltas, streaming=True) assert reasoning == "Some reasoning without end" assert content is None @@ -309,9 +300,7 @@ def test_streaming_only_end_token(self, test_tokenizer): "Final", ] - reasoning, content = run_reasoning_extraction(parser, - deltas, - streaming=True) + reasoning, content = run_reasoning_extraction(parser, deltas, streaming=True) assert reasoning == "Reasoning content" assert content == "Final" @@ -319,29 +308,27 @@ def test_streaming_only_end_token(self, test_tokenizer): class TestBaseThinkingReasoningParserMultipleImplementations: """ - Test that multiple implementations of - BaseThinkingReasoningParser work correctly. + Test that multiple implementations of + BaseThinkingReasoningParser work correctly. """ def test_different_token_implementations(self, test_tokenizer): """ - Test that different implementations - with different tokens work independently. + Test that different implementations + with different tokens work independently. """ parser1 = TestThinkingReasoningParser(test_tokenizer) parser2 = TestThinkingReasoningParserAlt(test_tokenizer) # Test parser1 - model_output1 = ("Reasoning1</test:think>Content1") - reasoning1, content1 = run_reasoning_extraction( - parser1, [model_output1]) + model_output1 = "Reasoning1</test:think>Content1" + reasoning1, content1 = run_reasoning_extraction(parser1, [model_output1]) assert reasoning1 == "Reasoning1" assert content1 == "Content1" # Test parser2 model_output2 = "Reasoning2<alt:end>Content2" - reasoning2, content2 = run_reasoning_extraction( - parser2, [model_output2]) + reasoning2, content2 = run_reasoning_extraction(parser2, [model_output2]) assert reasoning2 == "Reasoning2" assert content2 == "Content2" @@ -359,7 +346,7 @@ def test_multiple_end_tokens(self, test_tokenizer): """Test behavior with multiple end tokens.""" parser = TestThinkingReasoningParser(test_tokenizer) - model_output = ("First</test:think>Middle</test:think>Last") + model_output = "First</test:think>Middle</test:think>Last" reasoning, content = run_reasoning_extraction(parser, [model_output]) # Should stop at first end token @@ -370,8 +357,7 @@ def test_nested_tokens(self, test_tokenizer): """Test behavior with nested-like token patterns.""" parser = TestThinkingReasoningParser(test_tokenizer) - model_output = ("<test:think>Outer" - "<test:think>Inner</test:think>Content") + model_output = "<test:think>Outer<test:think>Inner</test:think>Content" reasoning, content = run_reasoning_extraction(parser, [model_output]) # Should process normally, start from first start token @@ -382,11 +368,9 @@ def test_malformed_tokens(self, test_tokenizer): """Test behavior with malformed token-like strings.""" parser = TestThinkingReasoningParser(test_tokenizer) - model_output = ("<test:thinking>Not a real token" - "</test:thinking>Content") + model_output = "<test:thinking>Not a real token</test:thinking>Content" reasoning, content = run_reasoning_extraction(parser, [model_output]) # Should treat as regular content since tokens don't match exactly - assert reasoning == ("<test:thinking>Not a real token" - "</test:thinking>Content") + assert reasoning == ("<test:thinking>Not a real token</test:thinking>Content") assert content is None diff --git a/tests/reasoning/test_deepseekr1_reasoning_parser.py b/tests/reasoning/test_deepseekr1_reasoning_parser.py index 987f3c48de0c..946d01c123c5 100644 --- a/tests/reasoning/test_deepseekr1_reasoning_parser.py +++ b/tests/reasoning/test_deepseekr1_reasoning_parser.py @@ -259,15 +259,15 @@ def test_reasoning( output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"]) # decode everything to tokens output_tokens: list[str] = [ - deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token]) - for token in output + deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token]) for token in output ] - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(deepseek_r1_qwen_tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + deepseek_r1_qwen_tokenizer + ) - reasoning, content = run_reasoning_extraction(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] @@ -281,7 +281,8 @@ def test_reasoning( if param_dict["content"] is not None: content = parser.extract_content_ids(output_ids) assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids( - deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"])) + deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]) + ) else: content = parser.extract_content_ids(output) assert content == [] diff --git a/tests/reasoning/test_glm4_moe_reasoning_parser.py b/tests/reasoning/test_glm4_moe_reasoning_parser.py index 4c5ec2c9b408..0a8595a00fcb 100644 --- a/tests/reasoning/test_glm4_moe_reasoning_parser.py +++ b/tests/reasoning/test_glm4_moe_reasoning_parser.py @@ -54,8 +54,7 @@ def glm45_tokenizer(): "is_reasoning_end": True, } MULTILINE_REASONING = { - "output": - "<think>This is a reasoning\nsection</think>This is the rest\nThat", + "output": "<think>This is a reasoning\nsection</think>This is the rest\nThat", "reasoning_content": "This is a reasoning\nsection", "content": "This is the rest\nThat", "is_reasoning_end": True, @@ -158,12 +157,12 @@ def glm45_tokenizer(): REASONING_END_TEST_CASES = [ pytest.param(STILL_REASONING_PROMPT, False, id="still_reasoning"), pytest.param(DONE_REASONING_PROMPT, True, id="done_reasoning"), - pytest.param(MULTI_TURN_STILL_REASONING_PROMPT, - False, - id="multi_turn_still_reasoning"), - pytest.param(MULTI_TURN_DONE_REASONING_PROMPT, - True, - id="multi_turn_done_reasoning") + pytest.param( + MULTI_TURN_STILL_REASONING_PROMPT, False, id="multi_turn_still_reasoning" + ), + pytest.param( + MULTI_TURN_DONE_REASONING_PROMPT, True, id="multi_turn_done_reasoning" + ), ] @@ -177,12 +176,13 @@ def test_reasoning( output_tokens: list[str] = [ glm45_tokenizer.convert_tokens_to_string([token]) for token in output ] - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(glm45_tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + glm45_tokenizer + ) - reasoning, content = run_reasoning_extraction(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] @@ -193,10 +193,12 @@ def test_reasoning( @pytest.mark.parametrize("prompt, is_reasoning_end", REASONING_END_TEST_CASES) -def test_is_reasoning_end_full_prompt(prompt: str, is_reasoning_end: bool, - glm45_tokenizer): - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(glm45_tokenizer) +def test_is_reasoning_end_full_prompt( + prompt: str, is_reasoning_end: bool, glm45_tokenizer +): + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + glm45_tokenizer + ) tokens = glm45_tokenizer.tokenize(prompt) token_ids = glm45_tokenizer.convert_tokens_to_ids(tokens) check_is_reasoning_end = parser.is_reasoning_end(token_ids) diff --git a/tests/reasoning/test_granite_reasoning_parser.py b/tests/reasoning/test_granite_reasoning_parser.py index 38cab73a45f2..de1663408d72 100644 --- a/tests/reasoning/test_granite_reasoning_parser.py +++ b/tests/reasoning/test_granite_reasoning_parser.py @@ -11,8 +11,7 @@ START_RESPONSE = "Here is my response:" SIMPLE_REASONING = { - "output": - f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", # noqa: E501 "reasoning_content": "This is a reasoning section", "content": "This is the rest", } @@ -27,14 +26,12 @@ "content": "This is content", } MULTIPLE_LINES = { - "output": - f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "output": f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } REASONING_WITH_THINK = { - "output": - f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", # noqa: E501 "reasoning_content": "This is a reasoning section", "content": "This is the rest", } @@ -44,8 +41,7 @@ "content": None, } MULTIPLE_LINES_WITH_THINK = { - "output": - f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "output": f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } @@ -137,12 +133,13 @@ def test_reasoning( output_tokens: list[str] = [ tokenizer.convert_tokens_to_string([token]) for token in output ] - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + tokenizer + ) - reasoning, content = run_reasoning_extraction(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] @@ -229,18 +226,15 @@ def test_reasoning( ## The Response is ongoing, and the delta mixes reasoning content / content STREAMING_10 = { "previous_text": "Here is my thought process: foo", - "current_text": - "Here is my thought process: foo bar Here is my response: baz", + "current_text": "Here is my thought process: foo bar Here is my response: baz", "delta_text": " bar Here is my response: baz", "reasoning_content": " bar ", "content": " baz", } # The delta text starts a new substring that might be a response special seq STREAMING_11 = { - "previous_text": - "Here is my thought process: This is a reasoning section ", - "current_text": - "Here is my thought process: This is a reasoning section Here", + "previous_text": "Here is my thought process: This is a reasoning section ", + "current_text": "Here is my thought process: This is a reasoning section Here", "delta_text": "Here", "reasoning_content": None, "content": None, @@ -320,14 +314,17 @@ def test_reasoning( @pytest.mark.parametrize("param_dict", STREAMING_SUBCASES) def test_streaming_subcases(param_dict): # Get all of the token IDs - previous_token_ids = tokenizer.encode( - param_dict["previous_text"] - ) if param_dict["previous_text"] is not None else [] + previous_token_ids = ( + tokenizer.encode(param_dict["previous_text"]) + if param_dict["previous_text"] is not None + else [] + ) current_token_ids = tokenizer.encode(param_dict["current_text"]) delta_token_ids = tokenizer.encode(param_dict["delta_text"]) - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + tokenizer + ) response = parser.extract_reasoning_content_streaming( previous_text=param_dict["previous_text"], @@ -339,8 +336,7 @@ def test_streaming_subcases(param_dict): ) # Streaming currently expects at least one of reasoning content / content, # so the response should return None in that case. - if param_dict["reasoning_content"] is None and param_dict[ - "content"] is None: + if param_dict["reasoning_content"] is None and param_dict["content"] is None: assert response is None else: assert isinstance(response, DeltaMessage) diff --git a/tests/reasoning/test_hunyuan_reasoning_parser.py b/tests/reasoning/test_hunyuan_reasoning_parser.py index f9238267f02e..b7e3ea73ccde 100644 --- a/tests/reasoning/test_hunyuan_reasoning_parser.py +++ b/tests/reasoning/test_hunyuan_reasoning_parser.py @@ -13,15 +13,13 @@ END_RESPONSE = "\n</answer>" NO_REASONING_QUICK_THROUGHT = { - "output": - f"{START_REASONING}{START_RESPONSE}This is the rest{END_RESPONSE}", #noqa: E501 + "output": f"{START_REASONING}{START_RESPONSE}This is the rest{END_RESPONSE}", # noqa: E501 "reasoning_content": None, "content": "This is the rest", } SIMPLE_REASONING = { - "output": - f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest{END_RESPONSE}", #noqa: E501 + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest{END_RESPONSE}", # noqa: E501 "reasoning_content": "This is a reasoning section", "content": "This is the rest", } @@ -42,14 +40,12 @@ "content": "This is content", } MULTIPLE_LINES = { - "output": - f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "output": f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } REASONING_WITH_THINK = { - "output": - f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", # noqa: E501 "reasoning_content": "This is a reasoning section", "content": "This is the rest", } @@ -59,8 +55,7 @@ "content": None, } MULTIPLE_LINES_WITH_THINK = { - "output": - f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "output": f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } @@ -122,9 +117,7 @@ NO_REASONING, id="no_reasoning_streaming", ), - pytest.param(True, - NO_REASONING_QUICK_THROUGHT, - id="no_reasoning_quick_stream"), + pytest.param(True, NO_REASONING_QUICK_THROUGHT, id="no_reasoning_quick_stream"), pytest.param( True, MULTIPLE_LINES, @@ -148,8 +141,9 @@ ] # Global tokenizer initialization to avoid repeated loading -tokenizer = AutoTokenizer.from_pretrained("tencent/Hunyuan-A13B-Instruct", - trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained( + "tencent/Hunyuan-A13B-Instruct", trust_remote_code=True +) @pytest.mark.parametrize("streaming, param_dict", TEST_CASES) @@ -162,12 +156,13 @@ def test_reasoning( output_tokens: list[str] = [ tokenizer.convert_tokens_to_string([token]) for token in output ] - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + tokenizer + ) - reasoning, content = run_reasoning_extraction(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] diff --git a/tests/reasoning/test_mistral_reasoning_parser.py b/tests/reasoning/test_mistral_reasoning_parser.py index 91a22f6f5d72..96107c0c1193 100644 --- a/tests/reasoning/test_mistral_reasoning_parser.py +++ b/tests/reasoning/test_mistral_reasoning_parser.py @@ -3,8 +3,7 @@ import pytest from mistral_common.tokens.tokenizers.base import SpecialTokens -from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo, - Tekkenizer) +from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer from tests.reasoning.utils import run_reasoning_extraction_mistral from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -18,23 +17,27 @@ def mistral_tokenizer(): # TODO(Julien): upon model release change to a tokenizer already configured. # ================================================================= mistral_tokenizer = MistralTokenizer.from_pretrained( - "mistralai/Devstral-Small-2507") + "mistralai/Devstral-Small-2507" + ) assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) # Add think special tokens to the tokenizer mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( - rank=35, is_control=True, token_str=SpecialTokens.begin_think.value) + rank=35, is_control=True, token_str=SpecialTokens.begin_think.value + ) mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( - rank=36, is_control=True, token_str=SpecialTokens.end_think.value) + rank=36, is_control=True, token_str=SpecialTokens.end_think.value + ) mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { k: v - for k, v in - mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() + for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() if v not in {35, 36} } mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.begin_think.value] = 35 + SpecialTokens.begin_think.value + ] = 35 mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ - SpecialTokens.end_think.value] = 36 + SpecialTokens.end_think.value + ] = 36 mistral_tokenizer.instruct.BEGIN_THINK = 35 mistral_tokenizer.instruct.END_THINK = 36 # ================================================================= @@ -290,39 +293,45 @@ def test_mistral_reasoning( if index_think != -1: output_before_think = output[:index_think] output_tokens += mistral_tokenizer.tokenizer.encode( - output_before_think, False, False) + output_before_think, False, False + ) output_tokens += [mistral_tokenizer.instruct.BEGIN_THINK] if index_end_think != -1: - output_middle = output[index_think + len_think:index_end_think] - output_after_think = output[index_end_think + len_end_think:] + output_middle = output[index_think + len_think : index_end_think] + output_after_think = output[index_end_think + len_end_think :] output_tokens += mistral_tokenizer.tokenizer.encode( - output_middle, False, False) + output_middle, False, False + ) output_tokens += [mistral_tokenizer.instruct.END_THINK] output_tokens += mistral_tokenizer.tokenizer.encode( - output_after_think, False, False) + output_after_think, False, False + ) else: - output_middle = output[index_think + len_think:] + output_middle = output[index_think + len_think :] output_tokens += mistral_tokenizer.tokenizer.encode( - output_middle, False, False) + output_middle, False, False + ) elif index_end_think != -1: output_before_think = output[:index_end_think] - output_after_think = output[index_end_think + len_end_think:] + output_after_think = output[index_end_think + len_end_think :] output_tokens += mistral_tokenizer.tokenizer.encode( - output_before_think, False, False) + output_before_think, False, False + ) output_tokens += [mistral_tokenizer.instruct.END_THINK] output_tokens += mistral_tokenizer.tokenizer.encode( - output_after_think, False, False) + output_after_think, False, False + ) else: - output_tokens += mistral_tokenizer.tokenizer.encode( - output, False, False) + output_tokens += mistral_tokenizer.tokenizer.encode(output, False, False) - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(mistral_tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + mistral_tokenizer + ) - reasoning, content = run_reasoning_extraction_mistral(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction_mistral( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] @@ -335,7 +344,8 @@ def test_mistral_reasoning( if param_dict["content"] is not None: content = parser.extract_content_ids(output_tokens) assert content == mistral_tokenizer.tokenizer.encode( - param_dict["content"], bos=False, eos=False) + param_dict["content"], bos=False, eos=False + ) else: content = parser.extract_content_ids(output_tokens) assert content == [] diff --git a/tests/reasoning/test_olmo3_reasoning_parser.py b/tests/reasoning/test_olmo3_reasoning_parser.py new file mode 100644 index 000000000000..4a2eca994610 --- /dev/null +++ b/tests/reasoning/test_olmo3_reasoning_parser.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "olmo3" +START_REASONING = "<think>" +END_REASONING = "</think>" + +NO_REASONING = { + "output": f"{START_REASONING}{END_REASONING}No thoughts, head empty!", + "reasoning_content": None, + "content": "No thoughts, head empty!", +} + +NO_REASONING_WITH_NEWLINE = { + "output": f"{START_REASONING}\n{END_REASONING}\n\nNo thoughts, head empty!", + "reasoning_content": "\n", + "content": "\n\nNo thoughts, head empty!", +} + +SIMPLE_REASONING = { + "output": f"{START_REASONING}This is a reasoning section{END_REASONING}This is the rest", # noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} + +SIMPLE_REASONING_WITH_NEWLINE = { + "output": f"{START_REASONING} Look!\n\nI'm thinking...{END_REASONING}\nThis is the rest", # noqa: E501 + "reasoning_content": " Look!\n\nI'm thinking...", + "content": "\nThis is the rest", +} + +SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES = { + "output": f"{START_REASONING}\nLook!\nI'm thinking...\n\n{END_REASONING}\n\n\nThis is the rest", # noqa: E501 + "reasoning_content": "\nLook!\nI'm thinking...\n\n", + "content": "\n\n\nThis is the rest", +} + +NO_REASONING_ONLY_END_THINK = { + "output": f"{END_REASONING}\n\nNo thoughts, head empty!", + "reasoning_content": None, + "content": "\n\nNo thoughts, head empty!", +} + +REASONING_ONLY_END_THINK = { + "output": f"The user is asking me not to think.{END_REASONING}No thoughts!", + "reasoning_content": "The user is asking me not to think.", + "content": "No thoughts!", +} + +TEST_CASES = [ + pytest.param( + False, # not streaming + NO_REASONING, + id="no_reasoning", + ), + pytest.param( + False, # not streaming + NO_REASONING_WITH_NEWLINE, + id="no_reasoning_with_newline", + ), + pytest.param( + False, # not streaming + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + False, # not streaming + SIMPLE_REASONING_WITH_NEWLINE, + id="simple_reasoning_with_newline", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES, + id="simple_reasoning_with_multiple_newlines", + ), + pytest.param( + False, # not streaming + NO_REASONING_ONLY_END_THINK, + id="no_reasoning_only_end_think", + ), + pytest.param( + False, # not streaming + REASONING_ONLY_END_THINK, + id="yes_reasoning_only_end_think", + ), + pytest.param( + True, # enable streaming + NO_REASONING, + id="no_reasoning_streaming", + ), + pytest.param( + True, # enable streaming + NO_REASONING_WITH_NEWLINE, + id="no_reasoning_with_newline_streaming", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING_WITH_NEWLINE, + id="simple_reasoning_with_newline_streaming", + ), + pytest.param( + True, # enable streaming + SIMPLE_REASONING_WITH_MULTIPLE_NEWLINES, + id="simple_reasoning_with_multiple_newlines_streaming", + ), + pytest.param( + True, # enable streaming + NO_REASONING_ONLY_END_THINK, + id="no_reasoning_only_end_think_streaming", + ), + pytest.param( + True, # enable streaming + REASONING_ONLY_END_THINK, + id="yes_reasoning_only_end_think_streaming", + ), +] + +# Global tokenizer initialization to avoid repeated loading +tokenizer = AutoTokenizer.from_pretrained("allenai/dolma2-tokenizer") + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict[str, str], +): + output = tokenizer.tokenize(param_dict["output"]) + + # decode everything to tokens + model_output: list[str] = [ + tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser: ReasoningParser = parser_cls(tokenizer) + + reasoning, content = run_reasoning_extraction( + reasoning_parser=parser, model_output=model_output, streaming=streaming + ) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] diff --git a/tests/reasoning/test_qwen3_reasoning_parser.py b/tests/reasoning/test_qwen3_reasoning_parser.py index 2d5557d5cdc1..c06e40d72de2 100644 --- a/tests/reasoning/test_qwen3_reasoning_parser.py +++ b/tests/reasoning/test_qwen3_reasoning_parser.py @@ -50,8 +50,7 @@ def qwen3_tokenizer(): "content": None, } MULTILINE_REASONING = { - "output": - "<think>This is a reasoning\nsection</think>This is the rest\nThat", + "output": "<think>This is a reasoning\nsection</think>This is the rest\nThat", "reasoning_content": "This is a reasoning\nsection", "content": "This is the rest\nThat", } @@ -131,12 +130,13 @@ def test_reasoning( output_tokens: list[str] = [ qwen3_tokenizer.convert_tokens_to_string([token]) for token in output ] - parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( - parser_name)(qwen3_tokenizer) + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + qwen3_tokenizer + ) - reasoning, content = run_reasoning_extraction(parser, - output_tokens, - streaming=streaming) + reasoning, content = run_reasoning_extraction( + parser, output_tokens, streaming=streaming + ) assert reasoning == param_dict["reasoning_content"] assert content == param_dict["content"] diff --git a/tests/reasoning/test_seedoss_reasoning_parser.py b/tests/reasoning/test_seedoss_reasoning_parser.py index bb5dc0f4ffe4..b356b8545f41 100644 --- a/tests/reasoning/test_seedoss_reasoning_parser.py +++ b/tests/reasoning/test_seedoss_reasoning_parser.py @@ -57,14 +57,10 @@ def seedoss_tokenizer(): "is_reasoning_end": True, } WITH_START_TOKEN: dict[str, Any] = { - "output": ("<seed:think>This is a reasoning section" - "</seed:think>This is the rest"), - "reasoning_content": - "This is a reasoning section", - "content": - "This is the rest", - "is_reasoning_end": - True, + "output": ("<seed:think>This is a reasoning section</seed:think>This is the rest"), + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, } ONLY_END_TOKEN: dict[str, Any] = { "output": "Some reasoning</seed:think>This is the rest", @@ -96,7 +92,8 @@ def test_simple_reasoning(seedoss_tokenizer, streaming): parser = parser_cls(seedoss_tokenizer) reasoning, content = run_reasoning_extraction( - parser, [cast(str, SIMPLE_REASONING["output"])], streaming=streaming) + parser, [cast(str, SIMPLE_REASONING["output"])], streaming=streaming + ) assert reasoning == SIMPLE_REASONING["reasoning_content"] assert content == SIMPLE_REASONING["content"] @@ -109,7 +106,8 @@ def test_complete_reasoning(seedoss_tokenizer, streaming): parser = parser_cls(seedoss_tokenizer) reasoning, content = run_reasoning_extraction( - parser, [cast(str, COMPLETE_REASONING["output"])], streaming=streaming) + parser, [cast(str, COMPLETE_REASONING["output"])], streaming=streaming + ) assert reasoning == COMPLETE_REASONING["reasoning_content"] assert content == COMPLETE_REASONING["content"] @@ -122,7 +120,8 @@ def test_no_content(seedoss_tokenizer, streaming): parser = parser_cls(seedoss_tokenizer) reasoning, content = run_reasoning_extraction( - parser, [cast(str, NO_CONTENT["output"])], streaming=streaming) + parser, [cast(str, NO_CONTENT["output"])], streaming=streaming + ) assert reasoning == NO_CONTENT["reasoning_content"] assert content == NO_CONTENT["content"] @@ -135,7 +134,8 @@ def test_multiple_lines(seedoss_tokenizer, streaming): parser = parser_cls(seedoss_tokenizer) reasoning, content = run_reasoning_extraction( - parser, [cast(str, MULTIPLE_LINES["output"])], streaming=streaming) + parser, [cast(str, MULTIPLE_LINES["output"])], streaming=streaming + ) assert reasoning == MULTIPLE_LINES["reasoning_content"] assert content == MULTIPLE_LINES["content"] @@ -148,7 +148,8 @@ def test_with_start_token(seedoss_tokenizer, streaming): parser = parser_cls(seedoss_tokenizer) reasoning, content = run_reasoning_extraction( - parser, [cast(str, WITH_START_TOKEN["output"])], streaming=streaming) + parser, [cast(str, WITH_START_TOKEN["output"])], streaming=streaming + ) assert reasoning == WITH_START_TOKEN["reasoning_content"] assert content == WITH_START_TOKEN["content"] @@ -157,14 +158,15 @@ def test_with_start_token(seedoss_tokenizer, streaming): @pytest.mark.parametrize("streaming", [True, False]) def test_only_end_token(seedoss_tokenizer, streaming): """ - Test reasoning extraction with only end token - (SeedOSS typical behavior). + Test reasoning extraction with only end token + (SeedOSS typical behavior). """ parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) parser = parser_cls(seedoss_tokenizer) reasoning, content = run_reasoning_extraction( - parser, [cast(str, ONLY_END_TOKEN["output"])], streaming=streaming) + parser, [cast(str, ONLY_END_TOKEN["output"])], streaming=streaming + ) assert reasoning == ONLY_END_TOKEN["reasoning_content"] assert content == ONLY_END_TOKEN["content"] @@ -177,7 +179,8 @@ def test_no_tokens(seedoss_tokenizer, streaming): parser = parser_cls(seedoss_tokenizer) reasoning, content = run_reasoning_extraction( - parser, [cast(str, NO_TOKENS["output"])], streaming=streaming) + parser, [cast(str, NO_TOKENS["output"])], streaming=streaming + ) assert reasoning == NO_TOKENS["reasoning_content"] assert content == NO_TOKENS["content"] @@ -225,13 +228,9 @@ def test_streaming_delta_processing(seedoss_tokenizer): parser = parser_cls(seedoss_tokenizer) # Test streaming with incremental tokens - deltas = [ - "Some ", "reasoning ", "content", "</seed:think>", "Final ", "answer" - ] + deltas = ["Some ", "reasoning ", "content", "</seed:think>", "Final ", "answer"] - reasoning, content = run_reasoning_extraction(parser, - deltas, - streaming=True) + reasoning, content = run_reasoning_extraction(parser, deltas, streaming=True) assert reasoning == "Some reasoning content" assert content == "Final answer" diff --git a/tests/reasoning/utils.py b/tests/reasoning/utils.py index 9af5fa5addbc..788136e99681 100644 --- a/tests/reasoning/utils.py +++ b/tests/reasoning/utils.py @@ -3,14 +3,12 @@ from typing import Optional, Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.reasoning import ReasoningParser from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer class StreamingReasoningReconstructor: - def __init__(self): self.reasoning_content = None self.other_content = None @@ -19,8 +17,8 @@ def append_delta(self, delta: DeltaMessage): # content and the reasoning content should not be present # at the same time assert delta.content is None or delta.reasoning_content is None, ( - "Both content and reasoning content are present in the " - "delta message") + "Both content and reasoning content are present in the delta message" + ) if delta.content is not None: if self.other_content is None: self.other_content = delta.content @@ -51,7 +49,8 @@ def run_reasoning_extraction( ) else: reasoning, content = run_reasoning_extraction_nonstreaming( - reasoning_parser, model_output, request) + reasoning_parser, model_output, request + ) return reasoning, content @@ -61,8 +60,9 @@ def run_reasoning_extraction_mistral( request: Union[ChatCompletionRequest, None] = None, streaming: bool = False, ) -> tuple[Optional[str], Optional[str]]: - assert isinstance(reasoning_parser.model_tokenizer, - MistralTokenizer), type(reasoning_parser.model_tokenizer) + assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type( + reasoning_parser.model_tokenizer + ) if streaming: reconstructor = run_reasoning_extraction_streaming_mistral( reasoning_parser, @@ -75,9 +75,11 @@ def run_reasoning_extraction_mistral( ) else: str_output = reasoning_parser.model_tokenizer.convert_ids_to_tokens( - model_output) + model_output + ) reasoning, content = run_reasoning_extraction_nonstreaming( - reasoning_parser, str_output, request) + reasoning_parser, str_output, request + ) return reasoning, content @@ -88,7 +90,8 @@ def run_reasoning_extraction_nonstreaming( ) -> tuple[Optional[str], Optional[str]]: request = request or ChatCompletionRequest(messages=[], model="test-model") return reasoning_parser.extract_reasoning_content( - model_output=''.join(model_output), request=request) + model_output="".join(model_output), request=request + ) def run_reasoning_extraction_streaming( @@ -128,16 +131,16 @@ def run_reasoning_extraction_streaming_mistral( model_deltas: list[int], request: Union[ChatCompletionRequest, None] = None, ) -> StreamingReasoningReconstructor: - assert isinstance(reasoning_parser.model_tokenizer, - MistralTokenizer), type(reasoning_parser.model_tokenizer) + assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type( + reasoning_parser.model_tokenizer + ) request = request or ChatCompletionRequest(messages=[], model="test-model") reconstructor = StreamingReasoningReconstructor() previous_text = "" previous_tokens: list[int] = [] for model_delta in model_deltas: token_delta = [model_delta] - delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens( - [model_delta])[0] + delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens([model_delta])[0] current_text = previous_text + delta current_tokens = previous_tokens + token_delta delta_message = reasoning_parser.extract_reasoning_content_streaming( diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 2960ffcbd9ea..78f5ab3e2d19 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -36,19 +36,21 @@ def test_beam_search_single_input( ) -> None: example_prompts = example_prompts[:1] with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, - max_tokens) + hf_outputs = hf_model.generate_beam_search( + example_prompts, beam_width, max_tokens + ) with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_beam_search(example_prompts, - beam_width, max_tokens) + vllm_outputs = vllm_model.generate_beam_search( + example_prompts, beam_width, max_tokens + ) for i in range(len(example_prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] vllm_output_ids, vllm_output_texts = vllm_outputs[i] - for j, (hf_text, - vllm_text) in enumerate(zip(hf_output_texts, - vllm_output_texts)): + for j, (hf_text, vllm_text) in enumerate( + zip(hf_output_texts, vllm_output_texts) + ): print(f">>>{j}-th hf output:") print(hf_text) print(f">>>{j}-th vllm output:") @@ -56,8 +58,8 @@ def test_beam_search_single_input( assert len(hf_output_ids) == len(vllm_output_ids) for j in range(len(hf_output_ids)): assert hf_output_ids[j] == vllm_output_ids[j], ( - f"Test{i} output{j}:\nHF: {hf_output_ids}\n" - f"vLLM: {vllm_output_ids}") + f"Test{i} output{j}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}" + ) @pytest.mark.skip_v1 # FIXME: This fails on V1 right now. @@ -76,30 +78,29 @@ def test_beam_search_with_concurrency_limit( ) -> None: # example_prompts[1]&[3]&[7] fails due to unknown reason even without # concurrency limit. skip them for now. - example_prompts = (example_prompts[:8]) + example_prompts = example_prompts[:8] concurrency_limit = 2 assert len(example_prompts) > concurrency_limit with vllm_runner(model, dtype=dtype) as vllm_model: outputs_with_limit = vllm_model.generate_beam_search( - example_prompts, - beam_width, - max_tokens, - concurrency_limit=concurrency_limit) + example_prompts, beam_width, max_tokens, concurrency_limit=concurrency_limit + ) outputs_without_limit = [] for i in range(0, len(example_prompts), concurrency_limit): outputs_without_limit.extend( vllm_model.generate_beam_search( - example_prompts[i:i + concurrency_limit], beam_width, - max_tokens)) + example_prompts[i : i + concurrency_limit], beam_width, max_tokens + ) + ) correct = True for i in range(len(example_prompts)): output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i] - output_ids_without_limit, output_texts_without_limit = ( - outputs_without_limit[i]) + output_ids_without_limit, output_texts_without_limit = outputs_without_limit[i] for j, (text_with_limit, text_without_limit) in enumerate( - zip(output_texts_with_limit, output_texts_without_limit)): + zip(output_texts_with_limit, output_texts_without_limit) + ): print(f">>>{j}-th with limit output:") print(text_with_limit) print(f">>>{j}-th without limit output:") @@ -107,8 +108,10 @@ def test_beam_search_with_concurrency_limit( assert len(output_ids_with_limit) == len(output_ids_without_limit) for j in range(len(output_ids_with_limit)): if output_ids_with_limit[j] != output_ids_without_limit[j]: - print(f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n" - f"-limit: {output_ids_without_limit}") + print( + f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n" + f"-limit: {output_ids_without_limit}" + ) correct = False assert correct @@ -131,11 +134,10 @@ def test_beam_search_passes_multimodal_data( model = "Qwen/Qwen2-Audio-7B-Instruct" audio_seq = "<|audio_bos|><|AUDIO|><|audio_eos|>" prompts = [ - f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" #noqa: E501 + f"<|im_start|>user\n{audio_seq}Can you transcribe this?<|im_end|>\n<|im_start|>assistant\n" # noqa: E501 ] - with hf_runner(model, dtype=dtype, - auto_cls=AutoModelForSeq2SeqLM) as hf_model: + with hf_runner(model, dtype=dtype, auto_cls=AutoModelForSeq2SeqLM) as hf_model: audio_token_id = hf_model.config.audio_token_index eos_token_id = hf_model.tokenizer.eos_token_id # <|im_end|> hf_outputs = hf_model.generate_beam_search( @@ -153,17 +155,15 @@ def test_beam_search_passes_multimodal_data( audios=audios, ) - seq_with_no_audio_toks = lambda seq: [ - tok for tok in seq if tok != audio_token_id - ] + seq_with_no_audio_toks = lambda seq: [tok for tok in seq if tok != audio_token_id] for i in range(len(prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] vllm_output_ids, vllm_output_texts = vllm_outputs[i] - for j, (hf_text, - vllm_text) in enumerate(zip(hf_output_texts, - vllm_output_texts)): + for j, (hf_text, vllm_text) in enumerate( + zip(hf_output_texts, vllm_output_texts) + ): print(f">>>{j}-th hf output [NOTE: special tokens are filtered]:") print(hf_text) print(f">>>{j}-th vllm output:") @@ -176,12 +176,10 @@ def test_beam_search_passes_multimodal_data( # token to match features, while the vLLM helper maintains the # single audio token in the input text filtered_hf_output_ids = seq_with_no_audio_toks(hf_output_ids[j]) - filtered_vllm_output_ids = seq_with_no_audio_toks( - vllm_output_ids[j]) + filtered_vllm_output_ids = seq_with_no_audio_toks(vllm_output_ids[j]) # HF output IDs may contain the end of sequence - if len(filtered_hf_output_ids - ) == len(filtered_vllm_output_ids) + 1: + if len(filtered_hf_output_ids) == len(filtered_vllm_output_ids) + 1: assert filtered_hf_output_ids[-1] == eos_token_id filtered_hf_output_ids = filtered_hf_output_ids[:-1] diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py index 1d77d37a5d58..d1609b24cc5a 100644 --- a/tests/samplers/test_ignore_eos.py +++ b/tests/samplers/test_ignore_eos.py @@ -25,11 +25,11 @@ def test_ignore_eos( max_tokens: int, ) -> None: with vllm_runner(model, dtype=dtype) as vllm_model: - sampling_params = SamplingParams(max_tokens=max_tokens, - ignore_eos=True) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) for prompt in example_prompts: ignore_eos_output = vllm_model.llm.generate( - prompt, sampling_params=sampling_params) + prompt, sampling_params=sampling_params + ) output_length = len(ignore_eos_output[0].outputs[0].token_ids) assert output_length == max_tokens diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index 128e8f552a16..42aebcd52414 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -5,6 +5,7 @@ Run `pytest tests/samplers/test_no_bad_words.py`. """ + from typing import Optional import pytest @@ -16,7 +17,7 @@ @pytest.fixture(autouse=True) def v1(monkeypatch): """Only run on vLLM v1.""" - monkeypatch.setenv('VLLM_USE_V1', '1') + monkeypatch.setenv("VLLM_USE_V1", "1") def _generate( @@ -49,25 +50,24 @@ class TestOneTokenBadWord: TARGET_TOKEN = "you" def setup_method(self, method): - self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL, - add_prefix_space=True) + self.tokenizer = AutoTokenizer.from_pretrained( + self.MODEL, add_prefix_space=True + ) self.num_prompt_tokens = len(self._encode(self.PROMPT)) - self.target_token_id = self._encode(self.TARGET_TOKEN, - add_special_tokens=False)[0] + self.target_token_id = self._encode( + self.TARGET_TOKEN, add_special_tokens=False + )[0] def test_one_token_bad_word(self, vllm_runner): with vllm_runner(self.MODEL) as llm: output_token_ids = self._generate(llm) assert output_token_ids[0] == self.target_token_id - output_token_ids = self._generate(llm, - bad_words=[self.TARGET_TOKEN]) + output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN]) assert self.target_token_id not in output_token_ids - def _generate(self, - llm: LLM, - bad_words: Optional[list[str]] = None) -> list[int]: + def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]: return _generate( llm=llm, prompt=self.PROMPT, @@ -75,11 +75,8 @@ def _generate(self, bad_words=bad_words, ) - def _encode(self, - prompt: str, - add_special_tokens: bool = True) -> list[int]: - return self.tokenizer(prompt, - add_special_tokens=add_special_tokens).input_ids + def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]: + return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids class TestTwoTokenBadWord: @@ -92,72 +89,80 @@ class TestTwoTokenBadWord: NEIGHBOUR_TOKEN2 = "older" def setup_method(self, method): - self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL, - add_prefix_space=True) + self.tokenizer = AutoTokenizer.from_pretrained( + self.MODEL, add_prefix_space=True + ) self.num_prompt_tokens = len(self._encode(self.PROMPT)) - self.target_token_id1 = self._encode(self.TARGET_TOKEN1, - add_special_tokens=False)[0] - self.target_token_id2 = self._encode(self.TARGET_TOKEN2, - add_special_tokens=False)[0] - self.neighbour_token_id2 = self._encode(self.NEIGHBOUR_TOKEN2, - add_special_tokens=False)[0] + self.target_token_id1 = self._encode( + self.TARGET_TOKEN1, add_special_tokens=False + )[0] + self.target_token_id2 = self._encode( + self.TARGET_TOKEN2, add_special_tokens=False + )[0] + self.neighbour_token_id2 = self._encode( + self.NEIGHBOUR_TOKEN2, add_special_tokens=False + )[0] def test_two_token_bad_word(self, vllm_runner): with vllm_runner(self.MODEL, dtype="half") as llm: output_token_ids = self._generate(llm) assert output_token_ids[:2] == [ - self.target_token_id1, self.target_token_id2 + self.target_token_id1, + self.target_token_id2, ] - output_token_ids = self._generate(llm, - bad_words=[self.TARGET_TOKEN1]) + output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN1]) assert self.target_token_id1 not in output_token_ids - output_token_ids = self._generate(llm, - bad_words=[self.TARGET_TOKEN2]) + output_token_ids = self._generate(llm, bad_words=[self.TARGET_TOKEN2]) assert output_token_ids[0] == self.target_token_id1 assert self.target_token_id2 not in output_token_ids output_token_ids = self._generate( - llm, bad_words=[f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}']) + llm, bad_words=[f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}"] + ) assert output_token_ids[0] == self.target_token_id1 assert output_token_ids[:2] != [ - self.target_token_id1, self.target_token_id2 + self.target_token_id1, + self.target_token_id2, ] assert not self._contains( - output_token_ids, - [self.target_token_id1, self.target_token_id2]) + output_token_ids, [self.target_token_id1, self.target_token_id2] + ) # Model dependent behaviour assert output_token_ids[:2] == [ - self.target_token_id1, self.neighbour_token_id2 + self.target_token_id1, + self.neighbour_token_id2, ] output_token_ids = self._generate( llm, bad_words=[ - f'{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}', - f'{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}' - ]) + f"{self.TARGET_TOKEN1} {self.TARGET_TOKEN2}", + f"{self.TARGET_TOKEN1} {self.NEIGHBOUR_TOKEN2}", + ], + ) assert output_token_ids[0] == self.target_token_id1 assert output_token_ids[:2] != [ - self.target_token_id1, self.target_token_id2 + self.target_token_id1, + self.target_token_id2, ] assert not self._contains( - output_token_ids, - [self.target_token_id1, self.target_token_id2]) + output_token_ids, [self.target_token_id1, self.target_token_id2] + ) assert output_token_ids[:2] != [ - self.target_token_id1, self.neighbour_token_id2 + self.target_token_id1, + self.neighbour_token_id2, ] assert not self._contains( - output_token_ids, - [self.target_token_id1, self.neighbour_token_id2]) - assert ((self.target_token_id2 in output_token_ids) - or (self.neighbour_token_id2 in output_token_ids)) - - def _generate(self, - llm: LLM, - bad_words: Optional[list[str]] = None) -> list[int]: + output_token_ids, [self.target_token_id1, self.neighbour_token_id2] + ) + assert (self.target_token_id2 in output_token_ids) or ( + self.neighbour_token_id2 in output_token_ids + ) + + def _generate(self, llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]: return _generate( llm=llm, prompt=self.PROMPT, @@ -187,8 +192,5 @@ def _contains(sequence: list[int], subsequence: list[int]) -> bool: return False - def _encode(self, - prompt: str, - add_special_tokens: bool = True) -> list[int]: - return self.tokenizer(prompt, - add_special_tokens=add_special_tokens).input_ids + def _encode(self, prompt: str, add_special_tokens: bool = True) -> list[int]: + return self.tokenizer(prompt, add_special_tokens=add_special_tokens).input_ids diff --git a/tests/samplers/test_ranks.py b/tests/samplers/test_ranks.py index 220a4a53f467..1359e6403e4c 100644 --- a/tests/samplers/test_ranks.py +++ b/tests/samplers/test_ranks.py @@ -20,25 +20,27 @@ def test_ranks( num_top_logprobs = 5 num_prompt_logprobs = 5 - with vllm_runner(model, dtype=dtype, - max_logprobs=num_top_logprobs) as vllm_model: - + with vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) as vllm_model: ## Test greedy logprobs ranks vllm_sampling_params = SamplingParams( temperature=0.0, top_p=1.0, max_tokens=max_tokens, logprobs=num_top_logprobs, - prompt_logprobs=num_prompt_logprobs) - vllm_results = vllm_model.generate_w_logprobs(example_prompts, - vllm_sampling_params) + prompt_logprobs=num_prompt_logprobs, + ) + vllm_results = vllm_model.generate_w_logprobs( + example_prompts, vllm_sampling_params + ) ## Test non-greedy logprobs ranks - sampling_params = SamplingParams(temperature=1.0, - top_p=1.0, - max_tokens=max_tokens, - logprobs=num_top_logprobs, - prompt_logprobs=num_prompt_logprobs) + sampling_params = SamplingParams( + temperature=1.0, + top_p=1.0, + max_tokens=max_tokens, + logprobs=num_top_logprobs, + prompt_logprobs=num_prompt_logprobs, + ) res = vllm_model.generate_w_logprobs(example_prompts, sampling_params) for result in vllm_results: diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index 87d799a5fed7..5ce6e1593b5c 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -7,19 +7,26 @@ from vllm.model_executor.models.interfaces import supports_eagle3 -@pytest.mark.parametrize("model_path", [ - pytest.param( - "nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized", - id="llama3-eagle3-speculator"), - pytest.param( - "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", - id="qwen3-eagle3-speculator"), - pytest.param( - "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", - id="qwen3-eagle3-speculator-w4a16-verifier"), -]) -def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path, - monkeypatch): +@pytest.mark.parametrize( + "model_path", + [ + pytest.param( + "nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized", + id="llama3-eagle3-speculator", + ), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", + id="qwen3-eagle3-speculator", + ), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", + id="qwen3-eagle3-speculator-w4a16-verifier", + ), + ], +) +def test_eagle3_speculators_model( + vllm_runner, example_prompts, model_path, monkeypatch +): """ Test Eagle3 speculators models properly initialize speculative decoding. @@ -40,18 +47,19 @@ def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path, vllm_config = vllm_model.llm.llm_engine.vllm_config - assert isinstance(vllm_config.speculative_config, SpeculativeConfig), \ + assert isinstance(vllm_config.speculative_config, SpeculativeConfig), ( "Speculative config should be initialized for speculators model" + ) spec_config = vllm_config.speculative_config - assert spec_config.num_speculative_tokens > 0, \ - (f"Expected positive speculative tokens, " - f"got {spec_config.num_speculative_tokens}") + assert spec_config.num_speculative_tokens > 0, ( + f"Expected positive speculative tokens, " + f"got {spec_config.num_speculative_tokens}" + ) - assert spec_config.model == model_path, \ + assert spec_config.model == model_path, ( f"Draft model should be {model_path}, got {spec_config.model}" + ) - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens=20) - assert vllm_outputs, \ - f"No outputs generated for speculators model {model_path}" + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) + assert vllm_outputs, f"No outputs generated for speculators model {model_path}" diff --git a/tests/standalone_tests/lazy_imports.py b/tests/standalone_tests/lazy_imports.py index 21bcb6b822d1..ddcdd2a51ab9 100644 --- a/tests/standalone_tests/lazy_imports.py +++ b/tests/standalone_tests/lazy_imports.py @@ -37,4 +37,5 @@ def any_module_imported(): assert not any_module_imported(), ( f"Some the modules in {module_names} are imported. To see the first" - f" import location, run the test with `use_blame=True`.") + f" import location, run the test with `use_blame=True`." +) diff --git a/tests/test_config.py b/tests/test_config.py index 90d0c78c451f..f3d40a7d8081 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -23,8 +23,8 @@ def test_compile_config_repr_succeeds(): # test that repr(config) succeeds val = repr(config) - assert 'VllmConfig' in val - assert 'inductor_passes' in val + assert "VllmConfig" in val + assert "inductor_passes" in val @dataclass @@ -51,8 +51,7 @@ def test_get_field(): @dataclass class _TestNestedConfig: - a: _TestConfigFields = field( - default_factory=lambda: _TestConfigFields(a=0)) + a: _TestConfigFields = field(default_factory=lambda: _TestConfigFields(a=0)) def test_update_config(): @@ -79,20 +78,19 @@ def test_update_config(): # Can remove once --task option is fully deprecated @pytest.mark.parametrize( - ("model_id", "expected_runner_type", "expected_convert_type", - "expected_task"), + ("model_id", "expected_runner_type", "expected_convert_type", "expected_task"), [ ("distilbert/distilgpt2", "generate", "none", "generate"), ("intfloat/multilingual-e5-small", "pooling", "none", "embed"), ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"), - ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", - "classify"), + ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "reward"), ("openai/whisper-small", "generate", "none", "transcription"), ], ) -def test_auto_task(model_id, expected_runner_type, expected_convert_type, - expected_task): +def test_auto_task( + model_id, expected_runner_type, expected_convert_type, expected_task +): config = ModelConfig(model_id, task="auto") assert config.runner_type == expected_runner_type @@ -101,20 +99,19 @@ def test_auto_task(model_id, expected_runner_type, expected_convert_type, # Can remove once --task option is fully deprecated @pytest.mark.parametrize( - ("model_id", "expected_runner_type", "expected_convert_type", - "expected_task"), + ("model_id", "expected_runner_type", "expected_convert_type", "expected_task"), [ ("distilbert/distilgpt2", "pooling", "embed", "embed"), ("intfloat/multilingual-e5-small", "pooling", "embed", "embed"), ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"), - ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify", - "classify"), + ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify", "classify"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed", "embed"), ("openai/whisper-small", "pooling", "embed", "embed"), ], ) -def test_score_task(model_id, expected_runner_type, expected_convert_type, - expected_task): +def test_score_task( + model_id, expected_runner_type, expected_convert_type, expected_task +): config = ModelConfig(model_id, task="score") assert config.runner_type == expected_runner_type @@ -123,14 +120,14 @@ def test_score_task(model_id, expected_runner_type, expected_convert_type, # Can remove once --task option is fully deprecated @pytest.mark.parametrize( - ("model_id", "expected_runner_type", "expected_convert_type", - "expected_task"), + ("model_id", "expected_runner_type", "expected_convert_type", "expected_task"), [ ("openai/whisper-small", "generate", "none", "transcription"), ], ) -def test_transcription_task(model_id, expected_runner_type, - expected_convert_type, expected_task): +def test_transcription_task( + model_id, expected_runner_type, expected_convert_type, expected_task +): config = ModelConfig(model_id, task="transcription") assert config.runner_type == expected_runner_type @@ -200,8 +197,9 @@ def test_disable_sliding_window(model_id_expected): assert model_config.max_model_len == expected -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_get_pooling_config(): model_id = "sentence-transformers/all-MiniLM-L12-v2" model_config = ModelConfig(model_id) @@ -211,8 +209,9 @@ def test_get_pooling_config(): assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_get_pooling_config_from_args(): model_id = "sentence-transformers/all-MiniLM-L12-v2" pooler_config = PoolerConfig(pooling_type="CLS", normalize=True) @@ -227,16 +226,18 @@ def test_get_pooling_config_from_args(): ("tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "LAST", "LAST"), # LLM ("intfloat/e5-small", "CLS", "MEAN"), # BertModel ("Qwen/Qwen2.5-Math-RM-72B", "ALL", "ALL"), # reward - ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP") # step reward - ]) + ("Qwen/Qwen2.5-Math-PRM-7B", "STEP", "STEP"), # step reward + ], +) def test_default_pooling_type(model_id, default_pooling_type, pooling_type): model_config = ModelConfig(model_id) assert model_config._model_info.default_pooling_type == default_pooling_type assert model_config.pooler_config.pooling_type == pooling_type -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Xformers backend is not supported on ROCm.") +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." +) def test_get_bert_tokenization_sentence_transformer_config(): model_id = "BAAI/bge-base-en-v1.5" bge_model_config = ModelConfig(model_id) @@ -264,17 +265,18 @@ def test_rope_customization(): "rope_theta": TEST_ROPE_THETA, }, ) - assert getattr(llama_model_config.hf_config, "rope_scaling", - None) == TEST_ROPE_SCALING - assert getattr(llama_model_config.hf_config, "rope_theta", - None) == TEST_ROPE_THETA + assert ( + getattr(llama_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING + ) + assert getattr(llama_model_config.hf_config, "rope_theta", None) == TEST_ROPE_THETA assert llama_model_config.max_model_len == 16384 longchat_model_config = ModelConfig("lmsys/longchat-13b-16k") # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config assert all( longchat_model_config.hf_config.rope_scaling.get(key) == value - for key, value in LONGCHAT_ROPE_SCALING.items()) + for key, value in LONGCHAT_ROPE_SCALING.items() + ) assert longchat_model_config.max_model_len == 16384 longchat_model_config = ModelConfig( @@ -283,28 +285,37 @@ def test_rope_customization(): "rope_scaling": TEST_ROPE_SCALING, }, ) - assert getattr(longchat_model_config.hf_config, "rope_scaling", - None) == TEST_ROPE_SCALING + assert ( + getattr(longchat_model_config.hf_config, "rope_scaling", None) + == TEST_ROPE_SCALING + ) assert longchat_model_config.max_model_len == 4096 -@pytest.mark.skipif(current_platform.is_rocm(), - reason="Encoder Decoder models not supported on ROCm.") -@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [ - ("facebook/opt-125m", False), - ("openai/whisper-tiny", True), - ("meta-llama/Llama-3.2-1B-Instruct", False), -]) +@pytest.mark.skipif( + current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm." +) +@pytest.mark.parametrize( + ("model_id", "is_encoder_decoder"), + [ + ("facebook/opt-125m", False), + ("openai/whisper-tiny", True), + ("meta-llama/Llama-3.2-1B-Instruct", False), + ], +) def test_is_encoder_decoder(model_id, is_encoder_decoder): config = ModelConfig(model_id) assert config.is_encoder_decoder == is_encoder_decoder -@pytest.mark.parametrize(("model_id", "uses_mrope"), [ - ("facebook/opt-125m", False), - ("Qwen/Qwen2-VL-2B-Instruct", True), -]) +@pytest.mark.parametrize( + ("model_id", "uses_mrope"), + [ + ("facebook/opt-125m", False), + ("Qwen/Qwen2-VL-2B-Instruct", True), + ], +) def test_uses_mrope(model_id, uses_mrope): config = ModelConfig(model_id) @@ -338,7 +349,8 @@ def test_generation_config_loading(): model_config = ModelConfig( model_id, generation_config="auto", - override_generation_config=override_generation_config) + override_generation_config=override_generation_config, + ) override_result = correct_generation_config.copy() override_result.update(override_generation_config) @@ -350,17 +362,19 @@ def test_generation_config_loading(): model_config = ModelConfig( model_id, generation_config="vllm", - override_generation_config=override_generation_config) + override_generation_config=override_generation_config, + ) assert model_config.get_diff_sampling_param() == override_generation_config -@pytest.mark.parametrize("pt_load_map_location", [ - "cuda", - { - "": "cuda" - }, -]) +@pytest.mark.parametrize( + "pt_load_map_location", + [ + "cuda", + {"": "cuda"}, + ], +) def test_load_config_pt_load_map_location(pt_load_map_location): load_config = LoadConfig(pt_load_map_location=pt_load_map_location) config = VllmConfig(load_config=load_config) @@ -369,15 +383,18 @@ def test_load_config_pt_load_map_location(pt_load_map_location): @pytest.mark.parametrize( - ("model_id", "max_model_len", "expected_max_len", "should_raise"), [ + ("model_id", "max_model_len", "expected_max_len", "should_raise"), + [ ("BAAI/bge-reranker-base", None, 512, False), ("BAAI/bge-reranker-base", 256, 256, False), ("BAAI/bge-reranker-base", 513, 512, True), ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", None, 131072, False), ("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", 131073, 131072, True), - ]) -def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len, - should_raise): + ], +) +def test_get_and_verify_max_len( + model_id, max_model_len, expected_max_len, should_raise +): """Test get_and_verify_max_len with different configurations.""" model_config = ModelConfig(model_id) @@ -398,11 +415,14 @@ def __init__(self, model: str, tokenizer: str): self.model_weights = None -@pytest.mark.parametrize("s3_url", [ - "s3://example-bucket-1/model/", - "s3://example-bucket-2/model/", -]) -@patch('vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files') +@pytest.mark.parametrize( + "s3_url", + [ + "s3://example-bucket-1/model/", + "s3://example-bucket-2/model/", + ], +) +@patch("vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files") def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url): """Test that S3 URLs create deterministic local directories for model and tokenizer.""" @@ -414,22 +434,24 @@ def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url): ModelConfig.maybe_pull_model_tokenizer_for_runai(config1, s3_url, s3_url) # Check that model and tokenizer point to existing directories - assert os.path.exists( - config1.model), f"Model directory does not exist: {config1.model}" - assert os.path.isdir( - config1.model), f"Model path is not a directory: {config1.model}" - assert os.path.exists( - config1.tokenizer - ), f"Tokenizer directory does not exist: {config1.tokenizer}" - assert os.path.isdir( - config1.tokenizer - ), f"Tokenizer path is not a directory: {config1.tokenizer}" + assert os.path.exists(config1.model), ( + f"Model directory does not exist: {config1.model}" + ) + assert os.path.isdir(config1.model), ( + f"Model path is not a directory: {config1.model}" + ) + assert os.path.exists(config1.tokenizer), ( + f"Tokenizer directory does not exist: {config1.tokenizer}" + ) + assert os.path.isdir(config1.tokenizer), ( + f"Tokenizer path is not a directory: {config1.tokenizer}" + ) # Verify that the paths are different from the original S3 URL - assert config1.model != s3_url, ( - "Model path should be converted to local directory") + assert config1.model != s3_url, "Model path should be converted to local directory" assert config1.tokenizer != s3_url, ( - "Tokenizer path should be converted to local directory") + "Tokenizer path should be converted to local directory" + ) # Store the original paths created_model_dir = config1.model @@ -440,27 +462,31 @@ def test_s3_url_model_tokenizer_paths(mock_pull_files, s3_url): ModelConfig.maybe_pull_model_tokenizer_for_runai(config2, s3_url, s3_url) # Check that the new directories exist - assert os.path.exists( - config2.model), f"Model directory does not exist: {config2.model}" - assert os.path.isdir( - config2.model), f"Model path is not a directory: {config2.model}" - assert os.path.exists( - config2.tokenizer - ), f"Tokenizer directory does not exist: {config2.tokenizer}" - assert os.path.isdir( - config2.tokenizer - ), f"Tokenizer path is not a directory: {config2.tokenizer}" + assert os.path.exists(config2.model), ( + f"Model directory does not exist: {config2.model}" + ) + assert os.path.isdir(config2.model), ( + f"Model path is not a directory: {config2.model}" + ) + assert os.path.exists(config2.tokenizer), ( + f"Tokenizer directory does not exist: {config2.tokenizer}" + ) + assert os.path.isdir(config2.tokenizer), ( + f"Tokenizer path is not a directory: {config2.tokenizer}" + ) # Verify that the paths are deterministic (same as before) assert config2.model == created_model_dir, ( f"Model paths are not deterministic. " - f"Original: {created_model_dir}, New: {config2.model}") + f"Original: {created_model_dir}, New: {config2.model}" + ) assert config2.tokenizer == create_tokenizer_dir, ( f"Tokenizer paths are not deterministic. " - f"Original: {create_tokenizer_dir}, New: {config2.tokenizer}") + f"Original: {create_tokenizer_dir}, New: {config2.tokenizer}" + ) -@patch('vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files') +@patch("vllm.transformers_utils.runai_utils.ObjectStorageModel.pull_files") def test_s3_url_different_models_create_different_directories(mock_pull_files): """Test that different S3 URLs create different local directories.""" # Mock pull_files to avoid actually downloading files during tests @@ -479,16 +505,16 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files): # Verify that different URLs produce different directories assert config1.model != config2.model, ( f"Different S3 URLs should create different model directories. " - f"URL1 model: {config1.model}, URL2 model: {config2.model}") + f"URL1 model: {config1.model}, URL2 model: {config2.model}" + ) assert config1.tokenizer != config2.tokenizer, ( f"Different S3 URLs should create different tokenizer directories. " f"URL1 tokenizer: {config1.tokenizer}, " - f"URL2 tokenizer: {config2.tokenizer}") + f"URL2 tokenizer: {config2.tokenizer}" + ) # Verify that both sets of directories exist assert os.path.exists(config1.model) and os.path.isdir(config1.model) - assert os.path.exists(config1.tokenizer) and os.path.isdir( - config1.tokenizer) + assert os.path.exists(config1.tokenizer) and os.path.isdir(config1.tokenizer) assert os.path.exists(config2.model) and os.path.isdir(config2.model) - assert os.path.exists(config2.tokenizer) and os.path.isdir( - config2.tokenizer) + assert os.path.exists(config2.tokenizer) and os.path.isdir(config2.tokenizer) diff --git a/tests/test_embedded_commit.py b/tests/test_embedded_commit.py index b9593e2a3b7c..687a15446fc2 100644 --- a/tests/test_embedded_commit.py +++ b/tests/test_embedded_commit.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import vllm - - -def test_embedded_commit_defined(): - assert hasattr(vllm, "__version__") - assert hasattr(vllm, "__version_tuple__") - assert vllm.__version__ != "dev" - assert vllm.__version_tuple__ != (0, 0, "dev") +import vllm + + +def test_embedded_commit_defined(): + assert hasattr(vllm, "__version__") + assert hasattr(vllm, "__version_tuple__") + assert vllm.__version__ != "dev" + assert vllm.__version_tuple__ != (0, 0, "dev") diff --git a/tests/test_envs.py b/tests/test_envs.py index f81a6e2e415c..62d529c36360 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -14,68 +14,71 @@ class TestEnvWithChoices: def test_default_value_returned_when_env_not_set(self): """Test default is returned when env var is not set.""" - env_func = env_with_choices("NONEXISTENT_ENV", "default", - ["option1", "option2"]) + env_func = env_with_choices( + "NONEXISTENT_ENV", "default", ["option1", "option2"] + ) assert env_func() == "default" def test_none_default_returned_when_env_not_set(self): """Test that None is returned when env not set and default is None.""" - env_func = env_with_choices("NONEXISTENT_ENV", None, - ["option1", "option2"]) + env_func = env_with_choices("NONEXISTENT_ENV", None, ["option1", "option2"]) assert env_func() is None def test_valid_value_returned_case_sensitive(self): """Test that valid value is returned in case sensitive mode.""" with patch.dict(os.environ, {"TEST_ENV": "option1"}): - env_func = env_with_choices("TEST_ENV", - "default", ["option1", "option2"], - case_sensitive=True) + env_func = env_with_choices( + "TEST_ENV", "default", ["option1", "option2"], case_sensitive=True + ) assert env_func() == "option1" def test_valid_lowercase_value_returned_case_insensitive(self): """Test that lowercase value is accepted in case insensitive mode.""" with patch.dict(os.environ, {"TEST_ENV": "option1"}): - env_func = env_with_choices("TEST_ENV", - "default", ["OPTION1", "OPTION2"], - case_sensitive=False) + env_func = env_with_choices( + "TEST_ENV", "default", ["OPTION1", "OPTION2"], case_sensitive=False + ) assert env_func() == "option1" def test_valid_uppercase_value_returned_case_insensitive(self): """Test that uppercase value is accepted in case insensitive mode.""" with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}): - env_func = env_with_choices("TEST_ENV", - "default", ["option1", "option2"], - case_sensitive=False) + env_func = env_with_choices( + "TEST_ENV", "default", ["option1", "option2"], case_sensitive=False + ) assert env_func() == "OPTION1" def test_invalid_value_raises_error_case_sensitive(self): """Test that invalid value raises ValueError in case sensitive mode.""" with patch.dict(os.environ, {"TEST_ENV": "invalid"}): - env_func = env_with_choices("TEST_ENV", - "default", ["option1", "option2"], - case_sensitive=True) - with pytest.raises(ValueError, - match="Invalid value 'invalid' for TEST_ENV"): + env_func = env_with_choices( + "TEST_ENV", "default", ["option1", "option2"], case_sensitive=True + ) + with pytest.raises( + ValueError, match="Invalid value 'invalid' for TEST_ENV" + ): env_func() def test_case_mismatch_raises_error_case_sensitive(self): """Test that case mismatch raises ValueError in case sensitive mode.""" with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}): - env_func = env_with_choices("TEST_ENV", - "default", ["option1", "option2"], - case_sensitive=True) - with pytest.raises(ValueError, - match="Invalid value 'OPTION1' for TEST_ENV"): + env_func = env_with_choices( + "TEST_ENV", "default", ["option1", "option2"], case_sensitive=True + ) + with pytest.raises( + ValueError, match="Invalid value 'OPTION1' for TEST_ENV" + ): env_func() def test_invalid_value_raises_error_case_insensitive(self): """Test that invalid value raises ValueError when case insensitive.""" with patch.dict(os.environ, {"TEST_ENV": "invalid"}): - env_func = env_with_choices("TEST_ENV", - "default", ["option1", "option2"], - case_sensitive=False) - with pytest.raises(ValueError, - match="Invalid value 'invalid' for TEST_ENV"): + env_func = env_with_choices( + "TEST_ENV", "default", ["option1", "option2"], case_sensitive=False + ) + with pytest.raises( + ValueError, match="Invalid value 'invalid' for TEST_ENV" + ): env_func() def test_callable_choices_resolved_correctly(self): @@ -96,8 +99,9 @@ def get_choices(): with patch.dict(os.environ, {"TEST_ENV": "invalid"}): env_func = env_with_choices("TEST_ENV", "default", get_choices) - with pytest.raises(ValueError, - match="Invalid value 'invalid' for TEST_ENV"): + with pytest.raises( + ValueError, match="Invalid value 'invalid' for TEST_ENV" + ): env_func() @@ -106,84 +110,78 @@ class TestEnvListWithChoices: def test_default_list_returned_when_env_not_set(self): """Test that default list is returned when env var is not set.""" - env_func = env_list_with_choices("NONEXISTENT_ENV", - ["default1", "default2"], - ["option1", "option2"]) + env_func = env_list_with_choices( + "NONEXISTENT_ENV", ["default1", "default2"], ["option1", "option2"] + ) assert env_func() == ["default1", "default2"] def test_empty_default_list_returned_when_env_not_set(self): """Test that empty default list is returned when env not set.""" - env_func = env_list_with_choices("NONEXISTENT_ENV", [], - ["option1", "option2"]) + env_func = env_list_with_choices("NONEXISTENT_ENV", [], ["option1", "option2"]) assert env_func() == [] def test_single_valid_value_parsed_correctly(self): """Test that single valid value is parsed correctly.""" with patch.dict(os.environ, {"TEST_ENV": "option1"}): - env_func = env_list_with_choices("TEST_ENV", [], - ["option1", "option2"]) + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) assert env_func() == ["option1"] def test_multiple_valid_values_parsed_correctly(self): """Test that multiple valid values are parsed correctly.""" with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}): - env_func = env_list_with_choices("TEST_ENV", [], - ["option1", "option2"]) + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) assert env_func() == ["option1", "option2"] def test_values_with_whitespace_trimmed(self): """Test that values with whitespace are trimmed correctly.""" with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}): - env_func = env_list_with_choices("TEST_ENV", [], - ["option1", "option2"]) + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) assert env_func() == ["option1", "option2"] def test_empty_values_filtered_out(self): """Test that empty values are filtered out.""" with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}): - env_func = env_list_with_choices("TEST_ENV", [], - ["option1", "option2"]) + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) assert env_func() == ["option1", "option2"] def test_empty_string_returns_default(self): """Test that empty string returns default.""" with patch.dict(os.environ, {"TEST_ENV": ""}): - env_func = env_list_with_choices("TEST_ENV", ["default"], - ["option1", "option2"]) + env_func = env_list_with_choices( + "TEST_ENV", ["default"], ["option1", "option2"] + ) assert env_func() == ["default"] def test_only_commas_returns_default(self): """Test that string with only commas returns default.""" with patch.dict(os.environ, {"TEST_ENV": ",,,"}): - env_func = env_list_with_choices("TEST_ENV", ["default"], - ["option1", "option2"]) + env_func = env_list_with_choices( + "TEST_ENV", ["default"], ["option1", "option2"] + ) assert env_func() == ["default"] def test_case_sensitive_validation(self): """Test case sensitive validation.""" with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}): - env_func = env_list_with_choices("TEST_ENV", [], - ["option1", "option2"], - case_sensitive=True) - with pytest.raises(ValueError, - match="Invalid value 'OPTION2' in TEST_ENV"): + env_func = env_list_with_choices( + "TEST_ENV", [], ["option1", "option2"], case_sensitive=True + ) + with pytest.raises(ValueError, match="Invalid value 'OPTION2' in TEST_ENV"): env_func() def test_case_insensitive_validation(self): """Test case insensitive validation.""" with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}): - env_func = env_list_with_choices("TEST_ENV", [], - ["option1", "option2"], - case_sensitive=False) + env_func = env_list_with_choices( + "TEST_ENV", [], ["option1", "option2"], case_sensitive=False + ) assert env_func() == ["OPTION1", "option2"] def test_invalid_value_in_list_raises_error(self): """Test that invalid value in list raises ValueError.""" with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}): - env_func = env_list_with_choices("TEST_ENV", [], - ["option1", "option2"]) - with pytest.raises(ValueError, - match="Invalid value 'invalid' in TEST_ENV"): + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) + with pytest.raises(ValueError, match="Invalid value 'invalid' in TEST_ENV"): env_func() def test_callable_choices_resolved_correctly(self): @@ -204,13 +202,11 @@ def get_choices(): with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}): env_func = env_list_with_choices("TEST_ENV", [], get_choices) - with pytest.raises(ValueError, - match="Invalid value 'invalid' in TEST_ENV"): + with pytest.raises(ValueError, match="Invalid value 'invalid' in TEST_ENV"): env_func() def test_duplicate_values_preserved(self): """Test that duplicate values in the list are preserved.""" with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}): - env_func = env_list_with_choices("TEST_ENV", [], - ["option1", "option2"]) + env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) assert env_func() == ["option1", "option1", "option2"] diff --git a/tests/test_inputs.py b/tests/test_inputs.py index b61b95bc4333..77379cc8de90 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -3,17 +3,19 @@ import pytest +from vllm.config import ModelConfig from vllm.inputs import zip_enc_dec_prompts -from vllm.inputs.parse import parse_and_batch_prompt +from vllm.inputs.parse import parse_raw_prompts +from vllm.inputs.preprocess import InputPreprocessor pytestmark = pytest.mark.cpu_test STRING_INPUTS = [ - '', - 'foo', - 'foo bar', - 'foo baz bar', - 'foo bar qux baz', + "", + "foo", + "foo bar", + "foo baz bar", + "foo bar qux baz", ] TOKEN_INPUTS = [ @@ -31,52 +33,105 @@ ] -def test_parse_single_batch_empty(): +def test_parse_raw_single_batch_empty(): with pytest.raises(ValueError, match="at least one prompt"): - parse_and_batch_prompt([]) + parse_raw_prompts([]) with pytest.raises(ValueError, match="at least one prompt"): - parse_and_batch_prompt([[]]) + parse_raw_prompts([[]]) -@pytest.mark.parametrize('string_input', STRING_INPUTS) -def test_parse_single_batch_string_consistent(string_input: str): - assert parse_and_batch_prompt(string_input) \ - == parse_and_batch_prompt([string_input]) +@pytest.mark.parametrize("string_input", STRING_INPUTS) +def test_parse_raw_single_batch_string_consistent(string_input: str): + assert parse_raw_prompts(string_input) == parse_raw_prompts([string_input]) -@pytest.mark.parametrize('token_input', TOKEN_INPUTS) -def test_parse_single_batch_token_consistent(token_input: list[int]): - assert parse_and_batch_prompt(token_input) \ - == parse_and_batch_prompt([token_input]) +@pytest.mark.parametrize("token_input", TOKEN_INPUTS) +def test_parse_raw_single_batch_token_consistent(token_input: list[int]): + assert parse_raw_prompts(token_input) == parse_raw_prompts([token_input]) -@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES) -def test_parse_single_batch_string_slice(inputs_slice: slice): - assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ - == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) +@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES) +def test_parse_raw_single_batch_string_slice(inputs_slice: slice): + assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] == parse_raw_prompts( + STRING_INPUTS[inputs_slice] + ) -# yapf: disable -@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [ - (None, [{}, {}]), - ({}, [{}, {}]), - ({"foo": 100}, [{"foo": 100}, {"foo": 100}]), - ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]), -]) -# yapf: enable +@pytest.mark.parametrize( + "mm_processor_kwargs,expected_mm_kwargs", + [ + (None, [{}, {}]), + ({}, [{}, {}]), + ({"foo": 100}, [{"foo": 100}, {"foo": 100}]), + ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]), + ], +) def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): """Test mm_processor_kwargs init for zipping enc/dec prompts.""" - encoder_prompts = ['An encoder prompt', 'Another encoder prompt'] - decoder_prompts = ['A decoder prompt', 'Another decoder prompt'] - zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts, - mm_processor_kwargs) + encoder_prompts = ["An encoder prompt", "Another encoder prompt"] + decoder_prompts = ["A decoder prompt", "Another decoder prompt"] + zipped_prompts = zip_enc_dec_prompts( + encoder_prompts, decoder_prompts, mm_processor_kwargs + ) assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts) - for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts, - expected_mm_kwargs, - zipped_prompts): + for enc, dec, exp_kwargs, zipped in zip( + encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts + ): assert isinstance(zipped, dict) assert len(zipped.keys()) == 3 - assert zipped['encoder_prompt'] == enc - assert zipped['decoder_prompt'] == dec - assert zipped['mm_processor_kwargs'] == exp_kwargs + assert zipped["encoder_prompt"] == enc + assert zipped["decoder_prompt"] == dec + assert zipped["mm_processor_kwargs"] == exp_kwargs + + +@pytest.mark.parametrize( + "model_id", + [ + "facebook/opt-125m", + ], +) +@pytest.mark.parametrize( + "prompt", + [ + { + "prompt": "", + "multi_modal_data": {"dummy": []}, + }, + { + "prompt_token_ids": [], + "multi_modal_data": {"dummy": []}, + }, + ], +) +def test_preprocessor_text_no_mm_inputs(model_id, prompt): + model_config = ModelConfig(model=model_id) + input_preprocessor = InputPreprocessor(model_config) + + with pytest.raises(ValueError, match="does not support multimodal inputs"): + input_preprocessor.preprocess(prompt) + + +@pytest.mark.parametrize( + "model_id", + [ + "facebook/chameleon-7b", + ], +) +@pytest.mark.parametrize( + "prompt", + [ + "", + {"prompt_token_ids": []}, + ], +) +def test_preprocessor_always_mm_code_path(model_id, prompt): + model_config = ModelConfig(model=model_id) + input_preprocessor = InputPreprocessor(model_config) + tokenizer = input_preprocessor.tokenizer + + # HF processor adds sep token + sep_token_id = tokenizer.vocab[tokenizer.sep_token] + + processed_inputs = input_preprocessor.preprocess(prompt) + assert sep_token_id in processed_inputs["prompt_token_ids"] diff --git a/tests/test_logger.py b/tests/test_logger.py index 0bfb449cdf21..ec368d4897b5 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -16,8 +16,13 @@ import pytest from vllm.entrypoints.logger import RequestLogger -from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger, - enable_trace_function_call, init_logger) +from vllm.logger import ( + _DATE_FORMAT, + _FORMAT, + _configure_vllm_root_logger, + enable_trace_function_call, + init_logger, +) from vllm.logging_utils import NewLineFormatter from vllm.logging_utils.dump_input import prepare_object_to_dump @@ -129,8 +134,7 @@ def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write("---\nloggers: []\nversion: 1") logging_config_file.flush() - with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", - logging_config_file.name): + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name): with pytest.raises(JSONDecodeError) as ex_info: _configure_vllm_root_logger() assert ex_info.type == JSONDecodeError @@ -138,24 +142,24 @@ def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): @patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) -@pytest.mark.parametrize("unexpected_config", ( - "Invalid string", - [{ - "version": 1, - "loggers": [] - }], - 0, -)) +@pytest.mark.parametrize( + "unexpected_config", + ( + "Invalid string", + [{"version": 1, "loggers": []}], + 0, + ), +) def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( - unexpected_config: Any): + unexpected_config: Any, +): """This test calls _configure_vllm_root_logger again to test custom logging config behavior, however it fails before any change in behavior or configuration occurs.""" with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write(json.dumps(unexpected_config)) logging_config_file.flush() - with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", - logging_config_file.name): + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name): with pytest.raises(ValueError) as ex_info: _configure_vllm_root_logger() assert ex_info.type == ValueError # noqa: E721 @@ -174,14 +178,15 @@ def test_custom_logging_config_is_parsed_and_used_when_provided(): "propagate": False, } }, - "version": 1 + "version": 1, } with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write(json.dumps(valid_logging_config)) logging_config_file.flush() - with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", - logging_config_file.name), patch( - "vllm.logger.dictConfig") as dict_config_mock: + with ( + patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name), + patch("vllm.logger.dictConfig") as dict_config_mock, + ): _configure_vllm_root_logger() dict_config_mock.assert_called_with(valid_logging_config) @@ -197,19 +202,19 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): "handlers": [], } }, - "version": 1 + "version": 1, } with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: logging_config_file.write(json.dumps(valid_logging_config)) logging_config_file.flush() - with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", - logging_config_file.name): + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name): with pytest.raises(RuntimeError) as ex_info: _configure_vllm_root_logger() assert ex_info.type is RuntimeError expected_message_snippet = ( "VLLM_CONFIGURE_LOGGING evaluated to false, but " - "VLLM_LOGGING_CONFIG_PATH was given.") + "VLLM_LOGGING_CONFIG_PATH was given." + ) assert expected_message_snippet in str(ex_info) # Remember! The root logger is assumed to have been configured as @@ -223,11 +228,11 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): def test_prepare_object_to_dump(): - str_obj = 'str' + str_obj = "str" assert prepare_object_to_dump(str_obj) == "'str'" list_obj = [1, 2, 3] - assert prepare_object_to_dump(list_obj) == '[1, 2, 3]' + assert prepare_object_to_dump(list_obj) == "[1, 2, 3]" dict_obj = {"a": 1, "b": "b"} assert prepare_object_to_dump(dict_obj) in [ @@ -236,9 +241,9 @@ def test_prepare_object_to_dump(): ] set_obj = {1, 2, 3} - assert prepare_object_to_dump(set_obj) == '[1, 2, 3]' + assert prepare_object_to_dump(set_obj) == "[1, 2, 3]" - tuple_obj = ('a', 'b', 'c') + tuple_obj = ("a", "b", "c") assert prepare_object_to_dump(tuple_obj) == "['a', 'b', 'c']" class CustomEnum(enum.Enum): @@ -253,8 +258,7 @@ class CustomClass: a: int b: str - assert (prepare_object_to_dump(CustomClass( - 1, "b")) == "CustomClass(a=1, b='b')") + assert prepare_object_to_dump(CustomClass(1, "b")) == "CustomClass(a=1, b='b')" def test_request_logger_log_outputs(): @@ -467,7 +471,7 @@ def test_request_logger_log_outputs_integration(): def test_streaming_complete_logs_full_text_content(): """Test that streaming complete logging includes - full accumulated text, not just token count.""" + full accumulated text, not just token count.""" mock_logger = MagicMock() with patch("vllm.entrypoints.logger.logger", mock_logger): diff --git a/tests/test_outputs.py b/tests/test_outputs.py index 46da83a428e5..7b234884c569 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -9,11 +9,13 @@ def test_request_output_forward_compatible(): - output = RequestOutput(request_id="test_request_id", - prompt="test prompt", - prompt_token_ids=[1, 2, 3], - prompt_logprobs=None, - outputs=[], - finished=False, - example_arg_added_in_new_version="some_value") + output = RequestOutput( + request_id="test_request_id", + prompt="test prompt", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[], + finished=False, + example_arg_added_in_new_version="some_value", + ) assert output is not None diff --git a/tests/test_pooling_params.py b/tests/test_pooling_params.py index 52c03015483c..e3561ac3a577 100644 --- a/tests/test_pooling_params.py +++ b/tests/test_pooling_params.py @@ -8,9 +8,11 @@ EMBEDDING_MODELS = [ EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False), - EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5", - is_matryoshka=True, - matryoshka_dimensions=[256]), + EmbedModelInfo( + "Snowflake/snowflake-arctic-embed-m-v1.5", + is_matryoshka=True, + matryoshka_dimensions=[256], + ), ] @@ -65,8 +67,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo): if model_info.is_matryoshka: assert model_info.matryoshka_dimensions is not None - pooling_params = PoolingParams( - dimensions=model_info.matryoshka_dimensions[0]) + pooling_params = PoolingParams(dimensions=model_info.matryoshka_dimensions[0]) pooling_params.verify(task=task, model_config=model_config) diff --git a/tests/test_regression.py b/tests/test_regression.py index f5f1ed8e805e..8a9829e4dba5 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -6,6 +6,7 @@ will never happen again. """ + import gc import pytest @@ -18,12 +19,12 @@ def test_duplicated_ignored_sequence_group(): """https://github.com/vllm-project/vllm/issues/1655""" - sampling_params = SamplingParams(temperature=0.01, - top_p=0.1, - max_tokens=256) - llm = LLM(model="distilbert/distilgpt2", - max_num_batched_tokens=4096, - tensor_parallel_size=1) + sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=256) + llm = LLM( + model="distilbert/distilgpt2", + max_num_batched_tokens=4096, + tensor_parallel_size=1, + ) prompts = ["This is a short prompt", "This is a very long prompt " * 1000] outputs = llm.generate(prompts, sampling_params=sampling_params) @@ -31,12 +32,12 @@ def test_duplicated_ignored_sequence_group(): def test_max_tokens_none(): - sampling_params = SamplingParams(temperature=0.01, - top_p=0.1, - max_tokens=None) - llm = LLM(model="distilbert/distilgpt2", - max_num_batched_tokens=4096, - tensor_parallel_size=1) + sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) + llm = LLM( + model="distilbert/distilgpt2", + max_num_batched_tokens=4096, + tensor_parallel_size=1, + ) prompts = ["Just say hello!"] outputs = llm.generate(prompts, sampling_params=sampling_params) diff --git a/tests/test_routing_simulator.py b/tests/test_routing_simulator.py index 77501f4bddc2..5a162fa8f791 100644 --- a/tests/test_routing_simulator.py +++ b/tests/test_routing_simulator.py @@ -13,7 +13,9 @@ import torch from vllm.model_executor.layers.fused_moe.routing_simulator import ( - DistributionBasedRouting, RoutingSimulator) + DistributionBasedRouting, + RoutingSimulator, +) @pytest.fixture @@ -60,10 +62,10 @@ def test_basic_functionality( ), f"Wrong ids shape for {strategy}" # Check that expert IDs are valid - assert (topk_ids.min() - >= 0), f"Invalid expert ID (negative) for {strategy}" - assert (topk_ids.max() - < num_experts), f"Invalid expert ID (too large) for {strategy}" + assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}" + assert topk_ids.max() < num_experts, ( + f"Invalid expert ID (too large) for {strategy}" + ) def test_routing_strategy_integration(monkeypatch, device): @@ -102,19 +104,20 @@ def test_routing_strategy_integration(monkeypatch, device): top_k=top_k, use_grouped_topk=False, renormalize=True, - indices_type=torch.long) + indices_type=torch.long, + ) # Verify output shapes - assert topk_weights.shape == ( - num_tokens, top_k), f"Wrong weights shape for {strategy}" - assert topk_ids.shape == (num_tokens, - top_k), f"Wrong ids shape for {strategy}" + assert topk_weights.shape == (num_tokens, top_k), ( + f"Wrong weights shape for {strategy}" + ) + assert topk_ids.shape == (num_tokens, top_k), f"Wrong ids shape for {strategy}" # Verify expert IDs are valid - assert topk_ids.min( - ) >= 0, f"Invalid expert ID (negative) for {strategy}" - assert topk_ids.max( - ) < num_experts, f"Invalid expert ID (too large) for {strategy}" + assert topk_ids.min() >= 0, f"Invalid expert ID (negative) for {strategy}" + assert topk_ids.max() < num_experts, ( + f"Invalid expert ID (too large) for {strategy}" + ) def test_distribution_based_routing_with_custom_strategy(): @@ -123,9 +126,7 @@ def test_distribution_based_routing_with_custom_strategy(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Register custom distribution-based strategy - custom_strategy = DistributionBasedRouting(distribution="normal", - mean=2.0, - std=0.5) + custom_strategy = DistributionBasedRouting(distribution="normal", mean=2.0, std=0.5) RoutingSimulator.register_strategy("custom_normal", custom_strategy) # Test data @@ -142,7 +143,8 @@ def test_distribution_based_routing_with_custom_strategy(): hidden_states=hidden_states, router_logits=router_logits, strategy_name="custom_normal", - top_k=top_k) + top_k=top_k, + ) # Check output shapes assert topk_weights.shape == (num_tokens, top_k) @@ -165,7 +167,8 @@ def test_instance_compatibility(): hidden_states=hidden_states, router_logits=router_logits, strategy_name="uniform_random", - top_k=2) + top_k=2, + ) assert topk_weights.shape == (10, 2) assert topk_ids.shape == (10, 2) diff --git a/tests/test_scalartype.py b/tests/test_scalartype.py index ef4aef3afc2e..5361efbbdf6f 100644 --- a/tests/test_scalartype.py +++ b/tests/test_scalartype.py @@ -7,21 +7,24 @@ from vllm.scalar_type import scalar_types -@pytest.mark.parametrize("type_tuple", ( - (-8, 7, scalar_types.int4), - (0, 15, scalar_types.uint4), - (-8, 7, scalar_types.uint4b8), - (-128, 127, scalar_types.uint8b128), - (-6., 6., scalar_types.float4_e2m1f), - (-28., 28., scalar_types.float6_e3m2f), - (torch.int8, scalar_types.int8), - (torch.uint8, scalar_types.uint8), - (torch.float8_e5m2, scalar_types.float8_e5m2), - (torch.float8_e4m3fn, scalar_types.float8_e4m3fn), - (torch.bfloat16, scalar_types.float16_e8m7), - (torch.float16, scalar_types.float16_e5m10), -), - ids=lambda x: str(x)) +@pytest.mark.parametrize( + "type_tuple", + ( + (-8, 7, scalar_types.int4), + (0, 15, scalar_types.uint4), + (-8, 7, scalar_types.uint4b8), + (-128, 127, scalar_types.uint8b128), + (-6.0, 6.0, scalar_types.float4_e2m1f), + (-28.0, 28.0, scalar_types.float6_e3m2f), + (torch.int8, scalar_types.int8), + (torch.uint8, scalar_types.uint8), + (torch.float8_e5m2, scalar_types.float8_e5m2), + (torch.float8_e4m3fn, scalar_types.float8_e4m3fn), + (torch.bfloat16, scalar_types.float16_e8m7), + (torch.float16, scalar_types.float16_e5m10), + ), + ids=lambda x: str(x), +) def test_scalar_type_min_max(type_tuple): print(type_tuple) if len(type_tuple) == 3: diff --git a/tests/test_seed_behavior.py b/tests/test_seed_behavior.py index e9138b9e8eb6..adc8a1a4bf08 100644 --- a/tests/test_seed_behavior.py +++ b/tests/test_seed_behavior.py @@ -1,25 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import random - -import numpy as np -import torch - -from vllm.platforms.interface import Platform - - -def test_seed_behavior(): - # Test with a specific seed - Platform.seed_everything(42) - random_value_1 = random.randint(0, 100) - np_random_value_1 = np.random.randint(0, 100) - torch_random_value_1 = torch.randint(0, 100, (1, )).item() - - Platform.seed_everything(42) - random_value_2 = random.randint(0, 100) - np_random_value_2 = np.random.randint(0, 100) - torch_random_value_2 = torch.randint(0, 100, (1, )).item() - - assert random_value_1 == random_value_2 - assert np_random_value_1 == np_random_value_2 - assert torch_random_value_1 == torch_random_value_2 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + +import numpy as np +import torch + +from vllm.platforms.interface import Platform + + +def test_seed_behavior(): + # Test with a specific seed + Platform.seed_everything(42) + random_value_1 = random.randint(0, 100) + np_random_value_1 = np.random.randint(0, 100) + torch_random_value_1 = torch.randint(0, 100, (1,)).item() + + Platform.seed_everything(42) + random_value_2 = random.randint(0, 100) + np_random_value_2 = np.random.randint(0, 100) + torch_random_value_2 = torch.randint(0, 100, (1,)).item() + + assert random_value_1 == random_value_2 + assert np_random_value_1 == np_random_value_2 + assert torch_random_value_1 == torch_random_value_2 diff --git a/tests/test_sequence.py b/tests/test_sequence.py index da9826ff0505..27af05bec22d 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -7,7 +7,6 @@ def test_sequence_intermediate_tensors_equal(): - class AnotherIntermediateTensors(IntermediateTensors): pass @@ -20,22 +19,31 @@ class AnotherIntermediateTensors(IntermediateTensors): assert empty_intermediate_tensors_1 == empty_intermediate_tensors_2 different_key_intermediate_tensors_1 = IntermediateTensors( - {"1": torch.zeros([2, 4], dtype=torch.int32)}) + {"1": torch.zeros([2, 4], dtype=torch.int32)} + ) difference_key_intermediate_tensors_2 = IntermediateTensors( - {"2": torch.zeros([2, 4], dtype=torch.int32)}) - assert (different_key_intermediate_tensors_1 - != difference_key_intermediate_tensors_2) + {"2": torch.zeros([2, 4], dtype=torch.int32)} + ) + assert different_key_intermediate_tensors_1 != difference_key_intermediate_tensors_2 same_key_different_value_intermediate_tensors_1 = IntermediateTensors( - {"1": torch.zeros([2, 4], dtype=torch.int32)}) + {"1": torch.zeros([2, 4], dtype=torch.int32)} + ) same_key_different_value_intermediate_tensors_2 = IntermediateTensors( - {"1": torch.zeros([2, 5], dtype=torch.int32)}) - assert (same_key_different_value_intermediate_tensors_1 - != same_key_different_value_intermediate_tensors_2) + {"1": torch.zeros([2, 5], dtype=torch.int32)} + ) + assert ( + same_key_different_value_intermediate_tensors_1 + != same_key_different_value_intermediate_tensors_2 + ) same_key_same_value_intermediate_tensors_1 = IntermediateTensors( - {"1": torch.zeros([2, 4], dtype=torch.int32)}) + {"1": torch.zeros([2, 4], dtype=torch.int32)} + ) same_key_same_value_intermediate_tensors_2 = IntermediateTensors( - {"1": torch.zeros([2, 4], dtype=torch.int32)}) - assert (same_key_same_value_intermediate_tensors_1 == - same_key_same_value_intermediate_tensors_2) + {"1": torch.zeros([2, 4], dtype=torch.int32)} + ) + assert ( + same_key_same_value_intermediate_tensors_1 + == same_key_same_value_intermediate_tensors_2 + ) diff --git a/tests/test_triton_utils.py b/tests/test_triton_utils.py index ebb69e627e95..7fe0a5d9c517 100644 --- a/tests/test_triton_utils.py +++ b/tests/test_triton_utils.py @@ -5,8 +5,7 @@ import types from unittest import mock -from vllm.triton_utils.importing import (TritonLanguagePlaceholder, - TritonPlaceholder) +from vllm.triton_utils.importing import TritonLanguagePlaceholder, TritonPlaceholder def test_triton_placeholder_is_module(): @@ -52,8 +51,7 @@ def foo(x): def bar(x): return x - @triton.heuristics( - {"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64}) + @triton.heuristics({"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64}) def baz(x): return x @@ -89,6 +87,7 @@ def test_no_triton_fallback(): # mock triton not being installed with mock.patch.dict(sys.modules, {"triton": None}): from vllm.triton_utils import HAS_TRITON, tl, triton + assert HAS_TRITON is False assert triton.__class__.__name__ == "TritonPlaceholder" assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder" diff --git a/tests/test_version.py b/tests/test_version.py index fd07abb59b1f..928f742f1de8 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -31,7 +31,8 @@ def test_version_tuple(): ((1, 0, 0), "1.-1", True), ((1, 0, 0), "0.9", False), ((1, 0, 0), "0.17", False), - ]) + ], +) def test_prev_minor_version_was(version_tuple, version_str, expected): with patch("vllm.version.__version_tuple__", version_tuple): assert version._prev_minor_version_was(version_str) == expected diff --git a/tests/test_vllm_port.py b/tests/test_vllm_port.py index 88e1efd8fdbb..68bd511635dc 100644 --- a/tests/test_vllm_port.py +++ b/tests/test_vllm_port.py @@ -23,14 +23,17 @@ def test_get_vllm_port_valid(): def test_get_vllm_port_invalid(): """Test when VLLM_PORT is set to a non-integer value.""" - with (patch.dict(os.environ, {"VLLM_PORT": "abc"}, clear=True), - pytest.raises(ValueError, match="must be a valid integer")): + with ( + patch.dict(os.environ, {"VLLM_PORT": "abc"}, clear=True), + pytest.raises(ValueError, match="must be a valid integer"), + ): get_vllm_port() def test_get_vllm_port_uri(): """Test when VLLM_PORT is set to a URI.""" - with (patch.dict(os.environ, {"VLLM_PORT": "tcp://localhost:5678"}, - clear=True), - pytest.raises(ValueError, match="appears to be a URI")): + with ( + patch.dict(os.environ, {"VLLM_PORT": "tcp://localhost:5678"}, clear=True), + pytest.raises(ValueError, match="appears to be a URI"), + ): get_vllm_port() diff --git a/tests/tokenization/test_cached_tokenizer.py b/tests/tokenization/test_cached_tokenizer.py index 07217611ea4d..074039f9e513 100644 --- a/tests/tokenization/test_cached_tokenizer.py +++ b/tests/tokenization/test_cached_tokenizer.py @@ -6,17 +6,16 @@ import pytest from transformers import AutoTokenizer -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - get_cached_tokenizer) +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_cached_tokenizer @pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"]) def test_cached_tokenizer(model_id: str): - reference_tokenizer = AutoTokenizer.from_pretrained(model_id, - trust_remote_code=True) + reference_tokenizer = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=True + ) reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"}) - reference_tokenizer.add_special_tokens( - {"additional_special_tokens": ["<SEP>"]}) + reference_tokenizer.add_special_tokens({"additional_special_tokens": ["<SEP>"]}) cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer)) _check_consistency(cached_tokenizer, reference_tokenizer) @@ -32,13 +31,13 @@ def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer): # Cached attributes assert target.all_special_ids == expected.all_special_ids assert target.all_special_tokens == expected.all_special_tokens - assert (target.all_special_tokens_extended == - expected.all_special_tokens_extended) + assert target.all_special_tokens_extended == expected.all_special_tokens_extended assert target.get_vocab() == expected.get_vocab() assert len(target) == len(expected) # Other attributes - assert getattr(target, "padding_side", - None) == getattr(expected, "padding_side", None) + assert getattr(target, "padding_side", None) == getattr( + expected, "padding_side", None + ) assert target.encode("prompt") == expected.encode("prompt") diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index fe6c313d2966..14dcab7707d4 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -5,15 +5,16 @@ from typing import Any, Optional import pytest -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, - IncrementalDetokenizer, - SlowIncrementalDetokenizer) +from vllm.v1.engine.detokenizer import ( + FastIncrementalDetokenizer, + IncrementalDetokenizer, + SlowIncrementalDetokenizer, +) SPECIAL_TOKS_TRUTH = [ "Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa @@ -45,33 +46,35 @@ ] -def _run_incremental_decode(tokenizer, - all_input_ids, - skip_special_tokens: bool, - starting_index: int, - spaces_between_special_tokens: bool = True, - fast: Optional[bool] = None): - +def _run_incremental_decode( + tokenizer, + all_input_ids, + skip_special_tokens: bool, + starting_index: int, + spaces_between_special_tokens: bool = True, + fast: Optional[bool] = None, +): prompt_token_ids = all_input_ids[:starting_index] params = SamplingParams( skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, ) - request = EngineCoreRequest(request_id="", - prompt_token_ids=prompt_token_ids, - mm_features=None, - sampling_params=params, - pooling_params=None, - eos_token_id=None, - arrival_time=0.0, - lora_request=None, - cache_salt=None, - data_parallel_rank=None) + request = EngineCoreRequest( + request_id="", + prompt_token_ids=prompt_token_ids, + mm_features=None, + sampling_params=params, + pooling_params=None, + eos_token_id=None, + arrival_time=0.0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + ) if fast is None: - detokenizer = IncrementalDetokenizer.from_new_request( - tokenizer, request) + detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request) elif fast: detokenizer = FastIncrementalDetokenizer(tokenizer, request) else: @@ -88,9 +91,11 @@ def _run_incremental_decode(tokenizer, @pytest.fixture def tokenizer(tokenizer_name): - return (MistralTokenizer.from_pretrained(tokenizer_name) - if "mistral" in tokenizer_name else - AutoTokenizer.from_pretrained(tokenizer_name)) + return ( + MistralTokenizer.from_pretrained(tokenizer_name) + if "mistral" in tokenizer_name + else AutoTokenizer.from_pretrained(tokenizer_name) + ) @pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"]) @@ -102,7 +107,8 @@ def tokenizer(tokenizer_name): "ပုံပြင်လေးပြောပြပါ", # Using "URGENCY" since "CY" has token id 130282 "URGENCY🌶️", - ]) + ], +) def test_mistral_edge_case(tokenizer, truth): """Test for a specific edge cases with V3-Tekken MistralTokenizer. @@ -115,7 +121,8 @@ def test_mistral_edge_case(tokenizer, truth): tokenizer, all_input_ids, skip_special_tokens=True, - starting_index=starting_index) + starting_index=starting_index, + ) assert decoded_text == truth assert out_ids == all_input_ids[starting_index:] @@ -124,8 +131,10 @@ def test_mistral_edge_case(tokenizer, truth): def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: if "mistral" in tokenizer_name: yield ( - True if request.param else - pytest.skip("mistral doesn't support skip_special_tokens=False")) + True + if request.param + else pytest.skip("mistral doesn't support skip_special_tokens=False") + ) else: yield bool(request.param) @@ -136,8 +145,14 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: @pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True) @pytest.mark.parametrize("spaces_between_special_tokens", (True, False)) @pytest.mark.parametrize("fast", (True, False)) -def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, - spaces_between_special_tokens, fast): +def test_decode_streaming( + tokenizer, + truth, + with_prompt, + skip_special_tokens, + spaces_between_special_tokens, + fast, +): if fast and not isinstance(tokenizer, PreTrainedTokenizerFast): pytest.skip() @@ -146,30 +161,35 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, if not fast and isinstance(tokenizer, PreTrainedTokenizerFast): # Fix up inconsistency in fast/slow tokenizer behaviour. - tokenizer.add_special_tokens({ - "additional_special_tokens": [ - at for at in - tokenizer._tokenizer.get_added_tokens_decoder().values() - if at.special - ] - }) - - extra_decode_args = {} if not isinstance(tokenizer, PreTrainedTokenizer) \ + tokenizer.add_special_tokens( + { + "additional_special_tokens": [ + at + for at in tokenizer._tokenizer.get_added_tokens_decoder().values() + if at.special + ] + } + ) + + extra_decode_args = ( + {} + if not isinstance(tokenizer, PreTrainedTokenizer) else {"spaces_between_special_tokens": spaces_between_special_tokens} + ) truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids if tokenizer.bos_token_id is not None: truth_tokens.insert(0, tokenizer.bos_token_id) truth_tokens.append(tokenizer.eos_token_id) - new_truth = tokenizer.decode(truth_tokens, - skip_special_tokens=skip_special_tokens, - **extra_decode_args) + new_truth = tokenizer.decode( + truth_tokens, skip_special_tokens=skip_special_tokens, **extra_decode_args + ) if with_prompt: num_prompt_tokens = len( - tokenizer(truth[:len(truth) // 2], - add_special_tokens=False).input_ids) + tokenizer(truth[: len(truth) // 2], add_special_tokens=False).input_ids + ) if tokenizer.bos_token_id is not None: num_prompt_tokens += 1 @@ -177,11 +197,13 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, generated_input_ids = truth_tokens[num_prompt_tokens:] all_input_ids = prompt_input_ids + generated_input_ids starting_index = len(prompt_input_ids) - prompt = tokenizer.decode(prompt_input_ids, - skip_special_tokens=skip_special_tokens, - **extra_decode_args) + prompt = tokenizer.decode( + prompt_input_ids, + skip_special_tokens=skip_special_tokens, + **extra_decode_args, + ) - generated = new_truth[len(prompt):] + generated = new_truth[len(prompt) :] else: generated = new_truth starting_index = 0 @@ -193,7 +215,8 @@ def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens, skip_special_tokens=skip_special_tokens, starting_index=starting_index, spaces_between_special_tokens=spaces_between_special_tokens, - fast=fast) + fast=fast, + ) assert decoded_text == generated assert out_ids == all_input_ids[starting_index:] @@ -206,11 +229,13 @@ def test_oov_decode(tokenizer, fast): pytest.skip() decoded_text, out_ids = _run_incremental_decode( - tokenizer, [len(tokenizer)], + tokenizer, + [len(tokenizer)], skip_special_tokens=True, starting_index=0, spaces_between_special_tokens=True, - fast=fast) + fast=fast, + ) - assert decoded_text == '' + assert decoded_text == "" assert out_ids == [len(tokenizer)] diff --git a/tests/tokenization/test_do_lower_case.py b/tests/tokenization/test_do_lower_case.py index 7aa655e1c3b4..8aff50b351e3 100644 --- a/tests/tokenization/test_do_lower_case.py +++ b/tests/tokenization/test_do_lower_case.py @@ -13,6 +13,6 @@ def test_special_tokens(tokenizer_name: str, n_tokens: int): tokenizer = get_tokenizer(tokenizer_name, revision="main") - prompts = '[UNK]' * n_tokens + prompts = "[UNK]" * n_tokens prompt_token_ids = tokenizer.encode(prompts) assert len(prompt_token_ids) == n_tokens + 2 diff --git a/tests/tokenization/test_get_eos.py b/tests/tokenization/test_get_eos.py index d8288429351c..921d77b1b335 100644 --- a/tests/tokenization/test_get_eos.py +++ b/tests/tokenization/test_get_eos.py @@ -5,6 +5,7 @@ only get the `eos_token_id` from the tokenizer as defined by {meth}`vllm.LLMEngine._get_eos_token_id`. """ + from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.tokenizer import get_tokenizer @@ -15,8 +16,7 @@ def test_get_llama3_eos_token(): tokenizer = get_tokenizer(model_name) assert tokenizer.eos_token_id == 128009 - generation_config = try_get_generation_config(model_name, - trust_remote_code=False) + generation_config = try_get_generation_config(model_name, trust_remote_code=False) assert generation_config is not None assert generation_config.eos_token_id == [128001, 128008, 128009] @@ -27,7 +27,6 @@ def test_get_blip2_eos_token(): tokenizer = get_tokenizer(model_name) assert tokenizer.eos_token_id == 2 - generation_config = try_get_generation_config(model_name, - trust_remote_code=False) + generation_config = try_get_generation_config(model_name, trust_remote_code=False) assert generation_config is not None assert generation_config.eos_token_id == 50118 diff --git a/tests/tokenization/test_mistral_tokenizer.py b/tests/tokenization/test_mistral_tokenizer.py index 69b3c6294284..a034188387d0 100644 --- a/tests/tokenization/test_mistral_tokenizer.py +++ b/tests/tokenization/test_mistral_tokenizer.py @@ -2,187 +2,206 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -from mistral_common.protocol.instruct.messages import (AssistantMessage, - ToolMessage, - UserMessage) +from mistral_common.protocol.instruct.messages import ( + AssistantMessage, + ToolMessage, + UserMessage, +) from mistral_common.protocol.instruct.request import ChatCompletionRequest -from mistral_common.protocol.instruct.tool_calls import (Function, - FunctionCall, Tool, - ToolCall) +from mistral_common.protocol.instruct.tool_calls import ( + Function, + FunctionCall, + Tool, + ToolCall, +) from vllm.transformers_utils.tokenizers.mistral import ( - make_mistral_chat_completion_request) + make_mistral_chat_completion_request, +) @pytest.mark.parametrize( "openai_request,expected_mistral_request", - [( - { - "messages": [{ - "role": "user", - "content": "What is the current local date and time?", - }], - "tools": [{ - "type": "function", - "function": { - "description": "Fetch the current local date and time.", - "name": "get_current_time", - }, - }], - }, - ChatCompletionRequest( - messages=[ - UserMessage(content="What is the current local date and time?") - ], - tools=[ - Tool( - type="function", - function=Function( - name="get_current_time", - description="Fetch the current local date and time.", - parameters={}, - ), - ) - ], + [ + ( + { + "messages": [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + "tools": [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + }, + } + ], + }, + ChatCompletionRequest( + messages=[ + UserMessage(content="What is the current local date and time?") + ], + tools=[ + Tool( + type="function", + function=Function( + name="get_current_time", + description="Fetch the current local date and time.", + parameters={}, + ), + ) + ], + ), ), - ), - ( - { - "messages": - [{ - "role": "user", - "content": "What is the current local date and time?", - }], - "tools": [{ - "type": "function", - "function": { - "description": "Fetch the current local date and time.", - "name": "get_current_time", - "parameters": None, - }, - }], - }, - ChatCompletionRequest( - messages=[ - UserMessage( - content="What is the current local date and time?") - ], - tools=[ - Tool( - type="function", - function=Function( - name="get_current_time", - description="Fetch the current local date and time.", - parameters={}, - ), - ) - ], - ), - )], + ( + { + "messages": [ + { + "role": "user", + "content": "What is the current local date and time?", + } + ], + "tools": [ + { + "type": "function", + "function": { + "description": "Fetch the current local date and time.", + "name": "get_current_time", + "parameters": None, + }, + } + ], + }, + ChatCompletionRequest( + messages=[ + UserMessage(content="What is the current local date and time?") + ], + tools=[ + Tool( + type="function", + function=Function( + name="get_current_time", + description="Fetch the current local date and time.", + parameters={}, + ), + ) + ], + ), + ), + ], ) -def test_make_mistral_chat_completion_request(openai_request, - expected_mistral_request): +def test_make_mistral_chat_completion_request(openai_request, expected_mistral_request): actual_request = make_mistral_chat_completion_request( - openai_request["messages"], openai_request["tools"]) + openai_request["messages"], openai_request["tools"] + ) assert actual_request == expected_mistral_request # Tool use with list content and reasoning_content -@pytest.mark.parametrize("openai_request,expected_mistral_request", [( - { - "messages": [ - { - "role": "user", - "content": "What's the weather in Paris?", - }, +@pytest.mark.parametrize( + "openai_request,expected_mistral_request", + [ + ( { - "role": - "assistant", - "reasoning_content": - None, - "content": - None, - "tool_calls": [{ - "id": "call123", - "type": "function", - "function": { - "name": "get_weather", - "arguments": '{"city": "Paris"}', + "messages": [ + { + "role": "user", + "content": "What's the weather in Paris?", }, - }], - }, - { - "role": "tool", - "content": [{ - "type": "text", - "text": "Rainy" - }], - "name": "get_weather", - "tool_call_id": "call123", - }, - ], - "tools": [{ - "type": "function", - "function": { - "name": "get_weather", - "description": "Gets the current weather in a city.", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - } + { + "role": "assistant", + "reasoning_content": None, + "content": None, + "tool_calls": [ + { + "id": "call123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"city": "Paris"}', + }, + } + ], + }, + { + "role": "tool", + "content": [{"type": "text", "text": "Rainy"}], + "name": "get_weather", + "tool_call_id": "call123", }, - "required": ["city"], - }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Gets the current weather in a city.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + } + }, + "required": ["city"], + }, + }, + } + ], }, - }], - }, - ChatCompletionRequest( - messages=[ - UserMessage(content="What's the weather in Paris?"), - AssistantMessage( - content=None, - tool_calls=[ - ToolCall( - id="call123", - function=FunctionCall( + ChatCompletionRequest( + messages=[ + UserMessage(content="What's the weather in Paris?"), + AssistantMessage( + content=None, + tool_calls=[ + ToolCall( + id="call123", + function=FunctionCall( + name="get_weather", + arguments='{"city": "Paris"}', + ), + ) + ], + ), + ToolMessage( + content="Rainy", + tool_call_id="call123", + name="get_weather", + ), + ], + tools=[ + Tool( + type="function", + function=Function( name="get_weather", - arguments='{"city": "Paris"}', + description="Gets the current weather in a city.", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name", + } + }, + "required": ["city"], + }, ), ) ], ), - ToolMessage( - content="Rainy", - tool_call_id="call123", - name="get_weather", - ), - ], - tools=[ - Tool( - type="function", - function=Function( - name="get_weather", - description="Gets the current weather in a city.", - parameters={ - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - } - }, - "required": ["city"], - }, - ), - ) - ], - ), -)]) + ) + ], +) def test_make_mistral_chat_completion_request_list_content( - openai_request, expected_mistral_request): + openai_request, expected_mistral_request +): actual_request = make_mistral_chat_completion_request( - openai_request["messages"], openai_request["tools"]) + openai_request["messages"], openai_request["tools"] + ) assert actual_request == expected_mistral_request diff --git a/tests/tokenization/test_tokenizer.py b/tests/tokenization/test_tokenizer.py index 09a3638fd2ed..e86bb03883b5 100644 --- a/tests/tokenization/test_tokenizer.py +++ b/tests/tokenization/test_tokenizer.py @@ -19,5 +19,5 @@ def test_tokenizer_revision(tokenizer_name: str): assert isinstance(tokenizer, PreTrainedTokenizerBase) # Assume that "never" branch always does not exist - with pytest.raises(OSError, match='not a valid git identifier'): + with pytest.raises(OSError, match="not a valid git identifier"): get_tokenizer(tokenizer_name, revision="never") diff --git a/tests/tokenization/test_tokenizer_registry.py b/tests/tokenization/test_tokenizer_registry.py index 68d4b416b4c9..de67c3e798c4 100644 --- a/tests/tokenization/test_tokenizer_registry.py +++ b/tests/tokenization/test_tokenizer_registry.py @@ -4,15 +4,13 @@ from typing import TYPE_CHECKING, Any, Optional, Union from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.transformers_utils.tokenizer_base import (TokenizerBase, - TokenizerRegistry) +from vllm.transformers_utils.tokenizer_base import TokenizerBase, TokenizerRegistry if TYPE_CHECKING: from vllm.entrypoints.chat_utils import ChatCompletionMessageParam class TestTokenizer(TokenizerBase): - @classmethod def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer": return TestTokenizer() @@ -85,23 +83,23 @@ def encode_one( ) -> list[int]: raise NotImplementedError() - def encode(self, - text: str, - add_special_tokens: Optional[bool] = None) -> list[int]: + def encode(self, text: str, add_special_tokens: Optional[bool] = None) -> list[int]: raise NotImplementedError() - def apply_chat_template(self, - messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, Any]]] = None, - **kwargs) -> list[int]: + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs, + ) -> list[int]: raise NotImplementedError() def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError() - def decode(self, - ids: Union[list[int], int], - skip_special_tokens: bool = True) -> str: + def decode( + self, ids: Union[list[int], int], skip_special_tokens: bool = True + ) -> str: raise NotImplementedError() def convert_ids_to_tokens( @@ -113,9 +111,9 @@ def convert_ids_to_tokens( def test_customized_tokenizer(): - TokenizerRegistry.register("test_tokenizer", - "tests.tokenization.test_tokenizer_registry", - "TestTokenizer") + TokenizerRegistry.register( + "test_tokenizer", "tests.tokenization.test_tokenizer_registry", "TestTokenizer" + ) tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer") assert isinstance(tokenizer, TestTokenizer) diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py index 510b54790cd9..ff9cdeeb7375 100644 --- a/tests/tool_use/conftest.py +++ b/tests/tool_use/conftest.py @@ -13,13 +13,13 @@ # select models to test based on command line arguments def pytest_addoption(parser): - parser.addoption("--models", - nargs="+", - help="Specify one or more models to test") - parser.addoption("--extended", - action="store_true", - default=False, - help="invoke extended tests requiring large GPUs") + parser.addoption("--models", nargs="+", help="Specify one or more models to test") + parser.addoption( + "--extended", + action="store_true", + default=False, + help="invoke extended tests requiring large GPUs", + ) # for each server config, download the model and return the config @@ -29,8 +29,10 @@ def server_config(request): models = request.config.getoption("--models") config_keys_to_test = [ - key for key in CONFIGS if (models is None or key in models) and ( - extended or not CONFIGS[key].get("extended", False)) + key + for key in CONFIGS + if (models is None or key in models) + and (extended or not CONFIGS[key].get("extended", False)) ] config_key = request.param @@ -40,8 +42,9 @@ def server_config(request): config = CONFIGS[config_key] if current_platform.is_rocm() and not config.get("supports_rocm", True): - pytest.skip("The {} model can't be tested on the ROCm platform".format( - config["model"])) + pytest.skip( + "The {} model can't be tested on the ROCm platform".format(config["model"]) + ) # download model and tokenizer using transformers snapshot_download(config["model"]) @@ -53,8 +56,9 @@ def server_config(request): def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] - with RemoteOpenAIServer(model, ARGS + args_for_model, - max_wait_seconds=480) as server: + with RemoteOpenAIServer( + model, ARGS + args_for_model, max_wait_seconds=480 + ) as server: yield server diff --git a/tests/tool_use/mistral/conftest.py b/tests/tool_use/mistral/conftest.py index e9dddccdc8c0..9b0a6eb27fca 100644 --- a/tests/tool_use/mistral/conftest.py +++ b/tests/tool_use/mistral/conftest.py @@ -17,8 +17,9 @@ def server_config(request): config = CONFIGS[request.param] if current_platform.is_rocm() and not config.get("supports_rocm", True): - pytest.skip("The {} model can't be tested on the ROCm platform".format( - config["model"])) + pytest.skip( + "The {} model can't be tested on the ROCm platform".format(config["model"]) + ) # download model and tokenizer using transformers snapshot_download(config["model"]) @@ -30,8 +31,9 @@ def server_config(request): def server(request, server_config: ServerConfig): model = server_config["model"] args_for_model = server_config["arguments"] - with RemoteOpenAIServer(model, ARGS + args_for_model, - max_wait_seconds=480) as server: + with RemoteOpenAIServer( + model, ARGS + args_for_model, max_wait_seconds=480 + ) as server: yield server diff --git a/tests/tool_use/mistral/test_mistral_tool_calls.py b/tests/tool_use/mistral/test_mistral_tool_calls.py index 9bf6863f3f2b..3c4a543abe41 100644 --- a/tests/tool_use/mistral/test_mistral_tool_calls.py +++ b/tests/tool_use/mistral/test_mistral_tool_calls.py @@ -19,12 +19,12 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL], tool_choice=WEATHER_TOOL, - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] assert choice.finish_reason != "tool_calls" # "stop" or "length" assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 1 + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1 assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral diff --git a/tests/tool_use/mistral/utils.py b/tests/tool_use/mistral/utils.py index 7a026cd9bb61..13a234f8e26b 100644 --- a/tests/tool_use/mistral/utils.py +++ b/tests/tool_use/mistral/utils.py @@ -18,17 +18,16 @@ class ServerConfig(TypedDict, total=False): CONFIGS: dict[str, ServerConfig] = { "mistral": { - "model": - "mistralai/Mistral-7B-Instruct-v0.3", + "model": "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ - "--tokenizer-mode", "mistral", - "--ignore-patterns=\"consolidated.safetensors\"" + "--tokenizer-mode", + "mistral", + '--ignore-patterns="consolidated.safetensors"', ], - "system_prompt": - "You are a helpful assistant with access to tools. If a tool" + "system_prompt": "You are a helpful assistant with access to tools. If a tool" " that you have would be helpful to answer a user query, " "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." + "to the user's question - just respond to it normally.", }, } diff --git a/tests/tool_use/test_chat_completion_request_validations.py b/tests/tool_use/test_chat_completion_request_validations.py index a30c58b09fe8..50cd9e4279b2 100644 --- a/tests/tool_use/test_chat_completion_request_validations.py +++ b/tests/tool_use/test_chat_completion_request_validations.py @@ -8,68 +8,56 @@ def test_chat_completion_request_with_no_tools(): # tools key is not present - request = ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - 'facebook/opt-125m', - }) - assert request.tool_choice == 'none' + request = ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "facebook/opt-125m", + } + ) + assert request.tool_choice == "none" # tools key is None - request = ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - 'facebook/opt-125m', - 'tools': - None - }) - assert request.tool_choice == 'none' + request = ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "facebook/opt-125m", + "tools": None, + } + ) + assert request.tool_choice == "none" # tools key present but empty - request = ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - 'facebook/opt-125m', - 'tools': [] - }) - assert request.tool_choice == 'none' + request = ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "facebook/opt-125m", + "tools": [], + } + ) + assert request.tool_choice == "none" -@pytest.mark.parametrize('tool_choice', ['auto', 'required']) +@pytest.mark.parametrize("tool_choice", ["auto", "required"]) def test_chat_completion_request_with_tool_choice_but_no_tools(tool_choice): - with pytest.raises(ValueError, - match="When using `tool_choice`, `tools` must be set."): - ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - 'facebook/opt-125m', - 'tool_choice': - tool_choice - }) - - with pytest.raises(ValueError, - match="When using `tool_choice`, `tools` must be set."): - ChatCompletionRequest.model_validate({ - 'messages': [{ - 'role': 'user', - 'content': 'Hello' - }], - 'model': - 'facebook/opt-125m', - 'tool_choice': - tool_choice, - 'tools': - None - }) + with pytest.raises( + ValueError, match="When using `tool_choice`, `tools` must be set." + ): + ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "facebook/opt-125m", + "tool_choice": tool_choice, + } + ) + + with pytest.raises( + ValueError, match="When using `tool_choice`, `tools` must be set." + ): + ChatCompletionRequest.model_validate( + { + "messages": [{"role": "user", "content": "Hello"}], + "model": "facebook/opt-125m", + "tool_choice": tool_choice, + "tools": None, + } + ) diff --git a/tests/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py index 8c01c86e29f2..425d3879985e 100644 --- a/tests/tool_use/test_chat_completions.py +++ b/tests/tool_use/test_chat_completions.py @@ -4,16 +4,21 @@ import openai import pytest -from .utils import (MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL, ServerConfig, - ensure_system_prompt) +from .utils import ( + MESSAGES_WITHOUT_TOOLS, + WEATHER_TOOL, + ServerConfig, + ensure_system_prompt, +) # test: make sure chat completions without tools provided work even when tools # are enabled. This makes sure tool call chat templates work, AND that the tool # parser stream processing doesn't change the output of the model. @pytest.mark.asyncio -async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, - server_config: ServerConfig): +async def test_chat_completion_without_tools( + client: openai.AsyncOpenAI, server_config: ServerConfig +): models = await client.models.list() model_name: str = models.data[0].id chat_completion = await client.chat.completions.create( @@ -21,7 +26,8 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, temperature=0, max_completion_tokens=150, model=model_name, - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason output_text = chat_completion.choices[0].message.content @@ -32,8 +38,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, assert stop_reason != "tool_calls" # check to make sure no tool calls were returned - assert (choice.message.tool_calls is None - or len(choice.message.tool_calls) == 0) + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 # make the same request, streaming stream = await client.chat.completions.create( @@ -55,7 +60,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, # make sure the role is assistant if delta.role: assert not role_sent - assert delta.role == 'assistant' + assert delta.role == "assistant" role_sent = True if delta.content: @@ -80,8 +85,9 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI, # tools, to make sure we can still get normal chat completion responses # and that they won't be parsed as tools @pytest.mark.asyncio -async def test_chat_completion_with_tools(client: openai.AsyncOpenAI, - server_config: ServerConfig): +async def test_chat_completion_with_tools( + client: openai.AsyncOpenAI, server_config: ServerConfig +): models = await client.models.list() model_name: str = models.data[0].id chat_completion = await client.chat.completions.create( @@ -90,19 +96,19 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI, max_completion_tokens=150, model=model_name, tools=[WEATHER_TOOL], - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason output_text = chat_completion.choices[0].message.content # check to make sure we got text assert output_text is not None - assert stop_reason != 'tool_calls' + assert stop_reason != "tool_calls" assert len(output_text) > 0 # check to make sure no tool calls were returned - assert (choice.message.tool_calls is None - or len(choice.message.tool_calls) == 0) + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 # make the same request, streaming stream = await client.chat.completions.create( @@ -125,7 +131,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI, # make sure the role is assistant if delta.role: - assert delta.role == 'assistant' + assert delta.role == "assistant" role_sent = True if delta.content: @@ -142,6 +148,6 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI, assert role_sent assert finish_reason_count == 1 assert chunk.choices[0].finish_reason == stop_reason - assert chunk.choices[0].finish_reason != 'tool_calls' + assert chunk.choices[0].finish_reason != "tool_calls" assert len(chunks) assert "".join(chunks) == output_text diff --git a/tests/tool_use/test_deepseekv31_tool_parser.py b/tests/tool_use/test_deepseekv31_tool_parser.py index 5f6b266d3aa1..9b7e71b49c05 100644 --- a/tests/tool_use/test_deepseekv31_tool_parser.py +++ b/tests/tool_use/test_deepseekv31_tool_parser.py @@ -21,23 +21,28 @@ def parser(deepseekv31_tokenizer): def test_extract_tool_calls_with_tool(parser): model_output = ( - "normal text" + "<|tool▁calls▁begin|>" + - "<|tool▁call▁begin|>foo<|tool▁sep|>{\"x\":1}<|tool▁call▁end|>" + - "<|tool▁calls▁end|>") + "normal text" + + "<|tool▁calls▁begin|>" + + '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>' + + "<|tool▁calls▁end|>" + ) result = parser.extract_tool_calls(model_output, None) assert result.tools_called assert len(result.tool_calls) == 1 assert result.tool_calls[0].function.name == "foo" - assert result.tool_calls[0].function.arguments == "{\"x\":1}" + assert result.tool_calls[0].function.arguments == '{"x":1}' assert result.content == "normal text" def test_extract_tool_calls_with_multiple_tools(parser): model_output = ( - "some prefix text" + "<|tool▁calls▁begin|>" + - "<|tool▁call▁begin|>foo<|tool▁sep|>{\"x\":1}<|tool▁call▁end|>" + - "<|tool▁call▁begin|>bar<|tool▁sep|>{\"y\":2}<|tool▁call▁end|>" + - "<|tool▁calls▁end|>" + " some suffix text") + "some prefix text" + + "<|tool▁calls▁begin|>" + + '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>' + + '<|tool▁call▁begin|>bar<|tool▁sep|>{"y":2}<|tool▁call▁end|>' + + "<|tool▁calls▁end|>" + + " some suffix text" + ) result = parser.extract_tool_calls(model_output, None) @@ -45,10 +50,10 @@ def test_extract_tool_calls_with_multiple_tools(parser): assert len(result.tool_calls) == 2 assert result.tool_calls[0].function.name == "foo" - assert result.tool_calls[0].function.arguments == "{\"x\":1}" + assert result.tool_calls[0].function.arguments == '{"x":1}' assert result.tool_calls[1].function.name == "bar" - assert result.tool_calls[1].function.arguments == "{\"y\":2}" + assert result.tool_calls[1].function.arguments == '{"y":2}' # prefix is content assert result.content == "some prefix text" diff --git a/tests/tool_use/test_glm4_moe_tool_parser.py b/tests/tool_use/test_glm4_moe_tool_parser.py index bb8c36fb13ad..6f1f6671d9b3 100644 --- a/tests/tool_use/test_glm4_moe_tool_parser.py +++ b/tests/tool_use/test_glm4_moe_tool_parser.py @@ -27,12 +27,14 @@ def glm4_moe_tool_parser(glm4_moe_tokenizer): return Glm4MoeModelToolParser(glm4_moe_tokenizer) -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 0 @@ -47,7 +49,8 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): model_output = "This is a test" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -73,14 +76,18 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): <arg_value>fahrenheit</arg_value> </tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], None, ), @@ -102,22 +109,30 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): <arg_value>fahrenheit</arg_value> </tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )), - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit", - }), - )), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + } + ), + ) + ), ], None, ), @@ -131,14 +146,18 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): <arg_value>celsius</arg_value> </tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Seattle", - "state": "WA", - "unit": "celsius", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Seattle", + "state": "WA", + "unit": "celsius", + } + ), + ) + ) ], "I'll help you check the weather.", ), @@ -152,37 +171,51 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): <arg_value>celsius</arg_value> </tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "New York", - "state": "NY", - "unit": "celsius", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "New York", + "state": "NY", + "unit": "celsius", + } + ), + ) + ) ], None, ), - ("""I will help you get the weather.<tool_call>get_weather + ( + """I will help you get the weather.<tool_call>get_weather <arg_key>city</arg_key> <arg_value>Beijing</arg_value> <arg_key>date</arg_key> <arg_value>2025-08-01</arg_value> - </tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "city": "Beijing", - "date": "2025-08-01", - }), - )) - ], "I will help you get the weather."), + </tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "city": "Beijing", + "date": "2025-08-01", + } + ), + ) + ) + ], + "I will help you get the weather.", + ), ], ) -def test_extract_tool_calls(glm4_moe_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls( + glm4_moe_tool_parser, model_output, expected_tool_calls, expected_content +): extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -202,7 +235,8 @@ def test_extract_tool_calls_with_thinking_tags(glm4_moe_tool_parser): </tool_call>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 @@ -224,7 +258,8 @@ def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser): </tool_call>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] # Should handle malformed XML gracefully # The parser should either extract what it can or return no tool calls @@ -239,12 +274,12 @@ def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser): </tool_call>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 - assert extracted_tool_calls.tool_calls[ - 0].function.name == "get_current_time" + assert extracted_tool_calls.tool_calls[0].function.name == "get_current_time" # Empty arguments should result in empty JSON object assert extracted_tool_calls.tool_calls[0].function.arguments == "{}" @@ -270,7 +305,8 @@ def test_extract_tool_calls_mixed_content(glm4_moe_tool_parser): </tool_call>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 2 @@ -321,8 +357,7 @@ def test_streaming_basic_functionality(glm4_moe_tool_parser): # The result behavior depends on the streaming state # This test mainly ensures no exceptions are thrown - assert result is None or hasattr(result, 'tool_calls') or hasattr( - result, 'content') + assert result is None or hasattr(result, "tool_calls") or hasattr(result, "content") def test_streaming_no_tool_calls(glm4_moe_tool_parser): @@ -341,7 +376,7 @@ def test_streaming_no_tool_calls(glm4_moe_tool_parser): # Should return the delta text as content assert result is not None - assert hasattr(result, 'content') + assert hasattr(result, "content") assert result.content == " without any tool calls." @@ -367,7 +402,7 @@ def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser): # Should return content when no tool call tokens are detected assert result is not None - assert hasattr(result, 'content') + assert hasattr(result, "content") assert result.content == "get the weather.<tool_call>" @@ -383,7 +418,8 @@ def test_extract_tool_calls_special_characters(glm4_moe_tool_parser): </tool_call>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 @@ -404,7 +440,8 @@ def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser): <arg_value>2025-08-01</arg_value>""" extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] # Incomplete tool calls should not be extracted assert not extracted_tool_calls.tools_called diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_use/test_jamba_tool_parser.py index 8f819301e264..44d42bbd72b0 100644 --- a/tests/tool_use/test_jamba_tool_parser.py +++ b/tests/tool_use/test_jamba_tool_parser.py @@ -9,8 +9,7 @@ import pytest from partial_json_parser.core.options import Allow -from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, - ToolCall) +from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall from vllm.entrypoints.openai.tool_parsers import JambaToolParser from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer @@ -30,12 +29,14 @@ def jamba_tool_parser(jamba_tokenizer): return JambaToolParser(jamba_tokenizer) -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 16 @@ -44,10 +45,9 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], def stream_delta_message_generator( - jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, - model_output: str) -> Generator[DeltaMessage, None, None]: - all_token_ids = jamba_tokenizer.encode(model_output, - add_special_tokens=False) + jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, model_output: str +) -> Generator[DeltaMessage, None, None]: + all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False) previous_text = "" previous_tokens = None @@ -56,18 +56,19 @@ def stream_delta_message_generator( for i, delta_token in enumerate(all_token_ids): delta_token_ids = [delta_token] previous_token_ids = all_token_ids[:i] - current_token_ids = all_token_ids[:i + 1] - - (new_tokens, delta_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=jamba_tokenizer, - all_input_ids=current_token_ids, - prev_tokens=previous_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=False, - spaces_between_special_tokens=True, - ) + current_token_ids = all_token_ids[: i + 1] + + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=jamba_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + ) current_text = previous_text + delta_text @@ -84,8 +85,9 @@ def stream_delta_message_generator( yield delta_message previous_text = current_text - previous_tokens = previous_tokens + new_tokens if previous_tokens\ - else new_tokens + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) prefix_offset = new_prefix_offset read_offset = new_read_offset @@ -93,7 +95,8 @@ def stream_delta_message_generator( def test_extract_tool_calls_no_tools(jamba_tool_parser): model_output = "This is a test" extracted_tool_calls = jamba_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -108,54 +111,63 @@ def test_extract_tool_calls_no_tools(jamba_tool_parser): argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ( - ''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) ], - None), + None, + ), ( - ''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) ], - " Sure! let me call the tool for you."), + " Sure! let me call the tool for you.", + ), ( - ''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))), - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Orlando", "state": "FL", "unit": "fahrenheit"} + ), + ) + ), ], - None) + None, + ), ], ) -def test_extract_tool_calls(jamba_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls( + jamba_tool_parser, model_output, expected_tool_calls, expected_content +): extracted_tool_calls = jamba_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -172,63 +184,75 @@ def test_extract_tool_calls(jamba_tool_parser, model_output, ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ - ('''This is a test''', [], '''This is a test'''), + ("""This is a test""", [], """This is a test"""), ( - ''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) ], - " "), + " ", + ), ( - ''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) ], - " Sure! let me call the tool for you."), + " Sure! let me call the tool for you.", + ), ( - ''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501 + """ <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))), - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit" - }))) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Orlando", "state": "FL", "unit": "fahrenheit"} + ), + ) + ), ], - " ") + " ", + ), ], ) -def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer, - model_output, expected_tool_calls, - expected_content): - other_content: str = '' +def test_extract_tool_calls_streaming( + jamba_tool_parser, + jamba_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + other_content: str = "" function_names: list[str] = [] function_args_strs: list[str] = [] tool_call_idx: int = -1 tool_call_ids: list[Optional[str]] = [] for delta_message in stream_delta_message_generator( - jamba_tool_parser, jamba_tokenizer, model_output): + jamba_tool_parser, jamba_tokenizer, model_output + ): # role should never be streamed from tool parser assert not delta_message.role @@ -264,18 +288,22 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer, # make sure they're a string and then add them to the list assert isinstance(tool_call.function.arguments, str) - function_args_strs[ - tool_call.index] += tool_call.function.arguments + function_args_strs[tool_call.index] += tool_call.function.arguments assert other_content == expected_content actual_tool_calls = [ - ToolCall(id=tool_call_id, - function=FunctionCall( - name=function_name, - arguments=partial_json_parser.ensure_json( - function_args_str, Allow.OBJ | Allow.STR))) + ToolCall( + id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR + ), + ), + ) for tool_call_id, function_name, function_args_str in zip( - tool_call_ids, function_names, function_args_strs) + tool_call_ids, function_names, function_args_strs + ) ] assert_tool_calls(actual_tool_calls, expected_tool_calls) diff --git a/tests/tool_use/test_kimi_k2_tool_parser.py b/tests/tool_use/test_kimi_k2_tool_parser.py index ad9af6361802..43feae4d865e 100644 --- a/tests/tool_use/test_kimi_k2_tool_parser.py +++ b/tests/tool_use/test_kimi_k2_tool_parser.py @@ -26,27 +26,31 @@ def kimi_k2_tool_parser(kimi_k2_tokenizer): return KimiK2ToolParser(kimi_k2_tokenizer) -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): - + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert actual_tool_call.type == "function" assert actual_tool_call.function == expected_tool_call.function # assert tool call id format assert actual_tool_call.id.startswith("functions.") - assert actual_tool_call.id.split(':')[-1].isdigit() - assert actual_tool_call.id.split('.')[1].split( - ':')[0] == expected_tool_call.function.name + assert actual_tool_call.id.split(":")[-1].isdigit() + assert ( + actual_tool_call.id.split(".")[1].split(":")[0] + == expected_tool_call.function.name + ) def test_extract_tool_calls_no_tools(kimi_k2_tool_parser): model_output = "This is a test" extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -63,14 +67,18 @@ def test_extract_tool_calls_no_tools(kimi_k2_tool_parser): """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""", [ - ToolCall(id='functions.get_weather:0', - function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "city": "Beijing", - }, ), - ), - type='function') + ToolCall( + id="functions.get_weather:0", + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "city": "Beijing", + }, + ), + ), + type="function", + ) ], "I'll help you check the weather. ", ), @@ -79,31 +87,41 @@ def test_extract_tool_calls_no_tools(kimi_k2_tool_parser): functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|> functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""", [ - ToolCall(id='functions.get_weather:0', - function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "city": "Beijing", - }, ), - ), - type='function'), - ToolCall(id='functions.get_weather:1', - function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "city": "Shanghai", - }, ), - ), - type='function') + ToolCall( + id="functions.get_weather:0", + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "city": "Beijing", + }, + ), + ), + type="function", + ), + ToolCall( + id="functions.get_weather:1", + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "city": "Shanghai", + }, + ), + ), + type="function", + ), ], "I'll help you check the weather. ", ), ], ) -def test_extract_tool_calls(kimi_k2_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls( + kimi_k2_tool_parser, model_output, expected_tool_calls, expected_content +): extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -118,15 +136,14 @@ def test_extract_tool_calls_invalid_json(kimi_k2_tool_parser): functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""" extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called # Should extract only the valid JSON tool calls assert len(extracted_tool_calls.tool_calls) == 2 - assert extracted_tool_calls.tool_calls[ - 0].function.name == "invalid_get_weather" - assert extracted_tool_calls.tool_calls[ - 1].function.name == "valid_get_weather" + assert extracted_tool_calls.tool_calls[0].function.name == "invalid_get_weather" + assert extracted_tool_calls.tool_calls[1].function.name == "valid_get_weather" def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser): @@ -136,13 +153,13 @@ def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser): functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""" extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called # Should extract only the valid JSON tool calls assert len(extracted_tool_calls.tool_calls) == 1 - assert extracted_tool_calls.tool_calls[ - 0].function.name == "valid_get_weather" + assert extracted_tool_calls.tool_calls[0].function.name == "valid_get_weather" def test_streaming_basic_functionality(kimi_k2_tool_parser): @@ -170,8 +187,7 @@ def test_streaming_basic_functionality(kimi_k2_tool_parser): # The result might be None or contain tool call information # This depends on the internal state management - if result is not None and hasattr(result, - 'tool_calls') and result.tool_calls: + if result is not None and hasattr(result, "tool_calls") and result.tool_calls: assert len(result.tool_calls) >= 0 @@ -191,5 +207,5 @@ def test_streaming_no_tool_calls(kimi_k2_tool_parser): # Should return the delta text as content assert result is not None - assert hasattr(result, 'content') + assert hasattr(result, "content") assert result.content == " without any tool calls." diff --git a/tests/tool_use/test_minimax_tool_parser.py b/tests/tool_use/test_minimax_tool_parser.py index 7aa19c9a51c9..8610656fa288 100644 --- a/tests/tool_use/test_minimax_tool_parser.py +++ b/tests/tool_use/test_minimax_tool_parser.py @@ -7,8 +7,11 @@ import pytest -from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionToolsParam, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser from vllm.transformers_utils.tokenizer import get_tokenizer @@ -31,60 +34,48 @@ def minimax_tool_parser(minimax_tokenizer): @pytest.fixture def sample_tools(): return [ - ChatCompletionToolsParam(type="function", - function={ - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - }, - "state": { - "type": "string", - "description": - "The state code" - }, - "unit": { - "type": "string", - "enum": - ["fahrenheit", "celsius"] - } - }, - "required": ["city", "state"] - } - }), - ChatCompletionToolsParam(type="function", - function={ - "name": "calculate_area", - "description": - "Calculate area of a shape", - "parameters": { - "type": "object", - "properties": { - "shape": { - "type": "string" - }, - "dimensions": { - "type": "object" - }, - "precision": { - "type": "integer" - } - } - } - }) + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city name"}, + "state": {"type": "string", "description": "The state code"}, + "unit": {"type": "string", "enum": ["fahrenheit", "celsius"]}, + }, + "required": ["city", "state"], + }, + }, + ), + ChatCompletionToolsParam( + type="function", + function={ + "name": "calculate_area", + "description": "Calculate area of a shape", + "parameters": { + "type": "object", + "properties": { + "shape": {"type": "string"}, + "dimensions": {"type": "object"}, + "precision": {"type": "integer"}, + }, + }, + }, + ), ] -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 16 @@ -95,7 +86,8 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], def test_extract_tool_calls_no_tools(minimax_tool_parser): model_output = "This is a test" extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -116,14 +108,18 @@ def test_extract_tool_calls_no_tools(minimax_tool_parser): {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}} </tool_calls>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], None, ), @@ -133,22 +129,30 @@ def test_extract_tool_calls_no_tools(minimax_tool_parser): {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}} </tool_calls>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )), - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit", - }), - )), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + } + ), + ) + ), ], None, ), @@ -157,14 +161,18 @@ def test_extract_tool_calls_no_tools(minimax_tool_parser): {"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}} </tool_calls>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Seattle", - "state": "WA", - "unit": "celsius", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Seattle", + "state": "WA", + "unit": "celsius", + } + ), + ) + ) ], "I'll help you check the weather.", ), @@ -173,14 +181,18 @@ def test_extract_tool_calls_no_tools(minimax_tool_parser): {"name": "get_current_weather", "arguments": {"city": "New York", "state": "NY", "unit": "celsius"}} </tool_calls>""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "New York", - "state": "NY", - "unit": "celsius", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "New York", + "state": "NY", + "unit": "celsius", + } + ), + ) + ) ], None, ), @@ -188,22 +200,28 @@ def test_extract_tool_calls_no_tools(minimax_tool_parser): """<tool_calls> {"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA"}}""", [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Boston", - "state": "MA", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Boston", + "state": "MA", + } + ), + ) + ) ], None, ), ], ) -def test_extract_tool_calls(minimax_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls( + minimax_tool_parser, model_output, expected_tool_calls, expected_content +): extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -221,8 +239,7 @@ def test_preprocess_model_output_with_thinking_tags(minimax_tool_parser): {"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"}} </tool_calls>""" - processed_output = minimax_tool_parser.preprocess_model_output( - model_output) + processed_output = minimax_tool_parser.preprocess_model_output(model_output) # The tool call within thinking tags should be removed assert "fake_tool" not in processed_output @@ -244,12 +261,12 @@ def test_extract_tool_calls_with_thinking_tags(minimax_tool_parser): </tool_calls>""" extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 - assert extracted_tool_calls.tool_calls[ - 0].function.name == "get_current_weather" + assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather" # Content extraction is based on the position of the first <tool_calls> in the original model_output # Since preprocessing removes tool calls within thinking tags, the actual first <tool_calls> is the external one @@ -270,14 +287,14 @@ def test_extract_tool_calls_invalid_json(minimax_tool_parser): </tool_calls>""" extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called # Should extract only the valid JSON tool calls assert len(extracted_tool_calls.tool_calls) == 2 assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool" - assert extracted_tool_calls.tool_calls[ - 1].function.name == "another_valid_tool" + assert extracted_tool_calls.tool_calls[1].function.name == "another_valid_tool" def test_extract_tool_calls_missing_name_or_arguments(minimax_tool_parser): @@ -290,14 +307,14 @@ def test_extract_tool_calls_missing_name_or_arguments(minimax_tool_parser): </tool_calls>""" extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called # Should extract only the valid tool calls with both name and arguments assert len(extracted_tool_calls.tool_calls) == 2 assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool" - assert extracted_tool_calls.tool_calls[ - 1].function.name == "another_valid_tool" + assert extracted_tool_calls.tool_calls[1].function.name == "another_valid_tool" def test_streaming_basic_functionality(minimax_tool_parser): @@ -326,8 +343,7 @@ def test_streaming_basic_functionality(minimax_tool_parser): # The result might be None or contain tool call information # This depends on the internal state management - if result is not None and hasattr(result, - 'tool_calls') and result.tool_calls: + if result is not None and hasattr(result, "tool_calls") and result.tool_calls: assert len(result.tool_calls) >= 0 @@ -352,7 +368,7 @@ def test_streaming_with_content_before_tool_calls(minimax_tool_parser): request=None, ) - if result is not None and hasattr(result, 'content'): + if result is not None and hasattr(result, "content"): # Should contain some content assert result.content is not None @@ -373,7 +389,7 @@ def test_streaming_no_tool_calls(minimax_tool_parser): # Should return the delta text as content assert result is not None - assert hasattr(result, 'content') + assert hasattr(result, "content") assert result.content == " without any tool calls." @@ -399,8 +415,7 @@ def test_streaming_with_thinking_tags(minimax_tool_parser): # The preprocessing should remove tool calls from thinking tags # and only process the real tool call - if result is not None and hasattr(result, - 'tool_calls') and result.tool_calls: + if result is not None and hasattr(result, "tool_calls") and result.tool_calls: for tool_call in result.tool_calls: assert tool_call.function.name != "ignored" @@ -419,7 +434,8 @@ def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser): </tool_calls>""" extracted_tool_calls = minimax_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] # Multiline JSON is currently not supported, should return no tools called assert not extracted_tool_calls.tools_called @@ -449,7 +465,7 @@ def test_streaming_arguments_incremental_output(minimax_tool_parser): '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', # Stage 6: Tool calls closed '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool', - '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool_calls>' + '<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool_calls>', ] function_name_sent = False @@ -457,8 +473,7 @@ def test_streaming_arguments_incremental_output(minimax_tool_parser): for i, current_text in enumerate(stages): previous_text = stages[i - 1] if i > 0 else "" - delta_text = current_text[len(previous_text - ):] if i > 0 else current_text + delta_text = current_text[len(previous_text) :] if i > 0 else current_text result = minimax_tool_parser.extract_tool_calls_streaming( previous_text=previous_text, @@ -473,30 +488,27 @@ def test_streaming_arguments_incremental_output(minimax_tool_parser): print(f"Stage {i}: Current text: {repr(current_text)}") print(f"Stage {i}: Delta text: {repr(delta_text)}") - if result is not None and hasattr(result, - 'tool_calls') and result.tool_calls: + if result is not None and hasattr(result, "tool_calls") and result.tool_calls: tool_call = result.tool_calls[0] # Check if function name is sent (should happen only once) if tool_call.function and tool_call.function.name: assert tool_call.function.name == "get_current_weather" function_name_sent = True - print( - f"Stage {i}: Function name sent: {tool_call.function.name}" - ) + print(f"Stage {i}: Function name sent: {tool_call.function.name}") # Check if arguments are sent incrementally if tool_call.function and tool_call.function.arguments: args_fragment = tool_call.function.arguments - print( - f"Stage {i}: Got arguments fragment: {repr(args_fragment)}" - ) + print(f"Stage {i}: Got arguments fragment: {repr(args_fragment)}") # For incremental output, each fragment should be new content only # The fragment should not contain all previous content if i >= 2 and previous_args_content: # After we start getting arguments # The new fragment should not be identical to or contain all previous content - assert args_fragment != previous_args_content, f"Fragment should be incremental, not cumulative: {args_fragment}" + assert args_fragment != previous_args_content, ( + f"Fragment should be incremental, not cumulative: {args_fragment}" + ) # If this is truly incremental, the fragment should be relatively small # compared to the complete arguments so far @@ -520,7 +532,9 @@ def test_streaming_arguments_delta_only(minimax_tool_parser): minimax_tool_parser.streamed_args_for_tool = [] # Simulate two consecutive calls with growing arguments - call1_text = '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1"}}' + call1_text = ( + '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1"}}' + ) call2_text = '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1", "param2": "value2"}}' print(f"Call 1 text: {repr(call1_text)}") @@ -538,7 +552,7 @@ def test_streaming_arguments_delta_only(minimax_tool_parser): ) print(f"Result 1: {result1}") - if result1 and hasattr(result1, 'tool_calls') and result1.tool_calls: + if result1 and hasattr(result1, "tool_calls") and result1.tool_calls: for i, tc in enumerate(result1.tool_calls): print(f" Tool call {i}: {tc}") @@ -554,13 +568,12 @@ def test_streaming_arguments_delta_only(minimax_tool_parser): ) print(f"Result 2: {result2}") - if result2 and hasattr(result2, 'tool_calls') and result2.tool_calls: + if result2 and hasattr(result2, "tool_calls") and result2.tool_calls: for i, tc in enumerate(result2.tool_calls): print(f" Tool call {i}: {tc}") # Verify the second call only returns the delta - if result2 is not None and hasattr(result2, - 'tool_calls') and result2.tool_calls: + if result2 is not None and hasattr(result2, "tool_calls") and result2.tool_calls: tool_call = result2.tool_calls[0] if tool_call.function and tool_call.function.arguments: args_delta = tool_call.function.arguments @@ -568,17 +581,21 @@ def test_streaming_arguments_delta_only(minimax_tool_parser): # Should only contain the new part, not the full arguments # The delta should be something like ', "param2": "value2"}' or just '"param2": "value2"' - assert ', "param2": "value2"}' in args_delta or '"param2": "value2"' in args_delta, f"Expected delta containing param2, got: {args_delta}" + assert ( + ', "param2": "value2"}' in args_delta + or '"param2": "value2"' in args_delta + ), f"Expected delta containing param2, got: {args_delta}" # Should NOT contain the previous parameter data - assert '"param1": "value1"' not in args_delta, f"Arguments delta should not contain previous data: {args_delta}" + assert '"param1": "value1"' not in args_delta, ( + f"Arguments delta should not contain previous data: {args_delta}" + ) # The delta should be relatively short (incremental, not cumulative) - expected_max_length = len( - ', "param2": "value2"}') + 10 # Some tolerance - assert len( - args_delta - ) <= expected_max_length, f"Delta seems too long (possibly cumulative): {args_delta}" + expected_max_length = len(', "param2": "value2"}') + 10 # Some tolerance + assert len(args_delta) <= expected_max_length, ( + f"Delta seems too long (possibly cumulative): {args_delta}" + ) print("✓ Delta validation passed") else: @@ -605,40 +622,39 @@ def test_streaming_openai_compatibility(minimax_tool_parser): # Test scenario: simple buffering without complex tool call context test_cases: list[dict[str, Any]] = [ { - 'stage': 'Token: <', - 'previous': '', - 'current': '<', - 'delta': '<', - 'expected_content': None, # Should be buffered + "stage": "Token: <", + "previous": "", + "current": "<", + "delta": "<", + "expected_content": None, # Should be buffered }, { - 'stage': 'Token: tool_calls>', - 'previous': '<', - 'current': '<tool_calls>', - 'delta': 'tool_calls>', - 'expected_content': None, # Complete tag, should not output + "stage": "Token: tool_calls>", + "previous": "<", + "current": "<tool_calls>", + "delta": "tool_calls>", + "expected_content": None, # Complete tag, should not output }, { - 'stage': 'Regular content', - 'previous': 'Hello', - 'current': 'Hello world', - 'delta': ' world', - 'expected_content': ' world', # Normal content should pass through + "stage": "Regular content", + "previous": "Hello", + "current": "Hello world", + "delta": " world", + "expected_content": " world", # Normal content should pass through }, { - 'stage': 'Content with end tag start', - 'previous': 'Text', - 'current': 'Text content</tool_', - 'delta': ' content</tool_', - 'expected_content': - ' content', # Content part output, </tool_ buffered + "stage": "Content with end tag start", + "previous": "Text", + "current": "Text content</tool_", + "delta": " content</tool_", + "expected_content": " content", # Content part output, </tool_ buffered }, { - 'stage': 'Complete end tag', - 'previous': 'Text content</tool_', - 'current': 'Text content</tool_calls>', - 'delta': 'calls>', - 'expected_content': None, # Complete close tag, should not output + "stage": "Complete end tag", + "previous": "Text content</tool_", + "current": "Text content</tool_calls>", + "delta": "calls>", + "expected_content": None, # Complete close tag, should not output }, ] @@ -649,9 +665,9 @@ def test_streaming_openai_compatibility(minimax_tool_parser): print(f"Delta: {repr(test_case['delta'])}") result = minimax_tool_parser.extract_tool_calls_streaming( - previous_text=test_case['previous'], - current_text=test_case['current'], - delta_text=test_case['delta'], + previous_text=test_case["previous"], + current_text=test_case["current"], + delta_text=test_case["delta"], previous_token_ids=[], current_token_ids=[], delta_token_ids=[], @@ -661,15 +677,18 @@ def test_streaming_openai_compatibility(minimax_tool_parser): print(f"Result: {result}") # Check expected content - if test_case['expected_content'] is None: - assert result is None or not getattr(result, 'content', None), \ + if test_case["expected_content"] is None: + assert result is None or not getattr(result, "content", None), ( f"Stage {i}: Expected no content, got {result}" + ) print("✓ No content output as expected") else: - assert result is not None and hasattr(result, 'content'), \ + assert result is not None and hasattr(result, "content"), ( f"Stage {i}: Expected content, got {result}" - assert result.content == test_case['expected_content'], \ + ) + assert result.content == test_case["expected_content"], ( f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}" + ) print(f"✓ Content matches: {repr(result.content)}") print("✓ Streaming test with buffering completed successfully") @@ -690,35 +709,26 @@ def test_streaming_thinking_tag_buffering(minimax_tool_parser): # Test scenario: tool calls within thinking tags should be ignored test_cases: list[dict[str, Any]] = [ { - 'stage': 'Start thinking', - 'previous': '', - 'current': '<think>I need to use a tool. <tool_calls>', - 'delta': '<think>I need to use a tool. <tool_calls>', - 'expected_content': - '<think>I need to use a tool. <tool_calls>', # Should pass through as content + "stage": "Start thinking", + "previous": "", + "current": "<think>I need to use a tool. <tool_calls>", + "delta": "<think>I need to use a tool. <tool_calls>", + "expected_content": "<think>I need to use a tool. <tool_calls>", # Should pass through as content }, { - 'stage': - 'Tool call in thinking', - 'previous': - '<think>I need to use a tool. <tool_calls>', - 'current': - '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', - 'delta': - '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', - 'expected_content': - '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', # </tool_calls> should be preserved in thinking tags + "stage": "Tool call in thinking", + "previous": "<think>I need to use a tool. <tool_calls>", + "current": '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', + "delta": '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', + "expected_content": '\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', # </tool_calls> should be preserved in thinking tags }, { - 'stage': 'Real tool call after thinking', - 'previous': - '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>', - 'current': - '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>\n<tool_calls>', - 'delta': '\n<tool_calls>', - 'expected_content': - '\n', # Should output '\n' and suppress <tool_calls> - } + "stage": "Real tool call after thinking", + "previous": '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>', + "current": '<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>\n<tool_calls>', + "delta": "\n<tool_calls>", + "expected_content": "\n", # Should output '\n' and suppress <tool_calls> + }, ] for i, test_case in enumerate(test_cases): @@ -728,9 +738,9 @@ def test_streaming_thinking_tag_buffering(minimax_tool_parser): print(f"Delta: {repr(test_case['delta'])}") result = minimax_tool_parser.extract_tool_calls_streaming( - previous_text=test_case['previous'], - current_text=test_case['current'], - delta_text=test_case['delta'], + previous_text=test_case["previous"], + current_text=test_case["current"], + delta_text=test_case["delta"], previous_token_ids=[], current_token_ids=[], delta_token_ids=[], @@ -740,25 +750,32 @@ def test_streaming_thinking_tag_buffering(minimax_tool_parser): print(f"Result: {result}") # Check expected content - if 'expected_content' in test_case: - if test_case['expected_content'] is None: - assert result is None or not getattr(result, 'content', None), \ + if "expected_content" in test_case: + if test_case["expected_content"] is None: + assert result is None or not getattr(result, "content", None), ( f"Stage {i}: Expected no content, got {result}" + ) else: - assert result is not None and hasattr(result, 'content'), \ + assert result is not None and hasattr(result, "content"), ( f"Stage {i}: Expected content, got {result}" - assert result.content == test_case['expected_content'], \ + ) + assert result.content == test_case["expected_content"], ( f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}" + ) print(f"✓ Content matches: {repr(result.content)}") # Check tool calls - if test_case.get('expected_tool_call'): - assert result is not None and hasattr(result, 'tool_calls') and result.tool_calls, \ - f"Stage {i}: Expected tool call, got {result}" + if test_case.get("expected_tool_call"): + assert ( + result is not None + and hasattr(result, "tool_calls") + and result.tool_calls + ), f"Stage {i}: Expected tool call, got {result}" tool_call = result.tool_calls[0] - assert tool_call.function.name == "real_tool", \ + assert tool_call.function.name == "real_tool", ( f"Expected real_tool, got {tool_call.function.name}" + ) print(f"✓ Real tool call detected: {tool_call.function.name}") print("✓ Thinking tag buffering test completed successfully") @@ -784,104 +801,79 @@ def test_streaming_complex_scenario_with_multiple_tools(minimax_tool_parser): # Complex scenario: tools inside thinking tags and multiple tools in one group test_stages: list[dict[str, Any]] = [ { - 'stage': 'Initial content', - 'previous': '', - 'current': 'Let me help you with this task.', - 'delta': 'Let me help you with this task.', - 'expected_content': 'Let me help you with this task.', - 'expected_tool_calls': 0, + "stage": "Initial content", + "previous": "", + "current": "Let me help you with this task.", + "delta": "Let me help you with this task.", + "expected_content": "Let me help you with this task.", + "expected_tool_calls": 0, }, { - 'stage': 'Start thinking tag', - 'previous': 'Let me help you with this task.', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.', - 'delta': '<think>I need to analyze this situation first.', - 'expected_content': - '<think>I need to analyze this situation first.', - 'expected_tool_calls': 0, + "stage": "Start thinking tag", + "previous": "Let me help you with this task.", + "current": "Let me help you with this task.<think>I need to analyze this situation first.", + "delta": "<think>I need to analyze this situation first.", + "expected_content": "<think>I need to analyze this situation first.", + "expected_tool_calls": 0, }, { - 'stage': 'Tool call inside thinking tag starts', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>', - 'delta': '<tool_calls>', - 'expected_content': - '<tool_calls>', # Inside thinking tags, tool tags should be preserved as content - 'expected_tool_calls': 0, + "stage": "Tool call inside thinking tag starts", + "previous": "Let me help you with this task.<think>I need to analyze this situation first.", + "current": "Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>", + "delta": "<tool_calls>", + "expected_content": "<tool_calls>", # Inside thinking tags, tool tags should be preserved as content + "expected_tool_calls": 0, }, { - 'stage': 'Complete tool call inside thinking tag', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', - 'delta': - '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', - 'expected_content': - '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', - 'expected_tool_calls': - 0, # Tools inside thinking tags should be ignored + "stage": "Complete tool call inside thinking tag", + "previous": "Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>", + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + "delta": '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + "expected_content": '\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + "expected_tool_calls": 0, # Tools inside thinking tags should be ignored }, { - 'stage': 'End thinking tag', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>', - 'delta': '</think>', - 'expected_content': '</think>', - 'expected_tool_calls': 0, + "stage": "End thinking tag", + "previous": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>', + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>', + "delta": "</think>", + "expected_content": "</think>", + "expected_tool_calls": 0, }, { - 'stage': 'Multiple tools group starts', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>', - 'delta': - '\nNow I need to get weather information and calculate area.<tool_calls>', - 'expected_content': - '\nNow I need to get weather information and calculate area.', # <tool_calls> should be filtered - 'expected_tool_calls': 0, + "stage": "Multiple tools group starts", + "previous": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>', + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>', + "delta": "\nNow I need to get weather information and calculate area.<tool_calls>", + "expected_content": "\nNow I need to get weather information and calculate area.", # <tool_calls> should be filtered + "expected_tool_calls": 0, }, { - 'stage': 'First tool in group', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', - 'delta': - '\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', - 'expected_content': - None, # No content should be output when tool call is in progress - 'expected_tool_calls': 1, - 'expected_tool_name': 'get_current_weather', + "stage": "First tool in group", + "previous": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>', + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + "delta": '\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + "expected_content": None, # No content should be output when tool call is in progress + "expected_tool_calls": 1, + "expected_tool_name": "get_current_weather", }, { - 'stage': 'Second tool in group', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', - 'delta': - '\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', - 'expected_content': None, - 'expected_tool_calls': 1, - 'expected_tool_name': 'calculate_area', + "stage": "Second tool in group", + "previous": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}', + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + "delta": '\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + "expected_content": None, + "expected_tool_calls": 1, + "expected_tool_name": "calculate_area", }, { - 'stage': 'Complete tool calls group', - 'previous': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', - 'current': - 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}</tool_calls>', - 'delta': '</tool_calls>', - 'expected_content': None, - 'expected_tool_calls': 0, - } + "stage": "Complete tool calls group", + "previous": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}', + "current": 'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}</tool_calls>', + "delta": "</tool_calls>", + "expected_content": None, + "expected_tool_calls": 0, + }, ] tool_calls_count = 0 @@ -895,9 +887,9 @@ def test_streaming_complex_scenario_with_multiple_tools(minimax_tool_parser): print(f"Delta: {repr(test_case['delta'])}") result = minimax_tool_parser.extract_tool_calls_streaming( - previous_text=test_case['previous'], - current_text=test_case['current'], - delta_text=test_case['delta'], + previous_text=test_case["previous"], + current_text=test_case["current"], + delta_text=test_case["delta"], previous_token_ids=[], current_token_ids=[], delta_token_ids=[], @@ -907,53 +899,64 @@ def test_streaming_complex_scenario_with_multiple_tools(minimax_tool_parser): print(f"Result: {result}") # Check expected content - if test_case['expected_content'] is None: - assert result is None or not getattr(result, 'content', None), \ + if test_case["expected_content"] is None: + assert result is None or not getattr(result, "content", None), ( f"Stage {i}: Expected no content output, got {result}" + ) print("✓ No content output as expected") else: - assert result is not None and hasattr(result, 'content'), \ + assert result is not None and hasattr(result, "content"), ( f"Stage {i}: Expected content output, got {result}" - assert result.content == test_case['expected_content'], \ + ) + assert result.content == test_case["expected_content"], ( f"Stage {i}: Expected content {repr(test_case['expected_content'])}, got {repr(result.content)}" + ) print(f"✓ Content matches: {repr(result.content)}") # Check tool calls - expected_tool_calls = test_case['expected_tool_calls'] - actual_tool_calls = len(result.tool_calls) if result and hasattr( - result, 'tool_calls') and result.tool_calls else 0 + expected_tool_calls = test_case["expected_tool_calls"] + actual_tool_calls = ( + len(result.tool_calls) + if result and hasattr(result, "tool_calls") and result.tool_calls + else 0 + ) if expected_tool_calls > 0: - assert actual_tool_calls >= expected_tool_calls, \ + assert actual_tool_calls >= expected_tool_calls, ( f"Stage {i}: Expected at least {expected_tool_calls} tool calls, got {actual_tool_calls}" + ) - if 'expected_tool_name' in test_case: + if "expected_tool_name" in test_case: # Find the tool call with the expected name found_tool_call = None for tool_call in result.tool_calls: - if tool_call.function.name == test_case[ - 'expected_tool_name']: + if tool_call.function.name == test_case["expected_tool_name"]: found_tool_call = tool_call break - assert found_tool_call is not None, \ + assert found_tool_call is not None, ( f"Stage {i}: Expected tool name {test_case['expected_tool_name']} not found in tool calls: {[tc.function.name for tc in result.tool_calls]}" + ) print(f"✓ Tool call correct: {found_tool_call.function.name}") # Ensure tools inside thinking tags are not called - assert found_tool_call.function.name != "internal_analysis", \ + assert found_tool_call.function.name != "internal_analysis", ( f"Stage {i}: Tool 'internal_analysis' inside thinking tags should not be called" + ) tool_calls_count += actual_tool_calls print(f"✓ Detected {actual_tool_calls} tool calls") else: - assert actual_tool_calls == 0, \ + assert actual_tool_calls == 0, ( f"Stage {i}: Expected no tool calls, got {actual_tool_calls}" + ) # Verify overall results print("\n=== Test Summary ===") print(f"Total tool calls count: {tool_calls_count}") - assert tool_calls_count >= 2, f"Expected at least 2 valid tool calls (outside thinking tags), but got {tool_calls_count}" + assert tool_calls_count >= 2, ( + f"Expected at least 2 valid tool calls (outside thinking tags), but got {tool_calls_count}" + ) print("✓ Complex streaming test completed:") print(" - ✓ Tools inside thinking tags correctly ignored") @@ -987,8 +990,8 @@ def test_streaming_character_by_character_output(minimax_tool_parser): # Stream character by character for i in range(1, len(complete_text) + 1): current_text = complete_text[:i] - previous_text = complete_text[:i - 1] if i > 1 else "" - delta_text = complete_text[i - 1:i] + previous_text = complete_text[: i - 1] if i > 1 else "" + delta_text = complete_text[i - 1 : i] # Show progress every 50 characters if i % 50 == 0 or i == len(complete_text): @@ -1007,36 +1010,35 @@ def test_streaming_character_by_character_output(minimax_tool_parser): # Collect results if result is not None: - if hasattr(result, 'content') and result.content: + if hasattr(result, "content") and result.content: content_fragments.append(result.content) # Log important content fragments if any( - keyword in result.content for keyword in - ['<think>', '</think>', '<tool_calls>', '</tool_calls>']): - print( - f" Char {i}: Content fragment: {repr(result.content)}" - ) - - if hasattr(result, 'tool_calls') and result.tool_calls: + keyword in result.content + for keyword in [ + "<think>", + "</think>", + "<tool_calls>", + "</tool_calls>", + ] + ): + print(f" Char {i}: Content fragment: {repr(result.content)}") + + if hasattr(result, "tool_calls") and result.tool_calls: for tool_call in result.tool_calls: tool_info = { - 'character_position': - i, - 'function_name': - tool_call.function.name - if tool_call.function else None, - 'arguments': - tool_call.function.arguments - if tool_call.function else None, + "character_position": i, + "function_name": tool_call.function.name + if tool_call.function + else None, + "arguments": tool_call.function.arguments + if tool_call.function + else None, } tool_calls_detected.append(tool_info) - print( - f" Char {i}: Tool call detected: {tool_call.function.name}" - ) + print(f" Char {i}: Tool call detected: {tool_call.function.name}") if tool_call.function.arguments: - print( - f" Arguments: {repr(tool_call.function.arguments)}" - ) + print(f" Arguments: {repr(tool_call.function.arguments)}") # Verify results print("\n=== Streaming Test Results ===") @@ -1044,68 +1046,74 @@ def test_streaming_character_by_character_output(minimax_tool_parser): print(f"Total tool calls detected: {len(tool_calls_detected)}") # Reconstruct content from fragments - reconstructed_content = ''.join(content_fragments) + reconstructed_content = "".join(content_fragments) print(f"Reconstructed content length: {len(reconstructed_content)}") # Verify thinking tags content is preserved - assert '<think>' in reconstructed_content, "Opening thinking tag should be preserved in content" - assert '</think>' in reconstructed_content, "Closing thinking tag should be preserved in content" + assert "<think>" in reconstructed_content, ( + "Opening thinking tag should be preserved in content" + ) + assert "</think>" in reconstructed_content, ( + "Closing thinking tag should be preserved in content" + ) # Verify that tool calls inside thinking tags are NOT extracted as actual tool calls thinking_tool_calls = [ - tc for tc in tool_calls_detected - if tc['function_name'] == 'internal_analysis' + tc for tc in tool_calls_detected if tc["function_name"] == "internal_analysis" ] - assert len( - thinking_tool_calls - ) == 0, f"Tool calls inside thinking tags should be ignored, but found: {thinking_tool_calls}" + assert len(thinking_tool_calls) == 0, ( + f"Tool calls inside thinking tags should be ignored, but found: {thinking_tool_calls}" + ) # Verify that real tool calls outside thinking tags ARE extracted weather_tool_calls = [ - tc for tc in tool_calls_detected - if tc['function_name'] == 'get_current_weather' + tc for tc in tool_calls_detected if tc["function_name"] == "get_current_weather" ] area_tool_calls = [ - tc for tc in tool_calls_detected - if tc['function_name'] == 'calculate_area' + tc for tc in tool_calls_detected if tc["function_name"] == "calculate_area" ] print(tool_calls_detected) - assert len(weather_tool_calls - ) > 0, "get_current_weather tool call should be detected" - assert len( - area_tool_calls) > 0, "calculate_area tool call should be detected" + assert len(weather_tool_calls) > 0, ( + "get_current_weather tool call should be detected" + ) + assert len(area_tool_calls) > 0, "calculate_area tool call should be detected" # Verify tool call arguments are properly streamed - weather_args_found = any(tc['arguments'] for tc in weather_tool_calls - if tc['arguments']) - area_args_found = any(tc['arguments'] for tc in area_tool_calls - if tc['arguments']) + weather_args_found = any( + tc["arguments"] for tc in weather_tool_calls if tc["arguments"] + ) + area_args_found = any(tc["arguments"] for tc in area_tool_calls if tc["arguments"]) print(f"Weather tool call with arguments: {weather_args_found}") print(f"Area tool call with arguments: {area_args_found}") # Verify content before and after tool calls - assert 'I\'ll help you with the weather analysis.' in reconstructed_content, "Initial content should be preserved" - assert 'Here are the results.' in reconstructed_content, "Final content should be preserved" + assert "I'll help you with the weather analysis." in reconstructed_content, ( + "Initial content should be preserved" + ) + assert "Here are the results." in reconstructed_content, ( + "Final content should be preserved" + ) # Verify that <tool_calls> and </tool_calls> tags are not included in the final content # (they should be filtered out when not inside thinking tags) content_outside_thinking = reconstructed_content # Remove thinking tag content to check content outside - if '<think>' in content_outside_thinking and '</think>' in content_outside_thinking: - start_think = content_outside_thinking.find('<think>') - end_think = content_outside_thinking.find('</think>') + len('</think>') - content_outside_thinking = content_outside_thinking[: - start_think] + content_outside_thinking[ - end_think:] + if "<think>" in content_outside_thinking and "</think>" in content_outside_thinking: + start_think = content_outside_thinking.find("<think>") + end_think = content_outside_thinking.find("</think>") + len("</think>") + content_outside_thinking = ( + content_outside_thinking[:start_think] + + content_outside_thinking[end_think:] + ) # Outside thinking tags, tool_calls tags should be filtered - tool_calls_in_content = content_outside_thinking.count('<tool_calls>') - assert tool_calls_in_content == 0, f"<tool_calls> tags should be filtered from content outside thinking tags, but found {tool_calls_in_content}" - - print( - "\n=== Character-by-character streaming test completed successfully ===" + tool_calls_in_content = content_outside_thinking.count("<tool_calls>") + assert tool_calls_in_content == 0, ( + f"<tool_calls> tags should be filtered from content outside thinking tags, but found {tool_calls_in_content}" ) + + print("\n=== Character-by-character streaming test completed successfully ===") print("✓ Tool calls inside thinking tags correctly ignored") print("✓ Tool calls outside thinking tags correctly detected") print("✓ Content properly streamed and reconstructed") @@ -1113,8 +1121,7 @@ def test_streaming_character_by_character_output(minimax_tool_parser): print("✓ Character-level streaming works correctly") -def test_streaming_character_by_character_simple_tool_call( - minimax_tool_parser): +def test_streaming_character_by_character_simple_tool_call(minimax_tool_parser): """Test character-by-character streaming for a simple tool call scenario.""" # Reset streaming state reset_streaming_state(minimax_tool_parser) @@ -1131,8 +1138,8 @@ def test_streaming_character_by_character_simple_tool_call( for i in range(1, len(simple_text) + 1): current_text = simple_text[:i] - previous_text = simple_text[:i - 1] if i > 1 else "" - delta_text = simple_text[i - 1:i] + previous_text = simple_text[: i - 1] if i > 1 else "" + delta_text = simple_text[i - 1 : i] result = minimax_tool_parser.extract_tool_calls_streaming( previous_text=previous_text, @@ -1145,19 +1152,17 @@ def test_streaming_character_by_character_simple_tool_call( ) if result: - if hasattr(result, 'content') and result.content: + if hasattr(result, "content") and result.content: content_parts.append(result.content) print( f" Char {i} ({repr(delta_text)}): Content: {repr(result.content)}" ) - if hasattr(result, 'tool_calls') and result.tool_calls: + if hasattr(result, "tool_calls") and result.tool_calls: for tool_call in result.tool_calls: if tool_call.function and tool_call.function.name: tool_name_sent = True - print( - f" Char {i}: Tool name: {tool_call.function.name}" - ) + print(f" Char {i}: Tool name: {tool_call.function.name}") if tool_call.function and tool_call.function.arguments: tool_args_sent = True print( @@ -1165,12 +1170,14 @@ def test_streaming_character_by_character_simple_tool_call( ) # Verify basic expectations - reconstructed_content = ''.join(content_parts) + reconstructed_content = "".join(content_parts) print(f"Final reconstructed content: {repr(reconstructed_content)}") assert tool_name_sent, "Tool name should be sent during streaming" assert tool_args_sent, "Tool arguments should be sent during streaming" - assert "Let me check the weather." in reconstructed_content, "Initial content should be preserved" + assert "Let me check the weather." in reconstructed_content, ( + "Initial content should be preserved" + ) print("✓ Simple character-by-character test passed") @@ -1190,8 +1197,8 @@ def test_streaming_character_by_character_with_buffering(minimax_tool_parser): for i in range(1, len(buffering_text) + 1): current_text = buffering_text[:i] - previous_text = buffering_text[:i - 1] if i > 1 else "" - delta_text = buffering_text[i - 1:i] + previous_text = buffering_text[: i - 1] if i > 1 else "" + delta_text = buffering_text[i - 1 : i] result = minimax_tool_parser.extract_tool_calls_streaming( previous_text=previous_text, @@ -1203,16 +1210,18 @@ def test_streaming_character_by_character_with_buffering(minimax_tool_parser): request=None, ) - if result and hasattr(result, 'content') and result.content: + if result and hasattr(result, "content") and result.content: all_content.append(result.content) print(f" Char {i} ({repr(delta_text)}): {repr(result.content)}") - final_content = ''.join(all_content) + final_content = "".join(all_content) print(f"Final content: {repr(final_content)}") # The parser should handle the edge case where </tool_calls> appears before <tool_calls> assert "Hello" in final_content, "Initial 'Hello' should be preserved" - assert "world" in final_content, "Content after false closing tag should be preserved" + assert "world" in final_content, ( + "Content after false closing tag should be preserved" + ) assert "done" in final_content, "Final content should be preserved" print("✓ Buffering character-by-character test passed") diff --git a/tests/tool_use/test_openai_tool_parser.py b/tests/tool_use/test_openai_tool_parser.py index 2551c41c6275..f6223f3fdce4 100644 --- a/tests/tool_use/test_openai_tool_parser.py +++ b/tests/tool_use/test_openai_tool_parser.py @@ -4,9 +4,15 @@ import json import pytest -from openai_harmony import (Conversation, DeveloperContent, - HarmonyEncodingName, Message, Role, SystemContent, - load_harmony_encoding) +from openai_harmony import ( + Conversation, + DeveloperContent, + HarmonyEncodingName, + Message, + Role, + SystemContent, + load_harmony_encoding, +) from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser @@ -37,8 +43,9 @@ def assert_tool_calls( ): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 16 # Default from protocol.py assert actual_tool_call.type == "function" @@ -46,20 +53,25 @@ def assert_tool_calls( def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding): - convo = Conversation.from_messages([ - Message.from_role_and_content( - Role.SYSTEM, - SystemContent.new(), - ), - Message.from_role_and_content( - Role.DEVELOPER, - DeveloperContent.new().with_instructions("Talk like a pirate!")), - Message.from_role_and_content(Role.USER, "Arrr, how be you?"), - Message.from_role_and_content(Role.ASSISTANT, - "This is a test").with_channel("final") - ]) + convo = Conversation.from_messages( + [ + Message.from_role_and_content( + Role.SYSTEM, + SystemContent.new(), + ), + Message.from_role_and_content( + Role.DEVELOPER, + DeveloperContent.new().with_instructions("Talk like a pirate!"), + ), + Message.from_role_and_content(Role.USER, "Arrr, how be you?"), + Message.from_role_and_content( + Role.ASSISTANT, "This is a test" + ).with_channel("final"), + ] + ) token_ids = harmony_encoding.render_conversation_for_completion( - convo, Role.ASSISTANT) + convo, Role.ASSISTANT + ) extracted_info = openai_tool_parser.extract_tool_calls( "", request=None, @@ -70,26 +82,32 @@ def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding): assert extracted_info.content == "This is a test" -@pytest.mark.parametrize("tool_args", [ - '{"location": "Tokyo"}', - '{\n"location": "Tokyo"\n}', -]) -def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding, - tool_args): - convo = Conversation.from_messages([ - Message.from_role_and_content(Role.USER, - "What is the weather in Tokyo?"), - Message.from_role_and_content( - Role.ASSISTANT, - 'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501 - ).with_channel("analysis"), - Message.from_role_and_content( - Role.ASSISTANT, - tool_args).with_channel("commentary").with_recipient( - "functions.get_current_weather").with_content_type("json"), - ]) +@pytest.mark.parametrize( + "tool_args", + [ + '{"location": "Tokyo"}', + '{\n"location": "Tokyo"\n}', + ], +) +def test_extract_tool_calls_single_tool( + openai_tool_parser, harmony_encoding, tool_args +): + convo = Conversation.from_messages( + [ + Message.from_role_and_content(Role.USER, "What is the weather in Tokyo?"), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content(Role.ASSISTANT, tool_args) + .with_channel("commentary") + .with_recipient("functions.get_current_weather") + .with_content_type("json"), + ] + ) token_ids = harmony_encoding.render_conversation_for_completion( - convo, Role.ASSISTANT) + convo, Role.ASSISTANT + ) extracted_info = openai_tool_parser.extract_tool_calls( "", @@ -98,10 +116,12 @@ def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding, ) assert extracted_info.tools_called expected_tool_calls = [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({"location": "Tokyo"}), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + ) + ) ] assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) assert extracted_info.content is None @@ -111,33 +131,39 @@ def test_extract_tool_calls_multiple_tools( openai_tool_parser, harmony_encoding, ): - convo = Conversation.from_messages([ - Message.from_role_and_content( - Role.USER, "What is the weather in Tokyo based on where I'm at?"), - Message.from_role_and_content( - Role.ASSISTANT, - 'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501 - ).with_channel("analysis"), - Message.from_role_and_content( - Role.ASSISTANT, - '{"location": "Tokyo"}').with_channel("commentary").with_recipient( - "functions.get_current_weather").with_content_type("json"), - Message.from_role_and_content( - Role.ASSISTANT, - '{"location": "Tokyo"}').with_channel("commentary").with_recipient( - "functions.get_user_location").with_content_type("json"), - Message.from_role_and_content( - Role.ASSISTANT, '{"location": "Tokyo"}').with_channel( - "commentary").with_recipient("functions.no_content_type"), - Message.from_role_and_content(Role.ASSISTANT, "foo").with_channel( - "commentary").with_recipient("functions.not_json_no_content_type"), - Message.from_role_and_content( - Role.ASSISTANT, '{}').with_channel("commentary").with_recipient( - "functions.empty_args").with_content_type("json"), - Message.from_role_and_content( - Role.ASSISTANT, '').with_channel("commentary").with_recipient( - "functions.no_args").with_content_type("json"), - ]) + convo = Conversation.from_messages( + [ + Message.from_role_and_content( + Role.USER, "What is the weather in Tokyo based on where I'm at?" + ), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') + .with_channel("commentary") + .with_recipient("functions.get_current_weather") + .with_content_type("json"), + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') + .with_channel("commentary") + .with_recipient("functions.get_user_location") + .with_content_type("json"), + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') + .with_channel("commentary") + .with_recipient("functions.no_content_type"), + Message.from_role_and_content(Role.ASSISTANT, "foo") + .with_channel("commentary") + .with_recipient("functions.not_json_no_content_type"), + Message.from_role_and_content(Role.ASSISTANT, "{}") + .with_channel("commentary") + .with_recipient("functions.empty_args") + .with_content_type("json"), + Message.from_role_and_content(Role.ASSISTANT, "") + .with_channel("commentary") + .with_recipient("functions.no_args") + .with_content_type("json"), + ] + ) token_ids = harmony_encoding.render_conversation_for_completion( convo, Role.ASSISTANT, @@ -150,30 +176,42 @@ def test_extract_tool_calls_multiple_tools( ) assert extracted_info.tools_called expected_tool_calls = [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({"location": "Tokyo"}), - )), - ToolCall(function=FunctionCall( - name="get_user_location", - arguments=json.dumps({"location": "Tokyo"}), - )), - ToolCall(function=FunctionCall( - name="no_content_type", - arguments=json.dumps({"location": "Tokyo"}), - )), - ToolCall(function=FunctionCall( - name="not_json_no_content_type", - arguments="foo", - )), - ToolCall(function=FunctionCall( - name="empty_args", - arguments=json.dumps({}), - )), - ToolCall(function=FunctionCall( - name="no_args", - arguments="", - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + ) + ), + ToolCall( + function=FunctionCall( + name="get_user_location", + arguments=json.dumps({"location": "Tokyo"}), + ) + ), + ToolCall( + function=FunctionCall( + name="no_content_type", + arguments=json.dumps({"location": "Tokyo"}), + ) + ), + ToolCall( + function=FunctionCall( + name="not_json_no_content_type", + arguments="foo", + ) + ), + ToolCall( + function=FunctionCall( + name="empty_args", + arguments=json.dumps({}), + ) + ), + ToolCall( + function=FunctionCall( + name="no_args", + arguments="", + ) + ), ] assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) assert extracted_info.content is None @@ -184,20 +222,24 @@ def test_extract_tool_calls_with_content( harmony_encoding, ): final_content = "This tool call will get the weather." - convo = Conversation.from_messages([ - Message.from_role_and_content( - Role.USER, "What is the weather in Tokyo based on where I'm at?"), - Message.from_role_and_content( - Role.ASSISTANT, - 'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501 - ).with_channel("analysis"), - Message.from_role_and_content( - Role.ASSISTANT, - '{"location": "Tokyo"}').with_channel("commentary").with_recipient( - "functions.get_current_weather").with_content_type("json"), - Message.from_role_and_content(Role.ASSISTANT, - final_content).with_channel("final"), - ]) + convo = Conversation.from_messages( + [ + Message.from_role_and_content( + Role.USER, "What is the weather in Tokyo based on where I'm at?" + ), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501 + ).with_channel("analysis"), + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') + .with_channel("commentary") + .with_recipient("functions.get_current_weather") + .with_content_type("json"), + Message.from_role_and_content(Role.ASSISTANT, final_content).with_channel( + "final" + ), + ] + ) token_ids = harmony_encoding.render_conversation_for_completion( convo, Role.ASSISTANT, @@ -210,10 +252,12 @@ def test_extract_tool_calls_with_content( ) assert extracted_info.tools_called expected_tool_calls = [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({"location": "Tokyo"}), - )), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({"location": "Tokyo"}), + ) + ), ] assert_tool_calls(extracted_info.tool_calls, expected_tool_calls) assert extracted_info.content == final_content diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index fff20c68d621..159966365ec4 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -7,9 +7,13 @@ import openai import pytest -from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, - MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, - WEATHER_TOOL, ServerConfig) +from .utils import ( + MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + SEARCH_TOOL, + WEATHER_TOOL, + ServerConfig, +) # test: getting the model to generate parallel tool calls (streaming/not) @@ -17,12 +21,15 @@ # may be added in the future. e.g. llama 3.1 models are not designed to support # parallel tool calls. @pytest.mark.asyncio -async def test_parallel_tool_calls(client: openai.AsyncOpenAI, - server_config: ServerConfig): - +async def test_parallel_tool_calls( + client: openai.AsyncOpenAI, server_config: ServerConfig +): if not server_config.get("supports_parallel", True): - pytest.skip("The {} model doesn't support parallel tool calls".format( - server_config["model"])) + pytest.skip( + "The {} model doesn't support parallel tool calls".format( + server_config["model"] + ) + ) models = await client.models.list() model_name: str = models.data[0].id @@ -32,7 +39,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, max_completion_tokens=200, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason @@ -69,7 +77,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, max_completion_tokens=200, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, - stream=True) + stream=True, + ) role_name: Optional[str] = None finish_reason_count: int = 0 @@ -80,24 +89,22 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, tool_call_id_count: int = 0 async for chunk in stream: - # if there's a finish reason make sure it's tools if chunk.choices[0].finish_reason: finish_reason_count += 1 - assert chunk.choices[0].finish_reason == 'tool_calls' + assert chunk.choices[0].finish_reason == "tool_calls" # if a role is being streamed make sure it wasn't already set to # something else if chunk.choices[0].delta.role: - assert not role_name or role_name == 'assistant' - role_name = 'assistant' + assert not role_name or role_name == "assistant" + role_name = "assistant" # if a tool call is streamed make sure there's exactly one # (based on the request parameters streamed_tool_calls = chunk.choices[0].delta.tool_calls if streamed_tool_calls and len(streamed_tool_calls) > 0: - # make sure only one diff is present - correct even for parallel assert len(streamed_tool_calls) == 1 tool_call = streamed_tool_calls[0] @@ -110,8 +117,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, # if a tool call ID is streamed, make sure one hasn't been already if tool_call.id: tool_call_id_count += 1 - assert (isinstance(tool_call.id, str) - and (len(tool_call.id) >= 9)) + assert isinstance(tool_call.id, str) and (len(tool_call.id) >= 9) # if parts of the function start being streamed if tool_call.function: @@ -125,32 +131,32 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, # make sure they're a string and then add them to the list assert isinstance(tool_call.function.arguments, str) - tool_call_args[ - tool_call.index] += tool_call.function.arguments + tool_call_args[tool_call.index] += tool_call.function.arguments assert finish_reason_count == 1 - assert role_name == 'assistant' + assert role_name == "assistant" - assert (len(non_streamed_tool_calls) == len(tool_call_names) == - len(tool_call_args)) + assert len(non_streamed_tool_calls) == len(tool_call_names) == len(tool_call_args) for i in range(2): assert non_streamed_tool_calls[i].function.name == tool_call_names[i] streamed_args = json.loads(tool_call_args[i]) - non_streamed_args = json.loads( - non_streamed_tool_calls[i].function.arguments) + non_streamed_args = json.loads(non_streamed_tool_calls[i].function.arguments) assert streamed_args == non_streamed_args # test: providing parallel tool calls back to the model to get a response # (streaming/not) @pytest.mark.asyncio -async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI, - server_config: ServerConfig): - +async def test_parallel_tool_calls_with_results( + client: openai.AsyncOpenAI, server_config: ServerConfig +): if not server_config.get("supports_parallel", True): - pytest.skip("The {} model doesn't support parallel tool calls".format( - server_config["model"])) + pytest.skip( + "The {} model doesn't support parallel tool calls".format( + server_config["model"] + ) + ) models = await client.models.list() model_name: str = models.data[0].id @@ -160,14 +166,14 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI, max_completion_tokens=200, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] assert choice.finish_reason != "tool_calls" # "stop" or "length" assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 0 + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 assert choice.message.content is not None assert "98" in choice.message.content # Dallas temp in tool response assert "78" in choice.message.content # Orlando temp in tool response @@ -179,7 +185,8 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, - stream=True) + stream=True, + ) chunks: list[str] = [] finish_reason_count = 0 diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index ade089e8246e..20fa3b08c7b9 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -7,14 +7,17 @@ import pytest -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaMessage, FunctionCall, - ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( - Qwen3CoderToolParser) -from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import ( - Qwen3XMLToolParser) + Qwen3CoderToolParser, +) +from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer @@ -39,8 +42,7 @@ def qwen3_xml_tool_parser(qwen3_tokenizer): @pytest.fixture(params=["original", "xml"]) -def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, - request): +def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request): """Parameterized fixture that provides both parser types for testing""" if request.param == "original": return qwen3_tool_parser @@ -51,76 +53,63 @@ def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, @pytest.fixture def sample_tools(): return [ - ChatCompletionToolsParam(type="function", - function={ - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "The city name" - }, - "state": { - "type": "string", - "description": - "The state code" - }, - "unit": { - "type": "string", - "enum": - ["fahrenheit", "celsius"] - } - }, - "required": ["city", "state"] - } - }), - ChatCompletionToolsParam(type="function", - function={ - "name": "calculate_area", - "description": - "Calculate area of a shape", - "parameters": { - "type": "object", - "properties": { - "shape": { - "type": "string" - }, - "dimensions": { - "type": "object" - }, - "precision": { - "type": "integer" - } - } - } - }) + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city name"}, + "state": {"type": "string", "description": "The state code"}, + "unit": {"type": "string", "enum": ["fahrenheit", "celsius"]}, + }, + "required": ["city", "state"], + }, + }, + ), + ChatCompletionToolsParam( + type="function", + function={ + "name": "calculate_area", + "description": "Calculate area of a shape", + "parameters": { + "type": "object", + "properties": { + "shape": {"type": "string"}, + "dimensions": {"type": "object"}, + "precision": {"type": "integer"}, + }, + }, + }, + ), ] -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): # Qwen3 parser doesn't generate IDs during extraction assert actual_tool_call.type == "function" - assert ( - actual_tool_call.function.name == expected_tool_call.function.name) - assert (json.loads(actual_tool_call.function.arguments) == json.loads( - expected_tool_call.function.arguments)) + assert actual_tool_call.function.name == expected_tool_call.function.name + assert json.loads(actual_tool_call.function.arguments) == json.loads( + expected_tool_call.function.arguments + ) def stream_delta_message_generator( qwen3_tool_parser, qwen3_tokenizer: AnyTokenizer, model_output: str, - request: Optional[ChatCompletionRequest] = None + request: Optional[ChatCompletionRequest] = None, ) -> Generator[DeltaMessage, None, None]: - all_token_ids = qwen3_tokenizer.encode(model_output, - add_special_tokens=False) + all_token_ids = qwen3_tokenizer.encode(model_output, add_special_tokens=False) previous_text = "" previous_tokens = None @@ -129,18 +118,19 @@ def stream_delta_message_generator( for i, delta_token in enumerate(all_token_ids): delta_token_ids = [delta_token] previous_token_ids = all_token_ids[:i] - current_token_ids = all_token_ids[:i + 1] - - (new_tokens, delta_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=qwen3_tokenizer, - all_input_ids=current_token_ids, - prev_tokens=previous_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=False, - spaces_between_special_tokens=True, - ) + current_token_ids = all_token_ids[: i + 1] + + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=qwen3_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + ) current_text = previous_text + delta_text @@ -157,8 +147,9 @@ def stream_delta_message_generator( yield delta_message previous_text = current_text - previous_tokens = (previous_tokens + - new_tokens if previous_tokens else new_tokens) + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) prefix_offset = new_prefix_offset read_offset = new_read_offset @@ -166,7 +157,8 @@ def stream_delta_message_generator( def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized): model_output = "This is a test response without any tool calls" extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -182,7 +174,8 @@ def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized): ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ - ('''<tool_call> + ( + """<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -194,16 +187,21 @@ def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized): fahrenheit </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) - ], None), - ('''Sure! Let me check the weather for you.<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) + ], + None, + ), + ( + """Sure! Let me check the weather for you.<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -215,16 +213,21 @@ def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized): fahrenheit </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) - ], "Sure! Let me check the weather for you."), - ('''<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) + ], + "Sure! Let me check the weather for you.", + ), + ( + """<tool_call> <function=calculate_area> <parameter=shape> rectangle @@ -237,18 +240,25 @@ def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized): 2 </parameter> </function> -</tool_call>''', [ - ToolCall(function=FunctionCall(name="calculate_area", - arguments=json.dumps({ - "shape": "rectangle", - "dimensions": { - "width": 10, - "height": 20 - }, - "precision": 2 - }))) - ], None), - ('''<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="calculate_area", + arguments=json.dumps( + { + "shape": "rectangle", + "dimensions": {"width": 10, "height": 20}, + "precision": 2, + } + ), + ) + ) + ], + None, + ), + ( + """<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -273,23 +283,29 @@ def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized): fahrenheit </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))), - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit" - }))) - ], None), - ('''Let me calculate that area for you.<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Orlando", "state": "FL", "unit": "fahrenheit"} + ), + ) + ), + ], + None, + ), + ( + """Let me calculate that area for you.<tool_call> <function=calculate_area> <parameter=shape> circle @@ -301,26 +317,36 @@ def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized): 3 </parameter> </function> -</tool_call>''', [ - ToolCall(function=FunctionCall(name="calculate_area", - arguments=json.dumps({ - "shape": "circle", - "dimensions": { - "radius": 15.5 - }, - "precision": 3 - }))) - ], "Let me calculate that area for you."), +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="calculate_area", + arguments=json.dumps( + { + "shape": "circle", + "dimensions": {"radius": 15.5}, + "precision": 3, + } + ), + ) + ) + ], + "Let me calculate that area for you.", + ), ], ) -def test_extract_tool_calls(qwen3_tool_parser_parametrized, sample_tools, - model_output, expected_tool_calls, - expected_content): - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) +def test_extract_tool_calls( + qwen3_tool_parser_parametrized, + sample_tools, + model_output, + expected_tool_calls, + expected_content, +): + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( - model_output, request=request) + model_output, request=request + ) assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -328,60 +354,51 @@ def test_extract_tool_calls(qwen3_tool_parser_parametrized, sample_tools, assert extracted_tool_calls.content == expected_content -def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser_parametrized, - sample_tools): +def test_extract_tool_calls_fallback_no_tags( + qwen3_tool_parser_parametrized, sample_tools +): """Test fallback parsing when XML tags are missing""" - model_output = '''<function=get_current_weather> + model_output = """<function=get_current_weather> <parameter=city> Dallas </parameter> <parameter=state> TX </parameter> -</function>''' +</function>""" - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( - model_output, request=request) + model_output, request=request + ) assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 - assert (extracted_tool_calls.tool_calls[0].function.name == - "get_current_weather") + assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather" def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): """Test parameter type conversion based on tool schema""" tools = [ - ChatCompletionToolsParam(type="function", - function={ - "name": "test_types", - "parameters": { - "type": "object", - "properties": { - "int_param": { - "type": "integer" - }, - "float_param": { - "type": "float" - }, - "bool_param": { - "type": "boolean" - }, - "str_param": { - "type": "string" - }, - "obj_param": { - "type": "object" - } - } - } - }) + ChatCompletionToolsParam( + type="function", + function={ + "name": "test_types", + "parameters": { + "type": "object", + "properties": { + "int_param": {"type": "integer"}, + "float_param": {"type": "float"}, + "bool_param": {"type": "boolean"}, + "str_param": {"type": "string"}, + "obj_param": {"type": "object"}, + }, + }, + }, + ) ] - model_output = '''<tool_call> + model_output = """<tool_call> <function=test_types> <parameter=int_param> 42 @@ -399,11 +416,12 @@ def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): {"key": "value"} </parameter> </function> -</tool_call>''' +</tool_call>""" request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( - model_output, request=request) + model_output, request=request + ) args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) assert args["int_param"] == 42 @@ -425,7 +443,8 @@ def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ("This is a test without tools", [], "This is a test without tools"), - ('''<tool_call> + ( + """<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -437,16 +456,21 @@ def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): fahrenheit </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) - ], None), - ('''Sure! Let me check the weather for you.<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) + ], + None, + ), + ( + """Sure! Let me check the weather for you.<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -458,16 +482,21 @@ def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): fahrenheit </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))) - ], "Sure! Let me check the weather for you."), - ('''<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ) + ], + "Sure! Let me check the weather for you.", + ), + ( + """<tool_call> <function=calculate_area> <parameter=shape> rectangle @@ -480,18 +509,25 @@ def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): 2 </parameter> </function> -</tool_call>''', [ - ToolCall(function=FunctionCall(name="calculate_area", - arguments=json.dumps({ - "shape": "rectangle", - "dimensions": { - "width": 10, - "height": 20 - }, - "precision": 2 - }))) - ], None), - ('''<tool_call> +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="calculate_area", + arguments=json.dumps( + { + "shape": "rectangle", + "dimensions": {"width": 10, "height": 20}, + "precision": 2, + } + ), + ) + ) + ], + None, + ), + ( + """<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -516,24 +552,30 @@ def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): celsius </parameter> </function> -</tool_call>''', [ - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit" - }))), - ToolCall( - function=FunctionCall(name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "celsius" - }))) - ], None), +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Dallas", "state": "TX", "unit": "fahrenheit"} + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "Orlando", "state": "FL", "unit": "celsius"} + ), + ) + ), + ], + None, + ), # Added tool_with_typed_params test case - ('''Let me calculate that area for you.<tool_call> + ( + """Let me calculate that area for you.<tool_call> <function=calculate_area> <parameter=shape> circle @@ -545,33 +587,42 @@ def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): 3 </parameter> </function> -</tool_call>''', [ - ToolCall(function=FunctionCall(name="calculate_area", - arguments=json.dumps({ - "shape": "circle", - "dimensions": { - "radius": 15.5 - }, - "precision": 3 - }))) - ], "Let me calculate that area for you."), +</tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="calculate_area", + arguments=json.dumps( + { + "shape": "circle", + "dimensions": {"radius": 15.5}, + "precision": 3, + } + ), + ) + ) + ], + "Let me calculate that area for you.", + ), ], ) -def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized, - qwen3_tokenizer, sample_tools, - model_output, expected_tool_calls, - expected_content): +def test_extract_tool_calls_streaming( + qwen3_tool_parser_parametrized, + qwen3_tokenizer, + sample_tools, + model_output, + expected_tool_calls, + expected_content, +): """Test incremental streaming behavior including typed parameters""" - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) - other_content = '' + other_content = "" tool_states = {} # Track state per tool index for delta_message in stream_delta_message_generator( - qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, - request): + qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request + ): # role should never be streamed from tool parser assert not delta_message.role @@ -588,7 +639,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized, "id": None, "name": None, "arguments": "", - "type": None + "type": None, } # First chunk should have id, name, and type @@ -607,8 +658,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized, if tool_call.function.arguments is not None: # Accumulate arguments incrementally - tool_states[idx][ - "arguments"] += tool_call.function.arguments + tool_states[idx]["arguments"] += tool_call.function.arguments # Verify final content assert other_content == (expected_content or "") # Handle None case @@ -632,10 +682,11 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized, def test_extract_tool_calls_missing_closing_parameter_tag( - qwen3_tool_parser_parametrized, sample_tools): + qwen3_tool_parser_parametrized, sample_tools +): """Test handling of missing closing </parameter> tag""" # Using get_current_weather from sample_tools but with malformed XML - model_output = '''Let me check the weather for you: + model_output = """Let me check the weather for you: <tool_call> <function=get_current_weather> <parameter=city> @@ -647,21 +698,19 @@ def test_extract_tool_calls_missing_closing_parameter_tag( fahrenheit </parameter> </function> -</tool_call>''' +</tool_call>""" - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( - model_output, request=request) + model_output, request=request + ) # The parser should handle the malformed XML gracefully assert extracted_tool_calls.tools_called assert len(extracted_tool_calls.tool_calls) == 1 # Verify the function name is correct - assert extracted_tool_calls.tool_calls[ - 0].function.name == "get_current_weather" + assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather" # Verify the arguments are parsed despite the missing closing tag args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) @@ -675,10 +724,11 @@ def test_extract_tool_calls_missing_closing_parameter_tag( def test_extract_tool_calls_streaming_missing_closing_tag( - qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools): + qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools +): """Test streaming with missing closing </parameter> tag""" # Using get_current_weather from sample_tools but with malformed XML - model_output = '''Let me check the weather for you: + model_output = """Let me check the weather for you: <tool_call> <function=get_current_weather> <parameter=city> @@ -690,19 +740,16 @@ def test_extract_tool_calls_streaming_missing_closing_tag( fahrenheit </parameter> </function> -</tool_call>''' +</tool_call>""" - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) - other_content = '' + other_content = "" tool_states = {} for delta_message in stream_delta_message_generator( - qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, - request): - + qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request + ): if delta_message.content: other_content += delta_message.content @@ -715,7 +762,7 @@ def test_extract_tool_calls_streaming_missing_closing_tag( "id": None, "name": None, "arguments": "", - "type": None + "type": None, } if tool_call.id: @@ -730,8 +777,7 @@ def test_extract_tool_calls_streaming_missing_closing_tag( tool_states[idx]["name"] = tool_call.function.name if tool_call.function.arguments is not None: - tool_states[idx][ - "arguments"] += tool_call.function.arguments + tool_states[idx]["arguments"] += tool_call.function.arguments # Verify content was streamed assert "Let me check the weather for you:" in other_content @@ -752,9 +798,10 @@ def test_extract_tool_calls_streaming_missing_closing_tag( def test_extract_tool_calls_streaming_incremental( - qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools): + qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools +): """Test that streaming is truly incremental""" - model_output = '''I'll check the weather.<tool_call> + model_output = """I'll check the weather.<tool_call> <function=get_current_weather> <parameter=city> Dallas @@ -763,16 +810,14 @@ def test_extract_tool_calls_streaming_incremental( TX </parameter> </function> -</tool_call>''' +</tool_call>""" - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) chunks = [] for delta_message in stream_delta_message_generator( - qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, - request): + qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request + ): chunks.append(delta_message) # Should have multiple chunks @@ -787,7 +832,7 @@ def test_extract_tool_calls_streaming_incremental( for chunk in chunks: if chunk.tool_calls and chunk.tool_calls[0].id: header_found = True - assert (chunk.tool_calls[0].function.name == "get_current_weather") + assert chunk.tool_calls[0].function.name == "get_current_weather" assert chunk.tool_calls[0].type == "function" # Empty initially assert chunk.tool_calls[0].function.arguments == "" @@ -811,46 +856,40 @@ def test_extract_tool_calls_streaming_incremental( def test_extract_tool_calls_complex_type_with_single_quote( - qwen3_tool_parser_parametrized): + qwen3_tool_parser_parametrized, +): """Test parameter type conversion based on tool schema""" tools = [ - ChatCompletionToolsParam(type="function", - function={ - "name": "test_types", - "parameters": { - "type": "object", - "properties": { - "int_param": { - "type": "integer" - }, - "float_param": { - "type": "float" - }, - "bool_param": { - "type": "boolean" - }, - "str_param": { - "type": "string" - }, - "obj_param": { - "type": "object" - } - } - } - }) + ChatCompletionToolsParam( + type="function", + function={ + "name": "test_types", + "parameters": { + "type": "object", + "properties": { + "int_param": {"type": "integer"}, + "float_param": {"type": "float"}, + "bool_param": {"type": "boolean"}, + "str_param": {"type": "string"}, + "obj_param": {"type": "object"}, + }, + }, + }, + ) ] - model_output = '''<tool_call> + model_output = """<tool_call> <function=test_types> <parameter=obj_param> {'key': 'value'} </parameter> </function> -</tool_call>''' +</tool_call>""" request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( - model_output, request=request) + model_output, request=request + ) args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) assert args["obj_param"] == {"key": "value"} diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py index 5100b5ac120b..eddb5a9b9f5e 100644 --- a/tests/tool_use/test_seed_oss_tool_parser.py +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -8,10 +8,13 @@ import pytest -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaMessage, FunctionCall, - ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer @@ -45,51 +48,56 @@ def sample_tools(): "properties": { "location": { "type": "string", - "description": - "City and country e.g. Bogotá, Colombia" + "description": "City and country e.g. Bogotá, Colombia", }, "unit": { "type": "string", - "description": "this is the unit of temperature" - } + "description": "this is the unit of temperature", + }, }, "required": ["location"], - "additionalProperties": False + "additionalProperties": False, }, "returns": { "type": "object", "properties": { "temperature": { "type": "number", - "description": "temperature in celsius" + "description": "temperature in celsius", } }, "required": ["temperature"], - "additionalProperties": False + "additionalProperties": False, }, - "strict": True - }), + "strict": True, + }, + ), ] -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): # Seed-OSS tool call will not generate id assert actual_tool_call.type == "function" assert actual_tool_call.function == expected_tool_call.function assert actual_tool_call.function.name == expected_tool_call.function.name - assert actual_tool_call.function.arguments == expected_tool_call.function.arguments + assert ( + actual_tool_call.function.arguments == expected_tool_call.function.arguments + ) def test_extract_tool_calls_no_tools(seed_oss_tool_parser): model_output = "This is a test response without any tool calls" extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] @@ -104,17 +112,24 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser): ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ - ("""<seed:tool_call>\n<function=get_weather>\n""" - """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", - [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "location": "Barcelona, Spain", - }, ), - ), - type='function') - ], None), + ( + """<seed:tool_call>\n<function=get_weather>\n""" + """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + }, + ), + ), + type="function", + ) + ], + None, + ), ( """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ @@ -131,13 +146,17 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser): """<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>""" """\n</seed:tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "location": "Barcelona, Spain", - }, ), - ), - type='function') + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + }, + ), + ), + type="function", + ) ], """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ @@ -169,15 +188,18 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser): """temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>""" """Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps( - { - "location": "Barcelona, Spain", - "unit": "celsius", - }, ), - ), - type='function') + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, + ), + ), + type="function", + ) ], """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ """First, I need to remember the function I can use: get_weather. The function requires a """ @@ -196,13 +218,17 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser): ), ], ) -def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output, - expected_tool_calls, expected_content): - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) +def test_extract_tool_calls( + seed_oss_tool_parser, + sample_tools, + model_output, + expected_tool_calls, + expected_content, +): + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls( - model_output, request=request) # type: ignore[arg-type] + model_output, request=request + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -225,7 +251,7 @@ def test_streaming_tool_calls_no_tools(seed_oss_tool_parser): # Should return the delta text as content assert result is not None - assert hasattr(result, 'content') + assert hasattr(result, "content") assert result.content == " without any tool calls." @@ -233,10 +259,9 @@ def stream_delta_message_generator( seed_oss_tool_parser: SeedOssToolParser, seed_oss_tokenizer: AnyTokenizer, model_output: str, - request: Optional[ChatCompletionRequest] = None + request: Optional[ChatCompletionRequest] = None, ) -> Generator[DeltaMessage, None, None]: - all_token_ids = seed_oss_tokenizer.encode(model_output, - add_special_tokens=False) + all_token_ids = seed_oss_tokenizer.encode(model_output, add_special_tokens=False) previous_text = "" previous_tokens = None @@ -245,18 +270,19 @@ def stream_delta_message_generator( for i, delta_token in enumerate(all_token_ids): delta_token_ids = [delta_token] previous_token_ids = all_token_ids[:i] - current_token_ids = all_token_ids[:i + 1] - - (new_tokens, delta_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=seed_oss_tokenizer, - all_input_ids=current_token_ids, - prev_tokens=previous_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=False, - spaces_between_special_tokens=True, - ) + current_token_ids = all_token_ids[: i + 1] + + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=seed_oss_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + ) current_text = previous_text + delta_text @@ -273,8 +299,9 @@ def stream_delta_message_generator( yield delta_message previous_text = current_text - previous_tokens = (previous_tokens + - new_tokens if previous_tokens else new_tokens) + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) prefix_offset = new_prefix_offset read_offset = new_read_offset @@ -287,22 +314,27 @@ def stream_delta_message_generator( ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ - ("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" - """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" - """<seed:tool_call>\n<function=get_weather>\n""" - """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", - [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "location": "Barcelona, Spain", - }, ), - ), - type='function') - ], - """<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" - """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" - ), + ( + """<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""" + """<seed:tool_call>\n<function=get_weather>\n""" + """<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""", + [ + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + }, + ), + ), + type="function", + ) + ], + """<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n""" + """The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""", + ), ( """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ @@ -319,13 +351,17 @@ def stream_delta_message_generator( """<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>""" """\n</seed:tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps({ - "location": "Barcelona, Spain", - }, ), - ), - type='function') + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + }, + ), + ), + type="function", + ) ], """<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """ """question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """ @@ -357,15 +393,18 @@ def stream_delta_message_generator( """temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>""" """Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""", [ - ToolCall(function=FunctionCall( - name="get_weather", - arguments=json.dumps( - { - "location": "Barcelona, Spain", - "unit": "celsius", - }, ), - ), - type='function') + ToolCall( + function=FunctionCall( + name="get_weather", + arguments=json.dumps( + { + "location": "Barcelona, Spain", + "unit": "celsius", + }, + ), + ), + type="function", + ) ], """<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """ """First, I need to remember the function I can use: get_weather. The function requires a """ @@ -384,19 +423,23 @@ def stream_delta_message_generator( ), ], ) -def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, - sample_tools, model_output, expected_tool_calls, - expected_content): +def test_streaming_tool_calls( + seed_oss_tool_parser, + seed_oss_tokenizer, + sample_tools, + model_output, + expected_tool_calls, + expected_content, +): """Test incremental streaming behavior""" - request = ChatCompletionRequest(model=MODEL, - messages=[], - tools=sample_tools) + request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) - other_content = '' + other_content = "" tool_states = {} # Track state per tool index for delta_message in stream_delta_message_generator( - seed_oss_tool_parser, seed_oss_tokenizer, model_output, request): + seed_oss_tool_parser, seed_oss_tokenizer, model_output, request + ): # role should never be streamed from tool parser assert not delta_message.role @@ -413,7 +456,7 @@ def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, "id": None, "name": None, "arguments": "", - "type": None + "type": None, } # First chunk should have id, name, and type @@ -432,8 +475,7 @@ def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer, if tool_call.function.arguments is not None: # Accumulate arguments incrementally - tool_states[idx][ - "arguments"] += tool_call.function.arguments + tool_states[idx]["arguments"] += tool_call.function.arguments # Verify final content assert other_content == expected_content diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index 53ba03a0ae10..64186aaac6a7 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -7,8 +7,12 @@ import openai import pytest -from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, - SEARCH_TOOL, WEATHER_TOOL) +from .utils import ( + MESSAGES_ASKING_FOR_TOOLS, + MESSAGES_WITH_TOOL_RESPONSE, + SEARCH_TOOL, + WEATHER_TOOL, +) # test: request a chat completion that should return tool calls, so we know they @@ -23,17 +27,18 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): max_completion_tokens=100, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] stop_reason = chat_completion.choices[0].finish_reason tool_calls = chat_completion.choices[0].message.tool_calls # make sure a tool call is present - assert choice.message.role == 'assistant' + assert choice.message.role == "assistant" assert tool_calls is not None assert len(tool_calls) == 1 - assert tool_calls[0].type == 'function' + assert tool_calls[0].type == "function" assert tool_calls[0].function is not None assert isinstance(tool_calls[0].id, str) assert len(tool_calls[0].id) >= 9 @@ -54,7 +59,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert stop_reason == "tool_calls" function_name: Optional[str] = None - function_args_str: str = '' + function_args_str: str = "" tool_call_id: Optional[str] = None role_name: Optional[str] = None finish_reason_count: int = 0 @@ -67,20 +72,21 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): max_completion_tokens=100, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, - stream=True) + stream=True, + ) async for chunk in stream: assert chunk.choices[0].index == 0 if chunk.choices[0].finish_reason: finish_reason_count += 1 - assert chunk.choices[0].finish_reason == 'tool_calls' + assert chunk.choices[0].finish_reason == "tool_calls" # if a role is being streamed make sure it wasn't already set to # something else if chunk.choices[0].delta.role: - assert not role_name or role_name == 'assistant' - role_name = 'assistant' + assert not role_name or role_name == "assistant" + role_name = "assistant" # if a tool call is streamed make sure there's exactly one # (based on the request parameters @@ -108,7 +114,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): function_args_str += tool_call.function.arguments assert finish_reason_count == 1 - assert role_name == 'assistant' + assert role_name == "assistant" assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9) # validate the name and arguments @@ -148,14 +154,14 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): max_completion_tokens=100, model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], - logprobs=False) + logprobs=False, + ) choice = chat_completion.choices[0] assert choice.finish_reason != "tool_calls" # "stop" or "length" assert choice.message.role == "assistant" - assert choice.message.tool_calls is None \ - or len(choice.message.tool_calls) == 0 + assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0 assert choice.message.content is not None assert "98" in choice.message.content # the temperature from the response @@ -166,7 +172,8 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): model=model_name, tools=[WEATHER_TOOL, SEARCH_TOOL], logprobs=False, - stream=True) + stream=True, + ) chunks: list[str] = [] finish_reason_count = 0 diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index 7c63816cd6f5..d52c141f6210 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -8,8 +8,10 @@ import regex as re from pydantic import TypeAdapter -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, +) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat pytestmark = pytest.mark.cpu_test @@ -24,18 +26,16 @@ "type": "object", "properties": { "city": { - "type": - "string", - "description": - "The city to find the weather for" + "type": "string", + "description": "The city to find the weather for" ", e.g. 'San Francisco'", }, }, "required": ["city"], - "additionalProperties": False + "additionalProperties": False, }, }, - "strict": True + "strict": True, }, { "type": "function", @@ -46,35 +46,34 @@ "type": "object", "properties": { "city": { - "type": - "string", - "description": - "The city to get the forecast for, e.g. 'New York'", + "type": "string", + "description": "The city to get the forecast for, e.g. " + "'New York'", }, "days": { - "type": - "integer", - "description": - "Number of days to get the forecast for (1-7)", + "type": "integer", + "description": "Number of days to get the forecast for (1-7)", }, }, "required": ["city", "days"], - "additionalProperties": False + "additionalProperties": False, }, }, - "strict": True + "strict": True, }, ] -def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, - should_match: bool): +def _compile_and_check( + tools: list[ChatCompletionToolsParam], sample_output, should_match: bool +): self = MagicMock(tool_choice="required", tools=tools) schema = ChatCompletionRequest._get_json_schema_from_tool(self) assert isinstance(schema, dict) # use build_regex_from_schema used in JSONLogitsProcessor to create Guide from outlines_core.json_schema import build_regex_from_schema + regex = build_regex_from_schema(json.dumps(schema)) compiled = re.compile(regex) matches = compiled.fullmatch(json.dumps(sample_output)) is not None @@ -83,65 +82,31 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, VALID_TOOL_OUTPUTS = [ - ([{ - "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } - }], True), - ([{ - "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } - }, { - "name": "get_current_weather", - "parameters": { - "city": "Berlin" - } - }], True), - ([{ - "name": "get_forecast", - "parameters": { - "city": "Vienna", - "days": 7 - } - }], True), - ([{ - "name": "get_forecast", - "parameters": { - "city": "Vienna", - "days": 7 - } - }, { - "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } - }], True), - ([{ - "name": "get_forecast", - "parameters": { - "city": "Vienna", - "days": 7 - } - }, { - "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } - }, { - "name": "get_forecast", - "parameters": { - "city": "Berlin", - "days": 7 - } - }, { - "name": "get_current_weather", - "parameters": { - "city": "Berlin" - } - }], True), + ([{"name": "get_current_weather", "parameters": {"city": "Vienna"}}], True), + ( + [ + {"name": "get_current_weather", "parameters": {"city": "Vienna"}}, + {"name": "get_current_weather", "parameters": {"city": "Berlin"}}, + ], + True, + ), + ([{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}], True), + ( + [ + {"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}, + {"name": "get_current_weather", "parameters": {"city": "Vienna"}}, + ], + True, + ), + ( + [ + {"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}, + {"name": "get_current_weather", "parameters": {"city": "Vienna"}}, + {"name": "get_forecast", "parameters": {"city": "Berlin", "days": 7}}, + {"name": "get_current_weather", "parameters": {"city": "Berlin"}}, + ], + True, + ), ] VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS] @@ -149,92 +114,100 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, @pytest.mark.parametrize( "sample_output, should_match", - VALID_TOOL_OUTPUTS + [ + VALID_TOOL_OUTPUTS + + [ (None, False), ([], False), # empty list cannot be generated ({}, False), # empty object cannot be generated ([{}], False), # list with empty object cannot be generated ( - [{ # function without required parameters cannot be generated - "name": "get_current_weather" - }], - False), + [ + { # function without required parameters cannot be generated + "name": "get_current_weather" + } + ], + False, + ), ( - [{ # function without required parameters cannot be generated - "name": "get_current_weather", - "parameters": {} - }], - False), + [ + { # function without required parameters cannot be generated + "name": "get_current_weather", + "parameters": {}, + } + ], + False, + ), ( - [{ # function without required parameters cannot be generated - "name": "get_current_weather", - "parameters": None - }], - False), + [ + { # function without required parameters cannot be generated + "name": "get_current_weather", + "parameters": None, + } + ], + False, + ), ( { # tool call without lists cannot be generated "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } + "parameters": {"city": "Vienna"}, }, - False), + False, + ), ( - [{ # tool call with extra parameters cannot be generated - "name": "get_current_weather", - "parameters": { - "city": "Vienna", - "extra": "value" + [ + { # tool call with extra parameters cannot be generated + "name": "get_current_weather", + "parameters": {"city": "Vienna", "extra": "value"}, } - }], - False), + ], + False, + ), ( - [{ # tool call where parameters are first cannot be generated - "parameters": { - "city": "Vienna" - }, - "name": "get_current_weather" - }], - False), + [ + { # tool call where parameters are first cannot be generated + "parameters": {"city": "Vienna"}, + "name": "get_current_weather", + } + ], + False, + ), ( - [{ # tool call without all required parameters cannot be generated - "name": "get_forecast", - "parameters": { - "city": "Vienna" + [ + { # tool call without all required parameters cannot be generated + "name": "get_forecast", + "parameters": {"city": "Vienna"}, } - }], - False), + ], + False, + ), ( # tool call with incorrect name/parameters cannot be generated - [{ - "name": "get_weather", - "parameters": { - "city": "Vienna", - "days": 7 - } - }], False), + [{"name": "get_weather", "parameters": {"city": "Vienna", "days": 7}}], + False, + ), ( # tool call with both valid and empty function cannot be generated - [{ - "name": "get_current_weather", - "parameters": { - "city": "Vienna" - } - }, {}], False), - ]) + [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}, {}], + False, + ), + ], +) def test_structured_outputs_json(sample_output, should_match): - _compile_and_check(tools=TypeAdapter( - list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS), - sample_output=sample_output, - should_match=should_match) + _compile_and_check( + tools=TypeAdapter(list[ChatCompletionToolsParam]).validate_python( + EXAMPLE_TOOLS + ), + sample_output=sample_output, + should_match=should_match, + ) -def update_parameters_none( - tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam: +def update_parameters_none(tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam: tool.function.parameters = None return tool def update_parameters_empty_dict( - tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam: + tool: ChatCompletionToolsParam, +) -> ChatCompletionToolsParam: tool.function.parameters = {} return tool @@ -247,48 +220,60 @@ def update_parameters_empty_dict( ({}, False), # empty object cannot be generated ([{}], False), # list with empty object cannot be generated ( - [{ # function without required parameters cannot be generated - "name": "get_current_weather" - }], - False), + [ + { # function without required parameters cannot be generated + "name": "get_current_weather" + } + ], + False, + ), ( - [{ # function without required parameters cannot be generated - "name": "get_current_weather", - "parameters": None - }], - False), + [ + { # function without required parameters cannot be generated + "name": "get_current_weather", + "parameters": None, + } + ], + False, + ), ( - [{ # function with extra parameters cannot be generated - "name": "get_current_weather", - "parameters": { - "extra": "value" + [ + { # function with extra parameters cannot be generated + "name": "get_current_weather", + "parameters": {"extra": "value"}, } - }], - False), + ], + False, + ), ( - [{ # only function with empty parameters object is valid - "name": "get_current_weather", - "parameters": {} - }], - True), - ]) + [ + { # only function with empty parameters object is valid + "name": "get_current_weather", + "parameters": {}, + } + ], + True, + ), + ], +) @pytest.mark.parametrize( - "update_parameters", - [update_parameters_none, update_parameters_empty_dict]) -def test_structured_outputs_json_without_parameters(sample_output, - should_match, - update_parameters): + "update_parameters", [update_parameters_none, update_parameters_empty_dict] +) +def test_structured_outputs_json_without_parameters( + sample_output, should_match, update_parameters +): updated_tools = [deepcopy(EXAMPLE_TOOLS[0])] - tools = TypeAdapter( - list[ChatCompletionToolsParam]).validate_python(updated_tools) + tools = TypeAdapter(list[ChatCompletionToolsParam]).validate_python(updated_tools) tools = list(map(update_parameters, tools)) - assert all([ - tool.function.parameters is None or tool.function.parameters == {} - for tool in tools - ]) - _compile_and_check(tools=tools, - sample_output=sample_output, - should_match=should_match) + assert all( + [ + tool.function.parameters is None or tool.function.parameters == {} + for tool in tools + ] + ) + _compile_and_check( + tools=tools, sample_output=sample_output, should_match=should_match + ) @pytest.mark.parametrize("output", VALID_TOOLS) @@ -306,7 +291,7 @@ def test_streaming_output_valid(output, empty_params, delta_len): function_name_returned = False messages = [] for i in range(0, len(output_json), delta_len): - delta_text = output_json[i:i + delta_len] + delta_text = output_json[i : i + delta_len] current_text = previous_text + delta_text delta_message, function_name_returned = ( @@ -315,7 +300,9 @@ def test_streaming_output_valid(output, empty_params, delta_len): previous_text=previous_text, current_text=current_text, delta_text=delta_text, - function_name_returned=function_name_returned)) + function_name_returned=function_name_returned, + ) + ) if delta_message: messages.append(delta_message) @@ -329,10 +316,12 @@ def test_streaming_output_valid(output, empty_params, delta_len): if len(combined_messages) > 1: combined_messages += "}," - combined_messages += '{"name": "' + \ - message.tool_calls[0].function.name + \ - '", "parameters": ' + \ - message.tool_calls[0].function.arguments + combined_messages += ( + '{"name": "' + + message.tool_calls[0].function.name + + '", "parameters": ' + + message.tool_calls[0].function.arguments + ) else: combined_messages += message.tool_calls[0].function.arguments combined_messages += "}]" diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py index 94e2a37cbf63..bdac878db4e7 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_use/test_xlam_tool_parser.py @@ -7,9 +7,12 @@ import pytest -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, FunctionCall, - ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import xLAMToolParser from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer @@ -30,12 +33,14 @@ def xlam_tool_parser(xlam_tokenizer): return xLAMToolParser(xlam_tokenizer) -def assert_tool_calls(actual_tool_calls: list[ToolCall], - expected_tool_calls: list[ToolCall]): +def assert_tool_calls( + actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall] +): assert len(actual_tool_calls) == len(expected_tool_calls) - for actual_tool_call, expected_tool_call in zip(actual_tool_calls, - expected_tool_calls): + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): assert isinstance(actual_tool_call.id, str) assert len(actual_tool_call.id) > 16 @@ -49,8 +54,7 @@ def stream_delta_message_generator( model_output: str, request: Optional[ChatCompletionRequest] = None, ) -> Generator[DeltaMessage, None, None]: - all_token_ids = xlam_tokenizer.encode(model_output, - add_special_tokens=False) + all_token_ids = xlam_tokenizer.encode(model_output, add_special_tokens=False) previous_text = "" previous_tokens = None @@ -59,18 +63,19 @@ def stream_delta_message_generator( for i, delta_token in enumerate(all_token_ids): delta_token_ids = [delta_token] previous_token_ids = all_token_ids[:i] - current_token_ids = all_token_ids[:i + 1] - - (new_tokens, delta_text, new_prefix_offset, - new_read_offset) = (detokenize_incrementally( - tokenizer=xlam_tokenizer, - all_input_ids=current_token_ids, - prev_tokens=previous_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=False, - spaces_between_special_tokens=True, - )) + current_token_ids = all_token_ids[: i + 1] + + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=xlam_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + ) current_text = previous_text + delta_text @@ -87,8 +92,9 @@ def stream_delta_message_generator( yield delta_message previous_text = current_text - previous_tokens = (previous_tokens + - new_tokens if previous_tokens else new_tokens) + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) prefix_offset = new_prefix_offset read_offset = new_read_offset @@ -96,7 +102,8 @@ def stream_delta_message_generator( def test_extract_tool_calls_no_tools(xlam_tool_parser): model_output = "This is a test" extracted_tool_calls = xlam_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert not extracted_tool_calls.tools_called assert extracted_tool_calls.tool_calls == [] assert extracted_tool_calls.content == model_output @@ -115,87 +122,113 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser): ( """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )), - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit", - }), - )), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + } + ), + ) + ), ], None, ), ( """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "<think>I'll help you with that.</think>", ), ( """I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "I'll help you with that.", ), ( """I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "I'll check the weather for you.", ), ( """I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "I'll help you check the weather.", ), ], ) -def test_extract_tool_calls(xlam_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_calls( + xlam_tool_parser, model_output, expected_tool_calls, expected_content +): extracted_tool_calls = xlam_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -210,25 +243,30 @@ def test_extract_tool_calls(xlam_tool_parser, model_output, ( """[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Seattle", - "state": "WA", - "unit": "celsius", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Seattle", + "state": "WA", + "unit": "celsius", + } + ), + ) + ) ], None, ), ], ) -def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output, - expected_tool_calls, - expected_content): +def test_extract_tool_calls_list_structure( + xlam_tool_parser, model_output, expected_tool_calls, expected_content +): """Test extraction of tool calls when the model outputs a list-structured tool call.""" # noqa: E501 extracted_tool_calls = xlam_tool_parser.extract_tool_calls( - model_output, request=None) # type: ignore[arg-type] + model_output, request=None + ) # type: ignore[arg-type] assert extracted_tool_calls.tools_called assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) @@ -239,20 +277,25 @@ def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output, # Test for preprocess_model_output method def test_preprocess_model_output(xlam_tool_parser): # Test with list structure - model_output = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501 + model_output = ( + """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501 + ) content, potential_tool_calls = xlam_tool_parser.preprocess_model_output( - model_output) + model_output + ) assert content is None assert potential_tool_calls == model_output # Test with thinking tag model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501 content, potential_tool_calls = xlam_tool_parser.preprocess_model_output( - model_output) + model_output + ) assert content == "<think>I'll help you with that.</think>" assert ( - potential_tool_calls == - '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]') + potential_tool_calls + == '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]' + ) # Test with JSON code block model_output = """I'll help you with that. @@ -260,14 +303,16 @@ def test_preprocess_model_output(xlam_tool_parser): [{"name": "get_current_weather", "arguments": {"city": "Seattle"}}] ```""" content, potential_tool_calls = xlam_tool_parser.preprocess_model_output( - model_output) + model_output + ) assert content == "I'll help you with that." assert "get_current_weather" in potential_tool_calls # Test with no tool calls model_output = """I'll help you with that.""" content, potential_tool_calls = xlam_tool_parser.preprocess_model_output( - model_output) + model_output + ) assert content == model_output assert potential_tool_calls is None @@ -281,7 +326,9 @@ def test_streaming_with_list_structure(xlam_tool_parser): xlam_tool_parser.current_tool_id = -1 # Simulate receiving a message with list structure - current_text = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501 + current_text = ( + """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501 + ) # First call to set up the tool xlam_tool_parser.extract_tool_calls_streaming( @@ -295,8 +342,7 @@ def test_streaming_with_list_structure(xlam_tool_parser): ) # Make sure the tool is set up correctly - assert (xlam_tool_parser.current_tool_id - >= 0), "Tool index should be initialized" + assert xlam_tool_parser.current_tool_id >= 0, "Tool index should be initialized" # Manually set up the state for sending the tool name xlam_tool_parser.current_tools_sent = [False] @@ -332,78 +378,102 @@ def test_streaming_with_list_structure(xlam_tool_parser): ( """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )), - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Orlando", - "state": "FL", - "unit": "fahrenheit", - }), - )), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + } + ), + ) + ), ], "", ), ( """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "<think>I'll help you with that.</think>", ), ( """```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "", ), ( """[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "", ), ( """I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501 [ - ToolCall(function=FunctionCall( - name="get_current_weather", - arguments=json.dumps({ - "city": "Dallas", - "state": "TX", - "unit": "fahrenheit", - }), - )) + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + } + ), + ) + ) ], "I can help with that.", ), @@ -421,7 +491,8 @@ def test_extract_tool_calls_streaming_incremental( chunks = [] for delta_message in stream_delta_message_generator( - xlam_tool_parser, xlam_tokenizer, model_output, request): + xlam_tool_parser, xlam_tokenizer, model_output, request + ): chunks.append(delta_message) # Should have multiple chunks @@ -433,8 +504,9 @@ def test_extract_tool_calls_streaming_incremental( for chunk in chunks: if chunk.tool_calls and chunk.tool_calls[0].id: header_found = True - assert (chunk.tool_calls[0].function.name == - expected_first_tool.function.name) + assert ( + chunk.tool_calls[0].function.name == expected_first_tool.function.name + ) assert chunk.tool_calls[0].type == "function" # Arguments may be empty initially or None if chunk.tool_calls[0].function.arguments is not None: @@ -446,11 +518,13 @@ def test_extract_tool_calls_streaming_incremental( # Should have chunks with incremental arguments arg_chunks = [] for chunk in chunks: - if (chunk.tool_calls and chunk.tool_calls[0].function.arguments - and chunk.tool_calls[0].function.arguments != "" - and chunk.tool_calls[0].index == - 0 # Only collect arguments from the first tool call - ): + if ( + chunk.tool_calls + and chunk.tool_calls[0].function.arguments + and chunk.tool_calls[0].function.arguments != "" + and chunk.tool_calls[0].index + == 0 # Only collect arguments from the first tool call + ): arg_chunks.append(chunk.tool_calls[0].function.arguments) # Arguments should be streamed incrementally diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index a17fab9aecbc..835d07608e40 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -4,8 +4,7 @@ from copy import deepcopy from typing import Any, Optional -from openai.types.chat import (ChatCompletionMessageParam, - ChatCompletionToolParam) +from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam from typing_extensions import TypedDict from tests.utils import VLLM_PATH @@ -20,8 +19,9 @@ class ServerConfig(TypedDict, total=False): extended: Optional[bool] # tests do not run in CI automatically -def patch_system_prompt(messages: list[dict[str, Any]], - system_prompt: str) -> list[dict[str, Any]]: +def patch_system_prompt( + messages: list[dict[str, Any]], system_prompt: str +) -> list[dict[str, Any]]: new_messages = deepcopy(messages) if new_messages[0]["role"] == "system": new_messages[0]["content"] = system_prompt @@ -30,8 +30,9 @@ def patch_system_prompt(messages: list[dict[str, Any]], return new_messages -def ensure_system_prompt(messages: list[dict[str, Any]], - config: ServerConfig) -> list[dict[str, Any]]: +def ensure_system_prompt( + messages: list[dict[str, Any]], config: ServerConfig +) -> list[dict[str, Any]]: prompt = config.get("system_prompt") if prompt: return patch_system_prompt(messages, prompt) @@ -42,92 +43,102 @@ def ensure_system_prompt(messages: list[dict[str, Any]], # universal args for all models go here. also good if you need to test locally # and change type or KV cache quantization or something. ARGS: list[str] = [ - "--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs", - "256" + "--enable-auto-tool-choice", + "--max-model-len", + "1024", + "--max-num-seqs", + "256", ] CONFIGS: dict[str, ServerConfig] = { "hermes": { - "model": - "NousResearch/Hermes-3-Llama-3.1-8B", + "model": "NousResearch/Hermes-3-Llama-3.1-8B", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "hermes", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "hermes", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja"), ], - "system_prompt": - "You are a helpful assistant with access to tools. If a tool" + "system_prompt": "You are a helpful assistant with access to tools. If a tool" " that you have would be helpful to answer a user query, " "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." + "to the user's question - just respond to it normally.", }, "llama": { - "model": - "meta-llama/Meta-Llama-3.1-8B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "llama3_json", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_llama3.1_json.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "llama3_json", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_llama3.1_json.jinja"), ], - "supports_parallel": - False, + "supports_parallel": False, }, "llama3.2": { - "model": - "meta-llama/Llama-3.2-3B-Instruct", + "model": "meta-llama/Llama-3.2-3B-Instruct", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "llama3_json", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_llama3.2_json.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "llama3_json", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_llama3.2_json.jinja"), ], - "supports_parallel": - False, + "supports_parallel": False, }, "llama4": { - "model": - "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "model": "meta-llama/Llama-4-Scout-17B-16E-Instruct", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "llama4_pythonic", "--chat-template", - str(VLLM_PATH / - "examples/tool_chat_template_llama4_pythonic.jinja"), "-tp", - "4" + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "llama4_pythonic", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_llama4_pythonic.jinja"), + "-tp", + "4", ], - "supports_parallel": - False, - "extended": - True + "supports_parallel": False, + "extended": True, }, "llama4_json": { - "model": - "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "model": "meta-llama/Llama-4-Scout-17B-16E-Instruct", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", "-tp", "4", - "--distributed-executor-backend", "mp", "--tool-call-parser", - "llama4_json", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "-tp", + "4", + "--distributed-executor-backend", + "mp", + "--tool-call-parser", + "llama4_json", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja"), ], - "supports_parallel": - True, - "extended": - True + "supports_parallel": True, + "extended": True, }, "mistral": { - "model": - "mistralai/Mistral-7B-Instruct-v0.3", + "model": "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "mistral", "--chat-template", + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "mistral", + "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), - "--ignore-patterns=\"consolidated.safetensors\"" + '--ignore-patterns="consolidated.safetensors"', ], - "system_prompt": - "You are a helpful assistant with access to tools. If a tool" + "system_prompt": "You are a helpful assistant with access to tools. If a tool" " that you have would be helpful to answer a user query, " "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " - "to the user's question - just respond to it normally." + "to the user's question - just respond to it normally.", }, # V1 Test: Passing locally but failing in CI. This runs the # V0 Engine because of CPU offloading. Need to debug why. @@ -146,49 +157,50 @@ def ensure_system_prompt(messages: list[dict[str, Any]], # False, # }, "granite-3.0-8b": { - "model": - "ibm-granite/granite-3.0-8b-instruct", + "model": "ibm-granite/granite-3.0-8b-instruct", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "granite", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_granite.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "granite", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_granite.jinja"), ], }, "granite-3.1-8b": { - "model": - "ibm-granite/granite-3.1-8b-instruct", + "model": "ibm-granite/granite-3.1-8b-instruct", "arguments": [ "--enforce-eager", "--no-enable-prefix-caching", "--tool-call-parser", "granite", ], - "supports_parallel": - True, + "supports_parallel": True, }, "internlm": { - "model": - "internlm/internlm2_5-7b-chat", + "model": "internlm/internlm2_5-7b-chat", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "internlm", "--chat-template", - str(VLLM_PATH / - "examples/tool_chat_template_internlm2_tool.jinja"), - "--trust_remote_code" + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "internlm", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_internlm2_tool.jinja"), + "--trust_remote_code", ], - "supports_parallel": - False, + "supports_parallel": False, }, "toolACE": { - "model": - "Team-ACE/ToolACE-8B", + "model": "Team-ACE/ToolACE-8B", "arguments": [ - "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "pythonic", "--chat-template", - str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja") + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "pythonic", + "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja"), ], - "supports_parallel": - True, + "supports_parallel": True, }, } @@ -201,37 +213,31 @@ def ensure_system_prompt(messages: list[dict[str, Any]], "type": "object", "properties": { "city": { - "type": - "string", - "description": - "The city to find the weather for, " - "e.g. 'San Francisco'" + "type": "string", + "description": "The city to find the weather for, " + "e.g. 'San Francisco'", }, "state": { - "type": - "string", - "description": - "must the two-letter abbreviation for the state " + "type": "string", + "description": "must the two-letter abbreviation for the state " "that the city is in, e.g. 'CA' which would " - "mean 'California'" + "mean 'California'", }, "unit": { "type": "string", "description": "The unit to fetch the temperature in", - "enum": ["celsius", "fahrenheit"] - } - } - } - } + "enum": ["celsius", "fahrenheit"], + }, + }, + }, + }, } SEARCH_TOOL: ChatCompletionToolParam = { "type": "function", "function": { - "name": - "web_search", - "description": - "Search the internet and get a summary of the top " + "name": "web_search", + "description": "Search the internet and get a summary of the top " "10 webpages. Should only be used if you don't know " "the answer to a user query, and the results are likely" "to be able to be found with a web search", @@ -239,124 +245,98 @@ def ensure_system_prompt(messages: list[dict[str, Any]], "type": "object", "properties": { "search_term": { - "type": - "string", - "description": - "The term to use in the search. This should" + "type": "string", + "description": "The term to use in the search. This should" "ideally be keywords to search for, not a" - "natural-language question" + "natural-language question", } }, - "required": ["search_term"] - } - } + "required": ["search_term"], + }, + }, } -MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "Hi! How are you?" -}, { - "role": - "assistant", - "content": - "I'm doing great! How can I assist you?" -}, { - "role": - "user", - "content": - "Can you tell me a joke please?" -}] +MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [ + {"role": "user", "content": "Hi! How are you?"}, + {"role": "assistant", "content": "I'm doing great! How can I assist you?"}, + {"role": "user", "content": "Can you tell me a joke please?"}, +] -MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas in Fahrenheit?" -}] +MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [ + {"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"} +] -MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas in Fahrenheit?" -}, { - "role": - "assistant", - "tool_calls": [{ - "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Dallas", "state": "TX", ' - '"unit": "fahrenheit"}' - } - }] -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "content": - "The weather in Dallas is 98 degrees fahrenheit, with partly" - "cloudy skies and a low chance of rain." -}] +MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [ + {"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": WEATHER_TOOL["function"]["name"], + "arguments": '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": "The weather in Dallas is 98 degrees fahrenheit, with partly" + "cloudy skies and a low chance of rain.", + }, +] -MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas and Orlando, Florida in " - "Fahrenheit?" -}] +MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [ + { + "role": "user", + "content": "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?", + } +] -MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{ - "role": - "user", - "content": - "What is the weather in Dallas, Texas and Orlando, Florida in " - "Fahrenheit?" -}, { - "role": - "assistant", - "tool_calls": [{ - "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Dallas", "state": "TX", ' - '"unit": "fahrenheit"}' - } - }, { - "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", - "type": "function", - "function": { - "name": - WEATHER_TOOL["function"]["name"], - "arguments": - '{"city": "Orlando", "state": "Fl", ' - '"unit": "fahrenheit"}' - } - }] -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-03e6481b146e408e9523d9c956696295", - "content": - "The weather in Dallas TX is 98 degrees fahrenheit with mostly " - "cloudy skies and a chance of rain in the evening." -}, { - "role": - "tool", - "tool_call_id": - "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", - "content": - "The weather in Orlando FL is 78 degrees fahrenheit with clear" - "skies." -}] +MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [ + { + "role": "user", + "content": "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": WEATHER_TOOL["function"]["name"], + "arguments": '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}', + }, + }, + { + "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "type": "function", + "function": { + "name": WEATHER_TOOL["function"]["name"], + "arguments": '{"city": "Orlando", "state": "Fl", ' + '"unit": "fahrenheit"}', + }, + }, + ], + }, + { + "role": "tool", + "tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": "The weather in Dallas TX is 98 degrees fahrenheit with mostly " + "cloudy skies and a chance of rain in the evening.", + }, + { + "role": "tool", + "tool_call_id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "content": "The weather in Orlando FL is 78 degrees fahrenheit with clear" + "skies.", + }, +] diff --git a/tests/tools/test_config_validator.py b/tests/tools/test_config_validator.py index b0475894a114..22d838d27264 100644 --- a/tests/tools/test_config_validator.py +++ b/tests/tools/test_config_validator.py @@ -7,11 +7,11 @@ from tools.validate_config import validate_ast -_TestConfig1 = ''' +_TestConfig1 = """ @config class _TestConfig1: pass -''' +""" _TestConfig2 = ''' @config @@ -21,12 +21,12 @@ class _TestConfig2: """docstring""" ''' -_TestConfig3 = ''' +_TestConfig3 = """ @config @dataclass class _TestConfig3: a: int = 1 -''' +""" _TestConfig4 = ''' @config @@ -37,12 +37,15 @@ class _TestConfig4: ''' -@pytest.mark.parametrize(("test_config", "expected_error"), [ - (_TestConfig1, "must be a dataclass"), - (_TestConfig2, "must have a default"), - (_TestConfig3, "must have a docstring"), - (_TestConfig4, "must use a single Literal"), -]) +@pytest.mark.parametrize( + ("test_config", "expected_error"), + [ + (_TestConfig1, "must be a dataclass"), + (_TestConfig2, "must have a default"), + (_TestConfig3, "must have a docstring"), + (_TestConfig4, "must use a single Literal"), + ], +) def test_config(test_config, expected_error): tree = ast.parse(test_config) with pytest.raises(Exception, match=expected_error): diff --git a/tests/tpu/lora/test_lora.py b/tests/tpu/lora/test_lora.py index 5196a92cb727..5999c9cf1e0e 100644 --- a/tests/tpu/lora/test_lora.py +++ b/tests/tpu/lora/test_lora.py @@ -29,17 +29,20 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch): def setup_vllm(num_loras: int, tp: int) -> vllm.LLM: - return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", - max_model_len=256, - max_num_seqs=8, - tensor_parallel_size=tp, - enable_lora=True, - max_loras=num_loras, - max_lora_rank=8) + return vllm.LLM( + model="Qwen/Qwen2.5-3B-Instruct", + max_model_len=256, + max_num_seqs=8, + tensor_parallel_size=tp, + enable_lora=True, + max_loras=num_loras, + max_lora_rank=8, + ) -TPU_TENSOR_PARALLEL_SIZES = [1, tpu.num_available_chips() - ] if tpu.num_available_chips() > 1 else [1] +TPU_TENSOR_PARALLEL_SIZES = ( + [1, tpu.num_available_chips()] if tpu.num_available_chips() > 1 else [1] +) @pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES) @@ -55,12 +58,19 @@ def test_single_lora(tp: int): prompt = "What is 1+1? \n" lora_request = LoRARequest( - "lora_adapter_1", 1, - "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter") - output = llm.generate(prompt, - sampling_params=vllm.SamplingParams(max_tokens=256, - temperature=0), - lora_request=lora_request)[0].outputs[0].text + "lora_adapter_1", + 1, + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter", + ) + output = ( + llm.generate( + prompt, + sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), + lora_request=lora_request, + )[0] + .outputs[0] + .text + ) answer = output.strip()[0] @@ -73,13 +83,12 @@ def test_lora_hotswapping(tp: int): """ This test ensures we can run multiple LoRA adapters on the TPU backend, even if we only have space to store 1. - + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. """ - lora_name_template = \ - "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" lora_requests = [ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) for i in range(1, 5) @@ -90,10 +99,15 @@ def test_lora_hotswapping(tp: int): prompt = "What is 1+1? \n" for i, req in enumerate(lora_requests): - output = llm.generate(prompt, - sampling_params=vllm.SamplingParams( - max_tokens=256, temperature=0), - lora_request=req)[0].outputs[0].text + output = ( + llm.generate( + prompt, + sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), + lora_request=req, + )[0] + .outputs[0] + .text + ) answer = output.strip()[0] assert answer.isdigit() @@ -105,12 +119,11 @@ def test_multi_lora(tp: int): """ This test ensures we can run multiple LoRA adapters on the TPU backend, when we have enough space to store all of them. - + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. """ - lora_name_template = \ - "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" lora_requests = [ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) for i in range(1, 5) @@ -121,10 +134,15 @@ def test_multi_lora(tp: int): prompt = "What is 1+1? \n" for i, req in enumerate(lora_requests): - output = llm.generate(prompt, - sampling_params=vllm.SamplingParams( - max_tokens=256, temperature=0), - lora_request=req)[0].outputs[0].text + output = ( + llm.generate( + prompt, + sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), + lora_request=req, + )[0] + .outputs[0] + .text + ) answer = output.strip()[0] diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 448b8b2bc094..5acfa484f0c1 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -26,16 +26,15 @@ def test_tpu_compilation(): # Currently, top-p sampling is disabled. `top_p` should be 1.0. N = 1 - sampling_params = SamplingParams(temperature=0.7, - top_p=1.0, - n=N, - max_tokens=16) + sampling_params = SamplingParams(temperature=0.7, top_p=1.0, n=N, max_tokens=16) - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - max_num_batched_tokens=256, - max_model_len=256, - max_num_seqs=32, - enforce_eager=False) + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", + max_num_batched_tokens=256, + max_model_len=256, + max_num_seqs=32, + enforce_eager=False, + ) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): @@ -45,7 +44,8 @@ def test_tpu_compilation(): assert generated_text.startswith(answer) compiled_codes = sorted( - glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py"))) + glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py")) + ) for i, compiled_code in enumerate(compiled_codes): print("{} file: {}".format(i + 1, compiled_code)) @@ -66,9 +66,10 @@ def extract_compiled_index(s): # Check all the compilations are as expected. The dump files include the # captured graph for the forward function of the nn.Module. - compiled_fns = sorted(glob.glob( - os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")), - key=lambda s: extract_compiled_index(s)) + compiled_fns = sorted( + glob.glob(os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")), + key=lambda s: extract_compiled_index(s), + ) for i, compiled_fn in enumerate(compiled_fns): print("{} file: {}".format(i + 1, compiled_fn)) @@ -82,4 +83,4 @@ def extract_compiled_index(s): # ragged_paged_attention with open(compiled_fns[1]) as f: content = f.read() - assert (kv_cache_prefix in content and attn_prefix in content) + assert kv_cache_prefix in content and attn_prefix in content diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 9c90df1b7701..102e5ddf16d6 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -15,17 +15,20 @@ def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_RPC_TIMEOUT", "30000") - compare_two_settings("Qwen/Qwen2.5-1.5B-Instruct", - arg1=[ - "--max-model-len=256", - "--max-num-seqs=32", - "--enforce-eager", - f"-O{CompilationLevel.DYNAMO_ONCE}", - ], - arg2=[ - "--max-model-len=256", "--max-num-seqs=32", - "--enforce-eager", - f"-O{CompilationLevel.DYNAMO_AS_IS}" - ], - env1={}, - env2={}) + compare_two_settings( + "Qwen/Qwen2.5-1.5B-Instruct", + arg1=[ + "--max-model-len=256", + "--max-num-seqs=32", + "--enforce-eager", + f"-O{CompilationLevel.DYNAMO_ONCE}", + ], + arg2=[ + "--max-model-len=256", + "--max-num-seqs=32", + "--enforce-eager", + f"-O{CompilationLevel.DYNAMO_AS_IS}", + ], + env1={}, + env2={}, + ) diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py index 1e5d9d923d00..e3236d20bf67 100644 --- a/tests/tpu/test_moe_pallas.py +++ b/tests/tpu/test_moe_pallas.py @@ -4,17 +4,15 @@ Run `pytest tests/kernels/moe/test_moe_pallas.py`. """ + import pytest import torch import torch_xla -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.layers.fused_moe.moe_pallas import ( - fused_moe as pallas_moe) +from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe as pallas_moe from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as torch_moe) -# yapf: enable + fused_moe as torch_moe, +) from vllm.platforms import current_platform if not current_platform.is_tpu(): @@ -43,6 +41,7 @@ def test_pallas_moe( dtype: torch.dtype, ): import torch_xla.core.xla_model as xm + with torch.device(xm.xla_device()): a = torch.randn((m, k), dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10 diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index 8d9fbd280317..151be5f17fe8 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -17,15 +17,15 @@ class GSM8KAccuracyTestConfig: expected_value: float def get_model_args(self) -> str: - return (f"pretrained={self.model_name}," - "max_model_len=4096,max_num_seqs=32") + return f"pretrained={self.model_name},max_model_len=4096,max_num_seqs=32" # NOTE: Accuracy scores measured on GPUs. ACCURACY_CONFIGS = [ GSM8KAccuracyTestConfig( model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - expected_value=0.76), # no bias + expected_value=0.76, + ), # no bias # NOTE(rob): We cannot re-initialize vLLM in the same process for TPU, # so only one of these tests can run in a single call to pytest. As # a follow-up, move this into the LM-EVAL section of the CI. @@ -37,7 +37,6 @@ def get_model_args(self) -> str: @pytest.mark.parametrize("config", ACCURACY_CONFIGS) def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): - results = lm_eval.simple_evaluate( model="vllm", model_args=config.get_model_args(), @@ -47,6 +46,7 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): EXPECTED_VALUE = config.expected_value measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert ( + measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/transformers_utils/test_config_parser_registry.py b/tests/transformers_utils/test_config_parser_registry.py index 13c654e05d2a..9372cb9d46d3 100644 --- a/tests/transformers_utils/test_config_parser_registry.py +++ b/tests/transformers_utils/test_config_parser_registry.py @@ -7,26 +7,25 @@ import pytest from transformers import PretrainedConfig -from vllm.transformers_utils.config import (get_config_parser, - register_config_parser) +from vllm.transformers_utils.config import get_config_parser, register_config_parser from vllm.transformers_utils.config_parser_base import ConfigParserBase @register_config_parser("custom_config_parser") class CustomConfigParser(ConfigParserBase): - - def parse(self, - model: Union[str, Path], - trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - **kwargs) -> tuple[dict, PretrainedConfig]: + def parse( + self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: raise NotImplementedError def test_register_config_parser(): - assert isinstance(get_config_parser("custom_config_parser"), - CustomConfigParser) + assert isinstance(get_config_parser("custom_config_parser"), CustomConfigParser) def test_invalid_config_parser(): diff --git a/tests/utils.py b/tests/utils.py index ffdc0f732543..b853542c241f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,20 +33,29 @@ import vllm.envs as envs from tests.models.utils import TextTextLogprobs -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) +from vllm.distributed import ( + ensure_model_parallel_initialized, + init_distributed_environment, +) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.cli.serve import ServeSubcommand from vllm.model_executor.model_loader import get_model_loader from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer -from vllm.utils import (FlexibleArgumentParser, GB_bytes, - cuda_device_count_stateless, get_open_port) +from vllm.utils import ( + FlexibleArgumentParser, + GB_bytes, + cuda_device_count_stateless, + get_open_port, +) if current_platform.is_rocm(): - from amdsmi import (amdsmi_get_gpu_vram_usage, - amdsmi_get_processor_handles, amdsmi_init, - amdsmi_shut_down) + from amdsmi import ( + amdsmi_get_gpu_vram_usage, + amdsmi_get_processor_handles, + amdsmi_init, + amdsmi_shut_down, + ) @contextmanager def _nvml(): @@ -56,9 +65,12 @@ def _nvml(): finally: amdsmi_shut_down() elif current_platform.is_cuda(): - from vllm.third_party.pynvml import (nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, nvmlInit, - nvmlShutdown) + from vllm.third_party.pynvml import ( + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlInit, + nvmlShutdown, + ) @contextmanager def _nvml(): @@ -81,14 +93,14 @@ def _nvml(): class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key - def _start_server(self, model: str, vllm_serve_args: list[str], - env_dict: Optional[dict[str, str]]) -> None: - """Subclasses override this method to customize server process launch - """ + def _start_server( + self, model: str, vllm_serve_args: list[str], env_dict: Optional[dict[str, str]] + ) -> None: + """Subclasses override this method to customize server process launch""" env = os.environ.copy() # the current process might initialize cuda, # to be safe, we should use spawn method - env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' + env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" if env_dict is not None: env.update(env_dict) serve_cmd = ["vllm", "serve", model, *vllm_serve_args] @@ -100,41 +112,42 @@ def _start_server(self, model: str, vllm_serve_args: list[str], stderr=sys.stderr, ) - def __init__(self, - model: str, - vllm_serve_args: list[str], - *, - env_dict: Optional[dict[str, str]] = None, - seed: Optional[int] = 0, - auto_port: bool = True, - max_wait_seconds: Optional[float] = None, - override_hf_configs: Optional[dict[str, Any]] = None) -> None: + def __init__( + self, + model: str, + vllm_serve_args: list[str], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None, + override_hf_configs: Optional[dict[str, Any]] = None, + ) -> None: if auto_port: if "-p" in vllm_serve_args or "--port" in vllm_serve_args: - raise ValueError("You have manually specified the port " - "when `auto_port=True`.") + raise ValueError( + "You have manually specified the port when `auto_port=True`." + ) # No need for a port if using unix sockets if "--uds" not in vllm_serve_args: # Don't mutate the input args - vllm_serve_args = vllm_serve_args + [ - "--port", str(get_open_port()) - ] + vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())] if seed is not None: if "--seed" in vllm_serve_args: - raise ValueError("You have manually specified the seed " - f"when `seed={seed}`.") + raise ValueError( + f"You have manually specified the seed when `seed={seed}`." + ) vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] if override_hf_configs is not None: vllm_serve_args = vllm_serve_args + [ "--hf-overrides", - json.dumps(override_hf_configs) + json.dumps(override_hf_configs), ] - parser = FlexibleArgumentParser( - description="vLLM's remote OpenAI server.") + parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") subparsers = parser.add_subparsers(required=False, dest="subparser") parser = ServeSubcommand().subparser_init(subparsers) args = parser.parse_args(["--model", model, *vllm_serve_args]) @@ -143,11 +156,10 @@ def __init__(self, self.host = None self.port = None else: - self.host = str(args.host or 'localhost') + self.host = str(args.host or "localhost") self.port = int(args.port) - self.show_hidden_metrics = \ - args.show_hidden_metrics_for_version is not None + self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None # download the model before starting the server to avoid timeout is_local = os.path.isdir(model) @@ -161,8 +173,7 @@ def __init__(self, self._start_server(model, vllm_serve_args, env_dict) max_wait_seconds = max_wait_seconds or 240 - self._wait_for_server(url=self.url_for("health"), - timeout=max_wait_seconds) + self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) def __enter__(self): return self @@ -182,8 +193,11 @@ def _poll(self) -> Optional[int]: def _wait_for_server(self, *, url: str, timeout: float): # run health check start = time.time() - client = (httpx.Client(transport=httpx.HTTPTransport( - uds=self.uds)) if self.uds else requests) + client = ( + httpx.Client(transport=httpx.HTTPTransport(uds=self.uds)) + if self.uds + else requests + ) while True: try: if client.get(url).status_code == 200: @@ -199,13 +213,15 @@ def _wait_for_server(self, *, url: str, timeout: float): time.sleep(0.5) if time.time() - start > timeout: - raise RuntimeError( - "Server failed to start in time.") from None + raise RuntimeError("Server failed to start in time.") from None @property def url_root(self) -> str: - return (f"http://{self.uds.split('/')[-1]}" - if self.uds else f"http://{self.host}:{self.port}") + return ( + f"http://{self.uds.split('/')[-1]}" + if self.uds + else f"http://{self.host}:{self.port}" + ) def url_for(self, *parts: str) -> str: return self.url_root + "/" + "/".join(parts) @@ -223,42 +239,47 @@ def get_client(self, **kwargs): def get_async_client(self, **kwargs): if "timeout" not in kwargs: kwargs["timeout"] = 600 - return openai.AsyncOpenAI(base_url=self.url_for("v1"), - api_key=self.DUMMY_API_KEY, - max_retries=0, - **kwargs) + return openai.AsyncOpenAI( + base_url=self.url_for("v1"), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs, + ) class RemoteOpenAIServerCustom(RemoteOpenAIServer): """Launch test server with custom child process""" - def _start_server(self, model: str, vllm_serve_args: list[str], - env_dict: Optional[dict[str, str]]) -> None: + def _start_server( + self, model: str, vllm_serve_args: list[str], env_dict: Optional[dict[str, str]] + ) -> None: self.proc: Process = Process( - target=self.child_process_fxn, - args=(env_dict, model, - vllm_serve_args)) # type: ignore[assignment] + target=self.child_process_fxn, args=(env_dict, model, vllm_serve_args) + ) # type: ignore[assignment] self.proc.start() - def __init__(self, - model: str, - vllm_serve_args: list[str], - child_process_fxn: Callable[ - [Optional[dict[str, str]], str, list[str]], None], - *, - env_dict: Optional[dict[str, str]] = None, - seed: Optional[int] = 0, - auto_port: bool = True, - max_wait_seconds: Optional[float] = None) -> None: + def __init__( + self, + model: str, + vllm_serve_args: list[str], + child_process_fxn: Callable[[Optional[dict[str, str]], str, list[str]], None], + *, + env_dict: Optional[dict[str, str]] = None, + seed: Optional[int] = 0, + auto_port: bool = True, + max_wait_seconds: Optional[float] = None, + ) -> None: """Store custom child process function then invoke superclass constructor which will indirectly launch it.""" self.child_process_fxn = child_process_fxn - super().__init__(model=model, - vllm_serve_args=vllm_serve_args, - env_dict=env_dict, - seed=seed, - auto_port=auto_port, - max_wait_seconds=max_wait_seconds) + super().__init__( + model=model, + vllm_serve_args=vllm_serve_args, + env_dict=env_dict, + seed=seed, + auto_port=auto_port, + max_wait_seconds=max_wait_seconds, + ) def _poll(self) -> Optional[int]: return self.proc.exitcode @@ -280,17 +301,18 @@ def _test_completion( results = [] # test with text prompt - completion = client.completions.create(model=model, - prompt=prompt, - max_tokens=5, - temperature=0.0) - - results.append({ - "test": "single_completion", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, - }) + completion = client.completions.create( + model=model, prompt=prompt, max_tokens=5, temperature=0.0 + ) + + results.append( + { + "test": "single_completion", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + } + ) # test using token IDs completion = client.completions.create( @@ -300,43 +322,42 @@ def _test_completion( temperature=0.0, ) - results.append({ - "test": "token_ids", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, - }) + results.append( + { + "test": "token_ids", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + } + ) # test seeded random sampling - completion = client.completions.create(model=model, - prompt=prompt, - max_tokens=5, - seed=33, - temperature=1.0) - - results.append({ - "test": "seeded_sampling", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, - }) + completion = client.completions.create( + model=model, prompt=prompt, max_tokens=5, seed=33, temperature=1.0 + ) + + results.append( + { + "test": "seeded_sampling", + "text": completion.choices[0].text, + "finish_reason": completion.choices[0].finish_reason, + "usage": completion.usage, + } + ) # test seeded random sampling with multiple prompts - completion = client.completions.create(model=model, - prompt=[prompt, prompt], - max_tokens=5, - seed=33, - temperature=1.0) - - results.append({ - "test": - "seeded_sampling", - "text": [choice.text for choice in completion.choices], - "finish_reason": - [choice.finish_reason for choice in completion.choices], - "usage": - completion.usage, - }) + completion = client.completions.create( + model=model, prompt=[prompt, prompt], max_tokens=5, seed=33, temperature=1.0 + ) + + results.append( + { + "test": "seeded_sampling", + "text": [choice.text for choice in completion.choices], + "finish_reason": [choice.finish_reason for choice in completion.choices], + "usage": completion.usage, + } + ) # test simple list batch = client.completions.create( @@ -346,11 +367,13 @@ def _test_completion( temperature=0.0, ) - results.append({ - "test": "simple_list", - "text0": batch.choices[0].text, - "text1": batch.choices[1].text, - }) + results.append( + { + "test": "simple_list", + "text0": batch.choices[0].text, + "text1": batch.choices[1].text, + } + ) # test streaming batch = client.completions.create( @@ -367,10 +390,12 @@ def _test_completion( choice = chunk.choices[0] texts[choice.index] += choice.text - results.append({ - "test": "streaming", - "texts": texts, - }) + results.append( + { + "test": "streaming", + "texts": texts, + } + ) return results @@ -383,19 +408,19 @@ def _test_completion_close( results = [] # test with text prompt - completion = client.completions.create(model=model, - prompt=prompt, - max_tokens=1, - logprobs=5, - temperature=0.0) + completion = client.completions.create( + model=model, prompt=prompt, max_tokens=1, logprobs=5, temperature=0.0 + ) logprobs = completion.choices[0].logprobs.top_logprobs[0] logprobs = {k: round(v, 2) for k, v in logprobs.items()} - results.append({ - "test": "completion_close", - "logprobs": logprobs, - }) + results.append( + { + "test": "completion_close", + "logprobs": logprobs, + } + ) return results @@ -407,26 +432,21 @@ def _test_chat( ): results = [] - messages = [{ - "role": "user", - "content": [{ - "type": "text", - "text": prompt - }] - }] + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] # test with text prompt - chat_response = client.chat.completions.create(model=model, - messages=messages, - max_tokens=5, - temperature=0.0) - - results.append({ - "test": "completion_close", - "text": chat_response.choices[0].message.content, - "finish_reason": chat_response.choices[0].finish_reason, - "usage": chat_response.usage, - }) + chat_response = client.chat.completions.create( + model=model, messages=messages, max_tokens=5, temperature=0.0 + ) + + results.append( + { + "test": "completion_close", + "text": chat_response.choices[0].message.content, + "finish_reason": chat_response.choices[0].finish_reason, + "usage": chat_response.usage, + } + ) return results @@ -445,11 +465,13 @@ def _test_embeddings( encoding_format="float", ) - results.append({ - "test": "single_embedding", - "embedding": embeddings.data[0].embedding, - "usage": embeddings.usage, - }) + results.append( + { + "test": "single_embedding", + "embedding": embeddings.data[0].embedding, + "usage": embeddings.usage, + } + ) return results @@ -462,74 +484,75 @@ def _test_image_text( results = [] # test pure text input - messages = [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "How do you feel today?" - }, - ], - }] - - chat_completion = client.chat.completions.create(model=model_name, - messages=messages, - temperature=0.0, - max_tokens=1, - logprobs=True, - top_logprobs=5) + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "How do you feel today?"}, + ], + } + ] + + chat_completion = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0.0, + max_tokens=1, + logprobs=True, + top_logprobs=5, + ) top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs for x in top_logprobs: x.logprob = round(x.logprob, 2) - results.append({ - "test": "pure_text", - "logprobs": top_logprobs, - }) - - messages = [{ - "role": - "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] - - chat_completion = client.chat.completions.create(model=model_name, - messages=messages, - temperature=0.0, - max_tokens=1, - logprobs=True, - top_logprobs=5) + results.append( + { + "test": "pure_text", + "logprobs": top_logprobs, + } + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": image_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] + + chat_completion = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0.0, + max_tokens=1, + logprobs=True, + top_logprobs=5, + ) top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs - results.append({ - "test": "text_image", - "logprobs": top_logprobs, - }) + results.append( + { + "test": "text_image", + "logprobs": top_logprobs, + } + ) return results -def compare_two_settings(model: str, - arg1: list[str], - arg2: list[str], - env1: Optional[dict[str, str]] = None, - env2: Optional[dict[str, str]] = None, - *, - method: str = "generate", - max_wait_seconds: Optional[float] = None) -> None: +def compare_two_settings( + model: str, + arg1: list[str], + arg2: list[str], + env1: Optional[dict[str, str]] = None, + env2: Optional[dict[str, str]] = None, + *, + method: str = "generate", + max_wait_seconds: Optional[float] = None, +) -> None: """ Launch API server with two different sets of arguments/environments and compare the results of the API calls. @@ -551,12 +574,14 @@ def compare_two_settings(model: str, ) -def compare_all_settings(model: str, - all_args: list[list[str]], - all_envs: list[Optional[dict[str, str]]], - *, - method: str = "generate", - max_wait_seconds: Optional[float] = None) -> None: +def compare_all_settings( + model: str, + all_args: list[list[str]], + all_envs: list[Optional[dict[str, str]]], + *, + method: str = "generate", + max_wait_seconds: Optional[float] = None, +) -> None: """ Launch API server with several different sets of arguments/environments and compare the results of the API calls with the first set of arguments. @@ -606,21 +631,22 @@ def compare_all_settings(model: str, args = args + ["--load-format", envs.VLLM_TEST_FORCE_LOAD_FORMAT] compare_results: list = [] results = ref_results if i == 0 else compare_results - with RemoteOpenAIServer(model, - args, - env_dict=env, - max_wait_seconds=max_wait_seconds) as server: + with RemoteOpenAIServer( + model, args, env_dict=env, max_wait_seconds=max_wait_seconds + ) as server: client = server.get_client() # test models list models = client.models.list() models = models.data served_model = models[0] - results.append({ - "test": "models_list", - "id": served_model.id, - "root": served_model.root, - }) + results.append( + { + "test": "models_list", + "id": served_model.id, + "root": served_model.root, + } + ) if method == "generate": results += _test_completion(client, model, prompt, token_ids) @@ -630,8 +656,9 @@ def compare_all_settings(model: str, results += _test_chat(client, model, prompt) elif method == "generate_with_image": results += _test_image_text( - client, model, - "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png" + client, + model, + "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ) elif method == "encode": results += _test_embeddings(client, model, prompt) @@ -644,8 +671,7 @@ def compare_all_settings(model: str, ref_envs = all_envs[0] compare_args = all_args[i] compare_envs = all_envs[i] - for ref_result, compare_result in zip(ref_results, - compare_results): + for ref_result, compare_result in zip(ref_results, compare_results): ref_result = copy.deepcopy(ref_result) compare_result = copy.deepcopy(compare_result) if "embedding" in ref_result and method == "encode": @@ -656,7 +682,8 @@ def compare_all_settings(model: str, ) assert sim >= 0.999, ( f"Embedding for {model=} are not the same.\n" - f"cosine_similarity={sim}\n") + f"cosine_similarity={sim}\n" + ) del ref_result["embedding"] del compare_result["embedding"] assert ref_result == compare_result, ( @@ -664,7 +691,8 @@ def compare_all_settings(model: str, f"{ref_args=} {ref_envs=}\n" f"{compare_args=} {compare_envs=}\n" f"{ref_result=}\n" - f"{compare_result=}\n") + f"{compare_result=}\n" + ) def init_test_distributed_environment( @@ -679,7 +707,8 @@ def init_test_distributed_environment( world_size=pp_size * tp_size, rank=rank, distributed_init_method=distributed_init_method, - local_rank=local_rank) + local_rank=local_rank, + ) ensure_model_parallel_initialized(tp_size, pp_size) @@ -701,13 +730,17 @@ def multi_process_parallel( os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1" ray.init( runtime_env={ - "working_dir": - VLLM_PATH, + "working_dir": VLLM_PATH, "excludes": [ - "build", ".git", "cmake-build-*", "shellcheck", "dist", - "ep_kernels_workspace" - ] - }) + "build", + ".git", + "cmake-build-*", + "shellcheck", + "dist", + "ep_kernels_workspace", + ], + } + ) distributed_init_port = get_open_port() refs = [] @@ -719,7 +752,8 @@ def multi_process_parallel( pp_size, rank, distributed_init_port, - ), ) + ), + ) ray.get(refs) ray.shutdown() @@ -748,11 +782,13 @@ def get_physical_device_indices(devices): @_nvml() -def wait_for_gpu_memory_to_clear(*, - devices: list[int], - threshold_bytes: Optional[int] = None, - threshold_ratio: Optional[float] = None, - timeout_s: float = 120) -> None: +def wait_for_gpu_memory_to_clear( + *, + devices: list[int], + threshold_bytes: Optional[int] = None, + threshold_ratio: Optional[float] = None, + timeout_s: float = 120, +) -> None: assert threshold_bytes is not None or threshold_ratio is not None # Use nvml instead of pytorch to reduce measurement error from torch cuda # context. @@ -773,29 +809,33 @@ def wait_for_gpu_memory_to_clear(*, gb_used = mem_info.used / 2**30 gb_total = mem_info.total / 2**30 output_raw[device] = (gb_used, gb_total) - output[device] = f'{gb_used:.02f}/{gb_total:.02f}' + output[device] = f"{gb_used:.02f}/{gb_total:.02f}" - print('gpu memory used/total (GiB): ', end='') + print("gpu memory used/total (GiB): ", end="") for k, v in output.items(): - print(f'{k}={v}; ', end='') - print('') + print(f"{k}={v}; ", end="") + print("") if threshold_bytes is not None: is_free = lambda used, total: used <= threshold_bytes / 2**30 - threshold = f"{threshold_bytes/2**30} GiB" + threshold = f"{threshold_bytes / 2**30} GiB" else: is_free = lambda used, total: used / total <= threshold_ratio threshold = f"{threshold_ratio:.2f}" dur_s = time.time() - start_time if all(is_free(used, total) for used, total in output_raw.values()): - print(f'Done waiting for free GPU memory on devices {devices=} ' - f'({threshold=}) {dur_s=:.02f}') + print( + f"Done waiting for free GPU memory on devices {devices=} " + f"({threshold=}) {dur_s=:.02f}" + ) break if dur_s >= timeout_s: - raise ValueError(f'Memory of devices {devices=} not free after ' - f'{dur_s=:.02f} ({threshold=})') + raise ValueError( + f"Memory of devices {devices=} not free after " + f"{dur_s=:.02f} ({threshold=})" + ) time.sleep(5) @@ -803,8 +843,7 @@ def wait_for_gpu_memory_to_clear(*, _P = ParamSpec("_P") -def fork_new_process_for_each_test( - func: Callable[_P, None]) -> Callable[_P, None]: +def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]: """Decorator to fork a new process for each test function. See https://github.com/vllm-project/vllm/issues/7053 for more details. """ @@ -818,11 +857,15 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # Create a unique temporary file to store exception info from child # process. Use test function name and process ID to avoid collisions. - with tempfile.NamedTemporaryFile( + with ( + tempfile.NamedTemporaryFile( delete=False, - mode='w+b', + mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", - suffix=".exc") as exc_file, ExitStack() as delete_after: + suffix=".exc", + ) as exc_file, + ExitStack() as delete_after, + ): exc_file_path = exc_file.name delete_after.callback(os.remove, exc_file_path) @@ -840,6 +883,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: os._exit(0) except Exception as e: import traceback + tb_string = traceback.format_exc() # Try to serialize the exception object first @@ -847,18 +891,18 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: try: # First, try to pickle the actual exception with # its traceback. - exc_to_serialize = {'pickled_exception': e} + exc_to_serialize = {"pickled_exception": e} # Test if it can be pickled cloudpickle.dumps(exc_to_serialize) except (Exception, KeyboardInterrupt): # Fall back to string-based approach. exc_to_serialize = { - 'exception_type': type(e).__name__, - 'exception_msg': str(e), - 'traceback': tb_string, + "exception_type": type(e).__name__, + "exception_msg": str(e), + "traceback": tb_string, } try: - with open(exc_file_path, 'wb') as f: + with open(exc_file_path, "wb") as f: cloudpickle.dump(exc_to_serialize, f) except Exception: # Fallback: just print the traceback. @@ -870,8 +914,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: pgid = os.getpgid(pid) _pid, _exitcode = os.waitpid(pid, 0) # ignore SIGTERM signal itself - old_signal_handler = signal.signal(signal.SIGTERM, - signal.SIG_IGN) + old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) # kill all child processes os.killpg(pgid, signal.SIGTERM) # restore the signal handler @@ -880,12 +923,15 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # Try to read the exception from the child process exc_info = {} if os.path.exists(exc_file_path): - with contextlib.suppress(Exception), \ - open(exc_file_path, 'rb') as f: + with ( + contextlib.suppress(Exception), + open(exc_file_path, "rb") as f, + ): exc_info = cloudpickle.load(f) - if (original_exception := - exc_info.get('pickled_exception')) is not None: + if ( + original_exception := exc_info.get("pickled_exception") + ) is not None: # Re-raise the actual exception object if it was # successfully pickled. assert isinstance(original_exception, Exception) @@ -903,33 +949,33 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: raise AssertionError( f"function {func.__name__} failed when called with" f" args {args} and kwargs {kwargs}" - f" (exit code: {_exitcode})") from None + f" (exit code: {_exitcode})" + ) from None return wrapper -def spawn_new_process_for_each_test( - f: Callable[_P, None]) -> Callable[_P, None]: - """Decorator to spawn a new process for each test function. - """ +def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]: + """Decorator to spawn a new process for each test function.""" @functools.wraps(f) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # Check if we're already in a subprocess - if os.environ.get('RUNNING_IN_SUBPROCESS') == '1': + if os.environ.get("RUNNING_IN_SUBPROCESS") == "1": # If we are, just run the function directly return f(*args, **kwargs) import torch.multiprocessing as mp + with suppress(RuntimeError): - mp.set_start_method('spawn') + mp.set_start_method("spawn") # Get the module module_name = f.__module__ # Create a process with environment variable set env = os.environ.copy() - env['RUNNING_IN_SUBPROCESS'] = '1' + env["RUNNING_IN_SUBPROCESS"] = "1" with tempfile.TemporaryDirectory() as tempdir: output_filepath = os.path.join(tempdir, "new_process.tmp") @@ -939,29 +985,29 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: cmd = [sys.executable, "-m", f"{module_name}"] - returned = subprocess.run(cmd, - input=input_bytes, - capture_output=True, - env=env) + returned = subprocess.run( + cmd, input=input_bytes, capture_output=True, env=env + ) # check if the subprocess is successful try: returned.check_returncode() except Exception as e: # wrap raised exception to provide more information - raise RuntimeError(f"Error raised in subprocess:\n" - f"{returned.stderr.decode()}") from e + raise RuntimeError( + f"Error raised in subprocess:\n{returned.stderr.decode()}" + ) from e return wrapper def create_new_process_for_each_test( - method: Optional[Literal["spawn", "fork"]] = None + method: Optional[Literal["spawn", "fork"]] = None, ) -> Callable[[Callable[_P, None]], Callable[_P, None]]: """Creates a decorator that runs each test function in a new process. Args: - method: The process creation method. Can be either "spawn" or "fork". + method: The process creation method. Can be either "spawn" or "fork". If not specified, it defaults to "spawn" on ROCm and XPU platforms and "fork" otherwise. @@ -972,8 +1018,7 @@ def create_new_process_for_each_test( use_spawn = current_platform.is_rocm() or current_platform.is_xpu() method = "spawn" if use_spawn else "fork" - assert method in ["spawn", - "fork"], "Method must be either 'spawn' or 'fork'" + assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" if method == "fork": return fork_new_process_for_each_test @@ -1057,7 +1102,7 @@ async def completions_with_server_args( max_wait_seconds: int = 240, max_tokens: Union[int, list] = 5, ) -> list[Completion]: - '''Construct a remote OpenAI server, obtain an async client to the + """Construct a remote OpenAI server, obtain an async client to the server & invoke the completions API to obtain completions. Args: @@ -1073,7 +1118,7 @@ async def completions_with_server_args( Returns: OpenAI Completion instance - ''' + """ if isinstance(max_tokens, int): max_tokens = [max_tokens] * len(prompts) @@ -1081,17 +1126,21 @@ async def completions_with_server_args( assert len(max_tokens) == len(prompts) outputs = None - with RemoteOpenAIServer(model_name, - server_cli_args, - max_wait_seconds=max_wait_seconds) as server: + with RemoteOpenAIServer( + model_name, server_cli_args, max_wait_seconds=max_wait_seconds + ) as server: client = server.get_async_client() - outputs = [ client.completions.create(model=model_name, - prompt=[p], - temperature=0, - stream=False, - max_tokens=max_tok, - logprobs=num_logprobs) \ - for p, max_tok in zip(prompts, max_tokens) ] + outputs = [ + client.completions.create( + model=model_name, + prompt=[p], + temperature=0, + stream=False, + max_tokens=max_tok, + logprobs=num_logprobs, + ) + for p, max_tok in zip(prompts, max_tokens) + ] outputs = await asyncio.gather(*outputs) assert outputs is not None, "Completion API call failed." @@ -1100,24 +1149,31 @@ async def completions_with_server_args( def get_client_text_generations(completions: list[Completion]) -> list[str]: - '''Extract generated tokens from the output of a + """Extract generated tokens from the output of a request made to an Open-AI-protocol completions endpoint. - ''' + """ assert all([len(x.choices) == 1 for x in completions]) return [x.choices[0].text for x in completions] def get_client_text_logprob_generations( - completions: list[Completion]) -> list[TextTextLogprobs]: - '''Operates on the output of a request made to an Open-AI-protocol + completions: list[Completion], +) -> list[TextTextLogprobs]: + """Operates on the output of a request made to an Open-AI-protocol completions endpoint; obtains top-rank logprobs for each token in each {class}`SequenceGroup` - ''' + """ text_generations = get_client_text_generations(completions) - text = ''.join(text_generations) - return [(text_generations, text, - (None if x.logprobs is None else x.logprobs.top_logprobs)) - for completion in completions for x in completion.choices] + text = "".join(text_generations) + return [ + ( + text_generations, + text, + (None if x.logprobs is None else x.logprobs.top_logprobs), + ) + for completion in completions + for x in completion.choices + ] def has_module_attribute(module_name, attribute_name): @@ -1138,6 +1194,7 @@ def get_attn_backend_list_based_on_platform() -> list[str]: attn_backend_list = ["TRITON_ATTN"] try: import aiter # noqa: F401 + attn_backend_list.append("FLASH_ATTN") except Exception: print("Skip FLASH_ATTN on ROCm as aiter is not installed") @@ -1152,8 +1209,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]: @contextmanager def override_cutlass_fp8_supported(value: bool): with patch( - "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported", - return_value=value): + "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported", + return_value=value, + ): yield @@ -1174,8 +1232,10 @@ def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): for _ in range(batch_size): idx = random.randint(30, 90) indices.append(idx) - prompt = "```python\n# We set a number of variables, " + \ - f"x{idx} will be important later\n" + prompt = ( + "```python\n# We set a number of variables, " + + f"x{idx} will be important later\n" + ) ln = random.randint(*ln_range) for k in range(30, ln): v = random.randint(10, 99) @@ -1188,10 +1248,9 @@ def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): return prompts, answer, indices -def check_answers(indices: list[int], - answer: list[int], - outputs: list[str], - accept_rate: float = 0.7): +def check_answers( + indices: list[int], answer: list[int], outputs: list[str], accept_rate: float = 0.7 +): answer2 = [int(text[0:2].strip()) for text in outputs] print(list(zip(indices, zip(answer, answer2)))) numok = 0 diff --git a/tests/utils_/test_gc_utils.py b/tests/utils_/test_gc_utils.py index 265761b069ca..f1d0de87c81b 100644 --- a/tests/utils_/test_gc_utils.py +++ b/tests/utils_/test_gc_utils.py @@ -3,8 +3,11 @@ from dataclasses import dataclass from typing import Any -from vllm.utils.gc_utils import (GCDebugConfig, _compute_detailed_type, - _compute_top_gc_collected_objects) +from vllm.utils.gc_utils import ( + GCDebugConfig, + _compute_detailed_type, + _compute_top_gc_collected_objects, +) @dataclass @@ -21,38 +24,51 @@ def __len__(self) -> int: def test_compute_detailed_type(): - assert _compute_detailed_type( - Normal(v=8)) == "<class 'tests.utils_.test_gc_utils.Normal'>" + assert ( + _compute_detailed_type(Normal(v=8)) + == "<class 'tests.utils_.test_gc_utils.Normal'>" + ) assert _compute_detailed_type([1, 2, 3]) == "<class 'list'>(size:3)" assert _compute_detailed_type({4, 5}) == "<class 'set'>(size:2)" assert _compute_detailed_type({6: 7}) == "<class 'dict'>(size:1)" - assert _compute_detailed_type(ListWrapper( - vs=[])) == "<class 'tests.utils_.test_gc_utils.ListWrapper'>(size:0)" + assert ( + _compute_detailed_type(ListWrapper(vs=[])) + == "<class 'tests.utils_.test_gc_utils.ListWrapper'>(size:0)" + ) def test_compute_top_gc_collected_objects(): - objects: list[Any] = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], - {13, 14}, { - 15: 16, - 17: 18 - }, - Normal(v=19), - Normal(v=20), - Normal(v=21)] + objects: list[Any] = [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12], + {13, 14}, + {15: 16, 17: 18}, + Normal(v=19), + Normal(v=20), + Normal(v=21), + ] assert _compute_top_gc_collected_objects(objects, top=-1) == "" assert _compute_top_gc_collected_objects(objects, top=0) == "" - assert _compute_top_gc_collected_objects( - objects, top=1) == " 4:<class 'list'>(size:3)" - assert _compute_top_gc_collected_objects(objects, top=2) == "\n".join([ - " 4:<class 'list'>(size:3)", - " 3:<class 'tests.utils_.test_gc_utils.Normal'>" - ]) - assert _compute_top_gc_collected_objects(objects, top=3) == "\n".join([ - " 4:<class 'list'>(size:3)", - " 3:<class 'tests.utils_.test_gc_utils.Normal'>", - " 1:<class 'set'>(size:2)" - ]) + assert ( + _compute_top_gc_collected_objects(objects, top=1) + == " 4:<class 'list'>(size:3)" + ) + assert _compute_top_gc_collected_objects(objects, top=2) == "\n".join( + [ + " 4:<class 'list'>(size:3)", + " 3:<class 'tests.utils_.test_gc_utils.Normal'>", + ] + ) + assert _compute_top_gc_collected_objects(objects, top=3) == "\n".join( + [ + " 4:<class 'list'>(size:3)", + " 3:<class 'tests.utils_.test_gc_utils.Normal'>", + " 1:<class 'set'>(size:2)", + ] + ) def test_gc_debug_config(): @@ -64,6 +80,6 @@ def test_gc_debug_config(): assert config.enabled assert config.top_objects == -1 - config = GCDebugConfig("{\"top_objects\":5}") + config = GCDebugConfig('{"top_objects":5}') assert config.enabled assert config.top_objects == 5 diff --git a/tests/utils_/test_tensor_schema.py b/tests/utils_/test_tensor_schema.py index 6aa781c1564d..c86bed75472c 100644 --- a/tests/utils_/test_tensor_schema.py +++ b/tests/utils_/test_tensor_schema.py @@ -6,37 +6,38 @@ from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs +from vllm.model_executor.models.hyperclovax_vision import HCXVisionVideoPixelInputs from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs def test_tensor_schema_valid_tensor(): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 3, 32, 32), + pixel_values=torch.randn(16, 64, 3, 32, 32), image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_optional_fields(): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 3, 32, 32), + pixel_values=torch.randn(16, 64, 3, 32, 32), image_sizes=None, ) - Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), ) + Phi3VImagePixelInputs(pixel_values=torch.randn(16, 64, 3, 32, 32)) def test_tensor_schema_constant_dim_failure(): with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4 + pixel_values=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4 image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_invalid_types_in_list(): - with pytest.raises(ValueError, match="is not a torch.Tensor"): + with pytest.raises(TypeError, match="is not one of the expected types"): Phi3VImagePixelInputs( - data=[ + pixel_values=[ torch.randn(64, 3, 32, 32), "not_a_tensor", torch.randn(64, 3, 32, 32), @@ -48,27 +49,29 @@ def test_tensor_schema_invalid_types_in_list(): def test_tensor_schema_rank_mismatch(): with pytest.raises(ValueError, match="has rank 3 but expected 5"): Phi3VImagePixelInputs( - data=torch.randn(16, 64, 3), + pixel_values=torch.randn(16, 64, 3), image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_missing_required_field(): - with pytest.raises(ValueError, match="Required field 'data' is missing"): - Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), ) + with pytest.raises(ValueError, match="Required field 'pixel_values' is missing"): + Phi3VImagePixelInputs( + image_sizes=torch.randint(0, 256, (16, 2)), + ) def test_tensor_schema_symbolic_dim_mismatch(): with pytest.raises(ValueError, match="expected 'bn'=12, got 16"): Phi3VImagePixelInputs( - data=torch.randn(12, 64, 3, 32, 32), + pixel_values=torch.randn(12, 64, 3, 32, 32), image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_list_tensor_valid(): Phi3VImagePixelInputs( - data=[torch.randn(64, 3, 32, 32) for _ in range(16)], + pixel_values=[torch.randn(64, 3, 32, 32) for _ in range(16)], image_sizes=torch.randint(0, 256, (16, 2)), ) @@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid(): def test_tensor_schema_variable_patch_counts_valid(): # Each image has a different number of patches (p) # Each tensor has shape (p, 3, 32, 32) - data = [ - torch.randn(16, 3, 32, 32), # p = 16 - torch.randn(32, 3, 32, 32), # p = 32 - torch.randn(64, 3, 32, 32), # p = 64 - ] - image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3 Phi3VImagePixelInputs( - data=data, - image_sizes=image_sizes, + pixel_values=[ + torch.randn(16, 3, 32, 32), # p = 16 + torch.randn(32, 3, 32, 32), # p = 32 + torch.randn(64, 3, 32, 32), # p = 64 + ], + image_sizes=torch.randint(0, 256, (3, 2)), # bn = 3 ) def test_tensor_schema_tuple_tensor_valid(): Phi3VImagePixelInputs( - data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)), + pixel_values=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)), image_sizes=torch.randint(0, 256, (16, 2)), ) +def test_tensor_schema_double_nested_tensors(): + x = torch.rand(4, 3, 32, 32) + y = torch.rand(2, 3, 32, 32) + + HCXVisionVideoPixelInputs(pixel_values_videos=([x, y, x], [y], [x, y])) + + def test_tensor_schema_inconsistent_shapes_in_list(): with pytest.raises(ValueError, match="contains inconsistent shapes"): Phi3VImagePixelInputs( - data=[torch.randn(64, 3, 32, 32), - torch.randn(64, 3, 16, 16)] + - [torch.randn(64, 3, 32, 32) for _ in range(14)], + pixel_values=[ + torch.randn(64, 3, 32, 32), + torch.randn(64, 3, 16, 16), + *(torch.randn(64, 3, 32, 32) for _ in range(14)), + ], image_sizes=torch.randint(0, 256, (16, 2)), ) def test_tensor_schema_empty_list(): - with pytest.raises(ValueError, match="is an empty list"): + with pytest.raises(ValueError, match="is an empty sequence"): Phi3VImagePixelInputs( - data=[], + pixel_values=[], image_sizes=torch.randint(0, 256, (0, 2)), ) @@ -117,39 +127,33 @@ def test_tensor_schema_validation_disabled_skips_shape_check(): # This should NOT raise, because validation is turned off # This would normally fail (dim[2] should be 3, not 4) Phi3VImagePixelInputs( - data=torch.randn(16, 64, 4, 32, 32), + pixel_values=torch.randn(16, 64, 4, 32, 32), image_sizes=torch.randint(0, 256, (16, 2)), validate=False, ) def test_tensor_schema_with_valid_resolve_binding_dims(): - data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336 + pixel_values = torch.randn(16, 64, 3, 336, 336) # h=336, w=336 image_sizes = torch.randint(0, 256, (16, 2)) Phi3VImagePixelInputs( - data=data, + pixel_values=pixel_values, image_sizes=image_sizes, - resolve_bindings={ - "h": 336, - "w": 336 - }, + resolve_bindings={"h": 336, "w": 336}, ) def test_tensor_schema_with_invalid_resolve_binding_dims(): - data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36 + pixel_values = torch.randn(16, 64, 3, 36, 36) # h=36, w=36 image_sizes = torch.randint(0, 256, (16, 2)) # Should raise because 'h' and 'w' don't match resolve bindings with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"): Phi3VImagePixelInputs( - data=data, + pixel_values=pixel_values, image_sizes=image_sizes, - resolve_bindings={ - "h": 336, - "w": 336 - }, + resolve_bindings={"h": 336, "w": 336}, ) diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index bdd92cc8e35e..71c82feac36b 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -21,24 +21,41 @@ from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.transformers_utils.detokenizer_utils import ( - convert_ids_list_to_tokens) +from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens # isort: off from vllm.utils import ( - CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, - PlaceholderModule, bind_kv_cache, common_broadcastable_dtype, - current_stream, deprecate_kwargs, get_open_port, get_tcp_uri, - is_lossless_cast, join_host_port, make_zmq_path, make_zmq_socket, - memory_profiling, merge_async_iterators, sha256, split_host_port, - split_zmq_path, supports_kw, swap_dict_values, unique_filepath) + CacheInfo, + FlexibleArgumentParser, + LRUCache, + MemorySnapshot, + PlaceholderModule, + bind_kv_cache, + common_broadcastable_dtype, + current_stream, + deprecate_kwargs, + get_open_port, + get_tcp_uri, + is_lossless_cast, + join_host_port, + make_zmq_path, + make_zmq_socket, + memory_profiling, + merge_async_iterators, + sha256, + split_host_port, + split_zmq_path, + supports_kw, + swap_dict_values, + unique_filepath, +) + # isort: on from ..utils import create_new_process_for_each_test, error_on_warning @pytest.mark.asyncio async def test_merge_async_iterators(): - async def mock_async_iterator(idx: int): try: while True: @@ -72,7 +89,6 @@ async def stream_output(generator: AsyncIterator[tuple[int, str]]): def test_deprecate_kwargs_always(): - @deprecate_kwargs("old_arg", is_deprecated=True) def dummy(*, old_arg: object = None, new_arg: object = None): pass @@ -85,7 +101,6 @@ def dummy(*, old_arg: object = None, new_arg: object = None): def test_deprecate_kwargs_never(): - @deprecate_kwargs("old_arg", is_deprecated=False) def dummy(*, old_arg: object = None, new_arg: object = None): pass @@ -120,7 +135,6 @@ def dummy(*, old_arg: object = None, new_arg: object = None): def test_deprecate_kwargs_additional_message(): - @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") def dummy(*, old_arg: object = None, new_arg: object = None): pass @@ -145,99 +159,107 @@ def test_get_open_port(monkeypatch: pytest.MonkeyPatch): @pytest.fixture def parser(): parser = FlexibleArgumentParser() - parser.add_argument('--image-input-type', - choices=['pixel_values', 'image_features']) - parser.add_argument('--model-name') - parser.add_argument('--batch-size', type=int) - parser.add_argument('--enable-feature', action='store_true') - parser.add_argument('--hf-overrides', type=json.loads) - parser.add_argument('-O', '--compilation-config', type=json.loads) + parser.add_argument( + "--image-input-type", choices=["pixel_values", "image_features"] + ) + parser.add_argument("--model-name") + parser.add_argument("--batch-size", type=int) + parser.add_argument("--enable-feature", action="store_true") + parser.add_argument("--hf-overrides", type=json.loads) + parser.add_argument("-O", "--compilation-config", type=json.loads) return parser @pytest.fixture def parser_with_config(): parser = FlexibleArgumentParser() - parser.add_argument('serve') - parser.add_argument('model_tag', nargs='?') - parser.add_argument('--model', type=str) - parser.add_argument('--served-model-name', type=str) - parser.add_argument('--config', type=str) - parser.add_argument('--port', type=int) - parser.add_argument('--tensor-parallel-size', type=int) - parser.add_argument('--trust-remote-code', action='store_true') + parser.add_argument("serve") + parser.add_argument("model_tag", nargs="?") + parser.add_argument("--model", type=str) + parser.add_argument("--served-model-name", type=str) + parser.add_argument("--config", type=str) + parser.add_argument("--port", type=int) + parser.add_argument("--tensor-parallel-size", type=int) + parser.add_argument("--trust-remote-code", action="store_true") return parser def test_underscore_to_dash(parser): - args = parser.parse_args(['--image_input_type', 'pixel_values']) - assert args.image_input_type == 'pixel_values' + args = parser.parse_args(["--image_input_type", "pixel_values"]) + assert args.image_input_type == "pixel_values" def test_mixed_usage(parser): - args = parser.parse_args([ - '--image_input_type', 'image_features', '--model-name', - 'facebook/opt-125m' - ]) - assert args.image_input_type == 'image_features' - assert args.model_name == 'facebook/opt-125m' + args = parser.parse_args( + ["--image_input_type", "image_features", "--model-name", "facebook/opt-125m"] + ) + assert args.image_input_type == "image_features" + assert args.model_name == "facebook/opt-125m" def test_with_equals_sign(parser): args = parser.parse_args( - ['--image_input_type=pixel_values', '--model-name=facebook/opt-125m']) - assert args.image_input_type == 'pixel_values' - assert args.model_name == 'facebook/opt-125m' + ["--image_input_type=pixel_values", "--model-name=facebook/opt-125m"] + ) + assert args.image_input_type == "pixel_values" + assert args.model_name == "facebook/opt-125m" def test_with_int_value(parser): - args = parser.parse_args(['--batch_size', '32']) + args = parser.parse_args(["--batch_size", "32"]) assert args.batch_size == 32 - args = parser.parse_args(['--batch-size', '32']) + args = parser.parse_args(["--batch-size", "32"]) assert args.batch_size == 32 def test_with_bool_flag(parser): - args = parser.parse_args(['--enable_feature']) + args = parser.parse_args(["--enable_feature"]) assert args.enable_feature is True - args = parser.parse_args(['--enable-feature']) + args = parser.parse_args(["--enable-feature"]) assert args.enable_feature is True def test_invalid_choice(parser): with pytest.raises(SystemExit): - parser.parse_args(['--image_input_type', 'invalid_choice']) + parser.parse_args(["--image_input_type", "invalid_choice"]) def test_missing_required_argument(parser): - parser.add_argument('--required-arg', required=True) + parser.add_argument("--required-arg", required=True) with pytest.raises(SystemExit): parser.parse_args([]) def test_cli_override_to_config(parser_with_config, cli_config_file): - args = parser_with_config.parse_args([ - 'serve', 'mymodel', '--config', cli_config_file, - '--tensor-parallel-size', '3' - ]) + args = parser_with_config.parse_args( + ["serve", "mymodel", "--config", cli_config_file, "--tensor-parallel-size", "3"] + ) assert args.tensor_parallel_size == 3 - args = parser_with_config.parse_args([ - 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', - cli_config_file - ]) + args = parser_with_config.parse_args( + ["serve", "mymodel", "--tensor-parallel-size", "3", "--config", cli_config_file] + ) assert args.tensor_parallel_size == 3 assert args.port == 12312 - args = parser_with_config.parse_args([ - 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', - cli_config_file, '--port', '666' - ]) + args = parser_with_config.parse_args( + [ + "serve", + "mymodel", + "--tensor-parallel-size", + "3", + "--config", + cli_config_file, + "--port", + "666", + ] + ) assert args.tensor_parallel_size == 3 assert args.port == 666 def test_config_args(parser_with_config, cli_config_file): args = parser_with_config.parse_args( - ['serve', 'mymodel', '--config', cli_config_file]) + ["serve", "mymodel", "--config", cli_config_file] + ) assert args.tensor_parallel_size == 2 assert args.trust_remote_code @@ -245,22 +267,31 @@ def test_config_args(parser_with_config, cli_config_file): def test_config_file(parser_with_config): with pytest.raises(FileNotFoundError): parser_with_config.parse_args( - ['serve', 'mymodel', '--config', 'test_config.yml']) + ["serve", "mymodel", "--config", "test_config.yml"] + ) with pytest.raises(ValueError): parser_with_config.parse_args( - ['serve', 'mymodel', '--config', './data/test_config.json']) + ["serve", "mymodel", "--config", "./data/test_config.json"] + ) with pytest.raises(ValueError): - parser_with_config.parse_args([ - 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', - '--batch-size', '32' - ]) + parser_with_config.parse_args( + [ + "serve", + "mymodel", + "--tensor-parallel-size", + "3", + "--config", + "--batch-size", + "32", + ] + ) def test_no_model_tag(parser_with_config, cli_config_file): with pytest.raises(ValueError): - parser_with_config.parse_args(['serve', '--config', cli_config_file]) + parser_with_config.parse_args(["serve", "--config", cli_config_file]) def test_dict_args(parser): @@ -323,7 +354,7 @@ def test_dict_args(parser): }, "key14": { "key15": "-minus.and.dot", - } + }, } assert parsed_args.compilation_config == { "level": 1, @@ -357,7 +388,6 @@ def test_duplicate_dict_args(caplog_vllm, parser): assert "-O.level" in caplog_vllm.text -# yapf: enable @pytest.mark.parametrize( "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", [ @@ -375,24 +405,28 @@ def test_duplicate_dict_args(caplog_vllm, parser): (lambda foo, **kwargs: None, "something_else", False, True, True), (lambda foo, **kwargs: None, "kwargs", True, True, False), (lambda foo, **kwargs: None, "foo", True, True, False), - ]) -# yapf: disable -def test_supports_kw(callable,kw_name,requires_kw_only, - allow_var_kwargs,is_supported): - assert supports_kw( + ], +) +def test_supports_kw( + callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported +): + assert ( + supports_kw( callable=callable, kw_name=kw_name, requires_kw_only=requires_kw_only, - allow_var_kwargs=allow_var_kwargs - ) == is_supported + allow_var_kwargs=allow_var_kwargs, + ) + == is_supported + ) @create_new_process_for_each_test() def test_memory_profiling(): # Fake out some model loading + inference memory usage to test profiling # Memory used by other processes will show up as cuda usage outside of torch - from vllm.distributed.device_communicators.cuda_wrapper import ( - CudaRTLibrary) + from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary + lib = CudaRTLibrary() # 512 MiB allocation outside of this instance handle1 = lib.cudaMalloc(512 * 1024 * 1024) @@ -401,9 +435,9 @@ def test_memory_profiling(): # load weights - weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32) + weights = torch.randn(128, 1024, 1024, device="cuda", dtype=torch.float32) - weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB + weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB def measure_current_non_torch(): free, total = torch.cuda.mem_get_info() @@ -412,11 +446,14 @@ def measure_current_non_torch(): current_non_torch = current_used - current_torch return current_non_torch - with memory_profiling(baseline_snapshot=baseline_snapshot, - weights_memory=weights_memory) as result, \ - monitor(measure_current_non_torch) as monitored_values: + with ( + memory_profiling( + baseline_snapshot=baseline_snapshot, weights_memory=weights_memory + ) as result, + monitor(measure_current_non_torch) as monitored_values, + ): # make a memory spike, 1 GiB - spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) + spike = torch.randn(256, 1024, 1024, device="cuda", dtype=torch.float32) del spike # Add some extra non-torch memory 256 MiB (simulate NCCL) @@ -431,7 +468,7 @@ def measure_current_non_torch(): # 5% tolerance is caused by cuda runtime. # we cannot control cuda runtime in the granularity of bytes, # which causes a small error (<10 MiB in practice) - non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa + non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa assert abs(non_torch_ratio - 1) <= 0.05 assert result.torch_peak_increase == 1024 * 1024 * 1024 del weights @@ -443,87 +480,84 @@ def test_bind_kv_cache(): from vllm.attention import Attention ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), - 'layers.1.self_attn': Attention(32, 128, 0.1), - 'layers.2.self_attn': Attention(32, 128, 0.1), - 'layers.3.self_attn': Attention(32, 128, 0.1), + "layers.0.self_attn": Attention(32, 128, 0.1), + "layers.1.self_attn": Attention(32, 128, 0.1), + "layers.2.self_attn": Attention(32, 128, 0.1), + "layers.3.self_attn": Attention(32, 128, 0.1), } kv_cache = [ - torch.zeros((1, )), - torch.zeros((1, )), - torch.zeros((1, )), - torch.zeros((1, )), + torch.zeros((1,)), + torch.zeros((1,)), + torch.zeros((1,)), + torch.zeros((1,)), ] bind_kv_cache(ctx, [kv_cache]) - assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0] - assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1] - assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2] - assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3] + assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0] + assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1] + assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[2] + assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[3] + def test_bind_kv_cache_kv_sharing(): from vllm.attention import Attention ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), - 'layers.1.self_attn': Attention(32, 128, 0.1), - 'layers.2.self_attn': Attention(32, 128, 0.1), - 'layers.3.self_attn': Attention(32, 128, 0.1), + "layers.0.self_attn": Attention(32, 128, 0.1), + "layers.1.self_attn": Attention(32, 128, 0.1), + "layers.2.self_attn": Attention(32, 128, 0.1), + "layers.3.self_attn": Attention(32, 128, 0.1), } kv_cache = [ - torch.zeros((1, )), - torch.zeros((1, )), - torch.zeros((1, )), - torch.zeros((1, )), + torch.zeros((1,)), + torch.zeros((1,)), + torch.zeros((1,)), + torch.zeros((1,)), ] shared_kv_cache_layers = { - 'layers.2.self_attn': 'layers.1.self_attn', - 'layers.3.self_attn': 'layers.0.self_attn' + "layers.2.self_attn": "layers.1.self_attn", + "layers.3.self_attn": "layers.0.self_attn", } bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers) - assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0] - assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1] - assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1] - assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0] + assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0] + assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1] + assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[1] + assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[0] + def test_bind_kv_cache_non_attention(): from vllm.attention import Attention # example from Jamba PP=2 ctx = { - 'model.layers.20.attn': Attention(32, 128, 0.1), - 'model.layers.28.attn': Attention(32, 128, 0.1), + "model.layers.20.attn": Attention(32, 128, 0.1), + "model.layers.28.attn": Attention(32, 128, 0.1), } kv_cache = [ - torch.zeros((1, )), - torch.zeros((1, )), + torch.zeros((1,)), + torch.zeros((1,)), ] bind_kv_cache(ctx, [kv_cache]) - assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0] - assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1] + assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache[0] + assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache[1] def test_bind_kv_cache_pp(): with patch("vllm.utils.cuda_device_count_stateless", lambda: 2): # this test runs with 1 GPU, but we simulate 2 GPUs - cfg = VllmConfig( - parallel_config=ParallelConfig(pipeline_parallel_size=2)) + cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2)) with set_current_vllm_config(cfg): from vllm.attention import Attention ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), + "layers.0.self_attn": Attention(32, 128, 0.1), } - kv_cache = [ - [torch.zeros((1, ))], - [torch.zeros((1, ))] - ] + kv_cache = [[torch.zeros((1,))], [torch.zeros((1,))]] bind_kv_cache(ctx, kv_cache) - assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0] - assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0] + assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0][0] + assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0] class TestLRUCache(LRUCache): - def _on_remove(self, key, value): if not hasattr(self, "_remove_counter"): self._remove_counter = 0 @@ -645,7 +679,6 @@ def test_lru_cache(): assert 6 in cache -# yapf: disable @pytest.mark.parametrize( ("src_dtype", "tgt_dtype", "expected_result"), [ @@ -679,12 +712,10 @@ def test_lru_cache(): (torch.complex64, torch.complex32, False), ], ) -# yapf: enable def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result): assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result -# yapf: disable @pytest.mark.parametrize( ("dtypes", "expected_result"), [ @@ -694,7 +725,6 @@ def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result): ([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501 ], ) -# yapf: enable def test_common_broadcastable_dtype(dtypes, expected_result): assert common_broadcastable_dtype(dtypes) == expected_result @@ -739,7 +769,6 @@ def build_ctx(): _ = placeholder_attr.module -# yapf: disable @pytest.mark.parametrize( "obj,key1,key2", [ @@ -749,8 +778,8 @@ def build_ctx(): ({1: "a", 2: "b"}, 1, 3), # Tests for both keys do not exist ({1: "a", 2: "b"}, 3, 4), - ]) -# yapf: enable + ], +) def test_swap_dict_values(obj, key1, key2): original_obj = obj.copy() swap_dict_values(obj, key1, key2) @@ -764,26 +793,30 @@ def test_swap_dict_values(obj, key1, key2): assert key1 not in obj -def test_model_specification(parser_with_config, cli_config_file, - cli_config_file_with_model): +def test_model_specification( + parser_with_config, cli_config_file, cli_config_file_with_model +): # Test model in CLI takes precedence over config args = parser_with_config.parse_args( - ['serve', 'cli-model', '--config', cli_config_file_with_model]) - assert args.model_tag == 'cli-model' - assert args.served_model_name == 'mymodel' + ["serve", "cli-model", "--config", cli_config_file_with_model] + ) + assert args.model_tag == "cli-model" + assert args.served_model_name == "mymodel" # Test model from config file works - args = parser_with_config.parse_args([ - 'serve', - '--config', - cli_config_file_with_model, - ]) - assert args.model == 'config-model' - assert args.served_model_name == 'mymodel' + args = parser_with_config.parse_args( + [ + "serve", + "--config", + cli_config_file_with_model, + ] + ) + assert args.model == "config-model" + assert args.served_model_name == "mymodel" # Test no model specified anywhere raises error with pytest.raises(ValueError, match="No model specified!"): - parser_with_config.parse_args(['serve', '--config', cli_config_file]) + parser_with_config.parse_args(["serve", "--config", cli_config_file]) # Test using --model option raises error # with pytest.raises( @@ -797,47 +830,52 @@ def test_model_specification(parser_with_config, cli_config_file, # Test using --model option back-compatibility # (when back-compatibility ends, the above test should be uncommented # and the below test should be removed) - args = parser_with_config.parse_args([ - 'serve', - '--tensor-parallel-size', - '2', - '--model', - 'my-model', - '--trust-remote-code', - '--port', - '8001', - ]) + args = parser_with_config.parse_args( + [ + "serve", + "--tensor-parallel-size", + "2", + "--model", + "my-model", + "--trust-remote-code", + "--port", + "8001", + ] + ) assert args.model is None assert args.tensor_parallel_size == 2 assert args.trust_remote_code is True assert args.port == 8001 - args = parser_with_config.parse_args([ - 'serve', - '--tensor-parallel-size=2', - '--model=my-model', - '--trust-remote-code', - '--port=8001', - ]) + args = parser_with_config.parse_args( + [ + "serve", + "--tensor-parallel-size=2", + "--model=my-model", + "--trust-remote-code", + "--port=8001", + ] + ) assert args.model is None assert args.tensor_parallel_size == 2 assert args.trust_remote_code is True assert args.port == 8001 # Test other config values are preserved - args = parser_with_config.parse_args([ - 'serve', - 'cli-model', - '--config', - cli_config_file_with_model, - ]) + args = parser_with_config.parse_args( + [ + "serve", + "cli-model", + "--config", + cli_config_file_with_model, + ] + ) assert args.tensor_parallel_size == 2 assert args.trust_remote_code is True assert args.port == 12312 -@pytest.mark.parametrize("input", [(), ("abc", ), (None, ), - (None, bool, [1, 2, 3])]) +@pytest.mark.parametrize("input", [(), ("abc",), (None,), (None, bool, [1, 2, 3])]) def test_sha256(input: tuple): digest = sha256(input) assert digest is not None @@ -851,7 +889,7 @@ def test_sha256(input: tuple): assert digest == sha256(input) # hashing different input, returns different value - assert digest != sha256(input + (1, )) + assert digest != sha256(input + (1,)) @pytest.mark.parametrize( @@ -861,7 +899,8 @@ def test_sha256(input: tuple): ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")), ("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address ("inproc://some_identifier", ("inproc", "some_identifier", "")), - ]) + ], +) def test_split_zmq_path(path, expected): assert split_zmq_path(path) == expected @@ -873,7 +912,8 @@ def test_split_zmq_path(path, expected): "tcp://127.0.0.1", # Missing port "tcp://[::1]", # Missing port for IPv6 "tcp://:5555", # Missing host - ]) + ], +) def test_split_zmq_path_invalid(invalid_path): with pytest.raises(ValueError): split_zmq_path(invalid_path) @@ -895,8 +935,9 @@ def test_make_zmq_socket_ipv6(): zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type) # Verify that the IPV6 option is set - assert zsock.getsockopt( - zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses" + assert zsock.getsockopt(zmq.IPV6) == 1, ( + "IPV6 option should be enabled for IPv6 addresses" + ) # Clean up zsock.close() @@ -983,15 +1024,14 @@ def test_convert_ids_list_to_tokens(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") token_ids = tokenizer.encode("Hello, world!") # token_ids = [9707, 11, 1879, 0] - assert tokenizer.convert_ids_to_tokens(token_ids) == [ - 'Hello', ',', 'Ġworld', '!' - ] + assert tokenizer.convert_ids_to_tokens(token_ids) == ["Hello", ",", "Ġworld", "!"] tokens = convert_ids_list_to_tokens(tokenizer, token_ids) - assert tokens == ['Hello', ',', ' world', '!'] + assert tokens == ["Hello", ",", " world", "!"] def test_current_stream_multithread(): import threading + if not torch.cuda.is_available(): pytest.skip("CUDA not available") @@ -1010,13 +1050,18 @@ def child_thread_func(): child_thread.start() try: - assert thread_stream_ready.wait( - timeout=5), "Child thread failed to enter stream context in time" + assert thread_stream_ready.wait(timeout=5), ( + "Child thread failed to enter stream context in time" + ) main_current_stream = current_stream() - assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread" - assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream" + assert main_current_stream != child_stream, ( + "Main thread's current_stream was contaminated by child thread" + ) + assert main_current_stream == main_default_stream, ( + "Main thread's current_stream is not the default stream" + ) # Notify child thread it can exit thread_can_exit.set() @@ -1034,7 +1079,7 @@ def test_load_config_file(tmp_path): "enable-logging": True, "list-arg": ["item1", "item2"], "port": 12323, - "tensor-parallel-size": 4 + "tensor-parallel-size": 4, } # Write the configuration data to a temporary YAML file diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 24cdd8afbb3b..188482e071ee 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for v1 attention backends without GPUModelRunner dependency.""" + from functools import partial from typing import Optional, Union @@ -8,21 +9,30 @@ import torch from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, - create_standard_kv_cache_spec, - create_vllm_config, - get_attention_backend) +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend, +) from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - set_kv_cache_layout) +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + set_kv_cache_layout, +) from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, _Backend.FLASHINFER, _Backend.FLEX_ATTENTION, - _Backend.TRITON_ATTN, _Backend.TREE_ATTN, "FLEX_ATTENTION_SLOW" + _Backend.FLASH_ATTN, + _Backend.FLASHINFER, + _Backend.FLEX_ATTENTION, + _Backend.TRITON_ATTN, + _Backend.TREE_ATTN, + "FLEX_ATTENTION_SLOW", ] # Remove flashinfer from the list if it's not available @@ -49,42 +59,38 @@ def _convert_dtype_to_torch(dtype): # Define common batch configurations BATCH_SPECS = { - "small_decode": - BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), - "small_prefill": - BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), - "mixed_small": - BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), - "medium_decode": - BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], - query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), - "medium_prefill": - BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), - "mixed_medium": - BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], - query_lens=[1, 1, 1, 7, 7, 7]), - "large_decode": - BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), - "large_prefill": - BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), - "single_decode": - BatchSpec(seq_lens=[1024], query_lens=[1]), - "single_prefill": - BatchSpec(seq_lens=[1024], query_lens=[64]), + "small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": BatchSpec( + seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1], + ), + "medium_prefill": BatchSpec( + seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16] + ), + "mixed_medium": BatchSpec( + seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7] + ), + "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), } def create_and_prepopulate_kv_cache( - k_contexts: list[torch.Tensor], - v_contexts: list[torch.Tensor], - block_size: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - num_blocks: int, - common_attn_metadata: CommonAttentionMetadata, - randomize_blocks: bool = True) -> torch.Tensor: + k_contexts: list[torch.Tensor], + v_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True, +) -> torch.Tensor: """Create and prepopulate a KV cache with context data. Args: @@ -106,20 +112,18 @@ def create_and_prepopulate_kv_cache( """ batch_size = len(k_contexts) seq_lens = common_attn_metadata.seq_lens_cpu - query_lens = common_attn_metadata.query_start_loc_cpu[ - 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + query_lens = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) context_lens = common_attn_metadata.num_computed_tokens_cpu block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping # Create KV cache - kv_cache = torch.empty(2, - num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype, - device=device) + kv_cache = torch.empty( + 2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device + ) kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size) # Populate the cache with the context tokens @@ -168,8 +172,8 @@ def create_and_prepopulate_kv_cache( start = common_attn_metadata.query_start_loc_cpu[i] end = common_attn_metadata.query_start_loc_cpu[i + 1] slot_mapping[start:end] = block_table[ - i, - block_indices] * block_size + token_inter_block_offsets.to(device) + i, block_indices + ] * block_size + token_inter_block_offsets.to(device) return kv_cache @@ -222,20 +226,19 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): # Return mock parameters for a single layer head_size = vllm_config.model_config.get_head_size() return { - layer_name: - PerLayerParameters( + layer_name: PerLayerParameters( window_left=-1, # No sliding window logits_soft_cap=0.0, # No soft cap - sm_scale=1.0 / (head_size**0.5) # Standard scale + sm_scale=1.0 / (head_size**0.5), # Standard scale ) for layer_name in layer_names } with unittest.mock.patch( - 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters', - mock_get_per_layer_parameters): - builder = builder_cls(kv_cache_spec, layer_names, vllm_config, - device) + "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters", + mock_get_per_layer_parameters, + ): + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -252,9 +255,11 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): # Instantiate implementation num_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) head_size = vllm_config.model_config.get_head_size() scale = 1.0 / (head_size**0.5) impl = impl_cls( @@ -274,13 +279,9 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): # Run forward pass # NOTE: The query, key, and value are already shaped correctly # in the calling test function. - output = impl.forward(mock_layer, - query, - key, - value, - kv_cache, - attn_metadata, - output=output) + output = impl.forward( + mock_layer, query, key, value, kv_cache, attn_metadata, output=output + ) return output @@ -311,10 +312,12 @@ def _test_backend_correctness( 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ current_platform.seed_everything(42) - vllm_config = create_vllm_config(model_name=model, - max_model_len=max(batch_spec.seq_lens), - block_size=block_size, - num_gpu_blocks=8192) + vllm_config = create_vllm_config( + model_name=model, + max_model_len=max(batch_spec.seq_lens), + block_size=block_size, + num_gpu_blocks=8192, + ) device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -324,9 +327,11 @@ def _test_backend_correctness( seq_lens = batch_spec.seq_lens query_lens = batch_spec.query_lens num_q_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) head_size = vllm_config.model_config.get_head_size() sliding_window = vllm_config.model_config.get_sliding_window() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) @@ -344,21 +349,9 @@ def _test_backend_correctness( context_len = s_len - q_len # Generate Q, K, V for the whole sequence to be used in SDPA - q = torch.randn(q_len, - num_q_heads, - head_size, - dtype=dtype, - device=device) - k_full = torch.randn(s_len, - num_kv_heads, - head_size, - dtype=dtype, - device=device) - v_full = torch.randn(s_len, - num_kv_heads, - head_size, - dtype=dtype, - device=device) + q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device) + k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) # SDPA expects (N, H, L, D), so unsqueeze batch and permute q_sdpa_in = q.unsqueeze(0).transpose(1, 2) @@ -368,7 +361,8 @@ def _test_backend_correctness( if num_q_heads != num_kv_heads: assert num_q_heads % num_kv_heads == 0, ( f"num_q_heads ({num_q_heads}) must be divisible by " - f"num_kv_heads ({num_kv_heads})") + f"num_kv_heads ({num_kv_heads})" + ) repeats = num_q_heads // num_kv_heads k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1) v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1) @@ -378,18 +372,17 @@ def _test_backend_correctness( kv_len = s_len final_mask_mod = partial(mask_mod, context_len=context_len) - block_mask = create_block_mask(final_mask_mod, - B=None, - H=None, - Q_LEN=q_len, - KV_LEN=kv_len, - device=device) - sdpa_out_i = flex_attention(q_sdpa_in, - k_sdpa_in, - v_sdpa_in, - block_mask=block_mask, - scale=scale, - enable_gqa=True) + block_mask = create_block_mask( + final_mask_mod, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, device=device + ) + sdpa_out_i = flex_attention( + q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + block_mask=block_mask, + scale=scale, + enable_gqa=True, + ) all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0)) @@ -408,7 +401,8 @@ def _test_backend_correctness( sdpa_output = torch.cat(all_sdpa_outputs, dim=0) common_attn_metadata = create_common_attn_metadata( - batch_spec, vllm_config.cache_config.block_size, device) + batch_spec, vllm_config.cache_config.block_size, device + ) # 3. Simulate Paged KV Cache and a realistic slot_mapping kv_cache = create_and_prepopulate_kv_cache( @@ -421,7 +415,8 @@ def _test_backend_correctness( device=device, num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000, common_attn_metadata=common_attn_metadata, - randomize_blocks=True) + randomize_blocks=True, + ) # 4. Run vLLM backends and compare # Note: flex_attention has known Triton kernel compatibility issues @@ -437,8 +432,9 @@ def _test_backend_correctness( kv_cache_for_backend = kv_cache.transpose(0, 1) # For FlashInfer default to HND layout and - kv_cache_for_backend = kv_cache_for_backend.transpose( - 2, 3).contiguous().transpose(2, 3) + kv_cache_for_backend = ( + kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) + ) set_kv_cache_layout("HND") backend_output = run_attention_backend( @@ -458,32 +454,45 @@ def _test_backend_correctness( # Check shape and dtype consistency assert backend_output.shape == sdpa_output.shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_output.shape}") + f"SDPA shape {sdpa_output.shape}" + ) assert backend_output.dtype == sdpa_output.dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_output.dtype}") + f"SDPA dtype {sdpa_output.dtype}" + ) assert torch.isfinite(backend_output).all(), ( - f"[{backend_name}] produced non-finite values") + f"[{backend_name}] produced non-finite values" + ) # Check numerical similarity def error_msg(msg: str, backend_name: str): - return (f"[{backend_name}] output differs from SDPA baseline. " - f"{msg}") - - torch.testing.assert_close(backend_output, - sdpa_output, - rtol=rtol, - atol=atol, - msg=partial(error_msg, - backend_name=backend_name)) - - -@pytest.mark.parametrize("batch_spec_name", [ - "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium", "large_decode", "large_prefill", - "single_decode", "single_prefill" -]) + return f"[{backend_name}] output differs from SDPA baseline. {msg}" + + torch.testing.assert_close( + backend_output, + sdpa_output, + rtol=rtol, + atol=atol, + msg=partial(error_msg, backend_name=backend_name), + ) + + +@pytest.mark.parametrize( + "batch_spec_name", + [ + "small_decode", + "small_prefill", + "mixed_small", + "medium_decode", + "medium_prefill", + "mixed_medium", + "large_decode", + "large_prefill", + "single_decode", + "single_prefill", + ], +) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) def test_causal_backend_correctness(batch_spec_name: str, model: str): """Test backend's correctness with causal attention.""" @@ -499,33 +508,33 @@ def causal_mask_mod( return (q_idx + context_len) >= kv_idx batch_spec = BATCH_SPECS[batch_spec_name] - LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION] - if is_torch_equal_or_newer("2.9.0.dev0") else []) + LARGE_BLOCK_BACKENDS = ( + [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + ) SMALL_BLOCK_BACKENDS = [ x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS ] - _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, - causal_mask_mod) + _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod) # Fast FlexAttention needs to run with block_size=128 if LARGE_BLOCK_BACKENDS: - _test_backend_correctness(batch_spec, - model, - LARGE_BLOCK_BACKENDS, - causal_mask_mod, - block_size=128) + _test_backend_correctness( + batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128 + ) SLIDING_WINDOW_BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN, - "FLEX_ATTENTION_SLOW" + _Backend.FLASH_ATTN, + _Backend.FLEX_ATTENTION, + _Backend.TRITON_ATTN, + "FLEX_ATTENTION_SLOW", ] -@pytest.mark.parametrize("batch_spec_name", [ - "small_decode", "small_prefill", "mixed_medium", "large_decode", - "large_prefill" -]) +@pytest.mark.parametrize( + "batch_spec_name", + ["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"], +) @pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) def test_sliding_window_backend_correctness(batch_spec_name: str, model: str): """Test backend's correctness with sliding window attention.""" @@ -544,25 +553,28 @@ def sliding_window_mask_mod( return causal_mask & window_mask batch_spec = BATCH_SPECS[batch_spec_name] - model_config = ModelConfig(model=model, - max_model_len=max(batch_spec.seq_lens)) + model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens)) sliding_window = model_config.get_sliding_window() - sliding_window_mask_mod_fn = partial(sliding_window_mask_mod, - sliding_window=sliding_window) + sliding_window_mask_mod_fn = partial( + sliding_window_mask_mod, sliding_window=sliding_window + ) - LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION] - if is_torch_equal_or_newer("2.9.0.dev0") else []) + LARGE_BLOCK_BACKENDS = ( + [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + ) SMALL_BLOCK_BACKENDS = [ - x for x in SLIDING_WINDOW_BACKENDS_TO_TEST - if x not in LARGE_BLOCK_BACKENDS + x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS ] - _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, - sliding_window_mask_mod_fn) + _test_backend_correctness( + batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn + ) # Fast FlexAttention needs to run with block_size=128 if LARGE_BLOCK_BACKENDS: - _test_backend_correctness(batch_spec, - model, - LARGE_BLOCK_BACKENDS, - sliding_window_mask_mod_fn, - block_size=128) + _test_backend_correctness( + batch_spec, + model, + LARGE_BLOCK_BACKENDS, + sliding_window_mask_mod_fn, + block_size=128, + ) diff --git a/tests/v1/attention/test_attention_backends_selection.py b/tests/v1/attention/test_attention_backends_selection.py index 59e562814946..6464bb52a4ea 100644 --- a/tests/v1/attention/test_attention_backends_selection.py +++ b/tests/v1/attention/test_attention_backends_selection.py @@ -9,17 +9,16 @@ from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.short_conv import ShortConv -from vllm.model_executor.models.minimax_text_01 import ( - MiniMaxText01LinearAttention) +from vllm.model_executor.models.minimax_text_01 import MiniMaxText01LinearAttention from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.short_conv_attn import ( - ShortConvAttentionBackend) +from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend @pytest.mark.parametrize( - "layer_class, init_kwargs, expected_backend, expected_mamba_type", [ + "layer_class, init_kwargs, expected_backend, expected_mamba_type", + [ ( MambaMixer, dict( @@ -77,9 +76,11 @@ ShortConvAttentionBackend, "short_conv", ), - ]) -def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs, - expected_backend, expected_mamba_type): + ], +) +def test_mamba_layers_get_attn_backend( + dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type +): """Test that Mamba-like layers return the correct attention backend.""" layer = layer_class(**init_kwargs) @@ -88,17 +89,23 @@ def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs, assert layer.mamba_type == expected_mamba_type -@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [ - (MambaMixer, Mamba1AttentionBackend, "mamba1"), - (MambaMixer2, Mamba2AttentionBackend, "mamba2"), - (MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"), - (ShortConv, ShortConvAttentionBackend, "short_conv"), -]) -def test_mamba_layers_have_unified_interface(layer_class, expected_backend, - expected_mamba_type): - """Test that all Mamba layers have the unified get_attn_backend +@pytest.mark.parametrize( + "layer_class,expected_backend,expected_mamba_type", + [ + (MambaMixer, Mamba1AttentionBackend, "mamba1"), + (MambaMixer2, Mamba2AttentionBackend, "mamba2"), + (MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"), + (ShortConv, ShortConvAttentionBackend, "short_conv"), + ], +) +def test_mamba_layers_have_unified_interface( + layer_class, expected_backend, expected_mamba_type +): + """Test that all Mamba layers have the unified get_attn_backend interface.""" - assert hasattr(layer_class, 'get_attn_backend'), ( - f"{layer_class.__name__} should have get_attn_backend method") - assert hasattr(layer_class, 'mamba_type'), ( - f"{layer_class.__name__} should have mamba_type property") + assert hasattr(layer_class, "get_attn_backend"), ( + f"{layer_class.__name__} should have get_attn_backend method" + ) + assert hasattr(layer_class, "mamba_type"), ( + f"{layer_class.__name__} should have mamba_type property" + ) diff --git a/tests/v1/attention/test_attention_splitting.py b/tests/v1/attention/test_attention_splitting.py index d81f3da7e9cd..6335d2a7db5e 100644 --- a/tests/v1/attention/test_attention_splitting.py +++ b/tests/v1/attention/test_attention_splitting.py @@ -6,11 +6,13 @@ from tests.v1.attention.test_attention_backends import BATCH_SPECS from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata -from vllm.v1.attention.backends.utils import (UBatchSlice, - _make_metadata_with_slice, - slice_query_start_locs, - split_attn_metadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + UBatchSlice, + _make_metadata_with_slice, + slice_query_start_locs, + split_attn_metadata, + split_decodes_and_prefills, +) from vllm.v1.worker.ubatch_splitting import create_ubatch_slices @@ -79,9 +81,7 @@ def small_decode_metadata(): """Create metadata for small decode batch""" batch_spec = BATCH_SPECS["small_decode"] device = torch.device("cpu") - return create_common_attn_metadata(batch_spec, - block_size=16, - device=device) + return create_common_attn_metadata(batch_spec, block_size=16, device=device) @pytest.fixture @@ -89,9 +89,7 @@ def large_decode_metadata(): """Create metadata for small decode batch""" batch_spec = BATCH_SPECS["large_decode"] device = torch.device("cpu") - return create_common_attn_metadata(batch_spec, - block_size=16, - device=device) + return create_common_attn_metadata(batch_spec, block_size=16, device=device) @pytest.fixture @@ -99,9 +97,7 @@ def mixed_small_metadata(): """Create metadata for mixed small batch""" batch_spec = BATCH_SPECS["mixed_small"] device = torch.device("cpu") - return create_common_attn_metadata(batch_spec, - block_size=16, - device=device) + return create_common_attn_metadata(batch_spec, block_size=16, device=device) # Tests for _make_metadata_with_slice @@ -122,8 +118,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata): def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata): """Test slicing mixed batch metadata""" - ubatch_slice = UBatchSlice(slice(1, 3), - slice(1, 7)) # Requests 1-3, tokens 1-7 + ubatch_slice = UBatchSlice(slice(1, 3), slice(1, 7)) # Requests 1-3, tokens 1-7 result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata) @@ -140,8 +135,7 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): mid_point = num_tokens // 2 ubatch_slices = [ UBatchSlice(slice(0, mid_point), slice(0, mid_point)), - UBatchSlice(slice(mid_point, num_tokens), slice(mid_point, - num_tokens)), + UBatchSlice(slice(mid_point, num_tokens), slice(mid_point, num_tokens)), ] results = split_attn_metadata(ubatch_slices, large_decode_metadata) @@ -159,26 +153,30 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) -def apply_split_decodes_and_prefills(query_lens: list[int], - decode_threshold: int, - require_uniform: bool): +def apply_split_decodes_and_prefills( + query_lens: list[int], decode_threshold: int, require_uniform: bool +): """Helper function to apply split_decodes_and_prefills and return the results.""" device = torch.device("cpu") seq_lens = [10 * (i + 1) for i in range(len(query_lens))] - common_metadata = create_common_attn_metadata(BatchSpec( - seq_lens=seq_lens, query_lens=query_lens), - block_size=16, - device=device) - return split_decodes_and_prefills(common_metadata, - decode_threshold=decode_threshold, - require_uniform=require_uniform) + common_metadata = create_common_attn_metadata( + BatchSpec(seq_lens=seq_lens, query_lens=query_lens), + block_size=16, + device=device, + ) + return split_decodes_and_prefills( + common_metadata, + decode_threshold=decode_threshold, + require_uniform=require_uniform, + ) def test_split_decodes_and_prefills_nonuniform_all_ones(): query_lens = [1, 1, 1] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - apply_split_decodes_and_prefills(query_lens, 1, False)) + apply_split_decodes_and_prefills(query_lens, 1, False) + ) assert num_decodes == 3 assert num_prefills == 0 assert num_decode_tokens == 3 @@ -188,7 +186,8 @@ def test_split_decodes_and_prefills_nonuniform_all_ones(): def test_split_decodes_and_prefills_nonuniform_all_short_decodes(): query_lens = [1, 2, 1, 3, 2, 1, 2] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - apply_split_decodes_and_prefills(query_lens, 3, False)) + apply_split_decodes_and_prefills(query_lens, 3, False) + ) assert num_decodes == 7 assert num_prefills == 0 assert num_decode_tokens == sum(query_lens) @@ -198,7 +197,8 @@ def test_split_decodes_and_prefills_nonuniform_all_short_decodes(): def test_split_decodes_and_prefills_nonuniform_all_prefills(): query_lens = [4, 5, 6, 7] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - apply_split_decodes_and_prefills(query_lens, 3, False)) + apply_split_decodes_and_prefills(query_lens, 3, False) + ) assert num_decodes == 0 assert num_prefills == 4 assert num_decode_tokens == 0 @@ -208,7 +208,8 @@ def test_split_decodes_and_prefills_nonuniform_all_prefills(): def test_split_decodes_and_prefills_nonuniform_mixed_batch(): query_lens = [2, 1, 3, 4, 5, 6, 7, 8] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - apply_split_decodes_and_prefills(query_lens, 4, False)) + apply_split_decodes_and_prefills(query_lens, 4, False) + ) assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4 assert num_prefills == 4 # 5, 6, 7, 8 are all > 4 assert num_decode_tokens == 10 # 2 + 1 + 3 + 4 @@ -218,7 +219,8 @@ def test_split_decodes_and_prefills_nonuniform_mixed_batch(): def test_split_decodes_and_prefills_uniform_all_ones(): query_lens = [1, 1, 1] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - apply_split_decodes_and_prefills(query_lens, 1, True)) + apply_split_decodes_and_prefills(query_lens, 1, True) + ) assert num_decodes == 3 assert num_prefills == 0 assert num_decode_tokens == 3 @@ -228,7 +230,8 @@ def test_split_decodes_and_prefills_uniform_all_ones(): def test_split_decodes_and_prefills_uniform_all_short_decodes(): query_lens = [2, 2, 1, 3, 2, 1, 2] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - apply_split_decodes_and_prefills(query_lens, 3, True)) + apply_split_decodes_and_prefills(query_lens, 3, True) + ) assert num_decodes == 2 assert num_prefills == 5 assert num_decode_tokens == 4 @@ -238,7 +241,8 @@ def test_split_decodes_and_prefills_uniform_all_short_decodes(): def test_split_decodes_and_prefills_uniform_all_prefills(): query_lens = [4, 5, 6, 7] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - apply_split_decodes_and_prefills(query_lens, 3, True)) + apply_split_decodes_and_prefills(query_lens, 3, True) + ) assert num_decodes == 0 assert num_prefills == 4 assert num_decode_tokens == 0 @@ -248,7 +252,8 @@ def test_split_decodes_and_prefills_uniform_all_prefills(): def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes(): query_lens = [2, 2, 2, 4, 5, 6, 7, 8] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - apply_split_decodes_and_prefills(query_lens, 4, True)) + apply_split_decodes_and_prefills(query_lens, 4, True) + ) assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4 assert num_decode_tokens == 6 # 2 + 2 + 2 @@ -258,7 +263,8 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes(): def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes(): query_lens = [2, 1, 2, 4, 5, 6, 7, 8] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - apply_split_decodes_and_prefills(query_lens, 4, True)) + apply_split_decodes_and_prefills(query_lens, 4, True) + ) assert num_decodes == 1 # only the first 2 is taken as decode assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform assert num_decode_tokens == 2 # only the first 2 @@ -274,17 +280,15 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes(): ([32, 40], [8, 8], 4, 1, 2), ], ) -def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point, - expected_first_reqs, - expected_second_reqs): +def test_prefill_split_across_ubatches( + seq_lens, query_lens, split_point, expected_first_reqs, expected_second_reqs +): """Test splitting a prefill across ubatches""" import numpy as np device = torch.device("cpu") batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens) - common = create_common_attn_metadata(batch_spec, - block_size=16, - device=device) + common = create_common_attn_metadata(batch_spec, block_size=16, device=device) num_scheduled_tokens = np.array(query_lens, dtype=np.int32) qsl_np = common.query_start_loc_cpu.numpy() @@ -307,19 +311,19 @@ def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point, # Identify which request is split and how many tokens are in the first chunk split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1) tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx]) - orig_q_lens = (common.query_start_loc_cpu[1:] - - common.query_start_loc_cpu[:-1]) + orig_q_lens = common.query_start_loc_cpu[1:] - common.query_start_loc_cpu[:-1] # Check query length continuity: first-chunk + second-chunk == original qlen # First ubatch last request query length - qlen_first_last = int(first_meta.query_start_loc_cpu[-1] - - first_meta.query_start_loc_cpu[-2]) + qlen_first_last = int( + first_meta.query_start_loc_cpu[-1] - first_meta.query_start_loc_cpu[-2] + ) # Second ubatch first request query length - qlen_second_first = int(second_meta.query_start_loc_cpu[1] - - second_meta.query_start_loc_cpu[0]) + qlen_second_first = int( + second_meta.query_start_loc_cpu[1] - second_meta.query_start_loc_cpu[0] + ) assert qlen_first_last == tokens_in_first_chunk - assert qlen_first_last + qlen_second_first == int( - orig_q_lens[split_req_idx]) + assert qlen_first_last + qlen_second_first == int(orig_q_lens[split_req_idx]) # Check seq_lens adjustments # Context lengths per original request diff --git a/tests/v1/attention/test_chunked_local_attention.py b/tests/v1/attention/test_chunked_local_attention.py index be77256a0d2f..faace3473a28 100644 --- a/tests/v1/attention/test_chunked_local_attention.py +++ b/tests/v1/attention/test_chunked_local_attention.py @@ -7,8 +7,7 @@ import torch from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata -from vllm.v1.attention.backends.utils import ( - make_local_attention_virtual_batches) +from vllm.v1.attention.backends.utils import make_local_attention_virtual_batches @dataclass @@ -46,21 +45,24 @@ class LocalAttentionTestData: [17, 17], # local-batch 5, (batch 1, starting from k[16]) [20, 21], # local-batch 6, (batch 2, starting from k[4]) [22, 23], # local-batch 7, (batch 2, starting from k[8]) - ]), + ], + ), # Case where block indices are not clipped to block table ncols-1 # because tokens_in_last_block == attn_chunk_size - LocalAttentionTestData(batch_spec=BatchSpec( - query_lens=[8], - seq_lens=[12], + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[8], + seq_lens=[12], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[4, 4], + expected_k_seqlens=[4, 4], + expected_local_block_table=[ + [2, 3], + [4, 5], + ], ), - attn_chunk_size=4, - block_size=2, - expected_q_seqlens=[4, 4], - expected_k_seqlens=[4, 4], - expected_local_block_table=[ - [2, 3], - [4, 5], - ]), # Case where all kv_seq positions are involved in attn LocalAttentionTestData( batch_spec=BatchSpec( @@ -76,7 +78,8 @@ class LocalAttentionTestData: [0, 1], [2, 3], [4, 4], - ]), + ], + ), # Case where attn_chunk_size > kv_seq_len # so no extra mini virtual batches are created LocalAttentionTestData( @@ -97,7 +100,8 @@ class LocalAttentionTestData: # is calculated as (attn_chunk_size // block_size) expected_local_block_table=[ [0, 1, 2, 2, 2], - ]), + ], + ), # Block size equal to chunk size # Expect single page per batch in local batch table LocalAttentionTestData( @@ -118,7 +122,8 @@ class LocalAttentionTestData: [1], # local-batch 1, (batch 0, starting from k[4]) [2], # local-batch 1, (batch 0, starting from k[0]) [3], # local-batch 1, (batch 0, starting from k[4]) - ]), + ], + ), # Case where query falls in the second attention chunk # k_toks > 0 1 2 3 4 # q_toks v _____________ @@ -128,17 +133,19 @@ class LocalAttentionTestData: # 3 | 1 1 1 1 # 4 | 1 # where tokens 0,1,2,3 have been pre-computed - LocalAttentionTestData(batch_spec=BatchSpec( - query_lens=[1], - seq_lens=[5], + LocalAttentionTestData( + batch_spec=BatchSpec( + query_lens=[1], + seq_lens=[5], + ), + attn_chunk_size=4, + block_size=2, + expected_q_seqlens=[1], + expected_k_seqlens=[1], + expected_local_block_table=[ + [2, 2], + ], ), - attn_chunk_size=4, - block_size=2, - expected_q_seqlens=[1], - expected_k_seqlens=[1], - expected_local_block_table=[ - [2, 2], - ]), ] @@ -165,9 +172,9 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): ) # Call the function - result = make_local_attention_virtual_batches(attn_chunk_size, - common_attn_metadata, - block_size) + result = make_local_attention_virtual_batches( + attn_chunk_size, common_attn_metadata, block_size + ) # Convert to numpy for easier comparison actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy()) @@ -184,13 +191,11 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): np.testing.assert_array_equal(actual_q_seqlens, expected_q_seqlens) np.testing.assert_array_equal(actual_k_seqlens, expected_k_seqlens) - expected_block_table_tensor =\ - torch.tensor(expected_local_block_table, - dtype=torch.int32, - device=device) + expected_block_table_tensor = torch.tensor( + expected_local_block_table, dtype=torch.int32, device=device + ) print(f"Expected block table:\n{expected_block_table_tensor}") print(f"Actual block table:\n{result.block_table_tensor}") - torch.testing.assert_close(result.block_table_tensor, - expected_block_table_tensor) + torch.testing.assert_close(result.block_table_tensor, expected_block_table_tensor) diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index f2d0a5b2407a..debaa6a5e009 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -1,15 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for v1 MLA backends without GPUModelRunner dependency.""" + from typing import Optional, Union import pytest import torch -from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, - create_standard_kv_cache_spec, - create_vllm_config, - get_attention_backend) +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend, +) from vllm import _custom_ops as ops from vllm.attention.backends.registry import _Backend from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv @@ -17,13 +21,14 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.CUTLASS_MLA, _Backend.FLASHMLA, _Backend.FLASH_ATTN_MLA, - _Backend.TRITON_MLA + _Backend.CUTLASS_MLA, + _Backend.FLASHMLA, + _Backend.FLASH_ATTN_MLA, + _Backend.TRITON_MLA, ] # Remove CUTLASS_MLA from the list if not using sm100 -if not torch.cuda.is_available() or torch.cuda.get_device_properties( - 0).major < 10: +if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) torch.manual_seed(42) @@ -46,45 +51,41 @@ def _convert_dtype_to_torch(dtype): # Define common batch configurations BATCH_SPECS = { - "small_decode": - BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), - "small_prefill": - BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), - "mixed_small": - BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), - "medium_decode": - BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], - query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), - "medium_prefill": - BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), - "mixed_medium": - BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], - query_lens=[1, 1, 1, 7, 7, 7]), - "large_decode": - BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), - "large_prefill": - BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), - "single_decode": - BatchSpec(seq_lens=[1024], query_lens=[1]), - "single_prefill": - BatchSpec(seq_lens=[1024], query_lens=[64]), + "small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": BatchSpec( + seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1], + ), + "medium_prefill": BatchSpec( + seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16] + ), + "mixed_medium": BatchSpec( + seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7] + ), + "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), } def create_and_prepopulate_kv_cache( - kv_c_contexts: list[torch.Tensor], - k_pe_contexts: list[torch.Tensor], - block_size: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - num_blocks: int, - common_attn_metadata: CommonAttentionMetadata, - randomize_blocks: bool = True, - kv_cache_dtype: Optional[str] = None, - scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor: + kv_c_contexts: list[torch.Tensor], + k_pe_contexts: list[torch.Tensor], + block_size: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True, + kv_cache_dtype: Optional[str] = None, + scale: Union[float, torch.Tensor] = 1.0, +) -> torch.Tensor: """Create and prepopulate an MLA KV cache with context data. - + Args: kv_c_contexts: List of latent KV context tensors for each sequence k_pe_contexts: List of key positional embedding context tensors @@ -95,21 +96,23 @@ def create_and_prepopulate_kv_cache( device: Device to create the cache on num_blocks: Total number of blocks in the cache common_attn_metadata: Common attention metadata - randomize_blocks: Whether to randomly permute blocks + randomize_blocks: Whether to randomly permute blocks or use sequential order kv_cache_dtype: Optional kv cache dtype string. When set to "fp8_ds_mla" the cache is populated using the fp8 DeepSeek MLA layout via concat_and_cache_mla. scale: Scaling factor forwarded to concat_and_cache_mla when the fp8 cache layout is requested. - + Returns: MLA KV cache tensor """ batch_size = len(kv_c_contexts) seq_lens = common_attn_metadata.seq_lens_cpu - query_lens = common_attn_metadata.query_start_loc_cpu[ - 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + query_lens = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) context_lens = common_attn_metadata.num_computed_tokens_cpu block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping @@ -118,27 +121,26 @@ def create_and_prepopulate_kv_cache( if use_fp8_ds_mla: if not kv_c_contexts: - raise ValueError("kv_c_contexts cannot be empty when using" - " fp8_ds_mla cache dtype") + raise ValueError( + "kv_c_contexts cannot be empty when using fp8_ds_mla cache dtype" + ) kv_lora_rank = kv_c_contexts[0].shape[-1] rope_dim = k_pe_contexts[0].shape[-1] entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim - kv_cache = torch.zeros(num_blocks, - block_size, - entry_size, - dtype=torch.uint8, - device=device) - scale_tensor = (scale - if isinstance(scale, torch.Tensor) else torch.tensor( - scale, dtype=torch.float32, device=device)) + kv_cache = torch.zeros( + num_blocks, block_size, entry_size, dtype=torch.uint8, device=device + ) + scale_tensor = ( + scale + if isinstance(scale, torch.Tensor) + else torch.tensor(scale, dtype=torch.float32, device=device) + ) scale_tensor = scale_tensor.to(device=device, dtype=torch.float32) else: # Create MLA KV cache: (num_blocks, block_size, head_size) - kv_cache = torch.empty(num_blocks, - block_size, - head_size, - dtype=dtype, - device=device) + kv_cache = torch.empty( + num_blocks, block_size, head_size, dtype=dtype, device=device + ) kv_cache_flat = kv_cache.view(-1, head_size) # Populate the cache with the context tokens @@ -154,8 +156,7 @@ def create_and_prepopulate_kv_cache( start = start_block_idx * block_size if use_fp8_ds_mla: - slots = torch.arange(context_len, device=device, - dtype=torch.long) + start + slots = torch.arange(context_len, device=device, dtype=torch.long) + start ops.concat_and_cache_mla( kv_c_context, k_pe_context.squeeze(1), @@ -165,8 +166,7 @@ def create_and_prepopulate_kv_cache( scale=scale_tensor, ) else: - kv_context = torch.cat( - [kv_c_context, k_pe_context.squeeze(1)], dim=-1) + kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) end = start + kv_context.shape[0] kv_cache_flat[start:end, ...] = kv_context @@ -177,15 +177,14 @@ def create_and_prepopulate_kv_cache( # Permute the context blocks (excluding block 0 which is null) if randomize_blocks: - perm = torch.randperm( - blocks_end - 1) + 1 # Random permutation starting from block 1 + perm = ( + torch.randperm(blocks_end - 1) + 1 + ) # Random permutation starting from block 1 else: - perm = torch.arange( - 1, blocks_end) # Sequential order starting from block 1 + perm = torch.arange(1, blocks_end) # Sequential order starting from block 1 inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) - inv_perm[1:] = torch.argsort( - perm) + 1 # Add 1 to account for starting from block 1 + inv_perm[1:] = torch.argsort(perm) + 1 # Add 1 to account for starting from block 1 kv_cache[1:blocks_end, ...] = kv_cache[perm, ...] # Construct the right block table @@ -206,8 +205,8 @@ def create_and_prepopulate_kv_cache( start = common_attn_metadata.query_start_loc_cpu[i] end = common_attn_metadata.query_start_loc_cpu[i + 1] slot_mapping[start:end] = block_table[ - i, - block_indices] * block_size + token_inter_block_offsets.to(device) + i, block_indices + ] * block_size + token_inter_block_offsets.to(device) return kv_cache @@ -221,15 +220,23 @@ def __init__(self, device: torch.device): self._v_scale = torch.tensor(1.0, device=device) -def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, - layer_names: list[str], vllm_config, - device: torch.device, - common_attn_metadata: CommonAttentionMetadata, - query: torch.Tensor, kv_c: torch.Tensor, - k_pe: torch.Tensor, kv_cache: torch.Tensor, - kv_lora_rank: int, qk_nope_head_dim: int, - qk_rope_head_dim: int, v_head_dim: int, - mock_kv_b_proj) -> torch.Tensor: +def run_attention_backend( + backend: _Backend, + kv_cache_spec: FullAttentionSpec, + layer_names: list[str], + vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, + kv_c: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + kv_lora_rank: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + mock_kv_b_proj, +) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" builder_cls, impl_cls = get_attention_backend(backend) @@ -243,9 +250,11 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, # Instantiate MLA implementation num_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) head_size = vllm_config.model_config.get_head_size() scale = 1.0 / (head_size**0.5) impl = impl_cls( @@ -275,30 +284,35 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, # Create mock layer and output buffer mock_layer = MockAttentionLayer(device) num_tokens = query.shape[0] - output = torch.empty(num_tokens, - num_heads * v_head_dim, - dtype=query.dtype, - device=query.device) + output = torch.empty( + num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device + ) # Run forward pass # NOTE: The query, key, and value are already shaped correctly # in the calling test function. - output = impl.forward(mock_layer, - query, - kv_c, - k_pe, - kv_cache, - attn_metadata, - output=output) + output = impl.forward( + mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output + ) return output -@pytest.mark.parametrize("batch_spec_name", [ - "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium", "large_decode", "large_prefill", - "single_decode", "single_prefill" -]) +@pytest.mark.parametrize( + "batch_spec_name", + [ + "small_decode", + "small_prefill", + "mixed_small", + "medium_decode", + "medium_prefill", + "mixed_medium", + "large_decode", + "large_prefill", + "single_decode", + "single_prefill", + ], +) @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) def test_backend_correctness(dist_init, batch_spec_name: str, model: str): """ @@ -317,9 +331,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ batch_spec = BATCH_SPECS[batch_spec_name] - vllm_config = create_vllm_config(model_name=model, - max_model_len=max(batch_spec.seq_lens), - num_gpu_blocks=2048) + vllm_config = create_vllm_config( + model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048 + ) device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -329,7 +343,8 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): seq_lens = batch_spec.seq_lens query_lens = batch_spec.query_lens num_q_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) head_size = vllm_config.model_config.get_head_size() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) block_size = vllm_config.cache_config.block_size @@ -338,8 +353,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): qk_nope_head_dim = 128 v_head_dim = 128 total_head_size = kv_lora_rank + qk_rope_head_dim - assert kv_lora_rank + qk_rope_head_dim == head_size, \ + assert kv_lora_rank + qk_rope_head_dim == head_size, ( f"MLA dimensions don't match: {total_head_size} != {head_size}" + ) scale = 1.0 / (total_head_size**0.5) # 2. Generate data and compute SDPA reference output for MLA @@ -348,16 +364,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): kv_c_contexts, k_pe_contexts = [], [] # Create shared MLA weight matrices for consistency across all sequences - W_UK = torch.randn(kv_lora_rank, - num_q_heads, - qk_nope_head_dim, - dtype=dtype, - device=device) - W_UV = torch.randn(kv_lora_rank, - num_q_heads, - v_head_dim, - dtype=dtype, - device=device) + W_UK = torch.randn( + kv_lora_rank, num_q_heads, qk_nope_head_dim, dtype=dtype, device=device + ) + W_UV = torch.randn( + kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device + ) kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) for i, backend in enumerate(BACKENDS_TO_TEST): @@ -371,24 +383,19 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # Generate MLA tensors # Q has both nope and rope components: # [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim] - q_c = torch.randn(q_len, - num_q_heads, - qk_nope_head_dim + qk_rope_head_dim, - dtype=dtype, - device=device) + q_c = torch.randn( + q_len, + num_q_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device, + ) # KV_C (latent K/V): [s_len, kv_lora_rank] - kv_c_full = torch.randn(s_len, - kv_lora_rank, - dtype=dtype, - device=device) + kv_c_full = torch.randn(s_len, kv_lora_rank, dtype=dtype, device=device) # K_PE (rope component): [s_len, 1, qk_rope_head_dim] - k_pe_full = torch.randn(s_len, - 1, - qk_rope_head_dim, - dtype=dtype, - device=device) + k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device) # Determine if this is decode or prefill is_decode = [] @@ -404,8 +411,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # Transform q_nope to latent space: q_nope @ W_UK # q_nope: [1, num_heads, qk_nope_head_dim] # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] - ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, - W_UK) # [1, num_heads, kv_lora_rank] + ql_nope = torch.einsum( + "qnh,lnh->qnl", q_nope, W_UK + ) # [1, num_heads, kv_lora_rank] # Build MQA attention inputs # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] @@ -431,25 +439,24 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale + ) sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze( - 0) # [1, num_heads, kv_lora_rank] + 0 + ) # [1, num_heads, kv_lora_rank] # Project back to output space: sdpa_out @ W_UV - sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, - W_UV) + sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, W_UV) sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2) ####################################################### # Prefill path: MHA-style attention with full sequence # Apply kv_b_proj to the full kv_c tensor kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight) - k_nope_full, v_full = kv_nope_full.split( - [qk_nope_head_dim, v_head_dim], dim=-1) + k_nope_full, v_full = kv_nope_full.split([qk_nope_head_dim, v_head_dim], dim=-1) # Build attention inputs for full sequence - q_mha = torch.cat([q_nope, q_pe], - dim=-1) # [q_len, num_heads, total_dim] + q_mha = torch.cat([q_nope, q_pe], dim=-1) # [q_len, num_heads, total_dim] k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) @@ -468,7 +475,8 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # Single attention call with custom mask sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale + ) sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) @@ -497,22 +505,25 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # Create mock kv_b_proj using the same weights as reference implementation from vllm.model_executor.layers.linear import ColumnParallelLinear - mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank, - output_size=num_q_heads * - (qk_nope_head_dim + v_head_dim), - bias=False).to(device=device, - dtype=dtype) + + mock_kv_b_proj = ColumnParallelLinear( + input_size=kv_lora_rank, + output_size=num_q_heads * (qk_nope_head_dim + v_head_dim), + bias=False, + ).to(device=device, dtype=dtype) # Set the mock weights to match our reference implementation # Reshape W_UK and W_UV to match the expected kv_b_proj format # [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim] kv_b_proj_weight = kv_b_proj_weight.view( - kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)) + kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim) + ) mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T) # Create metadata using original batch spec common_attn_metadata = create_common_attn_metadata( - batch_spec, vllm_config.cache_config.block_size, device) + batch_spec, vllm_config.cache_config.block_size, device + ) # 3. Simulate Paged KV Cache and a realistic slot_mapping kv_cache = create_and_prepopulate_kv_cache( @@ -524,41 +535,56 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): device=device, num_blocks=vllm_config.cache_config.num_gpu_blocks, common_attn_metadata=common_attn_metadata, - randomize_blocks=True) + randomize_blocks=True, + ) # 4. Run vLLM backends and compare for i, backend_name in enumerate(BACKENDS_TO_TEST): backend_output = run_attention_backend( - backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, - common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, - kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, - mock_kv_b_proj) + backend_name, + kv_cache_spec, + ["placeholder"], + vllm_config, + device, + common_attn_metadata, + query_vllm, + kv_c_vllm, + k_pe_vllm, + kv_cache, + kv_lora_rank, + qk_nope_head_dim, + qk_rope_head_dim, + v_head_dim, + mock_kv_b_proj, + ) # Check shape and dtype consistency assert backend_output.shape == sdpa_outputs[i].shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_outputs[i].shape}") + f"SDPA shape {sdpa_outputs[i].shape}" + ) assert backend_output.dtype == sdpa_outputs[i].dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_outputs[i].dtype}") + f"SDPA dtype {sdpa_outputs[i].dtype}" + ) assert torch.isfinite(backend_output).all(), ( - f"[{backend_name}] produced non-finite values") + f"[{backend_name}] produced non-finite values" + ) # Check numerical similarity rtol = 1e-2 atol = 5e-1 - max_diff = torch.max(torch.abs(backend_output - - sdpa_outputs[i])).item() + max_diff = torch.max(torch.abs(backend_output - sdpa_outputs[i])).item() max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_outputs[i]) / - torch.abs(sdpa_outputs[i])).item() - all_close = torch.allclose(backend_output, - sdpa_outputs[i], - rtol=rtol, - atol=atol) + torch.abs(backend_output - sdpa_outputs[i]) / torch.abs(sdpa_outputs[i]) + ).item() + all_close = torch.allclose( + backend_output, sdpa_outputs[i], rtol=rtol, atol=atol + ) assert all_close, ( f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})" + ) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index ddad9342fad0..f84951485310 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -10,18 +10,26 @@ import torch from tests.v1.attention.test_mla_backends import ( - BATCH_SPECS, BatchSpec, MockAttentionLayer, - create_and_prepopulate_kv_cache) -from tests.v1.attention.utils import (create_common_attn_metadata, - create_standard_kv_cache_spec, - create_vllm_config) + BATCH_SPECS, + BatchSpec, + MockAttentionLayer, + create_and_prepopulate_kv_cache, +) +from tests.v1.attention.utils import ( + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, +) from vllm import _custom_ops as ops from vllm.attention.ops import flashmla from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.utils import cdiv from vllm.v1.attention.backends.mla.flashmla_sparse import ( - FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata, - FlashMLASparseImpl, FlashMLASparseMetadata) + FlashMLASparseBackend, + FlashMLASparseDecodeAndContextMetadata, + FlashMLASparseImpl, + FlashMLASparseMetadata, +) from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks SPARSE_BACKEND_BATCH_SPECS = { @@ -35,41 +43,42 @@ ] } -SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2, - query_lens=[256] * 2) +SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec( + seq_lens=[1024] * 2, query_lens=[256] * 2 +) SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec( - seq_lens=[256] * 2, query_lens=[256] * 2) + seq_lens=[256] * 2, query_lens=[256] * 2 +) def _dequantize_fp8_ds_mla_entry( - cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, - dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: + cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, dtype: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: """Dequantize a single fp8_ds_mla cache entry back to latent + rope.""" # The first kv_lora_rank bytes store FP8 latent values with one scale per # 128 element tile written as float32 right after the latent payload. - scales = cache_slice.view(torch.float32)[kv_lora_rank // - 4:kv_lora_rank // 4 + 4] - latent = torch.empty(kv_lora_rank, - dtype=torch.float16, - device=cache_slice.device) + scales = cache_slice.view(torch.float32)[kv_lora_rank // 4 : kv_lora_rank // 4 + 4] + latent = torch.empty(kv_lora_rank, dtype=torch.float16, device=cache_slice.device) for tile_idx in range(4): tile_start = tile_idx * 128 tile_end = tile_start + 128 - ops.convert_fp8(latent[tile_start:tile_end], - cache_slice[tile_start:tile_end], - float(scales[tile_idx].item()), - kv_dtype="fp8") + ops.convert_fp8( + latent[tile_start:tile_end], + cache_slice[tile_start:tile_end], + float(scales[tile_idx].item()), + kv_dtype="fp8", + ) latent = latent.to(dtype) rope_offset = kv_lora_rank // 2 + 8 - rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim] + rope_vals = cache_slice.view(dtype)[rope_offset : rope_offset + rope_dim] return latent, rope_vals.clone() def _quantize_dequantize_fp8_ds_mla( - kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, - scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, scale: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: """Round-trip kv_c/k_pe though the fp8_ds_mla cache layout.""" if kv_c.numel() == 0: @@ -81,21 +90,14 @@ def _quantize_dequantize_fp8_ds_mla( num_blocks = max(1, math.ceil(num_tokens / block_size)) entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim - tmp_cache = torch.zeros(num_blocks, - block_size, - entry_size, - dtype=torch.uint8, - device=kv_c.device) - slot_mapping = torch.arange(num_tokens, - dtype=torch.long, - device=kv_c.device) - - ops.concat_and_cache_mla(kv_c, - k_pe, - tmp_cache, - slot_mapping, - kv_cache_dtype="fp8_ds_mla", - scale=scale) + tmp_cache = torch.zeros( + num_blocks, block_size, entry_size, dtype=torch.uint8, device=kv_c.device + ) + slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_c.device) + + ops.concat_and_cache_mla( + kv_c, k_pe, tmp_cache, slot_mapping, kv_cache_dtype="fp8_ds_mla", scale=scale + ) dequant_kv_c = torch.empty_like(kv_c) dequant_k_pe = torch.empty_like(k_pe) @@ -106,7 +108,8 @@ def _quantize_dequantize_fp8_ds_mla( block_offset = slot % block_size cache_slice = tmp_cache[block_idx, block_offset] latent, rope_vals = _dequantize_fp8_ds_mla_entry( - cache_slice, kv_lora_rank, rope_dim, kv_c.dtype) + cache_slice, kv_lora_rank, rope_dim, kv_c.dtype + ) dequant_kv_c[token_idx] = latent dequant_k_pe[token_idx] = rope_vals @@ -123,10 +126,9 @@ def test_sparse_backend_metadata_registration(): dtype_list = backend.get_supported_dtypes() assert torch.bfloat16 in dtype_list - shape = backend.get_kv_cache_shape(num_blocks=2, - block_size=64, - num_kv_heads=1, - head_size=576) + shape = backend.get_kv_cache_shape( + num_blocks=2, block_size=64, num_kv_heads=1, head_size=576 + ) assert shape == (2, 64, 576) @@ -141,13 +143,10 @@ def test_sparse_decode_metadata_filters_prefill_indices(): indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32) - context_indices, new_token_indices = metadata.filter_prefill_indices( - indices) + context_indices, new_token_indices = metadata.filter_prefill_indices(indices) - expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], - dtype=torch.int32) - expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], - dtype=torch.int32) + expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], dtype=torch.int32) + expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], dtype=torch.int32) assert torch.equal(context_indices, expected_context) assert torch.equal(new_token_indices, expected_new_tokens) @@ -162,14 +161,9 @@ def test_sparse_impl_zero_fills_when_metadata_missing(): kv_cache = torch.zeros((1, 1, 1)) output = torch.ones((2, 4)) - result = FlashMLASparseImpl.forward(impl, - dummy_layer, - q, - k_c, - k_pe, - kv_cache, - attn_metadata=None, - output=output) + result = FlashMLASparseImpl.forward( + impl, dummy_layer, q, k_c, k_pe, kv_cache, attn_metadata=None, output=output + ) assert result is output assert torch.all(result == 0) @@ -177,8 +171,7 @@ def test_sparse_impl_zero_fills_when_metadata_missing(): @pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) @pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) -def test_sparse_backend_decode_correctness(dist_init, batch_name, - kv_cache_dtype): +def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype): if not torch.cuda.is_available(): pytest.skip("CUDA is required for sparse MLA decode test") @@ -203,14 +196,13 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, vllm_config = create_vllm_config( model_name="deepseek-ai/DeepSeek-V2-Lite-Chat", max_model_len=max_seqlen, - num_gpu_blocks=max(2048, - cdiv(total_cache_tokens, block_size) + 1), - block_size=block_size) + num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1), + block_size=block_size, + ) model_config = vllm_config.model_config model_config.hf_config = SimpleNamespace( - attn_module_list_cfg=[{ - "topk_tokens": topk_tokens - }]) + attn_module_list_cfg=[{"topk_tokens": topk_tokens}] + ) model_config.hf_text_config = SimpleNamespace( q_lora_rank=None, kv_lora_rank=kv_lora_rank, @@ -221,13 +213,13 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, ) model_config.dtype = dtype model_config.get_num_attention_heads = MethodType( - lambda self, parallel_config: num_heads, model_config) - model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1, - model_config) - model_config.get_head_size = MethodType(lambda self: head_size, - model_config) - model_config.get_sliding_window = MethodType(lambda self: None, - model_config) + lambda self, parallel_config: num_heads, model_config + ) + model_config.get_num_kv_heads = MethodType( + lambda self, parallel_config: 1, model_config + ) + model_config.get_head_size = MethodType(lambda self: head_size, model_config) + model_config.get_sliding_window = MethodType(lambda self: None, model_config) kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -236,16 +228,10 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, scale = 1.0 / math.sqrt(head_size) # Shared MLA projection weights to keep reference and backend in sync - W_UK = torch.randn(kv_lora_rank, - num_heads, - qk_nope_head_dim, - dtype=dtype, - device=device) - W_UV = torch.randn(kv_lora_rank, - num_heads, - v_head_dim, - dtype=dtype, - device=device) + W_UK = torch.randn( + kv_lora_rank, num_heads, qk_nope_head_dim, dtype=dtype, device=device + ) + W_UV = torch.randn(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device) # Build synthetic decode-only workload seq_lens = batch_spec.seq_lens @@ -262,17 +248,15 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, q_len = query_lens[i] ctx_len = s_len - q_len - q_c = torch.rand(q_len, - num_heads, - qk_nope_head_dim + qk_rope_head_dim, - dtype=dtype, - device=device) + q_c = torch.rand( + q_len, + num_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device, + ) kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device) - k_pe_full = torch.rand(s_len, - 1, - qk_rope_head_dim, - dtype=dtype, - device=device) + k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device) kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla( kv_c_full, @@ -298,7 +282,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) sdpa_out = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale + ) sdpa_out = sdpa_out.transpose(1, 2).squeeze(0) sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV) @@ -307,8 +292,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, all_q_vllm.append(q_c) all_kv_c_vllm.append(kv_c_full[ctx_len:]) all_k_pe_vllm.append(k_pe_full[ctx_len:]) - kv_c_contexts.append(kv_c_full[:ctx_len + 1]) - k_pe_contexts.append(k_pe_full[:ctx_len + 1]) + kv_c_contexts.append(kv_c_full[: ctx_len + 1]) + k_pe_contexts.append(k_pe_full[: ctx_len + 1]) query_vllm = torch.cat(all_q_vllm, dim=0) kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) @@ -321,7 +306,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, batch_spec, vllm_config.cache_config.block_size, device, - arange_block_indices=True) + arange_block_indices=True, + ) kv_cache = create_and_prepopulate_kv_cache( kv_c_contexts=kv_c_contexts, @@ -339,31 +325,31 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, builder_cls = FlashMLASparseBackend.get_builder_cls() builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device) - metadata = builder.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + metadata = builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) - starts = np.asarray(common_attn_metadata.query_start_loc_cpu, - dtype=np.int32) + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) seg_lengths = np.diff(starts) positions = np.arange(starts[-1], dtype=np.int32) - np.repeat( - starts[:-1], seg_lengths) + starts[:-1], seg_lengths + ) seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32) prefix_lengths = seq_lengths - seg_lengths positions += np.repeat(prefix_lengths, seg_lengths) pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32) topk = metadata.topk_tokens - debug_indices = torch.arange(topk, device=device, - dtype=torch.int32).unsqueeze(0) + debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0) token_positions = pos_gpu.unsqueeze(1) - causal_mask = (debug_indices <= token_positions) - debug_indices = torch.where(causal_mask, debug_indices, - torch.full_like(debug_indices, -1)) + causal_mask = debug_indices <= token_positions + debug_indices = torch.where( + causal_mask, debug_indices, torch.full_like(debug_indices, -1) + ) # FlashMLASparseImpl now reads top-k indices from the indexer-provided # buffer, so emulate that contract with a simple namespace mock. - debug_indices = debug_indices.expand(metadata.num_actual_tokens, - -1).clone() + debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone() mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices) ok, reason = flashmla.is_flashmla_supported() @@ -372,59 +358,54 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) kv_b_proj_weight = kv_b_proj_weight.view( - kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)) + kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim) + ) - mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank, - output_size=num_heads * - (qk_nope_head_dim + v_head_dim), - bias=False).to(device=device, - dtype=dtype) + mock_kv_b_proj = ColumnParallelLinear( + input_size=kv_lora_rank, + output_size=num_heads * (qk_nope_head_dim + v_head_dim), + bias=False, + ).to(device=device, dtype=dtype) mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous()) impl_cls = FlashMLASparseBackend.get_impl_cls() - impl = impl_cls(num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=1, - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype=vllm_config.cache_config.cache_dtype, - logits_soft_cap=None, - attn_type="decoder", - kv_sharing_target_layer_name=None, - q_lora_rank=None, - kv_lora_rank=kv_lora_rank, - qk_nope_head_dim=qk_nope_head_dim, - qk_rope_head_dim=qk_rope_head_dim, - qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, - v_head_dim=v_head_dim, - kv_b_proj=mock_kv_b_proj, - indexer=mock_indexer) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=1, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + indexer=mock_indexer, + ) impl.process_weights_after_loading(dtype) layer = MockAttentionLayer(device) - out_buffer = torch.empty(metadata.num_actual_tokens, - num_heads * v_head_dim, - dtype=dtype, - device=device) - - backend_output = impl.forward(layer, - query_vllm, - kv_c_vllm, - k_pe_vllm, - kv_cache, - metadata, - output=out_buffer) + out_buffer = torch.empty( + metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device + ) + + backend_output = impl.forward( + layer, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, metadata, output=out_buffer + ) assert backend_output.shape == sdpa_reference.shape assert backend_output.dtype == sdpa_reference.dtype assert torch.isfinite(backend_output).all() - torch.testing.assert_close(backend_output, - sdpa_reference, - rtol=0.5, - atol=0.5) + torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5) @pytest.mark.parametrize( diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 2bea45210ff3..f30a6628b1bf 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -9,9 +9,17 @@ import torch from vllm.attention.backends.registry import _Backend -from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, - LoadConfig, ModelConfig, ModelDType, ParallelConfig, - SchedulerConfig, VllmConfig) +from vllm.config import ( + CacheConfig, + CompilationConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ModelDType, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) from vllm.platforms import current_platform from vllm.utils import resolve_obj_by_qualname from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -21,6 +29,7 @@ @dataclass class BatchSpec: """Specification for a batch configuration (workload shape only).""" + seq_lens: list[int] query_lens: list[int] @@ -38,26 +47,25 @@ def compute_num_tokens(self): def create_common_attn_metadata( - batch_spec: BatchSpec, - block_size: int, - device: torch.device, - max_block_idx: int = 1000, - arange_block_indices: bool = False) -> CommonAttentionMetadata: + batch_spec: BatchSpec, + block_size: int, + device: torch.device, + max_block_idx: int = 1000, + arange_block_indices: bool = False, +) -> CommonAttentionMetadata: """Create CommonAttentionMetadata from a BatchSpec and ModelParams.""" # Create query start locations - query_start_loc = torch.zeros(batch_spec.batch_size + 1, - dtype=torch.int32, - device=device) - query_start_loc[1:] = torch.tensor(batch_spec.query_lens, - dtype=torch.int32, - device=device).cumsum(0) + query_start_loc = torch.zeros( + batch_spec.batch_size + 1, dtype=torch.int32, device=device + ) + query_start_loc[1:] = torch.tensor( + batch_spec.query_lens, dtype=torch.int32, device=device + ).cumsum(0) query_start_loc_cpu = query_start_loc.cpu() num_tokens = batch_spec.compute_num_tokens() # Create sequence lengths - seq_lens = torch.tensor(batch_spec.seq_lens, - dtype=torch.int32, - device=device) + seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device) seq_lens_cpu = seq_lens.cpu() max_seq_len = int(seq_lens_cpu.max()) @@ -72,24 +80,23 @@ def create_common_attn_metadata( max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size if arange_block_indices: num_blocks = batch_spec.batch_size * max_blocks - block_table_tensor = torch.arange(num_blocks, - dtype=torch.int32, - device=device).view( - batch_spec.batch_size, - max_blocks) - slot_mapping = torch.arange(num_tokens, - dtype=torch.int64, - device=device).view(num_tokens) + block_table_tensor = torch.arange( + num_blocks, dtype=torch.int32, device=device + ).view(batch_spec.batch_size, max_blocks) + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device).view( + num_tokens + ) else: - block_table_tensor = torch.randint(0, - max_block_idx, - (batch_spec.batch_size, max_blocks), - dtype=torch.int32, - device=device) - slot_mapping = torch.randint(0, - max_block_idx, (num_tokens, ), - dtype=torch.int64, - device=device) + block_table_tensor = torch.randint( + 0, + max_block_idx, + (batch_spec.batch_size, max_blocks), + dtype=torch.int32, + device=device, + ) + slot_mapping = torch.randint( + 0, max_block_idx, (num_tokens,), dtype=torch.int64, device=device + ) # Calculate max query length max_query_len = max(batch_spec.query_lens) @@ -121,31 +128,21 @@ def get_attention_backend(backend_name: _Backend): Tuple of (backend_builder_class, backend_impl_class) """ backend_map = { - _Backend.FLASH_ATTN: - ("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - if current_platform.is_cuda() else - "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" - ), - _Backend.FLASHINFER: - "vllm.v1.attention.backends.flashinfer.FlashInferBackend", - _Backend.FLEX_ATTENTION: - "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", - _Backend.TRITON_ATTN: - "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", - _Backend.TREE_ATTN: - "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", - _Backend.XFORMERS: - "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", - _Backend.CUTLASS_MLA: - "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", - _Backend.FLASHMLA: - "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", - _Backend.FLASH_ATTN_MLA: - "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", - _Backend.FLASHINFER_MLA: - "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", - _Backend.TRITON_MLA: - "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", + _Backend.FLASH_ATTN: ( + "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + if current_platform.is_cuda() + else "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend" + ), + _Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend", + _Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501 + _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 + _Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", + _Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501 + _Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501 + _Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", + _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501 + _Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501 + _Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501 } if backend_name not in backend_map: @@ -160,29 +157,31 @@ def get_attention_backend(backend_name: _Backend): pytest.skip(f"{backend_name} not available: {e}") -def create_standard_kv_cache_spec( - vllm_config: VllmConfig) -> FullAttentionSpec: +def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec: """Create a FullAttentionSpec from ModelParams only.""" return FullAttentionSpec( block_size=vllm_config.cache_config.block_size, num_kv_heads=vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config), + vllm_config.parallel_config + ), head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, sliding_window=vllm_config.model_config.get_sliding_window(), ) -def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", - tensor_parallel_size: int = 1, - max_model_len: int = 1024, - dtype: Union[ModelDType, torch.dtype] = "auto", - num_gpu_blocks: int = 1000, - block_size: int = 16, - max_num_seqs: int = 256, - max_num_batched_tokens: int = 8192, - enable_chunked_prefill: bool = True, - add_mock_model_methods: bool = True) -> VllmConfig: +def create_vllm_config( + model_name: str = "meta-llama/Meta-Llama-3-8B", + tensor_parallel_size: int = 1, + max_model_len: int = 1024, + dtype: Union[ModelDType, torch.dtype] = "auto", + num_gpu_blocks: int = 1000, + block_size: int = 16, + max_num_seqs: int = 256, + max_num_batched_tokens: int = 8192, + enable_chunked_prefill: bool = True, + add_mock_model_methods: bool = True, +) -> VllmConfig: """Create a VllmConfig for testing with reasonable defaults.""" model_config = ModelConfig( @@ -205,7 +204,8 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", cache_config.num_cpu_blocks = 0 parallel_config = ParallelConfig( - tensor_parallel_size=tensor_parallel_size, ) + tensor_parallel_size=tensor_parallel_size, + ) scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, @@ -223,15 +223,17 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", # but some backends expect to query the model for layer-specific # parameters import types - model_config.get_num_layers = types.MethodType(lambda self: 1, - model_config) + + model_config.get_num_layers = types.MethodType(lambda self: 1, model_config) model_config.get_sliding_window_for_layer = types.MethodType( - lambda self, i: None, model_config) + lambda self, i: None, model_config + ) model_config.get_logits_soft_cap_for_layer = types.MethodType( - lambda self, i: 0.0, model_config) + lambda self, i: 0.0, model_config + ) model_config.get_sm_scale_for_layer = types.MethodType( - lambda self, i: 1.0 / model_config.get_head_size()**0.5, - model_config) + lambda self, i: 1.0 / model_config.get_head_size() ** 0.5, model_config + ) return VllmConfig( model_config=model_config, @@ -244,12 +246,14 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ) -def create_dummy_kv_cache(block_size: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: +def create_dummy_kv_cache( + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int = 100, +) -> torch.Tensor: """Create a dummy KV cache tensor for testing.""" kv_cache = torch.randn( num_blocks, @@ -258,7 +262,8 @@ def create_dummy_kv_cache(block_size: int, num_kv_heads, head_size, dtype=dtype, - device=device) + device=device, + ) return kv_cache @@ -273,75 +278,80 @@ class BackendConfig: # Define all backend configurations of full cudagraph to be tested full_cg_backend_configs = { # FA3 on Hopper - "FA3": - BackendConfig(name="FA3", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", - "VLLM_FLASH_ATTN_VERSION": "3", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", - }, - comp_config={ - "cudagraph_mode": "FULL", - }, - specific_gpu_arch=(9, 0)), + "FA3": BackendConfig( + name="FA3", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", + "VLLM_FLASH_ATTN_VERSION": "3", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, + comp_config={ + "cudagraph_mode": "FULL", + }, + specific_gpu_arch=(9, 0), + ), # FlashMLA on Hopper - "FlashMLA": - BackendConfig(name="FlashMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASHMLA", - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }, - specific_gpu_arch=(9, 0)), + "FlashMLA": BackendConfig( + name="FlashMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASHMLA", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(9, 0), + ), # Cutlass MLA on Blackwell - "CutlassMLA": - BackendConfig( + "CutlassMLA": BackendConfig( name="CutlassMLA", env_vars={ "VLLM_USE_V1": "1", "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", - "FORCE_NUM_KV_SPLITS": - "1", # TODO: remove this when hang issue is fixed + "FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed }, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", }, - specific_gpu_arch=(10, 0)), + specific_gpu_arch=(10, 0), + ), # FlashAttention MLA on Hopper - "FlashAttentionMLA": - BackendConfig(name="FlashAttentionMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", - }, - comp_config={ - "cudagraph_mode": "FULL_DECODE_ONLY", - }, - specific_gpu_arch=(9, 0)), + "FlashAttentionMLA": BackendConfig( + name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0), + ), # FA2 - "FA2": - BackendConfig(name="FA2", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", - "VLLM_FLASH_ATTN_VERSION": "2", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), + "FA2": BackendConfig( + name="FA2", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", + "VLLM_FLASH_ATTN_VERSION": "2", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + ), # Triton Attention - "TritonAttn": - BackendConfig(name="TritonAttn", - env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), + "TritonAttn": BackendConfig( + name="TritonAttn", + env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + ), # FlashInfer - "FlashInfer": - BackendConfig(name="FlashInfer", - env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), + "FlashInfer": BackendConfig( + name="FlashInfer", + env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + ), } diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 8ffe2e57b532..6d870b5640df 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -15,14 +15,12 @@ def _make_model_runner_output( - scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput: + scheduler_output: SchedulerOutput, +) -> ModelRunnerOutput: req_ids = list(scheduler_output.num_scheduled_tokens.keys()) return ModelRunnerOutput( req_ids=req_ids, - req_id_to_index={ - req_id: i - for i, req_id in enumerate(req_ids) - }, + req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)}, sampled_token_ids=[[i] for i in range(len(req_ids))], logprobs=None, prompt_logprobs_dict={}, @@ -75,8 +73,7 @@ def abort_request(): if not abort_order: return req = requests[abort_order.pop(0)] - scheduler.finish_requests(req.request_id, - RequestStatus.FINISHED_ABORTED) + scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED) while sched_outputs: # Abort a scheduled request. @@ -112,8 +109,7 @@ def abort_request(): if not abort_order: return req = requests[abort_order.pop(0)] - scheduler.finish_requests(req.request_id, - RequestStatus.FINISHED_ABORTED) + scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED) while sched_outputs: # Abort a scheduled request. @@ -135,15 +131,19 @@ def test_prefix_caching_for_prefill_dedup(): CHUNK_SIZE = 1000 BLOCK_SIZE = 16 num_prompt_tokens = 100 - scheduler = create_scheduler(async_scheduling=True, - max_num_batched_tokens=CHUNK_SIZE, - enable_prefix_caching=True, - block_size=BLOCK_SIZE) - requests = create_requests(num_requests=5, - num_tokens=num_prompt_tokens, - max_tokens=3, - same_prompt=True, - block_size=BLOCK_SIZE) + scheduler = create_scheduler( + async_scheduling=True, + max_num_batched_tokens=CHUNK_SIZE, + enable_prefix_caching=True, + block_size=BLOCK_SIZE, + ) + requests = create_requests( + num_requests=5, + num_tokens=num_prompt_tokens, + max_tokens=3, + same_prompt=True, + block_size=BLOCK_SIZE, + ) requests_copy = requests.copy() # Two requests with the same prompt. @@ -185,14 +185,18 @@ def test_prefix_caching_for_multi_turn(): BLOCK_SIZE = 16 num_prompt_tokens = 100 num_output_tokens = 200 - scheduler = create_scheduler(async_scheduling=True, - max_num_batched_tokens=CHUNK_SIZE, - enable_prefix_caching=True, - block_size=BLOCK_SIZE) - requests = create_requests(num_requests=5, - num_tokens=num_prompt_tokens, - max_tokens=num_output_tokens, - block_size=BLOCK_SIZE) + scheduler = create_scheduler( + async_scheduling=True, + max_num_batched_tokens=CHUNK_SIZE, + enable_prefix_caching=True, + block_size=BLOCK_SIZE, + ) + requests = create_requests( + num_requests=5, + num_tokens=num_prompt_tokens, + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE, + ) for req in requests: scheduler.add_request(req) @@ -212,14 +216,16 @@ def test_prefix_caching_for_multi_turn(): # Create next-turn requests whose prompts are the full output of the # previous turn. - next_turn_requests = create_requests(num_requests=5, - num_tokens=num_prompt_tokens + - num_output_tokens, - max_tokens=num_output_tokens, - block_size=BLOCK_SIZE) + next_turn_requests = create_requests( + num_requests=5, + num_tokens=num_prompt_tokens + num_output_tokens, + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE, + ) for i, req in enumerate(next_turn_requests): - req.prompt_token_ids = (requests[i].prompt_token_ids + - list(requests[i].output_token_ids)) + req.prompt_token_ids = requests[i].prompt_token_ids + list( + requests[i].output_token_ids + ) req._all_token_ids = req.prompt_token_ids.copy() req.all_token_ids = ConstantList(req._all_token_ids) req.block_hashes = [] @@ -233,5 +239,4 @@ def test_prefix_caching_for_multi_turn(): # Make sure the next-turn requests get prefix cache hit by the previous # requests. for req in next_turn_requests: - assert (req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * - BLOCK_SIZE) + assert req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * BLOCK_SIZE diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py index 6ef15b337ef0..8a52b5bd7897 100644 --- a/tests/v1/core/test_encoder_cache_manager.py +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -10,7 +10,6 @@ # ------------------ Mock Classes ------------------ # class MockRequest: - def __init__(self, request_id, mm_hashes, token_counts): self.request_id = request_id self._token_counts = token_counts @@ -20,8 +19,7 @@ def __init__(self, request_id, mm_hashes, token_counts): data=None, modality="image", identifier=mm_hash, - mm_position=PlaceholderRange(offset=0, - length=self._token_counts[i]), + mm_position=PlaceholderRange(offset=0, length=self._token_counts[i]), ) self.mm_features.append(feature) @@ -167,8 +165,7 @@ def test_schedule_request_multi_images_respect_space_limit(): num_tokens_to_schedule += req.get_num_encoder_tokens(0) compute_budget -= req.get_num_encoder_tokens(0) - assert not manager.can_allocate(req, 1, compute_budget, - num_tokens_to_schedule) + assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) def test_schedule_request_multi_images_respect_compute_limit(): @@ -180,5 +177,4 @@ def test_schedule_request_multi_images_respect_compute_limit(): num_tokens_to_schedule += req.get_num_encoder_tokens(0) compute_budget -= req.get_num_encoder_tokens(0) - assert not manager.can_allocate(req, 1, compute_budget, - num_tokens_to_schedule) + assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 09f43a793db2..aed00a60aeb4 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -8,30 +8,43 @@ import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.config import ModelConfig, SchedulerConfig, VllmConfig -from vllm.multimodal.inputs import (MultiModalFeatureSpec, - MultiModalKwargsItem, PlaceholderRange) +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes, sha256, sha256_cbor from vllm.v1.core.kv_cache_manager import KVCacheManager -# disable yapf here as it formats differently than isort such that both fail -# yapf: disable from vllm.v1.core.kv_cache_utils import ( - BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, - estimate_max_model_len, generate_block_hash_extra_keys, - generate_scheduler_kv_cache_config, get_kv_cache_configs, - get_max_concurrency_for_kv_cache_config, get_request_block_hasher, - hash_block_tokens, init_none_hash, is_kv_cache_spec_uniform, - make_block_hash_with_group_id) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, MLAAttentionSpec, - SlidingWindowSpec, - UniformTypeKVCacheSpecs) + BlockHash, + FreeKVCacheBlockQueue, + KVCacheBlock, + PrefixCachingMetrics, + estimate_max_model_len, + generate_block_hash_extra_keys, + generate_scheduler_kv_cache_config, + get_kv_cache_configs, + get_max_concurrency_for_kv_cache_config, + get_request_block_hasher, + hash_block_tokens, + init_none_hash, + is_kv_cache_spec_uniform, + make_block_hash_with_group_id, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + KVCacheTensor, + MLAAttentionSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request -# yapf: enable - pytestmark = pytest.mark.cpu_test @@ -62,42 +75,49 @@ def make_request( data=MultiModalKwargsItem.dummy("dummy_m"), mm_position=position, identifier=identifier, - modality="image") + modality="image", + ) mm_features.append(mm_feature) - return Request(request_id=request_id, - prompt_token_ids=prompt_token_ids, - mm_features=mm_features if mm_features else None, - sampling_params=SamplingParams(max_tokens=17), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - block_hasher=get_request_block_hasher(block_size, hash_fn)) - - -def new_kv_cache_spec(block_size=16, - num_kv_heads=2, - head_size=64, - dtype=torch.float32, - sliding_window=None): - return FullAttentionSpec(block_size=block_size, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - sliding_window=sliding_window) - - -def new_sliding_window_spec(block_size=16, - num_kv_heads=2, - head_size=64, - dtype=torch.float32, - sliding_window=1): - return SlidingWindowSpec(block_size=block_size, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - sliding_window=sliding_window) + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=mm_features if mm_features else None, + sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn), + ) + + +def new_kv_cache_spec( + block_size=16, + num_kv_heads=2, + head_size=64, + dtype=torch.float32, + sliding_window=None, +): + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + sliding_window=sliding_window, + ) + + +def new_sliding_window_spec( + block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, sliding_window=1 +): + return SlidingWindowSpec( + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + sliding_window=sliding_window, + ) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @@ -106,7 +126,7 @@ def test_none_hash(monkeypatch, hash_fn): # case 1: PYTHONHASHSEED is not set, use random with monkeypatch.context() as m: - m.delenv('PYTHONHASHSEED', raising=False) + m.delenv("PYTHONHASHSEED", raising=False) reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None @@ -115,16 +135,15 @@ def test_none_hash(monkeypatch, hash_fn): # case 2: PYTHONHASHSEED is set, use the seed and hash_fn with monkeypatch.context() as m: - m.setenv('PYTHONHASHSEED', 'python hash seed') + m.setenv("PYTHONHASHSEED", "python hash seed") reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes) - assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH + assert hash_fn("python hash seed") == reloaded_kv_cache_utils.NONE_HASH def test_kv_cache_block(): - # Test KVCacheBlock initialization block = KVCacheBlock(block_id=0) assert block.block_id == 0 @@ -192,10 +211,8 @@ def test_free_kv_cache_block_queue_operations(): for _ in range(4): queue.popleft() assert queue.num_free_blocks == 0 - assert (queue.fake_free_list_head.next_free_block - is queue.fake_free_list_tail) - assert (queue.fake_free_list_tail.prev_free_block - is queue.fake_free_list_head) + assert queue.fake_free_list_head.next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is queue.fake_free_list_head # Attempt to pop from an empty queue with pytest.raises(ValueError) as e: @@ -211,10 +228,8 @@ def test_free_kv_cache_block_queue_append_n(): # fake_head->fake_tail queue.append_n([]) assert queue.num_free_blocks == 0 - assert (queue.fake_free_list_head.next_free_block - is queue.fake_free_list_tail) - assert (queue.fake_free_list_tail.prev_free_block - is queue.fake_free_list_head) + assert queue.fake_free_list_head.next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is queue.fake_free_list_head # Append 1 block # fake_head->b0->fake_tail queue.append_n(blocks[0:1]) @@ -263,15 +278,18 @@ def test_free_kv_cache_block_queue_append_n(): # fake_head->fake_tail invalid_queue.append_n(blocks[0:1]) assert invalid_queue.num_free_blocks == 0 - assert (invalid_queue.fake_free_list_head.next_free_block == - invalid_queue.fake_free_list_tail) + assert ( + invalid_queue.fake_free_list_head.next_free_block + == invalid_queue.fake_free_list_tail + ) def test_free_kv_cache_block_queue_popleft_n(): blocks = [KVCacheBlock(block_id=i) for i in range(6)] # Create an empty FreeKVCacheBlockQueue with these blocks queue = FreeKVCacheBlockQueue( - [blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]]) + [blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]] + ) assert queue.num_free_blocks == 6 assert queue.fake_free_list_head.next_free_block is blocks[1] assert blocks[1].prev_free_block is queue.fake_free_list_head @@ -345,8 +363,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks(): # Append a block back and check again queue.append(block_to_remove) - assert queue.get_all_free_blocks() == \ - blocks[1:2] + blocks[3:] + [block_to_remove] + assert queue.get_all_free_blocks() == blocks[1:2] + blocks[3:] + [block_to_remove] def test_generate_block_hash_extra_keys(): @@ -362,12 +379,12 @@ def test_generate_block_hash_extra_keys(): # Test with no extra keys extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0) - assert extra_keys == ("hash1", ) + assert extra_keys == ("hash1",) assert next_mm_idx == 1 # Test with partial overlap extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 3, 8, 0) - assert extra_keys == ("hash1", ) + assert extra_keys == ("hash1",) assert next_mm_idx == 1 # Test with no overlap @@ -377,7 +394,7 @@ def test_generate_block_hash_extra_keys(): # Test with multiple extra keys extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 15, 0) - assert extra_keys == ('hash1', 'hash2') + assert extra_keys == ("hash1", "hash2") assert next_mm_idx == 2 @@ -405,9 +422,9 @@ def test_generate_block_hash_extra_keys_cache_salt(): # salt is added for the first token extra_keys, _ = generate_block_hash_extra_keys(request, 0, 1, 0) - assert extra_keys == ('salt', ) + assert extra_keys == ("salt",) extra_keys, _ = generate_block_hash_extra_keys(request, 0, 10, 0) - assert extra_keys == ('salt', ) + assert extra_keys == ("salt",) # no salt added for other tokens extra_keys, _ = generate_block_hash_extra_keys(request, 1, 2, 0) @@ -427,8 +444,7 @@ def test_generate_block_hash_extra_keys_cache_salt(): ) # Test with no extra keys - extra_keys, next_mm_idx = generate_block_hash_extra_keys( - request_mm, 0, 5, 0) + extra_keys, next_mm_idx = generate_block_hash_extra_keys(request_mm, 0, 5, 0) assert extra_keys == ("hash1", "salt") assert next_mm_idx == 1 @@ -439,8 +455,9 @@ def test_hash_block_tokens(hash_fn): curr_block_token_ids = (1, 2, 3) extra_keys = ("key1", "key2") - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - curr_block_token_ids, extra_keys) + block_hash = hash_block_tokens( + hash_fn, parent_block_hash, curr_block_token_ids, extra_keys + ) expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys)) assert block_hash == expected @@ -461,10 +478,8 @@ def test_request_block_hasher(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert block_hashes[0] == hash_fn( - (kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", ))) - assert block_hashes[1] == hash_fn( - (block_hashes[0], (3, 4, 5), ("hash2", ))) + assert block_hashes[0] == hash_fn((kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1",))) + assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), ("hash2",))) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @@ -509,8 +524,7 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert block_hashes[0] == hash_fn( - (kv_cache_utils.NONE_HASH, (0, 1, 2), None)) + assert block_hashes[0] == hash_fn((kv_cache_utils.NONE_HASH, (0, 1, 2), None)) assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None)) @@ -587,27 +601,36 @@ def test_get_kv_cache_configs_multiple_workers(): vllm_config = VllmConfig(model_config=model_config) ref_kv_cache_spec = new_kv_cache_spec() - same_kv_cache_specs = [{ - "layer1": new_kv_cache_spec(), - "layer2": new_kv_cache_spec(), - }, { - "layer1": new_kv_cache_spec(), - "layer2": new_kv_cache_spec(), - }] + same_kv_cache_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + ] # Basic case. All things are the same. - kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [ - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 10 - ]) + kv_cache_configs = get_kv_cache_configs( + vllm_config, + same_kv_cache_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -616,10 +639,12 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -629,18 +654,24 @@ def test_get_kv_cache_configs_multiple_workers(): # Different available memory. This is the case for TP. # Use the smallest memory available. - kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [ - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 20 - ]) + kv_cache_configs = get_kv_cache_configs( + vllm_config, + same_kv_cache_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 20, + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -649,10 +680,12 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -661,25 +694,32 @@ def test_get_kv_cache_configs_multiple_workers(): ] # Different KV cache specs. This is the case for PP. - different_layer_specs = [{ - "layer1": new_kv_cache_spec(), - }, { - "layer2": new_kv_cache_spec(), - "layer3": new_kv_cache_spec(), - }] + different_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + }, + { + "layer2": new_kv_cache_spec(), + "layer3": new_kv_cache_spec(), + }, + ] # Different workers have different layers. kv_cache_configs = get_kv_cache_configs( - vllm_config, different_layer_specs, [ + vllm_config, + different_layer_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 10 - ]) + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, - shared_by=["layer1"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer1"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), @@ -688,10 +728,12 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer3"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer2", "layer3"], new_kv_cache_spec()), @@ -700,33 +742,43 @@ def test_get_kv_cache_configs_multiple_workers(): ] # Some layers are the same, some are different. This is the case for TP+PP - tp_pp_kv_cache_specs = [{ - "layer1": new_kv_cache_spec(), - "layer2": new_kv_cache_spec(), - }, { - "layer1": new_kv_cache_spec(), - "layer2": new_kv_cache_spec(), - }, { - "layer3": new_kv_cache_spec(), - }, { - "layer3": new_kv_cache_spec(), - }] + tp_pp_kv_cache_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer3": new_kv_cache_spec(), + }, + { + "layer3": new_kv_cache_spec(), + }, + ] kv_cache_configs = get_kv_cache_configs( - vllm_config, tp_pp_kv_cache_specs, [ + vllm_config, + tp_pp_kv_cache_specs, + [ ref_kv_cache_spec.page_size_bytes * 2 * 10, ref_kv_cache_spec.page_size_bytes * 2 * 10, ref_kv_cache_spec.page_size_bytes * 2 * 10, ref_kv_cache_spec.page_size_bytes * 2 * 10, - ]) + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -735,10 +787,12 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -747,8 +801,9 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, - shared_by=["layer3"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer3"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), @@ -757,8 +812,9 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, - shared_by=["layer3"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 20, shared_by=["layer3"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), @@ -768,26 +824,34 @@ def test_get_kv_cache_configs_multiple_workers(): # Different workers have different types of layers. This is the case for # hybrid models + PP. - different_type_layer_specs = [{ - "layer1": new_kv_cache_spec(), - "layer2": new_kv_cache_spec(), - }, { - "layer3": new_sliding_window_spec(), - "layer4": new_sliding_window_spec(), - }] + different_type_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, + { + "layer3": new_sliding_window_spec(), + "layer4": new_sliding_window_spec(), + }, + ] kv_cache_configs = get_kv_cache_configs( - vllm_config, different_type_layer_specs, [ + vllm_config, + different_type_layer_specs, + [ ref_kv_cache_spec.page_size_bytes * 2 * 10, ref_kv_cache_spec.page_size_bytes * 2 * 10, - ]) + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer2"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer1"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), @@ -797,41 +861,50 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer3"]), - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer4"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer3"] + ), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, shared_by=["layer4"] + ), ], kv_cache_groups=[ KVCacheGroupSpec([], ref_kv_cache_spec), - KVCacheGroupSpec(["layer3", "layer4"], - new_sliding_window_spec()), + KVCacheGroupSpec(["layer3", "layer4"], new_sliding_window_spec()), ], ), ] # When divided into multiple KVCacheGroups, need to ensure the number of # layers per group is similar. - different_type_layer_specs = [{ - "layer1": new_kv_cache_spec(), - "layer2": new_sliding_window_spec(), - "layer3": new_sliding_window_spec(), - }, { - "layer4": new_kv_cache_spec(), - "layer5": new_sliding_window_spec(), - "layer6": new_sliding_window_spec(), - }] + different_type_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + "layer2": new_sliding_window_spec(), + "layer3": new_sliding_window_spec(), + }, + { + "layer4": new_kv_cache_spec(), + "layer5": new_sliding_window_spec(), + "layer6": new_sliding_window_spec(), + }, + ] kv_cache_configs = get_kv_cache_configs( - vllm_config, different_type_layer_specs, [ + vllm_config, + different_type_layer_specs, + [ ref_kv_cache_spec.page_size_bytes * 10, ref_kv_cache_spec.page_size_bytes * 10, - ]) + ], + ) assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer1", "layer2", "layer3"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1", "layer2", "layer3"], + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer1"], ref_kv_cache_spec), @@ -842,8 +915,10 @@ def test_get_kv_cache_configs_multiple_workers(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, - shared_by=["layer4", "layer5", "layer6"]), + KVCacheTensor( + size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer4", "layer5", "layer6"], + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer4"], ref_kv_cache_spec), @@ -854,16 +929,23 @@ def test_get_kv_cache_configs_multiple_workers(): ] # Have conflicting layers. Need to raise an error. - conflicting_layer_specs = [{ - "layer1": new_kv_cache_spec(), - }, { - "layer1": new_sliding_window_spec(), - }] + conflicting_layer_specs = [ + { + "layer1": new_kv_cache_spec(), + }, + { + "layer1": new_sliding_window_spec(), + }, + ] with pytest.raises(AssertionError): - get_kv_cache_configs(vllm_config, conflicting_layer_specs, [ - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ref_kv_cache_spec.page_size_bytes * 2 * 10, - ]) + get_kv_cache_configs( + vllm_config, + conflicting_layer_specs, + [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ], + ) def test_merge_kv_cache_spec(): @@ -908,14 +990,16 @@ def test_merge_kv_cache_spec(): ] with pytest.raises(ValueError): different_sliding_window_layer_specs[0].merge( - different_sliding_window_layer_specs) + different_sliding_window_layer_specs + ) same_sliding_window_layer_specs = [ new_kv_cache_spec(num_kv_heads=32, sliding_window=1), new_kv_cache_spec(num_kv_heads=32, sliding_window=1), ] merged_layer_spec = same_sliding_window_layer_specs[0].merge( - same_sliding_window_layer_specs) + same_sliding_window_layer_specs + ) assert merged_layer_spec.sliding_window == 1 same_sliding_window_layer_spec_with_none = [ @@ -923,7 +1007,8 @@ def test_merge_kv_cache_spec(): new_kv_cache_spec(num_kv_heads=32, sliding_window=None), ] merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge( - same_sliding_window_layer_spec_with_none) + same_sliding_window_layer_spec_with_none + ) assert merged_layer_spec.sliding_window == 1 @@ -960,12 +1045,13 @@ def test_is_kv_cache_spec_uniform(): @pytest.mark.parametrize( - ("model_id", "max_model_len", "want_estimated_max_len"), [ + ("model_id", "max_model_len", "want_estimated_max_len"), + [ ("Qwen/Qwen1.5-7B", 16385, 16384), ("Qwen/Qwen1.5-7B", 16383, 16383), - ]) -def test_estimate_max_model_len(model_id, max_model_len, - want_estimated_max_len): + ], +) +def test_estimate_max_model_len(model_id, max_model_len, want_estimated_max_len): # Create a VllmConfig model_config = ModelConfig( model_id, @@ -991,8 +1077,9 @@ def test_estimate_max_model_len(model_id, max_model_len, dtype=torch.float16, ) # Estimate the maximum model length, 16384 model_len need 8GB - estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, - 8 * GiB_bytes) + estimated_max_len = estimate_max_model_len( + vllm_config, kv_cache_spec, 8 * GiB_bytes + ) assert estimated_max_len == want_estimated_max_len @@ -1006,8 +1093,9 @@ def test_get_max_concurrency_for_kv_cache_config(): dtype="float16", max_model_len=max_model_len, ) - scheduler_config = SchedulerConfig(max_num_batched_tokens=1024, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=1024, enable_chunked_prefill=True + ) vllm_config = VllmConfig( model_config=model_config, @@ -1033,38 +1121,39 @@ def test_get_max_concurrency_for_kv_cache_config(): num_blocks=int(1024 * 1.5), kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec([f"layer_{i}" for i in range(32)], - full_attention_spec), + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], full_attention_spec), ], ) max_concurrency_full_attention = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config_full_attention) + vllm_config, kv_cache_config_full_attention + ) assert max_concurrency_full_attention == 1.5 kv_cache_config_sliding_window = KVCacheConfig( num_blocks=129 * 3, kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec([f"layer_{i}" for i in range(32)], - sliding_window_spec), + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], sliding_window_spec), ], ) max_concurrency_sliding_window = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config_sliding_window) + vllm_config, kv_cache_config_sliding_window + ) assert max_concurrency_sliding_window == 3 kv_cache_config_hybrid_model = KVCacheConfig( num_blocks=(1024 + 129) * 3, kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec([f"layer_{i}" for i in range(32)], - full_attention_spec), - KVCacheGroupSpec([f"layer_{i}" for i in range(32, 64)], - sliding_window_spec), + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], full_attention_spec), + KVCacheGroupSpec( + [f"layer_{i}" for i in range(32, 64)], sliding_window_spec + ), ], ) max_concurrency_hybrid_model = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config_hybrid_model) + vllm_config, kv_cache_config_hybrid_model + ) assert max_concurrency_hybrid_model == 3 @@ -1077,8 +1166,7 @@ def test_allocate_with_lookahead(): KVCacheTensor(size=100, shared_by=["layer1"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], - new_kv_cache_spec(block_size=block_size)), + KVCacheGroupSpec(["layer1"], new_kv_cache_spec(block_size=block_size)), ], ) @@ -1091,8 +1179,7 @@ def test_allocate_with_lookahead(): ) # Test case 1: Requires additional lookahead tokens - kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100) + kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -1101,8 +1188,7 @@ def test_allocate_with_lookahead(): assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks # Test case 2: With precomputed blocks - kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100) + kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) # required_blocks = ceil((3 + 2) /4) = 2 blocks = kv_cache_manager.allocate_slots( request, @@ -1113,8 +1199,7 @@ def test_allocate_with_lookahead(): # Test case 3: With precomputed blocks # required_blocks = ceil((3 + 4) / 4) = 2 - kv_cache_manager = KVCacheManager(kv_cache_config=config, - max_model_len=100) + kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -1131,82 +1216,78 @@ def test_get_kv_cache_config_one_worker(): mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # all layers are full attention -> single group kv_cache_specs_full = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), } kv_cache_config_full = get_kv_cache_configs( - vllm_config, [kv_cache_specs_full], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_full], [mem_per_block_per_layer * 2 * 32] + )[0] print(kv_cache_config_full) assert kv_cache_config_full == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) - ]) + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], + ) # all layers are sliding window -> single group kv_cache_specs_sliding = { - 'layer_1': new_sliding_window_spec(), - 'layer_2': new_sliding_window_spec(), + "layer_1": new_sliding_window_spec(), + "layer_2": new_sliding_window_spec(), } kv_cache_config_sliding = get_kv_cache_configs( - vllm_config, [kv_cache_specs_sliding], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_sliding], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_sliding == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_sliding_window_spec()) - ]) + ], + ) # full + sliding, but disable_hybrid_kv_cache_manager vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = True kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_sliding_window_spec(), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], - new_kv_cache_spec(sliding_window=1)), + KVCacheGroupSpec( + ["layer_1", "layer_2"], new_kv_cache_spec(sliding_window=1) + ), ], ) vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False # full + sliding, with hybrid_kv_cache_manager kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_sliding_window_spec(), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=64, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 64, - shared_by=["layer_1", "layer_2"]), + KVCacheTensor( + size=mem_per_block_per_layer * 64, shared_by=["layer_1", "layer_2"] + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1"], new_kv_cache_spec()), @@ -1216,112 +1297,113 @@ def test_get_kv_cache_config_one_worker(): # 2 full + 4 sliding, 2 layers per group kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(), - 'layer_3': new_sliding_window_spec(), - 'layer_4': new_sliding_window_spec(), - 'layer_5': new_sliding_window_spec(), - 'layer_6': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), + "layer_3": new_sliding_window_spec(), + "layer_4": new_sliding_window_spec(), + "layer_5": new_sliding_window_spec(), + "layer_6": new_sliding_window_spec(), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_3", "layer_4"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_5", "layer_6"]), + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_1", "layer_3", "layer_4"], + ), + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_2", "layer_5", "layer_6"], + ), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer_3", "layer_5"], - new_sliding_window_spec()), - KVCacheGroupSpec(["layer_4", "layer_6"], - new_sliding_window_spec()), + KVCacheGroupSpec(["layer_3", "layer_5"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer_4", "layer_6"], new_sliding_window_spec()), ], ) # 3 full + 7 sliding, pad to 3 full + 9 sliding kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(), - 'layer_3': new_kv_cache_spec(), - 'layer_4': new_sliding_window_spec(), - 'layer_5': new_sliding_window_spec(), - 'layer_6': new_sliding_window_spec(), - 'layer_7': new_sliding_window_spec(), - 'layer_8': new_sliding_window_spec(), - 'layer_9': new_sliding_window_spec(), - 'layer_10': new_sliding_window_spec(), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(), + "layer_3": new_kv_cache_spec(), + "layer_4": new_sliding_window_spec(), + "layer_5": new_sliding_window_spec(), + "layer_6": new_sliding_window_spec(), + "layer_7": new_sliding_window_spec(), + "layer_8": new_sliding_window_spec(), + "layer_9": new_sliding_window_spec(), + "layer_10": new_sliding_window_spec(), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 3 * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 3 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ KVCacheTensor( size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_4", "layer_5", "layer_6"]), + shared_by=["layer_1", "layer_4", "layer_5", "layer_6"], + ), KVCacheTensor( size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_7", "layer_8", "layer_9"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_3", "layer_10"]), + shared_by=["layer_2", "layer_7", "layer_8", "layer_9"], + ), + KVCacheTensor( + size=mem_per_block_per_layer * 32, shared_by=["layer_3", "layer_10"] + ), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"], - new_kv_cache_spec()), - KVCacheGroupSpec(["layer_4", "layer_7", "layer_10"], - new_sliding_window_spec()), - KVCacheGroupSpec(["layer_5", "layer_8"], - new_sliding_window_spec()), - KVCacheGroupSpec(["layer_6", "layer_9"], - new_sliding_window_spec()), + KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"], new_kv_cache_spec()), + KVCacheGroupSpec( + ["layer_4", "layer_7", "layer_10"], new_sliding_window_spec() + ), + KVCacheGroupSpec(["layer_5", "layer_8"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer_6", "layer_9"], new_sliding_window_spec()), ], ) # different hidden size kv_cache_specs_hybrid = { - 'layer_1': new_kv_cache_spec(head_size=128), - 'layer_2': new_kv_cache_spec(head_size=64), + "layer_1": new_kv_cache_spec(head_size=128), + "layer_2": new_kv_cache_spec(head_size=64), } kv_cache_config_hybrid = get_kv_cache_configs( - vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 3 * 32])[0] + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 3 * 32] + )[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 32 * 2, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 32 * 2, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, shared_by=["layer_2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], - UniformTypeKVCacheSpecs( - block_size=16, - kv_cache_specs=kv_cache_specs_hybrid)) - ]) + KVCacheGroupSpec( + ["layer_1", "layer_2"], + UniformTypeKVCacheSpecs( + block_size=16, kv_cache_specs=kv_cache_specs_hybrid + ), + ) + ], + ) # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 kv_cache_config_override_blocks = get_kv_cache_configs( - vllm_config, [kv_cache_specs_full], - [mem_per_block_per_layer * 2 * 32])[0] + vllm_config, [kv_cache_specs_full], [mem_per_block_per_layer * 2 * 32] + )[0] assert kv_cache_config_override_blocks == KVCacheConfig( num_blocks=16, kv_cache_tensors=[ - KVCacheTensor(size=mem_per_block_per_layer * 16, - shared_by=["layer_1"]), - KVCacheTensor(size=mem_per_block_per_layer * 16, - shared_by=["layer_2"]), + KVCacheTensor(size=mem_per_block_per_layer * 16, shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 16, shared_by=["layer_2"]), ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) - ]) + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], + ) def test_get_kv_cache_configs_attention_free(): @@ -1340,42 +1422,44 @@ def test_get_kv_cache_configs_attention_free(): def test_generate_uniform_type_kv_cache_specs(): # All layers are full attention, can be merged kv_cache_specs = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(head_size=128), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(head_size=128), } uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) assert uniform_spec == UniformTypeKVCacheSpecs( - block_size=16, kv_cache_specs=kv_cache_specs) + block_size=16, kv_cache_specs=kv_cache_specs + ) # Full attention + sliding window, cannot be merged kv_cache_specs = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_sliding_window_spec(sliding_window=1), + "layer_1": new_kv_cache_spec(), + "layer_2": new_sliding_window_spec(sliding_window=1), } uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) assert uniform_spec is None # different order of full attention + sliding window, cannot be merged kv_cache_specs = { - 'layer_1': new_sliding_window_spec(sliding_window=1), - 'layer_2': new_kv_cache_spec(), + "layer_1": new_sliding_window_spec(sliding_window=1), + "layer_2": new_kv_cache_spec(), } uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) assert uniform_spec is None # Same-size sliding window, can be merged kv_cache_specs = { - 'layer_1': new_sliding_window_spec(sliding_window=1), - 'layer_2': new_sliding_window_spec(sliding_window=1, head_size=128), + "layer_1": new_sliding_window_spec(sliding_window=1), + "layer_2": new_sliding_window_spec(sliding_window=1, head_size=128), } uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) assert uniform_spec == UniformTypeKVCacheSpecs( - block_size=16, kv_cache_specs=kv_cache_specs) + block_size=16, kv_cache_specs=kv_cache_specs + ) # different block sizes, cannot be merged kv_cache_specs = { - 'layer_1': new_kv_cache_spec(block_size=16), - 'layer_2': new_kv_cache_spec(block_size=32), + "layer_1": new_kv_cache_spec(block_size=16), + "layer_2": new_kv_cache_spec(block_size=32), } uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) assert uniform_spec is None @@ -1383,38 +1467,39 @@ def test_generate_uniform_type_kv_cache_specs(): def test_generate_scheduler_kv_cache_config(): kv_cache_specs = { - 'layer_1': new_kv_cache_spec(), - 'layer_2': new_kv_cache_spec(head_size=128), + "layer_1": new_kv_cache_spec(), + "layer_2": new_kv_cache_spec(head_size=128), } kv_cache_configs = [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec(['layer_1', 'layer_2'], - UniformTypeKVCacheSpecs( - block_size=16, - kv_cache_specs=kv_cache_specs)), + KVCacheGroupSpec( + ["layer_1", "layer_2"], + UniformTypeKVCacheSpecs( + block_size=16, kv_cache_specs=kv_cache_specs + ), + ), ], ) ] - scheduler_kv_cache_config = generate_scheduler_kv_cache_config( - kv_cache_configs) + scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) assert scheduler_kv_cache_config == KVCacheConfig( num_blocks=10, kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec()) - ], + kv_cache_groups=[KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())], ) def new_mla_spec(cache_dtype_str=None): - return MLAAttentionSpec(block_size=16, - num_kv_heads=16, - head_size=64, - dtype=torch.float32, - cache_dtype_str=cache_dtype_str) + return MLAAttentionSpec( + block_size=16, + num_kv_heads=16, + head_size=64, + dtype=torch.float32, + cache_dtype_str=cache_dtype_str, + ) def test_merge_mla_spec(): diff --git a/tests/v1/core/test_kv_sharing.py b/tests/v1/core/test_kv_sharing.py index 31a74101faf9..328f2640f218 100644 --- a/tests/v1/core/test_kv_sharing.py +++ b/tests/v1/core/test_kv_sharing.py @@ -26,8 +26,7 @@ def test_initialize_kv_cache_for_kv_sharing_different_attn_groups(): # However, if they have different attention backends, they will be # placed in different attention groups for KV cache group 0 kv_cache_groups = [ - KVCacheGroupSpec(["model.layers.0", "model.layers.1"], - new_kv_cache_spec()), + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], new_kv_cache_spec()), ] add_kv_sharing_layers_to_kv_cache_groups( @@ -38,7 +37,10 @@ def test_initialize_kv_cache_for_kv_sharing_different_attn_groups(): # Check that the layers were added to the correct KV cache group assert len(kv_cache_groups) == 1 assert kv_cache_groups[0].layer_names == [ - "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + "model.layers.0", + "model.layers.1", + "model.layers.2", + "model.layers.3", ] @@ -53,8 +55,7 @@ def test_initialize_kv_cache_for_kv_sharing_same_attn_groups(): } kv_cache_groups = [ - KVCacheGroupSpec(["model.layers.0", "model.layers.1"], - new_kv_cache_spec()), + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], new_kv_cache_spec()), ] add_kv_sharing_layers_to_kv_cache_groups( @@ -65,14 +66,17 @@ def test_initialize_kv_cache_for_kv_sharing_same_attn_groups(): # Check that the layers were added to the correct KV cache group assert len(kv_cache_groups) == 1 assert kv_cache_groups[0].layer_names == [ - "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + "model.layers.0", + "model.layers.1", + "model.layers.2", + "model.layers.3", ] def test_initialize_kv_cache_for_kv_sharing_no_attn_groups(): """ Test KV sharing set up when no attention groups are provided. - This is the case for the TPU model runner, which doesn't have + This is the case for the TPU model runner, which doesn't have support for attention groups yet. """ shared_kv_cache_layers = { @@ -92,9 +96,5 @@ def test_initialize_kv_cache_for_kv_sharing_no_attn_groups(): # Check that the layers were added to the correct KV cache group assert len(kv_cache_groups) == 2 - assert kv_cache_groups[0].layer_names == [ - "model.layers.0", "model.layers.2" - ] - assert kv_cache_groups[1].layer_names == [ - "model.layers.1", "model.layers.3" - ] + assert kv_cache_groups[0].layer_names == ["model.layers.0", "model.layers.2"] + assert kv_cache_groups[1].layer_names == ["model.layers.1", "model.layers.3"] diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 93ad4d8080e6..d08c1bcc57bd 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -10,20 +10,32 @@ import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved -from vllm.multimodal.inputs import (MultiModalFeatureSpec, - MultiModalKwargsItem, PlaceholderRange) +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.sampling_params import SamplingParams from vllm.utils import sha256, sha256_cbor from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, get_block_hash, - get_group_id, - get_request_block_hasher, - hash_block_tokens, init_none_hash, - make_block_hash_with_group_id) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, SlidingWindowSpec) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + BlockHashWithGroupId, + KVCacheBlock, + get_block_hash, + get_group_id, + get_request_block_hasher, + hash_block_tokens, + init_none_hash, + make_block_hash_with_group_id, +) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + SlidingWindowSpec, +) pytestmark = pytest.mark.cpu_test @@ -56,19 +68,21 @@ def make_request( data=MultiModalKwargsItem.dummy("dummy_m"), mm_position=position, identifier=identifier, - modality="image") + modality="image", + ) mm_features.append(mm_feature) - return Request(request_id=request_id, - prompt_token_ids=prompt_token_ids, - mm_features=mm_features if mm_features else None, - sampling_params=SamplingParams( - max_tokens=17, prompt_logprobs=prompt_logprobs), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - block_hasher=get_request_block_hasher(block_size, hash_fn)) + return Request( + request_id=request_id, + prompt_token_ids=prompt_token_ids, + mm_features=mm_features if mm_features else None, + sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn), + ) def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: @@ -84,8 +98,9 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: ) -def make_kv_cache_config_hybrid_model(block_size: int, - num_blocks: int) -> KVCacheConfig: +def make_kv_cache_config_hybrid_model( + block_size: int, num_blocks: int +) -> KVCacheConfig: return KVCacheConfig( num_blocks=num_blocks, kv_cache_tensors=[], @@ -96,19 +111,15 @@ def make_kv_cache_config_hybrid_model(block_size: int, ), KVCacheGroupSpec( ["layer2"], - SlidingWindowSpec(block_size, - 1, - 1, - torch.float32, - sliding_window=2 * block_size), + SlidingWindowSpec( + block_size, 1, 1, torch.float32, sliding_window=2 * block_size + ), ), KVCacheGroupSpec( ["layer3"], - SlidingWindowSpec(block_size, - 1, - 1, - torch.float32, - sliding_window=2 * block_size), + SlidingWindowSpec( + block_size, 1, 1, torch.float32, sliding_window=2 * block_size + ), ), ], ) @@ -116,7 +127,6 @@ def make_kv_cache_config_hybrid_model(block_size: int, @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_prefill(hash_fn): - block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -136,17 +146,16 @@ def test_prefill(hash_fn): assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) # Check full block metadata parent_block_hash = None for block_id in (1, 2, 3): - block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - block_tokens) + block_tokens = tuple(all_token_ids[(block_id - 1) * 16 : block_id * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) blk_hash = manager.block_pool.blocks[block_id].block_hash assert blk_hash is not None assert get_block_hash(blk_hash) == block_hash @@ -155,24 +164,23 @@ def test_prefill(hash_fn): parent_block_hash = block_hash # Check partial block metadata - for block_id in (4, ): + for block_id in (4,): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids, block_size, - hash_fn) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], ) + assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([5], ) + blocks = manager.allocate_slots( + req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([5],) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -191,30 +199,27 @@ def test_prefill(hash_fn): # [unique_req1 (5)] # [common (3, 2, 1)] assert [ - b.block_id - for b in manager.block_pool.free_block_queue.get_all_free_blocks() + b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 - req2 = make_request("2", common_token_ids + unique_token_ids, block_size, - hash_fn) + req2 = make_request("2", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(req2.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], ) + assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req2, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([6], ) + blocks = manager.allocate_slots( + req2, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([6],) # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. assert free_block_queue.num_free_blocks == 6 - assert all( - [b.ref_cnt == 0 for b in free_block_queue.get_all_free_blocks()]) + assert all([b.ref_cnt == 0 for b in free_block_queue.get_all_free_blocks()]) assert len([b for b in free_block_queue.get_all_free_blocks()]) == 6 manager.free(req2) @@ -224,19 +229,23 @@ def test_prefill(hash_fn): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 16 * 10, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req3, 16 * 10, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) # This block ID order also checks the eviction order. - assert blocks is not None and blocks.get_block_ids() == ([ - 7, 8, 9, 10, 4, 5, 6, 3, 2, 1 - ], ) + assert blocks is not None and blocks.get_block_ids() == ( + [7, 8, 9, 10, 4, 5, 6, 3, 2, 1], + ) assert free_block_queue.num_free_blocks == 0 - assert (free_block_queue.fake_free_list_head.next_free_block - is free_block_queue.fake_free_list_tail) - assert (free_block_queue.fake_free_list_tail.prev_free_block - is free_block_queue.fake_free_list_head) + assert ( + free_block_queue.fake_free_list_head.next_free_block + is free_block_queue.fake_free_list_tail + ) + assert ( + free_block_queue.fake_free_list_tail.prev_free_block + is free_block_queue.fake_free_list_head + ) def test_prefill_hybrid_model(): @@ -261,20 +270,20 @@ def test_prefill_hybrid_model(): assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [ - 5, 6, 7, 8 - ], [9, 10, 11, 12]) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ( + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + ) # Check full block metadata parent_block_hash = None - for length, block_ids in zip((1, 2, 3), - ((1, 5, 9), (2, 6, 10), (3, 7, 11))): - block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16]) - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - block_tokens) + for length, block_ids in zip((1, 2, 3), ((1, 5, 9), (2, 6, 10), (3, 7, 11))): + block_tokens = tuple(all_token_ids[(length - 1) * 16 : length * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) for group_id, block_id in enumerate(block_ids): blk_hash = manager.block_pool.blocks[block_id].block_hash assert blk_hash is not None @@ -291,17 +300,15 @@ def test_prefill_hybrid_model(): # Cache hit in the common prefix # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids, block_size, - hash_fn) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, - 7], [0, 10, 11]) + assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15]) for block_per_group in computed_blocks.blocks: for block in block_per_group: @@ -313,55 +320,70 @@ def test_prefill_hybrid_model(): manager.free(req1) cached_block_hash_to_block_bak = copy.copy( - manager.block_pool.cached_block_hash_to_block._cache) + manager.block_pool.cached_block_hash_to_block._cache + ) - def test_partial_request_hit(request_id: str, - hash_to_evict: list[BlockHashWithGroupId], - expect_hit_length: int): - req = make_request(request_id, common_token_ids + unique_token_ids, - block_size, sha256) + def test_partial_request_hit( + request_id: str, + hash_to_evict: list[BlockHashWithGroupId], + expect_hit_length: int, + ): + req = make_request( + request_id, common_token_ids + unique_token_ids, block_size, sha256 + ) for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block._cache.pop( - hash_with_group_id) + manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert len(req.block_hashes) == 3 assert num_computed_tokens == expect_hit_length * block_size for block_per_group in computed_blocks.blocks: assert len(block_per_group) == num_computed_tokens // block_size for hash_with_group_id in hash_to_evict: - manager.block_pool.cached_block_hash_to_block._cache[ - hash_with_group_id] = cached_block_hash_to_block_bak[ - hash_with_group_id] + manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = ( + cached_block_hash_to_block_bak[hash_with_group_id] + ) manager.free(req) # Evict the blocks outside sliding window, does not affect the hit length. - test_partial_request_hit("2", [ - make_block_hash_with_group_id(block_hashes[0], 1), - make_block_hash_with_group_id(block_hashes[0], 2) - ], 3) + test_partial_request_hit( + "2", + [ + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), + ], + 3, + ) # Evict the first block of full attention, makes total cache miss. test_partial_request_hit( - "3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0) + "3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0 + ) # Evict the last block of all layers, reduces the hit length to 2. - test_partial_request_hit("4", [ - make_block_hash_with_group_id(block_hashes[2], 0), - make_block_hash_with_group_id(block_hashes[2], 1), - make_block_hash_with_group_id(block_hashes[2], 2), - ], 2) + test_partial_request_hit( + "4", + [ + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[2], 1), + make_block_hash_with_group_id(block_hashes[2], 2), + ], + 2, + ) # Evict the last block of full attention, reduces the hit length to 2. test_partial_request_hit( - "5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2) + "5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2 + ) # Evict the last block of sliding window, reduces the hit length to 2. test_partial_request_hit( - "6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2) + "6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2 + ) # Evict the last block of sliding window, reduces the hit length to 2. test_partial_request_hit( - "7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2) + "7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2 + ) # Evict different set of blocks for full attention and sliding window makes # total cache miss. @@ -369,20 +391,24 @@ def test_partial_request_hit(request_id: str, # The cache hit length of sliding window is 2 * block_size. # Then it is cache miss as the two type of layers # have different hit length. - test_partial_request_hit("8", [ - make_block_hash_with_group_id(block_hashes[2], 0), - make_block_hash_with_group_id(block_hashes[0], 1), - make_block_hash_with_group_id(block_hashes[0], 2), - ], 0) + test_partial_request_hit( + "8", + [ + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), + ], + 0, + ) def test_prefill_plp(): - '''Test prefill with APC and some prompt logprobs (plp) requests. + """Test prefill with APC and some prompt logprobs (plp) requests. 1. Schedule plp request and validate APC block allocation 2. Schedule non-plp request and validate blocks 3. Schedule plp request; no hit should occur; validate blocks - ''' + """ block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -400,28 +426,23 @@ def test_prefill_plp(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", - all_token_ids, - block_size, - hash_fn, - prompt_logprobs=5) + req0 = make_request("0", all_token_ids, block_size, hash_fn, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) req0_block_hashes = [b.block_hash for b in blocks.blocks[0]] # Check full block metadata parent_block_hash = None for block_id in (1, 2, 3): - block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) - block_hash = hash_block_tokens(hash_fn, parent_block_hash, - block_tokens) - blk_hash = (manager.block_pool.blocks[block_id].block_hash) + block_tokens = tuple(all_token_ids[(block_id - 1) * 16 : block_id * 16]) + block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) + blk_hash = manager.block_pool.blocks[block_id].block_hash assert blk_hash is not None assert get_block_hash(blk_hash) == block_hash assert get_group_id(blk_hash) == 0 @@ -429,7 +450,7 @@ def test_prefill_plp(): parent_block_hash = block_hash # Check partial block metadata - for block_id in (4, ): + for block_id in (4,): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -437,17 +458,16 @@ def test_prefill_plp(): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids, block_size, - hash_fn) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 - assert computed_blocks.get_block_ids() == ([1, 2, 3], ) + assert computed_blocks.get_block_ids() == ([1, 2, 3],) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 - blocks = manager.allocate_slots(req1, num_new_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([5], ) + blocks = manager.allocate_slots( + req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([5],) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -465,30 +485,27 @@ def test_prefill_plp(): # [unique_req1 (5)] # [common (3, 2, 1)] assert [ - b.block_id - for b in manager.block_pool.free_block_queue.get_all_free_blocks() + b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [6, 7, 8, 9, 10, 4, 5, 3, 2, 1] # Request #2 is a prompt-logprobs request: # NO cache hit in the common prefix; duplicates request #0 cached blocks unique_token_ids = [3] * 6 - req2 = make_request("2", - common_token_ids + unique_token_ids, - block_size, - hash_fn, - prompt_logprobs=5) + req2 = make_request( + "2", common_token_ids + unique_token_ids, block_size, hash_fn, prompt_logprobs=5 + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(req2.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req2, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes - assert block_ids != ([1, 2, 3, 4], ) + assert block_ids != ([1, 2, 3, 4],) # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. @@ -512,26 +529,29 @@ def test_decode(): # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 - req0 = make_request("0", common_token_ids + unique_token_ids, block_size, - sha256) + req0 = make_request("0", common_token_ids + unique_token_ids, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 55, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) + blocks = manager.allocate_slots( + req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) # Append slots without allocating a new block. req0.num_computed_tokens = 55 for _ in range(4): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 4, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 4, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - assert manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id][-1].block_hash is None + assert ( + manager.coordinator.single_type_managers[0] + .req_to_blocks[req0.request_id][-1] + .block_hash + is None + ) # Append slots with allocating a new block. req0.num_computed_tokens = 59 @@ -539,14 +559,22 @@ def test_decode(): # the preallocated block. for _ in range(9 + 10): req0.append_output_token_ids(7) - new_blocks = manager.allocate_slots(req0, 19, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 19, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 1 - assert manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id][-2].block_hash is not None - assert manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id][-1].block_hash is None + assert ( + manager.coordinator.single_type_managers[0] + .req_to_blocks[req0.request_id][-2] + .block_hash + is not None + ) + assert ( + manager.coordinator.single_type_managers[0] + .req_to_blocks[req0.request_id][-1] + .block_hash + is None + ) def test_evict(): @@ -562,22 +590,22 @@ def test_evict(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, 5 * 16 + 7, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req0, 5 * 16 + 7, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) # 5 full + 1 partial assert blocks is not None and len(blocks.blocks[0]) == 6 # 3 blocks. - req1 = make_request("1", list(range(last_token_id, - last_token_id + 3 * 16)), block_size, - sha256) + req1 = make_request( + "1", list(range(last_token_id, last_token_id + 3 * 16)), block_size, sha256 + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 3 * 16, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req1, 3 * 16, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks last_token_id += 3 * 16 @@ -588,19 +616,18 @@ def test_evict(): manager.free(req1) assert manager.block_pool.free_block_queue.num_free_blocks == 10 assert [ - b.block_id - for b in manager.block_pool.free_block_queue.get_all_free_blocks() + b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert computed_blocks.get_block_ids() == ([1, 2], ) + assert computed_blocks.get_block_ids() == ([1, 2],) assert num_computed_tokens == 2 * 16 - blocks = manager.allocate_slots(req2, 3, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([10], ) + blocks = manager.allocate_slots( + req2, 3, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([10],) assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -622,9 +649,9 @@ def test_hash_block_correct_reuse(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 1 # Deallocate the block. @@ -636,13 +663,12 @@ def test_hash_block_correct_reuse(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req, num_tokens - 1, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req, num_tokens - 1, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 1 - assert manager.block_pool.blocks[blocks.blocks[0] - [0].block_id].block_hash is None + assert manager.block_pool.blocks[blocks.blocks[0][0].block_id].block_hash is None def test_computed_blocks_not_evicted(): @@ -663,21 +689,22 @@ def test_computed_blocks_not_evicted(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req0, num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req0, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 1 # Allocate another block. - req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), - block_size, sha256) + req1 = make_request( + "1", list(range(num_tokens, num_tokens * 2)), block_size, sha256 + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req1, num_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 @@ -693,9 +720,12 @@ def test_computed_blocks_not_evicted(): assert computed_blocks.blocks[0][0].block_id == 1 assert num_computed_tokens == block_size - blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req2, + num_tokens * 2 - num_tokens, + len(computed_blocks.blocks[0]) * 16, + computed_blocks, + ) assert blocks is not None and len(blocks.blocks[0]) == 1 assert blocks.blocks[0][0].block_id == 2 @@ -711,29 +741,29 @@ def test_basic_prefix_caching_disabled(): enable_caching=False, ) - req1 = make_request("1", list(range(10)), block_size, - sha256) # 2 blocks and some more + req1 = make_request( + "1", list(range(10)), block_size, sha256 + ) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req1, 10, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req1, 10, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 3 # Free the blocks. manager.free(req1) # No caching. - req2 = make_request("2", list(range(16)), block_size, - sha256) # shared prefix + req2 = make_request("2", list(range(16)), block_size, sha256) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req2, 16, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req2, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None and len(blocks.blocks[0]) == 4 # New requests should not have any blocks. @@ -741,9 +771,9 @@ def test_basic_prefix_caching_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - blocks = manager.allocate_slots(req3, 4, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req3, 4, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert not blocks @@ -842,24 +872,41 @@ def test_cache_blocks_multi_group(): # Block hash 1: hit for group 0 and 1 # Block hash 2: hit for group 1 - assert block_pool.get_cached_block(req.block_hashes[0], - kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(req.block_hashes[1], - kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(req.block_hashes[2], - kv_cache_group_ids=[0]) is None - assert block_pool.get_cached_block(req.block_hashes[0], - kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(req.block_hashes[1], - kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(req.block_hashes[2], - kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(req.block_hashes[0], - kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(req.block_hashes[1], - kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(req.block_hashes[2], - kv_cache_group_ids=[0, 1]) is None + assert ( + block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0]) is None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0, 1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0, 1]) + is not None + ) + assert ( + block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0, 1]) + is None + ) def test_mm_prefix_caching(): @@ -889,16 +936,16 @@ def test_mm_prefix_caching(): # A unique image plus some text tokens. unique_token_ids = [-1] * 7 + [100] * 4 all_token_ids = common_token_ids + unique_token_ids - mm_positions = common_mm_positions + [ - PlaceholderRange(offset=48, length=7) - ] + mm_positions = common_mm_positions + [PlaceholderRange(offset=48, length=7)] mm_hashes = common_mm_hashes + ["ccc"] - req0 = make_request("0", - all_token_ids, - block_size, - sha256, - mm_positions=mm_positions, - mm_hashes=mm_hashes) + req0 = make_request( + "0", + all_token_ids, + block_size, + sha256, + mm_positions=mm_positions, + mm_hashes=mm_hashes, + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes @@ -907,47 +954,55 @@ def test_mm_prefix_caching(): block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0] == sha256( - (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), - ("aaa", ))) + (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), ("aaa",)) + ) assert block_hashes[1] == sha256( - (block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]), - ("aaa", "bbb"))) + ( + block_hashes[0], + tuple(all_token_ids[block_size : block_size * 2]), + ("aaa", "bbb"), + ) + ) assert block_hashes[2] == sha256( - (block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]), - ("bbb", ))) + ( + block_hashes[1], + tuple(all_token_ids[block_size * 2 : block_size * 3]), + ("bbb",), + ) + ) - blocks = manager.allocate_slots(req0, 59, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks.get_block_ids() == ([1, 2, 3, 4],) req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 5, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 assert len(block_hashes) == 4 assert block_hashes[3] == sha256( - (block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5), - ("ccc", ))) + (block_hashes[2], tuple(all_token_ids[3 * block_size :] + [8] * 5), ("ccc",)) + ) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 all_token_ids = common_token_ids + unique_token_ids - mm_positions = common_mm_positions + [ - PlaceholderRange(offset=48, length=7) - ] + mm_positions = common_mm_positions + [PlaceholderRange(offset=48, length=7)] mm_hashes = common_mm_hashes + ["ccc"] - req1 = make_request("1", - all_token_ids, - block_size, - sha256, - mm_positions=mm_positions, - mm_hashes=mm_hashes) + req1 = make_request( + "1", + all_token_ids, + block_size, + sha256, + mm_positions=mm_positions, + mm_hashes=mm_hashes, + ) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(computed_blocks.blocks[0]) == 3 assert num_computed_tokens == 3 * 16 @@ -977,30 +1032,33 @@ def test_cache_key_salting(): block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0] == sha256( - (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", ))) + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1",)) + ) assert block_hashes[1] == sha256( - (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + (block_hashes[0], tuple(token_ids[block_size : block_size * 2]), None) + ) assert block_hashes[2] == sha256( - (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), - None)) + (block_hashes[1], tuple(token_ids[block_size * 2 : block_size * 3]), None) + ) - blocks = manager.allocate_slots(req0, 59, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + blocks = manager.allocate_slots( + req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert blocks is not None - assert blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks.get_block_ids() == ([1, 2, 3, 4],) req0.num_computed_tokens = 59 # Append slots without allocating a new block. for _ in range(5): req0.append_output_token_ids(8) - new_blocks = manager.allocate_slots(req0, 5, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + new_blocks = manager.allocate_slots( + req0, 5, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 assert len(block_hashes) == 4 assert block_hashes[3] == sha256( - (block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None)) + (block_hashes[2], tuple(token_ids[3 * block_size :] + [8] * 5), None) + ) # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 @@ -1019,12 +1077,14 @@ def test_cache_key_salting(): block_hashes = req2.block_hashes assert len(block_hashes) == 3 assert block_hashes[0] == sha256( - (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", ))) + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2",)) + ) assert block_hashes[1] == sha256( - (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + (block_hashes[0], tuple(token_ids[block_size : block_size * 2]), None) + ) assert block_hashes[2] == sha256( - (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), - None)) + (block_hashes[1], tuple(token_ids[block_size * 2 : block_size * 3]), None) + ) def test_prefill_not_enough_free_blocks_with_computed_blocks(): @@ -1047,22 +1107,24 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req0, 48, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req0, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) block_part0 = manager.coordinator.single_type_managers[0].req_to_blocks[ - req0.request_id] + req0.request_id + ] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 - manager.allocate_slots(req1, 48, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req1, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) block_part1 = manager.coordinator.single_type_managers[0].req_to_blocks[ - req1.request_id] + req1.request_id + ] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) @@ -1075,9 +1137,12 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req2, block_size * 2, - len(computed_blocks.blocks[0]) * block_size, - computed_blocks) + manager.allocate_slots( + req2, + block_size * 2, + len(computed_blocks.blocks[0]) * block_size, + computed_blocks, + ) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). @@ -1088,9 +1153,12 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 # Req3 cannot be allocated. - assert manager.allocate_slots(req3, 48, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) is None + assert ( + manager.allocate_slots( + req3, 48, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + is None + ) # Block 0-2 are used by Req 1. assert {block.ref_cnt for block in block_part1[:3]} == {1} # Block 3-5 are free. @@ -1110,7 +1178,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, block_size, sha256) blocks = manager.allocate_slots(req0, 55) - assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], ) + assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4],) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids @@ -1118,10 +1186,10 @@ def test_reset_prefix_cache(): computed_blocks, _ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 - blocks = manager.allocate_slots(req1, 7, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) - assert blocks is not None and blocks.get_block_ids() == ([5], ) + blocks = manager.allocate_slots( + req1, 7, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks is not None and blocks.get_block_ids() == ([5],) # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() @@ -1152,9 +1220,9 @@ def test_prefix_cache_stats_disabled(): computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - manager.allocate_slots(req, 16, - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, 16, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) manager.reset_prefix_cache() # Ensure prefix_cache_stats remains None @@ -1191,19 +1259,14 @@ def test_maybe_evict_cached_block(): # Evict block1 pool._maybe_evict_cached_block(block1) assert pool.cached_block_hash_to_block._cache == { - block_hash0: { - block0.block_id: block0, - block3.block_id: block3 - }, + block_hash0: {block0.block_id: block0, block3.block_id: block3}, block_hash2: block2, } # Evict block0: block_hash0 entry should NOT be removed, as block3 # also use the same hash pool._maybe_evict_cached_block(block0) assert pool.cached_block_hash_to_block._cache == { - block_hash0: { - block3.block_id: block3 - }, + block_hash0: {block3.block_id: block3}, block_hash2: block2, } # Evict block2 @@ -1236,8 +1299,11 @@ def test_kv_cache_events(blocks_to_cache: int): events = manager.take_events() block = events[-1] - assert (len(block.block_hashes) == blocks_to_cache == len( - manager.block_pool.cached_block_hash_to_block)) + assert ( + len(block.block_hashes) + == blocks_to_cache + == len(manager.block_pool.cached_block_hash_to_block) + ) assert len(block.token_ids) == block.block_size * len(block.block_hashes) assert len(manager.block_pool.kv_event_queue) == 0 @@ -1254,9 +1320,12 @@ def test_kv_cache_events(blocks_to_cache: int): for blocks in events[:-1]: assert blocks.block_hashes[0] in stored_block_hash assert len(events) == blocks_to_cache + 1 - assert (isinstance(events[-2], BlockRemoved)) - assert (len(events[-1].block_hashes) == blocks_to_cache == len( - manager.block_pool.cached_block_hash_to_block)) + assert isinstance(events[-2], BlockRemoved) + assert ( + len(events[-1].block_hashes) + == blocks_to_cache + == len(manager.block_pool.cached_block_hash_to_block) + ) # All Blocks Cleared # Should see a single all blocks cleared event @@ -1285,9 +1354,9 @@ def test_eagle_enabled_removes_last_block(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks + ) manager.free(req) # New request with same tokens + Eagle enabled @@ -1316,9 +1385,9 @@ def test_eagle_with_partial_blocks(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks + ) manager.free(req) # New request with Eagle enabled @@ -1343,7 +1412,7 @@ def test_eagle_with_sliding_window(): KVCacheConfig( num_blocks=10, kv_cache_tensors=[], - kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)], + kv_cache_groups=[KVCacheGroupSpec(["layer"], sliding_window_spec)], ), max_model_len=8192, enable_caching=True, @@ -1356,9 +1425,9 @@ def test_eagle_with_sliding_window(): # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) - manager.allocate_slots(req, len(token_ids), - len(computed_blocks.blocks[0]) * 16, - computed_blocks) + manager.allocate_slots( + req, len(token_ids), len(computed_blocks.blocks[0]) * 16, computed_blocks + ) # record the block hash of the first block in the request for later use block_hash_first_block = req.block_hashes[0] assert block_hash_first_block is not None @@ -1372,14 +1441,20 @@ def test_eagle_with_sliding_window(): assert num_tokens == 1 * block_size # Evict the first block in the request - assert manager.block_pool.get_cached_block( - block_hash_first_block, kv_cache_group_ids=[0]) is not None + assert ( + manager.block_pool.get_cached_block( + block_hash_first_block, kv_cache_group_ids=[0] + ) + is not None + ) manager.block_pool.cached_block_hash_to_block._cache.pop( - make_block_hash_with_group_id(block_hash_first_block, 0)) + make_block_hash_with_group_id(block_hash_first_block, 0) + ) # New request - req_after_evict = make_request("partial_eagle_after_evict", token_ids, - block_size, sha256) + req_after_evict = make_request( + "partial_eagle_after_evict", token_ids, block_size, sha256 + ) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 3de6dffc3395..e78cced2d2db 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -7,15 +7,27 @@ import pytest import torch -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import (MultiModalFeatureSpec, - MultiModalKwargsItem, PlaceholderRange) +from vllm.config import ( + CacheConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, +) from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager @@ -43,8 +55,7 @@ def test_finish_request(): scheduler.add_request(request) for i, request in enumerate(requests): - scheduler.finish_requests(request.request_id, - RequestStatus.FINISHED_ABORTED) + scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED) assert request.request_id not in scheduler.requests assert len(scheduler.waiting) == 9 - i @@ -56,23 +67,25 @@ def test_get_num_unfinished_requests(): scheduler.add_request(request) for i, request in enumerate(requests): - scheduler.finish_requests(request.request_id, - RequestStatus.FINISHED_STOPPED) + scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_STOPPED) assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 -@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ - (None, None), - (True, 5), -]) -def test_schedule(enable_prefix_caching: Optional[bool], - prompt_logprobs: Optional[int]): - '''Test scheduling. +@pytest.mark.parametrize( + "enable_prefix_caching, prompt_logprobs", + [ + (None, None), + (True, 5), + ], +) +def test_schedule( + enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int] +): + """Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs - ''' + """ scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) - requests = create_requests(num_requests=10, - prompt_logprobs=prompt_logprobs) + requests = create_requests(num_requests=10, prompt_logprobs=prompt_logprobs) for request in requests: scheduler.add_request(request) @@ -94,8 +107,7 @@ def test_schedule(enable_prefix_caching: Optional[bool], def test_schedule_multimodal_requests(): scheduler = create_scheduler(model="llava-hf/llava-1.5-7b-hf") - mm_positions = [[PlaceholderRange(offset=i, length=100)] - for i in range(10)] + mm_positions = [[PlaceholderRange(offset=i, length=100)] for i in range(10)] requests = create_requests( num_requests=10, num_tokens=200, @@ -128,8 +140,7 @@ def test_schedule_partial_requests(): model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, ) - mm_positions = [[PlaceholderRange(offset=100, length=600)] - for _ in range(3)] + mm_positions = [[PlaceholderRange(offset=100, length=600)] for _ in range(3)] requests = create_requests( num_requests=3, num_tokens=800, @@ -152,10 +163,7 @@ def test_schedule_partial_requests(): # The third request is also scheduled partially. # The <img> tokens are not scheduled because of the encoder budget. assert output.num_scheduled_tokens[requests[2].request_id] == 100 - req_to_index = { - request.request_id: i - for i, request in enumerate(requests) - } + req_to_index = {request.request_id: i for i, request in enumerate(requests)} model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, @@ -191,9 +199,9 @@ def test_no_mm_input_chunking(): max_model_len=2048, ) mm_positions = [[PlaceholderRange(offset=400, length=800)]] - requests = create_requests(num_requests=1, - num_tokens=1200, - mm_positions=mm_positions) + requests = create_requests( + num_requests=1, num_tokens=1200, mm_positions=mm_positions + ) for request in requests: scheduler.add_request(request) @@ -204,10 +212,7 @@ def test_no_mm_input_chunking(): # We want to only see the 400 text tokens at the start scheduled assert output.num_scheduled_tokens[requests[0].request_id] == 400 - req_to_index = { - request.request_id: i - for i, request in enumerate(requests) - } + req_to_index = {request.request_id: i for i, request in enumerate(requests)} model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, @@ -267,10 +272,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): assert output.num_scheduled_tokens[requests[1].request_id] == 400 # The third request is also scheduled partially - 1024 - 400 - 400 = 224. assert output.num_scheduled_tokens[requests[2].request_id] == 224 - req_to_index = { - request.request_id: i - for i, request in enumerate(requests) - } + req_to_index = {request.request_id: i for i, request in enumerate(requests)} model_runner_output = ModelRunnerOutput( req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, @@ -311,8 +313,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): assert len(output2.finished_req_ids) == 0 assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1 - assert output2.num_scheduled_tokens[ - requests[2].request_id] == 800 - 224 - 224 + assert output2.num_scheduled_tokens[requests[2].request_id] == 800 - 224 - 224 def test_stop_via_update_from_output(): @@ -330,34 +331,31 @@ def test_stop_via_update_from_output(): scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={ - requests[0].request_id: 1, - requests[1].request_id: 2 - }, + num_scheduled_tokens={requests[0].request_id: 1, requests[1].request_id: 2}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={ requests[0].request_id: [], - requests[1].request_id: [10] + requests[1].request_id: [10], }, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_mm_hashes=[], structured_output_request_ids={}, - grammar_bitmask=None) + grammar_bitmask=None, + ) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[EOS_TOKEN_ID], - [10, - 11]], # First request hits EOS, second continues + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[ + [EOS_TOKEN_ID], + [10, 11], + ], # First request hits EOS, second continues logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) @@ -371,9 +369,7 @@ def test_stop_via_update_from_output(): # Test case 2: Stop on custom stop token scheduler = create_scheduler(num_speculative_tokens=2) - requests = create_requests(num_requests=2, - max_tokens=10, - stop_token_ids=[42, 43]) + requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43]) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req @@ -383,15 +379,12 @@ def test_stop_via_update_from_output(): scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 2 - }, + num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 2}, total_num_scheduled_tokens=5, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={ requests[0].request_id: [10, 42], - requests[1].request_id: [13] + requests[1].request_id: [13], }, num_common_prefix_blocks=0, finished_req_ids=set(), @@ -402,15 +395,12 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[10, 42, 12], - [13, 14]], # First request hits stop token + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) @@ -435,15 +425,12 @@ def test_stop_via_update_from_output(): scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 1 - }, + num_scheduled_tokens={requests[0].request_id: 3, requests[1].request_id: 1}, total_num_scheduled_tokens=4, scheduled_encoder_inputs={}, scheduled_spec_decode_tokens={ requests[0].request_id: [10, 11], - requests[1].request_id: [] + requests[1].request_id: [], }, num_common_prefix_blocks=0, finished_req_ids=set(), @@ -454,15 +441,12 @@ def test_stop_via_update_from_output(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, - sampled_token_ids=[[10, 11, 12], - [13]], # First request exceeds max_tokens + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) @@ -471,8 +455,7 @@ def test_stop_via_update_from_output(): assert scheduler.running[0].request_id == requests[1].request_id assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED assert requests[0].request_id in scheduler.finished_req_ids - assert list(requests[0].output_token_ids) == [10, 11 - ] # Truncated to max_tokens + assert list(requests[0].output_token_ids) == [10, 11] # Truncated to max_tokens assert list(requests[1].output_token_ids) == [13] # Test case 4: Ignore EOS flag @@ -489,14 +472,13 @@ def test_stop_via_update_from_output(): num_scheduled_tokens={requests[0].request_id: 3}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [EOS_TOKEN_ID, 10] - }, + scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]}, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_mm_hashes=[], structured_output_request_ids={}, - grammar_bitmask=None) + grammar_bitmask=None, + ) model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], @@ -504,7 +486,8 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], logprobs=None, prompt_logprobs_dict={}, - pooler_output=[]) + pooler_output=[], + ) scheduler.update_from_output(scheduler_output, model_output) @@ -514,12 +497,16 @@ def test_stop_via_update_from_output(): assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] -@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ - (None, None), - (True, 5), -]) -def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], - prompt_logprobs: Optional[int]): +@pytest.mark.parametrize( + "enable_prefix_caching, prompt_logprobs", + [ + (None, None), + (True, 5), + ], +) +def test_schedule_concurrent_batches( + enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int] +): scheduler = create_scheduler( max_num_batched_tokens=1024, max_num_seqs=2, @@ -535,15 +522,13 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], scheduler.add_request(requests[0]) scheduler_output0 = scheduler.schedule() assert len(scheduler_output0.scheduled_new_reqs) == 1 - assert scheduler_output0.num_scheduled_tokens[ - requests[0].request_id] == 512 + assert scheduler_output0.num_scheduled_tokens[requests[0].request_id] == 512 # The first request is still running, so only schedule the second request. scheduler.add_request(requests[1]) scheduler_output1 = scheduler.schedule() assert len(scheduler_output1.scheduled_new_reqs) == 1 - assert scheduler_output1.num_scheduled_tokens[ - requests[1].request_id] == 512 + assert scheduler_output1.num_scheduled_tokens[requests[1].request_id] == 512 # Model output of the first request. model_runner_output = ModelRunnerOutput( @@ -577,10 +562,12 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], def test_preempt_during_execution(): # NOTE(woosuk): The actual number of available blocks is 10 instead of 11 # because block 0 is reserved as the null block. - scheduler = create_scheduler(max_num_batched_tokens=100, - block_size=16, - num_blocks=11, - enable_prefix_caching=False) + scheduler = create_scheduler( + max_num_batched_tokens=100, + block_size=16, + num_blocks=11, + enable_prefix_caching=False, + ) requests = create_requests(num_requests=2, num_tokens=80, block_size=16) # Schedule the first request. @@ -637,13 +624,16 @@ def test_preempt_during_execution(): [ ([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match ([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch - ([[1, 2], [3]], [[1, 2, 5], [3, 4]], - (2, 3, 3, [2, 1])), # multiple sequences + ([[1, 2], [3]], [[1, 2, 5], [3, 4]], (2, 3, 3, [2, 1])), # multiple sequences ([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence ([[]], [[5]], (0, 0, 0, [0])), # empty sequence - ([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]], - (2, 6, 3, [2, 1, 0])), # multiple mismatches - ]) + ( + [[1, 2, 3], [4, 5, 6]], + [[1, 2, 7], [4, 8]], + (2, 6, 3, [2, 1, 0]), + ), # multiple mismatches + ], +) def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): """Test scheduling behavior with speculative decoding. @@ -678,8 +668,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): prompt_logprobs_dict={}, pooler_output=[], ) - engine_core_outputs = scheduler.update_from_output(output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output(output, model_runner_output) draft_token_ids = DraftTokenIds(req_ids, spec_tokens) scheduler.update_draft_token_ids(draft_token_ids) @@ -694,20 +683,23 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): # No draft or accepted tokens counted yet assert not engine_core_outputs or ( - engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None) + engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None + ) # Schedule the speculated tokens for validation output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 0 # The sampled token and speculated tokens - assert output.total_num_scheduled_tokens == \ - len(requests) + sum(len(ids) for ids in spec_tokens) + assert output.total_num_scheduled_tokens == len(requests) + sum( + len(ids) for ids in spec_tokens + ) for i in range(len(requests)): req_id = requests[i].request_id assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i]) if spec_tokens[i]: - assert len(output.scheduled_spec_decode_tokens[req_id]) == \ - len(spec_tokens[i]) + assert len(output.scheduled_spec_decode_tokens[req_id]) == len( + spec_tokens[i] + ) else: assert req_id not in output.scheduled_spec_decode_tokens @@ -719,11 +711,11 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): prompt_logprobs_dict={}, pooler_output=[], ) - engine_core_outputs = scheduler.update_from_output(output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output(output, model_runner_output) - scheduler_stats = engine_core_outputs[0].scheduler_stats \ - if engine_core_outputs else None + scheduler_stats = ( + engine_core_outputs[0].scheduler_stats if engine_core_outputs else None + ) if expected[0] == 0: assert scheduler_stats.spec_decoding_stats is None else: @@ -763,18 +755,25 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size for req in requests: - blocks = (scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[req.request_id]) + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].req_to_blocks[req.request_id] hashes = req.block_hashes - assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS) + assert ( + scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].num_cached_block[req.request_id] + == EXPECTED_TOTAL_BLOCKS + ) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS # Make sure we actually touched all the blocks. BLOCKS_PER_REQ = num_tokens / block_size - assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == - num_total_blocks - num_requests * BLOCKS_PER_REQ) + assert ( + scheduler.kv_cache_manager.block_pool.get_num_free_blocks() + == num_total_blocks - num_requests * BLOCKS_PER_REQ + ) def _step_until_done( @@ -813,25 +812,28 @@ def test_kv_connector_basic(): enable_prefix_caching=True, use_kv_connector=True, ) - NUM_TOTAL_BLOCKS = ( - scheduler.kv_cache_manager.block_pool.get_num_free_blocks()) + NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks() BLOCK_SIZE = scheduler.cache_config.block_size # Mock External Cache Hit. NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, False) + NUM_MATCHED_NEW_TOKENS, + False, + ) ###################################################### # FIRST SET OF REQUESTS - External Hit Only NUM_REQUESTS = 2 NUM_TOKENS = NUM_MATCHED_NEW_TOKENS * 2 MAX_TOKENS = 3 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS, - block_size=BLOCK_SIZE) + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE, + ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -858,15 +860,17 @@ def test_kv_connector_basic(): ) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, - NUM_REQUESTS, NUM_TOTAL_BLOCKS) + _assert_right_kv_cache_manager( + scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS + ) # Continue Generation until done. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) _ = scheduler.schedule() # Confirm we clean up the memory properly. - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_TOTAL_BLOCKS + assert ( + scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS + ) ###################################################### # SECOND SET OF REQUESTS - Local And External Hit @@ -874,10 +878,12 @@ def test_kv_connector_basic(): # We will get a local prefix cache hit for the first # NUM_TOKENS_PREFIX tokens since they are used above. NUM_TOKENS = NUM_TOKENS_PREFIX * 2 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS, - block_size=BLOCK_SIZE) + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE, + ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -901,19 +907,23 @@ def test_kv_connector_basic(): output=output, num_requests=NUM_REQUESTS, # Just the incremental tokens after local + remote cache hit. - expected_num_scheduled_tokens=(NUM_TOKENS - NUM_TOKENS_PREFIX - - NUM_MATCHED_NEW_TOKENS)) + expected_num_scheduled_tokens=( + NUM_TOKENS - NUM_TOKENS_PREFIX - NUM_MATCHED_NEW_TOKENS + ), + ) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, - NUM_REQUESTS, NUM_TOTAL_BLOCKS) + _assert_right_kv_cache_manager( + scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS + ) # Continue Generation until done. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) _ = scheduler.schedule() # Confirm we clean up the memory properly. - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_TOTAL_BLOCKS + assert ( + scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_TOTAL_BLOCKS + ) def test_kv_connector_unable_to_allocate(): @@ -934,17 +944,21 @@ def test_kv_connector_unable_to_allocate(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, False) + NUM_MATCHED_NEW_TOKENS, + False, + ) # Create two requests. The second request will not be able to # allocate slots because it will not have enough blocks. NUM_REQUESTS = 2 NUM_TOKENS = (NUM_BLOCKS // 2 + 1) * BLOCK_SIZE MAX_TOKENS = 2 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS, - block_size=BLOCK_SIZE) + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE, + ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -963,33 +977,33 @@ def test_kv_connector_unable_to_allocate(): # Just one request should be running. output = scheduler.schedule() - _assert_right_scheduler_output(output, - num_requests=1, - expected_num_scheduled_tokens=NUM_TOKENS - - NUM_MATCHED_NEW_TOKENS) + _assert_right_scheduler_output( + output, + num_requests=1, + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, + ) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 # All memory should be freed, with one request waiting. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_BLOCKS - 1 + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 # Just one request should be running. output = scheduler.schedule() - _assert_right_scheduler_output(output, - num_requests=1, - expected_num_scheduled_tokens=NUM_TOKENS - - NUM_MATCHED_NEW_TOKENS) + _assert_right_scheduler_output( + output, + num_requests=1, + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, + ) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 0 # All memory should be freed, with no requests waiting / running. _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_BLOCKS - 1 + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 0 @@ -1014,7 +1028,9 @@ def test_kv_connector_handles_preemption(): NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE scheduler.connector.get_num_new_matched_tokens = Mock(name="method") scheduler.connector.get_num_new_matched_tokens.return_value = ( - NUM_MATCHED_NEW_TOKENS, False) + NUM_MATCHED_NEW_TOKENS, + False, + ) # Create two requests. # Both can be scheduled at first, but the second request @@ -1022,10 +1038,12 @@ def test_kv_connector_handles_preemption(): NUM_REQUESTS = 2 NUM_TOKENS = BLOCK_SIZE * 2 + 1 MAX_TOKENS = BLOCK_SIZE * 2 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS, - block_size=BLOCK_SIZE) + requests = create_requests( + num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE, + ) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -1048,7 +1066,8 @@ def test_kv_connector_handles_preemption(): output, # 2 remote kv cache hits. num_requests=2, - expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS) + expected_num_scheduled_tokens=NUM_TOKENS - NUM_MATCHED_NEW_TOKENS, + ) assert len(scheduler.running) == 2 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) @@ -1058,7 +1077,8 @@ def test_kv_connector_handles_preemption(): output, # no connector_metadata num_requests=0, - expected_num_scheduled_tokens=1) + expected_num_scheduled_tokens=1, + ) assert len(scheduler.running) == 2 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) @@ -1068,7 +1088,8 @@ def test_kv_connector_handles_preemption(): output, # no connector_metadata num_requests=0, - expected_num_scheduled_tokens=1) + expected_num_scheduled_tokens=1, + ) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) @@ -1081,14 +1102,14 @@ def test_kv_connector_handles_preemption(): output, # no connector_metadata num_requests=0, - expected_num_scheduled_tokens=1) + expected_num_scheduled_tokens=1, + ) assert len(scheduler.waiting) == 1 assert len(scheduler.running) == 1 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) assert len(scheduler.running) == 0 # All memory should be freed since nothing is running. - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_BLOCKS - 1 + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 # Restarts the preempted request - generate 3rd token. # This will have a local and remote cache hit. @@ -1113,22 +1134,19 @@ def test_kv_connector_handles_preemption(): output, # no connector_metadata num_requests=0, - expected_num_scheduled_tokens=1) + expected_num_scheduled_tokens=1, + ) assert len(scheduler.running) == 1 _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) assert len(scheduler.running) == 0 # All memory should be freed since nothing is running. - assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ - == NUM_BLOCKS - 1 + assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(scheduler.running) - }, + req_id_to_index={req.request_id: i for i, req in enumerate(scheduler.running)}, sampled_token_ids=[[1000]] * len(scheduler.running), logprobs=None, prompt_logprobs_dict={}, @@ -1149,14 +1167,24 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block) == 0 + assert ( + len( + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks + ) + == 0 + ) + assert ( + len( + scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].num_cached_block + ) + == 0 + ) num_free_blocks = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) - assert num_free_blocks == ( - scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks + ) + assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. @@ -1176,9 +1204,9 @@ def test_memory_leak(): NUM_REQUESTS = 5 NUM_TOKENS = 10 MAX_TOKENS = 10 - requests = create_requests(num_requests=NUM_REQUESTS, - num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + requests = create_requests( + num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS + ) # Add each request. for request in requests: @@ -1212,7 +1240,7 @@ def create_scheduler_with_priority( max_model_len: Optional[int] = None, num_speculative_tokens: Optional[int] = None, ) -> Scheduler: - '''Create scheduler with priority policy enabled. + """Create scheduler with priority policy enabled. Args: model: model under test @@ -1224,7 +1252,7 @@ def create_scheduler_with_priority( Returns: {class}`Scheduler` instance with priority scheduling - ''' + """ if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( @@ -1243,9 +1271,11 @@ def create_scheduler_with_priority( seed=42, ) # Cache config, optionally force APC - kwargs_cache = ({} if enable_prefix_caching is None else { - 'enable_prefix_caching': enable_prefix_caching - }) + kwargs_cache = ( + {} + if enable_prefix_caching is None + else {"enable_prefix_caching": enable_prefix_caching} + ) cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, @@ -1253,16 +1283,21 @@ def create_scheduler_with_priority( cache_dtype="auto", **kwargs_cache, ) - kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": "local_storage"}, - ) if use_kv_connector else None + kv_transfer_config = ( + KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) + if use_kv_connector + else None + ) speculative_config: Optional[SpeculativeConfig] = None if num_speculative_tokens is not None: speculative_config = SpeculativeConfig( - model="ngram", num_speculative_tokens=num_speculative_tokens) + model="ngram", num_speculative_tokens=num_speculative_tokens + ) vllm_config = VllmConfig( scheduler_config=scheduler_config, @@ -1275,9 +1310,9 @@ def create_scheduler_with_priority( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + KVCacheGroupSpec( + ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ) ], ) cache_config.num_gpu_blocks = num_blocks @@ -1290,15 +1325,16 @@ def create_scheduler_with_priority( def create_requests_with_priority( - num_requests: int, - priorities: list[int], - arrival_times: Optional[list[float]] = None, - num_tokens: int = 10, - mm_positions: Optional[list[list[PlaceholderRange]]] = None, - max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None, - starting_idx: int = 0): + num_requests: int, + priorities: list[int], + arrival_times: Optional[list[float]] = None, + num_tokens: int = 10, + mm_positions: Optional[list[list[PlaceholderRange]]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None, + starting_idx: int = 0, +): """Create requests with specified priorities and arrival times.""" assert len(priorities) == num_requests if arrival_times is not None: @@ -1306,10 +1342,12 @@ def create_requests_with_priority( else: arrival_times = [float(i) for i in range(num_requests)] - sampling_params = SamplingParams(ignore_eos=False, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - prompt_logprobs=prompt_logprobs) + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs, + ) requests = [] for i in range(num_requests): mm_features = [] @@ -1321,7 +1359,8 @@ def create_requests_with_priority( data=MultiModalKwargsItem.dummy("dummy_m"), mm_position=position, identifier=identifier, - modality="image") + modality="image", + ) mm_features.append(mm_feature) request = Request( @@ -1347,9 +1386,9 @@ def test_priority_scheduling_basic_ordering(): # Priority 0 (highest), 1, 2 (lowest) priorities = [2, 0, 1] # Add in non-priority order arrival_times = [1.0, 2.0, 3.0] # All different arrival times - requests = create_requests_with_priority(num_requests=3, - priorities=priorities, - arrival_times=arrival_times) + requests = create_requests_with_priority( + num_requests=3, priorities=priorities, arrival_times=arrival_times + ) # Add requests in non-priority order for request in requests: @@ -1375,9 +1414,9 @@ def test_priority_scheduling_arrival_time_tiebreaker(): # Create requests with same priority but different arrival times priorities = [1, 1, 1] # All same priority arrival_times = [3.0, 1.0, 2.0] # Different arrival times - requests = create_requests_with_priority(num_requests=3, - priorities=priorities, - arrival_times=arrival_times) + requests = create_requests_with_priority( + num_requests=3, priorities=priorities, arrival_times=arrival_times + ) # Add requests in non-arrival order for request in requests: @@ -1402,9 +1441,9 @@ def test_priority_scheduling_mixed_priority_and_arrival(): # Create requests with mixed priorities and arrival times priorities = [2, 1, 1, 0] # Mixed priorities arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times - requests = create_requests_with_priority(num_requests=4, - priorities=priorities, - arrival_times=arrival_times) + requests = create_requests_with_priority( + num_requests=4, priorities=priorities, arrival_times=arrival_times + ) # Add requests for request in requests: @@ -1441,7 +1480,7 @@ def test_priority_scheduling_preemption(): num_requests=2, priorities=[5, 5], # Low priority arrival_times=[1.0, 2.0], - num_tokens=30 # Large enough to consume significant memory + num_tokens=30, # Large enough to consume significant memory ) # Add and schedule low priority requests @@ -1455,8 +1494,7 @@ def test_priority_scheduling_preemption(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in low_priority_requests], req_id_to_index={ - req.request_id: i - for i, req in enumerate(low_priority_requests) + req.request_id: i for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], logprobs=None, @@ -1474,7 +1512,7 @@ def test_priority_scheduling_preemption(): num_requests=1, priorities=[0], # High priority arrival_times=[3.0], - num_tokens=30 # Large enough to require significant memory + num_tokens=30, # Large enough to require significant memory )[0] scheduler.add_request(high_priority_request) @@ -1515,10 +1553,8 @@ def test_priority_scheduling_no_preemption_when_space_available(): # Add two low-priority running requests low_priority_requests = create_requests_with_priority( - num_requests=2, - priorities=[5, 5], - arrival_times=[1.0, 2.0], - num_tokens=30) + num_requests=2, priorities=[5, 5], arrival_times=[1.0, 2.0], num_tokens=30 + ) for request in low_priority_requests: scheduler.add_request(request) @@ -1527,8 +1563,7 @@ def test_priority_scheduling_no_preemption_when_space_available(): model_output = ModelRunnerOutput( req_ids=[req.request_id for req in low_priority_requests], req_id_to_index={ - req.request_id: i - for i, req in enumerate(low_priority_requests) + req.request_id: i for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], logprobs=None, @@ -1538,10 +1573,9 @@ def test_priority_scheduling_no_preemption_when_space_available(): scheduler.update_from_output(output, model_output) # Add high-priority request - high_priority_request = create_requests_with_priority(num_requests=1, - priorities=[0], - arrival_times=[3.0], - num_tokens=30)[0] + high_priority_request = create_requests_with_priority( + num_requests=1, priorities=[0], arrival_times=[3.0], num_tokens=30 + )[0] scheduler.add_request(high_priority_request) @@ -1569,7 +1603,8 @@ def test_priority_scheduling_preemption_victim_selection(): num_requests=3, priorities=[3, 2, 0], # Different priorities: low, medium, high arrival_times=[1.0, 2.0, 3.0], - num_tokens=10) + num_tokens=10, + ) # Add all requests for request in requests: @@ -1608,7 +1643,8 @@ def test_priority_scheduling_equal_priority_preemption(): num_requests=3, priorities=[2, 2, 2], # Same priority arrival_times=[3.0, 1.0, 2.0], # Different arrival times - num_tokens=10) + num_tokens=10, + ) # Add all requests for request in requests: @@ -1644,7 +1680,8 @@ def test_priority_scheduling_waiting_queue_order(): num_requests=4, priorities=[3, 1, 2, 0], # Mixed priorities arrival_times=[1.0, 2.0, 3.0, 4.0], - num_tokens=10) + num_tokens=10, + ) # Add all requests for request in requests: @@ -1679,9 +1716,9 @@ def test_priority_scheduling_fcfs_fallback(): # Create requests with same priority but different arrival times priorities = [1, 1, 1, 1] # All same priority arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times - requests = create_requests_with_priority(num_requests=4, - priorities=priorities, - arrival_times=arrival_times) + requests = create_requests_with_priority( + num_requests=4, priorities=priorities, arrival_times=arrival_times + ) # Add requests for request in requests: @@ -1711,7 +1748,8 @@ def test_priority_scheduling_with_limited_slots(): num_requests=4, priorities=[3, 1, 2, 0], # Mixed priorities arrival_times=[1.0, 2.0, 3.0, 4.0], - num_tokens=10) + num_tokens=10, + ) # Add all requests for request in requests: @@ -1749,10 +1787,12 @@ def test_priority_scheduling_heap_property(): # Add requests in random priority order priorities = [5, 1, 8, 3, 2, 7, 4, 6] arrival_times = [float(i) for i in range(len(priorities))] - requests = create_requests_with_priority(num_requests=len(priorities), - priorities=priorities, - arrival_times=arrival_times, - num_tokens=10) + requests = create_requests_with_priority( + num_requests=len(priorities), + priorities=priorities, + arrival_times=arrival_times, + num_tokens=10, + ) # Add all requests for request in requests: @@ -1779,8 +1819,7 @@ def test_priority_scheduling_heap_property(): scheduler.update_from_output(output, model_output) # Finish the request to make room for the next one - scheduler.finish_requests(req.req_id, - RequestStatus.FINISHED_STOPPED) + scheduler.finish_requests(req.req_id, RequestStatus.FINISHED_STOPPED) # Verify requests were scheduled in priority order (lowest value first) expected_priorities = sorted(priorities) @@ -1879,10 +1918,7 @@ def test_priority_scheduling_preemption_when_out_of_kv(): requests = [request_low, request_high] model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], - req_id_to_index={ - req.request_id: i - for i, req in enumerate(requests) - }, + req_id_to_index={req.request_id: i for i, req in enumerate(requests)}, sampled_token_ids=[[100] for _ in requests], # spec_token_ids=None, logprobs=None, @@ -1913,11 +1949,12 @@ def test_priority_scheduling_preemption_when_out_of_kv(): # Encoder-decoder models should always have it disabled (False, True, False), (True, True, False), - ]) + ], +) def test_chunked_prefill_disabled_for_encoder_decoder( - enable_chunked_prefill: bool, is_encoder_decoder: bool, - expect_enabled: bool) -> None: - """Validate that chunked prefill is appropriately disabled for + enable_chunked_prefill: bool, is_encoder_decoder: bool, expect_enabled: bool +) -> None: + """Validate that chunked prefill is appropriately disabled for encoder-decoder models.""" scheduler_config = SchedulerConfig( enable_chunked_prefill=enable_chunked_prefill, @@ -1931,18 +1968,20 @@ def test_chunked_prefill_disabled_for_encoder_decoder( f.name for f in dataclasses.fields(scheduler_config) ] _validate_chunked_prefill_settings_for_encoder_decoder( - scheduler_config, is_encoder_decoder, expect_enabled) + scheduler_config, is_encoder_decoder, expect_enabled + ) # Ensure it is retained in VllmConfig, even after its post-init. vllm_config = VllmConfig(scheduler_config=scheduler_config) _validate_chunked_prefill_settings_for_encoder_decoder( - vllm_config.scheduler_config, is_encoder_decoder, expect_enabled) + vllm_config.scheduler_config, is_encoder_decoder, expect_enabled + ) def _validate_chunked_prefill_settings_for_encoder_decoder( - scheduler_config: SchedulerConfig, is_encoder_decoder: bool, - expect_enabled: bool) -> None: - """Validate chunked prefill settings in the scheduler config for + scheduler_config: SchedulerConfig, is_encoder_decoder: bool, expect_enabled: bool +) -> None: + """Validate chunked prefill settings in the scheduler config for encoder-decoder models.""" assert scheduler_config.chunked_prefill_enabled is expect_enabled assert scheduler_config.enable_chunked_prefill is expect_enabled diff --git a/tests/v1/core/test_scheduler_e2e.py b/tests/v1/core/test_scheduler_e2e.py index bd0320baef87..6983c3b92f6b 100644 --- a/tests/v1/core/test_scheduler_e2e.py +++ b/tests/v1/core/test_scheduler_e2e.py @@ -15,13 +15,15 @@ @pytest.fixture(scope="module") def llm() -> LLM: - return LLM(MODEL, - enforce_eager=True, - enable_prefix_caching=True, - long_prefill_token_threshold=2, - max_num_batched_tokens=6, - max_num_seqs=3, - block_size=16) + return LLM( + MODEL, + enforce_eager=True, + enable_prefix_caching=True, + long_prefill_token_threshold=2, + max_num_batched_tokens=6, + max_num_seqs=3, + block_size=16, + ) def test_concurrent_partial_prefill(llm): diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index 166be8bda05e..a27f32938c08 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -7,27 +7,28 @@ import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - make_block_hash_with_group_id) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + KVCacheBlock, + make_block_hash_with_group_id, +) from vllm.v1.core.single_type_kv_cache_manager import ( - ChunkedLocalAttentionManager, SlidingWindowManager) -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - SlidingWindowSpec) + ChunkedLocalAttentionManager, + SlidingWindowManager, +) +from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, SlidingWindowSpec pytestmark = pytest.mark.cpu_test def get_sliding_window_manager(sliding_window_spec, block_pool): - return SlidingWindowManager(sliding_window_spec, - block_pool, - kv_cache_group_id=0) + return SlidingWindowManager(sliding_window_spec, block_pool, kv_cache_group_id=0) -def get_chunked_local_attention_manager(chunked_local_attention_spec, - block_pool): - return ChunkedLocalAttentionManager(chunked_local_attention_spec, - block_pool, - kv_cache_group_id=0) +def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool): + return ChunkedLocalAttentionManager( + chunked_local_attention_spec, block_pool, kv_cache_group_id=0 + ) def test_chunked_local_attention_possible_cached_prefix(): @@ -41,8 +42,9 @@ def test_chunked_local_attention_possible_cached_prefix(): ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) - manager = get_chunked_local_attention_manager(chunked_local_attention_spec, - block_pool) + manager = get_chunked_local_attention_manager( + chunked_local_attention_spec, block_pool + ) def run_one_case(block_is_cached, tail_token, expect_length): block_hash_list = [ @@ -52,12 +54,14 @@ def run_one_case(block_is_cached, tail_token, expect_length): block_pool.cached_block_hash_to_block._cache.clear() # Mock the block pool with the cached blocks - for i, (block_hash, - is_cached) in enumerate(zip(block_hash_list, block_is_cached)): + for i, (block_hash, is_cached) in enumerate( + zip(block_hash_list, block_is_cached) + ): if is_cached: block_pool.cached_block_hash_to_block.insert( make_block_hash_with_group_id(block_hash, 0), - block_pool.blocks[i + 10]) + block_pool.blocks[i + 10], + ) computed_blocks = manager.find_longest_cache_hit( block_hashes=block_hash_list, @@ -65,11 +69,14 @@ def run_one_case(block_is_cached, tail_token, expect_length): kv_cache_group_ids=[0], block_pool=block_pool, kv_cache_spec=chunked_local_attention_spec, - use_eagle=False)[0] + use_eagle=False, + )[0] assert len(computed_blocks) == expect_length - assert all(block == block_pool.null_block - for block in computed_blocks[:(expect_length - 1) // 2]) + assert all( + block == block_pool.null_block + for block in computed_blocks[: (expect_length - 1) // 2] + ) run_one_case([True], 0, 1) run_one_case([True], 1, 1) @@ -115,12 +122,14 @@ def run_one_case(block_is_cached, expect_length): block_pool.cached_block_hash_to_block._cache.clear() # Mock the block pool with the cached blocks - for i, (block_hash, - is_cached) in enumerate(zip(block_hash_list, block_is_cached)): + for i, (block_hash, is_cached) in enumerate( + zip(block_hash_list, block_is_cached) + ): if is_cached: block_pool.cached_block_hash_to_block.insert( make_block_hash_with_group_id(block_hash, 0), - block_pool.blocks[i + 10]) + block_pool.blocks[i + 10], + ) computed_blocks = manager.find_longest_cache_hit( block_hashes=block_hash_list, @@ -128,16 +137,18 @@ def run_one_case(block_is_cached, expect_length): kv_cache_group_ids=[0], block_pool=block_pool, kv_cache_spec=sliding_window_spec, - use_eagle=False)[0] + use_eagle=False, + )[0] assert len(computed_blocks) == expect_length - assert all(block == block_pool.null_block - for block in computed_blocks[:expect_length - 2]) + assert all( + block == block_pool.null_block + for block in computed_blocks[: expect_length - 2] + ) for i in range(2): if i < expect_length: block_index = expect_length - i - 1 - assert computed_blocks[ - block_index].block_id == block_index + 10 + assert computed_blocks[block_index].block_id == block_index + 10 run_one_case([False] * 10, 0) run_one_case([True], 1) @@ -146,17 +157,16 @@ def run_one_case(block_is_cached, expect_length): run_one_case([True, True, False], 2) run_one_case([True, True, True], 3) run_one_case([True, True, True, False], 3) - run_one_case([ - True, True, False, True, False, False, True, True, False, True, True, - True - ], 12) - run_one_case([ - True, True, False, True, False, False, True, True, False, False, False - ], 8) - run_one_case([ - True, True, False, True, False, False, True, True, False, False, False, - True - ], 8) + run_one_case( + [True, True, False, True, False, False, True, True, False, True, True, True], 12 + ) + run_one_case( + [True, True, False, True, False, False, True, True, False, False, False], 8 + ) + run_one_case( + [True, True, False, True, False, False, True, True, False, False, False, True], + 8, + ) def test_chunked_local_attention_remove_skipped_blocks(): @@ -176,8 +186,8 @@ def test_chunked_local_attention_remove_skipped_blocks(): def id_to_block_table(ids) -> list[KVCacheBlock]: return [ - KVCacheBlock(id_) - if id_ != null_block_id else block_pool.null_block for id_ in ids + KVCacheBlock(id_) if id_ != null_block_id else block_pool.null_block + for id_ in ids ] def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): @@ -188,7 +198,17 @@ def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): assert block.block_id == id_ original_block_ids = [ - 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 + 1000, + 1001, + 1002, + 1003, + 1004, + 1005, + 1006, + 1007, + 1008, + 1009, + 1010, ] block_table = id_to_block_table(original_block_ids) manager.req_to_blocks["test"] = block_table @@ -227,8 +247,8 @@ def test_sliding_window_remove_skipped_blocks(): def id_to_block_table(ids) -> list[KVCacheBlock]: return [ - KVCacheBlock(id_) - if id_ != null_block_id else block_pool.null_block for id_ in ids + KVCacheBlock(id_) if id_ != null_block_id else block_pool.null_block + for id_ in ids ] def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): @@ -239,7 +259,17 @@ def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): assert block.block_id == id_ original_block_ids = [ - 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 + 1000, + 1001, + 1002, + 1003, + 1004, + 1005, + 1006, + 1007, + 1008, + 1009, + 1010, ] block_table = id_to_block_table(original_block_ids) manager.req_to_blocks["test"] = block_table @@ -289,13 +319,16 @@ def test_get_num_blocks_to_allocate(): block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) manager = get_sliding_window_manager(sliding_window_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] - cached_blocks_2 = [block_pool.null_block for _ in range(5) - ] + [KVCacheBlock(i + 1) for i in range(5)] + cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ + KVCacheBlock(i + 1) for i in range(5) + ] - assert manager.get_num_blocks_to_allocate("1", 20 * block_size, - cached_blocks_1) == 20 - assert manager.get_num_blocks_to_allocate("2", 20 * block_size, - cached_blocks_2) == 15 + assert ( + manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20 + ) + assert ( + manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 + ) def test_chunked_local_attention_get_num_blocks_to_allocate(): @@ -311,10 +344,13 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) manager = get_chunked_local_attention_manager(attention_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] - cached_blocks_2 = [block_pool.null_block for _ in range(5) - ] + [KVCacheBlock(i + 1) for i in range(5)] + cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ + KVCacheBlock(i + 1) for i in range(5) + ] - assert manager.get_num_blocks_to_allocate("1", 20 * block_size, - cached_blocks_1) == 20 - assert manager.get_num_blocks_to_allocate("2", 20 * block_size, - cached_blocks_2) == 15 + assert ( + manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20 + ) + assert ( + manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 + ) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index d343141cdf4c..75ef1a5ec165 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -4,18 +4,29 @@ import torch -from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, - SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import (MultiModalFeatureSpec, - MultiModalKwargsItem, PlaceholderRange) +from vllm.config import ( + CacheConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) +from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.sampling_params import SamplingParams from vllm.utils import sha256 -from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, - init_none_hash) +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, +) from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager @@ -37,7 +48,7 @@ def create_scheduler( skip_tokenizer_init: bool = False, async_scheduling: bool = False, ) -> Union[Scheduler, AsyncScheduler]: - '''Create scheduler under test. + """Create scheduler under test. Args: model: model under test @@ -49,7 +60,7 @@ def create_scheduler( Returns: {class}`Scheduler` instance - ''' + """ if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( @@ -69,9 +80,11 @@ def create_scheduler( skip_tokenizer_init=skip_tokenizer_init, ) # Cache config, optionally force APC - kwargs_cache = ({} if enable_prefix_caching is None else { - 'enable_prefix_caching': enable_prefix_caching - }) + kwargs_cache = ( + {} + if enable_prefix_caching is None + else {"enable_prefix_caching": enable_prefix_caching} + ) cache_config = CacheConfig( block_size=block_size, gpu_memory_utilization=0.9, @@ -79,16 +92,21 @@ def create_scheduler( cache_dtype="auto", **kwargs_cache, ) - kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": "local_storage"}, - ) if use_kv_connector else None + kv_transfer_config = ( + KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) + if use_kv_connector + else None + ) speculative_config: Optional[SpeculativeConfig] = None if num_speculative_tokens is not None: speculative_config = SpeculativeConfig( - model="ngram", num_speculative_tokens=num_speculative_tokens) + model="ngram", num_speculative_tokens=num_speculative_tokens + ) vllm_config = VllmConfig( scheduler_config=scheduler_config, @@ -101,9 +119,9 @@ def create_scheduler( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + KVCacheGroupSpec( + ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ) ], ) cache_config.num_gpu_blocks = num_blocks @@ -135,10 +153,12 @@ def create_requests( _none_hash_initialized = True block_hasher = get_request_block_hasher(block_size, sha256) - sampling_params = SamplingParams(ignore_eos=False, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - prompt_logprobs=prompt_logprobs) + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs, + ) requests = [] for i in range(num_requests): mm_features = [] @@ -152,11 +172,11 @@ def create_requests( data=MultiModalKwargsItem.dummy("dummy_m"), mm_position=position, identifier=identifier, - modality="image") + modality="image", + ) mm_features.append(mm_feature) - prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * - num_tokens) + prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens request = Request( request_id=f"{i}", prompt_token_ids=prompt_token_ids, diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index b6b85e4440d0..59841a446db3 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -9,8 +9,14 @@ from tests.utils import create_new_process_for_each_test from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, - ParallelConfig, SchedulerConfig, VllmConfig) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.platforms import current_platform from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher @@ -18,7 +24,6 @@ # Helper MLP for testing class SimpleMLP(nn.Module): - def __init__(self): super().__init__() self.fc1 = nn.Linear(10, 10) @@ -28,8 +33,9 @@ def forward(self, x): return self.fc2(self.fc1(x)) -def _create_vllm_config(compilation_config: CompilationConfig, - max_num_seqs: int = 8) -> MagicMock: +def _create_vllm_config( + compilation_config: CompilationConfig, max_num_seqs: int = 8 +) -> MagicMock: mock_config = MagicMock(spec=VllmConfig) mock_config.compilation_config = compilation_config mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) @@ -43,7 +49,6 @@ def _create_vllm_config(compilation_config: CompilationConfig, class TestCudagraphDispatcher: - @pytest.mark.parametrize( "case_id,cudagraph_mode_str,compilation_level", [ @@ -55,18 +60,21 @@ class TestCudagraphDispatcher: (2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION), # Test case 3: Piecewise for all (3, "PIECEWISE", CompilationLevel.PIECEWISE), - ]) + ], + ) def test_dispatcher(self, cudagraph_mode_str, compilation_level): # Setup dispatcher - comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str, - level=compilation_level, - cudagraph_capture_sizes=[1, 8]) + comp_config = CompilationConfig( + cudagraph_mode=cudagraph_mode_str, + level=compilation_level, + cudagraph_capture_sizes=[1, 8], + ) config = _create_vllm_config(comp_config, max_num_seqs=8) dispatcher = CudagraphDispatcher(config) dispatcher.initialize_cudagraph_keys( - cudagraph_mode=comp_config.cudagraph_mode, - uniform_decode_query_len=1) + cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1 + ) # Verify the key is initialized correctly if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: @@ -114,8 +122,7 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_level): # 4. Cascade attention should have a fall back mode desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) - rt_mode, key = dispatcher.dispatch(desc_full_exact, - use_cascade_attn=True) + rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True) if "PIECEWISE" in cudagraph_mode_str: # string contains check assert rt_mode == CUDAGraphMode.PIECEWISE assert key == desc_full_exact.non_uniform @@ -125,7 +132,6 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_level): @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") class TestCUDAGraphWrapper: - def setup_method(self): self.vllm_config = _create_vllm_config(CompilationConfig()) self.model = SimpleMLP().to("cuda") @@ -134,26 +140,30 @@ def setup_method(self): @create_new_process_for_each_test("spawn") def test_capture_and_replay(self): - wrapper = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + wrapper = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) batch_descriptor = BatchDescriptor(num_tokens=10) # 0. global warmup - with set_forward_context(attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=None): + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None, + ): wrapper(self.input_tensor) # 1. Capture - with set_forward_context( + with ( + set_forward_context( attn_metadata=None, vllm_config=self.vllm_config, cudagraph_runtime_mode=CUDAGraphMode.FULL, - batch_descriptor=batch_descriptor),\ - patch("torch.cuda.graph", - wraps=torch.cuda.graph) as mock_cuda_graph: + batch_descriptor=batch_descriptor, + ), + patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, + ): output1 = wrapper(self.input_tensor) # capturing phase should generate a zero output assert torch.allclose(output1, torch.zeros_like(output1)) @@ -164,13 +174,17 @@ def test_capture_and_replay(self): assert entry.cudagraph is not None # 2. Replay - with set_forward_context( + with ( + set_forward_context( attn_metadata=None, vllm_config=self.vllm_config, cudagraph_runtime_mode=CUDAGraphMode.FULL, - batch_descriptor=batch_descriptor),\ - patch.object(entry.cudagraph, 'replay', - wraps=entry.cudagraph.replay) as mock_replay: + batch_descriptor=batch_descriptor, + ), + patch.object( + entry.cudagraph, "replay", wraps=entry.cudagraph.replay + ) as mock_replay, + ): output2 = wrapper(self.input_tensor) mock_replay.assert_called_once() @@ -180,20 +194,23 @@ def test_capture_and_replay(self): @create_new_process_for_each_test("spawn") def test_bypass_on_mode_mismatch(self): - wrapper = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + wrapper = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) batch_descriptor = BatchDescriptor(num_tokens=10) - with set_forward_context( + with ( + set_forward_context( attn_metadata=None, vllm_config=self.vllm_config, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, - batch_descriptor=batch_descriptor), \ - patch('torch.cuda.graph', - wraps=torch.cuda.graph) as mock_cuda_graph, \ - patch.object(self.model, 'forward', - wraps=self.model.forward) as mock_forward: + batch_descriptor=batch_descriptor, + ), + patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, + patch.object( + self.model, "forward", wraps=self.model.forward + ) as mock_forward, + ): wrapper(self.input_tensor) mock_cuda_graph.assert_not_called() mock_forward.assert_called_once() @@ -201,18 +218,20 @@ def test_bypass_on_mode_mismatch(self): @create_new_process_for_each_test("spawn") def test_bypass_on_mode_none(self): - wrapper = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + wrapper = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) batch_descriptor = BatchDescriptor(num_tokens=10) - with set_forward_context( + with ( + set_forward_context( attn_metadata=None, vllm_config=self.vllm_config, cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=batch_descriptor), \ - patch('torch.cuda.graph', - wraps=torch.cuda.graph) as mock_cuda_graph: + batch_descriptor=batch_descriptor, + ), + patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, + ): wrapper(self.input_tensor) mock_cuda_graph.assert_not_called() assert not wrapper.concrete_cudagraph_entries @@ -220,38 +239,44 @@ def test_bypass_on_mode_none(self): @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") class TestCudagraphIntegration: - def setup_method(self): # only FULL mode for non-uniform batches - self.comp_config = CompilationConfig(level=CompilationLevel.PIECEWISE, - cudagraph_mode="FULL", - cudagraph_capture_sizes=[10, 20]) + self.comp_config = CompilationConfig( + level=CompilationLevel.PIECEWISE, + cudagraph_mode="FULL", + cudagraph_capture_sizes=[10, 20], + ) self.vllm_config = _create_vllm_config(self.comp_config) self.dispatcher = CudagraphDispatcher(self.vllm_config) self.dispatcher.initialize_cudagraph_keys( - self.comp_config.cudagraph_mode, uniform_decode_query_len=1) + self.comp_config.cudagraph_mode, uniform_decode_query_len=1 + ) - def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode, - batch_descriptor): + def _run_and_monitor_call( + self, wrapper, input_tensor, runtime_mode, batch_descriptor + ): """Helper to run a single call and monitor the action.""" - with patch('torch.cuda.graph', - wraps=torch.cuda.graph) as mock_graph_context, \ - patch.object(wrapper, 'runnable', - wraps=wrapper.runnable) as mock_runnable: + with ( + patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context, + patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable, + ): + entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None) - entry = wrapper.concrete_cudagraph_entries.get( - batch_descriptor, None) - - context = set_forward_context(attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=runtime_mode, - batch_descriptor=batch_descriptor) + context = set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=runtime_mode, + batch_descriptor=batch_descriptor, + ) mock_replay = MagicMock() if entry and entry.cudagraph: - with context, \ - patch.object(entry.cudagraph, 'replay', - new_callable=MagicMock) as mock_replay: + with ( + context, + patch.object( + entry.cudagraph, "replay", new_callable=MagicMock + ) as mock_replay, + ): wrapper(input_tensor) else: with context: @@ -272,8 +297,7 @@ def _run_and_monitor_call(self, wrapper, input_tensor, runtime_mode, @create_new_process_for_each_test("spawn") def test_capture_replay_bypass_logic(self): model = SimpleMLP().to("cuda") - full_wrapper = CUDAGraphWrapper(model, self.vllm_config, - CUDAGraphMode.FULL) + full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL) max_bs = 16 persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda") input_1 = persistent_input_buffer[:1] @@ -285,75 +309,79 @@ def test_capture_replay_bypass_logic(self): desc_3_unseen = BatchDescriptor(num_tokens=3) # 0. global warmup - with set_forward_context(attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=None): + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None, + ): full_wrapper(input_1) rt_mode, key = self.dispatcher.dispatch(desc_1) # 1. Capture first shape - action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, - key) + action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) assert action == "capture_global" # 2. Replay first shape - action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, - key) + action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) assert action == "replay" rt_mode, key = self.dispatcher.dispatch(desc_2) # 3. Capture second shape - action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, - key) + action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key) assert action == "capture_global" # 4. Replay second shape - action = self._run_and_monitor_call(full_wrapper, input_2, - CUDAGraphMode.FULL, desc_2) + action = self._run_and_monitor_call( + full_wrapper, input_2, CUDAGraphMode.FULL, desc_2 + ) assert action == "replay" # 5. Bypass if no key match rt_mode, key = self.dispatcher.dispatch(desc_3_unseen) assert rt_mode == CUDAGraphMode.NONE - action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, - key) + action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key) assert action == "bypass" # capture unseen shape is not allowed after disable set_cudagraph_capturing_enabled(False) with pytest.raises(RuntimeError): - self._run_and_monitor_call(full_wrapper, input_3, - CUDAGraphMode.FULL, desc_3_unseen) + self._run_and_monitor_call( + full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen + ) set_cudagraph_capturing_enabled(True) @create_new_process_for_each_test("spawn") def test_nested_wrappers(self): """Tests a scenario with a PIECEWISE wrapper inside a FULL one.""" model = SimpleMLP().to("cuda") - full_wrapper = CUDAGraphWrapper(model, self.vllm_config, - CUDAGraphMode.FULL) + full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL) input_1 = torch.randn(1, 10, device="cuda") # Setup: Inner model is wrapped with PIECEWISE, outer with FULL inner_model = SimpleMLP().to("cuda") - piecewise_wrapper = CUDAGraphWrapper(inner_model, self.vllm_config, - CUDAGraphMode.PIECEWISE) + piecewise_wrapper = CUDAGraphWrapper( + inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE + ) inner_model.forward = MagicMock(wraps=inner_model.forward) outer_model = SimpleMLP().to("cuda") # When outer model is called, it calls the piecewise_wrapper - outer_model.forward = MagicMock(wraps=outer_model.forward, - side_effect=piecewise_wrapper) - full_wrapper = CUDAGraphWrapper(outer_model, self.vllm_config, - CUDAGraphMode.FULL) + outer_model.forward = MagicMock( + wraps=outer_model.forward, side_effect=piecewise_wrapper + ) + full_wrapper = CUDAGraphWrapper( + outer_model, self.vllm_config, CUDAGraphMode.FULL + ) desc_1 = BatchDescriptor(num_tokens=1) # 0. global warmup - with set_forward_context(attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=None): + with set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None, + ): full_wrapper(input_1) # --- Test runtime mode FULL--- @@ -361,8 +389,9 @@ def test_nested_wrappers(self): # The inner mock should be called once inside the graph capture. outer_model.forward.reset_mock() inner_model.forward.reset_mock() - action = self._run_and_monitor_call(full_wrapper, input_1, - CUDAGraphMode.FULL, desc_1) + action = self._run_and_monitor_call( + full_wrapper, input_1, CUDAGraphMode.FULL, desc_1 + ) assert action == "capture_global" assert outer_model.forward.call_count == 1 assert inner_model.forward.call_count == 1 @@ -370,8 +399,9 @@ def test_nested_wrappers(self): # Run again. Expect outer wrapper to replay. # The outer model should NOT be called because the whole graph # is replayed. - action = self._run_and_monitor_call(full_wrapper, input_1, - CUDAGraphMode.FULL, desc_1) + action = self._run_and_monitor_call( + full_wrapper, input_1, CUDAGraphMode.FULL, desc_1 + ) assert action == "replay" assert outer_model.forward.call_count == 1 # No new call assert inner_model.forward.call_count == 1 @@ -382,16 +412,18 @@ def test_nested_wrappers(self): # Run with PIECEWISE mode context. # Expect outer wrapper to bypass and call inner wrapper. # Inner wrapper should capture. - action = self._run_and_monitor_call(full_wrapper, input_1, - CUDAGraphMode.PIECEWISE, desc_1) + action = self._run_and_monitor_call( + full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1 + ) assert action == "capture_global" assert outer_model.forward.call_count == 1 assert inner_model.forward.call_count == 1 # Run again with PIECEWISE. # Outer bypasses, inner replays. - action = self._run_and_monitor_call(full_wrapper, input_1, - CUDAGraphMode.PIECEWISE, desc_1) + action = self._run_and_monitor_call( + full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1 + ) assert action == "bypass" assert outer_model.forward.call_count == 2 assert inner_model.forward.call_count == 1 diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index c4116247bb7c..77d5c5d87fc1 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -45,10 +45,8 @@ def temporary_environ(env_vars): ] -@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", - combo_cases_1) -def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, - supported): +@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", combo_cases_1) +def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supported): if backend_name == "FlashInfer": try: import flashinfer # noqa: F401 @@ -56,8 +54,10 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, pytest.skip("FlashInfer is not installed") backend_config = backend_configs[backend_name] # Dynamically skip test if GPU capability is not met - if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\ - != current_platform.get_device_capability(): + if ( + backend_config.specific_gpu_arch + and backend_config.specific_gpu_arch != current_platform.get_device_capability() + ): pytest.skip("Only Hopper GPUs support FA3 and FlashMLA") env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars} @@ -66,13 +66,16 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, if not supported: stack.enter_context(pytest.raises(Exception)) - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - max_num_seqs=256, - trust_remote_code=True, - gpu_memory_utilization=0.45, - max_model_len=1024, - compilation_config=CompilationConfig( - level=3, cudagraph_mode=cudagraph_mode)) + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", + max_num_seqs=256, + trust_remote_code=True, + gpu_memory_utilization=0.45, + max_model_len=1024, + compilation_config=CompilationConfig( + level=3, cudagraph_mode=cudagraph_mode + ), + ) llm.generate(["Hello, my name is"] * 10) # when above code raises, `llm` may be undefined, so we need to catch that try: @@ -93,10 +96,13 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, ("FA2", "FULL", 0, True), # no compilation + full cudagraph ("FA2", "FULL", 3, True), # piecewise compilation + full cudagraph ("FA2", "PIECEWISE", 0, False), # no compilation + piecewise cudagraph - ("FA2", "PIECEWISE", 3, - True), # piecewise compilation + piecewise cudagraph - ("FA2", "FULL_AND_PIECEWISE", 0, - False), # piecewise cudagraph not supported without piecewise compilation + ("FA2", "PIECEWISE", 3, True), # piecewise compilation + piecewise cudagraph + ( + "FA2", + "FULL_AND_PIECEWISE", + 0, + False, + ), # piecewise cudagraph not supported without piecewise compilation ("FA2", "FULL_AND_PIECEWISE", 3, True), ("FA2", "FULL_DECODE_ONLY", 0, True), ("FA2", "FULL_DECODE_ONLY", 3, True), @@ -105,11 +111,11 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, ] -@pytest.mark.parametrize("backend_name,cudagraph_mode,compilation_level,"\ - "supported", combo_cases_2) +@pytest.mark.parametrize( + "backend_name,cudagraph_mode,compilation_level,supported", combo_cases_2 +) def test_cudagraph_compilation_combo(combo_case): - backend_name, cudagraph_mode, compilation_level, supported\ - = combo_case + backend_name, cudagraph_mode, compilation_level, supported = combo_case env_vars = {"VLLM_USE_V1": "1", **backend_configs[backend_name].env_vars} @@ -117,13 +123,16 @@ def test_cudagraph_compilation_combo(combo_case): if not supported: stack.enter_context(pytest.raises(Exception)) - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - max_num_seqs=256, - trust_remote_code=True, - gpu_memory_utilization=0.45, - max_model_len=1024, - compilation_config=CompilationConfig( - level=compilation_level, cudagraph_mode=cudagraph_mode)) + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", + max_num_seqs=256, + trust_remote_code=True, + gpu_memory_utilization=0.45, + max_model_len=1024, + compilation_config=CompilationConfig( + level=compilation_level, cudagraph_mode=cudagraph_mode + ), + ) llm.generate(["Hello, my name is"] * 10) # when above code raises, `llm` may be undefined, so we need to catch that try: diff --git a/tests/v1/distributed/test_async_llm_dp.py b/tests/v1/distributed/test_async_llm_dp.py index cef0f362cff8..75314dc37303 100644 --- a/tests/v1/distributed/test_async_llm_dp.py +++ b/tests/v1/distributed/test_async_llm_dp.py @@ -30,34 +30,38 @@ async def generate( - engine: AsyncLLM, - request_id: str, - prompt: PromptType, - output_kind: RequestOutputKind, - max_tokens: int, - prompt_logprobs: Optional[int] = None, - data_parallel_rank: Optional[int] = None) -> tuple[int, str]: + engine: AsyncLLM, + request_id: str, + prompt: PromptType, + output_kind: RequestOutputKind, + max_tokens: int, + prompt_logprobs: Optional[int] = None, + data_parallel_rank: Optional[int] = None, +) -> tuple[int, str]: # Ensure generate doesn't complete too fast for cancellation test. await asyncio.sleep(0.2) count = 0 - sampling_params = SamplingParams(max_tokens=max_tokens, - ignore_eos=True, - output_kind=output_kind, - temperature=0, - prompt_logprobs=prompt_logprobs) - async for out in engine.generate(request_id=request_id, - prompt=prompt, - sampling_params=sampling_params, - data_parallel_rank=data_parallel_rank): - + sampling_params = SamplingParams( + max_tokens=max_tokens, + ignore_eos=True, + output_kind=output_kind, + temperature=0, + prompt_logprobs=prompt_logprobs, + ) + async for out in engine.generate( + request_id=request_id, + prompt=prompt, + sampling_params=sampling_params, + data_parallel_rank=data_parallel_rank, + ): num_tokens = len(out.outputs[0].token_ids) if output_kind == RequestOutputKind.DELTA: count += num_tokens else: count = num_tokens - await asyncio.sleep(0.) + await asyncio.sleep(0.0) return count, request_id @@ -72,9 +76,9 @@ async def generate( @pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"]) @pytest.mark.parametrize("async_scheduling", [True, False]) @pytest.mark.asyncio -async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str, - async_scheduling: bool): - +async def test_load( + output_kind: RequestOutputKind, data_parallel_backend: str, async_scheduling: bool +): stats_loggers = {} @dataclass @@ -85,25 +89,26 @@ class SimpleStatsLogger(StatLoggerBase): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): stats_loggers[engine_index] = self - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ): if iteration_stats: - self.finished_req_count += len( - iteration_stats.finished_requests) + self.finished_req_count += len(iteration_stats.finished_requests) def log_engine_initialized(self): self.init_count += 1 with ExitStack() as after: - prompt = "This is a test of data parallel" engine_args.data_parallel_backend = data_parallel_backend engine_args.async_scheduling = async_scheduling - engine = AsyncLLM.from_engine_args(engine_args, - stat_loggers=[SimpleStatsLogger]) + engine = AsyncLLM.from_engine_args( + engine_args, stat_loggers=[SimpleStatsLogger] + ) after.callback(engine.shutdown) NUM_REQUESTS = 100 @@ -116,20 +121,23 @@ def log_engine_initialized(self): for request_id in request_ids: tasks.append( asyncio.create_task( - generate(engine, request_id, prompt, output_kind, - NUM_EXPECTED_TOKENS))) + generate( + engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS + ) + ) + ) # Short sleep to ensure that requests are distributed. await asyncio.sleep(0.01) # Confirm that we got all the EXPECTED tokens from the requests. - done, pending = await asyncio.wait(tasks, - return_when=asyncio.FIRST_EXCEPTION) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) for task in pending: task.cancel() for task in done: num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( f"{request_id} generated {num_generated_tokens} but " - f"expected {NUM_EXPECTED_TOKENS}") + f"expected {NUM_EXPECTED_TOKENS}" + ) assert not engine.output_processor.has_unfinished_requests() @@ -153,5 +161,6 @@ def log_engine_initialized(self): for sl in stats_loggers.values(): slogger: SimpleStatsLogger = sl - assert slogger.finished_req_count > NUM_REQUESTS // ( - DP_SIZE + 1), f"requests are imbalanced: {stats_loggers}" + assert slogger.finished_req_count > NUM_REQUESTS // (DP_SIZE + 1), ( + f"requests are imbalanced: {stats_loggers}" + ) diff --git a/tests/v1/distributed/test_external_lb_dp.py b/tests/v1/distributed/test_external_lb_dp.py index 862a76f3c4e2..912f8cffe7f6 100644 --- a/tests/v1/distributed/test_external_lb_dp.py +++ b/tests/v1/distributed/test_external_lb_dp.py @@ -26,12 +26,14 @@ class ExternalLBServerManager: """Manages data parallel vLLM server instances for external load balancer testing.""" - def __init__(self, - model_name: str, - dp_size: int, - api_server_count: int, - base_server_args: list, - tp_size: int = TP_SIZE): + def __init__( + self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + tp_size: int = TP_SIZE, + ): self.model_name = model_name self.dp_size = dp_size self.tp_size = tp_size @@ -47,20 +49,22 @@ def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: server_args = self.base_server_args.copy() # Add external LB specific arguments - server_args.extend([ - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-rank", - str(rank), - "--data-parallel-size-local", - "1", - "--tensor-parallel-size", - str(self.tp_size), - "--port", - str(8000 + rank), # Different port for each rank - "--api-server-count", - str(self.api_server_count), - ]) + server_args.extend( + [ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-rank", + str(rank), + "--data-parallel-size-local", + "1", + "--tensor-parallel-size", + str(self.tp_size), + "--port", + str(8000 + rank), # Different port for each rank + "--api-server-count", + str(self.api_server_count), + ] + ) # Use a thread to start each server to allow parallel initialization def start_server(r: int, sargs: list[str]): @@ -71,25 +75,24 @@ def start_server(r: int, sargs: list[str]): sargs, auto_port=False, env_dict={ - "VLLM_SERVER_DEV_MODE": - "1", - current_platform.device_control_env_var: - ",".join( - str( - current_platform. - device_id_to_physical_device_id(i)) - for i in range(r * TP_SIZE, (r + 1) * TP_SIZE)) - }) + "VLLM_SERVER_DEV_MODE": "1", + current_platform.device_control_env_var: ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(r * TP_SIZE, (r + 1) * TP_SIZE) + ), + }, + ) server.__enter__() - print(f"Server rank {r} started successfully with " - f"{self.api_server_count} API servers") + print( + f"Server rank {r} started successfully with " + f"{self.api_server_count} API servers" + ) self.servers.append((server, sargs)) except Exception as e: print(f"Failed to start server rank {r}: {e}") raise - thread = threading.Thread(target=start_server, - args=(rank, server_args)) + thread = threading.Thread(target=start_server, args=(rank, server_args)) thread.start() self.server_threads.append(thread) @@ -132,9 +135,9 @@ def default_server_args(): @pytest.fixture(scope="module", params=[1, 4]) def server_manager(request, default_server_args): api_server_count = request.param - server_manager = ExternalLBServerManager(MODEL_NAME, DP_SIZE, - api_server_count, - default_server_args) + server_manager = ExternalLBServerManager( + MODEL_NAME, DP_SIZE, api_server_count, default_server_args + ) with server_manager: yield server_manager @@ -174,18 +177,16 @@ def test_external_lb_server_info(server_manager): # `n_reqs` is set so that there is a good chance each server # receives at least one request n_reqs = 2 * api_server_count * api_server_count - parallel_configs = [ - _get_parallel_config(server) for _ in range(n_reqs) - ] - api_process_counts = [ - c["_api_process_count"] for c in parallel_configs - ] + parallel_configs = [_get_parallel_config(server) for _ in range(n_reqs)] + api_process_counts = [c["_api_process_count"] for c in parallel_configs] api_process_ranks = [c["_api_process_rank"] for c in parallel_configs] - assert all(c == api_server_count - for c in api_process_counts), api_process_counts - assert all(0 <= r < api_server_count - for r in api_process_ranks), api_process_ranks + assert all(c == api_server_count for c in api_process_counts), ( + api_process_counts + ) + assert all(0 <= r < api_server_count for r in api_process_ranks), ( + api_process_ranks + ) @pytest.mark.asyncio @@ -193,16 +194,15 @@ def test_external_lb_server_info(server_manager): "model_name", [MODEL_NAME], ) -async def test_external_lb_single_completion(clients: list[ - openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]], - model_name: str) -> None: - +async def test_external_lb_single_completion( + clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: async def make_request(client: openai.AsyncOpenAI): completion = await client.completions.create( - model=model_name, - prompt="Hello, my name is", - max_tokens=10, - temperature=1.0) + model=model_name, prompt="Hello, my name is", max_tokens=10, temperature=1.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -256,11 +256,14 @@ async def make_request(client: openai.AsyncOpenAI): _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) print( f"Successfully completed external LB test with {len(clients)} servers " - f"(API server count: {api_server_count})") + f"(API server count: {api_server_count})" + ) @pytest.mark.asyncio @@ -268,9 +271,11 @@ async def make_request(client: openai.AsyncOpenAI): "model_name", [MODEL_NAME], ) -async def test_external_lb_completion_streaming(clients: list[ - openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]], - model_name: str) -> None: +async def test_external_lb_completion_streaming( + clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: prompt = "What is an LLM?" async def make_streaming_request(client: openai.AsyncOpenAI): @@ -284,11 +289,9 @@ async def make_streaming_request(client: openai.AsyncOpenAI): single_output = single_completion.choices[0].text # Perform the streaming request - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 last_chunk = None @@ -299,16 +302,15 @@ async def make_streaming_request(client: openai.AsyncOpenAI): last_chunk = chunk # Keep track of the last chunk # finish reason should only return in the last block for OpenAI API - assert finish_reason_count == 1, ( - "Finish reason should appear exactly once.") - assert last_chunk is not None, ( - "Stream should have yielded at least one chunk.") - assert last_chunk.choices[ - 0].finish_reason == "length", "Finish reason should be 'length'." + assert finish_reason_count == 1, "Finish reason should appear exactly once." + assert last_chunk is not None, "Stream should have yielded at least one chunk." + assert last_chunk.choices[0].finish_reason == "length", ( + "Finish reason should be 'length'." + ) # Check that the combined text matches the non-streamed version. - assert "".join( - chunks - ) == single_output, "Streamed output should match non-streamed output." + assert "".join(chunks) == single_output, ( + "Streamed output should match non-streamed output." + ) return True # Indicate success for this request # Test single request to each server @@ -324,10 +326,7 @@ async def make_streaming_request(client: openai.AsyncOpenAI): all_tasks = [] for i, client in enumerate(clients): - tasks = [ - make_streaming_request(client) - for _ in range(num_requests_per_server) - ] + tasks = [make_streaming_request(client) for _ in range(num_requests_per_server)] all_tasks.extend(tasks) results = await asyncio.gather(*all_tasks) @@ -339,10 +338,7 @@ async def make_streaming_request(client: openai.AsyncOpenAI): # Second burst of streaming requests all_tasks = [] for i, client in enumerate(clients): - tasks = [ - make_streaming_request(client) - for _ in range(num_requests_per_server) - ] + tasks = [make_streaming_request(client) for _ in range(num_requests_per_server)] all_tasks.extend(tasks) results = await asyncio.gather(*all_tasks) @@ -351,7 +347,11 @@ async def make_streaming_request(client: openai.AsyncOpenAI): _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) - print(f"Successfully completed external LB streaming test with " - f"{len(clients)} servers (API server count: {api_server_count})") + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed external LB streaming test with " + f"{len(clients)} servers (API server count: {api_server_count})" + ) diff --git a/tests/v1/distributed/test_hybrid_lb_dp.py b/tests/v1/distributed/test_hybrid_lb_dp.py index 21d8009a6dbb..aa25130752a4 100644 --- a/tests/v1/distributed/test_hybrid_lb_dp.py +++ b/tests/v1/distributed/test_hybrid_lb_dp.py @@ -28,17 +28,19 @@ class HybridLBServerManager: - """Manages hybrid data parallel vLLM server instances where each node - runs a single logical API server that balances requests only to the + """Manages hybrid data parallel vLLM server instances where each node + runs a single logical API server that balances requests only to the DP engines running on that same node.""" - def __init__(self, - model_name: str, - dp_size: int, - api_server_count: int, - base_server_args: list, - dp_size_local: int = DP_SIZE_LOCAL, - tp_size: int = TP_SIZE): + def __init__( + self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + dp_size_local: int = DP_SIZE_LOCAL, + tp_size: int = TP_SIZE, + ): self.model_name = model_name self.dp_size = dp_size self.dp_size_local = dp_size_local @@ -59,25 +61,27 @@ def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: start_rank = node_id * self.dp_size_local # Add hybrid LB specific arguments - server_args.extend([ - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-size-local", - str(self.dp_size_local), - "--data-parallel-start-rank", - str(start_rank), - "--data-parallel-hybrid-lb", # Enable hybrid LB mode - "--tensor-parallel-size", - str(self.tp_size), - "--port", - str(8000 + node_id), # Different port for each node - "--api-server-count", - str(self.api_server_count), - "--data-parallel-address", - "127.0.0.1", - "--data-parallel-rpc-port", - "13345", - ]) + server_args.extend( + [ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_size_local), + "--data-parallel-start-rank", + str(start_rank), + "--data-parallel-hybrid-lb", # Enable hybrid LB mode + "--tensor-parallel-size", + str(self.tp_size), + "--port", + str(8000 + node_id), # Different port for each node + "--api-server-count", + str(self.api_server_count), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ] + ) # Use a thread to start each server to allow parallel initialization def start_server(node: int, sargs: list[str]): @@ -93,26 +97,25 @@ def start_server(node: int, sargs: list[str]): sargs, auto_port=False, env_dict={ - "VLLM_SERVER_DEV_MODE": - "1", - current_platform.device_control_env_var: - ",".join( - str( - current_platform. - device_id_to_physical_device_id(i)) - for i in range(gpu_start, gpu_end)) - }) + "VLLM_SERVER_DEV_MODE": "1", + current_platform.device_control_env_var: ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(gpu_start, gpu_end) + ), + }, + ) server.__enter__() - print(f"Hybrid LB node {node} started successfully with " - f"{self.dp_size_local} local DP ranks and " - f"{self.api_server_count} API servers") + print( + f"Hybrid LB node {node} started successfully with " + f"{self.dp_size_local} local DP ranks and " + f"{self.api_server_count} API servers" + ) self.servers.append((server, sargs)) except Exception as e: print(f"Failed to start hybrid LB node {node}: {e}") raise - thread = threading.Thread(target=start_server, - args=(node_id, server_args)) + thread = threading.Thread(target=start_server, args=(node_id, server_args)) thread.start() self.server_threads.append(thread) @@ -155,10 +158,14 @@ def default_server_args(): @pytest.fixture(scope="module", params=[1, 4]) def server_manager(request, default_server_args): api_server_count = request.param - server_manager = HybridLBServerManager(MODEL_NAME, DP_SIZE, - api_server_count, - default_server_args, DP_SIZE_LOCAL, - TP_SIZE) + server_manager = HybridLBServerManager( + MODEL_NAME, + DP_SIZE, + api_server_count, + default_server_args, + DP_SIZE_LOCAL, + TP_SIZE, + ) with server_manager: yield server_manager @@ -198,18 +205,16 @@ def test_hybrid_dp_server_info(server_manager): # `n_reqs` is set so that there is a good chance each server # receives at least one request n_reqs = 2 * api_server_count * api_server_count - parallel_configs = [ - _get_parallel_config(server) for _ in range(n_reqs) - ] - api_process_counts = [ - c["_api_process_count"] for c in parallel_configs - ] + parallel_configs = [_get_parallel_config(server) for _ in range(n_reqs)] + api_process_counts = [c["_api_process_count"] for c in parallel_configs] api_process_ranks = [c["_api_process_rank"] for c in parallel_configs] - assert all(c == api_server_count - for c in api_process_counts), api_process_counts - assert all(0 <= r < api_server_count - for r in api_process_ranks), api_process_ranks + assert all(c == api_server_count for c in api_process_counts), ( + api_process_counts + ) + assert all(0 <= r < api_server_count for r in api_process_ranks), ( + api_process_ranks + ) @pytest.mark.asyncio @@ -217,17 +222,15 @@ def test_hybrid_dp_server_info(server_manager): "model_name", [MODEL_NAME], ) -async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI], - servers: list[tuple[RemoteOpenAIServer, - list[str]]], - model_name: str) -> None: - +async def test_hybrid_lb_completion( + clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: async def make_request(client: openai.AsyncOpenAI): completion = await client.completions.create( - model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=1.0) + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -251,9 +254,7 @@ async def make_request(client: openai.AsyncOpenAI): for i, client in enumerate(clients): result = await make_request(client) assert result is not None - print( - f"Hybrid LB node {i} handled single completion request successfully" - ) + print(f"Hybrid LB node {i} handled single completion request successfully") await asyncio.sleep(0.5) @@ -284,8 +285,10 @@ async def make_request(client: openai.AsyncOpenAI): _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) print( f"Successfully completed hybrid LB test with {len(clients)} nodes " f"({DP_SIZE_LOCAL} DP ranks each, API server count: {api_server_count})" @@ -302,9 +305,11 @@ async def make_request(client: openai.AsyncOpenAI): "model_name", [MODEL_NAME], ) -async def test_hybrid_lb_completion_streaming(clients: list[ - openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]], - model_name: str) -> None: +async def test_hybrid_lb_completion_streaming( + clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: prompt = "What is an LLM?" async def make_streaming_request(client: openai.AsyncOpenAI): @@ -318,11 +323,9 @@ async def make_streaming_request(client: openai.AsyncOpenAI): single_output = single_completion.choices[0].text # Perform the streaming request - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 last_chunk = None @@ -333,25 +336,22 @@ async def make_streaming_request(client: openai.AsyncOpenAI): last_chunk = chunk # Keep track of the last chunk # finish reason should only return in the last block for OpenAI API - assert finish_reason_count == 1, ( - "Finish reason should appear exactly once.") - assert last_chunk is not None, ( - "Stream should have yielded at least one chunk.") - assert last_chunk.choices[ - 0].finish_reason == "length", "Finish reason should be 'length'." + assert finish_reason_count == 1, "Finish reason should appear exactly once." + assert last_chunk is not None, "Stream should have yielded at least one chunk." + assert last_chunk.choices[0].finish_reason == "length", ( + "Finish reason should be 'length'." + ) # Check that the combined text matches the non-streamed version. - assert "".join( - chunks - ) == single_output, "Streamed output should match non-streamed output." + assert "".join(chunks) == single_output, ( + "Streamed output should match non-streamed output." + ) return True # Indicate success for this request # Test single request to each node for i, client in enumerate(clients): result = await make_streaming_request(client) assert result is not None - print( - f"Hybrid LB node {i} handled single streaming request successfully" - ) + print(f"Hybrid LB node {i} handled single streaming request successfully") await asyncio.sleep(0.5) @@ -382,11 +382,15 @@ async def make_streaming_request(client: openai.AsyncOpenAI): _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) - print(f"Successfully completed hybrid LB streaming test with " - f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, " - f"API server count: {api_server_count})") + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed hybrid LB streaming test with " + f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, " + f"API server count: {api_server_count})" + ) # Check request balancing within each node for i, (server, _) in enumerate(servers): diff --git a/tests/v1/distributed/test_internal_lb_dp.py b/tests/v1/distributed/test_internal_lb_dp.py index 3f9defd13dea..452d3682e65d 100644 --- a/tests/v1/distributed/test_internal_lb_dp.py +++ b/tests/v1/distributed/test_internal_lb_dp.py @@ -31,66 +31,71 @@ class MultinodeInternalLBServerManager: """Manages multi-node data parallel vLLM server instances for internal load balancer testing using --headless mode.""" - def __init__(self, - model_name: str, - dp_size: int, - api_server_count: int, - base_server_args: list, - dp_per_node: int = 1, - tp_size: int = TP_SIZE): + def __init__( + self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + dp_per_node: int = 1, + tp_size: int = TP_SIZE, + ): self.model_name = model_name self.dp_size = dp_size self.dp_per_node = dp_per_node self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[Optional[tuple[RemoteOpenAIServer, - list[str]]]] = [None] * (dp_size // - dp_per_node) + self.servers: list[Optional[tuple[RemoteOpenAIServer, list[str]]]] = [None] * ( + dp_size // dp_per_node + ) self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: """Start all server instances for multi-node internal LB mode.""" - for server_idx, rank in enumerate( - range(0, self.dp_size, self.dp_per_node)): + for server_idx, rank in enumerate(range(0, self.dp_size, self.dp_per_node)): # Create server args for this specific rank server_args = self.base_server_args.copy() if rank == 0: # Head node - runs API server and first DP rank - server_args.extend([ - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-size-local", - str(self.dp_per_node), - "--tensor-parallel-size", - str(self.tp_size), - "--port", - "8000", # Single endpoint for all requests - "--api-server-count", - str(self.api_server_count), - "--data-parallel-address", - "127.0.0.1", - "--data-parallel-rpc-port", - "13345", - ]) + server_args.extend( + [ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_per_node), + "--tensor-parallel-size", + str(self.tp_size), + "--port", + "8000", # Single endpoint for all requests + "--api-server-count", + str(self.api_server_count), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ] + ) else: # Secondary nodes - run in headless mode - server_args.extend([ - "--headless", - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-size-local", - str(self.dp_per_node), - "--data-parallel-start-rank", - str(rank), - "--tensor-parallel-size", - str(self.tp_size), - "--data-parallel-address", - "127.0.0.1", - "--data-parallel-rpc-port", - "13345", - ]) + server_args.extend( + [ + "--headless", + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_per_node), + "--data-parallel-start-rank", + str(rank), + "--tensor-parallel-size", + str(self.tp_size), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ] + ) # Use a thread to start each server to allow parallel initialization def start_server(sidx: int, r: int, sargs: list[str]): @@ -102,20 +107,19 @@ def start_server(sidx: int, r: int, sargs: list[str]): sargs, auto_port=False, env_dict={ - "VLLM_SERVER_DEV_MODE": - "1", - current_platform.device_control_env_var: - ",".join( - str( - current_platform. - device_id_to_physical_device_id(i)) - for i in range(r, r + gpus_per_node)) - }) + "VLLM_SERVER_DEV_MODE": "1", + current_platform.device_control_env_var: ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(r, r + gpus_per_node) + ), + }, + ) server.__enter__() if r == 0: print( f"Head node (rank {r}) started successfully with " - f"{self.api_server_count} API servers") + f"{self.api_server_count} API servers" + ) else: print(f"Headless node (rank {r}) started successfully") self.servers[sidx] = (server, sargs) @@ -124,8 +128,9 @@ def start_server(sidx: int, r: int, sargs: list[str]): traceback.print_exc() raise - thread = threading.Thread(target=start_server, - args=(server_idx, rank, server_args)) + thread = threading.Thread( + target=start_server, args=(server_idx, rank, server_args) + ) thread.start() self.server_threads.append(thread) @@ -157,19 +162,20 @@ class APIOnlyServerManager: """Manages API-only server (Node 0) and headless engines server (Node 1) for testing separated API server and engine configuration.""" - def __init__(self, - model_name: str, - dp_size: int, - api_server_count: int, - base_server_args: list, - tp_size: int = TP_SIZE): + def __init__( + self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + tp_size: int = TP_SIZE, + ): self.model_name = model_name self.dp_size = dp_size self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[Optional[tuple[RemoteOpenAIServer, - list[str]]]] = [None] * 2 + self.servers: list[Optional[tuple[RemoteOpenAIServer, list[str]]]] = [None] * 2 self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: @@ -177,38 +183,42 @@ def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: # Start API-only server (Node 0) - no engines, only API server api_server_args = self.base_server_args.copy() - api_server_args.extend([ - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-size-local", - "0", # No engines on this node - "--tensor-parallel-size", - str(self.tp_size), - "--port", - "8000", - "--api-server-count", - str(self.api_server_count), - "--data-parallel-address", - "127.0.0.1", - "--data-parallel-rpc-port", - "13345", - ]) + api_server_args.extend( + [ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + "0", # No engines on this node + "--tensor-parallel-size", + str(self.tp_size), + "--port", + "8000", + "--api-server-count", + str(self.api_server_count), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ] + ) # Start headless engines server (Node 1) - all engines, no API server engines_server_args = self.base_server_args.copy() - engines_server_args.extend([ - "--headless", - "--data-parallel-size", - str(self.dp_size), - "--data-parallel-size-local", - str(self.dp_size), # All engines on this node - "--tensor-parallel-size", - str(self.tp_size), - "--data-parallel-address", - "127.0.0.1", - "--data-parallel-rpc-port", - "13345", - ]) + engines_server_args.extend( + [ + "--headless", + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_size), # All engines on this node + "--tensor-parallel-size", + str(self.tp_size), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ] + ) # Use threads to start both servers in parallel def start_api_server(): @@ -220,10 +230,13 @@ def start_api_server(): env_dict={ "VLLM_SERVER_DEV_MODE": "1", # No GPUs needed for API-only server - }) + }, + ) server.__enter__() - print(f"API-only server started successfully with " - f"{self.api_server_count} API servers") + print( + f"API-only server started successfully with " + f"{self.api_server_count} API servers" + ) self.servers[0] = (server, api_server_args) except Exception as e: print(f"Failed to start API-only server: {e}") @@ -236,16 +249,17 @@ def start_engines_server(): engines_server_args, auto_port=False, env_dict={ - current_platform.device_control_env_var: - ",".join( - str( - current_platform. - device_id_to_physical_device_id(i)) - for i in range(self.dp_size * self.tp_size)) - }) + current_platform.device_control_env_var: ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(self.dp_size * self.tp_size) + ) + }, + ) server.__enter__() - print(f"Headless engines server started successfully with " - f"{self.dp_size} engines") + print( + f"Headless engines server started successfully with " + f"{self.dp_size} engines" + ) self.servers[1] = (server, engines_server_args) except Exception as e: print(f"Failed to start headless engines server: {e}") @@ -301,11 +315,14 @@ def default_server_args(): @pytest.fixture(scope="module", params=[1, 4]) def server_manager(request, default_server_args): api_server_count = request.param - server_manager = MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE, - api_server_count, - default_server_args, - DP_SIZE // NUM_NODES, - TP_SIZE) + server_manager = MultinodeInternalLBServerManager( + MODEL_NAME, + DP_SIZE, + api_server_count, + default_server_args, + DP_SIZE // NUM_NODES, + TP_SIZE, + ) with server_manager: yield server_manager @@ -320,8 +337,9 @@ def servers(server_manager): def api_only_servers(request, default_server_args): """Fixture for API-only server + headless engines configuration.""" api_server_count = request.param - with APIOnlyServerManager(MODEL_NAME, DP_SIZE, api_server_count, - default_server_args, TP_SIZE) as server_list: + with APIOnlyServerManager( + MODEL_NAME, DP_SIZE, api_server_count, default_server_args, TP_SIZE + ) as server_list: yield server_list @@ -335,8 +353,7 @@ async def client(servers: list[tuple[RemoteOpenAIServer, list[str]]]): @pytest_asyncio.fixture -async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer, - list[str]]]): +async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer, list[str]]]): """Client fixture for API-only server configuration.""" # Connect to the API-only server (first server in the list) api_server = api_only_servers[0][0] @@ -360,16 +377,12 @@ def test_multinode_dp_server_info(server_manager): # `n_reqs` is set so that there is a good chance each server # receives at least one request n_reqs = 2 * api_server_count * api_server_count - parallel_configs = [ - _get_parallel_config(head_server) for _ in range(n_reqs) - ] + parallel_configs = [_get_parallel_config(head_server) for _ in range(n_reqs)] api_process_counts = [c["_api_process_count"] for c in parallel_configs] api_process_ranks = [c["_api_process_rank"] for c in parallel_configs] - assert all(c == api_server_count - for c in api_process_counts), api_process_counts - assert all(0 <= r < api_server_count - for r in api_process_ranks), api_process_ranks + assert all(c == api_server_count for c in api_process_counts), api_process_counts + assert all(0 <= r < api_server_count for r in api_process_ranks), api_process_ranks @pytest.mark.asyncio @@ -377,17 +390,15 @@ def test_multinode_dp_server_info(server_manager): "model_name", [MODEL_NAME], ) -async def test_multinode_dp_completion(client: openai.AsyncOpenAI, - servers: list[tuple[RemoteOpenAIServer, - list[str]]], - model_name: str) -> None: - +async def test_multinode_dp_completion( + client: openai.AsyncOpenAI, + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: async def make_request(): completion = await client.completions.create( - model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=1.0) + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -410,9 +421,7 @@ async def make_request(): # Test single request result = await make_request() assert result is not None - print( - "Multi-node internal LB handled single completion request successfully" - ) + print("Multi-node internal LB handled single completion request successfully") await asyncio.sleep(0.5) @@ -441,10 +450,14 @@ async def make_request(): _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) - print(f"Successfully completed multi-node internal LB test with " - f"{len(servers)} DP ranks (API server count: {api_server_count})") + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed multi-node internal LB test with " + f"{len(servers)} DP ranks (API server count: {api_server_count})" + ) # Check request balancing via Prometheus metrics head_server = servers[0][0] @@ -456,11 +469,11 @@ async def make_request(): "model_name", [MODEL_NAME], ) -async def test_multinode_dp_completion_streaming(client: openai.AsyncOpenAI, - servers: list[ - tuple[RemoteOpenAIServer, - list[str]]], - model_name: str) -> None: +async def test_multinode_dp_completion_streaming( + client: openai.AsyncOpenAI, + servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: prompt = "What is an LLM?" async def make_streaming_request(): @@ -474,11 +487,9 @@ async def make_streaming_request(): single_output = single_completion.choices[0].text # Perform the streaming request - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 last_chunk = None @@ -489,23 +500,21 @@ async def make_streaming_request(): last_chunk = chunk # Keep track of the last chunk # finish reason should only return in the last block for OpenAI API - assert finish_reason_count == 1, ( - "Finish reason should appear exactly once.") - assert last_chunk is not None, ( - "Stream should have yielded at least one chunk.") - assert last_chunk.choices[ - 0].finish_reason == "length", "Finish reason should be 'length'." + assert finish_reason_count == 1, "Finish reason should appear exactly once." + assert last_chunk is not None, "Stream should have yielded at least one chunk." + assert last_chunk.choices[0].finish_reason == "length", ( + "Finish reason should be 'length'." + ) # Check that the combined text matches the non-streamed version. - assert "".join( - chunks - ) == single_output, "Streamed output should match non-streamed output." + assert "".join(chunks) == single_output, ( + "Streamed output should match non-streamed output." + ) return True # Indicate success for this request # Test single streaming request result = await make_streaming_request() assert result is not None - print( - "Multi-node internal LB handled single streaming request successfully") + print("Multi-node internal LB handled single streaming request successfully") await asyncio.sleep(0.5) @@ -535,10 +544,14 @@ async def make_streaming_request(): _, server_args = servers[0] api_server_count = ( - server_args.count('--api-server-count') - and server_args[server_args.index('--api-server-count') + 1] or 1) - print(f"Successfully completed multi-node internal LB streaming test with " - f"{len(servers)} DP ranks (API server count: {api_server_count})") + server_args.count("--api-server-count") + and server_args[server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed multi-node internal LB streaming test with " + f"{len(servers)} DP ranks (API server count: {api_server_count})" + ) # Check request balancing via Prometheus metrics head_server = servers[0][0] @@ -551,17 +564,16 @@ async def make_streaming_request(): [MODEL_NAME], ) async def test_api_only_multinode_dp_completion( - api_only_client: openai.AsyncOpenAI, - api_only_servers: list[tuple[RemoteOpenAIServer, - list[str]]], model_name: str) -> None: + api_only_client: openai.AsyncOpenAI, + api_only_servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: """Test API-only server with all engines on separate headless server.""" async def make_request(): completion = await api_only_client.completions.create( - model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=1.0) + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=1.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -614,11 +626,14 @@ async def make_request(): api_server, api_server_args = api_only_servers[0] api_server_count = ( - api_server_args.count('--api-server-count') - and api_server_args[api_server_args.index('--api-server-count') + 1] - or 1) - print(f"Successfully completed API-only multi-node test with {DP_SIZE} " - f"engines on headless server (API server count: {api_server_count})") + api_server_args.count("--api-server-count") + and api_server_args[api_server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed API-only multi-node test with {DP_SIZE} " + f"engines on headless server (API server count: {api_server_count})" + ) # Check request balancing via Prometheus metrics check_request_balancing(api_server, DP_SIZE) @@ -630,9 +645,10 @@ async def make_request(): [MODEL_NAME], ) async def test_api_only_multinode_dp_completion_streaming( - api_only_client: openai.AsyncOpenAI, - api_only_servers: list[tuple[RemoteOpenAIServer, - list[str]]], model_name: str) -> None: + api_only_client: openai.AsyncOpenAI, + api_only_servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str, +) -> None: """Test API-only server streaming with all engines on separate headless server.""" prompt = "What is an LLM?" @@ -648,11 +664,9 @@ async def make_streaming_request(): single_output = single_completion.choices[0].text # Perform the streaming request - stream = await api_only_client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await api_only_client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 last_chunk = None @@ -663,16 +677,15 @@ async def make_streaming_request(): last_chunk = chunk # Keep track of the last chunk # finish reason should only return in the last block for OpenAI API - assert finish_reason_count == 1, ( - "Finish reason should appear exactly once.") - assert last_chunk is not None, ( - "Stream should have yielded at least one chunk.") - assert last_chunk.choices[ - 0].finish_reason == "length", "Finish reason should be 'length'." + assert finish_reason_count == 1, "Finish reason should appear exactly once." + assert last_chunk is not None, "Stream should have yielded at least one chunk." + assert last_chunk.choices[0].finish_reason == "length", ( + "Finish reason should be 'length'." + ) # Check that the combined text matches the non-streamed version. - assert "".join( - chunks - ) == single_output, "Streamed output should match non-streamed output." + assert "".join(chunks) == single_output, ( + "Streamed output should match non-streamed output." + ) return True # Indicate success for this request # Test single streaming request @@ -707,11 +720,14 @@ async def make_streaming_request(): _, api_server_args = api_only_servers[0] api_server_count = ( - api_server_args.count('--api-server-count') - and api_server_args[api_server_args.index('--api-server-count') + 1] - or 1) - print(f"Successfully completed API-only streaming test with {DP_SIZE} " - f"engines on headless server (API server count: {api_server_count})") + api_server_args.count("--api-server-count") + and api_server_args[api_server_args.index("--api-server-count") + 1] + or 1 + ) + print( + f"Successfully completed API-only streaming test with {DP_SIZE} " + f"engines on headless server (API server count: {api_server_count})" + ) # Check request balancing via Prometheus metrics api_server = api_only_servers[0][0] diff --git a/tests/v1/e2e/test_cascade_attention.py b/tests/v1/e2e/test_cascade_attention.py index 5022347a87a4..5f26c2f1c651 100644 --- a/tests/v1/e2e/test_cascade_attention.py +++ b/tests/v1/e2e/test_cascade_attention.py @@ -14,8 +14,10 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend): prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:" if attn_backend == "FLASHINFER": - pytest.skip("This test is failing with FlashInfer backend and " - "needs investigation. See issue #25679.") + pytest.skip( + "This test is failing with FlashInfer backend and " + "needs investigation. See issue #25679." + ) with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") diff --git a/tests/v1/e2e/test_context_length.py b/tests/v1/e2e/test_context_length.py index 67a6c7be4432..4076b1fa0615 100644 --- a/tests/v1/e2e/test_context_length.py +++ b/tests/v1/e2e/test_context_length.py @@ -4,15 +4,22 @@ end-to-end tests for context length corner cases of vLLM v1 model runner versus HuggingFace's transformers. -This test verifies the following behavior: allow a prefill that fills the -model's maximum context length and then request a single new token. +This test verifies the following behavior: allow prefill and decodes on the +model's maximum context length ``max_model_len`` and get one more token. Test strategy -- Build a textual prompt that tokenizes to exactly ``max_model_len`` tokens. -- Run vLLM generation requesting a single new token (max_tokens=1). -- Run HF generation on the same prompt requesting a single token too. +- Build a prompt consisting of exactly ``prompt_len`` tokens. +- Run vLLM generation requesting ``max_tokens`` new tokens. +- Run HF generation on the same prompt requesting the same number of tokens. - Assert both return the same number of generated tokens and the same ids. +Test cases +- Prefill a prompt of ``max_model_len`` (2048) and request a single token which +will be sampled after the prefill (context length ``max_model_len``). +- Prefill a prompt of ``max_model_len`` - 1 (2047) and request two tokens where +the 1st will be sampled after the prefill and the 2nd after the first decode +(context length ``max_model_len``). + """ import pytest @@ -27,11 +34,16 @@ @create_new_process_for_each_test() @pytest.mark.parametrize("model", ["JackFram/llama-160m"]) -@pytest.mark.parametrize("max_model_len", [2048]) -@pytest.mark.parametrize("max_tokens", [1]) -def test_prefill_max_context_length( +@pytest.mark.parametrize( + "prompt_len, max_tokens", + [ + (2048, 1), # prompt_len = max_model_len + (2047, 2), # prompt_len = max_model_len - 1 + ], +) +def test_max_context_length( model: str, - max_model_len: int, + prompt_len: int, max_tokens: int, ) -> None: """Compare vLLM and HuggingFace when the prompt already fills the @@ -42,8 +54,8 @@ def test_prefill_max_context_length( single token when given the same inputs. """ - # Construct a prompt of size max_model_len - prompt_ids = [[43] * max_model_len] + # Construct a prompt of size prompt_len + prompt_ids = [[43] * prompt_len] # Generate max_tokens new tokens deterministically. sampling_params = [ @@ -54,6 +66,7 @@ def test_prefill_max_context_length( llm = LLM( model=model, tokenizer=model, + max_model_len=2048, max_num_seqs=1, tensor_parallel_size=1, ) @@ -79,7 +92,10 @@ def test_prefill_max_context_length( ) # HF returns the prompt + generated tokens. Slice off the prompt. - hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):] + hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]) :] + + # check that exactly max_tokens tokens were generated with vLLM and HF + assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens # check that vLLM outputs (token ids) match HF outputs # Note: for simplicity don't pass detokenized string diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 5b0c15472251..c9018ee177e8 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -26,12 +26,14 @@ class TestConfig: [ "bigcode/starcoder2-3b", # sliding window only "google/gemma-3-1b-it", # sliding window + full attention - ]) + ], +) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False]) -def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed, - disable_hybrid_kv_cache_manager): +def test_sliding_window_retrieval( + monkeypatch, model, batch_size, seed, disable_hybrid_kv_cache_manager +): """ The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then asks for value of one of them (which is outside the sliding window). @@ -44,33 +46,38 @@ def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed, test_config = model_config[model] llm = LLM( - model=model, - disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager) + model=model, disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager + ) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) - prompts, answer, indices = prep_prompts(batch_size, - ln_range=test_config.ln_range) + prompts, answer, indices = prep_prompts( + batch_size, ln_range=test_config.ln_range + ) check_length(prompts, llm, test_config.sliding_window) # Fresh generation responses = llm.generate(prompts, sampling_params) - check_answers(indices, - answer, - [response.outputs[0].text for response in responses], - accept_rate=1.0) + check_answers( + indices, + answer, + [response.outputs[0].text for response in responses], + accept_rate=1.0, + ) # Re-generate with the same prompts to test prefix caching responses = llm.generate(prompts, sampling_params) - check_answers(indices, - answer, - [response.outputs[0].text for response in responses], - accept_rate=1.0) + check_answers( + indices, + answer, + [response.outputs[0].text for response in responses], + accept_rate=1.0, + ) def check_length(prompts: list[str], llm: LLM, sliding_window: int): """ - Check if the prompt length is valid, i.e., longer than the sliding window + Check if the prompt length is valid, i.e., longer than the sliding window size and shorter than the model's max length. Args: @@ -80,9 +87,9 @@ def check_length(prompts: list[str], llm: LLM, sliding_window: int): """ tokenizer = llm.get_tokenizer() max_model_len = llm.llm_engine.model_config.max_model_len - assert any( - len(tokenizer.encode(prompt)) > sliding_window - for prompt in prompts), "Prompt is too short for test" - assert all( - len(tokenizer.encode(prompt)) <= max_model_len - for prompt in prompts), "Prompt is too long for test" + assert any(len(tokenizer.encode(prompt)) > sliding_window for prompt in prompts), ( + "Prompt is too short for test" + ) + assert all(len(tokenizer.encode(prompt)) <= max_model_len for prompt in prompts), ( + "Prompt is too long for test" + ) diff --git a/tests/v1/e2e/test_kv_sharing_fast_prefill.py b/tests/v1/e2e/test_kv_sharing_fast_prefill.py index 6bc9b2b1d82d..b9052d8a58b8 100644 --- a/tests/v1/e2e/test_kv_sharing_fast_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_fast_prefill.py @@ -76,7 +76,9 @@ def test_kv_sharing_fast_prefill( # managing buffers for cudagraph cudagraph_copy_inputs=True, level=CompilationLevel.PIECEWISE - if not enforce_eager else CompilationLevel.NO_COMPILATION) + if not enforce_eager + else CompilationLevel.NO_COMPILATION, + ) with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -94,21 +96,21 @@ def test_kv_sharing_fast_prefill( cleanup(llm, compilation_config) - llm = LLM(model="google/gemma-3n-E2B-it", - enforce_eager=enforce_eager, - compilation_config=compilation_config, - seed=SEED, - kv_sharing_fast_prefill=True) + llm = LLM( + model="google/gemma-3n-E2B-it", + enforce_eager=enforce_eager, + compilation_config=compilation_config, + seed=SEED, + kv_sharing_fast_prefill=True, + ) optimized_responses = llm.generate(test_prompts, sampling_params) cleanup(llm, compilation_config) misses = 0 - for ref_response, optimized_response in zip(ref_responses, - optimized_responses): - if ref_response.outputs[0].text != optimized_response.outputs[ - 0].text: + for ref_response, optimized_response in zip(ref_responses, optimized_responses): + if ref_response.outputs[0].text != optimized_response.outputs[0].text: misses += 1 assert misses == 0 diff --git a/tests/v1/e2e/test_min_tokens.py b/tests/v1/e2e/test_min_tokens.py index f013425cb59d..f15982b7e5f3 100644 --- a/tests/v1/e2e/test_min_tokens.py +++ b/tests/v1/e2e/test_min_tokens.py @@ -46,29 +46,36 @@ def __init__( self.expected_exact_len = expected_exact_len def __str__(self): - return (f"{self.name}: min={self.min_tokens}, " - f"max={self.max_tokens}, stop={self.stop}") + return ( + f"{self.name}: min={self.min_tokens}, " + f"max={self.max_tokens}, stop={self.stop}" + ) # Test scenarios covering all critical cases MIN_TOKENS_TEST_CASES = [ # === BASIC FUNCTIONALITY (should work) === - MinTokensTestCase(name="basic_min_tokens_no_stop", - min_tokens=8, - max_tokens=20, - stop=None, - expected_min_len=8), - MinTokensTestCase(name="min_tokens_zero", - min_tokens=0, - max_tokens=10, - stop=None, - expected_min_len=0), - MinTokensTestCase(name="min_equals_max_no_stop", - min_tokens=15, - max_tokens=15, - stop=None, - expected_exact_len=15), - + MinTokensTestCase( + name="basic_min_tokens_no_stop", + min_tokens=8, + max_tokens=20, + stop=None, + expected_min_len=8, + ), + MinTokensTestCase( + name="min_tokens_zero", + min_tokens=0, + max_tokens=10, + stop=None, + expected_min_len=0, + ), + MinTokensTestCase( + name="min_equals_max_no_stop", + min_tokens=15, + max_tokens=15, + stop=None, + expected_exact_len=15, + ), # === STOP STRINGS WITH MIN_TOKENS === # These tests expose the detokenizer bug where stop strings # bypass min_tokens @@ -94,9 +101,11 @@ def __str__(self): expected_min_len=5, ), marks=pytest.mark.xfail( - reason=("Known bug #21987: stop strings bypass min_tokens " - "(fixed by PR #22014)"), - strict=False), + reason=( + "Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)" + ), + strict=False, + ), id="min_tokens_with_comprehensive_stops", ), pytest.param( @@ -108,12 +117,13 @@ def __str__(self): expected_min_len=3, ), marks=pytest.mark.xfail( - reason=("Known bug #21987: stop strings bypass min_tokens " - "(fixed by PR #22014)"), - strict=False), + reason=( + "Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)" + ), + strict=False, + ), id="min_tokens_with_simple_char_stop", ), - # === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) === # These test the MinTokensLogitsProcessor handling of EOS tokens pytest.param( @@ -125,26 +135,26 @@ def __str__(self): expected_exact_len=20, ), marks=pytest.mark.xfail( - reason= - ("Potential logits-processor bug: EOS tokens may bypass min_tokens" - ), + reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"), strict=False, ), id="min_equals_max_eos_only", ), - # === EDGE CASES === - MinTokensTestCase(name="large_min_tokens", - min_tokens=50, - max_tokens=60, - stop=None, - expected_min_len=50), + MinTokensTestCase( + name="large_min_tokens", + min_tokens=50, + max_tokens=60, + stop=None, + expected_min_len=50, + ), MinTokensTestCase( name="min_tokens_with_empty_stop_list", min_tokens=5, max_tokens=15, stop=[], # Empty stop list - expected_min_len=5), + expected_min_len=5, + ), ] @@ -170,25 +180,27 @@ def get_token_count(output: RequestOutput) -> int: return len(output.outputs[0].token_ids) -def assert_min_tokens_satisfied(output: RequestOutput, - test_case: MinTokensTestCase) -> None: +def assert_min_tokens_satisfied( + output: RequestOutput, test_case: MinTokensTestCase +) -> None: """Assert that min_tokens requirement is satisfied""" token_count = get_token_count(output) - stop_reason = (output.outputs[0].stop_reason - if output.outputs else "no output") + stop_reason = output.outputs[0].stop_reason if output.outputs else "no output" if test_case.expected_exact_len is not None: # Exact length requirement assert token_count == test_case.expected_exact_len, ( f"Expected exactly {test_case.expected_exact_len} tokens, " f"got {token_count} tokens. " - f"Stop reason: {stop_reason}") + f"Stop reason: {stop_reason}" + ) else: # Minimum length requirement assert token_count >= (test_case.expected_min_len or 0), ( f"Expected at least {test_case.expected_min_len} tokens, " f"got {token_count} tokens. " - f"Stop reason: {stop_reason}") + f"Stop reason: {stop_reason}" + ) @pytest.mark.parametrize( @@ -199,13 +211,13 @@ def assert_min_tokens_satisfied(output: RequestOutput, def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase): """ Comprehensive test for min_tokens functionality in V1 engine. - + This test covers all critical scenarios for min_tokens: - Basic functionality (should work) - Stop strings with min_tokens (known bug) - EOS tokens with min_tokens (potential bug) - Edge cases - + Args: llm_v1: V1 LLM instance test_case: Test scenario parameters @@ -218,7 +230,7 @@ def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase): max_tokens=test_case.max_tokens, stop=test_case.stop, temperature=GREEDY, - include_stop_str_in_output=True # Include stop strings for debugging + include_stop_str_in_output=True, # Include stop strings for debugging ) # Use simple prompt. Comprehensive stop lists should catch any generation @@ -250,13 +262,11 @@ def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase): def test_min_tokens_basic_functionality(llm_v1: LLM): """ Test basic min_tokens functionality without stop conditions. - + This is a baseline test that should always pass and validates that min_tokens works correctly in the simple case. """ - sampling_params = SamplingParams(min_tokens=10, - max_tokens=20, - temperature=GREEDY) + sampling_params = SamplingParams(min_tokens=10, max_tokens=20, temperature=GREEDY) prompt = "Once upon a time" outputs = llm_v1.generate([prompt], sampling_params) @@ -269,17 +279,16 @@ def test_min_tokens_basic_functionality(llm_v1: LLM): @pytest.mark.xfail( - reason=("Known bug #21987: stop strings bypass min_tokens " - "(fixed by PR #22014)"), + reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"), strict=False, ) def test_min_tokens_stop_strings_bug(llm_v1: LLM): """ Test the specific bug where stop strings bypass min_tokens. - + This test specifically reproduces the bug Calvin is fixing in PR #22014. It should fail until that fix is merged. - + Strategy: Use guaranteed stop characters that will appear in any generated text. """ @@ -291,7 +300,8 @@ def test_min_tokens_stop_strings_bug(llm_v1: LLM): # Common letter; likely appears early stop=["e"], temperature=GREEDY, - include_stop_str_in_output=True) + include_stop_str_in_output=True, + ) # Simple prompt that will generate text containing "e" prompt = "The quick brown fox" @@ -308,23 +318,25 @@ def test_min_tokens_stop_strings_bug(llm_v1: LLM): # This assertion should fail due to the bug - if stop string is found early, # the model should still continue generating until min_tokens is reached - stop_reason = (outputs[0].outputs[0].stop_reason - if outputs[0].outputs else "no output") - assert token_count >= 15, ("Bug confirmed: " - f"{token_count} tokens < min_tokens=15. " - f"Reason: {stop_reason}. " - f"Text: {repr(generated_text)}") + stop_reason = ( + outputs[0].outputs[0].stop_reason if outputs[0].outputs else "no output" + ) + assert token_count >= 15, ( + "Bug confirmed: " + f"{token_count} tokens < min_tokens=15. " + f"Reason: {stop_reason}. " + f"Text: {repr(generated_text)}" + ) @pytest.mark.xfail( - reason=("Known bug #21987: stop strings bypass min_tokens " - "(fixed by PR #22014)"), + reason=("Known bug #21987: stop strings bypass min_tokens (fixed by PR #22014)"), strict=False, ) def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): """ Guaranteed test for stop strings bypassing min_tokens bug. - + Strategy: Use very low temperature and multiple common stop strings to virtually guarantee early detection, combined with long min_tokens to ensure the bug is exposed regardless of model behavior. @@ -337,7 +349,8 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): # Use multiple very common patterns - at least one will appear stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"], temperature=GREEDY, - include_stop_str_in_output=True) + include_stop_str_in_output=True, + ) # Simple prompt that will generate some text prompt = "The cat" @@ -346,8 +359,7 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): assert len(outputs) == 1 token_count = get_token_count(outputs[0]) generated_text = outputs[0].outputs[0].text if outputs[0].outputs else "" - stop_reason = (outputs[0].outputs[0].stop_reason - if outputs[0].outputs else "unknown") + stop_reason = outputs[0].outputs[0].stop_reason if outputs[0].outputs else "unknown" print(f"Generated text: {repr(generated_text)}") print(f"Token count: {token_count}") @@ -357,21 +369,23 @@ def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM): # will trigger early termination before min_tokens=50 is reached # It's virtually impossible to generate 50 tokens without hitting # at least one of: e, a, i, o, u, space, t, n, s, r - finish_reason = (outputs[0].outputs[0].finish_reason - if outputs[0].outputs else "unknown") + finish_reason = ( + outputs[0].outputs[0].finish_reason if outputs[0].outputs else "unknown" + ) print(f"Finish reason: {finish_reason}") if finish_reason == "stop": - assert token_count >= 50, ("Bug confirmed: " - f"{token_count} tokens < min_tokens=50. " - f"Reason: {finish_reason}. " - f"Text: {repr(generated_text)}") + assert token_count >= 50, ( + "Bug confirmed: " + f"{token_count} tokens < min_tokens=50. " + f"Reason: {finish_reason}. " + f"Text: {repr(generated_text)}" + ) @pytest.mark.xfail( - reason=( - "Potential logits-processor bug: EOS tokens may bypass min_tokens"), + reason=("Potential logits-processor bug: EOS tokens may bypass min_tokens"), strict=False, ) def test_min_tokens_eos_behavior(llm_v1: LLM): @@ -404,8 +418,14 @@ def test_min_tokens_eos_behavior(llm_v1: LLM): finish_no_min = choice_no_min.finish_reason stop_no_min = choice_no_min.stop_reason - print("[no-min] tokens=", len(ids_no_min), " finish=", finish_no_min, - " stop_reason=", stop_no_min) + print( + "[no-min] tokens=", + len(ids_no_min), + " finish=", + finish_no_min, + " stop_reason=", + stop_no_min, + ) assert finish_no_min == "stop", ( f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}" @@ -414,7 +434,8 @@ def test_min_tokens_eos_behavior(llm_v1: LLM): "For EOS-based stop (no user stop strings), stop_reason should be None." ) assert len(ids_no_min) < max_toks, ( - f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}") + f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}" + ) # Case 2: WITH min_tokens sp_with_min = SamplingParams( @@ -430,23 +451,31 @@ def test_min_tokens_eos_behavior(llm_v1: LLM): finish_with_min = choice_with_min.finish_reason stop_with_min = choice_with_min.stop_reason - print("[with-min] tokens=", len(ids_with_min), " finish=", finish_with_min, - " stop_reason=", stop_with_min) + print( + "[with-min] tokens=", + len(ids_with_min), + " finish=", + finish_with_min, + " stop_reason=", + stop_with_min, + ) # Exact length reached; EOS should have been blocked assert len(ids_with_min) == max_toks, ( - f"Expected exactly {max_toks} tokens with min_tokens; " - f"got {len(ids_with_min)}") + f"Expected exactly {max_toks} tokens with min_tokens; got {len(ids_with_min)}" + ) assert finish_with_min == "length", ( - f"Expected finish_reason 'length'; got {finish_with_min}") + f"Expected finish_reason 'length'; got {finish_with_min}" + ) assert eos_token_id not in ids_with_min, ( - "EOS token id should not appear when min_tokens prevents early EOS.") + "EOS token id should not appear when min_tokens prevents early EOS." + ) def test_min_tokens_validation(): """ Test that SamplingParams correctly validates min_tokens parameters. - + This tests the parameter validation logic in SamplingParams. """ # Valid cases @@ -456,14 +485,14 @@ def test_min_tokens_validation(): # Invalid cases with pytest.raises( - ValueError, - match="min_tokens must be greater than or equal to 0", + ValueError, + match="min_tokens must be greater than or equal to 0", ): SamplingParams(min_tokens=-1, max_tokens=10) with pytest.raises( - ValueError, - match="min_tokens must be less than or equal to max_tokens", + ValueError, + match="min_tokens must be less than or equal to max_tokens", ): SamplingParams(min_tokens=15, max_tokens=10) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 8f048775352e..9ed9cd7950a9 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -48,19 +48,17 @@ def get_test_prompts(mm_enabled: bool): give no other output than that simple sentence without quotes. """ elif kind == "mm": - placeholders = [{ - "type": "image_url", - "image_url": { - "url": - f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" - }, - }] + placeholders = [ + { + "type": "image_url", + "image_url": { + "url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" + }, + } + ] prompt = [ *placeholders, - { - "type": "text", - "text": "The meaning of the image is" - }, + {"type": "text", "text": "The meaning of the image is"}, ] else: raise ValueError(f"Unknown prompt type: {kind}") @@ -84,10 +82,10 @@ def test_ngram_correctness( sampling_config: SamplingParams, model_name: str, ): - ''' + """ Compare the outputs of an original LLM and a speculative LLM should be the same when using ngram speculative decoding. - ''' + """ test_prompts = get_test_prompts(mm_enabled=False) ref_llm = LLM(model=model_name, max_model_len=1024) @@ -129,32 +127,77 @@ def test_ngram_correctness( ["model_setup", "mm_enabled"], [ (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), - pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct", - "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), - False, - marks=pytest.mark.skip(reason="Skipping due to its " \ - "head_dim not being a a multiple of 32")), - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), - pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - False, - marks=large_gpu_mark(min_gb=80)), # works on 4x H100 - pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - True, - marks=large_gpu_mark(min_gb=80)), # works on 4x H100 - (("eagle", "eagle618/deepseek-v3-random", - "eagle618/eagle-deepseek-v3-random", 1), False), + pytest.param( + ( + "eagle3", + "Qwen/Qwen2.5-VL-7B-Instruct", + "Rayzl/qwen2.5-vl-7b-eagle3-sgl", + 1, + ), + False, + marks=pytest.mark.skip( + reason="Skipping due to its head_dim not being a a multiple of 32" + ), + ), + ( + ( + "eagle", + "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", + 1, + ), + False, + ), + ( + ( + "eagle3", + "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + 1, + ), + False, + ), + pytest.param( + ( + "eagle", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + 4, + ), + False, + marks=large_gpu_mark(min_gb=80), + ), # works on 4x H100 + pytest.param( + ( + "eagle", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + 4, + ), + True, + marks=large_gpu_mark(min_gb=80), + ), # works on 4x H100 + ( + ( + "eagle", + "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", + 1, + ), + False, + ), ], ids=[ - "qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3", - "llama4_eagle", "llama4_eagle_mm", "deepseek_eagle" - ]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) + "qwen3_eagle3", + "qwen2_5_vl_eagle3", + "llama3_eagle", + "llama3_eagle3", + "llama4_eagle", + "llama4_eagle_mm", + "deepseek_eagle", + ], +) +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, @@ -166,15 +209,16 @@ def test_eagle_correctness( # TODO: Fix this flaky test pytest.skip( "TREE_ATTN is flaky in the test disable for now until it can be " - "resolved (see https://github.com/vllm-project/vllm/issues/22922)") + "resolved (see https://github.com/vllm-project/vllm/issues/22922)" + ) # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) - ''' + """ Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. model_setup: (method, model_name, eagle_model_name, tp_size) - ''' + """ with monkeypatch.context() as m: if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": # Scout requires default backend selection @@ -185,18 +229,20 @@ def test_eagle_correctness( m.setenv("VLLM_MLA_DISABLE", "1") m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): m.setenv("VLLM_ROCM_USE_AITER", "1") method, model_name, spec_model_name, tp_size = model_setup - ref_llm = LLM(model=model_name, - max_model_len=2048, - tensor_parallel_size=tp_size) + ref_llm = LLM( + model=model_name, max_model_len=2048, tensor_parallel_size=tp_size + ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm torch.cuda.empty_cache() @@ -233,11 +279,14 @@ def test_eagle_correctness( cleanup_dist_env_and_memory() -@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ - (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), - (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False), -], - ids=["mimo", "deepseek"]) +@pytest.mark.parametrize( + ["model_setup", "mm_enabled"], + [ + (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), + (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False), + ], + ids=["mimo", "deepseek"], +) def test_mtp_correctness( monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, @@ -246,21 +295,23 @@ def test_mtp_correctness( ): # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) - ''' + """ Compare the outputs of a original LLM and a speculative LLM should be the same when using MTP speculative decoding. model_setup: (method, model_name, tp_size) - ''' + """ with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_MLA_DISABLE", "1") method, model_name, tp_size = model_setup - ref_llm = LLM(model=model_name, - max_model_len=2048, - tensor_parallel_size=tp_size, - trust_remote_code=True) + ref_llm = LLM( + model=model_name, + max_model_len=2048, + tensor_parallel_size=tp_size, + trust_remote_code=True, + ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm torch.cuda.empty_cache() diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py index a73a9a6999f7..c5c5d35b83c3 100644 --- a/tests/v1/engine/conftest.py +++ b/tests/v1/engine/conftest.py @@ -5,12 +5,15 @@ import torch from transformers import AutoTokenizer -from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, - NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN, - TOKENIZER_NAME, - DummyOutputProcessorTestVectors, - generate_dummy_prompt_logprobs_tensors, - generate_dummy_sample_logprobs) +from tests.v1.engine.utils import ( + NUM_PROMPT_LOGPROBS_UNDER_TEST, + NUM_SAMPLE_LOGPROBS_UNDER_TEST, + PROMPT_LEN, + TOKENIZER_NAME, + DummyOutputProcessorTestVectors, + generate_dummy_prompt_logprobs_tensors, + generate_dummy_sample_logprobs, +) from vllm.engine.arg_utils import EngineArgs from ...distributed.conftest import publisher_config, random_port # noqa: F401 @@ -31,9 +34,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config() # Tokenize prompts under test & create dummy generated tokens - prompt_tokens = [ - tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS - ] + prompt_tokens = [tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS] generation_tokens = [ tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS ] @@ -42,9 +43,7 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: tokenizer.decode(prompt_tokens, skip_special_tokens=True) for prompt_tokens in prompt_tokens ] - prompt_strings_len = [ - len(prompt_string) for prompt_string in prompt_strings - ] + prompt_strings_len = [len(prompt_string) for prompt_string in prompt_strings] return DummyOutputProcessorTestVectors( tokenizer=tokenizer, vllm_config=vllm_config, @@ -58,7 +57,8 @@ def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors: for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len) ], prompt_logprobs=[], - generation_logprobs=[]) + generation_logprobs=[], + ) @pytest.fixture @@ -76,12 +76,16 @@ def dummy_test_vectors() -> DummyOutputProcessorTestVectors: generate_dummy_sample_logprobs( sampled_tokens_list=tokens_list, num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST, - tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens + tokenizer=dtv.tokenizer, + ) + for tokens_list in dtv.generation_tokens ] dtv.prompt_logprobs = [ generate_dummy_prompt_logprobs_tensors( prompt_tokens_list=tokens_list, num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST, - tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens + tokenizer=dtv.tokenizer, + ) + for tokens_list in dtv.prompt_tokens ] return dtv diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index aca546600d0b..3e30d28111c8 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -21,16 +21,16 @@ from vllm.v1.metrics.loggers import LoggingStatLogger if not current_platform.is_cuda(): - pytest.skip(reason="V1 currently only supported on CUDA.", - allow_module_level=True) + pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) TEXT_ENGINE_ARGS = AsyncEngineArgs( model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, ) -VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct", - enforce_eager=True) +VISION_ENGINE_ARGS = AsyncEngineArgs( + model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True +) TEXT_PROMPT = "Hello my name is Robert and" @@ -38,12 +38,11 @@ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" "What is in the image?<|im_end|>\n" - "<|im_start|>assistant\n") + "<|im_start|>assistant\n" +) VISION_PROMPT = { "prompt": VISION_PROMPT_TEMPLATE, - "multi_modal_data": { - "image": ImageAsset("stop_sign").pil_image - }, + "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image}, } @@ -70,10 +69,9 @@ async def generate( n=n, prompt_logprobs=prompt_logprobs, ) - async for out in engine.generate(request_id=request_id, - prompt=prompt, - sampling_params=sampling_params): - + async for out in engine.generate( + request_id=request_id, prompt=prompt, sampling_params=sampling_params + ): num_tokens = sum(len(output.token_ids) for output in out.outputs) if output_kind == RequestOutputKind.DELTA: count += num_tokens @@ -89,7 +87,8 @@ async def generate( @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) @pytest.mark.parametrize( "engine_args,prompt", [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)], @@ -121,25 +120,29 @@ async def test_load( for request_id in request_ids: tasks.append( asyncio.create_task( - generate(engine, request_id, prompt, output_kind, - NUM_EXPECTED_TOKENS))) + generate( + engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS + ) + ) + ) # Confirm that we got all the EXPECTED tokens from the requests. - done, pending = await asyncio.wait(tasks, - return_when=asyncio.FIRST_EXCEPTION) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) for task in pending: task.cancel() for task in done: num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( f"{request_id} generated {num_generated_tokens} but " - f"expected {NUM_EXPECTED_TOKENS}") + f"expected {NUM_EXPECTED_TOKENS}" + ) assert not engine.output_processor.has_unfinished_requests() @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) @pytest.mark.parametrize( "engine_args,prompt", [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)], @@ -151,7 +154,6 @@ async def test_abort( engine_args: AsyncEngineArgs, prompt: PromptType, ): - with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") @@ -170,14 +172,17 @@ async def test_abort( # Create concurrent requests. tasks: list[asyncio.Task] = [] for idx, request_id in enumerate(request_ids): - max_tokens = (NUM_EXPECTED_TOKENS_LONG if - (idx - in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS) + max_tokens = ( + NUM_EXPECTED_TOKENS_LONG + if (idx in REQUEST_IDS_TO_ABORT) + else NUM_EXPECTED_TOKENS + ) n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 tasks.append( asyncio.create_task( - generate(engine, request_id, prompt, output_kind, - max_tokens, n))) + generate(engine, request_id, prompt, output_kind, max_tokens, n) + ) + ) # API server cancels requests when they disconnect. for idx in REQUEST_IDS_TO_ABORT: @@ -197,7 +202,8 @@ async def test_abort( expected_tokens = NUM_EXPECTED_TOKENS * n assert num_generated_tokens == expected_tokens, ( f"{request_id} generated {num_generated_tokens} but " - f"expected {expected_tokens}") + f"expected {expected_tokens}" + ) # Make sure all aborted requests were really aborted. assert not engine.output_processor.has_unfinished_requests() @@ -205,21 +211,21 @@ async def test_abort( # Confirm we can do another generation. request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}" task = asyncio.create_task( - generate(engine, request_id, prompt, output_kind, - NUM_EXPECTED_TOKENS)) + generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS) + ) num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS assert not engine.output_processor.has_unfinished_requests() @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) @pytest.mark.asyncio async def test_multi_abort( monkeypatch: pytest.MonkeyPatch, output_kind: RequestOutputKind, ): - with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") @@ -238,14 +244,19 @@ async def test_multi_abort( # Create concurrent requests. tasks: list[asyncio.Task] = [] for idx, request_id in enumerate(request_ids): - max_tokens = (NUM_EXPECTED_TOKENS_LONG if - (idx - in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS) + max_tokens = ( + NUM_EXPECTED_TOKENS_LONG + if (idx in REQUEST_IDS_TO_ABORT) + else NUM_EXPECTED_TOKENS + ) n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 tasks.append( asyncio.create_task( - generate(engine, request_id, TEXT_PROMPT, output_kind, - max_tokens, n))) + generate( + engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n + ) + ) + ) # Let requests start await asyncio.sleep(0.5) @@ -261,25 +272,26 @@ async def test_multi_abort( for idx, result in enumerate(results): if idx in REQUEST_IDS_TO_ABORT: # Aborted requests should return partial results - assert isinstance( - result, tuple - ), f"Request {idx} should have completed with partial results" + assert isinstance(result, tuple), ( + f"Request {idx} should have completed with partial results" + ) num_generated_tokens, request_id = result # Should have generated some tokens before abort assert num_generated_tokens > 0, ( - f"Aborted request " - f"{request_id} should have generated some tokens") + f"Aborted request {request_id} should have generated some tokens" + ) else: # Non-aborted requests should complete normally - assert isinstance( - result, - tuple), f"Request {idx} should have completed successfully" + assert isinstance(result, tuple), ( + f"Request {idx} should have completed successfully" + ) num_generated_tokens, request_id = result n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1 expected_tokens = NUM_EXPECTED_TOKENS * n assert num_generated_tokens == expected_tokens, ( f"{request_id} generated {num_generated_tokens} but " - f"expected {expected_tokens}") + f"expected {expected_tokens}" + ) # Make sure all aborted requests were cleaned up assert not engine.output_processor.has_unfinished_requests() @@ -297,7 +309,6 @@ async def test_finished_flag( engine_args: AsyncEngineArgs, prompt: PromptType, ): - with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") @@ -314,9 +325,9 @@ async def test_finished_flag( ) outputs = [ out - async for out in engine.generate(request_id="request-33", - prompt=prompt, - sampling_params=sampling_params) + async for out in engine.generate( + request_id="request-33", prompt=prompt, sampling_params=sampling_params + ) ] # Assert only the last output has the finished flag set @@ -329,9 +340,9 @@ async def test_finished_flag( [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)], ) @pytest.mark.asyncio -async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, - engine_args: AsyncEngineArgs, - prompt: PromptType): +async def test_mid_stream_cancellation( + monkeypatch: pytest.MonkeyPatch, engine_args: AsyncEngineArgs, prompt: PromptType +): """Test that requests can be cancelled mid-stream.""" with monkeypatch.context() as m, ExitStack() as after: m.setenv("VLLM_USE_V1", "1") @@ -358,7 +369,9 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, RequestOutputKind.DELTA, NUM_TOKENS, cancel_after=NUM_EXPECTED_TOKENS, - ))) + ) + ) + ) # Wait for all tasks to complete results = await asyncio.gather(*tasks) @@ -367,7 +380,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, for num_generated_tokens, request_id in results: assert num_generated_tokens == NUM_EXPECTED_TOKENS, ( f"{request_id} generated {num_generated_tokens} tokens but " - f"expected to cancel after {NUM_EXPECTED_TOKENS}") + f"expected to cancel after {NUM_EXPECTED_TOKENS}" + ) # Make sure no requests are left hanging assert not engine.output_processor.has_unfinished_requests() @@ -375,15 +389,16 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, # Confirm we can reuse the request id after the cancellations. request_id = request_ids[0] task = asyncio.create_task( - generate(engine, request_id, prompt, RequestOutputKind.DELTA, - NUM_EXPECTED_TOKENS)) + generate( + engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS + ) + ) num_generated_tokens, request_id = await task assert num_generated_tokens == NUM_EXPECTED_TOKENS assert not engine.output_processor.has_unfinished_requests() class MockLoggingStatLogger(LoggingStatLogger): - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): super().__init__(vllm_config, engine_index) self.log = MagicMock() @@ -410,8 +425,7 @@ async def test_customize_loggers(monkeypatch): stat_loggers = engine.logger_manager.per_engine_logger_dict assert len(stat_loggers) == 1 - assert len( - stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger + assert len(stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger stat_loggers[0][0].log.assert_called_once() @@ -424,24 +438,30 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch): engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) after.callback(engine.shutdown) - sampling_params = SamplingParams(max_tokens=100, - output_kind=RequestOutputKind.DELTA, - temperature=1.0, - seed=33) + sampling_params = SamplingParams( + max_tokens=100, + output_kind=RequestOutputKind.DELTA, + temperature=1.0, + seed=33, + ) # Test with valid DP rank. - async for _ in engine.generate(request_id="request-34", - prompt=TEXT_PROMPT, - sampling_params=sampling_params, - data_parallel_rank=0): + async for _ in engine.generate( + request_id="request-34", + prompt=TEXT_PROMPT, + sampling_params=sampling_params, + data_parallel_rank=0, + ): pass # Test with out-of-range DP rank. with pytest.raises(ValueError): - async for _ in engine.generate(request_id="request-35", - prompt=TEXT_PROMPT, - sampling_params=sampling_params, - data_parallel_rank=1): + async for _ in engine.generate( + request_id="request-35", + prompt=TEXT_PROMPT, + sampling_params=sampling_params, + data_parallel_rank=1, + ): pass @@ -465,10 +485,14 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch): await engine.check_health() # Test 2: Mock the errored property to simulate a dead engine - with patch.object(type(engine), - 'errored', - new_callable=lambda: property(lambda self: True) - ), pytest.raises(EngineDeadError): + with ( + patch.object( + type(engine), + "errored", + new_callable=lambda: property(lambda self: True), + ), + pytest.raises(EngineDeadError), + ): await engine.check_health() # Test 3: Verify healthy engine still works after mock @@ -476,7 +500,8 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) @pytest.mark.asyncio async def test_abort_final_output( monkeypatch: pytest.MonkeyPatch, @@ -504,8 +529,8 @@ async def test_abort_final_output( outputs: list[RequestOutput] = [] generated = asyncio.create_task( - collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, - outputs)) + collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs) + ) # Let it generate some tokens await asyncio.sleep(0.5) @@ -525,14 +550,13 @@ async def test_abort_final_output( assert final_output.outputs[0].stop_reason is None # Verify num_cached_tokens is set correctly - assert hasattr(final_output, 'num_cached_tokens') + assert hasattr(final_output, "num_cached_tokens") assert final_output.num_cached_tokens >= 0 # If we got intermediate outputs, verify they are consistent if output_kind == RequestOutputKind.DELTA: # For DELTA, sum all intermediate tokens should <= final tokens - token_count = sum( - len(output.outputs[0].token_ids) for output in outputs) + token_count = sum(len(output.outputs[0].token_ids) for output in outputs) assert token_count > 0 # This would ordinarily be 0, but could end up > 0 if the # final abort is coalesced with another chunk in the output queue. @@ -554,9 +578,9 @@ async def collect_outputs( ) -> Optional[RequestOutput]: """Helper to collect outputs and return the final one.""" final_output: Optional[RequestOutput] = None - async for output in engine.generate(request_id=request_id, - prompt=prompt, - sampling_params=sampling_params): + async for output in engine.generate( + request_id=request_id, prompt=prompt, sampling_params=sampling_params + ): if not output.finished: outputs_list.append(output) final_output = output diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index 23ec3673b10b..f6b10fa67b3b 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -22,8 +22,9 @@ def test_prefix_caching_from_cli(): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) args = parser.parse_args([]) vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() - assert (vllm_config.cache_config.enable_prefix_caching - ), "V1 turns on prefix caching by default." + assert vllm_config.cache_config.enable_prefix_caching, ( + "V1 turns on prefix caching by default." + ) # Turn it off possible with flag. args = parser.parse_args(["--no-enable-prefix-caching"]) @@ -41,8 +42,7 @@ def test_prefix_caching_from_cli(): # set hash algorithm to sha256_cbor args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"]) vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() - assert vllm_config.cache_config.prefix_caching_hash_algo == \ - "sha256_cbor" + assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256_cbor" # set hash algorithm to sha256 args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"]) @@ -57,10 +57,10 @@ def test_prefix_caching_from_cli(): def test_defaults_with_usage_context(): engine_args = EngineArgs(model="facebook/opt-125m") - vllm_config: VllmConfig = engine_args.create_engine_config( - UsageContext.LLM_CLASS) + vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS) from vllm.platforms import current_platform + device_name = current_platform.get_device_name().lower() if "h100" in device_name or "h200" in device_name: # For H100 and H200, we use larger default values. @@ -76,7 +76,6 @@ def test_defaults_with_usage_context(): assert vllm_config.scheduler_config.max_num_batched_tokens == default_llm_tokens # noqa: E501 engine_args = EngineArgs(model="facebook/opt-125m") - vllm_config = engine_args.create_engine_config( - UsageContext.OPENAI_API_SERVER) + vllm_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER) assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs assert vllm_config.scheduler_config.max_num_batched_tokens == default_server_tokens # noqa: E501 diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 17b136aa4273..28d7854ab5d2 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -22,8 +22,7 @@ from ...utils import create_new_process_for_each_test, multi_gpu_test if not current_platform.is_cuda(): - pytest.skip(reason="V1 currently only supported on CUDA.", - allow_module_level=True) + pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) @@ -48,7 +47,6 @@ def make_request() -> EngineCoreRequest: @create_new_process_for_each_test() def test_engine_core(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") """Setup the EngineCore.""" @@ -57,14 +55,13 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): - engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True) + engine_core = EngineCore( + vllm_config=vllm_config, executor_class=executor_class, log_stats=True + ) """Test basic request lifecycle.""" # First request. - engine_core.add_request( - *engine_core.preprocess_add_request(make_request())) + engine_core.add_request(*engine_core.preprocess_add_request(make_request())) assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 @@ -73,8 +70,7 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): assert len(engine_core.scheduler.running) == 1 # Second request. - engine_core.add_request( - *engine_core.preprocess_add_request(make_request())) + engine_core.add_request(*engine_core.preprocess_add_request(make_request())) assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 1 @@ -83,10 +79,8 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): assert len(engine_core.scheduler.running) == 2 # Add two requests in a row. - engine_core.add_request( - *engine_core.preprocess_add_request(make_request())) - engine_core.add_request( - *engine_core.preprocess_add_request(make_request())) + engine_core.add_request(*engine_core.preprocess_add_request(make_request())) + engine_core.add_request(*engine_core.preprocess_add_request(make_request())) assert len(engine_core.scheduler.waiting) == 2 assert len(engine_core.scheduler.running) == 2 @@ -196,9 +190,9 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): - engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True) + engine_core = EngineCore( + vllm_config=vllm_config, executor_class=executor_class, log_stats=True + ) """Test basic request lifecycle.""" # First request. request: EngineCoreRequest = make_request() @@ -238,17 +232,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): Test that the engine can handle multiple concurrent batches. """ - def make_request_with_max_tokens(req_id: str, - max_tokens: int) -> EngineCoreRequest: + def make_request_with_max_tokens(req_id: str, max_tokens: int) -> EngineCoreRequest: request = make_request() request.request_id = req_id request.sampling_params.max_tokens = max_tokens return request class DummyExecutor(UniProcExecutor): - - def initialize_from_config( - self, kv_cache_configs: list[KVCacheConfig]) -> None: + def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: super().initialize_from_config(kv_cache_configs) # Create a thread pool with a single worker @@ -265,8 +256,7 @@ def execute_model( assert non_block def _execute(): - output = self.collective_rpc("execute_model", - args=(scheduler_output, )) + output = self.collective_rpc("execute_model", args=(scheduler_output,)) # Make a copy because output[0] may be reused # by the next batch. return copy.deepcopy(output[0]) @@ -279,7 +269,7 @@ def max_concurrent_batches(self) -> int: return 2 def shutdown(self): - if hasattr(self, 'thread_pool'): + if hasattr(self, "thread_pool"): self.thread_pool.shutdown(wait=False) with monkeypatch.context() as m: @@ -297,9 +287,9 @@ def shutdown(self): ) vllm_config = engine_args.create_engine_config() with set_default_torch_num_threads(1): - engine_core = EngineCore(vllm_config=vllm_config, - log_stats=False, - executor_class=DummyExecutor) + engine_core = EngineCore( + vllm_config=vllm_config, log_stats=False, executor_class=DummyExecutor + ) assert engine_core.batch_queue is not None # Add two requests in a row. Each request have 12 prompt tokens. @@ -314,8 +304,7 @@ def shutdown(self): scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["0"] == 10 # num_computed_tokens should have been updated immediately. - assert engine_core.scheduler.requests[ - req0.request_id].num_computed_tokens == 10 + assert engine_core.scheduler.requests[req0.request_id].num_computed_tokens == 10 # Schedule Batch 2: (2, req0), (8, req1) assert engine_core.step_with_batch_queue()[0] == {} @@ -366,8 +355,10 @@ def shutdown(self): assert output is not None assert len(output[0].outputs) == 1 if req_id in engine_core.scheduler.requests: - assert engine_core.scheduler.requests[ - req_id].num_tokens == expected_num_tokens[req_id] + assert ( + engine_core.scheduler.requests[req_id].num_tokens + == expected_num_tokens[req_id] + ) expected_num_tokens[req_id] += 1 req_id = (req_id + 1) % 2 @@ -391,17 +382,19 @@ def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch): executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): - engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True) + engine_core = EngineCore( + vllm_config=vllm_config, executor_class=executor_class, log_stats=True + ) def get_worker_cache_config_field(worker, key: str): return getattr(worker.cache_config, key) num_gpu_blocks = engine_core.collective_rpc( - get_worker_cache_config_field, args=("num_gpu_blocks", )) + get_worker_cache_config_field, args=("num_gpu_blocks",) + ) num_cpu_blocks = engine_core.collective_rpc( - get_worker_cache_config_field, args=("num_cpu_blocks", )) + get_worker_cache_config_field, args=("num_cpu_blocks",) + ) assert all(x is not None for x in num_gpu_blocks) assert all(x is not None for x in num_cpu_blocks) @@ -417,40 +410,35 @@ def test_engine_core_invalid_request_id_type(monkeypatch: pytest.MonkeyPatch): executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): - engine_core = EngineCore(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=True) + engine_core = EngineCore( + vllm_config=vllm_config, executor_class=executor_class, log_stats=True + ) # Test with UUID object (common mistake) uuid_request = make_request() uuid_request.request_id = uuid.uuid4() # UUID object instead of string - with pytest.raises(TypeError, - match="request_id must be a string, got.*UUID"): - engine_core.add_request( - *engine_core.preprocess_add_request(uuid_request)) + with pytest.raises(TypeError, match="request_id must be a string, got.*UUID"): + engine_core.add_request(*engine_core.preprocess_add_request(uuid_request)) # Test with integer int_request = make_request() int_request.request_id = 12345 - with pytest.raises(TypeError, - match="request_id must be a string, got.*int"): - engine_core.add_request( - *engine_core.preprocess_add_request(int_request)) + with pytest.raises(TypeError, match="request_id must be a string, got.*int"): + engine_core.add_request(*engine_core.preprocess_add_request(int_request)) # Test with None none_request = make_request() none_request.request_id = None - with pytest.raises(TypeError, - match="request_id must be a string, got.*NoneType"): - engine_core.add_request( - *engine_core.preprocess_add_request(none_request)) + with pytest.raises( + TypeError, match="request_id must be a string, got.*NoneType" + ): + engine_core.add_request(*engine_core.preprocess_add_request(none_request)) # Verify engine is still functional after errors valid_request = make_request() - engine_core.add_request( - *engine_core.preprocess_add_request(valid_request)) + engine_core.add_request(*engine_core.preprocess_add_request(valid_request)) assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 10adac9bab5f..90284fc54d06 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -17,16 +17,14 @@ from tests.utils import multi_gpu_test from vllm import SamplingParams -from vllm.distributed.kv_events import (BlockStored, KVEventBatch, - ZmqEventPublisher) +from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublisher from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext from vllm.utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore -from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, - SyncMPClient) +from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient from vllm.v1.engine.utils import CoreEngineProcManager from vllm.v1.executor.abstract import Executor @@ -34,8 +32,7 @@ from ...utils import create_new_process_for_each_test if not current_platform.is_cuda(): - pytest.skip(reason="V1 currently only supported on CUDA.", - allow_module_level=True) + pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True) MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) @@ -44,8 +41,8 @@ def make_request( - params: SamplingParams, - prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest: + params: SamplingParams, prompt_tokens_ids: Optional[list[int]] = None +) -> EngineCoreRequest: if not prompt_tokens_ids: prompt_tokens_ids = PROMPT_TOKENS @@ -64,7 +61,6 @@ def make_request( def loop_until_done(client: EngineCoreClient, outputs: dict): - while True: engine_core_outputs = client.get_output().outputs @@ -82,7 +78,6 @@ def loop_until_done(client: EngineCoreClient, outputs: dict): async def loop_until_done_async(client: EngineCoreClient, outputs: dict): - while True: engine_core_outputs = (await client.get_output_async()).outputs @@ -100,7 +95,6 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict): async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict): - while True: engine_core_outputs = (await client.get_output_async()).outputs @@ -119,10 +113,9 @@ async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict): # Dummy utility function to monkey-patch into engine core. -def echo(self, - msg: str, - err_msg: Optional[str] = None, - sleep: Optional[float] = None) -> str: +def echo( + self, msg: str, err_msg: Optional[str] = None, sleep: Optional[float] = None +) -> str: print(f"echo util function called: {msg}, {err_msg}") if sleep is not None: time.sleep(sleep) @@ -133,9 +126,9 @@ def echo(self, @create_new_process_for_each_test() @pytest.mark.parametrize("multiprocessing_mode", [True, False]) -def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, - multiprocessing_mode: bool): - +def test_engine_core_client( + monkeypatch: pytest.MonkeyPatch, multiprocessing_mode: bool +): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -143,8 +136,7 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, m.setattr(EngineCore, "echo", echo, raising=False) engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) - vllm_config = engine_args.create_engine_config( - UsageContext.UNKNOWN_CONTEXT) + vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): @@ -172,7 +164,8 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, for req_id in request_ids: assert len(outputs[req_id]) == MAX_TOKENS, ( - f"{outputs[req_id]=}, {MAX_TOKENS=}") + f"{outputs[req_id]=}, {MAX_TOKENS=}" + ) """Abort Request Cycle.""" # Note: this code pathway will only work for multiprocessing @@ -191,10 +184,12 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, for idx, req_id in enumerate(request_ids): if idx % 2 == 0: assert len(outputs[req_id]) < MAX_TOKENS, ( - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + f"{len(outputs[req_id])=}, {MAX_TOKENS=}" + ) else: assert len(outputs[req_id]) == MAX_TOKENS, ( - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + f"{len(outputs[req_id])=}, {MAX_TOKENS=}" + ) """Abort after request is finished.""" # Note: this code pathway will only work for multiprocessing @@ -202,7 +197,7 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, request = requests[0] client.add_request(request) - time.sleep(10.) + time.sleep(10.0) client.abort_requests([request.request_id]) @@ -222,7 +217,6 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, @pytest.mark.asyncio(loop_scope="function") async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -231,7 +225,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.UNKNOWN_CONTEXT) + usage_context=UsageContext.UNKNOWN_CONTEXT + ) executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): @@ -261,7 +256,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): for req_id in request_ids: assert len(outputs[req_id]) == MAX_TOKENS, ( - f"{outputs[req_id]=}, {MAX_TOKENS=}") + f"{outputs[req_id]=}, {MAX_TOKENS=}" + ) """Abort Request Cycle.""" # Add requests to the engine. @@ -277,10 +273,12 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): for idx, req_id in enumerate(request_ids): if idx % 2 == 0: assert len(outputs[req_id]) < MAX_TOKENS, ( - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + f"{len(outputs[req_id])=}, {MAX_TOKENS=}" + ) else: assert len(outputs[req_id]) == MAX_TOKENS, ( - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + f"{len(outputs[req_id])=}, {MAX_TOKENS=}" + ) """Utility method invocation""" core_client: AsyncMPClient = client @@ -296,8 +294,8 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): # Test that cancelling the utility call doesn't destabilize the # engine. util_task = asyncio.create_task( - core_client.call_utility_async("echo", "testarg2", None, - 0.5)) # sleep for 0.5 sec + core_client.call_utility_async("echo", "testarg2", None, 0.5) + ) # sleep for 0.5 sec await asyncio.sleep(0.05) cancelled = util_task.cancel() assert cancelled @@ -305,9 +303,9 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): # Ensure client is still functional. The engine runs utility # methods in a single thread so this request won't be processed # until the cancelled sleeping one is complete. - result = await asyncio.wait_for(core_client.call_utility_async( - "echo", "testarg3"), - timeout=1.0) + result = await asyncio.wait_for( + core_client.call_utility_async("echo", "testarg3"), timeout=1.0 + ) assert result == "testarg3" finally: client.shutdown() @@ -353,8 +351,7 @@ def echo_dc_nested( msg: str, structure_type: str = "list_of_dicts", ) -> Any: - print(f"echo dc nested util function called: {msg}, " - f"structure: {structure_type}") + print(f"echo dc nested util function called: {msg}, structure: {structure_type}") val = None if msg is None else MyDataclass(msg) if structure_type == "list_of_dicts": # noqa @@ -373,8 +370,8 @@ def echo_dc_nested( @pytest.mark.asyncio(loop_scope="function") async def test_engine_core_client_util_method_custom_return( - monkeypatch: pytest.MonkeyPatch): - + monkeypatch: pytest.MonkeyPatch, +): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -386,7 +383,8 @@ async def test_engine_core_client_util_method_custom_return( engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.UNKNOWN_CONTEXT) + usage_context=UsageContext.UNKNOWN_CONTEXT + ) executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): @@ -402,22 +400,17 @@ async def test_engine_core_client_util_method_custom_return( # Test utility method returning custom / non-native data type. core_client: AsyncMPClient = client - result = await core_client.call_utility_async( - "echo_dc", "testarg2", False) - assert isinstance(result, - MyDataclass) and result.message == "testarg2" - result = await core_client.call_utility_async( - "echo_dc", "testarg2", True) + result = await core_client.call_utility_async("echo_dc", "testarg2", False) + assert isinstance(result, MyDataclass) and result.message == "testarg2" + result = await core_client.call_utility_async("echo_dc", "testarg2", True) assert isinstance(result, list) and all( - isinstance(r, MyDataclass) and r.message == "testarg2" - for r in result) + isinstance(r, MyDataclass) and r.message == "testarg2" for r in result + ) # Test returning None and list of Nones - result = await core_client.call_utility_async( - "echo_dc", None, False) + result = await core_client.call_utility_async("echo_dc", None, False) assert result is None - result = await core_client.call_utility_async( - "echo_dc", None, True) + result = await core_client.call_utility_async("echo_dc", None, True) assert isinstance(result, list) and all(r is None for r in result) finally: @@ -426,8 +419,8 @@ async def test_engine_core_client_util_method_custom_return( @pytest.mark.asyncio(loop_scope="function") async def test_engine_core_client_util_method_custom_dict_return( - monkeypatch: pytest.MonkeyPatch): - + monkeypatch: pytest.MonkeyPatch, +): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -439,7 +432,8 @@ async def test_engine_core_client_util_method_custom_dict_return( engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.UNKNOWN_CONTEXT) + usage_context=UsageContext.UNKNOWN_CONTEXT + ) executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): @@ -457,22 +451,21 @@ async def test_engine_core_client_util_method_custom_dict_return( # Test single object return result = await core_client.call_utility_async( - "echo_dc_dict", "testarg3", False) - assert isinstance(result, - MyDataclass) and result.message == "testarg3" + "echo_dc_dict", "testarg3", False + ) + assert isinstance(result, MyDataclass) and result.message == "testarg3" # Test dict return with custom value types result = await core_client.call_utility_async( - "echo_dc_dict", "testarg3", True) + "echo_dc_dict", "testarg3", True + ) assert isinstance(result, dict) and len(result) == 3 for key, val in result.items(): assert key in ["key1", "key2", "key3"] - assert isinstance(val, - MyDataclass) and val.message == "testarg3" + assert isinstance(val, MyDataclass) and val.message == "testarg3" # Test returning dict with None values - result = await core_client.call_utility_async( - "echo_dc_dict", None, True) + result = await core_client.call_utility_async("echo_dc_dict", None, True) assert isinstance(result, dict) and len(result) == 3 for key, val in result.items(): assert key in ["key1", "key2", "key3"] @@ -484,8 +477,8 @@ async def test_engine_core_client_util_method_custom_dict_return( @pytest.mark.asyncio(loop_scope="function") async def test_engine_core_client_util_method_nested_structures( - monkeypatch: pytest.MonkeyPatch): - + monkeypatch: pytest.MonkeyPatch, +): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -497,7 +490,8 @@ async def test_engine_core_client_util_method_nested_structures( engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True) vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.UNKNOWN_CONTEXT) + usage_context=UsageContext.UNKNOWN_CONTEXT + ) executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): @@ -514,42 +508,48 @@ async def test_engine_core_client_util_method_nested_structures( # Test list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}] result = await core_client.call_utility_async( - "echo_dc_nested", "nested1", "list_of_dicts") + "echo_dc_nested", "nested1", "list_of_dicts" + ) assert isinstance(result, list) and len(result) == 2 for i, item in enumerate(result): assert isinstance(item, dict) if i == 0: assert "a" in item and "b" in item - assert isinstance( - item["a"], - MyDataclass) and item["a"].message == "nested1" - assert isinstance( - item["b"], - MyDataclass) and item["b"].message == "nested1" + assert ( + isinstance(item["a"], MyDataclass) + and item["a"].message == "nested1" + ) + assert ( + isinstance(item["b"], MyDataclass) + and item["b"].message == "nested1" + ) else: assert "c" in item and "d" in item - assert isinstance( - item["c"], - MyDataclass) and item["c"].message == "nested1" - assert isinstance( - item["d"], - MyDataclass) and item["d"].message == "nested1" + assert ( + isinstance(item["c"], MyDataclass) + and item["c"].message == "nested1" + ) + assert ( + isinstance(item["d"], MyDataclass) + and item["d"].message == "nested1" + ) # Test dict of lists: {"list1": [val, val], "list2": [val, val]} result = await core_client.call_utility_async( - "echo_dc_nested", "nested2", "dict_of_lists") + "echo_dc_nested", "nested2", "dict_of_lists" + ) assert isinstance(result, dict) and len(result) == 2 assert "list1" in result and "list2" in result for key, lst in result.items(): assert isinstance(lst, list) and len(lst) == 2 for item in lst: - assert isinstance( - item, MyDataclass) and item.message == "nested2" + assert isinstance(item, MyDataclass) and item.message == "nested2" # Test deeply nested: {"outer": [{"inner": [val, val]}, # {"inner": [val]}]} result = await core_client.call_utility_async( - "echo_dc_nested", "nested3", "deep_nested") + "echo_dc_nested", "nested3", "deep_nested" + ) assert isinstance(result, dict) and "outer" in result outer_list = result["outer"] assert isinstance(outer_list, list) and len(outer_list) == 2 @@ -560,21 +560,22 @@ async def test_engine_core_client_util_method_nested_structures( inner_list1 = inner_dict1["inner"] assert isinstance(inner_list1, list) and len(inner_list1) == 2 for item in inner_list1: - assert isinstance(item, - MyDataclass) and item.message == "nested3" + assert isinstance(item, MyDataclass) and item.message == "nested3" # Second dict in outer list should have "inner" with 1 item inner_dict2 = outer_list[1] assert isinstance(inner_dict2, dict) and "inner" in inner_dict2 inner_list2 = inner_dict2["inner"] assert isinstance(inner_list2, list) and len(inner_list2) == 1 - assert isinstance( - inner_list2[0], - MyDataclass) and inner_list2[0].message == "nested3" + assert ( + isinstance(inner_list2[0], MyDataclass) + and inner_list2[0].message == "nested3" + ) # Test with None values in nested structures result = await core_client.call_utility_async( - "echo_dc_nested", None, "list_of_dicts") + "echo_dc_nested", None, "list_of_dicts" + ) assert isinstance(result, list) and len(result) == 2 for item in result: assert isinstance(item, dict) @@ -595,7 +596,6 @@ def test_kv_cache_events( multiprocessing_mode: bool, publisher_config, ): - with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") block_size = 16 @@ -609,8 +609,7 @@ def test_kv_cache_events( ) engine_args.kv_events_config = publisher_config - vllm_config = engine_args.create_engine_config( - UsageContext.UNKNOWN_CONTEXT) + vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): @@ -622,9 +621,9 @@ def test_kv_cache_events( log_stats=False, ) endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") - subscriber = MockSubscriber(endpoint, - topic=publisher_config.topic, - decode_type=KVEventBatch) + subscriber = MockSubscriber( + endpoint, topic=publisher_config.topic, decode_type=KVEventBatch + ) try: custom_tokens = list(range(num_blocks * block_size)) @@ -641,22 +640,25 @@ def test_kv_cache_events( seq, received = result assert seq == 0, "Sequence number mismatch" - assert (len(received.events) == 1 - ), "We should have exactly one BlockStored event" + assert len(received.events) == 1, ( + "We should have exactly one BlockStored event" + ) event = received.events[0] - assert isinstance( - event, BlockStored), "We should have a BlockStored event" - assert (len(event.block_hashes) == num_blocks - ), "We should have a BlockStored event with 2 block_hashes" - assert (event.block_size == block_size - ), "Block size should be the same as the block size" - assert (event.parent_block_hash - is None), "Parent block hash should be None" + assert isinstance(event, BlockStored), "We should have a BlockStored event" + assert len(event.block_hashes) == num_blocks, ( + "We should have a BlockStored event with 2 block_hashes" + ) + assert event.block_size == block_size, ( + "Block size should be the same as the block size" + ) + assert event.parent_block_hash is None, "Parent block hash should be None" assert event.lora_id is None, "Lora id should be None" - assert (len(event.token_ids) == num_blocks * block_size - ), "Token ids should be the same as the custom tokens" - assert (event.token_ids == custom_tokens - ), "Token ids should be the same as the custom tokens" + assert len(event.token_ids) == num_blocks * block_size, ( + "Token ids should be the same as the custom tokens" + ) + assert event.token_ids == custom_tokens, ( + "Token ids should be the same as the custom tokens" + ) finally: client.shutdown() subscriber.close() @@ -674,7 +676,6 @@ async def test_kv_cache_events_dp( multiprocessing_mode: bool, publisher_config, ): - with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") block_size = 16 @@ -692,8 +693,7 @@ async def test_kv_cache_events_dp( ) engine_args.kv_events_config = publisher_config - vllm_config = engine_args.create_engine_config( - UsageContext.UNKNOWN_CONTEXT) + vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT) executor_class = Executor.get_class(vllm_config) with set_default_torch_num_threads(1): @@ -710,13 +710,12 @@ async def test_kv_cache_events_dp( base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") endpoints = [] for i in range(dp_size): - offset_endpoint = ZmqEventPublisher.offset_endpoint_port( - base_endpoint, i) + offset_endpoint = ZmqEventPublisher.offset_endpoint_port(base_endpoint, i) endpoints.append(offset_endpoint) - subscriber = MockSubscriber(endpoints, - topic=publisher_config.topic, - decode_type=KVEventBatch) + subscriber = MockSubscriber( + endpoints, topic=publisher_config.topic, decode_type=KVEventBatch + ) try: custom_tokens = list(range(num_blocks * block_size)) @@ -734,15 +733,12 @@ async def test_kv_cache_events_dp( await asyncio.sleep(0.1) # Initialize outputs dict for all requests - outputs: dict[str, list] = { - req_id: [] - for req_id in all_request_ids - } + outputs: dict[str, list] = {req_id: [] for req_id in all_request_ids} print("processing requests...") - await asyncio.wait_for(loop_until_fully_done_async( - client, outputs), - timeout=20.0) + await asyncio.wait_for( + loop_until_fully_done_async(client, outputs), timeout=20.0 + ) # Receive from subscriber until no more messages print("collecting results...") @@ -755,13 +751,11 @@ async def test_kv_cache_events_dp( results.append(result) # Collect all events and data_parallel_ranks from all results - all_dp_ranks = [ - received.data_parallel_rank for (_, received) in results - ] + all_dp_ranks = [received.data_parallel_rank for (_, received) in results] unique_dps = set(all_dp_ranks) - assert ( - len(unique_dps) == 2 - ), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}" + assert len(unique_dps) == 2, ( + f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}" + ) finally: client.shutdown() @@ -770,7 +764,6 @@ async def test_kv_cache_events_dp( @pytest.mark.timeout(20) def test_startup_failure(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m, pytest.raises(Exception) as e_info: m.setenv("VLLM_USE_V1", "1") @@ -787,7 +780,8 @@ def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs): t = time.time() engine_args = EngineArgs(model=MODEL_NAME) vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.UNKNOWN_CONTEXT) + usage_context=UsageContext.UNKNOWN_CONTEXT + ) executor_class = Executor.get_class(vllm_config) print(f"VllmConfig creation took {time.time() - t:.2f} seconds.") @@ -815,8 +809,7 @@ def kill_first_child(): @create_new_process_for_each_test() -def test_engine_core_proc_instantiation_cuda_empty( - monkeypatch: pytest.MonkeyPatch): +def test_engine_core_proc_instantiation_cuda_empty(monkeypatch: pytest.MonkeyPatch): """ Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES is empty. This ensures the engine frontend does not need access to GPUs. @@ -833,17 +826,13 @@ def create_mock_executor(vllm_config): # Only implement the methods that are actually called during init from vllm.v1.kv_cache_interface import FullAttentionSpec - mock_spec = FullAttentionSpec(block_size=16, - num_kv_heads=1, - head_size=64, - dtype=torch.float16) - - mock_executor.get_kv_cache_specs.return_value = [{ - "default": mock_spec - }] - mock_executor.determine_available_memory.return_value = [ - 1024 * 1024 * 1024 - ] + + mock_spec = FullAttentionSpec( + block_size=16, num_kv_heads=1, head_size=64, dtype=torch.float16 + ) + + mock_executor.get_kv_cache_specs.return_value = [{"default": mock_spec}] + mock_executor.determine_available_memory.return_value = [1024 * 1024 * 1024] mock_executor.initialize_from_config.return_value = None mock_executor.max_concurrent_batches = 1 @@ -857,19 +846,22 @@ def create_mock_executor(vllm_config): from vllm.v1.engine.utils import EngineZmqAddresses - def mock_startup_handshake(self, handshake_socket, local_client, - headless, parallel_config): - return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"], - outputs=["tcp://127.0.0.1:5556"], - coordinator_input=None, - coordinator_output=None) + def mock_startup_handshake( + self, handshake_socket, local_client, headless, parallel_config + ): + return EngineZmqAddresses( + inputs=["tcp://127.0.0.1:5555"], + outputs=["tcp://127.0.0.1:5556"], + coordinator_input=None, + coordinator_output=None, + ) # Background processes are not important here m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake) vllm_config = EngineArgs( - model="deepseek-ai/DeepSeek-V2-Lite", - trust_remote_code=True).create_engine_config() + model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True + ).create_engine_config() engine_core_proc = EngineCoreProc( vllm_config=vllm_config, local_client=True, diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index f3d8e13088b0..77e67d54e587 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -40,23 +40,139 @@ def test_fast_inc_detok_invalid_utf8_err_case(): detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request) - assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \ + assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", ( "Should use FastIncrementalDetokenizer by default" + ) # Process tokens incrementally test_tokens = [ - 236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908, - 147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292, - 827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418, - 569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118, - 35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140, - 236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654, - 236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654, - 236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817, - 4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509, - 19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398, - 432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745, - 2555, 513, 236789, 602, 31118, 569 + 236840, + 107, + 138, + 236782, + 107, + 140, + 236775, + 6265, + 1083, + 623, + 121908, + 147418, + 827, + 107, + 140, + 236775, + 6265, + 236779, + 2084, + 1083, + 623, + 203292, + 827, + 107, + 140, + 236775, + 6265, + 236779, + 7777, + 1083, + 623, + 121908, + 147418, + 569, + 537, + 236789, + 65880, + 569, + 537, + 236789, + 62580, + 853, + 115693, + 210118, + 35178, + 16055, + 1270, + 759, + 215817, + 4758, + 1925, + 1117, + 827, + 107, + 140, + 236775, + 5654, + 1083, + 623, + 110733, + 46291, + 827, + 107, + 140, + 236775, + 5654, + 236779, + 2084, + 1083, + 623, + 136955, + 56731, + 827, + 107, + 140, + 236775, + 5654, + 236779, + 7777, + 1083, + 623, + 194776, + 2947, + 496, + 109811, + 1608, + 890, + 215817, + 4758, + 1925, + 1117, + 2789, + 432, + 398, + 602, + 31118, + 569, + 124866, + 134772, + 509, + 19478, + 1640, + 33779, + 236743, + 236770, + 236819, + 236825, + 236771, + 432, + 398, + 432, + 237167, + 827, + 107, + 140, + 236775, + 77984, + 1083, + 623, + 2709, + 236745, + 2555, + 513, + 236789, + 602, + 31118, + 569, ] output = "" @@ -66,9 +182,9 @@ def test_fast_inc_detok_invalid_utf8_err_case(): finished = i == len(test_tokens) - 1 output += detokenizer.get_next_output_text(finished, delta=True) - -# fmt: off - assert output == r'''[ + assert ( + output + == r"""[ { "source": "Résultats", "source_type": "CONCEPT", @@ -76,4 +192,5 @@ def test_fast_inc_detok_invalid_utf8_err_case(): "target": "Israël", "target_type": "ORGANIZATION", "target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »", - "relationship": "Obtention d'un niveau de''' + "relationship": "Obtention d'un niveau de""" + ) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 7529c3780ec2..a19ba562136f 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -3,7 +3,7 @@ from __future__ import annotations import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import pytest @@ -43,7 +43,8 @@ def _vllm_model( # env var adjustment via monkeypatch scope="function", # Prefix caching - params=[False, True]) + params=[False, True], +) def vllm_model(vllm_runner, request, monkeypatch): """VllmRunner test fixture parameterized by APC True/False.""" with _vllm_model(request.param, vllm_runner, monkeypatch) as vllm_model: @@ -62,21 +63,22 @@ def vllm_model_apc(vllm_runner, monkeypatch): # env var adjustment via monkeypatch scope="function", # Prefix caching - params=[False, True]) + params=[False, True], +) def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch): """VllmRunner test fixture with APC.""" with _vllm_model( - request.param, - vllm_runner, - monkeypatch, - skip_tokenizer_init=True, + request.param, + vllm_runner, + monkeypatch, + skip_tokenizer_init=True, ) as vllm_model: yield vllm_model def _get_test_sampling_params( prompt_list: list[str], - seed: Optional[int] = 42, + seed: int | None = 42, structured_outputs: bool = False, ) -> tuple[list[SamplingParams], list[int]]: """Generate random sampling params for a batch.""" @@ -97,9 +99,11 @@ def get_mostly_n_gt1() -> int: top_p=0.95, n=n, seed=seed, - structured_outputs=StructuredOutputsParams( - regex="[0-9]+") if structured_outputs else None, - ) for n in n_list + structured_outputs=StructuredOutputsParams(regex="[0-9]+") + if structured_outputs + else None, + ) + for n in n_list ], n_list @@ -132,23 +136,20 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None: for out, n in zip(outputs, n_list): completion_counts: dict[str, int] = {} # Assert correct number of completions - assert len(out.outputs) == n, ( - f"{len(out.outputs)} completions; {n} expected.") + assert len(out.outputs) == n, f"{len(out.outputs)} completions; {n} expected." for idx in range(n): comp = out.outputs[idx] # Assert correct completion indices - assert comp.index == idx, (f"Index {comp.index}; expected {idx}.") + assert comp.index == idx, f"Index {comp.index}; expected {idx}." text = comp.text completion_counts[text] = completion_counts.get(text, 0) + 1 # Assert unique completions if len(completion_counts) != n: - repeats = { - txt: num - for (txt, num) in completion_counts.items() if num > 1 - } + repeats = {txt: num for (txt, num) in completion_counts.items() if num > 1} raise AssertionError( f"{len(completion_counts)} unique completions; expected" - f" {n}. Repeats: {repeats}") + f" {n}. Repeats: {repeats}" + ) def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): @@ -162,13 +163,12 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): } monkeypatch.setenv("VLLM_USE_V1", "1") with vllm_runner( - MODEL, - speculative_config=speculative_config, - disable_log_stats=False, + MODEL, + speculative_config=speculative_config, + disable_log_stats=False, ) as vllm_model: llm: LLM = vllm_model.llm - sampling_params = SamplingParams(temperature=0.0, - max_tokens=max_tokens) + sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = llm.generate(example_prompts, sampling_params) n_prompts = len(example_prompts) @@ -192,15 +192,14 @@ def find_metric(name) -> list[Metric]: num_requests_running = find_metric("vllm:num_requests_running") assert len(num_requests_running) == 1 assert isinstance(num_requests_running[0], Gauge) - assert num_requests_running[0].value == .0 + assert num_requests_running[0].value == 0.0 generation_tokens = find_metric("vllm:generation_tokens") assert len(generation_tokens) == 1 assert isinstance(generation_tokens[0], Counter) assert generation_tokens[0].value == total_tokens - request_generation_tokens = find_metric( - "vllm:request_generation_tokens") + request_generation_tokens = find_metric("vllm:request_generation_tokens") assert len(request_generation_tokens) == 1 assert isinstance(request_generation_tokens[0], Histogram) assert "+Inf" in request_generation_tokens[0].buckets @@ -209,15 +208,15 @@ def find_metric(name) -> list[Metric]: assert request_generation_tokens[0].sum == total_tokens num_accepted_tokens_per_pos = find_metric( - "vllm:spec_decode_num_accepted_tokens_per_pos") + "vllm:spec_decode_num_accepted_tokens_per_pos" + ) assert len(num_accepted_tokens_per_pos) == 1 assert isinstance(num_accepted_tokens_per_pos[0], Vector) assert len(num_accepted_tokens_per_pos[0].values) == 5 @pytest.mark.parametrize("model", ["meta-llama/Llama-3.2-1B-Instruct"]) -def test_skip_tokenizer_initialization(model: str, - monkeypatch: pytest.MonkeyPatch): +def test_skip_tokenizer_initialization(model: str, monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_V1", "1") # This test checks if the flag skip_tokenizer_init skips the initialization # of tokenizer and detokenizer. The generated output is expected to contain @@ -232,8 +231,9 @@ def test_skip_tokenizer_initialization(model: str, with pytest.raises(ValueError, match="cannot pass text prompts when"): llm.generate("abc", sampling_params) - outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, - sampling_params=sampling_params) + outputs = llm.generate( + {"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params + ) assert len(outputs) > 0 completions = outputs[0].outputs assert len(completions) > 0 diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 72c0a9a13e23..9ebf7f09503e 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -7,19 +7,20 @@ import pytest -from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, - NUM_SAMPLE_LOGPROBS_UNDER_TEST, - STOP_STRINGS, - DummyOutputProcessorTestVectors, - MockEngineCore) +from tests.v1.engine.utils import ( + NUM_PROMPT_LOGPROBS_UNDER_TEST, + NUM_SAMPLE_LOGPROBS_UNDER_TEST, + STOP_STRINGS, + DummyOutputProcessorTestVectors, + MockEngineCore, +) from vllm import PoolingParams from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.output_processor import (OutputProcessor, - RequestOutputCollector) +from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.metrics.stats import IterationStats @@ -40,33 +41,34 @@ def _ref_convert_id_to_token( @pytest.mark.parametrize( - "request_output_kind", - [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) -def test_incremental_detokenization(request_output_kind: RequestOutputKind, - dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, - log_stats=False) - engine_core = MockEngineCore( - tokens_list=dummy_test_vectors.generation_tokens) + "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) +def test_incremental_detokenization( + request_output_kind: RequestOutputKind, dummy_test_vectors +): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) + engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens) # Make N requests. requests = [ - EngineCoreRequest(request_id=f"request-{idx}", - prompt_token_ids=prompt_tokens, - mm_features=None, - eos_token_id=None, - arrival_time=0, - lora_request=None, - cache_salt=None, - data_parallel_rank=None, - sampling_params=SamplingParams( - skip_special_tokens=False, - spaces_between_special_tokens=False, - output_kind=request_output_kind, - stop=[], - include_stop_str_in_output=False, - ), - pooling_params=None) + EngineCoreRequest( + request_id=f"request-{idx}", + prompt_token_ids=prompt_tokens, + mm_features=None, + eos_token_id=None, + arrival_time=0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + sampling_params=SamplingParams( + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=request_output_kind, + stop=[], + include_stop_str_in_output=False, + ), + pooling_params=None, + ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] @@ -102,8 +104,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, # Confirmed tracked values matches what we expected. for idx, (ref_gen_str, ref_gen_toks) in enumerate( - zip(dummy_test_vectors.generation_strings, - dummy_test_vectors.generation_tokens)): + zip(dummy_test_vectors.generation_strings, dummy_test_vectors.generation_tokens) + ): gen_str = gen_strings[f"request-{idx}"] gen_toks = gen_tokens[f"request-{idx}"] @@ -134,9 +136,11 @@ def _validate_logprobs( ref_prompt_logprobs = dtv.prompt_logprobs[req_idx] if num_sample_logprobs is not None: # Validate sample logprobs - assert logprobs is not None, (f"Request {req_id} requires sample" - " logprobs but sample logprobs are" - " None.") + assert logprobs is not None, ( + f"Request {req_id} requires sample" + " logprobs but sample logprobs are" + " None." + ) # Require num sampled tokens to match num # sampled logprobs - especially important # to check since the detokenizer can cause @@ -147,44 +151,51 @@ def _validate_logprobs( assert num_new_tokens == len_sample_logprobs, ( f"Request {req_id} has {num_new_tokens}" " completion tokens but has" - f" {len_sample_logprobs} sample logprobs.") + f" {len_sample_logprobs} sample logprobs." + ) ref_cumulative_logprob = 0.0 - for idx, (sampled_token, - pos_logprob_dict) in enumerate(zip(new_tokens, - logprobs)): + for idx, (sampled_token, pos_logprob_dict) in enumerate( + zip(new_tokens, logprobs) + ): # Break out the reference log probability value & # logprob token id tensors associated with this # position in the completion. Also break out the # sampled token ranks - (ref_pos_logprob_toks, ref_pos_logprob_vals, - ref_sampled_token_rank) = ref_logprobs[idx] + (ref_pos_logprob_toks, ref_pos_logprob_vals, ref_sampled_token_rank) = ( + ref_logprobs[idx] + ) # For each position in the completion sequence, # ensure the actual sampled token is among the # logprobs assert sampled_token in pos_logprob_dict, ( f"Sampled token {sampled_token} not" - f" present in logprob at index {idx}") + f" present in logprob at index {idx}" + ) # Validate number of sample logprobs num_lp_toks = len(pos_logprob_dict) - assert (num_lp_toks == num_sample_logprobs - or num_lp_toks == num_sample_logprobs + - 1), ("Valid numbers of sample logprobs are" - f" {num_sample_logprobs} or" - f" {num_sample_logprobs+1} but" - f" {num_lp_toks} logprobs found at" - f" position {idx}. Logprobs dict:" - f" {pos_logprob_dict}") + assert ( + num_lp_toks == num_sample_logprobs + or num_lp_toks == num_sample_logprobs + 1 + ), ( + "Valid numbers of sample logprobs are" + f" {num_sample_logprobs} or" + f" {num_sample_logprobs + 1} but" + f" {num_lp_toks} logprobs found at" + f" position {idx}. Logprobs dict:" + f" {pos_logprob_dict}" + ) # Validate sampled token logprob rank smp_lp = pos_logprob_dict[sampled_token] smp_lp_rank = smp_lp.rank - assert (ref_sampled_token_rank == smp_lp_rank), ( + assert ref_sampled_token_rank == smp_lp_rank, ( "Sampled token logprob rank" f" {smp_lp_rank} does not match" " correct value" f" {ref_sampled_token_rank}" - f" in Logprob {smp_lp}") + f" in Logprob {smp_lp}" + ) # Validate that the logprob processor yields # the correct log probabilities and valid @@ -198,7 +209,8 @@ def _validate_logprobs( ref_tok_id = ref_pos_logprob_toks[jdx] assert ref_tok_id in pos_logprob_dict, ( f"Expected token {ref_tok_id} to be" - f" in logprob dict but it is not.") + f" in logprob dict but it is not." + ) # Extract actually-generated logprob # info @@ -208,40 +220,43 @@ def _validate_logprobs( # A "top" (rank 1) logprob must be # present - rank_one_appears = (True - if lp_rank == 1 else rank_one_appears) + rank_one_appears = True if lp_rank == 1 else rank_one_appears # Rank must be >= 1 - assert lp_rank >= 1, (f"Logprob {lp} has invalid" - f" rank {lp_rank} < 1." - f" Logprob dict: {pos_logprob_dict}") + assert lp_rank >= 1, ( + f"Logprob {lp} has invalid" + f" rank {lp_rank} < 1." + f" Logprob dict: {pos_logprob_dict}" + ) # Validate log probability assert math.isclose(lp_val, ref_lp_val), ( f"Token id {ref_tok_id} appears in logprobs dict" f" at position {idx} in completion with log" f" probability {lp_val} but {ref_lp_val} was" - f" expected. Logprob: {lp}") + f" expected. Logprob: {lp}" + ) - assert rank_one_appears, (f"No Logprob has rank 1" - " in the following Logprob" - f" dict: {pos_logprob_dict}") + assert rank_one_appears, ( + f"No Logprob has rank 1" + " in the following Logprob" + f" dict: {pos_logprob_dict}" + ) # Validate logprobs detokenization for lp_tok in pos_logprob_dict: # Confirm that sample logprob decoded token matches # the logprob token id at this sequence position decoded_token = pos_logprob_dict[lp_tok].decoded_token - ref_decoded_token = _ref_convert_id_to_token( - dtv.tokenizer, lp_tok) + ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, lp_tok) assert decoded_token == ref_decoded_token, ( f"Sampled logprob token id {lp_tok} decodes to" f" {ref_decoded_token} but Logprob decoded" f" token is {decoded_token} instead" - f" (at position {idx})") + f" (at position {idx})" + ) - ref_cumulative_logprob += pos_logprob_dict[ - sampled_token].logprob + ref_cumulative_logprob += pos_logprob_dict[sampled_token].logprob # Assert that cumulative logprobs are correct assert math.isclose(cumulative_logprob, ref_cumulative_logprob) else: @@ -254,7 +269,8 @@ def _validate_logprobs( assert prompt_logprobs is not None, ( f"Request {req_id} requires prompt" " logprobs but prompt logprobs are" - " None.") + " None." + ) # Require num prompt tokens to match num # prompt logprobs num_prompt_tokens = len(prompt_token_ids) @@ -262,56 +278,70 @@ def _validate_logprobs( assert num_prompt_tokens == len_prompt_logprobs, ( f"Request {req_id} has {num_prompt_tokens}" " prompt tokens but has" - f" {len_prompt_logprobs} prompt logprobs.") + f" {len_prompt_logprobs} prompt logprobs." + ) # First prompt logprob is None first_plp_dict = prompt_logprobs[0] assert first_plp_dict is None, ( f"Request {req_id} first prompt logprob" f" should be None but has following value" - f" instead: {first_plp_dict}") + f" instead: {first_plp_dict}" + ) # Break out the reference prompt log prob value & # logprob token id matrices for the whole prompt. # Also break out the prompt token rank vector - (ref_prompt_logprob_toks, ref_prompt_logprob_vals, - ref_prompt_token_ranks) = ref_prompt_logprobs + ( + ref_prompt_logprob_toks, + ref_prompt_logprob_vals, + ref_prompt_token_ranks, + ) = ref_prompt_logprobs for idx, (prompt_token, pos_logprob_dict) in enumerate( - zip(prompt_token_ids[1:], prompt_logprobs[1:])): - + zip(prompt_token_ids[1:], prompt_logprobs[1:]) + ): # Break out the reference prompt log prob value # vector, prompt logprob token id vector, and # prompt token rank at the current position. - (ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals, - ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :], - ref_prompt_logprob_vals[idx, :], - ref_prompt_token_ranks[idx]) + ( + ref_pos_prompt_logprob_toks, + ref_pos_prompt_logprob_vals, + ref_pos_prompt_token_rank, + ) = ( + ref_prompt_logprob_toks[idx, :], + ref_prompt_logprob_vals[idx, :], + ref_prompt_token_ranks[idx], + ) # For each position in the prompt sequence, # ensure the actual prompt token is among the # logprobs assert prompt_token in pos_logprob_dict, ( - f"Prompt token {prompt_token} not" - f" present in logprob at index {idx}") + f"Prompt token {prompt_token} not present in logprob at index {idx}" + ) # Validate number of prompt logprobs num_plp_toks = len(pos_logprob_dict) - assert (num_plp_toks == num_prompt_logprobs - or num_plp_toks == num_prompt_logprobs + - 1), ("Valid numbers of prompt logprobs are" - f" {num_prompt_logprobs} or" - f" {num_prompt_logprobs+1} but" - f" {num_plp_toks} logprobs found at" - f" position {idx}. Logprobs dict:" - f" {pos_logprob_dict}") + assert ( + num_plp_toks == num_prompt_logprobs + or num_plp_toks == num_prompt_logprobs + 1 + ), ( + "Valid numbers of prompt logprobs are" + f" {num_prompt_logprobs} or" + f" {num_prompt_logprobs + 1} but" + f" {num_plp_toks} logprobs found at" + f" position {idx}. Logprobs dict:" + f" {pos_logprob_dict}" + ) # Validate prompt token logprob rank prmpt_tok_lp = pos_logprob_dict[prompt_token] prmpt_tok_lp_rank = prmpt_tok_lp.rank ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank - assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), ( + assert ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank, ( "Prompt token logprob rank" f" {prmpt_tok_lp_rank} does not match" " correct value" f" {ref_prmpt_tok_lp_rank}" - f" in Logprob {prmpt_tok_lp}") + f" in Logprob {prmpt_tok_lp}" + ) # Validate that the logprob processor yields # the correct prompt log probs and valid @@ -325,7 +355,8 @@ def _validate_logprobs( ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx]) assert ref_tok_id in pos_logprob_dict, ( f"Expected token {ref_tok_id} to be" - f" in logprob dict but it is not.") + f" in logprob dict but it is not." + ) # Extract actually-generated logprob # info @@ -335,87 +366,93 @@ def _validate_logprobs( # A "top" (rank 1) logprob must be # present - rank_one_appears = (True - if plp_rank == 1 else rank_one_appears) + rank_one_appears = True if plp_rank == 1 else rank_one_appears # Rank must be >= 1 assert plp_rank >= 1, ( f"Logprob {plp} has invalid" f" rank {plp_rank} < 1." - f" Logprob dict: {pos_logprob_dict}") + f" Logprob dict: {pos_logprob_dict}" + ) # Validate log probability assert math.isclose(plp_val, ref_plp_val), ( f"Token id {ref_tok_id} appears in logprobs dict" f" at position {idx} in completion with log" f" probability {plp_val} but {ref_plp_val} was" - f" expected. Logprob: {plp}") + f" expected. Logprob: {plp}" + ) - assert rank_one_appears, (f"No Logprob has rank 1" - " in the following Logprob" - f" dict: {pos_logprob_dict}") + assert rank_one_appears, ( + f"No Logprob has rank 1" + " in the following Logprob" + f" dict: {pos_logprob_dict}" + ) # Validate prompt logprob detokenization for plp_tok in pos_logprob_dict: # Confirm that prompt logprob decoded token matches # the logprob token id at this sequence position decoded_token = pos_logprob_dict[plp_tok].decoded_token - ref_decoded_token = _ref_convert_id_to_token( - dtv.tokenizer, plp_tok) + ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, plp_tok) assert decoded_token == ref_decoded_token, ( f"Prompt logprob token id {plp_tok} decodes to" f" {ref_decoded_token} but Logprob decoded" f" token is {decoded_token} instead" - f" (at position {idx})") + f" (at position {idx})" + ) else: # Prompt logprobs disabled for this request assert prompt_logprobs is None @pytest.mark.parametrize( - "request_output_kind", - [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) -@pytest.mark.parametrize("num_sample_logprobs", - [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) -@pytest.mark.parametrize("num_prompt_logprobs", - [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) -def test_logprobs_processor(request_output_kind: RequestOutputKind, - num_sample_logprobs: Optional[int], - num_prompt_logprobs: Optional[int], - dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, - log_stats=False) + "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY] +) +@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +@pytest.mark.parametrize("num_prompt_logprobs", [None, NUM_PROMPT_LOGPROBS_UNDER_TEST]) +def test_logprobs_processor( + request_output_kind: RequestOutputKind, + num_sample_logprobs: Optional[int], + num_prompt_logprobs: Optional[int], + dummy_test_vectors, +): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, - generated_logprobs_raw=None if num_sample_logprobs is None else - dummy_test_vectors.generation_logprobs, + generated_logprobs_raw=None + if num_sample_logprobs is None + else dummy_test_vectors.generation_logprobs, prompt_logprobs_raw=None - if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs) + if num_prompt_logprobs is None + else dummy_test_vectors.prompt_logprobs, + ) # Make N requests. request_id_list = [ - f"request-{idx}" - for idx in range(len(dummy_test_vectors.prompt_strings)) + f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings)) ] requests = [ - EngineCoreRequest(request_id=request_id_list[idx], - prompt_token_ids=prompt_tokens, - mm_features=None, - eos_token_id=None, - arrival_time=0, - lora_request=None, - cache_salt=None, - data_parallel_rank=None, - sampling_params=SamplingParams( - skip_special_tokens=False, - spaces_between_special_tokens=False, - output_kind=request_output_kind, - stop=[], - include_stop_str_in_output=False, - logprobs=num_sample_logprobs, - prompt_logprobs=num_prompt_logprobs, - ), - pooling_params=None) + EngineCoreRequest( + request_id=request_id_list[idx], + prompt_token_ids=prompt_tokens, + mm_features=None, + eos_token_id=None, + arrival_time=0, + lora_request=None, + cache_salt=None, + data_parallel_rank=None, + sampling_params=SamplingParams( + skip_special_tokens=False, + spaces_between_special_tokens=False, + output_kind=request_output_kind, + stop=[], + include_stop_str_in_output=False, + logprobs=num_sample_logprobs, + prompt_logprobs=num_prompt_logprobs, + ), + pooling_params=None, + ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] @@ -446,7 +483,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, prompt_logprobs = request_output.prompt_logprobs logprobs = request_output.outputs[0].logprobs gen_cumulative_logprobs[request_id] = request_output.outputs[ - 0].cumulative_logprob + 0 + ].cumulative_logprob if request_id not in gen_logprobs: # Start tracking sample and prompt logprobs for this request gen_tokens[request_id] = new_tokens @@ -463,10 +501,16 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, plp.extend(prompt_logprobs) # Confirmed tracked logprobs match what we expect - _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, - gen_cumulative_logprobs, dummy_test_vectors, - request_id_list, num_sample_logprobs, - num_prompt_logprobs) + _validate_logprobs( + gen_tokens, + gen_logprobs, + gen_prompt_logprobs, + gen_cumulative_logprobs, + dummy_test_vectors, + request_id_list, + num_sample_logprobs, + num_prompt_logprobs, + ) assert output_processor.get_num_unfinished_requests() == 0 assert not output_processor.has_unfinished_requests() @@ -474,15 +518,23 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, @pytest.mark.parametrize( "include_stop_str_in_output,stop_token_type,ignore_eos,num_sample_logprobs", - [(False, "stop_token_ids", False, None), - (True, "stop_token_ids", False, None), - (False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST), - (True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST), - (False, "eos_token_id", False, None), (True, "eos_token_id", False, None), - (False, "eos_token_id", True, None)]) -def test_stop_token(include_stop_str_in_output: bool, - num_sample_logprobs: Optional[int], stop_token_type: str, - ignore_eos: bool, dummy_test_vectors): + [ + (False, "stop_token_ids", False, None), + (True, "stop_token_ids", False, None), + (False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST), + (True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST), + (False, "eos_token_id", False, None), + (True, "eos_token_id", False, None), + (False, "eos_token_id", True, None), + ], +) +def test_stop_token( + include_stop_str_in_output: bool, + num_sample_logprobs: Optional[int], + stop_token_type: str, + ignore_eos: bool, + dummy_test_vectors, +): """Test output processor EOS/stop token handling. Send mock engine core request to mock engine core and pass core outputs @@ -523,9 +575,10 @@ def test_stop_token(include_stop_str_in_output: bool, dummy_test_vectors: dummy engine core outputs and other data structures """ model_id = dummy_test_vectors.tokenizer.name_or_path - if model_id != 'meta-llama/Llama-3.2-1B': - raise AssertionError("Test requires meta-llama/Llama-3.2-1B but " - f"{model_id} is in use.") + if model_id != "meta-llama/Llama-3.2-1B": + raise AssertionError( + f"Test requires meta-llama/Llama-3.2-1B but {model_id} is in use." + ) do_logprobs = num_sample_logprobs is not None # EOS under test; if False, stop_token_ids under test is_eos_test = stop_token_type == "eos_token_id" @@ -536,18 +589,16 @@ def test_stop_token(include_stop_str_in_output: bool, ) # '<|end_of_text|>' stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>' - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, - log_stats=False) + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) # Dummy engine core outputs, with control tokens suffixed to test stops - suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids) + suffix_token = [eos_token_id] if is_eos_test else stop_token_ids assert suffix_token is not None and isinstance(suffix_token[0], int) generation_string = dummy_test_vectors.generation_strings[0] - generation_tokens = (dummy_test_vectors.generation_tokens[0] + - 2 * suffix_token) + generation_tokens = dummy_test_vectors.generation_tokens[0] + 2 * suffix_token if do_logprobs: - generation_logprobs = ( - dummy_test_vectors.generation_logprobs[0] + - 2 * [dummy_test_vectors.generation_logprobs[0][-1]]) + generation_logprobs = dummy_test_vectors.generation_logprobs[0] + 2 * [ + dummy_test_vectors.generation_logprobs[0][-1] + ] prompt_string = dummy_test_vectors.prompt_strings[0] prompt_tokens = dummy_test_vectors.prompt_tokens[0] engine_core = MockEngineCore( @@ -556,7 +607,8 @@ def test_stop_token(include_stop_str_in_output: bool, prompt_logprobs_raw=None, eos_token_id=eos_token_id, stop_token_ids=stop_token_ids, - ignore_eos=ignore_eos) + ignore_eos=ignore_eos, + ) # Make request. request_id = "request-0" @@ -580,7 +632,8 @@ def test_stop_token(include_stop_str_in_output: bool, prompt_logprobs=None, ignore_eos=ignore_eos, ), - pooling_params=None) + pooling_params=None, + ) # Add request to the detokenizer. output_processor.add_request(request, prompt_string) @@ -605,7 +658,7 @@ def test_stop_token(include_stop_str_in_output: bool, # Update tracking. request_output = request_outputs[0] if request_output.finished: - finish_reason = ("length" if is_eos_ignore_test else "stop") + finish_reason = "length" if is_eos_ignore_test else "stop" assert request_output.outputs[0].finish_reason == finish_reason gen_string += request_output.outputs[0].text @@ -614,7 +667,7 @@ def test_stop_token(include_stop_str_in_output: bool, gen_logprobs.extend(request_output.outputs[0].logprobs) # Validate generated text - control_token = '<|end_of_text|>' if is_eos_test else '<|eot_id|>' + control_token = "<|end_of_text|>" if is_eos_test else "<|eot_id|>" if is_eos_ignore_test: # Length-based stop; expect full string ref_str = generation_string + 2 * control_token @@ -624,14 +677,15 @@ def test_stop_token(include_stop_str_in_output: bool, else: # Stop token triggered but not in output ref_str = generation_string - assert gen_string == ref_str, (f"{gen_string=}, {ref_str=}") + assert gen_string == ref_str, f"{gen_string=}, {ref_str=}" if do_logprobs: # Validate number of sample logprobs num_tokens = len(gen_tokens) num_logprobs = len(gen_logprobs) assert num_tokens == num_logprobs, ( - f"Token count ({num_tokens}) != logprobs count ({num_logprobs})") + f"Token count ({num_tokens}) != logprobs count ({num_logprobs})" + ) # Check requests are finished assert output_processor.get_num_unfinished_requests() == 0 @@ -639,22 +693,24 @@ def test_stop_token(include_stop_str_in_output: bool, @pytest.mark.parametrize("include_stop_str_in_output", [True, False]) -@pytest.mark.parametrize("num_sample_logprobs", - [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) -def test_stop_string(include_stop_str_in_output: bool, - num_sample_logprobs: Optional[int], dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, - log_stats=False) +@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST]) +def test_stop_string( + include_stop_str_in_output: bool, + num_sample_logprobs: Optional[int], + dummy_test_vectors, +): + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False) engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, generated_logprobs_raw=dummy_test_vectors.generation_logprobs - if num_sample_logprobs else None, - prompt_logprobs_raw=None) + if num_sample_logprobs + else None, + prompt_logprobs_raw=None, + ) # Make N requests. request_id_list = [ - f"request-{idx}" - for idx in range(len(dummy_test_vectors.prompt_strings)) + f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings)) ] requests = [ EngineCoreRequest( @@ -675,7 +731,8 @@ def test_stop_string(include_stop_str_in_output: bool, logprobs=num_sample_logprobs, prompt_logprobs=None, ), - pooling_params=None) + pooling_params=None, + ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] @@ -715,7 +772,8 @@ def test_stop_string(include_stop_str_in_output: bool, prompt_logprobs = request_output.prompt_logprobs logprobs = request_output.outputs[0].logprobs gen_cumulative_logprobs[request_id] = request_output.outputs[ - 0].cumulative_logprob + 0 + ].cumulative_logprob if request_id not in gen_strings: gen_strings[request_id] = new_text gen_tokens[request_id] = new_tokens @@ -733,8 +791,8 @@ def test_stop_string(include_stop_str_in_output: bool, # Confirmed tracked values matches what we expected. for idx, (ref_gen_str, stop_str) in enumerate( - zip(dummy_test_vectors.generation_strings, STOP_STRINGS)): - + zip(dummy_test_vectors.generation_strings, STOP_STRINGS) + ): # Request should be aborted. request_id = f"request-{idx}" assert request_id in aborted @@ -748,24 +806,28 @@ def test_stop_string(include_stop_str_in_output: bool, ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str if include_stop_str_in_output: - assert gen_str == ref_str_inc_stop, ( - f"{gen_str=}, {ref_str_inc_stop=}") + assert gen_str == ref_str_inc_stop, f"{gen_str=}, {ref_str_inc_stop=}" else: - assert gen_str == ref_str_exc_stop, ( - f"{gen_str=}, {ref_str_exc_stop=}") + assert gen_str == ref_str_exc_stop, f"{gen_str=}, {ref_str_exc_stop=}" # Confirmed tracked logprobs match what we expect - _validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs, - gen_cumulative_logprobs, dummy_test_vectors, - request_id_list, num_sample_logprobs, None) + _validate_logprobs( + gen_tokens, + gen_logprobs, + gen_prompt_logprobs, + gen_cumulative_logprobs, + dummy_test_vectors, + request_id_list, + num_sample_logprobs, + None, + ) assert output_processor.get_num_unfinished_requests() == 0 assert not output_processor.has_unfinished_requests() def test_iteration_stats(dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, - log_stats=True) + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True) engine_core = MockEngineCore(dummy_test_vectors.generation_tokens) engine_core_timestamp = time.monotonic() @@ -782,7 +844,8 @@ def test_iteration_stats(dummy_test_vectors): data_parallel_rank=None, sampling_params=SamplingParams(), pooling_params=None, - ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) + ) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] # Add all requests except one to the OutputProcessor. @@ -794,12 +857,13 @@ def test_iteration_stats(dummy_test_vectors): # First iteration has 2 prefills. outputs = engine_core.get_outputs()[:num_active] iteration_stats = IterationStats() - output_processor.process_outputs(outputs, engine_core_timestamp, - iteration_stats) - total_prompt_tokens = sum([ - len(prompt_tokens) - for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active] - ]) + output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) + total_prompt_tokens = sum( + [ + len(prompt_tokens) + for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active] + ] + ) assert iteration_stats.num_prompt_tokens == total_prompt_tokens assert iteration_stats.num_generation_tokens == num_active @@ -807,8 +871,7 @@ def test_iteration_stats(dummy_test_vectors): # Just decodes in this step. outputs = engine_core.get_outputs()[:num_active] iteration_stats = IterationStats() - output_processor.process_outputs(outputs, engine_core_timestamp, - iteration_stats) + output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) assert iteration_stats.num_prompt_tokens == 0 assert iteration_stats.num_generation_tokens == num_active @@ -818,8 +881,7 @@ def test_iteration_stats(dummy_test_vectors): num_active += 1 outputs = engine_core.get_outputs()[:num_active] iteration_stats = IterationStats() - output_processor.process_outputs(outputs, engine_core_timestamp, - iteration_stats) + output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1]) assert iteration_stats.num_prompt_tokens == total_prompt_tokens @@ -828,8 +890,7 @@ def test_iteration_stats(dummy_test_vectors): # Just decodes in this step. outputs = engine_core.get_outputs()[:num_active] iteration_stats = IterationStats() - output_processor.process_outputs(outputs, engine_core_timestamp, - iteration_stats) + output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) assert iteration_stats.num_prompt_tokens == 0 assert iteration_stats.num_generation_tokens == num_active @@ -853,16 +914,13 @@ def make_outputs() -> list[RequestOutput]: text=TEXT, token_ids=[idx], cumulative_logprob=(idx + 1 * 1.0), - logprobs=[{ - "a": idx, - "b": idx - }], - finish_reason="length" if - (idx == NUM_REQS - 1) else None, + logprobs=[{"a": idx, "b": idx}], + finish_reason="length" if (idx == NUM_REQS - 1) else None, ) ], finished=(idx == NUM_REQS - 1), - ) for idx in range(NUM_REQS) + ) + for idx in range(NUM_REQS) ] collector = RequestOutputCollector(RequestOutputKind.DELTA) @@ -888,8 +946,7 @@ def make_outputs() -> list[RequestOutput]: assert not output.finished # Text, token_ids, and logprobs should get merged. assert output.outputs[0].text == TEXT * num_to_put - for tok_0, tok_1 in zip(output.outputs[0].token_ids, - list(range(num_to_put))): + for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))): assert tok_0 == tok_1 assert len(output.outputs[0].logprobs) == num_to_put @@ -910,8 +967,7 @@ def make_outputs() -> list[RequestOutput]: assert output.outputs[0].finish_reason == "length" # Text, token_ids, and logprobs should get merged. assert output.outputs[0].text == TEXT * num_to_put - for tok_0, tok_1 in zip(output.outputs[0].token_ids, - list(range(num_to_put))): + for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))): assert tok_0 == tok_1 assert len(output.outputs[0].logprobs) == num_to_put @@ -1003,8 +1059,7 @@ async def test_cumulative_output_collector_n(): @pytest.mark.parametrize("runner", ["generate", "pooling"]) def test_abort_requests(runner: str, dummy_test_vectors): - output_processor = OutputProcessor(dummy_test_vectors.tokenizer, - log_stats=True) + output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True) requests = [ EngineCoreRequest( request_id=f"request-{idx}", @@ -1016,9 +1071,9 @@ def test_abort_requests(runner: str, dummy_test_vectors): cache_salt=None, data_parallel_rank=None, sampling_params=SamplingParams() if runner == "generate" else None, - pooling_params=PoolingParams( - task="embed") if runner == "pooling" else None, - ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) + pooling_params=PoolingParams(task="embed") if runner == "pooling" else None, + ) + for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] for request in requests: diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py index 3a7bcb957182..2f73756ff615 100644 --- a/tests/v1/engine/test_processor_multi_modal_uuids.py +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -16,35 +16,33 @@ # Mock processor for testing -def _mk_processor(monkeypatch, - *, - mm_cache_gb: float = 4.0, - enable_prefix_caching: bool = True) -> Processor: +def _mk_processor( + monkeypatch, *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True +) -> Processor: """ Create a Processor instance with minimal configuration suitable for unit tests without accessing external resources. """ - monkeypatch.setattr(ModelConfig, - "try_get_generation_config", - lambda self: {}, - raising=True) - monkeypatch.setattr(ModelConfig, - "__post_init__", - lambda self, *args: None, - raising=True) - monkeypatch.setattr(ModelConfig, - "verify_with_parallel_config", - lambda self, parallel_config: None, - raising=True) - monkeypatch.setattr(processor_mod, - "processor_cache_from_config", - lambda vllm_config, mm_registry: None, - raising=True) - - monkeypatch.setattr(VllmConfig, - "__post_init__", - lambda self: None, - raising=True) + monkeypatch.setattr( + ModelConfig, "try_get_generation_config", lambda self: {}, raising=True + ) + monkeypatch.setattr( + ModelConfig, "__post_init__", lambda self, *args: None, raising=True + ) + monkeypatch.setattr( + ModelConfig, + "verify_with_parallel_config", + lambda self, parallel_config: None, + raising=True, + ) + monkeypatch.setattr( + processor_mod, + "processor_cache_from_config", + lambda vllm_config, mm_registry: None, + raising=True, + ) + + monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True) model_config = ModelConfig( skip_tokenizer_init=True, @@ -57,21 +55,17 @@ def _mk_processor(monkeypatch, # Minimal multimodal_config to satisfy references in # Processor.process_inputs. class _MockMMConfig: - def __init__(self, gb: float): self.mm_processor_cache_gb = gb - model_config.multimodal_config = _MockMMConfig( - mm_cache_gb) # type: ignore[attr-defined] + model_config.multimodal_config = _MockMMConfig(mm_cache_gb) # type: ignore[attr-defined] vllm_config = VllmConfig( model_config=model_config, cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching), device_config=DeviceConfig(device="cpu"), ) - # Pass tokenizer=None; InputPreprocessor handles None when - # skip_tokenizer_init is True. - return Processor(vllm_config, tokenizer=None) # type: ignore[arg-type] + return Processor(vllm_config) def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): @@ -79,13 +73,9 @@ def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): prompt = { "prompt": "USER: <image>\nDescribe\nASSISTANT:", - "multi_modal_data": { - "image": [cherry_pil_image, stop_pil_image] - }, + "multi_modal_data": {"image": [cherry_pil_image, stop_pil_image]}, # Mismatch: 2 items but only 1 uuid provided - "multi_modal_uuids": { - "image": ["hash_cherry"] - }, + "multi_modal_uuids": {"image": ["hash_cherry"]}, } with pytest.raises(ValueError, match="must have same length as data"): @@ -104,16 +94,13 @@ def test_multi_modal_uuids_missing_modality_raises(monkeypatch): # Two modalities provided in data "multi_modal_data": { "image": [cherry_pil_image], - "video": [baby_reading_np_ndarrays] + "video": [baby_reading_np_ndarrays], }, # Only image uuids provided; video missing should raise - "multi_modal_uuids": { - "image": ["hash_cherry"] - }, + "multi_modal_uuids": {"image": ["hash_cherry"]}, } - with pytest.raises(ValueError, - match="must be provided if multi_modal_data"): + with pytest.raises(ValueError, match="must be provided if multi_modal_data"): processor.process_inputs( request_id="req-2", prompt=prompt, # type: ignore[arg-type] @@ -130,28 +117,28 @@ def test_multi_modal_uuids_missing_modality_raises(monkeypatch): ], ) def test_multi_modal_uuids_accepts_none_and_passes_through( - monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool): - processor = _mk_processor(monkeypatch, - mm_cache_gb=mm_cache_gb, - enable_prefix_caching=enable_prefix_caching) + monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool +): + processor = _mk_processor( + monkeypatch, + mm_cache_gb=mm_cache_gb, + enable_prefix_caching=enable_prefix_caching, + ) # Capture the overrides passed to InputPreprocessor.preprocess captured: dict[str, object] = {} - def fake_preprocess(prompt, - *, - tokenization_kwargs=None, - lora_request=None, - mm_uuids=None): + def fake_preprocess( + prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None + ): captured["mm_uuids"] = mm_uuids # Minimal processed inputs for decoder-only flow return {"type": "token", "prompt_token_ids": [1]} # Monkeypatch only the bound preprocess method on this instance - monkeypatch.setattr(processor.input_preprocessor, - "preprocess", - fake_preprocess, - raising=True) + monkeypatch.setattr( + processor.input_preprocessor, "preprocess", fake_preprocess, raising=True + ) # Use a consistent two-image scenario across all configurations mm_uuids = {"image": [None, "hash_stop"], "video": None} @@ -176,24 +163,19 @@ def fake_preprocess(prompt, def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): # When both processor cache is 0 and prefix caching disabled, the # processor builds overrides from request id instead of using user UUIDs. - processor = _mk_processor(monkeypatch, - mm_cache_gb=0.0, - enable_prefix_caching=False) + processor = _mk_processor(monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False) captured: dict[str, object] = {} - def fake_preprocess(prompt, - *, - tokenization_kwargs=None, - lora_request=None, - mm_uuids=None): + def fake_preprocess( + prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None + ): captured["mm_uuids"] = mm_uuids return {"type": "token", "prompt_token_ids": [1]} - monkeypatch.setattr(processor.input_preprocessor, - "preprocess", - fake_preprocess, - raising=True) + monkeypatch.setattr( + processor.input_preprocessor, "preprocess", fake_preprocess, raising=True + ) request_id = "req-42" mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"} diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index 689b2c95f927..9b720f6eb668 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -82,11 +82,12 @@ def _create_random_top_logprob_test_matrix( def _create_random_top_token_test_vector( - num_logprobs: int, - lower: int, - upper: int, - sampled_token_id: int, - adjust_num_logprobs: bool = True) -> tuple[torch.Tensor, int]: + num_logprobs: int, + lower: int, + upper: int, + sampled_token_id: int, + adjust_num_logprobs: bool = True, +) -> tuple[torch.Tensor, int]: """Create a random vector of top logprob token indices Use to create fake sample logprobs for testing. The sampled token @@ -127,8 +128,9 @@ def _create_random_top_token_test_vector( # Check if the sampled_token_id occurs in choice_tensor[1:] if sampled_token_id in choice_tensor[1:]: - sampled_token_rank = (choice_tensor[1:] == sampled_token_id).nonzero( - as_tuple=True)[0].item() + sampled_token_rank = ( + (choice_tensor[1:] == sampled_token_id).nonzero(as_tuple=True)[0].item() + ) else: # If not found, assign a random int between num_logprobs and 50700 sampled_token_rank = random.randint(num_logprobs, 50700) @@ -164,9 +166,12 @@ def _create_random_top_token_test_matrix( num_elements = shape[0] * shape[1] choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower matrix = torch.cat( - (torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1), - choice_tensor.view(shape)), - dim=1) + ( + torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1), + choice_tensor.view(shape), + ), + dim=1, + ) # Initialize the tensor for storing the ranks prompt_token_ranks = torch.empty(shape[0], dtype=torch.int) @@ -174,8 +179,7 @@ def _create_random_top_token_test_matrix( # Iterate over each row to check presence of # tokens_list[rdx] and determine its index for rdx in range(shape[0]): - row = matrix[rdx, - 1:] # Skip the first column as it contains the token list + row = matrix[rdx, 1:] # Skip the first column as it contains the token list token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0] if token_index.numel() > 0: prompt_token_ranks[rdx] = token_index.item() @@ -229,19 +233,21 @@ def generate_dummy_sample_logprobs( ( token_vector, sampled_token_rank, - ) = _create_random_top_token_test_vector(num_logprobs, 0, - len(tokenizer.vocab) - 1, - sampled_token_id) + ) = _create_random_top_token_test_vector( + num_logprobs, 0, len(tokenizer.vocab) - 1, sampled_token_id + ) res.append( - (token_vector, - _create_random_top_logprob_test_vector(num_logprobs + 1, -100, - 0), sampled_token_rank)) + ( + token_vector, + _create_random_top_logprob_test_vector(num_logprobs + 1, -100, 0), + sampled_token_rank, + ) + ) # Convert tensors in the list tuples to Python lists res_list_format = [ - (log_probs_tensor.tolist(), token_ids_tensor.tolist(), - sampled_token_rank) + (log_probs_tensor.tolist(), token_ids_tensor.tolist(), sampled_token_rank) for log_probs_tensor, token_ids_tensor, sampled_token_rank in res ] @@ -282,18 +288,24 @@ def generate_dummy_prompt_logprobs_tensors( token_vector, prompt_token_ranks, ) = _create_random_top_token_test_matrix( - (num_prompt_logprobs, num_logprobs), 0, - len(tokenizer.vocab) - 1, prompt_tokens_list[1:]) + (num_prompt_logprobs, num_logprobs), + 0, + len(tokenizer.vocab) - 1, + prompt_tokens_list[1:], + ) return LogprobsTensors( token_vector, _create_random_top_logprob_test_matrix( - (num_prompt_logprobs, num_logprobs + 1), -100, 0), - prompt_token_ranks) + (num_prompt_logprobs, num_logprobs + 1), -100, 0 + ), + prompt_token_ranks, + ) @dataclass class DummyOutputProcessorTestVectors: """Dummy test vectors for output processor tests""" + tokenizer: GeneralTokenizerType vllm_config: EngineArgs full_tokens: list[list[int]] # Prompt + generated tokens @@ -320,9 +332,9 @@ def __init__( # For each request, for each sampled token offset, # a tuple of # (list of topk token ids, list of sample logprob vals, rank) - generated_logprobs_raw: Optional[list[list[tuple[list[int], - list[float], - int]]]] = None, + generated_logprobs_raw: Optional[ + list[list[tuple[list[int], list[float], int]]] + ] = None, # For each request, a tuple of # (prompt logprob val matrix, prompt logprob tok id matrix); # each matrix has dimensions @@ -355,7 +367,8 @@ def get_outputs(self) -> list[EngineCoreOutput]: if do_logprobs: assert self.generated_logprobs_raw is not None (logprobs_token_ids_, logprobs_, sampled_token_ranks_) = ( - self.generated_logprobs_raw[req_idx][token_idx]) + self.generated_logprobs_raw[req_idx][token_idx] + ) logprobs = LogprobsLists( [logprobs_token_ids_], [logprobs_], diff --git a/tests/v1/entrypoints/conftest.py b/tests/v1/entrypoints/conftest.py index 46b953fe3743..40b9d1fe850c 100644 --- a/tests/v1/entrypoints/conftest.py +++ b/tests/v1/entrypoints/conftest.py @@ -26,8 +26,10 @@ def sample_token_ids(): @pytest.fixture def sample_regex(): - return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" - r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + return ( + r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + ) # Note: Ensure this only uses attributes compatible with xgrammar @@ -36,53 +38,44 @@ def sample_json_schema(): return { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, "skills": { "type": "array", "items": { "type": "string", - } + }, }, "grade": { "type": "string", - "pattern": "^[A-D]$" # Regex pattern + "pattern": "^[A-D]$", # Regex pattern }, "email": { "type": "string", - "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$", }, "work_history": { "type": "array", "items": { "type": "object", "properties": { - "company": { - "type": "string" - }, + "company": {"type": "string"}, "duration": { "type": "number", "minimum": 0.0, "maximum": 100.0, # Numeric range }, - "position": { - "type": "string" - } + "position": {"type": "string"}, }, "required": ["company", "duration", "position"], - "additionalProperties": False + "additionalProperties": False, }, "minItems": 0, - "maxItems": 3 - } + "maxItems": 3, + }, }, - "required": - ["name", "age", "skills", "grade", "email", "work_history"], - "additionalProperties": False + "required": ["name", "age", "skills", "grade", "email", "work_history"], + "additionalProperties": False, } @@ -94,67 +87,60 @@ def unsupported_json_schema(): "properties": { "score": { "type": "integer", - "multipleOf": 5 # Numeric multiple + "multipleOf": 5, # Numeric multiple }, "tags": { "type": "array", - "items": { - "type": "string", - "minLength": 10, - "maxLength": 20 - } - } + "items": {"type": "string", "minLength": 10, "maxLength": 20}, + }, }, "required": ["score", "tags"], - "additionalProperties": False + "additionalProperties": False, } @pytest.fixture def sample_definition_json_schema(): return { - '$defs': { - 'Step': { - 'properties': { - 'explanation': { - 'title': 'Explanation', - 'type': 'string' - }, - 'output': { - 'title': 'Output', - 'type': 'string' - } + "$defs": { + "Step": { + "properties": { + "explanation": {"title": "Explanation", "type": "string"}, + "output": {"title": "Output", "type": "string"}, }, - 'required': ['explanation', 'output'], - 'title': 'Step', - 'type': 'object' + "required": ["explanation", "output"], + "title": "Step", + "type": "object", } }, - 'properties': { - 'steps': { - 'items': { - '$ref': '#/$defs/Step' - }, - 'title': 'Steps', - 'type': 'array' + "properties": { + "steps": { + "items": {"$ref": "#/$defs/Step"}, + "title": "Steps", + "type": "array", }, - 'final_answer': { - 'title': 'Final Answer', - 'type': 'string' - } + "final_answer": {"title": "Final Answer", "type": "string"}, }, - 'required': ['steps', 'final_answer'], - 'title': 'MathReasoning', - 'type': 'object', - "additionalProperties": False + "required": ["steps", "final_answer"], + "title": "MathReasoning", + "type": "object", + "additionalProperties": False, } @pytest.fixture def sample_structured_outputs_choices(): return [ - "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", - "Ruby", "Swift", "Kotlin" + "Python", + "Java", + "JavaScript", + "C++", + "C#", + "PHP", + "TypeScript", + "Ruby", + "Swift", + "Kotlin", ] @@ -172,11 +158,11 @@ def sample_sql_ebnf(): @pytest.fixture def sample_sql_lark(): - return (""" + return """ start: select_statement select_statement: "SELECT" column "from" table "where" condition column: "col_1" | "col_2" table: "table_1" | "table_2" condition: column "=" number number: "1" | "2" -""") +""" diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 83493e25b7a6..d4c33f6cbbe2 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -22,8 +22,11 @@ from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager -from vllm.sampling_params import (GuidedDecodingParams, SamplingParams, - StructuredOutputsParams) +from vllm.sampling_params import ( + GuidedDecodingParams, + SamplingParams, + StructuredOutputsParams, +) if TYPE_CHECKING: from vllm.config import TokenizerMode @@ -44,22 +47,18 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), - ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", - None), + ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), ("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None), - #FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402 + # FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402 # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), - #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), - ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", - NGRAM_SPEC_CONFIG), - ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", - NGRAM_SPEC_CONFIG), + # ("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), + ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", NGRAM_SPEC_CONFIG), + ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", NGRAM_SPEC_CONFIG), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG), - ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", - EAGLE_SPEC_CONFIG) + ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", EAGLE_SPEC_CONFIG), ] PARAMS_MODELS_TOKENIZER_MODE = [ @@ -82,19 +81,16 @@ class CarDescription(BaseModel): def test_guided_decoding_deprecated(): - with pytest.warns(DeprecationWarning, - match="GuidedDecodingParams is deprecated.*"): + with pytest.warns(DeprecationWarning, match="GuidedDecodingParams is deprecated.*"): guided_decoding = GuidedDecodingParams(json_object=True) structured_outputs = StructuredOutputsParams(json_object=True) assert fields(guided_decoding) == fields(structured_outputs) - with pytest.warns(DeprecationWarning, - match="guided_decoding is deprecated.*"): + with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"): sp1 = SamplingParams(guided_decoding=guided_decoding) - with pytest.warns(DeprecationWarning, - match="guided_decoding is deprecated.*"): + with pytest.warns(DeprecationWarning, match="guided_decoding is deprecated.*"): sp2 = SamplingParams.from_optional(guided_decoding=guided_decoding) assert sp1 == sp2 @@ -104,7 +100,8 @@ def test_guided_decoding_deprecated(): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( "model_name, backend, tokenizer_mode, speculative_config", - PARAMS_MODELS_BACKENDS_TOKENIZER_MODE) + PARAMS_MODELS_BACKENDS_TOKENIZER_MODE, +) def test_structured_output( monkeypatch: pytest.MonkeyPatch, sample_json_schema: dict[str, Any], @@ -125,15 +122,17 @@ def test_structured_output( # Use a single LLM instance for several scenarios to # speed up the test suite. - llm = LLM(model=model_name, - enforce_eager=True, - max_model_len=1024, - structured_outputs_config=dict(backend=backend, - disable_any_whitespace=backend - in {"xgrammar", "guidance"}), - seed=120, - tokenizer_mode=tokenizer_mode, - speculative_config=speculative_config) + llm = LLM( + model=model_name, + enforce_eager=True, + max_model_len=1024, + structured_outputs_config=dict( + backend=backend, disable_any_whitespace=backend in {"xgrammar", "guidance"} + ), + seed=120, + tokenizer_mode=tokenizer_mode, + speculative_config=speculative_config, + ) # # Test 1: Generate JSON output based on a provided schema @@ -141,11 +140,14 @@ def test_structured_output( sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, - structured_outputs=StructuredOutputsParams(json=sample_json_schema)) + structured_outputs=StructuredOutputsParams(json=sample_json_schema), + ) - prompt = ("Give an example JSON for an employee profile that fits this " - "schema. Make the response as short as possible. Schema: " - f"{sample_json_schema}") + prompt = ( + "Give an example JSON for an employee profile that fits this " + "schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}" + ) outputs = llm.generate( [prompt] * 2, sampling_params=sampling_params, @@ -161,7 +163,7 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None - if backend != 'lm-format-enforcer': + if backend != "lm-format-enforcer": assert "\n" not in generated_text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") try: @@ -169,7 +171,8 @@ def test_structured_output( except json.JSONDecodeError as e: pytest.fail( f"Invalid JSON from backend={backend}: {generated_text!r}\n" - f"Schema: {sample_json_schema}\nError: {e}") + f"Schema: {sample_json_schema}\nError: {e}" + ) jsonschema.validate(instance=output_json, schema=sample_json_schema) # @@ -180,14 +183,18 @@ def test_structured_output( temperature=1.0, max_tokens=4096, n=2, - structured_outputs=StructuredOutputsParams(json_object=True)) + structured_outputs=StructuredOutputsParams(json_object=True), + ) - outputs = llm.generate(prompts=( - "Generate a JSON object with curly braces for a person with " - "name and age fields for John Smith who is 31 years old. " - "Make the response as short as possible."), - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate( + prompts=( + "Generate a JSON object with curly braces for a person with " + "name and age fields for John Smith who is 31 years old. " + "Make the response as short as possible." + ), + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None for output in outputs: @@ -209,25 +216,30 @@ def test_structured_output( sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, - structured_outputs=StructuredOutputsParams( - json=unsupported_json_schema)) + structured_outputs=StructuredOutputsParams(json=unsupported_json_schema), + ) if backend.startswith("xgrammar"): - with pytest.raises(ValueError, - match="The provided JSON schema contains features " - "not supported by xgrammar."): - - prompt = (f"Give an example JSON for an employee profile that " - f"fits this schema: {unsupported_json_schema}. " - f"Make the response as short as possible.") + with pytest.raises( + ValueError, + match="The provided JSON schema contains features " + "not supported by xgrammar.", + ): + prompt = ( + f"Give an example JSON for an employee profile that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible." + ) llm.generate( [prompt] * 2, sampling_params=sampling_params, use_tqdm=True, ) else: - prompt = (f"Give an example JSON object for a grade that " - f"fits this schema: {unsupported_json_schema}. " - f"Make the response as short as possible.") + prompt = ( + f"Give an example JSON object for a grade that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible." + ) outputs = llm.generate( prompt, sampling_params=sampling_params, @@ -253,12 +265,14 @@ def test_structured_output( temperature=0.8, top_p=0.95, max_tokens=1000, - structured_outputs=StructuredOutputsParams( - grammar=sample_sql_ebnf)) + structured_outputs=StructuredOutputsParams(grammar=sample_sql_ebnf), + ) outputs = llm.generate( - ("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible." + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -273,8 +287,7 @@ def test_structured_output( assert generated_text is not None # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( - " ", "") + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") assert generated_text.strip() == ground_truth @@ -287,12 +300,14 @@ def test_structured_output( temperature=0.8, top_p=0.95, max_tokens=1000, - structured_outputs=StructuredOutputsParams( - grammar=sample_sql_lark)) + structured_outputs=StructuredOutputsParams(grammar=sample_sql_lark), + ) outputs = llm.generate( - ("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible." + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -308,12 +323,12 @@ def test_structured_output( # use Lark to parse the output, and make sure it's a valid parse tree from lark import Lark + parser = Lark(sample_sql_lark) parser.parse(generated_text) # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( - " ", "") + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "") assert generated_text.strip() == ground_truth @@ -326,13 +341,15 @@ def test_structured_output( temperature=0.8, top_p=0.95, max_tokens=1000, - structured_outputs=StructuredOutputsParams( - grammar="not a grammar")) + structured_outputs=StructuredOutputsParams(grammar="not a grammar"), + ) with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( - ("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short " - "as possible."), + ( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short " + "as possible." + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -343,10 +360,13 @@ def test_structured_output( sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - structured_outputs=StructuredOutputsParams(regex=sample_regex)) + structured_outputs=StructuredOutputsParams(regex=sample_regex), + ) - prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. " - f"Make the response as short as possible.") + prompt = ( + f"Give an example IPv4 address with this regex: {sample_regex}. " + f"Make the response as short as possible." + ) outputs = llm.generate( [prompt] * 2, sampling_params=sampling_params, @@ -371,11 +391,15 @@ def test_structured_output( temperature=0.8, top_p=0.95, structured_outputs=StructuredOutputsParams( - choice=sample_structured_outputs_choices)) + choice=sample_structured_outputs_choices + ), + ) outputs = llm.generate( - ("The best language for type-safe systems programming is " - "(Make the response as short as possible.) "), + ( + "The best language for type-safe systems programming is " + "(Make the response as short as possible.) " + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -397,12 +421,15 @@ def test_structured_output( sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - structured_outputs=StructuredOutputsParams(json=json_schema)) + structured_outputs=StructuredOutputsParams(json=json_schema), + ) outputs = llm.generate( - ("Generate a JSON with the brand, model and car_type of the most " - "iconic car from the 90's. Make the response as short as " - "possible."), + ( + "Generate a JSON with the brand, model and car_type of the most " + "iconic car from the 90's. Make the response as short as " + "possible." + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -422,7 +449,8 @@ def test_structured_output( except json.JSONDecodeError as e: pytest.fail( f"Invalid JSON from backend={backend}: {generated_text!r}\n" - f"Schema: {json_schema}\nError: {e}") + f"Schema: {json_schema}\nError: {e}" + ) jsonschema.validate(instance=output_json, schema=json_schema) # @@ -436,21 +464,24 @@ def test_structured_output( "description": { "type": "string", "maxLength": max_length, - "minLength": min_length + "minLength": min_length, } }, "required": ["description"], - "additionalProperties": False + "additionalProperties": False, } sampling_params = SamplingParams( temperature=1.0, max_tokens=4096, - structured_outputs=StructuredOutputsParams(json=json_schema)) + structured_outputs=StructuredOutputsParams(json=json_schema), + ) outputs = llm.generate( - ("Generate a description of a frog using 50 characters. " - "Make the response as short as possible."), + ( + "Generate a description of a frog using 50 characters. " + "Make the response as short as possible." + ), sampling_params=sampling_params, use_tqdm=True, ) @@ -470,7 +501,8 @@ def test_structured_output( except json.JSONDecodeError as e: pytest.fail( f"Invalid JSON from backend={backend}: {generated_text!r}\n" - f"Schema: {json_schema}\nError: {e}") + f"Schema: {json_schema}\nError: {e}" + ) jsonschema.validate(instance=output_json, schema=json_schema) if backend not in ["outlines", "lm-format-enforcer"]: @@ -478,29 +510,28 @@ def test_structured_output( # Test 11: Generate structured output using structural_tag format # structural_tag_config = { - "type": - "structural_tag", - "structures": [{ - "begin": "<function=get_weather>", - "schema": { - "type": "object", - "properties": { - "city": { - "type": "string" - } + "type": "structural_tag", + "structures": [ + { + "begin": "<function=get_weather>", + "schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "additionalProperties": False, }, - "additionalProperties": False - }, - "end": "</function>" - }], - "triggers": ["<function="] + "end": "</function>", + } + ], + "triggers": ["<function="], } sampling_params = SamplingParams( temperature=0.0, max_tokens=4096, structured_outputs=StructuredOutputsParams( - structural_tag=json.dumps(structural_tag_config))) + structural_tag=json.dumps(structural_tag_config) + ), + ) prompt = """ You have access to the following function to retrieve the weather in a city: @@ -542,9 +573,7 @@ def test_structured_output( """ # Change this once other backends support structural_tag - outputs = llm.generate(prompt, - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None for output in outputs: @@ -554,12 +583,13 @@ def test_structured_output( assert generated_text is not None # Search for function call pattern in the response - function_call_pattern = r'<function=get_weather>(.*?)</function>' + function_call_pattern = r"<function=get_weather>(.*?)</function>" matches = re.findall(function_call_pattern, generated_text) if not matches: - print(f"Warning: No function calls found in response: " - f"{generated_text!r}") + print( + f"Warning: No function calls found in response: {generated_text!r}" + ) continue # Take the first function call if multiple are found @@ -570,16 +600,22 @@ def test_structured_output( assert isinstance(json_content["city"], str) print(f"Found valid function call: {generated_text!r}") except (json.JSONDecodeError, AssertionError) as e: - pytest.fail("Invalid function call format: " - f"{generated_text!r}\nError: {str(e)}") + pytest.fail( + f"Invalid function call format: {generated_text!r}\nError: {str(e)}" + ) @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( "model_name, backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 [ - ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", - "deepseek_r1", NGRAM_SPEC_CONFIG), + ( + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "xgrammar", + "auto", + "deepseek_r1", + NGRAM_SPEC_CONFIG, + ), ("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None), ], ) @@ -605,27 +641,25 @@ def test_structured_output_with_reasoning_matrices( enforce_eager=bool(not current_platform.is_tpu()), max_model_len=1024, max_num_seqs=16, - structured_outputs_config=dict(backend=backend, - disable_any_whitespace=backend - in {"xgrammar", "guidance"}, - reasoning_parser=reasoning_parser), + structured_outputs_config=dict( + backend=backend, + disable_any_whitespace=backend in {"xgrammar", "guidance"}, + reasoning_parser=reasoning_parser, + ), tokenizer_mode=tokenizer_mode, speculative_config=speculative_config, ) tokenizer = llm.get_tokenizer() reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)( - tokenizer=tokenizer) + tokenizer=tokenizer + ) reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?" # noqa: E501 reasoning_schema = { "type": "object", - "properties": { - "result": { - "type": "integer" - } - }, + "properties": {"result": {"type": "integer"}}, "required": ["result"], - "additionalProperties": False + "additionalProperties": False, } if "Qwen3" in model_name: reasoning_prompt += "<think>\n" @@ -646,11 +680,8 @@ def test_structured_output_with_reasoning_matrices( assert output is not None and isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text - reasoning_content, content = run_reasoning_extraction( - reasoner, [generated_text]) - print( - f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" - ) + reasoning_content, content = run_reasoning_extraction(reasoner, [generated_text]) + print(f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}") assert content is not None and reasoning_content is not None output_json = json.loads(content) @@ -658,8 +689,7 @@ def test_structured_output_with_reasoning_matrices( @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("model_name, tokenizer_mode", - PARAMS_MODELS_TOKENIZER_MODE) +@pytest.mark.parametrize("model_name, tokenizer_mode", PARAMS_MODELS_TOKENIZER_MODE) def test_structured_output_auto_mode( monkeypatch: pytest.MonkeyPatch, unsupported_json_schema: dict[str, Any], @@ -668,30 +698,32 @@ def test_structured_output_auto_mode( ): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model=model_name, - max_model_len=1024, - structured_outputs_config=dict(backend="auto"), - tokenizer_mode=tokenizer_mode) + llm = LLM( + model=model_name, + max_model_len=1024, + structured_outputs_config=dict(backend="auto"), + tokenizer_mode=tokenizer_mode, + ) sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - structured_outputs=StructuredOutputsParams( - json=unsupported_json_schema)) + structured_outputs=StructuredOutputsParams(json=unsupported_json_schema), + ) prompts = ( "Give an example JSON object for a grade " "that fits this schema: " - f"{unsupported_json_schema}. Make the response as short as possible.") + f"{unsupported_json_schema}. Make the response as short as possible." + ) # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. - outputs = llm.generate(prompts, - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True) # Make sure `auto` backend handling doesn't mess up sampling_params # and that we can reuse it without error. outputs.extend( - llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True)) + llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True) + ) assert outputs is not None for output in outputs: @@ -710,27 +742,24 @@ def test_structured_output_auto_mode( def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_USE_V1", "1") - llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", - max_model_len=1024, - structured_outputs_config=dict( - backend="guidance", - disable_any_whitespace=True, - disable_additional_properties=True)) + llm = LLM( + model="Qwen/Qwen2.5-1.5B-Instruct", + max_model_len=1024, + structured_outputs_config=dict( + backend="guidance", + disable_any_whitespace=True, + disable_additional_properties=True, + ), + ) schema = { - 'type': 'object', - 'properties': { - 'a1': { - 'type': 'string' - }, - 'a2': { - 'type': 'string' - }, - 'a3': { - 'type': 'string' - } + "type": "object", + "properties": { + "a1": {"type": "string"}, + "a2": {"type": "string"}, + "a3": {"type": "string"}, }, - 'required': ['a1', 'a2', 'a3'], + "required": ["a1", "a2", "a3"], } prompt = ( @@ -738,18 +767,19 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20. " "Make the response as short as possible." - "<|im_end|>\n<|im_start|>assistant\n") + "<|im_end|>\n<|im_start|>assistant\n" + ) def generate_with_backend(backend): structured_outputs_params = StructuredOutputsParams( json=schema, backend=backend, disable_any_whitespace=True, - disable_additional_properties=True) + disable_additional_properties=True, + ) sampling_params = SamplingParams( - temperature=0, - max_tokens=256, - structured_outputs=structured_outputs_params) + temperature=0, max_tokens=256, structured_outputs=structured_outputs_params + ) outputs = llm.generate(prompt, sampling_params=sampling_params) assert outputs is not None @@ -794,16 +824,18 @@ def test_structured_output_batched_with_non_structured_outputs_requests( structured_outputs_prompt = ( "Give an example JSON for an employee profile that fits this " "schema. Make the response as short as possible. Schema: " - f"{sample_json_schema}") + f"{sample_json_schema}" + ) non_structured_outputs_prompt = "The diameter of the Earth in kilometers is " prompts = [structured_outputs_prompt, non_structured_outputs_prompt] sampling_params = [ - SamplingParams(temperature=1.0, - max_tokens=400, - structured_outputs=StructuredOutputsParams( - json=sample_json_schema)), + SamplingParams( + temperature=1.0, + max_tokens=400, + structured_outputs=StructuredOutputsParams(json=sample_json_schema), + ), # No max tokens, temp=0 to assert on contents SamplingParams( seed=42, @@ -812,9 +844,9 @@ def test_structured_output_batched_with_non_structured_outputs_requests( ), ] - outputs = llm.generate(prompts=prompts, - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate( + prompts=prompts, sampling_params=sampling_params, use_tqdm=True + ) assert outputs is not None @@ -837,8 +869,7 @@ def test_structured_output_batched_with_non_structured_outputs_requests( # First prompt is structured outputs, expect valid JSON assert "\n" not in generated_text output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, - schema=sample_json_schema) + jsonschema.validate(instance=output_json, schema=sample_json_schema) else: # Second prompt is not structured outputs, expect valid output # Cannot assert on exact output, but we can expect it to be factual diff --git a/tests/v1/entrypoints/openai/responses/conftest.py b/tests/v1/entrypoints/openai/responses/conftest.py index 2d677a00b646..ad7594a3dd6d 100644 --- a/tests/v1/entrypoints/openai/responses/conftest.py +++ b/tests/v1/entrypoints/openai/responses/conftest.py @@ -23,9 +23,9 @@ def default_server_args(): @pytest.fixture(scope="module") def server_with_store(default_server_args): with RemoteOpenAIServer( - MODEL_NAME, - default_server_args, - env_dict={"VLLM_ENABLE_RESPONSES_API_STORE": "1"}, + MODEL_NAME, + default_server_args, + env_dict={"VLLM_ENABLE_RESPONSES_API_STORE": "1"}, ) as remote_server: yield remote_server diff --git a/tests/v1/entrypoints/openai/responses/test_basic.py b/tests/v1/entrypoints/openai/responses/test_basic.py index 2ee1004493a1..dd3a563e9570 100644 --- a/tests/v1/entrypoints/openai/responses/test_basic.py +++ b/tests/v1/entrypoints/openai/responses/test_basic.py @@ -36,24 +36,14 @@ async def test_instructions(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_chat(client: openai.AsyncOpenAI): - response = await client.responses.create(input=[ - { - "role": "system", - "content": "Finish the answer with QED." - }, - { - "role": "user", - "content": "What is 5 * 3?" - }, - { - "role": "assistant", - "content": "15. QED." - }, - { - "role": "user", - "content": "Multiply the result by 2." - }, - ], ) + response = await client.responses.create( + input=[ + {"role": "system", "content": "Finish the answer with QED."}, + {"role": "user", "content": "What is 5 * 3?"}, + {"role": "assistant", "content": "15. QED."}, + {"role": "user", "content": "Multiply the result by 2."}, + ], + ) print(response) output_text = response.output[-1].content[0].text @@ -63,15 +53,14 @@ async def test_chat(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_chat_with_input_type(client: openai.AsyncOpenAI): - response = await client.responses.create(input=[ - { - "role": "user", - "content": [{ - "type": "input_text", - "text": "Hello!" - }], - }, - ], ) + response = await client.responses.create( + input=[ + { + "role": "user", + "content": [{"type": "input_text", "text": "Hello!"}], + }, + ], + ) print(response) assert response.status == "completed" @@ -99,6 +88,6 @@ async def test_streaming(client: openai.AsyncOpenAI): assert isinstance(events[0], openai_responses_types.ResponseCreatedEvent) assert any( isinstance(event, openai_responses_types.ResponseTextDeltaEvent) - for event in events) - assert isinstance(events[-1], - openai_responses_types.ResponseCompletedEvent) + for event in events + ) + assert isinstance(events[-1], openai_responses_types.ResponseCompletedEvent) diff --git a/tests/v1/entrypoints/openai/responses/test_image.py b/tests/v1/entrypoints/openai/responses/test_image.py index 3ed36ca678c0..980d83b787e7 100644 --- a/tests/v1/entrypoints/openai/responses/test_image.py +++ b/tests/v1/entrypoints/openai/responses/test_image.py @@ -38,9 +38,9 @@ def default_image_server_args(): @pytest.fixture(scope="module") def image_server(default_image_server_args): with RemoteOpenAIServer( - MODEL_NAME, - default_image_server_args, - env_dict={"VLLM_ENABLE_RESPONSES_API_STORE": "1"}, + MODEL_NAME, + default_image_server_args, + env_dict={"VLLM_ENABLE_RESPONSES_API_STORE": "1"}, ) as remote_server: yield remote_server @@ -54,8 +54,7 @@ async def client(image_server): @pytest.fixture(scope="session") def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_url: - encode_image_base64(local_asset_server.get_image_asset(image_url)) + image_url: encode_image_base64(local_asset_server.get_image_asset(image_url)) for image_url in TEST_IMAGE_ASSETS } @@ -63,24 +62,23 @@ def base64_encoded_image(local_asset_server) -> dict[str, str]: @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True) -async def test_single_chat_session_image(client: openai.AsyncOpenAI, - model_name: str, image_url: str): +async def test_single_chat_session_image( + client: openai.AsyncOpenAI, model_name: str, image_url: str +): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "input_image", - "image_url": image_url, - "detail": "auto", - }, - { - "type": "input_text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": image_url, + "detail": "auto", + }, + {"type": "input_text", "text": content_text}, + ], + } + ] # test image url response = await client.responses.create( @@ -100,22 +98,19 @@ async def test_single_chat_session_image_base64encoded( base64_encoded_image: dict[str, str], ): content_text = "What's in this image?" - messages = [{ - "role": - "user", - "content": [ - { - "type": "input_image", - "image_url": - f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}", - "detail": "auto", - }, - { - "type": "input_text", - "text": content_text - }, - ], - }] + messages = [ + { + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}", # noqa: E501 + "detail": "auto", + }, + {"type": "input_text", "text": content_text}, + ], + } + ] # test image base64 response = await client.responses.create( model=model_name, @@ -129,24 +124,27 @@ async def test_single_chat_session_image_base64encoded( @pytest.mark.parametrize( "image_urls", [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))], - indirect=True) -async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, - image_urls: list[str]): - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "input_image", - "image_url": image_url, - "detail": "auto", - } for image_url in image_urls), - { - "type": "input_text", - "text": "What's in this image?" - }, - ], - }] + indirect=True, +) +async def test_multi_image_input( + client: openai.AsyncOpenAI, model_name: str, image_urls: list[str] +): + messages = [ + { + "role": "user", + "content": [ + *( + { + "type": "input_image", + "image_url": image_url, + "detail": "auto", + } + for image_url in image_urls + ), + {"type": "input_text", "text": "What's in this image?"}, + ], + } + ] if len(image_urls) > MAXIMUM_IMAGES: with pytest.raises(openai.BadRequestError): # test multi-image input @@ -157,10 +155,12 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, # the server should still work afterwards response = await client.responses.create( model=model_name, - input=[{ - "role": "user", - "content": "What's the weather like in Paris today?", - }], + input=[ + { + "role": "user", + "content": "What's the weather like in Paris today?", + } + ], ) assert len(response.output_text) > 0 else: diff --git a/tests/v1/entrypoints/openai/responses/test_stateful.py b/tests/v1/entrypoints/openai/responses/test_stateful.py index a2d581ef7ced..6f7edb6bd7e7 100644 --- a/tests/v1/entrypoints/openai/responses/test_stateful.py +++ b/tests/v1/entrypoints/openai/responses/test_stateful.py @@ -24,8 +24,7 @@ async def test_store(client: openai.AsyncOpenAI): assert response.status == "completed" # The response should not be found. - with pytest.raises(openai.NotFoundError, - match="Response with id .* not found."): + with pytest.raises(openai.NotFoundError, match="Response with id .* not found."): await client.responses.retrieve(response.id) @@ -53,8 +52,8 @@ async def test_background(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_background_error(client: openai.AsyncOpenAI): with pytest.raises( - openai.BadRequestError, - match="background can only be used when `store` is true"): + openai.BadRequestError, match="background can only be used when `store` is true" + ): _ = await client.responses.create( input="What is 13 * 24?", background=True, @@ -87,8 +86,9 @@ async def test_cancel_completed(client: openai.AsyncOpenAI): response = await client.responses.create(input="Hello") assert response.status == "completed" - with pytest.raises(openai.BadRequestError, - match="Cannot cancel a synchronous response."): + with pytest.raises( + openai.BadRequestError, match="Cannot cancel a synchronous response." + ): await client.responses.cancel(response.id) @@ -97,7 +97,8 @@ async def test_previous_response_id(client: openai.AsyncOpenAI): response1 = await client.responses.create( instructions="You are tested on your ability to retrieve the correct " "information from the previous response.", - input="Hello, my name is John.") + input="Hello, my name is John.", + ) response2 = await client.responses.create( input="Actually, my name is not John. My real name is Mark.", @@ -118,7 +119,8 @@ async def test_two_responses_with_same_prev_id(client: openai.AsyncOpenAI): response1 = await client.responses.create( instructions="You are tested on your ability to retrieve the correct " "information from the previous response.", - input="Hello, my name is John.") + input="Hello, my name is John.", + ) # Both response 2 and 3 use response 1 as the previous response. response2 = client.responses.create( diff --git a/tests/v1/entrypoints/openai/responses/test_structured_output.py b/tests/v1/entrypoints/openai/responses/test_structured_output.py index c4c43a87b601..db8b87768e44 100644 --- a/tests/v1/entrypoints/openai/responses/test_structured_output.py +++ b/tests/v1/entrypoints/openai/responses/test_structured_output.py @@ -11,14 +11,10 @@ async def test_structured_output(client: openai.AsyncOpenAI): response = await client.responses.create( input=[ - { - "role": "system", - "content": "Extract the event information." - }, + {"role": "system", "content": "Extract the event information."}, { "role": "user", - "content": - "Alice and Bob are going to a science fair on Friday.", + "content": "Alice and Bob are going to a science fair on Friday.", }, ], text={ @@ -28,18 +24,9 @@ async def test_structured_output(client: openai.AsyncOpenAI): "schema": { "type": "object", "properties": { - "event_name": { - "type": "string" - }, - "date": { - "type": "string" - }, - "participants": { - "type": "array", - "items": { - "type": "string" - } - }, + "event_name": {"type": "string"}, + "date": {"type": "string"}, + "participants": {"type": "array", "items": {"type": "string"}}, }, "required": ["event_name", "date", "participants"], "additionalProperties": False, @@ -65,7 +52,6 @@ async def test_structured_output(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_structured_output_with_parse(client: openai.AsyncOpenAI): - class CalendarEvent(BaseModel): event_name: str date: str diff --git a/tests/v1/entrypoints/openai/test_chat_completion.py b/tests/v1/entrypoints/openai/test_chat_completion.py index 9aa285aa9b18..522c72b55955 100644 --- a/tests/v1/entrypoints/openai/test_chat_completion.py +++ b/tests/v1/entrypoints/openai/test_chat_completion.py @@ -40,8 +40,7 @@ async def client(server): "model_name", [MODEL_NAME], ) -async def test_invalid_json_schema(client: openai.AsyncOpenAI, - model_name: str) -> None: +async def test_invalid_json_schema(client: openai.AsyncOpenAI, model_name: str) -> None: invalid_json_schema = { "$defs": { "CarType": { @@ -51,35 +50,29 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, } }, "properties": { - "brand": { - "title": "Brand", - "type": "string" - }, - "model": { - "title": "Model", - "type": "string" - }, - "car_type": { - "$ref": "#/$defs/CarType" - }, + "brand": {"title": "Brand", "type": "string"}, + "model": {"title": "Model", "type": "string"}, + "car_type": {"$ref": "#/$defs/CarType"}, "foo": "bar", }, "required": ["brand", "model", "car_type"], "title": "CarDescription", "type": "object", } - prompt = ("Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's") + prompt = ( + "Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's" + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.chat.completions.create( model=model_name, - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={"structured_outputs": { - "json": invalid_json_schema - }}, + messages=[ + { + "role": "user", + "content": prompt, + } + ], + extra_body={"structured_outputs": {"json": invalid_json_schema}}, ) @@ -89,23 +82,22 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, [MODEL_NAME], ) async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): - prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") + prompt = ( + "Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n" + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.chat.completions.create( model=model_name, - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={ - "structured_outputs": { - "regex": r"[.*" - }, - "stop": ["\n"] - }, + messages=[ + { + "role": "user", + "content": prompt, + } + ], + extra_body={"structured_outputs": {"regex": r"[.*"}, "stop": ["\n"]}, ) @@ -129,18 +121,20 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): number ::= "1 " | "2 " """ - prompt = ("Generate an SQL query to show the 'username' and 'email'" - "from the 'users' table.") + prompt = ( + "Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table." + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.chat.completions.create( model=model_name, - messages=[{ - "role": "user", - "content": prompt, - }], - extra_body={ - "structured_outputs": { - "grammar": invalid_simplified_sql_grammar + messages=[ + { + "role": "user", + "content": prompt, } + ], + extra_body={ + "structured_outputs": {"grammar": invalid_simplified_sql_grammar} }, ) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 9090beb4bbd2..35287f5b979a 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -31,12 +31,13 @@ def default_server_args(): ] -@pytest.fixture(scope="module", - params=[["--no-enable-prefix-caching"], - [ - "--no-enable-prefix-caching", - "--disable-frontend-multiprocessing" - ]]) +@pytest.fixture( + scope="module", + params=[ + ["--no-enable-prefix-caching"], + ["--no-enable-prefix-caching", "--disable-frontend-multiprocessing"], + ], +) def server(default_server_args, request): if request.param: default_server_args = default_server_args + request.param @@ -55,12 +56,10 @@ async def client(server): "model_name", [MODEL_NAME], ) -async def test_single_completion(client: openai.AsyncOpenAI, - model_name: str) -> None: - completion = await client.completions.create(model=model_name, - prompt="Hello, my name is", - max_tokens=5, - temperature=0.0) +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str) -> None: + completion = await client.completions.create( + model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -69,7 +68,8 @@ async def test_single_completion(client: openai.AsyncOpenAI, assert len(choice.text) >= 5 assert choice.finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, prompt_tokens=6, total_tokens=11) + completion_tokens=5, prompt_tokens=6, total_tokens=11 + ) # test using token IDs completion = await client.completions.create( @@ -147,11 +147,12 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): "model_name", [MODEL_NAME], ) -async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, - model_name: str) -> None: - +async def test_too_many_completion_logprobs( + client: openai.AsyncOpenAI, model_name: str +) -> None: with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs + (openai.BadRequestError, openai.APIError) + ): # test using token IDs await client.completions.create( model=model_name, prompt=[0, 0, 0, 0, 0], @@ -163,7 +164,8 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, ) ... with pytest.raises( - (openai.BadRequestError, openai.APIError)): # test using token IDs + (openai.BadRequestError, openai.APIError) + ): # test using token IDs stream = await client.completions.create( model=model_name, prompt=[0, 0, 0, 0, 0], @@ -188,13 +190,13 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), - (MODEL_NAME, 0), - (MODEL_NAME, 1), - (MODEL_NAME, None)]) -async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, - model_name: str, - prompt_logprobs: Optional[int]): +@pytest.mark.parametrize( + "model_name, prompt_logprobs", + [(MODEL_NAME, -1), (MODEL_NAME, 0), (MODEL_NAME, 1), (MODEL_NAME, None)], +) +async def test_prompt_logprobs_completion( + client: openai.AsyncOpenAI, model_name: str, prompt_logprobs: Optional[int] +): params: dict = { "prompt": ["A robot may not injure another robot", "My name is"], "model": model_name, @@ -223,8 +225,9 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_completion_streaming(client: openai.AsyncOpenAI, - model_name: str) -> None: +async def test_completion_streaming( + client: openai.AsyncOpenAI, model_name: str +) -> None: prompt = "What is an LLM?" single_completion = await client.completions.create( @@ -234,11 +237,9 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, temperature=0.0, ) single_output = single_completion.choices[0].text - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 async for chunk in stream: @@ -257,8 +258,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_parallel_no_streaming(client: openai.AsyncOpenAI, - model_name: str): +async def test_parallel_no_streaming(client: openai.AsyncOpenAI, model_name: str): """Parallel sampling without streaming. A single request output contains a list of completions. """ @@ -268,27 +268,26 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, max_tokens = 50 # we want some to finish earlier than others # High temperature to maximize chance of unique completions. - completion = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=max_tokens, - n=n, - temperature=1.0, - stream=False, - logprobs=0, - seed=42) + completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + temperature=1.0, + stream=False, + logprobs=0, + seed=42, + ) # Assert `n` completions num_completions = len(completion.choices) - assert num_completions == n, ( - f"Num completions {num_completions} but expected {n}.") + assert num_completions == n, f"Num completions {num_completions} but expected {n}." completion_repeats: dict[str, int] = {} output_token_lengths = set() for idx, choice in enumerate(completion.choices): # Assert correct completion index & some finish reason. - assert choice.index == idx, ( - f"Index {choice.index} but expected {idx}.") - assert choice.finish_reason is not None, ( - "None finish_reason is invalid.") + assert choice.index == idx, f"Index {choice.index} but expected {idx}." + assert choice.finish_reason is not None, "None finish_reason is invalid." text = choice.text completion_repeats[text] = completion_repeats.get(text, 0) + 1 output_token_lengths.add(len(choice.logprobs.tokens)) @@ -297,13 +296,10 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI, # Assert `n` unique completions num_unique = len(completion_repeats) if num_unique != n: - repeats = { - txt: num - for (txt, num) in completion_repeats.items() if num > 1 - } + repeats = {txt: num for (txt, num) in completion_repeats.items() if num > 1} raise AssertionError( - f"Expected {n} unique completions, got {num_unique};" - f" repeats: {repeats}.") + f"Expected {n} unique completions, got {num_unique}; repeats: {repeats}." + ) @pytest.mark.asyncio @@ -321,13 +317,15 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): n = 3 max_tokens = 50 # we want some to finish earlier than others - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=max_tokens, - n=n, - temperature=1.0, - stream=True, - seed=42) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + temperature=1.0, + stream=True, + seed=42, + ) chunks: list[list[str]] = [[] for _ in range(n)] finish_reason_count = 0 async for chunk in stream: @@ -338,7 +336,8 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): finish_reason_count += 1 # Assert `n` completions with correct finish reasons assert finish_reason_count == n, ( - f"Expected {n} completions with valid indices and finish_reason.") + f"Expected {n} completions with valid indices and finish_reason." + ) completion_repeats: dict[str, int] = {} chunk_lengths = set() for chunk in chunks: @@ -346,7 +345,8 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): # Assert correct number of completion tokens chunk_lengths.add(chunk_len) assert chunk_len <= max_tokens, ( - f"max_tokens={max_tokens} but chunk len is {chunk_len}.") + f"max_tokens={max_tokens} but chunk len is {chunk_len}." + ) text = "".join(chunk) completion_repeats[text] = completion_repeats.get(text, 0) + 1 print(text) @@ -355,12 +355,10 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): # Assert `n` unique completions num_unique = len(completion_repeats) if num_unique != n: - repeats = { - txt: num - for (txt, num) in completion_repeats.items() if num > 1 - } - raise AssertionError(f"{num_unique} unique completions, expected {n};" - f" repeats: {repeats}") + repeats = {txt: num for (txt, num) in completion_repeats.items() if num > 1} + raise AssertionError( + f"{num_unique} unique completions, expected {n}; repeats: {repeats}" + ) @pytest.mark.asyncio @@ -368,53 +366,55 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): "model_name", [MODEL_NAME], ) -async def test_completion_stream_options(client: openai.AsyncOpenAI, - model_name: str): +async def test_completion_stream_options(client: openai.AsyncOpenAI, model_name: str): prompt = "What is the capital of France?" # Test stream=True, stream_options= # {"include_usage": False, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - False, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": False, + }, + ) async for chunk in stream: assert chunk.usage is None # Test stream=True, stream_options= # {"include_usage": False, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": False, - "continuous_usage_stats": - True, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": False, + "continuous_usage_stats": True, + }, + ) async for chunk in stream: assert chunk.usage is None # Test stream=True, stream_options= # {"include_usage": True, "continuous_usage_stats": False} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - False, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": False, + }, + ) async for chunk in stream: if chunk.choices[0].finish_reason is None: assert chunk.usage is None @@ -425,57 +425,63 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) + final_chunk.usage.prompt_tokens + final_chunk.usage.completion_tokens + ) assert final_chunk.choices == [] # Test stream=True, stream_options= # {"include_usage": True, "continuous_usage_stats": True} - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={ - "include_usage": True, - "continuous_usage_stats": - True, - }) + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": True, + }, + ) async for chunk in stream: assert chunk.usage is not None assert chunk.usage.prompt_tokens > 0 assert chunk.usage.completion_tokens > 0 - assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + - chunk.usage.completion_tokens) + assert chunk.usage.total_tokens == ( + chunk.usage.prompt_tokens + chunk.usage.completion_tokens + ) if chunk.choices[0].finish_reason is not None: final_chunk = await stream.__anext__() assert final_chunk.usage is not None assert final_chunk.usage.prompt_tokens > 0 assert final_chunk.usage.completion_tokens > 0 assert final_chunk.usage.total_tokens == ( - final_chunk.usage.prompt_tokens + - final_chunk.usage.completion_tokens) + final_chunk.usage.prompt_tokens + final_chunk.usage.completion_tokens + ) assert final_chunk.choices == [] # Test stream=False, stream_options= # {"include_usage": None} with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": None}) + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}, + ) # Test stream=False, stream_options= # {"include_usage": True} with pytest.raises(BadRequestError): - await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": True}) + await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}, + ) # Test stream=False, stream_options= # {"continuous_usage_stats": None} @@ -486,7 +492,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, stream=False, - stream_options={"continuous_usage_stats": None}) + stream_options={"continuous_usage_stats": None}, + ) # Test stream=False, stream_options= # {"continuous_usage_stats": True} @@ -497,7 +504,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, stream=False, - stream_options={"continuous_usage_stats": True}) + stream_options={"continuous_usage_stats": True}, + ) @pytest.mark.asyncio @@ -528,15 +536,19 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): extra_body=dict( # NOTE: this has to be true for n > 1 in vLLM, but # not necessary for official client. - use_beam_search=True), + use_beam_search=True + ), ) assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" + assert batch.choices[0].text != batch.choices[1].text, ( + "beam search should be different" + ) + assert batch.choices[0].text == batch.choices[2].text, ( + "two copies of the same prompt should be the same" + ) + assert batch.choices[1].text == batch.choices[3].text, ( + "two copies of the same prompt should be the same" + ) # test streaming batch = await client.completions.create( @@ -560,31 +572,30 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): [MODEL_NAME], ) @pytest.mark.parametrize("logprobs_arg", [1, 0]) -async def test_echo_logprob_completion(client: openai.AsyncOpenAI, - model_name: str, logprobs_arg: int): +async def test_echo_logprob_completion( + client: openai.AsyncOpenAI, model_name: str, logprobs_arg: int +): tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) # test using text and token IDs for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): - completion = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - echo=True, - logprobs=logprobs_arg) - - prompt_text = tokenizer.decode(prompt) if isinstance(prompt, - list) else prompt + completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=logprobs_arg, + ) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, list) else prompt assert re.search(r"^" + prompt_text, completion.choices[0].text) logprobs = completion.choices[0].logprobs assert logprobs is not None assert len(logprobs.text_offset) > 5 - assert (len(logprobs.token_logprobs) > 5 - and logprobs.token_logprobs[0] is None) - assert (len(logprobs.top_logprobs) > 5 - and logprobs.top_logprobs[0] is None) + assert len(logprobs.token_logprobs) > 5 and logprobs.token_logprobs[0] is None + assert len(logprobs.top_logprobs) > 5 and logprobs.top_logprobs[0] is None for top_logprobs in logprobs.top_logprobs[1:]: - assert max(logprobs_arg, - 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) > 5 @@ -593,8 +604,7 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_invalid_json_schema(client: openai.AsyncOpenAI, - model_name: str) -> None: +async def test_invalid_json_schema(client: openai.AsyncOpenAI, model_name: str) -> None: invalid_json_schema = { "$defs": { "CarType": { @@ -604,32 +614,24 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, } }, "properties": { - "brand": { - "title": "Brand", - "type": "string" - }, - "model": { - "title": "Model", - "type": "string" - }, - "car_type": { - "$ref": "#/$defs/CarType" - }, + "brand": {"title": "Brand", "type": "string"}, + "model": {"title": "Model", "type": "string"}, + "car_type": {"$ref": "#/$defs/CarType"}, "foo": "bar", }, "required": ["brand", "model", "car_type"], "title": "CarDescription", "type": "object", } - prompt = ("Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's") + prompt = ( + "Generate a JSON with the brand, model and car_type of" + "the most iconic car from the 90's" + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.completions.create( model=model_name, prompt=prompt, - extra_body={"structured_outputs": { - "json": invalid_json_schema - }}, + extra_body={"structured_outputs": {"json": invalid_json_schema}}, ) @@ -639,20 +641,17 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI, [MODEL_NAME], ) async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str): - prompt = ("Generate an email address for Alan Turing, who works in Enigma." - "End in .com and new line. Example result:" - "alan.turing@enigma.com\n") + prompt = ( + "Generate an email address for Alan Turing, who works in Enigma." + "End in .com and new line. Example result:" + "alan.turing@enigma.com\n" + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.completions.create( model=model_name, prompt=prompt, - extra_body={ - "structured_outputs": { - "regex": r"[.*" - }, - "stop": ["\n"] - }, + extra_body={"structured_outputs": {"regex": r"[.*"}, "stop": ["\n"]}, ) @@ -676,29 +675,29 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): number ::= "1 " | "2 " """ - prompt = ("Generate an SQL query to show the 'username' and 'email'" - "from the 'users' table.") + prompt = ( + "Generate an SQL query to show the 'username' and 'email'" + "from the 'users' table." + ) with pytest.raises((openai.BadRequestError, openai.APIError)): await client.completions.create( model=model_name, prompt=prompt, extra_body={ - "structured_outputs": { - "grammar": invalid_simplified_sql_grammar - } + "structured_outputs": {"grammar": invalid_simplified_sql_grammar} }, ) @pytest.mark.asyncio -async def test_completion_with_empty_prompt_embeds( - client: openai.AsyncOpenAI) -> None: +async def test_completion_with_empty_prompt_embeds(client: openai.AsyncOpenAI) -> None: """Test completion with empty prompt embeds.""" payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []} headers: dict[str, str] = {"Content-Type": "application/json"} # base_url = http://localhost:8000/v1/completions - response = requests.post(f"{client.base_url}completions", - headers=headers, - json=payload) + response = requests.post( + f"{client.base_url}completions", headers=headers, json=payload + ) assert response.status_code == 200, ( - f"Expected status code 200, got {response.status_code}. ") + f"Expected status code 200, got {response.status_code}. " + ) diff --git a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py index 41f1d02bf787..3c2b3de33958 100644 --- a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py +++ b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py @@ -37,9 +37,9 @@ def default_image_embeds_server_args() -> list[str]: @pytest.fixture(scope="module") def server_with_image_embeds(default_image_embeds_server_args): - with RemoteOpenAIServer(MODEL_NAME, - default_image_embeds_server_args, - max_wait_seconds=600) as remote_server: + with RemoteOpenAIServer( + MODEL_NAME, default_image_embeds_server_args, max_wait_seconds=600 + ) as remote_server: yield remote_server @@ -57,7 +57,7 @@ def encode_image_embedding_to_base64(image_embedding) -> str: torch.save(image_embedding, buffer) buffer.seek(0) binary_data = buffer.read() - base64_image_embedding = base64.b64encode(binary_data).decode('utf-8') + base64_image_embedding = base64.b64encode(binary_data).decode("utf-8") return base64_image_embedding @@ -75,19 +75,13 @@ async def test_completions_with_image_embeds( base64_image_embedding = encode_image_embedding_to_base64(image_embeds) chat_completion = await client_with_image_embeds.chat.completions.create( messages=[ + {"role": "system", "content": "You are a helpful assistant."}, { - "role": "system", - "content": "You are a helpful assistant." - }, - { - "role": - "user", + "role": "user", "content": [ { - "type": - "text", - "text": - "Describe these images separately. For each image," + "type": "text", + "text": "Describe these images separately. For each image," "reply with a short sentence (no more than 10 words).", }, { diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py index 35f75191d9c8..55328f0cf0f0 100644 --- a/tests/v1/entrypoints/openai/test_multi_api_servers.py +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -50,16 +50,13 @@ async def client(server): "model_name", [MODEL_NAME], ) -async def test_single_completion(client: openai.AsyncOpenAI, - server: RemoteOpenAIServer, - model_name: str) -> None: - +async def test_single_completion( + client: openai.AsyncOpenAI, server: RemoteOpenAIServer, model_name: str +) -> None: async def make_request(): completion = await client.completions.create( - model=model_name, - prompt="Hello, my name is", - max_tokens=10, - temperature=1.0) + model=model_name, prompt="Hello, my name is", max_tokens=10, temperature=1.0 + ) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 1 @@ -108,9 +105,9 @@ async def make_request(): "model_name", [MODEL_NAME], ) -async def test_completion_streaming(client: openai.AsyncOpenAI, - server: RemoteOpenAIServer, - model_name: str) -> None: +async def test_completion_streaming( + client: openai.AsyncOpenAI, server: RemoteOpenAIServer, model_name: str +) -> None: prompt = "What is an LLM?" async def make_streaming_request(): @@ -124,11 +121,9 @@ async def make_streaming_request(): single_output = single_completion.choices[0].text # Perform the streaming request - stream = await client.completions.create(model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True) + stream = await client.completions.create( + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True + ) chunks: list[str] = [] finish_reason_count = 0 last_chunk = None @@ -139,16 +134,15 @@ async def make_streaming_request(): last_chunk = chunk # Keep track of the last chunk # finish reason should only return in the last block for OpenAI API - assert finish_reason_count == 1, ( - "Finish reason should appear exactly once.") - assert last_chunk is not None, ( - "Stream should have yielded at least one chunk.") - assert last_chunk.choices[ - 0].finish_reason == "length", "Finish reason should be 'length'." + assert finish_reason_count == 1, "Finish reason should appear exactly once." + assert last_chunk is not None, "Stream should have yielded at least one chunk." + assert last_chunk.choices[0].finish_reason == "length", ( + "Finish reason should be 'length'." + ) # Check that the combined text matches the non-streamed version. - assert "".join( - chunks - ) == single_output, "Streamed output should match non-streamed output." + assert "".join(chunks) == single_output, ( + "Streamed output should match non-streamed output." + ) return True # Indicate success for this request # Test single request @@ -162,9 +156,9 @@ async def make_streaming_request(): tasks = [make_streaming_request() for _ in range(num_requests)] results = await asyncio.gather(*tasks) - assert len( - results - ) == num_requests, f"Expected {num_requests} results, got {len(results)}" + assert len(results) == num_requests, ( + f"Expected {num_requests} results, got {len(results)}" + ) assert all(results), "Not all streaming requests completed successfully." await asyncio.sleep(0.5) @@ -172,9 +166,9 @@ async def make_streaming_request(): tasks = [make_streaming_request() for _ in range(num_requests)] results = await asyncio.gather(*tasks) - assert len( - results - ) == num_requests, f"Expected {num_requests} results, got {len(results)}" + assert len(results) == num_requests, ( + f"Expected {num_requests} results, got {len(results)}" + ) assert all(results), "Not all streaming requests completed successfully." # Check request balancing via Prometheus metrics if DP_SIZE > 1 diff --git a/tests/v1/executor/test_executor.py b/tests/v1/executor/test_executor.py index 4e83e2f9d4b6..c8bcd62d6680 100644 --- a/tests/v1/executor/test_executor.py +++ b/tests/v1/executor/test_executor.py @@ -14,19 +14,19 @@ from vllm.v1.executor.multiproc_executor import MultiprocExecutor -class Mock: - ... +class Mock: ... class CustomMultiprocExecutor(MultiprocExecutor): - - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None, - non_block: bool = False, - unique_reply_rank: Optional[int] = None) -> list[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + unique_reply_rank: Optional[int] = None, + ) -> list[Any]: # Drop marker to show that this was run with open(".marker", "w"): ... @@ -47,17 +47,22 @@ def test_custom_executor_type_checking(): ) LLMEngine.from_engine_args(engine_args) with pytest.raises(ValueError): - engine_args = AsyncEngineArgs(model=MODEL, - gpu_memory_utilization=0.2, - max_model_len=8192, - distributed_executor_backend=Mock) + engine_args = AsyncEngineArgs( + model=MODEL, + gpu_memory_utilization=0.2, + max_model_len=8192, + distributed_executor_backend=Mock, + ) AsyncLLM.from_engine_args(engine_args) -@pytest.mark.parametrize("distributed_executor_backend", [ - CustomMultiprocExecutor, - "tests.v1.executor.test_executor.CustomMultiprocExecutor" -]) +@pytest.mark.parametrize( + "distributed_executor_backend", + [ + CustomMultiprocExecutor, + "tests.v1.executor.test_executor.CustomMultiprocExecutor", + ], +) def test_custom_executor(distributed_executor_backend, tmp_path): cwd = os.path.abspath(".") os.chdir(tmp_path) @@ -82,10 +87,13 @@ def test_custom_executor(distributed_executor_backend, tmp_path): os.chdir(cwd) -@pytest.mark.parametrize("distributed_executor_backend", [ - CustomMultiprocExecutorAsync, - "tests.v1.executor.test_executor.CustomMultiprocExecutorAsync" -]) +@pytest.mark.parametrize( + "distributed_executor_backend", + [ + CustomMultiprocExecutorAsync, + "tests.v1.executor.test_executor.CustomMultiprocExecutorAsync", + ], +) def test_custom_executor_async(distributed_executor_backend, tmp_path): cwd = os.path.abspath(".") os.chdir(tmp_path) @@ -103,9 +111,9 @@ def test_custom_executor_async(distributed_executor_backend, tmp_path): sampling_params = SamplingParams(max_tokens=1) async def t(): - stream = engine.generate(request_id="0", - prompt="foo", - sampling_params=sampling_params) + stream = engine.generate( + request_id="0", prompt="foo", sampling_params=sampling_params + ) async for x in stream: ... diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py index 5cc6fcfd9ac9..db1c757521f0 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/generation/test_batch_invariance.py @@ -72,25 +72,22 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): Notes: - Use seeded stochastic sampling with a fixed seed to test determinism. - Outputs are intentionally longer and sampled at higher temperature/top_p - to produce a more random-sounding phrase, yet remain deterministic by + to produce a more random-sounding phrase, yet remain deterministic by seed. - Keep max_tokens and max_model_len bounded for speed and memory use. """ - seed = int(os.getenv("VLLM_TEST_SEED", "12345")) - random.seed(seed) + random.seed(12345) # Allow overrides from environment (useful for CI tuning) # "facebook/opt-125m" is too small, doesn't reliably test determinism model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) - max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128")) - min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024")) - max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048")) - assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle." + batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64")) + assert batch_size >= 2, "Batch size should be >= 2 to mix needle." # Keep GPU memory usage low to avoid startup allocation failures. - gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4")) - max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120")) + gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3")) + max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096")) swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4")) # Sampling parameters: longer outputs with a more random-sounding @@ -106,7 +103,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): seed=20240919, ) - needle_prompt = ("There once was a ") + needle_prompt = "There once was a " llm_bs1 = None llm_bsN = None @@ -114,7 +111,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): # Engine with bs=1 behavior llm_bs1 = LLM_with_max_seqs( model=model, - max_num_seqs=128, + max_num_seqs=1, gpu_memory_utilization=gpu_mem_util, max_model_len=max_model_len, swap_space=swap_space_gb, @@ -129,7 +126,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): # Engine with larger batch limit (e.g., 64) llm_bsN = LLM_with_max_seqs( model=model, - max_num_seqs=128, + max_num_seqs=batch_size, gpu_memory_utilization=gpu_mem_util, max_model_len=max_model_len, swap_space=swap_space_gb, @@ -138,17 +135,15 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): mismatches = 0 for trial in range(num_trials): - # Create a batch of size `max_batch_size` and insert the needle at + # Create a batch of size `batch_size` and insert the needle at # a random index prompts: list[str] = [] - batch_size = random.randint(max_batch_size // 2, max_batch_size) needle_pos = random.randint(0, batch_size - 1) for i in range(batch_size): if i == needle_pos: prompts.append(needle_prompt) else: - prompts.append( - _random_prompt(min_random_prompt, max_random_prompt)) + prompts.append(_random_prompt()) # Generate with the larger-batch engine outputs = llm_bsN.generate(prompts, sampling) @@ -159,19 +154,20 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): text = needle_output.outputs[0].text if text != baseline_text: - print( - f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n") mismatches += 1 passes = num_trials - mismatches # Dump how many passed vs failed - print(f"[determinism] total={num_trials}, passed={passes}, " - f"failed={mismatches}, max_batch_size={max_batch_size}") + print( + f"[determinism] total={num_trials}, passed={passes}, " + f"failed={mismatches}, batch_size={batch_size}" + ) if mismatches > 0: pytest.fail( f"Nondeterministic outputs detected: {mismatches} failed out " - f"of {num_trials} trials (max_batch_size={max_batch_size}).") + f"of {num_trials} trials (batch_size={batch_size})." + ) finally: # Ensure engines are shutdown to free GPU/VRAM across test sessions @@ -203,14 +199,8 @@ def _extract_step_logprobs(request_output): not torch.cuda.is_available(), reason="Requires CUDA to match production inference path.", ) -@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"]) -def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): - - backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) - os.environ["VLLM_ATTENTION_BACKEND"] = backend - - seed = int(os.getenv("VLLM_TEST_SEED", "12345")) - random.seed(seed) +def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2(): + # model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m") model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) @@ -224,15 +214,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): prompts = [ "The capital of France is", "The capital of Germany is", - _random_prompt(10, 1024), - _random_prompt(10, 1024), - _random_prompt(10, 1024), - _random_prompt(10, 1024), - _random_prompt(10, 1024), ] sp = SamplingParams( - temperature=0.6, + temperature=0.0, top_p=1.0, max_tokens=8, # Seed shouldn't matter at temperature=0, but keeping it stable anyway. @@ -247,36 +232,43 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): assert len(outs) == 1 step_logprobs = _extract_step_logprobs(outs[0]) if step_logprobs is None: - pytest.skip("Logits are not available on RequestOutput; " - "enable logprobs return to run this test.") + pytest.skip( + "Logits are not available on RequestOutput; " + "enable logprobs return to run this test." + ) bs1_logprobs_per_prompt.append(step_logprobs) - # BS=N: run prompts in a batch and collect logprobs per step for each + # BS=2: run prompts in a batch and collect logprobs per step for each # prompt. outs_batched = llm.generate(prompts, sp, use_tqdm=False) assert len(outs_batched) == len(prompts) - bsN_logprobs_per_prompt = [] + bs2_logprobs_per_prompt = [] for o in outs_batched: step_logprobs = _extract_step_logprobs(o) if step_logprobs is None: - pytest.skip("Logits are not available on RequestOutput; " - "enable logprobs return to run this test.") - bsN_logprobs_per_prompt.append(step_logprobs) - - # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. - for i, (logprobs_bs1, logprobs_bsN) in enumerate( - zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)): - assert len(logprobs_bs1) == len(logprobs_bsN), ( + pytest.skip( + "Logits are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + bs2_logprobs_per_prompt.append(step_logprobs) + + # Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs. + for i, (logprobs_bs1, logprobs_bs2) in enumerate( + zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt) + ): + assert len(logprobs_bs1) == len(logprobs_bs2), ( f"Different number of generation steps for prompt index {i}: " - f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)") - for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): + f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)" + ) + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)): assert a.shape == b.shape, ( - f"Logits shape mismatch at prompt {i}, step {t}: " - f"{a.shape} vs {b.shape}") + f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}" + ) # Bitwise exact equality. - assert torch.equal( - a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} " - f"(dtype={a.dtype}, shape={a.shape}).") + assert torch.equal(a, b), ( + f"Bitwise logprobs mismatch at prompt {i}, step {t} " + f"(dtype={a.dtype}, shape={a.shape})." + ) def LLM_with_max_seqs( diff --git a/tests/v1/kv_connector/nixl_integration/test_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_accuracy.py index e5d66ffeeeb2..b301968e5bf8 100644 --- a/tests/v1/kv_connector/nixl_integration/test_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_accuracy.py @@ -12,12 +12,12 @@ RTOL = 0.03 # Model-specific expected values -EXPECTED_VALUES = { - "Qwen/Qwen3-0.6B": 0.41, - "deepseek-ai/deepseek-vl2-small": 0.59 -} +EXPECTED_VALUES = {"Qwen/Qwen3-0.6B": 0.41, "deepseek-ai/deepseek-vl2-small": 0.59} -SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501 +SIMPLE_PROMPT = ( + "The best part about working on vLLM is that I got to meet so many people across " + "various different organizations like UCB, Google, and Meta which means", +) # Get model name from environment variable MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B") @@ -25,8 +25,7 @@ def run_simple_prompt(): client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL) - completion = client.completions.create(model=MODEL_NAME, - prompt=SIMPLE_PROMPT) + completion = client.completions.create(model=MODEL_NAME, prompt=SIMPLE_PROMPT) print("-" * 50) print(f"Completion results for {MODEL_NAME}:") @@ -38,9 +37,11 @@ def test_accuracy(): """Run the end to end accuracy test.""" run_simple_prompt() - model_args = (f"model={MODEL_NAME}," - f"base_url={BASE_URL}/completions," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + model_args = ( + f"model={MODEL_NAME}," + f"base_url={BASE_URL}/completions," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False" + ) results = lm_eval.simple_evaluate( model="local-completions", @@ -52,11 +53,14 @@ def test_accuracy(): expected_value = EXPECTED_VALUES.get(MODEL_NAME) if expected_value is None: - print(f"Warning: No expected value found for {MODEL_NAME}. " - "Skipping accuracy check.") + print( + f"Warning: No expected value found for {MODEL_NAME}. " + "Skipping accuracy check." + ) print(f"Measured value: {measured_value}") return - assert (measured_value - RTOL < expected_value - and measured_value + RTOL > expected_value - ), f"Expected: {expected_value} | Measured: {measured_value}" + assert ( + measured_value - RTOL < expected_value + and measured_value + RTOL > expected_value + ), f"Expected: {expected_value} | Measured: {measured_value}" diff --git a/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py index 697e101c3592..caa4aab870ab 100644 --- a/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py @@ -43,37 +43,39 @@ def check_vllm_server(url: str, timeout=5, retries=3) -> bool: if response.status_code == 200: return True else: - print(f"Attempt {attempt + 1}: Server returned status code " - "{response.status_code}") + print( + f"Attempt {attempt + 1}: Server returned status code " + "{response.status_code}" + ) except requests.exceptions.RequestException as e: print(f"Attempt {attempt + 1}: Error connecting to server: {e}") time.sleep(1) # Wait before retrying return False -def run_simple_prompt(base_url: str, model_name: str, input_prompt: str, - use_chat_endpoint: bool) -> str: +def run_simple_prompt( + base_url: str, model_name: str, input_prompt: str, use_chat_endpoint: bool +) -> str: client = openai.OpenAI(api_key="EMPTY", base_url=base_url) if use_chat_endpoint: completion = client.chat.completions.create( model=model_name, - messages=[{ - "role": "user", - "content": [{ - "type": "text", - "text": input_prompt - }] - }], + messages=[ + {"role": "user", "content": [{"type": "text", "text": input_prompt}]} + ], max_completion_tokens=MAX_OUTPUT_LEN, temperature=0.0, - seed=42) + seed=42, + ) return completion.choices[0].message.content else: - completion = client.completions.create(model=model_name, - prompt=input_prompt, - max_tokens=MAX_OUTPUT_LEN, - temperature=0.0, - seed=42) + completion = client.completions.create( + model=model_name, + prompt=input_prompt, + max_tokens=MAX_OUTPUT_LEN, + temperature=0.0, + seed=42, + ) return completion.choices[0].text @@ -90,7 +92,8 @@ def main(): "--service_url", # Name of the first argument type=str, required=True, - help="The vLLM service URL.") + help="The vLLM service URL.", + ) parser.add_argument( "--model_name", # Name of the first argument @@ -127,28 +130,30 @@ def main(): if not os.path.exists(args.file_name): raise ValueError( f"In disagg mode, the output file {args.file_name} from " - "non-disagg. baseline does not exist.") + "non-disagg. baseline does not exist." + ) service_url = f"{args.service_url}/v1" if not check_vllm_server(health_check_url): - raise RuntimeError( - f"vllm server: {args.service_url} is not ready yet!") + raise RuntimeError(f"vllm server: {args.service_url} is not ready yet!") output_strs = dict() for i, prompt in enumerate(SAMPLE_PROMPTS): - use_chat_endpoint = (i % 2 == 1) - output_str = run_simple_prompt(base_url=service_url, - model_name=args.model_name, - input_prompt=prompt, - use_chat_endpoint=use_chat_endpoint) + use_chat_endpoint = i % 2 == 1 + output_str = run_simple_prompt( + base_url=service_url, + model_name=args.model_name, + input_prompt=prompt, + use_chat_endpoint=use_chat_endpoint, + ) print(f"Prompt: {prompt}, output: {output_str}") output_strs[prompt] = output_str if args.mode == "baseline": # baseline: save outputs try: - with open(args.file_name, 'w') as json_file: + with open(args.file_name, "w") as json_file: json.dump(output_strs, json_file, indent=4) except OSError as e: print(f"Error writing to file: {e}") diff --git a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py index 8439e30be154..268a1845a2bb 100644 --- a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py +++ b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py @@ -12,8 +12,7 @@ PROXY_PORT = os.getenv("PROXY_PORT", None) if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: - raise ValueError( - "Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.") + raise ValueError("Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.") LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501 PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501 @@ -41,13 +40,13 @@ def test_edge_cases(): # (1) Check that we can handle a very short prompt, # less than the length of the block size. - completion = proxy_client.completions.create(model=MODEL, - prompt=SHORT_PROMPT, - temperature=0) + completion = proxy_client.completions.create( + model=MODEL, prompt=SHORT_PROMPT, temperature=0 + ) proxy_response = completion.choices[0].text - completion = prefill_client.completions.create(model=MODEL, - prompt=SHORT_PROMPT, - temperature=0) + completion = prefill_client.completions.create( + model=MODEL, prompt=SHORT_PROMPT, temperature=0 + ) prefill_response = completion.choices[0].text print(f"SMALL PROMPT: {proxy_response=}") assert proxy_response == prefill_response @@ -55,27 +54,27 @@ def test_edge_cases(): # (2) Check that we can handle a full prefix cache # hit on the D worker but not on the P worker. # (2a): prime the D worker. - completion = decode_client.completions.create(model=MODEL, - prompt=PROMPT, - temperature=0) + completion = decode_client.completions.create( + model=MODEL, prompt=PROMPT, temperature=0 + ) decode_response = completion.choices[0].text # (2b): send via the P/D setup - completion = proxy_client.completions.create(model=MODEL, - prompt=PROMPT, - temperature=0) + completion = proxy_client.completions.create( + model=MODEL, prompt=PROMPT, temperature=0 + ) proxy_response = completion.choices[0].text print(f"FULL CACHE HIT: {proxy_response=}") assert proxy_response == decode_response # (3) Check that we can handle a partial prefix cache # hit on the D worker. - completion = proxy_client.completions.create(model=MODEL, - prompt=LONG_PROMPT, - temperature=0) + completion = proxy_client.completions.create( + model=MODEL, prompt=LONG_PROMPT, temperature=0 + ) proxy_response = completion.choices[0].text - completion = prefill_client.completions.create(model=MODEL, - prompt=LONG_PROMPT, - temperature=0) + completion = prefill_client.completions.create( + model=MODEL, prompt=LONG_PROMPT, temperature=0 + ) prefill_response = completion.choices[0].text print(f"PARTIAL CACHE HIT: {proxy_response=}") assert proxy_response == prefill_response diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 905ae0ea7172..37d70510fe25 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -27,49 +27,45 @@ async def lifespan(app: FastAPI): # Create prefill clients for i, (host, port) in enumerate(global_args.prefiller_instances): - prefiller_base_url = f'http://{host}:{port}/v1' - app.state.prefill_clients.append({ - 'client': - httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), - 'host': - host, - 'port': - port, - 'id': - i - }) + prefiller_base_url = f"http://{host}:{port}/v1" + app.state.prefill_clients.append( + { + "client": httpx.AsyncClient(timeout=None, base_url=prefiller_base_url), + "host": host, + "port": port, + "id": i, + } + ) # Create decode clients for i, (host, port) in enumerate(global_args.decoder_instances): - decoder_base_url = f'http://{host}:{port}/v1' - app.state.decode_clients.append({ - 'client': - httpx.AsyncClient(timeout=None, base_url=decoder_base_url), - 'host': - host, - 'port': - port, - 'id': - i - }) + decoder_base_url = f"http://{host}:{port}/v1" + app.state.decode_clients.append( + { + "client": httpx.AsyncClient(timeout=None, base_url=decoder_base_url), + "host": host, + "port": port, + "id": i, + } + ) # Initialize round-robin iterators - app.state.prefill_iterator = itertools.cycle( - range(len(app.state.prefill_clients))) - app.state.decode_iterator = itertools.cycle( - range(len(app.state.decode_clients))) + app.state.prefill_iterator = itertools.cycle(range(len(app.state.prefill_clients))) + app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients))) - print(f"Initialized {len(app.state.prefill_clients)} prefill clients " - f"and {len(app.state.decode_clients)} decode clients.") + print( + f"Initialized {len(app.state.prefill_clients)} prefill clients " + f"and {len(app.state.decode_clients)} decode clients." + ) yield # Shutdown: Close all clients for client_info in app.state.prefill_clients: - await client_info['client'].aclose() + await client_info["client"].aclose() for client_info in app.state.decode_clients: - await client_info['client'].aclose() + await client_info["client"].aclose() # Update FastAPI app initialization to use lifespan @@ -83,43 +79,38 @@ def parse_args(): parser.add_argument("--host", type=str, default="localhost") # For prefiller instances - parser.add_argument("--prefiller-hosts", - "--prefiller-host", - type=str, - nargs="+", - default=["localhost"]) - parser.add_argument("--prefiller-ports", - "--prefiller-port", - type=int, - nargs="+", - default=[8100]) + parser.add_argument( + "--prefiller-hosts", + "--prefiller-host", + type=str, + nargs="+", + default=["localhost"], + ) + parser.add_argument( + "--prefiller-ports", "--prefiller-port", type=int, nargs="+", default=[8100] + ) # For decoder instances - parser.add_argument("--decoder-hosts", - "--decoder-host", - type=str, - nargs="+", - default=["localhost"]) - parser.add_argument("--decoder-ports", - "--decoder-port", - type=int, - nargs="+", - default=[8200]) + parser.add_argument( + "--decoder-hosts", "--decoder-host", type=str, nargs="+", default=["localhost"] + ) + parser.add_argument( + "--decoder-ports", "--decoder-port", type=int, nargs="+", default=[8200] + ) args = parser.parse_args() # Validate and pair hosts with ports if len(args.prefiller_hosts) != len(args.prefiller_ports): raise ValueError( - "Number of prefiller hosts must match number of prefiller ports") + "Number of prefiller hosts must match number of prefiller ports" + ) if len(args.decoder_hosts) != len(args.decoder_ports): - raise ValueError( - "Number of decoder hosts must match number of decoder ports") + raise ValueError("Number of decoder hosts must match number of decoder ports") # Create tuples of (host, port) for each service type - args.prefiller_instances = list( - zip(args.prefiller_hosts, args.prefiller_ports)) + args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports)) args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) return args @@ -136,29 +127,30 @@ def get_next_client(app, service_type: str): Returns: The next client to use """ - if service_type == 'prefill': + if service_type == "prefill": client_idx = next(app.state.prefill_iterator) return app.state.prefill_clients[client_idx] - elif service_type == 'decode': + elif service_type == "decode": client_idx = next(app.state.decode_iterator) return app.state.decode_clients[client_idx] else: raise ValueError(f"Unknown service type: {service_type}") -async def send_request_to_service(client_info: dict, endpoint: str, - req_data: dict, request_id: str): +async def send_request_to_service( + client_info: dict, endpoint: str, req_data: dict, request_id: str +): """ Send a request to a service using a client from the pool. """ req_data = req_data.copy() - req_data['kv_transfer_params'] = { + req_data["kv_transfer_params"] = { "do_remote_decode": True, "do_remote_prefill": False, "remote_engine_id": None, "remote_block_ids": None, "remote_host": None, - "remote_port": None + "remote_port": None, } req_data["stream"] = False req_data["max_tokens"] = 1 @@ -168,31 +160,31 @@ async def send_request_to_service(client_info: dict, endpoint: str, del req_data["stream_options"] headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } - response = await client_info['client'].post(endpoint, - json=req_data, - headers=headers) + response = await client_info["client"].post( + endpoint, json=req_data, headers=headers + ) response.raise_for_status() return response -async def stream_service_response(client_info: dict, endpoint: str, - req_data: dict, request_id: str): +async def stream_service_response( + client_info: dict, endpoint: str, req_data: dict, request_id: str +): """ Asynchronously stream response from a service using a client from the pool. """ headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } - async with client_info['client'].stream("POST", - endpoint, - json=req_data, - headers=headers) as response: + async with client_info["client"].stream( + "POST", endpoint, json=req_data, headers=headers + ) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): yield chunk @@ -204,40 +196,39 @@ async def _handle_completions(api: str, request: Request): request_id = str(uuid.uuid4()) # Get the next prefill client in round-robin fashion - prefill_client_info = get_next_client(request.app, 'prefill') + prefill_client_info = get_next_client(request.app, "prefill") # Send request to prefill service - response = await send_request_to_service(prefill_client_info, api, - req_data, request_id) + response = await send_request_to_service( + prefill_client_info, api, req_data, request_id + ) # Extract the needed fields response_json = response.json() - kv_transfer_params = response_json.get('kv_transfer_params', {}) + kv_transfer_params = response_json.get("kv_transfer_params", {}) if kv_transfer_params: req_data["kv_transfer_params"] = kv_transfer_params # Get the next decode client in round-robin fashion - decode_client_info = get_next_client(request.app, 'decode') + decode_client_info = get_next_client(request.app, "decode") logger.debug("Using %s %s", prefill_client_info, decode_client_info) # Stream response from decode service async def generate_stream(): - async for chunk in stream_service_response(decode_client_info, - api, - req_data, - request_id=request_id): + async for chunk in stream_service_response( + decode_client_info, api, req_data, request_id=request_id + ): yield chunk - return StreamingResponse(generate_stream(), - media_type="application/json") + return StreamingResponse(generate_stream(), media_type="application/json") except Exception as e: import sys import traceback + exc_info = sys.exc_info() - print("Error occurred in disagg prefill proxy server" - f" - {api} endpoint") + print(f"Error occurred in disagg prefill proxy server - {api} endpoint") print(e) print("".join(traceback.format_exception(*exc_info))) raise @@ -259,13 +250,14 @@ async def healthcheck(): return { "status": "ok", "prefill_instances": len(app.state.prefill_clients), - "decode_instances": len(app.state.decode_clients) + "decode_instances": len(app.state.decode_clients), } -if __name__ == '__main__': +if __name__ == "__main__": global global_args global_args = parse_args() import uvicorn + uvicorn.run(app, host=global_args.host, port=global_args.port) diff --git a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py index fe6296cf12ea..0bb67b574fa1 100644 --- a/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py +++ b/tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py @@ -2,12 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa: E501 - SharedStorageConnectorMetadata) + SharedStorageConnectorMetadata, +) from vllm.distributed.kv_transfer.kv_transfer_state import ( - ensure_kv_transfer_initialized, get_kv_transfer_group) + ensure_kv_transfer_initialized, + get_kv_transfer_group, +) from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput -from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin) +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin # Importing utils registers TestSharedStorageConnector with the factory from .utils import create_vllm_config @@ -34,7 +36,7 @@ def test_kv_connector_mixin_clears_metadata(): vllm_config = create_vllm_config() vllm_config.kv_transfer_config.kv_connector = "TestSharedStorageConnector" vllm_config.kv_transfer_config.kv_role = "kv_both" - vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = ("unit") + vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit" # Initialize the global connector instance ensure_kv_transfer_initialized(vllm_config) @@ -46,7 +48,8 @@ def test_kv_connector_mixin_clears_metadata(): # Invoke the no-forward path which uses the mixin context manager KVConnectorModelRunnerMixin.kv_connector_no_forward( - scheduler_output, vllm_config) + scheduler_output, vllm_config + ) # Verify clear_connector_metadata was called on the connector connector = get_kv_transfer_group() diff --git a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py index 549e85875025..0902fbfe85f3 100644 --- a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py +++ b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py @@ -9,17 +9,19 @@ from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.request import Request, RequestStatus -from .utils import (create_model_runner_output, create_request, - create_scheduler, create_vllm_config) +from .utils import ( + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) def _make_get_num_new_matched_tokens( req_num_new_matched_tokens: dict[str, int], async_load, ) -> Callable[[Request, int], tuple[int, bool]]: - - def get_num_new_matched_tokens(request: Request, - _: int) -> tuple[int, bool]: + def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]: value = req_num_new_matched_tokens.get(request.request_id, 0) return value, async_load @@ -33,9 +35,7 @@ def scheduler(): @pytest.mark.parametrize( - "num_prompt_blocks," - "num_external_computed_blocks," - "invalid_block_idxs", + "num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs", [ (100, 99, {0, 98}), (100, 99, {50, 98}), @@ -51,8 +51,7 @@ def test_async_load_failure( assert num_prompt_blocks >= num_external_computed_blocks num_prompt_tokens = num_prompt_blocks * scheduler.block_size - num_external_computed_tokens = (num_external_computed_blocks * - scheduler.block_size) + num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size request1 = create_request(num_tokens=num_prompt_tokens) scheduler.add_request(request=request1) @@ -71,8 +70,8 @@ def test_async_load_failure( scheduler.connector = Mock() scheduler.connector.get_num_new_matched_tokens.side_effect = ( - _make_get_num_new_matched_tokens(req_num_new_matched_tokens, - async_load=True)) + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True) + ) scheduler.connector.take_events.return_value = () scheduler_output = scheduler.schedule() @@ -84,14 +83,14 @@ def test_async_load_failure( assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 # Simulate a failure in loading some of request2 blocks. - (req2_block_ids, ) = scheduler.kv_cache_manager.get_block_ids( - request2.request_id) + (req2_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request2.request_id) invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs} model_runner_output = create_model_runner_output( reqs=[], finished_recving={request1.request_id, request3.request_id}, invalid_block_ids=invalid_block_ids, - use_eos=True) + use_eos=True, + ) scheduler.update_from_output(scheduler_output, model_runner_output) @@ -100,8 +99,9 @@ def test_async_load_failure( assert len(scheduler.waiting) == 3 for request in scheduler.waiting: if request.request_id == request2.request_id: - assert request.num_computed_tokens == (min_invalid_block_idx * - scheduler.block_size) + assert request.num_computed_tokens == ( + min_invalid_block_idx * scheduler.block_size + ) else: assert request.num_computed_tokens == 0 assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS @@ -110,9 +110,7 @@ def test_async_load_failure( @pytest.mark.parametrize( - "num_prompt_blocks," - "num_external_computed_blocks," - "invalid_block_idxs", + "num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs", [ (100, 99, {0, 98}), (100, 99, {50, 98}), @@ -128,8 +126,7 @@ def test_sync_load_failure( assert num_prompt_blocks >= num_external_computed_blocks num_prompt_tokens = num_prompt_blocks * scheduler.block_size - num_external_computed_tokens = (num_external_computed_blocks * - scheduler.block_size) + num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size request1 = create_request(num_tokens=num_prompt_tokens) scheduler.add_request(request=request1) @@ -148,8 +145,8 @@ def test_sync_load_failure( scheduler.connector = Mock() scheduler.connector.get_num_new_matched_tokens.side_effect = ( - _make_get_num_new_matched_tokens(req_num_new_matched_tokens, - async_load=False)) + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False) + ) scheduler.connector.request_finished.return_value = (False, None) scheduler.connector.take_events.return_value = () @@ -165,8 +162,7 @@ def test_sync_load_failure( assert len(scheduler.running) == 3 assert len(scheduler_output.scheduled_new_reqs) == 3 for request in scheduler_output.scheduled_new_reqs: - assert request.num_computed_tokens == expected_computed_tokens[ - request.req_id] + assert request.num_computed_tokens == expected_computed_tokens[request.req_id] assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 # Simulate a failure in loading some of request2 blocks. @@ -175,14 +171,16 @@ def test_sync_load_failure( model_runner_output = create_model_runner_output( [request1, request2, request3], invalid_block_ids=invalid_block_ids, - use_eos=True) + use_eos=True, + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 assert scheduler.running[0].request_id == request2.request_id assert scheduler.running[0].num_computed_tokens == ( - min(invalid_block_idxs) * scheduler.block_size) + min(invalid_block_idxs) * scheduler.block_size + ) assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 assert scheduler.connector.request_finished.call_count == 2 @@ -205,19 +203,19 @@ def test_sync_load_failure_with_shared_blocks( num_common_prefix_blocks: int, invalid_block_idxs: set[int], ): - assert (num_prompt_blocks >= num_external_computed_blocks >= - num_common_prefix_blocks) + assert num_prompt_blocks >= num_external_computed_blocks >= num_common_prefix_blocks num_prompt_tokens = num_prompt_blocks * scheduler.block_size - num_external_computed_tokens = (num_external_computed_blocks * - scheduler.block_size) + num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size common_prefix_len = num_common_prefix_blocks * scheduler.block_size - request1 = create_request(num_tokens=num_prompt_tokens, - common_prefix_len=common_prefix_len) + request1 = create_request( + num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len + ) scheduler.add_request(request=request1) - request2 = create_request(num_tokens=num_prompt_tokens, - common_prefix_len=common_prefix_len) + request2 = create_request( + num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len + ) scheduler.add_request(request=request2) # Mock KV connector method. @@ -228,8 +226,8 @@ def test_sync_load_failure_with_shared_blocks( scheduler.connector = Mock() scheduler.connector.get_num_new_matched_tokens.side_effect = ( - _make_get_num_new_matched_tokens(req_num_new_matched_tokens, - async_load=False)) + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False) + ) scheduler.connector.take_events.return_value = () scheduler_output = scheduler.schedule() @@ -243,17 +241,15 @@ def test_sync_load_failure_with_shared_blocks( assert len(scheduler.running) == 2 assert len(scheduler_output.scheduled_new_reqs) == 2 for request in scheduler_output.scheduled_new_reqs: - assert request.num_computed_tokens == expected_computed_tokens[ - request.req_id] + assert request.num_computed_tokens == expected_computed_tokens[request.req_id] assert scheduler.connector.get_num_new_matched_tokens.call_count == 2 # Simulate a failure in loading some of the shared blocks. req1_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0] invalid_block_ids = {req1_block_ids[i] for i in invalid_block_idxs} model_runner_output = create_model_runner_output( - [request1, request2], - invalid_block_ids=invalid_block_ids, - use_eos=True) + [request1, request2], invalid_block_ids=invalid_block_ids, use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) @@ -266,15 +262,14 @@ def test_sync_load_failure_with_shared_blocks( assert len(scheduler.running) == 2 for request in scheduler.running: - assert request.num_computed_tokens == expected_computed_tokens[ - request.request_id] + assert ( + request.num_computed_tokens == expected_computed_tokens[request.request_id] + ) assert scheduler.connector.get_num_new_matched_tokens.call_count == 2 @pytest.mark.parametrize( - "num_prompt_blocks," - "num_external_computed_blocks," - "invalid_block_idxs", + "num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs", [ (100, 99, {0, 50, 98}), (100, 99, {98, 50, 0}), @@ -289,8 +284,7 @@ def test_async_progressive_load_failure( assert num_prompt_blocks >= num_external_computed_blocks num_prompt_tokens = num_prompt_blocks * scheduler.block_size - num_external_computed_tokens = (num_external_computed_blocks * - scheduler.block_size) + num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size request = create_request(num_tokens=num_prompt_tokens) scheduler.add_request(request=request) @@ -303,8 +297,8 @@ def test_async_progressive_load_failure( scheduler.connector = Mock() scheduler.connector.get_num_new_matched_tokens.side_effect = ( - _make_get_num_new_matched_tokens(req_num_new_matched_tokens, - async_load=True)) + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True) + ) scheduler.connector.take_events.return_value = () scheduler_output = scheduler.schedule() @@ -318,24 +312,24 @@ def test_async_progressive_load_failure( min_invalid_block_idx = max(invalid_block_idxs) + 1 # Simulate failures when progressively loading request blocks. for invalid_block_idx in invalid_block_idxs: - (req_block_ids, ) = scheduler.kv_cache_manager.get_block_ids( - request.request_id) + (req_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request.request_id) invalid_block_ids = {req_block_ids[invalid_block_idx]} model_runner_output = create_model_runner_output( reqs=[], finished_recving=set(), invalid_block_ids=invalid_block_ids, - use_eos=True) + use_eos=True, + ) scheduler.update_from_output(scheduler_output, model_runner_output) min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx) assert len(scheduler.waiting) == 1 - assert scheduler.waiting.peek_request( - ).request_id == request.request_id - assert request.num_computed_tokens == (min_invalid_block_idx * - scheduler.block_size) + assert scheduler.waiting.peek_request().request_id == request.request_id + assert request.num_computed_tokens == ( + min_invalid_block_idx * scheduler.block_size + ) assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert scheduler.failed_recving_kv_req_ids == {request.request_id} assert scheduler.connector.get_num_new_matched_tokens.call_count == 1 diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index b1780d8a9af8..74ae3ca9a863 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -52,29 +52,26 @@ def test_multi_shared_storage_connector_consistency(): kv_connector="MultiConnector", kv_role="kv_both", kv_connector_extra_config={ - "connectors": [{ - "kv_connector": - "TestSharedStorageConnector", - "kv_role": - "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_1_path), - "name": "storage1", + "connectors": [ + { + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_1_path), + "name": "storage1", + }, + "kv_connector_module_path": "tests.v1.kv_connector.unit.utils", }, - "kv_connector_module_path": - "tests.v1.kv_connector.unit.utils", - }, { - "kv_connector": - "TestSharedStorageConnector", - "kv_role": - "kv_both", - "kv_connector_extra_config": { - "shared_storage_path": str(storage_2_path), - "name": "storage2", + { + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_2_path), + "name": "storage2", + }, + "kv_connector_module_path": "tests.v1.kv_connector.unit.utils", }, - "kv_connector_module_path": - "tests.v1.kv_connector.unit.utils", - }] + ] }, ) @@ -93,14 +90,16 @@ def test_multi_shared_storage_connector_consistency(): local_subdirs = list(storage_1_path.iterdir()) external_subdirs = list(storage_2_path.iterdir()) - assert len( - local_subdirs - ) > 0, f"Local storage path {storage_1_path} is empty after generation." + assert len(local_subdirs) > 0, ( + f"Local storage path {storage_1_path} is empty after generation." + ) assert len(external_subdirs) > 0, ( - f"External storage path {storage_2_path} is empty after generation.") + f"External storage path {storage_2_path} is empty after generation." + ) assert len(local_subdirs) == len(external_subdirs), ( f"Mismatch in number of cache entries: " - f"Local={len(local_subdirs)}, External={len(external_subdirs)}") + f"Local={len(local_subdirs)}, External={len(external_subdirs)}" + ) # The subdirectories should correspond to the prompt hashes # Since prompts are the same, the hash directories should be the same name @@ -113,29 +112,39 @@ def test_multi_shared_storage_connector_consistency(): # Compare the contents of each corresponding cache directory for subdir_name in local_subdir_names: print(f"Comparing contents of cache directory: {subdir_name}") - assert _compare_directories(storage_1_path / subdir_name, - storage_2_path / subdir_name), \ - (f"Contents differ for cache directory '{subdir_name}' between " - f"{storage_1_path} and {storage_2_path}") + assert _compare_directories( + storage_1_path / subdir_name, storage_2_path / subdir_name + ), ( + f"Contents differ for cache directory '{subdir_name}' between " + f"{storage_1_path} and {storage_2_path}" + ) events = get_connector_events() # get_num_new_matched_tokens and update_state_after_alloc will be called # on each connector in turn. assert events["storage1-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[0] 0", + "build_connector_meta", ] assert events["storage1-WORKER"][:5] == [ - 'register_kv_caches', 'bind_connector_metadata', 'start_load_kv', - 'wait_for_layer_load', 'save_kv_layer' + "register_kv_caches", + "bind_connector_metadata", + "start_load_kv", + "wait_for_layer_load", + "save_kv_layer", ] assert events["storage2-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[0] 0", + "build_connector_meta", ] assert events["storage2-WORKER"][:5] == [ - 'register_kv_caches', 'bind_connector_metadata', 'start_load_kv', - 'wait_for_layer_load', 'save_kv_layer' + "register_kv_caches", + "bind_connector_metadata", + "start_load_kv", + "wait_for_layer_load", + "save_kv_layer", ] # Reset prefix cache or else we'll just get the tokens back from there. @@ -151,12 +160,14 @@ def test_multi_shared_storage_connector_consistency(): # on that one but with zero blocks for others (first nonzero match is # chosen). assert events["storage1-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[7] 96', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[7] 96", + "build_connector_meta", ] assert events["storage2-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[0] 0", + "build_connector_meta", ] # Delete storage1 connector state @@ -175,12 +186,14 @@ def test_multi_shared_storage_connector_consistency(): # a hit, so update_state_after_alloc will only be called with allocated # blocks for the second connector. assert events["storage1-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[0] 0', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[0] 0", + "build_connector_meta", ] assert events["storage2-SCHEDULER"][:3] == [ - 'get_num_new_matched_tokens 0', - 'update_state_after_alloc num_blocks=[7] 96', 'build_connector_meta' + "get_num_new_matched_tokens 0", + "update_state_after_alloc num_blocks=[7] 96", + "build_connector_meta", ] # Clean up @@ -191,15 +204,14 @@ def test_multi_shared_storage_connector_consistency(): def get_connector_events() -> dict[str, list[str]]: # Read in connector events and reset the files. import glob + event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log") connector_events = {} for fname in event_files: name = fname.split("connector_")[1].split("_events.log")[0] try: with open(fname, "r+") as f: - connector_events[name] = [ - line.strip() for line in f if line.strip() - ] + connector_events[name] = [line.strip() for line in f if line.strip()] f.truncate(0) except Exception as e: print(f"[ERROR] Could not read connector events for {name}: {e}") @@ -211,5 +223,5 @@ def test_engine_id_conflict(): configs = [KVTransferConfig() for _ in range(2)] ids = [config.engine_id for config in configs] assert ids[0] != ids[1], ( - "Engine IDs should be different for different configs. " - f"Got {ids}") + f"Engine IDs should be different for different configs. Got {ids}" + ) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 21953b5533ec..a1f53cb25563 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -19,20 +19,28 @@ from vllm import LLM from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( - MultiKVConnectorStats) + MultiKVConnectorStats, +) from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( - KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, - NixlConnectorWorker, NixlKVConnectorStats) + KVConnectorRole, + NixlAgentMetadata, + NixlConnector, + NixlConnectorMetadata, + NixlConnectorWorker, + NixlKVConnectorStats, +) from vllm.distributed.kv_transfer.kv_transfer_state import ( - ensure_kv_transfer_shutdown, has_kv_transfer_group) + ensure_kv_transfer_shutdown, + has_kv_transfer_group, +) from vllm.forward_context import ForwardContext from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +from vllm.v1.request import RequestStatus from .utils import create_request, create_scheduler, create_vllm_config @@ -41,14 +49,14 @@ def clear_kv_transfer(): """ The test cases in this file use `VLLM_ENABLE_V1_MULTIPROCESSING=0`, - causing the global variable `_KV_CONNECTOR_AGENT` + causing the global variable `_KV_CONNECTOR_AGENT` to be assigned but never deleted. - Since the current pytest process does not terminate and instead + Since the current pytest process does not terminate and instead continues running tests from other files, - this global variable remains in memory and interferes + this global variable remains in memory and interferes with test cases in other modules. - + So we use this fixture to ensure that the global variable `_KV_CONNECTOR_AGENT` is properly cleaned up after each test. """ @@ -57,11 +65,12 @@ def clear_kv_transfer(): ensure_kv_transfer_shutdown() -def get_default_xfer_telemetry(xferDurationS: float = 1, - postDurationS: float = 1, - totalBytes: int = 1, - descCount: int = 1) -> dict: - +def get_default_xfer_telemetry( + xferDurationS: float = 1, + postDurationS: float = 1, + totalBytes: int = 1, + descCount: int = 1, +) -> dict: class AttributeDict(dict): __slots__ = () __getattr__ = dict.__getitem__ @@ -82,7 +91,7 @@ class FakeNixlWrapper: We don't inherit from nixl._api.nixl_agent because nixl may not be installed. - + Note: The complete source of this class is also used in the `_make_fake_nixl_pkg` function to create a fake nixl package for Ray workers. @@ -93,8 +102,7 @@ class FakeNixlWrapper: def __init__(self, agent_name: str, *args, **kwargs): self._cycles_before_xfer_done = 0 - self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( - lambda: 0) + self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(lambda: 0) def get_reg_descs(self, caches_data, memory_type: str) -> list: return [str(uuid.uuid4()) for _ in caches_data] @@ -122,8 +130,7 @@ def get_new_notifs(self) -> dict[str, list[bytes]]: return {} def check_xfer_state(self, handle: int) -> str: - if self._check_xfer_state_cycles[ - handle] >= self._cycles_before_xfer_done: + if self._check_xfer_state_cycles[handle] >= self._cycles_before_xfer_done: return "DONE" self._check_xfer_state_cycles[handle] += 1 return "PROC" @@ -140,13 +147,15 @@ def remove_remote_agent(self, agent: str) -> None: def send_notif(self, agent_name: str, notif_msg: bytes) -> None: pass - def make_prepped_xfer(self, - xfer_type: str, - local_xfer_side_handle: int, - local_block_descs_ids: list[int], - remote_xfer_side_handle: int, - remote_block_descs_ids: list[int], - notif_msg: Optional[bytes] = None) -> int: + def make_prepped_xfer( + self, + xfer_type: str, + local_xfer_side_handle: int, + local_block_descs_ids: list[int], + remote_xfer_side_handle: int, + remote_block_descs_ids: list[int], + notif_msg: Optional[bytes] = None, + ) -> int: return uuid.uuid4().int def transfer(self, handle: int) -> str: @@ -167,7 +176,7 @@ def set_cycles_before_xfer_done(self, cycles: int): def _make_fake_nixl_pkg(): """Context manager that creates a temporary package making `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper. - + Automatically cleans up the temporary directory when done. """ with tempfile.TemporaryDirectory() as td: @@ -213,10 +222,12 @@ def test_basic_interface(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_prefill=True) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) request_id = request.request_id scheduler.add_request(request) @@ -232,8 +243,11 @@ def test_basic_interface(): req_meta = kv_connector_metadata.reqs_to_recv[request_id] for block_id, block in zip( - req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[request_id]): + req_meta.local_block_ids, + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ], + ): assert block_id == block.block_id @@ -253,11 +267,13 @@ def test_prompt_less_than_block_size(): NUM_TOKENS = int(BLOCK_SIZE * 0.5) # Request will have 1 partial remote block. - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_prefill=True, - num_remote_blocks=1) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + num_remote_blocks=1, + ) scheduler.add_request(request) scheduler_output = scheduler.schedule() @@ -270,15 +286,15 @@ def test_prompt_less_than_block_size(): class FakeNixlConnectorWorker(NixlConnectorWorker): - REMOTE_ENGINE_ID = "remote_engine" def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency - def _nixl_handshake(self, host: str, port: int, remote_tp_size: int, - expected_engine_id: str) -> dict[int, str]: + def _nixl_handshake( + self, host: str, port: int, remote_tp_size: int, expected_engine_id: str + ) -> dict[int, str]: # Mimic slow _nixl_handshake, as well as bypass zmq communication. time.sleep(self._hand_shake_latency) # These should've been done in register_kv_caches(), called by @@ -303,21 +319,23 @@ def _nixl_handshake(self, host: str, port: int, remote_tp_size: int, # is started. We mock HND here. kv_cache_layout="HND", ), - remote_tp_size=remote_tp_size) + remote_tp_size=remote_tp_size, + ) return {0: remote_agent_name} class TestNixlHandshake: - @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) + FakeNixlWrapper, + ) def test_multi_xfer_one_engine( self, # dist_init is a fixture that initializes the distributed environment. - dist_init): + dist_init, + ): """Test case where multiple xfers are initiated to the same engine. - + This test triggers the connector to load remote KV for the same `request_id`. The transfer is not done immediately due to `set_cycles_before_xfer_done`, so there is a state where there are @@ -331,9 +349,9 @@ def test_multi_xfer_one_engine( # Test worker role in decode server. connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id, hand_shake_latency=0) - assert isinstance(connector.connector_worker.nixl_wrapper, - FakeNixlWrapper) + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) num_xfers = 4 while True: @@ -344,21 +362,19 @@ def test_multi_xfer_one_engine( num_xfers -= 1 metadata.add_new_req( request_id=request_id, - local_block_ids=[ - num_xfers + 1, num_xfers + 2, num_xfers + 3 - ], + local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3], kv_transfer_params={ - "remote_block_ids": - [num_xfers + 4, num_xfers + 5, num_xfers + 6], - "remote_engine_id": - FakeNixlConnectorWorker.REMOTE_ENGINE_ID, - "remote_host": - "localhost", - "remote_port": - 1234, - "remote_tp_size": - 1, - }) + "remote_block_ids": [ + num_xfers + 4, + num_xfers + 5, + num_xfers + 6, + ], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) connector.bind_connector_metadata(metadata) # Mimic maybe_setup_kv_connector in gpu_model_runner. @@ -370,8 +386,9 @@ def test_multi_xfer_one_engine( _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) _after_load = time.perf_counter() - assert _after_load - _before_load < 0.1, "start_load_kv took " \ - f"{_after_load - _before_load} seconds" + assert _after_load - _before_load < 0.1, ( + f"start_load_kv took {_after_load - _before_load} seconds" + ) # Mimic get_finished_kv_transfers in gpu_model_runner. _, done_recving = connector.get_finished(finished_req_ids=set()) @@ -383,20 +400,25 @@ def test_multi_xfer_one_engine( @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) - @pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [ - (1, 1), - (2, 1), - (4, 2), - (4, 4), - ]) + FakeNixlWrapper, + ) + @pytest.mark.parametrize( + "decode_tp_size, prefill_tp_size", + [ + (1, 1), + (2, 1), + (4, 2), + (4, 4), + ], + ) def test_async_load_kv( - self, - # Fixture that initializes the distributed environment. - dist_init, - # Simulate consumer-producer TP sizes. - decode_tp_size, - prefill_tp_size): + self, + # Fixture that initializes the distributed environment. + dist_init, + # Simulate consumer-producer TP sizes. + decode_tp_size, + prefill_tp_size, + ): """Test that NixlConnector's start_load_kv should be non-blocking.""" vllm_config = create_vllm_config() @@ -405,18 +427,20 @@ def test_async_load_kv( # Test worker role in decode server. connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id) + vllm_config, connector.engine_id + ) metadata = NixlConnectorMetadata() - metadata.add_new_req(request_id="id", - local_block_ids=[1, 2, 3], - kv_transfer_params={ - "remote_block_ids": [4, 5, 6], - "remote_engine_id": - FakeNixlConnectorWorker.REMOTE_ENGINE_ID, - "remote_host": "localhost", - "remote_port": 1234, - "remote_tp_size": prefill_tp_size, - }) + metadata.add_new_req( + request_id="id", + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": prefill_tp_size, + }, + ) connector.bind_connector_metadata(metadata) timeout = 2.5 @@ -430,8 +454,9 @@ def test_async_load_kv( _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) _after_load = time.perf_counter() - assert _after_load - _before_load < 0.1, "start_load_kv took " \ - f"{_after_load - _before_load} seconds" + assert _after_load - _before_load < 0.1, ( + f"start_load_kv took {_after_load - _before_load} seconds" + ) time.sleep(0.5) # backoff for the async handshake to complete. connector.bind_connector_metadata(NixlConnectorMetadata()) _, done_recving = connector.get_finished(finished_req_ids=set()) @@ -441,11 +466,13 @@ def test_async_load_kv( @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) + FakeNixlWrapper, + ) def test_concurrent_load_kv( self, # dist_init is a fixture that initializes the distributed environment. - dist_init): + dist_init, + ): """Test that multiple start_load_kv calls should occur concurrently.""" vllm_config = create_vllm_config() @@ -453,20 +480,22 @@ def test_concurrent_load_kv( # Test worker role in decode server. connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id) + vllm_config, connector.engine_id + ) metadata = NixlConnectorMetadata() total_reqs = 5 for i in range(total_reqs): - metadata.add_new_req(request_id=f"id_{i}", - local_block_ids=[1, 2, 3], - kv_transfer_params={ - "remote_block_ids": [4, 5, 6], - "remote_engine_id": - FakeNixlConnectorWorker.REMOTE_ENGINE_ID, - "remote_host": "localhost", - "remote_port": 1234, - "remote_tp_size": 1, - }) + metadata.add_new_req( + request_id=f"id_{i}", + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) connector.bind_connector_metadata(metadata) timeout = 2.5 * total_reqs @@ -481,8 +510,9 @@ def test_concurrent_load_kv( _before_load = time.perf_counter() connector.start_load_kv(dummy_ctx) _after_load = time.perf_counter() - assert _after_load - _before_load < 0.1, "start_load_kv took " \ - f"{_after_load - _before_load} seconds" + assert _after_load - _before_load < 0.1, ( + f"start_load_kv took {_after_load - _before_load} seconds" + ) time.sleep(0.5) # backoff for the async handshake to complete. connector.bind_connector_metadata(NixlConnectorMetadata()) _, done_recving = connector.get_finished(finished_req_ids=set()) @@ -494,7 +524,8 @@ def test_concurrent_load_kv( @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) + FakeNixlWrapper, + ) def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): """ Verify that adding a remote agent fails if kv_cache_layout differs. @@ -505,12 +536,14 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): # Mock TP world size to 2 to force heterogeneous TP when # remote_tp_size=1 with patch( - "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 - return_value=2): + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501 + return_value=2, + ): # Initialize connector and worker (with fake NIXL wrapper) connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id, hand_shake_latency=0) + vllm_config, connector.engine_id, hand_shake_latency=0 + ) worker = connector.connector_worker # Minimal local registration params used by add_remote_agent @@ -520,8 +553,7 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): worker.dst_num_blocks[worker.engine_id] = worker.num_blocks # Metadata with different kv_cache_layout than local worker - mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \ - else "NHD" + mismatched_layout = "HND" if worker.kv_cache_layout != "HND" else "NHD" meta = NixlAgentMetadata( engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, agent_metadata=FakeNixlWrapper.AGENT_METADATA, @@ -544,16 +576,17 @@ def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init): # the rest of the tests. @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) + FakeNixlWrapper, +) def test_kv_connector_stats(dist_init): """Test that KV transfer stats are properly recorded and retrieved.""" vllm_config = create_vllm_config() # Test worker role in decode server. connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) - connector.connector_worker = FakeNixlConnectorWorker(vllm_config, - connector.engine_id, - hand_shake_latency=0) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) # Verify that xfer_stats starts empty initial_stats = connector.get_kv_connector_stats() @@ -562,16 +595,17 @@ def test_kv_connector_stats(dist_init): # Create transfer metadata request_id = "test_req_for_stats" metadata = NixlConnectorMetadata() - metadata.add_new_req(request_id=request_id, - local_block_ids=[1, 2, 3], - kv_transfer_params={ - "remote_block_ids": [4, 5, 6], - "remote_engine_id": - FakeNixlConnectorWorker.REMOTE_ENGINE_ID, - "remote_host": "localhost", - "remote_port": 1234, - "remote_tp_size": 1, - }) + metadata.add_new_req( + request_id=request_id, + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": FakeNixlConnectorWorker.REMOTE_ENGINE_ID, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": 1, + }, + ) connector.bind_connector_metadata(metadata) # Start the transfer @@ -592,8 +626,7 @@ def test_kv_connector_stats(dist_init): _, done_recving = connector.get_finished(finished_req_ids=set()) if len(done_recving) > 0 and request_id in done_recving: break - time.sleep( - 0.1) # Small delay to allow background handshake to complete + time.sleep(0.1) # Small delay to allow background handshake to complete else: assert "Transfer did not complete within expected iterations" @@ -612,7 +645,7 @@ def test_kv_connector_stats(dist_init): def test_kv_connector_stats_aggregation(): """ - Test KV transfer stats aggregation across TP ranks using + Test KV transfer stats aggregation across TP ranks using KVOutputAggregator (used by MultiprocExecutor). """ @@ -635,18 +668,16 @@ def test_kv_connector_stats_aggregation(): worker2_stats.record_transfer(stats) # Worker 3: 3 transfers - stats = get_default_xfer_telemetry(xferDurationS=2, - postDurationS=2, - totalBytes=2, - descCount=2) + stats = get_default_xfer_telemetry( + xferDurationS=2, postDurationS=2, totalBytes=2, descCount=2 + ) worker3_stats.record_transfer(stats) worker3_stats.record_transfer(stats) worker3_stats.record_transfer(stats) # Create ModelRunnerOutput instances for each worker worker_outputs = [] - for i, worker_stats in enumerate( - [worker1_stats, worker2_stats, worker3_stats]): + for i, worker_stats in enumerate([worker1_stats, worker2_stats, worker3_stats]): output = ModelRunnerOutput( req_ids=[f"req_{i}"], req_id_to_index={f"req_{i}": 0}, @@ -656,17 +687,19 @@ def test_kv_connector_stats_aggregation(): pooler_output=[None], kv_connector_output=KVConnectorOutput( finished_sending=set([f"req_{i}_send"]) - if i < 2 else None, # Workers 0,1 finished sending + if i < 2 + else None, # Workers 0,1 finished sending finished_recving=set([f"req_{i}_recv"]) - if i > 0 else None, # Workers 1,2 finished receiving + if i > 0 + else None, # Workers 1,2 finished receiving kv_connector_stats=worker_stats, - )) + ), + ) worker_outputs.append(output) # Use the real aggregation mechanism (like MultiprocExecutor.execute_model) aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0) - kv_connector_stats = \ - aggregated_output.kv_connector_output.kv_connector_stats + kv_connector_stats = aggregated_output.kv_connector_output.kv_connector_stats assert isinstance(kv_connector_stats, NixlKVConnectorStats) # Number of total transfers across all workers. assert kv_connector_stats.num_successful_transfers == 6 @@ -690,7 +723,6 @@ def test_multi_kv_connector_stats_aggregation(): # Mock a KVConnectorStats class for testing aggregation over connectors. @dataclass class FooKVConnectorStats(KVConnectorStats): - def reset(self): self.data = {"num_foo_transfers": 0} @@ -702,15 +734,12 @@ def record_transfer(self): def is_empty(self) -> bool: return self.data["num_foo_transfers"] == 0 - def aggregate(self, - other: "FooKVConnectorStats") -> "FooKVConnectorStats": + def aggregate(self, other: "FooKVConnectorStats") -> "FooKVConnectorStats": if not other.is_empty(): - self.data["num_foo_transfers"] += other.data[ - "num_foo_transfers"] + self.data["num_foo_transfers"] += other.data["num_foo_transfers"] return self - def make_multi_stats(nixl_count: int, - foo_count: int) -> MultiKVConnectorStats: + def make_multi_stats(nixl_count: int, foo_count: int) -> MultiKVConnectorStats: data: dict[str, KVConnectorStats] = {} if nixl_count > 0: nixl_stats = NixlKVConnectorStats() @@ -746,13 +775,11 @@ def make_multi_stats(nixl_count: int, worker_outputs.append(output) aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0) - kv_connector_stats = \ - aggregated_output.kv_connector_output.kv_connector_stats + kv_connector_stats = aggregated_output.kv_connector_output.kv_connector_stats assert isinstance(kv_connector_stats, MultiKVConnectorStats) # Validate per-connector totals across workers - assert isinstance(kv_connector_stats["NixlConnector"], - NixlKVConnectorStats) + assert isinstance(kv_connector_stats["NixlConnector"], NixlKVConnectorStats) assert kv_connector_stats["NixlConnector"].num_successful_transfers == 5 assert isinstance(kv_connector_stats["FooConnector"], FooKVConnectorStats) assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6 @@ -761,11 +788,12 @@ def make_multi_stats(nixl_count: int, @pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) + FakeNixlWrapper, +) def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): """ Test lifecycle of an aborted Remote Prefill request hitting the timeout. - -----> P + -----> P | {process request} <-/--- | {result is NOT delivered, eg proxy is down} | @@ -822,39 +850,38 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int): sampling_params = SamplingParams( temperature=0.0, max_tokens=1, - extra_args={"kv_transfer_params": remote_prefill_opts}) + extra_args={"kv_transfer_params": remote_prefill_opts}, + ) scheduler = llm.llm_engine.engine_core.engine_core.scheduler req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks + 0 + ].req_to_blocks padding = "Just making this request a little longer so that we're sure " "we're not hitting the small-request lower bound beneath which we don't " "actually trigger the whole kv transfer, but rather just recompute the " "blocks on D." - _ = llm.generate([f"What is the capital of Japan? {padding}"], - sampling_params) + _ = llm.generate([f"What is the capital of Japan? {padding}"], sampling_params) # Request finished but not freed - assert '0' in scheduler.finished_req_ids and '0' in req_to_blocks + assert "0" in scheduler.finished_req_ids and "0" in req_to_blocks # Some other request, 0 still not freed - _ = llm.generate([f"What is the capital of Italy? {padding}"], - sampling_params) - assert '0' in req_to_blocks - assert '1' in scheduler.finished_req_ids and '1' in req_to_blocks + _ = llm.generate([f"What is the capital of Italy? {padding}"], sampling_params) + assert "0" in req_to_blocks + assert "1" in scheduler.finished_req_ids and "1" in req_to_blocks # Wait for timeout and trigger another scheduler loop time.sleep(timeout) - _ = llm.generate([f"What is the capital of France? {padding}"], - sampling_params) + _ = llm.generate([f"What is the capital of France? {padding}"], sampling_params) # Request-0 times out and is cleared! - assert '0' not in req_to_blocks + assert "0" not in req_to_blocks def test_register_kv_caches(dist_init): """ Test that register_kv_caches() properly calls nixl_wrapper methods with correct data. - + This test verifies: 1. nixl_wrapper.get_reg_descs() is called with caches_data containing tensor metadata @@ -865,10 +892,9 @@ def test_register_kv_caches(dist_init): vllm_config = create_vllm_config() # Create test kv cache tensors using proper backend shape - kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2, - block_size=16, - num_kv_heads=4, - head_size=64) + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16) kv_caches = { @@ -878,21 +904,30 @@ def test_register_kv_caches(dist_init): } # Store tensor info for validation - expected_tensor_size = shared_tensor[0].element_size( - ) * shared_tensor[0].numel() + expected_tensor_size = shared_tensor[0].element_size() * shared_tensor[0].numel() expected_base_addrs = [ - shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(), - unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr() + shared_tensor[0].data_ptr(), + shared_tensor[1].data_ptr(), + unique_tensor[0].data_ptr(), + unique_tensor[1].data_ptr(), ] - with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \ - patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \ - patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501 - + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" + ) as mock_nixl_wrapper, + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread" + ), + ): # noqa: E501 # Create connector connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector.connector_worker = FakeNixlConnectorWorker( - vllm_config, connector.engine_id, hand_shake_latency=0) + vllm_config, connector.engine_id, hand_shake_latency=0 + ) # Get the mock instance mock_wrapper_instance = mock_nixl_wrapper.return_value @@ -908,12 +943,13 @@ def test_register_kv_caches(dist_init): for i, cache_entry in enumerate(caches_data): base_addr, size, _tp_rank, _ = cache_entry - assert size == expected_tensor_size, \ - f"Entry {i}: Expected tensor size {expected_tensor_size}, " \ - f"got {size}" - assert base_addr == expected_base_addrs[i], \ - f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \ + assert size == expected_tensor_size, ( + f"Entry {i}: Expected tensor size {expected_tensor_size}, got {size}" + ) + assert base_addr == expected_base_addrs[i], ( + f"Entry {i}: Expected base address {expected_base_addrs[i]}, " f"got {base_addr}" + ) # Verify get_xfer_descs was called with blocks_data assert mock_wrapper_instance.get_xfer_descs.called @@ -921,16 +957,17 @@ def test_register_kv_caches(dist_init): # Validate blocks_data structure and size expected_blocks_count = 8 - assert len(blocks_data) == expected_blocks_count, \ - f"Expected {expected_blocks_count} blocks, " \ - f"got {len(blocks_data)}" + assert len(blocks_data) == expected_blocks_count, ( + f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}" + ) expected_block_len = expected_tensor_size // 2 for i, block_entry in enumerate(blocks_data): block_start_addr, block_len, tp_rank = block_entry - assert block_len == expected_block_len, \ - f"Block entry {i}: Expected block len {expected_block_len}, " \ + assert block_len == expected_block_len, ( + f"Block entry {i}: Expected block len {expected_block_len}, " f"got {block_len}" + ) class FakePlatform(Platform): @@ -939,24 +976,26 @@ class FakePlatform(Platform): @classmethod def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: """ - Returns a mapping from device_type to a tuple of supported + Returns a mapping from device_type to a tuple of supported kv_buffer_device for nixl. """ - return {'oot': ('oot', )} + return {"oot": ("oot",)} @classmethod def get_nixl_memory_type(cls) -> Optional[str]: """ Returns the nixl memory type for the current platform. """ - return 'VRAM' + return "VRAM" -@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [ - ("oot", "VRAM"), -]) -def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, - nixl_memory_type): +@pytest.mark.parametrize( + "kv_buffer_device, nixl_memory_type", + [ + ("oot", "VRAM"), + ], +) +def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, nixl_memory_type): """ Test that register_kv_caches() passes the correct memory types from the config to the nixl_wrapper. @@ -965,15 +1004,30 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, # Override the default memory types in the config vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( - _NIXL_SUPPORTED_DEVICE) - _NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices()) + _NIXL_SUPPORTED_DEVICE, + ) - with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \ - patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \ - patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \ - patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \ - patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501 + _NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices()) + with ( + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread" + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", + FakePlatform, + ), + patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", + _NIXL_SUPPORTED_DEVICE, + ), + ): # noqa: E501 # Create connector and replace its worker with a fake one for isolation connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) @@ -984,22 +1038,23 @@ def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device, @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", - FakeNixlWrapper) + FakeNixlWrapper, +) def test_shutdown_cleans_up_resources(dist_init): """Test that shutdown() properly cleans up all resources.""" vllm_config = create_vllm_config() - worker = NixlConnectorWorker(vllm_config, - vllm_config.kv_transfer_config.engine_id) + worker = NixlConnectorWorker(vllm_config, vllm_config.kv_transfer_config.engine_id) nixl_wrapper = worker.nixl_wrapper - with patch.object(worker, '_handshake_initiation_executor') as mock_exec, \ - patch.object(worker, '_nixl_handshake_listener_t') as mock_listener, \ - patch.object(nixl_wrapper, 'release_xfer_handle') as mock_rel_xfer, \ - patch.object(nixl_wrapper, 'release_dlist_handle') as mock_rel_dlist, \ - patch.object(nixl_wrapper, 'remove_remote_agent') as mock_rem_agent, \ - patch.object(nixl_wrapper, 'deregister_memory') as mock_dereg: - + with ( + patch.object(worker, "_handshake_initiation_executor") as mock_exec, + patch.object(worker, "_nixl_handshake_listener_t") as mock_listener, + patch.object(nixl_wrapper, "release_xfer_handle") as mock_rel_xfer, + patch.object(nixl_wrapper, "release_dlist_handle") as mock_rel_dlist, + patch.object(nixl_wrapper, "remove_remote_agent") as mock_rem_agent, + patch.object(nixl_wrapper, "deregister_memory") as mock_dereg, + ): worker._recving_transfers = {"req1": [(123, time.perf_counter())]} worker.src_xfer_side_handle = 456 worker.dst_xfer_side_handles = {"engine1": 789} @@ -1023,3 +1078,69 @@ def test_shutdown_cleans_up_resources(dist_init): assert mock_dereg.call_count == 2 mock_dereg.assert_any_call("desc1") mock_dereg.assert_any_call("desc2") + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_aborted_request_removed_from_worker_in_batch(dist_init): + """ + Create and schedule a request so that P adds it to in-batch tracking via + the real scheduler, then simulate an abort (request not in next scheduler + iteration) and verify the worker no longer tracks it as in-batch. + """ + vllm_config = create_vllm_config() + + scheduler = create_scheduler(vllm_config) + # KVConnector Worker in P + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + + # Create a request that triggers do_remote_decode so that + # the scheduler adds it to reqs_in_batch + req = create_request(request_id=1, do_remote_decode=True, max_tokens=1) + scheduler.add_request(req) + + # First scheduling pass - examinate build_connector_meta output + sched_out = scheduler.schedule() + kv_meta = sched_out.kv_connector_metadata + assert kv_meta is not None + assert isinstance(kv_meta, NixlConnectorMetadata) + assert req.request_id in kv_meta.reqs_in_batch + + #### Model Runner start #### + # Bind scheduler-produced metadata and start worker processing. + connector.bind_connector_metadata(kv_meta) + + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + connector.start_load_kv(dummy_ctx) + + # Ensure it was tracked by the worker + assert req.request_id in connector.connector_worker._reqs_to_process + + #### Model Runner end #### + + # Abort request - request_finished call in connector scheduler + scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED) + # Second scheduling pass - build metadata with aborted request + sched_out2 = scheduler.schedule() + kv_meta2 = sched_out2.kv_connector_metadata + assert kv_meta2 is not None + assert isinstance(kv_meta2, NixlConnectorMetadata) + assert req.request_id not in kv_meta2.reqs_in_batch + + # Bind empty/abort metadata and run worker step + #### Model Runner start #### + connector.bind_connector_metadata(kv_meta2) + connector.start_load_kv(dummy_ctx) + + # After abort, the worker should not keep tracking it as "in-batch" + assert req.request_id not in connector.connector_worker._reqs_to_process + #### Model Runner end #### diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index f728b25d7834..46a5c097094e 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -14,27 +14,42 @@ from vllm.distributed.kv_events import BlockRemoved, BlockStored from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( - OffloadingConnector, OffloadingConnectorMetadata) + OffloadingConnector, + OffloadingConnectorMetadata, +) from vllm.forward_context import ForwardContext from vllm.utils import sha256 -from vllm.v1.core.kv_cache_utils import (BlockHash, get_request_block_hasher, - init_none_hash) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + get_request_block_hasher, + init_none_hash, +) from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent, - OffloadingManager, PrepareStoreOutput) +from vllm.v1.kv_offload.abstract import ( + LoadStoreSpec, + OffloadingEvent, + OffloadingManager, + PrepareStoreOutput, +) from vllm.v1.kv_offload.mediums import GPULoadStoreSpec from vllm.v1.kv_offload.spec import OffloadingSpec -from vllm.v1.kv_offload.worker.worker import (OffloadingHandler, - TransferResult, TransferSpec) +from vllm.v1.kv_offload.worker.worker import ( + OffloadingHandler, + TransferResult, + TransferSpec, +) from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.request import Request -from .utils import (EOS_TOKEN_ID, create_model_runner_output, create_scheduler, - create_vllm_config) +from .utils import ( + EOS_TOKEN_ID, + create_model_runner_output, + create_scheduler, + create_vllm_config, +) class MockLoadStoreSpec(LoadStoreSpec): - def __init__(self, block_hashes: Iterable[BlockHash]): self.block_hashes: list[BlockHash] = list(block_hashes) @@ -47,7 +62,6 @@ def __repr__(self) -> str: class MockOffloadingHandler(OffloadingHandler): - def __init__(self): self.completed_transfers: list[TransferResult] = [] self.completed_specs: list[TransferSpec] = [] @@ -64,14 +78,14 @@ def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: class MockOffloadingSpec(OffloadingSpec): - def __init__(self, vllm_config: VllmConfig): super().__init__(vllm_config) self.manager = MagicMock(spec=OffloadingManager) self.manager.lookup.return_value = 0 - self.manager.prepare_load = lambda block_hashes: (MockLoadStoreSpec( - block_hashes)) + self.manager.prepare_load = lambda block_hashes: ( + MockLoadStoreSpec(block_hashes) + ) self.handler = MockOffloadingHandler() def get_manager(self) -> OffloadingManager: @@ -79,9 +93,7 @@ def get_manager(self) -> OffloadingManager: def get_handlers( self, _ - ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], - OffloadingHandler]]: - + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler @@ -98,35 +110,35 @@ class TransferSummary: class RequestRunner: - - def __init__(self, offloaded_block_size: int, gpu_block_size: int, - num_gpu_blocks: int): + def __init__( + self, offloaded_block_size: int, gpu_block_size: int, num_gpu_blocks: int + ): self.offloaded_block_size: int = offloaded_block_size self.gpu_block_size: int = gpu_block_size self.num_gpu_blocks: int = num_gpu_blocks self.req_id: int = -1 - vllm_config = create_vllm_config(block_size=gpu_block_size, - max_num_batched_tokens=1000) + vllm_config = create_vllm_config( + block_size=gpu_block_size, max_num_batched_tokens=1000 + ) vllm_config.kv_transfer_config = KVTransferConfig( kv_connector="OffloadingConnector", kv_role="kv_both", kv_connector_extra_config={ "spec_name": "MockOffloadingSpec", - "spec_module_path": - "tests.v1.kv_connector.unit.test_offloading_connector", + "spec_module_path": "tests.v1.kv_connector.unit.test_offloading_connector", # noqa: E501 "block_size": offloaded_block_size, - }) + }, + ) - self.scheduler: Scheduler = create_scheduler(vllm_config, - num_blocks=num_gpu_blocks) - self.worker_connector = OffloadingConnector(vllm_config, - KVConnectorRole.WORKER) + self.scheduler: Scheduler = create_scheduler( + vllm_config, num_blocks=num_gpu_blocks + ) + self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER) # register worker kv_caches to enable OffloadingWorker creations - self.worker_connector.register_kv_caches( - kv_caches={"a": torch.empty(0)}) + self.worker_connector.register_kv_caches(kv_caches={"a": torch.empty(0)}) # extract connector of scheduler scheduler_connector = self.scheduler.connector @@ -166,9 +178,9 @@ def __init__(self, offloaded_block_size: int, gpu_block_size: int, init_none_hash(sha256) self._block_hasher = get_request_block_hasher(gpu_block_size, sha256) - self._dummy_ctx: ForwardContext = ForwardContext(no_compile_layers={}, - attn_metadata={}, - virtual_engine=0) + self._dummy_ctx: ForwardContext = ForwardContext( + no_compile_layers={}, attn_metadata={}, virtual_engine=0 + ) def new_request(self, token_ids: list[int]): assert not self.scheduler.requests @@ -189,8 +201,7 @@ def _wait_for_transfers(self): block_size_factor = self.offloaded_block_size // self.gpu_block_size while self.pending_loads_count or self.pending_stores_count: - for transfer_spec in ( - self.offloading_spec.get_completed_transfers()): + for transfer_spec in self.offloading_spec.get_completed_transfers(): src_spec, dst_spec = transfer_spec if isinstance(src_spec, GPULoadStoreSpec): @@ -207,8 +218,7 @@ def _wait_for_transfers(self): gpu_block_indices: list[int] = [] for block_id in gpu_spec.block_ids: - gpu_block_indices.append( - self.gpu_block_index[block_id.item()]) + gpu_block_indices.append(self.gpu_block_index[block_id.item()]) # list of (block_hash, sub_block_offset) offload_addresses: list[Any] = [] @@ -220,23 +230,26 @@ def _wait_for_transfers(self): assert len(gpu_block_indices) == len(offload_addresses) self.completed_stores.append( - TransferSummary(gpu_block_indices, offload_addresses)) + TransferSummary(gpu_block_indices, offload_addresses) + ) self.pending_stores_count -= 1 else: - remainder_sub_block_count = (len(offload_addresses) - - len(gpu_block_indices)) + remainder_sub_block_count = len(offload_addresses) - len( + gpu_block_indices + ) assert remainder_sub_block_count >= 0 assert remainder_sub_block_count < block_size_factor - offload_addresses = offload_addresses[ - remainder_sub_block_count:] + offload_addresses = offload_addresses[remainder_sub_block_count:] self.completed_loads.append( - TransferSummary(gpu_block_indices, offload_addresses)) + TransferSummary(gpu_block_indices, offload_addresses) + ) self.pending_loads_count -= 1 def _update_gpu_block_idx(self): - for blocks in (self.scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks.values()): + for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].req_to_blocks.values(): for block_idx, block in enumerate(blocks): self.gpu_block_index[block.block_id] = block_idx @@ -259,23 +272,20 @@ def _run(self, decoded_tokens: list[int]): kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None - assert isinstance(kv_connector_metadata, - OffloadingConnectorMetadata) + assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) self.pending_loads_count += len(kv_connector_metadata.reqs_to_load) - self.pending_stores_count += len( - kv_connector_metadata.reqs_to_store) + self.pending_stores_count += len(kv_connector_metadata.reqs_to_store) - self.worker_connector.bind_connector_metadata( - kv_connector_metadata) + self.worker_connector.bind_connector_metadata(kv_connector_metadata) self.worker_connector.start_load_kv(self._dummy_ctx) if scheduler_output.total_num_scheduled_tokens > 0: self.worker_connector.wait_for_save() - finished_sending, finished_recving = ( - self.worker_connector.get_finished( - scheduler_output.finished_req_ids)) + finished_sending, finished_recving = self.worker_connector.get_finished( + scheduler_output.finished_req_ids + ) self.worker_connector.clear_connector_metadata() @@ -283,13 +293,13 @@ def _run(self, decoded_tokens: list[int]): reqs=self.scheduler.running, finished_sending=finished_sending, finished_recving=finished_recving, - token_id=token_id) + token_id=token_id, + ) if self.scheduler.running: token_id = next(tokens_iter, None) - self.scheduler.update_from_output(scheduler_output, - model_runner_output) + self.scheduler.update_from_output(scheduler_output, model_runner_output) self._wait_for_transfers() @@ -300,24 +310,24 @@ def _run(self, decoded_tokens: list[int]): while self.scheduler.requests: scheduler_output = self.scheduler.schedule() - finished_sending, finished_recving = ( - self.worker_connector.get_finished( - scheduler_output.finished_req_ids)) + finished_sending, finished_recving = self.worker_connector.get_finished( + scheduler_output.finished_req_ids + ) assert not finished_recving model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending=finished_sending) + finished_sending=finished_sending + ) - self.scheduler.update_from_output(scheduler_output, - model_runner_output) + self.scheduler.update_from_output(scheduler_output, model_runner_output) def run( - self, - decoded_tokens: list[int], - expected_stored_gpu_block_indexes: tuple[int, ...] = (), - expected_loaded_gpu_block_indexes: tuple[int, ...] = (), + self, + decoded_tokens: list[int], + expected_stored_gpu_block_indexes: tuple[int, ...] = (), + expected_loaded_gpu_block_indexes: tuple[int, ...] = (), ): """ Runs multiple engine (scheduler + worker) steps. @@ -337,23 +347,23 @@ def run( loaded_gpu_block_indexes: set[int] = set() for transfer in self.completed_loads: for gpu_block_idx, offloaded_address in zip( - transfer.gpu_block_indices, transfer.offload_addresses): + transfer.gpu_block_indices, transfer.offload_addresses + ): loaded_gpu_block_indexes.add(gpu_block_idx) assert gpu_block_idx == self.offloaded[offloaded_address] - assert ( - set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes) + assert set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes self.completed_loads.clear() stored_gpu_block_indexes: set[int] = set() for transfer in self.completed_stores: for gpu_block_idx, offloaded_address in zip( - transfer.gpu_block_indices, transfer.offload_addresses): + transfer.gpu_block_indices, transfer.offload_addresses + ): stored_gpu_block_indexes.add(gpu_block_idx) self.offloaded[offloaded_address] = gpu_block_idx - assert ( - set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes) + assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes self.completed_stores.clear() @@ -362,9 +372,11 @@ def request_runner(): runners = [] def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks): - runner = RequestRunner(offloaded_block_size=offloaded_block_size, - gpu_block_size=gpu_block_size, - num_gpu_blocks=num_gpu_blocks) + runner = RequestRunner( + offloaded_block_size=offloaded_block_size, + gpu_block_size=gpu_block_size, + num_gpu_blocks=num_gpu_blocks, + ) runners.append(runner) return runner @@ -386,15 +398,18 @@ def test_offloading_connector(request_runner): num_gpu_blocks = 100 block_size_factor = offloaded_block_size // gpu_block_size - runner = request_runner(offloaded_block_size=offloaded_block_size, - gpu_block_size=gpu_block_size, - num_gpu_blocks=num_gpu_blocks) + runner = request_runner( + offloaded_block_size=offloaded_block_size, + gpu_block_size=gpu_block_size, + num_gpu_blocks=num_gpu_blocks, + ) # 3 blocks, store just the middle block (skip first and last) # blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8] runner.new_request(token_ids=[0] * offloaded_block_size * 3) - runner.manager.prepare_store.side_effect = \ + runner.manager.prepare_store.side_effect = ( lambda block_hashes: generate_store_output(list(block_hashes)[1:2]) + ) runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5)) # add block missing 1 token -> no offload @@ -402,21 +417,24 @@ def test_offloading_connector(request_runner): runner.manager.prepare_store.assert_not_called() # +1 token -> single block, fail prepare_store - runner.manager.prepare_store.side_effect = \ - lambda block_hashes: None + runner.manager.prepare_store.side_effect = lambda block_hashes: None runner.run(decoded_tokens=[0]) runner.manager.prepare_store.assert_called() # 1 more block, now set block_hashes_to_store = [] - runner.manager.prepare_store.side_effect = \ + runner.manager.prepare_store.side_effect = ( lambda block_hashes: generate_store_output([]) + ) runner.run(decoded_tokens=[0] * offloaded_block_size) # 1 more block, now check touch was called with all 6 blocks - runner.manager.prepare_store.side_effect = \ + runner.manager.prepare_store.side_effect = ( lambda block_hashes: generate_store_output(block_hashes) - runner.run(decoded_tokens=[0] * offloaded_block_size, - expected_stored_gpu_block_indexes=(15, 16, 17)) + ) + runner.run( + decoded_tokens=[0] * offloaded_block_size, + expected_stored_gpu_block_indexes=(15, 16, 17), + ) runner.manager.touch.assert_called() block_hashes1 = list(runner.manager.touch.call_args.args[0]) assert len(block_hashes1) == 6 @@ -426,9 +444,10 @@ def test_offloading_connector(request_runner): # create a new request differing only on the last token runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1]) - runner.run(decoded_tokens=[0], - expected_stored_gpu_block_indexes=tuple( - range(6 * block_size_factor))) + runner.run( + decoded_tokens=[0], + expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)), + ) runner.manager.touch.assert_called() block_hashes2 = list(runner.manager.touch.call_args.args[0]) assert len(block_hashes2) == 6 @@ -441,17 +460,20 @@ def test_offloading_connector(request_runner): runner.run(decoded_tokens=[EOS_TOKEN_ID]) # full_block_tokens - num_computed_tokens < offloaded_block_size - runner.new_request(token_ids=[0] * gpu_block_size + [1] * - (offloaded_block_size - gpu_block_size)) - runner.manager.prepare_store.side_effect = \ + runner.new_request( + token_ids=[0] * gpu_block_size + [1] * (offloaded_block_size - gpu_block_size) + ) + runner.manager.prepare_store.side_effect = ( lambda block_hashes: generate_store_output([]) + ) runner.run(decoded_tokens=[EOS_TOKEN_ID]) runner.manager.lookup.assert_not_called() # single block lookup with no hits runner.new_request(token_ids=[1] * offloaded_block_size) - runner.manager.prepare_store.side_effect = \ + runner.manager.prepare_store.side_effect = ( lambda block_hashes: generate_store_output([]) + ) runner.run(decoded_tokens=[EOS_TOKEN_ID]) runner.manager.lookup.assert_called() assert len(list(runner.manager.lookup.call_args.args[0])) == 1 @@ -459,34 +481,37 @@ def test_offloading_connector(request_runner): # single block lookup with a hit runner.scheduler.reset_prefix_cache() runner.new_request(token_ids=[0] * offloaded_block_size) - runner.manager.prepare_store.side_effect = \ + runner.manager.prepare_store.side_effect = ( lambda block_hashes: generate_store_output([]) + ) runner.manager.lookup.return_value = 1 - runner.run(decoded_tokens=[EOS_TOKEN_ID], - expected_loaded_gpu_block_indexes=(0, 1, 2)) + runner.run( + decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2) + ) # single block lookup with a hit in a middle block - runner.new_request(token_ids=[0] * offloaded_block_size * 2 + - [1] * offloaded_block_size) - runner.manager.prepare_store.side_effect = \ + runner.new_request( + token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size + ) + runner.manager.prepare_store.side_effect = ( lambda block_hashes: generate_store_output([]) + ) runner.manager.lookup.return_value = 1 - runner.run(decoded_tokens=[EOS_TOKEN_ID], - expected_loaded_gpu_block_indexes=(3, 4, 5)) + runner.run( + decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5) + ) # test take_events def to_hashes(int_hashes: list[int]) -> list[BlockHash]: return [BlockHash(str(i).encode()) for i in int_hashes] def take_events() -> Iterable[OffloadingEvent]: - yield OffloadingEvent(block_hashes=to_hashes([1, 2, 3]), - block_size=16, - medium="A", - removed=False) - yield OffloadingEvent(block_hashes=to_hashes([4, 5, 6]), - block_size=32, - medium="B", - removed=True) + yield OffloadingEvent( + block_hashes=to_hashes([1, 2, 3]), block_size=16, medium="A", removed=False + ) + yield OffloadingEvent( + block_hashes=to_hashes([4, 5, 6]), block_size=32, medium="B", removed=True + ) runner.manager.take_events.side_effect = take_events events = list(runner.scheduler_connector.take_events()) diff --git a/tests/v1/kv_connector/unit/test_output_aggreagator.py b/tests/v1/kv_connector/unit/test_output_aggreagator.py index 8c85732297f2..d05cbe1a2fd4 100644 --- a/tests/v1/kv_connector/unit/test_output_aggreagator.py +++ b/tests/v1/kv_connector/unit/test_output_aggreagator.py @@ -12,22 +12,25 @@ class DummyModelRunnerOutput(ModelRunnerOutput): - - def __init__(self, - finished_sending: Optional[set[str]] = None, - finished_recving: Optional[set[str]] = None, - invalid_block_ids: Optional[set[int]] = None): + def __init__( + self, + finished_sending: Optional[set[str]] = None, + finished_recving: Optional[set[str]] = None, + invalid_block_ids: Optional[set[int]] = None, + ): self.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, - invalid_block_ids=invalid_block_ids or set()) + invalid_block_ids=invalid_block_ids or set(), + ) def __repr__(self): return ( f"DummyModelRunnerOutput(" f"finished_sending={self.kv_connector_output.finished_sending}," f"finished_recving={self.kv_connector_output.finished_recving})" - f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})") + f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})" + ) def test_aggregate_workers_output(): @@ -44,8 +47,9 @@ def test_aggregate_workers_output(): assert aggregated.finished_recving is None assert not aggregated.invalid_block_ids - output1 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) + output1 = DummyModelRunnerOutput( + finished_sending={"req1"}, finished_recving={"req2"} + ) output2 = DummyModelRunnerOutput(invalid_block_ids={1}) aggregated = aggregator.aggregate([output1, output2]) @@ -57,26 +61,27 @@ def test_aggregate_workers_output(): assert aggregated.invalid_block_ids == {1} output1 = DummyModelRunnerOutput(invalid_block_ids={2}) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}) + output2 = DummyModelRunnerOutput(finished_sending={"req1"}) aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 aggregated = aggregated.kv_connector_output - assert aggregated.finished_sending == {'req1'} + assert aggregated.finished_sending == {"req1"} assert aggregated.finished_recving is None assert aggregated.invalid_block_ids == {2} output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4}) - output2 = DummyModelRunnerOutput(finished_recving={'req2'}, - invalid_block_ids={4, 5}) + output2 = DummyModelRunnerOutput( + finished_recving={"req2"}, invalid_block_ids={4, 5} + ) aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None - assert aggregated.finished_recving == {'req2'} + assert aggregated.finished_recving == {"req2"} assert aggregated.invalid_block_ids == {3, 4, 5} @@ -104,8 +109,9 @@ def test_async_aggregate_workers_output(): future2 = Future() result_future = aggregator.async_aggregate([future1, future2]) - output1 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) + output1 = DummyModelRunnerOutput( + finished_sending={"req1"}, finished_recving={"req2"} + ) output2 = DummyModelRunnerOutput(invalid_block_ids={1}) future1.set_result(output1) future2.set_result(output2) @@ -123,7 +129,7 @@ def test_async_aggregate_workers_output(): result_future = aggregator.async_aggregate([future1, future2]) output1 = DummyModelRunnerOutput(invalid_block_ids={2}) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}) + output2 = DummyModelRunnerOutput(finished_sending={"req1"}) future1.set_result(output1) future2.set_result(output2) @@ -131,7 +137,7 @@ def test_async_aggregate_workers_output(): aggregated = result_future.result() assert aggregated is output1 aggregated = aggregated.kv_connector_output - assert aggregated.finished_sending == {'req1'} + assert aggregated.finished_sending == {"req1"} assert aggregated.finished_recving is None assert aggregated.invalid_block_ids == {2} @@ -140,8 +146,9 @@ def test_async_aggregate_workers_output(): result_future = aggregator.async_aggregate([future1, future2]) output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4}) - output2 = DummyModelRunnerOutput(finished_recving={'req2'}, - invalid_block_ids={4, 5}) + output2 = DummyModelRunnerOutput( + finished_recving={"req2"}, invalid_block_ids={4, 5} + ) future1.set_result(output1) future2.set_result(output2) @@ -150,5 +157,5 @@ def test_async_aggregate_workers_output(): assert aggregated is output1 aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None - assert aggregated.finished_recving == {'req2'} + assert aggregated.finished_recving == {"req2"} assert aggregated.invalid_block_ids == {3, 4, 5} diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index e2c4d05bba71..e0404186eb2d 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -7,8 +7,13 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.request import FinishReason, RequestStatus -from .utils import (assert_scheduler_empty, create_model_runner_output, - create_request, create_scheduler, create_vllm_config) +from .utils import ( + assert_scheduler_empty, + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) pytestmark = pytest.mark.cpu_test @@ -24,11 +29,13 @@ def test_basic_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - max_tokens=1, - num_tokens=NUM_TOKENS, - do_remote_decode=True) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + ) scheduler.add_request(request) request_id = request.request_id @@ -43,8 +50,9 @@ def test_basic_lifecycle(): model_runner_output = create_model_runner_output(reqs=[request]) # (1c): update_from_output() - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output + ) # Ensure the request is finished after 1 token. assert request.is_finished() @@ -60,7 +68,8 @@ def test_basic_lifecycle(): # ... but blocks should not be freed. blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_id] + 0 + ].req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 @@ -92,7 +101,8 @@ def test_basic_lifecycle(): # (3b): execute_model() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending={request_id}) + finished_sending={request_id} + ) # (3c): update_from_output() scheduler.update_from_output(scheduler_output, model_runner_output) @@ -110,11 +120,13 @@ def test_short_prompt_lifecycle(): # Not enough tokens for full block. BLOCK_SIZE = vllm_config.cache_config.block_size NUM_TOKENS = BLOCK_SIZE // 2 - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - max_tokens=1, - num_tokens=NUM_TOKENS, - do_remote_decode=True) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + max_tokens=1, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + ) scheduler.add_request(request) @@ -132,14 +144,15 @@ def test_short_prompt_lifecycle(): eco = scheduler.update_from_output(scheduler_output, model_runner_output) kv_transfer_params = eco[0].outputs[0].kv_transfer_params - assert (len(kv_transfer_params["remote_block_ids"]) == 1) + assert len(kv_transfer_params["remote_block_ids"]) == 1 # Confirm we do not have any memory leaks after req lifecycle. # We need to mark sending finish to clear data for persistent batch. scheduler_output = scheduler.schedule() # Use create_model_runner_output to pass kv_connector_output along model_runner_output = create_model_runner_output( - reqs=[request], finished_sending={request.request_id}) + reqs=[request], finished_sending={request.request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert_scheduler_empty(scheduler) @@ -155,14 +168,15 @@ def test_prefix_cache_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 3 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_normal = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS) + request_normal = create_request( + request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS + ) scheduler.add_request(request_normal) scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_normal], - use_eos=True) + model_runner_output = create_model_runner_output( + reqs=[request_normal], use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.schedule() scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) @@ -174,10 +188,12 @@ def test_prefix_cache_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS -= 1 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_remote = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_decode=True) + request_remote = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + ) scheduler.add_request(request_remote) scheduler_output = scheduler.schedule() @@ -187,14 +203,13 @@ def test_prefix_cache_lifecycle(): # Ensure we send all block ids, including the partial blocks, # even if there is a cache hit. - assert (len( - kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + - 1)) + assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1) # STEP (2): Ensure it is freed. scheduler_output = scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending={request_remote.request_id}) + finished_sending={request_remote.request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 196483d76e87..b9588ebcd211 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -7,8 +7,13 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.request import FinishReason, RequestStatus -from .utils import (assert_scheduler_empty, create_model_runner_output, - create_request, create_scheduler, create_vllm_config) +from .utils import ( + assert_scheduler_empty, + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) pytestmark = pytest.mark.cpu_test @@ -24,12 +29,15 @@ def test_basic_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) START_FREE_BLOCK_QUEUE_SIZE = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks + ) - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_prefill=True) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) scheduler.add_request(request) request_id = request.request_id @@ -48,16 +56,16 @@ def test_basic_lifecycle(): # Req waiting for KVs with no computed/scheduled toks ... assert len(scheduler.waiting) == 1 assert request in scheduler.waiting - assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) - assert (request.num_computed_tokens == 0) + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert request.num_computed_tokens == 0 # ... but should have (uncached) blocks allocated to it. block_pool = scheduler.kv_cache_manager.block_pool - assert (block_pool.free_block_queue.num_free_blocks - < START_FREE_BLOCK_QUEUE_SIZE) + assert block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE assert len(block_pool.cached_block_hash_to_block) == 0 blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_id] + 0 + ].req_to_blocks[request_id] for block in blocks: assert block._block_hash is None @@ -65,8 +73,9 @@ def test_basic_lifecycle(): model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT # (1c): update_from_output() - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output + ) assert not engine_core_outputs or not engine_core_outputs[0].outputs # STEP (2): @@ -78,13 +87,15 @@ def test_basic_lifecycle(): # (2b): forward(): request finishes recv. model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_recving={request_id}) + finished_recving={request_id} + ) # (2c): update_from_output(): - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output + ) assert len(scheduler.waiting) == 1 - assert (request_id in scheduler.finished_recving_kv_req_ids) + assert request_id in scheduler.finished_recving_kv_req_ids # STEP (3): # (3a): schedule(): this should actually schedule. @@ -94,10 +105,11 @@ def test_basic_lifecycle(): # Confirm the block are actually allocated. num_hashed_blocks = 0 blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_id] + 0 + ].req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 - num_hashed_blocks += (1 if block._block_hash is not None else 0) + num_hashed_blocks += 1 if block._block_hash is not None else 0 assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS # Confirm the rest of the prompt is scheduled in this step. @@ -105,7 +117,7 @@ def test_basic_lifecycle(): num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] num_computed_tokens = scheduled_req.num_computed_tokens total_prompt_tokens = len(scheduled_req.prompt_token_ids) - assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + assert num_scheduled_tokens == total_prompt_tokens - num_computed_tokens # (3b): execute_model() model_runner_output = create_model_runner_output([request]) @@ -115,8 +127,9 @@ def test_basic_lifecycle(): # Step (4): Hit EOS. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output([request], use_eos=True) - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output + ) scheduler.schedule() outputs = engine_core_outputs[0].outputs @@ -137,10 +150,12 @@ def test_interleaved_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_remote = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_prefill=True) + request_remote = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) request_local_a = create_request( request_id=2, block_size=BLOCK_SIZE, @@ -169,8 +184,7 @@ def test_interleaved_lifecycle(): assert len(scheduler_output.scheduled_new_reqs) == 1 assert scheduler_output.scheduled_cached_reqs.num_reqs == 1 - model_runner_output = create_model_runner_output( - [request_local_a, request_local_b]) + model_runner_output = create_model_runner_output([request_local_a, request_local_b]) scheduler.update_from_output(scheduler_output, model_runner_output) # STEP 3: continue running, KVs not arrived yet. @@ -181,7 +195,8 @@ def test_interleaved_lifecycle(): assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( - reqs=[request_local_a, request_local_b]) + reqs=[request_local_a, request_local_b] + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 1 @@ -196,8 +211,8 @@ def test_interleaved_lifecycle(): assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( - [request_local_a, request_local_b], - finished_recving={request_remote.request_id}) + [request_local_a, request_local_b], finished_recving={request_remote.request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) # STEP 5: RECVed KVs are sent to ModelRunner. @@ -208,7 +223,8 @@ def test_interleaved_lifecycle(): assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 model_runner_output = create_model_runner_output( - [request_local_a, request_local_b, request_remote]) + [request_local_a, request_local_b, request_remote] + ) scheduler.update_from_output(scheduler_output, model_runner_output) # STEP 6: Hit EOS and free. @@ -273,15 +289,17 @@ def test_no_spurious_prefix_caching(): assert len(scheduler.waiting) == 1 local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_local.request_id] + 0 + ].req_to_blocks[request_local.request_id] remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_remote.request_id] + 0 + ].req_to_blocks[request_remote.request_id] # Local should have cached blocks (but not all due to preallocate). num_hashed_blocks = 0 for block in local_blocks: assert block.ref_cnt == 1 - num_hashed_blocks += (1 if block._block_hash is not None else 0) + num_hashed_blocks += 1 if block._block_hash is not None else 0 assert num_hashed_blocks > 0 # Remote blocks should not be cached. @@ -301,10 +319,12 @@ def test_full_block_prompt(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) - request = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS, - do_remote_prefill=True) + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_prefill=True, + ) scheduler.add_request(request) request_id = request.request_id @@ -312,8 +332,11 @@ def test_full_block_prompt(): # STEP (1): Initialize a recv. scheduler_output = scheduler.schedule() # All blocks should be allocated. - num_blocks = len(scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[request_id]) + num_blocks = len( + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ] + ) assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT scheduler.update_from_output(scheduler_output, model_runner_output) @@ -322,22 +345,25 @@ def test_full_block_prompt(): scheduler_output = scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.kv_connector_output = KVConnectorOutput( - finished_recving={request_id}) + finished_recving={request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.waiting) == 1 - assert (request_id in scheduler.finished_recving_kv_req_ids) + assert request_id in scheduler.finished_recving_kv_req_ids # # STEP (3): Run as usual. scheduler_output = scheduler.schedule() # We need to recompute the final token of the prompt to generate # the first new token, so we should not have a new block. - num_blocks = len(scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[request_id]) + num_blocks = len( + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_id + ] + ) assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS - assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == - NUM_TOKENS - 1) - assert (scheduler_output.num_scheduled_tokens[request_id] == 1) + assert scheduler_output.scheduled_new_reqs[0].num_computed_tokens == NUM_TOKENS - 1 + assert scheduler_output.num_scheduled_tokens[request_id] == 1 model_runner_output = create_model_runner_output([request]) scheduler.update_from_output(scheduler_output, model_runner_output) @@ -345,8 +371,9 @@ def test_full_block_prompt(): # # Step (4): Hit EOS. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output([request], use_eos=True) - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output( + scheduler_output, model_runner_output + ) scheduler.schedule() outputs = engine_core_outputs[0].outputs @@ -375,13 +402,15 @@ def test_cannot_schedule_after_recv(): NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) - request_normal = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS_LOCAL) - request_remote = create_request(request_id=2, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS_REMOTE, - do_remote_prefill=True) + request_normal = create_request( + request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL + ) + request_remote = create_request( + request_id=2, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_REMOTE, + do_remote_prefill=True, + ) # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode). scheduler.add_request(request_normal) @@ -402,7 +431,8 @@ def test_cannot_schedule_after_recv(): # Step 3: finish recving (5 blocks in use) scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output( - reqs=[request_normal], finished_recving={request_remote.request_id}) + reqs=[request_normal], finished_recving={request_remote.request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 @@ -411,7 +441,8 @@ def test_cannot_schedule_after_recv(): # because the transfer is completed. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output( - reqs=[request_normal, request_remote]) + reqs=[request_normal, request_remote] + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 0 @@ -426,8 +457,9 @@ def test_cannot_schedule_after_recv(): # Step 6: finish the request, free it. scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_normal], - use_eos=True) + model_runner_output = create_model_runner_output( + reqs=[request_normal], use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 @@ -436,16 +468,19 @@ def test_cannot_schedule_after_recv(): # request is retrieved from preempted list. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_remote]) - assert (scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] == - NUM_PROMPT_BLOCKS * BLOCK_SIZE) + assert ( + scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] + == NUM_PROMPT_BLOCKS * BLOCK_SIZE + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 0 # Step 8: free everything. scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_remote], - use_eos=True) + model_runner_output = create_model_runner_output( + reqs=[request_remote], use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) @@ -470,13 +505,15 @@ def test_cannot_recv(): NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) - request_normal = create_request(request_id=1, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS_LOCAL) - request_remote = create_request(request_id=2, - block_size=BLOCK_SIZE, - num_tokens=NUM_TOKENS_REMOTE, - do_remote_prefill=True) + request_normal = create_request( + request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL + ) + request_remote = create_request( + request_id=2, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_REMOTE, + do_remote_prefill=True, + ) # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode). scheduler.add_request(request_normal) @@ -495,12 +532,13 @@ def test_cannot_recv(): assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 # Should not have KV transfer in progress. - assert (request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS) + assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS # Step 3: finish the request, free it. scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_normal], - use_eos=True) + model_runner_output = create_model_runner_output( + reqs=[request_normal], use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 @@ -511,12 +549,13 @@ def test_cannot_recv(): scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 - assert (request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS # Step 5: finish recving (5 blocks in use) scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output( - reqs=[], finished_recving={request_remote.request_id}) + reqs=[], finished_recving={request_remote.request_id} + ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 @@ -530,8 +569,9 @@ def test_cannot_recv(): # Step 7: free everything. scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_remote], - use_eos=True) + model_runner_output = create_model_runner_output( + reqs=[request_remote], use_eos=True + ) scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_shared_storage_connector.py b/tests/v1/kv_connector/unit/test_shared_storage_connector.py index 6be261e45cb0..e7013a794a8c 100644 --- a/tests/v1/kv_connector/unit/test_shared_storage_connector.py +++ b/tests/v1/kv_connector/unit/test_shared_storage_connector.py @@ -37,16 +37,22 @@ def _list_path(path): return list(path.iterdir()) -def run_test(tmp_path, processor, llm: LLM, question: str, - image_urls: list[Image], expected_len: int, info: str): +def run_test( + tmp_path, + processor, + llm: LLM, + question: str, + image_urls: list[Image], + expected_len: int, + info: str, +): """ One individual test to process the prompt and output base on 1 set of input Then check if the length in the storage path matches the expected length `info` introduces details or purpose of the individual test """ print(f"***info: {info}***") - print( - f"**Expected storage path length after llm generate: {expected_len}**") + print(f"**Expected storage path length after llm generate: {expected_len}**") process_prompt(processor, llm, question, image_urls) print(f"Path matched expected length: {_check_path_len(tmp_path)}") @@ -54,51 +60,42 @@ def run_test(tmp_path, processor, llm: LLM, question: str, assert _check_path_len(tmp_path) == expected_len, ( f"Expect storage path length {expected_len} ;", - f"but end up {_check_path_len(tmp_path)} instead. ", f"Info: {info}") + f"but end up {_check_path_len(tmp_path)} instead. ", + f"Info: {info}", + ) -def process_prompt(processor, llm: LLM, question: str, - image_urls: list[Image]): +def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]): """ Form the prompt based on the text and image input, then llm generate output """ - placeholders = [{ - "type": "image_url", - "image_url": { - "url": f"data:image;base64,{encode_image_base64(image_pil)}" + placeholders = [ + { + "type": "image_url", + "image_url": {"url": f"data:image;base64,{encode_image_base64(image_pil)}"}, } - } for image_pil in image_urls] + for image_pil in image_urls + ] messages = [ - { - "role": "system", - "content": "You are a helpful assistant." - }, + {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", "content": [ *placeholders, - { - "type": "text", - "text": question - }, + {"type": "text", "text": question}, ], }, ] - prompt = processor.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) outputs = llm.generate( { - "prompt": - prompt, - **({ - "multi_modal_data": { - "image": [*image_urls] - } - } if image_urls else {}) + "prompt": prompt, + **({"multi_modal_data": {"image": [*image_urls]}} if image_urls else {}), }, sampling_params=SAMPLING_PARAMS, ) @@ -114,7 +111,7 @@ def process_prompt(processor, llm: LLM, question: str, def test_shared_storage_connector_hashes(tmp_path): """ Tests that SharedStorageConnector saves KV to the storage locations - with proper hashes; that are unique for inputs with identical text but + with proper hashes; that are unique for inputs with identical text but different images (same size), or same multiple images but different orders. """ # Using tmp_path as the storage path to store KV @@ -124,7 +121,8 @@ def test_shared_storage_connector_hashes(tmp_path): kv_transfer_config = KVTransferConfig( kv_connector="SharedStorageConnector", kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": str(tmp_path)}) + kv_connector_extra_config={"shared_storage_path": str(tmp_path)}, + ) engine_args = EngineArgs( model=MODEL_NAME, @@ -157,56 +155,88 @@ def test_shared_storage_connector_hashes(tmp_path): # Prepare the input cases input_cases = [ - InputCase(text=TEXT_PROMPTS[0], - img=[image_1], - expected_len=1, - info="image_1 single input the first time."), - InputCase(text=TEXT_PROMPTS[0], - img=[image_2], - expected_len=2, - info=("image_2 single input the first time. " - "It is in same pixel size with image_1, yet it " - "should be able to form a new unique hash.")), - InputCase(text=TEXT_PROMPTS[0], - img=[image_1], - expected_len=2, - info=("image_1 single input the 2nd time. " - "It should not form another new hash.")), - InputCase(text=TEXT_PROMPTS[0], - img=[image_2], - expected_len=2, - info=("image_2 single input the 2nd time. " - "It should not form another new hash.")), - InputCase(text=TEXT_PROMPTS[0], - img=[image_1, image_2], - expected_len=3, - info="image_1 with image_2 input the first time."), - InputCase(text=TEXT_PROMPTS[0], - img=[image_2, image_1], - expected_len=4, - info="The image order is swapped. Should form new hash."), - InputCase(text=TEXT_PROMPTS[0], - img=[image_1, image_2], - expected_len=4, - info=("[image_1, image_2] input the 2nd time. " - "It should not form another new hash.")), - InputCase(text=TEXT_PROMPTS[0], - img=[image_2, image_1], - expected_len=4, - info=("[image_2, image_1] input the 2nd time. " - "It should not form another new hash.")), - InputCase(text=TEXT_PROMPTS[0], - img=[], - expected_len=5, - info="Pure text input test as a case-control"), - InputCase(text=TEXT_PROMPTS[0], - img=[], - expected_len=5, - info="Identical pure text input as a case-control"), - InputCase(text=TEXT_PROMPTS[1], - img=[], - expected_len=6, - info="Another pure text input as a case-control"), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_1], + expected_len=1, + info="image_1 single input the first time.", + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_2], + expected_len=2, + info=( + "image_2 single input the first time. " + "It is in same pixel size with image_1, yet it " + "should be able to form a new unique hash." + ), + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_1], + expected_len=2, + info=( + "image_1 single input the 2nd time. " + "It should not form another new hash." + ), + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_2], + expected_len=2, + info=( + "image_2 single input the 2nd time. " + "It should not form another new hash." + ), + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_1, image_2], + expected_len=3, + info="image_1 with image_2 input the first time.", + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_2, image_1], + expected_len=4, + info="The image order is swapped. Should form new hash.", + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_1, image_2], + expected_len=4, + info=( + "[image_1, image_2] input the 2nd time. " + "It should not form another new hash." + ), + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[image_2, image_1], + expected_len=4, + info=( + "[image_2, image_1] input the 2nd time. " + "It should not form another new hash." + ), + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[], + expected_len=5, + info="Pure text input test as a case-control", + ), + InputCase( + text=TEXT_PROMPTS[0], + img=[], + expected_len=5, + info="Identical pure text input as a case-control", + ), + InputCase( + text=TEXT_PROMPTS[1], + img=[], + expected_len=6, + info="Another pure text input as a case-control", + ), ] # Run tests diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 3928cdc37b9d..24c0bd51216d 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -8,19 +8,27 @@ import torch from vllm import SamplingParams -from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, - ModelConfig, SchedulerConfig, VllmConfig) -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) +from vllm.config import ( + CacheConfig, + DeviceConfig, + KVTransferConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa - SharedStorageConnector) + SharedStorageConnector, +) from vllm.utils import sha256 from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, - init_none_hash) +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec) +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, +) from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager @@ -42,14 +50,24 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block) == 0 + assert ( + len( + scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks + ) + == 0 + ) + assert ( + len( + scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0 + ].num_cached_block + ) + == 0 + ) num_free_blocks = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) - assert num_free_blocks == ( - scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks + ) + assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) # NOTE(rob): just the ref count on blocks will be 0. The hash # value, etc will remain since we lazily evict for prefix cache. @@ -90,11 +108,13 @@ def create_vllm_config( kv_connector="NixlConnector", kv_role="kv_both", ) - return VllmConfig(scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - device_config=DeviceConfig("cpu")) + return VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + device_config=DeviceConfig("cpu"), + ) def create_scheduler( @@ -107,9 +127,9 @@ def create_scheduler( num_blocks=num_blocks, # A large number of blocks to hold all requests kv_cache_tensors=[], kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + KVCacheGroupSpec( + ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ) ], ) vllm_config.cache_config.num_gpu_blocks = num_blocks @@ -151,16 +171,16 @@ def create_request( if do_remote_decode: assert not do_remote_prefill - kv_transfer_params = dict(do_remote_prefill=False, - do_remote_decode=True) + kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True) elif do_remote_prefill: - kv_transfer_params = dict(do_remote_prefill=True, - do_remote_decode=False, - remote_engine_id="my-engine-id", - remote_block_ids=list( - range(num_remote_blocks)), - remote_host="my-host", - remote_port=1234) + kv_transfer_params = dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_engine_id="my-engine-id", + remote_block_ids=list(range(num_remote_blocks)), + remote_host="my-host", + remote_port=1234, + ) max_tokens = 1 if do_remote_decode else max_tokens sampling_params = SamplingParams(max_tokens=max_tokens) @@ -200,13 +220,19 @@ def create_model_runner_output( sampled_token = EOS_TOKEN_ID if use_eos else token_id sampled_token_ids = [[sampled_token] for _ in req_ids] - kv_connector_output = None if ( - finished_sending is None and finished_recving is None - and invalid_block_ids is None) else KVConnectorOutput( + kv_connector_output = ( + None + if ( + finished_sending is None + and finished_recving is None + and invalid_block_ids is None + ) + else KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, invalid_block_ids=invalid_block_ids or set(), ) + ) # Make output data structure. return ModelRunnerOutput( @@ -221,22 +247,30 @@ def create_model_runner_output( class TestSharedStorageConnector(SharedStorageConnector): - def __init__(self, config: VllmConfig, role): self.name = config.kv_transfer_config.kv_connector_extra_config["name"] self._connector = SharedStorageConnector(config, role) self.call_record: dict[str, int] = defaultdict(int) # Use a unique temp file per connector - self._event_file = tempfile.gettempdir( - ) + f"/connector_{self.name}-{self.role.name}_events.log" + self._event_file = ( + tempfile.gettempdir() + + f"/connector_{self.name}-{self.role.name}_events.log" + ) # Start with an empty file with open(self._event_file, "w") as _: pass def __getattribute__(self, name): - if name in ("_connector", "call_record", "name", "_event_file", - "__class__", "__dict__", "__getattribute__", - "__init__"): # avoid recursion + if name in ( + "_connector", + "call_record", + "name", + "_event_file", + "__class__", + "__dict__", + "__getattribute__", + "__init__", + ): # avoid recursion return object.__getattribute__(self, name) if not hasattr(self._connector, name): return object.__getattribute__(self, name) @@ -255,21 +289,20 @@ def wrapper(*args, **kwargs): if isinstance(arg, int): to_log.append(str(arg)) elif isinstance(arg, KVCacheBlocks): - to_log.append( - f"num_blocks={[len(b) for b in arg.blocks]}") + to_log.append(f"num_blocks={[len(b) for b in arg.blocks]}") # Log the event as a line to the file try: with open(self._event_file, "a") as f: - f.write(' '.join(to_log) + "\n") + f.write(" ".join(to_log) + "\n") except Exception as e: - print(f"[ERROR] Could not log event {name} " - f"for {self.name}: {e}") + print(f"[ERROR] Could not log event {name} for {self.name}: {e}") return attr(*args, **kwargs) return wrapper return attr -KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__, - TestSharedStorageConnector.__name__) +KVConnectorFactory.register_connector( + "TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__ +) diff --git a/tests/v1/kv_offload/test_cpu_gpu.py b/tests/v1/kv_offload/test_cpu_gpu.py index 0edb9513e3ff..81b57f1ca0c8 100644 --- a/tests/v1/kv_offload/test_cpu_gpu.py +++ b/tests/v1/kv_offload/test_cpu_gpu.py @@ -22,7 +22,7 @@ NUM_LAYERS = [4] DTYPES = [torch.bfloat16] SEEDS = [0] -CUDA_DEVICES = ['cuda:0'] +CUDA_DEVICES = ["cuda:0"] NUM_MAPPINGS = [3] @@ -56,35 +56,35 @@ def test_transfer( current_platform.seed_everything(seed) # create per-layer GPU KV caches - attn_backends_list = [ - FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend - ] + attn_backends_list = [FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend] gpu_caches = {} attn_backends = {} for i in range(num_layers): - layer_name = f'layer {i}' + layer_name = f"layer {i}" attn_backend = attn_backends_list[i % len(attn_backends_list)] attn_backends[layer_name] = attn_backend gpu_cache_shape = attn_backend.get_kv_cache_shape( - num_gpu_blocks, gpu_block_size, num_heads, head_size) - gpu_caches[layer_name] = torch.rand(gpu_cache_shape, - dtype=dtype, - device=device) + num_gpu_blocks, gpu_block_size, num_heads, head_size + ) + gpu_caches[layer_name] = torch.rand(gpu_cache_shape, dtype=dtype, device=device) # create handler cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size - handler = CpuGpuOffloadingHandler(attn_backends=attn_backends, - gpu_block_size=gpu_block_size, - cpu_block_size=cpu_block_size, - num_cpu_blocks=num_cpu_blocks, - gpu_caches=gpu_caches) + handler = CpuGpuOffloadingHandler( + attn_backends=attn_backends, + gpu_block_size=gpu_block_size, + cpu_block_size=cpu_block_size, + num_cpu_blocks=num_cpu_blocks, + gpu_caches=gpu_caches, + ) # select block mappings - gpu_blocks = random.sample(range(num_gpu_blocks), - num_mappings * gpu_blocks_per_cpu_block) + gpu_blocks = random.sample( + range(num_gpu_blocks), num_mappings * gpu_blocks_per_cpu_block + ) cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings) # convert cpu blocks to gpu block size @@ -96,9 +96,10 @@ def test_transfer( # maybe skip a GPU block to test writing to the middle of a CPU block if gpu_to_cpu: - gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1:] + gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1 :] cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[ - gpu_blocks_per_cpu_block - 1:] + gpu_blocks_per_cpu_block - 1 : + ] # set transfer direction if gpu_to_cpu: @@ -124,8 +125,9 @@ def test_transfer( # build dst -> src mapping dst_to_src = {} - for src_block, dst_block in zip(src_blocks_in_gpu_block_size, - dst_blocks_in_gpu_block_size): + for src_block, dst_block in zip( + src_blocks_in_gpu_block_size, dst_blocks_in_gpu_block_size + ): dst_to_src[dst_block] = src_block # build transfer specs @@ -157,8 +159,11 @@ def test_transfer( for dst_block in range(dst_size_in_gpu_blocks): src_block_candidate = dst_to_src.get(dst_block) for src_cache, dst_cache, orig_dst_cache, kv_dim in zip( - src_kv_caches, dst_kv_caches, orig_dst_caches, - handler.kv_dim_before_num_blocks): + src_kv_caches, + dst_kv_caches, + orig_dst_caches, + handler.kv_dim_before_num_blocks, + ): if kv_dim: # iterate over key, value for i in range(2): @@ -166,12 +171,14 @@ def test_transfer( expected_value = src_cache[i][src_block_candidate] else: expected_value = orig_dst_cache[i][dst_block] - torch.testing.assert_close(dst_cache[i][dst_block].cpu(), - expected_value.cpu()) + torch.testing.assert_close( + dst_cache[i][dst_block].cpu(), expected_value.cpu() + ) else: if src_block_candidate is not None: expected_value = src_cache[src_block_candidate] else: expected_value = orig_dst_cache[dst_block] - torch.testing.assert_close(dst_cache[dst_block].cpu(), - expected_value.cpu()) + torch.testing.assert_close( + dst_cache[dst_block].cpu(), expected_value.cpu() + ) diff --git a/tests/v1/kv_offload/test_cpu_manager.py b/tests/v1/kv_offload/test_cpu_manager.py index cdee7811d85b..57884f846b51 100644 --- a/tests/v1/kv_offload/test_cpu_manager.py +++ b/tests/v1/kv_offload/test_cpu_manager.py @@ -7,8 +7,11 @@ import numpy as np from vllm.v1.core.kv_cache_utils import BlockHash -from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent, - PrepareStoreOutput) +from vllm.v1.kv_offload.abstract import ( + LoadStoreSpec, + OffloadingEvent, + PrepareStoreOutput, +) from vllm.v1.kv_offload.backends.cpu import CPUBackend from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager from vllm.v1.kv_offload.mediums import CPULoadStoreSpec @@ -26,31 +29,38 @@ def to_hashes(int_hashes: list[int]) -> list[BlockHash]: def verify_store_output( - prepare_store_output: Optional[PrepareStoreOutput], - expected_prepare_store_output: ExpectedPrepareStoreOutput): + prepare_store_output: Optional[PrepareStoreOutput], + expected_prepare_store_output: ExpectedPrepareStoreOutput, +): assert prepare_store_output is not None - assert (prepare_store_output.block_hashes_to_store == to_hashes( - expected_prepare_store_output.block_hashes_to_store)) - assert (prepare_store_output.block_hashes_evicted == to_hashes( - expected_prepare_store_output.block_hashes_evicted)) + assert prepare_store_output.block_hashes_to_store == to_hashes( + expected_prepare_store_output.block_hashes_to_store + ) + assert prepare_store_output.block_hashes_evicted == to_hashes( + expected_prepare_store_output.block_hashes_evicted + ) store_spec = prepare_store_output.store_spec assert isinstance(store_spec, CPULoadStoreSpec) - expected_array = np.array(expected_prepare_store_output.store_block_ids, - dtype=np.int64) + expected_array = np.array( + expected_prepare_store_output.store_block_ids, dtype=np.int64 + ) assert np.array_equal(expected_array, store_spec.block_ids) -def verify_load_output(prepare_load_output: LoadStoreSpec, - expected_prepare_load_output: list[int]): +def verify_load_output( + prepare_load_output: LoadStoreSpec, expected_prepare_load_output: list[int] +): assert isinstance(prepare_load_output, CPULoadStoreSpec) expected_array = np.array(expected_prepare_load_output, dtype=np.int64) assert np.array_equal(expected_array, prepare_load_output.block_ids) -def verify_events(events: Iterable[OffloadingEvent], - block_size: int, - expected_stores: tuple[set[int], ...] = (), - expected_evictions: tuple[set[int], ...] = ()): +def verify_events( + events: Iterable[OffloadingEvent], + block_size: int, + expected_stores: tuple[set[int], ...] = (), + expected_evictions: tuple[set[int], ...] = (), +): stores: list[set[BlockHash]] = [] evictions: list[set[BlockHash]] = [] for event in events: @@ -61,8 +71,7 @@ def verify_events(events: Iterable[OffloadingEvent], else: stores.append(set(event.block_hashes)) - def to_hash_sets( - int_sets: tuple[set[int], ...]) -> tuple[set[BlockHash], ...]: + def to_hash_sets(int_sets: tuple[set[int], ...]) -> tuple[set[BlockHash], ...]: return tuple([set(to_hashes(list(int_set))) for int_set in int_sets]) assert tuple(evictions) == to_hash_sets(expected_evictions) @@ -86,7 +95,8 @@ def test_cpu_manager(): block_hashes_to_store=[1, 2], store_block_ids=[0, 1], block_hashes_evicted=[], - )) + ), + ) # lookup [1, 2] -> not ready assert cpu_manager.lookup(to_hashes([1, 2])) == 0 @@ -96,9 +106,9 @@ def test_cpu_manager(): # complete store [1, 2] cpu_manager.complete_store(to_hashes([1, 2])) - verify_events(cpu_manager.take_events(), - block_size=block_size, - expected_stores=({1, 2}, )) + verify_events( + cpu_manager.take_events(), block_size=block_size, expected_stores=({1, 2},) + ) # lookup [1, 2] assert cpu_manager.lookup(to_hashes([1])) == 1 @@ -113,12 +123,13 @@ def test_cpu_manager(): block_hashes_to_store=[3, 4, 5], store_block_ids=[2, 3, 0], block_hashes_evicted=[1], - )) + ), + ) # verify eviction event - verify_events(cpu_manager.take_events(), - block_size=block_size, - expected_evictions=({1}, )) + verify_events( + cpu_manager.take_events(), block_size=block_size, expected_evictions=({1},) + ) # prepare store with no space assert cpu_manager.prepare_store(to_hashes([1, 6])) is None @@ -144,7 +155,8 @@ def test_cpu_manager(): block_hashes_to_store=[6, 7, 8], store_block_ids=[3, 2, 1], block_hashes_evicted=[2, 3, 4], - )) + ), + ) # complete store [6, 7, 8] cpu_manager.complete_store(to_hashes([6, 7, 8])) @@ -160,7 +172,8 @@ def test_cpu_manager(): block_hashes_to_store=[9], store_block_ids=[1], block_hashes_evicted=[8], - )) + ), + ) # complete store [7, 9] with failure cpu_manager.complete_store(to_hashes([7, 9]), success=False) @@ -169,7 +182,9 @@ def test_cpu_manager(): assert cpu_manager.lookup(to_hashes([7])) == 1 assert cpu_manager.lookup(to_hashes([9])) == 0 - verify_events(cpu_manager.take_events(), - block_size=block_size, - expected_stores=({3, 4, 5}, {6, 7, 8}), - expected_evictions=({2, 3, 4}, {8})) + verify_events( + cpu_manager.take_events(), + block_size=block_size, + expected_stores=({3, 4, 5}, {6, 7, 8}), + expected_evictions=({2, 3, 4}, {8}), + ) diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py index fc8ca09bea3d..0d90cc715fd4 100644 --- a/tests/v1/kv_offload/test_cpu_offloading.py +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -20,10 +20,7 @@ def test_cpu_offloading(cpu_block_size: int) -> None: kv_transfer_config = KVTransferConfig( kv_connector="OffloadingConnector", kv_role="kv_both", - kv_connector_extra_config={ - "num_cpu_blocks": 100, - "block_size": cpu_block_size - }, + kv_connector_extra_config={"num_cpu_blocks": 100, "block_size": cpu_block_size}, ) llm = LLM( diff --git a/tests/v1/kv_offload/test_worker.py b/tests/v1/kv_offload/test_worker.py index 6cf8aa0875d6..6fcd408f3c59 100644 --- a/tests/v1/kv_offload/test_worker.py +++ b/tests/v1/kv_offload/test_worker.py @@ -1,17 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.v1.kv_offload.abstract import LoadStoreSpec -from vllm.v1.kv_offload.worker.worker import (OffloadingHandler, - OffloadingWorker, TransferResult, - TransferSpec) +from vllm.v1.kv_offload.worker.worker import ( + OffloadingHandler, + OffloadingWorker, + TransferResult, + TransferSpec, +) class LoadStoreSpec1(LoadStoreSpec): - - def __init__(self, - submit_success: bool = True, - async_success: bool = True, - exception: bool = False): + def __init__( + self, + submit_success: bool = True, + async_success: bool = True, + exception: bool = False, + ): self.finished = False self.submit_success = submit_success self.async_success = async_success @@ -26,7 +30,6 @@ def __repr__(self): class LoadStoreSpec2(LoadStoreSpec): - @staticmethod def medium() -> str: return "2" @@ -36,7 +39,6 @@ def __repr__(self): class OffloadingHandler1To2(OffloadingHandler): - def __init__(self): self.transfers: dict[int, LoadStoreSpec1] = {} @@ -63,7 +65,6 @@ def get_finished(self) -> list[TransferResult]: class OffloadingHandler2To1(OffloadingHandler): - def __init__(self): self.transfers: dict[int, LoadStoreSpec1] = {} @@ -144,9 +145,9 @@ def test_offloading_worker(): assert 7 in handler2to1.transfers # verify result of 3rd and 4th transfers - assert (sorted(worker.get_finished()) == [(3, False), (4, True)]) + assert sorted(worker.get_finished()) == [(3, False), (4, True)] # complete 6th and 7th transfers src6.finished = True dst7.finished = True - assert (sorted(worker.get_finished()) == [(6, True), (7, True)]) + assert sorted(worker.get_finished()) == [(6, True), (7, True)] diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py index 43caef79b02f..34997b7e7a43 100644 --- a/tests/v1/logits_processors/test_correctness.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -10,24 +10,28 @@ import torch from tests.utils import create_new_process_for_each_test -from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits, - create_penalty_tensor, - create_prompt_tokens_tensor, - fake_apply_logitsprocs, - fake_update_logitsprocs_state) +from tests.v1.sample.utils import ( + LogitsprocsTestFakes, + create_fake_logits, + create_penalty_tensor, + create_prompt_tokens_tensor, + fake_apply_logitsprocs, + fake_update_logitsprocs_state, +) from vllm.config import VllmConfig from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available -# yapf: disable -from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder, - LogitBiasLogitsProcessor, - LogitsProcessor, - MinPLogitsProcessor, - MinTokensLogitsProcessor, - MoveDirectionality, - build_logitsprocs) -# yapf: enable +from vllm.v1.sample.logits_processor import ( + BatchUpdate, + BatchUpdateBuilder, + LogitBiasLogitsProcessor, + LogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + MoveDirectionality, + build_logitsprocs, +) from vllm.v1.sample.metadata import SamplingMetadata PIN_MEMORY_AVAILABLE = is_pin_memory_available() @@ -49,9 +53,10 @@ class LogitsProcsRequestParams: """Encapsulates key params for a single request in a batch. - + Params can be customized based on the enabled logitproc """ + workload_index: int logitproc_type: LogitprocType # Logitproc enabled, specified by str id out_tokens: list[int] # Output tokens required for min tokens test @@ -64,14 +69,13 @@ def __init__(self, workload_index: int, logitproc_type: LogitprocType): # Number of output tokens is randomly 0 or twice the min-tokens # threshold which will be used in testing. Output token values # don't matter *for these tests* so use 0 as a dummy value - self.out_tokens = ([0] * - (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))) + self.out_tokens = [0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)) self.prompt_tokens = [] self.params = _sampling_params_from_logitproc(logitproc_type) def __str__(self): """For debugging""" - summ = ', '.join(f'{k}={v}' for k, v in vars(self).items()) + summ = ", ".join(f"{k}={v}" for k, v in vars(self).items()) return f"MyClass({summ})" @@ -86,12 +90,13 @@ def _generate_fake_sampling_metadata( prompt_token_ids: list[list[int]] = [] for _ in range(batch_size): output_token_ids.append( - np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) + np.random.randint(0, vocab_size, size=num_output_tokens).tolist() + ) prompt_token_ids.append( - np.random.randint(0, - vocab_size, - size=np.random.randint( - 1, MAX_NUM_PROMPT_TOKENS)).tolist()) + np.random.randint( + 0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS) + ).tolist() + ) logitsprocs = build_logitsprocs( vllm_config=VllmConfig(), device=device, @@ -99,15 +104,16 @@ def _generate_fake_sampling_metadata( is_pooling_model=False, ) fake_sampling_metadata = SamplingMetadata( - temperature=torch.full((batch_size, ), 0.0), + temperature=torch.full((batch_size,), 0.0), all_greedy=True, all_random=False, top_p=None, top_k=None, generators={}, max_num_logprobs=0, - prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids, - vocab_size, device), + prompt_token_ids=create_prompt_tokens_tensor( + prompt_token_ids, vocab_size, device + ), output_token_ids=output_token_ids, frequency_penalties=create_penalty_tensor(batch_size, 0.0, device), presence_penalties=create_penalty_tensor(batch_size, 0.0, device), @@ -115,7 +121,8 @@ def _generate_fake_sampling_metadata( no_penalties=True, allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=logitsprocs) + logitsprocs=logitsprocs, + ) return fake_sampling_metadata @@ -127,15 +134,15 @@ def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes: fake_logits[i, 0] = 10.0 # High logit for first token fake_logits[i, 1:] = 1e-2 # Others remain low sampling_metadata = _generate_fake_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) return LogitsprocsTestFakes( logits=fake_logits, sampling_metadata=sampling_metadata, ) -def _sampling_params_from_logitproc( - logitproc_type: LogitprocType) -> SamplingParams: +def _sampling_params_from_logitproc(logitproc_type: LogitprocType) -> SamplingParams: """Customize request SamplingParams for a specified logitproc""" # SamplingParams for req with no logitproc kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0} @@ -150,7 +157,7 @@ def _generate_mixed_logitsprocs_batch_params( ) -> list[LogitsProcsRequestParams]: """Define key params for a batch of requests with a different logitproc enabled per request. - + The batch will have `reqs_per_logitproc` repeats for all `logitsprocs_types` under test, including the case where no logitsproc is enabled. The batch is randomly shuffled. The @@ -173,7 +180,8 @@ def _generate_mixed_logitsprocs_batch_params( return [ LogitsProcsRequestParams( workload_index=idx, - logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc]) + logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc], + ) for idx, pdx in enumerate(batch_perm) ] @@ -185,10 +193,12 @@ def _raise_error_invalid( step_idx: int, err_cls: type[Exception] = ValueError, ) -> None: - raise err_cls(f"Validation failed for step={step_idx}, " - f"batch_index={batch_index}, " - f"workload_index={request_params.workload_index}, " - f"req_params={request_params}. Reason: {msg_suffix}") + raise err_cls( + f"Validation failed for step={step_idx}, " + f"batch_index={batch_index}, " + f"workload_index={request_params.workload_index}, " + f"req_params={request_params}. Reason: {msg_suffix}" + ) def _logit_bias_params(kwargs: dict) -> None: @@ -208,8 +218,7 @@ def _logit_bias_validate( ) -> None: """Validate logit bias logitproc applied correctly""" logit_bias = request_params.params.logit_bias - logits_old = ( - test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()) + logits_old = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu() logits_new = logits_new[batch_index].cpu() for token_id in range(VOCAB_SIZE): logit_old_value = logits_old[token_id] @@ -218,22 +227,28 @@ def _logit_bias_validate( bias_value = logit_bias[token_id] exp_value = bias_value + logit_old_value if logit_new_value != pytest.approx(exp_value): - _raise_error_invalid(msg_suffix=( - f"Biased token {token_id} logit value {logit_new_value} " - f"does not match expected value {exp_value} " - f"given bias {bias_value}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + f"Biased token {token_id} logit value {logit_new_value} " + f"does not match expected value {exp_value} " + f"given bias {bias_value}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) else: if logit_new_value != pytest.approx(logit_old_value): - _raise_error_invalid(msg_suffix=( - f"Unbiased token {token_id} logit value {logit_new_value} " - f"does not match expected value {logit_old_value}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + f"Unbiased token {token_id} logit value {logit_new_value} " + f"does not match expected value {logit_old_value}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) def _min_p_params(kwargs: dict) -> None: @@ -259,26 +274,27 @@ def _min_p_validate( msg_suffix="Invalid: dominant token 0 masked (-inf)", batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) else: if request_params.params.min_p > 0.0: # Non-dominant tokens should be masked when min_p > 0 if logits_for_token != -float("inf"): _raise_error_invalid( - msg_suffix= - f"Invalid: non-dominant token {token_id} not masked", + msg_suffix=f"Invalid: non-dominant token {token_id} not masked", batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) else: # No masking when min_p is 0 if logits_for_token == -float("inf"): _raise_error_invalid( - msg_suffix= - f"Invalid: token {token_id} masked when min_p=0.0", + msg_suffix=f"Invalid: token {token_id} masked when min_p=0.0", batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) def _min_tokens_params(kwargs: dict) -> None: @@ -303,7 +319,8 @@ def _min_tokens_validate( min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD ref_all_stop_token_ids = request_params.params.all_stop_token_ids mt_lp: MinTokensLogitsProcessor = next( - test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor)) + test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor) + ) assert isinstance(mt_lp, MinTokensLogitsProcessor) min_tok = mt_lp.min_toks.get(batch_index, None) @@ -312,38 +329,50 @@ def _min_tokens_validate( (_, out_tok, all_stop_token_ids) = min_tok num_out_tokens = len(out_tok) if num_out_tokens != ref_num_out_tokens: - _raise_error_invalid(msg_suffix=( - "Number of output tokens in min-token logit processor " - f"request metadata ({num_out_tokens}) does not match " - f"reference ({ref_num_out_tokens})."), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + "Number of output tokens in min-token logit processor " + f"request metadata ({num_out_tokens}) does not match " + f"reference ({ref_num_out_tokens})." + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) if ref_all_stop_token_ids != all_stop_token_ids: - _raise_error_invalid(msg_suffix=( - "Stop token ids do not match reference; all_stop_token_ids: " - f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: " - f"{sorted(ref_all_stop_token_ids)}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + "Stop token ids do not match reference; all_stop_token_ids: " + f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: " + f"{sorted(ref_all_stop_token_ids)}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) if min_reached: - _raise_error_invalid(msg_suffix=( - "Expected min-tokens request with min reached, but batch " - "index is recognized by min-tokens logits processor."), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx, - err_cls=RuntimeError) + _raise_error_invalid( + msg_suffix=( + "Expected min-tokens request with min reached, but batch " + "index is recognized by min-tokens logits processor." + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + err_cls=RuntimeError, + ) elif not min_reached: - _raise_error_invalid(msg_suffix=( - "Expected min-tokens request with min not reached, but batch " - "index is not recognized by min-tokens logits processor."), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx, - err_cls=RuntimeError) + _raise_error_invalid( + msg_suffix=( + "Expected min-tokens request with min not reached, but batch " + "index is not recognized by min-tokens logits processor." + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + err_cls=RuntimeError, + ) # Validate min-token logits for token_id in range(VOCAB_SIZE): @@ -351,21 +380,27 @@ def _min_tokens_validate( if token_id in ref_all_stop_token_ids and not min_reached: if logits_for_token != -float("inf"): _raise_error_invalid( - msg_suffix=(f"Token {token_id} is a stop token and " - "the sequence has not reached min length, " - "but the token is not masked " - f"(logit={logits_for_token})"), + msg_suffix=( + f"Token {token_id} is a stop token and " + "the sequence has not reached min length, " + "but the token is not masked " + f"(logit={logits_for_token})" + ), batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) else: if logits_for_token == -float("inf"): _raise_error_invalid( - msg_suffix=(f"Token {token_id} should not be masked but " - f"is (output len={ref_num_out_tokens})"), + msg_suffix=( + f"Token {token_id} should not be masked but " + f"is (output len={ref_num_out_tokens})" + ), batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) def _none_validate( @@ -377,52 +412,58 @@ def _none_validate( step_idx: int, ) -> None: """Validate that no logits processors are applied""" - logits = ( - test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()) + logits = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu() ref_logits = logits_new[batch_index] if not torch.all(ref_logits == logits): - mismatch_toks = (ref_logits - != logits).nonzero(as_tuple=True)[0].tolist() + mismatch_toks = (ref_logits != logits).nonzero(as_tuple=True)[0].tolist() mismatch_strs = [] for token in mismatch_toks: val = float(logits[token]) ref_val = float(ref_logits[token]) mismatch_strs.append(f"({token=},{val=},{ref_val=})") - _raise_error_invalid(msg_suffix=( - f"Unexpected modification of logits: {','.join(mismatch_strs)}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + f"Unexpected modification of logits: {','.join(mismatch_strs)}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) class LogitsprocTestHelpers(NamedTuple): """Supports setting up and validating logitsprocs unit tests.""" + eval_fxn: Callable gen_request_fxn: Optional[Callable] = None logitsprocs_test_mapping = { - STR_NO_LOGITPROC: - LogitsprocTestHelpers(eval_fxn=_none_validate), - LogitBiasLogitsProcessor: - LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params, - eval_fxn=_logit_bias_validate), - MinPLogitsProcessor: - LogitsprocTestHelpers(gen_request_fxn=_min_p_params, - eval_fxn=_min_p_validate), - MinTokensLogitsProcessor: - LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params, - eval_fxn=_min_tokens_validate), + STR_NO_LOGITPROC: LogitsprocTestHelpers(eval_fxn=_none_validate), + LogitBiasLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_logit_bias_params, eval_fxn=_logit_bias_validate + ), + MinPLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_min_p_params, eval_fxn=_min_p_validate + ), + MinTokensLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate + ), } def _get_test_cases() -> list[list[str]]: """Each test case is a set of logitsprocs""" logitsprocs_types = list(logitsprocs_test_mapping.keys()) - return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC] - for logitproc_type in logitsprocs_types - if logitproc_type != STR_NO_LOGITPROC - ] + [logitsprocs_types] + return ( + [[STR_NO_LOGITPROC]] + + [ + [logitproc_type, STR_NO_LOGITPROC] + for logitproc_type in logitsprocs_types + if logitproc_type != STR_NO_LOGITPROC + ] + + [logitsprocs_types] + ) def _generate_fake_step_update( @@ -440,11 +481,18 @@ def _generate_fake_step_update( # Other 50%: add a limited number of reqs (less than the number # of workload reqs remaining, less than an arbitrary max) # If no workload reqs remain: 100% of steps have 0 adds - num_step_add = random.choice([ - 0, - random.randint(1, min(max_add_remove_per_step, - workload_reqs_remaining)) - ]) if workload_reqs_remaining else 0 + num_step_add = ( + random.choice( + [ + 0, + random.randint( + 1, min(max_add_remove_per_step, workload_reqs_remaining) + ), + ] + ) + if workload_reqs_remaining + else 0 + ) # 50% of steps: remove no requests # Other 50%: remove a limited number of reqs (less than the number @@ -452,9 +500,11 @@ def _generate_fake_step_update( # If persistent batch is empty: 100% of steps have 0 removals until # more requests are added. Assume that removed requests are always # drawn from the current batch, before new adds - num_step_remove = random.choice([ - 0, random.randint(1, min(max_add_remove_per_step, batch_size)) - ]) if batch_size else 0 + num_step_remove = ( + random.choice([0, random.randint(1, min(max_add_remove_per_step, batch_size))]) + if batch_size + else 0 + ) num_step_add_replace = min(num_step_add, num_step_remove) @@ -463,23 +513,34 @@ def _generate_fake_step_update( batch_update_builder.removed_append(removal) # Get added requests from workload - for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]: + for add_req_params in workload_params[wdx : (wdx + num_step_add_replace)]: # Replace as many removed requests as possible with added requests add_remove_idx = batch_update_builder.pop_removed() batch_update_builder.added.append( - (add_remove_idx, add_req_params.params, - add_req_params.prompt_tokens, add_req_params.out_tokens)) + ( + add_remove_idx, + add_req_params.params, + add_req_params.prompt_tokens, + add_req_params.out_tokens, + ) + ) persistent_batch[add_remove_idx] = add_req_params # Append remaining added requests to end of batch - add_reqs_append = workload_params[(wdx + - num_step_add_replace):(wdx + - num_step_add)] - batch_update_builder.added.extend([ - (adx + batch_size, add_req_params.params, add_req_params.prompt_tokens, - add_req_params.out_tokens) - for adx, add_req_params in enumerate(add_reqs_append) - ]) + add_reqs_append = workload_params[ + (wdx + num_step_add_replace) : (wdx + num_step_add) + ] + batch_update_builder.added.extend( + [ + ( + adx + batch_size, + add_req_params.params, + add_req_params.prompt_tokens, + add_req_params.out_tokens, + ) + for adx, add_req_params in enumerate(add_reqs_append) + ] + ) persistent_batch.extend(add_reqs_append) pre_condense_batch_size = len(persistent_batch) wdx += num_step_add # Update workload offset @@ -488,8 +549,10 @@ def _generate_fake_step_update( last_nonempty_index = pre_condense_batch_size - 1 condensed_to_idxs = set() while batch_update_builder.removed: - if (last_nonempty_index in batch_update_builder.removed - or last_nonempty_index in condensed_to_idxs): + if ( + last_nonempty_index in batch_update_builder.removed + or last_nonempty_index in condensed_to_idxs + ): last_nonempty_index -= 1 continue # last_nonempty_index is the highest persistent batch index that was @@ -504,11 +567,10 @@ def _generate_fake_step_update( # move last_nonempty_index -> first_empty_index batch_update_builder.pop_removed() condensed_to_idxs.add(first_empty_index) - persistent_batch[first_empty_index] = persistent_batch[ - last_nonempty_index] + persistent_batch[first_empty_index] = persistent_batch[last_nonempty_index] batch_update_builder.moved.append( - (last_nonempty_index, first_empty_index, - MoveDirectionality.UNIDIRECTIONAL)) + (last_nonempty_index, first_empty_index, MoveDirectionality.UNIDIRECTIONAL) + ) last_nonempty_index -= 1 @@ -524,18 +586,21 @@ def _generate_fake_step_update( k = random.randint(0, condensed_batch_size // 2) idxs = list(range(condensed_batch_size)) random.shuffle(idxs) - swaps = [ - tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k) - ] - batch_update_builder.moved.extend([ - (sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps - ]) + swaps = [tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)] + batch_update_builder.moved.extend( + [(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps] + ) for adx, bdx in swaps: - persistent_batch[adx], persistent_batch[bdx] = persistent_batch[ - bdx], persistent_batch[adx] - - return (batch_update_builder.get_and_reset(condensed_batch_size), wdx, - workload_size - wdx) + persistent_batch[adx], persistent_batch[bdx] = ( + persistent_batch[bdx], + persistent_batch[adx], + ) + + return ( + batch_update_builder.get_and_reset(condensed_batch_size), + wdx, + workload_size - wdx, + ) def _assert_valid( @@ -550,8 +615,10 @@ def _assert_valid( # Trivial case of empty persistent batch assert len(persistent_batch) == 0 if logits_w_lp.shape[0] != 0: - raise ValueError("Fake persistent batch is empty but logitsprocs " - f"output batch has shape {logits_w_lp.shape}") + raise ValueError( + "Fake persistent batch is empty but logitsprocs " + f"output batch has shape {logits_w_lp.shape}" + ) return # Validate logits for each fake request @@ -560,36 +627,40 @@ def _assert_valid( # Invoke the appropriate validation function for # the logitproc employed by this request fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn - fxn(test_fakes=test_fakes, + fxn( + test_fakes=test_fakes, persistent_batch=persistent_batch, logits_new=logits_w_lp, batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) @create_new_process_for_each_test() @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC]) @pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases()) -def test_logitsprocs(device: str, reqs_per_logitproc: int, - logitsprocs_under_test: list[str]): +def test_logitsprocs( + device: str, reqs_per_logitproc: int, logitsprocs_under_test: list[str] +): random.seed(40) torch.set_default_device(device) # Define a shuffled batch of requests which individually use a different # logitproc, or no logitproc at all workload_params = _generate_mixed_logitsprocs_batch_params( - reqs_per_logitproc=reqs_per_logitproc, - logitsprocs_types=logitsprocs_under_test) + reqs_per_logitproc=reqs_per_logitproc, logitsprocs_types=logitsprocs_under_test + ) workload_size = len(workload_params) # Create fake test data structures for testing. test_fakes = _generate_test_fakes(workload_size, device) wdx = 0 # Next request index in workload to add - persistent_batch: list[LogitsProcsRequestParams] = [ - ] # Persistent batch state, as list of workload indices + persistent_batch: list[ + LogitsProcsRequestParams + ] = [] # Persistent batch state, as list of workload indices # Generate fake removed request indices from current persistent # batch before adds diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py index 891f55a14633..f57a21dce516 100644 --- a/tests/v1/logits_processors/test_custom_offline.py +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -7,32 +7,40 @@ import pytest from tests.utils import create_new_process_for_each_test -# yapf: disable -from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, - DUMMY_LOGITPROC_FQCN, - DUMMY_LOGITPROC_MODULE, - MAX_TOKENS, MODEL_NAME, - POOLING_MODEL_NAME, TEMP_GREEDY, - CustomLogitprocSource, - DummyLogitsProcessor, - WrappedPerReqLogitsProcessor, - dummy_module) +from tests.v1.logits_processors.utils import ( + DUMMY_LOGITPROC_ARG, + DUMMY_LOGITPROC_FQCN, + DUMMY_LOGITPROC_MODULE, + MAX_TOKENS, + MODEL_NAME, + POOLING_MODEL_NAME, + TEMP_GREEDY, + CustomLogitprocSource, + DummyLogitsProcessor, + WrappedPerReqLogitsProcessor, + dummy_module, + prompts, +) from tests.v1.logits_processors.utils import entry_points as fake_entry_points -from tests.v1.logits_processors.utils import prompts -# yapf: enable from vllm import LLM, SamplingParams -from vllm.v1.sample.logits_processor import (STR_POOLING_REJECTS_LOGITSPROCS, - LogitsProcessor) +from vllm.v1.sample.logits_processor import ( + STR_POOLING_REJECTS_LOGITSPROCS, + LogitsProcessor, +) # Create a mixture of requests which do and don't utilize the dummy logitproc sampling_params_list = [ - SamplingParams(temperature=TEMP_GREEDY, - max_tokens=MAX_TOKENS, - extra_args={DUMMY_LOGITPROC_ARG: 128}), + SamplingParams( + temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 128}, + ), SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), - SamplingParams(temperature=TEMP_GREEDY, - max_tokens=MAX_TOKENS, - extra_args={DUMMY_LOGITPROC_ARG: 67}), + SamplingParams( + temperature=TEMP_GREEDY, + max_tokens=MAX_TOKENS, + extra_args={DUMMY_LOGITPROC_ARG: 67}, + ), SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS), ] @@ -49,7 +57,7 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: 2. Server has *not* loaded dummy logitproc; test that all requests behave as if logitproc is *not* operating (output matches reference `LLM` output.) - + Args: kwargs: `LLM` constructor kwargs logitproc_loaded: server has loaded dummy logitproc if True @@ -73,7 +81,8 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: # Validate outputs for bdx, (out_lp, out_ref, params) in enumerate( - zip(outputs_logitproc, outputs_ref, sampling_params_list)): + zip(outputs_logitproc, outputs_ref, sampling_params_list) + ): lp_toks = out_lp.outputs[0].token_ids if logitproc_loaded and params.extra_args: # This request exercises custom logitproc; validate that logitproc @@ -81,8 +90,8 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: target_token = params.extra_args[DUMMY_LOGITPROC_ARG] if not all(x == target_token for x in lp_toks): raise AssertionError( - f"Request {bdx} generated {lp_toks}, should all be " - f"{target_token}") + f"Request {bdx} generated {lp_toks}, should all be {target_token}" + ) else: # This request does not exercise custom logitproc (or custom # logitproc is not enabled on this server); validate against @@ -90,16 +99,15 @@ def _run_test(kwargs: dict, logitproc_loaded: bool) -> None: ref_toks = out_ref.outputs[0].token_ids if lp_toks != ref_toks: raise AssertionError( - f"Request {bdx} generated {lp_toks}, should match " - f"{ref_toks}") + f"Request {bdx} generated {lp_toks}, should match {ref_toks}" + ) @create_new_process_for_each_test() @pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource)) -def test_custom_logitsprocs(monkeypatch, - logitproc_source: CustomLogitprocSource): +def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource): """Test offline Python interface for passing custom logitsprocs - + Construct an `LLM` instance which loads a custom logitproc that has a well-defined behavior (mask out all tokens except one `target_token`) @@ -118,7 +126,7 @@ def test_custom_logitsprocs(monkeypatch, instance output * Logitproc passed in via {entrypoint, class object, fully-qualified class name (FQCN)} - test that dummy logitproc is utilized correctly when - provided via any of these three possible sources + provided via any of these three possible sources Args: monkeypatch: for setting env vars @@ -142,6 +150,7 @@ def test_custom_logitsprocs(monkeypatch, # Scenario: vLLM loads a logitproc from a preconfigured entrypoint # To that end, mock a dummy logitproc entrypoint import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore # fork is required for workers to see entrypoint patch @@ -165,7 +174,7 @@ def test_custom_logitsprocs(monkeypatch, @create_new_process_for_each_test() def test_custom_logitsprocs_req(monkeypatch): """Test passing request-level logits processor to offline Python interface - + Wrap a request-level logits processor to create a batch level logits processor that has a well-defined behavior (mask out all tokens except one `target_token`) @@ -190,18 +199,23 @@ def test_custom_logitsprocs_req(monkeypatch): # Test that logitproc info is passed to workers monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") random.seed(40) - _run_test({"logits_processors": [WrappedPerReqLogitsProcessor]}, - logitproc_loaded=True) + _run_test( + {"logits_processors": [WrappedPerReqLogitsProcessor]}, logitproc_loaded=True + ) @create_new_process_for_each_test() -@pytest.mark.parametrize("logitproc_source", [ - CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT, - CustomLogitprocSource.LOGITPROC_SOURCE_FQCN, - CustomLogitprocSource.LOGITPROC_SOURCE_CLASS, -]) +@pytest.mark.parametrize( + "logitproc_source", + [ + CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT, + CustomLogitprocSource.LOGITPROC_SOURCE_FQCN, + CustomLogitprocSource.LOGITPROC_SOURCE_CLASS, + ], +) def test_pooling_rejects_custom_logitsprocs( - monkeypatch, logitproc_source: CustomLogitprocSource): + monkeypatch, logitproc_source: CustomLogitprocSource +): """Validate that vLLM engine initialization properly rejects custom logitsprocs when the model is a pooling model. @@ -233,6 +247,7 @@ def test_pooling_rejects_custom_logitsprocs( # Patch in dummy logitproc entrypoint import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore # fork is required for entrypoint patch to be visible to workers, @@ -245,10 +260,8 @@ def test_pooling_rejects_custom_logitsprocs( gpu_memory_utilization=0.1, ) # Require that no logitsprocs have been loaded - assert sum([ - 1 for _ in llm.llm_engine.model_executor.driver_worker.worker. - model_runner.input_batch.logitsprocs.all - ]) == 0 + worker = llm.llm_engine.model_executor.driver_worker.worker + assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0 return kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {} diff --git a/tests/v1/logits_processors/test_custom_online.py b/tests/v1/logits_processors/test_custom_online.py index a01a479e5b24..9c5b4ff0ba17 100644 --- a/tests/v1/logits_processors/test_custom_online.py +++ b/tests/v1/logits_processors/test_custom_online.py @@ -10,18 +10,18 @@ import pytest import pytest_asyncio -from tests.utils import (RemoteOpenAIServerCustom, - create_new_process_for_each_test) -# yapf: disable -from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, - DUMMY_LOGITPROC_FQCN, - DUMMY_LOGITPROC_MODULE, - MAX_TOKENS, MODEL_NAME, - TEMP_GREEDY, dummy_module) +from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_test +from tests.v1.logits_processors.utils import ( + DUMMY_LOGITPROC_ARG, + DUMMY_LOGITPROC_FQCN, + DUMMY_LOGITPROC_MODULE, + MAX_TOKENS, + MODEL_NAME, + TEMP_GREEDY, + dummy_module, + prompts, +) from tests.v1.logits_processors.utils import entry_points as fake_entry_points -from tests.v1.logits_processors.utils import prompts - -# yapf: enable def _server_with_logitproc_entrypoint( @@ -33,11 +33,12 @@ def _server_with_logitproc_entrypoint( # Patch `entry_points` to inject logitproc entrypoint import importlib.metadata + importlib.metadata.entry_points = fake_entry_points # type: ignore from vllm.entrypoints.cli import main # fork is required for workers to see entrypoint patch - os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork" + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork" if env_dict is not None: os.environ.update(env_dict) @@ -55,10 +56,11 @@ def _server_with_logitproc_module( # Patch `modules` to inject dummy logitproc module from vllm.entrypoints.cli import main + sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module # fork is required for workers to see entrypoint patch - os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork" + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork" if env_dict is not None: os.environ.update(env_dict) @@ -80,8 +82,9 @@ def default_server_args(): ] -@pytest.fixture(scope="function", - params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]]) +@pytest.fixture( + scope="function", params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]] +) def server(default_server_args, request, monkeypatch): """Consider two server configurations: (1) --logits-processors cli arg specifies dummy logits processor via fully- @@ -102,8 +105,7 @@ def server(default_server_args, request, monkeypatch): args = default_server_args _server_fxn = _server_with_logitproc_entrypoint - with RemoteOpenAIServerCustom(MODEL_NAME, args, - _server_fxn) as remote_server: + with RemoteOpenAIServerCustom(MODEL_NAME, args, _server_fxn) as remote_server: yield remote_server @@ -133,7 +135,7 @@ async def client(server): ) async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): """Test custom logitsprocs when starting OpenAI server from CLI - + Launch vLLM OpenAI-compatible server, configured to load a custom logitproc that has a well-defined behavior (mask out all tokens except one `target_token`). @@ -157,9 +159,7 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): # For requests which activate the dummy logitproc, choose one of # two `target_token` values which are known not to be EOS tokens request_keyword_args["extra_body"] = { - "vllm_xargs": { - DUMMY_LOGITPROC_ARG: target_token - } + "vllm_xargs": {DUMMY_LOGITPROC_ARG: target_token} } batch = await client.completions.create( model=model_name, @@ -173,8 +173,7 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str): choices: openai.types.CompletionChoice = batch.choices toks = choices[0].logprobs.tokens if not all([x == toks[0] for x in toks]): - raise AssertionError( - f"Generated {toks} should all be {toks[0]}") + raise AssertionError(f"Generated {toks} should all be {toks[0]}") # Alternate whether to activate dummy logitproc for each request use_dummy_logitproc = not use_dummy_logitproc diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index d3b7f314da09..9a1d5505a5f9 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -10,10 +10,13 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, - AdapterLogitsProcessor, - BatchUpdate, LogitsProcessor, - RequestLogitsProcessor) +from vllm.v1.sample.logits_processor import ( + LOGITSPROCS_GROUP, + AdapterLogitsProcessor, + BatchUpdate, + LogitsProcessor, + RequestLogitsProcessor, +) from vllm.v1.sample.logits_processor.builtin import process_dict_updates logger = init_logger(__name__) @@ -30,6 +33,7 @@ class CustomLogitprocSource(Enum): """How to source a logitproc for testing purposes""" + LOGITPROC_SOURCE_NONE = auto() # No custom logitproc LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN) @@ -48,8 +52,9 @@ class CustomLogitprocSource(Enum): class DummyLogitsProcessor(LogitsProcessor): """Fake logit processor to support unit testing and examples""" - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): self.req_info: dict[int, int] = {} def is_argmax_invariant(self) -> bool: @@ -60,8 +65,8 @@ def update_state(self, batch_update: Optional[BatchUpdate]): process_dict_updates( self.req_info, batch_update, - lambda params, _, __: params.extra_args and - (params.extra_args.get("target_token")), + lambda params, _, __: params.extra_args + and (params.extra_args.get("target_token")), ) def apply(self, logits: torch.Tensor) -> torch.Tensor: @@ -69,16 +74,16 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits # Save target values before modification - cols = torch.tensor(list(self.req_info.values()), - dtype=torch.long, - device=logits.device) - rows = torch.tensor(list(self.req_info.keys()), - dtype=torch.long, - device=logits.device) + cols = torch.tensor( + list(self.req_info.values()), dtype=torch.long, device=logits.device + ) + rows = torch.tensor( + list(self.req_info.keys()), dtype=torch.long, device=logits.device + ) values_to_keep = logits[rows, cols].clone() # Mask all but target tokens - logits[rows] = float('-inf') + logits[rows] = float("-inf") logits[rows, cols] = values_to_keep return logits @@ -154,14 +159,17 @@ def new_req_logits_processor( Returns: `Callable` request logits processor, or None """ - target_token: Optional[ - Any] = params.extra_args and params.extra_args.get("target_token") + target_token: Optional[Any] = params.extra_args and params.extra_args.get( + "target_token" + ) if target_token is None: return None if not isinstance(target_token, int): logger.warning( "target_token value %s is not int; not applying logits" - " processor to request.", target_token) + " processor to request.", + target_token, + ) return None return DummyPerReqLogitsProcessor(target_token) diff --git a/tests/v1/metrics/test_engine_logger_apis.py b/tests/v1/metrics/test_engine_logger_apis.py index e6a4d0a2a2e8..bf780b1f36ad 100644 --- a/tests/v1/metrics/test_engine_logger_apis.py +++ b/tests/v1/metrics/test_engine_logger_apis.py @@ -46,23 +46,22 @@ def log_stats_enabled_engine_args(): @pytest.mark.asyncio -async def test_async_llm_replace_default_loggers( - log_stats_enabled_engine_args): +async def test_async_llm_replace_default_loggers(log_stats_enabled_engine_args): """ RayPrometheusStatLogger should replace the default PrometheusStatLogger """ - engine = AsyncLLM.from_engine_args(log_stats_enabled_engine_args, - stat_loggers=[RayPrometheusStatLogger]) - assert isinstance(engine.logger_manager.prometheus_logger, - RayPrometheusStatLogger) + engine = AsyncLLM.from_engine_args( + log_stats_enabled_engine_args, stat_loggers=[RayPrometheusStatLogger] + ) + assert isinstance(engine.logger_manager.prometheus_logger, RayPrometheusStatLogger) engine.shutdown() @pytest.mark.asyncio async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args): """ - It's still possible to use custom stat loggers exclusively by passing + It's still possible to use custom stat loggers exclusively by passing disable_log_stats=True in addition to a list of custom stat loggers. """ # Create engine_args with disable_log_stats=True for this test @@ -70,12 +69,14 @@ async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args): disabled_log_engine_args.disable_log_stats = True # Disable default loggers; pass custom stat logger to the constructor - engine = AsyncLLM.from_engine_args(disabled_log_engine_args, - stat_loggers=[DummyStatLogger]) + engine = AsyncLLM.from_engine_args( + disabled_log_engine_args, stat_loggers=[DummyStatLogger] + ) assert len(engine.logger_manager.per_engine_logger_dict[0]) == 1 - assert isinstance(engine.logger_manager.per_engine_logger_dict[0][0], - DummyStatLogger) + assert isinstance( + engine.logger_manager.per_engine_logger_dict[0][0], DummyStatLogger + ) # log_stats is still True, since custom stat loggers are used assert engine.log_stats diff --git a/tests/v1/metrics/test_metrics_reader.py b/tests/v1/metrics/test_metrics_reader.py index 16bca359fc2f..1c90e6d33527 100644 --- a/tests/v1/metrics/test_metrics_reader.py +++ b/tests/v1/metrics/test_metrics_reader.py @@ -4,8 +4,13 @@ import prometheus_client import pytest -from vllm.v1.metrics.reader import (Counter, Gauge, Histogram, Vector, - get_metrics_snapshot) +from vllm.v1.metrics.reader import ( + Counter, + Gauge, + Histogram, + Vector, + get_metrics_snapshot, +) pytestmark = pytest.mark.cpu_test @@ -20,10 +25,12 @@ def test_registry(monkeypatch): @pytest.mark.parametrize("num_engines", [1, 4]) def test_gauge_metric(test_registry, num_engines): - g = prometheus_client.Gauge("vllm:test_gauge", - "Test gauge metric", - labelnames=["model", "engine_index"], - registry=test_registry) + g = prometheus_client.Gauge( + "vllm:test_gauge", + "Test gauge metric", + labelnames=["model", "engine_index"], + registry=test_registry, + ) for i in range(num_engines): g.labels(model="foo", engine_index=str(i)).set(98.5) @@ -41,10 +48,12 @@ def test_gauge_metric(test_registry, num_engines): @pytest.mark.parametrize("num_engines", [1, 4]) def test_counter_metric(test_registry, num_engines): - c = prometheus_client.Counter("vllm:test_counter", - "Test counter metric", - labelnames=["model", "engine_index"], - registry=test_registry) + c = prometheus_client.Counter( + "vllm:test_counter", + "Test counter metric", + labelnames=["model", "engine_index"], + registry=test_registry, + ) for i in range(num_engines): c.labels(model="bar", engine_index=str(i)).inc(19) @@ -62,11 +71,13 @@ def test_counter_metric(test_registry, num_engines): @pytest.mark.parametrize("num_engines", [1, 4]) def test_histogram_metric(test_registry, num_engines): - h = prometheus_client.Histogram("vllm:test_histogram", - "Test histogram metric", - labelnames=["model", "engine_index"], - buckets=[10, 20, 30, 40, 50], - registry=test_registry) + h = prometheus_client.Histogram( + "vllm:test_histogram", + "Test histogram metric", + labelnames=["model", "engine_index"], + buckets=[10, 20, 30, 40, 50], + registry=test_registry, + ) for i in range(num_engines): hist = h.labels(model="blaa", engine_index=str(i)) hist.observe(42) @@ -97,7 +108,8 @@ def test_vector_metric(test_registry, num_engines): "vllm:spec_decode_num_accepted_tokens_per_pos", "Vector-like counter metric", labelnames=["position", "model", "engine_index"], - registry=test_registry) + registry=test_registry, + ) for i in range(num_engines): c.labels(position="0", model="llama", engine_index=str(i)).inc(10) c.labels(position="1", model="llama", engine_index=str(i)).inc(5) diff --git a/tests/v1/metrics/test_ray_metrics.py b/tests/v1/metrics/test_ray_metrics.py index 0c9f83f049e4..c844330bb466 100644 --- a/tests/v1/metrics/test_ray_metrics.py +++ b/tests/v1/metrics/test_ray_metrics.py @@ -8,8 +8,7 @@ from vllm.config import ModelDType from vllm.sampling_params import SamplingParams from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM -from vllm.v1.metrics.ray_wrappers import (RayPrometheusMetric, - RayPrometheusStatLogger) +from vllm.v1.metrics.ray_wrappers import RayPrometheusMetric, RayPrometheusStatLogger @pytest.fixture(scope="function", autouse=True) @@ -17,7 +16,7 @@ def use_v1_only(monkeypatch): """ The change relies on V1 APIs, so set VLLM_USE_V1=1. """ - monkeypatch.setenv('VLLM_USE_V1', '1') + monkeypatch.setenv("VLLM_USE_V1", "1") MODELS = [ @@ -34,24 +33,23 @@ def test_engine_log_metrics_ray( dtype: ModelDType, max_tokens: int, ) -> None: - """ Simple smoke test, verifying this can be used without exceptions. + """Simple smoke test, verifying this can be used without exceptions. Need to start a Ray cluster in order to verify outputs.""" @ray.remote(num_gpus=1) class EngineTestActor: - async def run(self): # Set environment variable inside the Ray actor since environment # variables from pytest fixtures don't propagate to Ray actors - os.environ['VLLM_USE_V1'] = '1' + os.environ["VLLM_USE_V1"] = "1" - engine_args = AsyncEngineArgs(model=model, - dtype=dtype, - disable_log_stats=False, - enforce_eager=True) + engine_args = AsyncEngineArgs( + model=model, dtype=dtype, disable_log_stats=False, enforce_eager=True + ) engine = AsyncLLM.from_engine_args( - engine_args, stat_loggers=[RayPrometheusStatLogger]) + engine_args, stat_loggers=[RayPrometheusStatLogger] + ) for i, prompt in enumerate(example_prompts): results = engine.generate( @@ -73,32 +71,40 @@ def test_sanitized_opentelemetry_name(): # Only a-z, A-Z, 0-9, _, test valid characters are preserved valid_name = "valid_metric_123_abcDEF" - assert RayPrometheusMetric._get_sanitized_opentelemetry_name( - valid_name) == valid_name + assert ( + RayPrometheusMetric._get_sanitized_opentelemetry_name(valid_name) == valid_name + ) # Test dash, dot, are replaced name_with_dash_dot = "metric-name.test" expected = "metric_name_test" - assert RayPrometheusMetric._get_sanitized_opentelemetry_name( - name_with_dash_dot) == expected + assert ( + RayPrometheusMetric._get_sanitized_opentelemetry_name(name_with_dash_dot) + == expected + ) # Test colon is replaced with underscore name_with_colon = "metric:name" expected = "metric_name" - assert RayPrometheusMetric._get_sanitized_opentelemetry_name( - name_with_colon) == expected + assert ( + RayPrometheusMetric._get_sanitized_opentelemetry_name(name_with_colon) + == expected + ) # Test multiple invalid characters are replaced name_with_invalid = "metric:name@with#special%chars" expected = "metric_name_with_special_chars" - assert RayPrometheusMetric._get_sanitized_opentelemetry_name( - name_with_invalid) == expected + assert ( + RayPrometheusMetric._get_sanitized_opentelemetry_name(name_with_invalid) + == expected + ) # Test mixed valid and invalid characters complex_name = "vllm:engine_stats/time.latency_ms-99p" expected = "vllm_engine_stats_time_latency_ms_99p" - assert RayPrometheusMetric._get_sanitized_opentelemetry_name( - complex_name) == expected + assert ( + RayPrometheusMetric._get_sanitized_opentelemetry_name(complex_name) == expected + ) # Test empty string assert RayPrometheusMetric._get_sanitized_opentelemetry_name("") == "" diff --git a/tests/v1/metrics/test_stats.py b/tests/v1/metrics/test_stats.py new file mode 100644 index 000000000000..67a2d1739b6b --- /dev/null +++ b/tests/v1/metrics/test_stats.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.v1.metrics.stats import IterationStats + + +def test_iteration_stats_repr(): + iteration_stats = IterationStats() + iteration_stats.iteration_timestamp = 0 + expected_repr = ( + "IterationStats(" + "iteration_timestamp=0, " + "num_generation_tokens=0, " + "num_prompt_tokens=0, " + "num_preempted_reqs=0, " + "finished_requests=[], " + "max_num_generation_tokens_iter=[], " + "n_params_iter=[], " + "time_to_first_tokens_iter=[], " + "inter_token_latencies_iter=[], " + "waiting_lora_adapters={}, " + "running_lora_adapters={})" + ) + assert repr(iteration_stats) == expected_repr diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 71aa9e3d379c..f83bc90778b0 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -9,9 +9,12 @@ import torch from tests.v1.sample.utils import ( - BatchLogprobsComposition, BatchLogprobsSpecType, + BatchLogprobsComposition, + BatchLogprobsSpecType, assert_incr_detok_str_matches_non_incr_detok_str, - compute_correct_cumulative_logprob, get_test_batch) + compute_correct_cumulative_logprob, + get_test_batch, +) from vllm import SamplingParams from vllm.config import LogprobsMode @@ -29,22 +32,23 @@ @pytest.fixture( scope="module", # Parameterize APC - params=[False, True]) + params=[False, True], +) def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]: with vllm_runner( - MODEL, - dtype=DTYPE, - max_logprobs=7, - # Very small number of batched tokens to ensure - # that we test chunking. - max_num_batched_tokens=16, - max_num_seqs=16, - max_model_len=128, - enforce_eager=True, - #TODO: enable this once we support it for - # prompt logprobs. - enable_prefix_caching=request.param, - gpu_memory_utilization=0.4, # up to 2 alive concurrently + MODEL, + dtype=DTYPE, + max_logprobs=7, + # Very small number of batched tokens to ensure + # that we test chunking. + max_num_batched_tokens=16, + max_num_seqs=16, + max_model_len=128, + enforce_eager=True, + # TODO: enable this once we support it for + # prompt logprobs. + enable_prefix_caching=request.param, + gpu_memory_utilization=0.4, # up to 2 alive concurrently ) as vllm_model: yield vllm_model @@ -96,8 +100,8 @@ def _repeat_logprob_config( num_test_prompts = len(test_prompts) # Make sure there is a logprobs configuration for each test prompt logprob_prompt_logprob_list = list( - itertools.islice(itertools.cycle(logprob_prompt_logprob_list), - num_test_prompts)) + itertools.islice(itertools.cycle(logprob_prompt_logprob_list), num_test_prompts) + ) # Now the number of prompts should match the number of sample params combos assert num_test_prompts == len(logprob_prompt_logprob_list) return logprob_prompt_logprob_list @@ -115,24 +119,28 @@ def _run_and_validate( do_apc: bool, ) -> None: vllm_results = vllm_model.llm.generate( - test_prompts, sampling_params=vllm_sampling_params) + test_prompts, sampling_params=vllm_sampling_params + ) for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip( - vllm_results, hf_logprobs, hf_outputs, - logprob_prompt_logprob_list): - + vllm_results, hf_logprobs, hf_outputs, logprob_prompt_logprob_list + ): # Extract request-level (prompt)logprobs config num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob # Test whether sampled token output is consistent between vLLM and HF # vLLM prompt+completion should match HF output if temperature == 0.0: - assert (vllm_result.prompt_token_ids + - vllm_result.outputs[0].token_ids == hf_output[0]) + assert ( + vllm_result.prompt_token_ids + vllm_result.outputs[0].token_ids + == hf_output[0] + ) else: # Sampled tokens won't match if not greedy - assert (vllm_result.prompt_token_ids == hf_output[0] - [:len(vllm_result.prompt_token_ids)]) + assert ( + vllm_result.prompt_token_ids + == hf_output[0][: len(vllm_result.prompt_token_ids)] + ) # Validate sample logprobs if num_top_logprobs is not None: @@ -141,8 +149,9 @@ def _run_and_validate( # correct assert vllm_result.outputs[0].logprobs is not None assert len(vllm_result.outputs[0].logprobs) == max_tokens - for logprobs, token_id in zip(vllm_result.outputs[0].logprobs, - vllm_result.outputs[0].token_ids): + for logprobs, token_id in zip( + vllm_result.outputs[0].logprobs, vllm_result.outputs[0].token_ids + ): assert logprobs is not None # Confirm that the output token appears among the logprobs @@ -159,23 +168,26 @@ def _run_and_validate( if num_top_logprobs > 0: # We should have an entry for each of the topk ranks all_ranks = {lp.rank for lp in logprobs.values()} - assert all(r in all_ranks - for r in range(1, num_top_logprobs + 1)) + assert all(r in all_ranks for r in range(1, num_top_logprobs + 1)) output_text = vllm_result.outputs[0].text output_string_from_most_likely_tokens_lst: list[str] = [] for top_logprobs in vllm_result.outputs[0].logprobs: top_logprob = next(iter(top_logprobs.values())) output_string_from_most_likely_tokens_lst.append( - top_logprob.decoded_token) + top_logprob.decoded_token + ) output_string_from_most_likely_tokens = "".join( - output_string_from_most_likely_tokens_lst) + output_string_from_most_likely_tokens_lst + ) assert_incr_detok_str_matches_non_incr_detok_str( - output_text, output_string_from_most_likely_tokens, + output_text, + output_string_from_most_likely_tokens, "The output text from the top logprob for each token " "position should be the same as the output text in the " - "result.") + "result.", + ) # Compare vLLM sample logprobs to HF vllm_sample_logprobs = vllm_result.outputs[0].logprobs @@ -187,11 +199,12 @@ def _run_and_validate( logprob, hf_logprob[i][-1][token_id].item(), atol=1e-2, - rtol=1e-2) - assert isinstance( - sample_logprob.decoded_token, - str), ("The token should be decoded by the time it is" - " returned to the user.") + rtol=1e-2, + ) + assert isinstance(sample_logprob.decoded_token, str), ( + "The token should be decoded by the time it is" + " returned to the user." + ) # At this point we know the sample logprobs are correct for this # request. Validate that cumulative_logprob is actually the sum. @@ -201,7 +214,8 @@ def _run_and_validate( vllm_result.outputs[0].cumulative_logprob, compute_correct_cumulative_logprob(vllm_result.outputs[0]), atol=1e-6, - rtol=1e-6) + rtol=1e-6, + ) else: # Logprobs disabled for this request; should be None assert vllm_result.outputs[0].logprobs is None @@ -214,17 +228,17 @@ def _run_and_validate( assert vllm_result.prompt_logprobs[0] is None # - Prompt logprobs are returned for all indices in # the prompt - assert len(vllm_result.prompt_logprobs) == len( - vllm_result.prompt_token_ids) + assert len(vllm_result.prompt_logprobs) == len(vllm_result.prompt_token_ids) for prompt_logprobs, prompt_token_id in zip( - vllm_result.prompt_logprobs[1:], - vllm_result.prompt_token_ids[1:]): + vllm_result.prompt_logprobs[1:], vllm_result.prompt_token_ids[1:] + ): assert prompt_logprobs is not None # Confirm that the prompt token appears among the logprobs assert prompt_token_id in prompt_logprobs - token_in_topk = prompt_logprobs[ - prompt_token_id].rank <= num_top_prompt_logprobs + token_in_topk = ( + prompt_logprobs[prompt_token_id].rank <= num_top_prompt_logprobs + ) # If the prompt token is not included in the top K # logprob, it can return 1 more data @@ -236,8 +250,9 @@ def _run_and_validate( if num_top_prompt_logprobs > 0: # We should have an entry for each of the topk ranks all_ranks = {lp.rank for lp in prompt_logprobs.values()} - assert all(r in all_ranks - for r in range(1, num_top_prompt_logprobs + 1)) + assert all( + r in all_ranks for r in range(1, num_top_prompt_logprobs + 1) + ) # Compare prompt logprobs to HF # The first prompt logprob is always None, so we compare it from @@ -249,19 +264,24 @@ def _run_and_validate( logprob.logprob, hf_logprob[0][i][token_id].item(), atol=2e-2, - rtol=2e-2) + rtol=2e-2, + ) else: assert vllm_result.prompt_logprobs is None -@pytest.mark.parametrize("batch_logprobs_composition", - [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT]) +@pytest.mark.parametrize( + "batch_logprobs_composition", [NONE, SAMPLE, PROMPT, SAMPLE_PROMPT] +) @pytest.mark.parametrize("temperature", [0.0, 2.0]) def test_get_logprobs_and_prompt_logprobs( - hf_model, vllm_model, - batch_logprobs_composition: BatchLogprobsComposition, - temperature: float, example_prompts: list[str], - monkeypatch: pytest.MonkeyPatch) -> None: + hf_model, + vllm_model, + batch_logprobs_composition: BatchLogprobsComposition, + temperature: float, + example_prompts: list[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: """Test V1 Engine logprobs & prompt logprobs Exercise a variety of combinations of `logprobs` and `prompt_logprobs` @@ -291,8 +311,9 @@ def test_get_logprobs_and_prompt_logprobs( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching - if do_apc and (temperature < 2.0 - or batch_logprobs_composition != SAMPLE_PROMPT): + if do_apc and ( + temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT + ): # Skip some test-cases to save time. pytest.skip() test_prompts = example_prompts @@ -309,19 +330,21 @@ def test_get_logprobs_and_prompt_logprobs( # Batch has mixed sample params # (different logprobs/prompt logprobs combos) - logprob_prompt_logprob_list = get_test_batch( - batch_logprobs_composition) + logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition) # Ensure that each test prompt has a logprob config for testing logprob_prompt_logprob_list = _repeat_logprob_config( - test_prompts, logprob_prompt_logprob_list) + test_prompts, logprob_prompt_logprob_list + ) # Generate SamplingParams vllm_sampling_params = [ - SamplingParams(max_tokens=max_tokens, - logprobs=num_lp, - prompt_logprobs=num_plp, - temperature=temperature, - seed=1984) + SamplingParams( + max_tokens=max_tokens, + logprobs=num_lp, + prompt_logprobs=num_plp, + temperature=temperature, + seed=1984, + ) for num_lp, num_plp in logprob_prompt_logprob_list ] for _ in range(2 if do_apc else 1): @@ -334,7 +357,8 @@ def test_get_logprobs_and_prompt_logprobs( logprob_prompt_logprob_list=logprob_prompt_logprob_list, temperature=temperature, max_tokens=max_tokens, - do_apc=do_apc) + do_apc=do_apc, + ) def test_max_logprobs(monkeypatch: pytest.MonkeyPatch): @@ -351,19 +375,18 @@ def test_max_logprobs(monkeypatch: pytest.MonkeyPatch): enable_prefix_caching=False, # 2 other llms alive during whole session gpu_memory_utilization=0.15, - max_model_len=256) + max_model_len=256, + ) vllm_sampling_params = SamplingParams(logprobs=1) # should pass runner.generate(["Hello world"], sampling_params=vllm_sampling_params) bad_sampling_params = SamplingParams(logprobs=2) with pytest.raises(ValueError): - runner.generate(["Hello world"], - sampling_params=bad_sampling_params) + runner.generate(["Hello world"], sampling_params=bad_sampling_params) -def test_none_logprobs(vllm_model, example_prompts, - monkeypatch: pytest.MonkeyPatch): +def test_none_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch): """Engine should return `logprobs` and `prompt_logprobs` as `None` Args: @@ -388,14 +411,12 @@ def test_none_logprobs(vllm_model, example_prompts, for i in range(len(results_logprobs_none)): # Check sample logprobs are None assert results_logprobs_none[i].outputs[0].logprobs is None - assert results_logprobs_none[i].outputs[ - 0].cumulative_logprob is None + assert results_logprobs_none[i].outputs[0].cumulative_logprob is None # Check prompt logprobs are None assert results_logprobs_none[i].prompt_logprobs is None -def test_zero_logprobs(vllm_model, example_prompts, - monkeypatch: pytest.MonkeyPatch): +def test_zero_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch): """Engine should return sampled token and prompt token logprobs Args: @@ -406,12 +427,12 @@ def test_zero_logprobs(vllm_model, example_prompts, m.setenv("VLLM_USE_V1", "1") max_tokens = 5 - sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens, - logprobs=0, - prompt_logprobs=0, - temperature=0.0) + sampling_params_logprobs_zero = SamplingParams( + max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0 + ) results_logprobs_zero = vllm_model.llm.generate( - example_prompts, sampling_params=sampling_params_logprobs_zero) + example_prompts, sampling_params=sampling_params_logprobs_zero + ) for i in range(len(results_logprobs_zero)): # Check that there is one sample logprob dict for each @@ -422,8 +443,7 @@ def test_zero_logprobs(vllm_model, example_prompts, prompt_token_ids = results_logprobs_zero[i].prompt_token_ids assert logprobs is not None assert len(sampled_token_ids) == len(logprobs) - assert results_logprobs_zero[i].outputs[ - 0].cumulative_logprob is not None + assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None # Check that there is one prompt logprob dict for each # prompt token assert prompt_logprobs is not None @@ -444,13 +464,15 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): enable_prefix_caching=False, # 2 other llms alive during whole session gpu_memory_utilization=0.15, - max_model_len=256) + max_model_len=256, + ) - sampling_params_logprobs_all = SamplingParams(max_tokens=5, - logprobs=-1, - prompt_logprobs=-1) + sampling_params_logprobs_all = SamplingParams( + max_tokens=5, logprobs=-1, prompt_logprobs=-1 + ) results_logprobs_all = runner.llm.generate( - example_prompts, sampling_params=sampling_params_logprobs_all) + example_prompts, sampling_params=sampling_params_logprobs_all + ) vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size() for i in range(len(results_logprobs_all)): @@ -466,13 +488,13 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch): @pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode)) -def test_logprobs_mode(logprobs_mode: LogprobsMode, - monkeypatch: pytest.MonkeyPatch): +def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch): """Test with LLM engine with different logprobs_mode. For logprobs, we should have non-positive values. For logits, we should expect at least one positive values. """ from vllm import LLM + with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -483,10 +505,10 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode, # 2 other llms alive during whole session gpu_memory_utilization=0.05, max_model_len=16, - logprobs_mode=logprobs_mode) + logprobs_mode=logprobs_mode, + ) vllm_sampling_params = SamplingParams(logprobs=1) - results = llm.generate(["Hello world"], - sampling_params=vllm_sampling_params) + results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params) total_token_with_logprobs = 0 positive_values = 0 diff --git a/tests/v1/sample/test_logprobs_e2e.py b/tests/v1/sample/test_logprobs_e2e.py index 7f41355ff7ce..b3233e50fbf1 100644 --- a/tests/v1/sample/test_logprobs_e2e.py +++ b/tests/v1/sample/test_logprobs_e2e.py @@ -15,22 +15,23 @@ MODEL = "meta-llama/Llama-3.2-1B-Instruct" MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False,gpu_memory_utilization=0.8" # noqa: E501 SERVER_ARGS = [ - "--enforce_eager", "--no_enable_prefix_caching", - "--gpu-memory-utilization=0.8" + "--enforce_eager", + "--no_enable_prefix_caching", + "--gpu-memory-utilization=0.8", ] NUM_CONCURRENT = 100 def test_prompt_logprobs_e2e(): - results = lm_eval.simple_evaluate(model="vllm", - model_args=MODEL_ARGS, - tasks=TASK, - batch_size="auto") + results = lm_eval.simple_evaluate( + model="vllm", model_args=MODEL_ARGS, tasks=TASK, batch_size="auto" + ) measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert ( + measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" def test_prompt_logprobs_e2e_server(): @@ -40,7 +41,8 @@ def test_prompt_logprobs_e2e_server(): model_args = ( f"model={MODEL}," f"base_url={url}," - f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False" + ) results = lm_eval.simple_evaluate( model="local-completions", @@ -49,6 +51,7 @@ def test_prompt_logprobs_e2e_server(): ) measured_value = results["results"][TASK][FILTER] - assert (measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + assert ( + measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 4e912f98f376..36e2e2698810 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -9,8 +9,7 @@ from vllm.platforms import current_platform from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, - RejectionSampler) +from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata DEVICE = current_platform.device_type @@ -21,10 +20,11 @@ def rejection_sampler(): return RejectionSampler() -def create_logits_tensor(output_token_ids: list[list[int]], - vocab_size: int = 100) -> torch.Tensor: +def create_logits_tensor( + output_token_ids: list[list[int]], vocab_size: int = 100 +) -> torch.Tensor: """Helper function to create logits tensor that - will produce desired token ids on argmax""" + will produce desired token ids on argmax""" token_ids = [tokens[:-1] for tokens in output_token_ids] num_total_tokens = sum(len(tokens) for tokens in token_ids) logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) @@ -44,8 +44,8 @@ def create_sampling_metadata( generators: Optional[dict[int, Any]] = None, ) -> SamplingMetadata: """Create a v1 sampling metadata object with all_greedy set - to the given value. Either all greedy or all random sampling - is used. + to the given value. Either all greedy or all random sampling + is used. """ generators = generators or {} if all_greedy: @@ -81,10 +81,10 @@ def test_perfect_match(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -93,9 +93,7 @@ def test_perfect_match(rejection_sampler): bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) - expected = torch.tensor([[1, 2, 3, 4]], - dtype=torch.int, - device=logits.device) + expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) @@ -106,10 +104,10 @@ def test_early_mismatch(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -129,15 +127,16 @@ def test_early_mismatch(rejection_sampler): def test_multiple_sequences(rejection_sampler): """Test handling multiple sequences of speculated tokens""" spec_tokens = [[1, 2], [3]] - output_tokens = [[1, 2, 5], [3, - 4]] # Two sequences with bonus tokens 5 and 4 + output_tokens = [[1, 2, 5], [3, 4]] # Two sequences with bonus tokens 5 and 4 metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( - [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -146,9 +145,9 @@ def test_multiple_sequences(rejection_sampler): bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) - expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], - dtype=torch.int, - device=logits.device) + expected = torch.tensor( + [[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device + ) assert torch.equal(output, expected) @@ -159,10 +158,10 @@ def test_single_token_sequence(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -182,10 +181,10 @@ def test_empty_sequence(rejection_sampler): metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([output_tokens[0][-1]], - device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -201,15 +200,16 @@ def test_empty_sequence(rejection_sampler): def test_multiple_mismatches(rejection_sampler): """Test handling multiple sequences with mismatches""" spec_tokens = [[1, 2, 3], [4, 5, 6]] - output_tokens = [[1, 2, 7, 6], [4, 8, 6, - 9]] # Mismatches in both sequences + output_tokens = [[1, 2, 7, 6], [4, 8, 6, 9]] # Mismatches in both sequences metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( - [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -219,8 +219,10 @@ def test_multiple_mismatches(rejection_sampler): sampling_metadata=metadata, ) expected = torch.tensor( - [[1, 2, 7, PLACEHOLDER_TOKEN_ID], - [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], + [ + [1, 2, 7, PLACEHOLDER_TOKEN_ID], + [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID], + ], dtype=torch.int, device=logits.device, ) @@ -232,18 +234,23 @@ def test_multiple_mismatches(rejection_sampler): [ ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch - ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], - [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches - ]) -def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, - expected): + ( + [[1, 2], [3, 4]], + [[1, 5, 6], [3, 4, 7]], + [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]], + ), # Mixed matches + ], +) +def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expected): """Parametrized test for various matching scenarios""" metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) - bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], - device=logits.device) - spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, - device=logits.device) + bonus_token_tensor = torch.tensor( + [tokens[-1] for tokens in output_tokens], device=logits.device + ) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + spec_tokens, device=logits.device + ) output = rejection_sampler( spec_decode_metadata, @@ -252,9 +259,7 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, bonus_token_ids=bonus_token_tensor, sampling_metadata=metadata, ) - expected_tensor = torch.tensor(expected, - dtype=torch.int, - device=logits.device) + expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) assert torch.equal(output, expected_tensor) @@ -273,22 +278,15 @@ def test_deterministic_when_seeded( n_rep: int, ): num_tokens = batch_size * k - draft_probs = torch.rand(num_tokens, - vocab_size, - dtype=torch.float32, - device=DEVICE) + draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE) draft_probs = F.softmax(draft_probs, dim=-1) target_logits = torch.rand_like(draft_probs) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64, - device=DEVICE) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64, - device=DEVICE) + bonus_token_ids = torch.randint( + low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64, device=DEVICE + ) + draft_token_ids = torch.randint( + low=0, high=vocab_size, size=(batch_size, k), dtype=torch.int64, device=DEVICE + ) seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded @@ -296,17 +294,17 @@ def test_deterministic_when_seeded( for _ in range(n_rep): seeded_seqs = { i: torch.Generator(device=DEVICE).manual_seed(i) - for i in range(batch_size) if seeded_mask[i] + for i in range(batch_size) + if seeded_mask[i] } - temperature = torch.ones(batch_size, - dtype=torch.float32, - device=DEVICE) - sampling_metadata = create_sampling_metadata(all_greedy=False, - temperature=temperature, - generators=seeded_seqs) + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) + sampling_metadata = create_sampling_metadata( + all_greedy=False, temperature=temperature, generators=seeded_seqs + ) spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids.tolist(), device=DEVICE) + draft_token_ids.tolist(), device=DEVICE + ) rep_result = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, @@ -352,8 +350,7 @@ def test_rejection_sampling_approximates_target_distribution(): num_reference_probs = 100 # Prepare draft, target, and reference probability distributions - draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), - dim=-1) + draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), dim=-1) target_logits = torch.rand(vocab_size, dtype=torch.float32) target_probs = F.softmax(target_logits, dim=-1) reference_probs = F.softmax( @@ -368,38 +365,48 @@ def test_rejection_sampling_approximates_target_distribution(): for num_samples in sample_sizes: # Sample using rejection sampling. rej_sample_probs = estimate_rejection_sampling_pdf( - draft_probs, target_logits, k, vocab_size, num_samples) + draft_probs, target_logits, k, vocab_size, num_samples + ) rej_sample_probs = rej_sample_probs.to(DEVICE) # Average distance from reference probs. - reference_vs_rejsample_dist = torch.dist( - reference_probs, - rej_sample_probs).item() / reference_probs.shape[0] - target_vs_rejsample_dist = torch.dist(target_probs, - rej_sample_probs).item() + reference_vs_rejsample_dist = ( + torch.dist(reference_probs, rej_sample_probs).item() + / reference_probs.shape[0] + ) + target_vs_rejsample_dist = torch.dist(target_probs, rej_sample_probs).item() distance_wrt_reference.append(reference_vs_rejsample_dist) distance_wrt_target.append(target_vs_rejsample_dist) relative_change_in_distance_wrt_target = get_ratio_first_to_last( - distance_wrt_target) + distance_wrt_target + ) relative_change_in_distance_wrt_reference = get_ratio_first_to_last( - distance_wrt_reference) + distance_wrt_reference + ) - print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} " - f"{reference_vs_rejsample_dist=:.05f}") - print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " - f"{relative_change_in_distance_wrt_reference=:.02f}") + print( + f"{num_samples=} {target_vs_rejsample_dist=:.05f} " + f"{reference_vs_rejsample_dist=:.05f}" + ) + print( + f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " + f"{relative_change_in_distance_wrt_reference=:.02f}" + ) relative_change_in_distance_wrt_target = get_ratio_first_to_last( - distance_wrt_target) + distance_wrt_target + ) relative_change_in_distance_wrt_reference = get_ratio_first_to_last( - distance_wrt_reference) + distance_wrt_reference + ) expected_improvement_multiplier = 20 - assert (relative_change_in_distance_wrt_target - > relative_change_in_distance_wrt_reference * - expected_improvement_multiplier) + assert ( + relative_change_in_distance_wrt_target + > relative_change_in_distance_wrt_reference * expected_improvement_multiplier + ) def get_ratio_first_to_last(elements: list[float]) -> float: @@ -427,28 +434,29 @@ def estimate_rejection_sampling_pdf( rejection_sampler = RejectionSampler() num_tokens = num_samples * k # Repeat draft probs num_samples * k times. - draft_probs = draft_probs.reshape(1, 1, - vocab_size).repeat(num_samples, k, 1) + draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1) # Repeat target probs num_tokens times. target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1) # Randomly sample draft token ids from draft probs. - draft_token_ids = torch.multinomial(draft_probs[:, 0, :], - num_samples=k, - replacement=True).reshape( - num_samples, k) + draft_token_ids = torch.multinomial( + draft_probs[:, 0, :], num_samples=k, replacement=True + ).reshape(num_samples, k) draft_probs = draft_probs.view(num_tokens, vocab_size) # Bonus tokens not used but required. - bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, - device=DEVICE).repeat(num_samples, 1) + bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat( + num_samples, 1 + ) temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE) - sampling_metadata = create_sampling_metadata(all_greedy=False, - temperature=temperature) + sampling_metadata = create_sampling_metadata( + all_greedy=False, temperature=temperature + ) spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids.tolist(), device=bonus_token_ids.device) + draft_token_ids.tolist(), device=bonus_token_ids.device + ) output_token_ids = rejection_sampler( spec_decode_metadata, draft_probs=draft_probs, @@ -458,11 +466,12 @@ def estimate_rejection_sampling_pdf( ) output_token_ids = output_token_ids[:, :-1].flatten() - hist = torch.histogram(output_token_ids.to(dtype=torch.float, - device="cpu"), - bins=vocab_size, - range=(0, vocab_size), - density=True) + hist = torch.histogram( + output_token_ids.to(dtype=torch.float, device="cpu"), + bins=vocab_size, + range=(0, vocab_size), + density=True, + ) return hist.hist @@ -480,9 +489,9 @@ def _test_masked_logits( num_tokens = batch_size * num_draft_tokens # Create random draft probabilities. - draft_probs = torch.rand((num_tokens, vocab_size), - dtype=torch.float32, - device=DEVICE) + draft_probs = torch.rand( + (num_tokens, vocab_size), dtype=torch.float32, device=DEVICE + ) draft_probs = F.softmax(draft_probs, dim=-1) # Randomly sample draft token ids from draft probs @@ -491,9 +500,7 @@ def _test_masked_logits( draft_token_ids = draft_token_ids.tolist() # Bonus tokens not used but required - bonus_token_ids = torch.zeros((batch_size, 1), - dtype=torch.int64, - device=DEVICE) + bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE) # Create spec decode metadata spec_decode_metadata = SpecDecodeMetadata.make_dummy( @@ -531,8 +538,7 @@ def test_top_k(rejection_sampler, top_k): # Randomly create top-k indices. top_k_indices = [ - torch.randperm(vocab_size, device=DEVICE)[:top_k] - for _ in range(num_tokens) + torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens) ] top_k_indices = torch.stack(top_k_indices) @@ -550,9 +556,7 @@ def test_top_k(rejection_sampler, top_k): sampling_metadata = create_sampling_metadata( all_greedy=False, temperature=temperature, - top_k=torch.tensor([top_k] * batch_size, - device=DEVICE, - dtype=torch.int64), + top_k=torch.tensor([top_k] * batch_size, device=DEVICE, dtype=torch.int64), ) _test_masked_logits( @@ -595,9 +599,7 @@ def test_top_p(rejection_sampler, top_p): sampling_metadata = create_sampling_metadata( all_greedy=False, temperature=temperature, - top_p=torch.tensor([top_p] * batch_size, - device=DEVICE, - dtype=torch.float32), + top_p=torch.tensor([top_p] * batch_size, device=DEVICE, dtype=torch.float32), ) _test_masked_logits( diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 6ff000043265..5b34e27e79ac 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -29,12 +29,12 @@ def _create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: return fake_logits -def _create_penalty_tensor(batch_size: int, penalty_value: float, - device: torch.device) -> torch.Tensor: - return torch.full((batch_size, ), - fill_value=penalty_value, - dtype=torch.float, - device=device) +def _create_penalty_tensor( + batch_size: int, penalty_value: float, device: torch.device +) -> torch.Tensor: + return torch.full( + (batch_size,), fill_value=penalty_value, dtype=torch.float, device=device + ) def _create_prompt_tokens_tensor( @@ -62,9 +62,9 @@ def _create_allowed_token_ids( if i % 2 == 1: continue if mask is None: - mask = torch.zeros((batch_size, vocab_size), - dtype=torch.bool, - device=device) + mask = torch.zeros( + (batch_size, vocab_size), dtype=torch.bool, device=device + ) start = min(i, vocab_size - 1) end = min(i + num_allowed_token_ids, vocab_size - 1) mask[i, start:end] = True @@ -80,9 +80,9 @@ def _create_bad_words_token_ids( for batch_idx in range(batch_size): token_ids_single_batch = [] for bad_words_length in bad_words_lengths: - token_ids = np.random.choice(vocab_size, - size=bad_words_length, - replace=True).tolist() + token_ids = np.random.choice( + vocab_size, size=bad_words_length, replace=True + ).tolist() token_ids_single_batch.append(token_ids) bad_words_token_ids[batch_idx] = token_ids_single_batch if batch_size >= 2: @@ -95,26 +95,27 @@ def _create_bad_words_token_ids( # Returns all last tokens of bad word sequences that share the same prefix # as `given_prefix` (excluding the last token). def _collect_suffixes_with_same_prefix( - given_prefix: list[int], - bad_words_token_ids: list[list[int]]) -> list[int]: + given_prefix: list[int], bad_words_token_ids: list[list[int]] +) -> list[int]: return [bwt[-1] for bwt in bad_words_token_ids if bwt[:-1] == given_prefix] # generate a valid token id that is not in bad_words_token_ids -def _generate_valid_token_id(bad_words_token_ids: list[list[int]], - vocab_size: int) -> int: +def _generate_valid_token_id( + bad_words_token_ids: list[list[int]], vocab_size: int +) -> int: forbidden_start_tokens = set() for bad_word in bad_words_token_ids: forbidden_start_tokens.add(bad_word[0]) # Get a safe token that's not in forbidden starts - safe_token_candidates = list( - set(range(vocab_size)) - forbidden_start_tokens) + safe_token_candidates = list(set(range(vocab_size)) - forbidden_start_tokens) # Pick a random safe token return np.random.choice(safe_token_candidates) def _update_output_token_ids_for_bad_words( - metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]: + metadata: SamplingMetadata, vocab_size: int +) -> dict[int, list[int]]: bad_words_last_tokens = {} for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items(): output_token_ids = metadata.output_token_ids[batch_idx] @@ -132,12 +133,13 @@ def _update_output_token_ids_for_bad_words( # Collect all last tokens from other bad words # that share this prefix bad_words_last_token.extend( - _collect_suffixes_with_same_prefix( - prefix, bad_words_token_ids)) + _collect_suffixes_with_same_prefix(prefix, bad_words_token_ids) + ) break # Maximum one update to output_token_ids else: # Make sure no accidental match to bad words output_token_ids[-1] = _generate_valid_token_id( - bad_words_token_ids, vocab_size) + bad_words_token_ids, vocab_size + ) bad_words_last_tokens[batch_idx] = bad_words_last_token return bad_words_last_tokens @@ -152,22 +154,24 @@ def _create_default_sampling_metadata( prompt_token_ids: list[list[int]] = [] for _ in range(batch_size): output_token_ids.append( - np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) + np.random.randint(0, vocab_size, size=num_output_tokens).tolist() + ) prompt_token_ids.append( - np.random.randint(0, - vocab_size, - size=np.random.randint( - 1, MAX_NUM_PROMPT_TOKENS)).tolist()) + np.random.randint( + 0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS) + ).tolist() + ) fake_sampling_metadata = SamplingMetadata( - temperature=torch.full((batch_size, ), 0.0), + temperature=torch.full((batch_size,), 0.0), all_greedy=True, all_random=False, top_p=None, top_k=None, generators={}, max_num_logprobs=0, - prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, - vocab_size, device), + prompt_token_ids=_create_prompt_tokens_tensor( + prompt_token_ids, vocab_size, device + ), output_token_ids=output_token_ids, frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), @@ -181,8 +185,8 @@ def _create_default_sampling_metadata( def _create_weighted_output_token_list( - batch_size: int, - vocab_size: int) -> tuple[list[list[int]], list[list[int]]]: + batch_size: int, vocab_size: int +) -> tuple[list[list[int]], list[list[int]]]: """ Creates an output token list where each token occurs a distinct number of times. @@ -203,14 +207,13 @@ def _create_weighted_output_token_list( output_token_ids: list[list[int]] = [] sorted_token_ids_in_output: list[list[int]] = [] for _ in range(batch_size): - distinct_token_ids = np.random.choice(vocab_size, - size=np.random.randint(1, 10), - replace=False).tolist() + distinct_token_ids = np.random.choice( + vocab_size, size=np.random.randint(1, 10), replace=False + ).tolist() sorted_token_ids_in_output.append(distinct_token_ids) output_token_ids_for_batch = [] for index, token_id in enumerate(distinct_token_ids): - output_token_ids_for_batch.extend( - [token_id for _ in range(index + 1)]) + output_token_ids_for_batch.extend([token_id for _ in range(index + 1)]) output_token_ids.append(output_token_ids_for_batch) return output_token_ids, sorted_token_ids_in_output @@ -218,8 +221,9 @@ def _create_weighted_output_token_list( @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("presence_penalty", [-2.0, 2.0]) -def test_sampler_presence_penalty(device: str, batch_size: int, - presence_penalty: float): +def test_sampler_presence_penalty( + device: str, batch_size: int, presence_penalty: float +): """ Test to verify that if presence penalty is enabled then tokens are penalized as per their presence in the existing output. @@ -229,10 +233,12 @@ def test_sampler_presence_penalty(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) output_token_ids = sampling_metadata.output_token_ids sampling_metadata.presence_penalties = _create_penalty_tensor( - batch_size, presence_penalty, torch.device(device)) + batch_size, presence_penalty, torch.device(device) + ) sampling_metadata.no_penalties = False sampler = Sampler() logits = sampler.apply_penalties(fake_logits, sampling_metadata) @@ -263,8 +269,9 @@ def test_sampler_presence_penalty(device: str, batch_size: int, @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0]) -def test_sampler_frequency_penalty(device: str, batch_size: int, - frequency_penalty: float): +def test_sampler_frequency_penalty( + device: str, batch_size: int, frequency_penalty: float +): """ Test to verify that if frequency penalty is enabled then tokens are penalized as per their frequency of occurrence. @@ -274,14 +281,15 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) sampling_metadata.frequency_penalties = _create_penalty_tensor( - batch_size, frequency_penalty, torch.device(device)) - output_token_ids, sorted_token_ids_in_output = \ - _create_weighted_output_token_list( - batch_size, - VOCAB_SIZE, - ) + batch_size, frequency_penalty, torch.device(device) + ) + output_token_ids, sorted_token_ids_in_output = _create_weighted_output_token_list( + batch_size, + VOCAB_SIZE, + ) sampling_metadata.output_token_ids = output_token_ids sampling_metadata.no_penalties = False sampler = Sampler() @@ -290,18 +298,17 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, for batch_idx in range(batch_size): non_penalized_token_id = logits[batch_idx].argmax().item() penalized_token_id = logits[batch_idx].argmin().item() - distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[ - batch_idx] + distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[batch_idx] most_frequent_token_id = distinct_sorted_token_ids_in_output[ - len(distinct_sorted_token_ids_in_output) - 1] + len(distinct_sorted_token_ids_in_output) - 1 + ] if frequency_penalty > 0: # If `frequency_penalty` is set to > 0, it indicates # a preference for new tokens over existing ones. Verify that the # non-penalized token ID is not present in the output, while the # most penalized token is the one that occurs most frequently in # the output. - assert (non_penalized_token_id - not in distinct_sorted_token_ids_in_output) + assert non_penalized_token_id not in distinct_sorted_token_ids_in_output assert penalized_token_id == most_frequent_token_id elif frequency_penalty < 0: # If `frequency_penalty` is set to < 0, it indicates @@ -316,8 +323,9 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("repetition_penalty", [0.1, 1.9]) -def test_sampler_repetition_penalty(device: str, batch_size: int, - repetition_penalty: float): +def test_sampler_repetition_penalty( + device: str, batch_size: int, repetition_penalty: float +): """ Test to verify that when the repetition penalty is enabled, tokens are penalized based on their presence in the prompt or the existing @@ -328,9 +336,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) sampling_metadata.repetition_penalties = _create_penalty_tensor( - batch_size, repetition_penalty, torch.device(device)) + batch_size, repetition_penalty, torch.device(device) + ) sampling_metadata.no_penalties = False sampler = Sampler() logits = sampler.apply_penalties(fake_logits, sampling_metadata) @@ -338,32 +348,40 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, for batch_idx in range(batch_size): non_penalized_token_id = logits[batch_idx].argmax().item() penalized_token_id = logits[batch_idx].argmin().item() - prompt_tokens = sampling_metadata.prompt_token_ids[ - batch_idx][:].tolist() + prompt_tokens = sampling_metadata.prompt_token_ids[batch_idx][:].tolist() output_tokens = sampling_metadata.output_token_ids[batch_idx] if repetition_penalty > 1.0: # If `repetition_penalty` > 1.0, verify that the non-penalized # token ID has not been seen before, while the penalized token ID # exists either in the prompt or the output. - assert (non_penalized_token_id not in prompt_tokens - and non_penalized_token_id not in output_tokens) - assert (penalized_token_id in prompt_tokens - or penalized_token_id in output_tokens) + assert ( + non_penalized_token_id not in prompt_tokens + and non_penalized_token_id not in output_tokens + ) + assert ( + penalized_token_id in prompt_tokens + or penalized_token_id in output_tokens + ) elif repetition_penalty < 1.0: # If `repetition_penalty` < 1.0, verify that the penalized # token ID has not been seen before, while the non-penalized # token ID exists either in the prompt or the output. - assert (penalized_token_id not in prompt_tokens - and penalized_token_id not in output_tokens) - assert (non_penalized_token_id in prompt_tokens - or non_penalized_token_id in output_tokens) + assert ( + penalized_token_id not in prompt_tokens + and penalized_token_id not in output_tokens + ) + assert ( + non_penalized_token_id in prompt_tokens + or non_penalized_token_id in output_tokens + ) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2]) -def test_sampler_allowed_token_ids(device: str, batch_size: int, - num_allowed_token_ids: int): +def test_sampler_allowed_token_ids( + device: str, batch_size: int, num_allowed_token_ids: int +): """ Test to verify that when the repetition penalty is enabled, tokens are penalized based on their presence in the prompt or the existing @@ -374,7 +392,8 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) mask = _create_allowed_token_ids( batch_size=batch_size, vocab_size=VOCAB_SIZE, @@ -394,17 +413,19 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int, start = min(batch_idx, VOCAB_SIZE - 1) end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1) if token_id >= start and token_id < end: - assert logits_for_req[token_id] == -float( - "inf"), f"{batch_idx}, {token_id}" + assert logits_for_req[token_id] == -float("inf"), ( + f"{batch_idx}, {token_id}" + ) else: assert logits_for_req[token_id] != -float("inf") @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) -@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)]) -def test_sampler_bad_words(device: str, batch_size: int, - bad_words_lengths: tuple[int, ...]): +@pytest.mark.parametrize("bad_words_lengths", [(1,), (1, 3), (2, 2)]) +def test_sampler_bad_words( + device: str, batch_size: int, bad_words_lengths: tuple[int, ...] +): """ Test to verify that when the bad words restriction is present, tokens are penalized based on their match with the bad words. @@ -414,19 +435,24 @@ def test_sampler_bad_words(device: str, batch_size: int, # logit value. fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids( - batch_size, VOCAB_SIZE, bad_words_lengths) + batch_size, VOCAB_SIZE, bad_words_lengths + ) bad_words_last_tokens = _update_output_token_ids_for_bad_words( - sampling_metadata, VOCAB_SIZE) + sampling_metadata, VOCAB_SIZE + ) sampler = Sampler() logits = sampler.apply_bad_words(fake_logits, sampling_metadata) logits = logits.cpu() for batch_idx in range(batch_size): logits_for_req = logits[batch_idx] for token_id in range(VOCAB_SIZE): - if (batch_idx in bad_words_last_tokens - and token_id in bad_words_last_tokens[batch_idx]): + if ( + batch_idx in bad_words_last_tokens + and token_id in bad_words_last_tokens[batch_idx] + ): assert logits_for_req[token_id] == -float("inf") else: assert logits_for_req[token_id] != -float("inf") diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index f53e1e1c485d..24f9397cc4c6 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -66,9 +66,9 @@ def test_stop(llm): # Output should not contain the stop word. assert len(new_split_text) == STOP_IDX - params = SamplingParams(temperature=0, - stop=split_text[STOP_IDX], - include_stop_str_in_output=True) + params = SamplingParams( + temperature=0, stop=split_text[STOP_IDX], include_stop_str_in_output=True + ) output = llm.generate(PROMPT, params) new_split_text = output[0].outputs[0].text.split() @@ -103,8 +103,8 @@ def test_detokenize_false(llm): assert len(output[0].outputs[0].text) == 0 output = llm.generate( - PROMPT, SamplingParams(detokenize=False, logprobs=3, - prompt_logprobs=3)) + PROMPT, SamplingParams(detokenize=False, logprobs=3, prompt_logprobs=3) + ) assert len(output[0].outputs[0].token_ids) > 0 assert len(output[0].outputs[0].text) == 0 @@ -131,8 +131,7 @@ def test_bad_words(llm): assert bad_words_1 not in new_text bad_words_2 = new_text.split()[-1] - params = SamplingParams(temperature=0, - bad_words=[bad_words_1, bad_words_2]) + params = SamplingParams(temperature=0, bad_words=[bad_words_1, bad_words_2]) output = llm.generate(PROMPT, params) new_text = output[0].outputs[0].text assert bad_words_1 not in new_text @@ -158,8 +157,7 @@ def test_allowed_token_ids(llm): TOKEN_ID = 10 allowed_token_ids = [TOKEN_ID] - output = llm.generate(PROMPT, - SamplingParams(allowed_token_ids=allowed_token_ids)) + output = llm.generate(PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids)) assert output[0].outputs[0].token_ids[-1] == TOKEN_ID # Reject empty allowed_token_ids. diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index ccf38c31d39e..c70cbebe22ca 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -5,8 +5,10 @@ from torch import Generator from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, - is_flashinfer_available) +from vllm.v1.sample.ops.topk_topp_sampler import ( + apply_top_k_top_p, + is_flashinfer_available, +) DEVICE = current_platform.device_type @@ -30,19 +32,18 @@ def reset_default_device(): def test_topk_impl_equivalence(): - torch.set_default_device(DEVICE) generator = Generator(device=DEVICE).manual_seed(33) logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) # Random top-k values between 1 and 9. - k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator) + k = torch.randint(1, 10, (BATCH_SIZE,), generator=generator) # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). k.masked_fill_( - torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool), - VOCAB_SIZE) + torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=bool), VOCAB_SIZE + ) # Top-k only implementation result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) @@ -55,7 +56,7 @@ def test_topk_impl_equivalence(): def test_flashinfer_sampler(): - ''' + """ This test verifies that the FlashInfer top-k and top-p sampling implementation produces the same results as the Python implementation. @@ -63,11 +64,10 @@ def test_flashinfer_sampler(): top-p prob renorm (it did provide fused sampling but we cannot compare sampling results due to randomness), so we will compare the probability renormed consequently by top-k and then top-p of FlashInfer implementation. - ''' + """ if not FLASHINFER_ENABLED: - pytest.skip( - "FlashInfer not installed or not available on this platform.") + pytest.skip("FlashInfer not installed or not available on this platform.") torch.set_default_device(DEVICE) generator = Generator(device=DEVICE).manual_seed(42) @@ -76,23 +76,21 @@ def test_flashinfer_sampler(): logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) # Generate various top-k and top-p values - k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator) - p_values = torch.rand( - (BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0] + k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator) + p_values = ( + torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5 + ) # range in [0.5, 1.0] # Sometimes disable top-k (k=vocab_size) k_values.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=torch.bool), VOCAB_SIZE) + torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), + VOCAB_SIZE, + ) # Sometimes disable top-p (p=1.0) p_values.masked_fill_( - torch.randint(0, - 2, (BATCH_SIZE, ), - generator=generator, - dtype=torch.bool), 1.0) + torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0 + ) python_logits = apply_top_k_top_p( logits=logits.clone(), @@ -113,5 +111,6 @@ def test_flashinfer_sampler(): ) # Compare the results - assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \ + assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), ( "FlashInfer and Python sampling implementations do not match!" + ) diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index e33efb413d02..0f1214e9745c 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -16,6 +16,7 @@ class BatchLogprobsComposition(Enum): """Types of logprobs configs to include in test batch""" + NONE = 0 SAMPLE = 1 PROMPT = 2 @@ -26,10 +27,10 @@ class BatchLogprobsComposition(Enum): def get_test_batch( - batch_logprobs_composition: BatchLogprobsComposition + batch_logprobs_composition: BatchLogprobsComposition, ) -> BatchLogprobsSpecType: """Generate logprobs configs for a batch of requests - + A given request's logprobs configuration is (1) num_sample_logprobs and (2) num_prompt_logprobs. The batch logprobs configuration is the list of request logprobs configs. @@ -101,7 +102,7 @@ def assert_incr_detok_str_matches_non_incr_detok_str( msg: str, ) -> None: """Compare incrementally detok. text to non-incrementally detok. text - + Fail if the strings mismatch after non-alphanumeric characters are stripped out. @@ -120,15 +121,15 @@ def assert_incr_detok_str_matches_non_incr_detok_str( tokens msg: error message if `assert` fails """ - rgx = r'[^a-zA-Z0-9]+' - assert (re.sub(rgx, '', incremental_detokenization_str) == re.sub( - rgx, '', non_incremental_detokenization_str)), (msg) + rgx = r"[^a-zA-Z0-9]+" + assert re.sub(rgx, "", incremental_detokenization_str) == re.sub( + rgx, "", non_incremental_detokenization_str + ), msg -def compute_correct_cumulative_logprob( - completion_output: CompletionOutput) -> float: +def compute_correct_cumulative_logprob(completion_output: CompletionOutput) -> float: """Compute known-good value for evaluating cumulative logprob - + Args: completion_output: completion output from engine @@ -146,12 +147,12 @@ def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: return fake_logits -def create_penalty_tensor(batch_size: int, penalty_value: float, - device: torch.device) -> torch.Tensor: - return torch.full((batch_size, ), - fill_value=penalty_value, - dtype=torch.float, - device=device) +def create_penalty_tensor( + batch_size: int, penalty_value: float, device: torch.device +) -> torch.Tensor: + return torch.full( + (batch_size,), fill_value=penalty_value, dtype=torch.float, device=device + ) def create_prompt_tokens_tensor( @@ -170,6 +171,7 @@ def create_prompt_tokens_tensor( class LogitsprocsTestFakes(NamedTuple): """Wraps fake data structures to support testing""" + logits: torch.Tensor sampling_metadata: SamplingMetadata @@ -178,15 +180,16 @@ def get_logitsprocs_by_cls( cls: type[LogitsProcessor], ) -> Iterator[LogitsProcessor]: """Yield logits processors of a specific class. - + Args: cls: :class:`LogitsProcessor` subclass Returns: Iterator over logits processors """ - return (lp for lp in self.sampling_metadata.logitsprocs.all - if isinstance(lp, cls)) + return ( + lp for lp in self.sampling_metadata.logitsprocs.all if isinstance(lp, cls) + ) def get_logitsprocs(self) -> Iterator[LogitsProcessor]: """Iterator over all logits processors.""" @@ -208,8 +211,7 @@ def fake_apply_logitsprocs( slice_indices: list[int], ) -> torch.Tensor: """Imitate application of logits processors in engine core""" - logits = test_fakes.logits[torch.tensor(slice_indices, - dtype=torch.long)].clone() + logits = test_fakes.logits[torch.tensor(slice_indices, dtype=torch.long)].clone() for processor in test_fakes.get_logitsprocs(): logits = processor.apply(logits) return logits diff --git a/tests/v1/shutdown/test_delete.py b/tests/v1/shutdown/test_delete.py index 682d84dc23d1..d94357827864 100644 --- a/tests/v1/shutdown/test_delete.py +++ b/tests/v1/shutdown/test_delete.py @@ -5,8 +5,10 @@ import pytest from tests.utils import wait_for_gpu_memory_to_clear -from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, - SHUTDOWN_TEST_TIMEOUT_SEC) +from tests.v1.shutdown.utils import ( + SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC, +) from vllm import LLM, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import RequestOutputKind @@ -21,8 +23,9 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("send_one_request", [False, True]) -async def test_async_llm_delete(model: str, tensor_parallel_size: int, - send_one_request: bool) -> None: +async def test_async_llm_delete( + model: str, tensor_parallel_size: int, send_one_request: bool +) -> None: """Test that AsyncLLM frees GPU memory upon deletion. AsyncLLM always uses an MP client. @@ -34,19 +37,21 @@ async def test_async_llm_delete(model: str, tensor_parallel_size: int, if cuda_device_count_stateless() < tensor_parallel_size: pytest.skip(reason="Not enough CUDA devices") - engine_args = AsyncEngineArgs(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + engine_args = AsyncEngineArgs( + model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size + ) # Instantiate AsyncLLM; make request to complete any deferred # initialization; then delete instance async_llm = AsyncLLM.from_engine_args(engine_args) if send_one_request: async for _ in async_llm.generate( - "Hello my name is", - request_id="abc", - sampling_params=SamplingParams( - max_tokens=1, output_kind=RequestOutputKind.DELTA)): + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams( + max_tokens=1, output_kind=RequestOutputKind.DELTA + ), + ): pass del async_llm @@ -62,9 +67,13 @@ async def test_async_llm_delete(model: str, tensor_parallel_size: int, @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("enable_multiprocessing", [True]) @pytest.mark.parametrize("send_one_request", [False, True]) -def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int, - enable_multiprocessing: bool, - send_one_request: bool) -> None: +def test_llm_delete( + monkeypatch, + model: str, + tensor_parallel_size: int, + enable_multiprocessing: bool, + send_one_request: bool, +) -> None: """Test that LLM frees GPU memory upon deletion. TODO(andy) - LLM without multiprocessing. @@ -83,12 +92,13 @@ def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int, # Instantiate LLM; make request to complete any deferred # initialization; then delete instance - llm = LLM(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + llm = LLM( + model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size + ) if send_one_request: - llm.generate("Hello my name is", - sampling_params=SamplingParams(max_tokens=1)) + llm.generate( + "Hello my name is", sampling_params=SamplingParams(max_tokens=1) + ) del llm # Confirm all the processes are cleaned up. diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py index 523b7ee23115..383348e88540 100644 --- a/tests/v1/shutdown/test_forward_error.py +++ b/tests/v1/shutdown/test_forward_error.py @@ -7,8 +7,10 @@ import pytest from tests.utils import wait_for_gpu_memory_to_clear -from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, - SHUTDOWN_TEST_TIMEOUT_SEC) +from tests.v1.shutdown.utils import ( + SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC, +) from vllm import LLM, AsyncEngineArgs, SamplingParams from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.models.llama import LlamaForCausalLM @@ -26,8 +28,10 @@ def evil_forward(self, *args, **kwargs): if not hasattr(self, "num_calls"): self.num_calls = 0 - if (self.num_calls == NUMBER_OF_GOOD_PASSES - and get_tensor_model_parallel_rank() == 0): + if ( + self.num_calls == NUMBER_OF_GOOD_PASSES + and get_tensor_model_parallel_rank() == 0 + ): raise Exception("Simulated illegal memory access on Rank 0!") self.num_calls += 1 @@ -37,10 +41,11 @@ def evil_forward(self, *args, **kwargs): @pytest.mark.asyncio @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("model", MODELS) -async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int, - model: str) -> None: +async def test_async_llm_model_error( + monkeypatch, tensor_parallel_size: int, model: str +) -> None: """Test that AsyncLLM propagates a forward pass error and frees memory. - + AsyncLLM always uses an MP client. """ if cuda_device_count_stateless() < tensor_parallel_size: @@ -49,15 +54,15 @@ async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int, # Monkeypatch an error in the model. monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward) - engine_args = AsyncEngineArgs(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + engine_args = AsyncEngineArgs( + model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size + ) async_llm = AsyncLLM.from_engine_args(engine_args) async def generate(request_id: str): - generator = async_llm.generate("Hello my name is", - request_id=request_id, - sampling_params=SamplingParams()) + generator = async_llm.generate( + "Hello my name is", request_id=request_id, sampling_params=SamplingParams() + ) try: async for _ in generator: pass @@ -77,9 +82,9 @@ async def generate(request_id: str): # We should not be able to make another request. with pytest.raises(EngineDeadError): - async for _ in async_llm.generate("Hello my name is", - request_id="abc", - sampling_params=SamplingParams()): + async for _ in async_llm.generate( + "Hello my name is", request_id="abc", sampling_params=SamplingParams() + ): raise Exception("We should not get here.") # Confirm all the processes are cleaned up. @@ -98,8 +103,9 @@ async def generate(request_id: str): @pytest.mark.parametrize("enable_multiprocessing", [True]) @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("model", MODELS) -def test_llm_model_error(monkeypatch, tensor_parallel_size: int, - enable_multiprocessing: bool, model: str) -> None: +def test_llm_model_error( + monkeypatch, tensor_parallel_size: int, enable_multiprocessing: bool, model: str +) -> None: """Test that LLM propagates a forward pass error and frees memory. TODO(andy) - LLM without multiprocessing; LLM with multiprocessing and >1 rank @@ -108,19 +114,17 @@ def test_llm_model_error(monkeypatch, tensor_parallel_size: int, pytest.skip(reason="Not enough CUDA devices") with monkeypatch.context() as m: - MP_VALUE = "1" if enable_multiprocessing else "0" m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) # Monkeypatch an error in the model. m.setattr(LlamaForCausalLM, "forward", evil_forward) - llm = LLM(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + llm = LLM( + model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size + ) - with pytest.raises( - EngineDeadError if enable_multiprocessing else Exception): + with pytest.raises(EngineDeadError if enable_multiprocessing else Exception): llm.generate("Hello my name is Robert and I") # Confirm all the processes are cleaned up. diff --git a/tests/v1/shutdown/test_processor_error.py b/tests/v1/shutdown/test_processor_error.py index a077d48fecbb..013b929e3df6 100644 --- a/tests/v1/shutdown/test_processor_error.py +++ b/tests/v1/shutdown/test_processor_error.py @@ -30,9 +30,9 @@ async def test_async_llm_processor_error(model: str) -> None: async def generate(request_id: str): # [] is not allowed and will raise a ValueError in Processor. - generator = async_llm.generate(TokensPrompt([]), - request_id=request_id, - sampling_params=SamplingParams()) + generator = async_llm.generate( + TokensPrompt([]), request_id=request_id, sampling_params=SamplingParams() + ) try: async for _ in generator: pass @@ -55,11 +55,12 @@ async def generate(request_id: str): EXPECTED_TOKENS = 5 outputs = [] async for out in async_llm.generate( - "Hello my name is", - request_id="abc", - sampling_params=SamplingParams( - max_tokens=EXPECTED_TOKENS, - output_kind=RequestOutputKind.DELTA)): + "Hello my name is", + request_id="abc", + sampling_params=SamplingParams( + max_tokens=EXPECTED_TOKENS, output_kind=RequestOutputKind.DELTA + ), + ): outputs.append(out) generated_tokens = [] diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py index 88fc5297aaf5..019c0c4d7cf0 100644 --- a/tests/v1/shutdown/test_startup_error.py +++ b/tests/v1/shutdown/test_startup_error.py @@ -5,8 +5,10 @@ import pytest from tests.utils import wait_for_gpu_memory_to_clear -from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES, - SHUTDOWN_TEST_TIMEOUT_SEC) +from tests.v1.shutdown.utils import ( + SHUTDOWN_TEST_THRESHOLD_BYTES, + SHUTDOWN_TEST_TIMEOUT_SEC, +) from vllm import LLM from vllm.distributed import get_tensor_model_parallel_rank from vllm.engine.arg_utils import AsyncEngineArgs @@ -30,9 +32,9 @@ def evil_method(self, *args, **kwargs): @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("failing_method", ["forward", "load_weights"]) -def test_async_llm_startup_error(monkeypatch, model: str, - tensor_parallel_size: int, - failing_method: str) -> None: +def test_async_llm_startup_error( + monkeypatch, model: str, tensor_parallel_size: int, failing_method: str +) -> None: """Test that AsyncLLM propagates an __init__ error & frees memory. Test profiling (forward()) and load weights failures. AsyncLLM always uses an MP client. @@ -43,9 +45,9 @@ def test_async_llm_startup_error(monkeypatch, model: str, # Monkeypatch an error in the model. monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method) - engine_args = AsyncEngineArgs(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + engine_args = AsyncEngineArgs( + model=model, enforce_eager=True, tensor_parallel_size=tensor_parallel_size + ) # Confirm we get an exception. with pytest.raises(Exception, match="initialization failed"): @@ -63,9 +65,13 @@ def test_async_llm_startup_error(monkeypatch, model: str, @pytest.mark.parametrize("tensor_parallel_size", [2, 1]) @pytest.mark.parametrize("enable_multiprocessing", [True]) @pytest.mark.parametrize("failing_method", ["forward", "load_weights"]) -def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int, - enable_multiprocessing: bool, - failing_method: str) -> None: +def test_llm_startup_error( + monkeypatch, + model: str, + tensor_parallel_size: int, + enable_multiprocessing: bool, + failing_method: str, +) -> None: """Test that LLM propagates an __init__ error and frees memory. Test profiling (forward()) and load weights failures. TODO(andy) - LLM without multiprocessing. @@ -76,7 +82,6 @@ def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int, pytest.skip(reason="Not enough CUDA devices") with monkeypatch.context() as m: - MP_VALUE = "1" if enable_multiprocessing else "0" m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE) @@ -84,12 +89,16 @@ def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int, monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method) with pytest.raises( - Exception, - match="initialization failed" - if enable_multiprocessing else "Simulated Error in startup!"): - _ = LLM(model=model, - enforce_eager=True, - tensor_parallel_size=tensor_parallel_size) + Exception, + match="initialization failed" + if enable_multiprocessing + else "Simulated Error in startup!", + ): + _ = LLM( + model=model, + enforce_eager=True, + tensor_parallel_size=tensor_parallel_size, + ) # Confirm all the processes are cleaned up. wait_for_gpu_memory_to_clear( diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 938c6543e9b0..4c490f2188aa 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -8,13 +8,22 @@ import torch from tests.utils import get_attn_backend_list_based_on_platform -from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, - create_standard_kv_cache_spec, - get_attention_backend) +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + get_attention_backend, +) from vllm.attention.backends.registry import _Backend -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VllmConfig) +from vllm.config import ( + CacheConfig, + DeviceConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) from vllm.config.load import LoadConfig from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform @@ -32,9 +41,7 @@ def _create_proposer( num_speculative_tokens: int, speculative_token_tree: Optional[list[tuple[int, ...]]] = None, ) -> EagleProposer: - model_config = ModelConfig(model=model_dir, - runner="generate", - max_model_len=100) + model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) # Choose model directory based on method draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir @@ -60,10 +67,10 @@ def _create_proposer( device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig()) + scheduler_config=SchedulerConfig(), + ) - return EagleProposer(vllm_config=vllm_config, - device=current_platform.device_type) + return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) def test_prepare_next_token_ids(): @@ -82,7 +89,7 @@ def test_prepare_next_token_ids(): query_lens=[num_speculative_tokens + 1] * num_requests, ) - req_ids = [f"req_{i+1}" for i in range(num_requests)] + req_ids = [f"req_{i + 1}" for i in range(num_requests)] mock_input_batch = mock.MagicMock(spec=InputBatch) mock_input_batch.req_ids = req_ids mock_input_batch.num_reqs = num_requests @@ -101,24 +108,26 @@ def test_prepare_next_token_ids(): [0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled [0, 1, 2, 3, 4], # all accepted, "4" sampled [-1, -1, -1, -1, -1], # sampling skipped, use backup token "30" - [-1, -1, -1, -1, -1] # this request will be discarded + [-1, -1, -1, -1, -1], # this request will be discarded ] - sampled_token_ids_tensor = torch.tensor(sampled_token_ids, - dtype=torch.int32, - device=device) - sampled_token_ids_cpu = [[i for i in seq if i != -1] - for seq in sampled_token_ids] + sampled_token_ids_tensor = torch.tensor( + sampled_token_ids, dtype=torch.int32, device=device + ) + sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids] expected_next_token_ids_cpu = [1, 4, 30, 40] - expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu, - dtype=torch.int32, - device=device) + expected_next_token_ids_tensor = torch.tensor( + expected_next_token_ids_cpu, dtype=torch.int32, device=device + ) proposer = _create_proposer("eagle", num_speculative_tokens) next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu( - sampled_token_ids_cpu, mock_requests, mock_input_batch, - mock_num_scheduled_tokens) + sampled_token_ids_cpu, + mock_requests, + mock_input_batch, + mock_num_scheduled_tokens, + ) assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor) @@ -131,19 +140,23 @@ def test_prepare_next_token_ids(): discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device) num_discarded_reqs = 1 - expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0], - dtype=torch.int32, - device=device) + expected_valid_sampled_tokens_count = torch.tensor( + [2, 5, 0, 0], dtype=torch.int32, device=device + ) - next_token_ids_from_padded, valid_sampled_tokens_count = \ + next_token_ids_from_padded, valid_sampled_tokens_count = ( proposer.prepare_next_token_ids_padded( - common_attn_metadata, sampled_token_ids_tensor, mock_requests, - mock_input_batch, discarded_req_indices, num_discarded_reqs) + common_attn_metadata, + sampled_token_ids_tensor, + mock_requests, + mock_input_batch, + discarded_req_indices, + num_discarded_reqs, + ) + ) - assert torch.equal(next_token_ids_from_padded, - expected_next_token_ids_tensor) - assert torch.equal(valid_sampled_tokens_count, - expected_valid_sampled_tokens_count) + assert torch.equal(next_token_ids_from_padded, expected_next_token_ids_tensor) + assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count) def test_prepare_inputs(): @@ -183,21 +196,27 @@ def test_prepare_inputs(): sampled_token_ids = [ [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN], [ - ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, - REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN + ACCEPT_TOKEN, + ACCEPT_TOKEN, + ACCEPT_TOKEN, + REJECT_TOKEN, + REJECT_TOKEN, + REJECT_TOKEN, + BONUS_TOKEN, ], - [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN] + [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN], + ] + sampled_token_ids = [ + [i for i in seq if i != REJECT_TOKEN] for seq in sampled_token_ids ] - sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN] - for seq in sampled_token_ids] # Expected calculations: # query_len_per_req = [4, 7, 5] # num_tokens_per_req = [3, 4, 3] (after subtracting rejected tokens) # Expected cumulative counts: [0, 3, 7, 10] - expected_cu_num_tokens = torch.tensor([0, 3, 7, 10], - dtype=torch.int32, - device=device) + expected_cu_num_tokens = torch.tensor( + [0, 3, 7, 10], dtype=torch.int32, device=device + ) # Expected token indices (mapped from original positions): # First request: indices 0, 1, 2 (keeping first 3 from positions 0-3) @@ -214,17 +233,18 @@ def test_prepare_inputs(): 7, # Second request: 4 tokens (7-3) 11, 12, - 13 # Third request: 3 tokens (5-2) + 13, # Third request: 3 tokens (5-2) ], dtype=torch.int32, - device=device) + device=device, + ) proposer = _create_proposer("eagle", 1) updated_metadata, token_indices = proposer.prepare_inputs( - common_attn_metadata, sampled_token_ids, num_draft_tokens) + common_attn_metadata, sampled_token_ids, num_draft_tokens + ) - assert torch.equal(updated_metadata.query_start_loc, - expected_cu_num_tokens) + assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens) assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() assert torch.equal(token_indices, expected_token_indices) @@ -249,12 +269,12 @@ def test_prepare_inputs_padded(): device = torch.device(current_platform.device_type) - expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.int32, - device=device) - expected_token_indices_to_sample = torch.tensor([1, 5, 6], - dtype=torch.int32, - device=device) + expected_token_indices = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int32, device=device + ) + expected_token_indices_to_sample = torch.tensor( + [1, 5, 6], dtype=torch.int32, device=device + ) num_speculative_tokens = 2 batch_spec = BatchSpec( @@ -269,9 +289,9 @@ def test_prepare_inputs_padded(): ) # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9] - expected_query_start_loc = torch.tensor([0, 3, 6, 9], - dtype=torch.int32, - device=device) + expected_query_start_loc = torch.tensor( + [0, 3, 6, 9], dtype=torch.int32, device=device + ) spec_decode_metadata = SpecDecodeMetadata.make_dummy( draft_token_ids=[[0] * num_speculative_tokens] * 3, device=device, @@ -280,43 +300,48 @@ def test_prepare_inputs_padded(): # num_rejected_tokens = [1, 0, 2] # num_draft_tokens = [2, 2, 2] # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens - valid_sampled_tokens_count = torch.tensor([2, 3, 1], - dtype=torch.int32, - device=device) + valid_sampled_tokens_count = torch.tensor( + [2, 3, 1], dtype=torch.int32, device=device + ) proposer = _create_proposer("eagle", num_speculative_tokens) - output_metadata, token_indices, token_indices_to_sample = \ + output_metadata, token_indices, token_indices_to_sample = ( proposer.prepare_inputs_padded( - common_attn_metadata, - spec_decode_metadata, - valid_sampled_tokens_count) + common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + ) + ) assert output_metadata.max_query_len == 3 - assert torch.equal(output_metadata.query_start_loc, - expected_query_start_loc) + assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc) assert torch.equal(token_indices, expected_token_indices) - assert torch.equal(token_indices_to_sample, - expected_token_indices_to_sample) + assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) @pytest.mark.parametrize("method", ["eagle", "eagle3"]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) -@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') -@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') -@mock.patch('vllm.v1.spec_decode.eagle.get_model') -def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, - attn_backend, pp_size, use_distinct_embed_tokens, - monkeypatch): - +@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") +@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") +@mock.patch("vllm.v1.spec_decode.eagle.get_model") +def test_load_model( + mock_get_model, + mock_get_layers, + mock_get_pp_group, + method, + attn_backend, + pp_size, + use_distinct_embed_tokens, + monkeypatch, +): monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -335,20 +360,20 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, # Setup mocks for attention layers target_attn_layers = { "target_attn_1": mock.MagicMock(), - "target_attn_2": mock.MagicMock() + "target_attn_2": mock.MagicMock(), } target_indx_layers: dict[str, mock.MagicMock] = {} # Draft model has one extra attention layer compared to target model - all_attn_layers = { - **target_attn_layers, "draft_extra_attn": mock.MagicMock() - } + all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()} all_indx_layers: dict[str, mock.MagicMock] = {} # Make mock_get_layers return different values for each call mock_get_layers.side_effect = [ - target_attn_layers, target_indx_layers, all_attn_layers, - all_indx_layers + target_attn_layers, + target_indx_layers, + all_attn_layers, + all_indx_layers, ] # Setup mock for pp group to return the appropriate value for world size @@ -367,6 +392,7 @@ class _TargetModelStub(LlamaForCausalLM): target_model.model.embed_tokens.weight.shape = (131072, 4096) from vllm.model_executor.models import SupportsMultiModal + assert not isinstance(target_model, SupportsMultiModal) if method == "eagle": @@ -388,30 +414,30 @@ class _TargetModelStub(LlamaForCausalLM): # Verify that the embed tokens are set correctly # If pp_size is > 1, the embed tokens should be distinct if pp_size > 1 or use_distinct_embed_tokens: - assert proposer.model.model.embed_tokens != \ - target_model.model.embed_tokens + assert proposer.model.model.embed_tokens != target_model.model.embed_tokens else: # When pp_size is 1 and the draft and target models have # embed_tokens of the same shape, they should be shared. - assert proposer.model.model.embed_tokens == \ - target_model.model.embed_tokens + assert proposer.model.model.embed_tokens == target_model.model.embed_tokens @pytest.mark.parametrize("method", ["eagle", "eagle3"]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) - if (attn_backend == "TREE_ATTN"): - pytest.skip("TREE_ATTN is tested separately in test_propose_tree" - "because it requires special input mocking.") + if attn_backend == "TREE_ATTN": + pytest.skip( + "TREE_ATTN is tested separately in test_propose_tree" + "because it requires special input mocking." + ) if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -498,31 +524,22 @@ def create_deterministic_logits(token_ids): device=device, ) - target_token_ids = torch.randint(0, - vocab_size, (total_tokens, ), - device=device) - target_positions = torch.cat([ - torch.arange(seq_len_1, device=device), - torch.arange(seq_len_2, device=device) - ]) - target_hidden_states = torch.randn(total_tokens, - hidden_size, - device=device) - next_token_ids = torch.randint(0, - vocab_size, (batch_size, ), - dtype=torch.int32, - device=device) + target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) + target_positions = torch.cat( + [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)] + ) + target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) + next_token_ids = torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ) sampling_metadata = mock.MagicMock() if attn_backend == "FLASH_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend( - _Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN) elif attn_backend == "TRITON_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend( - _Backend.TRITON_ATTN) + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TRITON_ATTN) elif attn_backend == "TREE_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend( - _Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") @@ -536,18 +553,22 @@ def create_deterministic_logits(token_ids): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ - attn_metadata_builder + proposer.runner.attn_groups[0][ + 0 + ].get_metadata_builder.return_value = attn_metadata_builder proposer._get_attention_metadata_builder = mock.MagicMock( - return_value=attn_metadata_builder) + return_value=attn_metadata_builder + ) - result = proposer.propose(target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=None, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata) + result = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + ) assert result.shape == (batch_size, num_speculative_tokens) @@ -556,13 +577,14 @@ def create_deterministic_logits(token_ids): # Example for num_speculative_tokens=1: # [[42], [60]] expected_tokens = torch.tensor( - [[base_token_ids[0]], [base_token_ids[1]]], device=device) + [[base_token_ids[0]], [base_token_ids[1]]], device=device + ) else: # Example for num_speculative_tokens=3: # [[42, 43, 44], [60, 61, 62]] - expected_tokens = torch.zeros((batch_size, num_speculative_tokens), - dtype=torch.int64, - device=device) + expected_tokens = torch.zeros( + (batch_size, num_speculative_tokens), dtype=torch.int64, device=device + ) for i in range(batch_size): for j in range(num_speculative_tokens): expected_tokens[i, j] = base_token_ids[i] + j @@ -574,12 +596,12 @@ def create_deterministic_logits(token_ids): @pytest.mark.parametrize( "spec_token_tree", [ - [(0, )], # A single token - [(0, ), (0, 0), (0, 0, 0)], # Chain - [(0, ), (1, ), (2, )], # Parallel - [(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), - (2, 1)], # Tree - ]) + [(0,)], # A single token + [(0,), (0, 0), (0, 0, 0)], # Chain + [(0,), (1,), (2,)], # Parallel + [(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], # Tree + ], +) def test_propose_tree(spec_token_tree): # Get GPU device. device = torch.device(current_platform.device_type) @@ -594,9 +616,9 @@ def test_propose_tree(spec_token_tree): num_speculative_tokens = len(spec_token_tree) # Create proposer first so we can use its actual hidden_size. - proposer = _create_proposer("eagle", - num_speculative_tokens, - speculative_token_tree=spec_token_tree) + proposer = _create_proposer( + "eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree + ) # Get the hidden_size from the proposer to ensure consistency. hidden_size = proposer.hidden_size @@ -617,32 +639,31 @@ def create_deterministic_logits(token_ids, k: int): model_mock = mock.MagicMock() # Mock the model forward calls. - forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device), - torch.zeros(total_tokens, hidden_size, device=device))] + forward_returns = [ + ( + torch.zeros(total_tokens, hidden_size, device=device), + torch.zeros(total_tokens, hidden_size, device=device), + ) + ] for cu_num_drafts in proposer.cu_drafts_per_level: - h_logits = torch.zeros(batch_size * cu_num_drafts, - hidden_size, - device=device) - h_states = torch.zeros(batch_size * cu_num_drafts, - hidden_size, - device=device) + h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device) + h_states = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device) forward_returns.append((h_logits, h_states)) model_mock.side_effect = forward_returns # Mock the compute_logits calls. - cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level, - dtype=torch.int32, - device=device) + cu_num_drafts_tensor = torch.tensor( + [0] + proposer.cu_drafts_per_level, dtype=torch.int32, device=device + ) logits_returns = [] for level, num_children in enumerate(proposer.child_drafts_per_level): token_ids = base_token_ids + cu_num_drafts_tensor[level] - level_num_drafts = cu_num_drafts_tensor[ - level + 1] - cu_num_drafts_tensor[level] + level_num_drafts = cu_num_drafts_tensor[level + 1] - cu_num_drafts_tensor[level] level_logits = [] for i in range(level_num_drafts // num_children): level_logits.append( - create_deterministic_logits(token_ids + i * num_children, - num_children)) + create_deterministic_logits(token_ids + i * num_children, num_children) + ) logits_returns.append(torch.stack(level_logits, dim=1)) model_mock.compute_logits.side_effect = logits_returns @@ -664,29 +685,23 @@ def create_deterministic_logits(token_ids, k: int): # Mock runner for attention metadata building. proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builders = [ - attn_metadata_builder - ] - proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ - attn_metadata_builder + proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder] + proposer.runner.attn_groups[0][ + 0 + ].get_metadata_builder.return_value = attn_metadata_builder proposer._get_attention_metadata_builder = mock.MagicMock( - return_value=attn_metadata_builder) + return_value=attn_metadata_builder + ) # Setup inputs for the proposer. - target_token_ids = torch.randint(0, - vocab_size, (total_tokens, ), - device=device) - target_positions = torch.cat([ - torch.arange(seq_len_1, device=device), - torch.arange(seq_len_2, device=device) - ]) - target_hidden_states = torch.randn(total_tokens, - hidden_size, - device=device) - next_token_ids = torch.randint(0, - vocab_size, (batch_size, ), - dtype=torch.int32, - device=device) + target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) + target_positions = torch.cat( + [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)] + ) + target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) + next_token_ids = torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ) batch_spec = BatchSpec( seq_lens=seq_lens, query_lens=seq_lens, @@ -699,19 +714,22 @@ def create_deterministic_logits(token_ids, k: int): sampling_metadata = mock.MagicMock() # Propose draft tokens. - result = proposer.propose(target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=None, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata) + result = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + ) assert result.shape == (batch_size, num_speculative_tokens) # The tokens are expected to be consecutive integers starting # from the base token IDs. expected_tokens = base_token_ids[:, None] + torch.arange( - num_speculative_tokens, dtype=torch.int64, device=device) + num_speculative_tokens, dtype=torch.int64, device=device + ) # Verify that the draft tokens match our expectations. assert torch.equal(result, expected_tokens) diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py index f93593f2d482..647887812f8a 100644 --- a/tests/v1/spec_decode/test_max_len.py +++ b/tests/v1/spec_decode/test_max_len.py @@ -33,17 +33,19 @@ def test_ngram_max_len(num_speculative_tokens: int): @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) -def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch, - num_speculative_tokens: int, attn_backend: str): +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) +def test_eagle_max_len( + monkeypatch: pytest.MonkeyPatch, num_speculative_tokens: int, attn_backend: str +): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): m.setenv("VLLM_ROCM_USE_AITER", "1") diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index dc4a56c66de6..d7d9ef07e46c 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -6,13 +6,22 @@ import pytest import torch -from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, - create_standard_kv_cache_spec, - get_attention_backend) +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + get_attention_backend, +) from vllm.attention.backends.registry import _Backend -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VllmConfig) +from vllm.config import ( + CacheConfig, + DeviceConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) from vllm.config.load import LoadConfig from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform @@ -23,10 +32,9 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: """Create an MTP proposer with unified model configuration.""" - model_config = ModelConfig(model=mimo_7b_dir, - runner="generate", - max_model_len=100, - trust_remote_code=True) + model_config = ModelConfig( + model=mimo_7b_dir, runner="generate", max_model_len=100, trust_remote_code=True + ) speculative_config = SpeculativeConfig( target_model_config=model_config, @@ -43,17 +51,16 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig()) + scheduler_config=SchedulerConfig(), + ) - return EagleProposer(vllm_config=vllm_config, - device=current_platform.device_type) + return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) -@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') -@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') -@mock.patch('vllm.v1.spec_decode.eagle.get_model') -def test_mtp_load_model_unified(mock_get_model, mock_get_layers, - mock_get_pp_group): +@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") +@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") +@mock.patch("vllm.v1.spec_decode.eagle.get_model") +def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_group): """Test MTP-specific model loading with unified model approach.""" # Setup mocks @@ -67,8 +74,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers, all_indexer_layers: dict = {} mock_get_layers.side_effect = [ - target_attn_layers, target_indexer_layers, all_attn_layers, - all_indexer_layers + target_attn_layers, + target_indexer_layers, + all_attn_layers, + all_indexer_layers, ] mock_pp_group = mock.MagicMock() @@ -116,17 +125,13 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch): # MTP returns hidden states directly if num_speculative_tokens == 1: - model_mock.return_value = torch.zeros(total_tokens, - hidden_size, - device=device) + model_mock.return_value = torch.zeros(total_tokens, hidden_size, device=device) else: # Multiple forward passes for multi-token speculation forward_returns = [] for i in range(num_speculative_tokens): if i == 0: - h_states = torch.zeros(total_tokens, - hidden_size, - device=device) + h_states = torch.zeros(total_tokens, hidden_size, device=device) else: h_states = torch.zeros(batch_size, hidden_size, device=device) forward_returns.append(h_states) @@ -140,7 +145,8 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): if num_speculative_tokens == 1: model_mock.compute_logits.return_value = create_deterministic_logits( - batch_size, vocab_size, 42) + batch_size, vocab_size, 42 + ) else: logits_returns = [ create_deterministic_logits(batch_size, vocab_size, 42 + i) @@ -153,24 +159,21 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): # Prepare inputs batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) - common_attn_metadata = create_common_attn_metadata(batch_spec, - block_size=16, - device=device) - - target_token_ids = torch.randint(0, - vocab_size, (total_tokens, ), - device=device) - target_positions = torch.cat([ - torch.arange(seq_lens[0], device=device), - torch.arange(seq_lens[1], device=device) - ]) - target_hidden_states = torch.randn(total_tokens, - hidden_size, - device=device) - next_token_ids = torch.randint(0, - vocab_size, (batch_size, ), - dtype=torch.int32, - device=device) + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size=16, device=device + ) + + target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) + target_positions = torch.cat( + [ + torch.arange(seq_lens[0], device=device), + torch.arange(seq_lens[1], device=device), + ] + ) + target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) + next_token_ids = torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ) sampling_metadata = mock.MagicMock() # Setup attention metadata @@ -187,13 +190,15 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): proposer.attn_metadata_builder = attn_metadata_builder # Run propose - result = proposer.propose(target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=None, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata) + result = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + ) # Verify the model was called correctly assert model_mock.called diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 344d19c60db7..692c39282c37 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -4,77 +4,75 @@ from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig from vllm.v1.spec_decode.ngram_proposer import ( - NgramProposer, _find_longest_matched_ngram_and_propose_tokens) + NgramProposer, + _find_longest_matched_ngram_and_propose_tokens, +) def test_find_longest_matched_ngram_and_propose_tokens(): tokens = np.array([1, 2, 3, 4, 1, 2, 3, 5, 6]) result = _find_longest_matched_ngram_and_propose_tokens( - origin_tokens=tokens, - min_ngram=2, - max_ngram=2, - max_model_len=1024, - k=2) + origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2 + ) assert len(result) == 0 tokens = np.array([1, 2, 3, 4, 1, 2, 3]) np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=2, - max_ngram=2, - max_model_len=1024, - k=3), - np.array([4, 1, 2])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3 + ), + np.array([4, 1, 2]), + ) np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=2, - max_ngram=2, - max_model_len=1024, - k=2), np.array([4, 1])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=2 + ), + np.array([4, 1]), + ) np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=1, - max_ngram=1, - max_model_len=1024, - k=3), - np.array([4, 1, 2])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=3 + ), + np.array([4, 1, 2]), + ) np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=1, - max_ngram=1, - max_model_len=1024, - k=2), np.array([4, 1])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2 + ), + np.array([4, 1]), + ) tokens = np.array([1, 3, 6, 2, 3, 4, 1, 2, 3]) np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=2, - max_ngram=2, - max_model_len=1024, - k=3), - np.array([4, 1, 2])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=2, max_ngram=2, max_model_len=1024, k=3 + ), + np.array([4, 1, 2]), + ) # Return on the first match np.testing.assert_array_equal( - _find_longest_matched_ngram_and_propose_tokens(origin_tokens=tokens, - min_ngram=1, - max_ngram=1, - max_model_len=1024, - k=2), np.array([6, 2])) + _find_longest_matched_ngram_and_propose_tokens( + origin_tokens=tokens, min_ngram=1, max_ngram=1, max_model_len=1024, k=2 + ), + np.array([6, 2]), + ) def test_ngram_proposer(): - def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: # Dummy model config. Just to set max_model_len. model_config = ModelConfig(model="facebook/opt-125m") return NgramProposer( - vllm_config=VllmConfig(model_config=model_config, - speculative_config=SpeculativeConfig( - prompt_lookup_min=min_n, - prompt_lookup_max=max_n, - num_speculative_tokens=k, - method="ngram", - ))) + vllm_config=VllmConfig( + model_config=model_config, + speculative_config=SpeculativeConfig( + prompt_lookup_min=min_n, + prompt_lookup_max=max_n, + num_speculative_tokens=k, + method="ngram", + ), + ) + ) # No match. token_ids_cpu = np.array([[1, 2, 3, 4, 5]]) @@ -133,8 +131,7 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: assert np.array_equal(result, np.array([[1, 2]])) # Not [5, 2]] # Multiple 3-gram matched, but always pick the first one. - token_ids_cpu = np.array( - [[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) + token_ids_cpu = np.array([[1, 2, 3, 100, 1, 2, 3, 200, 1, 2, 3, 300, 1, 2, 3]]) result = get_ngram_proposer(min_n=3, max_n=3, k=2).propose( sampled_token_ids=[[0]], req_ids=["0"], @@ -191,6 +188,5 @@ def get_ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: spec_decode_unsupported_reqs=(), ) assert len(result[0]) == 2 - assert np.array_equal(result[0], - np.array([middle_integer + 2, middle_integer + 3])) + assert np.array_equal(result[0], np.array([middle_integer + 2, middle_integer + 3])) assert np.array_equal(result[1], np.array([])) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index ebb9a3d97861..a46e8e3ec755 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -6,9 +6,11 @@ import torch -from tests.v1.attention.utils import (create_standard_kv_cache_spec, - create_vllm_config, - get_attention_backend) +from tests.v1.attention.utils import ( + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend, +) from vllm.attention.backends.registry import _Backend from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -42,10 +44,11 @@ def forward_attention( num_kv_heads = k.shape[-2] # Initialize the query and KV sequence lengths. query_start_loc = q_len * torch.arange( - batch_size + 1, device=q.device, dtype=torch.int32) + batch_size + 1, device=q.device, dtype=torch.int32 + ) query_lens = torch.diff(query_start_loc) seq_lens = torch.full( - (batch_size, ), + (batch_size,), seqlen_k, device=q.device, dtype=torch.int32, @@ -55,14 +58,13 @@ def forward_attention( max_query_len = q_len num_actual_tokens = query_start_loc[-1] - softmax_scale = q.shape[-1]**(-0.5) + softmax_scale = q.shape[-1] ** (-0.5) layer = MockAttentionLayer() # Build common metadata. model_name = "meta-llama/Meta-Llama-3-8B" builder_cls, impl_cls = get_attention_backend(backend) - vllm_config = create_vllm_config(model_name=model_name, - max_model_len=max(seq_lens)) + vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens)) if spec_token_tree is not None: # Create speculative config if token tree is specified. vllm_config.speculative_config = SpeculativeConfig( @@ -71,7 +73,8 @@ def forward_attention( model=model_name, method="eagle", num_speculative_tokens=num_spec_tokens, - speculative_token_tree=spec_token_tree) + speculative_token_tree=spec_token_tree, + ) kv_cache_spec = create_standard_kv_cache_spec(vllm_config) builder = builder_cls(kv_cache_spec, [], vllm_config, q.device) common_attn_metadata = CommonAttentionMetadata( @@ -128,8 +131,7 @@ def test_tree_attn_correctness() -> None: device = "cuda" tree_attn_masks = { # Chain. - "[(0,), (0, 0), (0, 0, 0)]": - torch.tensor( + "[(0,), (0, 0), (0, 0, 0)]": torch.tensor( [ [1, 0, 0, 0], [1, 1, 0, 0], @@ -140,8 +142,7 @@ def test_tree_attn_correctness() -> None: dtype=torch.int32, ), # Tree. - "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": - torch.tensor( + "[(0,), (1,), (0, 0), (0, 1), (1, 0), (1, 1)]": torch.tensor( [ [1, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0], @@ -202,8 +203,7 @@ def test_tree_attn_correctness() -> None: device=q.device, dtype=torch.bfloat16, ) - num_alloc_blocks_per_batch = math.ceil(seqlen_k / - block_size) + num_alloc_blocks_per_batch = math.ceil(seqlen_k / block_size) block_table = torch.zeros( (batch_size, max_blocks_per_batch), device=q.device, @@ -217,11 +217,10 @@ def test_tree_attn_correctness() -> None: ) if randomize_blocks: # Randomize the block ids. - block_ids = block_ids[torch.randperm( - block_ids.numel())] - block_table[:, : - num_alloc_blocks_per_batch] = block_ids.view( - -1, num_alloc_blocks_per_batch) + block_ids = block_ids[torch.randperm(block_ids.numel())] + block_table[:, :num_alloc_blocks_per_batch] = block_ids.view( + -1, num_alloc_blocks_per_batch + ) # Set up the slot mapping for the input KVs. tree_positions = sequence_position + torch.arange( @@ -231,7 +230,8 @@ def test_tree_attn_correctness() -> None: dtype=torch.int64, ).repeat(batch_size, 1) tree_slot_mapping = _gen_slot_mapping( - tree_positions, block_table, block_size) + tree_positions, block_table, block_size + ) # Compute attention for the tree. tree_attn_output = forward_attention( @@ -253,8 +253,7 @@ def test_tree_attn_correctness() -> None: for q_index in range(tree_size_q): # Get the q, k, and v for the branch. branch_mask = tree_attn_mask[q_index, :] - branch_indices = torch.nonzero(branch_mask, - as_tuple=True)[0] + branch_indices = torch.nonzero(branch_mask, as_tuple=True)[0] q_len = branch_indices.shape[0] q_branch = q[:, branch_indices] k_branch = k[:, branch_indices] @@ -268,7 +267,8 @@ def test_tree_attn_correctness() -> None: dtype=torch.int64, ).repeat(batch_size, 1) branch_slot_mapping = _gen_slot_mapping( - branch_positions, block_table, block_size) + branch_positions, block_table, block_size + ) # Compute flash attention for the branch. flash_attn_output = forward_attention( @@ -287,16 +287,19 @@ def test_tree_attn_correctness() -> None: tree_attn_output[:, branch_indices], flash_attn_output, atol=7.81e-3, - ), (f"outputs are not close for " + ), ( + f"outputs are not close for " f"batch_size: {batch_size}, " f"num_heads: {num_heads}, " f"sequence_position: {sequence_position}, " f"tree_attn_mask: {tree_attn_mask}, " - f"q_index: {q_index}.") + f"q_index: {q_index}." + ) -def _gen_slot_mapping(positions: torch.Tensor, block_table: torch.Tensor, - block_size: int): +def _gen_slot_mapping( + positions: torch.Tensor, block_table: torch.Tensor, block_size: int +): block_indices = positions // block_size blocks = block_table.gather(dim=1, index=block_indices) return (blocks * block_size + positions % block_size).view(-1) diff --git a/tests/v1/structured_output/test_utils.py b/tests/v1/structured_output/test_utils.py index 0e2658304d12..b285658af3d1 100644 --- a/tests/v1/structured_output/test_utils.py +++ b/tests/v1/structured_output/test_utils.py @@ -4,7 +4,8 @@ import pytest from vllm.v1.structured_output.backend_xgrammar import ( - has_xgrammar_unsupported_json_features) + has_xgrammar_unsupported_json_features, +) pytestmark = pytest.mark.cpu_test @@ -12,82 +13,41 @@ @pytest.fixture def unsupported_string_schemas(): return [ - { - "type": "string", - "format": "email" - }, + {"type": "string", "format": "email"}, ] @pytest.fixture def unsupported_integer_schemas(): return [ - { - "type": "integer", - "multipleOf": 120 - }, + {"type": "integer", "multipleOf": 120}, ] @pytest.fixture def unsupported_number_schemas(): return [ - { - "type": "number", - "multipleOf": 120 - }, + {"type": "number", "multipleOf": 120}, ] @pytest.fixture def unsupported_array_schemas(): return [ - { - "type": "array", - "uniqueItems": True - }, - { - "type": "array", - "contains": { - "type": "string" - } - }, - { - "type": "array", - "minContains": 1 - }, - { - "type": "array", - "maxContains": 5 - }, + {"type": "array", "uniqueItems": True}, + {"type": "array", "contains": {"type": "string"}}, + {"type": "array", "minContains": 1}, + {"type": "array", "maxContains": 5}, ] @pytest.fixture def unsupported_object_schemas(): return [ - { - "type": "object", - "minProperties": 1 - }, - { - "type": "object", - "maxProperties": 5 - }, - { - "type": "object", - "propertyNames": { - "pattern": "^[a-z]+$" - } - }, - { - "type": "object", - "patternProperties": { - "^S": { - "type": "string" - } - } - }, + {"type": "object", "minProperties": 1}, + {"type": "object", "maxProperties": 5}, + {"type": "object", "propertyNames": {"pattern": "^[a-z]+$"}}, + {"type": "object", "patternProperties": {"^S": {"type": "string"}}}, ] @@ -96,75 +56,50 @@ def supported_schema(): return { "type": "object", "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, - "status": { - "type": "string" - }, - "scores": { - "type": "array", - "items": { - "type": "number" - } - }, - "car_type": { - "type": "string", - "enum": ["sedan", "suv", "truck"] - }, - "car_brand": { - "type": "string", - "pattern": "^[a-zA-Z]+$" - }, - "short_description": { - "type": "string", - "maxLength": 50 - }, - "mileage": { - "type": "number", - "minimum": 0, - "maximum": 1000000 - }, + "name": {"type": "string"}, + "age": {"type": "integer"}, + "status": {"type": "string"}, + "scores": {"type": "array", "items": {"type": "number"}}, + "car_type": {"type": "string", "enum": ["sedan", "suv", "truck"]}, + "car_brand": {"type": "string", "pattern": "^[a-zA-Z]+$"}, + "short_description": {"type": "string", "maxLength": 50}, + "mileage": {"type": "number", "minimum": 0, "maximum": 1000000}, "model_year": { "type": "integer", "exclusiveMinimum": 1900, - "exclusiveMaximum": 2100 - }, - "long_description": { - "type": "string", - "minLength": 50, - "maxLength": 2000 + "exclusiveMaximum": 2100, }, + "long_description": {"type": "string", "minLength": 50, "maxLength": 2000}, "address": { "type": "object", "properties": { - "street": { - "type": "string" - }, - "city": { - "type": "string" - } - } - } - } + "street": {"type": "string"}, + "city": {"type": "string"}, + }, + }, + }, } -@pytest.mark.parametrize("schema_type", [ - "unsupported_string_schemas", "unsupported_integer_schemas", - "unsupported_number_schemas", "unsupported_array_schemas", - "unsupported_object_schemas" -]) +@pytest.mark.parametrize( + "schema_type", + [ + "unsupported_string_schemas", + "unsupported_integer_schemas", + "unsupported_number_schemas", + "unsupported_array_schemas", + "unsupported_object_schemas", + ], +) def test_unsupported_json_features_by_type(schema_type, request): schemas = request.getfixturevalue(schema_type) for schema in schemas: - assert has_xgrammar_unsupported_json_features( - schema), f"Schema should be unsupported: {schema}" + assert has_xgrammar_unsupported_json_features(schema), ( + f"Schema should be unsupported: {schema}" + ) def test_supported_json_features(supported_schema): - assert not has_xgrammar_unsupported_json_features( - supported_schema), "Schema should be supported" + assert not has_xgrammar_unsupported_json_features(supported_schema), ( + "Schema should be supported" + ) diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 74aa20a2f7f9..5d3bb924590a 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -17,7 +17,6 @@ def test_reject_bad_config(monkeypatch): def test_unsupported_configs(monkeypatch): - with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index 5d467687c308..a306a2b040d3 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -9,18 +9,21 @@ import pytest import torch -from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalFlatField, - MultiModalKwargsItem, - MultiModalKwargsItems, - MultiModalSharedField, NestedTensors) +from vllm.multimodal.inputs import ( + MultiModalBatchedField, + MultiModalFieldElem, + MultiModalFlatField, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalSharedField, + NestedTensors, +) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder pytestmark = pytest.mark.cpu_test class UnrecognizedType(UserDict): - def __init__(self, an_int: int): super().__init__() self.an_int = an_int @@ -47,10 +50,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch): m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") obj = MyType( - tensor1=torch.randint(low=0, - high=100, - size=(1024, ), - dtype=torch.int32), + tensor1=torch.randint(low=0, high=100, size=(1024,), dtype=torch.int32), a_string="hello", list_of_tensors=[ torch.rand((1, 10), dtype=torch.float32), @@ -58,8 +58,9 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch): torch.tensor(1984), # test scalar too # Make sure to test bf16 which numpy doesn't support. torch.rand((3, 5, 1000), dtype=torch.bfloat16), - torch.tensor([float("-inf"), float("inf")] * 1024, - dtype=torch.bfloat16), + torch.tensor( + [float("-inf"), float("inf")] * 1024, dtype=torch.bfloat16 + ), ], numpy_array=np.arange(512), unrecognized=UnrecognizedType(33), @@ -103,22 +104,24 @@ class MyRequest(msgspec.Struct): def test_multimodal_kwargs(): - e1 = MultiModalFieldElem("audio", "a0", - torch.zeros(1000, dtype=torch.bfloat16), - MultiModalBatchedField()) + e1 = MultiModalFieldElem( + "audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField() + ) e2 = MultiModalFieldElem( "video", "v0", [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], - MultiModalFlatField( - [[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0), + MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0), + ) + e3 = MultiModalFieldElem( + "image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4) ) - e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, - dtype=torch.int32), - MultiModalSharedField(4)) e4 = MultiModalFieldElem( - "image", "i1", torch.zeros(1000, dtype=torch.int32), - MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2)) + "image", + "i1", + torch.zeros(1000, dtype=torch.int32), + MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2), + ) audio = MultiModalKwargsItem.from_elems([e1]) video = MultiModalKwargsItem.from_elems([e2]) image = MultiModalKwargsItem.from_elems([e3, e4]) @@ -164,16 +167,14 @@ def assert_equal(obj1: MyType, obj2: MyType): assert torch.equal(obj1.tensor1, obj2.tensor1) assert obj1.a_string == obj2.a_string assert all( - torch.equal(a, b) - for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)) + torch.equal(a, b) for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors) + ) assert np.array_equal(obj1.numpy_array, obj2.numpy_array) assert obj1.unrecognized.an_int == obj2.unrecognized.an_int assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor) assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor) - assert torch.equal(obj1.small_non_contig_tensor, - obj2.small_non_contig_tensor) - assert torch.equal(obj1.large_non_contig_tensor, - obj2.large_non_contig_tensor) + assert torch.equal(obj1.small_non_contig_tensor, obj2.small_non_contig_tensor) + assert torch.equal(obj1.large_non_contig_tensor, obj2.large_non_contig_tensor) assert torch.equal(obj1.empty_tensor, obj2.empty_tensor) @@ -210,8 +211,9 @@ def test_tensor_serialization(): decoded = decoder.decode(encoded) # Verify the decoded tensor matches the original - assert torch.allclose( - tensor, decoded), "Decoded tensor does not match the original tensor." + assert torch.allclose(tensor, decoded), ( + "Decoded tensor does not match the original tensor." + ) def test_numpy_array_serialization(): @@ -229,13 +231,12 @@ def test_numpy_array_serialization(): decoded = decoder.decode(encoded) # Verify the decoded array matches the original - assert np.allclose( - array, - decoded), "Decoded numpy array does not match the original array." + assert np.allclose(array, decoded), ( + "Decoded numpy array does not match the original array." + ) class CustomClass: - def __init__(self, value): self.value = value @@ -244,7 +245,8 @@ def __eq__(self, other): def test_custom_class_serialization_allowed_with_pickle( - monkeypatch: pytest.MonkeyPatch): + monkeypatch: pytest.MonkeyPatch, +): """Test that serializing a custom class succeeds when allow_pickle=True.""" with monkeypatch.context() as m: @@ -261,8 +263,7 @@ def test_custom_class_serialization_allowed_with_pickle( decoded = decoder.decode(encoded) # Verify the decoded object matches the original - assert obj == decoded, ( - "Decoded object does not match the original object.") + assert obj == decoded, "Decoded object does not match the original object." def test_custom_class_serialization_disallowed_without_pickle(): diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index 865b58bc7f4b..1518987ded04 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -4,6 +4,7 @@ Run `pytest tests/v1/tpu/test_basic.py`. """ + from __future__ import annotations from typing import TYPE_CHECKING @@ -32,8 +33,9 @@ # TENSOR_PARALLEL_SIZES = [1, 4] -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a basic test for TPU only") +@pytest.mark.skipif( + not current_platform.is_tpu(), reason="This is a basic test for TPU only" +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) @@ -46,32 +48,36 @@ def test_basic( tensor_parallel_size: int, max_num_seqs: int, ) -> None: - prompt = "The next numbers of the sequence " + ", ".join( - str(i) for i in range(1024)) + " are:" + prompt = ( + "The next numbers of the sequence " + + ", ".join(str(i) for i in range(1024)) + + " are:" + ) example_prompts = [prompt] with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") with vllm_runner( - model, - # Note: max_num_batched_tokens == 1024 is needed here to - # actually test chunked prompt - max_num_batched_tokens=1024, - max_model_len=8192, - gpu_memory_utilization=0.7, - max_num_seqs=max_num_seqs, - tensor_parallel_size=tensor_parallel_size) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) + model, + # Note: max_num_batched_tokens == 1024 is needed here to + # actually test chunked prompt + max_num_batched_tokens=1024, + max_model_len=8192, + gpu_memory_utilization=0.7, + max_num_seqs=max_num_seqs, + tensor_parallel_size=tensor_parallel_size, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) output = vllm_outputs[0][1] assert "1024" in output or "0, 1" in output @pytest.mark.skip(reason="Temporarily disabled due to timeout") -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a basic test for TPU only") +@pytest.mark.skipif( + not current_platform.is_tpu(), reason="This is a basic test for TPU only" +) @pytest.mark.parametrize("max_tokens", [8]) @pytest.mark.parametrize("max_num_seqs", [16]) def test_phi3( @@ -96,9 +102,9 @@ def test_phi3( with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - with vllm_runner(model, - max_num_batched_tokens=256, - max_num_seqs=max_num_seqs) as vllm_model: + with vllm_runner( + model, max_num_batched_tokens=256, max_num_seqs=max_num_seqs + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) # vllm_outputs is a list of tuples whose first element is the token id # and the second element is the output (including the prompt). @@ -110,10 +116,11 @@ def test_phi3( TP_SIZE_8 = 8 -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a test for TPU only") -@pytest.mark.skipif(tpu.num_available_chips() < TP_SIZE_8, - reason=f"This test requires {TP_SIZE_8} TPU chips.") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only") +@pytest.mark.skipif( + tpu.num_available_chips() < TP_SIZE_8, + reason=f"This test requires {TP_SIZE_8} TPU chips.", +) def test_gemma3_27b_with_text_input_and_tp( vllm_runner: type[VllmRunner], monkeypatch: pytest.MonkeyPatch, @@ -137,10 +144,11 @@ def test_gemma3_27b_with_text_input_and_tp( m.setenv("VLLM_USE_V1", "1") with vllm_runner( - model, - max_num_batched_tokens=256, - max_num_seqs=max_num_seqs, - tensor_parallel_size=tensor_parallel_size) as vllm_model: + model, + max_num_batched_tokens=256, + max_num_seqs=max_num_seqs, + tensor_parallel_size=tensor_parallel_size, + ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) # vllm_outputs is a list of tuples whose first element is the token id # and the second element is the output (including the prompt). @@ -149,8 +157,9 @@ def test_gemma3_27b_with_text_input_and_tp( assert answer in generated_text -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a basic test for TPU only") +@pytest.mark.skipif( + not current_platform.is_tpu(), reason="This is a basic test for TPU only" +) def test_w8a8_quantization( vllm_runner: type[VllmRunner], monkeypatch: pytest.MonkeyPatch, @@ -160,22 +169,25 @@ def test_w8a8_quantization( tensor_parallel_size = 1 max_num_seqs = 4 - prompt = "The next numbers of the sequence " + ", ".join( - str(i) for i in range(1024)) + " are:" + prompt = ( + "The next numbers of the sequence " + + ", ".join(str(i) for i in range(1024)) + + " are:" + ) example_prompts = [prompt] with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") with vllm_runner( - model, - max_num_batched_tokens=64, - max_model_len=4096, - gpu_memory_utilization=0.7, - max_num_seqs=max_num_seqs, - tensor_parallel_size=tensor_parallel_size) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens) + model, + max_num_batched_tokens=64, + max_model_len=4096, + gpu_memory_utilization=0.7, + max_num_seqs=max_num_seqs, + tensor_parallel_size=tensor_parallel_size, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) output = vllm_outputs[0][1] assert "1024" in output or "0, 1" in output diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py index acb607247d75..99d5f98351ad 100644 --- a/tests/v1/tpu/test_kv_cache_update_kernel.py +++ b/tests/v1/tpu/test_kv_cache_update_kernel.py @@ -10,61 +10,69 @@ from vllm.platforms import current_platform -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a test for TPU only") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only") @pytest.mark.parametrize("page_size", [32, 33]) @pytest.mark.parametrize("combined_kv_head_num", [2, 16]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("num_slices_per_block", [4, 8]) -def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int, - head_dim: int, num_slices_per_block: int): +def test_kv_cache_update_kernel( + page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int +): page_num = 1000 padded_num_tokens = 128 kv_cache_cpu = torch.zeros( (page_num * page_size, combined_kv_head_num, head_dim), dtype=torch.bfloat16, - device="cpu") + device="cpu", + ) kv_cache_xla = kv_cache_cpu.to(torch_xla.device()) new_kv_cpu = torch.randn( (padded_num_tokens, combined_kv_head_num, head_dim), dtype=torch.bfloat16, - device="cpu") + device="cpu", + ) new_kv_xla = new_kv_cpu.to(torch_xla.device()) - slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], - dtype=np.int32) + slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], dtype=np.int32) num_kv_update_slices = len(slice_lens) - kv_cache_start_indices = np.array([ - page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6, - page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3 - ], - dtype=np.int32) + kv_cache_start_indices = np.array( + [ + page_size * 2 - 7, + page_size * 2, + page_size * 3, + page_size * 4 + 6, + page_size * 5 + 7, + page_size * 6 + 8, + page_size * 15 + 3, + ], + dtype=np.int32, + ) new_kv_cache_indices = np.concatenate( - [np.array([0], dtype=np.int32), - np.cumsum(slice_lens[:-1])]) + [np.array([0], dtype=np.int32), np.cumsum(slice_lens[:-1])] + ) slot_mapping = np.stack( - [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1) + [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1 + ) slot_mapping = np.transpose(slot_mapping) - slot_mapping_cpu = torch.tensor(slot_mapping, - device="cpu", - dtype=torch.int32) + slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", dtype=torch.int32) slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) - num_kv_update_slices_xla = torch.tensor([num_kv_update_slices], - device=torch_xla.device(), - dtype=torch.int32) + num_kv_update_slices_xla = torch.tensor( + [num_kv_update_slices], device=torch_xla.device(), dtype=torch.int32 + ) torch_xla.sync() torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( - new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla, - page_size, num_slices_per_block) + new_kv_xla, + slot_mapping_xla, + kv_cache_xla, + num_kv_update_slices_xla, + page_size, + num_slices_per_block, + ) kv_cache_xla.copy_(new_kv_cache_xla) torch_xla.sync() - for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, - slice_lens): - kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :] + for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, slice_lens): + kv_cache_cpu[ci : ci + sl, :, :] = new_kv_cpu[ni : ni + sl, :, :] - assert torch.allclose(kv_cache_xla.cpu(), - kv_cache_cpu, - atol=1e-4, - rtol=1e-4) + assert torch.allclose(kv_cache_xla.cpu(), kv_cache_cpu, atol=1e-4, rtol=1e-4) diff --git a/tests/v1/tpu/test_mha_attn.py b/tests/v1/tpu/test_mha_attn.py index 9d690851b70e..5debdf85bea8 100644 --- a/tests/v1/tpu/test_mha_attn.py +++ b/tests/v1/tpu/test_mha_attn.py @@ -19,8 +19,7 @@ @pytest.fixture(autouse=True) def clear_cache(): - """Clear lru cache to ensure each test case runs without caching. - """ + """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() @@ -49,8 +48,7 @@ def ref_attention( HEAD_SIZES = [64, 80] -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This test needs a TPU") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") @pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -68,19 +66,12 @@ def test_mha_attn_forward( current_platform.seed_everything(0) # These are expected to be f32 q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device) - k = torch.randn(batch_size, - seq_len, - num_kv_heads * head_size, - device=device) - v = torch.randn(batch_size, - seq_len, - num_kv_heads * head_size, - device=device) + k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) + v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) scale = 1.0 / head_size**0.5 - attn = MultiHeadAttention(num_heads, - head_size, - scale=scale, - num_kv_heads=num_kv_heads) + attn = MultiHeadAttention( + num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads + ) output = attn(q, k, v) assert num_heads % num_kv_heads == 0 diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py index 9947fcbe7313..5bf823417d4d 100644 --- a/tests/v1/tpu/test_multimodal.py +++ b/tests/v1/tpu/test_multimodal.py @@ -14,38 +14,32 @@ @pytest.fixture(scope="session") def base64_encoded_image(local_asset_server) -> dict[str, str]: return { - image_asset: - encode_image_base64(local_asset_server.get_image_asset(image_asset)) + image_asset: encode_image_base64( + local_asset_server.get_image_asset(image_asset) + ) for image_asset in TEST_IMAGE_ASSETS } @pytest.mark.asyncio -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This test needs a TPU") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") @pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"]) -async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, - str]): - +async def test_basic_vision(model_name: str, base64_encoded_image: dict[str, str]): pytest.skip("Skip this test until it's fixed.") def whats_in_this_image_msg(b64): - return [{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this image?" - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{b64}" + return [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, }, - }, - ], - }] + ], + } + ] server_args = [ "--max-model-len", @@ -62,19 +56,20 @@ def whats_in_this_image_msg(b64): ] # Server will pre-compile on first startup (takes a long time). - with RemoteOpenAIServer(model_name, server_args, - max_wait_seconds=600) as remote_server: + with RemoteOpenAIServer( + model_name, server_args, max_wait_seconds=600 + ) as remote_server: client: openai.AsyncOpenAI = remote_server.get_async_client() # Other requests now should be much faster for image_url in TEST_IMAGE_ASSETS: image_base64 = base64_encoded_image[image_url] - chat_completion_from_base64 = await client.chat.completions\ - .create( + chat_completion_from_base64 = await client.chat.completions.create( model=model_name, messages=whats_in_this_image_msg(image_base64), max_completion_tokens=24, - temperature=0.0) + temperature=0.0, + ) result = chat_completion_from_base64 assert result choice = result.choices[0] diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index 1bc8dff317a7..0a994e99bade 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -5,8 +5,7 @@ import torch from vllm.attention.backends.abstract import AttentionType -from vllm.v1.attention.backends.pallas import (PallasAttentionBackendImpl, - PallasMetadata) +from vllm.v1.attention.backends.pallas import PallasAttentionBackendImpl, PallasMetadata def test_ragged_paged_attention(): @@ -53,14 +52,14 @@ class FakeAttentionLayer: max_num_reqs = 8 max_num_blocks_per_req = 8 num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32) - block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req), - dtype=torch.int32) - context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32) + block_tables = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), dtype=torch.int32 + ) + context_lens = torch.ones((max_num_reqs,), dtype=torch.int32) query_lens = [1] * max_num_reqs - query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.int32), - dim=0, - dtype=torch.int32) + query_start_loc = torch.cumsum( + torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32 + ) num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, @@ -72,8 +71,7 @@ class FakeAttentionLayer: num_slices_per_kv_cache_update_block=8, ) - with patch("torch.ops.xla.ragged_paged_attention" - ) as mock_ragged_paged_attention: + with patch("torch.ops.xla.ragged_paged_attention") as mock_ragged_paged_attention: attn_impl.forward( layer=layer, query=query, diff --git a/tests/v1/tpu/test_perf.py b/tests/v1/tpu/test_perf.py index f4a2d5ac853a..e8cc396f970e 100644 --- a/tests/v1/tpu/test_perf.py +++ b/tests/v1/tpu/test_perf.py @@ -4,6 +4,7 @@ Run `pytest tests/v1/tpu/test_perf.py`. """ + from __future__ import annotations import time @@ -37,7 +38,6 @@ class TestParams: # open(/dev/vfio/0): Device or resource busy: Device or resource busy; # Couldn't open iommu group /dev/vfio/0 # => Investigate - # TestParams( # model="Qwen/Qwen2.5-1.5B-Instruct", # num_prompts=1, @@ -59,16 +59,14 @@ class TestParams: num_prompts=64, prefix_len=500, decode_len=50, - # commit id: ccb246776d93ef105904a8ec015b3587240a1183 # tpu: v5lite (old vllm CI/CD) # expected_avg_time=1.4, # err_tol=0.30, - # (This is the active CI/CD instance) # commit id: ccb246776d93ef105904a8ec015b3587240a1183 # tpu: v6e (current vllm CI/CD) - expected_avg_time=1.7, # measured with VLLM_XLA_CACHE_PATH= + expected_avg_time=1.7, # measured with VLLM_XLA_CACHE_PATH= err_tol=0.20, ), ] @@ -81,44 +79,50 @@ class TestParams: GPU_UTIL = 0.9 -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This is a basic performance test for TPU only") +@pytest.mark.skipif( + not current_platform.is_tpu(), + reason="This is a basic performance test for TPU only", +) @pytest.mark.parametrize("params", TEST_PARAMS) def test_perf( vllm_runner: type[VllmRunner], monkeypatch: pytest.MonkeyPatch, params: TestParams, ) -> None: - tokenizer = get_tokenizer(params.model, - tokenizer_mode="auto", - trust_remote_code=True) + tokenizer = get_tokenizer( + params.model, tokenizer_mode="auto", trust_remote_code=True + ) prompts = [] for i in range(params.num_prompts): - prefix_token_ids = np.random.randint(0, - tokenizer.vocab_size, - size=params.prefix_len).tolist() + prefix_token_ids = np.random.randint( + 0, tokenizer.vocab_size, size=params.prefix_len + ).tolist() prompt = tokenizer.decode(prefix_token_ids) prompts.append(prompt) print( "-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format( - len(prompts), params.prefix_len, params.decode_len)) + len(prompts), params.prefix_len, params.decode_len + ) + ) with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - sampling_params = SamplingParams(max_tokens=params.decode_len, - temperature=1.0, - min_p=0.0) - - with vllm_runner(params.model, - max_num_batched_tokens=MAX_MODEL_LEN, - max_model_len=MAX_MODEL_LEN, - max_num_seqs=MAX_NUM_SEQS, - gpu_memory_utilization=GPU_UTIL, - enforce_eager=False, - tensor_parallel_size=1) as vllm_model: + sampling_params = SamplingParams( + max_tokens=params.decode_len, temperature=1.0, min_p=0.0 + ) + + with vllm_runner( + params.model, + max_num_batched_tokens=MAX_MODEL_LEN, + max_model_len=MAX_MODEL_LEN, + max_num_seqs=MAX_NUM_SEQS, + gpu_memory_utilization=GPU_UTIL, + enforce_eager=False, + tensor_parallel_size=1, + ) as vllm_model: print(" -- Warmup / Compile") for i in range(NUM_WARMUPS): _ = vllm_model.generate(prompts, sampling_params) @@ -133,14 +137,18 @@ def test_perf( avg_time = sum(times) / len(times) print(" -- avg_time = {}".format(avg_time)) - print(" -- expected_avg_time = {} with err_tol = {}".format( - params.expected_avg_time, params.err_tol)) + print( + " -- expected_avg_time = {} with err_tol = {}".format( + params.expected_avg_time, params.err_tol + ) + ) diff = avg_time - params.expected_avg_time ok = diff < params.err_tol if diff < -params.err_tol: - print(" !! WARNING !! Performance has improved by {}, " - "it may be necessary to fine-tune the " - "expected_avg_time = {}".format( - -diff, params.expected_avg_time)) + print( + " !! WARNING !! Performance has improved by {}, " + "it may be necessary to fine-tune the " + "expected_avg_time = {}".format(-diff, params.expected_avg_time) + ) assert ok, " !! ERROR !! Regression detected" diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index fa950e5f7f85..58f6292b05a7 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -10,21 +10,20 @@ @pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This test needs a TPU") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") def test_sampler_different(model_name: str): """ - Test significantly different sampling params to assert the model produces + Test significantly different sampling params to assert the model produces different results. """ - llm = LLM(model_name, - enforce_eager=False, - max_num_seqs=1, - max_model_len=512, - max_num_batched_tokens=256) - prompts = [ - "Write a short story about a robot that dreams for the first time." - ] + llm = LLM( + model_name, + enforce_eager=False, + max_num_seqs=1, + max_model_len=512, + max_num_batched_tokens=256, + ) + prompts = ["Write a short story about a robot that dreams for the first time."] sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64) output = llm.generate(prompts, sampling_params) @@ -47,7 +46,9 @@ def test_sampler_different(model_name: str): max_tokens=64, # Vary number of ks top_k=random.randint(4, 12), - top_p=random.random()) for _ in range(B) + top_p=random.random(), + ) + for _ in range(B) ] # Make sure first two reqs have the same K/P sampling_params[0] = sampling_params[1] @@ -61,20 +62,18 @@ def test_sampler_different(model_name: str): @pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) # TODO TPU will appear busy if we fan-out test params here @pytest.mark.parametrize("n_prompts", [1]) -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="This test needs a TPU") +@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") def test_logprobs(model_name: str, n_prompts: int): """ Request top logprobs with different sampling settings and check - that results contains the requested number, ordered ascendingly. + that results contains the requested number, ordered ascendingly. """ def check_num_logprobs(logprobs, expected_num: int): for step in logprobs: prev_logp = 1.0 # order by rank - sorted_step = dict( - sorted(step.items(), key=lambda item: item[1].rank)) + sorted_step = dict(sorted(step.items(), key=lambda item: item[1].rank)) # Can contain the sampled token assert len(step) == expected_num or len(step) == expected_num + 1 @@ -84,23 +83,23 @@ def check_num_logprobs(logprobs, expected_num: int): prev_logp = logp.logprob assert logp.rank == rankno + 1 - llm = LLM(model_name, - enforce_eager=False, - max_num_seqs=1, - max_model_len=128, - max_num_batched_tokens=128) + llm = LLM( + model_name, + enforce_eager=False, + max_num_seqs=1, + max_model_len=128, + max_num_batched_tokens=128, + ) prompts = [ "Write a short story about a robot that dreams for the first time." ] * n_prompts - greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\ - logprobs=4) - regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\ - logprobs=4) - topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\ - logprobs=4, top_k=12, top_p=0.5) + greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64, logprobs=4) + regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64, logprobs=4) + topkp_sampling_params = SamplingParams( + temperature=0.4, max_tokens=64, logprobs=4, top_k=12, top_p=0.5 + ) - for sp in [greedy_sampling_params, regular_sampling_params, \ - topkp_sampling_params]: + for sp in [greedy_sampling_params, regular_sampling_params, topkp_sampling_params]: output = llm.generate(prompts, sp) for o in output: check_num_logprobs(o.outputs[0].logprobs, 4) diff --git a/tests/v1/tpu/test_spmd_model_weight_loading.py b/tests/v1/tpu/test_spmd_model_weight_loading.py index ad234df0c8ed..be866bf90a79 100644 --- a/tests/v1/tpu/test_spmd_model_weight_loading.py +++ b/tests/v1/tpu/test_spmd_model_weight_loading.py @@ -9,14 +9,18 @@ import torch_xla.runtime as xr from vllm.config import set_current_vllm_config -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + init_distributed_environment, +) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.model_loader.tpu import TPUModelLoader def _setup_environment(model): - engine_args = EngineArgs(model=model, ) + engine_args = EngineArgs( + model=model, + ) vllm_config = engine_args.create_engine_config() with set_current_vllm_config(vllm_config): temp_file = tempfile.mkstemp()[1] @@ -25,7 +29,8 @@ def _setup_environment(model): 0, local_rank=0, distributed_init_method=f"file://{temp_file}", - backend="gloo") + backend="gloo", + ) # Under single worker mode, full model is init first and then # partitioned using GSPMD. ensure_model_parallel_initialized(1, 1) @@ -42,7 +47,7 @@ def _get_spmd_mesh(): num_devices = xr.global_runtime_device_count() mesh_shape = (num_devices, 1) device_ids = np.array(range(num_devices)) - MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y")) return MESH @@ -53,15 +58,17 @@ def _get_spmd_mesh(): # Skip large models due to CI runner disk space limitations # "meta-llama/Llama-3.1-8B-Instruct", # "meta-llama/Llama-3.1-70B-Instruct", - ]) + ], +) def test_tpu_model_loader(model): # Skip the 70B test if there are less than 8 chips # TODO: Query using torch xla API, the query API is not working # with SPMD now. However, This test is running under SPMD mode. - if '70B' in model and xr.global_runtime_device_count() < 8: + if "70B" in model and xr.global_runtime_device_count() < 8: pytest.skip( "Skipping 70B model if the TPU VM has less than 8 chips to \ - avoid OOM.") + avoid OOM." + ) vllm_config = _setup_environment(model) loader = TPUModelLoader(load_config=vllm_config.load_config) diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index 665cf8cd2629..c2fc24442c7c 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -10,8 +10,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p # isort: off -from vllm.v1.sample.tpu.sampler import (apply_top_k_top_p as - apply_top_k_top_p_tpu) +from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu # isort: on if not current_platform.is_tpu(): @@ -30,11 +29,10 @@ def test_topk_equivalence_to_native_impl(): logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) # Random top-k values between 1 and 10. - k = torch.randint(1, 10, (BATCH_SIZE, )) + k = torch.randint(1, 10, (BATCH_SIZE,)) # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). - k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), - VOCAB_SIZE) + k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), VOCAB_SIZE) result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None) @@ -50,15 +48,13 @@ def test_topp_result_sums_past_p(): probs = logits.softmax(dim=-1) # Random top-p values between 0 and 1. - p = torch.rand((BATCH_SIZE, )) + p = torch.rand((BATCH_SIZE,)) # Set p=1 for ~50% of requests in the batch (top-p disabled). - p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1) + p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), 1) no_op_k = torch.tensor([VOCAB_SIZE]) - logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), - k=no_op_k, - p=p) + logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), k=no_op_k, p=p) # Verify that the masked logit's probability sums to at least p. probs.masked_fill_(logits_masked.isinf(), 0) @@ -72,16 +68,16 @@ def test_topp_result_sums_past_p(): def test_topp_basic(): with torch.device(xm.xla_device()): - logits = torch.tensor([[math.log(0.2), - math.log(0.3), - math.log(0.5)], - [math.log(0.5), - math.log(0.1), - math.log(0.4)]]) + logits = torch.tensor( + [ + [math.log(0.2), math.log(0.3), math.log(0.5)], + [math.log(0.5), math.log(0.1), math.log(0.4)], + ] + ) - result = apply_top_k_top_p_tpu(logits=logits.clone(), - k=torch.tensor([3, 3]), - p=torch.tensor([0.79, 0.79])) + result = apply_top_k_top_p_tpu( + logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79]) + ) torch_xla.sync() @@ -94,16 +90,16 @@ def test_topp_basic(): def test_topp_select_all(): with torch.device(xm.xla_device()): - logits = torch.tensor([[math.log(0.2), - math.log(0.3), - math.log(0.5)], - [math.log(0.5), - math.log(0.1), - math.log(0.4)]]) + logits = torch.tensor( + [ + [math.log(0.2), math.log(0.3), math.log(0.5)], + [math.log(0.5), math.log(0.1), math.log(0.4)], + ] + ) - result = apply_top_k_top_p_tpu(logits=logits.clone(), - k=torch.tensor([3, 3]), - p=torch.tensor([1.0, 1.0])) + result = apply_top_k_top_p_tpu( + logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0]) + ) torch_xla.sync() @@ -114,14 +110,12 @@ def test_topp_with_ties(): with torch.device(xm.xla_device()): # Input has multiple math.log(0.3). logits = torch.tensor( - [[math.log(0.3), - math.log(0.3), - math.log(0.3), - math.log(0.1)]]) + [[math.log(0.3), math.log(0.3), math.log(0.3), math.log(0.1)]] + ) - result = apply_top_k_top_p_tpu(logits=logits.clone(), - k=torch.tensor([4]), - p=torch.tensor([0.2])) + result = apply_top_k_top_p_tpu( + logits=logits.clone(), k=torch.tensor([4]), p=torch.tensor([0.2]) + ) torch_xla.sync() @@ -135,17 +129,17 @@ def test_topp_with_ties(): def test_both_topk_topp(): with torch.device(xm.xla_device()): - logits = torch.tensor([[math.log(0.2), - math.log(0.3), - math.log(0.5)], - [math.log(0.5), - math.log(0.1), - math.log(0.4)]]) + logits = torch.tensor( + [ + [math.log(0.2), math.log(0.3), math.log(0.5)], + [math.log(0.5), math.log(0.1), math.log(0.4)], + ] + ) # Set k=1 for the first batch. - result = apply_top_k_top_p_tpu(logits=logits.clone(), - k=torch.tensor([1, 3]), - p=torch.tensor([0.79, 0.79])) + result = apply_top_k_top_p_tpu( + logits=logits.clone(), k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79]) + ) torch_xla.sync() diff --git a/tests/v1/tpu/test_tpu_int8.py b/tests/v1/tpu/test_tpu_int8.py index f39a8021a29e..50001567a958 100644 --- a/tests/v1/tpu/test_tpu_int8.py +++ b/tests/v1/tpu/test_tpu_int8.py @@ -4,11 +4,11 @@ Run `pytest tests/quantization/test_tpu_int8.py`. """ + import pytest from vllm.model_executor.layers.linear import LinearBase -from vllm.model_executor.layers.quantization.tpu_int8 import ( - TPUInt8LinearMethod) +from vllm.model_executor.layers.quantization.tpu_int8 import TPUInt8LinearMethod from vllm.platforms import current_platform from ...models.registry import HF_EXAMPLE_MODELS @@ -16,8 +16,9 @@ MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] -@pytest.mark.skipif(not current_platform.is_tpu(), - reason="TPU Int8 is only enabled for TPUs.") +@pytest.mark.skipif( + not current_platform.is_tpu(), reason="TPU Int8 is only enabled for TPUs." +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [10]) @@ -26,20 +27,28 @@ [ # w8a8 dynamic activation { - 'quantization_config': { - 'quant_method': 'tpu_int8', - 'activation_scheme': 'dynamic' + "quantization_config": { + "quant_method": "tpu_int8", + "activation_scheme": "dynamic", } } - ]) -def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int, - hf_overrides: dict, monkeypatch) -> None: + ], +) +def test_model_tpu_int8( + vllm_runner, + model: str, + dtype: str, + max_tokens: int, + hf_overrides: dict, + monkeypatch, +) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_transformers_version(on_fail="skip") - activation_scheme = hf_overrides.get('quantization_config', - {}).get('activation_scheme') - quantize_activation = activation_scheme == 'dynamic' + activation_scheme = hf_overrides.get("quantization_config", {}).get( + "activation_scheme" + ) + quantize_activation = activation_scheme == "dynamic" # Allows using apply_model monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") diff --git a/tests/v1/tpu/test_tpu_qkv_linear.py b/tests/v1/tpu/test_tpu_qkv_linear.py index 46fa1193881f..098d92550542 100644 --- a/tests/v1/tpu/test_tpu_qkv_linear.py +++ b/tests/v1/tpu/test_tpu_qkv_linear.py @@ -9,8 +9,10 @@ import torch_xla.runtime as xr from vllm.config import set_current_vllm_config -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, + init_distributed_environment, +) from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.layers.linear import QKVParallelLinear @@ -36,7 +38,8 @@ def setup_environment(): 0, local_rank=0, distributed_init_method=f"file://{temp_file}", - backend="gloo") + backend="gloo", + ) ensure_model_parallel_initialized(1, 1) yield @@ -51,7 +54,7 @@ def _get_spmd_mesh(): num_devices = xr.global_runtime_device_count() mesh_shape = (num_devices, 1) device_ids = np.array(range(num_devices)) - MESH = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y")) return MESH @@ -59,7 +62,7 @@ def _get_spmd_mesh(): # `xr.use_spmd()` will set a global state, and this state is not reversible. # Therefore, non-SPMD tests should be run before SPMD tests. @pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()]) -@pytest.mark.parametrize("device", ['cpu', 'xla']) +@pytest.mark.parametrize("device", ["cpu", "xla"]) @torch.no_grad() def test_xla_qkv_linear(bias, mesh, device): torch.manual_seed(123) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 4f4a9c7db88a..df9fcdc37fa3 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -4,18 +4,25 @@ import pytest from vllm.attention.layer import Attention -from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig, VllmConfig, - set_current_vllm_config) +from vllm.config import ( + CacheConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes -from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, - get_kv_cache_configs) -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) +from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.worker.tpu_model_runner import ( - TPUModelRunner, _get_padded_num_reqs_with_upper_limit, - _get_padded_token_len, _get_req_paddings, _get_token_paddings) + TPUModelRunner, + _get_padded_num_reqs_with_upper_limit, + _get_padded_token_len, + _get_req_paddings, + _get_token_paddings, +) def get_vllm_config(): @@ -67,10 +74,11 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_features=[], sampling_params=SamplingParams(), pooling_params=PoolingParams(), - block_ids=([0], ), # block_ids should be tuple[list[int]] + block_ids=([0],), # block_ids should be tuple[list[int]] num_computed_tokens=0, lora_request=None, - )) + ) + ) num_scheduled_tokens[req_id] = 3 total_num_scheduled_tokens += num_scheduled_tokens[req_id] @@ -99,7 +107,7 @@ def _is_req_added(model_runner, req_id: str) -> bool: def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: """Check if the request state block IDs match the block table. - + This function handles both legacy BlockTable and new MultiGroupBlockTable structures for backward compatibility. """ @@ -206,7 +214,7 @@ def test_update_states_request_resumed(model_runner): req_ids=[req_id], resumed_from_preemption=[False], new_token_ids=[[]], - new_block_ids=[([], )], + new_block_ids=[([],)], num_computed_tokens=[0], ) @@ -303,27 +311,23 @@ def test_get_paddings(): # Bucketed padding min_token_size, max_token_size, padding_gap = 16, 512, 64 expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, - padding_gap) + actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) # Bucketed padding with max_token_size not a power of two. max_token_size = 317 expected_paddings = [16, 32, 64, 128, 192, 256, 320] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, - padding_gap) + actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings # Exponential padding. max_token_size, padding_gap = 1024, 0 expected_paddings = [16, 32, 64, 128, 256, 512, 1024] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, - padding_gap) + actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings # Exponential padding with max_token_size not a power of two. max_token_size = 317 expected_paddings = [16, 32, 64, 128, 256, 512] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, - padding_gap) + actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert actual_paddings == expected_paddings @@ -350,32 +354,31 @@ def test_get_req_paddings(): assert _get_req_paddings(8, 36) == [8, 16, 32, 36] -def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order( - model_runner): +def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(model_runner): layer_0 = "model.layers.0.self_attn.attn" layer_1 = "model.layers.1.self_attn.attn" error_msg = f"{layer_1} must come before the current layer" vllm_config = model_runner.vllm_config - with pytest.raises(ValueError, match=error_msg), \ - set_current_vllm_config(vllm_config): + with ( + pytest.raises(ValueError, match=error_msg), + set_current_vllm_config(vllm_config), + ): fwd_context = { # initialization below will fail because target layer is invalid; # the target layer needs to come before layer 1 - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_0, kv_sharing_target_layer_name=layer_1, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -387,25 +390,25 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner): invalid_layer = "model.layers.0.cross_attn.attn" error_msg = f"{invalid_layer} is not a valid Attention layer in the model" vllm_config = model_runner.vllm_config - with pytest.raises(ValueError, match=error_msg), \ - set_current_vllm_config(vllm_config): + with ( + pytest.raises(ValueError, match=error_msg), + set_current_vllm_config(vllm_config), + ): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_1, # invalid layer: cross_attn.atn doesn't exist! kv_sharing_target_layer_name=invalid_layer, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -416,26 +419,26 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner): layer_1 = "model.layers.1.self_attn.attn" error_msg = f"{layer_1} cannot be the same as the current layer" vllm_config = model_runner.vllm_config - with pytest.raises(ValueError, match=error_msg), \ - set_current_vllm_config(vllm_config): + with ( + pytest.raises(ValueError, match=error_msg), + set_current_vllm_config(vllm_config), + ): fwd_context = { # initialization below will fail because target layer is invalid; # the target layer needs to come before layer 1 - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_1, kv_sharing_target_layer_name=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -447,20 +450,18 @@ def test_init_kv_cache_without_kv_sharing(): vllm_config = get_vllm_config() with set_current_vllm_config(vllm_config): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -475,17 +476,17 @@ def test_init_kv_cache_without_kv_sharing(): available_memory = 20 * GiB_bytes # page size for each layer KV can be calculated as # 2 (non-MLA) * 8 (num_heads) * 128 (head_dim) - # * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB + # * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers) - kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], - [available_memory])[0] + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 2 assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2 - max_context_len =\ - estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) # max context len with KV sharing should be 2x as large as without # max_context_len = available_memory / (page_size / block_size) / num_caches # max_context_len = 5GB / (512KB / 128) / 2 = 655360 @@ -495,8 +496,9 @@ def test_init_kv_cache_without_kv_sharing(): # this will only allocate 2 block worth of memory (2 * 512kb) kv_cache_config.num_blocks = 1 for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - kv_cache_tensor.size = ( - kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes) + kv_cache_tensor.size = kv_cache_spec[ + kv_cache_tensor.shared_by[0] + ].page_size_bytes model_runner.initialize_kv_cache(kv_cache_config) @@ -518,21 +520,19 @@ def test_init_kv_cache_with_kv_sharing_valid(): vllm_config = get_vllm_config() with set_current_vllm_config(vllm_config): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=128, scale=1.0, prefix=layer_1, kv_sharing_target_layer_name="model.layers.0.self_attn.attn", - ) + ), } # suppress var not used error assert fwd_context is not None @@ -550,24 +550,23 @@ def test_init_kv_cache_with_kv_sharing_valid(): # with KV sharing, we can allocate (available_mem//page_size//1) blocks # which is twice as many as without KV sharing num_expected_blocks = 2 * 20480 # 20GB / 512KB - kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], - [available_memory])[0] + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 1 # Each layer now has twice the available memory for KV cache # compared to no KV sharing assert kv_cache_config.kv_cache_tensors[0].size == available_memory - max_context_len =\ - estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) # max context len with KV sharing should be 2x as large as without assert max_context_len == (2 * 655360) # important: override tensor size to prevent large mem alloc during test # this will only allocate 1 block worth of memory (512kb) kv_cache_config.num_blocks = 1 - kv_cache_config.kv_cache_tensors[0].size =\ - kv_cache_spec[layer_0].page_size_bytes + kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes model_runner.initialize_kv_cache(kv_cache_config) diff --git a/tests/v1/tracing/test_tracing.py b/tests/v1/tracing/test_tracing.py index da8655f95e19..e7767aceec55 100644 --- a/tests/v1/tracing/test_tracing.py +++ b/tests/v1/tracing/test_tracing.py @@ -12,20 +12,23 @@ import grpc import pytest from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( - ExportTraceServiceResponse) + ExportTraceServiceResponse, +) from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( - TraceServiceServicer, add_TraceServiceServicer_to_server) + TraceServiceServicer, + add_TraceServiceServicer_to_server, +) from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue -from opentelemetry.sdk.environment_variables import ( - OTEL_EXPORTER_OTLP_TRACES_INSECURE) +from opentelemetry.sdk.environment_variables import OTEL_EXPORTER_OTLP_TRACES_INSECURE from vllm import LLM, SamplingParams from vllm.tracing import SpanAttributes FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" -FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value', - 'array_value'] +FieldName = Literal[ + "bool_value", "string_value", "int_value", "double_value", "array_value" +] def decode_value(value: AnyValue): @@ -34,8 +37,9 @@ def decode_value(value: AnyValue): "string_value": (lambda v: v.string_value), "int_value": (lambda v: v.int_value), "double_value": (lambda v: v.double_value), - "array_value": - (lambda v: [decode_value(item) for item in v.array_value.values]), + "array_value": ( + lambda v: [decode_value(item) for item in v.array_value.values] + ), } for field, decoder in field_decoders.items(): if value.HasField(field): @@ -48,7 +52,6 @@ def decode_attributes(attributes: Iterable[KeyValue]): class FakeTraceService(TraceServiceServicer): - def __init__(self): self.request = None self.evt = threading.Event() @@ -86,10 +89,12 @@ def test_traces( max_tokens=256, ) model = "facebook/opt-125m" - llm = LLM(model=model, - otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, - gpu_memory_utilization=0.3, - disable_log_stats=False) + llm = LLM( + model=model, + otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, + gpu_memory_utilization=0.3, + disable_log_stats=False, + ) prompts = ["This is a short prompt"] outputs = llm.generate(prompts, sampling_params=sampling_params) print(f"test_traces outputs is : {outputs}") @@ -98,40 +103,48 @@ def test_traces( if not trace_service.evt.wait(timeout): raise TimeoutError( f"The fake trace service didn't receive a trace within " - f"the {timeout} seconds timeout") + f"the {timeout} seconds timeout" + ) request = trace_service.request assert len(request.resource_spans) == 1, ( - f"Expected 1 resource span, " - f"but got {len(request.resource_spans)}") + f"Expected 1 resource span, but got {len(request.resource_spans)}" + ) assert len(request.resource_spans[0].scope_spans) == 1, ( f"Expected 1 scope span, " - f"but got {len(request.resource_spans[0].scope_spans)}") + f"but got {len(request.resource_spans[0].scope_spans)}" + ) assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( f"Expected 1 span, " - f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") + f"but got {len(request.resource_spans[0].scope_spans[0].spans)}" + ) attributes = decode_attributes( - request.resource_spans[0].scope_spans[0].spans[0].attributes) + request.resource_spans[0].scope_spans[0].spans[0].attributes + ) # assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id - assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE - ) == sampling_params.temperature - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p - assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS - ) == sampling_params.max_tokens - assert attributes.get( - SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n - assert attributes.get( - SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( - outputs[0].prompt_token_ids) + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id + assert ( + attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE) + == sampling_params.temperature + ) + assert ( + attributes.get(SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p + ) + assert ( + attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS) + == sampling_params.max_tokens + ) + assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n + assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( + outputs[0].prompt_token_ids + ) completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) - assert attributes.get( - SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens + assert ( + attributes.get(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) + == completion_tokens + ) assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) > 0 - assert attributes.get( - SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0 + assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0 assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) > 0 diff --git a/tests/v1/utils.py b/tests/v1/utils.py index b3f560c11e8f..993ad8a947d0 100644 --- a/tests/v1/utils.py +++ b/tests/v1/utils.py @@ -9,10 +9,9 @@ # Prometheus metrics utilities for testing -def get_prometheus_metrics( - server: RemoteOpenAIServer) -> dict[str, dict[str, float]]: +def get_prometheus_metrics(server: RemoteOpenAIServer) -> dict[str, dict[str, float]]: """Fetch and parse Prometheus metrics from the /metrics endpoint. - + Returns: Dict mapping metric names to their values grouped by labels. For example: {"vllm:request_success": { @@ -27,14 +26,14 @@ def get_prometheus_metrics( # Regex patterns for Prometheus metrics metric_with_labels = re.compile( - r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$') - metric_simple = re.compile( - r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$') + r"^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$" + ) + metric_simple = re.compile(r"^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$") - for line in response.text.split('\n'): + for line in response.text.split("\n"): line = line.strip() # Skip comments and empty lines - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue # Try to match metric with labels first @@ -45,7 +44,7 @@ def get_prometheus_metrics( value = float(value_str) if metric_name not in metrics: metrics[metric_name] = {} - metrics[metric_name][f'{{{labels_part}}}'] = value + metrics[metric_name][f"{{{labels_part}}}"] = value except ValueError: continue else: @@ -57,7 +56,7 @@ def get_prometheus_metrics( value = float(value_str) if metric_name not in metrics: metrics[metric_name] = {} - metrics[metric_name][''] = value + metrics[metric_name][""] = value except ValueError: continue @@ -67,10 +66,9 @@ def get_prometheus_metrics( return {} -def get_engine_request_counts( - metrics: dict[str, dict[str, float]]) -> dict[str, float]: +def get_engine_request_counts(metrics: dict[str, dict[str, float]]) -> dict[str, float]: """Extract request counts per engine from Prometheus metrics. - + Returns: Dict mapping engine indices to request counts. For example: {"0": 15.0, "1": 12.0} @@ -95,7 +93,7 @@ def get_engine_request_counts( def check_request_balancing(server: RemoteOpenAIServer, dp_size: int): """Check request balancing via Prometheus metrics if dp_size > 1. - + Args: server: The RemoteOpenAIServer instance dp_size: Number of data parallel ranks @@ -114,7 +112,8 @@ def check_request_balancing(server: RemoteOpenAIServer, dp_size: int): assert len(engines_with_requests) == dp_size, ( f"Expected requests to be distributed across multiple engines," f" but only engine(s) {engines_with_requests} received " - f"requests. Engine counts: {engine_counts}") + f"requests. Engine counts: {engine_counts}" + ) # Verify that the load is reasonably balanced # (no engine should handle all requests) @@ -122,4 +121,5 @@ def check_request_balancing(server: RemoteOpenAIServer, dp_size: int): for count in engine_counts.values(): assert count > total_requests // (dp_size + 1), ( - f"requests are imbalanced: {engine_counts}") + f"requests are imbalanced: {engine_counts}" + ) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 98700ff73fd1..c834577f1adb 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -29,14 +29,11 @@ MAX_NUM_PROMPT_TOKENS = 64 -def _compare_objs(obj1, - obj2, - skip: Sequence = ("logitsprocs", "batch_update_builder")): +def _compare_objs(obj1, obj2, skip: Sequence = ("logitsprocs", "batch_update_builder")): attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) - attr_names = set([ - a[0] for a in attrs - if not (a[0].startswith('__') and a[0].endswith('__')) - ]) + attr_names = set( + [a[0] for a in attrs if not (a[0].startswith("__") and a[0].endswith("__"))] + ) for attr_name in attr_names: if attr_name in skip: continue @@ -47,7 +44,7 @@ def _compare_objs(obj1, is_same = False if isinstance(a, torch.Tensor): if a.numel() == 0 or b.numel() == 0: - is_same = (a.numel() == 0 and b.numel() == 0) + is_same = a.numel() == 0 and b.numel() == 0 elif torch.allclose(a, b): is_same = True elif isinstance(a, np.ndarray): @@ -64,12 +61,14 @@ def _compare_objs(obj1, is_same = True elif isinstance(a, CpuGpuBuffer): is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu) - assert is_same, f"Attribute {attr_name} is different"\ - f" in {obj1} and {obj2}: {a} != {b}" + assert is_same, ( + f"Attribute {attr_name} is different in {obj1} and {obj2}: {a} != {b}" + ) -def _remove_requests(input_batch: InputBatch, batch_size: int, - reqs: list[CachedRequestState]) -> set[str]: +def _remove_requests( + input_batch: InputBatch, batch_size: int, reqs: list[CachedRequestState] +) -> set[str]: """ Remove some requests randomly from the batch and returns set of request removed @@ -109,10 +108,9 @@ def _construct_expected_sampling_metadata( temperature = [0.0 for _ in range(num_reqs)] min_tokens = {} logit_bias = [None] * num_reqs - allowed_token_ids_mask = torch.zeros(num_reqs, - VOCAB_SIZE, - dtype=torch.bool, - device=device) + allowed_token_ids_mask = torch.zeros( + num_reqs, VOCAB_SIZE, dtype=torch.bool, device=device + ) bad_words_token_ids = {} for req in reqs: if req.req_id not in req_ids_retained: @@ -120,35 +118,40 @@ def _construct_expected_sampling_metadata( index_in_input_batch = req_id_index_in_input_batch[req.req_id] output_token_ids[index_in_input_batch] = req.output_token_ids prompt_token_ids[index_in_input_batch] = req.prompt_token_ids - presence_penalties[ - index_in_input_batch] = req.sampling_params.presence_penalty + presence_penalties[index_in_input_batch] = req.sampling_params.presence_penalty frequency_penalties[index_in_input_batch] = ( - req.sampling_params.frequency_penalty) + req.sampling_params.frequency_penalty + ) repetition_penalties[index_in_input_batch] = ( - req.sampling_params.repetition_penalty) + req.sampling_params.repetition_penalty + ) top_k[index_in_input_batch] = req.sampling_params.top_k top_p[index_in_input_batch] = req.sampling_params.top_p temperature[index_in_input_batch] = req.sampling_params.temperature min_tokens[index_in_input_batch] = ( req.sampling_params.min_tokens, - req.sampling_params.all_stop_token_ids) + req.sampling_params.all_stop_token_ids, + ) logit_bias[index_in_input_batch] = req.sampling_params.logit_bias if req.sampling_params.allowed_token_ids: allowed_token_ids_mask[index_in_input_batch][ - req.sampling_params.allowed_token_ids] = True + req.sampling_params.allowed_token_ids + ] = True if req.sampling_params.bad_words_token_ids: - bad_words_token_ids[ - index_in_input_batch] = req.sampling_params.bad_words_token_ids + bad_words_token_ids[index_in_input_batch] = ( + req.sampling_params.bad_words_token_ids + ) return SamplingMetadata( - temperature=torch.tensor(temperature, dtype=torch.float, - device=device), + temperature=torch.tensor(temperature, dtype=torch.float, device=device), all_greedy=False, all_random=True, - top_p=None if all(x == 1.0 for x in top_p) else torch.tensor( - top_p, dtype=torch.float, device=device), - top_k=None if all(x == 0 for x in top_k) else torch.tensor( - top_k, dtype=torch.int, device=device), + top_p=None + if all(x == 1.0 for x in top_p) + else torch.tensor(top_p, dtype=torch.float, device=device), + top_k=None + if all(x == 0 for x in top_k) + else torch.tensor(top_k, dtype=torch.int, device=device), generators={}, max_num_logprobs=0, prompt_token_ids=make_tensor_with_pad( @@ -157,19 +160,21 @@ def _construct_expected_sampling_metadata( device=torch.device(device), dtype=torch.int64, ), - frequency_penalties=torch.tensor(frequency_penalties, - dtype=torch.float, - device=device), - presence_penalties=torch.tensor(presence_penalties, - dtype=torch.float, - device=device), - repetition_penalties=torch.tensor(repetition_penalties, - dtype=torch.float, - device=device), + frequency_penalties=torch.tensor( + frequency_penalties, dtype=torch.float, device=device + ), + presence_penalties=torch.tensor( + presence_penalties, dtype=torch.float, device=device + ), + repetition_penalties=torch.tensor( + repetition_penalties, dtype=torch.float, device=device + ), output_token_ids=output_token_ids, - no_penalties=(all(x == 0 for x in presence_penalties) - and all(x == 0 for x in frequency_penalties) - and all(x == 1 for x in repetition_penalties)), + no_penalties=( + all(x == 0 for x in presence_penalties) + and all(x == 0 for x in frequency_penalties) + and all(x == 1 for x in repetition_penalties) + ), allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=bad_words_token_ids, logitsprocs=LogitsProcessors(), @@ -185,8 +190,7 @@ def _create_sampling_params(): frequency_penalty=np.random.uniform(-2.0, 2.0), min_tokens=np.random.randint(1, 10), stop_token_ids=[ - np.random.randint(0, VOCAB_SIZE) - for _ in range(np.random.randint(10)) + np.random.randint(0, VOCAB_SIZE) for _ in range(np.random.randint(10)) ], logit_bias={0: np.random.uniform(-3.0, 3.0)}, ) @@ -207,7 +211,7 @@ def _construct_cached_request_state(req_id_suffix: int): sampling_params=_create_sampling_params(), pooling_params=None, mm_features=[], - block_ids=([], ), + block_ids=([],), generator=None, num_computed_tokens=len(output_token_ids), output_token_ids=output_token_ids, @@ -262,19 +266,18 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): # Create expected output. expected_sampling_metadata = _construct_expected_sampling_metadata( - reqs, - req_ids_retained, - input_batch.req_id_to_index, - device=torch.device(device)) + reqs, req_ids_retained, input_batch.req_id_to_index, device=torch.device(device) + ) def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: - return (t1 is None - and t2 is None) or (t1 is not None and t2 is not None - and torch.allclose(t1, t2)) + return (t1 is None and t2 is None) or ( + t1 is not None and t2 is not None and torch.allclose(t1, t2) + ) # Assert the actual and expected output. - assert torch.allclose(expected_sampling_metadata.temperature, - sampling_metadata.temperature) + assert torch.allclose( + expected_sampling_metadata.temperature, sampling_metadata.temperature + ) assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p) assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k) assert torch.allclose( @@ -289,25 +292,29 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: expected_sampling_metadata.repetition_penalties, sampling_metadata.repetition_penalties, ) - assert torch.allclose(expected_sampling_metadata.prompt_token_ids, - sampling_metadata.prompt_token_ids) - assert (expected_sampling_metadata.output_token_ids == - sampling_metadata.output_token_ids) - assert expected_sampling_metadata.no_penalties == \ - sampling_metadata.no_penalties + assert torch.allclose( + expected_sampling_metadata.prompt_token_ids, sampling_metadata.prompt_token_ids + ) + assert ( + expected_sampling_metadata.output_token_ids + == sampling_metadata.output_token_ids + ) + assert expected_sampling_metadata.no_penalties == sampling_metadata.no_penalties if sampling_metadata.allowed_token_ids_mask: assert torch.allclose( expected_sampling_metadata.allowed_token_ids_mask, - sampling_metadata.allowed_token_ids_mask) - assert expected_sampling_metadata.bad_words_token_ids == \ - sampling_metadata.bad_words_token_ids + sampling_metadata.allowed_token_ids_mask, + ) + assert ( + expected_sampling_metadata.bad_words_token_ids + == sampling_metadata.bad_words_token_ids + ) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("swap_list", [((0, 1), )]) -def test_swap_states_in_input_batch(device: str, batch_size: int, - swap_list: list): +@pytest.mark.parametrize("swap_list", [((0, 1),)]) +def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: list): """ Tests the logic for managing sampling metadata in the InputBatch. @@ -352,8 +359,10 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, reordered_reqs = reqs.copy() for swap_pair in swap_list: - reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \ - reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]] + reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = ( + reordered_reqs[swap_pair[1]], + reordered_reqs[swap_pair[0]], + ) input_batch.swap_states(swap_pair[0], swap_pair[1]) for req_index in range(batch_size): diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 23d7ce4cefa3..ef2956bd3ec2 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -6,20 +6,30 @@ import torch from vllm.attention import Attention -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VllmConfig, set_current_vllm_config) -from vllm.distributed.parallel_state import (init_distributed_environment, - initialize_model_parallel) +from vllm.config import ( + CacheConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, + set_current_vllm_config, +) +from vllm.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes, update_environment_variables -from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, - get_kv_cache_configs) -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor) +from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheTensor, +) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -35,8 +45,7 @@ def initialize_kv_cache(runner: GPUModelRunner): """ attn_spec = FullAttentionSpec( block_size=BLOCK_SIZE, - num_kv_heads=runner.model_config.get_num_kv_heads( - runner.parallel_config), + num_kv_heads=runner.model_config.get_num_kv_heads(runner.parallel_config), head_size=runner.model_config.get_head_size(), dtype=runner.kv_cache_dtype, ) @@ -58,9 +67,7 @@ def initialize_kv_cache(runner: GPUModelRunner): device=runner.device, pin_memory=runner.pin_memory, vocab_size=runner.model_config.get_vocab_size(), - block_sizes=[ - kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size - ], + block_sizes=[kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size], ) runner.initialize_attn_backend(kv_cache_config) @@ -98,8 +105,9 @@ def model_runner(): model_config = vllm_config.model_config num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) head_size = model_config.get_head_size() - vllm_config.compilation_config.static_forward_context[ - "layer.0"] = Attention(num_heads, head_size, 0.1) + vllm_config.compilation_config.static_forward_context["layer.0"] = Attention( + num_heads, head_size, 0.1 + ) runner = GPUModelRunner(vllm_config, DEVICE) initialize_kv_cache(runner) return runner @@ -120,10 +128,11 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_features=[], sampling_params=SamplingParams(), pooling_params=None, - block_ids=([0], ), + block_ids=([0],), num_computed_tokens=0, lora_request=None, - )) + ) + ) num_scheduled_tokens[req_id] = 3 total_num_scheduled_tokens += num_scheduled_tokens[req_id] @@ -150,22 +159,22 @@ def _is_req_added(model_runner, req_id: str) -> bool: return req_id in model_runner.requests -def _is_sampling_metadata_changed(model_runner, - sampling_metadata_before: SamplingMetadata): - return model_runner.input_batch.sampling_metadata is not ( - sampling_metadata_before) +def _is_sampling_metadata_changed( + model_runner, sampling_metadata_before: SamplingMetadata +): + return model_runner.input_batch.sampling_metadata is not (sampling_metadata_before) def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: req_index = model_runner.input_batch.req_id_to_index[req_id] block_table = model_runner.input_batch.block_table[0] req_state = model_runner.requests[req_id] - if block_table.num_blocks_per_row[req_index] != len( - req_state.block_ids[0]): + if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids[0]): return False num_blocks = block_table.num_blocks_per_row[req_index] - return (block_table.block_table.np[req_index, :num_blocks] == - req_state.block_ids[0]).all() + return ( + block_table.block_table.np[req_index, :num_blocks] == req_state.block_ids[0] + ).all() def test_update_states_new_request(model_runner, dist_init): @@ -248,7 +257,7 @@ def test_update_states_request_resumed(model_runner, dist_init): req_ids=[req_id], resumed_from_preemption=[False], new_token_ids=[[]], - new_block_ids=([[0]], ), + new_block_ids=([[0]],), num_computed_tokens=[0], num_output_tokens=[0], ) @@ -281,46 +290,58 @@ def test_get_nans_in_logits(model_runner, dist_init): scheduler_output = _schedule_new_request(*req_ids) model_runner._update_states(scheduler_output) - logits = torch.tensor([ - [1.0, 2.0, 3.0], - [3.0, 2.0, 1.0], - ], device=DEVICE) + logits = torch.tensor( + [ + [1.0, 2.0, 3.0], + [3.0, 2.0, 1.0], + ], + device=DEVICE, + ) result = model_runner._get_nans_in_logits(logits) assert result == {"req_0": 0, "req_1": 0} - logits = torch.tensor([ - [1.0, float('nan'), 3.0], - [4.0, float('nan'), float('nan')], - ], - device=DEVICE) + logits = torch.tensor( + [ + [1.0, float("nan"), 3.0], + [4.0, float("nan"), float("nan")], + ], + device=DEVICE, + ) result = model_runner._get_nans_in_logits(logits) assert result == {"req_0": 1, "req_1": 2} - logits = torch.tensor([ - [1.0, 2.0, 3.0], - [4.0, float('nan'), float('nan')], - ], - device=DEVICE) + logits = torch.tensor( + [ + [1.0, 2.0, 3.0], + [4.0, float("nan"), float("nan")], + ], + device=DEVICE, + ) result = model_runner._get_nans_in_logits(logits) assert result == {"req_0": 0, "req_1": 2} result = model_runner._get_nans_in_logits(logits=None) assert result == {"req_0": 0, "req_1": 0} - logits = torch.tensor([ - [1.0, float('nan'), 3.0], - ], device=DEVICE) + logits = torch.tensor( + [ + [1.0, float("nan"), 3.0], + ], + device=DEVICE, + ) result = model_runner._get_nans_in_logits(logits) - assert result == {'req_0': 1, 'req_1': 0} - - logits = torch.tensor([ - [float('nan'), float('nan'), 2.0], - [1.0, 2.0, 3.0], - [float('nan'), 2.0, 3.0], - ], - device=DEVICE) + assert result == {"req_0": 1, "req_1": 0} + + logits = torch.tensor( + [ + [float("nan"), float("nan"), 2.0], + [1.0, 2.0, 3.0], + [float("nan"), 2.0, 3.0], + ], + device=DEVICE, + ) result = model_runner._get_nans_in_logits(logits) - assert result == {'req_0': 2, 'req_1': 0} + assert result == {"req_0": 2, "req_1": 0} def test_update_states_no_changes(model_runner, dist_init): @@ -398,11 +419,13 @@ def test_update_states_request_unscheduled(model_runner, dist_init): def test_kv_cache_stride_order(monkeypatch, model_runner): # This test checks if GPUModelRunner initializes correctly when an attention # backend enforces a non-default KV cache stride order. - n_heads = model_runner.model_config.get_num_kv_heads( - model_runner.parallel_config) + n_heads = model_runner.model_config.get_num_kv_heads(model_runner.parallel_config) expected_kv_cache_shape = [ - 2, NUM_BLOCKS, BLOCK_SIZE, n_heads, - model_runner.model_config.get_head_size() + 2, + NUM_BLOCKS, + BLOCK_SIZE, + n_heads, + model_runner.model_config.get_head_size(), ] # TODO mla test default_stride = tuple(range(5)) @@ -415,8 +438,9 @@ def rnd_stride_order(test_stride=test_stride): # Patch the attention backend class and re-trigger the KV cache creation for attn_group in model_runner._attn_group_iterator(): attn_backend = attn_group.backend - monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order", - rnd_stride_order) + monkeypatch.setattr( + attn_backend, "get_kv_cache_stride_order", rnd_stride_order + ) model_runner.attn_groups = [] model_runner.kv_caches = [] @@ -448,14 +472,13 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): model_runner_2.update_config({"load_config": {"load_format": "dummy"}}) model_runner_2.load_model() # Initial model loading with dummy weights assert str(model_runner.get_model().state_dict()) != str( - model_runner_2.get_model().state_dict()) - model_runner_2.update_config( - {"load_config": { - "load_format": original_load_format - }}) + model_runner_2.get_model().state_dict() + ) + model_runner_2.update_config({"load_config": {"load_format": original_load_format}}) model_runner_2.reload_weights() # Load real weights inplace assert str(model_runner.get_model().state_dict()) == str( - model_runner_2.get_model().state_dict()) + model_runner_2.get_model().state_dict() + ) def test_reload_weights_before_load_model(model_runner): @@ -472,21 +495,19 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): fwd_context = { # initialization below will fail because target layer is invalid; # the target layer needs to come before layer 1 - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_0, kv_sharing_target_layer_name=layer_1, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -500,22 +521,20 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): error_msg = f"{invalid_layer} is not a valid Attention layer in the model" with pytest.raises(ValueError, match=error_msg): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_1, # invalid layer: cross_attn.atn doesn't exist! kv_sharing_target_layer_name=invalid_layer, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -530,21 +549,19 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current(): fwd_context = { # initialization below will fail because target layer is invalid; # the target layer needs to come before layer 1 - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_1, kv_sharing_target_layer_name=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -557,20 +574,18 @@ def test_init_kv_cache_without_kv_sharing(): vllm_config = get_vllm_config() with set_current_vllm_config(vllm_config): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_1, - ) + ), } # suppress var not used error assert fwd_context is not None @@ -585,15 +600,15 @@ def test_init_kv_cache_without_kv_sharing(): available_memory = 20 * GiB_bytes # page size for layer 0's kv_cache_spec is 32KB num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers) - kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], - [available_memory])[0] + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 2 assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2 - max_context_len =\ - estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) # max context len with KV sharing should be 2x as large as without assert max_context_len == 1310720 @@ -601,8 +616,9 @@ def test_init_kv_cache_without_kv_sharing(): # this will only allocate 2 block worth of memory (2 * 32kb) kv_cache_config.num_blocks = 1 for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - kv_cache_tensor.size = ( - kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes) + kv_cache_tensor.size = kv_cache_spec[ + kv_cache_tensor.shared_by[0] + ].page_size_bytes runner.initialize_kv_cache(kv_cache_config) @@ -625,21 +641,19 @@ def test_init_kv_cache_with_kv_sharing_valid(): vllm_config = get_vllm_config() with set_current_vllm_config(vllm_config): fwd_context = { - layer_0: - Attention( + layer_0: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_0, ), - layer_1: - Attention( + layer_1: Attention( num_heads=8, head_size=64, scale=1.0, prefix=layer_1, kv_sharing_target_layer_name="model.layers.0.self_attn.attn", - ) + ), } # suppress var not used error assert fwd_context is not None @@ -657,24 +671,23 @@ def test_init_kv_cache_with_kv_sharing_valid(): # with KV sharing, we can allocate (available_mem//page_size//1) blocks # which is twice as many as without KV sharing num_expected_blocks = 655360 # 20GB / 32KB - kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], - [available_memory])[0] + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 1 # Each layer now has twice the available memory for KV cache # compared to no KV sharing assert kv_cache_config.kv_cache_tensors[0].size == available_memory - max_context_len =\ - estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) # max context len with KV sharing should be 2x as large as without assert max_context_len == 2 * 1310720 # important: override tensor size to prevent large mem alloc during test # this will only allocate 1 block worth of memory (32kb) kv_cache_config.num_blocks = 1 - kv_cache_config.kv_cache_tensors[0].size =\ - kv_cache_spec[layer_0].page_size_bytes + kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes runner.initialize_kv_cache(kv_cache_config) kv_cache_config_after_init = runner.kv_cache_config @@ -687,30 +700,30 @@ def test_init_kv_cache_with_kv_sharing_valid(): # check layer 1 added to kv cache group's layer names assert len(kv_cache_config_after_init.kv_cache_groups) == 1 assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2 - assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ - 0] == layer_0 - assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[ - 1] == layer_1 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[0] == layer_0 + assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1 def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): - ''' + """ The GPU model runner creates different views into the KVCacheTensors for the attention and mamba layers (via _reshape_kv_cache_tensors function). This test verifies that the views are compatible: writing a mamba block will not corrupt an attention block and vice versa - ''' + """ current_platform.seed_everything(42) - update_environment_variables({ - 'RANK': "0", - 'LOCAL_RANK': "0", - 'WORLD_SIZE': "1", - 'MASTER_ADDR': 'localhost', - 'MASTER_PORT': '12345', - }) + update_environment_variables( + { + "RANK": "0", + "LOCAL_RANK": "0", + "WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) init_distributed_environment() initialize_model_parallel(tensor_model_parallel_size=1) torch.set_default_dtype(torch.float16) @@ -751,8 +764,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): fwd_context = {} for key in [layer_0, layer_1]: fwd_context[key] = Attention( - num_heads=model_config.get_num_attention_heads( - parallel_config), + num_heads=model_config.get_num_attention_heads(parallel_config), num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), scale=1.0, @@ -760,13 +772,12 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): ) for key in [layer_2, layer_3, layer_4, layer_5]: fwd_context[key] = MambaMixer2( - hidden_size = hf_config.hidden_size, - ssm_state_size = hf_config.mamba_d_state, - conv_kernel_size = hf_config.mamba_d_conv, - intermediate_size = hf_config.mamba_expand *\ - hf_config.hidden_size, - use_conv_bias = hf_config.mamba_conv_bias, - use_bias = hf_config.mamba_proj_bias, + hidden_size=hf_config.hidden_size, + ssm_state_size=hf_config.mamba_d_state, + conv_kernel_size=hf_config.mamba_d_conv, + intermediate_size=hf_config.mamba_expand * hf_config.hidden_size, + use_conv_bias=hf_config.mamba_conv_bias, + use_bias=hf_config.mamba_proj_bias, n_groups=hf_config.mamba_n_groups, num_heads=hf_config.mamba_n_heads, head_dim=hf_config.mamba_d_head, @@ -781,15 +792,15 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): vllm_ctx = vllm_config.compilation_config.static_forward_context with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") runner = GPUModelRunner(vllm_config, DEVICE) kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes - kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], - [available_memory])[0] + kv_cache_config = get_kv_cache_configs( + vllm_config, [kv_cache_spec], [available_memory] + )[0] runner.initialize_kv_cache(kv_cache_config) # random partition of blocks @@ -798,7 +809,7 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): num_blocks = kv_cache_config.num_blocks ind = np.arange(num_blocks) np.random.shuffle(ind) - blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):] + blocks0, blocks1 = ind[: (num_blocks // 2)], ind[(num_blocks // 2) :] attn_shape = vllm_ctx[layer_0].kv_cache[0].shape conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape @@ -807,34 +818,40 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): # assert we are using FlashInfer assert attn_shape[0] == num_blocks - attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]), - device=DEVICE, - fill_value=3.33) - conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]), - device=DEVICE, - fill_value=6.66) - ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]), - device=DEVICE, - fill_value=9.99) + attn_blocks_constant = torch.full( + (len(blocks0), *attn_shape[1:]), device=DEVICE, fill_value=3.33 + ) + conv_blocks_constant = torch.full( + (len(blocks1), *conv_shape[1:]), device=DEVICE, fill_value=6.66 + ) + ssm_blocks_constant = torch.full( + (len(blocks1), *ssm_shape[1:]), device=DEVICE, fill_value=9.99 + ) # fill all attention blocks with constant for layer in [layer_0, layer_1]: - vllm_ctx[layer].kv_cache[0][ - blocks0, :] = attn_blocks_constant.detach().clone() + vllm_ctx[layer].kv_cache[0][blocks0, :] = ( + attn_blocks_constant.detach().clone() + ) # fill all mamba blocks with constant for layer in [layer_2, layer_3, layer_4, layer_5]: - vllm_ctx[layer].kv_cache[0][0][ - blocks1, :] = conv_blocks_constant.detach().clone() - vllm_ctx[layer].kv_cache[0][1][ - blocks1, :] = ssm_blocks_constant.detach().clone() + vllm_ctx[layer].kv_cache[0][0][blocks1, :] = ( + conv_blocks_constant.detach().clone() + ) + vllm_ctx[layer].kv_cache[0][1][blocks1, :] = ( + ssm_blocks_constant.detach().clone() + ) # verify attention and mamba contents are correct for layer in [layer_0, layer_1]: - assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :], - attn_blocks_constant) + assert torch.equal( + vllm_ctx[layer].kv_cache[0][blocks0, :], attn_blocks_constant + ) for layer in [layer_2, layer_3, layer_4, layer_5]: - assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :], - conv_blocks_constant) - assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :], - ssm_blocks_constant) + assert torch.equal( + vllm_ctx[layer].kv_cache[0][0][blocks1, :], conv_blocks_constant + ) + assert torch.equal( + vllm_ctx[layer].kv_cache[0][1][blocks1, :], ssm_blocks_constant + ) diff --git a/tests/v1/worker/test_utils.py b/tests/v1/worker/test_utils.py index fd0e630ce178..f987b09e603e 100644 --- a/tests/v1/worker/test_utils.py +++ b/tests/v1/worker/test_utils.py @@ -10,32 +10,28 @@ def test_bind_kv_cache(): from vllm.attention import Attention ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), - 'layers.1.self_attn': Attention(32, 128, 0.1), - 'layers.2.self_attn': Attention(32, 128, 0.1), - 'layers.3.self_attn': Attention(32, 128, 0.1), + "layers.0.self_attn": Attention(32, 128, 0.1), + "layers.1.self_attn": Attention(32, 128, 0.1), + "layers.2.self_attn": Attention(32, 128, 0.1), + "layers.3.self_attn": Attention(32, 128, 0.1), } kv_cache = { - 'layers.0.self_attn': torch.zeros((1, )), - 'layers.1.self_attn': torch.zeros((1, )), - 'layers.2.self_attn': torch.zeros((1, )), - 'layers.3.self_attn': torch.zeros((1, )), + "layers.0.self_attn": torch.zeros((1,)), + "layers.1.self_attn": torch.zeros((1,)), + "layers.2.self_attn": torch.zeros((1,)), + "layers.3.self_attn": torch.zeros((1,)), } runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) - assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ - 'layers.0.self_attn'] - assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ - 'layers.1.self_attn'] - assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[ - 'layers.2.self_attn'] - assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[ - 'layers.3.self_attn'] + assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"] + assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"] + assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"] + assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"] - assert runner_kv_caches[0] is kv_cache['layers.0.self_attn'] - assert runner_kv_caches[1] is kv_cache['layers.1.self_attn'] - assert runner_kv_caches[2] is kv_cache['layers.2.self_attn'] - assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] + assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"] + assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"] + assert runner_kv_caches[2] is kv_cache["layers.2.self_attn"] + assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"] def test_bind_kv_cache_non_attention(): @@ -43,21 +39,19 @@ def test_bind_kv_cache_non_attention(): # example from Jamba PP=2 ctx = { - 'model.layers.20.attn': Attention(32, 128, 0.1), - 'model.layers.28.attn': Attention(32, 128, 0.1), + "model.layers.20.attn": Attention(32, 128, 0.1), + "model.layers.28.attn": Attention(32, 128, 0.1), } kv_cache = { - 'model.layers.20.attn': torch.zeros((1, )), - 'model.layers.28.attn': torch.zeros((1, )), + "model.layers.20.attn": torch.zeros((1,)), + "model.layers.28.attn": torch.zeros((1,)), } runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) - assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ - 'model.layers.20.attn'] - assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[ - 'model.layers.28.attn'] + assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"] + assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"] - assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] - assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] + assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"] + assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"] diff --git a/tests/v1/worker/test_worker_memory_snapshot.py b/tests/v1/worker/test_worker_memory_snapshot.py index 6faa6bcc591c..cbfb9a8dc0b6 100644 --- a/tests/v1/worker/test_worker_memory_snapshot.py +++ b/tests/v1/worker/test_worker_memory_snapshot.py @@ -13,8 +13,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.utils import MemorySnapshot -from vllm.v1.worker.gpu_worker import (Worker, - init_worker_distributed_environment) +from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment # Global queue to track operation order across processes _QUEUE: Optional[Queue] = None @@ -28,11 +27,11 @@ def track_operation(operation: str, rank: int): def make_operation_tracker(operation_name: str, original_func): """Create a mock function that tracks when an operation is called. - + Args: operation_name: Name to use when tracking this operation original_func: The original function to wrap - + Returns: A wrapper function that tracks the operation and calls the original """ @@ -45,8 +44,13 @@ def wrapper(*args, **kwargs): return wrapper -def worker_process(rank: int, world_size: int, distributed_init_method: str, - queue: Queue, error_queue: Queue): +def worker_process( + rank: int, + world_size: int, + distributed_init_method: str, + queue: Queue, + error_queue: Queue, +): """Worker process that initializes a GPU worker with proper tracking.""" global _QUEUE _QUEUE = queue @@ -58,9 +62,9 @@ def worker_process(rank: int, world_size: int, distributed_init_method: str, os.environ["WORLD_SIZE"] = str(world_size) # Create vLLM config with small model - vllm_config = EngineArgs(model="facebook/opt-125m", - tensor_parallel_size=2, - load_format="dummy").create_engine_config() + vllm_config = EngineArgs( + model="facebook/opt-125m", tensor_parallel_size=2, load_format="dummy" + ).create_engine_config() # Create worker worker = Worker( @@ -77,19 +81,22 @@ def worker_process(rank: int, world_size: int, distributed_init_method: str, # Apply minimal patches to track operation order init_patch = patch( - 'vllm.v1.worker.gpu_worker.init_worker_distributed_environment', - side_effect=make_operation_tracker("init_distributed", - original_init_worker)) + "vllm.v1.worker.gpu_worker.init_worker_distributed_environment", + side_effect=make_operation_tracker( + "init_distributed", original_init_worker + ), + ) memory_patch = patch.object( - MemorySnapshot, '__init__', - make_operation_tracker("memory_snapshot", - original_memory_snapshot_init)) - all_reduce_patch = patch('torch.distributed.all_reduce', - side_effect=make_operation_tracker( - "nccl_all_reduce", original_all_reduce)) + MemorySnapshot, + "__init__", + make_operation_tracker("memory_snapshot", original_memory_snapshot_init), + ) + all_reduce_patch = patch( + "torch.distributed.all_reduce", + side_effect=make_operation_tracker("nccl_all_reduce", original_all_reduce), + ) with init_patch, memory_patch, all_reduce_patch: - # Initialize device (this is where we test the order) worker.init_device() @@ -104,13 +111,14 @@ def worker_process(rank: int, world_size: int, distributed_init_method: str, raise -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs for tensor parallelism") +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Need at least 2 GPUs for tensor parallelism" +) def test_init_distributed_is_called_before_memory_snapshot(): """Test that distributed env is setup before memory snapshot. - - This test makes sure during worker initialization, the initial memory - snapshot is taken after distributed env is setup to include all the buffers + + This test makes sure during worker initialization, the initial memory + snapshot is taken after distributed env is setup to include all the buffers allocated by distributed env. """ world_size = 2 @@ -127,9 +135,16 @@ def test_init_distributed_is_called_before_memory_snapshot(): # Start worker processes processes = [] for rank in range(world_size): - p = ctx.Process(target=worker_process, - args=(rank, world_size, distributed_init_method, - operation_queue, error_queue)) + p = ctx.Process( + target=worker_process, + args=( + rank, + world_size, + distributed_init_method, + operation_queue, + error_queue, + ), + ) p.start() processes.append(p) @@ -168,7 +183,8 @@ def test_init_distributed_is_called_before_memory_snapshot(): assert init_distributed < nccl_all_reduce < memory_snapshot, ( f"Rank {rank}: init_distributed (index {init_distributed}) " f"must happen before nccl_all_reduce (index {nccl_all_reduce}) " - f"and memory_snapshot (index {memory_snapshot})") + f"and memory_snapshot (index {memory_snapshot})" + ) # Clean up os.unlink(distributed_init_method.replace("file://", "")) diff --git a/tests/vllm_test_utils/setup.py b/tests/vllm_test_utils/setup.py index 83be8bdce85c..4cb66b556e5a 100644 --- a/tests/vllm_test_utils/setup.py +++ b/tests/vllm_test_utils/setup.py @@ -4,7 +4,7 @@ from setuptools import setup setup( - name='vllm_test_utils', - version='0.1', - packages=['vllm_test_utils'], + name="vllm_test_utils", + version="0.1", + packages=["vllm_test_utils"], ) diff --git a/tests/vllm_test_utils/vllm_test_utils/blame.py b/tests/vllm_test_utils/vllm_test_utils/blame.py index 49fd083ef19c..e2cab92ea22b 100644 --- a/tests/vllm_test_utils/vllm_test_utils/blame.py +++ b/tests/vllm_test_utils/vllm_test_utils/blame.py @@ -26,7 +26,7 @@ def blame(func: Callable) -> Generator[BlameResult, None, None]: ```python with blame(lambda: some_condition()) as result: # do something - + if result.found: print(result.trace_stack) """ @@ -34,7 +34,7 @@ def blame(func: Callable) -> Generator[BlameResult, None, None]: def _trace_calls(frame, event, arg=None): nonlocal result - if event in ['call', 'return']: + if event in ["call", "return"]: # for every function call or return try: # Temporarily disable the trace function diff --git a/tests/vllm_test_utils/vllm_test_utils/monitor.py b/tests/vllm_test_utils/vllm_test_utils/monitor.py index 9454221b273e..e2f1212ed554 100644 --- a/tests/vllm_test_utils/vllm_test_utils/monitor.py +++ b/tests/vllm_test_utils/vllm_test_utils/monitor.py @@ -19,8 +19,8 @@ class MonitoredValues(Generic[_T]): @contextlib.contextmanager def monitor( - measure_func: Callable[[], - _T]) -> Generator[MonitoredValues[_T], None, None]: + measure_func: Callable[[], _T], +) -> Generator[MonitoredValues[_T], None, None]: """ Trace the function calls to continuously monitor the change of a value. @@ -28,23 +28,23 @@ def monitor( Usage: ```python - def measure_func(): - ... # measure the current value + ... # measure the current value return current_value + with monitor(measure_func) as monitored_values: # do something - - monitored_values.values # all changes of the values - monitored_values.trace_stacks # trace stacks of every change + + monitored_values.values # all changes of the values + monitored_values.trace_stacks # trace stacks of every change ``` """ monitored_values = MonitoredValues[_T]() def _trace_calls(frame, event, arg=None): nonlocal monitored_values - if event in ['line']: + if event in ["line"]: # triggered by every line of Python code. # only Python functions will trigger it, # c/cpp functions will not trigger it. @@ -53,11 +53,14 @@ def _trace_calls(frame, event, arg=None): sys.settrace(None) # do a measurement current_value = measure_func() - if len(monitored_values.values - ) == 0 or current_value != monitored_values.values[-1]: + if ( + len(monitored_values.values) == 0 + or current_value != monitored_values.values[-1] + ): monitored_values.values.append(current_value) - monitored_values.trace_stacks.append("".join( - traceback.format_stack())) + monitored_values.trace_stacks.append( + "".join(traceback.format_stack()) + ) # Re-enable the trace function sys.settrace(_trace_calls) except NameError: diff --git a/tests/weight_loading/test_weight_loading.py b/tests/weight_loading/test_weight_loading.py index 3aabae099073..658773068208 100644 --- a/tests/weight_loading/test_weight_loading.py +++ b/tests/weight_loading/test_weight_loading.py @@ -9,35 +9,39 @@ from vllm.platforms import current_platform MAX_MODEL_LEN = 1024 -MODEL_NAME = os.environ.get("MODEL_NAME", - "robertgshaw2/zephyr-7b-beta-channelwise-gptq") +MODEL_NAME = os.environ.get( + "MODEL_NAME", "robertgshaw2/zephyr-7b-beta-channelwise-gptq" +) REVISION = os.environ.get("REVISION", "main") QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin") MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "80") @pytest.mark.skipif( - MODEL_NAME == "casperhansen/deepseek-coder-v2-instruct-awq", - reason="OOM in the CI") + MODEL_NAME == "casperhansen/deepseek-coder-v2-instruct-awq", reason="OOM in the CI" +) @pytest.mark.skipif( not current_platform.has_device_capability(int(MIN_CAPABILITY)), - reason="Current system does not have minimum capability.") + reason="Current system does not have minimum capability.", +) def test_weight_loading(vllm_runner): """ Test parameter weight loading with tp>1. """ # MoE models need fp16. - NEEDS_FP16 = (QUANTIZATION == "gptq" or MODEL_NAME - == "nm-testing/test-w4a16-mixtral-actorder-group") + NEEDS_FP16 = ( + QUANTIZATION == "gptq" + or MODEL_NAME == "nm-testing/test-w4a16-mixtral-actorder-group" + ) with vllm_runner( - model_name=MODEL_NAME, - revision=REVISION, - dtype=torch.half if NEEDS_FP16 else "auto", - quantization=None if QUANTIZATION == "None" else QUANTIZATION, - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=2) as model: - + model_name=MODEL_NAME, + revision=REVISION, + dtype=torch.half if NEEDS_FP16 else "auto", + quantization=None if QUANTIZATION == "None" else QUANTIZATION, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=2, + ) as model: output = model.generate_greedy("Hello world!", max_tokens=20) print(output) assert output diff --git a/tools/check_init_lazy_imports.py b/tools/check_init_lazy_imports.py index e8e6f07cc33f..9255aa17db6a 100644 --- a/tools/check_init_lazy_imports.py +++ b/tools/check_init_lazy_imports.py @@ -17,12 +17,16 @@ INIT_PATH: Final = REPO_ROOT / "vllm" / "__init__.py" # If you need to add items to whitelist, do it here. -ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset({ - "vllm.env_override", -}) -ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset({ - ".version", -}) +ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset( + { + "vllm.env_override", + } +) +ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset( + { + ".version", + } +) def _is_internal(name: str | None, *, level: int = 0) -> bool: @@ -34,8 +38,7 @@ def _is_internal(name: str | None, *, level: int = 0) -> bool: def _fail(violations: Iterable[tuple[int, str]]) -> None: - print("ERROR: Disallowed eager imports in vllm/__init__.py:\n", - file=sys.stderr) + print("ERROR: Disallowed eager imports in vllm/__init__.py:\n", file=sys.stderr) for lineno, msg in violations: print(f" Line {lineno}: {msg}", file=sys.stderr) sys.exit(1) @@ -48,7 +51,6 @@ def main() -> None: violations: list[tuple[int, str]] = [] class Visitor(ast.NodeVisitor): - def __init__(self) -> None: super().__init__() self._in_type_checking = False @@ -56,10 +58,10 @@ def __init__(self) -> None: def visit_If(self, node: ast.If) -> None: guard_is_type_checking = False test = node.test - if isinstance(test, ast.Attribute) and isinstance( - test.value, ast.Name): - guard_is_type_checking = (test.value.id == "typing" - and test.attr == "TYPE_CHECKING") + if isinstance(test, ast.Attribute) and isinstance(test.value, ast.Name): + guard_is_type_checking = ( + test.value.id == "typing" and test.attr == "TYPE_CHECKING" + ) elif isinstance(test, ast.Name): guard_is_type_checking = test.id == "TYPE_CHECKING" @@ -79,24 +81,28 @@ def visit_Import(self, node: ast.Import) -> None: return for alias in node.names: module_name = alias.name - if _is_internal( - module_name) and module_name not in ALLOWED_IMPORTS: - violations.append(( - node.lineno, - f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501 - )) + if _is_internal(module_name) and module_name not in ALLOWED_IMPORTS: + violations.append( + ( + node.lineno, + f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501 + ) + ) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if self._in_type_checking: return module_as_written = ("." * node.level) + (node.module or "") - if _is_internal( - node.module, level=node.level - ) and module_as_written not in ALLOWED_FROM_MODULES: - violations.append(( - node.lineno, - f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501 - )) + if ( + _is_internal(node.module, level=node.level) + and module_as_written not in ALLOWED_FROM_MODULES + ): + violations.append( + ( + node.lineno, + f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501 + ) + ) Visitor().visit(tree) diff --git a/tools/check_spdx_header.py b/tools/check_spdx_header.py index ced10ba9097b..1fcca12519ff 100644 --- a/tools/check_spdx_header.py +++ b/tools/check_spdx_header.py @@ -7,6 +7,7 @@ class SPDXStatus(Enum): """SPDX header status enumeration""" + EMPTY = "empty" # empty __init__.py COMPLETE = "complete" MISSING_LICENSE = "missing_license" # Only has copyright line @@ -16,7 +17,8 @@ class SPDXStatus(Enum): FULL_SPDX_HEADER = ( "# SPDX-License-Identifier: Apache-2.0\n" - "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project") + "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" +) LICENSE_LINE = "# SPDX-License-Identifier: Apache-2.0" COPYRIGHT_LINE = "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project" # noqa: E501 @@ -123,8 +125,9 @@ def main(): continue # Collect all files that need fixing - all_files_to_fix = (files_missing_both + files_missing_copyright + - files_missing_license) + all_files_to_fix = ( + files_missing_both + files_missing_copyright + files_missing_license + ) if all_files_to_fix: print("The following files are missing the SPDX header:") if files_missing_both: diff --git a/tools/check_triton_import.py b/tools/check_triton_import.py index c01d9d4ab079..1b83074fe0d2 100644 --- a/tools/check_triton_import.py +++ b/tools/check_triton_import.py @@ -23,8 +23,7 @@ def is_allowed_file(current_file: str) -> bool: def is_forbidden_import(line: str) -> bool: stripped = line.strip() - return bool( - FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES + return bool(FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES def parse_diff(diff: str) -> list[str]: @@ -42,24 +41,24 @@ def parse_diff(diff: str) -> list[str]: elif line.startswith("@@"): match = re.search(r"\+(\d+)", line) if match: - current_lineno = int( - match.group(1)) - 1 # next "+ line" is here + current_lineno = int(match.group(1)) - 1 # next "+ line" is here elif line.startswith("+") and not line.startswith("++"): current_lineno += 1 code_line = line[1:] if is_forbidden_import(code_line): violations.append( - f"{current_file}:{current_lineno}: {code_line.strip()}") + f"{current_file}:{current_lineno}: {code_line.strip()}" + ) return violations def get_diff(diff_type: str) -> str: if diff_type == "staged": return subprocess.check_output( - ["git", "diff", "--cached", "--unified=0"], text=True) + ["git", "diff", "--cached", "--unified=0"], text=True + ) elif diff_type == "unstaged": - return subprocess.check_output(["git", "diff", "--unified=0"], - text=True) + return subprocess.check_output(["git", "diff", "--unified=0"], text=True) else: raise ValueError(f"Unknown diff_type: {diff_type}") @@ -75,8 +74,10 @@ def main(): print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr) if all_violations: - print("❌ Forbidden direct `import triton` detected." - " ➤ Use `from vllm.triton_utils import triton` instead.\n") + print( + "❌ Forbidden direct `import triton` detected." + " ➤ Use `from vllm.triton_utils import triton` instead.\n" + ) for v in all_violations: print(f"❌ {v}") return 1 diff --git a/tools/enforce_regex_import.py b/tools/enforce_regex_import.py index 63ceee5829ab..69f43cadc767 100644 --- a/tools/enforce_regex_import.py +++ b/tools/enforce_regex_import.py @@ -7,24 +7,23 @@ import regex as re -FORBIDDEN_PATTERNS = re.compile( - r'^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)') +FORBIDDEN_PATTERNS = re.compile(r"^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)") ALLOWED_PATTERNS = [ - re.compile(r'^\s*import\s+regex\s+as\s+re\s*$'), - re.compile(r'^\s*import\s+regex\s*$'), + re.compile(r"^\s*import\s+regex\s+as\s+re\s*$"), + re.compile(r"^\s*import\s+regex\s*$"), ] def get_staged_python_files() -> list[str]: try: result = subprocess.run( - ['git', 'diff', '--cached', '--name-only', '--diff-filter=AM'], + ["git", "diff", "--cached", "--name-only", "--diff-filter=AM"], capture_output=True, text=True, - check=True) - files = result.stdout.strip().split( - '\n') if result.stdout.strip() else [] - return [f for f in files if f.endswith('.py')] + check=True, + ) + files = result.stdout.strip().split("\n") if result.stdout.strip() else [] + return [f for f in files if f.endswith(".py")] except subprocess.CalledProcessError: return [] @@ -33,13 +32,14 @@ def is_forbidden_import(line: str) -> bool: line = line.strip() return bool( FORBIDDEN_PATTERNS.match(line) - and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS)) + and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS) + ) def check_file(filepath: str) -> list[tuple[int, str]]: violations = [] try: - with open(filepath, encoding='utf-8') as f: + with open(filepath, encoding="utf-8") as f: for line_num, line in enumerate(f, 1): if is_forbidden_import(line): violations.append((line_num, line.strip())) @@ -72,9 +72,7 @@ def main() -> int: if total_violations > 0: print(f"\n💡 Found {total_violations} violation(s).") print("❌ Please replace 'import re' with 'import regex as re'") - print( - " Also replace 'from re import ...' with 'from regex import ...'" - ) # noqa: E501 + print(" Also replace 'from re import ...' with 'from regex import ...'") # noqa: E501 print("✅ Allowed imports:") print(" - import regex as re") print(" - import regex") # noqa: E501 diff --git a/tools/generate_cmake_presets.py b/tools/generate_cmake_presets.py index 4869a71307e4..85847c2c0fe8 100644 --- a/tools/generate_cmake_presets.py +++ b/tools/generate_cmake_presets.py @@ -12,8 +12,7 @@ # most reliable source of truth for vLLM's build. from torch.utils.cpp_extension import CUDA_HOME except ImportError: - print("Warning: PyTorch not found. " - "Falling back to CUDA_HOME environment variable.") + print("Warning: PyTorch not found. Falling back to CUDA_HOME environment variable.") CUDA_HOME = os.environ.get("CUDA_HOME") @@ -27,8 +26,7 @@ def get_cpu_cores(): return multiprocessing.cpu_count() -def generate_presets(output_path="CMakeUserPresets.json", - force_overwrite=False): +def generate_presets(output_path="CMakeUserPresets.json", force_overwrite=False): """Generates the CMakeUserPresets.json file.""" print("Attempting to detect your system configuration...") @@ -39,8 +37,7 @@ def generate_presets(output_path="CMakeUserPresets.json", prospective_path = os.path.join(CUDA_HOME, "bin", "nvcc") if os.path.exists(prospective_path): nvcc_path = prospective_path - print("Found nvcc via torch.utils.cpp_extension.CUDA_HOME: " - f"{nvcc_path}") + print(f"Found nvcc via torch.utils.cpp_extension.CUDA_HOME: {nvcc_path}") if not nvcc_path: nvcc_path = which("nvcc") @@ -50,7 +47,8 @@ def generate_presets(output_path="CMakeUserPresets.json", if not nvcc_path: nvcc_path_input = input( "Could not automatically find 'nvcc'. Please provide the full " - "path to nvcc (e.g., /usr/local/cuda/bin/nvcc): ") + "path to nvcc (e.g., /usr/local/cuda/bin/nvcc): " + ) nvcc_path = nvcc_path_input.strip() print(f"Using NVCC path: {nvcc_path}") @@ -63,12 +61,13 @@ def generate_presets(output_path="CMakeUserPresets.json", "Could not automatically find Python executable. Please provide " "the full path to your Python executable for vLLM development " "(typically from your virtual environment, e.g., " - "/home/user/venvs/vllm/bin/python): ") + "/home/user/venvs/vllm/bin/python): " + ) python_executable = input(python_executable_prompt).strip() if not python_executable: raise ValueError( - "Could not determine Python executable. Please provide it " - "manually.") + "Could not determine Python executable. Please provide it manually." + ) print(f"Using Python executable: {python_executable}") @@ -76,20 +75,23 @@ def generate_presets(output_path="CMakeUserPresets.json", cpu_cores = get_cpu_cores() nvcc_threads = min(4, cpu_cores) cmake_jobs = max(1, cpu_cores // nvcc_threads) - print(f"Detected {cpu_cores} CPU cores. " - f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}.") + print( + f"Detected {cpu_cores} CPU cores. " + f"Setting NVCC_THREADS={nvcc_threads} and CMake jobs={cmake_jobs}." + ) # Get vLLM project root (assuming this script is in vllm/tools/) - project_root = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..")) + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) print(f"VLLM project root detected as: {project_root}") # Ensure python_executable path is absolute or resolvable if not os.path.isabs(python_executable) and which(python_executable): python_executable = os.path.abspath(which(python_executable)) elif not os.path.isabs(python_executable): - print(f"Warning: Python executable '{python_executable}' is not an " - "absolute path and not found in PATH. CMake might not find it.") + print( + f"Warning: Python executable '{python_executable}' is not an " + "absolute path and not found in PATH. CMake might not find it." + ) cache_variables = { "CMAKE_CUDA_COMPILER": nvcc_path, @@ -122,24 +124,20 @@ def generate_presets(output_path="CMakeUserPresets.json", configure_preset["generator"] = "Ninja" cache_variables["CMAKE_JOB_POOLS"] = f"compile={cmake_jobs}" else: - print("Ninja not found, using default generator. " - "Build may be slower.") + print("Ninja not found, using default generator. Build may be slower.") presets = { - "version": - 6, + "version": 6, # Keep in sync with CMakeLists.txt and requirements/build.txt - "cmakeMinimumRequired": { - "major": 3, - "minor": 26, - "patch": 1 - }, + "cmakeMinimumRequired": {"major": 3, "minor": 26, "patch": 1}, "configurePresets": [configure_preset], - "buildPresets": [{ - "name": "release", - "configurePreset": "release", - "jobs": cmake_jobs, - }], + "buildPresets": [ + { + "name": "release", + "configurePreset": "release", + "jobs": cmake_jobs, + } + ], } output_file_path = os.path.join(project_root, output_path) @@ -148,10 +146,12 @@ def generate_presets(output_path="CMakeUserPresets.json", if force_overwrite: print(f"Overwriting existing file '{output_file_path}'") else: - overwrite = input( - f"'{output_file_path}' already exists. Overwrite? (y/N): " - ).strip().lower() - if overwrite != 'y': + overwrite = ( + input(f"'{output_file_path}' already exists. Overwrite? (y/N): ") + .strip() + .lower() + ) + if overwrite != "y": print("Generation cancelled.") return @@ -160,11 +160,9 @@ def generate_presets(output_path="CMakeUserPresets.json", json.dump(presets, f, indent=4) print(f"Successfully generated '{output_file_path}'") print("\nTo use this preset:") - print( - f"1. Ensure you are in the vLLM root directory: cd {project_root}") + print(f"1. Ensure you are in the vLLM root directory: cd {project_root}") print("2. Initialize CMake: cmake --preset release") - print("3. Build+install: cmake --build --preset release " - "--target install") + print("3. Build+install: cmake --build --preset release --target install") except OSError as e: print(f"Error writing file: {e}") @@ -175,7 +173,7 @@ def generate_presets(output_path="CMakeUserPresets.json", parser.add_argument( "--force-overwrite", action="store_true", - help="Force overwrite existing CMakeUserPresets.json without prompting" + help="Force overwrite existing CMakeUserPresets.json without prompting", ) args = parser.parse_args() diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index c97a5b0b6c71..bceb894a7a5f 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -17,44 +17,48 @@ # add to this list if absolutely necessary and after careful security review. ALLOWED_FILES = { # pickle - 'vllm/v1/serial_utils.py', - 'vllm/v1/executor/multiproc_executor.py', - 'vllm/multimodal/hasher.py', - 'vllm/transformers_utils/config.py', - 'vllm/model_executor/models/registry.py', - 'tests/utils_/test_utils.py', - 'tests/tokenization/test_cached_tokenizer.py', - 'vllm/distributed/utils.py', - 'vllm/distributed/parallel_state.py', - 'vllm/distributed/device_communicators/all_reduce_utils.py', - 'vllm/distributed/device_communicators/shm_broadcast.py', - 'vllm/distributed/device_communicators/shm_object_storage.py', - 'benchmarks/kernels/graph_machete_bench.py', - 'benchmarks/kernels/benchmark_lora.py', - 'benchmarks/kernels/benchmark_machete.py', - 'benchmarks/fused_kernels/layernorm_rms_benchmarks.py', - 'benchmarks/cutlass_benchmarks/w8a8_benchmarks.py', - 'benchmarks/cutlass_benchmarks/sparse_benchmarks.py', + "vllm/v1/serial_utils.py", + "vllm/v1/executor/multiproc_executor.py", + "vllm/multimodal/hasher.py", + "vllm/transformers_utils/config.py", + "vllm/model_executor/models/registry.py", + "tests/utils_/test_utils.py", + "tests/tokenization/test_cached_tokenizer.py", + "vllm/distributed/utils.py", + "vllm/distributed/parallel_state.py", + "vllm/distributed/device_communicators/all_reduce_utils.py", + "vllm/distributed/device_communicators/shm_broadcast.py", + "vllm/distributed/device_communicators/shm_object_storage.py", + "benchmarks/kernels/graph_machete_bench.py", + "benchmarks/kernels/benchmark_lora.py", + "benchmarks/kernels/benchmark_machete.py", + "benchmarks/fused_kernels/layernorm_rms_benchmarks.py", + "benchmarks/cutlass_benchmarks/w8a8_benchmarks.py", + "benchmarks/cutlass_benchmarks/sparse_benchmarks.py", # cloudpickle - 'vllm/executor/mp_distributed_executor.py', - 'vllm/executor/ray_distributed_executor.py', - 'vllm/entrypoints/llm.py', - 'tests/utils.py', + "vllm/executor/mp_distributed_executor.py", + "vllm/executor/ray_distributed_executor.py", + "vllm/entrypoints/llm.py", + "tests/utils.py", # pickle and cloudpickle - 'vllm/utils/__init__.py', + "vllm/utils/__init__.py", } -PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)" - r"|from\s+(pickle|cloudpickle)\s+import\b)") +PICKLE_RE = re.compile( + r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)" + r"|from\s+(pickle|cloudpickle)\s+import\b)" +) def scan_file(path: str) -> int: - with open(path, encoding='utf-8') as f: + with open(path, encoding="utf-8") as f: for i, line in enumerate(f, 1): if PICKLE_RE.match(line): - print(f"{path}:{i}: " - "\033[91merror:\033[0m " # red color - "Found pickle/cloudpickle import") + print( + f"{path}:{i}: " + "\033[91merror:\033[0m " # red color + "Found pickle/cloudpickle import" + ) return 1 return 0 @@ -92,13 +96,13 @@ def test_regex(): for i, (line, should_match) in enumerate(test_cases): result = bool(PICKLE_RE.match(line)) assert result == should_match, ( - f"Test case {i} failed: '{line}' " - f"(expected {should_match}, got {result})") + f"Test case {i} failed: '{line}' (expected {should_match}, got {result})" + ) print("All regex tests passed.") -if __name__ == '__main__': - if '--test-regex' in sys.argv: +if __name__ == "__main__": + if "--test-regex" in sys.argv: test_regex() else: sys.exit(main()) diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 039cf6075f63..22ee08535bdd 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -94,11 +94,15 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]: return file_groups -def mypy(targets: list[str], python_version: Optional[str], - follow_imports: Optional[str], file_group: str) -> int: +def mypy( + targets: list[str], + python_version: Optional[str], + follow_imports: Optional[str], + file_group: str, +) -> int: """ Run mypy on the given targets. - + Args: targets: List of files or directories to check. python_version: Python version to use (e.g., "3.10") or None to use @@ -131,8 +135,9 @@ def main(): for file_group, changed_files in file_groups.items(): follow_imports = None if ci and file_group == "" else "skip" if changed_files: - returncode |= mypy(changed_files, python_version, follow_imports, - file_group) + returncode |= mypy( + changed_files, python_version, follow_imports, file_group + ) return returncode diff --git a/tools/profiler/nsys_profile_tools/gputrc2graph.py b/tools/profiler/nsys_profile_tools/gputrc2graph.py index 42dfede9e987..fd237c0b214a 100755 --- a/tools/profiler/nsys_profile_tools/gputrc2graph.py +++ b/tools/profiler/nsys_profile_tools/gputrc2graph.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ - This generates gpu kernel analysis output from nsys rep. Will call nsys - stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate - csv and html output for analysis +This generates gpu kernel analysis output from nsys rep. Will call nsys +stats -r cuda_gpu_kern_trace, get non-overlapped gpu cycles, then generate +csv and html output for analysis """ + import argparse import logging import os @@ -16,13 +17,13 @@ # helper data class for annotating kernels def load_engine_model(): - """ returns engine_model built from all json files in the current dir """ + """returns engine_model built from all json files in the current dir""" import glob import json + engine_model = {} - json_files = glob.glob( - os.path.join(os.path.dirname(__file__) or ".", "*.json")) + json_files = glob.glob(os.path.join(os.path.dirname(__file__) or ".", "*.json")) for fname in json_files: with open(fname, encoding="utf-8") as f: engine_model.update(json.load(f)) @@ -30,54 +31,54 @@ def load_engine_model(): class GPUTrace2Graph: - """ - Parses output of nsys report, generates csv and bar chart output + """ + Parses output of nsys report, generates csv and bar chart output """ def __init__(self): import pandas as pd # avoid importing till needed + self.pd = pd self.pd.options.mode.copy_on_write = True # helper functions for generating trace->summary csvs def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file): - logger.info('loading %s', in_file) + logger.info("loading %s", in_file) df = self.pd.read_csv( - in_file, - usecols=['Start (ns)', 'Duration (ns)', 'Device', 'Strm', 'Name']) - df['End (ns)'] = df['Start (ns)'] + df['Duration (ns)'] + in_file, usecols=["Start (ns)", "Duration (ns)", "Device", "Strm", "Name"] + ) + df["End (ns)"] = df["Start (ns)"] + df["Duration (ns)"] df = self.sum_non_overlapping_intervals(df) # get ready to print table with elapsed times per kernel - df['Instances'] = 1 - df_sum = df.groupby('Name', as_index=False).agg({ - 'Elapsed Time (ns)': 'sum', - 'Duration (ns)': 'sum', - 'Instances': 'size' - }) + df["Instances"] = 1 + df_sum = df.groupby("Name", as_index=False).agg( + {"Elapsed Time (ns)": "sum", "Duration (ns)": "sum", "Instances": "size"} + ) # generate csv - df_sum['Total Time (sec)'] = df_sum['Duration (ns)'] / 1e9 - df_sum['Elapsed Time (sec)'] = df_sum['Elapsed Time (ns)'] / 1e9 - df_sum = df_sum.sort_values(by='Elapsed Time (sec)', ascending=False) - df_sum[['Elapsed Time (sec)', 'Total Time (sec)', 'Instances', - 'Name']].to_csv(out_file, index=False) + df_sum["Total Time (sec)"] = df_sum["Duration (ns)"] / 1e9 + df_sum["Elapsed Time (sec)"] = df_sum["Elapsed Time (ns)"] / 1e9 + df_sum = df_sum.sort_values(by="Elapsed Time (sec)", ascending=False) + df_sum[["Elapsed Time (sec)", "Total Time (sec)", "Instances", "Name"]].to_csv( + out_file, index=False + ) def sum_non_overlapping_intervals(self, df): - """ - returns new sorted df with Elapsed Time (ns) column using - vectorized operations + """ + returns new sorted df with Elapsed Time (ns) column using + vectorized operations """ logger.info("sorting %s trace records by start time", str(df.shape)) # Sort by start time and reset index - df = df.sort_values(by='Start (ns)').reset_index(drop=True) + df = df.sort_values(by="Start (ns)").reset_index(drop=True) # Initialize elapsed time as duration - df['Elapsed Time (ns)'] = df['Duration (ns)'] + df["Elapsed Time (ns)"] = df["Duration (ns)"] # Get numpy arrays for faster operations - starts = df['Start (ns)'].values - ends = df['End (ns)'].values + starts = df["Start (ns)"].values + ends = df["End (ns)"].values # Keep track of current interval end current_end = ends[0] @@ -85,16 +86,17 @@ def sum_non_overlapping_intervals(self, df): # Update current_end for overlapping intervals for i in range(1, len(df)): if i % display_units == 0: - print(f'processing trace: {int(i/len(df) * 100)} %', end="\r") + print(f"processing trace: {int(i / len(df) * 100)} %", end="\r") if starts[i] <= current_end: if ends[i] > current_end: # Partial overlap - df.iloc[i, df.columns.get_loc('Elapsed Time (ns)' - )] = ends[i] - current_end + df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = ( + ends[i] - current_end + ) current_end = ends[i] else: # Complete overlap - df.iloc[i, df.columns.get_loc('Elapsed Time (ns)')] = 0 + df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = 0 else: # No overlap current_end = ends[i] @@ -103,147 +105,167 @@ def sum_non_overlapping_intervals(self, df): # functions for generating html files def make_html(self, df, output_dir, title): - """ make html graph from df """ + """make html graph from df""" import plotly.express as px + if df.empty: return - output_name = output_dir + '/result' + output_name = output_dir + "/result" if not title: - title = 'Model_Engine' - x = 'Model_Engine' - y = 'Elapsed Time (sec)' - color = 'Category' + title = "Model_Engine" + x = "Model_Engine" + y = "Elapsed Time (sec)" + color = "Category" """ generate kernel mapping table """ # Sort Model_Engine categories by last field after underscore - df['Model_Engine'] = self.pd.Categorical( - df['Model_Engine'], - sorted(df['Model_Engine'].unique(), - key=lambda x: x.split('_')[-1])) - df[['Model_Engine', color, 'Instances', 'Name', - y]].sort_values(by=color).to_csv(f'{output_name}.csv', index=False) - graph = px.histogram(df.round(2), - x=x, - y=y, - title=(f'{y} for {title}'), - color=color, - text_auto=True) + df["Model_Engine"] = self.pd.Categorical( + df["Model_Engine"], + sorted(df["Model_Engine"].unique(), key=lambda x: x.split("_")[-1]), + ) + df[["Model_Engine", color, "Instances", "Name", y]].sort_values( + by=color + ).to_csv(f"{output_name}.csv", index=False) + graph = px.histogram( + df.round(2), + x=x, + y=y, + title=(f"{y} for {title}"), + color=color, + text_auto=True, + ) # wrap x axis labels graph.update_xaxes(automargin=True) - graph.write_html(f'{output_name}.html') + graph.write_html(f"{output_name}.html") """ Generate data table with columns per Model_Engine into result.html """ - pivot_df = df.pivot_table(values='Elapsed Time (sec)', - index='Category', - columns='Model_Engine', - aggfunc='sum', - observed=False).round(2) + pivot_df = df.pivot_table( + values="Elapsed Time (sec)", + index="Category", + columns="Model_Engine", + aggfunc="sum", + observed=False, + ).round(2) # Add sum row at bottom - pivot_df.loc['total_elapsed_sec'] = pivot_df.sum() - pivot_df.fillna('').to_html('temp.html') - with (open(f'{output_name}.html', 'a', encoding='utf-8') as - outfile, open('temp.html', encoding='utf-8') as infile): + pivot_df.loc["total_elapsed_sec"] = pivot_df.sum() + pivot_df.fillna("").to_html("temp.html") + with ( + open(f"{output_name}.html", "a", encoding="utf-8") as outfile, + open("temp.html", encoding="utf-8") as infile, + ): outfile.write(infile.read()) - os.remove('temp.html') + os.remove("temp.html") - print(f'Finished generating: \n' - f' {output_name}.html for stack bar chart \n' - f' {output_name}.csv for Kernel-Category mapping') + print( + f"Finished generating: \n" + f" {output_name}.html for stack bar chart \n" + f" {output_name}.csv for Kernel-Category mapping" + ) def anno_gpu_kernname(self, df, mapping): - """ add "Category" column """ + """add "Category" column""" def anno_gpu_kernname_helper(name): for kern_name, val in mapping.items(): if re.search(kern_name, name): return val - df['Category'] = df['Name'].apply(anno_gpu_kernname_helper) + df["Category"] = df["Name"].apply(anno_gpu_kernname_helper) def make_nongpu_row(self, df, nongpu_sec): - """ this will append non-gpu time entry at end of df """ + """this will append non-gpu time entry at end of df""" nongpu_row = self.pd.DataFrame([df.iloc[-1]]) - nongpu_row['Category'] = nongpu_row['Name'] = 'CPU(non-GPU)' - nongpu_row['Instances'] = 1 - nongpu_row['Elapsed Time (sec)'] = nongpu_sec - return (nongpu_row) + nongpu_row["Category"] = nongpu_row["Name"] = "CPU(non-GPU)" + nongpu_row["Instances"] = 1 + nongpu_row["Elapsed Time (sec)"] = nongpu_sec + return nongpu_row def is_valid_file(self, base_file): - """ asserts if base_file is non-existent or is empty """ - assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, \ - f"{base_file} doesn't exist or is empty" + """asserts if base_file is non-existent or is empty""" + assert os.path.isfile(base_file) and os.path.getsize(base_file) > 0, ( + f"{base_file} doesn't exist or is empty" + ) def should_gen_file(self, new_file, base_file): - """ figure out if new file should be generated from base_file """ + """figure out if new file should be generated from base_file""" self.is_valid_file(base_file) - if (os.path.exists(new_file) - and (os.path.getmtime(new_file) > os.path.getmtime(base_file)) - and (os.path.getsize(base_file) > 0)): - logger.info('reusing %s', new_file) + if ( + os.path.exists(new_file) + and (os.path.getmtime(new_file) > os.path.getmtime(base_file)) + and (os.path.getsize(base_file) > 0) + ): + logger.info("reusing %s", new_file) return False else: - logger.info('generating %s', new_file) + logger.info("generating %s", new_file) return True def gen_sum_file(self, file, nsys_cmd): - """ - generates sum file from nsys trace with times per kernel and - returns the name of the sum file + """ + generates sum file from nsys trace with times per kernel and + returns the name of the sum file """ import subprocess + file_dir = os.path.dirname(file) file_name = os.path.basename(file) if not file_dir: - file_dir = '.' + file_dir = "." # Walk through trace and get the total non-overlapped time - nsys_stats_file = f'{file_dir}/{file_name}_cuda_gpu_trace.csv' - sum_file = f'{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv' + nsys_stats_file = f"{file_dir}/{file_name}_cuda_gpu_trace.csv" + sum_file = f"{file_dir}/{file_name}_cuda_gpu_kernel_tracesum.csv" if self.should_gen_file(nsys_stats_file, file): cmd = [ - nsys_cmd, 'stats', '-r', 'cuda_gpu_trace', file, '-o', - f'{file_dir}/{file_name}' + nsys_cmd, + "stats", + "-r", + "cuda_gpu_trace", + file, + "-o", + f"{file_dir}/{file_name}", ] - cmd_str = ' '.join(cmd) - logger.info('+ %s', cmd_str) + cmd_str = " ".join(cmd) + logger.info("+ %s", cmd_str) # estimate time based on calibrated 240M/min file_size_mb = os.path.getsize(file) / 1e6 logger.info( - 'nsys stats for %.2f MB file expected to take %.2f min', - file_size_mb, file_size_mb / 240) + "nsys stats for %.2f MB file expected to take %.2f min", + file_size_mb, + file_size_mb / 240, + ) try: subprocess.run(cmd, check=True) except Exception: - logger.error("%s failed; Use --nsys_cmd to specify nsys path", - cmd_str) + logger.error("%s failed; Use --nsys_cmd to specify nsys path", cmd_str) exit(1) - logger.info('generating non-overalapped sum %s', sum_file) + logger.info("generating non-overalapped sum %s", sum_file) self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file) self.is_valid_file(sum_file) - logger.info('Finished generating %s', sum_file) + logger.info("Finished generating %s", sum_file) return sum_file def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model): - """ generates graph and csv file from in_file into out_dir """ + """generates graph and csv file from in_file into out_dir""" # Initialize an empty DataFrame to store combined data combined_df = self.pd.DataFrame() for idx, (file, engine, model, total_sec) in enumerate(in_file): file_dir = os.path.dirname(file) file_name = os.path.basename(file) if not file_dir: - file_dir = '.' + file_dir = "." sum_file = self.gen_sum_file(file, nsys_cmd) # read kernel summary file df = self.pd.read_csv(sum_file) # annotate kernel to their categories - assert engine_model.get(engine), f'engine {engine} unknown' - assert engine_model[engine].get(model), f'model {model} unknown' + assert engine_model.get(engine), f"engine {engine} unknown" + assert engine_model[engine].get(model), f"model {model} unknown" # remove nsys-rep from file_name for shorter x-label - file_name = file_name.replace('.nsys-rep', '') - df['Model_Engine'] = f'{model}_{engine}_{file_name}_{idx}' + file_name = file_name.replace(".nsys-rep", "") + df["Model_Engine"] = f"{model}_{engine}_{file_name}_{idx}" self.anno_gpu_kernname(df, engine_model[engine][model]) # patch in non-gpu time - gpu_sec = round(df['Elapsed Time (sec)'].sum(), 1) + gpu_sec = round(df["Elapsed Time (sec)"].sum(), 1) total_sec = round(float(total_sec), 1) if total_sec < gpu_sec: logger.warning( @@ -256,7 +278,7 @@ def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model): df = self.pd.concat([df, nongpu_row], ignore_index=True) combined_df = self.pd.concat([combined_df, df], ignore_index=True) if out_dir is None: - out_dir = '.' + out_dir = "." else: os.makedirs(out_dir, exist_ok=True) # generate html file @@ -264,50 +286,59 @@ def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model): def parse_tuple(s): - return tuple(s.split(',')) + return tuple(s.split(",")) def main(): - logging.basicConfig(format=('%(asctime)s - %(levelname)s - %(message)s'), - level=logging.INFO) + logging.basicConfig( + format=("%(asctime)s - %(levelname)s - %(message)s"), level=logging.INFO + ) parser = argparse.ArgumentParser( description=( - 'Process nsys rep and generate kernel non-overlapped cycles. \n' - 'Example:\n' + "Process nsys rep and generate kernel non-overlapped cycles. \n" + "Example:\n" "gputrc2graph.py --in_file d1.nsys-rep,vllm,llama,100 \n" "d2.nsys-rep,vllm,gpt-oss,102 " - "--out_dir results/ --title \"Model=gpt-oss vLLM chart\""), - formatter_class=argparse.RawDescriptionHelpFormatter) + '--out_dir results/ --title "Model=gpt-oss vLLM chart"' + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) # load supported engine_model engine_model_supported = load_engine_model() # Get a string representation of supported engine/model combinations - engine_model_supported_str = ', '.join( + engine_model_supported_str = ", ".join( f"{engine}:[{', '.join(models.keys())}]" - for engine, models in engine_model_supported.items()) + for engine, models in engine_model_supported.items() + ) parser.add_argument( - '--in_file', + "--in_file", type=parse_tuple, - nargs='+', + nargs="+", help=( - 'list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) ' - 'separated by space. Elapsed_nonprofiled_sec is runtime without ' - 'profiling used to calculate non-gpu time. Specify 0 to use ' - 'elapsed time from nsys-rep but that might inflate non-gpu time. ' - f'Available engine:[model] are: {engine_model_supported_str} ' - f'Example: --infile d1.nsys-rep,vllm,llama,100 ' - 'd2.nsys-rep,vllm,gpt-oss,102'), - required=True) - parser.add_argument('--out_dir', help=('output dir for result.csv/html')) - parser.add_argument('--title', help=('title for html chart')) - parser.add_argument('--nsys_cmd', - help=('nsys cmd, e.g. /usr/bin/nsys, Default: nsys'), - default="nsys") + "list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) " + "separated by space. Elapsed_nonprofiled_sec is runtime without " + "profiling used to calculate non-gpu time. Specify 0 to use " + "elapsed time from nsys-rep but that might inflate non-gpu time. " + f"Available engine:[model] are: {engine_model_supported_str} " + f"Example: --infile d1.nsys-rep,vllm,llama,100 " + "d2.nsys-rep,vllm,gpt-oss,102" + ), + required=True, + ) + parser.add_argument("--out_dir", help=("output dir for result.csv/html")) + parser.add_argument("--title", help=("title for html chart")) + parser.add_argument( + "--nsys_cmd", + help=("nsys cmd, e.g. /usr/bin/nsys, Default: nsys"), + default="nsys", + ) args = parser.parse_args() gputrace = GPUTrace2Graph() - gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd, - engine_model_supported) + gputrace.gen_graph( + args.in_file, args.out_dir, args.title, args.nsys_cmd, engine_model_supported + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/profiler/print_layerwise_table.py b/tools/profiler/print_layerwise_table.py index 209c3a576aee..d7a24a598593 100644 --- a/tools/profiler/print_layerwise_table.py +++ b/tools/profiler/print_layerwise_table.py @@ -29,48 +29,50 @@ def get_entries(node, curr_depth=0): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--json-trace", - type=str, - required=True, - help="json trace file output by " - "examples/offline_inference/profiling.py") - parser.add_argument("--phase", - type=str, - required=True, - help="The phase to print the table for. This is either" - "prefill or decode_n, where n is the decode step " - "number") - parser.add_argument("--table", - type=str, - choices=["summary", "model"], - default="summary", - help="Which table to print, the summary table or the " - "layerwise model table") + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by examples/offline_inference/profiling.py", + ) + parser.add_argument( + "--phase", + type=str, + required=True, + help="The phase to print the table for. This is either" + "prefill or decode_n, where n is the decode step " + "number", + ) + parser.add_argument( + "--table", + type=str, + choices=["summary", "model"], + default="summary", + help="Which table to print, the summary table or the layerwise model table", + ) args = parser.parse_args() with open(args.json_trace) as f: profile_data = json.load(f) - assert args.phase in profile_data, \ - (f"Cannot find phase {args.phase} in profile data. Choose one among" - f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa + assert args.phase in profile_data, ( + f"Cannot find phase {args.phase} in profile data. Choose one among" + f"{[x for x in profile_data if 'prefill' in x or 'decode' in x]}" + ) # noqa if args.table == "summary": entries_and_depths = flatten_entries( - SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) - column_widths = dict(name=80, - cuda_time_us=12, - pct_cuda_time=12, - invocations=15) + SummaryStatsEntry, profile_data[args.phase]["summary_stats"] + ) + column_widths = dict(name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15) elif args.table == "model": entries_and_depths = flatten_entries( - ModelStatsEntry, profile_data[args.phase]["model_stats"]) - column_widths = dict(name=60, - cpu_time_us=12, - cuda_time_us=12, - pct_cuda_time=12, - trace=60) + ModelStatsEntry, profile_data[args.phase]["model_stats"] + ) + column_widths = dict( + name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60 + ) # indent entry names based on the depth entries = [] @@ -78,7 +80,8 @@ def get_entries(node, curr_depth=0): entry.name = indent_string( entry.name, indent=depth, - indent_style=lambda indent: "|" + "-" * indent + " ") + indent_style=lambda indent: "|" + "-" * indent + " ", + ) entries.append(entry) TablePrinter(type(entries[0]), column_widths).print_table(entries) diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index 30d6547073d3..cdab004366f9 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -18,17 +18,18 @@ def largest_dist_from_leaf(node: dict, depth: int = 0): if len(node["children"]) == 0: return depth - return max([ - largest_dist_from_leaf(child, depth=depth + 1) - for child in node["children"] - ]) - - -def get_entries_at_depth(depth: int, - entries_and_traces: list[tuple[Any, Any]], - node: dict, - curr_depth: int = 0, - trace=()): + return max( + [largest_dist_from_leaf(child, depth=depth + 1) for child in node["children"]] + ) + + +def get_entries_at_depth( + depth: int, + entries_and_traces: list[tuple[Any, Any]], + node: dict, + curr_depth: int = 0, + trace=(), +): # assert that the query is at kernel or module level assert depth == -1 or depth == -2 @@ -40,21 +41,18 @@ def get_entries_at_depth(depth: int, if largest_dist_from_leaf(node) == (abs(depth) - 1): entries_and_traces.append((node["entry"], trace)) - trace = (node["entry"]["name"], ) + trace + trace = (node["entry"]["name"],) + trace for child in node["children"]: - get_entries_at_depth(depth, - entries_and_traces, - child, - curr_depth=curr_depth + 1, - trace=trace) + get_entries_at_depth( + depth, entries_and_traces, child, curr_depth=curr_depth + 1, trace=trace + ) def fold_nodes(root: dict, nodes_to_fold: list[str]): - stack: list[dict] = [root] while len(stack) != 0: node = stack.pop() - if node['entry']['name'] in nodes_to_fold: + if node["entry"]["name"] in nodes_to_fold: node["children"] = [] continue for child in node["children"]: @@ -76,9 +74,7 @@ def trim_string_back(string: str, width: int) -> str: def shorten_plot_legend_strings(legend, max_char_len: int): for t in legend.get_texts(): - t.set_text( - trim_string_back(abbreviate_known_names(t.get_text()), - max_char_len)) + t.set_text(trim_string_back(abbreviate_known_names(t.get_text()), max_char_len)) def abbreviate_known_names(name: str) -> str: @@ -108,15 +104,21 @@ def all_the_same(items) -> bool: names.add(entry["name"]) for name in non_unique_names: - entries_and_traces_with_name = [(entry, trace) - for entry, trace in entries_and_traces - if entry["name"] == name] + entries_and_traces_with_name = [ + (entry, trace) + for entry, trace in entries_and_traces + if entry["name"] == name + ] - zipped_traces = list( - zip(*[trace for _, trace in entries_and_traces_with_name])) + zipped_traces = list(zip(*[trace for _, trace in entries_and_traces_with_name])) first_trace_difference = next( - (i for i, trace_eles in enumerate(zipped_traces) - if not all_the_same(trace_eles)), None) + ( + i + for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles) + ), + None, + ) if first_trace_difference is None: # can't create a unique name, leave the names as they @@ -124,34 +126,32 @@ def all_the_same(items) -> bool: continue for entry, trace in entries_and_traces_with_name: - entry["name"] = " <- ".join((entry["name"], ) + - trace[:first_trace_difference + 1]) + entry["name"] = " <- ".join( + (entry["name"],) + trace[: first_trace_difference + 1] + ) ## Operation grouping utils #### -''' +""" Group operations in the given dataframe by some high-level ops like, - gemms - attention - rms_norm etc. -''' +""" def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: - def is_rms_norm(op_name: str): if "rms_norm_kernel" in op_name: return True def is_attention_block(op_name: str): - if "flash_fwd" in op_name or \ - "reshape_and_cache_flash_kernel" in op_name: + if "flash_fwd" in op_name or "reshape_and_cache_flash_kernel" in op_name: return True def is_quant(op_name: str): - if "scaled_fp8_quant" in op_name or \ - "scaled_int8_quant" in op_name: + if "scaled_fp8_quant" in op_name or "scaled_int8_quant" in op_name: return True # LoRA ops @@ -168,24 +168,27 @@ def is_bgmv_expand(op_name: str): return "bgmv_expand" in op_name def is_cutlass_gemm_op(op_name: str): - return "void cutlass::Kernel" in op_name or \ - "void cutlass::device_kernel" in op_name + return ( + "void cutlass::Kernel" in op_name + or "void cutlass::device_kernel" in op_name + ) def is_gemm_op(op_name: str): if is_quant(op_name): return False - return is_cutlass_gemm_op(op_name) or \ - "xmma_gemm" in op_name or \ - "gemv2T_kernel" in op_name or \ - "splitKreduce" in op_name or \ - "s16816gemm" in op_name + return ( + is_cutlass_gemm_op(op_name) + or "xmma_gemm" in op_name + or "gemv2T_kernel" in op_name + or "splitKreduce" in op_name + or "s16816gemm" in op_name + ) def is_elementwise_op(op_name: str): return "elementwise_kernel" in op_name def is_mem_op(op_name: str): - return "memcpy" in op_name.lower() or \ - "memset" in op_name.lower() + return "memcpy" in op_name.lower() or "memset" in op_name.lower() def is_vocab_embedding_op(op_name: str): return "vocabparallelembed" in op_name.lower() @@ -195,17 +198,15 @@ def is_nccl_op(op_name: str): return "nccl" in op_name.lower() def is_nccl_all_reduce(op_name: str): - return is_nccl_op(op_name) and \ - ("all_reduce" in op_name.lower() or \ - "allreduce" in op_name.lower()) + return is_nccl_op(op_name) and ( + "all_reduce" in op_name.lower() or "allreduce" in op_name.lower() + ) def is_nccl_gather(op_name: str): - return is_nccl_op(op_name) and \ - "gather" in op_name.lower() + return is_nccl_op(op_name) and "gather" in op_name.lower() def is_nccl_broadcast(op_name: str): - return is_nccl_op(op_name) and \ - "broadcast" in op_name.lower() + return is_nccl_op(op_name) and "broadcast" in op_name.lower() # Reduce ops types def is_cross_device_reduce_1stage(op_name: str): @@ -269,114 +270,122 @@ def is_reduce_kernel(op_name: str): ops = list(filter(lambda x: x not in nccl_other_ops, ops)) cross_device_reduce_1stage_ops = list( - filter(lambda x: is_cross_device_reduce_1stage(x), ops)) + filter(lambda x: is_cross_device_reduce_1stage(x), ops) + ) ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops)) cross_device_reduce_2stage_ops = list( - filter(lambda x: is_cross_device_reduce_2stage(x), ops)) + filter(lambda x: is_cross_device_reduce_2stage(x), ops) + ) ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops)) - custom_ar_all_reduce_ops = list( - filter(lambda x: is_custom_ar_all_reduce(x), ops)) + custom_ar_all_reduce_ops = list(filter(lambda x: is_custom_ar_all_reduce(x), ops)) ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops)) reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops)) ops = list(filter(lambda x: x not in reduce_kernel_ops, ops)) if len(attention_ops): - trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) + trace_df["attention"] = trace_df[attention_ops].agg("sum", axis=1) if len(quant_ops): - trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) + trace_df["quant_ops"] = trace_df[quant_ops].agg("sum", axis=1) if len(sgmv_shrink_ops): - trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum", - axis=1) + trace_df["sgmv_shrink_ops"] = trace_df[sgmv_shrink_ops].agg("sum", axis=1) if len(sgmv_expand_ops): - trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum", - axis=1) + trace_df["sgmv_expand_ops"] = trace_df[sgmv_expand_ops].agg("sum", axis=1) if len(bgmv_shrink_ops): - trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum", - axis=1) + trace_df["bgmv_shrink_ops"] = trace_df[bgmv_shrink_ops].agg("sum", axis=1) if len(bgmv_expand_ops): - trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum", - axis=1) + trace_df["bgmv_expand_ops"] = trace_df[bgmv_expand_ops].agg("sum", axis=1) if len(cutlass_gemm_ops): - trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum", - axis=1) + trace_df["cutlass_gemm_ops"] = trace_df[cutlass_gemm_ops].agg("sum", axis=1) if len(gemm_ops): - trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) + trace_df["gemm_ops"] = trace_df[gemm_ops].agg("sum", axis=1) if len(rms_norm_ops): - trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1) + trace_df["rms_norm_ops"] = trace_df[rms_norm_ops].agg("sum", axis=1) if len(vocab_embed_ops): - trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum", - axis=1) + trace_df["vocab_embed_ops"] = trace_df[vocab_embed_ops].agg("sum", axis=1) if len(mem_ops): - trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1) + trace_df["mem_ops"] = trace_df[mem_ops].agg("sum", axis=1) if len(elementwise_ops): - trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum", - axis=1) + trace_df["elementwise_ops"] = trace_df[elementwise_ops].agg("sum", axis=1) if len(nccl_all_reduce_ops): - trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg( - "sum", axis=1) + trace_df["nccl_all_reduce_ops"] = trace_df[nccl_all_reduce_ops].agg( + "sum", axis=1 + ) if len(nccl_gather_ops): - trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum", - axis=1) + trace_df["nccl_gather_ops"] = trace_df[nccl_gather_ops].agg("sum", axis=1) if len(nccl_broadcast_ops): - trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg( - "sum", axis=1) + trace_df["nccl_broadcast_ops"] = trace_df[nccl_broadcast_ops].agg("sum", axis=1) if len(nccl_other_ops): - trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum", - axis=1) + trace_df["nccl_other_ops"] = trace_df[nccl_other_ops].agg("sum", axis=1) if len(cross_device_reduce_1stage_ops): - trace_df['cross_device_reduce_1stage_ops'] = trace_df[ - cross_device_reduce_1stage_ops].agg("sum", axis=1) + trace_df["cross_device_reduce_1stage_ops"] = trace_df[ + cross_device_reduce_1stage_ops + ].agg("sum", axis=1) if len(cross_device_reduce_2stage_ops): - trace_df['cross_device_reduce_2stage_ops'] = trace_df[ - cross_device_reduce_2stage_ops].agg("sum", axis=1) + trace_df["cross_device_reduce_2stage_ops"] = trace_df[ + cross_device_reduce_2stage_ops + ].agg("sum", axis=1) if len(custom_ar_all_reduce_ops): - trace_df['custom_ar_all_reduce_ops'] = trace_df[ - custom_ar_all_reduce_ops].agg("sum", axis=1) + trace_df["custom_ar_all_reduce_ops"] = trace_df[custom_ar_all_reduce_ops].agg( + "sum", axis=1 + ) if len(reduce_kernel_ops): - trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", - axis=1) - - trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops + - sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops + - cutlass_gemm_ops + gemm_ops + rms_norm_ops + - vocab_embed_ops + mem_ops + elementwise_ops + - nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops + - nccl_other_ops + cross_device_reduce_1stage_ops + - cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops + - reduce_kernel_ops, - axis=1, - inplace=True) + trace_df["reduce_kernel_ops"] = trace_df[reduce_kernel_ops].agg("sum", axis=1) + + trace_df.drop( + attention_ops + + quant_ops + + sgmv_shrink_ops + + sgmv_expand_ops + + bgmv_shrink_ops + + bgmv_expand_ops + + cutlass_gemm_ops + + gemm_ops + + rms_norm_ops + + vocab_embed_ops + + mem_ops + + elementwise_ops + + nccl_all_reduce_ops + + nccl_gather_ops + + nccl_broadcast_ops + + nccl_other_ops + + cross_device_reduce_1stage_ops + + cross_device_reduce_2stage_ops + + custom_ar_all_reduce_ops + + reduce_kernel_ops, + axis=1, + inplace=True, + ) return trace_df ## Data plotting utils #### -def plot_trace_df(traces_df: pd.DataFrame, - plot_metric: str, - plot_title: str, - output: Optional[Path] = None): - +def plot_trace_df( + traces_df: pd.DataFrame, + plot_metric: str, + plot_title: str, + output: Optional[Path] = None, +): def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: phase_df = traces_df.query(f'phase == "{phase}"') - descs = phase_df['phase_desc'].to_list() + descs = phase_df["phase_desc"].to_list() assert all([desc == descs[0] for desc in descs]) return descs[0] - phases = traces_df['phase'].unique() + phases = traces_df["phase"].unique() phase_descs = [get_phase_description(traces_df, p) for p in phases] - traces_df = traces_df.pivot_table(index="phase", - columns="name", - values=plot_metric, - aggfunc="sum") + traces_df = traces_df.pivot_table( + index="phase", columns="name", values=plot_metric, aggfunc="sum" + ) traces_df = group_trace_by_operations(traces_df) @@ -396,20 +405,19 @@ def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: # Write the values as text on the bars for bar in ax.patches: if bar.get_height() != 0: - ax.text(bar.get_x() + bar.get_width() / 2, - bar.get_height() / 2 + bar.get_y(), - f"{round(bar.get_height(), 2)}", - ha='center', - color='w', - weight='bold', - size=5) + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() / 2 + bar.get_y(), + f"{round(bar.get_height(), 2)}", + ha="center", + color="w", + weight="bold", + size=5, + ) # Setup legend handles, labels = plt.gca().get_legend_handles_labels() - legend = fig.legend(handles, - labels, - loc='center left', - bbox_to_anchor=(1, 1)) + legend = fig.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 1)) shorten_plot_legend_strings(legend, 50) # Setup labels and title @@ -417,21 +425,20 @@ def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: ax.set_ylabel(plot_metric) plt.suptitle(plot_title) - plt.savefig(output, bbox_inches='tight') + plt.savefig(output, bbox_inches="tight") print("Created: ", output) def main( - json_trace: Path, - output_directory: Path, - depth: int, # Fetch/Plot operations at this depth of the Json tree - plot_metric: str, - make_names_unique: bool, - top_k: int, - json_nodes_to_fold: list[str]): - + json_trace: Path, + output_directory: Path, + depth: int, # Fetch/Plot operations at this depth of the Json tree + plot_metric: str, + make_names_unique: bool, + top_k: int, + json_nodes_to_fold: list[str], +): def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame: - def get_entries_and_traces(key: str): entries_and_traces: list[tuple[Any, Any]] = [] for root in profile_json[key]["summary_stats"]: @@ -441,16 +448,14 @@ def get_entries_and_traces(key: str): get_entries_at_depth(depth, entries_and_traces, root) return entries_and_traces - def keep_only_top_entries(df: pd.DataFrame, - metric: str, - top_k: int = 9) -> pd.DataFrame: - df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, - ["name"]] = "others" + def keep_only_top_entries( + df: pd.DataFrame, metric: str, top_k: int = 9 + ) -> pd.DataFrame: + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others" return df def get_phase_description(key: str) -> str: - num_running_seqs = profile_json[key]['metadata'][ - 'num_running_seqs'] + num_running_seqs = profile_json[key]["metadata"]["num_running_seqs"] if num_running_seqs is not None: return f"{key}-seqs-{num_running_seqs}" else: @@ -466,20 +471,24 @@ def get_phase_description(key: str) -> str: # To pandas dataframe trace_dfs = list( - map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), - traces)) + map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), traces) + ) # Respect top_k if top_k: trace_dfs = list( map( lambda trace_df: keep_only_top_entries( - trace_df, "cuda_time_us", top_k), trace_dfs)) + trace_df, "cuda_time_us", top_k + ), + trace_dfs, + ) + ) # Fill in information about the step-keys for trace_df, step_key in zip(trace_dfs, step_keys): - trace_df['phase'] = step_key - trace_df['phase_desc'] = get_phase_description(step_key) + trace_df["phase"] = step_key + trace_df["phase_desc"] = get_phase_description(step_key) # Combine all data frames so they can be put in a single plot traces_df = pd.concat(trace_dfs) @@ -492,17 +501,23 @@ def get_phase_description(key: str) -> str: def make_plot_title_suffix(profile_json: dict) -> str: context = profile_json["context"] - sparsity = context.get('sparsity', None) - run_type = \ - f'Run {context["num_steps"]} steps' if context['num_steps'] else \ - (f'Complete {context["complete_num_requests_per_step"]} per ' - f'step; Run till completion') - return (f"{context['engine_args']['model']}\n" - f"Batch={context['batch_size']}, " - f"PromptLen={context['prompt_len']}, " - f"NumGpus={context['engine_args']['tensor_parallel_size']}" - f"{', Sparsity ' + sparsity if sparsity else ''}\n" - f"Run Type: {run_type}") + sparsity = context.get("sparsity", None) + run_type = ( + f"Run {context['num_steps']} steps" + if context["num_steps"] + else ( + f"Complete {context['complete_num_requests_per_step']} per " + f"step; Run till completion" + ) + ) + return ( + f"{context['engine_args']['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"NumGpus={context['engine_args']['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}\n" + f"Run Type: {run_type}" + ) profile_json = None with open(json_trace) as f: @@ -511,14 +526,14 @@ def make_plot_title_suffix(profile_json: dict) -> str: # Get all `llm.generate.step()` profile step_traces = list(profile_json.keys()) - assert (step_traces[0] == 'context') + assert step_traces[0] == "context" step_traces = step_traces[1:] # have only prefill and decodes prefills = list(filter(lambda x: "prefill" in x, step_traces)) all_decodes = list(filter(lambda x: "decode" in x, step_traces)) assert len(prefills) + len(all_decodes) == len(step_traces) assert len(prefills) == 1 - decodes = all_decodes[::args.step_plot_interval] + decodes = all_decodes[:: args.step_plot_interval] if decodes[-1] != all_decodes[-1]: # Always have the last decode decodes.append(all_decodes[-1]) @@ -528,48 +543,63 @@ def make_plot_title_suffix(profile_json: dict) -> str: plot_title_suffix = make_plot_title_suffix(profile_json) - plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix, - output_directory / Path("prefill.png")) - plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix, - output_directory / Path("decode_steps.png")) + plot_trace_df( + prefill_traces, + plot_metric, + "prefill " + plot_title_suffix, + output_directory / Path("prefill.png"), + ) + plot_trace_df( + decode_traces, + plot_metric, + "decodes " + plot_title_suffix, + output_directory / Path("decode_steps.png"), + ) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--json-trace", - type=str, - required=True, - help="json trace file output by \ - examples/offline_inference/profiling.py") - parser.add_argument("--output-directory", - type=str, - required=False, - help="Directory to output plots") - parser.add_argument("--level", - type=str, - default="module", - choices=["module", "kernel"]) - parser.add_argument("--top-k", - type=int, - default=12, - help="Only graph the top `top_k` entries by time.") - parser.add_argument("--fold-json-node", - nargs='+', - default=['Sampler', 'LogitsProcessor'], - help='Do not plot the children of these nodes. Let, \ + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by \ + examples/offline_inference/profiling.py", + ) + parser.add_argument( + "--output-directory", type=str, required=False, help="Directory to output plots" + ) + parser.add_argument( + "--level", type=str, default="module", choices=["module", "kernel"] + ) + parser.add_argument( + "--top-k", + type=int, + default=12, + help="Only graph the top `top_k` entries by time.", + ) + parser.add_argument( + "--fold-json-node", + nargs="+", + default=["Sampler", "LogitsProcessor"], + help="Do not plot the children of these nodes. Let, \ the node represent the aggregate of all its \ - children') - parser.add_argument("--plot-metric", - type=str, - default="cuda_time_ms", - help='Metric to plot. some options are cuda_time_ms, \ - pct_cuda_time') + children", + ) + parser.add_argument( + "--plot-metric", + type=str, + default="cuda_time_ms", + help="Metric to plot. some options are cuda_time_ms, \ + pct_cuda_time", + ) parser.add_argument( "--step-plot-interval", type=int, default=4, - help="For every `step_plot_interval` steps, plot 1 step") + help="For every `step_plot_interval` steps, plot 1 step", + ) args = parser.parse_args() @@ -583,11 +613,19 @@ def make_plot_title_suffix(profile_json: dict) -> str: else: raise Exception(f"Unexpected level value ({args.level})") - output_directory = args.output_directory if args.output_directory else Path( - args.json_trace).parent + output_directory = ( + args.output_directory if args.output_directory else Path(args.json_trace).parent + ) if not os.path.exists(output_directory): os.makedirs(output_directory) - main(Path(args.json_trace), output_directory, depth, args.plot_metric, - make_names_unique, args.top_k, args.fold_json_node) + main( + Path(args.json_trace), + output_directory, + depth, + args.plot_metric, + make_names_unique, + args.top_k, + args.fold_json_node, + ) diff --git a/tools/report_build_time_ninja.py b/tools/report_build_time_ninja.py index 7386cdd9f724..fe3f352fe153 100644 --- a/tools/report_build_time_ninja.py +++ b/tools/report_build_time_ninja.py @@ -83,9 +83,9 @@ def WeightedDuration(self): """ # Allow for modest floating-point errors epsilon = 0.000002 - if (self.weighted_duration > self.Duration() + epsilon): - print('{} > {}?'.format(self.weighted_duration, self.Duration())) - assert (self.weighted_duration <= self.Duration() + epsilon) + if self.weighted_duration > self.Duration() + epsilon: + print("{} > {}?".format(self.weighted_duration, self.Duration())) + assert self.weighted_duration <= self.Duration() + epsilon return self.weighted_duration def DescribeTargets(self): @@ -93,10 +93,10 @@ def DescribeTargets(self): # Some build steps generate dozens of outputs - handle them sanely. # The max_length was chosen so that it can fit most of the long # single-target names, while minimizing word wrapping. - result = ', '.join(self.targets) + result = ", ".join(self.targets) max_length = 65 if len(result) > max_length: - result = result[:max_length] + '...' + result = result[:max_length] + "..." return result @@ -106,12 +106,13 @@ def ReadTargets(log, show_all): The result is a list of Target objects.""" header = log.readline() - assert header == '# ninja log v5\n', \ - 'unrecognized ninja log version {!r}'.format(header) + assert header == "# ninja log v5\n", "unrecognized ninja log version {!r}".format( + header + ) targets_dict = {} last_end_seen = 0.0 for line in log: - parts = line.strip().split('\t') + parts = line.strip().split("\t") if len(parts) != 5: # If ninja.exe is rudely halted then the .ninja_log file may be # corrupt. Silently continue. @@ -150,17 +151,17 @@ def ReadTargets(log, show_all): def GetExtension(target, extra_patterns): """Return the file extension that best represents a target. - For targets that generate multiple outputs it is important to return a - consistent 'canonical' extension. Ultimately the goal is to group build steps - by type.""" + For targets that generate multiple outputs it is important to return a + consistent 'canonical' extension. Ultimately the goal is to group build steps + by type.""" for output in target.targets: if extra_patterns: - for fn_pattern in extra_patterns.split(';'): - if fnmatch.fnmatch(output, '*' + fn_pattern + '*'): + for fn_pattern in extra_patterns.split(";"): + if fnmatch.fnmatch(output, "*" + fn_pattern + "*"): return fn_pattern # Not a true extension, but a good grouping. - if output.endswith('type_mappings'): - extension = 'type_mappings' + if output.endswith("type_mappings"): + extension = "type_mappings" break # Capture two extensions if present. For example: file.javac.jar should @@ -170,26 +171,26 @@ def GetExtension(target, extra_patterns): extension = ext2 + ext1 # Preserve the order in the file name. if len(extension) == 0: - extension = '(no extension found)' + extension = "(no extension found)" - if ext1 in ['.pdb', '.dll', '.exe']: - extension = 'PEFile (linking)' + if ext1 in [".pdb", ".dll", ".exe"]: + extension = "PEFile (linking)" # Make sure that .dll and .exe are grouped together and that the # .dll.lib files don't cause these to be listed as libraries break - if ext1 in ['.so', '.TOC']: - extension = '.so (linking)' + if ext1 in [".so", ".TOC"]: + extension = ".so (linking)" # Attempt to identify linking, avoid identifying as '.TOC' break # Make sure .obj files don't get categorized as mojo files - if ext1 in ['.obj', '.o']: + if ext1 in [".obj", ".o"]: break # Jars are the canonical output of java targets. - if ext1 == '.jar': + if ext1 == ".jar": break # Normalize all mojo related outputs to 'mojo'. - if output.count('.mojom') > 0: - extension = 'mojo' + if output.count(".mojom") > 0: + extension = "mojo" break return extension @@ -214,8 +215,8 @@ def SummarizeEntries(entries, extra_step_types): if target.end > latest: latest = target.end total_cpu_time += target.Duration() - task_start_stop_times.append((target.start, 'start', target)) - task_start_stop_times.append((target.end, 'stop', target)) + task_start_stop_times.append((target.start, "start", target)) + task_start_stop_times.append((target.end, "stop", target)) length = latest - earliest weighted_total = 0.0 @@ -241,10 +242,10 @@ def SummarizeEntries(entries, extra_step_types): if num_running > 0: # Update the total weighted time up to this moment. last_weighted_time += (time - last_time) / float(num_running) - if action_name == 'start': + if action_name == "start": # Record the total weighted task time when this task starts. running_tasks[target] = last_weighted_time - if action_name == 'stop': + if action_name == "stop": # Record the change in the total weighted task time while this task # ran. weighted_duration = last_weighted_time - running_tasks[target] @@ -252,13 +253,16 @@ def SummarizeEntries(entries, extra_step_types): weighted_total += weighted_duration del running_tasks[target] last_time = time - assert (len(running_tasks) == 0) + assert len(running_tasks) == 0 # Warn if the sum of weighted times is off by more than half a second. if abs(length - weighted_total) > 500: - print('Warning: Possible corrupt ninja log, results may be ' - 'untrustworthy. Length = {:.3f}, weighted total = {:.3f}'.format( - length, weighted_total)) + print( + "Warning: Possible corrupt ninja log, results may be " + "untrustworthy. Length = {:.3f}, weighted total = {:.3f}".format( + length, weighted_total + ) + ) entries_by_ext = defaultdict(list) for target in entries: @@ -266,32 +270,38 @@ def SummarizeEntries(entries, extra_step_types): entries_by_ext[extension].append(target) for key, values in entries_by_ext.items(): - print(' Longest build steps for {}:'.format(key)) + print(" Longest build steps for {}:".format(key)) values.sort(key=lambda x: x.WeightedDuration()) for target in values[-long_count:]: print( - ' {:8.1f} weighted s to build {} ({:.1f} s elapsed time)'. - format(target.WeightedDuration(), target.DescribeTargets(), - target.Duration())) - - print(' {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x ' - 'parallelism)'.format(length, total_cpu_time, - total_cpu_time * 1.0 / length)) - print(' {} build steps completed, average of {:1.2f}/s'.format( - len(entries), - len(entries) / (length))) + " {:8.1f} weighted s to build {} ({:.1f} s elapsed time)".format( + target.WeightedDuration(), + target.DescribeTargets(), + target.Duration(), + ) + ) + + print( + " {:.1f} s weighted time ({:.1f} s elapsed time sum, {:1.1f}x " + "parallelism)".format(length, total_cpu_time, total_cpu_time * 1.0 / length) + ) + print( + " {} build steps completed, average of {:1.2f}/s".format( + len(entries), len(entries) / (length) + ) + ) def main(): - log_file = '.ninja_log' + log_file = ".ninja_log" parser = argparse.ArgumentParser() - parser.add_argument('-C', dest='build_directory', help='Build directory.') + parser.add_argument("-C", dest="build_directory", help="Build directory.") parser.add_argument( - '-s', - '--step-types', - help='semicolon separated fnmatch patterns for build-step grouping') - parser.add_argument('--log-file', - help="specific ninja log file to analyze.") + "-s", + "--step-types", + help="semicolon separated fnmatch patterns for build-step grouping", + ) + parser.add_argument("--log-file", help="specific ninja log file to analyze.") args, _extra_args = parser.parse_known_args() if args.build_directory: log_file = os.path.join(args.build_directory, log_file) @@ -300,17 +310,16 @@ def main(): if args.step_types: # Make room for the extra build types. global long_ext_count - long_ext_count += len(args.step_types.split(';')) + long_ext_count += len(args.step_types.split(";")) try: with open(log_file) as log: entries = ReadTargets(log, False) SummarizeEntries(entries, args.step_types) except OSError: - print('Log file {!r} not found, no build summary created.'.format( - log_file)) + print("Log file {!r} not found, no build summary created.".format(log_file)) return errno.ENOENT -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/tools/validate_config.py b/tools/validate_config.py index f6439fa9ada5..d779edabc841 100644 --- a/tools/validate_config.py +++ b/tools/validate_config.py @@ -38,10 +38,12 @@ def pairwise(iterable): # Consider each pair of nodes. for a, b in pairwise(cls_node.body): # Must be an assignment then a constant string. - if (not isinstance(a, (ast.Assign, ast.AnnAssign)) - or not isinstance(b, ast.Expr) - or not isinstance(b.value, ast.Constant) - or not isinstance(b.value.value, str)): + if ( + not isinstance(a, (ast.Assign, ast.AnnAssign)) + or not isinstance(b, ast.Expr) + or not isinstance(b.value, ast.Constant) + or not isinstance(b.value.value, str) + ): continue doc = inspect.cleandoc(b.value.value) @@ -61,25 +63,27 @@ def pairwise(iterable): class ConfigValidator(ast.NodeVisitor): - - def __init__(self): - ... + def __init__(self): ... def visit_ClassDef(self, node): # Validate class with both @config and @dataclass decorators decorators = [ - id for d in node.decorator_list if (isinstance(d, ast.Name) and ( - (id := d.id) == 'config' or id == 'dataclass')) or - (isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and - (id := d.func.id) == 'dataclass')) + id + for d in node.decorator_list + if ( + isinstance(d, ast.Name) + and ((id := d.id) == "config" or id == "dataclass") + ) + or ( + isinstance(d, ast.Call) + and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass") + ) ] - if set(decorators) == {'config', 'dataclass'}: + if set(decorators) == {"config", "dataclass"}: validate_class(node) - elif set(decorators) == {'config'}: - fail( - f"Class {node.name} with config decorator must be a dataclass.", - node) + elif set(decorators) == {"config"}: + fail(f"Class {node.name} with config decorator must be a dataclass.", node) self.generic_visit(node) @@ -93,9 +97,11 @@ def validate_class(class_node: ast.ClassDef): # Skip ClassVar and InitVar # see https://docs.python.org/3/library/dataclasses.html#class-variables # and https://docs.python.org/3/library/dataclasses.html#init-only-variables - if (isinstance(stmt.annotation, ast.Subscript) - and isinstance(stmt.annotation.value, ast.Name) - and stmt.annotation.value.id in {"ClassVar", "InitVar"}): + if ( + isinstance(stmt.annotation, ast.Subscript) + and isinstance(stmt.annotation.value, ast.Name) + and stmt.annotation.value.id in {"ClassVar", "InitVar"} + ): continue if isinstance(stmt.target, ast.Name): @@ -103,22 +109,30 @@ def validate_class(class_node: ast.ClassDef): if stmt.value is None: fail( f"Field '{field_name}' in {class_node.name} must have " - "a default value.", stmt) + "a default value.", + stmt, + ) if field_name not in attr_docs: fail( f"Field '{field_name}' in {class_node.name} must have " - "a docstring.", stmt) + "a docstring.", + stmt, + ) - if isinstance(stmt.annotation, ast.Subscript) and \ - isinstance(stmt.annotation.value, ast.Name) \ - and stmt.annotation.value.id == "Union" and \ - isinstance(stmt.annotation.slice, ast.Tuple): + if ( + isinstance(stmt.annotation, ast.Subscript) + and isinstance(stmt.annotation.value, ast.Name) + and stmt.annotation.value.id == "Union" + and isinstance(stmt.annotation.slice, ast.Tuple) + ): args = stmt.annotation.slice.elts literal_args = [ - arg for arg in args - if isinstance(arg, ast.Subscript) and isinstance( - arg.value, ast.Name) and arg.value.id == "Literal" + arg + for arg in args + if isinstance(arg, ast.Subscript) + and isinstance(arg.value, ast.Name) + and arg.value.id == "Literal" ] if len(literal_args) > 1: fail( @@ -126,7 +140,9 @@ def validate_class(class_node: ast.ClassDef): "use a single " "Literal type. Please use 'Literal[Literal1, " "Literal2]' instead of 'Union[Literal1, Literal2]'" - ".", stmt) + ".", + stmt, + ) def validate_ast(tree: ast.stmt): diff --git a/use_existing_torch.py b/use_existing_torch.py index 76480f3e58fe..fd4caa69ec9c 100644 --- a/use_existing_torch.py +++ b/use_existing_torch.py @@ -3,7 +3,7 @@ import glob -requires_files = glob.glob('requirements/*.txt') +requires_files = glob.glob("requirements/*.txt") requires_files += ["pyproject.toml"] for file in requires_files: print(f">>> cleaning {file}") @@ -11,11 +11,11 @@ lines = f.readlines() if "torch" in "".join(lines).lower(): print("removed:") - with open(file, 'w') as f: + with open(file, "w") as f: for line in lines: - if 'torch' not in line.lower(): + if "torch" not in line.lower(): f.write(line) else: print(line.strip()) print(f"<<< done cleaning {file}") - print() \ No newline at end of file + print() diff --git a/vllm/__init__.py b/vllm/__init__.py index 3a5c1b1ce0da..b9c868de6886 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -48,12 +48,18 @@ from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry - from vllm.outputs import (ClassificationOutput, - ClassificationRequestOutput, CompletionOutput, - EmbeddingOutput, EmbeddingRequestOutput, - PoolingOutput, PoolingRequestOutput, - RequestOutput, ScoringOutput, - ScoringRequestOutput) + from vllm.outputs import ( + ClassificationOutput, + ClassificationRequestOutput, + CompletionOutput, + EmbeddingOutput, + EmbeddingRequestOutput, + PoolingOutput, + PoolingRequestOutput, + RequestOutput, + ScoringOutput, + ScoringRequestOutput, + ) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -68,8 +74,7 @@ def __getattr__(name: str) -> typing.Any: module = import_module(module_name, __package__) return getattr(module, attr_name) else: - raise AttributeError( - f'module {__package__} has no attribute {name}') + raise AttributeError(f"module {__package__} has no attribute {name}") __all__ = [ diff --git a/vllm/_bc_linter.py b/vllm/_bc_linter.py index 52a95dbee186..af68396af0b5 100644 --- a/vllm/_bc_linter.py +++ b/vllm/_bc_linter.py @@ -9,13 +9,11 @@ @overload -def bc_linter_skip(obj: T) -> T: - ... +def bc_linter_skip(obj: T) -> T: ... @overload -def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: - ... +def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: ... def bc_linter_skip(obj: Any = None, *, reason: str | None = None): @@ -34,13 +32,11 @@ def _wrap(x: T) -> T: @overload -def bc_linter_include(obj: T) -> T: - ... +def bc_linter_include(obj: T) -> T: ... @overload -def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: - ... +def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: ... def bc_linter_include(obj: Any = None, *, reason: str | None = None): diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f07fa1e4e7be..b8cbb1ad90a6 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib from typing import TYPE_CHECKING, Optional, Union import torch @@ -13,16 +12,8 @@ logger = init_logger(__name__) -if not current_platform.is_tpu() and not current_platform.is_xpu(): - try: - import vllm._C - except ImportError as e: - logger.warning("Failed to import from vllm._C with %r", e) - -supports_moe_ops = False -with contextlib.suppress(ImportError): - import vllm._moe_C # noqa: F401 - supports_moe_ops = True +current_platform.import_core_kernels() +supports_moe_ops = current_platform.try_import_moe_kernels() if TYPE_CHECKING: @@ -58,11 +49,26 @@ def paged_attention_v1( blocksparse_head_sliding_step: int = 0, ) -> None: torch.ops._C.paged_attention_v1( - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step) + out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) def paged_attention_v2( @@ -90,11 +96,29 @@ def paged_attention_v2( blocksparse_head_sliding_step: int = 0, ) -> None: torch.ops._C.paged_attention_v2( - out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, - blocksparse_block_size, blocksparse_head_sliding_step) + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) def paged_attention_rocm( @@ -119,12 +143,28 @@ def paged_attention_rocm( fp8_out_scale: Optional[torch.Tensor] = None, mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16", ) -> None: - torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, - key_cache, value_cache, num_kv_heads, - scale, block_tables, seq_lens, - query_start_loc, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, - v_scale, fp8_out_scale, mfma_type) + torch.ops._rocm_C.paged_attention( + out, + exp_sum, + max_logits, + tmp_out, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + query_start_loc, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + fp8_out_scale, + mfma_type, + ) def mla_decode_kvcache_cpu( @@ -135,19 +175,23 @@ def mla_decode_kvcache_cpu( block_tables: torch.Tensor, seq_lens: torch.Tensor, ) -> None: - torch.ops._C_cpu.mla_decode_kvcache(out, query, kv_cache, scale, - block_tables, seq_lens) + torch.ops._C_cpu.mla_decode_kvcache( + out, query, kv_cache, scale, block_tables, seq_lens + ) # merge attn states ops -def merge_attn_states(output: torch.Tensor, - prefix_output: torch.Tensor, - prefix_lse: torch.Tensor, - suffix_output: torch.Tensor, - suffix_lse: torch.Tensor, - output_lse: Optional[torch.Tensor] = None) -> None: - torch.ops._C.merge_attn_states(output, output_lse, prefix_output, - prefix_lse, suffix_output, suffix_lse) +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + torch.ops._C.merge_attn_states( + output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse + ) def convert_vertical_slash_indexes( @@ -166,33 +210,43 @@ def convert_vertical_slash_indexes( nnz_vertical = vertical_indexes.size(2) num_rows = (context_size + block_size_M - 1) // block_size_M - block_count = torch.zeros(batch_size, - num_heads, - num_rows, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - block_offset = torch.zeros(batch_size, - num_heads, - num_rows, - nnz_slash, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - column_count = torch.zeros(batch_size, - num_heads, - num_rows, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - column_index = torch.zeros(batch_size, - num_heads, - num_rows, - nnz_vertical, - dtype=q_seqlens.dtype, - device=q_seqlens.device) + block_count = torch.zeros( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + block_offset = torch.zeros( + batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) + column_count = torch.zeros( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + column_index = torch.zeros( + batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) torch.ops._C.convert_vertical_slash_indexes( - block_count, block_offset, column_count, column_index, q_seqlens, - kv_seqlens, vertical_indexes, slash_indexes, context_size, - block_size_M, block_size_N, causal) + block_count, + block_offset, + column_count, + column_index, + q_seqlens, + kv_seqlens, + vertical_indexes, + slash_indexes, + context_size, + block_size_M, + block_size_N, + causal, + ) return block_count, block_offset, column_count, column_index @@ -215,33 +269,45 @@ def convert_vertical_slash_indexes_mergehead( nnz_vertical = vertical_indexes.size(2) num_rows = (context_size + block_size_M - 1) // block_size_M - block_count = torch.empty(batch_size, - num_heads, - num_rows, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - block_offset = torch.empty(batch_size, - num_heads, - num_rows, - nnz_slash, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - column_count = torch.empty(batch_size, - num_heads, - num_rows, - dtype=q_seqlens.dtype, - device=q_seqlens.device) - column_index = torch.empty(batch_size, - num_heads, - num_rows, - nnz_vertical, - dtype=q_seqlens.dtype, - device=q_seqlens.device) + block_count = torch.empty( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + block_offset = torch.empty( + batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) + column_count = torch.empty( + batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device + ) + column_index = torch.empty( + batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device, + ) torch.ops._C.convert_vertical_slash_indexes_mergehead( - block_count, block_offset, column_count, column_index, q_seqlens, - kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count, - slash_indices_count, context_size, block_size_M, block_size_N, causal) + block_count, + block_offset, + column_count, + column_index, + q_seqlens, + kv_seqlens, + vertical_indexes, + slash_indexes, + vertical_indices_count, + slash_indices_count, + context_size, + block_size_M, + block_size_N, + causal, + ) return block_count, block_offset, column_count, column_index @@ -254,53 +320,71 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - torch.ops._C.rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox) + torch.ops._C.rotary_embedding( + positions, query, key, head_size, cos_sin_cache, is_neox + ) # layer norm ops -def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - epsilon: float) -> None: +def rms_norm( + out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float +) -> None: # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input input_contiguous = input.contiguous() torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon) -def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, epsilon: float) -> None: +def fused_add_rms_norm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float +) -> None: torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) -def poly_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - bias: torch.Tensor, epsilon: float) -> None: +def poly_norm( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + epsilon: float, +) -> None: # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input input_contiguous = input.contiguous() torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon) def apply_repetition_penalties_torch( - logits: torch.Tensor, prompt_mask: torch.Tensor, - output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor, +) -> None: repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( - 1, logits.size(1)) + 1, logits.size(1) + ) # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. - penalties = torch.where(prompt_mask | output_mask, repetition_penalties, - 1.0) + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0) # If logits are positive, divide by penalty, otherwise multiply by penalty. scaling = torch.where(logits > 0, 1.0 / penalties, penalties) logits *= scaling def apply_repetition_penalties_cuda( - logits: torch.Tensor, prompt_mask: torch.Tensor, - output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: - torch.ops._C.apply_repetition_penalties_(logits, prompt_mask, output_mask, - repetition_penalties) + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor, +) -> None: + torch.ops._C.apply_repetition_penalties_( + logits, prompt_mask, output_mask, repetition_penalties + ) -def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, - output_mask: torch.Tensor, - repetition_penalties: torch.Tensor) -> None: +def apply_repetition_penalties( + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor, +) -> None: """Apply repetition penalties to logits in-place. Args: @@ -310,11 +394,13 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, repetition_penalties: The repetition penalties of shape (num_seqs, ). """ if logits.is_cuda and logits.is_contiguous(): - apply_repetition_penalties_cuda(logits, prompt_mask, output_mask, - repetition_penalties) + apply_repetition_penalties_cuda( + logits, prompt_mask, output_mask, repetition_penalties + ) else: - apply_repetition_penalties_torch(logits, prompt_mask, output_mask, - repetition_penalties) + apply_repetition_penalties_torch( + logits, prompt_mask, output_mask, repetition_penalties + ) # fused quant layer norm ops @@ -324,128 +410,172 @@ def rms_norm_dynamic_per_token_quant( epsilon: float, quant_dtype: torch.dtype, scale_ub: Optional[torch.Tensor] = None, - residual: Optional[torch.Tensor] = None + residual: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=quant_dtype) - scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) + scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) - torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight, - scales, epsilon, scale_ub, - residual) + torch.ops._C.rms_norm_dynamic_per_token_quant( + output, input, weight, scales, epsilon, scale_ub, residual + ) return output, scales # quantization ops # awq -def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, - zeros: torch.Tensor, split_k_iters: int, thx: int, - thy: int) -> torch.Tensor: +def awq_dequantize( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + split_k_iters: int, + thx: int, + thy: int, +) -> torch.Tensor: if envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import ( - awq_dequantize_triton) + awq_dequantize_triton, + ) + return awq_dequantize_triton(qweight, scales, zeros) - return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, - thx, thy) + return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) -def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, - scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: +def awq_gemm( + input: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + split_k_iters: int, +) -> torch.Tensor: if envs.VLLM_USE_TRITON_AWQ: - from vllm.model_executor.layers.quantization.awq_triton import ( - awq_gemm_triton) + from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton + return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # gptq -def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, use_exllama: bool, - bit: int) -> torch.Tensor: - return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, - b_g_idx, use_exllama, bit) +def gptq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, + use_exllama: bool, + bit: int, +) -> torch.Tensor: + return torch.ops._C.gptq_gemm( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit + ) if hasattr(torch.ops._C, "gptq_gemm"): @register_fake("_C::gptq_gemm") - def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, - b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, - use_exllama: bool, bit: int) -> torch.Tensor: - return torch.empty((a.size(0), b_q_weight.size(1)), - dtype=a.dtype, - device=a.device) + def _gptq_gemm_fake( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, + use_exllama: bool, + bit: int, + ) -> torch.Tensor: + return torch.empty( + (a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device + ) -def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, - bit: int) -> None: +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) # marlin_24 -def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, b_q_type: ScalarType, - size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, b_q_type.id, size_m, - size_n, size_k) +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return torch.ops._C.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @register_fake("_C::gptq_marlin_24_gemm") - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: + def _gptq_marlin_24_gemm_fake( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + ) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake("_C::gptq_marlin_gemm") - def _gptq_marlin_gemm_fake(a: torch.Tensor, - c: Optional[torch.Tensor], - b_q_weight: torch.Tensor, - b_bias: Optional[torch.Tensor], - b_scales: torch.Tensor, - global_scale: Optional[torch.Tensor], - b_zeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], - workspace: torch.Tensor, - b_q_type_id: int, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool = True, - use_atomic_add: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + def _gptq_marlin_gemm_fake( + a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_bias: Optional[torch.Tensor], + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, + ) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake("_C::awq_dequantize") - def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, - zeros: torch.Tensor, split_k_iters: torch.SymInt, - thx: int, thy: int) -> torch.Tensor: + def _awq_dequantize_fake( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + split_k_iters: torch.SymInt, + thx: int, + thy: int, + ) -> torch.Tensor: in_c = qweight.size(0) qout_c = qweight.size(1) out_c = qout_c * 8 - return torch.empty((in_c, out_c), - dtype=scales.dtype, - device=scales.device) + return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device) @register_fake("_C::awq_gemm") - def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, - qzeros: torch.Tensor, scales: torch.Tensor, - split_k_iters: torch.SymInt) -> torch.Tensor: + def _awq_gemm_fake( + input: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + split_k_iters: torch.SymInt, + ) -> torch.Tensor: num_in_feats = input.size(0) - return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8), - dtype=input.dtype, - device=input.device).sum(0) + return torch.empty( + (split_k_iters, num_in_feats, qweight.size(1) * 8), + dtype=input.dtype, + device=input.device, + ).sum(0) @register_fake("_C::machete_mm") def machete_mm_fake( @@ -467,22 +597,25 @@ def machete_mm_fake( @register_fake("_C::machete_prepack_B") def machete_prepack_B_fake( - b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, - group_scales_type: Optional[torch.dtype]) -> torch.Tensor: - return torch.empty_like(b_q_weight, - memory_format=torch.contiguous_format) + b_q_weight: torch.Tensor, + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + ) -> torch.Tensor: + return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) @register_fake("_C::cutlass_w4a8_mm") def cutlass_w4a8_mm_fake( - a: torch.Tensor, - # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b - b_q: torch.Tensor, - b_group_scales: torch.Tensor, - b_group_size: int, - b_channel_scales: torch.Tensor, - a_token_scales: torch.Tensor, - out_type: Optional[torch.dtype] = None, - maybe_schedule: Optional[str] = None) -> torch.Tensor: + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None, + ) -> torch.Tensor: m = a.size(0) n = b_q.size(1) out_dtype = out_type if out_type is not None else torch.bfloat16 @@ -500,15 +633,19 @@ def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @register_fake("_C::allspark_w8a16_gemm") - def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor, - b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], - n: torch.SymInt, group_size: torch.SymInt, - sm_count: torch.SymInt, - sm_version: torch.SymInt, - CUBLAS_M_THRESHOLD: torch.SymInt, - has_zp: bool, - n32k16_reorder: bool) -> torch.Tensor: + def _allspark_w8a16_gemm_fake( + a: torch.Tensor, + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + n: torch.SymInt, + group_size: torch.SymInt, + sm_count: torch.SymInt, + sm_version: torch.SymInt, + CUBLAS_M_THRESHOLD: torch.SymInt, + has_zp: bool, + n32k16_reorder: bool, + ) -> torch.Tensor: m = a.size(0) return torch.empty((m, n), device=a.device, dtype=a.dtype) @@ -517,11 +654,12 @@ def _allspark_w8a16_gemm_fake(a: torch.Tensor, b_qweight: torch.Tensor, @register_fake("_C::ggml_dequantize") def _ggml_dequantize_fake( - W: torch.Tensor, - quant_type: int, - m: torch.SymInt, - n: torch.SymInt, - dtype: Optional[torch.dtype] = None) -> torch.Tensor: + W: torch.Tensor, + quant_type: int, + m: torch.SymInt, + n: torch.SymInt, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) @register_fake("_C::ggml_mul_mat_vec_a8") @@ -556,9 +694,7 @@ def _ggml_moe_a8_fake( tokens: torch.SymInt, ) -> torch.Tensor: tokens = X.size(0) - return torch.empty((tokens * top_k, row), - dtype=torch.float16, - device=W.device) + return torch.empty((tokens * top_k, row), dtype=torch.float16, device=W.device) if hasattr(torch.ops._C, "ggml_moe_a8_vec"): @@ -574,9 +710,7 @@ def _ggml_moe_a8_vec_fake( tokens: torch.SymInt, ) -> torch.Tensor: tokens = X.size(0) - return torch.empty((tokens * top_k, row), - dtype=X.dtype, - device=W.device) + return torch.empty((tokens * top_k, row), dtype=X.dtype, device=W.device) # cutlass @@ -593,20 +727,23 @@ def cutlass_blockwise_scaled_grouped_mm( problem_sizes: torch.Tensor, expert_offsets: torch.Tensor, ): - torch.ops._C.cutlass_blockwise_scaled_grouped_mm(output, a, b, scales_a, - scales_b, problem_sizes, - expert_offsets) + torch.ops._C.cutlass_blockwise_scaled_grouped_mm( + output, a, b, scales_a, scales_b, problem_sizes, expert_offsets + ) -def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, - block_scale_a: torch.Tensor, - block_scale_b: torch.Tensor, alpha: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: +def cutlass_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, +) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 m, n = a.shape[0], b.shape[0] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, - alpha) + torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, alpha) return out @@ -615,16 +752,17 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_scaled_mm_supports_block_fp8( - cuda_device_capability) + return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ `cutlass_scaled_mm` implements a fused version of `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` @@ -647,70 +785,65 @@ def cutlass_scaled_mm(a: torch.Tensor, scale_a.shape * [1, 128] == a.shape scale_b.shape * [128, 128] == b.shape """ - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.numel( - ) == b.shape[1] and bias.dtype == out_dtype + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype # Massage the input to be 2D target_shape = (*a.shape[:-1], b.shape[1]) a = a.view(-1, a.shape[-1]) - cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + cutlass_compatible_b = b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 if current_platform.is_rocm() or not cutlass_compatible_b: from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa - triton_scaled_mm) + triton_scaled_mm, + ) + out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) else: - out = torch.empty((a.shape[0], b.shape[1]), - dtype=out_dtype, - device=a.device) + out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) return out.view(*target_shape) -def cutlass_scaled_mm_azp(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ :param azp_adj: In the per-tensor case, this should include the azp. Always per-channel. :param azp: Only set in the per-token case. Per-token if set. """ - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.numel( - ) == b.shape[1] and bias.dtype == out_dtype + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype # Massage the input to be 2D target_shape = (*a.shape[:-1], b.shape[1]) a = a.view(-1, a.shape[-1]) assert azp is None or azp.numel() == a.shape[0] - out = torch.empty((a.shape[0], b.shape[1]), - dtype=out_dtype, - device=a.device) - torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, - azp, bias) + out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) return out.view(*target_shape) def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_sparse_scaled_mm_supported( - cuda_device_capability) + return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability) def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) -def cutlass_sparse_compress(a: torch.Tensor) \ - -> tuple[torch.Tensor, torch.Tensor]: +def cutlass_sparse_compress(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Compresses a sparse matrix for use with Cutlass sparse operations. @@ -741,26 +874,25 @@ def cutlass_sparse_compress(a: torch.Tensor) \ - The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor. - The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`. """ - assert (a.dtype in [ - torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16 - ]) - assert (a.is_contiguous()) + assert a.dtype in [torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16] + assert a.is_contiguous() # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4 elemsPerMetaElem = 4 - assert (a.shape[1] % (2 * elemsPerMetaElem) == 0) + assert a.shape[1] % (2 * elemsPerMetaElem) == 0 return torch.ops._C.cutlass_sparse_compress(a) def cutlass_scaled_sparse_mm( - a: torch.Tensor, - bt_nzs: torch.Tensor, - bt_meta: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + a: torch.Tensor, + bt_nzs: torch.Tensor, + bt_meta: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ Performs a scaled sparse matrix multiplication using Cutlass. @@ -784,31 +916,33 @@ def cutlass_scaled_sparse_mm( Returns: - The result of the scaled sparse matrix multiplication. """ - assert (bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == bt_nzs.shape[0] \ - and bias.dtype == out_dtype + assert bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == bt_nzs.shape[0] and bias.dtype == out_dtype m = a.shape[0] n = bt_nzs.shape[0] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_sparse_mm(out, a, bt_nzs, bt_meta, scale_a, - scale_b, bias) + torch.ops._C.cutlass_scaled_sparse_mm( + out, a, bt_nzs, bt_meta, scale_a, scale_b, bias + ) return out -def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, - expert_offsets: torch.Tensor, - problem_sizes1: torch.Tensor, - problem_sizes2: torch.Tensor, - input_permutation: torch.Tensor, - output_permutation: torch.Tensor, - num_experts: int, - n: int, - k: int, - blockscale_offsets: Optional[torch.Tensor] = None): +def get_cutlass_moe_mm_data( + topk_ids: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + input_permutation: torch.Tensor, + output_permutation: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None, +): """ Prepare data necessary to perform CUTLASS grouped matrix multiplications used in CUTLASS-based fused MoE. @@ -832,22 +966,29 @@ def get_cutlass_moe_mm_data(topk_ids: torch.Tensor, computed with expert E is blockscale_offsets[E + 1] - blockscale_offsets[E] """ - return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, - problem_sizes1, problem_sizes2, - input_permutation, - output_permutation, - num_experts, n, k, - blockscale_offsets) + return torch.ops._C.get_cutlass_moe_mm_data( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + input_permutation, + output_permutation, + num_experts, + n, + k, + blockscale_offsets, + ) def get_cutlass_moe_mm_problem_sizes( - topk_ids: torch.Tensor, - problem_sizes1: torch.Tensor, - problem_sizes2: torch.Tensor, - num_experts: int, - n: int, - k: int, - blockscale_offsets: Optional[torch.Tensor] = None): + topk_ids: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + num_experts: int, + n: int, + k: int, + blockscale_offsets: Optional[torch.Tensor] = None, +): """ Compute only the per-expert problem sizes needed by the two grouped matrix multiplications used in CUTLASS-based fused MoE. @@ -858,8 +999,8 @@ def get_cutlass_moe_mm_problem_sizes( used in the fused MoE operation. """ return torch.ops._C.get_cutlass_moe_mm_problem_sizes( - topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, - blockscale_offsets) + topk_ids, problem_sizes1, problem_sizes2, num_experts, n, k, blockscale_offsets + ) def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): @@ -868,25 +1009,31 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): This is used in MoE to permute the input tensor before performing grouped matrix multiplications. """ num_tokens_permuted = dst2src_map.shape[0] - output_tensor = torch.empty((num_tokens_permuted, input_tensor.shape[1]), - device=input_tensor.device, - dtype=input_tensor.dtype) + output_tensor = torch.empty( + (num_tokens_permuted, input_tensor.shape[1]), + device=input_tensor.device, + dtype=input_tensor.dtype, + ) torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor) return output_tensor -def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor, - problem_sizes1: torch.Tensor, - problem_sizes2: torch.Tensor, - expert_num_tokens: torch.Tensor, - num_local_experts: int, padded_m: int, n: int, - k: int): +def get_cutlass_pplx_moe_mm_data( + expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + expert_num_tokens: torch.Tensor, + num_local_experts: int, + padded_m: int, + n: int, + k: int, +): """ Prepare data necessary to perform CUTLASS grouped matrix multiplications used in CUTLASS-based fused MoE. The function takes in expert_num_tokens (token count per expert) and - non_zero_expert_idxs (consecutive indices of experts with non-zero token + non_zero_expert_idxs (consecutive indices of experts with non-zero token counts) and uses them to compute: - expert_offsets: Indices that mark at which token index each expert begins its computation. @@ -895,16 +1042,31 @@ def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor, the fused MoE operation. """ return torch.ops._C.get_cutlass_pplx_moe_mm_data( - expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, - num_local_experts, padded_m, n, k) + expert_offsets, + problem_sizes1, + problem_sizes2, + expert_num_tokens, + num_local_experts, + padded_m, + n, + k, + ) -def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - b_tensors: torch.Tensor, a_scales: torch.Tensor, - b_scales: torch.Tensor, expert_offsets: torch.Tensor, - problem_sizes: torch.Tensor, a_strides: torch.Tensor, - b_strides: torch.Tensor, c_strides: torch.Tensor, - per_act_token: bool, per_out_ch: bool): +def cutlass_moe_mm( + out_tensors: torch.Tensor, + a_tensors: torch.Tensor, + b_tensors: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, + a_strides: torch.Tensor, + b_strides: torch.Tensor, + c_strides: torch.Tensor, + per_act_token: bool, + per_out_ch: bool, +): """ A single grouped matrix multiplication used in CUTLASS-based fused MoE. The function executes fp8-quantized OUT = AB matrix multiplication. @@ -916,17 +1078,33 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, MMs used in the fused MoE operation. - a/b/c_strides: The data strides passed to grouped matrix multiplication. """ - return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, - a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, - c_strides, per_act_token, per_out_ch) + return torch.ops._C.cutlass_moe_mm( + out_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + expert_offsets, + problem_sizes, + a_strides, + b_strides, + c_strides, + per_act_token, + per_out_ch, + ) -def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - b_tensors: torch.Tensor, a_scales: torch.Tensor, - b_scales: torch.Tensor, alphas: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, sf_offsets: torch.Tensor): +def cutlass_fp4_moe_mm( + out_tensors: torch.Tensor, + a_tensors: torch.Tensor, + b_tensors: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, + alphas: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, + sf_offsets: torch.Tensor, +): """ An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs the gemms for each combination based on the specified problem sizes. @@ -943,132 +1121,202 @@ def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. """ - return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors, - a_scales, b_scales, alphas, - problem_sizes, expert_offsets, - sf_offsets) + return torch.ops._C.cutlass_fp4_group_mm( + out_tensors, + a_tensors, + b_tensors, + a_scales, + b_scales, + alphas, + problem_sizes, + expert_offsets, + sf_offsets, + ) # gptq_marlin -def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, - num_bits) +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) # gptq_marlin -def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) -def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: +def gptq_marlin_moe_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), - device=b_q_weight.device, - dtype=b_q_weight.dtype) + output = torch.empty( + (num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype, + ) for e in range(num_experts): - output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], - size_k, size_n, num_bits) + output[e] = torch.ops._C.gptq_marlin_repack( + b_q_weight[e], perm[e], size_k, size_n, num_bits + ) return output -def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: +def awq_marlin_moe_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: num_experts = b_q_weight.shape[0] assert size_k % 16 == 0 - output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), - device=b_q_weight.device, - dtype=b_q_weight.dtype) + output = torch.empty( + (num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype, + ) for e in range(num_experts): - output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, - size_n, num_bits) + output[e] = torch.ops._C.awq_marlin_repack( + b_q_weight[e], size_k, size_n, num_bits + ) return output -def gptq_marlin_gemm(a: torch.Tensor, - c: Optional[torch.Tensor], - b_q_weight: torch.Tensor, - b_bias: Optional[torch.Tensor], - b_scales: torch.Tensor, - global_scale: Optional[torch.Tensor], - b_zeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool = True, - use_atomic_add: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: - return torch.ops._C.gptq_marlin_gemm(a, c, b_q_weight, b_bias, b_scales, - global_scale, b_zeros, g_idx, perm, - workspace, b_q_type.id, size_m, - size_n, size_k, is_k_full, - use_atomic_add, use_fp32_reduce, - is_zp_float) +def gptq_marlin_gemm( + a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_bias: Optional[torch.Tensor], + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return torch.ops._C.gptq_marlin_gemm( + a, + c, + b_q_weight, + b_bias, + b_scales, + global_scale, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + ) # machete def machete_supported_schedules( - a_type: torch.dtype, - b_type: ScalarType, - group_scales_type: Optional[torch.dtype], - group_zeros_type: Optional[torch.dtype] = None, - channel_scales_type: Optional[torch.dtype] = None, - token_scales_type: Optional[torch.dtype] = None, - out_type: Optional[torch.dtype] = None) -> list[str]: + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None, +) -> list[str]: return torch.ops._C.machete_supported_schedules( - a_type, b_type.id, group_scales_type, group_zeros_type, - channel_scales_type, token_scales_type, out_type) + a_type, + b_type.id, + group_scales_type, + group_zeros_type, + channel_scales_type, + token_scales_type, + out_type, + ) def machete_mm( - a: torch.Tensor, - # b_q Should be the tensor returned by machete_prepack_B - b_q: torch.Tensor, - b_type: ScalarType, - out_type: Optional[torch.dtype] = None, - b_group_scales: Optional[torch.Tensor] = None, - b_group_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - b_channel_scales: Optional[torch.Tensor] = None, - a_token_scales: Optional[torch.Tensor] = None, - schedule: Optional[str] = None) -> torch.Tensor: - return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales, - b_group_zeros, b_group_size, - b_channel_scales, a_token_scales, schedule) + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None, +) -> torch.Tensor: + return torch.ops._C.machete_mm( + a, + b_q, + b_type.id, + out_type, + b_group_scales, + b_group_zeros, + b_group_size, + b_channel_scales, + a_token_scales, + schedule, + ) def machete_prepack_B( - b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, - group_scales_type: Optional[torch.dtype]) -> torch.Tensor: - return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id, - group_scales_type) + b_q_weight: torch.Tensor, + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], +) -> torch.Tensor: + return torch.ops._C.machete_prepack_B( + b_q_weight, a_type, b_type.id, group_scales_type + ) # CUTLASS W4A8 def cutlass_w4a8_mm( - a: torch.Tensor, - # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b - b_q: torch.Tensor, - b_group_scales: torch.Tensor, - b_group_size: int, - b_channel_scales: torch.Tensor, - a_token_scales: torch.Tensor, - out_type: Optional[torch.dtype] = None, - maybe_schedule: Optional[str] = None) -> torch.Tensor: - return torch.ops._C.cutlass_w4a8_mm(a, b_q, b_group_scales, b_group_size, - b_channel_scales, a_token_scales, - out_type, maybe_schedule) + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None, +) -> torch.Tensor: + return torch.ops._C.cutlass_w4a8_mm( + a, + b_q, + b_group_scales, + b_group_size, + b_channel_scales, + a_token_scales, + out_type, + maybe_schedule, + ) def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor: @@ -1082,8 +1330,7 @@ def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") - def _permute_cols_fake(a: torch.Tensor, - perm: torch.Tensor) -> torch.Tensor: + def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) @@ -1093,8 +1340,8 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: # fp4 def scaled_fp4_quant( - input: torch.Tensor, - input_global_scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + input: torch.Tensor, input_global_scale: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP4 and return quantized tensor and scale. @@ -1114,18 +1361,17 @@ def scaled_fp4_quant( in the sizzled layout. """ assert not current_platform.is_rocm() - assert input.ndim >= 1, ( - f'input.ndim needs to be >= 1, but got {input.ndim}.') + assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." other_dims = 1 if input.ndim == 1 else -1 input = input.reshape(other_dims, input.shape[-1]) m, n = input.shape block_size = 16 device = input.device - assert n % block_size == 0, ( - f'last dim has to be multiple of 16, but got {n}.') + assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." assert input.dtype in (torch.float16, torch.bfloat16), ( - f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.') + f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." + ) # Two fp4 values will be packed into an uint8. output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) @@ -1139,12 +1385,11 @@ def scaled_fp4_quant( rounded_m = round_up(m, 128) scale_n = n // block_size rounded_n = round_up(scale_n, 4) - output_scale = torch.empty((rounded_m, rounded_n // 4), - device=device, - dtype=torch.int32) + output_scale = torch.empty( + (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 + ) - torch.ops._C.scaled_fp4_quant(output, input, output_scale, - input_global_scale) + torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale) output_scale = output_scale.view(torch.float8_e4m3fn) return output, output_scale @@ -1170,7 +1415,8 @@ def scaled_fp4_experts_quant( """ assert not current_platform.is_rocm() assert input_tensor.ndim == 2, ( - f'input.ndim needs to be == 2, but got {input_tensor.ndim}.') + f"input.ndim needs to be == 2, but got {input_tensor.ndim}." + ) # Control the maximum number of tokens per expert supported by the # NVFP4 MoE Expert Quantization. This is used to prevent the kernel @@ -1179,26 +1425,33 @@ def scaled_fp4_experts_quant( MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE m_numtopk, k = input_tensor.shape - assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), ( + assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, ( f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" f"{MAX_TOKENS_PER_EXPERT})" f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" - f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value.") + f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value." + ) scales_k = k // 16 padded_k = (scales_k + (4 - 1)) // 4 # output is uint8 and packed fp4 values - output = torch.empty(m_numtopk, - k // 2, - device=input_tensor.device, - dtype=torch.uint8) - output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk, - padded_k, - dtype=torch.int32, - device=input_tensor.device) - torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor, - input_global_scale, expert_offsets, - blockscale_offsets) + output = torch.empty( + m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 + ) + output_scales = torch.empty( + MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device, + ) + torch.ops._C.scaled_fp4_experts_quant( + output, + output_scales, + input_tensor, + input_global_scale, + expert_offsets, + blockscale_offsets, + ) output_scales = output_scales.view(torch.float8_e4m3fn) return output, output_scales @@ -1236,7 +1489,7 @@ def scaled_fp8_quant( scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) + assert input.ndim == 2 shape: Union[tuple[int, int], torch.Size] = input.shape # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = current_platform.fp8_dtype() @@ -1245,17 +1498,15 @@ def scaled_fp8_quant( if output is None: output = torch.empty(shape, device=input.device, dtype=out_dtype) else: - assert num_token_padding is None, \ - "padding not supported if output passed in" + assert num_token_padding is None, "padding not supported if output passed in" assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) + output, input, scale, scale_ub + ) else: scale = torch.empty(1, device=input.device, dtype=torch.float32) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) @@ -1268,10 +1519,10 @@ def scaled_fp8_quant( # gptq allspark def allspark_repack_weight( - qweight: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor] = None, - has_zp: bool = False + qweight: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None, + has_zp: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format @@ -1293,38 +1544,61 @@ def allspark_repack_weight( N = qweight.shape[1] N_32align = (N + 32 - 1) // 32 * 32 - qweight_reorder = torch.empty((N_32align, K), - device=qweight.device, - dtype=qweight.dtype) - scale_reorder = torch.empty((1, N_32align), - device=scale.device, - dtype=scale.dtype) + qweight_reorder = torch.empty( + (N_32align, K), device=qweight.device, dtype=qweight.dtype + ) + scale_reorder = torch.empty((1, N_32align), device=scale.device, dtype=scale.dtype) zero_point_reorder = None if has_zp: assert zero_point is not None, ( - "zero_point must be provided for asymmetric quantization.") - zero_point_reorder = torch.empty((1, N_32align), - device=zero_point.device, - dtype=zero_point.dtype) + "zero_point must be provided for asymmetric quantization." + ) + zero_point_reorder = torch.empty( + (1, N_32align), device=zero_point.device, dtype=zero_point.dtype + ) torch.ops._C.rearrange_kn_weight_as_n32k16_order( - qweight, scale, zero_point, has_zp, qweight_reorder, scale_reorder, - zero_point_reorder, K, N, N_32align) + qweight, + scale, + zero_point, + has_zp, + qweight_reorder, + scale_reorder, + zero_point_reorder, + K, + N, + N_32align, + ) return qweight_reorder, scale_reorder, zero_point_reorder -def allspark_w8a16_gemm(a: torch.Tensor, b_qweight: torch.Tensor, - b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], n: int, - group_size: int, sm_count: int, sm_version: int, - CUBLAS_M_THRESHOLD: int, has_zp: bool, - n32k16_reorder: bool) -> torch.Tensor: - - return torch.ops._C.allspark_w8a16_gemm(a, b_qweight, b_scales, b_qzeros, - n, group_size, sm_count, - sm_version, CUBLAS_M_THRESHOLD, - has_zp, n32k16_reorder) +def allspark_w8a16_gemm( + a: torch.Tensor, + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + n: int, + group_size: int, + sm_count: int, + sm_version: int, + CUBLAS_M_THRESHOLD: int, + has_zp: bool, + n32k16_reorder: bool, +) -> torch.Tensor: + return torch.ops._C.allspark_w8a16_gemm( + a, + b_qweight, + b_scales, + b_qzeros, + n, + group_size, + sm_count, + sm_version, + CUBLAS_M_THRESHOLD, + has_zp, + n32k16_reorder, + ) # int8 @@ -1332,7 +1606,7 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True + symmetric: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -1351,26 +1625,27 @@ def scaled_int8_quant( output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - assert symmetric == ( - azp - is None), "azp must only be provided for asymmetric quantization." + assert symmetric == (azp is None), ( + "azp must only be provided for asymmetric quantization." + ) torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - torch.ops._C.dynamic_scaled_int8_quant(output, input.contiguous(), - input_scales, input_azp) + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant( + output, input.contiguous(), input_scales, input_azp + ) return output, input_scales, input_azp # gguf -def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, - dtype: Optional[torch.dtype]) -> torch.Tensor: +def ggml_dequantize( + W: torch.Tensor, quant_type: int, m: int, n: int, dtype: Optional[torch.dtype] +) -> torch.Tensor: return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype) @@ -1403,9 +1678,17 @@ def ggml_moe_a8( top_k: int, tokens: int, ) -> torch.Tensor: - return torch.ops._C.ggml_moe_a8(X, W, sorted_token_ids, expert_ids, - num_tokens_post_padded, quant_type, row, - top_k, tokens) + return torch.ops._C.ggml_moe_a8( + X, + W, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + quant_type, + row, + top_k, + tokens, + ) def ggml_moe_a8_vec( @@ -1417,8 +1700,7 @@ def ggml_moe_a8_vec( row: torch.SymInt, tokens: torch.SymInt, ) -> torch.Tensor: - return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, - tokens) + return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, tokens) def ggml_moe_get_block_size(quant_type: int) -> int: @@ -1426,44 +1708,61 @@ def ggml_moe_get_block_size(quant_type: int) -> int: # mamba -def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, - B: torch.Tensor, C: torch.Tensor, - D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], - delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, - query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - ssm_states: torch.Tensor, pad_slot_id: int): - torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, - delta_softplus, query_start_loc, - cache_indices, has_initial_state, - ssm_states, pad_slot_id) +def selective_scan_fwd( + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D_: Optional[torch.Tensor], + z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: torch.Tensor, + pad_slot_id: int, +): + torch.ops._C.selective_scan_fwd( + u, + delta, + A, + B, + C, + D_, + z_, + delta_bias_, + delta_softplus, + query_start_loc, + cache_indices, + has_initial_state, + ssm_states, + pad_slot_id, + ) # ROCm skinny gemms -def LLMM1(a: torch.Tensor, b: torch.Tensor, - rows_per_block: int) -> torch.Tensor: +def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor: return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) -def wvSplitK(a: torch.Tensor, - b: torch.Tensor, - cu_count: int, - bias: torch.Tensor = None) -> torch.Tensor: +def wvSplitK( + a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None +) -> torch.Tensor: return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) -def wvSplitKQ(a: torch.Tensor, - b: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - cu_count: int, - bias: torch.Tensor = None) -> torch.Tensor: - out = torch.empty((b.shape[0], a.shape[0]), - dtype=out_dtype, - device=b.device) +def wvSplitKQ( + a: torch.Tensor, + b: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + cu_count: int, + bias: torch.Tensor = None, +) -> torch.Tensor: + out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device) torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count) return out @@ -1473,118 +1772,212 @@ def moe_sum(input: torch.Tensor, output: torch.Tensor): torch.ops._moe_C.moe_sum(input, output) -def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, - block_size: int, sorted_token_ids: torch.Tensor, - experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor) -> None: - torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) +def moe_align_block_size( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + torch.ops._moe_C.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + ) -def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, - b_qweight: torch.Tensor, b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], - topk_weights: Optional[torch.Tensor], - sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, top_k: int, - BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, - bit: int) -> torch.Tensor: +def moe_wna16_gemm( + input: torch.Tensor, + output: torch.Tensor, + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, + top_k: int, + BLOCK_SIZE_M: int, + BLOCK_SIZE_N: int, + BLOCK_SIZE_K: int, + bit: int, +) -> torch.Tensor: if not current_platform.is_cuda(): raise NotImplementedError( - "The optimized moe_wna16_gemm kernel is only " - "available on CUDA platforms") - torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales, - b_qzeros, topk_weights, sorted_token_ids, - experts_ids, num_tokens_post_pad, top_k, - BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, - bit) + "The optimized moe_wna16_gemm kernel is only available on CUDA platforms" + ) + torch.ops._moe_C.moe_wna16_gemm( + input, + output, + b_qweight, + b_scales, + b_qzeros, + topk_weights, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + top_k, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + bit, + ) -def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor) -> None: - torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, token_expert_indices, - gating_output) +def topk_softmax( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, +) -> None: + torch.ops._moe_C.topk_softmax( + topk_weights, topk_ids, token_expert_indices, gating_output + ) -def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor, - num_expert_group: int, topk_group: int, topk: int, - renormalize: bool, routed_scaling_factor: float): +def grouped_topk( + scores: torch.Tensor, + scores_with_bias: torch.Tensor, + num_expert_group: int, + topk_group: int, + topk: int, + renormalize: bool, + routed_scaling_factor: float, +): if not current_platform.is_cuda(): - raise NotImplementedError("The fused grouped_topk kernel is only " - "available on CUDA platforms") - return torch.ops._moe_C.grouped_topk(scores, scores_with_bias, - num_expert_group, topk_group, topk, - renormalize, routed_scaling_factor) - - -def moe_wna16_marlin_gemm(input: torch.Tensor, output: Optional[torch.Tensor], - b_qweight: torch.Tensor, - b_bias: Optional[torch.Tensor], - b_scales: torch.Tensor, - global_scale: Optional[torch.Tensor], - b_qzeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], - workspace: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_past_padded: torch.Tensor, - topk_weights: torch.Tensor, moe_block_size: int, - top_k: int, mul_topk_weights: bool, is_ep: bool, - b_q_type: ScalarType, size_m: int, size_n: int, - size_k: int, is_k_full: bool, use_atomic_add: bool, - use_fp32_reduce: bool, - is_zp_float: bool) -> torch.Tensor: + raise NotImplementedError( + "The fused grouped_topk kernel is only available on CUDA platforms" + ) + return torch.ops._moe_C.grouped_topk( + scores, + scores_with_bias, + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) + + +def moe_wna16_marlin_gemm( + input: torch.Tensor, + output: Optional[torch.Tensor], + b_qweight: torch.Tensor, + b_bias: Optional[torch.Tensor], + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, + moe_block_size: int, + top_k: int, + mul_topk_weights: bool, + is_ep: bool, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + use_atomic_add: bool, + use_fp32_reduce: bool, + is_zp_float: bool, +) -> torch.Tensor: return torch.ops._moe_C.moe_wna16_marlin_gemm( - input, output, b_qweight, b_bias, b_scales, global_scale, b_qzeros, - g_idx, perm, workspace, sorted_token_ids, expert_ids, - num_tokens_past_padded, topk_weights, moe_block_size, top_k, - mul_topk_weights, is_ep, b_q_type.id, size_m, size_n, size_k, - is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float) + input, + output, + b_qweight, + b_bias, + b_scales, + global_scale, + b_qzeros, + g_idx, + perm, + workspace, + sorted_token_ids, + expert_ids, + num_tokens_past_padded, + topk_weights, + moe_block_size, + top_k, + mul_topk_weights, + is_ep, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + ) if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): @register_fake("_moe_C::marlin_gemm_moe") - def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, - sorted_ids: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, b_scales: torch.Tensor, - b_zero_points: torch.Tensor, g_idx: torch.Tensor, - perm: torch.Tensor, workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, size_k: torch.SymInt, - is_k_full: bool, num_experts: int, topk: int, - moe_block_size: int, replicate_input: bool, - apply_weights: bool) -> torch.Tensor: - return torch.empty((size_m, topk, size_n), - dtype=a.dtype, - device=a.device) + def marlin_gemm_moe_fake( + a: torch.Tensor, + b_q_weights: torch.Tensor, + sorted_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + b_scales: torch.Tensor, + b_zero_points: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + num_experts: int, + topk: int, + moe_block_size: int, + replicate_input: bool, + apply_weights: bool, + ) -> torch.Tensor: + return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device) @register_fake("_moe_C::moe_wna16_marlin_gemm") - def moe_wna16_marlin_gemm_fake(input: torch.Tensor, - output: Optional[torch.Tensor], - b_qweight: torch.Tensor, - b_scales: torch.Tensor, - b_qzeros: Optional[torch.Tensor], - g_idx: Optional[torch.Tensor], - perm: Optional[torch.Tensor], - workspace: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_past_padded: torch.Tensor, - topk_weights: torch.Tensor, - moe_block_size: int, top_k: int, - mul_topk_weights: bool, is_ep: bool, - b_q_type: ScalarType, size_m: int, - size_n: int, size_k: int, is_k_full: bool, - use_atomic_add: bool, use_fp32_reduce: bool, - is_zp_float: bool) -> torch.Tensor: - return torch.empty((size_m * top_k, size_n), - dtype=input.dtype, - device=input.device) + def moe_wna16_marlin_gemm_fake( + input: torch.Tensor, + output: Optional[torch.Tensor], + b_qweight: torch.Tensor, + b_scales: torch.Tensor, + b_qzeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_past_padded: torch.Tensor, + topk_weights: torch.Tensor, + moe_block_size: int, + top_k: int, + mul_topk_weights: bool, + is_ep: bool, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + use_atomic_add: bool, + use_fp32_reduce: bool, + is_zp_float: bool, + ) -> torch.Tensor: + return torch.empty( + (size_m * top_k, size_n), dtype=input.dtype, device=input.device + ) def reshape_and_cache( @@ -1597,9 +1990,16 @@ def reshape_and_cache( k_scale: torch.Tensor, v_scale: torch.Tensor, ) -> None: - torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, - value_cache, slot_mapping, - kv_cache_dtype, k_scale, v_scale) + torch.ops._C_cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) def reshape_and_cache_flash( @@ -1612,10 +2012,16 @@ def reshape_and_cache_flash( k_scale: torch.Tensor, v_scale: torch.Tensor, ) -> None: - torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, - value_cache, slot_mapping, - kv_cache_dtype, k_scale, - v_scale) + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) def concat_and_cache_mla( @@ -1626,65 +2032,80 @@ def concat_and_cache_mla( kv_cache_dtype: str, scale: torch.Tensor, ) -> None: - torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, - slot_mapping, kv_cache_dtype, - scale) + torch.ops._C_cache_ops.concat_and_cache_mla( + kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale + ) -def copy_blocks(key_caches: list[torch.Tensor], - value_caches: list[torch.Tensor], - block_mapping: torch.Tensor) -> None: +def copy_blocks( + key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + block_mapping: torch.Tensor, +) -> None: torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) -def copy_blocks_mla(kv_caches: list[torch.Tensor], - block_mapping: torch.Tensor) -> None: +def copy_blocks_mla(kv_caches: list[torch.Tensor], block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.copy_blocks_mla(kv_caches, block_mapping) -def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: torch.Tensor) -> None: +def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor +) -> None: torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) -def convert_fp8(output: torch.Tensor, - input: torch.Tensor, - scale: float = 1.0, - kv_dtype: str = "fp8") -> None: +def convert_fp8( + output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" +) -> None: torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) def gather_and_maybe_dequant_cache( - src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, - batch_size: int, - kv_cache_dtype: str, - scale: torch.Tensor, - seq_starts: Optional[torch.Tensor] = None) -> None: + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + kv_cache_dtype: str, + scale: torch.Tensor, + seq_starts: Optional[torch.Tensor] = None, +) -> None: torch.ops._C_cache_ops.gather_and_maybe_dequant_cache( - src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype, - scale, seq_starts) + src_cache, + dst, + block_table, + cu_seq_lens, + batch_size, + kv_cache_dtype, + scale, + seq_starts, + ) -def cp_gather_cache(src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, - batch_size: int, - seq_starts: Optional[torch.Tensor] = None) -> None: - torch.ops._C_cache_ops.cp_gather_cache(src_cache, dst, block_table, - cu_seq_lens, batch_size, seq_starts) +def cp_gather_cache( + src_cache: torch.Tensor, + dst: torch.Tensor, + block_table: torch.Tensor, + cu_seq_lens: torch.Tensor, + batch_size: int, + seq_starts: Optional[torch.Tensor] = None, +) -> None: + torch.ops._C_cache_ops.cp_gather_cache( + src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts + ) -def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - quant_block_size: int, - kv_cache_dtype: str) -> None: - torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, - quant_block_size, - kv_cache_dtype) +def indexer_k_quant_and_cache( + k: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + quant_block_size: int, + kv_cache_dtype: str, +) -> None: + torch.ops._C_cache_ops.indexer_k_quant_and_cache( + k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype + ) def get_device_attribute(attribute: int, device: int) -> int: @@ -1694,20 +2115,30 @@ def get_device_attribute(attribute: int, device: int) -> int: def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # ruff: noqa: E501 return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( - device) + device + ) # custom ar -def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor, - rank: int, fully_connected: bool) -> int: - return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, - fully_connected) +def init_custom_ar( + ipc_tensors: list[torch.Tensor], + rank_data: torch.Tensor, + rank: int, + fully_connected: bool, +) -> int: + return torch.ops._C_custom_ar.init_custom_ar( + ipc_tensors, rank_data, rank, fully_connected + ) -def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, - reg_buffer_sz_bytes: int) -> None: - torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, - reg_buffer_sz_bytes) +def all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + reg_buffer: int, + reg_buffer_sz_bytes: int, +) -> None: + torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) def dispose(fa: int) -> None: @@ -1726,8 +2157,9 @@ def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]: return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) -def register_graph_buffers(fa: int, handles: list[list[int]], - offsets: list[list[int]]) -> None: +def register_graph_buffers( + fa: int, handles: list[list[int]], offsets: list[list[int]] +) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) @@ -1744,9 +2176,9 @@ def free_shared_buffer(ptr: int) -> None: # quick all reduce -def init_custom_qr(rank: int, - world_size: int, - qr_max_size: Optional[int] = None) -> int: +def init_custom_qr( + rank: int, world_size: int, qr_max_size: Optional[int] = None +) -> int: return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size) @@ -1754,13 +2186,14 @@ def qr_destroy(fa: int) -> None: torch.ops._C_custom_ar.qr_destroy(fa) -def qr_all_reduce(fa: int, - inp: torch.Tensor, - out: torch.Tensor, - quant_level: int, - cast_bf2half: bool = False) -> None: - torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level, - cast_bf2half) +def qr_all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + quant_level: int, + cast_bf2half: bool = False, +) -> None: + torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half) def qr_get_handle(fa: int) -> torch.Tensor: @@ -1790,9 +2223,9 @@ def get_flash_mla_metadata( tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return torch.ops._C.get_flash_mla_metadata(cache_seqlens, - num_heads_per_head_k, - num_heads_k) + return torch.ops._C.get_flash_mla_metadata( + cache_seqlens, num_heads_per_head_k, num_heads_k + ) def flash_mla_with_kvcache( @@ -1823,7 +2256,7 @@ def flash_mla_with_kvcache( softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: - softmax_scale = q.shape[-1]**(-0.5) + softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( q, k_cache, @@ -1839,35 +2272,53 @@ def flash_mla_with_kvcache( return out, softmax_lse -def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor, - q_nope: torch.Tensor, q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - seq_lens: torch.Tensor, page_table: torch.Tensor, - workspace: torch.Tensor, scale: float, - num_kv_splits: int) -> torch.Tensor: - torch.ops._C.sm100_cutlass_mla_decode(out, lse, q_nope, q_pe, - kv_c_and_k_pe_cache, seq_lens, - page_table, workspace, scale, - num_kv_splits) +def sm100_cutlass_mla_decode( + out: torch.Tensor, + lse: torch.Tensor, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + workspace: torch.Tensor, + scale: float, + num_kv_splits: int, +) -> torch.Tensor: + torch.ops._C.sm100_cutlass_mla_decode( + out, + lse, + q_nope, + q_pe, + kv_c_and_k_pe_cache, + seq_lens, + page_table, + workspace, + scale, + num_kv_splits, + ) return out -def sm100_cutlass_mla_get_workspace_size(max_seq_len: int, num_batches: int, - sm_count: int, - num_kv_splits: int) -> int: +def sm100_cutlass_mla_get_workspace_size( + max_seq_len: int, num_batches: int, sm_count: int, num_kv_splits: int +) -> int: return torch.ops._C.sm100_cutlass_mla_get_workspace_size( - max_seq_len, num_batches, sm_count, num_kv_splits) + max_seq_len, num_batches, sm_count, num_kv_splits + ) if hasattr(torch.ops._C, "weight_packed_linear"): @register_fake("_C::weight_packed_linear") - def weight_packed_linear_fake(mat1: torch.Tensor, mat2: torch.Tensor, - bias: Optional[torch.Tensor], - is_vnni: bool) -> torch.Tensor: - return torch.empty((mat1.size(0), mat2.size(0)), - dtype=mat1.dtype, - device=mat2.device) + def weight_packed_linear_fake( + mat1: torch.Tensor, + mat2: torch.Tensor, + bias: Optional[torch.Tensor], + is_vnni: bool, + ) -> torch.Tensor: + return torch.empty( + (mat1.size(0), mat2.size(0)), dtype=mat1.dtype, device=mat2.device + ) if hasattr(torch.ops._C, "fused_experts_cpu"): @@ -1909,7 +2360,6 @@ def int8_scaled_mm_with_quant_fake( class CPUDNNLGEMMHandler: - def __init__(self) -> None: self.handler: Optional[int] = None self.n = -1 @@ -1920,10 +2370,11 @@ def __del__(self): torch.ops._C.release_dnnl_matmul_handler(self.handler) -if hasattr(torch.ops._C, "create_onednn_mm_handler"): - _supports_onednn = True -else: - _supports_onednn = False +_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler")) + + +def is_onednn_acl_supported(): + return torch.ops._C.is_onednn_acl_supported() def create_onednn_mm( @@ -1933,7 +2384,8 @@ def create_onednn_mm( handler = CPUDNNLGEMMHandler() handler.k, handler.n = weight.size() handler.handler = torch.ops._C.create_onednn_mm_handler( - weight, primitive_cache_size) + weight, primitive_cache_size + ) return handler @@ -1943,8 +2395,9 @@ def onednn_mm( bias: Optional[torch.Tensor], ) -> torch.Tensor: output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) - torch.ops._C.onednn_mm(output, x.reshape(-1, dnnl_handler.k), bias, - dnnl_handler.handler) + torch.ops._C.onednn_mm( + output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler + ) return output @@ -1960,15 +2413,17 @@ def create_onednn_scaled_mm( handler = CPUDNNLGEMMHandler() handler.k, handler.n = weight.size() handler.handler = torch.ops._C.create_onednn_scaled_mm_handler( - weight, weight_scales, output_type, dynamic_quant, use_azp, - primitive_cache_size) + weight, weight_scales, output_type, dynamic_quant, use_azp, primitive_cache_size + ) return handler -def onednn_scaled_int8_quant(input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True): +def onednn_scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +): """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -1988,20 +2443,16 @@ def onednn_scaled_int8_quant(input: torch.Tensor, input = input.view((token_num, input.shape[-1])) if scale is not None: # static-per-tensor quantization. - assert symmetric == ( - azp - is None), "azp must only be provided for asymmetric quantization." + assert symmetric == (azp is None), ( + "azp must only be provided for asymmetric quantization." + ) torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty((token_num, 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) + input_scales = torch.empty((token_num, 1), device=input.device, dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) return output, input_scales, input_azp @@ -2014,8 +2465,9 @@ def onednn_scaled_mm( input_zp_adj: Optional[torch.Tensor], bias: Optional[torch.Tensor], ) -> torch.Tensor: - torch.ops._C.onednn_scaled_mm(output, x, input_scale, input_zp, - input_zp_adj, bias, dnnl_handler.handler) + torch.ops._C.onednn_scaled_mm( + output, x, input_scale, input_zp, input_zp_adj, bias, dnnl_handler.handler + ) return output @@ -2028,7 +2480,7 @@ def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor: Note that sylvester hadamard transforms are also symmetric, which means that this function is also applies the (transpose <=> inverse) transform. - + :param x: value to be transformed inplace :param inplace: modify value in place :return: value after transformation @@ -2039,6 +2491,5 @@ def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor: if hasattr(torch.ops._C, "hadacore_transform"): @register_fake("_C::hadacore_transform") - def _hadacore_transform_fake(x: torch.Tensor, - inplace: bool) -> torch.Tensor: + def _hadacore_transform_fake(x: torch.Tensor, inplace: bool) -> torch.Tensor: return torch.empty_like(x) if not inplace else x diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 9d2eda482fcf..1f458f940a28 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -17,10 +17,10 @@ class ipex_ops: - @staticmethod def _reshape_activation_tensor( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: num = x.size(0) d = x.size(1) // 2 x = x.reshape(num, 2, d) @@ -144,20 +144,26 @@ def rotary_embedding( is_neox: bool, ) -> None: rot_dim = cos_sin_cache.size(1) - ipex.llm.functional.rotary_embedding_batched(positions, query, key, - head_size, cos_sin_cache, - is_neox, rot_dim) + ipex.llm.functional.rotary_embedding_batched( + positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim + ) @staticmethod - def rms_norm(input: torch.Tensor, weight: torch.Tensor, - epsilon: float) -> torch.Tensor: + def rms_norm( + input: torch.Tensor, weight: torch.Tensor, epsilon: float + ) -> torch.Tensor: return ipex.llm.functional.rms_norm(input, weight, epsilon) @staticmethod - def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, epsilon: float) -> None: - tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, - epsilon, True) + def fused_add_rms_norm( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + ) -> None: + tmp = ipex.llm.functional.add_rms_norm( + residual, input, weight, None, epsilon, True + ) input.copy_(tmp) @staticmethod @@ -186,22 +192,43 @@ def varlen_attention( raise ValueError("IPEX CPU does not support logits_soft_cap") assert alibi_slopes is None assert window_size_left < 0 and window_size_right < 0 - ipex.llm.functional.varlen_attention(query.contiguous(), - key.contiguous(), - value.contiguous(), out, - seqlen_q.int(), - seqlen_k.int(), max_seqlen_q, - max_seqlen_k, pdropout, - softmax_scale, zero_tensors, - is_causal, return_softmax, - gen_) + ipex.llm.functional.varlen_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + out, + seqlen_q.int(), + seqlen_k.int(), + max_seqlen_q, + max_seqlen_k, + pdropout, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + ) else: # XPU build ipex.llm.functional.varlen_attention( - query.contiguous(), key.contiguous(), value.contiguous(), out, - seqlen_q.int(), seqlen_k.int(), alibi_slopes, max_seqlen_q, - max_seqlen_k, pdropout, softmax_scale, zero_tensors, is_causal, - return_softmax, gen_, window_size_left, window_size_right, - logits_soft_cap) + query.contiguous(), + key.contiguous(), + value.contiguous(), + out, + seqlen_q.int(), + seqlen_k.int(), + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + pdropout, + softmax_scale, + zero_tensors, + is_causal, + return_softmax, + gen_, + window_size_left, + window_size_right, + logits_soft_cap, + ) @staticmethod def reshape_and_cache( @@ -216,7 +243,8 @@ def reshape_and_cache( ) -> None: assert kv_cache_dtype == "auto" ipex.llm.modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, slot_mapping) + key, value, key_cache, value_cache, slot_mapping + ) @staticmethod def reshape_and_cache_flash( @@ -232,8 +260,15 @@ def reshape_and_cache_flash( v_scale_float: float = 1.0, ) -> None: ipex.llm.modules.PagedAttention.reshape_and_cache_flash( - key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, - k_scale_float, v_scale_float) + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale_float, + v_scale_float, + ) @staticmethod def flash_attn_varlen_func( @@ -265,10 +300,12 @@ def flash_attn_varlen_func( if cu_seqlens_k is None: # cu_seqlens_k is not used in ipex kernel. cu_seqlens_k = torch.cumsum(seqused_k, dim=0) - cu_seqlens_k = torch.cat([ - torch.tensor([0], device=seqused_k.device, dtype=torch.int32), - cu_seqlens_k - ]).to(torch.int32) + cu_seqlens_k = torch.cat( + [ + torch.tensor([0], device=seqused_k.device, dtype=torch.int32), + cu_seqlens_k, + ] + ).to(torch.int32) real_window_size: tuple[int, int] if window_size is None: @@ -298,36 +335,38 @@ def flash_attn_varlen_func( @staticmethod def get_scheduler_metadata( - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads_q, - num_heads_kv, - headdim, - cache_seqlens: torch.Tensor, - qkv_dtype=torch.bfloat16, - headdim_v=None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k_new: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - page_size: Optional[int] = None, - max_seqlen_k_new=0, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - has_softcap=False, - num_splits=0, # Can be tuned for speed - pack_gqa=None, # Can be tuned for speed - sm_margin=0, # Can be tuned if some SMs are used for communication + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads_q, + num_heads_kv, + headdim, + cache_seqlens: torch.Tensor, + qkv_dtype=torch.bfloat16, + headdim_v=None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new=0, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + has_softcap=False, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication ) -> None: logger.warning_once( - "get_scheduler_metadata is not implemented for ipex_ops, " - "returning None.") + "get_scheduler_metadata is not implemented for ipex_ops, returning None." + ) return None @staticmethod - def copy_blocks(key_caches: list[torch.Tensor], - value_caches: list[torch.Tensor], - block_mapping: torch.Tensor) -> None: + def copy_blocks( + key_caches: list[torch.Tensor], + value_caches: list[torch.Tensor], + block_mapping: torch.Tensor, + ) -> None: torch.xpu.copy_blocks( # type: ignore key_caches, value_caches, @@ -335,8 +374,9 @@ def copy_blocks(key_caches: list[torch.Tensor], ) @staticmethod - def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: torch.Tensor) -> None: + def swap_blocks( + src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor + ) -> None: torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore @staticmethod @@ -350,7 +390,7 @@ def scaled_fp8_quant( ) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. - + This function is designed for both static and dynamic quantization: If you provide the scale, it will use static scaling and if you omit it, the scale will be determined dynamically. Currently, XPU platform @@ -367,13 +407,13 @@ def scaled_fp8_quant( of the output to at least this value. use_per_token_if_dynamic: Whether to do per_tensor or per_token in the dynamic quantization case. - + Returns: tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) + assert input.ndim == 2 shape: Union[tuple[int, int], torch.Size] = input.shape out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: @@ -381,12 +421,14 @@ def scaled_fp8_quant( if output is None: output = torch.empty(shape, device=input.device, dtype=out_dtype) else: - assert num_token_padding is None, \ + assert num_token_padding is None, ( "padding not supported if output passed in" + ) assert output.dtype == out_dtype assert scale is None, "only dynamic fp8 quantization supported on XPU" assert not use_per_token_if_dynamic, ( - "per token dynamic fp8 quantization not supported on XPU") + "per token dynamic fp8 quantization not supported on XPU" + ) scale = torch.zeros(1, device=input.device, dtype=torch.float32) torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale) diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py index 1c16230849bc..61c2dbf55fe3 100644 --- a/vllm/assets/audio.py +++ b/vllm/assets/audio.py @@ -32,13 +32,11 @@ def filename(self) -> str: @property def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]: - audio_path = get_vllm_public_assets(filename=self.filename, - s3_prefix=ASSET_DIR) + audio_path = get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR) return librosa.load(audio_path, sr=None) def get_local_path(self) -> Path: - return get_vllm_public_assets(filename=self.filename, - s3_prefix=ASSET_DIR) + return get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR) @property def url(self) -> str: diff --git a/vllm/assets/base.py b/vllm/assets/base.py index 31cde431b5b6..409bfc18ff8c 100644 --- a/vllm/assets/base.py +++ b/vllm/assets/base.py @@ -20,8 +20,7 @@ def get_cache_dir() -> Path: @lru_cache -def get_vllm_public_assets(filename: str, - s3_prefix: Optional[str] = None) -> Path: +def get_vllm_public_assets(filename: str, s3_prefix: Optional[str] = None) -> Path: """ Download an asset file from ``s3://vllm-public-assets`` and return the path to the downloaded file. @@ -36,6 +35,7 @@ def get_vllm_public_assets(filename: str, global_http_connection.download_file( f"{VLLM_S3_BUCKET_URL}/{filename}", asset_path, - timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT) + timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT, + ) return asset_path diff --git a/vllm/assets/image.py b/vllm/assets/image.py index 4639a11187d0..c1a0f2b9cc29 100644 --- a/vllm/assets/image.py +++ b/vllm/assets/image.py @@ -12,12 +12,21 @@ VLM_IMAGES_DIR = "vision_model_images" -ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato", - "2560px-Gfp-wisconsin-madison-the-nature-boardwalk", - "Grayscale_8bits_palette_sample_image", - "1280px-Venn_diagram_rgb", "RGBA_comp", "237-400x300", - "231-200x300", "27-500x500", "17-150x600", - "handelsblatt-preview", "paper-11"] +ImageAssetName = Literal[ + "stop_sign", + "cherry_blossom", + "hato", + "2560px-Gfp-wisconsin-madison-the-nature-boardwalk", + "Grayscale_8bits_palette_sample_image", + "1280px-Venn_diagram_rgb", + "RGBA_comp", + "237-400x300", + "231-200x300", + "27-500x500", + "17-150x600", + "handelsblatt-preview", + "paper-11", +] @dataclass(frozen=True) @@ -28,12 +37,12 @@ def get_path(self, ext: str) -> Path: """ Return s3 path for given image. """ - return get_vllm_public_assets(filename=f"{self.name}.{ext}", - s3_prefix=VLM_IMAGES_DIR) + return get_vllm_public_assets( + filename=f"{self.name}.{ext}", s3_prefix=VLM_IMAGES_DIR + ) @property def pil_image(self, ext="jpg") -> Image.Image: - image_path = self.get_path(ext) return Image.open(image_path) @@ -42,7 +51,7 @@ def image_embeds(self) -> torch.Tensor: """ Image embeddings, only used for testing purposes with llava 1.5. """ - image_path = self.get_path('pt') + image_path = self.get_path("pt") return torch.load(image_path, map_location="cpu", weights_only=True) def read_bytes(self, ext: str) -> bytes: diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 5c9e403c4b91..6b2ca8f867e0 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -65,13 +65,14 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: frames = np.stack(frames) if len(frames) < num_frames: - raise ValueError(f"Could not read enough frames from video file {path}" - f" (expected {num_frames} frames, got {len(frames)})") + raise ValueError( + f"Could not read enough frames from video file {path}" + f" (expected {num_frames} frames, got {len(frames)})" + ) return frames -def video_to_pil_images_list(path: str, - num_frames: int = -1) -> list[Image.Image]: +def video_to_pil_images_list(path: str, num_frames: int = -1) -> list[Image.Image]: frames = video_to_ndarrays(path, num_frames) return [Image.fromarray(frame) for frame in frames] @@ -139,7 +140,7 @@ def metadata(self) -> dict[str, Any]: def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: """ Read audio data from the video asset, used in Qwen2.5-Omni examples. - + See also: examples/offline_inference/qwen2_5_omni/only_thinker.py """ return librosa.load(self.video_path, sr=sampling_rate)[0] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 1b37bd1f6100..dd35165d5415 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + AttentionType, +) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index b49e1c007c57..bb2f36271103 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar +from typing import Generic, Optional, Protocol, TypeVar import torch @@ -14,6 +14,7 @@ class AttentionType: Attention type. Use string to be compatible with `torch.compile`. """ + DECODER = "decoder" """Decoder attention between previous layer Q/K/V.""" ENCODER = "encoder" @@ -26,6 +27,7 @@ class AttentionType: class AttentionBackend(ABC): """Abstract class for attention backends.""" + # For some attention backends, we allocate an output tensor before # calling the custom op. When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. @@ -46,12 +48,12 @@ def get_name() -> str: @staticmethod @abstractmethod - def get_impl_cls() -> Type["AttentionImpl"]: + def get_impl_cls() -> type["AttentionImpl"]: raise NotImplementedError @staticmethod @abstractmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: + def get_metadata_cls() -> type["AttentionMetadata"]: raise NotImplementedError @classmethod @@ -71,11 +73,11 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: raise NotImplementedError @staticmethod - def get_kv_cache_stride_order() -> Tuple[int, ...]: + def get_kv_cache_stride_order() -> tuple[int, ...]: raise NotImplementedError @classmethod @@ -91,7 +93,6 @@ class AttentionMetadata: class AttentionLayer(Protocol): - _q_scale: torch.Tensor _k_scale: torch.Tensor _v_scale: torch.Tensor @@ -107,12 +108,10 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - ... + ) -> torch.Tensor: ... class AttentionImpl(ABC, Generic[T]): - # Whether the attention impl can return the softmax lse for decode. # Some features like decode context parallelism require the softmax lse. can_return_lse_for_decode: bool = False @@ -129,14 +128,16 @@ def __new__(cls, *args, **kwargs): self = super().__new__(cls) try: from vllm.distributed.parallel_state import get_dcp_group + self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 - self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \ - and self.can_return_lse_for_decode + self.need_to_return_lse_for_decode = ( + self.dcp_world_size > 1 and self.can_return_lse_for_decode + ) return self @abstractmethod @@ -146,7 +147,7 @@ def __init__( head_size: int, scale: float, num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, + alibi_slopes: Optional[list[float]] = None, sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", logits_soft_cap: Optional[float] = None, @@ -183,7 +184,6 @@ def fused_output_quant_supported(self, quant_key: QuantKey): class MLAAttentionImpl(AttentionImpl[T], Generic[T]): - @abstractmethod def forward( self, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 6b8d97be7050..46a87bdd1f7e 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention backend utils""" + from dataclasses import dataclass from typing import Optional diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ac34f279d0b5..6994debd4589 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from typing import Callable, List, Optional + +from typing import Callable, Optional import torch import torch.nn as nn @@ -14,9 +15,11 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) +from vllm.distributed.kv_transfer import ( + get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -24,8 +27,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform from vllm.utils import GiB_bytes, direct_register_custom_op @@ -33,7 +35,7 @@ logger = init_logger(__name__) USE_XFORMERS_OPS = None try: - tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, ) + tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,) except AttributeError: tag_cudagraph_unsafe = () # type: ignore[assignment] @@ -43,8 +45,7 @@ def check_xformers_availability(): if USE_XFORMERS_OPS is not None: return USE_XFORMERS_OPS - if current_platform.is_cuda() and current_platform.has_device_capability( - 100): + if current_platform.is_cuda() and current_platform.has_device_capability(100): # Xformers FA is not compatible with B200 USE_XFORMERS_OPS = False else: @@ -64,30 +65,36 @@ def check_xformers_availability(): def check_upstream_fa_availability(dtype: torch.dtype): - if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda( - ) and current_platform.has_device_capability(80): + if ( + dtype in (torch.float16, torch.bfloat16) + and current_platform.is_cuda() + and current_platform.has_device_capability(80) + ): from transformers.utils import is_flash_attn_2_available + return is_flash_attn_2_available() if current_platform.is_rocm(): from importlib.util import find_spec + return find_spec("flash_attn") is not None return False def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, - use_upstream_fa: bool) -> tuple[_Backend, Callable]: - if attn_backend != _Backend.FLASH_ATTN and \ - attn_backend != _Backend.ROCM_AITER_FA and \ - check_upstream_fa_availability(torch.get_default_dtype()): + attn_backend: _Backend, use_upstream_fa: bool +) -> tuple[_Backend, Callable]: + if ( + attn_backend != _Backend.FLASH_ATTN + and attn_backend != _Backend.ROCM_AITER_FA + and check_upstream_fa_availability(torch.get_default_dtype()) + ): attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True - if current_platform.is_rocm() and \ - attn_backend == _Backend.FLASH_ATTN: + if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN: use_upstream_fa = True - if (attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}): + if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: @@ -119,7 +126,7 @@ def __init__( head_size: int, scale: float, num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, + alibi_slopes: Optional[list[float]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, logits_soft_cap: Optional[float] = None, @@ -156,9 +163,9 @@ def __init__( calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads - assert num_heads % num_kv_heads == 0, \ - f"num_heads ({num_heads}) is not " \ - f"divisible by num_kv_heads ({num_kv_heads})" + assert num_heads % num_kv_heads == 0, ( + f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})" + ) # The default k/v_scale is set to 1.0. This is ignored # when kv-cache is not fp8, and should be used with @@ -193,16 +200,19 @@ def __init__( self.sliding_window = sliding_window self.has_sink = extra_impl_args.get("sinks") is not None - quant_method = quant_config.get_quant_method( - self, prefix=prefix) if quant_config else None + quant_method = ( + quant_config.get_quant_method(self, prefix=prefix) if quant_config else None + ) if quant_method is not None and not isinstance( - quant_method, UnquantizedLinearMethod): + quant_method, UnquantizedLinearMethod + ): assert isinstance(quant_method, BaseKVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 # checkpoint config and become the "auto" behavior if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError("fp8_e5m2 kv-cache is not supported with " - "fp8 checkpoints.") + raise ValueError( + "fp8_e5m2 kv-cache is not supported with fp8 checkpoints." + ) # If quantization is enabled, we make "k_scale" and "v_scale" # parameters so that it can be loaded from the model checkpoint. # The k/v_scale will then be converted back to native float32 @@ -214,21 +224,32 @@ def __init__( # weight and activation dtype. dtype = torch.get_default_dtype() if attn_backend is None: - self.attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla=use_mla, - has_sink=self.has_sink, - use_sparse=use_sparse) + self.attn_backend = get_attn_backend( + head_size, + dtype, + kv_cache_dtype, + block_size, + use_mla=use_mla, + has_sink=self.has_sink, + use_sparse=use_sparse, + ) else: self.attn_backend = attn_backend impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **extra_impl_args) + self.impl = impl_cls( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **extra_impl_args, + ) self.backend = backend_name_to_enum(self.attn_backend.get_name()) self.dtype = dtype @@ -258,37 +279,39 @@ def __init__( # by bind_kv_cache # this variable will not be accessed if use_direct_call is True self.kv_cache = [ - torch.tensor([]) for _ in range(get_current_vllm_config( - ).parallel_config.pipeline_parallel_size) + torch.tensor([]) + for _ in range( + get_current_vllm_config().parallel_config.pipeline_parallel_size + ) ] try: - self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, - dtype=torch.float32) - self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, - dtype=torch.float32) - self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, - dtype=torch.float32) + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) except torch.cuda.OutOfMemoryError as e: - logger.error( - "Failed to initialize attention q/k/v range constants: %s", e) + logger.error("Failed to initialize attention q/k/v range constants: %s", e) if torch.cuda.is_available(): logger.debug("CUDA device: %s", torch.cuda.current_device()) - logger.debug("Allocated: %.2f GiB", - torch.cuda.memory_allocated() / GiB_bytes) - logger.debug("Reserved: %.2f GiB", - torch.cuda.memory_reserved() / GiB_bytes) + logger.debug( + "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes + ) + logger.debug( + "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes + ) raise RuntimeError( "Failed to initialize q/k/v range constants. " "This may be caused by insufficient memory to allocate " - "kv cache.") from e + "kv cache." + ) from e # for attn backends supporting query quantization self.query_quant = None - if self.kv_cache_dtype.startswith( - "fp8") and self.attn_backend.supports_quant_query_input: - self.query_quant = QuantFP8(static=True, - group_shape=GroupShape.PER_TENSOR) + if ( + self.kv_cache_dtype.startswith("fp8") + and self.attn_backend.supports_quant_query_input + ): + self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR) def forward( self, @@ -310,8 +333,7 @@ def forward( `vllm.forward_context.get_forward_context().attn_metadata`. """ if self.calculate_kv_scales: - torch.ops.vllm.maybe_calc_kv_scales(query, key, value, - self.layer_name) + torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) output_dtype = query.dtype if self.query_quant is not None: @@ -324,11 +346,8 @@ def forward( query, _ = self.query_quant(query, self._q_scale) if self.use_output: - output_shape = (output_shape - if output_shape is not None else query.shape) - output = torch.zeros(output_shape, - dtype=output_dtype, - device=query.device) + output_shape = output_shape if output_shape is not None else query.shape + output = torch.zeros(output_shape, dtype=output_dtype, device=query.device) hidden_size = output_shape[-1] # We skip reshaping query, key and value tensors for the MLA # backend since these tensors have different semantics and are @@ -349,16 +368,13 @@ def forward( if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward(self, - query, - key, - value, - self_kv_cache, - attn_metadata, - output=output) + self.impl.forward( + self, query, key, value, self_kv_cache, attn_metadata, output=output + ) else: torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name) + query, key, value, output, self.layer_name + ) return output.view(-1, hidden_size) else: if self.use_direct_call: @@ -367,11 +383,13 @@ def forward( if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward(self, query, key, value, - self_kv_cache, attn_metadata) + return self.impl.forward( + self, query, key, value, self_kv_cache, attn_metadata + ) else: return torch.ops.vllm.unified_attention( - query, key, value, self.layer_name) + query, key, value, self.layer_name + ) def calc_kv_scales(self, query, key, value): self._q_scale.copy_(torch.abs(query).max() / self.q_range) @@ -396,12 +414,11 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): self.impl.process_weights_after_loading(act_dtype) # FlashInfer requires attention sinks to be float32 - if (self.backend == _Backend.FLASHINFER - and hasattr(self.impl, 'sinks')): + if self.backend == _Backend.FLASHINFER and hasattr(self.impl, "sinks"): from vllm.v1.attention.backends.flashinfer import FlashInferImpl + assert isinstance(self.impl, FlashInferImpl) - if (self.impl.sinks is not None - and self.impl.sinks.dtype != torch.float32): + if self.impl.sinks is not None and self.impl.sinks.dtype != torch.float32: self.impl.sinks = self.impl.sinks.to(torch.float32) def get_attn_backend(self) -> type[AttentionBackend]: @@ -417,16 +434,21 @@ def __init__( head_size: int, scale: float, num_kv_heads: Optional[int] = None, - ): + # This has no effect, it is only here to make it easier to swap + # between Attention and MultiHeadAttention + prefix: str = "", + ) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.layer_name = prefix - assert self.num_heads % self.num_kv_heads == 0, \ - f"num_heads ({self.num_heads}) is not " \ + assert self.num_heads % self.num_kv_heads == 0, ( + f"num_heads ({self.num_heads}) is not " f"divisible by num_kv_heads ({self.num_kv_heads})" + ) self.num_queries_per_kv = self.num_heads // self.num_kv_heads # During model initialization, the default dtype is set as the model @@ -445,38 +467,43 @@ def __init__( # currently, only torch_sdpa is supported on xpu self.attn_backend = _Backend.TORCH_SDPA else: + self.attn_backend = ( + backend + if backend + in { + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.PALLAS, + _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN, + } + else _Backend.TORCH_SDPA + ) - self.attn_backend = backend if backend in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.PALLAS, - _Backend.ROCM_AITER_FA, - _Backend.FLASH_ATTN, - } else _Backend.TORCH_SDPA - - self.attn_backend, self._flash_attn_varlen_func \ - = maybe_get_vit_flash_attn_backend( + self.attn_backend, self._flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( self.attn_backend, use_upstream_fa, ) + ) - if (self.attn_backend == _Backend.XFORMERS - and not check_xformers_availability()): + if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): self.attn_backend = _Backend.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } # this condition is just to make sure that the # use_upstream_fa in the log is correct - if current_platform.is_rocm() \ - and self.attn_backend == _Backend.FLASH_ATTN: + if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: use_upstream_fa = True logger.info_once( f"MultiHeadAttention attn_backend: {self.attn_backend}, " - f"use_upstream_fa: {use_upstream_fa}") + f"use_upstream_fa: {use_upstream_fa}" + ) def forward( self, @@ -484,7 +511,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, ) -> torch.Tensor: - """Input shape: + """Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size) """ @@ -501,14 +528,12 @@ def forward( value = torch.repeat_interleave(value, num_repeat, dim=2) if self.is_flash_attn_backend: - cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, - step=q_len, - dtype=torch.int32, - device=query.device) - cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, - step=kv_len, - dtype=torch.int32, - device=key.device) + cu_seqlens_q = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device + ) + cu_seqlens_k = torch.arange( + 0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device + ) out = self._flash_attn_varlen_func( query.flatten(0, 1), @@ -523,29 +548,24 @@ def forward( elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops - out = xops.memory_efficient_attention_forward(query, - key, - value, - scale=self.scale) + out = xops.memory_efficient_attention_forward( + query, key, value, scale=self.scale + ) elif self.attn_backend == _Backend.TORCH_SDPA: - query, key, value = (x.transpose(1, 2) - for x in (query, key, value)) - out = F.scaled_dot_product_attention(query, - key, - value, - scale=self.scale) + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = out.transpose(1, 2) elif self.attn_backend == _Backend.PALLAS: - query, key, value = (x.transpose(1, 2) - for x in (query, key, value)) + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) from torch_xla.experimental.custom_kernel import flash_attention + out = flash_attention(query, key, value, sm_scale=self.scale) out = out.transpose(1, 2) else: # ViT attention hasn't supported this backend yet raise NotImplementedError( - f"ViT attention hasn't supported {self.attn_backend} " - f"backend yet.") + f"ViT attention hasn't supported {self.attn_backend} backend yet." + ) return out.reshape(bsz, q_len, -1) @@ -566,7 +586,7 @@ def wait_for_kv_layer_from_connector(layer_name: str): def maybe_save_kv_layer_to_connector( layer_name: str, - kv_cache_layer: List[torch.Tensor], + kv_cache_layer: list[torch.Tensor], ): if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): return @@ -578,8 +598,7 @@ def maybe_save_kv_layer_to_connector( if attn_metadata is None: return assert isinstance(attn_metadata, dict) - connector.save_kv_layer(layer_name, kv_cache_layer, - attn_metadata[layer_name]) + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name]) def maybe_calc_kv_scales( @@ -588,7 +607,6 @@ def maybe_calc_kv_scales( value: torch.Tensor, layer_name: str, ) -> None: - forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -596,7 +614,8 @@ def maybe_calc_kv_scales( attn_metadata = attn_metadata[layer_name] if attn_metadata is None or not getattr( - attn_metadata, 'enable_kv_scales_calculation', False): + attn_metadata, "enable_kv_scales_calculation", False + ): return self = forward_context.no_compile_layers[layer_name] @@ -634,8 +653,7 @@ def unified_attention( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - output = self.impl.forward(self, query, key, value, kv_cache, - attn_metadata) + output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output @@ -674,15 +692,17 @@ def unified_attention_with_output( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward(self, - query, - key, - value, - kv_cache, - attn_metadata, - output=output, - output_scale=output_scale, - output_block_scale=output_block_scale) + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + ) maybe_save_kv_layer_to_connector(layer_name, kv_cache) diff --git a/vllm/attention/layers/chunked_local_attention.py b/vllm/attention/layers/chunked_local_attention.py index 2d11b2238e78..3d37e901605f 100644 --- a/vllm/attention/layers/chunked_local_attention.py +++ b/vllm/attention/layers/chunked_local_attention.py @@ -1,19 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import ClassVar, List, Optional +from typing import ClassVar, Optional import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, CommonAttentionMetadata, - make_local_attention_virtual_batches, subclass_attention_backend) + AttentionCGSupport, + CommonAttentionMetadata, + make_local_attention_virtual_batches, + subclass_attention_backend, +) from ..layer import Attention @@ -29,39 +31,42 @@ def create_chunked_local_attention_backend( underlying_builder = underlying_attn_backend.get_builder_cls() class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> AttentionMetadata: + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: common_attn_metadata = make_local_attention_virtual_batches( - attention_chunk_size, common_attn_metadata, block_size) - return super().build(common_prefix_len, common_attn_metadata, - fast_build) + attention_chunk_size, common_attn_metadata, block_size + ) + return super().build(common_prefix_len, common_attn_metadata, fast_build) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=ChunkedLocalAttentionBuilder) + builder_cls=ChunkedLocalAttentionBuilder, + ) return attn_backend class ChunkedLocalAttention(Attention): - - def __init__(self, - num_heads: int, - head_size: int, - scale: float, - attention_chunk_size: int, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - kv_sharing_target_layer_name: Optional[str] = None, - prefix: str = ""): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + attention_chunk_size: int, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[list[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + kv_sharing_target_layer_name: Optional[str] = None, + prefix: str = "", + ): dtype = torch.get_default_dtype() if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype @@ -71,12 +76,13 @@ def __init__(self, block_size = 16 if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend(head_size, dtype, - kv_cache_dtype, - block_size) + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) attn_backend = create_chunked_local_attention_backend( - underlying_attn_backend, attention_chunk_size, block_size) + underlying_attn_backend, attention_chunk_size, block_size + ) else: # in v0 the local attention is handled inside the backends attn_backend = None @@ -91,4 +97,5 @@ def __init__(self, quant_config=quant_config, prefix=prefix, kv_sharing_target_layer_name=kv_sharing_target_layer_name, - attn_backend=attn_backend) + attn_backend=attn_backend, + ) diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py index 9400c5bffa38..fb7004f86538 100644 --- a/vllm/attention/layers/cross_attention.py +++ b/vllm/attention/layers/cross_attention.py @@ -8,33 +8,40 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + AttentionType, +) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig, VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - subclass_attention_backend) +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + subclass_attention_backend, +) from vllm.v1.kv_cache_interface import CrossAttentionSpec logger = init_logger(__name__) def _get_max_encoder_len(vllm_config: "VllmConfig") -> int: - """Gets the max number of encoder input tokens from the config. - """ + """Gets the max number of encoder input tokens from the config.""" sc = vllm_config.scheduler_config - assert sc and isinstance(sc.max_num_encoder_input_tokens, int), \ + assert sc and isinstance(sc.max_num_encoder_input_tokens, int), ( "max_num_encoder_input_tokens must be int for enc-dec models" + ) return sc.max_num_encoder_input_tokens -def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray, - block_table_tensor: torch.Tensor, - kv_cache_spec: CrossAttentionSpec, - device: torch.device) -> torch.Tensor: +def _get_cross_slot_mapping( + encoder_seq_lens: np.ndarray, + block_table_tensor: torch.Tensor, + kv_cache_spec: CrossAttentionSpec, + device: torch.device, +) -> torch.Tensor: """Get cross-attention slot mappings.""" block_size = kv_cache_spec.block_size @@ -58,9 +65,7 @@ def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray, needed_block_ids = req_block_ids[:num_blocks_needed] # All needed blocks are allocated - i_values = torch.arange(encoder_seq_len, - dtype=torch.int64, - device=device) + i_values = torch.arange(encoder_seq_len, dtype=torch.int64, device=device) block_indices = i_values // block_size block_offsets = i_values % block_size block_numbers = needed_block_ids[block_indices] @@ -76,42 +81,48 @@ def _get_cross_slot_mapping(encoder_seq_lens: np.ndarray, @functools.lru_cache def create_cross_attention_backend( - underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + underlying_attn_backend: AttentionBackend, +) -> type[AttentionBackend]: prefix = "CrossAttention_" underlying_builder = underlying_attn_backend.get_builder_cls() class CrossAttentionBuilder(underlying_builder): # type: ignore - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> AttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: new_metadata = copy(common_attn_metadata) new_metadata.causal = False max_encoder_len = _get_max_encoder_len(self.vllm_config) new_metadata.max_seq_len = max_encoder_len new_metadata.seq_lens = torch.full( - (new_metadata.num_reqs, ), + (new_metadata.num_reqs,), max_encoder_len, dtype=torch.int32, device=self.device, ) new_metadata.seq_lens_cpu = torch.full( - (new_metadata.num_reqs, ), + (new_metadata.num_reqs,), max_encoder_len, dtype=torch.int32, device="cpu", ) new_metadata.slot_mapping = _get_cross_slot_mapping( - new_metadata.encoder_seq_lens, new_metadata.block_table_tensor, - self.kv_cache_spec, self.device) + new_metadata.encoder_seq_lens, + new_metadata.block_table_tensor, + self.kv_cache_spec, + self.device, + ) return super().build(common_prefix_len, new_metadata, fast_build) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=CrossAttentionBuilder) + builder_cls=CrossAttentionBuilder, + ) return attn_backend @@ -122,13 +133,15 @@ class CrossAttention(Attention): Handles attention between decoder queries and encoder keys/values. """ - def __init__(self, - num_heads: int, - head_size: int, - scale: float, - cache_config: Optional[CacheConfig] = None, - attn_type: Optional[str] = None, - **kwargs): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs, + ): dtype = torch.get_default_dtype() if cache_config is not None: @@ -139,24 +152,26 @@ def __init__(self, block_size = 16 if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend(head_size, dtype, - kv_cache_dtype, - block_size) + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) - attn_backend = create_cross_attention_backend( - underlying_attn_backend) + attn_backend = create_cross_attention_backend(underlying_attn_backend) else: # in v0 cross attention is handled inside the backends attn_backend = None if attn_type is not None: assert attn_type == AttentionType.ENCODER_DECODER, ( - "CrossAttention only supports AttentionType.ENCODER_DECODER") - - super().__init__(num_heads=num_heads, - head_size=head_size, - scale=scale, - cache_config=cache_config, - attn_backend=attn_backend, - attn_type=AttentionType.ENCODER_DECODER, - **kwargs) + "CrossAttention only supports AttentionType.ENCODER_DECODER" + ) + + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_DECODER, + **kwargs, + ) diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py index cea05df5b96d..f49f195563dc 100644 --- a/vllm/attention/layers/encoder_only_attention.py +++ b/vllm/attention/layers/encoder_only_attention.py @@ -7,36 +7,45 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionMetadata, + AttentionType, +) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - subclass_attention_backend) +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + subclass_attention_backend, +) @functools.lru_cache def create_encoder_only_attention_backend( - underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: + underlying_attn_backend: AttentionBackend, +) -> type[AttentionBackend]: prefix = "EncoderOnlyAttention_" underlying_builder = underlying_attn_backend.get_builder_cls() class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> AttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: new_common_attn_metadata = copy(common_attn_metadata) new_common_attn_metadata.causal = False - return super().build(common_prefix_len, new_common_attn_metadata, - fast_build) + return super().build( + common_prefix_len, new_common_attn_metadata, fast_build + ) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=EncoderOnlyAttentionBuilder) + builder_cls=EncoderOnlyAttentionBuilder, + ) return attn_backend @@ -46,13 +55,15 @@ class EncoderOnlyAttention(Attention): Encoder attention is a special case that doesn't need a KV Cache. """ - def __init__(self, - num_heads: int, - head_size: int, - scale: float, - cache_config: Optional[CacheConfig] = None, - attn_type: Optional[str] = None, - **kwargs): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + cache_config: Optional[CacheConfig] = None, + attn_type: Optional[str] = None, + **kwargs, + ): dtype = torch.get_default_dtype() if cache_config is not None: @@ -63,24 +74,28 @@ def __init__(self, block_size = 16 if envs.VLLM_USE_V1: - underlying_attn_backend = get_attn_backend(head_size, dtype, - kv_cache_dtype, - block_size) + underlying_attn_backend = get_attn_backend( + head_size, dtype, kv_cache_dtype, block_size + ) attn_backend = create_encoder_only_attention_backend( - underlying_attn_backend) + underlying_attn_backend + ) else: # in v0 encoder only attention is handled inside the backends attn_backend = None if attn_type is not None: - assert attn_type == AttentionType.ENCODER_ONLY, \ + assert attn_type == AttentionType.ENCODER_ONLY, ( "EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY" - - super().__init__(num_heads=num_heads, - head_size=head_size, - scale=scale, - cache_config=cache_config, - attn_backend=attn_backend, - attn_type=AttentionType.ENCODER_ONLY, - **kwargs) + ) + + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + cache_config=cache_config, + attn_backend=attn_backend, + attn_type=AttentionType.ENCODER_ONLY, + **kwargs, + ) diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index bf4b06512a3c..aa791fe97006 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -25,72 +25,73 @@ def cdiv_fn(x, y): @triton.jit def kernel_paged_attention_2d( - output_ptr, # [num_tokens, num_query_heads, head_size] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] - value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - out_scale_inv, - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - num_queries_per_kv_padded: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - output_stride_0: tl.int64, # int - output_stride_1: tl.int64, # int, should be equal to head_size - BLOCK_SIZE: tl.constexpr, # int - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - x: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.int64, # int - stride_k_cache_4: tl.int64, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.int64, # int - filter_by_query_len: tl.constexpr, # bool - query_start_len_ptr, # [num_seqs+1] - USE_SINKS: tl.constexpr, # bool - USE_FP8: tl.constexpr, - FP8_MIN: tl.constexpr = float8_info.min, - FP8_MAX: tl.constexpr = float8_info.max): + output_ptr, # [num_tokens, num_query_heads, head_size] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + out_scale_inv, + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + num_queries_per_kv_padded: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + BLOCK_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + x: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.int64, # int + stride_k_cache_4: tl.int64, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.int64, # int + filter_by_query_len: tl.constexpr, # bool + query_start_len_ptr, # [num_seqs+1] + USE_SINKS: tl.constexpr, # bool + USE_FP8: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): seq_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) if filter_by_query_len: cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) - cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + - 1) - cur_batch_query_len = cur_batch_in_all_stop_index \ - - cur_batch_in_all_start_index + cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index if cur_batch_query_len > 1: return else: cur_batch_in_all_start_index = seq_idx query_head_idx = kv_head_idx * num_queries_per_kv + tl.arange( - 0, num_queries_per_kv_padded) + 0, num_queries_per_kv_padded + ) - query_offset = (cur_batch_in_all_start_index * query_stride_0 + - query_head_idx[:, None] * query_stride_1) + query_offset = ( + cur_batch_in_all_start_index * query_stride_0 + + query_head_idx[:, None] * query_stride_1 + ) head_mask = query_head_idx < (kv_head_idx + 1) * num_queries_per_kv head_mask = head_mask & (query_head_idx < num_query_heads) - dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, - 0).to(tl.int1) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) # Q : (num_queries_per_kv, HEAD_SIZE,) Q = tl.load( @@ -102,9 +103,7 @@ def kernel_paged_attention_2d( block_table_offset = seq_idx * block_table_stride if not USE_SINKS: - M = tl.full([num_queries_per_kv_padded], - float("-inf"), - dtype=tl.float32) + M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32) else: M = tl.load( sink_ptr + query_head_idx, @@ -113,43 +112,43 @@ def kernel_paged_attention_2d( ).to(dtype=tl.float32) L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32) - acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], - dtype=tl.float32) + acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED], dtype=tl.float32) # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_head_idx, - mask=head_mask, - other=0.0) + alibi_slope = tl.load( + alibi_slopes_ptr + query_head_idx, mask=head_mask, other=0.0 + ) num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) # iterate through tiles for j in range(0, num_blocks): - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j) offs_n = tl.arange(0, BLOCK_SIZE) offs_d = tl.arange(0, HEAD_SIZE_PADDED) - v_offset = (physical_block_idx * stride_v_cache_0 + - kv_head_idx * stride_v_cache_1 + - offs_d[None, :] * stride_v_cache_2 + - offs_n[:, None] * stride_v_cache_3) + v_offset = ( + physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_d[None, :] * stride_v_cache_2 + + offs_n[:, None] * stride_v_cache_3 + ) - k_offset = (physical_block_idx * stride_k_cache_0 + - kv_head_idx * stride_k_cache_1 + - (offs_d[:, None] // x) * stride_k_cache_2 + - offs_n[None, :] * stride_k_cache_3 + - (offs_d[:, None] % x) * stride_k_cache_4) + k_offset = ( + physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + (offs_d[:, None] // x) * stride_k_cache_2 + + offs_n[None, :] * stride_k_cache_3 + + (offs_d[:, None] % x) * stride_k_cache_4 + ) # K : (HEAD_SIZE, BLOCK_SIZE) - K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None], - other=0.0) + K_load = tl.load(key_cache_ptr + k_offset, mask=dim_mask[:, None], other=0.0) if K_load.dtype.is_fp8(): K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) @@ -157,9 +156,7 @@ def kernel_paged_attention_2d( K = K_load # V : (BLOCK_SIZE, HEAD_SIZE) - V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :], - other=0.0) + V_load = tl.load(value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0) if V_load.dtype.is_fp8(): V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) @@ -171,15 +168,13 @@ def kernel_paged_attention_2d( seq_mask = seq_offset[None, :] < boundary # S : (num_queries_per_kv, BLOCK_SIZE,) - S = tl.where(head_mask[:, None] & seq_mask, 0.0, - float("-inf")).to(tl.float32) + S = tl.where(head_mask[:, None] & seq_mask, 0.0, float("-inf")).to(tl.float32) S += scale * tl.dot(Q, K) context_len = seq_len - 1 if SLIDING_WINDOW > 0: - S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, - -10000) + S = tl.where((context_len - seq_offset) < SLIDING_WINDOW, S, -10000) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -213,12 +208,13 @@ def kernel_paged_attention_2d( acc = acc * tl.load(out_scale_inv) acc = tl.clamp(acc, FP8_MIN, FP8_MAX) - output_offset = (cur_batch_in_all_start_index * output_stride_0 + - query_head_idx * output_stride_1) + output_offset = ( + cur_batch_in_all_start_index * output_stride_0 + + query_head_idx * output_stride_1 + ) tl.store( - output_ptr + output_offset[:, None] + - tl.arange(0, HEAD_SIZE_PADDED)[None, :], + output_ptr + output_offset[:, None] + tl.arange(0, HEAD_SIZE_PADDED)[None, :], acc, mask=dim_mask[None, :] & head_mask[:, None], ) @@ -246,9 +242,8 @@ def chunked_prefill_paged_decode( # Optional tensor for sinks sinks=None, ): - if sm_scale is None: - sm_scale = 1.0 / (query.shape[1]**0.5) + sm_scale = 1.0 / (query.shape[1] ** 0.5) use_alibi_slopes = alibi_slopes is not None @@ -302,10 +297,10 @@ def chunked_prefill_paged_decode( key_cache = key_cache.view(target_dtype) value_cache = value_cache.view(target_dtype) - num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), - 16) + num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv), 16) from vllm.platforms.rocm import use_rocm_custom_paged_attention + use_custom = use_rocm_custom_paged_attention( query.dtype, head_size, @@ -319,13 +314,13 @@ def chunked_prefill_paged_decode( ) if use_custom: _PARTITION_SIZE_ROCM = 256 - max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) // - _PARTITION_SIZE_ROCM) + max_num_partitions = ( + max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM assert _PARTITION_SIZE_ROCM % block_size == 0 total_num_seq = block_table.shape[0] tmp_output = torch.empty( - size=(total_num_seq, num_query_heads, max_num_partitions, - head_size), + size=(total_num_seq, num_query_heads, max_num_partitions, head_size), dtype=query.dtype, device=output.device, ) @@ -358,10 +353,12 @@ def chunked_prefill_paged_decode( fp8_out_scale=output_scale, ) else: - kernel_paged_attention_2d[( - num_seqs, - num_kv_heads, - )]( + kernel_paged_attention_2d[ + ( + num_seqs, + num_kv_heads, + ) + ]( output_ptr=output, query_ptr=query, key_cache_ptr=key_cache, @@ -373,8 +370,7 @@ def chunked_prefill_paged_decode( scale=sm_scale, k_scale=k_scale, v_scale=v_scale, - out_scale_inv=1.0 / - output_scale if output_scale is not None else 1.0, + out_scale_inv=1.0 / output_scale if output_scale is not None else 1.0, num_query_heads=num_query_heads, num_queries_per_kv=num_queries_per_kv, num_queries_per_kv_padded=num_queries_per_kv_padded, diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index e659f1f3eae9..097fbae68cda 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -7,11 +7,21 @@ @triton.jit -def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr, - vlse_ptr, outputs_stride_B, outputs_stride_H, - outputs_stride_D, lses_stride_N, lses_stride_B, - lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr, - N_ROUNDED: tl.constexpr): +def _correct_attn_cp_out_kernel( + outputs_ptr, + new_output_ptr, + lses_ptr, + vlse_ptr, + outputs_stride_B, + outputs_stride_H, + outputs_stride_D, + lses_stride_N, + lses_stride_B, + lses_stride_H, + lse_idx, + HEAD_DIM: tl.constexpr, + N_ROUNDED: tl.constexpr, +): """ Apply the all-gathered lses to correct each local rank's attention output. we still need perform a cross-rank reduction to obtain the @@ -33,12 +43,15 @@ def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr, num_n_offsets = tl.arange(0, N_ROUNDED) # shape = [N] - lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \ - lses_stride_B + head_idx * lses_stride_H + lse_offsets = ( + num_n_offsets * lses_stride_N + + batch_idx * lses_stride_B + + head_idx * lses_stride_H + ) # calc final lse lse = tl.load(lses_ptr + lse_offsets) - lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse) + lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse) lse_max = tl.max(lse, axis=0) lse -= lse_max lse_exp = tl.exp(lse) @@ -50,18 +63,23 @@ def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr, tl.store(vlse_ptr + lse_offsets, lse) # shape = [D] - output_offsets = batch_idx * outputs_stride_B + \ - head_idx * outputs_stride_H + \ - d_offsets * outputs_stride_D + output_offsets = ( + batch_idx * outputs_stride_B + + head_idx * outputs_stride_H + + d_offsets * outputs_stride_D + ) # correct output - lse_offset = lse_idx * lses_stride_N + batch_idx * \ - lses_stride_B + head_idx * lses_stride_H + lse_offset = ( + lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H + ) lse_tmp = tl.load(lses_ptr + lse_offset) lse_finally = lse_tmp - lse lse_finally = tl.where( - (lse_finally != lse_finally) | (lse_finally == float('inf')), - -float('inf'), lse_finally) + (lse_finally != lse_finally) | (lse_finally == float("inf")), + -float("inf"), + lse_finally, + ) factor = tl.exp(lse_finally) output = tl.load(outputs_ptr + output_offsets) output = output * factor @@ -70,8 +88,7 @@ def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr, class CPTritonContext: - """ The CPTritonContext is used to avoid recompilation of the Triton JIT. - """ + """The CPTritonContext is used to avoid recompilation of the Triton JIT.""" def __init__(self): self.inner_kernel = None @@ -84,8 +101,8 @@ def call_kernel(self, kernel, grid, *regular_args, **const_args): def correct_attn_out( - out: torch.Tensor, lses: torch.Tensor, cp_rank: int, - ctx: CPTritonContext) -> tuple[torch.Tensor, torch.Tensor]: + out: torch.Tensor, lses: torch.Tensor, cp_rank: int, ctx: CPTritonContext +) -> tuple[torch.Tensor, torch.Tensor]: """Correct the attention output using the all-gathered lses. Args: @@ -103,22 +120,22 @@ def correct_attn_out( lse = torch.empty_like(lses[0]) grid = (out.shape[0], out.shape[1], 1) - regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), - cp_rank) + regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(), cp_rank) const_args = { "HEAD_DIM": out.shape[-1], "N_ROUNDED": lses.shape[0], } - ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, - **const_args) + ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args) return out, lse -def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, - cp_attn_lse: torch.Tensor, - cp_group: GroupCoordinator, - ctx: CPTritonContext = None): +def cp_lse_ag_out_rs( + cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext = None, +): """ cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ] @@ -129,9 +146,11 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, if ctx is None: ctx = CPTritonContext() - lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape, - dtype=cp_attn_lse.dtype, - device=cp_attn_lse.device) + lses = torch.empty( + (cp_group.world_size,) + cp_attn_lse.shape, + dtype=cp_attn_lse.dtype, + device=cp_attn_lse.device, + ) cp_attn_lse = cp_attn_lse.contiguous() lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses) @@ -142,15 +161,15 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, @triton.jit def _pack_seq_kernel( - x_ptr, # [N, D] - out_ptr, # [B, Lmax, D] - lengths_ptr, # *i32, [B] - N: tl.constexpr, - D: tl.constexpr, - Lmax: tl.constexpr, - PAD_VALUE: tl.constexpr, - BLOCK_T: tl.constexpr, # timesteps per program - BLOCK_D: tl.constexpr # features per program + x_ptr, # [N, D] + out_ptr, # [B, Lmax, D] + lengths_ptr, # *i32, [B] + N: tl.constexpr, + D: tl.constexpr, + Lmax: tl.constexpr, + PAD_VALUE: tl.constexpr, + BLOCK_T: tl.constexpr, # timesteps per program + BLOCK_D: tl.constexpr, # features per program ): pid_b = tl.program_id(0) # batch id pid_t = tl.program_id(1) # block over time dimension @@ -176,8 +195,7 @@ def _pack_seq_kernel( x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :] # out_ptr: row-major [B, Lmax, D] - out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, - None] * D + off_d[None, :] + out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] # Initialize with PAD (cast will occur as needed based on out_ptr dtype) d_mask = off_d[None, :] < D @@ -189,21 +207,23 @@ def _pack_seq_kernel( tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask) -def pack_seq_triton(x: torch.Tensor, - lengths: torch.Tensor, - pad_value: float = -float('inf'), - block_t: int = 64, - block_d: int = 64) -> torch.Tensor: +def pack_seq_triton( + x: torch.Tensor, + lengths: torch.Tensor, + pad_value: float = -float("inf"), + block_t: int = 64, + block_d: int = 64, +) -> torch.Tensor: """ Pack sequences of different lengths into a batched tensor. - + Args: x: [N, ...] - input tensor where N is total number of tokens lengths: [B] - sequence lengths for each batch pad_value: value to use for padding block_t: block size for time dimension block_d: block size for feature dimension - + Returns: packed: [B, Lmax, ...] - packed tensor """ @@ -226,17 +246,19 @@ def pack_seq_triton(x: torch.Tensor, out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype) grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) - _pack_seq_kernel[grid](x_reshaped, - out, - lengths.int(), - N, - D, - Lmax, - PAD_VALUE=float(pad_value), - BLOCK_T=block_t, - BLOCK_D=block_d, - num_warps=4, - num_stages=2) + _pack_seq_kernel[grid]( + x_reshaped, + out, + lengths.int(), + N, + D, + Lmax, + PAD_VALUE=float(pad_value), + BLOCK_T=block_t, + BLOCK_D=block_d, + num_warps=4, + num_stages=2, + ) # Reshape output back to original dimensions (except first dimension) if len(original_shape) > 2: @@ -248,14 +270,14 @@ def pack_seq_triton(x: torch.Tensor, @triton.jit def _unpack_seq_triton_kernel( - packed_ptr, # [B, Lmax, D] - out_ptr, # [N, D] - lengths_ptr, # *i32, [B] - B: tl.constexpr, - Lmax: tl.constexpr, - D: tl.constexpr, - BLOCK_T: tl.constexpr, # timesteps per program - BLOCK_D: tl.constexpr # features per program + packed_ptr, # [B, Lmax, D] + out_ptr, # [N, D] + lengths_ptr, # *i32, [B] + B: tl.constexpr, + Lmax: tl.constexpr, + D: tl.constexpr, + BLOCK_T: tl.constexpr, # timesteps per program + BLOCK_D: tl.constexpr, # features per program ): pid_b = tl.program_id(0) # batch id pid_t = tl.program_id(1) # block over time dimension @@ -278,8 +300,7 @@ def _unpack_seq_triton_kernel( # Pointers # packed_ptr: row-major [B, Lmax, D] - packed_row_ptr = packed_ptr + (pid_b * Lmax + - off_t)[:, None] * D + off_d[None, :] + packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] # out_ptr: row-major [N, D] out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :] @@ -290,20 +311,22 @@ def _unpack_seq_triton_kernel( tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask) -def unpack_seq_triton(packed_tensor: torch.Tensor, - lengths: torch.Tensor, - block_t: int = 64, - block_d: int = 64) -> torch.Tensor: +def unpack_seq_triton( + packed_tensor: torch.Tensor, + lengths: torch.Tensor, + block_t: int = 64, + block_d: int = 64, +) -> torch.Tensor: """ Unpack a packed decode query tensor back to the original format. Efficient Triton implementation. - + Args: packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton lengths: [B] - sequence lengths for each batch block_t: block size for time dimension block_d: block size for feature dimension - + Returns: unpacked_tensor: [N, ...] where N = sum(lengths) """ @@ -321,25 +344,25 @@ def unpack_seq_triton(packed_tensor: torch.Tensor, # Calculate total number of elements N = int(lengths.sum().item()) - out = torch.empty((N, D), - device=packed_tensor.device, - dtype=packed_tensor.dtype) + out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype) grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) - _unpack_seq_triton_kernel[grid](packed_reshaped, - out, - lengths.int(), - B, - Lmax, - D, - BLOCK_T=block_t, - BLOCK_D=block_d, - num_warps=4, - num_stages=2) + _unpack_seq_triton_kernel[grid]( + packed_reshaped, + out, + lengths.int(), + B, + Lmax, + D, + BLOCK_T=block_t, + BLOCK_D=block_d, + num_warps=4, + num_stages=2, + ) # Reshape output back to original dimensions (except first dimension) if len(original_shape) > 3: - output_shape = (N, ) + original_shape[2:] + output_shape = (N,) + original_shape[2:] out = out.reshape(output_shape) return out diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 9654f9f6775a..0fe01a51ec62 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py -from typing import Optional, Tuple +from typing import Optional import torch @@ -13,6 +13,7 @@ if current_platform.is_cuda(): try: import vllm._flashmla_C # noqa: F401 + _flashmla_C_AVAILABLE = True except ImportError: _flashmla_C_AVAILABLE = False @@ -22,6 +23,7 @@ if current_platform.is_cuda(): try: import vllm._flashmla_extension_C # noqa: F401 + _flashmla_extension_C_AVAILABLE = True except ImportError: _flashmla_extension_C_AVAILABLE = False @@ -29,7 +31,7 @@ _flashmla_extension_C_AVAILABLE = False -def is_flashmla_supported() -> Tuple[bool, Optional[str]]: +def is_flashmla_supported() -> tuple[bool, Optional[str]]: """ Return: is_supported_flag, unsupported_reason (optional). """ @@ -38,42 +40,51 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]: if current_platform.get_device_capability()[0] != 9: return False, "FlashMLA is only supported on Hopper devices." if not _flashmla_C_AVAILABLE: - return False, "vllm._flashmla_C is not available, likely was not "\ - "compiled due to insufficient nvcc version or a supported arch "\ - "(only sm90a currently) was not in the list of target arches to "\ - "compile for." + return ( + False, + "vllm._flashmla_C is not available, likely was not " + "compiled due to insufficient nvcc version or a supported arch " + "(only sm90a currently) was not in the list of target arches to " + "compile for.", + ) return True, None def get_mla_metadata( - cache_seqlens: torch.Tensor, - num_q_tokens_per_head_k: int, - num_heads_k: int, - num_heads_q: Optional[int] = None, - is_fp8_kvcache: bool = False, - topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + cache_seqlens: torch.Tensor, + num_q_tokens_per_head_k: int, + num_heads_k: int, + num_heads_q: Optional[int] = None, + is_fp8_kvcache: bool = False, + topk: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - - num_q_tokens_per_head_k: + - num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. - num_heads_k: The number of k heads. - - num_heads_q: - The number of q heads. + - num_heads_q: + The number of q heads. This argument is optional when sparse attention is not enabled - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. - - topk: If not None, sparse attention will be enabled, - and only tokens in the `indices` array + - topk: If not None, sparse attention will be enabled, + and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. Returns: - - tile_scheduler_metadata: + - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32. """ return torch.ops._flashmla_C.get_mla_decoding_metadata( - cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, - is_fp8_kvcache, topk) + cache_seqlens, + num_q_tokens_per_head_k, + num_heads_k, + num_heads_q, + is_fp8_kvcache, + topk, + ) def flash_mla_with_kvcache( @@ -90,7 +101,7 @@ def flash_mla_with_kvcache( descale_k: Optional[torch.Tensor] = None, is_fp8_kvcache: bool = False, indices: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). @@ -98,26 +109,26 @@ def flash_mla_with_kvcache( - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head dimension of v. - - tile_scheduler_metadata: - (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, + - tile_scheduler_metadata: + (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. - - num_splits: + - num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. - - softmax_scale: float. - The scale of QK^T before applying softmax. + - softmax_scale: float. + The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - - descale_q: (batch_size), + - descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. - - descale_k: (batch_size), + - descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. - - is_fp8_kvcache: bool. - Whether the k_cache and v_cache are in fp8 format. + - is_fp8_kvcache: bool. + Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md - - indices: (batch_size, seq_len_q, topk), torch.int32. - If not None, sparse attention will be enabled, - and only tokens in the `indices` array will be attended to. - Invalid indices should be set to -1 or numbers >= total_seq_len_kv. + - indices: (batch_size, seq_len_q, topk), torch.int32. + If not None, sparse attention will be enabled, + and only tokens in the `indices` array will be attended to. + Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. Returns: @@ -125,26 +136,44 @@ def flash_mla_with_kvcache( - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: - softmax_scale = q.shape[-1]**(-0.5) + softmax_scale = q.shape[-1] ** (-0.5) if indices is not None: # NOTE (zyongye): sparse attention is also causal # since it only attend to the tokens before # but here `causal` should not be specified - assert not causal, \ - "causal must be `false` if sparse attention is enabled." - assert (descale_q is None) == ( - descale_k is None - ), "descale_q and descale_k should be both None or both not None" + assert not causal, "causal must be `false` if sparse attention is enabled." + assert (descale_q is None) == (descale_k is None), ( + "descale_q and descale_k should be both None or both not None" + ) if indices is None and q.element_size() == 1: out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8( - q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, - causal, tile_scheduler_metadata, num_splits, descale_q, descale_k) + q, + k_cache, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + descale_q, + descale_k, + ) else: out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( - q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, - causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache, - indices) + q, + k_cache, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + is_fp8_kvcache, + indices, + ) return out, softmax_lse @@ -154,28 +183,27 @@ def flash_mla_sparse_prefill( indices: torch.Tensor, sm_scale: float, d_v: int = 512, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Sparse attention prefill kernel Args: - q: [s_q, h_q, d_qk], bfloat16 - kv: [s_kv, h_kv, d_qk], bfloat16 - - indices: [s_q, h_kv, topk], int32. + - indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv - sm_scale: float - d_v: The dimension of value vectors. Can only be 512 Returns: - (output, max_logits, lse) - About the definition of output, + About the definition of output, max_logits and lse, please refer to README.md - output: [s_q, h_q, d_v], bfloat16 - max_logits: [s_q, h_q], float - lse: [s_q, h_q], float, 2-based log-sum-exp """ - results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, - sm_scale, d_v) + results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, sm_scale, d_v) return results diff --git a/vllm/attention/ops/merge_attn_states.py b/vllm/attention/ops/merge_attn_states.py index 5cb1a47394cf..79800eb40766 100644 --- a/vllm/attention/ops/merge_attn_states.py +++ b/vllm/attention/ops/merge_attn_states.py @@ -15,7 +15,6 @@ def merge_attn_states( suffix_lse: torch.Tensor, output_lse: Optional[torch.Tensor] = None, ) -> None: - # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel # is not support for FP8 dtype, fallback to use Triton kernel. def supported_dtypes(o: torch.Tensor) -> bool: @@ -31,13 +30,19 @@ def supported_headdim(o: torch.Tensor) -> bool: return headdim % 4 == 0 return headdim % 8 == 0 - if (current_platform.is_cuda() and supported_dtypes(output) - and supported_headdim(output)): + if ( + current_platform.is_cuda() + and supported_dtypes(output) + and supported_headdim(output) + ): from vllm._custom_ops import merge_attn_states - return merge_attn_states(output, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse) + + return merge_attn_states( + output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse + ) else: - from vllm.attention.ops.triton_merge_attn_states import ( - merge_attn_states) - return merge_attn_states(output, prefix_output, prefix_lse, - suffix_output, suffix_lse, output_lse) + from vllm.attention.ops.triton_merge_attn_states import merge_attn_states + + return merge_attn_states( + output, prefix_output, prefix_lse, suffix_output, suffix_lse, output_lse + ) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 539b57e41de7..4db7d1a3a325 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional import torch @@ -24,6 +24,7 @@ @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. seq_lens_tensor: Optional[torch.Tensor] @@ -39,9 +40,8 @@ class PagedAttentionMetadata: class PagedAttention: - @staticmethod - def get_supported_head_sizes() -> List[int]: + def get_supported_head_sizes() -> list[int]: return [32, 64, 80, 96, 112, 120, 128, 192, 256] @staticmethod @@ -51,7 +51,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", - ) -> Tuple[int, ...]: + ) -> tuple[int, ...]: return (2, num_blocks, block_size * num_kv_heads * head_size) @staticmethod @@ -59,13 +59,12 @@ def split_kv_cache( kv_cache: torch.Tensor, num_kv_heads: int, head_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: x = 16 // kv_cache.element_size() num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, - -1, x) + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) value_cache = kv_cache[1] value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) return key_cache, value_cache @@ -115,16 +114,17 @@ def forward_decode( if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: # use blocksparse paged attention block_size = value_cache.size(-1) - assert (blocksparse_block_size > 0 and - blocksparse_block_size % block_size == 0), \ - (f"{blocksparse_block_size=} needs to be a multiple of" - f"{block_size=} used in block_tables.") + assert ( + blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0 + ), ( + f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables." + ) output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) + max_num_partitions = (max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -132,8 +132,9 @@ def forward_decode( # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = (max_seq_len <= 8192 - and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + use_v1 = max_seq_len <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512 + ) if use_v1: # Run PagedAttention V1. @@ -254,7 +255,7 @@ def swap_blocks( @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], + kv_caches: list[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] diff --git a/vllm/attention/ops/pallas_kv_cache_update.py b/vllm/attention/ops/pallas_kv_cache_update.py index d75983bd407d..d0d836cc6aa5 100644 --- a/vllm/attention/ops/pallas_kv_cache_update.py +++ b/vllm/attention/ops/pallas_kv_cache_update.py @@ -33,10 +33,12 @@ def _kv_cache_update_kernel( # Copy from new_kv_hbm_ref to scratch for i in range(num_slices_per_block): offset_i = i + block_idx * num_slices_per_block - new_kv_start = jax.lax.select(offset_i < num_slices_ref[0], - slices_ref[1, offset_i], 0) - length = jax.lax.select(offset_i < num_slices_ref[0], - slices_ref[2, offset_i], 0) + new_kv_start = jax.lax.select( + offset_i < num_slices_ref[0], slices_ref[1, offset_i], 0 + ) + length = jax.lax.select( + offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0 + ) async_copy = pltpu.make_async_copy( new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], scratch.at[i, pl.ds(0, length), ...], @@ -52,10 +54,12 @@ def _kv_cache_update_kernel( async_copies.clear() for i in range(num_slices_per_block): offset_i = i + block_idx * num_slices_per_block - kv_cache_start = jax.lax.select(offset_i < num_slices_ref[0], - slices_ref[0, offset_i], 0) - length = jax.lax.select(offset_i < num_slices_ref[0], - slices_ref[2, offset_i], 0) + kv_cache_start = jax.lax.select( + offset_i < num_slices_ref[0], slices_ref[0, offset_i], 0 + ) + length = jax.lax.select( + offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0 + ) async_copy = pltpu.make_async_copy( scratch.at[i, pl.ds(0, length), ...], kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], @@ -72,12 +76,14 @@ def _kv_cache_update_kernel( static_argnames=["page_size", "num_slices_per_block"], ) def kv_cache_update( - new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim] - slices: jax. - Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) - kv_cache: jax. - Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] - num_kv_update_slices: jax.Array, # [1] + # [total_num_token, num_combined_kv_heads, head_dim] + new_kv: jax.Array, + # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) + slices: jax.Array, + # [total_num_pages * page_size, num_combined_kv_heads, head_dim] + kv_cache: jax.Array, + # [1] + num_kv_update_slices: jax.Array, *, page_size: int = 32, num_slices_per_block: int = 8, @@ -114,7 +120,7 @@ def kv_cache_update( num_scalar_prefetch=len(scalar_prefetches), in_specs=in_specs, out_specs=out_specs, - grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ), + grid=(cdiv(num_kv_update_slices[0], num_slices_per_block),), scratch_shapes=scratch_shapes, ), out_shape=out_shape, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 7e5c2b6c62e9..addf1d9dea73 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -34,62 +34,63 @@ # key=["BLOCK_SIZE", "MAX_Q_LEN", "MAX_CTX_LEN"] # ) @triton.jit -def _fwd_kernel(Q, - K, - V, - K_cache, - V_cache, - sink_ptr, - B_Loc, - sm_scale, - k_scale, - v_scale, - out_scale_inv, - B_Start_Loc, - B_Seqlen, - x: tl.constexpr, - Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_k_cache_bs, - stride_k_cache_h, - stride_k_cache_d, - stride_k_cache_bl: tl.constexpr, - stride_k_cache_x, - stride_v_cache_bs, - stride_v_cache_h, - stride_v_cache_d, - stride_v_cache_bl, - num_queries_per_kv: tl.constexpr, - IN_PRECISION: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DMODEL_PADDED: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, - num_unroll_cache: tl.constexpr, - num_unroll_request: tl.constexpr, - SKIP_DECODE: tl.constexpr, - USE_SINKS: tl.constexpr, - USE_FP8: tl.constexpr, - MAX_Q_LEN: tl.constexpr = 0, - MAX_CTX_LEN: tl.constexpr = 0, - FP8_MIN: tl.constexpr = float8_info.min, - FP8_MAX: tl.constexpr = float8_info.max): - +def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + sink_ptr, + B_Loc, + sm_scale, + k_scale, + v_scale, + out_scale_inv, + B_Start_Loc, + B_Seqlen, + x: tl.constexpr, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl: tl.constexpr, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: tl.constexpr, + IN_PRECISION: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL_PADDED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + num_unroll_cache: tl.constexpr, + num_unroll_request: tl.constexpr, + SKIP_DECODE: tl.constexpr, + USE_SINKS: tl.constexpr, + USE_FP8: tl.constexpr, + MAX_Q_LEN: tl.constexpr = 0, + MAX_CTX_LEN: tl.constexpr = 0, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) @@ -99,8 +100,7 @@ def _fwd_kernel(Q, cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len if SKIP_DECODE and cur_batch_query_len == 1: @@ -120,17 +120,21 @@ def _fwd_kernel(Q, # [M]; starts at current position in query offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # [M,D] - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, - 0).to(tl.int1) # [D] - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_query_len), - other=0.0) # [M,D] + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + + dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to( + tl.int1 + ) # [D] + + q = tl.load( + Q + off_q, + mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len), + other=0.0, + ) # [M,D] # initialize pointer to m and l if not USE_SINKS: @@ -146,32 +150,43 @@ def _fwd_kernel(Q, acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] # compute query against context (no causal mask here) - for start_n in tl.range(0, cur_batch_ctx_len, BLOCK_SIZE, \ - loop_unroll_factor=num_unroll_cache): + for start_n in tl.range( + 0, cur_batch_ctx_len, BLOCK_SIZE, loop_unroll_factor=num_unroll_cache + ): start_n = tl.multiple_of(start_n, BLOCK_SIZE) # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - (start_n // BLOCK_SIZE) * stride_b_loc_s).to(tl.int64) + bn = tl.load( + B_Loc + + cur_batch * stride_b_loc_b + + (start_n // BLOCK_SIZE) * stride_b_loc_s + ).to(tl.int64) # [D,BLOCK_SIZE] off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) + bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_bs_n[None, :]) % BLOCK_SIZE) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x + ) # [BLOCK_SIZE,D] - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - offs_bs_n[:, None] * stride_v_cache_bl) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + offs_bs_n[:, None] * stride_v_cache_bl + ) - if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ - BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + if ( + start_n + BLOCK_SIZE > cur_batch_ctx_len + or BLOCK_DMODEL != BLOCK_DMODEL_PADDED + ): k_load = tl.load( K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] + mask=dim_mask[:, None] + & ((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len), + other=0.0, + ) # [D,N] else: k_load = tl.load(K_cache + off_k) @@ -182,8 +197,9 @@ def _fwd_kernel(Q, qk = tl.zeros([BLOCK_M, BLOCK_SIZE], dtype=tl.float32) # [M,N] qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) + qk = tl.where( + (start_n + offs_bs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf") + ) qk *= sm_scale if SLIDING_WINDOW > 0: # (cur_batch_ctx_len + offs_m[:, None]) are the positions of @@ -197,9 +213,12 @@ def _fwd_kernel(Q, # sliding window may lead to the entire row being masked. # This then makes m_ij contain -inf, which causes NaNs in # exp(). - qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - - (start_n + offs_bs_n[None, :]) < SLIDING_WINDOW, qk, - -10000) + qk = tl.where( + (cur_batch_ctx_len + offs_m[:, None]) - (start_n + offs_bs_n[None, :]) + < SLIDING_WINDOW, + qk, + -10000, + ) # compute running maximum m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) @@ -209,13 +228,16 @@ def _fwd_kernel(Q, acc = acc * alpha[:, None] # update acc - if start_n + BLOCK_SIZE > cur_batch_ctx_len or \ - BLOCK_DMODEL != BLOCK_DMODEL_PADDED: + if ( + start_n + BLOCK_SIZE > cur_batch_ctx_len + or BLOCK_DMODEL != BLOCK_DMODEL_PADDED + ): v_load = tl.load( V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), - other=0.0) # [N,D] + mask=dim_mask[None, :] + & ((start_n + offs_bs_n[:, None]) < cur_batch_ctx_len), + other=0.0, + ) # [N,D] else: v_load = tl.load(V_cache + off_v) @@ -230,10 +252,16 @@ def _fwd_kernel(Q, l_i = l_i * alpha + l_ij m_i = m_ij - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) + off_k = ( + offs_n[None, :] * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd + ) + off_v = ( + offs_n[:, None] * stride_vbs + + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd + ) k_ptrs = K + off_k v_ptrs = V + off_v @@ -241,27 +269,32 @@ def _fwd_kernel(Q, block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) # compute query against itself (with causal mask) - for start_n in tl.range(0, \ - block_mask * (start_m + 1) * BLOCK_M, BLOCK_N, \ - loop_unroll_factor=num_unroll_request): + for start_n in tl.range( + 0, + block_mask * (start_m + 1) * BLOCK_M, + BLOCK_N, + loop_unroll_factor=num_unroll_request, + ): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_query_len), - other=0.0) + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] + & ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk *= sm_scale # apply causal mask - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) if SLIDING_WINDOW > 0: qk = tl.where( offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, - qk, -10000) + qk, + -10000, + ) # compute running maximum m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) @@ -271,11 +304,12 @@ def _fwd_kernel(Q, acc = acc * alpha[:, None] # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_query_len), - other=0.0) + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] + & ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0, + ) p = p.to(v.dtype) acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) @@ -286,15 +320,18 @@ def _fwd_kernel(Q, acc = acc / l_i[:, None] # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o if USE_FP8: acc = acc * tl.load(out_scale_inv) acc = tl.clamp(acc, FP8_MIN, FP8_MAX) - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) + tl.store( + out_ptrs, acc, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len) + ) return @@ -357,12 +394,17 @@ def _fwd_kernel_flash_attn_v2( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - q = tl.load(Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0, + ) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -372,26 +414,36 @@ def _fwd_kernel_flash_attn_v2( for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0).to(tl.int64) + bn = tl.load( + B_Loc + + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0, + ).to(tl.int64) off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) + bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x + ) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl + ) + k = tl.load( + K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) + qk = tl.where( + (start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf") + ) qk *= sm_scale # -- compute m_ij, p, l_ij @@ -410,9 +462,11 @@ def _fwd_kernel_flash_attn_v2( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, - other=0.0) + v = tl.load( + V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0, + ) p = p.to(v.dtype) acc += tl.dot(p, v) @@ -420,30 +474,34 @@ def _fwd_kernel_flash_attn_v2( l_i = l_i_new m_i = m_i_new - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) + off_k = ( + offs_n[None, :] * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd + ) + off_v = ( + offs_n[:, None] * stride_vbs + + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd + ) k_ptrs = K + off_k v_ptrs = V + off_v - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -461,11 +519,11 @@ def _fwd_kernel_flash_attn_v2( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + - (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0, + ) p = p.to(v.dtype) acc += tl.dot(p, v) @@ -475,12 +533,15 @@ def _fwd_kernel_flash_attn_v2( # acc /= l_i[:, None] # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + tl.store( + out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len + ) return @@ -545,8 +606,7 @@ def _fwd_kernel_alibi( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) - cur_batch_query_len = (cur_batch_in_all_stop_index - - cur_batch_in_all_start_index) + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len if SKIP_DECODE and cur_batch_query_len == 1: @@ -558,16 +618,22 @@ def _fwd_kernel_alibi( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + - cur_head * stride_qh + offs_d[None, :] * stride_qd) - - dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) - - q = tl.load(Q + off_q, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + + dim_mask = tl.where(tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to( + tl.int1 + ) + + q = tl.load( + Q + off_q, + mask=dim_mask[None, :] + & (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0, + ) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -580,23 +646,31 @@ def _fwd_kernel_alibi( for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + - ((start_n + offs_n) // block_size) * stride_b_loc_s, - mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0).to(tl.int64) + bn = tl.load( + B_Loc + + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0, + ).to(tl.int64) off_k = ( - bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + - (offs_d[:, None] // x) * stride_k_cache_d + - ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + - (offs_d[:, None] % x) * stride_k_cache_x) - off_v = (bn[:, None] * stride_v_cache_bs + - cur_kv_head * stride_v_cache_h + - offs_d[None, :] * stride_v_cache_d + - (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) - k_load = tl.load(K_cache + off_k, - mask=dim_mask[:, None] & - ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) # [D,N] + bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x + ) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl + ) + k_load = tl.load( + K_cache + off_k, + mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0, + ) # [D,N] if k_load.dtype.is_fp8(): k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) @@ -605,16 +679,20 @@ def _fwd_kernel_alibi( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, - float("-inf")) + qk = tl.where( + (start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf") + ) qk *= sm_scale # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope + alibi = ( + tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None] + ) * alibi_slope alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, - float("-inf")) + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, + float("-inf"), + ) qk += alibi alibi_start_k += BLOCK_N @@ -634,30 +712,36 @@ def _fwd_kernel_alibi( # acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v_load = tl.load(V_cache + off_v, - mask=dim_mask[None, :] & - ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) + v_load = tl.load( + V_cache + off_v, + mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0, + ) if v_load.dtype.is_fp8(): v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: v = v_load p = p.to(v.dtype) - acc = tl.dot(p, v, acc=acc, input_precision='ieee') + acc = tl.dot(p, v, acc=acc, input_precision="ieee") # update m_i and l_i l_i = l_i_new m_i = m_i_new - off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + - offs_d[:, None] * stride_kd) - off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + - offs_d[None, :] * stride_vd) + off_k = ( + offs_n[None, :] * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd + ) + off_v = ( + offs_n[:, None] * stride_vbs + + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd + ) k_ptrs = K + off_k v_ptrs = V + off_v - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) # init alibi alibi_slope = tl.load(Alibi_slopes + cur_head) @@ -672,22 +756,25 @@ def _fwd_kernel_alibi( # -- compute qk ---- k = tl.load( k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) + mask=dim_mask[:, None] + & ((start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk = tl.dot(q, k, acc=qk, input_precision='ieee') + qk = tl.dot(q, k, acc=qk, input_precision="ieee") qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, - float("-inf")) + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) # load alibi - alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - - alibi_start_q[:, None]) * alibi_slope + alibi = ( + tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - alibi_start_q[:, None] + ) * alibi_slope alibi = tl.where( - (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), alibi, - float("-inf")) + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, + float("-inf"), + ) qk += alibi alibi_start_k += BLOCK_N @@ -709,12 +796,13 @@ def _fwd_kernel_alibi( # update acc v = tl.load( v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) - < cur_batch_seq_len - cur_batch_ctx_len), - other=0.0) + mask=dim_mask[None, :] + & ((start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0, + ) p = p.to(v.dtype) - acc = tl.dot(p, v, acc=acc, input_precision='ieee') + acc = tl.dot(p, v, acc=acc, input_precision="ieee") # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -722,45 +810,51 @@ def _fwd_kernel_alibi( acc = acc / l_i[:, None] # initialize pointers to output - off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + - cur_head * stride_oh + offs_d[None, :] * stride_od) + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) out_ptrs = Out + off_o - tl.store(out_ptrs, - acc, - mask=dim_mask[None, :] & - (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + tl.store( + out_ptrs, + acc, + mask=dim_mask[None, :] + & (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + ) return @torch.inference_mode() -def context_attention_fwd(q, - k, - v, - o, - kv_cache_dtype: str, - k_cache, - v_cache, - b_loc, - b_start_loc, - b_seq_len, - max_seq_len, - max_input_len, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - alibi_slopes=None, - sliding_window=None, - sm_scale=None, - skip_decode=False, - fp8_out_scale=None, - sinks=None): - +def context_attention_fwd( + q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + max_seq_len, + max_input_len, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + alibi_slopes=None, + sliding_window=None, + sm_scale=None, + skip_decode=False, + fp8_out_scale=None, + sinks=None, +): q_dtype_is_f32 = q.dtype is torch.float32 # Turing does have tensor core for float32 multiplication # use ieee as fallback for triton kernels work. There is also # warning on vllm/config.py to inform users this fallback # implementation - IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None + IN_PRECISION = "ieee" if IS_TURING and q_dtype_is_f32 else None # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton @@ -778,10 +872,15 @@ def context_attention_fwd(q, k_cache = k_cache.view(target_dtype) v_cache = v_cache.view(target_dtype) - if (k_cache.dtype == torch.uint8 - or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): - raise ValueError("kv_cache_dtype='auto' unsupported for\ - FP8 KV Cache prefill kernel") + if ( + k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 + and kv_cache_dtype == "auto" + ): + raise ValueError( + "kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel" + ) # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] @@ -843,13 +942,11 @@ def context_attention_fwd(q, k_cache.stride(1), k_cache.stride(2), k_cache.stride(3), - k_cache.stride( - 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x] v_cache.stride(0), v_cache.stride(1), v_cache.stride(2), - v_cache.stride( - 3), #[num_blocks, num_kv_heads, head_size, block_size] + v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size] num_queries_per_kv=num_queries_per_kv, IN_PRECISION=IN_PRECISION, BLOCK_M=BLOCK, @@ -867,8 +964,7 @@ def context_attention_fwd(q, if current_platform.is_rocm(): extra_kargs = {"kpack": 1, "waves_per_eu": 2} - grid = lambda META: (batch, head, - triton.cdiv(max_input_len, META["BLOCK_M"])) + grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) _fwd_kernel[grid]( q, k, @@ -903,12 +999,11 @@ def context_attention_fwd(q, k_cache.stride(1), k_cache.stride(2), k_cache.stride(3), - k_cache.stride( - 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + k_cache.stride(4), # [num_blocks, num_kv_heads, head_size/x, block_size, x] v_cache.stride(0), v_cache.stride(1), v_cache.stride(2), - v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + v_cache.stride(3), # [num_blocks, num_kv_heads, head_size, block_size] BLOCK_SIZE=v_cache.shape[3], num_queries_per_kv=num_queries_per_kv, IN_PRECISION=IN_PRECISION, @@ -924,5 +1019,6 @@ def context_attention_fwd(q, num_warps=4, num_stages=1, USE_SINKS=sinks is not None, - **extra_kargs) + **extra_kargs, + ) return diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index d91cda255ff3..c358b5971f86 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -9,18 +9,16 @@ from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer -def get_aiter_mla_metadata(max_batch_size: int, block_size: int, - max_block_per_batch: int, - device: torch.device) -> tuple[torch.Tensor, ...]: - paged_kv_indices = torch.zeros(max_batch_size * max_block_per_batch, - dtype=torch.int32, - device=device) - paged_kv_indptr = torch.zeros(max_batch_size + 1, - dtype=torch.int32, - device=device) - paged_kv_last_page_lens = torch.full((max_batch_size, ), - block_size, - dtype=torch.int32) +def get_aiter_mla_metadata( + max_batch_size: int, block_size: int, max_block_per_batch: int, device: torch.device +) -> tuple[torch.Tensor, ...]: + paged_kv_indices = torch.zeros( + max_batch_size * max_block_per_batch, dtype=torch.int32, device=device + ) + paged_kv_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int32, device=device) + paged_kv_last_page_lens = torch.full( + (max_batch_size,), block_size, dtype=torch.int32 + ) qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device) return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr @@ -37,18 +35,18 @@ def aiter_mla_decode_fwd( kv_last_page_lens: Optional[torch.Tensor] = None, logit_cap: float = 0.0, ): - - torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, - kv_buffer.view( - -1, 1, 1, q.shape[-1]), - o, - qo_indptr, - max_seqlen_qo, - kv_indptr, - kv_indices, - kv_last_page_lens, - sm_scale=sm_scale, - logit_cap=logit_cap) + torch.ops.vllm.rocm_aiter_mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + max_seqlen_qo, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) def mla_decode_fwd_impl( @@ -65,16 +63,18 @@ def mla_decode_fwd_impl( ) -> None: from aiter.mla import mla_decode_fwd - mla_decode_fwd(q, - kv_buffer.view(-1, 1, 1, q.shape[-1]), - o, - qo_indptr, - kv_indptr, - kv_indices, - kv_last_page_lens, - max_seqlen_qo, - sm_scale=sm_scale, - logit_cap=logit_cap) + mla_decode_fwd( + q, + kv_buffer.view(-1, 1, 1, q.shape[-1]), + o, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale=sm_scale, + logit_cap=logit_cap, + ) def mla_decode_fwd_fake( @@ -96,9 +96,11 @@ def mla_decode_fwd_fake( if is_torch_equal_or_newer("2.7.0"): tags = () else: - tags = (torch.Tag.needs_fixed_stride_order, ), - direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", - op_func=mla_decode_fwd_impl, - mutates_args=["o"], - fake_impl=mla_decode_fwd_fake, - tags=tags) + tags = ((torch.Tag.needs_fixed_stride_order,),) + direct_register_custom_op( + op_name="rocm_aiter_mla_decode_fwd", + op_func=mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=mla_decode_fwd_fake, + tags=tags, + ) diff --git a/vllm/attention/ops/rocm_aiter_paged_attn.py b/vllm/attention/ops/rocm_aiter_paged_attn.py index 2a0336de8cf7..069cfcaf00aa 100644 --- a/vllm/attention/ops/rocm_aiter_paged_attn.py +++ b/vllm/attention/ops/rocm_aiter_paged_attn.py @@ -13,7 +13,6 @@ class AITERPagedAttention(PagedAttention): - @staticmethod def write_to_paged_cache( key: torch.Tensor, @@ -26,19 +25,31 @@ def write_to_paged_cache( v_scale: torch.Tensor, ) -> None: if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, slot_mapping, - kv_cache_dtype, k_scale, - v_scale) + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) else: - kv_cache_torch_dtype = (FP8_DTYPE - if "fp8" in kv_cache_dtype else torch.int8) + kv_cache_torch_dtype = FP8_DTYPE if "fp8" in kv_cache_dtype else torch.int8 key_cache = key_cache.view(kv_cache_torch_dtype) value_cache = value_cache.view(kv_cache_torch_dtype) rocm_aiter.reshape_and_cache_with_pertoken_quant( - key, value, key_cache, value_cache, k_scale, v_scale, - slot_mapping.flatten(), True) + key, + value, + key_cache, + value_cache, + k_scale, + v_scale, + slot_mapping.flatten(), + True, + ) @staticmethod def forward_decode( @@ -78,7 +89,8 @@ def forward_decode( blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_vert_stride=blocksparse_vert_stride, blocksparse_block_size=blocksparse_block_size, - blocksparse_head_sliding_step=blocksparse_head_sliding_step) + blocksparse_head_sliding_step=blocksparse_head_sliding_step, + ) if "fp8" in kv_cache_dtype: key_cache = key_cache.view(current_platform.fp8_dtype()) @@ -87,16 +99,26 @@ def forward_decode( if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: # use blocksparse paged attention block_size = value_cache.size(-1) - assert (blocksparse_block_size > 0 and - blocksparse_block_size % block_size == 0), \ - (f"{blocksparse_block_size=} needs to be a multiple of" - f"{block_size=} used in block_tables.") + assert ( + blocksparse_block_size > 0 and blocksparse_block_size % block_size == 0 + ), ( + f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables." + ) output = torch.empty_like(query) block_size = value_cache.shape[3] max_num_blocks_per_seq = cdiv(max_seq_len, block_size) - rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, - seq_lens, max_num_blocks_per_seq, k_scale, - v_scale, output) + rocm_aiter.pa_fwd_asm( + query, + key_cache, + value_cache, + block_tables, + seq_lens, + max_num_blocks_per_seq, + k_scale, + v_scale, + output, + ) return output diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index 7f5a678615cf..aebc2e63cff6 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -42,10 +42,11 @@ # Only print the following warnings when triton version < 3.2.0. # The issue won't affect performance or accuracy. -if version.parse(triton.__version__) < version.parse('3.2.0'): +if version.parse(triton.__version__) < version.parse("3.2.0"): logger.warning( "The following error message 'operation scheduled before its operands' " - "can be ignored.") + "can be ignored." + ) @triton.jit @@ -101,8 +102,7 @@ def _fwd_kernel_stage1( kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, - cur_batch_seq_len) + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) e_max = -float("inf") e_sum = 0.0 @@ -112,14 +112,18 @@ def _fwd_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + - offs_n // PAGE_SIZE, + Req_to_tokens + + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE - offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + - cur_kv_head * stride_buf_kh + offs_d[None, :]) + offs_buf_k = ( + kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), @@ -133,8 +137,11 @@ def _fwd_kernel_stage1( qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) - offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + - cur_kv_head * stride_buf_vh + offs_dv[None, :]) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) v = tl.load( V_Buffer + offs_buf_v, mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), @@ -150,8 +157,12 @@ def _fwd_kernel_stage1( e_sum = e_sum * re_scale + tl.sum(p, 0) e_max = n_e_max - offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + - split_kv_id * stride_mid_os + offs_dv) + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) tl.store( Att_Out + offs_mid_o, @@ -159,8 +170,12 @@ def _fwd_kernel_stage1( mask=(mask_dv), ) - offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + - split_kv_id * stride_mid_os + Lv) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv + ) tl.store( Att_Out + offs_mid_o_1, @@ -282,25 +297,22 @@ def _fwd_grouped_kernel_stage1( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_req_idx = cur_batch - offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[ - None, :] - q = tl.load(Q + offs_q, - mask=(mask_h[:, None]) & (mask_d[None, :]), - other=0.0) + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) mask_dpe = offs_dpe < Lk - off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + - offs_dpe[None, :]) - qpe = tl.load(Q + off_qpe, - mask=(mask_h[:, None]) & (mask_dpe[None, :]), - other=0.0) + off_qpe = ( + cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + ) + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + ) kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, - cur_batch_seq_len) + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) @@ -310,14 +322,18 @@ def _fwd_grouped_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + - offs_n // PAGE_SIZE, + Req_to_tokens + + stride_req_to_tokens_b * cur_batch_req_idx + + offs_n // PAGE_SIZE, mask=offs_n < split_kv_end, other=0, ) kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE - offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + - cur_kv_head * stride_buf_kh + offs_d[:, None]) + offs_buf_k = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) k = tl.load( K_Buffer + offs_buf_k, mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), @@ -325,13 +341,14 @@ def _fwd_grouped_kernel_stage1( ) qk = tl.dot(q, k.to(q.dtype)) if BLOCK_DPE > 0: - offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + - cur_kv_head * stride_buf_kh + - offs_dpe[:, None]) + offs_buf_kpe = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) kpe = tl.load( K_Buffer + offs_buf_kpe, - mask=(offs_n[None, :] < split_kv_end) & - (mask_dpe[:, None]), + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), other=0.0, ) qk += tl.dot(qpe, kpe.to(qpe.dtype)) @@ -340,11 +357,15 @@ def _fwd_grouped_kernel_stage1( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), - qk, float("-inf")) + qk = tl.where( + mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") + ) - offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + - cur_kv_head * stride_buf_vh + offs_dv[None, :]) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) v = tl.load( V_Buffer + offs_buf_v, mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), @@ -360,9 +381,12 @@ def _fwd_grouped_kernel_stage1( e_sum = e_sum * re_scale + tl.sum(p, 1) e_max = n_e_max - offs_mid_o = (cur_batch * stride_mid_ob + - cur_head[:, None] * stride_mid_oh + - split_kv_id * stride_mid_os + offs_dv[None, :]) + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] + ) tl.store( Att_Out + offs_mid_o, @@ -370,8 +394,12 @@ def _fwd_grouped_kernel_stage1( mask=(mask_h[:, None]) & (mask_dv[None, :]), ) - offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + - split_kv_id * stride_mid_os + Lv) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv + ) tl.store( Att_Out + offs_mid_o_1, @@ -427,11 +455,7 @@ def _decode_grouped_att_m_fwd( if is_hip_: # https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py - extra_kargs = { - "waves_per_eu": 1, - "matrix_instr_nonkdim": 16, - "kpack": 2 - } + extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} num_stages = 1 _fwd_grouped_kernel_stage1[grid]( @@ -504,13 +528,12 @@ def _fwd_kernel_stage2( for split_kv_id in range(0, NUM_KV_SPLITS): kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, - cur_batch_seq_len) + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) if split_kv_end > split_kv_start: - tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, - mask=mask_d, - other=0.0) + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) n_e_max = tl.maximum(tlogic, e_max) @@ -553,11 +576,7 @@ def _decode_softmax_reducev_fwd( if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py - extra_kargs = { - "waves_per_eu": 4, - "matrix_instr_nonkdim": 16, - "kpack": 2 - } + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} grid = (batch, head_num) _fwd_kernel_stage2[grid]( @@ -606,8 +625,9 @@ def decode_attention_fwd_normal( page_size, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len, - num_kv_splits) + _decode_softmax_reducev_fwd( + attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits + ) def decode_attention_fwd_grouped( @@ -636,8 +656,9 @@ def decode_attention_fwd_grouped( page_size, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, lse, v_buffer, b_seq_len, - num_kv_splits) + _decode_softmax_reducev_fwd( + attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits + ) def decode_attention_fwd( diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 49070e4c7ae6..c0ab35d07b1f 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -55,16 +55,16 @@ def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): @triton.jit def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, - stride).to(tl.uint32) + rng_offsets = dropout_offsets( + philox_seed, philox_offset, dropout_p, m, n, stride + ).to(tl.uint32) # TODO: use tl.randint for better performance return tl.rand(philox_seed, rng_offsets) @triton.jit def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, - stride) + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) rng_keep = rng_output > dropout_p return rng_keep @@ -74,9 +74,9 @@ def load_fn(block_ptr, first, second, pad): if first and second: tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) else: tensor = tl.load(block_ptr) return tensor @@ -145,9 +145,7 @@ def _attn_fwd_inner( # if not is_modulo_mn. last step might get wasted but that is okay. # check if this masking works for that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], - actual_seqlen_k, - dtype=tl.int32) + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) @@ -160,8 +158,9 @@ def _attn_fwd_inner( if USE_FP8: qk *= qk_scale if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS - and (n_extra_tokens != 0), "zero") + bias = load_fn( + bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero" + ) # While bias is added after multiplying qk with sm_scale, our # optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. @@ -173,9 +172,12 @@ def _attn_fwd_inner( # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = (batch_philox_offset + - start_m * BLOCK_M * actual_seqlen_k + start_n - - BLOCK_N) + philox_offset = ( + batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + + start_n + - BLOCK_N + ) keep = dropout_mask( philox_seed, philox_offset, @@ -187,8 +189,7 @@ def _attn_fwd_inner( if RETURN_ENCODED_SOFTMAX: tl.store( encoded_softmax_block_ptr, - tl.where(keep, p, - -p).to(encoded_softmax_block_ptr.type.element_ty), + tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty), ) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: @@ -221,89 +222,57 @@ def _attn_fwd_inner( if bias_ptr is not None: bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, BLOCK_N)) + encoded_softmax_block_ptr = tl.advance( + encoded_softmax_block_ptr, (0, BLOCK_N) + ) return acc, l_i, m_i def get_cdna_autotune_configs(): return [ triton.Config( - { - 'BLOCK_M': 256, - 'BLOCK_N': 64, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 256, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 128, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_M': 256, - 'BLOCK_N': 128, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 256, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 64, - 'waves_per_eu': 1, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, num_stages=1, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 64, - 'waves_per_eu': 3, - 'PRE_LOAD_V': True - }, + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": True}, num_stages=1, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_M': 128, - 'BLOCK_N': 64, - 'waves_per_eu': 3, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": False}, num_stages=1, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_M': 64, - 'BLOCK_N': 64, - 'waves_per_eu': 4, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 4, "PRE_LOAD_V": False}, num_stages=1, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_M': 32, - 'BLOCK_N': 32, - 'waves_per_eu': 4, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False}, num_stages=1, - num_warps=8), + num_warps=8, + ), # TODO: This config fails with head_size not pow2 with data mismatches. # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: # triton.Config( # { @@ -315,47 +284,31 @@ def get_cdna_autotune_configs(): # num_stages=1, # num_warps=4, # ), - ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + ], ["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL", "USE_FP8"] def get_rdna_autotune_configs(): return [ triton.Config( - { - 'BLOCK_M': 32, - 'BLOCK_N': 32, - 'waves_per_eu': 4, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 4, "PRE_LOAD_V": False}, num_stages=1, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_M': 32, - 'BLOCK_N': 32, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 32, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_M': 32, - 'BLOCK_N': 16, - 'waves_per_eu': 4, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 4, "PRE_LOAD_V": False}, num_stages=1, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_M': 32, - 'BLOCK_N': 16, - 'waves_per_eu': 2, - 'PRE_LOAD_V': False - }, + {"BLOCK_M": 32, "BLOCK_N": 16, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, - num_warps=2), + num_warps=2, + ), # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: # triton.Config( # { @@ -385,7 +338,7 @@ def get_rdna_autotune_configs(): # }, # num_stages=1, # num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + ], ["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL", "USE_FP8"] def get_autotune_configs(): @@ -501,15 +454,17 @@ def attn_fwd( # This captures the decrease in n_blocks if we have a rectangular attn # matrix n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N + ) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) # If we have no blocks after adjusting for seqlen deltas, this WG is # part of the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) + o_offset = ( + off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + ) O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, BLOCK_DMODEL), @@ -545,8 +500,7 @@ def attn_fwd( padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL # Compute pointers for all the tensors used in this kernel. - q_offset = (off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), @@ -555,8 +509,7 @@ def attn_fwd( block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ) - k_offset = (off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn K_block_ptr = tl.make_block_ptr( base=K + k_offset, shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), @@ -565,8 +518,7 @@ def attn_fwd( block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1), ) - v_offset = (off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk V_block_ptr = tl.make_block_ptr( base=V + v_offset, shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), @@ -587,9 +539,9 @@ def attn_fwd( else: bias_ptr = None if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base \ - + (off_z * HQ + off_h_q) \ - * seqlen_q * seqlen_k + batch_philox_offset = ( + philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k + ) else: batch_philox_offset = 0 # We can ask to return the dropout mask without actually doing any dropout. @@ -692,8 +644,9 @@ def attn_fwd( if bias_ptr is not None: bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, n_full_blocks)) + encoded_softmax_block_ptr = tl.advance( + encoded_softmax_block_ptr, (0, n_full_blocks) + ) acc, l_i, m_i = _attn_fwd_inner( acc, l_i, @@ -749,13 +702,12 @@ def attn_fwd( acc = acc.to(Out.type.element_ty) if IS_CAUSAL: # noqa: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) + out_mask_boundary = tl.full( + (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32 + ) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] - >= out_mask_boundary[None, :]) - z = tl.zeros((1, ), tl.float32) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = tl.zeros((1,), tl.float32) acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m @@ -772,8 +724,7 @@ def attn_fwd( # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), @@ -821,7 +772,6 @@ def check_args( class _attention(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -847,8 +797,7 @@ def forward( def check_and_convert(t, scale): if t.dtype != float8: descale = 1.0 / scale - ts = (t * descale).clamp(min=float8_info.min, - max=float8_info.max) + ts = (t * descale).clamp(min=float8_info.min, max=float8_info.max) return ts.to(float8) else: return t @@ -923,8 +872,7 @@ def check_and_convert(t, scale): bias_strides = (0, 0, 0, 0) p_descale = 1.0 / p_scale - o_descale = 1.0 / fp8_out_scale.item( - ) if fp8_out_scale is not None else 1.0 + o_descale = 1.0 / fp8_out_scale.item() if fp8_out_scale is not None else 1.0 arg_max_seqlens_q = 0 if on_gfx1x() else max_seqlens_q arg_max_seqlens_k = 0 if on_gfx1x() else max_seqlens_k diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index 56d78ed5ea6e..d29f92f8cecb 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -61,8 +61,8 @@ def merge_attn_states_kernel( # If we see an inf assume FA2 and convert inf to -inf for consistency # and correctness. Inf generally doesn't make sense in this context outside # of undefined-behavior/FA2-case, so I think this a safe assumption. - p_lse = float('-inf') if p_lse == float('inf') else p_lse - s_lse = float('-inf') if s_lse == float('inf') else s_lse + p_lse = float("-inf") if p_lse == float("inf") else p_lse + s_lse = float("-inf") if s_lse == float("inf") else s_lse max_lse = tl.maximum(p_lse, s_lse) p_lse = p_lse - max_lse @@ -70,7 +70,7 @@ def merge_attn_states_kernel( # Will reuse precomputed Exp values for scale factor computation. p_se = tl.exp(p_lse) s_se = tl.exp(s_lse) - out_se = (p_se + s_se) + out_se = p_se + s_se if OUTPUT_LSE: out_lse = tl.log(out_se) + max_lse @@ -78,12 +78,20 @@ def merge_attn_states_kernel( head_arange = tl.arange(0, PADDED_HEAD_SIZE) head_mask = head_arange < HEAD_SIZE - p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - mask=head_mask) - s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - mask=head_mask) + p_out = tl.load( + prefix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + s_out = tl.load( + suffix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) # NOTE(woosuk): Be careful with the numerical stability. # We should compute the scale first, and then multiply it with the output. @@ -91,7 +99,8 @@ def merge_attn_states_kernel( p_scale = p_se / out_se s_scale = s_se / out_se out = p_out * p_scale + s_out * s_scale - tl.store(output + token_idx * num_heads * HEAD_SIZE + - head_idx * HEAD_SIZE + head_arange, - out, - mask=head_mask) + tl.store( + output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask, + ) diff --git a/vllm/attention/ops/triton_reshape_and_cache_flash.py b/vllm/attention/ops/triton_reshape_and_cache_flash.py index 0d82935bb418..bbcd560ad56e 100644 --- a/vllm/attention/ops/triton_reshape_and_cache_flash.py +++ b/vllm/attention/ops/triton_reshape_and_cache_flash.py @@ -29,7 +29,6 @@ def reshape_and_cache_kernel_flash( # tune parameters TILE_SIZE: tl.constexpr, ): - token_idx = tl.program_id(axis=0) slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64) if slot_idx < 0: @@ -49,21 +48,20 @@ def reshape_and_cache_kernel_flash( tgt_idx = block_idx * block_stride + block_offset * page_stride # [TILE_SIZE] - key_load = tl.load(key_ptr + src_key_idx + tile_pos, - mask=tile_pos < (num_heads * head_size)) + key_load = tl.load( + key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size) + ) if FP8_KV_CACHE: - if key_load.dtype.is_fp8(): - key_tile = key_load - else: - # tl.store will do the correct implicit cast to fp8, - # based on the key_cache_ptr.dtype.element_ty - key_tile = key_load / tl.load(k_scale) + # tl.store will do the correct implicit cast to fp8, + # based on the key_cache_ptr.dtype.element_ty + key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale) else: key_tile = key_load # [TILE_SIZE] - value_load = tl.load(value_ptr + src_value_idx + tile_pos, - mask=tile_pos < (num_heads * head_size)) + value_load = tl.load( + value_ptr + src_value_idx + tile_pos, mask=tile_pos < (num_heads * head_size) + ) if FP8_KV_CACHE: if value_load.dtype.is_fp8(): value_tile = value_load @@ -88,16 +86,16 @@ def reshape_and_cache_kernel_flash( def triton_reshape_and_cache_flash( - key: torch.Tensor, # [num_tokens, num_heads, head_size] - value: torch.Tensor, # [num_tokens, num_heads, head_size] - # [num_blocks, block_size, num_heads, head_size] + key: torch.Tensor, # [num_tokens, num_heads, head_size] + value: torch.Tensor, # [num_tokens, num_heads, head_size] + # [num_blocks, block_size, num_heads, head_size] key_cache: torch.Tensor, - # [num_blocks, block_size, num_heads, head_size] - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, # [num_tokens] - kv_cache_dtype: str, # "auto", "fp8" - k_scale: torch.Tensor, # float32 - v_scale: torch.Tensor, # float32 + # [num_blocks, block_size, num_heads, head_size] + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, # [num_tokens] + kv_cache_dtype: str, # "auto", "fp8" + k_scale: torch.Tensor, # float32 + v_scale: torch.Tensor, # float32 ): num_tokens = key.shape[0] num_heads = key.shape[1] @@ -113,27 +111,36 @@ def triton_reshape_and_cache_flash( head_stride = key_cache.stride()[2] assert head_stride == head_size, "only continous heads are supported" - assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), \ + assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." - kv_cache_torch_dtype = current_platform.fp8_dtype() if \ - kv_cache_dtype.startswith("fp8") else key_cache.dtype + ) + kv_cache_torch_dtype = ( + current_platform.fp8_dtype() + if kv_cache_dtype.startswith("fp8") + else key_cache.dtype + ) - if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith( - "fp8"): + if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): # to avoid erounous implicit cast in triton kernel (tl.store to uint8) # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) key_cache = key_cache.view(kv_cache_torch_dtype) value_cache = value_cache.view(kv_cache_torch_dtype) - assert kv_cache_dtype != torch.uint8, "explicit fp8 cast and store to "\ + assert kv_cache_dtype != torch.uint8, ( + "explicit fp8 cast and store to " "uint8 is not supported by triton reshape_and_cache_flash" + ) FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ - torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8, - torch.float8_e4m3fnuz], \ - "unsupported dtype of KV cache tensor, got "\ - "{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " \ - "fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz." + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.uint8, + torch.float8_e4m3fnuz, + ], ( + "unsupported dtype of KV cache tensor, got " + "{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, " + "fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz." + ) # heuristics instead of autotuning TILE_SIZE = min(2048, triton.next_power_of_2(n)) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 9e7cafc17428..565be1c39bec 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -31,8 +31,13 @@ def apply_softcap(S, x): @triton.jit -def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, - BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: tl.constexpr, + use_q_block_mode: tl.constexpr, +): left: tl.int32 = 0 right = num_seqs while left < right: @@ -100,19 +105,18 @@ def kernel_unified_attention_2d( q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) - seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, - BLOCK_Q, True) + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) - q_block_start_idx = tl.load(query_start_len_ptr + - seq_idx) // BLOCK_Q + seq_idx + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx q_block_local_idx = q_block_global_idx - q_block_start_idx cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_index \ - - cur_batch_in_all_start_index + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: return @@ -123,10 +127,12 @@ def kernel_unified_attention_2d( query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos - query_offset_1 = kv_head_idx * num_queries_per_kv + \ - offs_m % num_queries_per_kv - query_offset = (query_offset_0[:, None] * query_stride_0 + - query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) @@ -161,19 +167,24 @@ def kernel_unified_attention_2d( # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, - mask=query_mask_1, - other=0.0) + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) # query-query attention bias if USE_QQ_BIAS: - qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 - ) # shape: [BLOCK_M] + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] # compute the length of the longest sequence prefix spanned by any # query token in the current q_block (q_block_local_idx) - max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( - BLOCK_M - 1) // num_queries_per_kv + 1 + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) # adjust for potential padding in the last q_block by considering the # actual sequence length @@ -211,23 +222,30 @@ def kernel_unified_attention_2d( seq_offset = j * TILE_SIZE + offs_t tile_mask = seq_offset < max_seq_prefix_len - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + - seq_offset // BLOCK_SIZE).to(tl.int64) + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) - v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + - kv_head_idx * stride_v_cache_2 + - offs_d[None, :] * stride_v_cache_3 + - (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) - k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + - kv_head_idx * stride_k_cache_2 + - offs_d[:, None] * stride_k_cache_3 + - (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) # K : (HEAD_SIZE, TILE_SIZE) - K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], - other=0.0) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + ) if K_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -238,9 +256,11 @@ def kernel_unified_attention_2d( K = K_load # V : (TILE_SIZE, HEAD_SIZE) - V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :] & tile_mask[:, None], - other=0.0) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + ) if V_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -260,12 +280,16 @@ def kernel_unified_attention_2d( if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, - S, float("-inf")) + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) if SLIDING_WINDOW > 0: - S = tl.where((context_len + query_pos[:, None] - seq_offset) - < SLIDING_WINDOW, S, float("-inf")) + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -315,9 +339,11 @@ def kernel_unified_attention_2d( acc = acc * tl.load(out_scale) acc = tl.clamp(acc, FP8_MIN, FP8_MAX) - output_offset = (query_offset_0[:, None] * output_stride_0 + - query_offset_1[:, None] * output_stride_1 + - offs_d[None, :]) + output_offset = ( + query_offset_0[:, None] * output_stride_0 + + query_offset_1[:, None] * output_stride_1 + + offs_d[None, :] + ) tl.store( output_ptr + output_offset, @@ -328,68 +354,67 @@ def kernel_unified_attention_2d( @triton.jit def kernel_unified_attention_3d( - segm_output_ptr, - # [num_tokens, num_query_heads, num_segments, head_size] - segm_max_ptr, # [num_tokens, num_query_heads, num_segments] - segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] - query_ptr, # [num_tokens, num_query_heads, head_size] - key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] - value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] - sink_ptr, # [num_query_heads] - block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] - seq_lens_ptr, # [num_seqs] - alibi_slopes_ptr, # [num_query_heads] - qq_bias_ptr, # [num_query_tokens, num_query_tokens] - scale, # float32 - k_scale, # float32 - v_scale, # float32 - softcap, # float32 - num_query_heads: tl.constexpr, # int - num_queries_per_kv: tl.constexpr, # int - block_table_stride: tl.int64, # int - query_stride_0: tl.int64, # int - query_stride_1: tl.int64, # int, should be equal to head_size - qq_bias_stride_0: tl.int64, # int - BLOCK_SIZE: tl.constexpr, # int - TILE_SIZE: tl.constexpr, # int, must be power of 2 - HEAD_SIZE: tl.constexpr, # int - HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 - USE_ALIBI_SLOPES: tl.constexpr, # bool - USE_QQ_BIAS: tl.constexpr, # bool - USE_SOFTCAP: tl.constexpr, # bool - USE_SINKS: tl.constexpr, # bool - SLIDING_WINDOW: tl.constexpr, # int - stride_k_cache_0: tl.int64, # int - stride_k_cache_1: tl.int64, # int - stride_k_cache_2: tl.int64, # int - stride_k_cache_3: tl.constexpr, # int - stride_v_cache_0: tl.int64, # int - stride_v_cache_1: tl.int64, # int - stride_v_cache_2: tl.int64, # int - stride_v_cache_3: tl.constexpr, # int - query_start_len_ptr, # [num_seqs+1] - BLOCK_Q: tl.constexpr, # int - num_seqs: tl.int32, - BLOCK_M: tl.constexpr, # int - NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + segm_output_ptr, + # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x] + value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + scale, # float32 + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_query_heads: tl.constexpr, # int + num_queries_per_kv: tl.constexpr, # int + block_table_stride: tl.int64, # int + query_stride_0: tl.int64, # int + query_stride_1: tl.int64, # int, should be equal to head_size + qq_bias_stride_0: tl.int64, # int + BLOCK_SIZE: tl.constexpr, # int + TILE_SIZE: tl.constexpr, # int, must be power of 2 + HEAD_SIZE: tl.constexpr, # int + HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 + USE_ALIBI_SLOPES: tl.constexpr, # bool + USE_QQ_BIAS: tl.constexpr, # bool + USE_SOFTCAP: tl.constexpr, # bool + USE_SINKS: tl.constexpr, # bool + SLIDING_WINDOW: tl.constexpr, # int + stride_k_cache_0: tl.int64, # int + stride_k_cache_1: tl.int64, # int + stride_k_cache_2: tl.int64, # int + stride_k_cache_3: tl.constexpr, # int + stride_v_cache_0: tl.int64, # int + stride_v_cache_1: tl.int64, # int + stride_v_cache_2: tl.int64, # int + stride_v_cache_3: tl.constexpr, # int + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + num_seqs: tl.int32, + BLOCK_M: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) - seq_idx = find_seq_idx(query_start_len_ptr, q_block_global_idx, num_seqs, - BLOCK_Q, True) + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) - q_block_start_idx = tl.load(query_start_len_ptr + - seq_idx) // BLOCK_Q + seq_idx + q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx q_block_local_idx = q_block_global_idx - q_block_start_idx cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx) cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_index \ - - cur_batch_in_all_start_index + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: return @@ -410,10 +435,12 @@ def kernel_unified_attention_3d( query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos - query_offset_1 = kv_head_idx * num_queries_per_kv + \ - offs_m % num_queries_per_kv - query_offset = (query_offset_0[:, None] * query_stride_0 + - query_offset_1[:, None] * query_stride_1 + offs_d[None, :]) + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) @@ -448,19 +475,24 @@ def kernel_unified_attention_3d( # alibi slope for this head if USE_ALIBI_SLOPES: - alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1, - mask=query_mask_1, - other=0.0) + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1, mask=query_mask_1, other=0.0 + ) # query-query attention bias if USE_QQ_BIAS: - qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 - ) # shape: [BLOCK_M] + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] # compute the length of the longest sequence prefix spanned by any # query token in the current q_block (q_block_local_idx) - max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( - BLOCK_M - 1) // num_queries_per_kv + 1 + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) # adjust for potential padding in the last q_block by considering the # actual sequence length @@ -473,29 +505,36 @@ def kernel_unified_attention_3d( # iterate through tiles within current segment for j in range( - segm_idx * tiles_per_segment, - min((segm_idx + 1) * tiles_per_segment, num_tiles), + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), ): seq_offset = j * TILE_SIZE + offs_t tile_mask = seq_offset < max_seq_prefix_len - physical_block_idx = tl.load(block_tables_ptr + block_table_offset + - seq_offset // BLOCK_SIZE).to(tl.int64) + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) - v_offset = (physical_block_idx[:, None] * stride_v_cache_0 + - kv_head_idx * stride_v_cache_2 + - offs_d[None, :] * stride_v_cache_3 + - (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1) + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) - k_offset = (physical_block_idx[None, :] * stride_k_cache_0 + - kv_head_idx * stride_k_cache_2 + - offs_d[:, None] * stride_k_cache_3 + - (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1) + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) # K : (HEAD_SIZE, TILE_SIZE) - K_load = tl.load(key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], - other=0.0) + K_load = tl.load( + key_cache_ptr + k_offset, + mask=dim_mask[:, None] & tile_mask[None, :], + other=0.0, + ) if K_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -506,9 +545,11 @@ def kernel_unified_attention_3d( K = K_load # V : (TILE_SIZE, HEAD_SIZE) - V_load = tl.load(value_cache_ptr + v_offset, - mask=dim_mask[None, :] & tile_mask[:, None], - other=0.0) + V_load = tl.load( + value_cache_ptr + v_offset, + mask=dim_mask[None, :] & tile_mask[:, None], + other=0.0, + ) if V_load.dtype.is_fp8(): if Q.dtype.is_fp8(): @@ -527,12 +568,16 @@ def kernel_unified_attention_3d( if USE_SOFTCAP: S = apply_softcap(S, softcap) - S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, - S, float("-inf")) + S = tl.where( + query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf") + ) if SLIDING_WINDOW > 0: - S = tl.where((context_len + query_pos[:, None] - seq_offset) - < SLIDING_WINDOW, S, float("-inf")) + S = tl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) if USE_ALIBI_SLOPES: S += alibi_slope[:, None] * (seq_offset - context_len) @@ -577,29 +622,31 @@ def kernel_unified_attention_3d( acc += tl.dot(P.to(V.dtype), V) segm_output_offset = ( - query_offset_0[:, None].to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - segm_idx * HEAD_SIZE_PADDED + tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + query_offset_0[:, None].to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + segm_idx * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) tl.store( segm_output_ptr + segm_output_offset, acc, mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], ) - segm_offset = (query_offset_0.to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ) + - query_offset_1 * NUM_SEGMENTS_PER_SEQ + segm_idx) + segm_offset = ( + query_offset_0.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + + segm_idx + ) tl.store(segm_max_ptr + segm_offset, M, mask=query_mask_0 & query_mask_1) - tl.store(segm_expsum_ptr + segm_offset, - L, - mask=query_mask_0 & query_mask_1) + tl.store(segm_expsum_ptr + segm_offset, L, mask=query_mask_0 & query_mask_1) @triton.jit def reduce_segments( output_ptr, # [num_tokens, num_query_heads, head_size] segm_output_ptr, - #[num_tokens, num_query_heads, max_num_segments, head_size] + # [num_tokens, num_query_heads, max_num_segments, head_size] segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] seq_lens_ptr, # [num_seqs] @@ -622,8 +669,9 @@ def reduce_segments( query_token_idx = tl.program_id(0) query_head_idx = tl.program_id(1) - seq_idx = find_seq_idx(query_start_len_ptr, query_token_idx, num_seqs, - BLOCK_Q, False) + seq_idx = find_seq_idx( + query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False + ) # sequence len for this particular sequence seq_len = tl.load(seq_lens_ptr + seq_idx) @@ -635,34 +683,32 @@ def reduce_segments( # create masks for subsequent loads act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( - [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32) - dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, - 0).to(tl.int1) + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 + ) + dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1, 0).to(tl.int1) # load segment maxima - segm_offset = (query_token_idx.to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ) + - query_head_idx * NUM_SEGMENTS_PER_SEQ + - tl.arange(0, NUM_SEGMENTS_PER_SEQ)) - segm_max = tl.load(segm_max_ptr + segm_offset, - mask=segm_mask, - other=float("-inf")) + segm_offset = ( + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ) + ) + segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) overall_max = tl.max(segm_max) # load and rescale segment exp sums - segm_expsum = tl.load(segm_expsum_ptr + segm_offset, - mask=segm_mask, - other=0.0) + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) segm_expsum = segm_expsum * tl.exp(segm_max - overall_max) overall_expsum = tl.sum(segm_expsum) # load, rescale, and add segment attention outputs segm_output_offset = ( - query_token_idx.to(tl.int64) * - (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + - tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + - tl.arange(0, HEAD_SIZE_PADDED)[None, :]) + query_token_idx.to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE_PADDED) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE_PADDED + + tl.arange(0, HEAD_SIZE_PADDED)[None, :] + ) segm_output = tl.load( segm_output_ptr + segm_output_offset, mask=segm_mask[:, None] & dim_mask[None, :], @@ -678,9 +724,11 @@ def reduce_segments( acc = tl.clamp(acc, FP8_MIN, FP8_MAX) # write result - output_offset = (query_token_idx * output_stride_0 + - query_head_idx * output_stride_1 + - tl.arange(0, HEAD_SIZE_PADDED)) + output_offset = ( + query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE_PADDED) + ) tl.store(output_ptr + output_offset, acc, mask=dim_mask) @@ -707,13 +755,11 @@ def unified_attention( # Optional tensor for sinks sinks=None, ): - assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" if sinks is not None: - assert sinks.shape[0] == q.shape[1], \ - "Sinks must be num_query_heads size" + assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size" use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None @@ -725,8 +771,9 @@ def unified_attention( num_queries_per_kv = num_query_heads // num_kv_heads head_size = q.shape[2] - BLOCK_M = 16 if num_queries_per_kv <= 16 else triton.next_power_of_2( - num_queries_per_kv) + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) BLOCK_Q = BLOCK_M // num_queries_per_kv # Ideally we would launch with kernel with: @@ -748,10 +795,12 @@ def unified_attention( # if batch contains a prefill if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128: - kernel_unified_attention_2d[( - total_num_q_blocks, - num_kv_heads, - )]( + kernel_unified_attention_2d[ + ( + total_num_q_blocks, + num_kv_heads, + ) + ]( output_ptr=out, query_ptr=q, key_cache_ptr=k, @@ -825,52 +874,51 @@ def unified_attention( device=q.device, ) - kernel_unified_attention_3d[( - total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( - segm_output_ptr=segm_output, - segm_max_ptr=segm_max, - segm_expsum_ptr=segm_expsum, - query_ptr=q, - key_cache_ptr=k, - value_cache_ptr=v, - sink_ptr=sinks, - block_tables_ptr=block_table, - seq_lens_ptr=seqused_k, - alibi_slopes_ptr=alibi_slopes, - qq_bias_ptr=qq_bias, - scale=softmax_scale, - k_scale=k_descale, - v_scale=v_descale, - softcap=softcap, - num_query_heads=num_query_heads, - num_queries_per_kv=num_queries_per_kv, - block_table_stride=block_table.stride(0), - query_stride_0=q.stride(0), - query_stride_1=q.stride(1), - qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, - BLOCK_SIZE=block_size, - TILE_SIZE=TILE_SIZE_DECODE, - HEAD_SIZE=head_size, - HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), - USE_ALIBI_SLOPES=use_alibi_slopes, - USE_QQ_BIAS=use_qq_bias, - USE_SOFTCAP=(softcap > 0), - USE_SINKS=(sinks is not None), - SLIDING_WINDOW=(1 + window_size[0]), - stride_k_cache_0=k.stride(0), - stride_k_cache_1=k.stride(1), - stride_k_cache_2=k.stride(2), - stride_k_cache_3=k.stride(3), - stride_v_cache_0=v.stride(0), - stride_v_cache_1=v.stride(1), - stride_v_cache_2=v.stride(2), - stride_v_cache_3=v.stride(3), - query_start_len_ptr=cu_seqlens_q, - BLOCK_Q=BLOCK_Q, - num_seqs=num_seqs, - BLOCK_M=BLOCK_M, - NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, - ) + kernel_unified_attention_3d[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + scale=softmax_scale, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_query_heads=num_query_heads, + num_queries_per_kv=num_queries_per_kv, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + TILE_SIZE=TILE_SIZE_DECODE, + HEAD_SIZE=head_size, + HEAD_SIZE_PADDED=triton.next_power_of_2(head_size), + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=(1 + window_size[0]), + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + num_seqs=num_seqs, + BLOCK_M=BLOCK_M, + NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, + ) reduce_segments[(q.shape[0], num_query_heads)]( output_ptr=out, segm_output_ptr=segm_output, @@ -879,8 +927,7 @@ def unified_attention( seq_lens_ptr=seqused_k, num_seqs=num_seqs, num_query_heads=num_query_heads, - out_scale_inv=1 / - output_scale if output_scale is not None else 1.0, + out_scale_inv=1 / output_scale if output_scale is not None else 1.0, output_stride_0=out.stride(0), output_stride_1=out.stride(1), block_table_stride=block_table.stride(0), diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index d3214fecfa70..effd35444d54 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass from functools import cache -from typing import Generator, Optional, Union +from typing import Optional, Union import torch @@ -29,12 +30,11 @@ def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: loaded. """ assert backend_name is not None - return _Backend[backend_name] if backend_name in _Backend.__members__ else \ - None + return _Backend[backend_name] if backend_name in _Backend.__members__ else None def get_env_variable_attn_backend() -> Optional[_Backend]: - ''' + """ Get the backend override specified by the vLLM attention backend environment variable, if one is specified. @@ -42,10 +42,9 @@ def get_env_variable_attn_backend() -> Optional[_Backend]: * _Backend enum value if an override is specified * None otherwise - ''' + """ backend_name = os.environ.get(STR_BACKEND_ENV_VAR) - return (None - if backend_name is None else backend_name_to_enum(backend_name)) + return None if backend_name is None else backend_name_to_enum(backend_name) # Global state allows a particular choice of backend @@ -59,7 +58,7 @@ def get_env_variable_attn_backend() -> Optional[_Backend]: def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: - ''' + """ Force all attention operations to use a specified backend. Passing `None` for the argument re-enables automatic @@ -68,16 +67,16 @@ def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: Arguments: * attn_backend: backend selection (None to revert to auto) - ''' + """ global forced_attn_backend forced_attn_backend = attn_backend def get_global_forced_attn_backend() -> Optional[_Backend]: - ''' + """ Get the currently-forced choice of attention backend, or None if auto-selection is currently enabled. - ''' + """ return forced_attn_backend @@ -110,26 +109,27 @@ def is_attn_backend_supported( assert isinstance(attn_backend, type) # TODO: Update the interface once V0 is removed - if get_supported_head_sizes := getattr(attn_backend, - "get_supported_head_sizes", None): + if get_supported_head_sizes := getattr( + attn_backend, "get_supported_head_sizes", None + ): is_head_size_supported = head_size in get_supported_head_sizes() - elif validate_head_size := getattr(attn_backend, "validate_head_size", - None): + elif validate_head_size := getattr(attn_backend, "validate_head_size", None): try: validate_head_size(head_size) is_head_size_supported = True except Exception: is_head_size_supported = False else: - raise NotImplementedError(f"{attn_backend.__name__} does not support " - "head size validation") + raise NotImplementedError( + f"{attn_backend.__name__} does not support head size validation" + ) - if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", - None): + if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", None): is_dtype_supported = dtype in get_supported_dtypes() else: - raise NotImplementedError(f"{attn_backend.__name__} does not support " - "dtype validation") + raise NotImplementedError( + f"{attn_backend.__name__} does not support dtype validation" + ) return _IsSupported( can_import=True, @@ -175,15 +175,13 @@ def _cached_get_attn_backend( has_sink: bool = False, use_sparse: bool = False, ) -> type[AttentionBackend]: - # Check whether a particular choice of backend was # previously forced. # # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND # ENVIRONMENT VARIABLE. selected_backend = None - backend_by_global_setting: Optional[_Backend] = ( - get_global_forced_attn_backend()) + backend_by_global_setting: Optional[_Backend] = get_global_forced_attn_backend() if backend_by_global_setting is not None: selected_backend = backend_by_global_setting else: @@ -195,29 +193,41 @@ def _cached_get_attn_backend( "The suffix '_VLLM_V1' in the environment variable " "%s is no longer necessary as V0 backends have been " "deprecated. Please remove this suffix from your " - "environment variable setting.", STR_BACKEND_ENV_VAR) - backend_by_env_var = backend_by_env_var.removesuffix( - "_VLLM_V1") + "environment variable setting.", + STR_BACKEND_ENV_VAR, + ) + backend_by_env_var = backend_by_env_var.removesuffix("_VLLM_V1") selected_backend = backend_name_to_enum(backend_by_env_var) if selected_backend is None: raise ValueError( f"Invalid attention backend: '{backend_by_env_var}'. " - f"Valid backends are: {list(_Backend.__members__.keys())}") + f"Valid backends are: {list(_Backend.__members__.keys())}" + ) # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( - selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, - use_mla, has_sink, use_sparse) + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink, + use_sparse, + ) if not attention_cls: raise ValueError( - f"Invalid attention backend for {current_platform.device_name}") + f"Invalid attention backend for {current_platform.device_name}" + ) return resolve_obj_by_qualname(attention_cls) @contextmanager def global_force_attn_backend_context_manager( - attn_backend: _Backend) -> Generator[None, None, None]: - ''' + attn_backend: _Backend, +) -> Generator[None, None, None]: + """ Globally force a vLLM attention backend override within a context manager, reverting the global attention backend override to its prior state upon exiting the context @@ -230,7 +240,7 @@ def global_force_attn_backend_context_manager( Returns: * Generator - ''' + """ # Save the current state of the global backend override (if any) original_value = get_global_forced_attn_backend() diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index dc0af7e28e3e..e13afd46ee96 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -10,11 +10,12 @@ if current_platform.is_cuda(): from vllm import _custom_ops as ops + reshape_and_cache_flash = ops.reshape_and_cache_flash - from vllm.vllm_flash_attn import (flash_attn_varlen_func, - get_scheduler_metadata) + from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops as ops + reshape_and_cache_flash = ops.reshape_and_cache_flash flash_attn_varlen_func = ops.flash_attn_varlen_func get_scheduler_metadata = ops.get_scheduler_metadata @@ -23,18 +24,23 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: # import here to avoid circular dependencies from vllm.platforms import current_platform + if current_platform.is_xpu(): return 2 try: from vllm.vllm_flash_attn.flash_attn_interface import ( - fa_version_unsupported_reason, is_fa_version_supported) + fa_version_unsupported_reason, + is_fa_version_supported, + ) + device_capability = current_platform.get_device_capability() assert device_capability is not None # 1. default version depending on platform - fa_version = 3 if (device_capability.major == 9 - and is_fa_version_supported(3)) else 2 + fa_version = ( + 3 if (device_capability.major == 9 and is_fa_version_supported(3)) else 2 + ) # 2. override if passed by environment if envs.VLLM_FLASH_ATTN_VERSION is not None: @@ -45,17 +51,22 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: if device_capability.major == 10 and fa_version == 3: logger.warning_once( "Cannot use FA version 3 on Blackwell platform " - "defaulting to FA version 2.") + "defaulting to FA version 2." + ) fa_version = 2 if requires_alibi and fa_version == 3: - logger.warning_once("Cannot use FA version 3 with ALiBi, " - "defaulting to FA version 2.") + logger.warning_once( + "Cannot use FA version 3 with ALiBi, defaulting to FA version 2." + ) fa_version = 2 if not is_fa_version_supported(fa_version): - logger.error("Cannot use FA version %d is not supported due to %s", - fa_version, fa_version_unsupported_reason(fa_version)) + logger.error( + "Cannot use FA version %d is not supported due to %s", + fa_version, + fa_version_unsupported_reason(fa_version), + ) assert is_fa_version_supported(fa_version) return fa_version @@ -64,18 +75,25 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: def flash_attn_supports_fp8() -> bool: - return get_flash_attn_version() == 3 and \ - current_platform.get_device_capability().major == 9 + return ( + get_flash_attn_version() == 3 + and current_platform.get_device_capability().major == 9 + ) def flash_attn_supports_mla(): from vllm.platforms import current_platform + if current_platform.is_cuda(): try: from vllm.vllm_flash_attn.flash_attn_interface import ( - is_fa_version_supported) - return is_fa_version_supported(3) \ + is_fa_version_supported, + ) + + return ( + is_fa_version_supported(3) and current_platform.get_device_capability()[0] == 9 + ) except (ImportError, AssertionError): pass return False diff --git a/vllm/attention/utils/kv_sharing_utils.py b/vllm/attention/utils/kv_sharing_utils.py index b4ae8bdf4d76..93af5bf7e13f 100644 --- a/vllm/attention/utils/kv_sharing_utils.py +++ b/vllm/attention/utils/kv_sharing_utils.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -def validate_kv_sharing_target(current_layer_name, target_layer_name, - static_forward_context): - error_msg = (f"Specified KV sharing target layer for {current_layer_name} " - f"is not valid: target layer {target_layer_name} ") +def validate_kv_sharing_target( + current_layer_name, target_layer_name, static_forward_context +): + error_msg = ( + f"Specified KV sharing target layer for {current_layer_name} " + f"is not valid: target layer {target_layer_name} " + ) if current_layer_name == target_layer_name: - raise ValueError(error_msg + - "cannot be the same as the current layer.") + raise ValueError(error_msg + "cannot be the same as the current layer.") if target_layer_name not in static_forward_context: from vllm.model_executor.models.utils import extract_layer_index @@ -20,14 +22,12 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name, if current_layer_idx <= target_layer_idx: raise ValueError(error_msg + "must come before the current layer.") else: - raise ValueError(error_msg + - "is not a valid Attention layer in the model.") + raise ValueError(error_msg + "is not a valid Attention layer in the model.") # Currently KV sharing is only supported between layers of the same type - target_layer_attn_type = static_forward_context[ - target_layer_name].attn_type + target_layer_attn_type = static_forward_context[target_layer_name].attn_type expected = static_forward_context[current_layer_name].attn_type if target_layer_attn_type != expected: raise ValueError( - error_msg + - f"must be the same type as the current layer ({expected}).") + error_msg + f"must be the same type as the current layer ({expected})." + ) diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 01124872e98c..e0ba863b9210 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -18,6 +18,7 @@ class BeamSearchSequence: The text field is optional and will only be filled when the sequence is about to be returned to the user. """ + # The tokens include the prompt. tokens: list[int] logprobs: list[dict[int, Logprob]] @@ -36,11 +37,11 @@ class BeamSearchOutput: It contains the list of the best beam search sequences. The length of the list is equal to the beam width. """ + sequences: list[BeamSearchSequence] class BeamSearchInstance: - def __init__( self, prompt_tokens: list[int], @@ -79,9 +80,9 @@ def get_beam_search_score( def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): - def sort_beams_key(x: BeamSearchSequence) -> float: - return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, - length_penalty) + return get_beam_search_score( + x.tokens, x.cum_logprob, eos_token_id, length_penalty + ) return sort_beams_key diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index bf0defc24542..e955b15e87fe 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -11,6 +11,7 @@ - HuggingFace - VisionArena """ + import argparse import ast import base64 @@ -77,9 +78,7 @@ class SampleRequest: prompt: Union[str, list[str]] prompt_len: int expected_output_len: int - multi_modal_data: Optional[ - Union[MultiModalDataDict, dict, list[dict]] - ] = None + multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None lora_request: Optional[LoRARequest] = None request_id: Optional[str] = None @@ -97,6 +96,8 @@ def __init__( self, dataset_path: Optional[str] = None, random_seed: int = DEFAULT_SEED, + disable_shuffle: bool = False, + **kwargs, ) -> None: """ Initialize the BenchmarkDataset with an optional dataset path and random @@ -111,16 +112,15 @@ def __init__( self.dataset_path = dataset_path # Set the random seed, ensuring that a None value is replaced with the # default seed. - self.random_seed = (random_seed - if random_seed is not None else self.DEFAULT_SEED) + self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED + self.disable_shuffle = disable_shuffle self.data = None def apply_multimodal_chat_transformation( - self, - prompt: str, - mm_content: Optional[ - Union[MultiModalDataDict, dict, list[dict]] - ] = None) -> list[dict]: + self, + prompt: str, + mm_content: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None, + ) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. This method is used for chat models that expect a specific conversation @@ -134,8 +134,8 @@ def apply_multimodal_chat_transformation( content.append(mm_content) else: raise TypeError( - "Could not process multimodal content of type: " + - f"{type(mm_content)}" + "Could not process multimodal content of type: " + + f"{type(mm_content)}" ) return [{"role": "user", "content": content}] @@ -150,8 +150,7 @@ def load_data(self) -> None: NotImplementedError: If a subclass does not implement this method. """ # TODO (jenniferzhao): add support for downloading data - raise NotImplementedError( - "load_data must be implemented in subclasses.") + raise NotImplementedError("load_data must be implemented in subclasses.") def get_random_lora_request( self, @@ -187,10 +186,13 @@ def get_random_lora_request( return lora_request @abstractmethod - def sample(self, tokenizer: PreTrainedTokenizerBase, - num_requests: int, - request_id_prefix: str = "", - no_oversample: bool = False) -> list[SampleRequest]: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, + ) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -229,8 +231,7 @@ def maybe_oversample_requests( """ if no_oversample: - logger.info("Skipping oversampling. " \ - "Total samples: %d.", len(requests)) + logger.info("Skipping oversampling. Total samples: %d.", len(requests)) return if len(requests) < num_requests: @@ -242,14 +243,15 @@ def maybe_oversample_requests( req.request_id = request_id_prefix + str(len(requests) + i) additional.append(req) requests.extend(additional) - logger.info("Oversampled requests to reach %d total samples.", - num_requests) + logger.info("Oversampled requests to reach %d total samples.", num_requests) ids = [req.request_id for req in requests] if len(ids) != len(set(ids)): - raise ValueError("Duplicate request_id found in the sampled " - "requests. Please ensure that each request_id " - "is unique.") + raise ValueError( + "Duplicate request_id found in the sampled " + "requests. Please ensure that each request_id " + "is unique." + ) # ----------------------------------------------------------------------------- @@ -274,14 +276,14 @@ def is_valid_sequence( """ # Check for invalid conditions prompt_too_short = prompt_len < min_len - output_too_short = (not skip_min_output_len_check) and (output_len - < min_len) + output_too_short = (not skip_min_output_len_check) and (output_len < min_len) prompt_too_long = prompt_len > max_prompt_len combined_too_long = (prompt_len + output_len) > max_total_len # Return True if none of the invalid conditions are met - return not (prompt_too_short or output_too_short or prompt_too_long - or combined_too_long) + return not ( + prompt_too_short or output_too_short or prompt_too_long or combined_too_long + ) @cache @@ -313,28 +315,30 @@ def process_image(image: Any) -> Mapping[str, Any]: Raises: ValueError: If the input is not a supported type. """ - if isinstance(image, dict) and 'bytes' in image: - image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, dict) and "bytes" in image: + image = Image.open(BytesIO(image["bytes"])) if isinstance(image, Image.Image): image = convert_image_mode(image, "RGB") with io.BytesIO() as image_data: image.save(image_data, format="JPEG") - image_base64 = base64.b64encode( - image_data.getvalue()).decode("utf-8") + image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") return { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - }, + "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, } if isinstance(image, str): - image_url = (image if image.startswith( - ("http://", "https://", "file://")) else f"file://{image}") + image_url = ( + image + if image.startswith(("http://", "https://", "file://")) + else f"file://{image}" + ) return {"type": "image_url", "image_url": {"url": image_url}} - raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" - " or str or dictionary with raw image bytes.") + raise ValueError( + f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes." + ) def process_video(video: Any) -> Mapping[str, Any]: @@ -353,19 +357,20 @@ def process_video(video: Any) -> Mapping[str, Any]: Raises: ValueError: If the input is not a supported type. """ - if isinstance(video, dict) and 'bytes' in video: - video_bytes = video['bytes'] + if isinstance(video, dict) and "bytes" in video: + video_bytes = video["bytes"] video_base64 = base64.b64encode(video_bytes).decode("utf-8") return { "type": "video_url", - "video_url": { - "url": f"data:video/mp4;base64,{video_base64}" - }, + "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, } if isinstance(video, str): - video_url = (video if video.startswith( - ("http://", "https://", "file://")) else f"file://{video}") + video_url = ( + video + if video.startswith(("http://", "https://", "file://")) + else f"file://{video}" + ) return {"type": "video_url", "video_url": {"url": video_url}} raise ValueError( @@ -385,8 +390,8 @@ def gen_prompt_decode_to_target_len( Ensure decoded-then-encoded prompt length matches the target token length. This function decodes an initial token sequence to text and re-encodes it - , iteratively adjusting the token sequence length to match a target. - This is necessary because some tokenizers do not guarantee a 1:1 mapping + , iteratively adjusting the token sequence length to match a target. + This is necessary because some tokenizers do not guarantee a 1:1 mapping between consecutive tokens and the decoded-then-encoded sequence length. For example, for GPT2Tokenizer: [6880, 6881] -> ['Ġcalls', 'here'] -> @@ -398,14 +403,12 @@ def gen_prompt_decode_to_target_len( token_mismatch = 0 while True: prompt = tokenizer.decode(token_sequence) - token_sequence = tokenizer.encode( - prompt, add_special_tokens=add_special_tokens - ) + token_sequence = tokenizer.encode(prompt, add_special_tokens=add_special_tokens) if remain_num_try <= 0: if len(token_sequence) != target_token_len: token_mismatch = len(token_sequence) - target_token_len break - + if len(token_sequence) == target_token_len: break elif len(token_sequence) < target_token_len: @@ -429,10 +432,12 @@ def gen_prompt_decode_to_target_len( return prompt, token_sequence, token_mismatch + # ----------------------------------------------------------------------------- # Random Dataset Implementation (Synthetic Data) # ----------------------------------------------------------------------------- + class RandomDataset(BenchmarkDataset): """ Synthetic text-only dataset for serving/throughput benchmarks. @@ -446,6 +451,7 @@ class RandomDataset(BenchmarkDataset): - Decode then re-encode/truncate to ensure prompt token counts match. - Uses numpy.default_rng seeded with random_seed for reproducible sampling. """ + # Default values copied from benchmark_serving.py for the random dataset. DEFAULT_PREFIX_LEN = 0 DEFAULT_RANGE_RATIO = 0.0 @@ -472,7 +478,6 @@ def sample( batchsize: int = 1, **kwargs, ) -> list[SampleRequest]: - input_lens, output_lens, offsets = self.get_sampling_params( num_requests, range_ratio, input_len, output_len, tokenizer ) @@ -484,7 +489,7 @@ def sample( requests = [] token_mismatch_total = 0 for i in range(num_requests): - prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501 + prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501 tokenizer=tokenizer, prefix_token_ids=prefix_token_ids, prefix_len=prefix_len, @@ -517,7 +522,7 @@ def sample( ) ) requests = batch_requests - + if token_mismatch_total != 0: sign = "more" if token_mismatch_total > 0 else "fewer" logger.warning( @@ -538,8 +543,7 @@ def get_prefix( Get the prefix for the dataset. """ return ( - self._rng.integers( - 0, tokenizer.vocab_size, size=prefix_len).tolist() + self._rng.integers(0, tokenizer.vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [] ) @@ -571,8 +575,7 @@ def get_sampling_params( if input_low > input_high: raise ValueError( - "Invalid input sampling interval: " - f"low={input_low} > high={input_high}" + f"Invalid input sampling interval: low={input_low} > high={input_high}" ) if output_low > output_high: raise ValueError( @@ -588,12 +591,9 @@ def get_sampling_params( output_high, ) - input_lens = self._rng.integers(input_low, input_high + 1, - size=num_requests) - output_lens = self._rng.integers(output_low, output_high + 1, - size=num_requests) - offsets = self._rng.integers(0, tokenizer.vocab_size, - size=num_requests) + input_lens = self._rng.integers(input_low, input_high + 1, size=num_requests) + output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests) + offsets = self._rng.integers(0, tokenizer.vocab_size, size=num_requests) return input_lens, output_lens, offsets def generate_token_sequence( @@ -620,18 +620,19 @@ def generate_token_sequence( the encoded sequence is truncated before being decoded again. """ # Build the inner sequence by sampling sequentially from the vocab - inner_seq = ((offset + index + np.arange(input_len)) - % vocab_size).tolist() + inner_seq = ((offset + index + np.arange(input_len)) % vocab_size).tolist() token_sequence = prefix_token_ids + inner_seq # Decode, then re-encode and truncate to preserve token count invariants total_input_len = prefix_len + int(input_len) - prompt, adjusted_token_sequence, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501 - tokenizer=tokenizer, - token_sequence=token_sequence, - target_token_len=total_input_len, - add_special_tokens=False, - rng=self._rng, + prompt, adjusted_token_sequence, token_mismatch = ( + gen_prompt_decode_to_target_len( + tokenizer=tokenizer, + token_sequence=token_sequence, + target_token_len=total_input_len, + add_special_tokens=False, + rng=self._rng, + ) ) total_input_len = len(adjusted_token_sequence) return prompt, total_input_len, token_mismatch @@ -641,6 +642,7 @@ def generate_token_sequence( # MultiModalDataset Implementation # ----------------------------------------------------------------------------- + class RandomMultiModalDataset(RandomDataset): """ Synthetic multimodal dataset (text + images) that extends RandomDataset. @@ -687,7 +689,6 @@ class RandomMultiModalDataset(RandomDataset): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - def generate_synthetic_image(self, width: int, height: int) -> Image.Image: """Generate synthetic PIL image with random RGB values. @@ -704,9 +705,7 @@ def generate_synthetic_image(self, width: int, height: int) -> Image.Image: ) return Image.fromarray(random_pixels) - def generate_synthetic_video(self, width: int, - height: int, - num_frames: int) -> Any: + def generate_synthetic_video(self, width: int, height: int, num_frames: int) -> Any: """Generate synthetic video with random values. TODO: Finish this method. @@ -722,8 +721,9 @@ def map_config_to_modality(self, config: tuple[int, int, int]) -> str: else: raise ValueError(f"Invalid multimodal item configuration: {config}") - def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], - float]) -> dict[tuple[int, int, int], float]: + def normalize_bucket_config( + self, bucket_config: dict[tuple[int, int, int], float] + ) -> dict[tuple[int, int, int], float]: """ Remove zero probability entries and normalize the bucket config to sum to 1. @@ -735,16 +735,17 @@ def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], bucket_config = {k: v for k, v in bucket_config.items() if v > 0} # if bucket config is empty, raise error if not bucket_config: - raise ValueError("Got invalid bucket config. " - "Bucket config values must be non-zero.") + raise ValueError( + "Got invalid bucket config. Bucket config values must be non-zero." + ) # Normalize the remaining bucket config to sum to 1 total = sum(bucket_config.values()) return {k: v / total for k, v in bucket_config.items()} - - def generate_mm_item(self, - mm_item_config: tuple[int, int, int], - ) -> Mapping[str, Any]: + def generate_mm_item( + self, + mm_item_config: tuple[int, int, int], + ) -> Mapping[str, Any]: """ Create synthetic images and videos and apply process_image/process_video respectively. @@ -753,18 +754,17 @@ def generate_mm_item(self, """ if self.map_config_to_modality(mm_item_config) == "image": - return process_image(self.generate_synthetic_image( - mm_item_config[1], - mm_item_config[0])) + return process_image( + self.generate_synthetic_image(mm_item_config[1], mm_item_config[0]) + ) elif self.map_config_to_modality(mm_item_config) == "video": - return process_video(self.generate_synthetic_video( - mm_item_config[1], - mm_item_config[0], - mm_item_config[2])) + return process_video( + self.generate_synthetic_video( + mm_item_config[1], mm_item_config[0], mm_item_config[2] + ) + ) else: - raise ValueError(f"Invalid multimodal item configuration: " - f"{mm_item_config}") - + raise ValueError(f"Invalid multimodal item configuration: {mm_item_config}") def get_mm_item_sampling_params( self, @@ -785,49 +785,53 @@ def get_mm_item_sampling_params( # get modality from bucket config modality = self.map_config_to_modality(k) if modality not in limit_mm_per_prompt: - raise ValueError(f"Modality {modality} is not in " - f"limit_mm_per_prompt: " - f"{limit_mm_per_prompt.keys()}") + raise ValueError( + f"Modality {modality} is not in " + f"limit_mm_per_prompt: " + f"{limit_mm_per_prompt.keys()}" + ) # Remove zero probability entries # and normalize bucket config to sum to 1 bucket_config = self.normalize_bucket_config(bucket_config) logger.info( - "Normalized bucket config: %s", bucket_config, + "Normalized bucket config: %s", + bucket_config, ) # Only consider limit per prompt for modalities in bucket config - allowed_modalities = {self.map_config_to_modality(cfg) - for cfg in bucket_config} + allowed_modalities = {self.map_config_to_modality(cfg) for cfg in bucket_config} limit_mm_per_prompt = { - k: v for k, v in limit_mm_per_prompt.items() - if k in allowed_modalities} + k: v for k, v in limit_mm_per_prompt.items() if k in allowed_modalities + } if not limit_mm_per_prompt: - raise ValueError("No valid limits for modalities present in " - "bucket_config.") + raise ValueError("No valid limits for modalities present in bucket_config.") logger.info( - "Updated mm-limit-per-prompt: %s", limit_mm_per_prompt, + "Updated mm-limit-per-prompt: %s", + limit_mm_per_prompt, ) # Get max and min num mm items and ensure # it is at most the sum of limit_mm_per_prompt for all modalities max_num_mm_items = min( sum(limit_mm_per_prompt.values()), - math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)) + math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)), ) # Ensure min num mm items is at least 0 min_num_mm_items = max( - 0, - math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) + 0, math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) ) # Raise error if min num mm items is greater than max num mm items if min_num_mm_items > max_num_mm_items: - raise ValueError(f"Min num mm items is greater than max mm items: " - f"{min_num_mm_items} > {max_num_mm_items}") + raise ValueError( + f"Min num mm items is greater than max mm items: " + f"{min_num_mm_items} > {max_num_mm_items}" + ) logger.info( "Sampling number of multimodal items from [%s, %s]", - min_num_mm_items, max_num_mm_items, + min_num_mm_items, + max_num_mm_items, ) return ( @@ -843,7 +847,7 @@ def get_mm_item_iterator( max_num_mm_items: int, bucket_config: dict[tuple[int, int, int], float], limit_mm_per_prompt: dict[str, int], - ) -> Iterator[tuple[int,int, int]]: + ) -> Iterator[tuple[int, int, int]]: """ Iterator over the multimodal items for each request whose size is between min_num_mm_items and max_num_mm_items. @@ -867,22 +871,20 @@ def get_mm_item_iterator( if request_num_mm_items == 0: return # Initialize modality counters - modality_counter = {self.map_config_to_modality(k): 0 - for k in bucket_config} + modality_counter = {self.map_config_to_modality(k): 0 for k in bucket_config} # Copy the bucket config to avoid modifying the original bucket_config_copy = bucket_config.copy() # Loop over the number of multimodal items to sample while sum(modality_counter.values()) < request_num_mm_items: # Sample a multimodal item config - mm_item_config = self._rng.choice(list(bucket_config_copy.keys()), - p=list(bucket_config_copy.values())) + mm_item_config = self._rng.choice( + list(bucket_config_copy.keys()), p=list(bucket_config_copy.values()) + ) modality = self.map_config_to_modality(mm_item_config) # Check that modality count is less than limit per prompt if modality_counter[modality] < limit_mm_per_prompt[modality]: modality_counter[modality] += 1 - yield ( - mm_item_config - ) + yield (mm_item_config) else: # If the counter is greater than the limit per prompt # set all multimodal items of this modality to 0 @@ -893,14 +895,12 @@ def get_mm_item_iterator( # This should not happen as request_num_mm_items is at most # the sum of limit_mm_per_prompt for all modalities if all(v == 0 for v in bucket_config_copy.values()): - logger.warning("Exhausted all multimodal items " - "of modality %s", - modality) + logger.warning( + "Exhausted all multimodal items of modality %s", modality + ) break # Renormalize the bucket config - bucket_config_copy = self.normalize_bucket_config( - bucket_config_copy) - + bucket_config_copy = self.normalize_bucket_config(bucket_config_copy) def sample( self, @@ -915,18 +915,21 @@ def sample( limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, - bucket_config: dict[tuple[int, int, int], float] = - DEFAULT_MM_ITEM_BUCKET_CONFIG, + bucket_config: dict[ + tuple[int, int, int], float + ] = DEFAULT_MM_ITEM_BUCKET_CONFIG, enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, **kwargs, ) -> list[SampleRequest]: - # NOTE: Video sampling is WIP. Raise error if video is in bucket config # and probability is non-zero. - if any(self.map_config_to_modality(cfg) == "video" and p > 0 - for cfg, p in bucket_config.items()): - raise NotImplementedError("Video sampling not implemented; " - "set its probability to 0.") + if any( + self.map_config_to_modality(cfg) == "video" and p > 0 + for cfg, p in bucket_config.items() + ): + raise NotImplementedError( + "Video sampling not implemented; set its probability to 0." + ) # Get the sampling parameters for the dataset input_lens, output_lens, offsets = self.get_sampling_params( @@ -952,7 +955,7 @@ def sample( mm_requests = [] token_mismatch_total = 0 for i in range(num_requests): - prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501 + prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501 tokenizer=tokenizer, prefix_token_ids=prefix_token_ids, prefix_len=prefix_len, @@ -970,17 +973,21 @@ def sample( limit_mm_per_prompt, ) - mm_content = cast(list[dict[str, Any]], [ - self.generate_mm_item(mm_item_config) - for mm_item_config in mm_item_iterator - ]) + mm_content = cast( + list[dict[str, Any]], + [ + self.generate_mm_item(mm_item_config) + for mm_item_config in mm_item_iterator + ], + ) if enable_multimodal_chat: # NOTE: For now this option is only provided for completeness # given that the serve.py benchmark currently does not use it. mm_chat_prompt: Any = prompt mm_chat_prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt, mm_content + ) sample_request = SampleRequest( prompt=mm_chat_prompt, prompt_len=total_input_len, @@ -1011,6 +1018,7 @@ def sample( return mm_requests + # ----------------------------------------------------------------------------- # ShareGPT Dataset Implementation # ----------------------------------------------------------------------------- @@ -1034,11 +1042,13 @@ def load_data(self) -> None: self.data = json.load(f) # Filter entries with at least two conversation turns. self.data = [ - entry for entry in self.data + entry + for entry in self.data if "conversations" in entry and len(entry["conversations"]) >= 2 ] random.seed(self.random_seed) - random.shuffle(self.data) + if not getattr(self, "disable_shuffle", False): + random.shuffle(self.data) def sample( self, @@ -1063,16 +1073,17 @@ def sample( ) lora_request = self.get_random_lora_request( - max_loras=max_loras, lora_path=lora_path) + max_loras=max_loras, lora_path=lora_path + ) prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids prompt_len = len(prompt_ids) - new_output_len = (len(completion_ids) - if output_len is None else output_len) - if not is_valid_sequence(prompt_len, - new_output_len, - skip_min_output_len_check=output_len - is not None): + new_output_len = len(completion_ids) if output_len is None else output_len + if not is_valid_sequence( + prompt_len, + new_output_len, + skip_min_output_len_check=output_len is not None, + ): continue if image_path := entry.get("image"): mm_content = process_image(image_path) @@ -1081,8 +1092,7 @@ def sample( else: mm_content = None if enable_multimodal_chat: - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) samples.append( SampleRequest( prompt=prompt, @@ -1091,23 +1101,24 @@ def sample( lora_request=lora_request, multi_modal_data=mm_content, request_id=request_id_prefix + str(ind), - )) + ) + ) ind += 1 - self.maybe_oversample_requests(samples, - num_requests, - request_id_prefix, - no_oversample) + self.maybe_oversample_requests( + samples, num_requests, request_id_prefix, no_oversample + ) return samples class _ValidateDatasetArgs(argparse.Action): """Argparse action to validate dataset name and path compatibility.""" + def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, values) # Get current values of both dataset_name and dataset_path - dataset_name = getattr(namespace, 'dataset_name', 'random') - dataset_path = getattr(namespace, 'dataset_path', None) + dataset_name = getattr(namespace, "dataset_name", "random") + dataset_path = getattr(namespace, "dataset_path", None) # Validate the combination if dataset_name == "random" and dataset_path is not None: @@ -1133,8 +1144,15 @@ def add_dataset_parser(parser: FlexibleArgumentParser): default="random", action=_ValidateDatasetArgs, choices=[ - "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", - "custom", "prefix_repetition", "spec_bench" + "sharegpt", + "burstgpt", + "sonnet", + "random", + "random-mm", + "hf", + "custom", + "prefix_repetition", + "spec_bench", ], help="Name of the dataset to benchmark on.", ) @@ -1154,14 +1172,17 @@ def add_dataset_parser(parser: FlexibleArgumentParser): parser.add_argument( "--no-oversample", action="store_true", - help="Do not oversample if the dataset has " \ - "fewer samples than num-prompts.", + help="Do not oversample if the dataset has fewer samples than num-prompts.", ) parser.add_argument( "--skip-chat-template", action="store_true", - help= - "Skip applying chat template to prompt for datasets that support it.", + help="Skip applying chat template to prompt for datasets that support it.", + ) + parser.add_argument( + "--disable-shuffle", + action="store_true", + help="Disable shuffling of dataset samples for deterministic ordering.", ) # group for dataset specific arguments @@ -1170,8 +1191,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--custom-output-len", type=int, default=256, - help= - "Number of output tokens per request, used only for custom dataset.", + help="Number of output tokens per request, used only for custom dataset.", ) spec_bench_group = parser.add_argument_group("spec bench dataset options") @@ -1179,15 +1199,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--spec-bench-output-len", type=int, default=256, - help= - "Num of output tokens per request, used only for spec bench dataset.", + help="Num of output tokens per request, used only for spec bench dataset.", ) spec_bench_group.add_argument( "--spec-bench-category", type=str, default=None, - help= - "Category for spec bench dataset. If None, use all categories.", + help="Category for spec bench dataset. If None, use all categories.", ) sonnet_group = parser.add_argument_group("sonnet dataset options") @@ -1195,22 +1213,19 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--sonnet-input-len", type=int, default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", + help="Number of input tokens per request, used only for sonnet dataset.", ) sonnet_group.add_argument( "--sonnet-output-len", type=int, default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", + help="Number of output tokens per request, used only for sonnet dataset.", ) sonnet_group.add_argument( "--sonnet-prefix-len", type=int, default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", + help="Number of prefix tokens per request, used only for sonnet dataset.", ) sharegpt_group = parser.add_argument_group("sharegpt dataset options") @@ -1227,15 +1242,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--blazedit-min-distance", type=float, default=0.0, - help= - "Minimum distance for blazedit dataset. Min: 0, Max: 1.0", + help="Minimum distance for blazedit dataset. Min: 0, Max: 1.0", ) blazedit_group.add_argument( "--blazedit-max-distance", type=float, default=1.0, - help= - "Maximum distance for blazedit dataset. Min: 0, Max: 1.0", + help="Maximum distance for blazedit dataset. Min: 0, Max: 1.0", ) random_group = parser.add_argument_group("random dataset options") @@ -1243,15 +1256,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--random-input-len", type=int, default=1024, - help= - "Number of input tokens per request, used only for random sampling.", + help="Number of input tokens per request, used only for random sampling.", ) random_group.add_argument( "--random-output-len", type=int, default=128, - help= - "Number of output tokens per request, used only for random sampling.", + help="Number of output tokens per request, used only for random sampling.", ) random_group.add_argument( "--random-range-ratio", @@ -1266,24 +1277,26 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "--random-prefix-len", type=int, default=0, - help=("Number of fixed prefix tokens before the random context " - "in a request. " - "The total input length is the sum of `random-prefix-len` and " - "a random " - "context length sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]."), + help=( + "Number of fixed prefix tokens before the random context " + "in a request. " + "The total input length is the sum of `random-prefix-len` and " + "a random " + "context length sampled from [input_len * (1 - range_ratio), " + "input_len * (1 + range_ratio)]." + ), ) random_group.add_argument( "--random-batch-size", type=int, default=1, - help=("Batch size for random sampling. " - "Only used for embeddings benchmark."), + help=("Batch size for random sampling. Only used for embeddings benchmark."), ) # random multimodal dataset options random_mm_group = parser.add_argument_group( - "random multimodal dataset options extended from random dataset") + "random multimodal dataset options extended from random dataset" + ) random_mm_group.add_argument( "--random-mm-base-items-per-request", type=int, @@ -1315,7 +1328,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT, help=( "Per-modality hard caps for items attached per request, e.g. " - "'{\"image\": 3, \"video\": 0}'. The sampled per-request item " + '\'{"image": 3, "video": 0}\'. The sampled per-request item ' "count is clamped to the sum of these limits. When a modality " "reaches its cap, its buckets are excluded and probabilities are " "renormalized." @@ -1332,8 +1345,11 @@ def normalize(d: dict) -> dict[tuple[int, int, int], float]: if isinstance(key, str): with suppress(Exception): key = ast.literal_eval(key) - if not (isinstance(key, tuple) and len(key) == 3 - and all(isinstance(x, int) for x in key)): + if not ( + isinstance(key, tuple) + and len(key) == 3 + and all(isinstance(x, int) for x in key) + ): raise ValueError( f"Invalid bucket key {k!r}. Expected tuple (H, W, T)." ) @@ -1372,14 +1388,12 @@ def normalize(d: dict) -> dict[tuple[int, int, int], float]: ) hf_group = parser.add_argument_group("hf dataset options") - hf_group.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - hf_group.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-subset", type=str, default=None, help="Subset of the HF dataset." + ) + hf_group.add_argument( + "--hf-split", type=str, default=None, help="Split of the HF dataset." + ) hf_group.add_argument( "--hf-name", type=str, @@ -1399,7 +1413,8 @@ def normalize(d: dict) -> dict[tuple[int, int, int], float]: ) prefix_repetition_group = parser.add_argument_group( - "prefix repetition dataset options") + "prefix repetition dataset options" + ) prefix_repetition_group.add_argument( "--prefix-repetition-prefix-len", type=int, @@ -1431,12 +1446,13 @@ def normalize(d: dict) -> dict[tuple[int, int, int], float]: def get_samples(args, tokenizer) -> list[SampleRequest]: - if not hasattr(args, "request_id_prefix"): args.request_id_prefix = "" if args.dataset_name == "custom": - dataset = CustomDataset(dataset_path=args.dataset_path) + dataset = CustomDataset( + dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle + ) input_requests = dataset.sample( num_requests=args.num_prompts, tokenizer=tokenizer, @@ -1447,7 +1463,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ) elif args.dataset_name == "sonnet": - dataset = SonnetDataset(dataset_path=args.dataset_path) + dataset = SonnetDataset( + dataset_path=args.dataset_path, disable_shuffle=args.disable_shuffle + ) # For the "sonnet" dataset, formatting depends on the backend. if args.backend == "openai-chat": input_requests = dataset.sample( @@ -1462,7 +1480,8 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") + "Tokenizer/model must have chat template for sonnet dataset." + ) input_requests = dataset.sample( num_requests=args.num_prompts, input_len=args.sonnet_input_len, @@ -1516,8 +1535,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: dataset_class = AIMODataset args.hf_split = "train" elif ( - args.dataset_path - in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS # noqa: E501 + args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS # noqa: E501 or args.hf_name in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS ): dataset_class = NextEditPredictionDataset @@ -1549,26 +1567,31 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: args.hf_split = "val" args.hf_subset = None else: - supported_datasets = set([ - dataset_name for cls in HuggingFaceDataset.__subclasses__() - for dataset_name in cls.SUPPORTED_DATASET_PATHS - ]) + supported_datasets = set( + [ + dataset_name + for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ] + ) raise ValueError( f"Unsupported dataset path: {args.dataset_path}. " "Huggingface dataset only supports dataset_path" f" from one of following: {supported_datasets}. " "Please consider contributing if you would " - "like to add support for additional dataset formats.") + "like to add support for additional dataset formats." + ) if dataset_class.IS_MULTIMODAL and args.backend not in [ - "openai-chat", - "openai-audio", + "openai-chat", + "openai-audio", ]: # multi-modal benchmark is only available on OpenAI Chat # endpoint-type. raise ValueError( "Multi-modal content is only supported on 'openai-chat' and " - "'openai-audio' backends.") + "'openai-audio' backends." + ) input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, @@ -1576,6 +1599,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: random_seed=args.seed, no_stream=args.no_stream, hf_name=args.hf_name, + disable_shuffle=args.disable_shuffle, ).sample( num_requests=args.num_prompts, tokenizer=tokenizer, @@ -1583,15 +1607,17 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: request_id_prefix=args.request_id_prefix, no_oversample=args.no_oversample, skip_chat_template=args.skip_chat_template, - **hf_kwargs + **hf_kwargs, ) else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { - "spec_bench": - lambda: SpecBench(dataset_path=args.dataset_path, - category=args.spec_bench_category).sample( + "spec_bench": lambda: SpecBench( + dataset_path=args.dataset_path, + category=args.spec_bench_category, + disable_shuffle=args.disable_shuffle, + ).sample( num_requests=args.num_prompts, tokenizer=tokenizer, output_len=args.spec_bench_output_len, @@ -1599,7 +1625,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: no_oversample=args.no_oversample, ), "sharegpt": lambda: ShareGPTDataset( - random_seed=args.seed, dataset_path=args.dataset_path + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -1608,7 +1636,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: no_oversample=args.no_oversample, ), "burstgpt": lambda: BurstGPTDataset( - random_seed=args.seed, dataset_path=args.dataset_path + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -1616,7 +1646,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: no_oversample=args.no_oversample, ), "random": lambda: RandomDataset( - random_seed=args.seed, dataset_path=args.dataset_path + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -1628,9 +1660,10 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: batchsize=args.random_batch_size, no_oversample=args.no_oversample, ), - "random-mm": - lambda: RandomMultiModalDataset( - random_seed=args.seed, dataset_path=args.dataset_path + "random-mm": lambda: RandomMultiModalDataset( + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -1645,9 +1678,10 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: request_id_prefix=args.request_id_prefix, no_oversample=args.no_oversample, ), - "prefix_repetition": - lambda: PrefixRepetitionRandomDataset( - random_seed=args.seed, dataset_path=args.dataset_path + "prefix_repetition": lambda: PrefixRepetitionRandomDataset( + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, ).sample( tokenizer=tokenizer, num_requests=args.num_prompts, @@ -1662,8 +1696,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: try: # Enforce endpoint compatibility for multimodal datasets. - if args.dataset_name == "random-mm" and args.backend not in [ - "openai-chat"]: + if args.dataset_name == "random-mm" and args.backend not in ["openai-chat"]: raise ValueError( "Multi-modal content (images) is only supported on " "'openai-chat' backend." @@ -1708,8 +1741,7 @@ def load_data(self) -> None: # Load the JSONL file if self.dataset_path.endswith(".jsonl"): - jsonl_data = pd.read_json(path_or_buf=self.dataset_path, - lines=True) + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True) # check if the JSONL file has a 'prompt' column if "prompt" not in jsonl_data.columns: @@ -1723,10 +1755,12 @@ def load_data(self) -> None: self.data.append(row.to_dict()) else: raise NotImplementedError( - "Only JSONL format is supported for CustomDataset.") + "Only JSONL format is supported for CustomDataset." + ) random.seed(self.random_seed) - random.shuffle(self.data) + if not getattr(self, "disable_shuffle", False): + random.shuffle(self.data) def sample( self, @@ -1745,9 +1779,11 @@ def sample( self.num_available_samples = len(self.data) if num_requests <= 0: num_requests = self.num_available_samples - logger.info("num_requests is set to 0 or negative, " - "so using all available samples: %d", - num_requests) + logger.info( + "num_requests is set to 0 or negative, " + "so using all available samples: %d", + num_requests, + ) sampled_requests = [] for i, item in enumerate(self.data): @@ -1758,10 +1794,7 @@ def sample( # apply template if not skip_chat_template: prompt = tokenizer.apply_chat_template( - [{ - "role": "user", - "content": prompt - }], + [{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False, ) @@ -1773,9 +1806,11 @@ def sample( prompt_len=prompt_len, expected_output_len=output_len, request_id=request_id_prefix + str(i), - )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix, no_oversample) + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -1790,7 +1825,7 @@ class SpecBench(CustomDataset): Implements the SpecBench dataset: https://github.com/hemingkx/Spec-Bench Download the dataset using: wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl - """ # noqa: E501 + """ # noqa: E501 def __init__(self, **kwargs) -> None: self.category = kwargs.pop("category", None) @@ -1804,8 +1839,7 @@ def load_data(self) -> None: self.data = [] # Load the JSONL file - jsonl_data = pd.read_json(path_or_buf=self.dataset_path, - lines=True) + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, lines=True) # check if the JSONL file has a 'turns' column if "turns" not in jsonl_data.columns: @@ -1813,12 +1847,13 @@ def load_data(self) -> None: for _, row in jsonl_data.iterrows(): # sample only from a specific category if specified - if (not self.category) or (self.category == row['category']): + if (not self.category) or (self.category == row["category"]): prompt = row["turns"][0] self.data.append({"prompt": prompt}) random.seed(self.random_seed) - random.shuffle(self.data) + if not getattr(self, "disable_shuffle", False): + random.shuffle(self.data) def sample(self, **kwargs) -> list: # leverage CustomDataset sample @@ -1829,6 +1864,7 @@ def sample(self, **kwargs) -> list: # Sonnet Dataset Implementation # ----------------------------------------------------------------------------- + @deprecated( "SonnetDataset is deprecated and will be removed in a future version.", ) @@ -1870,20 +1906,20 @@ def sample( ) -> list: # Calculate average token length for a poem line. tokenized_lines = [tokenizer(line).input_ids for line in self.data] - avg_len = sum(len(tokens) - for tokens in tokenized_lines) / len(tokenized_lines) + avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines) # Build the base prompt. base_prompt = "Pick as many lines as you can from these poem lines:\n" base_msg = [{"role": "user", "content": base_prompt}] - base_fmt = tokenizer.apply_chat_template(base_msg, - add_generation_prompt=True, - tokenize=False) + base_fmt = tokenizer.apply_chat_template( + base_msg, add_generation_prompt=True, tokenize=False + ) base_offset = len(tokenizer(base_fmt).input_ids) if input_len <= base_offset: raise ValueError( f"'input_len' must be higher than the base prompt length " - f"({base_offset}).") + f"({base_offset})." + ) # Determine how many poem lines to use. num_input_lines = round((input_len - base_offset) / avg_len) @@ -1893,22 +1929,24 @@ def sample( samples = [] ind = 0 while len(samples) < num_requests: - extra_lines = random.choices(self.data, - k=num_input_lines - num_prefix_lines) + extra_lines = random.choices( + self.data, k=num_input_lines - num_prefix_lines + ) prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" msg = [{"role": "user", "content": prompt}] prompt_formatted = tokenizer.apply_chat_template( - msg, add_generation_prompt=True, tokenize=False) + msg, add_generation_prompt=True, tokenize=False + ) prompt_len = len(tokenizer(prompt_formatted).input_ids) if prompt_len <= input_len: samples.append( SampleRequest( - prompt=prompt_formatted - if return_prompt_formatted else prompt, + prompt=prompt_formatted if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, - request_id=request_id_prefix + str(ind), - )) + request_id=request_id_prefix + str(ind), + ) + ) ind += 1 return samples @@ -1929,7 +1967,9 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.load_data() - def load_data(self, ): + def load_data( + self, + ): if self.dataset_path is None: raise ValueError("dataset_path must be provided for loading data.") @@ -1943,8 +1983,7 @@ def load_data(self, ): def _sample_loaded_data(self, num_requests: int) -> list: if num_requests <= len(self.data): - data = self.data.sample(n=num_requests, - random_state=self.random_seed) + data = self.data.sample(n=num_requests, random_state=self.random_seed) else: data = self.data.sample( n=num_requests, @@ -1970,7 +2009,8 @@ def sample( input_len = int(data[i][2]) output_len = int(data[i][3]) lora_req = self.get_random_lora_request( - max_loras=max_loras, lora_path=lora_path) + max_loras=max_loras, lora_path=lora_path + ) vocab_size = tokenizer.vocab_size # Generate a synthetic prompt: a list of token IDs computed as (i + # j) modulo vocab_size. @@ -1983,7 +2023,8 @@ def sample( expected_output_len=output_len, lora_request=lora_req, request_id=request_id_prefix + str(i), - )) + ) + ) return samples @@ -2020,7 +2061,8 @@ def load_data(self) -> None: split=self.dataset_split, streaming=self.load_stream, ) - self.data = self.data.shuffle(seed=self.random_seed) + if not getattr(self, "disable_shuffle", False): + self.data = self.data.shuffle(seed=self.random_seed) # ----------------------------------------------------------------------------- @@ -2030,22 +2072,25 @@ def load_data(self) -> None: class ConversationDataset(HuggingFaceDataset): """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { - 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + "lmms-lab/LLaVA-OneVision-Data", + "Aeala/ShareGPT_Vicuna_unfiltered", } IS_MULTIMODAL = True - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - request_id_prefix: str = "", - no_oversample: bool = False, - **kwargs) -> list: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list: # Filter examples with at least 2 conversations - filtered_data = self.data.filter( - lambda x: len(x["conversations"]) >= 2) + filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) sampled_requests = [] ind = 0 dynamic_output = output_len is None @@ -2062,17 +2107,14 @@ def sample(self, completion_len = len(completion_ids) output_len = completion_len if dynamic_output else output_len assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence( - prompt_len, completion_len): + if dynamic_output and not is_valid_sequence(prompt_len, completion_len): continue - mm_content = process_image( - item["image"]) if "image" in item else None + mm_content = process_image(item["image"]) if "image" in item else None if enable_multimodal_chat: # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len and output len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, @@ -2080,10 +2122,12 @@ def sample(self, expected_output_len=output_len, multi_modal_data=mm_content, request_id=request_id_prefix + str(ind), - )) + ) + ) ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix, no_oversample) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -2099,10 +2143,8 @@ class VisionArenaDataset(HuggingFaceDataset): DEFAULT_OUTPUT_LEN = 128 SUPPORTED_DATASET_PATHS = { - "lmarena-ai/VisionArena-Chat": - lambda x: x["conversation"][0][0]["content"], - "lmarena-ai/vision-arena-bench-v0.1": - lambda x: x["turns"][0][0]["content"] + "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"], } IS_MULTIMODAL = True @@ -2116,8 +2158,7 @@ def sample( no_oversample: bool = False, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: @@ -2132,8 +2173,7 @@ def sample( # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, @@ -2141,9 +2181,11 @@ def sample( expected_output_len=output_len, multi_modal_data=mm_content, request_id=request_id_prefix + str(i), - )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix, no_oversample) + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -2155,10 +2197,9 @@ class MMVUDataset(HuggingFaceDataset): DEFAULT_OUTPUT_LEN = 128 SUPPORTED_DATASET_PATHS = { - "yale-nlp/MMVU": - lambda x: x["question"] + " " + ( - " ".join(f"{k}.{v}" for k, v in x["choices"].items()) - ), + "yale-nlp/MMVU": lambda x: x["question"] + + " " + + (" ".join(f"{k}.{v}" for k, v in x["choices"].items())), } def sample( @@ -2171,8 +2212,7 @@ def sample( no_oversample: bool = False, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: @@ -2187,8 +2227,7 @@ def sample( # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, @@ -2196,9 +2235,11 @@ def sample( expected_output_len=output_len, multi_modal_data=mm_content, request_id=request_id_prefix + str(i), - )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix, no_oversample) + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -2222,17 +2263,18 @@ class InstructCoderDataset(HuggingFaceDataset): "likaixin/InstructCoder", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - skip_chat_template: bool = False, - request_id_prefix: str = "", - no_oversample: bool = False, - **kwargs) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: @@ -2245,10 +2287,7 @@ def sample(self, # apply template if not skip_chat_template: prompt = tokenizer.apply_chat_template( - [{ - "role": "user", - "content": prompt - }], + [{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False, ) @@ -2260,9 +2299,11 @@ def sample(self, prompt_len=prompt_len, expected_output_len=output_len, request_id=request_id_prefix + str(i), - )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix, no_oversample) + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -2297,8 +2338,7 @@ def sample( no_oversample: bool = False, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for i, item in enumerate(self.data): @@ -2309,10 +2349,7 @@ def sample( # apply template if not skip_chat_template: prompt = tokenizer.apply_chat_template( - [{ - "role": "user", - "content": prompt - }], + [{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False, ) @@ -2324,9 +2361,11 @@ def sample( prompt_len=prompt_len, expected_output_len=output_len, request_id=request_id_prefix + str(i), - )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix, no_oversample) + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -2366,8 +2405,7 @@ def sample( max_distance: float = 1.0, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for i, item in enumerate(self.data): @@ -2393,15 +2431,12 @@ def sample( Change request: {change_request} -Please generate the new code file in the "New file" section below.""" # noqa: E501 +Please generate the new code file in the "New file" section below.""" # noqa: E501 # apply template if not skip_chat_template: prompt = tokenizer.apply_chat_template( - [{ - "role": "user", - "content": prompt - }], + [{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False, ) @@ -2414,9 +2449,11 @@ def sample( prompt_len=prompt_len, expected_output_len=output_len, request_id=request_id_prefix + str(i), - )) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix, no_oversample) + ) + ) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -2430,18 +2467,22 @@ class AIMODataset(HuggingFaceDataset): """ Dataset class for processing a AIMO dataset with reasoning questions. """ + SUPPORTED_DATASET_PATHS = { - "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", - "AI-MO/NuminaMath-CoT" + "AI-MO/aimo-validation-aime", + "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - request_id_prefix: str = "", - no_oversample: bool = False, - **kwargs) -> list: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list: sampled_requests = [] ind = 0 dynamic_output = output_len is None @@ -2449,7 +2490,7 @@ def sample(self, for item in self.data: if len(sampled_requests) >= num_requests: break - prompt, completion = item['problem'], item["solution"] + prompt, completion = item["problem"], item["solution"] prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids @@ -2457,10 +2498,9 @@ def sample(self, completion_len = len(completion_ids) output_len = completion_len if dynamic_output else output_len assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence(prompt_len, - completion_len, - max_prompt_len=2048, - max_total_len=32000): + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000 + ): continue sampled_requests.append( SampleRequest( @@ -2469,10 +2509,12 @@ def sample(self, expected_output_len=output_len, multi_modal_data=None, request_id=request_id_prefix + str(ind), - )) + ) + ) ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix, no_oversample) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -2494,12 +2536,12 @@ def sample(self, ### Response: -""" # noqa: E501 +""" # noqa: E501 def _format_zeta_prompt( - sample: dict, - original_start_marker: str = "<|editable_region_start|>") -> dict: + sample: dict, original_start_marker: str = "<|editable_region_start|>" +) -> dict: """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. This function formats examples from the NEP dataset @@ -2542,10 +2584,14 @@ class NextEditPredictionDataset(HuggingFaceDataset): "zed-industries/zeta": _format_zeta_prompt, } - def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, - request_id_prefix: str = "", - no_oversample: bool = False, - **kwargs): + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ): formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.hf_name) if formatting_prompt_func is None: raise ValueError(f"Unsupported dataset path: {self.hf_name}") @@ -2557,15 +2603,16 @@ def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, prompt=sample["prompt"], prompt_len=len(tokenizer(sample["prompt"]).input_ids), expected_output_len=len( - tokenizer(sample["expected_output"]).input_ids), + tokenizer(sample["expected_output"]).input_ids + ), request_id=request_id_prefix + str(i), - )) + ) + ) if len(samples) >= num_requests: break - self.maybe_oversample_requests(samples, - num_requests, - request_id_prefix, - no_oversample) + self.maybe_oversample_requests( + samples, num_requests, request_id_prefix, no_oversample + ) return samples @@ -2606,8 +2653,7 @@ class ASRDataset(HuggingFaceDataset): IS_MULTIMODAL = True # TODO Whisper-specific. Abstract interface when more models are supported. - TRANSCRIPTION_PREAMBLE = ( - "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>") + TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" skip_long_audios: bool = True def sample( @@ -2619,8 +2665,7 @@ def sample( no_oversample: bool = False, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] @@ -2645,7 +2690,8 @@ def sample( expected_output_len=output_len, multi_modal_data=mm_content, request_id=request_id_prefix + str(ind), - )) + ) + ) ind += 1 if skipped: logger.warning( @@ -2654,8 +2700,9 @@ def sample( " what Whisper supports.", skipped, ) - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix, no_oversample) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -2739,8 +2786,9 @@ def sample( ) ind += 1 - self.maybe_oversample_requests(sampled_requests, num_requests, - request_id_prefix, no_oversample) + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) return sampled_requests @@ -2789,10 +2837,9 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]: """Generate tokens that decode and re-encode to exactly target_length.""" # Generate random tokens - tokens = np.random.randint( - 0, vocab_size, size=target_length).tolist() + tokens = np.random.randint(0, vocab_size, size=target_length).tolist() - _, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501 + _, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len( # noqa: E501 tokenizer=tokenizer, token_sequence=tokens, target_token_len=target_length, @@ -2806,7 +2853,9 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]: prefix_tokens = _generate_exact_length_tokens(prefix_len) for _ in range(prompts_per_prefix): - suffix_tokens, token_mistmatch = _generate_exact_length_tokens(suffix_len) # noqa: E501 + suffix_tokens, token_mistmatch = _generate_exact_length_tokens( + suffix_len + ) token_mismatch_total += token_mistmatch combined_tokens = prefix_tokens + suffix_tokens prompt = tokenizer.decode(combined_tokens) @@ -2829,7 +2878,8 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]: abs(token_mismatch_total), sign, ) - random.shuffle(requests) + if not getattr(self, "disable_shuffle", False): + random.shuffle(requests) return requests @@ -2843,6 +2893,7 @@ class MMStarDataset(HuggingFaceDataset): Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar refer to: https://github.com/sgl-project/SpecForge/pull/106 """ + DEFAULT_OUTPUT_LEN = 128 SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"} IS_MULTIMODAL = True @@ -2858,8 +2909,7 @@ def sample( **kwargs, ) -> list[SampleRequest]: # If --hf-output-len is not set, use the default output length. - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests: list[SampleRequest] = [] for ind, item in enumerate(self.data): diff --git a/vllm/benchmarks/latency.py b/vllm/benchmarks/latency.py index 05378ec74d2f..7692697fe768 100644 --- a/vllm/benchmarks/latency.py +++ b/vllm/benchmarks/latency.py @@ -13,20 +13,20 @@ from tqdm import tqdm import vllm.envs as envs -from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, - write_to_json) +from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.sampling_params import BeamSearchParams -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any]) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={"latency": results["latencies"]}, - extra_info={k: results[k] - for k in ["avg_latency", "percentiles"]}) + extra_info={k: results[k] for k in ["avg_latency", "percentiles"]}, + ) if pt_records: pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" write_to_json(pt_file, pt_records) @@ -49,10 +49,9 @@ def add_cli_args(parser: argparse.ArgumentParser): default=10, help="Number of iterations to run for warmup.", ) - parser.add_argument("--num-iters", - type=int, - default=30, - help="Number of iterations to run.") + parser.add_argument( + "--num-iters", type=int, default=30, help="Number of iterations to run." + ) parser.add_argument( "--profile", action="store_true", @@ -67,8 +66,10 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--disable-detokenize", action="store_true", - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) @@ -81,7 +82,8 @@ def main(args: argparse.Namespace): if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: raise OSError( "The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. " - "Please set it to a valid path to use torch profiler.") + "Please set it to a valid path to use torch profiler." + ) engine_args = EngineArgs.from_cli_args(args) # Lazy import to avoid importing LLM when the bench command is not selected. @@ -91,9 +93,11 @@ def main(args: argparse.Namespace): # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args)) assert llm.llm_engine.model_config.max_model_len >= ( - args.input_len + - args.output_len), ("Please ensure that max_model_len is greater than" - " the sum of input_len and output_len.") + args.input_len + args.output_len + ), ( + "Please ensure that max_model_len is greater than" + " the sum of input_len and output_len." + ) sampling_params = SamplingParams( n=args.n, @@ -103,18 +107,16 @@ def main(args: argparse.Namespace): max_tokens=args.output_len, detokenize=not args.disable_detokenize, ) - dummy_prompt_token_ids = np.random.randint(10000, - size=(args.batch_size, - args.input_len)) - dummy_prompts: list[PromptType] = [{ - "prompt_token_ids": batch - } for batch in dummy_prompt_token_ids.tolist()] + dummy_prompt_token_ids = np.random.randint( + 10000, size=(args.batch_size, args.input_len) + ) + dummy_prompts: list[PromptType] = [ + {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist() + ] def llm_generate(): if not args.use_beam_search: - llm.generate(dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) else: llm.beam_search( dummy_prompts, diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 725b7df8b187..425a171c3c06 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -62,6 +62,7 @@ def add_chunk(self, chunk_bytes: bytes) -> list[str]: @dataclass class RequestFuncInput: """The input for the request function.""" + prompt: str api_url: str prompt_len: int @@ -80,13 +81,13 @@ class RequestFuncInput: @dataclass class RequestFuncOutput: """The output of the request function including metrics.""" + generated_text: str = "" success: bool = False latency: float = 0.0 output_tokens: int = 0 ttft: float = 0.0 # Time to first token - itl: list[float] = field( - default_factory=list) # list of inter-token latencies + itl: list[float] = field(default_factory=list) # list of inter-token latencies tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" @@ -99,8 +100,7 @@ def __call__( request_func_input: RequestFuncInput, session: aiohttp.ClientSession, pbar: Optional[tqdm] = None, - ) -> Awaitable[RequestFuncOutput]: - ... + ) -> Awaitable[RequestFuncOutput]: ... async def async_request_openai_completions( @@ -118,13 +118,14 @@ async def async_request_openai_completions( The output of the request function. """ api_url = request_func_input.api_url - assert api_url.endswith( - ("completions", "profile") - ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + assert api_url.endswith(("completions", "profile")), ( + "OpenAI Completions API URL must end with 'completions' or 'profile'." + ) payload = { "model": request_func_input.model_name - if request_func_input.model_name else request_func_input.model, + if request_func_input.model_name + else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "repetition_penalty": 1.0, @@ -139,9 +140,7 @@ async def async_request_openai_completions( payload["ignore_eos"] = request_func_input.ignore_eos if request_func_input.extra_body: payload.update(request_func_input.extra_body) - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" - } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} if request_func_input.extra_headers: headers |= request_func_input.extra_headers if request_func_input.request_id: @@ -155,8 +154,7 @@ async def async_request_openai_completions( output.start_time = st most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: + async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: first_chunk_received = False handler = StreamedResponseHandler() @@ -195,21 +193,20 @@ async def async_request_openai_completions( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text += text or "" elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + output.output_tokens = usage.get("completion_tokens") if first_chunk_received: output.success = True else: output.success = False output.error = ( "Never received a valid chunk to calculate TTFT." - "This response will be marked as failed!") + "This response will be marked as failed!" + ) output.generated_text = generated_text output.latency = most_recent_timestamp - st else: @@ -232,7 +229,8 @@ async def async_request_openai_chat_completions( ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith(("chat/completions", "profile")), ( - "OpenAI Chat Completions API URL must end with 'chat/completions'.") + "OpenAI Chat Completions API URL must end with 'chat/completions'." + ) content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: @@ -243,25 +241,18 @@ async def async_request_openai_chat_completions( content.append(mm_content) else: raise TypeError( - "multi_modal_content must be a dict or list[dict] " - "for openai-chat" + "multi_modal_content must be a dict or list[dict] for openai-chat" ) payload = { - "model": - request_func_input.model_name - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "messages": [ - { - "role": "user", - "content": content - }, + {"role": "user", "content": content}, ], - "temperature": - 0.0, - "max_completion_tokens": - request_func_input.output_len, - "stream": - True, + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, "stream_options": { "include_usage": True, }, @@ -288,8 +279,7 @@ async def async_request_openai_chat_completions( output.start_time = st most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: + async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: handler = StreamedResponseHandler() async for chunk_bytes in response.content.iter_any(): @@ -320,13 +310,11 @@ async def async_request_openai_chat_completions( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) generated_text += content or "" elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + output.output_tokens = usage.get("completion_tokens") most_recent_timestamp = timestamp @@ -356,27 +344,22 @@ async def async_request_openai_audio( api_url = request_func_input.api_url assert api_url.endswith(("transcriptions", "translations")), ( - "OpenAI Chat Completions API URL must end with 'transcriptions' ") + "OpenAI Chat Completions API URL must end with 'transcriptions' " + ) "or `translations`." content = [{"type": "text", "text": request_func_input.prompt}] payload = { - "model": - request_func_input.model_name - if request_func_input.model_name else request_func_input.model, - "temperature": - 0.0, - "max_completion_tokens": - request_func_input.output_len, - "stream": - True, - "language": - "en", + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "language": "en", # Flattened due to multipart/form-data - "stream_include_usage": - True, - "stream_continuous_usage_stats": - True, + "stream_include_usage": True, + "stream_continuous_usage_stats": True, } if request_func_input.extra_body: payload.update(request_func_input.extra_body) @@ -413,9 +396,9 @@ def to_bytes(y, sr): output.start_time = st most_recent_timestamp = st try: - async with session.post(url=api_url, - data=form, - headers=headers) as response: + async with session.post( + url=api_url, data=form, headers=headers + ) as response: if response.status == 200: handler = StreamedResponseHandler() @@ -426,15 +409,13 @@ def to_bytes(y, sr): messages = handler.add_chunk(chunk_bytes) for message in messages: - chunk = message.decode("utf-8").removeprefix( - "data: ") + chunk = message.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) if choices := data.get("choices"): - content = choices[0]["delta"].get( - "content") + content = choices[0]["delta"].get("content") # First token if ttft == 0.0: ttft = timestamp - st @@ -443,12 +424,14 @@ def to_bytes(y, sr): # Decoding phase else: output.itl.append( - timestamp - most_recent_timestamp) + timestamp - most_recent_timestamp + ) generated_text += content or "" elif usage := data.get("usage"): output.output_tokens = usage.get( - "completion_tokens") + "completion_tokens" + ) most_recent_timestamp = timestamp @@ -474,9 +457,9 @@ async def async_request_openai_embeddings( pbar: Optional[tqdm] = None, ): api_url = request_func_input.api_url - assert api_url.endswith( - "embeddings" - ), "OpenAI Embeddings API URL must end with 'embeddings'." + assert api_url.endswith("embeddings"), ( + "OpenAI Embeddings API URL must end with 'embeddings'." + ) headers = { "Content-Type": "application/json", @@ -492,19 +475,13 @@ async def async_request_openai_embeddings( st = time.perf_counter() output.start_time = st try: - async with session.post( - url=api_url, - headers=headers, - json=payload - ) as response: + async with session.post(url=api_url, headers=headers, json=payload) as response: if response.status == 200: output.latency = time.perf_counter() - st data = await response.json() output.success = True output.generated_text = "" - output.prompt_len = data.get( - "usage", {}).get( - "prompt_tokens", 0) + output.prompt_len = data.get("usage", {}).get("prompt_tokens", 0) else: output.success = False output.error = response.reason or "" @@ -527,7 +504,7 @@ async def async_request_openai_embeddings( } OPENAI_COMPATIBLE_BACKENDS = [ - k for k, v in ASYNC_REQUEST_FUNCS.items() - if v in (async_request_openai_completions, - async_request_openai_chat_completions) + k + for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, async_request_openai_chat_completions) ] diff --git a/vllm/benchmarks/lib/ready_checker.py b/vllm/benchmarks/lib/ready_checker.py index 87fc16b55012..5649faf05597 100644 --- a/vllm/benchmarks/lib/ready_checker.py +++ b/vllm/benchmarks/lib/ready_checker.py @@ -8,8 +8,7 @@ import aiohttp from tqdm.asyncio import tqdm -from .endpoint_request_func import (RequestFunc, RequestFuncInput, - RequestFuncOutput) +from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput async def wait_for_endpoint( @@ -21,30 +20,29 @@ async def wait_for_endpoint( ) -> RequestFuncOutput: """ Wait for an endpoint to become available before starting benchmarks. - + Args: request_func: The async request function to call test_input: The RequestFuncInput to test with timeout_seconds: Maximum time to wait in seconds (default: 10 minutes) retry_interval: Time between retries in seconds (default: 5 seconds) - + Returns: RequestFuncOutput: The successful response - + Raises: ValueError: If the endpoint doesn't become available within the timeout """ deadline = time.perf_counter() + timeout_seconds output = RequestFuncOutput(success=False) print(f"Waiting for endpoint to become up in {timeout_seconds} seconds") - + with tqdm( - total=timeout_seconds, + total=timeout_seconds, bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining", unit="s", ) as pbar: - - while True: + while True: # update progress bar remaining = deadline - time.perf_counter() elapsed = timeout_seconds - remaining @@ -58,16 +56,17 @@ async def wait_for_endpoint( # ping the endpoint using request_func try: output = await request_func( - request_func_input=test_input, session=session) + request_func_input=test_input, session=session + ) if output.success: pbar.close() return output except aiohttp.ClientConnectorError: pass - + # retry after a delay sleep_duration = min(retry_interval, remaining) if sleep_duration > 0: await asyncio.sleep(sleep_duration) - + return output diff --git a/vllm/benchmarks/lib/utils.py b/vllm/benchmarks/lib/utils.py index 0c27687dcf16..32e9db499007 100644 --- a/vllm/benchmarks/lib/utils.py +++ b/vllm/benchmarks/lib/utils.py @@ -8,9 +8,9 @@ from typing import Any -def convert_to_pytorch_benchmark_format(args: argparse.Namespace, - metrics: dict[str, list], - extra_info: dict[str, Any]) -> list: +def convert_to_pytorch_benchmark_format( + args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any] +) -> list: """ Save the benchmark results in the format used by PyTorch OSS benchmark with on metric per record @@ -38,12 +38,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, }, } - tp = record["benchmark"]["extra_info"]["args"].get( - "tensor_parallel_size") + tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size") # Save tensor_parallel_size parameter if it's part of the metadata if not tp and "tensor_parallel_size" in extra_info: - record["benchmark"]["extra_info"]["args"][ - "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = ( + extra_info["tensor_parallel_size"] + ) records.append(record) @@ -51,7 +51,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, class InfEncoder(json.JSONEncoder): - def clear_inf(self, o: Any): if isinstance(o, dict): return { diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 2371bbf27079..cad1d2eb2c6a 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -15,6 +15,7 @@ --request-rate <request_rate. Default inf> \ --num-prompts <num_prompts. Default 1000> """ + import argparse import asyncio import gc @@ -36,20 +37,22 @@ from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase -from vllm.benchmarks.datasets import (SampleRequest, add_dataset_parser, - get_samples) +from vllm.benchmarks.datasets import SampleRequest, add_dataset_parser, get_samples from vllm.benchmarks.lib.endpoint_request_func import ( - ASYNC_REQUEST_FUNCS, OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, - RequestFuncOutput) + ASYNC_REQUEST_FUNCS, + OPENAI_COMPATIBLE_BACKENDS, + RequestFuncInput, + RequestFuncOutput, +) from vllm.benchmarks.lib.ready_checker import wait_for_endpoint -from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, - write_to_json) +from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.transformers_utils.tokenizer import get_tokenizer MILLISECONDS_TO_SECONDS_CONVERSION = 1000 -TERM_PLOTLIB_AVAILABLE = ((importlib.util.find_spec("termplotlib") is not None) - and (shutil.which("gnuplot") is not None)) +TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) and ( + shutil.which("gnuplot") is not None +) class TaskType(Enum): @@ -110,8 +113,11 @@ def _get_current_request_rate( total_requests: int, request_rate: float, ) -> float: - if (ramp_up_strategy and ramp_up_start_rps is not None - and ramp_up_end_rps is not None): + if ( + ramp_up_strategy + and ramp_up_start_rps is not None + and ramp_up_end_rps is not None + ): progress = request_index / max(total_requests - 1, 1) if ramp_up_strategy == "linear": increase = (ramp_up_end_rps - ramp_up_start_rps) * progress @@ -158,10 +164,10 @@ async def get_request( The ending request rate for ramp-up. """ assert burstiness > 0, ( - f"A positive burstiness factor is expected, but given {burstiness}.") + f"A positive burstiness factor is expected, but given {burstiness}." + ) # Convert to list to get length for ramp-up calculations - if isinstance(input_requests, - Iterable) and not isinstance(input_requests, list): + if isinstance(input_requests, Iterable) and not isinstance(input_requests, list): input_requests = list(input_requests) total_requests = len(input_requests) @@ -172,8 +178,13 @@ async def get_request( delay_ts = [] for request_index, request in enumerate(input_requests): current_request_rate = _get_current_request_rate( - ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps, - request_index, total_requests, request_rate) + ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate, + ) request_rates.append(current_request_rate) if current_request_rate == float("inf"): delay_ts.append(0) @@ -213,8 +224,8 @@ async def get_request( def calculate_metrics_for_embeddings( - outputs: list[RequestFuncOutput], dur_s: float, - selected_percentiles: list[float]) -> EmbedBenchmarkMetrics: + outputs: list[RequestFuncOutput], dur_s: float, selected_percentiles: list[float] +) -> EmbedBenchmarkMetrics: """Calculate the metrics for the embedding requests. Args: @@ -238,7 +249,8 @@ def calculate_metrics_for_embeddings( warnings.warn( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", - stacklevel=2) + stacklevel=2, + ) metrics = EmbedBenchmarkMetrics( completed=completed, total_input=total_input, @@ -247,8 +259,9 @@ def calculate_metrics_for_embeddings( mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles], + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], ) return metrics @@ -294,8 +307,10 @@ def calculate_metrics( # bundled together # Note : this may inflate the output token count slightly output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + tokenizer( + outputs[i].generated_text, add_special_tokens=False + ).input_ids + ) actual_output_lens.append(output_len) total_input += input_requests[i].prompt_len tpot = 0 @@ -318,16 +333,19 @@ def calculate_metrics( if "ttft" in goodput_config_dict: valid_metrics.append(ttfts) - slo_values.append(goodput_config_dict["ttft"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "tpot" in goodput_config_dict: valid_metrics.append(all_tpots) - slo_values.append(goodput_config_dict["tpot"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "e2el" in goodput_config_dict: valid_metrics.append(e2els) - slo_values.append(goodput_config_dict["e2el"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) for req_metric in zip(*valid_metrics): is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) @@ -338,7 +356,8 @@ def calculate_metrics( warnings.warn( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", - stacklevel=2) + stacklevel=2, + ) # Calculate max output tokens per second metric max_output_tokens_per_s = 0.0 @@ -347,10 +366,10 @@ def calculate_metrics( # Find the time range across all successful requests successful_outputs = [output for output in outputs if output.success] if successful_outputs: - min_start_time = min(output.start_time - for output in successful_outputs) - max_end_time = max(output.start_time + output.latency - for output in successful_outputs) + min_start_time = min(output.start_time for output in successful_outputs) + max_end_time = max( + output.start_time + output.latency for output in successful_outputs + ) # Create second buckets (ceiling to ensure we capture all time) duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1 @@ -374,8 +393,9 @@ def calculate_metrics( # Track concurrent requests for each second this request was active request_start_second = int(output.start_time - min_start_time) - request_end_second = int((output.start_time + output.latency) - - min_start_time) + request_end_second = int( + (output.start_time + output.latency) - min_start_time + ) for second in range(request_start_second, request_end_second + 1): concurrent_requests_per_second[second] += 1 @@ -384,18 +404,22 @@ def calculate_metrics( # concurrent requests if len(tokens_per_second) > 0: max_output_tokens_per_s = float(np.max(tokens_per_second)) - max_concurrent_requests = int( - np.max(concurrent_requests_per_second)) + max_concurrent_requests = int(np.max(concurrent_requests_per_second)) if TERM_PLOTLIB_AVAILABLE: import termplotlib as tpl + fig = tpl.figure() - fig.plot(np.arange(len(tokens_per_second)), - tokens_per_second, - title="Output tokens per second") - fig.plot(np.arange(len(concurrent_requests_per_second)), - concurrent_requests_per_second, - title="Concurrent requests per second") + fig.plot( + np.arange(len(tokens_per_second)), + tokens_per_second, + title="Output tokens per second", + ) + fig.plot( + np.arange(len(concurrent_requests_per_second)), + concurrent_requests_per_second, + title="Concurrent requests per second", + ) fig.show() else: print("tip: install termplotlib and gnuplot to plot the metrics") @@ -408,27 +432,31 @@ def calculate_metrics( request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) * - 1000, # ttfts is empty if streaming is not supported by the endpoint + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by the endpoint std_ttft_ms=np.std(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) - for p in selected_percentiles], + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles + ], mean_tpot_ms=np.mean(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) - for p in selected_percentiles], + percentiles_tpot_ms=[ + (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles + ], mean_itl_ms=np.mean(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) - for p in selected_percentiles], + percentiles_itl_ms=[ + (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles + ], mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles], + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], max_output_tokens_per_s=max_output_tokens_per_s, max_concurrent_requests=max_concurrent_requests, ) @@ -462,8 +490,11 @@ async def benchmark( ramp_up_end_rps: Optional[int] = None, ready_check_timeout_sec: int = 600, ): - task_type = (TaskType.EMBEDDING if api_url.endswith("/v1/embeddings") else - TaskType.GENERATION) + task_type = ( + TaskType.EMBEDDING + if api_url.endswith("/v1/embeddings") + else TaskType.GENERATION + ) if endpoint_type in ASYNC_REQUEST_FUNCS: if task_type == TaskType.EMBEDDING: request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"] @@ -498,10 +529,14 @@ async def benchmark( input_requests[0].multi_modal_data, ) - assert (test_mm_content is None or isinstance(test_mm_content, dict) - or (isinstance(test_mm_content, list) - and all(isinstance(item, dict) for item in test_mm_content)) - ), "multi_modal_data must be a dict or list[dict]" + assert ( + test_mm_content is None + or isinstance(test_mm_content, dict) + or ( + isinstance(test_mm_content, list) + and all(isinstance(item, dict) for item in test_mm_content) + ) + ), "multi_modal_data must be a dict or list[dict]" test_input = RequestFuncInput( model=model_id, model_name=model_name, @@ -527,7 +562,8 @@ async def benchmark( raise ValueError( "Initial test run failed - Please make sure benchmark " "arguments are correctly specified. " - f"Error: {test_output.error}") + f"Error: {test_output.error}" + ) else: print("Initial test run completed. Starting main benchmark run...") else: @@ -536,33 +572,38 @@ async def benchmark( if lora_modules: # For each input request, choose a LoRA module at random. lora_modules = iter( - [random.choice(lora_modules) for _ in range(len(input_requests))]) + [random.choice(lora_modules) for _ in range(len(input_requests))] + ) if profile: print("Starting profiler...") - profile_input = RequestFuncInput(model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_headers=extra_headers, - extra_body=extra_body) - profile_output = await request_func(request_func_input=profile_input, - session=session) + profile_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_headers=extra_headers, + extra_body=extra_body, + ) + profile_output = await request_func( + request_func_input=profile_input, session=session + ) if profile_output.success: print("Profiler started") - distribution = ("Poisson process" - if burstiness == 1.0 else "Gamma distribution") + distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" if ramp_up_strategy is not None: print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") - print(f"Will increase RPS from {ramp_up_start_rps} to " - f"{ramp_up_end_rps} RPS over the duration of the benchmark.") + print( + f"Will increase RPS from {ramp_up_start_rps} to " + f"{ramp_up_end_rps} RPS over the duration of the benchmark." + ) else: print(f"Traffic request rate: {request_rate}") @@ -575,18 +616,17 @@ async def benchmark( # and it will simplify the code in limited_request_func. # semaphore = (asyncio.Semaphore(max_concurrency) # if max_concurrency else contextlib.nullcontext()) - semaphore = (asyncio.Semaphore(max_concurrency) - if max_concurrency else None) + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None async def limited_request_func(request_func_input, session, pbar): if semaphore is None: - return await request_func(request_func_input=request_func_input, - session=session, - pbar=pbar) + return await request_func( + request_func_input=request_func_input, session=session, pbar=pbar + ) async with semaphore: - return await request_func(request_func_input=request_func_input, - session=session, - pbar=pbar) + return await request_func( + request_func_input=request_func_input, session=session, pbar=pbar + ) benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] @@ -595,23 +635,27 @@ async def limited_request_func(request_func_input, session, pbar): last_int_rps = -1 if ramp_up_strategy is not None and ramp_up_start_rps is not None: last_int_rps = ramp_up_start_rps - rps_change_events.append({ - "rps": last_int_rps, - "timestamp": datetime.now().isoformat(), - }) + rps_change_events.append( + { + "rps": last_int_rps, + "timestamp": datetime.now().isoformat(), + } + ) async for request, current_request_rate in get_request( - input_requests, request_rate, burstiness, ramp_up_strategy, - ramp_up_start_rps, ramp_up_end_rps): + input_requests, + request_rate, + burstiness, + ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + ): if ramp_up_strategy is not None: current_int_rps = int(current_request_rate) if current_int_rps > last_int_rps: timestamp = datetime.now().isoformat() for rps_val in range(last_int_rps + 1, current_int_rps + 1): - rps_change_events.append({ - "rps": rps_val, - "timestamp": timestamp - }) + rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) last_int_rps = current_int_rps prompt, prompt_len, output_len, mm_content, request_id = ( request.prompt, @@ -641,9 +685,11 @@ async def limited_request_func(request_func_input, session, pbar): ) tasks.append( asyncio.create_task( - limited_request_func(request_func_input=request_func_input, - session=session, - pbar=pbar))) + limited_request_func( + request_func_input=request_func_input, session=session, pbar=pbar + ) + ) + ) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if pbar is not None: @@ -668,35 +714,48 @@ async def limited_request_func(request_func_input, session, pbar): ) actual_output_lens = 0 - print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) if max_concurrency is not None: - print("{:<40} {:<10}".format("Maximum request concurrency:", - max_concurrency)) - if request_rate != float('inf'): - print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", - request_rate)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", - benchmark_duration)) + print("{:<40} {:<10}".format("Maximum request concurrency:", max_concurrency)) + if request_rate != float("inf"): + print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", request_rate)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) if isinstance(metrics, BenchmarkMetrics): - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) - print("{:<40} {:<10.2f}".format("Request throughput (req/s):", - metrics.request_throughput)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) if goodput_config_dict: - print("{:<40} {:<10.2f}".format("Request goodput (req/s):", - metrics.request_goodput)) + print( + "{:<40} {:<10.2f}".format( + "Request goodput (req/s):", metrics.request_goodput + ) + ) if isinstance(metrics, BenchmarkMetrics): - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) - print("{:<40} {:<10.2f}".format( - "Peak output token throughput (tok/s):", - metrics.max_output_tokens_per_s)) - print("{:<40} {:<10.2f}".format("Peak concurrent requests:", - metrics.max_concurrent_requests)) - print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", - metrics.total_token_throughput)) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Peak output token throughput (tok/s):", metrics.max_output_tokens_per_s + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Peak concurrent requests:", metrics.max_concurrent_requests + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total Token throughput (tok/s):", metrics.total_token_throughput + ) + ) if isinstance(metrics, BenchmarkMetrics): result = { @@ -705,8 +764,7 @@ async def limited_request_func(request_func_input, session, pbar): "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "request_goodput": - metrics.request_goodput if goodput_config_dict else None, + "request_goodput": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], @@ -744,30 +802,36 @@ def process_one_metric( # metric. if metric_attribute_name not in selected_percentile_metrics: return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) - print("{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"))) - print("{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"))) + print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"), + ) + ) result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms") + metrics, f"mean_{metric_attribute_name}_ms" + ) result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms") + metrics, f"median_{metric_attribute_name}_ms" + ) result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms") - for p, value in getattr(metrics, - f"percentiles_{metric_attribute_name}_ms"): + metrics, f"std_{metric_attribute_name}_ms" + ) + for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", - value)) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value if task_type == TaskType.GENERATION: process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") + process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") @@ -783,8 +847,9 @@ def process_one_metric( output_len=test_output_len, logprobs=logprobs, ) - profile_output = await request_func(request_func_input=profile_input, - session=session) + profile_output = await request_func( + request_func_input=profile_input, session=session + ) if profile_output.success: print("Profiler stopped") @@ -803,12 +868,14 @@ def check_goodput_args(args): raise ValueError( f"Invalid metric name found, {slo_name}: {slo_val}. " "The service level objective name should be one of " - f"{str(VALID_NAMES)}. ") + f"{str(VALID_NAMES)}. " + ) if slo_val < 0: raise ValueError( f"Invalid value found, {slo_name}: {slo_val}. " "The service level objective value should be " - "non-negative.") + "non-negative." + ) return goodput_config_dict @@ -821,31 +888,42 @@ def parse_goodput(slo_pairs): except ValueError as err: raise argparse.ArgumentTypeError( "Invalid format found for service level objectives. " - "Specify service level objectives for goodput as \"KEY:VALUE\" " + 'Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is a " - "number in milliseconds.") from err + "number in milliseconds." + ) from err return goodput_config_dict -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any], - file_name: str) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any], file_name: str +) -> None: metrics = [ - "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", - "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", - "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + "median_ttft_ms", + "mean_ttft_ms", + "std_ttft_ms", + "p99_ttft_ms", + "mean_tpot_ms", + "median_tpot_ms", + "std_tpot_ms", + "p99_tpot_ms", + "median_itl_ms", + "mean_itl_ms", + "std_itl_ms", + "p99_itl_ms", ] # These raw data might be useful, but they are rather big. They can be added # later if needed ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] pt_records = convert_to_pytorch_benchmark_format( args=args, - metrics={k: [results[k]] - for k in metrics if k in results}, + metrics={k: [results[k]] for k in metrics if k in results}, extra_info={ k: results[k] - for k in results if k not in metrics and k not in ignored_metrics - }) + for k in results + if k not in metrics and k not in ignored_metrics + }, + ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" @@ -866,7 +944,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default="openai", choices=list(ASYNC_REQUEST_FUNCS.keys()), - help="The type of backend or endpoint to use for the benchmark." + help="The type of backend or endpoint to use for the benchmark.", ) parser.add_argument( "--base-url", @@ -888,9 +966,9 @@ def add_cli_args(parser: argparse.ArgumentParser): metavar="KEY=VALUE", nargs="*", help="Key-value pairs (e.g, --header x-additional-info=0.3.3) " - "for headers to be passed with each request. These headers override " \ - "per backend constants and values set via environment variable, and " \ - "will be overriden by other arguments (such as request ids)." + "for headers to be passed with each request. These headers override " + "per backend constants and values set via environment variable, and " + "will be overriden by other arguments (such as request ids).", ) parser.add_argument( "--max-concurrency", @@ -915,19 +993,20 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( "--logprobs", type=int, default=None, - help=("Number of logprobs-per-token to compute & return as part of " - "the request. If unspecified, then either (1) if beam search " - "is disabled, no logprobs are computed & a single dummy " - "logprob is returned for each token; or (2) if beam search " - "is enabled 1 logprob per token is computed"), + help=( + "Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed" + ), ) parser.add_argument( "--request-rate", @@ -1010,32 +1089,34 @@ def add_cli_args(parser: argparse.ArgumentParser): "--ignore-eos", action="store_true", help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", + ) parser.add_argument( "--percentile-metrics", type=str, default="ttft,tpot,itl", help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " - "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". ") + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ', + ) parser.add_argument( "--metric-percentiles", type=str, default="99", help="Comma-separated list of percentiles for selected metrics. " - "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " - "Default value is \"99\"." - "Use \"--percentile-metrics\" to select metrics.", + 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' + 'Default value is "99".' + 'Use "--percentile-metrics" to select metrics.', ) parser.add_argument( "--goodput", nargs="+", required=False, - help="Specify service level objectives for goodput as \"KEY:VALUE\" " + help='Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is in " - "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' "separated by spaces. Allowed request level metric names are " - "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + '"ttft", "tpot", "e2el". For more context on the definition of ' "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "and the blog: https://hao-ai-lab.github.io/blogs/distserve", ) @@ -1052,22 +1133,19 @@ def add_cli_args(parser: argparse.ArgumentParser): "--top-p", type=float, default=None, - help="Top-p sampling parameter. Only has effect on " - "openai-compatible backends.", + help="Top-p sampling parameter. Only has effect on openai-compatible backends.", ) sampling_group.add_argument( "--top-k", type=int, default=None, - help="Top-k sampling parameter. Only has effect on " - "openai-compatible backends.", + help="Top-k sampling parameter. Only has effect on openai-compatible backends.", ) sampling_group.add_argument( "--min-p", type=float, default=None, - help="Min-p sampling parameter. Only has effect on " - "openai-compatible backends.", + help="Min-p sampling parameter. Only has effect on openai-compatible backends.", ) sampling_group.add_argument( "--temperature", @@ -1100,29 +1178,34 @@ def add_cli_args(parser: argparse.ArgumentParser): ) parser.add_argument( - '--tokenizer-mode', + "--tokenizer-mode", type=str, default="auto", - choices=['auto', 'slow', 'mistral', 'custom'], + choices=["auto", "slow", "mistral", "custom"], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' - 'always use the slow tokenizer. \n* ' + "always use the slow tokenizer. \n* " '"mistral" will always use the `mistral_common` tokenizer. \n*' - '"custom" will use --tokenizer to select the preregistered tokenizer.') - - parser.add_argument("--served-model-name", - type=str, - default=None, - help="The model name used in the API. " - "If not specified, the model name will be the " - "same as the ``--model`` argument. ") - - parser.add_argument("--lora-modules", - nargs='+', - default=None, - help="A subset of LoRA module names passed in when " - "launching the server. For each request, the " - "script chooses a LoRA module at random.") + '"custom" will use --tokenizer to select the preregistered tokenizer.', + ) + + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ", + ) + + parser.add_argument( + "--lora-modules", + nargs="+", + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.", + ) parser.add_argument( "--ramp-up-strategy", @@ -1132,7 +1215,8 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The ramp-up strategy. This would be used to " "ramp up the request rate from initial RPS to final " "RPS rate (specified by --ramp-up-start-rps and " - "--ramp-up-end-rps.) over the duration of the benchmark.") + "--ramp-up-end-rps.) over the duration of the benchmark.", + ) parser.add_argument( "--ramp-up-start-rps", type=int, @@ -1153,7 +1237,7 @@ def add_cli_args(parser: argparse.ArgumentParser): default=600, help="Maximum time to wait for the endpoint to become ready " "in seconds (default: 600 seconds / 10 minutes). If set to 0, " - "the ready check will be skipped." + "the ready check will be skipped.", ) @@ -1172,19 +1256,19 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: raise ValueError( "When using ramp-up, do not specify --request-rate. " "The request rate will be controlled by ramp-up parameters. " - "Please remove the --request-rate argument.") + "Please remove the --request-rate argument." + ) if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: raise ValueError( "When using --ramp-up-strategy, both --ramp-up-start-rps and " - "--ramp-up-end-rps must be specified") + "--ramp-up-end-rps must be specified" + ) if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: raise ValueError("Ramp-up start and end RPS must be non-negative") if args.ramp_up_start_rps > args.ramp_up_end_rps: raise ValueError("Ramp-up start RPS must be less than end RPS") - if (args.ramp_up_strategy == "exponential" - and args.ramp_up_start_rps == 0): - raise ValueError( - "For exponential ramp-up, the start RPS cannot be 0.") + if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0: + raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") label = args.label model_id = args.model @@ -1208,17 +1292,19 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: kvstring = item.split("=", 1) headers[kvstring[0].strip()] = kvstring[1].strip() else: - raise ValueError( - "Invalid header format. Please use KEY=VALUE format.") + raise ValueError("Invalid header format. Please use KEY=VALUE format.") - tokenizer = get_tokenizer(tokenizer_id, - tokenizer_mode=tokenizer_mode, - trust_remote_code=args.trust_remote_code) + tokenizer = get_tokenizer( + tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) if args.dataset_name is None: raise ValueError( "Please specify '--dataset-name' and the corresponding " - "'--dataset-path' if required.") + "'--dataset-path' if required." + ) # Load the dataset. input_requests = get_samples(args, tokenizer) @@ -1235,13 +1321,15 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: "frequency_penalty": args.frequency_penalty, "presence_penalty": args.presence_penalty, "repetition_penalty": args.repetition_penalty, - }.items() if v is not None + }.items() + if v is not None } # Sampling parameters are only supported by openai-compatible backend. if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: - raise ValueError("Sampling parameters are only supported by " - "openai-compatible backends.") + raise ValueError( + "Sampling parameters are only supported by openai-compatible backends." + ) if "temperature" not in sampling_params: sampling_params["temperature"] = 0.0 # Default to greedy decoding. @@ -1264,9 +1352,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: disable_tqdm=args.disable_tqdm, profile=args.profile, selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], ignore_eos=args.ignore_eos, goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, @@ -1285,7 +1371,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: # Setup current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") result_json["date"] = current_dt - result_json["endpoint_type"] = args.backend # for backward compatibility + result_json["endpoint_type"] = args.backend # for backward compatibility result_json["backend"] = args.backend result_json["label"] = label result_json["model_id"] = model_id @@ -1300,11 +1386,13 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: result_json[kvstring[0].strip()] = kvstring[1].strip() else: raise ValueError( - "Invalid metadata format. Please use KEY=VALUE format.") + "Invalid metadata format. Please use KEY=VALUE format." + ) # Traffic - result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") + result_json["request_rate"] = ( + args.request_rate if args.request_rate < float("inf") else "inf" + ) result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency @@ -1319,12 +1407,12 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if not args.save_detailed: # Remove fields with too many data points for field in [ - "input_lens", - "output_lens", - "ttfts", - "itls", - "generated_texts", - "errors", + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", ]: if field in result_json: del result_json[field] @@ -1334,8 +1422,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: # Save to file if args.save_result or args.append_result: base_model_id = model_id.split("/")[-1] - max_concurrency_str = (f"-concurrency{args.max_concurrency}" - if args.max_concurrency is not None else "") + max_concurrency_str = ( + f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None + else "" + ) label = label or args.backend if args.ramp_up_strategy is not None: file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa @@ -1346,9 +1437,9 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if args.result_dir: os.makedirs(args.result_dir, exist_ok=True) file_name = os.path.join(args.result_dir, file_name) - with open(file_name, - mode="a+" if args.append_result else "w", - encoding="utf-8") as outfile: + with open( + file_name, mode="a+" if args.append_result else "w", encoding="utf-8" + ) as outfile: # Append a newline. if args.append_result and outfile.tell() != 0: outfile.write("\n") diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 9e38e63a0883..181a3e196586 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Benchmark offline inference throughput.""" + import argparse import dataclasses import json @@ -13,18 +14,21 @@ import torch import uvloop from tqdm import tqdm -from transformers import (AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizerBase) - -from vllm.benchmarks.datasets import (AIMODataset, BurstGPTDataset, - ConversationDataset, - InstructCoderDataset, - PrefixRepetitionRandomDataset, - RandomDataset, SampleRequest, - ShareGPTDataset, SonnetDataset, - VisionArenaDataset) -from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format, - write_to_json) +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import ( + AIMODataset, + BurstGPTDataset, + ConversationDataset, + InstructCoderDataset, + PrefixRepetitionRandomDataset, + RandomDataset, + SampleRequest, + ShareGPTDataset, + SonnetDataset, + VisionArenaDataset, +) +from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest @@ -41,23 +45,30 @@ def run_vllm( disable_detokenize: bool = False, ) -> tuple[float, Optional[list[RequestOutput]]]: from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") + llm.llm_engine.model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests." + ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] for request in requests: prompts.append( - TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data) - if "prompt_token_ids" in request.prompt else \ - TextPrompt(prompt=request.prompt, - multi_modal_data=request.multi_modal_data)) + TokensPrompt( + prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data, + ) + if "prompt_token_ids" in request.prompt + else TextPrompt( + prompt=request.prompt, multi_modal_data=request.multi_modal_data + ) + ) sampling_params.append( SamplingParams( n=n, @@ -66,7 +77,8 @@ def run_vllm( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) lora_requests: Optional[list[LoRARequest]] = None if engine_args.enable_lora: lora_requests = [request.lora_request for request in requests] @@ -78,10 +90,9 @@ def run_vllm( start = time.perf_counter() if do_profile: llm.start_profile() - outputs = llm.generate(prompts, - sampling_params, - lora_request=lora_requests, - use_tqdm=True) + outputs = llm.generate( + prompts, sampling_params, lora_request=lora_requests, use_tqdm=True + ) if do_profile: llm.stop_profile() end = time.perf_counter() @@ -101,7 +112,8 @@ def run_vllm( beam_width=n, max_tokens=output_len, ignore_eos=True, - )) + ), + ) if do_profile: llm.stop_profile() end = time.perf_counter() @@ -109,25 +121,29 @@ def run_vllm( def run_vllm_chat( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - do_profile: bool, - disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + do_profile: bool, + disable_detokenize: bool = False, +) -> tuple[float, list[RequestOutput]]: """ Run vLLM chat benchmark. This function is recommended ONLY for benchmarking multimodal models as it properly handles multimodal inputs and chat formatting. For non-multimodal models, use run_vllm() instead. """ from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of " - "prompt_len and expected_output_len for all requests.") + llm.llm_engine.model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests." + ) prompts = [] sampling_params: list[SamplingParams] = [] @@ -141,7 +157,8 @@ def run_vllm_chat( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) start = time.perf_counter() if do_profile: llm.start_profile() @@ -162,7 +179,8 @@ async def run_vllm_async( ) -> float: from vllm import SamplingParams from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, + ) async with build_async_engine_client_from_engine_args( engine_args, @@ -170,11 +188,13 @@ async def run_vllm_async( ) as llm: model_config = await llm.get_model_config() assert all( - model_config.max_model_len >= (request.prompt_len + - request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") + model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests." + ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] @@ -182,11 +202,15 @@ async def run_vllm_async( lora_requests: list[Optional[LoRARequest]] = [] for request in requests: prompts.append( - TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data) - if "prompt_token_ids" in request.prompt else \ - TextPrompt(prompt=request.prompt, - multi_modal_data=request.multi_modal_data)) + TokensPrompt( + prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data, + ) + if "prompt_token_ids" in request.prompt + else TextPrompt( + prompt=request.prompt, multi_modal_data=request.multi_modal_data + ) + ) sampling_params.append( SamplingParams( n=n, @@ -195,19 +219,18 @@ async def run_vllm_async( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) lora_requests.append(request.lora_request) generators = [] start = time.perf_counter() if do_profile: await llm.start_profile() - for i, (prompt, sp, - lr) in enumerate(zip(prompts, sampling_params, lora_requests)): - generator = llm.generate(prompt, - sp, - lora_request=lr, - request_id=f"test{i}") + for i, (prompt, sp, lr) in enumerate( + zip(prompts, sampling_params, lora_requests) + ): + generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: @@ -228,7 +251,8 @@ def run_hf( disable_detokenize: bool = False, ) -> float: llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + ) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token @@ -251,14 +275,15 @@ def run_hf( # Check if we can add more requests to the batch. next_prompt_len = requests[i + 1].prompt_len next_output_len = requests[i + 1].expected_output_len - if (max(max_prompt_len, next_prompt_len) + - max(max_output_len, next_output_len)) <= 2048: + if ( + max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len) + ) <= 2048: # We can add more requests to the batch. continue # Generate the sequences. - input_ids = tokenizer(batch, return_tensors="pt", - padding=True).input_ids + input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), do_sample=True, @@ -281,8 +306,9 @@ def run_hf( return end - start -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any]) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={ @@ -290,9 +316,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, "tokens_per_second": [results["tokens_per_second"]], }, extra_info={ - k: results[k] - for k in ["elapsed_time", "num_requests", "total_num_tokens"] - }) + k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }, + ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" @@ -324,7 +350,8 @@ def get_requests(args, tokenizer): sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_name == "sonnet": assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") + "Tokenizer/model must have chat template for sonnet dataset." + ) dataset_cls = SonnetDataset sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["return_prompt_formatted"] = True @@ -333,21 +360,21 @@ def get_requests(args, tokenizer): elif args.dataset_name == "hf": if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: dataset_cls = VisionArenaDataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_subset"] = None + common_kwargs["dataset_split"] = "train" sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: dataset_cls = InstructCoderDataset - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_split"] = "train" elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_cls = ConversationDataset - common_kwargs['dataset_subset'] = args.hf_subset - common_kwargs['dataset_split'] = args.hf_split + common_kwargs["dataset_subset"] = args.hf_subset + common_kwargs["dataset_split"] = args.hf_split sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_cls = AIMODataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_subset"] = None + common_kwargs["dataset_split"] = "train" elif args.dataset_name == "prefix_repetition": dataset_cls = PrefixRepetitionRandomDataset sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len @@ -373,8 +400,11 @@ def filter_requests_for_dp(requests, data_parallel_size): global_rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) data_parallel_rank = global_rank // (world_size // data_parallel_size) - return [r for i, r in enumerate(requests) - if i % data_parallel_size == data_parallel_rank] + return [ + r + for i, r in enumerate(requests) + if i % data_parallel_size == data_parallel_rank + ] def validate_args(args): @@ -387,7 +417,8 @@ def validate_args(args): warnings.warn( "The '--dataset' argument will be deprecated in the next release. " "Please use '--dataset-name' and '--dataset-path' instead.", - stacklevel=2) + stacklevel=2, + ) args.dataset_path = args.dataset if not getattr(args, "tokenizer", None): @@ -404,9 +435,8 @@ def validate_args(args): and not args.dataset_path and args.dataset_name not in {"prefix_repetition"} ): - print( - "When dataset path is not set, it will default to random dataset") - args.dataset_name = 'random' + print("When dataset path is not set, it will default to random dataset") + args.dataset_name = "random" if args.input_len is None: raise ValueError("input_len must be provided for a random dataset") @@ -414,41 +444,55 @@ def validate_args(args): # --hf-subset and --hf-split: only used # when dataset_name is 'hf' if args.dataset_name != "hf" and ( - getattr(args, "hf_subset", None) is not None - or getattr(args, "hf_split", None) is not None): - warnings.warn("--hf-subset and --hf-split will be ignored \ + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None + ): + warnings.warn( + "--hf-subset and --hf-split will be ignored \ since --dataset-name is not 'hf'.", - stacklevel=2) + stacklevel=2, + ) elif args.dataset_name == "hf": if args.dataset_path in ( - VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() - | ConversationDataset.SUPPORTED_DATASET_PATHS): - assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 - elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS - | AIMODataset.SUPPORTED_DATASET_PATHS): - assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | ConversationDataset.SUPPORTED_DATASET_PATHS + ): + assert args.backend == "vllm-chat", ( + f"{args.dataset_path} needs to use vllm-chat as the backend." + ) + elif args.dataset_path in ( + InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS + ): + assert args.backend == "vllm", ( + f"{args.dataset_path} needs to use vllm as the backend." + ) else: - raise ValueError( - f"{args.dataset_path} is not supported by hf dataset.") + raise ValueError(f"{args.dataset_path} is not supported by hf dataset.") # --random-range-ratio: only used when dataset_name is 'random' - if args.dataset_name != 'random' and args.random_range_ratio is not None: - warnings.warn("--random-range-ratio will be ignored since \ + if args.dataset_name != "random" and args.random_range_ratio is not None: + warnings.warn( + "--random-range-ratio will be ignored since \ --dataset-name is not 'random'.", - stacklevel=2) + stacklevel=2, + ) # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not # set. - if args.dataset_name not in {"random", "sonnet", None - } and args.prefix_len is not None: - warnings.warn("--prefix-len will be ignored since --dataset-name\ + if ( + args.dataset_name not in {"random", "sonnet", None} + and args.prefix_len is not None + ): + warnings.warn( + "--prefix-len will be ignored since --dataset-name\ is not 'random', 'sonnet', or not set.", - stacklevel=2) + stacklevel=2, + ) # === LoRA Settings === if getattr(args, "enable_lora", False) and args.backend != "vllm": - raise ValueError( - "LoRA benchmarking is only supported for vLLM backend") + raise ValueError("LoRA benchmarking is only supported for vLLM backend") if getattr(args, "enable_lora", False) and args.lora_path is None: raise ValueError("LoRA path must be provided when enable_lora is True") @@ -458,8 +502,10 @@ def validate_args(args): if args.backend != "hf" and args.hf_max_batch_size is not None: raise ValueError("HF max batch size is only for HF backend.") - if args.backend in {"hf", "mii"} and getattr(args, "quantization", - None) is not None: + if ( + args.backend in {"hf", "mii"} + and getattr(args, "quantization", None) is not None + ): raise ValueError("Quantization is only for vLLM backend.") if args.backend == "mii" and args.dtype != "auto": @@ -467,12 +513,11 @@ def validate_args(args): if args.backend == "mii" and args.n != 1: raise ValueError("n must be 1 for MII backend.") if args.backend == "mii" and args.tokenizer != args.model: - raise ValueError( - "Tokenizer must be the same as the model for MII backend.") + raise ValueError("Tokenizer must be the same as the model for MII backend.") if args.data_parallel_size > 1 and ( - args.distributed_executor_backend != "external_launcher" - or args.async_engine): + args.distributed_executor_backend != "external_launcher" or args.async_engine + ): # --data-parallel is not supported fully. # Old issue: https://github.com/vllm-project/vllm/issues/16222 # Currently we only support data parallel with external launcher @@ -485,19 +530,19 @@ def validate_args(args): def add_cli_args(parser: argparse.ArgumentParser): - parser.add_argument("--backend", - type=str, - choices=["vllm", "hf", "mii", "vllm-chat"], - default="vllm") + parser.add_argument( + "--backend", + type=str, + choices=["vllm", "hf", "mii", "vllm-chat"], + default="vllm", + ) parser.add_argument( "--dataset-name", type=str, - choices=[ - "sharegpt", "random", "sonnet", "burstgpt", "hf", - "prefix_repetition" - ], + choices=["sharegpt", "random", "sonnet", "burstgpt", "hf", "prefix_repetition"], help="Name of the dataset to benchmark on.", - default="sharegpt") + default="sharegpt", + ) parser.add_argument( "--dataset", type=str, @@ -505,57 +550,70 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Path to the ShareGPT dataset, will be deprecated in\ the next release. The dataset is expected to " "be a json in form of list[dict[..., conversations: " - "list[dict[..., value: <prompt_or_response>]]]]") - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the dataset") - parser.add_argument("--input-len", - type=int, - default=None, - help="Input prompt length for each request") - parser.add_argument("--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.") - parser.add_argument("--n", - type=int, - default=1, - help="Number of generated sequences per prompt.") - parser.add_argument("--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.") - parser.add_argument("--hf-max-batch-size", - type=int, - default=None, - help="Maximum batch size for HF backend.") + "list[dict[..., value: <prompt_or_response>]]]]", + ) + parser.add_argument( + "--dataset-path", type=str, default=None, help="Path to the dataset" + ) + parser.add_argument( + "--input-len", + type=int, + default=None, + help="Input prompt length for each request", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.", + ) parser.add_argument( - '--output-json', + "--n", type=int, default=1, help="Number of generated sequences per prompt." + ) + parser.add_argument( + "--num-prompts", type=int, default=1000, help="Number of prompts to process." + ) + parser.add_argument( + "--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.", + ) + parser.add_argument( + "--output-json", type=str, default=None, - help='Path to save the throughput results in JSON format.') - parser.add_argument("--async-engine", - action='store_true', - default=False, - help="Use vLLM async engine rather than LLM class.") - parser.add_argument("--disable-frontend-multiprocessing", - action='store_true', - default=False, - help="Disable decoupled async engine frontend.") + help="Path to save the throughput results in JSON format.", + ) + parser.add_argument( + "--async-engine", + action="store_true", + default=False, + help="Use vLLM async engine rather than LLM class.", + ) + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + default=False, + help="Disable decoupled async engine frontend.", + ) parser.add_argument( "--disable-detokenize", action="store_true", - help=("Do not detokenize the response (i.e. do not include " - "detokenization time in the measurement)")) + help=( + "Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)" + ), + ) # LoRA parser.add_argument( "--lora-path", type=str, default=None, help="Path to the lora adapters to use. This can be an absolute path, " - "a relative path, or a Hugging Face model identifier.") + "a relative path, or a Hugging Face model identifier.", + ) parser.add_argument( "--prefix-len", type=int, @@ -575,24 +633,24 @@ def add_cli_args(parser: argparse.ArgumentParser): ) # hf dtaset - parser.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - parser.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") + parser.add_argument( + "--hf-subset", type=str, default=None, help="Subset of the HF dataset." + ) + parser.add_argument( + "--hf-split", type=str, default=None, help="Split of the HF dataset." + ) parser.add_argument( "--profile", action="store_true", default=False, help="Use Torch Profiler. The env variable " - "VLLM_TORCH_PROFILER_DIR must be set to enable profiler.") + "VLLM_TORCH_PROFILER_DIR must be set to enable profiler.", + ) # prefix repetition dataset prefix_repetition_group = parser.add_argument_group( - "prefix repetition dataset options") + "prefix repetition dataset options" + ) prefix_repetition_group.add_argument( "--prefix-repetition-prefix-len", type=int, @@ -634,10 +692,10 @@ def main(args: argparse.Namespace): random.seed(args.seed) # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) + args.tokenizer, trust_remote_code=args.trust_remote_code + ) requests = get_requests(args, tokenizer) - is_multi_modal = any(request.multi_modal_data is not None - for request in requests) + is_multi_modal = any(request.multi_modal_data is not None for request in requests) request_outputs: Optional[list[RequestOutput]] = None if args.backend == "vllm": if args.async_engine: @@ -649,24 +707,37 @@ def main(args: argparse.Namespace): disable_frontend_multiprocessing=args.disable_frontend_multiprocessing, disable_detokenize=args.disable_detokenize, do_profile=args.profile, - )) + ) + ) else: elapsed_time, request_outputs = run_vllm( - requests, args.n, EngineArgs.from_cli_args(args), + requests, + args.n, + EngineArgs.from_cli_args(args), disable_detokenize=args.disable_detokenize, - do_profile=args.profile) + do_profile=args.profile, + ) elif args.backend == "hf": assert args.tensor_parallel_size == 1 if args.profile: - raise NotImplementedError( - "Profiling not implemented yet for backend='hf'.") - elapsed_time = run_hf(requests, args.model, tokenizer, args.n, - args.hf_max_batch_size, args.trust_remote_code, - args.disable_detokenize) + raise NotImplementedError("Profiling not implemented yet for backend='hf'.") + elapsed_time = run_hf( + requests, + args.model, + tokenizer, + args.n, + args.hf_max_batch_size, + args.trust_remote_code, + args.disable_detokenize, + ) elif args.backend == "vllm-chat": elapsed_time, request_outputs = run_vllm_chat( - requests, args.n, EngineArgs.from_cli_args(args), - disable_detokenize=args.disable_detokenize, do_profile=args.profile) + requests, + args.n, + EngineArgs.from_cli_args(args), + disable_detokenize=args.disable_detokenize, + do_profile=args.profile, + ) else: raise ValueError(f"Unknown backend: {args.backend}") @@ -678,28 +749,31 @@ def main(args: argparse.Namespace): for ro in request_outputs: if not isinstance(ro, RequestOutput): continue - total_prompt_tokens += len( - ro.prompt_token_ids) if ro.prompt_token_ids else 0 - total_output_tokens += sum( - len(o.token_ids) for o in ro.outputs if o) + total_prompt_tokens += ( + len(ro.prompt_token_ids) if ro.prompt_token_ids else 0 + ) + total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o) total_num_tokens = total_prompt_tokens + total_output_tokens else: - total_num_tokens = sum(r.prompt_len + r.expected_output_len - for r in requests) + total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests) total_output_tokens = sum(r.expected_output_len for r in requests) total_prompt_tokens = total_num_tokens - total_output_tokens if is_multi_modal and args.backend != "vllm-chat": - print("\033[91mWARNING\033[0m: Multi-modal request with " - f"{args.backend} backend detected. The " - "following metrics are not accurate because image tokens are not" - " counted. See vllm-project/vllm/issues/9778 for details.") + print( + "\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details." + ) # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. # vllm-chat backend counts the image tokens now - print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print( + f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s" + ) print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num output tokens: {total_output_tokens}") diff --git a/vllm/collect_env.py b/vllm/collect_env.py index fb9d3657790c..4ca0852e3998 100644 --- a/vllm/collect_env.py +++ b/vllm/collect_env.py @@ -9,6 +9,7 @@ import os import subprocess import sys + # Unlike the rest of the PyTorch this file must be python2 compliant. # This script outputs relevant system environment info # Run it with `python collect_env.py` or `python -m torch.utils.collect_env` @@ -20,45 +21,47 @@ try: import torch + TORCH_AVAILABLE = True except (ImportError, NameError, AttributeError, OSError): TORCH_AVAILABLE = False # System Environment Information SystemEnv = namedtuple( - 'SystemEnv', + "SystemEnv", [ - 'torch_version', - 'is_debug_build', - 'cuda_compiled_version', - 'gcc_version', - 'clang_version', - 'cmake_version', - 'os', - 'libc_version', - 'python_version', - 'python_platform', - 'is_cuda_available', - 'cuda_runtime_version', - 'cuda_module_loading', - 'nvidia_driver_version', - 'nvidia_gpu_models', - 'cudnn_version', - 'pip_version', # 'pip' or 'pip3' - 'pip_packages', - 'conda_packages', - 'hip_compiled_version', - 'hip_runtime_version', - 'miopen_runtime_version', - 'caching_allocator_config', - 'is_xnnpack_available', - 'cpu_info', - 'rocm_version', # vllm specific field - 'vllm_version', # vllm specific field - 'vllm_build_flags', # vllm specific field - 'gpu_topo', # vllm specific field - 'env_vars', - ]) + "torch_version", + "is_debug_build", + "cuda_compiled_version", + "gcc_version", + "clang_version", + "cmake_version", + "os", + "libc_version", + "python_version", + "python_platform", + "is_cuda_available", + "cuda_runtime_version", + "cuda_module_loading", + "nvidia_driver_version", + "nvidia_gpu_models", + "cudnn_version", + "pip_version", # 'pip' or 'pip3' + "pip_packages", + "conda_packages", + "hip_compiled_version", + "hip_runtime_version", + "miopen_runtime_version", + "caching_allocator_config", + "is_xnnpack_available", + "cpu_info", + "rocm_version", # vllm specific field + "vllm_version", # vllm specific field + "vllm_build_flags", # vllm specific field + "gpu_topo", # vllm specific field + "env_vars", + ], +) DEFAULT_CONDA_PATTERNS = { "torch", @@ -98,18 +101,17 @@ def run(command): """Return (return-code, stdout, stderr).""" shell = True if type(command) is str else False try: - p = subprocess.Popen(command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=shell) + p = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell + ) raw_output, raw_err = p.communicate() rc = p.returncode - if get_platform() == 'win32': - enc = 'oem' + if get_platform() == "win32": + enc = "oem" else: enc = locale.getpreferredencoding() output = raw_output.decode(enc) - if command == 'nvidia-smi topo -m': + if command == "nvidia-smi topo -m": # don't remove the leading whitespace of `nvidia-smi topo -m` # because they are meaningful output = output.rstrip() @@ -120,7 +122,7 @@ def run(command): except FileNotFoundError: cmd_str = command if isinstance(command, str) else command[0] - return 127, '', f"Command not found: {cmd_str}" + return 127, "", f"Command not found: {cmd_str}" def run_and_read_all(run_lambda, command): @@ -147,49 +149,54 @@ def run_and_return_first_line(run_lambda, command): rc, out, _ = run_lambda(command) if rc != 0: return None - return out.split('\n')[0] + return out.split("\n")[0] def get_conda_packages(run_lambda, patterns=None): if patterns is None: patterns = DEFAULT_CONDA_PATTERNS - conda = os.environ.get('CONDA_EXE', 'conda') - out = run_and_read_all(run_lambda, [conda, 'list']) + conda = os.environ.get("CONDA_EXE", "conda") + out = run_and_read_all(run_lambda, [conda, "list"]) if out is None: return out - return "\n".join(line for line in out.splitlines() - if not line.startswith("#") and any(name in line - for name in patterns)) + return "\n".join( + line + for line in out.splitlines() + if not line.startswith("#") and any(name in line for name in patterns) + ) def get_gcc_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)') + return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)") def get_clang_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'clang --version', - r'clang version (.*)') + return run_and_parse_first_match( + run_lambda, "clang --version", r"clang version (.*)" + ) def get_cmake_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'cmake --version', - r'cmake (.*)') + return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)") def get_nvidia_driver_version(run_lambda): - if get_platform() == 'darwin': - cmd = 'kextstat | grep -i cuda' - return run_and_parse_first_match(run_lambda, cmd, - r'com[.]nvidia[.]CUDA [(](.*?)[)]') + if get_platform() == "darwin": + cmd = "kextstat | grep -i cuda" + return run_and_parse_first_match( + run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]" + ) smi = get_nvidia_smi() - return run_and_parse_first_match(run_lambda, smi, - r'Driver Version: (.*?) ') + return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ") def get_gpu_info(run_lambda): - if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr( - torch.version, 'hip') and torch.version.hip is not None): + if get_platform() == "darwin" or ( + TORCH_AVAILABLE + and hasattr(torch.version, "hip") + and torch.version.hip is not None + ): if TORCH_AVAILABLE and torch.cuda.is_available(): if torch.version.hip is not None: prop = torch.cuda.get_device_properties(0) @@ -202,43 +209,42 @@ def get_gpu_info(run_lambda): return torch.cuda.get_device_name(None) + gcnArch return None smi = get_nvidia_smi() - uuid_regex = re.compile(r' \(UUID: .+?\)') - rc, out, _ = run_lambda(smi + ' -L') + uuid_regex = re.compile(r" \(UUID: .+?\)") + rc, out, _ = run_lambda(smi + " -L") if rc != 0: return None # Anonymize GPUs by removing their UUID - return re.sub(uuid_regex, '', out) + return re.sub(uuid_regex, "", out) def get_running_cuda_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'nvcc --version', - r'release .+ V(.*)') + return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)") def get_cudnn_version(run_lambda): """Return a list of libcudnn.so; it's hard to tell which one is being used.""" - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%") - where_cmd = os.path.join(system_root, 'System32', 'where') + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%") + where_cmd = os.path.join(system_root, "System32", "where") cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path) - elif get_platform() == 'darwin': + elif get_platform() == "darwin": # CUDA libraries and drivers can be found in /usr/local/cuda/. See # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. - cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*" else: cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' rc, out, _ = run_lambda(cudnn_cmd) # find will return 1 if there are permission errors or if not found if len(out) == 0 or (rc != 1 and rc != 0): - l = os.environ.get('CUDNN_LIBRARY') + l = os.environ.get("CUDNN_LIBRARY") if l is not None and os.path.isfile(l): return os.path.realpath(l) return None files_set = set() - for fn in out.split('\n'): + for fn in out.split("\n"): fn = os.path.realpath(fn) # eliminate symbolic links if os.path.isfile(fn): files_set.add(fn) @@ -248,20 +254,20 @@ def get_cudnn_version(run_lambda): files = sorted(files_set) if len(files) == 1: return files[0] - result = '\n'.join(files) - return 'Probably one of the following:\n{}'.format(result) + result = "\n".join(files) + return "Probably one of the following:\n{}".format(result) def get_nvidia_smi(): # Note: nvidia-smi is currently available only on Windows and Linux - smi = 'nvidia-smi' - if get_platform() == 'win32': - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - program_files_root = os.environ.get('PROGRAMFILES', - 'C:\\Program Files') - legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', - 'NVSMI', smi) - new_path = os.path.join(system_root, 'System32', smi) + smi = "nvidia-smi" + if get_platform() == "win32": + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files") + legacy_path = os.path.join( + program_files_root, "NVIDIA Corporation", "NVSMI", smi + ) + new_path = os.path.join(system_root, "System32", smi) smis = [new_path, legacy_path] for candidate_smi in smis: if os.path.exists(candidate_smi): @@ -272,8 +278,9 @@ def get_nvidia_smi(): def get_rocm_version(run_lambda): """Returns the ROCm version if available, otherwise 'N/A'.""" - return run_and_parse_first_match(run_lambda, 'hipcc --version', - r'HIP version: (\S+)') + return run_and_parse_first_match( + run_lambda, "hipcc --version", r"HIP version: (\S+)" + ) def get_vllm_version(): @@ -282,12 +289,12 @@ def get_vllm_version(): if __version__ == "dev": return "N/A (dev)" version_str = __version_tuple__[-1] - if isinstance(version_str, str) and version_str.startswith('g'): + if isinstance(version_str, str) and version_str.startswith("g"): # it's a dev build - if '.' in version_str: + if "." in version_str: # it's a dev build containing local changes - git_sha = version_str.split('.')[0][1:] - date = version_str.split('.')[-1][1:] + git_sha = version_str.split(".")[0][1:] + date = version_str.split(".")[-1][1:] return f"{__version__} (git sha: {git_sha}, date: {date})" else: # it's a dev build without local changes @@ -298,19 +305,19 @@ def get_vllm_version(): def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. - return 'CUDA Archs: {}; ROCm: {}'.format( - os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'), - 'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled', + return "CUDA Archs: {}; ROCm: {}".format( + os.environ.get("TORCH_CUDA_ARCH_LIST", "Not Set"), + "Enabled" if os.environ.get("ROCM_HOME") else "Disabled", ) def get_gpu_topo(run_lambda): output = None - if get_platform() == 'linux': - output = run_and_read_all(run_lambda, 'nvidia-smi topo -m') + if get_platform() == "linux": + output = run_and_read_all(run_lambda, "nvidia-smi topo -m") if output is None: - output = run_and_read_all(run_lambda, 'rocm-smi --showtopo') + output = run_and_read_all(run_lambda, "rocm-smi --showtopo") return output @@ -392,17 +399,17 @@ def get_gpu_topo(run_lambda): def get_cpu_info(run_lambda): - rc, out, err = 0, '', '' - if get_platform() == 'linux': - rc, out, err = run_lambda('lscpu') - elif get_platform() == 'win32': + rc, out, err = 0, "", "" + if get_platform() == "linux": + rc, out, err = run_lambda("lscpu") + elif get_platform() == "win32": rc, out, err = run_lambda( - 'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ - CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE' + "wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ + CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE" ) - elif get_platform() == 'darwin': + elif get_platform() == "darwin": rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") - cpu_info = 'None' + cpu_info = "None" if rc == 0: cpu_info = out else: @@ -411,67 +418,69 @@ def get_cpu_info(run_lambda): def get_platform(): - if sys.platform.startswith('linux'): - return 'linux' - elif sys.platform.startswith('win32'): - return 'win32' - elif sys.platform.startswith('cygwin'): - return 'cygwin' - elif sys.platform.startswith('darwin'): - return 'darwin' + if sys.platform.startswith("linux"): + return "linux" + elif sys.platform.startswith("win32"): + return "win32" + elif sys.platform.startswith("cygwin"): + return "cygwin" + elif sys.platform.startswith("darwin"): + return "darwin" else: return sys.platform def get_mac_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', - r'(.*)') + return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)") def get_windows_version(run_lambda): - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') - findstr_cmd = os.path.join(system_root, 'System32', 'findstr') + system_root = os.environ.get("SYSTEMROOT", "C:\\Windows") + wmic_cmd = os.path.join(system_root, "System32", "Wbem", "wmic") + findstr_cmd = os.path.join(system_root, "System32", "findstr") return run_and_read_all( - run_lambda, - '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) + run_lambda, "{} os get Caption | {} /v Caption".format(wmic_cmd, findstr_cmd) + ) def get_lsb_version(run_lambda): - return run_and_parse_first_match(run_lambda, 'lsb_release -a', - r'Description:\t(.*)') + return run_and_parse_first_match( + run_lambda, "lsb_release -a", r"Description:\t(.*)" + ) def check_release_file(run_lambda): - return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', - r'PRETTY_NAME="(.*)"') + return run_and_parse_first_match( + run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"' + ) def get_os(run_lambda): from platform import machine + platform = get_platform() - if platform == 'win32' or platform == 'cygwin': + if platform == "win32" or platform == "cygwin": return get_windows_version(run_lambda) - if platform == 'darwin': + if platform == "darwin": version = get_mac_version(run_lambda) if version is None: return None - return 'macOS {} ({})'.format(version, machine()) + return "macOS {} ({})".format(version, machine()) - if platform == 'linux': + if platform == "linux": # Ubuntu/Debian based desc = get_lsb_version(run_lambda) if desc is not None: - return '{} ({})'.format(desc, machine()) + return "{} ({})".format(desc, machine()) # Try reading /etc/*-release desc = check_release_file(run_lambda) if desc is not None: - return '{} ({})'.format(desc, machine()) + return "{} ({})".format(desc, machine()) - return '{} ({})'.format(platform, machine()) + return "{} ({})".format(platform, machine()) # Unknown platform return platform @@ -479,23 +488,25 @@ def get_os(run_lambda): def get_python_platform(): import platform + return platform.platform() def get_libc_version(): import platform - if get_platform() != 'linux': - return 'N/A' - return '-'.join(platform.libc_ver()) + + if get_platform() != "linux": + return "N/A" + return "-".join(platform.libc_ver()) def is_uv_venv(): if os.environ.get("UV"): return True - pyvenv_cfg_path = os.path.join(sys.prefix, 'pyvenv.cfg') + pyvenv_cfg_path = os.path.join(sys.prefix, "pyvenv.cfg") if os.path.exists(pyvenv_cfg_path): - with open(pyvenv_cfg_path, 'r') as f: - return any(line.startswith('uv = ') for line in f) + with open(pyvenv_cfg_path, "r") as f: + return any(line.startswith("uv = ") for line in f) return False @@ -507,13 +518,14 @@ def get_pip_packages(run_lambda, patterns=None): def run_with_pip(): try: import importlib.util - pip_spec = importlib.util.find_spec('pip') + + pip_spec = importlib.util.find_spec("pip") pip_available = pip_spec is not None except ImportError: pip_available = False if pip_available: - cmd = [sys.executable, '-mpip', 'list', '--format=freeze'] + cmd = [sys.executable, "-mpip", "list", "--format=freeze"] elif is_uv_venv(): print("uv is set") cmd = ["uv", "pip", "list", "--format=freeze"] @@ -523,23 +535,24 @@ def run_with_pip(): ) out = run_and_read_all(run_lambda, cmd) - return "\n".join(line for line in out.splitlines() - if any(name in line for name in patterns)) + return "\n".join( + line for line in out.splitlines() if any(name in line for name in patterns) + ) - pip_version = 'pip3' if sys.version[0] == '3' else 'pip' + pip_version = "pip3" if sys.version[0] == "3" else "pip" out = run_with_pip() return pip_version, out def get_cachingallocator_config(): - ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '') + ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") return ca_config def get_cuda_module_loading_config(): if TORCH_AVAILABLE and torch.cuda.is_available(): torch.cuda.init() - config = os.environ.get('CUDA_MODULE_LOADING', '') + config = os.environ.get("CUDA_MODULE_LOADING", "") return config else: return "N/A" @@ -548,17 +561,26 @@ def get_cuda_module_loading_config(): def is_xnnpack_available(): if TORCH_AVAILABLE: import torch.backends.xnnpack - return str( - torch.backends.xnnpack.enabled) # type: ignore[attr-defined] + + return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined] else: return "N/A" def get_env_vars(): - env_vars = '' - secret_terms = ('secret', 'token', 'api', 'access', 'password') - report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN", - "OMP_", "MKL_", "NVIDIA") + env_vars = "" + secret_terms = ("secret", "token", "api", "access", "password") + report_prefix = ( + "TORCH", + "NCCL", + "PYTORCH", + "CUDA", + "CUBLAS", + "CUDNN", + "OMP_", + "MKL_", + "NVIDIA", + ) for k, v in os.environ.items(): if any(term in k.lower() for term in secret_terms): continue @@ -579,23 +601,24 @@ def get_env_info(): debug_mode_str = str(torch.version.debug) cuda_available_str = str(torch.cuda.is_available()) cuda_version_str = torch.version.cuda - if not hasattr(torch.version, - 'hip') or torch.version.hip is None: # cuda version - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + if ( + not hasattr(torch.version, "hip") or torch.version.hip is None + ): # cuda version + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" else: # HIP version def get_version_or_na(cfg, prefix): _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s] - return _lst[0] if _lst else 'N/A' + return _lst[0] if _lst else "N/A" - cfg = torch._C._show_config().split('\n') - hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime') - miopen_runtime_version = get_version_or_na(cfg, 'MIOpen') - cuda_version_str = 'N/A' + cfg = torch._C._show_config().split("\n") + hip_runtime_version = get_version_or_na(cfg, "HIP Runtime") + miopen_runtime_version = get_version_or_na(cfg, "MIOpen") + cuda_version_str = "N/A" hip_compiled_version = torch.version.hip else: - version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A' - hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A' + version_str = debug_mode_str = cuda_available_str = cuda_version_str = "N/A" + hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A" sys_version = sys.version.replace("\n", " ") @@ -609,9 +632,9 @@ def get_version_or_na(cfg, prefix): return SystemEnv( torch_version=version_str, is_debug_build=debug_mode_str, - python_version='{} ({}-bit runtime)'.format( - sys_version, - sys.maxsize.bit_length() + 1), + python_version="{} ({}-bit runtime)".format( + sys_version, sys.maxsize.bit_length() + 1 + ), python_platform=get_python_platform(), is_cuda_available=cuda_available_str, cuda_compiled_version=cuda_version_str, @@ -715,15 +738,14 @@ def get_version_or_na(cfg, prefix): def pretty_str(envinfo): - - def replace_nones(dct, replacement='Could not collect'): + def replace_nones(dct, replacement="Could not collect"): for key in dct.keys(): if dct[key] is not None: continue dct[key] = replacement return dct - def replace_bools(dct, true='Yes', false='No'): + def replace_bools(dct, true="Yes", false="No"): for key in dct.keys(): if dct[key] is True: dct[key] = true @@ -731,43 +753,48 @@ def replace_bools(dct, true='Yes', false='No'): dct[key] = false return dct - def prepend(text, tag='[prepend]'): - lines = text.split('\n') + def prepend(text, tag="[prepend]"): + lines = text.split("\n") updated_lines = [tag + line for line in lines] - return '\n'.join(updated_lines) + return "\n".join(updated_lines) - def replace_if_empty(text, replacement='No relevant packages'): + def replace_if_empty(text, replacement="No relevant packages"): if text is not None and len(text) == 0: return replacement return text def maybe_start_on_next_line(string): # If `string` is multiline, prepend a \n to it. - if string is not None and len(string.split('\n')) > 1: - return '\n{}\n'.format(string) + if string is not None and len(string.split("\n")) > 1: + return "\n{}\n".format(string) return string mutable_dict = envinfo._asdict() # If nvidia_gpu_models is multiline, start on the next line - mutable_dict['nvidia_gpu_models'] = \ - maybe_start_on_next_line(envinfo.nvidia_gpu_models) + mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line( + envinfo.nvidia_gpu_models + ) # If the machine doesn't have CUDA, report some fields as 'No CUDA' dynamic_cuda_fields = [ - 'cuda_runtime_version', - 'nvidia_gpu_models', - 'nvidia_driver_version', + "cuda_runtime_version", + "nvidia_gpu_models", + "nvidia_driver_version", ] - all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] - all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None - for field in dynamic_cuda_fields) - if TORCH_AVAILABLE and not torch.cuda.is_available( - ) and all_dynamic_cuda_fields_missing: + all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields + ) + if ( + TORCH_AVAILABLE + and not torch.cuda.is_available() + and all_dynamic_cuda_fields_missing + ): for field in all_cuda_fields: - mutable_dict[field] = 'No CUDA' + mutable_dict[field] = "No CUDA" if envinfo.cuda_compiled_version is None: - mutable_dict['cuda_compiled_version'] = 'None' + mutable_dict["cuda_compiled_version"] = "None" # Replace True with Yes, False with No mutable_dict = replace_bools(mutable_dict) @@ -776,20 +803,20 @@ def maybe_start_on_next_line(string): mutable_dict = replace_nones(mutable_dict) # If either of these are '', replace with 'No relevant packages' - mutable_dict['pip_packages'] = replace_if_empty( - mutable_dict['pip_packages']) - mutable_dict['conda_packages'] = replace_if_empty( - mutable_dict['conda_packages']) + mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"]) + mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"]) # Tag conda and pip packages with a prefix # If they were previously None, they'll show up as ie '[conda] Could not collect' - if mutable_dict['pip_packages']: - mutable_dict['pip_packages'] = prepend( - mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version)) - if mutable_dict['conda_packages']: - mutable_dict['conda_packages'] = prepend( - mutable_dict['conda_packages'], '[conda] ') - mutable_dict['cpu_info'] = envinfo.cpu_info + if mutable_dict["pip_packages"]: + mutable_dict["pip_packages"] = prepend( + mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version) + ) + if mutable_dict["conda_packages"]: + mutable_dict["conda_packages"] = prepend( + mutable_dict["conda_packages"], "[conda] " + ) + mutable_dict["cpu_info"] = envinfo.cpu_info return env_info_fmt.format(**mutable_dict) @@ -802,22 +829,29 @@ def main(): output = get_pretty_env_info() print(output) - if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr( - torch.utils, '_crash_handler'): + if ( + TORCH_AVAILABLE + and hasattr(torch, "utils") + and hasattr(torch.utils, "_crash_handler") + ): minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR if sys.platform == "linux" and os.path.exists(minidump_dir): dumps = [ - os.path.join(minidump_dir, dump) - for dump in os.listdir(minidump_dir) + os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir) ] latest = max(dumps, key=os.path.getctime) ctime = os.path.getctime(latest) creation_time = datetime.datetime.fromtimestamp(ctime).strftime( - '%Y-%m-%d %H:%M:%S') - msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \ - "if this is related to your bug please include it when you file a report ***" + "%Y-%m-%d %H:%M:%S" + ) + msg = ( + "\n*** Detected a minidump at {} created on {}, ".format( + latest, creation_time + ) + + "if this is related to your bug please include it when you file a report ***" + ) print(msg, file=sys.stderr) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 74462fb37ca9..7448bb122152 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -5,14 +5,21 @@ import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, - register_replacement) +from torch._inductor.pattern_matcher import ( + PatternMatcherPass, + fwd_only, + register_replacement, +) from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) + QuantKey, + kFp8StaticTensorSym, + kNvfp4Quant, + kStaticTensorScale, +) from vllm.platforms import current_platform from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 @@ -29,11 +36,11 @@ FUSED_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501 } -silu_and_mul_nvfp4_quant_supported = (current_platform.is_cuda() and hasattr( - torch.ops._C, "silu_and_mul_nvfp4_quant")) +silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr( + torch.ops._C, "silu_and_mul_nvfp4_quant" +) if silu_and_mul_nvfp4_quant_supported: - FUSED_OPS[ - kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 + FUSED_OPS[kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 class ActivationQuantPattern(ABC): @@ -49,16 +56,18 @@ def __init__( self.quant_key = quant_key self.quant_dtype = quant_key.dtype - assert self.quant_key in QUANT_OPS, \ + assert self.quant_key in QUANT_OPS, ( f"unsupported quantization scheme {self.quant_key}" + ) self.QUANT_OP = QUANT_OPS[self.quant_key] - assert self.quant_key in FUSED_OPS, \ + assert self.quant_key in FUSED_OPS, ( f"unsupported fusion scheme {self.quant_key}" + ) self.FUSED_OP = FUSED_OPS[self.quant_key] def empty_quant(self, *args, **kwargs): - kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} + kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} return torch.empty(*args, **kwargs) @abstractmethod @@ -72,37 +81,40 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): """ def __init__(self, symmetric: bool = True): - quant_key = QuantKey(dtype=FP8_DTYPE, - scale=kStaticTensorScale, - symmetric=symmetric) + quant_key = QuantKey( + dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric + ) super().__init__(quant_key) def register(self, pm_pass: PatternMatcherPass): - - def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): - at1 = auto_functionalized(SILU_MUL_OP, - result=result_silu_mul, - input=input) - at2 = auto_functionalized(self.QUANT_OP, - result=result, - input=at1[1], - scale=scale) + def pattern( + result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, + scale: torch.Tensor, + ): + at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input) + at2 = auto_functionalized( + self.QUANT_OP, result=result, input=at1[1], scale=scale + ) return at2[1] - def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - scale=scale) + def replacement( + result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, result=result, input=input, scale=scale + ) return at[1] inputs = [ self.empty_quant(5, 4), # result empty_bf16(5, 4), # result_silu_mul empty_bf16(5, 4), # input - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) @@ -117,28 +129,37 @@ def __init__(self): super().__init__(kNvfp4Quant) def register(self, pm_pass: PatternMatcherPass): - - def pattern(result: torch.Tensor, output_scale: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(SILU_MUL_OP, - result=result_silu_mul, - input=input) - at2 = auto_functionalized(self.QUANT_OP, - output=result, - input=at1[1], - output_scale=output_scale, - input_scale=scale) + def pattern( + result: torch.Tensor, + output_scale: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, + scale: torch.Tensor, + ): + at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input) + at2 = auto_functionalized( + self.QUANT_OP, + output=result, + input=at1[1], + output_scale=output_scale, + input_scale=scale, + ) return at2[1], at2[2] - def replacement(result: torch.Tensor, output_scale: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - result_block_scale=output_scale, - input=input, - input_global_scale=scale) + def replacement( + result: torch.Tensor, + output_scale: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + result=result, + result_block_scale=output_scale, + input=input, + input_global_scale=scale, + ) return at[1], at[2] inputs = [ @@ -146,7 +167,7 @@ def replacement(result: torch.Tensor, output_scale: torch.Tensor, empty_i32(128, 4), # output_scale empty_bf16(5, 64), # result_silu_mul empty_bf16(5, 64), # input - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) @@ -167,7 +188,8 @@ def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="activation_quant_fusion_pass") + pass_name="activation_quant_fusion_pass" + ) pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern() pattern_silu_mul_fp8.register(self.patterns) @@ -184,6 +206,9 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) def uuid(self): - return VllmInductorPass.hash_source(self, ActivationQuantPattern, - SiluMulFp8StaticQuantPattern, - SiluMulNvfp4QuantPattern) + return VllmInductorPass.hash_source( + self, + ActivationQuantPattern, + SiluMulFp8StaticQuantPattern, + SiluMulNvfp4QuantPattern, + ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 335bbda5e4eb..da9debbb0e27 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -20,8 +20,12 @@ from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname -from .compiler_interface import (CompilerInterface, EagerAdaptor, - InductorAdaptor, InductorStandaloneAdaptor) +from .compiler_interface import ( + CompilerInterface, + EagerAdaptor, + InductorAdaptor, + InductorStandaloneAdaptor, +) from .counter import compilation_counter from .inductor_pass import InductorPass from .pass_manager import PostGradPassManager @@ -33,9 +37,11 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: if compilation_config.use_inductor: # Use standalone compile only if requested, version is new enough, # and the symbol actually exists in this PyTorch build. - if (envs.VLLM_USE_STANDALONE_COMPILE - and is_torch_equal_or_newer("2.8.0.dev") - and hasattr(torch._inductor, "standalone_compile")): + if ( + envs.VLLM_USE_STANDALONE_COMPILE + and is_torch_equal_or_newer("2.8.0.dev") + and hasattr(torch._inductor, "standalone_compile") + ): logger.debug("Using InductorStandaloneAdaptor") return InductorStandaloneAdaptor() else: @@ -70,10 +76,9 @@ def __init__(self, compilation_config: CompilationConfig): def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) - def initialize_cache(self, - cache_dir: str, - disable_cache: bool = False, - prefix: str = ""): + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): """ Initialize the cache directory for the compiler. @@ -101,9 +106,9 @@ def initialize_cache(self, # do not use eval(), it is unsafe. self.cache = ast.literal_eval(f.read()) - self.compiler.initialize_cache(cache_dir=cache_dir, - disable_cache=disable_cache, - prefix=prefix) + self.compiler.initialize_cache( + cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix + ) def save_to_file(self): if self.disable_cache or not self.is_cache_updated: @@ -113,35 +118,46 @@ def save_to_file(self): with open(self.cache_file_path, "w") as f: f.write(data) - def load(self, - graph: fx.GraphModule, - example_inputs: list[Any], - graph_index: int, - runtime_shape: Optional[int] = None) -> Optional[Callable]: + def load( + self, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Optional[Callable]: if (runtime_shape, graph_index, self.compiler.name) not in self.cache: return None handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] - compiled_graph = self.compiler.load(handle, graph, example_inputs, - graph_index, runtime_shape) + compiled_graph = self.compiler.load( + handle, graph, example_inputs, graph_index, runtime_shape + ) if runtime_shape is None: logger.debug( - "Directly load the %s-th graph for dynamic shape from %s via " - "handle %s", graph_index, self.compiler.name, handle) + "Directly load the %s-th graph for dynamic shape from %s via handle %s", + graph_index, + self.compiler.name, + handle, + ) else: logger.debug( - "Directly load the %s-th graph for shape %s from %s via " - "handle %s", graph_index, str(runtime_shape), - self.compiler.name, handle) + "Directly load the %s-th graph for shape %s from %s via handle %s", + graph_index, + str(runtime_shape), + self.compiler.name, + handle, + ) return compiled_graph - def compile(self, - graph: fx.GraphModule, - example_inputs, - additional_inductor_config, - compilation_config: CompilationConfig, - graph_index: int = 0, - num_graphs: int = 1, - runtime_shape: Optional[int] = None) -> Any: + def compile( + self, + graph: fx.GraphModule, + example_inputs, + additional_inductor_config, + compilation_config: CompilationConfig, + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None, + ) -> Any: if graph_index == 0: # before compiling the first graph, record the start time global compilation_start_time @@ -152,8 +168,7 @@ def compile(self, compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, - runtime_shape) + compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. @@ -163,12 +178,16 @@ def compile(self, if runtime_shape is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " - "from the cache, took %.3f s", elapsed) + "from the cache, took %.3f s", + elapsed, + ) else: logger.info( "Directly load the compiled graph(s) for shape %s " - "from the cache, took %.3f s", str(runtime_shape), - elapsed) + "from the cache, took %.3f s", + str(runtime_shape), + elapsed, + ) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -177,37 +196,41 @@ def compile(self, # Let compile_fx generate a key for us maybe_key = None else: - maybe_key = \ - f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" + maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, runtime_shape, - maybe_key) + graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key + ) assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: - self.cache[(runtime_shape, graph_index, - self.compiler.name)] = handle + self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph if runtime_shape is None: - logger.info( - "Cache the graph for dynamic shape for later use") + logger.info("Cache the graph for dynamic shape for later use") else: - logger.info("Cache the graph of shape %s for later use", - str(runtime_shape)) + logger.info( + "Cache the graph of shape %s for later use", str(runtime_shape) + ) if runtime_shape is None: logger.debug( - "Store the %s-th graph for dynamic shape from %s via " - "handle %s", graph_index, self.compiler.name, handle) + "Store the %s-th graph for dynamic shape from %s via handle %s", + graph_index, + self.compiler.name, + handle, + ) else: logger.debug( "Store the %s-th graph for shape %s from %s via handle %s", - graph_index, str(runtime_shape), self.compiler.name, - handle) + graph_index, + str(runtime_shape), + self.compiler.name, + handle, + ) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: @@ -215,11 +238,13 @@ def compile(self, elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed if runtime_shape is None: - logger.info("Compiling a graph for dynamic shape takes %.2f s", - elapsed) + logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed) else: - logger.info("Compiling a graph for shape %s takes %.2f s", - runtime_shape, elapsed) + logger.info( + "Compiling a graph for shape %s takes %.2f s", + runtime_shape, + elapsed, + ) return compiled_graph @@ -232,8 +257,9 @@ class SplitItem: graph: fx.GraphModule -def split_graph(graph: fx.GraphModule, - ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]: +def split_graph( + graph: fx.GraphModule, ops: list[str] +) -> tuple[fx.GraphModule, list[SplitItem]]: # split graph by ops subgraph_id = 0 node_to_subgraph_id = {} @@ -241,7 +267,7 @@ def split_graph(graph: fx.GraphModule, for node in graph.graph.nodes: if node.op in ("output", "placeholder"): continue - if node.op == 'call_function' and str(node.target) in ops: + if node.op == "call_function" and str(node.target) in ops: subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) @@ -254,10 +280,8 @@ def split_graph(graph: fx.GraphModule, # the semantics of the graph will change when we # have mutations in the graph split_gm = torch.fx.passes.split_module.split_module( - graph, - None, - lambda node: node_to_subgraph_id[node], - keep_original_order=True) + graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True + ) outputs = [] @@ -271,8 +295,7 @@ def split_graph(graph: fx.GraphModule, module = getattr(split_gm, name) graph_id = int(name.replace("submod_", "")) - outputs.append( - SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) + outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) # sort by integer graph_id, rather than string name outputs.sort(key=lambda x: x.graph_id) @@ -295,11 +318,16 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): has some special cudagraph output handling. """ - def __init__(self, module: torch.fx.GraphModule, - compile_submod_names: list[str], vllm_config: VllmConfig, - vllm_backend: "VllmBackend"): + def __init__( + self, + module: torch.fx.GraphModule, + compile_submod_names: list[str], + vllm_config: VllmConfig, + vllm_backend: "VllmBackend", + ): super().__init__(module) from torch._guards import detect_fake_mode + self.fake_mode = detect_fake_mode() self.compile_submod_names = compile_submod_names self.compilation_config = vllm_config.compilation_config @@ -316,9 +344,12 @@ def run(self, *args): with self.fake_mode, enable_python_dispatcher(): return super().run(*fake_args) - def call_module(self, target: torch.fx.node.Target, - args: tuple[torch.fx.node.Argument, - ...], kwargs: dict[str, Any]) -> Any: + def call_module( + self, + target: torch.fx.node.Target, + args: tuple[torch.fx.node.Argument, ...], + kwargs: dict[str, Any], + ) -> Any: assert isinstance(target, str) output = super().call_module(target, args, kwargs) @@ -330,26 +361,34 @@ def call_module(self, target: torch.fx.node.Target, ] global compilation_start_time - compiled_graph_for_dynamic_shape = self.vllm_backend.\ - compiler_manager.compile( - submod, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=index, - num_graphs=len(self.compile_submod_names), - runtime_shape=None) + compiled_graph_for_dynamic_shape = ( + self.vllm_backend.compiler_manager.compile( + submod, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=index, + num_graphs=len(self.compile_submod_names), + runtime_shape=None, + ) + ) # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend piecewise_backend = PiecewiseBackend( - submod, self.vllm_config, index, - len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_dynamic_shape, self.vllm_backend) + submod, + self.vllm_config, + index, + len(self.compile_submod_names), + sym_shape_indices, + compiled_graph_for_dynamic_shape, + self.vllm_backend, + ) - if (self.compilation_config.cudagraph_mode.\ - has_piecewise_cudagraphs() and - not self.compilation_config.use_inductor_graph_partition): + if ( + self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() + and not self.compilation_config.use_inductor_graph_partition + ): # We're using Dynamo-based piecewise splitting, so we wrap # the whole subgraph with a static graph wrapper. from .cuda_graph import CUDAGraphOptions @@ -357,7 +396,8 @@ def call_module(self, target: torch.fx.node.Target, # resolve the static graph wrapper class (e.g. CUDAGraphWrapper # class) as platform dependent. static_graph_wrapper_class = resolve_obj_by_qualname( - current_platform.get_static_graph_wrapper_cls()) + current_platform.get_static_graph_wrapper_cls() + ) # Always assign PIECEWISE runtime mode to the # CUDAGraphWrapper for piecewise_backend, to distinguish @@ -370,7 +410,9 @@ def call_module(self, target: torch.fx.node.Target, cudagraph_options=CUDAGraphOptions( debug_log_enable=piecewise_backend.is_first_graph, gc_disable=not piecewise_backend.is_first_graph, - weak_ref_output=piecewise_backend.is_last_graph)) + weak_ref_output=piecewise_backend.is_last_graph, + ), + ) else: self.module.__dict__[target] = piecewise_backend @@ -388,8 +430,9 @@ def call_module(self, target: torch.fx.node.Target, def set_model_tag(tag: str): """Context manager to set the model tag.""" global model_tag - assert tag != model_tag, \ + assert tag != model_tag, ( f"Model tag {tag} is the same as the current tag {model_tag}." + ) old_tag = model_tag model_tag = tag try: @@ -430,7 +473,6 @@ def __init__( vllm_config: VllmConfig, prefix: str = "", ): - # if the model is initialized with a non-empty prefix, # then usually it's enough to use that prefix, # e.g. language_model, vision_model, etc. @@ -449,7 +491,8 @@ def __init__( self.compilation_config = vllm_config.compilation_config self.compiler_manager: CompilerManager = CompilerManager( - self.compilation_config) + self.compilation_config + ) # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -465,8 +508,10 @@ def configure_post_pass(self): if PASS_KEY in inductor_config: if isinstance(inductor_config[PASS_KEY], PostGradPassManager): # PassManager already added to config, make sure it's correct - assert (inductor_config[PASS_KEY].uuid() == - self.post_grad_pass_manager.uuid()) + assert ( + inductor_config[PASS_KEY].uuid() + == self.post_grad_pass_manager.uuid() + ) else: # Config should automatically wrap all inductor passes assert isinstance(inductor_config[PASS_KEY], InductorPass) @@ -474,7 +519,6 @@ def configure_post_pass(self): inductor_config[PASS_KEY] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: - vllm_config = self.vllm_config if not self.compilation_config.cache_dir: # no provided cache dir, generate one based on the known factors @@ -495,12 +539,12 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # 2. factors come from the code files that are traced by Dynamo ( # it mainly summarizes how the model is used in forward pass) - forward_code_files = list( - sorted(self.compilation_config.traced_files)) + forward_code_files = list(sorted(self.compilation_config.traced_files)) self.compilation_config.traced_files.clear() logger.debug( "Traced files (to be considered for compilation cache):\n%s", - "\n".join(forward_code_files)) + "\n".join(forward_code_files), + ) hash_content = [] for filepath in forward_code_files: hash_content.append(filepath) @@ -511,8 +555,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: with open(filepath) as f: hash_content.append(f.read()) import hashlib - code_hash = hashlib.md5("\n".join(hash_content).encode(), - usedforsecurity=False).hexdigest() + + code_hash = hashlib.md5( + "\n".join(hash_content).encode(), usedforsecurity=False + ).hexdigest() factors.append(code_hash) # 3. compiler hash @@ -520,8 +566,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: factors.append(compiler_hash) # combine all factors to generate the cache dir - hash_key = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest()[:10] + hash_key = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] cache_dir = os.path.join( envs.VLLM_CACHE_ROOT, @@ -535,8 +582,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_config.cache_dir = cache_dir rank = vllm_config.parallel_config.rank dp_rank = vllm_config.parallel_config.data_parallel_rank - local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", - self.prefix) + local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix) os.makedirs(local_cache_dir, exist_ok=True) self.compilation_config.local_cache_dir = local_cache_dir @@ -545,16 +591,19 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if disable_cache: logger.info("vLLM's torch.compile cache is disabled.") else: - logger.info("Using cache directory: %s for vLLM's torch.compile", - local_cache_dir) + logger.info( + "Using cache directory: %s for vLLM's torch.compile", local_cache_dir + ) - self.compiler_manager.initialize_cache(local_cache_dir, disable_cache, - self.prefix) + self.compiler_manager.initialize_cache( + local_cache_dir, disable_cache, self.prefix + ) # when dynamo calls the backend, it means the bytecode # transform and analysis are done compilation_counter.num_graphs_seen += 1 from .monitor import torch_compile_start_time + dynamo_time = time.time() - torch_compile_start_time logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) self.compilation_config.compilation_time += dynamo_time @@ -567,7 +616,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.configure_post_pass() self.split_gm, self.piecewise_graphs = split_graph( - graph, self.compilation_config.splitting_ops) + graph, self.compilation_config.splitting_ops + ) from torch._dynamo.utils import lazy_format_graph_code @@ -576,25 +626,27 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: lazy_format_graph_code("before split", self.graph) lazy_format_graph_code("after split", self.split_gm) - compilation_counter.num_piecewise_graphs_seen += len( - self.piecewise_graphs) + compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs) submod_names_to_compile = [ - item.submod_name for item in self.piecewise_graphs + item.submod_name + for item in self.piecewise_graphs if not item.is_splitting_graph ] # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes - PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, - self.vllm_config, - self).run(*example_inputs) + PiecewiseCompileInterpreter( + self.split_gm, submod_names_to_compile, self.vllm_config, self + ).run(*example_inputs) graph_path = os.path.join(local_cache_dir, "computation_graph.py") if not os.path.exists(graph_path): # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa # use `print_readable` because it can include submodules - src = "from __future__ import annotations\nimport torch\n" + \ - self.split_gm.print_readable(print_output=False) + src = ( + "from __future__ import annotations\nimport torch\n" + + self.split_gm.print_readable(print_output=False) + ) src = src.replace("<lambda>", "GraphModule") with open(graph_path, "w") as f: f.write(src) @@ -603,12 +655,15 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self._called = True - if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \ - not self.compilation_config.cudagraph_copy_inputs: + if ( + self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + or not self.compilation_config.cudagraph_copy_inputs + ): return self.split_gm # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode + fake_mode = detect_fake_mode() fake_args = [ fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t @@ -619,10 +674,12 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # for weights and static buffers, they will have concrete shapes. # symbolic shape only happens for input tensors. from torch.fx.experimental.symbolic_shapes import is_symbolic + self.sym_tensor_indices = [ - i for i, x in enumerate(fake_args) - if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \ - any(is_symbolic(d) for d in x.size()) + i + for i, x in enumerate(fake_args) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) + and any(is_symbolic(d) for d in x.size()) ] # compiler managed cudagraph input buffers diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 04b76a9c2d22..01fd9f9a1c8e 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -14,7 +14,9 @@ from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -27,8 +29,12 @@ if find_spec("flashinfer"): try: import flashinfer.comm as flashinfer_comm - flashinfer_comm = (flashinfer_comm if hasattr( - flashinfer_comm, "trtllm_allreduce_fusion") else None) + + flashinfer_comm = ( + flashinfer_comm + if hasattr(flashinfer_comm, "trtllm_allreduce_fusion") + else None + ) except ImportError: flashinfer_comm = None else: @@ -44,7 +50,6 @@ class BasePattern: - def __init__(self, dtype: torch.dtype, device: str): self.dtype = dtype self.device = device @@ -53,14 +58,12 @@ def __init__(self, dtype: torch.dtype, device: str): class GEMMReduceScatterPattern(BasePattern): - def get_inputs(self): mul = torch.empty([16, 4], device=self.device, dtype=self.dtype) mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [mul, mm_weight] def register(self, pm_pass: PatternMatcherPass): - def pattern(mul: torch.Tensor, mm_weight: torch.Tensor): mm = torch.ops.aten.mm.default(mul, mm_weight) reduce_scatter = torch.ops.vllm.reduce_scatter.default( @@ -82,12 +85,12 @@ def replacement(mul: torch.Tensor, mm_weight: torch.Tensor): return gemm_rs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AllGatherGEMMPattern(BasePattern): - def get_inputs(self): x = torch.empty([4, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) @@ -95,7 +98,6 @@ def get_inputs(self): return [x, weight] def register(self, pm_pass: PatternMatcherPass): - def pattern( x: torch.Tensor, weight: torch.Tensor, @@ -110,8 +112,8 @@ def pattern( return torch.ops.aten.mm.default(all_gather, weight) def replacement( - x: torch.Tensor, - weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( x, [weight], @@ -120,42 +122,53 @@ def replacement( ) return mm_outputs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class ScaledMMReduceScatterPattern(BasePattern): - def get_inputs(self): input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) - mm_weight = torch.empty([16, 16], device=self.device, - dtype=FP8_DTYPE).contiguous().transpose(0, 1) + mm_weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) return [input, mm_weight, scale_a, scale_b] def register(self, pm_pass: PatternMatcherPass): - - def pattern(input: torch.Tensor, mat2: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor) -> torch.Tensor: - scaled_mm = torch.ops.aten._scaled_mm.default(input, - mat2=mat2, - scale_a=scale_a, - scale_b=scale_b, - bias=None, - scale_result=None, - out_dtype=self.dtype) + def pattern( + input: torch.Tensor, + mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: + scaled_mm = torch.ops.aten._scaled_mm.default( + input, + mat2=mat2, + scale_a=scale_a, + scale_b=scale_b, + bias=None, + scale_result=None, + out_dtype=self.dtype, + ) reduce_scatter = torch.ops.vllm.reduce_scatter.default( scaled_mm, dim=0, world_size=self.tp_size, - group_name=self.tp.unique_name) + group_name=self.tp.unique_name, + ) return reduce_scatter - def replacement(input: torch.Tensor, mat2: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor) -> torch.Tensor: + def replacement( + input: torch.Tensor, + mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( input, mat2, @@ -169,16 +182,19 @@ def replacement(input: torch.Tensor, mat2: torch.Tensor, return gemm_rs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AllGatherScaledMMPattern(BasePattern): - def get_inputs(self): x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) - weight = torch.empty([16, 16], device=self.device, - dtype=FP8_DTYPE).contiguous().transpose(0, 1) + weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) s1 = x.shape[0] * self.tp_size @@ -188,7 +204,6 @@ def get_inputs(self): return [x, weight, scale_a, scale_b] def register(self, pm_pass: PatternMatcherPass): - def pattern( x: torch.Tensor, weight: torch.Tensor, @@ -196,22 +211,25 @@ def pattern( scale_b: torch.Tensor, ) -> torch.Tensor: all_gather = torch.ops.vllm.all_gather.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp.unique_name) - - return torch.ops.aten._scaled_mm.default(all_gather, - mat2=weight, - scale_a=scale_a, - scale_b=scale_b, - bias=None, - scale_result=None, - out_dtype=self.dtype) - - def replacement(x: torch.Tensor, weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor) -> torch.Tensor: + x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name + ) + + return torch.ops.aten._scaled_mm.default( + all_gather, + mat2=weight, + scale_a=scale_a, + scale_b=scale_b, + bias=None, + scale_result=None, + out_dtype=self.dtype, + ) + + def replacement( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + ) -> torch.Tensor: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa x, [weight], @@ -226,29 +244,33 @@ def replacement(x: torch.Tensor, weight: torch.Tensor, ) return mm_outputs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class CutlassScaledMMReduceScatterPattern(BasePattern): - def get_inputs(self): input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) - mm_weight = torch.empty([16, 16], device=self.device, - dtype=FP8_DTYPE).contiguous().transpose(0, 1) + mm_weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32) scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32) - cutlass_mm_output = torch.empty([16, 16], - device=self.device, - dtype=self.dtype) + cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype) return [input, mm_weight, scale_a, scale_b, cutlass_mm_output] def register(self, pm_pass: PatternMatcherPass): - - def pattern(input: torch.Tensor, weight: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - cutlass_mm_output: torch.Tensor) -> torch.Tensor: + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + cutlass_mm_output: torch.Tensor, + ) -> torch.Tensor: cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( torch.ops._C.cutlass_scaled_mm.default, out=cutlass_mm_output, @@ -256,18 +278,24 @@ def pattern(input: torch.Tensor, weight: torch.Tensor, b=weight, a_scales=scale_a, b_scales=scale_b, - bias=None) + bias=None, + ) reduce_scatter = torch.ops.vllm.reduce_scatter.default( cutlass_scaled_mm[1], dim=0, world_size=self.tp_size, - group_name=self.tp.unique_name) + group_name=self.tp.unique_name, + ) return reduce_scatter - def replacement(input: torch.Tensor, mat2: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - cutlass_mm_output: torch.Tensor) -> torch.Tensor: + def replacement( + input: torch.Tensor, + mat2: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + cutlass_mm_output: torch.Tensor, + ) -> torch.Tensor: gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( input, mat2, @@ -281,16 +309,19 @@ def replacement(input: torch.Tensor, mat2: torch.Tensor, return gemm_rs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AllGatherCutlassScaledMMPattern(BasePattern): - def get_inputs(self): x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) - weight = torch.empty([16, 16], device=self.device, - dtype=FP8_DTYPE).contiguous().transpose(0, 1) + weight = ( + torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + .contiguous() + .transpose(0, 1) + ) s1 = x.shape[0] * self.tp_size @@ -303,7 +334,6 @@ def get_inputs(self): return [x, weight, scale_a, scale_b, output] def register(self, pm_pass: PatternMatcherPass): - def pattern( x: torch.Tensor, weight: torch.Tensor, @@ -312,10 +342,8 @@ def pattern( output: torch.Tensor, ) -> torch.Tensor: all_gather = torch.ops.vllm.all_gather.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp.unique_name) + x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name + ) cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized( torch.ops._C.cutlass_scaled_mm.default, @@ -324,12 +352,17 @@ def pattern( b=weight, a_scales=scale_a, b_scales=scale_b, - bias=None) + bias=None, + ) return cutlass_scaled_mm[1] - def replacement(x: torch.Tensor, weight: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - output: torch.Tensor) -> torch.Tensor: + def replacement( + x: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + output: torch.Tensor, + ) -> torch.Tensor: ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa x, [weight], @@ -344,12 +377,12 @@ def replacement(x: torch.Tensor, weight: torch.Tensor, ) return mm_outputs - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AsyncTPPass(VllmPatternMatcherPass): - @enable_fake_mode def __init__(self, config: VllmConfig): super().__init__(config) @@ -357,26 +390,29 @@ def __init__(self, config: VllmConfig): # Enable symmetric memory for the TP process group enable_symm_mem_for_group(get_tp_group().device_group.group_name) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="async_tp_pass") - GEMMReduceScatterPattern(self.model_dtype, - self.device).register(self.patterns) + pass_name="async_tp_pass" + ) + GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns) - AllGatherGEMMPattern(self.model_dtype, - self.device).register(self.patterns) + AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns) # These fusions are enabled only for bfloat16 models because # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling # only supports bfloat16 as the output dtype. if self.model_dtype == torch.bfloat16: - ScaledMMReduceScatterPattern(self.model_dtype, - self.device).register(self.patterns) - AllGatherScaledMMPattern(self.model_dtype, - self.device).register(self.patterns) + ScaledMMReduceScatterPattern(self.model_dtype, self.device).register( + self.patterns + ) + AllGatherScaledMMPattern(self.model_dtype, self.device).register( + self.patterns + ) - CutlassScaledMMReduceScatterPattern( - self.model_dtype, self.device).register(self.patterns) - AllGatherCutlassScaledMMPattern( - self.model_dtype, self.device).register(self.patterns) + CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register( + self.patterns + ) + AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register( + self.patterns + ) self.dump_patterns(config, self.patterns) @@ -405,15 +441,16 @@ def __call__(self, graph: fx.Graph): } try: - _FI_MAX_SIZES.update({ - int(k): int(float(v) * MiB) - for k, v in - envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() - }) + _FI_MAX_SIZES.update( + { + int(k): int(float(v) * MiB) + for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items() + } + ) except Exception as e: raise ValueError( - "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " - + str(e)) from e + "Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e) + ) from e # opt for a more conservative default value # when world size is not in _FI_MAX_SIZES @@ -446,8 +483,9 @@ def call_trtllm_fused_allreduce_norm( max_fusion_size, ) if use_flashinfer: - assert (_FI_WORKSPACE_TENSOR is not None - ), "Flashinfer must be enabled when using flashinfer" + assert _FI_WORKSPACE_TENSOR is not None, ( + "Flashinfer must be enabled when using flashinfer" + ) if norm_out is None: norm_out = allreduce_in residual_out = residual @@ -479,38 +517,43 @@ def call_trtllm_fused_allreduce_norm( quant_out=quant_out, scale_out=scale_out, # in vllm we only support swizzled layout - layout_code=flashinfer_comm.QuantizationSFLayout. - SWIZZLED_128x4, + layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, scale_factor=scale_factor, ) else: allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) - if (scale_factor is not None and scale_out is None - and fuse_rms_quant): + if scale_factor is not None and scale_out is None and fuse_rms_quant: # Do fused rms norm static fp8 quant fused op if norm_out is None: torch.ops._C.fused_add_rms_norm_static_fp8_quant( - quant_out, allreduce_out, residual, rms_gamma, - scale_factor, rms_eps) + quant_out, + allreduce_out, + residual, + rms_gamma, + scale_factor, + rms_eps, + ) else: torch.ops._C.rms_norm_static_fp8_quant( - quant_out, allreduce_out, rms_gamma, scale_factor, - rms_eps) + quant_out, allreduce_out, rms_gamma, scale_factor, rms_eps + ) else: if norm_out is None: - torch.ops._C.fused_add_rms_norm(allreduce_out, residual, - rms_gamma, rms_eps) + torch.ops._C.fused_add_rms_norm( + allreduce_out, residual, rms_gamma, rms_eps + ) norm_out = allreduce_out else: - torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, - rms_eps) + torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps) if scale_factor is not None: if scale_out is not None: - torch.ops._C.scaled_fp4_quant(quant_out, norm_out, - scale_out, scale_factor) + torch.ops._C.scaled_fp4_quant( + quant_out, norm_out, scale_out, scale_factor + ) else: torch.ops._C.static_scaled_fp8_quant( - quant_out, norm_out, scale_factor) + quant_out, norm_out, scale_factor + ) if scale_factor is None or norm_out is not None: # we need to return allreduce output # in cases of non quant fused AR + RMS norm @@ -518,22 +561,23 @@ def call_trtllm_fused_allreduce_norm( allreduce_in.copy_(allreduce_out) def call_trtllm_fused_allreduce_norm_fake( - allreduce_in: torch.Tensor, - residual: torch.Tensor, - rms_gamma: torch.Tensor, - rms_eps: float, - world_rank: int, - world_size: int, - launch_with_pdl: bool, - trigger_completion_at_end: bool, - fp32_acc: bool, - max_token_num: int, - pattern_code: int, - fuse_rms_quant: bool, - norm_out: Optional[torch.Tensor] = None, - quant_out: Optional[torch.Tensor] = None, - scale_out: Optional[torch.Tensor] = None, - scale_factor: Optional[torch.Tensor] = None) -> None: + allreduce_in: torch.Tensor, + residual: torch.Tensor, + rms_gamma: torch.Tensor, + rms_eps: float, + world_rank: int, + world_size: int, + launch_with_pdl: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + max_token_num: int, + pattern_code: int, + fuse_rms_quant: bool, + norm_out: Optional[torch.Tensor] = None, + quant_out: Optional[torch.Tensor] = None, + scale_out: Optional[torch.Tensor] = None, + scale_factor: Optional[torch.Tensor] = None, + ) -> None: pass direct_register_custom_op( @@ -549,7 +593,8 @@ def call_trtllm_fused_allreduce_norm_fake( fake_impl=call_trtllm_fused_allreduce_norm_fake, ) flashinfer_trtllm_fused_allreduce_norm = ( - torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default) + torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default + ) class FlashInferFusedAllReduceParams: @@ -587,7 +632,7 @@ def get_trtllm_fused_allreduce_kwargs(self): class AllReduceRMSNormPattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (without residual) + This pattern replaces the allreduce + rms norm (without residual) with fused flashinfer implementation. Applies to allreduce + rmsnorm before attn in the first Transformer block. """ @@ -605,17 +650,15 @@ def __init__( def get_inputs(self): input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) - rms_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.dtype) + rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4], device=self.device, dtype=self.dtype) return [input, rms_result, weight] def register(self, pm_pass: PatternMatcherPass): - - def pattern(input: torch.Tensor, rms_result: torch.Tensor, - weight: torch.Tensor): + def pattern( + input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor + ): allreduce_output = tensor_model_parallel_all_reduce(input) rms = auto_functionalized( RMS_OP, @@ -627,8 +670,9 @@ def pattern(input: torch.Tensor, rms_result: torch.Tensor, # rms_result, allreduce_output return rms[1], allreduce_output - def replacement(input: torch.Tensor, rms_result: torch.Tensor, - weight: torch.Tensor): + def replacement( + input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor + ): residual = torch.zeros_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, @@ -639,20 +683,20 @@ def replacement(input: torch.Tensor, rms_result: torch.Tensor, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNorm, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) # rms_result, allreduce_in return allreduce[3], allreduce[1] - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusedAddRMSNormPattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (with residual) + This pattern replaces the allreduce + rms norm (with residual) with fused flashinfer implementation. Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn. """ @@ -679,9 +723,7 @@ def get_inputs(self): ] def register(self, pm_pass: PatternMatcherPass): - - def pattern(residual: torch.Tensor, input: torch.Tensor, - weight: torch.Tensor): + def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor): allreduce_output = tensor_model_parallel_all_reduce(input) rms = auto_functionalized( RMS_ADD_OP, @@ -693,8 +735,9 @@ def pattern(residual: torch.Tensor, input: torch.Tensor, # input, residual return rms[1], rms[2] - def replacement(residual: torch.Tensor, input: torch.Tensor, - weight: torch.Tensor): + def replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ): allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -704,44 +747,46 @@ def replacement(residual: torch.Tensor, input: torch.Tensor, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNorm, + pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) # allreduce_in, residual return allreduce[1], allreduce[2] - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (without residual) + This pattern replaces the allreduce + rms norm (without residual) + static fp8 quant with fused flashinfer implementation. - Applies to allreduce + rmsnorm + quant before attn + Applies to allreduce + rmsnorm + quant before attn in the first Transformer block. """ - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - allreduce_params: FlashInferFusedAllReduceParams): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): - input = torch.zeros([1, 8, 4], - device=self.device, - dtype=self.dtype) - rmsnorm_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.quant_dtype) + input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) + rmsnorm_result = torch.empty( + [1, 8, 4], device=self.device, dtype=self.dtype + ) + quant_result = torch.empty( + [1, 8, 4], device=self.device, dtype=self.quant_dtype + ) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) return [input, rmsnorm_result, quant_result, weight, scale] @@ -754,23 +799,31 @@ def pattern( scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized(RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon) + rmsnorm_out_tuple = auto_functionalized( + RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon, + ) - quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP, - result=quant_result, - input=rmsnorm_out_tuple[1], - scale=scale) + quant_out_tuple = auto_functionalized( + STATIC_FP8_QUANT_OP, + result=quant_result, + input=rmsnorm_out_tuple[1], + scale=scale, + ) # quant_out, allreduce_output return quant_out_tuple[1], all_reduce - def replacement(input: torch.Tensor, result_rms: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def replacement( + input: torch.Tensor, + result_rms: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): residual = torch.zeros_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, @@ -781,8 +834,10 @@ def replacement(input: torch.Tensor, result_rms: torch.Tensor, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards + # We don't use norm_out afterwards + pattern_code=( + flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant + ), scale_factor=scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) @@ -790,40 +845,41 @@ def replacement(input: torch.Tensor, result_rms: torch.Tensor, # quant_out, allreduce_output return allreduce[4], allreduce[1] - pm.register_replacement(pattern, replacement, get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): """ This pattern replaces the allreduce + rms norm (with residual) + static fp8 quant with fused flashinfer implementation. - Applies to o_proj + rmsnorm after attn + quant and + Applies to o_proj + rmsnorm after attn + quant and mlp + rmsnorm + quant before attn. """ - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - allreduce_params: FlashInferFusedAllReduceParams): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params self.quant_dtype = torch.float8_e4m3fn def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): input = torch.empty([4, 4], device=self.device, dtype=self.dtype) - residual = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) - quant_result = torch.empty([4, 4], - device=self.device, - dtype=self.quant_dtype) - scale = torch.empty([1, 1], - device=self.device, - dtype=torch.float32) + quant_result = torch.empty( + [4, 4], device=self.device, dtype=self.quant_dtype + ) + scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) return [ quant_result, @@ -842,25 +898,30 @@ def pattern( ): allreduce_output = tensor_model_parallel_all_reduce(input) - fused_add_rmsnorm_out_tuple = \ - auto_functionalized( + fused_add_rmsnorm_out_tuple = auto_functionalized( RMS_ADD_OP, input=allreduce_output, residual=residual, weight=weight, - epsilon=self.epsilon) + epsilon=self.epsilon, + ) quant_out_tuple = auto_functionalized( STATIC_FP8_QUANT_OP, result=quant_result, input=fused_add_rmsnorm_out_tuple[1], - scale=scale) + scale=scale, + ) # quant_out, allreduce_output return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2] - def replacement(quant_result: torch.Tensor, residual: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): + def replacement( + quant_result: torch.Tensor, + residual: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -870,56 +931,61 @@ def replacement(quant_result: torch.Tensor, residual: torch.Tensor, scale_out=None, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards + # We don't use norm_out afterwards + pattern_code=( + flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant + ), scale_factor=scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) - # # quant_out, rms_norm_residual + # quant_out, rms_norm_residual return allreduce[4], allreduce[2] - pm.register_replacement(pattern, replacement, get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): """ - This pattern replaces the allreduce + rms norm (without residual) + This pattern replaces the allreduce + rms norm (without residual) + static nvfp4 quant with fused flashinfer implementation. - Applies to allreduce + rmsnorm + quant before attn + Applies to allreduce + rmsnorm + quant before attn in the first Transformer block. """ - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - allreduce_params: FlashInferFusedAllReduceParams): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): - input = torch.empty([1, 16, 16], - device=self.device, - dtype=self.dtype) - - rmsnorm_result = torch.empty([1, 16, 16], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty((16, 8), - device=self.device, - dtype=torch.uint8) - input_global_scale = torch.empty([1, 1], - device=self.device, - dtype=torch.float32) + input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype) + + rmsnorm_result = torch.empty( + [1, 16, 16], device=self.device, dtype=self.dtype + ) + quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) + input_global_scale = torch.empty( + [1, 1], device=self.device, dtype=torch.float32 + ) weight = torch.empty([16], device=self.device, dtype=self.dtype) - output_scale = torch.empty([128, 4], - device=self.device, - dtype=torch.int32) + output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) return [ - input, rmsnorm_result, quant_result, weight, - input_global_scale, output_scale + input, + rmsnorm_result, + quant_result, + weight, + input_global_scale, + output_scale, ] def pattern( @@ -931,26 +997,33 @@ def pattern( output_scale: torch.Tensor, ): all_reduce = tensor_model_parallel_all_reduce(input) - rmsnorm_out_tuple = auto_functionalized(RMS_OP, - result=rmsnorm_result, - input=all_reduce, - weight=weight, - epsilon=self.epsilon) + rmsnorm_out_tuple = auto_functionalized( + RMS_OP, + result=rmsnorm_result, + input=all_reduce, + weight=weight, + epsilon=self.epsilon, + ) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, input=rmsnorm_out_tuple[1], output_scale=output_scale, - input_scale=input_global_scale) + input_scale=input_global_scale, + ) # quant_out, allreduce_output, output_scale return quant_out_tuple[1], all_reduce, quant_out_tuple[2] - def replacement(input: torch.Tensor, result_rms: torch.Tensor, - quant_result: torch.Tensor, weight: torch.Tensor, - input_global_scale: torch.Tensor, - output_scale: torch.Tensor): + def replacement( + input: torch.Tensor, + result_rms: torch.Tensor, + quant_result: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + output_scale: torch.Tensor, + ): residual = torch.zeros_like(input) allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, @@ -961,8 +1034,10 @@ def replacement(input: torch.Tensor, result_rms: torch.Tensor, scale_out=output_scale, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards + # We don't use norm_out afterwards + pattern_code=( + flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant + ), scale_factor=input_global_scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) @@ -970,44 +1045,41 @@ def replacement(input: torch.Tensor, result_rms: torch.Tensor, # quant_out, allreduce_output, output_scale return allreduce[4], allreduce[1], allreduce[5] - pm.register_replacement(pattern, replacement, get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): """ This pattern replaces the allreduce + rms norm (with residual) + static nvfp4 quant with fused flashinfer implementation. - Applies to o_proj + rmsnorm after attn + quant and + Applies to o_proj + rmsnorm after attn + quant and mlp + rmsnorm + quant before attn. """ - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - allreduce_params: FlashInferFusedAllReduceParams): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): super().__init__(dtype, device) self.epsilon = epsilon self.allreduce_params = allreduce_params def register(self, pm_pass: PatternMatcherPass): - def get_inputs(): input = torch.empty([16, 16], device=self.device, dtype=self.dtype) - residual = torch.empty([16, 16], - device=self.device, - dtype=self.dtype) - weight = torch.empty([16, 16], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty((16, 8), - device=self.device, - dtype=torch.uint8) - input_global_scale = torch.empty([1, 1], - device=self.device, - dtype=torch.float32) - output_scale = torch.empty([128, 4], - device=self.device, - dtype=torch.int32) + residual = torch.empty([16, 16], device=self.device, dtype=self.dtype) + weight = torch.empty([16, 16], device=self.device, dtype=self.dtype) + quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8) + input_global_scale = torch.empty( + [1, 1], device=self.device, dtype=torch.float32 + ) + output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) return [ quant_result, @@ -1018,33 +1090,46 @@ def get_inputs(): input_global_scale, ] - def pattern(quant_result: torch.Tensor, residual: torch.Tensor, - input: torch.Tensor, output_scale: torch.Tensor, - weight: torch.Tensor, input_global_scale: torch.Tensor): + def pattern( + quant_result: torch.Tensor, + residual: torch.Tensor, + input: torch.Tensor, + output_scale: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + ): allreduce_output = tensor_model_parallel_all_reduce(input) - fused_add_rmsnorm_out_tuple = \ - auto_functionalized( + fused_add_rmsnorm_out_tuple = auto_functionalized( RMS_ADD_OP, input=allreduce_output, residual=residual, weight=weight, - epsilon=self.epsilon) + epsilon=self.epsilon, + ) quant_out_tuple = auto_functionalized( STATIC_FP4_QUANT_OP, output=quant_result, input=fused_add_rmsnorm_out_tuple[1], output_scale=output_scale, - input_scale=input_global_scale) + input_scale=input_global_scale, + ) # quant_out, allreduce_output, output_scale - return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[ - 2], quant_out_tuple[2] + return ( + quant_out_tuple[1], + fused_add_rmsnorm_out_tuple[2], + quant_out_tuple[2], + ) - def replacement(quant_result: torch.Tensor, residual: torch.Tensor, - input: torch.Tensor, output_scale: torch.Tensor, - weight: torch.Tensor, - input_global_scale: torch.Tensor): + def replacement( + quant_result: torch.Tensor, + residual: torch.Tensor, + input: torch.Tensor, + output_scale: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + ): allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -1054,20 +1139,22 @@ def replacement(quant_result: torch.Tensor, residual: torch.Tensor, scale_out=output_scale, rms_gamma=weight, rms_eps=self.epsilon, - pattern_code=flashinfer_comm.AllReduceFusionPattern. - kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards + # We don't use norm_out afterwards + pattern_code=( + flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant + ), scale_factor=input_global_scale, **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), ) # quant_out, rms_norm_residual, output_scale return allreduce[4], allreduce[2], allreduce[5] - pm.register_replacement(pattern, replacement, get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, get_inputs(), pm.fwd_only, pm_pass + ) class AllReduceFusionPass(VllmPatternMatcherPass): - def __init__(self, config: VllmConfig): super().__init__(config) self.disabled = True @@ -1075,7 +1162,8 @@ def __init__(self, config: VllmConfig): if self.tp_size <= 1: return self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="all_reduce_fusion_pass") + pass_name="all_reduce_fusion_pass" + ) if config.model_config is None: return self.hidden_dim = config.model_config.get_hidden_size() @@ -1085,21 +1173,21 @@ def __init__(self, config: VllmConfig): if flashinfer_comm is None: logger.warning( "Flashinfer is not installed or comm module not found, " - "skipping allreduce fusion pass") + "skipping allreduce fusion pass" + ) return # Check if the world size is supported if self.tp_size not in _FI_MAX_SIZES: logger.warning( - "Flashinfer allreduce fusion is not " - "supported for world size %s", + "Flashinfer allreduce fusion is not supported for world size %s", self.tp_size, ) return max_num_token = min( - _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) // - (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), - config.compilation_config.pass_config. - fi_allreduce_fusion_max_token_num) + _FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) + // (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)), + config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num, + ) self.ipc_handles, workspace_tensor = ( flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, @@ -1108,7 +1196,8 @@ def __init__(self, config: VllmConfig): hidden_dim=self.hidden_dim, group=self.group, use_fp32_lamport=use_fp32_lamport, - )) + ) + ) global _FI_WORKSPACE_TENSOR _FI_WORKSPACE_TENSOR = workspace_tensor @@ -1119,7 +1208,8 @@ def __init__(self, config: VllmConfig): max_token_num=max_num_token, # fuse rms norm static fp8 quant fused op # in fallback path, when we don't use flashinfer - fuse_rms_quant=config.compilation_config.pass_config.enable_fusion) + fuse_rms_quant=config.compilation_config.pass_config.enable_fusion, + ) self.register_patterns() self.dump_patterns(config, self.patterns) @@ -1185,4 +1275,5 @@ def __del__(self): return if flashinfer_comm is not None: flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce( - self.ipc_handles, self.group) + self.ipc_handles, self.group + ) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index eeca14d1296f..3b5fecaf189b 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -24,14 +24,14 @@ class CompilerInterface: """ The interface for a compiler that can be used by vLLM. """ + # The name of the compiler, e.g. inductor. # This is a class-level attribute. name: str - def initialize_cache(self, - cache_dir: str, - disable_cache: bool = False, - prefix: str = ""): + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): """ when the vLLM process uses `cache_dir` as the cache directory, the compiler should initialize itself with the cache directory, @@ -93,12 +93,14 @@ def compile( """ return None, None - def load(self, - handle: Any, - graph: fx.GraphModule, - example_inputs: list[Any], - graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Callable: """ Load the compiled function from the handle. Raises an error if the handle is invalid. @@ -150,11 +152,13 @@ def get_inductor_factors() -> list[Any]: factors: list[Any] = [] # summarize system state from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() factors.append(system_factors) # summarize pytorch state from torch._inductor.codecache import torch_key + torch_factors = torch_key() factors.append(torch_factors) return factors @@ -169,18 +173,19 @@ class InductorStandaloneAdaptor(CompilerInterface): Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off. """ + name = "inductor_standalone" def compute_hash(self, vllm_config: VllmConfig) -> str: factors = get_inductor_factors() - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest()[:10] + hash_str = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] return hash_str - def initialize_cache(self, - cache_dir: str, - disable_cache: bool = False, - prefix: str = ""): + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): self.cache_dir = cache_dir def compile( @@ -203,12 +208,14 @@ def compile( dynamic_shapes = "from_tracing_context" from torch._inductor import standalone_compile + with pass_context(runtime_shape): compiled_graph = standalone_compile( graph, example_inputs, dynamic_shapes=dynamic_shapes, - options={"config_patches": current_config}) + options={"config_patches": current_config}, + ) # Save the compiled artifact to disk in the specified path assert key is not None @@ -218,19 +225,23 @@ def compile( compilation_counter.num_compiled_artifacts_saved += 1 return compiled_graph, (key, path) - def load(self, - handle: Any, - graph: fx.GraphModule, - example_inputs: list[Any], - graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) assert isinstance(handle[1], str) path = handle[1] inductor_compiled_graph = torch._inductor.CompiledArtifact.load( - path=path, format="unpacked") + path=path, format="unpacked" + ) from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) def compiled_graph_wrapper(*args): @@ -250,21 +261,22 @@ class InductorAdaptor(CompilerInterface): """ The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7. """ + name = "inductor" def compute_hash(self, vllm_config: VllmConfig) -> str: factors = get_inductor_factors() - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest()[:10] + hash_str = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] return hash_str - def initialize_cache(self, - cache_dir: str, - disable_cache: bool = False, - prefix: str = ""): + def initialize_cache( + self, cache_dir: str, disable_cache: bool = False, prefix: str = "" + ): self.cache_dir = cache_dir self.prefix = prefix - self.base_cache_dir = cache_dir[:-len(prefix)] if prefix else cache_dir + self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir if disable_cache: return # redirect the cache directory to a sub-directory @@ -288,6 +300,7 @@ def compile( ) -> tuple[Optional[Callable], Optional[Any]]: compilation_counter.num_inductor_compiles += 1 from torch._inductor.compile_fx import compile_fx + current_config = {} if compiler_config is not None: current_config.update(compiler_config) @@ -308,8 +321,8 @@ def compile( # it to get the hash of the compiled graph directly. hash_str, file_path = None, None - from torch._inductor.codecache import (FxGraphCache, - compiled_fx_graph_hash) + from torch._inductor.codecache import FxGraphCache, compiled_fx_graph_hash + if torch.__version__.startswith("2.5"): original_load = FxGraphCache.load original_load_name = "torch._inductor.codecache.FxGraphCache.load" @@ -326,7 +339,8 @@ def hijack_load(*args, **kwargs): if not callable(cell.cell_contents): continue if cell.cell_contents.__code__.co_filename.startswith( - self.base_cache_dir): + self.base_cache_dir + ): # this is the real file path compiled from Inductor file_path = cell.cell_contents.__code__.co_filename break @@ -338,8 +352,7 @@ def hijack_load(*args, **kwargs): original_load_name = None def hijacked_compile_fx_inner(*args, **kwargs): - output = torch._inductor.compile_fx.compile_fx_inner( - *args, **kwargs) + output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs) nonlocal hash_str inductor_compiled_graph = output if inductor_compiled_graph is not None: @@ -353,8 +366,7 @@ def hijacked_compile_fx_inner(*args, **kwargs): if not callable(cell.cell_contents): continue code = cell.cell_contents.__code__ - if code.co_filename.startswith( - self.base_cache_dir): + if code.co_filename.startswith(self.base_cache_dir): # this is the real file path # compiled from Inductor file_path = code.co_filename @@ -387,29 +399,38 @@ def _get_shape_env() -> AlwaysHitShapeEnv: # for hijacking the hash of the compiled graph stack.enter_context( - patch("torch._inductor.codecache.compiled_fx_graph_hash", - hijack_compiled_fx_graph_hash)) + patch( + "torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash, + ) + ) # for providing a dummy shape environment stack.enter_context( - patch("torch._inductor.codecache.FxGraphCache._get_shape_env", - _get_shape_env)) + patch( + "torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env, + ) + ) - from torch._functorch._aot_autograd.autograd_cache import ( - AOTAutogradCache) + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache if hasattr(AOTAutogradCache, "_get_shape_env"): stack.enter_context( patch( "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", - _get_shape_env)) + _get_shape_env, + ) + ) # for forcing the graph to be cached stack.enter_context( patch( "torch._inductor.codecache.FxGraphCache._check_can_cache", - _check_can_cache)) + _check_can_cache, + ) + ) # Dynamo metrics context, see method for more details. stack.enter_context(self.metrics_context()) @@ -422,23 +443,26 @@ def _get_shape_env() -> AlwaysHitShapeEnv: # standalone_compile sometime. if is_torch_equal_or_newer("2.6"): stack.enter_context( - torch._inductor.config.patch(fx_graph_remote_cache=False)) + torch._inductor.config.patch(fx_graph_remote_cache=False) + ) # InductorAdaptor (unfortunately) requires AOTAutogradCache # to be turned off to run. It will fail to acquire the hash_str # and error if not. # StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem. stack.enter_context( - torch._functorch.config.patch(enable_autograd_cache=False)) + torch._functorch.config.patch(enable_autograd_cache=False) + ) stack.enter_context( - torch._functorch.config.patch( - enable_remote_autograd_cache=False)) + torch._functorch.config.patch(enable_remote_autograd_cache=False) + ) with pass_context(runtime_shape): compiled_graph = compile_fx( graph, example_inputs, inner_compile=hijacked_compile_fx_inner, - config_patches=current_config) + config_patches=current_config, + ) # We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch # compilation cache. So turn off the checks if we disable the @@ -451,52 +475,63 @@ def _get_shape_env() -> AlwaysHitShapeEnv: "failed, leading to a corrupted compilation artifact. " "We recommend trying to " "remove ~/.cache/vllm/torch_compile_cache and try again " - "to see the real issue. ") + "to see the real issue. " + ) assert file_path is not None, ( - "failed to get the file path of the compiled graph") + "failed to get the file path of the compiled graph" + ) return compiled_graph, (hash_str, file_path) - def load(self, - handle: Any, - graph: fx.GraphModule, - example_inputs: list[Any], - graph_index: int, - runtime_shape: Optional[int] = None) -> Callable: + def load( + self, + handle: Any, + graph: fx.GraphModule, + example_inputs: list[Any], + graph_index: int, + runtime_shape: Optional[int] = None, + ) -> Callable: assert isinstance(handle, tuple) assert isinstance(handle[0], str) assert isinstance(handle[1], str) hash_str = handle[0] - from torch._functorch._aot_autograd.autograd_cache import ( - AOTAutogradCache) + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._inductor.codecache import FxGraphCache + with ExitStack() as exit_stack: exit_stack.enter_context( - patch("torch._inductor.codecache.FxGraphCache._get_shape_env", - lambda *args, **kwargs: AlwaysHitShapeEnv())) + patch( + "torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv(), + ) + ) # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache if hasattr(AOTAutogradCache, "_get_shape_env"): exit_stack.enter_context( patch( "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env", - lambda *args, **kwargs: AlwaysHitShapeEnv())) + lambda *args, **kwargs: AlwaysHitShapeEnv(), + ) + ) # Dynamo metrics context, see method for more details. exit_stack.enter_context(self.metrics_context()) if torch.__version__.startswith("2.5"): inductor_compiled_graph = FxGraphCache._lookup_graph( - hash_str, example_inputs, True, False) + hash_str, example_inputs, True, False + ) assert inductor_compiled_graph is not None, ( "Inductor cache lookup failed. Please remove" f"the cache directory and try again." # noqa ) elif torch.__version__ >= "2.6": - from torch._inductor.output_code import ( - CompiledFxGraphConstantsWithGm) + from torch._inductor.output_code import CompiledFxGraphConstantsWithGm + constants = CompiledFxGraphConstantsWithGm(graph) inductor_compiled_graph, _ = FxGraphCache._lookup_graph( - hash_str, example_inputs, True, None, constants) + hash_str, example_inputs, True, None, constants + ) assert inductor_compiled_graph is not None, ( "Inductor cache lookup failed. Please remove" f"the cache directory and try again." # noqa @@ -509,6 +544,7 @@ def load(self, # need to know if the graph returns a tuple from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) # this is the callable we return to Dynamo to run @@ -542,6 +578,7 @@ def metrics_context(self) -> contextlib.AbstractContextManager: """ if is_torch_equal_or_newer("2.6"): import torch._dynamo.utils + return torch._dynamo.utils.get_metrics_context() else: return contextlib.nullcontext() @@ -553,7 +590,8 @@ def set_inductor_config(config, runtime_shape): # can be beneficial config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE config["coordinate_descent_tuning"] = ( - envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING) + envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING + ) class EagerAdaptor(CompilerInterface): diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index e01dd3915a3a..9e8de831bcb2 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -41,7 +41,8 @@ def expect(self, **kwargs): assert getattr(self, k) - getattr(old, k) == v, ( f"{k} not as expected, before it is {getattr(old, k)}" f", after it is {getattr(self, k)}, " - f"expected diff is {v}") + f"expected diff is {v}" + ) compilation_counter = CompilationCounter() diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index befb7736d75a..4c3ac9e56a37 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -12,8 +12,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CUDAGraphMode, VllmConfig -from vllm.distributed.device_communicators.pynccl_allocator import ( - set_graph_pool_id) +from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform @@ -46,10 +45,10 @@ class CUDAGraphWrapper: The workflow of this wrapper in the cudagraph dispatching is as follows: 1. At initialization, a runtime mode is assigned to the wrapper (FULL or - PIECEWISE). - 2. At runtime, the wrapper receives a runtime_mode and a + PIECEWISE). + 2. At runtime, the wrapper receives a runtime_mode and a batch_descriptor(key) from the forward context and blindly trust them - for cudagraph dispatching. + for cudagraph dispatching. 3. If runtime_mode is NONE or runtime_mode does not match the mode of the wrapper, just call the runnable directly. 4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper, @@ -58,18 +57,20 @@ class CUDAGraphWrapper: Note: CUDAGraphWrapper does not store persistent buffers or copy any runtime inputs into that buffers for replay. We assume implementing them - is done outside of the wrapper. That is because we do not make any + is done outside of the wrapper. That is because we do not make any assumption on the dynamic shape (batch size) of the runtime inputs, as a - trade-off for staying orthogonal to compilation logic. Nevertheless, + trade-off for staying orthogonal to compilation logic. Nevertheless, tracing and checking the input addresses to be consistent during replay is guaranteed when VLLM_LOGGING_LEVEL == "DEBUG". """ - def __init__(self, - runnable: Callable, - vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, - cudagraph_options: Optional[CUDAGraphOptions] = None): + def __init__( + self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + cudagraph_options: Optional[CUDAGraphOptions] = None, + ): self.runnable = runnable self.vllm_config = vllm_config self.runtime_mode = runtime_mode @@ -91,15 +92,16 @@ def __init__(self, self.cudagraph_options = cudagraph_options # the entries for different batch descriptors that we need to capture # cudagraphs for. - self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry]\ - = {} + self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {} def __getattr__(self, key: str): # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): return getattr(self.runnable, key) - raise AttributeError(f"Attribute {key} not exists in the runnable of " - f"cudagraph wrapper: {self.runnable}") + raise AttributeError( + f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self.runnable}" + ) def unwrap(self) -> Callable: # in case we need to access the original runnable. @@ -110,8 +112,10 @@ def __call__(self, *args, **kwargs): batch_descriptor = forward_context.batch_descriptor cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode - if cudagraph_runtime_mode == CUDAGraphMode.NONE or \ - cudagraph_runtime_mode != self.runtime_mode: + if ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode != self.runtime_mode + ): # CUDAGraphMode.NONE could mean the profile run, a warmup run, or # running without cudagraphs. # We do not trigger capture/replay if the runtime mode is not @@ -122,8 +126,9 @@ def __call__(self, *args, **kwargs): if batch_descriptor not in self.concrete_cudagraph_entries: # create a new entry for this batch descriptor - self.concrete_cudagraph_entries[batch_descriptor] = \ - CUDAGraphEntry(batch_descriptor=batch_descriptor) + self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry( + batch_descriptor=batch_descriptor + ) entry = self.concrete_cudagraph_entries[batch_descriptor] @@ -133,8 +138,11 @@ def __call__(self, *args, **kwargs): # capturing is fast, we don't need to log it for every # shape. E.g. we only log it for the first subgraph in # piecewise mode. - logger.debug("Capturing a cudagraph on (%s,%s)", - self.runtime_mode.name, entry.batch_descriptor) + logger.debug( + "Capturing a cudagraph on (%s,%s)", + self.runtime_mode.name, + entry.batch_descriptor, + ) # validate that cudagraph capturing is legal at this point. validate_cudagraph_capturing_enabled() @@ -153,8 +161,7 @@ def __call__(self, *args, **kwargs): # therefore, we only run gc for the first graph, # and disable gc for the rest of the graphs. stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.cuda.empty_cache", lambda: None)) + stack.enter_context(patch("torch.cuda.empty_cache", lambda: None)) if self.graph_pool is not None: set_graph_pool_id(self.graph_pool) @@ -193,7 +200,8 @@ def __call__(self, *args, **kwargs): assert new_input_addresses == entry.input_addresses, ( f"Input addresses for cudagraphs are different " f"during replay. Expected {entry.input_addresses}, " - f"got {new_input_addresses}") + f"got {new_input_addresses}" + ) entry.cudagraph.replay() return entry.output diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index fa38cfe49a91..4f5648d3000a 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -34,11 +34,11 @@ def ignore_torch_compile(cls: _T) -> _T: a support_torch_compile decorator, but we don't want to compile the class `cls` that inherits the parent class. This only ignores compiling the forward of the class the - decorator is applied to. + decorator is applied to. If the parent has ignore_torch_compile but the child has support_torch_compile, the child will still be compiled. - + If the class has one or more submodules that have support_torch_compile decorator applied, compile will not be ignored for those submodules. @@ -58,21 +58,18 @@ def _should_ignore_torch_compile(cls) -> bool: def support_torch_compile( *, enable_if: Optional[Callable[[VllmConfig], bool]] = None, -) -> Callable[[_T], _T]: - ... +) -> Callable[[_T], _T]: ... @overload def support_torch_compile( *, dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]], -) -> Callable[[_T], _T]: - ... +) -> Callable[[_T], _T]: ... @overload -def support_torch_compile(cls: _T) -> _T: - ... +def support_torch_compile(cls: _T) -> _T: ... def support_torch_compile( @@ -89,8 +86,7 @@ def support_torch_compile( ```python @support_torch_compile class MyModel(nn.Module): - def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): - ... + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... ``` Usage 2: use as a decorator with arguments: @@ -98,8 +94,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ```python @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0}) class MyModel(nn.Module): - def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): - ... + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... ``` `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic @@ -139,7 +134,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): def cls_decorator_helper(cls: _T) -> _T: # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` # to avoid too much indentation for `_support_torch_compile`` - if not hasattr(cls, 'forward'): + if not hasattr(cls, "forward"): raise TypeError("decorated class should have a forward method.") sig = inspect.signature(cls.forward) inferred_dynamic_arg_dims = dynamic_arg_dims @@ -147,26 +142,31 @@ def cls_decorator_helper(cls: _T) -> _T: inferred_dynamic_arg_dims = {} for k, v in sig.parameters.items(): if v.annotation in [ - torch.Tensor, Optional[torch.Tensor], - IntermediateTensors, Optional[IntermediateTensors] + torch.Tensor, + Optional[torch.Tensor], + IntermediateTensors, + Optional[IntermediateTensors], ]: inferred_dynamic_arg_dims[k] = 0 - logger.debug(("Inferred dynamic dimensions for " - "forward method of %s: %s"), cls, - list(inferred_dynamic_arg_dims.keys())) + logger.debug( + ("Inferred dynamic dimensions for forward method of %s: %s"), + cls, + list(inferred_dynamic_arg_dims.keys()), + ) if len(inferred_dynamic_arg_dims) == 0: raise ValueError( "No dynamic dimensions found in the forward method of " - f"{cls}. Please provide dynamic_arg_dims explicitly.") + f"{cls}. Please provide dynamic_arg_dims explicitly." + ) for k in inferred_dynamic_arg_dims: if k not in sig.parameters: raise ValueError( - f"Argument {k} not found in the forward method of {cls}") - return _support_torch_compile(cls, inferred_dynamic_arg_dims, - enable_if) + f"Argument {k} not found in the forward method of {cls}" + ) + return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if) if cls is not None: # use `support_torch_compile` as a decorator without arguments @@ -191,29 +191,32 @@ def _support_torch_compile( # take care of method resolution order # make sure super().__init__ is called on the base class # other than TorchCompileWrapperWithCustomDispatcher - cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,) old_init = cls.__init__ setattr(cls, IGNORE_COMPILE_KEY, False) - def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config enable_compile = enable_if is None or enable_if(vllm_config) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. - self.do_not_compile = \ - vllm_config.compilation_config.level in [ - CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo() or _should_ignore_torch_compile( - self.__class__) or not enable_compile + self.do_not_compile = ( + vllm_config.compilation_config.level + in [CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS] + or not supports_dynamo() + or _should_ignore_torch_compile(self.__class__) + or not enable_compile + ) if self.do_not_compile: return compilation_counter.num_models_seen += 1 TorchCompileWrapperWithCustomDispatcher.__init__( - self, compilation_level=vllm_config.compilation_config.level) + self, compilation_level=vllm_config.compilation_config.level + ) cls.__init__ = __init__ @@ -235,26 +238,23 @@ def __call__(self, *args, **kwargs): dims = [dims] if isinstance(dims, int) else dims if isinstance(arg, torch.Tensor): # In case dims is specified with negative indexing - dims = [ - arg.ndim + dim if dim < 0 else dim for dim in dims - ] + dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] torch._dynamo.mark_dynamic(arg, dims) elif isinstance(arg, IntermediateTensors): for tensor in arg.tensors.values(): # In case dims is specified with negative indexing dims = [ - tensor.ndim + dim if dim < 0 else dim - for dim in dims + tensor.ndim + dim if dim < 0 else dim for dim in dims ] torch._dynamo.mark_dynamic(tensor, dims) else: raise ValueError( "Unsupported dynamic dimensions" - f" {dims} for argument {k} with type {type(arg)}.") + f" {dims} for argument {k} with type {type(arg)}." + ) # here, it is the starting point of the `torch.compile` process start_monitoring_torch_compile(self.vllm_config) - logger.debug("Start compiling function %s", - self.original_code_object) + logger.debug("Start compiling function %s", self.original_code_object) # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, @@ -263,8 +263,7 @@ def __call__(self, *args, **kwargs): # it seems Dynamo reuse the compilation across instances, # while we need to make sure the compiled code is not reused. # we need to control all the compilation of the model. - torch._dynamo.eval_frame.remove_from_cache( - self.original_code_object) + torch._dynamo.eval_frame.remove_from_cache(self.original_code_object) # collect all relevant files traced by Dynamo, # so that the compilation cache can trigger re-compilation @@ -272,7 +271,8 @@ def __call__(self, *args, **kwargs): # 1. the file containing the top-level forward function self.vllm_config.compilation_config.traced_files.add( - self.original_code_object.co_filename) + self.original_code_object.co_filename + ) # 2. every time Dynamo sees a function call, it will inline # the function by calling InliningInstructionTranslator.inline_call @@ -282,8 +282,7 @@ def __call__(self, *args, **kwargs): def patched_inline_call(parent, func, args, kwargs): code = func.get_code() - self.vllm_config.compilation_config.traced_files.add( - code.co_filename) + self.vllm_config.compilation_config.traced_files.add(code.co_filename) return inline_call(parent, func, args, kwargs) # Disable the C++ compilation of symbolic shape guards. C++-fication @@ -293,20 +292,20 @@ def patched_inline_call(parent, func, args, kwargs): dynamo_config_patches = {} try: _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards - dynamo_config_patches[ - "enable_cpp_symbolic_shape_guards"] = False + dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False except AttributeError: # Note: this config is not available in torch 2.6, we can skip # if the config doesn't exist - logger.debug( - "enable_cpp_symbolic_shape_guards config not available") - - with patch.object( - InliningInstructionTranslator, "inline_call", - patched_inline_call), torch._dynamo.config.patch( - **dynamo_config_patches - ), maybe_use_cudagraph_partition_wrapper( - self.vllm_config), _torch27_patch_tensor_subclasses(): + logger.debug("enable_cpp_symbolic_shape_guards config not available") + + with ( + patch.object( + InliningInstructionTranslator, "inline_call", patched_inline_call + ), + torch._dynamo.config.patch(**dynamo_config_patches), + maybe_use_cudagraph_partition_wrapper(self.vllm_config), + _torch27_patch_tensor_subclasses(), + ): output = self.compiled_callable(*args, **kwargs) return output @@ -336,18 +335,20 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config - if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs() - and compilation_config.use_inductor_graph_partition): + if ( + compilation_config.cudagraph_mode.has_piecewise_cudagraphs() + and compilation_config.use_inductor_graph_partition + ): from torch._inductor.utils import CUDAGraphWrapperMetadata from vllm.compilation.cuda_graph import CUDAGraphOptions from vllm.platforms import current_platform static_graph_wrapper_class = resolve_obj_by_qualname( - current_platform.get_static_graph_wrapper_cls()) + current_platform.get_static_graph_wrapper_cls() + ) - def customized_cudagraph_wrapper(f, - metadata: CUDAGraphWrapperMetadata): + def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata): partition_id = metadata.partition_index num_partitions = metadata.num_partitions return static_graph_wrapper_class( @@ -358,15 +359,19 @@ def customized_cudagraph_wrapper(f, debug_log_enable=partition_id == 0, gc_disable=partition_id != 0, weak_ref_output=partition_id == num_partitions - 1, - )) + ), + ) torch._inductor.utils.set_customized_partition_wrappers( - customized_cudagraph_wrapper) + customized_cudagraph_wrapper + ) yield - if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs() - and compilation_config.use_inductor_graph_partition): + if ( + compilation_config.cudagraph_mode.has_piecewise_cudagraphs() + and compilation_config.use_inductor_graph_partition + ): torch._inductor.utils.set_customized_partition_wrappers(None) @@ -378,23 +383,32 @@ def _torch27_patch_tensor_subclasses(): `BasevLLMParameters` without having to replace them with regular tensors before `torch.compile`-time. """ - from vllm.model_executor.parameter import (BasevLLMParameter, - ModelWeightParameter, - RowvLLMParameter, - _ColumnvLLMParameter) + from vllm.model_executor.parameter import ( + BasevLLMParameter, + ModelWeightParameter, + RowvLLMParameter, + _ColumnvLLMParameter, + ) def return_false(*args, **kwargs): return False - if version.parse("2.7") <= version.parse( - torch.__version__) < version.parse("2.8"): + if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"): yield return - with (torch._dynamo.config.patch("traceable_tensor_subclasses", [ - BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter, - RowvLLMParameter - ]), - patch("torch._dynamo.variables.torch.can_dispatch_torch_function", - return_false)): + with ( + torch._dynamo.config.patch( + "traceable_tensor_subclasses", + [ + BasevLLMParameter, + ModelWeightParameter, + _ColumnvLLMParameter, + RowvLLMParameter, + ], + ), + patch( + "torch._dynamo.variables.torch.can_dispatch_torch_function", return_false + ), + ): yield diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index ce6db9c1ebca..0dffb343f9a2 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -31,8 +31,9 @@ def __call__(self, graph: torch.fx.Graph): # XPU does not support auto-functionalization yet. # Will enable this when switch to vllm-xpu-kernels. if current_platform.is_xpu(): - logger.debug("XPU platform does not support fix functionalization" - "pass currently.") + logger.debug( + "XPU platform does not support fix functionalizationpass currently." + ) return self.nodes_to_remove: list[torch.fx.Node] = [] @@ -45,19 +46,21 @@ def __call__(self, graph: torch.fx.Graph): at_target = node.args[0] if at_target == torch.ops._C.rotary_embedding.default: - query = kwargs['query'] - key = kwargs['key'] + query = kwargs["query"] + key = kwargs["key"] getitem_nodes = self.getitem_users(node) - if (is_func(query, operator.getitem) - and is_func(key, operator.getitem) - and query.args[0] == key.args[0] - and is_func(query.args[0], - torch.ops.aten.split_with_sizes.default) - and all( - is_func(user, torch.ops.aten.slice_scatter.default) - for getitem_node in getitem_nodes.values() - for user in getitem_node.users)): + if ( + is_func(query, operator.getitem) + and is_func(key, operator.getitem) + and query.args[0] == key.args[0] + and is_func(query.args[0], torch.ops.aten.split_with_sizes.default) + and all( + is_func(user, torch.ops.aten.slice_scatter.default) + for getitem_node in getitem_nodes.values() + for user in getitem_node.users + ) + ): # Pattern where query and key are slices of an mm_node. # While functionalized, results at [1] and [2] are scattered # back into mm_node. So after de-functionalization, we can @@ -66,8 +69,9 @@ def __call__(self, graph: torch.fx.Graph): mm_node = query.args[0].args[0] for user in getitem_nodes.values(): for user_of_getitem in user.users: - if is_func(user_of_getitem, - torch.ops.aten.slice_scatter.default): + if is_func( + user_of_getitem, torch.ops.aten.slice_scatter.default + ): user_of_getitem.replace_all_uses_with(mm_node) self._remove(user_of_getitem) self._remove(user) @@ -81,49 +85,54 @@ def __call__(self, graph: torch.fx.Graph): # do this blindly, but in practice in vLLM it's ok. The best # solution is to use auto_functionalization_v2 and then use # inductor's builtin defunctionalization (reinplacing) pass. - mutated_args = {1: 'query', 2: 'key'} + mutated_args = {1: "query", 2: "key"} self.defunctionalize(graph, node, mutated_args) # rms_norm replacements avoid the most copies for LLaMa. elif at_target == torch.ops._C.fused_add_rms_norm.default: - mutated_args = {1: 'input', 2: 'residual'} + mutated_args = {1: "input", 2: "residual"} self.defunctionalize(graph, node, mutated_args) elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 - mutated_args = {1: 'result', 2: 'residual'} + mutated_args = {1: "result", 2: "residual"} self.defunctionalize(graph, node, mutated_args) elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 - mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} + mutated_args = {1: "result", 2: "scale", 3: "residual"} self.defunctionalize(graph, node, mutated_args) elif at_target in [ - torch.ops._C.rms_norm.default, - torch.ops._C.rms_norm_static_fp8_quant.default, + torch.ops._C.rms_norm.default, + torch.ops._C.rms_norm_static_fp8_quant.default, ]: - mutated_args = {1: 'result'} + mutated_args = {1: "result"} self.defunctionalize(graph, node, mutated_args) # For some reason we need to specify the args for both # silu_and_mul and silu_and_mul_quant. The kwargs # pathway gets the wrong answer. elif at_target == torch.ops._C.silu_and_mul.default: - mutated_args = {1: 'result'} - self.defunctionalize(graph, - node, - mutated_args, - args=('result', 'input')) + mutated_args = {1: "result"} + self.defunctionalize( + graph, node, mutated_args, args=("result", "input") + ) elif at_target == torch.ops._C.silu_and_mul_quant.default: - mutated_args = {1: 'result'} - self.defunctionalize(graph, - node, - mutated_args, - args=('result', 'input', 'scale')) - elif hasattr( - torch.ops._C, "silu_and_mul_nvfp4_quant" - ) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default: - mutated_args = {1: 'result', 2: 'result_block_scale'} - self.defunctionalize(graph, - node, - mutated_args, - args=('result', 'result_block_scale', - 'input', 'input_global_scale')) + mutated_args = {1: "result"} + self.defunctionalize( + graph, node, mutated_args, args=("result", "input", "scale") + ) + elif ( + hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant") + and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default + ): + mutated_args = {1: "result", 2: "result_block_scale"} + self.defunctionalize( + graph, + node, + mutated_args, + args=( + "result", + "result_block_scale", + "input", + "input_global_scale", + ), + ) else: continue # skip the count @@ -136,12 +145,12 @@ def __call__(self, graph: torch.fx.Graph): for node in self.nodes_to_remove: graph.erase_node(node) - logger.debug("De-functionalized %s nodes, removed %s nodes", count, - count_removed) + logger.debug( + "De-functionalized %s nodes, removed %s nodes", count, count_removed + ) self.nodes_to_remove.clear() - def _remove(self, node_or_nodes: Union[torch.fx.Node, - Iterable[torch.fx.Node]]): + def _remove(self, node_or_nodes: Union[torch.fx.Node, Iterable[torch.fx.Node]]): """ Stage a node (or nodes) for removal at the end of the pass. """ @@ -150,12 +159,13 @@ def _remove(self, node_or_nodes: Union[torch.fx.Node, else: self.nodes_to_remove.extend(node_or_nodes) - def defunctionalize(self, - graph: torch.fx.Graph, - node: torch.fx.Node, - mutated_args: dict[int, Union[torch.fx.Node, str]], - args: Optional[tuple[Union[torch.fx.Node, str], - ...]] = None): + def defunctionalize( + self, + graph: torch.fx.Graph, + node: torch.fx.Node, + mutated_args: dict[int, Union[torch.fx.Node, str]], + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None, + ): """ De-functionalize a node by replacing it with a call to the original. It also replaces the getitem users with the mutated arguments. @@ -165,10 +175,9 @@ def defunctionalize(self, self.insert_defunctionalized(graph, node, args=args) self._remove(node) - def replace_users_with_mutated_args(self, node: torch.fx.Node, - mutated_args: dict[int, - Union[torch.fx.Node, - str]]): + def replace_users_with_mutated_args( + self, node: torch.fx.Node, mutated_args: dict[int, Union[torch.fx.Node, str]] + ): """ Replace all getitem users of the auto-functionalized node with the mutated arguments. @@ -194,11 +203,12 @@ def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]: users[idx] = user return users - def insert_defunctionalized(self, - graph: torch.fx.Graph, - node: torch.fx.Node, - args: Optional[tuple[Union[torch.fx.Node, str], - ...]] = None): + def insert_defunctionalized( + self, + graph: torch.fx.Graph, + node: torch.fx.Node, + args: Optional[tuple[Union[torch.fx.Node, str], ...]] = None, + ): """ Insert a new defunctionalized node into the graph before node. If one of the kwargs is 'out', provide args directly, @@ -210,8 +220,9 @@ def insert_defunctionalized(self, :param args: If we cannot use kwargs, specify args directly. If an arg is a string, `node.kwargs[arg]` is used. """ # noqa: E501 - assert is_func(node, auto_functionalized), \ + assert is_func(node, auto_functionalized), ( f"node must be auto-functionalized, is {node} instead" + ) # Create a new call to the original function with graph.inserting_before(node): @@ -220,6 +231,7 @@ def insert_defunctionalized(self, graph.call_function(function, kwargs=node.kwargs) else: # Args passed as strings refer to items in node.kwargs - args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg - for arg in args) + args = tuple( + node.kwargs[arg] if isinstance(arg, str) else arg for arg in args + ) graph.call_function(function, args=args) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 3034b6eaeaca..df54e94a03db 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -12,8 +12,15 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, - kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) + GroupShape, + QuantKey, + ScaleDesc, + kFp8DynamicTensorSym, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kNvfp4Quant, + kStaticTensorScale, +) from vllm.platforms import current_platform from .inductor_pass import enable_fake_mode @@ -40,12 +47,9 @@ def empty_i32(*args, **kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: - torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTensorSym: - torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 - kFp8DynamicTokenSym: - torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501 + kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default @@ -57,80 +61,93 @@ class FusedRMSQuantKey(NamedTuple): quant: type of quantization fused_add: does the op also perform the residual add """ + quant: QuantKey fused_add: bool def __str__(self): - return (f"FusedQuantKey({self.quant}, with" - f"{'' if self.fused_add else 'out'} residual)") + return ( + f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)" + ) FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = { - FusedRMSQuantKey(kFp8StaticTensorSym, False): - torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8StaticTensorSym, True): - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8DynamicTokenSym, False): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 - FusedRMSQuantKey(kFp8DynamicTokenSym, True): - torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8StaticTensorSym, False + ): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8StaticTensorSym, True + ): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, False + ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 + FusedRMSQuantKey( + kFp8DynamicTokenSym, True + ): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501 } class RMSNormQuantPattern: - def __init__(self, epsilon: float, key: FusedRMSQuantKey): self.epsilon = epsilon self.quant_dtype = key.quant.dtype - assert key.quant in QUANT_OPS, \ - f"unsupported quantization scheme {key.quant}" + assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}" self.QUANT_OP = QUANT_OPS[key.quant] - assert key in FUSED_OPS, \ - f"unsupported fused rmsnorm+quant op for {key}" + assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}" self.FUSED_OP = FUSED_OPS[key] class RMSNormStaticQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - symmetric=True): - fused_key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey(dtype=quant_dtype, - scale=kStaticTensorScale, - symmetric=symmetric)) + def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): + fused_key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey( + dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric + ), + ) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): # Cannot use methods, as the self argument affects tracing - def pattern(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon) - at2 = auto_functionalized(self.QUANT_OP, - result=result, - input=at1[1], - scale=scale) + def pattern( + result: torch.Tensor, + result_rms: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at1 = auto_functionalized( + RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon, + ) + at2 = auto_functionalized( + self.QUANT_OP, result=result, input=at1[1], scale=scale + ) # result return at2[1] - def replacement(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon) + def replacement( + result: torch.Tensor, + result_rms: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + ) # result return at[1] @@ -140,53 +157,60 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, empty_bf16(5, 4), # result_rms empty_bf16(5, 4), # input empty_bf16(1, 5), # weight - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] - pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, - pm_pass) + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass) class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - symmetric=True): - key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey(dtype=quant_dtype, - scale=kStaticTensorScale, - symmetric=symmetric)) + def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True): + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey( + dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric + ), + ) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - - def pattern(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon) - at1 = auto_functionalized(self.QUANT_OP, - result=result, - input=at[1], - scale=scale) + def pattern( + result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + at1 = auto_functionalized( + self.QUANT_OP, result=result, input=at[1], scale=scale + ) # result, residual return at1[1], at[2] - def replacement(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - residual=residual, - weight=weight, - scale=scale, - epsilon=self.epsilon) + def replacement( + result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon, + ) # result, residual return at[1], at[2] @@ -196,7 +220,7 @@ def replacement(result: torch.Tensor, input: torch.Tensor, empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] pm.register_replacement( @@ -209,49 +233,59 @@ def replacement(result: torch.Tensor, input: torch.Tensor, class RMSNormDynamicQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey(dtype=quant_dtype, - scale=scale, - symmetric=symmetric)) + key = FusedRMSQuantKey( + fused_add=False, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - - def pattern(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(RMS_OP, - result=result_rms, - input=input, - weight=weight, - epsilon=self.epsilon) - at2 = auto_functionalized(self.QUANT_OP, - result=result, - input=at1[1], - scale=scale, - scale_ub=None) + def pattern( + result: torch.Tensor, + result_rms: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at1 = auto_functionalized( + RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon, + ) + at2 = auto_functionalized( + self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None + ) # result, scale return at2[1], at2[2] - def replacement(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=None) + def replacement( + result: torch.Tensor, + result_rms: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None, + ) # result, scale return at[1], at[2] @@ -261,7 +295,7 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, empty_bf16(5, 4), # result_rms empty_bf16(5, 4), # input empty_bf16(1, 5), # weight - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] pm.register_replacement( @@ -274,49 +308,59 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): - - def __init__(self, - epsilon: float, - quant_dtype: torch.dtype, - group_shape: GroupShape = GroupShape.PER_TOKEN, - symmetric=True): + def __init__( + self, + epsilon: float, + quant_dtype: torch.dtype, + group_shape: GroupShape = GroupShape.PER_TOKEN, + symmetric=True, + ): scale = ScaleDesc(torch.float32, False, group_shape) - key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey(dtype=quant_dtype, - scale=scale, - symmetric=symmetric)) + key = FusedRMSQuantKey( + fused_add=True, + quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric), + ) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass): - - def pattern(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(RMS_ADD_OP, - input=input, - residual=residual, - weight=weight, - epsilon=self.epsilon) - at1 = auto_functionalized(self.QUANT_OP, - result=result, - input=at[1], - scale=scale, - scale_ub=None) + def pattern( + result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + at1 = auto_functionalized( + self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None + ) # result, residual, scale return at1[1], at[2], at1[2] - def replacement(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=self.epsilon, - scale_ub=None, - residual=residual) + def replacement( + result: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + ): + at = auto_functionalized( + self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual, + ) # result, residual, scale return at[1], at[3], at[2] @@ -326,7 +370,7 @@ def replacement(result: torch.Tensor, input: torch.Tensor, empty_bf16(5, 4), # input empty_bf16(5, 4), # residual empty_bf16(1, 5), # weight - empty_fp32(1, 1) # scale + empty_fp32(1, 1), # scale ] pm.register_replacement( @@ -349,24 +393,25 @@ def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="rmsnorm_quant_fusion_pass") + pass_name="rmsnorm_quant_fusion_pass" + ) for epsilon in [1e-5, 1e-6]: # Fuse rms_norm + static fp8 quant - RMSNormStaticQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + static fp8 quant FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns) + self.patterns + ) # Fuse rms_norm + dynamic per-token fp8 quant - RMSNormDynamicQuantPattern(epsilon, - FP8_DTYPE).register(self.patterns) + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns) # Fuse fused_add_rms_norm + dynamic per-token fp8 quant FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register( - self.patterns) + self.patterns + ) self.dump_patterns(config, self.patterns) @@ -376,8 +421,11 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", self.matched_count) def uuid(self) -> Any: - return self.hash_source(self, RMSNormQuantPattern, - RMSNormStaticQuantPattern, - RMSNormDynamicQuantPattern, - FusedAddRMSNormStaticQuantPattern, - FusedAddRMSNormDynamicQuantPattern) + return self.hash_source( + self, + RMSNormQuantPattern, + RMSNormStaticQuantPattern, + RMSNormDynamicQuantPattern, + FusedAddRMSNormStaticQuantPattern, + FusedAddRMSNormDynamicQuantPattern, + ) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 2c6cf8f12fdc..ae36cef92653 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -12,7 +12,10 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kNvfp4Quant, kStaticTensorScale) + QuantKey, + kNvfp4Quant, + kStaticTensorScale, +) from vllm.platforms import current_platform from vllm.utils import round_up @@ -49,21 +52,21 @@ def __init__( self.quant_dtype = quant_key.dtype self.dtype = dtype - assert self.quant_key in QUANT_OPS, \ + assert self.quant_key in QUANT_OPS, ( f"unsupported quantization scheme {self.quant_key}" + ) self.QUANT_OP = QUANT_OPS[self.quant_key] def empty(self, *args, **kwargs): - kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs} + kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs} return torch.empty(*args, **kwargs) def empty_quant(self, *args, **kwargs): - kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} + kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} return torch.empty(*args, **kwargs) @staticmethod def wrap_trace_fn(process_fx, trace_fn): - def wrapped(*args, **kwargs): return process_fx(trace_fn(*args, **kwargs)) @@ -72,6 +75,7 @@ def wrapped(*args, **kwargs): @staticmethod def fx_view_to_reshape(gm: torch.fx.GraphModule): from torch._inductor.fx_passes.post_grad import view_to_reshape + view_to_reshape(gm) return gm @@ -100,70 +104,85 @@ def __init__( dtype: torch.dtype, symmetric: bool = True, ): - quant_key = QuantKey(dtype=FP8_DTYPE, - scale=kStaticTensorScale, - symmetric=symmetric) + quant_key = QuantKey( + dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric + ) super().__init__(layer, quant_key, dtype) def _register(self, pm_pass: PatternMatcherPass): - - def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - output_attn: torch.Tensor, output_quant: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(ATTN_OP, - query=q, - key=k, - value=v, - output=output_attn, - layer_name=self.layer_name, - output_scale=None, - output_block_scale=None) + def pattern( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + output_attn: torch.Tensor, + output_quant: torch.Tensor, + scale: torch.Tensor, + ): + at1 = auto_functionalized( + ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=None, + output_block_scale=None, + ) attn_out_view = RESHAPE_OP( - at1[1], [q.shape[0], self.num_heads * self.head_size]) - at2 = auto_functionalized(self.QUANT_OP, - result=output_quant, - input=attn_out_view, - scale=scale) + at1[1], [q.shape[0], self.num_heads * self.head_size] + ) + at2 = auto_functionalized( + self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale + ) return at2[1] - def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - output_attn: torch.Tensor, output_quant: torch.Tensor, - scale: torch.Tensor): + def replacement( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + output_attn: torch.Tensor, + output_quant: torch.Tensor, + scale: torch.Tensor, + ): # attn output in quant_dtype output_attn = torch.ops.aten.full.default( [q.shape[0], self.num_heads, self.head_size], 0.0, dtype=self.quant_dtype, - device=q.device) - at1 = auto_functionalized(ATTN_OP, - query=q, - key=k, - value=v, - output=output_attn, - layer_name=self.layer_name, - output_scale=scale, - output_block_scale=None) + device=q.device, + ) + at1 = auto_functionalized( + ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=scale, + output_block_scale=None, + ) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) inputs = [ - self.empty(5, self.num_heads, self.head_size, - dtype=self.dtype), # q - self.empty(5, self.num_heads, self.head_size, - dtype=self.dtype), # k - self.empty(5, self.num_heads, self.head_size, - dtype=self.dtype), # v - self.empty(5, self.num_heads, self.head_size, - dtype=self.dtype), # attn_output - self.empty_quant(5, - self.num_heads * self.head_size), # quant_output - empty_fp32(1, 1) # scale + self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q + self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k + self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v + self.empty( + 5, self.num_heads, self.head_size, dtype=self.dtype + ), # attn_output + self.empty_quant(5, self.num_heads * self.head_size), # quant_output + empty_fp32(1, 1), # scale ] pm.register_replacement( - pattern, replacement, inputs, + pattern, + replacement, + inputs, AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), - pm_pass) + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + ), + pm_pass, + ) class AttentionNvfp4QuantPattern(AttentionQuantPattern): @@ -180,50 +199,67 @@ def __init__(self, layer: Attention, dtype: torch.dtype): super().__init__(layer, kNvfp4Quant, dtype) def _register(self, pm_pass: PatternMatcherPass): - - def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - output_attn: torch.Tensor, output_quant: torch.Tensor, - output_scale: torch.Tensor, input_scale: torch.Tensor): - at1 = auto_functionalized(ATTN_OP, - query=q, - key=k, - value=v, - output=output_attn, - layer_name=self.layer_name, - output_scale=None, - output_block_scale=None) + def pattern( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + output_attn: torch.Tensor, + output_quant: torch.Tensor, + output_scale: torch.Tensor, + input_scale: torch.Tensor, + ): + at1 = auto_functionalized( + ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=None, + output_block_scale=None, + ) attn_out_view = RESHAPE_OP( - at1[1], [q.shape[0], self.num_heads * self.head_size]) - at2 = auto_functionalized(self.QUANT_OP, - output=output_quant, - input=attn_out_view, - output_scale=output_scale, - input_scale=input_scale) + at1[1], [q.shape[0], self.num_heads * self.head_size] + ) + at2 = auto_functionalized( + self.QUANT_OP, + output=output_quant, + input=attn_out_view, + output_scale=output_scale, + input_scale=input_scale, + ) output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) return at2[1], output_scale_view - def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - output_attn: torch.Tensor, output_quant: torch.Tensor, - output_scale: torch.Tensor, input_scale: torch.Tensor): + def replacement( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + output_attn: torch.Tensor, + output_quant: torch.Tensor, + output_scale: torch.Tensor, + input_scale: torch.Tensor, + ): # attention output in quant_dtype output_attn = torch.ops.aten.full.default( [q.shape[0], self.num_heads, self.head_size // 2], 0.0, dtype=self.quant_dtype, - device=q.device) + device=q.device, + ) # attention output block scale - output_scale_view = torch.ops.aten.view.dtype( - output_scale, FP8_DTYPE) - at2 = auto_functionalized(ATTN_OP, - query=q, - key=k, - value=v, - output=output_attn, - layer_name=self.layer_name, - output_scale=input_scale, - output_block_scale=output_scale_view) - output = RESHAPE_OP(at2[1], - [-1, self.num_heads * self.head_size // 2]) + output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE) + at2 = auto_functionalized( + ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=input_scale, + output_block_scale=output_scale_view, + ) + output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2]) return output, at2[2] inputs = [ @@ -231,18 +267,22 @@ def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, empty_bf16(5, self.num_heads, self.head_size), # k empty_bf16(5, self.num_heads, self.head_size), # v empty_bf16(5, self.num_heads, self.head_size), # output_attn - self.empty_quant(5, self.num_heads * self.head_size // - 2), # output_quant - empty_i32(128, round_up(self.num_heads * self.head_size // 16, - 4)), # output_scale + self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant + empty_i32( + 128, round_up(self.num_heads * self.head_size // 16, 4) + ), # output_scale empty_fp32(1, 1), # input_scale ] pm.register_replacement( - pattern, replacement, inputs, + pattern, + replacement, + inputs, AttentionQuantPattern.wrap_trace_fn( - AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), - pm_pass) + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only + ), + pm_pass, + ) class AttnFusionPass(VllmPatternMatcherPass): @@ -267,20 +307,22 @@ def __init__(self, config: VllmConfig): attn_layers = get_layers_from_vllm_config(config, Attention) for layer_name, layer in attn_layers.items(): pattern_fp8 = AttentionFp8StaticQuantPattern( - layer, config.model_config.dtype) + layer, config.model_config.dtype + ) pattern_fp8.register_if_supported(self.patterns) - if current_platform.is_cuda() and hasattr(torch.ops._C, - "scaled_fp4_quant"): + if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): pattern_nvfp4 = AttentionNvfp4QuantPattern( - layer, config.model_config.dtype) + layer, config.model_config.dtype + ) pattern_nvfp4.register_if_supported(self.patterns) if len(attn_layers) == 0: logger.warning( "Attention + quant fusion is enabled, but no attention layers " "were found in CompilationConfig.static_forward_context " - "so no fusion patterns were registered.") + "so no fusion patterns were registered." + ) self.dump_patterns(config, self.patterns) @@ -290,6 +332,9 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: logger.debug("Fused quant onto %s attention nodes", self.matched_count) def uuid(self): - return VllmInductorPass.hash_source(self, AttentionQuantPattern, - AttentionFp8StaticQuantPattern, - AttentionNvfp4QuantPattern) + return VllmInductorPass.hash_source( + self, + AttentionQuantPattern, + AttentionFp8StaticQuantPattern, + AttentionNvfp4QuantPattern, + ) diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index 2db8b5441bd6..114b53c74c48 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -19,8 +19,9 @@ def is_auto_func(node: fx.Node, op: OpOverload) -> bool: # Returns the first specified node with the given op (if it exists) -def find_specified_fn_maybe(nodes: Iterable[fx.Node], - op: OpOverload) -> Optional[fx.Node]: +def find_specified_fn_maybe( + nodes: Iterable[fx.Node], op: OpOverload +) -> Optional[fx.Node]: for node in nodes: if node.target == op: return node @@ -35,8 +36,7 @@ def find_specified_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: # Returns the first auto_functionalized node with the given op (if it exists) -def find_auto_fn_maybe(nodes: Iterable[fx.Node], - op: OpOverload) -> Optional[fx.Node]: +def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> Optional[fx.Node]: for node in nodes: if is_func(node, auto_functionalized) and node.args[0] == op: # noqa return node diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index e1b691df385d..9085448d2397 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -11,8 +11,7 @@ import torch from torch import fx -from torch._subclasses.fake_tensor import (FakeTensorMode, - unset_fake_temporarily) +from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily from vllm.utils import is_torch_equal_or_newer @@ -20,14 +19,14 @@ from torch._inductor.custom_graph_pass import CustomGraphPass else: # CustomGraphPass is not present in 2.5 or lower, import our version - from .torch25_custom_graph_pass import ( # noqa: E501 - Torch25CustomGraphPass as CustomGraphPass) + from .torch25_custom_graph_pass import ( + Torch25CustomGraphPass as CustomGraphPass, + ) _pass_context = None class PassContext: - def __init__(self, runtime_shape: Optional[int]): self.runtime_shape = runtime_shape @@ -106,9 +105,9 @@ class CallableInductorPass(InductorPass): implementation of the UUID. """ - def __init__(self, - callable: Callable[[fx.Graph], None], - uuid: Optional[Any] = None): + def __init__( + self, callable: Callable[[fx.Graph], None], uuid: Optional[Any] = None + ): self.callable = callable self._uuid = self.hash_source(callable) if uuid is None else uuid @@ -127,8 +126,7 @@ def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(fn) def fn_new(*args, **kwargs) -> Any: - with torch._guards.tracing( - None), unset_fake_temporarily(), FakeTensorMode(): + with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): result = fn(*args, **kwargs) return result diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 35658466d66d..d3c437795fab 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -20,6 +20,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): path = vllm_config.compile_debug_dump_path() if compilation_config.level == CompilationLevel.PIECEWISE and path: import depyf + path.mkdir(parents=True, exist_ok=True) global context_manager context_manager = depyf.prepare_debug(path.as_posix()) @@ -29,8 +30,9 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): def end_monitoring_torch_compile(vllm_config: VllmConfig): compilation_config: CompilationConfig = vllm_config.compilation_config if compilation_config.level == CompilationLevel.PIECEWISE: - logger.info("torch.compile takes %.2f s in total", - compilation_config.compilation_time) + logger.info( + "torch.compile takes %.2f s in total", compilation_config.compilation_time + ) global context_manager if context_manager is not None: context_manager.__exit__(None, None, None) @@ -46,8 +48,10 @@ def validate_cudagraph_capturing_enabled(): # if an illegal cudagraph capturing happens, raise an error. global cudagraph_capturing_enabled if not cudagraph_capturing_enabled: - raise RuntimeError("CUDA graph capturing detected at an inappropriate " - "time. This operation is currently disabled.") + raise RuntimeError( + "CUDA graph capturing detected at an inappropriate " + "time. This operation is currently disabled." + ) def set_cudagraph_capturing_enabled(enabled: bool): diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 2c453daf873d..3d807ab3a6de 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -122,8 +122,9 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Removed %s no-op reshapes and slices", count) # ---------------------- Reshape helpers ---------------------- - def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node], - i_dim: Union[int, SymInt]) -> bool: + def reshape_dims_equivalent( + self, dim: Union[int, torch.fx.Node], i_dim: Union[int, SymInt] + ) -> bool: """ This function checks if two dimensions are equivalent. :param dim: The dimension arg to reshape/slice @@ -153,6 +154,4 @@ def reshape_all_dims_equivalent( dims: Iterable[Union[int, torch.fx.Node]], i_dims: Iterable[Union[int, SymInt]], ) -> bool: - return all( - self.reshape_dims_equivalent(s, i_s) - for s, i_s in zip(dims, i_dims)) + return all(self.reshape_dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims)) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index ae26e9f1bf2b..61551766a1c5 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -23,15 +23,19 @@ class ConcreteSizeEntry: class PiecewiseBackend: - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - piecewise_compile_index: int, total_piecewise_compiles: int, - sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend): + def __init__( + self, + graph: fx.GraphModule, + vllm_config: VllmConfig, + piecewise_compile_index: int, + total_piecewise_compiles: int, + sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend, + ): """ The backend for piecewise compilation. - It mainly handles the compilation of static shapes and + It mainly handles the compilation of static shapes and dispatching based on runtime shape. We will compile `self.graph` once for the general shape, @@ -46,13 +50,11 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.vllm_backend = vllm_backend self.is_first_graph = piecewise_compile_index == 0 - self.is_last_graph = ( - piecewise_compile_index == total_piecewise_compiles - 1) + self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_full_graph = total_piecewise_compiles == 1 - self.compile_sizes: set[int] = set( - self.compilation_config.compile_sizes) + self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) self.first_run_finished = False @@ -108,7 +110,8 @@ def __call__(self, *args) -> Any: self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape) + runtime_shape=runtime_shape, + ) # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: diff --git a/vllm/compilation/post_cleanup.py b/vllm/compilation/post_cleanup.py index 6a31f3935da7..55117516838c 100644 --- a/vllm/compilation/post_cleanup.py +++ b/vllm/compilation/post_cleanup.py @@ -16,5 +16,6 @@ class PostCleanupPass(VllmInductorPass): @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: from torch._inductor.pattern_matcher import stable_topological_sort + stable_topological_sort(graph) graph.eliminate_dead_code() diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index a6ca50c925a2..2bc705c3b9a9 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -9,8 +9,7 @@ from vllm.config import VllmConfig from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.platforms import current_platform @@ -23,12 +22,14 @@ class _RMSNormAndQuantOpHelper: """Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" - def __init__(self, - epsilon: float, - dtype: torch.dtype, - device: str, - quant_op: Optional[torch._ops.OpOverload] = None, - **kwargs): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + quant_op: Optional[torch._ops.OpOverload] = None, + **kwargs, + ): self.epsilon = epsilon self.dtype = dtype self.device = device @@ -40,60 +41,78 @@ def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor): result=result_buffer, input=input_tensor, weight=weight_tensor, - epsilon=self.epsilon) + epsilon=self.epsilon, + ) - def _functional_fused_add_rmsnorm(self, input_tensor, residual_tensor, - weight_tensor): + def _functional_fused_add_rmsnorm( + self, input_tensor, residual_tensor, weight_tensor + ): return torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm.default, input=input_tensor, residual=residual_tensor, weight=weight_tensor, - epsilon=self.epsilon) - - def _functional_rmsnorm_then_quant(self, rmsnorm_result_buffer, - quant_result_buffer, input_tensor, - weight_tensor, scale_tensor): + epsilon=self.epsilon, + ) + + def _functional_rmsnorm_then_quant( + self, + rmsnorm_result_buffer, + quant_result_buffer, + input_tensor, + weight_tensor, + scale_tensor, + ): if self.quant_op is None: raise RuntimeError( "_RMSNormAndQuantOpHelper was not initialized with a quant_op." ) - rmsnorm_out_tuple = self._functional_rmsnorm(rmsnorm_result_buffer, - input_tensor, - weight_tensor) + rmsnorm_out_tuple = self._functional_rmsnorm( + rmsnorm_result_buffer, input_tensor, weight_tensor + ) quant_out_tuple = torch.ops.higher_order.auto_functionalized( self.quant_op, result=quant_result_buffer, input=rmsnorm_out_tuple[1], - scale=scale_tensor) + scale=scale_tensor, + ) return quant_out_tuple - def _functional_fused_add_rmsnorm_then_quant(self, quant_result_buffer, - input_tensor, residual_tensor, - weight_tensor, scale_tensor): + def _functional_fused_add_rmsnorm_then_quant( + self, + quant_result_buffer, + input_tensor, + residual_tensor, + weight_tensor, + scale_tensor, + ): if self.quant_op is None: raise RuntimeError( "_RMSNormAndQuantOpHelper was not initialized with a quant_op." ) fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm( - input_tensor, residual_tensor, weight_tensor) + input_tensor, residual_tensor, weight_tensor + ) quant_out_tuple = torch.ops.higher_order.auto_functionalized( self.quant_op, result=quant_result_buffer, input=fused_add_rmsnorm_out_tuple[1], - scale=scale_tensor) + scale=scale_tensor, + ) return quant_out_tuple, fused_add_rmsnorm_out_tuple[2] class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper): """Helper for sequence parallelism patterns.""" - def __init__(self, - epsilon: float, - dtype: torch.dtype, - device: str, - quant_op: Optional[torch._ops.OpOverload] = None, - **kwargs): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + quant_op: Optional[torch._ops.OpOverload] = None, + **kwargs, + ): super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs) self.tp_group = get_tp_group() self.tp_size = get_tensor_model_parallel_world_size() @@ -103,21 +122,16 @@ def _all_reduce(self, x: torch.Tensor) -> torch.Tensor: def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor: return torch.ops.vllm.reduce_scatter.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp_group.unique_name) + x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name + ) def _all_gather(self, x: torch.Tensor) -> torch.Tensor: return torch.ops.vllm.all_gather.default( - x, - dim=0, - world_size=self.tp_size, - group_name=self.tp_group.unique_name) + x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name + ) class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def get_inputs(self): input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) @@ -126,7 +140,6 @@ def get_inputs(self): return [input, permute, arg3_1] def register(self, pm_pass: PatternMatcherPass): - def pattern( input: torch.Tensor, permute: torch.Tensor, @@ -145,26 +158,23 @@ def replacement( reduce_scatter = self._reduce_scatter(input) rmsnorm_result = torch.empty_like(reduce_scatter) - rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, - arg3_1) + rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1) all_gather = self._all_gather(rmsnorm[1]) return all_gather, reduce_scatter - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [ residual, @@ -173,7 +183,6 @@ def get_inputs(self): ] def register(self, pm_pass: PatternMatcherPass): - def pattern( residual: torch.Tensor, mm_1: torch.Tensor, @@ -181,7 +190,8 @@ def pattern( ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) rmsnorm = self._functional_fused_add_rmsnorm( - all_reduce, residual, rms_norm_weights) + all_reduce, residual, rms_norm_weights + ) return rmsnorm[1], rmsnorm[2] def replacement( @@ -191,23 +201,22 @@ def replacement( ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) rmsnorm = self._functional_fused_add_rmsnorm( - reduce_scatter, residual, rms_norm_weights) + reduce_scatter, residual, rms_norm_weights + ) all_gather = self._all_gather(rmsnorm[1]) return all_gather, rmsnorm[2] - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) return [ residual, @@ -216,7 +225,6 @@ def get_inputs(self): ] def register(self, pm_pass: PatternMatcherPass): - def pattern( residual: torch.Tensor, mm_1: torch.Tensor, @@ -224,7 +232,8 @@ def pattern( ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) rmsnorm = self._functional_fused_add_rmsnorm( - all_reduce, residual, rms_norm_weights) + all_reduce, residual, rms_norm_weights + ) return rmsnorm[1] def replacement( @@ -234,37 +243,34 @@ def replacement( ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) rmsnorm = self._functional_fused_add_rmsnorm( - reduce_scatter, residual, rms_norm_weights) + reduce_scatter, residual, rms_norm_weights + ) normalized = self._all_gather(rmsnorm[1]) return normalized - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) FP8_DTYPE = current_platform.fp8_dtype() class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - op: torch._ops.OpOverload): + def __init__( + self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload + ): super().__init__(epsilon, dtype, device, quant_op=op) def get_inputs(self): input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype) - rmsnorm_result = torch.empty([1, 8, 4], - device=self.device, - dtype=self.dtype) - quant_result = torch.empty([1, 8, 4], - device=self.device, - dtype=FP8_DTYPE) + rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) + quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE) weight = torch.empty([4], device=self.device, dtype=self.dtype) scale = torch.tensor(1.0, device=self.device, dtype=torch.float32) return [input, rmsnorm_result, quant_result, weight, scale] def register(self, pm_pass: PatternMatcherPass): - def pattern( input: torch.Tensor, rmsnorm_result: torch.Tensor, @@ -274,7 +280,8 @@ def pattern( ): all_reduce = self._all_reduce(input) static_fp8 = self._functional_rmsnorm_then_quant( - rmsnorm_result, quant_result, all_reduce, weight, scale) + rmsnorm_result, quant_result, all_reduce, weight, scale + ) return static_fp8[1], all_reduce def replacement( @@ -286,34 +293,36 @@ def replacement( ): reduce_scatter = self._reduce_scatter(input) - rmsnorm_result = torch.empty_like(reduce_scatter, - dtype=rmsnorm_result.dtype) + rmsnorm_result = torch.empty_like( + reduce_scatter, dtype=rmsnorm_result.dtype + ) quant_result = torch.empty_like( rmsnorm_result, # Output of RMSNorm - dtype=quant_result.dtype) + dtype=quant_result.dtype, + ) static_fp8 = self._functional_rmsnorm_then_quant( - rmsnorm_result, quant_result, reduce_scatter, weight, scale) + rmsnorm_result, quant_result, reduce_scatter, weight, scale + ) all_gather = self._all_gather(static_fp8[1]) return all_gather, reduce_scatter - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - op: torch._ops.OpOverload): + def __init__( + self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload + ): super().__init__(epsilon, dtype, device, quant_op=op) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) @@ -326,7 +335,6 @@ def get_inputs(self): ] def register(self, pm_pass: PatternMatcherPass): - def pattern( result: torch.Tensor, residual: torch.Tensor, @@ -335,8 +343,11 @@ def pattern( scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) - static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 - result, all_reduce, residual, rms_norm_weights, scale) + static_fp8, rmsnorm_residual_out = ( + self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + result, all_reduce, residual, rms_norm_weights, scale + ) + ) return static_fp8[1], rmsnorm_residual_out def replacement( @@ -347,31 +358,31 @@ def replacement( scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) - quant_result_buf = torch.empty_like(reduce_scatter, - dtype=result.dtype) - static_fp8, rmsnorm_residual_out = self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 - quant_result_buf, reduce_scatter, residual, rms_norm_weights, - scale) + quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) + static_fp8, rmsnorm_residual_out = ( + self._functional_fused_add_rmsnorm_then_quant( # noqa: E501 + quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale + ) + ) all_gather = self._all_gather(static_fp8[1]) return all_gather, rmsnorm_residual_out - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - - def __init__(self, epsilon: float, dtype: torch.dtype, device: str, - op: torch._ops.OpOverload): + def __init__( + self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload + ): super().__init__(epsilon, dtype, device, quant_op=op) def get_inputs(self): mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype) residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) - rms_norm_weights = torch.empty([4, 4], - device=self.device, - dtype=self.dtype) + rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype) result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE) scale = torch.empty([1, 1], device=self.device, dtype=torch.float32) @@ -384,7 +395,6 @@ def get_inputs(self): ] def register(self, pm_pass: PatternMatcherPass): - def pattern( result: torch.Tensor, residual: torch.Tensor, @@ -394,7 +404,8 @@ def pattern( ) -> tuple[torch.Tensor, torch.Tensor]: all_reduce = self._all_reduce(mm_1) static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( - result, all_reduce, residual, rms_norm_weights, scale) + result, all_reduce, residual, rms_norm_weights, scale + ) return static_fp8[1] def replacement( @@ -405,16 +416,16 @@ def replacement( scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: reduce_scatter = self._reduce_scatter(mm_1) - quant_result_buf = torch.empty_like(reduce_scatter, - dtype=result.dtype) + quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype) static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant( - quant_result_buf, reduce_scatter, residual, rms_norm_weights, - scale) + quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale + ) normalized = self._all_gather(static_fp8[1]) return normalized - pm.register_replacement(pattern, replacement, self.get_inputs(), - pm.fwd_only, pm_pass) + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) class SequenceParallelismPass(VllmPatternMatcherPass): @@ -442,30 +453,34 @@ def __init__(self, config: VllmConfig): super().__init__(config) self.patterns: PatternMatcherPass = PatternMatcherPass( - pass_name="sequence_parallelism_pass") + pass_name="sequence_parallelism_pass" + ) for epsilon in [1e-5, 1e-6]: # RMSNorm + Static FP8 quantization patterns fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default FirstAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, - fp8_quant_op).register(self.patterns) + epsilon, self.model_dtype, self.device, fp8_quant_op + ).register(self.patterns) MiddleAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, - fp8_quant_op).register(self.patterns) + epsilon, self.model_dtype, self.device, fp8_quant_op + ).register(self.patterns) LastAllReduceRMSNormStaticFP8Pattern( - epsilon, self.model_dtype, self.device, - fp8_quant_op).register(self.patterns) + epsilon, self.model_dtype, self.device, fp8_quant_op + ).register(self.patterns) # Normal RMSNorm patterns - FirstAllReduceRMSNormPattern(epsilon, self.model_dtype, - self.device).register(self.patterns) + FirstAllReduceRMSNormPattern( + epsilon, self.model_dtype, self.device + ).register(self.patterns) - MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype, - self.device).register(self.patterns) + MiddleAllReduceRMSNormPattern( + epsilon, self.model_dtype, self.device + ).register(self.patterns) - LastAllReduceRMSNormPattern(epsilon, self.model_dtype, - self.device).register(self.patterns) + LastAllReduceRMSNormPattern( + epsilon, self.model_dtype, self.device + ).register(self.patterns) self.dump_patterns(config, self.patterns) def is_applicable_for_shape(self, shape: Optional[int]) -> bool: diff --git a/vllm/compilation/torch25_custom_graph_pass.py b/vllm/compilation/torch25_custom_graph_pass.py index cd3970657522..ea8b56cf9d6a 100644 --- a/vllm/compilation/torch25_custom_graph_pass.py +++ b/vllm/compilation/torch25_custom_graph_pass.py @@ -37,6 +37,8 @@ def __getstate__(self): return self.uuid() def __setstate__(self, state): - raise ValueError("Cannot unpickle CustomGraphPass because pickling" - " is used for cache key uuid. Use torch>=2.6 with" - " native uuid support for custom passes.") + raise ValueError( + "Cannot unpickle CustomGraphPass because pickling" + " is used for cache key uuid. Use torch>=2.6 with" + " native uuid support for custom passes." + ) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 59019d74cb80..5aa08220bc2d 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -8,8 +8,7 @@ import regex as re import torch from torch._dynamo.utils import lazy_format_graph_code -from torch._inductor.pattern_matcher import (PatternMatcherPass, - PatternPrettyPrinter) +from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter from vllm.config import VllmConfig from vllm.logger import init_logger @@ -24,20 +23,18 @@ class VllmInductorPass(InductorPass): An inductor pass with access to vLLM PassConfig. It provides timing, logging, and dumping utilities. """ + dump_prefix: ClassVar[Optional[int]] = None """Keep track of pass index for debug dump ordering.""" def __init__(self, config: VllmConfig): self.pass_config = config.compilation_config.pass_config - self.model_dtype = config.model_config.dtype if config.model_config \ - else None - self.device = config.device_config.device if config.device_config \ - else None + self.model_dtype = config.model_config.dtype if config.model_config else None + self.device = config.device_config.device if config.device_config else None self.pass_name = self.__class__.__name__ @staticmethod def time_and_log(call_fn): - @functools.wraps(call_fn) def wrapped(self: VllmInductorPass, graph: torch.fx.Graph): self.begin() @@ -51,8 +48,9 @@ def wrapped(self: VllmInductorPass, graph: torch.fx.Graph): def dump_graph(self, graph: torch.fx.Graph, stage: str): i = VllmInductorPass.dump_prefix i_str = "" if i is None else f".{i}" - lazy_format_graph_code(f"post_grad{i_str}.{self.pass_name}.{stage}", - graph.owning_module) + lazy_format_graph_code( + f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module + ) def begin(self): self._start_time = time.perf_counter_ns() @@ -71,11 +69,13 @@ class VllmPatternMatcherPass(VllmInductorPass): TODO(luka) move more utilities to this pass. """ + matched_count: int = 0 """The number of matched patterns in the pass.""" _OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile( - r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>") + r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>" + ) def _replace_op_overloads(self, string: str) -> str: """Replace <OpOverload(..., ...)> with nicer formulations""" @@ -102,19 +102,22 @@ def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): debug_dump_path.mkdir(parents=True, exist_ok=True) from vllm.utils import unique_filepath + file_path = unique_filepath( - lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py") + lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py" + ) with file_path.open("w") as f: print( - f'# This file was produced by VllmPatternMatcherPass.' - f'dump_patterns for {self.pass_name}.\n' - f'# It does its best to produce valid-Python-looking code but' - f' please add to dump_patterns if there are any errors.\n\n' - f'from torch._higher_order_ops.auto_functionalize import ' - f'auto_functionalized as auto_functionalized\n' - f'from torch._inductor.pattern_matcher import *', - file=f) + f"# This file was produced by VllmPatternMatcherPass." + f"dump_patterns for {self.pass_name}.\n" + f"# It does its best to produce valid-Python-looking code but" + f" please add to dump_patterns if there are any errors.\n\n" + f"from torch._higher_order_ops.auto_functionalize import " + f"auto_functionalized as auto_functionalized\n" + f"from torch._inductor.pattern_matcher import *", + file=f, + ) for node, patterns in pm_pass.patterns.items(): # fix the operator.getitem repr @@ -133,18 +136,21 @@ def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): # Assemble pattern out_node = pp.pretty_print(pattern.pattern) - pattern_repr = "\n".join([f"def pattern_{i}():"] + [ - f"{pp.memoized_objs_names[key]} = " - f"{pp.memoized_objs_pp[key]}" - for key in pp.memoized_objs_names - ] + [f"return {out_node}"]).replace("\n", "\n ") + pattern_repr = "\n".join( + [f"def pattern_{i}():"] + + [ + f"{pp.memoized_objs_names[key]} = " + f"{pp.memoized_objs_pp[key]}" + for key in pp.memoized_objs_names + ] + + [f"return {out_node}"] + ).replace("\n", "\n ") pattern_repr = self._replace_op_overloads(pattern_repr) print(f"{pattern_repr}\n", file=f) class PrinterInductorPass(VllmInductorPass): - def __init__(self, name: str, config: VllmConfig): super().__init__(config) self.name = name diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 062c9dc27017..71a4e1745d4e 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -10,8 +10,7 @@ import torch -from vllm.config import (CompilationLevel, CUDAGraphMode, - get_current_vllm_config) +from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config from vllm.logger import init_logger logger = init_logger(__name__) @@ -30,10 +29,9 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, - compiled_callable: Optional[Callable] = None, - compilation_level: int = 0): - + def __init__( + self, compiled_callable: Optional[Callable] = None, compilation_level: int = 0 + ): vllm_config = get_current_vllm_config() self.vllm_config = vllm_config if compiled_callable is None: @@ -43,13 +41,13 @@ def __init__(self, backend = vllm_config.compilation_config.init_backend(vllm_config) options = None if isinstance(backend, str) and backend == "inductor": - options = get_current_vllm_config( - ).compilation_config.inductor_compile_config + options = ( + get_current_vllm_config().compilation_config.inductor_compile_config + ) - compiled_callable = torch.compile(self.forward, - fullgraph=True, - backend=backend, - options=options) + compiled_callable = torch.compile( + self.forward, fullgraph=True, backend=backend, options=options + ) self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ @@ -59,8 +57,9 @@ def __init__(self, # read the env var to determine whether to use the custom dispatcher # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. - self.use_custom_dispatcher: bool = \ + self.use_custom_dispatcher: bool = ( compilation_level >= CompilationLevel.DYNAMO_ONCE + ) def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. @@ -70,8 +69,7 @@ def __call__(self, *args, **kwargs): return self.compiled_callable(*args, **kwargs) @abstractmethod - def forward(self, *args, **kwargs): - ... + def forward(self, *args, **kwargs): ... def bytecode_hook(self, old_code: CodeType, new_code: CodeType): """Hook to save the compiled bytecode for direct execution.""" @@ -103,21 +101,30 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): # but there's no 100% guarantee, since decompliation is # not a reversible process. import depyf + src = depyf.decompile(new_code) with open(decompiled_file, "w") as f: f.write(src) - logger.debug("Dynamo transformed code saved to %s", - decompiled_file) + logger.debug("Dynamo transformed code saved to %s", decompiled_file) except Exception: pass - if self.vllm_config.compilation_config.cudagraph_mode != \ - CUDAGraphMode.NONE and "update" in new_code.co_names: + if ( + self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and "update" in new_code.co_names + ): import depyf + src = depyf.decompile(new_code) - msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa + msg = ( + "Assigning / modifying buffers of nn.Module during forward pass is not " + "allowed when using cudagraph inside the compiler because it will " + "cause silent errors. Please use eager mode or fix the code. The " + "following code contains clues about which buffer is being modified " + f"(please search for the usage of the function `update`):\n{src}" + ) raise RuntimeError(msg) @contextmanager @@ -128,8 +135,9 @@ def dispatch_to_code(self, index: int): variables as the original code. Therefore we can directly switch the code object in the function and call it. - See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. - """ # noqa + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 + for more details. + """ self.__class__.forward.__code__ = self.compiled_codes[index] yield self.__class__.forward.__code__ = self.original_code_object diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index c909265c071d..7c5052c822f8 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1,36 +1,60 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, - PrefixCachingHashAlgo) -from vllm.config.compilation import (CompilationConfig, CompilationLevel, - CUDAGraphMode, PassConfig) +from vllm.config.cache import ( + BlockSize, + CacheConfig, + CacheDType, + MambaDType, + PrefixCachingHashAlgo, +) +from vllm.config.compilation import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + PassConfig, +) from vllm.config.device import Device, DeviceConfig from vllm.config.kv_events import KVEventsConfig from vllm.config.kv_transfer import KVTransferConfig from vllm.config.load import LoadConfig from vllm.config.lora import LoRAConfig -from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode, - ModelConfig, ModelDType, ModelImpl, - RunnerOption, TaskOption, TokenizerMode, - iter_architecture_defaults, - try_match_architecture_defaults) -from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode, - MultiModalConfig) +from vllm.config.model import ( + ConvertOption, + HfOverrides, + LogprobsMode, + ModelConfig, + ModelDType, + ModelImpl, + RunnerOption, + TaskOption, + TokenizerMode, + iter_architecture_defaults, + try_match_architecture_defaults, +) +from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.observability import DetailedTraceModules, ObservabilityConfig -from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, - ParallelConfig) +from vllm.config.parallel import DistributedExecutorBackend, EPLBConfig, ParallelConfig from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.structured_outputs import StructuredOutputsConfig -from vllm.config.utils import (ConfigType, SupportsMetricsInfo, config, - get_attr_docs, is_init_field, update_config) -from vllm.config.vllm import (VllmConfig, get_cached_compilation_config, - get_current_vllm_config, - get_layers_from_vllm_config, - set_current_vllm_config) +from vllm.config.utils import ( + ConfigType, + SupportsMetricsInfo, + config, + get_attr_docs, + is_init_field, + update_config, +) +from vllm.config.vllm import ( + VllmConfig, + get_cached_compilation_config, + get_current_vllm_config, + get_layers_from_vllm_config, + set_current_vllm_config, +) __all__ = [ # From vllm.config.cache diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 58770649a8af..3d2496b7f21d 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -22,8 +22,7 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] -CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", - "fp8_inc"] +CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] @@ -92,7 +91,8 @@ class CacheConfig: mamba_page_size_padded: Optional[int] = None """ Optional override for mamba page size; used by hybrid mamba/attention models to ensure exact alignment with attention page size.""" - + mamba_block_size: Optional[int] = None + """Size of a contiguous cache block in number of tokens for mamba cache.""" mamba_cache_dtype: MambaDType = "auto" """The data type to use for the Mamba cache (both the conv as well as the ssm state). If set to 'auto', the data type will be inferred from the model @@ -144,8 +144,7 @@ def compute_hash(self) -> str: factors.append(self.mamba_cache_dtype) factors.append(self.mamba_ssm_cache_dtype) # `cpu_offload_gb` does not use `torch.compile` yet. - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self) -> None: @@ -159,16 +158,18 @@ def metrics_info(self): # metrics info return {key: str(value) for key, value in self.__dict__.items()} - @model_validator(mode='after') + @model_validator(mode="after") def _verify_args(self) -> Self: if self.cpu_offload_gb < 0: - raise ValueError("CPU offload space must be non-negative" - f", but got {self.cpu_offload_gb}") + raise ValueError( + f"CPU offload space must be non-negative, but got {self.cpu_offload_gb}" + ) if self.gpu_memory_utilization > 1.0: raise ValueError( "GPU memory utilization must be less than 1.0. Got " - f"{self.gpu_memory_utilization}.") + f"{self.gpu_memory_utilization}." + ) return self @@ -181,7 +182,8 @@ def _verify_cache_dtype(self) -> None: "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor.") + "scaling factor." + ) else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") @@ -192,14 +194,17 @@ def _verify_prefix_caching(self) -> None: if self.sliding_window is not None and not envs.VLLM_USE_V1: raise NotImplementedError( "Prefix caching is not supported with sliding window. " - "Run with --disable-sliding-window to use prefix caching.") + "Run with --disable-sliding-window to use prefix caching." + ) - if (self.enable_prefix_caching and self.prefix_caching_hash_algo - not in get_args(PrefixCachingHashAlgo)): + if self.enable_prefix_caching and self.prefix_caching_hash_algo not in get_args( + PrefixCachingHashAlgo + ): raise ValueError( "Unknown prefix caching hash algorithm: " f"{self.prefix_caching_hash_algo}. Must be one of " - f"{get_args(PrefixCachingHashAlgo)}.") + f"{get_args(PrefixCachingHashAlgo)}." + ) def verify_with_parallel_config( self, @@ -211,9 +216,11 @@ def verify_with_parallel_config( num_gpus_per_node = parallel_config.tensor_parallel_size cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node - msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " - f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " - "is allocated for the swap space.") + msg = ( + f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space." + ) if cpu_memory_usage > 0.7 * total_cpu_memory: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ce173edb4b94..3443d2e1559e 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -33,31 +33,31 @@ class CompilationLevel: class CUDAGraphMode(enum.Enum): - """ Constants for the cudagraph mode in CompilationConfig. + """Constants for the cudagraph mode in CompilationConfig. Meanwhile, the subset enum `NONE`, `PIECEWISE` and `FULL` are also treated as concrete runtime mode for cudagraph runtime dispatching. """ + NONE = 0 PIECEWISE = 1 FULL = 2 FULL_DECODE_ONLY = (FULL, NONE) FULL_AND_PIECEWISE = (FULL, PIECEWISE) - def decode_mode(self) -> 'CUDAGraphMode': - return CUDAGraphMode(self.value[0]) if \ - self.separate_routine() else self + def decode_mode(self) -> "CUDAGraphMode": + return CUDAGraphMode(self.value[0]) if self.separate_routine() else self - def mixed_mode(self) -> 'CUDAGraphMode': - return CUDAGraphMode(self.value[1]) if \ - self.separate_routine() else self + def mixed_mode(self) -> "CUDAGraphMode": + return CUDAGraphMode(self.value[1]) if self.separate_routine() else self def requires_piecewise_compilation(self) -> bool: - return (self.decode_mode() == CUDAGraphMode.PIECEWISE - or self.mixed_mode() == CUDAGraphMode.PIECEWISE) + return ( + self.decode_mode() == CUDAGraphMode.PIECEWISE + or self.mixed_mode() == CUDAGraphMode.PIECEWISE + ) - def max_cudagraph_mode(self) -> 'CUDAGraphMode': - return CUDAGraphMode(max( - self.value)) if self.separate_routine() else self + def max_cudagraph_mode(self) -> "CUDAGraphMode": + return CUDAGraphMode(max(self.value)) if self.separate_routine() else self def has_full_cudagraphs(self) -> bool: return self.max_cudagraph_mode() == CUDAGraphMode.FULL @@ -69,9 +69,7 @@ def separate_routine(self) -> bool: return isinstance(self.value, tuple) def valid_runtime_modes(self) -> bool: - return self in [ - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL - ] + return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL] def __str__(self) -> str: return self.name @@ -116,11 +114,13 @@ def __post_init__(self) -> None: if self.enable_fusion: logger.warning_once( "Fusion enabled but reshape elimination disabled. " - "RMSNorm/SiluMul + quant (fp8) fusion might not work") + "RMSNorm/SiluMul + quant (fp8) fusion might not work" + ) if self.enable_attn_fusion: logger.warning_once( "Fusion enabled but reshape elimination disabled. " - "Attention + quant (fp8) fusion might not work") + "Attention + quant (fp8) fusion might not work" + ) @config @@ -163,6 +163,7 @@ class CompilationConfig: sufficient for most cases. It might be beneficial to compile for certain small batchsizes, where inductor is good at optimizing. """ + # Top-level Compilation control level: Optional[int] = None """The level of compilation: @@ -340,26 +341,24 @@ class CompilationConfig: """local cache dir for each rank""" bs_to_padded_graph_size: list[int] = field( default=None, # type: ignore - init=False) + init=False, + ) """optimization: Intuitively, bs_to_padded_graph_size should be dict[int, int]. since we know all keys are in a range [0, max_capture_size], we can optimize it to list[int] for better lookup performance.""" # keep track of enabled and disabled custom ops - enabled_custom_ops: Counter[str] = field(default_factory=Counter, - init=False) + enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False) """custom ops that are enabled""" - disabled_custom_ops: Counter[str] = field(default_factory=Counter, - init=False) + disabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False) """custom ops that are disabled""" traced_files: set[str] = field(default_factory=set, init=False) """files that are traced for compilation""" compilation_time: float = field(default=0.0, init=False) """time taken for compilation""" - static_forward_context: dict[str, Any] = field(default_factory=dict, - init=False) + static_forward_context: dict[str, Any] = field(default_factory=dict, init=False) """Per-model forward context Map from layer name to layer objects that need to be accessed outside model code, e.g., Attention, FusedMOE when dp_size>1.""" @@ -421,9 +420,9 @@ def __repr__(self) -> str: if pass_config_exclude: exclude["pass_config"] = pass_config_exclude - config = TypeAdapter(CompilationConfig).dump_python(self, - exclude=exclude, - exclude_unset=True) + config = TypeAdapter(CompilationConfig).dump_python( + self, exclude=exclude, exclude_unset=True + ) return str(config) @@ -453,16 +452,16 @@ def __post_init__(self) -> None: # https://github.com/vllm-project/vllm/issues/14703 if is_torch_equal_or_newer("2.6"): - KEY = 'enable_auto_functionalized_v2' + KEY = "enable_auto_functionalized_v2" if KEY not in self.inductor_compile_config: self.inductor_compile_config[KEY] = False for k, v in self.inductor_passes.items(): if not isinstance(v, str): - assert callable(v), ( - f"pass {k} should be callable or a qualified name") - self.inductor_compile_config[k] = v if isinstance( - v, InductorPass) else CallableInductorPass(v) + assert callable(v), f"pass {k} should be callable or a qualified name" + self.inductor_compile_config[k] = ( + v if isinstance(v, InductorPass) else CallableInductorPass(v) + ) continue # resolve function from qualified name @@ -470,54 +469,68 @@ def __post_init__(self) -> None: module = ".".join(names[:-1]) func_name = names[-1] func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func if isinstance( - func, InductorPass) else CallableInductorPass(func) + self.inductor_compile_config[k] = ( + func if isinstance(func, InductorPass) else CallableInductorPass(func) + ) if isinstance(self.pass_config, dict): self.pass_config = PassConfig(**self.pass_config) # migrate the deprecated flags if not self.use_cudagraph: - logger.warning("use_cudagraph is deprecated, use " - "cudagraph_mode=NONE instead.") - if self.cudagraph_mode is not None and \ - self.cudagraph_mode != CUDAGraphMode.NONE: + logger.warning( + "use_cudagraph is deprecated, use cudagraph_mode=NONE instead." + ) + if ( + self.cudagraph_mode is not None + and self.cudagraph_mode != CUDAGraphMode.NONE + ): raise ValueError( "use_cudagraph and cudagraph_mode are mutually" " exclusive, prefer cudagraph_mode since " - "use_cudagraph is deprecated.") + "use_cudagraph is deprecated." + ) self.cudagraph_mode = CUDAGraphMode.NONE if self.full_cuda_graph: - logger.warning("full_cuda_graph is deprecated, use " - "cudagraph_mode=FULL instead.") - if self.cudagraph_mode is not None and \ - not self.cudagraph_mode.has_full_cudagraphs(): - raise ValueError("full_cuda_graph and cudagraph_mode are " - "mutually exclusive, prefer cudagraph_mode " - "since full_cuda_graph is deprecated.") + logger.warning( + "full_cuda_graph is deprecated, use cudagraph_mode=FULL instead." + ) + if ( + self.cudagraph_mode is not None + and not self.cudagraph_mode.has_full_cudagraphs() + ): + raise ValueError( + "full_cuda_graph and cudagraph_mode are " + "mutually exclusive, prefer cudagraph_mode " + "since full_cuda_graph is deprecated." + ) self.cudagraph_mode = CUDAGraphMode.FULL - if (self.use_inductor_graph_partition - and not is_torch_equal_or_newer("2.9.0.dev")): - raise ValueError("use_inductor_graph_partition is only " - "supported with torch>=2.9.0.dev. Set " - "use_inductor_graph_partition=False instead.") + if self.use_inductor_graph_partition and not is_torch_equal_or_newer( + "2.9.0.dev" + ): + raise ValueError( + "use_inductor_graph_partition is only " + "supported with torch>=2.9.0.dev. Set " + "use_inductor_graph_partition=False instead." + ) for op in self.custom_ops: - if op[0] not in {'+', '-'} and op not in {'all', 'none'}: - raise ValueError(f"Invalid syntax '{op}' for custom op, " - "must be 'all', 'none', '+op' or '-op' " - "(where 'op' is the registered op name)") + if op[0] not in {"+", "-"} and op not in {"all", "none"}: + raise ValueError( + f"Invalid syntax '{op}' for custom op, " + "must be 'all', 'none', '+op' or '-op' " + "(where 'op' is the registered op name)" + ) def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") from torch._dynamo.backends.registry import list_backends + torch_backends = list_backends(exclude_tags=tuple()) - if self.level in [ - CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE - ]: + if self.level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: if self.backend == "": return "eager" if self.backend in torch_backends: @@ -529,10 +542,10 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: assert self.level == CompilationLevel.PIECEWISE from vllm.compilation.backends import VllmBackend + return VllmBackend(vllm_config) - def init_with_cudagraph_sizes(self, - cudagraph_capture_sizes: list[int]) -> None: + def init_with_cudagraph_sizes(self, cudagraph_capture_sizes: list[int]) -> None: """To complete the initialization of config, we need to know the cudagraph sizes.""" @@ -542,9 +555,14 @@ def init_with_cudagraph_sizes(self, # de-duplicate the sizes provided by the config dedup_sizes = list(set(self.cudagraph_capture_sizes)) if len(dedup_sizes) < len(self.cudagraph_capture_sizes): - logger.info(("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - cudagraph_capture_sizes, dedup_sizes) + logger.info( + ( + "cudagraph sizes specified by model runner" + " %s is overridden by config %s" + ), + cudagraph_capture_sizes, + dedup_sizes, + ) self.cudagraph_capture_sizes = dedup_sizes computed_compile_sizes = [] @@ -553,9 +571,10 @@ def init_with_cudagraph_sizes(self, self.compile_sizes = list(set(self.compile_sizes)) for x in self.compile_sizes: if isinstance(x, str): - assert x == "cudagraph_capture_sizes", \ - "Unrecognized size type in compile_sizes, " \ + assert x == "cudagraph_capture_sizes", ( + "Unrecognized size type in compile_sizes, " f"expect 'cudagraph_capture_sizes', got {x}" + ) computed_compile_sizes.extend(self.cudagraph_capture_sizes) else: assert isinstance(x, int) @@ -564,29 +583,29 @@ def init_with_cudagraph_sizes(self, # sort to make sure cudagraph capture sizes are in descending order self.cudagraph_capture_sizes.sort(reverse=True) - self.max_capture_size = self.cudagraph_capture_sizes[ - 0] if self.cudagraph_capture_sizes else 0 + self.max_capture_size = ( + self.cudagraph_capture_sizes[0] if self.cudagraph_capture_sizes else 0 + ) # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = [ - 0 for i in range(self.max_capture_size + 1) - ] - for end, start in zip(self.cudagraph_capture_sizes, - self.cudagraph_capture_sizes[1:] + [0]): + self.bs_to_padded_graph_size = [0 for i in range(self.max_capture_size + 1)] + for end, start in zip( + self.cudagraph_capture_sizes, self.cudagraph_capture_sizes[1:] + [0] + ): for bs in range(start, end): if bs == start: self.bs_to_padded_graph_size[bs] = start else: self.bs_to_padded_graph_size[bs] = end - self.bs_to_padded_graph_size[ - self.max_capture_size] = self.max_capture_size + self.bs_to_padded_graph_size[self.max_capture_size] = self.max_capture_size def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called only when level is # CompilationLevel.PIECEWISE assert self.level == CompilationLevel.PIECEWISE, ( "set_splitting_ops_for_v1 should only be called when " - "level is CompilationLevel.PIECEWISE") + "level is CompilationLevel.PIECEWISE" + ) if self.use_inductor_graph_partition: self.set_splitting_ops_for_inductor_graph_partition() @@ -608,22 +627,23 @@ def set_splitting_ops_for_v1(self): # list via reference. self.splitting_ops = list(self._attention_ops) elif len(self.splitting_ops) == 0: - logger.warning_once( - "Using piecewise compilation with empty splitting_ops") + logger.warning_once("Using piecewise compilation with empty splitting_ops") if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: logger.warning_once( - "Piecewise compilation with empty splitting_ops do not" \ + "Piecewise compilation with empty splitting_ops do not" "contains piecewise cudagraph. Setting cudagraph_" "mode to NONE. Hint: If you are using attention backends " "that support cudagraph, consider manually setting " "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable " - "full cudagraphs.") + "full cudagraphs." + ) self.cudagraph_mode = CUDAGraphMode.NONE elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: logger.warning_once( "Piecewise compilation with empty splitting_ops do not " "contains piecewise cudagraph. Setting cudagraph_mode " - "to FULL.") + "to FULL." + ) self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] @@ -632,10 +652,10 @@ def set_splitting_ops_for_inductor_graph_partition(self): use_inductor_graph_partition_msg = ( "When use_inductor_graph_partition=True, splitting_ops " "are ignored and set to an empty list. Instead, " - "\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is " - "used to annotate custom ops for graph partition.") - if self.splitting_ops is not None and \ - len(self.splitting_ops) > 0: + '"tags=(torch._C.Tag.cudagraph_unsafe, )," is ' + "used to annotate custom ops for graph partition." + ) + if self.splitting_ops is not None and len(self.splitting_ops) > 0: logger.warning_once(use_inductor_graph_partition_msg) self.splitting_ops = [] @@ -651,32 +671,38 @@ def set_splitting_ops_for_attn_fusion(self): "list, and cudagraph_mode will be set to FULL. " "Please ensure you are using attention backends that " "support cudagraph or set cudagraph_mode to NONE " - "explicitly if encountering any problems.") + "explicitly if encountering any problems." + ) self.cudagraph_mode = CUDAGraphMode.FULL assert not self.splitting_ops_contain_attention(), ( "attention ops should not be in splitting_ops " - "when enable_attn_fusion is True") + "when enable_attn_fusion is True" + ) def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( - op in self.splitting_ops for op in self._attention_ops) + op in self.splitting_ops for op in self._attention_ops + ) def is_attention_compiled_piecewise(self) -> bool: use_fx_graph_piecewise_compilation = ( self.level == CompilationLevel.PIECEWISE - and self.splitting_ops_contain_attention()) - - inductor_used = (self.level == CompilationLevel.PIECEWISE - and self.use_inductor) or ( - self.level >= CompilationLevel.DYNAMO_AS_IS - and self.backend == "inductor") + and self.splitting_ops_contain_attention() + ) + + inductor_used = ( + self.level == CompilationLevel.PIECEWISE and self.use_inductor + ) or ( + self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor" + ) use_inductor_piecewise_compilation = ( - inductor_used and self.use_inductor_graph_partition - and not self.splitting_ops_contain_attention()) + inductor_used + and self.use_inductor_graph_partition + and not self.splitting_ops_contain_attention() + ) - return use_fx_graph_piecewise_compilation or \ - use_inductor_piecewise_compilation + return use_fx_graph_piecewise_compilation or use_inductor_piecewise_compilation def custom_op_log_check(self): """ @@ -693,13 +719,14 @@ def custom_op_log_check(self): logger.debug("enabled custom ops: %s", self.enabled_custom_ops) logger.debug("disabled custom ops: %s", self.disabled_custom_ops) - all_ops_in_model = (self.enabled_custom_ops | self.disabled_custom_ops) + all_ops_in_model = self.enabled_custom_ops | self.disabled_custom_ops for op in self.custom_ops: if op in {"all", "none"}: continue - assert op[0] in {'+', '-'}, "Invalid custom op syntax " \ - "(should be checked during init)" + assert op[0] in {"+", "-"}, ( + "Invalid custom op syntax (should be checked during init)" + ) # check if op name exists in model op_name = op[1:] @@ -708,10 +735,17 @@ def custom_op_log_check(self): # Does op exist at all or is it just not present in this model? # Note: Only imported op classes appear in the registry. - missing_str = "doesn't exist (or wasn't imported/registered)" \ - if op_name not in CustomOp.op_registry \ + missing_str = ( + "doesn't exist (or wasn't imported/registered)" + if op_name not in CustomOp.op_registry else "not present in model" + ) - enable_str = "enabling" if op[0] == '+' else "disabling" - logger.warning_once("Op '%s' %s, %s with '%s' has no effect", - op_name, missing_str, enable_str, op) + enable_str = "enabling" if op[0] == "+" else "disabling" + logger.warning_once( + "Op '%s' %s, %s with '%s' has no effect", + op_name, + missing_str, + enable_str, + op, + ) diff --git a/vllm/config/device.py b/vllm/config/device.py index 4654ac96e0b7..4b6642479541 100644 --- a/vllm/config/device.py +++ b/vllm/config/device.py @@ -45,20 +45,21 @@ def compute_hash(self) -> str: # the device/platform information will be summarized # by torch/vllm automatically. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self): if self.device == "auto": # Automated device type detection from vllm.platforms import current_platform + self.device_type = current_platform.device_type if not self.device_type: raise RuntimeError( "Failed to infer device type, please set " "the environment variable `VLLM_LOGGING_LEVEL=DEBUG` " - "to turn on verbose logging to help debug the issue.") + "to turn on verbose logging to help debug the issue." + ) else: # Device type is assigned explicitly if isinstance(self.device, str): diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index c3d9a3309eb3..b33294fd66f7 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -76,8 +76,7 @@ def compute_hash(self) -> str: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self) -> None: @@ -85,27 +84,28 @@ def __post_init__(self) -> None: self.engine_id = str(uuid.uuid4()) if self.kv_role is not None and self.kv_role not in get_args(KVRole): - raise ValueError(f"Unsupported kv_role: {self.kv_role}. " - f"Supported roles are {get_args(KVRole)}") + raise ValueError( + f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are {get_args(KVRole)}" + ) if self.kv_connector is not None and self.kv_role is None: - raise ValueError("Please specify kv_disagg_role when kv_connector " - f"is set, supported roles are {get_args(KVRole)}") + raise ValueError( + "Please specify kv_disagg_role when kv_connector " + f"is set, supported roles are {get_args(KVRole)}" + ) @property def is_kv_transfer_instance(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVRole) + return self.kv_connector is not None and self.kv_role in get_args(KVRole) @property def is_kv_producer(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVProducer) + return self.kv_connector is not None and self.kv_role in get_args(KVProducer) @property def is_kv_consumer(self) -> bool: - return self.kv_connector is not None and \ - self.kv_role in get_args(KVConsumer) + return self.kv_connector is not None and self.kv_role in get_args(KVConsumer) def get_from_extra_config(self, key, default) -> Any: return self.kv_connector_extra_config.get(key, default) diff --git a/vllm/config/load.py b/vllm/config/load.py index 26ffec23ad5c..6aacff60157b 100644 --- a/vllm/config/load.py +++ b/vllm/config/load.py @@ -61,7 +61,8 @@ class LoadConfig: initialization. However, it uses more CPU RAM. """ model_loader_extra_config: Union[dict, TensorizerConfig] = field( - default_factory=dict) + default_factory=dict + ) """Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format.""" device: Optional[str] = None @@ -99,8 +100,7 @@ def compute_hash(self) -> str: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self): @@ -108,6 +108,7 @@ def __post_init__(self): if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: logger.info( "Ignoring the following patterns when downloading weights: %s", - self.ignore_patterns) + self.ignore_patterns, + ) else: self.ignore_patterns = ["original/**/*"] diff --git a/vllm/config/lora.py b/vllm/config/lora.py index 3fe28f5dad4f..f97f2a111d41 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -47,8 +47,9 @@ class LoRAConfig: lora_extra_vocab_size: int = 256 """(Deprecated) Maximum size of extra vocabulary that can be present in a LoRA adapter. Will be removed in v0.12.0.""" - lora_vocab_padding_size: ClassVar[int] = current_platform\ - .get_lora_vocab_padding_size() + lora_vocab_padding_size: ClassVar[int] = ( + current_platform.get_lora_vocab_padding_size() + ) default_mm_loras: Optional[dict[str, str]] = None """Dictionary mapping specific modalities to LoRA model paths; this field is only applicable to multimodal models and should be leveraged when a @@ -83,8 +84,7 @@ def compute_hash(self) -> str: factors.append(self.lora_extra_vocab_size) factors.append(self.lora_vocab_padding_size) factors.append(self.bias_enabled) - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self): @@ -92,12 +92,14 @@ def __post_init__(self): logger.warning( "`lora_extra_vocab_size` is deprecated and will be removed " "in v0.12.0. Additional vocabulary support for " - "LoRA adapters is being phased out.") + "LoRA adapters is being phased out." + ) # Deprecation warning for enable_lora_bias if self.bias_enabled: - logger.warning("`enable_lora_bias` is deprecated " - "and will be removed in v0.12.0.") + logger.warning( + "`enable_lora_bias` is deprecated and will be removed in v0.12.0." + ) # Setting the maximum rank to 512 should be able to satisfy the vast # majority of applications. @@ -106,11 +108,13 @@ def __post_init__(self): if self.max_lora_rank not in possible_max_ranks: raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " - f"{possible_max_ranks}.") + f"{possible_max_ranks}." + ) if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: raise ValueError( f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " - f"must be one of {possible_lora_extra_vocab_size}.") + f"must be one of {possible_lora_extra_vocab_size}." + ) if self.max_loras < 1: raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") if self.max_cpu_loras is None: @@ -118,12 +122,12 @@ def __post_init__(self): elif self.max_cpu_loras < self.max_loras: raise ValueError( f"max_cpu_loras ({self.max_cpu_loras}) must be >= " - f"max_loras ({self.max_loras})") + f"max_loras ({self.max_loras})" + ) def verify_with_cache_config(self, cache_config: CacheConfig): if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1: - raise ValueError( - "V0 LoRA does not support CPU offload, please use V1.") + raise ValueError("V0 LoRA does not support CPU offload, please use V1.") def verify_with_model_config(self, model_config: ModelConfig): if self.lora_dtype in (None, "auto"): diff --git a/vllm/config/model.py b/vllm/config/model.py index 0bf8a9fe1f0f..146ace9782b9 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -6,31 +6,44 @@ import warnings from dataclasses import InitVar, field from importlib.util import find_spec -from typing import (TYPE_CHECKING, Any, Callable, Literal, Optional, Union, - cast, get_args) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + Optional, + Union, + cast, + get_args, +) import torch -from pydantic import (ConfigDict, SkipValidation, field_validator, - model_validator) +from pydantic import ConfigDict, SkipValidation, field_validator, model_validator from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE import vllm.envs as envs -from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode, - MultiModalConfig) +from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType from vllm.config.utils import assert_hashable, config, getattr_iter from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.transformers_utils.config import ( - ConfigFormat, get_config, get_hf_image_processor_config, - get_hf_text_config, get_pooling_config, - get_sentence_transformer_tokenizer_config, is_encoder_decoder, - is_interleaved, try_get_generation_config, try_get_safetensors_metadata, - try_get_tokenizer_config, uses_mrope) -from vllm.transformers_utils.runai_utils import (ObjectStorageModel, - is_runai_obj_uri) + ConfigFormat, + get_config, + get_hf_image_processor_config, + get_hf_text_config, + get_pooling_config, + get_sentence_transformer_tokenizer_config, + is_encoder_decoder, + is_interleaved, + try_get_generation_config, + try_get_safetensors_metadata, + try_get_tokenizer_config, + uses_mrope, +) +from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils import LayerBlockType, LazyLoader, common_broadcastable_dtype @@ -46,10 +59,10 @@ else: PretrainedConfig = Any - me_quant = LazyLoader("model_executor", globals(), - "vllm.model_executor.layers.quantization") - me_models = LazyLoader("model_executor", globals(), - "vllm.model_executor.models") + me_quant = LazyLoader( + "model_executor", globals(), "vllm.model_executor.layers.quantization" + ) + me_models = LazyLoader("model_executor", globals(), "vllm.model_executor.models") LoadConfig = Any ParallelConfig = Any QuantizationMethods = Any @@ -60,14 +73,23 @@ RunnerOption = Literal["auto", RunnerType] ConvertType = Literal["none", "embed", "classify", "reward"] ConvertOption = Literal["auto", ConvertType] -TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", - "score", "reward", "transcription", "draft"] +TaskOption = Literal[ + "auto", + "generate", + "embedding", + "embed", + "classify", + "score", + "reward", + "transcription", + "draft", +] TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] -LogprobsMode = Literal["raw_logits", "raw_logprobs", "processed_logits", - "processed_logprobs"] -HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig], - PretrainedConfig]] +LogprobsMode = Literal[ + "raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs" +] +HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig], PretrainedConfig]] ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"] _RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = { @@ -276,9 +298,7 @@ class ModelConfig: multimodal_config: Optional[MultiModalConfig] = None """Configuration for multimodal model. If `None`, this will be inferred from the architecture of `self.model`.""" - limit_mm_per_prompt: InitVar[Optional[dict[str, Union[int, - dict[str, - int]]]]] = None + limit_mm_per_prompt: InitVar[Optional[dict[str, Union[int, dict[str, int]]]]] = None media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None mm_processor_cache_gb: InitVar[Optional[float]] = None @@ -327,15 +347,19 @@ def compute_hash(self) -> str: from vllm.utils.jsontree import json_map_leaves # Handle nested HF configs with unserializable values gracefully - hf_config_json = json.dumps( - json_map_leaves( - lambda v: v.to_dict() - if isinstance(v, PretrainedConfig) else str(v), - self.hf_config.to_dict(), - ), - indent=2, - sort_keys=True, - ) + "\n" + hf_config_json = ( + json.dumps( + json_map_leaves( + lambda v: v.to_dict() + if isinstance(v, PretrainedConfig) + else str(v), + self.hf_config.to_dict(), + ), + indent=2, + sort_keys=True, + ) + + "\n" + ) factors.append(hf_config_json) @@ -373,11 +397,14 @@ def __post_init__( "The global random seed is set to %d. Since " "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may " "affect the random state of the Python process that " - "launched vLLM.", self.seed) + "launched vLLM.", + self.seed, + ) # Keep set served_model_name before maybe_model_redirect(self.model) - self.served_model_name = get_served_model_name(self.model, - self.served_model_name) + self.served_model_name = get_served_model_name( + self.model, self.served_model_name + ) self.model = maybe_model_redirect(self.model) # The tokenizer is consistent with the model by default. if self.tokenizer is None: @@ -402,7 +429,8 @@ def __post_init__( hf_overrides_str = json.dumps(hf_overrides_kw) msg = ( "`--rope-scaling` will be removed in a future release. " - f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`" + ) warnings.warn(DeprecationWarning(msg), stacklevel=2) if self.rope_theta is not None: hf_override = {"rope_theta": self.rope_theta} @@ -410,52 +438,58 @@ def __post_init__( hf_overrides_str = json.dumps(hf_overrides_kw) msg = ( "`--rope-theta` will be removed in a future release. " - f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") + f"'Please instead use `--hf-overrides '{hf_overrides_str}'`" + ) warnings.warn(DeprecationWarning(msg), stacklevel=2) self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) - if (backend := envs.VLLM_ATTENTION_BACKEND - ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: + if ( + (backend := envs.VLLM_ATTENTION_BACKEND) + and backend == "FLASHINFER" + and find_spec("flashinfer") is None + ): raise ValueError( "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " "module was not found. See " "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 - "for instructions on how to install it.") + "for instructions on how to install it." + ) from vllm.platforms import current_platform - if (self.override_attention_dtype is not None - and not current_platform.is_rocm()): + if self.override_attention_dtype is not None and not current_platform.is_rocm(): warnings.warn( "override-attention-dtype is set but not using ROCm platform", - stacklevel=2) - - if (self.enable_sleep_mode - and not current_platform.is_sleep_mode_available()): - raise ValueError( - "Sleep mode is not supported on current platform.") + stacklevel=2, + ) - hf_config = get_config(self.hf_config_path or self.model, - self.trust_remote_code, - self.revision, - self.code_revision, - self.config_format, - hf_overrides_kw=hf_overrides_kw, - hf_overrides_fn=hf_overrides_fn) + if self.enable_sleep_mode and not current_platform.is_sleep_mode_available(): + raise ValueError("Sleep mode is not supported on current platform.") + + hf_config = get_config( + self.hf_config_path or self.model, + self.trust_remote_code, + self.revision, + self.code_revision, + self.config_format, + hf_overrides_kw=hf_overrides_kw, + hf_overrides_fn=hf_overrides_fn, + ) self.hf_config = hf_config self.hf_text_config = get_hf_text_config(self.hf_config) - self.attention_chunk_size = getattr(self.hf_text_config, - "attention_chunk_size", None) + self.attention_chunk_size = getattr( + self.hf_text_config, "attention_chunk_size", None + ) self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( - self.model, hf_token=self.hf_token, revision=self.revision) + self.model, hf_token=self.hf_token, revision=self.revision + ) architectures = self.architectures registry = self.registry - is_generative_model = registry.is_text_generation_model( - architectures, self) + is_generative_model = registry.is_text_generation_model(architectures, self) is_pooling_model = registry.is_pooling_model(architectures, self) def _task_to_convert(task: TaskOption) -> ConvertType: @@ -474,8 +508,10 @@ def _task_to_convert(task: TaskOption) -> ConvertType: if self.task is not None: runner: RunnerOption = "auto" convert: ConvertOption = "auto" - msg_prefix = ("The 'task' option has been deprecated and will be " - "removed in v0.13.0 or v1.0, whichever comes first.") + msg_prefix = ( + "The 'task' option has been deprecated and will be " + "removed in v0.13.0 or v1.0, whichever comes first." + ) msg_hint = "Please remove this option." is_generative_task = self.task in _RUNNER_TASKS["generate"] @@ -485,15 +521,19 @@ def _task_to_convert(task: TaskOption) -> ConvertType: if is_generative_task: runner = "generate" convert = "auto" - msg_hint = ("Please replace this option with `--runner " - "generate` to continue using this model " - "as a generative model.") + msg_hint = ( + "Please replace this option with `--runner " + "generate` to continue using this model " + "as a generative model." + ) elif is_pooling_task: runner = "pooling" convert = "auto" - msg_hint = ("Please replace this option with `--runner " - "pooling` to continue using this model " - "as a pooling model.") + msg_hint = ( + "Please replace this option with `--runner " + "pooling` to continue using this model " + "as a pooling model." + ) else: # task == "auto" pass elif is_generative_model or is_pooling_model: @@ -504,9 +544,11 @@ def _task_to_convert(task: TaskOption) -> ConvertType: elif is_pooling_task: runner = "pooling" convert = _task_to_convert(self.task) - msg_hint = ("Please replace this option with `--convert " - f"{convert}` to continue using this model " - "as a pooling model.") + msg_hint = ( + "Please replace this option with `--convert " + f"{convert}` to continue using this model " + "as a pooling model." + ) else: # task == "auto" pass else: @@ -515,9 +557,11 @@ def _task_to_convert(task: TaskOption) -> ConvertType: "is_generative_model": is_generative_model, "is_pooling_model": is_pooling_model, } - raise AssertionError("The model should be a generative or " - "pooling model when task is set to " - f"{self.task!r}. Found: {debug_info}") + raise AssertionError( + "The model should be a generative or " + "pooling model when task is set to " + f"{self.task!r}. Found: {debug_info}" + ) self.runner = runner self.convert = convert @@ -526,16 +570,15 @@ def _task_to_convert(task: TaskOption) -> ConvertType: warnings.warn(msg, DeprecationWarning, stacklevel=2) self.runner_type = self._get_runner_type(architectures, self.runner) - self.convert_type = self._get_convert_type(architectures, - self.runner_type, - self.convert) + self.convert_type = self._get_convert_type( + architectures, self.runner_type, self.convert + ) if self.runner_type == "generate" and not is_generative_model: generate_converts = _RUNNER_CONVERTS["generate"] if self.convert_type not in generate_converts: # Currently we don't have any converters for generative models - raise ValueError( - "This model does not support `--runner generate`.") + raise ValueError("This model does not support `--runner generate`.") if self.runner_type == "pooling" and not is_pooling_model: pooling_converts = _RUNNER_CONVERTS["pooling"] if self.convert_type not in pooling_converts: @@ -543,7 +586,8 @@ def _task_to_convert(task: TaskOption) -> ConvertType: raise ValueError( "This model does not support `--runner pooling`. " f"You can pass `--convert {convert_option} to adapt " - "it into a pooling model.") + "it into a pooling model." + ) # Note: Initialize these attributes early because transformers fallback # may fail to load dynamic modules in child processes @@ -558,11 +602,11 @@ def _task_to_convert(task: TaskOption) -> ConvertType: logger.warning_once( "`override_pooler_config` is deprecated and will be " "removed in v0.12.0 or v1.0.0, whichever is sooner. " - "Please use `pooler_config` instead.") + "Please use `pooler_config` instead." + ) if isinstance(self.override_pooler_config, dict): - self.pooler_config = PoolerConfig( - **self.override_pooler_config) + self.pooler_config = PoolerConfig(**self.override_pooler_config) else: self.pooler_config = self.override_pooler_config @@ -589,11 +633,12 @@ def _task_to_convert(task: TaskOption) -> ConvertType: ) # Interleaved attention is not supported by some backends in V0 - if (not self.disable_sliding_window - and is_interleaved(self.hf_text_config) - and not envs.VLLM_USE_V1 - and (backend := envs.VLLM_ATTENTION_BACKEND) - in ("XFORMERS", "FLASHINFER")): + if ( + not self.disable_sliding_window + and is_interleaved(self.hf_text_config) + and not envs.VLLM_USE_V1 + and (backend := envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER") + ): logger.warning_once( "%s has interleaved attention, which is currently not " "supported by the %s backend. Disabling sliding window and " @@ -608,11 +653,14 @@ def _task_to_convert(task: TaskOption) -> ConvertType: self.max_model_len = self.get_and_verify_max_len(self.max_model_len) # Init multimodal config if needed if self._model_info.supports_multimodal: - if (mm_encoder_tp_mode == "data" and - not self._model_info.supports_multimodal_encoder_tp_data): + if ( + mm_encoder_tp_mode == "data" + and not self._model_info.supports_multimodal_encoder_tp_data + ): logger.warning_once( "This model does not support `--mm-encoder-tp-mode data`. " - "Falling back to `--mm-encoder-tp-mode weights`.") + "Falling back to `--mm-encoder-tp-mode weights`." + ) mm_encoder_tp_mode = "weights" mm_config_kwargs = dict( @@ -629,8 +677,7 @@ def _task_to_convert(task: TaskOption) -> ConvertType: ) mm_config_kwargs = { - k: v - for k, v in mm_config_kwargs.items() if v is not None + k: v for k, v in mm_config_kwargs.items() if v is not None } self.multimodal_config = MultiModalConfig(**mm_config_kwargs) @@ -662,8 +709,7 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": if not isinstance(self.tokenizer, str): raise ValueError("tokenizer must be a string after __post_init__.") if not isinstance(self.max_model_len, int): - raise ValueError( - "max_model_len must be an integer after __post_init__.") + raise ValueError("max_model_len must be an integer after __post_init__.") return self def _get_transformers_backend_cls(self) -> str: @@ -718,8 +764,7 @@ def architecture(self) -> str: """The architecture vllm actually used.""" return self._architecture - def maybe_pull_model_tokenizer_for_runai(self, model: str, - tokenizer: str) -> None: + def maybe_pull_model_tokenizer_for_runai(self, model: str, tokenizer: str) -> None: """Pull model/tokenizer from Object Storage to temporary directory when needed. @@ -734,42 +779,45 @@ def maybe_pull_model_tokenizer_for_runai(self, model: str, if is_runai_obj_uri(model): object_storage_model = ObjectStorageModel(url=model) object_storage_model.pull_files( - model, allow_pattern=["*.model", "*.py", "*.json"]) + model, allow_pattern=["*.model", "*.py", "*.json"] + ) self.model_weights = model self.model = object_storage_model.dir # If tokenizer is same as model, download to same directory if model == tokenizer: - object_storage_model.pull_files(model, - ignore_pattern=[ - "*.pt", "*.safetensors", - "*.bin", "*.tensors", - "*.pth" - ]) + object_storage_model.pull_files( + model, + ignore_pattern=[ + "*.pt", + "*.safetensors", + "*.bin", + "*.tensors", + "*.pth", + ], + ) self.tokenizer = object_storage_model.dir return # Only download tokenizer if needed and not already handled if is_runai_obj_uri(tokenizer): object_storage_tokenizer = ObjectStorageModel(url=tokenizer) - object_storage_tokenizer.pull_files(model, - ignore_pattern=[ - "*.pt", "*.safetensors", - "*.bin", "*.tensors", - "*.pth" - ]) + object_storage_tokenizer.pull_files( + model, + ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors", "*.pth"], + ) self.tokenizer = object_storage_tokenizer.dir def _get_encoder_config(self): - return get_sentence_transformer_tokenizer_config( - self.model, self.revision) + return get_sentence_transformer_tokenizer_config(self.model, self.revision) def _verify_tokenizer_mode(self) -> None: tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) if tokenizer_mode not in get_args(TokenizerMode): raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - f"one of {get_args(TokenizerMode)}.") + f"one of {get_args(TokenizerMode)}." + ) self.tokenizer_mode = tokenizer_mode def _get_default_runner_type( @@ -811,7 +859,8 @@ def _get_runner_type( logger.info( "Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message.", - runner_type) + runner_type, + ) return runner_type @@ -824,16 +873,16 @@ def _get_default_convert_type( for arch in architectures: if arch in registry.get_supported_archs(): - if (runner_type == "generate" - and registry.is_text_generation_model( - architectures, self)): + if runner_type == "generate" and registry.is_text_generation_model( + architectures, self + ): return "none" - if (runner_type == "pooling" - and registry.is_pooling_model(architectures, self)): + if runner_type == "pooling" and registry.is_pooling_model( + architectures, self + ): return "none" - match = try_match_architecture_defaults(arch, - runner_type=runner_type) + match = try_match_architecture_defaults(arch, runner_type=runner_type) if match: _, (_, convert_type) = match return convert_type @@ -855,15 +904,15 @@ def _get_convert_type( if convert != "auto": return convert - convert_type = self._get_default_convert_type(architectures, - runner_type) + convert_type = self._get_default_convert_type(architectures, runner_type) # Don't log the most common case if convert_type != "none": logger.info( "Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message.", - convert_type) + convert_type, + ) return convert_type @@ -875,8 +924,7 @@ def _get_default_pooling_task( return "classify" for arch in architectures: - match = try_match_architecture_defaults(arch, - runner_type="pooling") + match = try_match_architecture_defaults(arch, runner_type="pooling") if match: _, (_, convert_type) = match assert convert_type != "none" @@ -894,28 +942,26 @@ def _parse_quant_hf_config(self, hf_config: PretrainedConfig): # Set quant_method for ModelOpt models. producer_name = quant_cfg.get("producer", {}).get("name") if producer_name == "modelopt": - quant_algo = quant_cfg.get("quantization", - {}).get("quant_algo") + quant_algo = quant_cfg.get("quantization", {}).get("quant_algo") if quant_algo == "FP8": quant_cfg["quant_method"] = "modelopt" elif quant_algo == "NVFP4": quant_cfg["quant_method"] = "modelopt_fp4" elif quant_algo is not None: - raise ValueError( - f"Unknown ModelOpt quant algo: {quant_algo}") + raise ValueError(f"Unknown ModelOpt quant algo: {quant_algo}") return quant_cfg def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS if self.quantization is not None: - self.quantization = cast(me_quant.QuantizationMethods, - self.quantization) + self.quantization = cast(me_quant.QuantizationMethods, self.quantization) # Parse quantization method from the HF model config, if available. quant_cfg = self._parse_quant_hf_config(self.hf_config) - if quant_cfg is None and (text_config := getattr( - self.hf_config, "text_config", None)): + if quant_cfg is None and ( + text_config := getattr(self.hf_config, "text_config", None) + ): # Check the text config as well for multi-modal models. quant_cfg = self._parse_quant_hf_config(text_config) @@ -924,8 +970,9 @@ def _verify_quantization(self) -> None: quant_method = quant_cfg.get("quant_method", "").lower() # Normalize library names - quant_method = quant_method.replace("compressed_tensors", - "compressed-tensors") + quant_method = quant_method.replace( + "compressed_tensors", "compressed-tensors" + ) quant_cfg["quant_method"] = quant_method @@ -959,18 +1006,22 @@ def _verify_quantization(self) -> None: for name in quantization_methods: method = me_quant.get_quantization_config(name) quantization_override = method.override_quantization_method( - quant_cfg, self.quantization) + quant_cfg, self.quantization + ) if quantization_override is not None: # Raise error if the override is not custom (custom would # be in QUANTIZATION_METHODS but not QuantizationMethods) # and hasn't been added to the overrides list. - if (name in get_args(me_quant.QuantizationMethods) - and name not in overrides): + if ( + name in get_args(me_quant.QuantizationMethods) + and name not in overrides + ): raise ValueError( f"Quantization method {name} is an override but " "is has not been added to the `overrides` list " "above. This is necessary to ensure that the " - "overrides are checked in order of preference.") + "overrides are checked in order of preference." + ) quant_method = quantization_override self.quantization = quantization_override break @@ -984,24 +1035,28 @@ def _verify_quantization(self) -> None: "Quantization method specified in the model config " f"({quant_method}) does not match the quantization " f"method specified in the `quantization` argument " - f"({self.quantization}).") + f"({self.quantization})." + ) if self.quantization is not None: if self.quantization not in supported_quantization: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " - f"be one of {supported_quantization}.") + f"be one of {supported_quantization}." + ) from vllm.platforms import current_platform + current_platform.verify_quantization(self.quantization) def _verify_cuda_graph(self) -> None: # CUDAGraph capture not supported for encoder-decoder models on ROCm unsupported_rocm = self.is_encoder_decoder - if (unsupported_rocm and not self.enforce_eager - and current_platform.is_rocm()): + if unsupported_rocm and not self.enforce_eager and current_platform.is_rocm(): logger.warning( "CUDA graph is not supported for %s on ROCm yet, fallback " - "to eager mode.", self.hf_config.model_type) + "to eager mode.", + self.hf_config.model_type, + ) self.enforce_eager = True def _verify_bnb_config(self) -> None: @@ -1011,20 +1066,26 @@ def _verify_bnb_config(self) -> None: # TODO Remove this when bitsandbytes supports. """ is_bitsandbytes = self.quantization == "bitsandbytes" - has_quantization_config = (getattr(self.hf_config, - "quantization_config", None) - is not None) - is_8bit = (self.hf_config.quantization_config.get( - "load_in_8bit", False) if has_quantization_config else False) - if all([ + has_quantization_config = ( + getattr(self.hf_config, "quantization_config", None) is not None + ) + is_8bit = ( + self.hf_config.quantization_config.get("load_in_8bit", False) + if has_quantization_config + else False + ) + if all( + [ is_bitsandbytes, has_quantization_config, is_8bit, not self.enforce_eager, - ]): + ] + ): logger.warning( "CUDA graph is not supported on BitsAndBytes 8bit yet, " - "fallback to the eager mode.") + "fallback to the eager mode." + ) self.enforce_eager = True @@ -1033,7 +1094,8 @@ def _verify_with_expert_parallelism(self) -> None: if num_experts < 1: raise ValueError( "Number of experts in the model must be greater than 0 " - "when expert parallelism is enabled.") + "when expert parallelism is enabled." + ) def verify_dual_chunk_attention_config( self, @@ -1042,45 +1104,54 @@ def verify_dual_chunk_attention_config( if hasattr(self.hf_config, "dual_chunk_attention_config"): # Try loading the sparse attention config from vllm.model_executor.model_loader.weight_utils import ( - get_sparse_attention_config) + get_sparse_attention_config, + ) + sparse_attn_config = get_sparse_attention_config(self, load_config) if sparse_attn_config: self.hf_config.dual_chunk_attention_config[ - "sparse_attention_config"] = sparse_attn_config - if "sparse_attention_enabled" not in \ - self.hf_config.dual_chunk_attention_config: + "sparse_attention_config" + ] = sparse_attn_config + if ( + "sparse_attention_enabled" + not in self.hf_config.dual_chunk_attention_config + ): self.hf_config.dual_chunk_attention_config[ - "sparse_attention_enabled"] = True + "sparse_attention_enabled" + ] = True def verify_with_parallel_config( self, parallel_config: ParallelConfig, ) -> None: - if parallel_config.distributed_executor_backend == "external_launcher": assert self.seed is not None, ( "Seed must be set when using external launcher backend to " - "make sure sampling results are the same across workers.") + "make sure sampling results are the same across workers." + ) - total_num_attention_heads = getattr(self.hf_text_config, - "num_attention_heads", 0) + total_num_attention_heads = getattr( + self.hf_text_config, "num_attention_heads", 0 + ) tensor_parallel_size = parallel_config.tensor_parallel_size if total_num_attention_heads % tensor_parallel_size != 0: raise ValueError( f"Total number of attention heads ({total_num_attention_heads})" " must be divisible by tensor parallel size " - f"({tensor_parallel_size}).") + f"({tensor_parallel_size})." + ) if parallel_config.enable_expert_parallel: self._verify_with_expert_parallelism() pipeline_parallel_size = parallel_config.pipeline_parallel_size - if (pipeline_parallel_size > 1 - and not self.registry.is_pp_supported_model( - self.architectures, self)): + if pipeline_parallel_size > 1 and not self.registry.is_pp_supported_model( + self.architectures, self + ): raise NotImplementedError( "Pipeline parallelism is not supported for this model. " - "Supported models implement the `SupportsPP` interface.") + "Supported models implement the `SupportsPP` interface." + ) def get_sliding_window(self) -> Optional[int]: """Get the sliding window size from the HF text config if present.""" @@ -1096,34 +1167,39 @@ def get_hidden_size(self) -> int: def is_deepseek_mla(self) -> bool: if not hasattr(self.hf_text_config, "model_type"): return False - elif self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3', 'deepseek_v32', 'deepseek_mtp', - 'kimi_k2', 'longcat_flash'): + elif self.hf_text_config.model_type in ( + "deepseek_v2", + "deepseek_v3", + "deepseek_v32", + "deepseek_mtp", + "kimi_k2", + "longcat_flash", + ): return self.hf_text_config.kv_lora_rank is not None - elif self.hf_text_config.model_type == 'eagle': + elif self.hf_text_config.model_type == "eagle": # if the model is an EAGLE module, check for the # underlying architecture - return self.hf_text_config.model.model_type in \ - ('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \ + return ( + self.hf_text_config.model.model_type + in ("deepseek_v2", "deepseek_v3", "deepseek_v32") and self.hf_text_config.kv_lora_rank is not None + ) return False def get_head_size(self) -> int: # TODO remove hard code if self.is_deepseek_mla: - qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", - 0) + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0) if self.use_mla: return self.hf_text_config.kv_lora_rank + qk_rope_head_dim else: - qk_nope_head_dim = getattr(self.hf_text_config, - "qk_nope_head_dim", 0) + qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0) if qk_rope_head_dim and qk_nope_head_dim: return qk_rope_head_dim + qk_nope_head_dim - if hasattr(self.hf_text_config, - "model_type") and (self.hf_text_config.model_type - == "zamba2"): + if hasattr(self.hf_text_config, "model_type") and ( + self.hf_text_config.model_type == "zamba2" + ): return self.hf_text_config.attention_head_dim if self.is_attention_free: @@ -1134,13 +1210,13 @@ def get_head_size(self) -> int: return self.hf_text_config.head_dim # NOTE: Some models (such as PLaMo2.1) use `hidden_size_per_head` - if getattr(self.hf_text_config, "hidden_size_per_head", - None) is not None: + if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: return self.hf_text_config.hidden_size_per_head # FIXME(woosuk): This may not be true for all models. - return (self.hf_text_config.hidden_size // - self.hf_text_config.num_attention_heads) + return ( + self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads + ) def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" @@ -1151,9 +1227,11 @@ def get_total_num_kv_heads(self) -> int: falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] new_decoder_arch_falcon = ( self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False)) - if not new_decoder_arch_falcon and getattr(self.hf_text_config, - "multi_query", False): + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_text_config, "multi_query", False + ): # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. return 1 @@ -1164,14 +1242,19 @@ def get_total_num_kv_heads(self) -> int: return self.hf_config.attn_config["kv_n_heads"] return self.hf_config.num_attention_heads if self.hf_config.model_type == "dbrx": - return getattr(self.hf_config.attn_config, "kv_n_heads", - self.hf_config.num_attention_heads) + return getattr( + self.hf_config.attn_config, + "kv_n_heads", + self.hf_config.num_attention_heads, + ) if self.hf_config.model_type == "nemotron-nas": for block in self.hf_config.block_configs: if not block.attention.no_op: - return self.hf_config.num_attention_heads \ + return ( + self.hf_config.num_attention_heads // block.attention.n_heads_in_group + ) raise RuntimeError("Couldn't determine number of kv heads") @@ -1207,8 +1290,7 @@ def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int: # the tensor parallel size. We will replicate the KV heads in the # case where the number of KV heads is smaller than the tensor # parallel size so each GPU has at least one KV head. - return max(1, - total_num_kv_heads // parallel_config.tensor_parallel_size) + return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) def get_num_attention_heads(self, parallel_config: ParallelConfig) -> int: num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) @@ -1230,24 +1312,32 @@ def get_num_experts(self) -> int: return num_experts def get_layers_start_end_indices( - self, parallel_config: ParallelConfig) -> tuple[int, int]: + self, parallel_config: ParallelConfig + ) -> tuple[int, int]: from vllm.distributed.utils import get_pp_indices - if (self.hf_text_config.model_type == "deepseek_mtp" - or self.hf_config.model_type == "mimo_mtp" - or self.hf_config.model_type == "glm4_moe_mtp" - or self.hf_config.model_type == "ernie_mtp" - or self.hf_config.model_type == "qwen3_next_mtp"): - total_num_hidden_layers = getattr(self.hf_text_config, - "num_nextn_predict_layers", 0) - elif (self.hf_config.model_type == "longcat_flash_mtp"): - total_num_hidden_layers = getattr(self.hf_text_config, - "num_nextn_predict_layers", 1) + + if ( + self.hf_text_config.model_type == "deepseek_mtp" + or self.hf_config.model_type == "mimo_mtp" + or self.hf_config.model_type == "glm4_moe_mtp" + or self.hf_config.model_type == "ernie_mtp" + or self.hf_config.model_type == "qwen3_next_mtp" + ): + total_num_hidden_layers = getattr( + self.hf_text_config, "num_nextn_predict_layers", 0 + ) + elif self.hf_config.model_type == "longcat_flash_mtp": + total_num_hidden_layers = getattr( + self.hf_text_config, "num_nextn_predict_layers", 1 + ) else: - total_num_hidden_layers = getattr(self.hf_text_config, - "num_hidden_layers", 0) + total_num_hidden_layers = getattr( + self.hf_text_config, "num_hidden_layers", 0 + ) # the layout order is: DP x PP x TP - pp_rank = (parallel_config.rank // parallel_config.tensor_parallel_size - ) % parallel_config.pipeline_parallel_size + pp_rank = ( + parallel_config.rank // parallel_config.tensor_parallel_size + ) % parallel_config.pipeline_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) return start, end @@ -1264,9 +1354,9 @@ def get_num_layers_by_block_type( # This function relies on 'layers_block_type' in hf_config, # for w/o this attribute, we will need to have workarounds like so attn_block_type = block_type == LayerBlockType.attention - is_transformer = not self.is_hybrid and \ - not self.has_noops and \ - not self.is_attention_free + is_transformer = ( + not self.is_hybrid and not self.has_noops and not self.is_attention_free + ) start, end = self.get_layers_start_end_indices(parallel_config) if is_transformer: @@ -1279,23 +1369,25 @@ def get_num_layers_by_block_type( return 0 if attn_block_type else end - start elif self.has_noops: block_configs = self.hf_config.block_configs - return sum(not bc.attention.no_op - for bc in block_configs[start:end]) + return sum(not bc.attention.no_op for bc in block_configs[start:end]) else: # Hybrid model Jamba - layers_block_type_value = getattr(self.hf_text_config, - "layers_block_type", None) + layers_block_type_value = getattr( + self.hf_text_config, "layers_block_type", None + ) if layers_block_type_value is not None: - if hasattr(self.hf_text_config, - "model_type") and (self.hf_text_config.model_type - == "zamba2"): + if hasattr(self.hf_text_config, "model_type") and ( + self.hf_text_config.model_type == "zamba2" + ): if attn_block_type: - return sum(t == "hybrid" - for t in layers_block_type_value[start:end]) + return sum( + t == "hybrid" for t in layers_block_type_value[start:end] + ) else: return self.get_num_layers(parallel_config) - return sum(t == block_type.value - for t in layers_block_type_value[start:end]) + return sum( + t == block_type.value for t in layers_block_type_value[start:end] + ) # Hybrid model Minimax attn_type_list = getattr(self.hf_config, "attn_type_list", None) @@ -1306,23 +1398,30 @@ def get_num_layers_by_block_type( layer_types_value = getattr(self.hf_config, "layer_types", None) if layer_types_value is not None: if getattr(block_type, "value", block_type) == "attention": - return sum(t == "full_attention" - for t in layer_types_value[start:end]) - elif getattr(block_type, "value", - block_type) == "linear_attention": - return sum(t == "linear_attention" - for t in layer_types_value[start:end]) + return sum( + t == "full_attention" for t in layer_types_value[start:end] + ) + elif getattr(block_type, "value", block_type) == "linear_attention": + return sum( + t == "linear_attention" for t in layer_types_value[start:end] + ) else: - return sum(t == getattr(block_type, "value", block_type) - for t in layer_types_value[start:end]) - - if (layers_block_type_value is None and attn_type_list is None - and layer_types_value is None): + return sum( + t == getattr(block_type, "value", block_type) + for t in layer_types_value[start:end] + ) + + if ( + layers_block_type_value is None + and attn_type_list is None + and layer_types_value is None + ): raise ValueError( "The model is an hybrid without a" "layers_block_type or an attn_type_list, or a layer_types " "in the hf_config, cannot determine the num of " - f"{block_type.value} layers") + f"{block_type.value} layers" + ) def get_mamba_chunk_size(self) -> Optional[int]: """ @@ -1411,14 +1510,14 @@ def get_diff_sampling_param(self) -> dict[str, Any]: ] if any(p in config for p in available_params): diff_sampling_param = { - p: config.get(p) - for p in available_params if config.get(p) is not None + p: config.get(p) for p in available_params if config.get(p) is not None } # Huggingface definition of max_new_tokens is equivalent # to vLLM's max_tokens if "max_new_tokens" in diff_sampling_param: diff_sampling_param["max_tokens"] = diff_sampling_param.pop( - "max_new_tokens") + "max_new_tokens" + ) else: diff_sampling_param = {} @@ -1427,7 +1526,8 @@ def get_diff_sampling_param(self) -> dict[str, Any]: "Default sampling parameters have been overridden by the " "model's Hugging Face generation config recommended from the " "model creator. If this is not intended, please relaunch " - "vLLM instance with `--generation-config vllm`.") + "vLLM instance with `--generation-config vllm`." + ) return diff_sampling_param @property @@ -1449,8 +1549,9 @@ def is_multimodal_raw_input_only_model(self) -> bool: @property def is_cross_encoder(self) -> bool: - return (self._model_info.supports_cross_encoding - or self.convert_type == "classify") + return ( + self._model_info.supports_cross_encoding or self.convert_type == "classify" + ) @property def is_pp_supported(self) -> bool: @@ -1482,8 +1583,9 @@ def use_mla(self) -> bool: @property def is_matryoshka(self) -> bool: - return (bool(getattr(self.hf_config, "matryoshka_dimensions", None)) - or getattr(self.hf_config, "is_matryoshka", False)) + return bool(getattr(self.hf_config, "matryoshka_dimensions", None)) or getattr( + self.hf_config, "is_matryoshka", False + ) @property def matryoshka_dimensions(self): @@ -1507,20 +1609,25 @@ def head_dtype(self) -> torch.dtype: you can use --hf-overrides '{"head_dtype": "model"}' to disable it. """ - head_dtype = _get_head_dtype(config=self.hf_config, - dtype=self.dtype, - runner_type=self.runner_type) + head_dtype = _get_head_dtype( + config=self.hf_config, dtype=self.dtype, runner_type=self.runner_type + ) if self.runner_type != "pooling" and head_dtype != self.dtype: logger.warning_once( "`head_dtype` currently only supports pooling models." - "fallback to model dtype [%s].", self.dtype) + "fallback to model dtype [%s].", + self.dtype, + ) return self.dtype if head_dtype not in current_platform.supported_dtypes: logger.warning_once( "The current platform does not support [%s] head dtype, " - "fallback to model dtype [%s].", head_dtype, self.dtype) + "fallback to model dtype [%s].", + head_dtype, + self.dtype, + ) return self.dtype logger.debug_once("head dtype: %s", head_dtype) @@ -1530,12 +1637,15 @@ def get_and_verify_max_len(self, max_model_len: int): # Consider max_model_len in tokenizer_config only when # pooling models use absolute position_embedding. tokenizer_config = None - if (self.runner_type == "pooling" and getattr( - self.hf_config, "position_embedding_type", "") == "absolute"): + if ( + self.runner_type == "pooling" + and getattr(self.hf_config, "position_embedding_type", "") == "absolute" + ): tokenizer_config = try_get_tokenizer_config( self.tokenizer, trust_remote_code=self.trust_remote_code, - revision=self.tokenizer_revision) + revision=self.tokenizer_revision, + ) max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, tokenizer_config=tokenizer_config, @@ -1543,13 +1653,15 @@ def get_and_verify_max_len(self, max_model_len: int): disable_sliding_window=self.disable_sliding_window, sliding_window=self.get_sliding_window(), spec_target_max_model_len=self.spec_target_max_model_len, - encoder_config=self.encoder_config) + encoder_config=self.encoder_config, + ) logger.info("Using max model len %s", max_model_len) return max_model_len -def get_served_model_name(model: str, - served_model_name: Optional[Union[str, list[str]]]): +def get_served_model_name( + model: str, served_model_name: Optional[Union[str, list[str]]] +): """ If the input is a non-empty list, the first model_name in `served_model_name` is taken. @@ -1596,11 +1708,15 @@ def try_match_architecture_defaults( runner_type: Optional[RunnerType] = None, convert_type: Optional[ConvertType] = None, ) -> Optional[tuple[str, tuple[RunnerType, ConvertType]]]: - for suffix, (default_runner_type, - default_convert_type) in iter_architecture_defaults(): - if ((runner_type is None or runner_type == default_runner_type) and - (convert_type is None or convert_type == default_convert_type) - and architecture.endswith(suffix)): + for suffix, ( + default_runner_type, + default_convert_type, + ) in iter_architecture_defaults(): + if ( + (runner_type is None or runner_type == default_runner_type) + and (convert_type is None or convert_type == default_convert_type) + and architecture.endswith(suffix) + ): return suffix, (default_runner_type, default_convert_type) return None @@ -1618,8 +1734,7 @@ def try_match_architecture_defaults( _FLOAT16_NOT_SUPPORTED_MODELS = { "gemma2": "Numerical instability. Please use bfloat16 or float32 instead.", "gemma3": "Numerical instability. Please use bfloat16 or float32 instead.", - "gemma3_text": - "Numerical instability. Please use bfloat16 or float32 instead.", + "gemma3_text": "Numerical instability. Please use bfloat16 or float32 instead.", "plamo2": "Numerical instability. Please use bfloat16 or float32 instead.", "glm4": "Numerical instability. Please use bfloat16 or float32 instead.", } @@ -1635,8 +1750,9 @@ def _is_valid_dtype(model_type: str, dtype: torch.dtype): def _check_valid_dtype(model_type: str, dtype: torch.dtype): if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type] - raise ValueError(f"The model type {model_type!r} " - f"does not support float16. Reason: {reason}") + raise ValueError( + f"The model type {model_type!r} does not support float16. Reason: {reason}" + ) return True @@ -1690,7 +1806,8 @@ def _resolve_auto_dtype( from vllm.platforms import current_platform supported_dtypes = [ - dtype for dtype in current_platform.supported_dtypes + dtype + for dtype in current_platform.supported_dtypes if _is_valid_dtype(model_type, dtype) ] @@ -1717,8 +1834,7 @@ def _resolve_auto_dtype( device_str = f"{device_name!r} (with compute capability {version_str})" logger.warning( - "Your device %s doesn't support %s. " - "Falling back to %s for compatibility.", + "Your device %s doesn't support %s. Falling back to %s for compatibility.", device_str, config_dtype, preferred_dtype, @@ -1772,11 +1888,10 @@ def _get_and_verify_dtype( return torch_dtype -def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype, - runner_type: str) -> torch.dtype: - head_dtype: Optional[Union[str, - torch.dtype]] = getattr(config, "head_dtype", - None) +def _get_head_dtype( + config: PretrainedConfig, dtype: torch.dtype, runner_type: str +) -> torch.dtype: + head_dtype: Optional[Union[str, torch.dtype]] = getattr(config, "head_dtype", None) if head_dtype == "model": return dtype @@ -1831,8 +1946,7 @@ def _get_and_verify_max_len( for key in possible_keys: max_len = getattr(hf_config, key, None) if max_len is not None: - max_len_key = key if max_len < derived_max_model_len \ - else max_len_key + max_len_key = key if max_len < derived_max_model_len else max_len_key derived_max_model_len = min(derived_max_model_len, max_len) # For Command-R / Cohere, Cohere2 / Aya Vision models if tmp_max_len := getattr(hf_config, "model_max_length", None): @@ -1841,17 +1955,20 @@ def _get_and_verify_max_len( # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. - if (disable_sliding_window and sliding_window is not None - and sliding_window < derived_max_model_len): + if ( + disable_sliding_window + and sliding_window is not None + and sliding_window < derived_max_model_len + ): max_len_key = "sliding_window" derived_max_model_len = sliding_window # Consider model_max_length in tokenizer_config if tokenizer_config: tokenizer_model_max_length = tokenizer_config.get( - "model_max_length", derived_max_model_len) - derived_max_model_len = min(derived_max_model_len, - tokenizer_model_max_length) + "model_max_length", derived_max_model_len + ) + derived_max_model_len = min(derived_max_model_len, tokenizer_model_max_length) # If none of the keys were found in the config, use a default and # log a warning. @@ -1869,8 +1986,10 @@ def _get_and_verify_max_len( logger.warning( "The model's config.json does not contain any of the following " "keys to determine the original maximum length of the model: " - "%s. Assuming the model's maximum length is %d.", possible_keys, - default_max_len) + "%s. Assuming the model's maximum length is %d.", + possible_keys, + default_max_len, + ) derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) @@ -1888,15 +2007,15 @@ def _get_and_verify_max_len( raise NotImplementedError( "Disabling sliding window is not supported for models " "with rope_scaling. Please raise an issue so we can " - "investigate.") + "investigate." + ) # NOTE: rope_type == "default" does not define factor # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py scaling_factor = rope_scaling.get("factor", 1.0) if rope_type == "yarn": - derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] + derived_max_model_len = rope_scaling["original_max_position_embeddings"] derived_max_model_len *= scaling_factor if encoder_config and "max_seq_length" in encoder_config: @@ -1913,7 +2032,9 @@ def _get_and_verify_max_len( "which might be too large." "Please input with --max-model-len based on your " "request input length and output length, to avoid " - "unnecessary degradation.", max_model_len) + "unnecessary degradation.", + max_model_len, + ) elif max_model_len > derived_max_model_len: # Some models might have a separate key for specifying model_max_length # that will be bigger than derived_max_model_len. We compare user input @@ -1926,24 +2047,28 @@ def _get_and_verify_max_len( raise NotImplementedError( "Disabling sliding window is not supported for models " "model_max_length in the config. Please raise an issue " - "so we can investigate.") + "so we can investigate." + ) else: msg = ( f"User-specified max_model_len ({max_model_len}) is greater " f"than the derived max_model_len ({max_len_key}=" f"{derived_max_model_len} or model_max_length=" - f"{model_max_length} in model's config.json).") + f"{model_max_length} in model's config.json)." + ) warning = ( "VLLM_ALLOW_LONG_MAX_MODEL_LEN must be used with extreme " "caution. If the model uses relative position encoding (RoPE), " "positions exceeding derived_max_model_len lead to nan. If the " "model uses absolute position encoding, positions exceeding " "derived_max_model_len will cause a CUDA array out-of-bounds " - "error.") + "error." + ) if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: logger.warning_once("%s %s", msg, warning) else: raise ValueError( f"{msg} To allow overriding this maximum, set " - f"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. {warning}") + f"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. {warning}" + ) return int(max_model_len) diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index fd62d2411ade..fc8d2262dcb4 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -15,12 +15,14 @@ @dataclass class BaseDummyOptions: """Base options for generating dummy data during profiling.""" + count: int = Field(999, ge=0) @dataclass(config=ConfigDict(extra="forbid")) class VideoDummyOptions(BaseDummyOptions): """Options for generating dummy video data during profiling.""" + num_frames: Optional[int] = Field(None, gt=0) width: Optional[int] = Field(None, gt=0) height: Optional[int] = Field(None, gt=0) @@ -29,6 +31,7 @@ class VideoDummyOptions(BaseDummyOptions): @dataclass(config=ConfigDict(extra="forbid")) class ImageDummyOptions(BaseDummyOptions): """Options for generating dummy image data during profiling.""" + width: Optional[int] = Field(None, gt=0) height: Optional[int] = Field(None, gt=0) @@ -36,13 +39,15 @@ class ImageDummyOptions(BaseDummyOptions): @dataclass(config=ConfigDict(extra="forbid")) class AudioDummyOptions(BaseDummyOptions): """Options for generating dummy audio data during profiling.""" + length: Optional[int] = Field(None, gt=0) MMEncoderTPMode = Literal["weights", "data"] MMCacheType = Literal["shm", "lru"] -DummyOptions = Union[BaseDummyOptions, VideoDummyOptions, ImageDummyOptions, - AudioDummyOptions] +DummyOptions = Union[ + BaseDummyOptions, VideoDummyOptions, ImageDummyOptions, AudioDummyOptions +] @config @@ -127,9 +132,8 @@ class MultiModalConfig: @field_validator("limit_per_prompt", mode="before") @classmethod def _validate_limit_per_prompt( - cls, value: dict[str, Union[int, - dict[str, - int]]]) -> dict[str, DummyOptions]: + cls, value: dict[str, Union[int, dict[str, int]]] + ) -> dict[str, DummyOptions]: for k, v in value.items(): # Handle legacy format where only count is specified if isinstance(v, int): @@ -160,8 +164,7 @@ def compute_hash(self) -> str: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def get_limit_per_prompt(self, modality: str) -> int: @@ -196,5 +199,4 @@ def merge_mm_processor_kwargs( return kwargs | dict(inference_kwargs) def is_multimodal_pruning_enabled(self): - return (self.video_pruning_rate is not None - and self.video_pruning_rate > 0) + return self.video_pruning_rate is not None and self.video_pruning_rate > 0 diff --git a/vllm/config/observability.py b/vllm/config/observability.py index 766d03051e21..6c7b5fbbee47 100644 --- a/vllm/config/observability.py +++ b/vllm/config/observability.py @@ -31,8 +31,7 @@ def show_hidden_metrics(self) -> bool: """Check if the hidden metrics should be shown.""" if self.show_hidden_metrics_for_version is None: return False - return version._prev_minor_version_was( - self.show_hidden_metrics_for_version) + return version._prev_minor_version_was(self.show_hidden_metrics_for_version) otlp_traces_endpoint: Optional[str] = None """Target URL to which OpenTelemetry traces will be sent.""" @@ -49,16 +48,18 @@ def show_hidden_metrics(self) -> bool: @cached_property def collect_model_forward_time(self) -> bool: """Whether to collect model forward time for the request.""" - return (self.collect_detailed_traces is not None - and ("model" in self.collect_detailed_traces - or "all" in self.collect_detailed_traces)) + return self.collect_detailed_traces is not None and ( + "model" in self.collect_detailed_traces + or "all" in self.collect_detailed_traces + ) @cached_property def collect_model_execute_time(self) -> bool: """Whether to collect model execute time for the request.""" - return (self.collect_detailed_traces is not None - and ("worker" in self.collect_detailed_traces - or "all" in self.collect_detailed_traces)) + return self.collect_detailed_traces is not None and ( + "worker" in self.collect_detailed_traces + or "all" in self.collect_detailed_traces + ) def compute_hash(self) -> str: """ @@ -75,25 +76,28 @@ def compute_hash(self) -> str: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self): - if (self.collect_detailed_traces is not None - and len(self.collect_detailed_traces) == 1 - and "," in self.collect_detailed_traces[0]): + if ( + self.collect_detailed_traces is not None + and len(self.collect_detailed_traces) == 1 + and "," in self.collect_detailed_traces[0] + ): self._parse_collect_detailed_traces() from vllm.tracing import is_otel_available, otel_import_error_traceback + if not is_otel_available() and self.otlp_traces_endpoint is not None: raise ValueError( "OpenTelemetry is not available. Unable to configure " "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " - f"installed. Original error:\n{otel_import_error_traceback}") + f"installed. Original error:\n{otel_import_error_traceback}" + ) def _parse_collect_detailed_traces(self): assert isinstance(self.collect_detailed_traces, list) self.collect_detailed_traces = cast( - list[DetailedTraceModules], - self.collect_detailed_traces[0].split(",")) + list[DetailedTraceModules], self.collect_detailed_traces[0].split(",") + ) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 8b980458ddaf..649b2434ebbf 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -161,9 +161,9 @@ class ParallelConfig: placement_group: Optional[PlacementGroup] = None """ray distributed model workers placement group.""" - distributed_executor_backend: Optional[Union[str, - DistributedExecutorBackend, - type[ExecutorBase]]] = None + distributed_executor_backend: Optional[ + Union[str, DistributedExecutorBackend, type[ExecutorBase]] + ] = None """Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size is less than @@ -253,7 +253,8 @@ def stateless_init_dp_group(self) -> ProcessGroup: from torch.distributed import DistNetworkError from vllm.distributed.utils import ( - stateless_init_torch_distributed_process_group) + stateless_init_torch_distributed_process_group, + ) max_retries = 5 last_exc: Optional[Exception] = None @@ -265,12 +266,12 @@ def stateless_init_dp_group(self) -> ProcessGroup: self.get_next_dp_init_port(), self.data_parallel_rank, self.data_parallel_size, - backend="gloo") + backend="gloo", + ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. if "EADDRINUSE" in str(e): - logger.warning( - "Address already in use. Retrying with a new port.") + logger.warning("Address already in use. Retrying with a new port.") last_exc = e continue # try again with a new port raise e @@ -290,19 +291,22 @@ def stateless_init_dp_group(self) -> ProcessGroup: # Not needed for pplx-kernels as it can handle duplicate input tokens. @property def use_sequence_parallel_moe(self) -> bool: - return (envs.VLLM_ALL2ALL_BACKEND - in ("allgather_reducescatter", "naive", - "deepep_high_throughput", "deepep_low_latency") - and self.enable_expert_parallel - and self.tensor_parallel_size > 1 - and self.data_parallel_size > 1) + return ( + envs.VLLM_ALL2ALL_BACKEND + in ( + "allgather_reducescatter", + "naive", + "deepep_high_throughput", + "deepep_low_latency", + ) + and self.enable_expert_parallel + and self.tensor_parallel_size > 1 + and self.data_parallel_size > 1 + ) @staticmethod - def has_unfinished_dp(dp_group: ProcessGroup, - has_unfinished: bool) -> bool: - tensor = torch.tensor([has_unfinished], - dtype=torch.int32, - device="cpu") + def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool: + tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu") # dp rank 0: has_unfinished_seqs=True # dp rank 1: has_unfinished_seqs=False # aggregated: has_unfinished_seqs=True @@ -312,13 +316,10 @@ def has_unfinished_dp(dp_group: ProcessGroup, return aggregated_has_unfinished @staticmethod - def sync_kv_cache_memory_size(dp_group: ProcessGroup, - kv_cache_memory: int) -> int: + def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> int: if kv_cache_memory == -1: kv_cache_memory = torch.iinfo(torch.int64).max - tensor = torch.tensor([kv_cache_memory], - dtype=torch.int64, - device="cpu") + tensor = torch.tensor([kv_cache_memory], dtype=torch.int64, device="cpu") # we cannot use broadcast for stateless dp group since it depends # on global rank torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group) @@ -343,38 +344,40 @@ def compute_hash(self): def __post_init__(self) -> None: # Forward deprecated fields to their new location if self.num_redundant_experts is not None: - self.eplb_config.num_redundant_experts = ( - self.num_redundant_experts) + self.eplb_config.num_redundant_experts = self.num_redundant_experts logger.warning_once( "num_redundant_experts is deprecated and has been replaced " "with eplb_config.num_redundant_experts. This will be removed " "in v0.12.0. Changing this field after initialization will " - "have no effect.") + "have no effect." + ) if self.eplb_window_size is not None: self.eplb_config.window_size = self.eplb_window_size logger.warning_once( "eplb_window_size is deprecated and has been replaced " "with eplb_config.window_size. This will be removed " "in v0.12.0. Changing this field after initialization will " - "have no effect.") + "have no effect." + ) if self.eplb_step_interval is not None: self.eplb_config.step_interval = self.eplb_step_interval logger.warning_once( "eplb_step_interval is deprecated and has been replaced " "with eplb_config.step_interval. This will be removed " "in v0.12.0. Changing this field after initialization will " - "have no effect.") + "have no effect." + ) if self.eplb_log_balancedness is not None: self.eplb_config.log_balancedness = self.eplb_log_balancedness logger.warning_once( "eplb_log_balancedness is deprecated and has been replaced " "with eplb_config.log_balancedness. This will be removed " "in v0.12.0. Changing this field after initialization will " - "have no effect.") + "have no effect." + ) # Continue with the rest of the initialization - self.world_size = self.pipeline_parallel_size * \ - self.tensor_parallel_size + self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size if self.distributed_executor_backend == "external_launcher": logger.info("Using external launcher for distributed inference.") @@ -383,26 +386,30 @@ def __post_init__(self) -> None: if self.data_parallel_size_local > self.data_parallel_size: raise ValueError( f"data_parallel_size_local ({self.data_parallel_size_local}) " - f"must be <= data_parallel_size ({self.data_parallel_size})") + f"must be <= data_parallel_size ({self.data_parallel_size})" + ) if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. if self.distributed_executor_backend == "external_launcher": # For external launcher, # we need to set the data parallel rank automatically - self.data_parallel_rank = int(os.environ["RANK"]) \ - // (self.world_size // self.data_parallel_size) - logger.info("Set data_parallel_rank to %d automatically.", - self.data_parallel_rank) + self.data_parallel_rank = int(os.environ["RANK"]) // ( + self.world_size // self.data_parallel_size + ) + logger.info( + "Set data_parallel_rank to %d automatically.", + self.data_parallel_rank, + ) if not self._data_parallel_master_port_list: self._data_parallel_master_port_list = get_open_ports_list(5) - self.data_parallel_master_port = \ - self._data_parallel_master_port_list.pop() + self.data_parallel_master_port = self._data_parallel_master_port_list.pop() if not (0 <= self.data_parallel_rank < self.data_parallel_size): raise ValueError( f"data_parallel_rank ({self.data_parallel_rank})" - f" must be in the range [0, {self.data_parallel_size})") + f" must be in the range [0, {self.data_parallel_size})" + ) else: # Otherwise fall back to env vars (e.g. for offline SPMD case). self.data_parallel_size = envs.VLLM_DP_SIZE @@ -412,8 +419,10 @@ def __post_init__(self) -> None: self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT if self.data_parallel_external_lb: - raise ValueError("data_parallel_external_lb can only " - "be set when data_parallel_size > 1") + raise ValueError( + "data_parallel_external_lb can only " + "be set when data_parallel_size > 1" + ) if self.distributed_executor_backend == "external_launcher": os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" @@ -423,14 +432,15 @@ def __post_init__(self) -> None: if not current_platform.is_cuda(): raise ValueError( "Expert parallelism load balancing is only supported on " - "CUDA devices now.") + "CUDA devices now." + ) if self.eplb_config.num_redundant_experts < 0: raise ValueError( "num_redundant_experts must be non-negative, but got " - f"{self.eplb_config.num_redundant_experts}.") + f"{self.eplb_config.num_redundant_experts}." + ) if not self.enable_expert_parallel: - raise ValueError( - "enable_expert_parallel must be True to use EPLB.") + raise ValueError("enable_expert_parallel must be True to use EPLB.") if self.tensor_parallel_size * self.data_parallel_size <= 1: raise ValueError( "EPLB requires tensor_parallel_size or data_parallel_size " @@ -443,41 +453,50 @@ def __post_init__(self) -> None: "num_redundant_experts is set to " f"{self.eplb_config.num_redundant_experts} but EPLB is not " "enabled. Either enable EPLB or unset " - "num_redundant_experts.") + "num_redundant_experts." + ) if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. from vllm.executor import ray_utils + backend: DistributedExecutorBackend = "mp" ray_found = ray_utils.ray_is_available() if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD: backend = "uni" - elif (current_platform.is_cuda() - and cuda_device_count_stateless() < self.world_size): + elif ( + current_platform.is_cuda() + and cuda_device_count_stateless() < self.world_size + ): if not ray_found: - raise ValueError("Unable to load Ray: " - f"{ray_utils.ray_import_err}. Ray is " - "required for multi-node inference, " - "please install Ray with `pip install " - "ray`.") + raise ValueError( + "Unable to load Ray: " + f"{ray_utils.ray_import_err}. Ray is " + "required for multi-node inference, " + "please install Ray with `pip install " + "ray`." + ) backend = "ray" elif self.data_parallel_backend == "ray": - logger.info("Using ray distributed inference because " - "data_parallel_backend is ray") + logger.info( + "Using ray distributed inference because " + "data_parallel_backend is ray" + ) backend = "ray" elif ray_found: if self.placement_group: backend = "ray" else: from ray import is_initialized as ray_is_initialized + if ray_is_initialized(): from ray.util import get_current_placement_group + if get_current_placement_group(): backend = "ray" self.distributed_executor_backend = backend - logger.debug("Defaulting to use %s for distributed inference", - backend) + logger.debug("Defaulting to use %s for distributed inference", backend) if self.distributed_executor_backend is None and self.world_size == 1: self.distributed_executor_backend = "uni" @@ -486,39 +505,50 @@ def __post_init__(self) -> None: raise ValueError( "Invalid value of `_api_process_rank`. " f"Expected to be `-1` or `[0, {self._api_process_count})`, " - f"but found: {self._api_process_rank}") + f"but found: {self._api_process_rank}" + ) @property def use_ray(self) -> bool: return self.distributed_executor_backend == "ray" or ( isinstance(self.distributed_executor_backend, type) - and getattr(self.distributed_executor_backend, "uses_ray", False)) + and getattr(self.distributed_executor_backend, "uses_ray", False) + ) - @model_validator(mode='after') + @model_validator(mode="after") def _verify_args(self) -> Self: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase from vllm.platforms import current_platform - if self.distributed_executor_backend is not None and not isinstance( - self.distributed_executor_backend, str) and not (isinstance( - self.distributed_executor_backend, type) and issubclass( - self.distributed_executor_backend, ExecutorBase)): + + if ( + self.distributed_executor_backend is not None + and not isinstance(self.distributed_executor_backend, str) + and not ( + isinstance(self.distributed_executor_backend, type) + and issubclass(self.distributed_executor_backend, ExecutorBase) + ) + ): raise ValueError( "Unrecognized distributed executor backend " f"{self.distributed_executor_backend}. Supported " "values are 'ray', 'mp' 'uni', 'external_launcher', " - " custom ExecutorBase subclass or its import path.") + " custom ExecutorBase subclass or its import path." + ) if self.use_ray: from vllm.executor import ray_utils + ray_utils.assert_ray_available() if not current_platform.use_custom_allreduce(): self.disable_custom_all_reduce = True logger.debug( "Disabled the custom all-reduce kernel because it is not " - "supported on current platform.") + "supported on current platform." + ) if self.ray_workers_use_nsight and not self.use_ray: - raise ValueError("Unable to use nsight profiling unless workers " - "run with Ray.") + raise ValueError( + "Unable to use nsight profiling unless workers run with Ray." + ) return self diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index 85b5a1ace85f..8b10992faa02 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -92,6 +92,5 @@ def compute_hash(self) -> str: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 1b0a10d3a069..396258aac287 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -11,9 +11,11 @@ from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, - MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, - POOLING_MODEL_MAX_NUM_BATCHED_TOKENS) +from vllm.utils import ( + DEFAULT_MAX_NUM_BATCHED_TOKENS, + MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, +) logger = init_logger(__name__) @@ -164,8 +166,7 @@ def compute_hash(self) -> str: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self, is_encoder_decoder: bool) -> None: @@ -183,7 +184,8 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: self.long_prefill_token_threshold = 0 logger.info( "Encoder-decoder models do not support chunked prefill nor" - " prefix caching; disabling both.") + " prefix caching; disabling both." + ) if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: @@ -193,7 +195,8 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value # for higher throughput. self.max_num_batched_tokens = max( - self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS + ) if self.runner_type == "pooling": # Choose specific value for higher throughput @@ -212,8 +215,8 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: # Ensure max_num_batched_tokens does not exceed model limit. # Some models (e.g., Whisper) have embeddings tied to max length. self.max_num_batched_tokens = min( - self.max_num_seqs * self.max_model_len, - self.max_num_batched_tokens) + self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens + ) self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens @@ -221,20 +224,22 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: if self.enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", - self.max_num_batched_tokens) + self.max_num_batched_tokens, + ) self.chunked_prefill_enabled = self.enable_chunked_prefill if self.max_num_partial_prefills > 1: if self.long_prefill_token_threshold == 0: - self.long_prefill_token_threshold = int(self.max_model_len * - 0.04) + self.long_prefill_token_threshold = int(self.max_model_len * 0.04) logger.info( "Concurrent partial prefills enabled with " "max_num_partial_prefills=%d, max_long_partial_prefills=%d, " "long_prefill_token_threshold=%d", - self.max_num_partial_prefills, self.max_long_partial_prefills, - self.long_prefill_token_threshold) + self.max_num_partial_prefills, + self.max_long_partial_prefills, + self.long_prefill_token_threshold, + ) # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)]. # This avoids OOM in tight memory scenarios with small max_num_seqs, @@ -244,61 +249,71 @@ def __post_init__(self, is_encoder_decoder: bool) -> None: self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)] if self.async_scheduling: - self.scheduler_cls = ( - "vllm.v1.core.sched.async_scheduler.AsyncScheduler") + self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler" - @model_validator(mode='after') + @model_validator(mode="after") def _verify_args(self) -> Self: - if (self.max_num_batched_tokens < self.max_model_len - and not self.chunked_prefill_enabled): + if ( + self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled + ): raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"smaller than max_model_len ({self.max_model_len}). " "This effectively limits the maximum sequence length to " "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " - "decrease max_model_len.") + "decrease max_model_len." + ) if self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " - f"({self.max_num_seqs}).") + f"({self.max_num_seqs})." + ) if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: logger.warning( "max_num_batched_tokens (%d) exceeds max_num_seqs " "* max_model_len (%d). This may lead to unexpected behavior.", self.max_num_batched_tokens, - self.max_num_seqs * self.max_model_len) + self.max_num_seqs * self.max_model_len, + ) if self.num_lookahead_slots < 0: raise ValueError( "num_lookahead_slots " f"({self.num_lookahead_slots}) must be greater than or " - "equal to 0.") + "equal to 0." + ) if self.max_num_partial_prefills < 1: raise ValueError( f"max_num_partial_prefills ({self.max_num_partial_prefills}) " - "must be greater than or equal to 1.") + "must be greater than or equal to 1." + ) elif self.max_num_partial_prefills > 1: if not self.chunked_prefill_enabled: - raise ValueError("Chunked prefill must be enabled to set " - "max_num_partial_prefills > 1.") + raise ValueError( + "Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1." + ) if self.long_prefill_token_threshold > self.max_model_len: raise ValueError( "long_prefill_token_threshold " f"({self.long_prefill_token_threshold}) cannot be greater " - f"than the max_model_len ({self.max_model_len}).") + f"than the max_model_len ({self.max_model_len})." + ) - if (self.max_long_partial_prefills - < 1) or (self.max_long_partial_prefills - > self.max_num_partial_prefills): + if (self.max_long_partial_prefills < 1) or ( + self.max_long_partial_prefills > self.max_num_partial_prefills + ): raise ValueError( f"max_long_partial_prefills ({self.max_long_partial_prefills}) " "must be greater than or equal to 1 and less than or equal to " - f"max_num_partial_prefills ({self.max_num_partial_prefills}).") + f"max_num_partial_prefills ({self.max_num_partial_prefills})." + ) return self diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index d5c6d1d4d866..aa0c07cf62a3 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -24,23 +24,41 @@ PretrainedConfig = Any ModelConfig = Any - me_quant = LazyLoader("model_executor", globals(), - "vllm.model_executor.layers.quantization") + me_quant = LazyLoader( + "model_executor", globals(), "vllm.model_executor.layers.quantization" + ) logger = init_logger(__name__) -SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", - "mlp_speculator", "draft_model", "deepseek_mtp", - "ernie_mtp", "qwen3_next_mtp", "mimo_mtp", - "longcat_flash_mtp", "mtp"] -MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp", - "qwen3_next_mtp", "longcat_flash_mtp") +SpeculativeMethod = Literal[ + "ngram", + "eagle", + "eagle3", + "medusa", + "mlp_speculator", + "draft_model", + "deepseek_mtp", + "ernie_mtp", + "qwen3_next_mtp", + "mimo_mtp", + "longcat_flash_mtp", + "mtp", +] +MTP_MODEL_TYPES = ( + "deepseek_mtp", + "mimo_mtp", + "glm4_moe_mtp", + "ernie_mtp", + "qwen3_next_mtp", + "longcat_flash_mtp", +) @config @dataclass class SpeculativeConfig: """Configuration for speculative decoding.""" + enforce_eager: Optional[bool] = None """Override the default enforce_eager from model_config""" # General speculative decoding control @@ -107,8 +125,7 @@ class SpeculativeConfig: # required configuration params passed from engine target_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the target model.""" - target_parallel_config: SkipValidation[ - ParallelConfig] = None # type: ignore + target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore """The parallel configuration for the target model.""" enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """Whether vLLM is configured to use chunked prefill or not. Used for @@ -120,8 +137,7 @@ class SpeculativeConfig: # params generated in the post-init stage draft_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the draft model initialized internal.""" - draft_parallel_config: SkipValidation[ - ParallelConfig] = None # type: ignore + draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" def compute_hash(self) -> str: @@ -140,8 +156,7 @@ def compute_hash(self) -> str: # Eagle3 affects the computation graph because it returns intermediate # hidden states in addition to the final hidden state. factors.append(self.method == "eagle3") - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @staticmethod @@ -150,58 +165,57 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.model_type = "deepseek_mtp" if hf_config.model_type == "deepseek_mtp": n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["DeepSeekMTPModel"] - }) + hf_config.update( + {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]} + ) if hf_config.architectures[0] == "MiMoForCausalLM": hf_config.model_type = "mimo_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "num_hidden_layers": 0, - "n_predict": n_predict, - "architectures": ["MiMoMTPModel"] - }) + hf_config.update( + { + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["MiMoMTPModel"], + } + ) if hf_config.architectures[0] == "Glm4MoeForCausalLM": hf_config.model_type = "glm4_moe_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "num_hidden_layers": 0, - "n_predict": n_predict, - "architectures": ["Glm4MoeMTPModel"] - }) + hf_config.update( + { + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["Glm4MoeMTPModel"], + } + ) if hf_config.model_type == "ernie4_5_moe": hf_config.model_type = "ernie_mtp" if hf_config.model_type == "ernie_mtp": n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["ErnieMTPModel"] - }) + hf_config.update( + {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]} + ) if hf_config.model_type == "qwen3_next": hf_config.model_type = "qwen3_next_mtp" if hf_config.model_type == "qwen3_next_mtp": n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["Qwen3NextMTP"] - }) + hf_config.update( + {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]} + ) if hf_config.model_type == "longcat_flash": hf_config.model_type = "longcat_flash_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["LongCatFlashMTPModel"] - }) + hf_config.update( + {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]} + ) return hf_config def __post_init__(self): - # Note: "method" is a new parameter that helps to extend the # configuration of non-model-based proposers, and the "model" parameter # will be used to set the draft model, eagle head, or additional weight @@ -211,17 +225,17 @@ def __post_init__(self): # default. if self.method in MTP_MODEL_TYPES: - logger.warning("method `%s` is deprecated and replaced with mtp.", - self.method) + logger.warning( + "method `%s` is deprecated and replaced with mtp.", self.method + ) self.method = "mtp" if self.model is None and self.num_speculative_tokens is not None: if self.method == "mtp": - assert ( - self.target_model_config - is not None), "target_model_config must be present for mtp" - if self.target_model_config.hf_text_config.model_type \ - == "deepseek_v32": + assert self.target_model_config is not None, ( + "target_model_config must be present for mtp" + ) + if self.target_model_config.hf_text_config.model_type == "deepseek_v32": # FIXME(luccafong): cudgraph with v32 MTP is not supported, # remove this when the issue is fixed. self.enforce_eager = True @@ -235,21 +249,21 @@ def __post_init__(self): self.model = "ngram" else: raise ValueError( - "num_speculative_tokens was provided but without " - "speculative model.") + "num_speculative_tokens was provided but without speculative model." + ) # Automatically configure the method for ngram when "model" is used # instead of "method" - if self.method is None and (self.model is not None - and self.model in ("ngram", "[ngram]")): + if self.method is None and ( + self.model is not None and self.model in ("ngram", "[ngram]") + ): self.method = "ngram" if self.method in ("ngram", "[ngram]"): # Unified to "ngram" internally self.method = "ngram" # Set default values if not provided - if (self.prompt_lookup_min is None - and self.prompt_lookup_max is None): + if self.prompt_lookup_min is None and self.prompt_lookup_max is None: # TODO(woosuk): Tune these values. They are arbitrarily chosen. self.prompt_lookup_min = 5 self.prompt_lookup_max = 5 @@ -263,14 +277,17 @@ def __post_init__(self): # Validate values if self.prompt_lookup_min < 1: raise ValueError( - f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") + f"prompt_lookup_min={self.prompt_lookup_min} must be > 0" + ) if self.prompt_lookup_max < 1: raise ValueError( - f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") + f"prompt_lookup_max={self.prompt_lookup_max} must be > 0" + ) if self.prompt_lookup_min > self.prompt_lookup_max: raise ValueError( f"prompt_lookup_min={self.prompt_lookup_min} must " - f"be <= prompt_lookup_max={self.prompt_lookup_max}") + f"be <= prompt_lookup_max={self.prompt_lookup_max}" + ) # TODO: current we still need extract vocab_size from target model # config, in future, we may try refactor it out, and set @@ -285,25 +302,21 @@ def __post_init__(self): # TODO: Move this import to the top once `ModelConfig` # lives in `vllm.config.model`. from vllm.config import ModelConfig + self.draft_model_config = ModelConfig( model=self.model, runner="draft", tokenizer=self.target_model_config.tokenizer, tokenizer_mode=self.target_model_config.tokenizer_mode, - trust_remote_code=self.target_model_config. - trust_remote_code, - allowed_local_media_path=self.target_model_config. - allowed_local_media_path, - allowed_media_domains=self.target_model_config. - allowed_media_domains, + trust_remote_code=self.target_model_config.trust_remote_code, + allowed_local_media_path=self.target_model_config.allowed_local_media_path, + allowed_media_domains=self.target_model_config.allowed_media_domains, dtype=self.target_model_config.dtype, seed=self.target_model_config.seed, revision=self.revision, code_revision=self.code_revision, - tokenizer_revision=self.target_model_config. - tokenizer_revision, - spec_target_max_model_len=self.target_model_config. - max_model_len, + tokenizer_revision=self.target_model_config.tokenizer_revision, + spec_target_max_model_len=self.target_model_config.max_model_len, quantization=self.quantization, enforce_eager=self.target_model_config.enforce_eager, max_logprobs=self.target_model_config.max_logprobs, @@ -311,7 +324,7 @@ def __post_init__(self): ) # Automatically detect the method - if self.method in ('eagle', 'eagle3'): + if self.method in ("eagle", "eagle3"): pass # examples: # yuhuili/EAGLE-LLaMA3-Instruct-8B @@ -323,94 +336,101 @@ def __post_init__(self): self.method = "eagle3" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" - elif (self.draft_model_config.hf_config.model_type == - "mlp_speculator"): + elif self.draft_model_config.hf_config.model_type == "mlp_speculator": self.method = "mlp_speculator" - elif (self.draft_model_config.hf_config.model_type - in MTP_MODEL_TYPES): + elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES: self.method = "mtp" if self.num_speculative_tokens > 1: logger.warning( - "Enabling num_speculative_tokens > 1 will run" \ - "multiple times of forward on same MTP layer" \ - ",which may result in lower acceptance rate" \ - ) - elif (self.draft_model_config.hf_config.model_type - in ("longcat_flash_mtp")): + "Enabling num_speculative_tokens > 1 will run" + "multiple times of forward on same MTP layer" + ",which may result in lower acceptance rate" + ) + elif self.draft_model_config.hf_config.model_type in ( + "longcat_flash_mtp" + ): self.method = "longcat_flash_mtp" if self.num_speculative_tokens > 1: logger.warning( - "LongCat MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) + "LongCat MTP models only have " + "one layer. Might need some code changes " + "to support multiple layers." + ) else: self.method = "draft_model" raise NotImplementedError( "Speculative decoding with draft model is not " "supported yet. Please consider using other " "speculative decoding methods such as ngram, medusa, " - "eagle, or mtp.") + "eagle, or mtp." + ) # Replace hf_config for EAGLE draft_model if self.method in ("eagle", "eagle3"): if self.enable_chunked_prefill and not envs.VLLM_USE_V1: raise ValueError( "Chunked prefill and EAGLE are not compatible " - "when using V0.") + "when using V0." + ) - from vllm.transformers_utils.configs import ( - SpeculatorsConfig) - from vllm.transformers_utils.configs.eagle import ( - EAGLEConfig) + from vllm.transformers_utils.configs import SpeculatorsConfig + from vllm.transformers_utils.configs.eagle import EAGLEConfig - if isinstance(self.draft_model_config.hf_config, - (EAGLEConfig, SpeculatorsConfig)): + if isinstance( + self.draft_model_config.hf_config, + (EAGLEConfig, SpeculatorsConfig), + ): pass else: eagle_config = EAGLEConfig( self.draft_model_config.hf_config, method=self.method, - model_type="eagle") + model_type="eagle", + ) self.draft_model_config.hf_config = eagle_config - if (self.num_speculative_tokens is not None - and hasattr(self.draft_model_config.hf_config, - "num_lookahead_tokens")): - self.draft_model_config.hf_config.num_lookahead_tokens = \ - self.num_speculative_tokens + if self.num_speculative_tokens is not None and hasattr( + self.draft_model_config.hf_config, "num_lookahead_tokens" + ): + self.draft_model_config.hf_config.num_lookahead_tokens = ( + self.num_speculative_tokens + ) - n_predict = getattr(self.draft_model_config.hf_config, - "n_predict", None) + n_predict = getattr( + self.draft_model_config.hf_config, "n_predict", None + ) if n_predict is not None: if self.num_speculative_tokens is None: # Default to max value defined in draft model config. self.num_speculative_tokens = n_predict - elif self.num_speculative_tokens > n_predict and \ - self.num_speculative_tokens % n_predict != 0: + elif ( + self.num_speculative_tokens > n_predict + and self.num_speculative_tokens % n_predict != 0 + ): # Ensure divisibility for MTP module reuse. raise ValueError( f"num_speculative_tokens:{self.num_speculative_tokens}" - f" must be divisible by {n_predict=}") + f" must be divisible by {n_predict=}" + ) if self.speculative_token_tree is None: # Generate chain of tokens. - self.speculative_token_tree = str([ - (i + 1) * (0, ) - for i in range(self.num_speculative_tokens) - ]) + self.speculative_token_tree = str( + [(i + 1) * (0,) for i in range(self.num_speculative_tokens)] + ) else: # Sort the token tree breadth-first. - tree_choices = ast.literal_eval( - self.speculative_token_tree) + tree_choices = ast.literal_eval(self.speculative_token_tree) self.speculative_token_tree = str( - sorted(tree_choices, key=lambda t: (len(t), t))) + sorted(tree_choices, key=lambda t: (len(t), t)) + ) - self.draft_tensor_parallel_size = \ + self.draft_tensor_parallel_size = ( SpeculativeConfig._verify_and_get_draft_tp( self.target_parallel_config, self.draft_tensor_parallel_size, - self.draft_model_config.hf_config + self.draft_model_config.hf_config, + ) ) self.draft_model_config.max_model_len = ( @@ -418,12 +438,14 @@ def __post_init__(self): self.max_model_len, self.draft_model_config.max_model_len, self.target_model_config.max_model_len, - )) + ) + ) self.draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( - self.target_parallel_config, - self.draft_tensor_parallel_size)) + self.target_parallel_config, self.draft_tensor_parallel_size + ) + ) @staticmethod def _maybe_override_draft_max_model_len( @@ -444,14 +466,17 @@ def _maybe_override_draft_max_model_len( """ if speculative_max_model_len is not None: - if speculative_max_model_len > draft_max_model_len: - raise ValueError(f"{speculative_max_model_len=} cannot be " - f"larger than {draft_max_model_len=}") + raise ValueError( + f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}" + ) if speculative_max_model_len > target_max_model_len: - raise ValueError(f"{speculative_max_model_len=} cannot be " - f"larger than {target_max_model_len=}") + raise ValueError( + f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}" + ) return speculative_max_model_len @@ -462,9 +487,10 @@ def _maybe_override_draft_max_model_len( @staticmethod def _verify_and_get_draft_tp( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int], - draft_hf_config: PretrainedConfig) -> int: + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig, + ) -> int: """ Verifies and adjusts the tensor parallel size for a draft model specified using speculative_draft_tensor_parallel_size. @@ -478,15 +504,20 @@ def _verify_and_get_draft_tp( logger.warning( "%s cannot currently be run with tp>1; " "setting speculative_draft_tensor_parallel_size=1", - draft_hf_config.model_type) + draft_hf_config.model_type, + ) else: - speculative_draft_tensor_parallel_size = \ + speculative_draft_tensor_parallel_size = ( target_parallel_config.tensor_parallel_size + ) elif speculative_draft_tensor_parallel_size not in ( - 1, target_parallel_config.tensor_parallel_size): + 1, + target_parallel_config.tensor_parallel_size, + ): raise ValueError( f"{speculative_draft_tensor_parallel_size=} cannot be " - f"other value than 1 or target model tensor_parallel_size") + f"other value than 1 or target model tensor_parallel_size" + ) return speculative_draft_tensor_parallel_size @staticmethod @@ -499,52 +530,57 @@ def create_draft_parallel_config( This is mostly a copy of the target parallel config, except the tp_size. """ draft_parallel_config = ParallelConfig( - pipeline_parallel_size=target_parallel_config. - pipeline_parallel_size, + pipeline_parallel_size=target_parallel_config.pipeline_parallel_size, tensor_parallel_size=speculative_draft_tensor_parallel_size, - distributed_executor_backend=target_parallel_config. - distributed_executor_backend, - max_parallel_loading_workers=target_parallel_config. - max_parallel_loading_workers, - disable_custom_all_reduce=target_parallel_config. - disable_custom_all_reduce, - ray_workers_use_nsight=target_parallel_config. - ray_workers_use_nsight, + distributed_executor_backend=target_parallel_config.distributed_executor_backend, + max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce, + ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight, placement_group=target_parallel_config.placement_group, ) return draft_parallel_config - @model_validator(mode='after') + @model_validator(mode="after") def _verify_args(self) -> Self: if self.num_speculative_tokens is None: raise ValueError( "num_speculative_tokens must be provided with " "speculative model unless the draft model config contains an " - "n_predict parameter.") + "n_predict parameter." + ) if self.num_speculative_tokens <= 0: - raise ValueError("Expected num_speculative_tokens to be greater " - f"than zero ({self.num_speculative_tokens}).") + raise ValueError( + "Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens})." + ) if self.draft_model_config: self.draft_model_config.verify_with_parallel_config( - self.draft_parallel_config) + self.draft_parallel_config + ) - if (self.disable_by_batch_size is not None - and self.disable_by_batch_size < 2): - raise ValueError("Expect the batch size threshold of disabling " - "speculative decoding is > 1, but got " - f"{self.disable_by_batch_size=}") + if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2: + raise ValueError( + "Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{self.disable_by_batch_size=}" + ) eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"] - if self.method == "eagle3" and self.target_model_config and not any( - supported_model in - self.target_model_config.hf_text_config.model_type - for supported_model in eagle3_target_supported): + if ( + self.method == "eagle3" + and self.target_model_config + and not any( + supported_model in self.target_model_config.hf_text_config.model_type + for supported_model in eagle3_target_supported + ) + ): raise ValueError( f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 - f"Got {self.target_model_config.hf_text_config.model_type=}") + f"Got {self.target_model_config.hf_text_config.model_type=}" + ) return self diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index b1f14294510f..5111c9c77d90 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -8,8 +8,9 @@ from vllm.config.utils import config -StructuredOutputsBackend = Literal["auto", "xgrammar", "guidance", "outlines", - "lm-format-enforcer"] +StructuredOutputsBackend = Literal[ + "auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer" +] @config @@ -50,15 +51,17 @@ def compute_hash(self) -> str: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self): - if (self.disable_any_whitespace - and self.backend not in ("xgrammar", "guidance")): - raise ValueError("disable_any_whitespace is only supported for " - "xgrammar and guidance backends.") - if (self.disable_additional_properties and self.backend != "guidance"): - raise ValueError("disable_additional_properties is only supported " - "for the guidance backend.") + if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"): + raise ValueError( + "disable_any_whitespace is only supported for " + "xgrammar and guidance backends." + ) + if self.disable_additional_properties and self.backend != "guidance": + raise ValueError( + "disable_additional_properties is only supported " + "for the guidance backend." + ) diff --git a/vllm/config/utils.py b/vllm/config/utils.py index d355ff3a9023..889ebf45b12d 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility functions for vLLM config dataclasses.""" + import ast import inspect import textwrap @@ -50,7 +51,8 @@ def get_field(cls: ConfigType, name: str) -> Field: if (default := named_field.default) is not MISSING: return field(default=default) raise ValueError( - f"{cls.__name__}.{name} must have a default value or default factory.") + f"{cls.__name__}.{name} must have a default value or default factory." + ) def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any: @@ -78,7 +80,7 @@ def contains_object_print(text: str) -> bool: Returns: result (bool): `True` if a match is found, `False` otherwise. """ - pattern = r'at 0x[a-fA-F0-9]{2,16}>' + pattern = r"at 0x[a-fA-F0-9]{2,16}>" match = re.search(pattern, text) return match is not None @@ -89,7 +91,8 @@ def assert_hashable(text: str) -> bool: raise AssertionError( f"vLLM tried to hash some configs that may have Python objects ids " f"in them. This is a bug, please file an issue. " - f"Text being hashed: {text}") + f"Text being hashed: {text}" + ) def get_attr_docs(cls: type[Any]) -> dict[str, str]: @@ -132,10 +135,12 @@ def pairwise(iterable): # Consider each pair of nodes. for a, b in pairwise(cls_node.body): # Must be an assignment then a constant string. - if (not isinstance(a, (ast.Assign, ast.AnnAssign)) - or not isinstance(b, ast.Expr) - or not isinstance(b.value, ast.Constant) - or not isinstance(b.value.value, str)): + if ( + not isinstance(a, (ast.Assign, ast.AnnAssign)) + or not isinstance(b, ast.Expr) + or not isinstance(b.value, ast.Constant) + or not isinstance(b.value.value, str) + ): continue doc = inspect.cleandoc(b.value.value) @@ -160,29 +165,28 @@ def is_init_field(cls: ConfigType, name: str) -> bool: @runtime_checkable class SupportsHash(Protocol): - - def compute_hash(self) -> str: - ... + def compute_hash(self) -> str: ... class SupportsMetricsInfo(Protocol): - - def metrics_info(self) -> dict[str, str]: - ... + def metrics_info(self) -> dict[str, str]: ... def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: processed_overrides = {} for field_name, value in overrides.items(): - assert hasattr( - config, field_name), f"{type(config)} has no field `{field_name}`" + assert hasattr(config, field_name), ( + f"{type(config)} has no field `{field_name}`" + ) current_value = getattr(config, field_name) if is_dataclass(current_value) and not is_dataclass(value): assert isinstance(value, dict), ( f"Overrides to {type(config)}.{field_name} must be a dict" - f" or {type(current_value)}, but got {type(value)}") + f" or {type(current_value)}, but got {type(value)}" + ) value = update_config( current_value, # type: ignore[type-var] - value) + value, + ) processed_overrides[field_name] = value return replace(config, **processed_overrides) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ac40b0fd4783..b5856958ce2e 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -38,8 +38,7 @@ if TYPE_CHECKING: from transformers import PretrainedConfig - from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig else: PretrainedConfig = Any @@ -74,14 +73,14 @@ class VllmConfig: speculative_config: Optional[SpeculativeConfig] = None """Speculative decoding configuration.""" structured_outputs_config: StructuredOutputsConfig = field( - default_factory=StructuredOutputsConfig) + default_factory=StructuredOutputsConfig + ) """Structured outputs configuration.""" observability_config: Optional[ObservabilityConfig] = None """Observability configuration.""" quant_config: Optional[QuantizationConfig] = None """Quantization configuration.""" - compilation_config: CompilationConfig = field( - default_factory=CompilationConfig) + compilation_config: CompilationConfig = field(default_factory=CompilationConfig) """`torch.compile` and cudagraph capture configuration for the model. As a shorthand, `-O<n>` can be used to directly specify the compilation @@ -127,6 +126,7 @@ def compute_hash(self) -> str: # summarize vllm config vllm_factors: list[Any] = [] from vllm import __version__ + vllm_factors.append(__version__) vllm_factors.append(envs.VLLM_USE_V1) if self.model_config: @@ -158,8 +158,7 @@ def compute_hash(self) -> str: # LoRA creates static buffers based on max_num_batched_tokens. # The tensor sizes and strides get captured in the torch.compile # graph explicitly. - vllm_factors.append( - str(self.scheduler_config.max_num_batched_tokens)) + vllm_factors.append(str(self.scheduler_config.max_num_batched_tokens)) else: vllm_factors.append("None") if self.speculative_config: @@ -197,8 +196,9 @@ def compute_hash(self) -> str: vllm_factors.append("None") factors.append(vllm_factors) - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest()[:10] + hash_str = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] return hash_str def pad_for_cudagraph(self, batch_size: int) -> int: @@ -210,13 +210,14 @@ def pad_for_cudagraph(self, batch_size: int) -> int: @staticmethod def _get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: + model_config: ModelConfig, load_config: LoadConfig + ) -> Optional[QuantizationConfig]: """Get the quantization config.""" from vllm.platforms import current_platform + if model_config.quantization is not None: - from vllm.model_executor.model_loader.weight_utils import ( - get_quant_config) + from vllm.model_executor.model_loader.weight_utils import get_quant_config + quant_config = get_quant_config(model_config, load_config) capability_tuple = current_platform.get_device_capability() @@ -227,27 +228,30 @@ def _get_quantization_config( f"The quantization method {model_config.quantization} " "is not supported for the current GPU. Minimum " f"capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + f"Current capability: {capability}." + ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") + f"{supported_dtypes}" + ) quant_config.maybe_update_config(model_config.model) return quant_config return None @staticmethod def get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: + model_config: ModelConfig, load_config: LoadConfig + ) -> Optional[QuantizationConfig]: import copy # For some reason, the _ version of this modifies the model_config # object, so using deepcopy to avoid this problem. - return VllmConfig._get_quantization_config(copy.deepcopy(model_config), - load_config) + return VllmConfig._get_quantization_config( + copy.deepcopy(model_config), load_config + ) def with_hf_config( self, @@ -264,15 +268,13 @@ def with_hf_config( return replace(self, model_config=model_config) def __post_init__(self): - """Verify configs are valid & consistent with each other. - """ + """Verify configs are valid & consistent with each other.""" self.try_verify_and_update_config() if self.model_config is not None: self.model_config.verify_with_parallel_config(self.parallel_config) - self.model_config.verify_dual_chunk_attention_config( - self.load_config) + self.model_config.verify_dual_chunk_attention_config(self.load_config) self.cache_config.verify_with_parallel_config(self.parallel_config) @@ -282,29 +284,35 @@ def __post_init__(self): if self.quant_config is None and self.model_config is not None: self.quant_config = VllmConfig._get_quantization_config( - self.model_config, self.load_config) + self.model_config, self.load_config + ) from vllm.platforms import current_platform - if self.model_config is not None and \ - self.scheduler_config.chunked_prefill_enabled and \ - self.model_config.dtype == torch.float32 and \ - current_platform.get_device_capability() == (7, 5): + + if ( + self.model_config is not None + and self.scheduler_config.chunked_prefill_enabled + and self.model_config.dtype == torch.float32 + and current_platform.get_device_capability() == (7, 5) + ): logger.warning_once( "Turing devices tensor cores do not support float32 matmul. " "To workaround this limitation, vLLM will set 'ieee' input " - "precision for chunked prefill triton kernels.") + "precision for chunked prefill triton kernels." + ) # If the user does not explicitly set a compilation level, then # we use the default level. The default level depends on other # settings (see the below code). if self.compilation_config.level is None: if envs.VLLM_USE_V1: - if (self.model_config is not None - and not self.model_config.enforce_eager): + if ( + self.model_config is not None + and not self.model_config.enforce_eager + ): self.compilation_config.level = CompilationLevel.PIECEWISE else: - self.compilation_config.level = \ - CompilationLevel.NO_COMPILATION + self.compilation_config.level = CompilationLevel.NO_COMPILATION else: # NB: Passing both --enforce-eager and a compilation level @@ -314,8 +322,7 @@ def __post_init__(self): # async tp is built on top of sequence parallelism # and requires it to be enabled. if self.compilation_config.pass_config.enable_async_tp: - self.compilation_config.pass_config.enable_sequence_parallelism = \ - True + self.compilation_config.pass_config.enable_sequence_parallelism = True if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") @@ -323,25 +330,27 @@ def __post_init__(self): # if cudagraph_mode is not explicitly set by users, set default # value if self.compilation_config.cudagraph_mode is None: - if envs.VLLM_USE_V1 and self.compilation_config.level \ - == CompilationLevel.PIECEWISE: + if ( + envs.VLLM_USE_V1 + and self.compilation_config.level == CompilationLevel.PIECEWISE + ): # default to full and piecewise for most models - self.compilation_config.cudagraph_mode = \ + self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE + ) # pooling models and encoder-decoder models # do not support full cudagraphs - if self.model_config is not None and \ - (self.model_config.pooler_config is not None - or self.model_config.is_encoder_decoder): - self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE + if self.model_config is not None and ( + self.model_config.pooler_config is not None + or self.model_config.is_encoder_decoder + ): + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE # disable cudagraph when enforce eager execution - if self.model_config is not None and \ - self.model_config.enforce_eager: + if self.model_config is not None and self.model_config.enforce_eager: logger.info("Cudagraph is disabled under eager mode") self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE elif envs.VLLM_USE_V1: @@ -352,18 +361,21 @@ def __post_init__(self): self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE if self.cache_config.kv_sharing_fast_prefill: - - if self.speculative_config is not None and \ - self.speculative_config.use_eagle(): + if ( + self.speculative_config is not None + and self.speculative_config.use_eagle() + ): raise NotImplementedError( "Fast prefill optimization for KV sharing is not " "compatible with EAGLE as EAGLE requires correct logits " "for all tokens while fast prefill gives incorrect logits " - "for prompt tokens.") + "for prompt tokens." + ) logger.warning_once( "--kv-sharing-fast-prefill requires changes on model side for " - "correctness and to realize prefill savings. ") + "correctness and to realize prefill savings. " + ) disable_chunked_prefill_reasons: list[str] = [] @@ -372,34 +384,51 @@ def __post_init__(self): pooling_type = self.model_config.pooler_config.pooling_type if pooling_type is None or pooling_type.lower() != "last": disable_chunked_prefill_reasons.append( - "Only \"last\" pooling supports chunked " - "prefill and prefix caching; disabling both.") + 'Only "last" pooling supports chunked ' + "prefill and prefix caching; disabling both." + ) if not getattr(self.model_config.hf_config, "is_causal", True): disable_chunked_prefill_reasons.append( "Only models using causal attention supports chunked " - "prefill and prefix caching; disabling both.") + "prefill and prefix caching; disabling both." + ) elif self.model_config.is_encoder_decoder: from vllm.multimodal import MULTIMODAL_REGISTRY - self.scheduler_config.max_num_encoder_input_tokens = \ + + self.scheduler_config.max_num_encoder_input_tokens = ( MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) + ) logger.debug( "Encoder-decoder model detected: setting " "`max_num_encoder_input_tokens` to encoder length (%s)", - self.scheduler_config.max_num_encoder_input_tokens) - if (self.model_config.architecture - == "WhisperForConditionalGeneration" - and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") - != "spawn"): + self.scheduler_config.max_num_encoder_input_tokens, + ) + if ( + self.model_config.architecture == "WhisperForConditionalGeneration" + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn" + ): logger.warning( "Whisper is known to have issues with " "forked workers. If startup is hanging, " "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " - "to 'spawn'.") + "to 'spawn'." + ) - # Disable prefix caching only if chunked prefill is explicitly disabled - # (and not merely unset) - if (self.scheduler_config.chunked_prefill_enabled is False - or disable_chunked_prefill_reasons): + # Final off-switch for CP/APC: + # Disable for (a) collected blockers, (b) encoder–decoder, or + # (c) explicit CP=False when APC wasn't requested. + # Do NOT disable merely because the resolved CP flag is False. + apc_requested = ( + self.cache_config is not None and self.cache_config.enable_prefix_caching + ) + if ( + disable_chunked_prefill_reasons + or (self.model_config is not None and self.model_config.is_encoder_decoder) + or ( + self.scheduler_config.enable_chunked_prefill is False + and not apc_requested + ) + ): for reason in disable_chunked_prefill_reasons: logger.info(reason) self.scheduler_config.chunked_prefill_enabled = False @@ -408,72 +437,85 @@ def __post_init__(self): if self.cache_config is not None: self.cache_config.enable_prefix_caching = False - if (self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events - and not self.cache_config.enable_prefix_caching): + if ( + self.kv_events_config is not None + and self.kv_events_config.enable_kv_cache_events + and not self.cache_config.enable_prefix_caching + ): logger.warning( "KV cache events are on, but prefix caching is not enabled." - "Use --enable-prefix-caching to enable.") - if (self.kv_events_config is not None - and self.kv_events_config.publisher != "null" - and not self.kv_events_config.enable_kv_cache_events): - logger.warning("KV cache events are disabled," - "but the scheduler is configured to publish them." - "Modify KVEventsConfig.enable_kv_cache_events" - "to True to enable.") + "Use --enable-prefix-caching to enable." + ) + if ( + self.kv_events_config is not None + and self.kv_events_config.publisher != "null" + and not self.kv_events_config.enable_kv_cache_events + ): + logger.warning( + "KV cache events are disabled," + "but the scheduler is configured to publish them." + "Modify KVEventsConfig.enable_kv_cache_events" + "to True to enable." + ) current_platform.check_and_update_config(self) # Do this after all the updates to compilation_config.level - if envs.VLLM_USE_V1 and \ - self.compilation_config.level == CompilationLevel.PIECEWISE: + if ( + envs.VLLM_USE_V1 + and self.compilation_config.level == CompilationLevel.PIECEWISE + ): self.compilation_config.set_splitting_ops_for_v1() # final check of cudagraph mode after all possible updates if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): - if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\ - and self.model_config is not None and \ - not self.model_config.disable_cascade_attn and\ - not self.compilation_config.cudagraph_mode.\ - has_piecewise_cudagraphs(): + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and self.model_config is not None + and not self.model_config.disable_cascade_attn + and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() # noqa: E501 + ): logger.warning_once( "No piecewise cudagraph for executing cascade attention." " Will fall back to eager execution if a batch runs " - "into cascade attentions") - - if self.compilation_config.cudagraph_mode\ - .requires_piecewise_compilation(): - assert self.compilation_config.level == \ - CompilationLevel.PIECEWISE, \ - "Compilation level should be CompilationLevel.PIECEWISE "\ - "when cudagraph_mode piecewise cudagraphs is used, "\ + "into cascade attentions" + ) + + if self.compilation_config.cudagraph_mode.requires_piecewise_compilation(): + assert self.compilation_config.level == CompilationLevel.PIECEWISE, ( + "Compilation level should be CompilationLevel.PIECEWISE " + "when cudagraph_mode piecewise cudagraphs is used, " f"cudagraph_mode={self.compilation_config.cudagraph_mode}" + ) # final migrate the deprecated flags - self.compilation_config.use_cudagraph = self.compilation_config.\ - cudagraph_mode!= CUDAGraphMode.NONE - self.compilation_config.full_cuda_graph = self.compilation_config.\ - cudagraph_mode.has_full_cudagraphs() + self.compilation_config.use_cudagraph = ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ) + self.compilation_config.full_cuda_graph = ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) if self.parallel_config.enable_dbo: a2a_backend = envs.VLLM_ALL2ALL_BACKEND - assert a2a_backend in \ - ["deepep_low_latency", "deepep_high_throughput"], \ - "Microbatching currently only supports the deepep_low_latency and "\ - f"deepep_high_throughput all2all backend. {a2a_backend} is not "\ - "supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\ - "variable to deepep_low_latency or deepep_high_throughput and "\ - "install the DeepEP kernels." + assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], ( + "Microbatching currently only supports the deepep_low_latency and " + f"deepep_high_throughput all2all backend. {a2a_backend} is not " + "supported. To fix set the VLLM_ALL2ALL_BACKEND environment " + "variable to deepep_low_latency or deepep_high_throughput and " + "install the DeepEP kernels." + ) if not self.model_config.disable_cascade_attn: self.model_config.disable_cascade_attn = True - logger.warning_once( - "Disabling cascade attention when DBO is enabled.") + logger.warning_once("Disabling cascade attention when DBO is enabled.") if not self.instance_id: self.instance_id = random_uuid()[:5] - if (envs.VLLM_USE_V1 - and not self.scheduler_config.disable_hybrid_kv_cache_manager): + if ( + envs.VLLM_USE_V1 + and not self.scheduler_config.disable_hybrid_kv_cache_manager + ): # logger should only print warning message for hybrid models. As we # can't know whether the model is hybrid or not now, so we don't log # warning message here and will log it later. @@ -486,15 +528,18 @@ def __post_init__(self): if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.model_config is not None and \ - self.model_config.attention_chunk_size is not None: - if self.speculative_config is not None and \ - self.speculative_config.use_eagle(): + if ( + self.model_config is not None + and self.model_config.attention_chunk_size is not None + ): + if ( + self.speculative_config is not None + and self.speculative_config.use_eagle() + ): # Hybrid KV cache manager is not yet supported with chunked # local attention + eagle. self.scheduler_config.disable_hybrid_kv_cache_manager = True - elif \ - not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: + elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: logger.warning( "There is a latency regression when using chunked local" " attention with the hybrid KV cache manager. Disabling" @@ -506,14 +551,17 @@ def __post_init__(self): self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.compilation_config.debug_dump_path: - self.compilation_config.debug_dump_path = \ + self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path.absolute().expanduser() + ) if envs.VLLM_DEBUG_DUMP_PATH is not None: env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser() if self.compilation_config.debug_dump_path: logger.warning( "Config-specified debug dump path is overridden" - " by VLLM_DEBUG_DUMP_PATH to %s", env_path) + " by VLLM_DEBUG_DUMP_PATH to %s", + env_path, + ) self.compilation_config.debug_dump_path = env_path def has_blocked_weights(): @@ -533,23 +581,26 @@ def has_blocked_weights(): if "none" not in custom_ops and "-quant_fp8" not in custom_ops: custom_ops.append("+quant_fp8") - def update_sizes_for_sequence_parallelism(self, - possible_sizes: list) -> list: + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when # enable sequence parallelism removed_sizes = [ - size for size in possible_sizes + size + for size in possible_sizes if size % self.parallel_config.tensor_parallel_size != 0 ] if removed_sizes: logger.warning( "Batch sizes %s are removed because they are not " "multiple of tp_size %d when " - "sequence parallelism is enabled", removed_sizes, - self.parallel_config.tensor_parallel_size) + "sequence parallelism is enabled", + removed_sizes, + self.parallel_config.tensor_parallel_size, + ) return [ - size for size in possible_sizes + size + for size in possible_sizes if size % self.parallel_config.tensor_parallel_size == 0 ] @@ -593,13 +644,13 @@ def _set_cudagraph_sizes(self): # calculate the default `batch_size_capture_list` batch_size_capture_list = [] - if self.model_config is not None and \ - not self.model_config.enforce_eager: + if self.model_config is not None and not self.model_config.enforce_eager: cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes if len(cuda_graph_sizes) == 1: max_graph_size = cuda_graph_sizes[0] - assert max_graph_size >= 1, "Maximum cudagraph size should be" \ - " greater than or equal to 1." + assert max_graph_size >= 1, ( + "Maximum cudagraph size should be greater than or equal to 1." + ) batch_size_capture_list = [ i for i in [1, 2, 4] if i <= max_graph_size ] + list(range(8, max_graph_size + 1, 8)) @@ -607,18 +658,19 @@ def _set_cudagraph_sizes(self): batch_size_capture_list = sorted(cuda_graph_sizes) else: raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") - if self.parallel_config.tensor_parallel_size > 1 and \ - self.compilation_config.pass_config.enable_sequence_parallelism: - batch_size_capture_list = \ - self.update_sizes_for_sequence_parallelism(batch_size_capture_list) + if ( + self.parallel_config.tensor_parallel_size > 1 + and self.compilation_config.pass_config.enable_sequence_parallelism + ): + batch_size_capture_list = self.update_sizes_for_sequence_parallelism( + batch_size_capture_list + ) max_num_tokens = self.scheduler_config.max_num_batched_tokens batch_size_capture_list = [ - size for size in batch_size_capture_list - if size <= max_num_tokens + size for size in batch_size_capture_list if size <= max_num_tokens ] - self.compilation_config.init_with_cudagraph_sizes( - batch_size_capture_list) + self.compilation_config.init_with_cudagraph_sizes(batch_size_capture_list) def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config @@ -641,7 +693,10 @@ def try_verify_and_update_config(self): return from vllm.model_executor.models.config import ( - MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig) + MODELS_CONFIG_MAP, + HybridAttentionMambaModelConfig, + ) + cls = MODELS_CONFIG_MAP.get(architecture, None) if cls is not None: cls.verify_and_update_config(self) @@ -651,24 +706,29 @@ def try_verify_and_update_config(self): if self.model_config.convert_type == "classify": # Maybe convert ForCausalLM into ForSequenceClassification model. - from vllm.model_executor.models.adapters import ( - SequenceClassificationConfig) + from vllm.model_executor.models.adapters import SequenceClassificationConfig + SequenceClassificationConfig.verify_and_update_config(self) if hasattr(self.model_config, "model_weights") and is_runai_obj_uri( - self.model_config.model_weights): + self.model_config.model_weights + ): if self.load_config.load_format == "auto": - logger.info("Detected Run:ai model config. " - "Overriding `load_format` to 'runai_streamer'") + logger.info( + "Detected Run:ai model config. " + "Overriding `load_format` to 'runai_streamer'" + ) self.load_config.load_format = "runai_streamer" elif self.load_config.load_format != "runai_streamer": - raise ValueError(f"To load a model from S3, 'load_format' " - f"must be 'runai_streamer', " - f"but got '{self.load_config.load_format}'. " - f"Model: {self.model_config.model}") + raise ValueError( + f"To load a model from S3, 'load_format' " + f"must be 'runai_streamer', " + f"but got '{self.load_config.load_format}'. " + f"Model: {self.model_config.model}" + ) def compile_debug_dump_path(self) -> Optional[Path]: - """Returns a rank-aware path for dumping + """Returns a rank-aware path for dumping torch.compile debug information. """ if self.compilation_config.debug_dump_path is None: @@ -676,8 +736,11 @@ def compile_debug_dump_path(self) -> Optional[Path]: tp_rank = self.parallel_config.rank dp_rank = self.parallel_config.data_parallel_rank data_parallel_size = self.parallel_config.data_parallel_size - append_path = f"rank_{tp_rank}" if data_parallel_size == 1 \ + append_path = ( + f"rank_{tp_rank}" + if data_parallel_size == 1 else f"rank_{tp_rank}_dp_{dp_rank}" + ) path = self.compilation_config.debug_dump_path / append_path return path @@ -710,7 +773,8 @@ def __str__(self): f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa f"pooler_config={self.model_config.pooler_config!r}, " - f"compilation_config={self.compilation_config!r}") + f"compilation_config={self.compilation_config!r}" + ) _current_vllm_config: Optional[VllmConfig] = None @@ -718,9 +782,9 @@ def __str__(self): @contextmanager -def set_current_vllm_config(vllm_config: VllmConfig, - check_compile=False, - prefix: Optional[str] = None): +def set_current_vllm_config( + vllm_config: VllmConfig, check_compile=False, prefix: Optional[str] = None +): """ Temporarily set the current vLLM config. Used during model initialization. @@ -732,6 +796,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, old_vllm_config = _current_vllm_config old_prefix = _current_prefix from vllm.compilation.counter import compilation_counter + num_models_seen = compilation_counter.num_models_seen try: _current_vllm_config = vllm_config @@ -743,9 +808,11 @@ def set_current_vllm_config(vllm_config: VllmConfig, if check_compile: vllm_config.compilation_config.custom_op_log_check() - if check_compile and \ - vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ - and compilation_counter.num_models_seen == num_models_seen: + if ( + check_compile + and vllm_config.compilation_config.level == CompilationLevel.PIECEWISE + and compilation_counter.num_models_seen == num_models_seen + ): # If the model supports compilation, # compilation_counter.num_models_seen should be increased # by at least 1. @@ -755,7 +822,8 @@ def set_current_vllm_config(vllm_config: VllmConfig, "`torch.compile` is turned on, but the model %s" " does not support it. Please open an issue on GitHub" " if you want it to be supported.", - vllm_config.model_config.model) + vllm_config.model_config.model, + ) finally: _current_vllm_config = old_vllm_config _current_prefix = old_prefix @@ -783,9 +851,10 @@ def get_current_vllm_config() -> VllmConfig: def get_layers_from_vllm_config( - vllm_config: VllmConfig, - layer_type: type[T], - layer_names: Optional[list[str]] = None) -> dict[str, T]: + vllm_config: VllmConfig, + layer_type: type[T], + layer_names: Optional[list[str]] = None, +) -> dict[str, T]: """ Get layers from the vLLM config. @@ -796,8 +865,7 @@ def get_layers_from_vllm_config( """ if layer_names is None: - layer_names = list( - vllm_config.compilation_config.static_forward_context.keys()) + layer_names = list(vllm_config.compilation_config.static_forward_context.keys()) forward_context = vllm_config.compilation_config.static_forward_context diff --git a/vllm/connections.py b/vllm/connections.py index 1f341719ae30..8d5e0e5cbf5d 100644 --- a/vllm/connections.py +++ b/vllm/connections.py @@ -41,8 +41,9 @@ def _validate_http_url(self, url: str): parsed_url = urlparse(url) if parsed_url.scheme not in ("http", "https"): - raise ValueError("Invalid HTTP URL: A valid HTTP URL " - "must have scheme 'http' or 'https'.") + raise ValueError( + "Invalid HTTP URL: A valid HTTP URL must have scheme 'http' or 'https'." + ) def _headers(self, **extras: str) -> MutableMapping[str, str]: return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras} @@ -61,11 +62,13 @@ def get_response( client = self.get_sync_client() extra_headers = extra_headers or {} - return client.get(url, - headers=self._headers(**extra_headers), - stream=stream, - timeout=timeout, - allow_redirects=allow_redirects) + return client.get( + url, + headers=self._headers(**extra_headers), + stream=stream, + timeout=timeout, + allow_redirects=allow_redirects, + ) async def get_async_response( self, @@ -80,19 +83,19 @@ async def get_async_response( client = await self.get_async_client() extra_headers = extra_headers or {} - return client.get(url, - headers=self._headers(**extra_headers), - timeout=timeout, - allow_redirects=allow_redirects) - - def get_bytes(self, - url: str, - *, - timeout: Optional[float] = None, - allow_redirects: bool = True) -> bytes: - with self.get_response(url, - timeout=timeout, - allow_redirects=allow_redirects) as r: + return client.get( + url, + headers=self._headers(**extra_headers), + timeout=timeout, + allow_redirects=allow_redirects, + ) + + def get_bytes( + self, url: str, *, timeout: Optional[float] = None, allow_redirects: bool = True + ) -> bytes: + with self.get_response( + url, timeout=timeout, allow_redirects=allow_redirects + ) as r: r.raise_for_status() return r.content @@ -105,7 +108,8 @@ async def async_get_bytes( allow_redirects: bool = True, ) -> bytes: async with await self.get_async_response( - url, timeout=timeout, allow_redirects=allow_redirects) as r: + url, timeout=timeout, allow_redirects=allow_redirects + ) as r: r.raise_for_status() return await r.read() diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index af7ca6be1fca..97c6654385b3 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -28,7 +28,7 @@ def find_loaded_library(lib_name) -> Optional[str]: the file `/proc/self/maps` contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library. - """ # noqa + """ # noqa found_line = None with open("/proc/self/maps") as f: for line in f: @@ -43,17 +43,21 @@ def find_loaded_library(lib_name) -> Optional[str]: start = found_line.index("/") path = found_line[start:].strip() filename = path.split("/")[-1] - assert filename.rpartition(".so")[0].startswith(lib_name), \ + assert filename.rpartition(".so")[0].startswith(lib_name), ( f"Unexpected filename: {filename} for library {lib_name}" + ) return path cumem_available = False try: - from vllm.cumem_allocator import (init_module, python_create_and_map, - python_unmap_and_release) - from vllm.distributed.device_communicators.cuda_wrapper import ( - CudaRTLibrary) + from vllm.cumem_allocator import ( + init_module, + python_create_and_map, + python_unmap_and_release, + ) + from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary + lib_name = find_loaded_library("cumem_allocator") libcudart = CudaRTLibrary() cumem_available = True @@ -86,20 +90,19 @@ def unmap_and_release(allocation_handle: HandleType) -> None: def get_pluggable_allocator( - python_malloc_fn: Callable[[int], - int], python_free_func: Callable[[int, int], - None] + python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None] ) -> torch.cuda.memory.CUDAPluggableAllocator: init_module(python_malloc_fn, python_free_func) new_alloc = torch.cuda.memory.CUDAPluggableAllocator( - lib_name, 'my_malloc', 'my_free') + lib_name, "my_malloc", "my_free" + ) return new_alloc @contextmanager def use_memory_pool_with_allocator( - python_malloc_fn: Callable[[int], int], - python_free_func: Callable[[int, int], None]) -> None: + python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None] +) -> None: new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) with torch.cuda.memory.use_mem_pool(mem_pool): @@ -130,6 +133,7 @@ class CuMemAllocator: the global variable will be overwritten and the free callback will not work as expected. """ + instance: "CuMemAllocator" = None default_tag: str = "default" @@ -147,10 +151,11 @@ def get_instance() -> "CuMemAllocator": def __init__(self): conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") - assert "expandable_segments:True" not in conf, \ - ("Expandable segments are not compatible with memory pool. " + assert "expandable_segments:True" not in conf, ( + "Expandable segments are not compatible with memory pool. " "Please track https://github.com/pytorch/pytorch/issues/147851 " - "for the latest updates.") + "for the latest updates." + ) self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag @@ -167,10 +172,14 @@ def _python_malloc_callback(self, allocation_handle: HandleType) -> None: when memory is allocated in the memory pool.""" py_d_mem = allocation_handle[2] self.pointer_to_data[py_d_mem] = AllocationData( - allocation_handle, self.current_tag) + allocation_handle, self.current_tag + ) logger.debug( "Allocated %s bytes for %s with address %s from cumem allocator", - allocation_handle[1], self.current_tag, py_d_mem) + allocation_handle[1], + self.current_tag, + py_d_mem, + ) return def _python_free_callback(self, ptr: int) -> HandleType: @@ -182,13 +191,13 @@ def _python_free_callback(self, ptr: int) -> HandleType: data.cpu_backup_tensor = None logger.debug( "Freed %s bytes for %s with address %s from cumem allocator", - data.handle[1], data.tag, ptr) + data.handle[1], + data.tag, + ptr, + ) return data.handle - def sleep( - self, - offload_tags: Optional[Union[tuple[str, ...], - str]] = None) -> None: + def sleep(self, offload_tags: Optional[Union[tuple[str, ...], str]] = None) -> None: """ Put the allocator in sleep mode. All data in the memory allocation with the specified tag will be @@ -200,9 +209,9 @@ def sleep( if offload_tags is None: # by default, allocated tensors are offloaded # when the allocator sleeps - offload_tags = (CuMemAllocator.default_tag, ) + offload_tags = (CuMemAllocator.default_tag,) elif isinstance(offload_tags, str): - offload_tags = (offload_tags, ) + offload_tags = (offload_tags,) assert isinstance(offload_tags, tuple) @@ -218,8 +227,9 @@ def sleep( cpu_backup_tensor = torch.empty( size_in_bytes, dtype=torch.uint8, - device='cpu', - pin_memory=is_pin_memory_available()) + device="cpu", + pin_memory=is_pin_memory_available(), + ) cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes) data.cpu_backup_tensor = cpu_backup_tensor @@ -228,8 +238,11 @@ def sleep( logger.info( "CuMemAllocator: sleep freed %.2f GiB memory in total, of which " "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded " - "directly.", total_bytes / 1024**3, backup_bytes / 1024**3, - (total_bytes - backup_bytes) / 1024**3) + "directly.", + total_bytes / 1024**3, + backup_bytes / 1024**3, + (total_bytes - backup_bytes) / 1024**3, + ) gc.collect() torch.cuda.empty_cache() @@ -251,8 +264,9 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None: if data.cpu_backup_tensor is not None: cpu_backup_tensor = data.cpu_backup_tensor if cpu_backup_tensor is not None: - size_in_bytes = cpu_backup_tensor.numel( - ) * cpu_backup_tensor.element_size() + size_in_bytes = ( + cpu_backup_tensor.numel() * cpu_backup_tensor.element_size() + ) cpu_ptr = cpu_backup_tensor.data_ptr() libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) data.cpu_backup_tensor = None @@ -274,8 +288,9 @@ def use_memory_pool(self, tag: Optional[str] = None): old_tag = self.current_tag self.current_tag = tag - with use_memory_pool_with_allocator(self.python_malloc_callback, - self.python_free_callback) as data: + with use_memory_pool_with_allocator( + self.python_malloc_callback, self.python_free_callback + ) as data: # start to hit another PyTorch bug in PyTorch 2.6, # possibly because of gc-related issue w.r.t. the allocator and # the memory pool. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 0a5a95176f7c..46a735f22ed8 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -14,28 +14,30 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: return get_tp_group().all_reduce(input_) -def tensor_model_parallel_all_gather(input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: +def tensor_model_parallel_all_gather( + input_: torch.Tensor, dim: int = -1 +) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" return get_tp_group().all_gather(input_, dim) -def tensor_model_parallel_reduce_scatter(input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: +def tensor_model_parallel_reduce_scatter( + input_: torch.Tensor, dim: int = -1 +) -> torch.Tensor: """Reduce-Scatter the input tensor across model parallel group.""" return get_tp_group().reduce_scatter(input_, dim) -def tensor_model_parallel_gather(input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: +def tensor_model_parallel_gather( + input_: torch.Tensor, dst: int = 0, dim: int = -1 +) -> Optional[torch.Tensor]: """Gather the input tensor across model parallel group.""" return get_tp_group().gather(input_, dst, dim) -def broadcast_tensor_dict(tensor_dict: Optional[dict[Any, Union[torch.Tensor, - Any]]] = None, - src: int = 0): +def broadcast_tensor_dict( + tensor_dict: Optional[dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0 +): if not torch.distributed.is_initialized(): return tensor_dict return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index bb3fd657facd..a22f43cd88d1 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -33,17 +33,19 @@ class NaiveAll2AllManager(All2AllManagerBase): def __init__(self, cpu_group): super().__init__(cpu_group) - def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_sp_cpu: torch.Tensor, - is_sequence_parallel: bool) -> torch.Tensor: - assert (len(x.shape) == 2) - buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) + def naive_multicast( + self, + x: torch.Tensor, + cu_tokens_across_sp_cpu: torch.Tensor, + is_sequence_parallel: bool, + ) -> torch.Tensor: + assert len(x.shape) == 2 + buffer = torch.empty( + (cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype + ) rank = self.rank if is_sequence_parallel else self.dp_rank - world_size = (self.world_size - if is_sequence_parallel else self.dp_world_size) + world_size = self.world_size if is_sequence_parallel else self.dp_world_size start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1] end = cu_tokens_across_sp_cpu[rank] @@ -59,24 +61,23 @@ def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - is_sequence_parallel: bool = False + is_sequence_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: sp_size = self.tp_group.world_size if is_sequence_parallel else 1 dp_metadata = get_forward_context().dp_metadata cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_sp_cpu, - is_sequence_parallel) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_sp_cpu, - is_sequence_parallel) + hidden_states = self.naive_multicast( + hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel + ) + router_logits = self.naive_multicast( + router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel + ) return hidden_states, router_logits - def combine(self, - hidden_states: torch.Tensor, - is_sequence_parallel: bool = False) -> torch.Tensor: - + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: ep_rank = self.rank if is_sequence_parallel else self.dp_rank dp_metadata = get_forward_context().dp_metadata @@ -107,13 +108,12 @@ def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - is_sequence_parallel: bool = False + is_sequence_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Gather hidden_states and router_logits from all dp ranks. """ - sizes = get_forward_context( - ).dp_metadata.get_chunk_sizes_across_dp_rank() + sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] @@ -124,19 +124,16 @@ def dispatch( ) return hidden_states, router_logits - def combine(self, - hidden_states: torch.Tensor, - is_sequence_parallel: bool = False) -> torch.Tensor: + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: """ Reduce-scatter hidden_states across all dp ranks. """ - sizes = get_forward_context( - ).dp_metadata.get_chunk_sizes_across_dp_rank() + sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() - hidden_states = dist_group.reduce_scatterv(hidden_states, - dim=0, - sizes=sizes) + hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes) return hidden_states def destroy(self): @@ -149,24 +146,36 @@ class PPLXAll2AllManager(All2AllManagerBase): """ def __init__(self, cpu_group): - assert has_pplx( - ), "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa + assert has_pplx(), ( + "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md" + " to install pplx_kernels." + ) super().__init__(cpu_group) if self.internode: # inter-node communication needs nvshmem, # intra-node communication uses p2p mapping directly - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, - nvshmem_init) + from pplx_kernels.nvshmem import ( + nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, + nvshmem_init, + ) + logger.debug( - "Initialize NVSHMEM for pplx_kernels: " - "rank=%d, world size=%d", self.rank, self.world_size) - uid = nvshmem_get_unique_id( - ) if self.rank == 0 else nvshmem_alloc_empty_unique_id() - dist.broadcast(uid, - src=dist.get_process_group_ranks(self.cpu_group)[0], - group=self.cpu_group) + "Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d", + self.rank, + self.world_size, + ) + uid = ( + nvshmem_get_unique_id() + if self.rank == 0 + else nvshmem_alloc_empty_unique_id() + ) + dist.broadcast( + uid, + src=dist.get_process_group_ranks(self.cpu_group)[0], + group=self.cpu_group, + ) logger.debug("PPLX NVSHMEM UID = %s", uid) nvshmem_init(uid, self.rank, self.world_size) @@ -174,21 +183,23 @@ def __init__(self, cpu_group): def get_handle(self, kwargs): import pplx_kernels as pplx + return self.handle_cache.get_or_create( - kwargs, pplx.AllToAll.internode - if self.internode else pplx.AllToAll.intranode) + kwargs, + pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode, + ) def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - is_sequence_parallel: bool = False + is_sequence_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError - def combine(self, - hidden_states: torch.Tensor, - is_sequence_parallel: bool = False) -> torch.Tensor: + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: raise NotImplementedError def destroy(self): @@ -198,6 +209,7 @@ def destroy(self): if self.internode: from pplx_kernels.nvshmem import nvshmem_finalize + logger.debug("PPLX NVSHMEM finalize") nvshmem_finalize() @@ -208,8 +220,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): """ def __init__(self, cpu_group): - assert has_deep_ep( - ), "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa + assert has_deep_ep(), ( + "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md" + " to install DeepEP kernels." + ) # noqa super().__init__(cpu_group) self.handle_cache = Cache() @@ -224,13 +238,13 @@ def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - is_sequence_parallel: bool = False + is_sequence_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError - def combine(self, - hidden_states: torch.Tensor, - is_sequence_parallel: bool = False) -> torch.Tensor: + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: raise NotImplementedError def destroy(self): @@ -260,23 +274,27 @@ def _make_all2all_kwargs(self) -> dict[Any, Any]: assert num_rdma_bytes is not None assert num_qps_per_rank is not None - return dict(group=self.cpu_group, - num_nvl_bytes=num_nvl_bytes, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=False, - num_qps_per_rank=num_qps_per_rank) + return dict( + group=self.cpu_group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=False, + num_qps_per_rank=num_qps_per_rank, + ) def get_handle(self, kwargs): - assert len(kwargs) == 0, ( "DeepEPHTAll2AllManager expects no arguments. All the required " - "args are computed in the Manager itself.") + "args are computed in the Manager itself." + ) import deep_ep + buffer_kwargs = self._make_all2all_kwargs() logger.debug("DeepEP all2all args %s", buffer_kwargs) handle: deep_ep.Buffer = self.handle_cache.get_or_create( - buffer_kwargs, deep_ep.Buffer) + buffer_kwargs, deep_ep.Buffer + ) return handle def set_num_sms(self, num_sms: int): @@ -323,14 +341,17 @@ def _make_all2all_kwargs( num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, hidden=token_hidden_size, num_ranks=num_ep_ranks, - num_experts=num_global_experts) + num_experts=num_global_experts, + ) assert num_rdma_bytes is not None - return dict(group=self.cpu_group, - num_nvl_bytes=num_nvl_bytes, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=True, - num_qps_per_rank=num_qps_per_rank) + return dict( + group=self.cpu_group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_qps_per_rank, + ) def get_handle(self, kwargs): """ @@ -338,10 +359,12 @@ def get_handle(self, kwargs): _make_all2all_kwargs. """ import deep_ep + buffer_kwargs = self._make_all2all_kwargs(**kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs) handle: deep_ep.Buffer = self.handle_cache.get_or_create( - buffer_kwargs, deep_ep.Buffer) + buffer_kwargs, deep_ep.Buffer + ) return handle # DeepEP LL uses RDMA so no SMs are used for communication @@ -355,12 +378,15 @@ class FlashInferAllToAllManager(All2AllManagerBase): """ def __init__(self, cpu_group): - assert has_flashinfer_all2all( - ), "flashinfer all2all module not found. Please install/check flashinfer" # noqa + assert has_flashinfer_all2all(), ( + "flashinfer all2all module not found. Please install/check flashinfer" + ) # noqa super().__init__(cpu_group) logger.debug( - "Initialize for flashinfer All2All " - "rank=%d, world size=%d", self.rank, self.world_size) + "Initialize for flashinfer All2All rank=%d, world size=%d", + self.rank, + self.world_size, + ) self.initialized = False self.alltoall_info = None @@ -375,8 +401,7 @@ def initialize( return self.cleanup() - logger.debug("making map: " - "rank=%d, world size=%d", rank, world_size) + logger.debug("making map: rank=%d, world size=%d", rank, world_size) self.mapping = Mapping( world_size, rank, @@ -385,25 +410,28 @@ def initialize( ) from vllm.distributed.device_communicators.mnnvl_compat import ( - CustomCommunicator) + CustomCommunicator, + ) + dp_config = MnnvlConfig( comm_backend=CustomCommunicator(get_dp_group().cpu_group), fabric_page_size=1 << 29, # 512MB - allocation_granularity=0 # Auto-detect + allocation_granularity=0, # Auto-detect ) - self.workspace_tensor = MnnvlMoe.get_moe_workspaces( - self.mapping, dp_config) + self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config) self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace( - self.mapping, dp_config) + self.mapping, dp_config + ) self.world_size = world_size self.rank = rank self.gpus_per_node = gpus_per_node self.initialized = True - logger.info("FlashInfer All2All initialized for rank %s, size %s", - rank, world_size) + logger.info( + "FlashInfer All2All initialized for rank %s, size %s", rank, world_size + ) def ensure_alltoall_workspace_initialized(self): """Ensure workspace is initialized""" @@ -426,8 +454,11 @@ def get_handle(self, kwargs): def cleanup(self): """Clean up workspace""" - if self.initialized and self.workspace_tensor is not None \ - and self.prepare_workspace_tensor is not None: + if ( + self.initialized + and self.workspace_tensor is not None + and self.prepare_workspace_tensor is not None + ): try: del self.workspace_tensor del self.prepare_workspace_tensor diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 87e0f8e1a967..dabb48320be4 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -19,8 +19,7 @@ import vllm.envs as envs from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.logger import init_logger -from vllm.utils import (cuda_device_count_stateless, - update_environment_variables) +from vllm.utils import cuda_device_count_stateless, update_environment_variables logger = init_logger(__name__) @@ -39,7 +38,7 @@ 4: 2 * MiB, # 2 MB 6: 1 * MiB, # 1 MB 8: 1 * MiB, # 1 MB - } + }, } SYMM_MEM_ALL_REDUCE_MAX_SIZES = { @@ -54,7 +53,7 @@ 4: 32 * MiB, # 32 MB 6: 128 * MiB, # 128 MB 8: 128 * MiB, # 128 MB - } + }, } NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = { @@ -63,14 +62,15 @@ 4: 2 * MiB, # 2 MB 8: 1 * MiB, # 1 MB }, - "always_use_above_world_size": 8 # Always use symm mem for world_size > 8 + "always_use_above_world_size": 8, # Always use symm mem for world_size > 8 } -def should_nccl_symm_mem_allreduce(world_size: int, - input_tensor: torch.Tensor) -> bool: +def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool: from vllm.distributed.device_communicators.pynccl_allocator import ( - is_symmetric_memory_enabled) + is_symmetric_memory_enabled, + ) + if not is_symmetric_memory_enabled(): return False if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]: @@ -78,18 +78,18 @@ def should_nccl_symm_mem_allreduce(world_size: int, threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size) if threshold is not None and input_tensor.nbytes >= threshold: return True - return (world_size - > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]) + return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"] -def producer(batch_src: Sequence[int], - producer_queue, - consumer_queue, - result_queue, - cuda_visible_devices: Optional[str] = None): +def producer( + batch_src: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None, +): if cuda_visible_devices is not None: - update_environment_variables( - {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) lib = CudaRTLibrary() for i in batch_src: @@ -115,14 +115,15 @@ def producer(batch_src: Sequence[int], lib.cudaDeviceReset() -def consumer(batch_tgt: Sequence[int], - producer_queue, - consumer_queue, - result_queue, - cuda_visible_devices: Optional[str] = None): +def consumer( + batch_tgt: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None, +): if cuda_visible_devices is not None: - update_environment_variables( - {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) lib = CudaRTLibrary() for j in batch_tgt: @@ -198,12 +199,26 @@ def can_actually_p2p( producer_queue = smp.Queue() consumer_queue = smp.Queue() result_queue = smp.Queue() - p_src = smp.Process(target=producer, - args=(batch_src, producer_queue, consumer_queue, - result_queue, cuda_visible_devices)) - p_tgt = smp.Process(target=consumer, - args=(batch_tgt, producer_queue, consumer_queue, - result_queue, cuda_visible_devices)) + p_src = smp.Process( + target=producer, + args=( + batch_src, + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices, + ), + ) + p_tgt = smp.Process( + target=consumer, + args=( + batch_tgt, + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices, + ), + ) p_src.start() p_tgt.start() p_src.join() @@ -216,7 +231,10 @@ def can_actually_p2p( if a != b: logger.warning( "Two processes do not agree on the P2P access" - " status on %d -> %d, treat as disabled.", src, tgt) + " status on %d -> %d, treat as disabled.", + src, + tgt, + ) result.append(False) else: result.append(a) @@ -255,12 +273,14 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) path = os.path.join( - envs.VLLM_CACHE_ROOT, - f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json" + ) os.makedirs(os.path.dirname(path), exist_ok=True) from vllm.distributed.parallel_state import get_world_group - if ((not is_distributed or get_world_group().local_rank == 0) - and (not os.path.exists(path))): + + if (not is_distributed or get_world_group().local_rank == 0) and ( + not os.path.exists(path) + ): # only the local master process (with local_rank == 0) can # enter this block to calculate the cache logger.info("generating GPU P2P access cache in %s", path) @@ -279,11 +299,10 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: # we don't use the output of the subprocess directly, # because the subprocess might produce logging output with tempfile.NamedTemporaryFile() as output_file: - input_bytes = pickle.dumps( - (batch_src, batch_tgt, output_file.name)) - returned = subprocess.run([sys.executable, __file__], - input=input_bytes, - capture_output=True) + input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name)) + returned = subprocess.run( + [sys.executable, __file__], input=input_bytes, capture_output=True + ) # check if the subprocess is successful try: returned.check_returncode() @@ -292,7 +311,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: raise RuntimeError( f"Error happened when batch testing " f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" - f"{returned.stderr.decode()}") from e + f"{returned.stderr.decode()}" + ) from e with open(output_file.name, "rb") as f: result = pickle.load(f) for _i, _j, r in zip(batch_src, batch_tgt, result): diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index a42081fb0c15..c32be0bec55c 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -10,7 +10,6 @@ class Cache: - def __init__(self): self._cache: WeakValueDictionary = WeakValueDictionary() self._lock = threading.RLock() # Reentrant lock for thread safety @@ -35,9 +34,11 @@ def __init__(self, cpu_group): self.cpu_group = cpu_group # compute some common properties - from vllm.distributed.parallel_state import (get_dp_group, - get_tp_group, - in_the_same_node_as) + from vllm.distributed.parallel_state import ( + get_dp_group, + get_tp_group, + in_the_same_node_as, + ) # all2all lives in ep group, which is merged from dp and tp group self.dp_group = get_dp_group() @@ -63,10 +64,12 @@ def get_handle(self, kwargs): # and reuse it for the same config. raise NotImplementedError - def dispatch(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_sequence_parallel: bool = False): + def dispatch( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + ): raise NotImplementedError def set_num_sms(self, num_sms: int): @@ -75,9 +78,7 @@ def set_num_sms(self, num_sms: int): def max_sms_used(self) -> Optional[int]: return None # None means it could use the whole GPU - def combine(self, - hidden_states: torch.Tensor, - is_sequence_parallel: bool = False): + def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False): raise NotImplementedError def destroy(self): @@ -92,11 +93,13 @@ class DeviceCommunicatorBase: communication backend), the `device_group` will also be given. """ - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = "", + ): self.device = device or torch.device("cpu") self.cpu_group = cpu_group self.device_group = device_group @@ -106,11 +109,11 @@ def __init__(self, self.ranks = dist.get_process_group_ranks(cpu_group) self.global_rank = dist.get_rank() self.global_world_size = dist.get_world_size() - self.rank_in_group = dist.get_group_rank(self.cpu_group, - self.global_rank) + self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) use_ep = False from vllm.config import get_current_vllm_config + config = get_current_vllm_config() if config is not None: # as long as we use data parallel (coupled data parallel @@ -134,41 +137,39 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: # NOTE: we have to use concat-style all-gather here, # stack-style all-gather has compatibility issues with # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 - output_size = (input_size[0] * self.world_size, ) + input_size[1:] + output_size = (input_size[0] * self.world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=input_.dtype, - device=input_.device) + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) # All-gather. - dist.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) + dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) # Reshape - output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.reshape((self.world_size,) + input_size) output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (self.world_size * - input_size[dim], ) + - input_size[dim + 1:]) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (self.world_size * input_size[dim],) + + input_size[dim + 1 :] + ) return output_tensor def all_gatherv( self, input_: Union[torch.Tensor, list[torch.Tensor]], dim: int = 0, - sizes: Optional[list[int]] = None + sizes: Optional[list[int]] = None, ) -> Union[torch.Tensor, list[torch.Tensor]]: raise NotImplementedError - def reduce_scatter(self, - input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if dim < 0: # Convert negative dim to positive. @@ -180,30 +181,28 @@ def reduce_scatter(self, assert input_tensor.shape[0] % world_size == 0 chunk_size = input_tensor.shape[0] // world_size - output_shape = (chunk_size, ) + input_tensor.shape[1:] + output_shape = (chunk_size,) + input_tensor.shape[1:] - output_tensor = torch.empty(output_shape, - dtype=input_tensor.dtype, - device=input_tensor.device) + output_tensor = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) # Perform reduce-scatter operation - torch.distributed.reduce_scatter_tensor(output_tensor, - input_tensor, - group=self.device_group) + torch.distributed.reduce_scatter_tensor( + output_tensor, input_tensor, group=self.device_group + ) # Reshape before returning return output_tensor.movedim(0, dim).contiguous() - def reduce_scatterv(self, - input_: torch.Tensor, - dim: int = -1, - sizes: Optional[list[int]] = None) -> torch.Tensor: + def reduce_scatterv( + self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None + ) -> torch.Tensor: raise NotImplementedError - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[torch.Tensor]: """ NOTE: We assume that the input tensor is on the same device across all the ranks. @@ -211,7 +210,8 @@ def gather(self, """ world_size = self.world_size assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -222,10 +222,9 @@ def gather(self, else: gather_list = None # Gather. - torch.distributed.gather(input_, - gather_list, - dst=self.ranks[dst], - group=self.device_group) + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) if self.rank_in_group == dst: output_tensor = torch.cat(gather_list, dim=dim) else: @@ -239,10 +238,9 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: dst = (self.rank_in_group + 1) % self.world_size torch.distributed.send(tensor, self.ranks[dst], self.device_group) - def recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" if src is None: @@ -255,8 +253,7 @@ def recv(self, def destroy(self): pass - def prepare_communication_buffer_for_model(self, - model: torch.nn.Module) -> None: + def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None: """ Prepare the communication buffer for the model. """ @@ -264,11 +261,14 @@ def prepare_communication_buffer_for_model(self, return moe_modules = [ - module for module in model.modules() + module + for module in model.modules() # TODO(bnell): Should use isinstance but can't. Maybe search for # presence of quant_method.init_prepare_finalize? - if (module.__class__.__name__ == "FusedMoE" - or module.__class__.__name__ == "SharedFusedMoE") + if ( + module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE" + ) ] for module in moe_modules: module.quant_method.init_prepare_finalize(module) @@ -277,7 +277,7 @@ def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - is_sequence_parallel: bool = False + is_sequence_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ Dispatch the hidden states and router logits to the appropriate device. @@ -285,9 +285,9 @@ def dispatch( """ return hidden_states, router_logits - def combine(self, - hidden_states: torch.Tensor, - is_sequence_parallel: bool = False) -> torch.Tensor: + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: """ Combine the hidden states and router logits from the appropriate device. This is a no-op in the base class. diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index bda567f8489c..c09b3ba9ceba 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -15,30 +15,30 @@ class CpuCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = "", + ): super().__init__(cpu_group, device, device_group, unique_name) self.dist_module = torch.distributed - if (current_platform.get_cpu_architecture() - == CpuArchEnum.X86) and hasattr( - torch.ops._C, - "init_shm_manager") and (unique_name.startswith("tp") - or unique_name.startswith("pp")): + if ( + (current_platform.get_cpu_architecture() == CpuArchEnum.X86) + and hasattr(torch.ops._C, "init_shm_manager") + and (unique_name.startswith("tp") or unique_name.startswith("pp")) + ): self.dist_module = _CPUSHMDistributed(self) def all_reduce(self, input_): self.dist_module.all_reduce(input_, group=self.device_group) return input_ - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[torch.Tensor]: """ NOTE: We assume that the input tensor is on the same device across all the ranks. @@ -46,7 +46,8 @@ def gather(self, """ world_size = self.world_size assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -58,10 +59,9 @@ def gather(self, gather_list = None # Gather. - self.dist_module.gather(input_, - gather_list, - dst=self.ranks[dst], - group=self.device_group) + self.dist_module.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) if self.rank_in_group == dst: output_tensor = torch.cat(gather_list, dim=dim) @@ -77,23 +77,24 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: # NOTE: we have to use concat-style all-gather here, # stack-style all-gather has compatibility issues with # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 - output_size = (input_size[0] * self.world_size, ) + input_size[1:] + output_size = (input_size[0] * self.world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=input_.dtype, - device=input_.device) + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) # All-gather. - self.dist_module.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) + self.dist_module.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) # Reshape - output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.reshape((self.world_size,) + input_size) output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (self.world_size * - input_size[dim], ) + - input_size[dim + 1:]) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (self.world_size * input_size[dim],) + + input_size[dim + 1 :] + ) return output_tensor def send_tensor_dict( @@ -111,7 +112,6 @@ def recv_tensor_dict( class _CPUSHMDistributed: - def __init__(self, communicator: CpuCommunicator): instance_identifier = os.environ["VLLM_DIST_IDENT"] unique_name = communicator.unique_name @@ -139,24 +139,32 @@ def _init_cpu_shm(self) -> int: return handle - def all_reduce(self, - input: torch.Tensor, - group: Optional[ProcessGroup] = None) -> None: + def all_reduce( + self, input: torch.Tensor, group: Optional[ProcessGroup] = None + ) -> None: torch.ops._C.shm_allreduce(self.handle, input) - def gather(self, - input: torch.Tensor, - gather_list: Optional[list[torch.Tensor]], - dst: int = -1, - group: Optional[ProcessGroup] = None) -> None: + def gather( + self, + input: torch.Tensor, + gather_list: Optional[list[torch.Tensor]], + dst: int = -1, + group: Optional[ProcessGroup] = None, + ) -> None: # Note: different from the torch gather, here we use local dst rank. - torch.ops._C.shm_gather(self.handle, input, gather_list, - torch.distributed.get_group_rank(group, dst)) + torch.ops._C.shm_gather( + self.handle, + input, + gather_list, + torch.distributed.get_group_rank(group, dst), + ) - def all_gather_into_tensor(self, - output: torch.Tensor, - input: torch.Tensor, - group: Optional[ProcessGroup] = None) -> None: + def all_gather_into_tensor( + self, + output: torch.Tensor, + input: torch.Tensor, + group: Optional[ProcessGroup] = None, + ) -> None: torch.ops._C.shm_all_gather(self.handle, input, output) def send_tensor_dict( @@ -169,11 +177,11 @@ def send_tensor_dict( size_list = [] for v in value_list: if not isinstance(v, torch.Tensor): - raise RuntimeError( - "CpuCommunicator only supports sending tensors.") + raise RuntimeError("CpuCommunicator only supports sending tensors.") size_list.append(v.size()) - key_size_tensor = torch.frombuffer(pickle.dumps([key_list, size_list]), - dtype=torch.uint8) + key_size_tensor = torch.frombuffer( + pickle.dumps([key_list, size_list]), dtype=torch.uint8 + ) value_list.append(key_size_tensor) torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 9c2bf51a813e..45096dffb5b6 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -8,11 +8,12 @@ import vllm.envs as envs from vllm.distributed.device_communicators.all_reduce_utils import ( - should_nccl_symm_mem_allreduce) -from vllm.distributed.device_communicators.pynccl import ( - register_nccl_symmetric_ops) + should_nccl_symm_mem_allreduce, +) +from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops from vllm.distributed.device_communicators.pynccl_allocator import ( - is_symmetric_memory_enabled) + is_symmetric_memory_enabled, +) from vllm.logger import init_logger from vllm.platforms import current_platform @@ -22,20 +23,21 @@ class CudaCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = "", + ): super().__init__(cpu_group, device, device_group, unique_name) if "tp" not in unique_name: # custom allreduce or torch symm mem can be used only by tp use_custom_allreduce = False use_torch_symm_mem = False else: - from vllm.distributed.parallel_state import ( - _ENABLE_CUSTOM_ALL_REDUCE) + from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM @@ -44,13 +46,13 @@ def __init__(self, # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( - CustomAllreduce) - from vllm.distributed.device_communicators.pynccl import ( - PyNcclCommunicator) + CustomAllreduce, + ) + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickAllReduce) - from vllm.distributed.device_communicators.symm_mem import ( - SymmMemCommunicator) + QuickAllReduce, + ) + from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator self.pynccl_comm: Optional[PyNcclCommunicator] = None if self.world_size > 1: @@ -75,8 +77,9 @@ def __init__(self, self.ca_comm = CustomAllreduce( group=self.cpu_group, device=self.device, - symm_mem_enabled=(self.symm_mem_comm is not None - and not self.symm_mem_comm.disabled), + symm_mem_enabled=( + self.symm_mem_comm is not None and not self.symm_mem_comm.disabled + ), ) if current_platform.is_rocm(): @@ -85,35 +88,39 @@ def __init__(self, # Based on quickreduce (https://github.com/mk1-project/quickreduce). # If it's a rocm, 'use_custom_allreduce==True' means it must # currently be an MI300 series. - self.qr_comm = QuickAllReduce(group=self.cpu_group, - device=self.device) + self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") elif all2all_backend == "allgather_reducescatter": from .all2all import AgRsAll2AllManager + self.all2all_manager = AgRsAll2AllManager(self.cpu_group) logger.info("Using AllGather-ReduceScatter all2all manager.") elif all2all_backend == "pplx": from .all2all import PPLXAll2AllManager + self.all2all_manager = PPLXAll2AllManager(self.cpu_group) logger.info("Using PPLX all2all manager.") elif all2all_backend == "deepep_high_throughput": from .all2all import DeepEPHTAll2AllManager + self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) logger.info("Using DeepEP High-Throughput all2all manager.") elif all2all_backend == "deepep_low_latency": from .all2all import DeepEPLLAll2AllManager + self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) logger.info("Using DeepEP Low-Latency all2all manager.") elif all2all_backend == "flashinfer_all2allv": from .all2all import FlashInferAllToAllManager - self.all2all_manager = FlashInferAllToAllManager( - self.cpu_group) + + self.all2all_manager = FlashInferAllToAllManager(self.cpu_group) logger.info("Using Flashinfer all2allv manager.") else: raise ValueError(f"Unknown all2all backend: {all2all_backend}") @@ -121,28 +128,34 @@ def __init__(self, def all_reduce(self, input_): # since currently we perform copy input -> symm_input -> out-of-place AR # return symm_output, we don't need to check if input is symmetric - if self.pynccl_comm is not None and \ - should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_): + if self.pynccl_comm is not None and should_nccl_symm_mem_allreduce( + self.pynccl_comm.world_size, input_ + ): out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_) if out is not None: return out # always try quick reduce first, then custom allreduce, # and then pynccl. (quick reduce just for ROCM MI3*) qr_comm = self.qr_comm - if qr_comm is not None and not qr_comm.disabled and \ - qr_comm.should_quick_allreduce(input_): + if ( + qr_comm is not None + and not qr_comm.disabled + and qr_comm.should_quick_allreduce(input_) + ): out = qr_comm.quick_all_reduce(input_) assert out is not None return out ca_comm = self.ca_comm - if ca_comm is not None and not ca_comm.disabled and \ - ca_comm.should_custom_ar(input_): + if ( + ca_comm is not None + and not ca_comm.disabled + and ca_comm.should_custom_ar(input_) + ): out = ca_comm.custom_all_reduce(input_) assert out is not None return out symm_mem_comm = self.symm_mem_comm - if symm_mem_comm is not None and \ - symm_mem_comm.should_use_symm_mem(input_): + if symm_mem_comm is not None and symm_mem_comm.should_use_symm_mem(input_): out = symm_mem_comm.all_reduce(input_) assert out is not None return out @@ -176,21 +189,20 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): assert input_tensor.shape[0] % world_size == 0 chunk_size = input_tensor.shape[0] // world_size - output_shape = (chunk_size, ) + input_tensor.shape[1:] + output_shape = (chunk_size,) + input_tensor.shape[1:] - output = torch.empty(output_shape, - dtype=input_tensor.dtype, - device=input_tensor.device) + output = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) pynccl_comm.reduce_scatter(output, input_tensor) # Reshape before returning return output.movedim(0, dim).contiguous() - def reduce_scatterv(self, - input_: torch.Tensor, - dim: int = -1, - sizes: Optional[list[int]] = None): + def reduce_scatterv( + self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None + ): world_size = self.world_size pynccl_comm = self.pynccl_comm assert pynccl_comm is not None @@ -209,11 +221,11 @@ def reduce_scatterv(self, else: assert input_tensor.shape[0] % world_size == 0 chunk_size = input_tensor.shape[0] // world_size - output_shape = (chunk_size, ) + input_tensor.shape[1:] + output_shape = (chunk_size,) + input_tensor.shape[1:] - output = torch.empty(output_shape, - dtype=input_tensor.dtype, - device=input_tensor.device) + output = torch.empty( + output_shape, dtype=input_tensor.dtype, device=input_tensor.device + ) if sizes is not None: pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes) @@ -235,10 +247,9 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: else: torch.distributed.send(tensor, self.ranks[dst], self.device_group) - def recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" if src is None: @@ -261,10 +272,12 @@ def destroy(self): self.all2all_manager.destroy() self.all2all_manager = None - def all_gatherv(self, - input_: Union[torch.Tensor, list[torch.Tensor]], - dim: int = 0, - sizes: Optional[list[int]] = None): + def all_gatherv( + self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None, + ): if dim != 0: raise NotImplementedError("only dim 0 all-gatherv is supported") world_size = self.world_size @@ -276,20 +289,20 @@ def all_gatherv(self, if sizes is not None and all(s == sizes[0] for s in sizes): sizes = None - def _all_gather_single(input_: torch.Tensor, - sizes: Optional[list[int]] = None): + def _all_gather_single(input_: torch.Tensor, sizes: Optional[list[int]] = None): input_size = input_.size() if sizes is not None: assert len(sizes) == world_size assert input_.shape[dim] == sizes[self.rank_in_group], ( - f"{input_.shape[dim]} != {sizes[self.rank_in_group]}") - output_size = (sum(sizes), ) + input_size[1:] + f"{input_.shape[dim]} != {sizes[self.rank_in_group]}" + ) + output_size = (sum(sizes),) + input_size[1:] else: - output_size = (input_size[0] * world_size, ) + input_size[1:] + output_size = (input_size[0] * world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=input_.dtype, - device=input_.device) + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) if sizes is not None: pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes) else: @@ -311,17 +324,19 @@ def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - is_sequence_parallel: bool = False + is_sequence_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: assert self.all2all_manager is not None hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits, is_sequence_parallel) + hidden_states, router_logits, is_sequence_parallel + ) return hidden_states, router_logits - def combine(self, - hidden_states: torch.Tensor, - is_sequence_parallel: bool = False) -> torch.Tensor: + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine(hidden_states, - is_sequence_parallel) + hidden_states = self.all2all_manager.combine( + hidden_states, is_sequence_parallel + ) return hidden_states diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index 2c38e8ed21d7..a77d2666e2ce 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -42,7 +42,7 @@ def find_loaded_library(lib_name) -> Optional[str]: the file `/proc/self/maps` contains the memory maps of the process, which includes the shared libraries loaded by the process. We can use this file to find the path of the a loaded library. - """ # noqa + """ # noqa found = False with open("/proc/self/maps") as f: for line in f: @@ -57,8 +57,9 @@ def find_loaded_library(lib_name) -> Optional[str]: start = line.index("/") path = line[start:].strip() filename = path.split("/")[-1] - assert filename.rpartition(".so")[0].startswith(lib_name), \ + assert filename.rpartition(".so")[0].startswith(lib_name), ( f"Unexpected filename: {filename} for library {lib_name}" + ) return path @@ -70,30 +71,38 @@ class CudaRTLibrary: Function("cudaDeviceSynchronize", cudaError_t, []), # ​cudaError_t cudaDeviceReset ( void ) Function("cudaDeviceReset", cudaError_t, []), - # const char* cudaGetErrorString ( cudaError_t error ) Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), - # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) - Function("cudaMalloc", cudaError_t, - [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + Function( + "cudaMalloc", + cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t], + ), # ​cudaError_t cudaFree ( void* devPtr ) Function("cudaFree", cudaError_t, [ctypes.c_void_p]), # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) - Function("cudaMemset", cudaError_t, - [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + Function( + "cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t] + ), # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa - Function("cudaMemcpy", cudaError_t, [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind - ]), - + Function( + "cudaMemcpy", + cudaError_t, + [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind], + ), # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa - Function("cudaIpcGetMemHandle", cudaError_t, - [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + Function( + "cudaIpcGetMemHandle", + cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p], + ), # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa - Function("cudaIpcOpenMemHandle", cudaError_t, [ - ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint - ]), + Function( + "cudaIpcOpenMemHandle", + cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint], + ), ] # class attribute to store the mapping from the path to the library @@ -109,11 +118,10 @@ def __init__(self, so_file: Optional[str] = None): so_file = find_loaded_library("libcudart") if so_file is None: so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var - assert so_file is not None, \ - ( - "libcudart is not loaded in the current process, " - "try setting VLLM_CUDART_SO_PATH" - ) + assert so_file is not None, ( + "libcudart is not loaded in the current process, " + "try setting VLLM_CUDART_SO_PATH" + ) if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib @@ -154,27 +162,29 @@ def cudaMalloc(self, size: int) -> ctypes.c_void_p: def cudaFree(self, devPtr: ctypes.c_void_p) -> None: self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) - def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, - count: int) -> None: + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None: self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) - def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, - count: int) -> None: + def cudaMemcpy( + self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int + ) -> None: cudaMemcpyDefault = 4 kind = cudaMemcpyDefault self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) - def cudaIpcGetMemHandle(self, - devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: handle = cudaIpcMemHandle_t() - self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( - ctypes.byref(handle), devPtr)) + self.CUDART_CHECK( + self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr) + ) return handle - def cudaIpcOpenMemHandle(self, - handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: cudaIpcMemLazyEnablePeerAccess = 1 devPtr = ctypes.c_void_p() - self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( - ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) + self.CUDART_CHECK( + self.funcs["cudaIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess + ) + ) return devPtr diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 3cc4bbb25824..fd5c5dfd9da0 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -11,7 +11,9 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed.device_communicators.all_reduce_utils import ( - CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check) + CUSTOM_ALL_REDUCE_MAX_SIZES, + gpu_p2p_access_check, +) from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform @@ -32,8 +34,7 @@ def _can_p2p(rank: int, world_size: int) -> bool: if i == rank: continue if envs.VLLM_SKIP_P2P_CHECK: - logger.info( - "Skipping P2P check and trusting the driver's P2P report.") + logger.info("Skipping P2P check and trusting the driver's P2P report.") return torch.cuda.can_device_access_peer(rank, i) if not gpu_p2p_access_check(rank, i): return False @@ -41,21 +42,23 @@ def _can_p2p(rank: int, world_size: int) -> bool: def is_weak_contiguous(inp: torch.Tensor): - return inp.is_contiguous() or (inp.storage().nbytes() - - inp.storage_offset() * inp.element_size() - == inp.numel() * inp.element_size()) + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size() + ) class CustomAllreduce: - _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] # max_size: max supported allreduce size - def __init__(self, - group: ProcessGroup, - device: Union[int, str, torch.device], - max_size=8192 * 1024, - symm_mem_enabled=False) -> None: + def __init__( + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_size=8192 * 1024, + symm_mem_enabled=False, + ) -> None: """ Args: group: the process group to work on. If None, it will use the @@ -72,20 +75,24 @@ def __init__(self, if not custom_ar: # disable because of missing custom allreduce library # e.g. in a non-GPU environment - logger.info("Custom allreduce is disabled because " - "of missing custom allreduce library") + logger.info( + "Custom allreduce is disabled because " + "of missing custom allreduce library" + ) return self.group = group assert dist.get_backend(group) != dist.Backend.NCCL, ( - "CustomAllreduce should be attached to a non-NCCL group.") + "CustomAllreduce should be attached to a non-NCCL group." + ) if not all(in_the_same_node_as(group, source_rank=0)): # No need to initialize custom allreduce for multi-node case. logger.warning( "Custom allreduce is disabled because this process group" - " spans across nodes.") + " spans across nodes." + ) return rank = dist.get_rank(group=self.group) @@ -100,7 +107,9 @@ def __init__(self, "Custom allreduce is disabled due to an unsupported world" " size: %d. Supported world sizes: %s. To silence this " "warning, specify disable_custom_all_reduce=True explicitly.", - world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + world_size, + str(CustomAllreduce._SUPPORTED_WORLD_SIZES), + ) return if isinstance(device, int): @@ -110,13 +119,15 @@ def __init__(self, # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device - device_capability = current_platform.get_device_capability( - ).as_version_str() - if (current_platform.is_cuda() and symm_mem_enabled - and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): + device_capability = current_platform.get_device_capability().as_version_str() + if ( + current_platform.is_cuda() + and symm_mem_enabled + and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES + ): max_size = min( - CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], - max_size) + CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size + ) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) @@ -124,12 +135,9 @@ def __init__(self, device_ids = list(range(cuda_device_count_stateless())) physical_device_id = device_ids[device.index] - tensor = torch.tensor([physical_device_id], - dtype=torch.int, - device="cpu") + tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") gather_list = [ - torch.tensor([0], dtype=torch.int, device="cpu") - for _ in range(world_size) + torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size) ] dist.all_gather(gather_list, tensor, group=self.group) physical_device_ids = [t.item() for t in gather_list] @@ -138,13 +146,13 @@ def __init__(self, # where custom allreduce is not supported # this checks hardware and driver support for NVLink assert current_platform.is_cuda_alike() - fully_connected = current_platform.is_fully_connected( - physical_device_ids) + fully_connected = current_platform.is_fully_connected(physical_device_ids) if world_size > 2 and not fully_connected: logger.warning( "Custom allreduce is disabled because it's not supported on" " more than two PCIe-only GPUs. To silence this warning, " - "specify disable_custom_all_reduce=True explicitly.") + "specify disable_custom_all_reduce=True explicitly." + ) return # test P2P capability, this checks software/cudaruntime support # this is expensive to compute at the first time @@ -154,16 +162,17 @@ def __init__(self, logger.warning( "Custom allreduce is disabled because your platform lacks " "GPU P2P capability or P2P test failed. To silence this " - "warning, specify disable_custom_all_reduce=True explicitly.") + "warning, specify disable_custom_all_reduce=True explicitly." + ) return self.disabled = False # Buffers memory are owned by this Python class and passed to C++. # Metadata composes of two parts: metadata for synchronization and a # temporary buffer for storing intermediate allreduce results. - self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, - group=group, - uncached=True) + self.meta_ptrs = self.create_shared_buffer( + ops.meta_size() + max_size, group=group, uncached=True + ) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) @@ -172,21 +181,22 @@ def __init__(self, # 8*world_size bytes where world_size is at most 8. Allocating 8MB # is enough for 131072 such tuples. The largest model I've seen only # needs less than 10000 of registered tuples. - self.rank_data = torch.empty(8 * 1024 * 1024, - dtype=torch.uint8, - device=self.device) + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) self.max_size = max_size self.rank = rank self.world_size = world_size self.fully_connected = fully_connected - self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, - self.fully_connected) + self._ptr = ops.init_custom_ar( + self.meta_ptrs, self.rank_data, rank, self.fully_connected + ) ops.register_buffer(self._ptr, self.buffer_ptrs) @contextmanager def capture(self): """ - The main responsibility of this context manager is the + The main responsibility of this context manager is the `register_graph_buffers` call at the end of the context. It records all the buffer addresses used in the CUDA graph. """ @@ -204,15 +214,13 @@ def register_graph_buffers(self): # We cannot directly use `dist.all_gather_object` here # because it is incompatible with `gloo` backend under inference mode. # see https://github.com/pytorch/pytorch/issues/126032 for details. - all_data = [[None, None] - for _ in range(dist.get_world_size(group=self.group))] + all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] all_data[self.rank] = [handle, offset] ranks = sorted(dist.get_process_group_ranks(group=self.group)) for i, rank in enumerate(ranks): - dist.broadcast_object_list(all_data[i], - src=rank, - group=self.group, - device="cpu") + dist.broadcast_object_list( + all_data[i], src=rank, group=self.group, device="cpu" + ) # Unpack list of tuples to tuple of lists. handles = [d[0] for d in all_data] # type: ignore offsets = [d[1] for d in all_data] # type: ignore @@ -233,13 +241,11 @@ def should_custom_ar(self, inp: torch.Tensor): return inp_size < self.max_size return False - def all_reduce(self, - inp: torch.Tensor, - *, - out: torch.Tensor = None, - registered: bool = False): + def all_reduce( + self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False + ): """Performs an out-of-place all reduce. - + If registered is True, this assumes inp's pointer is already IPC-registered. Otherwise, inp is first copied into a pre-registered buffer. @@ -249,8 +255,9 @@ def all_reduce(self, if registered: ops.all_reduce(self._ptr, inp, out, 0, 0) else: - ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], - self.max_size) + ops.all_reduce( + self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size + ) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -283,9 +290,11 @@ def __del__(self): self.close() @staticmethod - def create_shared_buffer(size_in_bytes: int, - group: Optional[ProcessGroup] = None, - uncached: Optional[bool] = False) -> list[int]: + def create_shared_buffer( + size_in_bytes: int, + group: Optional[ProcessGroup] = None, + uncached: Optional[bool] = False, + ) -> list[int]: pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes) world_size = dist.get_world_size(group=group) @@ -302,9 +311,11 @@ def create_shared_buffer(size_in_bytes: int, return pointers @staticmethod - def free_shared_buffer(pointers: list[int], - group: Optional[ProcessGroup] = None, - rank: Optional[int] = None) -> None: + def free_shared_buffer( + pointers: list[int], + group: Optional[ProcessGroup] = None, + rank: Optional[int] = None, + ) -> None: if rank is None: rank = dist.get_rank(group=group) if ops is not None: diff --git a/vllm/distributed/device_communicators/mnnvl_compat.py b/vllm/distributed/device_communicators/mnnvl_compat.py index 80072c4fa643..61aee2db46b8 100644 --- a/vllm/distributed/device_communicators/mnnvl_compat.py +++ b/vllm/distributed/device_communicators/mnnvl_compat.py @@ -9,7 +9,6 @@ class CustomCommunicator(CommBackend): - def __init__(self, group): self._group = group @@ -24,5 +23,5 @@ def allgather(self, data: int): dist.all_gather_object(gathered, data, group=self._group) return gathered - def Split(self, color: int, key: int) -> 'CustomCommunicator': + def Split(self, color: int, key: int) -> "CustomCommunicator": return self diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 81c02d1899e5..59fa3f9c449b 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -10,8 +10,14 @@ import vllm.envs as envs from vllm.distributed.device_communicators.pynccl_wrapper import ( - NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, - ncclRedOpTypeEnum, ncclUniqueId) + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, + ncclRedOpTypeEnum, + ncclUniqueId, +) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import current_stream @@ -23,7 +29,8 @@ def register_nccl_symmetric_ops(pynccl_comm): from vllm.distributed.device_communicators.pynccl_allocator import ( - nccl_symm_mem_context) + nccl_symm_mem_context, + ) from vllm.utils import direct_register_custom_op global _NCCL_SYMM_OPS_REGISTERED @@ -31,8 +38,7 @@ def register_nccl_symmetric_ops(pynccl_comm): return _NCCL_SYMM_OPS_REGISTERED = True - def all_reduce_symmetric_with_copy_impl( - input_tensor: torch.Tensor) -> torch.Tensor: + def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor: with nccl_symm_mem_context(pynccl_comm): symm_input = torch.empty_like(input_tensor) symm_output = torch.empty_like(input_tensor) @@ -40,8 +46,7 @@ def all_reduce_symmetric_with_copy_impl( symm_output = pynccl_comm.all_reduce(symm_input, symm_output) return symm_output - def all_reduce_symmetric_with_copy_fake( - input_tensor: torch.Tensor) -> torch.Tensor: + def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor: return torch.empty_like(input_tensor) direct_register_custom_op( @@ -52,7 +57,6 @@ def all_reduce_symmetric_with_copy_fake( class PyNcclCommunicator: - def __init__( self, group: Union[ProcessGroup, StatelessProcessGroup], @@ -73,7 +77,8 @@ def __init__( if not isinstance(group, StatelessProcessGroup): assert dist.is_initialized() assert dist.get_backend(group) != dist.Backend.NCCL, ( - "PyNcclCommunicator should be attached to a non-NCCL group.") + "PyNcclCommunicator should be attached to a non-NCCL group." + ) # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) @@ -132,7 +137,8 @@ def __init__( # current cuda device to the specified one with torch.cuda.device(device): self.comm: ncclComm_t = self.nccl.ncclCommInitRank( - self.world_size, self.unique_id, self.rank) + self.world_size, self.unique_id, self.rank + ) stream = current_stream() # A small all_reduce for warmup. @@ -141,11 +147,13 @@ def __init__( stream.synchronize() del data - def all_reduce(self, - in_tensor: torch.Tensor, - out_tensor: torch.Tensor = None, - op: ReduceOp = ReduceOp.SUM, - stream=None) -> torch.Tensor: + def all_reduce( + self, + in_tensor: torch.Tensor, + out_tensor: torch.Tensor = None, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ) -> torch.Tensor: if self.disabled: return None # nccl communicator created on a specific device @@ -153,25 +161,28 @@ def all_reduce(self, # otherwise it will cause "illegal memory access" assert in_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {in_tensor.device}") + f"but the input tensor is on {in_tensor.device}" + ) if out_tensor is None: out_tensor = torch.empty_like(in_tensor) if stream is None: stream = current_stream() - self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), - buffer_type(out_tensor.data_ptr()), - in_tensor.numel(), - ncclDataTypeEnum.from_torch(in_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) + self.nccl.ncclAllReduce( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) return out_tensor - def all_gather(self, - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - stream=None): + def all_gather( + self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None + ): if self.disabled: return # nccl communicator created on a specific device @@ -179,14 +190,18 @@ def all_gather(self, # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), input_tensor.numel(), - ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, - cudaStream_t(stream.cuda_stream)) + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) def all_gatherv( self, @@ -202,14 +217,15 @@ def all_gatherv( # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() assert output_tensor.shape[0] == sum(sizes) split_offset = 0 self.nccl.ncclGroupStart() for root, split_size in enumerate(sizes): - dst_slice = output_tensor[split_offset:split_offset + split_size] + dst_slice = output_tensor[split_offset : split_offset + split_size] self.nccl.ncclBroadcast( buffer_type(input_tensor.data_ptr()), buffer_type(dst_slice.data_ptr()), @@ -222,11 +238,13 @@ def all_gatherv( split_offset += split_size self.nccl.ncclGroupEnd() - def reduce_scatter(self, - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - stream=None): + def reduce_scatter( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): if self.disabled: return # nccl communicator created on a specific device @@ -234,15 +252,19 @@ def reduce_scatter(self, # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() self.nccl.ncclReduceScatter( buffer_type(input_tensor.data_ptr()), - buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - cudaStream_t(stream.cuda_stream)) + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) def reduce_scatterv( self, @@ -259,20 +281,25 @@ def reduce_scatterv( # otherwise it will cause "illegal memory access" assert input_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {input_tensor.device}") + f"but the input tensor is on {input_tensor.device}" + ) if stream is None: stream = current_stream() split_offset = 0 self.nccl.ncclGroupStart() for root, split_size in enumerate(sizes): - chunk = input_tensor[split_offset:split_offset + split_size, ...] + chunk = input_tensor[split_offset : split_offset + split_size, ...] self.nccl.ncclReduce( buffer_type(chunk.data_ptr()), - buffer_type(output_tensor.data_ptr()), chunk.numel(), + buffer_type(output_tensor.data_ptr()), + chunk.numel(), ncclDataTypeEnum.from_torch(input_tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), root, self.comm, - cudaStream_t(stream.cuda_stream)) + ncclRedOpTypeEnum.from_torch(op), + root, + self.comm, + cudaStream_t(stream.cuda_stream), + ) split_offset += split_size self.nccl.ncclGroupEnd() @@ -281,31 +308,44 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None): return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() - self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), dst, - self.comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + cudaStream_t(stream.cuda_stream), + ) def recv(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() - self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) def broadcast(self, tensor: torch.Tensor, src: int, stream=None): if self.disabled: return assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() if src == self.rank: @@ -315,9 +355,15 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): else: sendbuff = buffer_type() recvbuff = buffer_type(tensor.data_ptr()) - self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - self.comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclBroadcast( + sendbuff, + recvbuff, + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) def group_start(self): self.nccl.ncclGroupStart() @@ -334,8 +380,7 @@ def register_comm_window(self, tensor: torch.Tensor): ) def register_comm_window_raw(self, ptr: int, size: int): - return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), - size, 1) + return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1) def deregister_comm_window(self, window): return self.nccl.ncclCommWindowDeregister(self.comm, window) diff --git a/vllm/distributed/device_communicators/pynccl_allocator.py b/vllm/distributed/device_communicators/pynccl_allocator.py index bc874c1e197e..3fe4fd744d77 100644 --- a/vllm/distributed/device_communicators/pynccl_allocator.py +++ b/vllm/distributed/device_communicators/pynccl_allocator.py @@ -98,7 +98,9 @@ def compile_nccl_allocator(): "This is expected if NCCL headers are not available. " "optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory " "containing the NCCL header. " - "Error: %s", str(e)) + "Error: %s", + str(e), + ) def get_nccl_mem_pool(): @@ -125,21 +127,24 @@ def _cleanup_nccl_allocator_wrapper(): class nccl_symm_mem_context: - def __init__( self, pynccl_comm: PyNcclCommunicator, disabled: bool = False, ): - self.disabled = (disabled or not is_symmetric_memory_enabled() - or pynccl_comm.world_size == 1 - or not current_platform.is_cuda() - or get_nccl_mem_pool() is None or version.parse( - torch.__version__) < version.parse("2.8.0.a0")) + self.disabled = ( + disabled + or not is_symmetric_memory_enabled() + or pynccl_comm.world_size == 1 + or not current_platform.is_cuda() + or get_nccl_mem_pool() is None + or version.parse(torch.__version__) < version.parse("2.8.0.a0") + ) if self.disabled: self.pynccl_comm: Optional[PyNcclCommunicator] = None - self._mem_pool_ctx: contextlib.AbstractContextManager[ - Any] = contextlib.nullcontext() + self._mem_pool_ctx: contextlib.AbstractContextManager[Any] = ( + contextlib.nullcontext() + ) self.is_graph_capture = None self.device = None else: @@ -151,16 +156,16 @@ def __init__( def __enter__(self): if self.disabled: return self - assert ( - self.pynccl_comm - is not None), "Symmetric memory requires pynccl to be initalized" - assert ( - self.pynccl_comm.nccl_version >= 22703 - ), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory" + assert self.pynccl_comm is not None, ( + "Symmetric memory requires pynccl to be initalized" + ) + assert self.pynccl_comm.nccl_version >= 22703, ( + "NCCL version 2.27.3 or higher is required for NCCL symmetric memory" + ) if self.is_graph_capture: - assert ( - _graph_pool_id - is not None), "graph_pool_id is not set under graph capture" + assert _graph_pool_id is not None, ( + "graph_pool_id is not set under graph capture" + ) # Pause graph memory pool to use symmetric memory with cuda graph torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id) self._mem_pool_ctx.__enter__() @@ -179,8 +184,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): for segment in _cached_pool_snapshot: if segment["address"] not in _registered_base_addrs: self.pynccl_comm.register_comm_window_raw( - segment["address"], segment["total_size"]) + segment["address"], segment["total_size"] + ) _registered_base_addrs.add(segment["address"]) if self.is_graph_capture: - torch._C._cuda_beginAllocateCurrentThreadToPool( - self.device, _graph_pool_id) + torch._C._cuda_beginAllocateCurrentThreadToPool(self.device, _graph_pool_id) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 2e9a4e024de4..e4d7b0f8fb85 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -133,88 +133,141 @@ class NCCLLibrary: # const char* ncclGetErrorString(ncclResult_t result) Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), # ncclResult_t ncclGetVersion(int *version); - Function("ncclGetVersion", ncclResult_t, - [ctypes.POINTER(ctypes.c_int)]), + Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); - Function("ncclGetUniqueId", ncclResult_t, - [ctypes.POINTER(ncclUniqueId)]), + Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), # ncclResult_t ncclCommInitRank( # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); # note that ncclComm_t is a pointer type, so the first argument # is a pointer to a pointer - Function("ncclCommInitRank", ncclResult_t, [ - ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, - ctypes.c_int - ]), + Function( + "ncclCommInitRank", + ncclResult_t, + [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int], + ), # ncclResult_t ncclAllReduce( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclAllReduce", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclAllReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclReduce( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, int root, # ncclComm_t comm, cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclReduce", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclAllGather( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclAllGather", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclAllGather", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclReduceScatter( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, # cudaStream_t stream); # note that cudaStream_t is a pointer type, so the last argument # is a pointer - Function("ncclReduceScatter", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclReduceScatter", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); - Function("ncclSend", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclSend", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclRecv( # void* recvbuff, size_t count, ncclDataType_t datatype, # int src, ncclComm_t comm, cudaStream_t stream); - Function("ncclRecv", ncclResult_t, [ - buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, - ncclComm_t, cudaStream_t - ]), - + Function( + "ncclRecv", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclBroadcast( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, int root, ncclComm_t comm, # cudaStream_t stream); - Function("ncclBroadcast", ncclResult_t, [ - buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, - ctypes.c_int, ncclComm_t, cudaStream_t - ]), - + Function( + "ncclBroadcast", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -241,8 +294,7 @@ class NCCLLibrary: ), # ncclResult_t ncclCommWindowDeregister( # ncclComm_t comm, ncclWindow_t win); - Function("ncclCommWindowDeregister", ncclResult_t, - [ncclComm_t, ncclWindow_t]), + Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]), ] # class attribute to store the mapping from the path to the library @@ -254,7 +306,6 @@ class NCCLLibrary: path_to_dict_mapping: dict[str, dict[str, Any]] = {} def __init__(self, so_file: Optional[str] = None): - so_file = so_file or find_nccl_library() try: @@ -270,8 +321,10 @@ def __init__(self, so_file: Optional[str] = None): "or it does not support the current platform %s. " "If you already have the library, please set the " "environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) + " to point to the correct nccl library path.", + so_file, + platform.platform(), + ) raise e if so_file not in NCCLLibrary.path_to_dict_mapping: @@ -284,15 +337,18 @@ def __init__(self, so_file: Optional[str] = None): _funcs[func.name] = f except AttributeError: if func.name in [ - "ncclCommWindowRegister", - "ncclCommWindowDeregister" + "ncclCommWindowRegister", + "ncclCommWindowDeregister", ]: if envs.VLLM_USE_NCCL_SYMM_MEM: logger.warning_once( "The symbol %s is not found in the NCCL " "library %s. To enable VLLM_USE_NCCL_SYMM_MEM " " please update your NCCL version to >= " - "2.27.03.", func.name, so_file) + "2.27.03.", + func.name, + so_file, + ) if current_platform.is_rocm(): # Having an exception here on ROCm platform is # not allowed during graph capturing @@ -325,88 +381,153 @@ def ncclGetVersion(self) -> str: def ncclGetUniqueId(self) -> ncclUniqueId: unique_id = ncclUniqueId() - self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( - ctypes.byref(unique_id))) + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) return unique_id def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: if len(data) != 128: raise ValueError( - f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes") + f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes" + ) unique_id = ncclUniqueId() ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128) return unique_id - def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, - rank: int) -> ncclComm_t: + def ncclCommInitRank( + self, world_size: int, unique_id: ncclUniqueId, rank: int + ) -> ncclComm_t: comm = ncclComm_t() - self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), - world_size, unique_id, - rank)) + self.NCCL_CHECK( + self._funcs["ncclCommInitRank"]( + ctypes.byref(comm), world_size, unique_id, rank + ) + ) return comm - def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + def ncclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, - datatype, op, comm, - stream)) - - def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, root: int, - comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK( + self._funcs["ncclAllReduce"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count, - datatype, op, root, comm, - stream)) - - def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, op: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + self.NCCL_CHECK( + self._funcs["ncclReduce"]( + sendbuff, recvbuff, count, datatype, op, root, comm, stream + ) + ) + + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # and `op` should be `ncclRedOp_t` # both are aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, - count, datatype, op, - comm, stream)) - - def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: + self.NCCL_CHECK( + self._funcs["ncclReduceScatter"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: # `datatype` actually should be `ncclDataType_t` # which is an aliases of `ctypes.c_int` # when we pass int to a function, it will be converted to `ctypes.c_int` # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, - datatype, comm, stream)) - - def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, - dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, - dest, comm, stream)) - - def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, - src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, - comm, stream)) - - def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, root: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: - self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, - datatype, root, comm, - stream)) + self.NCCL_CHECK( + self._funcs["ncclAllGather"]( + sendbuff, recvbuff, count, datatype, comm, stream + ) + ) + + def ncclSend( + self, + sendbuff: buffer_type, + count: int, + datatype: int, + dest: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream) + ) + + def ncclRecv( + self, + recvbuff: buffer_type, + count: int, + datatype: int, + src: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) + ) + + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclBroadcast"]( + sendbuff, recvbuff, count, datatype, root, comm, stream + ) + ) def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) @@ -417,19 +538,27 @@ def ncclGroupStart(self) -> None: def ncclGroupEnd(self) -> None: self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) - def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type, - size: int, win_flags: int) -> ncclWindow_t: + def ncclCommWindowRegister( + self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int + ) -> ncclWindow_t: window = ncclWindow_t() - self.NCCL_CHECK(self._funcs["ncclCommWindowRegister"]( - comm, buff, size, ctypes.byref(window), win_flags)) + self.NCCL_CHECK( + self._funcs["ncclCommWindowRegister"]( + comm, buff, size, ctypes.byref(window), win_flags + ) + ) return window - def ncclCommWindowDeregister(self, comm: ncclComm_t, - window: ncclWindow_t) -> None: + def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window)) __all__ = [ - "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", - "ncclComm_t", "cudaStream_t", "buffer_type" + "NCCLLibrary", + "ncclDataTypeEnum", + "ncclRedOpTypeEnum", + "ncclUniqueId", + "ncclComm_t", + "cudaStream_t", + "buffer_type", ] diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 836241910e2f..16b6b6c28ea3 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -27,9 +27,10 @@ def is_weak_contiguous(inp: torch.Tensor): - return inp.is_contiguous() or (inp.storage().nbytes() - - inp.storage_offset() * inp.element_size() - == inp.numel() * inp.element_size()) + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size() + ) class QuickReduceRegime(Enum): @@ -44,7 +45,6 @@ class QuickReduceRegime(Enum): class QuickAllReduce: - _SUPPORTED_WORLD_SIZES = [2, 4, 8] _SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] # The following data is based on kernel tests. @@ -58,20 +58,21 @@ class QuickAllReduce: (torch.bfloat16, 8): [16 * MB, 2048 * MB, 2048 * MB, 2048 * MB], } - def __init__(self, group: ProcessGroup, - device: Union[int, str, torch.device]) -> None: + def __init__( + self, group: ProcessGroup, device: Union[int, str, torch.device] + ) -> None: """ - Custom allreduce provides non-destructive acceleration and is + Custom allreduce provides non-destructive acceleration and is available for CUDA and ROCm MI300 series. - Custom quick allreduce leverages quantization for further - acceleration on ROCm. It currently supports Q8, Q6, and Q4 + Custom quick allreduce leverages quantization for further + acceleration on ROCm. It currently supports Q8, Q6, and Q4 quantization formats and FP(float16, bfloat16). - Quick allreduce is designed as a complement to custom allreduce. - Its initialization requires even stricter conditions. + Quick allreduce is designed as a complement to custom allreduce. + Its initialization requires even stricter conditions. - Only the ROCm MI300 series is supported for quick allreduce at + Only the ROCm MI300 series is supported for quick allreduce at this time. Args: @@ -93,18 +94,23 @@ def __init__(self, group: ProcessGroup, if not quick_ar: # disable because of missing quick reduce library # e.g. in a cuda environment - logger.info("Custom quick allreduce is disabled because " - "of missing custom quick allreduce library") + logger.info( + "Custom quick allreduce is disabled because " + "of missing custom quick allreduce library" + ) return self.group = group assert dist.get_backend(group) != dist.Backend.NCCL, ( - "Custom quick allreduce should be attached to a non-NCCL group.") + "Custom quick allreduce should be attached to a non-NCCL group." + ) if not all(in_the_same_node_as(group, source_rank=0)): # No need to initialize custom quick allreduce for # multi-node case. - logger.warning("Custom quick allreduce is disabled because this " - "process group spans across nodes.") + logger.warning( + "Custom quick allreduce is disabled because this " + "process group spans across nodes." + ) return rank = dist.get_rank(group=self.group) world_size = dist.get_world_size(group=self.group) @@ -118,7 +124,9 @@ def __init__(self, group: ProcessGroup, logger.warning( "Custom quick allreduce is disabled due to an " "unsupported world size: %d. Supported world sizes: %s.", - world_size, str(QuickAllReduce._SUPPORTED_WORLD_SIZES)) + world_size, + str(QuickAllReduce._SUPPORTED_WORLD_SIZES), + ) return if isinstance(device, int): @@ -134,9 +142,7 @@ def __init__(self, group: ProcessGroup, else: device_ids = list(range(cuda_device_count_stateless())) physical_device_id = device_ids[device.index] - tensor = torch.tensor([physical_device_id], - dtype=torch.int, - device="cpu") + tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") gather_list = [ torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(self.world_size) @@ -148,12 +154,12 @@ def __init__(self, group: ProcessGroup, # where custom quick allreduce is not supported # this checks hardware and driver support for NVLink assert current_platform.is_cuda_alike() - self.fully_connected = current_platform.is_fully_connected( - physical_device_ids) + self.fully_connected = current_platform.is_fully_connected(physical_device_ids) if self.world_size > 2 and not self.fully_connected: logger.debug( "Custom quick allreduce is disabled because it's not supported " - "on more than two PCIe-only GPUs. ") + "on more than two PCIe-only GPUs. " + ) return self.init_quick_all_reduce() @@ -169,24 +175,31 @@ def init_quick_all_reduce(self): "Custom quick allreduce:", f"Invalid quantization level: {regime_str}. " "Supported levels: " - f"{list(QuickReduceRegime.__members__.keys())}") + f"{list(QuickReduceRegime.__members__.keys())}", + ) return if regime_str == "NONE": - logger.debug("Custom quick allreduce is disabled based " - "on env variable " - "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'") + logger.debug( + "Custom quick allreduce is disabled based " + "on env variable " + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION='NONE'" + ) return self.qr_quant_level = QuickReduceRegime[regime_str] vllm_config = get_current_vllm_config() - if vllm_config is not None and \ - hasattr(vllm_config, "model_config") and \ - hasattr(vllm_config.model_config, "dtype"): + if ( + vllm_config is not None + and hasattr(vllm_config, "model_config") + and hasattr(vllm_config.model_config, "dtype") + ): dtype = vllm_config.model_config.dtype if dtype not in [torch.float16, torch.bfloat16]: logger.debug( "Custom quick allreduce disabled: only supports " - "float16 and float16, but get %s.", dtype) + "float16 and float16, but get %s.", + dtype, + ) return if dtype == torch.bfloat16 and self.use_fp16_kernels: @@ -194,7 +207,8 @@ def init_quick_all_reduce(self): "Custom quick allreduce: BF16 inputs will be converted " "to FP16 to improve performance. set " "envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16=0 " - "to turn off.") + "to turn off." + ) # VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is specified in MB qr_max_size = envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB @@ -206,8 +220,7 @@ def init_quick_all_reduce(self): ) qr_max_size = qr_max_size * MB self._ptr = ops.init_custom_qr(self.rank, self.world_size, qr_max_size) - self.qr_max_size = qr_max_size if qr_max_size is not None \ - else ops.qr_max_size() + self.qr_max_size = qr_max_size if qr_max_size is not None else ops.qr_max_size() self.create_shared_buffer() self.disabled = False @@ -217,16 +230,15 @@ def _rocm_arch_available(self): try: props = torch.cuda.get_device_properties(0) gcn_arch = getattr(props, "gcnArchName", "") - supported_archs = ['gfx94', 'gfx95'] + supported_archs = ["gfx94", "gfx95"] return any(gfx in gcn_arch for gfx in supported_archs) except Exception as e: - logger.warning("Failed to determine ROCm for quick allreduce: %s", - e) + logger.warning("Failed to determine ROCm for quick allreduce: %s", e) return False def create_shared_buffer(self): """ - Creates a shared buffer for quickreduce. + Creates a shared buffer for quickreduce. Has to be called after init_custom_qr """ handle = ops.qr_get_handle(self._ptr) @@ -253,9 +265,11 @@ def should_quick_allreduce(self, inp: torch.Tensor): dtype = inp.dtype if self.use_fp16_kernels: dtype = torch.float16 - return inp_size <= self.qr_max_size and \ - inp_size >= self._QR_MIN_SIZE[(dtype, self.world_size)]\ - [self.qr_quant_level.value] + return ( + inp_size <= self.qr_max_size + and inp_size + >= self._QR_MIN_SIZE[(dtype, self.world_size)][self.qr_quant_level.value] + ) def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): """Performs an out-of-place custom quick all reduce.""" @@ -263,8 +277,9 @@ def quick_all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None): # as QR uses static IPC buffer. if out is None: out = torch.empty_like(inp) - ops.qr_all_reduce(self._ptr, inp, out, self.qr_quant_level.value, - self.use_fp16_kernels) + ops.qr_all_reduce( + self._ptr, inp, out, self.qr_quant_level.value, self.use_fp16_kernels + ) return out def close(self): diff --git a/vllm/distributed/device_communicators/ray_communicator.py b/vllm/distributed/device_communicators/ray_communicator.py index 69efc8b45270..da79afc7ac14 100644 --- a/vllm/distributed/device_communicators/ray_communicator.py +++ b/vllm/distributed/device_communicators/ray_communicator.py @@ -6,12 +6,12 @@ import ray import torch from ray.exceptions import RayChannelError -from ray.experimental.channel.communicator import (Communicator, - TorchTensorAllocator) +from ray.experimental.channel.communicator import Communicator, TorchTensorAllocator from torch.distributed import ReduceOp from vllm.distributed.device_communicators.base_device_communicator import ( - DeviceCommunicatorBase) + DeviceCommunicatorBase, +) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.utils import current_stream @@ -59,11 +59,11 @@ def __init__( self._rank: Optional[int] = None self._actor_handles = actor_handles if use_communication_streams: - raise NotImplementedError( - "use_communication_streams is not supported") + raise NotImplementedError("use_communication_streams is not supported") if cuda_stream is not None and cuda_stream != current_stream(): raise ValueError( - "cuda_stream other than the current stream is not supported") + "cuda_stream other than the current stream is not supported" + ) if rank is not None: # Rank is not None, this is Ray worker @@ -99,13 +99,14 @@ def _build_actor_rank_mapping(self): # Ray actor IDs are 32-character hex strings (128 bits) ACTOR_ID_LEN = 32 - actor_id_bytes = actor_id_str.encode('utf-8') - assert len( - actor_id_bytes - ) == ACTOR_ID_LEN, f"Unexpected actor ID length: {len(actor_id_bytes)}" + actor_id_bytes = actor_id_str.encode("utf-8") + assert len(actor_id_bytes) == ACTOR_ID_LEN, ( + f"Unexpected actor ID length: {len(actor_id_bytes)}" + ) - actor_id_tensor = torch.frombuffer( - actor_id_bytes, dtype=torch.uint8).to(self._comm.device) + actor_id_tensor = torch.frombuffer(actor_id_bytes, dtype=torch.uint8).to( + self._comm.device + ) # All-gather full actor IDs from all actors gathered_ids = self._comm.all_gather(actor_id_tensor, dim=0) @@ -115,9 +116,8 @@ def _build_actor_rank_mapping(self): for rank in range(self._world_size): start_idx = rank * ACTOR_ID_LEN end_idx = (rank + 1) * ACTOR_ID_LEN - actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy( - ).tobytes() - actor_id = actor_bytes.decode('utf-8') + actor_bytes = gathered_ids[start_idx:end_idx].cpu().numpy().tobytes() + actor_id = actor_bytes.decode("utf-8") self._actor_id_to_rank[actor_id] = rank def initialize(self, rank: int) -> None: @@ -131,9 +131,10 @@ def get_rank(self, actor: ray.actor.ActorHandle) -> int: """ Return the given actor's rank using device communicator collective ops. """ - assert hasattr(self, '_actor_id_to_rank'), ( + assert hasattr(self, "_actor_id_to_rank"), ( "Actor rank mapping not built. " - "This should have been done during initialization.") + "This should have been done during initialization." + ) actor_id_str = actor._actor_id.hex() diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 0fc9d1cf4f51..4cec60102728 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -14,14 +14,24 @@ import torch.distributed as dist import zmq from torch.distributed import ProcessGroup -from zmq import IPV6 # type: ignore -from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore +from zmq import ( # type: ignore + IPV6, # type: ignore + SUB, + SUBSCRIBE, + XPUB, + XPUB_VERBOSE, + Context, +) import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup, sched_yield from vllm.logger import init_logger -from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, - is_valid_ipv6_address) +from vllm.utils import ( + get_ip, + get_open_port, + get_open_zmq_ipc_path, + is_valid_ipv6_address, +) VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL @@ -29,7 +39,6 @@ class SpinTimer: - def record_activity(self): pass @@ -66,12 +75,13 @@ def spin(self): class ShmRingBuffer: - - def __init__(self, - n_reader: int, - max_chunk_bytes: int, - max_chunks: int, - name: Optional[str] = None): + def __init__( + self, + n_reader: int, + max_chunk_bytes: int, + max_chunks: int, + name: Optional[str] = None, + ): """ A shared memory ring buffer implementation for broadcast communication. Essentially, it is a queue where only one will `enqueue` and multiple @@ -120,13 +130,14 @@ def __init__(self, created object to other processes by pickling it. The other processes will get the name of the shared memory and open it, so that they can access the same shared memory buffer. - """# noqa + """ # noqa self.n_reader = n_reader self.metadata_size = 1 + n_reader self.max_chunk_bytes = max_chunk_bytes self.max_chunks = max_chunks - self.total_bytes_of_buffer = (self.max_chunk_bytes + - self.metadata_size) * self.max_chunks + self.total_bytes_of_buffer = ( + self.max_chunk_bytes + self.metadata_size + ) * self.max_chunks self.data_offset = 0 self.metadata_offset = self.max_chunk_bytes * self.max_chunks @@ -134,10 +145,10 @@ def __init__(self, # we are creating a buffer self.is_creator = True self.shared_memory = shared_memory.SharedMemory( - create=True, size=self.total_bytes_of_buffer) + create=True, size=self.total_bytes_of_buffer + ) # initialize the metadata section to 0 - with self.shared_memory.buf[self. - metadata_offset:] as metadata_buffer: + with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer: torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0) else: # we are opening an existing buffer @@ -145,8 +156,10 @@ def __init__(self, # fix to https://stackoverflow.com/q/62748654/9191338 # Python incorrectly tracks shared memory even if it is not # created by the process. The following patch is a workaround. - with patch("multiprocessing.resource_tracker.register", - lambda *args, **kwargs: None): + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): try: self.shared_memory = shared_memory.SharedMemory(name=name) # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa @@ -154,8 +167,7 @@ def __init__(self, # so the shared memory block size may be larger or equal # to the requested size. The size parameter is ignored # when attaching to an existing block. - assert (self.shared_memory.size - >= self.total_bytes_of_buffer) + assert self.shared_memory.size >= self.total_bytes_of_buffer except FileNotFoundError: # we might deserialize the object in a different node # in this case, this object is not used, @@ -163,8 +175,12 @@ def __init__(self, pass def handle(self): - return (self.n_reader, self.max_chunk_bytes, self.max_chunks, - self.shared_memory.name) + return ( + self.n_reader, + self.max_chunk_bytes, + self.max_chunks, + self.shared_memory.name, + ) def __reduce__(self): return ( @@ -204,7 +220,6 @@ class Handle: class MessageQueue: - def __init__( self, n_reader, # number of all readers @@ -228,8 +243,7 @@ def __init__( # for local readers, we will: # 1. create a shared memory ring buffer to communicate small data # 2. create a publish-subscribe socket to communicate large data - self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, - max_chunks) + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks) # XPUB is very similar to PUB, # except that it can receive subscription messages @@ -279,8 +293,7 @@ def __init__( self.handle = Handle( local_reader_ranks=local_reader_ranks, - buffer_handle=self.buffer.handle() - if self.buffer is not None else None, + buffer_handle=self.buffer.handle() if self.buffer is not None else None, local_subscribe_addr=local_subscribe_addr, remote_subscribe_addr=remote_subscribe_addr, remote_addr_ipv6=remote_addr_ipv6, @@ -315,8 +328,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.remote_socket = None - self._read_spin_timer = SpinSleepTimer( - ) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() + self._read_spin_timer = ( + SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() + ) else: self.buffer = None # type: ignore self.current_idx = -1 @@ -399,7 +413,8 @@ def acquire_write(self, timeout: Optional[float] = None): " in %s seconds. This typically happens when some" " processes are hanging or doing some" " time-consuming work (e.g. compilation)", - VLLM_RINGBUFFER_WARNING_INTERVAL) + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) n_warning += 1 continue @@ -423,15 +438,16 @@ def acquire_write(self, timeout: Optional[float] = None): metadata_buffer[i] = 0 # mark the block as written metadata_buffer[0] = 1 - self.current_idx = (self.current_idx + - 1) % self.buffer.max_chunks + self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break @contextmanager - def acquire_read(self, - timeout: Optional[float] = None, - cancel: Optional[Event] = None, - indefinite: bool = False): + def acquire_read( + self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None, + indefinite: bool = False, + ): assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() n_warning = 1 @@ -460,15 +476,16 @@ def acquire_read(self, raise TimeoutError # if we wait for a long time, log a message - if not indefinite and (elapsed - > VLLM_RINGBUFFER_WARNING_INTERVAL * - n_warning): + if not indefinite and ( + elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning + ): logger.info( "No available shared memory broadcast block found" " in %s seconds. This typically happens when some" " processes are hanging or doing some" " time-consuming work (e.g. compilation).", - VLLM_RINGBUFFER_WARNING_INTERVAL) + VLLM_RINGBUFFER_WARNING_INTERVAL, + ) n_warning += 1 continue @@ -480,14 +497,13 @@ def acquire_read(self, # caller has read from the buffer # set the read flag metadata_buffer[self.local_reader_rank + 1] = 1 - self.current_idx = (self.current_idx + - 1) % self.buffer.max_chunks + self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks self._read_spin_timer.record_activity() break def enqueue(self, obj, timeout: Optional[float] = None): - """ Write to message queue with optional timeout (in seconds) """ + """Write to message queue with optional timeout (in seconds)""" assert self._is_writer, "Only writers can enqueue" serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) if self.n_local_reader > 0: @@ -498,15 +514,17 @@ def enqueue(self, obj, timeout: Optional[float] = None): else: with self.acquire_write(timeout) as buf: buf[0] = 0 # not overflow - buf[1:len(serialized_obj) + 1] = serialized_obj + buf[1 : len(serialized_obj) + 1] = serialized_obj if self.n_remote_reader > 0: self.remote_socket.send(serialized_obj) - def dequeue(self, - timeout: Optional[float] = None, - cancel: Optional[Event] = None, - indefinite: bool = False): - """ Read from message queue with optional timeout (in seconds) """ + def dequeue( + self, + timeout: Optional[float] = None, + cancel: Optional[Event] = None, + indefinite: bool = False, + ): + """Read from message queue with optional timeout (in seconds)""" if self._is_local_reader: with self.acquire_read(timeout, cancel, indefinite) as buf: overflow = buf[0] == 1 @@ -539,11 +557,12 @@ def broadcast_object(self, obj=None): return self.dequeue() @staticmethod - def create_from_process_group(pg: Union[ProcessGroup, - StatelessProcessGroup], - max_chunk_bytes, - max_chunks, - writer_rank=0) -> "MessageQueue": + def create_from_process_group( + pg: Union[ProcessGroup, StatelessProcessGroup], + max_chunk_bytes, + max_chunks, + writer_rank=0, + ) -> "MessageQueue": if isinstance(pg, ProcessGroup): group_rank = dist.get_rank(pg) group_world_size = dist.get_world_size(pg) @@ -554,6 +573,7 @@ def create_from_process_group(pg: Union[ProcessGroup, global_ranks = list(range(pg.world_size)) from vllm.distributed.parallel_state import in_the_same_node_as + status = in_the_same_node_as(pg, source_rank=writer_rank) same_node_ranks = [i for i, s in enumerate(status) if s] n_reader = group_world_size - 1 @@ -570,17 +590,17 @@ def create_from_process_group(pg: Union[ProcessGroup, ) handle = buffer_io.export_handle() if isinstance(pg, ProcessGroup): - dist.broadcast_object_list([handle], - src=global_ranks[writer_rank], - group=pg) + dist.broadcast_object_list( + [handle], src=global_ranks[writer_rank], group=pg + ) else: pg.broadcast_obj(handle, writer_rank) else: if isinstance(pg, ProcessGroup): recv = [None] - dist.broadcast_object_list(recv, - src=global_ranks[writer_rank], - group=pg) + dist.broadcast_object_list( + recv, src=global_ranks[writer_rank], group=pg + ) handle = recv[0] # type: ignore else: handle = pg.broadcast_obj(None, writer_rank) diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py index 0310fc14da25..a5486c30edf2 100644 --- a/vllm/distributed/device_communicators/shm_object_storage.py +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -24,63 +24,63 @@ class SingleWriterShmRingBuffer: A single-writer, multiple-reader ring buffer implementation using shared memory. This class provides a thread-safe ring buffer where one process can write data while multiple processes/threads can read from it. - + Architecture: - Uses shared memory for cross-process communication - Maintains metadata for each allocated buffer chunk in the writer process - Supports custom "is_free_fn" functions to determine when buffers can be reused - Each buffer chunk contains: `[4-byte id][4-byte size][actual_data]` - + Key Concepts: - monotonic_id_start/end: Track the range of active buffer IDs - data_buffer_start/end: Track the physical memory range in use - Automatic wraparound when reaching buffer end - Lazy garbage collection based on is_free_fn checks - + Example Usage Scenarios: - + Scenario 1: Simple Linear Allocation ``` Buffer size: 100 bytes Initial state: [................................................. ] ^start=end(0) - + After allocating 20 bytes (id=0): [id:0|size:20|data........][...................................] ^start(0) ^end(28) - - After allocating 30 bytes (id=1): + + After allocating 30 bytes (id=1): [id:0|size:20|data........][id:1|size:30|data..............][..] ^start(0) ^end(66) ``` - + Scenario 2: Memory Reclamation ``` Before freeing (both buffers still in use): [id:0|size:20|data........][id:1|size:30|data..............][..] ^start(0) ^end(66) - + After id:0 is marked free by readers: [FREED.................... ][id:1|size:30|data..............][..] ^start(28) ^end(66) - + After both are freed: [FREED..............................................][..] ^start=end(66) ``` - + Scenario 3: Wraparound Allocation (continuing from Scenario 2) ``` Starting from after memory reclamation in Scenario 2: [FREED..............................................][..] ^start=end(66) - + Allocate 40 bytes (id=2) - only 34 bytes available at end, so wraparound: [id:2|size:40|data........................][FREED.............][..] ^end(148) ^start(66) ``` - + Scenario 4: Error Handling - Out of Space ``` Starting from after wraparound allocation in Scenario 3: @@ -91,17 +91,17 @@ class SingleWriterShmRingBuffer: occupied_size_new = end + size - start = 148 + 28 - 66 > buffer_size(100) -> Raises MemoryError: "Not enough space in the data buffer" ``` - + Thread Safety: - Single writer: Only one process/thread should write (allocate_buf) - - Multiple readers: Multiple processes/threads can read (access_buf) + - Multiple readers: Multiple processes/threads can read (access_buf) - Reader synchronization handled by is_free_fn callback - Writer handles garbage collection (free_buf) based on reader feedback - + Memory Layout per Buffer Chunk: `[4-byte monotonic_id][4-byte chunk_size][actual_data...]` ^metadata_start ^data_start - + The monotonic_id ensures data integrity - readers can verify they're accessing the correct data even after buffer wraparound or reuse. """ @@ -131,15 +131,16 @@ def __init__( self.monotonic_id_end: self.data_buffer_end } # monotonic_id -> start address self.shared_memory = shared_memory.SharedMemory( - create=True, size=self.data_buffer_size, name=name) + create=True, size=self.data_buffer_size, name=name + ) else: # we are opening an existing buffer # fix to https://stackoverflow.com/q/62748654/9191338 # Python incorrectly tracks shared memory even if it is not # created by the process. The following patch is a workaround. with patch( - "multiprocessing.resource_tracker.register", - lambda *args, **kwargs: None, + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, ): self.shared_memory = shared_memory.SharedMemory(name=name) # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa @@ -149,8 +150,11 @@ def __init__( # when attaching to an existing block. assert self.shared_memory.size >= self.data_buffer_size - logger.debug("Shared memory created/opened with name: %s, size: %d", - self.shared_memory.name, self.data_buffer_size) + logger.debug( + "Shared memory created/opened with name: %s, size: %d", + self.shared_memory.name, + self.data_buffer_size, + ) def handle(self): return ( @@ -182,19 +186,20 @@ def byte2int(self, byte_data: bytes) -> int: return int.from_bytes(byte_data, "little", signed=True) def allocate_buf(self, size: int) -> tuple[int, int]: - ''' + """ Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory. Memory layout: `[4-byte monotonic_id][4-byte size][buffer data...]` - ''' + """ assert self.is_writer, "Only the writer can allocate buffers." assert size > 0, "Size must be greater than 0" size += self.MD_SIZE # add metadata size to the buffer size # reset to beginning if the buffer does have enough contiguous space buffer_end_reset = self.data_buffer_end % self.data_buffer_size if buffer_end_reset + size > self.data_buffer_size: - buffer_end_reset = (self.data_buffer_end // self.data_buffer_size + - 1) * self.data_buffer_size + buffer_end_reset = ( + self.data_buffer_end // self.data_buffer_size + 1 + ) * self.data_buffer_size else: # no reset needed buffer_end_reset = self.data_buffer_end @@ -203,21 +208,24 @@ def allocate_buf(self, size: int) -> tuple[int, int]: # exceeds the start of the data buffer occupied_size_new = buffer_end_reset + size - self.data_buffer_start if occupied_size_new > self.data_buffer_size: - raise MemoryError("Not enough space in the data buffer, " - "try calling free_buf() to free up space") + raise MemoryError( + "Not enough space in the data buffer, " + "try calling free_buf() to free up space" + ) self.data_buffer_end = buffer_end_reset # first 4 bytes as the monotonic id buf_idx = self.data_buffer_end % self.data_buffer_size - self.shared_memory.buf[buf_idx:buf_idx + self.ID_NBYTES] = \ - self.int2byte(self.monotonic_id_end) + self.shared_memory.buf[buf_idx : buf_idx + self.ID_NBYTES] = self.int2byte( + self.monotonic_id_end + ) # next 4 bytes as the size of the data buffer - self.shared_memory.buf[buf_idx + self.ID_NBYTES: \ - buf_idx + self.MD_SIZE] = self.int2byte(size) + self.shared_memory.buf[buf_idx + self.ID_NBYTES : buf_idx + self.MD_SIZE] = ( + self.int2byte(size) + ) # record metadata - self.metadata[self.monotonic_id_end % - self.ID_MAX] = self.data_buffer_end + self.metadata[self.monotonic_id_end % self.ID_MAX] = self.data_buffer_end # update buffer and monotonic id indices current_buffer_end = self.data_buffer_end current_id_end = self.monotonic_id_end @@ -230,23 +238,26 @@ def access_buf(self, address: int): buf_idx = address % self.data_buffer_size # read metadata - metadata_buff = self.shared_memory.buf[buf_idx:buf_idx + self.MD_SIZE] - id = self.byte2int(metadata_buff[:self.ID_NBYTES]) - size = self.byte2int(metadata_buff[self.ID_NBYTES:self.MD_SIZE]) + metadata_buff = self.shared_memory.buf[buf_idx : buf_idx + self.MD_SIZE] + id = self.byte2int(metadata_buff[: self.ID_NBYTES]) + size = self.byte2int(metadata_buff[self.ID_NBYTES : self.MD_SIZE]) # yield the data buffer and metadata - data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE:buf_idx + - size] - with (memoryview(data_buff) as data_view, ): + data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE : buf_idx + size] + with ( + memoryview(data_buff) as data_view, + ): yield data_view, (id, size) - def free_buf(self, - is_free_fn: Callable[[int, memoryview], bool], - nbytes: Optional[int] = None) -> Iterable[int]: - ''' + def free_buf( + self, + is_free_fn: Callable[[int, memoryview], bool], + nbytes: Optional[int] = None, + ) -> Iterable[int]: + """ Free a buffer of the given size. This is a no-op in shared memory, but we need to keep track of the metadata. - + If freed memory spreads across the end and start of the ring buffer, the actual freed memory will be in two segments. In this case there still might not be a contiguous space of `nbytes` available. @@ -254,13 +265,15 @@ def free_buf(self, Args: nbytes (int, optional): The size of the buffer to free. If None, frees the maximum size of the ring buffer. - ''' + """ assert self.is_writer, "Only the writer can free buffers." logger.debug( "Freeing up space in the ring buffer, " "monotonic_id_start: %d, monotonic_id_end: %d", - self.monotonic_id_start, self.monotonic_id_end) + self.monotonic_id_start, + self.monotonic_id_end, + ) monotonic_id_before = self.monotonic_id_start # if nbytes is None, free up the maximum size of the ring buffer if nbytes is None: @@ -272,8 +285,9 @@ def free_buf(self, if is_free_fn(self.monotonic_id_start, data_buff): # check passed, we can free the buffer del self.metadata[self.monotonic_id_start] - self.monotonic_id_start = ((self.monotonic_id_start + 1) % - self.ID_MAX) + self.monotonic_id_start = ( + self.monotonic_id_start + 1 + ) % self.ID_MAX self.data_buffer_start = address freed_bytes += metadata[1] else: @@ -282,8 +296,11 @@ def free_buf(self, logger.debug( "Freed %d bytes from the ring buffer, " - "monotonic_id_start: %d, monotonic_id_end: %d", freed_bytes, - self.monotonic_id_start, self.monotonic_id_end) + "monotonic_id_start: %d, monotonic_id_end: %d", + freed_bytes, + self.monotonic_id_start, + self.monotonic_id_end, + ) # buffer wrap around if self.data_buffer_start >= self.data_buffer_size: @@ -295,12 +312,12 @@ def free_buf(self, if monotonic_id_after >= monotonic_id_before: return range(monotonic_id_before, monotonic_id_after) else: - return chain(range(monotonic_id_before, self.ID_MAX), - range(0, monotonic_id_after)) + return chain( + range(monotonic_id_before, self.ID_MAX), range(0, monotonic_id_after) + ) class ObjectSerde(ABC): - @abstractmethod def serialize(self, value: Any) -> tuple[Any, int, bytes, int]: """Serialize an object to bytes.""" @@ -313,7 +330,6 @@ def deserialize(self, data: memoryview) -> Any: class MsgpackSerde(ObjectSerde): - def __init__(self): # Delayed import to avoid circular dependency from vllm.multimodal.inputs import MultiModalKwargsItem @@ -325,8 +341,8 @@ def __init__(self): self._mm_kwargs_item_cls = MultiModalKwargsItem def serialize( - self, - value: Any) -> tuple[Union[bytes, list[bytes]], int, bytes, int]: + self, value: Any + ) -> tuple[Union[bytes, list[bytes]], int, bytes, int]: len_arr = None if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)): type_name = type(value).__name__ @@ -339,8 +355,9 @@ def serialize( nbytes = len(value) object_metadata = (type_name, nbytes, len_arr) - serialized_metadata = pickle.dumps(object_metadata, - protocol=pickle.HIGHEST_PROTOCOL) + serialized_metadata = pickle.dumps( + object_metadata, protocol=pickle.HIGHEST_PROTOCOL + ) return value, nbytes, serialized_metadata, len(serialized_metadata) def deserialize(self, data_view: memoryview) -> Any: @@ -353,7 +370,7 @@ def deserialize(self, data_view: memoryview) -> Any: obj = [] start_idx = 0 for length in len_arr: - item_bytes = serialized_data[start_idx:start_idx + length] + item_bytes = serialized_data[start_idx : start_idx + length] obj.append(item_bytes) start_idx += length obj = self.tensor_decoder.decode(obj) @@ -361,15 +378,14 @@ def deserialize(self, data_view: memoryview) -> Any: obj = [] start_idx = 0 for length in len_arr: - item_bytes = serialized_data[start_idx:start_idx + length] + item_bytes = serialized_data[start_idx : start_idx + length] obj.append(item_bytes) start_idx += length obj = self.mm_decoder.decode(obj) elif type_name == bytes.__name__: obj = pickle.loads(serialized_data) else: - raise ValueError( - f"Unsupported object type '{type_name}' in metadata") + raise ValueError(f"Unsupported object type '{type_name}' in metadata") return obj @@ -388,18 +404,18 @@ class SingleWriterShmObjectStorage: A single-writer, multiple-reader object storage system built on top of a shared memory ring buffer. Provides key-value storage with automatic memory management and cross-process serialization support. - + This storage system follows a FIFO (First-In-First-Out) eviction policy where the oldest objects are automatically freed when memory runs low. Memory is reclaimed based on reader reference counting - objects are only freed when all readers have finished accessing them. - + Architecture: - Single writer process can put(key, value) objects - Multiple reader processes can get(address, monotonic_id) objects - Built on SingleWriterShmRingBuffer for efficient shared memory management - Thread-safe operations with reader synchronization via locks - + Key Features: - FIFO Eviction: Oldest objects are evicted first when memory is full - Reference Counting: Objects are only freed when no readers are @@ -414,7 +430,7 @@ class SingleWriterShmObjectStorage: Memory Layout per Object: `[4-byte reference_count][metadata_size][serialized_object_data]` - + Thread Safety: - Writer operations (put, clear) are single-threaded by design - Reader operations (get) are thread-safe with lock-based reference @@ -482,18 +498,17 @@ def copy_to_buffer( md_bytes: int, data_view: memoryview, ) -> None: - data_view[self.flag_bytes:self.flag_bytes + md_bytes] = metadata + data_view[self.flag_bytes : self.flag_bytes + md_bytes] = metadata if isinstance(data, bytes): data_view[-data_bytes:] = data elif isinstance(data, list): start_idx = self.flag_bytes + md_bytes for item_bytes in data: item_size = len(item_bytes) - data_view[start_idx:start_idx + item_size] = item_bytes + data_view[start_idx : start_idx + item_size] = item_bytes start_idx += item_size else: - raise ValueError( - f"Unsupported data type for serialization: {type(data)}") + raise ValueError(f"Unsupported data type for serialization: {type(data)}") def increment_writer_flag(self, id: int) -> None: """Set the in-use flag for the writer.""" @@ -509,8 +524,9 @@ def free_unused(self) -> None: """Free unused buffers in the ring buffer.""" # try to free up 2*max_object_size bytes of space in the ring buffer, # since the buffer might be fragmented - freed_ids = self.ring_buffer.free_buf(self.default_is_free_check, - 2 * self.max_object_size) + freed_ids = self.ring_buffer.free_buf( + self.default_is_free_check, 2 * self.max_object_size + ) # update the metadata after freeing up space for freed_id in freed_ids: key_to_free = self.id_index[freed_id] @@ -537,7 +553,7 @@ def put(self, key: str, value: Any) -> tuple[int, int]: Store a key-value pair in the object storage. Attempts to free max_object_size bytes using FIFO order when the ring buffer runs out of space during a put() operation. - + Args: key: String key to identify the object value: Any serializable Python object @@ -550,15 +566,17 @@ def put(self, key: str, value: Any) -> tuple[int, int]: if key in self.key_index: raise ValueError(f"Key '{key}' already exists in the storage.") - object_data, data_bytes, object_metadata, md_bytes = \ - self.ser_de.serialize(value) + object_data, data_bytes, object_metadata, md_bytes = self.ser_de.serialize( + value + ) buffer_size = self.flag_bytes + data_bytes + md_bytes # Sanity checks if buffer_size > self.max_object_size: raise ValueError( f"Serialized object size ({buffer_size} bytes) exceeds " - f"max object size ({self.max_object_size} bytes)") + f"max object size ({self.max_object_size} bytes)" + ) # Allocate new buffer try: @@ -570,9 +588,10 @@ def put(self, key: str, value: Any) -> tuple[int, int]: # Write data to buffer with self.ring_buffer.access_buf(address) as (data_view, metadata): - data_view[:self.flag_bytes] = self.ring_buffer.int2byte(0) - self.copy_to_buffer(object_data, data_bytes, object_metadata, - md_bytes, data_view) + data_view[: self.flag_bytes] = self.ring_buffer.int2byte(0) + self.copy_to_buffer( + object_data, data_bytes, object_metadata, md_bytes, data_view + ) self.increment_writer_flag(monotonic_id) # Update key index @@ -587,14 +606,15 @@ def get(self, address: int, monotonic_id: int) -> Any: if buf_metadata[0] != monotonic_id: raise ValueError( f"Data for address:id '{address}:{monotonic_id}'" - " has been modified or is invalid.") + " has been modified or is invalid." + ) - obj = self.ser_de.deserialize(data_view[self.flag_bytes:]) + obj = self.ser_de.deserialize(data_view[self.flag_bytes :]) # decrease the in-use flag for reader reads if self._reader_lock is not None: with self._reader_lock: - self.increment_reader_flag(data_view[:self.flag_bytes]) + self.increment_reader_flag(data_view[: self.flag_bytes]) else: # if self._reader_lock is None, it means we are the writer # in this case, we do not need to decrease the reader count @@ -614,7 +634,8 @@ def handle(self): @staticmethod def create_from_handle( - handle: ShmObjectStorageHandle) -> "SingleWriterShmObjectStorage": + handle: ShmObjectStorageHandle, + ) -> "SingleWriterShmObjectStorage": logger.debug("Creating storage from handle: %s", handle) ring_buffer = SingleWriterShmRingBuffer(*handle.ring_buffer_handle) return SingleWriterShmObjectStorage( diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index 09012d16978d..88451f9552c1 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -7,7 +7,8 @@ from torch.distributed import ProcessGroup from vllm.distributed.device_communicators.all_reduce_utils import ( - SYMM_MEM_ALL_REDUCE_MAX_SIZES) + SYMM_MEM_ALL_REDUCE_MAX_SIZES, +) from vllm.logger import init_logger from vllm.platforms import current_platform @@ -28,20 +29,20 @@ class SymmMemCommunicator: } def __init__( - self, - group: ProcessGroup, - device: Union[int, str, torch.device], - # add options for testing - force_multimem: Optional[bool] = None, - max_size_override: Optional[int] = None): + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + # add options for testing + force_multimem: Optional[bool] = None, + max_size_override: Optional[int] = None, + ): self.disabled = True if not symm_mem_available: return if not current_platform.is_cuda(): - logger.warning("SymmMemCommunicator: symmetric " - "memory is not available.") + logger.warning("SymmMemCommunicator: symmetric memory is not available.") return if isinstance(device, int): device = torch.device(f"cuda:{device}") @@ -52,8 +53,9 @@ def __init__( self.device = device self.group = group self.world_size = dist.get_world_size(self.group) - self.device_capability = current_platform.get_device_capability( - ).as_version_str() + self.device_capability = ( + current_platform.get_device_capability().as_version_str() + ) if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: logger.warning( "SymmMemCommunicator: Device capability %s not supported, " @@ -61,8 +63,7 @@ def __init__( self.device_capability, ) return - if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[ - self.device_capability]: + if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]: logger.warning( "SymmMemCommunicator: World size %d not supported, " "communicator is not available.", @@ -77,8 +78,9 @@ def __init__( self.max_size, ) else: - self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[ - self.device_capability][self.world_size] + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ + self.world_size + ] self.buffer = torch_symm_mem.empty( self.max_size // self.dtype.itemsize, @@ -87,8 +89,10 @@ def __init__( ) handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) if handle.multicast_ptr == 0: - logger.warning("SymmMemCommunicator: symmetric memory " - "multicast operations are not supported.") + logger.warning( + "SymmMemCommunicator: symmetric memory " + "multicast operations are not supported." + ) return self.force_multimem = force_multimem self.disabled = False @@ -104,15 +108,13 @@ def should_use_symm_mem(self, inp: torch.Tensor): return inp_size < self.max_size def all_reduce( - self, - inp: torch.Tensor, - *, - out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None + ) -> Optional[torch.Tensor]: if not self.should_use_symm_mem(inp): return None if out is None: out = torch.empty_like(inp) - self.buffer[:inp.numel()].copy_(inp.view(-1)) + self.buffer[: inp.numel()].copy_(inp.view(-1)) # Determine which algorithm to use use_multimem = False @@ -121,16 +123,17 @@ def all_reduce( use_multimem = self.force_multimem else: # Normal logic: use multimem for supported world sizes - use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM[ - self.device_capability] + use_multimem = ( + self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability] + ) if use_multimem: - torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], - "sum", - self.group.group_name) + torch.ops.symm_mem.multimem_all_reduce_( + self.buffer[: inp.numel()], "sum", self.group.group_name + ) else: - torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], - "sum", - self.group.group_name) - out.copy_(self.buffer[:inp.numel()].view(out.shape)) + torch.ops.symm_mem.two_shot_all_reduce_( + self.buffer[: inp.numel()], "sum", self.group.group_name + ) + out.copy_(self.buffer[: inp.numel()].view(out.shape)) return out diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 942dd67f065d..e0ac9df9a6af 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -14,8 +14,9 @@ from .base_device_communicator import DeviceCommunicatorBase -USE_RAY = parallel_config = get_current_vllm_config( -).parallel_config.distributed_executor_backend == "ray" +USE_RAY = parallel_config = ( + get_current_vllm_config().parallel_config.distributed_executor_backend == "ray" +) logger = init_logger(__name__) @@ -27,18 +28,21 @@ import torch_xla.runtime as xr from torch_xla._internal import pjrt from torch_xla.distributed.xla_multiprocessing import ( - create_optimized_replica_groups) + create_optimized_replica_groups, + ) + if USE_RAY: from vllm.executor import ray_utils class TpuCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = "", + ): super().__init__(cpu_group, device, device_group, unique_name) # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node @@ -98,5 +102,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: if USE_TPU_COMMONS: from tpu_commons.distributed.device_communicators import ( - TpuCommunicator as TpuCommonsCommunicator) + TpuCommunicator as TpuCommonsCommunicator, + ) + TpuCommunicator = TpuCommonsCommunicator # type: ignore diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index 27bd176554af..33d5b2cf1d87 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -16,12 +16,13 @@ class XpuCommunicator(DeviceCommunicatorBase): - - def __init__(self, - cpu_group: ProcessGroup, - device: Optional[torch.device] = None, - device_group: Optional[ProcessGroup] = None, - unique_name: str = ""): + def __init__( + self, + cpu_group: ProcessGroup, + device: Optional[torch.device] = None, + device_group: Optional[ProcessGroup] = None, + unique_name: str = "", + ): super().__init__(cpu_group, device, device_group, unique_name) if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND @@ -29,10 +30,12 @@ def __init__(self, logger.warning( "`%s` all2all manager is not supported on XPU." "Falling back to `naive` all2all manager for XPU.", - all2all_backend) + all2all_backend, + ) all2all_backend = "naive" if all2all_backend == "naive": from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") @@ -40,12 +43,12 @@ def all_reduce(self, input_) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[torch.Tensor]: assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if dim < 0: # Convert negative dim to positive. dim += input_.dim() @@ -53,20 +56,19 @@ def gather(self, # cluster so we use all_gather instead for now. input_size = input_.size() # Allocate output tensor. - output_tensor = torch.empty((self.world_size, ) + input_size, - dtype=input_.dtype, - device=input_.device) + output_tensor = torch.empty( + (self.world_size,) + input_size, dtype=input_.dtype, device=input_.device + ) # All-gather. - dist.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) + dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) if self.rank_in_group == dst: # Reshape output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (self.world_size * - input_size[dim], ) + - input_size[dim + 1:]) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (self.world_size * input_size[dim],) + + input_size[dim + 1 :] + ) else: output_tensor = None return output_tensor @@ -78,17 +80,19 @@ def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - is_sequence_parallel: bool = False + is_sequence_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: assert self.all2all_manager is not None hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits, is_sequence_parallel) + hidden_states, router_logits, is_sequence_parallel + ) return hidden_states, router_logits - def combine(self, - hidden_states: torch.Tensor, - is_sequence_parallel: bool = False) -> torch.Tensor: + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: assert self.all2all_manager is not None - hidden_states = self.all2all_manager.combine(hidden_states, - is_sequence_parallel) + hidden_states = self.all2all_manager.combine( + hidden_states, is_sequence_parallel + ) return hidden_states diff --git a/vllm/distributed/eplb/__init__.py b/vllm/distributed/eplb/__init__.py index 80511024b930..4cd51dd384ad 100644 --- a/vllm/distributed/eplb/__init__.py +++ b/vllm/distributed/eplb/__init__.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -''' +""" Expert parallelism load balancer (EPLB). -''' +""" from .eplb_state import * from .rebalance_algo import * diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 3e318d784832..663f04027046 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -35,8 +35,11 @@ from torch.distributed import ProcessGroup, all_reduce from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import (get_ep_group, get_node_count, - in_the_same_node_as) +from vllm.distributed.parallel_state import ( + get_ep_group, + get_node_count, + in_the_same_node_as, +) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -190,11 +193,10 @@ def build( """ Build the initial EPLB state. """ - physical_to_logical_map_list = ( - cls.build_initial_global_physical_to_logical_map( - model.num_routed_experts, - model.num_redundant_experts, - )) + physical_to_logical_map_list = cls.build_initial_global_physical_to_logical_map( + model.num_routed_experts, + model.num_redundant_experts, + ) physical_to_logical_map = torch.tensor( physical_to_logical_map_list, device=device, @@ -205,7 +207,8 @@ def build( MAX_EXPERT_REDUNDANCY = 1023 assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, ( f"num_redundant_experts {model.num_redundant_experts} " - f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}") + f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}" + ) max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1 logical_to_physical_map = torch.full( (model.num_logical_experts, max_slots_per_logical_expert), @@ -213,31 +216,42 @@ def build( device=device, ) logical_replica_count = torch.zeros( - (model.num_logical_experts, ), + (model.num_logical_experts,), device=device, dtype=torch.long, ) for i in range(model.num_physical_experts): logical_idx = physical_to_logical_map[i] - logical_to_physical_map[logical_idx, - logical_replica_count[logical_idx]] = i + logical_to_physical_map[logical_idx, logical_replica_count[logical_idx]] = i logical_replica_count[logical_idx] += 1 # Duplicate initial mapping for all layers - physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand( - model.num_moe_layers, - -1, - ).contiguous() - logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand( - model.num_moe_layers, - -1, - -1, - ).contiguous() - logical_replica_count = logical_replica_count.unsqueeze(0).expand( - model.num_moe_layers, - -1, - ).contiguous() + physical_to_logical_map = ( + physical_to_logical_map.unsqueeze(0) + .expand( + model.num_moe_layers, + -1, + ) + .contiguous() + ) + logical_to_physical_map = ( + logical_to_physical_map.unsqueeze(0) + .expand( + model.num_moe_layers, + -1, + -1, + ) + .contiguous() + ) + logical_replica_count = ( + logical_replica_count.unsqueeze(0) + .expand( + model.num_moe_layers, + -1, + ) + .contiguous() + ) expert_load_pass = torch.zeros( (model.num_moe_layers, model.num_physical_experts), @@ -246,21 +260,21 @@ def build( ) expert_load_window_size = parallel_config.eplb_config.window_size expert_load_window = torch.zeros( - (expert_load_window_size, model.num_moe_layers, - model.num_physical_experts), + (expert_load_window_size, model.num_moe_layers, model.num_physical_experts), dtype=torch.int32, device=device, ) # Set the initial progress of rearrangement to 3/4 eplb_step_interval = parallel_config.eplb_config.step_interval - expert_rearrangement_step = max( - 0, eplb_step_interval - eplb_step_interval // 4) + expert_rearrangement_step = max(0, eplb_step_interval - eplb_step_interval // 4) if global_expert_load is not None: ep_group = get_ep_group().device_group - assert global_expert_load.shape == (model.num_moe_layers, - model.num_logical_experts) + assert global_expert_load.shape == ( + model.num_moe_layers, + model.num_logical_experts, + ) assert global_expert_load.dtype == torch.int64 num_replicas = model.num_physical_experts @@ -273,20 +287,21 @@ def build( logger.warning_once( f"num_gpus % num_nodes != 0, " "not using hierarchical rearrangement algorithm.\n" - f"{num_gpus=}, {num_nodes=}") + f"{num_gpus=}, {num_nodes=}" + ) # Get new expert mappings ( new_physical_to_logical_map, new_logical_to_physical_map, new_logical_replica_count, - ) = (rebalance_experts( + ) = rebalance_experts( global_expert_load, num_replicas, num_groups, num_nodes, num_gpus, - )) + ) max_physical_slots = new_logical_to_physical_map.shape[-1] assert max_physical_slots <= logical_to_physical_map.shape[-1] @@ -326,11 +341,13 @@ def build( expert_rearrangement_step_interval=eplb_step_interval, ) - def step(self, - model: MixtureOfExperts, - is_dummy: bool = False, - is_profile: bool = False, - log_stats: bool = False) -> None: + def step( + self, + model: MixtureOfExperts, + is_dummy: bool = False, + is_profile: bool = False, + log_stats: bool = False, + ) -> None: """ Step the EPLB state. @@ -369,32 +386,40 @@ def step(self, all_reduce(total_expert_load_pass, group=ep_group) # num_tokens_per_rank: (num_moe_layers, num_ranks) - num_tokens_per_rank = total_expert_load_pass.reshape( - total_expert_load_pass.shape[0], ep_group.size(), - -1).sum(dim=-1).float() + num_tokens_per_rank = ( + total_expert_load_pass.reshape( + total_expert_load_pass.shape[0], ep_group.size(), -1 + ) + .sum(dim=-1) + .float() + ) # Compute balancedness ratio: # for each layer: # (mean load across ranks) / (max load across ranks) avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) - max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum( - dim=0) + max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0) # Just to make type checker happy tokens_tensors: list[float] = torch.stack( - [avg_tokens_tensor, max_tokens_tensor]).tolist() + [avg_tokens_tensor, max_tokens_tensor] + ).tolist() avg_tokens, max_tokens = tokens_tensors balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 if ep_group.rank() == 0: logger.info( - "EPLB step: avg_tokens=%.2f, max_tokens=%d, " - "balancedness=%.4f", avg_tokens, max_tokens, balancedness) + "EPLB step: avg_tokens=%.2f, max_tokens=%d, balancedness=%.4f", + avg_tokens, + max_tokens, + balancedness, + ) # Update the expert load sliding window if not is_dummy: self.expert_load_window[self.expert_load_window_step] = ( - self.expert_load_pass.clone()) + self.expert_load_pass.clone() + ) self.expert_load_window_step += 1 if self.expert_load_window_step >= self.expert_load_window_size: self.expert_load_window_step = 0 @@ -405,8 +430,7 @@ def step(self, # rearrangement step and perform rearrangement to ensure all ranks are # performing collective communication. self.expert_rearrangement_step += 1 - if (self.expert_rearrangement_step - >= self.expert_rearrangement_step_interval): + if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval: self.expert_rearrangement_step = 0 self.rearrange(model) @@ -416,8 +440,8 @@ def rearrange( is_profile: bool = False, execute_shuffle: bool = True, global_expert_load: Optional[torch.Tensor] = None, - rank_mapping: Optional[dict[int, - int]] = None) -> Optional[torch.Tensor]: + rank_mapping: Optional[dict[int, int]] = None, + ) -> Optional[torch.Tensor]: """ Rearrange the experts according to the current load. """ @@ -430,8 +454,7 @@ def rearrange( if is_main_rank: torch.cuda.synchronize() time_start = time.perf_counter() - logger.info("Rearranging experts %s...", - "(profile)" if is_profile else "") + logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") if global_expert_load is None: # Map the physical expert load to global logical experts @@ -444,23 +467,25 @@ def rearrange( ) logical_expert_load_window.scatter_add_( dim=-1, - index=self.physical_to_logical_map.unsqueeze(0).expand_as( - self.expert_load_window).long(), + index=self.physical_to_logical_map.unsqueeze(0) + .expand_as(self.expert_load_window) + .long(), src=self.expert_load_window, ) if not execute_shuffle: metadata = torch.tensor( [ - model.num_moe_layers, model.num_logical_experts, - self.physical_to_logical_map.shape[1] + model.num_moe_layers, + model.num_logical_experts, + self.physical_to_logical_map.shape[1], ], dtype=torch.int32, device="cpu", ) - torch.distributed.broadcast(metadata, - group=get_ep_group().cpu_group, - group_src=0) + torch.distributed.broadcast( + metadata, group=get_ep_group().cpu_group, group_src=0 + ) # Perform all-reduce to get the expert load across all ranks global_expert_load_window = logical_expert_load_window.sum(dim=0) @@ -469,9 +494,9 @@ def rearrange( if not execute_shuffle: # (num_moe_layers, old_num_physical_experts) old_global_expert_indices = self.physical_to_logical_map - torch.distributed.broadcast(old_global_expert_indices, - group=ep_group, - group_src=0) + torch.distributed.broadcast( + old_global_expert_indices, group=ep_group, group_src=0 + ) return global_expert_load_window else: assert execute_shuffle @@ -486,10 +511,10 @@ def rearrange( # the GPUs to be released. cpu_group = get_ep_group().cpu_group num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping) - num_gpus = sum(new_rank != -1 - for new_rank in rank_mapping.values()) - num_replicas = num_replicas // ep_group.size( - ) * num_gpus # handle num replicas change + num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values()) + num_replicas = ( + num_replicas // ep_group.size() * num_gpus + ) # handle num replicas change else: num_nodes = get_node_count() num_gpus = ep_group.size() @@ -499,20 +524,21 @@ def rearrange( logger.warning_once( f"num_gpus % num_nodes != 0, " "not using hierarchical rearrangement algorithm.\n" - f"{num_gpus=}, {num_nodes=}") + f"{num_gpus=}, {num_nodes=}" + ) # Get new expert mappings ( new_physical_to_logical_map, new_logical_to_physical_map, new_logical_replica_count, - ) = (rebalance_experts( + ) = rebalance_experts( global_expert_load_window, num_replicas, num_groups, num_nodes, num_gpus, - )) + ) # Update expert weights rearrange_expert_weights_inplace( @@ -525,18 +551,20 @@ def rearrange( ) if not is_profile: - if self.physical_to_logical_map.shape[ - 1] != new_physical_to_logical_map.shape[1]: + if ( + self.physical_to_logical_map.shape[1] + != new_physical_to_logical_map.shape[1] + ): self.physical_to_logical_map = new_physical_to_logical_map.to( - self.physical_to_logical_map.device) + self.physical_to_logical_map.device + ) else: self.physical_to_logical_map.copy_(new_physical_to_logical_map) max_physical_slots = new_logical_to_physical_map.shape[-1] assert max_physical_slots <= self.logical_to_physical_map.shape[-1] new_logical_to_physical_map = torch.nn.functional.pad( new_logical_to_physical_map, - (0, - self.logical_to_physical_map.shape[-1] - max_physical_slots), + (0, self.logical_to_physical_map.shape[-1] - max_physical_slots), value=-1, ) self.logical_to_physical_map.copy_(new_logical_to_physical_map) @@ -560,11 +588,10 @@ def recv_state() -> tuple[torch.Tensor, torch.Tensor]: """ ep_group = get_ep_group() metadata = torch.empty(3, dtype=torch.int32, device="cpu") - torch.distributed.broadcast(metadata, - group=ep_group.cpu_group, - group_src=0) + torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0) num_moe_layers, num_logical_experts, num_old_physical_experts = ( - metadata.tolist()) + metadata.tolist() + ) global_expert_load = torch.zeros( (num_moe_layers, num_logical_experts), dtype=torch.int64, @@ -576,9 +603,9 @@ def recv_state() -> tuple[torch.Tensor, torch.Tensor]: dtype=torch.int64, device=ep_group.device, ) - torch.distributed.broadcast(old_global_expert_indices, - group=ep_group.device_group, - group_src=0) + torch.distributed.broadcast( + old_global_expert_indices, group=ep_group.device_group, group_src=0 + ) return global_expert_load, old_global_expert_indices diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index fc43dbe3b653..c9d30d6481ab 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -15,8 +15,9 @@ import torch -def balanced_packing(weight: torch.Tensor, - num_packs: int) -> tuple[torch.Tensor, torch.Tensor]: +def balanced_packing( + weight: torch.Tensor, num_packs: int +) -> tuple[torch.Tensor, torch.Tensor]: """ Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs are as balanced as possible. @@ -34,25 +35,21 @@ def balanced_packing(weight: torch.Tensor, groups_per_pack = num_groups // num_packs if groups_per_pack == 1: - pack_index = torch.arange(weight.size(-1), - dtype=torch.int64, - device=weight.device).expand(weight.shape) + pack_index = torch.arange( + weight.size(-1), dtype=torch.int64, device=weight.device + ).expand(weight.shape) rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) return pack_index, rank_in_pack indices = weight.float().sort(-1, descending=True).indices.cpu() - pack_index = torch.full_like(weight, - fill_value=-1, - dtype=torch.int64, - device="cpu") + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") rank_in_pack = torch.full_like(pack_index, fill_value=-1) for i in range(num_layers): pack_weights = [0] * num_packs pack_items = [0] * num_packs for group in indices[i]: pack = min( - (i - for i in range(num_packs) if pack_items[i] < groups_per_pack), + (i for i in range(num_packs) if pack_items[i] < groups_per_pack), key=pack_weights.__getitem__, ) assert pack_items[pack] < groups_per_pack @@ -64,8 +61,8 @@ def balanced_packing(weight: torch.Tensor, def replicate_experts( - weight: torch.Tensor, - num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + weight: torch.Tensor, num_phy: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. @@ -83,8 +80,7 @@ def replicate_experts( num_redundant = num_phy - num_log assert num_redundant >= 0 device = weight.device - phy2log = torch.arange(num_phy, dtype=torch.int64, - device=device).repeat(n, 1) + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) arangen = torch.arange(n, dtype=torch.int64, device=device) @@ -108,7 +104,7 @@ def rebalance_experts_hierarchical( weight: [num_moe_layers, num_logical_experts] num_physical_experts: number of physical experts after replication num_groups: number of expert groups - num_nodes: number of server nodes, where the intra-node network + num_nodes: number of server nodes, where the intra-node network (e.g., NVLink) is faster num_gpus: number of GPUs, must be a multiple of `num_nodes` @@ -134,45 +130,51 @@ def inverse(perm: torch.Tensor) -> torch.Tensor: inv.scatter_( 1, perm, - torch.arange(perm.size(1), dtype=torch.int64, - device=perm.device).expand(perm.shape), + torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand( + perm.shape + ), ) return inv # Step 1: pack groups to nodes tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) - group_pack_index, group_rank_in_pack = balanced_packing( - tokens_per_group, num_nodes) - log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * - group_size).unsqueeze(-1) + - torch.arange(group_size, - dtype=torch.int64, - device=group_pack_index.device)).flatten(-2) + group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) + log2mlog = ( + ( + (group_pack_index * groups_per_node + group_rank_in_pack) * group_size + ).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device) + ).flatten(-2) mlog2log = inverse(log2mlog) # Step 2: construct redundant experts within nodes # [num_layers * num_nodes, num_logical_experts // num_nodes] tokens_per_mlog = weight.gather(-1, mlog2log).view( - -1, num_logical_experts // num_nodes) + -1, num_logical_experts // num_nodes + ) phy2mlog, phyrank, mlogcnt = replicate_experts( - tokens_per_mlog, num_physical_experts // num_nodes) + tokens_per_mlog, num_physical_experts // num_nodes + ) # Step 3: pack physical_experts to GPUs # [num_layers * num_nodes, num_physical_experts // num_nodes] tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) - pack_index, rank_in_pack = balanced_packing(tokens_per_phy, - num_gpus // num_nodes) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack pphy2phy = inverse(phy2pphy) pphy2mlog = phy2mlog.gather( - -1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] - pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange( - 0, - num_logical_experts, - num_logical_experts // num_nodes, - device=group_pack_index.device, - ).view(1, -1, 1)).flatten(-2) + -1, pphy2phy + ) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = ( + pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1) + ).flatten(-2) pphy2log = mlog2log.gather(-1, pphy2mlog) pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) @@ -214,11 +216,13 @@ def rebalance_experts( if num_groups % num_nodes == 0: # use hierarchical load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, num_groups, num_nodes, num_gpus) + weight, num_replicas, num_groups, num_nodes, num_gpus + ) else: # use global load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, 1, 1, num_gpus) + weight, num_replicas, 1, 1, num_gpus + ) num_redundant_experts = num_replicas - num_logical_experts maxlogcnt = num_redundant_experts + 1 log2phy: torch.Tensor = torch.full( @@ -230,8 +234,9 @@ def rebalance_experts( log2phy.view(num_layers, -1).scatter_( -1, phy2log * maxlogcnt + phyrank, - torch.arange(num_replicas, dtype=torch.int64, - device=log2phy.device).expand(num_layers, -1), + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( + num_layers, -1 + ), ) return phy2log, log2phy, logcnt diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index f8a7d1170bb0..344fae457c9b 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -11,8 +11,13 @@ from typing import Optional import torch -from torch.distributed import (P2POp, ProcessGroup, all_gather, - batch_isend_irecv, get_global_rank) +from torch.distributed import ( + P2POp, + ProcessGroup, + all_gather, + batch_isend_irecv, + get_global_rank, +) def idx_local_to_global( @@ -132,8 +137,7 @@ def shuffle_layer( continue if old_indices[src_global] == new_indices[dst_global]: is_received_locally[dst] = True - for weight, buffer in zip(expert_weights, - expert_weights_buffer): + for weight, buffer in zip(expert_weights, expert_weights_buffer): buffer[dst].copy_(weight[src]) p2p_ops: list[P2POp] = [] @@ -177,7 +181,8 @@ def shuffle_layer( torch.distributed.isend, weight[src], dst_global, - ) for weight in expert_weights + ) + for weight in expert_weights ] # 3. Initiate receiving of weights. @@ -216,7 +221,8 @@ def shuffle_layer( torch.distributed.irecv, weight[dst], src_global, - ) for weight in expert_weights_buffer + ) + for weight in expert_weights_buffer ] # 4. Execute the P2P operations. The real communication happens here. @@ -271,29 +277,25 @@ def rearrange_expert_weights_inplace( if rank_mapping is not None: if len(rank_mapping) == ep_group.size(): # scale down - new_global_expert_indices = \ - _map_new_expert_indices_with_rank_mapping( + new_global_expert_indices = _map_new_expert_indices_with_rank_mapping( new_global_expert_indices, rank_mapping, ) else: # scale up - old_global_expert_indices = \ - _map_old_expert_indices_with_rank_mapping( + old_global_expert_indices = _map_old_expert_indices_with_rank_mapping( old_global_expert_indices, rank_mapping, ep_group.size(), ) - assert old_global_expert_indices.shape[ - 1] == new_global_expert_indices.shape[1] + assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1] num_moe_layers, num_physical_experts = old_global_expert_indices.shape assert len(expert_weights) == num_moe_layers num_local_physical_experts = next(iter(expert_weights[0])).shape[0] - assert new_global_expert_indices.shape == (num_moe_layers, - num_physical_experts) + assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) ep_rank = ep_group.rank() ep_size = ep_group.size() @@ -342,13 +344,13 @@ def _map_old_expert_indices_with_rank_mapping( ) -> torch.Tensor: """ Map the old global expert indices to the new global expert indices. - + Args: old_global_expert_indices: Shape (num_layers, old_ep_size * num_local_physical_experts). rank_mapping: Mapping from old rank to new rank. new_ep_size: New expert parallelism size. - + Returns: Mapped expert indices with shape (num_layers, new_ep_size * num_local_physical_experts). @@ -379,8 +381,9 @@ def _map_old_expert_indices_with_rank_mapping( new_start_idx = new_rank * num_local_physical_experts new_end_idx = (new_rank + 1) * num_local_physical_experts - mapped_expert_indices[:, new_start_idx:new_end_idx] = \ + mapped_expert_indices[:, new_start_idx:new_end_idx] = ( old_global_expert_indices[:, old_start_idx:old_end_idx] + ) # If new_rank is None or >= new_ep_size, the experts remain -1 # (scale down case) @@ -415,8 +418,9 @@ def _map_new_expert_indices_with_rank_mapping( new_start_idx = new_rank * num_local_physical_experts new_end_idx = (new_rank + 1) * num_local_physical_experts - mapped_expert_indices[:, old_start_idx:old_end_idx] = \ + mapped_expert_indices[:, old_start_idx:old_end_idx] = ( new_global_expert_indices[:, new_start_idx:new_end_idx] + ) return mapped_expert_indices diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 46f0cd9289b2..d93ae63e0eb4 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -22,10 +22,10 @@ class EventBatch( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False, # type: ignore[call-arg] + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] ): ts: float events: list[Any] @@ -33,11 +33,12 @@ class EventBatch( class KVCacheEvent( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False, # type: ignore[call-arg] - tag=True): + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] + tag=True, +): """Base class for all KV cache-related events""" @@ -69,14 +70,14 @@ class KVEventBatch(EventBatch): class EventPublisher(ABC): """Lightweight publisher for EventBatch batches with data parallelism support. - + In data parallel setups, each DP rank runs its own EventPublisher instance to avoid duplicate events and ensure proper event attribution: - + - Each DP rank creates a separate publisher - Publishers automatically annotate events with their data_parallel_rank - This allows consumers to distinguish events from different DP ranks - + The publisher is responsible for adding DP metadata since the scheduler operates independently of DP topology and shouldn't need DP awareness. """ @@ -130,6 +131,7 @@ class ZmqEventPublisher(EventPublisher): topic: Topic to publish events to. """ + SHUTDOWN_TIMEOUT: float = 1.0 END_SEQ = (-1).to_bytes(8, "big", signed=True) @@ -156,21 +158,22 @@ def __init__( self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank) self._replay_endpoint = self.offset_endpoint_port( - replay_endpoint, self._dp_rank) + replay_endpoint, self._dp_rank + ) self._hwm = hwm self._socket_setup() # Payload self._seq_gen = count() - self._topic_bytes = topic.encode('utf-8') + self._topic_bytes = topic.encode("utf-8") # Thread self._running = True logger.info("Starting ZMQ publisher thread") - self._thread = threading.Thread(target=self._publisher_thread, - daemon=True, - name="zmq-publisher") + self._thread = threading.Thread( + target=self._publisher_thread, daemon=True, name="zmq-publisher" + ) self._thread.start() def publish(self, events: EventBatch) -> None: @@ -220,10 +223,12 @@ def _socket_setup(self) -> None: self._pub.set_hwm(self._hwm) # Heuristic: bind if wildcard / * present, else connect. # bind stable, connect volatile convention - if (self._endpoint is not None - and ("*" in self._endpoint or "::" in self._endpoint - or self._endpoint.startswith("ipc://") - or self._endpoint.startswith("inproc://"))): + if self._endpoint is not None and ( + "*" in self._endpoint + or "::" in self._endpoint + or self._endpoint.startswith("ipc://") + or self._endpoint.startswith("inproc://") + ): self._pub.bind(self._endpoint) elif self._endpoint is not None: self._pub.connect(self._endpoint) @@ -263,8 +268,7 @@ def _publisher_thread(self) -> None: payload = self._pack.encode(event) seq_bytes = seq.to_bytes(8, "big") - self._pub.send_multipart( - (self._topic_bytes, seq_bytes, payload)) + self._pub.send_multipart((self._topic_bytes, seq_bytes, payload)) self._buffer.append((seq, payload)) self._event_queue.task_done() @@ -291,24 +295,26 @@ def _service_replay(self) -> None: # (identity, empty_delim) are stripped off by the router # receiving payload is (seq_bytes, payload) self._replay.send_multipart( - (client_id, b"", seq.to_bytes(8, "big"), buf)) + (client_id, b"", seq.to_bytes(8, "big"), buf) + ) # Send end of sequence marker # receiving payload is (-1, b""") self._replay.send_multipart((client_id, b"", self.END_SEQ, b"")) @staticmethod - def offset_endpoint_port(endpoint: Optional[str], - data_parallel_rank: int) -> Optional[str]: - """Helper function to offset the port in an endpoint by + def offset_endpoint_port( + endpoint: Optional[str], data_parallel_rank: int + ) -> Optional[str]: + """Helper function to offset the port in an endpoint by the data parallel rank. Args: - endpoint: The endpoint string + endpoint: The endpoint string (e.g., "tcp://*:5557" or "inproc://cache") data_parallel_rank: The data parallel rank to offset by Returns: - The endpoint with the port offset by data_parallel_rank + The endpoint with the port offset by data_parallel_rank or suffix appended """ # Do nothing if input is None or data_parallel_rank is 0 @@ -322,7 +328,7 @@ def offset_endpoint_port(endpoint: Optional[str], # Get everything after the last colon (the port) last_colon_idx = endpoint.rfind(":") base_addr = endpoint[:last_colon_idx] - base_port = int(endpoint[last_colon_idx + 1:]) + base_port = int(endpoint[last_colon_idx + 1 :]) new_port = base_port + data_parallel_rank return f"{base_addr}:{new_port}" return endpoint @@ -336,16 +342,15 @@ class EventPublisherFactory: } @classmethod - def register_publisher(cls, name: str, - ctor: Callable[..., EventPublisher]) -> None: + def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None: if name in cls._registry: raise KeyError(f"publisher '{name}' already registered") cls._registry[name] = ctor @classmethod - def create(cls, - config: Optional[KVEventsConfig], - data_parallel_rank: int = 0) -> EventPublisher: + def create( + cls, config: Optional[KVEventsConfig], data_parallel_rank: int = 0 + ) -> EventPublisher: """Create publisher from a config mapping.""" if not config: return NullEventPublisher() @@ -358,5 +363,4 @@ def create(cls, constructor = cls._registry[kind] except KeyError as exc: raise ValueError(f"Unknown event publisher '{kind}'") from exc - return constructor(data_parallel_rank=data_parallel_rank, - **config_dict) + return constructor(data_parallel_rank=data_parallel_rank, **config_dict) diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py index cf58e7914972..2bf4e1feb703 100644 --- a/vllm/distributed/kv_transfer/__init__.py +++ b/vllm/distributed/kv_transfer/__init__.py @@ -2,12 +2,19 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.distributed.kv_transfer.kv_transfer_state import ( - KVConnectorBaseType, ensure_kv_transfer_initialized, - ensure_kv_transfer_shutdown, get_kv_transfer_group, has_kv_transfer_group, - is_v1_kv_transfer_group) + KVConnectorBaseType, + ensure_kv_transfer_initialized, + ensure_kv_transfer_shutdown, + get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group, +) __all__ = [ - "get_kv_transfer_group", "has_kv_transfer_group", - "is_v1_kv_transfer_group", "ensure_kv_transfer_initialized", - "ensure_kv_transfer_shutdown", "KVConnectorBaseType" + "get_kv_transfer_group", + "has_kv_transfer_group", + "is_v1_kv_transfer_group", + "ensure_kv_transfer_initialized", + "ensure_kv_transfer_shutdown", + "KVConnectorBaseType", ] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 873f130ed827..395a4e20e0ba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -4,15 +4,14 @@ import importlib from typing import TYPE_CHECKING, Callable -# yapf: disable import vllm.envs as envs from vllm.distributed.kv_transfer.kv_connector.base import ( - KVConnectorBase, KVConnectorBaseType) + KVConnectorBase, + KVConnectorBaseType, +) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger -# yapf: enable - if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig @@ -24,8 +23,7 @@ class KVConnectorFactory: _registry: dict[str, Callable[[], type[KVConnectorBase]]] = {} @classmethod - def register_connector(cls, name: str, module_path: str, - class_name: str) -> None: + def register_connector(cls, name: str, module_path: str, class_name: str) -> None: """Register a connector with a lazy-loading module and class name.""" if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") @@ -43,13 +41,18 @@ def create_connector( role: KVConnectorRole, ) -> KVConnectorBase: if not envs.VLLM_USE_V1: - raise ValueError("Attempting to initialize a V1 Connector, " - f"but found {envs.VLLM_USE_V1=}") + raise ValueError( + "Attempting to initialize a V1 Connector, " + f"but found {envs.VLLM_USE_V1=}" + ) kv_transfer_config = config.kv_transfer_config connector_cls = cls.get_connector_class(kv_transfer_config) - logger.info("Creating v1 connector with name: %s and engine_id: %s", - connector_cls.__name__, kv_transfer_config.engine_id) + logger.info( + "Creating v1 connector with name: %s and engine_id: %s", + connector_cls.__name__, + kv_transfer_config.engine_id, + ) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. # Scheduler connector: # - Co-locate with scheduler process @@ -62,7 +65,7 @@ def create_connector( @classmethod def get_connector_class( - cls, kv_transfer_config: "KVTransferConfig" + cls, kv_transfer_config: "KVTransferConfig" ) -> type[KVConnectorBaseType]: """Get the connector class by name.""" connector_name = kv_transfer_config.kv_connector @@ -71,8 +74,7 @@ def get_connector_class( else: connector_module_path = kv_transfer_config.kv_connector_module_path if connector_module_path is None: - raise ValueError( - f"Unsupported connector type: {connector_name}") + raise ValueError(f"Unsupported connector type: {connector_name}") connector_module = importlib.import_module(connector_module_path) connector_cls = getattr(connector_module, connector_name) return connector_cls @@ -85,29 +87,35 @@ def get_connector_class( KVConnectorFactory.register_connector( "SharedStorageConnector", "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", - "SharedStorageConnector") + "SharedStorageConnector", +) KVConnectorFactory.register_connector( "P2pNcclConnector", "vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector", - "P2pNcclConnector") + "P2pNcclConnector", +) KVConnectorFactory.register_connector( "LMCacheConnectorV1", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", - "LMCacheConnectorV1") + "LMCacheConnectorV1", +) KVConnectorFactory.register_connector( "NixlConnector", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", - "NixlConnector") + "NixlConnector", +) KVConnectorFactory.register_connector( "MultiConnector", "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", - "MultiConnector") + "MultiConnector", +) KVConnectorFactory.register_connector( "OffloadingConnector", "vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector", - "OffloadingConnector") + "OffloadingConnector", +) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 103fba41fcb4..056ece60e84d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,6 +3,7 @@ """ KV cache helper for store. """ + from collections import defaultdict from collections.abc import Sequence from concurrent.futures import CancelledError, Future @@ -13,8 +14,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -22,14 +22,12 @@ class model_aware_kv_ops_helper: - def __init__(self, config: VllmConfig): self.is_deepseek_mla = config.model_config.is_deepseek_mla self.use_mla_opt = not envs.VLLM_MLA_DISABLE self.tp_size = config.parallel_config.tensor_parallel_size def get_model_args(self, model_executable: torch.nn.Module): - model_config = model_executable.model.config self.model_executable = model_executable num_heads = int(model_config.num_key_value_heads / self.tp_size) @@ -46,12 +44,10 @@ def get_model_args(self, model_executable: torch.nn.Module): # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. # For more details, see vllm/v1/attention/backends/mla/common.py. if self.is_deepseek_mla and self.use_mla_opt: - head_size = model_config.kv_lora_rank + \ - model_config.qk_rope_head_dim + head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim num_heads = 1 elif self.is_deepseek_mla and not self.use_mla_opt: - head_size = model_config.qk_nope_head_dim + \ - model_config.qk_rope_head_dim + head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim else: head_size = getattr(model_config, "head_dim", None) if head_size is None: @@ -68,16 +64,24 @@ def get_kv_from_cache(self, kv_cache, num_heads, head_size): value_cache = kv_cache[1].reshape(-1, num_heads, head_size) return key_cache, value_cache - def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values, - layer, kv_cache, slot_mapping, start_pos, end_pos): - + def put_kv_to_cache( + self, + model_executable: torch.nn.Module, + keys, + values, + layer, + kv_cache, + slot_mapping, + start_pos, + end_pos, + ): model_config = model_executable.model.config if self.is_deepseek_mla and self.use_mla_opt: layer.self_attn.attn = layer.self_attn.mla_attn k_c_normed_k_pe = keys.squeeze(1) - k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank] - k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:] + k_c_normed = k_c_normed_k_pe[:, : model_config.kv_lora_rank] + k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank :] ops.concat_and_cache_mla( k_c_normed.to(kv_cache.device), k_pe.to(kv_cache.device), @@ -107,12 +111,12 @@ def get_kv_connector_cache_layout(): kv_config = vllm_config.kv_transfer_config if kv_config is not None: connector_cls = KVConnectorFactory.get_connector_class(kv_config) - required_kvcache_layout = connector_cls.get_required_kvcache_layout( - vllm_config) + required_kvcache_layout = connector_cls.get_required_kvcache_layout(vllm_config) if required_kvcache_layout is not None: return required_kvcache_layout - logger.info_once("Connectors do not specify a " \ - "kv cache layout, defaulting to NHD.") + logger.info_once( + "Connectors do not specify a kv cache layout, defaulting to NHD." + ) return "NHD" @@ -126,14 +130,16 @@ def __init__(self, world_size: int): self._recv_remaining_count = defaultdict[str, int](lambda: world_size) self._send_remaining_count = defaultdict[str, int](lambda: world_size) - def aggregate(self, - outputs: list[ModelRunnerOutput], - output_rank: int = 0) -> ModelRunnerOutput: + def aggregate( + self, outputs: list[ModelRunnerOutput], output_rank: int = 0 + ) -> ModelRunnerOutput: # Aggregate kv_connector_output from all workers - def update_finished_set(req_ids: Optional[set[str]], - remaining_count_dict: dict[str, int], - finished_set: set[str]) -> None: + def update_finished_set( + req_ids: Optional[set[str]], + remaining_count_dict: dict[str, int], + finished_set: set[str], + ) -> None: for req_id in req_ids or (): remaining_count_dict[req_id] -= 1 if remaining_count_dict[req_id] == 0: @@ -148,10 +154,12 @@ def update_finished_set(req_ids: Optional[set[str]], output = model_runner_output.kv_connector_output if not output: continue - update_finished_set(output.finished_sending, - self._send_remaining_count, finished_sending) - update_finished_set(output.finished_recving, - self._recv_remaining_count, finished_recving) + update_finished_set( + output.finished_sending, self._send_remaining_count, finished_sending + ) + update_finished_set( + output.finished_recving, self._recv_remaining_count, finished_recving + ) # Aggregate kv_connector_stats from all workers. if aggregated_kv_connector_stats is None: @@ -161,10 +169,12 @@ def update_finished_set(req_ids: Optional[set[str]], if aggregated_kv_connector_stats is None: aggregated_kv_connector_stats = kv_connector_stats else: - assert isinstance(aggregated_kv_connector_stats, - type(kv_connector_stats)) - aggregated_kv_connector_stats = \ + assert isinstance( + aggregated_kv_connector_stats, type(kv_connector_stats) + ) + aggregated_kv_connector_stats = ( aggregated_kv_connector_stats.aggregate(kv_connector_stats) + ) invalid_block_ids |= output.invalid_block_ids @@ -180,18 +190,16 @@ def update_finished_set(req_ids: Optional[set[str]], return output - def async_aggregate(self, - output_futures: Sequence[Future[ModelRunnerOutput]], - output_rank: int = 0) -> Future[ModelRunnerOutput]: + def async_aggregate( + self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0 + ) -> Future[ModelRunnerOutput]: """Takes a list of futures and returns a single future which resolves to the respective list of outputs.""" result_future: Future[ModelRunnerOutput] = Future() - outputs: list[Optional[ModelRunnerOutput]] = [None - ] * len(output_futures) + outputs: list[Optional[ModelRunnerOutput]] = [None] * len(output_futures) def make_callback(idx): - def callback(fut): if result_future.done(): return @@ -206,8 +214,10 @@ def callback(fut): # this check assumes io_thread_pool uses a single thread if all(outputs): result_future.set_result( - self.aggregate(cast(list[ModelRunnerOutput], outputs), - output_rank)) + self.aggregate( + cast(list[ModelRunnerOutput], outputs), output_rank + ) + ) return callback @@ -223,12 +233,8 @@ def _make_src_and_dst_indices( src_device: Union[torch.device, str], dst_device: Union[torch.device, str], ) -> tuple[torch.Tensor, torch.Tensor]: - src_indices = torch.tensor(src_block_ids, - device=src_device, - dtype=torch.int64) - dst_indices = torch.tensor(dst_block_ids, - device=dst_device, - dtype=torch.int64) + src_indices = torch.tensor(src_block_ids, device=src_device, dtype=torch.int64) + dst_indices = torch.tensor(dst_block_ids, device=dst_device, dtype=torch.int64) return src_indices, dst_indices @@ -240,9 +246,13 @@ def copy_kv_blocks( direction: Literal["h2d", "d2h"], ) -> None: """Copy kv blocks between different buffers.""" - if not src_kv_caches or not dst_kv_caches or \ - not src_block_ids or not dst_block_ids or \ - len(src_block_ids) != len(dst_block_ids): + if ( + not src_kv_caches + or not dst_kv_caches + or not src_block_ids + or not dst_block_ids + or len(src_block_ids) != len(dst_block_ids) + ): return src_device = next(iter(src_kv_caches.values())).device @@ -252,9 +262,11 @@ def copy_kv_blocks( src_block_ids=src_block_ids, dst_block_ids=dst_block_ids, src_device=src_device, - dst_device=dst_device) + dst_device=dst_device, + ) from vllm.platforms import current_platform + if direction == "h2d": copy_fn = current_platform.insert_blocks_to_device else: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py index f00f31dde915..034c7afe97a4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/__init__.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorRole, +) __all__ = ["KVConnectorRole", "KVConnectorBase_V1"] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index e3b4bcbfd1e6..70225e95aed2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -7,7 +7,7 @@ The class provides the following primitives: Scheduler-side: runs in the scheduler, binds metadata, which is used by the worker-side to load/save KV cache. - get_num_new_matched_tokens() - get number of new tokens + get_num_new_matched_tokens() - get number of new tokens that exist in the remote KV cache. Might be called multiple times for a given request and should be side-effect free. update_state_after_alloc() - update KVConnector state after @@ -49,17 +49,22 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_events import KVCacheEvent - from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request # s_tensor_list, d_tensor_list, s_indices, d_indices, direction -CopyBlocksOp = Callable[[ - dict[str, torch.Tensor], dict[ - str, torch.Tensor], list[int], list[int], Literal["h2d", "d2h"] -], None] +CopyBlocksOp = Callable[ + [ + dict[str, torch.Tensor], + dict[str, torch.Tensor], + list[int], + list[int], + Literal["h2d", "d2h"], + ], + None, +] logger = init_logger(__name__) @@ -77,15 +82,16 @@ class KVConnectorMetadata(ABC): # noqa: B024 Abstract Metadata used to communicate between the Scheduler KVConnector and Worker KVConnector. """ + pass class KVConnectorBase_V1(ABC): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " - "subject to change in the future as we iterate the design.") + "subject to change in the future as we iterate the design." + ) self._connector_metadata: Optional[KVConnectorMetadata] = None self._vllm_config = vllm_config self._role = role @@ -98,11 +104,10 @@ def role(self) -> KVConnectorRole: # Worker-side methods # ============================== - def bind_connector_metadata( - self, connector_metadata: KVConnectorMetadata) -> None: + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: """Set the connector metadata from the scheduler. - This function should be called by the model runner every time + This function should be called by the model runner every time before the model execution. The metadata will be used for runtime KV cache loading and saving. @@ -114,7 +119,7 @@ def bind_connector_metadata( def clear_connector_metadata(self) -> None: """Clear the connector metadata. - This function should be called by the model runner every time + This function should be called by the model runner every time after the model execution. """ self._connector_metadata = None @@ -137,7 +142,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). - Args: + Args: kv_caches: dictionary of layer names, kv cache """ return @@ -150,8 +155,7 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): return @abstractmethod - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs: Any) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the @@ -162,9 +166,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. - + """ pass @@ -174,7 +178,7 @@ def wait_for_layer_load(self, layer_name: str) -> None: Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. - + This interface will be useful for layer-by-layer pipelining. Args: @@ -183,17 +187,21 @@ def wait_for_layer_load(self, layer_name: str) -> None: pass @abstractmethod - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", - **kwargs: Any) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: """ - Start saving a layer of KV cache from vLLM's paged buffer + Start saving a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -276,7 +284,7 @@ def get_num_new_matched_tokens( """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally @@ -303,9 +311,9 @@ def get_num_new_matched_tokens( pass @abstractmethod - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): """ Update KVConnector state after block allocation. @@ -325,7 +333,8 @@ def update_state_after_alloc(self, request: "Request", @abstractmethod def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: """ Build the connector metadata for this step. @@ -374,8 +383,7 @@ def take_events(self) -> Iterable["KVCacheEvent"]: return () @classmethod - def get_required_kvcache_layout( - cls, vllm_config: "VllmConfig") -> Optional[str]: + def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> Optional[str]: """ Get the required KV cache layout for this connector. Args: @@ -387,8 +395,10 @@ def get_required_kvcache_layout( """ if cls is KVConnectorBase_V1: - raise TypeError("get_required_kvcache_layout should not be called " - "on the abstract base class") + raise TypeError( + "get_required_kvcache_layout should not be called " + "on the abstract base class" + ) return None def get_finished_count(self) -> Optional[int]: @@ -404,11 +414,10 @@ def get_finished_count(self) -> Optional[int]: @classmethod def build_kv_connector_stats( - cls, - data: Optional[dict[str, - Any]] = None) -> Optional["KVConnectorStats"]: + cls, data: Optional[dict[str, Any]] = None + ) -> Optional["KVConnectorStats"]: """ - KVConnectorStats resolution method. This method allows dynamically + KVConnectorStats resolution method. This method allows dynamically registered connectors to return their own KVConnectorStats object, which can implement custom aggregation logic on the data dict. """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 2b0abe983fbb..b50cc3ab30fa 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -7,7 +7,10 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput @@ -21,7 +24,6 @@ class LMCacheConnectorV1(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) @@ -29,8 +31,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # ============================== # Worker-side methods # ============================== - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs: Any) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the @@ -41,9 +42,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. - + """ self._lmcache_engine.start_load_kv(forward_context, **kwargs) @@ -52,7 +53,7 @@ def wait_for_layer_load(self, layer_name: str) -> None: Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. - + This interface will be useful for layer-by-layer pipelining. Args: @@ -60,23 +61,28 @@ def wait_for_layer_load(self, layer_name: str) -> None: """ self._lmcache_engine.wait_for_layer_load(layer_name) - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", - **kwargs: Any) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: """ - Start saving the a layer of KV cache from vLLM's paged buffer + Start saving the a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. """ - self._lmcache_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, - **kwargs) + self._lmcache_engine.save_kv_layer( + layer_name, kv_layer, attn_metadata, **kwargs + ) def wait_for_save(self): """ @@ -115,30 +121,31 @@ def get_num_new_matched_tokens( """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ return self._lmcache_engine.get_num_new_matched_tokens( - request, num_computed_tokens), False + request, num_computed_tokens + ), False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): """ Update KVConnector state after block allocation. """ - self._lmcache_engine.update_state_after_alloc(request, - num_external_tokens) + self._lmcache_engine.update_state_after_alloc(request, num_external_tokens) def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: """ Build the connector metadata for this step. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py index e40007230ba4..879cc9a23581 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py @@ -4,10 +4,8 @@ from typing import Any, Optional, Union from vllm.config.kv_transfer import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_transfer_state import ( - has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group from vllm.logger import init_logger logger = init_logger(__name__) @@ -16,11 +14,12 @@ @dataclass class KVConnectorStats: """ - Base class for KV Connector Stats, a container for transfer performance - metrics or otherwise important telemetry from the connector. + Base class for KV Connector Stats, a container for transfer performance + metrics or otherwise important telemetry from the connector. All sub-classes need to be serializable as stats are sent from worker to logger process. """ + data: dict[str, Any] = field(default_factory=dict) def reset(self): @@ -35,8 +34,8 @@ def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats": def reduce(self) -> dict[str, Union[int, float]]: """ - Reduce the observations collected during a time interval to one or - more representative values (eg avg/median/sum of the series). + Reduce the observations collected during a time interval to one or + more representative values (eg avg/median/sum of the series). This is meant to be called by the logger to produce a summary of the stats for the last time interval. """ @@ -48,14 +47,14 @@ def is_empty(self) -> bool: class KVConnectorLogging: - def __init__(self, kv_tranfer_config: KVTransferConfig): # This should be called on frontend process. assert not has_kv_transfer_group() # Instantiate the connector's stats class. if kv_tranfer_config and kv_tranfer_config.kv_connector: self.connector_cls = KVConnectorFactory.get_connector_class( - kv_tranfer_config) + kv_tranfer_config + ) self.reset() def reset(self): @@ -69,32 +68,37 @@ def observe(self, transfer_stats_data: dict[str, Any]): # We expect transfer_stats_data to be aggregated across all workers and # consist of observations from a single connector or a MultiConnector. transfer_stats = self.connector_cls.build_kv_connector_stats( - transfer_stats_data) + transfer_stats_data + ) if transfer_stats is None: logger.warning_once( "The connector %s is collecting stats but " "does not implement the " "`build_kv_connector_stats` method. " - "Stats will not be logged.", self.connector_cls) + "Stats will not be logged.", + self.connector_cls, + ) return if self.transfer_stats_accumulator is None: self.transfer_stats_accumulator = transfer_stats else: # Accumulate last interval stats. - self.transfer_stats_accumulator = \ - self.transfer_stats_accumulator.aggregate(transfer_stats) + self.transfer_stats_accumulator = self.transfer_stats_accumulator.aggregate( + transfer_stats + ) def log(self, log_fn=logger.info): """Log transfer metrics periodically, similar to throughput logging""" - if (self.transfer_stats_accumulator - and not self.transfer_stats_accumulator.is_empty()): + if ( + self.transfer_stats_accumulator + and not self.transfer_stats_accumulator.is_empty() + ): # Produce a single cumulative stats object for the last time # interval from the recorded observations. xfer_metrics = self.transfer_stats_accumulator.reduce() - xfer_metrics_str = ", ".join(f"{k}={v}" - for k, v in xfer_metrics.items()) + xfer_metrics_str = ", ".join(f"{k}={v}" for k, v in xfer_metrics.items()) log_fn("KV Transfer metrics: %s", xfer_metrics_str) # Reset metrics for next interval - self.reset() \ No newline at end of file + self.reset() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index a7713ba326fc..e48d4ccd1d6c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -9,12 +9,13 @@ from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput @@ -58,8 +59,7 @@ def reset(self): def reduce(self) -> dict[str, Any]: # TODO (NickLucche) Adjust for logging on separate lines return { - connector_id: stats.reduce() - for connector_id, stats in self.data.items() + connector_id: stats.reduce() for connector_id, stats in self.data.items() } def is_empty(self) -> bool: @@ -87,16 +87,18 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self._connectors: list[KVConnectorBase_V1] = [] self._ktc_kv_transfer_config = [] ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "connectors") + "connectors" + ) assert ktcs is not None for ktc in ktcs: temp_config = copy.copy(vllm_config) - engine_id = ktc.get("engine_id", - vllm_config.kv_transfer_config.engine_id) + engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id) temp_config.kv_transfer_config = KVTransferConfig( - **ktc, engine_id=engine_id) + **ktc, engine_id=engine_id + ) self._connectors.append( - KVConnectorFactory.create_connector(temp_config, role)) + KVConnectorFactory.create_connector(temp_config, role) + ) self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) # A mapping from request id to the index of the connector chosen to @@ -116,12 +118,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. - def bind_connector_metadata( - self, connector_metadata: KVConnectorMetadata) -> None: + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: assert isinstance(connector_metadata, MultiKVConnectorMetadata) if connector_metadata.extra_async_saves: - self._extra_async_saves.update( - connector_metadata.extra_async_saves) + self._extra_async_saves.update(connector_metadata.extra_async_saves) for c, cm in zip(self._connectors, connector_metadata.metadata): c.bind_connector_metadata(cm) @@ -135,8 +135,9 @@ def shutdown(self): try: c.shutdown() except Exception as e: - logger.exception("Exception during connector %s shutdown.", - c.__class__.__name__) + logger.exception( + "Exception during connector %s shutdown.", c.__class__.__name__ + ) exception = e if exception: raise exception @@ -144,8 +145,7 @@ def shutdown(self): # ============================== # Worker-side methods # ============================== - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: for c in self._connectors: c.start_load_kv(forward_context, **kwargs) @@ -153,8 +153,13 @@ def wait_for_layer_load(self, layer_name: str) -> None: for c in self._connectors: c.wait_for_layer_load(layer_name) - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: for c in self._connectors: c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) @@ -206,7 +211,8 @@ def get_num_new_matched_tokens( to_return = (0, False) for i, c in enumerate(self._connectors): toks, load_async = c.get_num_new_matched_tokens( - request, num_computed_tokens) + request, num_computed_tokens + ) # If there is a connector still looking up the matches, # we return None to indicate that we are not done yet. if toks is None: @@ -218,27 +224,27 @@ def get_num_new_matched_tokens( to_return = (toks, load_async) return to_return - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - chosen_connector = self._requests_to_connector.get( - request.request_id, -1) + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + chosen_connector = self._requests_to_connector.get(request.request_id, -1) empty_blocks = blocks.new_empty() for i, c in enumerate(self._connectors): if i == chosen_connector: # Forward call to the chosen connector (if any). - c.update_state_after_alloc(request, blocks, - num_external_tokens) + c.update_state_after_alloc(request, blocks, num_external_tokens) else: # Call with empty blocks for other connectors. c.update_state_after_alloc(request, empty_blocks, 0) def build_connector_meta( - self, - scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: - metadata = MultiKVConnectorMetadata(metadata=tuple( - c.build_connector_meta(scheduler_output) - for c in self._connectors)) + self, scheduler_output: SchedulerOutput + ) -> MultiKVConnectorMetadata: + metadata = MultiKVConnectorMetadata( + metadata=tuple( + c.build_connector_meta(scheduler_output) for c in self._connectors + ) + ) if self._extra_async_saves: metadata.extra_async_saves = self._extra_async_saves self._extra_async_saves = {} @@ -264,7 +270,8 @@ def request_finished( # TODO we can probably change this to merge the dicts here, # checking for key clashes. raise RuntimeError( - "Only one connector can produce KV transfer params") + "Only one connector can produce KV transfer params" + ) kv_txfer_params = txfer_params if async_saves > 1: self._extra_async_saves[request.request_id] = async_saves - 1 @@ -279,8 +286,7 @@ def take_events(self) -> Iterable["KVCacheEvent"]: yield from c.take_events() @classmethod - def get_required_kvcache_layout( - cls, vllm_config: "VllmConfig") -> Optional[str]: + def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> Optional[str]: """ Get the required KV cache layout for this connector. Args: @@ -291,34 +297,39 @@ def get_required_kvcache_layout( None if the connector does not require a specific layout. """ ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "connectors") + "connectors" + ) assert ktcs is not None layouts: set[str] = set() temp_vllm_config = copy.copy(vllm_config) for ktc in ktcs: kv_transfer_config = KVTransferConfig(**ktc) temp_vllm_config.kv_transfer_config = kv_transfer_config - connector_cls = KVConnectorFactory.get_connector_class( - kv_transfer_config) - required_kvcache_layout = ( - connector_cls.get_required_kvcache_layout(temp_vllm_config)) + connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config) + required_kvcache_layout = connector_cls.get_required_kvcache_layout( + temp_vllm_config + ) if required_kvcache_layout is not None: layouts.add(required_kvcache_layout) if len(layouts) > 1: - raise ValueError(f"KV cache layout mismatch: " - f"found {len(layouts)} different layouts " - f"({', '.join(layouts) })." - f"All connectors must use the same layout.") + raise ValueError( + f"KV cache layout mismatch: " + f"found {len(layouts)} different layouts " + f"({', '.join(layouts)})." + f"All connectors must use the same layout." + ) return next(iter(layouts), None) @classmethod def build_kv_connector_stats( - cls, - data: Optional[dict[str, - Any]] = None) -> Optional[KVConnectorStats]: - return MultiKVConnectorStats(data=data) if data is not None \ + cls, data: Optional[dict[str, Any]] = None + ) -> Optional[KVConnectorStats]: + return ( + MultiKVConnectorStats(data=data) + if data is not None else MultiKVConnectorStats() + ) def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]: # Group connector stats by connector type. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index fdfcc39666ad..e3e3389fd164 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -25,12 +25,17 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) + CopyBlocksOp, + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) from vllm.distributed.utils import divide from vllm.forward_context import ForwardContext from vllm.logger import init_logger @@ -56,6 +61,7 @@ try: from nixl._api import nixl_agent as NixlWrapper from nixl._bindings import nixlXferTelemetry + logger.info("NIXL is available") except ImportError: logger.warning("NIXL is not available") @@ -75,18 +81,19 @@ "cuda", "cpu", ), - "tpu": ("cpu", ), - "xpu": ("cpu", ), + "tpu": ("cpu",), + "xpu": ("cpu",), } # support for oot platform by providing mapping in current_platform _NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices()) class NixlAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True): + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] @@ -107,12 +114,12 @@ class ReqMeta: class NixlConnectorMetadata(KVConnectorMetadata): - def __init__(self): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} self.reqs_to_save: dict[ReqId, ReqMeta] = {} self.reqs_to_send: dict[ReqId, float] = {} self.reqs_in_batch: set[ReqId] = set() + self.reqs_not_processed: set[ReqId] = set() def add_new_req( self, @@ -140,20 +147,19 @@ def add_new_req( class NixlConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler: Optional[NixlConnectorScheduler] = \ + self.connector_scheduler: Optional[NixlConnectorScheduler] = ( NixlConnectorScheduler(vllm_config, self.engine_id) + ) self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker( - vllm_config, self.engine_id) + self.connector_worker = NixlConnectorWorker(vllm_config, self.engine_id) ############################################################ # Class Methods @@ -161,8 +167,10 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): @classmethod def get_required_kvcache_layout(cls, vllm_config: VllmConfig): if vllm_config.model_config is None: - logger.warning_once("Unable to detect current VLLM config. " - "Fallback to default kv cache layout.") + logger.warning_once( + "Unable to detect current VLLM config. " + "Fallback to default kv cache layout." + ) return None use_mla = vllm_config.model_config.use_mla if use_mla: @@ -170,8 +178,9 @@ def get_required_kvcache_layout(cls, vllm_config: VllmConfig): # as the layout should not matter in that case, # which fallback to the default behavior. return None - logger.info_once("NixlConnector setting KV cache " - "layout to HND for better xfer performance.") + logger.info_once( + "NixlConnector setting KV cache layout to HND for better xfer performance." + ) return "HND" ############################################################ @@ -179,18 +188,20 @@ def get_required_kvcache_layout(cls, vllm_config: VllmConfig): ############################################################ def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[Optional[int], bool]: + self, request: "Request", num_computed_tokens: int + ) -> tuple[Optional[int], bool]: assert self.connector_scheduler is not None return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) + request, num_computed_tokens + ) - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): assert self.connector_scheduler is not None return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens) + request, blocks, num_external_tokens + ) def build_connector_meta( self, @@ -218,8 +229,7 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): assert self.connector_worker is not None self.connector_worker.set_host_xfer_buffer_ops(copy_operation) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None return self.connector_worker.get_finished() @@ -230,14 +240,15 @@ def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: @classmethod def build_kv_connector_stats( - cls, - data: Optional[dict[str, - Any]] = None) -> Optional[KVConnectorStats]: - return NixlKVConnectorStats(data=data) if data is not None \ + cls, data: Optional[dict[str, Any]] = None + ) -> Optional[KVConnectorStats]: + return ( + NixlKVConnectorStats(data=data) + if data is not None else NixlKVConnectorStats() + ) - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None assert isinstance(self._connector_metadata, NixlConnectorMetadata) self.connector_worker.start_load_kv(self._connector_metadata) @@ -246,16 +257,20 @@ def wait_for_layer_load(self, layer_name: str) -> None: """NixlConnector does not do layerwise saving.""" pass - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: """NixlConnector does not save explicitly.""" pass def wait_for_save(self): assert self.connector_worker is not None assert isinstance(self._connector_metadata, NixlConnectorMetadata) - if self.connector_worker.use_host_buffer and \ - self.connector_worker.copy_blocks: + if self.connector_worker.use_host_buffer and self.connector_worker.copy_blocks: self.connector_worker.save_kv_to_host(self._connector_metadata) def shutdown(self): @@ -272,11 +287,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.engine_id: EngineId = engine_id self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST self.side_channel_port = ( - envs.VLLM_NIXL_SIDE_CHANNEL_PORT + - vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) - self.use_host_buffer = \ - vllm_config.kv_transfer_config.kv_buffer_device == "cpu" + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + ) + self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" logger.info("Initializing NIXL Scheduler %s", engine_id) # Requests that need to start recv/send. @@ -287,10 +302,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} self._reqs_in_batch: set[ReqId] = set() + # Reqs to remove from processed set because they're not to send after + # remote prefill or aborted. + self._reqs_not_processed: set[ReqId] = set() def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: """ For remote prefill, pull all prompt blocks from remote asynchronously relative to engine execution. @@ -310,7 +328,9 @@ def get_num_new_matched_tokens( logger.debug( "NIXLConnector get_num_new_matched_tokens: " "num_computed_tokens=%s, kv_transfer_params=%s", - num_computed_tokens, params) + num_computed_tokens, + params, + ) if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. @@ -321,15 +341,16 @@ def get_num_new_matched_tokens( # No remote prefill for this request. return 0, False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): - + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): params = request.kv_transfer_params logger.debug( "NIXLConnector update_state_after_alloc: " "num_external_tokens=%s, kv_transfer_params=%s", - num_external_tokens, params) + num_external_tokens, + params, + ) if not params: return @@ -348,25 +369,33 @@ def update_state_after_alloc(self, request: "Request", # block is not overwritten; and it will be safe to skip saving them # to host xfer buffer. if block_ids: - self._reqs_need_save[request.request_id] = \ - (request, block_ids) + self._reqs_need_save[request.request_id] = (request, block_ids) elif params.get("do_remote_prefill"): if params.get("remote_block_ids"): - if all(p in params for p in ("remote_engine_id", "remote_host", - "remote_port")): + if all( + p in params + for p in ("remote_engine_id", "remote_host", "remote_port") + ): # If remote_blocks and num_external_tokens = 0, we have # a full prefix cache hit on the D worker. We need to call # send_notif in _read_blocks to free the memory on the P. - local_block_ids = (blocks.get_unhashed_block_ids() - if num_external_tokens > 0 else []) + local_block_ids = ( + blocks.get_unhashed_block_ids() + if num_external_tokens > 0 + else [] + ) # Get unhashed blocks to pull from remote. self._reqs_need_recv[request.request_id] = ( - request, local_block_ids) + request, + local_block_ids, + ) else: logger.warning( "Got invalid KVTransferParams: %s. This " - "request will not utilize KVTransfer", params) + "request will not utilize KVTransfer", + params, + ) else: assert num_external_tokens == 0 # Only trigger 1 KV transfer per request. @@ -401,11 +430,13 @@ def build_connector_meta( meta.reqs_to_send = self._reqs_need_send meta.reqs_in_batch = self._reqs_in_batch + meta.reqs_not_processed = self._reqs_not_processed # Clear the list once workers start the transfers self._reqs_need_recv.clear() self._reqs_need_save.clear() self._reqs_in_batch = set() + self._reqs_not_processed = set() self._reqs_need_send = {} return meta @@ -423,8 +454,10 @@ def request_finished( params = request.kv_transfer_params logger.debug( - "NIXLConnector request_finished, request_status=%s, " - "kv_transfer_params=%s", request.status, params) + "NIXLConnector request_finished, request_status=%s, kv_transfer_params=%s", + request.status, + params, + ) if not params: return False, None @@ -439,8 +472,12 @@ def request_finished( params["do_remote_prefill"] = False return False, None - if (not params.get("do_remote_decode") - or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + if not params.get("do_remote_decode"): + return False, None + if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: + # Also include the case of a P/D Prefill request with immediate + # block free (eg abort). Stop tracking this request. + self._reqs_not_processed.add(request.request_id) return False, None # TODO: check whether block_ids actually ever be 0. If not we could @@ -449,8 +486,9 @@ def request_finished( if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion - self._reqs_need_send[request.request_id] = time.perf_counter( - ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + self._reqs_need_send[request.request_id] = ( + time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + ) return delay_free_blocks, dict( do_remote_prefill=True, @@ -459,7 +497,8 @@ def request_finished( remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, - tp_size=self.vllm_config.parallel_config.tensor_parallel_size) + tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + ) class NixlConnectorWorker: @@ -476,9 +515,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size - self.nixl_backends = \ - vllm_config.kv_transfer_config.get_from_extra_config( - "backends", ["UCX"]) + self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config( + "backends", ["UCX"] + ) # TODO temporary, once nixl allows for telemetry flag in config # (next release), we can remove this env var. os.environ["NIXL_TELEMETRY_ENABLE"] = "1" @@ -487,8 +526,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): if nixl_agent_config is None: config = None else: - config = nixl_agent_config(backends=self.nixl_backends) if len( - non_ucx_backends) > 0 else nixl_agent_config(num_threads=8) + config = ( + nixl_agent_config(backends=self.nixl_backends) + if len(non_ucx_backends) > 0 + else nixl_agent_config(num_threads=8) + ) self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. @@ -499,9 +541,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # base port (which is sent in the KVTransferParams). # Each TP rank listens/queries on the base_port + tp_rank. self.side_channel_port: int = ( - envs.VLLM_NIXL_SIDE_CHANNEL_PORT + - vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + ) # Metadata. self.engine_id: EngineId = engine_id @@ -512,15 +555,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # KV Caches and nixl tracking data. self.device_type = current_platform.device_type - self.kv_buffer_device: str = \ - vllm_config.kv_transfer_config.kv_buffer_device + self.kv_buffer_device: str = vllm_config.kv_transfer_config.kv_buffer_device if self.device_type not in _NIXL_SUPPORTED_DEVICE: raise RuntimeError(f"{self.device_type} is not supported.") - elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[ - self.device_type]: + elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[self.device_type]: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " - "is not supported.") + "is not supported." + ) self.device_kv_caches: dict[str, torch.Tensor] = {} # cpu kv buffer for xfer @@ -538,7 +580,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): if self.nixl_memory_type is None: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " - "is not supported.") + "is not supported." + ) # Note: host xfer buffer ops when use_host_buffer is True self.copy_blocks: Optional[CopyBlocksOp] = None @@ -577,7 +620,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._handshake_initiation_executor = ThreadPoolExecutor( # NIXL is not guaranteed to be thread-safe, limit 1 worker. max_workers=1, - thread_name_prefix="vllm-nixl-handshake-initiator") + thread_name_prefix="vllm-nixl-handshake-initiator", + ) self._ready_requests = queue.Queue[tuple[ReqId, ReqMeta]]() self._handshake_futures: dict[EngineId, Future[dict[int, str]]] = {} # Protects _handshake_futures and _remote_agents. @@ -594,11 +638,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.block_window_per_layer: list[Optional[int]] = [] self.use_mla = self.model_config.use_mla - backend = get_attn_backend(self.model_config.get_head_size(), - self.model_config.dtype, - self.cache_config.cache_dtype, - self.block_size, - use_mla=self.use_mla) + backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + use_mla=self.use_mla, + ) self.backend_name = backend.get_name() attn_backend = backend_name_to_enum(self.backend_name) self._use_flashinfer = attn_backend == _Backend.FLASHINFER @@ -614,9 +660,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.xfer_stats = NixlKVConnectorStats() @staticmethod - def _nixl_handshake_listener(metadata: NixlAgentMetadata, - ready_event: threading.Event, base_port: int, - tp_rank: int): + def _nixl_handshake_listener( + metadata: NixlAgentMetadata, + ready_event: threading.Event, + base_port: int, + tp_rank: int, + ): """Background thread for getting new NIXL handshakes.""" # NOTE(rob): this is a simple implementation. We will move # to a better approach via HTTP endpoint soon. @@ -624,8 +673,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) size_in_bytes = len(encoded_data) - logger.debug("Size of encoded NixlAgentMetadata: %s bytes", - str(size_in_bytes)) + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", str(size_in_bytes)) # Listen for new requests for metadata. host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST @@ -636,8 +684,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, while True: identity, _, msg = sock.recv_multipart() if msg != GET_META_MSG: - logger.warning( - "Connection listener got unexpected message %s", msg) + logger.warning("Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data)) def _nixl_handshake( @@ -660,8 +707,9 @@ def _nixl_handshake( tp_ratio = self._tp_size[self.engine_id] // remote_tp_size p_remote_rank = self.tp_rank // tp_ratio path = make_zmq_path("tcp", host, port + p_remote_rank) - logger.debug("Querying metadata on path: %s at remote rank %s", path, - p_remote_rank) + logger.debug( + "Querying metadata on path: %s at remote rank %s", path, p_remote_rank + ) # Send query for the request. with zmq_ctx(zmq.REQ, path) as sock: @@ -670,27 +718,32 @@ def _nixl_handshake( decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) metadata = decoder.decode(metadata_bytes) got_metadata_time = time.perf_counter() - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) + logger.debug( + "NIXL handshake: get metadata took: %s", got_metadata_time - start_time + ) # Ensure engine id matches. if metadata.engine_id != expected_engine_id: - raise RuntimeError(f"Remote NIXL agent engine ID mismatch. " - f"Expected {expected_engine_id}," - f"received {metadata.engine_id}.") + raise RuntimeError( + f"Remote NIXL agent engine ID mismatch. " + f"Expected {expected_engine_id}," + f"received {metadata.engine_id}." + ) # Register Remote agent. - remote_agent_name = self.add_remote_agent(metadata, p_remote_rank, - remote_tp_size) + remote_agent_name = self.add_remote_agent( + metadata, p_remote_rank, remote_tp_size + ) setup_agent_time = time.perf_counter() - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + logger.debug( + "NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time, + ) # Remote rank -> agent name. return {p_remote_rank: remote_agent_name} - def initialize_host_xfer_buffer( - self, kv_caches: dict[str, torch.Tensor]) -> None: + def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: """ Initialize transfer buffer in CPU mem for accelerators NOT directly supported by NIXL (e.g., tpu) @@ -700,9 +753,9 @@ def initialize_host_xfer_buffer( for layer_name, kv_cache in kv_caches.items(): kv_shape = kv_cache.shape kv_dtype = kv_cache.dtype - xfer_buffers[layer_name] = torch.empty(kv_shape, - dtype=kv_dtype, - device="cpu") + xfer_buffers[layer_name] = torch.empty( + kv_shape, dtype=kv_dtype, device="cpu" + ) except MemoryError as e: logger.error("NIXLConnectorWorker gets %s.", e) raise @@ -717,14 +770,19 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): assert self.use_host_buffer self.copy_blocks = copy_operation - def _background_nixl_handshake(self, req_id: str, - remote_engine_id: EngineId, meta: ReqMeta): + def _background_nixl_handshake( + self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta + ): # Do NIXL handshake in background and add to _ready_requests when done. fut = self._handshake_futures.get(remote_engine_id) if fut is None: fut = self._handshake_initiation_executor.submit( - self._nixl_handshake, meta.remote_host, meta.remote_port, - meta.tp_size, remote_engine_id) + self._nixl_handshake, + meta.remote_host, + meta.remote_port, + meta.tp_size, + remote_engine_id, + ) self._handshake_futures[remote_engine_id] = fut def done_callback(f: Future[dict[int, str]], eid=remote_engine_id): @@ -751,18 +809,23 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.initialize_host_xfer_buffer(kv_caches=kv_caches) assert len(self.host_xfer_buffers) == len(kv_caches), ( f"host_buffer: {len(self.host_xfer_buffers)}, " - f"kv_caches: {len(kv_caches)}") + f"kv_caches: {len(kv_caches)}" + ) xfer_buffers = self.host_xfer_buffers else: xfer_buffers = kv_caches assert not self.host_xfer_buffers, ( "host_xfer_buffer should not be initialized when " - f"kv_buffer_device is {self.kv_buffer_device}") + f"kv_buffer_device is {self.kv_buffer_device}" + ) logger.info( "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " - "use_host_buffer: %s", self.use_mla, self.kv_buffer_device, - self.use_host_buffer) + "use_host_buffer: %s", + self.use_mla, + self.kv_buffer_device, + self.use_host_buffer, + ) caches_data = [] # With hybrid allocator, layers can share a kv cache tensor @@ -776,16 +839,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are registered in the same region # to better exploit the memory layout (ie num_blocks is the first dim). - split_k_and_v = not (self.use_mla or self._use_pallas - or self._use_flashinfer) + split_k_and_v = not (self.use_mla or self._use_pallas or self._use_flashinfer) tensor_size_bytes = None # Enable different block lengths for different layers when MLA is used. self.block_len_per_layer = list[int]() self.slot_size_per_layer = list[int]() # HD bytes in kv terms for layer_name, cache_or_caches in xfer_buffers.items(): - cache_list = cache_or_caches if split_k_and_v else [ - cache_or_caches - ] + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] for cache in cache_list: base_addr = cache.data_ptr() @@ -799,23 +859,29 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): tensor_size_bytes = curr_tensor_size_bytes self.num_blocks = cache.shape[0] - assert cache.shape[0] == self.num_blocks, \ + assert cache.shape[0] == self.num_blocks, ( "All kv cache tensors must have the same number of blocks" + ) - self.block_len_per_layer.append(curr_tensor_size_bytes // - self.num_blocks) - self.slot_size_per_layer.append(self.block_len_per_layer[-1] // - self.block_size) + self.block_len_per_layer.append( + curr_tensor_size_bytes // self.num_blocks + ) + self.slot_size_per_layer.append( + self.block_len_per_layer[-1] // self.block_size + ) if not self.use_mla: # Different kv cache shape is not supported by HeteroTP - assert tensor_size_bytes == curr_tensor_size_bytes, \ + assert tensor_size_bytes == curr_tensor_size_bytes, ( "All kv cache tensors must have the same size" + ) caches_data.append( - (base_addr, curr_tensor_size_bytes, self.tp_rank, "")) + (base_addr, curr_tensor_size_bytes, self.tp_rank, "") + ) - logger.debug("Different block lengths collected: %s", - set(self.block_len_per_layer)) + logger.debug( + "Different block lengths collected: %s", set(self.block_len_per_layer) + ) assert len(self.block_len_per_layer) == len(seen_base_addresses) assert self.num_blocks != 0 @@ -823,8 +889,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) - descs = self.nixl_wrapper.get_reg_descs(caches_data, - self.nixl_memory_type) + descs = self.nixl_wrapper.get_reg_descs(caches_data, self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends) logger.debug("Done registering descs") @@ -870,21 +935,27 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Register addresses for V cache (K registered first). v_addr = addr + kv_block_len blocks_data.append((v_addr, kv_block_len, self.tp_rank)) - logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.tp_rank) + logger.debug( + "Created %s blocks for src engine %s and rank %s", + len(blocks_data), + self.engine_id, + self.tp_rank, + ) - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, - self.nixl_memory_type) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) + "NIXL_INIT_AGENT", descs + ) # TODO(mgoin): Hybrid memory allocator is currently disabled for # models with local attention (Llama 4). Can remove this once enabled. if self.vllm_config.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig - assert isinstance(self.vllm_config.model_config.hf_text_config, - Llama4TextConfig) + + assert isinstance( + self.vllm_config.model_config.hf_text_config, Llama4TextConfig + ) llama4_config = self.vllm_config.model_config.hf_text_config no_rope_layers = llama4_config.no_rope_layers chunk_size = llama4_config.attention_chunk_size @@ -895,8 +966,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): is_local_attention = no_rope_layers[layer_idx] != 0 block_window = chunk_block_size if is_local_attention else None self.block_window_per_layer.append(block_window) - logger.debug("Llama 4 block window per layer mapping: %s", - self.block_window_per_layer) + logger.debug( + "Llama 4 block window per layer mapping: %s", + self.block_window_per_layer, + ) assert len(self.block_window_per_layer) == self.num_layers # After KV Caches registered, listen for new connections. @@ -907,33 +980,37 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, attn_backend_name=self.backend_name, - kv_cache_layout=self.kv_cache_layout) + kv_cache_layout=self.kv_cache_layout, + ) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, args=(metadata, ready_event, self.side_channel_port, self.tp_rank), daemon=True, - name="nixl_handshake_listener") + name="nixl_handshake_listener", + ) self._nixl_handshake_listener_t.start() ready_event.wait() # Wait for listener ZMQ socket to be ready. - def add_remote_agent(self, - nixl_agent_meta: NixlAgentMetadata, - remote_tp_rank: int = 0, - remote_tp_size: int = 1) -> str: + def add_remote_agent( + self, + nixl_agent_meta: NixlAgentMetadata, + remote_tp_rank: int = 0, + remote_tp_size: int = 1, + ) -> str: """ Add the remote NIXL agent and prepare the descriptors for reading cache blocks from remote. In particular, handle both homogeneous and heterogeneous TP. The former - requires local rank_i to read from remote rank_i. - The latter, assuming D.world_size > P.world_size, requires that two or + requires local rank_i to read from remote rank_i. + The latter, assuming D.world_size > P.world_size, requires that two or more local TP worker share the xfer from a single TP worker. Here's an example (non-MLA case): rank_offset p_remote_tp_rank - (kv split no) + (kv split no) -------------------------------- 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] / @@ -946,19 +1023,19 @@ def add_remote_agent(self, Decoder TP workers Prefix TP workers (world_size=4) (world_size=2) - tp_ratio = 4 // 2 = 2 - - Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] + tp_ratio = 4 // 2 = 2 + + Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format. - Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio + Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split - along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. - + along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. + Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1. Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0 so that the whole cache is shared by "tp_ratio" D TP workers. - """ # noqa: E501 + """ # noqa: E501 engine_id = nixl_agent_meta.engine_id # TODO re-evaluate refreshing for scaling/recovery if remote_tp_rank in self._remote_agents.get(engine_id, {}): @@ -972,15 +1049,16 @@ def add_remote_agent(self, assert nixl_agent_meta.attn_backend_name == self.backend_name remote_agent_name = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) + nixl_agent_meta.agent_metadata + ) # Number of D TP workers reading from a single P TP worker. This is # 1 when P and D `--tensor-parallel-size` match. - tp_ratio = divide(self._tp_size[self.engine_id], - self._tp_size[engine_id]) + tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id]) assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" - assert not self._use_pallas or tp_ratio == 1, \ - "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." + assert not self._use_pallas or tp_ratio == 1, ( + "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." + ) # Handle tp_size>num_kv_heads: replicate KV cache. total_num_kv_heads = self.model_config.get_total_num_kv_heads() @@ -989,17 +1067,19 @@ def add_remote_agent(self, remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or is_kv_replicated: # With replicated KV cache, only the number of blocks can differ. - assert self.block_len_per_layer == nixl_agent_meta.block_lens, \ + assert self.block_len_per_layer == nixl_agent_meta.block_lens, ( "KV cache sizes must match between P and D when replicated" - remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0]) + ) + remote_block_size = remote_block_len // (self.slot_size_per_layer[0]) else: # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: - assert block_len == remote_block_len, \ + assert block_len == remote_block_len, ( "All remote layers must have the same block size" + ) remote_block_size = remote_block_len // ( - self.slot_size_per_layer[0] * tp_ratio) + self.slot_size_per_layer[0] * tp_ratio + ) if self._use_flashinfer: # With flashinfer, KV are sent in the same message. remote_block_size //= 2 @@ -1007,8 +1087,7 @@ def add_remote_agent(self, # Heterogeneous TP expects same kv_cache_layout. assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout if self.device_type == "xpu": - raise ValueError( - "Heterogeneous TP is not supported on XPU") + raise ValueError("Heterogeneous TP is not supported on XPU") assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, ( "Remote P worker KV layer cache must be of shape [2, N, " @@ -1017,7 +1096,8 @@ def add_remote_agent(self, assert self.block_size == remote_block_size, ( "Remote P worker with different page/block size is not supported " - f"{self.block_size=}, {remote_block_size=}") + f"{self.block_size=}, {remote_block_size=}" + ) # Create dst descs and xfer side handles. TP workers have same #blocks. if engine_id in self.dst_num_blocks: @@ -1030,16 +1110,17 @@ def add_remote_agent(self, # rank. With heterogeneous TP, prepare the descriptors by splitting the # P KV cache along kv_head dim, of D worker's kv_head size (D>P). # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. - self.kv_caches_base_addr[ - engine_id] = nixl_agent_meta.kv_caches_base_addr + self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr - assert len(nixl_agent_meta.kv_caches_base_addr) == len( - self.block_len_per_layer) + assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) # Register all remote blocks, but only the corresponding kv heads. for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - rank_offset = self.tp_rank % tp_ratio * kv_block_len \ - if not (self.use_mla or is_kv_replicated) else 0 + rank_offset = ( + self.tp_rank % tp_ratio * kv_block_len + if not (self.use_mla or is_kv_replicated) + else 0 + ) for block_id in range(nixl_agent_meta.num_blocks): block_offset = block_id * nixl_agent_meta.block_lens[i] # For each block, grab the heads chunk belonging to rank_i @@ -1058,16 +1139,18 @@ def add_remote_agent(self, blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) logger.debug( - "Created %s blocks for dst engine %s with remote rank %s and " - "local rank %s", len(blocks_data), engine_id, remote_tp_rank, - self.tp_rank) + "Created %s blocks for dst engine %s with remote rank %s and local rank %s", + len(blocks_data), + engine_id, + remote_tp_rank, + self.tp_rank, + ) # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, - self.nixl_memory_type) - self.dst_xfer_side_handles[ - engine_id] = self.nixl_wrapper.prep_xfer_dlist( - remote_agent_name, descs) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) + self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist( + remote_agent_name, descs + ) return remote_agent_name @@ -1077,13 +1160,20 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): assert self.copy_blocks is not None local_block_ids = meta.local_block_ids - self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches, - local_block_ids, local_block_ids, "h2d") + self.copy_blocks( + self.host_xfer_buffers, + self.device_kv_caches, + local_block_ids, + local_block_ids, + "h2d", + ) if logger.isEnabledFor(logging.DEBUG): logger.debug( "synced recved kv of request[%s] to device kv buffer," - "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids))) + "local_block_ids: %s. ", + req_id, + ",".join(map(str, meta.local_block_ids)), + ) def save_kv_to_host(self, metadata: NixlConnectorMetadata): """copy kv from device to host buffer.""" @@ -1094,11 +1184,18 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): if logger.isEnabledFor(logging.DEBUG): logger.debug( "save_load_kv for request[%s] to host xfer buffer." - "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids))) + "local_block_ids: %s. ", + req_id, + ",".join(map(str, meta.local_block_ids)), + ) # blocking - self.copy_blocks(self.device_kv_caches, self.host_xfer_buffers, - meta.local_block_ids, meta.local_block_ids, "d2h") + self.copy_blocks( + self.device_kv_caches, + self.host_xfer_buffers, + meta.local_block_ids, + meta.local_block_ids, + "d2h", + ) def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -1111,8 +1208,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: if len(done_sending) > 0 or len(done_recving) > 0: logger.debug( "Rank %s, get_finished: %s requests done sending " - "and %s requests done recving", self.tp_rank, - len(done_sending), len(done_recving)) + "and %s requests done recving", + self.tp_rank, + len(done_sending), + len(done_recving), + ) if self.use_host_buffer: for req_id in done_recving: @@ -1130,8 +1230,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: count = self.consumer_notification_counts_by_req.pop(req_id, 0) logger.warning( "Releasing expired KV blocks for request %s which were " - "retrieved by %d decode worker(s) within %d seconds.", req_id, - count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) + "retrieved by %d decode worker(s) within %d seconds.", + req_id, + count, + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, + ) self._reqs_to_process.remove(req_id) del self._reqs_to_send[req_id] done_sending.add(req_id) @@ -1148,18 +1251,21 @@ def _get_new_notifs(self) -> set[str]: for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) - if (req_id not in self._reqs_to_send - and req_id not in self._reqs_to_process): + if ( + req_id not in self._reqs_to_send + and req_id not in self._reqs_to_process + ): logger.error( "Potentially invalid KV blocks for " "unrecognized request %s were retrieved by " - "a decode worker. They may have expired.", req_id) + "a decode worker. They may have expired.", + req_id, + ) continue self.consumer_notification_counts_by_req[req_id] += 1 # Wait all consumers (D) to be done reading before freeing. - if self.consumer_notification_counts_by_req[req_id] == int( - tp_ratio): + if self.consumer_notification_counts_by_req[req_id] == int(tp_ratio): notified_req_ids.add(req_id) del self.consumer_notification_counts_by_req[req_id] self._reqs_to_process.remove(req_id) @@ -1167,7 +1273,8 @@ def _get_new_notifs(self) -> set[str]: return notified_req_ids def _pop_done_transfers( - self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]: + self, transfers: dict[str, list[tuple[int, float]]] + ) -> set[str]: """ Pop completed xfers by checking for DONE state. Args: @@ -1189,8 +1296,7 @@ def _pop_done_transfers( in_progress = True continue else: - raise RuntimeError("Transfer failed with state %s", - xfer_state) + raise RuntimeError("Transfer failed with state %s", xfer_state) if not in_progress: done_req_ids.add(req_id) del transfers[req_id] @@ -1205,17 +1311,19 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " - "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, - remote_engine_id, len(meta.local_block_ids), - len(meta.remote_block_ids)) + "Num local_block_ids: %s. Num remote_block_ids: %s. ", + req_id, + remote_engine_id, + len(meta.local_block_ids), + len(meta.remote_block_ids), + ) if self.use_host_buffer: self._recving_metadata[req_id] = meta if remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: if remote_engine_id not in self._remote_agents: - self._background_nixl_handshake( - req_id, remote_engine_id, meta) + self._background_nixl_handshake(req_id, remote_engine_id, meta) continue # Handshake already completed, start async read xfer. @@ -1234,6 +1342,10 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): for req_id in metadata.reqs_in_batch: self._reqs_to_process.add(req_id) + # Remove all requests that are not to be processed (eg aborted). + for req_id in metadata.reqs_not_processed: + self._reqs_to_process.discard(req_id) + # Add to requests that are waiting to be read and track expiration. for req_id, expiration_time in metadata.reqs_to_send.items(): if req_id in self._reqs_to_process: @@ -1242,7 +1354,9 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug( "Remote agent %s available, calling _read_blocks for req %s", - meta.remote_engine_id, req_id) + meta.remote_engine_id, + req_id, + ) self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, @@ -1250,9 +1364,13 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): remote_block_ids=meta.remote_block_ids, ) - def _read_blocks(self, local_block_ids: list[int], - remote_block_ids: list[int], dst_engine_id: str, - request_id: str): + def _read_blocks( + self, + local_block_ids: list[int], + remote_block_ids: list[int], + dst_engine_id: str, + request_id: str, + ): # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the @@ -1265,8 +1383,7 @@ def _read_blocks(self, local_block_ids: list[int], # Number of D TP workers that will read from dst P. Propagate tp_ratio # on notification so that dst worker can wait before freeing blocks. - tp_ratio = self._tp_size[ - self.engine_id] // self._tp_size[dst_engine_id] + tp_ratio = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id] notif_id = f"{request_id}:{tp_ratio}".encode() # Full prefix cache hit: do not need to read remote blocks, @@ -1298,16 +1415,17 @@ def _read_blocks(self, local_block_ids: list[int], if not self.block_window_per_layer: # Default case: assume global attention remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, remote_block_ids) + dst_engine_id, remote_block_ids + ) local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, local_block_ids) + self.engine_id, local_block_ids + ) else: # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) local_descs_list = [] remote_descs_list = [] - for layer_idx, block_window in enumerate( - self.block_window_per_layer): + for layer_idx, block_window in enumerate(self.block_window_per_layer): # For each layer: if block_window is None: # If not chunked, we just use the @@ -1321,9 +1439,11 @@ def _read_blocks(self, local_block_ids: list[int], # Get descs ids for the layer. layer_local_desc_ids = self._get_block_descs_ids( - self.engine_id, layer_local_block_ids, layer_idx) + self.engine_id, layer_local_block_ids, layer_idx + ) layer_remote_desc_ids = self._get_block_descs_ids( - dst_engine_id, layer_remote_block_ids, layer_idx) + dst_engine_id, layer_remote_block_ids, layer_idx + ) local_descs_list.append(layer_local_desc_ids) remote_descs_list.append(layer_remote_desc_ids) @@ -1347,13 +1467,11 @@ def _read_blocks(self, local_block_ids: list[int], self.nixl_wrapper.transfer(handle) # Use handle to check completion in future step(). - self._recving_transfers[request_id].append( - (handle, time.perf_counter())) + self._recving_transfers[request_id].append((handle, time.perf_counter())) - def _get_block_descs_ids(self, - engine_id: str, - block_ids: list[int], - layer_idx: Optional[int] = None) -> np.ndarray: + def _get_block_descs_ids( + self, engine_id: str, block_ids: list[int], layer_idx: Optional[int] = None + ) -> np.ndarray: """ Get the descs ids for a set of block ids. If layer_idx is provided, we use the region_ids for the given layer. @@ -1386,7 +1504,7 @@ def get_backend_aware_kv_block_len(self, layer_idx: int): """ Get the block length for one K/V element (K and V have the same size). - For FA and other backends, this is equal to the length of the whole + For FA and other backends, this is equal to the length of the whole block, as K and V are in separate regions. For FlashInfer, this is half the length of the whole block, as K and V share the same region. @@ -1442,10 +1560,9 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: ctx: Optional[zmq.Context] = None try: ctx = zmq.Context() # type: ignore[attr-defined] - yield make_zmq_socket(ctx=ctx, - path=addr, - socket_type=socket_type, - bind=socket_type == zmq.ROUTER) + yield make_zmq_socket( + ctx=ctx, path=addr, socket_type=socket_type, bind=socket_type == zmq.ROUTER + ) finally: if ctx is not None: ctx.destroy(linger=0) @@ -1533,4 +1650,4 @@ def reduce(self) -> dict[str, Union[int, float]]: @property def num_successful_transfers(self) -> int: - return len(self.data["transfer_duration"]) \ No newline at end of file + return len(self.data["transfer_duration"]) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 6936638c7f4e..745af0efba18 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -11,10 +11,11 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorMetadata) +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -40,7 +41,6 @@ class OffloadingConnectorMetadata(KVConnectorMetadata): class OffloadingConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): super().__init__(vllm_config, role) @@ -57,47 +57,51 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None - assert isinstance(self._connector_metadata, - OffloadingConnectorMetadata) + assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) self.connector_worker.start_load_kv(self._connector_metadata) def wait_for_layer_load(self, layer_name: str) -> None: pass - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: pass def wait_for_save(self): assert self.connector_worker is not None - assert isinstance(self._connector_metadata, - OffloadingConnectorMetadata) + assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) self.connector_worker.start_store_kv(self._connector_metadata) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: assert self.connector_worker is not None return self.connector_worker.get_finished(finished_req_ids) def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: assert self.connector_scheduler is not None return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) + request, num_computed_tokens + ) - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): assert self.connector_scheduler is not None return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens) + request, blocks, num_external_tokens + ) def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: assert self.connector_scheduler is not None return self.connector_scheduler.build_connector_meta(scheduler_output) @@ -124,8 +128,7 @@ class OffloadingConnectorScheduler: def __init__(self, spec: OffloadingSpec): self.gpu_block_size = spec.gpu_block_size self.offloaded_block_size = spec.offloaded_block_size - self.block_size_factor = (self.offloaded_block_size // - self.gpu_block_size) + self.block_size_factor = self.offloaded_block_size // self.gpu_block_size self.manager: OffloadingManager = spec.get_manager() self._requests: dict[ReqId, Request] = {} @@ -151,11 +154,12 @@ def _get_block_hashes( req.block_hashes, self.block_size_factor * start_idx + self.block_size_factor - 1, self.block_size_factor * end_idx if end_idx else None, - self.block_size_factor) + self.block_size_factor, + ) def get_num_new_matched_tokens( - self, request: Request, - num_computed_tokens: int) -> tuple[int, bool]: + self, request: Request, num_computed_tokens: int + ) -> tuple[int, bool]: """ Get number of new tokens that can be loaded beyond the num_computed_tokens. @@ -174,8 +178,7 @@ def get_num_new_matched_tokens( """ num_blocks = request.num_tokens // self.offloaded_block_size - assert (len(request.block_hashes) // - self.block_size_factor == num_blocks) + assert len(request.block_hashes) // self.block_size_factor == num_blocks block_hashes = self._get_block_hashes(request) self.manager.touch(block_hashes) @@ -187,12 +190,14 @@ def get_num_new_matched_tokens( start_block_idx = num_computed_tokens // self.offloaded_block_size hits = self.manager.lookup( - self._get_block_hashes(request, start_idx=start_block_idx)) + self._get_block_hashes(request, start_idx=start_block_idx) + ) if hits == 0: return 0, False - num_hit_tokens = (self.offloaded_block_size * - (start_block_idx + hits) - num_computed_tokens) + num_hit_tokens = ( + self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens + ) logger.debug( "Request %s hit %s offloaded tokens after %s GPU hit tokens", request.request_id, @@ -204,8 +209,9 @@ def get_num_new_matched_tokens( return num_hit_tokens, True - def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, - num_external_tokens: int): + def update_state_after_alloc( + self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int + ): self._requests[request.request_id] = request # the block ids are updated in _get_reqs_to_store self._request_block_ids[request.request_id] = [] @@ -216,31 +222,30 @@ def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, block_groups = blocks.get_block_ids() block_ids = block_groups[0] - num_computed_gpu_blocks = sum(block.block_hash is not None - for block in blocks.blocks[0]) + num_computed_gpu_blocks = sum( + block.block_hash is not None for block in blocks.blocks[0] + ) num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size full_block_tokens = num_computed_tokens + num_external_tokens assert full_block_tokens % self.offloaded_block_size == 0 num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks - assert (num_external_tokens == num_pending_gpu_blocks * - self.gpu_block_size) + assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size start_block_idx = num_computed_tokens // self.offloaded_block_size num_blocks = full_block_tokens // self.offloaded_block_size - assert (len(request.block_hashes) // self.block_size_factor - >= num_blocks) - block_hashes = self._get_block_hashes(request, - start_idx=start_block_idx, - end_idx=num_blocks) + assert len(request.block_hashes) // self.block_size_factor >= num_blocks + block_hashes = self._get_block_hashes( + request, start_idx=start_block_idx, end_idx=num_blocks + ) src_spec = self.manager.prepare_load(block_hashes) dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:]) - block_hashes = self._get_block_hashes(request, - start_idx=start_block_idx, - end_idx=num_blocks) + block_hashes = self._get_block_hashes( + request, start_idx=start_block_idx, end_idx=num_blocks + ) self._reqs_to_load[request.request_id] = (src_spec, dst_spec) self._reqs_being_loaded[request.request_id].update(block_hashes) @@ -249,9 +254,7 @@ def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): reqs_to_store: dict[ReqId, TransferSpec] = {} # iterate over both new and cached requests - for req_id, new_block_id_groups, preempted in yield_req_data( - scheduler_output): - + for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): if preempted: self._request_block_ids[req_id] = [] @@ -275,11 +278,13 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): assert len(req.block_hashes) >= num_gpu_blocks new_block_hashes = self._get_block_hashes( - req, start_idx=start_block_idx, end_idx=num_blocks) + req, start_idx=start_block_idx, end_idx=num_blocks + ) store_output = self.manager.prepare_store(new_block_hashes) if store_output is None: - logger.warning("Request %s: cannot store %s blocks", req_id, - num_new_blocks) + logger.warning( + "Request %s: cannot store %s blocks", req_id, num_new_blocks + ) continue self._next_stored_block_idx[req_id] = num_blocks @@ -292,7 +297,8 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): self.manager.touch(block_hashes) new_block_hashes = self._get_block_hashes( - req, start_idx=start_block_idx, end_idx=num_blocks) + req, start_idx=start_block_idx, end_idx=num_blocks + ) dst_spec = store_output.store_spec src_block_ids: list[int] = [] for idx, blk_hash in enumerate(new_block_hashes): @@ -317,10 +323,12 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): return reqs_to_store def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: meta = OffloadingConnectorMetadata( reqs_to_load=self._reqs_to_load, - reqs_to_store=self._get_reqs_to_store(scheduler_output)) + reqs_to_store=self._get_reqs_to_store(scheduler_output), + ) self._reqs_to_load = {} return meta @@ -373,15 +381,16 @@ def take_events(self) -> Iterable[KVCacheEvent]: """ for event in self.manager.take_events(): if event.removed: - yield BlockRemoved(block_hashes=event.block_hashes, - medium=event.medium) + yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium) else: - yield BlockStored(block_hashes=event.block_hashes, - parent_block_hash=None, - token_ids=[], - lora_id=None, - block_size=event.block_size, - medium=event.medium) + yield BlockStored( + block_hashes=event.block_hashes, + parent_block_hash=None, + token_ids=[], + lora_id=None, + block_size=event.block_size, + medium=event.medium, + ) class OffloadingConnectorWorker: @@ -408,7 +417,7 @@ def _generate_job_id(self) -> int: return job_id def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - for src_cls, dst_cls, handler in (self.spec.get_handlers(kv_caches)): + for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches): self.worker.register_handler(src_cls, dst_cls, handler) def start_load_kv(self, metadata: OffloadingConnectorMetadata): @@ -426,8 +435,7 @@ def start_store_kv(self, metadata: OffloadingConnectorMetadata): self._store_jobs[req_id].add(job_id) assert self.worker.transfer_async(job_id, transfer_spec) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """ Notifies worker-side connector ids of requests that have finished generating tokens. @@ -471,7 +479,8 @@ def get_finished(self, def yield_req_data( - scheduler_output) -> Iterator[tuple[str, tuple[list[int], ...], bool]]: + scheduler_output, +) -> Iterator[tuple[str, tuple[list[int], ...], bool]]: """ Yields: (req_id, new_block_id_groups, preempted) @@ -482,5 +491,8 @@ def yield_req_data( # cached requests cached_reqs = scheduler_output.scheduled_cached_reqs - yield from zip(cached_reqs.req_ids, cached_reqs.new_block_ids, - cached_reqs.resumed_from_preemption) + yield from zip( + cached_reqs.req_ids, + cached_reqs.new_block_ids, + cached_reqs.resumed_from_preemption, + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 3dadfa595ef1..0e6693db5cd2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -9,9 +9,13 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( - P2pNcclEngine) + P2pNcclEngine, +) from vllm.distributed.parallel_state import get_world_group from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import MLACommonMetadata @@ -36,8 +40,9 @@ class ReqMeta: num_tokens: int @staticmethod - def make_meta(request_id: str, token_ids: list[int], block_ids: list[int], - block_size: int) -> "ReqMeta": + def make_meta( + request_id: str, token_ids: list[int], block_ids: list[int], block_size: int + ) -> "ReqMeta": block_ids_tensor = torch.tensor(block_ids) return ReqMeta( request_id=request_id, @@ -61,11 +66,11 @@ def add_request( block_size: int, ) -> None: self.requests.append( - ReqMeta.make_meta(request_id, token_ids, block_ids, block_size)) + ReqMeta.make_meta(request_id, token_ids, block_ids, block_size) + ) class P2pNcclConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._block_size = vllm_config.cache_config.block_size @@ -74,24 +79,27 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.is_producer = self.config.is_kv_producer self.chunked_prefill: dict[str, Any] = {} - self._rank = get_world_group().rank \ - if role == KVConnectorRole.WORKER else 0 - self._local_rank = get_world_group().local_rank \ - if role == KVConnectorRole.WORKER else 0 + self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0 + self._local_rank = ( + get_world_group().local_rank if role == KVConnectorRole.WORKER else 0 + ) - self.p2p_nccl_engine = P2pNcclEngine( - local_rank=self._local_rank, - config=self.config, - hostname="", - port_offset=self._rank, - ) if role == KVConnectorRole.WORKER else None + self.p2p_nccl_engine = ( + P2pNcclEngine( + local_rank=self._local_rank, + config=self.config, + hostname="", + port_offset=self._rank, + ) + if role == KVConnectorRole.WORKER + else None + ) # ============================== # Worker-side methods # ============================== - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs: Any) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. @@ -143,8 +151,9 @@ def inject_kv_into_layer( Returns: None. The function modifies `layer` in-place. """ - if (isinstance(attn_metadata, MLACommonMetadata) - or layer.shape[1] == 2): # MLA or FlashInfer + if ( + isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2 + ): # MLA or FlashInfer num_block = kv_cache.shape[0] self.check_tensors_except_dim(layer, kv_cache, 0) if len(block_ids) == num_block: @@ -153,8 +162,11 @@ def inject_kv_into_layer( layer[block_ids[:num_block], ...] = kv_cache logger.warning( "🚧kv_cache does not match, block_ids:%d, " - "num_block:%d, request_id:%s", len(block_ids), - num_block, request_id) + "num_block:%d, request_id:%s", + len(block_ids), + num_block, + request_id, + ) elif layer.shape[0] == 2: # FlashAttention num_block = kv_cache.shape[1] @@ -165,12 +177,14 @@ def inject_kv_into_layer( layer[:, block_ids[:num_block], ...] = kv_cache logger.warning( "🚧kv_cache does not match, block_ids:%d, " - "num_block:%d, request_id:%s", len(block_ids), - num_block, request_id) + "num_block:%d, request_id:%s", + len(block_ids), + num_block, + request_id, + ) # Get the metadata - metadata: KVConnectorMetadata = \ - self._get_connector_metadata() + metadata: KVConnectorMetadata = self._get_connector_metadata() assert isinstance(metadata, P2pNcclConnectorMetadata) if metadata is None: @@ -187,21 +201,23 @@ def inject_kv_into_layer( # Only process layers that have kv_cache # attribute (attention layers) Skip non-attention # layers like FusedMoE - kv_cache = getattr(layer, 'kv_cache', None) + kv_cache = getattr(layer, "kv_cache", None) if kv_cache is None: continue layer = kv_cache[forward_context.virtual_engine] kv_cache = self.p2p_nccl_engine.recv_tensor( - request.request_id + "#" + layer_name, remote_address) + request.request_id + "#" + layer_name, remote_address + ) if kv_cache is None: logger.warning("🚧kv_cache is None, %s", request.request_id) continue - inject_kv_into_layer(layer, kv_cache, request.block_ids, - request.request_id) + inject_kv_into_layer( + layer, kv_cache, request.block_ids, request.request_id + ) def wait_for_layer_load(self, layer_name: str) -> None: """Blocking until the KV for a specific layer is loaded into vLLM's @@ -214,9 +230,13 @@ def wait_for_layer_load(self, layer_name: str) -> None: """ return - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", - **kwargs: Any) -> None: + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. @@ -255,8 +275,9 @@ def extract_kv_from_layer( torch.Tensor: A tensor containing the extracted KV slices. Returns None if the layout is unsupported. """ - if (isinstance(attn_metadata, MLACommonMetadata) - or layer.shape[1] == 2): # MLA or FlashInfer + if ( + isinstance(attn_metadata, MLACommonMetadata) or layer.shape[1] == 2 + ): # MLA or FlashInfer return layer[block_ids, ...] if layer.shape[0] == 2: # FlashAttention @@ -272,8 +293,9 @@ def extract_kv_from_layer( remote_address = ip + ":" + str(port + self._rank) kv_cache = extract_kv_from_layer(kv_layer, request.block_ids) - self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, - kv_cache, remote_address) + self.p2p_nccl_engine.send_tensor( + request_id + "#" + layer_name, kv_cache, remote_address + ) def wait_for_save(self): if self.is_producer: @@ -281,8 +303,8 @@ def wait_for_save(self): self.p2p_nccl_engine.wait_for_sent() def get_finished( - self, finished_req_ids: set[str], - **kwargs: Any) -> tuple[Optional[set[str]], Optional[set[str]]]: + self, finished_req_ids: set[str], **kwargs: Any + ) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have finished generating tokens. @@ -296,10 +318,8 @@ def get_finished( assert self.p2p_nccl_engine is not None - no_compile_layers = ( - self._vllm_config.compilation_config.static_forward_context) - return self.p2p_nccl_engine.get_finished(finished_req_ids, - no_compile_layers) + no_compile_layers = self._vllm_config.compilation_config.static_forward_context + return self.p2p_nccl_engine.get_finished(finished_req_ids, no_compile_layers) # ============================== # Scheduler-side methods @@ -326,23 +346,24 @@ def get_num_new_matched_tokens( if self.is_producer: return 0, False - num_external_tokens = (len(request.prompt_token_ids) - 1 - - num_computed_tokens) + num_external_tokens = len(request.prompt_token_ids) - 1 - num_computed_tokens if num_external_tokens < 0: num_external_tokens = 0 return num_external_tokens, False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): """ Update KVConnector state after block allocation. """ if not self.is_producer and num_external_tokens > 0: self._requests_need_load[request.request_id] = ( - request, blocks.get_block_ids()[0]) + request, + blocks.get_block_ids()[0], + ) def build_connector_meta( self, @@ -361,26 +382,33 @@ def build_connector_meta( for new_req in scheduler_output.scheduled_new_reqs: if self.is_producer: - num_scheduled_tokens = ( - scheduler_output.num_scheduled_tokens)[new_req.req_id] + num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[ + new_req.req_id + ] num_tokens = num_scheduled_tokens + new_req.num_computed_tokens # the request's prompt is chunked prefill if num_tokens < len(new_req.prompt_token_ids): # 'CachedRequestData' has no attribute 'prompt_token_ids' self.chunked_prefill[new_req.req_id] = ( - new_req.block_ids[0], new_req.prompt_token_ids) + new_req.block_ids[0], + new_req.prompt_token_ids, + ) continue # the request's prompt is not chunked prefill - meta.add_request(request_id=new_req.req_id, - token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size) + meta.add_request( + request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + ) continue if new_req.req_id in self._requests_need_load: - meta.add_request(request_id=new_req.req_id, - token_ids=new_req.prompt_token_ids, - block_ids=new_req.block_ids[0], - block_size=self._block_size) + meta.add_request( + request_id=new_req.req_id, + token_ids=new_req.prompt_token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + ) self._requests_need_load.pop(new_req.req_id) cached_reqs = scheduler_output.scheduled_cached_reqs @@ -390,24 +418,24 @@ def build_connector_meta( resumed_from_preemption = cached_reqs.resumed_from_preemption[i] if self.is_producer: - num_scheduled_tokens = ( - scheduler_output.num_scheduled_tokens)[req_id] - num_tokens = (num_scheduled_tokens + num_computed_tokens) + num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id] + num_tokens = num_scheduled_tokens + num_computed_tokens assert req_id in self.chunked_prefill block_ids = new_block_ids[0] if not resumed_from_preemption: - block_ids = (self.chunked_prefill[req_id][0] + block_ids) + block_ids = self.chunked_prefill[req_id][0] + block_ids prompt_token_ids = self.chunked_prefill[req_id][1] # the request's prompt is chunked prefill again if num_tokens < len(prompt_token_ids): - self.chunked_prefill[req_id] = (block_ids, - prompt_token_ids) + self.chunked_prefill[req_id] = (block_ids, prompt_token_ids) continue # the request's prompt is all prefilled finally - meta.add_request(request_id=req_id, - token_ids=prompt_token_ids, - block_ids=block_ids, - block_size=self._block_size) + meta.add_request( + request_id=req_id, + token_ids=prompt_token_ids, + block_ids=block_ids, + block_size=self._block_size, + ) self.chunked_prefill.pop(req_id, None) continue @@ -424,10 +452,12 @@ def build_connector_meta( # of the block_ids for the request. block_ids = new_block_ids[0] - meta.add_request(request_id=req_id, - token_ids=token_ids, - block_ids=block_ids, - block_size=self._block_size) + meta.add_request( + request_id=req_id, + token_ids=token_ids, + block_ids=block_ids, + block_size=self._block_size, + ) self._requests_need_load.clear() return meta @@ -472,8 +502,7 @@ def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]: port = int(match.group(2)) return ip, port - raise ValueError( - f"Request id {request_id} does not contain hostname and port") + raise ValueError(f"Request id {request_id} does not contain hostname and port") @staticmethod def check_tensors_except_dim(tensor1, tensor2, dim): @@ -481,8 +510,9 @@ def check_tensors_except_dim(tensor1, tensor2, dim): shape2 = tensor2.size() if len(shape1) != len(shape2) or not all( - s1 == s2 - for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim): + s1 == s2 for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim + ): raise NotImplementedError( "Currently, only symmetric TP is supported. Asymmetric TP, PP," - "and others will be supported in future PRs.") + "and others will be supported in future PRs." + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index 959bf0277a3f..cff68818ca70 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -17,9 +17,15 @@ from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.device_communicators.pynccl_wrapper import ( - NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum) + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, +) from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 - TensorMemoryPool) + TensorMemoryPool, +) from vllm.utils import current_stream, get_ip logger = logging.getLogger(__name__) @@ -31,12 +37,12 @@ def set_p2p_nccl_context(num_channels: str): original_values: dict[str, Any] = {} env_vars = [ - 'NCCL_MAX_NCHANNELS', - 'NCCL_MIN_NCHANNELS', - 'NCCL_CUMEM_ENABLE', - 'NCCL_BUFFSIZE', - 'NCCL_PROTO', # LL,LL128,SIMPLE - 'NCCL_ALGO', # RING,TREE + "NCCL_MAX_NCHANNELS", + "NCCL_MIN_NCHANNELS", + "NCCL_CUMEM_ENABLE", + "NCCL_BUFFSIZE", + "NCCL_PROTO", # LL,LL128,SIMPLE + "NCCL_ALGO", # RING,TREE ] for var in env_vars: @@ -45,9 +51,9 @@ def set_p2p_nccl_context(num_channels: str): logger.info("set_p2p_nccl_context, original_values: %s", original_values) try: - os.environ['NCCL_MAX_NCHANNELS'] = num_channels - os.environ['NCCL_MIN_NCHANNELS'] = num_channels - os.environ['NCCL_CUMEM_ENABLE'] = '1' + os.environ["NCCL_MAX_NCHANNELS"] = num_channels + os.environ["NCCL_MIN_NCHANNELS"] = num_channels + os.environ["NCCL_CUMEM_ENABLE"] = "1" yield finally: for var in env_vars: @@ -65,13 +71,14 @@ class SendQueueItem: class P2pNcclEngine: - - def __init__(self, - local_rank: int, - config: KVTransferConfig, - hostname: str = "", - port_offset: int = 0, - library_path: Optional[str] = None) -> None: + def __init__( + self, + local_rank: int, + config: KVTransferConfig, + hostname: str = "", + port_offset: int = 0, + library_path: Optional[str] = None, + ) -> None: self.config = config self.rank = port_offset self.local_rank = local_rank @@ -91,8 +98,8 @@ def __init__(self, # The `http_port` must be consistent with the port of OpenAI. self.http_address = ( - f"{self._hostname}:" - f"{self.config.kv_connector_extra_config['http_port']}") + f"{self._hostname}:{self.config.kv_connector_extra_config['http_port']}" + ) # If `proxy_ip` or `proxy_port` is `""`, # then the ping thread will not be enabled. @@ -118,15 +125,17 @@ def __init__(self, self.recv_stream = torch.cuda.Stream() mem_pool_size_gb = float( - self.config.get_from_extra_config("mem_pool_size_gb", - DEFAULT_MEM_POOL_SIZE_GB)) - self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb * - 1024**3)) # GB + self.config.get_from_extra_config( + "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB + ) + ) + self.pool = TensorMemoryPool( + max_block_size=int(mem_pool_size_gb * 1024**3) + ) # GB # The sending type includes tree mutually exclusive options: # PUT, GET, PUT_ASYNC. - self.send_type = self.config.get_from_extra_config( - "send_type", "PUT_ASYNC") + self.send_type = self.config.get_from_extra_config("send_type", "PUT_ASYNC") if self.send_type == "GET": # tensor_id: torch.Tensor self.send_store: dict[str, torch.Tensor] = {} @@ -135,8 +144,9 @@ def __init__(self, # tensor_id: torch.Tensor self.send_queue: deque[SendQueueItem] = deque() if self.send_type == "PUT_ASYNC": - self._send_thread = threading.Thread(target=self.send_async, - daemon=True) + self._send_thread = threading.Thread( + target=self.send_async, daemon=True + ) self._send_thread.start() # tensor_id: torch.Tensor/(addr, dtype, shape) @@ -150,10 +160,12 @@ def __init__(self, self.buffer_size_threshold = float(self.config.kv_buffer_size) self.nccl_num_channels = self.config.get_from_extra_config( - "nccl_num_channels", "8") + "nccl_num_channels", "8" + ) self._listener_thread = threading.Thread( - target=self.listen_for_requests, daemon=True) + target=self.listen_for_requests, daemon=True + ) self._listener_thread.start() self._ping_thread = None @@ -164,9 +176,16 @@ def __init__(self, logger.info( "💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, " "zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_" - "threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank, - self.http_address, self.zmq_address, self.proxy_address, - self.send_type, self.buffer_size_threshold, self.nccl_num_channels) + "threshold:%.2f, nccl_num_channels:%s", + self.rank, + self.local_rank, + self.http_address, + self.zmq_address, + self.proxy_address, + self.send_type, + self.buffer_size_threshold, + self.nccl_num_channels, + ) def create_connect(self, remote_address: typing.Optional[str] = None): assert remote_address is not None @@ -176,8 +195,11 @@ def create_connect(self, remote_address: typing.Optional[str] = None): sock.connect(f"tcp://{remote_address}") self.socks[remote_address] = sock if remote_address in self.comms: - logger.info("👋comm exists, remote_address:%s, comms:%s", - remote_address, self.comms) + logger.info( + "👋comm exists, remote_address:%s, comms:%s", + remote_address, + self.comms, + ) return sock, self.comms[remote_address] unique_id = self.nccl.ncclGetUniqueId() @@ -187,11 +209,14 @@ def create_connect(self, remote_address: typing.Optional[str] = None): with torch.cuda.device(self.device): rank = 0 with set_p2p_nccl_context(self.nccl_num_channels): - comm: ncclComm_t = self.nccl.ncclCommInitRank( - 2, unique_id, rank) + comm: ncclComm_t = self.nccl.ncclCommInitRank(2, unique_id, rank) self.comms[remote_address] = (comm, rank) - logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank:%s", - self.zmq_address, remote_address, rank) + logger.info( + "🤝ncclCommInitRank Success, %s👉%s, MyRank:%s", + self.zmq_address, + remote_address, + rank, + ) return self.socks[remote_address], self.comms[remote_address] @@ -207,9 +232,9 @@ def send_tensor( self.recv_store_cv.notify() return True - item = SendQueueItem(tensor_id=tensor_id, - remote_address=remote_address, - tensor=tensor) + item = SendQueueItem( + tensor_id=tensor_id, remote_address=remote_address, tensor=tensor + ) if self.send_type == "PUT": return self.send_sync(item) @@ -227,31 +252,45 @@ def send_tensor( logger.warning( "❗[GET]tensor_id:%s, tensor_size:%d, is greater than" "buffer size threshold :%d, skip send to %s, rank:%d", - tensor_id, tensor_size, self.buffer_size_threshold, - remote_address, self.rank) + tensor_id, + tensor_size, + self.buffer_size_threshold, + remote_address, + self.rank, + ) return False - while (self.buffer_size + tensor_size - > self.buffer_size_threshold): + while self.buffer_size + tensor_size > self.buffer_size_threshold: assert len(self.send_store) > 0 oldest_tensor_id = next(iter(self.send_store)) oldest_tensor = self.send_store.pop(oldest_tensor_id) - oldest_tensor_size = oldest_tensor.element_size( - ) * oldest_tensor.numel() + oldest_tensor_size = ( + oldest_tensor.element_size() * oldest_tensor.numel() + ) self.buffer_size -= oldest_tensor_size logger.debug( "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," " buffer_size:%d, oldest_tensor_size:%d, rank:%d", - remote_address, tensor_id, tensor_size, self.buffer_size, - oldest_tensor_size, self.rank) + remote_address, + tensor_id, + tensor_size, + self.buffer_size, + oldest_tensor_size, + self.rank, + ) self.send_store[tensor_id] = tensor self.buffer_size += tensor_size logger.debug( "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " - "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address, - tensor_id, tensor_size, tensor.shape, self.rank, + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", + remote_address, + tensor_id, + tensor_size, + tensor.shape, + self.rank, self.buffer_size, - self.buffer_size / self.buffer_size_threshold * 100) + self.buffer_size / self.buffer_size_threshold * 100, + ) return True def recv_tensor( @@ -269,17 +308,18 @@ def recv_tensor( if tensor is not None: if isinstance(tensor, tuple): addr, dtype, shape = tensor - tensor = self.pool.load_tensor(addr, dtype, shape, - self.device) + tensor = self.pool.load_tensor(addr, dtype, shape, self.device) else: - self.buffer_size -= (tensor.element_size() * - tensor.numel()) + self.buffer_size -= tensor.element_size() * tensor.numel() else: duration = time.time() - start_time logger.warning( - "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, " - "rank:%d", remote_address, tensor_id, duration * 1000, - self.rank) + "🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, rank:%d", + remote_address, + tensor_id, + duration * 1000, + self.rank, + ) return tensor # GET @@ -298,14 +338,18 @@ def recv_tensor( message = sock.recv() data = msgpack.loads(message) if data["ret"] != 0: - logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d", - remote_address, tensor_id, data["ret"]) + logger.warning( + "🔴[GET]Recv From %s, tensor_id: %s, ret: %d", + remote_address, + tensor_id, + data["ret"], + ) return None with torch.cuda.stream(self.recv_stream): - tensor = torch.empty(data["shape"], - dtype=getattr(torch, data["dtype"]), - device=self.device) + tensor = torch.empty( + data["shape"], dtype=getattr(torch, data["dtype"]), device=self.device + ) self.recv(comm, tensor, rank ^ 1, self.recv_stream) @@ -320,38 +364,45 @@ def listen_for_requests(self): remote_address, message = self.router_socket.recv_multipart() data = msgpack.loads(message) if data["cmd"] == "NEW": - unique_id = self.nccl.unique_id_from_bytes( - bytes(data["unique_id"])) + unique_id = self.nccl.unique_id_from_bytes(bytes(data["unique_id"])) with torch.cuda.device(self.device): rank = 1 with set_p2p_nccl_context(self.nccl_num_channels): comm: ncclComm_t = self.nccl.ncclCommInitRank( - 2, unique_id, rank) + 2, unique_id, rank + ) self.comms[remote_address.decode()] = (comm, rank) - logger.info("🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", - self.zmq_address, remote_address.decode(), - rank) + logger.info( + "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, + remote_address.decode(), + rank, + ) elif data["cmd"] == "PUT": tensor_id = data["tensor_id"] try: with torch.cuda.stream(self.recv_stream): - tensor = torch.empty(data["shape"], - dtype=getattr( - torch, data["dtype"]), - device=self.device) + tensor = torch.empty( + data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device, + ) self.router_socket.send_multipart([remote_address, b"0"]) comm, rank = self.comms[remote_address.decode()] self.recv(comm, tensor, rank ^ 1, self.recv_stream) tensor_size = tensor.element_size() * tensor.numel() - if (self.buffer_size + tensor_size - > self.buffer_size_threshold): + if self.buffer_size + tensor_size > self.buffer_size_threshold: # Store Tensor in memory pool addr = self.pool.store_tensor(tensor) tensor = (addr, tensor.dtype, tensor.shape) logger.warning( "🔴[PUT]Recv Tensor, Out Of Threshold, " - "%s👈%s, data:%s, addr:%d", self.zmq_address, - remote_address.decode(), data, addr) + "%s👈%s, data:%s, addr:%d", + self.zmq_address, + remote_address.decode(), + data, + addr, + ) else: self.buffer_size += tensor_size @@ -359,9 +410,11 @@ def listen_for_requests(self): self.router_socket.send_multipart([remote_address, b"1"]) tensor = None logger.warning( - "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " - "data:%s", self.zmq_address, remote_address.decode(), - data) + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, data:%s", + self.zmq_address, + remote_address.decode(), + data, + ) with self.recv_store_cv: self.recv_store[tensor_id] = tensor @@ -376,7 +429,7 @@ def listen_for_requests(self): data = { "ret": 0, "shape": tensor.shape, - "dtype": str(tensor.dtype).replace("torch.", "") + "dtype": str(tensor.dtype).replace("torch.", ""), } # LRU self.send_store[tensor_id] = tensor @@ -384,26 +437,26 @@ def listen_for_requests(self): else: data = {"ret": 1} - self.router_socket.send_multipart( - [remote_address, msgpack.dumps(data)]) + self.router_socket.send_multipart([remote_address, msgpack.dumps(data)]) if data["ret"] == 0: comm, rank = self.comms[remote_address.decode()] - self.send(comm, tensor.to(self.device), rank ^ 1, - self.send_stream) + self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) else: logger.warning( "🚧Unexpected, Received message from %s, data:%s", - remote_address, data) + remote_address, + data, + ) def have_sent_tensor_id(self, tensor_id: str): - request_id = tensor_id.split('#')[0] + request_id = tensor_id.split("#")[0] if request_id not in self.send_request_id_to_tensor_ids: self.send_request_id_to_tensor_ids[request_id] = set() self.send_request_id_to_tensor_ids[request_id].add(tensor_id) def have_received_tensor_id(self, tensor_id: str): - request_id = tensor_id.split('#')[0] + request_id = tensor_id.split("#")[0] if request_id not in self.recv_request_id_to_tensor_ids: self.recv_request_id_to_tensor_ids[request_id] = set() self.recv_request_id_to_tensor_ids[request_id].add(tensor_id) @@ -427,7 +480,10 @@ def wait_for_sent(self): duration = time.time() - start_time logger.debug( "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" - " to be empty, rank:%d", duration * 1000, self.rank) + " to be empty, rank:%d", + duration * 1000, + self.rank, + ) def send_sync(self, item: SendQueueItem) -> bool: if item.remote_address is None: @@ -443,7 +499,7 @@ def send_sync(self, item: SendQueueItem) -> bool: "cmd": "PUT", "tensor_id": item.tensor_id, "shape": tensor.shape, - "dtype": str(tensor.dtype).replace("torch.", "") + "dtype": str(tensor.dtype).replace("torch.", ""), } sock.send(msgpack.dumps(data)) @@ -452,10 +508,14 @@ def send_sync(self, item: SendQueueItem) -> bool: logger.error( "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", - self.zmq_address, item.remote_address, rank, data, + self.zmq_address, + item.remote_address, + rank, + data, tensor.shape, tensor.element_size() * tensor.numel() / 1024**3, - response.decode()) + response.decode(), + ) return False self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) @@ -466,7 +526,7 @@ def send_sync(self, item: SendQueueItem) -> bool: return True def get_finished( - self, finished_req_ids: set[str], no_compile_layers + self, finished_req_ids: set[str], no_compile_layers ) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have @@ -486,10 +546,8 @@ def get_finished( if tensor_id in self.recv_store: with self.recv_store_cv: tensor = self.recv_store.pop(tensor_id, None) - self.send_request_id_to_tensor_ids.pop( - request_id, None) - self.recv_request_id_to_tensor_ids.pop( - request_id, None) + self.send_request_id_to_tensor_ids.pop(request_id, None) + self.recv_request_id_to_tensor_ids.pop(request_id, None) if isinstance(tensor, tuple): addr, _, _ = tensor self.pool.free(addr) @@ -510,7 +568,7 @@ def ping(self): data = { "type": "P" if self.config.is_kv_producer else "D", "http_address": self.http_address, - "zmq_address": self.zmq_address + "zmq_address": self.zmq_address, } while True: sock.send(msgpack.dumps(data)) @@ -519,27 +577,39 @@ def ping(self): def send(self, comm, tensor: torch.Tensor, dst: int, stream=None): assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() with torch.cuda.stream(stream): - self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), dst, - comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + comm, + cudaStream_t(stream.cuda_stream), + ) stream.synchronize() def recv(self, comm, tensor: torch.Tensor, src: int, stream=None): assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {tensor.device}" + ) if stream is None: stream = current_stream() with torch.cuda.stream(stream): - self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), src, - comm, cudaStream_t(stream.cuda_stream)) + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + comm, + cudaStream_t(stream.cuda_stream), + ) stream.synchronize() def close(self) -> None: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py index 26070488bad8..899f1eae86d2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py @@ -67,8 +67,7 @@ def __init__(self, max_block_size: int, min_block_size: int = 512): if max_block_size <= 0 or min_block_size <= 0: raise ValueError("Block sizes must be positive") if max_block_size < min_block_size: - raise ValueError( - "Max block size must be greater than min block size") + raise ValueError("Max block size must be greater than min block size") self.max_block_size = self._round_to_power_of_two(max_block_size) self.min_block_size = self._round_to_power_of_two(min_block_size) @@ -91,17 +90,18 @@ def _initialize_free_lists(self): size //= 2 def _allocate_pinned_memory(self): - self.base_tensor = torch.empty(self.max_block_size // 4, - dtype=torch.float32, - pin_memory=True) + self.base_tensor = torch.empty( + self.max_block_size // 4, dtype=torch.float32, pin_memory=True + ) self.base_address = self.base_tensor.data_ptr() - initial_block = MemoryBlock(size=self.max_block_size, - addr=self.base_address) - self.free_lists[self.max_block_size][ - initial_block.addr] = initial_block + initial_block = MemoryBlock(size=self.max_block_size, addr=self.base_address) + self.free_lists[self.max_block_size][initial_block.addr] = initial_block - logger.debug("TensorMemoryPool, base_address:%d, max_block_size:%d", - self.base_address, self.max_block_size) + logger.debug( + "TensorMemoryPool, base_address:%d, max_block_size:%d", + self.base_address, + self.max_block_size, + ) def allocate(self, size: int) -> int: """Allocates a memory block of at least the requested size. @@ -118,8 +118,7 @@ def allocate(self, size: int) -> int: if size <= 0: raise ValueError("Allocation size must be positive") - required_size = self._round_to_power_of_two( - max(size, self.min_block_size)) + required_size = self._round_to_power_of_two(max(size, self.min_block_size)) if required_size > self.max_block_size: raise ValueError("Requested size exceeds maximum block size") @@ -135,8 +134,7 @@ def allocate(self, size: int) -> int: raise ValueError("Insufficient memory") def _split_block(self, block: MemoryBlock, required_size: int): - while (block.size > required_size - and block.size // 2 >= self.min_block_size): + while block.size > required_size and block.size // 2 >= self.min_block_size: buddy_size = block.size // 2 buddy_addr = block.addr + buddy_size @@ -165,8 +163,11 @@ def _merge_buddies(self, block: MemoryBlock): depth = 0 while depth < MAX_MERGE_DEPTH: - buddy_offset = block.size if (block.addr - self.base_address) % ( - 2 * block.size) == 0 else -block.size + buddy_offset = ( + block.size + if (block.addr - self.base_address) % (2 * block.size) == 0 + else -block.size + ) buddy_addr = block.addr + buddy_offset buddy = self.free_lists[block.size].get(buddy_addr) if buddy: @@ -202,14 +203,14 @@ def store_tensor(self, tensor: torch.Tensor) -> int: self.free(addr) raise ValueError( f"Allocated block size {block.size} is smaller than " - f"required size {size}") + f"required size {size}" + ) try: buffer = (ctypes.c_byte * block.size).from_address(block.addr) - cpu_tensor = torch.frombuffer(buffer, - dtype=tensor.dtype, - count=tensor.numel()).reshape( - tensor.shape) + cpu_tensor = torch.frombuffer( + buffer, dtype=tensor.dtype, count=tensor.numel() + ).reshape(tensor.shape) except ValueError as err: self.free(addr) raise ValueError(f"Failed to create tensor view: {err}") from err @@ -218,9 +219,13 @@ def store_tensor(self, tensor: torch.Tensor) -> int: return addr - def load_tensor(self, addr: int, dtype: torch.dtype, shape: tuple[int, - ...], - device: torch.device) -> torch.Tensor: + def load_tensor( + self, + addr: int, + dtype: torch.dtype, + shape: tuple[int, ...], + device: torch.device, + ) -> torch.Tensor: """Loads a tensor from pinned host memory to the specified device. Args: @@ -247,8 +252,9 @@ def load_tensor(self, addr: int, dtype: torch.dtype, shape: tuple[int, raise ValueError("Requested tensor size exceeds block size") buffer = (ctypes.c_byte * block.size).from_address(block.addr) - cpu_tensor = torch.frombuffer(buffer, dtype=dtype, - count=num_elements).reshape(shape) + cpu_tensor = torch.frombuffer(buffer, dtype=dtype, count=num_elements).reshape( + shape + ) cuda_tensor = torch.empty(shape, dtype=dtype, device=device) @@ -260,7 +266,7 @@ def cleanup(self): """Cleans up all memory resources and resets the pool state.""" self.free_lists.clear() self.allocated_blocks.clear() - if hasattr(self, 'base_tensor'): + if hasattr(self, "base_tensor"): del self.base_tensor def __del__(self): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index c9949d81465c..a1bab4e06145 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -10,7 +10,10 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.core.sched.output import SchedulerOutput @@ -35,15 +38,22 @@ class ReqMeta: mm_hashes: list[str] @staticmethod - def make_meta(token_ids: list[int], block_ids: list[int], block_size: int, - is_store: bool, mm_hashes: list[str]) -> "ReqMeta": + def make_meta( + token_ids: list[int], + block_ids: list[int], + block_size: int, + is_store: bool, + mm_hashes: list[str], + ) -> "ReqMeta": valid_num_tokens = align_to_block_size(len(token_ids), block_size) token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens] block_ids_tensor = torch.tensor(block_ids) num_blocks = block_ids_tensor.shape[0] block_offsets = torch.arange(0, block_size) - slot_mapping = block_offsets.reshape((1, block_size)) + \ - block_ids_tensor.reshape((num_blocks, 1)) * block_size + slot_mapping = ( + block_offsets.reshape((1, block_size)) + + block_ids_tensor.reshape((num_blocks, 1)) * block_size + ) slot_mapping = slot_mapping.flatten()[:valid_num_tokens] return ReqMeta( token_ids=token_ids_tensor, @@ -66,8 +76,8 @@ def add_request( mm_hashes: list[str], ) -> None: self.requests.append( - ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, - mm_hashes)) + ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes) + ) class SharedStorageConnector(KVConnectorBase_V1): @@ -82,13 +92,13 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self._requests_need_load: dict[str, Request] = {} transfer_config = vllm_config.kv_transfer_config self._storage_path = transfer_config.get_from_extra_config( - "shared_storage_path", "/tmp") + "shared_storage_path", "/tmp" + ) logger.info(vllm_config.kv_transfer_config) logger.info("Shared storage path is %s", self._storage_path) - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs: Any) -> None: - """Start loading the KV cache from the connector buffer to vLLM's + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: + """Start loading the KV cache from the connector buffer to vLLM's paged KV buffer. Args: @@ -96,7 +106,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs: additional arguments for the load operation Note: - The number of elements in kv_caches and layer_names should be + The number of elements in kv_caches and layer_names should be the same. """ attn_metadata = forward_context.attn_metadata @@ -109,13 +119,13 @@ def inject_kv_into_layer( """Inject the KV cache into the layer. Args: - dst_kv_cache_layer (torch.Tensor): the destination KV cache - layer. In shape [2, num_pages, page_size, xxx] if not + dst_kv_cache_layer (torch.Tensor): the destination KV cache + layer. In shape [2, num_pages, page_size, xxx] if not using MLA, [num_pages, page_size, xxx] otherwise. src_kv_cache (torch.Tensor): the source KV cache. In shape - [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] + [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx] otherwise. - slot_mapping (torch.Tensor): the slot mapping. In shape + slot_mapping (torch.Tensor): the slot mapping. In shape [num_tokens]. """ dst_kv_cache_layer_shape = dst_kv_cache_layer.shape @@ -123,14 +133,16 @@ def inject_kv_into_layer( num_pages = dst_kv_cache_layer_shape[0] page_size = dst_kv_cache_layer_shape[1] dst_kv_cache_layer = dst_kv_cache_layer.reshape( - num_pages * page_size, -1) + num_pages * page_size, -1 + ) dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) else: num_pages = dst_kv_cache_layer_shape[1] page_size = dst_kv_cache_layer_shape[2] dst_kv_cache_layer = dst_kv_cache_layer.reshape( - 2, num_pages * page_size, -1) + 2, num_pages * page_size, -1 + ) dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape) @@ -146,40 +158,39 @@ def inject_kv_into_layer( attn_metadata = forward_context.attn_metadata if attn_metadata is None: - logger.warning( - "In connector.start_load_kv, but the attn_metadata is None") + logger.warning("In connector.start_load_kv, but the attn_metadata is None") return # Load the KV for each request each layer for request in metadata.requests: if request.is_store: continue - logger.info("Inject KV cache of %d tokens to the paged memory", - len(request.slot_mapping)) + logger.info( + "Inject KV cache of %d tokens to the paged memory", + len(request.slot_mapping), + ) for layer_name in forward_context.no_compile_layers: layer = forward_context.no_compile_layers[layer_name] # Only process layers that have kv_cache # attribute (attention layers) Skip non-attention # layers like FusedMoE/MLP etc. - kv_cache_attr = getattr(layer, 'kv_cache', None) + kv_cache_attr = getattr(layer, "kv_cache", None) if kv_cache_attr is None: continue - kv_cache_layer = kv_cache_attr[ \ - forward_context.virtual_engine] + kv_cache_layer = kv_cache_attr[forward_context.virtual_engine] filename = self._generate_filename_debug( - layer_name, request.token_ids, request.mm_hashes) - kv_cache = safetensors.torch.load_file( - filename)["kv_cache"].cuda() - inject_kv_into_layer(kv_cache_layer, kv_cache, - request.slot_mapping) + layer_name, request.token_ids, request.mm_hashes + ) + kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda() + inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping) def wait_for_layer_load(self, layer_name: str) -> None: """Blocking until the KV for a specific layer is loaded into vLLM's - paged buffer. - + paged buffer. + This interface will be useful for layer-by-layer pipelining. Args: @@ -187,15 +198,19 @@ def wait_for_layer_load(self, layer_name: str) -> None: """ return - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", - **kwargs: Any) -> None: - """Start saving the KV cache of the layer from vLLM's paged buffer + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs: Any, + ) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer to the connector. Args: layer_name (str): the name of the layer. - kv_layer (torch.Tensor): the paged KV buffer of the current + kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. @@ -212,20 +227,18 @@ def extract_kv_from_layer( """ if isinstance(attn_metadata, MLACommonMetadata): num_pages, page_size = layer.shape[0], layer.shape[1] - return layer.reshape(num_pages * page_size, -1)[slot_mapping, - ...] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] num_pages, page_size = layer.shape[1], layer.shape[2] - return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, - ...] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...] connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, SharedStorageConnectorMetadata) for request in connector_metadata.requests: if request.is_store: filename = self._generate_filename_debug( - layer_name, request.token_ids, request.mm_hashes) - kv_cache = extract_kv_from_layer(kv_layer, - request.slot_mapping) + layer_name, request.token_ids, request.mm_hashes + ) + kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) tensors = {"kv_cache": kv_cache.detach().cpu()} safetensors.torch.save_file(tensors, filename) @@ -240,14 +253,14 @@ def get_num_new_matched_tokens( """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - the number of tokens that can be loaded from the + the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ # NOTE: in this debug implementation, we assume that the prompt is @@ -265,13 +278,14 @@ def get_num_new_matched_tokens( # Now, first num_tokens_to_check tokens are hit, we need to prepare # the metadata for the worker connector to correctly load the KV num_tokens_to_check = align_to_block_size( - len(request.prompt_token_ids) - 1, self._block_size) + len(request.prompt_token_ids) - 1, self._block_size + ) return num_tokens_to_check - num_computed_tokens, False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): """ Update KVConnector state after block allocation. @@ -303,7 +317,8 @@ def build_connector_meta( block_ids=new_req.block_ids[0], block_size=self._block_size, is_store=False, - mm_hashes=[f.identifier for f in new_req.mm_features]) + mm_hashes=[f.identifier for f in new_req.mm_features], + ) total_need_load += 1 else: # NOTE: here, we set the store and load being exclusive, @@ -316,7 +331,8 @@ def build_connector_meta( block_ids=new_req.block_ids[0], block_size=self._block_size, is_store=True, - mm_hashes=[f.identifier for f in new_req.mm_features]) + mm_hashes=[f.identifier for f in new_req.mm_features], + ) cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): @@ -346,7 +362,8 @@ def build_connector_meta( block_ids=block_ids, block_size=self._block_size, is_store=False, - mm_hashes=[f.identifier for f in request.mm_features]) + mm_hashes=[f.identifier for f in request.mm_features], + ) total_need_load += 1 assert total_need_load == len(self._requests_need_load) @@ -361,14 +378,15 @@ def _found_match_for_request( self, request: "Request", ) -> bool: - """Check if the cache is hit for the request. - """ + """Check if the cache is hit for the request.""" num_tokens_to_check = align_to_block_size( - len(request.prompt_token_ids) - 1, self._block_size) + len(request.prompt_token_ids) - 1, self._block_size + ) foldername = self._generate_foldername_debug( torch.tensor(request.prompt_token_ids)[:num_tokens_to_check], [f.identifier for f in request.mm_features], - create_folder=False) + create_folder=False, + ) return os.path.exists(foldername) def _generate_foldername_debug( @@ -377,7 +395,7 @@ def _generate_foldername_debug( mm_hashes: list[str], create_folder=False, ) -> str: - """Generate a folder name based on the hash of the bytes of the input + """Generate a folder name based on the hash of the bytes of the input ids. """ token_bytes = token_ids.numpy().tobytes() @@ -385,9 +403,8 @@ def _generate_foldername_debug( # to create a canonical key. if mm_hashes: mm_str = "-".join(mm_hashes) - token_bytes += mm_str.encode('utf-8') - input_ids_hash = hashlib.md5(token_bytes, - usedforsecurity=False).hexdigest() + token_bytes += mm_str.encode("utf-8") + input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest() foldername = os.path.join(self._storage_path, input_ids_hash) if create_folder: @@ -400,16 +417,15 @@ def _generate_filename_debug( token_ids: torch.Tensor, mm_hashes: list[str], ) -> str: - """Generate a file name based on the layer name and the hash + """Generate a file name based on the layer name and the hash of the bytes of the input ids. """ - foldername = self._generate_foldername_debug(token_ids, - mm_hashes=mm_hashes, - create_folder=True) + foldername = self._generate_foldername_debug( + token_ids, mm_hashes=mm_hashes, create_folder=True + ) return os.path.join(foldername, f"{layer_name}.safetensors") def align_to_block_size(num_tokens: int, block_size) -> int: - """Align the number of tokens to the block size. - """ + """Align the number of tokens to the block size.""" return (num_tokens - 1) // block_size * block_size diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py index eef14269f196..08b683bfe23f 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -42,39 +42,44 @@ class KVLookupBufferBase(KVCacheBufferBase): Abstract base class for a KVCache lookup buffer. This class provides an abstraction for a key-value (KV) cache lookup buffer. - + The key of the lookup buffer: - input_tokens: token IDs of the request - roi: a binary mask on top of input_tokens. - - Purpose of roi: Since KV cache may only be available for a subset of - tokens in the input (for example, when vLLM is connected to an external - KV cache service), roi specifies the subset of tokens that the KV cache + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache is associated with. - - NOTE: roi can be further extended to describe which part of KV the - current process is holding (each process may only hold a part of KV + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV due to TP and PP). This is not implemented for now. - + The value of the lookup buffer: - key: the key tensor in the KV cache - value: the value tensor in the KV cache - - hidden: the final hidden state generated by model forwarding. This allows + - hidden: the final hidden state generated by model forwarding. This allows vLLM to bypass further model forwarding by transmitting the hidden state. """ @abstractmethod - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: + def insert( + self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor, + ) -> None: """Insert into the lookup buffer. - + The functionality is similar to the following python statement ``` buffer[input_tokens, roi] = [key, value, hidden] ``` - + FIXME: in the future, we should only have two arguments, key and value, where key is a tensor dict and value is a tensor dict. - + FIXME: we should transmit both sampler outputs and the hidden states. Args: @@ -82,8 +87,8 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, roi (torch.Tensor): A binary mask on top of the input tokens key (torch.Tensor): The key tensor in the KV cache. value (torch.Tensor): The value tensor in the KV cache. - hidden (torch.Tensor): The final hidden state tensor generated - during model forwarding to bypass model + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model forwarding. Raises: @@ -93,16 +98,16 @@ def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, @abstractmethod def drop_select( - self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: + self, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor] + ) -> list[Optional[torch.Tensor]]: """Select and *drop* KV cache entries from the lookup buffer. - + The functionality is similar to the following python statements ``` ret = buffer.pop(input_tokens, roi) return ret ``` - + If `input_tokens` and `roi` is `None`, it means selecting any of the KV caches in the buffer, return, and remove it from the buffer, useful when offloading KV cache to KV cache storage service. diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py index 4381aad1e995..44fc6d8ac5ad 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py @@ -6,6 +6,7 @@ into a remote KVStore-based lookup buffer and getting existing KV caches from this remote lookup buffer. """ + import json import os from dataclasses import dataclass @@ -16,8 +17,7 @@ from safetensors.torch import save as safetensors_save from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( - KVStoreBufferBase) +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVStoreBufferBase from vllm.logger import init_logger DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB @@ -37,65 +37,69 @@ class MooncakeStoreConfig: master_server_address: str @staticmethod - def from_file(file_path: str) -> 'MooncakeStoreConfig': + def from_file(file_path: str) -> "MooncakeStoreConfig": """Load the config from a JSON file.""" with open(file_path) as fin: config = json.load(fin) return MooncakeStoreConfig( local_hostname=config.get("local_hostname"), metadata_server=config.get("metadata_server"), - global_segment_size=config.get("global_segment_size", - DEFAULT_GLOBAL_SEGMENT_SIZE), - local_buffer_size=config.get("local_buffer_size", - DEFAULT_LOCAL_BUFFER_SIZE), + global_segment_size=config.get( + "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE + ), + local_buffer_size=config.get( + "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE + ), protocol=config.get("protocol", "tcp"), device_name=config.get("device_name", ""), master_server_address=config.get("master_server_address"), ) @staticmethod - def load_from_env() -> 'MooncakeStoreConfig': + def load_from_env() -> "MooncakeStoreConfig": """Load config from a file specified in the environment variable.""" - config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") if config_file_path is None: raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." + ) return MooncakeStoreConfig.from_file(config_file_path) class MooncakeStore(KVStoreBufferBase): - def __init__( self, config: VllmConfig, ): - try: from mooncake.store import MooncakeDistributedStore except ImportError as e: raise ImportError( "Please install mooncake by following the instructions at " "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector.") from e + "to run vLLM with MooncakeConnector." + ) from e try: self.store = MooncakeDistributedStore() self.config = MooncakeStoreConfig.load_from_env() logger.info("Mooncake Configuration loaded successfully.") - self.store.setup(self.config.local_hostname, - self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, self.config.device_name, - self.config.master_server_address) + self.store.setup( + self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + ) except ValueError as e: logger.error("Configuration loading failed: %s", e) raise except Exception as exc: - logger.error( - "An error occurred while loading the configuration: %s", exc) + logger.error("An error occurred while loading the configuration: %s", exc) raise def close(self): @@ -126,12 +130,9 @@ def _put_impl( value: torch.Tensor, ) -> None: """Put KVCache to Mooncake Store""" - device_id = value.device.index if value.device.type == 'cuda' else -1 + device_id = value.device.index if value.device.type == "cuda" else -1 device_tensor = torch.tensor(device_id, dtype=torch.int32) - value_bytes = safetensors_save({ - "tensor": value, - "device_id": device_tensor - }) + value_bytes = safetensors_save({"tensor": value, "device_id": device_tensor}) try: self.store.put(key, value_bytes) except TypeError as err: @@ -154,8 +155,11 @@ def _get_impl( tensor = loaded_tensors["tensor"] device_id_tensor = loaded_tensors["device_id"] device_id = int(device_id_tensor.item()) - device = torch.device( - 'cuda', device_id) if device_id >= 0 else torch.device('cpu') + device = ( + torch.device("cuda", device_id) + if device_id >= 0 + else torch.device("cpu") + ) return tensor.to(device) return None diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index a0ff7c320f61..cd58ec2e7639 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -1,23 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ - Implements a distributed key-value (KV) cache transfer mechanism. - - Key Features: - - Distributed KV cache transmission using PyNccl pipes. - - Non-blocking `insert`, blocking `drop_select`. - - Use CPU signal pipe to avoid racing condition - - Handles buffer size constraints and provide backpressure mechanism to - stop the prefill instance when the decode instance is slow. +Implements a distributed key-value (KV) cache transfer mechanism. + +Key Features: +- Distributed KV cache transmission using PyNccl pipes. +- Non-blocking `insert`, blocking `drop_select`. +- Use CPU signal pipe to avoid racing condition +- Handles buffer size constraints and provide backpressure mechanism to + stop the prefill instance when the decode instance is slow. """ + import threading from collections import deque from typing import Optional, Union import torch -from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( - KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import KVLookupBufferBase from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger @@ -25,9 +25,9 @@ class SimpleBuffer(KVLookupBufferBase): - - def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, - buffer_size_thresh: float): + def __init__( + self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, buffer_size_thresh: float + ): """ signal_pipe: on CPU @@ -51,9 +51,11 @@ def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, self.normal_signal = torch.tensor([0], device="cpu") self.end_signal = None - def _matches(self, tokens_roi_sender: list[torch.Tensor], - tokens_roi_recver: list[torch.Tensor]): - + def _matches( + self, + tokens_roi_sender: list[torch.Tensor], + tokens_roi_recver: list[torch.Tensor], + ): # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) @@ -74,15 +76,12 @@ def _matches(self, tokens_roi_sender: list[torch.Tensor], # simple common prefix matching min_length = min(len(tokens_sender), len(tokens_recver)) - if torch.allclose(tokens_sender[:min_length], - tokens_recver[:min_length]): + if torch.allclose(tokens_sender[:min_length], tokens_recver[:min_length]): return min_length return 0 - def _send_tensor_and_dec_size(self, - tensor: Optional[torch.Tensor]) -> None: - + def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor]) -> None: assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() if tensor.dtype == torch.bool: @@ -90,7 +89,6 @@ def _send_tensor_and_dec_size(self, self.data_pipe.send_tensor(tensor) def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]): - if isinstance(data, torch.Tensor): return data.element_size() * data.numel() if not data: @@ -100,10 +98,14 @@ def _get_element_size(self, data: Optional[Union[list, torch.Tensor]]): raise AssertionError(f"Unknown data type {type(data)}") - def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor): - + def _add_to_buffer( + self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor, + ): if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() if isinstance(roi, torch.Tensor): @@ -134,9 +136,7 @@ def _is_end_signal(self, signal): return signal is None def drop_select_handler(self): - try: - while True: signal = self.signal_pipe.recv_tensor() if self._is_end_signal(signal): @@ -146,20 +146,21 @@ def drop_select_handler(self): input_tokens = self.data_pipe.recv_tensor() roi = self.data_pipe.recv_tensor() - assert roi is not None, "Please provide the roi when sending "\ - "drop-select request" - roi = (roi > 0.5) + assert roi is not None, ( + "Please provide the roi when sending drop-select request" + ) + roi = roi > 0.5 tokens_roi_recver = [input_tokens, roi] def is_buffer_available( - tokens_roi_recver: list[torch.Tensor], ) -> bool: + tokens_roi_recver: list[torch.Tensor], + ) -> bool: # perform input tokens and roi matching # FIXME: this matching is O(n), ideally it should be O(1) # but this buffer size won't (and shouldn't) be too large so # the fix is not urgent. for _ in range(len(self.buffer)): - if self._matches(self.buffer[0], - tokens_roi_recver) > 0: + if self._matches(self.buffer[0], tokens_roi_recver) > 0: return True # rotate the element we just accessed to the end self.buffer.rotate(-1) @@ -167,8 +168,7 @@ def is_buffer_available( with self.buffer_cv: while not is_buffer_available(tokens_roi_recver): - logger.debug( - "KV transfer buffer is not available. Waiting...") + logger.debug("KV transfer buffer is not available. Waiting...") self.buffer_cv.wait() # need to clone the tensor # in case the tensor is freed before sending finishes @@ -178,18 +178,18 @@ def is_buffer_available( self.buffer_cv.notify() except RuntimeError as e: - if 'Connection closed by peer' not in str(e): + if "Connection closed by peer" not in str(e): raise e logger.debug("Closing drop_select_handler") def drop_select( - self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: - - assert self.request_handling_thread is None, \ - "drop_select should be called by the KV cache consumer "\ + self, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor] + ) -> list[Optional[torch.Tensor]]: + assert self.request_handling_thread is None, ( + "drop_select should be called by the KV cache consumer " "(e.g. the decode vLLM instance)" + ) if isinstance(input_tokens, torch.Tensor): input_tokens = input_tokens.clone() @@ -205,30 +205,36 @@ def drop_select( if roi is not None: # convert from float tensor to bool tensor # as PyNccl does not support sending bool tensor - roi = (roi > 0.5) + roi = roi > 0.5 key = self.data_pipe.recv_tensor() value = self.data_pipe.recv_tensor() hidden = self.data_pipe.recv_tensor() return [input_tokens, roi, key, value, hidden] - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - + def insert( + self, + input_tokens: torch.Tensor, + roi: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + hidden: torch.Tensor, + ) -> None: self._add_to_buffer(input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request. if self.request_handling_thread is None: self.request_handling_thread = threading.Thread( - target=self.drop_select_handler) + target=self.drop_select_handler + ) self.request_handling_thread.start() def close(self): - - if hasattr(self, "request_handling_thread" - ) and self.request_handling_thread is not None: + if ( + hasattr(self, "request_handling_thread") + and self.request_handling_thread is not None + ): self.request_handling_thread.join() else: diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index 1423fd032477..e27c6b2101b8 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -26,11 +26,11 @@ class KVPipeBase(ABC): @abstractmethod def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: """Send a tensor, or None, via the pipe. - + Need to support sending None -- important for error handling. - - TODO: add a `key` argument so that we can use traditional - key-value database as the distributed communication mechanism behind + + TODO: add a `key` argument so that we can use traditional + key-value database as the distributed communication mechanism behind the pipe. Args: @@ -46,7 +46,7 @@ def recv_tensor(self) -> Optional[torch.Tensor]: """Receive a tensor (can be None) from the pipeline. Returns: - Optional[torch.Tensor]: The tensor received from the pipeline. Can + Optional[torch.Tensor]: The tensor received from the pipeline. Can be None. Raises: @@ -58,7 +58,7 @@ def recv_tensor(self) -> Optional[torch.Tensor]: def close(self) -> None: """Close the pipeline and release resources. - This method is responsible for closing the communication pipeline + This method is responsible for closing the communication pipeline and releasing any resources associated with it. Raises: diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 2a434e280179..65858f86aa23 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -32,7 +32,7 @@ class MooncakeTransferEngineConfig: device_name: str @staticmethod - def from_file(file_path: str) -> 'MooncakeTransferEngineConfig': + def from_file(file_path: str) -> "MooncakeTransferEngineConfig": """Load the config from a JSON file.""" with open(file_path) as fin: config = json.load(fin) @@ -46,12 +46,13 @@ def from_file(file_path: str) -> 'MooncakeTransferEngineConfig': ) @staticmethod - def load_from_env() -> 'MooncakeTransferEngineConfig': + def load_from_env() -> "MooncakeTransferEngineConfig": """Load config from a file specified in the environment variable.""" - config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") if config_file_path is None: raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." + ) return MooncakeTransferEngineConfig.from_file(config_file_path) @@ -65,7 +66,8 @@ def __init__(self, kv_rank: int, local_rank: int): raise ImportError( "Please install mooncake by following the instructions at " "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector.") from e + "to run vLLM with MooncakeConnector." + ) from e self.engine = TransferEngine() self.local_rank = local_rank @@ -77,16 +79,13 @@ def __init__(self, kv_rank: int, local_rank: int): logger.error(e) raise except Exception as exc: - logger.error( - "An error occurred while loading the configuration: %s", exc) + logger.error("An error occurred while loading the configuration: %s", exc) raise - prefill_host, base_prefill_port = split_host_port( - self.config.prefill_url) + prefill_host, base_prefill_port = split_host_port(self.config.prefill_url) decode_host, base_decode_port = split_host_port(self.config.decode_url) # Avoid ports conflict when running prefill and decode on the same node - if prefill_host == decode_host and \ - base_prefill_port == base_decode_port: + if prefill_host == decode_host and base_prefill_port == base_decode_port: base_decode_port = base_decode_port + 100 prefill_port = base_prefill_port + self.local_rank @@ -94,12 +93,15 @@ def __init__(self, kv_rank: int, local_rank: int): self.prefill_url = join_host_port(prefill_host, prefill_port) self.decode_url = join_host_port(decode_host, decode_port) - self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url, - self.config.metadata_server, self.config.protocol, - self.config.device_name, self.config.metadata_backend) + self.initialize( + self.prefill_url if kv_rank == 0 else self.decode_url, + self.config.metadata_server, + self.config.protocol, + self.config.device_name, + self.config.metadata_backend, + ) - self.remote_url = (self.decode_url - if kv_rank == 0 else self.prefill_url) + self.remote_url = self.decode_url if kv_rank == 0 else self.prefill_url # Initialize ZeroMQ context and sockets self.context = zmq.Context() # type: ignore[attr-defined] @@ -109,51 +111,57 @@ def __init__(self, kv_rank: int, local_rank: int): self.receiver_ack = self.context.socket(zmq.constants.PUSH) self.buffer_cleaner = ThreadPoolExecutor(max_workers=1) - self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port, - decode_host, base_decode_port) + self._setup_metadata_sockets( + kv_rank, prefill_host, base_prefill_port, decode_host, base_decode_port + ) - def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int, - d_host: str, d_port: int) -> None: + def _setup_metadata_sockets( + self, kv_rank: int, p_host: str, p_port: int, d_host: str, d_port: int + ) -> None: """Set up ZeroMQ sockets for sending and receiving data.""" # Offsets < 8 are left for initialization in case tp and pp are enabled p_rank_offset = p_port + 8 + self.local_rank * 2 d_rank_offset = d_port + 8 + self.local_rank * 2 if kv_rank == 0: - self.sender_socket.bind( - make_zmq_path("tcp", p_host, p_rank_offset + 1)) + self.sender_socket.bind(make_zmq_path("tcp", p_host, p_rank_offset + 1)) self.receiver_socket.connect( - make_zmq_path("tcp", d_host, d_rank_offset + 1)) - self.sender_ack.connect( - make_zmq_path("tcp", d_host, d_rank_offset + 2)) - self.receiver_ack.bind( - make_zmq_path("tcp", p_host, p_rank_offset + 2)) + make_zmq_path("tcp", d_host, d_rank_offset + 1) + ) + self.sender_ack.connect(make_zmq_path("tcp", d_host, d_rank_offset + 2)) + self.receiver_ack.bind(make_zmq_path("tcp", p_host, p_rank_offset + 2)) else: self.receiver_socket.connect( - make_zmq_path("tcp", p_host, p_rank_offset + 1)) - self.sender_socket.bind( - make_zmq_path("tcp", d_host, d_rank_offset + 1)) - self.receiver_ack.bind( - make_zmq_path("tcp", d_host, d_rank_offset + 2)) - self.sender_ack.connect( - make_zmq_path("tcp", p_host, p_rank_offset + 2)) - - def initialize(self, local_hostname: str, metadata_server: str, - protocol: str, device_name: str, - metadata_backend: Union[str, None]) -> None: + make_zmq_path("tcp", p_host, p_rank_offset + 1) + ) + self.sender_socket.bind(make_zmq_path("tcp", d_host, d_rank_offset + 1)) + self.receiver_ack.bind(make_zmq_path("tcp", d_host, d_rank_offset + 2)) + self.sender_ack.connect(make_zmq_path("tcp", p_host, p_rank_offset + 2)) + + def initialize( + self, + local_hostname: str, + metadata_server: str, + protocol: str, + device_name: str, + metadata_backend: Union[str, None], + ) -> None: """Initialize the mooncake instance.""" if metadata_backend is None: - self.engine.initialize(local_hostname, metadata_server, protocol, - device_name) + self.engine.initialize( + local_hostname, metadata_server, protocol, device_name + ) else: supported_backend = ["etcd", "redis"] metadata_backend = metadata_backend.lower() if metadata_backend not in supported_backend: raise ValueError( "Mooncake Configuration error. `metadata_backend`" - f" should be one of {supported_backend}.") + f" should be one of {supported_backend}." + ) - self.engine.initialize_ext(local_hostname, metadata_server, - protocol, device_name, metadata_backend) + self.engine.initialize_ext( + local_hostname, metadata_server, protocol, device_name, metadata_backend + ) def allocate_managed_buffer(self, length: int) -> int: """Allocate a managed buffer of the specified length.""" @@ -167,18 +175,17 @@ def free_managed_buffer(self, buffer: int, length: int) -> int: """Free a previously allocated managed buffer.""" return self.engine.free_managed_buffer(buffer, length) - def transfer_sync(self, buffer: int, peer_buffer_address: int, - length: int) -> int: + def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int: """Synchronously transfer data to the specified address.""" - ret = self.engine.transfer_sync_read(self.remote_url, buffer, - peer_buffer_address, length) + ret = self.engine.transfer_sync_read( + self.remote_url, buffer, peer_buffer_address, length + ) if ret < 0: logger.error("Transfer Return Error") raise Exception("Transfer Return Error") return ret - def write_bytes_to_buffer(self, buffer: int, user_data: bytes, - length: int) -> int: + def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int: """Write bytes to the allocated buffer.""" return self.engine.write_bytes_to_buffer(buffer, user_data, length) @@ -189,7 +196,7 @@ def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: def wait_for_ack(self, src_ptr: int, length: int) -> None: """Asynchronously wait for ACK from the receiver.""" ack = self.sender_ack.recv() - if ack != b'ACK': + if ack != b"ACK": logger.error("Failed to receive ACK from the receiver") self.free_managed_buffer(src_ptr, length) @@ -200,8 +207,8 @@ def send_bytes(self, user_data: bytes) -> None: src_ptr = self.allocate_managed_buffer(length) self.write_bytes_to_buffer(src_ptr, user_data, length) self.sender_socket.send_multipart( - [struct.pack("!Q", src_ptr), - struct.pack("!Q", length)]) + [struct.pack("!Q", src_ptr), struct.pack("!Q", length)] + ) self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) def recv_bytes(self) -> bytes: @@ -214,7 +221,7 @@ def recv_bytes(self) -> bytes: ret = self.read_bytes_from_buffer(dst_ptr, length) # Buffer cleanup - self.receiver_ack.send(b'ACK') + self.receiver_ack.send(b"ACK") self.free_managed_buffer(dst_ptr, length) return ret @@ -223,10 +230,9 @@ def recv_bytes(self) -> bytes: class MooncakePipe(KVPipeBase): """MooncakeTransferEngine based Pipe implementation.""" - def __init__(self, - local_rank: int, - config: KVTransferConfig, - device: Optional[str] = None): + def __init__( + self, local_rank: int, config: KVTransferConfig, device: Optional[str] = None + ): """Initialize the mooncake pipe and set related parameters.""" self.config = config self.local_rank = local_rank @@ -236,8 +242,7 @@ def __init__(self, else: self.device = self._select_device(device) - self.transfer_engine = MooncakeTransferEngine(self.kv_rank, - self.local_rank) + self.transfer_engine = MooncakeTransferEngine(self.kv_rank, self.local_rank) self.transport_thread: Optional[ThreadPoolExecutor] = None self.none_tensor = torch.tensor([NONE_INT], device=self.device) @@ -267,7 +272,7 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) tensor = tensor if tensor is not None else self.none_tensor - assert (len(tensor.shape) > 0) + assert len(tensor.shape) > 0 self.transport_thread.submit(self._send_impl, tensor) def recv_tensor(self) -> Optional[torch.Tensor]: diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 7a79a8cc0c93..c79b7e7e5030 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -1,16 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ - This module implements a PyNccl pipe for sending and receiving - Optional[torch.Tensor] between distributed ranks with advanced - communication features. - - Key Features: - - Supports sending and receiving tensors with metadata - - Handles both CUDA and CPU device communications - - Implements a non-blocking tensor transfer mechanism - - Manages buffer size and provides backpressure control - - Supports distributed process groups with configurable parameters +This module implements a PyNccl pipe for sending and receiving +Optional[torch.Tensor] between distributed ranks with advanced +communication features. + +Key Features: +- Supports sending and receiving tensors with metadata +- Handles both CUDA and CPU device communications +- Implements a non-blocking tensor transfer mechanism +- Manages buffer size and provides backpressure control +- Supports distributed process groups with configurable parameters """ import threading @@ -30,7 +30,6 @@ class BrokenPipeException(Exception): - def __init__(self, message): self.message = message super().__init__(self.message) @@ -40,16 +39,17 @@ def __init__(self, message): class PyNcclPipe(KVPipeBase): - METADATA_LENGTH = 16 MAX_TENSOR_DIMENSIONS = 14 METADATA_DTYPE = torch.int64 - def __init__(self, - local_rank: int, - config: KVTransferConfig, - device: Optional[str] = None, - port_offset: int = 0): + def __init__( + self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None, + port_offset: int = 0, + ): self.config = config self.local_rank = local_rank self.kv_rank = self.config.kv_rank @@ -84,9 +84,9 @@ def __init__(self, def _get_device_send_recv_impl( self, group: StatelessProcessGroup - ) -> tuple[Callable[[torch.Tensor, int], None], Callable[ - [torch.Tensor, int], None]]: - + ) -> tuple[ + Callable[[torch.Tensor, int], None], Callable[[torch.Tensor, int], None] + ]: send: Callable[[torch.Tensor, int], None] recv: Callable[[torch.Tensor, int], None] if self.device.type == "cuda": @@ -144,9 +144,9 @@ def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: buffer: A tensor of the specified type and shape, allocated on `self.device`. """ - return torch.empty(metadata["shape"], - dtype=metadata["dtype"], - device=self.device) + return torch.empty( + metadata["shape"], dtype=metadata["dtype"], device=self.device + ) def _send_metadata(self, metadata: Metadata): """ @@ -179,8 +179,7 @@ def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: metadata = self._make_metadata(tensor) self._send_metadata(metadata) if tensor is not None: - self.device_send_func(tensor.to(self.device), - self.target_rank_for_send) + self.device_send_func(tensor.to(self.device), self.target_rank_for_send) def _recv_impl(self) -> Optional[torch.Tensor]: """ @@ -198,8 +197,9 @@ def _recv_impl(self) -> Optional[torch.Tensor]: return buffer - def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], - tensor_size: int) -> None: + def send_tensor_wrapper( + self, tensor: Optional[torch.Tensor], tensor_size: int + ) -> None: """ Wrapper for _send_impl to handle exceptions and update buffer size. """ @@ -209,9 +209,14 @@ def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], with self.buffer_size_lock: self.buffer_size -= tensor_size except Exception as e: - logger.error("[rank%d]: Exception when trying to send %s, msg: %s", - torch.distributed.get_rank(), str(tensor), str(e)) + logger.error( + "[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), + str(tensor), + str(e), + ) import traceback + traceback.print_exc() def block_if_full(self): @@ -244,8 +249,7 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: with self.buffer_size_lock: self.buffer_size += tensor_size - self.transport_thread.submit(self.send_tensor_wrapper, tensor, - tensor_size) + self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size) def recv_tensor(self) -> Optional[torch.Tensor]: """ @@ -266,6 +270,7 @@ def recv_tensor(self) -> Optional[torch.Tensor]: logger.error("%s", e) logger.error("My device: %s", self.device) import traceback + traceback.print_exc() raise e @@ -275,6 +280,5 @@ def close(self): """ Close the pipe and release associated resources. """ - if hasattr(self, - "transport_thread") and self.transport_thread is not None: + if hasattr(self, "transport_thread") and self.transport_thread is not None: self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index d5747bed9277..f8f65f28ff6d 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -4,10 +4,11 @@ from vllm import envs from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -17,7 +18,8 @@ def get_kv_transfer_group() -> KVConnectorBaseType: assert _KV_CONNECTOR_AGENT is not None, ( - "disaggregated KV cache transfer parallel group is not initialized") + "disaggregated KV cache transfer parallel group is not initialized" + ) return _KV_CONNECTOR_AGENT @@ -25,8 +27,7 @@ def has_kv_transfer_group() -> bool: return _KV_CONNECTOR_AGENT is not None -def is_v1_kv_transfer_group( - connector: Optional[KVConnectorBaseType] = None) -> bool: +def is_v1_kv_transfer_group(connector: Optional[KVConnectorBaseType] = None) -> bool: """Check if the KV connector is the v1 connector. If the argument is None, it will check the global KV connector @@ -57,11 +58,14 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: if vllm_config.kv_transfer_config is None: return - if (vllm_config.kv_transfer_config.is_kv_transfer_instance - and _KV_CONNECTOR_AGENT is None): + if ( + vllm_config.kv_transfer_config.is_kv_transfer_instance + and _KV_CONNECTOR_AGENT is None + ): if envs.VLLM_USE_V1: _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( - config=vllm_config, role=KVConnectorRole.WORKER) + config=vllm_config, role=KVConnectorRole.WORKER + ) else: raise ValueError("V0 is no longer supported") diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 638170963e2b..aee5507ade46 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -22,6 +22,7 @@ parallelism, you can skip the model parallel initialization and destruction steps. """ + import contextlib import gc import pickle @@ -41,11 +42,16 @@ import vllm.envs as envs from vllm.distributed.device_communicators.base_device_communicator import ( - DeviceCommunicatorBase) + DeviceCommunicatorBase, +) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import (direct_register_custom_op, get_distributed_init_method, - resolve_obj_by_qualname, supports_custom_op) +from vllm.utils import ( + direct_register_custom_op, + get_distributed_init_method, + resolve_obj_by_qualname, + supports_custom_op, +) @dataclass @@ -57,7 +63,7 @@ class GraphCaptureContext: def _split_tensor_dict( - tensor_dict: dict[str, Union[torch.Tensor, Any]] + tensor_dict: dict[str, Union[torch.Tensor, Any]], ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced @@ -74,7 +80,8 @@ def _split_tensor_dict( # receiving side will set the device index. device = value.device.type metadata_list.append( - (key, TensorMetadata(device, value.dtype, value.size()))) + (key, TensorMetadata(device, value.dtype, value.size())) + ) tensor_list.append(value) else: metadata_list.append((key, value)) @@ -116,8 +123,9 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: return torch.empty_like(tensor) -def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: +def reduce_scatter( + tensor: torch.Tensor, dim: int, world_size: int, group_name: str +) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: @@ -125,15 +133,17 @@ def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, return group._reduce_scatter_out_place(tensor, dim) -def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: +def reduce_scatter_fake( + tensor: torch.Tensor, dim: int, world_size: int, group_name: str +) -> torch.Tensor: new_shape = list(tensor.shape) new_shape[dim] = tensor.shape[dim] // world_size return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) -def all_gather(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: +def all_gather( + tensor: torch.Tensor, dim: int, world_size: int, group_name: str +) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: @@ -141,8 +151,9 @@ def all_gather(tensor: torch.Tensor, dim: int, world_size: int, return group._all_gather_out_place(tensor, dim) -def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int, - group_name: str) -> torch.Tensor: +def all_gather_fake( + tensor: torch.Tensor, dim: int, world_size: int, group_name: str +) -> torch.Tensor: new_shape = list(tensor.shape) new_shape[dim] = tensor.shape[dim] * world_size return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) @@ -218,7 +229,8 @@ def __init__( for ranks in group_ranks: device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) + ranks, backend=torch_distributed_backend + ) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. cpu_group = torch.distributed.new_group(ranks, backend="gloo") @@ -242,8 +254,7 @@ def __init__( elif current_platform.is_xpu(): self.device = torch.device(f"xpu:{local_rank}") elif current_platform.is_out_of_tree(): - self.device = torch.device( - f"{current_platform.device_name}:{local_rank}") + self.device = torch.device(f"{current_platform.device_name}:{local_rank}") else: self.device = torch.device("cpu") @@ -251,7 +262,8 @@ def __init__( self.device_communicator = None if use_device_communicator and self.world_size > 1: device_comm_cls = resolve_obj_by_qualname( - current_platform.get_device_communicator_cls()) + current_platform.get_device_communicator_cls() + ) self.device_communicator = device_comm_cls( cpu_group=self.cpu_group, device=self.device, @@ -259,19 +271,23 @@ def __init__( unique_name=self.unique_name, ) - from vllm.distributed.device_communicators.shm_broadcast import ( - MessageQueue) + from vllm.distributed.device_communicators.shm_broadcast import MessageQueue + self.mq_broadcaster: Optional[MessageQueue] = None if use_message_queue_broadcaster and self.world_size > 1: self.mq_broadcaster = MessageQueue.create_from_process_group( - self.cpu_group, 1 << 22, 6) + self.cpu_group, 1 << 22, 6 + ) from vllm.platforms import current_platform - self.use_custom_op_call = (current_platform.is_cuda_alike() - or current_platform.is_tpu()) - self.use_cpu_custom_send_recv = (current_platform.is_cpu() and hasattr( - torch.ops._C, "init_shm_manager")) + self.use_custom_op_call = ( + current_platform.is_cuda_alike() or current_platform.is_tpu() + ) + + self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr( + torch.ops._C, "init_shm_manager" + ) @property def first_rank(self): @@ -309,7 +325,8 @@ def prev_rank(self): @contextmanager def graph_capture( - self, graph_capture_context: Optional[GraphCaptureContext] = None): + self, graph_capture_context: Optional[GraphCaptureContext] = None + ): if graph_capture_context is None: stream = torch.cuda.Stream() graph_capture_context = GraphCaptureContext(stream) @@ -320,7 +337,9 @@ def graph_capture( # so we don't abstract it into the base class maybe_ca_context = nullcontext() from vllm.distributed.device_communicators.cuda_communicator import ( - CudaCommunicator) + CudaCommunicator, + ) + if self.device_communicator is not None: assert isinstance(self.device_communicator, CudaCommunicator) ca_comm = self.device_communicator.ca_comm @@ -356,8 +375,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return input_ if self.use_custom_op_call: - return torch.ops.vllm.all_reduce(input_, - group_name=self.unique_name) + return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) else: return self._all_reduce_out_place(input_) @@ -372,66 +390,62 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: if world_size == 1: return input_ assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if self.use_custom_op_call: - return torch.ops.vllm.all_gather(input_, - dim, - world_size, - group_name=self.unique_name) + return torch.ops.vllm.all_gather( + input_, dim, world_size, group_name=self.unique_name + ) else: return self._all_gather_out_place(input_, dim) - def _all_gather_out_place(self, input_: torch.Tensor, - dim: int) -> torch.Tensor: + def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.all_gather(input_, dim) - def all_gatherv(self, - input_: Union[torch.Tensor, list[torch.Tensor]], - dim: int = 0, - sizes: Optional[list[int]] = None): + def all_gatherv( + self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None, + ): if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.all_gatherv(input_, dim, sizes) - def reduce_scatter(self, - input_: torch.Tensor, - dim: int = -1) -> torch.Tensor: + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + ) if self.use_custom_op_call: - return torch.ops.vllm.reduce_scatter(input_, - dim, - world_size, - group_name=self.unique_name) + return torch.ops.vllm.reduce_scatter( + input_, dim, world_size, group_name=self.unique_name + ) else: return self._reduce_scatter_out_place(input_, dim) - def reduce_scatterv(self, - input_: torch.Tensor, - dim: int = -1, - sizes: Optional[list[int]] = None) -> torch.Tensor: + def reduce_scatterv( + self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None + ) -> torch.Tensor: if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.reduce_scatterv(input_, dim, sizes) - def _reduce_scatter_out_place(self, input_: torch.Tensor, - dim: int) -> torch.Tensor: + def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.reduce_scatter(input_, dim) - def gather(self, - input_: torch.Tensor, - dst: int = 0, - dim: int = -1) -> Optional[torch.Tensor]: + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> Optional[torch.Tensor]: """ NOTE: We assume that the input tensor is on the same device across all the ranks. @@ -455,9 +469,9 @@ def broadcast(self, input_: torch.Tensor, src: int = 0): if self.world_size == 1: return input_ # Broadcast. - torch.distributed.broadcast(input_, - src=self.ranks[src], - group=self.device_group) + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) return input_ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): @@ -473,21 +487,20 @@ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): assert src == 0, "Message queue broadcaster only supports src=0" return self.mq_broadcaster.broadcast_object(obj) if self.rank_in_group == src: - torch.distributed.broadcast_object_list([obj], - src=self.ranks[src], - group=self.cpu_group) + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) return obj else: recv = [None] - torch.distributed.broadcast_object_list(recv, - src=self.ranks[src], - group=self.cpu_group) + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) return recv[0] - def broadcast_object_list(self, - obj_list: list[Any], - src: int = 0, - group: Optional[ProcessGroup] = None): + def broadcast_object_list( + self, obj_list: list[Any], src: int = 0, group: Optional[ProcessGroup] = None + ): """Broadcast the input object list. NOTE: `src` is the local rank of the source rank. """ @@ -497,9 +510,9 @@ def broadcast_object_list(self, if self.world_size == 1: return obj_list # Broadcast. - torch.distributed.broadcast_object_list(obj_list, - src=self.ranks[src], - group=self.device_group) + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) return obj_list def send_object(self, obj: Any, dst: int) -> None: @@ -510,25 +523,22 @@ def send_object(self, obj: Any, dst: int) -> None: assert dst != self.rank_in_group, ( "Invalid destination rank. Destination rank is the same " - "as the current rank.") + "as the current rank." + ) # Serialize object to tensor and get the size as well object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) - size_tensor = torch.tensor([object_tensor.numel()], - dtype=torch.long, - device="cpu") + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) # Send object size - torch.distributed.send(size_tensor, - dst=self.ranks[dst], - group=self.cpu_group) + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) # Send object - torch.distributed.send(object_tensor, - dst=self.ranks[dst], - group=self.cpu_group) + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) return None @@ -545,22 +555,24 @@ def recv_object(self, src: int) -> Any: size_tensor = torch.empty(1, dtype=torch.long, device="cpu") # Receive object size - rank_size = torch.distributed.recv(size_tensor, - src=self.ranks[src], - group=self.cpu_group) + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) # Tensor to receive serialized objects into. object_tensor = torch.empty( # type: ignore[call-overload] size_tensor.item(), # type: ignore[arg-type] dtype=torch.uint8, - device="cpu") + device="cpu", + ) - rank_object = torch.distributed.recv(object_tensor, - src=self.ranks[src], - group=self.cpu_group) + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) assert rank_object == rank_size, ( - "Received object sender rank does not match the size sender rank.") + "Received object sender rank does not match the size sender rank." + ) obj = pickle.loads(object_tensor.numpy().tobytes()) @@ -571,13 +583,13 @@ def broadcast_tensor_dict( tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None + metadata_group: Optional[ProcessGroup] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ # Bypass the function if we are using only 1 GPU. - if (not torch.distributed.is_initialized() or self.world_size == 1): + if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict group = self.device_group @@ -587,9 +599,9 @@ def broadcast_tensor_dict( rank_in_group = self.rank_in_group if rank_in_group == src: metadata_list: list[tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + assert isinstance(tensor_dict, dict), ( + f"Expecting a dictionary, got {type(tensor_dict)}" + ) metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `broadcast_object_list` has serialization & deserialization, @@ -602,16 +614,14 @@ def broadcast_tensor_dict( continue if tensor.is_cpu: # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=metadata_group, - async_op=True) + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=metadata_group, async_op=True + ) else: # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=group, - async_op=True) + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=group, async_op=True + ) async_handles.append(handle) for async_handle in async_handles: async_handle.wait() @@ -622,9 +632,9 @@ def broadcast_tensor_dict( async_handles = [] for key, value in metadata_list: if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor @@ -635,14 +645,13 @@ def broadcast_tensor_dict( tensor, src=self.ranks[src], group=metadata_group, - async_op=True) + async_op=True, + ) else: # use group for GPU tensors handle = torch.distributed.broadcast( - tensor, - src=self.ranks[src], - group=group, - async_op=True) + tensor, src=self.ranks[src], group=group, async_op=True + ) async_handles.append(handle) tensor_dict[key] = tensor else: @@ -679,10 +688,10 @@ def send_tensor_dict( # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict - all_gather_size = (1 if all_gather_group is None else - all_gather_group.world_size) - all_gather_rank = (0 if all_gather_group is None else - all_gather_group.rank_in_group) + all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size + all_gather_rank = ( + 0 if all_gather_group is None else all_gather_group.rank_in_group + ) group = self.device_group metadata_group = self.cpu_group @@ -695,22 +704,21 @@ def send_tensor_dict( if self.device_communicator is None: raise ValueError("No device communicator found") self.device_communicator.send_tensor_dict( # type: ignore - tensor_dict, dst) + tensor_dict, dst + ) return None metadata_list: list[tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), f"Expecting a dictionary, got {type(tensor_dict)}" + assert isinstance(tensor_dict, dict), ( + f"Expecting a dictionary, got {type(tensor_dict)}" + ) metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `send_object_list` has serialization & deserialization, # all happening on CPU. Therefore, we can use the CPU group. self.send_object(metadata_list, dst=dst) - tensor_keys = [ - k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor) - ] + tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)] assert len(tensor_keys) == len(tensor_list) for key, tensor in zip(tensor_keys, tensor_list): @@ -719,23 +727,25 @@ def send_tensor_dict( continue # send-allgather: send only a slice, then do allgather. - use_all_gather = (all_gather_group is not None - and tensor.numel() % all_gather_size == 0) - use_all_gather = all_gather_tensors.get(key, use_all_gather) \ - if all_gather_tensors else use_all_gather + use_all_gather = ( + all_gather_group is not None and tensor.numel() % all_gather_size == 0 + ) + use_all_gather = ( + all_gather_tensors.get(key, use_all_gather) + if all_gather_tensors + else use_all_gather + ) if use_all_gather: tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: # use metadata_group for CPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=metadata_group) + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) else: # use group for GPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=group) + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) return None def recv_tensor_dict( @@ -765,10 +775,10 @@ def recv_tensor_dict( # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return None - all_gather_size = (1 if all_gather_group is None else - all_gather_group.world_size) - all_gather_rank = (0 if all_gather_group is None else - all_gather_group.rank_in_group) + all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size + all_gather_rank = ( + 0 if all_gather_group is None else all_gather_group.rank_in_group + ) group = self.device_group metadata_group = self.cpu_group @@ -781,45 +791,47 @@ def recv_tensor_dict( if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.recv_tensor_dict( # type: ignore - src) + src + ) recv_metadata_list = self.recv_object(src=src) tensor_dict: dict[str, Any] = {} for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue # send-allgather: send only a slice, then do allgather. - use_all_gather = (all_gather_group is not None - and tensor.numel() % all_gather_size == 0) - use_all_gather = all_gather_tensors.get(key, use_all_gather) \ - if all_gather_tensors else use_all_gather + use_all_gather = ( + all_gather_group is not None + and tensor.numel() % all_gather_size == 0 + ) + use_all_gather = ( + all_gather_tensors.get(key, use_all_gather) + if all_gather_tensors + else use_all_gather + ) if use_all_gather: orig_shape = tensor.shape - tensor = tensor.reshape(all_gather_size, - -1)[all_gather_rank] + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] if tensor.is_cpu: # use metadata_group for CPU tensors - torch.distributed.recv(tensor, - src=self.ranks[src], - group=metadata_group) + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) else: # use group for GPU tensors - torch.distributed.recv(tensor, - src=self.ranks[src], - group=group) + torch.distributed.recv(tensor, src=self.ranks[src], group=group) if use_all_gather: # do the allgather tensor = all_gather_group.all_gather( # type: ignore - tensor, dim=0) + tensor, dim=0 + ) tensor = tensor.reshape(orig_shape) tensor_dict[key] = tensor @@ -843,10 +855,9 @@ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: raise ValueError("No device communicator found") self.device_communicator.send(tensor, dst) - def recv(self, - size: torch.Size, - dtype: torch.dtype, - src: Optional[int] = None) -> torch.Tensor: + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: """Receives a tensor from the source rank.""" """NOTE: `src` is the local rank of the source rank.""" if self.device_communicator is None: @@ -867,28 +878,26 @@ def destroy(self): def prepare_communication_buffer_for_model(self, model: torch.nn.Module): if self.device_communicator is not None: - self.device_communicator.prepare_communication_buffer_for_model( - model) + self.device_communicator.prepare_communication_buffer_for_model(model) def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - is_sequence_parallel: bool = False + is_sequence_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: if self.device_communicator is not None: - return self.device_communicator.dispatch(hidden_states, - router_logits, - is_sequence_parallel) + return self.device_communicator.dispatch( + hidden_states, router_logits, is_sequence_parallel + ) else: return hidden_states, router_logits - def combine(self, - hidden_states, - is_sequence_parallel: bool = False) -> torch.Tensor: + def combine( + self, hidden_states, is_sequence_parallel: bool = False + ) -> torch.Tensor: if self.device_communicator is not None: - return self.device_communicator.combine(hidden_states, - is_sequence_parallel) + return self.device_communicator.combine(hidden_states, is_sequence_parallel) else: return hidden_states @@ -898,12 +907,13 @@ def combine(self, def get_world_group() -> GroupCoordinator: - assert _WORLD is not None, ("world group is not initialized") + assert _WORLD is not None, "world group is not initialized" return _WORLD -def init_world_group(ranks: list[int], local_rank: int, - backend: str) -> GroupCoordinator: +def init_world_group( + ranks: list[int], local_rank: int, backend: str +) -> GroupCoordinator: return GroupCoordinator( group_ranks=[ranks], local_rank=local_rank, @@ -920,7 +930,6 @@ def init_model_parallel_group( use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ) -> GroupCoordinator: - return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, @@ -935,13 +944,15 @@ def init_model_parallel_group( def get_tp_group() -> GroupCoordinator: - assert _TP is not None, ("tensor model parallel group is not initialized") + assert _TP is not None, "tensor model parallel group is not initialized" return _TP -@deprecated("`get_tensor_model_parallel_group` has been replaced with " - "`get_tp_group` and may be removed after v0.12. Please use " - "`get_tp_group` instead.") +@deprecated( + "`get_tensor_model_parallel_group` has been replaced with " + "`get_tp_group` and may be removed after v0.12. Please use " + "`get_tp_group` instead." +) def get_tensor_model_parallel_group(): return get_tp_group() @@ -950,8 +961,7 @@ def get_tensor_model_parallel_group(): def get_dcp_group() -> GroupCoordinator: - assert _DCP is not None, ( - "decode context model parallel group is not initialized") + assert _DCP is not None, "decode context model parallel group is not initialized" return _DCP @@ -964,7 +974,7 @@ def get_dcp_group() -> GroupCoordinator: def get_dp_group() -> GroupCoordinator: - assert _DP is not None, ("data parallel group is not initialized") + assert _DP is not None, "data parallel group is not initialized" return _DP @@ -972,19 +982,20 @@ def get_dp_group() -> GroupCoordinator: def get_ep_group() -> GroupCoordinator: - assert _EP is not None, ("expert parallel group is not initialized") + assert _EP is not None, "expert parallel group is not initialized" return _EP def get_pp_group() -> GroupCoordinator: - assert _PP is not None, ( - "pipeline model parallel group is not initialized") + assert _PP is not None, "pipeline model parallel group is not initialized" return _PP -@deprecated("`get_pipeline_model_parallel_group` has been replaced with " - "`get_pp_group` and may be removed in v0.12. Please use " - "`get_pp_group` instead.") +@deprecated( + "`get_pipeline_model_parallel_group` has been replaced with " + "`get_pp_group` and may be removed in v0.12. Please use " + "`get_pp_group` instead." +) def get_pipeline_model_parallel_group(): return get_pp_group() @@ -1005,8 +1016,7 @@ def graph_capture(device: torch.device): from other kernels possibly launched on background in the default stream. """ context = GraphCaptureContext(torch.cuda.Stream(device=device)) - with get_tp_group().graph_capture(context), get_pp_group().graph_capture( - context): + with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context): yield context @@ -1020,21 +1030,30 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable -def init_distributed_environment(world_size: int = -1, - rank: int = -1, - distributed_init_method: str = "env://", - local_rank: int = -1, - backend: str = "nccl", - timeout: Optional[timedelta] = None): +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", + timeout: Optional[timedelta] = None, +): logger.debug( - "world_size=%d rank=%d local_rank=%d " - "distributed_init_method=%s backend=%s", world_size, rank, local_rank, - distributed_init_method, backend) + "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) from vllm.config import get_current_vllm_config + config = get_current_vllm_config() - if config is not None and config.parallel_config.data_parallel_size > 1 \ - and config.parallel_config.distributed_executor_backend \ - != "external_launcher": + if ( + config is not None + and config.parallel_config.data_parallel_size > 1 + and config.parallel_config.distributed_executor_backend != "external_launcher" + ): parallel_config = config.parallel_config # adjust to take into account data parallelism # offset the rank by the data parallel rank @@ -1046,17 +1065,23 @@ def init_distributed_environment(world_size: int = -1, distributed_init_method = get_distributed_init_method(ip, port) logger.info( "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", - world_size, rank, distributed_init_method) + world_size, + rank, + distributed_init_method, + ) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " - "distributed environment") + "distributed environment" + ) if not torch.distributed.is_backend_available(backend): logger.warning( - "Distributed backend %s is not available; " - "falling back to gloo.", backend) + "Distributed backend %s is not available; falling back to gloo.", + backend, + ) assert torch.distributed.is_gloo_available(), ( - "Fallback Gloo backend is not available.") + "Fallback Gloo backend is not available." + ) backend = "gloo" # this backend is used for WORLD torch.distributed.init_process_group( @@ -1064,27 +1089,25 @@ def init_distributed_environment(world_size: int = -1, init_method=distributed_init_method, world_size=world_size, rank=rank, - timeout=timeout) + timeout=timeout, + ) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 if local_rank == -1: # local rank not set, this usually happens in single-node # setting, where we can use rank as local rank - if distributed_init_method == "env://": - local_rank = envs.LOCAL_RANK - else: - local_rank = rank + local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank global _WORLD, _NODE_COUNT if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) _NODE_COUNT = _node_count(_WORLD.cpu_group) - logger.debug("Detected %d nodes in the distributed environment", - _NODE_COUNT) + logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( - "world group already initialized with a different world size") + "world group already initialized with a different world size" + ) def initialize_model_parallel( @@ -1120,11 +1143,11 @@ def initialize_model_parallel( assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() rank = torch.distributed.get_rank() - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) + backend = backend or torch.distributed.get_backend(get_world_group().device_group) data_parallel_size = 1 from vllm.config import get_current_vllm_config + config = get_current_vllm_config() if config is not None: data_parallel_size = config.parallel_config.data_parallel_size @@ -1139,77 +1162,82 @@ def initialize_model_parallel( # to get group_ranks for each dimension, transpose that dimension to the # last dimension, then reshape to 2D, then unbind the last dimension all_ranks = torch.arange(world_size).reshape( - -1, data_parallel_size, pipeline_model_parallel_size, - tensor_model_parallel_size) # noqa + -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size + ) # noqa # Build the tensor model-parallel groups. global _TP - assert _TP is None, ("tensor model parallel group is already initialized") + assert _TP is None, "tensor model parallel group is already initialized" group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] # message queue broadcaster is only used in tensor model parallel group - _TP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=True, - group_name="tp") + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp", + ) # Build the DCP model-parallel groups. global _DCP - assert _DCP is None, ( - "decode context model parallel group is already initialized") + assert _DCP is None, "decode context model parallel group is already initialized" # Note(hc): In the current implementation of decode context parallel, # dcp_size must not exceed tp_size, because the world size does not # change by DCP, it simply reuses the GPUs of TP group, and split one # TP group into tp_size//dcp_size DCP groups. - group_ranks = all_ranks.reshape( - -1, decode_context_model_parallel_size).unbind(0) + group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - _DCP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=True, - group_name="dcp") + _DCP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="dcp", + ) # Build the pipeline model-parallel groups. global _PP - assert _PP is None, ( - "pipeline model parallel group is already initialized") - group_ranks = all_ranks.transpose(2, 3).reshape( - -1, pipeline_model_parallel_size).unbind(0) + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0) + ) group_ranks = [x.tolist() for x in group_ranks] - _PP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="pp") + _PP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="pp" + ) global _DP - assert _DP is None, ("data parallel group is already initialized") - group_ranks = all_ranks.transpose(1, - 3).reshape(-1, - data_parallel_size).unbind(0) + assert _DP is None, "data parallel group is already initialized" + group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - _DP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="dp") + _DP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="dp" + ) global _EP - assert _EP is None, ("expert parallel group is already initialized") - group_ranks = all_ranks.transpose(1, 2).reshape( - -1, data_parallel_size * tensor_model_parallel_size).unbind(0) + assert _EP is None, "expert parallel group is already initialized" + group_ranks = ( + all_ranks.transpose(1, 2) + .reshape(-1, data_parallel_size * tensor_model_parallel_size) + .unbind(0) + ) group_ranks = [x.tolist() for x in group_ranks] - _EP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name="ep") + _EP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="ep" + ) logger.info( "rank %s in world size %s is assigned as " - "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size, - _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, - _EP.rank_in_group) + "DP rank %s, PP rank %s, TP rank %s, EP rank %s", + rank, + world_size, + _DP.rank_in_group, + _PP.rank_in_group, + _TP.rank_in_group, + _EP.rank_in_group, + ) def ensure_model_parallel_initialized( @@ -1222,24 +1250,27 @@ def ensure_model_parallel_initialized( or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ - backend = backend or torch.distributed.get_backend( - get_world_group().device_group) + backend = backend or torch.distributed.get_backend(get_world_group().device_group) if not model_parallel_is_initialized(): - initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size, - decode_context_model_parallel_size, backend) + initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + decode_context_model_parallel_size, + backend, + ) return - assert ( - get_tensor_model_parallel_world_size() == tensor_model_parallel_size - ), ("tensor parallel group already initialized, but of unexpected size. " + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( + "tensor parallel group already initialized, but of unexpected size. " f"got: {get_tensor_model_parallel_world_size()=} vs. " - f"wanted: {tensor_model_parallel_size=}") + f"wanted: {tensor_model_parallel_size=}" + ) pp_world_size = get_pp_group().world_size - assert (pp_world_size == pipeline_model_parallel_size), ( + assert pp_world_size == pipeline_model_parallel_size, ( "pipeline parallel group already initialized, but of unexpected size. " f"got: {pp_world_size=} vs. " - f"wanted: {pipeline_model_parallel_size=}") + f"wanted: {pipeline_model_parallel_size=}" + ) def prepare_communication_buffer_for_model(model: torch.nn.Module): @@ -1261,7 +1292,7 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module): def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TP is not None and _PP is not None) + return _TP is not None and _PP is not None _TP_STATE_PATCHED = False @@ -1313,9 +1344,8 @@ def get_decode_context_model_parallel_rank(): def get_node_count() -> int: - """Return the total number of nodes in the distributed environment. """ - assert _NODE_COUNT is not None, ( - "distributed environment is not initialized") + """Return the total number of nodes in the distributed environment.""" + assert _NODE_COUNT is not None, "distributed environment is not initialized" return _NODE_COUNT @@ -1363,9 +1393,11 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): destroy_distributed_environment() if shutdown_ray: import ray # Lazy import Ray + ray.shutdown() gc.collect() from vllm.platforms import current_platform + empty_cache = current_platform.empty_cache if empty_cache is not None: empty_cache() @@ -1373,21 +1405,21 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): if not current_platform.is_cpu(): torch._C._host_emptyCache() except AttributeError: - logger.warning( - "torch._C._host_emptyCache() only available in Pytorch >=2.5") + logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5") -def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], - source_rank: int = 0) -> list[bool]: +def in_the_same_node_as( + pg: Union[ProcessGroup, StatelessProcessGroup], source_rank: int = 0 +) -> list[bool]: """ This is a collective operation that returns if each rank is in the same node as the source rank. It tests if processes are attached to the same memory system (shared access to shared memory). """ if isinstance(pg, ProcessGroup): - assert torch.distributed.get_backend( - pg) != torch.distributed.Backend.NCCL, ( - "in_the_same_node_as should be tested with a non-NCCL group.") + assert torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL, ( + "in_the_same_node_as should be tested with a non-NCCL group." + ) # local rank inside the group rank = torch.distributed.get_rank(group=pg) world_size = torch.distributed.get_world_size(group=pg) @@ -1410,10 +1442,11 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], if rank == source_rank: # create a shared memory segment shm = shared_memory.SharedMemory(create=True, size=128) - shm.buf[:len(magic_message)] = magic_message + shm.buf[: len(magic_message)] = magic_message if isinstance(pg, ProcessGroup): torch.distributed.broadcast_object_list( - [shm.name], src=ranks[source_rank], group=pg) + [shm.name], src=ranks[source_rank], group=pg + ) else: pg.broadcast_obj(shm.name, src=source_rank) is_in_the_same_node[rank] = 1 @@ -1422,17 +1455,20 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], if isinstance(pg, ProcessGroup): recv = [None] torch.distributed.broadcast_object_list( - recv, src=ranks[source_rank], group=pg) + recv, src=ranks[source_rank], group=pg + ) name = recv[0] else: name = pg.broadcast_obj(None, src=source_rank) # fix to https://stackoverflow.com/q/62748654/9191338 # Python incorrectly tracks shared memory even if it is not # created by the process. The following patch is a workaround. - with patch("multiprocessing.resource_tracker.register", - lambda *args, **kwargs: None): + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): shm = shared_memory.SharedMemory(name=name) - if shm.buf[:len(magic_message)] == magic_message: + if shm.buf[: len(magic_message)] == magic_message: is_in_the_same_node[rank] = 1 except Exception as e: logger.error("Error ignored in is_in_the_same_node: %s", e) diff --git a/vllm/distributed/tpu_distributed_utils.py b/vllm/distributed/tpu_distributed_utils.py index 0a786b4a1708..3db25d1a1964 100644 --- a/vllm/distributed/tpu_distributed_utils.py +++ b/vllm/distributed/tpu_distributed_utils.py @@ -10,18 +10,17 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) logger = init_logger(__name__) class XlaQKVParallelLinear(nn.Module): - - def __init__(self, - qkv_linear: nn.Module, - mesh: Optional["xs.Mesh"] = None): + def __init__(self, qkv_linear: nn.Module, mesh: Optional["xs.Mesh"] = None): super().__init__() assert isinstance(qkv_linear, QKVParallelLinear) self.skip_bias_add = qkv_linear.skip_bias_add @@ -39,21 +38,22 @@ def __init__(self, self._shard_weight(mesh) def _shard_weight(self, mesh: "xs.Mesh"): - self.q_weight = Parameter(self.q_weight.to('xla'), requires_grad=False) - self.k_weight = Parameter(self.k_weight.to('xla'), requires_grad=False) - self.v_weight = Parameter(self.v_weight.to('xla'), requires_grad=False) - xs.mark_sharding(self.q_weight, mesh, ('x', None)) - xs.mark_sharding(self.k_weight, mesh, ('x', None)) - xs.mark_sharding(self.v_weight, mesh, ('x', None)) + self.q_weight = Parameter(self.q_weight.to("xla"), requires_grad=False) + self.k_weight = Parameter(self.k_weight.to("xla"), requires_grad=False) + self.v_weight = Parameter(self.v_weight.to("xla"), requires_grad=False) + xs.mark_sharding(self.q_weight, mesh, ("x", None)) + xs.mark_sharding(self.k_weight, mesh, ("x", None)) + xs.mark_sharding(self.v_weight, mesh, ("x", None)) if self.q_bias is not None: - assert self.k_bias is not None and self.v_bias is not None, \ + assert self.k_bias is not None and self.v_bias is not None, ( "QKVParallelLinear should have q, k, and v biases together." - self.q_bias = Parameter(self.q_bias.to('xla'), requires_grad=False) - xs.mark_sharding(self.q_bias, mesh, ('x', )) - self.k_bias = Parameter(self.k_bias.to('xla'), requires_grad=False) - xs.mark_sharding(self.k_bias, mesh, ('x', )) - self.v_bias = Parameter(self.v_bias.to('xla'), requires_grad=False) - xs.mark_sharding(self.v_bias, mesh, ('x', )) + ) + self.q_bias = Parameter(self.q_bias.to("xla"), requires_grad=False) + xs.mark_sharding(self.q_bias, mesh, ("x",)) + self.k_bias = Parameter(self.k_bias.to("xla"), requires_grad=False) + xs.mark_sharding(self.k_bias, mesh, ("x",)) + self.v_bias = Parameter(self.v_bias.to("xla"), requires_grad=False) + xs.mark_sharding(self.v_bias, mesh, ("x",)) def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module): q_proj_size, k_proj_size, _ = qkv_linear.output_sizes @@ -61,22 +61,25 @@ def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module): # along the output dimension. qkv_weight = qkv_linear.weight.data.cpu() q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False) - k_weight = Parameter(qkv_weight[q_proj_size:q_proj_size + k_proj_size], - requires_grad=False) - v_weight = Parameter(qkv_weight[q_proj_size + k_proj_size:], - requires_grad=False) + k_weight = Parameter( + qkv_weight[q_proj_size : q_proj_size + k_proj_size], requires_grad=False + ) + v_weight = Parameter( + qkv_weight[q_proj_size + k_proj_size :], requires_grad=False + ) self.register_parameter("q_weight", q_weight) self.register_parameter("k_weight", k_weight) self.register_parameter("v_weight", v_weight) if qkv_linear.bias is not None: - q_bias = Parameter(qkv_linear.bias[:q_proj_size], - requires_grad=False) - k_bias = Parameter(qkv_linear.bias[q_proj_size:q_proj_size + - k_proj_size], - requires_grad=False) - v_bias = Parameter(qkv_linear.bias[q_proj_size + k_proj_size:], - requires_grad=False) + q_bias = Parameter(qkv_linear.bias[:q_proj_size], requires_grad=False) + k_bias = Parameter( + qkv_linear.bias[q_proj_size : q_proj_size + k_proj_size], + requires_grad=False, + ) + v_bias = Parameter( + qkv_linear.bias[q_proj_size + k_proj_size :], requires_grad=False + ) self.register_parameter("q_bias", q_bias) self.register_parameter("k_bias", k_bias) self.register_parameter("v_bias", v_bias) @@ -102,42 +105,48 @@ def forward(self, input): # The concat and the following split will be noop, and should be # optimized away by the compiler. qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1) - output_bias = torch.cat([q_bias, k_bias, v_bias], dim=-1) if \ - self.skip_bias_add else None + output_bias = ( + torch.cat([q_bias, k_bias, v_bias], dim=-1) if self.skip_bias_add else None + ) if not self.return_bias: return qkv_proj return qkv_proj, output_bias -def partition_column_parallel_linear(layer: torch.nn.Module, - mesh: xs.Mesh) -> torch.nn.Module: +def partition_column_parallel_linear( + layer: torch.nn.Module, mesh: xs.Mesh +) -> torch.nn.Module: assert isinstance(layer, ColumnParallelLinear) - xs.mark_sharding(layer.weight, mesh, ('x', None)) + xs.mark_sharding(layer.weight, mesh, ("x", None)) logger.debug("Applied column-parallel sharding to %s", layer) return layer -def partition_row_parallel_linear(layer: torch.nn.Module, - mesh: xs.Mesh) -> torch.nn.Module: +def partition_row_parallel_linear( + layer: torch.nn.Module, mesh: xs.Mesh +) -> torch.nn.Module: assert isinstance(layer, RowParallelLinear) - xs.mark_sharding(layer.weight, mesh, (None, 'x')) + xs.mark_sharding(layer.weight, mesh, (None, "x")) logger.debug("Applied row-parallel sharding to %s", layer) return layer -def partition_qkv_parallel_linear(layer: torch.nn.Module, - mesh: xs.Mesh) -> torch.nn.Module: +def partition_qkv_parallel_linear( + layer: torch.nn.Module, mesh: xs.Mesh +) -> torch.nn.Module: assert isinstance(layer, QKVParallelLinear) xla_layer = XlaQKVParallelLinear(layer, mesh) logger.debug("Applied qkv parallel sharding to %s", layer) return xla_layer -MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict([ - ("QKVParallelLinear", partition_qkv_parallel_linear), - ("ColumnParallelLinear", partition_column_parallel_linear), - ("RowParallelLinear", partition_row_parallel_linear), -]) +MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict( + [ + ("QKVParallelLinear", partition_qkv_parallel_linear), + ("ColumnParallelLinear", partition_column_parallel_linear), + ("RowParallelLinear", partition_row_parallel_linear), + ] +) def get_fqn(module): @@ -147,9 +156,9 @@ def get_fqn(module): def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None: """ - Recursively check a PyTorch model and apply appropriate sharding based on + Recursively check a PyTorch model and apply appropriate sharding based on the MODULE_TYPE_TO_WRAPPING_FUNC mapping. - + Args: model: torch.nn.Module to process mesh: An XLA SPMD mesh object used for sharding @@ -161,7 +170,8 @@ def _process_module(module, name=None, parent=None): wrapped_module = wrapping_func(module, mesh) assert parent is not None and name is not None, ( - "Top Level module is not expected to be wrapped.") + "Top Level module is not expected to be wrapped." + ) if wrapped_module is not module: # Wrapped module and module are different py object. # The original module should be replaced by the diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 67f71643d039..a35f28c25385 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -19,9 +19,12 @@ import torch from torch.distributed import ProcessGroup, TCPStore -from torch.distributed.distributed_c10d import (Backend, PrefixStore, - _get_default_timeout, - _unregister_process_group) +from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + _get_default_timeout, + _unregister_process_group, +) from torch.distributed.rendezvous import rendezvous import vllm.envs as envs @@ -33,9 +36,9 @@ # We prefer to use os.sched_yield as it results in tighter polling loops, # measured to be around 3e-7 seconds. However on earlier versions of Python # os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) -USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) - or (sys.version_info[:2] == (3, 10) - and sys.version_info[2] >= 8)) +USE_SCHED_YIELD = (sys.version_info[:3] >= (3, 11, 1)) or ( + sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8 +) def sched_yield(): @@ -48,7 +51,8 @@ def sched_yield(): def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format( - numerator, denominator) + numerator, denominator + ) def divide(numerator, denominator): @@ -63,16 +67,16 @@ def split_tensor_along_last_dim( num_partitions: int, contiguous_split_chunks: bool = False, ) -> Sequence[torch.Tensor]: - """ Split a tensor along its last dimension. + """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. - Returns: - A list of Tensors + Returns: + A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 @@ -86,8 +90,9 @@ def split_tensor_along_last_dim( return tensor_list -def get_pp_indices(num_hidden_layers: int, pp_rank: int, - pp_size: int) -> tuple[int, int]: +def get_pp_indices( + num_hidden_layers: int, pp_rank: int, pp_size: int +) -> tuple[int, int]: """Try to evenly distribute layers across partitions. If the number of layers is not divisible by the number of partitions, @@ -104,17 +109,15 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, partition_list_str = envs.VLLM_PP_LAYER_PARTITION if partition_list_str is not None: try: - partitions = [ - int(layer) for layer in partition_list_str.split(",") - ] + partitions = [int(layer) for layer in partition_list_str.split(",")] except ValueError as err: - raise ValueError("Invalid partition string: {}".format( - partition_list_str)) from err + raise ValueError( + "Invalid partition string: {}".format(partition_list_str) + ) from err if len(partitions) != pp_size: raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") if sum(partitions) != num_hidden_layers: - raise ValueError( - f"{sum(partitions)=} does not match {num_hidden_layers=}.") + raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.") else: layers_per_partition = num_hidden_layers // pp_size partitions = [layers_per_partition for _ in range(pp_size)] @@ -126,7 +129,8 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, "Hidden layers were unevenly partitioned: [%s]. " "This can be manually overridden using the " "VLLM_PP_LAYER_PARTITION environment variable", - ",".join(str(p) for p in partitions)) + ",".join(str(p) for p in partitions), + ) start_layer = sum(partitions[:pp_rank]) end_layer = start_layer + partitions[pp_rank] @@ -140,6 +144,7 @@ class StatelessProcessGroup: group. Only use it to communicate metadata between processes. For data-plane communication, create NCCL-related objects. """ + rank: int world_size: int store: torch._C._distributed_c10d.Store @@ -154,21 +159,16 @@ class StatelessProcessGroup: # src rank -> counter recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) broadcast_send_counter: int = 0 - broadcast_recv_src_counter: dict[int, int] = dataclasses.field( - default_factory=dict) + broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) # A deque to store the data entries, with key and timestamp. - entries: deque[tuple[str, - float]] = dataclasses.field(default_factory=deque) + entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque) def __post_init__(self): assert self.rank < self.world_size self.send_dst_counter = {i: 0 for i in range(self.world_size)} self.recv_src_counter = {i: 0 for i in range(self.world_size)} - self.broadcast_recv_src_counter = { - i: 0 - for i in range(self.world_size) - } + self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)} def send_obj(self, obj: Any, dst: int): """Send an object to a destination rank.""" @@ -192,8 +192,8 @@ def expire_data(self): def recv_obj(self, src: int) -> Any: """Receive an object from a source rank.""" obj = pickle.loads( - self.store.get( - f"send_to/{self.rank}/{self.recv_src_counter[src]}")) + self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}") + ) self.recv_src_counter[src] += 1 return obj @@ -204,15 +204,13 @@ def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: """ if self.rank == src: self.expire_data() - key = (f"broadcast_from/{src}/" - f"{self.broadcast_send_counter}") + key = f"broadcast_from/{src}/{self.broadcast_send_counter}" self.store.set(key, pickle.dumps(obj)) self.broadcast_send_counter += 1 self.entries.append((key, time.time())) return obj else: - key = (f"broadcast_from/{src}/" - f"{self.broadcast_recv_src_counter[src]}") + key = f"broadcast_from/{src}/{self.broadcast_recv_src_counter[src]}" recv_obj = pickle.loads(self.store.get(key)) self.broadcast_recv_src_counter[src] += 1 return recv_obj @@ -278,8 +276,7 @@ def barrier(self, timeout: float = 30.0): # Check for timeout cur_time = time.time() if cur_time - start_time > timeout: - raise RuntimeError("Barrier timed out after %f seconds", - timeout) + raise RuntimeError("Barrier timed out after %f seconds", timeout) # Check for each process for i in range(self.world_size): @@ -326,8 +323,7 @@ def barrier(self, timeout: float = 30.0): while len(processes_departed) < self.world_size: # Check for timeout if time.time() - start_time > timeout: - raise RuntimeError("Barrier departure timed out after %f s", - timeout) + raise RuntimeError("Barrier departure timed out after %f s", timeout) # Check for each process for i in range(self.world_size): @@ -356,14 +352,12 @@ def barrier(self, timeout: float = 30.0): try: self.store.delete_key(f"arrival_{barrier_id}_{i}") except Exception: - logger.debug("Error deleting key: %s", - f'arrival_{barrier_id}_{i}') + logger.debug("Error deleting key: %s", f"arrival_{barrier_id}_{i}") try: self.store.delete_key(f"departure_{barrier_id}_{i}") except Exception: - logger.debug("Error deleting key: %s", - f'departure_{barrier_id}_{i}') + logger.debug("Error deleting key: %s", f"departure_{barrier_id}_{i}") @staticmethod def create( @@ -388,7 +382,7 @@ def create( used for exchanging metadata. With this function, process A and process B can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. - """ # noqa + """ # noqa launch_server = rank == 0 if launch_server: # listen on the specified interface (instead of 0.0.0.0) @@ -416,14 +410,19 @@ def create( world_size=world_size, store=store, socket=listen_socket, - data_expiration_seconds=data_expiration_seconds) + data_expiration_seconds=data_expiration_seconds, + ) -def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore, - group_rank: int, group_size: int, - timeout: timedelta) -> ProcessGroup: +def init_gloo_process_group( + backend: Backend, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, +) -> ProcessGroup: """ - Stateless init ProcessGroup with gloo backend compatible with + Stateless init ProcessGroup with gloo backend compatible with different torch versions. """ if is_torch_equal_or_newer("2.6"): @@ -441,10 +440,10 @@ def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore, options, ) from torch.distributed.distributed_c10d import ProcessGroupGloo - backend_class = ProcessGroupGloo(prefix_store, - group_rank, - group_size, - timeout=timeout) + + backend_class = ProcessGroupGloo( + prefix_store, group_rank, group_size, timeout=timeout + ) backend_type = ProcessGroup.BackendType.GLOO device = torch.device("cpu") if is_torch_equal_or_newer("2.6"): @@ -457,8 +456,8 @@ def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore, def stateless_init_torch_distributed_process_group( - host: str, port: int, rank: int, world_size: int, - backend: str) -> ProcessGroup: + host: str, port: int, rank: int, world_size: int, backend: str +) -> ProcessGroup: """ A replacement for `torch.distributed.init_process_group` that does not pollute the global state. The created ProcessGroup object can be used for @@ -495,7 +494,8 @@ def stateless_init_torch_distributed_process_group( timeout = _get_default_timeout(backend) store, rank, world_size = next( - rendezvous(init_method, rank, world_size, timeout=timeout)) + rendezvous(init_method, rank, world_size, timeout=timeout) + ) store.set_timeout(timeout) group_rank = rank @@ -506,22 +506,25 @@ def stateless_init_torch_distributed_process_group( prefix_store = PrefixStore(init_method, store) if backend == "gloo": - return init_gloo_process_group(backend=backend, - prefix_store=prefix_store, - group_rank=group_rank, - group_size=group_size, - timeout=timeout) + return init_gloo_process_group( + backend=backend, + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout, + ) from vllm.platforms import current_platform + return current_platform.stateless_init_device_torch_dist_pg( backend=backend, prefix_store=prefix_store, group_rank=group_rank, group_size=group_size, - timeout=timeout) + timeout=timeout, + ) -def stateless_destroy_torch_distributed_process_group( - pg: ProcessGroup) -> None: +def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: """ Destroy ProcessGroup returned by stateless_init_torch_distributed_process_group(). @@ -531,6 +534,7 @@ def stateless_destroy_torch_distributed_process_group( else: # Lazy import for non-CUDA backends. from torch.distributed.distributed_c10d import _shutdown_backend + _shutdown_backend(pg) _unregister_process_group(pg.group_name) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bf293a4d2aa9..a94ef598f2de 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable import argparse import copy import dataclasses @@ -10,9 +9,19 @@ import sys from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations -from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List, - Literal, Optional, Type, TypeVar, Union, cast, get_args, - get_origin) +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Literal, + Optional, + TypeVar, + Union, + cast, + get_args, + get_origin, +) import huggingface_hub import regex as re @@ -21,17 +30,42 @@ from typing_extensions import TypeIs, deprecated import vllm.envs as envs -from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, - ConfigType, ConvertOption, DetailedTraceModules, - Device, DeviceConfig, DistributedExecutorBackend, - EPLBConfig, HfOverrides, KVEventsConfig, - KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, - ModelDType, ObservabilityConfig, ParallelConfig, - PoolerConfig, PrefixCachingHashAlgo, RunnerOption, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - StructuredOutputsConfig, TaskOption, TokenizerMode, - VllmConfig, get_attr_docs) +from vllm.config import ( + BlockSize, + CacheConfig, + CacheDType, + CompilationConfig, + ConfigType, + ConvertOption, + DetailedTraceModules, + Device, + DeviceConfig, + DistributedExecutorBackend, + EPLBConfig, + HfOverrides, + KVEventsConfig, + KVTransferConfig, + LoadConfig, + LogprobsMode, + LoRAConfig, + MambaDType, + MMEncoderTPMode, + ModelConfig, + ModelDType, + ObservabilityConfig, + ParallelConfig, + PoolerConfig, + PrefixCachingHashAlgo, + RunnerOption, + SchedulerConfig, + SchedulerPolicy, + SpeculativeConfig, + StructuredOutputsConfig, + TaskOption, + TokenizerMode, + VllmConfig, + get_attr_docs, +) from vllm.config.multimodal import MMCacheType, MultiModalConfig from vllm.config.parallel import ExpertPlacementStrategy from vllm.config.utils import get_field @@ -41,15 +75,15 @@ from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.config import (get_model_path, is_interleaved, - maybe_override_with_speculators) +from vllm.transformers_utils.config import ( + get_model_path, + is_interleaved, + maybe_override_with_speculators, +) from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip, - is_in_ray_actor) +from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor from vllm.v1.sample.logits_processor import LogitsProcessor -# yapf: enable - if TYPE_CHECKING: from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization import QuantizationMethods @@ -70,20 +104,18 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: - def _parse_type(val: str) -> T: try: return return_type(val) except ValueError as e: raise argparse.ArgumentTypeError( - f"Value {val} cannot be converted to {return_type}.") from e + f"Value {val} cannot be converted to {return_type}." + ) from e return _parse_type -def optional_type( - return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: - +def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: def _optional_type(val: str) -> Optional[T]: if val == "" or val == "None": return None @@ -124,7 +156,8 @@ def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]: if not all(isinstance(option, option_type) for option in options): raise ValueError( "All options must be of the same type. " - f"Got {options} with types {[type(c) for c in options]}") + f"Got {options} with types {[type(c) for c in options]}" + ) kwarg = "metavar" if contains_type(type_hints, str) else "choices" return {"type": option_type, kwarg: sorted(options)} @@ -191,8 +224,9 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name] = {"default": default, "help": help} # Set other kwargs based on the type hints - json_tip = ("Should either be a valid JSON string or JSON keys passed " - "individually.") + json_tip = ( + "Should either be a valid JSON string or JSON keys passed individually." + ) if dataclass_cls is not None: def parse_dataclass(val: str, cls=dataclass_cls) -> Any: @@ -214,7 +248,8 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: tuple_type = types[0] assert all(t is tuple_type for t in types if t is not Ellipsis), ( "All non-Ellipsis tuple elements must be of the same " - f"type. Got {types}.") + f"type. Got {types}." + ) kwargs[name]["type"] = tuple_type kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) elif contains_type(type_hints, list): @@ -240,19 +275,20 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}" elif contains_type(type_hints, float): kwargs[name]["type"] = float - elif (contains_type(type_hints, dict) - and (contains_type(type_hints, str) - or any(is_not_builtin(th) for th in type_hints))): + elif contains_type(type_hints, dict) and ( + contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints) + ): kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["help"] += f"\n\n{json_tip}" - elif (contains_type(type_hints, str) - or any(is_not_builtin(th) for th in type_hints)): + elif contains_type(type_hints, str) or any( + is_not_builtin(th) for th in type_hints + ): kwargs[name]["type"] = str else: - raise ValueError( - f"Unsupported type {type_hints} for argument {name}.") + raise ValueError(f"Unsupported type {type_hints} for argument {name}.") # If the type hint was a sequence of literals, use the helper function # to update the type and choices @@ -284,9 +320,9 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: @dataclass class EngineArgs: """Arguments for vLLM engine.""" + model: str = ModelConfig.model - served_model_name: Optional[Union[ - str, List[str]]] = ModelConfig.served_model_name + served_model_name: Optional[Union[str, list[str]]] = ModelConfig.served_model_name tokenizer: Optional[str] = ModelConfig.tokenizer hf_config_path: Optional[str] = ModelConfig.hf_config_path runner: RunnerOption = ModelConfig.runner @@ -297,8 +333,7 @@ class EngineArgs: tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path - allowed_media_domains: Optional[ - list[str]] = ModelConfig.allowed_media_domains + allowed_media_domains: Optional[list[str]] = ModelConfig.allowed_media_domains download_dir: Optional[str] = LoadConfig.download_dir safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy load_format: Union[str, LoadFormats] = LoadConfig.load_format @@ -307,19 +342,17 @@ class EngineArgs: kv_cache_dtype: CacheDType = CacheConfig.cache_dtype seed: Optional[int] = ModelConfig.seed max_model_len: Optional[int] = ModelConfig.max_model_len - cuda_graph_sizes: list[int] = get_field(SchedulerConfig, - "cuda_graph_sizes") + cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes") # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. - distributed_executor_backend: Optional[Union[ - str, DistributedExecutorBackend, - Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend + distributed_executor_backend: Optional[ + Union[str, DistributedExecutorBackend, type[ExecutorBase]] + ] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size - decode_context_parallel_size: int = \ - ParallelConfig.decode_context_parallel_size + decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None data_parallel_start_rank: Optional[int] = None @@ -330,38 +363,37 @@ class EngineArgs: data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_dbo: bool = ParallelConfig.enable_dbo - dbo_decode_token_threshold: int = \ - ParallelConfig.dbo_decode_token_threshold - dbo_prefill_token_threshold: int = \ - ParallelConfig.dbo_prefill_token_threshold + dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold + dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb - expert_placement_strategy: ExpertPlacementStrategy = \ + expert_placement_strategy: ExpertPlacementStrategy = ( ParallelConfig.expert_placement_strategy + ) _api_process_count: int = ParallelConfig._api_process_count _api_process_rank: int = ParallelConfig._api_process_rank num_redundant_experts: int = EPLBConfig.num_redundant_experts eplb_window_size: int = EPLBConfig.window_size eplb_step_interval: int = EPLBConfig.step_interval eplb_log_balancedness: bool = EPLBConfig.log_balancedness - max_parallel_loading_workers: Optional[ - int] = ParallelConfig.max_parallel_loading_workers + max_parallel_loading_workers: Optional[int] = ( + ParallelConfig.max_parallel_loading_workers + ) block_size: Optional[BlockSize] = CacheConfig.block_size enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching - prefix_caching_hash_algo: PrefixCachingHashAlgo = \ + prefix_caching_hash_algo: PrefixCachingHashAlgo = ( CacheConfig.prefix_caching_hash_algo + ) disable_sliding_window: bool = ModelConfig.disable_sliding_window disable_cascade_attn: bool = ModelConfig.disable_cascade_attn swap_space: float = CacheConfig.swap_space cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes - max_num_batched_tokens: Optional[ - int] = SchedulerConfig.max_num_batched_tokens + max_num_batched_tokens: Optional[int] = SchedulerConfig.max_num_batched_tokens max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills - long_prefill_token_threshold: int = \ - SchedulerConfig.long_prefill_token_threshold + long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode @@ -376,20 +408,22 @@ class EngineArgs: quantization: Optional[QuantizationMethods] = ModelConfig.quantization enforce_eager: bool = ModelConfig.enforce_eager disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce - limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = \ - get_field(MultiModalConfig, "limit_per_prompt") + limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = get_field( + MultiModalConfig, "limit_per_prompt" + ) interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings - media_io_kwargs: dict[str, dict[str, - Any]] = get_field(MultiModalConfig, - "media_io_kwargs") - mm_processor_kwargs: Optional[Dict[str, Any]] = \ - MultiModalConfig.mm_processor_kwargs + media_io_kwargs: dict[str, dict[str, Any]] = get_field( + MultiModalConfig, "media_io_kwargs" + ) + mm_processor_kwargs: Optional[dict[str, Any]] = MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb - mm_processor_cache_type: Optional[MMCacheType] = \ + mm_processor_cache_type: Optional[MMCacheType] = ( MultiModalConfig.mm_processor_cache_type - mm_shm_cache_max_object_size_mb: int = \ + ) + mm_shm_cache_max_object_size_mb: int = ( MultiModalConfig.mm_shm_cache_max_object_size_mb + ) mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode io_processor_plugin: Optional[str] = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling @@ -399,31 +433,28 @@ class EngineArgs: enable_lora_bias: bool = LoRAConfig.bias_enabled max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank - default_mm_loras: Optional[Dict[str, str]] = \ - LoRAConfig.default_mm_loras + default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight - num_gpu_blocks_override: Optional[ - int] = CacheConfig.num_gpu_blocks_override + num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots - model_loader_extra_config: dict = \ - get_field(LoadConfig, "model_loader_extra_config") - ignore_patterns: Optional[Union[str, - List[str]]] = LoadConfig.ignore_patterns + model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") + ignore_patterns: Optional[Union[str, list[str]]] = LoadConfig.ignore_patterns - enable_chunked_prefill: Optional[ - bool] = SchedulerConfig.enable_chunked_prefill + enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input disable_hybrid_kv_cache_manager: bool = ( - SchedulerConfig.disable_hybrid_kv_cache_manager) + SchedulerConfig.disable_hybrid_kv_cache_manager + ) structured_outputs_config: StructuredOutputsConfig = get_field( - VllmConfig, "structured_outputs_config") + VllmConfig, "structured_outputs_config" + ) reasoning_parser: str = StructuredOutputsConfig.reasoning_parser # Deprecated guided decoding fields guided_decoding_backend: Optional[str] = None @@ -431,25 +462,25 @@ class EngineArgs: guided_decoding_disable_any_whitespace: Optional[bool] = None guided_decoding_disable_additional_properties: Optional[bool] = None - logits_processor_pattern: Optional[ - str] = ModelConfig.logits_processor_pattern + logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern - speculative_config: Optional[Dict[str, Any]] = None + speculative_config: Optional[dict[str, Any]] = None - show_hidden_metrics_for_version: Optional[str] = \ + show_hidden_metrics_for_version: Optional[str] = ( ObservabilityConfig.show_hidden_metrics_for_version - otlp_traces_endpoint: Optional[str] = \ - ObservabilityConfig.otlp_traces_endpoint - collect_detailed_traces: Optional[list[DetailedTraceModules]] = \ + ) + otlp_traces_endpoint: Optional[str] = ObservabilityConfig.otlp_traces_endpoint + collect_detailed_traces: Optional[list[DetailedTraceModules]] = ( ObservabilityConfig.collect_detailed_traces + ) scheduling_policy: SchedulerPolicy = SchedulerConfig.policy - scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls + scheduler_cls: Union[str, type[object]] = SchedulerConfig.scheduler_cls pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config - override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ + override_pooler_config: Optional[Union[dict, PoolerConfig]] = ( ModelConfig.override_pooler_config - compilation_config: CompilationConfig = \ - get_field(VllmConfig, "compilation_config") + ) + compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls @@ -458,8 +489,9 @@ class EngineArgs: generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode - override_generation_config: dict[str, Any] = \ - get_field(ModelConfig, "override_generation_config") + override_generation_config: dict[str, Any] = get_field( + ModelConfig, "override_generation_config" + ) model_impl: str = ModelConfig.model_impl override_attention_dtype: str = ModelConfig.override_attention_dtype @@ -467,8 +499,7 @@ class EngineArgs: mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype - additional_config: dict[str, Any] = \ - get_field(VllmConfig, "additional_config") + additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location @@ -476,34 +507,36 @@ class EngineArgs: # DEPRECATED enable_multimodal_encoder_data_parallel: bool = False - logits_processors: Optional[list[Union[ - str, type[LogitsProcessor]]]] = ModelConfig.logits_processors + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = ( + ModelConfig.logits_processors + ) """Custom logitproc types""" async_scheduling: bool = SchedulerConfig.async_scheduling - kv_sharing_fast_prefill: bool = \ - CacheConfig.kv_sharing_fast_prefill + kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object if isinstance(self.compilation_config, dict): - self.compilation_config = CompilationConfig( - **self.compilation_config) + self.compilation_config = CompilationConfig(**self.compilation_config) if isinstance(self.eplb_config, dict): self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins from vllm.plugins import load_general_plugins + load_general_plugins() # when use hf offline,replace model id to local model path if huggingface_hub.constants.HF_HUB_OFFLINE: model_id = self.model self.model = get_model_path(self.model, self.revision) logger.info( - "HF_HUB_OFFLINE is True, replace model_id [%s] " \ - "to model_path [%s]",model_id, self.model) + "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]", + model_id, + self.model, + ) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -515,86 +548,92 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: title="ModelConfig", description=ModelConfig.__doc__, ) - if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]): + if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]): model_group.add_argument("--model", **model_kwargs["model"]) model_group.add_argument("--runner", **model_kwargs["runner"]) model_group.add_argument("--convert", **model_kwargs["convert"]) - model_group.add_argument("--task", - **model_kwargs["task"], - deprecated=True) + model_group.add_argument("--task", **model_kwargs["task"], deprecated=True) model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"]) - model_group.add_argument("--tokenizer-mode", - **model_kwargs["tokenizer_mode"]) - model_group.add_argument("--trust-remote-code", - **model_kwargs["trust_remote_code"]) + model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"]) + model_group.add_argument( + "--trust-remote-code", **model_kwargs["trust_remote_code"] + ) model_group.add_argument("--dtype", **model_kwargs["dtype"]) model_group.add_argument("--seed", **model_kwargs["seed"]) - model_group.add_argument("--hf-config-path", - **model_kwargs["hf_config_path"]) - model_group.add_argument("--allowed-local-media-path", - **model_kwargs["allowed_local_media_path"]) - model_group.add_argument("--allowed-media-domains", - **model_kwargs["allowed_media_domains"]) + model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"]) + model_group.add_argument( + "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"] + ) + model_group.add_argument( + "--allowed-media-domains", **model_kwargs["allowed_media_domains"] + ) model_group.add_argument("--revision", **model_kwargs["revision"]) - model_group.add_argument("--code-revision", - **model_kwargs["code_revision"]) - model_group.add_argument("--rope-scaling", - **model_kwargs["rope_scaling"]) + model_group.add_argument("--code-revision", **model_kwargs["code_revision"]) + model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"]) model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"]) - model_group.add_argument("--tokenizer-revision", - **model_kwargs["tokenizer_revision"]) - model_group.add_argument("--max-model-len", - **model_kwargs["max_model_len"]) - model_group.add_argument("--quantization", "-q", - **model_kwargs["quantization"]) - model_group.add_argument("--enforce-eager", - **model_kwargs["enforce_eager"]) - model_group.add_argument("--max-logprobs", - **model_kwargs["max_logprobs"]) - model_group.add_argument("--logprobs-mode", - **model_kwargs["logprobs_mode"]) - model_group.add_argument("--disable-sliding-window", - **model_kwargs["disable_sliding_window"]) - model_group.add_argument("--disable-cascade-attn", - **model_kwargs["disable_cascade_attn"]) - model_group.add_argument("--skip-tokenizer-init", - **model_kwargs["skip_tokenizer_init"]) - model_group.add_argument("--enable-prompt-embeds", - **model_kwargs["enable_prompt_embeds"]) - model_group.add_argument("--served-model-name", - **model_kwargs["served_model_name"]) - model_group.add_argument("--config-format", - **model_kwargs["config_format"]) + model_group.add_argument( + "--tokenizer-revision", **model_kwargs["tokenizer_revision"] + ) + model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"]) + model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"]) + model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"]) + model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) + model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"]) + model_group.add_argument( + "--disable-sliding-window", **model_kwargs["disable_sliding_window"] + ) + model_group.add_argument( + "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"] + ) + model_group.add_argument( + "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"] + ) + model_group.add_argument( + "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"] + ) + model_group.add_argument( + "--served-model-name", **model_kwargs["served_model_name"] + ) + model_group.add_argument("--config-format", **model_kwargs["config_format"]) # This one is a special case because it can bool # or str. TODO: Handle this in get_kwargs - model_group.add_argument("--hf-token", - type=str, - nargs="?", - const=True, - default=model_kwargs["hf_token"]["default"], - help=model_kwargs["hf_token"]["help"]) - model_group.add_argument("--hf-overrides", - **model_kwargs["hf_overrides"]) - model_group.add_argument("--pooler-config", - **model_kwargs["pooler_config"]) - model_group.add_argument("--override-pooler-config", - **model_kwargs["override_pooler_config"], - deprecated=True) - model_group.add_argument("--logits-processor-pattern", - **model_kwargs["logits_processor_pattern"]) - model_group.add_argument("--generation-config", - **model_kwargs["generation_config"]) - model_group.add_argument("--override-generation-config", - **model_kwargs["override_generation_config"]) - model_group.add_argument("--enable-sleep-mode", - **model_kwargs["enable_sleep_mode"]) + model_group.add_argument( + "--hf-token", + type=str, + nargs="?", + const=True, + default=model_kwargs["hf_token"]["default"], + help=model_kwargs["hf_token"]["help"], + ) + model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"]) + model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"]) + model_group.add_argument( + "--override-pooler-config", + **model_kwargs["override_pooler_config"], + deprecated=True, + ) + model_group.add_argument( + "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"] + ) + model_group.add_argument( + "--generation-config", **model_kwargs["generation_config"] + ) + model_group.add_argument( + "--override-generation-config", **model_kwargs["override_generation_config"] + ) + model_group.add_argument( + "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"] + ) model_group.add_argument("--model-impl", **model_kwargs["model_impl"]) - model_group.add_argument("--override-attention-dtype", - **model_kwargs["override_attention_dtype"]) - model_group.add_argument("--logits-processors", - **model_kwargs["logits_processors"]) - model_group.add_argument("--io-processor-plugin", - **model_kwargs["io_processor_plugin"]) + model_group.add_argument( + "--override-attention-dtype", **model_kwargs["override_attention_dtype"] + ) + model_group.add_argument( + "--logits-processors", **model_kwargs["logits_processors"] + ) + model_group.add_argument( + "--io-processor-plugin", **model_kwargs["io_processor_plugin"] + ) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -603,18 +642,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: description=LoadConfig.__doc__, ) load_group.add_argument("--load-format", **load_kwargs["load_format"]) - load_group.add_argument("--download-dir", - **load_kwargs["download_dir"]) - load_group.add_argument("--safetensors-load-strategy", - **load_kwargs["safetensors_load_strategy"]) - load_group.add_argument("--model-loader-extra-config", - **load_kwargs["model_loader_extra_config"]) - load_group.add_argument("--ignore-patterns", - **load_kwargs["ignore_patterns"]) - load_group.add_argument("--use-tqdm-on-load", - **load_kwargs["use_tqdm_on_load"]) - load_group.add_argument('--pt-load-map-location', - **load_kwargs["pt_load_map_location"]) + load_group.add_argument("--download-dir", **load_kwargs["download_dir"]) + load_group.add_argument( + "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"] + ) + load_group.add_argument( + "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"] + ) + load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"]) + load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"]) + load_group.add_argument( + "--pt-load-map-location", **load_kwargs["pt_load_map_location"] + ) # Structured outputs arguments structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) @@ -626,7 +665,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--reasoning-parser", # This choice is a special case because it's not static choices=list(ReasoningParserManager.reasoning_parsers), - **structured_outputs_kwargs["reasoning_parser"]) + **structured_outputs_kwargs["reasoning_parser"], + ) # Deprecated guided decoding arguments for arg, type in [ ("--guided-decoding-backend", str), @@ -638,7 +678,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: arg, type=type, help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."), - deprecated=True) + deprecated=True, + ) # Parallel arguments parallel_kwargs = get_kwargs(ParallelConfig) @@ -648,111 +689,128 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parallel_group.add_argument( "--distributed-executor-backend", - **parallel_kwargs["distributed_executor_backend"]) + **parallel_kwargs["distributed_executor_backend"], + ) + parallel_group.add_argument( + "--pipeline-parallel-size", + "-pp", + **parallel_kwargs["pipeline_parallel_size"], + ) + parallel_group.add_argument( + "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"] + ) + parallel_group.add_argument( + "--decode-context-parallel-size", + "-dcp", + **parallel_kwargs["decode_context_parallel_size"], + ) + parallel_group.add_argument( + "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] + ) parallel_group.add_argument( - "--pipeline-parallel-size", "-pp", - **parallel_kwargs["pipeline_parallel_size"]) - parallel_group.add_argument("--tensor-parallel-size", "-tp", - **parallel_kwargs["tensor_parallel_size"]) + "--data-parallel-rank", + "-dpn", + type=int, + help="Data parallel rank of this instance. " + "When set, enables external load balancer mode.", + ) parallel_group.add_argument( - "--decode-context-parallel-size", "-dcp", - **parallel_kwargs["decode_context_parallel_size"]) - parallel_group.add_argument("--data-parallel-size", "-dp", - **parallel_kwargs["data_parallel_size"]) + "--data-parallel-start-rank", + "-dpr", + type=int, + help="Starting data parallel rank for secondary nodes.", + ) parallel_group.add_argument( - '--data-parallel-rank', - '-dpn', + "--data-parallel-size-local", + "-dpl", type=int, - help='Data parallel rank of this instance. ' - 'When set, enables external load balancer mode.') - parallel_group.add_argument('--data-parallel-start-rank', - '-dpr', - type=int, - help='Starting data parallel rank ' - 'for secondary nodes.') - parallel_group.add_argument('--data-parallel-size-local', - '-dpl', - type=int, - help='Number of data parallel replicas ' - 'to run on this node.') - parallel_group.add_argument('--data-parallel-address', - '-dpa', - type=str, - help='Address of data parallel cluster ' - 'head-node.') - parallel_group.add_argument('--data-parallel-rpc-port', - '-dpp', - type=int, - help='Port for data parallel RPC ' - 'communication.') - parallel_group.add_argument('--data-parallel-backend', - '-dpb', - type=str, - default='mp', - help='Backend for data parallel, either ' - '"mp" or "ray".') + help="Number of data parallel replicas to run on this node.", + ) + parallel_group.add_argument( + "--data-parallel-address", + "-dpa", + type=str, + help="Address of data parallel cluster head-node.", + ) + parallel_group.add_argument( + "--data-parallel-rpc-port", + "-dpp", + type=int, + help="Port for data parallel RPC communication.", + ) parallel_group.add_argument( - "--data-parallel-hybrid-lb", - **parallel_kwargs["data_parallel_hybrid_lb"]) + "--data-parallel-backend", + "-dpb", + type=str, + default="mp", + help='Backend for data parallel, either "mp" or "ray".', + ) + parallel_group.add_argument( + "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"] + ) parallel_group.add_argument( - "--enable-expert-parallel", - **parallel_kwargs["enable_expert_parallel"]) - parallel_group.add_argument("--enable-dbo", - **parallel_kwargs["enable_dbo"]) + "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"] + ) + parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) parallel_group.add_argument( "--dbo-decode-token-threshold", - **parallel_kwargs["dbo_decode_token_threshold"]) + **parallel_kwargs["dbo_decode_token_threshold"], + ) parallel_group.add_argument( "--dbo-prefill-token-threshold", - **parallel_kwargs["dbo_prefill_token_threshold"]) - parallel_group.add_argument("--enable-eplb", - **parallel_kwargs["enable_eplb"]) - parallel_group.add_argument("--eplb-config", - **parallel_kwargs["eplb_config"]) + **parallel_kwargs["dbo_prefill_token_threshold"], + ) + parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) + parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"]) parallel_group.add_argument( "--expert-placement-strategy", - **parallel_kwargs["expert_placement_strategy"]) + **parallel_kwargs["expert_placement_strategy"], + ) parallel_group.add_argument( "--num-redundant-experts", type=int, - help= - "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--eplb-window-size", type=int, help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.", - deprecated=True) + deprecated=True, + ) parallel_group.add_argument( "--eplb-step-interval", type=int, - help= - "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--eplb-log-balancedness", action=argparse.BooleanOptionalAction, - help= - "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--max-parallel-loading-workers", - **parallel_kwargs["max_parallel_loading_workers"]) + **parallel_kwargs["max_parallel_loading_workers"], + ) parallel_group.add_argument( - "--ray-workers-use-nsight", - **parallel_kwargs["ray_workers_use_nsight"]) + "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"] + ) parallel_group.add_argument( "--disable-custom-all-reduce", - **parallel_kwargs["disable_custom_all_reduce"]) - parallel_group.add_argument("--worker-cls", - **parallel_kwargs["worker_cls"]) - parallel_group.add_argument("--worker-extension-cls", - **parallel_kwargs["worker_extension_cls"]) + **parallel_kwargs["disable_custom_all_reduce"], + ) + parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"]) + parallel_group.add_argument( + "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"] + ) parallel_group.add_argument( "--enable-multimodal-encoder-data-parallel", action="store_true", - deprecated=True) + deprecated=True, + ) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -761,29 +819,36 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: description=CacheConfig.__doc__, ) cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) - cache_group.add_argument("--gpu-memory-utilization", - **cache_kwargs["gpu_memory_utilization"]) - cache_group.add_argument("--kv-cache-memory-bytes", - **cache_kwargs["kv_cache_memory_bytes"]) + cache_group.add_argument( + "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"] + ) + cache_group.add_argument( + "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"] + ) cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) - cache_group.add_argument("--kv-cache-dtype", - **cache_kwargs["cache_dtype"]) - cache_group.add_argument("--num-gpu-blocks-override", - **cache_kwargs["num_gpu_blocks_override"]) - cache_group.add_argument("--enable-prefix-caching", - **cache_kwargs["enable_prefix_caching"]) - cache_group.add_argument("--prefix-caching-hash-algo", - **cache_kwargs["prefix_caching_hash_algo"]) - cache_group.add_argument("--cpu-offload-gb", - **cache_kwargs["cpu_offload_gb"]) - cache_group.add_argument("--calculate-kv-scales", - **cache_kwargs["calculate_kv_scales"]) - cache_group.add_argument("--kv-sharing-fast-prefill", - **cache_kwargs["kv_sharing_fast_prefill"]) - cache_group.add_argument("--mamba-cache-dtype", - **cache_kwargs["mamba_cache_dtype"]) - cache_group.add_argument("--mamba-ssm-cache-dtype", - **cache_kwargs["mamba_ssm_cache_dtype"]) + cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"]) + cache_group.add_argument( + "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"] + ) + cache_group.add_argument( + "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"] + ) + cache_group.add_argument( + "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"] + ) + cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"]) + cache_group.add_argument( + "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"] + ) + cache_group.add_argument( + "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"] + ) + cache_group.add_argument( + "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"] + ) + cache_group.add_argument( + "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"] + ) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -791,35 +856,41 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: title="MultiModalConfig", description=MultiModalConfig.__doc__, ) - multimodal_group.add_argument("--limit-mm-per-prompt", - **multimodal_kwargs["limit_per_prompt"]) - multimodal_group.add_argument("--media-io-kwargs", - **multimodal_kwargs["media_io_kwargs"]) multimodal_group.add_argument( - "--mm-processor-kwargs", - **multimodal_kwargs["mm_processor_kwargs"]) + "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"] + ) + multimodal_group.add_argument( + "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"] + ) + multimodal_group.add_argument( + "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"] + ) + multimodal_group.add_argument( + "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"] + ) multimodal_group.add_argument( - "--mm-processor-cache-gb", - **multimodal_kwargs["mm_processor_cache_gb"]) - multimodal_group.add_argument("--disable-mm-preprocessor-cache", - action="store_true", - deprecated=True) + "--disable-mm-preprocessor-cache", action="store_true", deprecated=True + ) multimodal_group.add_argument( - "--mm-processor-cache-type", - **multimodal_kwargs["mm_processor_cache_type"]) + "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"] + ) multimodal_group.add_argument( "--mm-shm-cache-max-object-size-mb", - **multimodal_kwargs["mm_shm_cache_max_object_size_mb"]) + **multimodal_kwargs["mm_shm_cache_max_object_size_mb"], + ) multimodal_group.add_argument( - "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) + "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"] + ) + multimodal_group.add_argument( + "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"] + ) multimodal_group.add_argument( - "--interleave-mm-strings", - **multimodal_kwargs["interleave_mm_strings"]) - multimodal_group.add_argument("--skip-mm-profiling", - **multimodal_kwargs["skip_mm_profiling"]) + "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"] + ) multimodal_group.add_argument( - "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]) + "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"] + ) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -830,24 +901,23 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: lora_group.add_argument( "--enable-lora", action=argparse.BooleanOptionalAction, - help="If True, enable handling of LoRA adapters.") - lora_group.add_argument("--enable-lora-bias", - **lora_kwargs["bias_enabled"]) + help="If True, enable handling of LoRA adapters.", + ) + lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) - lora_group.add_argument("--max-lora-rank", - **lora_kwargs["max_lora_rank"]) - lora_group.add_argument("--lora-extra-vocab-size", - **lora_kwargs["lora_extra_vocab_size"]) + lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) + lora_group.add_argument( + "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"] + ) lora_group.add_argument( "--lora-dtype", **lora_kwargs["lora_dtype"], ) - lora_group.add_argument("--max-cpu-loras", - **lora_kwargs["max_cpu_loras"]) - lora_group.add_argument("--fully-sharded-loras", - **lora_kwargs["fully_sharded_loras"]) - lora_group.add_argument("--default-mm-loras", - **lora_kwargs["default_mm_loras"]) + lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"]) + lora_group.add_argument( + "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"] + ) + lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) # Observability arguments observability_kwargs = get_kwargs(ObservabilityConfig) @@ -857,21 +927,22 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) observability_group.add_argument( "--show-hidden-metrics-for-version", - **observability_kwargs["show_hidden_metrics_for_version"]) + **observability_kwargs["show_hidden_metrics_for_version"], + ) observability_group.add_argument( - "--otlp-traces-endpoint", - **observability_kwargs["otlp_traces_endpoint"]) + "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"] + ) # TODO: generalise this special case choices = observability_kwargs["collect_detailed_traces"]["choices"] metavar = f"{{{','.join(choices)}}}" observability_kwargs["collect_detailed_traces"]["metavar"] = metavar observability_kwargs["collect_detailed_traces"]["choices"] += [ - ",".join(p) - for p in permutations(get_args(DetailedTraceModules), r=2) + ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2) ] observability_group.add_argument( "--collect-detailed-traces", - **observability_kwargs["collect_detailed_traces"]) + **observability_kwargs["collect_detailed_traces"], + ) # Scheduler arguments scheduler_kwargs = get_kwargs(SchedulerConfig) @@ -880,40 +951,49 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: description=SchedulerConfig.__doc__, ) scheduler_group.add_argument( - "--max-num-batched-tokens", - **scheduler_kwargs["max_num_batched_tokens"]) - scheduler_group.add_argument("--max-num-seqs", - **scheduler_kwargs["max_num_seqs"]) + "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"] + ) scheduler_group.add_argument( - "--max-num-partial-prefills", - **scheduler_kwargs["max_num_partial_prefills"]) + "--max-num-seqs", **scheduler_kwargs["max_num_seqs"] + ) + scheduler_group.add_argument( + "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"] + ) scheduler_group.add_argument( "--max-long-partial-prefills", - **scheduler_kwargs["max_long_partial_prefills"]) - scheduler_group.add_argument('--cuda-graph-sizes', - **scheduler_kwargs["cuda_graph_sizes"]) + **scheduler_kwargs["max_long_partial_prefills"], + ) + scheduler_group.add_argument( + "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"] + ) scheduler_group.add_argument( "--long-prefill-token-threshold", - **scheduler_kwargs["long_prefill_token_threshold"]) - scheduler_group.add_argument("--num-lookahead-slots", - **scheduler_kwargs["num_lookahead_slots"]) + **scheduler_kwargs["long_prefill_token_threshold"], + ) + scheduler_group.add_argument( + "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"] + ) # multi-step scheduling has been removed; corresponding arguments # are no longer supported. - scheduler_group.add_argument("--scheduling-policy", - **scheduler_kwargs["policy"]) scheduler_group.add_argument( - "--enable-chunked-prefill", - **scheduler_kwargs["enable_chunked_prefill"]) + "--scheduling-policy", **scheduler_kwargs["policy"] + ) scheduler_group.add_argument( - "--disable-chunked-mm-input", - **scheduler_kwargs["disable_chunked_mm_input"]) - scheduler_group.add_argument("--scheduler-cls", - **scheduler_kwargs["scheduler_cls"]) + "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"] + ) + scheduler_group.add_argument( + "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"] + ) + scheduler_group.add_argument( + "--scheduler-cls", **scheduler_kwargs["scheduler_cls"] + ) scheduler_group.add_argument( "--disable-hybrid-kv-cache-manager", - **scheduler_kwargs["disable_hybrid_kv_cache_manager"]) - scheduler_group.add_argument("--async-scheduling", - **scheduler_kwargs["async_scheduling"]) + **scheduler_kwargs["disable_hybrid_kv_cache_manager"], + ) + scheduler_group.add_argument( + "--async-scheduling", **scheduler_kwargs["async_scheduling"] + ) # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) @@ -925,23 +1005,29 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: # create_engine_config. So we set the type to a JSON string here to # delay the Pydantic validation that comes with SpeculativeConfig. vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads) - vllm_group.add_argument("--speculative-config", - **vllm_kwargs["speculative_config"]) - vllm_group.add_argument("--kv-transfer-config", - **vllm_kwargs["kv_transfer_config"]) - vllm_group.add_argument('--kv-events-config', - **vllm_kwargs["kv_events_config"]) - vllm_group.add_argument("--compilation-config", "-O", - **vllm_kwargs["compilation_config"]) - vllm_group.add_argument("--additional-config", - **vllm_kwargs["additional_config"]) - vllm_group.add_argument('--structured-outputs-config', - **vllm_kwargs["structured_outputs_config"]) + vllm_group.add_argument( + "--speculative-config", **vllm_kwargs["speculative_config"] + ) + vllm_group.add_argument( + "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"] + ) + vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument( + "--compilation-config", "-O", **vllm_kwargs["compilation_config"] + ) + vllm_group.add_argument( + "--additional-config", **vllm_kwargs["additional_config"] + ) + vllm_group.add_argument( + "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] + ) # Other arguments - parser.add_argument('--disable-log-stats', - action='store_true', - help='Disable logging statistics.') + parser.add_argument( + "--disable-log-stats", + action="store_true", + help="Disable logging statistics.", + ) return parser @@ -950,10 +1036,9 @@ def from_cli_args(cls, args: argparse.Namespace): # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. - engine_args = cls(**{ - attr: getattr(args, attr) - for attr in attrs if hasattr(args, attr) - }) + engine_args = cls( + **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)} + ) return engine_args def create_model_config(self) -> ModelConfig: @@ -962,15 +1047,20 @@ def create_model_config(self) -> ModelConfig: self.quantization = self.load_format = "gguf" # NOTE: This is to allow model loading from S3 in CI - if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 - and self.model in MODELS_ON_S3 and self.load_format == "auto"): + if ( + not isinstance(self, AsyncEngineArgs) + and envs.VLLM_CI_USE_S3 + and self.model in MODELS_ON_S3 + and self.load_format == "auto" + ): self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" if self.disable_mm_preprocessor_cache: logger.warning( "`--disable-mm-preprocessor-cache` is deprecated " "and will be removed in v0.13. " - "Please use `--mm-processor-cache-gb 0` instead.", ) + "Please use `--mm-processor-cache-gb 0` instead.", + ) self.mm_processor_cache_gb = 0 elif envs.VLLM_MM_INPUT_CACHE_GIB != 4: @@ -987,7 +1077,8 @@ def create_model_config(self) -> ModelConfig: logger.warning( "--enable-multimodal-encoder-data-parallel` is deprecated " "and will be removed in v0.13. " - "Please use `--mm-encoder-tp-mode data` instead.") + "Please use `--mm-encoder-tp-mode data` instead." + ) self.mm_encoder_tp_mode = "data" @@ -1029,8 +1120,7 @@ def create_model_config(self) -> ModelConfig: mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, mm_processor_cache_type=self.mm_processor_cache_type, - mm_shm_cache_max_object_size_mb=self. - mm_shm_cache_max_object_size_mb, + mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, pooler_config=self.pooler_config, override_pooler_config=self.override_pooler_config, @@ -1046,33 +1136,34 @@ def create_model_config(self) -> ModelConfig: ) def validate_tensorizer_args(self): - from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig) + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + for key in self.model_loader_extra_config: if key in TensorizerConfig._fields: - self.model_loader_extra_config["tensorizer_config"][ - key] = self.model_loader_extra_config[key] + self.model_loader_extra_config["tensorizer_config"][key] = ( + self.model_loader_extra_config[key] + ) def create_load_config(self) -> LoadConfig: - if self.quantization == "bitsandbytes": self.load_format = "bitsandbytes" if self.load_format == "tensorizer": if hasattr(self.model_loader_extra_config, "to_serializable"): self.model_loader_extra_config = ( - self.model_loader_extra_config.to_serializable()) + self.model_loader_extra_config.to_serializable() + ) self.model_loader_extra_config["tensorizer_config"] = {} - self.model_loader_extra_config["tensorizer_config"][ - "tensorizer_dir"] = self.model + self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = ( + self.model + ) self.validate_tensorizer_args() return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, safetensors_load_strategy=self.safetensors_load_strategy, - device="cpu" - if is_online_quantization(self.quantization) else None, + device="cpu" if is_online_quantization(self.quantization) else None, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, use_tqdm_on_load=self.use_tqdm_on_load, @@ -1100,12 +1191,14 @@ def create_speculative_config( # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine # config. - self.speculative_config.update({ - "target_model_config": target_model_config, - "target_parallel_config": target_parallel_config, - "enable_chunked_prefill": enable_chunked_prefill, - "disable_log_stats": disable_log_stats, - }) + self.speculative_config.update( + { + "target_model_config": target_model_config, + "target_parallel_config": target_parallel_config, + "enable_chunked_prefill": enable_chunked_prefill, + "disable_log_stats": disable_log_stats, + } + ) return SpeculativeConfig(**self.speculative_config) def create_engine_config( @@ -1128,21 +1221,21 @@ def create_engine_config( """ current_platform.pre_register_and_update() - device_config = DeviceConfig( - device=cast(Device, current_platform.device_type)) + device_config = DeviceConfig(device=cast(Device, current_platform.device_type)) model_config = self.create_model_config() self.model = model_config.model self.tokenizer = model_config.tokenizer - (self.model, self.tokenizer, - self.speculative_config) = maybe_override_with_speculators( - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code, - vllm_speculative_config=self.speculative_config, - ) + (self.model, self.tokenizer, self.speculative_config) = ( + maybe_override_with_speculators( + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + vllm_speculative_config=self.speculative_config, + ) + ) # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" # and fall back to V0 for experimental or unsupported features. @@ -1164,12 +1257,17 @@ def create_engine_config( # Set default arguments for V1 Engine. self._set_default_args(usage_context, model_config) # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 - if current_platform.is_cpu() and current_platform.get_cpu_architecture( - ) in (CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM, - CpuArchEnum.RISCV): - logger.info("Chunked prefill is not supported for ARM and POWER, " - "S390X and RISC-V CPUs; " - "disabling it for V1 backend.") + if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( + CpuArchEnum.POWERPC, + CpuArchEnum.S390X, + CpuArchEnum.ARM, + CpuArchEnum.RISCV, + ): + logger.info( + "Chunked prefill is not supported for ARM and POWER, " + "S390X and RISC-V CPUs; " + "disabling it for V1 backend." + ) self.enable_chunked_prefill = False assert self.enable_chunked_prefill is not None @@ -1185,8 +1283,7 @@ def create_engine_config( # because the world size does not change by dcp, it simply # reuses the GPUs of TP group, and split one TP group into # tp_size//dcp_size DCP groups. - assert self.tensor_parallel_size % self.decode_context_parallel_size \ - == 0, ( + assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, ( f"tp_size={self.tensor_parallel_size} must be divisible by" f"dcp_size={self.decode_context_parallel_size}." ) @@ -1215,6 +1312,7 @@ def create_engine_config( # of a Ray task, therefore we check is_ray_initialized() # as opposed to is_in_ray_actor(). import ray + ray_runtime_env = ray.get_runtime_context().runtime_env logger.info("Using ray runtime env: %s", ray_runtime_env) @@ -1230,15 +1328,15 @@ def create_engine_config( placement_group = ray.util.get_current_placement_group() assert not headless or not self.data_parallel_hybrid_lb, ( - "data_parallel_hybrid_lb is not applicable in " - "headless mode") + "data_parallel_hybrid_lb is not applicable in headless mode" + ) data_parallel_external_lb = self.data_parallel_rank is not None # Local DP rank = 1, use pure-external LB. if data_parallel_external_lb: assert self.data_parallel_size_local in (1, None), ( - "data_parallel_size_local must be 1 when data_parallel_rank " - "is set") + "data_parallel_size_local must be 1 when data_parallel_rank is set" + ) data_parallel_size_local = 1 # Use full external lb if we have local_size of 1. self.data_parallel_hybrid_lb = False @@ -1261,8 +1359,8 @@ def create_engine_config( self.data_parallel_rank = self.data_parallel_start_rank or 0 else: assert not self.data_parallel_hybrid_lb, ( - "data_parallel_size_local must be set to use " - "data_parallel_hybrid_lb.") + "data_parallel_size_local must be set to use data_parallel_hybrid_lb." + ) # Local DP size defaults to global DP size if not set. data_parallel_size_local = self.data_parallel_size @@ -1273,39 +1371,46 @@ def create_engine_config( if self.data_parallel_backend == "ray": host_ip = get_ip() logger.info( - "Using host IP %s as ray-based data parallel address", - host_ip) + "Using host IP %s as ray-based data parallel address", host_ip + ) data_parallel_address = host_ip else: assert self.data_parallel_backend == "mp", ( "data_parallel_backend can only be ray or mp, got %s", - self.data_parallel_backend) + self.data_parallel_backend, + ) data_parallel_address = ParallelConfig.data_parallel_master_ip else: data_parallel_address = self.data_parallel_address # This port is only used when there are remote data parallel engines, # otherwise the local IPC transport is used. - data_parallel_rpc_port = self.data_parallel_rpc_port if ( + data_parallel_rpc_port = ( self.data_parallel_rpc_port - is not None) else ParallelConfig.data_parallel_rpc_port + if (self.data_parallel_rpc_port is not None) + else ParallelConfig.data_parallel_rpc_port + ) if self.async_scheduling: # Async scheduling does not work with the uniprocess backend. if self.distributed_executor_backend is None: self.distributed_executor_backend = "mp" - logger.info("Defaulting to mp-based distributed executor " - "backend for async scheduling.") + logger.info( + "Defaulting to mp-based distributed executor " + "backend for async scheduling." + ) if self.pipeline_parallel_size > 1: - raise ValueError("Async scheduling is not supported with " - "pipeline-parallel-size > 1.") + raise ValueError( + "Async scheduling is not supported with pipeline-parallel-size > 1." + ) # Currently, async scheduling does not support speculative decoding. # TODO(woosuk): Support it. if self.speculative_config is not None: raise ValueError( "Currently, speculative decoding is not supported with " - "async scheduling.") + "async scheduling." + ) # Forward the deprecated CLI args to the EPLB config. if self.num_redundant_experts is not None: @@ -1372,33 +1477,38 @@ def create_engine_config( disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, is_encoder_decoder=model_config.is_encoder_decoder, - send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER - and parallel_config.use_ray), + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, scheduler_cls=self.scheduler_cls, max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, - disable_hybrid_kv_cache_manager=self. - disable_hybrid_kv_cache_manager, + disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager, async_scheduling=self.async_scheduling, ) if not model_config.is_multimodal_model and self.default_mm_loras: raise ValueError( "Default modality-specific LoRA(s) were provided for a " - "non multimodal model") - - lora_config = LoRAConfig( - bias_enabled=self.enable_lora_bias, - max_lora_rank=self.max_lora_rank, - max_loras=self.max_loras, - default_mm_loras=self.default_mm_loras, - fully_sharded_loras=self.fully_sharded_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, - lora_dtype=self.lora_dtype, - max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras - and self.max_cpu_loras > 0 else None) if self.enable_lora else None + "non multimodal model" + ) + + lora_config = ( + LoRAConfig( + bias_enabled=self.enable_lora_bias, + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + default_mm_loras=self.default_mm_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras + if self.max_cpu_loras and self.max_cpu_loras > 0 + else None, + ) + if self.enable_lora + else None + ) # bitsandbytes pre-quantized model need a specific model loader if model_config.quantization == "bitsandbytes": @@ -1408,27 +1518,27 @@ def create_engine_config( # Pass reasoning_parser into StructuredOutputsConfig if self.reasoning_parser: - self.structured_outputs_config.reasoning_parser = \ - self.reasoning_parser + self.structured_outputs_config.reasoning_parser = self.reasoning_parser # Forward the deprecated CLI args to the StructuredOutputsConfig so_config = self.structured_outputs_config if self.guided_decoding_backend is not None: - so_config.guided_decoding_backend = \ - self.guided_decoding_backend + so_config.guided_decoding_backend = self.guided_decoding_backend if self.guided_decoding_disable_fallback is not None: - so_config.guided_decoding_disable_fallback = \ - self.guided_decoding_disable_fallback + so_config.guided_decoding_disable_fallback = ( + self.guided_decoding_disable_fallback + ) if self.guided_decoding_disable_any_whitespace is not None: - so_config.guided_decoding_disable_any_whitespace = \ - self.guided_decoding_disable_any_whitespace + so_config.guided_decoding_disable_any_whitespace = ( + self.guided_decoding_disable_any_whitespace + ) if self.guided_decoding_disable_additional_properties is not None: - so_config.guided_decoding_disable_additional_properties = \ - self.guided_decoding_disable_additional_properties + so_config.guided_decoding_disable_additional_properties = ( + self.guided_decoding_disable_additional_properties + ) observability_config = ObservabilityConfig( - show_hidden_metrics_for_version=( - self.show_hidden_metrics_for_version), + show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version), otlp_traces_endpoint=self.otlp_traces_endpoint, collect_detailed_traces=self.collect_detailed_traces, ) @@ -1458,25 +1568,28 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: ############################################################# # Unsupported Feature Flags on V1. - if (self.logits_processor_pattern - != EngineArgs.logits_processor_pattern): - _raise_or_fallback(feature_name="--logits-processor-pattern", - recommend_to_remove=False) + if self.logits_processor_pattern != EngineArgs.logits_processor_pattern: + _raise_or_fallback( + feature_name="--logits-processor-pattern", recommend_to_remove=False + ) return False # No Mamba or Encoder-Decoder so far. if not model_config.is_v1_compatible: - _raise_or_fallback(feature_name=model_config.architectures, - recommend_to_remove=False) + _raise_or_fallback( + feature_name=model_config.architectures, recommend_to_remove=False + ) return False # No Concurrent Partial Prefills so far. - if (self.max_num_partial_prefills - != SchedulerConfig.max_num_partial_prefills - or self.max_long_partial_prefills - != SchedulerConfig.max_long_partial_prefills): - _raise_or_fallback(feature_name="Concurrent Partial Prefill", - recommend_to_remove=False) + if ( + self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills + or self.max_long_partial_prefills + != SchedulerConfig.max_long_partial_prefills + ): + _raise_or_fallback( + feature_name="Concurrent Partial Prefill", recommend_to_remove=False + ) return False # V1 supports N-gram, Medusa, and Eagle speculative decoding. @@ -1491,7 +1604,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: raise NotImplementedError( "Draft model speculative decoding is not supported yet. " "Please consider using other speculative decoding methods " - "such as ngram, medusa, eagle, or mtp.") + "such as ngram, medusa, eagle, or mtp." + ) V1_BACKENDS = [ "FLASH_ATTN", @@ -1510,8 +1624,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "XFORMERS", "ROCM_ATTN", ] - if (envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): + if ( + envs.is_set("VLLM_ATTENTION_BACKEND") + and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS + ): name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" _raise_or_fallback(feature_name=name, recommend_to_remove=True) return False @@ -1520,30 +1636,36 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # Experimental Features - allow users to opt in. if self.pipeline_parallel_size > 1: - supports_pp = getattr(self.distributed_executor_backend, - 'supports_pp', False) + supports_pp = getattr( + self.distributed_executor_backend, "supports_pp", False + ) if not supports_pp and self.distributed_executor_backend not in ( - ParallelConfig.distributed_executor_backend, "ray", "mp", - "external_launcher"): - name = "Pipeline Parallelism without Ray distributed " \ - "executor or multiprocessing executor or external " \ - "launcher" - _raise_or_fallback(feature_name=name, - recommend_to_remove=False) + ParallelConfig.distributed_executor_backend, + "ray", + "mp", + "external_launcher", + ): + name = ( + "Pipeline Parallelism without Ray distributed " + "executor or multiprocessing executor or external " + "launcher" + ) + _raise_or_fallback(feature_name=name, recommend_to_remove=False) return False - if (current_platform.is_cpu() - and model_config.get_sliding_window() is not None): - _raise_or_fallback(feature_name="sliding window (CPU backend)", - recommend_to_remove=False) + if current_platform.is_cpu() and model_config.get_sliding_window() is not None: + _raise_or_fallback( + feature_name="sliding window (CPU backend)", recommend_to_remove=False + ) return False ############################################################# return True - def _set_default_args(self, usage_context: UsageContext, - model_config: ModelConfig) -> None: + def _set_default_args( + self, usage_context: UsageContext, model_config: ModelConfig + ) -> None: """Set Default Arguments for V1 Engine.""" # V1 always uses chunked prefills and prefix caching @@ -1554,26 +1676,31 @@ def _set_default_args(self, usage_context: UsageContext, # TODO: When prefix caching supports prompt embeds inputs, this # check can be removed. - if (self.enable_prompt_embeds - and self.enable_prefix_caching is not False): + if self.enable_prompt_embeds and self.enable_prefix_caching is not False: logger.warning( "--enable-prompt-embeds and --enable-prefix-caching " "are not supported together in V1. Prefix caching has " - "been disabled.") + "been disabled." + ) self.enable_prefix_caching = False if self.enable_prefix_caching is None: - self.enable_prefix_caching = True + # Disable prefix caching default for hybrid models + # since the feature is still experimental. + if model_config.is_hybrid: + self.enable_prefix_caching = False + else: + self.enable_prefix_caching = True else: - pooling_type = model_config.pooler_config.pooling_type is_causal = getattr(model_config.hf_config, "is_causal", True) - incremental_prefill_supported = (pooling_type is not None - and pooling_type.lower() == "last" - and is_causal) + incremental_prefill_supported = ( + pooling_type is not None + and pooling_type.lower() == "last" + and is_causal + ) - action = "Enabling" if \ - incremental_prefill_supported else "Disabling" + action = "Enabling" if incremental_prefill_supported else "Disabling" if self.enable_chunked_prefill is None: self.enable_chunked_prefill = incremental_prefill_supported @@ -1607,6 +1734,7 @@ def _set_default_args(self, usage_context: UsageContext, # throughput, see PR #17885 for more details. # So here we do an extra device name check to prevent such regression. from vllm.usage.usage_lib import UsageContext + if device_memory >= 70 * GiB_bytes and "a100" not in device_name: # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { @@ -1632,15 +1760,15 @@ def _set_default_args(self, usage_context: UsageContext, if current_platform.is_tpu(): default_max_num_batched_tokens_tpu = { UsageContext.LLM_CLASS: { - 'V6E': 2048, - 'V5E': 1024, - 'V5P': 512, + "V6E": 2048, + "V5E": 1024, + "V5P": 512, }, UsageContext.OPENAI_API_SERVER: { - 'V6E': 1024, - 'V5E': 512, - 'V5P': 256, - } + "V6E": 1024, + "V5E": 512, + "V5P": 256, + }, } # cpu specific default values. @@ -1656,47 +1784,58 @@ def _set_default_args(self, usage_context: UsageContext, } use_context_value = usage_context.value if usage_context else None - if (self.max_num_batched_tokens is None - and usage_context in default_max_num_batched_tokens): + if ( + self.max_num_batched_tokens is None + and usage_context in default_max_num_batched_tokens + ): if current_platform.is_tpu(): chip_name = current_platform.get_device_name() - if chip_name in default_max_num_batched_tokens_tpu[ - usage_context]: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens_tpu[ - usage_context][chip_name] + if chip_name in default_max_num_batched_tokens_tpu[usage_context]: + self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[ + usage_context + ][chip_name] else: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens[usage_context] + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context + ] else: if not self.enable_chunked_prefill: self.max_num_batched_tokens = model_config.max_model_len else: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens[usage_context] + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context + ] logger.debug( "Setting max_num_batched_tokens to %d for %s usage context.", - self.max_num_batched_tokens, use_context_value) + self.max_num_batched_tokens, + use_context_value, + ) - if (self.max_num_seqs is None - and usage_context in default_max_num_seqs): - self.max_num_seqs = min(default_max_num_seqs[usage_context], - self.max_num_batched_tokens or sys.maxsize) + if self.max_num_seqs is None and usage_context in default_max_num_seqs: + self.max_num_seqs = min( + default_max_num_seqs[usage_context], + self.max_num_batched_tokens or sys.maxsize, + ) - logger.debug("Setting max_num_seqs to %d for %s usage context.", - self.max_num_seqs, use_context_value) + logger.debug( + "Setting max_num_seqs to %d for %s usage context.", + self.max_num_seqs, + use_context_value, + ) @dataclass class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" + enable_log_requests: bool = False @property @deprecated( "`disable_log_requests` is deprecated and has been replaced with " "`enable_log_requests`. This will be removed in v0.12.0. Please use " - "`enable_log_requests` instead.") + "`enable_log_requests` instead." + ) def disable_log_requests(self) -> bool: return not self.enable_log_requests @@ -1704,28 +1843,34 @@ def disable_log_requests(self) -> bool: @deprecated( "`disable_log_requests` is deprecated and has been replaced with " "`enable_log_requests`. This will be removed in v0.12.0. Please use " - "`enable_log_requests` instead.") + "`enable_log_requests` instead." + ) def disable_log_requests(self, value: bool): self.enable_log_requests = not value @staticmethod - def add_cli_args(parser: FlexibleArgumentParser, - async_args_only: bool = False) -> FlexibleArgumentParser: + def add_cli_args( + parser: FlexibleArgumentParser, async_args_only: bool = False + ) -> FlexibleArgumentParser: # Initialize plugin to update the parser, for example, The plugin may # add a new kind of quantization method to --quantization argument or # a new device to --device argument. load_general_plugins() if not async_args_only: parser = EngineArgs.add_cli_args(parser) - parser.add_argument('--enable-log-requests', - action=argparse.BooleanOptionalAction, - default=AsyncEngineArgs.enable_log_requests, - help='Enable logging requests.') - parser.add_argument('--disable-log-requests', - action=argparse.BooleanOptionalAction, - default=not AsyncEngineArgs.enable_log_requests, - help='[DEPRECATED] Disable logging requests.', - deprecated=True) + parser.add_argument( + "--enable-log-requests", + action=argparse.BooleanOptionalAction, + default=AsyncEngineArgs.enable_log_requests, + help="Enable logging requests.", + ) + parser.add_argument( + "--disable-log-requests", + action=argparse.BooleanOptionalAction, + default=not AsyncEngineArgs.enable_log_requests, + help="[DEPRECATED] Disable logging requests.", + deprecated=True, + ) current_platform.pre_register_and_update(parser) return parser @@ -1733,7 +1878,8 @@ def add_cli_args(parser: FlexibleArgumentParser, def _raise_or_fallback(feature_name: str, recommend_to_remove: bool): if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: raise NotImplementedError( - f"VLLM_USE_V1=1 is not supported with {feature_name}.") + f"VLLM_USE_V1=1 is not supported with {feature_name}." + ) msg = f"{feature_name} is not supported by the V1 Engine. " msg += "Falling back to V0. " if recommend_to_remove: @@ -1752,17 +1898,17 @@ def human_readable_int(value): - '25.6k' -> 25,600 """ value = value.strip() - match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value) + match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value) if match: decimal_multiplier = { - 'k': 10**3, - 'm': 10**6, - 'g': 10**9, + "k": 10**3, + "m": 10**6, + "g": 10**9, } binary_multiplier = { - 'K': 2**10, - 'M': 2**20, - 'G': 2**30, + "K": 2**10, + "M": 2**20, + "G": 2**30, } number, suffix = match.groups() @@ -1775,9 +1921,11 @@ def human_readable_int(value): try: return int(number) * mult except ValueError as e: - raise argparse.ArgumentTypeError("Decimals are not allowed " \ - f"with binary suffixes like {suffix}. Did you mean to use " \ - f"{number}{suffix.lower()} instead?") from e + raise argparse.ArgumentTypeError( + "Decimals are not allowed " + f"with binary suffixes like {suffix}. Did you mean to use " + f"{number}{suffix.lower()} instead?" + ) from e # Regular plain number. return int(value) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 2762175c430f..45b798ed96cb 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -from typing import Counter as CollectionsCounter -from typing import Dict, List, Optional, Type, Union, cast +from collections import Counter as CollectionsCounter +from typing import Optional, Union, cast import numpy as np import prometheus_client @@ -43,7 +43,7 @@ class Metrics: _counter_cls = prometheus_client.Counter _histogram_cls = prometheus_client.Histogram - def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + def __init__(self, labelnames: list[str], vllm_config: VllmConfig): # Unregister any existing vLLM collectors (for CI/CD) self._unregister_vllm_metrics() @@ -51,8 +51,7 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): # Use this flag to hide metrics that were deprecated in # a previous release and which will be removed future - self.show_hidden_metrics = \ - vllm_config.observability_config.show_hidden_metrics + self.show_hidden_metrics = vllm_config.observability_config.show_hidden_metrics # System stats # Scheduler State @@ -60,12 +59,14 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): name="vllm:num_requests_running", documentation="Number of requests currently running on GPU.", labelnames=labelnames, - multiprocess_mode="sum") + multiprocess_mode="sum", + ) self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames, - multiprocess_mode="sum") + multiprocess_mode="sum", + ) self.gauge_lora_info = self._gauge_cls( name="vllm:lora_requests_info", documentation="Running stats on lora requests.", @@ -82,93 +83,173 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames, - multiprocess_mode="sum") + multiprocess_mode="sum", + ) # Iteration stats self.counter_num_preemption = self._counter_cls( name="vllm:num_preemptions_total", documentation="Cumulative number of preemption from the engine.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_prompt_tokens = self._counter_cls( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_iteration_tokens = self._histogram_cls( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", labelnames=labelnames, - buckets=[ - 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 - ]) + buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + ) self.histogram_time_to_first_token = self._histogram_cls( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", labelnames=labelnames, buckets=[ - 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, - 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, - 2560.0 - ]) + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + 160.0, + 640.0, + 2560.0, + ], + ) # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds # TODO: in 0.12, only enable if show_hidden_metrics=True self.histogram_time_per_output_token = self._histogram_cls( name="vllm:time_per_output_token_seconds", documentation=( "Histogram of time per output token in seconds." - "DEPRECATED: Use vllm:inter_token_latency_seconds instead."), + "DEPRECATED: Use vllm:inter_token_latency_seconds instead." + ), labelnames=labelnames, buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, - 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 - ]) + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + ], + ) self.histogram_inter_token_latency = self._histogram_cls( name="vllm:inter_token_latency_seconds", documentation="Histogram of inter token latency in seconds.", labelnames=labelnames, buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, - 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 - ]) + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + ], + ) # Request stats # Latency request_latency_buckets = [ - 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, - 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + 0.3, + 0.5, + 0.8, + 1.0, + 1.5, + 2.0, + 2.5, + 5.0, + 10.0, + 15.0, + 20.0, + 30.0, + 40.0, + 50.0, + 60.0, + 120.0, + 240.0, + 480.0, + 960.0, + 1920.0, + 7680.0, ] self.histogram_e2e_time_request = self._histogram_cls( name="vllm:e2e_request_latency_seconds", documentation="Histogram of end to end request latency in seconds.", labelnames=labelnames, - buckets=request_latency_buckets) + buckets=request_latency_buckets, + ) self.histogram_queue_time_request = self._histogram_cls( name="vllm:request_queue_time_seconds", - documentation= - "Histogram of time spent in WAITING phase for request.", + documentation="Histogram of time spent in WAITING phase for request.", labelnames=labelnames, - buckets=request_latency_buckets) + buckets=request_latency_buckets, + ) self.histogram_inference_time_request = self._histogram_cls( name="vllm:request_inference_time_seconds", - documentation= - "Histogram of time spent in RUNNING phase for request.", + documentation="Histogram of time spent in RUNNING phase for request.", labelnames=labelnames, - buckets=request_latency_buckets) + buckets=request_latency_buckets, + ) self.histogram_prefill_time_request = self._histogram_cls( name="vllm:request_prefill_time_seconds", - documentation= - "Histogram of time spent in PREFILL phase for request.", + documentation="Histogram of time spent in PREFILL phase for request.", labelnames=labelnames, - buckets=request_latency_buckets) + buckets=request_latency_buckets, + ) self.histogram_decode_time_request = self._histogram_cls( name="vllm:request_decode_time_seconds", - documentation= - "Histogram of time spent in DECODE phase for request.", + documentation="Histogram of time spent in DECODE phase for request.", labelnames=labelnames, - buckets=request_latency_buckets) + buckets=request_latency_buckets, + ) # Metadata self.histogram_num_prompt_tokens_request = self._histogram_cls( @@ -177,19 +258,18 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): labelnames=labelnames, buckets=build_1_2_5_buckets(max_model_len), ) - self.histogram_num_generation_tokens_request = \ - self._histogram_cls( - name="vllm:request_generation_tokens", - documentation="Number of generation tokens processed.", - labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), - ) + self.histogram_num_generation_tokens_request = self._histogram_cls( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) self.histogram_max_num_generation_tokens_request = self._histogram_cls( name="vllm:request_max_num_generation_tokens", - documentation= - "Histogram of maximum number of requested generation tokens.", + documentation="Histogram of maximum number of requested generation tokens.", labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len)) + buckets=build_1_2_5_buckets(max_model_len), + ) self.histogram_n_request = self._histogram_cls( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", @@ -205,10 +285,10 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): self.counter_request_success = self._counter_cls( name="vllm:request_success_total", documentation="Count of successfully processed requests.", - labelnames=labelnames + [Metrics.labelname_finish_reason]) - + labelnames=labelnames + [Metrics.labelname_finish_reason], + ) -# --8<-- [end:metrics-definitions] + # --8<-- [end:metrics-definitions] def _unregister_vllm_metrics(self) -> None: for collector in list(prometheus_client.REGISTRY._collector_to_names): @@ -220,16 +300,18 @@ class _RayGaugeWrapper: """Wraps around ray.util.metrics.Gauge to provide same API as prometheus_client.Gauge""" - def __init__(self, - name: str, - documentation: str = "", - labelnames: Optional[List[str]] = None, - multiprocess_mode: str = ""): + def __init__( + self, + name: str, + documentation: str = "", + labelnames: Optional[list[str]] = None, + multiprocess_mode: str = "", + ): del multiprocess_mode labelnames_tuple = tuple(labelnames) if labelnames else None - self._gauge = ray_metrics.Gauge(name=name, - description=documentation, - tag_keys=labelnames_tuple) + self._gauge = ray_metrics.Gauge( + name=name, description=documentation, tag_keys=labelnames_tuple + ) def labels(self, **labels): self._gauge.set_default_tags(labels) @@ -247,14 +329,13 @@ class _RayCounterWrapper: """Wraps around ray.util.metrics.Counter to provide same API as prometheus_client.Counter""" - def __init__(self, - name: str, - documentation: str = "", - labelnames: Optional[List[str]] = None): + def __init__( + self, name: str, documentation: str = "", labelnames: Optional[list[str]] = None + ): labelnames_tuple = tuple(labelnames) if labelnames else None - self._counter = ray_metrics.Counter(name=name, - description=documentation, - tag_keys=labelnames_tuple) + self._counter = ray_metrics.Counter( + name=name, description=documentation, tag_keys=labelnames_tuple + ) def labels(self, **labels): self._counter.set_default_tags(labels) @@ -270,17 +351,21 @@ class _RayHistogramWrapper: """Wraps around ray.util.metrics.Histogram to provide same API as prometheus_client.Histogram""" - def __init__(self, - name: str, - documentation: str = "", - labelnames: Optional[List[str]] = None, - buckets: Optional[List[float]] = None): + def __init__( + self, + name: str, + documentation: str = "", + labelnames: Optional[list[str]] = None, + buckets: Optional[list[float]] = None, + ): labelnames_tuple = tuple(labelnames) if labelnames else None boundaries = buckets if buckets else [] - self._histogram = ray_metrics.Histogram(name=name, - description=documentation, - tag_keys=labelnames_tuple, - boundaries=boundaries) + self._histogram = ray_metrics.Histogram( + name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=boundaries, + ) def labels(self, **labels): self._histogram.set_default_tags(labels) @@ -295,14 +380,18 @@ class RayMetrics(Metrics): RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. Provides the same metrics as Metrics but uses Ray's util.metrics library. """ - _gauge_cls: Type[prometheus_client.Gauge] = cast( - Type[prometheus_client.Gauge], _RayGaugeWrapper) - _counter_cls: Type[prometheus_client.Counter] = cast( - Type[prometheus_client.Counter], _RayCounterWrapper) - _histogram_cls: Type[prometheus_client.Histogram] = cast( - Type[prometheus_client.Histogram], _RayHistogramWrapper) - - def __init__(self, labelnames: List[str], vllm_config: VllmConfig): + + _gauge_cls: type[prometheus_client.Gauge] = cast( + type[prometheus_client.Gauge], _RayGaugeWrapper + ) + _counter_cls: type[prometheus_client.Counter] = cast( + type[prometheus_client.Counter], _RayCounterWrapper + ) + _histogram_cls: type[prometheus_client.Histogram] = cast( + type[prometheus_client.Histogram], _RayHistogramWrapper + ) + + def __init__(self, labelnames: list[str], vllm_config: VllmConfig): if ray_metrics is None: raise ImportError("RayMetrics requires Ray to be installed.") super().__init__(labelnames, vllm_config) @@ -312,14 +401,14 @@ def _unregister_vllm_metrics(self) -> None: pass -def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: +def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]: """ Builds a list of buckets with increasing powers of 10 multiplied by mantissa values until the value exceeds the specified maximum. """ exponent = 0 - buckets: List[int] = [] + buckets: list[int] = [] while True: for m in mantissa_lst: value = m * 10**exponent @@ -330,7 +419,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]: exponent += 1 -def build_1_2_5_buckets(max_value: int) -> List[int]: +def build_1_2_5_buckets(max_value: int) -> list[int]: """ Example: >>> build_1_2_5_buckets(100) @@ -339,7 +428,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]: return build_buckets([1, 2, 5], max_value) -def build_1_2_3_5_8_buckets(max_value: int) -> List[int]: +def build_1_2_3_5_8_buckets(max_value: int) -> list[int]: """ Example: >>> build_1_2_3_5_8_buckets(100) @@ -348,14 +437,12 @@ def build_1_2_3_5_8_buckets(max_value: int) -> List[int]: return build_buckets([1, 2, 3, 5, 8], max_value) -def local_interval_elapsed(now: float, last_log: float, - local_interval: float) -> bool: +def local_interval_elapsed(now: float, last_log: float, local_interval: float) -> bool: elapsed_time = now - last_log return elapsed_time > local_interval -def get_throughput(tracked_stats: List[int], now: float, - last_log: float) -> float: +def get_throughput(tracked_stats: list[int], now: float, last_log: float) -> float: return float(np.sum(tracked_stats) / (now - last_log)) @@ -369,29 +456,32 @@ def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: def log(self, stats: Stats) -> None: """Called by LLMEngine. - Logs to Stdout every self.local_interval seconds.""" + Logs to Stdout every self.local_interval seconds.""" # Save tracked stats for token counters. self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter) # Log locally every local_interval seconds. - if local_interval_elapsed(stats.now, self.last_local_log, - self.local_interval): + if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): # Compute summary metrics for tracked stats (and log them # to prometheus if applicable). - prompt_throughput = get_throughput(self.num_prompt_tokens, - now=stats.now, - last_log=self.last_local_log) + prompt_throughput = get_throughput( + self.num_prompt_tokens, now=stats.now, last_log=self.last_local_log + ) generation_throughput = get_throughput( - self.num_generation_tokens, - now=stats.now, - last_log=self.last_local_log) + self.num_generation_tokens, now=stats.now, last_log=self.last_local_log + ) log_fn = logger.info - if not any((prompt_throughput, generation_throughput, - self.last_prompt_throughput, - self.last_generation_throughput)): + if not any( + ( + prompt_throughput, + generation_throughput, + self.last_prompt_throughput, + self.last_generation_throughput, + ) + ): # Avoid log noise on an idle production system log_fn = logger.debug @@ -409,8 +499,10 @@ def log(self, stats: Stats) -> None: stats.gpu_cache_usage_sys * 100, stats.cpu_cache_usage_sys * 100, ) - if (stats.cpu_prefix_cache_hit_rate >= 0 - or stats.gpu_prefix_cache_hit_rate >= 0): + if ( + stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0 + ): log_fn( "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", stats.gpu_prefix_cache_hit_rate * 100, @@ -433,16 +525,19 @@ def info(self, type: str, obj: SupportsMetricsInfo) -> None: class PrometheusStatLogger(StatLoggerBase): """PrometheusStatLogger is used LLMEngine to log to Prometheus.""" + _metrics_cls = Metrics _gauge_cls = prometheus_client.Gauge - def __init__(self, local_interval: float, labels: Dict[str, str], - vllm_config: VllmConfig) -> None: + def __init__( + self, local_interval: float, labels: dict[str, str], vllm_config: VllmConfig + ) -> None: super().__init__(local_interval, vllm_config) # Prometheus metrics self.labels = labels - self.metrics = self._metrics_cls(labelnames=list(labels.keys()), - vllm_config=vllm_config) + self.metrics = self._metrics_cls( + labelnames=list(labels.keys()), vllm_config=vllm_config + ) def _log_gauge(self, gauge, data: Union[int, float]) -> None: # Convenience function for logging to gauge. @@ -452,90 +547,106 @@ def _log_counter(self, counter, data: Union[int, float]) -> None: # Convenience function for logging to counter. # Prevent ValueError from negative increment if data < 0: - logger.warning("Skipping negative increment of %g to %s", data, - counter) + logger.warning("Skipping negative increment of %g to %s", data, counter) return counter.labels(**self.labels).inc(data) - def _log_counter_labels(self, counter, data: CollectionsCounter, - label_key: str) -> None: + def _log_counter_labels( + self, counter, data: CollectionsCounter, label_key: str + ) -> None: # Convenience function for collection counter of labels. for label, count in data.items(): counter.labels(**{**self.labels, label_key: label}).inc(count) - def _log_histogram(self, histogram, data: Union[List[int], - List[float]]) -> None: + def _log_histogram(self, histogram, data: Union[list[int], list[float]]) -> None: # Convenience function for logging list to histogram. for datum in data: histogram.labels(**self.labels).observe(datum) - def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None: + def _log_gauge_string(self, gauge, data: dict[str, str]) -> None: gauge.labels(**data).set_to_current_time() def _log_prometheus(self, stats: Stats) -> None: # System state data - self._log_gauge(self.metrics.gauge_scheduler_running, - stats.num_running_sys) - self._log_gauge(self.metrics.gauge_scheduler_waiting, - stats.num_waiting_sys) - self._log_gauge(self.metrics.gauge_gpu_cache_usage, - stats.gpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_scheduler_running, stats.num_running_sys) + self._log_gauge(self.metrics.gauge_scheduler_waiting, stats.num_waiting_sys) + self._log_gauge(self.metrics.gauge_gpu_cache_usage, stats.gpu_cache_usage_sys) # Including max-lora in metric, in future this property of lora # config maybe extended to be dynamic. lora_info = { - self.metrics.labelname_running_lora_adapters: - ",".join(stats.running_lora_adapters), - self.metrics.labelname_waiting_lora_adapters: - ",".join(stats.waiting_lora_adapters), - self.metrics.labelname_max_lora: - stats.max_lora, + self.metrics.labelname_running_lora_adapters: ",".join( + stats.running_lora_adapters + ), + self.metrics.labelname_waiting_lora_adapters: ",".join( + stats.waiting_lora_adapters + ), + self.metrics.labelname_max_lora: stats.max_lora, } self._log_gauge_string(self.metrics.gauge_lora_info, lora_info) # Iteration level data - self._log_counter(self.metrics.counter_num_preemption, - stats.num_preemption_iter) - self._log_counter(self.metrics.counter_prompt_tokens, - stats.num_prompt_tokens_iter) - self._log_counter(self.metrics.counter_generation_tokens, - stats.num_generation_tokens_iter) - self._log_histogram(self.metrics.histogram_iteration_tokens, - [stats.num_tokens_iter]) - self._log_histogram(self.metrics.histogram_time_to_first_token, - stats.time_to_first_tokens_iter) - self._log_histogram(self.metrics.histogram_time_per_output_token, - stats.inter_token_latencies_iter) - self._log_histogram(self.metrics.histogram_inter_token_latency, - stats.inter_token_latencies_iter) + self._log_counter( + self.metrics.counter_num_preemption, stats.num_preemption_iter + ) + self._log_counter( + self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter + ) + self._log_counter( + self.metrics.counter_generation_tokens, stats.num_generation_tokens_iter + ) + self._log_histogram( + self.metrics.histogram_iteration_tokens, [stats.num_tokens_iter] + ) + self._log_histogram( + self.metrics.histogram_time_to_first_token, stats.time_to_first_tokens_iter + ) + self._log_histogram( + self.metrics.histogram_time_per_output_token, + stats.inter_token_latencies_iter, + ) + self._log_histogram( + self.metrics.histogram_inter_token_latency, stats.inter_token_latencies_iter + ) # Request level data # Latency - self._log_histogram(self.metrics.histogram_e2e_time_request, - stats.time_e2e_requests) - self._log_histogram(self.metrics.histogram_queue_time_request, - stats.time_queue_requests) - self._log_histogram(self.metrics.histogram_inference_time_request, - stats.time_inference_requests) - self._log_histogram(self.metrics.histogram_prefill_time_request, - stats.time_prefill_requests) - self._log_histogram(self.metrics.histogram_decode_time_request, - stats.time_decode_requests) + self._log_histogram( + self.metrics.histogram_e2e_time_request, stats.time_e2e_requests + ) + self._log_histogram( + self.metrics.histogram_queue_time_request, stats.time_queue_requests + ) + self._log_histogram( + self.metrics.histogram_inference_time_request, stats.time_inference_requests + ) + self._log_histogram( + self.metrics.histogram_prefill_time_request, stats.time_prefill_requests + ) + self._log_histogram( + self.metrics.histogram_decode_time_request, stats.time_decode_requests + ) # Metadata - finished_reason_counter = CollectionsCounter( - stats.finished_reason_requests) - self._log_counter_labels(self.metrics.counter_request_success, - finished_reason_counter, - Metrics.labelname_finish_reason) - self._log_histogram(self.metrics.histogram_num_prompt_tokens_request, - stats.num_prompt_tokens_requests) + finished_reason_counter = CollectionsCounter(stats.finished_reason_requests) + self._log_counter_labels( + self.metrics.counter_request_success, + finished_reason_counter, + Metrics.labelname_finish_reason, + ) + self._log_histogram( + self.metrics.histogram_num_prompt_tokens_request, + stats.num_prompt_tokens_requests, + ) self._log_histogram( self.metrics.histogram_num_generation_tokens_request, - stats.num_generation_tokens_requests) + stats.num_generation_tokens_requests, + ) self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) self._log_histogram( self.metrics.histogram_max_num_generation_tokens_request, - stats.max_num_generation_tokens_requests) - self._log_histogram(self.metrics.histogram_max_tokens_request, - stats.max_tokens_requests) + stats.max_num_generation_tokens_requests, + ) + self._log_histogram( + self.metrics.histogram_max_tokens_request, stats.max_tokens_requests + ) def log(self, stats: Stats): """Logs to prometheus and tracked stats every iteration.""" @@ -547,9 +658,7 @@ def log(self, stats: Stats): self.num_generation_tokens.append(stats.num_generation_tokens_iter) # Log locally every local_interval seconds. - if local_interval_elapsed(stats.now, self.last_local_log, - self.local_interval): - + if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): # Reset tracked stats for next interval. self.num_prompt_tokens = [] self.num_generation_tokens = [] @@ -565,12 +674,14 @@ def info(self, type: str, obj: SupportsMetricsInfo) -> None: name="vllm:cache_config_info", documentation="Information of the LLMEngine CacheConfig", labelnames=metrics_info.keys(), - multiprocess_mode="mostrecent") + multiprocess_mode="mostrecent", + ) info_gauge.labels(**metrics_info).set(1) class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" + _metrics_cls = RayMetrics def info(self, type: str, obj: SupportsMetricsInfo) -> None: diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 9778ab5a8c99..ac796f4e1c75 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -4,7 +4,7 @@ These types are defined in this file to avoid importing vllm.engine.metrics and therefore importing prometheus_client. -This is required due to usage of Prometheus multiprocess mode to enable +This is required due to usage of Prometheus multiprocess mode to enable metrics after splitting out the uvicorn process from the engine process. Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR @@ -16,7 +16,6 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List from vllm.config import SupportsMetricsInfo, VllmConfig @@ -24,6 +23,7 @@ @dataclass class Stats: """Created by LLMEngine for use by StatLogger.""" + now: float # System stats (should have _sys suffix) @@ -42,26 +42,26 @@ class Stats: num_prompt_tokens_iter: int num_generation_tokens_iter: int num_tokens_iter: int - time_to_first_tokens_iter: List[float] - inter_token_latencies_iter: List[float] + time_to_first_tokens_iter: list[float] + inter_token_latencies_iter: list[float] num_preemption_iter: int # Request stats (should have _requests suffix) # Latency - time_e2e_requests: List[float] - time_queue_requests: List[float] - time_inference_requests: List[float] - time_prefill_requests: List[float] - time_decode_requests: List[float] + time_e2e_requests: list[float] + time_queue_requests: list[float] + time_inference_requests: list[float] + time_prefill_requests: list[float] + time_decode_requests: list[float] # Metadata - num_prompt_tokens_requests: List[int] - num_generation_tokens_requests: List[int] - n_requests: List[int] - max_num_generation_tokens_requests: List[int] - max_tokens_requests: List[int] - finished_reason_requests: List[str] - waiting_lora_adapters: List[str] - running_lora_adapters: List[str] + num_prompt_tokens_requests: list[int] + num_generation_tokens_requests: list[int] + n_requests: list[int] + max_num_generation_tokens_requests: list[int] + max_tokens_requests: list[int] + finished_reason_requests: list[str] + waiting_lora_adapters: list[str] + running_lora_adapters: list[str] max_lora: str @@ -70,8 +70,8 @@ class StatLoggerBase(ABC): def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: # Tracked stats over current local logging interval. - self.num_prompt_tokens: List[int] = [] - self.num_generation_tokens: List[int] = [] + self.num_prompt_tokens: list[int] = [] + self.num_generation_tokens: list[int] = [] self.last_local_log = time.time() self.local_interval = local_interval diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index bc917f2f57f0..feb2e841c83a 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,8 @@ import asyncio from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union +from collections.abc import AsyncGenerator, Iterable, Mapping +from typing import Any, Optional, Union from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import ModelConfig, VllmConfig @@ -29,23 +30,19 @@ class EngineClient(ABC): @property @abstractmethod - def is_running(self) -> bool: - ... + def is_running(self) -> bool: ... @property @abstractmethod - def is_stopped(self) -> bool: - ... + def is_stopped(self) -> bool: ... @property @abstractmethod - def errored(self) -> bool: - ... + def errored(self) -> bool: ... @property @abstractmethod - def dead_error(self) -> BaseException: - ... + def dead_error(self) -> BaseException: ... @abstractmethod def generate( @@ -71,7 +68,6 @@ async def beam_search( params: BeamSearchParams, lora_request: Optional[LoRARequest] = None, ) -> AsyncGenerator[RequestOutput, None]: - beam_width = params.beam_width max_tokens = params.max_tokens ignore_eos = params.ignore_eos @@ -112,8 +108,7 @@ async def beam_search( tokenized_length = len(prompt_token_ids) - sort_beams_key = create_sort_beams_key_function( - eos_token_id, length_penalty) + sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) beam_search_params = SamplingParams( logprobs=2 * beam_width, @@ -121,35 +116,49 @@ async def beam_search( temperature=temperature, ) all_beams = [ - BeamSearchSequence(tokens=prompt_token_ids, - cum_logprob=0, - logprobs=[], - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, - lora_request=lora_request) + BeamSearchSequence( + tokens=prompt_token_ids, + cum_logprob=0, + logprobs=[], + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + lora_request=lora_request, + ) ] completed = [] for _ in range(max_tokens): - prompts_batch, lora_req_batch = zip(*[( - TokensPrompt(prompt_token_ids=beam.tokens, - multi_modal_data=beam.multi_modal_data, - mm_processor_kwargs=beam.mm_processor_kwargs), - beam.lora_request, - ) for beam in all_beams]) + prompts_batch, lora_req_batch = zip( + *[ + ( + TokensPrompt( + prompt_token_ids=beam.tokens, + multi_modal_data=beam.multi_modal_data, + mm_processor_kwargs=beam.mm_processor_kwargs, + ), + beam.lora_request, + ) + for beam in all_beams + ] + ) tasks = [] request_id = f"beam_search-{random_uuid()}" - for i, (individual_prompt, - lora_req) in enumerate(zip(prompts_batch, lora_req_batch)): + for i, (individual_prompt, lora_req) in enumerate( + zip(prompts_batch, lora_req_batch) + ): request_id_item = f"{request_id}-{i}" task = asyncio.create_task( collect_from_async_generator( - self.generate(individual_prompt, - beam_search_params, - request_id_item, - lora_request=lora_req))) + self.generate( + individual_prompt, + beam_search_params, + request_id_item, + lora_request=lora_req, + ) + ) + ) tasks.append(task) output = await asyncio.gather(*tasks) @@ -163,32 +172,31 @@ async def beam_search( if result.outputs[0].logprobs is not None: logprobs = result.outputs[0].logprobs[0] for token_id, logprob_obj in logprobs.items(): - if token_id == eos_token_id and \ - not ignore_eos: + if token_id == eos_token_id and not ignore_eos: completed.append( BeamSearchSequence( - tokens=current_beam.tokens + - [token_id] if include_stop_str_in_output + tokens=current_beam.tokens + [token_id] + if include_stop_str_in_output else current_beam.tokens, - logprobs=current_beam.logprobs + - [logprobs], - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob, + logprobs=current_beam.logprobs + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, finish_reason="stop", - stop_reason=eos_token_id)) + stop_reason=eos_token_id, + ) + ) else: new_beams.append( BeamSearchSequence( tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + - [logprobs], + logprobs=current_beam.logprobs + [logprobs], lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob, - multi_modal_data=current_beam. - multi_modal_data, - mm_processor_kwargs=current_beam. - mm_processor_kwargs)) + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam.multi_modal_data, + mm_processor_kwargs=current_beam.mm_processor_kwargs, + ) + ) sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) all_beams = sorted_beams[:beam_width] @@ -198,7 +206,7 @@ async def beam_search( best_beams = sorted_completed[:beam_width] for beam in best_beams: - if (beam.tokens[-1] == eos_token_id and not ignore_eos): + if beam.tokens[-1] == eos_token_id and not ignore_eos: # Skip the eos token in the text. tokens = beam.tokens[tokenized_length:-1] else: @@ -209,19 +217,23 @@ async def beam_search( request_id=request_id, prompt=prompt_text, outputs=[ - CompletionOutput(text=beam.text, - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens[tokenized_length:], - index=i, - logprobs=beam.logprobs, - finish_reason=beam.finish_reason if - beam.finish_reason is not None else "length", - stop_reason=beam.stop_reason) + CompletionOutput( + text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + finish_reason=beam.finish_reason + if beam.finish_reason is not None + else "length", + stop_reason=beam.stop_reason, + ) for (i, beam) in enumerate(best_beams) ], finished=True, prompt_token_ids=prompt_token_ids, - prompt_logprobs=None) + prompt_logprobs=None, + ) @abstractmethod def encode( @@ -271,12 +283,10 @@ async def get_io_processor(self) -> IOProcessor: raise NotImplementedError @abstractmethod - async def is_tracing_enabled(self) -> bool: - ... + async def is_tracing_enabled(self) -> bool: ... @abstractmethod - async def do_log_stats(self) -> None: - ... + async def do_log_stats(self) -> None: ... @abstractmethod async def check_health(self) -> None: @@ -299,8 +309,7 @@ async def reset_mm_cache(self) -> None: ... @abstractmethod - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: """Reset the prefix cache""" ... @@ -324,17 +333,19 @@ async def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into the engine for future requests.""" ... - async def scale_elastic_ep(self, - new_data_parallel_size: int, - drain_timeout: int = 300) -> None: + async def scale_elastic_ep( + self, new_data_parallel_size: int, drain_timeout: int = 300 + ) -> None: """Scale the engine""" raise NotImplementedError - async def collective_rpc(self, - method: str, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None): + async def collective_rpc( + self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + ): """Perform a collective RPC call to the given path.""" raise NotImplementedError diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 3d1e5dc14d2f..c31d15ddac4f 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -7,6 +7,7 @@ We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. """ + import asyncio import json import ssl @@ -68,9 +69,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: prompt = request_output.prompt assert prompt is not None - text_outputs = [ - prompt + output.text for output in request_output.outputs - ] + text_outputs = [prompt + output.text for output in request_output.outputs] ret = {"text": text_outputs} yield (json.dumps(ret) + "\n").encode("utf-8") @@ -109,16 +108,20 @@ async def init_app( global engine engine_args = AsyncEngineArgs.from_cli_args(args) - engine = (llm_engine - if llm_engine is not None else AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.API_SERVER)) + engine = ( + llm_engine + if llm_engine is not None + else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER + ) + ) app.state.engine_client = engine return app -async def run_server(args: Namespace, - llm_engine: Optional[AsyncLLMEngine] = None, - **uvicorn_kwargs: Any) -> None: +async def run_server( + args: Namespace, llm_engine: Optional[AsyncLLMEngine] = None, **uvicorn_kwargs: Any +) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) @@ -151,26 +154,27 @@ async def run_server(args: Namespace, parser.add_argument("--port", type=parser.check_port, default=8000) parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None) - parser.add_argument("--ssl-ca-certs", - type=str, - default=None, - help="The CA certificates file") + parser.add_argument( + "--ssl-ca-certs", type=str, default=None, help="The CA certificates file" + ) parser.add_argument( "--enable-ssl-refresh", action="store_true", default=False, - help="Refresh SSL Context when SSL certificate files change") + help="Refresh SSL Context when SSL certificate files change", + ) parser.add_argument( "--ssl-cert-reqs", type=int, default=int(ssl.CERT_NONE), - help="Whether client certificate is required (see stdlib ssl module's)" + help="Whether client certificate is required (see stdlib ssl module's)", ) parser.add_argument( "--root-path", type=str, default=None, - help="FastAPI root_path when app is behind a path based routing proxy") + help="FastAPI root_path when app is behind a path based routing proxy", + ) parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6b0ed23277d3..24eac17950fe 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -8,8 +8,7 @@ from collections.abc import Awaitable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path -from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, - cast) +from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast import jinja2 import jinja2.ext @@ -18,41 +17,37 @@ import jinja2.parser import jinja2.sandbox import transformers.utils.chat_template_utils as hf_chat_utils -# yapf conflicts with isort for this block -# yapf: disable -from openai.types.chat import (ChatCompletionAssistantMessageParam, - ChatCompletionContentPartImageParam, - ChatCompletionContentPartInputAudioParam) from openai.types.chat import ( - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) -from openai.types.chat import (ChatCompletionContentPartRefusalParam, - ChatCompletionContentPartTextParam) + ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartInputAudioParam, + ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam, + ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam, +) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, +) from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) -from openai.types.chat import (ChatCompletionMessageToolCallParam, - ChatCompletionToolMessageParam) -from openai.types.chat.chat_completion_content_part_input_audio_param import ( - InputAudio) + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam, +) +from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio from openai.types.responses import ResponseInputImageParam from openai_harmony import Message as OpenAIHarmonyMessage from PIL import Image from pydantic import BaseModel, ConfigDict, TypeAdapter -# yapf: enable -from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, - ProcessorMixin) +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin + # pydantic needs the TypedDict from typing_extensions from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal -from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, - MultiModalUUIDDict) +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import MediaConnector -# yapf: disable -from vllm.transformers_utils.chat_templates import ( - get_chat_template_fallback_path) -# yapf: enable +from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import random_uuid, supports_kw @@ -284,9 +279,11 @@ def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: if isinstance(node, jinja2.nodes.Getitem): - return (_is_var_access(node.node, varname) - and isinstance(node.arg, jinja2.nodes.Const) - and node.arg.value == key) + return ( + _is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key + ) if isinstance(node, jinja2.nodes.Getattr): return _is_var_access(node.node, varname) and node.attr == key @@ -301,19 +298,17 @@ def _is_var_or_elems_access( ) -> bool: if isinstance(node, jinja2.nodes.Filter): return node.node is not None and _is_var_or_elems_access( - node.node, varname, key) + node.node, varname, key + ) if isinstance(node, jinja2.nodes.Test): return _is_var_or_elems_access(node.node, varname, key) if isinstance(node, jinja2.nodes.Getitem) and isinstance( - node.arg, jinja2.nodes.Slice): + node.arg, jinja2.nodes.Slice + ): return _is_var_or_elems_access(node.node, varname, key) - # yapf: disable - return ( - _is_attr_access(node, varname, key) if key - else _is_var_access(node, varname) - ) # yapf: enable + return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): @@ -342,8 +337,7 @@ def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): # the scope in which each variable is defined, but that is too complicated def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): messages_varnames = [ - varname - for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") ] # Search for {%- for message in messages -%} loops @@ -484,8 +478,7 @@ def resolve_hf_chat_template( # 2nd priority: AutoProcessor chat template, unless tool calling is enabled if tools is None: - chat_template = _try_get_processor_chat_template(tokenizer, - model_config) + chat_template = _try_get_processor_chat_template(tokenizer, model_config) if chat_template is not None: return chat_template @@ -678,16 +671,12 @@ def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]: mm_uuids = {} uuids_by_modality = dict(self._uuids_by_modality) if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality: - raise ValueError( - "Mixing raw image and embedding inputs is not allowed" - ) + raise ValueError("Mixing raw image and embedding inputs is not allowed") if "image_embeds" in uuids_by_modality: image_embeds_uuids = uuids_by_modality["image_embeds"] if len(image_embeds_uuids) > 1: - raise ValueError( - "Only one message can have {'type': 'image_embeds'}" - ) + raise ValueError("Only one message can have {'type': 'image_embeds'}") mm_uuids["image"] = uuids_by_modality["image_embeds"] if "image" in uuids_by_modality: mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images @@ -709,16 +698,12 @@ def all_mm_data(self) -> Optional[MultiModalDataDict]: mm_inputs = {} items_by_modality = dict(self._items_by_modality) if "image" in items_by_modality and "image_embeds" in items_by_modality: - raise ValueError( - "Mixing raw image and embedding inputs is not allowed" - ) + raise ValueError("Mixing raw image and embedding inputs is not allowed") if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: - raise ValueError( - "Only one message can have {'type': 'image_embeds'}" - ) + raise ValueError("Only one message can have {'type': 'image_embeds'}") mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images @@ -748,16 +733,12 @@ async def all_mm_data(self) -> Optional[MultiModalDataDict]: items_by_modality[modality] = await asyncio.gather(*coros) if "image" in items_by_modality and "image_embeds" in items_by_modality: - raise ValueError( - "Mixing raw image and embedding inputs is not allowed" - ) + raise ValueError("Mixing raw image and embedding inputs is not allowed") if "image_embeds" in items_by_modality: image_embeds_lst = items_by_modality["image_embeds"] if len(image_embeds_lst) > 1: - raise ValueError( - "Only one message can have {'type': 'image_embeds'}" - ) + raise ValueError("Only one message can have {'type': 'image_embeds'}") mm_inputs["image"] = image_embeds_lst[0] if "image" in items_by_modality: mm_inputs["image"] = items_by_modality["image"] # A list of images @@ -783,9 +764,7 @@ def __init__(self) -> None: # } self._placeholder_storage: dict[str, list] = defaultdict(list) - def _add_placeholder( - self, modality: ModalityStr, placeholder: Optional[str] - ): + def _add_placeholder(self, modality: ModalityStr, placeholder: Optional[str]): mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality] if placeholder: self._placeholder_storage[mod_placeholder].append(placeholder) @@ -794,8 +773,7 @@ def mm_placeholder_storage(self) -> dict[str, list]: return dict(self._placeholder_storage) @abstractmethod - def parse_image( - self, image_url: Optional[str], uuid: Optional[str] = None) -> None: + def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None: raise NotImplementedError @abstractmethod @@ -813,9 +791,7 @@ def parse_image_pil( raise NotImplementedError @abstractmethod - def parse_audio( - self, audio_url: Optional[str], uuid: Optional[str] = None - ) -> None: + def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None: raise NotImplementedError @abstractmethod @@ -825,9 +801,7 @@ def parse_input_audio( raise NotImplementedError @abstractmethod - def parse_video( - self, video_url: Optional[str], uuid: Optional[str] = None - ) -> None: + def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: raise NotImplementedError @@ -844,9 +818,7 @@ def __init__(self, tracker: MultiModalItemTracker) -> None: allowed_media_domains=tracker.allowed_media_domains, ) - def parse_image( - self, image_url: Optional[str], uuid: Optional[str] = None - ) -> None: + def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None: image = self._connector.fetch_image(image_url) if image_url else None placeholder = self._tracker.add("image", image, uuid) @@ -879,9 +851,7 @@ def parse_image_pil( placeholder = self._tracker.add("image", image_pil, uuid) self._add_placeholder("image", placeholder) - def parse_audio( - self, audio_url: Optional[str], uuid: Optional[str] = None - ) -> None: + def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None: audio = self._connector.fetch_audio(audio_url) if audio_url else None placeholder = self._tracker.add("audio", audio, uuid) @@ -903,14 +873,8 @@ def parse_input_audio( return self.parse_audio(audio_url, uuid) - def parse_video( - self, video_url: Optional[str], uuid: Optional[str] = None - ) -> None: - video = ( - self._connector.fetch_video(video_url=video_url) - if video_url - else None - ) + def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: + video = self._connector.fetch_video(video_url=video_url) if video_url else None placeholder = self._tracker.add("video", video, uuid) self._add_placeholder("video", placeholder) @@ -929,12 +893,8 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: allowed_media_domains=tracker.allowed_media_domains, ) - def parse_image( - self, image_url: Optional[str], uuid: Optional[str] = None - ) -> None: - image_coro = ( - self._connector.fetch_image_async(image_url) if image_url else None - ) + def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None: + image_coro = self._connector.fetch_image_async(image_url) if image_url else None placeholder = self._tracker.add("image", image_coro, uuid) self._add_placeholder("image", placeholder) @@ -944,9 +904,7 @@ def parse_image_embeds( image_embeds: Union[str, dict[str, str], None], uuid: Optional[str] = None, ) -> None: - future: asyncio.Future[Union[str, dict[str, str], None]] = ( - asyncio.Future() - ) + future: asyncio.Future[Union[str, dict[str, str], None]] = asyncio.Future() if isinstance(image_embeds, dict): embeds = { @@ -977,12 +935,8 @@ def parse_image_pil( placeholder = self._tracker.add("image", future, uuid) self._add_placeholder("image", placeholder) - def parse_audio( - self, audio_url: Optional[str], uuid: Optional[str] = None - ) -> None: - audio_coro = ( - self._connector.fetch_audio_async(audio_url) if audio_url else None - ) + def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None: + audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None placeholder = self._tracker.add("audio", audio_coro, uuid) self._add_placeholder("audio", placeholder) @@ -1003,9 +957,7 @@ def parse_input_audio( return self.parse_audio(audio_url, uuid) - def parse_video( - self, video_url: Optional[str], uuid: Optional[str] = None - ) -> None: + def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None: video = ( self._connector.fetch_video_async(video_url=video_url) if video_url @@ -1036,9 +988,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): ) else: - raise TypeError( - f"{type(chat_template)} is not a valid chat template type" - ) + raise TypeError(f"{type(chat_template)} is not a valid chat template type") def _load_chat_template( @@ -1145,9 +1095,7 @@ def _get_full_multimodal_text_prompt( "actual multimodal data items." ) - missing_placeholders.extend( - [placeholder] * placeholder_counts[placeholder] - ) + missing_placeholders.extend([placeholder] * placeholder_counts[placeholder]) # NOTE: Default behaviour: we always add missing placeholders # at the front of the prompt, if interleave_strings=False @@ -1166,9 +1114,7 @@ def _get_full_multimodal_text_prompt( _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python _VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python -_ResponsesInputImageParser = TypeAdapter( - ResponseInputImageParam -).validate_python +_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage] # Define a mapping from part types to their corresponding parsing functions. @@ -1179,26 +1125,14 @@ def _get_full_multimodal_text_prompt( "text": lambda part: _TextParser(part).get("text", None), "thinking": lambda part: _ThinkParser(part).get("thinking", None), "input_text": lambda part: _TextParser(part).get("text", None), - "input_image": lambda part: _ResponsesInputImageParser(part).get( - "image_url", None - ), - "image_url": lambda part: _ImageParser(part) - .get("image_url", {}) - .get("url", None), - "image_embeds": lambda part: _ImageEmbedsParser(part).get( - "image_embeds", None - ), + "input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None), + "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None), + "image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None), "image_pil": lambda part: _PILImageParser(part).get("image_pil", None), - "audio_url": lambda part: _AudioParser(part) - .get("audio_url", {}) - .get("url", None), - "input_audio": lambda part: _InputAudioParser(part).get( - "input_audio", None - ), + "audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), + "input_audio": lambda part: _InputAudioParser(part).get("input_audio", None), "refusal": lambda part: _RefusalParser(part).get("refusal", None), - "video_url": lambda part: _VideoParser(part) - .get("video_url", {}) - .get("url", None), + "video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None), } @@ -1225,15 +1159,14 @@ def _parse_chat_message_content_mm_part( part_type = part.get("type", None) uuid = part.get("uuid", None) - if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501 + if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501 content = MM_PARSER_MAP[part_type](part) # Special case for 'image_url.detail' # We only support 'auto', which is the default if part_type == "image_url" and part.get("detail", "auto") != "auto": logger.warning( - "'image_url.detail' is currently not supported " - "and will be ignored." + "'image_url.detail' is currently not supported and will be ignored." ) return part_type, content @@ -1242,9 +1175,7 @@ def _parse_chat_message_content_mm_part( # 'type' is required field by pydantic if part_type is None or uuid is not None: if "image_url" in part: - image_params = cast( - CustomChatCompletionContentSimpleImageParam, part - ) + image_params = cast(CustomChatCompletionContentSimpleImageParam, part) image_url = image_params.get("image_url", None) if isinstance(image_url, dict): # Can potentially happen if user provides a uuid @@ -1253,22 +1184,20 @@ def _parse_chat_message_content_mm_part( return "image_url", image_url if "image_pil" in part: # "image_pil" could be None if UUID is provided. - image_params = cast( # type: ignore + image_params = cast( # type: ignore CustomChatCompletionContentPILImageParam, part ) image_pil = image_params.get("image_pil", None) return "image_pil", image_pil if "image_embeds" in part: # "image_embeds" could be None if UUID is provided. - image_params = cast( # type: ignore + image_params = cast( # type: ignore ChatCompletionContentPartImageEmbedsParam, part ) image_embeds = image_params.get("image_embeds", None) return "image_embeds", image_embeds if "audio_url" in part: - audio_params = cast( - CustomChatCompletionContentSimpleAudioParam, part - ) + audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part) audio_url = audio_params.get("audio_url", None) if isinstance(audio_url, dict): # Can potentially happen if user provides a uuid @@ -1279,9 +1208,7 @@ def _parse_chat_message_content_mm_part( input_audio_params = cast(dict[str, str], part) return "input_audio", input_audio_params if "video_url" in part: - video_params = cast( - CustomChatCompletionContentSimpleVideoParam, part - ) + video_params = cast(CustomChatCompletionContentSimpleVideoParam, part) video_url = video_params.get("video_url", None) if isinstance(video_url, dict): # Can potentially happen if user provides a uuid @@ -1383,10 +1310,7 @@ def _parse_chat_message_content_part( modality = None if part_type == "image_pil": - if content is not None: - image_content = cast(Image.Image, content) - else: - image_content = None + image_content = cast(Image.Image, content) if content is not None else None mm_parser.parse_image_pil(image_content, uuid) modality = "image" elif part_type in ("image_url", "input_image"): @@ -1418,9 +1342,7 @@ def _parse_chat_message_content_part( return ( {"type": modality} if wrap_dicts - else ( - MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None - ) + else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None) ) @@ -1441,9 +1363,7 @@ def _parse_chat_message_content( if content is None: content = [] elif isinstance(content, str): - content = [ - ChatCompletionContentPartTextParam(type="text", text=content) - ] + content = [ChatCompletionContentPartTextParam(type="text", text=content)] result = _parse_chat_message_content_parts( role, content, # type: ignore @@ -1459,10 +1379,7 @@ def _parse_chat_message_content( # The 'tool_calls' is not None check ensures compatibility. # It's needed only if downstream code doesn't strictly # follow the OpenAI spec. - if ( - "tool_calls" in parsed_msg - and parsed_msg["tool_calls"] is not None - ): + if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None: result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) elif role == "tool": parsed_msg = _ToolParser(message) @@ -1572,31 +1489,40 @@ def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: return call_block.set_lineno(lineno) +def _resolve_chat_template_kwargs( + chat_template: str, +): + env = jinja2.sandbox.ImmutableSandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[AssistantTracker, jinja2.ext.loopcontrols], + ) + parsed_content = env.parse(chat_template) + template_vars = jinja2.meta.find_undeclared_variables(parsed_content) + return template_vars + + +_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs) + + def resolve_chat_template_kwargs( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], chat_template: str, chat_template_kwargs: dict[str, Any], ) -> dict[str, Any]: fn_kw = { - k for k in chat_template_kwargs + k + for k in chat_template_kwargs if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) } - env = jinja2.sandbox.ImmutableSandboxedEnvironment( - trim_blocks=True, - lstrip_blocks=True, - extensions=[AssistantTracker, jinja2.ext.loopcontrols], - ) - parsed_content = env.parse(chat_template) - template_vars = jinja2.meta.find_undeclared_variables(parsed_content) + template_vars = _cached_resolve_chat_template_kwargs(chat_template) # We exclude chat_template from kwargs here, because # chat template has been already resolved at this stage unexpected_vars = {"chat_template"} accept_vars = (fn_kw | template_vars) - unexpected_vars - return { - k: v for k, v in chat_template_kwargs.items() if k in accept_vars - } + return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars} def apply_hf_chat_template( diff --git a/vllm/entrypoints/cli/__init__.py b/vllm/entrypoints/cli/__init__.py index 41671b5b98ab..211e157fc7c8 100644 --- a/vllm/entrypoints/cli/__init__.py +++ b/vllm/entrypoints/cli/__init__.py @@ -2,11 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand -from vllm.entrypoints.cli.benchmark.throughput import ( - BenchmarkThroughputSubcommand) +from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand __all__: list[str] = [ "BenchmarkLatencySubcommand", "BenchmarkServingSubcommand", "BenchmarkThroughputSubcommand", -] \ No newline at end of file +] diff --git a/vllm/entrypoints/cli/benchmark/base.py b/vllm/entrypoints/cli/benchmark/base.py index 0c22bc75105e..3263459fd681 100644 --- a/vllm/entrypoints/cli/benchmark/base.py +++ b/vllm/entrypoints/cli/benchmark/base.py @@ -6,7 +6,7 @@ class BenchmarkSubcommandBase(CLISubcommand): - """ The base class of subcommands for vllm bench. """ + """The base class of subcommands for vllm bench.""" help: str diff --git a/vllm/entrypoints/cli/benchmark/latency.py b/vllm/entrypoints/cli/benchmark/latency.py index 3e68963cfd44..548ddf4d603e 100644 --- a/vllm/entrypoints/cli/benchmark/latency.py +++ b/vllm/entrypoints/cli/benchmark/latency.py @@ -7,7 +7,7 @@ class BenchmarkLatencySubcommand(BenchmarkSubcommandBase): - """ The `latency` subcommand for vllm bench. """ + """The `latency` subcommand for vllm bench.""" name = "latency" help = "Benchmark the latency of a single batch of requests." diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py index 0c65fd97fc04..d7455daa1a6b 100644 --- a/vllm/entrypoints/cli/benchmark/main.py +++ b/vllm/entrypoints/cli/benchmark/main.py @@ -15,7 +15,7 @@ class BenchmarkSubcommand(CLISubcommand): - """ The `bench` subcommand for the vLLM CLI. """ + """The `bench` subcommand for the vLLM CLI.""" name = "bench" help = "vLLM bench subcommand." @@ -28,14 +28,14 @@ def validate(self, args: argparse.Namespace) -> None: pass def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: bench_parser = subparsers.add_parser( self.name, description=self.help, - usage=f"vllm {self.name} <bench_type> [options]") - bench_subparsers = bench_parser.add_subparsers(required=True, - dest="bench_type") + usage=f"vllm {self.name} <bench_type> [options]", + ) + bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type") for cmd_cls in BenchmarkSubcommandBase.__subclasses__(): cmd_subparser = bench_subparsers.add_parser( @@ -47,7 +47,8 @@ def subparser_init( cmd_subparser.set_defaults(dispatch_function=cmd_cls.cmd) cmd_cls.add_cli_args(cmd_subparser) cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format( - subcmd=f"{self.name} {cmd_cls.name}") + subcmd=f"{self.name} {cmd_cls.name}" + ) return bench_parser diff --git a/vllm/entrypoints/cli/benchmark/serve.py b/vllm/entrypoints/cli/benchmark/serve.py index 3dd7a46d6284..b085f52afb3b 100644 --- a/vllm/entrypoints/cli/benchmark/serve.py +++ b/vllm/entrypoints/cli/benchmark/serve.py @@ -7,7 +7,7 @@ class BenchmarkServingSubcommand(BenchmarkSubcommandBase): - """ The `serve` subcommand for vllm bench. """ + """The `serve` subcommand for vllm bench.""" name = "serve" help = "Benchmark the online serving throughput." diff --git a/vllm/entrypoints/cli/benchmark/throughput.py b/vllm/entrypoints/cli/benchmark/throughput.py index d5d43ad4a359..c25be75ec11e 100644 --- a/vllm/entrypoints/cli/benchmark/throughput.py +++ b/vllm/entrypoints/cli/benchmark/throughput.py @@ -7,7 +7,7 @@ class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase): - """ The `throughput` subcommand for vllm bench. """ + """The `throughput` subcommand for vllm bench.""" name = "throughput" help = "Benchmark offline inference throughput." diff --git a/vllm/entrypoints/cli/collect_env.py b/vllm/entrypoints/cli/collect_env.py index 785c18812adb..e79a7efec6ba 100644 --- a/vllm/entrypoints/cli/collect_env.py +++ b/vllm/entrypoints/cli/collect_env.py @@ -14,7 +14,8 @@ class CollectEnvSubcommand(CLISubcommand): - """The `collect-env` subcommand for the vLLM CLI. """ + """The `collect-env` subcommand for the vLLM CLI.""" + name = "collect-env" @staticmethod @@ -23,13 +24,14 @@ def cmd(args: argparse.Namespace) -> None: collect_env_main() def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: return subparsers.add_parser( "collect-env", help="Start collecting environment information.", description="Start collecting environment information.", - usage="vllm collect-env") + usage="vllm collect-env", + ) def cmd_init() -> list[CLISubcommand]: diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index f1bcbc8262bd..0ebfe1c22269 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -'''The CLI entrypoints of vLLM +"""The CLI entrypoints of vLLM Note that all future modules must be lazily loaded within main -to avoid certain eager import breakage.''' +to avoid certain eager import breakage.""" + from __future__ import annotations import importlib.metadata @@ -33,18 +34,17 @@ def main(): epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"), ) parser.add_argument( - '-v', - '--version', - action='version', - version=importlib.metadata.version('vllm'), + "-v", + "--version", + action="version", + version=importlib.metadata.version("vllm"), ) subparsers = parser.add_subparsers(required=False, dest="subparser") cmds = {} for cmd_module in CMD_MODULES: new_cmds = cmd_module.cmd_init() for cmd in new_cmds: - cmd.subparser_init(subparsers).set_defaults( - dispatch_function=cmd.cmd) + cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) cmds[cmd.name] = cmd args = parser.parse_args() if args.subparser in cmds: diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index 1929d6a7f77a..5372210bbf55 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -19,7 +19,6 @@ def _register_signal_handlers(): - def signal_handler(sig, frame): sys.exit(0) @@ -80,26 +79,29 @@ def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: break conversation.append({"role": "user", "content": input_message}) - stream = client.chat.completions.create(model=model_name, - messages=conversation, - stream=True) + stream = client.chat.completions.create( + model=model_name, messages=conversation, stream=True + ) output = _print_chat_stream(stream) conversation.append({"role": "assistant", "content": output}) -def _add_query_options( - parser: FlexibleArgumentParser) -> FlexibleArgumentParser: +def _add_query_options(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( "--url", type=str, default="http://localhost:8000/v1", - help="url of the running OpenAI-Compatible RESTful API server") + help="url of the running OpenAI-Compatible RESTful API server", + ) parser.add_argument( "--model-name", type=str, default=None, - help=("The model name used in prompt completion, default to " - "the first model in list models API call.")) + help=( + "The model name used in prompt completion, default to " + "the first model in list models API call." + ), + ) parser.add_argument( "--api-key", type=str, @@ -107,12 +109,14 @@ def _add_query_options( help=( "API key for OpenAI services. If provided, this api key " "will overwrite the api key obtained through environment variables." - )) + ), + ) return parser class ChatCommand(CLISubcommand): - """The `chat` subcommand for the vLLM CLI. """ + """The `chat` subcommand for the vLLM CLI.""" + name = "chat" @staticmethod @@ -127,9 +131,9 @@ def cmd(args: argparse.Namespace) -> None: if args.quick: conversation.append({"role": "user", "content": args.quick}) - stream = client.chat.completions.create(model=model_name, - messages=conversation, - stream=True) + stream = client.chat.completions.create( + model=model_name, messages=conversation, stream=True + ) output = _print_chat_stream(stream) conversation.append({"role": "assistant", "content": output}) return @@ -142,9 +146,9 @@ def cmd(args: argparse.Namespace) -> None: break conversation.append({"role": "user", "content": input_message}) - stream = client.chat.completions.create(model=model_name, - messages=conversation, - stream=True) + stream = client.chat.completions.create( + model=model_name, messages=conversation, stream=True + ) output = _print_chat_stream(stream) conversation.append({"role": "assistant", "content": output}) @@ -156,39 +160,45 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--system-prompt", type=str, default=None, - help=("The system prompt to be added to the chat template, " - "used for models that support system prompts.")) - parser.add_argument("-q", - "--quick", - type=str, - metavar="MESSAGE", - help=("Send a single prompt as MESSAGE " - "and print the response, then exit.")) + help=( + "The system prompt to be added to the chat template, " + "used for models that support system prompts." + ), + ) + parser.add_argument( + "-q", + "--quick", + type=str, + metavar="MESSAGE", + help=("Send a single prompt as MESSAGE and print the response, then exit."), + ) return parser def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: parser = subparsers.add_parser( "chat", help="Generate chat completions via the running API server.", description="Generate chat completions via the running API server.", - usage="vllm chat [options]") + usage="vllm chat [options]", + ) return ChatCommand.add_cli_args(parser) class CompleteCommand(CLISubcommand): - """The `complete` subcommand for the vLLM CLI. """ - name = 'complete' + """The `complete` subcommand for the vLLM CLI.""" + + name = "complete" @staticmethod def cmd(args: argparse.Namespace) -> None: model_name, client = _interactive_cli(args) if args.quick: - stream = client.completions.create(model=model_name, - prompt=args.quick, - stream=True) + stream = client.completions.create( + model=model_name, prompt=args.quick, stream=True + ) _print_completion_stream(stream) return @@ -198,9 +208,9 @@ def cmd(args: argparse.Namespace) -> None: input_prompt = input("> ") except EOFError: break - stream = client.completions.create(model=model_name, - prompt=input_prompt, - stream=True) + stream = client.completions.create( + model=model_name, prompt=input_prompt, stream=True + ) _print_completion_stream(stream) @staticmethod @@ -212,20 +222,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--quick", type=str, metavar="PROMPT", - help= - "Send a single prompt and print the completion output, then exit.") + help="Send a single prompt and print the completion output, then exit.", + ) return parser def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: parser = subparsers.add_parser( "complete", - help=("Generate text completions based on the given prompt " - "via the running API server."), - description=("Generate text completions based on the given prompt " - "via the running API server."), - usage="vllm complete [options]") + help=( + "Generate text completions based on the given prompt " + "via the running API server." + ), + description=( + "Generate text completions based on the given prompt " + "via the running API server." + ), + usage="vllm complete [options]", + ) return CompleteCommand.add_cli_args(parser) diff --git a/vllm/entrypoints/cli/run_batch.py b/vllm/entrypoints/cli/run_batch.py index e669464bff83..6e7a15ada49c 100644 --- a/vllm/entrypoints/cli/run_batch.py +++ b/vllm/entrypoints/cli/run_batch.py @@ -20,14 +20,16 @@ class RunBatchSubcommand(CLISubcommand): """The `run-batch` subcommand for vLLM CLI.""" + name = "run-batch" @staticmethod def cmd(args: argparse.Namespace) -> None: from vllm.entrypoints.openai.run_batch import main as run_batch_main - logger.info("vLLM batch processing API version %s", - importlib.metadata.version("vllm")) + logger.info( + "vLLM batch processing API version %s", importlib.metadata.version("vllm") + ) logger.info("args: %s", args) # Start the Prometheus metrics server. @@ -44,8 +46,8 @@ def cmd(args: argparse.Namespace) -> None: asyncio.run(run_batch_main(args)) def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: from vllm.entrypoints.openai.run_batch import make_arg_parser run_batch_parser = subparsers.add_parser( @@ -53,13 +55,12 @@ def subparser_init( help="Run batch prompts and write results to file.", description=( "Run batch prompts using vLLM's OpenAI-compatible API.\n" - "Supports local or HTTP input/output files."), - usage= - "vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>", + "Supports local or HTTP input/output files." + ), + usage="vllm run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>", ) run_batch_parser = make_arg_parser(run_batch_parser) - run_batch_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format( - subcmd=self.name) + run_batch_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name) return run_batch_parser diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 0a5547144800..b3960b74cf01 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -10,21 +10,26 @@ import vllm import vllm.envs as envs from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.openai.api_server import (run_server, run_server_worker, - setup_server) -from vllm.entrypoints.openai.cli_args import (make_arg_parser, - validate_parsed_serve_args) +from vllm.entrypoints.openai.api_server import ( + run_server, + run_server_worker, + setup_server, +) +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import (FlexibleArgumentParser, decorate_logs, get_tcp_uri, - set_process_title) +from vllm.utils import ( + FlexibleArgumentParser, + decorate_logs, + get_tcp_uri, + set_process_title, +) from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus -from vllm.v1.utils import (APIServerProcessManager, - wait_for_completion_or_failure) +from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure logger = init_logger(__name__) @@ -38,13 +43,14 @@ class ServeSubcommand(CLISubcommand): - """The `serve` subcommand for the vLLM CLI. """ + """The `serve` subcommand for the vLLM CLI.""" + name = "serve" @staticmethod def cmd(args: argparse.Namespace) -> None: # If model is specified in CLI (as positional arg), it takes precedence - if hasattr(args, 'model_tag') and args.model_tag is not None: + if hasattr(args, "model_tag") and args.model_tag is not None: args.model = args.model_tag if args.headless or args.api_server_count < 1: @@ -60,16 +66,14 @@ def validate(self, args: argparse.Namespace) -> None: validate_parsed_serve_args(args) def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: serve_parser = subparsers.add_parser( - self.name, - description=DESCRIPTION, - usage="vllm serve [model_tag] [options]") + self.name, description=DESCRIPTION, usage="vllm serve [model_tag] [options]" + ) serve_parser = make_arg_parser(serve_parser) - serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format( - subcmd=self.name) + serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name) return serve_parser @@ -78,29 +82,27 @@ def cmd_init() -> list[CLISubcommand]: def run_headless(args: argparse.Namespace): - if args.api_server_count > 1: raise ValueError("api_server_count can't be set in headless mode") # Create the EngineConfig. engine_args = vllm.AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER - vllm_config = engine_args.create_engine_config(usage_context=usage_context, - headless=True) + vllm_config = engine_args.create_engine_config( + usage_context=usage_context, headless=True + ) if not envs.VLLM_USE_V1: raise ValueError("Headless mode is only supported for V1") if engine_args.data_parallel_hybrid_lb: - raise ValueError("data_parallel_hybrid_lb is not applicable in " - "headless mode") + raise ValueError("data_parallel_hybrid_lb is not applicable in headless mode") parallel_config = vllm_config.parallel_config local_engine_count = parallel_config.data_parallel_size_local if local_engine_count <= 0: - raise ValueError("data_parallel_size_local must be > 0 in " - "headless mode") + raise ValueError("data_parallel_size_local must be > 0 in headless mode") host = parallel_config.data_parallel_master_ip port = engine_args.data_parallel_rpc_port # add to config too @@ -116,7 +118,10 @@ def signal_handler(signum, frame): logger.info( "Launching %d data parallel engine(s) in headless mode, " - "with head node address %s.", local_engine_count, handshake_address) + "with head node address %s.", + local_engine_count, + handshake_address, + ) # Create the engines. engine_manager = CoreEngineProcManager( @@ -139,7 +144,6 @@ def signal_handler(signum, frame): def run_multi_api_server(args: argparse.Namespace): - assert not args.headless num_api_servers: int = args.api_server_count assert num_api_servers > 0 @@ -161,8 +165,10 @@ def run_multi_api_server(args: argparse.Namespace): raise ValueError("api_server_count > 1 is only supported for V1") if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " - "with api_server_count > 1") + raise ValueError( + "VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " + "with api_server_count > 1" + ) executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats @@ -175,10 +181,9 @@ def run_multi_api_server(args: argparse.Namespace): api_server_manager: Optional[APIServerProcessManager] = None - with launch_core_engines(vllm_config, executor_class, log_stats, - num_api_servers) as (local_engine_manager, - coordinator, addresses): - + with launch_core_engines( + vllm_config, executor_class, log_stats, num_api_servers + ) as (local_engine_manager, coordinator, addresses): # Construct common args for the APIServerProcessManager up-front. api_server_manager_kwargs = dict( target_server_fn=run_api_server_worker_proc, @@ -189,7 +194,9 @@ def run_multi_api_server(args: argparse.Namespace): input_addresses=addresses.inputs, output_addresses=addresses.outputs, stats_update_address=coordinator.get_stats_publish_address() - if coordinator else None) + if coordinator + else None, + ) # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the # start of the API servers until the local engine is started @@ -198,27 +205,26 @@ def run_multi_api_server(args: argparse.Namespace): # via the handshake with the local engine. if dp_rank == 0 or not (external_dp_lb or hybrid_dp_lb): # Start API servers using the manager. - api_server_manager = APIServerProcessManager( - **api_server_manager_kwargs) + api_server_manager = APIServerProcessManager(**api_server_manager_kwargs) # Start API servers now if they weren't already started. if api_server_manager is None: api_server_manager_kwargs["stats_update_address"] = ( - addresses.frontend_stats_publish_address) - api_server_manager = APIServerProcessManager( - **api_server_manager_kwargs) + addresses.frontend_stats_publish_address + ) + api_server_manager = APIServerProcessManager(**api_server_manager_kwargs) # Wait for API servers - wait_for_completion_or_failure(api_server_manager=api_server_manager, - engine_manager=local_engine_manager, - coordinator=coordinator) + wait_for_completion_or_failure( + api_server_manager=api_server_manager, + engine_manager=local_engine_manager, + coordinator=coordinator, + ) -def run_api_server_worker_proc(listen_address, - sock, - args, - client_config=None, - **uvicorn_kwargs) -> None: +def run_api_server_worker_proc( + listen_address, sock, args, client_config=None, **uvicorn_kwargs +) -> None: """Entrypoint for individual API server worker processes.""" client_config = client_config or {} server_index = client_config.get("client_index", 0) @@ -228,5 +234,5 @@ def run_api_server_worker_proc(listen_address, decorate_logs() uvloop.run( - run_server_worker(listen_address, sock, args, client_config, - **uvicorn_kwargs)) + run_server_worker(listen_address, sock, args, client_config, **uvicorn_kwargs) + ) diff --git a/vllm/entrypoints/cli/types.py b/vllm/entrypoints/cli/types.py index b88f094b302a..6194f421a1bb 100644 --- a/vllm/entrypoints/cli/types.py +++ b/vllm/entrypoints/cli/types.py @@ -24,6 +24,6 @@ def validate(self, args: argparse.Namespace) -> None: pass def subparser_init( - self, - subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: raise NotImplementedError("Subclasses should implement this method") diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index ea81fdbcd825..f410ee9c4045 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -12,7 +12,10 @@ from openai_harmony import Author, Message, Role, StreamState, TextContent from vllm.entrypoints.harmony_utils import ( - get_encoding, get_streamable_parser_for_assistant, render_for_completion) + get_encoding, + get_streamable_parser_for_assistant, + render_for_completion, +) from vllm.entrypoints.tool import Tool from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput @@ -34,10 +37,11 @@ def _map_tool_name_to_tool_type(tool_name: str) -> str: if tool_name not in _TOOL_NAME_TO_TYPE_MAP: - available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys()) + available_tools = ", ".join(_TOOL_NAME_TO_TYPE_MAP.keys()) raise ValueError( f"Built-in tool name '{tool_name}' not defined in mapping. " - f"Available tools: {available_tools}") + f"Available tools: {available_tools}" + ) return _TOOL_NAME_TO_TYPE_MAP[tool_name] @@ -59,7 +63,6 @@ def copy(self): class ConversationContext(ABC): - @abstractmethod def append_output(self, output) -> None: pass @@ -77,9 +80,13 @@ def render_for_completion(self) -> list[int]: pass @abstractmethod - async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack, request_id: str, - mcp_tools: dict[str, Mcp]) -> None: + async def init_tool_sessions( + self, + tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack, + request_id: str, + mcp_tools: dict[str, Mcp], + ) -> None: pass @abstractmethod @@ -88,7 +95,6 @@ async def cleanup_session(self) -> None: class SimpleContext(ConversationContext): - def __init__(self): self.last_output = None self.num_prompt_tokens = 0 @@ -114,9 +120,13 @@ async def call_tool(self) -> list[Message]: def render_for_completion(self) -> list[int]: raise NotImplementedError("Should not be called.") - async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack, request_id: str, - mcp_tools: dict[str, Mcp]) -> None: + async def init_tool_sessions( + self, + tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack, + request_id: str, + mcp_tools: dict[str, Mcp], + ) -> None: pass async def cleanup_session(self) -> None: @@ -124,7 +134,6 @@ async def cleanup_session(self) -> None: class HarmonyContext(ConversationContext): - def __init__( self, messages: list, @@ -155,8 +164,7 @@ def _update_num_reasoning_tokens(self): if self.parser.current_channel in {"analysis", "commentary"}: self.num_reasoning_tokens += 1 - def append_output(self, output: Union[RequestOutput, - list[Message]]) -> None: + def append_output(self, output: Union[RequestOutput, list[Message]]) -> None: if isinstance(output, RequestOutput): output_token_ids = output.outputs[0].token_ids self.parser = get_streamable_parser_for_assistant() @@ -202,8 +210,7 @@ def _update_prefill_token_usage(self, output: RequestOutput) -> None: this_turn_input_tokens = len(output.prompt_token_ids) else: this_turn_input_tokens = 0 - logger.error( - "RequestOutput appended contains no prompt_token_ids.") + logger.error("RequestOutput appended contains no prompt_token_ids.") # Update current turn input tokens self.current_turn.input_tokens = this_turn_input_tokens @@ -216,9 +223,11 @@ def _update_prefill_token_usage(self, output: RequestOutput) -> None: # start counting tool after first turn # tool tokens = this turn prefill - last turn prefill - # last turn decode - this_turn_tool_tokens = (self.current_turn.input_tokens - - self.previous_turn.input_tokens - - self.previous_turn.output_tokens) + this_turn_tool_tokens = ( + self.current_turn.input_tokens + - self.previous_turn.input_tokens + - self.previous_turn.output_tokens + ) # Handle negative tool token counts (shouldn't happen in normal # cases) @@ -227,9 +236,11 @@ def _update_prefill_token_usage(self, output: RequestOutput) -> None: "Negative tool output tokens calculated: %d " "(current_input=%d, previous_input=%d, " "previous_output=%d). Setting to 0.", - this_turn_tool_tokens, self.current_turn.input_tokens, + this_turn_tool_tokens, + self.current_turn.input_tokens, self.previous_turn.input_tokens, - self.previous_turn.output_tokens) + self.previous_turn.output_tokens, + ) this_turn_tool_tokens = 0 self.num_tool_output_tokens += this_turn_tool_tokens @@ -271,9 +282,11 @@ def messages(self) -> list: def need_builtin_tool_call(self) -> bool: last_msg = self.messages[-1] recipient = last_msg.recipient - return recipient is not None and (recipient.startswith("browser.") - or recipient.startswith("python") or - recipient.startswith("container.")) + return recipient is not None and ( + recipient.startswith("browser.") + or recipient.startswith("python") + or recipient.startswith("container.") + ) async def call_tool(self) -> list[Message]: if not self.messages: @@ -283,21 +296,24 @@ async def call_tool(self) -> list[Message]: if recipient is not None: if recipient.startswith("browser."): return await self.call_search_tool( - self._tool_sessions["browser"], last_msg) + self._tool_sessions["browser"], last_msg + ) elif recipient.startswith("python"): return await self.call_python_tool( - self._tool_sessions["python"], last_msg) + self._tool_sessions["python"], last_msg + ) elif recipient.startswith("container."): return await self.call_container_tool( - self._tool_sessions["container"], last_msg) + self._tool_sessions["container"], last_msg + ) raise ValueError("No tool call found") def render_for_completion(self) -> list[int]: return render_for_completion(self.messages) - async def call_search_tool(self, tool_session: Union["ClientSession", - Tool], - last_msg: Message) -> list[Message]: + async def call_search_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: Message + ) -> list[Message]: self.called_tools.add("browser") if isinstance(tool_session, Tool): return await tool_session.get_result(self) @@ -308,15 +324,17 @@ async def call_search_tool(self, tool_session: Union["ClientSession", content = TextContent(text=result_str) author = Author(role=Role.TOOL, name=last_msg.recipient) return [ - Message(author=author, - content=[content], - recipient=Role.ASSISTANT, - channel=last_msg.channel) + Message( + author=author, + content=[content], + recipient=Role.ASSISTANT, + channel=last_msg.channel, + ) ] - async def call_python_tool(self, tool_session: Union["ClientSession", - Tool], - last_msg: Message) -> list[Message]: + async def call_python_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: Message + ) -> list[Message]: self.called_tools.add("python") if isinstance(tool_session, Tool): return await tool_session.get_result(self) @@ -330,45 +348,52 @@ async def call_python_tool(self, tool_session: Union["ClientSession", author = Author(role=Role.TOOL, name="python") return [ - Message(author=author, - content=[content], - channel=last_msg.channel, - recipient=Role.ASSISTANT) + Message( + author=author, + content=[content], + channel=last_msg.channel, + recipient=Role.ASSISTANT, + ) ] - async def init_tool_sessions(self, tool_server: Optional[ToolServer], - exit_stack: AsyncExitStack, request_id: str, - mcp_tools: dict[str, Mcp]): + async def init_tool_sessions( + self, + tool_server: Optional[ToolServer], + exit_stack: AsyncExitStack, + request_id: str, + mcp_tools: dict[str, Mcp], + ): if tool_server: for tool_name in self.available_tools: if tool_name not in self._tool_sessions: tool_type = _map_tool_name_to_tool_type(tool_name) - headers = mcp_tools[ - tool_type].headers if tool_type in mcp_tools else None + headers = ( + mcp_tools[tool_type].headers if tool_type in mcp_tools else None + ) tool_session = await exit_stack.enter_async_context( - tool_server.new_session(tool_name, request_id, - headers)) + tool_server.new_session(tool_name, request_id, headers) + ) self._tool_sessions[tool_name] = tool_session exit_stack.push_async_exit(self.cleanup_session) - async def call_container_tool(self, tool_session: Union["ClientSession", - Tool], - last_msg: Message) -> list[Message]: + async def call_container_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: Message + ) -> list[Message]: """ - Call container tool. Expect this to be run in a stateful docker - with command line terminal. - The official container tool would at least - expect the following format: - - for tool name: exec - - args: - { - "cmd":List[str] "command to execute", - "workdir":optional[str] "current working directory", - "env":optional[object/dict] "environment variables", - "session_name":optional[str] "session name", - "timeout":optional[int] "timeout in seconds", - "user":optional[str] "user name", - } + Call container tool. Expect this to be run in a stateful docker + with command line terminal. + The official container tool would at least + expect the following format: + - for tool name: exec + - args: + { + "cmd":List[str] "command to execute", + "workdir":optional[str] "current working directory", + "env":optional[object/dict] "environment variables", + "session_name":optional[str] "session name", + "timeout":optional[int] "timeout in seconds", + "user":optional[str] "user name", + } """ self.called_tools.add("container") if isinstance(tool_session, Tool): @@ -380,10 +405,12 @@ async def call_container_tool(self, tool_session: Union["ClientSession", content = TextContent(text=result_str) author = Author(role=Role.TOOL, name=last_msg.recipient) return [ - Message(author=author, - content=[content], - recipient=Role.ASSISTANT, - channel=last_msg.channel) + Message( + author=author, + content=[content], + recipient=Role.ASSISTANT, + channel=last_msg.channel, + ) ] async def cleanup_session(self, *args, **kwargs) -> None: @@ -391,17 +418,21 @@ async def cleanup_session(self, *args, **kwargs) -> None: async def cleanup_tool_session(tool_session): if not isinstance(tool_session, Tool): - logger.info("Cleaning up tool session for %s", - tool_session._client_info) + logger.info( + "Cleaning up tool session for %s", tool_session._client_info + ) with contextlib.suppress(Exception): await tool_session.call_tool("cleanup_session", {}) - await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool]) - for tool in self.called_tools)) + await asyncio.gather( + *( + cleanup_tool_session(self._tool_sessions[tool]) + for tool in self.called_tools + ) + ) class StreamingHarmonyContext(HarmonyContext): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.last_output = None @@ -415,8 +446,7 @@ def __init__(self, *args, **kwargs): def messages(self) -> list: return self._messages - def append_output(self, output: Union[RequestOutput, - list[Message]]) -> None: + def append_output(self, output: Union[RequestOutput, list[Message]]) -> None: if isinstance(output, RequestOutput): # append_output is called for each output token in streaming case, # so we only want to add the prompt tokens once for each message. @@ -438,11 +468,10 @@ def append_output(self, output: Union[RequestOutput, # Check if the current token is part of reasoning content self._update_num_reasoning_tokens() self.last_tok = tok - if len(self._messages) - self.num_init_messages < len( - self.parser.messages): + if len(self._messages) - self.num_init_messages < len(self.parser.messages): self._messages.extend( - self.parser.messages[len(self._messages) - - self.num_init_messages:]) + self.parser.messages[len(self._messages) - self.num_init_messages :] + ) else: # Handle the case of tool output in direct message format assert len(output) == 1, "Tool output should be a single message" @@ -461,8 +490,7 @@ def is_expecting_start(self) -> bool: return self.parser.state == StreamState.EXPECT_START def is_assistant_action_turn(self) -> bool: - return self.last_tok in self.encoding.stop_tokens_for_assistant_actions( - ) + return self.last_tok in self.encoding.stop_tokens_for_assistant_actions() def render_for_completion(self) -> list[int]: # now this list of tokens as next turn's starting tokens diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index 0c1c9c3192fc..bf6cc3e97c82 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -6,25 +6,46 @@ import datetime import json from collections.abc import Iterable, Sequence -from typing import Literal, Optional, Union - -from openai.types.responses import (ResponseFunctionToolCall, - ResponseOutputItem, ResponseOutputMessage, - ResponseOutputText, ResponseReasoningItem) +from typing import Literal, Union + +from openai.types.responses import ( + ResponseFunctionToolCall, + ResponseOutputItem, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, +) from openai.types.responses.response_function_web_search import ( - ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch) + ActionFind, + ActionOpenPage, + ActionSearch, + ResponseFunctionWebSearch, +) from openai.types.responses.response_reasoning_item import ( - Content as ResponseReasoningTextContent) + Content as ResponseReasoningTextContent, +) from openai.types.responses.tool import Tool -from openai_harmony import (Author, ChannelConfig, Conversation, - DeveloperContent, HarmonyEncodingName, Message, - ReasoningEffort, Role, StreamableParser, - SystemContent, TextContent, ToolDescription, - load_harmony_encoding) +from openai_harmony import ( + Author, + ChannelConfig, + Conversation, + DeveloperContent, + HarmonyEncodingName, + Message, + ReasoningEffort, + Role, + StreamableParser, + SystemContent, + TextContent, + ToolDescription, + load_harmony_encoding, +) from vllm import envs -from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam, - ResponseInputOutputItem) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionToolsParam, + ResponseInputOutputItem, +) from vllm.utils import random_uuid REASONING_EFFORT = { @@ -53,33 +74,33 @@ def has_custom_tools(tool_types: list[str]) -> bool: def get_encoding(): global _harmony_encoding if _harmony_encoding is None: - _harmony_encoding = load_harmony_encoding( - HarmonyEncodingName.HARMONY_GPT_OSS) + _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) return _harmony_encoding def get_system_message( - model_identity: Optional[str] = None, - reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, - start_date: Optional[str] = None, - browser_description: Optional[str] = None, - python_description: Optional[str] = None, - container_description: Optional[str] = None, - instructions: Optional[str] = None, + model_identity: str | None = None, + reasoning_effort: Literal["high", "medium", "low"] | None = None, + start_date: str | None = None, + browser_description: str | None = None, + python_description: str | None = None, + container_description: str | None = None, + instructions: str | None = None, with_custom_tools: bool = False, ) -> Message: sys_msg_content = SystemContent.new() if model_identity is not None: sys_msg_content = sys_msg_content.with_model_identity(model_identity) - if (instructions is not None - and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS): + if instructions is not None and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: current_identity = sys_msg_content.model_identity - new_identity = (f'{current_identity}\n{instructions}' - if current_identity else instructions) + new_identity = ( + f"{current_identity}\n{instructions}" if current_identity else instructions + ) sys_msg_content = sys_msg_content.with_model_identity(new_identity) if reasoning_effort is not None: sys_msg_content = sys_msg_content.with_reasoning_effort( - REASONING_EFFORT[reasoning_effort]) + REASONING_EFFORT[reasoning_effort] + ) if start_date is None: # NOTE(woosuk): This brings non-determinism in vLLM. Be careful. start_date = datetime.datetime.now().strftime("%Y-%m-%d") @@ -94,7 +115,8 @@ def get_system_message( channel_config = sys_msg_content.channel_config invalid_channel = "commentary" new_config = ChannelConfig.require_channels( - [c for c in channel_config.valid_channels if c != invalid_channel]) + [c for c in channel_config.valid_channels if c != invalid_channel] + ) sys_msg_content = sys_msg_content.with_channel_config(new_config) sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) return sys_msg @@ -115,18 +137,21 @@ def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]): def get_developer_message( - instructions: Optional[str] = None, - tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None, + instructions: str | None = None, + tools: list[Union[Tool, ChatCompletionToolsParam]] | None = None, ) -> Message: dev_msg_content = DeveloperContent.new() - if (instructions is not None - and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS): + if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: dev_msg_content = dev_msg_content.with_instructions(instructions) if tools is not None: function_tools: list[Union[Tool, ChatCompletionToolsParam]] = [] for tool in tools: - if tool.type in ("web_search_preview", "code_interpreter", - "container", "mcp"): + if tool.type in ( + "web_search_preview", + "code_interpreter", + "container", + "mcp", + ): # These are built-in tools that are added to the system message. # Adding in MCP for now until we support MCP tools executed # server side @@ -141,7 +166,8 @@ def get_developer_message( create_tool_definition(tool) for tool in function_tools ] dev_msg_content = dev_msg_content.with_function_tools( - function_tool_descriptions) + function_tool_descriptions + ) dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) return dev_msg @@ -152,7 +178,7 @@ def get_user_message(content: str) -> Message: def parse_response_input( response_msg: ResponseInputOutputItem, - prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]] + prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]], ) -> Message: if not isinstance(response_msg, dict): response_msg = response_msg.model_dump() @@ -170,32 +196,32 @@ def parse_response_input( if isinstance(content, str): msg = Message.from_role_and_content(role, text_prefix + content) else: - contents = [ - TextContent(text=text_prefix + c["text"]) for c in content - ] + contents = [TextContent(text=text_prefix + c["text"]) for c in content] msg = Message.from_role_and_contents(role, contents) if role == "assistant": msg = msg.with_channel("final") elif response_msg["type"] == "function_call_output": call_id = response_msg["call_id"] - call_response: Optional[ResponseFunctionToolCall] = None + call_response: ResponseFunctionToolCall | None = None for prev_response in reversed(prev_responses): - if isinstance(prev_response, ResponseFunctionToolCall - ) and prev_response.call_id == call_id: + if ( + isinstance(prev_response, ResponseFunctionToolCall) + and prev_response.call_id == call_id + ): call_response = prev_response break if call_response is None: raise ValueError(f"No call message found for {call_id}") msg = Message.from_author_and_content( Author.new(Role.TOOL, f"functions.{call_response.name}"), - response_msg["output"]) + response_msg["output"], + ) elif response_msg["type"] == "reasoning": content = response_msg["content"] assert len(content) == 1 msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"]) elif response_msg["type"] == "function_call": - msg = Message.from_role_and_content(Role.ASSISTANT, - response_msg["arguments"]) + msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"]) msg = msg.with_channel("commentary") msg = msg.with_recipient(f"functions.{response_msg['name']}") msg = msg.with_content_type("json") @@ -231,8 +257,8 @@ def parse_chat_input(chat_msg) -> list[Message]: name = chat_msg.get("name", "") content = chat_msg.get("content", "") or "" msg = Message.from_author_and_content( - Author.new(Role.TOOL, f"functions.{name}"), - content).with_channel("commentary") + Author.new(Role.TOOL, f"functions.{name}"), content + ).with_channel("commentary") return [msg] # Default: user/assistant/system messages with content @@ -249,7 +275,8 @@ def parse_chat_input(chat_msg) -> list[Message]: def render_for_completion(messages: list[Message]) -> list[int]: conversation = Conversation.from_messages(messages) token_ids = get_encoding().render_conversation_for_completion( - conversation, Role.ASSISTANT) + conversation, Role.ASSISTANT + ) return token_ids @@ -273,14 +300,18 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: # TODO: translate to url properly! if recipient == "browser.search": action = ActionSearch( - query=f"cursor:{browser_call.get('query', '')}", type="search") + query=f"cursor:{browser_call.get('query', '')}", type="search" + ) elif recipient == "browser.open": action = ActionOpenPage( - url=f"cursor:{browser_call.get('url', '')}", type="open_page") + url=f"cursor:{browser_call.get('url', '')}", type="open_page" + ) elif recipient == "browser.find": - action = ActionFind(pattern=browser_call["pattern"], - url=f"cursor:{browser_call.get('url', '')}", - type="find") + action = ActionFind( + pattern=browser_call["pattern"], + url=f"cursor:{browser_call.get('url', '')}", + type="find", + ) else: raise ValueError(f"Unknown browser action: {recipient}") web_search_item = ResponseFunctionWebSearch( @@ -297,8 +328,9 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: summary=[], type="reasoning", content=[ - ResponseReasoningTextContent(text=content.text, - type="reasoning_text") + ResponseReasoningTextContent( + text=content.text, type="reasoning_text" + ) ], status=None, ) @@ -316,17 +348,20 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: id=f"fc_{random_id}", ) output_items.append(response_item) - elif recipient is not None and (recipient.startswith("python") - or recipient.startswith("browser") - or recipient.startswith("container")): + elif recipient is not None and ( + recipient.startswith("python") + or recipient.startswith("browser") + or recipient.startswith("container") + ): for content in message.content: reasoning_item = ResponseReasoningItem( id=f"rs_{random_uuid()}", summary=[], type="reasoning", content=[ - ResponseReasoningTextContent(text=content.text, - type="reasoning_text") + ResponseReasoningTextContent( + text=content.text, type="reasoning_text" + ) ], status=None, ) @@ -356,15 +391,13 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: return output_items -def parse_remaining_state( - parser: StreamableParser) -> list[ResponseOutputItem]: +def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]: if not parser.current_content: return [] if parser.current_role != Role.ASSISTANT: return [] current_recipient = parser.current_recipient - if (current_recipient is not None - and current_recipient.startswith("browser.")): + if current_recipient is not None and current_recipient.startswith("browser."): return [] if parser.current_channel == "analysis": @@ -373,8 +406,9 @@ def parse_remaining_state( summary=[], type="reasoning", content=[ - ResponseReasoningTextContent(text=parser.current_content, - type="reasoning_text") + ResponseReasoningTextContent( + text=parser.current_content, type="reasoning_text" + ) ], status=None, ) @@ -415,7 +449,8 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser: def parse_chat_output( - token_ids: Sequence[int]) -> tuple[Optional[str], Optional[str], bool]: + token_ids: Sequence[int], +) -> tuple[str | None, str | None, bool]: parser = parse_output_into_messages(token_ids) output_msgs = parser.messages is_tool_call = False # TODO: update this when tool call is supported @@ -430,7 +465,6 @@ def parse_chat_output( else: reasoning_msg = output_msgs[:-1] final_msg = output_msgs[-1] - reasoning_content = "\n".join( - [msg.content[0].text for msg in reasoning_msg]) + reasoning_content = "\n".join([msg.content[0].text for msg in reasoning_msg]) final_content = final_msg.content[0].text return reasoning_content, final_content, is_tool_call diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 8b2acedf805c..349437363c5b 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -12,8 +12,10 @@ from vllm import envs from vllm.engine.protocol import EngineClient -from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, - H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) +from vllm.entrypoints.constants import ( + H11_MAX_HEADER_COUNT_DEFAULT, + H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT, +) from vllm.entrypoints.ssl import SSLCertRefresher from vllm.logger import init_logger from vllm.utils import find_process_using_port @@ -22,10 +24,12 @@ logger = init_logger(__name__) -async def serve_http(app: FastAPI, - sock: Optional[socket.socket], - enable_ssl_refresh: bool = False, - **uvicorn_kwargs: Any): +async def serve_http( + app: FastAPI, + sock: Optional[socket.socket], + enable_ssl_refresh: bool = False, + **uvicorn_kwargs: Any, +): """ Start a FastAPI app using Uvicorn, with support for custom Uvicorn config options. Supports http header limits via h11_max_incomplete_event_size and @@ -39,11 +43,12 @@ async def serve_http(app: FastAPI, if methods is None or path is None: continue - logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + logger.info("Route: %s, Methods: %s", path, ", ".join(methods)) # Extract header limit options if present h11_max_incomplete_event_size = uvicorn_kwargs.pop( - "h11_max_incomplete_event_size", None) + "h11_max_incomplete_event_size", None + ) h11_max_header_count = uvicorn_kwargs.pop("h11_max_header_count", None) # Set safe defaults if not provided @@ -62,16 +67,19 @@ async def serve_http(app: FastAPI, loop = asyncio.get_running_loop() - watchdog_task = loop.create_task( - watchdog_loop(server, app.state.engine_client)) - server_task = loop.create_task( - server.serve(sockets=[sock] if sock else None)) - - ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher( - ssl_context=config.ssl, - key_path=config.ssl_keyfile, - cert_path=config.ssl_certfile, - ca_path=config.ssl_ca_certs) + watchdog_task = loop.create_task(watchdog_loop(server, app.state.engine_client)) + server_task = loop.create_task(server.serve(sockets=[sock] if sock else None)) + + ssl_cert_refresher = ( + None + if not enable_ssl_refresh + else SSLCertRefresher( + ssl_context=config.ssl, + key_path=config.ssl_keyfile, + cert_path=config.ssl_certfile, + ca_path=config.ssl_ca_certs, + ) + ) def signal_handler() -> None: # prevents the uvicorn signal handler to exit early @@ -95,7 +103,10 @@ async def dummy_shutdown() -> None: if process is not None: logger.warning( "port %s is used by process %s launched with command:\n%s", - port, process, " ".join(process.cmdline())) + port, + process, + " ".join(process.cmdline()), + ) logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() finally: @@ -131,14 +142,14 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """ VLLM V1 AsyncLLM catches exceptions and returns only two types: EngineGenerateError and EngineDeadError. - + EngineGenerateError is raised by the per request generate() method. This error could be request specific (and therefore recoverable - e.g. if there is an error in input processing). - + EngineDeadError is raised by the background output_handler method. This error is global and therefore not recoverable. - + We register these @app.exception_handlers to return nice responses to the end user if they occur and shut down if needed. See https://fastapi.tiangolo.com/tutorial/handling-errors/ diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 705a72f657a2..9afbf8b7e1b8 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -11,47 +11,70 @@ from tqdm.auto import tqdm from typing_extensions import TypeVar -from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, - BeamSearchSequence, - create_sort_beams_key_function) -from vllm.config import (CompilationConfig, ModelDType, - StructuredOutputsConfig, TokenizerMode, is_init_field) -from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides, - PoolerConfig, RunnerOption) -from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - ChatTemplateContentFormatOption, - apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages, - resolve_chat_template_content_format) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.score_utils import (ScoreContentPartParam, - ScoreMultiModalParam, - _cosine_similarity, - _validate_score_input_lens, - compress_token_type_ids, - get_score_prompt) -# yapf: enable -from vllm.entrypoints.utils import (_validate_truncation_size, - log_non_default_args) -from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt, - TokensPrompt) +from vllm.beam_search import ( + BeamSearchInstance, + BeamSearchOutput, + BeamSearchSequence, + create_sort_beams_key_function, +) +from vllm.config import ( + CompilationConfig, + ModelDType, + StructuredOutputsConfig, + TokenizerMode, + is_init_field, +) +from vllm.engine.arg_utils import ( + ConvertOption, + EngineArgs, + HfOverrides, + PoolerConfig, + RunnerOption, +) +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages, + resolve_chat_template_content_format, +) +from vllm.entrypoints.score_utils import ( + ScoreContentPartParam, + ScoreMultiModalParam, + _cosine_similarity, + _validate_score_input_lens, + compress_token_type_ids, + get_score_prompt, +) +from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args +from vllm.inputs import ( + DataPrompt, + PromptType, + SingletonPrompt, + TextPrompt, + TokensPrompt, +) from vllm.inputs.parse import get_prompt_components from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, - PoolingRequestOutput, RequestOutput, - ScoringRequestOutput) +from vllm.outputs import ( + ClassificationRequestOutput, + EmbeddingRequestOutput, + PoolingRequestOutput, + RequestOutput, + ScoringRequestOutput, +) from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams -from vllm.sampling_params import (BeamSearchParams, RequestOutputKind, - SamplingParams) +from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask -from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, - get_cached_tokenizer, - init_tokenizer_from_configs) +from vllm.transformers_utils.tokenizer import ( + AnyTokenizer, + MistralTokenizer, + get_cached_tokenizer, +) from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, Device, as_iter, is_list_of from vllm.v1.engine import EngineCoreRequest @@ -90,7 +113,7 @@ class LLM: or videos from directories specified by the server file system. This is a security risk. Should only be enabled in trusted environments. - allowed_media_domains: If set, only media URLs that belong to this + allowed_media_domains: If set, only media URLs that belong to this domain can be used for multi-modal inputs. tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. @@ -192,13 +215,14 @@ def __init__( mm_processor_kwargs: Optional[dict[str, Any]] = None, pooler_config: Optional[PoolerConfig] = None, override_pooler_config: Optional[PoolerConfig] = None, - structured_outputs_config: Optional[Union[dict[ - str, Any], StructuredOutputsConfig]] = None, + structured_outputs_config: Optional[ + Union[dict[str, Any], StructuredOutputsConfig] + ] = None, kv_cache_memory_bytes: Optional[int] = None, - compilation_config: Optional[Union[int, dict[str, Any], - CompilationConfig]] = None, - logits_processors: Optional[list[Union[str, - type[LogitsProcessor]]]] = None, + compilation_config: Optional[ + Union[int, dict[str, Any], CompilationConfig] + ] = None, + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None, **kwargs: Any, ) -> None: """LLM constructor.""" @@ -214,21 +238,23 @@ def __init__( kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) if "kv_transfer_config" in kwargs and isinstance( - kwargs["kv_transfer_config"], dict): + kwargs["kv_transfer_config"], dict + ): from vllm.config.kv_transfer import KVTransferConfig + raw_config_dict = kwargs["kv_transfer_config"] try: - kwargs["kv_transfer_config"] = KVTransferConfig( - **raw_config_dict) + kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict) except ValidationError as e: logger.error( "Failed to convert 'kv_transfer_config' dict to " "KVTransferConfig object. Dict: %s. Error: %s", - raw_config_dict, e) + raw_config_dict, + e, + ) # Consider re-raising a more specific vLLM error or ValueError # to provide better context to the user. - raise ValueError( - f"Invalid 'kv_transfer_config' provided: {e}") from e + raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e if hf_overrides is None: hf_overrides = {} @@ -236,14 +262,16 @@ def __init__( if compilation_config is not None: if isinstance(compilation_config, int): compilation_config_instance = CompilationConfig( - level=compilation_config) + level=compilation_config + ) elif isinstance(compilation_config, dict): compilation_config_instance = CompilationConfig( **{ k: v for k, v in compilation_config.items() if is_init_field(CompilationConfig, k) - }) + } + ) else: compilation_config_instance = compilation_config else: @@ -256,7 +284,8 @@ def __init__( k: v for k, v in structured_outputs_config.items() if is_init_field(StructuredOutputsConfig, k) - }) + } + ) else: structured_outputs_instance = structured_outputs_config else: @@ -299,7 +328,8 @@ def __init__( # Create the Engine (autoselects V0 vs V1) self.llm_engine = LLMEngine.from_engine_args( - engine_args=engine_args, usage_context=UsageContext.LLM_CLASS) + engine_args=engine_args, usage_context=UsageContext.LLM_CLASS + ) self.engine_class = type(self.llm_engine) self.request_counter = Counter() @@ -313,8 +343,9 @@ def __init__( # Load the Input/Output processor plugin if any io_processor_plugin = self.llm_engine.model_config.io_processor_plugin - self.io_processor = get_io_processor(self.llm_engine.vllm_config, - io_processor_plugin) + self.io_processor = get_io_processor( + self.llm_engine.vllm_config, io_processor_plugin + ) @property def model_config(self): @@ -335,17 +366,15 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: def _get_processor(self) -> Processor: if not hasattr(self, "_processor"): vllm_config = self.llm_engine.vllm_config - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = init_tokenizer_from_configs(self.model_config) - self._processor = Processor(vllm_config, tokenizer) + self._processor = Processor(vllm_config) + return self._processor def get_default_sampling_params(self) -> SamplingParams: if self.default_sampling_params is None: self.default_sampling_params = ( - self.llm_engine.model_config.get_diff_sampling_param()) + self.llm_engine.model_config.get_diff_sampling_param() + ) if self.default_sampling_params: return SamplingParams.from_optional(**self.default_sampling_params) return SamplingParams() @@ -353,8 +382,9 @@ def get_default_sampling_params(self) -> SamplingParams: def generate( self, prompts: Union[PromptType, Sequence[PromptType]], - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, + sampling_params: Optional[ + Union[SamplingParams, Sequence[SamplingParams]] + ] = None, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, @@ -398,15 +428,15 @@ def generate( raise ValueError( "LLM.generate() is only supported for generative models. " "Try passing `--runner generate` to use the model as a " - "generative model.") + "generative model." + ) if sampling_params is None: # Use default sampling params. sampling_params = self.get_default_sampling_params() # Add any modality specific loras to the corresponding prompts - lora_request = self._get_modality_specific_lora_reqs( - prompts, lora_request) + lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request) self._validate_and_add_requests( prompts=prompts, @@ -420,46 +450,59 @@ def generate( return self.engine_class.validate_outputs(outputs, RequestOutput) def _get_modality_specific_lora_reqs( - self, prompts: Union[PromptType, Sequence[PromptType]], - lora_request: Optional[Union[list[LoRARequest], LoRARequest]]): + self, + prompts: Union[PromptType, Sequence[PromptType]], + lora_request: Optional[Union[list[LoRARequest], LoRARequest]], + ): # Grab the lora config off the vllm config on the engine, # since this is the same for both v0 & v1. lora_config = self.llm_engine.vllm_config.lora_config # If there's no lora config / default_mm_loras, or the model # isn't multimodal, leave the lora as is. - if (lora_config is None - or not self.llm_engine.model_config.is_multimodal_model - or (lora_config and lora_config.default_mm_loras is None)): + if ( + lora_config is None + or not self.llm_engine.model_config.is_multimodal_model + or (lora_config and lora_config.default_mm_loras is None) + ): return lora_request if not isinstance(prompts, Sequence): prompts = [prompts] - optional_loras = ([lora_request] * len(prompts) - if not isinstance(lora_request, Sequence) else - lora_request) + optional_loras = ( + [lora_request] * len(prompts) + if not isinstance(lora_request, Sequence) + else lora_request + ) return [ self._resolve_single_prompt_mm_lora( prompt, opt_lora_req, lora_config.default_mm_loras, - ) for prompt, opt_lora_req in zip(prompts, optional_loras) + ) + for prompt, opt_lora_req in zip(prompts, optional_loras) ] - def _resolve_single_prompt_mm_lora(self, prompt: PromptType, - lora_request: Optional[LoRARequest], - default_mm_loras: Optional[dict[str, - str]]): - if (not default_mm_loras or not isinstance(prompt, dict) - or "multi_modal_data" not in prompt): + def _resolve_single_prompt_mm_lora( + self, + prompt: PromptType, + lora_request: Optional[LoRARequest], + default_mm_loras: Optional[dict[str, str]], + ): + if ( + not default_mm_loras + or not isinstance(prompt, dict) + or "multi_modal_data" not in prompt + ): return lora_request prompt = cast(Union[TextPrompt, TokensPrompt], prompt) - intersection = set(prompt["multi_modal_data"].keys()) \ - .intersection(default_mm_loras.keys()) + intersection = set(prompt["multi_modal_data"].keys()).intersection( + default_mm_loras.keys() + ) if not intersection: return lora_request if len(intersection) > 1: @@ -469,7 +512,9 @@ def _resolve_single_prompt_mm_lora(self, prompt: PromptType, " used by a single prompt consuming several modalities; " " currently we only support one lora per request; as such," " lora(s) registered with modalities: %s" - " will be skipped", intersection) + " will be skipped", + intersection, + ) return lora_request # Build the LoRA request; the ID of the default mm lora is the @@ -485,7 +530,8 @@ def _resolve_single_prompt_mm_lora(self, prompt: PromptType, logger.warning( "A modality with a registered lora and a lora_request " "with a different ID were provided; falling back to the " - "lora_request as we only apply one LoRARequest per prompt") + "lora_request as we only apply one LoRARequest per prompt" + ) return lora_request return LoRARequest( @@ -494,11 +540,13 @@ def _resolve_single_prompt_mm_lora(self, prompt: PromptType, modality_lora_path, ) - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: """ Execute an RPC call on all workers. @@ -543,10 +591,10 @@ def _get_beam_search_lora_requests( prompts: list[Union[TokensPrompt, TextPrompt]], ) -> list[Optional[LoRARequest]]: """Get the optional lora request corresponding to each prompt.""" - if isinstance(lora_request, - Sequence) and len(lora_request) != len(prompts): + if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts): raise ValueError( - "Lora request list should be the same length as the prompts") + "Lora request list should be the same length as the prompts" + ) if lora_request is None or isinstance(lora_request, LoRARequest): return [lora_request] * len(prompts) @@ -581,8 +629,7 @@ def beam_search( ignore_eos = params.ignore_eos length_penalty = params.length_penalty - lora_requests = self._get_beam_search_lora_requests( - lora_request, prompts) + lora_requests = self._get_beam_search_lora_requests(lora_request, prompts) tokenizer = self.get_tokenizer() sort_beams_key = create_sort_beams_key_function( @@ -593,31 +640,28 @@ def beam_search( if use_tqdm and concurrency_limit is not None: logger.warning( "Progress bar is not supported when using concurrency_limit. " - "Disabling progress bar.") + "Disabling progress bar." + ) use_tqdm = False if concurrency_limit is None: concurrency_limit = len(prompts) - def create_tokens_prompt_from_beam( - beam: BeamSearchSequence) -> TokensPrompt: - token_prompt_kwargs: TokensPrompt = { - "prompt_token_ids": beam.tokens - } + def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt: + token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens} if beam.multi_modal_data is not None: token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data if beam.mm_processor_kwargs is not None: - token_prompt_kwargs[ - "mm_processor_kwargs"] = beam.mm_processor_kwargs + token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs return TokensPrompt(**token_prompt_kwargs) # generate 2 * beam_width candidates at each step # following the huggingface transformers implementation # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa - beam_search_params = SamplingParams(logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature) + beam_search_params = SamplingParams( + logprobs=2 * beam_width, max_tokens=1, temperature=temperature + ) instances: list[BeamSearchInstance] = [] for lora_req, prompt in zip(lora_requests, prompts): @@ -626,8 +670,7 @@ def create_tokens_prompt_from_beam( if "multi_modal_data" in prompt: mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"] if "mm_processor_kwargs" in prompt: - mm_kwargs["mm_processor_kwargs"] = prompt[ - "mm_processor_kwargs"] + mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"] if "prompt_token_ids" in prompt: prompt = cast(TokensPrompt, prompt) # Needed for mypy @@ -641,48 +684,58 @@ def create_tokens_prompt_from_beam( lora_request=lora_req, logprobs=None, **mm_kwargs, - ), ) + ), + ) for prompt_start in range(0, len(prompts), concurrency_limit): - instances_batch = instances[prompt_start:prompt_start + - concurrency_limit] + instances_batch = instances[prompt_start : prompt_start + concurrency_limit] token_iter = range(max_tokens) if use_tqdm: - token_iter = tqdm(token_iter, - desc="Beam search", - unit="token", - unit_scale=False) + token_iter = tqdm( + token_iter, desc="Beam search", unit="token", unit_scale=False + ) logger.warning( "The progress bar shows the upper bound on token steps and " "may finish early due to stopping conditions. It does not " - "reflect instance-level progress.") + "reflect instance-level progress." + ) for _ in token_iter: all_beams: list[BeamSearchSequence] = list( - sum((instance.beams for instance in instances_batch), [])) + sum((instance.beams for instance in instances_batch), []) + ) pos = [0] + list( itertools.accumulate( - len(instance.beams) for instance in instances_batch)) + len(instance.beams) for instance in instances_batch + ) + ) instance_start_and_end: list[tuple[int, int]] = list( - zip(pos[:-1], pos[1:])) + zip(pos[:-1], pos[1:]) + ) if len(all_beams) == 0: break # create corresponding batch entries for prompt & optional lora prompts_batch, lora_req_batch = zip( - *[(create_tokens_prompt_from_beam(beam), beam.lora_request) - for beam in all_beams]) + *[ + (create_tokens_prompt_from_beam(beam), beam.lora_request) + for beam in all_beams + ] + ) # only runs for one step # we don't need to use tqdm here - output = self.generate(prompts_batch, - sampling_params=beam_search_params, - use_tqdm=False, - lora_request=lora_req_batch) + output = self.generate( + prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False, + lora_request=lora_req_batch, + ) - for (start, end), instance in zip(instance_start_and_end, - instances_batch): + for (start, end), instance in zip( + instance_start_and_end, instances_batch + ): instance_new_beams = [] for i in range(start, end): current_beam = all_beams[i] @@ -697,32 +750,32 @@ def create_tokens_prompt_from_beam( for token_id, logprob_obj in logprobs.items(): new_beam = BeamSearchSequence( tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + - [logprobs], + logprobs=current_beam.logprobs + [logprobs], lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob, - multi_modal_data=current_beam. - multi_modal_data, - mm_processor_kwargs=current_beam. - mm_processor_kwargs) - - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam.multi_modal_data, + mm_processor_kwargs=current_beam.mm_processor_kwargs, + ) + + if ( + token_id == tokenizer.eos_token_id + and not ignore_eos + ): instance.completed.append(new_beam) else: instance_new_beams.append(new_beam) - sorted_beams = sorted(instance_new_beams, - key=sort_beams_key, - reverse=True) + sorted_beams = sorted( + instance_new_beams, key=sort_beams_key, reverse=True + ) instance.beams = sorted_beams[:beam_width] outputs = [] for instance in instances: instance.completed.extend(instance.beams) - sorted_completed = sorted(instance.completed, - key=sort_beams_key, - reverse=True) + sorted_completed = sorted( + instance.completed, key=sort_beams_key, reverse=True + ) best_beams = sorted_completed[:beam_width] for beam in best_beams: @@ -733,8 +786,9 @@ def create_tokens_prompt_from_beam( def preprocess_chat( self, - messages: Union[list[ChatCompletionMessageParam], - list[list[ChatCompletionMessageParam]]], + messages: Union[ + list[ChatCompletionMessageParam], list[list[ChatCompletionMessageParam]] + ], chat_template: Optional[str] = None, chat_template_content_format: ChatTemplateContentFormatOption = "auto", add_generation_prompt: bool = True, @@ -758,13 +812,10 @@ def preprocess_chat( # Handle multi and single conversations if is_list_of(messages, list): # messages is list[list[...]] - list_of_messages = cast(list[list[ChatCompletionMessageParam]], - messages) + list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages) else: # messages is list[...] - list_of_messages = [ - cast(list[ChatCompletionMessageParam], messages) - ] + list_of_messages = [cast(list[ChatCompletionMessageParam], messages)] tokenizer = self.get_tokenizer() model_config = self.llm_engine.get_model_config() @@ -812,8 +863,9 @@ def preprocess_chat( ) # Special tokens are already included in chat templates so # should not be added by the tokenizer in this case. - prompt_token_ids = tokenizer.encode(prompt_str, - add_special_tokens=False) + prompt_token_ids = tokenizer.encode( + prompt_str, add_special_tokens=False + ) prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) @@ -832,10 +884,10 @@ def preprocess_chat( def chat( self, - messages: Union[list[ChatCompletionMessageParam], - list[list[ChatCompletionMessageParam]]], - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, + messages: Union[ + list[ChatCompletionMessageParam], list[list[ChatCompletionMessageParam]] + ], + sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, @@ -918,8 +970,7 @@ def chat( def encode( self, prompts: Union[PromptType, Sequence[PromptType], DataPrompt], - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, @@ -963,23 +1014,21 @@ def encode( pooling_task = "encode" if pooling_task is None: - if "embed" in self.supported_tasks: - pooling_task = "embed" - else: - pooling_task = "encode" + pooling_task = "embed" if "embed" in self.supported_tasks else "encode" logger.warning_once( "`LLM.encode` is currently using `pooling_task = %s`.\n" "Please use one of the more specific methods or set the " "task directly when using `LLM.encode`:\n" " - For embeddings, use `LLM.embed(...)` " - "or `pooling_task=\"embed\"`.\n" + 'or `pooling_task="embed"`.\n' " - For classification logits, use `LLM.classify(...)` " - "or `pooling_task=\"classify\"`.\n" + 'or `pooling_task="classify"`.\n' " - For rewards, use `LLM.reward(...)` " - "or `pooling_task=\"reward\"`\n" + 'or `pooling_task="reward"`\n' " - For similarity scores, use `LLM.score(...)`.", - pooling_task) + pooling_task, + ) model_config = self.llm_engine.model_config runner_type = model_config.runner_type @@ -987,11 +1036,11 @@ def encode( raise ValueError( "LLM.encode() is only supported for pooling models. " "Try passing `--runner pooling` to use the model as a " - "pooling model.") + "pooling model." + ) if pooling_task not in self.supported_tasks: - raise ValueError( - f"pooling_task must be one of {self.supported_tasks}.") + raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") if pooling_params is None: # Use default pooling params. @@ -1011,7 +1060,8 @@ def encode( "No IOProcessor plugin installed. Please refer " "to the documentation and to the " "'prithvi_geospatial_mae_io_processor' " - "offline inference example for more details.") + "offline inference example for more details." + ) # Validate the request data is valid for the loaded plugin validated_prompt = self.io_processor.parse_request(prompts) @@ -1029,19 +1079,23 @@ def encode( outputs = self._run_engine(use_tqdm=use_tqdm) model_outputs = self.engine_class.validate_outputs( - outputs, PoolingRequestOutput) + outputs, PoolingRequestOutput + ) if io_processor_prompt: # get the post-processed model outputs assert self.io_processor is not None processed_outputs = self.io_processor.post_process( - model_output=model_outputs) + model_output=model_outputs + ) return [ - PoolingRequestOutput[Any](request_id="", - outputs=processed_outputs, - prompt_token_ids=[], - finished=True) + PoolingRequestOutput[Any]( + request_id="", + outputs=processed_outputs, + prompt_token_ids=[], + finished=True, + ) ] else: return model_outputs @@ -1052,8 +1106,7 @@ def embed( *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[EmbeddingRequestOutput]: """ @@ -1082,7 +1135,8 @@ def embed( if "embed" not in self.supported_tasks: raise ValueError( "Embedding API is not supported by this model. " - "Try converting the model using `--convert embed`.") + "Try converting the model using `--convert embed`." + ) items = self.encode( prompts, @@ -1100,8 +1154,7 @@ def classify( prompts: Union[PromptType, Sequence[PromptType]], *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ClassificationRequestOutput]: """ @@ -1129,7 +1182,8 @@ def classify( if "classify" not in self.supported_tasks: raise ValueError( "Classification API is not supported by this model. " - "Try converting the model using `--convert classify`.") + "Try converting the model using `--convert classify`." + ) items = self.encode( prompts, @@ -1148,8 +1202,7 @@ def reward( *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, + pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[PoolingRequestOutput]: """ @@ -1190,7 +1243,6 @@ def _embedding_score( pooling_params: Optional[PoolingParams] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, ) -> list[ScoringRequestOutput]: - encoded_output: list[PoolingRequestOutput] = self.encode( text_1 + text_2, truncate_prompt_tokens=truncate_prompt_tokens, @@ -1200,20 +1252,17 @@ def _embedding_score( pooling_task="embed", ) - encoded_output_1: list[PoolingRequestOutput] = encoded_output[ - 0:len(text_1)] - encoded_output_2: list[PoolingRequestOutput] = encoded_output[ - len(text_1):] + encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)] + encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :] if len(encoded_output_1) == 1: encoded_output_1 = encoded_output_1 * len(encoded_output_2) - scores = _cosine_similarity(tokenizer=tokenizer, - embed_1=encoded_output_1, - embed_2=encoded_output_2) + scores = _cosine_similarity( + tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2 + ) - items = self.engine_class.validate_outputs(scores, - PoolingRequestOutput) + items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) return [ScoringRequestOutput.from_base(item) for item in items] def _cross_encoding_score( @@ -1229,8 +1278,7 @@ def _cross_encoding_score( model_config = self.llm_engine.model_config if isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "Score API is not supported for Mistral tokenizer") + raise ValueError("Score API is not supported for Mistral tokenizer") if len(data_1) == 1: data_1 = data_1 * len(data_2) @@ -1244,8 +1292,9 @@ def _cross_encoding_score( tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, tokenization_kwargs) + _validate_truncation_size( + model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs + ) prompts = list[PromptType]() @@ -1262,7 +1311,7 @@ def _cross_encoding_score( tokenization_kwargs=tokenization_kwargs, ) - if (token_type_ids := engine_prompt.pop("token_type_ids", None)): + if token_type_ids := engine_prompt.pop("token_type_ids", None): params = pooling_params.clone() compressed = compress_token_type_ids(token_type_ids) params.extra_kwargs = {"compressed_token_type_ids": compressed} @@ -1280,17 +1329,14 @@ def _cross_encoding_score( ) outputs = self._run_engine(use_tqdm=use_tqdm) - items = self.engine_class.validate_outputs(outputs, - PoolingRequestOutput) + items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput) return [ScoringRequestOutput.from_base(item) for item in items] def score( self, - data_1: Union[SingletonPrompt, Sequence[SingletonPrompt], - ScoreMultiModalParam], - data_2: Union[SingletonPrompt, Sequence[SingletonPrompt], - ScoreMultiModalParam], + data_1: Union[SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam], + data_2: Union[SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam], /, *, truncate_prompt_tokens: Optional[int] = None, @@ -1339,16 +1385,21 @@ def score( raise ValueError( "LLM.score() is only supported for pooling models. " "Try passing `--runner pooling` to use the model as a " - "pooling model.") + "pooling model." + ) supported_tasks = self.supported_tasks if all(t not in supported_tasks for t in ("embed", "classify")): - raise ValueError("Score API is not supported by this model. " - "Try converting the model using " - "`--convert embed` or `--convert classify`.") + raise ValueError( + "Score API is not supported by this model. " + "Try converting the model using " + "`--convert embed` or `--convert classify`." + ) - if (model_config.is_cross_encoder - and getattr(model_config.hf_config, "num_labels", 0) != 1): + if ( + model_config.is_cross_encoder + and getattr(model_config.hf_config, "num_labels", 0) != 1 + ): raise ValueError("Score API is only enabled for num_labels == 1.") # the tokenizer for models such as @@ -1358,12 +1409,16 @@ def score( if not model_config.is_multimodal_model: - def check_data_type(data: Union[SingletonPrompt, - Sequence[SingletonPrompt], - ScoreMultiModalParam]): + def check_data_type( + data: Union[ + SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam + ], + ): if isinstance(data, dict) and "content" in data: - raise ValueError("ScoreMultiModalParam is not supported " - f"for {model_config.architecture}") + raise ValueError( + "ScoreMultiModalParam is not supported " + f"for {model_config.architecture}" + ) check_data_type(data_1) check_data_type(data_2) @@ -1371,11 +1426,13 @@ def check_data_type(data: Union[SingletonPrompt, def ensure_str(prompt: SingletonPrompt): if isinstance(prompt, dict): if "multi_modal_data" in prompt: - raise ValueError("Multi-modal prompt is not " - "supported for scoring") + raise ValueError( + "Multi-modal prompt is not supported for scoring" + ) elif "prompt_token_ids" in prompt: prompt = tokenizer.decode( - cast(TokensPrompt, prompt)["prompt_token_ids"]) + cast(TokensPrompt, prompt)["prompt_token_ids"] + ) elif "prompt" in prompt: prompt = cast(TextPrompt, prompt)["prompt"] assert type(prompt) is str @@ -1413,7 +1470,8 @@ def ensure_str(prompt: SingletonPrompt): truncate_prompt_tokens, use_tqdm, pooling_params, - lora_request) + lora_request, + ) else: return self._embedding_score( tokenizer, @@ -1422,7 +1480,8 @@ def ensure_str(prompt: SingletonPrompt): truncate_prompt_tokens, use_tqdm, pooling_params, - lora_request) + lora_request, + ) def start_profile(self) -> None: self.llm_engine.start_profile() @@ -1484,8 +1543,12 @@ def get_metrics(self) -> list["Metric"]: def _validate_and_add_requests( self, prompts: Union[PromptType, Sequence[PromptType], DataPrompt], - params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, - Sequence[PoolingParams]], + params: Union[ + SamplingParams, + Sequence[SamplingParams], + PoolingParams, + Sequence[PoolingParams], + ], *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], @@ -1497,14 +1560,13 @@ def _validate_and_add_requests( num_requests = len(prompts) if isinstance(params, Sequence) and len(params) != num_requests: - raise ValueError("The lengths of prompts and params " - "must be the same.") - if isinstance(lora_request, - Sequence) and len(lora_request) != num_requests: - raise ValueError("The lengths of prompts and lora_request " - "must be the same.") - - for sp in params if isinstance(params, Sequence) else (params, ): + raise ValueError("The lengths of prompts and params must be the same.") + if isinstance(lora_request, Sequence) and len(lora_request) != num_requests: + raise ValueError( + "The lengths of prompts and lora_request must be the same." + ) + + for sp in params if isinstance(params, Sequence) else (params,): if isinstance(sp, SamplingParams): # We only care about the final output sp.output_kind = RequestOutputKind.FINAL_ONLY @@ -1516,24 +1578,24 @@ def _validate_and_add_requests( it = tqdm_func(it, desc="Adding requests") for i, prompt in enumerate(it): - if isinstance(prompt, dict): self._validate_mm_data_and_uuids( - prompt.get("multi_modal_data"), - prompt.get("multi_modal_uuids")) + prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids") + ) self._add_request( prompt, params[i] if isinstance(params, Sequence) else params, - lora_request=lora_request[i] if isinstance( - lora_request, Sequence) else lora_request, + lora_request=lora_request[i] + if isinstance(lora_request, Sequence) + else lora_request, priority=priority[i] if priority else 0, ) def _validate_mm_data_and_uuids( - self, - multi_modal_data: Optional[Any], # MultiModalDataDict - multi_modal_uuids: Optional[Any], # MultiModalUUIDDict + self, + multi_modal_data: Optional[Any], # MultiModalDataDict + multi_modal_uuids: Optional[Any], # MultiModalUUIDDict ): """ Validate that if any multi-modal data is skipped (i.e. None), @@ -1546,24 +1608,37 @@ def _validate_mm_data_and_uuids( if isinstance(data, list): for i, d in enumerate(data): if d is None: - if multi_modal_uuids is None or modality not in multi_modal_uuids or multi_modal_uuids[ # noqa: E501 - modality] is None: + if ( + multi_modal_uuids is None + or modality not in multi_modal_uuids + or multi_modal_uuids[ # noqa: E501 + modality + ] + is None + ): raise ValueError( f"Multi-modal data for {modality} is None " - f"but UUID is not provided") + f"but UUID is not provided" + ) else: - if len( - multi_modal_uuids[modality] - ) <= i or multi_modal_uuids[modality][i] is None: + if ( + len(multi_modal_uuids[modality]) <= i + or multi_modal_uuids[modality][i] is None + ): raise ValueError( f"Multi-modal data for {modality} is None " - f"but UUID is not provided") + f"but UUID is not provided" + ) else: - if data is None and (multi_modal_uuids is None - or modality not in multi_modal_uuids - or multi_modal_uuids[modality] is None): - raise ValueError(f"Multi-modal data for {modality} is None" - f" but UUID is not provided") + if data is None and ( + multi_modal_uuids is None + or modality not in multi_modal_uuids + or multi_modal_uuids[modality] is None + ): + raise ValueError( + f"Multi-modal data for {modality} is None" + f" but UUID is not provided" + ) def _process_inputs( self, @@ -1576,9 +1651,11 @@ def _process_inputs( ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for LLMEngine.""" tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size(self.model_config.max_model_len, - params.truncate_prompt_tokens, - tokenization_kwargs) + _validate_truncation_size( + self.model_config.max_model_len, + params.truncate_prompt_tokens, + tokenization_kwargs, + ) processor = self._get_processor() engine_request = processor.process_inputs( @@ -1620,9 +1697,7 @@ def _add_request( ) def _run_engine( - self, - *, - use_tqdm: Union[bool, Callable[..., tqdm]] = True + self, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True ) -> list[Union[RequestOutput, PoolingRequestOutput]]: # Initialize tqdm. if use_tqdm: @@ -1632,8 +1707,7 @@ def _run_engine( total=num_requests, desc="Processed prompts", dynamic_ncols=True, - postfix=(f"est. speed input: {0:.2f} toks/s, " - f"output: {0:.2f} toks/s"), + postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"), ) # Run the engine. @@ -1653,12 +1727,13 @@ def _run_engine( total_in_toks += len(output.prompt_token_ids) * n in_spd = total_in_toks / pbar.format_dict["elapsed"] total_out_toks += sum( - len(stp.token_ids) for stp in output.outputs) - out_spd = (total_out_toks / - pbar.format_dict["elapsed"]) + len(stp.token_ids) for stp in output.outputs + ) + out_spd = total_out_toks / pbar.format_dict["elapsed"] pbar.postfix = ( f"est. speed input: {in_spd:.2f} toks/s, " - f"output: {out_spd:.2f} toks/s") + f"output: {out_spd:.2f} toks/s" + ) pbar.update(n) else: pbar.update(1) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 152d11c84ea0..96a84668e92b 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -15,7 +15,6 @@ class RequestLogger: - def __init__(self, *, max_log_len: Optional[int]) -> None: self.max_log_len = max_log_len @@ -25,8 +24,7 @@ def log_inputs( prompt: Optional[str], prompt_token_ids: Optional[list[int]], prompt_embeds: Optional[torch.Tensor], - params: Optional[Union[SamplingParams, PoolingParams, - BeamSearchParams]], + params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], ) -> None: max_log_len = self.max_log_len @@ -41,9 +39,14 @@ def log_inputs( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " "prompt_embeds shape: %s, " - "lora_request: %s.", request_id, prompt, params, prompt_token_ids, + "lora_request: %s.", + request_id, + prompt, + params, + prompt_token_ids, prompt_embeds.shape if prompt_embeds is not None else None, - lora_request) + lora_request, + ) def log_outputs( self, @@ -65,8 +68,7 @@ def log_outputs( stream_info = "" if is_streaming: - stream_info = (" (streaming delta)" - if delta else " (streaming complete)") + stream_info = " (streaming delta)" if delta else " (streaming complete)" logger.info( "Generated response %s%s: output: %r, " diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 15844d3162fe..889326dee749 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -25,8 +25,7 @@ import pydantic import regex as re import uvloop -from fastapi import (APIRouter, Depends, FastAPI, Form, HTTPException, Query, - Request) +from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse @@ -42,69 +41,83 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import (load_chat_template, - resolve_hf_chat_template, - resolve_mistral_chat_template) +from vllm.entrypoints.chat_utils import ( + load_chat_template, + resolve_hf_chat_template, + resolve_mistral_chat_template, +) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.cli_args import (make_arg_parser, - validate_parsed_serve_args) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionResponse, - ClassificationRequest, - ClassificationResponse, - CompletionRequest, - CompletionResponse, - DetokenizeRequest, - DetokenizeResponse, - EmbeddingRequest, - EmbeddingResponse, ErrorInfo, - ErrorResponse, - IOProcessorResponse, - LoadLoRAAdapterRequest, - PoolingRequest, PoolingResponse, - RerankRequest, RerankResponse, - ResponsesRequest, - ResponsesResponse, ScoreRequest, - ScoreResponse, - StreamingResponsesResponse, - TokenizeRequest, - TokenizeResponse, - TranscriptionRequest, - TranscriptionResponse, - TranslationRequest, - TranslationResponse, - UnloadLoRAAdapterRequest) -# yapf: enable +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ClassificationRequest, + ClassificationResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + DetokenizeResponse, + EmbeddingRequest, + EmbeddingResponse, + ErrorInfo, + ErrorResponse, + IOProcessorResponse, + LoadLoRAAdapterRequest, + PoolingRequest, + PoolingResponse, + RerankRequest, + RerankResponse, + ResponsesRequest, + ResponsesResponse, + ScoreRequest, + ScoreResponse, + StreamingResponsesResponse, + TokenizeRequest, + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, + TranslationRequest, + TranslationResponse, + UnloadLoRAAdapterRequest, +) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_classification import ( - ServingClassification) +from vllm.entrypoints.openai.serving_classification import ServingClassification from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - LoRAModulePath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_models import ( + BaseModelPath, + LoRAModulePath, + OpenAIServingModels, +) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses from vllm.entrypoints.openai.serving_score import ServingScores -from vllm.entrypoints.openai.serving_tokenization import ( - OpenAIServingTokenization) +from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization from vllm.entrypoints.openai.serving_transcription import ( - OpenAIServingTranscription, OpenAIServingTranslation) + OpenAIServingTranscription, + OpenAIServingTranslation, +) from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer, - ToolServer) -from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, - log_non_default_args, with_cancellation) +from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer +from vllm.entrypoints.utils import ( + cli_env_setup, + load_aware_call, + log_non_default_args, + with_cancellation, +) from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import (Device, FlexibleArgumentParser, decorate_logs, - is_valid_ipv6_address, set_ulimit) +from vllm.utils import ( + Device, + FlexibleArgumentParser, + decorate_logs, + is_valid_ipv6_address, + set_ulimit, +) from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION @@ -112,7 +125,7 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) -logger = init_logger('vllm.entrypoints.openai.api_server') +logger = init_logger("vllm.entrypoints.openai.api_server") _running_tasks: set[asyncio.Task] = set() @@ -156,12 +169,11 @@ async def build_async_engine_client( disable_frontend_multiprocessing: Optional[bool] = None, client_config: Optional[dict[str, Any]] = None, ) -> AsyncIterator[EngineClient]: - if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver": # The executor is expected to be mp. # Pre-import heavy modules in the forkserver process logger.debug("Setup forkserver with pre-imports") - multiprocessing.set_start_method('forkserver') + multiprocessing.set_start_method("forkserver") multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"]) forkserver.ensure_running() logger.debug("Forkserver setup complete!") @@ -174,14 +186,13 @@ async def build_async_engine_client( engine_args._api_process_rank = client_config.get("client_index", 0) if disable_frontend_multiprocessing is None: - disable_frontend_multiprocessing = bool( - args.disable_frontend_multiprocessing) + disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing) async with build_async_engine_client_from_engine_args( - engine_args, - usage_context=usage_context, - disable_frontend_multiprocessing=disable_frontend_multiprocessing, - client_config=client_config, + engine_args, + usage_context=usage_context, + disable_frontend_multiprocessing=disable_frontend_multiprocessing, + client_config=client_config, ) as engine: yield engine @@ -211,9 +222,11 @@ async def build_async_engine_client_from_engine_args( if disable_frontend_multiprocessing: logger.warning( "V1 is enabled, but got --disable-frontend-multiprocessing. " - "To disable frontend multiprocessing, set VLLM_USE_V1=0.") + "To disable frontend multiprocessing, set VLLM_USE_V1=0." + ) from vllm.v1.engine.async_llm import AsyncLLM + async_llm: Optional[AsyncLLM] = None # Don't mutate the input client_config @@ -229,7 +242,8 @@ async def build_async_engine_client_from_engine_args( disable_log_stats=engine_args.disable_log_stats, client_addresses=client_config, client_count=client_count, - client_index=client_index) + client_index=client_index, + ) # Don't keep the dummy data in memory await async_llm.reset_mm_cache() @@ -244,9 +258,9 @@ async def validate_json_request(raw_request: Request): content_type = raw_request.headers.get("content-type", "").lower() media_type = content_type.split(";", maxsplit=1)[0] if media_type != "application/json": - raise RequestValidationError(errors=[ - "Unsupported Media Type: Only 'application/json' is allowed" - ]) + raise RequestValidationError( + errors=["Unsupported Media Type: Only 'application/json' is allowed"] + ) router = APIRouter() @@ -368,8 +382,7 @@ async def get_server_load_metrics(request: Request): # - /rerank # - /v1/rerank # - /v2/rerank - return JSONResponse( - content={'server_load': request.app.state.server_load_metrics}) + return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) @router.get("/ping", response_class=Response) @@ -379,22 +392,16 @@ async def ping(raw_request: Request) -> Response: return await health(raw_request) -@router.post("/tokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_IMPLEMENTED.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/tokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, + }, +) @with_cancellation async def tokenize(request: TokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -402,34 +409,33 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): try: generator = await handler.create_tokenize(request, raw_request) except NotImplementedError as e: - raise HTTPException(status_code=HTTPStatus.NOT_IMPLEMENTED.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e) + ) from e except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, TokenizeResponse): return JSONResponse(content=generator.model_dump()) assert_never(generator) -@router.post("/detokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/detokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation async def detokenize(request: DetokenizeRequest, raw_request: Request): handler = tokenization(raw_request) @@ -439,12 +445,14 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): except OverflowError as e: raise RequestValidationError(errors=[str(e)]) from e except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, DetokenizeResponse): return JSONResponse(content=generator.model_dump()) @@ -453,15 +461,18 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): def maybe_register_tokenizer_info_endpoint(args): """Conditionally register the tokenizer info endpoint if enabled.""" - if getattr(args, 'enable_tokenizer_info_endpoint', False): + if getattr(args, "enable_tokenizer_info_endpoint", False): @router.get("/tokenizer_info") async def get_tokenizer_info(raw_request: Request): """Get comprehensive tokenizer information.""" result = await tokenization(raw_request).get_tokenizer_info() - return JSONResponse(content=result.model_dump(), - status_code=result.error.code if isinstance( - result, ErrorResponse) else 200) + return JSONResponse( + content=result.model_dump(), + status_code=result.error.code + if isinstance(result, ErrorResponse) + else 200, + ) @router.get("/v1/models") @@ -479,55 +490,52 @@ async def show_version(): async def _convert_stream_to_sse_events( - generator: AsyncGenerator[StreamingResponsesResponse, None] + generator: AsyncGenerator[StreamingResponsesResponse, None], ) -> AsyncGenerator[str, None]: """Convert the generator to a stream of events in SSE format""" async for event in generator: - event_type = getattr(event, 'type', 'unknown') + event_type = getattr(event, "type", "unknown") # https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format - event_data = (f"event: {event_type}\n" - f"data: {event.model_dump_json(indent=None)}\n\n") + event_data = ( + f"event: {event_type}\ndata: {event.model_dump_json(indent=None)}\n\n" + ) yield event_data -@router.post("/v1/responses", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/responses", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation async def create_responses(request: ResponsesRequest, raw_request: Request): handler = responses(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Responses API") + message="The model does not support Responses API" + ) try: generator = await handler.create_responses(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, ResponsesResponse): return JSONResponse(content=generator.model_dump()) - return StreamingResponse(content=_convert_stream_to_sse_events(generator), - media_type="text/event-stream") + return StreamingResponse( + content=_convert_stream_to_sse_events(generator), media_type="text/event-stream" + ) @router.get("/v1/responses/{response_id}") @@ -540,7 +548,8 @@ async def retrieve_responses( handler = responses(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Responses API") + message="The model does not support Responses API" + ) try: response = await handler.retrieve_responses( @@ -549,16 +558,19 @@ async def retrieve_responses( stream=stream, ) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.error.code) + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) elif isinstance(response, ResponsesResponse): return JSONResponse(content=response.model_dump()) - return StreamingResponse(content=_convert_stream_to_sse_events(response), - media_type="text/event-stream") + return StreamingResponse( + content=_convert_stream_to_sse_events(response), media_type="text/event-stream" + ) @router.post("/v1/responses/{response_id}/cancel") @@ -566,54 +578,51 @@ async def cancel_responses(response_id: str, raw_request: Request): handler = responses(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Responses API") + message="The model does not support Responses API" + ) try: response = await handler.cancel_responses(response_id) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.error.code) + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) return JSONResponse(content=response.model_dump()) -@router.post("/v1/chat/completions", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - } - }) +@router.post( + "/v1/chat/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call -async def create_chat_completion(request: ChatCompletionRequest, - raw_request: Request): +async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): handler = chat(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Chat Completions API") + message="The model does not support Chat Completions API" + ) try: generator = await handler.create_chat_completion(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, ChatCompletionResponse): return JSONResponse(content=generator.model_dump()) @@ -621,108 +630,106 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/completions", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.NOT_FOUND.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/completions", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def create_completion(request: CompletionRequest, raw_request: Request): handler = completion(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Completions API") + message="The model does not support Completions API" + ) try: generator = await handler.create_completion(request, raw_request) except OverflowError as e: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e) + ) from e except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, CompletionResponse): return JSONResponse(content=generator.model_dump()) return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/embeddings", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/embeddings", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Embeddings API") + message="The model does not support Embeddings API" + ) try: generator = await handler.create_embedding(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, EmbeddingResponse): return JSONResponse(content=generator.model_dump()) assert_never(generator) -@router.post("/pooling", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/pooling", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def create_pooling(request: PoolingRequest, raw_request: Request): handler = pooling(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Pooling API") + message="The model does not support Pooling API" + ) try: generator = await handler.create_pooling(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, (PoolingResponse, IOProcessorResponse)): return JSONResponse(content=generator.model_dump()) @@ -732,21 +739,23 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): @router.post("/classify", dependencies=[Depends(validate_json_request)]) @with_cancellation @load_aware_call -async def create_classify(request: ClassificationRequest, - raw_request: Request): +async def create_classify(request: ClassificationRequest, raw_request: Request): handler = classify(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Classification API") + message="The model does not support Classification API" + ) try: generator = await handler.create_classify(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, ClassificationResponse): return JSONResponse(content=generator.model_dump()) @@ -754,96 +763,90 @@ async def create_classify(request: ClassificationRequest, assert_never(generator) -@router.post("/score", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/score", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def create_score(request: ScoreRequest, raw_request: Request): handler = score(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Score API") + message="The model does not support Score API" + ) try: generator = await handler.create_score(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, ScoreResponse): return JSONResponse(content=generator.model_dump()) assert_never(generator) -@router.post("/v1/score", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/score", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def create_score_v1(request: ScoreRequest, raw_request: Request): logger.warning( "To indicate that Score API is not part of standard OpenAI API, we " - "have moved it to `/score`. Please update your client accordingly.") + "have moved it to `/score`. Please update your client accordingly." + ) return await create_score(request, raw_request) -@router.post("/v1/audio/transcriptions", - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.UNPROCESSABLE_ENTITY.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/audio/transcriptions", + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call -async def create_transcriptions(raw_request: Request, - request: Annotated[TranscriptionRequest, - Form()]): +async def create_transcriptions( + raw_request: Request, request: Annotated[TranscriptionRequest, Form()] +): handler = transcription(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Transcriptions API") + message="The model does not support Transcriptions API" + ) audio_data = await request.file.read() try: - generator = await handler.create_transcription(audio_data, request, - raw_request) + generator = await handler.create_transcription(audio_data, request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, TranscriptionResponse): return JSONResponse(content=generator.model_dump()) @@ -851,44 +854,38 @@ async def create_transcriptions(raw_request: Request, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/v1/audio/translations", - responses={ - HTTPStatus.OK.value: { - "content": { - "text/event-stream": {} - } - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.UNPROCESSABLE_ENTITY.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/audio/translations", + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call -async def create_translations(request: Annotated[TranslationRequest, - Form()], - raw_request: Request): +async def create_translations( + request: Annotated[TranslationRequest, Form()], raw_request: Request +): handler = translation(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Translations API") + message="The model does not support Translations API" + ) audio_data = await request.file.read() try: - generator = await handler.create_translation(audio_data, request, - raw_request) + generator = await handler.create_translation(audio_data, request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, TranslationResponse): return JSONResponse(content=generator.model_dump()) @@ -896,90 +893,88 @@ async def create_translations(request: Annotated[TranslationRequest, return StreamingResponse(content=generator, media_type="text/event-stream") -@router.post("/rerank", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation @load_aware_call async def do_rerank(request: RerankRequest, raw_request: Request): handler = rerank(raw_request) if handler is None: return base(raw_request).create_error_response( - message="The model does not support Rerank (Score) API") + message="The model does not support Rerank (Score) API" + ) try: generator = await handler.do_rerank(request, raw_request) except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - detail=str(e)) from e + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), - status_code=generator.error.code) + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) elif isinstance(generator, RerankResponse): return JSONResponse(content=generator.model_dump()) assert_never(generator) -@router.post("/v1/rerank", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v1/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation async def do_rerank_v1(request: RerankRequest, raw_request: Request): logger.warning_once( "To indicate that the rerank API is not part of the standard OpenAI" " API, we have located it at `/rerank`. Please update your client " - "accordingly. (Note: Conforms to JinaAI rerank API)") + "accordingly. (Note: Conforms to JinaAI rerank API)" + ) return await do_rerank(request, raw_request) -@router.post("/v2/rerank", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/v2/rerank", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) @with_cancellation async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) if envs.VLLM_SERVER_DEV_MODE: - logger.warning("SECURITY WARNING: Development endpoints are enabled! " - "This should NOT be used in production!") + logger.warning( + "SECURITY WARNING: Development endpoints are enabled! " + "This should NOT be used in production!" + ) PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig) @router.get("/server_info") async def show_server_info( raw_request: Request, - config_format: Annotated[Literal["text", "json"], - Query()] = "text", + config_format: Annotated[Literal["text", "json"], Query()] = "text", ): vllm_config: VllmConfig = raw_request.app.state.vllm_config server_info = { - "vllm_config": - str(vllm_config) - if config_format == "text" else PydanticVllmConfig.dump_python( - vllm_config, mode="json", fallback=str) + "vllm_config": str(vllm_config) + if config_format == "text" + else PydanticVllmConfig.dump_python(vllm_config, mode="json", fallback=str) # fallback=str is needed to handle e.g. torch.dtype } return JSONResponse(content=server_info) @@ -1030,19 +1025,24 @@ async def collective_rpc(raw_request: Request): try: body = await raw_request.json() except json.JSONDecodeError as e: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}") from e + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e method = body.get("method") if method is None: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, - detail="Missing 'method' in request body") + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'method' in request body", + ) # For security reason, only serialized string args/kwargs are passed. # User-defined `method` is responsible for deserialization if needed. args: list[str] = body.get("args", []) kwargs: dict[str, str] = body.get("kwargs", {}) timeout: Optional[float] = body.get("timeout") results = await engine_client(raw_request).collective_rpc( - method=method, timeout=timeout, args=tuple(args), kwargs=kwargs) + method=method, timeout=timeout, args=tuple(args), kwargs=kwargs + ) if results is None: return Response(status_code=200) response: list[Any] = [] @@ -1054,45 +1054,39 @@ async def collective_rpc(raw_request: Request): return JSONResponse(content={"results": response}) -@router.post("/scale_elastic_ep", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: { - "model": dict - }, - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.REQUEST_TIMEOUT.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/scale_elastic_ep", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) async def scale_elastic_ep(raw_request: Request): try: body = await raw_request.json() except json.JSONDecodeError as e: - raise HTTPException(status_code=400, - detail="Invalid JSON format") from e # noqa: B904 + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 new_data_parallel_size = body.get("new_data_parallel_size") drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes if new_data_parallel_size is None: - raise HTTPException(status_code=400, - detail="new_data_parallel_size is required") + raise HTTPException( + status_code=400, detail="new_data_parallel_size is required" + ) - if not isinstance(new_data_parallel_size, - int) or new_data_parallel_size <= 0: + if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0: raise HTTPException( - status_code=400, - detail="new_data_parallel_size must be a positive integer") + status_code=400, detail="new_data_parallel_size must be a positive integer" + ) if not isinstance(drain_timeout, int) or drain_timeout <= 0: - raise HTTPException(status_code=400, - detail="drain_timeout must be a positive integer") + raise HTTPException( + status_code=400, detail="drain_timeout must be a positive integer" + ) # Set scaling flag to prevent new requests global _scaling_elastic_ep @@ -1100,15 +1094,17 @@ async def scale_elastic_ep(raw_request: Request): client = engine_client(raw_request) try: await client.scale_elastic_ep(new_data_parallel_size, drain_timeout) - return JSONResponse({ - "message": - f"Scaled to {new_data_parallel_size} " - "data parallel engines", - }) + return JSONResponse( + { + "message": f"Scaled to {new_data_parallel_size} data parallel engines", + } + ) except TimeoutError as e: - raise HTTPException(status_code=408, - detail="Scale failed due to request drain timeout " - f"after {drain_timeout} seconds") from e + raise HTTPException( + status_code=408, + detail="Scale failed due to request drain timeout " + f"after {drain_timeout} seconds", + ) from e except Exception as e: logger.error("Scale failed: %s", e) raise HTTPException(status_code=500, detail="Scale failed") from e @@ -1145,31 +1141,29 @@ async def is_scaling_elastic_ep(raw_request: Request): ] -@router.post("/invocations", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: { - "model": ErrorResponse - }, - HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: { - "model": ErrorResponse - }, - HTTPStatus.INTERNAL_SERVER_ERROR.value: { - "model": ErrorResponse - }, - }) +@router.post( + "/invocations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) async def invocations(raw_request: Request): """For SageMaker, routes requests based on the request type.""" try: body = await raw_request.json() except json.JSONDecodeError as e: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}") from e + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}" + ) from e - valid_endpoints = [(validator, endpoint) - for validator, (get_handler, - endpoint) in INVOCATION_VALIDATORS - if get_handler(raw_request) is not None] + valid_endpoints = [ + (validator, endpoint) + for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS + if get_handler(raw_request) is not None + ] for request_validator, endpoint in valid_endpoints: try: @@ -1183,8 +1177,7 @@ async def invocations(raw_request: Request): t.__name__ if isinstance(t := validator._type, type) else str(t) for validator, _ in valid_endpoints ] - msg = ("Cannot find suitable handler for request. " - f"Expected one of: {type_names}") + msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" res = base(raw_request).create_error_response(message=msg) return JSONResponse(content=res.model_dump(), status_code=res.error.code) @@ -1192,7 +1185,8 @@ async def invocations(raw_request: Request): if envs.VLLM_TORCH_PROFILER_DIR: logger.warning( "Torch Profiler is enabled in the API server. This should ONLY be " - "used for local development!") + "used for local development!" + ) @router.post("/start_profile") async def start_profile(raw_request: Request): @@ -1212,29 +1206,32 @@ async def stop_profile(raw_request: Request): if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: logger.warning( "LoRA dynamic loading & unloading is enabled in the API server. " - "This should ONLY be used for local development!") + "This should ONLY be used for local development!" + ) - @router.post("/v1/load_lora_adapter", - dependencies=[Depends(validate_json_request)]) - async def load_lora_adapter(request: LoadLoRAAdapterRequest, - raw_request: Request): + @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) + async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): handler = models(raw_request) response = await handler.load_lora_adapter(request) if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.error.code) + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) return Response(status_code=200, content=response) - @router.post("/v1/unload_lora_adapter", - dependencies=[Depends(validate_json_request)]) - async def unload_lora_adapter(request: UnloadLoRAAdapterRequest, - raw_request: Request): + @router.post( + "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)] + ) + async def unload_lora_adapter( + request: UnloadLoRAAdapterRequest, raw_request: Request + ): handler = models(raw_request) response = await handler.unload_lora_adapter(request) if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.error.code) + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) return Response(status_code=200, content=response) @@ -1246,8 +1243,9 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]: with open(log_config_file) as f: return json.load(f) except Exception as e: - logger.warning("Failed to load log config from file %s: error %s", - log_config_file, e) + logger.warning( + "Failed to load log config from file %s: error %s", log_config_file, e + ) return None @@ -1265,9 +1263,7 @@ class AuthenticationMiddleware: def __init__(self, app: ASGIApp, tokens: list[str]) -> None: self.app = app - self.api_tokens = [ - hashlib.sha256(t.encode("utf-8")).digest() for t in tokens - ] + self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens] def verify_token(self, headers: Headers) -> bool: authorization_header_value = headers.get("Authorization") @@ -1286,10 +1282,8 @@ def verify_token(self, headers: Headers) -> bool: return token_match - def __call__(self, scope: Scope, receive: Receive, - send: Send) -> Awaitable[None]: - if scope["type"] not in ("http", - "websocket") or scope["method"] == "OPTIONS": + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: + if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS": # scope["type"] can be "lifespan" or "startup" for example, # in which case we don't need to do anything return self.app(scope, receive, send) @@ -1298,8 +1292,7 @@ def __call__(self, scope: Scope, receive: Receive, headers = Headers(scope=scope) # Type narrow to satisfy mypy. if url_path.startswith("/v1") and not self.verify_token(headers): - response = JSONResponse(content={"error": "Unauthorized"}, - status_code=401) + response = JSONResponse(content={"error": "Unauthorized"}, status_code=401) return response(scope, receive, send) return self.app(scope, receive, send) @@ -1314,8 +1307,7 @@ class XRequestIdMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app - def __call__(self, scope: Scope, receive: Receive, - send: Send) -> Awaitable[None]: + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: if scope["type"] not in ("http", "websocket"): return self.app(scope, receive, send) @@ -1329,8 +1321,7 @@ async def send_with_request_id(message: Message) -> None: """ if message["type"] == "http.response.start": response_headers = MutableHeaders(raw=message["headers"]) - request_id = request_headers.get("X-Request-Id", - uuid.uuid4().hex) + request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex) response_headers.append("X-Request-Id", request_id) await send(message) @@ -1353,8 +1344,7 @@ class ScalingMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app - def __call__(self, scope: Scope, receive: Receive, - send: Send) -> Awaitable[None]: + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: if scope["type"] != "http": return self.app(scope, receive, send) @@ -1362,11 +1352,12 @@ def __call__(self, scope: Scope, receive: Receive, global _scaling_elastic_ep if _scaling_elastic_ep: # Return 503 Service Unavailable response - response = JSONResponse(content={ - "error": - "The model is currently scaling. Please try again later." - }, - status_code=503) + response = JSONResponse( + content={ + "error": "The model is currently scaling. Please try again later." + }, + status_code=503, + ) return response(scope, receive, send) return self.app(scope, receive, send) @@ -1376,28 +1367,27 @@ def _extract_content_from_chunk(chunk_data: dict) -> str: """Extract content from a streaming response chunk.""" try: from vllm.entrypoints.openai.protocol import ( - ChatCompletionStreamResponse, CompletionStreamResponse) + ChatCompletionStreamResponse, + CompletionStreamResponse, + ) # Try using Completion types for type-safe parsing - if chunk_data.get('object') == 'chat.completion.chunk': - chat_response = ChatCompletionStreamResponse.model_validate( - chunk_data) + if chunk_data.get("object") == "chat.completion.chunk": + chat_response = ChatCompletionStreamResponse.model_validate(chunk_data) if chat_response.choices and chat_response.choices[0].delta.content: return chat_response.choices[0].delta.content - elif chunk_data.get('object') == 'text_completion': - completion_response = CompletionStreamResponse.model_validate( - chunk_data) - if completion_response.choices and completion_response.choices[ - 0].text: + elif chunk_data.get("object") == "text_completion": + completion_response = CompletionStreamResponse.model_validate(chunk_data) + if completion_response.choices and completion_response.choices[0].text: return completion_response.choices[0].text except pydantic.ValidationError: # Fallback to manual parsing - if 'choices' in chunk_data and chunk_data['choices']: - choice = chunk_data['choices'][0] - if 'delta' in choice and choice['delta'].get('content'): - return choice['delta']['content'] - elif choice.get('text'): - return choice['text'] + if "choices" in chunk_data and chunk_data["choices"]: + choice = chunk_data["choices"][0] + if "delta" in choice and choice["delta"].get("content"): + return choice["delta"]["content"] + elif choice.get("text"): + return choice["text"] return "" @@ -1413,7 +1403,7 @@ def decode_chunk(self, chunk: bytes) -> list[dict]: import json try: - chunk_str = chunk.decode('utf-8') + chunk_str = chunk.decode("utf-8") except UnicodeDecodeError: # Skip malformed chunks return [] @@ -1422,18 +1412,18 @@ def decode_chunk(self, chunk: bytes) -> list[dict]: events = [] # Process complete lines - while '\n' in self.buffer: - line, self.buffer = self.buffer.split('\n', 1) - line = line.rstrip('\r') # Handle CRLF + while "\n" in self.buffer: + line, self.buffer = self.buffer.split("\n", 1) + line = line.rstrip("\r") # Handle CRLF - if line.startswith('data: '): + if line.startswith("data: "): data_str = line[6:].strip() - if data_str == '[DONE]': - events.append({'type': 'done'}) + if data_str == "[DONE]": + events.append({"type": "done"}) elif data_str: try: event_data = json.loads(data_str) - events.append({'type': 'data', 'data': event_data}) + events.append({"type": "data", "data": event_data}) except json.JSONDecodeError: # Skip malformed JSON continue @@ -1451,7 +1441,7 @@ def add_content(self, content: str) -> None: def get_complete_content(self) -> str: """Get the complete buffered content.""" - return ''.join(self.content_buffer) + return "".join(self.content_buffer) def _log_streaming_response(response, response_body: list) -> None: @@ -1472,10 +1462,10 @@ def buffered_iterator(): events = sse_decoder.decode_chunk(chunk) for event in events: - if event['type'] == 'data': - content = sse_decoder.extract_content(event['data']) + if event["type"] == "data": + content = sse_decoder.extract_content(event["data"]) sse_decoder.add_content(content) - elif event['type'] == 'done': + elif event["type"] == "done": # Log complete content when done full_content = sse_decoder.get_complete_content() if full_content: @@ -1484,19 +1474,20 @@ def buffered_iterator(): full_content = full_content[:2048] + "" "...[truncated]" logger.info( - "response_body={streaming_complete: " \ + "response_body={streaming_complete: " "content='%s', chunks=%d}", - full_content, chunk_count) + full_content, + chunk_count, + ) else: logger.info( - "response_body={streaming_complete: " \ - "no_content, chunks=%d}", - chunk_count) + "response_body={streaming_complete: no_content, chunks=%d}", + chunk_count, + ) return response.body_iterator = iterate_in_threadpool(buffered_iterator()) - logger.info("response_body={streaming_started: chunks=%d}", - len(response_body)) + logger.info("response_body={streaming_started: chunks=%d}", len(response_body)) def _log_non_streaming_response(response_body: list) -> None: @@ -1510,10 +1501,9 @@ def _log_non_streaming_response(response_body: list) -> None: def build_app(args: Namespace) -> FastAPI: if args.disable_fastapi_docs: - app = FastAPI(openapi_url=None, - docs_url=None, - redoc_url=None, - lifespan=lifespan) + app = FastAPI( + openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan + ) else: app = FastAPI(lifespan=lifespan) app.include_router(router) @@ -1532,14 +1522,16 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(HTTPException) async def http_exception_handler(_: Request, exc: HTTPException): err = ErrorResponse( - error=ErrorInfo(message=exc.detail, - type=HTTPStatus(exc.status_code).phrase, - code=exc.status_code)) + error=ErrorInfo( + message=exc.detail, + type=HTTPStatus(exc.status_code).phrase, + code=exc.status_code, + ) + ) return JSONResponse(err.model_dump(), status_code=exc.status_code) @app.exception_handler(RequestValidationError) - async def validation_exception_handler(_: Request, - exc: RequestValidationError): + async def validation_exception_handler(_: Request, exc: RequestValidationError): exc_str = str(exc) errors_str = str(exc.errors()) @@ -1548,11 +1540,14 @@ async def validation_exception_handler(_: Request, else: message = exc_str - err = ErrorResponse(error=ErrorInfo(message=message, - type=HTTPStatus.BAD_REQUEST.phrase, - code=HTTPStatus.BAD_REQUEST)) - return JSONResponse(err.model_dump(), - status_code=HTTPStatus.BAD_REQUEST) + err = ErrorResponse( + error=ErrorInfo( + message=message, + type=HTTPStatus.BAD_REQUEST.phrase, + code=HTTPStatus.BAD_REQUEST, + ) + ) + return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]: @@ -1565,16 +1560,16 @@ async def validation_exception_handler(_: Request, app.add_middleware(ScalingMiddleware) if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE: - logger.warning("CAUTION: Enabling log response in the API Server. " - "This can include sensitive information and should be " - "avoided in production.") + logger.warning( + "CAUTION: Enabling log response in the API Server. " + "This can include sensitive information and should be " + "avoided in production." + ) @app.middleware("http") async def log_response(request: Request, call_next): response = await call_next(request) - response_body = [ - section async for section in response.body_iterator - ] + response_body = [section async for section in response.body_iterator] response.body_iterator = iterate_in_threadpool(iter(response_body)) # Check if this is a streaming response by looking at content-type content_type = response.headers.get("content-type", "") @@ -1597,8 +1592,9 @@ async def log_response(request: Request, call_next): elif inspect.iscoroutinefunction(imported): app.middleware("http")(imported) else: - raise ValueError(f"Invalid middleware {middleware}. " - f"Must be a function or a class.") + raise ValueError( + f"Invalid middleware {middleware}. Must be a function or a class." + ) return app @@ -1620,8 +1616,7 @@ async def init_app_state( request_logger = None base_model_paths = [ - BaseModelPath(name=name, model_path=args.model) - for name in served_model_names + BaseModelPath(name=name, model_path=args.model) for name in served_model_names ] state.engine_client = engine_client @@ -1641,7 +1636,8 @@ async def init_app_state( if isinstance(tokenizer, MistralTokenizer): # The warning is logged in resolve_mistral_chat_template. resolved_chat_template = resolve_mistral_chat_template( - chat_template=resolved_chat_template) + chat_template=resolved_chat_template + ) else: hf_chat_template = resolve_hf_chat_template( tokenizer=tokenizer, @@ -1655,7 +1651,9 @@ async def init_app_state( "Using supplied chat template: %s\n" "It is different from official chat template '%s'. " "This discrepancy may lead to performance degradation.", - resolved_chat_template, args.model) + resolved_chat_template, + args.model, + ) if args.tool_server == "demo": tool_server: Optional[ToolServer] = DemoToolServer() @@ -1668,8 +1666,11 @@ async def init_app_state( tool_server = None # Merge default_mm_loras into the static lora_modules - default_mm_loras = (vllm_config.lora_config.default_mm_loras - if vllm_config.lora_config is not None else {}) + default_mm_loras = ( + vllm_config.lora_config.default_mm_loras + if vllm_config.lora_config is not None + else {} + ) lora_modules = args.lora_modules if default_mm_loras: @@ -1677,7 +1678,8 @@ async def init_app_state( LoRAModulePath( name=modality, path=lora_path, - ) for modality, lora_path in default_mm_loras.items() + ) + for modality, lora_path in default_mm_loras.items() ] if args.lora_modules is None: lora_modules = default_mm_lora_paths @@ -1691,85 +1693,114 @@ async def init_app_state( lora_modules=lora_modules, ) await state.openai_serving_models.init_static_loras() - state.openai_serving_responses = OpenAIServingResponses( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser, - tool_server=tool_server, - reasoning_parser=args.structured_outputs_config.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - enable_log_outputs=args.enable_log_outputs, - log_error_stack=args.log_error_stack, - ) if "generate" in supported_tasks else None - state.openai_serving_chat = OpenAIServingChat( - engine_client, - model_config, - state.openai_serving_models, - args.response_role, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - trust_request_chat_template=args.trust_request_chat_template, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - exclude_tools_when_tool_choice_none=args. - exclude_tools_when_tool_choice_none, - tool_parser=args.tool_call_parser, - reasoning_parser=args.structured_outputs_config.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - enable_log_outputs=args.enable_log_outputs, - log_error_stack=args.log_error_stack, - ) if "generate" in supported_tasks else None - state.openai_serving_completion = OpenAIServingCompletion( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - log_error_stack=args.log_error_stack, - ) if "generate" in supported_tasks else None - state.openai_serving_pooling = OpenAIServingPooling( - engine_client, - vllm_config, - state.openai_serving_models, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - log_error_stack=args.log_error_stack, - ) if "encode" in supported_tasks else None - state.openai_serving_embedding = OpenAIServingEmbedding( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - log_error_stack=args.log_error_stack, - ) if "embed" in supported_tasks else None - state.openai_serving_classification = ServingClassification( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - log_error_stack=args.log_error_stack, - ) if "classify" in supported_tasks else None - state.openai_serving_scores = ServingScores( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - log_error_stack=args.log_error_stack, - ) if ("embed" in supported_tasks or "score" in supported_tasks) else None + state.openai_serving_responses = ( + OpenAIServingResponses( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + tool_server=tool_server, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_chat = ( + OpenAIServingChat( + engine_client, + model_config, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_completion = ( + OpenAIServingCompletion( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + log_error_stack=args.log_error_stack, + ) + if "generate" in supported_tasks + else None + ) + state.openai_serving_pooling = ( + OpenAIServingPooling( + engine_client, + vllm_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + if "encode" in supported_tasks + else None + ) + state.openai_serving_embedding = ( + OpenAIServingEmbedding( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, + log_error_stack=args.log_error_stack, + ) + if "embed" in supported_tasks + else None + ) + state.openai_serving_classification = ( + ServingClassification( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + ) + if "classify" in supported_tasks + else None + ) + state.openai_serving_scores = ( + ServingScores( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + ) + if ("embed" in supported_tasks or "score" in supported_tasks) + else None + ) state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, @@ -1777,22 +1808,31 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + trust_request_chat_template=args.trust_request_chat_template, log_error_stack=args.log_error_stack, ) - state.openai_serving_transcription = OpenAIServingTranscription( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - log_error_stack=args.log_error_stack, - ) if "transcription" in supported_tasks else None - state.openai_serving_translation = OpenAIServingTranslation( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger, - log_error_stack=args.log_error_stack, - ) if "transcription" in supported_tasks else None + state.openai_serving_transcription = ( + OpenAIServingTranscription( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + ) + if "transcription" in supported_tasks + else None + ) + state.openai_serving_translation = ( + OpenAIServingTranslation( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + log_error_stack=args.log_error_stack, + ) + if "transcription" in supported_tasks + else None + ) state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 @@ -1819,17 +1859,20 @@ def create_server_unix_socket(path: str) -> socket.socket: def validate_api_server_args(args): valid_tool_parses = ToolParserManager.tool_parsers.keys() - if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valid_tool_parses: - raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " - f"(chose from {{ {','.join(valid_tool_parses)} }})") + if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses: + raise KeyError( + f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parses)} }})" + ) valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() - if ((reasoning_parser := args.structured_outputs_config.reasoning_parser) - and reasoning_parser not in valid_reasoning_parses): + if ( + reasoning_parser := args.structured_outputs_config.reasoning_parser + ) and reasoning_parser not in valid_reasoning_parses: raise KeyError( f"invalid reasoning parser: {reasoning_parser} " - f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + f"(chose from {{ {','.join(valid_reasoning_parses)} }})" + ) def setup_server(args): @@ -1868,8 +1911,7 @@ def signal_handler(*_) -> None: else: addr, port = sock_addr is_ssl = args.ssl_keyfile and args.ssl_certfile - host_part = f"[{addr}]" if is_valid_ipv6_address( - addr) else addr or "0.0.0.0" + host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0" listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" return listen_address, sock @@ -1884,11 +1926,9 @@ async def run_server(args, **uvicorn_kwargs) -> None: await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) -async def run_server_worker(listen_address, - sock, - args, - client_config=None, - **uvicorn_kwargs) -> None: +async def run_server_worker( + listen_address, sock, args, client_config=None, **uvicorn_kwargs +) -> None: """Run a single API server worker.""" if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: @@ -1897,11 +1937,11 @@ async def run_server_worker(listen_address, # Load logging config for uvicorn if specified log_config = load_log_config(args.log_config_file) if log_config is not None: - uvicorn_kwargs['log_config'] = log_config + uvicorn_kwargs["log_config"] = log_config async with build_async_engine_client( - args, - client_config=client_config, + args, + client_config=client_config, ) as engine_client: maybe_register_tokenizer_info_endpoint(args) app = build_app(args) @@ -1909,9 +1949,11 @@ async def run_server_worker(listen_address, vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, vllm_config, app.state, args) - logger.info("Starting vLLM API server %d on %s", - vllm_config.parallel_config._api_process_rank, - listen_address) + logger.info( + "Starting vLLM API server %d on %s", + vllm_config.parallel_config._api_process_rank, + listen_address, + ) shutdown_task = await serve_http( app, sock=sock, @@ -1945,7 +1987,8 @@ async def run_server_worker(listen_address, # entrypoints. cli_env_setup() parser = FlexibleArgumentParser( - description="vLLM OpenAI-Compatible RESTful API server.") + description="vLLM OpenAI-Compatible RESTful API server." + ) parser = make_arg_parser(parser) args = parser.parse_args() validate_parsed_serve_args(args) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a306c2bb7cb5..1f16646db63b 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -18,10 +18,14 @@ import vllm.envs as envs from vllm.config import config from vllm.engine.arg_utils import AsyncEngineArgs, optional_type -from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, - validate_chat_template) -from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT, - H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT) +from vllm.entrypoints.chat_utils import ( + ChatTemplateContentFormatOption, + validate_chat_template, +) +from vllm.entrypoints.constants import ( + H11_MAX_HEADER_COUNT_DEFAULT, + H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT, +) from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger @@ -31,7 +35,6 @@ class LoRAParserAction(argparse.Action): - def __call__( self, parser: argparse.ArgumentParser, @@ -57,8 +60,7 @@ def __call__( lora = LoRAModulePath(**lora_dict) lora_list.append(lora) except json.JSONDecodeError: - parser.error( - f"Invalid JSON format for --lora-modules: {item}") + parser.error(f"Invalid JSON format for --lora-modules: {item}") except TypeError as e: parser.error( f"Invalid fields for --lora-modules: {item} - {str(e)}" @@ -70,14 +72,16 @@ def __call__( @dataclass class FrontendArgs: """Arguments for the OpenAI-compatible frontend server.""" + host: Optional[str] = None """Host name.""" port: int = 8000 """Port number.""" uds: Optional[str] = None """Unix domain socket path. If set, host and port arguments are ignored.""" - uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical", - "trace"] = "info" + uvicorn_log_level: Literal[ + "debug", "info", "warning", "error", "critical", "trace" + ] = "info" """Log level for uvicorn.""" disable_uvicorn_access_log: bool = False """Disable uvicorn access log.""" @@ -218,7 +222,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: valid_tool_parsers = list(ToolParserManager.tool_parsers.keys()) parsers_str = ",".join(valid_tool_parsers) frontend_kwargs["tool_call_parser"]["metavar"] = ( - f"{{{parsers_str}}} or name registered in --tool-parser-plugin") + f"{{{parsers_str}}} or name registered in --tool-parser-plugin" + ) frontend_group = parser.add_argument_group( title="Frontend", @@ -238,27 +243,32 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: register all arguments instead of manually enumerating them here. This avoids code duplication and keeps the argument definitions in one place. """ - parser.add_argument("model_tag", - type=str, - nargs="?", - help="The model tag to serve " - "(optional if specified in config)") + parser.add_argument( + "model_tag", + type=str, + nargs="?", + help="The model tag to serve (optional if specified in config)", + ) parser.add_argument( "--headless", action="store_true", default=False, help="Run in headless mode. See multi-node data parallel " - "documentation for more details.") - parser.add_argument("--api-server-count", - "-asc", - type=int, - default=1, - help="How many API server processes to run.") + "documentation for more details.", + ) + parser.add_argument( + "--api-server-count", + "-asc", + type=int, + default=1, + help="How many API server processes to run.", + ) parser.add_argument( "--config", help="Read CLI options from a config file. " "Must be a YAML with the following options: " - "https://docs.vllm.ai/en/latest/configuration/serve_args.html") + "https://docs.vllm.ai/en/latest/configuration/serve_args.html", + ) parser = FrontendArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser) @@ -275,14 +285,13 @@ def validate_parsed_serve_args(args: argparse.Namespace): # Enable auto tool needs a tool call parser to be valid if args.enable_auto_tool_choice and not args.tool_call_parser: - raise TypeError("Error: --enable-auto-tool-choice requires " - "--tool-call-parser") + raise TypeError("Error: --enable-auto-tool-choice requires --tool-call-parser") if args.enable_log_outputs and not args.enable_log_requests: - raise TypeError("Error: --enable-log-outputs requires " - "--enable-log-requests") + raise TypeError("Error: --enable-log-outputs requires --enable-log-requests") def create_parser_for_docs() -> FlexibleArgumentParser: parser_for_docs = FlexibleArgumentParser( - prog="-m vllm.entrypoints.openai.api_server") + prog="-m vllm.entrypoints.openai.api_server" + ) return make_arg_parser(parser_for_docs) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index 29d72256cf70..2ea9fbf386ba 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -19,12 +19,11 @@ def __init__(self, allowed_ids: Iterable[int]): self.allowed_ids: Optional[list[int]] = list(allowed_ids) self.mask: Optional[torch.Tensor] = None - def __call__(self, token_ids: list[int], - logits: torch.Tensor) -> torch.Tensor: + def __call__(self, token_ids: list[int], logits: torch.Tensor) -> torch.Tensor: if self.mask is None: - self.mask = torch.ones((logits.shape[-1], ), - dtype=torch.bool, - device=logits.device) + self.mask = torch.ones( + (logits.shape[-1],), dtype=torch.bool, device=logits.device + ) self.mask[self.allowed_ids] = False self.allowed_ids = None logits.masked_fill_(self.mask, float("-inf")) @@ -39,8 +38,7 @@ def _get_allowed_token_ids_logits_processor( if not allowed_token_ids: raise ValueError("Empty allowed_token_ids provided") if not all(0 <= tid < vocab_size for tid in allowed_token_ids): - raise ValueError("allowed_token_ids contains " - "out-of-vocab token id") + raise ValueError("allowed_token_ids contains out-of-vocab token id") return AllowedTokenIdsLogitsProcessor(allowed_token_ids) @@ -71,20 +69,25 @@ def get_logits_processors( except ValueError as exc: raise ValueError( "Found token_id in logit_bias that is not " - "an integer or string representing an integer") from exc + "an integer or string representing an integer" + ) from exc # Check if token_id is within the vocab size for token_id, bias in clamped_logit_bias.items(): if token_id < 0 or token_id >= len(tokenizer): - raise ValueError(f"token_id {token_id} in logit_bias contains " - "out-of-vocab token id") + raise ValueError( + f"token_id {token_id} in logit_bias contains out-of-vocab token id" + ) logits_processors.append( - partial(logit_bias_logits_processor, clamped_logit_bias)) + partial(logit_bias_logits_processor, clamped_logit_bias) + ) if allowed_token_ids is not None: logits_processors.append( _get_allowed_token_ids_logits_processor( - frozenset(allowed_token_ids), len(tokenizer))) + frozenset(allowed_token_ids), len(tokenizer) + ) + ) return logits_processors diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 9d51372887c2..a92e8372b304 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,70 +6,80 @@ import json import time from http import HTTPStatus -from typing import (Annotated, Any, ClassVar, Generic, Literal, Optional, - TypeVar, Union) +from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, Union import regex as re import torch from fastapi import HTTPException, UploadFile -# yapf: disable from openai.types.chat.chat_completion_audio import ( - ChatCompletionAudio as OpenAIChatCompletionAudio) -from openai.types.chat.chat_completion_message import ( - Annotation as OpenAIAnnotation) + ChatCompletionAudio as OpenAIChatCompletionAudio, +) +from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation from openai.types.responses import ( ResponseCodeInterpreterCallCodeDeltaEvent, ResponseCodeInterpreterCallCodeDoneEvent, ResponseCodeInterpreterCallCompletedEvent, ResponseCodeInterpreterCallInProgressEvent, - ResponseCodeInterpreterCallInterpretingEvent) -from openai.types.responses import ( - ResponseCompletedEvent as OpenAIResponseCompletedEvent) -from openai.types.responses import (ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent) + ResponseCodeInterpreterCallInterpretingEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseFunctionToolCall, + ResponseInputItemParam, + ResponseOutputItem, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponsePrompt, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseStatus, + ResponseWebSearchCallCompletedEvent, + ResponseWebSearchCallInProgressEvent, + ResponseWebSearchCallSearchingEvent, +) from openai.types.responses import ( - ResponseCreatedEvent as OpenAIResponseCreatedEvent) -from openai.types.responses import ResponseFunctionToolCall + ResponseCompletedEvent as OpenAIResponseCompletedEvent, +) +from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent from openai.types.responses import ( - ResponseInProgressEvent as OpenAIResponseInProgressEvent) -from openai.types.responses import (ResponseInputItemParam, ResponseOutputItem, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponsePrompt, ResponseReasoningItem, - ResponseReasoningTextDeltaEvent, - ResponseReasoningTextDoneEvent, - ResponseStatus, - ResponseWebSearchCallCompletedEvent, - ResponseWebSearchCallInProgressEvent, - ResponseWebSearchCallSearchingEvent) -# yapf: enable + ResponseInProgressEvent as OpenAIResponseInProgressEvent, +) from openai.types.responses.response_reasoning_item import ( - Content as ResponseReasoningTextContent) + Content as ResponseReasoningTextContent, +) # Backward compatibility for OpenAI client versions try: # For older openai versions (< 1.100.0) from openai.types.responses import ResponseTextConfig except ImportError: # For newer openai versions (>= 1.100.0) - from openai.types.responses import (ResponseFormatTextConfig as - ResponseTextConfig) + from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig from openai.types.responses.response import IncompleteDetails, ToolChoice from openai.types.responses.tool import Tool from openai.types.shared import Metadata, Reasoning -from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, - ValidationInfo, field_validator, model_validator) +from pydantic import ( + BaseModel, + ConfigDict, + Field, + TypeAdapter, + ValidationInfo, + field_validator, + model_validator, +) from typing_extensions import TypeAlias from vllm import envs -from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - make_tool_call_id) -from vllm.entrypoints.score_utils import (ScoreContentPartParam, - ScoreMultiModalParam) +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id +from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.pooling_params import PoolingParams -from vllm.sampling_params import (BeamSearchParams, RequestOutputKind, - SamplingParams, StructuredOutputsParams) +from vllm.sampling_params import ( + BeamSearchParams, + RequestOutputKind, + SamplingParams, + StructuredOutputsParams, +) from vllm.utils import random_uuid, resolve_obj_by_qualname logger = init_logger(__name__) @@ -103,8 +113,7 @@ def __log_extra_fields__(cls, data, handler): # Compare against both field names and aliases if any(k not in field_names for k in data): logger.warning( - "The following fields were present in the request " - "but ignored: %s", + "The following fields were present in the request but ignored: %s", data.keys() - field_names, ) return result @@ -173,7 +182,7 @@ class JsonSchemaResponseFormat(OpenAIBaseModel): description: Optional[str] = None # schema is the field in openai but that causes conflicts with pydantic so # instead use json_schema with an alias - json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema') + json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema") strict: Optional[bool] = None @@ -181,8 +190,9 @@ class StructuralTag(OpenAIBaseModel): begin: str # schema is the field, but that causes conflicts with pydantic so # instead use structural_tag_schema with an alias - structural_tag_schema: Optional[dict[str, Any]] = Field(default=None, - alias="schema") + structural_tag_schema: Optional[dict[str, Any]] = Field( + default=None, alias="schema" + ) end: str @@ -239,18 +249,19 @@ class LogitsProcessorConstructor(BaseModel): LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] -def get_logits_processors(processors: Optional[LogitsProcessors], - pattern: Optional[str]) -> Optional[list[Any]]: +def get_logits_processors( + processors: Optional[LogitsProcessors], pattern: Optional[str] +) -> Optional[list[Any]]: if processors and pattern: logits_processors = [] for processor in processors: - qualname = processor if isinstance(processor, - str) else processor.qualname + qualname = processor if isinstance(processor, str) else processor.qualname if not re.match(pattern, qualname): raise ValueError( f"Logits processor '{qualname}' is not allowed by this " "server. See --logits-processor-pattern engine argument " - "for more information.") + "for more information." + ) try: logits_processor = resolve_obj_by_qualname(qualname) except Exception as e: @@ -258,37 +269,41 @@ def get_logits_processors(processors: Optional[LogitsProcessors], f"Logits processor '{qualname}' could not be resolved: {e}" ) from e if isinstance(processor, LogitsProcessorConstructor): - logits_processor = logits_processor(*processor.args or [], - **processor.kwargs or {}) + logits_processor = logits_processor( + *processor.args or [], **processor.kwargs or {} + ) logits_processors.append(logits_processor) return logits_processors elif processors: raise ValueError( "The `logits_processors` argument is not supported by this " "server. See --logits-processor-pattern engine argument " - "for more information.") + "for more information." + ) return None -ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam, - ResponseReasoningItem, - ResponseFunctionToolCall] +ResponseInputOutputItem: TypeAlias = Union[ + ResponseInputItemParam, ResponseReasoningItem, ResponseFunctionToolCall +] class ResponsesRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/responses/create background: Optional[bool] = False - include: Optional[list[ - Literal[ - "code_interpreter_call.outputs", - "computer_call_output.output.image_url", - "file_search_call.results", - "message.input_image.image_url", - "message.output_text.logprobs", - "reasoning.encrypted_content", - ], - ]] = None + include: Optional[ + list[ + Literal[ + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", + ], + ] + ] = None input: Union[str, list[ResponseInputOutputItem]] instructions: Optional[str] = None max_output_tokens: Optional[int] = None @@ -299,8 +314,7 @@ class ResponsesRequest(OpenAIBaseModel): previous_response_id: Optional[str] = None prompt: Optional[ResponsePrompt] = None reasoning: Optional[Reasoning] = None - service_tier: Literal["auto", "default", "flex", "scale", - "priority"] = "auto" + service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto" store: Optional[bool] = True stream: Optional[bool] = False temperature: Optional[float] = None @@ -318,7 +332,8 @@ class ResponsesRequest(OpenAIBaseModel): description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -329,7 +344,8 @@ class ResponsesRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) cache_salt: Optional[str] = Field( default=None, @@ -339,14 +355,18 @@ class ResponsesRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0.")) + "to 256 bit). Not supported by vLLM engine V0." + ), + ) enable_response_messages: bool = Field( default=False, description=( "Dictates whether or not to return messages as part of the " "response object. Currently only supported for non-streaming " - "non-background and gpt-oss only. ")) + "non-background and gpt-oss only. " + ), + ) # --8<-- [end:responses-extra-params] _DEFAULT_SAMPLING_PARAMS = { @@ -367,20 +387,25 @@ def to_sampling_params( default_sampling_params = default_sampling_params or {} if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output structured_outputs = None if self.text is not None and self.text.format is not None: response_format = self.text.format - if (response_format.type == "json_schema" - and response_format.schema_ is not None): + if ( + response_format.type == "json_schema" + and response_format.schema_ is not None + ): structured_outputs = StructuredOutputsParams( - json=response_format.schema_) + json=response_format.schema_ + ) elif response_format.type == "json_object": raise NotImplementedError("json_object is not supported") @@ -389,11 +414,11 @@ def to_sampling_params( temperature=temperature, top_p=top_p, max_tokens=max_tokens, - logprobs=self.top_logprobs - if self.is_include_output_logprobs() else None, + logprobs=self.top_logprobs if self.is_include_output_logprobs() else None, stop_token_ids=stop_token_ids, - output_kind=(RequestOutputKind.DELTA - if self.stream else RequestOutputKind.FINAL_ONLY), + output_kind=( + RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY + ), structured_outputs=structured_outputs, ) @@ -401,17 +426,17 @@ def is_include_output_logprobs(self) -> bool: """Check if the request includes output logprobs.""" if self.include is None: return False - return isinstance( - self.include, - list) and "message.output_text.logprobs" in self.include + return ( + isinstance(self.include, list) + and "message.output_text.logprobs" in self.include + ) @model_validator(mode="before") def validate_background(cls, data): if not data.get("background"): return data if not data.get("store", True): - raise ValueError( - "background can only be used when `store` is true") + raise ValueError("background can only be used when `store` is true") return data @model_validator(mode="before") @@ -426,11 +451,12 @@ def check_cache_salt_support(cls, data): if not envs.VLLM_USE_V1: raise ValueError( "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], - str) or not data["cache_salt"]: - raise ValueError("Parameter 'cache_salt' must be a " - "non-empty string if provided.") + "this instance of vLLM, which uses engine V0." + ) + if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -445,8 +471,9 @@ class ChatCompletionRequest(OpenAIBaseModel): top_logprobs: Optional[int] = 0 max_tokens: Optional[int] = Field( default=None, - deprecated= - 'max_tokens is deprecated in favor of the max_completion_tokens field') + deprecated="max_tokens is deprecated in favor of " + "the max_completion_tokens field", + ) max_completion_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 @@ -458,12 +485,14 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: Optional[float] = None top_p: Optional[float] = None tools: Optional[list[ChatCompletionToolsParam]] = None - tool_choice: Optional[Union[ - Literal["none"], - Literal["auto"], - Literal["required"], - ChatCompletionNamedToolChoiceParam, - ]] = "none" + tool_choice: Optional[ + Union[ + Literal["none"], + Literal["auto"], + Literal["required"], + ChatCompletionNamedToolChoiceParam, + ] + ] = "none" reasoning_effort: Optional[Literal["low", "medium", "high"]] = None include_reasoning: bool = True @@ -495,23 +524,26 @@ class ChatCompletionRequest(OpenAIBaseModel): default=False, description=( "If true, the new message will be prepended with the last message " - "if they belong to the same role."), + "if they belong to the same role." + ), ) add_generation_prompt: bool = Field( default=True, - description= - ("If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model."), + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), ) continue_final_message: bool = Field( default=False, - description= - ("If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - "This allows you to \"prefill\" part of the model's response for it. " - "Cannot be used at the same time as `add_generation_prompt`."), + description=( + "If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + 'This allows you to "prefill" part of the model\'s response for it. ' + "Cannot be used at the same time as `add_generation_prompt`." + ), ) add_special_tokens: bool = Field( default=False, @@ -520,16 +552,18 @@ class ChatCompletionRequest(OpenAIBaseModel): "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to false (as is the " - "default)."), + "default)." + ), ) documents: Optional[list[dict[str, str]]] = Field( default=None, - description= - ("A list of dicts representing documents that will be accessible to " - "the model if it is performing RAG (retrieval-augmented generation)." - " If the template does not support RAG, this argument will have no " - "effect. We recommend that each document should be a dict containing " - "\"title\" and \"text\" keys."), + description=( + "A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + '"title" and "text" keys.' + ), ) chat_template: Optional[str] = Field( default=None, @@ -537,13 +571,15 @@ class ChatCompletionRequest(OpenAIBaseModel): "A Jinja template to use for this conversion. " "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one."), + "does not define one." + ), ) chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template."), + "Will be accessible by the chat template." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -558,42 +594,48 @@ class ChatCompletionRequest(OpenAIBaseModel): description=( "`guided_json` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `json` to `structured_outputs` instead."), + "Please pass `json` to `structured_outputs` instead." + ), ) guided_regex: Optional[str] = Field( default=None, description=( "`guided_regex` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `regex` to `structured_outputs` instead."), + "Please pass `regex` to `structured_outputs` instead." + ), ) guided_choice: Optional[list[str]] = Field( default=None, description=( "`guided_choice` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `choice` to `structured_outputs` instead."), + "Please pass `choice` to `structured_outputs` instead." + ), ) guided_grammar: Optional[str] = Field( default=None, description=( "`guided_grammar` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `grammar` to `structured_outputs` instead."), + "Please pass `grammar` to `structured_outputs` instead." + ), ) structural_tag: Optional[str] = Field( default=None, description=( "`structural_tag` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `structural_tag` to `structured_outputs` instead."), + "Please pass `structural_tag` to `structured_outputs` instead." + ), ) guided_decoding_backend: Optional[str] = Field( default=None, description=( "`guided_decoding_backend` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please remove it from your request."), + "Please remove it from your request." + ), ) guided_whitespace_pattern: Optional[str] = Field( default=None, @@ -608,14 +650,16 @@ class ChatCompletionRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) logits_processors: Optional[LogitsProcessors] = Field( default=None, @@ -627,13 +671,17 @@ class ChatCompletionRequest(OpenAIBaseModel): "'args' and 'kwargs' fields containing positional and keyword " "arguments. For example: {'qualname': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}.")) + "{'param': 'value'}}." + ), + ) return_tokens_as_token_ids: Optional[bool] = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.")) + "that are not JSON-encodable can be identified." + ), + ) return_token_ids: Optional[bool] = Field( default=None, description=( @@ -641,7 +689,9 @@ class ChatCompletionRequest(OpenAIBaseModel): "generated text. In streaming mode, prompt_token_ids is included " "only in the first chunk, and token_ids contains the delta tokens " "for each chunk. This is useful for debugging or when you " - "need to map generated text back to input tokens.")) + "need to map generated text back to input tokens." + ), + ) cache_salt: Optional[str] = Field( default=None, description=( @@ -650,15 +700,20 @@ class ChatCompletionRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0.")) + "to 256 bit). Not supported by vLLM engine V0." + ), + ) kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, - description="KVTransfer parameters used for disaggregated serving.") + description="KVTransfer parameters used for disaggregated serving.", + ) vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, - description=("Additional request parameters with string or " - "numeric values, used by custom extensions."), + description=( + "Additional request parameters with string or " + "numeric values, used by custom extensions." + ), ) # --8<-- [end:chat-completion-extra-params] @@ -673,13 +728,13 @@ class ChatCompletionRequest(OpenAIBaseModel): } def to_beam_search_params( - self, max_tokens: int, - default_sampling_params: dict) -> BeamSearchParams: - + self, max_tokens: int, default_sampling_params: dict + ) -> BeamSearchParams: n = self.n if self.n is not None else 1 if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) return BeamSearchParams( beam_width=n, @@ -696,7 +751,6 @@ def to_sampling_params( logits_processor_pattern: Optional[str], default_sampling_params: dict, ) -> SamplingParams: - # Default parameters if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( @@ -705,16 +759,20 @@ def to_sampling_params( ) if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) if (top_k := self.top_k) is None: top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) if (min_p := self.min_p) is None: min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] + ) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: @@ -753,10 +811,10 @@ def to_sampling_params( elif response_format.type == "structural_tag": structural_tag = response_format assert structural_tag is not None and isinstance( - structural_tag, StructuralTagResponseFormat) + structural_tag, StructuralTagResponseFormat + ) s_tag_obj = structural_tag.model_dump(by_alias=True) - self.structured_outputs.structural_tag = json.dumps( - s_tag_obj) + self.structured_outputs.structural_tag = json.dumps(s_tag_obj) # Set structured output params for tool calling if json_schema_from_tool is not None: @@ -786,12 +844,14 @@ def to_sampling_params( min_tokens=self.min_tokens, skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, - logits_processors=get_logits_processors(self.logits_processors, - logits_processor_pattern), + logits_processors=get_logits_processors( + self.logits_processors, logits_processor_pattern + ), include_stop_str_in_output=self.include_stop_str_in_output, truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA if self.stream \ - else RequestOutputKind.FINAL_ONLY, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, bad_words=self.bad_words, @@ -809,8 +869,7 @@ def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]: tool_name = self.tool_choice.function.name tools = {tool.function.name: tool.function for tool in self.tools} if tool_name not in tools: - raise ValueError( - f"Tool '{tool_name}' has not been passed in `tools`.") + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") tool = tools[tool_name] return tool.parameters @@ -822,37 +881,31 @@ def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]: def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: return { "properties": { - "name": { - "type": "string", - "enum": [tool.function.name] - }, + "name": {"type": "string", "enum": [tool.function.name]}, # parameters are always generated as '{}' in the final # output if they are missing from the request # (i.e. are None or '{}') so the schema is # updated to produce an empty object in that case "parameters": tool.function.parameters - if tool.function.parameters else { - "type": "object", - "properties": {} - } + if tool.function.parameters + else {"type": "object", "properties": {}}, }, - "required": ["name", "parameters"] + "required": ["name", "parameters"], } - def get_tool_schema_defs( - tools: list[ChatCompletionToolsParam]) -> dict: + def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict: all_defs = dict[str, dict[str, Any]]() for tool in tools: if tool.function.parameters is None: continue defs = tool.function.parameters.pop("$defs", {}) for def_name, def_schema in defs.items(): - if def_name in all_defs and all_defs[ - def_name] != def_schema: + if def_name in all_defs and all_defs[def_name] != def_schema: raise ValueError( f"Tool definition '{def_name}' has " "multiple schemas, which is not " - "supported.") + "supported." + ) else: all_defs[def_name] = def_schema return all_defs @@ -862,8 +915,8 @@ def get_tool_schema_defs( "minItems": 1, "items": { "type": "object", - "anyOf": [get_tool_schema(tool) for tool in self.tools] - } + "anyOf": [get_tool_schema(tool) for tool in self.tools], + }, } json_schema_defs = get_tool_schema_defs(self.tools) if json_schema_defs: @@ -876,8 +929,7 @@ def get_tool_schema_defs( @classmethod def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data @@ -885,24 +937,22 @@ def validate_stream_options(cls, data): @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and (prompt_logprobs > 0 - or prompt_logprobs == -1): + if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): raise ValueError( - "`prompt_logprobs` are not available when `stream=True`.") + "`prompt_logprobs` are not available when `stream=True`." + ) if prompt_logprobs < 0 and prompt_logprobs != -1: - raise ValueError( - "`prompt_logprobs` must be a positive value or -1.") + raise ValueError("`prompt_logprobs` must be a positive value or -1.") if prompt_logprobs == -1 and not envs.VLLM_USE_V1: - raise ValueError("`prompt_logprobs=-1` is only supported with " - "vLLM engine V1.") + raise ValueError( + "`prompt_logprobs=-1` is only supported with vLLM engine V1." + ) if (top_logprobs := data.get("top_logprobs")) is not None: if top_logprobs < 0 and top_logprobs != -1: - raise ValueError( - "`top_logprobs` must be a positive value or -1.") + raise ValueError("`top_logprobs` must be a positive value or -1.") - if (top_logprobs == -1 - or top_logprobs > 0) and not data.get("logprobs"): + if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"): raise ValueError( "when using `top_logprobs`, `logprobs` must be set to true." ) @@ -918,30 +968,32 @@ def check_structured_outputs_count(cls, data): if data.get("structured_outputs", None) is None: return data - structured_outputs_kwargs = data['structured_outputs'] + structured_outputs_kwargs = data["structured_outputs"] count = sum( structured_outputs_kwargs.get(k) is not None - for k in ("json", "regex", "choice")) + for k in ("json", "regex", "choice") + ) # you can only use one kind of constraints for structured outputs if count > 1: raise ValueError( "You can only use one kind of constraints for structured " - "outputs ('json', 'regex' or 'choice').") + "outputs ('json', 'regex' or 'choice')." + ) # you can only either use structured outputs or tools, not both if count > 1 and data.get("tool_choice", "none") not in ( - "none", - "auto", - "required", + "none", + "auto", + "required", ): raise ValueError( "You can only either use constraints for structured outputs " - "or tools, not both.") + "or tools, not both." + ) return data @model_validator(mode="before") @classmethod def check_tool_usage(cls, data): - # if "tool_choice" is not specified but tools are provided, # default to "auto" tool_choice if "tool_choice" not in data and data.get("tools"): @@ -953,52 +1005,58 @@ def check_tool_usage(cls, data): # if "tool_choice" is specified -- validation if "tool_choice" in data and data["tool_choice"] is not None: - # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: - raise ValueError( - "When using `tool_choice`, `tools` must be set.") + raise ValueError("When using `tool_choice`, `tools` must be set.") # make sure that tool choice is either a named tool # OR that it's set to "auto" or "required" - if data["tool_choice"] not in [ - "auto", "required" - ] and not isinstance(data["tool_choice"], dict): + if data["tool_choice"] not in ["auto", "required"] and not isinstance( + data["tool_choice"], dict + ): raise ValueError( - f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\ - 'Only named tools, "none", "auto" or "required" '\ - 'are supported.' + f"Invalid value for `tool_choice`: {data['tool_choice']}! " + 'Only named tools, "none", "auto" or "required" ' + "are supported." ) # if tool_choice is "required" but the "tools" list is empty, # override the data to behave like "none" to align with # OpenAI’s behavior. - if data["tool_choice"] == "required" and isinstance( - data["tools"], list) and len(data["tools"]) == 0: + if ( + data["tool_choice"] == "required" + and isinstance(data["tools"], list) + and len(data["tools"]) == 0 + ): data["tool_choice"] = "none" del data["tools"] return data # ensure that if "tool_choice" is specified as an object, # it matches a valid tool - correct_usage_message = 'Correct usage: `{"type": "function",' \ + correct_usage_message = ( + 'Correct usage: `{"type": "function",' ' "function": {"name": "my_function"}}`' + ) if isinstance(data["tool_choice"], dict): valid_tool = False function = data["tool_choice"].get("function") if not isinstance(function, dict): raise ValueError( f"Invalid value for `function`: `{function}` in " - f"`tool_choice`! {correct_usage_message}") + f"`tool_choice`! {correct_usage_message}" + ) if "name" not in function: - raise ValueError(f"Expected field `name` in `function` in " - f"`tool_choice`! {correct_usage_message}") + raise ValueError( + f"Expected field `name` in `function` in " + f"`tool_choice`! {correct_usage_message}" + ) function_name = function["name"] - if not isinstance(function_name, - str) or len(function_name) == 0: + if not isinstance(function_name, str) or len(function_name) == 0: raise ValueError( f"Invalid `name` in `function`: `{function_name}`" - f" in `tool_choice`! {correct_usage_message}") + f" in `tool_choice`! {correct_usage_message}" + ) for tool in data["tools"]: if tool["function"]["name"] == function_name: valid_tool = True @@ -1006,16 +1064,18 @@ def check_tool_usage(cls, data): if not valid_tool: raise ValueError( "The tool specified in `tool_choice` does not match any" - " of the specified `tools`") + " of the specified `tools`" + ) return data @model_validator(mode="before") @classmethod def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get( - "add_generation_prompt"): - raise ValueError("Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True.") + if data.get("continue_final_message") and data.get("add_generation_prompt"): + raise ValueError( + "Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True." + ) return data @model_validator(mode="before") @@ -1025,11 +1085,12 @@ def check_cache_salt_support(cls, data): if not envs.VLLM_USE_V1: raise ValueError( "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], - str) or not data["cache_salt"]: - raise ValueError("Parameter 'cache_salt' must be a " - "non-empty string if provided.") + "this instance of vLLM, which uses engine V0." + ) + if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -1078,7 +1139,8 @@ class CompletionRequest(OpenAIBaseModel): default=True, description=( "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt."), + "the prompt." + ), ) response_format: Optional[AnyResponseFormat] = Field( default=None, @@ -1097,35 +1159,40 @@ class CompletionRequest(OpenAIBaseModel): description=( "`guided_json` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `json` to `structured_outputs` instead."), + "Please pass `json` to `structured_outputs` instead." + ), ) guided_regex: Optional[str] = Field( default=None, description=( "`guided_regex` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `regex` to `structured_outputs` instead."), + "Please pass `regex` to `structured_outputs` instead." + ), ) guided_choice: Optional[list[str]] = Field( default=None, description=( "`guided_choice` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `choice` to `structured_outputs` instead."), + "Please pass `choice` to `structured_outputs` instead." + ), ) guided_grammar: Optional[str] = Field( default=None, description=( "`guided_grammar` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `grammar` to `structured_outputs` instead."), + "Please pass `grammar` to `structured_outputs` instead." + ), ) guided_decoding_backend: Optional[str] = Field( default=None, description=( "`guided_decoding_backend` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please remove it from your request."), + "Please remove it from your request." + ), ) guided_whitespace_pattern: Optional[str] = Field( default=None, @@ -1140,14 +1207,16 @@ class CompletionRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) logits_processors: Optional[LogitsProcessors] = Field( default=None, @@ -1159,14 +1228,18 @@ class CompletionRequest(OpenAIBaseModel): "'args' and 'kwargs' fields containing positional and keyword " "arguments. For example: {'qualname': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}.")) + "{'param': 'value'}}." + ), + ) return_tokens_as_token_ids: Optional[bool] = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.")) + "that are not JSON-encodable can be identified." + ), + ) return_token_ids: Optional[bool] = Field( default=None, description=( @@ -1174,7 +1247,9 @@ class CompletionRequest(OpenAIBaseModel): "generated text. In streaming mode, prompt_token_ids is included " "only in the first chunk, and token_ids contains the delta tokens " "for each chunk. This is useful for debugging or when you " - "need to map generated text back to input tokens.")) + "need to map generated text back to input tokens." + ), + ) cache_salt: Optional[str] = Field( default=None, @@ -1184,16 +1259,21 @@ class CompletionRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0.")) + "to 256 bit). Not supported by vLLM engine V0." + ), + ) kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, - description="KVTransfer parameters used for disaggregated serving.") + description="KVTransfer parameters used for disaggregated serving.", + ) vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, - description=("Additional request parameters with string or " - "numeric values, used by custom extensions."), + description=( + "Additional request parameters with string or " + "numeric values, used by custom extensions." + ), ) # --8<-- [end:completion-extra-params] @@ -1212,7 +1292,6 @@ def to_beam_search_params( max_tokens: int, default_sampling_params: Optional[dict] = None, ) -> BeamSearchParams: - if default_sampling_params is None: default_sampling_params = {} n = self.n if self.n is not None else 1 @@ -1235,7 +1314,6 @@ def to_sampling_params( logits_processor_pattern: Optional[str], default_sampling_params: Optional[dict] = None, ) -> SamplingParams: - if default_sampling_params is None: default_sampling_params = {} @@ -1247,16 +1325,20 @@ def to_sampling_params( ) if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) if (top_k := self.top_k) is None: top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) if (min_p := self.min_p) is None: min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] + ) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: @@ -1277,9 +1359,11 @@ def to_sampling_params( if len(kwargs) > 0: self.structured_outputs = StructuredOutputsParams(**kwargs) - if (self.structured_outputs is not None - and self.response_format is not None - and self.response_format.type == "json_object"): + if ( + self.structured_outputs is not None + and self.response_format is not None + and self.response_format.type == "json_object" + ): self.structured_outputs.json_object = True extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} @@ -1307,16 +1391,18 @@ def to_sampling_params( skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, - logits_processors=get_logits_processors(self.logits_processors, - logits_processor_pattern), + logits_processors=get_logits_processors( + self.logits_processors, logits_processor_pattern + ), truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA if self.stream \ - else RequestOutputKind.FINAL_ONLY, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, - ) + ) @model_validator(mode="before") @classmethod @@ -1324,31 +1410,33 @@ def check_structured_outputs_count(cls, data): if data.get("structured_outputs", None) is None: return data - structured_outputs_kwargs = data['structured_outputs'] + structured_outputs_kwargs = data["structured_outputs"] count = sum( structured_outputs_kwargs.get(k) is not None - for k in ("json", "regex", "choice")) + for k in ("json", "regex", "choice") + ) if count > 1: raise ValueError( "You can only use one kind of constraints for structured " - "outputs ('json', 'regex' or 'choice').") + "outputs ('json', 'regex' or 'choice')." + ) return data @model_validator(mode="before") @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and (prompt_logprobs > 0 - or prompt_logprobs == -1): + if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): raise ValueError( - "`prompt_logprobs` are not available when `stream=True`.") + "`prompt_logprobs` are not available when `stream=True`." + ) if prompt_logprobs < 0 and prompt_logprobs != -1: - raise ValueError( - "`prompt_logprobs` must be a positive value or -1.") + raise ValueError("`prompt_logprobs` must be a positive value or -1.") if prompt_logprobs == -1 and not envs.VLLM_USE_V1: - raise ValueError("`prompt_logprobs=-1` is only supported with " - "vLLM engine V1.") + raise ValueError( + "`prompt_logprobs=-1` is only supported with vLLM engine V1." + ) if (logprobs := data.get("logprobs")) is not None and logprobs < 0: raise ValueError("`logprobs` must be a positive value.") @@ -1358,8 +1446,7 @@ def check_logprobs(cls, data): @classmethod def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data @@ -1369,11 +1456,10 @@ def validate_prompt_and_prompt_embeds(cls, data): prompt = data.get("prompt") prompt_embeds = data.get("prompt_embeds") - prompt_is_empty = (prompt is None - or (isinstance(prompt, str) and prompt == "")) - embeds_is_empty = (prompt_embeds is None - or (isinstance(prompt_embeds, list) - and len(prompt_embeds) == 0)) + prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "") + embeds_is_empty = prompt_embeds is None or ( + isinstance(prompt_embeds, list) and len(prompt_embeds) == 0 + ) if prompt_is_empty and embeds_is_empty: raise ValueError( @@ -1389,11 +1475,12 @@ def check_cache_salt_support(cls, data): if not envs.VLLM_USE_V1: raise ValueError( "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], - str) or not data["cache_salt"]: - raise ValueError("Parameter 'cache_salt' must be a " - "non-empty string if provided.") + "this instance of vLLM, which uses engine V0." + ) + if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -1412,21 +1499,24 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): default=True, description=( "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt."), + "the prompt." + ), ) priority: int = Field( default=0, description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) normalize: Optional[bool] = None @@ -1436,7 +1526,8 @@ def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, dimensions=self.dimensions, - normalize=self.normalize) + normalize=self.normalize, + ) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1451,10 +1542,11 @@ class EmbeddingChatRequest(OpenAIBaseModel): # --8<-- [start:chat-embedding-extra-params] add_generation_prompt: bool = Field( default=False, - description= - ("If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model."), + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), ) add_special_tokens: bool = Field( @@ -1464,7 +1556,8 @@ class EmbeddingChatRequest(OpenAIBaseModel): "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to false (as is the " - "default)."), + "default)." + ), ) chat_template: Optional[str] = Field( default=None, @@ -1472,13 +1565,15 @@ class EmbeddingChatRequest(OpenAIBaseModel): "A Jinja template to use for this conversion. " "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one."), + "does not define one." + ), ) chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template."), + "Will be accessible by the chat template." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -1489,14 +1584,16 @@ class EmbeddingChatRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) normalize: Optional[bool] = None # --8<-- [end:chat-embedding-extra-params] @@ -1504,17 +1601,19 @@ class EmbeddingChatRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get( - "add_generation_prompt"): - raise ValueError("Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True.") + if data.get("continue_final_message") and data.get("add_generation_prompt"): + raise ValueError( + "Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True." + ) return data def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, dimensions=self.dimensions, - normalize=self.normalize) + normalize=self.normalize, + ) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] @@ -1546,7 +1645,6 @@ def to_pooling_params(self): class IOProcessorResponse(OpenAIBaseModel, Generic[T]): - request_id: Optional[str] = None """ The request_id associated with this response @@ -1560,8 +1658,7 @@ class IOProcessorResponse(OpenAIBaseModel, Generic[T]): """ -PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest, - IOProcessorRequest] +PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest, IOProcessorRequest] class ScoreRequest(OpenAIBaseModel): @@ -1582,7 +1679,8 @@ class ScoreRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) activation: Optional[bool] = None @@ -1592,7 +1690,8 @@ class ScoreRequest(OpenAIBaseModel): def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, - activation=self.activation) + activation=self.activation, + ) class RerankRequest(OpenAIBaseModel): @@ -1614,7 +1713,8 @@ class RerankRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) activation: Optional[bool] = None @@ -1624,7 +1724,8 @@ class RerankRequest(OpenAIBaseModel): def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, - activation=self.activation) + activation=self.activation, + ) class RerankDocument(BaseModel): @@ -1653,8 +1754,7 @@ class CompletionLogProbs(OpenAIBaseModel): text_offset: list[int] = Field(default_factory=list) token_logprobs: list[Optional[float]] = Field(default_factory=list) tokens: list[str] = Field(default_factory=list) - top_logprobs: list[Optional[dict[str, - float]]] = Field(default_factory=list) + top_logprobs: list[Optional[dict[str, float]]] = Field(default_factory=list) class CompletionResponseChoice(OpenAIBaseModel): @@ -1667,7 +1767,8 @@ class CompletionResponseChoice(OpenAIBaseModel): description=( "The stop string or token id that caused the completion " "to stop, None if the completion finished for some other reason " - "including encountering the EOS token"), + "including encountering the EOS token" + ), ) token_ids: Optional[list[int]] = None # For response prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None @@ -1680,14 +1781,16 @@ class CompletionResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[CompletionResponseChoice] - service_tier: Optional[Literal["auto", "default", "flex", "scale", - "priority"]] = None + service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = ( + None + ) system_fingerprint: Optional[str] = None usage: UsageInfo # vLLM-specific fields that are not in OpenAI spec kv_transfer_params: Optional[dict[str, Any]] = Field( - default=None, description="KVTransfer parameters.") + default=None, description="KVTransfer parameters." + ) class CompletionResponseStreamChoice(OpenAIBaseModel): @@ -1700,7 +1803,8 @@ class CompletionResponseStreamChoice(OpenAIBaseModel): description=( "The stop string or token id that caused the completion " "to stop, None if the completion finished for some other reason " - "including encountering the EOS token"), + "including encountering the EOS token" + ), ) # not part of the OpenAI spec but for tracing the tokens # prompt tokens is put into choice to align with CompletionResponseChoice @@ -1774,7 +1878,8 @@ class ClassificationRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) activation: Optional[bool] = None @@ -1784,7 +1889,8 @@ class ClassificationRequest(OpenAIBaseModel): def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, - activation=self.activation) + activation=self.activation, + ) class ClassificationData(OpenAIBaseModel): @@ -1888,8 +1994,9 @@ class ChatCompletionResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[ChatCompletionResponseChoice] - service_tier: Optional[Literal["auto", "default", "flex", "scale", - "priority"]] = None + service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = ( + None + ) system_fingerprint: Optional[str] = None usage: UsageInfo @@ -1897,7 +2004,8 @@ class ChatCompletionResponse(OpenAIBaseModel): prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None prompt_token_ids: Optional[list[int]] = None kv_transfer_params: Optional[dict[str, Any]] = Field( - default=None, description="KVTransfer parameters.") + default=None, description="KVTransfer parameters." + ) class DeltaMessage(OpenAIBaseModel): @@ -2007,10 +2115,9 @@ def from_request( input_messages: Optional[list[ChatCompletionMessageParam]] = None, output_messages: Optional[list[ChatCompletionMessageParam]] = None, ) -> "ResponsesResponse": - incomplete_details: Optional[IncompleteDetails] = None - if status == 'incomplete': - incomplete_details = IncompleteDetails(reason='max_output_tokens') + if status == "incomplete": + incomplete_details = IncompleteDetails(reason="max_output_tokens") # TODO: implement the other reason for incomplete_details, # which is content_filter # incomplete_details = IncompleteDetails(reason='content_filter') @@ -2125,8 +2232,9 @@ class ResponseInProgressEvent(OpenAIResponseInProgressEvent): ResponseCodeInterpreterCallCompletedEvent, ] -BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest, - ScoreRequest, RerankRequest] +BatchRequestInputBody = Union[ + ChatCompletionRequest, EmbeddingRequest, ScoreRequest, RerankRequest +] class BatchRequestInput(OpenAIBaseModel): @@ -2151,7 +2259,7 @@ class BatchRequestInput(OpenAIBaseModel): # The parameters of the request. body: BatchRequestInputBody - @field_validator('body', mode='plain') + @field_validator("body", mode="plain") @classmethod def check_type_for_url(cls, value: Any, info: ValidationInfo): # Use url to disambiguate models @@ -2175,8 +2283,9 @@ class BatchResponseData(OpenAIBaseModel): request_id: str # The body of the response. - body: Optional[Union[ChatCompletionResponse, EmbeddingResponse, - ScoreResponse, RerankResponse]] = None + body: Optional[ + Union[ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse] + ] = None class BatchRequestOutput(OpenAIBaseModel): @@ -2205,12 +2314,14 @@ class TokenizeCompletionRequest(OpenAIBaseModel): default=True, description=( "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt."), + "the prompt." + ), ) return_token_strs: Optional[bool] = Field( default=False, - description=("If true, also return the token strings " - "corresponding to the token ids."), + description=( + "If true, also return the token strings corresponding to the token ids." + ), ) @@ -2220,24 +2331,27 @@ class TokenizeChatRequest(OpenAIBaseModel): add_generation_prompt: bool = Field( default=True, - description= - ("If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model."), + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), ) return_token_strs: Optional[bool] = Field( default=False, - description=("If true, also return the token strings " - "corresponding to the token ids."), + description=( + "If true, also return the token strings corresponding to the token ids." + ), ) continue_final_message: bool = Field( default=False, - description= - ("If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - "This allows you to \"prefill\" part of the model's response for it. " - "Cannot be used at the same time as `add_generation_prompt`."), + description=( + "If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + 'This allows you to "prefill" part of the model\'s response for it. ' + "Cannot be used at the same time as `add_generation_prompt`." + ), ) add_special_tokens: bool = Field( default=False, @@ -2246,7 +2360,8 @@ class TokenizeChatRequest(OpenAIBaseModel): "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to false (as is the " - "default)."), + "default)." + ), ) chat_template: Optional[str] = Field( default=None, @@ -2254,13 +2369,15 @@ class TokenizeChatRequest(OpenAIBaseModel): "A Jinja template to use for this conversion. " "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one."), + "does not define one." + ), ) chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template."), + "Will be accessible by the chat template." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -2274,10 +2391,11 @@ class TokenizeChatRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get( - "add_generation_prompt"): - raise ValueError("Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True.") + if data.get("continue_final_message") and data.get("add_generation_prompt"): + raise ValueError( + "Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True." + ) return data @@ -2321,8 +2439,7 @@ class UnloadLoRAAdapterRequest(BaseModel): ## Protocols for Audio -AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", - "vtt"] +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"] class TranscriptionRequest(OpenAIBaseModel): @@ -2364,7 +2481,8 @@ class TranscriptionRequest(OpenAIBaseModel): ## TODO (varun) : Support if set to 0, certain thresholds are met !! timestamp_granularities: list[Literal["word", "segment"]] = Field( - alias="timestamp_granularities[]", default=[]) + alias="timestamp_granularities[]", default=[] + ) """The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. @@ -2384,8 +2502,10 @@ class TranscriptionRequest(OpenAIBaseModel): vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, - description=("Additional request parameters with string or " - "numeric values, used by custom extensions."), + description=( + "Additional request parameters with string or " + "numeric values, used by custom extensions." + ), ) # --8<-- [end:transcription-extra-params] @@ -2442,10 +2562,8 @@ class TranscriptionRequest(OpenAIBaseModel): } def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None) -> SamplingParams: - + self, default_max_tokens: int, default_sampling_params: Optional[dict] = None + ) -> SamplingParams: max_tokens = default_max_tokens if default_sampling_params is None: @@ -2454,35 +2572,42 @@ def to_sampling_params( # Default parameters if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) if (top_k := self.top_k) is None: top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) if (min_p := self.min_p) is None: min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] + ) if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( "repetition_penalty", - self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"]) - - return SamplingParams.from_optional(temperature=temperature, - max_tokens=max_tokens, - seed=self.seed, - top_p=top_p, - top_k=top_k, - min_p=min_p, - frequency_penalty=self.frequency_penalty, - repetition_penalty=repetition_penalty, - presence_penalty=self.presence_penalty, - output_kind=RequestOutputKind.DELTA - if self.stream \ - else RequestOutputKind.FINAL_ONLY, - extra_args=self.vllm_xargs) + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + + return SamplingParams.from_optional( + temperature=temperature, + max_tokens=max_tokens, + seed=self.seed, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + presence_penalty=self.presence_penalty, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, + extra_args=self.vllm_xargs, + ) @model_validator(mode="before") @classmethod @@ -2496,8 +2621,7 @@ def validate_transcription_request(cls, data): stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] stream = data.get("stream", False) if any(bool(data.get(so, False)) for so in stream_opts) and not stream: - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data @@ -2675,10 +2799,8 @@ class TranslationRequest(OpenAIBaseModel): } def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None) -> SamplingParams: - + self, default_max_tokens: int, default_sampling_params: Optional[dict] = None + ) -> SamplingParams: max_tokens = default_max_tokens if default_sampling_params is None: @@ -2686,14 +2808,17 @@ def to_sampling_params( # Default parameters if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) - return SamplingParams.from_optional(temperature=temperature, - max_tokens=max_tokens, - seed=self.seed, - output_kind=RequestOutputKind.DELTA - if self.stream \ - else RequestOutputKind.FINAL_ONLY) + return SamplingParams.from_optional( + temperature=temperature, + max_tokens=max_tokens, + seed=self.seed, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, + ) @model_validator(mode="before") @classmethod @@ -2701,8 +2826,7 @@ def validate_stream_options(cls, data): stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] stream = data.get("stream", False) if any(bool(data.get(so, False)) for so in stream_opts) and not stream: - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 2568c21c4abe..030ce3ce0844 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -18,18 +18,19 @@ from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -# yapf: disable -from vllm.entrypoints.openai.protocol import (BatchRequestInput, - BatchRequestOutput, - BatchResponseData, - ChatCompletionResponse, - EmbeddingResponse, ErrorResponse, - RerankResponse, ScoreResponse) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + BatchRequestInput, + BatchRequestOutput, + BatchResponseData, + ChatCompletionResponse, + EmbeddingResponse, + ErrorResponse, + RerankResponse, + ScoreResponse, +) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_models import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.serving_score import ServingScores from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser, random_uuid @@ -44,10 +45,10 @@ def make_arg_parser(parser: FlexibleArgumentParser): "--input-file", required=True, type=str, - help= - "The path or url to a single input file. Currently supports local file " + help="The path or url to a single input file. Currently supports local file " "paths, or the http protocol (http or https). If a URL is specified, " - "the file should be available via HTTP GET.") + "the file should be available via HTTP GET.", + ) parser.add_argument( "-o", "--output-file", @@ -55,7 +56,8 @@ def make_arg_parser(parser: FlexibleArgumentParser): type=str, help="The path or url to a single output file. Currently supports " "local file paths, or web (http or https) urls. If a URL is specified," - " the file should be available via HTTP PUT.") + " the file should be available via HTTP PUT.", + ) parser.add_argument( "--output-tmp-dir", type=str, @@ -63,24 +65,27 @@ def make_arg_parser(parser: FlexibleArgumentParser): help="The directory to store the output file before uploading it " "to the output URL.", ) - parser.add_argument("--response-role", - type=optional_type(str), - default="assistant", - help="The role name to return if " - "`request.add_generation_prompt=True`.") + parser.add_argument( + "--response-role", + type=optional_type(str), + default="assistant", + help="The role name to return if `request.add_generation_prompt=True`.", + ) parser = AsyncEngineArgs.add_cli_args(parser) - parser.add_argument('--max-log-len', - type=int, - default=None, - help='Max number of prompt characters or prompt ' - 'ID numbers being printed in log.' - '\n\nDefault: Unlimited') + parser.add_argument( + "--max-log-len", + type=int, + default=None, + help="Max number of prompt characters or prompt " + "ID numbers being printed in log." + "\n\nDefault: Unlimited", + ) - parser.add_argument("--enable-metrics", - action="store_true", - help="Enable Prometheus metrics") + parser.add_argument( + "--enable-metrics", action="store_true", help="Enable Prometheus metrics" + ) parser.add_argument( "--url", type=str, @@ -97,16 +102,16 @@ def make_arg_parser(parser: FlexibleArgumentParser): ) parser.add_argument( "--enable-prompt-tokens-details", - action='store_true', + action="store_true", default=False, - help="If set to True, enable prompt_tokens_details in usage.") + help="If set to True, enable prompt_tokens_details in usage.", + ) return parser def parse_args(): - parser = FlexibleArgumentParser( - description="vLLM OpenAI-Compatible batch runner.") + parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible batch runner.") return make_arg_parser(parser).parse_args() @@ -118,7 +123,6 @@ def parse_args(): class BatchProgressTracker: - def __init__(self): self._total = 0 self._pbar: Optional[tqdm] = None @@ -131,29 +135,32 @@ def completed(self): self._pbar.update() def pbar(self) -> tqdm: - enable_tqdm = not torch.distributed.is_initialized( - ) or torch.distributed.get_rank() == 0 - self._pbar = tqdm(total=self._total, - unit="req", - desc="Running batch", - mininterval=5, - disable=not enable_tqdm, - bar_format=_BAR_FORMAT) + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + self._pbar = tqdm( + total=self._total, + unit="req", + desc="Running batch", + mininterval=5, + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ) return self._pbar async def read_file(path_or_url: str) -> str: if path_or_url.startswith("http://") or path_or_url.startswith("https://"): - async with aiohttp.ClientSession() as session, \ - session.get(path_or_url) as resp: + async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp: return await resp.text() else: with open(path_or_url, encoding="utf-8") as f: return f.read() -async def write_local_file(output_path: str, - batch_outputs: list[BatchRequestOutput]) -> None: +async def write_local_file( + output_path: str, batch_outputs: list[BatchRequestOutput] +) -> None: """ Write the responses to a local file. output_path: The path to write the responses to. @@ -166,8 +173,7 @@ async def write_local_file(output_path: str, print(o.model_dump_json(), file=f) -async def upload_data(output_url: str, data_or_file: str, - from_file: bool) -> None: +async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None: """ Upload a local file to a URL. output_url: The URL to upload the file to. @@ -184,23 +190,26 @@ async def upload_data(output_url: str, data_or_file: str, try: # We increase the timeout to 1000 seconds to allow # for large files (default is 300). - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( - total=1000)) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=1000) + ) as session: if from_file: with open(data_or_file, "rb") as file: - async with session.put(output_url, - data=file) as response: + async with session.put(output_url, data=file) as response: if response.status != 200: - raise Exception(f"Failed to upload file.\n" - f"Status: {response.status}\n" - f"Response: {response.text()}") + raise Exception( + f"Failed to upload file.\n" + f"Status: {response.status}\n" + f"Response: {response.text()}" + ) else: - async with session.put(output_url, - data=data_or_file) as response: + async with session.put(output_url, data=data_or_file) as response: if response.status != 200: - raise Exception(f"Failed to upload data.\n" - f"Status: {response.status}\n" - f"Response: {response.text()}") + raise Exception( + f"Failed to upload data.\n" + f"Status: {response.status}\n" + f"Response: {response.text()}" + ) except Exception as e: if attempt < max_retries: @@ -217,8 +226,9 @@ async def upload_data(output_url: str, data_or_file: str, ) from e -async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], - output_tmp_dir: str) -> None: +async def write_file( + path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str +) -> None: """ Write batch_outputs to a file or upload to a URL. path_or_url: The path or URL to write batch_outputs to. @@ -242,14 +252,13 @@ async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], else: # Write responses to a temporary file and then upload it to the URL. with tempfile.NamedTemporaryFile( - mode="w", - encoding="utf-8", - dir=output_tmp_dir, - prefix="tmp_batch_output_", - suffix=".jsonl", + mode="w", + encoding="utf-8", + dir=output_tmp_dir, + prefix="tmp_batch_output_", + suffix=".jsonl", ) as f: - logger.info("Writing outputs to temporary local file %s", - f.name) + logger.info("Writing outputs to temporary local file %s", f.name) await write_local_file(f.name, batch_outputs) logger.info("Uploading outputs to %s", path_or_url) await upload_data(path_or_url, f.name, from_file=True) @@ -258,8 +267,9 @@ async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], await write_local_file(path_or_url, batch_outputs) -def make_error_request_output(request: BatchRequestInput, - error_msg: str) -> BatchRequestOutput: +def make_error_request_output( + request: BatchRequestInput, error_msg: str +) -> BatchRequestOutput: batch_output = BatchRequestOutput( id=f"vllm-{random_uuid()}", custom_id=request.custom_id, @@ -273,25 +283,28 @@ def make_error_request_output(request: BatchRequestInput, async def make_async_error_request_output( - request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: + request: BatchRequestInput, error_msg: str +) -> BatchRequestOutput: return make_error_request_output(request, error_msg) -async def run_request(serving_engine_func: Callable, - request: BatchRequestInput, - tracker: BatchProgressTracker) -> BatchRequestOutput: +async def run_request( + serving_engine_func: Callable, + request: BatchRequestInput, + tracker: BatchProgressTracker, +) -> BatchRequestOutput: response = await serving_engine_func(request.body) if isinstance( - response, - (ChatCompletionResponse, EmbeddingResponse, ScoreResponse, - RerankResponse), + response, + (ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse), ): batch_output = BatchRequestOutput( id=f"vllm-{random_uuid()}", custom_id=request.custom_id, response=BatchResponseData( - body=response, request_id=f"vllm-batch-{random_uuid()}"), + body=response, request_id=f"vllm-batch-{random_uuid()}" + ), error=None, ) elif isinstance(response, ErrorResponse): @@ -300,12 +313,14 @@ async def run_request(serving_engine_func: Callable, custom_id=request.custom_id, response=BatchResponseData( status_code=response.error.code, - request_id=f"vllm-batch-{random_uuid()}"), + request_id=f"vllm-batch-{random_uuid()}", + ), error=response, ) else: batch_output = make_error_request_output( - request, error_msg="Request must not be sent in stream mode") + request, error_msg="Request must not be sent in stream mode" + ) tracker.completed() return batch_output @@ -327,8 +342,7 @@ async def run_batch( request_logger = None base_model_paths = [ - BaseModelPath(name=name, model_path=args.model) - for name in served_model_names + BaseModelPath(name=name, model_path=args.model) for name in served_model_names ] model_config = vllm_config.model_config @@ -343,34 +357,48 @@ async def run_batch( base_model_paths=base_model_paths, lora_modules=None, ) - openai_serving_chat = OpenAIServingChat( - engine_client, - model_config, - openai_serving_models, - args.response_role, - request_logger=request_logger, - chat_template=None, - chat_template_content_format="auto", - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - ) if "generate" in supported_tasks else None - openai_serving_embedding = OpenAIServingEmbedding( - engine_client, - model_config, - openai_serving_models, - request_logger=request_logger, - chat_template=None, - chat_template_content_format="auto", - ) if "embed" in supported_tasks else None - - enable_serving_reranking = ("classify" in supported_tasks and getattr( - model_config.hf_config, "num_labels", 0) == 1) - - openai_serving_scores = ServingScores( - engine_client, - model_config, - openai_serving_models, - request_logger=request_logger, - ) if ("embed" in supported_tasks or enable_serving_reranking) else None + openai_serving_chat = ( + OpenAIServingChat( + engine_client, + model_config, + openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=None, + chat_template_content_format="auto", + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + ) + if "generate" in supported_tasks + else None + ) + openai_serving_embedding = ( + OpenAIServingEmbedding( + engine_client, + model_config, + openai_serving_models, + request_logger=request_logger, + chat_template=None, + chat_template_content_format="auto", + ) + if "embed" in supported_tasks + else None + ) + + enable_serving_reranking = ( + "classify" in supported_tasks + and getattr(model_config.hf_config, "num_labels", 0) == 1 + ) + + openai_serving_scores = ( + ServingScores( + engine_client, + model_config, + openai_serving_models, + request_logger=request_logger, + ) + if ("embed" in supported_tasks or enable_serving_reranking) + else None + ) tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) @@ -387,61 +415,72 @@ async def run_batch( # Determine the type of request and run it. if request.url == "/v1/chat/completions": - chat_handler_fn = openai_serving_chat.create_chat_completion if \ - openai_serving_chat is not None else None + chat_handler_fn = ( + openai_serving_chat.create_chat_completion + if openai_serving_chat is not None + else None + ) if chat_handler_fn is None: response_futures.append( make_async_error_request_output( request, - error_msg= - "The model does not support Chat Completions API", - )) + error_msg="The model does not support Chat Completions API", + ) + ) continue - response_futures.append( - run_request(chat_handler_fn, request, tracker)) + response_futures.append(run_request(chat_handler_fn, request, tracker)) tracker.submitted() elif request.url == "/v1/embeddings": - embed_handler_fn = openai_serving_embedding.create_embedding if \ - openai_serving_embedding is not None else None + embed_handler_fn = ( + openai_serving_embedding.create_embedding + if openai_serving_embedding is not None + else None + ) if embed_handler_fn is None: response_futures.append( make_async_error_request_output( request, error_msg="The model does not support Embeddings API", - )) + ) + ) continue - response_futures.append( - run_request(embed_handler_fn, request, tracker)) + response_futures.append(run_request(embed_handler_fn, request, tracker)) tracker.submitted() elif request.url.endswith("/score"): - score_handler_fn = openai_serving_scores.create_score if \ - openai_serving_scores is not None else None + score_handler_fn = ( + openai_serving_scores.create_score + if openai_serving_scores is not None + else None + ) if score_handler_fn is None: response_futures.append( make_async_error_request_output( request, error_msg="The model does not support Scores API", - )) + ) + ) continue - response_futures.append( - run_request(score_handler_fn, request, tracker)) + response_futures.append(run_request(score_handler_fn, request, tracker)) tracker.submitted() elif request.url.endswith("/rerank"): - rerank_handler_fn = openai_serving_scores.do_rerank if \ - openai_serving_scores is not None else None + rerank_handler_fn = ( + openai_serving_scores.do_rerank + if openai_serving_scores is not None + else None + ) if rerank_handler_fn is None: response_futures.append( make_async_error_request_output( request, error_msg="The model does not support Rerank API", - )) + ) + ) continue - response_futures.append( - run_request(rerank_handler_fn, request, tracker)) + response_futures.append(run_request(rerank_handler_fn, request, tracker)) tracker.submitted() else: response_futures.append( @@ -452,7 +491,8 @@ async def run_batch( " /score, /rerank ." "See vllm/entrypoints/openai/api_server.py for supported " "score/rerank versions.", - )) + ) + ) with tracker.pbar(): responses = await asyncio.gather(*response_futures) @@ -465,9 +505,9 @@ async def main(args: Namespace): from vllm.usage.usage_lib import UsageContext async with build_async_engine_client( - args, - usage_context=UsageContext.OPENAI_BATCH_RUNNER, - disable_frontend_multiprocessing=False, + args, + usage_context=UsageContext.OPENAI_BATCH_RUNNER, + disable_frontend_multiprocessing=False, ) as engine_client: vllm_config = await engine_client.get_vllm_config() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a646b16da82c..12dd474936db 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -17,29 +17,48 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, - ConversationMessage, - get_history_tool_calls_cnt, - make_tool_call_id) +from vllm.entrypoints.chat_utils import ( + ChatTemplateContentFormatOption, + ConversationMessage, + get_history_tool_calls_cnt, + make_tool_call_id, +) from vllm.entrypoints.harmony_utils import ( - get_developer_message, get_stop_tokens_for_assistant_actions, - get_streamable_parser_for_assistant, get_system_message, parse_chat_input, - parse_chat_output, render_for_completion) + get_developer_message, + get_stop_tokens_for_assistant_actions, + get_streamable_parser_for_assistant, + get_system_message, + parse_chat_input, + parse_chat_output, + render_for_completion, +) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( - ChatCompletionLogProb, ChatCompletionLogProbs, - ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, - ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition, - PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - clamp_prompt_logprobs) + ChatCompletionLogProb, + ChatCompletionLogProbs, + ChatCompletionLogProbsContent, + ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ErrorResponse, + FunctionCall, + FunctionDefinition, + PromptTokenUsageInfo, + RequestResponseMetadata, + ToolCall, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( - MistralToolCall) +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.entrypoints.utils import get_max_tokens from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger @@ -48,16 +67,17 @@ from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, - truncate_tool_call_ids, - validate_request_params) +from vllm.transformers_utils.tokenizers import ( + maybe_serialize_tool_calls, + truncate_tool_call_ids, + validate_request_params, +) from vllm.utils import as_list logger = init_logger(__name__) class OpenAIServingChat(OpenAIServing): - def __init__( self, engine_client: EngineClient, @@ -79,13 +99,15 @@ def __init__( enable_log_outputs: bool = False, log_error_stack: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage, - log_error_stack=log_error_stack) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, + ) self.response_role = response_role self.chat_template = chat_template @@ -97,58 +119,63 @@ def __init__( self.enable_auto_tools: bool = enable_auto_tools if self.enable_auto_tools: logger.info( - "\"auto\" tool choice has been enabled please note that while" + '"auto" tool choice has been enabled please note that while' " the parallel_tool_calls client option is preset for " - "compatibility reasons, it will be ignored.") + "compatibility reasons, it will be ignored." + ) - self.reasoning_parser: Optional[Callable[[AnyTokenizer], - ReasoningParser]] = None + self.reasoning_parser: Optional[Callable[[AnyTokenizer], ReasoningParser]] = ( + None + ) if reasoning_parser: try: - self.reasoning_parser = ( - ReasoningParserManager.get_reasoning_parser( - reasoning_parser)) + self.reasoning_parser = ReasoningParserManager.get_reasoning_parser( + reasoning_parser + ) assert self.reasoning_parser is not None except Exception as e: - raise TypeError( - f"{reasoning_parser=} has not been registered") from e + raise TypeError(f"{reasoning_parser=} has not been registered") from e self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None if self.enable_auto_tools: try: - if (tool_parser == "pythonic" and - model_config.model.startswith("meta-llama/Llama-3.2")): + if tool_parser == "pythonic" and model_config.model.startswith( + "meta-llama/Llama-3.2" + ): logger.warning( - "Llama3.2 models may struggle to emit valid pythonic" - " tool calls") - self.tool_parser = ToolParserManager.get_tool_parser( - tool_parser) + "Llama3.2 models may struggle to emit valid pythonic tool calls" + ) + self.tool_parser = ToolParserManager.get_tool_parser(tool_parser) except Exception as e: - raise TypeError("Error: --enable-auto-tool-choice requires " - f"tool_parser:'{tool_parser}' which has not " - "been registered") from e - self.exclude_tools_when_tool_choice_none = ( - exclude_tools_when_tool_choice_none) + raise TypeError( + "Error: --enable-auto-tool-choice requires " + f"tool_parser:'{tool_parser}' which has not " + "been registered" + ) from e + self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_force_include_usage = enable_force_include_usage - self.default_sampling_params = ( - self.model_config.get_diff_sampling_param()) + self.default_sampling_params = self.model_config.get_diff_sampling_param() if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source - logger.info("Using default chat sampling params from %s: %s", - source, self.default_sampling_params) - if self.model_config.hf_config.model_type == 'kimi_k2': - self.tool_call_id_type = 'kimi_k2' + logger.info( + "Using default chat sampling params from %s: %s", + source, + self.default_sampling_params, + ) + if self.model_config.hf_config.model_type == "kimi_k2": + self.tool_call_id_type = "kimi_k2" else: - self.tool_call_id_type = 'random' + self.tool_call_id_type = "random" self.use_harmony = model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: if "stop_token_ids" not in self.default_sampling_params: self.default_sampling_params["stop_token_ids"] = [] self.default_sampling_params["stop_token_ids"].extend( - get_stop_tokens_for_assistant_actions()) + get_stop_tokens_for_assistant_actions() + ) # NOTE(woosuk): While OpenAI's chat completion API supports browsing # for some models, currently vLLM doesn't support it. Please use the @@ -164,8 +191,7 @@ async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Optional[Request] = None, - ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, - ErrorResponse]: + ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, ErrorResponse]: """ Chat Completion API similar to OpenAI's API. @@ -186,7 +212,8 @@ async def create_chat_completion( try: lora_request = self._maybe_get_adapters( - request, supports_default_mm_loras=True) + request, supports_default_mm_loras=True + ) model_name = self.models.model_name(lora_request) @@ -202,36 +229,36 @@ async def create_chat_completion( truncate_tool_call_ids(request) validate_request_params(request) - if (request.tool_choice == "auto" and - not (self.enable_auto_tools and tool_parser is not None) - and not isinstance(tokenizer, MistralTokenizer) - and not self.use_harmony): + if ( + request.tool_choice == "auto" + and not (self.enable_auto_tools and tool_parser is not None) + and not isinstance(tokenizer, MistralTokenizer) + and not self.use_harmony + ): # for hf tokenizers, "auto" tools requires # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( - "\"auto\" tool choice requires " + '"auto" tool choice requires ' "--enable-auto-tool-choice and --tool-call-parser to be set" ) - if (request.tools is None - or (request.tool_choice == "none" - and self.exclude_tools_when_tool_choice_none)): + if request.tools is None or ( + request.tool_choice == "none" + and self.exclude_tools_when_tool_choice_none + ): tool_dicts = None else: tool_dicts = [tool.model_dump() for tool in request.tools] if not self.use_harmony: # Common case. - request_chat_template = request.chat_template - chat_template_kwargs = request.chat_template_kwargs - if not self.trust_request_chat_template and ( - request_chat_template is not None or - (chat_template_kwargs and - chat_template_kwargs.get("chat_template") is not None)): - return self.create_error_response( - "Chat template is passed with request, but " - "--trust-request-chat-template is not set. " - "Refused request with untrusted chat template.") + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret ( conversation, request_prompts, @@ -240,9 +267,8 @@ async def create_chat_completion( request, tokenizer, request.messages, - chat_template=request_chat_template or self.chat_template, - chat_template_content_format=self. - chat_template_content_format, + chat_template=request.chat_template or self.chat_template, + chat_template_content_format=self.chat_template_content_format, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, tool_dicts=tool_dicts, @@ -258,13 +284,13 @@ async def create_chat_completion( request_prompts, engine_prompts, ) = self._make_request_with_harmony(request) - except (ValueError, TypeError, RuntimeError, - jinja2.TemplateError) as e: + except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") - request_id = "chatcmpl-" \ - f"{self._base_request_id(raw_request, request.request_id)}" + request_id = ( + f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" + ) request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: @@ -274,8 +300,7 @@ async def create_chat_completion( generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - prompt_text, _, _ = (self._get_prompt_components( - request_prompts[i])) + prompt_text, _, _ = self._get_prompt_components(request_prompts[i]) if self.default_sampling_params is None: self.default_sampling_params = {} @@ -284,24 +309,33 @@ async def create_chat_completion( max_model_len=self.max_model_len, request=request, input_length=len(engine_prompt["prompt_token_ids"]), - default_sampling_params=self.default_sampling_params) + default_sampling_params=self.default_sampling_params, + ) sampling_params: Union[SamplingParams, BeamSearchParams] if request.use_beam_search: sampling_params = request.to_beam_search_params( - max_tokens, self.default_sampling_params) + max_tokens, self.default_sampling_params + ) else: sampling_params = request.to_sampling_params( - max_tokens, self.model_config.logits_processor_pattern, - self.default_sampling_params) + max_tokens, + self.model_config.logits_processor_pattern, + self.default_sampling_params, + ) - self._log_inputs(request_id, - request_prompts[i], - params=sampling_params, - lora_request=lora_request) + self._log_inputs( + request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + ) - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( @@ -311,15 +345,14 @@ async def create_chat_completion( lora_request=lora_request, ) else: - engine_request, tokenization_kwargs = ( - await self._process_inputs( - request_id, - engine_prompt, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, - )) + engine_request, tokenization_kwargs = await self._process_inputs( + request_id, + engine_prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) generator = self.engine_client.generate( engine_request, @@ -338,7 +371,7 @@ async def create_chat_completion( return self.create_error_response(str(e)) assert len(generators) == 1 - result_generator, = generators + (result_generator,) = generators # Streaming response if request.stream: @@ -350,12 +383,19 @@ async def create_chat_completion( conversation, tokenizer, request_metadata, - enable_force_include_usage=self.enable_force_include_usage) + enable_force_include_usage=self.enable_force_include_usage, + ) try: return await self.chat_completion_full_generator( - request, result_generator, request_id, model_name, - conversation, tokenizer, request_metadata) + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + ) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -366,7 +406,7 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: return request.messages[-1]["role"] @staticmethod - def _bracket_level(s: str, opening='{', closing='}') -> int: + def _bracket_level(s: str, opening="{", closing="}") -> int: """ Calculate the current level of nested brackets in a given string. """ @@ -379,8 +419,7 @@ def _bracket_level(s: str, opening='{', closing='}') -> int: return level @staticmethod - def _filter_delta_text(delta_text: str, - previous_text: str) -> tuple[str, bool]: + def _filter_delta_text(delta_text: str, previous_text: str) -> tuple[str, bool]: # remove last '},' of the tool definition stemming from the # "name"/"parameters" outer object or closing ']' of the tool list # count occurrences of opening and closing curly braces and @@ -390,10 +429,10 @@ def _filter_delta_text(delta_text: str, bracket_level = OpenAIServingChat._bracket_level(previous_text) updated_delta, passed_zero = "", False for c in delta_text: - if c == '{': + if c == "{": bracket_level += 1 passed_zero = bracket_level == 0 - elif c == '}': + elif c == "}": bracket_level -= 1 passed_zero = bracket_level == 0 @@ -401,7 +440,7 @@ def _filter_delta_text(delta_text: str, updated_delta += c else: # if a comma is reached at level 0 we can stop - if c == ',': + if c == ",": break return updated_delta, passed_zero @@ -411,7 +450,7 @@ def extract_tool_call_required_streaming( current_text: Optional[str], delta_text: str, function_name_returned: bool, - tool_call_idx: Optional[int] = None + tool_call_idx: Optional[int] = None, ) -> tuple[Optional[DeltaMessage], bool]: if current_text is None or current_text == "": # if the current text is empty, we cannot parse it @@ -419,7 +458,7 @@ def extract_tool_call_required_streaming( try: obj = partial_json_parser.loads(current_text) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") obj = None # check if the current text is a valid array @@ -430,60 +469,72 @@ def extract_tool_call_required_streaming( delta_message = None else: _, finishes_previous_tool = OpenAIServingChat._filter_delta_text( - delta_text, previous_text) + delta_text, previous_text + ) # take the last tool call from the generated list current_tool_call = obj[-1] # once parameters have been generated the name is complete as well - if not finishes_previous_tool and ("name" not in current_tool_call - or "parameters" - not in current_tool_call): + if not finishes_previous_tool and ( + "name" not in current_tool_call or "parameters" not in current_tool_call + ): function_name_returned = False delta_message = None else: if not function_name_returned: # get partly generated arguments from the latest tool call - param_match = re.search(r'.*"parameters":\s*(.*)', - current_text, re.DOTALL) + param_match = re.search( + r'.*"parameters":\s*(.*)', current_text, re.DOTALL + ) arguments = param_match.group(1) if param_match else "" arguments, _ = OpenAIServingChat._filter_delta_text( - arguments, previous_text) + arguments, previous_text + ) # if this iteration finishes a previous tool call but a # new incomplete tool is already generated, take the # previous from the list - if (finishes_previous_tool - and "parameters" not in current_tool_call): + if finishes_previous_tool and "parameters" not in current_tool_call: current_tool_call = obj[-2] function_name_returned = True tool_call_id = make_tool_call_id( id_type=self.tool_call_id_type, func_name=current_tool_call["name"], - idx=tool_call_idx) - delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(id=tool_call_id, - function=DeltaFunctionCall( - name=current_tool_call["name"], - arguments=arguments), - index=len(obj) - 1, - type="function") - ]) + idx=tool_call_idx, + ) + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + id=tool_call_id, + function=DeltaFunctionCall( + name=current_tool_call["name"], arguments=arguments + ), + index=len(obj) - 1, + type="function", + ) + ] + ) else: delta_text, _ = OpenAIServingChat._filter_delta_text( - delta_text, previous_text) + delta_text, previous_text + ) if delta_text != "": - delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall( - function=DeltaFunctionCall( - # OpenAI API returns None - # instead of name every time - name=None, - arguments=delta_text), - index=len(obj) - 1) - ]) + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + function=DeltaFunctionCall( + # OpenAI API returns None + # instead of name every time + name=None, + arguments=delta_text, + ), + index=len(obj) - 1, + ) + ] + ) else: delta_message = None @@ -512,8 +563,7 @@ async def chat_completion_stream_generator( num_cached_tokens = None if self.use_harmony: harmony_parsers = [ - get_streamable_parser_for_assistant() - for _ in range(num_choices) + get_streamable_parser_for_assistant() for _ in range(num_choices) ] harmony_tools_streamed = [False] * num_choices tools_streamed = [False] * num_choices @@ -526,11 +576,12 @@ async def chat_completion_stream_generator( # Determine whether tools are in use with "auto" tool choice tool_choice_auto = ( not tool_choice_function_name - and self._should_stream_with_auto_tool_parsing(request)) + and self._should_stream_with_auto_tool_parsing(request) + ) all_previous_token_ids: Optional[list[list[int]]] function_name_returned = [False] * num_choices - if self.tool_call_id_type == 'kimi_k2': + if self.tool_call_id_type == "kimi_k2": history_tool_call_cnt = get_history_tool_calls_cnt(conversation) else: history_tool_call_cnt = 0 @@ -577,10 +628,10 @@ async def chat_completion_stream_generator( stream_options = request.stream_options if stream_options: - include_usage = stream_options.include_usage \ - or enable_force_include_usage - include_continuous_usage = include_usage and \ - stream_options.continuous_usage_stats + include_usage = stream_options.include_usage or enable_force_include_usage + include_continuous_usage = ( + include_usage and stream_options.continuous_usage_stats + ) else: include_usage, include_continuous_usage = False, False @@ -610,7 +661,8 @@ async def chat_completion_stream_generator( content="", ), logprobs=None, - finish_reason=None) + finish_reason=None, + ) # return prompt_token_ids at the first chunk ever chunk = ChatCompletionStreamResponse( @@ -619,16 +671,20 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name, - prompt_token_ids=(res.prompt_token_ids - if request.return_token_ids else - None)) + prompt_token_ids=( + res.prompt_token_ids + if request.return_token_ids + else None + ), + ) # if continuous usage stats are requested, add it if include_continuous_usage: chunk.usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=num_prompt_tokens) + total_tokens=num_prompt_tokens, + ) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" @@ -637,33 +693,36 @@ async def chat_completion_stream_generator( # last message if request.echo: last_msg_content: Union[str, list[dict[str, str]]] = "" - if conversation and "content" in conversation[ - -1] and conversation[-1].get("role") == role: + if ( + conversation + and "content" in conversation[-1] + and conversation[-1].get("role") == role + ): last_msg_content = conversation[-1]["content"] or "" if last_msg_content: for i in range(num_choices): - choice_data = ( - ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage( - content=last_msg_content), - logprobs=None, - finish_reason=None)) + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + logprobs=None, + finish_reason=None, + ) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + ) if include_continuous_usage: chunk.usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=0, - total_tokens=num_prompt_tokens) + total_tokens=num_prompt_tokens, + ) - data = chunk.model_dump_json( - exclude_unset=True) + data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" first_iteration = False @@ -675,15 +734,13 @@ async def chat_completion_stream_generator( continue if request.logprobs and request.top_logprobs is not None: - assert output.logprobs is not None, ( - "Did not output logprobs") + assert output.logprobs is not None, "Did not output logprobs" logprobs = self._create_chat_logprobs( token_ids=output.token_ids, top_logprobs=output.logprobs, tokenizer=tokenizer, num_output_top_logprobs=request.top_logprobs, - return_as_token_id=request. - return_tokens_as_token_ids, + return_as_token_id=request.return_tokens_as_token_ids, ) else: logprobs = None @@ -694,15 +751,17 @@ async def chat_completion_stream_generator( delta_text = "" for token_id in output.token_ids: harmony_parser.process(token_id) - delta_text += (harmony_parser.last_content_delta - or "") + delta_text += harmony_parser.last_content_delta or "" cur_channel = harmony_parser.current_channel cur_recipient = harmony_parser.current_recipient else: delta_text = output.text - if not delta_text and not output.token_ids and \ - not previous_num_tokens[i]: + if ( + not delta_text + and not output.token_ids + and not previous_num_tokens[i] + ): # Chunked prefill case, don't return empty chunks continue @@ -718,7 +777,8 @@ async def chat_completion_stream_generator( # avoid the None + list error. if previous_token_ids: current_token_ids = previous_token_ids + as_list( - output.token_ids) + output.token_ids + ) else: current_token_ids = as_list(output.token_ids) @@ -728,42 +788,51 @@ async def chat_completion_stream_generator( elif cur_channel == "analysis": if request.include_reasoning: delta_message = DeltaMessage( - reasoning_content=delta_text) + reasoning_content=delta_text + ) else: delta_message = None - elif (cur_channel == "commentary" and cur_recipient - and cur_recipient.startswith("functions.")): + elif ( + cur_channel == "commentary" + and cur_recipient + and cur_recipient.startswith("functions.") + ): # Count completed tool calls to determine index base_index = 0 for msg in harmony_parser.messages: - if (msg.channel == "commentary" - and msg.recipient - and msg.recipient.startswith( - "functions.")): + if ( + msg.channel == "commentary" + and msg.recipient + and msg.recipient.startswith("functions.") + ): base_index += 1 if prev_recipient != cur_recipient: - tool_name = cur_recipient.split( - "functions.", 1)[1] - delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall( - id=make_tool_call_id(), - type="function", - function=DeltaFunctionCall( - name=tool_name, - arguments="", - ), - index=base_index, - ) - ]) + tool_name = cur_recipient.split("functions.", 1)[1] + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=tool_name, + arguments="", + ), + index=base_index, + ) + ] + ) elif delta_text: - delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=base_index, - function=DeltaFunctionCall( - arguments=delta_text), - ) - ]) + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=base_index, + function=DeltaFunctionCall( + arguments=delta_text + ), + ) + ] + ) else: delta_message = None @@ -773,30 +842,37 @@ async def chat_completion_stream_generator( delta_message = None # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: - if (self.reasoning_parser and not reasoning_end_arr[i] - and not reasoning_parser.is_reasoning_end( - previous_token_ids)): + if ( + self.reasoning_parser + and not reasoning_end_arr[i] + and not reasoning_parser.is_reasoning_end( + previous_token_ids + ) + ): assert reasoning_parser is not None delta_message = ( - reasoning_parser. - extract_reasoning_content_streaming( + reasoning_parser.extract_reasoning_content_streaming( previous_text, current_text, delta_text, previous_token_ids, current_token_ids, output.token_ids, - )) + ) + ) # When encountering think end id in delta_token_ids # or think end id in prompt_token_ids # i.e {"enable_thinking": False}, # set reasoning status to end. # Only keep 'content', remove 'reasoning_content'. if reasoning_parser.is_reasoning_end( - as_list(output.token_ids)) or ( - res.prompt_token_ids - and reasoning_parser.is_reasoning_end( - res.prompt_token_ids)): + as_list(output.token_ids) + ) or ( + res.prompt_token_ids + and reasoning_parser.is_reasoning_end( + res.prompt_token_ids + ) + ): reasoning_end_arr[i] = True if delta_message and delta_message.content: # This need to be added to next `delta_text` @@ -812,22 +888,26 @@ async def chat_completion_stream_generator( if function_name_returned[i]: delta_tool_call = DeltaToolCall( - function=DeltaFunctionCall( - arguments=delta_text), - index=i) + function=DeltaFunctionCall(arguments=delta_text), + index=i, + ) else: delta_tool_call = DeltaToolCall( id=make_tool_call_id(), type="function", function=DeltaFunctionCall( name=tool_choice_function_name, - arguments=delta_text), - index=i) + arguments=delta_text, + ), + index=i, + ) function_name_returned[i] = True - delta_message = DeltaMessage(tool_calls=[ - delta_tool_call, - ]) + delta_message = DeltaMessage( + tool_calls=[ + delta_tool_call, + ] + ) tools_streamed[i] = True elif request.tool_choice == "required": @@ -837,11 +917,9 @@ async def chat_completion_stream_generator( fn_name_returned = function_name_returned[i] if self.reasoning_parser: - _, content = \ - reasoning_parser.extract_reasoning_content( - current_text, - request - ) + _, content = reasoning_parser.extract_reasoning_content( + current_text, request + ) else: content = current_text delta_message, function_name_returned[i] = ( @@ -850,9 +928,14 @@ async def chat_completion_stream_generator( current_text=content, delta_text=delta_text, function_name_returned=fn_name_returned, - tool_call_idx=history_tool_call_cnt)) - if (delta_message and delta_message.tool_calls and - delta_message.tool_calls[0].id is not None): + tool_call_idx=history_tool_call_cnt, + ) + ) + if ( + delta_message + and delta_message.tool_calls + and delta_message.tool_calls[0].id is not None + ): history_tool_call_cnt += 1 tools_streamed[i] = True @@ -866,23 +949,26 @@ async def chat_completion_stream_generator( output_token_ids = as_list(output.token_ids) if not reasoning_end_arr[i]: delta_message = ( - reasoning_parser. - extract_reasoning_content_streaming( + reasoning_parser.extract_reasoning_content_streaming( previous_text, current_text, delta_text, previous_token_ids, current_token_ids, output_token_ids, - )) + ) + ) # When encountering think end id in prompt_token_ids # i.e {"enable_thinking": False}, # set reasoning status to end. # Remove the text and token ids related # to 'reasoning_content'. - if res.prompt_token_ids and \ - reasoning_parser.is_reasoning_end( - res.prompt_token_ids): + if ( + res.prompt_token_ids + and reasoning_parser.is_reasoning_end( + res.prompt_token_ids + ) + ): reasoning_end_arr[i] = True current_token_ids = output_token_ids if delta_message and delta_message.content: @@ -894,12 +980,13 @@ async def chat_completion_stream_generator( # set reasoning status to end. # Remove the text and token ids related # to 'reasoning_content'. - if reasoning_parser.is_reasoning_end( - output_token_ids): + if reasoning_parser.is_reasoning_end(output_token_ids): reasoning_end_arr[i] = True - current_token_ids = \ + current_token_ids = ( reasoning_parser.extract_content_ids( - output_token_ids) + output_token_ids + ) + ) if delta_message and delta_message.content: current_text = delta_message.content delta_message.content = None @@ -919,50 +1006,52 @@ async def chat_completion_stream_generator( delta_text = current_text delta_token_ids = current_token_ids - delta_message = ( - tool_parser.extract_tool_calls_streaming( - previous_text=previous_text, - current_text=current_text, - delta_text=delta_text, - previous_token_ids=previous_token_ids, - current_token_ids=current_token_ids, - delta_token_ids=delta_token_ids, - request=request)) - if delta_message and delta_message.tool_calls: - tools_streamed[i] = True - # when only tool calls - elif tool_choice_auto: - assert tool_parser is not None - delta_message = ( - tool_parser.extract_tool_calls_streaming( + delta_message = tool_parser.extract_tool_calls_streaming( previous_text=previous_text, current_text=current_text, delta_text=delta_text, previous_token_ids=previous_token_ids, current_token_ids=current_token_ids, - delta_token_ids=output.token_ids, - request=request)) + delta_token_ids=delta_token_ids, + request=request, + ) + if delta_message and delta_message.tool_calls: + tools_streamed[i] = True + # when only tool calls + elif tool_choice_auto: + assert tool_parser is not None + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + request=request, + ) if delta_message and delta_message.tool_calls: tools_streamed[i] = True # when only reasoning elif self.reasoning_parser: - delta_message = (reasoning_parser. - extract_reasoning_content_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - output.token_ids, - )) + delta_message = ( + reasoning_parser.extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + ) + ) # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) # update the previous values for the next iteration - if ((tool_choice_auto or self.reasoning_parser) - and not self.use_harmony): + if ( + tool_choice_auto or self.reasoning_parser + ) and not self.use_harmony: assert previous_texts is not None assert all_previous_token_ids is not None previous_texts[i] = current_text @@ -994,7 +1083,8 @@ async def chat_completion_stream_generator( delta_content = "".join( tc.function.arguments for tc in delta_message.tool_calls - if tc.function and tc.function.arguments) + if tc.function and tc.function.arguments + ) if delta_content: self.request_logger.log_outputs( @@ -1013,8 +1103,12 @@ async def chat_completion_stream_generator( delta=delta_message, logprobs=logprobs, finish_reason=None, - token_ids=(as_list(output.token_ids) - if request.return_token_ids else None)) + token_ids=( + as_list(output.token_ids) + if request.return_token_ids + else None + ), + ) # if the model is finished generating else: @@ -1024,66 +1118,86 @@ async def chat_completion_stream_generator( # only happens if we are NOT using structured outputs auto_tools_called = False if tool_parser: - auto_tools_called = len( - tool_parser.prev_tool_call_arr) > 0 - index = len(tool_parser.prev_tool_call_arr - ) - 1 if auto_tools_called else 0 + auto_tools_called = len(tool_parser.prev_tool_call_arr) > 0 + index = ( + len(tool_parser.prev_tool_call_arr) - 1 + if auto_tools_called + else 0 + ) else: index = 0 - if self._should_check_for_unstreamed_tool_arg_tokens( - delta_message, output) and tool_parser: + if ( + self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output + ) + and tool_parser + ): latest_delta_len = 0 - if ((isinstance( + if ( + isinstance( delta_message.tool_calls[0].function, - DeltaFunctionCall)) and isinstance( - delta_message.tool_calls[0].function. - arguments, str)): + DeltaFunctionCall, + ) + ) and isinstance( + delta_message.tool_calls[0].function.arguments, str + ): latest_delta_len = len( - delta_message.tool_calls[0].function. - arguments) + delta_message.tool_calls[0].function.arguments + ) # get the expected call based on partial JSON # parsing which "autocompletes" the JSON expected_call = json.dumps( tool_parser.prev_tool_call_arr[index].get( - "arguments", {}), - ensure_ascii=False) + "arguments", {} + ), + ensure_ascii=False, + ) # get what we've streamed so far for arguments # for the current tool - actual_call = tool_parser.streamed_args_for_tool[ - index] - if (latest_delta_len > 0): + actual_call = tool_parser.streamed_args_for_tool[index] + if latest_delta_len > 0: actual_call = actual_call[:-latest_delta_len] # check to see if there's anything left to stream - remaining_call = expected_call.replace( - actual_call, "", 1) + remaining_call = expected_call.replace(actual_call, "", 1) # set that as a delta message - delta_message = DeltaMessage(tool_calls=[ - DeltaToolCall(index=index, - function=DeltaFunctionCall( - arguments=remaining_call). - model_dump(exclude_none=True)) - ]) + delta_message = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=index, + function=DeltaFunctionCall( + arguments=remaining_call + ).model_dump(exclude_none=True), + ) + ] + ) # Send the finish response for each request.n only once - if auto_tools_called or tools_streamed[i] or ( - self.use_harmony - and harmony_tools_streamed[i]): + if ( + auto_tools_called + or tools_streamed[i] + or (self.use_harmony and harmony_tools_streamed[i]) + ): finish_reason_ = "tool_calls" else: - finish_reason_ = output.finish_reason \ - if output.finish_reason else "stop" + finish_reason_ = ( + output.finish_reason if output.finish_reason else "stop" + ) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, logprobs=logprobs, finish_reason=finish_reason_, stop_reason=output.stop_reason, - token_ids=(as_list(output.token_ids) - if request.return_token_ids else None)) + token_ids=( + as_list(output.token_ids) + if request.return_token_ids + else None + ), + ) finish_reason_sent[i] = True @@ -1092,7 +1206,8 @@ async def chat_completion_stream_generator( object=chunk_object_type, created=created_time, choices=[choice_data], - model=model_name) + model=model_name, + ) # handle usage stats if requested & if continuous if include_continuous_usage: @@ -1110,13 +1225,15 @@ async def chat_completion_stream_generator( # is sent, send the usage if include_usage: completion_tokens = sum(previous_num_tokens) - final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + - completion_tokens) + final_usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) if self.enable_prompt_tokens_details and num_cached_tokens: final_usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=num_cached_tokens) + cached_tokens=num_cached_tokens + ) final_usage_chunk = ChatCompletionStreamResponse( id=request_id, @@ -1124,9 +1241,11 @@ async def chat_completion_stream_generator( created=created_time, choices=[], model=model_name, - usage=final_usage) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=True, exclude_none=True)) + usage=final_usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True + ) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices @@ -1143,14 +1262,13 @@ async def chat_completion_stream_generator( for i in range(num_choices): full_text = ( previous_texts[i] - if previous_texts and i < len(previous_texts) else - f"<streaming_complete: {previous_num_tokens[i]} tokens>" + if previous_texts and i < len(previous_texts) + else f"<streaming_complete: {previous_num_tokens[i]} tokens>" ) self.request_logger.log_outputs( request_id=request_id, outputs=full_text, - output_token_ids= - None, # Consider also logging all token IDs + output_token_ids=None, # Consider also logging all token IDs finish_reason="streaming_complete", is_streaming=True, delta=False, @@ -1174,7 +1292,6 @@ async def chat_completion_full_generator( tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: - created_time = int(time.time()) final_res: Optional[RequestOutput] = None @@ -1190,7 +1307,7 @@ async def chat_completion_full_generator( assert final_res is not None choices: list[ChatCompletionResponseChoice] = [] - if self.tool_call_id_type == 'kimi_k2': + if self.tool_call_id_type == "kimi_k2": history_tool_call_cnt = get_history_tool_calls_cnt(conversation) else: history_tool_call_cnt = 0 @@ -1244,10 +1361,11 @@ async def chat_completion_full_generator( index=output.index, message=message, logprobs=logprobs, - finish_reason="tool_calls" if - (tool_call_info is not None - and tool_call_info.tools_called) else - output.finish_reason if output.finish_reason else "stop", + finish_reason="tool_calls" + if (tool_call_info is not None and tool_call_info.tools_called) + else output.finish_reason + if output.finish_reason + else "stop", stop_reason=output.stop_reason, ) choices.append(choice_data) @@ -1261,9 +1379,9 @@ async def chat_completion_full_generator( return self.create_error_response(str(e)) # If the reasoning parser is enabled, # tool calls are extracted exclusively from the content. - reasoning_content, content = ( - reasoning_parser.extract_reasoning_content( - output.text, request=request)) + reasoning_content, content = reasoning_parser.extract_reasoning_content( + output.text, request=request + ) if not request.include_reasoning: reasoning_content = None else: @@ -1273,76 +1391,93 @@ async def chat_completion_full_generator( auto_tools_called = False # if auto tools are not enabled, and a named tool choice using # outlines is not being used - if (not self.enable_auto_tools or not self.tool_parser) and \ - (not isinstance(request.tool_choice, - ChatCompletionNamedToolChoiceParam - ) and request.tool_choice != "required"): - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=content) + if (not self.enable_auto_tools or not self.tool_parser) and ( + not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) + and request.tool_choice != "required" + ): + message = ChatMessage( + role=role, reasoning_content=reasoning_content, content=content + ) # if the request uses tools and specified a tool choice - elif request.tool_choice and type( - request.tool_choice) is ChatCompletionNamedToolChoiceParam: - - tool_call_class = MistralToolCall if isinstance( - tokenizer, MistralTokenizer) else ToolCall + elif ( + request.tool_choice + and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam + ): + tool_call_class = ( + MistralToolCall + if isinstance(tokenizer, MistralTokenizer) + else ToolCall + ) message = ChatMessage( role=role, reasoning_content=reasoning_content, content="", tool_calls=[ - tool_call_class(function=FunctionCall( - name=request.tool_choice.function.name, - arguments=content, - )) + tool_call_class( + function=FunctionCall( + name=request.tool_choice.function.name, + arguments=content, + ) + ) ], ) elif request.tool_choice and request.tool_choice == "required": - tool_call_class = MistralToolCall if isinstance( - tokenizer, MistralTokenizer) else ToolCall + tool_call_class = ( + MistralToolCall + if isinstance(tokenizer, MistralTokenizer) + else ToolCall + ) # the fields of FunctionDefinition are a superset of the # tool call outputs and can be used for parsing assert content is not None - tool_calls = TypeAdapter( - list[FunctionDefinition]).validate_json(content) + tool_calls = TypeAdapter(list[FunctionDefinition]).validate_json( + content + ) tool_call_ids = [] for tool_call in tool_calls: tool_call_ids.append( - make_tool_call_id(id_type=self.tool_call_id_type, - func_name=tool_call.name, - idx=history_tool_call_cnt)) + make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=tool_call.name, + idx=history_tool_call_cnt, + ) + ) history_tool_call_cnt += 1 message = ChatMessage( role=role, content="", tool_calls=[ - tool_call_class(id=tool_call_ids[i], - function=FunctionCall( - name=tool_call.name, - arguments=json.dumps( - tool_call.parameters, - ensure_ascii=False))) + tool_call_class( + id=tool_call_ids[i], + function=FunctionCall( + name=tool_call.name, + arguments=json.dumps( + tool_call.parameters, ensure_ascii=False + ), + ), + ) for i, tool_call in enumerate(tool_calls) ], - reasoning_content=reasoning_content) + reasoning_content=reasoning_content, + ) # if the request doesn't use tool choice # OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": - - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=content) + message = ChatMessage( + role=role, reasoning_content=reasoning_content, content=content + ) # handle when there are tools and tool choice is auto - elif request.tools and ( - request.tool_choice == "auto" - or request.tool_choice is None) and self.enable_auto_tools \ - and self.tool_parser: - + elif ( + request.tools + and (request.tool_choice == "auto" or request.tool_choice is None) + and self.enable_auto_tools + and self.tool_parser + ): try: tool_parser = self.tool_parser(tokenizer) except RuntimeError as e: @@ -1350,16 +1485,19 @@ async def chat_completion_full_generator( return self.create_error_response(str(e)) tool_call_info = tool_parser.extract_tool_calls( - content if content is not None else "", request=request) + content if content is not None else "", request=request + ) # In the OpenAI API the finish_reason is "tools_called" # if the tool choice is auto and the model produced a tool # call. The same is not true for named function calls auto_tools_called = tool_call_info.tools_called if tool_call_info.tools_called: - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=tool_call_info.content, - tool_calls=tool_call_info.tool_calls) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls, + ) else: # FOR NOW make it a chat message; we will have to detect @@ -1368,48 +1506,55 @@ async def chat_completion_full_generator( # try to use content return from tool parser first, # tool parser may do some modify for the content. - if (tool_call_info.content - and len(tool_call_info.content) > 0): + if tool_call_info.content and len(tool_call_info.content) > 0: ret_content = tool_call_info.content - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=ret_content) + message = ChatMessage( + role=role, + reasoning_content=reasoning_content, + content=ret_content, + ) # undetermined case that is still important to handle else: logger.error( "Error in chat_completion_full_generator - cannot determine" " if tools should be extracted. Returning a standard chat " - "completion.") - message = ChatMessage(role=role, - reasoning_content=reasoning_content, - content=content) + "completion." + ) + message = ChatMessage( + role=role, reasoning_content=reasoning_content, content=content + ) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason="tool_calls" if auto_tools_called else - output.finish_reason if output.finish_reason else "stop", + finish_reason="tool_calls" + if auto_tools_called + else output.finish_reason + if output.finish_reason + else "stop", stop_reason=output.stop_reason, - token_ids=(as_list(output.token_ids) - if request.return_token_ids else None), + token_ids=( + as_list(output.token_ids) if request.return_token_ids else None + ), ) choices.append(choice_data) if request.echo: last_msg_content: Union[str, list[dict[str, str]]] = "" - if (conversation and "content" in conversation[-1] - and conversation[-1].get("role") == role): + if ( + conversation + and "content" in conversation[-1] + and conversation[-1].get("role") == role + ): last_msg_content = conversation[-1]["content"] or "" if isinstance(last_msg_content, list): - last_msg_content = "\n".join(msg['text'] - for msg in last_msg_content) + last_msg_content = "\n".join(msg["text"] for msg in last_msg_content) for choice in choices: - full_message = last_msg_content + (choice.message.content - or "") + full_message = last_msg_content + (choice.message.content or "") choice.message.content = full_message assert final_res.prompt_token_ids is not None @@ -1417,14 +1562,17 @@ async def chat_completion_full_generator( if final_res.encoder_prompt_token_ids is not None: num_prompt_tokens += len(final_res.encoder_prompt_token_ids) num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo(prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + - num_generated_tokens) + len(output.token_ids) for output in final_res.outputs + ) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) if self.enable_prompt_tokens_details and final_res.num_cached_tokens: usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=final_res.num_cached_tokens) + cached_tokens=final_res.num_cached_tokens + ) request_metadata.final_usage_info = usage @@ -1435,8 +1583,9 @@ async def chat_completion_full_generator( choices=choices, usage=usage, prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs), - prompt_token_ids=(final_res.prompt_token_ids - if request.return_token_ids else None), + prompt_token_ids=( + final_res.prompt_token_ids if request.return_token_ids else None + ), kv_transfer_params=final_res.kv_transfer_params, ) @@ -1451,9 +1600,11 @@ async def chat_completion_full_generator( tool_call_descriptions = [] for tc in choice.message.tool_calls: if hasattr(tc.function, "name") and hasattr( - tc.function, "arguments"): + tc.function, "arguments" + ): tool_call_descriptions.append( - f"{tc.function.name}({tc.function.arguments})") + f"{tc.function.name}({tc.function.arguments})" + ) tool_calls_str = ", ".join(tool_call_descriptions) output_text = f"[tool_calls: {tool_calls_str}]" @@ -1461,8 +1612,7 @@ async def chat_completion_full_generator( # Get the corresponding output token IDs output_token_ids = None if choice.index < len(final_res.outputs): - output_token_ids = final_res.outputs[ - choice.index].token_ids + output_token_ids = final_res.outputs[choice.index].token_ids self.request_logger.log_outputs( request_id=request_id, @@ -1476,20 +1626,26 @@ async def chat_completion_full_generator( return response def _get_top_logprobs( - self, logprobs: dict[int, Logprob], top_logprobs: Optional[int], - tokenizer: AnyTokenizer, - should_return_as_token_id: bool) -> list[ChatCompletionLogProb]: + self, + logprobs: dict[int, Logprob], + top_logprobs: Optional[int], + tokenizer: AnyTokenizer, + should_return_as_token_id: bool, + ) -> list[ChatCompletionLogProb]: return [ ChatCompletionLogProb( - token=(token := self._get_decoded_token( - p[1], - p[0], - tokenizer, - return_as_token_id=should_return_as_token_id, - )), + token=( + token := self._get_decoded_token( + p[1], + p[0], + tokenizer, + return_as_token_id=should_return_as_token_id, + ) + ), logprob=max(p[1].logprob, -9999.0), bytes=list(token.encode("utf-8", errors="replace")), - ) for i, p in enumerate(logprobs.items()) + ) + for i, p in enumerate(logprobs.items()) if top_logprobs and i < top_logprobs ] @@ -1504,12 +1660,14 @@ def _create_chat_logprobs( """Create OpenAI-style logprobs.""" logprobs_content: list[ChatCompletionLogProbsContent] = [] - should_return_as_token_id = return_as_token_id if \ - return_as_token_id is not None else self.return_tokens_as_token_ids + should_return_as_token_id = ( + return_as_token_id + if return_as_token_id is not None + else self.return_tokens_as_token_ids + ) for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] - if step_top_logprobs is None or step_top_logprobs.get( - token_id) is None: + if step_top_logprobs is None or step_top_logprobs.get(token_id) is None: if should_return_as_token_id: token = f"token_id:{token_id}" else: @@ -1519,7 +1677,8 @@ def _create_chat_logprobs( ChatCompletionLogProbsContent( token=token, bytes=list(token.encode("utf-8", errors="replace")), - )) + ) + ) else: step_token = step_top_logprobs[token_id] step_decoded = step_token.decoded_token @@ -1533,17 +1692,21 @@ def _create_chat_logprobs( should_return_as_token_id, ), logprob=max(step_token.logprob, -9999.0), - bytes=None if step_decoded is None else list( - step_decoded.encode("utf-8", errors="replace")), + bytes=None + if step_decoded is None + else list(step_decoded.encode("utf-8", errors="replace")), top_logprobs=self._get_top_logprobs( - step_top_logprobs, num_output_top_logprobs, - tokenizer, should_return_as_token_id), - )) + step_top_logprobs, + num_output_top_logprobs, + tokenizer, + should_return_as_token_id, + ), + ) + ) return ChatCompletionLogProbs(content=logprobs_content) - def _should_stream_with_auto_tool_parsing(self, - request: ChatCompletionRequest): + def _should_stream_with_auto_tool_parsing(self, request: ChatCompletionRequest): """ Utility function to check if streamed tokens should go through the tool call parser that was configured. @@ -1552,8 +1715,12 @@ def _should_stream_with_auto_tool_parsing(self, is configured, "auto" tool choice is enabled, and the request's tool choice field indicates that "auto" tool choice should be used. """ - return (request.tools and self.tool_parser and self.enable_auto_tools - and request.tool_choice in ['auto', None]) + return ( + request.tools + and self.tool_parser + and self.enable_auto_tools + and request.tool_choice in ["auto", None] + ) def _should_check_for_unstreamed_tool_arg_tokens( self, @@ -1566,13 +1733,15 @@ def _should_check_for_unstreamed_tool_arg_tokens( is a tool call with arguments. """ - # yapf: disable return bool( # if there is a delta message that includes tool calls which # include a function that has arguments output.finish_reason is not None - and self.enable_auto_tools and self.tool_parser and delta_message - and delta_message.tool_calls and delta_message.tool_calls[0] + and self.enable_auto_tools + and self.tool_parser + and delta_message + and delta_message.tool_calls + and delta_message.tool_calls[0] and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function.arguments is not None ) @@ -1592,8 +1761,8 @@ def _make_request_with_harmony( reasoning_effort=request.reasoning_effort, browser_description=None, python_description=None, - with_custom_tools=request.tools is not None - ) + with_custom_tools=request.tools is not None, + ) messages.append(sys_msg) # Add developer message. diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index fc56668aeb1b..25e167e9bb0c 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -11,14 +11,18 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ClassificationData, - ClassificationRequest, - ClassificationResponse, - ErrorResponse, UsageInfo) -# yapf: enable -from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext, - OpenAIServing, - ServeContext) +from vllm.entrypoints.openai.protocol import ( + ClassificationData, + ClassificationRequest, + ClassificationResponse, + ErrorResponse, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import ( + ClassificationServeContext, + OpenAIServing, + ServeContext, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig from vllm.logger import init_logger @@ -29,7 +33,6 @@ class ClassificationMixin(OpenAIServing): - @override async def _preprocess( self, @@ -55,7 +58,8 @@ async def _preprocess( renderer = self._get_renderer(ctx.tokenizer) ctx.engine_prompts = await renderer.render_prompt( prompt_or_prompts=ctx.request.input, - config=self._build_render_config(ctx.request)) + config=self._build_render_config(ctx.request), + ) return None @@ -76,16 +80,16 @@ def _build_response( items: list[ClassificationData] = [] num_prompt_tokens = 0 - final_res_batch_checked = cast(list[PoolingRequestOutput], - ctx.final_res_batch) + final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) for idx, final_res in enumerate(final_res_batch_checked): classify_res = ClassificationOutput.from_base(final_res.outputs) probs = classify_res.probs predicted_index = int(np.argmax(probs)) - label = getattr(self.model_config.hf_config, "id2label", - {}).get(predicted_index) + label = getattr(self.model_config.hf_config, "id2label", {}).get( + predicted_index + ) item = ClassificationData( index=idx, @@ -111,11 +115,11 @@ def _build_response( usage=usage, ) - def _build_render_config(self, - request: ClassificationRequest) -> RenderConfig: + def _build_render_config(self, request: ClassificationRequest) -> RenderConfig: return RenderConfig( max_length=self.max_model_len, - truncate_prompt_tokens=request.truncate_prompt_tokens) + truncate_prompt_tokens=request.truncate_prompt_tokens, + ) class ServingClassification(ClassificationMixin): @@ -144,8 +148,7 @@ async def create_classify( raw_request: Request, ) -> Union[ClassificationResponse, ErrorResponse]: model_name = self.models.model_name() - request_id = (f"{self.request_id_prefix}-" - f"{self._base_request_id(raw_request)}") + request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" ctx = ClassificationServeContext( request=request, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index d0756e42b796..ce0a6c0e23e5 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -13,21 +13,19 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.openai.protocol import (CompletionLogProbs, - CompletionRequest, - CompletionResponse, - CompletionResponseChoice, - CompletionResponseStreamChoice, - CompletionStreamResponse, - ErrorResponse, - PromptTokenUsageInfo, - RequestResponseMetadata, - UsageInfo) -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - clamp_prompt_logprobs) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + CompletionLogProbs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + ErrorResponse, + PromptTokenUsageInfo, + RequestResponseMetadata, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import get_max_tokens @@ -43,7 +41,6 @@ class OpenAIServingCompletion(OpenAIServing): - def __init__( self, engine_client: EngineClient, @@ -66,8 +63,7 @@ def __init__( log_error_stack=log_error_stack, ) self.enable_prompt_tokens_details = enable_prompt_tokens_details - self.default_sampling_params = ( - self.model_config.get_diff_sampling_param()) + self.default_sampling_params = self.model_config.get_diff_sampling_param() if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source @@ -103,21 +99,17 @@ async def create_completion( # Return error for unsupported features. if request.suffix is not None: - return self.create_error_response( - "suffix is not currently supported") + return self.create_error_response("suffix is not currently supported") if request.echo and request.prompt_embeds is not None: - return self.create_error_response( - "Echo is unsupported with prompt embeds.") + return self.create_error_response("Echo is unsupported with prompt embeds.") - if (request.prompt_logprobs is not None - and request.prompt_embeds is not None): + if request.prompt_logprobs is not None and request.prompt_embeds is not None: return self.create_error_response( - "prompt_logprobs is not compatible with prompt embeds.") + "prompt_logprobs is not compatible with prompt embeds." + ) - request_id = ( - f"cmpl-" - f"{self._base_request_id(raw_request, request.request_id)}") + request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}" created_time = int(time.time()) request_metadata = RequestResponseMetadata(request_id=request_id) @@ -156,7 +148,8 @@ async def create_completion( try: for i, engine_prompt in enumerate(engine_prompts): prompt_text, prompt_token_ids, prompt_embeds = ( - self._get_prompt_components(engine_prompt)) + self._get_prompt_components(engine_prompt) + ) input_length = None if prompt_token_ids is not None: @@ -179,7 +172,8 @@ async def create_completion( sampling_params: Union[SamplingParams, BeamSearchParams] if request.use_beam_search: sampling_params = request.to_beam_search_params( - max_tokens, self.default_sampling_params) + max_tokens, self.default_sampling_params + ) else: sampling_params = request.to_sampling_params( max_tokens, @@ -196,14 +190,16 @@ async def create_completion( lora_request=lora_request, ) - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) # Mypy inconsistently requires this second cast in different # environments. It shouldn't be necessary (redundant from above) # but pre-commit in CI fails without it. - engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], - engine_prompt) + engine_prompt = cast(Union[EmbedsPrompt, TokensPrompt], engine_prompt) if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.beam_search( prompt=engine_prompt, @@ -212,15 +208,14 @@ async def create_completion( lora_request=lora_request, ) else: - engine_request, tokenization_kwargs = ( - await self._process_inputs( - request_id_item, - engine_prompt, - sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - priority=request.priority, - )) + engine_request, tokenization_kwargs = await self._process_inputs( + request_id_item, + engine_prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) generator = self.engine_client.generate( engine_request, @@ -246,9 +241,11 @@ async def create_completion( # Similar to the OpenAI API, when n != best_of, we do not stream the # results. Noting that best_of is only supported in V0. In addition, # we do not stream the results when use beam search. - stream = (request.stream - and (request.best_of is None or request.n == request.best_of) - and not request.use_beam_search) + stream = ( + request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search + ) # Streaming response if stream: @@ -279,11 +276,13 @@ async def create_completion( # with the inputs token IDs if final_res.prompt is None: engine_prompt = engine_prompts[i] - final_res.prompt = None if is_embeds_prompt( - engine_prompt) else engine_prompt.get("prompt") + final_res.prompt = ( + None + if is_embeds_prompt(engine_prompt) + else engine_prompt.get("prompt") + ) - final_res_batch_checked = cast(list[RequestOutput], - final_res_batch) + final_res_batch_checked = cast(list[RequestOutput], final_res_batch) response = self.request_output_to_completion_response( final_res_batch_checked, @@ -336,10 +335,10 @@ async def completion_stream_generator( stream_options = request.stream_options if stream_options: - include_usage = (stream_options.include_usage - or enable_force_include_usage) - include_continuous_usage = (include_usage and - stream_options.continuous_usage_stats) + include_usage = stream_options.include_usage or enable_force_include_usage + include_continuous_usage = ( + include_usage and stream_options.continuous_usage_stats + ) else: include_usage, include_continuous_usage = False, False @@ -355,16 +354,18 @@ async def completion_stream_generator( prompt_text = res.prompt if prompt_text is None: engine_prompt = engine_prompts[prompt_idx] - prompt_text = None if is_embeds_prompt( - engine_prompt) else engine_prompt.get("prompt") + prompt_text = ( + None + if is_embeds_prompt(engine_prompt) + else engine_prompt.get("prompt") + ) # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: num_prompt_tokens[prompt_idx] = len(prompt_token_ids) delta_token_ids: GenericSequence[int] - out_logprobs: Optional[GenericSequence[Optional[dict[ - int, Logprob]]]] + out_logprobs: Optional[GenericSequence[Optional[dict[int, Logprob]]]] for output in res.outputs: i = output.index + prompt_idx * num_choices @@ -410,22 +411,23 @@ async def completion_stream_generator( prompt_token_ids_to_return = prompt_token_ids has_echoed[i] = True - if (not delta_text and not delta_token_ids - and not previous_num_tokens[i]): + if ( + not delta_text + and not delta_token_ids + and not previous_num_tokens[i] + ): # Chunked prefill case, don't return empty chunks continue if request.logprobs is not None: - assert out_logprobs is not None, ( - "Did not output logprobs") + assert out_logprobs is not None, "Did not output logprobs" logprobs = self._create_completion_logprobs( token_ids=delta_token_ids, top_logprobs=out_logprobs, num_output_top_logprobs=request.logprobs, tokenizer=tokenizer, initial_text_offset=previous_text_lens[i], - return_as_token_id=request. - return_tokens_as_token_ids, + return_as_token_id=request.return_tokens_as_token_ids, ) else: logprobs = None @@ -447,8 +449,11 @@ async def completion_stream_generator( finish_reason=finish_reason, stop_reason=stop_reason, prompt_token_ids=prompt_token_ids_to_return, - token_ids=(as_list(output.token_ids) if - request.return_token_ids else None), + token_ids=( + as_list(output.token_ids) + if request.return_token_ids + else None + ), ) ], ) @@ -474,7 +479,8 @@ async def completion_stream_generator( if self.enable_prompt_tokens_details and num_cached_tokens: final_usage_info.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=num_cached_tokens) + cached_tokens=num_cached_tokens + ) if include_usage: final_usage_chunk = CompletionStreamResponse( @@ -485,7 +491,8 @@ async def completion_stream_generator( usage=final_usage_info, ) final_usage_data = final_usage_chunk.model_dump_json( - exclude_unset=False, exclude_none=True) + exclude_unset=False, exclude_none=True + ) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices @@ -520,8 +527,7 @@ def request_output_to_completion_response( prompt_text = final_res.prompt token_ids: GenericSequence[int] - out_logprobs: Optional[GenericSequence[Optional[dict[int, - Logprob]]]] + out_logprobs: Optional[GenericSequence[Optional[dict[int, Logprob]]]] for output in final_res.outputs: assert request.max_tokens is not None @@ -571,10 +577,12 @@ def request_output_to_completion_response( finish_reason=output.finish_reason, stop_reason=output.stop_reason, prompt_logprobs=final_res.prompt_logprobs, - prompt_token_ids=(prompt_token_ids - if request.return_token_ids else None), - token_ids=(as_list(output.token_ids) - if request.return_token_ids else None), + prompt_token_ids=( + prompt_token_ids if request.return_token_ids else None + ), + token_ids=( + as_list(output.token_ids) if request.return_token_ids else None + ), ) choices.append(choice_data) @@ -588,10 +596,14 @@ def request_output_to_completion_response( total_tokens=num_prompt_tokens + num_generated_tokens, ) - if (self.enable_prompt_tokens_details and last_final_res - and last_final_res.num_cached_tokens): + if ( + self.enable_prompt_tokens_details + and last_final_res + and last_final_res.num_cached_tokens + ): usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=last_final_res.num_cached_tokens) + cached_tokens=last_final_res.num_cached_tokens + ) request_metadata.final_usage_info = usage if final_res_batch: @@ -622,9 +634,11 @@ def _create_completion_logprobs( last_token_len = 0 - should_return_as_token_id = (return_as_token_id - if return_as_token_id is not None else - self.return_tokens_as_token_ids) + should_return_as_token_id = ( + return_as_token_id + if return_as_token_id is not None + else self.return_tokens_as_token_ids + ) for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: @@ -653,19 +667,20 @@ def _create_completion_logprobs( # logprobs, as defined in the openai API # (cf. https://github.com/openai/openai-openapi/blob/ # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) - out_top_logprobs.append({ - # Convert float("-inf") to the - # JSON-serializable float that OpenAI uses - self._get_decoded_token( - top_lp[1], - top_lp[0], - tokenizer, - return_as_token_id=should_return_as_token_id, - ): - max(top_lp[1].logprob, -9999.0) - for i, top_lp in enumerate(step_top_logprobs.items()) - if num_output_top_logprobs >= i - }) + out_top_logprobs.append( + { + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + self._get_decoded_token( + top_lp[1], + top_lp[0], + tokenizer, + return_as_token_id=should_return_as_token_id, + ): max(top_lp[1].logprob, -9999.0) + for i, top_lp in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs >= i + } + ) if len(out_text_offset) == 0: out_text_offset.append(initial_text_offset) @@ -691,6 +706,5 @@ def _build_render_config( truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, cache_salt=request.cache_salt, - needs_detokenization=bool(request.echo - and not request.return_token_ids), + needs_detokenization=bool(request.echo and not request.return_token_ids), ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 647e7daed659..5517ab2802e3 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -14,25 +14,32 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this docstring -# yapf: disable -from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, - EmbeddingCompletionRequest, - EmbeddingRequest, - EmbeddingResponse, - EmbeddingResponseData, - ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, - OpenAIServing, - ServeContext, - TextTokensPrompt) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + ErrorResponse, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import ( + EmbeddingServeContext, + OpenAIServing, + ServeContext, + TextTokensPrompt, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger -from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, - PoolingOutput, PoolingRequestOutput, RequestOutput) +from vllm.outputs import ( + EmbeddingOutput, + EmbeddingRequestOutput, + PoolingOutput, + PoolingRequestOutput, + RequestOutput, +) from vllm.pooling_params import PoolingParams from vllm.utils import chunk_list @@ -55,7 +62,6 @@ def _get_embedding( class EmbeddingMixin(OpenAIServing): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -63,9 +69,13 @@ def __init__(self, *args, **kwargs): # Avoid repeated attribute lookups self.supports_chunked_processing = bool( - pooler_config and pooler_config.enable_chunked_processing) - self.max_embed_len = (pooler_config.max_embed_len if pooler_config - and pooler_config.max_embed_len else None) + pooler_config and pooler_config.enable_chunked_processing + ) + self.max_embed_len = ( + pooler_config.max_embed_len + if pooler_config and pooler_config.max_embed_len + else None + ) @override async def _preprocess( @@ -88,10 +98,8 @@ async def _preprocess( ctx.request, tokenizer, ctx.request.messages, - chat_template=ctx.request.chat_template - or ctx.chat_template, - chat_template_content_format=ctx. - chat_template_content_format, + chat_template=ctx.request.chat_template or ctx.chat_template, + chat_template_content_format=ctx.chat_template_content_format, add_generation_prompt=ctx.request.add_generation_prompt, continue_final_message=False, add_special_tokens=ctx.request.add_special_tokens, @@ -106,8 +114,7 @@ async def _preprocess( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - def _build_render_config( - self, request: EmbeddingCompletionRequest) -> RenderConfig: + def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig: # Set max_length based on chunked processing capability if self._should_use_chunked_processing(request): max_length = None @@ -117,7 +124,8 @@ def _build_render_config( return RenderConfig( max_length=max_length, truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens) + add_special_tokens=request.add_special_tokens, + ) @override def _build_response( @@ -127,16 +135,16 @@ def _build_response( items: list[EmbeddingResponseData] = [] num_prompt_tokens = 0 - final_res_batch_checked = cast(list[PoolingRequestOutput], - ctx.final_res_batch) + final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch) for idx, final_res in enumerate(final_res_batch_checked): embedding_res = EmbeddingRequestOutput.from_base(final_res) item = EmbeddingResponseData( index=idx, - embedding=_get_embedding(embedding_res.outputs, - ctx.request.encoding_format), + embedding=_get_embedding( + embedding_res.outputs, ctx.request.encoding_format + ), ) prompt_token_ids = final_res.prompt_token_ids @@ -162,10 +170,10 @@ def _get_max_position_embeddings(self) -> int: def _should_use_chunked_processing(self, request) -> bool: """Check if chunked processing should be used for this request.""" - return isinstance( - request, - (EmbeddingCompletionRequest, - EmbeddingChatRequest)) and self.supports_chunked_processing + return ( + isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)) + and self.supports_chunked_processing + ) async def _process_chunked_request( self, @@ -183,25 +191,27 @@ async def _process_chunked_request( max_pos_embeddings = self._get_max_position_embeddings() # Process all chunks for MEAN aggregation for chunk_idx, chunk_tokens in enumerate( - chunk_list(token_ids, max_pos_embeddings)): + chunk_list(token_ids, max_pos_embeddings) + ): # Create a request ID for this chunk - chunk_request_id = (f"{ctx.request_id}-prompt-{prompt_idx}-" - f"chunk-{chunk_idx}") + chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}" # Create engine prompt for this chunk - chunk_engine_prompt = EngineTokensPrompt( - prompt_token_ids=chunk_tokens) + chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens) # Create chunk request prompt for logging chunk_text = "" chunk_request_prompt = TextTokensPrompt( - prompt=chunk_text, prompt_token_ids=chunk_tokens) + prompt=chunk_text, prompt_token_ids=chunk_tokens + ) # Log the chunk - self._log_inputs(chunk_request_id, - chunk_request_prompt, - params=pooling_params, - lora_request=ctx.lora_request) + self._log_inputs( + chunk_request_id, + chunk_request_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) # Create generator for this chunk and wrap it to return indices original_generator = self.engine_client.encode( @@ -227,8 +237,7 @@ def _validate_input( token_num = len(input_ids) # Note: EmbeddingRequest doesn't have max_tokens - if isinstance(request, - (EmbeddingCompletionRequest, EmbeddingChatRequest)): + if isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)): # Check if chunked processing is enabled for pooling models enable_chunked = self._should_use_chunked_processing(request) @@ -248,13 +257,15 @@ def _validate_input( validation_error_msg = ( "This model's {length_type} is {max_length_value} tokens. " "However, you requested {token_num} tokens in the input for " - "embedding generation. Please reduce the length of the input.") + "embedding generation. Please reduce the length of the input." + ) chunked_processing_error_msg = ( "This model's {length_type} is {max_length_value} tokens. " "However, you requested {token_num} tokens in the input for " "embedding generation. Please reduce the length of the input " - "or enable chunked processing.") + "or enable chunked processing." + ) # Check if input exceeds max length if token_num > max_length_value: @@ -262,7 +273,9 @@ def _validate_input( validation_error_msg.format( length_type=length_type, max_length_value=max_length_value, - token_num=token_num)) + token_num=token_num, + ) + ) # Check for chunked processing # when exceeding max_position_embeddings @@ -271,25 +284,31 @@ def _validate_input( # Allow long inputs when chunked processing is enabled logger.info( "Input length %s exceeds max_position_embeddings " - "%s, will use chunked processing", token_num, - max_pos_embeddings) + "%s, will use chunked processing", + token_num, + max_pos_embeddings, + ) else: raise ValueError( chunked_processing_error_msg.format( length_type="maximum position embeddings length", max_length_value=max_pos_embeddings, - token_num=token_num)) + token_num=token_num, + ) + ) - return TextTokensPrompt(prompt=input_text, - prompt_token_ids=input_ids) + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # For other request types, use the parent's implementation return super()._validate_input(request, input_ids, input_text) def _is_text_tokens_prompt(self, prompt) -> bool: """Check if a prompt is a TextTokensPrompt (has prompt_token_ids).""" - return (isinstance(prompt, dict) and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt) + return ( + isinstance(prompt, dict) + and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt + ) async def _create_single_prompt_generator( self, @@ -302,10 +321,12 @@ async def _create_single_prompt_generator( """Create a generator for a single prompt using standard processing.""" request_id_item = f"{ctx.request_id}-{prompt_index}" - self._log_inputs(request_id_item, - engine_prompt, - params=pooling_params, - lora_request=ctx.lora_request) + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) # Return the original generator without wrapping return self.engine_client.encode( @@ -333,13 +354,16 @@ async def _prepare_generators( return await super()._prepare_generators(ctx) # Custom logic for chunked processing - generators: list[AsyncGenerator[Union[RequestOutput, - PoolingRequestOutput], - None]] = [] + generators: list[ + AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None] + ] = [] try: - trace_headers = (None if ctx.raw_request is None else await - self._get_trace_headers(ctx.raw_request.headers)) + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) + ) pooling_params = self._create_pooling_params(ctx) if isinstance(pooling_params, ErrorResponse): @@ -352,8 +376,7 @@ async def _prepare_generators( return self.create_error_response(str(e)) if ctx.engine_prompts is None: - return self.create_error_response( - "Engine prompts not available") + return self.create_error_response("Engine prompts not available") max_pos_embeddings = self._get_max_position_embeddings() @@ -363,21 +386,22 @@ async def _prepare_generators( # Cast to TextTokensPrompt since we've verified # prompt_token_ids text_tokens_prompt = cast(TextTokensPrompt, engine_prompt) - if (len(text_tokens_prompt["prompt_token_ids"]) - > max_pos_embeddings): + if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings: # Use chunked processing for this prompt chunk_generators = await self._process_chunked_request( - ctx, text_tokens_prompt, pooling_params, - trace_headers, i) + ctx, text_tokens_prompt, pooling_params, trace_headers, i + ) generators.extend(chunk_generators) continue # Normal processing for short prompts or non-token prompts generator = await self._create_single_prompt_generator( - ctx, engine_prompt, pooling_params, trace_headers, i) + ctx, engine_prompt, pooling_params, trace_headers, i + ) generators.append(generator) from vllm.utils import merge_async_iterators + ctx.result_generator = merge_async_iterators(*generators) return None @@ -401,8 +425,7 @@ async def _collect_batch( ctx = cast(EmbeddingServeContext, ctx) try: if ctx.engine_prompts is None: - return self.create_error_response( - "Engine prompts not available") + return self.create_error_response("Engine prompts not available") # Check if we used chunked processing use_chunked = self._should_use_chunked_processing(ctx.request) @@ -411,8 +434,7 @@ async def _collect_batch( return await super()._collect_batch(ctx=ctx) if ctx.result_generator is None: - return self.create_error_response( - "Result generator not available") + return self.create_error_response("Result generator not available") # Online aggregation for chunked requests to # minimize memory usage @@ -433,10 +455,10 @@ async def _collect_batch( # Initialize aggregator for this prompt if needed if prompt_idx not in prompt_aggregators: prompt_aggregators[prompt_idx] = { - 'weighted_sum': None, - 'total_weight': 0, - 'chunk_count': 0, - 'request_id': result.request_id.split("-chunk-")[0] + "weighted_sum": None, + "total_weight": 0, + "chunk_count": 0, + "request_id": result.request_id.split("-chunk-")[0], } aggregator = prompt_aggregators[prompt_idx] @@ -448,44 +470,45 @@ async def _collect_batch( return self.create_error_response( f"Expected PoolingRequestOutput for " f"chunked embedding, got " - f"{type(result).__name__}") + f"{type(result).__name__}" + ) # Handle both PoolingOutput and # EmbeddingOutput types - if hasattr(result.outputs, 'data'): + if hasattr(result.outputs, "data"): # PoolingOutput case embedding_data = result.outputs.data - elif hasattr(result.outputs, 'embedding'): + elif hasattr(result.outputs, "embedding"): # EmbeddingOutput case - # convert embedding list to tensor embedding_data = result.outputs.embedding else: return self.create_error_response( - f"Unsupported output type: " - f"{type(result.outputs).__name__}") + f"Unsupported output type: {type(result.outputs).__name__}" + ) if not isinstance(embedding_data, torch.Tensor): - embedding_data = torch.tensor(embedding_data, - dtype=torch.float32) + embedding_data = torch.tensor( + embedding_data, dtype=torch.float32 + ) if result.prompt_token_ids is None: return self.create_error_response( - "prompt_token_ids cannot be None for " - "chunked processing") + "prompt_token_ids cannot be None for chunked processing" + ) weight = len(result.prompt_token_ids) - weighted_embedding = embedding_data.to( - dtype=torch.float32) * weight + weighted_embedding = embedding_data.to(dtype=torch.float32) * weight - if aggregator['weighted_sum'] is None: + if aggregator["weighted_sum"] is None: # First chunk - aggregator['weighted_sum'] = weighted_embedding + aggregator["weighted_sum"] = weighted_embedding else: # Accumulate - aggregator['weighted_sum'] += weighted_embedding + aggregator["weighted_sum"] += weighted_embedding - aggregator['total_weight'] += weight - aggregator['chunk_count'] += 1 + aggregator["total_weight"] += weight + aggregator["chunk_count"] += 1 else: # Non-chunked result - extract prompt_idx from request_id parts = result.request_id.split("-") @@ -496,11 +519,13 @@ async def _collect_batch( prompt_idx = result_idx # Fallback to result_idx short_prompts_results[prompt_idx] = cast( - PoolingRequestOutput, result) + PoolingRequestOutput, result + ) # Finalize aggregated results - final_res_batch: list[Union[PoolingRequestOutput, - EmbeddingRequestOutput]] = [] + final_res_batch: list[ + Union[PoolingRequestOutput, EmbeddingRequestOutput] + ] = [] num_prompts = len(ctx.engine_prompts) for prompt_idx in range(num_prompts): @@ -508,55 +533,57 @@ async def _collect_batch( # Finalize MEAN aggregation for this chunked prompt aggregator = prompt_aggregators[prompt_idx] - weighted_sum = aggregator['weighted_sum'] - total_weight = aggregator['total_weight'] - - if (weighted_sum is not None - and isinstance(weighted_sum, torch.Tensor) - and isinstance(total_weight, - (int, float)) and total_weight > 0): + weighted_sum = aggregator["weighted_sum"] + total_weight = aggregator["total_weight"] + if ( + weighted_sum is not None + and isinstance(weighted_sum, torch.Tensor) + and isinstance(total_weight, (int, float)) + and total_weight > 0 + ): # Compute final mean embedding final_embedding = weighted_sum / total_weight # Create a PoolingRequestOutput # for the aggregated result - pooling_output_data = PoolingOutput( - data=final_embedding) + pooling_output_data = PoolingOutput(data=final_embedding) # Get original prompt token IDs for this prompt original_prompt = ctx.engine_prompts[prompt_idx] if not self._is_text_tokens_prompt(original_prompt): return self.create_error_response( - f"Chunked prompt {prompt_idx} is not a " - f"TextTokensPrompt") + f"Chunked prompt {prompt_idx} is not a TextTokensPrompt" + ) - original_token_ids = cast( - TextTokensPrompt, - original_prompt)["prompt_token_ids"] + original_token_ids = cast(TextTokensPrompt, original_prompt)[ + "prompt_token_ids" + ] pooling_request_output = PoolingRequestOutput( - request_id=aggregator['request_id'], + request_id=aggregator["request_id"], prompt_token_ids=original_token_ids, outputs=pooling_output_data, - finished=True) + finished=True, + ) final_res_batch.append(pooling_request_output) else: return self.create_error_response( - f"Failed to aggregate chunks " - f"for prompt {prompt_idx}") + f"Failed to aggregate chunks for prompt {prompt_idx}" + ) elif prompt_idx in short_prompts_results: final_res_batch.append( - cast(PoolingRequestOutput, - short_prompts_results[prompt_idx])) + cast(PoolingRequestOutput, short_prompts_results[prompt_idx]) + ) else: return self.create_error_response( - f"Result not found for prompt {prompt_idx}") + f"Result not found for prompt {prompt_idx}" + ) ctx.final_res_batch = cast( - list[Union[RequestOutput, PoolingRequestOutput]], - final_res_batch) + list[Union[RequestOutput, PoolingRequestOutput]], final_res_batch + ) return None @@ -576,16 +603,20 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - log_error_stack=log_error_stack) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + log_error_stack=log_error_stack, + ) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template async def create_embedding( self, @@ -601,7 +632,8 @@ async def create_embedding( model_name = self.models.model_name() request_id = ( f"{self.request_id_prefix}-" - f"{self._base_request_id(raw_request, request.request_id)}") + f"{self._base_request_id(raw_request, request.request_id)}" + ) ctx = EmbeddingServeContext( request=request, @@ -629,3 +661,17 @@ def _create_pooling_params( return self.create_error_response(str(e)) return pooling_params + + async def _preprocess( + self, + ctx: ServeContext, + ) -> Optional[ErrorResponse]: + if isinstance(ctx.request, EmbeddingChatRequest): + error_check_ret = self._validate_chat_template( + request_chat_template=ctx.request.chat_template, + chat_template_kwargs=ctx.request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret + return await super()._preprocess(ctx) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index e58d943d3f7f..6ddde23b4a34 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -16,7 +16,6 @@ from typing_extensions import TypeIs from vllm.entrypoints.utils import _validate_truncation_size -from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.processor import Processor @@ -28,44 +27,47 @@ import vllm.envs as envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - ChatTemplateContentFormatOption, - ConversationMessage, - apply_hf_chat_template, - apply_mistral_chat_template, - parse_chat_messages_futures, - resolve_chat_template_content_format) +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, + ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages_futures, + resolve_chat_template_content_format, +) from vllm.entrypoints.context import ConversationContext from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionResponse, - ClassificationRequest, - ClassificationResponse, - CompletionRequest, - CompletionResponse, - DetokenizeRequest, - EmbeddingChatRequest, - EmbeddingCompletionRequest, - EmbeddingRequest, - EmbeddingResponse, ErrorInfo, - ErrorResponse, - IOProcessorRequest, - PoolingResponse, RerankRequest, - ResponsesRequest, ScoreRequest, - ScoreResponse, - TokenizeChatRequest, - TokenizeCompletionRequest, - TokenizeResponse, - TranscriptionRequest, - TranscriptionResponse, - TranslationRequest) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ClassificationRequest, + ClassificationResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingRequest, + EmbeddingResponse, + ErrorInfo, + ErrorResponse, + IOProcessorRequest, + PoolingResponse, + RerankRequest, + ResponsesRequest, + ScoreRequest, + ScoreResponse, + TokenizeChatRequest, + TokenizeCompletionRequest, + TokenizeResponse, + TranscriptionRequest, + TranscriptionResponse, + TranslationRequest, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser -from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer, - RenderConfig) -# yapf: enable +from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig from vllm.inputs.data import PromptType from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import PromptComponents, get_prompt_components @@ -73,15 +75,25 @@ from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin - MultiModalDataDict, MultiModalUUIDDict) + MultiModalDataDict, + MultiModalUUIDDict, +) from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.tracing import (contains_trace_headers, extract_trace_headers, - log_tracing_disabled_warning) +from vllm.tracing import ( + contains_trace_headers, + extract_trace_headers, + log_tracing_disabled_warning, +) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of, make_async, - merge_async_iterators, random_uuid) +from vllm.utils import ( + AsyncMicrobatchTokenizer, + is_list_of, + make_async, + merge_async_iterators, + random_uuid, +) logger = init_logger(__name__) @@ -95,8 +107,9 @@ TokenizeCompletionRequest, ] -ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, - TokenizeChatRequest] +ChatLikeRequest = Union[ + ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest +] SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] AnyRequest = Union[ CompletionLikeRequest, @@ -131,13 +144,19 @@ class EmbedsPrompt(TypedDict): def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]: - return (isinstance(prompt, dict) and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt) + return ( + isinstance(prompt, dict) + and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt + ) def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]: - return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt - and "prompt_embeds" in prompt) + return ( + isinstance(prompt, dict) + and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt + ) RequestT = TypeVar("RequestT", bound=AnyRequest) @@ -161,19 +180,21 @@ class ResponseGenerationMixin(BaseModel): managing result generators and final batch results. """ - result_generator: Optional[AsyncGenerator[tuple[int, Union[ - RequestOutput, PoolingRequestOutput]], None]] = None + result_generator: Optional[ + AsyncGenerator[tuple[int, Union[RequestOutput, PoolingRequestOutput]], None] + ] = None final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field( - default_factory=list) + default_factory=list + ) model_config = ConfigDict(arbitrary_types_allowed=True) class ServeContext( - RequestProcessingMixin, - ResponseGenerationMixin, - BaseModel, - Generic[RequestT], + RequestProcessingMixin, + ResponseGenerationMixin, + BaseModel, + Generic[RequestT], ): # Shared across all requests request: RequestT @@ -241,20 +262,17 @@ def __init__( self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) self._apply_mistral_chat_template_async = make_async( - apply_mistral_chat_template, executor=self._tokenizer_executor) + apply_mistral_chat_template, executor=self._tokenizer_executor + ) - self._async_tokenizer_pool: dict[AnyTokenizer, - AsyncMicrobatchTokenizer] = {} + self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} self.log_error_stack = log_error_stack async def _get_processor(self) -> Processor: if not hasattr(self, "_processor"): vllm_config = await self.engine_client.get_vllm_config() - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = init_tokenizer_from_configs(self.model_config) - self._processor = Processor(vllm_config, tokenizer) + self._processor = Processor(vllm_config) + return self._processor def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer: @@ -265,7 +283,8 @@ def _get_renderer(self, tokenizer: Optional[AnyTokenizer]) -> BaseRenderer: return CompletionRenderer( model_config=self.model_config, tokenizer=tokenizer, - async_tokenizer_pool=self._async_tokenizer_pool) + async_tokenizer_pool=self._async_tokenizer_pool, + ) def _build_render_config( self, @@ -348,15 +367,17 @@ async def _pipeline( yield self._build_response(ctx) def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]: - truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", - None) + truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) - if (truncate_prompt_tokens is not None - and truncate_prompt_tokens > self.max_model_len): + if ( + truncate_prompt_tokens is not None + and truncate_prompt_tokens > self.max_model_len + ): return self.create_error_response( "truncate_prompt_tokens value is " "greater than max_model_len." - " Please, select a smaller truncation size.") + " Please, select a smaller truncation size." + ) return None def _create_pooling_params( @@ -365,7 +386,8 @@ def _create_pooling_params( ) -> Union[PoolingParams, ErrorResponse]: if not hasattr(ctx.request, "to_pooling_params"): return self.create_error_response( - "Request type does not support pooling parameters") + "Request type does not support pooling parameters" + ) return ctx.request.to_pooling_params() @@ -374,21 +396,23 @@ async def _prepare_generators( ctx: ServeContext, ) -> Optional[ErrorResponse]: """Schedule the request and get the result generator.""" - generators: list[AsyncGenerator[Union[RequestOutput, - PoolingRequestOutput], - None]] = [] + generators: list[ + AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None] + ] = [] try: - trace_headers = (None if ctx.raw_request is None else await - self._get_trace_headers(ctx.raw_request.headers)) + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) + ) pooling_params = self._create_pooling_params(ctx) if isinstance(pooling_params, ErrorResponse): return pooling_params if ctx.engine_prompts is None: - return self.create_error_response( - "Engine prompts not available") + return self.create_error_response("Engine prompts not available") for i, engine_prompt in enumerate(ctx.engine_prompts): request_id_item = f"{ctx.request_id}-{i}" @@ -426,28 +450,24 @@ async def _collect_batch( """Collect batch results from the result generator.""" try: if ctx.engine_prompts is None: - return self.create_error_response( - "Engine prompts not available") + return self.create_error_response("Engine prompts not available") num_prompts = len(ctx.engine_prompts) - final_res_batch: list[Optional[Union[RequestOutput, - PoolingRequestOutput]]] + final_res_batch: list[Optional[Union[RequestOutput, PoolingRequestOutput]]] final_res_batch = [None] * num_prompts if ctx.result_generator is None: - return self.create_error_response( - "Result generator not available") + return self.create_error_response("Result generator not available") async for i, res in ctx.result_generator: final_res_batch[i] = res if None in final_res_batch: return self.create_error_response( - "Failed to generate results for all prompts") + "Failed to generate results for all prompts" + ) - ctx.final_res_batch = [ - res for res in final_res_batch if res is not None - ] + ctx.final_res_batch = [res for res in final_res_batch if res is not None] return None @@ -466,8 +486,9 @@ def create_error_response( traceback.print_exc() else: traceback.print_stack() - return ErrorResponse(error=ErrorInfo( - message=message, type=err_type, code=status_code.value)) + return ErrorResponse( + error=ErrorInfo(message=message, type=err_type, code=status_code.value) + ) def create_streaming_error_response( self, @@ -476,9 +497,10 @@ def create_streaming_error_response( status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, ) -> str: json_str = json.dumps( - self.create_error_response(message=message, - err_type=err_type, - status_code=status_code).model_dump()) + self.create_error_response( + message=message, err_type=err_type, status_code=status_code + ).model_dump() + ) return json_str async def _check_model( @@ -491,12 +513,17 @@ async def _check_model( return None if request.model in self.models.lora_requests: return None - if (envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and - (load_result := await self.models.resolve_lora(request.model))): + if ( + envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING + and request.model + and (load_result := await self.models.resolve_lora(request.model)) + ): if isinstance(load_result, LoRARequest): return None - if (isinstance(load_result, ErrorResponse) and - load_result.error.code == HTTPStatus.BAD_REQUEST.value): + if ( + isinstance(load_result, ErrorResponse) + and load_result.error.code == HTTPStatus.BAD_REQUEST.value + ): error_response = load_result return error_response or self.create_error_response( @@ -506,7 +533,8 @@ async def _check_model( ) def _get_active_default_mm_loras( - self, request: AnyRequest) -> Optional[LoRARequest]: + self, request: AnyRequest + ) -> Optional[LoRARequest]: """Determine if there are any active default multimodal loras.""" # TODO: Currently this is only enabled for chat completions # to be better aligned with only being enabled for .generate @@ -561,8 +589,11 @@ def _get_message_types(self, request: AnyRequest) -> set[str]: return message_types for message in request.messages: - if (isinstance(message, dict) and "content" in message - and isinstance(message["content"], list)): + if ( + isinstance(message, dict) + and "content" in message + and isinstance(message["content"], list) + ): for content_dict in message["content"]: if "type" in content_dict: message_types.add(content_dict["type"].split("_")[0]) @@ -577,17 +608,18 @@ async def _normalize_prompt_text_to_input( ) -> TextTokensPrompt: async_tokenizer = self._get_async_tokenizer(tokenizer) - if (self.model_config.encoder_config is not None - and self.model_config.encoder_config.get( - "do_lower_case", False)): + if ( + self.model_config.encoder_config is not None + and self.model_config.encoder_config.get("do_lower_case", False) + ): prompt = prompt.lower() - truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", - None) + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) if truncate_prompt_tokens is None: encoded = await async_tokenizer( - prompt, add_special_tokens=add_special_tokens) + prompt, add_special_tokens=add_special_tokens + ) elif truncate_prompt_tokens < 0: # Negative means we cap at the model's max length encoded = await async_tokenizer( @@ -615,13 +647,12 @@ async def _normalize_prompt_tokens_to_input( prompt_ids: list[int], tokenizer: Optional[AnyTokenizer], ) -> TextTokensPrompt: - truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", - None) + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) if truncate_prompt_tokens is None: input_ids = prompt_ids elif truncate_prompt_tokens < 0: - input_ids = prompt_ids[-self.max_model_len:] + input_ids = prompt_ids[-self.max_model_len :] else: input_ids = prompt_ids[-truncate_prompt_tokens:] @@ -644,7 +675,7 @@ def _validate_input( # Note: EmbeddingRequest, ClassificationRequest, # and ScoreRequest doesn't have max_tokens if isinstance( - request, + request, ( EmbeddingChatRequest, EmbeddingCompletionRequest, @@ -660,25 +691,22 @@ def _validate_input( ScoreRequest: "score", ClassificationRequest: "classification", } - operation = operations.get(type(request), - "embedding generation") + operation = operations.get(type(request), "embedding generation") raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the input for {operation}. " - f"Please reduce the length of the input.") - return TextTokensPrompt(prompt=input_text, - prompt_token_ids=input_ids) + f"Please reduce the length of the input." + ) + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # and does not require model context length validation if isinstance( - request, - (TokenizeCompletionRequest, TokenizeChatRequest, - DetokenizeRequest), + request, + (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest), ): - return TextTokensPrompt(prompt=input_text, - prompt_token_ids=input_ids) + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # chat completion endpoint supports max_completion_tokens if isinstance(request, ChatCompletionRequest): @@ -694,16 +722,17 @@ def _validate_input( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, your request has " f"{token_num} input tokens. Please reduce the length of " - "the input messages.") + "the input messages." + ) - if (max_tokens is not None - and token_num + max_tokens > self.max_model_len): + if max_tokens is not None and token_num + max_tokens > self.max_model_len: raise ValueError( "'max_tokens' or 'max_completion_tokens' is too large: " f"{max_tokens}. This model's maximum context length is " f"{self.max_model_len} tokens and your request has " f"{token_num} input tokens ({max_tokens} > {self.max_model_len}" - f" - {token_num}).") + f" - {token_num})." + ) return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) @@ -718,10 +747,10 @@ async def _tokenize_prompt_input_async( A simpler implementation that tokenizes a single prompt input. """ async for result in self._tokenize_prompt_inputs_async( - request, - tokenizer, + request, + tokenizer, [prompt_input], - add_special_tokens=add_special_tokens, + add_special_tokens=add_special_tokens, ): return result raise ValueError("No results yielded from tokenization") @@ -751,6 +780,26 @@ async def _tokenize_prompt_inputs_async( tokenizer=tokenizer, ) + def _validate_chat_template( + self, + request_chat_template: Optional[str], + chat_template_kwargs: Optional[dict[str, Any]], + trust_request_chat_template: bool, + ) -> Optional[ErrorResponse]: + if not trust_request_chat_template and ( + request_chat_template is not None + or ( + chat_template_kwargs + and chat_template_kwargs.get("chat_template") is not None + ) + ): + return self.create_error_response( + "Chat template is passed with request, but " + "--trust-request-chat-template is not set. " + "Refused request with untrusted chat template." + ) + return None + async def _preprocess_chat( self, request: Union[ChatLikeRequest, ResponsesRequest], @@ -766,9 +815,9 @@ async def _preprocess_chat( tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, add_special_tokens: bool = False, ) -> tuple[ - list[ConversationMessage], - Sequence[RequestPrompt], - list[EngineTokensPrompt], + list[ConversationMessage], + Sequence[RequestPrompt], + list[EngineTokensPrompt], ]: model_config = self.model_config @@ -818,8 +867,9 @@ async def _preprocess_chat( # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser # is set, we want to prevent parsing a tool_call hallucinated by the LLM - should_parse_tools = tool_parser is not None and (hasattr( - request, "tool_choice") and request.tool_choice != "none") + should_parse_tools = tool_parser is not None and ( + hasattr(request, "tool_choice") and request.tool_choice != "none" + ) if should_parse_tools: if not isinstance(request, ChatCompletionRequest): @@ -827,15 +877,17 @@ async def _preprocess_chat( raise NotImplementedError(msg) request = tool_parser(tokenizer).adjust_request( # type: ignore - request=request) + request=request + ) if tokenizer is None: assert isinstance(request_prompt, str), ( "Prompt has to be a string", "when the tokenizer is not initialised", ) - prompt_inputs = TextTokensPrompt(prompt=request_prompt, - prompt_token_ids=[1]) + prompt_inputs = TextTokensPrompt( + prompt=request_prompt, prompt_token_ids=[1] + ) elif isinstance(request_prompt, str): prompt_inputs = await self._tokenize_prompt_input_async( request, @@ -846,14 +898,16 @@ async def _preprocess_chat( else: # For MistralTokenizer assert is_list_of(request_prompt, int), ( - "Prompt has to be either a string or a list of token ids") + "Prompt has to be either a string or a list of token ids" + ) prompt_inputs = TextTokensPrompt( prompt=tokenizer.decode(request_prompt), prompt_token_ids=request_prompt, ) engine_prompt = EngineTokensPrompt( - prompt_token_ids=prompt_inputs["prompt_token_ids"]) + prompt_token_ids=prompt_inputs["prompt_token_ids"] + ) if mm_data is not None: engine_prompt["multi_modal_data"] = mm_data @@ -880,9 +934,9 @@ async def _process_inputs( ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for AsyncLLM.""" tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size(self.max_model_len, - params.truncate_prompt_tokens, - tokenization_kwargs) + _validate_truncation_size( + self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs + ) processor = await self._get_processor() engine_request = processor.process_inputs( @@ -917,14 +971,14 @@ async def _generate_with_builtin_tools( lora_request=lora_request, ) trace_headers = kwargs.get("trace_headers") - engine_request, tokenization_kwargs = (await self._process_inputs( + engine_request, tokenization_kwargs = await self._process_inputs( request_id, engine_prompt, sampling_params, lora_request=lora_request, trace_headers=trace_headers, priority=priority, - )) + ) generator = self.engine_client.generate( engine_request, @@ -956,12 +1010,10 @@ async def _generate_with_builtin_tools( # Create inputs for the next turn. # Render the next prompt token ids. prompt_token_ids = context.render_for_completion() - engine_prompt = EngineTokensPrompt( - prompt_token_ids=prompt_token_ids) + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) request_prompt = prompt_token_ids # Update the sampling params. - sampling_params.max_tokens = self.max_model_len - len( - prompt_token_ids) + sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids) # OPTIMIZATION priority = orig_priority - 1 @@ -978,15 +1030,13 @@ def _log_inputs( self, request_id: str, inputs: Union[RequestPrompt, PromptType], - params: Optional[Union[SamplingParams, PoolingParams, - BeamSearchParams]], + params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], ) -> None: if self.request_logger is None: return - prompt, prompt_token_ids, prompt_embeds = ( - self._get_prompt_components(inputs)) + prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs) self.request_logger.log_inputs( request_id, @@ -1012,8 +1062,9 @@ async def _get_trace_headers( return None @staticmethod - def _base_request_id(raw_request: Optional[Request], - default: Optional[str] = None) -> Optional[str]: + def _base_request_id( + raw_request: Optional[Request], default: Optional[str] = None + ) -> Optional[str]: """Pulls the request id to use from a header, if provided""" default = default or random_uuid() if raw_request is None: @@ -1042,8 +1093,8 @@ def _is_model_supported(self, model_name: Optional[str]) -> bool: def clamp_prompt_logprobs( - prompt_logprobs: Union[PromptLogprobs, - None], ) -> Union[PromptLogprobs, None]: + prompt_logprobs: Union[PromptLogprobs, None], +) -> Union[PromptLogprobs, None]: if prompt_logprobs is None: return prompt_logprobs diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index a4efa0815b4e..d2a58a487a76 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -9,11 +9,15 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.openai.protocol import (ErrorInfo, ErrorResponse, - LoadLoRAAdapterRequest, - ModelCard, ModelList, - ModelPermission, - UnloadLoRAAdapterRequest) +from vllm.entrypoints.openai.protocol import ( + ErrorInfo, + ErrorResponse, + LoadLoRAAdapterRequest, + ModelCard, + ModelList, + ModelPermission, + UnloadLoRAAdapterRequest, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry @@ -65,10 +69,10 @@ def __init__( self.lora_id_counter = AtomicCounter(0) self.lora_resolvers: list[LoRAResolver] = [] - for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers( - ): + for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(): self.lora_resolvers.append( - LoRAResolverRegistry.get_resolver(lora_resolver_name)) + LoRAResolverRegistry.get_resolver(lora_resolver_name) + ) self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) async def init_static_loras(self): @@ -77,10 +81,12 @@ async def init_static_loras(self): if self.static_lora_modules is None: return for lora in self.static_lora_modules: - load_request = LoadLoRAAdapterRequest(lora_path=lora.path, - lora_name=lora.name) + load_request = LoadLoRAAdapterRequest( + lora_path=lora.path, lora_name=lora.name + ) load_result = await self.load_lora_adapter( - request=load_request, base_model_name=lora.base_model_name) + request=load_request, base_model_name=lora.base_model_name + ) if isinstance(load_result, ErrorResponse): raise ValueError(load_result.error.message) @@ -100,47 +106,48 @@ def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: return self.base_model_paths[0].name async def show_available_models(self) -> ModelList: - """Show available models. This includes the base model and all + """Show available models. This includes the base model and all adapters""" model_cards = [ - ModelCard(id=base_model.name, - max_model_len=self.max_model_len, - root=base_model.model_path, - permission=[ModelPermission()]) + ModelCard( + id=base_model.name, + max_model_len=self.max_model_len, + root=base_model.model_path, + permission=[ModelPermission()], + ) for base_model in self.base_model_paths ] lora_cards = [ - ModelCard(id=lora.lora_name, - root=lora.local_path, - parent=lora.base_model_name if lora.base_model_name else - self.base_model_paths[0].name, - permission=[ModelPermission()]) + ModelCard( + id=lora.lora_name, + root=lora.local_path, + parent=lora.base_model_name + if lora.base_model_name + else self.base_model_paths[0].name, + permission=[ModelPermission()], + ) for lora in self.lora_requests.values() ] model_cards.extend(lora_cards) return ModelList(data=model_cards) async def load_lora_adapter( - self, - request: LoadLoRAAdapterRequest, - base_model_name: Optional[str] = None + self, request: LoadLoRAAdapterRequest, base_model_name: Optional[str] = None ) -> Union[ErrorResponse, str]: lora_name = request.lora_name # Ensure atomicity based on the lora name async with self.lora_resolver_lock[lora_name]: - error_check_ret = await self._check_load_lora_adapter_request( - request) + error_check_ret = await self._check_load_lora_adapter_request(request) if error_check_ret is not None: return error_check_ret lora_path = request.lora_path unique_id = self.lora_id_counter.inc(1) - lora_request = LoRARequest(lora_name=lora_name, - lora_int_id=unique_id, - lora_path=lora_path) - if base_model_name is not None and self.is_base_model( - base_model_name): + lora_request = LoRARequest( + lora_name=lora_name, lora_int_id=unique_id, lora_path=lora_path + ) + if base_model_name is not None and self.is_base_model(base_model_name): lora_request.base_model_name = base_model_name # Validate that the adapter can be loaded into the engine @@ -154,24 +161,24 @@ async def load_lora_adapter( error_type = "NotFoundError" status_code = HTTPStatus.NOT_FOUND - return create_error_response(message=str(e), - err_type=error_type, - status_code=status_code) + return create_error_response( + message=str(e), err_type=error_type, status_code=status_code + ) self.lora_requests[lora_name] = lora_request - logger.info("Loaded new LoRA adapter: name '%s', path '%s'", - lora_name, lora_path) + logger.info( + "Loaded new LoRA adapter: name '%s', path '%s'", lora_name, lora_path + ) return f"Success: LoRA adapter '{lora_name}' added successfully." async def unload_lora_adapter( - self, - request: UnloadLoRAAdapterRequest) -> Union[ErrorResponse, str]: + self, request: UnloadLoRAAdapterRequest + ) -> Union[ErrorResponse, str]: lora_name = request.lora_name # Ensure atomicity based on the lora name async with self.lora_resolver_lock[lora_name]: - error_check_ret = await self._check_unload_lora_adapter_request( - request) + error_check_ret = await self._check_unload_lora_adapter_request(request) if error_check_ret is not None: return error_check_ret @@ -181,48 +188,49 @@ async def unload_lora_adapter( return f"Success: LoRA adapter '{lora_name}' removed successfully." async def _check_load_lora_adapter_request( - self, request: LoadLoRAAdapterRequest) -> Optional[ErrorResponse]: + self, request: LoadLoRAAdapterRequest + ) -> Optional[ErrorResponse]: # Check if both 'lora_name' and 'lora_path' are provided if not request.lora_name or not request.lora_path: return create_error_response( message="Both 'lora_name' and 'lora_path' must be provided.", err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) + status_code=HTTPStatus.BAD_REQUEST, + ) # Check if the lora adapter with the given name already exists if request.lora_name in self.lora_requests: return create_error_response( - message= - f"The lora adapter '{request.lora_name}' has already been " + message=f"The lora adapter '{request.lora_name}' has already been " "loaded.", err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) + status_code=HTTPStatus.BAD_REQUEST, + ) return None async def _check_unload_lora_adapter_request( - self, - request: UnloadLoRAAdapterRequest) -> Optional[ErrorResponse]: + self, request: UnloadLoRAAdapterRequest + ) -> Optional[ErrorResponse]: # Check if 'lora_name' is not provided return an error if not request.lora_name: return create_error_response( - message= - "'lora_name' needs to be provided to unload a LoRA adapter.", + message="'lora_name' needs to be provided to unload a LoRA adapter.", err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) + status_code=HTTPStatus.BAD_REQUEST, + ) # Check if the lora adapter with the given name exists if request.lora_name not in self.lora_requests: return create_error_response( - message= - f"The lora adapter '{request.lora_name}' cannot be found.", + message=f"The lora adapter '{request.lora_name}' cannot be found.", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND) + status_code=HTTPStatus.NOT_FOUND, + ) return None - async def resolve_lora( - self, lora_name: str) -> Union[LoRARequest, ErrorResponse]: + async def resolve_lora(self, lora_name: str) -> Union[LoRARequest, ErrorResponse]: """Attempt to resolve a LoRA adapter using available resolvers. Args: @@ -244,8 +252,7 @@ async def resolve_lora( # Try to resolve using available resolvers for resolver in self.lora_resolvers: - lora_request = await resolver.resolve_lora( - base_model_name, lora_name) + lora_request = await resolver.resolve_lora(base_model_name, lora_name) if lora_request is not None: found_adapter = True @@ -256,33 +263,43 @@ async def resolve_lora( self.lora_requests[lora_name] = lora_request logger.info( "Resolved and loaded LoRA adapter '%s' using %s", - lora_name, resolver.__class__.__name__) + lora_name, + resolver.__class__.__name__, + ) return lora_request except BaseException as e: logger.warning( "Failed to load LoRA '%s' resolved by %s: %s. " - "Trying next resolver.", lora_name, - resolver.__class__.__name__, e) + "Trying next resolver.", + lora_name, + resolver.__class__.__name__, + e, + ) continue if found_adapter: # An adapter was found, but all attempts to load it failed. return create_error_response( - message=(f"LoRA adapter '{lora_name}' was found " - "but could not be loaded."), + message=( + f"LoRA adapter '{lora_name}' was found but could not be loaded." + ), err_type="BadRequestError", - status_code=HTTPStatus.BAD_REQUEST) + status_code=HTTPStatus.BAD_REQUEST, + ) else: # No adapter was found return create_error_response( message=f"LoRA adapter {lora_name} does not exist", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND) + status_code=HTTPStatus.NOT_FOUND, + ) def create_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: - return ErrorResponse(error=ErrorInfo( - message=message, type=err_type, code=status_code.value)) + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, +) -> ErrorResponse: + return ErrorResponse( + error=ErrorInfo(message=message, type=err_type, code=status_code.value) + ) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 0750c7ec3e9f..390b388e303c 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -17,15 +17,17 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger -# yapf: disable -from vllm.entrypoints.openai.protocol import (ErrorResponse, - IOProcessorRequest, - IOProcessorResponse, - PoolingChatRequest, - PoolingCompletionRequest, - PoolingRequest, PoolingResponse, - PoolingResponseData, UsageInfo) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + IOProcessorRequest, + IOProcessorResponse, + PoolingChatRequest, + PoolingCompletionRequest, + PoolingRequest, + PoolingResponse, + PoolingResponseData, + UsageInfo, +) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig @@ -55,7 +57,6 @@ def _get_data( class OpenAIServingPooling(OpenAIServing): - def __init__( self, engine_client: EngineClient, @@ -65,16 +66,20 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=vllm_config.model_config, - models=models, - request_logger=request_logger, - log_error_stack=log_error_stack) + super().__init__( + engine_client=engine_client, + model_config=vllm_config.model_config, + models=models, + request_logger=request_logger, + log_error_stack=log_error_stack, + ) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template io_processor_plugin = self.model_config.io_processor_plugin self.io_processor = get_io_processor(vllm_config, io_processor_plugin) @@ -108,12 +113,13 @@ async def create_pooling( if getattr(request, "dimensions", None) is not None: return self.create_error_response( - "dimensions is currently not supported") + "dimensions is currently not supported" + ) - truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", - None) + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) truncate_prompt_tokens = _validate_truncation_size( - self.max_model_len, truncate_prompt_tokens) + self.max_model_len, truncate_prompt_tokens + ) if is_io_processor_request: if self.io_processor is None: @@ -121,14 +127,23 @@ async def create_pooling( "No IOProcessor plugin installed. Please refer " "to the documentation and to the " "'prithvi_geospatial_mae_io_processor' " - "offline inference example for more details.") + "offline inference example for more details." + ) validated_prompt = self.io_processor.parse_request(request) engine_prompts = await self.io_processor.pre_process_async( - prompt=validated_prompt, request_id=request_id) + prompt=validated_prompt, request_id=request_id + ) elif isinstance(request, PoolingChatRequest): + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret ( _, _, @@ -138,8 +153,7 @@ async def create_pooling( tokenizer, request.messages, chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self. - chat_template_content_format, + chat_template_content_format=self.chat_template_content_format, # In pooling requests, we are not generating tokens, # so there is no need to append extra tokens to the input add_generation_prompt=False, @@ -152,8 +166,7 @@ async def create_pooling( config=self._build_render_config(request), ) else: - raise ValueError( - f"Unsupported request of type {type(request)}") + raise ValueError(f"Unsupported request of type {type(request)}") except (ValueError, TypeError, jinja2.TemplateError) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) @@ -171,13 +184,18 @@ async def create_pooling( for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" - self._log_inputs(request_id_item, - engine_prompt, - params=pooling_params, - lora_request=lora_request) + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=lora_request, + ) - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) generator = self.engine_client.encode( engine_prompt, @@ -203,8 +221,7 @@ async def create_pooling( ) return self.io_processor.output_to_response(output) - assert isinstance(request, - (PoolingCompletionRequest, PoolingChatRequest)) + assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest)) num_prompts = len(engine_prompts) # Non-streaming response @@ -216,8 +233,7 @@ async def create_pooling( assert all(final_res is not None for final_res in final_res_batch) - final_res_batch_checked = cast(list[PoolingRequestOutput], - final_res_batch) + final_res_batch_checked = cast(list[PoolingRequestOutput], final_res_batch) response = self.request_output_to_pooling_response( final_res_batch_checked, @@ -268,9 +284,9 @@ def request_output_to_pooling_response( usage=usage, ) - def _build_render_config( - self, request: PoolingCompletionRequest) -> RenderConfig: + def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig: return RenderConfig( max_length=self.max_model_len, truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens) + add_special_tokens=request.add_special_tokens, + ) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index faaed2fca392..1b25fd4eb27e 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -14,59 +14,81 @@ import jinja2 from fastapi import Request -# yapf conflicts with isort for this block -# yapf: disable from openai.types.responses import ( ResponseCodeInterpreterCallCodeDeltaEvent, ResponseCodeInterpreterCallCodeDoneEvent, ResponseCodeInterpreterCallCompletedEvent, ResponseCodeInterpreterCallInProgressEvent, ResponseCodeInterpreterCallInterpretingEvent, - ResponseCodeInterpreterToolCallParam, ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, ResponseFunctionToolCall, - ResponseFunctionWebSearch, ResponseOutputItem, - ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, - ResponseOutputMessage, ResponseOutputText, ResponseReasoningItem, - ResponseReasoningTextDeltaEvent, ResponseReasoningTextDoneEvent, - ResponseStatus, ResponseTextDeltaEvent, ResponseTextDoneEvent, - ResponseWebSearchCallCompletedEvent, ResponseWebSearchCallInProgressEvent, - ResponseWebSearchCallSearchingEvent, response_function_web_search, - response_text_delta_event) -from openai.types.responses.response_output_text import (Logprob, - LogprobTopLogprob) -# yapf: enable + ResponseCodeInterpreterToolCallParam, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseFunctionToolCall, + ResponseFunctionWebSearch, + ResponseOutputItem, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, + ResponseOutputText, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseStatus, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, + ResponseWebSearchCallCompletedEvent, + ResponseWebSearchCallInProgressEvent, + ResponseWebSearchCallSearchingEvent, + response_function_web_search, + response_text_delta_event, +) +from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob from openai.types.responses.response_reasoning_item import ( - Content as ResponseReasoningTextContent) + Content as ResponseReasoningTextContent, +) from openai_harmony import Message as OpenAIHarmonyMessage from vllm import envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - ChatTemplateContentFormatOption) -from vllm.entrypoints.context import (ConversationContext, HarmonyContext, - SimpleContext, StreamingHarmonyContext) +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormatOption, +) +from vllm.entrypoints.context import ( + ConversationContext, + HarmonyContext, + SimpleContext, + StreamingHarmonyContext, +) from vllm.entrypoints.harmony_utils import ( - get_developer_message, get_stop_tokens_for_assistant_actions, - get_system_message, get_user_message, has_custom_tools, - parse_output_message, parse_remaining_state, parse_response_input, - render_for_completion) + get_developer_message, + get_stop_tokens_for_assistant_actions, + get_system_message, + get_user_message, + has_custom_tools, + parse_output_message, + parse_remaining_state, + parse_response_input, + render_for_completion, +) from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.openai.protocol import (DeltaMessage, ErrorResponse, - InputTokensDetails, - OutputTokensDetails, - RequestResponseMetadata, - ResponseCompletedEvent, - ResponseCreatedEvent, - ResponseInProgressEvent, - ResponseReasoningPartAddedEvent, - ResponseReasoningPartDoneEvent, - ResponsesRequest, - ResponsesResponse, ResponseUsage, - StreamingResponsesResponse) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + DeltaMessage, + ErrorResponse, + InputTokensDetails, + OutputTokensDetails, + RequestResponseMetadata, + ResponseCompletedEvent, + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseReasoningPartAddedEvent, + ResponseReasoningPartDoneEvent, + ResponsesRequest, + ResponsesResponse, + ResponseUsage, + StreamingResponsesResponse, +) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.tool_server import ToolServer @@ -84,7 +106,6 @@ class OpenAIServingResponses(OpenAIServing): - def __init__( self, engine_client: EngineClient, @@ -118,27 +139,29 @@ def __init__( self.chat_template_content_format: Final = chat_template_content_format self.enable_log_outputs = enable_log_outputs - self.reasoning_parser: Optional[Callable[[AnyTokenizer], - ReasoningParser]] = None + self.reasoning_parser: Optional[Callable[[AnyTokenizer], ReasoningParser]] = ( + None + ) if reasoning_parser: try: - self.reasoning_parser = ( - ReasoningParserManager.get_reasoning_parser( - reasoning_parser)) + self.reasoning_parser = ReasoningParserManager.get_reasoning_parser( + reasoning_parser + ) assert self.reasoning_parser is not None except Exception as e: - raise TypeError( - f"{reasoning_parser=} has not been registered") from e + raise TypeError(f"{reasoning_parser=} has not been registered") from e self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_force_include_usage = enable_force_include_usage - self.default_sampling_params = ( - self.model_config.get_diff_sampling_param()) + self.default_sampling_params = self.model_config.get_diff_sampling_param() if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source - logger.info("Using default chat sampling params from %s: %s", - source, self.default_sampling_params) + logger.info( + "Using default chat sampling params from %s: %s", + source, + self.default_sampling_params, + ) # If False (default), the "store" option is (silently) ignored and the # response is not stored. If True, the response is stored in memory. @@ -150,26 +173,31 @@ def __init__( logger.warning_once( "`VLLM_ENABLE_RESPONSES_API_STORE` is enabled. This may " "cause a memory leak since we never remove responses from " - "the store.") + "the store." + ) self.use_harmony = model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: - logger.warning("For gpt-oss, we ignore --enable-auto-tool-choice " - "and always enable tool use.") + logger.warning( + "For gpt-oss, we ignore --enable-auto-tool-choice " + "and always enable tool use." + ) # OpenAI models have two EOS-like tokens: <|return|> and <|call|>. # We need to add them to the stop token ids. if "stop_token_ids" not in self.default_sampling_params: self.default_sampling_params["stop_token_ids"] = [] self.default_sampling_params["stop_token_ids"].extend( - get_stop_tokens_for_assistant_actions()) + get_stop_tokens_for_assistant_actions() + ) # set up tool use self.enable_auto_tools: bool = enable_auto_tools if self.enable_auto_tools: logger.info( - "\"auto\" tool choice has been enabled please note that while" + '"auto" tool choice has been enabled please note that while' " the parallel_tool_calls client option is preset for " - "compatibility reasons, it will be ignored.") + "compatibility reasons, it will be ignored." + ) # HACK(woosuk): This is a hack. We should use a better store. # FIXME: If enable_store=True, this may cause a memory leak since we @@ -185,23 +213,25 @@ def __init__( # HACK(wuhang): This is a hack. We should use a better store. # FIXME: If enable_store=True, this may cause a memory leak since we # never remove events from the store. - self.event_store: dict[str, tuple[deque[StreamingResponsesResponse], - asyncio.Event]] = {} + self.event_store: dict[ + str, tuple[deque[StreamingResponsesResponse], asyncio.Event] + ] = {} self.background_tasks: dict[str, asyncio.Task] = {} self.tool_server = tool_server def _validate_generator_input( - self, - engine_prompt: EngineTokensPrompt) -> Optional[ErrorResponse]: + self, engine_prompt: EngineTokensPrompt + ) -> Optional[ErrorResponse]: """Add validations to the input to the generator here.""" if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): error_message = ( "The engine prompt length" f" {len(engine_prompt['prompt_token_ids'])} " f"exceeds the max_model_len {self.max_model_len}. " - "Please reduce prompt.") + "Please reduce prompt." + ) return self.create_error_response( err_type="invalid_request_error", message=error_message, @@ -213,8 +243,11 @@ async def create_responses( self, request: ResponsesRequest, raw_request: Optional[Request] = None, - ) -> Union[AsyncGenerator[StreamingResponsesResponse, None], - ResponsesResponse, ErrorResponse]: + ) -> Union[ + AsyncGenerator[StreamingResponsesResponse, None], + ResponsesResponse, + ErrorResponse, + ]: error_check_ret = await self._check_model(request) if error_check_ret is not None: logger.error("Error with model %s", error_check_ret) @@ -235,7 +268,8 @@ async def create_responses( "therefore does not support the background mode. To " "enable these features, set the environment variable " "`VLLM_ENABLE_RESPONSES_API_STORE=1` when launching " - "the vLLM server."), + "the vLLM server." + ), status_code=HTTPStatus.BAD_REQUEST, ) # Disable the store option. @@ -269,19 +303,24 @@ async def create_responses( if self.use_harmony: messages, request_prompts, engine_prompts = ( - self._make_request_with_harmony(request, prev_response)) + self._make_request_with_harmony(request, prev_response) + ) else: - messages, request_prompts, engine_prompts = ( - await self._make_request(request, prev_response, - tokenizer)) + messages, request_prompts, engine_prompts = await self._make_request( + request, prev_response, tokenizer + ) - except (ValueError, TypeError, RuntimeError, jinja2.TemplateError, - NotImplementedError) as e: + except ( + ValueError, + TypeError, + RuntimeError, + jinja2.TemplateError, + NotImplementedError, + ) as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(f"{e} {e.__cause__}") - request_metadata = RequestResponseMetadata( - request_id=request.request_id) + request_metadata = RequestResponseMetadata(request_id=request.request_id) if raw_request: raw_request.state.request_metadata = request_metadata @@ -309,19 +348,23 @@ async def create_responses( return maybe_error default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) + engine_prompt["prompt_token_ids"] + ) sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) + default_max_tokens, self.default_sampling_params + ) - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) context: ConversationContext if self.use_harmony: if request.stream: - context = StreamingHarmonyContext( - messages, available_tools) + context = StreamingHarmonyContext(messages, available_tools) else: context = HarmonyContext(messages, available_tools) else: @@ -342,7 +385,7 @@ async def create_responses( return self.create_error_response(str(e)) assert len(generators) == 1 - result_generator, = generators + (result_generator,) = generators # Store the input messages. if request.store: @@ -396,11 +439,11 @@ async def create_responses( response_id = response.id self.background_tasks[response_id] = task task.add_done_callback( - lambda _: self.background_tasks.pop(response_id, None)) + lambda _: self.background_tasks.pop(response_id, None) + ) if request.stream: - return self.responses_background_stream_generator( - request.request_id) + return self.responses_background_stream_generator(request.request_id) return response if request.stream: @@ -435,7 +478,8 @@ async def _make_request( ): if len(request.tools) > 0: raise NotImplementedError( - "Tool use is not supported in Responses API without Harmony") + "Tool use is not supported in Responses API without Harmony" + ) # Construct the input messages. messages = self._construct_input_messages(request, prev_response) _, request_prompts, engine_prompts = await self._preprocess_chat( @@ -454,10 +498,9 @@ def _make_request_with_harmony( ): if request.tool_choice != "auto": raise NotImplementedError( - "Only 'auto' tool_choice is supported in " - "response API with Harmony") - messages = self._construct_input_messages_with_harmony( - request, prev_response) + "Only 'auto' tool_choice is supported in response API with Harmony" + ) + messages = self._construct_input_messages_with_harmony(request, prev_response) prompt_token_ids = render_for_completion(messages) engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) @@ -467,18 +510,21 @@ def _make_request_with_harmony( return messages, [prompt_token_ids], [engine_prompt] - async def _initialize_tool_sessions(self, request: ResponsesRequest, - context: ConversationContext, - exit_stack: AsyncExitStack): + async def _initialize_tool_sessions( + self, + request: ResponsesRequest, + context: ConversationContext, + exit_stack: AsyncExitStack, + ): # we should only initialize the tool session if the request needs tools if len(request.tools) == 0: return mcp_tools = { - tool.server_label: tool - for tool in request.tools if tool.type == "mcp" + tool.server_label: tool for tool in request.tools if tool.type == "mcp" } - await context.init_tool_sessions(self.tool_server, exit_stack, - request.request_id, mcp_tools) + await context.init_tool_sessions( + self.tool_server, exit_stack, request.request_id, mcp_tools + ) async def responses_full_generator( self, @@ -496,8 +542,7 @@ async def responses_full_generator( async with AsyncExitStack() as exit_stack: try: - await self._initialize_tool_sessions(request, context, - exit_stack) + await self._initialize_tool_sessions(request, context, exit_stack) async for _ in result_generator: pass except asyncio.CancelledError: @@ -517,8 +562,8 @@ async def responses_full_generator( assert isinstance(context, HarmonyContext) output = self._make_response_output_items_with_harmony(context) if request.enable_response_messages: - input_messages = context.messages[:context.num_init_messages] - output_messages = context.messages[context.num_init_messages:] + input_messages = context.messages[: context.num_init_messages] + output_messages = context.messages[context.num_init_messages :] num_tool_output_tokens = context.num_tool_output_tokens if len(output) > 0: if context.finish_reason == "length": @@ -534,15 +579,14 @@ async def responses_full_generator( assert len(final_res.outputs) == 1 final_output = final_res.outputs[0] - output = self._make_response_output_items(request, final_output, - tokenizer) + output = self._make_response_output_items(request, final_output, tokenizer) # TODO: context for non-gptoss models doesn't use messages # so we can't get them out yet if request.enable_response_messages: raise NotImplementedError( - "enable_response_messages is currently" - " only supported for gpt-oss") + "enable_response_messages is currently only supported for gpt-oss" + ) # Calculate usage. assert final_res.prompt_token_ids is not None num_tool_output_tokens = 0 @@ -557,11 +601,11 @@ async def responses_full_generator( input_tokens=num_prompt_tokens, output_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, - input_tokens_details=InputTokensDetails( - cached_tokens=num_cached_tokens), + input_tokens_details=InputTokensDetails(cached_tokens=num_cached_tokens), output_tokens_details=OutputTokensDetails( reasoning_tokens=num_reasoning_tokens, - tool_output_tokens=num_tool_output_tokens), + tool_output_tokens=num_tool_output_tokens, + ), ) response = ResponsesResponse.from_request( request, @@ -579,54 +623,67 @@ async def responses_full_generator( async with self.response_store_lock: stored_response = self.response_store.get(response.id) # If the response is already cancelled, don't update it. - if (stored_response is None - or stored_response.status != "cancelled"): + if stored_response is None or stored_response.status != "cancelled": self.response_store[response.id] = response return response - def _topk_logprobs(self, logprobs: dict[int, - SampleLogprob], top_logprobs: int, - tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]: + def _topk_logprobs( + self, + logprobs: dict[int, SampleLogprob], + top_logprobs: int, + tokenizer: AnyTokenizer, + ) -> list[LogprobTopLogprob]: """Returns the top-k logprobs from the logprobs dictionary.""" out = [] for i, (token_id, _logprob) in enumerate(logprobs.items()): if i >= top_logprobs: break - text = _logprob.decoded_token if _logprob.decoded_token \ - is not None else tokenizer.decode([token_id]) + text = ( + _logprob.decoded_token + if _logprob.decoded_token is not None + else tokenizer.decode([token_id]) + ) out.append( LogprobTopLogprob( token=text, logprob=max(_logprob.logprob, -9999.0), bytes=list(text.encode("utf-8", errors="replace")), - )) + ) + ) return out def _create_response_logprobs( - self, - token_ids: Sequence[int], - logprobs: Optional[SampleLogprobs], - tokenizer: AnyTokenizer, - top_logprobs: Optional[int] = None) -> list[Logprob]: + self, + token_ids: Sequence[int], + logprobs: Optional[SampleLogprobs], + tokenizer: AnyTokenizer, + top_logprobs: Optional[int] = None, + ) -> list[Logprob]: assert logprobs is not None, "logprobs must be provided" assert len(token_ids) == len(logprobs), ( - "token_ids and logprobs.token_ids must have the same length") + "token_ids and logprobs.token_ids must have the same length" + ) out = [] for i, token_id in enumerate(token_ids): logprob = logprobs[i] token_logprob = logprob[token_id] - text = token_logprob.decoded_token if token_logprob.decoded_token \ - is not None else tokenizer.decode([token_id]) + text = ( + token_logprob.decoded_token + if token_logprob.decoded_token is not None + else tokenizer.decode([token_id]) + ) out.append( Logprob( token=text, logprob=max(token_logprob.logprob, -9999.0), bytes=list(text.encode("utf-8", errors="replace")), - top_logprobs=self._topk_logprobs(logprob, - top_logprobs=top_logprobs, - tokenizer=tokenizer) - if top_logprobs else [], - )) + top_logprobs=self._topk_logprobs( + logprob, top_logprobs=top_logprobs, tokenizer=tokenizer + ) + if top_logprobs + else [], + ) + ) return out def _create_stream_response_logprobs( @@ -634,21 +691,26 @@ def _create_stream_response_logprobs( token_ids: Sequence[int], logprobs: Optional[SampleLogprobs], tokenizer: AnyTokenizer, - top_logprobs: Optional[int] = None + top_logprobs: Optional[int] = None, ) -> list[response_text_delta_event.Logprob]: - lgs = self._create_response_logprobs(token_ids=token_ids, - logprobs=logprobs, - tokenizer=tokenizer, - top_logprobs=top_logprobs) + lgs = self._create_response_logprobs( + token_ids=token_ids, + logprobs=logprobs, + tokenizer=tokenizer, + top_logprobs=top_logprobs, + ) return [ response_text_delta_event.Logprob( token=lg.token, logprob=lg.logprob, top_logprobs=[ response_text_delta_event.LogprobTopLogprob( - token=tl.token, logprob=tl.logprob) + token=tl.token, logprob=tl.logprob + ) for tl in lg.top_logprobs - ]) for lg in lgs + ], + ) + for lg in lgs ] def _make_response_output_items( @@ -664,9 +726,9 @@ def _make_response_output_items( logger.exception("Error in reasoning parser creation.") raise e - reasoning_content, content = ( - reasoning_parser.extract_reasoning_content(final_output.text, - request=request)) + reasoning_content, content = reasoning_parser.extract_reasoning_content( + final_output.text, request=request + ) else: reasoning_content = None content = final_output.text @@ -696,8 +758,9 @@ def _make_response_output_items( summary=[], type="reasoning", content=[ - ResponseReasoningTextContent(text=reasoning_content, - type="reasoning_text") + ResponseReasoningTextContent( + text=reasoning_content, type="reasoning_text" + ) ], status=None, # NOTE: Only the last output item has status. ) @@ -712,7 +775,9 @@ def _make_response_output_items( logprobs=final_output.logprobs, tokenizer=tokenizer, top_logprobs=request.top_logprobs, - ) if request.is_include_output_logprobs() else None, + ) + if request.is_include_output_logprobs() + else None, ) message = ResponseOutputMessage( id=f"msg_{random_uuid()}", @@ -745,10 +810,12 @@ def _construct_input_messages( ) -> list[ChatCompletionMessageParam]: messages: list[ChatCompletionMessageParam] = [] if request.instructions: - messages.append({ - "role": "system", - "content": request.instructions, - }) + messages.append( + { + "role": "system", + "content": request.instructions, + } + ) # Prepend the conversation history. if prev_response is not None: @@ -761,10 +828,12 @@ def _construct_input_messages( # NOTE: We skip the reasoning output. if isinstance(output_item, ResponseOutputMessage): for content in output_item.content: - messages.append({ - "role": "assistant", - "content": content.text, - }) + messages.append( + { + "role": "assistant", + "content": content.text, + } + ) # Append the new input. # Responses API supports simple text inputs without chat format. @@ -782,8 +851,7 @@ def _construct_input_messages_with_harmony( messages: list[OpenAIHarmonyMessage] = [] if prev_response is None: # New conversation. - reasoning_effort = (request.reasoning.effort - if request.reasoning else None) + reasoning_effort = request.reasoning.effort if request.reasoning else None tool_types = [tool.type for tool in request.tools] # Allow the MCP Tool type to enable built in tools if the @@ -791,37 +859,46 @@ def _construct_input_messages_with_harmony( # envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS: for tool in request.tools: - if (tool.type == "mcp" and tool.server_label - in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS): + if ( + tool.type == "mcp" + and tool.server_label in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS + ): tool_types.append(tool.server_label) - enable_browser = ("web_search_preview" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("browser")) - enable_code_interpreter = ("code_interpreter" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("python")) - enable_container = ("container" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("container")) + enable_browser = ( + "web_search_preview" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("browser") + ) + enable_code_interpreter = ( + "code_interpreter" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("python") + ) + enable_container = ( + "container" in tool_types + and self.tool_server is not None + and self.tool_server.has_tool("container") + ) with_custom_tools = has_custom_tools(tool_types) sys_msg = get_system_message( reasoning_effort=reasoning_effort, - browser_description=self.tool_server.get_tool_description( - "browser") - if enable_browser and self.tool_server is not None else None, - python_description=self.tool_server.get_tool_description( - "python") if enable_code_interpreter - and self.tool_server is not None else None, - container_description=self.tool_server.get_tool_description( - "container") - if enable_container and self.tool_server is not None else None, + browser_description=self.tool_server.get_tool_description("browser") + if enable_browser and self.tool_server is not None + else None, + python_description=self.tool_server.get_tool_description("python") + if enable_code_interpreter and self.tool_server is not None + else None, + container_description=self.tool_server.get_tool_description("container") + if enable_container and self.tool_server is not None + else None, instructions=request.instructions, with_custom_tools=with_custom_tools, ) messages.append(sys_msg) if with_custom_tools: dev_msg = get_developer_message( - instructions=request.instructions, tools=request.tools) + instructions=request.instructions, tools=request.tools + ) messages.append(dev_msg) else: # Continue the previous conversation. @@ -842,8 +919,8 @@ def _construct_input_messages_with_harmony( if prev_msg_i.channel == "final": prev_final_msg_idx = i break - recent_turn_msgs = prev_msgs[prev_final_msg_idx + 1:] - del prev_msgs[prev_final_msg_idx + 1:] + recent_turn_msgs = prev_msgs[prev_final_msg_idx + 1 :] + del prev_msgs[prev_final_msg_idx + 1 :] for msg in recent_turn_msgs: assert isinstance(msg, OpenAIHarmonyMessage) if msg.channel != "analysis": @@ -859,8 +936,7 @@ def _construct_input_messages_with_harmony( else: prev_outputs = [] for response_msg in request.input: - messages.append( - parse_response_input(response_msg, prev_outputs)) + messages.append(parse_response_input(response_msg, prev_outputs)) # User passes in a tool call request and its output. We need # to add the tool call request to prev_outputs so that the # parse_response_input can find the tool call request when @@ -880,14 +956,12 @@ async def _run_background_request_stream( self.event_store[request.request_id] = (event_deque, new_event_signal) response = None try: - generator = self.responses_stream_generator( - request, *args, **kwargs) + generator = self.responses_stream_generator(request, *args, **kwargs) async for event in generator: event_deque.append(event) new_event_signal.set() # Signal new event available except Exception as e: - logger.exception("Background request failed for %s", - request.request_id) + logger.exception("Background request failed for %s", request.request_id) response = self.create_error_response(str(e)) finally: new_event_signal.set() @@ -908,11 +982,9 @@ async def _run_background_request( **kwargs, ): try: - response = await self.responses_full_generator( - request, *args, **kwargs) + response = await self.responses_full_generator(request, *args, **kwargs) except Exception as e: - logger.exception("Background request failed for %s", - request.request_id) + logger.exception("Background request failed for %s", request.request_id) response = self.create_error_response(str(e)) if isinstance(response, ErrorResponse): @@ -943,7 +1015,7 @@ async def responses_background_stream_generator( while current_index < len(event_deque): event = event_deque[current_index] yield event - if getattr(event, 'type', 'unknown') == "response.completed": + if getattr(event, "type", "unknown") == "response.completed": return current_index += 1 @@ -954,8 +1026,11 @@ async def retrieve_responses( response_id: str, starting_after: Optional[int], stream: Optional[bool], - ) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[ - StreamingResponsesResponse, None]]: + ) -> Union[ + ErrorResponse, + ResponsesResponse, + AsyncGenerator[StreamingResponsesResponse, None], + ]: async with self.response_store_lock: response = self.response_store.get(response_id) @@ -989,13 +1064,12 @@ async def cancel_responses( response.status = "cancelled" # Abort the request. - if (task := self.background_tasks.get(response_id)): + if task := self.background_tasks.get(response_id): task.cancel() try: await task except asyncio.CancelledError: - logger.exception("Background task for %s was cancelled", - response_id) + logger.exception("Background task for %s was cancelled", response_id) return response def _make_not_found_error(self, response_id: str) -> ErrorResponse: @@ -1008,10 +1082,12 @@ def _make_not_found_error(self, response_id: str) -> ErrorResponse: def _make_store_not_supported_error(self) -> ErrorResponse: return self.create_error_response( err_type="invalid_request_error", - message=("`store=True` (default) is not supported. Please set " - "`store=False` in Responses API or set " - "`VLLM_ENABLE_RESPONSES_API_STORE=1` in the env var when " - "starting the vLLM server."), + message=( + "`store=True` (default) is not supported. Please set " + "`store=False` in Responses API or set " + "`VLLM_ENABLE_RESPONSES_API_STORE=1` in the env var when " + "starting the vLLM server." + ), status_code=HTTPStatus.BAD_REQUEST, ) @@ -1026,7 +1102,8 @@ async def _process_simple_streaming_events( request_metadata: RequestResponseMetadata, created_time: int, _increment_sequence_number_and_return: Callable[ - [StreamingResponsesResponse], StreamingResponsesResponse], + [StreamingResponsesResponse], StreamingResponsesResponse + ], ) -> AsyncGenerator[StreamingResponsesResponse, None]: current_content_index = 0 current_output_index = 0 @@ -1045,18 +1122,20 @@ async def _process_simple_streaming_events( if ctx.last_output.outputs: output = ctx.last_output.outputs[0] if reasoning_parser: - delta_message = \ + delta_message = ( reasoning_parser.extract_reasoning_content_streaming( - previous_text=previous_text, - current_text=previous_text + output.text, - delta_text=output.text, - previous_token_ids=previous_token_ids, - current_token_ids=previous_token_ids + - output.token_ids, - delta_token_ids=output.token_ids, + previous_text=previous_text, + current_text=previous_text + output.text, + delta_text=output.text, + previous_token_ids=previous_token_ids, + current_token_ids=previous_token_ids + output.token_ids, + delta_token_ids=output.token_ids, + ) ) else: - delta_message = DeltaMessage(content=output.text, ) + delta_message = DeltaMessage( + content=output.text, + ) previous_text += output.text previous_token_ids += output.token_ids if not delta_message: @@ -1075,7 +1154,8 @@ async def _process_simple_streaming_events( summary=[], status="in_progress", ), - )) + ) + ) else: yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( @@ -1089,7 +1169,8 @@ async def _process_simple_streaming_events( content=[], status="in_progress", ), - )) + ) + ) yield _increment_sequence_number_and_return( ResponseContentPartAddedEvent( type="response.content_part.added", @@ -1103,21 +1184,26 @@ async def _process_simple_streaming_events( annotations=[], logprobs=[], ), - )) + ) + ) current_content_index += 1 first_delta_sent = True # todo(kebe7jun) tool call support # check delta message and previous delta message are # same as content or reasoning content - if (previous_delta_messages - and previous_delta_messages[-1].reasoning_content - is not None and delta_message.content is not None): + if ( + previous_delta_messages + and previous_delta_messages[-1].reasoning_content is not None + and delta_message.content is not None + ): # from reasoning to normal content, send done # event for reasoning - reason_content = ''.join( - pm.reasoning_content for pm in previous_delta_messages - if pm.reasoning_content is not None) + reason_content = "".join( + pm.reasoning_content + for pm in previous_delta_messages + if pm.reasoning_content is not None + ) yield _increment_sequence_number_and_return( ResponseReasoningTextDoneEvent( type="response.reasoning_text.done", @@ -1126,7 +1212,8 @@ async def _process_simple_streaming_events( output_index=current_output_index, content_index=current_content_index, text=reason_content, - )) + ) + ) current_content_index = 0 reasoning_item = ResponseReasoningItem( type="reasoning", @@ -1146,7 +1233,8 @@ async def _process_simple_streaming_events( sequence_number=-1, output_index=current_output_index, item=reasoning_item, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added", @@ -1159,7 +1247,8 @@ async def _process_simple_streaming_events( content=[], status="in_progress", ), - )) + ) + ) current_output_index += 1 current_item_id = str(uuid.uuid4()) yield _increment_sequence_number_and_return( @@ -1175,7 +1264,8 @@ async def _process_simple_streaming_events( annotations=[], logprobs=[], ), - )) + ) + ) current_content_index += 1 # reset previous delta messages previous_delta_messages = [] @@ -1189,7 +1279,8 @@ async def _process_simple_streaming_events( output_index=current_output_index, item_id=current_item_id, delta=delta_message.reasoning_content, - )) + ) + ) elif delta_message.content is not None: yield _increment_sequence_number_and_return( ResponseTextDeltaEvent( @@ -1204,16 +1295,21 @@ async def _process_simple_streaming_events( logprobs=output.logprobs, tokenizer=tokenizer, top_logprobs=request.top_logprobs, - ) if request.is_include_output_logprobs() else [], - )) + ) + if request.is_include_output_logprobs() + else [], + ) + ) current_content_index += 1 previous_delta_messages.append(delta_message) if previous_delta_messages: if previous_delta_messages[-1].reasoning_content is not None: - reason_content = ''.join(pm.reasoning_content - for pm in previous_delta_messages - if pm.reasoning_content is not None) + reason_content = "".join( + pm.reasoning_content + for pm in previous_delta_messages + if pm.reasoning_content is not None + ) yield _increment_sequence_number_and_return( ResponseReasoningTextDoneEvent( type="response.reasoning_text.done", @@ -1222,7 +1318,8 @@ async def _process_simple_streaming_events( output_index=current_output_index, content_index=current_content_index, text=reason_content, - )) + ) + ) current_content_index += 1 reasoning_item = ResponseReasoningItem( type="reasoning", @@ -1242,11 +1339,14 @@ async def _process_simple_streaming_events( sequence_number=-1, output_index=current_output_index, item=reasoning_item, - )) + ) + ) elif previous_delta_messages[-1].content is not None: - final_content = ''.join(pm.content - for pm in previous_delta_messages - if pm.content is not None) + final_content = "".join( + pm.content + for pm in previous_delta_messages + if pm.content is not None + ) yield _increment_sequence_number_and_return( ResponseTextDoneEvent( type="response.output_text.done", @@ -1256,7 +1356,8 @@ async def _process_simple_streaming_events( text=final_content, logprobs=[], item_id=current_item_id, - )) + ) + ) current_content_index += 1 part = ResponseOutputText( text=final_content, @@ -1271,7 +1372,8 @@ async def _process_simple_streaming_events( output_index=current_output_index, content_index=current_content_index, part=part, - )) + ) + ) current_content_index += 1 item = ResponseOutputMessage( type="message", @@ -1289,7 +1391,8 @@ async def _process_simple_streaming_events( sequence_number=-1, output_index=current_output_index, item=item, - )) + ) + ) async def _process_harmony_streaming_events( self, @@ -1302,7 +1405,8 @@ async def _process_harmony_streaming_events( request_metadata: RequestResponseMetadata, created_time: int, _increment_sequence_number_and_return: Callable[ - [StreamingResponsesResponse], StreamingResponsesResponse], + [StreamingResponsesResponse], StreamingResponsesResponse + ], ) -> AsyncGenerator[StreamingResponsesResponse, None]: current_content_index = -1 current_output_index = 0 @@ -1310,7 +1414,6 @@ async def _process_harmony_streaming_events( sent_output_item_added = False async for ctx in result_generator: - assert isinstance(ctx, StreamingHarmonyContext) if ctx.is_expecting_start(): @@ -1342,7 +1445,8 @@ async def _process_harmony_streaming_events( output_index=current_output_index, content_index=current_content_index, text=previous_item.content[0].text, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseReasoningPartDoneEvent( type="response.reasoning_part.done", @@ -1351,14 +1455,16 @@ async def _process_harmony_streaming_events( output_index=current_output_index, content_index=current_content_index, part=content, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseOutputItemDoneEvent( type="response.output_item.done", sequence_number=-1, output_index=current_output_index, item=reasoning_item, - )) + ) + ) elif previous_item.channel == "final": text_content = ResponseOutputText( type="output_text", @@ -1374,7 +1480,8 @@ async def _process_harmony_streaming_events( text=previous_item.content[0].text, logprobs=[], item_id=current_item_id, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseContentPartDoneEvent( type="response.content_part.done", @@ -1383,7 +1490,8 @@ async def _process_harmony_streaming_events( output_index=current_output_index, content_index=current_content_index, part=text_content, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseOutputItemDoneEvent( type="response.output_item.done", @@ -1396,12 +1504,15 @@ async def _process_harmony_streaming_events( content=[text_content], status="completed", ), - )) + ) + ) # stream the output of a harmony message if ctx.parser.last_content_delta: - if (ctx.parser.current_channel == "final" - and ctx.parser.current_recipient is None): + if ( + ctx.parser.current_channel == "final" + and ctx.parser.current_recipient is None + ): if not sent_output_item_added: sent_output_item_added = True current_item_id = f"msg_{random_uuid()}" @@ -1417,7 +1528,8 @@ async def _process_harmony_streaming_events( content=[], status="in_progress", ), - )) + ) + ) current_content_index += 1 yield _increment_sequence_number_and_return( ResponseContentPartAddedEvent( @@ -1432,7 +1544,8 @@ async def _process_harmony_streaming_events( annotations=[], logprobs=[], ), - )) + ) + ) yield _increment_sequence_number_and_return( ResponseTextDeltaEvent( type="response.output_text.delta", @@ -1443,9 +1556,12 @@ async def _process_harmony_streaming_events( delta=ctx.parser.last_content_delta, # TODO, use logprobs from ctx.last_request_output logprobs=[], - )) - elif (ctx.parser.current_channel == "analysis" - and ctx.parser.current_recipient is None): + ) + ) + elif ( + ctx.parser.current_channel == "analysis" + and ctx.parser.current_recipient is None + ): if not sent_output_item_added: sent_output_item_added = True current_item_id = f"msg_{random_uuid()}" @@ -1460,7 +1576,8 @@ async def _process_harmony_streaming_events( summary=[], status="in_progress", ), - )) + ) + ) current_content_index += 1 yield _increment_sequence_number_and_return( ResponseReasoningPartAddedEvent( @@ -1473,7 +1590,8 @@ async def _process_harmony_streaming_events( text="", type="reasoning_text", ), - )) + ) + ) yield _increment_sequence_number_and_return( ResponseReasoningTextDeltaEvent( type="response.reasoning_text.delta", @@ -1482,13 +1600,15 @@ async def _process_harmony_streaming_events( content_index=current_content_index, delta=ctx.parser.last_content_delta, sequence_number=-1, - )) + ) + ) # built-in tools will be triggered on the analysis channel # However, occasionally built-in tools will # still be output to commentary. - elif (ctx.parser.current_channel == "commentary" - or ctx.parser.current_channel == "analysis" - ) and ctx.parser.current_recipient == "python": + elif ( + ctx.parser.current_channel == "commentary" + or ctx.parser.current_channel == "analysis" + ) and ctx.parser.current_recipient == "python": if not sent_output_item_added: sent_output_item_added = True current_item_id = f"tool_{random_uuid()}" @@ -1505,15 +1625,16 @@ async def _process_harmony_streaming_events( outputs=None, status="in_progress", ), - )) + ) + ) yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallInProgressEvent( - type= - "response.code_interpreter_call.in_progress", + type="response.code_interpreter_call.in_progress", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallCodeDeltaEvent( type="response.code_interpreter_call_code.delta", @@ -1521,41 +1642,41 @@ async def _process_harmony_streaming_events( output_index=current_output_index, item_id=current_item_id, delta=ctx.parser.last_content_delta, - )) + ) + ) # stream tool call outputs if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: previous_item = ctx.parser.messages[-1] - if (self.tool_server is not None - and self.tool_server.has_tool("browser") - and previous_item.recipient is not None - and previous_item.recipient.startswith("browser.")): - function_name = previous_item.recipient[len("browser."):] + if ( + self.tool_server is not None + and self.tool_server.has_tool("browser") + and previous_item.recipient is not None + and previous_item.recipient.startswith("browser.") + ): + function_name = previous_item.recipient[len("browser.") :] action = None parsed_args = json.loads(previous_item.content[0].text) if function_name == "search": - action = (response_function_web_search.ActionSearch( + action = response_function_web_search.ActionSearch( type="search", query=parsed_args["query"], - )) + ) elif function_name == "open": - action = ( - response_function_web_search.ActionOpenPage( - type="open_page", - # TODO: translate to url - url=f"cursor:{parsed_args.get('cursor', '')}", - )) + action = response_function_web_search.ActionOpenPage( + type="open_page", + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + ) elif function_name == "find": - action = ( - response_function_web_search.ActionFind( - type="find", - pattern=parsed_args["pattern"], - # TODO: translate to url - url=f"cursor:{parsed_args.get('cursor', '')}", - )) + action = response_function_web_search.ActionFind( + type="find", + pattern=parsed_args["pattern"], + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + ) else: - raise ValueError( - f"Unknown function name: {function_name}") + raise ValueError(f"Unknown function name: {function_name}") current_item_id = f"tool_{random_uuid()}" yield _increment_sequence_number_and_return( @@ -1563,29 +1684,31 @@ async def _process_harmony_streaming_events( type="response.output_item.added", sequence_number=-1, output_index=current_output_index, - item=response_function_web_search. - ResponseFunctionWebSearch( + item=response_function_web_search.ResponseFunctionWebSearch( # TODO: generate a unique id for web search call type="web_search_call", id=current_item_id, action=action, status="in_progress", ), - )) + ) + ) yield _increment_sequence_number_and_return( ResponseWebSearchCallInProgressEvent( type="response.web_search_call.in_progress", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseWebSearchCallSearchingEvent( type="response.web_search_call.searching", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) + ) + ) # enqueue yield _increment_sequence_number_and_return( @@ -1594,7 +1717,8 @@ async def _process_harmony_streaming_events( sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseOutputItemDoneEvent( type="response.output_item.done", @@ -1606,12 +1730,15 @@ async def _process_harmony_streaming_events( action=action, status="completed", ), - )) + ) + ) - if (self.tool_server is not None - and self.tool_server.has_tool("python") - and previous_item.recipient is not None - and previous_item.recipient.startswith("python")): + if ( + self.tool_server is not None + and self.tool_server.has_tool("python") + and previous_item.recipient is not None + and previous_item.recipient.startswith("python") + ): yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallCodeDoneEvent( type="response.code_interpreter_call_code.done", @@ -1619,21 +1746,24 @@ async def _process_harmony_streaming_events( output_index=current_output_index, item_id=current_item_id, code=previous_item.content[0].text, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallInterpretingEvent( type="response.code_interpreter_call.interpreting", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseCodeInterpreterCallCompletedEvent( type="response.code_interpreter_call.completed", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseOutputItemDoneEvent( type="response.output_item.done", @@ -1648,7 +1778,8 @@ async def _process_harmony_streaming_events( outputs=[], status="completed", ), - )) + ) + ) async def responses_stream_generator( self, @@ -1669,11 +1800,11 @@ async def responses_stream_generator( sequence_number = 0 def _increment_sequence_number_and_return( - event: StreamingResponsesResponse + event: StreamingResponsesResponse, ) -> StreamingResponsesResponse: nonlocal sequence_number # Set sequence_number if the event has this attribute - if hasattr(event, 'sequence_number'): + if hasattr(event, "sequence_number"): event.sequence_number = sequence_number sequence_number += 1 return event @@ -1683,8 +1814,7 @@ def _increment_sequence_number_and_return( if self.use_harmony: # TODO: in streaming, we noticed this bug: # https://github.com/vllm-project/vllm/issues/25697 - await self._initialize_tool_sessions(request, context, - exit_stack) + await self._initialize_tool_sessions(request, context, exit_stack) processer = self._process_harmony_streaming_events else: processer = self._process_simple_streaming_events @@ -1703,18 +1833,27 @@ def _increment_sequence_number_and_return( type="response.created", sequence_number=-1, response=initial_response, - )) + ) + ) yield _increment_sequence_number_and_return( ResponseInProgressEvent( type="response.in_progress", sequence_number=-1, response=initial_response, - )) + ) + ) async for event_data in processer( - request, sampling_params, result_generator, context, - model_name, tokenizer, request_metadata, created_time, - _increment_sequence_number_and_return): + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + _increment_sequence_number_and_return, + ): yield event_data async def empty_async_generator(): @@ -1738,4 +1877,5 @@ async def empty_async_generator(): type="response.completed", sequence_number=-1, response=final_response.model_dump(), - )) + ) + ) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 623b1c863f77..234a31421828 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -10,22 +10,28 @@ from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument, - RerankRequest, RerankResponse, - RerankResult, RerankUsage, - ScoreRequest, ScoreResponse, - ScoreResponseData, UsageInfo) +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + RerankDocument, + RerankRequest, + RerankResponse, + RerankResult, + RerankUsage, + ScoreRequest, + ScoreResponse, + ScoreResponseData, + UsageInfo, +) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.score_utils import (ScoreContentPartParam, - ScoreMultiModalParam, - _cosine_similarity, - _validate_score_input_lens, - compress_token_type_ids, - get_score_prompt) -# yapf: enable +from vllm.entrypoints.score_utils import ( + ScoreContentPartParam, + ScoreMultiModalParam, + _cosine_similarity, + _validate_score_input_lens, + compress_token_type_ids, + get_score_prompt, +) from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger @@ -38,7 +44,6 @@ class ServingScores(OpenAIServing): - def __init__( self, engine_client: EngineClient, @@ -48,11 +53,13 @@ def __init__( request_logger: Optional[RequestLogger], log_error_stack: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - log_error_stack=log_error_stack) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + log_error_stack=log_error_stack, + ) async def _embedding_score( self, @@ -68,24 +75,23 @@ async def _embedding_score( input_texts = texts_1 + texts_2 engine_prompts: list[TokensPrompt] = [] - tokenize_async = make_async(tokenizer.__call__, - executor=self._tokenizer_executor) + tokenize_async = make_async( + tokenizer.__call__, executor=self._tokenizer_executor + ) tokenization_kwargs = tokenization_kwargs or {} tokenized_prompts = await asyncio.gather( - *(tokenize_async(t, **tokenization_kwargs) for t in input_texts)) + *(tokenize_async(t, **tokenization_kwargs) for t in input_texts) + ) for tok_result, input_text in zip(tokenized_prompts, input_texts): - - text_token_prompt = \ - self._validate_input( - request, - tok_result["input_ids"], - input_text) + text_token_prompt = self._validate_input( + request, tok_result["input_ids"], input_text + ) engine_prompts.append( - TokensPrompt( - prompt_token_ids=text_token_prompt["prompt_token_ids"])) + TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"]) + ) # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] @@ -97,13 +103,14 @@ async def _embedding_score( return self.create_error_response(str(e)) for i, engine_prompt in enumerate(engine_prompts): - request_id_item = f"{request_id}-{i}" - self._log_inputs(request_id_item, - input_texts[i], - params=pooling_params, - lora_request=lora_request) + self._log_inputs( + request_id_item, + input_texts[i], + params=pooling_params, + lora_request=lora_request, + ) generators.append( self.engine_client.encode( @@ -113,15 +120,15 @@ async def _embedding_score( lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, - )) + ) + ) result_generator = merge_async_iterators(*generators) # Non-streaming response final_res_batch: list[PoolingRequestOutput] = [] - embeddings: list[Optional[PoolingRequestOutput]] =\ - [None] * len(engine_prompts) + embeddings: list[Optional[PoolingRequestOutput]] = [None] * len(engine_prompts) async for i, res in result_generator: embeddings[i] = res @@ -140,9 +147,9 @@ async def _embedding_score( if len(emb_texts_1) == 1: emb_texts_1 = emb_texts_1 * len(emb_texts_2) - final_res_batch = _cosine_similarity(tokenizer=tokenizer, - embed_1=emb_texts_1, - embed_2=emb_texts_2) + final_res_batch = _cosine_similarity( + tokenizer=tokenizer, embed_1=emb_texts_1, embed_2=emb_texts_2 + ) return final_res_batch @@ -154,7 +161,6 @@ def _preprocess_score( data_1: Union[str, ScoreContentPartParam], data_2: Union[str, ScoreContentPartParam], ) -> tuple[str, TokensPrompt]: - model_config = self.model_config full_prompt, engine_prompt = get_score_prompt( @@ -164,8 +170,7 @@ def _preprocess_score( tokenizer=tokenizer, tokenization_kwargs=tokenization_kwargs, ) - self._validate_input(request, engine_prompt["prompt_token_ids"], - full_prompt) + self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt) if request.mm_processor_kwargs is not None: engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs @@ -189,22 +194,28 @@ async def _cross_encoding_score( data_1 = data_1 * len(data_2) if isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "MistralTokenizer not supported for cross-encoding") + raise ValueError("MistralTokenizer not supported for cross-encoding") tokenization_kwargs = tokenization_kwargs or {} input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - preprocess_async = make_async(self._preprocess_score, - executor=self._tokenizer_executor) + preprocess_async = make_async( + self._preprocess_score, executor=self._tokenizer_executor + ) preprocessed_prompts = await asyncio.gather( - *(preprocess_async(request=request, - tokenizer=tokenizer, - tokenization_kwargs=tokenization_kwargs, - data_1=t1, - data_2=t2) for t1, t2 in input_pairs)) + *( + preprocess_async( + request=request, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + data_1=t1, + data_2=t2, + ) + for t1, t2 in input_pairs + ) + ) for full_prompt, engine_prompt in preprocessed_prompts: request_prompts.append(full_prompt) @@ -223,19 +234,19 @@ async def _cross_encoding_score( for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" - self._log_inputs(request_id_item, - request_prompts[i], - params=default_pooling_params, - lora_request=lora_request) + self._log_inputs( + request_id_item, + request_prompts[i], + params=default_pooling_params, + lora_request=lora_request, + ) - if (token_type_ids := engine_prompt.pop("token_type_ids", None)): + if token_type_ids := engine_prompt.pop("token_type_ids", None): pooling_params = default_pooling_params.clone() compressed = compress_token_type_ids(token_type_ids) - pooling_params.extra_kwargs = { - "compressed_token_type_ids": compressed - } + pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed} else: - pooling_params = (default_pooling_params) + pooling_params = default_pooling_params generator = self.engine_client.encode( engine_prompt, @@ -251,8 +262,9 @@ async def _cross_encoding_score( result_generator = merge_async_iterators(*generators) # Non-streaming response - final_res_batch: list[ - Optional[PoolingRequestOutput]] = [None] * len(engine_prompts) + final_res_batch: list[Optional[PoolingRequestOutput]] = [None] * len( + engine_prompts + ) async for i, res in result_generator: final_res_batch[i] = res @@ -271,18 +283,22 @@ async def _run_scoring( tokenizer = await self.engine_client.get_tokenizer() - truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", - None) + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size(self.max_model_len, truncate_prompt_tokens, - tokenization_kwargs) + _validate_truncation_size( + self.max_model_len, truncate_prompt_tokens, tokenization_kwargs + ) - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) + trace_headers = ( + None + if raw_request is None + else await self._get_trace_headers(raw_request.headers) + ) - if not self.model_config.is_multimodal_model and (isinstance( - data_1, dict) or isinstance(data_2, dict)): + if not self.model_config.is_multimodal_model and ( + isinstance(data_1, dict) or isinstance(data_2, dict) + ): raise ValueError( f"MultiModalParam is not supported for {self.model_config.architecture}" # noqa: E501 ) @@ -308,7 +324,8 @@ async def _run_scoring( request_id=request_id, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - trace_headers=trace_headers) + trace_headers=trace_headers, + ) else: return await self._embedding_score( @@ -319,7 +336,8 @@ async def _run_scoring( request_id=request_id, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - trace_headers=trace_headers) + trace_headers=trace_headers, + ) async def create_score( self, @@ -362,9 +380,7 @@ async def create_score( return self.create_error_response(str(e)) async def do_rerank( - self, - request: RerankRequest, - raw_request: Optional[Request] = None + self, request: RerankRequest, raw_request: Optional[Request] = None ) -> Union[RerankResponse, ErrorResponse]: """ Rerank API based on JinaAI's rerank API; implements the same @@ -381,9 +397,15 @@ async def do_rerank( request_id = f"rerank-{self._base_request_id(raw_request)}" documents = request.documents - top_n = request.top_n if request.top_n > 0 else ( - len(documents) - if isinstance(documents, list) else len(documents["content"])) + top_n = ( + request.top_n + if request.top_n > 0 + else ( + len(documents) + if isinstance(documents, list) + else len(documents["content"]) + ) + ) try: final_res_batch = await self._run_scoring( @@ -445,9 +467,13 @@ def request_output_to_score_response( ) def request_output_to_rerank_response( - self, final_res_batch: list[PoolingRequestOutput], request_id: str, - model_name: str, documents: Union[list[str], ScoreMultiModalParam], - top_n: int) -> RerankResponse: + self, + final_res_batch: list[PoolingRequestOutput], + request_id: str, + model_name: str, + documents: Union[list[str], ScoreMultiModalParam], + top_n: int, + ) -> RerankResponse: """ Convert the output of do_rank to a RerankResponse """ @@ -458,9 +484,9 @@ def request_output_to_rerank_response( result = RerankResult( index=idx, - document=RerankDocument(text=documents[idx]) if isinstance( - documents, list) else RerankDocument( - multi_modal=documents["content"][idx]), + document=RerankDocument(text=documents[idx]) + if isinstance(documents, list) + else RerankDocument(multi_modal=documents["content"][idx]), relevance_score=classify_res.outputs.score, ) results.append(result) @@ -476,4 +502,5 @@ def request_output_to_rerank_response( id=request_id, model=model_name, results=results, - usage=RerankUsage(total_tokens=num_prompt_tokens)) + usage=RerankUsage(total_tokens=num_prompt_tokens), + ) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 3918d08ebf81..7b192dcd6c86 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -10,16 +10,15 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption from vllm.entrypoints.logger import RequestLogger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.entrypoints.openai.protocol import (DetokenizeRequest, - DetokenizeResponse, - ErrorResponse, - TokenizeChatRequest, - TokenizeRequest, - TokenizeResponse, - TokenizerInfoResponse) -# yapf: enable +from vllm.entrypoints.openai.protocol import ( + DetokenizeRequest, + DetokenizeResponse, + ErrorResponse, + TokenizeChatRequest, + TokenizeRequest, + TokenizeResponse, + TokenizerInfoResponse, +) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.renderer import RenderConfig @@ -30,7 +29,6 @@ class OpenAIServingTokenization(OpenAIServing): - def __init__( self, engine_client: EngineClient, @@ -40,16 +38,20 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + trust_request_chat_template: bool = False, log_error_stack: bool = False, ) -> None: - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - log_error_stack=log_error_stack) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + log_error_stack=log_error_stack, + ) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.trust_request_chat_template = trust_request_chat_template async def create_tokenize( self, @@ -69,8 +71,18 @@ async def create_tokenize( renderer = self._get_renderer(tokenizer) if isinstance(request, TokenizeChatRequest): - tool_dicts = (None if request.tools is None else - [tool.model_dump() for tool in request.tools]) + tool_dicts = ( + None + if request.tools is None + else [tool.model_dump() for tool in request.tools] + ) + error_check_ret = self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + if error_check_ret is not None: + return error_check_ret ( _, _, @@ -81,8 +93,7 @@ async def create_tokenize( request.messages, tool_dicts=tool_dicts, chat_template=request.chat_template or self.chat_template, - chat_template_content_format=self. - chat_template_content_format, + chat_template_content_format=self.chat_template_content_format, add_generation_prompt=request.add_generation_prompt, continue_final_message=request.continue_final_message, chat_template_kwargs=request.chat_template_kwargs, @@ -99,23 +110,23 @@ async def create_tokenize( input_ids: list[int] = [] for engine_prompt in engine_prompts: - self._log_inputs(request_id, - engine_prompt, - params=None, - lora_request=lora_request) + self._log_inputs( + request_id, engine_prompt, params=None, lora_request=lora_request + ) - if isinstance(engine_prompt, - dict) and "prompt_token_ids" in engine_prompt: + if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt: input_ids.extend(engine_prompt["prompt_token_ids"]) token_strs = None if request.return_token_strs: token_strs = tokenizer.convert_ids_to_tokens(input_ids) - return TokenizeResponse(tokens=input_ids, - token_strs=token_strs, - count=len(input_ids), - max_model_len=self.max_model_len) + return TokenizeResponse( + tokens=input_ids, + token_strs=token_strs, + count=len(input_ids), + max_model_len=self.max_model_len, + ) async def create_detokenize( self, @@ -132,10 +143,9 @@ async def create_detokenize( tokenizer = await self.engine_client.get_tokenizer() - self._log_inputs(request_id, - request.tokens, - params=None, - lora_request=lora_request) + self._log_inputs( + request_id, request.tokens, params=None, lora_request=lora_request + ) prompt_input = await self._tokenize_prompt_input_async( request, @@ -147,15 +157,15 @@ async def create_detokenize( return DetokenizeResponse(prompt=input_text) async def get_tokenizer_info( - self, ) -> Union[TokenizerInfoResponse, ErrorResponse]: + self, + ) -> Union[TokenizerInfoResponse, ErrorResponse]: """Get comprehensive tokenizer information.""" try: tokenizer = await self.engine_client.get_tokenizer() info = TokenizerInfo(tokenizer, self.chat_template).to_dict() return TokenizerInfoResponse(**info) except Exception as e: - return self.create_error_response( - f"Failed to get tokenizer info: {str(e)}") + return self.create_error_response(f"Failed to get tokenizer info: {str(e)}") def _build_render_config(self, request: TokenizeRequest) -> RenderConfig: return RenderConfig(add_special_tokens=request.add_special_tokens) diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 9ba58d442522..6cc31c1e08d3 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -9,10 +9,17 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( - ErrorResponse, RequestResponseMetadata, TranscriptionRequest, - TranscriptionResponse, TranscriptionResponseStreamChoice, - TranscriptionStreamResponse, TranslationRequest, TranslationResponse, - TranslationResponseStreamChoice, TranslationStreamResponse) + ErrorResponse, + RequestResponseMetadata, + TranscriptionRequest, + TranscriptionResponse, + TranscriptionResponseStreamChoice, + TranscriptionStreamResponse, + TranslationRequest, + TranslationResponse, + TranslationResponseStreamChoice, + TranslationStreamResponse, +) from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.speech_to_text import OpenAISpeechToText from vllm.logger import init_logger @@ -34,19 +41,19 @@ def __init__( return_tokens_as_token_ids: bool = False, log_error_stack: bool = False, ): - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="transcribe", - log_error_stack=log_error_stack) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + task_type="transcribe", + log_error_stack=log_error_stack, + ) async def create_transcription( - self, audio_data: bytes, request: TranscriptionRequest, - raw_request: Request - ) -> Union[TranscriptionResponse, AsyncGenerator[str, None], - ErrorResponse]: + self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request + ) -> Union[TranscriptionResponse, AsyncGenerator[str, None], ErrorResponse]: """Transcription API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/audio/createTranscription @@ -61,10 +68,13 @@ async def create_transcription( ) async def transcription_stream_generator( - self, request: TranscriptionRequest, - result_generator: list[AsyncGenerator[RequestOutput, None]], - request_id: str, request_metadata: RequestResponseMetadata, - audio_duration_s: float) -> AsyncGenerator[str, None]: + self, + request: TranscriptionRequest, + result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, + request_metadata: RequestResponseMetadata, + audio_duration_s: float, + ) -> AsyncGenerator[str, None]: generator = self._speech_to_text_stream_generator( request=request, list_result_generator=result_generator, @@ -92,17 +102,18 @@ def __init__( return_tokens_as_token_ids: bool = False, log_error_stack: bool = False, ): - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="translate", - log_error_stack=log_error_stack) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + task_type="translate", + log_error_stack=log_error_stack, + ) async def create_translation( - self, audio_data: bytes, request: TranslationRequest, - raw_request: Request + self, audio_data: bytes, request: TranslationRequest, raw_request: Request ) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]: """Translation API similar to OpenAI's API. @@ -118,10 +129,13 @@ async def create_translation( ) async def translation_stream_generator( - self, request: TranslationRequest, - result_generator: list[AsyncGenerator[RequestOutput, None]], - request_id: str, request_metadata: RequestResponseMetadata, - audio_duration_s: float) -> AsyncGenerator[str, None]: + self, + request: TranslationRequest, + result_generator: list[AsyncGenerator[RequestOutput, None]], + request_id: str, + request_metadata: RequestResponseMetadata, + audio_duration_s: float, + ) -> AsyncGenerator[str, None]: generator = self._speech_to_text_stream_generator( request=request, list_result_generator=result_generator, diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 965bdac3ac5a..779498b308e8 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -16,12 +16,18 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( - DeltaMessage, ErrorResponse, RequestResponseMetadata, - TranscriptionResponse, TranscriptionResponseStreamChoice, - TranscriptionStreamResponse, TranslationResponse, - TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - SpeechToTextRequest) + DeltaMessage, + ErrorResponse, + RequestResponseMetadata, + TranscriptionResponse, + TranscriptionResponseStreamChoice, + TranscriptionStreamResponse, + TranslationResponse, + TranslationResponseStreamChoice, + TranslationStreamResponse, + UsageInfo, +) +from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import PromptType from vllm.logger import init_logger @@ -41,7 +47,7 @@ class OpenAISpeechToText(OpenAIServing): - """Base class for speech-to-text operations like transcription and + """Base class for speech-to-text operations like transcription and translation.""" def __init__( @@ -55,30 +61,34 @@ def __init__( task_type: Literal["transcribe", "translate"] = "transcribe", log_error_stack: bool = False, ): - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - log_error_stack=log_error_stack) - - self.default_sampling_params = ( - self.model_config.get_diff_sampling_param()) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + log_error_stack=log_error_stack, + ) + + self.default_sampling_params = self.model_config.get_diff_sampling_param() self.task_type = task_type self.asr_config = self.model_cls.get_speech_to_text_config( - model_config, task_type) + model_config, task_type + ) self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB if self.default_sampling_params: logger.info( "Overwriting default completion sampling param with: %s", - self.default_sampling_params) + self.default_sampling_params, + ) @cached_property def model_cls(self) -> type[SupportsTranscription]: from vllm.model_executor.model_loader import get_model_cls + model_cls = get_model_cls(self.model_config) return cast(type[SupportsTranscription], model_cls) @@ -90,8 +100,11 @@ async def _preprocess_speech_to_text( # Validate request language = self.model_cls.validate_language(request.language) # Skip to_language validation to avoid extra logging for Whisper. - to_language = self.model_cls.validate_language(request.to_language) \ - if request.to_language else None + to_language = ( + self.model_cls.validate_language(request.to_language) + if request.to_language + else None + ) if len(audio_data) / 1024**2 > self.max_audio_filesize_mb: raise ValueError("Maximum file size exceeded.") @@ -102,8 +115,10 @@ async def _preprocess_speech_to_text( y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) duration = librosa.get_duration(y=y, sr=sr) - do_split_audio = (self.asr_config.allow_audio_chunking - and duration > self.asr_config.max_audio_clip_s) + do_split_audio = ( + self.asr_config.allow_audio_chunking + and duration > self.asr_config.max_audio_clip_s + ) chunks = [y] if not do_split_audio else self._split_audio(y, int(sr)) prompts = [] for chunk in chunks: @@ -129,7 +144,7 @@ async def _create_speech_to_text( response_class: type[T], stream_generator_method: Callable[..., AsyncGenerator[str, None]], ) -> Union[T, AsyncGenerator[str, None], ErrorResponse]: - """Base method for speech-to-text operations like transcription and + """Base method for speech-to-text operations like transcription and translation.""" error_check_ret = await self._check_model(request) if error_check_ret is not None: @@ -141,9 +156,10 @@ async def _create_speech_to_text( if self.engine_client.errored: raise self.engine_client.dead_error - if request.response_format not in ['text', 'json']: + if request.response_format not in ["text", "json"]: return self.create_error_response( - "Currently only support response_format `text` or `json`") + "Currently only support response_format `text` or `json`" + ) request_id = f"{self.task_type}-{self._base_request_id(raw_request)}" @@ -156,8 +172,8 @@ async def _create_speech_to_text( if lora_request: return self.create_error_response( - "Currently do not support LoRA for " - f"{self.task_type.title()}.") + f"Currently do not support LoRA for {self.task_type.title()}." + ) prompts, duration_s = await self._preprocess_speech_to_text( request=request, @@ -168,38 +184,42 @@ async def _create_speech_to_text( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) - list_result_generator: Optional[list[AsyncGenerator[RequestOutput, - None]]] = None + list_result_generator: Optional[list[AsyncGenerator[RequestOutput, None]]] = ( + None + ) try: # Unlike most decoder-only models, whisper generation length is not # constrained by the size of the input audio, which is mapped to a # fixed-size log-mel-spectogram. default_max_tokens = self.model_config.max_model_len sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) + default_max_tokens, self.default_sampling_params + ) self._log_inputs( request_id, # It will not display special tokens like <|startoftranscript|> request.prompt, params=sampling_params, - lora_request=None) + lora_request=None, + ) list_result_generator = [ self.engine_client.generate( prompt, sampling_params, request_id, - ) for prompt in prompts + ) + for prompt in prompts ] except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) if request.stream: - return stream_generator_method(request, list_result_generator, - request_id, request_metadata, - duration_s) + return stream_generator_method( + request, list_result_generator, request_id, request_metadata, duration_s + ) # Non-streaming response. try: assert list_result_generator is not None @@ -215,12 +235,10 @@ async def _create_speech_to_text( # rounded up as per openAI specs "seconds": int(math.ceil(duration_s)), } - final_response = cast(T, response_class(text=text, - usage=usage)) + final_response = cast(T, response_class(text=text, usage=usage)) else: # no usage in response for translation task - final_response = cast( - T, response_class(text=text)) # type: ignore[call-arg] + final_response = cast(T, response_class(text=text)) # type: ignore[call-arg] return final_response except asyncio.CancelledError: @@ -239,9 +257,11 @@ async def _speech_to_text_stream_generator( chunk_object_type: Literal["translation.chunk", "transcription.chunk"], response_stream_choice_class: Union[ type[TranscriptionResponseStreamChoice], - type[TranslationResponseStreamChoice]], - stream_response_class: Union[type[TranscriptionStreamResponse], - type[TranslationStreamResponse]], + type[TranslationResponseStreamChoice], + ], + stream_response_class: Union[ + type[TranscriptionStreamResponse], type[TranslationStreamResponse] + ], ) -> AsyncGenerator[str, None]: created_time = int(time.time()) model_name = request.model @@ -249,11 +269,14 @@ async def _speech_to_text_stream_generator( completion_tokens = 0 num_prompt_tokens = 0 - include_usage = request.stream_include_usage \ - if request.stream_include_usage else False - include_continuous_usage = request.stream_continuous_usage_stats\ - if include_usage and request.stream_continuous_usage_stats\ + include_usage = ( + request.stream_include_usage if request.stream_include_usage else False + ) + include_continuous_usage = ( + request.stream_continuous_usage_stats + if include_usage and request.stream_continuous_usage_stats else False + ) try: for result_generator in list_result_generator: @@ -262,8 +285,8 @@ async def _speech_to_text_stream_generator( if res.prompt_token_ids is not None: num_prompt_tokens = len(res.prompt_token_ids) if audio_tokens := self.model_cls.get_num_audio_tokens( - audio_duration_s, self.asr_config, - self.model_config): + audio_duration_s, self.asr_config, self.model_config + ): num_prompt_tokens += audio_tokens # We need to do it here, because if there are exceptions in @@ -279,20 +302,22 @@ async def _speech_to_text_stream_generator( if output.finish_reason is None: # Still generating, send delta update. - choice_data = response_stream_choice_class( - delta=delta_message) + choice_data = response_stream_choice_class(delta=delta_message) else: # Model is finished generating. choice_data = response_stream_choice_class( delta=delta_message, finish_reason=output.finish_reason, - stop_reason=output.stop_reason) + stop_reason=output.stop_reason, + ) - chunk = stream_response_class(id=request_id, - object=chunk_object_type, - created=created_time, - choices=[choice_data], - model=model_name) + chunk = stream_response_class( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name, + ) # handle usage stats if requested & if continuous if include_continuous_usage: @@ -308,10 +333,11 @@ async def _speech_to_text_stream_generator( # Once the final token is handled, if stream_options.include_usage # is sent, send the usage. if include_usage: - final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + - completion_tokens) + final_usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) final_usage_chunk = stream_response_class( id=request_id, @@ -319,16 +345,19 @@ async def _speech_to_text_stream_generator( created=created_time, choices=[], model=model_name, - usage=final_usage) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=True, exclude_none=True)) + usage=final_usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True + ) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices request_metadata.final_usage_info = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens) + total_tokens=num_prompt_tokens + completion_tokens, + ) except Exception as e: # TODO: Use a vllm-specific Validation Error @@ -338,8 +367,9 @@ async def _speech_to_text_stream_generator( # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" - def _split_audio(self, audio_data: np.ndarray, - sample_rate: int) -> list[np.ndarray]: + def _split_audio( + self, audio_data: np.ndarray, sample_rate: int + ) -> list[np.ndarray]: chunk_size = sample_rate * self.asr_config.max_audio_clip_s overlap_size = sample_rate * self.asr_config.overlap_chunk_second chunks = [] @@ -353,17 +383,15 @@ def _split_audio(self, audio_data: np.ndarray, # Find the best split point in the overlap region search_start = i + chunk_size - overlap_size search_end = min(i + chunk_size, audio_data.shape[-1]) - split_point = self._find_split_point(audio_data, search_start, - search_end) + split_point = self._find_split_point(audio_data, search_start, search_end) # Extract chunk up to the split point chunks.append(audio_data[..., i:split_point]) i = split_point return chunks - def _find_split_point(self, wav: np.ndarray, start_idx: int, - end_idx: int) -> int: - """Find the best point to split audio by + def _find_split_point(self, wav: np.ndarray, start_idx: int, end_idx: int) -> int: + """Find the best point to split audio by looking for silence or low amplitude. Args: wav: Audio tensor [1, T] @@ -380,8 +408,8 @@ def _find_split_point(self, wav: np.ndarray, start_idx: int, min_energy_window = self.asr_config.min_energy_split_window_size assert min_energy_window is not None for i in range(0, len(segment) - min_energy_window, min_energy_window): - window = segment[i:i + min_energy_window] - energy = (window**2).mean()**0.5 + window = segment[i : i + min_energy_window] + energy = (window**2).mean() ** 0.5 if energy < min_energy: quietest_idx = i + start_idx min_energy = energy diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 02aeab613631..e6ee2fa777f8 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -6,9 +6,11 @@ from functools import cached_property from typing import Callable, Optional, Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, - ExtractedToolCallInformation) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import import_from_path, is_list_of @@ -38,16 +40,15 @@ def vocab(self) -> dict[str, int]: # whereas all tokenizers have .get_vocab() return self.model_tokenizer.get_vocab() - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: """ Static method that used to adjust the request parameters. """ return request def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Static method that should be implemented for extracting tool calls from a complete model-generated string. @@ -56,7 +57,8 @@ def extract_tool_calls( Static because it's stateless. """ raise NotImplementedError( - "AbstractToolParser.extract_tool_calls has not been implemented!") + "AbstractToolParser.extract_tool_calls has not been implemented!" + ) def extract_tool_calls_streaming( self, @@ -76,8 +78,8 @@ def extract_tool_calls_streaming( previously been parsed and extracted (see constructor) """ raise NotImplementedError( - "AbstractToolParser.extract_tool_calls_streaming has not been " - "implemented!") + "AbstractToolParser.extract_tool_calls_streaming has not been implemented!" + ) class ToolParserManager: @@ -96,13 +98,15 @@ def get_tool_parser(cls, name) -> type: raise KeyError(f"tool helper: '{name}' not found in tool_parsers") @classmethod - def _register_module(cls, - module: type, - module_name: Optional[Union[str, list[str]]] = None, - force: bool = True) -> None: + def _register_module( + cls, + module: type, + module_name: Optional[Union[str, list[str]]] = None, + force: bool = True, + ) -> None: if not issubclass(module, ToolParser): raise TypeError( - f'module must be subclass of ToolParser, but got {type(module)}' + f"module must be subclass of ToolParser, but got {type(module)}" ) if module_name is None: module_name = module.__name__ @@ -111,30 +115,32 @@ def _register_module(cls, for name in module_name: if not force and name in cls.tool_parsers: existed_module = cls.tool_parsers[name] - raise KeyError(f'{name} is already registered ' - f'at {existed_module.__module__}') + raise KeyError( + f"{name} is already registered at {existed_module.__module__}" + ) cls.tool_parsers[name] = module @classmethod def register_module( - cls, - name: Optional[Union[str, list[str]]] = None, - force: bool = True, - module: Union[type, None] = None) -> Union[type, Callable]: + cls, + name: Optional[Union[str, list[str]]] = None, + force: bool = True, + module: Union[type, None] = None, + ) -> Union[type, Callable]: """ Register module with the given name or name list. it can be used as a - decoder(with module as None) or normal function(with module as not + decoder(with module as None) or normal function(with module as not None). """ if not isinstance(force, bool): - raise TypeError(f'force must be a boolean, but got {type(force)}') + raise TypeError(f"force must be a boolean, but got {type(force)}") # raise the error ahead of time - if not (name is None or isinstance(name, str) - or is_list_of(name, str)): + if not (name is None or isinstance(name, str) or is_list_of(name, str)): raise TypeError( - 'name must be None, an instance of str, or a sequence of str, ' - f'but got {type(name)}') + "name must be None, an instance of str, or a sequence of str, " + f"but got {type(name)}" + ) # use it as a normal method: x.register_module(module=SomeClass) if module is not None: @@ -159,6 +165,7 @@ def import_tool_parser(cls, plugin_path: str) -> None: try: import_from_path(module_name, plugin_path) except Exception: - logger.exception("Failed to load module '%s' from %s.", - module_name, plugin_path) + logger.exception( + "Failed to load module '%s' from %s.", module_name, plugin_path + ) return diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py index 09095f899177..c6e8f1686e24 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py @@ -7,13 +7,19 @@ import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -22,15 +28,15 @@ @ToolParserManager.register_module("deepseek_v31") class DeepSeekV31ToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = ( - []) # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" self.tool_calls_end_token: str = "<|tool▁calls▁end|>" @@ -43,41 +49,43 @@ def __init__(self, tokenizer: AnyTokenizer): ) self.stream_tool_call_portion_regex = re.compile( - r"(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)") + r"(?P<function_name>.*)<|tool▁sep|>(?P<function_arguments>.*)" + ) self.stream_tool_call_name_regex = re.compile( - r"(?P<function_name>.*)<|tool▁sep|>") + r"(?P<function_name>.*)<|tool▁sep|>" + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") - self.tool_calls_start_token_id = self.vocab.get( - self.tool_calls_start_token) - self.tool_calls_end_token_id = self.vocab.get( - self.tool_calls_end_token) - - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + "constructor during construction." + ) + self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_calls_start_token_id is None - or self.tool_calls_end_token_id is None): + if ( + self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None + ): raise RuntimeError( "DeepSeek-V3.1 Tool parser could not locate tool call " - "start/end tokens in the tokenizer!") + "start/end tokens in the tokenizer!" + ) def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - # sanity check; avoid unnecessary processing if self.tool_calls_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: try: @@ -85,8 +93,7 @@ def extract_tool_calls( # tag and end-of-string so the result of # findall is an array of tuples where one is a function call and # the other is None - function_call_tuples = self.tool_call_regex.findall( - model_output) + function_call_tuples = self.tool_call_regex.findall(model_output) tool_calls = [] for match in function_call_tuples: @@ -94,12 +101,13 @@ def extract_tool_calls( tool_calls.append( ToolCall( type="function", - function=FunctionCall(name=function_name, - arguments=function_args), - )) + function=FunctionCall( + name=function_name, arguments=function_args + ), + ) + ) - content = model_output[:model_output. - find(self.tool_calls_start_token)] + content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -107,11 +115,10 @@ def extract_tool_calls( ) except Exception: - logger.exception( - "Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -123,55 +130,58 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a if self.tool_calls_start_token_id not in current_token_ids: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) - delta_text = delta_text.replace(self.tool_calls_start_token, - "").replace(self.tool_calls_end_token, - "") + delta_text = delta_text.replace(self.tool_calls_start_token, "").replace( + self.tool_calls_end_token, "" + ) try: - # figure out where we are in the parsing by counting tool call # start & end tags prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) tool_call_portion = None text_portion = None # case: if we're generating text, OR rounding out a tool call - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count - and self.tool_call_end_token not in delta_text): + if ( + cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text + ): logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text - tool_call_portion = full_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0].rstrip() - delta_text = delta_text.split( - self.tool_call_end_token)[0].rstrip() - text_portion = delta_text.split( - self.tool_call_end_token)[-1].lstrip() + tool_call_portion = ( + full_text.split(self.tool_call_start_token)[-1] + .split(self.tool_call_end_token)[0] + .rstrip() + ) + delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() # case -- we're starting a new tool call - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): + if ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count + ): if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[ + -1 + ] else: tool_call_portion = None delta = None @@ -185,27 +195,29 @@ def extract_tool_calls_streaming( logger.debug("Starting on a new tool %s", self.current_tool_id) # case -- we're updating an existing tool call - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): - + elif ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count + ): # get the portion of the text that's the tool call - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None # case -- the current tool call is being closed. - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count >= prev_tool_end_count): - if self.prev_tool_call_arr is None or len( - self.prev_tool_call_arr) == 0: - logger.debug( - "attempting to close tool call, but no tool call") + elif ( + cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count + ): + if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: + logger.debug("attempting to close tool call, but no tool call") return None - diff = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: - diff = (diff.encode("utf-8").decode("unicode_escape") - if diff is str else diff) + diff = ( + diff.encode("utf-8").decode("unicode_escape") + if diff is str + else diff + ) if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') @@ -216,13 +228,16 @@ def extract_tool_calls_streaming( diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump(exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ] + ) # case -- otherwise we're just generating text else: @@ -233,17 +248,17 @@ def extract_tool_calls_streaming( current_tool_call = dict() if tool_call_portion: - current_tool_call_matches = ( - self.stream_tool_call_portion_regex.match( - tool_call_portion)) + current_tool_call_matches = self.stream_tool_call_portion_regex.match( + tool_call_portion + ) if current_tool_call_matches: tool_name, tool_args = current_tool_call_matches.groups() current_tool_call["name"] = tool_name current_tool_call["arguments"] = tool_args else: current_tool_call_name_matches = ( - self.stream_tool_call_name_regex.match( - tool_call_portion)) + self.stream_tool_call_name_regex.match(tool_call_portion) + ) if current_tool_call_name_matches: tool_name = current_tool_call_name_matches.groups() current_tool_call["name"] = tool_name @@ -260,16 +275,18 @@ def extract_tool_calls_streaming( function_name: Union[str, None] = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) else: return None @@ -279,15 +296,19 @@ def extract_tool_calls_streaming( if tool_call_portion is None: # if there's text but not tool calls, send that - # otherwise None to skip chunk - delta = (DeltaMessage( - content=delta_text) if text_portion is not None else None) + delta = ( + DeltaMessage(content=delta_text) + if text_portion is not None + else None + ) return delta # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. - logger.debug("Trying to parse current tool call with ID %s", - self.current_tool_id) + logger.debug( + "Trying to parse current tool call with ID %s", self.current_tool_id + ) # if we're starting a new tool call, push an empty object in as # a placeholder for the arguments @@ -297,7 +318,8 @@ def extract_tool_calls_streaming( # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + "arguments" + ) cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) @@ -311,52 +333,56 @@ def extract_tool_calls_streaming( # case -- prev arguments are defined, but non are now. # probably impossible, but not a fatal error - just keep going elif not cur_arguments and prev_arguments: - logger.error("should be impossible to have arguments reset " - "mid-call. skipping streaming anything.") + logger.error( + "should be impossible to have arguments reset " + "mid-call. skipping streaming anything." + ) delta = None # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=cur_arguments).model_dump( - exclude_none=True), - ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: - if (isinstance(delta_text, str) - and cur_arguments != prev_arguments - and len(cur_arguments) > len(prev_arguments) - and cur_arguments.startswith(prev_arguments)): - delta_arguments = cur_arguments[len(prev_arguments):] + if ( + isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments) + ): + delta_arguments = cur_arguments[len(prev_arguments) :] logger.debug("got diff %s", delta_text) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=delta_arguments).model_dump( - exclude_none=True), - ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments else: delta = None # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py index ac272b0c3b20..e8a5d2e6dc13 100644 --- a/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py @@ -7,13 +7,19 @@ import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -22,15 +28,15 @@ @ToolParserManager.register_module("deepseek_v3") class DeepSeekV3ToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = ( - []) # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" self.tool_calls_end_token: str = "<|tool▁calls▁end|>" @@ -47,38 +53,39 @@ def __init__(self, tokenizer: AnyTokenizer): ) self.stream_tool_call_name_regex = re.compile( - r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n") + r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n" + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") - self.tool_calls_start_token_id = self.vocab.get( - self.tool_calls_start_token) - self.tool_calls_end_token_id = self.vocab.get( - self.tool_calls_end_token) - - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + "constructor during construction." + ) + self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_calls_start_token_id is None - or self.tool_calls_end_token_id is None): + if ( + self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None + ): raise RuntimeError( "DeepSeek-V3 Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") + "tokens in the tokenizer!" + ) def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - # sanity check; avoid unnecessary processing if self.tool_calls_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: try: @@ -86,8 +93,7 @@ def extract_tool_calls( # tag and end-of-string so the result of # findall is an array of tuples where one is a function call and # the other is None - function_call_tuples = self.tool_call_regex.findall( - model_output) + function_call_tuples = self.tool_call_regex.findall(model_output) tool_calls = [] for match in function_call_tuples: @@ -95,12 +101,13 @@ def extract_tool_calls( tool_calls.append( ToolCall( type=tool_type, - function=FunctionCall(name=function_name, - arguments=function_args), - )) + function=FunctionCall( + name=function_name, arguments=function_args + ), + ) + ) - content = model_output[:model_output. - find(self.tool_calls_start_token)] + content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -108,11 +115,10 @@ def extract_tool_calls( ) except Exception: - logger.exception( - "Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -124,55 +130,58 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a if self.tool_calls_start_token_id not in current_token_ids: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) - delta_text = delta_text.replace(self.tool_calls_start_token, - "").replace(self.tool_calls_end_token, - "") + delta_text = delta_text.replace(self.tool_calls_start_token, "").replace( + self.tool_calls_end_token, "" + ) try: - # figure out where we are in the parsing by counting tool call # start & end tags prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) tool_call_portion = None text_portion = None # case: if we're generating text, OR rounding out a tool call - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count - and self.tool_call_end_token not in delta_text): + if ( + cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text + ): logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text - tool_call_portion = full_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0].rstrip() - delta_text = delta_text.split( - self.tool_call_end_token)[0].rstrip() - text_portion = delta_text.split( - self.tool_call_end_token)[-1].lstrip() + tool_call_portion = ( + full_text.split(self.tool_call_start_token)[-1] + .split(self.tool_call_end_token)[0] + .rstrip() + ) + delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() # case -- we're starting a new tool call - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): + if ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count + ): if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[ + -1 + ] else: tool_call_portion = None delta = None @@ -186,27 +195,29 @@ def extract_tool_calls_streaming( logger.debug("Starting on a new tool %s", self.current_tool_id) # case -- we're updating an existing tool call - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): - + elif ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count + ): # get the portion of the text that's the tool call - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None # case -- the current tool call is being closed. - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count >= prev_tool_end_count): - if self.prev_tool_call_arr is None or len( - self.prev_tool_call_arr) == 0: - logger.debug( - "attempting to close tool call, but no tool call") + elif ( + cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count + ): + if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: + logger.debug("attempting to close tool call, but no tool call") return None - diff = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: - diff = (diff.encode("utf-8").decode("unicode_escape") - if diff is str else diff) + diff = ( + diff.encode("utf-8").decode("unicode_escape") + if diff is str + else diff + ) if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') @@ -217,13 +228,16 @@ def extract_tool_calls_streaming( diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump(exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ] + ) # case -- otherwise we're just generating text else: @@ -234,21 +248,19 @@ def extract_tool_calls_streaming( current_tool_call = dict() if tool_call_portion: - current_tool_call_matches = ( - self.stream_tool_call_portion_regex.match( - tool_call_portion)) + current_tool_call_matches = self.stream_tool_call_portion_regex.match( + tool_call_portion + ) if current_tool_call_matches: - tool_type, tool_name, tool_args = ( - current_tool_call_matches.groups()) + tool_type, tool_name, tool_args = current_tool_call_matches.groups() current_tool_call["name"] = tool_name current_tool_call["arguments"] = tool_args else: current_tool_call_name_matches = ( - self.stream_tool_call_name_regex.match( - tool_call_portion)) + self.stream_tool_call_name_regex.match(tool_call_portion) + ) if current_tool_call_name_matches: - tool_type, tool_name = ( - current_tool_call_name_matches.groups()) + tool_type, tool_name = current_tool_call_name_matches.groups() current_tool_call["name"] = tool_name current_tool_call["arguments"] = "" else: @@ -263,16 +275,18 @@ def extract_tool_calls_streaming( function_name: Union[str, None] = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) else: return None @@ -282,15 +296,19 @@ def extract_tool_calls_streaming( if tool_call_portion is None: # if there's text but not tool calls, send that - # otherwise None to skip chunk - delta = (DeltaMessage( - content=delta_text) if text_portion is not None else None) + delta = ( + DeltaMessage(content=delta_text) + if text_portion is not None + else None + ) return delta # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. - logger.debug("Trying to parse current tool call with ID %s", - self.current_tool_id) + logger.debug( + "Trying to parse current tool call with ID %s", self.current_tool_id + ) # if we're starting a new tool call, push an empty object in as # a placeholder for the arguments @@ -300,7 +318,8 @@ def extract_tool_calls_streaming( # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + "arguments" + ) cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) @@ -314,52 +333,56 @@ def extract_tool_calls_streaming( # case -- prev arguments are defined, but non are now. # probably impossible, but not a fatal error - just keep going elif not cur_arguments and prev_arguments: - logger.error("should be impossible to have arguments reset " - "mid-call. skipping streaming anything.") + logger.error( + "should be impossible to have arguments reset " + "mid-call. skipping streaming anything." + ) delta = None # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=cur_arguments).model_dump( - exclude_none=True), - ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: - if (isinstance(delta_text, str) - and cur_arguments != prev_arguments - and len(cur_arguments) > len(prev_arguments) - and cur_arguments.startswith(prev_arguments)): - delta_arguments = cur_arguments[len(prev_arguments):] + if ( + isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments) + ): + delta_arguments = cur_arguments[len(prev_arguments) :] logger.debug("got diff %s", delta_text) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=delta_arguments).model_dump( - exclude_none=True), - ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments else: delta = None # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py index 8fd14f171d0a..1d7d7d3f8629 100644 --- a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py @@ -8,14 +8,20 @@ import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -24,7 +30,6 @@ @ToolParserManager.register_module("glm45") class Glm4MoeModelToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent = False @@ -36,20 +41,20 @@ def __init__(self, tokenizer: AnyTokenizer): self.tool_calls_start_token = self.tool_call_start_token - self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>", - re.DOTALL) + self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL) self.func_detail_regex = re.compile( - r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL) + r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL + ) self.func_arg_regex = re.compile( - r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", - re.DOTALL) + r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") + "constructor during construction." + ) - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) self._buffer = "" @@ -58,18 +63,22 @@ def extract_tool_calls( model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - def _is_string_type( - tool_name: str, arg_name: str, - tools: Optional[list[ChatCompletionToolsParam]]) -> bool: + tool_name: str, + arg_name: str, + tools: Optional[list[ChatCompletionToolsParam]], + ) -> bool: if tools is None: return False for tool in tools: if tool.function.name == tool_name: if tool.function.parameters is None: return False - arg_type = tool.function.parameters.get( - "properties", {}).get(arg_name, {}).get("type", None) + arg_type = ( + tool.function.parameters.get("properties", {}) + .get(arg_name, {}) + .get("type", None) + ) return arg_type == "string" logger.warning("No tool named '%s'.", tool_name) return False @@ -101,28 +110,30 @@ def _deserialize(value: str) -> Any: arg_val = value.strip() if not _is_string_type(tc_name, arg_key, request.tools): arg_val = _deserialize(arg_val) - logger.debug("arg_key = %s, arg_val = %s", arg_key, - arg_val) + logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val) arg_dct[arg_key] = arg_val tool_calls.append( - ToolCall(type="function", - function=FunctionCall( - name=tc_name, arguments=json.dumps(arg_dct)))) + ToolCall( + type="function", + function=FunctionCall( + name=tc_name, arguments=json.dumps(arg_dct) + ), + ) + ) except Exception: logger.exception("Failed to extract tool call spec") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: if len(tool_calls) > 0: - content = model_output[:model_output. - find(self.tool_calls_start_token)] - return ExtractedToolCallInformation(tools_called=True, - tool_calls=tool_calls, - content=content) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + content = model_output[: model_output.find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=content + ) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -155,7 +166,8 @@ def extract_tool_calls_streaming( self.streamed_args_for_tool.append("") extracted_tool_calls = self.extract_tool_calls( - cur_text[:end_idx + len(self.tool_call_end_token)], request) + cur_text[: end_idx + len(self.tool_call_end_token)], request + ) if len(extracted_tool_calls.tool_calls) == 0: logger.warning("Failed to extract any tool calls.") @@ -163,22 +175,27 @@ def extract_tool_calls_streaming( tool_call = extracted_tool_calls.tool_calls[0] self.prev_tool_call_arr[self.current_tool_id] = { "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments) + "arguments": json.loads(tool_call.function.arguments), } - self.streamed_args_for_tool[ - self.current_tool_id] = tool_call.function.arguments + self.streamed_args_for_tool[self.current_tool_id] = ( + tool_call.function.arguments + ) delta = DeltaMessage( content=extracted_tool_calls.content, tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - id=tool_call.id, - type=tool_call.type, - function=DeltaFunctionCall( - name=tool_call.function.name, - arguments=tool_call.function.arguments)) - ]) + DeltaToolCall( + index=self.current_tool_id, + id=tool_call.id, + type=tool_call.type, + function=DeltaFunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) + ], + ) self.current_tool_id += 1 - self._buffer = cur_text[end_idx + len(self.tool_call_end_token):] + self._buffer = cur_text[end_idx + len(self.tool_call_end_token) :] return delta self._buffer = cur_text[start_idx:] diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py index 824b100f357b..c42b358b1e34 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py @@ -11,17 +11,25 @@ from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, - find_common_prefix, - is_complete_json, - partial_json_loads) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import ( + consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -47,12 +55,12 @@ def __init__(self, tokenizer: AnyTokenizer): self.tool_call_regex = re.compile(r"<function_call>\s*") def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: if self.tool_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) dec = JSONDecoder() try: @@ -66,13 +74,15 @@ def extract_tool_calls( start_of_json = match.end() # end_index == the start of the next function call # (if exists) - next_function_call_start = (matches[i + 1].start() if i + - 1 < len(matches) else None) + next_function_call_start = ( + matches[i + 1].start() if i + 1 < len(matches) else None + ) raw_function_calls.append( dec.raw_decode( - model_output[start_of_json:next_function_call_start]) - [0]) + model_output[start_of_json:next_function_call_start] + )[0] + ) logger.debug("Extracted %d tool calls", len(raw_function_calls)) tool_calls = [ @@ -81,13 +91,15 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"], - ensure_ascii=False), + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), ), - ) for function_call in raw_function_calls + ) + for function_call in raw_function_calls ] - content = model_output[:model_output.find(self.bot_token)] + content = model_output[: model_output.find(self.bot_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -96,9 +108,9 @@ def extract_tool_calls( except Exception as e: logger.error("Error in extracting tool call from response %s", e) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -110,9 +122,9 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - - if len(current_text) < len( - self.bot_token) and self.bot_token.startswith(current_text): + if len(current_text) < len(self.bot_token) and self.bot_token.startswith( + current_text + ): return None if not current_text.startswith(self.bot_token): @@ -122,8 +134,7 @@ def extract_tool_calls_streaming( # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: tool_call_arr = [] is_complete = [] @@ -132,24 +143,23 @@ def extract_tool_calls_streaming( start_idx = consume_space(start_idx, current_text) while start_idx < len(current_text): - (obj, - end_idx) = partial_json_loads(current_text[start_idx:], - flags) + (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags) is_complete.append( - is_complete_json(current_text[start_idx:start_idx + - end_idx])) + is_complete_json(current_text[start_idx : start_idx + end_idx]) + ) start_idx += end_idx start_idx = consume_space(start_idx, current_text) start_idx += len(self.bot_token) start_idx = consume_space(start_idx, current_text) tool_call_arr.append(obj) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # select as the current tool call the one we're on the state at - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} + current_tool_call: dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) # case -- if no tokens have been streamed for the tool, e.g. # only the array brackets, stream nothing @@ -158,9 +168,9 @@ def extract_tool_calls_streaming( # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't @@ -168,21 +178,24 @@ def extract_tool_calls_streaming( if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) else: delta = None else: @@ -199,15 +212,18 @@ def extract_tool_calls_streaming( elif not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True else: delta = None @@ -219,34 +235,35 @@ def extract_tool_calls_streaming( delta = None if cur_arguments: - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) argument_diff = None if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) if cur_args_json != prev_args_json: - - prefix = find_common_prefix( - prev_args_json, cur_args_json) + prefix = find_common_prefix(prev_args_json, cur_args_json) argument_diff = prefix[sent:] if argument_diff is not None: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) self.prev_tool_call_arr = tool_call_arr return delta @@ -254,6 +271,6 @@ def extract_tool_calls_streaming( except Exception as e: logger.error("Error trying to handle streaming tool call: %s", e) logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py index ac517616a95b..989973923ae5 100644 --- a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py @@ -9,17 +9,25 @@ from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import (consume_space, - find_common_prefix, - is_complete_json, - partial_json_loads) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import ( + consume_space, + find_common_prefix, + is_complete_json, + partial_json_loads, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -45,21 +53,24 @@ def __init__(self, tokenizer: AnyTokenizer): self.bot_string = "<tool_call>" def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: - stripped = model_output.strip()\ - .removeprefix(self.bot_token)\ - .removeprefix(self.bot_string)\ - .lstrip() - if not stripped or stripped[0] != '[': - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: + stripped = ( + model_output.strip() + .removeprefix(self.bot_token) + .removeprefix(self.bot_string) + .lstrip() + ) + if not stripped or stripped[0] != "[": + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: raw_function_calls = json.loads(stripped) if not isinstance(raw_function_calls, list): raise Exception( - f"Expected dict or list, got {type(raw_function_calls)}") + f"Expected dict or list, got {type(raw_function_calls)}" + ) logger.debug("Extracted %d tool calls", len(raw_function_calls)) tool_calls = [ @@ -68,10 +79,12 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"], - ensure_ascii=False), + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), ), - ) for function_call in raw_function_calls + ) + for function_call in raw_function_calls ] return ExtractedToolCallInformation( @@ -82,9 +95,9 @@ def extract_tool_calls( except Exception as e: logger.error("Error in extracting tool call from response %s", e) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -96,41 +109,40 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - start_idx = consume_space(0, current_text) if current_text[start_idx:].startswith(self.bot_token): - start_idx = consume_space(start_idx + len(self.bot_token), - current_text) + start_idx = consume_space(start_idx + len(self.bot_token), current_text) if current_text[start_idx:].startswith(self.bot_string): - start_idx = consume_space(start_idx + len(self.bot_string), - current_text) - if not current_text or start_idx >= len(current_text)\ - or current_text[start_idx] != '[': + start_idx = consume_space(start_idx + len(self.bot_string), current_text) + if ( + not current_text + or start_idx >= len(current_text) + or current_text[start_idx] != "[" + ): return DeltaMessage(content=delta_text) # bit mask flags for partial JSON parsing. If the name hasn't been # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: tool_call_arr = None is_complete = None try: tool_calls, end_idx = partial_json_loads( - current_text[start_idx:], flags) + current_text[start_idx:], flags + ) if type(tool_calls) is list: tool_call_arr = tool_calls else: return DeltaMessage(content=delta_text) is_complete = [True] * len(tool_calls) - if not is_complete_json( - current_text[start_idx:start_idx + end_idx]): + if not is_complete_json(current_text[start_idx : start_idx + end_idx]): is_complete[-1] = False except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # case -- if no tokens have been streamed for the tool, e.g. @@ -145,7 +157,6 @@ def extract_tool_calls_streaming( # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor if len(tool_call_arr) > self.current_tool_id + 1: - # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't @@ -153,21 +164,24 @@ def extract_tool_calls_streaming( if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) # re-set stuff pertaining to progress in the current tool self.current_tool_id = len(tool_call_arr) - 1 @@ -181,15 +195,18 @@ def extract_tool_calls_streaming( elif not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True # now we know we're on the same tool call and we're streaming @@ -198,33 +215,35 @@ def extract_tool_calls_streaming( cur_arguments = current_tool_call.get("arguments") if cur_arguments: - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) argument_diff = None if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) if cur_args_json != prev_args_json: - prefix = find_common_prefix( - prev_args_json, cur_args_json) + prefix = find_common_prefix(prev_args_json, cur_args_json) argument_diff = prefix[sent:] if argument_diff is not None: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) self.prev_tool_call_arr = tool_call_arr return delta @@ -232,6 +251,6 @@ def extract_tool_calls_streaming( except Exception as e: logger.error("Error trying to handle streaming tool call: %s", e) logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index 87595953da06..4529eb51796e 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -10,13 +10,19 @@ from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -25,37 +31,41 @@ @ToolParserManager.register_module("hermes") class Hermes2ProToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) if isinstance(self.model_tokenizer, MistralTokenizer): - logger.error( - "Detected Mistral tokenizer when using a Hermes model") + logger.error("Detected Mistral tokenizer when using a Hermes model") self.model_tokenizer = self.model_tokenizer.tokenizer self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.tool_call_start_token: str = "<tool_call>" self.tool_call_end_token: str = "</tool_call>" self.tool_call_regex = re.compile( - r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL) + r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL + ) self.scratch_pad_regex = re.compile( - r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL) + r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") + "constructor during construction." + ) self.tool_call_start_token_ids = self.model_tokenizer.encode( - self.tool_call_start_token, add_special_tokens=False) + self.tool_call_start_token, add_special_tokens=False + ) self.tool_call_end_token_ids = self.model_tokenizer.encode( - self.tool_call_end_token, add_special_tokens=False) + self.tool_call_end_token, add_special_tokens=False + ) self.tool_call_start_token_array = [ self.model_tokenizer.decode([token_id]) @@ -77,13 +87,17 @@ def __init__(self, tokenizer: AnyTokenizer): def tool_call_delta_buffer(self, delta_text: str): # If the sequence of tool_call_start or tool_call_end tokens is not yet # complete, fill the buffer with the token and return "". - if (delta_text in self.tool_call_start_token_array - or delta_text in self.tool_call_end_token_array): + if ( + delta_text in self.tool_call_start_token_array + or delta_text in self.tool_call_end_token_array + ): # If delta_text is the last token of tool_call_start_token or # tool_call_end_token, empty the buffer and return # the buffered text + delta_text. - if (delta_text == self.tool_call_start_token_array[-1] - or delta_text == self.tool_call_end_token_array[-1]): + if ( + delta_text == self.tool_call_start_token_array[-1] + or delta_text == self.tool_call_end_token_array[-1] + ): buffered_text = self.buffered_delta_text self.buffered_delta_text = "" return buffered_text + delta_text @@ -98,9 +112,8 @@ def tool_call_delta_buffer(self, delta_text: str): else: return delta_text - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: - if request.tools and request.tool_choice != 'none': + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != "none": # do not skip special tokens because the tool_call tokens are # marked "special" in some models. Since they are skipped # prior to the call to the tool parser, it breaks tool calling. @@ -112,22 +125,19 @@ def extract_tool_calls( model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - # sanity check; avoid unnecessary processing if self.tool_call_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: - try: # there are two possible captures - between tags, or between a # tag and end-of-string so the result of # findall is an array of tuples where one is a function call and # the other is None - function_call_tuples = ( - self.tool_call_regex.findall(model_output)) + function_call_tuples = self.tool_call_regex.findall(model_output) # load the JSON, and then use it to build the Function and # Tool Call @@ -141,24 +151,26 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"], - ensure_ascii=False))) + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), + ), + ) for function_call in raw_function_calls ] - content = model_output[:model_output. - find(self.tool_call_start_token)] + content = model_output[: model_output.find(self.tool_call_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=content if content else None) + content=content if content else None, + ) except Exception: - logger.exception( - "Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -177,10 +189,12 @@ def extract_tool_calls_streaming( delta_text = self.tool_call_delta_buffer(delta_text) # If the last characters of previous_text # match self.buffered_delta_text, remove only the matching part. - if (len(previous_text) >= len(self.buffered_delta_text) - and previous_text[-len(self.buffered_delta_text):] - == self.buffered_delta_text): - previous_text = previous_text[:-len(self.buffered_delta_text)] + if ( + len(previous_text) >= len(self.buffered_delta_text) + and previous_text[-len(self.buffered_delta_text) :] + == self.buffered_delta_text + ): + previous_text = previous_text[: -len(self.buffered_delta_text)] current_text = previous_text + delta_text logger.debug("delta_text: %s", delta_text) @@ -191,50 +205,51 @@ def extract_tool_calls_streaming( return DeltaMessage(content=delta_text) try: - # figure out where we are in the parsing by counting tool call # start & end tags - prev_tool_start_count = previous_text.count( - self.tool_call_start_token) + prev_tool_start_count = previous_text.count(self.tool_call_start_token) prev_tool_end_count = previous_text.count(self.tool_call_end_token) - cur_tool_start_count = current_text.count( - self.tool_call_start_token) + cur_tool_start_count = current_text.count(self.tool_call_start_token) cur_tool_end_count = current_text.count(self.tool_call_end_token) tool_call_portion = None text_portion = None # case: if we're generating text, OR rounding out a tool call - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count - and self.tool_call_end_token not in delta_text): + if ( + cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text + ): logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text - tool_call_portion = full_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0].rstrip() - delta_text = delta_text.split( - self.tool_call_end_token)[0].rstrip() - text_portion = delta_text.split( - self.tool_call_end_token)[-1].lstrip() + tool_call_portion = ( + full_text.split(self.tool_call_start_token)[-1] + .split(self.tool_call_end_token)[0] + .rstrip() + ) + delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() # case: if tool open & close tag counts don't match, we're doing # imaginary "else" block here # something with tools with this diff. # flags for partial JSON parting. exported constants from # "Allow" are handled via BIT MASK - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR # case -- we're starting a new tool call - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): + if ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count + ): if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[ + -1 + ] else: tool_call_portion = None delta = None @@ -248,42 +263,49 @@ def extract_tool_calls_streaming( logger.debug("Starting on a new tool %s", self.current_tool_id) # case -- we're updating an existing tool call - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): - + elif ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count + ): # get the portion of the text that's the tool call - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None # case -- the current tool call is being closed. - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count >= prev_tool_end_count): - if (self.prev_tool_call_arr is None - or len(self.prev_tool_call_arr) == 0): - logger.debug( - "attempting to close tool call, but no tool call") + elif ( + cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count + ): + if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: + logger.debug("attempting to close tool call, but no tool call") return None - diff = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: - diff = diff.encode('utf-8').decode( - 'unicode_escape') if diff is str else diff - if ('"}' not in delta_text): + diff = ( + diff.encode("utf-8").decode("unicode_escape") + if diff is str + else diff + ) + if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') diff = delta_text[:end_loc] + '"}' logger.debug( "Finishing tool and found diff that had not " - "been streamed yet: %s", diff) - self.streamed_args_for_tool[self.current_tool_id] \ - += diff - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ] + ) # case -- otherwise we're just generating text else: @@ -293,13 +315,14 @@ def extract_tool_calls_streaming( return delta try: - - current_tool_call = partial_json_parser.loads( - tool_call_portion or "{}", - flags) if tool_call_portion else None + current_tool_call = ( + partial_json_parser.loads(tool_call_portion or "{}", flags) + if tool_call_portion + else None + ) logger.debug("Parsed tool call %s", current_tool_call) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None except json.decoder.JSONDecodeError: logger.debug("unable to parse JSON") @@ -308,19 +331,23 @@ def extract_tool_calls_streaming( # case - we haven't sent the tool name yet. If it's available, send # it. otherwise, wait until it's available. if not self.current_tool_name_sent: - if (current_tool_call is None): + if current_tool_call is None: return None function_name: Union[str, None] = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) else: return None # case -- otherwise, send the tool call delta @@ -329,15 +356,19 @@ def extract_tool_calls_streaming( if tool_call_portion is None: # if there's text but not tool calls, send that - # otherwise None to skip chunk - delta = DeltaMessage(content=delta_text) \ - if text_portion is not None else None + delta = ( + DeltaMessage(content=delta_text) + if text_portion is not None + else None + ) return delta # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. - logger.debug("Trying to parse current tool call with ID %s", - self.current_tool_id) + logger.debug( + "Trying to parse current tool call with ID %s", self.current_tool_id + ) # if we're starting a new tool call, push an empty object in as # a placeholder for the arguments @@ -346,8 +377,9 @@ def extract_tool_calls_streaming( # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON - prev_arguments = ( - self.prev_tool_call_arr[self.current_tool_id].get("arguments")) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) @@ -361,8 +393,10 @@ def extract_tool_calls_streaming( # case -- prev arguments are defined, but non are now. # probably impossible, but not a fatal error - just keep going elif not cur_arguments and prev_arguments: - logger.error("should be impossible to have arguments reset " - "mid-call. skipping streaming anything.") + logger.error( + "should be impossible to have arguments reset " + "mid-call. skipping streaming anything." + ) delta = None # case -- we now have the first info about arguments available from @@ -378,38 +412,41 @@ def extract_tool_calls_streaming( # {"search_request": {}} function_name = current_tool_call.get("name") match = re.search( - r'\{"name":\s*"' + - re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)', - tool_call_portion.strip(), re.DOTALL) + r'\{"name":\s*"' + + re.escape(function_name) + + r'"\s*,\s*"arguments":\s*(.*)', + tool_call_portion.strip(), + re.DOTALL, + ) if match: cur_arguments_json = match.group(1) else: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) + cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) - logger.debug("finding %s in %s", delta_text, - cur_arguments_json) + logger.debug("finding %s in %s", delta_text, cur_arguments_json) # get the location where previous args differ from current. - if (delta_text not in cur_arguments_json): + if delta_text not in cur_arguments_json: return None - args_delta_start_loc = cur_arguments_json. \ - rindex(delta_text) + \ - len(delta_text) + args_delta_start_loc = cur_arguments_json.rindex(delta_text) + len( + delta_text + ) # use that to find the actual delta arguments_delta = cur_arguments_json[:args_delta_start_loc] - logger.debug("First tokens in arguments received: %s", - arguments_delta) - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[self.current_tool_id] \ - += arguments_delta + logger.debug("First tokens in arguments received: %s", arguments_delta) + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: @@ -423,28 +460,32 @@ def extract_tool_calls_streaming( # if the delta_text ends with a '}' and tool_call_portion is a # complete JSON, then the last '}' does not belong to the # arguments, so we should trim it off - if isinstance(delta_text, str) \ - and len(delta_text.rstrip()) >= 1 \ - and delta_text.rstrip()[-1] == '}' \ - and is_complete_json: + if ( + isinstance(delta_text, str) + and len(delta_text.rstrip()) >= 1 + and delta_text.rstrip()[-1] == "}" + and is_complete_json + ): delta_text = delta_text.rstrip()[:-1] logger.debug("got diff %s", delta_text) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=delta_text).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[self.current_tool_id] \ - += delta_text + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=delta_text).model_dump( + exclude_none=True + ), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += delta_text # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[self.current_tool_id] = \ - current_tool_call + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) diff --git a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py index 2b65f2579fb4..1855d69adb21 100644 --- a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py @@ -8,13 +8,19 @@ import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.entrypoints.openai.tool_parsers.utils import consume_space from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -25,7 +31,6 @@ @ToolParserManager.register_module("hunyuan_a13b") class HunyuanA13BToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -33,8 +38,7 @@ def __init__(self, tokenizer: AnyTokenizer): self.prev_tool_calls: list[dict] = [] self.current_tool_id = -1 self.current_tool_name_sent = False - self.streamed_args: list[str] = [ - ] # Track arguments sent for each tool + self.streamed_args: list[str] = [] # Track arguments sent for each tool # For backward compatibility with tests self.current_tools_sent: list[bool] = [] @@ -44,12 +48,14 @@ def __init__(self, tokenizer: AnyTokenizer): # Regex patterns for preprocessing self.answer_tool_calls_pattern = re.compile( - r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL) + r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL + ) self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"') self.tool_empty_arg_reg = re.compile( - r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') + r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}' + ) # TODO: not support nested json object in fc arguments. self.tool_non_empty_arg_reg = re.compile( @@ -66,15 +72,21 @@ def __init__(self, tokenizer: AnyTokenizer): } def preprocess_model_output( - self, model_output: str) -> tuple[Optional[str], Optional[str]]: + self, model_output: str + ) -> tuple[Optional[str], Optional[str]]: # find the location tool call for match in self.answer_tool_calls_pattern.finditer(model_output): start, end = match.span() # check tool_calls whether in side of <think> - think_regions = [(m.start(), m.end()) for m in re.finditer( - r"<think>(.*?)</think>", model_output, flags=re.DOTALL)] - in_think = any(start > t_start and end < t_end - for t_start, t_end in think_regions) + think_regions = [ + (m.start(), m.end()) + for m in re.finditer( + r"<think>(.*?)</think>", model_output, flags=re.DOTALL + ) + ] + in_think = any( + start > t_start and end < t_end for t_start, t_end in think_regions + ) if not in_think: content = model_output[:start] tool_calls_content = match.group(1).strip() @@ -86,24 +98,23 @@ def preprocess_model_output( return model_output, None def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract tool calls from a complete model output. """ try: # Preprocess the model output - content, potential_tool_calls = self.preprocess_model_output( - model_output) + content, potential_tool_calls = self.preprocess_model_output(model_output) if not potential_tool_calls: # some text should be filtered out for no function call # this text is in a13b's chat template. if content: content = content.replace("助手:", "", 1) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=content) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=content + ) # Parse the potential tool calls as JSON tool_calls_data = json.loads(potential_tool_calls) @@ -120,8 +131,11 @@ def extract_tool_calls( tool_calls: list[ToolCall] = [] for idx, call in enumerate(tool_calls_data): - if (not isinstance(call, dict) or "name" not in call - or "arguments" not in call): + if ( + not isinstance(call, dict) + or "name" not in call + or "arguments" not in call + ): continue tool_call = ToolCall( @@ -129,8 +143,11 @@ def extract_tool_calls( type="function", function=FunctionCall( name=call["name"], - arguments=(json.dumps(call["arguments"]) if isinstance( - call["arguments"], dict) else call["arguments"]), + arguments=( + json.dumps(call["arguments"]) + if isinstance(call["arguments"], dict) + else call["arguments"] + ), ), ) tool_calls.append(tool_call) @@ -146,9 +163,9 @@ def extract_tool_calls( ) except Exception: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -166,10 +183,12 @@ def extract_tool_calls_streaming( start_idx = consume_space(0, current_text) if current_text[start_idx:].startswith(self.bot_string): - start_idx = consume_space(start_idx + len(self.bot_string), - current_text) - if not current_text or start_idx >= len( - current_text) or current_text[start_idx] != '[': + start_idx = consume_space(start_idx + len(self.bot_string), current_text) + if ( + not current_text + or start_idx >= len(current_text) + or current_text[start_idx] != "[" + ): return DeltaMessage(content=delta_text) self._try_parse_json_tools(current_text[start_idx:]) @@ -185,13 +204,15 @@ def extract_tool_calls_streaming( self._ensure_state_arrays(tool_count) current_idx = self.streaming_state["current_tool_index"] - name_delta = self._handle_tool_name_streaming(current_idx, tool_count, - name_matches) + name_delta = self._handle_tool_name_streaming( + current_idx, tool_count, name_matches + ) if name_delta: return name_delta - args_delta = self._handle_tool_args_streaming(current_text, - current_idx, tool_count) + args_delta = self._handle_tool_args_streaming( + current_text, current_idx, tool_count + ) if args_delta: return args_delta @@ -207,166 +228,195 @@ def _try_parse_json_tools(self, current_text: str): def _handle_test_compatibility(self, current_text: str): if len(self.current_tools_sent) > 0: - if (len(self.current_tools_sent) == 1 - and self.current_tools_sent[0] is False): + if ( + len(self.current_tools_sent) == 1 + and self.current_tools_sent[0] is False + ): name_match = self.tool_name_reg.search(current_text) if name_match: function_name = name_match.group(1) tool_id = f"chatcmpl-tool-{random_uuid()}" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=0, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tools_sent = [True] self.current_tool_id = 0 self.streaming_state["current_tool_index"] = 0 if len(self.streaming_state["sent_tools"]) == 0: - self.streaming_state["sent_tools"].append({ - "sent_name": - True, - "sent_arguments_prefix": - False, - "sent_arguments": - "", - }) + self.streaming_state["sent_tools"].append( + { + "sent_name": True, + "sent_arguments_prefix": False, + "sent_arguments": "", + } + ) else: - self.streaming_state["sent_tools"][0][ - "sent_name"] = True + self.streaming_state["sent_tools"][0]["sent_name"] = True self.current_tool_name_sent = True return delta return None def _ensure_state_arrays(self, tool_count: int): while len(self.streaming_state["sent_tools"]) < tool_count: - self.streaming_state["sent_tools"].append({ - "sent_name": False, - "sent_arguments_prefix": False, - "sent_arguments": "", - }) + self.streaming_state["sent_tools"].append( + { + "sent_name": False, + "sent_arguments_prefix": False, + "sent_arguments": "", + } + ) while len(self.streaming_state["tool_ids"]) < tool_count: self.streaming_state["tool_ids"].append(None) - def _handle_tool_name_streaming(self, current_idx: int, tool_count: int, - name_matches): + def _handle_tool_name_streaming( + self, current_idx: int, tool_count: int, name_matches + ): if current_idx == -1 or current_idx < tool_count - 1: next_idx = current_idx + 1 - if (next_idx < tool_count - and not self.streaming_state["sent_tools"][next_idx] - ["sent_name"]): + if ( + next_idx < tool_count + and not self.streaming_state["sent_tools"][next_idx]["sent_name"] + ): self.streaming_state["current_tool_index"] = next_idx self.current_tool_id = next_idx current_idx = next_idx tool_name = name_matches[current_idx].group(1) tool_id = f"call_{current_idx}_{random_uuid()}" self.streaming_state["tool_ids"][current_idx] = tool_id - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - type="function", - id=tool_id, - function=DeltaFunctionCall(name=tool_name).model_dump( - exclude_none=True), - ) - ]) - self.streaming_state["sent_tools"][current_idx][ - "sent_name"] = True + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall(name=tool_name).model_dump( + exclude_none=True + ), + ) + ] + ) + self.streaming_state["sent_tools"][current_idx]["sent_name"] = True self.current_tool_name_sent = True while len(self.streamed_args) <= current_idx: self.streamed_args.append("") return delta return None - def _handle_tool_args_streaming(self, current_text: str, current_idx: int, - tool_count: int): - + def _handle_tool_args_streaming( + self, current_text: str, current_idx: int, tool_count: int + ): if current_idx >= 0 and current_idx < tool_count: empty_args_match = self.tool_empty_arg_reg.search(current_text) if empty_args_match and empty_args_match.start() > 0: for i in range(tool_count): if i == current_idx: if not self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"]: + "sent_arguments_prefix" + ]: self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] = True + "sent_arguments_prefix" + ] = True self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = "{}" + "sent_arguments" + ] = "{}" while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{}" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments="{}").model_dump( - exclude_none=True), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{}" + ).model_dump(exclude_none=True), + ) + ] + ) if current_idx < tool_count - 1: self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ - "current_tool_index"] + "current_tool_index" + ] return delta - args_matches = list( - self.tool_non_empty_arg_reg.finditer(current_text)) + args_matches = list(self.tool_non_empty_arg_reg.finditer(current_text)) if current_idx < len(args_matches): args_text = args_matches[current_idx].group(1) is_last_tool = current_idx == tool_count - 1 if not is_last_tool: next_tool_pos = current_text.find( - "},{", args_matches[current_idx].start()) + "},{", args_matches[current_idx].start() + ) if next_tool_pos != -1: - args_end_pos = (next_tool_pos + 1) + args_end_pos = next_tool_pos + 1 args_text = ( - current_text[args_matches[current_idx].start( - ):args_end_pos].split('"arguments":')[1].strip()) + current_text[ + args_matches[current_idx].start() : args_end_pos + ] + .split('"arguments":')[1] + .strip() + ) sent_args = self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] + "sent_arguments" + ] if not self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] and args_text.startswith("{"): + "sent_arguments_prefix" + ] and args_text.startswith("{"): self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] = True + "sent_arguments_prefix" + ] = True self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = "{" + "sent_arguments" + ] = "{" while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments="{").model_dump(exclude_none=True), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall(arguments="{").model_dump( + exclude_none=True + ), + ) + ] + ) return delta if args_text.startswith(sent_args): - args_diff = args_text[len(sent_args):] + args_diff = args_text[len(sent_args) :] if args_diff: self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = args_text + "sent_arguments" + ] = args_text while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += args_diff - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments=args_diff).model_dump( - exclude_none=True), - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=args_diff + ).model_dump(exclude_none=True), + ) + ] + ) return delta if args_text.endswith("}") and args_text == sent_args: if current_idx < tool_count - 1: self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ - "current_tool_index"] + "current_tool_index" + ] return None diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py index 37c360145b04..9adaea297b05 100644 --- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -9,15 +9,20 @@ from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -26,14 +31,12 @@ @ToolParserManager.register_module(["internlm"]) class Internlm2ToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.position = 0 - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: - if request.tools and request.tool_choice != 'none': + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != "none": # do not skip special tokens because internlm use the special # tokens to indicate the start and end of the tool calls # information. @@ -57,34 +60,33 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - if '<|action_start|>' not in current_text: + if "<|action_start|>" not in current_text: self.position = len(current_text) return DeltaMessage(content=delta_text) # if the tool call is sent, return an empty delta message # to make sure the finish_reason will be sent correctly. if self.current_tool_id > 0: - return DeltaMessage(content='') + return DeltaMessage(content="") last_pos = self.position - if '<|action_start|><|plugin|>' not in current_text[last_pos:]: + if "<|action_start|><|plugin|>" not in current_text[last_pos:]: return None new_delta = current_text[last_pos:] - text, action = new_delta.split('<|action_start|><|plugin|>') + text, action = new_delta.split("<|action_start|><|plugin|>") if len(text) > 0: self.position = self.position + len(text) return DeltaMessage(content=text) action = action.strip() - action = action.split('<|action_end|>'.strip())[0] + action = action.split("<|action_end|>".strip())[0] # bit mask flags for partial JSON parsing. If the name hasn't been # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: parsable_arr = action @@ -92,10 +94,9 @@ def extract_tool_calls_streaming( # tool calls are generated in an object in internlm2 # it's not support parallel tool calls try: - tool_call_arr: dict = partial_json_parser.loads( - parsable_arr, flags) + tool_call_arr: dict = partial_json_parser.loads(parsable_arr, flags) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # if the current tool name hasn't been sent, send if available @@ -104,14 +105,18 @@ def extract_tool_calls_streaming( function_name = tool_call_arr.get("name") if function_name: self.current_tool_id = self.current_tool_id + 1 - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True self.streamed_args_for_tool.append("") else: @@ -120,7 +125,8 @@ def extract_tool_calls_streaming( # arguments else: prev_arguments = self.get_arguments( - self.prev_tool_call_arr[self.current_tool_id]) + self.prev_tool_call_arr[self.current_tool_id] + ) cur_arguments = self.get_arguments(tool_call_arr) # not arguments generated @@ -129,43 +135,47 @@ def extract_tool_calls_streaming( # will never happen elif not cur_arguments and prev_arguments: logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") + "INVARIANT - impossible to have arguments reset mid-arguments" + ) delta = None # first time to get parameters elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) - - arguments_delta = cur_arguments_json[:cur_arguments_json. - index(delta_text) + - len(delta_text)] - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta + cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) + + arguments_delta = cur_arguments_json[ + : cur_arguments_json.index(delta_text) + len(delta_text) + ] + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta # both prev and cur parameters, send the increase parameters elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + cur_args_json, prev_args_json + ) + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff # check to see if the name is defined and has been sent. if so, # stream the name - otherwise keep waiting @@ -176,8 +186,8 @@ def extract_tool_calls_streaming( except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None def extract_tool_calls( @@ -187,30 +197,33 @@ def extract_tool_calls( ) -> ExtractedToolCallInformation: text = model_output tools = request.tools - if '<|action_start|><|plugin|>' in text: - text, action = text.split('<|action_start|><|plugin|>') - action = action.split('<|action_end|>'.strip())[0] - action = action[action.find('{'):] + if "<|action_start|><|plugin|>" in text: + text, action = text.split("<|action_start|><|plugin|>") + action = action.split("<|action_end|>".strip())[0] + action = action[action.find("{") :] action_dict = json.loads(action) - name, parameters = action_dict['name'], json.dumps( - action_dict.get('parameters', action_dict.get('arguments', - {})), - ensure_ascii=False) + name, parameters = ( + action_dict["name"], + json.dumps( + action_dict.get("parameters", action_dict.get("arguments", {})), + ensure_ascii=False, + ), + ) if not tools or name not in [t.function.name for t in tools]: - ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=text) + ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=text + ) tool_calls = [ - ToolCall( - function=FunctionCall(name=name, arguments=parameters)) + ToolCall(function=FunctionCall(name=name, arguments=parameters)) ] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=text if len(text) > 0 else None) + content=text if len(text) > 0 else None, + ) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=text) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=text + ) diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py index 3b41f6034704..1ae3e0da3351 100644 --- a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py @@ -10,14 +10,17 @@ from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) +from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizers import MistralTokenizer @@ -27,7 +30,6 @@ @ToolParserManager.register_module("jamba") class JambaToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -39,33 +41,35 @@ def __init__(self, tokenizer: AnyTokenizer): self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.tool_calls_start_token: str = "<tool_calls>" self.tool_calls_end_token: str = "</tool_calls>" self.tool_calls_regex = re.compile( - rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}", - re.DOTALL) + rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}", re.DOTALL + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") - self.tool_calls_start_token_id = self.vocab.get( - self.tool_calls_start_token) - self.tool_calls_end_token_id = self.vocab.get( - self.tool_calls_end_token) - if (self.tool_calls_start_token_id is None - or self.tool_calls_end_token_id is None): + "constructor during construction." + ) + self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) + if ( + self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None + ): raise RuntimeError( "Jamba Tool parser could not locate tool calls start/end " - "tokens in the tokenizer!") + "tokens in the tokenizer!" + ) - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: - if request.tools and request.tool_choice != 'none': + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != "none": # do not skip special tokens because jamba use the special # tokens to indicate the start and end of the tool calls # information. @@ -73,17 +77,15 @@ def adjust_request( return request def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: - + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing if self.tool_calls_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: - try: # use a regex to find the tool call between the tags function_calls = self.tool_calls_regex.findall(model_output)[0] @@ -97,25 +99,26 @@ def extract_tool_calls( function=FunctionCall( name=function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"], - ensure_ascii=False), - )) for function_call in raw_function_calls + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), + ), + ) + for function_call in raw_function_calls ] - content = model_output[:model_output. - find(self.tool_calls_start_token)] + content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=content if - (len(content) > 0 and content != " ") else None) + content=content if (len(content) > 0 and content != " ") else None, + ) except Exception: - logger.exception( - "Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -127,7 +130,6 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # if the tool call token is not in the tokens generated so far, append # output to contents since it's not a tool if self.tool_calls_start_token not in current_text: @@ -138,8 +140,10 @@ def extract_tool_calls_streaming( # handle if we detected the start of tool calls token which means # the start of tool calling - if (self.tool_calls_start_token_id in delta_token_ids - and len(delta_token_ids) == 1): + if ( + self.tool_calls_start_token_id in delta_token_ids + and len(delta_token_ids) == 1 + ): # if it's the only token, return None, so we don't send a chat # completion and don't send a control token return None @@ -148,28 +152,28 @@ def extract_tool_calls_streaming( # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: - # Extract the tool calls between the special tool call tokens - parsable_arr = current_text.split( - self.tool_calls_start_token)[-1].split( - self.tool_calls_end_token)[0] + parsable_arr = current_text.split(self.tool_calls_start_token)[-1].split( + self.tool_calls_end_token + )[0] # tool calls are generated in an array, so do partial JSON # parsing on the entire array try: tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags) + parsable_arr, flags + ) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # select as the current tool call the one we're on the state at - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} + current_tool_call: dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) # case -- if no tokens have been streamed for the tool, e.g. # only the array brackets, stream nothing @@ -178,9 +182,9 @@ def extract_tool_calls_streaming( # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't @@ -190,16 +194,19 @@ def extract_tool_calls_streaming( if diff: diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], - "") - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += diff + self.streamed_args_for_tool[self.current_tool_id], "" + ) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += diff else: delta = None else: @@ -218,15 +225,18 @@ def extract_tool_calls_streaming( if not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True else: delta = None @@ -234,60 +244,66 @@ def extract_tool_calls_streaming( # now we know we're on the same tool call and we're streaming # arguments else: - - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) cur_arguments = current_tool_call.get("arguments") - new_text = delta_text.replace("\'", "\"") + new_text = delta_text.replace("'", '"') if not cur_arguments and not prev_arguments: - delta = None elif not cur_arguments and prev_arguments: logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") + "INVARIANT - impossible to have arguments reset mid-arguments" + ) delta = None elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) - logger.debug("finding %s in %s", new_text, - cur_arguments_json) - - arguments_delta = cur_arguments_json[:cur_arguments_json. - index(new_text) + - len(new_text)] - logger.debug("First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta + cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) + logger.debug("finding %s in %s", new_text, cur_arguments_json) + + arguments_delta = cur_arguments_json[ + : cur_arguments_json.index(new_text) + len(new_text) + ] + logger.debug( + "First tokens in arguments received: %s", arguments_delta + ) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) + logger.debug( + "Searching for diff between \n%s\n%s", + cur_args_json, + prev_args_json, + ) argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) + cur_args_json, prev_args_json + ) logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff else: # try parsing it with regular JSON - if it works we're # at the end, and we need to send the difference between @@ -303,6 +319,6 @@ def extract_tool_calls_streaming( except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py index 834b33052b45..a2eff21a4466 100644 --- a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py @@ -7,13 +7,19 @@ import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -22,14 +28,14 @@ @ToolParserManager.register_module(["kimi_k2"]) class KimiK2ToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[str] = ( - []) # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" self.tool_calls_end_token: str = "<|tool_calls_section_end|>" @@ -45,39 +51,38 @@ def __init__(self, tokenizer: AnyTokenizer): r"(?P<tool_call_id>.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*)" ) - self.stream_tool_call_name_regex = re.compile( - r"(?P<tool_call_id>.+:\d+)\s*") + self.stream_tool_call_name_regex = re.compile(r"(?P<tool_call_id>.+:\d+)\s*") if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") - self.tool_calls_start_token_id = self.vocab.get( - self.tool_calls_start_token) - self.tool_calls_end_token_id = self.vocab.get( - self.tool_calls_end_token) - - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + "constructor during construction." + ) + self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_calls_start_token_id is None - or self.tool_calls_end_token_id is None): + if ( + self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None + ): raise RuntimeError( "Kimi-K2 Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") + "tokens in the tokenizer!" + ) def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - # sanity check; avoid unnecessary processing if self.tool_calls_start_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) else: try: @@ -85,8 +90,7 @@ def extract_tool_calls( # tag and end-of-string so the result of # findall is an array of tuples where one is a function call and # the other is None - function_call_tuples = self.tool_call_regex.findall( - model_output) + function_call_tuples = self.tool_call_regex.findall(model_output) logger.debug("function_call_tuples: %s", function_call_tuples) @@ -94,17 +98,18 @@ def extract_tool_calls( for match in function_call_tuples: function_id, function_args = match # function_id: functions.get_weather:0 - function_name = function_id.split('.')[1].split(':')[0] + function_name = function_id.split(".")[1].split(":")[0] tool_calls.append( ToolCall( id=function_id, - type='function', - function=FunctionCall(name=function_name, - arguments=function_args), - )) + type="function", + function=FunctionCall( + name=function_name, arguments=function_args + ), + ) + ) - content = model_output[:model_output. - find(self.tool_calls_start_token)] + content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, @@ -112,11 +117,10 @@ def extract_tool_calls( ) except Exception: - logger.exception( - "Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -128,55 +132,58 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # check to see if we should be streaming a tool call - is there a if self.tool_calls_start_token_id not in current_token_ids: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) - delta_text = delta_text.replace(self.tool_calls_start_token, - "").replace(self.tool_calls_end_token, - "") + delta_text = delta_text.replace(self.tool_calls_start_token, "").replace( + self.tool_calls_end_token, "" + ) try: - # figure out where we are in the parsing by counting tool call # start & end tags prev_tool_start_count = previous_token_ids.count( - self.tool_call_start_token_id) - prev_tool_end_count = previous_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) cur_tool_start_count = current_token_ids.count( - self.tool_call_start_token_id) - cur_tool_end_count = current_token_ids.count( - self.tool_call_end_token_id) + self.tool_call_start_token_id + ) + cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) tool_call_portion = None text_portion = None # case: if we're generating text, OR rounding out a tool call - if (cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count - and self.tool_call_end_token not in delta_text): + if ( + cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text + ): logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text - tool_call_portion = full_text.split( - self.tool_call_start_token)[-1].split( - self.tool_call_end_token)[0].rstrip() - delta_text = delta_text.split( - self.tool_call_end_token)[0].rstrip() - text_portion = delta_text.split( - self.tool_call_end_token)[-1].lstrip() + tool_call_portion = ( + full_text.split(self.tool_call_start_token)[-1] + .split(self.tool_call_end_token)[0] + .rstrip() + ) + delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() # case -- we're starting a new tool call - if (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count): + if ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count + ): if len(delta_token_ids) > 1: - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[ + -1 + ] else: tool_call_portion = None delta = None @@ -190,27 +197,29 @@ def extract_tool_calls_streaming( logger.debug("Starting on a new tool %s", self.current_tool_id) # case -- we're updating an existing tool call - elif (cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count): - + elif ( + cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count + ): # get the portion of the text that's the tool call - tool_call_portion = current_text.split( - self.tool_call_start_token)[-1] + tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None # case -- the current tool call is being closed. - elif (cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count >= prev_tool_end_count): - if self.prev_tool_call_arr is None or len( - self.prev_tool_call_arr) == 0: - logger.debug( - "attempting to close tool call, but no tool call") + elif ( + cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count + ): + if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: + logger.debug("attempting to close tool call, but no tool call") return None - diff = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: - diff = (diff.encode("utf-8").decode("unicode_escape") - if diff is str else diff) + diff = ( + diff.encode("utf-8").decode("unicode_escape") + if diff is str + else diff + ) if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') @@ -221,13 +230,16 @@ def extract_tool_calls_streaming( diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump(exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ] + ) # case -- otherwise we're just generating text else: @@ -238,23 +250,23 @@ def extract_tool_calls_streaming( current_tool_call = dict() if tool_call_portion: - current_tool_call_matches = ( - self.stream_tool_call_portion_regex.match( - tool_call_portion)) + current_tool_call_matches = self.stream_tool_call_portion_regex.match( + tool_call_portion + ) if current_tool_call_matches: - tool_id, tool_args = (current_tool_call_matches.groups()) - tool_name = tool_id.split('.')[1].split(':')[0] - current_tool_call['id'] = tool_id + tool_id, tool_args = current_tool_call_matches.groups() + tool_name = tool_id.split(".")[1].split(":")[0] + current_tool_call["id"] = tool_id current_tool_call["name"] = tool_name current_tool_call["arguments"] = tool_args else: current_tool_call_name_matches = ( - self.stream_tool_call_name_regex.match( - tool_call_portion)) + self.stream_tool_call_name_regex.match(tool_call_portion) + ) if current_tool_call_name_matches: - tool_id_str, = current_tool_call_name_matches.groups() - tool_name = tool_id_str.split('.')[1].split(':')[0] - current_tool_call['id'] = tool_id_str + (tool_id_str,) = current_tool_call_name_matches.groups() + tool_name = tool_id_str.split(".")[1].split(":")[0] + current_tool_call["id"] = tool_id_str current_tool_call["name"] = tool_name current_tool_call["arguments"] = "" else: @@ -270,16 +282,18 @@ def extract_tool_calls_streaming( tool_id = current_tool_call.get("id") if function_name: self.current_tool_name_sent = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) else: return None @@ -289,15 +303,19 @@ def extract_tool_calls_streaming( if tool_call_portion is None: # if there's text but not tool calls, send that - # otherwise None to skip chunk - delta = (DeltaMessage( - content=delta_text) if text_portion is not None else None) + delta = ( + DeltaMessage(content=delta_text) + if text_portion is not None + else None + ) return delta # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. - logger.debug("Trying to parse current tool call with ID %s", - self.current_tool_id) + logger.debug( + "Trying to parse current tool call with ID %s", self.current_tool_id + ) # if we're starting a new tool call, push an empty object in as # a placeholder for the arguments @@ -307,7 +325,8 @@ def extract_tool_calls_streaming( # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments") + "arguments" + ) cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) @@ -321,52 +340,56 @@ def extract_tool_calls_streaming( # case -- prev arguments are defined, but non are now. # probably impossible, but not a fatal error - just keep going elif not cur_arguments and prev_arguments: - logger.error("should be impossible to have arguments reset " - "mid-call. skipping streaming anything.") + logger.error( + "should be impossible to have arguments reset " + "mid-call. skipping streaming anything." + ) delta = None # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=cur_arguments).model_dump( - exclude_none=True), - ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: - if (isinstance(delta_text, str) - and cur_arguments != prev_arguments - and len(cur_arguments) > len(prev_arguments) - and cur_arguments.startswith(prev_arguments)): - delta_arguments = cur_arguments[len(prev_arguments):] + if ( + isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments) + ): + delta_arguments = cur_arguments[len(prev_arguments) :] logger.debug("got diff %s", delta_text) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=delta_arguments).model_dump( - exclude_none=True), - ) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] = cur_arguments + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] = cur_arguments else: delta = None # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[ - self.current_tool_id] = current_tool_call + self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py index 9a9a19ce2188..162675efbc9a 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -9,13 +9,19 @@ from transformers import PreTrainedTokenizerBase import vllm.envs as envs -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger logger = init_logger(__name__) @@ -31,6 +37,7 @@ class Llama4PythonicToolParser(ToolParser): Toolcall parser for Llama4 that produce tool calls in a pythonic style Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic """ + # TODO(mdepinet): Possible future improvements: # 1. Support text + tools separated by either <|python_tag|> or \n\n # 2. Support tools outside of a list (or separated by a semicolon). @@ -40,7 +47,8 @@ class Llama4PythonicToolParser(ToolParser): TOOL_CALL_REGEX = re.compile( r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", - re.DOTALL) + re.DOTALL, + ) def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) @@ -55,8 +63,8 @@ def current_tool_index(self, value: int) -> None: self.current_tool_id = value def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. """ @@ -64,46 +72,52 @@ def extract_tool_calls( # remove <|python_start|> and <|python_end|> # as Llama 4 model sometime will output those tokens if model_output.startswith("<|python_start|>"): - model_output = model_output[len("<|python_start|>"):] + model_output = model_output[len("<|python_start|>") :] model_output = model_output.replace("<|python_end|>", "") is_tool_call_pattern = False try: - is_tool_call_pattern = self.TOOL_CALL_REGEX.match( - model_output, - timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None + is_tool_call_pattern = ( + self.TOOL_CALL_REGEX.match( + model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS + ) + is not None + ) except TimeoutError: - logger.warning( - "Regex timeout occurred when matching tool call pattern.") - logger.debug("Regex timeout occurred when matching user input: %s", - model_output) + logger.warning("Regex timeout occurred when matching tool call pattern.") + logger.debug( + "Regex timeout occurred when matching user input: %s", model_output + ) if not is_tool_call_pattern: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: module = ast.parse(model_output) parsed = getattr(module.body[0], "value", None) if isinstance(parsed, ast.List) and all( - isinstance(e, ast.Call) for e in parsed.elts): + isinstance(e, ast.Call) for e in parsed.elts + ): return ExtractedToolCallInformation( tools_called=True, tool_calls=[ _handle_single_tool(e) # type: ignore for e in parsed.elts ], - content=None) + content=None, + ) else: raise _UnexpectedAstError( - "Tool output must be a list of function calls") + "Tool output must be a list of function calls" + ) except Exception: logger.exception("Error in extracting tool call from response.") # Treat as regular text - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -115,18 +129,17 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - if not current_text.startswith("[") and not current_text.startswith( - "<|python_start|>"): + "<|python_start|>" + ): return DeltaMessage(content=delta_text) try: # remove <|python_start|> and <|python_end|> if current_text.startswith("<|python_start|>"): - current_text = current_text[len("<|python_start|>"):] + current_text = current_text[len("<|python_start|>") :] if current_text.endswith("<|python_end|>"): - current_text = current_text[:current_text. - rfind("<|python_end|>")] + current_text = current_text[: current_text.rfind("<|python_end|>")] valid_and_added_text = _make_valid_python(current_text) if valid_and_added_text is None: return None @@ -135,9 +148,11 @@ def extract_tool_calls_streaming( module = ast.parse(valid_text) parsed = getattr(module.body[0], "value", None) if not isinstance(parsed, ast.List) or not all( - isinstance(e, ast.Call) for e in parsed.elts): + isinstance(e, ast.Call) for e in parsed.elts + ): raise _UnexpectedAstError( - "Tool output must be a list of function calls") + "Tool output must be a list of function calls" + ) tool_calls = [ _handle_single_tool(e) # type: ignore for e in parsed.elts @@ -152,34 +167,36 @@ def extract_tool_calls_streaming( if len(self.streamed_args_for_tool) == index: self.streamed_args_for_tool.append("") - new_call_complete = index < len( - tool_calls) - 1 or ")]" not in added_text + new_call_complete = ( + index < len(tool_calls) - 1 or ")]" not in added_text + ) if new_call_complete: self.current_tool_index += 1 - withheld_suffix = (added_text[:-2] - if not new_call_complete else "") + withheld_suffix = added_text[:-2] if not new_call_complete else "" if not new_call_complete and added_text[-2] == ")": # Function call is incomplete. Withhold the closing bracket. withheld_suffix = withheld_suffix + "}" # Strings get single quotes in the model-produced string. # JSON requires double quotes. withheld_suffix = withheld_suffix.replace("'", '"') - delta = _compute_tool_delta(self.streamed_args_for_tool[index], - new_call, index, withheld_suffix) + delta = _compute_tool_delta( + self.streamed_args_for_tool[index], new_call, index, withheld_suffix + ) if delta is not None: tool_deltas.append(delta) - if (delta.function is not None - and delta.function.arguments is not None): - self.streamed_args_for_tool[ - index] += delta.function.arguments - - # HACK: serving_chat.py inspects the internal state of tool parsers - # when determining its final streaming delta, automatically - # adding autocompleted JSON. - # These two lines avoid that nonsense while ensuring finish_reason - # is set to tool_calls when at least one tool is called. + if ( + delta.function is not None + and delta.function.arguments is not None + ): + self.streamed_args_for_tool[index] += delta.function.arguments + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining its final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. if tool_deltas and not self.prev_tool_call_arr: self.prev_tool_call_arr = [{"arguments": {}}] @@ -188,14 +205,14 @@ def extract_tool_calls_streaming( elif not added_text and self.current_tool_id > 0: # Return an empty DeltaMessage once the tool calls are all done # so that finish_reason gets set. - return DeltaMessage(content='') + return DeltaMessage(content="") else: return None except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None @@ -204,8 +221,7 @@ def _get_parameter_value(val: ast.expr) -> Any: return val.value elif isinstance(val, ast.Dict): if not all(isinstance(k, ast.Constant) for k in val.keys): - raise _UnexpectedAstError( - "Dict tool call arguments must have literal keys") + raise _UnexpectedAstError("Dict tool call arguments must have literal keys") return { k.value: _get_parameter_value(v) # type: ignore for k, v in zip(val.keys, val.values) @@ -223,9 +239,10 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: arguments = {} for keyword in call.keywords: arguments[keyword.arg] = _get_parameter_value(keyword.value) - return ToolCall(type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(arguments))) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, arguments=json.dumps(arguments)), + ) def _make_valid_python(text: str) -> Union[tuple[str, str], None]: @@ -261,21 +278,25 @@ def _make_valid_python(text: str) -> Union[tuple[str, str], None]: # we can't fill in a valid value. return None if bracket_stack and bracket_stack[-1] == "{": - trailing_dict_text = text[:text.rfind("{")] + trailing_dict_text = text[: text.rfind("{")] num_keys = trailing_dict_text.count(":") num_values = trailing_dict_text.count(",") if num_keys <= num_values: return None # Incomplete property name within parameter value if bracket_stack and bracket_stack[-1] == "(": - trailing_params_text = text[:text.rfind("(")] + trailing_params_text = text[: text.rfind("(")] num_full_param_names = trailing_params_text.count("=") num_full_param_values = trailing_params_text.count(",") if num_full_param_names <= num_full_param_values: return None # Incomplete parameter name if text.endswith(","): text = text[:-1] - if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( - "[") and not text.endswith(")"): + if ( + bracket_stack + and bracket_stack[-1] == "[" + and not text.endswith("[") + and not text.endswith(")") + ): return None # Incomplete function name added_text = "" @@ -294,23 +315,29 @@ def _make_valid_python(text: str) -> Union[tuple[str, str], None]: return text + added_text, added_text -def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, - index: int, - withheld_suffix: str) -> Union[DeltaToolCall, None]: +def _compute_tool_delta( + previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str +) -> Union[DeltaToolCall, None]: new_call_args = new_call.function.arguments if withheld_suffix: assert new_call_args.endswith(withheld_suffix) - new_call_args = new_call_args[:-len(withheld_suffix)] + new_call_args = new_call_args[: -len(withheld_suffix)] if not previously_sent_args: - return DeltaToolCall(id=new_call.id, - type="function", - index=index, - function=DeltaFunctionCall( - name=new_call.function.name, - arguments=new_call_args, - )) - - arg_diff = new_call_args[len(previously_sent_args):] - return DeltaToolCall( - id=None, index=index, function=DeltaFunctionCall( - arguments=arg_diff)) if arg_diff else None + return DeltaToolCall( + id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + ), + ) + + arg_diff = new_call_args[len(previously_sent_args) :] + return ( + DeltaToolCall( + id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff) + ) + if arg_diff + else None + ) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index 31b19c8db416..4d5ef5ed64aa 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -11,16 +11,24 @@ from transformers import PreTrainedTokenizerBase from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix, - is_complete_json, - partial_json_loads) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import ( + find_common_prefix, + is_complete_json, + partial_json_loads, +) from vllm.logger import init_logger logger = init_logger(__name__) @@ -33,7 +41,7 @@ class Llama3JsonToolParser(ToolParser): Tool call parser for Llama 3.x and 4 models intended for use with the examples/tool_chat_template_llama.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser llama3_json or + Used when --enable-auto-tool-choice --tool-call-parser llama3_json or llama4_json are set. """ @@ -45,42 +53,45 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.bot_token = "<|python_tag|>" - self.bot_token_id = tokenizer.encode(self.bot_token, - add_special_tokens=False)[0] + self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[ + 0 + ] # Updated regex to match multiple JSONs separated by semicolons # This pattern is more robust and can handle nested JSON objects self.tool_call_regex = re.compile( - r'{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*', - re.DOTALL) + r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*", + re.DOTALL, + ) def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. Only extracts JSON content and ignores any surrounding plain text. Supports both single JSON and multiple JSONs separated by semicolons. """ # Quick check before running regex - if not (self.bot_token in model_output or '{' in model_output): - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + if not (self.bot_token in model_output or "{" in model_output): + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) # Find JSON object(s) in the text using regex match = self.tool_call_regex.search(model_output) if not match: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: json_str = match.group(0) # Split by semicolon and strip whitespace - json_objects = [obj.strip() for obj in json_str.split(';')] + json_objects = [obj.strip() for obj in json_str.split(";")] tool_calls: list[ToolCall] = [] for json_obj in json_objects: @@ -95,19 +106,24 @@ def extract_tool_calls( # function call args are JSON but as a string arguments=json.dumps( obj["arguments"] - if "arguments" in obj else obj["parameters"], - ensure_ascii=False)))) - - return ExtractedToolCallInformation(tools_called=True, - tool_calls=tool_calls, - content=None) + if "arguments" in obj + else obj["parameters"], + ensure_ascii=False, + ), + ), + ) + ) + + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=None + ) except Exception: logger.exception("Error in extracting tool call from response.") # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -119,47 +135,49 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - - if not (current_text.startswith(self.bot_token) - or current_text.startswith('{')): + if not ( + current_text.startswith(self.bot_token) or current_text.startswith("{") + ): return DeltaMessage(content=delta_text) # bit mask flags for partial JSON parsing. If the name hasn't been # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: tool_call_arr = [] is_complete = [] try: # depending on the prompt format the Llama model may or may not # prefix the output with the <|python_tag|> token - start_idx = len(self.bot_token) if current_text.startswith( - self.bot_token) else 0 + start_idx = ( + len(self.bot_token) + if current_text.startswith(self.bot_token) + else 0 + ) while start_idx < len(current_text): - (obj, - end_idx) = partial_json_loads(current_text[start_idx:], - flags) + (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags) is_complete.append( - is_complete_json(current_text[start_idx:start_idx + - end_idx])) - start_idx += end_idx + len('; ') + is_complete_json(current_text[start_idx : start_idx + end_idx]) + ) + start_idx += end_idx + len("; ") # depending on the prompt Llama can use # either arguments or parameters if "parameters" in obj: - assert "arguments" not in obj, \ + assert "arguments" not in obj, ( "model generated both parameters and arguments" + ) obj["arguments"] = obj["parameters"] tool_call_arr.append(obj) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # select as the current tool call the one we're on the state at - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} + current_tool_call: dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) # case -- if no tokens have been streamed for the tool, e.g. # only the array brackets, stream nothing @@ -168,9 +186,9 @@ def extract_tool_calls_streaming( # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't @@ -178,21 +196,24 @@ def extract_tool_calls_streaming( if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) else: delta = None else: @@ -209,15 +230,18 @@ def extract_tool_calls_streaming( elif not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True else: delta = None @@ -229,34 +253,35 @@ def extract_tool_calls_streaming( delta = None if cur_arguments: - sent = len( - self.streamed_args_for_tool[self.current_tool_id]) - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) argument_diff = None if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) if cur_args_json != prev_args_json: - - prefix = find_common_prefix( - prev_args_json, cur_args_json) + prefix = find_common_prefix(prev_args_json, cur_args_json) argument_diff = prefix[sent:] if argument_diff is not None: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += ( + argument_diff + ) self.prev_tool_call_arr = tool_call_arr return delta @@ -264,6 +289,6 @@ def extract_tool_calls_streaming( except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py index 87a3fdc44397..1dc1a0290c8d 100644 --- a/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py @@ -3,16 +3,13 @@ import regex as re -from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import ( - Hermes2ProToolParser) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParserManager +from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.transformers_utils.tokenizer import AnyTokenizer @ToolParserManager.register_module("longcat") class LongcatFlashToolParser(Hermes2ProToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -21,12 +18,15 @@ def __init__(self, tokenizer: AnyTokenizer): self.tool_call_regex = re.compile( r"<longcat_tool_call>(.*?)</longcat_tool_call>|<longcat_tool_call>(.*)", - re.DOTALL) + re.DOTALL, + ) self.tool_call_start_token_ids = self.model_tokenizer.encode( - self.tool_call_start_token, add_special_tokens=False) + self.tool_call_start_token, add_special_tokens=False + ) self.tool_call_end_token_ids = self.model_tokenizer.encode( - self.tool_call_end_token, add_special_tokens=False) + self.tool_call_end_token, add_special_tokens=False + ) self.tool_call_start_token_array = [ self.model_tokenizer.decode([token_id]) diff --git a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py index 0fd62f0b6a7f..0b83fd237a6a 100644 --- a/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py @@ -8,15 +8,20 @@ import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -25,7 +30,6 @@ @ToolParserManager.register_module("minimax") class MinimaxToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -40,7 +44,8 @@ def __init__(self, tokenizer: AnyTokenizer): self.tool_call_start_token = "<tool_calls>" self.tool_call_end_token = "</tool_calls>" self.tool_call_regex = re.compile( - r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL) + r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL + ) self.thinking_tag_pattern = r"<think>(.*?)</think>" self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"') self.tool_args_pattern = re.compile(r'"arguments":\s*') @@ -52,50 +57,51 @@ def __init__(self, tokenizer: AnyTokenizer): if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") + "constructor during construction." + ) # Get token IDs for tool call start/end tokens - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_call_start_token_id is None - or self.tool_call_end_token_id is None): + if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: logger.warning( "Minimax Tool parser could not locate tool call start/end " - "tokens in the tokenizer. Falling back to string matching.") + "tokens in the tokenizer. Falling back to string matching." + ) def preprocess_model_output(self, model_output: str) -> str: """ Preprocess model output by removing tool calls from thinking tags. - + Args: model_output: Raw model output string - + Returns: Preprocessed model output with tool calls removed from thinking tags """ def remove_tool_calls_from_think(match): think_content = match.group(1) - cleaned_content = re.sub(r"<tool_calls>.*?</tool_calls>", - "", - think_content, - flags=re.DOTALL) + cleaned_content = re.sub( + r"<tool_calls>.*?</tool_calls>", "", think_content, flags=re.DOTALL + ) return f"<think>{cleaned_content}</think>" - return re.sub(self.thinking_tag_pattern, - remove_tool_calls_from_think, - model_output, - flags=re.DOTALL) + return re.sub( + self.thinking_tag_pattern, + remove_tool_calls_from_think, + model_output, + flags=re.DOTALL, + ) def _clean_duplicate_braces(self, args_text: str) -> str: """ Clean duplicate closing braces from arguments text. - + Args: args_text: Raw arguments text - + Returns: Cleaned arguments text with proper JSON formatting """ @@ -109,7 +115,7 @@ def _clean_duplicate_braces(self, args_text: str) -> str: except json.JSONDecodeError: pass - while args_text.endswith('}}'): + while args_text.endswith("}}"): candidate = args_text[:-1] try: json.loads(candidate) @@ -122,10 +128,10 @@ def _clean_duplicate_braces(self, args_text: str) -> str: def _clean_delta_braces(self, delta_text: str) -> str: """ Clean delta text by removing excessive closing braces. - + Args: delta_text: Delta text to clean - + Returns: Cleaned delta text """ @@ -134,10 +140,10 @@ def _clean_delta_braces(self, delta_text: str) -> str: delta_stripped = delta_text.strip() - if delta_stripped and all(c in '}\n\r\t ' for c in delta_stripped): - brace_count = delta_stripped.count('}') + if delta_stripped and all(c in "}\n\r\t " for c in delta_stripped): + brace_count = delta_stripped.count("}") if brace_count > 1: - return '}\n' if delta_text.endswith('\n') else '}' + return "}\n" if delta_text.endswith("\n") else "}" return delta_text @@ -148,34 +154,32 @@ def extract_tool_calls( ) -> ExtractedToolCallInformation: """ Extract tool calls from model output for non-streaming mode. - + Args: model_output: Complete model output request: Chat completion request - + Returns: ExtractedToolCallInformation containing tool calls and content """ processed_output = self.preprocess_model_output(model_output) if self.tool_call_start_token not in processed_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: - function_call_tuples = self.tool_call_regex.findall( - processed_output) + function_call_tuples = self.tool_call_regex.findall(processed_output) raw_function_calls = [] for match in function_call_tuples: tool_call_content = match[0] if match[0] else match[1] if tool_call_content.strip(): - lines = tool_call_content.strip().split('\n') + lines = tool_call_content.strip().split("\n") for line in lines: line = line.strip() - if line and line.startswith('{') and line.endswith( - '}'): + if line and line.startswith("{") and line.endswith("}"): try: parsed_call = json.loads(line) raw_function_calls.append(parsed_call) @@ -186,25 +190,29 @@ def extract_tool_calls( for function_call in raw_function_calls: if "name" in function_call and "arguments" in function_call: tool_calls.append( - ToolCall(type="function", - function=FunctionCall( - name=function_call["name"], - arguments=json.dumps( - function_call["arguments"], - ensure_ascii=False)))) + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + arguments=json.dumps( + function_call["arguments"], ensure_ascii=False + ), + ), + ) + ) processed_pos = processed_output.find(self.tool_call_start_token) if processed_pos != -1: processed_content = processed_output[:processed_pos].strip() if processed_content: - lines = processed_content.split('\n') + lines = processed_content.split("\n") for line in reversed(lines): line = line.strip() if line: pos = model_output.find(line) if pos != -1: - content = model_output[:pos + len(line)] + content = model_output[: pos + len(line)] break else: content = "" @@ -216,68 +224,74 @@ def extract_tool_calls( return ExtractedToolCallInformation( tools_called=len(tool_calls) > 0, tool_calls=tool_calls, - content=content.strip() if content.strip() else None) + content=content.strip() if content.strip() else None, + ) except Exception: logger.exception( - "An unexpected error occurred during tool call extraction.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + "An unexpected error occurred during tool call extraction." + ) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def _update_thinking_state(self, text: str) -> None: """ Update the thinking tag state based on text content. - + Args: text: Text to analyze for thinking tags """ open_count = text.count("<think>") close_count = text.count("</think>") self.in_thinking_tag = open_count > close_count or ( - open_count == close_count and text.endswith("</think>")) + open_count == close_count and text.endswith("</think>") + ) def _is_potential_tag_start(self, text: str) -> bool: """ Check if text might be the start of a tool call tag. - + Args: text: Text to check - + Returns: True if text could be the start of a tool call tag """ for tag in [self.tool_call_start_token, self.tool_call_end_token]: if any( - tag.startswith(text[-i:]) - for i in range(1, min(len(text) + 1, len(tag)))): + tag.startswith(text[-i:]) + for i in range(1, min(len(text) + 1, len(tag))) + ): return True return False def _should_buffer_content(self, delta_text: str) -> bool: """ Determine if content should be buffered for later processing. - + Args: delta_text: Delta text to check - + Returns: True if content should be buffered """ if self.in_thinking_tag: return False - return bool(self.pending_buffer - or self.tool_call_start_token in delta_text - or self.tool_call_end_token in delta_text - or delta_text.startswith('<')) + return bool( + self.pending_buffer + or self.tool_call_start_token in delta_text + or self.tool_call_end_token in delta_text + or delta_text.startswith("<") + ) def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]: """ Split delta text into safe content and potential tag content. - + Args: delta_text: Delta text to split - + Returns: Tuple of (safe_content, potential_tag_content) """ @@ -295,10 +309,10 @@ def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]: def _process_buffer(self, new_content: str) -> str: """ Process buffered content and return output content. - + Args: new_content: New content to add to buffer - + Returns: Processed output content """ @@ -326,7 +340,7 @@ def _process_buffer(self, new_content: str) -> str: break output_content += self.pending_buffer[:tag_pos] - self.pending_buffer = self.pending_buffer[tag_pos + tag_len:] + self.pending_buffer = self.pending_buffer[tag_pos + tag_len :] return output_content @@ -340,13 +354,14 @@ def _reset_streaming_state(self) -> None: def _advance_to_next_tool(self) -> None: """Advance to the next tool in the streaming sequence.""" - self.streaming_state["current_tool_index"] = int( - self.streaming_state["current_tool_index"]) + 1 + self.streaming_state["current_tool_index"] = ( + int(self.streaming_state["current_tool_index"]) + 1 + ) def _set_current_tool_index(self, index: int) -> None: """ Set the current tool index. - + Args: index: Tool index to set """ @@ -355,7 +370,7 @@ def _set_current_tool_index(self, index: int) -> None: def _get_current_tool_index(self) -> int: """ Get the current tool index. - + Returns: Current tool index """ @@ -364,10 +379,10 @@ def _get_current_tool_index(self) -> int: def _get_next_unsent_tool_index(self, tool_count: int) -> int: """ Get the index of the next unsent tool. - + Args: tool_count: Total number of tools - + Returns: Index of next unsent tool, or -1 if all tools sent """ @@ -383,7 +398,7 @@ def _get_next_unsent_tool_index(self, tool_count: int) -> int: def _ensure_state_arrays(self, tool_count: int) -> None: """ Ensure state arrays have sufficient capacity for tool_count tools. - + Args: tool_count: Number of tools to prepare for """ @@ -391,11 +406,13 @@ def _ensure_state_arrays(self, tool_count: int) -> None: tool_ids = list(self.streaming_state["tool_ids"]) while len(sent_tools) < tool_count: - sent_tools.append({ - "sent_name": False, - "sent_arguments": "", - "id": make_tool_call_id(), - }) + sent_tools.append( + { + "sent_name": False, + "sent_arguments": "", + "id": make_tool_call_id(), + } + ) while len(tool_ids) < tool_count: tool_ids.append(None) @@ -406,10 +423,10 @@ def _ensure_state_arrays(self, tool_count: int) -> None: def _detect_tools_in_text(self, text: str) -> int: """ Detect the number of tools in text by counting name patterns. - + Args: text: Text to analyze - + Returns: Number of tools detected """ @@ -419,26 +436,26 @@ def _detect_tools_in_text(self, text: str) -> int: def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]: """ Find the boundaries of tool calls in text. - + Args: text: Text to analyze - + Returns: List of (start, end) positions for tool calls """ boundaries = [] i = 0 while i < len(text): - if text[i] == '{': + if text[i] == "{": start = i depth = 0 has_name = False has_arguments = False while i < len(text): - if text[i] == '{': + if text[i] == "{": depth += 1 - elif text[i] == '}': + elif text[i] == "}": depth -= 1 if depth == 0: end = i + 1 @@ -447,10 +464,9 @@ def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]: boundaries.append((start, end)) break - if not has_name and '"name"' in text[start:i + 1]: + if not has_name and '"name"' in text[start : i + 1]: has_name = True - if not has_arguments and '"arguments"' in text[start:i + - 1]: + if not has_arguments and '"arguments"' in text[start : i + 1]: has_arguments = True i += 1 @@ -461,47 +477,46 @@ def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]: i += 1 return boundaries - def _extract_tool_args(self, tool_content: str, - args_match: re.Match[str]) -> str: + def _extract_tool_args(self, tool_content: str, args_match: re.Match[str]) -> str: """ Extract tool arguments from tool content. - + Args: tool_content: Tool call content args_match: Regex match for arguments pattern - + Returns: Extracted arguments as string """ args_start_pos = args_match.end() remaining_content = tool_content[args_start_pos:] - if remaining_content.strip().startswith('{'): + if remaining_content.strip().startswith("{"): depth = 0 for i, char in enumerate(remaining_content): - if char == '{': + if char == "{": depth += 1 - elif char == '}': + elif char == "}": depth -= 1 if depth == 0: - return remaining_content[:i + 1] + return remaining_content[: i + 1] else: - args_end = remaining_content.find('}') + args_end = remaining_content.find("}") if args_end > 0: return remaining_content[:args_end].strip() - return remaining_content.rstrip('}').strip() + return remaining_content.rstrip("}").strip() def _get_current_tool_content( - self, text: str, - tool_index: int) -> tuple[Optional[str], Optional[str]]: + self, text: str, tool_index: int + ) -> tuple[Optional[str], Optional[str]]: """ Get the content of a specific tool by index. - + Args: text: Text containing tool calls tool_index: Index of tool to extract - + Returns: Tuple of (tool_name, tool_arguments) or (None, None) if not found """ @@ -522,22 +537,22 @@ def _get_current_tool_content( args_text = self._extract_tool_args(tool_content, args_match) return name, args_text except Exception: - remaining_content = tool_content[args_match.end():] - args_text = remaining_content.rstrip('}').strip() + remaining_content = tool_content[args_match.end() :] + args_text = remaining_content.rstrip("}").strip() return name, args_text return name, None def _handle_tool_name_streaming( - self, tool_content: str, - tool_count: int) -> Union[DeltaMessage, None]: + self, tool_content: str, tool_count: int + ) -> Union[DeltaMessage, None]: """ Handle streaming of tool names. - + Args: tool_content: Content containing tool calls tool_count: Total number of tools - + Returns: DeltaMessage with tool name or None if no tool to stream """ @@ -565,24 +580,29 @@ def _handle_tool_name_streaming( self.streaming_state["sent_tools"] = sent_tools self.streaming_state["tool_ids"] = tool_ids - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=next_idx, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=tool_name).model_dump(exclude_none=True)) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=next_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall(name=tool_name).model_dump( + exclude_none=True + ), + ) + ] + ) def _handle_tool_args_streaming( - self, tool_content: str, - tool_count: int) -> Union[DeltaMessage, None]: + self, tool_content: str, tool_count: int + ) -> Union[DeltaMessage, None]: """ Handle streaming of tool arguments. - + Args: tool_content: Content containing tool calls tool_count: Total number of tools - + Returns: DeltaMessage with tool arguments or None if no arguments to stream """ @@ -591,8 +611,7 @@ def _handle_tool_args_streaming( if current_idx < 0 or current_idx >= tool_count: return None - tool_name, tool_args = self._get_current_tool_content( - tool_content, current_idx) + tool_name, tool_args = self._get_current_tool_content(tool_content, current_idx) if not tool_name or tool_args is None: return None @@ -612,29 +631,37 @@ def _handle_tool_args_streaming( sent_tools[current_idx]["sent_arguments"] = clean_args self.streaming_state["sent_tools"] = sent_tools - if clean_args.endswith('}'): + if clean_args.endswith("}"): self._advance_to_next_tool() - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=current_idx, - function=DeltaFunctionCall( - arguments=args_delta).model_dump( - exclude_none=True)) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=args_delta + ).model_dump(exclude_none=True), + ) + ] + ) elif not sent_args and clean_args: clean_args_delta = self._clean_delta_braces(clean_args) sent_tools[current_idx]["sent_arguments"] = clean_args self.streaming_state["sent_tools"] = sent_tools - if clean_args.endswith('}'): + if clean_args.endswith("}"): self._advance_to_next_tool() - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=current_idx, - function=DeltaFunctionCall( - arguments=clean_args_delta).model_dump( - exclude_none=True)) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=clean_args_delta + ).model_dump(exclude_none=True), + ) + ] + ) return None @@ -652,14 +679,15 @@ def _is_end_tool_calls(self, current_text: str) -> bool: search_start = pos + 1 think_regions = [] - for match in re.finditer(self.thinking_tag_pattern, - current_text, - flags=re.DOTALL): + for match in re.finditer( + self.thinking_tag_pattern, current_text, flags=re.DOTALL + ): think_regions.append((match.start(), match.end())) for pos in end_token_positions: - in_think = any(pos >= t_start and pos < t_end - for t_start, t_end in think_regions) + in_think = any( + pos >= t_start and pos < t_end for t_start, t_end in think_regions + ) if not in_think: return True @@ -682,14 +710,12 @@ def extract_tool_calls_streaming( if self._should_buffer_content(delta_text): buffered_output = self._process_buffer(delta_text) - return DeltaMessage( - content=buffered_output) if buffered_output else None + return DeltaMessage(content=buffered_output) if buffered_output else None if self._is_end_tool_calls(current_text): return DeltaMessage(content=delta_text) - safe_content, potential_tag = self._split_content_for_buffering( - delta_text) + safe_content, potential_tag = self._split_content_for_buffering(delta_text) if potential_tag: self.pending_buffer += potential_tag return DeltaMessage(content=safe_content) if safe_content else None @@ -697,35 +723,39 @@ def extract_tool_calls_streaming( processed_current_text = self.preprocess_model_output(current_text) if self.tool_call_start_token not in processed_current_text: - if (self.tool_call_end_token in delta_text - and self.tool_call_start_token in current_text): + if ( + self.tool_call_end_token in delta_text + and self.tool_call_start_token in current_text + ): return None - if delta_text.strip( - ) == '' and self.tool_call_start_token in current_text: + if delta_text.strip() == "" and self.tool_call_start_token in current_text: return None - if (self._get_current_tool_index() != -1 - and self.tool_call_end_token in current_text): + if ( + self._get_current_tool_index() != -1 + and self.tool_call_end_token in current_text + ): self._reset_streaming_state() return DeltaMessage(content=delta_text) - if (self.tool_call_start_token_id is not None - and self.tool_call_start_token_id in delta_token_ids - and len(delta_token_ids) == 1): + if ( + self.tool_call_start_token_id is not None + and self.tool_call_start_token_id in delta_token_ids + and len(delta_token_ids) == 1 + ): return None - original_tool_start = self._find_tool_start_outside_thinking( - current_text) + original_tool_start = self._find_tool_start_outside_thinking(current_text) if original_tool_start is None: return None content_before_tools = self._extract_content_before_tools( - current_text, delta_text, original_tool_start) + current_text, delta_text, original_tool_start + ) if content_before_tools: return DeltaMessage(content=content_before_tools) try: - tool_content = self._extract_tool_content(current_text, - original_tool_start) + tool_content = self._extract_tool_content(current_text, original_tool_start) current_tools_count = self._detect_tools_in_text(tool_content) if current_tools_count == 0: @@ -736,24 +766,23 @@ def extract_tool_calls_streaming( self._ensure_state_arrays(current_tools_count) - return (self._handle_tool_name_streaming(tool_content, - current_tools_count) - or self._handle_tool_args_streaming( - tool_content, current_tools_count)) + return self._handle_tool_name_streaming( + tool_content, current_tools_count + ) or self._handle_tool_args_streaming(tool_content, current_tools_count) except Exception: - logger.exception("An unexpected error occurred ", - "during streaming tool call handling.") + logger.exception( + "An unexpected error occurred ", "during streaming tool call handling." + ) return None - def _find_tool_start_outside_thinking(self, - current_text: str) -> Optional[int]: + def _find_tool_start_outside_thinking(self, current_text: str) -> Optional[int]: """ Find the start position of tool calls outside of thinking tags. - + Args: current_text: Current text to search - + Returns: Position of tool call start or None if not found """ @@ -763,26 +792,32 @@ def _find_tool_start_outside_thinking(self, if pos == -1: return None - think_regions = [(m.start(), m.end()) for m in re.finditer( - r"<think>(.*?)</think>", current_text, flags=re.DOTALL)] - in_think = any(pos >= t_start and pos < t_end - for t_start, t_end in think_regions) + think_regions = [ + (m.start(), m.end()) + for m in re.finditer( + r"<think>(.*?)</think>", current_text, flags=re.DOTALL + ) + ] + in_think = any( + pos >= t_start and pos < t_end for t_start, t_end in think_regions + ) if not in_think: return pos search_start = pos + 1 - def _extract_content_before_tools(self, current_text: str, delta_text: str, - tool_start: int) -> Optional[str]: + def _extract_content_before_tools( + self, current_text: str, delta_text: str, tool_start: int + ) -> Optional[str]: """ Extract content that appears before tool calls. - + Args: current_text: Current text delta_text: Delta text tool_start: Start position of tools - + Returns: Content before tools or None """ @@ -791,18 +826,18 @@ def _extract_content_before_tools(self, current_text: str, delta_text: str, if delta_start_pos < tool_start: content_part = delta_text if delta_start_pos + len(delta_text) > tool_start: - content_part = delta_text[:tool_start - delta_start_pos] + content_part = delta_text[: tool_start - delta_start_pos] return content_part if content_part else None return None def _extract_tool_content(self, current_text: str, tool_start: int) -> str: """ Extract tool content from current text starting at tool_start. - + Args: current_text: Current text tool_start: Start position of tool calls - + Returns: Extracted tool content """ diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index e6b300fd84e9..b3b8960276bc 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -12,15 +12,20 @@ from partial_json_parser.core.options import Allow from pydantic import Field -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) + ToolParser, + ToolParserManager, +) +from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -30,8 +35,7 @@ class MistralToolCall(ToolCall): - id: str = Field( - default_factory=lambda: MistralToolCall.generate_random_id()) + id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id()) @staticmethod def generate_random_id(): @@ -45,8 +49,9 @@ def is_valid_id(id: str) -> bool: def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool: - return isinstance(model_tokenizer, MistralTokenizer) \ - and model_tokenizer.version >= 11 + return ( + isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11 + ) @ToolParserManager.register_module("mistral") @@ -63,35 +68,38 @@ def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) if not isinstance(self.model_tokenizer, MistralTokenizer): - logger.info("Non-Mistral tokenizer detected when using a Mistral " - "model...") + logger.info("Non-Mistral tokenizer detected when using a Mistral model...") # initialize properties used for state when parsing tool calls in # streaming mode self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.bot_token = "[TOOL_CALLS]" self.bot_token_id = self.vocab.get(self.bot_token) self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) if _is_fn_name_regex_support(self.model_tokenizer): self.fn_name_regex = re.compile( - r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL) + r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)", re.DOTALL + ) else: self.fn_name_regex = None if self.bot_token_id is None: raise RuntimeError( "Mistral Tool Parser could not locate the tool call token in " - "the tokenizer!") - - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: - if not isinstance( - self.model_tokenizer, MistralTokenizer - ) and request.tools and request.tool_choice != 'none': + "the tokenizer!" + ) + + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if ( + not isinstance(self.model_tokenizer, MistralTokenizer) + and request.tools + and request.tool_choice != "none" + ): # Do not skip special tokens when using chat template # with Mistral parser as TOOL_CALL token is needed # for tool detection. @@ -113,9 +121,9 @@ def extract_tool_calls( # case -- if a tool call token is not present, return a text response if self.bot_token not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) # first remove the BOT token tool_content = model_output.replace(self.bot_token, "").strip() @@ -134,10 +142,9 @@ def extract_tool_calls( # fn_name is encoded outside serialized json dump # only arguments are serialized - function_call_arr.append({ - "name": fn_name, - "arguments": json.loads(args) - }) + function_call_arr.append( + {"name": fn_name, "arguments": json.loads(args)} + ) else: function_call_arr = json.loads(tool_content) except json.JSONDecodeError: @@ -155,8 +162,11 @@ def extract_tool_calls( function=FunctionCall( name=raw_function_call["name"], # function call args are JSON but as a string - arguments=json.dumps(raw_function_call["arguments"], - ensure_ascii=False))) + arguments=json.dumps( + raw_function_call["arguments"], ensure_ascii=False + ), + ), + ) for raw_function_call in function_call_arr ] @@ -165,14 +175,15 @@ def extract_tool_calls( return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=content if len(content) > 0 else None) + content=content if len(content) > 0 else None, + ) except Exception: logger.exception("Error in extracting tool call from response.") # return information to just treat the tool call as regular JSON - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=tool_content) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=tool_content + ) def extract_tool_calls_streaming( self, @@ -184,7 +195,6 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # if the tool call token is not in the tokens generated so far, append # output to contents since it's not a tool if self.bot_token not in current_text: @@ -195,8 +205,7 @@ def extract_tool_calls_streaming( # handle if we detected the BOT token which means the start of tool # calling - if (self.bot_token_id in delta_token_ids - and len(delta_token_ids) == 1): + if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1: # if it's the only token, return None, so we don't send a chat # completion any don't send a control token return None @@ -205,10 +214,8 @@ def extract_tool_calls_streaming( # sent yet, don't allow sending # an incomplete string since OpenAI only ever (as far as I have # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: - # replace BOT token with empty string, and convert single quotes # to double to allow parsing as JSON since mistral uses single # quotes instead of double for tool calls @@ -218,15 +225,17 @@ def extract_tool_calls_streaming( # parsing on the entire array try: tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags) + parsable_arr, flags + ) except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') + logger.debug("not enough tokens to parse into JSON yet") return None # select as the current tool call the one we're on the state at - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} + current_tool_call: dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) # case -- if no tokens have been streamed for the tool, e.g. # only the array brackets, stream nothing @@ -235,9 +244,9 @@ def extract_tool_calls_streaming( # case: we are starting a new tool in the array # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): # if we're moving on to a new call, first make sure we # haven't missed anything in the previous one that was # auto-generated due to JSON completions, but wasn't @@ -247,16 +256,19 @@ def extract_tool_calls_streaming( if diff: diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], - "") - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += diff + self.streamed_args_for_tool[self.current_tool_id], "" + ) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += diff else: delta = None else: @@ -275,15 +287,18 @@ def extract_tool_calls_streaming( if not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=MistralToolCall.generate_random_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), + ) + ] + ) self.current_tool_name_sent = True else: delta = None @@ -291,64 +306,72 @@ def extract_tool_calls_streaming( # now we know we're on the same tool call and we're streaming # arguments else: - - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) cur_arguments = current_tool_call.get("arguments") - new_text = delta_text.replace("\'", "\"") - if ('"}' in new_text): - new_text = new_text[:new_text.rindex('"}')] + new_text = delta_text.replace("'", '"') + if '"}' in new_text: + new_text = new_text[: new_text.rindex('"}')] if not cur_arguments and not prev_arguments: - delta = None elif not cur_arguments and prev_arguments: logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") + "INVARIANT - impossible to have arguments reset mid-arguments" + ) delta = None elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False)[:-2] - logger.debug("finding %s in %s", new_text, - cur_arguments_json) + cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)[ + :-2 + ] + logger.debug("finding %s in %s", new_text, cur_arguments_json) - if (new_text not in cur_arguments_json): + if new_text not in cur_arguments_json: return None - arguments_delta = cur_arguments_json[:cur_arguments_json. - rindex(new_text) + - len(new_text)] - logger.debug("First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta + arguments_delta = cur_arguments_json[ + : cur_arguments_json.rindex(new_text) + len(new_text) + ] + logger.debug( + "First tokens in arguments received: %s", arguments_delta + ) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) + cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) + prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) + logger.debug( + "Searching for diff between \n%s\n%s", + cur_args_json, + prev_args_json, + ) argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) + cur_args_json, prev_args_json + ) logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff + ).model_dump(exclude_none=True), + ) + ] + ) + self.streamed_args_for_tool[self.current_tool_id] += argument_diff else: # try parsing it with regular JSON - if it works we're # at the end, and we need to send the difference between @@ -364,6 +387,6 @@ def extract_tool_calls_streaming( except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py index 1729fdbc9971..8d7cbbfba649 100644 --- a/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py @@ -7,12 +7,17 @@ from typing import TYPE_CHECKING from vllm.entrypoints.harmony_utils import parse_output_into_messages -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger if TYPE_CHECKING: @@ -23,7 +28,6 @@ @ToolParserManager.register_module("openai") class OpenAIToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -57,7 +61,8 @@ def extract_tool_calls( tool_args = json.dumps(json.loads(msg_text)) except json.JSONDecodeError: logger.exception( - "Error decoding JSON tool call from response.") + "Error decoding JSON tool call from response." + ) tool_args = msg_text else: tool_args = msg_text @@ -68,7 +73,8 @@ def extract_tool_calls( name=msg.recipient.split("functions.")[1], arguments=tool_args, ), - )) + ) + ) elif msg.channel == "final": final_content = msg_text diff --git a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py index 85dd56213c6a..114987e5600b 100644 --- a/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py @@ -9,12 +9,17 @@ from transformers import PreTrainedTokenizerBase from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger logger = init_logger(__name__) @@ -26,7 +31,7 @@ class Phi4MiniJsonToolParser(ToolParser): Tool call parser for phi-4-mini models intended for use with the examples/tool_chat_template_llama.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json + Used when --enable-auto-tool-choice --tool-call-parser phi4_mini_json are all set """ @@ -38,39 +43,42 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None: self.prev_tool_call_arr: list[dict[str, Any]] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[str] = [ - ] # map what has been streamed for each tool so far to a list + self.streamed_args_for_tool: list[ + str + ] = [] # map what has been streamed for each tool so far to a list self.bot_token: str = "functools" def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. """ logger.debug("Model output: %s", model_output) - pattern = r'functools\[(.*?)\]' + pattern = r"functools\[(.*?)\]" matches = re.search(pattern, model_output, re.DOTALL) if not matches: logger.debug("No function calls found") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: function_call_arr: list[dict[str, Any]] = [] try: - json_content = '[' + matches.group(1) + ']' + json_content = "[" + matches.group(1) + "]" function_call_arr = json.loads(json_content) - logger.debug("Successfully extracted %d function calls", - len(function_call_arr)) + logger.debug( + "Successfully extracted %d function calls", len(function_call_arr) + ) except json.JSONDecodeError as e: logger.error( - "Failed to parse function calls from model output. " - "Error: %s", str(e)) + "Failed to parse function calls from model output. Error: %s", + str(e), + ) tool_calls: list[ToolCall] = [ ToolCall( @@ -81,22 +89,25 @@ def extract_tool_calls( # function call args are JSON but as a string arguments=json.dumps( raw_function_call["arguments"] - if "arguments" in raw_function_call else - raw_function_call["parameters"], - ensure_ascii=False), - )) for raw_function_call in function_call_arr + if "arguments" in raw_function_call + else raw_function_call["parameters"], + ensure_ascii=False, + ), + ), + ) + for raw_function_call in function_call_arr ] # get any content before the tool call - ret = ExtractedToolCallInformation(tools_called=True, - tool_calls=tool_calls, - content=None) + ret = ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=None + ) return ret except Exception: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -108,5 +119,4 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Optional[DeltaMessage]: - return None diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py index 992f141bef0f..272068a6f0ac 100644 --- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py @@ -10,13 +10,19 @@ from transformers import PreTrainedTokenizerBase import vllm.envs as envs -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger logger = init_logger(__name__) @@ -34,6 +40,7 @@ class PythonicToolParser(ToolParser): Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set """ + # TODO(mdepinet): Possible future improvements: # 1. Support text + tools separated by either <|python_tag|> or \n\n # 2. Support tools outside of a list (or separated by a semicolon). @@ -43,7 +50,8 @@ class PythonicToolParser(ToolParser): TOOL_CALL_REGEX = re.compile( r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", - re.DOTALL) + re.DOTALL, + ) def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) @@ -58,48 +66,54 @@ def current_tool_index(self, value: int) -> None: self.current_tool_id = value def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. """ is_tool_call_pattern = False try: - is_tool_call_pattern = self.TOOL_CALL_REGEX.match( - model_output, - timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS) is not None + is_tool_call_pattern = ( + self.TOOL_CALL_REGEX.match( + model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS + ) + is not None + ) except TimeoutError: - logger.warning( - "Regex timeout occurred when matching tool call pattern.") - logger.debug("Regex timeout occurred when matching user input: %s", - model_output) + logger.warning("Regex timeout occurred when matching tool call pattern.") + logger.debug( + "Regex timeout occurred when matching user input: %s", model_output + ) if not is_tool_call_pattern: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: module = ast.parse(model_output) parsed = getattr(module.body[0], "value", None) if isinstance(parsed, ast.List) and all( - isinstance(e, ast.Call) for e in parsed.elts): + isinstance(e, ast.Call) for e in parsed.elts + ): return ExtractedToolCallInformation( tools_called=True, tool_calls=[ _handle_single_tool(e) # type: ignore for e in parsed.elts ], - content=None) + content=None, + ) else: raise _UnexpectedAstError( - "Tool output must be a list of function calls") + "Tool output must be a list of function calls" + ) except Exception: logger.exception("Error in extracting tool call from response.") # Treat as regular text - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -111,7 +125,6 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - if not current_text.startswith("["): return DeltaMessage(content=delta_text) @@ -124,9 +137,11 @@ def extract_tool_calls_streaming( module = ast.parse(valid_text) parsed = getattr(module.body[0], "value", None) if not isinstance(parsed, ast.List) or not all( - isinstance(e, ast.Call) for e in parsed.elts): + isinstance(e, ast.Call) for e in parsed.elts + ): raise _UnexpectedAstError( - "Tool output must be a list of function calls") + "Tool output must be a list of function calls" + ) tool_calls = [ _handle_single_tool(e) # type: ignore for e in parsed.elts @@ -141,28 +156,30 @@ def extract_tool_calls_streaming( if len(self.streamed_args_for_tool) == index: self.streamed_args_for_tool.append("") - new_call_complete = index < len( - tool_calls) - 1 or ")]" not in added_text + new_call_complete = ( + index < len(tool_calls) - 1 or ")]" not in added_text + ) if new_call_complete: self.current_tool_index += 1 - withheld_suffix = (added_text[:-2] - if not new_call_complete else "") + withheld_suffix = added_text[:-2] if not new_call_complete else "" if not new_call_complete and added_text[-2] == ")": # Function call is incomplete. Withhold the closing bracket. withheld_suffix = withheld_suffix + "}" # Strings get single quotes in the model-produced string. # JSON requires double quotes. withheld_suffix = withheld_suffix.replace("'", '"') - delta = _compute_tool_delta(self.streamed_args_for_tool[index], - new_call, index, withheld_suffix) + delta = _compute_tool_delta( + self.streamed_args_for_tool[index], new_call, index, withheld_suffix + ) if delta is not None: tool_deltas.append(delta) - if (delta.function is not None - and delta.function.arguments is not None): - self.streamed_args_for_tool[ - index] += delta.function.arguments + if ( + delta.function is not None + and delta.function.arguments is not None + ): + self.streamed_args_for_tool[index] += delta.function.arguments # HACK: serving_chat.py inspects the internal state of tool parsers # when determining its final streaming delta, automatically @@ -177,14 +194,14 @@ def extract_tool_calls_streaming( elif not added_text and self.current_tool_id > 0: # Return an empty DeltaMessage once the tool calls are all done # so that finish_reason gets set. - return DeltaMessage(content='') + return DeltaMessage(content="") else: return None except Exception: logger.exception("Error trying to handle streaming tool call.") logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") + "Skipping chunk as a result of tool streaming extraction error" + ) return None @@ -193,8 +210,7 @@ def _get_parameter_value(val: ast.expr) -> Any: return val.value elif isinstance(val, ast.Dict): if not all(isinstance(k, ast.Constant) for k in val.keys): - raise _UnexpectedAstError( - "Dict tool call arguments must have literal keys") + raise _UnexpectedAstError("Dict tool call arguments must have literal keys") return { k.value: _get_parameter_value(v) # type: ignore for k, v in zip(val.keys, val.values) @@ -214,9 +230,9 @@ def _handle_single_tool(call: ast.Call) -> ToolCall: arguments[keyword.arg] = _get_parameter_value(keyword.value) return ToolCall( type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(arguments, - ensure_ascii=False)), + function=FunctionCall( + name=function_name, arguments=json.dumps(arguments, ensure_ascii=False) + ), ) @@ -253,21 +269,25 @@ def _make_valid_python(text: str) -> Union[tuple[str, str], None]: # we can't fill in a valid value. return None if bracket_stack and bracket_stack[-1] == "{": - trailing_dict_text = text[:text.rfind("{")] + trailing_dict_text = text[: text.rfind("{")] num_keys = trailing_dict_text.count(":") num_values = trailing_dict_text.count(",") if num_keys <= num_values: return None # Incomplete property name within parameter value if bracket_stack and bracket_stack[-1] == "(": - trailing_params_text = text[:text.rfind("(")] + trailing_params_text = text[: text.rfind("(")] num_full_param_names = trailing_params_text.count("=") num_full_param_values = trailing_params_text.count(",") if num_full_param_names <= num_full_param_values: return None # Incomplete parameter name if text.endswith(","): text = text[:-1] - if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( - "[") and not text.endswith(")"): + if ( + bracket_stack + and bracket_stack[-1] == "[" + and not text.endswith("[") + and not text.endswith(")") + ): return None # Incomplete function name added_text = "" @@ -286,23 +306,29 @@ def _make_valid_python(text: str) -> Union[tuple[str, str], None]: return text + added_text, added_text -def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, - index: int, - withheld_suffix: str) -> Union[DeltaToolCall, None]: +def _compute_tool_delta( + previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str +) -> Union[DeltaToolCall, None]: new_call_args = new_call.function.arguments if withheld_suffix: assert new_call_args.endswith(withheld_suffix) - new_call_args = new_call_args[:-len(withheld_suffix)] + new_call_args = new_call_args[: -len(withheld_suffix)] if not previously_sent_args: - return DeltaToolCall(id=new_call.id, - type="function", - index=index, - function=DeltaFunctionCall( - name=new_call.function.name, - arguments=new_call_args, - )) - - arg_diff = new_call_args[len(previously_sent_args):] - return DeltaToolCall( - id=None, index=index, function=DeltaFunctionCall( - arguments=arg_diff)) if arg_diff else None + return DeltaToolCall( + id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + ), + ) + + arg_diff = new_call_args[len(previously_sent_args) :] + return ( + DeltaToolCall( + id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff) + ) + if arg_diff + else None + ) diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py index 955813ddd340..a41ca30bf527 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -8,14 +8,20 @@ import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -24,7 +30,6 @@ @ToolParserManager.register_module("qwen3_coder") class Qwen3CoderToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -49,32 +54,37 @@ def __init__(self, tokenizer: AnyTokenizer): # Regex patterns self.tool_call_complete_regex = re.compile( - r"<tool_call>(.*?)</tool_call>", re.DOTALL) + r"<tool_call>(.*?)</tool_call>", re.DOTALL + ) self.tool_call_regex = re.compile( - r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL) + r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL + ) self.tool_call_function_regex = re.compile( - r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) + r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL + ) self.tool_call_parameter_regex = re.compile( r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)", - re.DOTALL) + re.DOTALL, + ) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " - "constructor during construction.") + "constructor during construction." + ) - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) - if (self.tool_call_start_token_id is None - or self.tool_call_end_token_id is None): + if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: raise RuntimeError( "Qwen3 XML Tool parser could not locate tool call start/end " - "tokens in the tokenizer!") + "tokens in the tokenizer!" + ) - logger.info("vLLM Successfully import tool parser %s !", - self.__class__.__name__) + logger.info( + "vLLM Successfully import tool parser %s !", self.__class__.__name__ + ) def _generate_tool_call_id(self) -> str: """Generate a unique tool call ID.""" @@ -100,14 +110,15 @@ def _reset_streaming_state(self): self.streaming_request = None def _get_arguments_config( - self, func_name: str, - tools: Optional[list[ChatCompletionToolsParam]]) -> dict: + self, func_name: str, tools: Optional[list[ChatCompletionToolsParam]] + ) -> dict: """Extract argument configuration for a function.""" if tools is None: return {} for config in tools: - if not hasattr(config, "type") or not (hasattr( - config, "function") and hasattr(config.function, "name")): + if not hasattr(config, "type") or not ( + hasattr(config, "function") and hasattr(config.function, "name") + ): continue if config.type == "function" and config.function.name == func_name: if not hasattr(config.function, "parameters"): @@ -119,12 +130,12 @@ def _get_arguments_config( return params else: return {} - logger.warning("Tool '%s' is not defined in the tools list.", - func_name) + logger.warning("Tool '%s' is not defined in the tools list.", func_name) return {} - def _convert_param_value(self, param_value: str, param_name: str, - param_config: dict, func_name: str) -> Any: + def _convert_param_value( + self, param_value: str, param_name: str, param_config: dict, func_name: str + ) -> Any: """Convert parameter value based on its type in the schema.""" # Handle null value for any type if param_value.lower() == "null": @@ -135,38 +146,55 @@ def _convert_param_value(self, param_value: str, param_name: str, logger.warning( "Parsed parameter '%s' is not defined in the tool " "parameters for tool '%s', directly returning the " - "string value.", param_name, func_name) + "string value.", + param_name, + func_name, + ) return param_value - if isinstance(param_config[param_name], - dict) and "type" in param_config[param_name]: + if ( + isinstance(param_config[param_name], dict) + and "type" in param_config[param_name] + ): param_type = str(param_config[param_name]["type"]).strip().lower() else: param_type = "string" if param_type in ["string", "str", "text", "varchar", "char", "enum"]: return param_value - elif param_type.startswith("int") or param_type.startswith( - "uint") or param_type.startswith( - "long") or param_type.startswith( - "short") or param_type.startswith("unsigned"): + elif ( + param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + ): try: return int(param_value) except (ValueError, TypeError): logger.warning( "Parsed value '%s' of parameter '%s' is not an " "integer in tool '%s', degenerating to string.", - param_value, param_name, func_name) + param_value, + param_name, + func_name, + ) return param_value elif param_type.startswith("num") or param_type.startswith("float"): try: float_param_value = float(param_value) - return float_param_value if float_param_value - int( - float_param_value) != 0 else int(float_param_value) + return ( + float_param_value + if float_param_value - int(float_param_value) != 0 + else int(float_param_value) + ) except (ValueError, TypeError): logger.warning( "Parsed value '%s' of parameter '%s' is not a float " - "in tool '%s', degenerating to string.", param_value, - param_name, func_name) + "in tool '%s', degenerating to string.", + param_value, + param_name, + func_name, + ) return param_value elif param_type in ["boolean", "bool", "binary"]: param_value = param_value.lower() @@ -174,12 +202,18 @@ def _convert_param_value(self, param_value: str, param_name: str, logger.warning( "Parsed value '%s' of parameter '%s' is not a boolean " "(`true` or `false`) in tool '%s', degenerating to " - "false.", param_value, param_name, func_name) + "false.", + param_value, + param_name, + func_name, + ) return param_value == "true" else: - if param_type in ["object", "array", "arr" - ] or param_type.startswith( - "dict") or param_type.startswith("list"): + if ( + param_type in ["object", "array", "arr"] + or param_type.startswith("dict") + or param_type.startswith("list") + ): try: param_value = json.loads(param_value) return param_value @@ -187,33 +221,37 @@ def _convert_param_value(self, param_value: str, param_name: str, logger.warning( "Parsed value '%s' of parameter '%s' cannot be " "parsed with json.loads in tool '%s', will try " - "other methods to parse it.", param_value, param_name, - func_name) + "other methods to parse it.", + param_value, + param_name, + func_name, + ) try: param_value = ast.literal_eval(param_value) # safer except (ValueError, SyntaxError, TypeError): logger.warning( "Parsed value '%s' of parameter '%s' cannot be " "converted via Python `ast.literal_eval()` in tool " - "'%s', degenerating to string.", param_value, param_name, - func_name) + "'%s', degenerating to string.", + param_value, + param_name, + func_name, + ) return param_value def _parse_xml_function_call( - self, function_call_str: str, - tools: Optional[list[ChatCompletionToolsParam]] + self, function_call_str: str, tools: Optional[list[ChatCompletionToolsParam]] ) -> Optional[ToolCall]: - # Extract function name end_index = function_call_str.index(">") function_name = function_call_str[:end_index] param_config = self._get_arguments_config(function_name, tools) - parameters = function_call_str[end_index + 1:] + parameters = function_call_str[end_index + 1 :] param_dict = {} for match_text in self.tool_call_parameter_regex.findall(parameters): idx = match_text.index(">") param_name = match_text[:idx] - param_value = str(match_text[idx + 1:]) + param_value = str(match_text[idx + 1 :]) # Remove prefix and trailing \n if param_value.startswith("\n"): param_value = param_value[1:] @@ -221,12 +259,13 @@ def _parse_xml_function_call( param_value = param_value[:-1] param_dict[param_name] = self._convert_param_value( - param_value, param_name, param_config, function_name) + param_value, param_name, param_config, function_name + ) return ToolCall( type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(param_dict, - ensure_ascii=False)), + function=FunctionCall( + name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False) + ), ) def _get_function_calls(self, model_output: str) -> list[str]: @@ -242,8 +281,7 @@ def _get_function_calls(self, model_output: str) -> list[str]: raw_function_calls = [] for tool_call in raw_tool_calls: - raw_function_calls.extend( - self.tool_call_function_regex.findall(tool_call)) + raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call)) function_calls = [ match[0] if match[0] else match[1] for match in raw_function_calls @@ -257,16 +295,16 @@ def extract_tool_calls( ) -> ExtractedToolCallInformation: # Quick check to avoid unnecessary processing if self.tool_call_prefix not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) try: function_calls = self._get_function_calls(model_output) if len(function_calls) == 0: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) tool_calls = [ self._parse_xml_function_call(function_call_str, request.tools) @@ -277,12 +315,12 @@ def extract_tool_calls( self.prev_tool_call_arr.clear() # Clear previous calls for tool_call in tool_calls: if tool_call: - self.prev_tool_call_arr.append({ - "name": - tool_call.function.name, - "arguments": - tool_call.function.arguments, - }) + self.prev_tool_call_arr.append( + { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + ) # Extract content before tool calls content_index = model_output.find(self.tool_call_start_token) @@ -298,9 +336,9 @@ def extract_tool_calls( except Exception: logger.exception("Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -322,19 +360,19 @@ def extract_tool_calls_streaming( # Check if this is an EOS token after all tool calls are complete # Check for tool calls in text even if is_tool_call_started # is False (might have been reset after processing all tools) - if (delta_token_ids - and self.tool_call_end_token_id not in delta_token_ids): + if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids: # Count complete tool calls complete_calls = len( - self.tool_call_complete_regex.findall(current_text)) + self.tool_call_complete_regex.findall(current_text) + ) # If we have completed tool calls and populated # prev_tool_call_arr if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: # Check if all tool calls are closed open_calls = current_text.count( - self.tool_call_start_token) - current_text.count( - self.tool_call_end_token) + self.tool_call_start_token + ) - current_text.count(self.tool_call_end_token) if open_calls == 0: # Return empty delta for finish_reason processing return DeltaMessage(content="") @@ -370,20 +408,25 @@ def extract_tool_calls_streaming( # Handle normal content before tool calls if not self.is_tool_call_started: # Check if tool call is starting - if (self.tool_call_start_token_id in delta_token_ids - or self.tool_call_start_token in delta_text): + if ( + self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text + ): self.is_tool_call_started = True # Return any content before the tool call if self.tool_call_start_token in delta_text: - content_before = delta_text[:delta_text.index( - self.tool_call_start_token)] + content_before = delta_text[ + : delta_text.index(self.tool_call_start_token) + ] if content_before: return DeltaMessage(content=content_before) return None else: # Check if we're between tool calls - skip whitespace - if (current_text.rstrip().endswith(self.tool_call_end_token) - and delta_text.strip() == ""): + if ( + current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == "" + ): # We just ended a tool call, skip whitespace return None # Normal content, no tool call @@ -413,19 +456,20 @@ def extract_tool_calls_streaming( tool_start_idx = tool_start_positions[self.current_tool_index] # Find where this tool call ends (or current position if not ended yet) - tool_end_idx = current_text.find(self.tool_call_end_token, - tool_start_idx) + tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) if tool_end_idx == -1: tool_text = current_text[tool_start_idx:] else: - tool_text = current_text[tool_start_idx:tool_end_idx + - len(self.tool_call_end_token)] + tool_text = current_text[ + tool_start_idx : tool_end_idx + len(self.tool_call_end_token) + ] # Looking for function header if not self.header_sent: if self.tool_call_prefix in tool_text: func_start = tool_text.find(self.tool_call_prefix) + len( - self.tool_call_prefix) + self.tool_call_prefix + ) func_end = tool_text.find(">", func_start) if func_end != -1: @@ -440,38 +484,44 @@ def extract_tool_calls_streaming( # finish_reason="tool_calls" even if parsing isn't complete already_added = any( tool.get("name") == self.current_function_name - for tool in self.prev_tool_call_arr) + for tool in self.prev_tool_call_arr + ) if not already_added: - self.prev_tool_call_arr.append({ - "name": self.current_function_name, - "arguments": - "{}", # Placeholder, will be updated later - }) + self.prev_tool_call_arr.append( + { + "name": self.current_function_name, + "arguments": "{}", # Placeholder, will be updated later + } + ) # Send header with function info - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - id=self.current_tool_id, - function=DeltaFunctionCall( - name=self.current_function_name, arguments=""), - type="function", - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments="" + ), + type="function", + ) + ] + ) return None # We've sent header, now handle function body if self.in_function: # Send opening brace if not sent yet - if (not self.json_started - and self.parameter_prefix not in delta_text): + if not self.json_started and self.parameter_prefix not in delta_text: self.json_started = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="{"), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ] + ) # Make sure json_started is set if we're processing parameters if not self.json_started: @@ -486,35 +536,38 @@ def extract_tool_calls_streaming( # prev_tool_call_arr with final arguments # Find the function content func_start = tool_text.find(self.tool_call_prefix) + len( - self.tool_call_prefix) - func_content_end = tool_text.find(self.function_end_token, - func_start) + self.tool_call_prefix + ) + func_content_end = tool_text.find(self.function_end_token, func_start) if func_content_end != -1: func_content = tool_text[func_start:func_content_end] # Parse to get the complete arguments try: parsed_tool = self._parse_xml_function_call( - func_content, self.streaming_request.tools - if self.streaming_request else None) + func_content, + self.streaming_request.tools + if self.streaming_request + else None, + ) if parsed_tool: # Update existing entry in # prev_tool_call_arr with complete args for i, tool in enumerate(self.prev_tool_call_arr): - if tool.get( - "name") == parsed_tool.function.name: + if tool.get("name") == parsed_tool.function.name: args = parsed_tool.function.arguments - self.prev_tool_call_arr[i][ - "arguments"] = args + self.prev_tool_call_arr[i]["arguments"] = args break except Exception: pass # Ignore parsing errors during streaming - result = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="}"), - ) - ]) + result = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ] + ) # Reset state for next tool self.in_function = False @@ -535,8 +588,11 @@ def extract_tool_calls_streaming( idx += len(self.parameter_prefix) # Check if we should start a new parameter - if (not self.in_param and self.param_count < len(param_starts) - and len(param_starts) > self.param_count): + if ( + not self.in_param + and self.param_count < len(param_starts) + and len(param_starts) > self.param_count + ): # Process the next parameter param_idx = param_starts[self.param_count] param_start = param_idx + len(self.parameter_prefix) @@ -561,9 +617,9 @@ def extract_tool_calls_streaming( next_param_idx = value_text.find(self.parameter_prefix) func_end_idx = value_text.find(self.function_end_token) - if next_param_idx != -1 and (func_end_idx == -1 - or next_param_idx - < func_end_idx): + if next_param_idx != -1 and ( + func_end_idx == -1 or next_param_idx < func_end_idx + ): param_end_idx = next_param_idx elif func_end_idx != -1: param_end_idx = func_end_idx @@ -585,41 +641,49 @@ def extract_tool_calls_streaming( param_value = param_value[:-1] # Store raw value for later processing - self.accumulated_params[ - self.current_param_name] = param_value + self.accumulated_params[self.current_param_name] = param_value # Get parameter configuration for type conversion param_config = self._get_arguments_config( self.current_function_name or "", self.streaming_request.tools - if self.streaming_request else None) + if self.streaming_request + else None, + ) # Convert param value to appropriate type converted_value = self._convert_param_value( - param_value, self.current_param_name, param_config, - self.current_function_name or "") + param_value, + self.current_param_name, + param_config, + self.current_function_name or "", + ) # Build JSON fragment based on the converted type # Use json.dumps to properly serialize the value - serialized_value = json.dumps(converted_value, - ensure_ascii=False) + serialized_value = json.dumps( + converted_value, ensure_ascii=False + ) if self.param_count == 0: - json_fragment = (f'"{self.current_param_name}": ' - f'{serialized_value}') + json_fragment = ( + f'"{self.current_param_name}": {serialized_value}' + ) else: - json_fragment = (f', "{self.current_param_name}": ' - f'{serialized_value}') + json_fragment = ( + f', "{self.current_param_name}": {serialized_value}' + ) self.param_count += 1 - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=json_fragment), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments=json_fragment), + ) + ] + ) # Continue parameter value - Not used in the current implementation # since we process complete parameters above @@ -632,31 +696,33 @@ def extract_tool_calls_streaming( # Skip past > if at start if not self.current_param_value and ">" in value_chunk: gt_idx = value_chunk.find(">") - value_chunk = value_chunk[gt_idx + 1:] + value_chunk = value_chunk[gt_idx + 1 :] - if not self.current_param_value and value_chunk.startswith( - "\n"): + if not self.current_param_value and value_chunk.startswith("\n"): value_chunk = value_chunk[1:] # Store complete value full_value = self.current_param_value + value_chunk - self.accumulated_params[ - self.current_param_name] = full_value + self.accumulated_params[self.current_param_name] = full_value # Get parameter configuration for type conversion param_config = self._get_arguments_config( self.current_function_name or "", self.streaming_request.tools - if self.streaming_request else None) + if self.streaming_request + else None, + ) # Convert the parameter value to the appropriate type converted_value = self._convert_param_value( - full_value, self.current_param_name or "", - param_config, self.current_function_name or "") + full_value, + self.current_param_name or "", + param_config, + self.current_function_name or "", + ) # Serialize the converted value - serialized_value = json.dumps(converted_value, - ensure_ascii=False) + serialized_value = json.dumps(converted_value, ensure_ascii=False) # Since we've been streaming the quoted version, # we need to close it properly @@ -665,13 +731,16 @@ def extract_tool_calls_streaming( self.current_param_value = "" # Just close the current parameter string - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments='"'), # Close the string quote - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments='"' + ), # Close the string quote + ) + ] + ) else: # Continue accumulating value value_chunk = delta_text @@ -679,29 +748,36 @@ def extract_tool_calls_streaming( # Handle first chunk after param name if not self.current_param_value and ">" in value_chunk: gt_idx = value_chunk.find(">") - value_chunk = value_chunk[gt_idx + 1:] + value_chunk = value_chunk[gt_idx + 1 :] - if not self.current_param_value and value_chunk.startswith( - "\n"): + if not self.current_param_value and value_chunk.startswith("\n"): value_chunk = value_chunk[1:] if value_chunk: # Stream the escaped delta - prev_escaped = json.dumps( - self.current_param_value, ensure_ascii=False - )[1:-1] if self.current_param_value else "" + prev_escaped = ( + json.dumps(self.current_param_value, ensure_ascii=False)[ + 1:-1 + ] + if self.current_param_value + else "" + ) self.current_param_value += value_chunk - full_escaped = json.dumps(self.current_param_value, - ensure_ascii=False)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] + full_escaped = json.dumps( + self.current_param_value, ensure_ascii=False + )[1:-1] + delta_escaped = full_escaped[len(prev_escaped) :] if delta_escaped: - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=delta_escaped), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + ), + ) + ] + ) - return None \ No newline at end of file + return None diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py index 4ab67dfea104..1b7e4fec316e 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py @@ -9,14 +9,20 @@ import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -34,12 +40,12 @@ def __init__(self): # Tool configuration information self.tools: Union[list[ChatCompletionToolsParam], None] = None - self.tool_call_start_token: str = '<tool_call>' - self.tool_call_end_token: str = '</tool_call>' - self.function_start_token: str = '<function=' - self.function_end_token: str = '</function>' - self.parameter_start_token: str = '<parameter=' - self.parameter_end_token: str = '</parameter>' + self.tool_call_start_token: str = "<tool_call>" + self.tool_call_end_token: str = "</tool_call>" + self.function_start_token: str = "<function=" + self.function_end_token: str = "</function>" + self.parameter_start_token: str = "<parameter=" + self.parameter_end_token: str = "</parameter>" def reset_streaming_state(self): """Reset streaming parsing state""" @@ -53,16 +59,16 @@ def reset_streaming_state(self): self.current_function_open = False self.parameters = {} self.current_param_name = None - self.current_param_value = '' - self.current_param_value_converted = '' + self.current_param_value = "" + self.current_param_value_converted = "" self.current_param_is_first = False self.should_emit_end_newline = False self.start_quote_emitted = False - self.streaming_buffer = '' + self.streaming_buffer = "" self.last_processed_pos = 0 - self.text_content_buffer = '' + self.text_content_buffer = "" # state for preprocessing and deferred parsing self._pre_inside_parameter = False @@ -78,13 +84,13 @@ def reset_streaming_state(self): def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: """ Parse single streaming XML chunk and return Delta response - This is the actual streaming interface that receives chunks + This is the actual streaming interface that receives chunks one by one and maintains internal state Args: xml_chunk: Single XML chunk string Returns: - DeltaMessage: Contains delta information generated by this chunk, + DeltaMessage: Contains delta information generated by this chunk, returns empty response if no complete elements """ # Record delta count before processing @@ -101,42 +107,67 @@ def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: new_deltas = self.deltas[initial_delta_count:] # If this chunk contains </function> # but didn't generate '}', then complete it - if (self.current_call_id is not None - and self.function_end_token in xml_chunk): - + if ( + self.current_call_id is not None + and self.function_end_token in xml_chunk + ): # - Added '}' (non-empty parameter ending) # - Added '{}' (empty parameter function) - has_function_close = any((td.tool_calls and any( - (tc.function and tc.id == self.current_call_id - and isinstance(tc.function.arguments, str) and - (tc.function.arguments in ('}', '{}'))) - for tc in td.tool_calls)) for td in new_deltas) + has_function_close = any( + ( + td.tool_calls + and any( + ( + tc.function + and tc.id == self.current_call_id + and isinstance(tc.function.arguments, str) + and (tc.function.arguments in ("}", "{}")) + ) + for tc in td.tool_calls + ) + ) + for td in new_deltas + ) if not has_function_close: # Close potentially unclosed element if self.current_param_name: - self._end_element('parameter') + self._end_element("parameter") if self.current_function_name: - self._end_element('function') + self._end_element("function") # If this chunk contains </tool_call> # but didn't generate final empty delta, then complete it - if (self.current_call_id is not None - and self.tool_call_end_token in xml_chunk): - has_toolcall_close = any((td.tool_calls and any( - (tc.type == 'function' and tc.function and tc.function. - arguments == '' and tc.id == self.current_call_id) - for tc in td.tool_calls)) for td in new_deltas) + if ( + self.current_call_id is not None + and self.tool_call_end_token in xml_chunk + ): + has_toolcall_close = any( + ( + td.tool_calls + and any( + ( + tc.type == "function" + and tc.function + and tc.function.arguments == "" + and tc.id == self.current_call_id + ) + for tc in td.tool_calls + ) + ) + for td in new_deltas + ) if not has_toolcall_close: # Close potentially unclosed element if self.current_param_name: - self._end_element('parameter') + self._end_element("parameter") if self.current_function_name: - self._end_element('function') - self._end_element('tool_call') + self._end_element("function") + self._end_element("tool_call") except Exception as e: logger.warning("Error with fallback parsing: %s", e) # Merge newly generated deltas into single response result_delta = self._merge_new_deltas_to_single_response( - initial_delta_count) + initial_delta_count + ) return result_delta else: # No complete elements, check if there's unoutput text content @@ -145,7 +176,7 @@ def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: text_delta = DeltaMessage(content=self.text_content_buffer) self._emit_delta(text_delta) # Clear buffer to avoid duplicate output - self.text_content_buffer = '' + self.text_content_buffer = "" return text_delta # If this chunk contains end tags but wasn't triggered by parser, @@ -153,20 +184,21 @@ def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage: # Only execute when still on the same call as when entered, # to prevent accidentally closing new calls # in multi <tool_call> scenarios - if (self.current_call_id is not None - and (self.function_end_token in xml_chunk - or self.tool_call_end_token in xml_chunk)): + if self.current_call_id is not None and ( + self.function_end_token in xml_chunk + or self.tool_call_end_token in xml_chunk + ): # Close potentially unclosed element if self.current_param_name: - self._end_element('parameter') - if self.function_end_token in xml_chunk and \ - self.current_function_name: - self._end_element('function') + self._end_element("parameter") + if self.function_end_token in xml_chunk and self.current_function_name: + self._end_element("function") if self.tool_call_end_token in xml_chunk: - self._end_element('tool_call') + self._end_element("tool_call") # Return the merged delta result generated by this fallback result_delta = self._merge_new_deltas_to_single_response( - initial_delta_count) + initial_delta_count + ) return result_delta # No complete elements, return empty response @@ -181,11 +213,11 @@ def _escape_xml_special_chars(self, text: str) -> str: Escaped text """ xml_escapes = { - '&': '&', - '<': '<', - '>': '>', - '"': '"', - "'": ''' + "&": "&", + "<": "<", + ">": ">", + '"': """, + "'": "'", } for char, escape in xml_escapes.items(): @@ -204,8 +236,7 @@ def _process_complete_xml_elements(self) -> bool: while self.last_processed_pos < len(self.streaming_buffer): # Find next complete xml element - element, end_pos = self._find_next_complete_element( - self.last_processed_pos) + element, end_pos = self._find_next_complete_element(self.last_processed_pos) if element is None: # No complete element found, wait for more data break @@ -219,38 +250,46 @@ def _process_complete_xml_elements(self) -> bool: try: preprocessed_element = self._preprocess_xml_chunk(element) # Check if this is the first tool_call start - if ((preprocessed_element.strip().startswith('<tool_call>') or - preprocessed_element.strip().startswith('<function name=') - ) and self.tool_call_index - == 0) and self.text_content_buffer: + if ( + ( + preprocessed_element.strip().startswith("<tool_call>") + or preprocessed_element.strip().startswith("<function name=") + ) + and self.tool_call_index == 0 + ) and self.text_content_buffer: # First tool_call starts, # output previously collected text content first text_delta = DeltaMessage(content=self.text_content_buffer) self._emit_delta(text_delta) # Clear buffer for potential subsequent text content - self.text_content_buffer = '' + self.text_content_buffer = "" # If a new tool_call starts and # there are already completed tool_calls - if (preprocessed_element.strip().startswith('<tool_call>') - and self.tool_call_index > 0 and self.current_call_id): + if ( + preprocessed_element.strip().startswith("<tool_call>") + and self.tool_call_index > 0 + and self.current_call_id + ): # Reset parser state but preserve generated deltas if self.current_param_name: - self._end_element('parameter') + self._end_element("parameter") if self.current_function_open or self.current_function_name: - self._end_element('function') + self._end_element("function") # Output final tool_call tail delta final_delta = DeltaMessage( role=None, content=None, reasoning_content=None, tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall( - name=None, arguments='')) - ]) + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=""), + ) + ], + ) self._emit_delta(final_delta) # Reset XML parser and current call state self._reset_xml_parser_after_tool_call() @@ -278,10 +317,11 @@ def _should_skip_element(self, element: str) -> bool: """ # If it's a tool_call XML tag, don't skip - if element.startswith( - self.tool_call_start_token) or element.startswith( - self.function_start_token) or element.startswith( - self.parameter_start_token): + if ( + element.startswith(self.tool_call_start_token) + or element.startswith(self.function_start_token) + or element.startswith(self.parameter_start_token) + ): return False # If currently not parsing tool calls and not blank, @@ -301,8 +341,7 @@ def _should_skip_element(self, element: str) -> bool: # Skip blank content return not element - def _find_next_complete_element( - self, start_pos: int) -> tuple[Optional[str], int]: + def _find_next_complete_element(self, start_pos: int) -> tuple[Optional[str], int]: """ Find next complete XML element from specified position @@ -310,7 +349,7 @@ def _find_next_complete_element( start_pos: Position to start searching Returns: - (Complete element string, element end position), + (Complete element string, element end position), returns (None, start_pos) if no complete element found """ buffer = self.streaming_buffer[start_pos:] @@ -318,28 +357,28 @@ def _find_next_complete_element( if not buffer: return None, start_pos - if buffer.startswith('<'): + if buffer.startswith("<"): # Need to ensure no new < appears, # find the nearest one between < and > - tag_end = buffer.find('<', 1) - tag_end2 = buffer.find('>', 1) + tag_end = buffer.find("<", 1) + tag_end2 = buffer.find(">", 1) if tag_end != -1 and tag_end2 != -1: # Next nearest is < if tag_end < tag_end2: return buffer[:tag_end], start_pos + tag_end # Next nearest is >, means found XML element else: - return buffer[:tag_end2 + 1], start_pos + tag_end2 + 1 + return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 elif tag_end != -1: return buffer[:tag_end], start_pos + tag_end elif tag_end2 != -1: - return buffer[:tag_end2 + 1], start_pos + tag_end2 + 1 + return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 else: # If currently not parsing tool calls (entering a tool_call), # check if starts with <tool_call> if self.current_call_id is None: # Check if might be start of <tool_call> - if buffer == '<tool_call>'[:len(buffer)]: + if buffer == "<tool_call>"[: len(buffer)]: # Might be start of <tool_call>, wait for more data return None, start_pos else: @@ -351,7 +390,7 @@ def _find_next_complete_element( return None, start_pos else: # Find text content (until next < or buffer end) - next_tag_pos = buffer.find('<') + next_tag_pos = buffer.find("<") if next_tag_pos != -1: # Found text content text_content = buffer[:next_tag_pos] @@ -362,8 +401,7 @@ def _find_next_complete_element( remaining = buffer return remaining, start_pos + len(remaining) - def _merge_new_deltas_to_single_response( - self, initial_count: int) -> DeltaMessage: + def _merge_new_deltas_to_single_response(self, initial_count: int) -> DeltaMessage: """ Merge newly generated deltas from this processing into a single DeltaMessage @@ -386,7 +424,7 @@ def _merge_new_deltas_to_single_response( # Merge multiple new deltas merged_tool_calls: list[DeltaToolCall] = [] - merged_content: str = '' + merged_content: str = "" for delta in new_deltas: if delta.content: @@ -404,12 +442,13 @@ def _merge_new_deltas_to_single_response( if existing_call and existing_call.function: # Merge to existing tool_call if tool_call.function and tool_call.function.name: - existing_call.function.name = \ - tool_call.function.name - if tool_call.function \ - and tool_call.function.arguments is not None: + existing_call.function.name = tool_call.function.name + if ( + tool_call.function + and tool_call.function.arguments is not None + ): if existing_call.function.arguments is None: - existing_call.function.arguments = '' + existing_call.function.arguments = "" # For streaming JSON parameters, # simply concatenate in order @@ -421,12 +460,14 @@ def _merge_new_deltas_to_single_response( # Add new tool_call merged_tool_calls.append(tool_call) - return DeltaMessage(content=merged_content if merged_content else None, - tool_calls=merged_tool_calls) + return DeltaMessage( + content=merged_content if merged_content else None, + tool_calls=merged_tool_calls, + ) def _preprocess_xml_chunk(self, chunk: str) -> str: """ - Preprocess XML chunk, handle non-standard formats, + Preprocess XML chunk, handle non-standard formats, and escape special characters Args: @@ -439,27 +480,28 @@ def _preprocess_xml_chunk(self, chunk: str) -> str: # Check if this is a tool_call related element is_tool_call = False if chunk.startswith(self.tool_call_start_token) or chunk.startswith( - self.tool_call_end_token): + self.tool_call_end_token + ): is_tool_call = True if chunk.startswith(self.function_start_token) or chunk.startswith( - self.function_end_token): + self.function_end_token + ): is_tool_call = True if chunk.startswith(self.parameter_start_token) or chunk.startswith( - self.parameter_end_token): + self.parameter_end_token + ): is_tool_call = True # Handle <function=name> format -> <function name="name"> - processed = re.sub(r'<function=([^>]+)>', r'<function name="\1">', - chunk) + processed = re.sub(r"<function=([^>]+)>", r'<function name="\1">', chunk) # Handle <parameter=name> format -> <parameter name="name"> - processed = re.sub(r'<parameter=([^>]+)>', r'<parameter name="\1">', - processed) + processed = re.sub(r"<parameter=([^>]+)>", r'<parameter name="\1">', processed) original_chunk = chunk # If in parameter value accumulation mode if self._pre_inside_parameter: # Parameter end: output accumulated raw text # safely then return </parameter> - if processed.startswith('</parameter>'): + if processed.startswith("</parameter>"): body_text = self._pre_param_buffer # Trigger deferred parsing mode # literal_eval+json output in end_element @@ -478,29 +520,38 @@ def _preprocess_xml_chunk(self, chunk: str) -> str: # and pass through directly if self._pre_param_buffer == "": # Get current parameter type - param_type = self._get_param_type( - self._pre_current_param_name - ) if self._pre_current_param_name else 'string' + param_type = ( + self._get_param_type(self._pre_current_param_name) + if self._pre_current_param_name + else "string" + ) # Only these types need deferred parsing to # handle Python literals containing single quotes is_object_type = param_type in ["object"] - is_complex_type = (param_type - in ["array", "arr", "sequence"] - or param_type.startswith("dict") - or param_type.startswith("list")) + is_complex_type = ( + param_type in ["array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list") + ) # Only delay when contains container symbols # and has single quotes and is complex type - has_container_hint = ('[' in original_chunk) or ( - '{' in original_chunk) or ('(' in original_chunk) + has_container_hint = ( + ("[" in original_chunk) + or ("{" in original_chunk) + or ("(" in original_chunk) + ) # Determine if deferred parsing is needed need_defer = False if is_complex_type: # Complex type, always need deferred parsing need_defer = True - elif is_object_type and has_container_hint and ( - "'" in original_chunk): + elif ( + is_object_type + and has_container_hint + and ("'" in original_chunk) + ): # Object type with container symbols # and single quotes, need deferred parsing need_defer = True @@ -514,7 +565,7 @@ def _preprocess_xml_chunk(self, chunk: str) -> str: return "" # Parameter start: enable accumulation - if processed.startswith('<parameter name='): + if processed.startswith("<parameter name="): m = re.match(r'<parameter name="([^"]+)">', processed) if m: self._pre_current_param_name = m.group(1) @@ -533,76 +584,78 @@ def _emit_delta(self, delta: DeltaMessage): """Emit Delta response (streaming output)""" self.deltas.append(delta) - def _auto_close_open_parameter_if_needed(self, - incoming_tag: Optional[str] = None - ): - """Before starting to process new elements, - if there are unclosed tags from before, + def _auto_close_open_parameter_if_needed(self, incoming_tag: Optional[str] = None): + """Before starting to process new elements, + if there are unclosed tags from before, automatically complete their endings to the parser. - - If there are unclosed parameters, + - If there are unclosed parameters, it's equivalent to feeding `</parameter>` - - When about to start a new function or tool_call, + - When about to start a new function or tool_call, if there are unclosed functions, complete `</function>`. - - When about to start a new tool_call, + - When about to start a new tool_call, if there are unclosed tool_calls, complete `</tool_call>`. """ # First close unclosed parameters if self.current_param_name: - self._end_element('parameter') + self._end_element("parameter") # If about to start new function or tool_call, # and there are unclosed functions, close function first - if incoming_tag in ('function', - 'tool_call') and self.current_function_name: - self._end_element('function') + if incoming_tag in ("function", "tool_call") and self.current_function_name: + self._end_element("function") # If about to start new tool_call, # and there are unclosed tool_calls, close tool_call first - if incoming_tag == 'tool_call' and self.current_call_id: - self._end_element('tool_call') + if incoming_tag == "tool_call" and self.current_call_id: + self._end_element("tool_call") def _start_element(self, name: str, attrs: dict[str, str]): """Handle XML start element events""" - if name == 'root': + if name == "root": return - if name == 'tool_call': + if name == "tool_call": # Before opening new tool_call, # automatically complete previous unclosed tags - self._auto_close_open_parameter_if_needed('tool_call') + self._auto_close_open_parameter_if_needed("tool_call") self.parameters = {} self.current_call_id = self._get_next_call_id() self.current_param_is_first = True self.tool_call_index += 1 - elif name.startswith('function') or (name == 'function'): + elif name.startswith("function") or (name == "function"): # If missing tool_call, manually complete if not self.current_call_id: - self._start_element('tool_call', {}) + self._start_element("tool_call", {}) # Before opening new function, # automatically complete previous unclosed tags (parameter/function) - self._auto_close_open_parameter_if_needed('function') + self._auto_close_open_parameter_if_needed("function") function_name = self._extract_function_name(name, attrs) self.current_function_name = function_name self.current_function_open = True if function_name: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall( - name=function_name, arguments='')) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=function_name, arguments="" + ), + ) + ] + ) self._emit_delta(delta) - elif name.startswith('parameter') or (name == 'parameter'): + elif name.startswith("parameter") or (name == "parameter"): # If previous parameter hasn't ended normally, # complete its end first, then start new parameter - self._auto_close_open_parameter_if_needed('parameter') + self._auto_close_open_parameter_if_needed("parameter") param_name = self._extract_parameter_name(name, attrs) self.current_param_name = param_name - self.current_param_value = '' - self.current_param_value_converted = '' + self.current_param_value = "" + self.current_param_value_converted = "" self.start_quote_emitted = False # Reset start quote flag # Only output parameter name and colon, @@ -613,26 +666,36 @@ def _start_element(self, name: str, attrs: dict[str, str]): # First parameter # start JSON, only output parameter name and colon json_start = f'{{"{param_name}": ' - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall( - name=None, arguments=json_start)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_start + ), + ) + ] + ) self._emit_delta(delta) self.current_param_is_first = True else: # Subsequent parameters # add comma and parameter name, no quotes json_continue = f', "{param_name}": ' - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall( - name=None, arguments=json_continue)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=json_continue + ), + ) + ] + ) self._emit_delta(delta) self.current_param_is_first = False @@ -644,9 +707,9 @@ def _char_data(self, data: str): if self.defer_current_parameter: original_data = data if self.should_emit_end_newline: - original_data = '\n' + original_data + original_data = "\n" + original_data self.should_emit_end_newline = False - if original_data.endswith('\n'): + if original_data.endswith("\n"): self.should_emit_end_newline = True original_data = original_data[:-1] self.current_param_value += original_data @@ -656,20 +719,24 @@ def _char_data(self, data: str): # Check if this is the first time receiving data for this parameter # If this is the first packet of data and starts with \n, remove \n - if not self.current_param_value and data.startswith('\n'): + if not self.current_param_value and data.startswith("\n"): data = data[1:] # Output start quote for string type (if not already output) - if (param_type - in ['string', 'str', 'text', 'varchar', 'char', 'enum'] - and not self.start_quote_emitted): - quote_delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall(name=None, - arguments='"')) - ]) + if ( + param_type in ["string", "str", "text", "varchar", "char", "enum"] + and not self.start_quote_emitted + ): + quote_delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ] + ) self._emit_delta(quote_delta) self.start_quote_emitted = True @@ -679,45 +746,50 @@ def _char_data(self, data: str): original_data = data # Delay output of trailing newline if self.should_emit_end_newline: - original_data = '\n' + original_data + original_data = "\n" + original_data self.should_emit_end_newline = False - if original_data.endswith('\n'): + if original_data.endswith("\n"): self.should_emit_end_newline = True original_data = original_data[:-1] self.current_param_value += original_data # convert parameter value by param_type converted_value = self._convert_param_value( - self.current_param_value, param_type) - output_data = self._convert_for_json_streaming( - converted_value, param_type) + self.current_param_value, param_type + ) + output_data = self._convert_for_json_streaming(converted_value, param_type) - delta_data = output_data[len(self.current_param_value_converted):] + delta_data = output_data[len(self.current_param_value_converted) :] self.current_param_value_converted = output_data - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall(name=None, - arguments=delta_data)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=delta_data), + ) + ] + ) self._emit_delta(delta) def _end_element(self, name: str): """Handle XML end element events""" - if name == 'root': + if name == "root": return # If function or tool_call ends and there are still unclosed parameters, # complete parameter end first - if (name.startswith('function') or name == 'function' - or name == 'tool_call') and self.current_param_name: + if ( + name.startswith("function") or name == "function" or name == "tool_call" + ) and self.current_param_name: self._auto_close_open_parameter_if_needed() - if (name.startswith('parameter') - or name == 'parameter') and self.current_param_name: + if ( + name.startswith("parameter") or name == "parameter" + ) and self.current_param_name: # End current parameter param_name = self.current_param_name param_value = self.current_param_value @@ -726,32 +798,39 @@ def _end_element(self, name: str): # perform overall parsing on raw content # accumulated in preprocessing stage and output once if self.defer_current_parameter: - raw_text = self.deferred_param_raw_value \ - if self.deferred_param_raw_value else param_value + raw_text = ( + self.deferred_param_raw_value + if self.deferred_param_raw_value + else param_value + ) parsed_value = None output_arguments = None try: # If previously delayed trailing newline, # add it back before parsing if self.should_emit_end_newline: - raw_for_parse = raw_text + '\n' + raw_for_parse = raw_text + "\n" else: raw_for_parse = raw_text parsed_value = ast.literal_eval(raw_for_parse) - output_arguments = json.dumps(parsed_value, - ensure_ascii=False) + output_arguments = json.dumps(parsed_value, ensure_ascii=False) except Exception: # Fallback: output as string as-is output_arguments = json.dumps(raw_text, ensure_ascii=False) parsed_value = raw_text - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall( - name=None, arguments=output_arguments)) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall( + name=None, arguments=output_arguments + ), + ) + ] + ) self._emit_delta(delta) # Clean up and store @@ -768,84 +847,96 @@ def _end_element(self, name: str): param_type = self._get_param_type(param_name) # convert complete parameter value by param_type - converted_value = self._convert_param_value( - param_value, param_type) + converted_value = self._convert_param_value(param_value, param_type) # Decide whether to add end quote based on parameter type - if param_type in [ - 'string', 'str', 'text', 'varchar', 'char', 'enum' - ]: + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: # For empty string parameters, need special handling if not param_value and not self.start_quote_emitted: # No start quote output, # directly output complete empty string - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall( - name=None, arguments='""')) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='""'), + ) + ] + ) self._emit_delta(delta) else: # Non-empty parameter value, output end quote - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall( - name=None, arguments='"')) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments='"'), + ) + ] + ) self._emit_delta(delta) self.should_emit_end_newline = False # Store converted value self.parameters[param_name] = converted_value self.current_param_name = None - self.current_param_value = '' - self.current_param_value_converted = '' + self.current_param_value = "" + self.current_param_value_converted = "" self.start_quote_emitted = False - elif name.startswith('function') or name == 'function': + elif name.startswith("function") or name == "function": # if there are parameters, close JSON object if self.parameters: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall(name=None, - arguments='}')) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="}"), + ) + ] + ) self._emit_delta(delta) # return empty object else: - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall(name=None, - arguments='{}')) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments="{}"), + ) + ] + ) self._emit_delta(delta) self.current_function_open = False - elif name == 'tool_call': + elif name == "tool_call": # Before ending tool_call, # ensure function is closed to complete missing right brace if self.current_function_open: # If there are still unclosed parameters, close them first if self.current_param_name: - self._end_element('parameter') + self._end_element("parameter") # Close function, ensure output '}' or '{}' - self._end_element('function') + self._end_element("function") # Final Delta - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.tool_call_index - 1, - id=self.current_call_id, - type='function', - function=DeltaFunctionCall(name=None, - arguments='')) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.tool_call_index - 1, + id=self.current_call_id, + type="function", + function=DeltaFunctionCall(name=None, arguments=""), + ) + ] + ) self._emit_delta(delta) # Check if there's text content to output (between tool_calls) @@ -868,30 +959,30 @@ def set_tools(self, tools: Union[list[ChatCompletionToolsParam], None]): def _get_next_call_id(self): """Generate unique call ID""" - return f'call_{uuid.uuid4().hex[:24]}' + return f"call_{uuid.uuid4().hex[:24]}" - def _extract_function_name(self, name: str, - attrs: dict[str, str]) -> Optional[str]: + def _extract_function_name(self, name: str, attrs: dict[str, str]) -> Optional[str]: """Extract function name from various formats""" - if attrs and 'name' in attrs: - return attrs['name'] + if attrs and "name" in attrs: + return attrs["name"] - if '=' in name: - parts = name.split('=', 1) - if len(parts) == 2 and parts[0] == 'function': + if "=" in name: + parts = name.split("=", 1) + if len(parts) == 2 and parts[0] == "function": return parts[1] return None - def _extract_parameter_name(self, name: str, - attrs: dict[str, str]) -> Optional[str]: + def _extract_parameter_name( + self, name: str, attrs: dict[str, str] + ) -> Optional[str]: """Extract parameter name from various formats""" - if attrs and 'name' in attrs: - return attrs['name'] + if attrs and "name" in attrs: + return attrs["name"] - if '=' in name: - parts = name.split('=', 1) - if len(parts) == 2 and parts[0] == 'parameter': + if "=" in name: + parts = name.split("=", 1) + if len(parts) == 2 and parts[0] == "parameter": return parts[1] return None @@ -905,30 +996,36 @@ def _get_param_type(self, param_name: str) -> str: Parameter type """ if not self.tools or not self.current_function_name: - return 'string' + return "string" for tool in self.tools: - if not hasattr(tool, 'type') or not (hasattr( - tool, 'function') and hasattr(tool.function, 'name')): + if not hasattr(tool, "type") or not ( + hasattr(tool, "function") and hasattr(tool.function, "name") + ): continue - if tool.type == 'function' and \ - tool.function.name == self.current_function_name: - if not hasattr(tool.function, 'parameters'): - return 'string' + if ( + tool.type == "function" + and tool.function.name == self.current_function_name + ): + if not hasattr(tool.function, "parameters"): + return "string" params = tool.function.parameters - if isinstance(params, dict) and 'properties' in params: - properties = params['properties'] + if isinstance(params, dict) and "properties" in params: + properties = params["properties"] if param_name in properties and isinstance( - properties[param_name], dict): + properties[param_name], dict + ): return self.repair_param_type( - str(properties[param_name].get('type', 'string'))) + str(properties[param_name].get("type", "string")) + ) elif isinstance(params, dict) and param_name in params: param_config = params[param_name] if isinstance(param_config, dict): return self.repair_param_type( - str(param_config.get('type', 'string'))) + str(param_config.get("type", "string")) + ) break - return 'string' + return "string" def repair_param_type(self, param_type: str) -> str: """Repair unknown parameter types by treating them as string @@ -938,21 +1035,25 @@ def repair_param_type(self, param_type: str) -> str: Returns: Repaired parameter type """ - if param_type in [ - 'string', 'str', 'text', 'varchar', 'char', 'enum' - ] or param_type.startswith('int') or param_type.startswith( - 'uint' - ) or param_type.startswith('long') or param_type.startswith( - 'short' - ) or param_type.startswith('unsigned') or param_type.startswith( - 'num') or param_type.startswith('float') or param_type in [ - 'boolean', 'bool', 'binary' - ] or (param_type in ["object", "array", "arr", "sequence"] - or param_type.startswith("dict") - or param_type.startswith("list")): + if ( + param_type in ["string", "str", "text", "varchar", "char", "enum"] + or param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + or param_type.startswith("num") + or param_type.startswith("float") + or param_type in ["boolean", "bool", "binary"] + or ( + param_type in ["object", "array", "arr", "sequence"] + or param_type.startswith("dict") + or param_type.startswith("list") + ) + ): return param_type else: - return 'string' + return "string" def _convert_param_value(self, param_value: str, param_type: str) -> Any: """Convert value based on parameter type @@ -963,42 +1064,51 @@ def _convert_param_value(self, param_value: str, param_type: str) -> Any: Returns: Converted value """ - if param_value.lower() == 'null': + if param_value.lower() == "null": return None param_type = param_type.strip().lower() - if param_type in ['string', 'str', 'text', 'varchar', 'char', 'enum']: + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: return param_value - elif (param_type.startswith('int') or param_type.startswith('uint') - or param_type.startswith('long') - or param_type.startswith('short') - or param_type.startswith('unsigned')): + elif ( + param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + ): try: return int(param_value) except (ValueError, TypeError): logger.warning( "Parsed value '%s' of parameter '%s' is not an integer " - "in tool '%s', degenerating to string.", param_value) + "in tool '%s', degenerating to string.", + param_value, + ) return param_value - elif param_type.startswith('num') or param_type.startswith('float'): + elif param_type.startswith("num") or param_type.startswith("float"): try: float_param_value: float = float(param_value) - return float_param_value if float_param_value - int( - float_param_value) != 0 else int(float_param_value) + return ( + float_param_value + if float_param_value - int(float_param_value) != 0 + else int(float_param_value) + ) except (ValueError, TypeError): logger.warning( "Parsed value '%s' of parameter '%s' is not a float " - "in tool '%s', degenerating to string.", param_value) + "in tool '%s', degenerating to string.", + param_value, + ) return param_value - elif param_type in ['boolean', 'bool', 'binary']: + elif param_type in ["boolean", "bool", "binary"]: param_value = param_value.lower() - return param_value == 'true' + return param_value == "true" else: return param_value - def _convert_for_json_streaming(self, converted_value: Any, - param_type: str) -> str: - """Convert converted_value based on + def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str: + """Convert converted_value based on whether it's empty and if type is string Args: converted_value: Converted value @@ -1008,10 +1118,10 @@ def _convert_for_json_streaming(self, converted_value: Any, Converted string for streaming output """ # Check if value is empty, but exclude numeric 0 - if converted_value is None or converted_value == '': - return '' + if converted_value is None or converted_value == "": + return "" - if param_type in ['string', 'str', 'text', 'varchar', 'char', 'enum']: + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: # String type, remove double quotes return json.dumps(converted_value, ensure_ascii=False)[1:-1] else: @@ -1023,7 +1133,7 @@ def _convert_for_json_streaming(self, converted_value: Any, def _reset_xml_parser_after_tool_call(self): """ - Each tool_call is treated as a separate XML document, + Each tool_call is treated as a separate XML document, so we need to reset the parser after each tool_call. """ @@ -1039,12 +1149,12 @@ def _reset_xml_parser_after_tool_call(self): self.current_function_open = False self.parameters = {} self.current_param_name = None - self.current_param_value = '' - self.current_param_value_converted = '' + self.current_param_value = "" + self.current_param_value_converted = "" self.current_param_is_first = False self.should_emit_end_newline = False self.start_quote_emitted = False - self.text_content_buffer = '' + self.text_content_buffer = "" # Reset preprocessing and deferred parsing state self._pre_inside_parameter = False @@ -1056,13 +1166,13 @@ def _reset_xml_parser_after_tool_call(self): @ToolParserManager.register_module("qwen3_xml") class Qwen3XMLToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.parser = StreamingXMLToolCallParser() - logger.info("vLLM Successfully import tool parser %s !", - self.__class__.__name__) + logger.info( + "vLLM Successfully import tool parser %s !", self.__class__.__name__ + ) def extract_tool_calls( self, @@ -1091,7 +1201,8 @@ def extract_tool_calls( name=tool_call.function.name, arguments=tool_call.function.arguments, ), - )) + ) + ) return ExtractedToolCallInformation( tool_calls=tool_calls, tools_called=len(tool_calls) > 0, @@ -1119,19 +1230,22 @@ def extract_tool_calls_streaming( # to correctly output tool_call field if not delta_text and delta_token_ids: open_calls = current_text.count( - self.parser.tool_call_start_token) - current_text.count( - self.parser.tool_call_end_token) + self.parser.tool_call_start_token + ) - current_text.count(self.parser.tool_call_end_token) if open_calls == 0 and self.parser.tool_call_index > 0: # If current_call_id is None, use last_completed_call_id - call_id = self.parser.current_call_id or \ - self.parser.last_completed_call_id - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.parser.tool_call_index - 1, - id=call_id, - function=DeltaFunctionCall(arguments=''), - type='function', - ) - ]) + call_id = ( + self.parser.current_call_id or self.parser.last_completed_call_id + ) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.parser.tool_call_index - 1, + id=call_id, + function=DeltaFunctionCall(arguments=""), + type="function", + ) + ] + ) return self.parser.parse_single_streaming_chunks(delta_text) diff --git a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py index 95458f07ff2a..2e7bd0d1d344 100644 --- a/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py @@ -11,14 +11,20 @@ import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ChatCompletionToolsParam, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -51,33 +57,36 @@ def __init__(self, tokenizer: AnyTokenizer): self.failed_count: int = 0 self._reset_streaming_state() - self.tool_call_start_token_id = self.vocab.get( - self.tool_call_start_token) + self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) self.think_end_token_id = self.vocab.get(self.think_end_token) - if (self.tool_call_start_token_id is None - or self.tool_call_end_token_id is None): + if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: raise RuntimeError( "Seed_Oss XML parser: tokenizer did not include " - "<seed:tool_call> or its closing tag.") + "<seed:tool_call> or its closing tag." + ) tool_start_re = re.escape(self.tool_call_start_token) tool_end_re = re.escape(self.tool_call_end_token) self.tool_call_complete_regex = re.compile( - rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL) + rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL + ) self.tool_call_regex = re.compile( - rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", - re.DOTALL) + rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$", re.DOTALL + ) self.tool_call_function_regex = re.compile( - r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) + r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL + ) self.tool_call_parameter_regex = re.compile( - r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL) + r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL + ) - logger.info("vLLM Seed-Oss XML tool parser loaded (%s).", - self.__class__.__name__) + logger.info( + "vLLM Seed-Oss XML tool parser loaded (%s).", self.__class__.__name__ + ) def _generate_tool_call_id(self) -> str: """Generate a unique tool call ID.""" @@ -100,20 +109,17 @@ def _reset_streaming_state(self): self.json_closed = False def _parse_xml_function_call( - self, function_call_str: str, - tools: Optional[list[ChatCompletionToolsParam]] + self, function_call_str: str, tools: Optional[list[ChatCompletionToolsParam]] ) -> Optional[ToolCall]: - def get_arguments_config(func_name: str) -> dict: if tools is None: return {} for config in tools: if not hasattr(config, "type") or not ( - hasattr(config, "function") - and hasattr(config.function, "name")): + hasattr(config, "function") and hasattr(config.function, "name") + ): continue - if (config.type == "function" - and config.function.name == func_name): + if config.type == "function" and config.function.name == func_name: if not hasattr(config.function, "parameters"): return {} params = config.function.parameters @@ -123,12 +129,12 @@ def get_arguments_config(func_name: str) -> dict: return params else: return {} - logger.warning("Tool '%s' is not defined in the tools list.", - func_name) + logger.warning("Tool '%s' is not defined in the tools list.", func_name) return {} - def convert_param_value(param_value: str, param_name: str, - param_config: dict, func_name: str) -> Any: + def convert_param_value( + param_value: str, param_name: str, param_config: dict, func_name: str + ) -> Any: # Handle null value for any type if param_value.lower() == "null": return None @@ -138,44 +144,55 @@ def convert_param_value(param_value: str, param_name: str, logger.warning( "Parsed parameter '%s' is not defined in " "the tool parameters for tool '%s', " - "directly returning the string value.", param_name, - func_name) + "directly returning the string value.", + param_name, + func_name, + ) return param_value - if (isinstance(param_config[param_name], dict) - and "type" in param_config[param_name]): - param_type = str( - param_config[param_name]["type"]).strip().lower() + if ( + isinstance(param_config[param_name], dict) + and "type" in param_config[param_name] + ): + param_type = str(param_config[param_name]["type"]).strip().lower() else: param_type = "string" - if param_type in [ - "string", "str", "text", "varchar", "char", "enum" - ]: + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: return param_value - elif (param_type.startswith("int") or param_type.startswith("uint") - or param_type.startswith("long") - or param_type.startswith("short") - or param_type.startswith("unsigned")): + elif ( + param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + ): try: param_value = int(param_value) # type: ignore except (ValueError, TypeError): logger.warning( "Parsed value '%s' of parameter '%s' is not an integer in tool " - "'%s', degenerating to string.", param_value, - param_name, func_name) + "'%s', degenerating to string.", + param_value, + param_name, + func_name, + ) return param_value - elif param_type.startswith("num") or param_type.startswith( - "float"): + elif param_type.startswith("num") or param_type.startswith("float"): try: float_param_value = float(param_value) - param_value = float_param_value if float_param_value - int( - float_param_value) != 0 else int( - float_param_value) # type: ignore + param_value = ( + float_param_value # type: ignore + if float_param_value - int(float_param_value) != 0 + else int(float_param_value) # type: ignore + ) except (ValueError, TypeError): logger.warning( "Parsed value '%s' of parameter '%s' is not a float in tool " - "'%s', degenerating to string.", param_value, - param_name, func_name) + "'%s', degenerating to string.", + param_value, + param_name, + func_name, + ) return param_value elif param_type in ["boolean", "bool", "binary"]: param_value = param_value.lower() @@ -183,7 +200,10 @@ def convert_param_value(param_value: str, param_name: str, logger.warning( "Parsed value '%s' of parameter '%s' is not a boolean " "(`true` of `false`) in tool '%s', degenerating to false.", - param_value, param_name, func_name) + param_value, + param_name, + func_name, + ) return param_value == "true" else: if param_type == "object" or param_type.startswith("dict"): @@ -194,27 +214,33 @@ def convert_param_value(param_value: str, param_name: str, logger.warning( "Parsed value '%s' of parameter '%s' is not a valid JSON " "object in tool '%s', will try other methods to parse it.", - param_value, param_name, func_name) + param_value, + param_name, + func_name, + ) try: param_value = ast.literal_eval(param_value) except (ValueError, SyntaxError): logger.warning( "Parsed value '%s' of parameter '%s' cannot be converted via " "Python `ast.literal_eval()` in tool '%s', degenerating to string.", - param_value, param_name, func_name) + param_value, + param_name, + func_name, + ) return param_value # Extract function name end_index = function_call_str.index(">") function_name = function_call_str[:end_index] param_config = get_arguments_config(function_name) - parameters = function_call_str[end_index + 1:] + parameters = function_call_str[end_index + 1 :] param_dict = {} for match in self.tool_call_parameter_regex.findall(parameters): match_text = match[0] if match[0] else match[1] idx = match_text.index(">") param_name = match_text[:idx] - param_value = str(match_text[idx + 1:]) + param_value = str(match_text[idx + 1 :]) # Remove prefix and trailing \n if param_value.startswith("\n"): param_value = param_value[1:] @@ -222,12 +248,13 @@ def convert_param_value(param_value: str, param_name: str, param_value = param_value[:-1] param_dict[param_name] = convert_param_value( - param_value, param_name, param_config, function_name) + param_value, param_name, param_config, function_name + ) return ToolCall( type="function", - function=FunctionCall(name=function_name, - arguments=json.dumps(param_dict, - ensure_ascii=False)), + function=FunctionCall( + name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False) + ), ) def _get_function_calls(self, model_output: str) -> list[str]: @@ -243,8 +270,7 @@ def _get_function_calls(self, model_output: str) -> list[str]: raw_function_calls = [] for tool_call in raw_tool_calls: - raw_function_calls.extend( - self.tool_call_function_regex.findall(tool_call)) + raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call)) function_calls = [ match[0] if match[0] else match[1] for match in raw_function_calls @@ -258,16 +284,19 @@ def extract_tool_calls( ) -> ExtractedToolCallInformation: # Quick check to avoid unnecessary processing if self.tool_call_prefix not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) # Check if both think start and end tokens are present - if (self.think_start_token in model_output - and self.think_end_token in model_output): + if ( + self.think_start_token in model_output + and self.think_end_token in model_output + ): # Find the position of think end token think_end_index = model_output.find(self.think_end_token) + len( - self.think_end_token) + self.think_end_token + ) # Extract content after think end token result_content = model_output[think_end_index:] thinking_content = model_output[:think_end_index] @@ -278,9 +307,9 @@ def extract_tool_calls( try: function_calls = self._get_function_calls(result_content) if len(function_calls) == 0: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) tool_calls = [ self._parse_xml_function_call(function_call_str, request.tools) @@ -291,19 +320,20 @@ def extract_tool_calls( self.prev_tool_call_arr.clear() # Clear previous calls for tool_call in tool_calls: if tool_call: - self.prev_tool_call_arr.append({ - "name": - tool_call.function.name, - "arguments": - tool_call.function.arguments, - }) + self.prev_tool_call_arr.append( + { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + ) # Extract content before tool calls - tool_call_start_index = result_content.find( - self.tool_call_start_token) + tool_call_start_index = result_content.find(self.tool_call_start_token) tool_call_start_index = ( - tool_call_start_index if tool_call_start_index >= 0 else - result_content.find(self.tool_call_prefix)) + tool_call_start_index + if tool_call_start_index >= 0 + else result_content.find(self.tool_call_prefix) + ) content = thinking_content + result_content[:tool_call_start_index] return ExtractedToolCallInformation( @@ -314,9 +344,9 @@ def extract_tool_calls( except Exception: logger.exception("Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -334,18 +364,18 @@ def extract_tool_calls_streaming( # Check if this is an EOS token after all tool calls are complete # We check for tool calls in the text even if is_tool_call_started # is False because it might have been reset after processing all tools - if (delta_token_ids - and self.tool_call_end_token_id not in delta_token_ids): + if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids: # Count complete tool calls complete_calls = len( - self.tool_call_complete_regex.findall(current_text)) + self.tool_call_complete_regex.findall(current_text) + ) # If we have completed tool calls and populated prev_tool_call_arr if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: # Check if all tool calls are closed open_calls = current_text.count( - self.tool_call_start_token) - current_text.count( - self.tool_call_end_token) + self.tool_call_start_token + ) - current_text.count(self.tool_call_end_token) if open_calls == 0: # Return empty delta message to allow finish_reason processing return DeltaMessage(content="") @@ -375,16 +405,18 @@ def extract_tool_calls_streaming( # Check if there are more tool calls if self.current_tool_index >= current_text.count( - self.tool_call_start_token): + self.tool_call_start_token + ): # No more tool calls self.is_tool_call_started = False # Continue processing next tool return None # Check if end thinking - if (not self.is_thinking_end - and (self.think_end_token_id in delta_token_ids - or self.think_end_token in delta_text)): + if not self.is_thinking_end and ( + self.think_end_token_id in delta_token_ids + or self.think_end_token in delta_text + ): self.is_thinking_end = True # If thinking hasn't ended yet, don't process any tool calls @@ -394,20 +426,25 @@ def extract_tool_calls_streaming( # Handle normal content before tool calls if not self.is_tool_call_started: # Check if tool call is starting - if (self.tool_call_start_token_id in delta_token_ids - or self.tool_call_start_token in delta_text): + if ( + self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text + ): self.is_tool_call_started = True # Return any content before the tool call if self.tool_call_start_token in delta_text: - content_before = delta_text[:delta_text.index( - self.tool_call_start_token)] + content_before = delta_text[ + : delta_text.index(self.tool_call_start_token) + ] if content_before: return DeltaMessage(content=content_before) return None else: # Check if we're between tool calls - skip whitespace - if (current_text.rstrip().endswith(self.tool_call_end_token) - and delta_text.strip() == ""): + if ( + current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == "" + ): # We just ended a tool call, skip whitespace return None # Normal content, no tool call @@ -423,9 +460,11 @@ def extract_tool_calls_streaming( # We're in a tool call, find the current tool call portion # Need to find the correct tool call based on current_tool_index # Only process tool calls after think_end_token - think_end_index = current_text.find(self.think_end_token) + len( - self.think_end_token - ) if self.think_end_token in current_text else 0 + think_end_index = ( + current_text.find(self.think_end_token) + len(self.think_end_token) + if self.think_end_token in current_text + else 0 + ) tool_starts: list[int] = [] idx = think_end_index while True: @@ -441,26 +480,26 @@ def extract_tool_calls_streaming( tool_start_idx = tool_starts[self.current_tool_index] # Find where this tool call ends (or current position if not ended yet) - tool_end_idx = current_text.find(self.tool_call_end_token, - tool_start_idx) + tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) if tool_end_idx == -1: tool_text = current_text[tool_start_idx:] else: - tool_text = current_text[tool_start_idx:tool_end_idx + - len(self.tool_call_end_token)] + tool_text = current_text[ + tool_start_idx : tool_end_idx + len(self.tool_call_end_token) + ] # Looking for function header if not self.header_sent: if self.tool_call_prefix in tool_text: func_start = tool_text.find(self.tool_call_prefix) + len( - self.tool_call_prefix) + self.tool_call_prefix + ) func_end = tool_text.find(">", func_start) if func_end != -1: # Found complete function name self.current_function_name = tool_text[func_start:func_end] - self.current_tool_id = self._generate_tool_call_id( - ) # type: ignore + self.current_tool_id = self._generate_tool_call_id() # type: ignore self.header_sent = True self.in_function = True @@ -468,38 +507,44 @@ def extract_tool_calls_streaming( # This ensures finish_reason="tool_calls" even if parsing isn't complete already_added = any( tool.get("name") == self.current_function_name - for tool in self.prev_tool_call_arr) + for tool in self.prev_tool_call_arr + ) if not already_added: - self.prev_tool_call_arr.append({ - "name": self.current_function_name, - "arguments": - "{}", # Placeholder, will be updated later - }) + self.prev_tool_call_arr.append( + { + "name": self.current_function_name, + "arguments": "{}", # Placeholder, will be updated later + } + ) # Send header with function info - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - id=self.current_tool_id, - function=DeltaFunctionCall( - name=self.current_function_name, arguments=""), - type="function", - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments="" + ), + type="function", + ) + ] + ) return None # We've sent header, now handle function body if self.in_function: # Send opening brace if not sent yet - if (not self.json_started - and self.parameter_prefix not in delta_text): + if not self.json_started and self.parameter_prefix not in delta_text: self.json_started = True - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="{"), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ] + ) # Make sure json_started is set if we're processing parameters if not self.json_started: @@ -513,34 +558,38 @@ def extract_tool_calls_streaming( # Extract the complete tool call to update prev_tool_call_arr with final arguments # Find the function content func_start = tool_text.find(self.tool_call_prefix) + len( - self.tool_call_prefix) - func_content_end = tool_text.find(self.function_end_token, - func_start) + self.tool_call_prefix + ) + func_content_end = tool_text.find(self.function_end_token, func_start) if func_content_end != -1: func_content = tool_text[func_start:func_content_end] # Parse to get the complete arguments try: parsed_tool = self._parse_xml_function_call( - func_content, request.tools if request else None) + func_content, request.tools if request else None + ) if parsed_tool: # Update existing entry in prev_tool_call_arr with complete arguments for i, tool in enumerate(self.prev_tool_call_arr): - if tool.get( - "name") == parsed_tool.function.name: + if tool.get("name") == parsed_tool.function.name: self.prev_tool_call_arr[i]["arguments"] = ( - parsed_tool.function.arguments) + parsed_tool.function.arguments + ) break except Exception: logger.warning( "Failed to parse tool arguments during streaming.", - exc_info=True) + exc_info=True, + ) - result = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="}"), - ) - ]) + result = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ] + ) # Reset state for next tool self.in_function = False @@ -583,8 +632,7 @@ def extract_tool_calls_streaming( value_text = value_text[1:] # Find where this parameter ends - param_end_idx = value_text.find( - self.parameter_end_token) + param_end_idx = value_text.find(self.parameter_end_token) if param_end_idx != -1: # Complete parameter found param_value = value_text[:param_end_idx] @@ -594,22 +642,33 @@ def extract_tool_calls_streaming( # Build complete JSON fragment for this parameter if self.param_count == 0: json_fragment = ( - '"' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + '"' + + self.current_param_name + + '": "' + + json.dumps(param_value)[1:-1] + + '"' + ) else: json_fragment = ( - ', "' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + ', "' + + self.current_param_name + + '": "' + + json.dumps(param_value)[1:-1] + + '"' + ) self.param_count += 1 - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=json_fragment), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment + ), + ) + ] + ) # Continue parameter value if self.in_param: @@ -621,29 +680,34 @@ def extract_tool_calls_streaming( # Skip past > if at start if not self.current_param_value and ">" in value_chunk: gt_idx = value_chunk.find(">") - value_chunk = value_chunk[gt_idx + 1:] + value_chunk = value_chunk[gt_idx + 1 :] - if not self.current_param_value and value_chunk.startswith( - "\n"): + if not self.current_param_value and value_chunk.startswith("\n"): value_chunk = value_chunk[1:] # Calculate incremental JSON full_value = self.current_param_value + value_chunk - prev_escaped = (json.dumps(self.current_param_value)[1:-1] - if self.current_param_value else "") + prev_escaped = ( + json.dumps(self.current_param_value)[1:-1] + if self.current_param_value + else "" + ) full_escaped = json.dumps(full_value)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] + delta_escaped = full_escaped[len(prev_escaped) :] self.in_param = False self.current_param_value = "" - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=delta_escaped + '"'), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + '"' + ), + ) + ] + ) else: # Continue accumulating value value_chunk = delta_text @@ -651,29 +715,32 @@ def extract_tool_calls_streaming( # Handle first chunk after param name if not self.current_param_value and ">" in value_chunk: gt_idx = value_chunk.find(">") - value_chunk = value_chunk[gt_idx + 1:] + value_chunk = value_chunk[gt_idx + 1 :] - if not self.current_param_value and value_chunk.startswith( - "\n"): + if not self.current_param_value and value_chunk.startswith("\n"): value_chunk = value_chunk[1:] if value_chunk: # Stream the escaped delta - prev_escaped = (json.dumps( - self.current_param_value)[1:-1] - if self.current_param_value else "") + prev_escaped = ( + json.dumps(self.current_param_value)[1:-1] + if self.current_param_value + else "" + ) self.current_param_value += value_chunk - full_escaped = json.dumps( - self.current_param_value)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] + full_escaped = json.dumps(self.current_param_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped) :] if delta_escaped: - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=delta_escaped), - ) - ]) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + ), + ) + ] + ) return None diff --git a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py index a20d18eb5254..34bd372b2060 100644 --- a/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py @@ -8,13 +8,19 @@ import regex as re -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -35,9 +41,7 @@ class Step3ToolParser(ToolParser): TOOL_CALL_BEGIN = "<|tool_call_begin|>" TOOL_CALL_END = "<|tool_call_end|>" TOOL_SEP = "<|tool_sep|>" - SPECIAL_TOKENS = [ - TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END - ] + SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END] def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -46,18 +50,16 @@ def __init__(self, tokenizer: AnyTokenizer): self.tool_block_started = False self.tool_block_finished = False - def adjust_request( - self, request: ChatCompletionRequest) -> ChatCompletionRequest: - if request.tools and request.tool_choice != 'none': + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != "none": request.skip_special_tokens = False return request @staticmethod def _parse_steptml_invoke( - action_text: str + action_text: str, ) -> tuple[Optional[str], Optional[dict[str, str]]]: - func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', - action_text) + func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text) if not func_name_match: return None, None func_name = func_name_match.group(1) @@ -65,7 +67,8 @@ def _parse_steptml_invoke( params: dict[str, str] = {} param_matches = re.findall( r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>', - action_text) + action_text, + ) for name, value in param_matches: params[name] = value.strip() return func_name, params @@ -95,11 +98,13 @@ def _cast_arguments( params[key] = float(value) elif typ == "boolean": lower_val = value.lower() - params[key] = lower_val == "true" if lower_val in ( - "true", "false") else value + params[key] = ( + lower_val == "true" + if lower_val in ("true", "false") + else value + ) elif typ == "null": - params[key] = None if value.lower( - ) == "null" else value + params[key] = None if value.lower() == "null" else value break return params @@ -113,13 +118,12 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # The main loop processes the stream from the last known position. while True: if self.position >= len(current_text): return None # We've processed the entire stream. - unprocessed_text = current_text[self.position:] + unprocessed_text = current_text[self.position :] # STATE: After all tools are done, all subsequent text is content. if self.tool_block_finished: @@ -135,8 +139,10 @@ def extract_tool_calls_streaming( start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN) if start_pos == -1: - if self.TOOL_CALLS_BEGIN.startswith( - unprocessed_text.strip()) and unprocessed_text: + if ( + self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip()) + and unprocessed_text + ): return None # It's a prefix, wait. self.position = len(current_text) return DeltaMessage(content=unprocessed_text) @@ -157,9 +163,9 @@ def extract_tool_calls_streaming( continue # Check if we are between tool calls. - tool_finished = ( - self.current_tool_id != -1 and - self.prev_tool_call_arr[self.current_tool_id].get("finished")) + tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[ + self.current_tool_id + ].get("finished") if self.current_tool_id == -1 or tool_finished: if unprocessed_text.startswith(self.TOOL_CALL_BEGIN): self.position += len(self.TOOL_CALL_BEGIN) @@ -170,8 +176,7 @@ def extract_tool_calls_streaming( self.current_tool_name_sent = False while len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) - self.prev_tool_call_arr[ - self.current_tool_id]["finished"] = False + self.prev_tool_call_arr[self.current_tool_id]["finished"] = False continue if self.TOOL_CALL_BEGIN.startswith(unprocessed_text): @@ -179,63 +184,65 @@ def extract_tool_calls_streaming( # STATE: Parsing an active tool call. if self.current_tool_id != -1 and not self.prev_tool_call_arr[ - self.current_tool_id].get("finished", False): + self.current_tool_id + ].get("finished", False): end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END) if end_tool_pos == -1: tool_body = unprocessed_text else: tool_body = unprocessed_text[:end_tool_pos] - if end_tool_pos == -1 and self.TOOL_CALL_END.startswith( - tool_body): + if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body): return None - function_name, arguments = self._parse_steptml_invoke( - tool_body) + function_name, arguments = self._parse_steptml_invoke(tool_body) if not function_name: return None - tool_call_arr = { - "name": function_name, - "parameters": arguments or {} - } + tool_call_arr = {"name": function_name, "parameters": arguments or {}} # Send the function name as soon as it's parsed. if not self.current_tool_name_sent: self.current_tool_name_sent = True - self.prev_tool_call_arr[self.current_tool_id].update( - tool_call_arr) - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=f"chatcmpl-tool-{random_uuid()}", - function=DeltaFunctionCall( - name=function_name)) - ]) + self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall(name=function_name), + ) + ] + ) # Update our internal state with the latest parsed arguments. - self.prev_tool_call_arr[ - self.current_tool_id].update( # noqa: E501 - tool_call_arr) + self.prev_tool_call_arr[self.current_tool_id].update( # noqa: E501 + tool_call_arr + ) # Only send arguments when the tool call is complete. if end_tool_pos != -1: self.position += end_tool_pos + len(self.TOOL_CALL_END) - self.prev_tool_call_arr[ - self.current_tool_id]["finished"] = True + self.prev_tool_call_arr[self.current_tool_id]["finished"] = True final_args = self._cast_arguments( function_name, tool_call_arr.get("parameters", {}), # type: ignore - request) + request, + ) if final_args: - final_args_json = json.dumps(final_args, - ensure_ascii=False) - return DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=final_args_json)) - ]) + final_args_json = json.dumps(final_args, ensure_ascii=False) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=final_args_json + ), + ) + ] + ) # If tool is not finished, return None to wait for more tokens. return None @@ -248,15 +255,15 @@ def extract_tool_calls( request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: if self.TOOL_CALLS_BEGIN not in model_output: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1) if self.TOOL_CALLS_END not in rest: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1) content = (pre_text + post_text).strip() @@ -276,21 +283,22 @@ def extract_tool_calls( if type_part.strip() != "function": continue - function_name, params_dict = self._parse_steptml_invoke( - invoke_part) + function_name, params_dict = self._parse_steptml_invoke(invoke_part) if function_name and params_dict is not None: - params_dict = self._cast_arguments(function_name, params_dict, - request) + params_dict = self._cast_arguments(function_name, params_dict, request) params_str = json.dumps(params_dict, ensure_ascii=False) tool_calls.append( - ToolCall(function=FunctionCall(name=function_name, - arguments=params_str))) + ToolCall( + function=FunctionCall(name=function_name, arguments=params_str) + ) + ) if tool_calls: return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, - content=content if content else None) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + content=content if content else None, + ) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py index aa41cd6dc53e..e076ab38e336 100644 --- a/vllm/entrypoints/openai/tool_parsers/utils.py +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -22,7 +22,7 @@ def find_common_prefix(s1: str, s2: str) -> str: e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '{"fruit": "ap' """ - prefix = '' + prefix = "" min_length = min(len(s1), len(s2)) for i in range(0, min_length): if s1[i] == s2[i]: @@ -40,7 +40,7 @@ def find_common_suffix(s1: str, s2: str) -> str: e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' """ - suffix = '' + suffix = "" min_length = min(len(s1), len(s2)) for i in range(1, min_length + 1): if s1[-i] == s2[-i] and not s1[-i].isalnum(): @@ -70,15 +70,15 @@ def extract_intermediate_diff(curr: str, old: str) -> str: """ suffix = find_common_suffix(curr, old) - old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + old = old[::-1].replace(suffix[::-1], "", 1)[::-1] prefix = find_common_prefix(curr, old) diff = curr if len(suffix): - diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1] if len(prefix): # replace the prefix only once in case it's mirrored - diff = diff.replace(prefix, '', 1) + diff = diff.replace(prefix, "", 1) return diff diff --git a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py index 484e904cd8c3..c1f0d29cc087 100644 --- a/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py @@ -8,13 +8,19 @@ import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, - ExtractedToolCallInformation, - FunctionCall, ToolCall) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser, ToolParserManager) + ToolParser, + ToolParserManager, +) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -24,7 +30,6 @@ @ToolParserManager.register_module("xlam") class xLAMToolParser(ToolParser): - def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) @@ -32,8 +37,7 @@ def __init__(self, tokenizer: AnyTokenizer): self.prev_tool_calls: list[dict] = [] self.current_tool_id = -1 self.current_tool_name_sent = False - self.streamed_args: list[str] = [ - ] # Track arguments sent for each tool + self.streamed_args: list[str] = [] # Track arguments sent for each tool # For backward compatibility with tests self.current_tools_sent: list[bool] = [] @@ -57,7 +61,8 @@ def __init__(self, tokenizer: AnyTokenizer): } def preprocess_model_output( - self, model_output: str) -> tuple[Optional[str], Optional[str]]: + self, model_output: str + ) -> tuple[Optional[str], Optional[str]]: """ Preprocess the model output to extract content and potential tool calls. Returns: @@ -66,8 +71,7 @@ def preprocess_model_output( # Check for thinking tag thinking_match = re.search(self.thinking_tag_pattern, model_output) if thinking_match: - content = model_output[:thinking_match.start() + - len("</think>")].strip() + content = model_output[: thinking_match.start() + len("</think>")].strip() thinking_content = thinking_match.group(1).strip() # Try to parse the thinking content as JSON @@ -94,8 +98,7 @@ def preprocess_model_output( try: json.loads(json_str) # Extract content by removing the JSON code block - content = re.sub(json_pattern, "", - model_output).strip() + content = re.sub(json_pattern, "", model_output).strip() return content, json_str except json.JSONDecodeError: continue @@ -107,28 +110,30 @@ def preprocess_model_output( return None, model_output except json.JSONDecodeError: # Even if it's not valid JSON yet, it might be a tool call in progress - if ("{" in model_output and "name" in model_output - and "arguments" in model_output): + if ( + "{" in model_output + and "name" in model_output + and "arguments" in model_output + ): return None, model_output # If no tool calls found, return the original output as content return model_output, None def extract_tool_calls( - self, model_output: str, - request: ChatCompletionRequest) -> ExtractedToolCallInformation: + self, model_output: str, request: ChatCompletionRequest + ) -> ExtractedToolCallInformation: """ Extract tool calls from a complete model output. """ try: # Preprocess the model output - content, potential_tool_calls = self.preprocess_model_output( - model_output) + content, potential_tool_calls = self.preprocess_model_output(model_output) if not potential_tool_calls: - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=content) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=content + ) # Parse the potential tool calls as JSON tool_calls_data = json.loads(potential_tool_calls) @@ -145,8 +150,11 @@ def extract_tool_calls( tool_calls: list[ToolCall] = [] for idx, call in enumerate(tool_calls_data): - if (not isinstance(call, dict) or "name" not in call - or "arguments" not in call): + if ( + not isinstance(call, dict) + or "name" not in call + or "arguments" not in call + ): logger.debug("Invalid tool call format at index %d", idx) continue @@ -155,8 +163,11 @@ def extract_tool_calls( type="function", function=FunctionCall( name=call["name"], - arguments=(json.dumps(call["arguments"]) if isinstance( - call["arguments"], dict) else call["arguments"]), + arguments=( + json.dumps(call["arguments"]) + if isinstance(call["arguments"], dict) + else call["arguments"] + ), ), ) tool_calls.append(tool_call) @@ -169,9 +180,9 @@ def extract_tool_calls( except Exception as e: logger.exception("Error extracting tool calls: %s", str(e)) - return ExtractedToolCallInformation(tools_called=False, - tool_calls=[], - content=model_output) + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) def extract_tool_calls_streaming( self, @@ -189,26 +200,36 @@ def extract_tool_calls_streaming( # First, check for a definitive start of a tool call block. # This prevents premature parsing of incomplete output. stripped_text = current_text.strip() - preprocessed_content, preprocessed_tool_calls = ( - self.preprocess_model_output(current_text)) + preprocessed_content, preprocessed_tool_calls = self.preprocess_model_output( + current_text + ) # For JSON code blocks, we need to detect them earlier, even if incomplete - has_potential_json_block = ("```json" in current_text - or "```\n[" in current_text - or "[TOOL_CALLS]" in current_text - or "<tool_call>" in current_text) + has_potential_json_block = ( + "```json" in current_text + or "```\n[" in current_text + or "[TOOL_CALLS]" in current_text + or "<tool_call>" in current_text + ) is_tool_call_block = ( stripped_text.startswith("[") or stripped_text.startswith("<tool_call>") - or stripped_text.startswith("[TOOL_CALLS]") or + or stripped_text.startswith("[TOOL_CALLS]") + or # Check if we have thinking tags with JSON-like content following - ("</think>[" in current_text) or + ("</think>[" in current_text) + or # Check if the text contains a JSON array after preprocessing - preprocessed_tool_calls is not None or + preprocessed_tool_calls is not None + or # For JSON code blocks, detect early if we see enough structure - (has_potential_json_block and '"name"' in current_text - and '"arguments"' in current_text)) + ( + has_potential_json_block + and '"name"' in current_text + and '"arguments"' in current_text + ) + ) if not is_tool_call_block: return DeltaMessage(content=delta_text) @@ -225,8 +246,9 @@ def extract_tool_calls_streaming( # Try parsing as JSON to check for complete tool calls try: # Use preprocessed tool calls if available - tool_calls_text = (preprocessed_tool_calls if - preprocessed_tool_calls else current_text) + tool_calls_text = ( + preprocessed_tool_calls if preprocessed_tool_calls else current_text + ) parsed_tools = json.loads(tool_calls_text) if isinstance(parsed_tools, list): # Update our tool array for next time @@ -237,11 +259,15 @@ def extract_tool_calls_streaming( # Check for test-specific state setup (current_tools_sent) # This handles the case where tests manually set current_tools_sent - if (hasattr(self, "current_tools_sent") # type: ignore - and len(self.current_tools_sent) > 0): + if ( + hasattr(self, "current_tools_sent") # type: ignore + and len(self.current_tools_sent) > 0 + ): # If current_tools_sent is set to [False], it means the test wants us to send the name - if (len(self.current_tools_sent) == 1 - and self.current_tools_sent[0] is False): + if ( + len(self.current_tools_sent) == 1 + and self.current_tools_sent[0] is False + ): # Extract the function name using regex name_pattern = r'"name"\s*:\s*"([^"]+)"' name_match = re.search(name_pattern, current_text) @@ -250,51 +276,53 @@ def extract_tool_calls_streaming( # The test expects us to send just the name first tool_id = make_tool_call_id() - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=0, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True), # type: ignore - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=0, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name + ).model_dump(exclude_none=True), # type: ignore + ) + ] + ) # Update state to reflect that we've sent the name self.current_tools_sent = [True] self.current_tool_id = 0 self.streaming_state["current_tool_index"] = 0 if len(self.streaming_state["sent_tools"]) == 0: - self.streaming_state["sent_tools"].append({ - "sent_name": - True, - "sent_arguments_prefix": - False, - "sent_arguments": - "", - }) + self.streaming_state["sent_tools"].append( + { + "sent_name": True, + "sent_arguments_prefix": False, + "sent_arguments": "", + } + ) else: - self.streaming_state["sent_tools"][0][ - "sent_name"] = True + self.streaming_state["sent_tools"][0]["sent_name"] = True self.current_tool_name_sent = True return delta # Use regex to identify tool calls in the output # Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks - search_text = (preprocessed_tool_calls - if preprocessed_tool_calls else current_text) + search_text = ( + preprocessed_tool_calls if preprocessed_tool_calls else current_text + ) # For JSON code blocks that aren't complete yet, try to extract the JSON content if not preprocessed_tool_calls and has_potential_json_block: # Try to extract the JSON array from within the code block - json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)", - current_text) + json_match = re.search( + r"```(?:json)?\s*([\s\S]*?)(?:```|$)", current_text + ) if json_match: potential_json = json_match.group(1).strip() # Use this as search text even if it's incomplete if potential_json.startswith("[") and ( - '"name"' in potential_json - and '"arguments"' in potential_json): + '"name"' in potential_json and '"arguments"' in potential_json + ): search_text = potential_json # Try to find complete tool names first @@ -306,8 +334,7 @@ def extract_tool_calls_streaming( if tool_count == 0: # Check if we're in the middle of parsing a tool name partial_name_pattern = r'"name"\s*:\s*"([^"]*)' - partial_matches = list( - re.finditer(partial_name_pattern, search_text)) + partial_matches = list(re.finditer(partial_name_pattern, search_text)) if partial_matches: # We have a partial tool name - not ready to emit yet return None @@ -317,14 +344,13 @@ def extract_tool_calls_streaming( # Ensure our state arrays are large enough while len(self.streaming_state["sent_tools"]) < tool_count: - self.streaming_state["sent_tools"].append({ - "sent_name": - False, - "sent_arguments_prefix": - False, - "sent_arguments": - "", - }) + self.streaming_state["sent_tools"].append( + { + "sent_name": False, + "sent_arguments_prefix": False, + "sent_arguments": "", + } + ) while len(self.streaming_state["tool_ids"]) < tool_count: self.streaming_state["tool_ids"].append(None) @@ -337,14 +363,13 @@ def extract_tool_calls_streaming( next_idx = current_idx + 1 # If tool at next_idx has not been sent yet - if (next_idx < tool_count - and not self.streaming_state["sent_tools"][next_idx] - ["sent_name"]): + if ( + next_idx < tool_count + and not self.streaming_state["sent_tools"][next_idx]["sent_name"] + ): # Update indexes self.streaming_state["current_tool_index"] = next_idx - self.current_tool_id = ( - next_idx # For backward compatibility - ) + self.current_tool_id = next_idx # For backward compatibility current_idx = next_idx # Extract the tool name @@ -354,21 +379,20 @@ def extract_tool_calls_streaming( tool_id = f"call_{current_idx}_{random_uuid()}" self.streaming_state["tool_ids"][current_idx] = tool_id - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - type="function", - id=tool_id, - function=DeltaFunctionCall( - name=tool_name).model_dump( - exclude_none=True), # type: ignore - ) - ]) - self.streaming_state["sent_tools"][current_idx][ - "sent_name"] = True - self.current_tool_name_sent = ( - True # For backward compatibility + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall(name=tool_name).model_dump( + exclude_none=True + ), # type: ignore + ) + ] ) + self.streaming_state["sent_tools"][current_idx]["sent_name"] = True + self.current_tool_name_sent = True # For backward compatibility # Keep track of streamed args for backward compatibility while len(self.streamed_args) <= current_idx: @@ -381,7 +405,8 @@ def extract_tool_calls_streaming( # Support both regular and empty argument objects # First, check for the empty arguments case: "arguments": {} empty_args_pattern = ( - r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') + r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}' + ) empty_args_match = re.search(empty_args_pattern, search_text) # Check if this tool has empty arguments @@ -391,36 +416,39 @@ def extract_tool_calls_streaming( for i in range(tool_count): if i == current_idx: # If this is our current tool and it has empty arguments - if not self.streaming_state["sent_tools"][ - current_idx]["sent_arguments_prefix"]: + if not self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix" + ]: # Send empty object - self.streaming_state["sent_tools"][ - current_idx][ - "sent_arguments_prefix"] = True - self.streaming_state["sent_tools"][ - current_idx]["sent_arguments"] = "{}" + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix" + ] = True + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments" + ] = "{}" # Update streamed_args for backward compatibility while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{}" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments="{}"). - model_dump( - exclude_none=True), # type: ignore - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{}" + ).model_dump(exclude_none=True), # type: ignore + ) + ] + ) # Move to next tool if available if current_idx < tool_count - 1: - self.streaming_state[ - "current_tool_index"] += 1 + self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ - "current_tool_index"] + "current_tool_index" + ] return delta @@ -439,72 +467,77 @@ def extract_tool_calls_streaming( # Parse the entire JSON structure to properly extract arguments for each tool try: parsed_tools = json.loads(search_text) - if isinstance( - parsed_tools, - list) and current_idx < len(parsed_tools): + if isinstance(parsed_tools, list) and current_idx < len( + parsed_tools + ): current_tool = parsed_tools[current_idx] - if isinstance(current_tool.get("arguments"), - dict): - args_text = json.dumps( - current_tool["arguments"]) + if isinstance(current_tool.get("arguments"), dict): + args_text = json.dumps(current_tool["arguments"]) else: - args_text = str( - current_tool.get("arguments", "{}")) + args_text = str(current_tool.get("arguments", "{}")) except (json.JSONDecodeError, KeyError, IndexError): # Fallback to regex-based extraction pass # If arguments haven't been sent yet - sent_args = self.streaming_state["sent_tools"][ - current_idx]["sent_arguments"] + sent_args = self.streaming_state["sent_tools"][current_idx][ + "sent_arguments" + ] # If we haven't sent the opening bracket yet if not self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] and args_text.startswith( - "{"): + "sent_arguments_prefix" + ] and args_text.startswith("{"): self.streaming_state["sent_tools"][current_idx][ - "sent_arguments_prefix"] = True + "sent_arguments_prefix" + ] = True self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = "{" + "sent_arguments" + ] = "{" # Update streamed_args for backward compatibility while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += "{" - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments="{").model_dump( - exclude_none=True), # type: ignore - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{" + ).model_dump(exclude_none=True), # type: ignore + ) + ] + ) return delta # If we need to send more arguments if args_text.startswith(sent_args): # Calculate what part of arguments we need to send - args_diff = args_text[len(sent_args):] + args_diff = args_text[len(sent_args) :] if args_diff: # Update our state self.streaming_state["sent_tools"][current_idx][ - "sent_arguments"] = args_text + "sent_arguments" + ] = args_text # Update streamed_args for backward compatibility while len(self.streamed_args) <= current_idx: self.streamed_args.append("") self.streamed_args[current_idx] += args_diff - delta = DeltaMessage(tool_calls=[ - DeltaToolCall( - index=current_idx, - function=DeltaFunctionCall( - arguments=args_diff).model_dump( - exclude_none=True), # type: ignore - ) - ]) + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=args_diff + ).model_dump(exclude_none=True), # type: ignore + ) + ] + ) return delta # If the tool's arguments are complete, check if we need to move to the next tool @@ -513,7 +546,8 @@ def extract_tool_calls_streaming( if current_idx < tool_count - 1: self.streaming_state["current_tool_index"] += 1 self.current_tool_id = self.streaming_state[ - "current_tool_index"] # For compatibility + "current_tool_index" + ] # For compatibility # If we got here, we couldn't determine what to stream next return None diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index d7ce57c728ba..98c9cbbbd376 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -13,8 +13,9 @@ from vllm.config import ModelConfig from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt +from vllm.inputs.data import TextPrompt as EngineTextPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt -from vllm.inputs.parse import parse_and_batch_prompt +from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import AsyncMicrobatchTokenizer @@ -41,17 +42,38 @@ class RenderConfig: needs_detokenization: Optional[bool] = False """If True, detokenize IDs back to text for inclusion in outputs.""" + def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> Optional[int]: + """Validate and normalize `truncate_prompt_tokens` parameter.""" + truncate_prompt_tokens = self.truncate_prompt_tokens + if truncate_prompt_tokens is None: + return None + + if truncate_prompt_tokens == 0: + return 0 + + if truncate_prompt_tokens < 0: + truncate_prompt_tokens = model_config.max_model_len + + max_length = self.max_length + if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator] + raise ValueError( + f"{truncate_prompt_tokens=} cannot be greater than " + f"{max_length=}. Please select a smaller truncation size." + ) + + return truncate_prompt_tokens + class BaseRenderer(ABC): """ Base class for unified input processing and rendering. - + The Renderer serves as a unified input processor that consolidates tokenization, chat template formatting, and multimodal input handling into a single component. It converts high-level API requests (OpenAI-style JSON) into token IDs and multimodal features ready for engine consumption. - + Key responsibilities: - Convert text prompts to token sequences with proper special tokens - Apply chat templates and format conversations @@ -74,7 +96,7 @@ async def render_prompt( self, *, prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], - config: "RenderConfig", + config: RenderConfig, ) -> list[EngineTokensPrompt]: """ Convert text or token inputs into engine-ready TokensPrompt objects. @@ -90,7 +112,7 @@ async def render_prompt( - ``list[int]``: Single pre-tokenized sequence. - ``list[list[int]]``: Batch of pre-tokenized sequences. config: Render configuration controlling how prompts are prepared - (e.g., tokenization and length handling). + (e.g., tokenization and length handling). Returns: list[EngineTokensPrompt]: Engine-ready token prompts. @@ -104,10 +126,11 @@ async def render_prompt( async def render_prompt_and_embeds( self, *, - prompt_or_prompts: Optional[Union[str, list[str], list[int], - list[list[int]]]] = None, + prompt_or_prompts: Optional[ + Union[str, list[str], list[int], list[list[int]]] + ] = None, prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, - config: "RenderConfig", + config: RenderConfig, ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: """ Convert text/token and/or base64-encoded embeddings inputs into @@ -122,7 +145,7 @@ async def render_prompt_and_embeds( prompt_embeds: Base64-encoded bytes (or list thereof) containing a torch-saved tensor to be used as prompt embeddings. config: Render configuration controlling how prompts are prepared - (e.g., tokenization and length handling). + (e.g., tokenization and length handling). Returns: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: @@ -173,13 +196,13 @@ def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: class CompletionRenderer(BaseRenderer): - def __init__( self, model_config: ModelConfig, tokenizer: Optional[AnyTokenizer] = None, - async_tokenizer_pool: Optional[dict[AnyTokenizer, - AsyncMicrobatchTokenizer]] = None, + async_tokenizer_pool: Optional[ + dict[AnyTokenizer, AsyncMicrobatchTokenizer] + ] = None, ): super().__init__(model_config, tokenizer) self.async_tokenizer_pool = async_tokenizer_pool @@ -189,62 +212,42 @@ async def render_prompt( self, *, prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]], - config: "RenderConfig", + config: RenderConfig, ) -> list[EngineTokensPrompt]: """Implementation of prompt rendering for completion-style requests. - + Uses async tokenizer pooling for improved performance. See base class for detailed parameter documentation. """ - truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( - config.truncate_prompt_tokens, config.max_length) + truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config) if truncate_prompt_tokens == 0: return [] - # Parse and batch the input prompts - batch_inputs = parse_and_batch_prompt(prompt_or_prompts) - - tasks = [] - for prompt_input in batch_inputs: - if prompt_input["is_tokens"] is True: - # Token input - # Note: detokenization is needed when echo is enabled, - # where the input token IDs are decoded back to text. - task = self._maybe_detokenize(prompt_input["content"], - config.max_length, - truncate_prompt_tokens, - config.cache_salt, - config.needs_detokenization) - else: - # Text input - task = self._tokenize(prompt_input["content"], - config.max_length, - truncate_prompt_tokens, - config.add_special_tokens, - config.cache_salt) - tasks.append(task) - - # Wait for all text tokenization to finish - if tasks: - tokenized_text_prompts = await asyncio.gather(*tasks) - return tokenized_text_prompts - - return [] + tasks = ( + self._create_prompt( + prompt_input, + config=config, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + for prompt_input in parse_raw_prompts(prompt_or_prompts) + ) + + return await asyncio.gather(*tasks) async def render_prompt_and_embeds( self, *, - prompt_or_prompts: Optional[Union[str, list[str], list[int], - list[list[int]]]] = None, + prompt_or_prompts: Optional[ + Union[str, list[str], list[int], list[list[int]]] + ] = None, prompt_embeds: Optional[Union[bytes, list[bytes]]] = None, - config: "RenderConfig", + config: RenderConfig, ) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: """ Render text/token prompts and/or precomputed embedding prompts. At least one of `prompt_or_prompts` or `prompt_embeds` must be provided. """ - truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens( - config.truncate_prompt_tokens, config.max_length) + truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config) if truncate_prompt_tokens == 0: return [] @@ -252,8 +255,10 @@ async def render_prompt_and_embeds( if prompt_embeds is not None: rendered.extend( - self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens, - config.cache_salt)) + self.load_prompt_embeds( + prompt_embeds, truncate_prompt_tokens, config.cache_salt + ) + ) if prompt_or_prompts is None or prompt_or_prompts == "": return rendered @@ -265,32 +270,9 @@ async def render_prompt_and_embeds( return rendered - def _validate_and_normalize_truncate_tokens( - self, - truncate_prompt_tokens: Optional[int], - max_length: Optional[int], - ) -> Optional[int]: - """Validate and normalize truncate_prompt_tokens parameter.""" - if truncate_prompt_tokens is None: - return None - - if truncate_prompt_tokens == 0: - return 0 - - if truncate_prompt_tokens < 0: - truncate_prompt_tokens = self.model_config.max_model_len - - if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator] - raise ValueError( - f"truncate_prompt_tokens ({truncate_prompt_tokens}) " - f"cannot be greater than max_length ({max_length}). " - f"Please select a smaller truncation size.") - - return truncate_prompt_tokens - def _maybe_apply_truncation( - self, token_ids: list[int], - truncate_prompt_tokens: Optional[int]) -> list[int]: + self, token_ids: list[int], truncate_prompt_tokens: Optional[int] + ) -> list[int]: """Apply truncation to token sequence.""" if truncate_prompt_tokens is None: return token_ids @@ -299,7 +281,38 @@ def _maybe_apply_truncation( return token_ids[-truncate_prompt_tokens:] - async def _tokenize( + async def _create_prompt( + self, + prompt_input: Union[EngineTextPrompt, EngineTokensPrompt], + config: RenderConfig, + truncate_prompt_tokens: Optional[int], + ) -> EngineTokensPrompt: + prompt, prompt_token_ids, _ = get_prompt_components(prompt_input) + + if prompt_token_ids is not None: + # NOTE: detokenization is needed when echo is enabled, + # where the input token IDs are decoded back to text. + return await self._create_prompt_from_token_ids( + prompt_token_ids, + config.max_length, + truncate_prompt_tokens, + config.cache_salt, + config.needs_detokenization, + ) + + if prompt is not None: + return await self._create_prompt_from_text( + prompt, + config.max_length, + truncate_prompt_tokens, + config.add_special_tokens, + config.cache_salt, + ) + + # TODO: Also handle embeds prompt using this method + raise NotImplementedError + + async def _create_prompt_from_text( self, text: str, max_length: Optional[int], @@ -311,26 +324,28 @@ async def _tokenize( async_tokenizer = self._get_async_tokenizer() # Handle encoder-specific preprocessing - if (self.model_config.encoder_config is not None - and self.model_config.encoder_config.get( - "do_lower_case", False)): + if ( + self.model_config.encoder_config is not None + and self.model_config.encoder_config.get("do_lower_case", False) + ): text = text.lower() # Tokenize texts if truncate_prompt_tokens is None: - encoded = await async_tokenizer( - text, add_special_tokens=add_special_tokens) + encoded = await async_tokenizer(text, add_special_tokens=add_special_tokens) else: encoded = await async_tokenizer( text, add_special_tokens=add_special_tokens, truncation=True, - max_length=truncate_prompt_tokens) + max_length=truncate_prompt_tokens, + ) - return self._create_tokens_prompt(encoded.input_ids, max_length, - cache_salt, text) + return self._create_tokens_prompt( + encoded.input_ids, max_length, cache_salt, text + ) - async def _maybe_detokenize( + async def _create_prompt_from_token_ids( self, token_ids: list[int], max_length: Optional[int], @@ -339,18 +354,19 @@ async def _maybe_detokenize( needs_detokenization: Optional[bool] = False, ) -> EngineTokensPrompt: """Optionally detokenize token IDs and build a tokens prompt.""" - token_ids = self._maybe_apply_truncation(token_ids, - truncate_prompt_tokens) + token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens) prompt = None - if needs_detokenization is True: + if needs_detokenization: async_tokenizer = self._get_async_tokenizer() prompt = await async_tokenizer.decode(token_ids) - return self._create_tokens_prompt(token_ids=token_ids, - max_length=max_length, - cache_salt=cache_salt, - prompt=prompt) + return self._create_tokens_prompt( + token_ids=token_ids, + max_length=max_length, + cache_salt=cache_salt, + prompt=prompt, + ) def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: """Get or create async tokenizer using shared pool.""" @@ -360,8 +376,7 @@ def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: tokenizer = self.tokenizer if self.tokenizer is None: - raise ValueError( - "No tokenizer available for text input processing") + raise ValueError("No tokenizer available for text input processing") if self.async_tokenizer_pool is None: async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) @@ -385,7 +400,8 @@ def _create_tokens_prompt( raise ValueError( f"This model's maximum context length is {max_length} tokens. " f"However, your request has {len(token_ids)} input tokens. " - "Please reduce the length of the input messages.") + "Please reduce the length of the input messages." + ) tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids) if cache_salt is not None: diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 642d6389539b..1fb56d246deb 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -7,31 +7,39 @@ from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import ( - BaseMultiModalItemTracker, ChatCompletionContentPartImageEmbedsParam, - ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam, - MultiModalItemTracker, _ContentPart, _parse_chat_message_content_part) + BaseMultiModalItemTracker, + ChatCompletionContentPartImageEmbedsParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam, + MultiModalItemTracker, + _ContentPart, + _parse_chat_message_content_part, +) from vllm.inputs import TokensPrompt from vllm.model_executor.models.interfaces import supports_score_template from vllm.multimodal.inputs import MultiModalDataDict from vllm.outputs import PoolingRequestOutput -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - PreTrainedTokenizer, - PreTrainedTokenizerFast) +from vllm.transformers_utils.tokenizer import ( + AnyTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) ScoreContentPartParam: TypeAlias = Union[ - ChatCompletionContentPartImageParam, - ChatCompletionContentPartImageEmbedsParam] + ChatCompletionContentPartImageParam, ChatCompletionContentPartImageEmbedsParam +] class ScoreMultiModalParam(TypedDict, total=False): """ A specialized parameter type for scoring multimodal content - + The reasons why don't reuse `CustomChatCompletionMessageParam` directly: 1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions 2. Including chat-specific fields would confuse users about their purpose in scoring 3. This is a more focused interface that only exposes what's needed for scoring - """ # noqa: E501 + """ # noqa: E501 + content: Required[list[ScoreContentPartParam]] """The multimodal contents""" @@ -41,7 +49,6 @@ def _cosine_similarity( embed_1: list[PoolingRequestOutput], embed_2: list[PoolingRequestOutput], ) -> list[PoolingRequestOutput]: - scorer = CosineSimilarity(0) scores: Union[list[PoolingRequestOutput]] = [] @@ -49,8 +56,7 @@ def _cosine_similarity( pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data) padding = [] - if (pad_token_id := getattr(tokenizer, "pad_token_id", - None)) is not None: + if (pad_token_id := getattr(tokenizer, "pad_token_id", None)) is not None: padding = [pad_token_id] tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids @@ -60,7 +66,9 @@ def _cosine_similarity( request_id=f"{emb_1.request_id}_{emb_2.request_id}", outputs=pair_score, prompt_token_ids=tokens, - finished=True)) + finished=True, + ) + ) return scores @@ -96,8 +104,7 @@ def ensure_str(content: Optional[_ContentPart]) -> str: if content is not None and isinstance(content, str): return cast(str, content) else: - raise ValueError( - f"Only string content is supported, but got {content}.") + raise ValueError(f"Only string content is supported, but got {content}.") prompt_1 = ensure_str(content_1) prompt_2 = ensure_str(content_2) @@ -109,7 +116,6 @@ def _parse_score_content( data: Union[str, ScoreContentPartParam], mm_tracker: BaseMultiModalItemTracker, ) -> Optional[_ContentPart]: - if isinstance(data, str): data = ChatCompletionContentPartTextParam(type="text", text=data) @@ -127,8 +133,10 @@ def _parse_score_content( mm_placeholder_storage = mm_parser.mm_placeholder_storage() - if len(mm_placeholder_storage) != 1 or len( - next(iter(mm_placeholder_storage.values()))) != 1: + if ( + len(mm_placeholder_storage) != 1 + or len(next(iter(mm_placeholder_storage.values()))) != 1 + ): raise ValueError("Only one multi-modal item is supported") return next(iter(mm_placeholder_storage.values()))[0] @@ -149,8 +157,7 @@ def apply_score_template( raise ValueError("Get empty score template from model") return full_prompt - raise ValueError( - f"Unsupported model architecture: {model_config.architecture}") + raise ValueError(f"Unsupported model architecture: {model_config.architecture}") def post_process_tokens( @@ -159,7 +166,7 @@ def post_process_tokens( ) -> None: """ Perform architecture-specific manipulations on the input tokens. - + Note: This is an in-place operation. """ @@ -192,9 +199,9 @@ def get_score_prompt( prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) elif model_config.use_pad_token: # cross_encoder models defaults to using pad_token. - prompt_inputs = tokenizer(text=prompt_1, - text_pair=prompt_2, - **tokenization_kwargs) + prompt_inputs = tokenizer( + text=prompt_1, text_pair=prompt_2, **tokenization_kwargs + ) full_prompt = tokenizer.decode(prompt_inputs["input_ids"]) else: # `llm as reranker` models defaults to not using pad_token. @@ -219,8 +226,10 @@ def compress_token_type_ids(token_type_ids: list[int]) -> int: if not found. """ first_one = len(token_type_ids) - err_msg = "Token type ids are expected to be a sequence"\ - " of zeros followed by a sequence of ones" + err_msg = ( + "Token type ids are expected to be a sequence" + " of zeros followed by a sequence of ones" + ) for i, type_id in enumerate(token_type_ids): if type_id == 0 and first_one < i: raise ValueError(err_msg) diff --git a/vllm/entrypoints/ssl.py b/vllm/entrypoints/ssl.py index e3646a60a7cc..ff0dd1bbfc6b 100644 --- a/vllm/entrypoints/ssl.py +++ b/vllm/entrypoints/ssl.py @@ -17,11 +17,13 @@ class SSLCertRefresher: reloads them when they change. """ - def __init__(self, - ssl_context: SSLContext, - key_path: Optional[str] = None, - cert_path: Optional[str] = None, - ca_path: Optional[str] = None) -> None: + def __init__( + self, + ssl_context: SSLContext, + key_path: Optional[str] = None, + cert_path: Optional[str] = None, + ca_path: Optional[str] = None, + ) -> None: self.ssl = ssl_context self.key_path = key_path self.cert_path = cert_path @@ -36,8 +38,10 @@ def update_ssl_cert_chain(change: Change, file_path: str) -> None: self.watch_ssl_cert_task = None if self.key_path and self.cert_path: self.watch_ssl_cert_task = asyncio.create_task( - self._watch_files([self.key_path, self.cert_path], - update_ssl_cert_chain)) + self._watch_files( + [self.key_path, self.cert_path], update_ssl_cert_chain + ) + ) # Setup CA files watcher def update_ssl_ca(change: Change, file_path: str) -> None: @@ -48,22 +52,21 @@ def update_ssl_ca(change: Change, file_path: str) -> None: self.watch_ssl_ca_task = None if self.ca_path: self.watch_ssl_ca_task = asyncio.create_task( - self._watch_files([self.ca_path], update_ssl_ca)) + self._watch_files([self.ca_path], update_ssl_ca) + ) - async def _watch_files(self, paths, fun: Callable[[Change, str], - None]) -> None: + async def _watch_files(self, paths, fun: Callable[[Change, str], None]) -> None: """Watch multiple file paths asynchronously.""" logger.info("SSLCertRefresher monitors files: %s", paths) async for changes in awatch(*paths): try: for change, file_path in changes: - logger.info("File change detected: %s - %s", change.name, - file_path) + logger.info("File change detected: %s - %s", change.name, file_path) fun(change, file_path) except Exception as e: logger.error( - "SSLCertRefresher failed taking action on file change. " - "Error: %s", e) + "SSLCertRefresher failed taking action on file change. Error: %s", e + ) def stop(self) -> None: """Stop watching files.""" diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py index f5f4d7d3b556..c74ce1ee16de 100644 --- a/vllm/entrypoints/tool.py +++ b/vllm/entrypoints/tool.py @@ -14,10 +14,12 @@ logger = init_logger(__name__) +MIN_GPT_OSS_VERSION = "0.0.7" + def validate_gpt_oss_install(): """ - Check if the gpt-oss is installed and its version is at least 0.0.3. + Check if the gpt-oss is installed and its version is at least 0.0.7. If not, raise an ImportError. """ from importlib.metadata import PackageNotFoundError, version @@ -25,29 +27,27 @@ def validate_gpt_oss_install(): from packaging.version import InvalidVersion, Version try: - pkg_version_str = version("gpt_oss") # e.g., "0.0.5" + pkg_version_str = version("gpt_oss") pkg_version = Version(pkg_version_str) except PackageNotFoundError: raise ImportError("Package 'gpt_oss' is not installed.") from None except InvalidVersion as e: - raise ImportError( - f"Invalid version string for 'gpt_oss': {e}") from None + raise ImportError(f"Invalid version string for 'gpt_oss': {e}") from None - if pkg_version < Version("0.0.3"): + if pkg_version < Version(MIN_GPT_OSS_VERSION): raise ImportError( - f"gpt_oss >= 0.0.3 is required, but {pkg_version} is installed." + f"gpt_oss >= {MIN_GPT_OSS_VERSION} is required, " + f"but {pkg_version} is installed." ) from None class Tool(ABC): - @abstractmethod async def get_result(self, context: "ConversationContext") -> Any: pass class HarmonyBrowserTool(Tool): - def __init__(self): self.enabled = True exa_api_key = os.getenv("EXA_API_KEY") @@ -63,8 +63,8 @@ def __init__(self): except ImportError as e: self.enabled = False logger.warning_once( - "gpt_oss is not installed properly (%s), browsing is disabled", - e) + "gpt_oss is not installed properly (%s), browsing is disabled", e + ) return browser_backend = ExaBackend(source="web", api_key=exa_api_key) @@ -73,6 +73,7 @@ def __init__(self): async def get_result(self, context: "ConversationContext") -> Any: from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) last_msg = context.messages[-1] tool_output_msgs = [] @@ -86,7 +87,6 @@ def tool_config(self) -> Any: class HarmonyPythonTool(Tool): - def __init__(self): self.enabled = True @@ -96,8 +96,9 @@ def __init__(self): except ImportError as e: self.enabled = False logger.warning_once( - "gpt_oss is not installed properly (%s), code interpreter is " - "disabled", e) + "gpt_oss is not installed properly (%s), code interpreter is disabled", + e, + ) return self.python_tool = PythonTool() @@ -121,12 +122,15 @@ async def validate(self): self.enabled = False logger.warning_once( "Code interpreter tool failed to initialize (%s), code " - "interpreter is disabled", e) + "interpreter is disabled", + e, + ) return logger.info_once("Code interpreter tool initialized") async def get_result(self, context: "ConversationContext") -> Any: from vllm.entrypoints.context import HarmonyContext + assert isinstance(context, HarmonyContext) last_msg = context.messages[-1] tool_output_msgs = [] diff --git a/vllm/entrypoints/tool_server.py b/vllm/entrypoints/tool_server.py index 4c627b865ef9..b3dceecc1583 100644 --- a/vllm/entrypoints/tool_server.py +++ b/vllm/entrypoints/tool_server.py @@ -18,8 +18,11 @@ async def list_server_and_tools(server_url: str): from mcp import ClientSession from mcp.client.sse import sse_client - async with sse_client(url=server_url) as streams, ClientSession( - *streams) as session: + + async with ( + sse_client(url=server_url) as streams, + ClientSession(*streams) as session, + ): initialize_response = await session.initialize() list_tools_response = await session.list_tools() return initialize_response, list_tools_response @@ -37,21 +40,22 @@ def trim_schema(schema: dict) -> dict: # if there's more than 1 types, also remove "null" type as Harmony will # just ignore it types = [ - type_dict["type"] for type_dict in schema["anyOf"] - if type_dict["type"] != 'null' + type_dict["type"] + for type_dict in schema["anyOf"] + if type_dict["type"] != "null" ] schema["type"] = types del schema["anyOf"] if "properties" in schema: schema["properties"] = { - k: trim_schema(v) - for k, v in schema["properties"].items() + k: trim_schema(v) for k, v in schema["properties"].items() } return schema def post_process_tools_description( - list_tools_result: "ListToolsResult") -> "ListToolsResult": + list_tools_result: "ListToolsResult", +) -> "ListToolsResult": # Adapt the MCP tool result for Harmony for tool in list_tools_result.tools: tool.inputSchema = trim_schema(tool.inputSchema) @@ -59,7 +63,8 @@ def post_process_tools_description( # Some tools schema don't need to be part of the prompt (e.g. simple text # in text out for Python) list_tools_result.tools = [ - tool for tool in list_tools_result.tools + tool + for tool in list_tools_result.tools if getattr(tool.annotations, "include_in_prompt", True) ] @@ -67,7 +72,6 @@ def post_process_tools_description( class ToolServer(ABC): - @abstractmethod def has_tool(self, tool_name: str) -> bool: """ @@ -76,8 +80,7 @@ def has_tool(self, tool_name: str) -> bool: pass @abstractmethod - def get_tool_description(self, - tool_name: str) -> Optional[ToolNamespaceConfig]: + def get_tool_description(self, tool_name: str) -> Optional[ToolNamespaceConfig]: """ Return the tool description for the given tool name. If the tool is not supported, return None. @@ -86,10 +89,7 @@ def get_tool_description(self, @abstractmethod def new_session( - self, - tool_name: str, - session_id: str, - headers: Optional[dict[str, str]] = None + self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None ) -> AbstractAsyncContextManager[Any]: """ Create a session for the tool. @@ -98,14 +98,14 @@ def new_session( class MCPToolServer(ToolServer): - def __init__(self): try: import mcp # noqa: F401 except ImportError: raise ImportError( "mcp is not installed. Please run `pip install mcp` to use " - "MCPToolServer.") from None + "MCPToolServer." + ) from None self.harmony_tool_descriptions = {} async def add_tool_server(self, server_url: str): @@ -114,19 +114,19 @@ async def add_tool_server(self, server_url: str): self.urls: dict[str, str] = {} for url in tool_urls: url = f"http://{url}/sse" - initialize_response, list_tools_response = ( - await list_server_and_tools(url)) + initialize_response, list_tools_response = await list_server_and_tools(url) - list_tools_response = post_process_tools_description( - list_tools_response) + list_tools_response = post_process_tools_description(list_tools_response) tool_from_mcp = ToolNamespaceConfig( name=initialize_response.serverInfo.name, description=initialize_response.instructions, tools=[ - ToolDescription.new(name=tool.name, - description=tool.description, - parameters=tool.inputSchema) + ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.inputSchema, + ) for tool in list_tools_response.tools ], ) @@ -136,9 +136,13 @@ async def add_tool_server(self, server_url: str): else: logger.warning( "Tool %s already exists. Ignoring duplicate tool server %s", - tool_from_mcp.name, url) - logger.info("MCPToolServer initialized with tools: %s", - list(self.harmony_tool_descriptions.keys())) + tool_from_mcp.name, + url, + ) + logger.info( + "MCPToolServer initialized with tools: %s", + list(self.harmony_tool_descriptions.keys()), + ) def has_tool(self, tool_name: str): return tool_name in self.harmony_tool_descriptions @@ -147,27 +151,27 @@ def get_tool_description(self, tool_name: str): return self.harmony_tool_descriptions.get(tool_name) @asynccontextmanager - async def new_session(self, - tool_name: str, - session_id: str, - headers: Optional[dict[str, str]] = None): + async def new_session( + self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None + ): from mcp import ClientSession from mcp.client.sse import sse_client + url = self.urls.get(tool_name) request_headers = {"x-session-id": session_id} if headers is not None: request_headers.update(headers) if not url: raise KeyError(f"Tool '{tool_name}' is not supported") - async with sse_client( - url=url, headers=request_headers) as streams, ClientSession( - *streams) as session: + async with ( + sse_client(url=url, headers=request_headers) as streams, + ClientSession(*streams) as session, + ): await session.initialize() yield session class DemoToolServer(ToolServer): - def __init__(self): self.tools: dict[str, Tool] = {} @@ -179,14 +183,14 @@ async def init_and_validate(self): self.tools["browser"] = browser_tool if python_tool.enabled: self.tools["python"] = python_tool - logger.info("DemoToolServer initialized with tools: %s", - list(self.tools.keys())) + logger.info( + "DemoToolServer initialized with tools: %s", list(self.tools.keys()) + ) def has_tool(self, tool_name: str) -> bool: return tool_name in self.tools - def get_tool_description(self, - tool_name: str) -> Optional[ToolNamespaceConfig]: + def get_tool_description(self, tool_name: str) -> Optional[ToolNamespaceConfig]: if tool_name not in self.tools: return None if tool_name == "browser": @@ -197,10 +201,9 @@ def get_tool_description(self, raise ValueError(f"Unknown tool {tool_name}") @asynccontextmanager - async def new_session(self, - tool_name: str, - session_id: str, - headers: Optional[dict[str, str]] = None): + async def new_session( + self, tool_name: str, session_id: str, headers: Optional[dict[str, str]] = None + ): if tool_name not in self.tools: raise KeyError(f"Tool '{tool_name}' is not supported") yield self.tools[tool_name] diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 4a90fe094ae2..c97ca6538814 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -14,8 +14,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser @@ -26,7 +25,8 @@ "For full list: vllm {subcmd} --help=all\n" "For a section: vllm {subcmd} --help=ModelConfig (case-insensitive)\n" # noqa: E501 "For a flag: vllm {subcmd} --help=max-model-len (_ or - accepted)\n" # noqa: E501 - "Documentation: https://docs.vllm.ai\n") + "Documentation: https://docs.vllm.ai\n" +) async def listen_for_disconnect(request: Request) -> None: @@ -37,9 +37,9 @@ async def listen_for_disconnect(request: Request) -> None: # If load tracking is enabled *and* the counter exists, decrement # it. Combines the previous nested checks into a single condition # to satisfy the linter rule. - if (getattr(request.app.state, "enable_server_load_tracking", - False) - and hasattr(request.app.state, "server_load_metrics")): + if getattr( + request.app.state, "enable_server_load_tracking", False + ) and hasattr(request.app.state, "server_load_metrics"): request.app.state.server_load_metrics -= 1 break @@ -70,15 +70,15 @@ def with_cancellation(handler_func): # normal route handler, with the correct request type hinting. @functools.wraps(handler_func) async def wrapper(*args, **kwargs): - # The request is either the second positional arg or `raw_request` request = args[1] if len(args) > 1 else kwargs["raw_request"] handler_task = asyncio.create_task(handler_func(*args, **kwargs)) cancellation_task = asyncio.create_task(listen_for_disconnect(request)) - done, pending = await asyncio.wait([handler_task, cancellation_task], - return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + [handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED + ) for task in pending: task.cancel() @@ -94,18 +94,16 @@ def decrement_server_load(request: Request): def load_aware_call(func): - @functools.wraps(func) async def wrapper(*args, **kwargs): - raw_request = kwargs.get("raw_request", - args[1] if len(args) > 1 else None) + raw_request = kwargs.get("raw_request", args[1] if len(args) > 1 else None) if raw_request is None: raise ValueError( - "raw_request required when server load tracking is enabled") + "raw_request required when server load tracking is enabled" + ) - if not getattr(raw_request.app.state, "enable_server_load_tracking", - False): + if not getattr(raw_request.app.state, "enable_server_load_tracking", False): return await func(*args, **kwargs) # ensure the counter exists @@ -121,18 +119,18 @@ async def wrapper(*args, **kwargs): if isinstance(response, (JSONResponse, StreamingResponse)): if response.background is None: - response.background = BackgroundTask(decrement_server_load, - raw_request) + response.background = BackgroundTask(decrement_server_load, raw_request) elif isinstance(response.background, BackgroundTasks): - response.background.add_task(decrement_server_load, - raw_request) + response.background.add_task(decrement_server_load, raw_request) elif isinstance(response.background, BackgroundTask): # Convert the single BackgroundTask to BackgroundTasks # and chain the decrement_server_load task to it tasks = BackgroundTasks() - tasks.add_task(response.background.func, - *response.background.args, - **response.background.kwargs) + tasks.add_task( + response.background.func, + *response.background.args, + **response.background.kwargs, + ) tasks.add_task(decrement_server_load, raw_request) response.background = tasks else: @@ -169,7 +167,6 @@ def _validate_truncation_size( truncate_prompt_tokens: Optional[int], tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> Optional[int]: - if truncate_prompt_tokens is not None: if truncate_prompt_tokens <= -1: truncate_prompt_tokens = max_model_len @@ -178,7 +175,8 @@ def _validate_truncation_size( raise ValueError( f"truncate_prompt_tokens value ({truncate_prompt_tokens}) " f"is greater than max_model_len ({max_model_len})." - f" Please, select a smaller truncation size.") + f" Please, select a smaller truncation size." + ) if tokenization_kwargs is not None: tokenization_kwargs["truncation"] = True @@ -191,19 +189,26 @@ def _validate_truncation_size( return truncate_prompt_tokens -def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest, - CompletionRequest], - input_length: int, default_sampling_params: dict) -> int: - - max_tokens = getattr(request, "max_completion_tokens", - None) or request.max_tokens +def get_max_tokens( + max_model_len: int, + request: Union[ChatCompletionRequest, CompletionRequest], + input_length: int, + default_sampling_params: dict, +) -> int: + max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens default_max_tokens = max_model_len - input_length max_output_tokens = current_platform.get_max_output_tokens(input_length) - return min(val - for val in (default_max_tokens, max_tokens, max_output_tokens, - default_sampling_params.get("max_tokens")) - if val is not None) + return min( + val + for val in ( + default_max_tokens, + max_tokens, + max_output_tokens, + default_sampling_params.get("max_tokens"), + ) + if val is not None + ) def log_non_default_args(args: Union[Namespace, EngineArgs]): @@ -227,7 +232,8 @@ def log_non_default_args(args: Union[Namespace, EngineArgs]): if default_args.model != EngineArgs.model: non_default_args["model"] = default_args.model else: - raise TypeError("Unsupported argument type. " \ - "Must be Namespace or EngineArgs instance.") + raise TypeError( + "Unsupported argument type. Must be Namespace or EngineArgs instance." + ) logger.info("non-default args: %s", non_default_args) diff --git a/vllm/env_override.py b/vllm/env_override.py index b06703a2fbf9..7f9054e73846 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -15,9 +15,9 @@ # see https://github.com/vllm-project/vllm/pull/15951 # it avoids unintentional cuda initialization from torch.cuda.is_available() -os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' +os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" # see https://github.com/vllm-project/vllm/issues/10480 -os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' +os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" # see https://github.com/vllm-project/vllm/issues/10619 torch._inductor.config.compile_threads = 1 diff --git a/vllm/envs.py b/vllm/envs.py index 6dce4bd0f94e..a4f53925626b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -41,7 +41,7 @@ VLLM_LOGGING_STREAM: str = "ext://sys.stdout" VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None - VLLM_LOG_STATS_INTERVAL: float = 10. + VLLM_LOG_STATS_INTERVAL: float = 10.0 VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None @@ -57,8 +57,7 @@ VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False - VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", - "shm"] = "auto" + VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True VLLM_XLA_USE_SPMD: bool = False @@ -81,8 +80,7 @@ VLLM_DOCKER_BUILD_CONTEXT: bool = False VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False - CMAKE_BUILD_TYPE: Optional[Literal["Debug", "Release", - "RelWithDebInfo"]] = None + CMAKE_BUILD_TYPE: Optional[Literal["Debug", "Release", "RelWithDebInfo"]] = None VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms @@ -151,19 +149,20 @@ VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False - VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", - "latency"] = "throughput" + VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput" VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600 - VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", - "deepep_high_throughput", - "deepep_low_latency", - "allgather_reducescatter", - "flashinfer_all2allv"] = \ - "allgather_reducescatter" + VLLM_ALL2ALL_BACKEND: Literal[ + "naive", + "pplx", + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter", + "flashinfer_all2allv", + ] = "allgather_reducescatter" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_SLEEP_WHEN_IDLE: bool = False @@ -172,8 +171,9 @@ VLLM_KV_CACHE_LAYOUT: Optional[Literal["NHD", "HND"]] = None VLLM_COMPUTE_NANS_IN_LOGITS: bool = False VLLM_USE_NVFP4_CT_EMULATIONS: bool = False - VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal["FP", "INT8", "INT6", "INT4", - "NONE"] = "NONE" + VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal[ + "FP", "INT8", "INT6", "INT4", "NONE" + ] = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 @@ -237,19 +237,20 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: def env_with_choices( - env_name: str, - default: Optional[str], - choices: Union[list[str], Callable[[], list[str]]], - case_sensitive: bool = True) -> Callable[[], Optional[str]]: + env_name: str, + default: Optional[str], + choices: Union[list[str], Callable[[], list[str]]], + case_sensitive: bool = True, +) -> Callable[[], Optional[str]]: """ Create a lambda that validates environment variable against allowed choices - + Args: env_name: Name of the environment variable default: Default value if not set (can be None) choices: List of valid string options or callable that returns list case_sensitive: Whether validation should be case sensitive - + Returns: Lambda function for environment_variables dict """ @@ -270,8 +271,10 @@ def _get_validated_env() -> Optional[str]: check_choices = actual_choices if check_value not in check_choices: - raise ValueError(f"Invalid value '{value}' for {env_name}. " - f"Valid options: {actual_choices}.") + raise ValueError( + f"Invalid value '{value}' for {env_name}. " + f"Valid options: {actual_choices}." + ) return value @@ -279,20 +282,21 @@ def _get_validated_env() -> Optional[str]: def env_list_with_choices( - env_name: str, - default: list[str], - choices: Union[list[str], Callable[[], list[str]]], - case_sensitive: bool = True) -> Callable[[], list[str]]: + env_name: str, + default: list[str], + choices: Union[list[str], Callable[[], list[str]]], + case_sensitive: bool = True, +) -> Callable[[], list[str]]: """ - Create a lambda that validates environment variable + Create a lambda that validates environment variable containing comma-separated values against allowed choices - + Args: env_name: Name of the environment variable default: Default list of values if not set choices: List of valid string options or callable that returns list case_sensitive: Whether validation should be case sensitive - + Returns: Lambda function for environment_variables dict that returns list of strings @@ -322,8 +326,10 @@ def _get_validated_env_list() -> list[str]: check_choices = actual_choices if check_value not in check_choices: - raise ValueError(f"Invalid value '{val}' in {env_name}. " - f"Valid options: {actual_choices}.") + raise ValueError( + f"Invalid value '{val}' in {env_name}. " + f"Valid options: {actual_choices}." + ) return values @@ -339,15 +345,16 @@ def get_vllm_port() -> Optional[int]: Raises: ValueError: If VLLM_PORT is a URI, suggest k8s service discovery issue. """ - if 'VLLM_PORT' not in os.environ: + if "VLLM_PORT" not in os.environ: return None - port = os.getenv('VLLM_PORT', '0') + port = os.getenv("VLLM_PORT", "0") try: return int(port) except ValueError as err: from urllib.parse import urlparse + parsed = urlparse(port) if parsed.scheme: raise ValueError( @@ -355,8 +362,7 @@ def get_vllm_port() -> Optional[int]: "This may be caused by a Kubernetes service discovery issue," "check the warning in: https://docs.vllm.ai/en/stable/serving/env_vars.html" ) from None - raise ValueError( - f"VLLM_PORT '{port}' must be a valid integer") from err + raise ValueError(f"VLLM_PORT '{port}' must be a valid integer") from err # The begin-* and end* here are used by the documentation generator @@ -365,247 +371,200 @@ def get_vllm_port() -> Optional[int]: # --8<-- [start:env-vars-definition] environment_variables: dict[str, Callable[[], Any]] = { - # ================== Installation Time Env Vars ================== - # Target device of vLLM, supporting [cuda (by default), # rocm, cpu] - "VLLM_TARGET_DEVICE": - lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), - + "VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), # Main CUDA version of vLLM, supporting [12.6, 12.8, 12.9], # 12.8 is the default. This follows PyTorch but can be overridden. - "VLLM_MAIN_CUDA_VERSION": - lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() or "12.8", - + "VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() + or "12.8", # Maximum number of compilation jobs to run in parallel. # By default this is the number of CPUs - "MAX_JOBS": - lambda: os.getenv("MAX_JOBS", None), - + "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), # Number of threads to use for nvcc # By default this is 1. # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. - "NVCC_THREADS": - lambda: os.getenv("NVCC_THREADS", None), - + "NVCC_THREADS": lambda: os.getenv("NVCC_THREADS", None), # If set, vllm will use precompiled binaries (*.so) - "VLLM_USE_PRECOMPILED": - lambda: os.environ.get("VLLM_USE_PRECOMPILED", "").strip().lower() in - ("1", "true") or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), - + "VLLM_USE_PRECOMPILED": lambda: os.environ.get("VLLM_USE_PRECOMPILED", "") + .strip() + .lower() + in ("1", "true") + or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), # Used to mark that setup.py is running in a Docker build context, # in order to force the use of precompiled binaries. - "VLLM_DOCKER_BUILD_CONTEXT": - lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "").strip().lower() in - ("1", "true"), - + "VLLM_DOCKER_BUILD_CONTEXT": lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "") + .strip() + .lower() + in ("1", "true"), # Whether to force using nightly wheel in python build. # This is used for testing the nightly wheel in python build. - "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": - lambda: bool(int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) - ), - + "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": lambda: bool( + int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) + ), # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" - "CMAKE_BUILD_TYPE": - env_with_choices("CMAKE_BUILD_TYPE", None, - ["Debug", "Release", "RelWithDebInfo"]), - + "CMAKE_BUILD_TYPE": env_with_choices( + "CMAKE_BUILD_TYPE", None, ["Debug", "Release", "RelWithDebInfo"] + ), # If set, vllm will print verbose logs during installation - "VERBOSE": - lambda: bool(int(os.getenv('VERBOSE', '0'))), - + "VERBOSE": lambda: bool(int(os.getenv("VERBOSE", "0"))), # Root directory for vLLM configuration files # Defaults to `~/.config/vllm` unless `XDG_CONFIG_HOME` is set # Note that this not only affects how vllm finds its configuration files # during runtime, but also affects how vllm installs its configuration # files during **installation**. - "VLLM_CONFIG_ROOT": - lambda: os.path.expanduser( + "VLLM_CONFIG_ROOT": lambda: os.path.expanduser( os.getenv( "VLLM_CONFIG_ROOT", os.path.join(get_default_config_root(), "vllm"), - )), - + ) + ), # ================== Runtime Env Vars ================== - # Root directory for vLLM cache files # Defaults to `~/.cache/vllm` unless `XDG_CACHE_HOME` is set - "VLLM_CACHE_ROOT": - lambda: os.path.expanduser( + "VLLM_CACHE_ROOT": lambda: os.path.expanduser( os.getenv( "VLLM_CACHE_ROOT", os.path.join(get_default_cache_root(), "vllm"), - )), - + ) + ), # used in distributed environment to determine the ip address # of the current node, when the node has multiple network interfaces. # If you are using multi-node inference, you should set this differently # on each node. - 'VLLM_HOST_IP': - lambda: os.getenv('VLLM_HOST_IP', ""), - + "VLLM_HOST_IP": lambda: os.getenv("VLLM_HOST_IP", ""), # used in distributed environment to manually set the communication port # Note: if VLLM_PORT is set, and some code asks for multiple ports, the # VLLM_PORT will be used as the first port, and the rest will be generated # by incrementing the VLLM_PORT value. - 'VLLM_PORT': - get_vllm_port, - + "VLLM_PORT": get_vllm_port, # path used for ipc when the frontend api server is running in # multi-processing mode to communicate with the backend engine process. - 'VLLM_RPC_BASE_PATH': - lambda: os.getenv('VLLM_RPC_BASE_PATH', tempfile.gettempdir()), - + "VLLM_RPC_BASE_PATH": lambda: os.getenv( + "VLLM_RPC_BASE_PATH", tempfile.gettempdir() + ), # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers - "VLLM_USE_MODELSCOPE": - lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", - + "VLLM_USE_MODELSCOPE": lambda: os.environ.get( + "VLLM_USE_MODELSCOPE", "False" + ).lower() + == "true", # Interval in seconds to log a warning message when the ring buffer is full - "VLLM_RINGBUFFER_WARNING_INTERVAL": - lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")), - + "VLLM_RINGBUFFER_WARNING_INTERVAL": lambda: int( + os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60") + ), # path to cudatoolkit home directory, under which should be bin, include, # and lib directories. - "CUDA_HOME": - lambda: os.environ.get("CUDA_HOME", None), - + "CUDA_HOME": lambda: os.environ.get("CUDA_HOME", None), # Path to the NCCL library file. It is needed because nccl>=2.19 brought # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 - "VLLM_NCCL_SO_PATH": - lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), - + "VLLM_NCCL_SO_PATH": lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl # library file in the locations specified by `LD_LIBRARY_PATH` - "LD_LIBRARY_PATH": - lambda: os.environ.get("LD_LIBRARY_PATH", None), - + "LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None), # flag to control if vllm should use triton flash attention - "VLLM_USE_TRITON_FLASH_ATTN": - lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in - ("true", "1")), - + "VLLM_USE_TRITON_FLASH_ATTN": lambda: ( + os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1") + ), # Use separate prefill and decode kernels for V1 attention instead of # the unified triton kernel. - "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": - lambda: - (os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in - ("true", "1")), - + "VLLM_V1_USE_PREFILL_DECODE_ATTENTION": lambda: ( + os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() + in ("true", "1") + ), # Use AITER triton unified attention for V1 attention - "VLLM_USE_AITER_UNIFIED_ATTENTION": - lambda: - (os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in - ("true", "1")), - + "VLLM_USE_AITER_UNIFIED_ATTENTION": lambda: ( + os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in ("true", "1") + ), # Force vllm to use a specific flash-attention version (2 or 3), only valid # when using the flash-attention backend. - "VLLM_FLASH_ATTN_VERSION": - lambda: maybe_convert_int(os.environ.get("VLLM_FLASH_ATTN_VERSION", None)), - + "VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int( + os.environ.get("VLLM_FLASH_ATTN_VERSION", None) + ), # Feature flag to enable/disable Inductor standalone compile. # In torch <= 2.7 we ignore this flag; in torch >= 2.8 this is # disabled by default. - "VLLM_USE_STANDALONE_COMPILE": - lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "0") == "1", - + "VLLM_USE_STANDALONE_COMPILE": lambda: os.environ.get( + "VLLM_USE_STANDALONE_COMPILE", "0" + ) + == "1", # Debug pattern matching inside custom passes. # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). - "VLLM_PATTERN_MATCH_DEBUG": - lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None), - + "VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get( + "VLLM_PATTERN_MATCH_DEBUG", None + ), # Dump fx graphs to the given directory. # It will override CompilationConfig.debug_dump_path if set. - "VLLM_DEBUG_DUMP_PATH": - lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None), - + "VLLM_DEBUG_DUMP_PATH": lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None), # local rank of the process in the distributed setting, used to determine # the GPU device id - "LOCAL_RANK": - lambda: int(os.environ.get("LOCAL_RANK", "0")), - + "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), # used to control the visible devices in the distributed setting - "CUDA_VISIBLE_DEVICES": - lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), - + "CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), # timeout for each iteration in the engine - "VLLM_ENGINE_ITERATION_TIMEOUT_S": - lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), - + "VLLM_ENGINE_ITERATION_TIMEOUT_S": lambda: int( + os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60") + ), # API key for vLLM API server - "VLLM_API_KEY": - lambda: os.environ.get("VLLM_API_KEY", None), - + "VLLM_API_KEY": lambda: os.environ.get("VLLM_API_KEY", None), # Whether to log responses from API Server for debugging - "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": - lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" - ).lower() == "true", - + "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": lambda: os.environ.get( + "VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" + ).lower() + == "true", # S3 access information, used for tensorizer to load model from S3 - "S3_ACCESS_KEY_ID": - lambda: os.environ.get("S3_ACCESS_KEY_ID", None), - "S3_SECRET_ACCESS_KEY": - lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), - "S3_ENDPOINT_URL": - lambda: os.environ.get("S3_ENDPOINT_URL", None), - + "S3_ACCESS_KEY_ID": lambda: os.environ.get("S3_ACCESS_KEY_ID", None), + "S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), + "S3_ENDPOINT_URL": lambda: os.environ.get("S3_ENDPOINT_URL", None), # Usage stats collection - "VLLM_USAGE_STATS_SERVER": - lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"), - "VLLM_NO_USAGE_STATS": - lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", - "VLLM_DISABLE_FLASHINFER_PREFILL": - lambda: os.environ.get("VLLM_DISABLE_FLASHINFER_PREFILL", "0") == "1", - "VLLM_DO_NOT_TRACK": - lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get( - "DO_NOT_TRACK", None) or "0") == "1", - "VLLM_USAGE_SOURCE": - lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), - + "VLLM_USAGE_STATS_SERVER": lambda: os.environ.get( + "VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai" + ), + "VLLM_NO_USAGE_STATS": lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", + "VLLM_DISABLE_FLASHINFER_PREFILL": lambda: os.environ.get( + "VLLM_DISABLE_FLASHINFER_PREFILL", "0" + ) + == "1", + "VLLM_DO_NOT_TRACK": lambda: ( + os.environ.get("VLLM_DO_NOT_TRACK", None) + or os.environ.get("DO_NOT_TRACK", None) + or "0" + ) + == "1", + "VLLM_USAGE_SOURCE": lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), # Logging configuration # If set to 0, vllm will not configure logging # If set to 1, vllm will configure logging using the default configuration # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH - "VLLM_CONFIGURE_LOGGING": - lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), - "VLLM_LOGGING_CONFIG_PATH": - lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), - + "VLLM_CONFIGURE_LOGGING": lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_LOGGING_CONFIG_PATH": lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), # this is used for configuring the default logging level - "VLLM_LOGGING_LEVEL": - lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), - + "VLLM_LOGGING_LEVEL": lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), # this is used for configuring the default logging stream - "VLLM_LOGGING_STREAM": - lambda: os.getenv("VLLM_LOGGING_STREAM", "ext://sys.stdout"), - + "VLLM_LOGGING_STREAM": lambda: os.getenv("VLLM_LOGGING_STREAM", "ext://sys.stdout"), # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages - "VLLM_LOGGING_PREFIX": - lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), - + "VLLM_LOGGING_PREFIX": lambda: os.getenv("VLLM_LOGGING_PREFIX", ""), # if set, vllm will call logits processors in a thread pool with this many # threads. This is useful when using custom logits processors that either # (a) launch additional CUDA kernels or (b) do significant CPU-bound work # while not holding the python GIL, or both. - "VLLM_LOGITS_PROCESSOR_THREADS": - lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0")) - if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None, - + "VLLM_LOGITS_PROCESSOR_THREADS": lambda: int( + os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0") + ) + if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ + else None, # If set, vllm will log stats at this interval in seconds # If not set, vllm will log stats every 10 seconds. - "VLLM_LOG_STATS_INTERVAL": - lambda: val if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10."))) - > 0. else 10., - + "VLLM_LOG_STATS_INTERVAL": lambda: val + if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10."))) > 0.0 + else 10.0, # Trace function calls # If set to 1, vllm will trace function calls # Useful for debugging - "VLLM_TRACE_FUNCTION": - lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), - + "VLLM_TRACE_FUNCTION": lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), # Backend for attention computation # Example options: # - "TORCH_SDPA": use torch.nn.MultiheadAttention @@ -617,64 +576,60 @@ def get_vllm_port() -> Optional[int]: # - "FLASHINFER_MLA": use FlashInfer for MLA # - "CUTLASS_MLA": use CUTLASS for MLA # All possible options loaded dynamically from _Backend enum - "VLLM_ATTENTION_BACKEND": - env_with_choices("VLLM_ATTENTION_BACKEND", None, - lambda: list(__import__( - 'vllm.attention.backends.registry', - fromlist=['_Backend'])._Backend.__members__.keys())), - + "VLLM_ATTENTION_BACKEND": env_with_choices( + "VLLM_ATTENTION_BACKEND", + None, + lambda: list( + __import__( + "vllm.attention.backends.registry", fromlist=["_Backend"] + )._Backend.__members__.keys() + ), + ), # If set, vllm will use flashinfer sampler - "VLLM_USE_FLASHINFER_SAMPLER": - lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])) - if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None, - + "VLLM_USE_FLASHINFER_SAMPLER": lambda: bool( + int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]) + ) + if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ + else None, # Pipeline stage partition strategy - "VLLM_PP_LAYER_PARTITION": - lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), - + "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), # (CPU backend only) CPU key-value cache space. # default is None and will be set as 4 GB - "VLLM_CPU_KVCACHE_SPACE": - lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")) - if "VLLM_CPU_KVCACHE_SPACE" in os.environ else None, - + "VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")) + if "VLLM_CPU_KVCACHE_SPACE" in os.environ + else None, # (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31", # "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'. - "VLLM_CPU_OMP_THREADS_BIND": - lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "auto"), - + "VLLM_CPU_OMP_THREADS_BIND": lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "auto"), # (CPU backend only) CPU cores not used by OMP threads . # Those CPU cores will not be used by OMP threads of a rank. - "VLLM_CPU_NUM_OF_RESERVED_CPU": - lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")) - if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None, - + "VLLM_CPU_NUM_OF_RESERVED_CPU": lambda: int( + os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0") + ) + if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ + else None, # (CPU backend only) whether to use prepack for MoE layer. This will be # passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might # need to set this to "0" (False). - "VLLM_CPU_MOE_PREPACK": - lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), - + "VLLM_CPU_MOE_PREPACK": lambda: bool(int(os.getenv("VLLM_CPU_MOE_PREPACK", "1"))), # (CPU backend only) whether to use SGL kernels, optimized for small batch. - "VLLM_CPU_SGL_KERNEL": - lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), - + "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), # If the env var is set, then all workers will execute as separate # processes from the engine, and we use the same mechanism to trigger # execution on all workers. # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. - "VLLM_USE_RAY_SPMD_WORKER": - lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))), - + "VLLM_USE_RAY_SPMD_WORKER": lambda: bool( + int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0")) + ), # If the env var is set, it uses the Ray's Compiled Graph # (previously known as ADAG) API which optimizes the # control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Note that this variable is set to 1 in V1 by default # when ray distributed executor is used. - "VLLM_USE_RAY_COMPILED_DAG": - lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))), - + "VLLM_USE_RAY_COMPILED_DAG": lambda: bool( + int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0")) + ), # If the env var is set, Ray Compiled Graph uses the specified # channel type to communicate between workers belonging to # different pipeline-parallel stages. @@ -683,75 +638,69 @@ def get_vllm_port() -> Optional[int]: # - "nccl": use NCCL for communication # - "shm": use shared memory and gRPC for communication # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. - "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": - env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto", - ["auto", "nccl", "shm"]), - + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": env_with_choices( + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "auto", ["auto", "nccl", "shm"] + ), # If the env var is set, it enables GPU communication overlap # (experimental feature) in Ray's Compiled Graph. This flag is ignored if # VLLM_USE_RAY_COMPILED_DAG is not set. - "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": - lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0")) - ), - + "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": lambda: bool( + int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0")) + ), # If the env var is set, it uses a Ray Communicator wrapping # vLLM's pipeline parallelism communicator to interact with Ray's # Compiled Graph. Otherwise, it uses Ray's NCCL communicator. # This flag is ignored if VLLM_USE_RAY_COMPILED_DAG is not set. - "VLLM_USE_RAY_WRAPPED_PP_COMM": - lambda: bool(int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1"))), - + "VLLM_USE_RAY_WRAPPED_PP_COMM": lambda: bool( + int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1")) + ), # Use dedicated multiprocess context for workers. # Both spawn and fork work - "VLLM_WORKER_MULTIPROC_METHOD": - env_with_choices("VLLM_WORKER_MULTIPROC_METHOD", "fork", - ["spawn", "fork"]), - + "VLLM_WORKER_MULTIPROC_METHOD": env_with_choices( + "VLLM_WORKER_MULTIPROC_METHOD", "fork", ["spawn", "fork"] + ), # Path to the cache for storing downloaded assets - "VLLM_ASSETS_CACHE": - lambda: os.path.expanduser( + "VLLM_ASSETS_CACHE": lambda: os.path.expanduser( os.getenv( "VLLM_ASSETS_CACHE", os.path.join(get_default_cache_root(), "vllm", "assets"), - )), - + ) + ), # If the env var is set, we will clean model file in # this path $VLLM_ASSETS_CACHE/model_streamer/$model_name - "VLLM_ASSETS_CACHE_MODEL_CLEAN": - lambda: bool(int(os.getenv("VLLM_ASSETS_CACHE_MODEL_CLEAN", "0"))), - + "VLLM_ASSETS_CACHE_MODEL_CLEAN": lambda: bool( + int(os.getenv("VLLM_ASSETS_CACHE_MODEL_CLEAN", "0")) + ), # Timeout for fetching images when serving multimodal models # Default is 5 seconds - "VLLM_IMAGE_FETCH_TIMEOUT": - lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), - + "VLLM_IMAGE_FETCH_TIMEOUT": lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), # Timeout for fetching videos when serving multimodal models # Default is 30 seconds - "VLLM_VIDEO_FETCH_TIMEOUT": - lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "30")), - + "VLLM_VIDEO_FETCH_TIMEOUT": lambda: int( + os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "30") + ), # Timeout for fetching audio when serving multimodal models # Default is 10 seconds - "VLLM_AUDIO_FETCH_TIMEOUT": - lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), - + "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int( + os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10") + ), # Whether to allow HTTP redirects when fetching from media URLs. # Default to True - "VLLM_MEDIA_URL_ALLOW_REDIRECTS": - lambda: bool(int(os.getenv("VLLM_MEDIA_URL_ALLOW_REDIRECTS", "1"))), - + "VLLM_MEDIA_URL_ALLOW_REDIRECTS": lambda: bool( + int(os.getenv("VLLM_MEDIA_URL_ALLOW_REDIRECTS", "1")) + ), # Max number of workers for the thread pool handling # media bytes loading. Set to 1 to disable parallel processing. # Default is 8 - "VLLM_MEDIA_LOADING_THREAD_COUNT": - lambda: int(os.getenv("VLLM_MEDIA_LOADING_THREAD_COUNT", "8")), - + "VLLM_MEDIA_LOADING_THREAD_COUNT": lambda: int( + os.getenv("VLLM_MEDIA_LOADING_THREAD_COUNT", "8") + ), # Maximum filesize in MB for a single audio file when processing # speech-to-text requests. Files larger than this will be rejected. # Default is 25 MB - "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB": - lambda: int(os.getenv("VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", "25")), - + "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB": lambda: int( + os.getenv("VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", "25") + ), # Backend for Video IO # - "opencv": Default backend that uses OpenCV stream buffered backend. # @@ -759,289 +708,251 @@ def get_vllm_port() -> Optional[int]: # via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and # imported at runtime. # If a non-existing backend is used, an AssertionError will be thrown. - "VLLM_VIDEO_LOADER_BACKEND": - lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), - + "VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv( + "VLLM_VIDEO_LOADER_BACKEND", "opencv" + ), # [DEPRECATED] Cache size (in GiB per process) for multimodal input cache # Default is 4 GiB per API process + 4 GiB per engine core process - "VLLM_MM_INPUT_CACHE_GIB": - lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), - + "VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), # Path to the XLA persistent cache directory. # Only used for XLA devices such as TPUs. - "VLLM_XLA_CACHE_PATH": - lambda: os.path.expanduser( + "VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser( os.getenv( "VLLM_XLA_CACHE_PATH", os.path.join(get_default_cache_root(), "vllm", "xla_cache"), - )), - + ) + ), # If set, assert on XLA recompilation after each execution step. - "VLLM_XLA_CHECK_RECOMPILATION": - lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0"))), - + "VLLM_XLA_CHECK_RECOMPILATION": lambda: bool( + int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION", "0")) + ), # Enable SPMD mode for TPU backend. - "VLLM_XLA_USE_SPMD": - lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), - "VLLM_FUSED_MOE_CHUNK_SIZE": - lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), + "VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), + "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int( + os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768") + ), # Control whether to use fused MoE activation chunking. Current chunking # logic is incompatible with torch.compile and causes IMA. See issue # https://github.com/vllm-project/vllm/issues/19631. - "VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING": - lambda: bool( - int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1"))), - + "VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING": lambda: bool( + int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1")) + ), # If set, the OpenAI API server will stay alive even after the underlying # AsyncLLMEngine errors and stops serving requests - "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": - lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)), - + "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool( + os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0) + ), # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows # the user to specify a max sequence length greater than # the max length derived from the model's config.json. # To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. - "VLLM_ALLOW_LONG_MAX_MODEL_LEN": - lambda: - (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in - ("1", "true")), - + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": lambda: ( + os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() + in ("1", "true") + ), # If set, forces FP8 Marlin to be used for FP8 quantization regardless # of the hardware support for FP8 compute. - "VLLM_TEST_FORCE_FP8_MARLIN": - lambda: - (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in - ("1", "true")), - "VLLM_TEST_FORCE_LOAD_FORMAT": - lambda: os.getenv("VLLM_TEST_FORCE_LOAD_FORMAT", "dummy"), - + "VLLM_TEST_FORCE_FP8_MARLIN": lambda: ( + os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() + in ("1", "true") + ), + "VLLM_TEST_FORCE_LOAD_FORMAT": lambda: os.getenv( + "VLLM_TEST_FORCE_LOAD_FORMAT", "dummy" + ), # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations - "VLLM_RPC_TIMEOUT": - lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), - + "VLLM_RPC_TIMEOUT": lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), # Timeout in seconds for keeping HTTP connections alive in API server - "VLLM_HTTP_TIMEOUT_KEEP_ALIVE": - lambda: int(os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5")), - + "VLLM_HTTP_TIMEOUT_KEEP_ALIVE": lambda: int( + os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5") + ), # a list of plugin names to load, separated by commas. # if this is not set, it means all plugins will be loaded # if this is set to an empty string, no plugins will be loaded - "VLLM_PLUGINS": - lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ - "VLLM_PLUGINS"].split(","), - + "VLLM_PLUGINS": lambda: None + if "VLLM_PLUGINS" not in os.environ + else os.environ["VLLM_PLUGINS"].split(","), # a local directory to look in for unrecognized LoRA adapters. # only works if plugins are enabled and # VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled. - "VLLM_LORA_RESOLVER_CACHE_DIR": - lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None), - + "VLLM_LORA_RESOLVER_CACHE_DIR": lambda: os.getenv( + "VLLM_LORA_RESOLVER_CACHE_DIR", None + ), # Enables torch profiler if set. # Both AsyncLLM's CPU traces as well as workers' # traces (CPU & GPU) will be saved under this directory. # Note that it must be an absolute path. - "VLLM_TORCH_PROFILER_DIR": - lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os - .path.abspath(os.path.expanduser(os.getenv( - "VLLM_TORCH_PROFILER_DIR", ".")))), - + "VLLM_TORCH_PROFILER_DIR": lambda: ( + None + if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None + else os.path.abspath( + os.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", ".")) + ) + ), # Enable torch profiler to record shapes if set # VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will # not record shapes. - "VLLM_TORCH_PROFILER_RECORD_SHAPES": - lambda: bool(os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0"), - + "VLLM_TORCH_PROFILER_RECORD_SHAPES": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_RECORD_SHAPES", "0") != "0" + ), # Enable torch profiler to profile memory if set # VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1. If not set, torch profiler # will not profile memory. - "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": - lambda: bool( - os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0"), - + "VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY", "0") != "0" + ), # Enable torch profiler to profile stack if set # VLLM_TORCH_PROFILER_WITH_STACK=1. If not set, torch profiler WILL # profile stack by default. - "VLLM_TORCH_PROFILER_WITH_STACK": - lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0"), - + "VLLM_TORCH_PROFILER_WITH_STACK": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_WITH_STACK", "1") != "0" + ), # Enable torch profiler to profile flops if set # VLLM_TORCH_PROFILER_WITH_FLOPS=1. If not set, torch profiler will # not profile flops. - "VLLM_TORCH_PROFILER_WITH_FLOPS": - lambda: bool(os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0"), - + "VLLM_TORCH_PROFILER_WITH_FLOPS": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_WITH_FLOPS", "0") != "0" + ), # If set, vLLM will use Triton implementations of AWQ. - "VLLM_USE_TRITON_AWQ": - lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), - + "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), # If set, allow loading or unloading lora adapters in runtime, - "VLLM_ALLOW_RUNTIME_LORA_UPDATING": - lambda: - (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in - ("1", "true")), - + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": lambda: ( + os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() + in ("1", "true") + ), # We assume drivers can report p2p status correctly. # If the program hangs when using custom allreduce, # potantially caused by a bug in the driver (535 series), # if might be helpful to set VLLM_SKIP_P2P_CHECK=0 # so that vLLM can verify if p2p is actually working. # See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa - "VLLM_SKIP_P2P_CHECK": - lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "1") == "1", - + "VLLM_SKIP_P2P_CHECK": lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "1") == "1", # List of quantization kernels that should be disabled, used for testing # and performance comparisons. Currently only affects MPLinearKernel # selection # (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel) - "VLLM_DISABLED_KERNELS": - lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[ - "VLLM_DISABLED_KERNELS"].split(","), - + "VLLM_DISABLED_KERNELS": lambda: [] + if "VLLM_DISABLED_KERNELS" not in os.environ + else os.environ["VLLM_DISABLED_KERNELS"].split(","), # Swaps the all reduce backend that we use to coordinate the DP padding # information from NCCL to gloo. - "VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION": - lambda: - (os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() in - ("true", "1")), - + "VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION": lambda: ( + os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() + in ("true", "1") + ), # Disable pynccl (using torch.distributed instead) - "VLLM_DISABLE_PYNCCL": - lambda: - (os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")), - + "VLLM_DISABLE_PYNCCL": lambda: ( + os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") + ), # If set, use the V1 code path. - "VLLM_USE_V1": - lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), - + "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), # Disable aiter ops unless specifically enabled. # Acts as a parent switch to enable the rest of the other operations. - "VLLM_ROCM_USE_AITER": - lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1") + ), # Whether to use aiter paged attention. # By default is disabled. - "VLLM_ROCM_USE_AITER_PAGED_ATTN": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_PAGED_ATTN": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in ("true", "1") + ), # use aiter linear op if aiter ops are enabled # The following list of related ops # - scaled_mm (per-tensor / rowwise) - "VLLM_ROCM_USE_AITER_LINEAR": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_LINEAR": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in ("true", "1") + ), # Whether to use aiter moe ops. # By default is enabled. - "VLLM_ROCM_USE_AITER_MOE": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_MOE": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in ("true", "1") + ), # use aiter rms norm op if aiter ops are enabled. - "VLLM_ROCM_USE_AITER_RMSNORM": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_RMSNORM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1") + ), # Whether to use aiter mla ops. # By default is enabled. - "VLLM_ROCM_USE_AITER_MLA": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_MLA": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in ("true", "1") + ), # Whether to use aiter mha ops. # By default is enabled. - "VLLM_ROCM_USE_AITER_MHA": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_MHA": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1") + ), # Whether to use aiter fp4 gemm asm. # By default is disabled. - "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1") + ), # Whether to use aiter rope. # By default is disabled. - "VLLM_ROCM_USE_TRITON_ROPE": - lambda: (os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_TRITON_ROPE": lambda: ( + os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in ("true", "1") + ), # Whether to use aiter triton fp8 bmm kernel # By default is enabled. - "VLLM_ROCM_USE_AITER_FP8BMM": - lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_AITER_FP8BMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1") + ), # use rocm skinny gemms - "VLLM_ROCM_USE_SKINNY_GEMM": - lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( + os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") + ), # Pad the fp8 weights to 256 bytes for ROCm - "VLLM_ROCM_FP8_PADDING": - lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), - + "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), # Pad the weights for the moe kernel - "VLLM_ROCM_MOE_PADDING": - lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), - + "VLLM_ROCM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), # custom paged attention kernel for MI3* cards - "VLLM_ROCM_CUSTOM_PAGED_ATTN": - lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: ( + os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1") + ), # Custom quick allreduce kernel for MI3* cards # Choice of quantization level: FP, INT8, INT6, INT4 or NONE # Recommended for large models to get allreduce - "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION": - env_with_choices("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "NONE", - ["FP", "INT8", "INT6", "INT4", "NONE"]), - + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION": env_with_choices( + "VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", + "NONE", + ["FP", "INT8", "INT6", "INT4", "NONE"], + ), # Custom quick allreduce kernel for MI3* cards # Due to the lack of the bfloat16 asm instruction, bfloat16 # kernels are slower than fp16, # If environment variable is set to 1, the input is converted to fp16 - "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16": - lambda: - (os.getenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "True").lower() in - ("true", "1")), - + "VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16": lambda: ( + os.getenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "True").lower() + in ("true", "1") + ), # Custom quick allreduce kernel for MI3* cards. # Controls the maximum allowed number of data bytes(MB) for custom quick # allreduce communication. # Default: 2048 MB. # Data exceeding this size will use either custom allreduce or RCCL # communication. - "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB": - lambda: maybe_convert_int( - os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)), - + "VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB": lambda: maybe_convert_int( + os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None) + ), # Divisor for dynamic query scale factor calculation for FP8 KV Cache - "Q_SCALE_CONSTANT": - lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), + "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), # Divisor for dynamic key scale factor calculation for FP8 KV Cache - "K_SCALE_CONSTANT": - lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + "K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), # Divisor for dynamic value scale factor calculation for FP8 KV Cache - "V_SCALE_CONSTANT": - lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), - + "V_SCALE_CONSTANT": lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), # If set, enable multiprocessing in LLM for the V1 code path. - "VLLM_ENABLE_V1_MULTIPROCESSING": - lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), - "VLLM_LOG_BATCHSIZE_INTERVAL": - lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")), - "VLLM_DISABLE_COMPILE_CACHE": - lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))), - + "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool( + int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")) + ), + "VLLM_LOG_BATCHSIZE_INTERVAL": lambda: float( + os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1") + ), + "VLLM_DISABLE_COMPILE_CACHE": lambda: bool( + int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0")) + ), # If set, vllm will run in development mode, which will enable # some additional endpoints for developing and debugging, # e.g. `/reset_prefix_cache` - "VLLM_SERVER_DEV_MODE": - lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))), - + "VLLM_SERVER_DEV_MODE": lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))), # Controls the maximum number of requests to handle in a # single asyncio task when processing per-token outputs in the # V1 AsyncLLM interface. It is applicable when handling a high @@ -1049,175 +960,157 @@ def get_vllm_port() -> Optional[int]: # Setting this too high can result in a higher variance of # inter-message latencies. Setting it too low can negatively impact # TTFT and overall throughput. - "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE": - lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")), - + "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE": lambda: int( + os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128") + ), # If set, vLLM will disable the MLA attention optimizations. - "VLLM_MLA_DISABLE": - lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), - + "VLLM_MLA_DISABLE": lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), # If set, vLLM will pick up the provided Flash Attention MLA # max number splits for cuda graph decode - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": - lambda: int(os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", - "32")), - + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": lambda: int( + os.getenv("VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH", "32") + ), # Number of GPUs per worker in Ray, if it is set to be a fraction, # it allows ray to schedule multiple actors on a single GPU, # so that users can colocate other actors on the same GPUs as vLLM. - "VLLM_RAY_PER_WORKER_GPUS": - lambda: float(os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0")), - + "VLLM_RAY_PER_WORKER_GPUS": lambda: float( + os.getenv("VLLM_RAY_PER_WORKER_GPUS", "1.0") + ), # Bundle indices for Ray, if it is set, it can control precisely # which indices are used for the Ray bundle, for every worker. # Format: comma-separated list of integers, e.g. "0,1,2,3" - "VLLM_RAY_BUNDLE_INDICES": - lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""), - + "VLLM_RAY_BUNDLE_INDICES": lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""), # In some system, find_loaded_library() may not work. So we allow users to # specify the path through environment variable VLLM_CUDART_SO_PATH. - "VLLM_CUDART_SO_PATH": - lambda: os.getenv("VLLM_CUDART_SO_PATH", None), - + "VLLM_CUDART_SO_PATH": lambda: os.getenv("VLLM_CUDART_SO_PATH", None), # Rank of the process in the data parallel setting - "VLLM_DP_RANK": - lambda: int(os.getenv("VLLM_DP_RANK", "0")), - + "VLLM_DP_RANK": lambda: int(os.getenv("VLLM_DP_RANK", "0")), # Rank of the process in the data parallel setting. # Defaults to VLLM_DP_RANK when not set. - "VLLM_DP_RANK_LOCAL": - lambda: int( - os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK)), - + "VLLM_DP_RANK_LOCAL": lambda: int( + os.getenv("VLLM_DP_RANK_LOCAL", sys.modules[__name__].VLLM_DP_RANK) + ), # World size of the data parallel setting - "VLLM_DP_SIZE": - lambda: int(os.getenv("VLLM_DP_SIZE", "1")), - + "VLLM_DP_SIZE": lambda: int(os.getenv("VLLM_DP_SIZE", "1")), # IP address of the master node in the data parallel setting - "VLLM_DP_MASTER_IP": - lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), - + "VLLM_DP_MASTER_IP": lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"), # Port of the master node in the data parallel setting - "VLLM_DP_MASTER_PORT": - lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), - + "VLLM_DP_MASTER_PORT": lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), # In the context of executing MoE models with Data-Parallel, Expert-Parallel # and Batched All-to-All dispatch/combine kernels, VLLM_MOE_DP_CHUNK_SIZE # dictates the quantum of tokens that can be dispatched from a DP # rank. All DP ranks process the activations in VLLM_MOE_DP_CHUNK_SIZE # units. - "VLLM_MOE_DP_CHUNK_SIZE": - lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), - + "VLLM_MOE_DP_CHUNK_SIZE": lambda: int(os.getenv("VLLM_MOE_DP_CHUNK_SIZE", "256")), # Randomize inputs during dummy runs when using Data Parallel - "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": - lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1", - + "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: os.environ.get( + "VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0" + ) + == "1", # Whether to use S3 path for model loading in CI via RunAI Streamer - "VLLM_CI_USE_S3": - lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", - + "VLLM_CI_USE_S3": lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", # Use model_redirect to redirect the model name to a local folder. # `model_redirect` can be a json file mapping the model between # repo_id and local folder: # {"meta-llama/Llama-3.2-1B": "/tmp/Llama-3.2-1B"} # or a space separated values table file: # meta-llama/Llama-3.2-1B /tmp/Llama-3.2-1B - "VLLM_MODEL_REDIRECT_PATH": - lambda: os.environ.get("VLLM_MODEL_REDIRECT_PATH", None), - + "VLLM_MODEL_REDIRECT_PATH": lambda: os.environ.get( + "VLLM_MODEL_REDIRECT_PATH", None + ), # Whether to use atomicAdd reduce in gptq/awq marlin kernel. - "VLLM_MARLIN_USE_ATOMIC_ADD": - lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", - + "VLLM_MARLIN_USE_ATOMIC_ADD": lambda: os.environ.get( + "VLLM_MARLIN_USE_ATOMIC_ADD", "0" + ) + == "1", # Whether to use marlin kernel in mxfp4 quantization method - "VLLM_MXFP4_USE_MARLIN": - lambda: maybe_convert_bool(os.environ.get("VLLM_MXFP4_USE_MARLIN", None)), - + "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( + os.environ.get("VLLM_MXFP4_USE_MARLIN", None) + ), # Whether to turn on the outlines cache for V0 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. - "VLLM_V0_USE_OUTLINES_CACHE": - lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", - + "VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get( + "VLLM_V0_USE_OUTLINES_CACHE", "0" + ) + == "1", # Whether to turn on the outlines cache for V1 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. - "VLLM_V1_USE_OUTLINES_CACHE": - lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1", - + "VLLM_V1_USE_OUTLINES_CACHE": lambda: os.environ.get( + "VLLM_V1_USE_OUTLINES_CACHE", "0" + ) + == "1", # Gap between padding buckets for the forward pass. So we have # 8, we will run forward pass with [16, 24, 32, ...]. - "VLLM_TPU_BUCKET_PADDING_GAP": - lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) - if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0, - "VLLM_TPU_MOST_MODEL_LEN": - lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)), - + "VLLM_TPU_BUCKET_PADDING_GAP": lambda: int( + os.environ["VLLM_TPU_BUCKET_PADDING_GAP"] + ) + if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ + else 0, + "VLLM_TPU_MOST_MODEL_LEN": lambda: maybe_convert_int( + os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None) + ), # Whether using Pathways - "VLLM_TPU_USING_PATHWAYS": - lambda: bool("proxy" in os.getenv("JAX_PLATFORMS", "").lower()), - + "VLLM_TPU_USING_PATHWAYS": lambda: bool( + "proxy" in os.getenv("JAX_PLATFORMS", "").lower() + ), # Allow use of DeepGemm kernels for fused moe ops. - "VLLM_USE_DEEP_GEMM": - lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))), - + "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. - "VLLM_USE_DEEP_GEMM_E8M0": - lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))), + "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool( + int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1")) + ), # TODO(wentao): unify the two E8M0 flags after verifying the correctness. # Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs. - "VLLM_USE_DEEP_GEMM_E8M0_HOPPER": - lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))), + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER": lambda: bool( + int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0")) + ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine # startup time by a couple of minutes. # Set `VLLM_SKIP_DEEP_GEMM_WARMUP` to disable the warmup. - "VLLM_SKIP_DEEP_GEMM_WARMUP": - lambda: bool(int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0"))), - + "VLLM_SKIP_DEEP_GEMM_WARMUP": lambda: bool( + int(os.getenv("VLLM_SKIP_DEEP_GEMM_WARMUP", "0")) + ), # Whether to use fused grouped_topk used for MoE expert selection. - "VLLM_USE_FUSED_MOE_GROUPED_TOPK": - lambda: bool(int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))), - + "VLLM_USE_FUSED_MOE_GROUPED_TOPK": lambda: bool( + int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1")) + ), # Allow use of FlashInfer MoE kernels for fused moe ops. - "VLLM_USE_FLASHINFER_MOE_FP16": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))), - + "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0")) + ), # Allow use of FlashInfer MoE kernels for fused moe ops. - "VLLM_USE_FLASHINFER_MOE_FP8": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), - + "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0")) + ), # Allow use of FlashInfer CUTLASS kernels for fused moe ops. - "VLLM_USE_FLASHINFER_MOE_FP4": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))), - + "VLLM_USE_FLASHINFER_MOE_FP4": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0")) + ), # If set to 1, use the FlashInfer # MXFP8 (activation) x MXFP4 (weight) MoE backend. - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))), - + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0")) + ), # If set to 1, use the FlashInfer CUTLASS backend for # MXFP8 (activation) x MXFP4 (weight) MoE. # This is separate from the TRTLLMGEN path controlled by # VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8. - "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS": - lambda: bool(int( - os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0") - )), - + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0")) + ), # If set to 1, use the FlashInfer # BF16 (activation) x MXFP4 (weight) MoE backend. - "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": - lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))), - + "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0")) + ), # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. - "VLLM_XGRAMMAR_CACHE_MB": - lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), - + "VLLM_XGRAMMAR_CACHE_MB": lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), # Control the threshold for msgspec to use 'zero copy' for # serialization/deserialization of tensors. Tensors below # this limit will be encoded into the msgpack buffer, and @@ -1225,23 +1118,23 @@ def get_vllm_port() -> Optional[int]: # While the sending side still actually copies the tensor # in all cases, on the receiving side, tensors above this # limit will actually be zero-copy decoded. - "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": - lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), - + "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": lambda: int( + os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256") + ), # If set, allow insecure serialization using pickle. # This is useful for environments where it is deemed safe to use the # insecure method and it is needed for some reason. - "VLLM_ALLOW_INSECURE_SERIALIZATION": - lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))), - + "VLLM_ALLOW_INSECURE_SERIALIZATION": lambda: bool( + int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0")) + ), # IP address used for NIXL handshake between remote agents. - "VLLM_NIXL_SIDE_CHANNEL_HOST": - lambda: os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"), - + "VLLM_NIXL_SIDE_CHANNEL_HOST": lambda: os.getenv( + "VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost" + ), # Port used for NIXL handshake between remote agents. - "VLLM_NIXL_SIDE_CHANNEL_PORT": - lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600")), - + "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int( + os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600") + ), # all2all backend for vllm's expert parallel communication # Available options: # - "naive": naive all2all implementation using broadcasts @@ -1251,14 +1144,18 @@ def get_vllm_port() -> Optional[int]: # - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_low_latency", use deepep low-latency kernels # - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl - "VLLM_ALL2ALL_BACKEND": - env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter", - ["naive", "pplx", - "deepep_high_throughput", - "deepep_low_latency", - "allgather_reducescatter", - "flashinfer_all2allv"]), - + "VLLM_ALL2ALL_BACKEND": env_with_choices( + "VLLM_ALL2ALL_BACKEND", + "allgather_reducescatter", + [ + "naive", + "pplx", + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter", + "flashinfer_all2allv", + ], + ), # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. # Both require compute capability 10.0 or above. # Available options: @@ -1266,56 +1163,52 @@ def get_vllm_port() -> Optional[int]: # Uses CUTLASS kernels optimized for high-throughput batch inference. # - "latency": # Uses TensorRT-LLM kernels optimized for low-latency inference. - "VLLM_FLASHINFER_MOE_BACKEND": - env_with_choices("VLLM_FLASHINFER_MOE_BACKEND", "throughput", - ["throughput", "latency"]), - + "VLLM_FLASHINFER_MOE_BACKEND": env_with_choices( + "VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"] + ), # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for # the blockscale tensor of activations NVFP4 Quantization. # This is used to prevent the kernel from running out of memory. - "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": - lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")), - + "VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE": lambda: int( + os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840") + ), # Specifies the thresholds of the communicated tensor sizes under which # vllm should use flashinfer fused allreduce. The variable should be a # JSON with the following format: # { <world size>: <max size in mb> } # Unspecified world sizes will fall back to # { 2: 64, 4: 1, <everything else>: 0.5 } - "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB": - lambda: json.loads(os.getenv( - "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")), - + "VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB": lambda: json.loads( + os.getenv("VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}") + ), # MoE routing strategy selector. # See `RoutingSimulator.get_available_strategies()` # for available # strategies. # Cutstom routing strategies can be registered by # RoutingSimulator.register_strategy() # Note: custom strategies may not produce correct model outputs - "VLLM_MOE_ROUTING_SIMULATION_STRATEGY": - lambda: os.environ.get("VLLM_MOE_ROUTING_SIMULATION_STRATEGY", "").lower(), - + "VLLM_MOE_ROUTING_SIMULATION_STRATEGY": lambda: os.environ.get( + "VLLM_MOE_ROUTING_SIMULATION_STRATEGY", "" + ).lower(), # Regex timeout for use by the vLLM tool parsing plugins. - "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": - lambda: int(os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")), - + "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int( + os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1") + ), # Reduce CPU usage when vLLM is idle. Enabling this will incur small # latency penalty when a request eventually comes. - "VLLM_SLEEP_WHEN_IDLE": - lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), - + "VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), # Control the max chunk bytes (in MB) for the rpc message queue. # Object larger than this threshold will be broadcast to worker # processes via zmq. - "VLLM_MQ_MAX_CHUNK_BYTES_MB": - lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")), - + "VLLM_MQ_MAX_CHUNK_BYTES_MB": lambda: int( + os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16") + ), # Timeout in seconds for execute_model RPC calls in multiprocessing # executor (only applies when TP > 1). - "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS": - lambda: int(os.getenv("VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", "300")), - + "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS": lambda: int( + os.getenv("VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", "300") + ), # KV Cache layout used throughout vllm. # Some common values are: # - NHD @@ -1323,76 +1216,71 @@ def get_vllm_port() -> Optional[int]: # Where N=num_blocks, H=num_heads and D=head_size. The default value will # leave the layout choice to the backend. Mind that backends may only # implement and support a subset of all possible layouts. - "VLLM_KV_CACHE_LAYOUT": - env_with_choices("VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"]), - + "VLLM_KV_CACHE_LAYOUT": env_with_choices( + "VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"] + ), # Enable checking whether the generated logits contain NaNs, # indicating corrupted output. Useful for debugging low level bugs # or bad hardware but it may add compute overhead. - "VLLM_COMPUTE_NANS_IN_LOGITS": - lambda: bool(int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0"))), - + "VLLM_COMPUTE_NANS_IN_LOGITS": lambda: bool( + int(os.getenv("VLLM_COMPUTE_NANS_IN_LOGITS", "0")) + ), # Controls whether or not emulations are used for NVFP4 # generations on machines < 100 for compressed-tensors # models - "VLLM_USE_NVFP4_CT_EMULATIONS": - lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))), - + "VLLM_USE_NVFP4_CT_EMULATIONS": lambda: bool( + int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0")) + ), # Time (in seconds) after which the KV cache on the producer side is # automatically cleared if no READ notification is received from the # consumer. This is only applicable when using NixlConnector in a # disaggregated decode-prefill setup. - "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": - lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480")), - + "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( + os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") + ), # Controls whether or not to use cudnn prefill - "VLLM_USE_CUDNN_PREFILL": - lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), - + "VLLM_USE_CUDNN_PREFILL": lambda: bool( + int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) + ), # If set to 1/True, use the TRTLLM attention backend in flashinfer. # If set to 0/False, use the default attention backend in flashinfer. # If not set, auto-detect the attention backend in flashinfer. - "VLLM_USE_TRTLLM_ATTENTION": - lambda: (None if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ else - os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true")), - + "VLLM_USE_TRTLLM_ATTENTION": lambda: ( + None + if "VLLM_USE_TRTLLM_ATTENTION" not in os.environ + else os.environ["VLLM_USE_TRTLLM_ATTENTION"].lower() in ("1", "true") + ), # If set to 1, when we use fp8 kv, we do not quantize Q to fp8 - "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION": - lambda: bool(int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0"))), - + "VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION": lambda: bool( + int(os.getenv("VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION", "0")) + ), # If set, it means we pre-downloaded cubin files and flashinfer will # read the cubin files directly. - "VLLM_HAS_FLASHINFER_CUBIN": - lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), - + "VLLM_HAS_FLASHINFER_CUBIN": lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. # Otherwise, uses the first available of: flashinfer cutlass GEMM, # vllm cutlass GEMM, marlin GEMM. - "VLLM_USE_TRTLLM_FP4_GEMM": - lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0"))), - + "VLLM_USE_TRTLLM_FP4_GEMM": lambda: bool( + int(os.getenv("VLLM_USE_TRTLLM_FP4_GEMM", "0")) + ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. # If set to 1, allows GC to run during capture. - "VLLM_ENABLE_CUDAGRAPH_GC": - lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))), - + "VLLM_ENABLE_CUDAGRAPH_GC": lambda: bool( + int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0")) + ), # Disable padding to CUDA graph capture batch sizes. # TODO(wentao): https://github.com/vllm-project/vllm/issues/23378 # After the issue is fixed, we can remove this flag. - "VLLM_DISABLE_PAD_FOR_CUDAGRAPH": - lambda: bool(int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0"))), - + "VLLM_DISABLE_PAD_FOR_CUDAGRAPH": lambda: bool( + int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0")) + ), # Used to force set up loopback IP - "VLLM_LOOPBACK_IP": - lambda: os.getenv("VLLM_LOOPBACK_IP", ""), - + "VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""), # Used to set the process name prefix for vLLM processes. # This is useful for debugging and monitoring purposes. # The default value is "VLLM". - "VLLM_PROCESS_NAME_PREFIX": - lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"), - + "VLLM_PROCESS_NAME_PREFIX": lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"), # Allow chunked local attention with hybrid kv cache manager. # Currently using the Hybrid KV cache manager with chunked local attention # in the Llama4 models (the only models currently using chunked local attn) @@ -1400,10 +1288,9 @@ def get_vllm_port() -> Optional[int]: # This flag is used to allow users to enable it if they want to (to save on # kv-cache memory usage and enable longer contexts) # TODO(lucas): Remove this flag once latency regression is resolved. - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": - lambda: bool(int(os.getenv(\ - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0"))), - + "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE": lambda: bool( + int(os.getenv("VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE", "0")) + ), # Enables support for the "store" option in the OpenAI Responses API. # When set to 1, vLLM's OpenAI server will retain the input and output # messages for those requests in memory. By default, this is disabled (0), @@ -1413,83 +1300,74 @@ def get_vllm_port() -> Optional[int]: # lost when the vLLM server shuts down. # 2. Enabling this option will cause a memory leak, as stored messages are # never removed from memory until the server terminates. - "VLLM_ENABLE_RESPONSES_API_STORE": - lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), - + "VLLM_ENABLE_RESPONSES_API_STORE": lambda: bool( + int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0")) + ), # If set, use the fp8 mfma in rocm paged attention. - "VLLM_ROCM_FP8_MFMA_PAGE_ATTN": - lambda: bool(int(os.getenv("VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "0"))), - + "VLLM_ROCM_FP8_MFMA_PAGE_ATTN": lambda: bool( + int(os.getenv("VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "0")) + ), # Whether to use pytorch symmetric memory for allreduce - "VLLM_ALLREDUCE_USE_SYMM_MEM": - lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))), - + "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool( + int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")) + ), # Allows vllm to find tuned config under customized folder - "VLLM_TUNED_CONFIG_FOLDER": - lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), - + "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), # Allows harmony instructions to be injected on system messages - "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": - lambda: bool( - int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))), - + "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool( + int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0")) + ), # Add optional custom scopes for profiling, disable to avoid overheads - "VLLM_CUSTOM_SCOPES_FOR_PROFILING": - lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))), - + "VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool( + int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0")) + ), # Add optional nvtx scopes for profiling, disable to avoid overheads - "VLLM_NVTX_SCOPES_FOR_PROFILING": - lambda: bool(int(os.getenv("VLLM_NVTX_SCOPES_FOR_PROFILING", "0"))), - + "VLLM_NVTX_SCOPES_FOR_PROFILING": lambda: bool( + int(os.getenv("VLLM_NVTX_SCOPES_FOR_PROFILING", "0")) + ), # Represent block hashes in KV cache events as 64-bit integers instead of # raw bytes. Defaults to True for backward compatibility. - "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": - lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))), - + "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": lambda: bool( + int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1")) + ), # Name of the shared memory buffer used for object storage. # Only effective when mm_config.mm_processor_cache_type == "shm". - "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": - lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", - "VLLM_OBJECT_STORAGE_SHM_BUFFER"), - + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": lambda: os.getenv( + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", "VLLM_OBJECT_STORAGE_SHM_BUFFER" + ), # The size in MB of the buffers (NVL and RDMA) used by DeepEP - "VLLM_DEEPEP_BUFFER_SIZE_MB": - lambda: int(os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024")), - + "VLLM_DEEPEP_BUFFER_SIZE_MB": lambda: int( + os.getenv("VLLM_DEEPEP_BUFFER_SIZE_MB", "1024") + ), # The number of SMs to allocate for communication kernels when running DBO # the rest of the SMs on the device will be allocated to compute - "VLLM_DBO_COMM_SMS": - lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")), - + "VLLM_DBO_COMM_SMS": lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")), # Valid values are container,code_interpreter,web_search_preview # ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter - "GPT_OSS_SYSTEM_TOOL_MCP_LABELS": - env_list_with_choices("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", [], - ["container", - "code_interpreter", - "web_search_preview"]), - + "GPT_OSS_SYSTEM_TOOL_MCP_LABELS": env_list_with_choices( + "GPT_OSS_SYSTEM_TOOL_MCP_LABELS", + [], + ["container", "code_interpreter", "web_search_preview"], + ), # Enable max_autotune & coordinate_descent_tuning in inductor_config # to compile static shapes passed from compile_sizes in compilation_config # If set to 1, enable max_autotune; By default, this is enabled (1) - "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE": - lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "1"))), + "VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE": lambda: bool( + int(os.getenv("VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE", "1")) + ), # If set to 1, enable coordinate_descent_tuning; # By default, this is enabled (1) - "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING": - lambda: bool(int(os.getenv("VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", - "1"))), - + "VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING": lambda: bool( + int(os.getenv("VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING", "1")) + ), # Flag to enable NCCL symmetric memory allocation and registration - "VLLM_USE_NCCL_SYMM_MEM": - lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))), - + "VLLM_USE_NCCL_SYMM_MEM": lambda: bool( + int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0")) + ), # NCCL header path - "VLLM_NCCL_INCLUDE_PATH": - lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None), + "VLLM_NCCL_INCLUDE_PATH": lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None), # Flag to enable FBGemm kernels on model execution "VLLM_USE_FBGEMM": lambda: bool(int(os.getenv("VLLM_USE_FBGEMM", "0"))), - # GC debug config # - VLLM_GC_DEBUG=0: disable GC debugger # - VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elpased times @@ -1524,7 +1402,8 @@ def set_vllm_use_v1(use_v1: bool): raise ValueError( "Should not call set_vllm_use_v1() if VLLM_USE_V1 is set " "explicitly by the user. Please raise this as a Github " - "Issue and explicitly set VLLM_USE_V1=0 or 1.") + "Issue and explicitly set VLLM_USE_V1=0 or 1." + ) os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0" @@ -1598,14 +1477,12 @@ def compute_hash() -> str: for key in environment_variables_to_hash: # if this goes out of sync with environment_variables, # it's not a user error, it's a bug - assert key in environment_variables, \ + assert key in environment_variables, ( "Please update environment_variables_to_hash in envs.py" + ) - factors = [ - environment_variables[key]() for key in environment_variables_to_hash - ] + factors = [environment_variables[key]() for key in environment_variables_to_hash] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index fe80be61410c..3a7347b8e465 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -4,11 +4,11 @@ import asyncio import time from abc import ABC, abstractmethod +from collections.abc import Awaitable from functools import cached_property -from typing import Any, Awaitable, Callable, List, Optional, Set, Union +from typing import Any, Callable, Optional, Union -import torch.nn as nn -from typing_extensions import TypeVar, deprecated +from typing_extensions import TypeVar import vllm.platforms from vllm.config import VllmConfig @@ -61,11 +61,13 @@ def _init_executor(self) -> None: raise NotImplementedError @abstractmethod - def collective_rpc(self, - method: Union[str, Callable[[WorkerBase], _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[[WorkerBase], _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: """ Execute an RPC call on all workers. @@ -110,32 +112,29 @@ def determine_num_available_blocks(self) -> tuple[int, int]: return a, b def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: - """Initialize the KV cache by invoking the underlying worker. - """ + """Initialize the KV cache by invoking the underlying worker.""" # NOTE: This is logged in the executor because there can be >1 workers. - logger.info("# %s blocks: %d, # CPU blocks: %d", - vllm.platforms.current_platform.device_name, - num_gpu_blocks, num_cpu_blocks) - max_concurrency = (num_gpu_blocks * self.cache_config.block_size / - self.model_config.max_model_len) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - self.model_config.max_model_len, max_concurrency) + logger.info( + "# %s blocks: %d, # CPU blocks: %d", + vllm.platforms.current_platform.device_name, + num_gpu_blocks, + num_cpu_blocks, + ) + max_concurrency = ( + num_gpu_blocks + * self.cache_config.block_size + / self.model_config.max_model_len + ) + logger.info( + "Maximum concurrency for %s tokens per request: %.2fx", + self.model_config.max_model_len, + max_concurrency, + ) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - self.collective_rpc("initialize_cache", - args=(num_gpu_blocks, num_cpu_blocks)) - - @deprecated("`llm_engine.model_executor.apply_model` will no longer work " - "in V1 Engine. Please replace with `llm_engine.apply_model` " - "and set `VLLM_ALLOW_INSECURE_SERIALIZATION=1`.") - def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: - """ - Run a function directly on the model inside each worker, - returning the result for each of them. - """ - return self.collective_rpc("apply_model", args=(func, )) + self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) @cached_property # Avoid unnecessary RPC calls def supported_tasks(self) -> tuple[SupportedTask, ...]: @@ -144,9 +143,8 @@ def supported_tasks(self) -> tuple[SupportedTask, ...]: def execute_model( self, execute_model_req: ExecuteModelRequest - ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: - output = self.collective_rpc("execute_model", - args=(execute_model_req, )) + ) -> Optional[list[Union[SamplerOutput, PoolerOutput]]]: + output = self.collective_rpc("execute_model", args=(execute_model_req,)) return output[0] def stop_remote_worker_execution_loop(self) -> None: @@ -155,17 +153,17 @@ def stop_remote_worker_execution_loop(self) -> None: def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("add_lora", args=(lora_request, ))) + return all(self.collective_rpc("add_lora", args=(lora_request,))) def remove_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("remove_lora", args=(lora_id, ))) + return all(self.collective_rpc("remove_lora", args=(lora_id,))) def pin_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." - return all(self.collective_rpc("pin_lora", args=(lora_id, ))) + return all(self.collective_rpc("pin_lora", args=(lora_id,))) - def list_loras(self) -> Set[int]: + def list_loras(self) -> set[int]: sets = self.collective_rpc("list_loras") for s in sets: assert s == sets[0], "All workers should have the same LORAs." @@ -186,8 +184,9 @@ def sleep(self, level: int = 1): time_after_sleep = time.perf_counter() self.sleeping_tags = {"weights", "kv_cache"} self.is_sleeping = True - logger.info("It took %.6f seconds to fall asleep.", - time_after_sleep - time_before_sleep) + logger.info( + "It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep + ) def wake_up(self, tags: Optional[list[str]] = None): if not self.is_sleeping: @@ -196,15 +195,18 @@ def wake_up(self, tags: Optional[list[str]] = None): if tags: for tag in tags: if tag not in self.sleeping_tags: - logger.warning("Tag %s is not in sleeping tags %s", tag, - self.sleeping_tags) + logger.warning( + "Tag %s is not in sleeping tags %s", tag, self.sleeping_tags + ) return time_before_wakeup = time.perf_counter() self.collective_rpc("wake_up", kwargs=dict(tags=tags)) time_after_wakeup = time.perf_counter() - logger.info("It took %.6f seconds to wake up tags %s.", - time_after_wakeup - time_before_wakeup, - tags if tags is not None else self.sleeping_tags) + logger.info( + "It took %.6f seconds to wake up tags %s.", + time_after_wakeup - time_before_wakeup, + tags if tags is not None else self.sleeping_tags, + ) if tags: for tag in tags: self.sleeping_tags.remove(tag) @@ -219,10 +221,10 @@ def save_sharded_state( pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - self.collective_rpc("save_sharded_state", - kwargs=dict(path=path, - pattern=pattern, - max_size=max_size)) + self.collective_rpc( + "save_sharded_state", + kwargs=dict(path=path, pattern=pattern, max_size=max_size), + ) @abstractmethod def check_health(self) -> None: @@ -235,8 +237,8 @@ def shutdown(self) -> None: self.collective_rpc("shutdown") async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> list[SamplerOutput]: """Executes one model step on the given sequences.""" output = await make_async(self.execute_model)(execute_model_req) return output @@ -253,7 +255,8 @@ async def check_health_async(self) -> None: def init_kv_output_aggregator(self, finished_count: Optional[int]) -> None: """Init KVOutputAggregator""" self.kv_output_aggregator = KVOutputAggregator( - finished_count or self.parallel_config.world_size) + finished_count or self.parallel_config.world_size + ) class DistributedExecutorBase(ExecutorBase): @@ -269,12 +272,13 @@ def __init__(self, *args, **kwargs): def execute_model( self, execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: # TODO: unify into collective_rpc if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", - async_run_tensor_parallel_workers_only=True) + async_run_tensor_parallel_workers_only=True, + ) # Only the driver worker returns the sampling results. driver_outputs = self._driver_execute_model(execute_model_req) @@ -295,7 +299,7 @@ def stop_remote_worker_execution_loop(self) -> None: @abstractmethod def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: + ) -> Optional[list[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop @@ -304,11 +308,13 @@ def _driver_execute_model( """ raise NotImplementedError - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[Any]: return self._run_workers(method, *args, **(kwargs or {})) @abstractmethod @@ -339,12 +345,13 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: raise NotImplementedError async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> list[SamplerOutput]: if self.parallel_worker_tasks is None: # Start model execution loop running in the parallel workers self.parallel_worker_tasks = asyncio.create_task( - self._start_worker_execution_loop()) + self._start_worker_execution_loop() + ) # Only the driver worker returns the sampling results. return await self._driver_execute_model_async(execute_model_req) @@ -364,7 +371,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None: async def _driver_execute_model_async( self, execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> List[SamplerOutput]: + ) -> list[SamplerOutput]: """Execute the model asynchronously in the driver worker. Passing None will cause the driver to stop the model execution diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py index 4ce6d8dfad2c..ac16f06b160e 100644 --- a/vllm/executor/msgspec_utils.py +++ b/vllm/executor/msgspec_utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from array import array -from typing import Any, Type +from typing import Any from vllm.multimodal.inputs import MultiModalKwargs from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE @@ -16,13 +16,14 @@ def encode_hook(obj: Any) -> Any: if isinstance(obj, array): assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, ( f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " - f"Given array has a type code of {obj.typecode}.") + f"Given array has a type code of {obj.typecode}." + ) return obj.tobytes() if isinstance(obj, MultiModalKwargs): return dict(obj) -def decode_hook(type: Type, obj: Any) -> Any: +def decode_hook(type: type, obj: Any) -> Any: """Custom msgspec dec hook that supports array types and MultiModalKwargs. See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 84747575b496..6a9608d70b69 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -5,23 +5,26 @@ import os from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import cloudpickle import msgspec import vllm.envs as envs -from vllm.executor.executor_base import ( - DistributedExecutorBase) # yapf: disable +from vllm.executor.executor_base import DistributedExecutorBase from vllm.executor.msgspec_utils import encode_hook -from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, - ray) +from vllm.executor.ray_utils import RayWorkerWrapper, initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy from vllm.sequence import ExecuteModelRequest -from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, make_async) +from vllm.utils import ( + _run_task_with_lock, + get_distributed_init_method, + get_ip, + get_open_port, + make_async, +) from vllm.v1.outputs import SamplerOutput if ray is not None: @@ -43,6 +46,7 @@ class RayWorkerMetaData: The order of ray worker creation can be random, and we need to reset the rank after creating all workers. """ + worker: ActorHandle created_rank: int adjusted_rank: int = -1 @@ -55,7 +59,10 @@ class RayDistributedExecutor(DistributedExecutorBase): # These env vars are worker-specific, therefore are NOT copied # from the driver to the workers WORKER_SPECIFIC_ENV_VARS = { - "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES" + "VLLM_HOST_IP", + "VLLM_HOST_PORT", + "LOCAL_RANK", + "CUDA_VISIBLE_DEVICES", } # These non-vLLM env vars are copied from the driver to workers @@ -86,13 +93,13 @@ def _init_executor(self) -> None: self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER if self.use_ray_compiled_dag: assert self.use_ray_spmd_worker, ( - "VLLM_USE_RAY_COMPILED_DAG=1 requires " - "VLLM_USE_RAY_SPMD_WORKER=1") + "VLLM_USE_RAY_COMPILED_DAG=1 requires VLLM_USE_RAY_SPMD_WORKER=1" + ) if self.use_ray_spmd_worker: # TODO: Support SPMD worker for non-DAG Ray executor. assert self.use_ray_compiled_dag, ( - "VLLM_USE_RAY_SPMD_WORKER=1 requires " - "VLLM_USE_RAY_COMPILED_DAG=1") + "VLLM_USE_RAY_SPMD_WORKER=1 requires VLLM_USE_RAY_COMPILED_DAG=1" + ) assert self.uses_ray initialize_ray_cluster(self.parallel_config) @@ -107,14 +114,12 @@ def _init_executor(self) -> None: self._init_workers_ray(placement_group) self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) - self.output_decoder = msgspec.msgpack.Decoder( - Optional[List[SamplerOutput]]) + self.output_decoder = msgspec.msgpack.Decoder(Optional[list[SamplerOutput]]) self.use_v1 = envs.VLLM_USE_V1 - self.pp_locks: Optional[List[asyncio.Lock]] = None + self.pp_locks: Optional[list[asyncio.Lock]] = None if not self.use_ray_compiled_dag: - self.driver_exec_method = make_async( - self.driver_worker.execute_method) + self.driver_exec_method = make_async(self.driver_worker.execute_method) def shutdown(self) -> None: if logger: @@ -122,26 +127,29 @@ def shutdown(self) -> None: logger.info( "Shutting down Ray distributed executor. If you see error log " "from logging.cc regarding SIGTERM received, please ignore " - "because this is the expected termination process in Ray.") + "because this is the expected termination process in Ray." + ) if hasattr(self, "forward_dag") and self.forward_dag is not None: self.forward_dag.teardown() import ray + for worker in self.workers: ray.kill(worker) self.forward_dag = None - def _configure_ray_workers_use_nsight(self, - ray_remote_kwargs) -> Dict[str, Any]: + def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]: # If nsight profiling is enabled, we need to set the profiling # configuration for the ray workers as runtime env. runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) - runtime_env.update({ - "nsight": { - "t": "cuda,cudnn,cublas", - "o": "'worker_process_%p'", - "cuda-graph-trace": "node", + runtime_env.update( + { + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } } - }) + ) return ray_remote_kwargs @@ -149,49 +157,50 @@ def _configure_ray_workers_use_nsight(self, def _get_env_vars_to_be_updated(self): return self._env_vars_for_all_workers - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. self.driver_dummy_worker: Optional[RayWorkerWrapper] = None # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerWrapper] = [] + self.workers: list[RayWorkerWrapper] = [] # Used in ray compiled DAG: indexed first by PP rank, # and then TP rank. In other words, the inner list is # the TP group of workers for a PP rank. - self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + self.pp_tp_workers: list[list[RayWorkerWrapper]] = [] if self.parallel_config.ray_workers_use_nsight: ray_remote_kwargs = self._configure_ray_workers_use_nsight( - ray_remote_kwargs) + ray_remote_kwargs + ) logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) # Create the workers. - bundle_indices: List[int] + bundle_indices: list[int] if envs.VLLM_RAY_BUNDLE_INDICES: # Use the bundle indices specified by the user. - bundle_indices = list( - map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) - assert len(bundle_indices) == self.parallel_config.world_size, \ - ("VLLM_RAY_BUNDLE_INDICES must have the same size" - f" as the world size, but got {bundle_indices=} " - f"and {self.parallel_config.world_size=}") - assert len(set(bundle_indices)) == len(bundle_indices), \ - ("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values," - f" but got {bundle_indices=}") + bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))) + assert len(bundle_indices) == self.parallel_config.world_size, ( + "VLLM_RAY_BUNDLE_INDICES must have the same size" + f" as the world size, but got {bundle_indices=} " + f"and {self.parallel_config.world_size=}" + ) + assert len(set(bundle_indices)) == len(bundle_indices), ( + "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values," + f" but got {bundle_indices=}" + ) else: # use the first N bundles that have GPU resources. bundle_indices = [] for bundle_id, bundle in enumerate(placement_group.bundle_specs): if bundle.get(current_platform.ray_device_key, 0): bundle_indices.append(bundle_id) - bundle_indices = bundle_indices[:self.parallel_config.world_size] + bundle_indices = bundle_indices[: self.parallel_config.world_size] - worker_metadata: List[RayWorkerMetaData] = [] + worker_metadata: list[RayWorkerMetaData] = [] driver_ip = get_ip() for rank, bundle_id in enumerate(bundle_indices): scheduling_strategy = PlacementGroupSchedulingStrategy( @@ -207,8 +216,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, - rpc_rank=rank) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, rpc_rank=rank) else: worker = ray.remote( num_cpus=0, @@ -216,15 +224,15 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", resources={current_platform.ray_device_key: num_gpus}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, - rpc_rank=rank) - worker_metadata.append( - RayWorkerMetaData(worker=worker, created_rank=rank)) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config, rpc_rank=rank) + worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank)) - worker_ips = ray.get([ - each.worker.get_node_ip.remote() # type: ignore[attr-defined] - for each in worker_metadata - ]) + worker_ips = ray.get( + [ + each.worker.get_node_ip.remote() # type: ignore[attr-defined] + for each in worker_metadata + ] + ) for each, ip in zip(worker_metadata, worker_ips): each.ip = ip @@ -239,7 +247,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - vllm_config=self.vllm_config, rpc_rank=0) + vllm_config=self.vllm_config, rpc_rank=0 + ) worker_metadata.pop(i) break @@ -250,9 +259,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", "Ray does not allocate any GPUs on the driver node." f"Driver IP: {driver_ip}, worker IPs: {worker_ips}." "Consider adjusting the Ray placement group or running " - "the driver on a GPU node.") + "the driver on a GPU node." + ) - ip_counts: Dict[str, int] = {} + ip_counts: dict[str, int] = {} for ip in worker_ips: ip_counts[ip] = ip_counts.get(ip, 0) + 1 @@ -272,15 +282,15 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): # After sorting, the workers on the same node will be # close to each other, and the workers on the driver # node will be placed first. - sorted_worker_metadata = sorted(worker_metadata, - key=sort_by_driver_then_worker_ip) + sorted_worker_metadata = sorted( + worker_metadata, key=sort_by_driver_then_worker_ip + ) start_rank = 0 if self.use_ray_spmd_worker else 1 for i, item in enumerate(sorted_worker_metadata): item.adjusted_rank = i + start_rank self.workers = [item.worker for item in sorted_worker_metadata] rerank_mapping = { - item.created_rank: item.adjusted_rank - for item in sorted_worker_metadata + item.created_rank: item.adjusted_rank for item in sorted_worker_metadata } self._run_workers("adjust_rank", rerank_mapping) @@ -291,8 +301,8 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): # driver_dummy_worker can be None when using ray spmd worker. continue worker_node_and_gpu_ids.append( - ray.get(worker.get_node_and_gpu_ids.remote()) \ - ) # type: ignore + ray.get(worker.get_node_and_gpu_ids.remote()) + ) # type: ignore node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids @@ -320,20 +330,27 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): f"{n_ips} unique IP addresses {all_ips}. Please check your" " network configuration. If you set `VLLM_HOST_IP`" " environment variable, make sure it is unique for" - " each node.") + " each node." + ) # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [{ - current_platform.device_control_env_var: - ",".join(map(str, node_gpus[node_id])), - } for (node_id, _) in worker_node_and_gpu_ids] + all_args_to_update_environment_variables = [ + { + current_platform.device_control_env_var: ",".join( + map(str, node_gpus[node_id]) + ), + } + for (node_id, _) in worker_node_and_gpu_ids + ] # Environment variables to copy from driver to workers env_vars_to_copy = get_env_vars_to_copy( exclude_vars=self.WORKER_SPECIFIC_ENV_VARS, additional_vars=set(current_platform.additional_env_vars).union( - self.ADDITIONAL_ENV_VARS), - destination="workers") + self.ADDITIONAL_ENV_VARS + ), + destination="workers", + ) # Copy existing env vars to each worker's args for args in all_args_to_update_environment_variables: @@ -342,11 +359,11 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): if name in os.environ: args[name] = os.environ[name] - self._env_vars_for_all_workers = ( - all_args_to_update_environment_variables) + self._env_vars_for_all_workers = all_args_to_update_environment_variables - self._run_workers("update_environment_variables", - self._get_env_vars_to_be_updated()) + self._run_workers( + "update_environment_variables", self._get_env_vars_to_be_updated() + ) if len(node_gpus) == 1: # in single node case, we don't need to get the IP address. @@ -359,7 +376,8 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): # the node. driver_ip = "127.0.0.1" distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) + driver_ip, get_open_port() + ) # Initialize the actual workers inside worker wrapper. all_kwargs = [] @@ -377,19 +395,20 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): self._run_workers("init_worker", all_kwargs) self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) + self._run_workers( + "load_model", + max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, + ) if self.use_ray_spmd_worker: for pp_rank in range(self.parallel_config.pipeline_parallel_size): self.pp_tp_workers.append([]) - for tp_rank in range( - self.parallel_config.tensor_parallel_size): + for tp_rank in range(self.parallel_config.tensor_parallel_size): # PP=2, TP=4 # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] - rank = (pp_rank * self.parallel_config.tensor_parallel_size - ) + tp_rank + rank = ( + pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank assert len(self.pp_tp_workers[pp_rank]) == tp_rank assert pp_rank < len(self.pp_tp_workers) self.pp_tp_workers[pp_rank].append(self.workers[rank]) @@ -397,11 +416,11 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): # This is the list of workers that are rank 0 of each TP group EXCEPT # global rank 0. These are the workers that will broadcast to the # rest of the workers. - self.tp_driver_workers: List[RayWorkerWrapper] = [] + self.tp_driver_workers: list[RayWorkerWrapper] = [] # This is the list of workers that are not drivers and not the first # worker in a TP group. These are the workers that will be # broadcasted to. - self.non_driver_workers: List[RayWorkerWrapper] = [] + self.non_driver_workers: list[RayWorkerWrapper] = [] # Enforce rank order for correct rank to return final output. for index, worker in enumerate(self.workers): @@ -414,20 +433,20 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): def _driver_execute_model( self, execute_model_req: Optional[ExecuteModelRequest] - ) -> Optional[List[SamplerOutput]]: + ) -> Optional[list[SamplerOutput]]: """Run execute_model in the driver worker. Passing None will cause the driver to stop the model execution loop running in each of the remote workers. """ assert not self.use_ray_spmd_worker, ( - "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") - return self.driver_worker.execute_method("execute_model", - execute_model_req) + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1" + ) + return self.driver_worker.execute_method("execute_model", execute_model_req) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> list[SamplerOutput]: if not self.use_ray_spmd_worker: return super().execute_model(execute_model_req) @@ -439,10 +458,7 @@ def execute_model( else: serialized_data = self.input_encoder.encode(execute_model_req) outputs = ray.get(self.forward_dag.execute(serialized_data)) - if self.use_v1: - output = outputs[0] - else: - output = self.output_decoder.decode(outputs[0]) + output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0]) return output def _run_workers( @@ -463,19 +479,15 @@ def _run_workers( rather than blocking on the results. - args/kwargs: All workers share the same args/kwargs """ - if isinstance(method, str): - sent_method = method - else: - sent_method = cloudpickle.dumps(method) + sent_method = method if isinstance(method, str) else cloudpickle.dumps(method) del method if self.use_ray_spmd_worker: assert not async_run_tensor_parallel_workers_only, ( - "async_run_tensor_parallel_workers_only is not supported for " - "spmd mode.") + "async_run_tensor_parallel_workers_only is not supported for spmd mode." + ) if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") + raise NotImplementedError("max_concurrent_workers is not supported yet.") # Start the ray workers first. ray_workers = self.workers @@ -519,23 +531,27 @@ def _check_ray_cgraph_installation(self): required_version = version.parse("2.43.0") current_version = version.parse(importlib.metadata.version("ray")) if current_version < required_version: - raise ValueError(f"Ray version {required_version} is " - f"required, but found {current_version}") + raise ValueError( + f"Ray version {required_version} is " + f"required, but found {current_version}" + ) import importlib.util - cgraph_spec = importlib.util.find_spec( - "ray.experimental.compiled_dag_ref") + + cgraph_spec = importlib.util.find_spec("ray.experimental.compiled_dag_ref") if cgraph_spec is None: - raise ValueError("Ray Compiled Graph is not installed. " - "Run `pip install ray[cgraph]` to install it.") + raise ValueError( + "Ray Compiled Graph is not installed. " + "Run `pip install ray[cgraph]` to install it." + ) cupy_spec = importlib.util.find_spec("cupy") - if (cupy_spec is None - and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"): + if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl": raise ValueError( "cupy is not installed but required since " "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. " - "Run `pip install ray[cgraph]` and check cupy installation.") + "Run `pip install ray[cgraph]` and check cupy installation." + ) def _compiled_ray_dag(self, enable_asyncio: bool): assert self.parallel_config.use_ray @@ -549,18 +565,26 @@ def _compiled_ray_dag(self, enable_asyncio: bool): # ray.dag, otherwise it will not take effect. os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112 from ray.dag import InputNode, MultiOutputNode - logger.info("RAY_CGRAPH_get_timeout is set to %s", - os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112 - logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s", - envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE) - logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", - envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) + + logger.info( + "RAY_CGRAPH_get_timeout is set to %s", + os.environ["RAY_CGRAPH_get_timeout"], # noqa: SIM112 + ) + logger.info( + "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE, + ) + logger.info( + "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM, + ) channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE if channel_type not in ("auto", "nccl", "shm"): raise ValueError( "Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: " - f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.") + f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'." + ) with InputNode() as input_data: # Example DAG: PP=2, TP=4 @@ -585,20 +609,24 @@ def _compiled_ray_dag(self, enable_asyncio: bool): # and the TP group executes in SPMD fashion. if self.use_v1: outputs = [ - worker.execute_model_ray. - bind( # type: ignore[attr-defined] - outputs[i]) for i, worker in enumerate(tp_group) + worker.execute_model_ray.bind( # type: ignore[attr-defined] + outputs[i] + ) + for i, worker in enumerate(tp_group) ] else: outputs = [ - worker.execute_model_spmd. - bind( # type: ignore[attr-defined] - outputs[i]) for i, worker in enumerate(tp_group) + worker.execute_model_spmd.bind( # type: ignore[attr-defined] + outputs[i] + ) + for i, worker in enumerate(tp_group) ] last_pp_rank = len(self.pp_tp_workers) - 1 - if (pp_rank < last_pp_rank and - envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"): + if ( + pp_rank < last_pp_rank + and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm" + ): # Specify how intermediate tensors should be passed # between pp stages, no need to specify for the last # pp stage or when using shared memory (the default). @@ -612,30 +640,37 @@ def _compiled_ray_dag(self, enable_asyncio: bool): if envs.VLLM_USE_RAY_WRAPPED_PP_COMM: from ray.experimental.channel.accelerator_context import ( - register_accelerator_context) + register_accelerator_context, + ) from vllm.distributed.device_communicators.ray_communicator import ( - RayPPCommunicator) - register_accelerator_context(torch_module_name="cuda", - communicator_cls=RayPPCommunicator) - logger.info("Using RayPPCommunicator " - "(which wraps vLLM _PP GroupCoordinator) " - "for Ray Compiled Graph communication.") + RayPPCommunicator, + ) + + register_accelerator_context( + torch_module_name="cuda", communicator_cls=RayPPCommunicator + ) + logger.info( + "Using RayPPCommunicator " + "(which wraps vLLM _PP GroupCoordinator) " + "for Ray Compiled Graph communication." + ) else: - logger.info("Using Ray's NCCL communicator for " - "Ray Compiled Graph communication.") + logger.info( + "Using Ray's NCCL communicator for Ray Compiled Graph communication." + ) return forward_dag.experimental_compile( enable_asyncio=enable_asyncio, - _overlap_gpu_communication=envs. - VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) + _overlap_gpu_communication=envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM, + ) def __del__(self): self.shutdown() async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> list[SamplerOutput]: if not self.use_ray_spmd_worker: return await super().execute_model_async(execute_model_req) @@ -648,14 +683,13 @@ async def execute_model_async( return self.output_decoder.decode(output) async def _driver_execute_model_async( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: + self, execute_model_req: Optional[ExecuteModelRequest] = None + ) -> list[SamplerOutput]: assert not self.use_ray_spmd_worker, ( - "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1" + ) if not self.tp_driver_workers: - return await self.driver_exec_method("execute_model", - execute_model_req) + return await self.driver_exec_method("execute_model", execute_model_req) if self.pp_locks is None: # This locks each pipeline parallel stage so multiple virtual # engines can't execute on the same stage at the same time @@ -668,16 +702,25 @@ async def _driver_execute_model_async( tasks = [ asyncio.create_task( - _run_task_with_lock(self.driver_exec_method, self.pp_locks[0], - "execute_model", execute_model_req)) + _run_task_with_lock( + self.driver_exec_method, + self.pp_locks[0], + "execute_model", + execute_model_req, + ) + ) ] - for pp_rank, driver_worker in enumerate(self.tp_driver_workers, - start=1): + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, start=1): tasks.append( asyncio.create_task( - _run_task_with_lock(driver_worker.execute_method.remote, - self.pp_locks[pp_rank], - "execute_model", execute_model_req))) + _run_task_with_lock( + driver_worker.execute_method.remote, + self.pp_locks[pp_rank], + "execute_model", + execute_model_req, + ) + ) + ) results = await asyncio.gather(*tasks) @@ -686,7 +729,8 @@ async def _driver_execute_model_async( async def _start_worker_execution_loop(self): assert not self.use_ray_spmd_worker, ( - "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1") + "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1" + ) coros = [ worker.execute_method.remote("start_worker_execution_loop") for worker in self.non_driver_workers diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 5b76334722e9..c3c8a70678ad 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -4,7 +4,7 @@ import os import time from collections import defaultdict -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import msgspec @@ -30,11 +30,13 @@ import ray from ray.util import placement_group_table from ray.util.placement_group import PlacementGroup + try: from ray._private.state import available_resources_per_node except ImportError: # Ray 2.9.x doesn't expose `available_resources_per_node` from ray._private.state import state as _state + available_resources_per_node = _state._available_resources_per_node class RayWorkerWrapper(WorkerWrapperBase): @@ -49,27 +51,28 @@ def __init__(self, *args, **kwargs) -> None: # that thread. self.compiled_dag_cuda_device_set = False - self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, - dec_hook=decode_hook) + self.input_decoder = msgspec.msgpack.Decoder( + ExecuteModelRequest, dec_hook=decode_hook + ) self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) def get_node_ip(self) -> str: return get_ip() - def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + def get_node_and_gpu_ids(self) -> tuple[str, list[int]]: node_id = ray.get_runtime_context().get_node_id() device_key = vllm.platforms.current_platform.ray_device_key if not device_key: - raise RuntimeError("current platform %s does not support ray.", - vllm.platforms.current_platform.device_name) - gpu_ids = ray.get_runtime_context().get_accelerator_ids( - )[device_key] + raise RuntimeError( + "current platform %s does not support ray.", + vllm.platforms.current_platform.device_name, + ) + gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key] return node_id, gpu_ids def execute_model_spmd( - self, req_or_tuple: Union[bytes, - Tuple[bytes, - Optional[IntermediateTensors]]] + self, + req_or_tuple: Union[bytes, tuple[bytes, Optional[IntermediateTensors]]], ) -> bytes: """Execute model in SPMD fashion: used only when SPMD worker and compiled DAG are both enabled. @@ -94,8 +97,9 @@ def execute_model_spmd( current_platform.set_device(self.worker.device) self.compiled_dag_cuda_device_set = True - output = self.worker._execute_model_spmd(execute_model_req, - intermediate_tensors) + output = self.worker._execute_model_spmd( + execute_model_req, intermediate_tensors + ) # Pipeline model request and output to the next pipeline stage. if isinstance(output, IntermediateTensors): output = serialized_req, output @@ -121,11 +125,12 @@ def setup_device_if_necessary(self): def execute_model_ray( self, - scheduler_output: Union["SchedulerOutput", - Tuple["SchedulerOutput", - "IntermediateTensors"]], - ) -> Union["ModelRunnerOutput", Tuple["SchedulerOutput", - "IntermediateTensors"]]: + scheduler_output: Union[ + "SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"] + ], + ) -> Union[ + "ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"] + ]: # This method is used by Ray Compiled Graph to execute the model, # and it needs a special logic of self.setup_device_if_necessary() self.setup_device_if_necessary() @@ -135,7 +140,8 @@ def execute_model_ray( else: scheduler_output, intermediate_tensors = scheduler_output, None output = self.worker.model_runner.execute_model( - scheduler_output, intermediate_tensors) + scheduler_output, intermediate_tensors + ) if isinstance(output, IntermediateTensors): output = scheduler_output, output elif not get_pp_group().is_last_rank: @@ -150,7 +156,7 @@ def execute_model_ray( output = output.get_output() return output - def override_env_vars(self, vars: Dict[str, str]): + def override_env_vars(self, vars: dict[str, str]): os.environ.update(vars) ray_import_err = None @@ -171,12 +177,15 @@ def ray_is_available() -> bool: def assert_ray_available(): """Raise an exception if Ray is not available.""" if ray is None: - raise ValueError(f"Failed to import Ray: {ray_import_err}." - "Please install Ray with `pip install ray`.") + raise ValueError( + f"Failed to import Ray: {ray_import_err}." + "Please install Ray with `pip install ray`." + ) -def _verify_bundles(placement_group: "PlacementGroup", - parallel_config: ParallelConfig, device_str: str): +def _verify_bundles( + placement_group: "PlacementGroup", parallel_config: ParallelConfig, device_str: str +): """Verify a given placement group has bundles located in the right place. There are 2 rules. @@ -184,14 +193,15 @@ def _verify_bundles(placement_group: "PlacementGroup", - Fail if driver node is not included in a placement group. """ assert ray.is_initialized(), ( - "Ray is not initialized although distributed-executor-backend is ray.") + "Ray is not initialized although distributed-executor-backend is ray." + ) pg_data = placement_group_table(placement_group) # bundle_idx -> node_id bundle_to_node_ids = pg_data["bundles_to_node_id"] # bundle_idx -> bundle (e.g., {"GPU": 1}) bundles = pg_data["bundles"] # node_id -> List of bundle (e.g., {"GPU": 1}) - node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) + node_id_to_bundle: dict[str, list[dict[str, float]]] = defaultdict(list) for bundle_idx, node_id in bundle_to_node_ids.items(): node_id_to_bundle[node_id].append(bundles[bundle_idx]) @@ -217,8 +227,13 @@ def _verify_bundles(placement_group: "PlacementGroup", "unless you have fast interconnect across nodes, like " "Infiniband. To resolve this issue, make sure you have more " "than %d GPUs available at each node.", - parallel_config.tensor_parallel_size, device_str, len(bundles), - device_str, node_id, parallel_config.tensor_parallel_size) + parallel_config.tensor_parallel_size, + device_str, + len(bundles), + device_str, + node_id, + parallel_config.tensor_parallel_size, + ) def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): @@ -250,7 +265,9 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): " and make sure the IP addresses used by ray cluster" " are the same as VLLM_HOST_IP environment variable" " specified in each node if you are running on a multi-node.", - int(time.time() - s), placement_group_specs) + int(time.time() - s), + placement_group_specs, + ) try: ray.get(pg_ready_ref, timeout=0) @@ -259,7 +276,8 @@ def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): "Cannot provide a placement group of " f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See " "`ray status` and `ray list nodes` to make sure the cluster has " - "enough resources.") from None + "enough resources." + ) from None def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): @@ -274,8 +292,9 @@ def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): # Exponential backoff for warning print. wait_interval *= 2 logger.info( - "Waiting for removing a placement group of specs for " - "%d seconds.", int(time.time() - s)) + "Waiting for removing a placement group of specs for %d seconds.", + int(time.time() - s), + ) time.sleep(wait_interval) @@ -306,19 +325,21 @@ def initialize_ray_cluster( except ConnectionError: logger.warning( "No existing RAY instance detected. " - "A new instance will be launched with current node resources.") - ray.init(address=ray_address, - num_gpus=parallel_config.world_size, - runtime_env=parallel_config.ray_runtime_env) + "A new instance will be launched with current node resources." + ) + ray.init( + address=ray_address, + num_gpus=parallel_config.world_size, + runtime_env=parallel_config.ray_runtime_env, + ) else: - ray.init(address=ray_address, - runtime_env=parallel_config.ray_runtime_env) + ray.init(address=ray_address, runtime_env=parallel_config.ray_runtime_env) device_str = current_platform.ray_device_key if not device_str: raise ValueError( - f"current platform {current_platform.device_name} does not " - "support ray.") + f"current platform {current_platform.device_name} does not support ray." + ) # Create or get the placement group for worker processes if parallel_config.placement_group: @@ -337,8 +358,8 @@ def initialize_ray_cluster( bundle_devices = bundle.get(device_str, 0) if bundle_devices > 1: raise ValueError( - "Placement group bundle cannot have more than 1 " - f"{device_str}.") + f"Placement group bundle cannot have more than 1 {device_str}." + ) if bundle_devices: device_bundles += 1 if parallel_config.world_size > device_bundles: @@ -346,10 +367,10 @@ def initialize_ray_cluster( f"The number of required {device_str}s exceeds the total " f"number of available {device_str}s in the placement group. " f"Required number of devices: {parallel_config.world_size}. " - f"Total number of devices: {device_bundles}.") + f"Total number of devices: {device_bundles}." + ) else: - logger.info("No current placement group found. " - "Creating a new placement group.") + logger.info("No current placement group found. Creating a new placement group.") num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) # Log a warning message and delay resource allocation failure response. # Avoid immediate rejection to allow user-initiated placement group @@ -357,12 +378,14 @@ def initialize_ray_cluster( if parallel_config.world_size > num_devices_in_cluster: logger.warning( "The number of required %ss exceeds the total " - "number of available %ss in the placement group.", device_str, - device_str) + "number of available %ss in the placement group.", + device_str, + device_str, + ) # Create a new placement group - placement_group_specs: List[Dict[str, float]] = ([{ - device_str: 1.0 - } for _ in range(parallel_config.world_size)]) + placement_group_specs: list[dict[str, float]] = [ + {device_str: 1.0} for _ in range(parallel_config.world_size) + ] # vLLM engine is also a worker to execute model with an accelerator, # so it requires to have the device in a current node. Check if @@ -375,14 +398,16 @@ def initialize_ray_cluster( f"Current node has no {device_str} available. " f"{current_node_resource=}. vLLM engine cannot start without " f"{device_str}. Make sure you have at least 1 {device_str} " - f"available in a node {current_node_id=} {current_ip=}.") + f"available in a node {current_node_id=} {current_ip=}." + ) # This way, at least bundle is required to be created in a current # node. placement_group_specs[0][f"node:{current_ip}"] = 0.001 # By default, Ray packs resources as much as possible. current_placement_group = ray.util.placement_group( - placement_group_specs, strategy="PACK") + placement_group_specs, strategy="PACK" + ) _wait_until_pg_ready(current_placement_group) assert current_placement_group is not None @@ -393,6 +418,7 @@ def initialize_ray_cluster( def get_num_tpu_nodes() -> int: from ray._private.accelerators import TPUAcceleratorManager + cluster_resources = ray.cluster_resources() total_tpus = int(cluster_resources["TPU"]) tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index d669592e75f1..8206f23d1878 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -4,7 +4,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from functools import cached_property from multiprocessing import Lock -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import torch import torch.distributed as dist @@ -14,8 +14,7 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import worker_receiver_cache_from_config -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - run_method) +from vllm.utils import get_distributed_init_method, get_ip, get_open_port, run_method from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.outputs import AsyncModelRunnerOutput @@ -25,14 +24,11 @@ class UniProcExecutor(ExecutorBase): - uses_ray: bool = False def _init_executor(self) -> None: - """Initialize the worker and load the model. - """ - self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, - rpc_rank=0) + """Initialize the worker and load the model.""" + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) distributed_init_method, rank, local_rank = self._distributed_args() is_driver_worker = True kwargs = dict( @@ -43,24 +39,24 @@ def _init_executor(self) -> None: is_driver_worker=is_driver_worker, ) self.mm_receiver_cache = worker_receiver_cache_from_config( - self.vllm_config, MULTIMODAL_REGISTRY, Lock()) + self.vllm_config, MULTIMODAL_REGISTRY, Lock() + ) self.async_output_thread: Optional[ThreadPoolExecutor] = None if self.max_concurrent_batches > 1: self.async_output_thread = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="WorkerAsyncOutput") + max_workers=1, thread_name_prefix="WorkerAsyncOutput" + ) - self.collective_rpc("init_worker", args=([kwargs], )) + self.collective_rpc("init_worker", args=([kwargs],)) self.collective_rpc("init_device") self.collective_rpc("load_model") def _distributed_args(self) -> tuple[str, int, int]: """Return (distributed_init_method, rank, local_rank).""" - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) + distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) # set local rank as the device index if specified - device_info = self.vllm_config.device_config.device.__str__().split( - ":") + device_info = self.vllm_config.device_config.device.__str__().split(":") local_rank = int(device_info[1]) if len(device_info) > 1 else 0 return distributed_init_method, 0, local_rank @@ -68,12 +64,14 @@ def _distributed_args(self) -> tuple[str, int, int]: def max_concurrent_batches(self) -> int: return 2 if self.scheduler_config.async_scheduling else 1 - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict] = None, - non_block: bool = False) -> List[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + ) -> list[Any]: if kwargs is None: kwargs = {} if self.mm_receiver_cache is not None and method == "execute_model": @@ -101,10 +99,13 @@ def check_health(self) -> None: return def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: self.driver_worker.reinitialize_distributed(reconfig_request) - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): self.shutdown() return @@ -132,15 +133,16 @@ class ExecutorWithExternalLauncher(UniProcExecutor): deterministic, all the engines will generate the same outputs, and they don't need to synchronize the states with each other. """ + uses_ray: bool = False def _init_executor(self) -> None: - """Initialize the worker and load the model. - """ + """Initialize the worker and load the model.""" if envs.VLLM_USE_V1: - assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ - ("To get deterministic execution in V1, " - "please set VLLM_ENABLE_V1_MULTIPROCESSING=0") + assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, ( + "To get deterministic execution in V1, " + "please set VLLM_ENABLE_V1_MULTIPROCESSING=0" + ) super()._init_executor() def _distributed_args(self) -> tuple[str, int, int]: @@ -156,7 +158,7 @@ def _distributed_args(self) -> tuple[str, int, int]: local_rank = int(os.environ["LOCAL_RANK"]) return distributed_init_method, rank, local_rank - def determine_num_available_blocks(self) -> Tuple[int, int]: + def determine_num_available_blocks(self) -> tuple[int, int]: """ Determine the number of available KV blocks. Add an additional all_reduce to get the min across all ranks. @@ -168,6 +170,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ a, b = super().determine_num_available_blocks() from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64) b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 09defade00dc..26ad37dda776 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -34,6 +34,7 @@ class BatchDescriptor(NamedTuple): items as minimal as possible to properly and uniquely describe the padded batch for cudagraph. """ + num_tokens: int uniform_decode: bool = False """ @@ -49,29 +50,30 @@ def non_uniform(self) -> "BatchDescriptor": return BatchDescriptor(self.num_tokens, uniform_decode=False) -def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor, - sequence_parallel_size: int) -> list[int]: - sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) // - sequence_parallel_size) +def _compute_sp_num_tokens( + num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int +) -> list[int]: + sp_tokens = ( + num_tokens_across_dp_cpu + sequence_parallel_size - 1 + ) // sequence_parallel_size sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size) return sp_tokens.tolist() -def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor, - sequence_parallel_size: int, - max_num_tokens: int, - chunk_idx: int) -> list[int]: - - sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, - sequence_parallel_size) +def _compute_chunked_local_num_tokens( + num_tokens_across_dp_cpu: torch.Tensor, + sequence_parallel_size: int, + max_num_tokens: int, + chunk_idx: int, +) -> list[int]: + sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size) sp_size = len(sp_tokens) local_size = [-1] * sp_size for i in range(sp_size): # Take into account sharding if MoE activation is sequence parallel. - local_size[i] = min(max_num_tokens, - sp_tokens[i] - (max_num_tokens * chunk_idx)) + local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx)) if local_size[i] <= 0: local_size[i] = 1 # ensure lockstep even if done return local_size @@ -86,13 +88,15 @@ class DPMetadata: local_sizes: Optional[list[int]] = None @staticmethod - def num_tokens_across_dp(num_tokens: int, dp_size: int, - dp_rank: int) -> torch.Tensor: + def num_tokens_across_dp( + num_tokens: int, dp_size: int, dp_rank: int + ) -> torch.Tensor: """ Gather the num_tokens across all DP ranks and return results in a CPU tensor of size dp_size. """ from vllm.distributed.parallel_state import get_dp_group + device = current_platform.device_type group = get_dp_group().device_group @@ -102,14 +106,15 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, # this optimization if we run into this case. if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: logger.info_once( - "Using CPU all reduce to syncronize DP padding between ranks.") + "Using CPU all reduce to syncronize DP padding between ranks." + ) device = "cpu" group = get_dp_group().cpu_group num_tokens_across_dp = [0] * dp_size num_tokens_across_dp[dp_rank] = num_tokens - num_tokens_tensor = torch.tensor(num_tokens_across_dp, - device=device, - dtype=torch.int32) + num_tokens_tensor = torch.tensor( + num_tokens_across_dp, device=device, dtype=torch.int32 + ) dist.all_reduce(num_tokens_tensor, group=group) return num_tokens_tensor.cpu() @@ -119,16 +124,19 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, # When sp_size==1, this is just the cummulative num tokens across DP. def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor: num_tokens_across_sp_cpu = ( - (self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size) - num_tokens_across_sp_cpu = ( - num_tokens_across_sp_cpu.repeat_interleave(sp_size)) + self.num_tokens_across_dp_cpu - 1 + sp_size + ) // sp_size + num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size) return torch.cumsum(num_tokens_across_sp_cpu, dim=0) @staticmethod def should_ubatch_across_dp( - should_ubatch: bool, orig_num_tokens_per_ubatch: int, - padded_num_tokens_per_ubatch: int, dp_size: int, - dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]: + should_ubatch: bool, + orig_num_tokens_per_ubatch: int, + padded_num_tokens_per_ubatch: int, + dp_size: int, + dp_rank: int, + ) -> tuple[bool, Optional[torch.Tensor]]: """ 1. Decides if each DP rank is going to microbatch. Either all ranks run with microbatching or none of them do. If this function decides @@ -154,6 +162,7 @@ def should_ubatch_across_dp( tensor[2][dp_rank] = 1 if should_ubatch else 0 from vllm.distributed.parallel_state import get_dp_group + dist.all_reduce(tensor, group=get_dp_group().device_group) result: bool = bool(torch.all(tensor[2] == 1).item()) @@ -166,8 +175,9 @@ def should_ubatch_across_dp( orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): - logger.debug("Aborting ubatching %s %s", orig_min_num_tokens, - padded_max_num_tokens) + logger.debug( + "Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens + ) return False, None return result, padded_num_tokens_tensor.cpu() @@ -176,35 +186,37 @@ def make( parallel_config: ParallelConfig, attn_metadata: Any, num_tokens: int, - num_tokens_across_dp_cpu: Optional[torch.Tensor] = None + num_tokens_across_dp_cpu: Optional[torch.Tensor] = None, ) -> "DPMetadata": - assert parallel_config.data_parallel_size > 1 dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank - if attn_metadata is not None and hasattr(attn_metadata, - "num_prefill_tokens"): + if attn_metadata is not None and hasattr(attn_metadata, "num_prefill_tokens"): # for v0 attention backends - batchsize = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens + batchsize = ( + attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + ) else: # for v1 attention backends or no attn_metadata batchsize = num_tokens # If num_tokens_across_dp is None, it will be computed by all_reduce # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize - assert (num_tokens_across_dp_cpu is None - or num_tokens_across_dp_cpu[dp_rank] == batchsize - ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" + assert ( + num_tokens_across_dp_cpu is None + or num_tokens_across_dp_cpu[dp_rank] == batchsize + ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" if num_tokens_across_dp_cpu is None: num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp( - batchsize, dp_size, dp_rank) + batchsize, dp_size, dp_rank + ) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu) @contextmanager - def chunked_sizes(self, sequence_parallel_size: int, - max_chunk_size_per_rank: int, chunk_idx: int): + def chunked_sizes( + self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int + ): """ Context manager to compute and temporarily set the per-rank local token sizes for a specific chunk during chunked forward execution. @@ -225,13 +237,16 @@ def chunked_sizes(self, sequence_parallel_size: int, we use SP between the layers to avoid redundant ops. We need this value to compute the chunked sizes. - max_chunk_size_per_rank: The max number of tokens each rank is + max_chunk_size_per_rank: The max number of tokens each rank is allowed to process in this chunk. chunk_idx: The index of the chunk to compute sizes for. """ self.local_sizes = _compute_chunked_local_num_tokens( - self.num_tokens_across_dp_cpu, sequence_parallel_size, - max_chunk_size_per_rank, chunk_idx) + self.num_tokens_across_dp_cpu, + sequence_parallel_size, + max_chunk_size_per_rank, + chunk_idx, + ) try: yield self.local_sizes finally: @@ -244,7 +259,8 @@ def sp_local_sizes(self, sequence_parallel_size: int): but without any chunking. """ self.local_sizes = _compute_sp_num_tokens( - self.num_tokens_across_dp_cpu, sequence_parallel_size) + self.num_tokens_across_dp_cpu, sequence_parallel_size + ) try: yield self.local_sizes finally: @@ -267,8 +283,11 @@ class ForwardContext: for each microbatch. Set dynamically for each forward pass """ - attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"], - list[dict[str, "AttentionMetadata"]]] + attn_metadata: Union[ + "AttentionMetadata", + dict[str, "AttentionMetadata"], + list[dict[str, "AttentionMetadata"]], + ] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass @@ -281,8 +300,9 @@ class ForwardContext: ubatch_slices: Optional[UBatchSlices] = None def __post_init__(self): - assert self.cudagraph_runtime_mode.valid_runtime_modes(), \ + assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" + ) _forward_context: Optional[ForwardContext] = None @@ -292,26 +312,29 @@ def get_forward_context() -> ForwardContext: """Get the current forward context.""" assert _forward_context is not None, ( "Forward context is not set. " - "Please use `set_forward_context` to set the forward context.") + "Please use `set_forward_context` to set the forward context." + ) return _forward_context def create_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - dp_metadata: Optional[DPMetadata] = None, - cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None, - ubatch_slices: Optional[UBatchSlices] = None): - return ForwardContext(no_compile_layers=vllm_config.compilation_config. - static_forward_context, - virtual_engine=virtual_engine, - attn_metadata=attn_metadata, - dp_metadata=dp_metadata, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ubatch_slices=ubatch_slices) + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + dp_metadata: Optional[DPMetadata] = None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: Optional[BatchDescriptor] = None, + ubatch_slices: Optional[UBatchSlices] = None, +): + return ForwardContext( + no_compile_layers=vllm_config.compilation_config.static_forward_context, + virtual_engine=virtual_engine, + attn_metadata=attn_metadata, + dp_metadata=dp_metadata, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ) @contextmanager @@ -331,14 +354,15 @@ def override_forward_context(forward_context: Optional[ForwardContext]): @contextmanager def set_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None, - cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None, - ubatch_slices: Optional[UBatchSlices] = None): + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor: Optional[BatchDescriptor] = None, + ubatch_slices: Optional[UBatchSlices] = None, +): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -350,15 +374,24 @@ def set_forward_context( dp_metadata: Optional[DPMetadata] = None if vllm_config.parallel_config.data_parallel_size > 1 and ( - attn_metadata is not None or num_tokens is not None): - dp_metadata = DPMetadata.make(vllm_config.parallel_config, - attn_metadata, num_tokens or 0, - num_tokens_across_dp) - - forward_context = create_forward_context(attn_metadata, vllm_config, - virtual_engine, dp_metadata, - cudagraph_runtime_mode, - batch_descriptor, ubatch_slices) + attn_metadata is not None or num_tokens is not None + ): + dp_metadata = DPMetadata.make( + vllm_config.parallel_config, + attn_metadata, + num_tokens or 0, + num_tokens_across_dp, + ) + + forward_context = create_forward_context( + attn_metadata, + vllm_config, + virtual_engine, + dp_metadata, + cudagraph_runtime_mode, + batch_descriptor, + ubatch_slices, + ) try: with override_forward_context(forward_context): @@ -368,8 +401,9 @@ def set_forward_context( if need_to_track_batchsize: if hasattr(attn_metadata, "num_prefill_tokens"): # for v0 attention backends - batchsize = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens + batchsize = ( + attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + ) else: # for v1 attention backends batchsize = num_tokens @@ -377,13 +411,13 @@ def set_forward_context( # adding a sync point here should not affect # scheduling of the next batch from vllm.platforms import current_platform + synchronize = current_platform.synchronize if synchronize is not None: synchronize() now = time.perf_counter() # time measurement is in milliseconds - batchsize_forward_time[batchsize].append( - (now - forward_start_time) * 1000) + batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000) if now - last_logging_time > batchsize_logging_interval: last_logging_time = now forward_stats = [] @@ -396,6 +430,10 @@ def set_forward_context( forward_stats.append((bs, len(times), medium)) forward_stats.sort(key=lambda x: x[1], reverse=True) if forward_stats: - logger.info(("Batchsize forward time stats " - "(batchsize, count, median_time(ms)): %s"), - forward_stats) + logger.info( + ( + "Batchsize forward time stats " + "(batchsize, count, median_time(ms)): %s" + ), + forward_stats, + ) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 3f1cac531f45..d9aed70c9b97 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,12 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, - EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, - ProcessorInputs, PromptType, SingletonInputs, - SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, embeds_inputs, - to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) +from .data import ( + DataPrompt, + DecoderOnlyInputs, + EmbedsInputs, + EmbedsPrompt, + EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, + ProcessorInputs, + PromptType, + SingletonInputs, + SingletonPrompt, + TextPrompt, + TokenInputs, + TokensPrompt, + build_explicit_enc_dec_prompt, + embeds_inputs, + to_enc_dec_tuple_list, + token_inputs, + zip_enc_dec_prompts, +) __all__ = [ "DataPrompt", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 562e73eead66..c463723e5d0e 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -7,8 +7,11 @@ from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar if TYPE_CHECKING: - from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs, - MultiModalUUIDDict) + from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalInputs, + MultiModalUUIDDict, + ) class TextPrompt(TypedDict): @@ -134,23 +137,27 @@ class DataPrompt(TypedDict): def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]: - return (isinstance(prompt, dict) and "prompt_token_ids" in prompt - and "prompt_embeds" not in prompt) + return ( + isinstance(prompt, dict) + and "prompt_token_ids" in prompt + and "prompt_embeds" not in prompt + ) def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]: - return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt - and "prompt_embeds" in prompt) + return ( + isinstance(prompt, dict) + and "prompt_token_ids" not in prompt + and "prompt_embeds" in prompt + ) -_T1_co = TypeVar("_T1_co", - bound=SingletonPrompt, - default=SingletonPrompt, - covariant=True) -_T2_co = TypeVar("_T2_co", - bound=SingletonPrompt, - default=SingletonPrompt, - covariant=True) +_T1_co = TypeVar( + "_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True +) +_T2_co = TypeVar( + "_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True +) # TODO: Make fields ReadOnly once mypy supports it @@ -309,8 +316,9 @@ def build_explicit_enc_dec_prompt( def zip_enc_dec_prompts( enc_prompts: Iterable[_T1], dec_prompts: Iterable[Optional[_T2]], - mm_processor_kwargs: Optional[Union[Iterable[dict[str, Any]], - dict[str, Any]]] = None, + mm_processor_kwargs: Optional[ + Union[Iterable[dict[str, Any]], dict[str, Any]] + ] = None, ) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]: """ Zip encoder and decoder prompts together into a list of @@ -329,20 +337,21 @@ def zip_enc_dec_prompts( encoder_prompt, decoder_prompt, cast(dict[str, Any], mm_processor_kwargs), - ) for (encoder_prompt, - decoder_prompt) in zip(enc_prompts, dec_prompts) + ) + for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts) ] return [ - build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, - mm_proc_kwargs) - for (encoder_prompt, decoder_prompt, mm_proc_kwargs - ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs) + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs) + for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip( + enc_prompts, dec_prompts, mm_processor_kwargs + ) ] def to_enc_dec_tuple_list( enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]], ) -> list[tuple[_T1, Optional[_T2]]]: - return [(enc_dec_prompt["encoder_prompt"], - enc_dec_prompt["decoder_prompt"]) - for enc_dec_prompt in enc_dec_prompts] + return [ + (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) + for enc_dec_prompt in enc_dec_prompts + ] diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 123c81173120..2f7bd50df022 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -1,49 +1,33 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict, - Union, cast, overload) +from typing import TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict, Union, cast from typing_extensions import TypeIs from vllm.utils import is_list_of -from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs, - PromptType, SingletonInputs, SingletonPrompt, TextPrompt, - TokensPrompt) +from .data import ( + EmbedsPrompt, + ExplicitEncoderDecoderPrompt, + ProcessorInputs, + PromptType, + SingletonInputs, + SingletonPrompt, + TextPrompt, + TokensPrompt, +) if TYPE_CHECKING: import torch -class ParsedText(TypedDict): - content: str - is_tokens: Literal[False] - - -class ParsedTokens(TypedDict): - content: list[int] - is_tokens: Literal[True] - - -@overload -def parse_and_batch_prompt( - prompt: Union[str, list[str]], ) -> Sequence[ParsedText]: - ... - - -@overload -def parse_and_batch_prompt( - prompt: Union[list[int], list[list[int]]], ) -> Sequence[ParsedTokens]: - ... - - -def parse_and_batch_prompt( +def parse_raw_prompts( prompt: Union[str, list[str], list[int], list[list[int]]], -) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: +) -> Union[Sequence[TextPrompt], Sequence[TokensPrompt]]: if isinstance(prompt, str): # case 1: a string - return [ParsedText(content=prompt, is_tokens=False)] + return [TextPrompt(prompt=prompt)] if isinstance(prompt, list): if len(prompt) == 0: @@ -52,13 +36,11 @@ def parse_and_batch_prompt( if is_list_of(prompt, str): # case 2: array of strings prompt = cast(list[str], prompt) - return [ - ParsedText(content=elem, is_tokens=False) for elem in prompt - ] + return [TextPrompt(prompt=elem) for elem in prompt] if is_list_of(prompt, int): # case 3: array of tokens prompt = cast(list[int], prompt) - return [ParsedTokens(content=prompt, is_tokens=True)] + return [TokensPrompt(prompt_token_ids=prompt)] if is_list_of(prompt, list): prompt = cast(list[list[int]], prompt) if len(prompt[0]) == 0: @@ -66,13 +48,12 @@ def parse_and_batch_prompt( if is_list_of(prompt[0], int): # case 4: array of token arrays - return [ - ParsedTokens(content=elem, is_tokens=True) - for elem in prompt - ] + return [TokensPrompt(prompt_token_ids=elem) for elem in prompt] - raise TypeError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") + raise TypeError( + "prompt must be a string, array of strings, " + "array of tokens, or array of token arrays" + ) class ParsedStrPrompt(TypedDict): @@ -95,28 +76,9 @@ class ParsedEmbedsPrompt(TypedDict): content: EmbedsPrompt -ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt, - ParsedTokensPrompt, ParsedEmbedsPrompt] - - -@overload -def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt: - ... - - -@overload -def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt: - ... - - -@overload -def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt: - ... - - -@overload -def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt: - ... +ParsedSingletonPrompt = Union[ + ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, ParsedEmbedsPrompt +] def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: @@ -126,19 +88,19 @@ def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: # Type ignores are because mypy does not correctly infer the TypedDicts # Pyright does succeed. if "prompt_embeds" in prompt: - return ParsedEmbedsPrompt( - type="embeds", content=prompt) # type: ignore[typeddict-item] + return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item] elif "prompt_token_ids" in prompt: - return ParsedTokensPrompt( - type="tokens", content=prompt) # type: ignore[typeddict-item] + return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item] elif "prompt" in prompt: return ParsedTextPrompt(type="text", content=prompt) raise TypeError( - "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") + "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt" + ) def is_explicit_encoder_decoder_prompt( - prompt: PromptType, ) -> TypeIs[ExplicitEncoderDecoderPrompt]: + prompt: PromptType, +) -> TypeIs[ExplicitEncoderDecoderPrompt]: return isinstance(prompt, dict) and "encoder_prompt" in prompt @@ -165,7 +127,7 @@ def get_prompt_components(prompt: PromptType) -> PromptComponents: if isinstance(prompt, str): return PromptComponents(text=prompt) - if (encoder_prompt := prompt.get("encoder_prompt")): + if encoder_prompt := prompt.get("encoder_prompt"): return get_prompt_components(encoder_prompt) # type: ignore[arg-type] return PromptComponents( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 65460b46cb5a..00f30e483693 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -10,56 +10,77 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import BaseMultiModalProcessorCache -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs, MultiModalUUIDDict) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalEncDecInputs, + MultiModalInputs, + MultiModalUUIDDict, +) from vllm.multimodal.processing import BaseMultiModalProcessor -from vllm.transformers_utils.tokenizer import AnyTokenizer - -from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, - EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, - ProcessorInputs, PromptType, SingletonInputs, - SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, - embeds_inputs, token_inputs) +from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs +from vllm.utils.jsontree import json_iter_leaves + +from .data import ( + DecoderOnlyInputs, + EmbedsInputs, + EmbedsPrompt, + EncoderDecoderInputs, + ExplicitEncoderDecoderPrompt, + ProcessorInputs, + PromptType, + SingletonInputs, + SingletonPrompt, + TextPrompt, + TokenInputs, + TokensPrompt, + embeds_inputs, + token_inputs, +) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) class InputPreprocessor: - def __init__( self, model_config: ModelConfig, - tokenizer: Optional[AnyTokenizer], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, ) -> None: super().__init__() self.model_config = model_config - self.tokenizer = tokenizer self.mm_registry = mm_registry self.mm_processor_cache = mm_processor_cache + if model_config.skip_tokenizer_init: + self.tokenizer = None + else: + self.tokenizer = init_tokenizer_from_configs(model_config) + def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: - raise ValueError("You cannot pass text prompts when " - "`skip_tokenizer_init` is True") + raise ValueError( + "You cannot pass text prompts when `skip_tokenizer_init` is True" + ) return self.tokenizer def get_bos_token_id(self) -> Optional[int]: if self.tokenizer is None: - logger.warning("Using None for BOS token id because tokenizer " - "is not initialized") + logger.warning( + "Using None for BOS token id because tokenizer is not initialized" + ) return None return self.tokenizer.bos_token_id def get_eos_token_id(self) -> Optional[int]: if self.tokenizer is None: - logger.warning("Using None for EOS token id because tokenizer " - "is not initialized") + logger.warning( + "Using None for EOS token id because tokenizer is not initialized" + ) return None return self.tokenizer.eos_token_id @@ -74,22 +95,26 @@ def get_decoder_start_token_id(self) -> Optional[int]: if not self.model_config.is_encoder_decoder: logger.warning_once( "Using None for decoder start token id because " - "this is not an encoder/decoder model.") + "this is not an encoder/decoder model." + ) return None if self.model_config is None or self.model_config.hf_config is None: logger.warning_once( "Using None for decoder start token id because " - "model config is not available.") + "model config is not available." + ) return None - dec_start_token_id = getattr(self.model_config.hf_config, - "decoder_start_token_id", None) + dec_start_token_id = getattr( + self.model_config.hf_config, "decoder_start_token_id", None + ) if dec_start_token_id is None: logger.warning_once( "Falling back on <BOS> for decoder start token " "id because decoder start token id is not " - "available.") + "available." + ) dec_start_token_id = self.get_bos_token_id() return dec_start_token_id @@ -159,8 +184,10 @@ def _prepare_decoder_input_ids_for_generation( # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - if (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): + if ( + len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id + ): decoder_input_ids = [decoder_start_token_id] + decoder_input_ids return decoder_input_ids @@ -250,11 +277,15 @@ def _process_multimodal( mm_hashes = mm_input["mm_hashes"] # Validate that all mm items have a string as their hash - if not contains_only_strings(mm_hashes): + contains_only_strings = all( + isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes) + ) + if not contains_only_strings: raise ValueError( f"mm_hashes must contain only strings, got: {mm_hashes}. " "This is likely due to an incorrect custom implementation of " - "MultiModalProcessor.apply method.") + "MultiModalProcessor.apply method." + ) return mm_input @@ -263,8 +294,9 @@ def _process_embeds( parsed_content: EmbedsPrompt, ) -> EmbedsInputs: if not self.model_config.enable_prompt_embeds: - raise ValueError("You must set `--enable-prompt-embeds` to input " - "`prompt_embeds`.") + raise ValueError( + "You must set `--enable-prompt-embeds` to input `prompt_embeds`." + ) prompt_embeds = parsed_content["prompt_embeds"] @@ -276,24 +308,25 @@ def _process_embeds( prompt_embeds = prompt_embeds.squeeze(dim=0) if prompt_embeds.ndim != 2: - raise ValueError( - "prompt_embeds must be of shape (seq_len, hidden_size).") + raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).") # Tensors must be on CPU for serialization between processes # in the MsgpackEncoder. Casting to CPU here ensures that there is no # hidden device transfer in the critical path of generation. prompt_embeds = prompt_embeds.cpu() - return embeds_inputs(prompt_embeds=prompt_embeds, - cache_salt=parsed_content.get("cache_salt")) + return embeds_inputs( + prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt") + ) def _truncate_inputs( - self, - inputs: list[int], - tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]: - - if not tokenization_kwargs or "truncation" not in \ - tokenization_kwargs or self.tokenizer is None: + self, inputs: list[int], tokenization_kwargs: Optional[dict[str, Any]] = None + ) -> list[int]: + if ( + not tokenization_kwargs + or "truncation" not in tokenization_kwargs + or self.tokenizer is None + ): return inputs max_length = tokenization_kwargs["max_length"] @@ -311,18 +344,22 @@ def _process_tokens( mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> Union[TokenInputs, MultiModalInputs]: prompt_token_ids = self._truncate_inputs( - parsed_content["prompt_token_ids"], tokenization_kwargs) + parsed_content["prompt_token_ids"], tokenization_kwargs + ) inputs: Union[TokenInputs, MultiModalInputs] - if multi_modal_data := parsed_content.get("multi_modal_data"): + if self.model_config.is_multimodal_model: inputs = self._process_multimodal( prompt_token_ids, - multi_modal_data, + parsed_content.get("multi_modal_data", {}), parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) else: + if parsed_content.get("multi_modal_data"): + raise ValueError("This model does not support multimodal inputs") + inputs = token_inputs(prompt_token_ids) if cache_salt := parsed_content.get("cache_salt"): @@ -340,15 +377,18 @@ def _process_text( prompt_text = parsed_content["prompt"] inputs: Union[TokenInputs, MultiModalInputs] - if multi_modal_data := parsed_content.get("multi_modal_data"): + if self.model_config.is_multimodal_model: inputs = self._process_multimodal( prompt_text, - multi_modal_data, + parsed_content.get("multi_modal_data", {}), parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) else: + if parsed_content.get("multi_modal_data"): + raise ValueError("This model does not support multimodal inputs") + prompt_token_ids = self._tokenize_prompt( prompt_text, tokenization_kwargs=tokenization_kwargs, @@ -407,16 +447,20 @@ def _build_enc_dec_llm_inputs( encoder_inputs: SingletonInputs, decoder_inputs: Optional[SingletonInputs], ) -> EncoderDecoderInputs: - if (encoder_inputs["type"] == "embeds" - or decoder_inputs and decoder_inputs["type"] == "embeds"): - raise ValueError("Embedding inputs are not supported for encoder-" - "decoder models") + if ( + encoder_inputs["type"] == "embeds" + or decoder_inputs + and decoder_inputs["type"] == "embeds" + ): + raise ValueError( + "Embedding inputs are not supported for encoder-decoder models" + ) # Needed for mypy - encoder_inputs = cast(Union[TokenInputs, MultiModalInputs], - encoder_inputs) - decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]], - decoder_inputs) + encoder_inputs = cast(Union[TokenInputs, MultiModalInputs], encoder_inputs) + decoder_inputs = cast( + Optional[Union[TokenInputs, MultiModalInputs]], decoder_inputs + ) if decoder_inputs is None: if self.model_config.hf_config.model_type == "whisper": @@ -426,16 +470,18 @@ def _build_enc_dec_llm_inputs( # overridden by the audio features. dec_token_ids = encoder_inputs["prompt_token_ids"].copy() else: - dec_token_ids = self._prepare_decoder_input_ids_for_generation( - None) + dec_token_ids = self._prepare_decoder_input_ids_for_generation(None) decoder_inputs = token_inputs(dec_token_ids) else: if "multi_modal_data" in decoder_inputs: - raise ValueError("Multi-modal decoder inputs of encoder-" - "decoder models are not supported yet") + raise ValueError( + "Multi-modal decoder inputs of encoder-" + "decoder models are not supported yet" + ) dec_token_ids = self._prepare_decoder_input_ids_for_generation( - decoder_inputs["prompt_token_ids"]) + decoder_inputs["prompt_token_ids"] + ) decoder_inputs["prompt_token_ids"] = dec_token_ids return EncoderDecoderInputs( @@ -452,10 +498,14 @@ def _split_enc_dec_mm_inputs( For encoder/decoder models only: Separate Encoder/Decoder inputs from a MultiModalEncDecInputs """ - if (inputs["type"] == "embeds" or decoder_inputs_to_override - and decoder_inputs_to_override["type"] == "embeds"): - raise ValueError("Embedding inputs are not supported for encoder-" - "decoder models") + if ( + inputs["type"] == "embeds" + or decoder_inputs_to_override + and decoder_inputs_to_override["type"] == "embeds" + ): + raise ValueError( + "Embedding inputs are not supported for encoder-decoder models" + ) # Needed for mypy inputs = cast( @@ -472,9 +522,11 @@ def _split_enc_dec_mm_inputs( if inputs["type"] == "multimodal": # Multimodal data inputs if "encoder_prompt_token_ids" not in inputs: - raise RuntimeError("You should register an encoder-decoder " - "multi-modal processor for encoder-decoder " - "models.") + raise RuntimeError( + "You should register an encoder-decoder " + "multi-modal processor for encoder-decoder " + "models." + ) inputs = cast(MultiModalEncDecInputs, inputs) encoder_inputs = token_inputs(inputs["encoder_prompt_token_ids"]) @@ -556,9 +608,9 @@ def _process_encoder_decoder_prompt( # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model: - encoder_inputs, decoder_inputs = ( - self._split_enc_dec_mm_inputs(encoder_inputs, - decoder_inputs)) + encoder_inputs, decoder_inputs = self._split_enc_dec_mm_inputs( + encoder_inputs, decoder_inputs + ) else: # `cast` is needed for mypy, but not pyright inputs = self._prompt_to_llm_inputs( @@ -568,8 +620,7 @@ def _process_encoder_decoder_prompt( ) if self.model_config.is_multimodal_model: # Encoder-Decoder Multimodal model - encoder_inputs, decoder_inputs = ( - self._split_enc_dec_mm_inputs(inputs)) + encoder_inputs, decoder_inputs = self._split_enc_dec_mm_inputs(inputs) else: encoder_inputs = inputs decoder_inputs = None @@ -581,8 +632,9 @@ def _build_decoder_only_llm_inputs( prompt_inputs: DecoderOnlyInputs, ) -> DecoderOnlyInputs: if "prompt_token_ids" in prompt_inputs: - prompt_inputs = cast(Union[TokenInputs, MultiModalInputs], - prompt_inputs) # Needed for mypy + prompt_inputs = cast( + Union[TokenInputs, MultiModalInputs], prompt_inputs + ) # Needed for mypy return prompt_inputs @@ -633,8 +685,9 @@ def preprocess( ) if is_explicit_encoder_decoder_prompt(prompt): - raise ValueError("Cannot pass encoder-decoder prompt " - "to decoder-only models") + raise ValueError( + "Cannot pass encoder-decoder prompt to decoder-only models" + ) # Decoder-only operation # `cast` is needed for mypy, but not pyright @@ -647,15 +700,3 @@ def preprocess( def clear_cache(self) -> None: if self.mm_processor_cache is not None: self.mm_processor_cache.clear_cache() - - -# Helper function to validate that a nested dictionary contains -# only strings or list of strings as the leaf values. -def contains_only_strings(obj: object): - if isinstance(obj, str): - return True - if isinstance(obj, list): - return all(isinstance(x, str) for x in obj) - if isinstance(obj, dict): - return all(contains_only_strings(v) for v in obj.values()) - return False diff --git a/vllm/logger.py b/vllm/logger.py index 2861e0f1686c..37e8495768c0 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Logging configuration for vLLM.""" + import datetime import json import logging @@ -22,8 +23,10 @@ VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX VLLM_LOGGING_STREAM = envs.VLLM_LOGGING_STREAM -_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " - "[%(fileinfo)s:%(lineno)d] %(message)s") +_FORMAT = ( + f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s " + "[%(fileinfo)s:%(lineno)d] %(message)s" +) _DATE_FORMAT = "%m-%d %H:%M:%S" DEFAULT_LOGGING_CONFIG = { @@ -50,7 +53,7 @@ }, }, "version": 1, - "disable_existing_loggers": False + "disable_existing_loggers": False, } @@ -119,7 +122,8 @@ def _configure_vllm_root_logger() -> None: "VLLM_CONFIGURE_LOGGING evaluated to false, but " "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH " "implies VLLM_CONFIGURE_LOGGING. Please enable " - "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.") + "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH." + ) if VLLM_CONFIGURE_LOGGING: logging_config = DEFAULT_LOGGING_CONFIG @@ -128,13 +132,16 @@ def _configure_vllm_root_logger() -> None: if not path.exists(VLLM_LOGGING_CONFIG_PATH): raise RuntimeError( "Could not load logging config. File does not exist: %s", - VLLM_LOGGING_CONFIG_PATH) + VLLM_LOGGING_CONFIG_PATH, + ) with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8") as file: custom_config = json.loads(file.read()) if not isinstance(custom_config, dict): - raise ValueError("Invalid logging config. Expected dict, got %s.", - type(custom_config).__name__) + raise ValueError( + "Invalid logging config. Expected dict, got %s.", + type(custom_config).__name__, + ) logging_config = custom_config for formatter in logging_config.get("formatters", {}).values(): @@ -168,7 +175,7 @@ def init_logger(name: str) -> _VllmLogger: def _trace_calls(log_path, root_dir, frame, event, arg=None): - if event in ['call', 'return']: + if event in ["call", "return"]: # Extract the filename, line number, function name, and the code object filename = frame.f_code.co_filename lineno = frame.f_lineno @@ -188,26 +195,29 @@ def _trace_calls(log_path, root_dir, frame, event, arg=None): last_filename = "" last_lineno = 0 last_func_name = "" - with open(log_path, 'a') as f: + with open(log_path, "a") as f: ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") - if event == 'call': - f.write(f"{ts} Call to" - f" {func_name} in {filename}:{lineno}" - f" from {last_func_name} in {last_filename}:" - f"{last_lineno}\n") + if event == "call": + f.write( + f"{ts} Call to" + f" {func_name} in {filename}:{lineno}" + f" from {last_func_name} in {last_filename}:" + f"{last_lineno}\n" + ) else: - f.write(f"{ts} Return from" - f" {func_name} in {filename}:{lineno}" - f" to {last_func_name} in {last_filename}:" - f"{last_lineno}\n") + f.write( + f"{ts} Return from" + f" {func_name} in {filename}:{lineno}" + f" to {last_func_name} in {last_filename}:" + f"{last_lineno}\n" + ) except NameError: # modules are deleted during shutdown pass return partial(_trace_calls, log_path, root_dir) -def enable_trace_function_call(log_file_path: str, - root_dir: Optional[str] = None): +def enable_trace_function_call(log_file_path: str, root_dir: Optional[str] = None): """ Enable tracing of every function call in code under `root_dir`. This is useful for debugging hangs or crashes. @@ -221,7 +231,8 @@ def enable_trace_function_call(log_file_path: str, logger.warning( "VLLM_TRACE_FUNCTION is enabled. It will record every" " function executed by Python. This will slow down the code. It " - "is suggested to be used for debugging hang or crashes only.") + "is suggested to be used for debugging hang or crashes only." + ) logger.info("Trace frame log is saved to %s", log_file_path) if root_dir is None: # by default, this is the vllm root directory diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py index ad89638e1061..3a97000647d6 100644 --- a/vllm/logging_utils/dump_input.py +++ b/vllm/logging_utils/dump_input.py @@ -21,9 +21,10 @@ def prepare_object_to_dump(obj) -> str: if isinstance(obj, str): return f"'{obj}'" # Double quotes elif isinstance(obj, dict): - dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \ - for k, v in obj.items()}) - return f'{{{dict_str}}}' + dict_str = ", ".join( + {f"{str(k)}: {prepare_object_to_dump(v)}" for k, v in obj.items()} + ) + return f"{{{dict_str}}}" elif isinstance(obj, list): return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]" elif isinstance(obj, set): @@ -36,15 +37,14 @@ def prepare_object_to_dump(obj) -> str: elif isinstance(obj, torch.Tensor): # We only print the 'draft' of the tensor to not expose sensitive data # and to get some metadata in case of CUDA runtime crashed - return (f"Tensor(shape={obj.shape}, " - f"device={obj.device}," - f"dtype={obj.dtype})") - elif hasattr(obj, 'anon_repr'): + return f"Tensor(shape={obj.shape}, device={obj.device},dtype={obj.dtype})" + elif hasattr(obj, "anon_repr"): return obj.anon_repr() - elif hasattr(obj, '__dict__'): + elif hasattr(obj, "__dict__"): items = obj.__dict__.items() - dict_str = ', '.join([f'{str(k)}={prepare_object_to_dump(v)}' \ - for k, v in items]) + dict_str = ", ".join( + [f"{str(k)}={prepare_object_to_dump(v)}" for k, v in items] + ) return f"{type(obj).__name__}({dict_str})" else: # Hacky way to make sure we can serialize the object in JSON format @@ -54,18 +54,22 @@ def prepare_object_to_dump(obj) -> str: return repr(obj) -def dump_engine_exception(config: VllmConfig, - scheduler_output: SchedulerOutput, - scheduler_stats: Optional[SchedulerStats]): +def dump_engine_exception( + config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats], +): # NOTE: ensure we can log extra info without risking raises # unexpected errors during logging with contextlib.suppress(Exception): _dump_engine_exception(config, scheduler_output, scheduler_stats) -def _dump_engine_exception(config: VllmConfig, - scheduler_output: SchedulerOutput, - scheduler_stats: Optional[SchedulerStats]): +def _dump_engine_exception( + config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats], +): logger.error( "Dumping input data for V1 LLM engine (v%s) with config: %s, ", VLLM_VERSION, @@ -73,8 +77,7 @@ def _dump_engine_exception(config: VllmConfig, ) try: dump_obj = prepare_object_to_dump(scheduler_output) - logger.error("Dumping scheduler output for model execution: %s", - dump_obj) + logger.error("Dumping scheduler output for model execution: %s", dump_obj) if scheduler_stats: logger.error("Dumping scheduler stats: %s", scheduler_stats) except Exception: diff --git a/vllm/logging_utils/formatter.py b/vllm/logging_utils/formatter.py index 004b79f3ea6e..02ba308e1879 100644 --- a/vllm/logging_utils/formatter.py +++ b/vllm/logging_utils/formatter.py @@ -18,7 +18,6 @@ def __init__(self, fmt, datefmt=None, style="%"): self.root_dir = Path(__file__).resolve().parent.parent.parent def format(self, record): - def shrink_path(relpath: Path) -> str: """ Shortens a file path for logging display: @@ -62,8 +61,7 @@ def shrink_path(relpath: Path) -> str: abs_path = getattr(record, "pathname", None) if abs_path: try: - relpath = Path(abs_path).resolve().relative_to( - self.root_dir) + relpath = Path(abs_path).resolve().relative_to(self.root_dir) except Exception: relpath = Path(record.filename) else: diff --git a/vllm/logging_utils/log_time.py b/vllm/logging_utils/log_time.py index 013dd144beaf..9e94f463711d 100644 --- a/vllm/logging_utils/log_time.py +++ b/vllm/logging_utils/log_time.py @@ -15,15 +15,17 @@ def logtime(logger, msg=None): """ def _inner(func): - @functools.wraps(func) def _wrapper(*args, **kwargs): start = time.perf_counter() result = func(*args, **kwargs) elapsed = time.perf_counter() - start - prefix = f"Function '{func.__module__}.{func.__qualname__}'" \ - if msg is None else msg + prefix = ( + f"Function '{func.__module__}.{func.__qualname__}'" + if msg is None + else msg + ) logger.debug("%s: Elapsed time %.7f secs", prefix, elapsed) return result diff --git a/vllm/logits_process.py b/vllm/logits_process.py index 48f7e7495b17..6ac30ae0028e 100644 --- a/vllm/logits_process.py +++ b/vllm/logits_process.py @@ -19,8 +19,8 @@ def get_bad_words_logits_processors( - bad_words: list[str], - tokenizer: AnyTokenizer) -> list[LogitsProcessor]: + bad_words: list[str], tokenizer: AnyTokenizer +) -> list[LogitsProcessor]: bad_words_ids: list[list[int]] = list() for bad_word in bad_words: @@ -31,15 +31,15 @@ def get_bad_words_logits_processors( prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - prompt_token_ids = tokenizer.encode(text=prompt, - add_special_tokens=False) + prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False) # If no space at the beginning # or if prefix space produces a new word token if (not add_prefix_space) or ( - add_prefix_space - and prompt_token_ids[0] != bad_words_ids[-1][0] - and len(prompt_token_ids) == len(bad_words_ids[-1])): + add_prefix_space + and prompt_token_ids[0] != bad_words_ids[-1][0] + and len(prompt_token_ids) == len(bad_words_ids[-1]) + ): bad_words_ids.append(prompt_token_ids) return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)] @@ -78,8 +78,9 @@ def __call__( assert len(actual_prefix) == len(expected_prefix) is_match = tuple(actual_prefix) == tuple(expected_prefix) - last_token_bias[last_token_id] += (self._SMALLEST_LOGIT if is_match - else self._NEUTRAL_LOGIT) + last_token_bias[last_token_id] += ( + self._SMALLEST_LOGIT if is_match else self._NEUTRAL_LOGIT + ) logits = logits + self.word_bias + last_token_bias @@ -93,9 +94,9 @@ def _init_word_bias(self, logits: torch.FloatTensor) -> None: self._check_token_ids_bounds(vocab_size=vocab_size) - self.word_bias = torch.zeros((vocab_size, ), - dtype=torch.float, - device=logits.device) + self.word_bias = torch.zeros( + (vocab_size,), dtype=torch.float, device=logits.device + ) for bad_word_ids in self.bad_words_ids: if len(bad_word_ids) == 1: @@ -116,4 +117,5 @@ def _check_token_ids_bounds(self, vocab_size: int) -> None: f" but the following tokens" f" were specified as bad: {invalid_token_ids}." f" All token id values should be integers satisfying:" - f" 0 <= token_id < {vocab_size}.") + f" 0 <= token_id < {vocab_size}." + ) diff --git a/vllm/logprobs.py b/vllm/logprobs.py index e58ca142c00a..2458e43c690f 100644 --- a/vllm/logprobs.py +++ b/vllm/logprobs.py @@ -16,6 +16,7 @@ class Logprob: rank: The vocab rank of chosen token (>=1) decoded_token: The decoded chosen token index """ + logprob: float rank: Optional[int] = None decoded_token: Optional[str] = None diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py index d3bb145dc7bf..4915ef85f4f7 100644 --- a/vllm/lora/layers/__init__.py +++ b/vllm/lora/layers/__init__.py @@ -2,18 +2,23 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.lora.layers.base import BaseLayerWithLoRA from vllm.lora.layers.column_parallel_linear import ( - ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA, + ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, - MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, - QKVParallelLinearWithShardedLoRA) + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, + QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, +) from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA from vllm.lora.layers.row_parallel_linear import ( - RowParallelLinearWithLoRA, RowParallelLinearWithShardedLoRA) + RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, +) from vllm.lora.layers.utils import LoRAMapping -from vllm.lora.layers.vocal_parallel_embedding import ( - VocabParallelEmbeddingWithLoRA) +from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA __all__ = [ "BaseLayerWithLoRA", diff --git a/vllm/lora/layers/base.py b/vllm/lora/layers/base.py index a80a033e39b4..753dc268a2ff 100644 --- a/vllm/lora/layers/base.py +++ b/vllm/lora/layers/base.py @@ -14,7 +14,6 @@ class BaseLayerWithLoRA(nn.Module): - def slice_lora_a( self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index ed294b0aedaf..d2f017c19ccd 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -8,10 +8,12 @@ from vllm.config.lora import LoRAConfig from vllm.distributed.utils import divide -# yapf: disable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + ReplicatedLinear, + RowParallelLinear, +) from vllm.platforms import current_platform from .base import BaseLayerWithLoRA @@ -19,7 +21,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): - def __init__(self, base_layer: LinearBase): super().__init__() self.base_layer = base_layer @@ -46,16 +47,20 @@ def create_lora_weights( lora_b_out_size = self.output_size elif isinstance(self.base_layer, ColumnParallelLinear): - lora_a_out_size = (lora_config.max_lora_rank if - not lora_config.fully_sharded_loras else divide( - lora_config.max_lora_rank, self.tp_size)) + lora_a_out_size = ( + lora_config.max_lora_rank + if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size) + ) lora_b_out_size = self.output_size elif isinstance(self.base_layer, RowParallelLinear): lora_a_out_size = lora_config.max_lora_rank - lora_b_out_size = (self.output_size if - not lora_config.fully_sharded_loras else divide( - self.output_size, self.tp_size)) + lora_b_out_size = ( + self.output_size + if not lora_config.fully_sharded_loras + else divide(self.output_size, self.tp_size) + ) else: raise NotImplementedError @@ -67,7 +72,9 @@ def create_lora_weights( self.input_size, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(self.n_slices)) + ) + for _ in range(self.n_slices) + ) self.lora_b_stacked = tuple( torch.zeros( max_loras, @@ -76,7 +83,9 @@ def create_lora_weights( lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(self.n_slices)) + ) + for _ in range(self.n_slices) + ) if lora_config.bias_enabled: lora_bias_out_size = lora_b_out_size self.lora_bias_stacked = tuple( @@ -86,8 +95,10 @@ def create_lora_weights( lora_bias_out_size, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(self.n_slices)) - self.output_slices = (self.lora_b_stacked[0].shape[2], ) + ) + for _ in range(self.n_slices) + ) + self.output_slices = (self.lora_b_stacked[0].shape[2],) def reset_lora(self, index: int): for s_index in range(self.n_slices): @@ -95,8 +106,9 @@ def reset_lora(self, index: int): self.lora_b_stacked[s_index][index] = 0 if self.lora_config.bias_enabled: # Make mypy happy - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) + self.lora_bias_stacked = cast( + tuple[torch.Tensor, ...], self.lora_bias_stacked + ) self.lora_bias_stacked[s_index][index] = 0 def set_lora( @@ -111,8 +123,9 @@ def set_lora( # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # store weights in a tuple of size 1. These two layers will # override this function. - assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == - self.n_slices == 1) + assert ( + len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1 + ) self.reset_lora(index) if self.tp_size > 1: @@ -121,23 +134,24 @@ def set_lora( if lora_bias is not None: lora_bias = self.slice_bias(lora_bias) - self.lora_a_stacked[0][index, - 0, :lora_a.shape[0], :lora_a.shape[1]].copy_( - lora_a, non_blocking=True) - self.lora_b_stacked[0][index, - 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( - lora_b, non_blocking=True) + self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( + lora_a, non_blocking=True + ) + self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( + lora_b, non_blocking=True + ) if lora_bias is not None: - - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) + self.lora_bias_stacked = cast( + tuple[torch.Tensor, ...], self.lora_bias_stacked + ) assert len(self.lora_bias_stacked) - self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( - lora_bias, non_blocking=True) + self.lora_bias_stacked[0][index, 0, : lora_bias.shape[0]].copy_( + lora_bias, non_blocking=True + ) - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) # In transformers backend, x and output have extra batch dimension like @@ -147,10 +161,15 @@ def apply(self, output = output.flatten(0, 1) x = x.flatten(0, 1) - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_linear( - output, x, self.lora_a_stacked, self.lora_b_stacked, - self.lora_bias_stacked, 1.0, self.output_slices) + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, + 1.0, + self.output_slices, + ) if not current_platform.can_update_inplace(): output = lora_output @@ -158,7 +177,6 @@ def apply(self, @property def weight(self) -> torch.Tensor: - # unquantizedLinear if hasattr(self.base_layer, "weight"): return self.base_layer.weight diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 6284576446c8..011d38157456 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -10,9 +10,11 @@ from vllm.config.lora import LoRAConfig from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.utils import divide -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, +) from vllm.platforms import current_platform from .base_linear import BaseLinearLayerWithLoRA @@ -20,12 +22,16 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): - """ - For `ColumnParallelLinearWithLoRA` or classes that inherit from + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. """ - assert (layer.n_slices == len(layer.lora_a_stacked) == len( - layer.lora_b_stacked) == len(layer.output_slices)) + assert ( + layer.n_slices + == len(layer.lora_a_stacked) + == len(layer.lora_b_stacked) + == len(layer.output_slices) + ) if layer.lora_bias_stacked is not None: assert layer.n_slices == len(layer.lora_bias_stacked) @@ -43,7 +49,8 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): ) shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( - buffers, x, layer.lora_a_stacked, 1.0) + buffers, x, layer.lora_a_stacked, 1.0 + ) if not current_platform.can_update_inplace(): buffers = shrunk_buffers @@ -57,7 +64,8 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): layer.lora_bias_stacked, layer.output_slices, offset_start=0, - add_input=True) + add_input=True, + ) if not current_platform.can_update_inplace(): output = lora_output @@ -81,8 +89,7 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None: # The base_layer type is ColumnParallelLinear or # MergedColumnParallelLinear, their weight sharding logic is # inconsistent when TP is greater than 1. - self.is_merged_col_linear = type( - base_layer) is MergedColumnParallelLinear + self.is_merged_col_linear = type(base_layer) is MergedColumnParallelLinear self.output_size = self.base_layer.output_size_per_partition # There is only one LoRA layer self.n_slices = 1 @@ -97,10 +104,14 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: shard_size = self.output_size // 2 offset = lora_b.shape[0] // 2 - left_weight = lora_b[self.tp_rank * shard_size:(self.tp_rank + 1) * - shard_size, :] - right_weight = lora_b[offset + self.tp_rank * shard_size:offset + - (self.tp_rank + 1) * shard_size, :] + left_weight = lora_b[ + self.tp_rank * shard_size : (self.tp_rank + 1) * shard_size, : + ] + right_weight = lora_b[ + offset + self.tp_rank * shard_size : offset + + (self.tp_rank + 1) * shard_size, + :, + ] lora_b = torch.cat([left_weight, right_weight], dim=0) # Applicable to cases where the base_layer is # ColumnParallelLinear. @@ -133,8 +144,7 @@ def forward( - output - bias """ - bias = (self.base_layer.bias - if not self.base_layer.skip_bias_add else None) + bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None # Matrix multiply. output_parallel = self.apply(input_, bias) @@ -147,8 +157,7 @@ def forward( if not self.base_layer.return_bias: return output - output_bias = (self.base_layer.bias - if self.base_layer.skip_bias_add else None) + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None return output, output_bias @classmethod @@ -162,7 +171,8 @@ def can_replace_layer( ) -> bool: return type(source_layer) is ColumnParallelLinear or ( type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 1) + and len(packed_modules_list) == 1 + ) class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): @@ -175,17 +185,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ def __init__( - self, base_layer: Union[MergedColumnParallelLinear, - QKVParallelLinear]) -> None: + self, base_layer: Union[MergedColumnParallelLinear, QKVParallelLinear] + ) -> None: super().__init__(base_layer) # There are two LoRA layers # the output_sizes in MergedColumnParallelLinear is not sharded by tp # we need to divide it by the tp_size to get correct slices size output_sizes = self.base_layer.output_sizes self.output_slices = tuple( - divide(output_size, self.tp_size) for output_size in output_sizes) + divide(output_size, self.tp_size) for output_size in output_sizes + ) self.n_slices = len(self.output_slices) - self.output_ids = (self.tp_rank, ) * self.n_slices + self.output_ids = (self.tp_rank,) * self.n_slices def create_lora_weights( self, @@ -194,14 +205,16 @@ def create_lora_weights( model_config: Optional[PretrainedConfig] = None, ) -> None: """ - The main reason for overriding this function is to enhance code + The main reason for overriding this function is to enhance code maintainability. """ self.lora_config = lora_config lora_a_output_size_per_partition = ( - lora_config.max_lora_rank if not lora_config.fully_sharded_loras - else divide(lora_config.max_lora_rank, self.tp_size)) + lora_config.max_lora_rank + if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size) + ) self.lora_a_stacked = tuple( torch.zeros( @@ -211,7 +224,9 @@ def create_lora_weights( self.input_size, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(self.n_slices)) + ) + for _ in range(self.n_slices) + ) self.lora_b_stacked = tuple( torch.zeros( max_loras, @@ -220,7 +235,9 @@ def create_lora_weights( lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=self.device, - ) for output_size in self.output_slices) + ) + for output_size in self.output_slices + ) if lora_config.bias_enabled: self.lora_bias_stacked = tuple( torch.zeros( @@ -229,7 +246,9 @@ def create_lora_weights( output_size, dtype=lora_config.lora_dtype, device=self.device, - ) for output_size in self.output_slices) + ) + for output_size in self.output_slices + ) def slice_lora_a( self, lora_a: list[Union[torch.Tensor, None]] @@ -241,20 +260,22 @@ def slice_lora_b( ) -> list[Union[torch.Tensor, None]]: sliced_lora_b = [None] * self.n_slices for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices)): + zip(self.output_ids, self.output_slices) + ): if (lora_b_i := lora_b[i]) is not None: - sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size * - (shard_id + 1), :] + sliced_lora_b[i] = lora_b_i[ + shard_size * shard_id : shard_size * (shard_id + 1), : + ] return sliced_lora_b def slice_bias( - self, bias: list[Union[torch.Tensor, - None]]) -> list[Union[torch.Tensor, None]]: + self, bias: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: for i, (shard_id, shard_size) in enumerate( - zip(self.output_ids, self.output_slices)): + zip(self.output_ids, self.output_slices) + ): if (bias_i := bias[i]) is not None: - bias[i] = bias_i[shard_size * shard_id:shard_size * - (shard_id + 1)] + bias[i] = bias_i[shard_size * shard_id : shard_size * (shard_id + 1)] return bias def set_lora( @@ -276,22 +297,22 @@ def set_lora( for i in range(self.n_slices): if (lora_a_i := lora_a[i]) is not None: self.lora_a_stacked[i][ - index, 0, :lora_a_i.shape[0], :lora_a_i.shape[1]].copy_( - lora_a_i, non_blocking=True) + index, 0, : lora_a_i.shape[0], : lora_a_i.shape[1] + ].copy_(lora_a_i, non_blocking=True) if (lora_b_i := lora_b[i]) is not None: self.lora_b_stacked[i][ - index, 0, :lora_b_i.shape[0], :lora_b_i.shape[1]].copy_( - lora_b_i, non_blocking=True) + index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1] + ].copy_(lora_b_i, non_blocking=True) if lora_bias is not None: - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) + self.lora_bias_stacked = cast( + tuple[torch.Tensor, ...], self.lora_bias_stacked + ) for i in range(self.n_slices): if (lora_bias_i := lora_bias[i]) is not None: - self.lora_bias_stacked[i][index, - 0, :lora_bias_i.shape[0]].copy_( - lora_bias_i, - non_blocking=True) + self.lora_bias_stacked[i][index, 0, : lora_bias_i.shape[0]].copy_( + lora_bias_i, non_blocking=True + ) @classmethod @_not_fully_sharded_can_replace @@ -302,8 +323,10 @@ def can_replace_layer( packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: - return (type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 2) + return ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2 + ) class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): @@ -321,57 +344,70 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): def __init__(self, base_layer: QKVParallelLinear) -> None: super().__init__(base_layer) - self.q_proj_total_size = (self.base_layer.total_num_heads * - self.base_layer.head_size) - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) - self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * - self.base_layer.head_size) + self.q_proj_total_size = ( + self.base_layer.total_num_heads * self.base_layer.head_size + ) + self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size + self.kv_proj_shard_size = ( + self.base_layer.num_kv_heads * self.base_layer.head_size + ) + self.kv_proj_total_size = ( + self.base_layer.total_num_kv_heads * self.base_layer.head_size + ) # There is only one LoRA layer self.n_slices = 1 def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - self.q_shard_id = self.tp_rank self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas - lora_b_q = lora_b[self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1), :] + lora_b_q = lora_b[ + self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size + * (self.q_shard_id + 1), + :, + ] k_offset = self.q_proj_total_size - lora_b_k = lora_b[k_offset + - self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1), :] + lora_b_k = lora_b[ + k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1), + :, + ] v_offset = k_offset + self.kv_proj_total_size - lora_b_v = lora_b[v_offset + - self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1), :] + lora_b_v = lora_b[ + v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1), + :, + ] lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0) return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: - bias_q = bias[self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] + bias_q = bias[ + self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size + * (self.q_shard_id + 1) + ] k_offset = self.q_proj_total_size - bias_k = bias[k_offset + - self.kv_proj_shard_size * self.kv_shard_id:k_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias_k = bias[ + k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1) + ] v_offset = k_offset + self.kv_proj_total_size - bias_v = bias[v_offset + - self.kv_proj_shard_size * self.kv_shard_id:v_offset + - self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias_v = bias[ + v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1) + ] bias = torch.cat([bias_q, bias_k, bias_v], dim=1) return bias @classmethod @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is QKVParallelLinear and len( - packed_modules_list) == 1 + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1 class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): @@ -390,10 +426,10 @@ def __init__(self, base_layer: QKVParallelLinear) -> None: # There are three LoRA layer. self.n_slices = len(self.base_layer.output_sizes) - self.q_proj_shard_size = (self.base_layer.num_heads * - self.base_layer.head_size) - self.kv_proj_shard_size = (self.base_layer.num_kv_heads * - self.base_layer.head_size) + self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size + self.kv_proj_shard_size = ( + self.base_layer.num_kv_heads * self.base_layer.head_size + ) self.q_shard_id = self.tp_rank self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas @@ -415,7 +451,7 @@ def create_lora_weights( model_config: Optional[PretrainedConfig] = None, ) -> None: """ - The main reason for overloading this function is to handle inconsistent + The main reason for overloading this function is to handle inconsistent weight dimensions in qkv lora. """ super().create_lora_weights(max_loras, lora_config, model_config) @@ -429,8 +465,7 @@ def can_replace_layer( packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: - return (type(source_layer) is QKVParallelLinear - and len(packed_modules_list) == 3) + return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3 # These following layers are based on the tensor parallelism strategy given in @@ -453,12 +488,12 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: shard_size = self.lora_a_stacked[0].shape[2] start_idx = self.tp_rank * shard_size - lora_a = lora_a[start_idx:start_idx + shard_size, :] + lora_a = lora_a[start_idx : start_idx + shard_size, :] return lora_a - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @@ -480,8 +515,7 @@ def can_replace_layer( ) -class MergedColumnParallelLinearWithShardedLoRA( - MergedColumnParallelLinearWithLoRA): +class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLoRA): """ Differs from MergedColumnParallelLinearWithLoRA by slicing the LoRA A's also. @@ -492,20 +526,22 @@ class MergedColumnParallelLinearWithShardedLoRA( def slice_lora_a( self, lora_a: list[Union[torch.Tensor, None]] ) -> list[Union[torch.Tensor, None]]: - #NOTE: lora_a contains 2 subloras, and each sublora could be None. + # NOTE: lora_a contains 2 subloras, and each sublora could be None. output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size lora_a = [ - lora_a[0][output_start_idx:output_start_idx + - output_shard_size, :] if lora_a[0] is not None else None, - lora_a[1][output_start_idx:output_start_idx + - output_shard_size, :] if lora_a[1] is not None else None, + lora_a[0][output_start_idx : output_start_idx + output_shard_size, :] + if lora_a[0] is not None + else None, + lora_a[1][output_start_idx : output_start_idx + output_shard_size, :] + if lora_a[1] is not None + else None, ] return lora_a - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @@ -538,19 +574,23 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: shard_size = self.lora_a_stacked[0].shape[2] start_idx = self.tp_rank * shard_size - lora_a = lora_a[start_idx:start_idx + shard_size, :] + lora_a = lora_a[start_idx : start_idx + shard_size, :] return lora_a - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig], + ) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( source_layer=source_layer, @@ -563,7 +603,7 @@ def can_replace_layer(cls, source_layer: nn.Module, class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): """ - Differs from MergedQKVParallelLinearWithLoRA by slicing the + Differs from MergedQKVParallelLinearWithLoRA by slicing the LoRA A's also. Based on S-LoRA, slicing happens along the rank dim. @@ -576,18 +616,21 @@ def slice_lora_a( shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)] lora_a = [ - lora_a[0][start_idx[0]:start_idx[0] + - shard_size[0], :] if lora_a[0] is not None else None, - lora_a[1][start_idx[1]:start_idx[1] + - shard_size[1], :] if lora_a[1] is not None else None, - lora_a[2][start_idx[2]:start_idx[2] + - shard_size[2], :] if lora_a[2] is not None else None, + lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :] + if lora_a[0] is not None + else None, + lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :] + if lora_a[1] is not None + else None, + lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :] + if lora_a[2] is not None + else None, ] return lora_a - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index b8fbad3a4af0..4f30c9db4c67 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -9,11 +9,12 @@ from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.platforms import current_platform from .base import BaseLayerWithLoRA @@ -34,9 +35,14 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): no reindexing will be done. """ - def __init__(self, base_layer: LogitsProcessor, hidden_size: int, - dtype: torch.dtype, device: torch.device, - sharded_to_full_mapping: Optional[list[int]]) -> None: + def __init__( + self, + base_layer: LogitsProcessor, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + sharded_to_full_mapping: Optional[list[int]], + ) -> None: super().__init__() self.base_layer = base_layer self.hidden_size = hidden_size @@ -86,8 +92,9 @@ def create_lora_weights( ) -> None: # TODO: Verify if this condition can be further relaxed if 32000 < self.base_layer.vocab_size > 257024: - raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 257024") + raise ValueError( + "When using LoRA, vocab size must be 32000 >= vocab_size <= 257024" + ) self.lora_a_stacked = torch.zeros( ( max_loras, @@ -103,9 +110,10 @@ def create_lora_weights( max_loras, 1, # Pad for kernel compatibility - math.ceil(self.base_layer.vocab_size / - lora_config.lora_vocab_padding_size) * - lora_config.lora_vocab_padding_size, + math.ceil( + self.base_layer.vocab_size / lora_config.lora_vocab_padding_size + ) + * lora_config.lora_vocab_padding_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -119,9 +127,8 @@ def create_lora_weights( ) if self.sharded_to_full_mapping is not None: self.sharded_to_full_mapping_gpu = torch.tensor( - self.sharded_to_full_mapping, - device=self.device, - dtype=torch.long) + self.sharded_to_full_mapping, device=self.device, dtype=torch.long + ) else: self.sharded_to_full_mapping_gpu = None @@ -139,17 +146,17 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) - self.lora_a_stacked[index, - 0, :lora_a.shape[0], :lora_a.shape[1]].copy_( - lora_a, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( - lora_b, non_blocking=True) + self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( + lora_a, non_blocking=True + ) + self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( + lora_b, non_blocking=True + ) if embeddings_tensor is not None: self.embeddings_tensors[ index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], + : embeddings_tensor.shape[0], + : embeddings_tensor.shape[1], ] = embeddings_tensor def _get_logits( @@ -195,41 +202,41 @@ def _get_logits( dtype=self.embeddings_tensors.dtype, device=self.embeddings_tensors.device, ) - torch.matmul(self.embeddings_tensors, - hidden_states.T, - out=lora_logits[:-1]) + torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) - neg_inf, pos_inf = current_platform.get_infinity_values( - lora_logits.dtype) + neg_inf, pos_inf = current_platform.get_infinity_values(lora_logits.dtype) lora_logits[-1] = neg_inf lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded if current_platform.is_tpu() or current_platform.is_xpu(): - indices_padded = indices_padded[:logits.size(0)] - - lora_logits = (lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, - posinf=pos_inf, - neginf=neg_inf)) + indices_padded = indices_padded[: logits.size(0)] + + lora_logits = ( + lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ) + .index_select(0, indices_padded) + .nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf) + ) - logits[:, - self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits + logits[ + :, + self.base_layer.org_vocab_size : self.base_layer.org_vocab_size + + lora_logits.shape[1], + ] = lora_logits - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_logits( - logits, hidden_states, self.lora_a_stacked, - self.lora_b_stacked, 1.0) + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0 + ) if not current_platform.can_update_inplace(): logits = lora_output # Remove paddings in vocab (if any). - logits = logits[:, :self.base_layer.vocab_size] + logits = logits[:, : self.base_layer.vocab_size] return logits def forward(self, *args, **kwargs): diff --git a/vllm/lora/layers/qkv_x_parallel_linear.py b/vllm/lora/layers/qkv_x_parallel_linear.py index 367482d0ee07..785cdf38e360 100644 --- a/vllm/lora/layers/qkv_x_parallel_linear.py +++ b/vllm/lora/layers/qkv_x_parallel_linear.py @@ -3,6 +3,6 @@ from .base import BaseLayerWithLoRA -#TODO: Implement this +# TODO: Implement this class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA): pass diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py index 18a8f13ed942..18a35cd1e0f2 100644 --- a/vllm/lora/layers/replicated_linear.py +++ b/vllm/lora/layers/replicated_linear.py @@ -14,9 +14,10 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): - def __init__(self, base_layer: ReplicatedLinear) -> None: - super().__init__(base_layer, ) + super().__init__( + base_layer, + ) # To ensure interface compatibility, set to 1 always. self.output_size = self.base_layer.output_size self.n_slices = 1 @@ -33,14 +34,12 @@ def forward( - output - bias """ - bias = (self.base_layer.bias - if not self.base_layer.skip_bias_add else None) + bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None # Matrix multiply. output = self.apply(input_, bias) - output_bias = (self.base_layer.bias - if self.base_layer.skip_bias_add else None) + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None if not self.base_layer.return_bias: return output diff --git a/vllm/lora/layers/row_parallel_linear.py b/vllm/lora/layers/row_parallel_linear.py index d468655e629a..738371f22a36 100644 --- a/vllm/lora/layers/row_parallel_linear.py +++ b/vllm/lora/layers/row_parallel_linear.py @@ -8,9 +8,10 @@ from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig -from vllm.distributed import (split_tensor_along_last_dim, - tensor_model_parallel_all_reduce) -# yapf: disable +from vllm.distributed import ( + split_tensor_along_last_dim, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.linear import RowParallelLinear from vllm.platforms import current_platform @@ -19,7 +20,6 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): - def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__(base_layer) @@ -30,11 +30,10 @@ def __init__(self, base_layer: RowParallelLinear) -> None: self.n_slices = 1 def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - shard_size = self.input_size start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size - lora_a = lora_a[:,start_idx:end_idx] + lora_a = lora_a[:, start_idx:end_idx] return lora_a def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: @@ -63,7 +62,8 @@ def forward( else: # TODO: simplify code below splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) + input_, num_partitions=self.tp_size + ) input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. @@ -74,8 +74,11 @@ def forward( output_ = output_parallel if not self.base_layer.skip_bias_add: - output = (output_ + self.base_layer.bias - if self.base_layer.bias is not None else output_) + output = ( + output_ + self.base_layer.bias + if self.base_layer.bias is not None + else output_ + ) output_bias = None else: output = output_ @@ -98,11 +101,11 @@ def can_replace_layer( return type(source_layer) is RowParallelLinear - # The following layer is based on the tensor parallelism strategy given in # Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, # https://arxiv.org/abs/2311.03285. + class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): """ Differs from RowParallelLinearWithLoRA by slicing the @@ -117,28 +120,26 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: shard_size = self.lora_b_stacked[0].shape[2] start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size - lora_b = lora_b[ start_idx:end_idx,:] + lora_b = lora_b[start_idx:end_idx, :] return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: if bias is None: return bias - self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], - self.lora_bias_stacked) + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) shard_size = self.lora_bias_stacked[0].shape[2] start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size bias = bias[start_idx:end_idx] return bias - def apply(self, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, - output.shape[-1]), output.shape + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape buffer = torch.zeros( (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), dtype=torch.float32, @@ -146,10 +147,11 @@ def apply(self, ) shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( - buffer, x, self.lora_a_stacked, 1.0) + buffer, x, self.lora_a_stacked, 1.0 + ) if not current_platform.can_update_inplace(): buffer = shrunk_buffer - if self.tp_size>1: + if self.tp_size > 1: buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py index 772d32a44c22..2da90f180ee7 100644 --- a/vllm/lora/layers/utils.py +++ b/vllm/lora/layers/utils.py @@ -45,8 +45,7 @@ def _not_fully_sharded_can_replace(can_replace): def dec(*args, **kwargs): decorate = kwargs.pop("decorate") if "decorate" in kwargs else True - condition = (not kwargs["lora_config"].fully_sharded_loras - if decorate else True) + condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True return can_replace(*args, **kwargs) and condition return dec @@ -59,7 +58,8 @@ def _fully_sharded_can_replace(can_replace): """ def dec(*args, **kwargs): - return (can_replace(*args, **kwargs) - and kwargs["lora_config"].fully_sharded_loras) + return ( + can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras + ) return dec diff --git a/vllm/lora/layers/vocal_parallel_embedding.py b/vllm/lora/layers/vocal_parallel_embedding.py index ca01c7e17fff..42eae1d4e3b0 100644 --- a/vllm/lora/layers/vocal_parallel_embedding.py +++ b/vllm/lora/layers/vocal_parallel_embedding.py @@ -9,15 +9,13 @@ from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.platforms import current_platform from .base import BaseLayerWithLoRA class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): - def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer @@ -25,24 +23,26 @@ def __init__(self, base_layer: VocabParallelEmbedding) -> None: self.embeddings_weights: Optional[torch.Tensor] def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> None: - + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: if self.base_layer.num_added_embeddings_per_partition > 0: # We can start adding lora weights self.embeddings_weights = self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:self. - base_layer.num_org_embeddings_per_partition + - self.base_layer.num_added_embeddings_per_partition] + self.base_layer.num_org_embeddings_per_partition : self.base_layer.num_org_embeddings_per_partition # noqa: E501 + + self.base_layer.num_added_embeddings_per_partition + ] self.embeddings_slice = ( - self.base_layer.shard_indices.added_vocab_start_index - - self.base_layer.org_vocab_size, - self.base_layer.shard_indices.added_vocab_end_index - - self.base_layer.org_vocab_size) + self.base_layer.shard_indices.added_vocab_start_index + - self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index + - self.base_layer.org_vocab_size, + ) self.base_layer.weight.data[ - self.base_layer.num_org_embeddings_per_partition:].fill_(0) + self.base_layer.num_org_embeddings_per_partition : + ].fill_(0) else: self.embeddings_slice = None self.embeddings_weights = None @@ -59,8 +59,7 @@ def create_lora_weights( self.lora_a_stacked = torch.zeros( ( max_loras, - self.base_layer.org_vocab_size + - lora_config.lora_extra_vocab_size, + self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -97,31 +96,30 @@ def set_lora( self.reset_lora(index) # NOTE self.lora_a_stacked is row-major, and lora_a is col-major, # so we need transpose here - self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( - lora_b, non_blocking=True) + self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True + ) + self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( + lora_b, non_blocking=True + ) if embeddings_tensor is not None: self.embeddings_tensors[ index, - :embeddings_tensor.shape[0], - :embeddings_tensor.shape[1], + : embeddings_tensor.shape[0], + : embeddings_tensor.shape[1], ].copy_(embeddings_tensor, non_blocking=True) if self.embeddings_slice is not None: # TODO(yard1): Optimize this copy, we don't need to copy # everything, just the modified part embeddings = self.embeddings_tensors.view( - self.embeddings_tensors.shape[0] * - self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1], self.embeddings_tensors.shape[2], - )[self.embeddings_slice[0]:self.embeddings_slice[1]] + )[self.embeddings_slice[0] : self.embeddings_slice[1]] assert self.embeddings_weights is not None - self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + self.embeddings_weights[: embeddings.shape[0]].copy_(embeddings) def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, - 1, 0) + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) # NB: Don't use torch.narrow here. torch.narrow triggers some # Dynamic Shape specialization in torch.compile @@ -133,26 +131,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x + indices_1, self.lora_a_stacked_2d, ) - full_output = self.base_layer.forward(x + - (indices_0 * added_tokens_mask)) + full_output = self.base_layer.forward(x + (indices_0 * added_tokens_mask)) full_output_org = full_output if full_output.ndim == 3: full_output = full_output.view( - full_output.shape[0] * full_output.shape[1], -1) + full_output.shape[0] * full_output.shape[1], -1 + ) if full_lora_a_embeddings.ndim == 3: full_lora_a_embeddings = full_lora_a_embeddings.view( - full_lora_a_embeddings.shape[0] * - full_lora_a_embeddings.shape[1], + full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1], -1, ) - lora_output: Optional[ - torch.Tensor] = self.punica_wrapper.add_lora_embedding( - full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, full_lora_a_embeddings, self.lora_b_stacked, add_input=True + ) if not current_platform.can_update_inplace(): full_output = lora_output diff --git a/vllm/lora/lora_weights.py b/vllm/lora/lora_weights.py index 90e18217d28b..d502c8eb543f 100644 --- a/vllm/lora/lora_weights.py +++ b/vllm/lora/lora_weights.py @@ -60,8 +60,9 @@ def is_packed(self) -> bool: @property def extra_vocab_size(self) -> int: - return self.embeddings_tensor.shape[ - 0] if self.embeddings_tensor is not None else 0 + return ( + self.embeddings_tensor.shape[0] if self.embeddings_tensor is not None else 0 + ) @classmethod def from_config( @@ -70,44 +71,54 @@ def from_config( peft_helper: PEFTHelper, embeddings_tensor: Optional[torch.Tensor] = None, ) -> "LoRALayerWeights": - return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None, - None, None, embeddings_tensor, - peft_helper.vllm_lora_scaling_factor) + return cls( + module_name, + peft_helper.r, + peft_helper.lora_alpha, + None, + None, + None, + embeddings_tensor, + peft_helper.vllm_lora_scaling_factor, + ) @classmethod def create_dummy_lora_weights( - cls, - module_name: str, - input_dim: int, - output_dim: int, - rank: int, - dtype: torch.dtype, - device: torch.types.Device, - embeddings_tensor_dim: Optional[int] = None, - bias_enabled: Optional[bool] = False) -> "LoRALayerWeights": + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.types.Device, + embeddings_tensor_dim: Optional[int] = None, + bias_enabled: Optional[bool] = False, + ) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() - lora_a = torch.zeros([rank, input_dim], - dtype=dtype, - device=device, - pin_memory=pin_memory) - lora_b = torch.zeros([output_dim, rank], - dtype=dtype, - device=device, - pin_memory=pin_memory) + lora_a = torch.zeros( + [rank, input_dim], dtype=dtype, device=device, pin_memory=pin_memory + ) + lora_b = torch.zeros( + [output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory + ) if bias_enabled: - bias = torch.zeros([output_dim], - dtype=dtype, - device=device, - pin_memory=pin_memory) + bias = torch.zeros( + [output_dim], dtype=dtype, device=device, pin_memory=pin_memory + ) else: bias = None - embeddings_tensor = torch.rand( - 10, - embeddings_tensor_dim, - dtype=dtype, - device=device, - pin_memory=pin_memory) if embeddings_tensor_dim else None + embeddings_tensor = ( + torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory, + ) + if embeddings_tensor_dim + else None + ) return cls( module_name, rank=rank, @@ -174,7 +185,8 @@ def pack( scaling=[ 1 if lora is not None else None # type: ignore for lora in loras - ]) + ], + ) return obj def optimize(self) -> "PackedLoRALayerWeights": diff --git a/vllm/lora/models.py b/vllm/lora/models.py index cc64cc78affa..edf34b483e9a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -17,10 +17,14 @@ from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper -from vllm.lora.utils import (from_layer, from_layer_logits_processor, - get_supported_lora_modules, - is_regex_target_modules, - parse_fine_tuned_lora_name, replace_submodule) +from vllm.lora.utils import ( + from_layer, + from_layer_logits_processor, + get_supported_lora_modules, + is_regex_target_modules, + parse_fine_tuned_lora_name, + replace_submodule, +) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import SupportsLoRA, supports_multimodal @@ -36,7 +40,6 @@ class AdapterLRUCache(LRUCache[int, T]): - def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): super().__init__(capacity) self.deactivate_fn = deactivate_fn @@ -62,7 +65,8 @@ def is_moe_model(model: nn.Module) -> bool: logger.warning_once( "For MoE models, vLLM currently does not support fused MoE LoRA " "inference. Please ensure that the loaded LoRA model does not " - "contain expert weights.") + "contain expert weights." + ) return True return False @@ -85,9 +89,9 @@ def __init__( """ self.id = lora_model_id - assert ( - lora_model_id - > 0), f"a valid lora id should be greater than 0, got {self.id}" + assert lora_model_id > 0, ( + f"a valid lora id should be greater than 0, got {self.id}" + ) self.rank = rank self.loras: dict[str, LoRALayerWeights] = loras @@ -103,8 +107,11 @@ def clone(self, lora_model_id: int) -> "LoRAModel": @property def extra_vocab_size(self) -> int: - return max(lora.extra_vocab_size - for lora in self.loras.values()) if self.loras else 0 + return ( + max(lora.extra_vocab_size for lora in self.loras.values()) + if self.loras + else 0 + ) def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: """Get LoRA for a given module by name""" @@ -133,23 +140,24 @@ def from_lora_tensors( loras: dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( - tensor_name, weights_mapper) + tensor_name, weights_mapper + ) if module_name not in loras: lora_embeddings_tensor = None if embeddings: assert embedding_modules is not None embeddings_module = next( - (k for k in embedding_modules if k in module_name), - None) + (k for k in embedding_modules if k in module_name), None + ) if embeddings_module: lora_embeddings_tensor = embeddings[ - embedding_modules[embeddings_module]].to( - device=device, dtype=dtype) + embedding_modules[embeddings_module] + ].to(device=device, dtype=dtype) if pin_memory: - lora_embeddings_tensor = ( - lora_embeddings_tensor.pin_memory()) + lora_embeddings_tensor = lora_embeddings_tensor.pin_memory() loras[module_name] = LoRALayerWeights.from_config( - module_name, peft_helper, lora_embeddings_tensor) + module_name, peft_helper, lora_embeddings_tensor + ) if is_bias: loras[module_name].bias = tensor.to(device=device, dtype=dtype) @@ -158,26 +166,24 @@ def from_lora_tensors( bias = bias.pin_memory() loras[module_name].bias = bias elif is_lora_a: - loras[module_name].lora_a = tensor.to(device=device, - dtype=dtype) + loras[module_name].lora_a = tensor.to(device=device, dtype=dtype) if pin_memory: - loras[module_name].lora_a = loras[ - module_name].lora_a.pin_memory() + loras[module_name].lora_a = loras[module_name].lora_a.pin_memory() else: - loras[module_name].lora_b = tensor.to(device=device, - dtype=dtype) + loras[module_name].lora_b = tensor.to(device=device, dtype=dtype) assert embedding_padding_modules is not None - if any(name in module_name - for name in embedding_padding_modules - ) and target_embedding_padding is not None: + if ( + any(name in module_name for name in embedding_padding_modules) + and target_embedding_padding is not None + ): lora_b = loras[module_name].lora_b assert target_embedding_padding >= lora_b.shape[0] addition = target_embedding_padding - lora_b.shape[0] loras[module_name].lora_b = torch.nn.functional.pad( - lora_b, (0, 0, 0, addition)) + lora_b, (0, 0, 0, addition) + ) if pin_memory: - loras[module_name].lora_b = loras[ - module_name].lora_b.pin_memory() + loras[module_name].lora_b = loras[module_name].lora_b.pin_memory() for lora in loras.values(): lora.optimize() @@ -186,19 +192,20 @@ def from_lora_tensors( @classmethod def from_local_checkpoint( - cls, - lora_dir: str, - expected_lora_modules: list[str], - peft_helper: PEFTHelper, - *, - lora_model_id: Optional[int] = None, - device: str = "cuda", - dtype: Optional[torch.dtype] = None, - target_embedding_padding: Optional[int] = None, - embedding_modules: Optional[dict[str, str]] = None, - embedding_padding_modules: Optional[list[str]] = None, - weights_mapper: Optional[WeightsMapper] = None, - tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel": + cls, + lora_dir: str, + expected_lora_modules: list[str], + peft_helper: PEFTHelper, + *, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[dict[str, str]] = None, + embedding_padding_modules: Optional[list[str]] = None, + weights_mapper: Optional[WeightsMapper] = None, + tensorizer_config_dict: Optional[dict] = None, + ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. Args: @@ -218,16 +225,17 @@ def from_local_checkpoint( lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") new_embeddings_tensor_path = os.path.join( - lora_dir, "new_embeddings.safetensors") - new_embeddings_bin_file_path = os.path.join(lora_dir, - "new_embeddings.bin") + lora_dir, "new_embeddings.safetensors" + ) + new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") tensors: dict[str, torch.Tensor] = {} unexpected_modules: list[Union[list[str], str]] = [] def check_unexpected_modules(modules: dict): for lora_module in modules.keys(): # noqa module_name, _, _ = parse_fine_tuned_lora_name( - lora_module, weights_mapper) + lora_module, weights_mapper + ) part_name = module_name.split(".")[-1] if part_name not in expected_lora_modules: unexpected_modules.append(module_name) @@ -236,19 +244,22 @@ def check_unexpected_modules(modules: dict): f"While loading {lora_dir}, expected" f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." - f" Please verify that the loaded LoRA module is correct") + f" Please verify that the loaded LoRA module is correct" + ) if tensorizer_config_dict: from tensorizer import TensorDeserializer tensorizer_config = TensorizerConfig(**tensorizer_config_dict) - lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir, - "adapter_model.tensors") + lora_tensor_path = os.path.join( + tensorizer_config.tensorizer_dir, "adapter_model.tensors" + ) tensorizer_args = tensorizer_config._construct_tensorizer_args() tensors = TensorDeserializer( lora_tensor_path, dtype=tensorizer_config.dtype, - **tensorizer_args.deserialization_kwargs) + **tensorizer_args.deserialization_kwargs, + ) check_unexpected_modules(tensors) elif os.path.isfile(lora_tensor_path): @@ -259,14 +270,12 @@ def check_unexpected_modules(modules: dict): # loraified. C won’t exist in the safetensor but it will exist in # the target_modules of the adapter_config.json. unexpected_modules = [] - with safetensors.safe_open(lora_tensor_path, - framework="pt") as f: # type: ignore + with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore # Load tensors if there are only expected modules. check_unexpected_modules(f) for module in f.keys(): # noqa tensors[module] = f.get_tensor(module) - elif os.path.isfile(lora_bin_file_path) or os.path.isfile( - lora_pt_file_path): + elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path): # When a bin/pt file is provided, we rely on config to find # unexpected modules. unexpected_modules = [] @@ -284,33 +293,33 @@ def check_unexpected_modules(modules: dict): # https://github.com/vllm-project/vllm/pull/5909. But there's no # other better mechanism. if unexpected_modules and not is_regex_target_modules( - peft_helper.target_modules, expected_lora_modules): + peft_helper.target_modules, expected_lora_modules + ): raise ValueError( f"While loading {lora_dir}, expected" f" target modules in {expected_lora_modules}" f" but received {unexpected_modules}." - f" Please verify that the loaded LoRA module is correct") - lora_file_path = (lora_bin_file_path - if os.path.isfile(lora_bin_file_path) else - lora_pt_file_path) - tensors = torch.load(lora_file_path, - map_location=device, - weights_only=True) + f" Please verify that the loaded LoRA module is correct" + ) + lora_file_path = ( + lora_bin_file_path + if os.path.isfile(lora_bin_file_path) + else lora_pt_file_path + ) + tensors = torch.load(lora_file_path, map_location=device, weights_only=True) else: raise ValueError(f"{lora_dir} doesn't contain tensors") embeddings = None if os.path.isfile(new_embeddings_tensor_path): - embeddings = safetensors.torch.load_file( - new_embeddings_tensor_path) + embeddings = safetensors.torch.load_file(new_embeddings_tensor_path) elif os.path.isfile(new_embeddings_bin_file_path): - embeddings = torch.load(new_embeddings_bin_file_path, - map_location=device, - weights_only=True) + embeddings = torch.load( + new_embeddings_bin_file_path, map_location=device, weights_only=True + ) return cls.from_lora_tensors( - lora_model_id=get_lora_id() - if lora_model_id is None else lora_model_id, + lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, tensors=tensors, peft_helper=peft_helper, device=device, @@ -319,7 +328,8 @@ def check_unexpected_modules(modules: dict): target_embedding_padding=target_embedding_padding, embedding_modules=embedding_modules, embedding_padding_modules=embedding_padding_modules, - weights_mapper=weights_mapper) + weights_mapper=weights_mapper, + ) class LoRAModelManager: @@ -374,7 +384,8 @@ def __init__( supports_multimodal(self.model) # In case the model only supports LoRA for # text modules (e.g. ChatGLM) - and hasattr(self.model, "get_mm_mapping")) + and hasattr(self.model, "get_mm_mapping") + ) self.is_pooling_model = is_pooling_model(self.model) self.is_moe_model = is_moe_model(self.model) self.packed_modules: dict[str, list[str]] = {} @@ -407,15 +418,21 @@ def activate_adapter( if lora_id in self._active_adapters: return False first_free_slot = next( - ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) - if lora_id is None), None) + ( + (i, lora_id) + for i, lora_id in enumerate(self.lora_index_to_id) + if lora_id is None + ), + None, + ) if first_free_slot is None: raise ValueError("No free lora slots") index, _ = first_free_slot self._active_adapters[lora_id] = None lora_model = self._registered_adapters[lora_id] - logger.debug("Activating LoRA. int id: %d, slot index: %d", - lora_model.id, index) + logger.debug( + "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index + ) self.lora_index_to_id[index] = lora_model.id for module_name, module in self.modules.items(): module_lora = self._get_lora_layer_weights(lora_model, module_name) @@ -423,17 +440,22 @@ def activate_adapter( module_lora.optimize() # Bias is not explicitly enabled with the flag enable_lora_bias. bias = module_lora.bias - if ((torch.is_tensor(bias) or - (isinstance(bias, Sequence) and any(b is not None - for b in bias))) - and not self.lora_config.bias_enabled): + if ( + torch.is_tensor(bias) + or (isinstance(bias, Sequence) and any(b is not None for b in bias)) + ) and not self.lora_config.bias_enabled: module_lora.bias = None raise ValueError( f"Adapter bias cannot be used for {module_name}" - " without --enable-lora-bias.") - module.set_lora(index, module_lora.lora_a, module_lora.lora_b, - module_lora.embeddings_tensor, - module_lora.bias) + " without --enable-lora-bias." + ) + module.set_lora( + index, + module_lora.lora_a, + module_lora.lora_b, + module_lora.embeddings_tensor, + module_lora.bias, + ) else: module.reset_lora(index) return True @@ -453,7 +475,8 @@ def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" raise NotImplementedError( "Pinning is not supported in LoRAModelManager. " - "Use LRUCacheLoRAModelManager for pinning") # type: ignore + "Use LRUCacheLoRAModelManager for pinning" + ) # type: ignore def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: # update lora states @@ -472,16 +495,14 @@ def remove_all_adapters(self): self._active_adapters.clear() def _create_lora_modules(self): - def _parent_module(module_name: str) -> str: # module name is a dot separated name. # for example: # - given an input 'x.y.z' return 'x.y' # - given an input 'x' return '' - return module_name.rpartition('.')[0] + return module_name.rpartition(".")[0] - for module_name, module in self.model.named_modules( - remove_duplicate=False): + for module_name, module in self.model.named_modules(remove_duplicate=False): if isinstance(module, PPMissingLayer): continue if not self._match_target_modules(module_name): @@ -498,35 +519,48 @@ def _parent_module(module_name: str) -> str: parts = module_name.split(".")[-1] packed_moduled_lst = self.packed_modules_mapping.get(parts, []) new_module = replace_submodule( - self.model, module_name, - from_layer(module, self.lora_slots, self.lora_config, - packed_moduled_lst, self.model.config)) + self.model, + module_name, + from_layer( + module, + self.lora_slots, + self.lora_config, + packed_moduled_lst, + self.model.config, + ), + ) # (yard1): TODO make this more robust if "lm_head" in module_name: - logits_processor_module_name = 'logits_processor' + logits_processor_module_name = "logits_processor" parent_module = _parent_module(module_name) if parent_module: logits_processor_module_name = ( - f"{parent_module}.{logits_processor_module_name}") + f"{parent_module}.{logits_processor_module_name}" + ) logits_processor_module = self.model.get_submodule( - logits_processor_module_name) + logits_processor_module_name + ) new_module = replace_submodule( - self.model, logits_processor_module_name, - from_layer_logits_processor(logits_processor_module, - module, self.lora_slots, - self.lora_config, - self.model.config)) + self.model, + logits_processor_module_name, + from_layer_logits_processor( + logits_processor_module, + module, + self.lora_slots, + self.lora_config, + self.model.config, + ), + ) # In some models, especially multimodal ones, layers with the same # name may have different types, such as nn.Linear and # ReplicatedLinear. The nn.Linear layers cannot be replaced with # LoRA layers, leading to assertion error. The following check # aims to prevent this error - if self.supports_mm and not isinstance(new_module, - BaseLayerWithLoRA): + if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): continue self.register_module(module_name, new_module) self._register_packed_modules(module_name) @@ -538,33 +572,41 @@ def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): self.modules[module_name] = module def create_dummy_lora( - self, - lora_id: int, - rank: int, - embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel: + self, + lora_id: int, + rank: int, + embedding_modules: Optional[dict[str, str]] = None, + ) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}) for module_name, module in self.model.named_modules(): bias_enabled = self.lora_config.bias_enabled - if (not self._match_target_modules(module_name) - or not isinstance(module, BaseLayerWithLoRA) - or self._filter_unsupported_mm_module(module_name)): + if ( + not self._match_target_modules(module_name) + or not isinstance(module, BaseLayerWithLoRA) + or self._filter_unsupported_mm_module(module_name) + ): continue parts = module_name.split(".") if module_name not in self.packed_modules: assert embedding_modules is not None if parts[-1] in embedding_modules: - input_dim = (module.base_layer.org_vocab_size + - self.lora_config.lora_extra_vocab_size if - hasattr(module.base_layer, "org_vocab_size") - else module.base_layer.weight.shape[1]) - output_dim = module.base_layer.embedding_dim if hasattr( - module.base_layer, - "embedding_dim") else module.base_layer.weight.shape[0] - embeddings_tensor_dim = (module.base_layer.embedding_dim if - hasattr(module.base_layer, - "embedding_dim") else - module.base_layer.weight.shape[1]) + input_dim = ( + module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size + if hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1] + ) + output_dim = ( + module.base_layer.embedding_dim + if hasattr(module.base_layer, "embedding_dim") + else module.base_layer.weight.shape[0] + ) + embeddings_tensor_dim = ( + module.base_layer.embedding_dim + if hasattr(module.base_layer, "embedding_dim") + else module.base_layer.weight.shape[1] + ) lora = LoRALayerWeights.create_dummy_lora_weights( module_name, input_dim, @@ -573,7 +615,8 @@ def create_dummy_lora( module.lora_a_stacked[0].dtype, "cpu", embeddings_tensor_dim=embeddings_tensor_dim, - bias_enabled=bias_enabled) + bias_enabled=bias_enabled, + ) else: lora = LoRALayerWeights.create_dummy_lora_weights( module_name, @@ -606,9 +649,11 @@ def create_dummy_lora( def _match_target_modules(self, module_name: str): return any( re.match( - r".*\.{target_module}$".format(target_module=target_module), - module_name) or target_module == module_name - for target_module in self.supported_lora_modules) + r".*\.{target_module}$".format(target_module=target_module), module_name + ) + or target_module == module_name + for target_module in self.supported_lora_modules + ) def _filter_unsupported_mm_module(self, module_name: str) -> bool: """ @@ -619,8 +664,7 @@ def _filter_unsupported_mm_module(self, module_name: str) -> bool: if self.supports_mm: module_mapping: MultiModelKeys = self.model.get_mm_mapping() prefix_lst = module_mapping.connector + module_mapping.tower_model - return any( - [module_name.startswith(prefix) for prefix in prefix_lst]) + return any([module_name.startswith(prefix) for prefix in prefix_lst]) return False def _register_packed_modules(self, module_full_name: str) -> None: @@ -654,23 +698,22 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: continue replacement_loras[i] = None # HACK Temporary solution for the pool model. - if self.is_pooling_model and not lora_model.check_lora_name( - module_name): + if self.is_pooling_model and not lora_model.check_lora_name(module_name): replaced_module_name = module_name.replace("model.", "") if lora_model.check_lora_name(module_name): module_name = replaced_module_name lora_model.loras[module_name] = PackedLoRALayerWeights.pack( - replacement_loras) + replacement_loras + ) # Remove the modules that have been replaced. for module in replaced_module: lora_model.loras.pop(module, None) def _get_lora_layer_weights( - self, lora_model: LoRAModel, - module_name: str) -> Optional[LoRALayerWeights]: + self, lora_model: LoRAModel, module_name: str + ) -> Optional[LoRALayerWeights]: org_module_name = module_name - if self.is_pooling_model and not lora_model.check_lora_name( - module_name): + if self.is_pooling_model and not lora_model.check_lora_name(module_name): # If it's a pool model, and the layer name is not found, # remove the prefix 'model.' and search again. module_name = module_name.replace("model.", "") @@ -678,7 +721,8 @@ def _get_lora_layer_weights( org_module_name = module_name logger.info_once( "For the pool model, successfully loaded the LoRA weights " - "after removing the prefix 'model.'.") + "after removing the prefix 'model.'." + ) return lora_model.get_lora(org_module_name) def deactivate_adapter(self, adapter_id: int) -> bool: @@ -689,8 +733,7 @@ def deactivate_adapter(self, adapter_id: int) -> bool: return True def add_adapter(self, adapter: LoRAModel) -> bool: - logger.debug("Adding lora. Model id: %d, " - "int id: %d", adapter.id, adapter.id) + logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id) if adapter.id in self._registered_adapters: return False if len(self._registered_adapters) >= self.capacity: @@ -718,24 +761,31 @@ def get_adapter(self, adapter_id: int) -> Optional[LoRAModel]: class LoRALRUCache(AdapterLRUCache[LoRAModel]): - - def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], - bool]): + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]): super().__init__(capacity, deactivate_lora_fn) class LRUCacheLoRAModelManager(LoRAModelManager): """A model manager that manages multiple LoRAs with LRU cache.""" - def __init__(self, model: nn.Module, max_num_seqs: int, - max_num_batched_tokens: int, vocab_size: int, - lora_config: LoRAConfig, device: torch.device): - super().__init__(model, max_num_seqs, max_num_batched_tokens, - vocab_size, lora_config, device) + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + ): + super().__init__( + model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device + ) self._registered_adapters: LoRALRUCache = LoRALRUCache( - self.capacity, self.deactivate_adapter) + self.capacity, self.deactivate_adapter + ) self._active_adapters: LoRALRUCache = LoRALRUCache( - self.lora_slots, self._deactivate_adapter) + self.lora_slots, self._deactivate_adapter + ) def list_adapters(self) -> dict[int, LoRAModel]: """List all registered LoRAModels.""" @@ -743,8 +793,7 @@ def list_adapters(self) -> dict[int, LoRAModel]: def add_adapter(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" - logger.debug("Adding lora. Model id: %d, " - "int id: %d", lora.id, lora.id) + logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id) if lora.id not in self._registered_adapters: self._add_adapter(lora) was_added = True @@ -758,8 +807,10 @@ def activate_adapter( self, lora_id: int, ) -> bool: - if lora_id not in self._active_adapters and len( - self._active_adapters) >= self.lora_slots: + if ( + lora_id not in self._active_adapters + and len(self._active_adapters) >= self.lora_slots + ): self._active_adapters.remove_oldest() result = super().activate_adapter(lora_id) # We always touch to update the LRU cache order @@ -782,8 +833,9 @@ def _pin_lora_in_cpu_cache(self, lora_id: int): try: self._registered_adapters.pin(lora_id) except ValueError as err: - raise ValueError("Pinning failed. " - f"LoRA {lora_id} is not registered.") from err + raise ValueError( + f"Pinning failed. LoRA {lora_id} is not registered." + ) from err def _pin_lora_in_gpu_cache(self, lora_id: int): if lora_id not in self._active_adapters: @@ -794,14 +846,15 @@ def _pin_lora_in_gpu_cache(self, lora_id: int): def create_lora_manager( - model: nn.Module, - max_num_seqs: int, - max_num_batched_tokens: int, - vocab_size: int, - lora_config: LoRAConfig, - device: torch.device, - lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, - **kwargs) -> LoRAModelManager: + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, + **kwargs, +) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" if not isinstance(model, SupportsLoRA): raise ValueError(f"Model {type(model)} is not supported for LoRA.") @@ -812,5 +865,6 @@ def create_lora_manager( vocab_size=vocab_size, lora_config=lora_config, device=device, - **kwargs) + **kwargs, + ) return lora_manager diff --git a/vllm/lora/ops/ipex_ops/__init__.py b/vllm/lora/ops/ipex_ops/__init__.py index 5daa432493b1..f5a5e0e6f951 100644 --- a/vllm/lora/ops/ipex_ops/__init__.py +++ b/vllm/lora/ops/ipex_ops/__init__.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.lora.ops.ipex_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink) +from vllm.lora.ops.ipex_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink __all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/ipex_ops/lora_ops.py b/vllm/lora/ops/ipex_ops/lora_ops.py index 7590c868ecb6..0767f90b2f9e 100644 --- a/vllm/lora/ops/ipex_ops/lora_ops.py +++ b/vllm/lora/ops/ipex_ops/lora_ops.py @@ -13,32 +13,45 @@ raise e -def bgmv_shrink(inputs: torch.Tensor, - lora_a_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0) -> None: - - ipex.llm.functional.bgmv_shrink(inputs, lora_a_weights, output_tensor, - lora_indices_tensor, scaling) - - -def bgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True) -> None: - ipex.llm.functional.bgmv_expand(inputs, lora_b_weights, output_tensor, - lora_indices_tensor, add_inputs) - - -def bgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True) -> None: - ipex.llm.functional.bgmv_expand_slice(inputs, lora_b_weights, - output_tensor, lora_indices_tensor, - slice_offset, slice_size, add_inputs) +def bgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> None: + ipex.llm.functional.bgmv_shrink( + inputs, lora_a_weights, output_tensor, lora_indices_tensor, scaling + ) + + +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> None: + ipex.llm.functional.bgmv_expand( + inputs, lora_b_weights, output_tensor, lora_indices_tensor, add_inputs + ) + + +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> None: + ipex.llm.functional.bgmv_expand_slice( + inputs, + lora_b_weights, + output_tensor, + lora_indices_tensor, + slice_offset, + slice_size, + add_inputs, + ) diff --git a/vllm/lora/ops/torch_ops/__init__.py b/vllm/lora/ops/torch_ops/__init__.py index 22aa3c63dce1..89865af4e9b8 100644 --- a/vllm/lora/ops/torch_ops/__init__.py +++ b/vllm/lora/ops/torch_ops/__init__.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401 -from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, - sgmv_shrink) +from vllm.lora.ops.torch_ops.lora_ops import ( + bgmv_expand, # noqa: F401 + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink, +) __all__ = [ "bgmv_expand", diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py index cba5baad8668..4fc6248d5448 100644 --- a/vllm/lora/ops/torch_ops/lora_ops.py +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -4,30 +4,31 @@ import torch -def sgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - add_inputs: bool = False): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - seq_len_tensor) - - bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, - add_inputs) - - -def bgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs) + + +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) @@ -58,62 +59,70 @@ def sgmv_shrink( token_nums: int, scaling: float, ): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - seq_len_tensor) + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, - scaling) + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling) -def bgmv_shrink(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) +def bgmv_shrink( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] - - -def sgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - slice_offset: int, - slice_size: int, - add_inputs: bool = False): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - seq_len_tensor) - - bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, - slice_offset, slice_size, add_inputs) - - -def bgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) + output_tensor[:, : outputs.shape[1]] = scaling * outputs[:] + + +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_expand_slice( + inputs, + lora_b_weights, + output_tensor, + exploded_indices, + slice_offset, + slice_size, + add_inputs, + ) + + +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) inputs = inputs.to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) if add_inputs: - output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:] else: - output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] + output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:] diff --git a/vllm/lora/ops/triton_ops/kernel_utils.py b/vllm/lora/ops/triton_ops/kernel_utils.py index e93064d0c83a..f6397a68ddb8 100644 --- a/vllm/lora/ops/triton_ops/kernel_utils.py +++ b/vllm/lora/ops/triton_ops/kernel_utils.py @@ -3,23 +3,35 @@ """ Utilities for Punica kernel construction. """ + from vllm.triton_utils import tl, triton @triton.jit -def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr, - b_dtype: tl.constexpr): +def mm_k( + a_ptr, + b_ptr, + ak_stride, + bk_stride, + offset_k, + K: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + CAST_TYPE: tl.constexpr, + b_dtype: tl.constexpr, +): """ Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of B (k x n), iterate, through the K dimension to compute the partial/complete matrix block product. If SPLIT_K == 1, the output m x n product is complete. If SPLIT_K > 1, the thread block computes partial outputs. The partial - outputs are then atomically summed in the caller code. + outputs are then atomically summed in the caller code. Args: - a_ptr: Array of pointers, identifying rows of A + a_ptr: Array of pointers, identifying rows of A b_ptr: Array of pointers, identifying columns of B ak_stride: K dimension stride of the A matrix bk_stride: K dimension stride of the B matrix @@ -29,7 +41,7 @@ def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, BLOCK_K: K dimension atom EVEN_K: True if the blocks of A and B can be loaded without any masking. - SPLIT_K: Parameter signifying parallelism in the K dimension. + SPLIT_K: Parameter signifying parallelism in the K dimension. CAST_TYPE: if True, cast the values from the A matrix to the B matrix dtype. b_dtype: datatype of the B matrix @@ -40,14 +52,12 @@ def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr, tiled_a = tl.load(a_ptr) tiled_b = tl.load(b_ptr) else: - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] - < K - k * (BLOCK_K * SPLIT_K), - other=0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] - < K - k * (BLOCK_K * SPLIT_K), - other=0) + tiled_a = tl.load( + a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0 + ) + tiled_b = tl.load( + b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0 + ) if CAST_TYPE: tiled_a = tiled_a.to(b_dtype) accumulator += tl.dot( @@ -121,7 +131,8 @@ def do_expand_kernel( else: cur_input_ptr = input_ptr + slice_id * input_d0_stride cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(out_ptr.dtype.element_ty)) + tl.pointer_type(out_ptr.dtype.element_ty) + ) # Identify the column indices of B to process. offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N @@ -129,17 +140,35 @@ def do_expand_kernel( # Identify A and B block pointers offset_k = tl.arange(0, BLOCK_K) - a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride) - b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + - offset_k[:, None] * cur_lora_d2_stride + - rbn[None, :] * cur_lora_d1_stride) + a_ptr = ( + cur_input_ptr + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride + ) + b_ptr = ( + cur_lora_ptr + + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride + ) # Compute the block matrix product. SPLIT_K = 1 - accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, - offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, - CAST_TYPE, cur_lora_ptr.dtype.element_ty) + accumulator = mm_k( + a_ptr, + b_ptr, + input_d2_stride, + cur_lora_d2_stride, + offset_k, + K, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + CAST_TYPE, + cur_lora_ptr.dtype.element_ty, + ) tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) if SLICE_NUM == 1: @@ -150,10 +179,12 @@ def do_expand_kernel( # Identify the C output pointers to store the results of the accumulator. offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start offset_cm = tl.arange(0, BLOCK_M) - c_ptr = (out_ptr + ram[:, None] * output_d0_stride + - offset_cn[None, :] * output_d1_stride) - c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] - < (cur_slice_start + N)) + c_ptr = ( + out_ptr + + ram[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride + ) + c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < (cur_slice_start + N)) if ADD_INPUTS: tiled_out = tl.load(c_ptr, mask=c_mask) @@ -207,7 +238,8 @@ def do_shrink_kernel( else: # current lora ptr cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(input_ptr.dtype.element_ty)) + tl.pointer_type(input_ptr.dtype.element_ty) + ) # Identify the column indices of B to process. offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N @@ -215,24 +247,42 @@ def do_shrink_kernel( # Identify A and B block pointers offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) - a_ptr = (input_ptr + ram[:, None] * input_d0_stride + - offset_k[None, :] * input_d1_stride) - b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + - rbn[None, :] * lora_d1_stride + - offset_k[:, None] * lora_d2_stride) + a_ptr = ( + input_ptr + ram[:, None] * input_d0_stride + offset_k[None, :] * input_d1_stride + ) + b_ptr = ( + cur_lora_ptr + + lora_d0_stride * lora_index + + rbn[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride + ) # Compute partial/complete block matrix product. - accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k, - K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False, - cur_lora_ptr.dtype.element_ty) + accumulator = mm_k( + a_ptr, + b_ptr, + input_d1_stride, + lora_d2_stride, + offset_k, + K, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + False, + cur_lora_ptr.dtype.element_ty, + ) # Identify the C output pointers to store the results of the accumulator. offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N offset_cm = tl.arange(0, BLOCK_M) - cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + - slice_id * output_d0_stride) - c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[ - None, :] * output_d2_stride + cur_out_ptr = out_ptr if SLICE_NUM == 1 else out_ptr + slice_id * output_d0_stride + c_ptr = ( + cur_out_ptr + + ram[:, None] * output_d1_stride + + offset_cn[None, :] * output_d2_stride + ) c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N) accumulator *= scaling diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index 467cbaa8af48..a7a552b9903d 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -17,35 +17,35 @@ @triton.jit def _lora_expand_kernel( - input_ptr, - lora_ptr, - out_ptr, - M, - N, - K, - token_indices_sorted_by_lora_ids, - num_tokens_per_lora, - lora_token_start_loc, - lora_ids, - slice_start_loc, - input_d0_stride, - input_d1_stride, - input_d2_stride, # 1 - ls_d0_ptr, - ls_d1_ptr, - ls_d2_ptr, # 1 - output_d0_stride, - output_d1_stride, # 1 - output_hs_ptr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - ADD_INPUTS: tl.constexpr, - CAST_TYPE: tl.constexpr, - SLICE_NUM: tl.constexpr, - SAME_STRIDE: tl.constexpr): - + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, # 1 + output_d0_stride, + output_d1_stride, # 1 + output_hs_ptr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr, +): cta_n_num = tl.cdiv(N, BLOCK_N) cta_m_num = tl.cdiv(M, BLOCK_M) @@ -81,8 +81,9 @@ def _lora_expand_kernel( # Identify all rows that this CTA should process. lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) - cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + - lora_m_indices_start + cta_m_offset) + cta_lora_seq_indices = ( + token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + ) # Load all relevant row indices. offset_m = tl.arange(0, BLOCK_M) % cta_m_len @@ -119,22 +120,21 @@ def _lora_expand_kernel( SLICE_NUM, EVEN_K, CAST_TYPE, - ADD_INPUTS) + ADD_INPUTS, + ) @torch.inference_mode() def _lora_expand( inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] - lora_b_weights: list[ - torch.Tensor], # shape [num_lora, hidden_size, lora_rank] - output_tensor: torch. - Tensor, # shape [num_tokens, hidden_size * num_slices] + lora_b_weights: list[torch.Tensor], # shape [num_lora, hidden_size, lora_rank] + output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices] token_lora_mapping: torch.Tensor, # shape [num_tokens] token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] - no_lora_flag_cpu: torch.Tensor, # shape [1] + no_lora_flag_cpu: torch.Tensor, # shape [1] offset_start: int = 0, add_inputs: bool = False, ) -> None: @@ -149,7 +149,7 @@ def _lora_expand( token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from the A matrix grouped by LoRA IDs. num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number - of tokens that are to be processed by LoRA ID lora_ids[i] + of tokens that are to be processed by LoRA ID lora_ids[i] lora_token_start_loc (torch.Tensor): A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] @@ -158,9 +158,9 @@ def _lora_expand( lora_ids (torch.Tensor): LoRA ids to process. no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates if there are any requests that require LoRA. - offset_start (int, optional): Offset start for output_tensor. + offset_start (int, optional): Offset start for output_tensor. Defaults to 0. - add_inputs (bool, optional): Whether to add the input tensor to the + add_inputs (bool, optional): Whether to add the input tensor to the output tensor. Defaults to False. """ @@ -179,15 +179,20 @@ def _lora_expand( # metadata sanity check. M = inputs.size(1) assert token_lora_mapping.size(0) == M - assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( - 0) + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 - (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, - lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor, - same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start, - inputs.device) + ( + slice_start_tensor, + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + hidden_sizes_tensor, + same_stride, + MAX_N, + ) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device) K = lora_b_weights[0].shape[-1] # K= rank ADD_INPUTS = add_inputs @@ -206,8 +211,8 @@ def _lora_expand( EVEN_K = K % BLOCK_K == 0 # type: ignore if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [ - torch.float16, - torch.bfloat16, + torch.float16, + torch.bfloat16, ]: CAST_TYPE = True diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index e27604728ed0..df343305d710 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -30,39 +30,35 @@ class LoRAKernelMeta: no_lora_flag_cpu: torch.Tensor @staticmethod - def make(max_loras: int, max_num_tokens: int, - device: Union[torch.device, str]) -> "LoRAKernelMeta": - - token_lora_mapping = torch.empty(max_num_tokens, - dtype=torch.int32, - device=device) + def make( + max_loras: int, max_num_tokens: int, device: Union[torch.device, str] + ) -> "LoRAKernelMeta": + token_lora_mapping = torch.empty( + max_num_tokens, dtype=torch.int32, device=device + ) - token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens, - dtype=torch.int32, - device=device) + token_indices_sorted_by_lora_ids = torch.empty( + max_num_tokens, dtype=torch.int32, device=device + ) # +1 because "no-lora" is also a possibility # example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1] # is a possibility. - active_lora_ids = torch.empty(max_loras + 1, - dtype=torch.int32, - device=device) + active_lora_ids = torch.empty(max_loras + 1, dtype=torch.int32, device=device) # using running example, [3, 10, 5, 2] is a possibility. - num_tokens_per_lora = torch.zeros(max_loras + 1, - dtype=torch.int32, - device=device) + num_tokens_per_lora = torch.zeros( + max_loras + 1, dtype=torch.int32, device=device + ) # +2 for this because, the first index is always 0. # using running example, lora_token_start_loc # is [0, 3, 13, 18, 20]. - lora_token_start_loc = torch.zeros(max_loras + 2, - dtype=torch.int32, - device=device) + lora_token_start_loc = torch.zeros( + max_loras + 2, dtype=torch.int32, device=device + ) - no_lora_flag_cpu = torch.tensor([False], - dtype=torch.bool, - device='cpu') + no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu") return LoRAKernelMeta( token_lora_mapping=token_lora_mapping, @@ -70,7 +66,8 @@ def make(max_loras: int, max_num_tokens: int, active_lora_ids=active_lora_ids, num_tokens_per_lora=num_tokens_per_lora, lora_token_start_loc=lora_token_start_loc, - no_lora_flag_cpu=no_lora_flag_cpu) + no_lora_flag_cpu=no_lora_flag_cpu, + ) def _reset(self): self.active_lora_ids.fill_(-1) @@ -100,34 +97,44 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: num_tokens = token_lora_mapping.size(0) # copy token lora mapping - self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping, - non_blocking=True) + self.token_lora_mapping[:num_tokens].copy_( + token_lora_mapping, non_blocking=True + ) # token_indices_sorted_by_lora_ids - _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, - stable=True) + _, token_indices_sorted_by_lora_ids = torch.sort( + token_lora_mapping, stable=True + ) # start gpu transfer self.token_indices_sorted_by_lora_ids[:num_tokens].copy_( - token_indices_sorted_by_lora_ids, non_blocking=True) + token_indices_sorted_by_lora_ids, non_blocking=True + ) # active_lora_ids, num_tokens_per_lora - lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping, - sorted=True, - return_counts=True) - self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, - non_blocking=True) - self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_( - num_tokens_per_lora, non_blocking=True) + lora_ids, num_tokens_per_lora = torch.unique( + token_lora_mapping, sorted=True, return_counts=True + ) + self.active_lora_ids[: lora_ids.size(0)].copy_(lora_ids, non_blocking=True) + self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_( + num_tokens_per_lora, non_blocking=True + ) # lora_token_start_loc lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) - self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_( - lora_token_start_loc, non_blocking=True) + self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_( + lora_token_start_loc, non_blocking=True + ) def meta_args( self, token_nums: int - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, torch.Tensor]: + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: """ This function returns the kernel metadata required for the current forward pass execution of the kernel. The function returns all the diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 57da93c226d2..1e7e43e30de7 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -16,16 +16,33 @@ @triton.jit -def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, - token_indices_sorted_by_lora_ids, num_tokens_per_lora, - lora_token_start_loc, lora_ids, scaling, - input_d0_stride, input_d1_stride, lora_d0_stride, - lora_d1_stride, lora_d2_stride, output_d0_stride, - output_d1_stride, output_d2_stride, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, - SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr): - +def _lora_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + input_d0_stride, + input_d1_stride, + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + output_d0_stride, + output_d1_stride, + output_d2_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr, +): cta_n_num = tl.cdiv(N, BLOCK_N) cta_m_num = tl.cdiv(M, BLOCK_M) @@ -54,8 +71,9 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, # Identify all rows that this CTA should process. lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) - cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + - lora_m_indices_start + cta_m_offset) + cta_lora_seq_indices = ( + token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + ) # Load all relevant row indices. offset_m = tl.arange(0, BLOCK_M) % cta_m_len @@ -90,17 +108,17 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, BLOCK_K, EVEN_K, SPLIT_K, - SLICE_NUM) + SLICE_NUM, + ) @torch.inference_mode() def _lora_shrink( inputs: torch.Tensor, # shape [num_tokens, hidden_size] - lora_a_weights: list[ - torch.Tensor], # shape [num_loras, lora_rank, hidden_size] + lora_a_weights: list[torch.Tensor], # shape [num_loras, lora_rank, hidden_size] output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] token_lora_mapping: torch.Tensor, # shape [num_tokens] - token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] @@ -118,7 +136,7 @@ def _lora_shrink( token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from the A matrix grouped by LoRA IDs. num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number - of tokens that are to be processed by LoRA ID lora_ids[i] + of tokens that are to be processed by LoRA ID lora_ids[i] lora_token_start_loc (torch.Tensor): A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] @@ -147,13 +165,13 @@ def _lora_shrink( # metadata sanity check M = inputs.size(0) assert token_lora_mapping.size(0) == M - assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( - 0) + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 - (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, - lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device) + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = ( + _get_lora_a_ptr(lora_a_weights, inputs.device) + ) N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank NUM_SLICES = len(lora_a_weights) MAX_LORAS = lora_ids.size(0) diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 4c50fbd27051..3a3e8fc8931e 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -9,9 +9,9 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): """ - `_LORA_A_PTR_DICT` collects the required information during `profile_run`, + `_LORA_A_PTR_DICT` collects the required information during `profile_run`, After this, it remains constant and subsequent usage is through LUT. - Refer to: + Refer to: https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py """ key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) @@ -35,14 +35,15 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): lora_strides_d1.append(lora_a_weight.stride(1)) lora_strides_d2.append(lora_a_weight.stride(2)) if len(lora_a_weights) > 1: - lora_ptr_tensor = torch.tensor(tensor_ptrs, - device=device, - dtype=torch.uint64) + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) else: lora_ptr_tensor = lora_a_weights[0] - if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1 - or len(set(lora_strides_d2)) > 1): + if ( + len(set(lora_strides_d0)) > 1 + or len(set(lora_strides_d1)) > 1 + or len(set(lora_strides_d2)) > 1 + ): raise ValueError("All LoRA weights must have the same stride.") _LORA_A_PTR_DICT[key] = ( @@ -54,12 +55,13 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): return _LORA_A_PTR_DICT.get(key) -def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, - device: torch.device): - """ - `_LORA_B_PTR_DICT` collects the required information during `profile_run`, +def _get_lora_b_ptr( + lora_weights: list[torch.Tensor], offset_start: int, device: torch.device +): + """ + `_LORA_B_PTR_DICT` collects the required information during `profile_run`, After this, it remains constant and subsequent usage is through LUT. - Refer to: + Refer to: https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py """ @@ -91,20 +93,21 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, if len(lora_weights) > 1: # note these are device tensors - lora_ptr_tensor = torch.tensor(tensor_ptrs, - device=device, - dtype=torch.uint64) - slice_start_tensor = torch.tensor(slice_offset_lst, - device=device, - dtype=torch.uint64) + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) + slice_start_tensor = torch.tensor( + slice_offset_lst, device=device, dtype=torch.uint64 + ) else: slice_start_tensor = slice_offset_lst[0] lora_ptr_tensor = lora_b_weight[0] # If each lora has the same stride, there's no need to use a # tensor for storage. - if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and - len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1: + if ( + len(set(lora_strides_d0)) == 1 + and len(set(lora_strides_d1)) == 1 + and len(set(lora_strides_d2)) == 1 + ) and len(set(hidden_sizes)) == 1: lora_strides_d0_tensor = lora_strides_d0[0] lora_strides_d1_tensor = lora_strides_d1[0] lora_strides_d2_tensor = lora_strides_d2[0] @@ -119,8 +122,14 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, same_stride = False # MAX_N is the maximum hidden size among all the lora_b weights MAX_N = max(hidden_sizes) - _LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, - lora_strides_d0_tensor, lora_strides_d1_tensor, - lora_strides_d2_tensor, hidden_sizes_tensor, - same_stride, MAX_N) + _LORA_B_PTR_DICT[key] = ( + slice_start_tensor, + lora_ptr_tensor, + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + hidden_sizes_tensor, + same_stride, + MAX_N, + ) return _LORA_B_PTR_DICT.get(key) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 7e7c3c892457..b5570ceca68c 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink) +from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink __all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 29bfd5753a58..4924890b388c 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -33,8 +33,7 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): @impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") -def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor): +def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): T, _ = inputs.shape if len(loras.shape) == 4: loras = loras.squeeze(axis=1) @@ -73,13 +72,12 @@ def bgmv_expand( limit = 1 if output_tensor.shape[1] > outputs.shape[1]: - outputs = F.pad(outputs, - (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) + outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) if add_inputs: - return output_tensor + outputs[:limit, :output_tensor.shape[1]] + return output_tensor + outputs[:limit, : output_tensor.shape[1]] else: - return outputs[:limit, :output_tensor.shape[1]] + return outputs[:limit, : output_tensor.shape[1]] def bgmv_shrink( @@ -98,8 +96,7 @@ def bgmv_shrink( scaling (float, optional): Scalar multiplier applied to the output. """ - return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, - lora_indices_tensor) + return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) def bgmv_expand_slice( diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index dc7249c38602..48412eab92d8 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -18,9 +18,9 @@ @dataclass class PEFTHelper: - """ + """ A helper class for PEFT configurations, specifically designed for LoRA. - This class handles configuration validation, compatibility checks for + This class handles configuration validation, compatibility checks for various LoRA implementations. """ @@ -71,37 +71,38 @@ def from_dict(cls, config_dict: dict) -> "PEFTHelper": # Identify any missing required fields missing_fields = required_fields - set(config_dict.keys()) if missing_fields: - raise ValueError( - f"Missing required configuration fields: {missing_fields}") + raise ValueError(f"Missing required configuration fields: {missing_fields}") # Filter out fields that aren't defined in the class - filtered_dict = { - k: v - for k, v in config_dict.items() if k in class_fields - } + filtered_dict = {k: v for k, v in config_dict.items() if k in class_fields} return cls(**filtered_dict) @classmethod def from_local_dir( - cls, - lora_path: str, - max_position_embeddings: Optional[int], - tensorizer_config_dict: Optional[dict] = None) -> "PEFTHelper": + cls, + lora_path: str, + max_position_embeddings: Optional[int], + tensorizer_config_dict: Optional[dict] = None, + ) -> "PEFTHelper": lora_config_path = os.path.join(lora_path, "adapter_config.json") if tensorizer_config_dict: tensorizer_config = TensorizerConfig(**tensorizer_config_dict) tensorizer_args = tensorizer_config._construct_tensorizer_args() from tensorizer.stream_io import open_stream - lora_config_path = os.path.join(tensorizer_config.tensorizer_dir, - "adapter_config.json") - with open_stream(lora_config_path, - mode="rb", - **tensorizer_args.stream_kwargs) as f: + + lora_config_path = os.path.join( + tensorizer_config.tensorizer_dir, "adapter_config.json" + ) + with open_stream( + lora_config_path, mode="rb", **tensorizer_args.stream_kwargs + ) as f: config = json.load(f) - logger.info("Successfully deserialized LoRA config from %s", - tensorizer_config.tensorizer_dir) + logger.info( + "Successfully deserialized LoRA config from %s", + tensorizer_config.tensorizer_dir, + ) else: with open(lora_config_path) as f: @@ -112,16 +113,16 @@ def from_local_dir( def validate_legal(self, lora_config: LoRAConfig) -> None: """ - Validates the LoRA configuration settings against application + Validates the LoRA configuration settings against application constraints and requirements. """ error_msg = self._validate_features() if self.r > lora_config.max_lora_rank: error_msg.append( f"LoRA rank {self.r} is greater than max_lora_rank" - f" {lora_config.max_lora_rank}.") + f" {lora_config.max_lora_rank}." + ) if self.bias != "none" and not lora_config.bias_enabled: - error_msg.append( - "Adapter bias cannot be used without bias_enabled.") + error_msg.append("Adapter bias cannot be used without bias_enabled.") if error_msg: raise ValueError(f"{' '.join(error_msg)}") diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index b3413de1c816..770c3cf7b073 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -81,39 +81,43 @@ def add_lora_embedding( **kwargs, ) -> Optional[torch.Tensor]: """ - Applies lora specifically for VocabParallelEmbeddingWithLoRA, + Applies lora specifically for VocabParallelEmbeddingWithLoRA, and this layer only requires the expand operation. """ raise NotImplementedError @abstractmethod - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, - **kwargs) -> Optional[torch.Tensor]: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs, + ) -> Optional[torch.Tensor]: """ - Applicable to linear-related lora. + Applicable to linear-related lora. """ raise NotImplementedError @abstractmethod - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> Optional[torch.Tensor]: + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. """ @@ -122,41 +126,41 @@ def add_lora_logits(self, class PunicaWrapperBase(PunicaWrapperABC): """ - PunicaWrapperBase is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + PunicaWrapperBase is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - self._token_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._sampler_indices_padded = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - self._embeddings_indices = torch.empty(2, - max_num_batched_tokens, - dtype=torch.long, - device=device) + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + self._token_lora_indices = torch.empty( + max_num_batched_tokens, dtype=torch.long, device=device + ) + self._sampler_indices = torch.empty( + max_num_batched_tokens, dtype=torch.long, device=device + ) + self._sampler_indices_padded = torch.empty( + max_num_batched_tokens, dtype=torch.long, device=device + ) + self._embeddings_indices = torch.empty( + 2, max_num_batched_tokens, dtype=torch.long, device=device + ) # 4 is the number of indices tensors. # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices self.indices_len: list[Optional[int]] = [None] * 4 # these attributes are the information required for sgmv kernel - self._seq_start_locs = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._seq_lengths = torch.empty(max_batches, - dtype=torch.long, - device=device) - self._lora_indices_per_batch = torch.empty(max_batches, - dtype=torch.long, - device=device) + self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, device=device) + self._seq_lengths = torch.empty(max_batches, dtype=torch.long, device=device) + self._lora_indices_per_batch = torch.empty( + max_batches, dtype=torch.long, device=device + ) self.device: torch.device = device self.max_length: int = 0 self.token_nums: int = 0 @@ -186,28 +190,33 @@ def _update_base_metadata( extra_vocab_size, self.device, ) - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) + self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[: sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded + ) + self._embeddings_indices[ + : embeddings_indices.shape[0], : embeddings_indices.shape[1] + ].copy_(embeddings_indices) self.indices_len[:] = indices_len - def _update_prefill_metadata(self, - token_lora_tensor: torch.Tensor) -> None: - - (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, - no_lora) = compute_meta(token_lora_tensor) - - self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( - b_seq_start_tensor) - self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) - self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( - lora_indices_tensor) + def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: + ( + b_seq_start_tensor, + seq_length_tensor, + lora_indices_tensor, + batch_size, + max_length, + token_nums, + no_lora, + ) = compute_meta(token_lora_tensor) + + self._seq_start_locs[: b_seq_start_tensor.shape[0]].copy_(b_seq_start_tensor) + self._seq_lengths[: seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[: lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor + ) self.batch_size = batch_size self.max_length = max_length self.token_nums = token_nums @@ -240,35 +249,39 @@ def _apply_bias( bias = bias.view(-1, bias.shape[-1]) bias = bias[indices] bias[indices == -1] = 0 - output[:, offset_left:offset_left + slice] += bias + output[:, offset_left : offset_left + slice] += bias offset_left += slice return output.view_as(org_output) @property def prefill_metadata( - self + self, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """ - This property provides a convenient way to access the necessary + This property provides a convenient way to access the necessary metadata for prefill-related kernel computations. 1. seq_start_locs: Tensor of sequence start positions. 2. seq_lengths: Tensor of sequence lengths. - 3. lora_indices_per_batch: Tensor of lora indices, and an index of + 3. lora_indices_per_batch: Tensor of lora indices, and an index of -1 means no lora should be applied. 4. batch_size: Batch size after clustering identical lora indices. 5. max_length: The maximum sequence length in the batch. 6. token_nums: The token numbers in the batch. """ - return (self._seq_start_locs[:self.batch_size], - self._seq_lengths[:self.batch_size], - self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length, self.token_nums) + return ( + self._seq_start_locs[: self.batch_size], + self._seq_lengths[: self.batch_size], + self._lora_indices_per_batch[: self.batch_size], + self.batch_size, + self.max_length, + self.token_nums, + ) @property def token_lora_indices(self) -> torch.Tensor: """ - This property provides the lora indices corresponding to each token + This property provides the lora indices corresponding to each token in the batch. An index of -1 means no lora should be applied. """ token_lora_len = self.indices_len[0] @@ -276,8 +289,8 @@ def token_lora_indices(self) -> torch.Tensor: @property def sampler_indices(self) -> torch.Tensor: - """ - This property is used to access the lora indices specifically for + """ + This property is used to access the lora indices specifically for LogitsProcessorWithLoRA. """ sampler_indices_len = self.indices_len[1] @@ -294,18 +307,24 @@ def sampler_indices_padded(self) -> torch.Tensor: @property def embeddings_indices(self) -> torch.Tensor: """ - This property provides access to the indices used for lora embeddings, + This property provides access to the indices used for lora embeddings, specifically for VocabParallelEmbeddingWithLoRA. """ embeddings_indices_len = self.indices_len[3] return self._embeddings_indices[:, :embeddings_indices_len] - def update_metadata(self, mapping: "LoRAMapping", - lora_index_to_id: list[Optional[int]], max_loras: int, - vocab_size: int, extra_vocab_size: int, **kwargs): - - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size) + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + **kwargs, + ): + self._update_base_metadata( + mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size + ) if mapping.is_prefill: # Update metadata required for prefill-related operators. @@ -315,16 +334,21 @@ def update_metadata(self, mapping: "LoRAMapping", self.is_prefill = False @abstractmethod - def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], - scale: float, **kwargs) -> Optional[torch.Tensor]: + def add_shrink( + self, + y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. Semantics: for i in range(len(lora_a_stacked)): y[i] += (x @ lora_a_stacked[i]) * scale - + Args: y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor @@ -336,31 +360,33 @@ def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], raise NotImplementedError @abstractmethod - def add_expand(self, - y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> Optional[torch.Tensor]: + def add_expand( + self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> Optional[torch.Tensor]: """ Performs GEMM and bias addition for multiple slices of lora_b. - + Semantics: offset = offset_start for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] offset += slice - + Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight output_slices (tuple[int, ...]): Every slice's size offset_start (int): The starting position of y, defaults to 0 @@ -371,12 +397,14 @@ def add_expand(self, raise NotImplementedError @abstractmethod - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> Optional[torch.Tensor]: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. and this layer only requires the expand operation. @@ -393,19 +421,21 @@ def add_lora_embedding(self, raise NotImplementedError @abstractmethod - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, - **kwargs) -> Optional[torch.Tensor]: - """ - Applicable to linear-related lora. + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs, + ) -> Optional[torch.Tensor]: + """ + Applicable to linear-related lora. Semantics: for i in range(len(lora_a_stacked)): @@ -430,18 +460,20 @@ def add_lora_linear(self, raise NotImplementedError @abstractmethod - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> Optional[torch.Tensor]: + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 59049cccc8cb..c51a13db873c 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -5,9 +5,14 @@ import torch -from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.torch_ops import ( + bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink, +) from .punica_base import PunicaWrapperBase @@ -16,15 +21,19 @@ # inherit this class class PunicaWrapperCPU(PunicaWrapperBase): """ - PunicaWrapperCPU is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + PunicaWrapperCPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the pytorch punica ops. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, - device) + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) def _shrink_prefill( self, @@ -33,7 +42,7 @@ def _shrink_prefill( w_t_all: torch.Tensor, scale: float, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return sgmv_shrink( @@ -60,7 +69,7 @@ def _expand_prefill( w_t_all: torch.Tensor, add_inputs: bool, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return sgmv_expand( @@ -89,7 +98,7 @@ def _expand_slice_prefill( y_slice_size: int, add_inputs: bool, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return sgmv_expand_slice( @@ -111,8 +120,9 @@ def _expand_slice_decode( y_slice_size: int, add_inputs: bool, ): - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_inputs) + bgmv_expand_slice( + x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs + ) def _apply_expand( self, @@ -124,18 +134,19 @@ def _apply_expand( add_inputs: bool = True, ): """ - Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` computation, which is suitable for the GEMM of lora'b. """ - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._expand_slice_decode) + expand_slice_fun: Callable = ( + self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode + ) expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) - def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, scale: float): + def _apply_shrink( + self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float + ): """ Perform the ` y+=x@w_t_all` computation, which is suitable for the GEMM of lora'a. @@ -146,25 +157,31 @@ def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, """ y_org = y y = y.view(-1, y.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) + shrink_fun: Callable = ( + self._shrink_prefill if self.is_prefill else self._shrink_decode + ) shrink_fun(y, x, w_t_all, scale) y = y.view_as(y_org) - def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], - scale: float, **kwargs): + def add_shrink( + self, + y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ): """ Performs GEMM for multiple slices of lora_a. When `is_prefill is` true, it indicates that it is currently the prefill stage, and the `_shrink_prefill` function should be called. Otherwise, it is the decode stage, and the _shrink_decode function should be called. - + Semantics: for i in range(len(lora_a_stacked)): y[i] += (x @ lora_a_stacked[i]) * scale - + Args: y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor @@ -175,33 +192,34 @@ def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], x = x.view(-1, x.shape[-1]) # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) - - def add_expand(self, - y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> None: + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale) + + def add_expand( + self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: """ Performs GEMM and bias addition for multiple slices of lora_b. - + Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] offset += slice - + Args: y (torch.Tensor): Output tensor. x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. @@ -210,8 +228,9 @@ def add_expand(self, y = y.view(-1, y.shape[-1]) offset_left = offset_start if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) + self._apply_bias( + self.token_lora_indices, y, output_slices, lora_bias_stacked + ) for slice_idx in range(len(lora_b_stacked)): self._apply_expand( y, @@ -224,12 +243,14 @@ def add_expand(self, offset_left += output_slices[slice_idx] y = y.view_as(y_org) - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> None: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -244,23 +265,26 @@ def add_lora_embedding(self, """ # Embedding layer only need expand op - expand_fun: Callable = (self._expand_prefill - if self.is_prefill else self._expand_decode) + expand_fun: Callable = ( + self._expand_prefill if self.is_prefill else self._expand_decode + ) expand_fun(y, x, lora_b_stacked, add_inputs) - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs, + ) -> None: """ - Applicable to linear-related lora. + Applicable to linear-related lora. Semantics: for i in range(len(lora_a_stacked)): @@ -285,38 +309,37 @@ def add_lora_linear(self, assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) if lora_bias_stacked is not None: assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) + y = self._apply_bias( + self.token_lora_indices, y, output_slices, lora_bias_stacked + ) if buffer is None: r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default, consistent with the # triton op buffer = tuple( - torch.zeros( - (x.size(0), r), dtype=torch.float32, device=x.device) - for _ in range(len(output_slices))) + torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices)) + ) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_inputs=True, - **kwargs) - - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + self.add_expand( + y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs + ) + + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked @@ -336,14 +359,8 @@ def add_lora_logits(self, if buffer is None: # We set the buffer to be float32 by default, consistent with the # triton op - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) # LogitsProcessorWithLoRA always using bgmv. bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) + bgmv_expand(buffer, lora_b_stacked, y, self.sampler_indices, add_inputs=True) y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 467f50050eb2..431e97102faf 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -15,8 +15,7 @@ from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, - lora_shrink) + from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink from .punica_base import PunicaWrapperBase @@ -24,48 +23,63 @@ @final class PunicaWrapperGPU(PunicaWrapperBase): """ - PunicaWrapperGPU is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + PunicaWrapperGPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica triton kernel. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, - device) + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - self.max_loras = kwargs['max_loras'] + self.max_loras = kwargs["max_loras"] - self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras, - max_num_batched_tokens, - device=device) - - self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras, - max_batches, - device=device) + self.token_mapping_meta = LoRAKernelMeta.make( + self.max_loras, max_num_batched_tokens, device=device + ) - def update_metadata(self, mapping: LoRAMapping, - lora_index_to_id: list[Optional[int]], max_loras: int, - vocab_size: int, extra_vocab_size: int, **kwargs): + self.prompt_mapping_meta = LoRAKernelMeta.make( + self.max_loras, max_batches, device=device + ) + def update_metadata( + self, + mapping: LoRAMapping, + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + **kwargs, + ): self.is_prefill = mapping.is_prefill - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size) + self._update_base_metadata( + mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size + ) # Prepare cuda kernel metadata tensors self.token_mapping_meta.prepare_tensors(self.token_lora_indices) self.prompt_mapping_meta.prepare_tensors(self.sampler_indices) - def add_shrink(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, - ...], scale: float, **kwargs): + def add_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ): """ Performs GEMM for multiple slices of lora_a. - + Semantics: for i in range(len(lora_a_stacked)): y[i] += (x @ lora_a_stacked[i]) * scale - + Args: y (torch.Tensor): Output tensors x (torch.Tensor): Input tensor @@ -82,30 +96,32 @@ def add_shrink(self, y: torch.Tensor, x: torch.Tensor, scale, ) - def add_expand(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> None: + def add_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: """ Performs GEMM and bias addition for multiple slices of lora_b. - + Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] offset += slice - + Args: y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. @@ -113,10 +129,8 @@ def add_expand(self, y_org = y y = y.view(-1, y.shape[-1]) if lora_bias_stacked is not None: - token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, - y.size(0)) - self._apply_bias(token_lora_indices, y, output_slices, - lora_bias_stacked) + token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0)) + self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked) assert x.ndim == 3 assert x.size(0) == len(output_slices) @@ -133,12 +147,14 @@ def add_expand(self, y = y.view_as(y_org) - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> None: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -154,26 +170,28 @@ def add_lora_embedding(self, lora_expand( x.unsqueeze(dim=0), - (lora_b_stacked, ), + (lora_b_stacked,), y, *self.token_mapping_meta.meta_args(x.size(0)), offset_start=0, add_inputs=add_inputs, ) - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: """ - Applicable to linear-related lora. + Applicable to linear-related lora. Semantics: for i in range(len(lora_a_stacked)): @@ -198,10 +216,10 @@ def add_lora_linear(self, assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) if lora_bias_stacked is not None: assert len(lora_bias_stacked) == len(output_slices) - token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, - y.size(0)) - y = self._apply_bias(token_lora_indices, y, output_slices, - lora_bias_stacked) + token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, y.size(0)) + y = self._apply_bias( + token_lora_indices, y, output_slices, lora_bias_stacked + ) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -217,7 +235,8 @@ def add_lora_linear(self, x, lora_a_stacked, scale, - **kwargs) + **kwargs, + ) self.add_expand( y, buffer, # type: ignore @@ -225,20 +244,23 @@ def add_lora_linear(self, None, output_slices, add_inputs=True, - **kwargs) - - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs, + ) + + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked @@ -258,15 +280,21 @@ def add_lora_logits(self, if buffer is None: # We set the buffer to be float32 by default, refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - lora_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0), - *self.prompt_mapping_meta.meta_args(x.size(0)), scale) + lora_shrink( + x, + [lora_a_stacked], + buffer.unsqueeze(dim=0), + *self.prompt_mapping_meta.meta_args(x.size(0)), + scale, + ) - lora_expand(buffer.unsqueeze(dim=0), [lora_b_stacked], - y, - *self.prompt_mapping_meta.meta_args(buffer.size(0)), - add_inputs=True) + lora_expand( + buffer.unsqueeze(dim=0), + [lora_b_stacked], + y, + *self.prompt_mapping_meta.meta_args(buffer.size(0)), + add_inputs=True, + ) y = y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index c684ac77cc9c..c017721803fe 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -14,7 +14,8 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: punica_wrapper_qualname = current_platform.get_punica_wrapper() punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname) punica_wrapper = punica_wrapper_cls(*args, **kwargs) - assert punica_wrapper is not None, \ + assert punica_wrapper is not None, ( "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong." + ) logger.info_once("Using %s.", punica_wrapper_qualname.rsplit(".", 1)[1]) return punica_wrapper diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 5896da516540..5d2f05b815be 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -25,27 +25,29 @@ class PunicaWrapperTPU(PunicaWrapperBase): Multi-LoRA, and to provide the interface for the pytorch punica ops. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, - device) + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) # PunicaWrapperBase defines some tensors with dtype=torch.int64, which # isn't supported by the TPU. So convert those tensors to int32. # Not all of them are used by the TPU so only convert the useful ones. - self._token_lora_indices = self._token_lora_indices.to( - dtype=torch.int32) + self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32) self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) self._sampler_indices_padded = self._sampler_indices_padded.to( - dtype=torch.int32) + dtype=torch.int32 + ) torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True) torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, - True) + torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True) torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, - True) + torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True) torch._dynamo.mark_dynamic(self._token_lora_indices, 0) torch._dynamo.mark_dynamic(self._embeddings_indices, 1) @@ -77,21 +79,38 @@ def shrink( ): return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale) - def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - add_inputs: bool): - return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), - add_inputs) - - def expand_slice(self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - add_inputs: bool) -> torch.Tensor: - return bgmv_expand_slice(x, w_t_all, y, - self._get_token_lora_indices(x), y_offset, - y_slice_size, add_inputs) - - def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], - scale: float, **kwargs) -> Optional[torch.Tensor]: + def expand( + self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool + ): + return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs) + + def expand_slice( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ) -> torch.Tensor: + return bgmv_expand_slice( + x, + w_t_all, + y, + self._get_token_lora_indices(x), + y_offset, + y_slice_size, + add_inputs, + ) + + def add_shrink( + self, + y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. @@ -115,15 +134,17 @@ def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], y[slice_idx, :, :] = y_s # type: ignore[index] return y - def add_expand(self, - y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> torch.Tensor: + def add_expand( + self, + y: torch.Tensor, + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> torch.Tensor: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -148,24 +169,29 @@ def add_expand(self, offset_left = 0 if lora_bias_stacked is not None: - y = self._apply_bias(self._get_token_lora_indices(y), y, - output_slices, lora_bias_stacked) + y = self._apply_bias( + self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked + ) for slice_idx in range(len(lora_b_stacked)): - y = self.expand_slice(y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_inputs=add_inputs) + y = self.expand_slice( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) offset_left += output_slices[slice_idx] return y.view_as(y_org) - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> torch.Tensor: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> torch.Tensor: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -182,17 +208,19 @@ def add_lora_embedding(self, # Embedding layer only needs the expand op return self.expand(y, x, lora_b_stacked, add_inputs) - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, - **kwargs) -> torch.Tensor: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[tuple[torch.Tensor, ...]] = None, + **kwargs, + ) -> torch.Tensor: """ Applicable to linear-related lora. @@ -219,8 +247,9 @@ def add_lora_linear(self, assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) if lora_bias_stacked is not None: assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self._get_token_lora_indices(y), y, - output_slices, lora_bias_stacked) + y = self._apply_bias( + self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked + ) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -231,23 +260,21 @@ def add_lora_linear(self, device=x.device, ) buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - return self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_inputs=True, - **kwargs) - - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> torch.Tensor: + return self.add_expand( + y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs + ) + + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: """ Applies lora specifically for LogitsProcessorWithLoRA. @@ -269,11 +296,7 @@ def add_lora_logits(self, sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale) - y = bgmv_expand(buffer, - lora_b_stacked, - y, - sampler_indices, - add_inputs=True) + y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) return y.view_as(y_org) def _apply_bias( @@ -304,8 +327,9 @@ def _apply_bias( bias = bias[indices] bias = torch.where(indices[:, None] == -1, 0, bias) - bias = F.pad(bias, (offset_left, output.shape[1] - - (offset_left + slice), 0, 0)) + bias = F.pad( + bias, (offset_left, output.shape[1] - (offset_left + slice), 0, 0) + ) output += bias offset_left += slice @@ -328,8 +352,7 @@ def _update_base_metadata( # Pad the prompt mapping to avoid running into recompiles on the TPU # TODO: Should this happen inside mapping internally? If so how can we # avoid having backend specific LoRAMapping classes? - mapping.prompt_mapping = self._pad_prompt_mapping( - mapping.prompt_mapping) + mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping) ( base_indices, @@ -346,35 +369,33 @@ def _update_base_metadata( "cpu", ) self._token_lora_indices = self._pad_to_shape( - base_indices, self._token_lora_indices.shape, - dims=1).to(self.device) - self._sampler_indices = self._pad_to_shape(sampler_indices, - self._sampler_indices.shape, - dims=1).to(self.device) + base_indices, self._token_lora_indices.shape, dims=1 + ).to(self.device) + self._sampler_indices = self._pad_to_shape( + sampler_indices, self._sampler_indices.shape, dims=1 + ).to(self.device) self._sampler_indices_padded = self._pad_to_shape( - sampler_indices_padded, self._sampler_indices_padded.shape, - dims=1).to(self.device) + sampler_indices_padded, self._sampler_indices_padded.shape, dims=1 + ).to(self.device) self._embeddings_indices = self._pad_to_shape( - embeddings_indices, self._embeddings_indices.shape, - dims=2).to(self.device) + embeddings_indices, self._embeddings_indices.shape, dims=2 + ).to(self.device) self.indices_len[:] = indices_len - def _update_prefill_metadata(self, - token_lora_tensor: torch.Tensor) -> None: + def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 - self._lora_indices_per_batch[:self. - batch_size] = token_lora_tensor[:self. - batch_size] + self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[ + : self.batch_size + ] - def _pad_prompt_mapping( - self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: + def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: num_reqs = len(prompt_mapping) # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular # import MIN_NUM_SEQS = 8 - padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) + padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) pad_len = padded_num_reqs - num_reqs padding = [-1] * pad_len @@ -387,5 +408,4 @@ def _pad_to_shape(self, src, target_shape, dims=1): else: pad_rows = target_shape[0] - src.shape[0] pad_cols = target_shape[1] - src.shape[1] - return F.pad(src, (0, pad_cols, 0, pad_rows), - value=0).to(torch.int32) + return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32) diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 163bb412235c..5196199b2ac3 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ @@ -21,25 +21,35 @@ class PunicaWrapperXPU(PunicaWrapperBase): """ PunicaWrapperXPU is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica ipex kernel. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, - device) + def __init__( + self, + max_num_batched_tokens: int, + max_batches: int, + device: Union[torch.device, str], + **kwargs, + ): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) torch._dynamo.mark_dynamic(self._token_lora_indices, 0) torch._dynamo.mark_dynamic(self._embeddings_indices, 1) torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) - def update_metadata(self, mapping: LoRAMapping, - lora_index_to_id: list[Optional[int]], max_loras: int, - vocab_size: int, extra_vocab_size: int, **kwargs): - + def update_metadata( + self, + mapping: LoRAMapping, + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + **kwargs, + ): self.is_prefill = mapping.is_prefill - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size) + self._update_base_metadata( + mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size + ) def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) @@ -63,19 +73,25 @@ def _apply_expand( add_inputs: bool, ): token_lora_indices = self._get_token_lora_indices(x) - bgmv_expand_slice(x, w_t_all, y, token_lora_indices, y_offset, - y_slice_size, add_inputs) + bgmv_expand_slice( + x, w_t_all, y, token_lora_indices, y_offset, y_slice_size, add_inputs + ) - def add_shrink(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, - ...], scale: float, **kwargs): + def add_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ): """ Performs GEMM for multiple slices of lora_a. - + Semantics: for i in range(len(lora_a_stacked)): y[i] += (x @ lora_a_stacked[i]) * scale - + Args: y (torch.Tensor): Output tensors x (torch.Tensor): Input tensor @@ -85,33 +101,34 @@ def add_shrink(self, y: torch.Tensor, x: torch.Tensor, x = x.view(-1, x.shape[-1]) for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) - - def add_expand(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> None: + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale) + + def add_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: """ Performs GEMM and bias addition for multiple slices of lora_b. - + Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] offset += slice - + Args: y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensors lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. @@ -120,8 +137,7 @@ def add_expand(self, y = y.view(-1, y.shape[-1]) if lora_bias_stacked is not None: token_lora_indices = self._get_token_lora_indices(y) - self._apply_bias(token_lora_indices, y, output_slices, - lora_bias_stacked) + self._apply_bias(token_lora_indices, y, output_slices, lora_bias_stacked) assert x.ndim == 3 assert x.size(0) == len(output_slices) @@ -139,12 +155,14 @@ def add_expand(self, offset_start += output_slices[slice_idx] y.view_as(y_org) - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> None: + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs, + ) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -160,17 +178,19 @@ def add_lora_embedding(self, token_lora_indices = self._get_token_lora_indices(x) bgmv_expand(x, lora_b_stacked, y, token_lora_indices, add_inputs) - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: """ Applicable to linear-related lora. @@ -198,8 +218,9 @@ def add_lora_linear(self, if lora_bias_stacked is not None: assert len(lora_bias_stacked) == len(output_slices) token_lora_indices = self._get_token_lora_indices(y) - y = self._apply_bias(token_lora_indices, y, output_slices, - lora_bias_stacked) + y = self._apply_bias( + token_lora_indices, y, output_slices, lora_bias_stacked + ) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -215,7 +236,8 @@ def add_lora_linear(self, x, lora_a_stacked, scale, - **kwargs) + **kwargs, + ) self.add_expand( y, buffer, # type: ignore @@ -223,7 +245,8 @@ def add_lora_linear(self, None, output_slices, add_inputs=True, - **kwargs) + **kwargs, + ) @property def sampler_indices_padded(self) -> torch.Tensor: @@ -232,18 +255,20 @@ def sampler_indices_padded(self) -> torch.Tensor: """ return self._sampler_indices_padded[:] - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs, + ) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked @@ -263,14 +288,8 @@ def add_lora_logits(self, if buffer is None: # We set the buffer to be float32 by default, refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - sampler_indices, - add_inputs=True) + bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) return y.view_as(y_org) diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index d22c29da1c61..90d1614e674d 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -11,7 +11,7 @@ def compute_meta( - token_lora_tensor: torch.Tensor + token_lora_tensor: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: @@ -23,7 +23,8 @@ def compute_meta( """ lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( - token_lora_tensor, return_counts=True) + token_lora_tensor, return_counts=True + ) cum_result = torch.cumsum(seq_length_tensor, dim=0) b_seq_start_tensor = torch.zeros_like(seq_length_tensor) b_seq_start_tensor[1:].copy_(cum_result[:-1]) @@ -36,8 +37,15 @@ def compute_meta( # does not need to launch the triton kernel, which can improve performance if batch_size == 1 and lora_indices_tensor == -1: no_lora = True - return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, token_nums, no_lora) + return ( + b_seq_start_tensor, + seq_length_tensor, + lora_indices_tensor, + batch_size, + max_length, + token_nums, + no_lora, + ) # TODO see if this can be vectorized @@ -83,14 +91,16 @@ def convert_mapping( lora_indices = index_mapping_indices.copy() prompt_mapping: list[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping + lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping ] lora_idx = None for i in range(len(index_mapping_indices)): # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) + lora_idx = ( + lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 + else -1 + ) embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx @@ -101,23 +111,27 @@ def convert_mapping( ] indices = torch.tensor(indices_list, dtype=torch.long, device=device) - prompt_mapping_tensor = torch.tensor(prompt_mapping, - dtype=torch.long, - device=device) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size), - ]) - embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1, - embeddings_indices) + prompt_mapping_tensor = torch.tensor( + prompt_mapping, dtype=torch.long, device=device + ) + embeddings_indices = torch.stack( + [ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ] + ) + embeddings_indices = torch.where( + embeddings_indices == -1, max_loras - 1, embeddings_indices + ) base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded = torch.where(sampler_indices_padded == -1, - max_loras - 1, sampler_indices_padded) + sampler_indices_padded = torch.where( + sampler_indices_padded == -1, max_loras - 1, sampler_indices_padded + ) sampler_indices_padded = torch.arange( - 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( - sampler_indices_padded * len(sampler_indices_padded)) + 0, len(sampler_indices_padded), device=device, dtype=torch.long + ) + (sampler_indices_padded * len(sampler_indices_padded)) # Contain length of indices tensors. Used to index into each tensor. indices_len = [ diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 523525d46f0b..650e060a5804 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -8,9 +8,10 @@ class LoRARequest( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True, +): # type: ignore[call-arg] """ Request for a LoRA adapter. @@ -22,6 +23,7 @@ class LoRARequest( lora_int_id must be globally unique for a given adapter. This is currently not enforced in vLLM. """ + lora_name: str lora_int_id: int lora_path: str = "" @@ -39,7 +41,8 @@ def __post_init__(self): "and will be removed in a future version. " "Please use 'lora_path' instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) if not self.lora_path: self.lora_path = self.lora_local_path or "" @@ -65,7 +68,8 @@ def local_path(self): "and will be removed in a future version. " "Please use 'path' instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) return self.lora_path @local_path.setter @@ -75,7 +79,8 @@ def local_path(self, value): "and will be removed in a future version. " "Please use 'path' instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) self.lora_path = value def __eq__(self, value: object) -> bool: @@ -84,8 +89,7 @@ def __eq__(self, value: object) -> bool: instances based on lora_name. This allows for identification and comparison lora adapter across engines. """ - return isinstance(value, - self.__class__) and self.lora_name == value.lora_name + return isinstance(value, self.__class__) and self.lora_name == value.lora_name def __hash__(self) -> int: """ diff --git a/vllm/lora/resolver.py b/vllm/lora/resolver.py index 5808ae105e86..d366b94521cd 100644 --- a/vllm/lora/resolver.py +++ b/vllm/lora/resolver.py @@ -22,8 +22,9 @@ class LoRAResolver(ABC): """ @abstractmethod - async def resolve_lora(self, base_model_name: str, - lora_name: str) -> Optional[LoRARequest]: + async def resolve_lora( + self, base_model_name: str, lora_name: str + ) -> Optional[LoRARequest]: """Abstract method to resolve and fetch a LoRA model adapter. Implements logic to locate and download LoRA adapter based on the name. @@ -61,8 +62,10 @@ def register_resolver( if resolver_name in self.resolvers: logger.warning( "LoRA resolver %s is already registered, and will be " - "overwritten by the new resolver instance %s.", resolver_name, - resolver) + "overwritten by the new resolver instance %s.", + resolver_name, + resolver, + ) self.resolvers[resolver_name] = resolver @@ -78,7 +81,8 @@ def get_resolver(self, resolver_name: str) -> LoRAResolver: if resolver_name not in self.resolvers: raise KeyError( f"LoRA resolver '{resolver_name}' not found. " - f"Available resolvers: {list(self.resolvers.keys())}") + f"Available resolvers: {list(self.resolvers.keys())}" + ) return self.resolvers[resolver_name] diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 10ba390bffd9..5e55d44ce8d9 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -6,37 +6,40 @@ import huggingface_hub import regex as re -from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - HFValidationError, RepositoryNotFoundError) +from huggingface_hub.utils import ( + EntryNotFoundError, + HfHubHTTPError, + HFValidationError, + RepositoryNotFoundError, +) from torch import nn from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig from vllm.logger import init_logger + # being imported for _all_lora_classes below -# yapf conflicts with isort for this block -# yapf: disable -from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - ColumnParallelLinearWithShardedLoRA, - LogitsProcessorWithLoRA, - MergedColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithLoRA, - MergedQKVParallelLinearWithShardedLoRA, - QKVParallelLinearWithLoRA, - QKVParallelLinearWithShardedLoRA, - ReplicatedLinearWithLoRA, - RowParallelLinearWithLoRA, - RowParallelLinearWithShardedLoRA, - VocabParallelEmbeddingWithLoRA) +from vllm.lora.layers import ( + BaseLayerWithLoRA, + ColumnParallelLinearWithLoRA, + ColumnParallelLinearWithShardedLoRA, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithLoRA, + MergedQKVParallelLinearWithShardedLoRA, + QKVParallelLinearWithLoRA, + QKVParallelLinearWithShardedLoRA, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + RowParallelLinearWithShardedLoRA, + VocabParallelEmbeddingWithLoRA, +) from vllm.model_executor.layers.linear import LinearBase -# yapf: enable - if TYPE_CHECKING: from vllm.model_executor.layers.logits_processor import LogitsProcessor - from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead) + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -58,20 +61,23 @@ } -def from_layer(layer: nn.Module, - max_loras: int, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig] = None) -> nn.Module: +def from_layer( + layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: Optional[PretrainedConfig] = None, +) -> nn.Module: for lora_cls in _all_lora_classes: # specifying kwargs so they can be easily accessed in decorator - if lora_cls.can_replace_layer(source_layer=layer, - lora_config=lora_config, - packed_modules_list=packed_modules_list, - model_config=model_config): + if lora_cls.can_replace_layer( + source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + ): instance_layer = lora_cls(layer) - instance_layer.create_lora_weights(max_loras, lora_config, - model_config) + instance_layer.create_lora_weights(max_loras, lora_config, model_config) return instance_layer return layer @@ -83,15 +89,20 @@ def from_layer_logits_processor( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> LogitsProcessorWithLoRA: - ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, - lm_head.weight.dtype, lm_head.weight.device, - lm_head.get_sharded_to_full_mapping()) + ret = LogitsProcessorWithLoRA( + layer, + lm_head.embedding_dim, + lm_head.weight.dtype, + lm_head.weight.device, + lm_head.get_sharded_to_full_mapping(), + ) ret.create_lora_weights(max_loras, lora_config, model_config) return ret -def replace_submodule(model: nn.Module, module_name: str, - new_module: nn.Module) -> nn.Module: +def replace_submodule( + model: nn.Module, module_name: str, new_module: nn.Module +) -> nn.Module: """Replace a submodule in a model with a new module.""" parent = model.get_submodule(".".join(module_name.split(".")[:-1])) target_name = module_name.split(".")[-1] @@ -100,8 +111,7 @@ def replace_submodule(model: nn.Module, module_name: str, def parse_fine_tuned_lora_name( - name: str, - weights_mapper: Optional["WeightsMapper"] = None + name: str, weights_mapper: Optional["WeightsMapper"] = None ) -> tuple[str, bool, bool]: """Parse the name of lora weights. @@ -134,8 +144,7 @@ def parse_fine_tuned_lora_name( start_index = 2 if name.startswith("base_model.model.") else 0 parts = name.split(".") - if parts[-1] == "weight" and (parts[-2] == "lora_A" - or parts[-2] == "lora_B"): + if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"): new_name = ".".join(parts[start_index:-2]) return new_name, parts[-2] == "lora_A", False @@ -150,12 +159,13 @@ def parse_fine_tuned_lora_name( raise ValueError(f"{name} is unsupported LoRA weight") -def is_regex_target_modules(load_modules: Union[str, list[str]], - expected_lora_modules: list[str]) -> bool: +def is_regex_target_modules( + load_modules: Union[str, list[str]], expected_lora_modules: list[str] +) -> bool: """ - PEFT supports passing `target_modules` in the form of regular expressions, - such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to - determine whether the suffix in the regular expression is present in the + PEFT supports passing `target_modules` in the form of regular expressions, + such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to + determine whether the suffix in the regular expression is present in the `expected_lora_modules`. """ @@ -197,7 +207,7 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]: supported_lora_modules.add(name) # get all the linear subfixes. - if isinstance(module, (LinearBase, )): + if isinstance(module, (LinearBase,)): supported_lora_modules.add(name.split(".")[-1]) return list(supported_lora_modules) @@ -225,7 +235,7 @@ def get_adapter_absolute_path(lora_path: str) -> str: return lora_path # If the path starts with ~, expand the user home directory. - if lora_path.startswith('~'): + if lora_path.startswith("~"): return os.path.expanduser(lora_path) # Check if the expanded relative path exists locally. @@ -234,10 +244,13 @@ def get_adapter_absolute_path(lora_path: str) -> str: # If the path does not exist locally, assume it's a Hugging Face repo. try: - local_snapshot_path = huggingface_hub.snapshot_download( - repo_id=lora_path) - except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError, - HFValidationError): + local_snapshot_path = huggingface_hub.snapshot_download(repo_id=lora_path) + except ( + HfHubHTTPError, + RepositoryNotFoundError, + EntryNotFoundError, + HFValidationError, + ): # Handle errors that may occur during the download # Return original path instead of throwing error here logger.exception("Error downloading the HuggingFace model") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index cdb2f86611d8..3ca819fb732c 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -8,8 +8,12 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.lora.models import (LoRAModel, LoRAModelManager, - LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.models import ( + LoRAModel, + LoRAModelManager, + LRUCacheLoRAModelManager, + create_lora_manager, +) from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path @@ -39,7 +43,8 @@ def __init__( self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs self.max_num_batched_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + vllm_config.scheduler_config.max_num_batched_tokens + ) self.vocab_size = vllm_config.model_config.get_vocab_size() self.lora_config = vllm_config.lora_config @@ -81,15 +86,12 @@ def create_lora_manager( def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: - supported_lora_modules = ( - self._adapter_manager.supported_lora_modules) - packed_modules_mapping = ( - self._adapter_manager.packed_modules_mapping) + supported_lora_modules = self._adapter_manager.supported_lora_modules + packed_modules_mapping = self._adapter_manager.packed_modules_mapping expected_lora_modules: list[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: - expected_lora_modules.extend( - packed_modules_mapping[module]) + expected_lora_modules.extend(packed_modules_mapping[module]) else: expected_lora_modules.append(module) @@ -97,8 +99,10 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: lora_path = get_adapter_absolute_path(lora_request.lora_path) peft_helper = PEFTHelper.from_local_dir( - lora_path, self.max_position_embeddings, - lora_request.tensorizer_config_dict) + lora_path, + self.max_position_embeddings, + lora_request.tensorizer_config_dict, + ) # Validates the LoRA configuration against requirements before # loading weights, throwing an exception if validation fails. @@ -116,12 +120,13 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, - target_embedding_padding=self.vocab_size + - self.lora_config.lora_extra_vocab_size, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, embedding_modules=self.embedding_modules, embedding_padding_modules=self.embedding_padding_modules, tensorizer_config_dict=lora_request.tensorizer_config_dict, - weights_mapper=hf_to_vllm_mapper) + weights_mapper=hf_to_vllm_mapper, + ) except FileNotFoundError as e: # FileNotFoundError should be raised if both @@ -131,26 +136,29 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: # For NotFoundError raise ValueError( f"Loading lora {lora_request.lora_name} failed: No adapter " - f"found for {lora_request.lora_path}") from e + f"found for {lora_request.lora_path}" + ) from e except Exception as e: # For BadRequestError raise e if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: - raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " - f"is greater than lora_extra_vocab_size " - f"{self.lora_config.lora_extra_vocab_size}.") + raise ValueError( + f"LoRA added vocab size {lora.extra_vocab_size} " + f"is greater than lora_extra_vocab_size " + f"{self.lora_config.lora_extra_vocab_size}." + ) return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: if lora_request.lora_int_id in self.list_adapters(): return False if isinstance(self._cached_dummy_lora, LoRAModel): - dummy_lora = self._cached_dummy_lora.clone( - lora_request.lora_int_id) + dummy_lora = self._cached_dummy_lora.clone(lora_request.lora_int_id) else: dummy_lora = self._adapter_manager.create_dummy_lora( - lora_request.lora_int_id, rank, self.embedding_modules) + lora_request.lora_int_id, rank, self.embedding_modules + ) if self._cached_dummy_lora is None: self._cached_dummy_lora = dummy_lora return self._adapter_manager.add_adapter(dummy_lora) @@ -158,8 +166,7 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: def pin_adapter(self, adapter_id: int) -> bool: return self._adapter_manager.pin_adapter(adapter_id) - def set_active_adapters(self, requests: set[Any], - mapping: Optional[Any]) -> None: + def set_active_adapters(self, requests: set[Any], mapping: Optional[Any]) -> None: self._apply_adapters(requests) if mapping is not None: self._adapter_manager.set_adapter_mapping(mapping) @@ -168,13 +175,15 @@ def _apply_adapters(self, adapter_requests: set[Any]) -> None: existing_adapters = self.list_adapters() models_map = { adapter_request.adapter_id: adapter_request - for adapter_request in adapter_requests if adapter_request + for adapter_request in adapter_requests + if adapter_request } if len(models_map) > self._adapter_manager.adapter_slots: raise RuntimeError( f"Number of requested models ({len(models_map)}) is greater " "than the number of GPU model slots " - f"({self._adapter_manager.adapter_slots}).") + f"({self._adapter_manager.adapter_slots})." + ) requested_ids = set(models_map) for adapter_id in existing_adapters - requested_ids: self.remove_adapter(adapter_id) @@ -227,13 +236,15 @@ def create_lora_manager( def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request - for lora_request in lora_requests if lora_request + for lora_request in lora_requests + if lora_request } if len(loras_map) > self._adapter_manager.lora_slots: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " - f"({self._adapter_manager.lora_slots}).") + f"({self._adapter_manager.lora_slots})." + ) for lora in loras_map.values(): self.add_adapter(lora) @@ -253,15 +264,15 @@ def add_adapter(self, lora_request: LoRARequest) -> bool: # Loading succeeded, now check if we will exceed cache capacity and # evict if the oldest adapter if so if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: - assert isinstance(self._adapter_manager, - LRUCacheLoRAModelManager) + assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager) self._adapter_manager.remove_oldest_adapter() # Then add the new adapter to the cache loaded = self._adapter_manager.add_adapter(lora) else: # If the lora is already loaded, just touch it to # update its position in the caches - loaded = self._adapter_manager.get_adapter( - lora_request.lora_int_id) is not None + loaded = ( + self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None + ) self._adapter_manager.activate_adapter(lora_request.lora_int_id) return loaded diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 3c094cfdb553..b50f0cb3a61a 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.model_executor.parameter import (BasevLLMParameter, - PackedvLLMParameter) +from vllm.model_executor.parameter import BasevLLMParameter, PackedvLLMParameter from vllm.model_executor.utils import set_random_seed __all__ = [ diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index e7eb8247d5ef..ad5a09ca970d 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -32,8 +32,11 @@ def __new__(cls, *args, **kwargs): op_cls_to_instantiate = cls else: op_cls_to_instantiate = cls.op_registry_oot[op_name] - logger.debug("Instantiating custom op: %s using %s", op_name, - str(op_cls_to_instantiate)) + logger.debug( + "Instantiating custom op: %s using %s", + op_name, + str(op_cls_to_instantiate), + ) return super().__new__(op_cls_to_instantiate) def __init__(self): @@ -86,8 +89,7 @@ def dispatch_forward(self): if enabled: compilation_config.enabled_custom_ops.update([self.__class__.name]) else: - compilation_config.disabled_custom_ops.update( - [self.__class__.name]) + compilation_config.disabled_custom_ops.update([self.__class__.name]) if not enabled: return self.forward_native @@ -119,8 +121,7 @@ def enabled(cls) -> bool: enabled = f"+{cls.name}" in custom_ops disabled = f"-{cls.name}" in custom_ops - assert not (enabled - and disabled), f"Cannot enable and disable {cls.name}" + assert not (enabled and disabled), f"Cannot enable and disable {cls.name}" return (CustomOp.default_on() or enabled) and not disabled @@ -131,9 +132,12 @@ def default_on() -> bool: Specifying 'all' or 'none' in custom_op takes precedence. """ from vllm.config import CompilationLevel + compilation_config = get_cached_compilation_config() - default_on = (compilation_config.level < CompilationLevel.PIECEWISE - or not compilation_config.use_inductor) + default_on = ( + compilation_config.level < CompilationLevel.PIECEWISE + or not compilation_config.use_inductor + ) count_none = compilation_config.custom_ops.count("none") count_all = compilation_config.custom_ops.count("all") return default_on and not count_none > 0 or count_all > 0 @@ -143,13 +147,12 @@ def default_on() -> bool: # Examples: # - MyOp.enabled() # - op_registry["my_op"].enabled() - op_registry: dict[str, type['CustomOp']] = {} - op_registry_oot: dict[str, type['CustomOp']] = {} + op_registry: dict[str, type["CustomOp"]] = {} + op_registry_oot: dict[str, type["CustomOp"]] = {} # Decorator to register custom ops. @classmethod def register(cls, name: str): - def decorator(op_cls): assert name not in cls.op_registry, f"Duplicate op name: {name}" op_cls.name = name @@ -169,11 +172,9 @@ def decorator(op_cls): # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod") @classmethod def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None): - def decorator(op_cls): reg_name = name if name is not None else cls.__name__ - assert reg_name not in cls.op_registry_oot, \ - f"Duplicate op name: {reg_name}" + assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" op_cls.name = reg_name cls.op_registry_oot[reg_name] = op_cls return op_cls diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 235df1a77c5c..96745b99f7a7 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom activation functions.""" + import math from typing import Optional @@ -8,8 +9,11 @@ import torch.nn as nn import torch.nn.functional as F -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.utils import set_weight_attrs @@ -32,7 +36,7 @@ class FatreluAndMul(CustomOp): return: (num_tokens, d) or (batch_size, seq_len, d) """ - def __init__(self, threshold: float = 0.): + def __init__(self, threshold: float = 0.0): super().__init__() self.threshold = threshold if current_platform.is_cuda_alike(): @@ -49,7 +53,7 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x, self.threshold) return out @@ -72,6 +76,7 @@ def __init__(self): self.op = torch.ops._C.silu_and_mul elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.silu_and_mul elif current_platform.is_cpu(): self._forward_method = self.forward_native @@ -83,14 +88,14 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x) return out @@ -113,6 +118,7 @@ def __init__(self): self.op = torch.ops._C.mul_and_silu elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.silu_and_mul elif current_platform.is_cpu(): self._forward_method = self.forward_native @@ -124,7 +130,7 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x) return out @@ -156,10 +162,8 @@ def __init__(self, activation_sparsity: float, approximate: str = "none"): # Sparsity. if activation_sparsity == 0.0: - raise ValueError( - "activation_sparsity is 0.0. Please use GeluAndMul.") - target_sparsity_tensor = torch.tensor(activation_sparsity, - dtype=torch.float32) + raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.") + target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32) normal_dist = torch.distributions.normal.Normal(0, 1) self.std_multiplier = normal_dist.icdf(target_sparsity_tensor) @@ -207,6 +211,7 @@ def __init__(self, approximate: str = "none"): self.op = torch.ops._C.gelu_tanh_and_mul elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + if approximate == "none": self.op = ipex_ops.gelu_and_mul else: @@ -219,20 +224,20 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) self.op(out, x) return out def extra_repr(self) -> str: - return f'approximate={repr(self.approximate)}' + return f"approximate={repr(self.approximate)}" @CustomOp.register("swigluoai_and_mul") @@ -255,7 +260,7 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit) return out @@ -266,20 +271,19 @@ def extra_repr(self) -> str: @CustomOp.register("gelu_new") class NewGELU(CustomOp): - def __init__(self): super().__init__() if current_platform.is_cuda_alike() or current_platform.is_cpu(): self.op = torch.ops._C.gelu_new elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_new def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" c = math.sqrt(2.0 / math.pi) - return 0.5 * x * (1.0 + torch.tanh(c * - (x + 0.044715 * torch.pow(x, 3.0)))) + return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0)))) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) @@ -292,19 +296,18 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: @CustomOp.register("gelu_fast") class FastGELU(CustomOp): - def __init__(self): super().__init__() if current_platform.is_cuda_alike() or current_platform.is_cpu(): self.op = torch.ops._C.gelu_fast elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_fast def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" - return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * - (1.0 + 0.044715 * x * x))) + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) @@ -324,6 +327,7 @@ def __init__(self): self.op = torch.ops._C.gelu_quick elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.gelu_quick def forward_native(self, x: torch.Tensor) -> torch.Tensor: @@ -355,7 +359,7 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: return torch.square(F.relu(x)) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - #TODO : implement cuda kernels + # TODO : implement cuda kernels return self.forward_native(x) @@ -378,12 +382,15 @@ def __init__( ): super().__init__() self.alpha_p = nn.Parameter( - torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - - 1).unsqueeze(0)) + torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze( + 0 + ) + ) self.alpha_n = nn.Parameter( torch.log( - torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - - 1).unsqueeze(0)) + torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1 + ).unsqueeze(0) + ) self.register_buffer("beta", torch.tensor(beta, dtype=dtype)) self.register_buffer("eps", torch.tensor(eps, dtype=dtype)) self.with_vector_loads = with_vector_loads @@ -403,8 +410,10 @@ def __init__( self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda) msg += " Enabled torch._dynamo for xIELU CUDA." except Exception as err: - msg += (f" Could not enable torch._dynamo for xIELU ({err}) - " - "this may result in slower performance.") + msg += ( + f" Could not enable torch._dynamo for xIELU ({err}) - " + "this may result in slower performance." + ) self._xielu_cuda_fn = self._xielu_cuda logger.warning_once(msg) except Exception as err: @@ -421,14 +430,12 @@ def _xielu_python(self, x: torch.Tensor) -> torch.Tensor: return torch.where( x > 0, alpha_p * x * x + self.beta * x, - (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + - self.beta * x, + (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x, ) def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor: """Firewall function to prevent torch.compile from seeing .item()""" - assert self._xielu_cuda_obj is not None, ( - "XIELU CUDA object must not be None") + assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None" original_shape = x.shape # CUDA kernel expects 3D tensors, reshape if needed while x.dim() < 3: @@ -486,14 +493,14 @@ def __init__( self.input_is_parallel = input_is_parallel if input_is_parallel: tp_size = get_tensor_model_parallel_world_size() - intermediate_size_per_partition = divide(intermediate_size, - tp_size) + intermediate_size_per_partition = divide(intermediate_size, tp_size) else: intermediate_size_per_partition = intermediate_size if params_dtype is None: params_dtype = torch.get_default_dtype() self.scales = nn.Parameter( - torch.empty(intermediate_size_per_partition, dtype=params_dtype)) + torch.empty(intermediate_size_per_partition, dtype=params_dtype) + ) set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -510,30 +517,21 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data.copy_(loaded_weight) -_ACTIVATION_REGISTRY = LazyDict({ - "gelu": - lambda: nn.GELU(), - "gelu_fast": - lambda: FastGELU(), - "gelu_new": - lambda: NewGELU(), - "gelu_pytorch_tanh": - lambda: nn.GELU(approximate="tanh"), - "relu": - lambda: nn.ReLU(), - "relu2": - lambda: ReLUSquaredActivation(), - "silu": - lambda: nn.SiLU(), - "quick_gelu": - lambda: QuickGELU(), - "tanh": - lambda: nn.Tanh(), - "sigmoid": - lambda: nn.Sigmoid(), - "xielu": - lambda: XIELU(), -}) +_ACTIVATION_REGISTRY = LazyDict( + { + "gelu": lambda: nn.GELU(), + "gelu_fast": lambda: FastGELU(), + "gelu_new": lambda: NewGELU(), + "gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"), + "relu": lambda: nn.ReLU(), + "relu2": lambda: ReLUSquaredActivation(), + "silu": lambda: nn.SiLU(), + "quick_gelu": lambda: QuickGELU(), + "tanh": lambda: nn.Tanh(), + "sigmoid": lambda: nn.Sigmoid(), + "xielu": lambda: XIELU(), + } +) def get_act_fn(act_fn_name: str) -> nn.Module: @@ -547,29 +545,25 @@ def get_act_fn(act_fn_name: str) -> nn.Module: act_fn_name = activation_name if act_fn_name not in _ACTIVATION_REGISTRY: - raise ValueError( - f"Activation function {act_fn_name!r} is not supported.") + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") return _ACTIVATION_REGISTRY[act_fn_name] -_ACTIVATION_AND_MUL_REGISTRY = LazyDict({ - "gelu": - lambda: GeluAndMul(), - "silu": - lambda: SiluAndMul(), - "geglu": - lambda: GeluAndMul(), - "swigluoai": - lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), -}) +_ACTIVATION_AND_MUL_REGISTRY = LazyDict( + { + "gelu": lambda: GeluAndMul(), + "silu": lambda: SiluAndMul(), + "geglu": lambda: GeluAndMul(), + "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), + } +) def get_act_and_mul_fn(act_fn_name: str) -> nn.Module: """Get an activation-and-mul (i.e. SiluAndMul) function by name.""" act_fn_name = act_fn_name.lower() if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY: - raise ValueError( - f"Activation function {act_fn_name!r} is not supported.") + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name] diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py index 782818f55fbc..fa74c20840da 100644 --- a/vllm/model_executor/layers/attention_layer_base.py +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Base class for attention-like layers.""" + from abc import ABC, abstractmethod from typing import TYPE_CHECKING @@ -10,10 +11,10 @@ class AttentionLayerBase(ABC): """ - Base class for attention-like layers (Attention, Mamba, etc.) + Base class for attention-like layers (Attention, Mamba, etc.) that support the v1 engine. - - This provides a common interface for getting attention backends + + This provides a common interface for getting attention backends from different layer types. """ diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 150c48c0e880..9fd85d1e9e19 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -8,21 +8,20 @@ import torch -import vllm.envs as envs -from vllm.logger import init_logger from vllm.triton_utils import tl, triton -logger = init_logger(__name__) - -def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, - args: dict[str, Any]) -> dict[str, Any]: +def _matmul_launch_metadata( + grid: Callable[..., Any], kernel: Any, args: dict[str, Any] +) -> dict[str, Any]: ret = {} m, n, k = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]" if "tiles_per_update" in args: - ret["name"] = (f"{kernel.name} [M={m}, N={n}, K={k}, " - f"tiles_per_update={args['tiles_per_update']:02}]") + ret["name"] = ( + f"{kernel.name} [M={m}, N={n}, K={k}, " + f"tiles_per_update={args['tiles_per_update']:02}]" + ) if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: @@ -79,8 +78,9 @@ def matmul_kernel_persistent( num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): - pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, - GROUP_SIZE_M, NUM_SMS) + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS + ) start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) @@ -91,46 +91,44 @@ def matmul_kernel_persistent( offs_bn = offs_bn.to(tl.int64) offs_am = tl.where(offs_am < M, offs_am, 0) offs_bn = tl.where(offs_bn < N, offs_bn, 0) - offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), - BLOCK_SIZE_M) - offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), - BLOCK_SIZE_N) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): if A_LARGE or B_LARGE: - offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to( - tl.int64) + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) else: offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + - offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + - offs_bn[None, :] * stride_bn) - - a = tl.load(a_ptrs, - mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, - other=0.0) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) + b_ptrs = b_ptr + ( + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + ) + + a = tl.load( + a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0 + ) + b = tl.load( + b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0 + ) accumulator = tl.dot(a, b, accumulator) tile_id_c += NUM_SMS - pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, - GROUP_SIZE_M, NUM_SMS) + pid_m, pid_n = _compute_pid( + tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS + ) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if C_LARGE: offs_cm = offs_cm.to(tl.int64) offs_cn = offs_cn.to(tl.int64) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) if HAS_BIAS: bias_ptrs = bias_ptr + offs_cn - bias = tl.load(bias_ptrs, mask=offs_cn < N, - other=0.0).to(tl.float32) + bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32) accumulator += bias if c_ptr.dtype.element_ty == tl.float8e4nv: c = accumulator.to(tl.float8e4nv) @@ -139,14 +137,15 @@ def matmul_kernel_persistent( tl.store(c_ptrs, c, mask=c_mask) -def matmul_persistent(a: torch.Tensor, - b: torch.Tensor, - bias: Union[torch.Tensor, None] = None): +def matmul_persistent( + a: torch.Tensor, b: torch.Tensor, bias: Union[torch.Tensor, None] = None +): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.dtype == b.dtype, "Incompatible dtypes" assert bias is None or bias.dim() == 1, ( - "Currently assuming bias is 1D, let Horace know if you run into this") + "Currently assuming bias is 1D, let Horace know if you run into this" + ) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count M, K = a.shape K, N = b.shape @@ -156,10 +155,13 @@ def matmul_persistent(a: torch.Tensor, # 1D launch kernel where each block gets its own program. def grid(META): - return (min( - NUM_SMS, - triton.cdiv(M, META["BLOCK_SIZE_M"]) * - triton.cdiv(N, META["BLOCK_SIZE_N"])), ) + return ( + min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), + ) configs = { torch.bfloat16: { @@ -288,8 +290,9 @@ def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: Tensor with log_softmax applied along the specified dimension """ if dim != -1 and dim != input.ndim - 1: - raise ValueError("This implementation only supports log_softmax along " - "the last dimension") + raise ValueError( + "This implementation only supports log_softmax along the last dimension" + ) # Flatten all dimensions except the last one original_shape = input.shape @@ -305,7 +308,7 @@ def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: BLOCK_SIZE = 1024 # Launch kernel with one block per row - grid = (n_rows, ) + grid = (n_rows,) _log_softmax_kernel[grid]( input_2d, output, @@ -354,8 +357,9 @@ def mean_kernel( mask = n_offsets < N # Calculate input indices - input_idx = m_idx * input_stride0 + n_offsets * input_stride1 \ - + k_idx * input_stride2 + input_idx = ( + m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2 + ) # Load and accumulate vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) @@ -367,10 +371,12 @@ def mean_kernel( tl.store(output_ptr + output_idx, mean_val) -def mean_dim(input: torch.Tensor, - dim: int, - keepdim: bool = False, - dtype: Union[torch.dtype, None] = None) -> torch.Tensor: +def mean_dim( + input: torch.Tensor, + dim: int, + keepdim: bool = False, + dtype: Union[torch.dtype, None] = None, +) -> torch.Tensor: """ Triton implementation of torch.mean with single dimension reduction. @@ -387,7 +393,8 @@ def mean_dim(input: torch.Tensor, # Validate inputs assert input.is_cuda, "Input must be a CUDA tensor" assert -input.ndim <= dim < input.ndim, ( - f"Invalid dimension {dim} for tensor with {input.ndim} dimensions") + f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" + ) # Handle negative dim if dim < 0: @@ -426,19 +433,16 @@ def mean_dim(input: torch.Tensor, output_shape = shape.copy() output_shape[dim] = 1 else: - output_shape = shape[:dim] + shape[dim + 1:] + output_shape = shape[:dim] + shape[dim + 1 :] # Create output tensor output = torch.empty(output_shape, dtype=dtype, device=input.device) # Reshape output for kernel - if keepdim: - output_2d = output.reshape(M, 1, K).squeeze(1) - else: - output_2d = output.reshape(M, K) + output_2d = output.reshape(M, 1, K).squeeze(1) if keepdim else output.reshape(M, K) # Launch kernel - grid = (M * K, ) + grid = (M * K,) BLOCK_SIZE = 1024 mean_kernel[grid]( @@ -471,12 +475,10 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float): return log_softmax(input, dim=dim) -def mean_batch_invariant(input, - dim, - keepdim=False, - dtype: Union[torch.dtype, None] = None): - assert dtype is None or dtype == torch.float32, \ - f"unsupported dtype: {dtype}" +def mean_batch_invariant( + input, dim, keepdim=False, dtype: Union[torch.dtype, None] = None +): + assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" result = input.to(torch.float32) @@ -513,8 +515,9 @@ def enable_batch_invariant_mode(): _batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") - _batch_invariant_LIB.impl("aten::_log_softmax", - _log_softmax_batch_invariant, "CUDA") + _batch_invariant_LIB.impl( + "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" + ) _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") @@ -561,12 +564,5 @@ def vllm_kernel_override_batch_invariant(): def init_batch_invariance(): # this will hit all the csrc overrides as well if vllm_kernel_override_batch_invariant(): - curr_attn_backend = envs.VLLM_ATTENTION_BACKEND - supported_backends = ["FLEX_ATTENTION", "FLASHINFER"] - if curr_attn_backend not in supported_backends: - warning = "Forcibly updating attention backend to" \ - f" {supported_backends[0]} for batch_invariant. " \ - f" Supported backends: {supported_backends}." - logger.warning_once(warning) - os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0] + os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" enable_batch_invariant_mode() diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index e7d295aff239..d65c87aba11c 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -23,22 +23,22 @@ from .wy_fast import recompute_w_u_fwd -def chunk_gated_delta_rule_fwd(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None): +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +): g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. - A = chunk_scaled_dot_kkt_fwd(k=k, - beta=beta, - g_cumsum=g, - cu_seqlens=cu_seqlens, - output_dtype=torch.float32) + A = chunk_scaled_dot_kkt_fwd( + k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32 + ) A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) w, u = recompute_w_u_fwd( k=k, @@ -73,21 +73,22 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor, class ChunkGatedDeltaRuleFunction(torch.autograd.Function): - @staticmethod @input_guard - @torch.amp.custom_fwd(device_type='cuda') - def forward(ctx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, - use_qk_l2norm_in_kernel: bool = False): + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) k = l2norm_fwd(k) @@ -109,17 +110,19 @@ def forward(ctx, @torch.compiler.disable -def chunk_gated_delta_rule(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float = None, - initial_state: torch.Tensor = None, - output_final_state: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False, - use_qk_l2norm_in_kernel: bool = False): +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): r""" Args: q (torch.Tensor): @@ -184,42 +187,55 @@ def chunk_gated_delta_rule(q: torch.Tensor, ) """ assert q.dtype == k.dtype == v.dtype - assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." - assert len( - beta.shape - ) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + assert q.dtype != torch.float32, ( + "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + ) + assert len(beta.shape) == 3, ( + "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + ) if head_first: raise DeprecationWarning( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead.", - stacklevel=2) + stacklevel=2, + ) q, k, v, beta, g = map( - lambda x: rearrange(x, 'b h t ... -> b t h ...'), - (q, k, v, beta, g)) + lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g) + ) if not head_first and q.shape[1] < q.shape[2]: warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", - stacklevel=2) + stacklevel=2, + ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") - if initial_state is not None and initial_state.shape[0] != len( - cu_seqlens) - 1: + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: raise ValueError( f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." ) if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 o, final_state = ChunkGatedDeltaRuleFunction.apply( - q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, - use_qk_l2norm_in_kernel) + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + ) if head_first: - o = rearrange(o, 'b t h ... -> b h t ...') + o = rearrange(o, "b t h ... -> b h t ...") return o, final_state diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py index 34006f87f457..817962d9c946 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py +++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py @@ -20,22 +20,26 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] -@triton.heuristics({ - 'USE_G': lambda args: args['g'] is not None, - 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, - 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, - 'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None, - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, -}) +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) @triton.autotune( configs=[ - triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4] for num_stages in [2, 3, 4] for BV in [32, 64] + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] ], - key=['H', 'K', 'V', 'BT', 'USE_G'], + key=["H", "K", "V", "BT", "USE_G"], use_cuda_graph=use_cuda_graph, ) -@triton.jit(do_not_specialize=['T']) +@triton.jit(do_not_specialize=["T"]) def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( k, v, @@ -63,8 +67,10 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( i_v, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // H, i_nh % H if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos NT = tl.cdiv(T, BT) boh = tl.load(chunk_offsets + i_n).to(tl.int32) @@ -100,87 +106,98 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( # load initial state if USE_INITIAL_STATE: - p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), - (1, 0)) + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) if K > 64: - p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), - (64, BV), (1, 0)) + p_h0_2 = tl.make_block_ptr( + h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) if K > 128: - p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), - (64, BV), (1, 0)) + p_h0_3 = tl.make_block_ptr( + h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) if K > 192: - p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), - (64, BV), (1, 0)) + p_h0_4 = tl.make_block_ptr( + h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) # main recurrence for i_t in range(NT): - p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), - (0, i_v * BV), (64, BV), (1, 0)) + p_h1 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0) + ) tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) if K > 64: - p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), - (64, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h2, - b_h2.to(p_h2.dtype.element_ty), - boundary_check=(0, 1)) + p_h2 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) if K > 128: - p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), - (128, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h3, - b_h3.to(p_h3.dtype.element_ty), - boundary_check=(0, 1)) + p_h3 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) if K > 192: - p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), - (192, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h4, - b_h4.to(p_h4.dtype.element_ty), - boundary_check=(0, 1)) + p_h4 = tl.make_block_ptr( + h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) - p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), - (BT, BV), (1, 0)) - p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), - (i_t * BT, i_v * BV), (BT, BV), - (1, 0)) if SAVE_NEW_VALUE else None + p_v = tl.make_block_ptr( + v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + p_v_new = ( + tl.make_block_ptr( + v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + if SAVE_NEW_VALUE + else None + ) b_v_new = tl.zeros([BT, BV], dtype=tl.float32) - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), - (BT, 64), (1, 0)) + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0) + ) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) if K > 64: - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), - (BT, 64), (1, 0)) + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0) + ) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) if K > 128: - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), - (BT, 64), (1, 0)) + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0) + ) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) if K > 192: - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), - (BT, 64), (1, 0)) + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0) + ) b_w = tl.load(p_w, boundary_check=(0, 1)) b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) if SAVE_NEW_VALUE: - p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), - (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - tl.store(p_v_new, - b_v_new.to(p_v_new.dtype.element_ty), - boundary_check=(0, 1)) + p_v_new = tl.make_block_ptr( + v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + tl.store( + p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1) + ) if USE_G: m_t = (i_t * BT + tl.arange(0, BT)) < T last_idx = min((i_t + 1) * BT, T) - 1 b_g_last = tl.load(g + bos * H + last_idx * H + i_h) - p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ), - (i_t * BT, ), (BT, ), (0, )) - b_g = tl.load(p_g, boundary_check=(0, )) + p_g = tl.make_block_ptr( + g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) + ) + b_g = tl.load(p_g, boundary_check=(0,)) b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] b_g_last = exp(b_g_last) b_h1 = b_h1 * b_g_last @@ -191,49 +208,49 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( if K > 192: b_h4 = b_h4 * b_g_last b_v_new = b_v_new.to(k.dtype.element_ty) - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), - (64, BT), (0, 1)) + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1) + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_h1 += tl.dot(b_k, b_v_new) if K > 64: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), - (64, BT), (0, 1)) + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1) + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_h2 += tl.dot(b_k, b_v_new) if K > 128: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), - (64, BT), (0, 1)) + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1) + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_h3 += tl.dot(b_k, b_v_new) if K > 192: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), - (64, BT), (0, 1)) + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1) + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_h4 += tl.dot(b_k, b_v_new) # epilogue if STORE_FINAL_STATE: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), - (1, 0)) + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) if K > 64: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), - (64, BV), (1, 0)) - tl.store(p_ht, - b_h2.to(p_ht.dtype.element_ty), - boundary_check=(0, 1)) + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) if K > 128: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), - (64, BV), (1, 0)) - tl.store(p_ht, - b_h3.to(p_ht.dtype.element_ty), - boundary_check=(0, 1)) + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) if K > 192: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), - (64, BV), (1, 0)) - tl.store(p_ht, - b_h4.to(p_ht.dtype.element_ty), - boundary_check=(0, 1)) + p_ht = tl.make_block_ptr( + ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0) + ) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) def chunk_gated_delta_rule_fwd_h( @@ -251,24 +268,31 @@ def chunk_gated_delta_rule_fwd_h( H = u.shape[-2] BT = chunk_size - chunk_indices = prepare_chunk_indices( - cu_seqlens, chunk_size) if cu_seqlens is not None else None + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is not None + else None + ) # N: the actual number of sequences in the batch with either equal or variable lengths if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: - N, NT, chunk_offsets = len(cu_seqlens) - 1, len( - chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) assert K <= 256, "current kernel does not support head dimension larger than 256." h = k.new_empty(B, NT, H, K, V) - final_state = k.new_empty( - N, H, K, V, dtype=torch.float32) if output_final_state else None + final_state = ( + k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + ) v_new = torch.empty_like(u) if save_new_value else None def grid(meta): - return (triton.cdiv(V, meta['BV']), N * H) + return (triton.cdiv(V, meta["BV"]), N * H) chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( k=k, @@ -286,5 +310,6 @@ def grid(meta): Hg=Hg, K=K, V=V, - BT=BT) + BT=BT, + ) return h, v_new, final_state diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py index 332751a1860a..ae404a3615f6 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_o.py +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -23,24 +23,23 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] -@triton.heuristics({ - 'USE_G': lambda args: args['g'] is not None, - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None -}) +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) @triton.autotune( configs=[ - triton.Config({ - 'BK': BK, - 'BV': BV - }, - num_warps=num_warps, - num_stages=num_stages) for BK in BKV_LIST - for BV in BKV_LIST for num_warps in NUM_WARPS + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS for num_stages in [2, 3, 4] ], - key=['H', 'K', 'V', 'BT'], + key=["H", "K", "V", "BT"], ) -@triton.jit(do_not_specialize=['T']) +@triton.jit(do_not_specialize=["T"]) def chunk_fwd_kernel_o( q, k, @@ -67,10 +66,14 @@ def chunk_fwd_kernel_o( if IS_VARLEN: i_tg = i_t - i_n, i_t = tl.load(chunk_indices + i_t * 2).to( - tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos NT = tl.cdiv(T, BT) else: @@ -89,12 +92,15 @@ def chunk_fwd_kernel_o( b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), - (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), - (BK, BT), (0, 1)) - p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), - (BK, BV), (1, 0)) + p_q = tl.make_block_ptr( + q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_k = tl.make_block_ptr( + k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1) + ) + p_h = tl.make_block_ptr( + h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0) + ) # [BT, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BT] @@ -109,8 +115,8 @@ def chunk_fwd_kernel_o( if USE_G: g += bos * H + i_h - p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, )) - b_g = tl.load(p_g, boundary_check=(0, )) + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) b_o = b_o * exp(b_g)[:, None] b_A = b_A * exp(b_g[:, None] - b_g[None, :]) @@ -119,10 +125,12 @@ def chunk_fwd_kernel_o( m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) b_A = tl.where(m_A, b_A, 0) - p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), - (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), - (BT, BV), (1, 0)) + p_v = tl.make_block_ptr( + v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + p_o = tl.make_block_ptr( + o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) b_v = tl.load(p_v, boundary_check=(0, 1)) # to fix mma -> mma layout conversion @@ -132,30 +140,29 @@ def chunk_fwd_kernel_o( def chunk_fwd_o( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - h: torch.Tensor, - g: Optional[torch.Tensor] = None, # cumsum of log decay - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64) -> torch.Tensor: + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: B, T, Hg, K, V = *q.shape, v.shape[-1] H = v.shape[-2] - if FLA_GDN_FIX_BT: - BT = 64 - else: - BT = min(chunk_size, max(16, triton.next_power_of_2(T))) - chunk_indices = prepare_chunk_indices( - cu_seqlens, BT) if cu_seqlens is not None else None + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 o = torch.empty_like(v) def grid(meta): - return (triton.cdiv(V, meta['BV']), NT, B * H) + return (triton.cdiv(V, meta["BV"]), NT, B * H) chunk_fwd_kernel_o[grid]( q, diff --git a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py index d1adc6978f24..0da3f243901f 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py +++ b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py @@ -17,19 +17,22 @@ from .op import exp -@triton.heuristics({ - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, - 'USE_G': lambda args: args['g_cumsum'] is not None -}) +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "USE_G": lambda args: args["g_cumsum"] is not None, + } +) @triton.autotune( configs=[ - triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) - for BK in [32, 64, 128] for num_warps in [2, 4, 8] + triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] ], - key=['H', 'K', 'BT', 'IS_VARLEN'], + key=["H", "K", "BT", "IS_VARLEN"], ) -@triton.jit(do_not_specialize=['T']) +@triton.jit(do_not_specialize=["T"]) def chunk_scaled_dot_kkt_fwd_kernel( k, beta, @@ -49,50 +52,63 @@ def chunk_scaled_dot_kkt_fwd_kernel( i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to( - tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T o_t = i_t * BT + tl.arange(0, BT) m_t = o_t < T - p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ), - (i_t * BT, ), (BT, ), (0, )) - b_beta = tl.load(p_beta, boundary_check=(0, )) + p_beta = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) + ) + b_beta = tl.load(p_beta, boundary_check=(0,)) b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), - (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), - (1, 0)) + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = b_k * b_beta[:, None] b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) if USE_G: - p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ), - (i_t * BT, ), (BT, ), (0, )) - b_g = tl.load(p_g, boundary_check=(0, )) + p_g = tl.make_block_ptr( + g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) + ) + b_g = tl.load(p_g, boundary_check=(0,)) b_g_diff = b_g[:, None] - b_g[None, :] b_A = b_A * exp(b_g_diff) m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) b_A = tl.where(m_A, b_A, 0) - p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), - (i_t * BT, 0), (BT, BT), (1, 0)) + p_A = tl.make_block_ptr( + A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0) + ) tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) def chunk_scaled_dot_kkt_fwd( - k: torch.Tensor, - beta: torch.Tensor, - g_cumsum: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, - output_dtype: torch.dtype = torch.float32) -> torch.Tensor: + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: r""" Compute beta * K * K^T. @@ -120,8 +136,9 @@ def chunk_scaled_dot_kkt_fwd( H = beta.shape[-1] BT = chunk_size - chunk_indices = prepare_chunk_indices( - cu_seqlens, BT) if cu_seqlens is not None else None + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( diff --git a/vllm/model_executor/layers/fla/ops/cumsum.py b/vllm/model_executor/layers/fla/ops/cumsum.py index 370a45fe1635..cfa2b3b48e70 100644 --- a/vllm/model_executor/layers/fla/ops/cumsum.py +++ b/vllm/model_executor/layers/fla/ops/cumsum.py @@ -20,12 +20,12 @@ BS_LIST = [32, 64] if check_shared_mem() else [16, 32] -@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) -@triton.autotune(configs=[ - triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8] -], - key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE']) -@triton.jit(do_not_specialize=['T']) +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) def chunk_local_cumsum_scalar_kernel( s, o, @@ -42,40 +42,47 @@ def chunk_local_cumsum_scalar_kernel( i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to( - tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: - p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T, ), (1, ), - (i_t * BT, ), (BT, ), (0, )) - p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T, ), (1, ), - (i_t * BT, ), (BT, ), (0, )) + p_s = tl.make_block_ptr( + s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,) + ) + p_o = tl.make_block_ptr( + o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,) + ) else: - p_s = tl.make_block_ptr(s + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), - (BT, ), (0, )) - p_o = tl.make_block_ptr(o + bos * H + i_h, (T, ), (H, ), (i_t * BT, ), - (BT, ), (0, )) + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) # [BT] - b_s = tl.load(p_s, boundary_check=(0, )).to(tl.float32) + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) b_o = tl.cumsum(b_s, axis=0) if REVERSE: b_z = tl.sum(b_s, axis=0) b_o = -b_o + b_z[None] + b_s - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, )) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) -@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) -@triton.autotune(configs=[ - triton.Config({'BS': BS}, num_warps=num_warps) for BS in BS_LIST - for num_warps in [2, 4, 8] -], - key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE']) -@triton.jit(do_not_specialize=['T']) +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({"BS": BS}, num_warps=num_warps) + for BS in BS_LIST + for num_warps in [2, 4, 8] + ], + key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) def chunk_local_cumsum_vector_kernel( s, o, @@ -94,30 +101,58 @@ def chunk_local_cumsum_vector_kernel( i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to( - tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T o_i = tl.arange(0, BT) if REVERSE: - m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) else: - m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) if HEAD_FIRST: - p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), - (i_t * BT, i_s * BS), (BT, BS), (1, 0)) - p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), - (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_s = tl.make_block_ptr( + s + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) else: - p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), - (i_t * BT, i_s * BS), (BT, BS), (1, 0)) - p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), - (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_s = tl.make_block_ptr( + s + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) # [BT, BS] b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) b_o = tl.dot(m_s, b_s, allow_tf32=False) @@ -125,102 +160,122 @@ def chunk_local_cumsum_vector_kernel( def chunk_local_cumsum_scalar( - g: torch.Tensor, - chunk_size: int, - reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor: + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: if head_first: B, H, T = g.shape else: B, T, H = g.shape - assert chunk_size == 2**(chunk_size.bit_length() - - 1), "chunk_size must be a power of 2" + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), ( + "chunk_size must be a power of 2" + ) BT = chunk_size - chunk_indices = prepare_chunk_indices( - cu_seqlens, BT) if cu_seqlens is not None else None + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) grid = (NT, B * H) - chunk_local_cumsum_scalar_kernel[grid](g_org, - g, - cu_seqlens, - chunk_indices, - T=T, - B=B, - H=H, - BT=BT, - HEAD_FIRST=head_first, - REVERSE=reverse) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) return g def chunk_local_cumsum_vector( - g: torch.Tensor, - chunk_size: int, - reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor: + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: if head_first: B, H, T, S = g.shape else: B, T, H, S = g.shape BT = chunk_size - chunk_indices = prepare_chunk_indices( - cu_seqlens, chunk_size) if cu_seqlens is not None else None + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is not None + else None + ) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - assert chunk_size == 2**(chunk_size.bit_length() - - 1), "chunk_size must be a power of 2" + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), ( + "chunk_size must be a power of 2" + ) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) def grid(meta): - return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) # keep cumulative normalizer in fp32 # this kernel is equivalent to # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) - chunk_local_cumsum_vector_kernel[grid](g_org, - g, - cu_seqlens, - chunk_indices, - T=T, - B=B, - H=H, - S=S, - BT=BT, - HEAD_FIRST=head_first, - REVERSE=reverse) + chunk_local_cumsum_vector_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) return g @input_guard -def chunk_local_cumsum(g: torch.Tensor, - chunk_size: int, - reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, - **kwargs) -> torch.Tensor: +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs, +) -> torch.Tensor: if not head_first and g.shape[1] < g.shape[2]: warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", - stacklevel=2) + stacklevel=2, + ) if cu_seqlens is not None: - assert g.shape[ - 0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + assert g.shape[0] == 1, ( + "Only batch size 1 is supported when cu_seqlens are provided" + ) if len(g.shape) == 3: - return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, - head_first, output_dtype) + return chunk_local_cumsum_scalar( + g, chunk_size, reverse, cu_seqlens, head_first, output_dtype + ) elif len(g.shape) == 4: - return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, - head_first, output_dtype) + return chunk_local_cumsum_vector( + g, chunk_size, reverse, cu_seqlens, head_first, output_dtype + ) else: - raise ValueError(f"Unsupported input shape {g.shape}. " - f"which should be (B, T, H, D) if `head_first=False` " - f"or (B, H, T, D) otherwise") + raise ValueError( + f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index 98437340fd24..fa10bdb36caa 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -16,17 +16,15 @@ from .op import exp -@triton.heuristics({ - 'USE_INITIAL_STATE': - lambda args: args['h0'] is not None, - 'IS_VARLEN': - lambda args: args['cu_seqlens'] is not None, - "IS_CONTINUOUS_BATCHING": - lambda args: args['ssm_state_indices'] is not None, - "IS_SPEC_DECODING": - lambda args: args['num_accepted_tokens'] is not None, -}) -@triton.jit(do_not_specialize=['N', 'T']) +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) def fused_recurrent_gated_delta_rule_fwd_kernel( q, k, @@ -55,8 +53,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( stride_indices_tok: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, # whether to use initial state INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace - IS_BETA_HEADWISE: tl. - constexpr, # whether beta is headwise vector or scalar, + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, USE_QK_L2NORM_IN_KERNEL: tl.constexpr, IS_VARLEN: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, @@ -66,8 +63,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( i_n, i_hv = i_nh // HV, i_nh % HV i_h = i_hv // (HV // H) if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) all = T T = eos - bos else: @@ -102,8 +101,13 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 else: i_t = 0 - p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + - i_t).to(tl.int64) * stride_init_state_token + p_h0 = ( + h0 + + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( + tl.int64 + ) + * stride_init_state_token + ) else: p_h0 = h0 + bos * HV * K * V p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] @@ -136,8 +140,13 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( # keep the states for multi-query tokens if INPLACE_FINAL_STATE: - p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + - i_t).to(tl.int64) * stride_final_state_token + p_ht = ( + ht + + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( + tl.int64 + ) + * stride_final_state_token + ) else: p_ht = ht + (bos + i_t) * stride_final_state_token p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] @@ -228,21 +237,22 @@ def fused_recurrent_gated_delta_rule_fwd( class FusedRecurrentFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - use_qk_l2norm_in_kernel: bool = False): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): o, final_state = fused_recurrent_gated_delta_rule_fwd( q=q.contiguous(), k=k.contiguous(), @@ -342,9 +352,10 @@ def fused_recurrent_gated_delta_rule( if cu_seqlens is not None and q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing.") + f"Please flatten variable-length inputs before processing." + ) if scale is None: - scale = k.shape[-1]**-0.5 + scale = k.shape[-1] ** -0.5 else: assert scale > 0, "scale must be positive" if beta is None: diff --git a/vllm/model_executor/layers/fla/ops/index.py b/vllm/model_executor/layers/fla/ops/index.py index 9eca32bc31a0..f023e1378bb8 100644 --- a/vllm/model_executor/layers/fla/ops/index.py +++ b/vllm/model_executor/layers/fla/ops/index.py @@ -20,20 +20,22 @@ def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: @tensor_cache -def prepare_chunk_indices(cu_seqlens: torch.LongTensor, - chunk_size: int) -> torch.LongTensor: - indices = torch.cat([ - torch.arange(n) - for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() - ]) - return torch.stack([indices.eq(0).cumsum(0) - 1, indices], - 1).to(cu_seqlens) +def prepare_chunk_indices( + cu_seqlens: torch.LongTensor, chunk_size: int +) -> torch.LongTensor: + indices = torch.cat( + [ + torch.arange(n) + for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() + ] + ) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) @tensor_cache -def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, - chunk_size: int) -> torch.LongTensor: - return torch.cat([ - cu_seqlens.new_tensor([0]), - triton.cdiv(prepare_lens(cu_seqlens), chunk_size) - ]).cumsum(-1) +def prepare_chunk_offsets( + cu_seqlens: torch.LongTensor, chunk_size: int +) -> torch.LongTensor: + return torch.cat( + [cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)] + ).cumsum(-1) diff --git a/vllm/model_executor/layers/fla/ops/l2norm.py b/vllm/model_executor/layers/fla/ops/l2norm.py index ef9788ceaf20..315dd904523b 100644 --- a/vllm/model_executor/layers/fla/ops/l2norm.py +++ b/vllm/model_executor/layers/fla/ops/l2norm.py @@ -19,11 +19,12 @@ USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) -@triton.autotune(configs=[ - triton.Config({}, num_warps=num_warps) - for num_warps in [1, 2, 4, 8, 16, 32] -], - key=['D']) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=["D"], +) @triton.jit def l2norm_fwd_kernel1( x, @@ -47,11 +48,14 @@ def l2norm_fwd_kernel1( tl.store(y + cols, b_y, mask=mask) -@triton.autotune(configs=[ - triton.Config({'BT': BT}, num_warps=num_warps) - for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST -], - key=['D']) +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + for BT in BT_LIST + ], + key=["D"], +) @triton.jit(do_not_specialize=["NB"]) def l2norm_fwd_kernel( x, @@ -85,9 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) -def l2norm_fwd(x: torch.Tensor, - eps: float = 1e-6, - output_dtype: Optional[torch.dtype] = None): +def l2norm_fwd( + x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None +): x_shape_og = x.shape x = x.view(-1, x.shape[-1]) # allocate output @@ -107,7 +111,7 @@ def l2norm_fwd(x: torch.Tensor, if not USE_DEFAULT_FLA_NORM: MBLOCK = 32 # M, N = x.shape - l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )]( + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)]( x, y, eps, @@ -120,7 +124,7 @@ def l2norm_fwd(x: torch.Tensor, NB = triton.cdiv(T, 2048) def grid(meta): - return (triton.cdiv(T, meta['BT']), ) + return (triton.cdiv(T, meta["BT"]),) l2norm_fwd_kernel[grid]( x, @@ -132,7 +136,7 @@ def grid(meta): BD=BD, ) else: - l2norm_fwd_kernel1[(T, )]( + l2norm_fwd_kernel1[(T,)]( x, y, eps=eps, diff --git a/vllm/model_executor/layers/fla/ops/layernorm_guard.py b/vllm/model_executor/layers/fla/ops/layernorm_guard.py index a733c6c81e36..655cdb3f30eb 100644 --- a/vllm/model_executor/layers/fla/ops/layernorm_guard.py +++ b/vllm/model_executor/layers/fla/ops/layernorm_guard.py @@ -25,14 +25,16 @@ from .utils import input_guard -def rms_norm_ref(x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - upcast=True): +def rms_norm_ref( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + upcast=True, +): dtype = x.dtype weight = weight.float() bias = bias.float() if bias is not None else None @@ -43,12 +45,10 @@ def rms_norm_ref(x, x = x * F.silu(z) if group_size is None: rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = (x * rstd * weight) + bias if bias is not None else (x * rstd * - weight) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) else: x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) - rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + - eps) + rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps) out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight if bias is not None: out = out + bias @@ -57,10 +57,12 @@ def rms_norm_ref(x, return out.to(dtype) -@triton.heuristics({ - "HAS_BIAS": lambda args: args["B"] is not None, - "HAS_Z": lambda args: args["Z"] is not None, -}) +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, + } +) @triton.jit def layer_norm_fwd_kernel( X, # pointer to the input @@ -97,17 +99,17 @@ def layer_norm_fwd_kernel( B += group * N # Compute mean and variance cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_Z and not NORM_BEFORE_GATE: z = tl.load(Z + cols, mask=cols < N).to(tl.float32) x *= z * tl.sigmoid(z) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.) + xbar = tl.where(cols < N, x - mean, 0.0) var = tl.sum(xbar * xbar, axis=0) / N else: - xbar = tl.where(cols < N, x, 0.) + xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) @@ -145,64 +147,68 @@ def layer_norm_fwd( if z is not None: assert z.stride(-1) == 1 assert z.shape == (M, N) - assert weight.shape == (N, ) + assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 - assert bias.shape == (N, ) + assert bias.shape == (N,) # allocate output if out is not None: assert out.shape == x.shape else: out = torch.empty_like(x) assert out.stride(-1) == 1 - mean = torch.empty((ngroups * M, ), dtype=torch.float32, - device=x.device) if not is_rms_norm else None - rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + mean = ( + torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) if group_size > BLOCK_N: - raise RuntimeError( - "This layer norm doesn't support feature dim >= 64KB.") + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) - layer_norm_fwd_kernel[grid](x, - out, - weight, - bias, - z, - mean, - rstd, - x.stride(0), - out.stride(0), - z.stride(0) if z is not None else 0, - M, - group_size, - eps, - BLOCK_N=BLOCK_N, - NORM_BEFORE_GATE=norm_before_gate, - IS_RMS_NORM=is_rms_norm, - num_warps=num_warps) + layer_norm_fwd_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) return out, mean, rstd class LayerNormFn(torch.autograd.Function): - @input_guard @staticmethod - def forward(ctx, - x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - is_rms_norm=False): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) - """ + def forward( + ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, + ): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" x_shape_og = x.shape # reshape input data into 2D tensor @@ -236,31 +242,30 @@ def forward(ctx, return y.reshape(x_shape_og) -def layernorm_fn(x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - is_rms_norm=False): - return LayerNormFn.apply(x, weight, bias, z, eps, group_size, - norm_before_gate, is_rms_norm) +def layernorm_fn( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + return LayerNormFn.apply( + x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm + ) -def rmsnorm_fn(x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True): - return LayerNormFn.apply(x, weight, bias, z, eps, group_size, - norm_before_gate, True) +def rmsnorm_fn( + x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True +): + return LayerNormFn.apply( + x, weight, bias, z, eps, group_size, norm_before_gate, True + ) class LayerNormGated(nn.Module): - def __init__( self, hidden_size, @@ -288,19 +293,19 @@ def reset_parameters(self): torch.nn.init.zeros_(self.bias) def forward(self, x, z=None): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) - """ - return layernorm_fn(x, - self.weight, - self.bias, - z=z, - group_size=self.group_size, - eps=self.eps, - norm_before_gate=self.norm_before_gate) + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return layernorm_fn( + x, + self.weight, + self.bias, + z=z, + group_size=self.group_size, + eps=self.eps, + norm_before_gate=self.norm_before_gate, + ) class RMSNormGated(nn.Module): - def __init__( self, hidden_size, @@ -326,12 +331,13 @@ def reset_parameters(self): torch.nn.init.ones_(self.weight) def forward(self, x, z=None): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) - """ - return rmsnorm_fn(x, - self.weight, - self.bias, - z=z, - eps=self.eps, - group_size=self.group_size, - norm_before_gate=self.norm_before_gate) + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + return rmsnorm_fn( + x, + self.weight, + self.bias, + z=z, + eps=self.eps, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + ) diff --git a/vllm/model_executor/layers/fla/ops/op.py b/vllm/model_executor/layers/fla/ops/op.py index 8c29434ca106..ee2f4185a5df 100644 --- a/vllm/model_executor/layers/fla/ops/op.py +++ b/vllm/model_executor/layers/fla/ops/op.py @@ -11,7 +11,7 @@ from vllm.triton_utils import tl, tldevice, triton -if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': +if os.environ.get("FLA_USE_FAST_OPS", "0") == "1": div = tldevice.fast_dividef exp = tldevice.fast_expf log = tldevice.fast_logf @@ -28,7 +28,7 @@ def div_normal(x, y): log2 = tl.log2 -if not hasattr(tl, 'gather'): +if not hasattr(tl, "gather"): @triton.jit def gather(src, index, axis, _builder=None): diff --git a/vllm/model_executor/layers/fla/ops/solve_tril.py b/vllm/model_executor/layers/fla/ops/solve_tril.py index 97cb0d800d41..d30fea90aec3 100644 --- a/vllm/model_executor/layers/fla/ops/solve_tril.py +++ b/vllm/model_executor/layers/fla/ops/solve_tril.py @@ -17,15 +17,16 @@ from .utils import input_guard -@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5] + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] ], - key=['BT'], + key=["BT"], ) -@triton.jit(do_not_specialize=['T']) +@triton.jit(do_not_specialize=["T"]) def solve_tril_16x16_kernel( A, Ad, @@ -39,10 +40,14 @@ def solve_tril_16x16_kernel( i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to( - tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T @@ -51,13 +56,12 @@ def solve_tril_16x16_kernel( Ad = Ad + (bos * H + i_h) * 16 offset = (i_t * 16) % BT - p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), - (16, 16), (1, 0)) - p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), - (1, 0)) + p_A = tl.make_block_ptr( + A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0) + ) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) - b_A = -tl.where( - tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) o_i = tl.arange(0, 16) for i in range(1, min(16, T - i_t * 16)): @@ -66,30 +70,45 @@ def solve_tril_16x16_kernel( mask = o_i == i b_A = tl.where(mask[:, None], b_a, b_A) b_A += o_i[:, None] == o_i[None, :] - tl.store(p_Ai, - b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) + tl.store( + p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) -@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5] + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] ], - key=['H', 'BT', 'IS_VARLEN'], + key=["H", "BT", "IS_VARLEN"], ) -@triton.jit(do_not_specialize=['T']) -def merge_16x16_to_32x32_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices, - T, H: tl.constexpr, BT: tl.constexpr, - IS_VARLEN: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to( - tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T @@ -98,55 +117,80 @@ def merge_16x16_to_32x32_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices, Ad += (bos * H + i_h) * 16 Ai += (bos * H + i_h) * 32 - p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), - (16, 16), (1, 0)) - p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), - (16, 16), (1, 0)) - p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), - (16, 16), (1, 0)) - p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), - (16, 16), (1, 0)) - p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), - (16, 16), (1, 0)) - p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), - (16, 16), (1, 0)) + p_A_21 = tl.make_block_ptr( + A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) + ) + p_Ad_11 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0) + ) + p_Ad_22 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) + ) + p_Ai_11 = tl.make_block_ptr( + Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0) + ) + p_Ai_22 = tl.make_block_ptr( + Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0) + ) + p_Ai_21 = tl.make_block_ptr( + Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0) + ) A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) - Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), - Ai_11, - input_precision='ieee') - tl.store(p_Ai_11, - Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_22, - Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_21, - Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) + Ai_21 = -tl.dot( + tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee" + ) + tl.store( + p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) -@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4, 8] for num_stages in [2, 3, 4, 5] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4, 5] ], - key=['H', 'BT', 'IS_VARLEN'], + key=["H", "BT", "IS_VARLEN"], ) -@triton.jit(do_not_specialize=['T']) -def merge_16x16_to_64x64_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices, - T, H: tl.constexpr, BT: tl.constexpr, - IS_VARLEN: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to( - tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T @@ -155,26 +199,36 @@ def merge_16x16_to_64x64_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices, Ad += (bos * H + i_h) * 16 Ai += (bos * H + i_h) * 64 - p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), - (16, 16), (1, 0)) - p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), - (16, 16), (1, 0)) - p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), - (16, 16), (1, 0)) - p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), - (16, 16), (1, 0)) - p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), - (16, 16), (1, 0)) - p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), - (16, 16), (1, 0)) - p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), - (16, 16), (1, 0)) - p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), - (16, 16), (1, 0)) - p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), - (16, 16), (1, 0)) - p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), - (16, 16), (1, 0)) + p_A_21 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0) + ) + p_A_32 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0) + ) + p_A_31 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0) + ) + p_A_43 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0) + ) + p_A_42 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0) + ) + p_A_41 = tl.make_block_ptr( + A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0) + ) + p_Ad_11 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0) + ) + p_Ad_22 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0) + ) + p_Ad_33 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0) + ) + p_Ad_44 = tl.make_block_ptr( + Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0) + ) A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) @@ -188,124 +242,174 @@ def merge_16x16_to_64x64_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices, Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) - Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), - Ai_11, - input_precision='ieee') - Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'), - Ai_22, - input_precision='ieee') - Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'), - Ai_33, - input_precision='ieee') + Ai_21 = -tl.dot( + tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee" + ) + Ai_32 = -tl.dot( + tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee" + ) + Ai_43 = -tl.dot( + tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee" + ) - Ai_31 = -tl.dot(Ai_33, - tl.dot(A_31, Ai_11, input_precision='ieee') + - tl.dot(A_32, Ai_21, input_precision='ieee'), - input_precision='ieee') - Ai_42 = -tl.dot(Ai_44, - tl.dot(A_42, Ai_22, input_precision='ieee') + - tl.dot(A_43, Ai_32, input_precision='ieee'), - input_precision='ieee') - Ai_41 = -tl.dot(Ai_44, - tl.dot(A_41, Ai_11, input_precision='ieee') + - tl.dot(A_42, Ai_21, input_precision='ieee') + - tl.dot(A_43, Ai_31, input_precision='ieee'), - input_precision='ieee') + Ai_31 = -tl.dot( + Ai_33, + tl.dot(A_31, Ai_11, input_precision="ieee") + + tl.dot(A_32, Ai_21, input_precision="ieee"), + input_precision="ieee", + ) + Ai_42 = -tl.dot( + Ai_44, + tl.dot(A_42, Ai_22, input_precision="ieee") + + tl.dot(A_43, Ai_32, input_precision="ieee"), + input_precision="ieee", + ) + Ai_41 = -tl.dot( + Ai_44, + tl.dot(A_41, Ai_11, input_precision="ieee") + + tl.dot(A_42, Ai_21, input_precision="ieee") + + tl.dot(A_43, Ai_31, input_precision="ieee"), + input_precision="ieee", + ) - p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), - (16, 16), (1, 0)) - p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), - (16, 16), (1, 0)) - p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), - (16, 16), (1, 0)) - p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), - (16, 16), (1, 0)) - p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), - (16, 16), (1, 0)) - p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), - (16, 16), (1, 0)) - p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), - (16, 16), (1, 0)) - p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), - (16, 16), (1, 0)) - p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), - (16, 16), (1, 0)) - p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), - (16, 16), (1, 0)) - tl.store(p_Ai_11, - Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_22, - Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_33, - Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_44, - Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_21, - Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_31, - Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_32, - Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_41, - Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_42, - Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_43, - Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) + p_Ai_11 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0) + ) + p_Ai_22 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0) + ) + p_Ai_33 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0) + ) + p_Ai_44 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0) + ) + p_Ai_21 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0) + ) + p_Ai_31 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0) + ) + p_Ai_32 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0) + ) + p_Ai_41 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0) + ) + p_Ai_42 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0) + ) + p_Ai_43 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0) + ) + tl.store( + p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_33, + Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_44, + Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_31, + Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_32, + Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_41, + Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_42, + Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_43, + Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) fill_zeros = tl.zeros((16, 16), dtype=tl.float32) - p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), - (16, 16), (1, 0)) - p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), - (16, 16), (1, 0)) - p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), - (16, 16), (1, 0)) - p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), - (16, 16), (1, 0)) - p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), - (16, 16), (1, 0)) - p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), - (16, 16), (1, 0)) - tl.store(p_Ai_12, - fill_zeros.to(p_Ai_12.dtype.element_ty, - fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_13, - fill_zeros.to(p_Ai_13.dtype.element_ty, - fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_14, - fill_zeros.to(p_Ai_14.dtype.element_ty, - fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_23, - fill_zeros.to(p_Ai_23.dtype.element_ty, - fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_24, - fill_zeros.to(p_Ai_24.dtype.element_ty, - fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) - tl.store(p_Ai_34, - fill_zeros.to(p_Ai_34.dtype.element_ty, - fp_downcast_rounding="rtne"), - boundary_check=(0, 1)) + p_Ai_12 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0) + ) + p_Ai_13 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0) + ) + p_Ai_14 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0) + ) + p_Ai_23 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0) + ) + p_Ai_24 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0) + ) + p_Ai_34 = tl.make_block_ptr( + Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0) + ) + tl.store( + p_Ai_12, + fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_13, + fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_14, + fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_23, + fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_24, + fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_34, + fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) @input_guard -def solve_tril(A: torch.Tensor, - cu_seqlens: Optional[torch.Tensor] = None, - output_dtype: torch.dtype = torch.float) -> torch.Tensor: +def solve_tril( + A: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: """ Compute the inverse of the lower triangular matrix A should be strictly lower triangular, i.e., A.triu() == 0. @@ -325,15 +429,13 @@ def solve_tril(A: torch.Tensor, assert A.shape[-1] in [16, 32, 64] B, T, H, BT = A.shape - Ad = torch.empty(B, - T, - H, - 16, - device=A.device, - dtype=torch.float if BT != 16 else output_dtype) + Ad = torch.empty( + B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype + ) - chunk_indices = prepare_chunk_indices( - cu_seqlens, 16) if cu_seqlens is not None else None + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None + ) NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) solve_tril_16x16_kernel[NT, B * H]( A=A, @@ -348,9 +450,14 @@ def solve_tril(A: torch.Tensor, return Ad Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) - merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel - chunk_indices = prepare_chunk_indices( - cu_seqlens, BT) if cu_seqlens is not None else None + merge_fn = ( + merge_16x16_to_32x32_inverse_kernel + if BT == 32 + else merge_16x16_to_64x64_inverse_kernel + ) + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) merge_fn[NT, B * H]( A=A, diff --git a/vllm/model_executor/layers/fla/ops/utils.py b/vllm/model_executor/layers/fla/ops/utils.py index 7fd90cee45d0..07124f33f1e6 100644 --- a/vllm/model_executor/layers/fla/ops/utils.py +++ b/vllm/model_executor/layers/fla/ops/utils.py @@ -27,8 +27,7 @@ SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) -def tensor_cache( - fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: """ A decorator that caches the most recent results of a function with tensor inputs. @@ -52,12 +51,19 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: nonlocal cache_entries, cache_size for i, entry in enumerate(cache_entries): last_args, last_kwargs, last_result = entry - if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \ - and all(a is b for a, b in zip(args, last_args)) \ - and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): - cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [ - (args, kwargs, last_result) - ] + if ( + len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all( + k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items() + ) + ): + cache_entries = ( + cache_entries[:i] + + cache_entries[i + 1 :] + + [(args, kwargs, last_result)] + ) return last_result result = fn(*args, **kwargs) @@ -70,16 +76,16 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper -def input_guard( - fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: """ A decorator to make sure all input tensors are contiguous and set the device based on input tensors. """ @functools.wraps(fn) def wrapper(*args, **kwargs): - contiguous_args = (i if not isinstance(i, torch.Tensor) else - i.contiguous() for i in args) + contiguous_args = ( + i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args + ) contiguous_kwargs = { k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items() @@ -112,11 +118,11 @@ def get_available_device() -> str: try: return triton.runtime.driver.active.get_current_target().backend except BaseException: - return 'cpu' + return "cpu" @functools.cache -def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']: +def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: device = get_available_device() mapping = { "cuda": "nvidia", @@ -130,27 +136,28 @@ def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']: # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. # Therefore, we need to check the triton backend to determine the actual GPU vendor. -device = get_available_device() if get_available_device() != 'hip' else 'cuda' +device = get_available_device() if get_available_device() != "hip" else "cuda" device_torch_lib = getattr(torch, device) device_platform = _check_platform() -is_amd = (device_platform == 'amd') -is_intel = (device_platform == 'intel') -is_nvidia = (device_platform == 'nvidia') -is_intel_alchemist = (is_intel - and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) -is_nvidia_hopper = (is_nvidia - and ('NVIDIA H' in torch.cuda.get_device_name(0) - or torch.cuda.get_device_capability()[0] >= 9)) -use_cuda_graph = (is_nvidia - and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') +is_amd = device_platform == "amd" +is_intel = device_platform == "intel" +is_nvidia = device_platform == "nvidia" +is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) +is_nvidia_hopper = is_nvidia and ( + "NVIDIA H" in torch.cuda.get_device_name(0) + or torch.cuda.get_device_capability()[0] >= 9 +) +use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" def get_all_max_shared_mem(): try: return [ - triton.runtime.driver.active.utils.get_device_properties(i) - ['max_shared_mem'] for i in range(device_torch_lib.device_count()) + triton.runtime.driver.active.utils.get_device_properties(i)[ + "max_shared_mem" + ] + for i in range(device_torch_lib.device_count()) ] except BaseException: return [-1] diff --git a/vllm/model_executor/layers/fla/ops/wy_fast.py b/vllm/model_executor/layers/fla/ops/wy_fast.py index 70374eb65064..b628a90e843f 100644 --- a/vllm/model_executor/layers/fla/ops/wy_fast.py +++ b/vllm/model_executor/layers/fla/ops/wy_fast.py @@ -17,56 +17,100 @@ from .index import prepare_chunk_indices -@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] ], - key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], ) -@triton.jit(do_not_specialize=['T']) -def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices, - T, H: tl.constexpr, Hg: tl.constexpr, - K: tl.constexpr, V: tl.constexpr, - BT: tl.constexpr, BK: tl.constexpr, - BV: tl.constexpr, IS_VARLEN: tl.constexpr): +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to( - tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to( - tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T - p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ), - (i_t * BT, ), (BT, ), (0, )) - p_g = tl.make_block_ptr(g + (bos * H + i_h), (T, ), (H, ), (i_t * BT, ), - (BT, ), (0, )) - p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), - (i_t * BT, 0), (BT, BT), (1, 0)) - b_beta = tl.load(p_beta, boundary_check=(0, )) + p_beta = tl.make_block_ptr( + beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) + ) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr( + A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) + ) + b_beta = tl.load(p_beta, boundary_check=(0,)) b_A = tl.load(p_A, boundary_check=(0, 1)) - b_g = tl.exp(tl.load(p_g, boundary_check=(0, ))) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) for i_v in range(tl.cdiv(V, BV)): - p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), - (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), - (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) b_v = tl.load(p_v, boundary_check=(0, 1)) b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) b_u = tl.dot(b_A, b_vb, allow_tf32=False) tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), - (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), - (1, 0)) - p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), - (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) b_w = tl.dot(b_A, b_kb) @@ -85,8 +129,9 @@ def recompute_w_u_fwd( H = v.shape[-2] BT = A.shape[-1] - chunk_indices = prepare_chunk_indices( - cu_seqlens, BT) if cu_seqlens is not None else None + chunk_indices = ( + prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + ) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) BK = 64 BV = 64 diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 75f56cd01a4e..56ffaf861ac7 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -6,10 +6,15 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize) + FusedMoEActivationFormat, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, +) from vllm.model_executor.layers.fused_moe.utils import activation_without_mul from vllm.triton_utils import HAS_TRITON @@ -46,21 +51,31 @@ def get_config() -> Optional[dict[str, Any]]: # import to register the custom ops import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) + BatchedDeepGemmExperts, + ) from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) + BatchedTritonOrDeepGemmExperts, + ) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - CutlassBatchedExpertsFp8, CutlassExpertsFp8, cutlass_moe_fp4, - cutlass_moe_fp8) - from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts) + CutlassBatchedExpertsFp8, + CutlassExpertsFp8, + cutlass_moe_fp4, + cutlass_moe_fp8, + ) + from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) + BatchedTritonExperts, + ) from vllm.model_executor.layers.fused_moe.fused_moe import ( - TritonExperts, fused_experts, fused_topk, get_config_file_name, - grouped_topk) + TritonExperts, + fused_experts, + fused_topk, + get_config_file_name, + grouped_topk, + ) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) + TritonOrDeepGemmExperts, + ) __all__ += [ "fused_topk", @@ -82,8 +97,7 @@ def get_config() -> Optional[dict[str, Any]]: # Some model classes directly use the custom ops. Add placeholders # to avoid import errors. def _raise_exception(method: str): - raise NotImplementedError( - f"{method} is not implemented as lack of triton.") + raise NotImplementedError(f"{method} is not implemented as lack of triton.") fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk") fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts") diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 2017a01475b2..f30ebec76c67 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -8,15 +8,14 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - deep_gemm_block_shape) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceDelegate, +) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, - is_deep_gemm_e8m0_used) +from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked, is_deep_gemm_e8m0_used logger = init_logger(__name__) @@ -73,17 +72,14 @@ def _silu_mul_fp8_quant_deep_gemm( base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h base_gate_offset = base_input_offset + cols * stride_i_h base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h - base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + - cols * stride_yq_h) + base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h base_ys_offset = e * stride_ys_e + g * stride_ys_g for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): - gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t, - mask=mask, - other=0.0).to(tl.float32) - up = tl.load(input_ptr + base_up_offset + t * stride_i_t, - mask=mask, - other=0.0) + gate = tl.load( + input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0 + ).to(tl.float32) + up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0) gate = gate * (1.0 / (1.0 + tl.exp(-gate))) y = gate * up @@ -120,8 +116,7 @@ def silu_mul_fp8_quant_deep_gemm_cuda( assert group_size == 128, "H must be divisible by 8" assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E - tokens_per_expert = tokens_per_expert.to(device=y.device, - dtype=torch.int32) + tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32) fp8_dtype = torch.float8_e4m3fn y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) @@ -129,10 +124,12 @@ def silu_mul_fp8_quant_deep_gemm_cuda( stride_ys_e = T * G stride_ys_t = 1 stride_ys_g = T - y_s = torch.empty_strided((E, T, G), - (stride_ys_e, stride_ys_t, stride_ys_g), - dtype=torch.float32, - device=y.device) + y_s = torch.empty_strided( + (E, T, G), + (stride_ys_e, stride_ys_t, stride_ys_g), + dtype=torch.float32, + device=y.device, + ) use_ue8m0 = is_deep_gemm_e8m0_used() @@ -146,17 +143,16 @@ def silu_mul_fp8_quant_deep_gemm_cuda( # We never want to launch more than Tx number of threads # This computes the clip. num_parallel_tokens = max( - 1, - min(max_empirical_parallelism, 2**int(log2(min(num_parallel_tokens, - T))))) + 1, min(max_empirical_parallelism, 2 ** int(log2(min(num_parallel_tokens, T)))) + ) cuda_arch = current_platform.get_device_capability( - device_id=y.device.index).to_int() + device_id=y.device.index + ).to_int() if cuda_arch >= 80: - torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda(y, tokens_per_expert, - y_q, y_s, group_size, - use_ue8m0, - num_parallel_tokens) + torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda( + y, tokens_per_expert, y_q, y_s, group_size, use_ue8m0, num_parallel_tokens + ) else: # Default to triton if not on cuda or if arch is too old y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device) @@ -165,7 +161,7 @@ def silu_mul_fp8_quant_deep_gemm_cuda( # Static grid over experts and H-groups. # A loop inside the kernel handles the token dim - grid = (E * G, ) + grid = (E * G,) # strides (elements) stride_i_e, stride_i_t, stride_i_h = y.stride() stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride() @@ -214,7 +210,6 @@ def silu_mul_fp8_quant_deep_gemm_cuda( class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, max_num_tokens: int, @@ -233,10 +228,12 @@ def __init__( @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) def supports_chunking(self) -> bool: return False @@ -266,10 +263,10 @@ def workspace_shapes( # end up sending their tokens. This needs to be fixed. num_dispatchers = self.num_dispatchers num_experts = local_num_experts - max_num_tokens = a.size( - 0) if self.max_num_tokens is None else self.max_num_tokens - workspace13 = (num_experts, max_num_tokens * num_dispatchers, - max(K, N)) + max_num_tokens = ( + a.size(0) if self.max_num_tokens is None else self.max_num_tokens + ) + workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output, a.dtype) @@ -304,7 +301,8 @@ def apply( assert w2.size(1) == K E, max_num_tokens, N, K, top_k_num = self.moe_problem_size( - hidden_states, w1, w2, topk_ids) + hidden_states, w1, w2, topk_ids + ) workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) @@ -312,11 +310,18 @@ def apply( # for the M expectation of each batch, correctly setting this value # may lead to better performance. expected_m = max_num_tokens - fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, self.w1_scale), - workspace1, expert_num_tokens, expected_m) + fp8_m_grouped_gemm_nt_masked( + (a1q, a1q_scale), + (w1, self.w1_scale), + workspace1, + expert_num_tokens, + expected_m, + ) a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda( - workspace1, expert_num_tokens) + workspace1, expert_num_tokens + ) - fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, self.w2_scale), - output, expert_num_tokens, expected_m) + fp8_m_grouped_gemm_nt_masked( + (a2q, a2q_scale), (w2, self.w2_scale), output, expert_num_tokens, expected_m + ) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index c3c4f4a5d190..d268f70477f4 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -6,16 +6,14 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts) + BatchedDeepGemmExperts, +) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - deep_gemm_block_shape) -from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape +from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, max_num_tokens: int, @@ -31,27 +29,37 @@ def __init__( quant_config=self.quant_config, ) - self.allow_deep_gemm = (allow_deep_gemm - and self.quant_config.use_fp8_w8a8 and - self.block_shape == deep_gemm_block_shape()) + self.allow_deep_gemm = ( + allow_deep_gemm + and self.quant_config.use_fp8_w8a8 + and self.block_shape == deep_gemm_block_shape() + ) - self.batched_deep_gemm_experts = BatchedDeepGemmExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, - quant_config=self.quant_config, - ) if self.allow_deep_gemm else None + self.batched_deep_gemm_experts = ( + BatchedDeepGemmExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + quant_config=self.quant_config, + ) + if self.allow_deep_gemm + else None + ) - assert (self.batched_deep_gemm_experts is not None - or self.batched_triton_experts is not None) + assert ( + self.batched_deep_gemm_experts is not None + or self.batched_triton_experts is not None + ) @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: if self.batched_triton_experts is not None: - assert (self.batched_deep_gemm_experts is None - or self.batched_deep_gemm_experts.activation_formats - == self.batched_triton_experts.activation_formats) + assert ( + self.batched_deep_gemm_experts is None + or self.batched_deep_gemm_experts.activation_formats + == self.batched_triton_experts.activation_formats + ) return self.batched_triton_experts.activation_formats else: assert self.batched_deep_gemm_experts is not None @@ -60,14 +68,16 @@ def activation_formats( def supports_chunking(self) -> bool: bdge = self.batched_deep_gemm_experts bte = self.batched_triton_experts - return ((bdge is None or bdge.supports_chunking()) - and (bte is None or bte.supports_chunking())) + return (bdge is None or bdge.supports_chunking()) and ( + bte is None or bte.supports_chunking() + ) def supports_expert_map(self) -> bool: bdge = self.batched_deep_gemm_experts bte = self.batched_triton_experts - return ((bdge is None or bdge.supports_expert_map()) - and (bte is None or bte.supports_expert_map())) + return (bdge is None or bdge.supports_expert_map()) and ( + bte is None or bte.supports_expert_map() + ) def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: bdge = self.batched_deep_gemm_experts @@ -80,7 +90,8 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: if is_bdge_war and is_bte_war: assert bdge_war == bte_war, ( "Both implementations should agree on WeightAndReduce impls. " - f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}") + f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}" + ) if bdge_war is not None: return bdge_war @@ -106,13 +117,29 @@ def workspace_shapes( if self.allow_deep_gemm: assert self.batched_deep_gemm_experts is not None return self.batched_deep_gemm_experts.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts, - expert_tokens_metadata) + a, + aq, + M, + N, + K, + topk, + global_num_experts, + local_num_experts, + expert_tokens_metadata, + ) else: assert self.batched_triton_experts is not None return self.batched_triton_experts.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts, - expert_tokens_metadata) + a, + aq, + M, + N, + K, + topk, + global_num_experts, + local_num_experts, + expert_tokens_metadata, + ) def apply( self, @@ -132,10 +159,26 @@ def apply( expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - experts = (self.batched_deep_gemm_experts - if self.allow_deep_gemm else self.batched_triton_experts) + experts = ( + self.batched_deep_gemm_experts + if self.allow_deep_gemm + else self.batched_triton_experts + ) assert experts is not None - experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids, - activation, global_num_experts, expert_map, a1q_scale, - a2_scale, workspace13, workspace2, expert_tokens_meta, - apply_router_weight_on_input) + experts.apply( + output, + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + activation, + global_num_experts, + expert_map, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_tokens_meta, + apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 34bfe1c16aac..ae2fad1cd0d7 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -9,8 +9,7 @@ from vllm.config import ParallelConfig from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.utils import cdiv, has_triton_kernels from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe @@ -22,7 +21,8 @@ except ImportError: logger.error( "Failed to import Triton kernels. Please make sure your triton " - "version is compatible.") + "version is compatible." + ) def _get_config_dtype_str( @@ -163,8 +163,9 @@ class FusedMoEQuantConfig: _w2: FusedMoEQuantDesc def __post_init__(self): - assert (not self.per_act_token_quant - or self.block_shape is None), "illegal quantization" + assert not self.per_act_token_quant or self.block_shape is None, ( + "illegal quantization" + ) # # Convenience accessors for various properties. @@ -196,9 +197,11 @@ def is_per_tensor(self) -> bool: @property def block_shape(self) -> Optional[list[int]]: - if (self._a1.shape is not None - and self._a1.shape != GroupShape.PER_TENSOR - and self._a1.shape != GroupShape.PER_TOKEN): + if ( + self._a1.shape is not None + and self._a1.shape != GroupShape.PER_TENSOR + and self._a1.shape != GroupShape.PER_TOKEN + ): return [self._a1.shape.row, self._a1.shape.col] else: return None @@ -209,8 +212,7 @@ def is_block_quantized(self) -> bool: @property def a1_scale(self) -> Optional[torch.Tensor]: - assert self._a1.scale is None or isinstance(self._a1.scale, - torch.Tensor) + assert self._a1.scale is None or isinstance(self._a1.scale, torch.Tensor) return self._a1.scale @property @@ -219,8 +221,7 @@ def a1_gscale(self) -> Optional[torch.Tensor]: @property def a2_scale(self) -> Optional[torch.Tensor]: - assert self._a2.scale is None or isinstance(self._a2.scale, - torch.Tensor) + assert self._a2.scale is None or isinstance(self._a2.scale, torch.Tensor) return self._a2.scale @property @@ -229,8 +230,7 @@ def a2_gscale(self) -> Optional[torch.Tensor]: @property def w1_scale(self) -> Optional[torch.Tensor]: - assert self._w1.scale is None or isinstance(self._w1.scale, - torch.Tensor) + assert self._w1.scale is None or isinstance(self._w1.scale, torch.Tensor) return self._w1.scale @property @@ -243,8 +243,7 @@ def w1_bias(self) -> Optional[torch.Tensor]: @property def w1_precision(self) -> Optional["PrecisionConfig"]: - assert self._w1.scale is None or isinstance(self._w1.scale, - PrecisionConfig) + assert self._w1.scale is None or isinstance(self._w1.scale, PrecisionConfig) return self._w1.scale @property @@ -253,8 +252,7 @@ def g1_alphas(self) -> Optional[torch.Tensor]: @property def w2_scale(self) -> Optional[torch.Tensor]: - assert self._w2.scale is None or isinstance(self._w2.scale, - torch.Tensor) + assert self._w2.scale is None or isinstance(self._w2.scale, torch.Tensor) return self._w2.scale @property @@ -267,8 +265,7 @@ def w2_bias(self) -> Optional[torch.Tensor]: @property def w2_precision(self) -> Optional["PrecisionConfig"]: - assert self._w2.scale is None or isinstance(self._w2.scale, - PrecisionConfig) + assert self._w2.scale is None or isinstance(self._w2.scale, PrecisionConfig) return self._w2.scale @property @@ -285,19 +282,19 @@ def use_int8_w8a8(self) -> bool: @property def use_int8_w8a16(self) -> bool: - return (self._a1.dtype is None and self._w1.dtype == torch.int8) + return self._a1.dtype is None and self._w1.dtype == torch.int8 @property def use_int4_w4a16(self) -> bool: - return (self._a1.dtype is None and self._w1.dtype == "int4") + return self._a1.dtype is None and self._w1.dtype == "int4" @property def use_mxfp4_w4a4(self) -> bool: - return (self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4") + return self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4" @property def use_mxfp4_w4a16(self) -> bool: - return (self._a1.dtype is None and self._w1.dtype == "mxfp4") + return self._a1.dtype is None and self._w1.dtype == "mxfp4" @property def use_nvfp4_w4a4(self) -> bool: @@ -398,19 +395,23 @@ def make( - w1_zp: Optional w1 zero points for int4/int8 quantization. - w2_zp: Optional w2 zero points for int4/int8 quantization. """ - assert (not isinstance(quant_dtype, str) or quant_dtype == "nvfp4" - or quant_dtype == "mxfp4") - a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype, - per_act_token_quant, - per_out_ch_quant, - block_shape) + assert ( + not isinstance(quant_dtype, str) + or quant_dtype == "nvfp4" + or quant_dtype == "mxfp4" + ) + a_shape, w_shape = _quant_flags_to_group_shape( + quant_dtype, per_act_token_quant, per_out_ch_quant, block_shape + ) quant_config = FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale), _a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale), - _w1=FusedMoEQuantDesc(quant_dtype, w_shape, w1_scale, g1_alphas, - w1_zp, w1_bias), - _w2=FusedMoEQuantDesc(quant_dtype, w_shape, w2_scale, g2_alphas, - w2_zp, w2_bias), + _w1=FusedMoEQuantDesc( + quant_dtype, w_shape, w1_scale, g1_alphas, w1_zp, w1_bias + ), + _w2=FusedMoEQuantDesc( + quant_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias + ), ) assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_out_ch_quant == per_out_ch_quant @@ -430,14 +431,16 @@ def fp8_w8a8_moe_quant_config( """ Construct a quant config for fp8 activations and fp8 weights. """ - return FusedMoEQuantConfig.make(torch.float8_e4m3fn, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - per_act_token_quant=per_act_token_quant, - per_out_ch_quant=per_out_ch_quant, - block_shape=block_shape) + return FusedMoEQuantConfig.make( + torch.float8_e4m3fn, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + ) def int8_w8a8_moe_quant_config( @@ -463,10 +466,11 @@ def int8_w8a8_moe_quant_config( def mxfp4_w4a16_moe_quant_config( - w1_scale: Union[torch.Tensor, "PrecisionConfig"], - w2_scale: Union[torch.Tensor, "PrecisionConfig"], - w1_bias: Optional[torch.Tensor] = None, - w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig: + w1_scale: Union[torch.Tensor, "PrecisionConfig"], + w2_scale: Union[torch.Tensor, "PrecisionConfig"], + w1_bias: Optional[torch.Tensor] = None, + w2_bias: Optional[torch.Tensor] = None, +) -> FusedMoEQuantConfig: """ Construct a quant config for unquantized activations and mxfp4 weights. """ @@ -605,22 +609,26 @@ def use_all2all_kernels(self): @property def use_pplx_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "pplx") + return self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "pplx" @property def use_deepep_ht_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") + return ( + self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + ) @property def use_deepep_ll_kernels(self): - return (self.use_all2all_kernels - and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + return ( + self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" + ) @staticmethod - def make(tp_size_: int, dp_size_: int, - vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + def make( + tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig + ) -> "FusedMoEParallelConfig": """ Determine MoE parallel configuration. Based on the input `tp_size_`, `dp_size_` and vllm's parallel config, determine what @@ -700,34 +708,37 @@ def flatten_tp_across_dp(dp_rank: int): tp_rank = dp_rank * tp_size_ + tp_rank return tp_size, tp_rank - use_ep = (dp_size_ * tp_size_ > 1 - and vllm_parallel_config.enable_expert_parallel) + use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 tp_size, tp_rank = flatten_tp_across_dp(dp_rank) if not use_ep: - return FusedMoEParallelConfig(tp_size=tp_size, - tp_rank=tp_rank, - dp_size=dp_size, - dp_rank=dp_rank, - ep_size=1, - ep_rank=0, - use_ep=False) + return FusedMoEParallelConfig( + tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + use_ep=False, + ) # DP + EP / TP + EP / DP + TP + EP assert use_ep # In EP, each device owns a set of experts fully. There is no tensor # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. ep_size = tp_size ep_rank = tp_rank - return FusedMoEParallelConfig(tp_size=1, - tp_rank=0, - dp_size=dp_size, - dp_rank=dp_rank, - ep_size=ep_size, - ep_rank=ep_rank, - use_ep=True) + return FusedMoEParallelConfig( + tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + use_ep=True, + ) # Adapted from pplx-kernels tests/all_to_all_utils.py @@ -749,8 +760,9 @@ class FusedMoEConfig: def __post_init__(self): if self.dp_size > 1: - logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d", - self.max_num_tokens) + logger.debug_once( + "Using FusedMoEConfig::max_num_tokens=%d", self.max_num_tokens + ) assert self.max_num_tokens > 0 @@ -799,6 +811,8 @@ def use_flashinfer_cutlass_kernels(self): """ Whether to use FlashInfer cutlass kernels for NVFP4 MoE. """ - return (envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe() - and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput") + return ( + envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput" + ) diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 114f349538fb..b62817d0115f 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -22,10 +22,9 @@ def grouped_topk( topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" gating_output = gating_output.float() if scoring_func == "softmax": @@ -39,29 +38,30 @@ def grouped_topk( if e_score_correction_bias is not None: original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = (scores.view(num_token, num_expert_group, - -1).topk(2, dim=-1)[0].sum(dim=-1)) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) else: - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, - sorted=False)[1] # [n, top_k_group] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), - float("-inf")) # [n, e] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] topk_weights = original_scores.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk(tmp_scores, - k=topk, - dim=-1, - sorted=False) + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -87,21 +87,22 @@ def select_experts( if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None - return grouped_topk(hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias) + return grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) elif custom_routing_function is None: assert scoring_func == "softmax" - topk_logit_vals, topk_idx = torch.topk(router_logits, - k=top_k, - dim=-1, - sorted=False) + topk_logit_vals, topk_idx = torch.topk( + router_logits, k=top_k, dim=-1, sorted=False + ) if renormalize: topk_vals = torch.softmax(topk_logit_vals, dim=-1) else: @@ -109,16 +110,18 @@ def select_experts( topk_vals = (topk_logit_vals - logZ).exp() return topk_vals.to(torch.float32), topk_idx.to(torch.int32) else: - return custom_routing_function(hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize) + return custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) class IPEXFusedMOE: - def __init__(self, layer: torch.nn.Module) -> None: import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.w13_weight, layer.w2_weight, @@ -146,8 +149,9 @@ def __call__( ) -> torch.Tensor: assert activation == "silu", f"{activation} is not supported." assert not apply_router_weight_on_input - assert routed_scaling_factor == 1.0, \ + assert routed_scaling_factor == 1.0, ( f"routed_scaling_factor {routed_scaling_factor} is not supported." + ) return layer.ipex_fusion( x, use_grouped_topk, @@ -163,7 +167,6 @@ def __call__( class SGLFusedMOE: - def __init__(self, layer: torch.nn.Module) -> None: pass @@ -222,7 +225,6 @@ def __call__( class CPUFusedMOE: - def __init__(self, layer: torch.nn.Module) -> None: pass @@ -289,12 +291,15 @@ def __call__( outputs.append(expert_out) start_idx = end_idx - outs = torch.cat(outputs, - dim=0) if len(outputs) else sorted_tokens.new_empty(0) + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) new_x = torch.empty_like(outs) new_x[idxs] = outs - final_out = (new_x.view( - *topk_ids.shape, -1).type(topk_weights.dtype).mul_( - topk_weights.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype)) + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weights.dtype) + .mul_(topk_weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) return final_out diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 1578e4822765..d3fed9332958 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" CUTLASS based Fused MoE kernels.""" +"""CUTLASS based Fused MoE kernels.""" + from typing import Callable, Optional import torch @@ -10,13 +11,17 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - moe_permute, moe_unpermute) + moe_permute, + moe_unpermute, +) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, - _resize_cache) + TopKWeightAndReduceDelegate, + TopKWeightAndReduceNoOP, +) +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize, _resize_cache from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -56,20 +61,28 @@ def run_cutlass_moe_fp8( assert w2.dtype == torch.float8_e4m3fn assert a1q.size(-1) == w1.size(2), "Hidden size mismatch w1" assert w1.size(1) == w2.size(2) * 2, "Hidden size mismatch w2" - assert w1_scale.dim() == 1 or w1_scale.size( - 1) == 1 or w1_scale.shape[1] == w1.size(1), "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.size( - 1) == 1 or w2_scale.shape[1] == w2.size(1), "W2 scale shape mismatch" + assert ( + w1_scale.dim() == 1 or w1_scale.size(1) == 1 or w1_scale.shape[1] == w1.size(1) + ), "W1 scale shape mismatch" + assert ( + w2_scale.dim() == 1 or w2_scale.size(1) == 1 or w2_scale.shape[1] == w2.size(1) + ), "W2 scale shape mismatch" assert w1.size(0) == w2.size(0), "Expert number mismatch" - assert a1q_scale is None or a1q_scale.dim() == 0 or a1q_scale.size( - 0) == 1 or a1q_scale.size( - 0) == a1q.shape[0], "Input scale shape mismatch" + assert ( + a1q_scale is None + or a1q_scale.dim() == 0 + or a1q_scale.size(0) == 1 + or a1q_scale.size(0) == a1q.shape[0] + ), "Input scale shape mismatch" assert w1.size(0) == w2.size(0), "Weights expert number mismatch" assert w1.size(0) == w1_scale.size(0), "w1 scales expert number mismatch" assert w1.size(0) == w2_scale.size(0), "w2 scales expert number mismatch" - assert a2_scale is None or a2_scale.dim() == 0 or a2_scale.size( - 0) == 1 or a2_scale.size( - 0) == a1q.shape[0], "Intermediate scale shape mismatch" + assert ( + a2_scale is None + or a2_scale.dim() == 0 + or a2_scale.size(0) == 1 + or a2_scale.size(0) == a1q.shape[0] + ), "Intermediate scale shape mismatch" assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" if expert_map is not None: assert expert_num_tokens is None @@ -97,8 +110,9 @@ def run_cutlass_moe_fp8( if expert_map is not None: "Translate info from expert_map to topk_ids" - local_topk_ids = torch.where(expert_map[topk_ids] != -1, - expert_map[topk_ids], -1) + local_topk_ids = torch.where( + expert_map[topk_ids] != -1, expert_map[topk_ids], -1 + ) else: local_topk_ids = topk_ids @@ -108,35 +122,39 @@ def run_cutlass_moe_fp8( if use_batched_format: mm1_out = _resize_cache(workspace13, (local_E * padded_M, N * 2)) act_out = _resize_cache(workspace2, (local_E * padded_M, N)) - quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), - (local_E * padded_M, N)) + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (local_E * padded_M, N) + ) mm2_out = _resize_cache(workspace2, (local_E * padded_M, K)) else: - a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), - (M * topk, K)) + a1q_perm = _resize_cache( + workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K) + ) mm1_out = _resize_cache(workspace13, (M * topk, N * 2)) act_out = _resize_cache(workspace2, (M * topk, N)) # original workspace are based on input hidden_states dtype (bf16) - quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), - (M * topk, N)) + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M * topk, N) + ) mm2_out = _resize_cache(workspace2, (M * topk, K)) if use_batched_format: assert expert_num_tokens is not None - expert_offsets = torch.empty((local_E), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((local_E, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((local_E, 3), - dtype=torch.int32, - device=device) + expert_offsets = torch.empty((local_E), dtype=torch.int32, device=device) + problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device) + problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device) - ops.get_cutlass_pplx_moe_mm_data(expert_offsets, problem_sizes1, - problem_sizes2, expert_num_tokens, - local_E, padded_M, N, K) + ops.get_cutlass_pplx_moe_mm_data( + expert_offsets, + problem_sizes1, + problem_sizes2, + expert_num_tokens, + local_E, + padded_M, + N, + K, + ) w1_scale = w1_scale.reshape(w1_scale.size(0), -1) w2_scale = w2_scale.reshape(w2_scale.size(0), -1) @@ -146,15 +164,14 @@ def run_cutlass_moe_fp8( # during offset calculations expert_offsets = expert_offsets.to(torch.int64) else: - problem_sizes1 = torch.empty((global_num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((global_num_experts, 3), - dtype=torch.int32, - device=device) - - num_expert = global_num_experts if expert_map is None \ - else expert_map.size(0) + problem_sizes1 = torch.empty( + (global_num_experts, 3), dtype=torch.int32, device=device + ) + problem_sizes2 = torch.empty( + (global_num_experts, 3), dtype=torch.int32, device=device + ) + + num_expert = global_num_experts if expert_map is None else expert_map.size(0) # permuted a1q reuses workspace2 a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute( a1q, @@ -163,12 +180,13 @@ def run_cutlass_moe_fp8( num_expert, local_E, expert_map, - permuted_hidden_states=a1q_perm) + permuted_hidden_states=a1q_perm, + ) expert_offsets = expert_offsets[:-1] - ops.get_cutlass_moe_mm_problem_sizes(local_topk_ids, problem_sizes1, - problem_sizes2, - global_num_experts, N, K) + ops.get_cutlass_moe_mm_problem_sizes( + local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K + ) if not per_act_token and (expert_map is not None or use_batched_format): # this is necessary to avoid imprecise scale calculation caused by @@ -176,38 +194,59 @@ def run_cutlass_moe_fp8( # this rank handles only partial tokens, or when it is batched . mm1_out.fill_(0) - ops.cutlass_moe_mm(mm1_out, a1q, w1, a1q_scale, w1_scale, expert_offsets, - problem_sizes1, ab_strides1, ab_strides1, c_strides1, - per_act_token, per_out_ch) + ops.cutlass_moe_mm( + mm1_out, + a1q, + w1, + a1q_scale, + w1_scale, + expert_offsets, + problem_sizes1, + ab_strides1, + ab_strides1, + c_strides1, + per_act_token, + per_out_ch, + ) activation_callable(act_out, mm1_out) a2q, a2q_scale = ops.scaled_fp8_quant( - act_out, - a2_scale, - use_per_token_if_dynamic=per_act_token, - output=quant_out) + act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out + ) if expert_map is not None: mm2_out.fill_(0) - ops.cutlass_moe_mm(mm2_out, a2q, w2, a2q_scale, w2_scale, expert_offsets, - problem_sizes2, ab_strides2, ab_strides2, c_strides2, - per_act_token, per_out_ch) + ops.cutlass_moe_mm( + mm2_out, + a2q, + w2, + a2q_scale, + w2_scale, + expert_offsets, + problem_sizes2, + ab_strides2, + ab_strides2, + c_strides2, + per_act_token, + per_out_ch, + ) if use_batched_format: output.copy_(mm2_out.reshape(local_E, padded_M, K), non_blocking=True) else: # for non-chunking mode the output is resized from workspace13 # so we need to make sure mm2_out uses workspace2. - moe_unpermute(out=output, - permuted_hidden_states=mm2_out, - topk_weights=topk_weights, - inv_permuted_idx=inv_perm) + moe_unpermute( + out=output, + permuted_hidden_states=mm2_out, + topk_weights=topk_weights, + inv_permuted_idx=inv_perm, + ) class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, out_dtype: Optional[torch.dtype], @@ -256,23 +295,40 @@ def apply( activation_callable = lambda o, i: self.activation(activation, o, i) - use_batched_format = self.activation_formats[ - 0] == mk.FusedMoEActivationFormat.BatchedExperts + use_batched_format = ( + self.activation_formats[0] == mk.FusedMoEActivationFormat.BatchedExperts + ) in_dtype = hidden_states.dtype run_cutlass_moe_fp8( - output, hidden_states, w1, w2, topk_ids, activation_callable, - global_num_experts, expert_map, self.w1_scale, self.w2_scale, - a1q_scale, a2_scale, self.ab_strides1, self.ab_strides2, - self.c_strides1, self.c_strides2, workspace13, workspace2, + output, + hidden_states, + w1, + w2, + topk_ids, + activation_callable, + global_num_experts, + expert_map, + self.w1_scale, + self.w2_scale, + a1q_scale, + a2_scale, + self.ab_strides1, + self.ab_strides2, + self.c_strides1, + self.c_strides2, + workspace13, + workspace2, expert_num_tokens, self.out_dtype if self.out_dtype is not None else in_dtype, - self.per_act_token_quant, self.per_out_ch_quant, - use_batched_format, topk_weights) + self.per_act_token_quant, + self.per_out_ch_quant, + use_batched_format, + topk_weights, + ) class CutlassExpertsFp8(CutlassExpertsFp8Base): - def __init__( self, out_dtype: Optional[torch.dtype], @@ -293,10 +349,12 @@ def __init__( @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: return True @@ -323,12 +381,15 @@ def workspace_shapes( workspace1 = (M * topk, max(N, K)) workspace2 = (M * topk, max(N // 2, K)) output = (M, K) - return (workspace1, workspace2, output, - self.out_dtype if self.out_dtype is not None else a.dtype) + return ( + workspace1, + workspace2, + output, + self.out_dtype if self.out_dtype is not None else a.dtype, + ) class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): - def __init__( self, max_experts_per_worker: int, @@ -354,10 +415,12 @@ def __init__( @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) def supports_chunking(self) -> bool: return False @@ -381,13 +444,15 @@ def workspace_shapes( padded_M = aq.size(1) num_dp = self.num_dispatchers assert num_dp is not None - workspace1 = (self.max_experts_per_worker, padded_M * num_dp, - max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M * num_dp, - max(N // 2, K)) + workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M * num_dp, max(N // 2, K)) output = (self.max_experts_per_worker, padded_M, K) - return (workspace1, workspace2, output, - self.out_dtype if self.out_dtype is not None else a.dtype) + return ( + workspace1, + workspace2, + output, + self.out_dtype if self.out_dtype is not None else a.dtype, + ) def cutlass_moe_fp8( @@ -456,18 +521,15 @@ def cutlass_moe_fp8( assert quant_config is not None if quant_config.a1_scale is not None: - assert (quant_config.per_act_token_quant == - quant_config.a1_scale.numel() != 1) + assert quant_config.per_act_token_quant == quant_config.a1_scale.numel() != 1 if quant_config.a2_scale is not None: - assert (quant_config.per_act_token_quant == - quant_config.a2_scale.numel() != 1) + assert quant_config.per_act_token_quant == quant_config.a2_scale.numel() != 1 - assert (quant_config.w1_scale is None - or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1) - == w1_q.size(1)))) + assert quant_config.w1_scale is None or ( + quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1) == w1_q.size(1)) + ) - num_experts = global_num_experts if global_num_experts != -1 else w1_q.size( - 0) + num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), @@ -550,25 +612,30 @@ def run_cutlass_moe_fp4( assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8" assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8" - assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3 - and w2_blockscale.ndim - == 3), ("All Weights must be of rank 3 for cutlass_moe_fp4") + assert ( + w1_fp4.ndim == 3 + and w2_fp4.ndim == 3 + and w1_blockscale.ndim == 3 + and w2_blockscale.ndim == 3 + ), "All Weights must be of rank 3 for cutlass_moe_fp4" m_a, k_a = a.shape e_w1, nx2_w1, half_k_w1 = w1_fp4.shape e_w2, k_w2, half_n_w2 = w2_fp4.shape - assert (e_w1 == e_w2 - and e_w1 == e), ("Number of experts must match", - f" between weights. {e_w1}, {e_w2}, {e}") - assert (k_a == half_k_w1 * 2 - and k == k_w2), ("Hidden size mismatch between a, w1 and w2") - assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " - "expected `n`") - assert (m == m_a), "input shape mismatch" + assert e_w1 == e_w2 and e_w1 == e, ( + "Number of experts must match", + f" between weights. {e_w1}, {e_w2}, {e}", + ) + assert k_a == half_k_w1 * 2 and k == k_w2, ( + "Hidden size mismatch between a, w1 and w2" + ) + assert nx2_w1 == n * 2 and half_n_w2 * 2 == n, "mismatch in expected `n`" + assert m == m_a, "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" - assert (topk_weights.size(0) == m and topk_ids.size(0) - == m), ("topk must be provided for each row of a") + assert topk_weights.size(0) == m and topk_ids.size(0) == m, ( + "topk must be provided for each row of a" + ) topk = topk_ids.size(1) out_dtype = a.dtype num_topk = topk_ids.size(1) @@ -585,15 +652,25 @@ def run_cutlass_moe_fp4( if apply_router_weight_on_input: # TODO: this only works for topK=1, will need to update for topK>1 - assert num_topk == 1, \ + assert num_topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1" + ) a.mul_(topk_weights.to(out_dtype)) # problem shapes should have [m, n, k] # Note that problem sizes are based on logical number of elements. - ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, - problem_sizes2, a_map, c_map, e, n, k, - blockscale_offsets) + ops.get_cutlass_moe_mm_data( + topk_ids, + expert_offsets, + problem_sizes1, + problem_sizes2, + a_map, + c_map, + e, + n, + k, + blockscale_offsets, + ) a = ops.shuffle_rows(a, a_map) rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( @@ -606,17 +683,34 @@ def run_cutlass_moe_fp4( c1 = _resize_cache(workspace13, (m * topk, n * 2)) c2 = _resize_cache(workspace2, (m * topk, n)) c3 = _resize_cache(workspace13, (m * topk, k)) - ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale, - w1_blockscale, w1_alphas, problem_sizes1, - expert_offsets[:-1], blockscale_offsets[:-1]) + ops.cutlass_fp4_moe_mm( + c1, + rep_a_fp4, + w1_fp4, + rep_a_blockscale, + w1_blockscale, + w1_alphas, + problem_sizes1, + expert_offsets[:-1], + blockscale_offsets[:-1], + ) del rep_a_fp4, rep_a_blockscale torch.ops._C.silu_and_mul(c2, c1) int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( - c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk) + c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk + ) - ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale, - w2_alphas, problem_sizes2, expert_offsets[:-1], - blockscale_offsets[:-1]) + ops.cutlass_fp4_moe_mm( + c3, + int_fp4, + w2_fp4, + int_blockscale, + w2_blockscale, + w2_alphas, + problem_sizes2, + expert_offsets[:-1], + blockscale_offsets[:-1], + ) del int_fp4, int_blockscale c3 = ops.shuffle_rows(c3, c_map) @@ -624,9 +718,12 @@ def run_cutlass_moe_fp4( assert output.dtype == out_dtype if not apply_router_weight_on_input: output.copy_( - (c3.view(m, num_topk, k) * - topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1), - non_blocking=True) + ( + c3.view(m, num_topk, k) + * topk_weights.view(m, num_topk, 1).to(out_dtype) + ).sum(dim=1), + non_blocking=True, + ) else: output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True) return @@ -634,7 +731,6 @@ def run_cutlass_moe_fp4( # Split into batched and non-batched class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, max_experts_per_worker: int, @@ -649,14 +745,18 @@ def __init__( @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: if self.use_batched_format: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) else: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_expert_map(self) -> bool: return False @@ -691,8 +791,12 @@ def workspace_shapes( workspace1 = (M * topk, max(2 * N, K)) workspace2 = (M * topk, N) output = (M, K) - return (workspace1, workspace2, output, - self.out_dtype if self.out_dtype is not None else a.dtype) + return ( + workspace1, + workspace2, + output, + self.out_dtype if self.out_dtype is not None else a.dtype, + ) def apply( self, @@ -740,21 +844,24 @@ def apply( def cutlass_moe_fp4( - a: torch.Tensor, - w1_fp4: torch.Tensor, - w2_fp4: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_config: FusedMoEQuantConfig, - m: int, - n: int, - k: int, - e: int, - expert_map: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False) -> torch.Tensor: - assert expert_map is None, ("Expert Parallelism / expert_map " - "is currently not supported for " - "ModelOptNvFp4FusedMoE's cutlass_moe_fp4.") + a: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_config: FusedMoEQuantConfig, + m: int, + n: int, + k: int, + e: int, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, +) -> torch.Tensor: + assert expert_map is None, ( + "Expert Parallelism / expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE's cutlass_moe_fp4." + ) # TODO(bnell): this feels a bit hacky # NVFP4 requires two levels of quantization, which involves @@ -799,10 +906,13 @@ def cutlass_moe_fp4( def _valid_cutlass_block_scaled_grouped_gemm( - w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str, - apply_router_weight_on_input: bool, - expert_map: Optional[torch.Tensor]) -> bool: - + w1: torch.Tensor, + w2: torch.Tensor, + inplace: bool, + activation: str, + apply_router_weight_on_input: bool, + expert_map: Optional[torch.Tensor], +) -> bool: def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): return N % 128 == 0 and K % 128 == 0 @@ -816,7 +926,7 @@ def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): ) return False - if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + if w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn: logger.debug_once( "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s). " "w1.dtype: %s, w2.dtype: %s", @@ -827,19 +937,21 @@ def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): if expert_map is not None: logger.debug_once( - "CutlassBlockScaledGroupedGemm disabled: expert_parallel is" - " not supported.") + "CutlassBlockScaledGroupedGemm disabled: expert_parallel is not supported." + ) return False if activation != "silu": logger.debug_once( - "CutlassBlockScaledGroupedGemm disabled: only activation silu is" - " supported.") + "CutlassBlockScaledGroupedGemm disabled: only activation silu is supported." + ) return False if apply_router_weight_on_input: - logger.debug_once("CutlassBlockScaledGroupedGemm disabled:" - " apply_router_weight_on_input is not supported.") + logger.debug_once( + "CutlassBlockScaledGroupedGemm disabled:" + " apply_router_weight_on_input is not supported." + ) return False if inplace: @@ -867,17 +979,16 @@ def run_cutlass_block_scaled_fused_experts( w2_scale = w2_scale.transpose(1, 2) assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert a.shape[0] == topk_ids.shape[ - 0], "a and topk_ids must have the same batch size" + assert a.shape[0] == topk_ids.shape[0], ( + "a and topk_ids must have the same batch size" + ) assert w1_q.dtype == torch.float8_e4m3fn, "w1_q must be float8_e4m3fn" assert w2_q.dtype == torch.float8_e4m3fn, "w2_q must be float8_e4m3fn" assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1" assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2" assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch" - assert w1_q.shape[0] == w1_scale.shape[ - 0], "w1_scale expert number mismatch" - assert w1_q.shape[0] == w2_scale.shape[ - 0], "w2_scale expert number mismatch" + assert w1_q.shape[0] == w1_scale.shape[0], "w1_scale expert number mismatch" + assert w1_q.shape[0] == w2_scale.shape[0], "w2_scale expert number mismatch" assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype" out_dtype = a.dtype @@ -888,21 +999,14 @@ def run_cutlass_block_scaled_fused_experts( topk = topk_ids.size(1) - a_q, a1_scale = _fp8_quantize(a, - A_scale=None, - per_act_token=False, - block_shape=[128, 128]) + a_q, a1_scale = _fp8_quantize( + a, A_scale=None, per_act_token=False, block_shape=[128, 128] + ) device = a_q.device - expert_offsets = torch.empty((num_experts + 1, ), - dtype=torch.int32, - device=device) - problem_sizes1 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) - problem_sizes2 = torch.empty((num_experts, 3), - dtype=torch.int32, - device=device) + expert_offsets = torch.empty((num_experts + 1,), dtype=torch.int32, device=device) + problem_sizes1 = torch.empty((num_experts, 3), dtype=torch.int32, device=device) + problem_sizes2 = torch.empty((num_experts, 3), dtype=torch.int32, device=device) a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) @@ -938,10 +1042,9 @@ def run_cutlass_block_scaled_fused_experts( intermediate = torch.empty((m * topk, n), dtype=out_dtype, device=device) torch.ops._C.silu_and_mul(intermediate, c1) - intermediate_q, a2_scale = _fp8_quantize(intermediate, - A_scale=None, - per_act_token=False, - block_shape=[128, 128]) + intermediate_q, a2_scale = _fp8_quantize( + intermediate, A_scale=None, per_act_token=False, block_shape=[128, 128] + ) ops.cutlass_blockwise_scaled_grouped_mm( c2, @@ -953,5 +1056,6 @@ def run_cutlass_block_scaled_fused_experts( expert_offsets[:-1], ) - return (c2[c_map].view(m, topk, k) * - topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) + return ( + c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype) + ).sum(dim=1) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 51a4f275e98c..fec3a7c5d0a9 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -9,17 +9,25 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - compute_aligned_M, deep_gemm_block_shape, deepgemm_moe_permute, - deepgemm_unpermute_and_reduce) + compute_aligned_M, + deep_gemm_block_shape, + deepgemm_moe_permute, + deepgemm_unpermute_and_reduce, +) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) + TopKWeightAndReduceNoOP, +) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.utils import has_deep_gemm, run_once from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous @@ -31,8 +39,9 @@ def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool: return align <= M and N % align == 0 and K % align == 0 -def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor) -> bool: +def _valid_deep_gemm( + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor +) -> bool: """ Check if the given problem size is supported by the DeepGemm grouped gemm kernel. All of M, N, K and the quantization block_shape must be @@ -71,17 +80,19 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, ) return False - if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): + if w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn: logger.debug_once( - "DeepGemm disabled: invalid weight dtype(s). " - "w1.dtype: %s, w2.dtype: %s", + "DeepGemm disabled: invalid weight dtype(s). w1.dtype: %s, w2.dtype: %s", w1.dtype, w2.dtype, ) return False - if (not hidden_states.is_contiguous() or not w1.is_contiguous() - or not w2.is_contiguous()): + if ( + not hidden_states.is_contiguous() + or not w1.is_contiguous() + or not w2.is_contiguous() + ): logger.debug_once( "DeepGemm disabled: weights or activations not contiguous. " "hidden_states.is_contiguous(): %s, w1.is_contiguous(): %s, " @@ -96,10 +107,13 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, @run_once -def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - num_topk: int): +def warmup_deepgemm_gg_contiguous_kernels( + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + num_topk: int, +): """ DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the input tensor shapes. In this function, we construct all possible input @@ -108,8 +122,7 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, call and not during actual model inference. """ - assert w1.size(0) == w2.size(0), ( - "w1 and w2 must have the same number of experts") + assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" block_m = deep_gemm_block_shape()[0] num_experts = w1.size(0) @@ -117,36 +130,39 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor, # This is the maximum GroupedGemm M size that we expect to run # the grouped_gemm with. - MAX_M = compute_aligned_M(env.VLLM_FUSED_MOE_CHUNK_SIZE, - num_topk, - num_experts, - block_m, - expert_tokens_meta=None) + MAX_M = compute_aligned_M( + env.VLLM_FUSED_MOE_CHUNK_SIZE, + num_topk, + num_experts, + block_m, + expert_tokens_meta=None, + ) # Distribute expert-ids evenly. MAX_BLOCKS = MAX_M // block_m - expert_ids_block = torch.randint(low=0, - high=num_experts, - size=(MAX_BLOCKS, ), - device=device, - dtype=torch.int32) + expert_ids_block = torch.randint( + low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32 + ) expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) def _warmup(w: torch.Tensor, w_scale: torch.Tensor): - _, n, k = w.size() a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn) - a1q_scales = torch.empty((MAX_M, k // block_m), - device=device, - dtype=torch.float32) + a1q_scales = torch.empty( + (MAX_M, k // block_m), device=device, dtype=torch.float32 + ) out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) - pbar = tqdm(total=MAX_BLOCKS, - desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})") + pbar = tqdm( + total=MAX_BLOCKS, desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})" + ) num_tokens = MAX_M while num_tokens > 0: m_grouped_fp8_gemm_nt_contiguous( - (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), - out[:num_tokens], expert_ids[:num_tokens]) + (a1q[:num_tokens], a1q_scales[:num_tokens]), + (w, w_scale), + out[:num_tokens], + expert_ids[:num_tokens], + ) pbar.update(1) num_tokens = num_tokens - block_m @@ -155,7 +171,6 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor): class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): super().__init__(quant_config) assert quant_config.block_shape == deep_gemm_block_shape() @@ -165,10 +180,12 @@ def __init__(self, quant_config: FusedMoEQuantConfig): @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: return True @@ -193,8 +210,9 @@ def workspace_shapes( ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert self.block_shape is not None block_m = self.block_shape[0] - M_sum = compute_aligned_M(M, topk, local_num_experts, block_m, - expert_tokens_meta) + M_sum = compute_aligned_M( + M, topk, local_num_experts, block_m, expert_tokens_meta + ) assert M_sum % block_m == 0 workspace1 = (M_sum, max(N, K)) @@ -235,18 +253,20 @@ def apply( assert w2.size(1) == K - M_sum = compute_aligned_M(M=topk_ids.size(0), - num_topk=topk_ids.size(1), - local_num_experts=local_num_experts, - alignment=deep_gemm_block_shape()[0], - expert_tokens_meta=expert_tokens_meta) + M_sum = compute_aligned_M( + M=topk_ids.size(0), + num_topk=topk_ids.size(1), + local_num_experts=local_num_experts, + alignment=deep_gemm_block_shape()[0], + expert_tokens_meta=expert_tokens_meta, + ) - a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), - (M_sum, K)) + a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M_sum, K)) mm1_out = _resize_cache(workspace13, (M_sum, N)) act_out = _resize_cache(workspace2, (M_sum, N // 2)) - quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), - (M_sum, N // 2)) + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2) + ) mm2_out = _resize_cache(workspace2, (M_sum, K)) a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute( @@ -256,32 +276,36 @@ def apply( local_num_experts=local_num_experts, expert_map=expert_map, expert_tokens_meta=expert_tokens_meta, - aq_out=a1q_perm) + aq_out=a1q_perm, + ) assert a1q.size(0) == M_sum - m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, self.w1_scale), - mm1_out, expert_ids) + m_grouped_fp8_gemm_nt_contiguous( + (a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids + ) self.activation(activation, act_out, mm1_out.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None - a2q, a2q_scale = per_token_group_quant_fp8(act_out, - self.block_shape[1], - column_major_scales=True, - out_q=quant_out) + a2q, a2q_scale = per_token_group_quant_fp8( + act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out + ) - m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, self.w2_scale), - mm2_out, expert_ids) + m_grouped_fp8_gemm_nt_contiguous( + (a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids + ) if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - deepgemm_unpermute_and_reduce(a=mm2_out, - topk_ids=topk_ids, - topk_weights=topk_weights, - inv_perm=inv_perm, - expert_map=expert_map, - output=output) + deepgemm_unpermute_and_reduce( + a=mm2_out, + topk_ids=topk_ids, + topk_weights=topk_weights, + inv_perm=inv_perm, + expert_map=expert_map, + output=output, + ) def deep_gemm_moe_fp8( @@ -342,7 +366,8 @@ def deep_gemm_moe_fp8( w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, - block_shape=deep_gemm_block_shape()) + block_shape=deep_gemm_block_shape(), + ) fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py index c8469501af5d..2ac968a9b4ab 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py @@ -20,27 +20,33 @@ def deep_gemm_block_shape() -> list[int]: # Lazy import to avoid CUDA initialization problems. import deep_gemm as dg + block = dg.get_m_alignment_for_contiguous_layout() return [block, block] -def expert_num_tokens_round_up_and_sum(expert_num_tokens: torch.Tensor, - alignment: int) -> int: +def expert_num_tokens_round_up_and_sum( + expert_num_tokens: torch.Tensor, alignment: int +) -> int: # Round up each element in expert_num_tokens to the nearest multiple of # alignment. - ent = (expert_num_tokens.to(torch.int64) + - (alignment - 1)) // alignment * alignment + ent = (expert_num_tokens.to(torch.int64) + (alignment - 1)) // alignment * alignment return torch.sum(ent).item() -def compute_aligned_M(M: int, num_topk: int, local_num_experts: int, - alignment: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata]): - - if ((expert_tokens_meta is not None) - and (expert_tokens_meta.expert_num_tokens_cpu is not None)): +def compute_aligned_M( + M: int, + num_topk: int, + local_num_experts: int, + alignment: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], +): + if (expert_tokens_meta is not None) and ( + expert_tokens_meta.expert_num_tokens_cpu is not None + ): return expert_num_tokens_round_up_and_sum( - expert_tokens_meta.expert_num_tokens_cpu, alignment=alignment) + expert_tokens_meta.expert_num_tokens_cpu, alignment=alignment + ) # expert_num_tokens information is not available on the cpu. # compute the max required size. @@ -74,14 +80,14 @@ def _fwd_kernel_ep_scatter_1( cur_expert = tl.program_id(0) offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) - tokens_per_expert = tl.load(num_recv_tokens_per_expert + offset_cumsum, - mask=offset_cumsum < num_experts, - other=0) + tokens_per_expert = tl.load( + num_recv_tokens_per_expert + offset_cumsum, + mask=offset_cumsum < num_experts, + other=0, + ) tokens_per_expert = round_up_128(tokens_per_expert) cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert - tl.store(expert_start_loc + offset_cumsum, - cumsum, - mask=offset_cumsum < num_experts) + tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) cur_expert_start = tl.load(expert_start_loc + cur_expert) cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) @@ -136,34 +142,31 @@ def _fwd_kernel_ep_scatter_2( mask_s = offset_in_s < SCALE_HIDDEN_SIZE for token_id in range(start_token_id, total_token_num, grid_num): - to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, - mask=mask) - to_copy_s = tl.load(recv_x_scale + token_id * recv_x_scale_stride0 + - offset_in_s, - mask=mask_s) + to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) + to_copy_s = tl.load( + recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s + ) for topk_index in tl.range(0, topk_num, 1, num_stages=4): - expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + - topk_index) + expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) if HAS_EXPERT_MAP: expert_id = apply_expert_map(expert_id, expert_map) if expert_id >= 0: - dest_token_index = tl.atomic_add(expert_start_loc + expert_id, - 1) + dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) tl.store( - output_index + token_id * output_index_stride0 + - topk_index, dest_token_index) - output_tensor_ptr = (output_tensor + - dest_token_index * output_tensor_stride0) + output_index + token_id * output_index_stride0 + topk_index, + dest_token_index, + ) + output_tensor_ptr = ( + output_tensor + dest_token_index * output_tensor_stride0 + ) output_tensor_scale_ptr = ( - output_tensor_scale + - dest_token_index * output_tensor_scale_stride0) + output_tensor_scale + dest_token_index * output_tensor_scale_stride0 + ) tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) - tl.store(output_tensor_scale_ptr + offset_in_s, - to_copy_s, - mask=mask_s) + tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s) @torch.no_grad() @@ -189,7 +192,7 @@ def ep_scatter( assert m_indices.shape[0] % BLOCK_E == 0 - _fwd_kernel_ep_scatter_1[(grid, )]( + _fwd_kernel_ep_scatter_1[(grid,)]( num_recv_tokens_per_expert, expert_start_loc, m_indices, @@ -201,7 +204,7 @@ def ep_scatter( grid = min(recv_topk.shape[0], 1024 * 8) - _fwd_kernel_ep_scatter_2[(grid, )]( + _fwd_kernel_ep_scatter_2[(grid,)]( recv_topk.shape[0], expert_start_loc, recv_x, @@ -265,27 +268,33 @@ def _fwd_kernel_ep_gather( off_d = tl.arange(0, BLOCK_D) accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) for topk_index in range(0, topk_num): - expert_id = tl.load(recv_topk_ids + - cur_token * recv_topk_ids_stride0 + topk_index) + expert_id = tl.load( + recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index + ) if HAS_EXPERT_MAP: expert_id = apply_expert_map(expert_id, expert_map) if expert_id >= 0: - source_token_index = tl.load(input_index + - cur_token * input_index_stride0 + - topk_index) - acc_weight = tl.load(recv_topk_weight + - cur_token * recv_topk_weight_stride0 + - topk_index) - tmp = tl.load(input_tensor + - source_token_index * input_tensor_stride0 + - cur_block * BLOCK_D + off_d) + source_token_index = tl.load( + input_index + cur_token * input_index_stride0 + topk_index + ) + acc_weight = tl.load( + recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index + ) + tmp = tl.load( + input_tensor + + source_token_index * input_tensor_stride0 + + cur_block * BLOCK_D + + off_d + ) accumulator += tmp.to(tl.float32) * acc_weight tl.store( - output_tensor + cur_token * output_tensor_stride0 + - cur_block * BLOCK_D + off_d, + output_tensor + + cur_token * output_tensor_stride0 + + cur_block * BLOCK_D + + off_d, accumulator.to(output_tensor.dtype.element_ty), ) @@ -332,44 +341,46 @@ def ep_gather( return -def deepgemm_moe_permute(aq: torch.Tensor, - aq_scale: torch.Tensor, - topk_ids: torch.Tensor, - local_num_experts: int, - expert_map: Optional[torch.Tensor], - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - aq_out: Optional[torch.Tensor] = None): - +def deepgemm_moe_permute( + aq: torch.Tensor, + aq_scale: torch.Tensor, + topk_ids: torch.Tensor, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + aq_out: Optional[torch.Tensor] = None, +): assert aq.ndim == 2 - assert topk_ids.dtype.is_signed, ( - "The kernel uses -1 to represent invalid topk_ids") + assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids" H = aq.size(1) device = aq.device block_m = deep_gemm_block_shape()[0] block_k = deep_gemm_block_shape()[1] - M_sum = compute_aligned_M(M=topk_ids.size(0), - num_topk=topk_ids.size(1), - local_num_experts=local_num_experts, - alignment=block_m, - expert_tokens_meta=expert_tokens_meta) + M_sum = compute_aligned_M( + M=topk_ids.size(0), + num_topk=topk_ids.size(1), + local_num_experts=local_num_experts, + alignment=block_m, + expert_tokens_meta=expert_tokens_meta, + ) - expert_start_loc = torch.empty((local_num_experts), - device=device, - dtype=torch.int32) + expert_start_loc = torch.empty( + (local_num_experts), device=device, dtype=torch.int32 + ) assert aq_out is None or aq_out.shape == (M_sum, H) if aq_out is None: aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype) - aq_scale_out = torch.empty((M_sum, H // block_k), - device=device, - dtype=torch.float32) + aq_scale_out = torch.empty( + (M_sum, H // block_k), device=device, dtype=torch.float32 + ) - maybe_has_empty_blocks = ((expert_tokens_meta is None) - or (expert_tokens_meta.expert_num_tokens_cpu - is None)) + maybe_has_empty_blocks = (expert_tokens_meta is None) or ( + expert_tokens_meta.expert_num_tokens_cpu is None + ) expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32) @@ -379,35 +390,39 @@ def deepgemm_moe_permute(aq: torch.Tensor, if expert_tokens_meta is not None: expert_num_tokens = expert_tokens_meta.expert_num_tokens else: - expert_num_tokens = count_expert_num_tokens(topk_ids, - local_num_experts, - expert_map) - - ep_scatter(recv_x=aq, - recv_x_scale=aq_scale, - recv_topk=topk_ids, - num_recv_tokens_per_expert=expert_num_tokens, - expert_start_loc=expert_start_loc, - expert_map=expert_map, - output_tensor=aq_out, - output_tensor_scale=aq_scale_out, - m_indices=expert_ids, - output_index=inv_perm) + expert_num_tokens = count_expert_num_tokens( + topk_ids, local_num_experts, expert_map + ) + + ep_scatter( + recv_x=aq, + recv_x_scale=aq_scale, + recv_topk=topk_ids, + num_recv_tokens_per_expert=expert_num_tokens, + expert_start_loc=expert_start_loc, + expert_map=expert_map, + output_tensor=aq_out, + output_tensor_scale=aq_scale_out, + m_indices=expert_ids, + output_index=inv_perm, + ) return aq_out, aq_scale_out, expert_ids, inv_perm def deepgemm_unpermute_and_reduce( - a: torch.Tensor, # Grouped gemm output - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - inv_perm: torch.Tensor, - expert_map: Optional[torch.Tensor], - output: torch.Tensor): - - return ep_gather(input_tensor=a, - recv_topk_ids=topk_ids, - recv_topk_weight=topk_weights, - input_index=inv_perm, - expert_map=expert_map, - output_tensor=output) + a: torch.Tensor, # Grouped gemm output + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + inv_perm: torch.Tensor, + expert_map: Optional[torch.Tensor], + output: torch.Tensor, +): + return ep_gather( + input_tensor=a, + recv_topk_ids=topk_ids, + recv_topk_weight=topk_weights, + input_index=inv_perm, + expert_map=expert_map, + output_tensor=output, + ) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 9e9a9afc18a0..9a2844b7d998 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -8,15 +8,20 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + TopKWeightAndReduceContiguous, + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils import round_up from vllm.v1.worker.ubatching import ( - dbo_current_ubatch_id, dbo_enabled, dbo_switch_to_comm, - dbo_switch_to_compute, dbo_switch_to_compute_sync, + dbo_current_ubatch_id, + dbo_enabled, + dbo_switch_to_comm, + dbo_switch_to_compute, + dbo_switch_to_compute_sync, dbo_yield_and_switch_from_comm_to_compute, - dbo_yield_and_switch_from_compute_to_comm) + dbo_yield_and_switch_from_compute_to_comm, +) class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -25,8 +30,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): """ @staticmethod - def maybe_roundup_layer_hidden_size(hidden_size: int, - dtype: torch.dtype) -> int: + def maybe_roundup_layer_hidden_size(hidden_size: int, dtype: torch.dtype) -> int: # Round up hidden size so it is compatible with DeepEP High Throughput # kernels. # DeepEP intranode kernels make copies in units of, @@ -41,8 +45,13 @@ def maybe_roundup_layer_hidden_size(hidden_size: int, hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size) return hidden_size_bytes // dtype.itemsize - def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int, - dp_size: int, rank_expert_offset: int): + def __init__( + self, + buffer: deep_ep.Buffer, + num_dispatchers: int, + dp_size: int, + rank_expert_offset: int, + ): super().__init__() self.buffer = buffer self.num_dispatchers_ = num_dispatchers @@ -91,7 +100,6 @@ def _do_dispatch( a1_scale: Optional[torch.Tensor], quant_config: FusedMoEQuantConfig, ) -> Callable: - has_scales = token_scales is not None # We yield before launching the dispatch kernel since the dispatch @@ -99,22 +107,31 @@ def _do_dispatch( # for the other ubatch before the dispatch kernel starts. dbo_yield_and_switch_from_compute_to_comm() - (num_tokens_per_rank, num_tokens_per_rdma_rank, - dispatch_expert_num_tokens, is_token_in_rank, - event) = self.buffer.get_dispatch_layout( - topk_idx=rank_topk_ids, - num_experts=num_experts, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False) + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + dispatch_expert_num_tokens, + is_token_in_rank, + event, + ) = self.buffer.get_dispatch_layout( + topk_idx=rank_topk_ids, + num_experts=num_experts, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ) token_data = tokens if has_scales: token_data = (tokens, token_scales) ( - token_data, expert_topk_ids, expert_topk_weights, - expert_num_tokens_per_expert_list, handle, event + token_data, + expert_topk_ids, + expert_topk_weights, + expert_num_tokens_per_expert_list, + handle, + event, ) = self.buffer.dispatch( x=token_data, handle=None, @@ -130,7 +147,8 @@ def _do_dispatch( config=self._get_dispatch_config(), previous_event=None, async_finish=self.async_prepare and not dbo_enabled(), - allocate_on_comm_stream=False) + allocate_on_comm_stream=False, + ) # record the handle for this ubatch a2a_idx = dbo_current_ubatch_id() @@ -185,13 +203,15 @@ def _receiver( expert_topk_ids = torch.where( expert_topk_ids == -1, num_experts - 1 if self.rank_expert_offset == 0 else 0, - expert_topk_ids + self.rank_expert_offset) + expert_topk_ids + self.rank_expert_offset, + ) # Makes a GPU-CPU copy. # TODO (varun): Maybe it is better to re-compute the expert_num_tokens # on GPU. expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( - expert_num_tokens_per_expert_list, device=expert_x.device) + expert_num_tokens_per_expert_list, device=expert_x.device + ) # Dispatch and Quant # DeepEP kernels only support dispatching block-quantized @@ -206,10 +226,16 @@ def _receiver( a1_scale, quant_dtype=quant_config.quant_dtype, per_act_token_quant=False, - block_shape=quant_config.block_shape) + block_shape=quant_config.block_shape, + ) - return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, - expert_topk_weights) + return ( + expert_x, + expert_x_scale, + expert_tokens_meta, + expert_topk_ids, + expert_topk_weights, + ) def supports_async(self) -> bool: return True @@ -224,12 +250,12 @@ def prepare_async( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.ReceiverType: - if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1") + "apply_router_weight_on_input is only implemented for topk=1" + ) a1 = a1 * topk_weights.to(a1.dtype) if quant_config.is_block_quantized: @@ -249,13 +275,15 @@ def prepare_async( a1q_scale = None a1_post_scale = quant_config.a1_scale - return self._do_dispatch(tokens=a1q, - token_scales=a1q_scale, - rank_topk_ids=topk_ids, - rank_topk_weights=topk_weights, - num_experts=num_experts, - a1_scale=a1_post_scale, - quant_config=quant_config) + return self._do_dispatch( + tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=topk_ids, + rank_topk_weights=topk_weights, + num_experts=num_experts, + a1_scale=a1_post_scale, + quant_config=quant_config, + ) def prepare( self, @@ -267,9 +295,15 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - receiver = self.prepare_async(a1, topk_weights, topk_ids, num_experts, - expert_map, apply_router_weight_on_input, - quant_config) + receiver = self.prepare_async( + a1, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) return receiver() def _finalize( @@ -282,7 +316,6 @@ def _finalize( weight_and_reduce_impl: mk.TopKWeightAndReduce, do_async: bool, ) -> Optional[Callable]: - a2a_idx = dbo_current_ubatch_id() handle = self.handles[a2a_idx] assert handle is not None @@ -307,7 +340,8 @@ def _finalize( config=self._get_combine_config(), previous_event=None, async_finish=do_async and not dbo_enabled(), - allocate_on_comm_stream=False) + allocate_on_comm_stream=False, + ) dbo_switch_to_compute() @@ -341,9 +375,15 @@ def finalize_async( apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> Callable: - receiver = self._finalize(output, fused_expert_output, topk_weights, - topk_ids, apply_router_weight_on_input, - weight_and_reduce_impl, True) + receiver = self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + True, + ) assert receiver is not None return receiver @@ -356,6 +396,12 @@ def finalize( apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: - self._finalize(output, fused_expert_output, topk_weights, topk_ids, - apply_router_weight_on_input, weight_and_reduce_impl, - False) + self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + False, + ) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index a9554291db69..6712995b52af 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -8,19 +8,26 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceDelegate, +) from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input, normalize_batched_scales_shape) -from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, - dbo_maybe_run_recv_hook) + moe_kernel_quantize_input, + normalize_batched_scales_shape, +) +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, + dbo_enabled, + dbo_maybe_run_recv_hook, +) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE] -def dequant_fp8(expert_x_fp8: torch.Tensor, - expert_x_scales: torch.Tensor) -> torch.Tensor: +def dequant_fp8( + expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor +) -> torch.Tensor: """ Return dequantized tensor in fp32 """ @@ -30,7 +37,8 @@ def dequant_fp8(expert_x_fp8: torch.Tensor, num_experts = expert_x_fp8.size(0) expert_x_fp32 = expert_x_fp8.to(torch.float32).view( - num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE) + num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE + ) expert_x_scales = expert_x_scales.view(num_experts, -1, 1) return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) @@ -44,11 +52,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # specific hidden sizes. SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168] - def __init__(self, - buffer: deep_ep.Buffer, - max_tokens_per_rank: int, - num_dispatchers: int, - use_fp8_dispatch: bool = False): + def __init__( + self, + buffer: deep_ep.Buffer, + max_tokens_per_rank: int, + num_dispatchers: int, + use_fp8_dispatch: bool = False, + ): super().__init__() self.buffer = buffer @@ -79,10 +89,12 @@ def _do_quant( a1_dtype: torch.dtype, quant_config: FusedMoEQuantConfig, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.use_fp8_dispatch: - block_k = quant_config.block_shape[ - 1] if quant_config.block_shape is not None else None + block_k = ( + quant_config.block_shape[1] + if quant_config.block_shape is not None + else None + ) if block_k == DEEPEP_QUANT_BLOCK_SIZE: # DeepEP kernels did the quantization for us. x, x_scales = x @@ -99,8 +111,12 @@ def _do_quant( # TODO (varun): Optimization - Use a batched version of quant x = x.view((-1, hidden_dim)) x, x_scales = moe_kernel_quantize_input( - x, quant_config.a1_scale, quant_config.quant_dtype, - quant_config.per_act_token_quant, quant_config.block_shape) + x, + quant_config.a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) x = x.view((num_experts, -1, hidden_dim)) if quant_config.quant_dtype is not None: @@ -122,47 +138,62 @@ def prepare_async( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> tuple[Callable, mk.ReceiverType]: - hidden_size = a1.size(1) - assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ - (f"Hidden Size {hidden_size} not in supported list of hidden sizes" - f"{self.SUPPORTED_HIDDEN_SIZES}") + assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, ( + f"Hidden Size {hidden_size} not in supported list of hidden sizes" + f"{self.SUPPORTED_HIDDEN_SIZES}" + ) a2a_idx = dbo_current_ubatch_id() if self.use_fp8_dispatch: - assert hidden_size % 128 == 0, \ - "DeepEP kernels quantize the inputs in blocks of shape 128" - - has_per_token_scales = quant_config.a1_scale.numel( - ) != 1 if quant_config.a1_scale is not None else ( - quant_config.a2_scale.numel() != 1 - if quant_config.a2_scale is not None else False) + assert hidden_size % 128 == 0, ( + "DeepEP kernels quantize the inputs in blocks of shape 128" + ) + + has_per_token_scales = ( + quant_config.a1_scale.numel() != 1 + if quant_config.a1_scale is not None + else ( + quant_config.a2_scale.numel() != 1 + if quant_config.a2_scale is not None + else False + ) + ) assert not has_per_token_scales, ( - "low_latency kernels doesn't support dispatching per-token scales") + "low_latency kernels doesn't support dispatching per-token scales" + ) if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1") + "apply_router_weight_on_input is only implemented for topk=1" + ) a1 = a1 * topk_weights.to(a1.dtype) # Dispatch - expert_x, expert_num_tokens, handle, _, hook= \ - self.buffer.low_latency_dispatch(a1, - topk_ids, - self.max_tokens_per_rank, - num_experts, - use_fp8=self.use_fp8_dispatch, - async_finish=False, - return_recv_hook=True) + expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch( + a1, + topk_ids, + self.max_tokens_per_rank, + num_experts, + use_fp8=self.use_fp8_dispatch, + async_finish=False, + return_recv_hook=True, + ) self.handles[a2a_idx] = handle return ( hook, - lambda: self._receiver(expert_x, expert_num_tokens, quant_config. - a1_scale, a1.dtype, quant_config)) + lambda: self._receiver( + expert_x, + expert_num_tokens, + quant_config.a1_scale, + a1.dtype, + quant_config, + ), + ) def _receiver( self, @@ -172,11 +203,11 @@ def _receiver( a1_dtype: torch.dtype, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, - quant_config) + expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config) expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None + ) return expert_x, expert_x_scale, expert_tokens_meta, None, None @@ -190,10 +221,15 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - hook, receiver = self.prepare_async(a1, topk_weights, topk_ids, - num_experts, expert_map, - apply_router_weight_on_input, - quant_config) + hook, receiver = self.prepare_async( + a1, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) hook() return receiver() @@ -207,9 +243,9 @@ def _finalize( weight_and_reduce_impl: mk.TopKWeightAndReduce, do_async: bool, ) -> tuple[Callable, Callable]: - assert isinstance( - weight_and_reduce_impl, TopKWeightAndReduceDelegate - ), ("Weight application and reduction happens in the combine kernel.") + assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), ( + "Weight application and reduction happens in the combine kernel." + ) a2a_idx = dbo_current_ubatch_id() do_recv_hook = dbo_enabled() or do_async @@ -231,7 +267,8 @@ def _finalize( async_finish=False, zero_copy=False, return_recv_hook=do_recv_hook, - out=output) + out=output, + ) return recv_hook, lambda: None diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 3ea4ed39e956..a2d8fe0da154 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -8,40 +8,47 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize) + create_flashinfer_prepare_finalize, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) -from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, - has_flashinfer_cutlass_fused_moe) + TopKWeightAndReduceNoOP, +) +from vllm.utils.flashinfer import ( + flashinfer_cutlass_fused_moe, + has_flashinfer_cutlass_fused_moe, +) logger = init_logger(__name__) -def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor) -> bool: +def is_valid_flashinfer_cutlass_fused_moe( + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor +) -> bool: """ Check if the given problem size is supported by the FlashInfer CUTLASS MoE kernel. """ if not has_flashinfer_cutlass_fused_moe(): - logger.debug_once("FlashInferExperts disabled: " - "flashinfer_cutlass_fused_moe not available.") + logger.debug_once( + "FlashInferExperts disabled: flashinfer_cutlass_fused_moe not available." + ) return False # Data type checks - if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8 - or hidden_states.dtype - not in [torch.float32, torch.float16, torch.bfloat16]): + if ( + w1.dtype != torch.uint8 + or w2.dtype != torch.uint8 + or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16] + ): logger.debug_once( "FlashInferExperts disabled: w1/w2 must be torch.uint8 " f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " - f"float32, float16, or bfloat16 (got {hidden_states.dtype}).") + f"float32, float16, or bfloat16 (got {hidden_states.dtype})." + ) return False return True class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, out_dtype: torch.dtype, @@ -52,10 +59,10 @@ def __init__( tp_size: int = 1, ): super().__init__(quant_config) - assert quant_config.quant_dtype in ( - "nvfp4", torch.float8_e4m3fn, - None), ("Only nvfp4, fp8, bfloat16 and" - " float16 quantization are currently supported.") + assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( + "Only nvfp4, fp8, bfloat16 and" + " float16 quantization are currently supported." + ) self.ep_rank = ep_rank self.ep_size = ep_size self.tp_rank = tp_rank @@ -64,10 +71,12 @@ def __init__( @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_expert_map(self) -> bool: return False @@ -110,10 +119,8 @@ def workspace_shapes( of each tuple must be the number of tokens. """ aq_m, aq_n = aq.shape - workspace2 = (0, ) - output_shape = (aq_m, - aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m, - aq_n) + workspace2 = (0,) + output_shape = (aq_m, aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m, aq_n) workspace_dtype = a.dtype workspace1 = output_shape # The workspace is determined by `aq`, since it comes after any @@ -138,13 +145,16 @@ def apply( expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: Optional[bool], ): - - assert activation == "silu", ("Only activation silu is supported in " - "FlashInferExperts") + assert activation == "silu", ( + "Only activation silu is supported in FlashInferExperts" + ) if self.quant_dtype == torch.float8_e4m3fn: quant_scales = [ - self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale + self.g1_alphas, + self.a2_gscale, + self.g2_alphas, + self.a1_gscale, ] a1q_scale = None # not passing input_sf in fp8 @@ -153,8 +163,8 @@ def apply( elif self.quant_dtype == "nvfp4": # Ensure w1_scale and w2_scale are not None before calling view assert self.w1_scale is not None and self.w2_scale is not None, ( - "w1_scale and w2_scale must not " - "be None for FlashInferExperts") + "w1_scale and w2_scale must not be None for FlashInferExperts" + ) # Flashinfer CUTLASS kernel takes scalar global scales, # min because inv_scale. quant_scales = [ @@ -209,7 +219,8 @@ def flashinfer_cutlass_moe_fp4( FlashInferExperts( out_dtype=hidden_states.dtype, quant_config=quant_config, - )) + ), + ) return fused_experts( hidden_states=hidden_states, @@ -252,7 +263,8 @@ def flashinfer_cutlass_moe( tp_size=tp_size, ep_rank=ep_rank, ep_size=ep_size, - )) + ), + ) return fused_experts( hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index ed364ac77b28..04bc987d0885 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -7,11 +7,11 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.distributed import get_dp_group, get_ep_group from vllm.distributed.device_communicators.base_device_communicator import ( - All2AllManagerBase) + All2AllManagerBase, +) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.flashinfer import nvfp4_block_scale_interleave @@ -55,13 +55,13 @@ def _apply_router_weight_on_input( """Apply router weight on input if needed.""" if apply_router_weight_on_input: topk = topk_ids.size(1) - assert topk == 1, \ + assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1" + ) a1.mul_(topk_weights.to(a1.dtype)) -class FlashInferAllToAllMoEPrepareAndFinalize( - FlashInferCutlassMoEPrepareAndFinalize): +class FlashInferAllToAllMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize): """FlashInfer implementation using AllToAll communication.""" def __init__( @@ -75,8 +75,7 @@ def __init__( # Initialize all2all_manager only for DP case self.all2all_manager = None if self.use_dp: - self.all2all_manager = get_ep_group( - ).device_communicator.all2all_manager + self.all2all_manager = get_ep_group().device_communicator.all2all_manager def prepare( self, @@ -88,9 +87,9 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - - self._apply_router_weight_on_input(a1, topk_weights, topk_ids, - apply_router_weight_on_input) + self._apply_router_weight_on_input( + a1, topk_weights, topk_ids, apply_router_weight_on_input + ) if not self.use_dp: # Non-DP case: standard quantization @@ -107,18 +106,19 @@ def prepare( global_num_tokens_cpu = get_local_sizes() top_k = topk_ids.size(1) - (self.alltoall_info, topk_ids, topk_weights, a1q, - a1q_scale) = flashinfer_alltoall_dispatch( - self.all2all_manager, - global_num_tokens_cpu, - a1, - quant_config.a1_gscale, - topk_ids, - topk_weights, - top_k, - num_experts, - quant_config, - ) + (self.alltoall_info, topk_ids, topk_weights, a1q, a1q_scale) = ( + flashinfer_alltoall_dispatch( + self.all2all_manager, + global_num_tokens_cpu, + a1, + quant_config.a1_gscale, + topk_ids, + topk_weights, + top_k, + num_experts, + quant_config, + ) + ) return a1q, a1q_scale, None, topk_ids, topk_weights @@ -144,9 +144,7 @@ def finalize( output.copy_(fused_expert_output) -class FlashInferAllGatherMoEPrepareAndFinalize( - FlashInferCutlassMoEPrepareAndFinalize): - +class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFinalize): def __init__( self, use_dp: bool, @@ -164,9 +162,9 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - - self._apply_router_weight_on_input(a1, topk_weights, topk_ids, - apply_router_weight_on_input) + self._apply_router_weight_on_input( + a1, topk_weights, topk_ids, apply_router_weight_on_input + ) a1q, a1q_scale = moe_kernel_quantize_input( a1, @@ -177,12 +175,11 @@ def prepare( is_fp4_scale_swizzled=not self.use_dp, ) if self.use_dp: - topk_weights, topk_ids, a1q, a1q_scale = \ - get_dp_group().all_gatherv( - [topk_weights, topk_ids, a1q, a1q_scale], - dim=0, - sizes=get_local_sizes(), - ) + topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv( + [topk_weights, topk_ids, a1q, a1q_scale], + dim=0, + sizes=get_local_sizes(), + ) if quant_config.quant_dtype == "nvfp4": a1q_scale = nvfp4_block_scale_interleave(a1q_scale) @@ -197,10 +194,10 @@ def finalize( apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: - if self.use_dp: fused_expert_output = get_dp_group().reduce_scatterv( - fused_expert_output, dim=0, sizes=get_local_sizes()) + fused_expert_output, dim=0, sizes=get_local_sizes() + ) output.copy_(fused_expert_output) @@ -216,13 +213,16 @@ def flashinfer_alltoall_dispatch( quant_config: FusedMoEQuantConfig, ): from flashinfer.comm.trtllm_alltoall import MnnvlMoe - assert (all2all_manager.ensure_alltoall_workspace_initialized() - ), "FlashInfer AllToAll workspace not available" + + assert all2all_manager.ensure_alltoall_workspace_initialized(), ( + "FlashInfer AllToAll workspace not available" + ) ep_rank = all2all_manager.rank ep_size = all2all_manager.world_size - max_num_token = max(global_num_tokens_cpu - ) if global_num_tokens_cpu is not None else x.shape[0] + max_num_token = ( + max(global_num_tokens_cpu) if global_num_tokens_cpu is not None else x.shape[0] + ) alltoall_info, topk_ids, topk_weights, _ = ( MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( topk_ids, @@ -235,7 +235,8 @@ def flashinfer_alltoall_dispatch( num_experts, num_experts, top_k, - )) + ) + ) x, x_sf = moe_kernel_quantize_input( x, @@ -272,8 +273,10 @@ def flashinfer_alltoall_combine( alltoall_info, ): from flashinfer.comm.trtllm_alltoall import MnnvlMoe - assert (all2all_manager.ensure_alltoall_workspace_initialized() - ), "FlashInfer AllToAll workspace not available" + + assert all2all_manager.ensure_alltoall_workspace_initialized(), ( + "FlashInfer AllToAll workspace not available" + ) return MnnvlMoe.mnnvl_moe_alltoallv_combine( output, alltoall_info, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 74bcffd8ca03..d12d05915566 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -1,37 +1,39 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import List # noqa: UP035 from typing import Optional import torch -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - calculate_tile_tokens_dim) + calculate_tile_tokens_dim, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.utils import direct_register_custom_op def flashinfer_fused_moe_blockscale_fp8( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int, - topk_group: int, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: List[int], #noqa: UP006 - routed_scaling: float = 1.0) -> torch.Tensor: + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: list[int], + routed_scaling: float = 1.0, +) -> torch.Tensor: from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + assert top_k <= global_num_experts assert top_k <= 8 assert topk_group <= 4 @@ -63,30 +65,32 @@ def flashinfer_fused_moe_blockscale_fp8( local_expert_offset=expert_offset, local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling, - tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k, - global_num_experts), + tile_tokens_dim=calculate_tile_tokens_dim( + x.shape[0], top_k, global_num_experts + ), routing_method_type=2, # DeepSeek-styled routing method use_shuffled_weight=False, ) def flashinfer_fused_moe_blockscale_fp8_fake( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int, - topk_group: int, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: list[int], - routed_scaling: float = 1.0) -> torch.Tensor: + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: list[int], + routed_scaling: float = 1.0, +) -> torch.Tensor: return torch.empty_like(x) @@ -95,30 +99,31 @@ def flashinfer_fused_moe_blockscale_fp8_fake( op_name="flashinfer_fused_moe_blockscale_fp8", op_func=flashinfer_fused_moe_blockscale_fp8, fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(torch.Tag.needs_fixed_stride_order,), ) def flashinfer_fused_moe_per_tensor_scale_fp8( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - routed_scaling_factor: float = 1.0) -> torch.Tensor: + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0, +) -> torch.Tensor: num_expert_group = num_expert_group if num_expert_group is not None else 0 topk_group = topk_group if topk_group is not None else 0 @@ -126,10 +131,11 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( hidden_states, input_scale, quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=False) + per_act_token_quant=False, + ) + + from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe - from vllm.utils.flashinfer import ( - flashinfer_trtllm_fp8_per_tensor_scale_moe) return flashinfer_trtllm_fp8_per_tensor_scale_moe( routing_logits=routing_logits, routing_bias=routing_bias, @@ -148,31 +154,34 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, use_routing_scales_on_input=use_routing_scales_on_input, - tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0], - top_k, num_experts), - routing_method_type=routing_method_type) + tile_tokens_dim=calculate_tile_tokens_dim( + hidden_states.shape[0], top_k, num_experts + ), + routing_method_type=routing_method_type, + ) def flashinfer_fused_moe_per_tensor_scale_fp8_fake( - routing_logits: torch.Tensor, - routing_bias: Optional[torch.Tensor], - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: Optional[int], - topk_group: Optional[int], - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - routed_scaling_factor: float = 1.0) -> torch.Tensor: + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm2_weights: torch.Tensor, + output1_scales_scalar: torch.Tensor, + output1_scales_gate_scalar: torch.Tensor, + output2_scales_scalar: torch.Tensor, + num_experts: int, + top_k: int, + num_expert_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + use_routing_scales_on_input: bool, + routing_method_type: int, + routed_scaling_factor: float = 1.0, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -182,5 +191,5 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake( op_func=flashinfer_fused_moe_per_tensor_scale_fp8, mutates_args=["hidden_states"], fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(torch.Tag.needs_fixed_stride_order,), ) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index fee628eae4d8..2a768c75b0bc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -1,21 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" + from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_moe import ( - try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched) + TopKWeightAndReduceDelegate, + TopKWeightAndReduceNaiveBatched, +) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape, - normalize_scales_shape) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - group_broadcast) + _resize_cache, + moe_kernel_quantize_input, + normalize_batched_scales_shape, + normalize_scales_shape, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast from vllm.triton_utils import tl, triton @@ -56,12 +60,12 @@ def moe_mmk( use_w8a16: tl.constexpr, per_act_token_quant: tl.constexpr, ): - offs_k = tl.arange(0, BLOCK_K) if use_w8a16: - b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[ - None, :] * stride_bsn + b_scale_ptrs = ( + b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn + ) b_scale = tl.load(b_scale_ptrs) if use_w8a8: @@ -94,9 +98,11 @@ def moe_mmk( for k in range(0, tl.cdiv(K, BLOCK_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load(a_ptrs, - mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), - other=0.0) + a = tl.load( + a_ptrs, + mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), + other=0.0, + ) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) # We accumulate along the K dimension. if use_w8a16: @@ -105,13 +111,12 @@ def moe_mmk( if group_k > 0 and group_n > 0: k_start = k * BLOCK_K offs_ks = k_start // group_k - a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, - mask=mask_m, - other=0.0) + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=mask_m, other=0.0 + ) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - accumulator += tl.dot(a, b) * a_scale[:, - None] * b_scale[None, :] + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: # acc used to enable fp8_fast_accum accumulator = tl.dot(a, b, acc=accumulator) @@ -137,9 +142,9 @@ def moe_mmk( @triton.jit def expert_triton_kernel( - a_ptr, #[max_tokens, K] - b_ptr, #[K, N] - c_ptr, #[max_tokens, N] + a_ptr, # [max_tokens, K] + b_ptr, # [K, N] + c_ptr, # [max_tokens, N] expert_id, compute_type: tl.constexpr, # Dimensions @@ -177,7 +182,6 @@ def expert_triton_kernel( BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): - offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) % N offs_k = tl.arange(0, BLOCK_K) @@ -221,7 +225,8 @@ def expert_triton_kernel( compute_type, use_fp8_w8a8, use_int8_w8a16, - per_act_token_quant) + per_act_token_quant, + ) # store in C offs_cn = tl.arange(0, BLOCK_N) @@ -284,7 +289,7 @@ def batched_triton_kernel( # axis 1 is M_blocks * N_blocks pid_mn = tl.program_id(axis=1) - #num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + # num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) num_pid_n = tl.cdiv(N, BLOCK_N) pid_m = pid_mn // num_pid_n pid_n = pid_mn % num_pid_n @@ -300,8 +305,12 @@ def batched_triton_kernel( a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn - c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm + - cta_n_start * stride_cn) + c_ptr = ( + c_ptr + + expert_id * stride_ce + + cta_m_start * stride_cm + + cta_n_start * stride_cn + ) offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N @@ -350,50 +359,54 @@ def batched_triton_kernel( # Kernel config BLOCK_M, BLOCK_N, - BLOCK_K) + BLOCK_K, + ) def invoke_moe_batched_triton_kernel( - A: torch.Tensor, # [E, max_tokens, K] - B: torch.Tensor, # [E, N, K] - C: torch.Tensor, # [E, max_tokens, N] - expert_num_tokens: torch.Tensor, # [E] - compute_type: tl.dtype, - # Quantization data - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: torch.Tensor, - # Quantization schemes - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - config: dict[str, int], - per_act_token_quant: bool, - block_shape: Optional[list[int]] = None): - + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, N, K] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +): assert not use_int4_w4a16 max_num_tokens = A.size(1) K = A.size(2) N = C.size(2) - BLOCK_M = config['BLOCK_SIZE_M'] - BLOCK_N = config['BLOCK_SIZE_N'] - BLOCK_K = config['BLOCK_SIZE_K'] + BLOCK_M = config["BLOCK_SIZE_M"] + BLOCK_N = config["BLOCK_SIZE_N"] + BLOCK_K = config["BLOCK_SIZE_K"] - grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * - triton.cdiv(B.size(1), BLOCK_N)) + grid = ( + expert_num_tokens.size(0), + triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N), + ) - A_scale = normalize_batched_scales_shape(A_scale, - expert_num_tokens.shape[0]) + A_scale = normalize_batched_scales_shape(A_scale, expert_num_tokens.shape[0]) if B_scale is not None and B_scale.ndim == 1: assert B_scale.numel() == expert_num_tokens.shape[0] B_scale = B_scale.view(-1, 1, 1) assert A_scale is None or A_scale.ndim == 3, ( - f"{0 if A_scale is None else A_scale.shape}") + f"{0 if A_scale is None else A_scale.shape}" + ) assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, ( - f"{0 if B_scale is None else B_scale.shape}") + f"{0 if B_scale is None else B_scale.shape}" + ) if B_scale is not None: if B_scale.ndim == 1: @@ -459,7 +472,8 @@ def invoke_moe_batched_triton_kernel( # Kernel config BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - BLOCK_K=BLOCK_K) + BLOCK_K=BLOCK_K, + ) class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): @@ -512,16 +526,15 @@ def prepare( if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, \ + assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1" + ) a1.mul_(topk_weights.to(a1.dtype)) num_tokens, hidden_dim = a1.size() topk = topk_ids.size(1) - tokens_per_expert = torch.zeros(num_experts, - dtype=torch.int, - device=a1.device) + tokens_per_expert = torch.zeros(num_experts, dtype=torch.int, device=a1.device) num_local_experts = self.num_local_experts @@ -533,15 +546,15 @@ def prepare( b_a1 = torch.zeros( (num_local_experts, self.max_num_tokens, hidden_dim), dtype=b_type, - device=a1.device) + device=a1.device, + ) if quant_config.is_quantized: scale_shape = quant_config.batched_scale_shape( - num_local_experts, self.max_num_tokens, hidden_dim) + num_local_experts, self.max_num_tokens, hidden_dim + ) - b_a1_scale = torch.empty(scale_shape, - dtype=torch.float32, - device=a1.device) + b_a1_scale = torch.empty(scale_shape, dtype=torch.float32, device=a1.device) else: assert quant_config.a1_scale is None b_a1_scale = None @@ -558,11 +571,11 @@ def prepare( continue idx = expert_id - first_expert tokens_per_expert[idx] = rows - rhs = a1[:topks.numel()][topks] + rhs = a1[: topks.numel()][topks] if quant_config.quant_dtype is not None: if a1_scale is not None: if quant_config.is_per_act_token: - rhs_a1_scale = a1_scale[:topks.numel()][topks] + rhs_a1_scale = a1_scale[: topks.numel()][topks] else: rhs_a1_scale = a1_scale else: @@ -578,14 +591,15 @@ def prepare( if quant_config.is_per_act_token: b_a1_scale[idx, :rows] = b_s[:rows] else: - b_a1_scale[idx, :b_s.shape[0]] = b_s + b_a1_scale[idx, : b_s.shape[0]] = b_s else: b_a1[idx, :rows, :] = rhs assert b_a1_scale is None or b_a1_scale.ndim == 3 expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None) + expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None + ) return b_a1, b_a1_scale, expert_tokens_meta, None, None @@ -632,10 +646,12 @@ def __init__( @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) def supports_chunking(self) -> bool: return False @@ -670,8 +686,7 @@ def workspace_shapes( def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: assert self.quant_config.is_quantized f32 = torch.float32 - if (self.quant_config.is_per_act_token - or self.quant_config.is_per_tensor): + if self.quant_config.is_per_act_token or self.quant_config.is_per_tensor: return t.to(f32) * scale else: return t.to(f32) * group_broadcast(scale, t.shape) @@ -699,15 +714,16 @@ def apply( expert_num_tokens = expert_tokens_meta.expert_num_tokens num_local_experts = w1.size(0) - assert num_local_experts == w1.size(0), ( - f"{num_local_experts} == {w1.size(0)}") + assert num_local_experts == w1.size(0), f"{num_local_experts} == {w1.size(0)}" N = w1.size(1) // 2 for expert in range(num_local_experts): # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor - if (torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing()): + if ( + torch.compiler.is_compiling() + or torch.cuda.is_current_stream_capturing() + ): num = hidden_states.shape[1] else: num = int(expert_num_tokens[expert].item()) @@ -719,13 +735,11 @@ def apply( if self.quant_config.is_quantized: assert a1q_scale is not None and self.w1_scale is not None - input = self.dequant(hidden_states[expert, :, :], - a1q_scale[expert]) + input = self.dequant(hidden_states[expert, :, :], a1q_scale[expert]) w1_dq = self.dequant(w1[expert], self.w1_scale[expert]) input = input[:num] @ w1_dq.transpose(0, 1) else: - input = hidden_states[expert, :num, :] @ w1[expert].transpose( - 0, 1) + input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) self.activation(activation, tmp, input.to(tmp.dtype)) @@ -749,17 +763,16 @@ def batched_moe_kernel_quantize_input( per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if (torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing()): + if torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing(): # Note: this does a bunch of extra work because expert_num_tokens is # ignored but it does support torch.compile + cudagraphs. hidden_dim = A.size(-1) assert A_scale is None or A_scale.ndim <= 2, ( - f"{A_scale.shape if A_scale is not None else None}") - A_q, A_q_scale = moe_kernel_quantize_input(A.view(-1, - hidden_dim), A_scale, - qtype, per_act_token_quant, - block_shape) + f"{A_scale.shape if A_scale is not None else None}" + ) + A_q, A_q_scale = moe_kernel_quantize_input( + A.view(-1, hidden_dim), A_scale, qtype, per_act_token_quant, block_shape + ) A_q = A_q.view(E, -1, hidden_dim) A_q_scale = normalize_batched_scales_shape(A_q_scale, E) @@ -779,9 +792,7 @@ def batched_moe_kernel_quantize_input( else: scale_shape = (E, 1, 1) - A_q_scale = torch.zeros(scale_shape, - dtype=torch.float32, - device=A.device) + A_q_scale = torch.zeros(scale_shape, dtype=torch.float32, device=A.device) num_experts = expert_num_tokens.numel() @@ -791,7 +802,7 @@ def batched_moe_kernel_quantize_input( num_tokens = int(expert_num_tokens[e].item()) if num_tokens > 0: if A_scale is not None: - scales = A_scale[e, :min(num_tokens, A_scale.shape[1])] + scales = A_scale[e, : min(num_tokens, A_scale.shape[1])] else: scales = None A_q[e, :num_tokens], tmp_scale = moe_kernel_quantize_input( @@ -802,7 +813,7 @@ def batched_moe_kernel_quantize_input( block_shape, ) assert tmp_scale is not None - A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale + A_q_scale[e, : tmp_scale.shape[0]] = tmp_scale return A_q, A_q_scale @@ -832,10 +843,12 @@ def __init__( @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.BatchedExperts, - mk.FusedMoEActivationFormat.BatchedExperts) + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) def supports_chunking(self) -> bool: return False @@ -888,26 +901,28 @@ def apply( ): # Check constraints. if self.quant_config.use_int4_w4a16: - assert hidden_states.size(-1) // 2 == w1.size(2), ( - "Hidden size mismatch") + assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch" else: assert hidden_states.size(-1) == w1.size(2), ( - f"Hidden size mismatch {hidden_states.size(-1)} " - f"!= {w1.size(2)}") + f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}" + ) - assert hidden_states.is_contiguous( - ), "Hidden_states must be contiguous" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + torch.float32, + torch.float16, + torch.bfloat16, + torch.float8_e4m3fn, ] assert expert_tokens_meta is not None expert_num_tokens = expert_tokens_meta.expert_num_tokens E, max_num_tokens, N, K, top_k_num = self.moe_problem_size( - hidden_states, w1, w2, topk_ids) + hidden_states, w1, w2, topk_ids + ) assert w1.size(0) == E assert w2.size(0) == E @@ -932,15 +947,12 @@ def apply( elif hidden_states.dtype == torch.float8_e4m3fn: compute_type = tl.bfloat16 else: - raise ValueError( - f"Unsupported compute_type: {hidden_states.dtype}") + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 - intermediate_cache1 = _resize_cache(workspace13, - (E, max_num_tokens, N)) - intermediate_cache2 = _resize_cache(workspace2, - (E, max_num_tokens, N // 2)) + intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N)) + intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) # TODO(bnell): should this be done for any quantized type? if self.quant_config.use_fp8_w8a8: @@ -963,18 +975,29 @@ def apply( use_int4_w4a16=self.quant_config.use_int4_w4a16, config=config, per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape) + block_shape=self.block_shape, + ) intermediate_cache2.fill_(0) # TODO (bnell): use triton utility from batched deep gemm. - self.activation(activation, intermediate_cache2.view(-1, N // 2), - intermediate_cache1.view(-1, N)) + self.activation( + activation, + intermediate_cache2.view(-1, N // 2), + intermediate_cache1.view(-1, N), + ) qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( - intermediate_cache2, a2_scale, max_num_tokens, E, N, - expert_num_tokens, self.quant_dtype, self.per_act_token_quant, - self.block_shape) + intermediate_cache2, + a2_scale, + max_num_tokens, + E, + N, + expert_num_tokens, + self.quant_dtype, + self.per_act_token_quant, + self.block_shape, + ) invoke_moe_batched_triton_kernel( A=qintermediate_cache2, @@ -990,4 +1013,5 @@ def apply( use_int4_w4a16=self.quant_config.use_int4_w4a16, config=config, per_act_token_quant=self.per_act_token_quant, - block_shape=self.block_shape) + block_shape=self.block_shape, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 617d871a5b3d..c46cc016214f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE utilities for GPTQ.""" + from typing import Optional import torch @@ -11,44 +12,49 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) + TopKWeightAndReduceNoOP, +) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_make_workspace_new, marlin_moe_intermediate_size, - maybe_warn_marlin_atomic_add) + marlin_make_workspace_new, + marlin_moe_intermediate_size, + maybe_warn_marlin_atomic_add, +) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import direct_register_custom_op -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - bias1: Optional[torch.Tensor], - bias2: Optional[torch.Tensor], - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_type_id: int, - apply_router_weight_on_input: bool = False, - global_num_experts: int = -1, - activation: Optional[str] = "silu", - expert_map: Optional[torch.Tensor] = None, - global_scale1: Optional[torch.Tensor] = None, - global_scale2: Optional[torch.Tensor] = None, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - intermediate_cache13: Optional[torch.Tensor] = None, - intermediate_cache2: Optional[torch.Tensor] = None, - is_k_full: bool = True, - output: Optional[torch.Tensor] = None, - inplace: bool = False) -> torch.Tensor: +def fused_marlin_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + bias1: Optional[torch.Tensor], + bias2: Optional[torch.Tensor], + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_type_id: int, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + activation: Optional[str] = "silu", + expert_map: Optional[torch.Tensor] = None, + global_scale1: Optional[torch.Tensor] = None, + global_scale2: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + intermediate_cache13: Optional[torch.Tensor] = None, + intermediate_cache2: Optional[torch.Tensor] = None, + is_k_full: bool = True, + output: Optional[torch.Tensor] = None, + inplace: bool = False, +) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -78,23 +84,29 @@ def fused_marlin_moe(hidden_states: torch.Tensor, """ quant_type = ScalarType.from_id(quant_type_id) assert quant_type in [ - scalar_types.uint4, scalar_types.uint8b128, scalar_types.uint4b8, - scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f + scalar_types.uint4, + scalar_types.uint8b128, + scalar_types.uint4b8, + scalar_types.float8_e4m3fn, + scalar_types.float4_e2m1f, ] bit4_scalar_types = [ - scalar_types.uint4, scalar_types.uint4b8, scalar_types.float4_e2m1f + scalar_types.uint4, + scalar_types.uint4b8, + scalar_types.float4_e2m1f, ] num_bits = 4 if quant_type in bit4_scalar_types else 8 # Check constraints. if gating_output is not None: - assert hidden_states.shape[0] == gating_output.shape[ - 0], "Number of tokens mismatch" - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[1] == w2.shape[2] // ( - num_bits // 2), "Hidden size mismatch w2" + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch" + ) + assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // (num_bits // 2), ( + "Hidden size mismatch w2" + ) assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" @@ -115,9 +127,9 @@ def fused_marlin_moe(hidden_states: torch.Tensor, if global_num_experts == -1: global_num_experts = E - sorted_token_ids, expert_ids, num_tokens_post_padded = \ - moe_align_block_size(topk_ids, block_size_m, global_num_experts, - expert_map) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, block_size_m, global_num_experts, expert_map + ) if workspace is None: workspace = marlin_make_workspace_new(hidden_states.device, 4) @@ -131,19 +143,20 @@ def fused_marlin_moe(hidden_states: torch.Tensor, if intermediate_cache13 is None: intermediate_cache13 = torch.empty( - (M * topk * max(2 * N, K), ), + (M * topk * max(2 * N, K),), device=hidden_states.device, dtype=hidden_states.dtype, ) - intermediate_cache1 = _resize_cache(intermediate_cache13, - (M * topk, 2 * N)) + intermediate_cache1 = _resize_cache(intermediate_cache13, (M * topk, 2 * N)) intermediate_cache3 = _resize_cache(intermediate_cache13, (M * topk, K)) intermediate_cache2 = _resize_cache(intermediate_cache2, (M * topk, N)) maybe_warn_marlin_atomic_add(hidden_states.device, hidden_states.dtype) - use_atomic_add = hidden_states.dtype == torch.half or \ - torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 + use_atomic_add = ( + hidden_states.dtype == torch.half + or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 + ) intermediate_cache1 = ops.moe_wna16_marlin_gemm( hidden_states, @@ -171,18 +184,23 @@ def fused_marlin_moe(hidden_states: torch.Tensor, is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=True, - is_zp_float=False) + is_zp_float=False, + ) if activation == "silu": - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, 2 * N)) + torch.ops._C.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, 2 * N) + ) elif activation == "swigluoai": # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, 2 * N)) + torch.ops._C.swigluoai_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, 2 * N) + ) else: - raise ValueError(f"Unsupported activation: {activation}. " - "Only silu and swigluoai activations are supported.") + raise ValueError( + f"Unsupported activation: {activation}. " + "Only silu and swigluoai activations are supported." + ) if expert_map is not None: intermediate_cache3.zero_() @@ -213,39 +231,42 @@ def fused_marlin_moe(hidden_states: torch.Tensor, is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=True, - is_zp_float=False).view(-1, topk, K) + is_zp_float=False, + ).view(-1, topk, K) if output is None: output = hidden_states if inplace else torch.empty_like(hidden_states) return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output) -def fused_marlin_moe_fake(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - gating_output: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - quant_type_id: int, - apply_router_weight_on_input: bool = False, - global_num_experts: int = -1, - global_scale1: Optional[torch.Tensor] = None, - global_scale2: Optional[torch.Tensor] = None, - expert_map: Optional[torch.Tensor] = None, - g_idx1: Optional[torch.Tensor] = None, - g_idx2: Optional[torch.Tensor] = None, - sort_indices1: Optional[torch.Tensor] = None, - sort_indices2: Optional[torch.Tensor] = None, - w1_zeros: Optional[torch.Tensor] = None, - w2_zeros: Optional[torch.Tensor] = None, - workspace: Optional[torch.Tensor] = None, - intermediate_cache13: Optional[torch.Tensor] = None, - intermediate_cache2: Optional[torch.Tensor] = None, - is_k_full: bool = True, - output: Optional[torch.Tensor] = None, - inplace: bool = False) -> torch.Tensor: +def fused_marlin_moe_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + quant_type_id: int, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + global_scale1: Optional[torch.Tensor] = None, + global_scale2: Optional[torch.Tensor] = None, + expert_map: Optional[torch.Tensor] = None, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + intermediate_cache13: Optional[torch.Tensor] = None, + intermediate_cache2: Optional[torch.Tensor] = None, + is_k_full: bool = True, + output: Optional[torch.Tensor] = None, + inplace: bool = False, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -257,7 +278,6 @@ def fused_marlin_moe_fake(hidden_states: torch.Tensor, class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): # TODO (varun) : Enable activation quantization assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" @@ -279,8 +299,7 @@ def moe_problem_size( if a1.dim() == 2: # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.size(0) == a1.size(0), \ - f"{topk_ids.size(0)} != {a1.size(0)}" + assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}" M = a1.size(0) else: assert a1.dim() == 3 @@ -300,18 +319,27 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: return True def workspace_shapes( - self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, - topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata] + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Modular Kernel provisions output buffer from workspace1. However in # the fused_marlin_moe() function, the final torch.sum(), is defined @@ -323,13 +351,13 @@ def workspace_shapes( # workspace2. # Workspace/IntermediateCache allocation matching fused_marlin_moe() - #workspace1 = (M * topk * max(2 * N, K),) - #workspace2 = (M * topk, N) + # workspace1 = (M * topk * max(2 * N, K),) + # workspace2 = (M * topk, N) # Workspace/IntermediateCache allocation accounting for output buffer # provisioning workspace1 = (M * topk, max(N, K)) - workspace2 = (M * topk * max(2 * N, K), ) + workspace2 = (M * topk * max(2 * N, K),) output = (M, K) return (workspace1, workspace2, output, a.dtype) @@ -374,4 +402,5 @@ def apply( # Workspaces are swapped in workspace_shapes() to account for proper # output buffer allocation. Please refer to workspace_shapes(). intermediate_cache13=workspace2, - intermediate_cache2=workspace13) + intermediate_cache2=workspace13, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f96525734fd9..fe934e56e6bf 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused MoE Triton kernels.""" + import functools import json import os @@ -13,25 +14,34 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger -# yapf: disable from vllm.model_executor.layers.fused_moe.config import ( - FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, _get_config_dtype_str) + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEQuantConfig, + _get_config_dtype_str, +) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( _valid_cutlass_block_scaled_grouped_gemm, - run_cutlass_block_scaled_fused_experts) -# yapf: enable + run_cutlass_block_scaled_fused_experts, +) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - _valid_deep_gemm, deep_gemm_moe_fp8) + _valid_deep_gemm, + deep_gemm_moe_fp8, +) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) + moe_align_block_size, +) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP) + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) + TopKWeightAndReduceNoOP, +) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, activation_without_mul, moe_kernel_quantize_input) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - dequant_mxfp4) + _resize_cache, + activation_without_mul, + moe_kernel_quantize_input, +) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer @@ -43,64 +53,73 @@ @triton.jit -def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, - token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N, - compute_type): +def write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, +): accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @triton.jit def fused_moe_kernel_gptq_awq( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - b_scale_ptr, - b_zp_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - N: tl.constexpr, - K: tl.constexpr, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_bse, - stride_bsk, - stride_bsn, - stride_bze, - stride_bzk, - stride_bzn, - block_k_diviable: tl.constexpr, - group_size: tl.constexpr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, - has_zp: tl.constexpr, - use_int4_w4a16: tl.constexpr, - use_int8_w8a16: tl.constexpr): + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr, +): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. @@ -149,8 +168,7 @@ def fused_moe_kernel_gptq_awq( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( - tl.int64) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens @@ -159,25 +177,41 @@ def fused_moe_kernel_gptq_awq( # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. - write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, - offs_token, token_mask, BLOCK_SIZE_M, - BLOCK_SIZE_N, compute_type) + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) return - offs_bn = (pid_n * BLOCK_SIZE_N + - tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) if use_int4_w4a16: - b_ptrs = b_ptr + off_experts * stride_be + \ - (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \ - stride_bn + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] // 2) * stride_bk + + offs_bn[None, :] * stride_bn + ) b_shifter = (offs_k[:, None] % 2) * 4 elif use_int8_w8a16: - b_ptrs = b_ptr + off_experts * stride_be + \ - offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + b_ptrs = ( + b_ptr + + off_experts * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) if not has_zp and use_int4_w4a16: b_zp_num = 8 @@ -203,34 +237,43 @@ def fused_moe_kernel_gptq_awq( k_mask = None k_other = None - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) b = tl.load(b_ptrs) if use_int4_w4a16: b = (b >> b_shifter) & 0xF - b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ - offs_bn[None, :] * stride_bsn + \ - ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * \ - stride_bsk + b_scale_ptrs = ( + b_scale_ptr + + off_experts * stride_bse + + offs_bn[None, :] * stride_bsn + + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + ) b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) b_scale = b_scale.to(tl.float32) if has_zp and use_int4_w4a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size - b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ - (offs_bn[None, :] // 2) * stride_bzn + \ - offs_k_true * stride_bzk + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + (offs_bn[None, :] // 2) * stride_bzn + + offs_k_true * stride_bzk + ) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) - b_zp = ((b_zp >> b_zp_shifter) & 0xF) + b_zp = (b_zp >> b_zp_shifter) & 0xF b_zp = b_zp.to(tl.float32) elif has_zp and use_int8_w8a16: offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size - b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ - offs_bn[None, :] * stride_bzn + \ - offs_k_true * stride_bzk + b_zp_ptrs = ( + b_zp_ptr + + off_experts * stride_bze + + offs_bn[None, :] * stride_bzn + + offs_k_true * stride_bzk + ) b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) b_zp = b_zp.to(tl.float32) @@ -249,17 +292,14 @@ def fused_moe_kernel_gptq_awq( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) @@ -365,8 +405,7 @@ def fused_moe_kernel( num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( - tl.int64) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) token_mask = offs_token < num_valid_tokens @@ -375,22 +414,35 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back zeros to the output when the expert is not # in the current expert parallel rank. - write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, - offs_token, token_mask, BLOCK_SIZE_M, - BLOCK_SIZE_N, compute_type) + write_zeros_to_output( + c_ptr, + stride_cm, + stride_cn, + pid_n, + N, + offs_token, + token_mask, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + compute_type, + ) return - offs_bn = (pid_n * BLOCK_SIZE_N + - tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) + a_ptrs = a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) - b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + - offs_bn[None, :] * stride_bn) + b_ptrs = ( + b_ptr + + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + ) if use_int8_w8a16: - b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ - None, :] * stride_bsn + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8 or use_int8_w8a8: @@ -398,17 +450,18 @@ def fused_moe_kernel( if group_k > 0 and group_n > 0: a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm offs_bsn = offs_bn // group_n - b_scale_ptrs = (b_scale_ptr + off_experts * stride_bse + - offs_bsn * stride_bsn) + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) # channel-wise elif per_channel_quant: - b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ - None, :] * stride_bsn + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + ) b_scale = tl.load(b_scale_ptrs) # Load per-token scale for activations a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm - a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, - None] + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] # tensor-wise else: a_scale = tl.load(a_scale_ptr) @@ -426,13 +479,12 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) @@ -440,13 +492,12 @@ def fused_moe_kernel( if group_k > 0 and group_n > 0: k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k - a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, - mask=token_mask, - other=0.0) + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) - accumulator += tl.dot(a, b) * a_scale[:, - None] * b_scale[None, :] + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] else: if use_fp8_w8a8: # acc used to enable fp8_fast_accum @@ -461,9 +512,7 @@ def fused_moe_kernel( if HAS_BIAS: accumulator = accumulator + bias[None, :] if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, - mask=token_mask, - other=0) + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) @@ -478,43 +527,46 @@ def fused_moe_kernel( # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) -def invoke_fused_moe_kernel(A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: Optional[torch.Tensor], - topk_weights: Optional[torch.Tensor], - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: dict[str, Any], - compute_type: tl.dtype, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: Optional[list[int]] = None, - B_bias: Optional[torch.Tensor] = None) -> None: +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: Optional[list[int]] = None, + B_bias: Optional[torch.Tensor] = None, +) -> None: assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None - assert (block_shape is None - or triton.cdiv(B.size(-2), block_shape[0]) == B_scale.size(-2)) - assert (block_shape is None - or triton.cdiv(B.size(-1), block_shape[1]) == B_scale.size(-1)) + assert block_shape is None or triton.cdiv( + B.size(-2), block_shape[0] + ) == B_scale.size(-2) + assert block_shape is None or triton.cdiv( + B.size(-1), block_shape[1] + ) == B_scale.size(-1) elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None @@ -532,13 +584,17 @@ def invoke_fused_moe_kernel(A: torch.Tensor, # We assume that top_ids of each token is unique, # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # and we can skip some invalid blocks. - EM = min(sorted_token_ids.size(0), - A.size(0) * top_k * config['BLOCK_SIZE_M']) - grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( - B.size(1), META['BLOCK_SIZE_N']), ) + EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) + grid = lambda META: ( + triton.cdiv(EM, META["BLOCK_SIZE_M"]) + * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), + ) HAS_BIAS = B_bias is not None - if (use_int8_w8a16 or use_int4_w4a16) and \ - block_shape is not None and block_shape[1] > 0: + if ( + (use_int8_w8a16 or use_int4_w4a16) + and block_shape is not None + and block_shape[1] > 0 + ): assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 @@ -546,27 +602,41 @@ def invoke_fused_moe_kernel(A: torch.Tensor, num_valid_tokens=num_tokens, group_size=block_shape[1], num_experts=B.size(0), - bit=4 if use_int4_w4a16 else 8) + bit=4 if use_int4_w4a16 else 8, + ) config = config.copy() config.update( - get_moe_wna16_block_config(config=config, - use_moe_wna16_cuda=use_moe_wna16_cuda, - num_valid_tokens=num_tokens, - size_k=A.size(1), - size_n=B.size(1), - num_experts=B.size(1), - group_size=block_shape[1], - real_top_k=top_k, - block_size_m=config["BLOCK_SIZE_M"])) + get_moe_wna16_block_config( + config=config, + use_moe_wna16_cuda=use_moe_wna16_cuda, + num_valid_tokens=num_tokens, + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), + group_size=block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"], + ) + ) if use_moe_wna16_cuda: bit = 4 if use_int4_w4a16 else 8 - ops.moe_wna16_gemm(A, C, B, B_scale, B_zp, - topk_weights if mul_routed_weight else None, - sorted_token_ids, expert_ids, - num_tokens_post_padded, top_k, - config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], - config["BLOCK_SIZE_K"], bit) + ops.moe_wna16_gemm( + A, + C, + B, + B_scale, + B_zp, + topk_weights if mul_routed_weight else None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + bit, + ) return fused_moe_kernel_gptq_awq[grid]( @@ -610,8 +680,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, config = config.copy() BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") if block_shape is not None: - BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], - block_shape[1])) + BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) fused_moe_kernel[grid]( A, B, @@ -634,16 +703,11 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B.stride(1), C.stride(1), C.stride(2), - A_scale.stride(0) - if A_scale is not None and A_scale.ndim == 2 else 0, - A_scale.stride(1) - if A_scale is not None and A_scale.ndim == 2 else 0, - B_scale.stride(0) - if B_scale is not None and B_scale.ndim >= 2 else 0, - B_scale.stride(2) - if B_scale is not None and B_scale.ndim == 3 else 0, - B_scale.stride(1) - if B_scale is not None and B_scale.ndim >= 2 else 0, + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, B_bias.stride(0) if B_bias is not None else 0, B_bias.stride(1) if B_bias is not None else 0, 0 if block_shape is None else block_shape[0], @@ -680,28 +744,36 @@ def compute_identity_kernel( if batch_id >= num_tokens or dim_offset >= hidden_dim: return - h = tl.load(hidden_states_ptr + batch_id * hidden_dim + dim_offset + - tl.arange(0, BLOCK_SIZE), - mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim) + h = tl.load( + hidden_states_ptr + + batch_id * hidden_dim + + dim_offset + + tl.arange(0, BLOCK_SIZE), + mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim, + ) result = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for i in range(top_k): scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i) result += h * scale - tl.store(output_ptr + batch_id * hidden_dim + dim_offset + - tl.arange(0, BLOCK_SIZE), - result, - mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim) + tl.store( + output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE), + result, + mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim, + ) -def zero_experts_compute_triton(expert_indices: torch.Tensor, - expert_scales: torch.Tensor, num_experts: int, - zero_expert_type: str, - hidden_states: torch.Tensor) -> torch.Tensor: +def zero_experts_compute_triton( + expert_indices: torch.Tensor, + expert_scales: torch.Tensor, + num_experts: int, + zero_expert_type: str, + hidden_states: torch.Tensor, +) -> torch.Tensor: N = expert_indices.numel() top_k = expert_indices.size(-1) - grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), ) + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) if zero_expert_type == "identity": zero_expert_mask = expert_indices < num_experts @@ -716,7 +788,7 @@ def zero_experts_compute_triton(expert_indices: torch.Tensor, hidden_dim = hidden_states.size(-1) num_tokens = hidden_states.size(0) - grid = lambda meta: (num_tokens * (hidden_dim // meta['BLOCK_SIZE']), ) + grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),) compute_identity_kernel[grid]( top_k, hidden_states, @@ -732,14 +804,14 @@ def zero_experts_compute_triton(expert_indices: torch.Tensor, # Adapted from: https://github.com/sgl-project/sglang/pull/2628 -def get_config_file_name(E: int, - N: int, - dtype: Optional[str], - block_shape: Optional[list[int]] = None) -> str: +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[list[int]] = None +) -> str: device_name = current_platform.get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - block_shape_selector = ("" if not block_shape or not all(block_shape) else - f",block_shape={block_shape}").replace(" ", "") + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ).replace(" ", "") return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 @@ -772,18 +844,21 @@ def get_moe_configs( user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER if user_defined_config_folder is not None: user_defined_config_file_path = os.path.join( - user_defined_config_folder, json_file_name) + user_defined_config_folder, json_file_name + ) config_file_paths.append(user_defined_config_file_path) default_config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) config_file_paths.append(default_config_file_path) for config_file_path in config_file_paths: if os.path.exists(config_file_path): with open(config_file_path) as f: - logger.info("Using configuration from %s for MoE layer.", - config_file_path) + logger.info( + "Using configuration from %s for MoE layer.", config_file_path + ) # If a configuration has been found, return it tuned_config = json.load(f) # Delete triton_version from tuned_config @@ -793,16 +868,26 @@ def get_moe_configs( # If no optimized configuration is available, we will use the default # configuration logger.warning( - ("Using default MoE config. Performance might be sub-optimal! " - "Config file not found at %s"), config_file_paths) + ( + "Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_paths, + ) return None -def get_moe_wna16_block_config(config: dict[str, - int], use_moe_wna16_cuda: bool, - num_valid_tokens: int, size_k: int, size_n: int, - num_experts: int, group_size: int, - real_top_k: int, block_size_m: int): +def get_moe_wna16_block_config( + config: dict[str, int], + use_moe_wna16_cuda: bool, + num_valid_tokens: int, + size_k: int, + size_n: int, + num_experts: int, + group_size: int, + real_top_k: int, + block_size_m: int, +): if "BLOCK_SIZE_N" in config and "BLOCK_SIZE_K" in config: # optimal block config is set return {} @@ -824,20 +909,24 @@ def get_moe_wna16_block_config(config: dict[str, num_n_blocks = size_k // block_size_k num_k_blocks = size_n // block_size_k - num_m_blocks = (num_valid_tokens + block_size_m - 1) / block_size_m + \ - num_experts + num_m_blocks = ( + num_valid_tokens + block_size_m - 1 + ) / block_size_m + num_experts if num_valid_tokens // real_top_k <= block_size_m: num_m_blocks = min(num_m_blocks, num_valid_tokens) num_blocks = num_m_blocks * num_n_blocks * num_k_blocks - if size_k % 256 == 0 and num_blocks >= 256 and \ - block_size_k < 256: + if size_k % 256 == 0 and num_blocks >= 256 and block_size_k < 256: block_size_k = 256 num_blocks = num_blocks // (256 // block_size_k) - if num_m_blocks <= 16 and size_k % (block_size_k * 2) == 0 and \ - size_k % (block_size_k * 2) == 0 and block_size_k <= 512 and \ - num_blocks >= 512: + if ( + num_m_blocks <= 16 + and size_k % (block_size_k * 2) == 0 + and size_k % (block_size_k * 2) == 0 + and block_size_k <= 512 + and num_blocks >= 512 + ): block_size_k = block_size_k * 2 num_blocks = num_blocks // 2 @@ -856,10 +945,15 @@ def get_moe_wna16_block_config(config: dict[str, return {"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k} -def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, - num_experts: int, bit: int): - return current_platform.is_cuda() and bit == 4 and \ - group_size in [32, 64, 128] and num_valid_tokens / num_experts <= 6 +def should_moe_wna16_use_cuda( + num_valid_tokens: int, group_size: int, num_experts: int, bit: int +): + return ( + current_platform.is_cuda() + and bit == 4 + and group_size in [32, 64, 128] + and num_valid_tokens / num_experts <= 6 + ) def get_default_config( @@ -889,8 +983,7 @@ def get_default_config( # only set BLOCK_SIZE_M # BLOCK_SIZE_N and BLOCK_SIZE_K would be set later bit = 4 if dtype == "int4_w4a16" else 8 - use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, - block_shape[1], E, bit) + use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit) if use_moe_wna16_cuda: config = {"BLOCK_SIZE_M": min(16, M)} elif M <= 20: @@ -925,6 +1018,7 @@ def try_get_optimal_moe_config( block_shape: Optional[list[int]] = None, ) -> dict[str, int]: from vllm.model_executor.layers.fused_moe import get_config + override_config = get_config() if override_config: config = override_config @@ -943,15 +1037,17 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, - block_shape) + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape) return config -def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool) -> tuple[torch.Tensor, ...]: +def vllm_topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> tuple[torch.Tensor, ...]: ops.topk_softmax( topk_weights, topk_indices, @@ -967,6 +1063,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: if is_rocm_aiter_moe_enabled(): from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax + return rocm_aiter_topk_softmax return vllm_topk_softmax @@ -978,31 +1075,29 @@ def fused_topk( renormalize: bool, indices_type: Optional[torch.dtype] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert hidden_states.size(0) == gating_output.size(0), ( - "Number of tokens mismatch") + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" M, _ = hidden_states.size() - topk_weights = torch.empty(M, - topk, - dtype=torch.float32, - device=hidden_states.device) + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) topk_ids = torch.empty( M, topk, dtype=torch.int32 if indices_type is None else indices_type, - device=hidden_states.device) - token_expert_indices = torch.empty(M, - topk, - dtype=torch.int32, - device=hidden_states.device) + device=hidden_states.device, + ) + token_expert_indices = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. topk_func = dispatch_topk_func() - topk_weights, topk_ids = topk_func(topk_weights, topk_ids, - token_expert_indices, - gating_output_float, renormalize) + topk_weights, topk_ids = topk_func( + topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize + ) return topk_weights, topk_ids, token_expert_indices @@ -1017,9 +1112,9 @@ def fused_topk_bias( n_routed_experts = gating_output.shape[-1] scores = gating_output.softmax(dim=-1) scores_for_choice = scores.view( - -1, n_routed_experts) + e_score_correction_bias.unsqueeze(0) - topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, - sorted=False)[1] + -1, n_routed_experts + ) + e_score_correction_bias.unsqueeze(0) + topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1] topk_weights = scores.gather(1, topk_indices) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -1039,10 +1134,13 @@ def grouped_topk( routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK and \ - current_platform.is_cuda() and \ - num_expert_group <= 32 and topk <= 32 and \ - e_score_correction_bias is not None: + if ( + envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK + and current_platform.is_cuda() + and num_expert_group <= 32 + and topk <= 32 + and e_score_correction_bias is not None + ): return fused_grouped_topk( hidden_states=hidden_states, gating_output=gating_output, @@ -1052,10 +1150,10 @@ def grouped_topk( num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor) + routed_scaling_factor=routed_scaling_factor, + ) - assert hidden_states.size(0) == gating_output.size(0), ( - "Number of tokens mismatch") + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) @@ -1070,30 +1168,31 @@ def grouped_topk( # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = (scores.view(num_token, num_expert_group, - -1).topk(2, dim=-1)[0].sum(dim=-1)) + group_scores = ( + scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) + ) else: - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, - sorted=False)[1] # [n, top_k_group] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.size(-1) // num_expert_group).reshape(num_token, -1) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), - float("-inf")) # [n, e] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk(tmp_scores, - k=topk, - dim=-1, - sorted=False) + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -1105,12 +1204,13 @@ def grouped_topk( @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def eplb_map_to_physical_and_record( - topk_ids: torch.Tensor, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - indices_type: Optional[torch.dtype] = None) -> torch.Tensor: - ''' + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + indices_type: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ Map the logical expert ids to physical expert ids and record the expert load metrics. @@ -1126,7 +1226,7 @@ def eplb_map_to_physical_and_record( Returns: The physical expert ids. - ''' + """ # 1. Convert the logical expert ids to physical expert ids # Directly select a random replica for each logical expert @@ -1138,13 +1238,14 @@ def eplb_map_to_physical_and_record( # to deterministically choose a replica replica_count = logical_replica_count[topk_ids_long] # Flatten-position based index, reshaped back to `topk_ids` shape - pos_indices = torch.arange(topk_ids.numel(), - device=topk_ids.device, - dtype=torch.long).reshape_as(topk_ids) + pos_indices = torch.arange( + topk_ids.numel(), device=topk_ids.device, dtype=torch.long + ).reshape_as(topk_ids) # Compute pseudo-random indices by modulo replica_indices = (pos_indices % replica_count).unsqueeze(-1) - physical_ids = logical_to_physical_map[topk_ids_long].gather( - -1, replica_indices).squeeze(-1) + physical_ids = ( + logical_to_physical_map[topk_ids_long].gather(-1, replica_indices).squeeze(-1) + ) topk_ids = physical_ids @@ -1169,7 +1270,8 @@ def eplb_map_to_physical_and_record( expert_load_view.scatter_add_( dim=0, index=topk_ids_flatten.long(), - src=torch.ones_like(topk_ids_flatten).to(expert_load_view)) + src=torch.ones_like(topk_ids_flatten).to(expert_load_view), + ) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) @@ -1187,8 +1289,7 @@ def fused_grouped_topk( scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: - assert hidden_states.size(0) == gating_output.size(0), ( - "Number of tokens mismatch") + assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) @@ -1199,8 +1300,14 @@ def fused_grouped_topk( scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) topk_values, topk_indices = ops.grouped_topk( - scores, scores_with_bias.to(scores.dtype), num_expert_group, - topk_group, topk, renormalize, routed_scaling_factor) + scores, + scores_with_bias.to(scores.dtype), + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + ) return topk_values.to(torch.float32), topk_indices.to(torch.int32) @@ -1230,12 +1337,33 @@ def inplace_fused_experts( w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> None: - fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - activation, apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, - use_mxfp4_w4a4, per_channel_quant, global_num_experts, - expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, - a2_scale, block_shape, w1_bias, w2_bias) + fused_experts_impl( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + True, + activation, + apply_router_weight_on_input, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + use_mxfp4_w4a4, + per_channel_quant, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + w1_bias, + w2_bias, + ) def inplace_fused_experts_fake( @@ -1272,8 +1400,11 @@ def inplace_fused_experts_fake( op_func=inplace_fused_experts, mutates_args=["hidden_states"], fake_impl=inplace_fused_experts_fake, - tags=(() if is_torch_equal_or_newer("2.7.0") else - (torch.Tag.needs_fixed_stride_order, )), + tags=( + () + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ), ) @@ -1304,11 +1435,32 @@ def outplace_fused_experts( w2_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return fused_experts_impl( - hidden_states, w1, w2, topk_weights, topk_ids, False, activation, - apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, - use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, - global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, - a1_scale, a2_scale, block_shape, w1_bias, w2_bias) + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + False, + activation, + apply_router_weight_on_input, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + use_mxfp4_w4a4, + per_channel_quant, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1_scale, + a2_scale, + block_shape, + w1_bias, + w2_bias, + ) def outplace_fused_experts_fake( @@ -1343,14 +1495,17 @@ def outplace_fused_experts_fake( op_name="outplace_fused_experts", op_func=outplace_fused_experts, fake_impl=outplace_fused_experts_fake, - tags=(() if is_torch_equal_or_newer("2.7.0") else - (torch.Tag.needs_fixed_stride_order, )), + tags=( + () + if is_torch_equal_or_newer("2.7.0") + else (torch.Tag.needs_fixed_stride_order,) + ), ) def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor: torch.ops.vllm.inplace_fused_experts(**kwargs) - hidden_states = kwargs['hidden_states'] + hidden_states = kwargs["hidden_states"] return hidden_states @@ -1381,7 +1536,6 @@ def fused_experts( allow_deep_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False, ) -> torch.Tensor: - if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG use_fp8_w8a8 = quant_config.use_fp8_w8a8 @@ -1392,8 +1546,11 @@ def fused_experts( # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - if (allow_deep_gemm and quant_config.use_fp8_w8a8 and - (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))): + if ( + allow_deep_gemm + and quant_config.use_fp8_w8a8 + and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)) + ): assert quant_config is not None assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( @@ -1412,10 +1569,13 @@ def fused_experts( a2_scale=quant_config.a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm( - w1, w2, inplace, activation, apply_router_weight_on_input, - expert_map)): + elif ( + allow_cutlass_block_scaled_grouped_gemm + and use_fp8_w8a8 + and _valid_cutlass_block_scaled_grouped_gemm( + w1, w2, inplace, activation, apply_router_weight_on_input, expert_map + ) + ): assert quant_config is not None return run_cutlass_block_scaled_fused_experts( a=hidden_states, @@ -1424,7 +1584,8 @@ def fused_experts( w1_scale=quant_config.w1_scale, w2_scale=quant_config.w2_scale, topk_weights=topk_weights, - topk_ids=topk_ids) + topk_ids=topk_ids, + ) else: return dispatch_fused_experts_func(inplace)( hidden_states=hidden_states, @@ -1450,7 +1611,8 @@ def fused_experts( a2_scale=quant_config.a2_scale, block_shape=quant_config.block_shape, w1_bias=quant_config.w1_bias, - w2_bias=quant_config.w2_bias) + w2_bias=quant_config.w2_bias, + ) SILU_NO_MUL: str = activation_without_mul("silu") @@ -1507,22 +1669,20 @@ def fused_experts_impl( ) -> torch.Tensor: # Check constraints. if use_int4_w4a16: - assert hidden_states.size(1) // 2 == w1.size(2), ( - "Hidden size mismatch") + assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" elif use_mxfp4_w4a4: # 16bit activation and fp4x2 packed weight assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch" else: assert hidden_states.size(1) == w1.size(2), ( - f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") + f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}" + ) assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] num_tokens = hidden_states.size(0) E, N, _ = w1.size() @@ -1535,17 +1695,21 @@ def fused_experts_impl( CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE M = min(num_tokens, CHUNK_SIZE) - config_dtype = _get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, - dtype=hidden_states.dtype) + config_dtype = _get_config_dtype_str( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, + dtype=hidden_states.dtype, + ) # Note: for use_int8_w8a16 or use_int4_w4a16, the activations are # quantized prior to calling fused_experts. - quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_mxfp4_w4a4=use_mxfp4_w4a4) + quant_dtype = _get_config_quant_dtype( + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_mxfp4_w4a4=use_mxfp4_w4a4, + ) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -1560,16 +1724,18 @@ def fused_experts_impl( # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 - cache13 = torch.empty(M * top_k_num * max(N, K), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) - intermediate_cache3 = cache13[:M * top_k_num * K].view(M, top_k_num, K) + cache13 = torch.empty( + M * top_k_num * max(N, K), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N) + intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) # This needs separate memory since it's used concurrently with cache1 - intermediate_cache2 = torch.empty((M * top_k_num, N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty( + (M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype + ) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 @@ -1580,10 +1746,7 @@ def fused_experts_impl( else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") - if inplace: - out_hidden_states = hidden_states - else: - out_hidden_states = torch.empty_like(hidden_states) + out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states) if use_mxfp4_w4a4: # Weight has to be dequantized for mxfp4 emulation. @@ -1593,9 +1756,10 @@ def fused_experts_impl( w2_scale = None for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, - num_tokens)) + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.size() @@ -1608,8 +1772,9 @@ def fused_experts_impl( # so the cache size and config are already set correctly and # do not need to be adjusted. intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * - topk_ids.size(1)] + intermediate_cache2 = intermediate_cache2[ + : tokens_in_chunk * topk_ids.size(1) + ] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) @@ -1620,45 +1785,51 @@ def fused_experts_impl( A_scale=a1_scale, quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, - block_shape=block_shape) - - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) - - invoke_fused_moe_kernel(qcurr_hidden_states, - w1, - intermediate_cache1, - a1q_scale, - w1_scale, - w1_zp, - curr_topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - apply_router_weight_on_input, - top_k_num, - config, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - B_bias=w1_bias) + block_shape=block_shape, + ) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map + ) + + invoke_fused_moe_kernel( + qcurr_hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + apply_router_weight_on_input, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + B_bias=w1_bias, + ) # Activation function with multiplication if activation == "silu": - torch.ops._C.silu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) + torch.ops._C.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) elif activation == "gelu": - torch.ops._C.gelu_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) + torch.ops._C.gelu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) elif activation == "swigluoai": # alpha = 1.702, limit = 7.0 - torch.ops._C.swigluoai_and_mul(intermediate_cache2, - intermediate_cache1.view(-1, N)) + torch.ops._C.swigluoai_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) # Activation function without multiplication elif activation == SILU_NO_MUL: intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N)) @@ -1673,38 +1844,42 @@ def fused_experts_impl( A_scale=a2_scale, quant_dtype=quant_dtype, per_act_token_quant=per_channel_quant, - block_shape=block_shape) - - invoke_fused_moe_kernel(qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - curr_topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - not apply_router_weight_on_input, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - block_shape=block_shape, - B_bias=w2_bias) - - ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), - out_hidden_states[begin_chunk_idx:end_chunk_idx]) + block_shape=block_shape, + ) + + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + curr_topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + block_shape=block_shape, + B_bias=w2_bias, + ) + + ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.size()), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) return out_hidden_states class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, quant_config: FusedMoEQuantConfig, @@ -1713,10 +1888,12 @@ def __init__( @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: return True @@ -1764,24 +1941,26 @@ def apply( ): # Check constraints. if self.quant_config.use_int4_w4a16: - assert hidden_states.size(-1) // 2 == w1.size(2), ( - "Hidden size mismatch") + assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch" else: - assert hidden_states.size(-1) == w1.size(2), \ - (f"Hidden size mismatch {hidden_states.size(-1)} " - f"!= {w1.size(2)}") + assert hidden_states.size(-1) == w1.size(2), ( + f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}" + ) - assert hidden_states.is_contiguous( - ), "Hidden_states must be contiguous" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.dim() == 2 assert w1.stride(-1) == 1, "Stride of last dimension must be 1" assert w2.stride(-1) == 1, "Stride of last dimension must be 1" assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn + torch.float32, + torch.float16, + torch.bfloat16, + torch.float8_e4m3fn, ] E, num_tokens, N, K, top_k_num = self.moe_problem_size( - hidden_states, w1, w2, topk_ids) + hidden_states, w1, w2, topk_ids + ) if global_num_experts == -1: global_num_experts = E @@ -1804,20 +1983,18 @@ def apply( elif hidden_states.dtype == torch.float8_e4m3fn: compute_type = tl.bfloat16 else: - raise ValueError( - f"Unsupported compute_type: {hidden_states.dtype}") + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") # Note that the output tensor might be in workspace1 - intermediate_cache1 = _resize_cache(workspace2, - (num_tokens, top_k_num, N)) - intermediate_cache2 = _resize_cache(workspace13, - (num_tokens * top_k_num, N // 2)) - intermediate_cache3 = _resize_cache(workspace2, - (num_tokens, top_k_num, K)) + intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) + intermediate_cache2 = _resize_cache( + workspace13, (num_tokens * top_k_num, N // 2) + ) + intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map + ) invoke_fused_moe_kernel( hidden_states, @@ -1843,14 +2020,19 @@ def apply( B_bias=self.w1_bias, ) - self.activation(activation, intermediate_cache2, - intermediate_cache1.view(-1, N)) + self.activation( + activation, intermediate_cache2, intermediate_cache1.view(-1, N) + ) a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, a2_scale, self.quant_dtype, - self.per_act_token_quant, self.block_shape) + intermediate_cache2, + a2_scale, + self.quant_dtype, + self.per_act_token_quant, + self.block_shape, + ) invoke_fused_moe_kernel( qintermediate_cache2, @@ -1880,7 +2062,8 @@ def apply( def modular_triton_fused_moe( - quant_config: FusedMoEQuantConfig) -> mk.FusedMoEModularKernel: + quant_config: FusedMoEQuantConfig, +) -> mk.FusedMoEModularKernel: return mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP(), TritonExperts(quant_config), diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 18de75851934..39faeed5d10f 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -7,9 +7,12 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( - FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) + TopKWeightAndReduceNoOP, +) from vllm.triton_utils import tl, triton from vllm.utils import has_triton_kernels @@ -18,24 +21,24 @@ if has_triton_kernels(): try: import triton_kernels.swiglu - from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, - matmul_ogs) - from triton_kernels.routing import (RoutingData, routing, - routing_from_bitmatrix) + from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs + from triton_kernels.routing import RoutingData, routing, routing_from_bitmatrix from triton_kernels.tensor import Bitmatrix except (AttributeError, ImportError) as e: logger.error( "Failed to import Triton kernels. Please make sure your triton " - "version is compatible. Error: %s", e) + "version is compatible. Error: %s", + e, + ) @triton.jit def pack_bitmatrix( bitmatrix, topk_ids, - n_rows, # n_rows in bitmatrix / topk_ids + n_rows, # n_rows in bitmatrix / topk_ids bm_cols: tl.constexpr, # n int32_t bitpacks in bitmatrix - n_expts_act, # num_topk + n_expts_act, # num_topk BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): @@ -60,12 +63,12 @@ def pack_bitmatrix( offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32) # All topks that need to go into this column has the correct bit set. # Other bits are 0. x is a 2D tensor. - x = tl.where(div[:, :, None] == offs[None, None, :], - (one << rem)[:, :, None], 0) + x = tl.where( + div[:, :, None] == offs[None, None, :], (one << rem)[:, :, None], 0 + ) # Reduce x to get a single int32_t bitpack. y = tl.reduce_or(x, axis=1) - bitmatrix_ptrs = bitmatrix + offsets_m[:, - None] * bm_cols + offs[None, :] + bitmatrix_ptrs = bitmatrix + offsets_m[:, None] * bm_cols + offs[None, :] tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) @@ -82,10 +85,9 @@ def triton_kernel_moe_forward( global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - routing_data, gather_idx, scatter_idx = routing(gating_output, - topk, - sm_first=not renormalize) + routing_data, gather_idx, scatter_idx = routing( + gating_output, topk, sm_first=not renormalize + ) return triton_kernel_fused_experts( None, @@ -99,7 +101,8 @@ def triton_kernel_moe_forward( quant_config=quant_config, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, - expert_map=expert_map) + expert_map=expert_map, + ) # This is a triton implementation of the fused_experts function @@ -125,10 +128,8 @@ def triton_kernel_fused_experts( # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 - assert (quant_config.w1_bias is None - or quant_config.w1_bias.dtype == torch.float32) - assert (quant_config.w2_bias is None - or quant_config.w2_bias.dtype == torch.float32) + assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32 + assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 # Shape check, only check non-mxfp4 assert hidden_states.shape[-1] == w1.shape[-2] @@ -141,7 +142,9 @@ def triton_kernel_fused_experts( act = FusedActivation( FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), - (swiglu_alpha, swiglu_limit), 2) + (swiglu_alpha, swiglu_limit), + 2, + ) gammas = routing_data.gate_scal if routing_data else None intermediate_cache1 = matmul_ogs( @@ -152,7 +155,8 @@ def triton_kernel_fused_experts( gather_indx=gather_indx, precision_config=quant_config.w1_precision, gammas=gammas if apply_router_weight_on_input else None, - fused_activation=act) + fused_activation=act, + ) intermediate_cache3 = matmul_ogs( intermediate_cache1, @@ -172,7 +176,6 @@ def make_routing_data( topk_weights: torch.Tensor, num_local_experts: int, ) -> tuple["RoutingData", torch.Tensor, torch.Tensor]: - topk_ids = topk_ids.to(torch.int16) topk_weights = topk_weights.to(torch.bfloat16) @@ -182,11 +185,11 @@ def make_routing_data( BLOCK_SIZE_K = 32 bm_cols = triton.cdiv(num_local_experts, BLOCK_SIZE_K) # n_bitpacks - bitmatrix = torch.zeros((n_rows, bm_cols), - dtype=torch.uint32, - device=topk_ids.device) + bitmatrix = torch.zeros( + (n_rows, bm_cols), dtype=torch.uint32, device=topk_ids.device + ) - grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), ) + grid = (triton.cdiv(n_rows, BLOCK_SIZE_M),) pack_bitmatrix[grid]( bitmatrix, topk_ids, @@ -199,21 +202,20 @@ def make_routing_data( bitmatrix_shape = [n_rows, bm_cols * 32] bitmatrix_shape_max = [n_rows, None] - bitmatrix = Bitmatrix(bitmatrix, - shape=bitmatrix_shape, - shape_max=bitmatrix_shape_max, - scratchpad=None) + bitmatrix = Bitmatrix( + bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=None + ) # matmul_ogs expects invalid topk_weights to be -1s topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights) routing_data, gather_indx, scatter_indx = routing_from_bitmatrix( - bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk) + bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk + ) return routing_data, gather_indx, scatter_indx class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config: FusedMoEQuantConfig): super().__init__(quant_config) @@ -234,7 +236,6 @@ def _make_routing_data( class OAITritonExperts(BaseOAITritonExperts): - def __init__(self, quant_config: FusedMoEQuantConfig): # TODO (varun) : Enable activation quantization assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16" @@ -242,18 +243,27 @@ def __init__(self, quant_config: FusedMoEQuantConfig): @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: return True def workspace_shapes( - self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, - topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata] + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # workspace are allocated inside the kernel workspace1 = (M, K) @@ -287,7 +297,8 @@ def apply( global_num_experts = local_num_experts routing_data, gather_indx, scatter_indx = self._make_routing_data( - topk_ids, topk_weights, local_num_experts) + topk_ids, topk_weights, local_num_experts + ) experts_output = triton_kernel_fused_experts( None, @@ -302,6 +313,7 @@ def apply( apply_router_weight_on_input=False, global_num_experts=local_num_experts, expert_map=None, # applied already - a1q_scale=a1q_scale) + a1q_scale=a1q_scale, + ) output.copy_(experts_output, non_blocking=True) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3b5ef78b37b0..767f9cd46a93 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -14,58 +14,72 @@ import vllm.envs as envs from vllm.config import get_current_vllm_config from vllm.config.parallel import ExpertPlacementStrategy -from vllm.distributed import (get_dp_group, get_ep_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_dp_group, + get_ep_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -# yapf: disable from vllm.model_executor.layers.fused_moe.config import ( - FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig, - FusedMoEQuantConfig, biased_moe_quant_config) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - zero_experts_compute_triton) -# yapf: enable + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + biased_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.fused_moe import zero_experts_compute_triton from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEActivationFormat, FusedMoEModularKernel, - FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) + FusedMoEActivationFormat, + FusedMoEModularKernel, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, +) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) -from vllm.model_executor.layers.fused_moe.routing_simulator import ( - RoutingSimulator) + is_rocm_aiter_moe_enabled, +) +from vllm.model_executor.layers.fused_moe.routing_simulator import RoutingSimulator from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx, - round_up) +from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts - from .fused_moe import (TritonExperts, eplb_map_to_physical_and_record, - fused_experts) + from .fused_moe import TritonExperts, eplb_map_to_physical_and_record, fused_experts + if has_pplx(): - from .pplx_prepare_finalize import (PplxPrepareAndFinalize, - pplx_hidden_dim_scale_bytes) + from .pplx_prepare_finalize import ( + PplxPrepareAndFinalize, + pplx_hidden_dim_scale_bytes, + ) if has_deep_ep(): from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize - from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, - DeepEPLLPrepareAndFinalize) + from .deepep_ll_prepare_finalize import ( + DEEPEP_QUANT_BLOCK_SHAPE, + DeepEPLLPrepareAndFinalize, + ) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPrepareAndFinalize = None # type: ignore def _eplb_map_to_physical_and_record( - topk_ids: torch.Tensor, expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - indices_type: Optional[torch.dtype]) -> torch.Tensor: + topk_ids: torch.Tensor, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + indices_type: Optional[torch.dtype], + ) -> torch.Tensor: # CPU fallback: no EPLB so just return as is return topk_ids @@ -73,7 +87,8 @@ def _eplb_map_to_physical_and_record( if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk as grouped_topk) + rocm_aiter_grouped_topk as grouped_topk, + ) else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): @@ -92,7 +107,6 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): - def __init__(self, moe: FusedMoEConfig): super().__init__() self.moe = moe @@ -101,9 +115,15 @@ def __init__(self, moe: FusedMoEConfig): self.topk_indices_dtype = None @abstractmethod - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): raise NotImplementedError def uses_weight_scale_2_pattern(self) -> bool: @@ -127,8 +147,7 @@ def _maybe_make_prepare_finalize( prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None # TODO: could allow this now - assert not moe.use_flashinfer_cutlass_kernels, \ - "Must be created in modelopt.py" + assert not moe.use_flashinfer_cutlass_kernels, "Must be created in modelopt.py" if moe.use_pplx_kernels: assert quant_config is not None @@ -155,13 +174,13 @@ def _maybe_make_prepare_finalize( hidden_dim_scale_bytes=hidden_scale_bytes, ) - num_dispatchers = (all2all_manager.world_size // - all2all_manager.tp_group.world_size) + num_dispatchers = ( + all2all_manager.world_size // all2all_manager.tp_group.world_size + ) # Intranode pplx a2a takes a group name while internode does not. if not all2all_manager.internode: - all_to_all_args[ - "group_name"] = all2all_manager.cpu_group.group_name + all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name handle = all2all_manager.get_handle(all_to_all_args) @@ -180,8 +199,7 @@ def _maybe_make_prepare_finalize( handle, num_dispatchers=all2all_manager.world_size, dp_size=all2all_manager.dp_world_size, - rank_expert_offset=all2all_manager.rank * - moe.num_local_experts, + rank_expert_offset=all2all_manager.rank * moe.num_local_experts, ) elif moe.use_deepep_ll_kernels: @@ -191,15 +209,16 @@ def _maybe_make_prepare_finalize( token_hidden_size=moe.hidden_dim, num_ep_ranks=all2all_manager.world_size, num_global_experts=moe.num_experts, - num_local_experts=moe.num_experts // - all2all_manager.world_size) + num_local_experts=moe.num_experts // all2all_manager.world_size, + ) handle = all2all_manager.get_handle(all_to_all_args) # Note: We may want to use FP8 dispatch just to reduce # data movement. use_fp8_dispatch = ( quant_config.quant_dtype == current_platform.fp8_dtype() - and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE) + and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE + ) prepare_finalize = DeepEPLLPrepareAndFinalize( handle, @@ -210,11 +229,11 @@ def _maybe_make_prepare_finalize( return prepare_finalize - def maybe_make_prepare_finalize( - self) -> Optional[FusedMoEPrepareAndFinalize]: + def maybe_make_prepare_finalize(self) -> Optional[FusedMoEPrepareAndFinalize]: if self.moe.moe_parallel_config.use_all2all_kernels: return FusedMoEMethodBase._maybe_make_prepare_finalize( - self.moe, self.moe_quant_config) + self.moe, self.moe_quant_config + ) else: return None @@ -231,11 +250,13 @@ def init_prepare_finalize(self, layer: torch.nn.Module): prepare_finalize = self.maybe_make_prepare_finalize() if prepare_finalize is not None: - logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__, - self, id(self)) + logger.debug( + "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) + ) assert self.topk_indices_dtype is None - assert self.fused_experts is None, \ + assert self.fused_experts is None, ( f"Attempt to override experts for {id(self)}!" + ) self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, layer) self.fused_experts = FusedMoEModularKernel( @@ -253,11 +274,13 @@ def select_gemm_impl( # gemm implementation raise NotImplementedError( f"{self.__class__.__name__} must select appropriate gemm " - "implementation based on the prepare_finalize") + "implementation based on the prepare_finalize" + ) @abstractmethod def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: raise NotImplementedError @abstractmethod @@ -296,6 +319,7 @@ def __init__(self, moe: FusedMoEConfig): self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() if self.rocm_aiter_moe_enabled: from .rocm_aiter_fused_moe import rocm_aiter_fused_experts + self.rocm_aiter_fused_experts = rocm_aiter_fused_experts else: self.rocm_aiter_fused_experts = None # type: ignore @@ -306,7 +330,8 @@ def __init__(self, moe: FusedMoEConfig): and envs.VLLM_USE_FLASHINFER_MOE_FP16 and self.moe.moe_parallel_config.use_ep and self.moe.moe_parallel_config.dp_size == 1 - and current_platform.get_device_capability()[0] >= 9) + and current_platform.get_device_capability()[0] >= 9 + ) if self.flashinfer_cutlass_moe_enabled: logger.info_once( "Enabling FlashInfer CUTLASS MoE for UnquantizedFusedMoEMethod" @@ -314,28 +339,32 @@ def __init__(self, moe: FusedMoEConfig): from functools import partial from .flashinfer_cutlass_moe import flashinfer_cutlass_moe + self.flashinfer_cutlass_moe = partial( flashinfer_cutlass_moe, quant_config=FUSED_MOE_UNQUANTIZED_CONFIG, tp_rank=self.moe.moe_parallel_config.tp_rank, tp_size=self.moe.moe_parallel_config.tp_size, ep_rank=self.moe.moe_parallel_config.ep_rank, - ep_size=self.moe.moe_parallel_config.ep_size) + ep_size=self.moe.moe_parallel_config.ep_size, + ) else: - if (self.moe.moe_parallel_config.use_ep - and self.moe.moe_parallel_config.dp_size == 1): + if ( + self.moe.moe_parallel_config.use_ep + and self.moe.moe_parallel_config.dp_size == 1 + ): logger.info_once( "FlashInfer CUTLASS MoE is available for EP" " but not enabled, consider setting" - " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.") + " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it." + ) elif self.moe.moe_parallel_config.dp_size > 1: logger.info_once( "FlashInfer CUTLASS MoE is currently not available for DP." ) self.flashinfer_cutlass_moe = None # type: ignore - def maybe_make_prepare_finalize( - self) -> Optional[FusedMoEPrepareAndFinalize]: + def maybe_make_prepare_finalize(self) -> Optional[FusedMoEPrepareAndFinalize]: if self.rocm_aiter_moe_enabled: return None else: @@ -347,8 +376,10 @@ def select_gemm_impl( layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: assert self.moe_quant_config is not None - if (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts): + if ( + prepare_finalize.activation_format + == FusedMoEActivationFormat.BatchedExperts + ): logger.debug("BatchedTritonExperts %s", self.moe) return BatchedTritonExperts( max_num_tokens=self.moe.max_num_tokens, @@ -359,49 +390,65 @@ def select_gemm_impl( logger.debug("TritonExperts %s", self.moe) return TritonExperts(self.moe_quant_config) - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) if self.moe.has_bias: - w13_bias = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, 2 * intermediate_size_per_partition, dtype=params_dtype + ), + requires_grad=False, + ) layer.register_parameter("w13_bias", w13_bias) set_weight_attrs(w13_bias, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) if self.moe.has_bias: - w2_bias = torch.nn.Parameter(torch.zeros(num_experts, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_bias", w2_bias) set_weight_attrs(w2_bias, extra_weight_attrs) def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: # Pad the weight tensor. This is an optimization on ROCm platform, which # can benefit from tensors located far enough from one another in memory - if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm() - and weight.stride(-1) == 1 - and (weight.stride(-2) * weight.element_size()) % 512 == 0): + if ( + envs.VLLM_ROCM_MOE_PADDING + and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0 + ): num_pad = 256 // weight.element_size() weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] torch.cuda.empty_cache() @@ -416,11 +463,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) # Lazy import to avoid importing triton. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - shuffle_weights) + shuffle_weights, + ) if self.rocm_aiter_moe_enabled: shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + layer.w13_weight.data, layer.w2_weight.data + ) layer.w13_weight.data = shuffled_w13 layer.w2_weight.data = shuffled_w2 @@ -433,6 +482,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if current_platform.is_xpu(): import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.w13_weight, layer.w2_weight, @@ -440,23 +490,28 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) elif current_platform.is_cpu(): from vllm.model_executor.layers.fused_moe import cpu_fused_moe + if current_platform.get_cpu_architecture() == CpuArchEnum.X86: - from vllm.model_executor.layers.utils import ( - check_cpu_sgl_kernel) + from vllm.model_executor.layers.utils import check_cpu_sgl_kernel + dtype_w13 = layer.w13_weight.dtype _, n_w13, k_w13 = layer.w13_weight.size() dtype_w2 = layer.w2_weight.dtype _, n_w2, k_w2 = layer.w2_weight.size() - if (envs.VLLM_CPU_SGL_KERNEL - and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) - and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2)): + if ( + envs.VLLM_CPU_SGL_KERNEL + and check_cpu_sgl_kernel(n_w13, k_w13, dtype_w13) + and check_cpu_sgl_kernel(n_w2, k_w2, dtype_w2) + ): packed_w13_weight = torch.ops._C.convert_weight_packed( - layer.w13_weight) + layer.w13_weight + ) assert packed_w13_weight.size() == layer.w13_weight.size() layer.w13_weight.copy_(packed_w13_weight) del packed_w13_weight packed_w2_weight = torch.ops._C.convert_weight_packed( - layer.w2_weight) + layer.w2_weight + ) assert packed_w2_weight.size() == layer.w2_weight.size() layer.w2_weight.copy_(packed_w2_weight) layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer) @@ -518,7 +573,8 @@ def apply( ) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: if self.moe.has_bias: return biased_moe_quant_config( layer.w13_bias, @@ -550,9 +606,8 @@ def forward_cuda( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - - zero_expert_num = getattr(layer, 'zero_expert_num', 0) - zero_expert_type = getattr(layer, 'zero_expert_type', None) + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts( hidden_states=x, @@ -574,7 +629,8 @@ def forward_cuda( logical_replica_count=logical_replica_count, global_num_experts=global_num_experts, zero_expert_num=zero_expert_num, - zero_expert_type=zero_expert_type) + zero_expert_type=zero_expert_type, + ) if self.rocm_aiter_moe_enabled: assert self.fused_experts is None @@ -586,7 +642,8 @@ def forward_cuda( topk_ids=topk_ids, expert_map=expert_map, activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + ) elif self.flashinfer_cutlass_moe_enabled: return self.flashinfer_cutlass_moe( hidden_states=x, @@ -595,11 +652,11 @@ def forward_cuda( topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + ) elif self.fused_experts is not None: if self.moe.has_bias: - raise ValueError( - "FusedMoEModularKernel does not support bias.") + raise ValueError("FusedMoEModularKernel does not support bias.") result = self.fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -629,8 +686,9 @@ def forward_cuda( ) if zero_expert_num != 0 and zero_expert_type is not None: - assert not isinstance(result, tuple), \ + assert not isinstance(result, tuple), ( "Shared + zero experts are mutually exclusive not yet supported" + ) return result, zero_expert_result else: return result @@ -658,11 +716,13 @@ def forward_cpu( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - if enable_eplb is not False or expert_load_view is not None or \ - logical_to_physical_map is not None or \ - logical_replica_count is not None: - raise NotImplementedError("Expert load balancing is not supported " - "for CPU.") + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for CPU.") return layer.cpu_fused_moe( layer, x, @@ -705,11 +765,13 @@ def forward_xpu( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - if enable_eplb is not False or expert_load_view is not None or \ - logical_to_physical_map is not None or \ - logical_replica_count is not None: - raise NotImplementedError("Expert load balancing is not supported " - "for XPU.") + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for XPU.") assert custom_routing_function is None return layer.ipex_fusion( x, @@ -751,27 +813,33 @@ def forward_tpu( assert apply_router_weight_on_input is False if scoring_func != "softmax": raise NotImplementedError( - "Only softmax scoring function is supported for TPU.") + "Only softmax scoring function is supported for TPU." + ) if e_score_correction_bias is not None: raise NotImplementedError( - "Expert score correction bias is not supported for TPU.") + "Expert score correction bias is not supported for TPU." + ) assert activation == "silu", f"{activation} is not supported for TPU." - assert routed_scaling_factor == 1.0, \ - f"routed_scaling_factor {routed_scaling_factor} is not supported " \ - f"for TPU." - if enable_eplb is not False or expert_load_view is not None or \ - logical_to_physical_map is not None or \ - logical_replica_count is not None: - raise NotImplementedError("Expert load balancing is not supported " - "for TPU.") - return fused_moe_pallas(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk=top_k, - gating_output=router_logits, - global_num_experts=global_num_experts, - expert_map=expert_map, - renormalize=renormalize) + assert routed_scaling_factor == 1.0, ( + f"routed_scaling_factor {routed_scaling_factor} is not supported for TPU." + ) + if ( + enable_eplb is not False + or expert_load_view is not None + or logical_to_physical_map is not None + or logical_replica_count is not None + ): + raise NotImplementedError("Expert load balancing is not supported for TPU.") + return fused_moe_pallas( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=top_k, + gating_output=router_logits, + global_num_experts=global_num_experts, + expert_map=expert_map, + renormalize=renormalize, + ) if current_platform.is_tpu(): forward_native = forward_tpu @@ -790,27 +858,27 @@ def determine_expert_map( expert_placement_strategy: ExpertPlacementStrategy = "linear", ) -> tuple[int, Optional[torch.Tensor]]: """ - Calculates how many experts should be assigned to each rank for EP and - creates a mapping from global to local expert index. Experts are - distributed evenly across ranks. Any remaining are assigned to the - last rank. - - Args: - ep_size: The size of the expert parallel group - ep_rank: The rank of the current process in the expert parallel - group - global_num_experts: The total number of experts in the model. - expert_placement_strategy: The expert placement strategy. + Calculates how many experts should be assigned to each rank for EP and + creates a mapping from global to local expert index. Experts are + distributed evenly across ranks. Any remaining are assigned to the + last rank. - Returns: - tuple[int, Optional[torch.Tensor]]: A tuple containing: - - local_num_experts (int): The number of experts assigned - to the current rank. - - expert_map (Optional[torch.Tensor]): A tensor of shape - (global_num_experts,) mapping from global to local index. - Contains -1 for experts not assigned to the current rank. - Returns None if ep_size is 1. - """ + Args: + ep_size: The size of the expert parallel group + ep_rank: The rank of the current process in the expert parallel + group + global_num_experts: The total number of experts in the model. + expert_placement_strategy: The expert placement strategy. + + Returns: + tuple[int, Optional[torch.Tensor]]: A tuple containing: + - local_num_experts (int): The number of experts assigned + to the current rank. + - expert_map (Optional[torch.Tensor]): A tensor of shape + (global_num_experts,) mapping from global to local index. + Contains -1 for experts not assigned to the current rank. + Returns None if ep_size is 1. + """ assert ep_size > 0 if ep_size == 1: return (global_num_experts, None) @@ -818,62 +886,64 @@ def determine_expert_map( # Distribute experts as evenly as possible to each rank. base_experts = global_num_experts // ep_size remainder = global_num_experts % ep_size - if ep_rank < remainder: - local_num_experts = base_experts + 1 - else: - local_num_experts = base_experts + local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts # Create a tensor of size num_experts filled with -1 - expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) + expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32) # Create an expert map for the local experts if expert_placement_strategy == "linear": start_idx = ep_rank * base_experts + min(ep_rank, remainder) - expert_map[start_idx:start_idx + local_num_experts] = torch.arange( - 0, local_num_experts, dtype=torch.int32) + expert_map[start_idx : start_idx + local_num_experts] = torch.arange( + 0, local_num_experts, dtype=torch.int32 + ) elif expert_placement_strategy == "round_robin": - local_log_experts = torch.arange(ep_rank, - global_num_experts, - ep_size, - dtype=torch.int32) - - expert_map[local_log_experts] = torch.arange(0, - local_num_experts, - dtype=torch.int32) + local_log_experts = torch.arange( + ep_rank, global_num_experts, ep_size, dtype=torch.int32 + ) + + expert_map[local_log_experts] = torch.arange( + 0, local_num_experts, dtype=torch.int32 + ) else: - raise ValueError("Unsupported expert placement strategy " - f"'{expert_placement_strategy}', expected one of " - f"{get_args(ExpertPlacementStrategy)}") + raise ValueError( + "Unsupported expert placement strategy " + f"'{expert_placement_strategy}', expected one of " + f"{get_args(ExpertPlacementStrategy)}" + ) return (local_num_experts, expert_map) def get_compressed_expert_map(expert_map: torch.Tensor) -> str: """ - Compresses the expert map by removing any -1 entries. + Compresses the expert map by removing any -1 entries. - Args: - expert_map (torch.Tensor): A tensor of shape (global_num_experts,) - mapping from global to local index. Contains -1 for experts not - assigned to the current rank. + Args: + expert_map (torch.Tensor): A tensor of shape (global_num_experts,) + mapping from global to local index. Contains -1 for experts not + assigned to the current rank. - Returns: - str: A string mapping from local to global index. - Using str to support hashing for logging once only. - """ + Returns: + str: A string mapping from local to global index. + Using str to support hashing for logging once only. + """ global_indices = torch.where(expert_map != -1)[0] local_indices = expert_map[global_indices] return ", ".join( f"{local_index.item()}->{global_index.item()}" - for local_index, global_index in zip(local_indices, global_indices)) + for local_index, global_index in zip(local_indices, global_indices) + ) def maybe_roundup_hidden_size( - hidden_size: int, act_dtype: torch.dtype, - quant_config: Optional[QuantizationConfig], - moe_parallel_config: FusedMoEParallelConfig) -> int: + hidden_size: int, + act_dtype: torch.dtype, + quant_config: Optional[QuantizationConfig], + moe_parallel_config: FusedMoEParallelConfig, +) -> int: """ Given layer hidden size and MoE configurations, round up hidden_size if necessary. - + Args: hidden_size: Layer hidden-size act_dtype: Data type of the layer activations. @@ -885,24 +955,29 @@ def maybe_roundup_hidden_size( Original hidden size otherwise. """ - if (moe_parallel_config.use_deepep_ht_kernels): - hidden_size = ( - DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size( - hidden_size, act_dtype)) + if moe_parallel_config.use_deepep_ht_kernels: + hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size, act_dtype + ) # we are padding globally so EP buffer allocation works if quant_config and quant_config.get_name() == "mxfp4": - from vllm.model_executor.layers.quantization.mxfp4 import ( - Mxfp4Backend, get_mxfp4_backend) + Mxfp4Backend, + get_mxfp4_backend, + ) + current_mxfp4_backend = get_mxfp4_backend() - if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 - or current_mxfp4_backend - == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS): + if ( + current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + ): hidden_size = round_up(hidden_size, 128) - elif (current_platform.is_rocm() or current_mxfp4_backend - == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + elif ( + current_platform.is_rocm() + or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + ): hidden_size = round_up(hidden_size, 256) return hidden_size @@ -979,19 +1054,19 @@ def __init__( # since model_config is not set in the pytest test. moe_in_dtype = params_dtype - tp_size_ = (tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()) - dp_size_ = (dp_size - if dp_size is not None else get_dp_group().world_size) + tp_size_ = ( + tp_size if tp_size is not None else get_tensor_model_parallel_world_size() + ) + dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size self.is_sequence_parallel = is_sequence_parallel self.sp_size = tp_size_ if is_sequence_parallel else 1 - self.moe_parallel_config: FusedMoEParallelConfig = ( - FusedMoEParallelConfig.make( - tp_size_=tp_size_, - dp_size_=dp_size_, - vllm_parallel_config=vllm_config.parallel_config)) + self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( + tp_size_=tp_size_, + dp_size_=dp_size_, + vllm_parallel_config=vllm_config.parallel_config, + ) self.global_num_experts = num_experts + num_redundant_experts self.zero_expert_num = zero_expert_num @@ -1001,9 +1076,9 @@ def __init__( self.expert_mapping = expert_mapping # Round up hidden size if needed. - hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype, - quant_config, - self.moe_parallel_config) + hidden_size = maybe_roundup_hidden_size( + hidden_size, moe_in_dtype, quant_config, self.moe_parallel_config + ) # For smuggling this layer into the fused moe custom op compilation_config = vllm_config.compilation_config @@ -1020,28 +1095,33 @@ def __init__( # Determine expert maps if self.use_ep: if self.enable_eplb: - assert self.global_num_experts % self.ep_size == 0, \ - "EPLB currently only supports even distribution of " \ + assert self.global_num_experts % self.ep_size == 0, ( + "EPLB currently only supports even distribution of " "experts across ranks." + ) else: - assert num_redundant_experts == 0, \ + assert num_redundant_experts == 0, ( "Redundant experts are only supported with EPLB." + ) expert_placement_strategy = ( - vllm_config.parallel_config.expert_placement_strategy) + vllm_config.parallel_config.expert_placement_strategy + ) if expert_placement_strategy == "round_robin": # TODO(Bruce): will support round robin expert placement with # EPLB enabled in the future. - round_robin_supported = ((num_expert_group is not None - and num_expert_group > 1) - and num_redundant_experts == 0 - and not self.enable_eplb) + round_robin_supported = ( + (num_expert_group is not None and num_expert_group > 1) + and num_redundant_experts == 0 + and not self.enable_eplb + ) if not round_robin_supported: logger.warning( "Round-robin expert placement is only supported for " "models with multiple expert groups and no redundant " - "experts. Falling back to linear expert placement.") + "experts. Falling back to linear expert placement." + ) expert_placement_strategy = "linear" self.expert_map: Optional[torch.Tensor] @@ -1057,12 +1137,16 @@ def __init__( "[EP Rank %s/%s] Expert parallelism is enabled. Expert " "placement strategy: %s. Local/global" " number of experts: %s/%s. Experts local to global index map:" - " %s.", self.ep_rank, self.ep_size, expert_placement_strategy, - self.local_num_experts, self.global_num_experts, - get_compressed_expert_map(self.expert_map)) + " %s.", + self.ep_rank, + self.ep_size, + expert_placement_strategy, + self.local_num_experts, + self.global_num_experts, + get_compressed_expert_map(self.expert_map), + ) else: - self.local_num_experts, self.expert_map = (self.global_num_experts, - None) + self.local_num_experts, self.expert_map = (self.global_num_experts, None) self.top_k = top_k @@ -1084,8 +1168,9 @@ def __init__( self.activation = activation if self.scoring_func != "softmax" and not self.use_grouped_topk: - raise ValueError("Only softmax scoring function is supported for " - "non-grouped topk.") + raise ValueError( + "Only softmax scoring function is supported for non-grouped topk." + ) moe = FusedMoEConfig( num_experts=self.global_num_experts, @@ -1104,18 +1189,20 @@ def __init__( # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. quant_method: Optional[QuantizeMethodBase] = None - quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None - else quant_config.get_quant_method(self, prefix)) + quant_method = ( + UnquantizedFusedMoEMethod(moe) + if quant_config is None + else quant_config.get_quant_method(self, prefix) + ) assert quant_method is not None assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method if self.enable_eplb: - from vllm.model_executor.layers.quantization.fp8 import ( - Fp8MoEMethod) - if not isinstance(quant_method, - (Fp8MoEMethod, UnquantizedFusedMoEMethod)): + from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod + + if not isinstance(quant_method, (Fp8MoEMethod, UnquantizedFusedMoEMethod)): # TODO: Add support for additional quantization methods. # The implementation for other quantization methods does not # contain essential differences, but the current quant API @@ -1123,22 +1210,23 @@ def __init__( # quantization methods, so I'm leaving it for now. # If you plan to add support for more quantization methods, # please refer to the implementation in `Fp8MoEMethod`. - raise NotImplementedError("EPLB is only supported for FP8 " - "quantization for now.") + raise NotImplementedError( + "EPLB is only supported for FP8 quantization for now." + ) moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, - "intermediate_size_per_partition": - self.intermediate_size_per_partition, + "intermediate_size_per_partition": self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod")): + if self.quant_method.__class__.__name__ in ( + "GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + ): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) @@ -1149,31 +1237,37 @@ def __init__( # TODO(bnell): flashinfer uses non-batched format. # Does it really need a batched buffer? - if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels - or self.moe_config.use_flashinfer_cutlass_kernels): + if ( + self.moe_parallel_config.use_pplx_kernels + or self.moe_parallel_config.use_deepep_ll_kernels + or self.moe_config.use_flashinfer_cutlass_kernels + ): if vllm_config.parallel_config.enable_dbo: self.batched_hidden_states = torch.zeros( (2, moe.max_num_tokens, self.hidden_size), dtype=moe.in_dtype, - device=torch.cuda.current_device()) + device=torch.cuda.current_device(), + ) # Note here we use `num_experts` which is logical expert count self.batched_router_logits = torch.zeros( (2, moe.max_num_tokens, num_experts), dtype=moe.in_dtype, - device=torch.cuda.current_device()) + device=torch.cuda.current_device(), + ) else: self.batched_hidden_states = torch.zeros( (moe.max_num_tokens, self.hidden_size), dtype=moe.in_dtype, - device=torch.cuda.current_device()) + device=torch.cuda.current_device(), + ) # Note here we use `num_experts` which is logical expert count self.batched_router_logits = torch.zeros( (moe.max_num_tokens, num_experts), dtype=moe.in_dtype, - device=torch.cuda.current_device()) + device=torch.cuda.current_device(), + ) @property def shared_experts(self) -> Optional[torch.nn.Module]: @@ -1221,9 +1315,11 @@ def use_deepep_ll_kernels(self): @property def use_flashinfer_cutlass_kernels(self): - return (self.moe_quant_config is not None - and self.moe_quant_config.quant_dtype == "nvfp4" - and self.moe_config.use_flashinfer_cutlass_kernels) + return ( + self.moe_quant_config is not None + and self.moe_quant_config.quant_dtype == "nvfp4" + and self.moe_config.use_flashinfer_cutlass_kernels + ) def update_expert_map(self): # ep_size and ep_rank should already be updated @@ -1232,14 +1328,18 @@ def update_expert_map(self): local_num_experts, expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, - global_num_experts=self.global_num_experts) + global_num_experts=self.global_num_experts, + ) self.local_num_experts = local_num_experts self.register_buffer("expert_map", expert_map) - def _load_per_tensor_weight_scale(self, shard_id: str, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - expert_id: int): + def _load_per_tensor_weight_scale( + self, + shard_id: str, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + expert_id: int, + ): param_data = param.data # for per tensor weight quantization if shard_id in ("w1", "w3"): @@ -1251,25 +1351,32 @@ def _load_per_tensor_weight_scale(self, shard_id: str, elif shard_id == "w2": param_data[expert_id] = loaded_weight - def _load_combined_w13_weight_scale(self, shard_dim: int, - loaded_weight: torch.Tensor, - param: torch.Tensor, tp_rank: int): + def _load_combined_w13_weight_scale( + self, + shard_dim: int, + loaded_weight: torch.Tensor, + param: torch.Tensor, + tp_rank: int, + ): """ Load w13 weight scales assuming that w1 weight scales and w3 weight scales are stored in the same loaded_weight tensor. """ shard_size = param.shape[shard_dim] - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) param.copy_(loaded_weight) - def _load_model_weight_or_group_weight_scale(self, - shard_dim: int, - expert_data: torch.Tensor, - shard_id: str, - loaded_weight: torch.Tensor, - tp_rank: int, - load_full_w2: bool = False): + def _load_model_weight_or_group_weight_scale( + self, + shard_dim: int, + expert_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full_w2: bool = False, + ): """ Load grouped weight scales for group quantization or model weights :param shard_dim: dimension to shard @@ -1282,47 +1389,58 @@ def _load_model_weight_or_group_weight_scale(self, if shard_id == "w2": # In the case where we have actorder/g_idx, we do not partition the # w2 scales, as indicated by `load_full` argument, for all tp cases - self._load_w2(shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank, - load_full=load_full_w2) + self._load_w2( + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + load_full=load_full_w2, + ) elif shard_id in ("w1", "w3"): - self._load_w13(shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - - def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, - shard_dim: int, shard_id: str, - loaded_weight: torch.Tensor, - tp_rank: int): + self._load_w13( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + + def _load_per_channel_weight_scale( + self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + ): # for per channel weight quantization if shard_id == "w2": expert_data.copy_(loaded_weight) elif shard_id in ("w1", "w3"): - self._load_w13(shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) - - def _load_w13(self, - expert_data: torch.Tensor, - shard_dim: int, - shard_id: str, - loaded_weight: torch.Tensor, - tp_rank: int, - load_full: bool = False): + self._load_w13( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) + def _load_w13( + self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False, + ): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 if not load_full: - loaded_weight = loaded_weight.narrow(shard_dim, - shard_size * tp_rank, - shard_size) + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -1333,39 +1451,48 @@ def _load_w13(self, expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) - def _load_w2(self, - expert_data: torch.Tensor, - shard_dim: int, - loaded_weight: torch.Tensor, - tp_rank: int, - load_full: bool = False): - + def _load_w2( + self, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False, + ): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] if not load_full: - loaded_weight = loaded_weight.narrow(shard_dim, - shard_size * tp_rank, - shard_size) + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) - def _load_single_value(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, expert_id: int): + def _load_single_value( + self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int + ): param_data = param.data # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight - def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, - shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): - + def _load_g_idx( + self, + shard_id: str, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + ): if shard_id == "w2": - self._load_w2(shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=tp_rank) + self._load_w2( + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank, + ) else: assert shard_id in ("w1", "w3") expert_data.copy_(loaded_weight) @@ -1376,27 +1503,36 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: return self.expert_map[expert_id].item() @overload - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int, - return_success: Literal[False]) -> None: - ... + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: Literal[False], + ) -> None: ... @overload - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int, - return_success: Literal[True]) -> bool: - ... - - def weight_loader(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, - return_success: bool = False) -> Optional[bool]: - + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: Literal[True], + ) -> bool: ... + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False, + ) -> Optional[bool]: if self.quant_config and self.quant_config.get_name() == "mxfp4": # (FIXME) for gpt-oss all experts are combined if "bias" in weight_name: @@ -1419,13 +1555,13 @@ def weight_loader(self, # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality if self.quant_method.__class__.__name__ in ( - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod"): + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + ): loaded_weight = loaded_weight.t().contiguous() if shard_id not in ("w1", "w2", "w3"): - raise ValueError(f"shard_id must be ['w1','w2','w3'] but " - f"got {shard_id}.") + raise ValueError(f"shard_id must be ['w1','w2','w3'] but got {shard_id}.") # Fetch the dim to shard the parameter/loaded weight # based on the shard id. This will be whatever @@ -1487,43 +1623,49 @@ def weight_loader(self, # this is needed for compressed-tensors only loaded_weight = loaded_weight.to(param.data.device) - if ("compressed" in quant_method_name.lower() - and param.data[expert_id] != 1 - and (param.data[expert_id] - loaded_weight).abs() > 1e-5): + if ( + "compressed" in quant_method_name.lower() + and param.data[expert_id] != 1 + and (param.data[expert_id] - loaded_weight).abs() > 1e-5 + ): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param.data[expert_id]} " - f"vs. {loaded_weight}") + f"vs. {loaded_weight}" + ) - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) + self._load_single_value( + param=param, loaded_weight=loaded_weight, expert_id=expert_id + ) return True if return_success else None # Case g_idx if "g_idx" in weight_name: - self._load_g_idx(shard_dim=0, - shard_id=shard_id, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=self.tp_rank) + self._load_g_idx( + shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + ) return True if return_success else None # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern if "ModelOpt" in quant_method_name: # Determine per-tensor weight scale patterns based on variant # Use the dedicated method instead of brittle string matching - uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern( - ) + uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern() # Call _load_per_tensor_weight_scale() to load per-tensor (scalar) # weights scales. # Input scales are always per-tensor. # Weight scales: FP4 uses "weight_scale_2" and FP8 uses # "weight_scale" for per-tensor scales. - is_per_tensor = ("weight_scale_2" in weight_name - if uses_weight_scale_2 else "weight_scale" - in weight_name) or "input_scale" in weight_name + is_per_tensor = ( + "weight_scale_2" in weight_name + if uses_weight_scale_2 + else "weight_scale" in weight_name + ) or "input_scale" in weight_name if is_per_tensor: self._load_per_tensor_weight_scale( shard_id=shard_id, @@ -1558,12 +1700,12 @@ def weight_loader(self, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=self.tp_rank) + tp_rank=self.tp_rank, + ) return True if return_success else None # Case weight scales, zero_points and offset, weight/input global scales - if ("scale" in weight_name or "zero" in weight_name - or "offset" in weight_name): + if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name: # load the weight scales and zp based on the quantization scheme # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported @@ -1576,10 +1718,11 @@ def weight_loader(self, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=self.tp_rank) + tp_rank=self.tp_rank, + ) elif quant_method in [ - FusedMoeWeightScaleSupported.GROUP.value, - FusedMoeWeightScaleSupported.BLOCK.value, + FusedMoeWeightScaleSupported.GROUP.value, + FusedMoeWeightScaleSupported.BLOCK.value, ]: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, @@ -1587,26 +1730,28 @@ def weight_loader(self, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank, - load_full_w2=getattr(param, "load_full_w2", False)) + load_full_w2=getattr(param, "load_full_w2", False), + ) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: - self._load_per_tensor_weight_scale(shard_id=shard_id, - param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) + self._load_per_tensor_weight_scale( + shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id, + ) else: - WEIGHT_SCALE_SUPPORTED = [ - e.value for e in FusedMoeWeightScaleSupported - ] + WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported] raise ValueError( - f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") + f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}" + ) return True if return_success else None # Case weight_shape if "weight_shape" in weight_name: # only required by compressed-tensors - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) + self._load_single_value( + param=param, loaded_weight=loaded_weight, expert_id=expert_id + ) return True if return_success else None # Case model weights @@ -1616,17 +1761,20 @@ def weight_loader(self, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=self.tp_rank) + tp_rank=self.tp_rank, + ) return True if return_success else None return False if return_success else None def load_weights( - self, weights: Iterable[tuple[str, - torch.Tensor]]) -> Iterable[str]: + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[str]: if (expert_mapping := self.expert_mapping) is None: - raise ValueError("`self.expert_mapping` must be provided to " - "load weights using `self.load_weights`.") + raise ValueError( + "`self.expert_mapping` must be provided to " + "load weights using `self.load_weights`." + ) for expert_name, loaded_weight in weights: qual_name = f"{self.layer_name}.{expert_name}" for param_name, weight_name, expert_id, shard_id in expert_mapping: @@ -1644,8 +1792,12 @@ def load_weights( return_success=True, ) if success: - logger.debug("Loaded %s for expert %d into %s", param_name, - expert_id, self.layer_name) + logger.debug( + "Loaded %s for expert %d into %s", + param_name, + expert_id, + self.layer_name, + ) yield param_name def get_expert_weights(self) -> Iterable[torch.Tensor]: @@ -1660,9 +1812,11 @@ def get_expert_weights(self) -> Iterable[torch.Tensor]: } return [ - weight.view(self.local_num_experts, -1) for name, weight in weights - if name not in NON_EXPERT_WEIGHTS and weight.shape != torch.Size( - []) and not name.startswith("_shared_experts.") + weight.view(self.local_num_experts, -1) + for name, weight in weights + if name not in NON_EXPERT_WEIGHTS + and weight.shape != torch.Size([]) + and not name.startswith("_shared_experts.") ] def set_eplb_state( @@ -1685,7 +1839,8 @@ def set_eplb_state( def ensure_moe_quant_config(self): if self.quant_method.moe_quant_config is None: self.quant_method.moe_quant_config = ( - self.quant_method.get_fused_moe_quant_config(self)) + self.quant_method.get_fused_moe_quant_config(self) + ) @staticmethod def select_experts( @@ -1715,7 +1870,7 @@ def select_experts( router logits. Returns: - (topk_weights, topk_ids, zero_expert_result) + (topk_weights, topk_ids, zero_expert_result) (tuple[torch.Tensor, torch.Tensor, torch.Tensor]): The weights, expert ids, and zero expert computation result. @@ -1724,7 +1879,9 @@ def select_experts( plain MoE implementations without redundant experts. """ from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_topk, fused_topk_bias) + fused_topk, + fused_topk_bias, + ) # Check if we should use a routing simulation strategy routing_strategy = envs.VLLM_MOE_ROUTING_SIMULATION_STRATEGY @@ -1734,7 +1891,8 @@ def select_experts( router_logits=router_logits, strategy_name=routing_strategy, top_k=top_k, - indices_type=indices_type) + indices_type=indices_type, + ) # DeepSeekv2 uses grouped_top_k if use_grouped_topk: @@ -1749,7 +1907,8 @@ def select_experts( topk_group=topk_group, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + ) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) elif e_score_correction_bias is not None: @@ -1775,7 +1934,8 @@ def select_experts( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, - renormalize=renormalize) + renormalize=renormalize, + ) if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) @@ -1795,9 +1955,12 @@ def select_experts( assert topk_ids.dtype == indices_type or indices_type is None # Compute zero expert result if needed - if (zero_expert_num is not None and zero_expert_num > 0 - and zero_expert_type is not None - and global_num_experts is not None): + if ( + zero_expert_num is not None + and zero_expert_num > 0 + and zero_expert_type is not None + and global_num_experts is not None + ): zero_expert_result = zero_experts_compute_triton( expert_indices=topk_ids, expert_scales=topk_weights, @@ -1822,16 +1985,21 @@ def must_reduce_shared_expert_outputs(self) -> bool: Therefore it is required that we reduce the shared_experts output early. """ - return (self.use_pplx_kernels or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels) + return ( + self.use_pplx_kernels + or self.use_deepep_ht_kernels + or self.use_deepep_ll_kernels + ) - def maybe_all_reduce_tensor_model_parallel( - self, final_hidden_states: torch.Tensor): + def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): """ The pplx combine kernel reduces across GPU ranks by default. """ - if (self.use_pplx_kernels or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels): + if ( + self.use_pplx_kernels + or self.use_deepep_ht_kernels + or self.use_deepep_ll_kernels + ): return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) @@ -1843,10 +2011,12 @@ def forward_native( ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: og_hidden_states = hidden_states.shape[-1] if self.hidden_size != og_hidden_states: - hidden_states = F.pad(hidden_states, - (0, self.hidden_size - og_hidden_states), - mode='constant', - value=0.0) + hidden_states = F.pad( + hidden_states, + (0, self.hidden_size - og_hidden_states), + mode="constant", + value=0.0, + ) if self.shared_experts is None: if current_platform.is_tpu(): @@ -1856,19 +2026,24 @@ def forward_native( assert not isinstance(fused_output, tuple) else: fused_output = torch.ops.vllm.moe_forward( - hidden_states, router_logits, self.layer_name) + hidden_states, router_logits, self.layer_name + ) return fused_output[..., :og_hidden_states] else: if current_platform.is_tpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. shared_output, fused_output = self.forward_impl( - hidden_states, router_logits) + hidden_states, router_logits + ) else: shared_output, fused_output = torch.ops.vllm.moe_forward_shared( - hidden_states, router_logits, self.layer_name) - return (shared_output[..., :og_hidden_states], - fused_output[..., :og_hidden_states]) + hidden_states, router_logits, self.layer_name + ) + return ( + shared_output[..., :og_hidden_states], + fused_output[..., :og_hidden_states], + ) def forward_cuda( self, @@ -1887,17 +2062,14 @@ def forward_impl_chunked( assert self.batched_hidden_states.dtype == full_hidden_states.dtype assert self.batched_router_logits.dtype == full_router_logits.dtype # Check size compatibility. - assert ( - self.batched_hidden_states.size(-1) == full_hidden_states.size(-1)) - assert ( - self.batched_router_logits.size(-1) == full_router_logits.size(-1)) + assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1) + assert self.batched_router_logits.size(-1) == full_router_logits.size(-1) self.ensure_moe_quant_config() full_fused_final_hidden_states = torch.empty_like(full_hidden_states) if self.shared_experts is not None: - full_shared_final_hidden_states = torch.empty_like( - full_hidden_states) + full_shared_final_hidden_states = torch.empty_like(full_hidden_states) def process_chunk(chunk_start, chunk_end, skip_result_store=False): chunk_size = chunk_end - chunk_start @@ -1911,30 +2083,31 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): if self.batched_hidden_states.dim() == 3: assert self.batched_router_logits.dim() == 3 batch_buffer_idx = dbo_current_ubatch_id() - batched_hidden_states = self.batched_hidden_states[ - batch_buffer_idx, :] - batched_router_logits = self.batched_router_logits[ - batch_buffer_idx, :] + batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :] + batched_router_logits = self.batched_router_logits[batch_buffer_idx, :] else: batched_hidden_states = self.batched_hidden_states batched_router_logits = self.batched_router_logits - assert (batched_hidden_states.size(0) # type: ignore - >= chunk_size) - assert (batched_router_logits.size(0) # type: ignore - >= chunk_size) - staged_hidden_states = batched_hidden_states[: - chunk_size, :] # type: ignore - staged_router_logits = batched_router_logits[: - chunk_size, :] # type: ignore + assert ( + batched_hidden_states.size(0) # type: ignore + >= chunk_size + ) + assert ( + batched_router_logits.size(0) # type: ignore + >= chunk_size + ) + staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore + staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True) # If there are shared experts but we are not using a modular kernel, # the shared experts must be called here - if (not isinstance(self.quant_method.fused_experts, - FusedMoEModularKernel) - and self.shared_experts is not None): + if ( + not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) + and self.shared_experts is not None + ): shared_output = self.shared_experts(staged_hidden_states) else: shared_output = None @@ -1979,16 +2152,16 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): if not skip_result_store: if self.shared_experts is None: - full_fused_final_hidden_states[ - chunk_start:chunk_end, :].copy_(final_hidden_states, - non_blocking=True) + full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states, non_blocking=True + ) else: - full_shared_final_hidden_states[ - chunk_start:chunk_end, :].copy_(final_hidden_states[0], - non_blocking=True) - full_fused_final_hidden_states[ - chunk_start:chunk_end, :].copy_(final_hidden_states[1], - non_blocking=True) + full_shared_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states[0], non_blocking=True + ) + full_fused_final_hidden_states[chunk_start:chunk_end, :].copy_( + final_hidden_states[1], non_blocking=True + ) ctx = get_forward_context() # flashinfer_cutlass_kernels can handle: optional DP + TP/EP @@ -1998,31 +2171,32 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): # If the input to the MoE is sequence parallel then divide by sp_size # to find the maximum number of tokens for any individual dispatcher. if self.is_sequence_parallel: - max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers, - self.sp_size) + max_tokens_across_dispatchers = cdiv( + max_tokens_across_dispatchers, self.sp_size + ) num_tokens = full_hidden_states.size(0) for chunk_idx, chunk_start_ in enumerate( - range(0, max_tokens_across_dispatchers, - moe_dp_chunk_size_per_rank)): + range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank) + ): chunk_start = chunk_start_ - chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, - max_tokens_across_dispatchers) + chunk_end = min( + chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers + ) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) - with ctx.dp_metadata.chunked_sizes(self.sp_size, - moe_dp_chunk_size_per_rank, - chunk_idx): - process_chunk(chunk_start, - chunk_end, - skip_result_store=chunk_start_ >= num_tokens) + with ctx.dp_metadata.chunked_sizes( + self.sp_size, moe_dp_chunk_size_per_rank, chunk_idx + ): + process_chunk( + chunk_start, chunk_end, skip_result_store=chunk_start_ >= num_tokens + ) if self.shared_experts is None: return full_fused_final_hidden_states else: - return (full_shared_final_hidden_states, - full_fused_final_hidden_states) + return (full_shared_final_hidden_states, full_fused_final_hidden_states) def forward_impl( self, @@ -2035,36 +2209,45 @@ def forward_impl( # Route to the chunked forward path using the FlashInfer Cutlass kernel # only when data parallelism (DP) is enabled. - _use_flashinfer_cutlass_kernels = (self.dp_size > 1 and - self.use_flashinfer_cutlass_kernels) + _use_flashinfer_cutlass_kernels = ( + self.dp_size > 1 and self.use_flashinfer_cutlass_kernels + ) - if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels - or _use_flashinfer_cutlass_kernels): + if ( + self.moe_parallel_config.use_pplx_kernels + or self.moe_parallel_config.use_deepep_ll_kernels + or _use_flashinfer_cutlass_kernels + ): return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( self.dp_size > 1 and not self.moe_parallel_config.use_deepep_ht_kernels - and not self.moe_config.use_flashinfer_cutlass_kernels) + and not self.moe_config.use_flashinfer_cutlass_kernels + ) # If there are shared experts but we are not using a modular kernel, the # shared experts must be called here - if (not isinstance(self.quant_method.fused_experts, - FusedMoEModularKernel) - and self.shared_experts is not None): + if ( + not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) + and self.shared_experts is not None + ): shared_output = self.shared_experts(hidden_states) else: shared_output = None ctx = get_forward_context() - sp_ctx = ctx.dp_metadata.sp_local_sizes( - self.sp_size) if ctx.dp_metadata else nullcontext() + sp_ctx = ( + ctx.dp_metadata.sp_local_sizes(self.sp_size) + if ctx.dp_metadata + else nullcontext() + ) with sp_ctx: if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits, self.is_sequence_parallel) + hidden_states, router_logits, self.is_sequence_parallel + ) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -2101,16 +2284,18 @@ def forward_impl( assert isinstance(final_hidden_states, tuple) final_hidden_states, zero_expert_result = final_hidden_states - def reduce_output(states: torch.Tensor, - do_combine: bool = True) -> torch.Tensor: + def reduce_output( + states: torch.Tensor, do_combine: bool = True + ) -> torch.Tensor: if do_naive_dispatch_combine and do_combine: - states = get_ep_group().combine(states, - self.is_sequence_parallel) + states = get_ep_group().combine(states, self.is_sequence_parallel) - if (not self.is_sequence_parallel and self.reduce_results - and (self.tp_size > 1 or self.ep_size > 1)): - states = self.maybe_all_reduce_tensor_model_parallel( - states) + if ( + not self.is_sequence_parallel + and self.reduce_results + and (self.tp_size > 1 or self.ep_size > 1) + ): + states = self.maybe_all_reduce_tensor_model_parallel(states) return states @@ -2127,29 +2312,36 @@ def reduce_output(states: torch.Tensor, @classmethod def make_expert_params_mapping( - cls, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int, - num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]: - + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + num_redundant_experts: int = 0, + ) -> list[tuple[str, str, int, str]]: num_physical_experts = num_experts + num_redundant_experts # In the returned mapping: # - `expert_id` is the physical expert id # - `weight_name` contains the weight name of the logical expert # So that we should map the expert id to logical in `weight_name` - physical_to_logical_map = \ + physical_to_logical_map = ( EplbState.build_initial_global_physical_to_logical_map( - num_experts, num_redundant_experts) + num_experts, num_redundant_experts + ) + ) return [ # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_" if weight_name - in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", - expert_id, shard_id) for expert_id in range(num_physical_experts) + ( + "experts.w13_" + if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] + else "experts.w2_", + f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", + expert_id, + shard_id, + ) + for expert_id in range(num_physical_experts) for shard_id, weight_name in [ ("w1", ckpt_gate_proj_name), ("w2", ckpt_down_proj_name), @@ -2158,7 +2350,6 @@ def make_expert_params_mapping( ] def extra_repr(self) -> str: - s = ( f"global_num_experts={self.global_num_experts}, " f"local_num_experts={self.local_num_experts}, " @@ -2168,7 +2359,8 @@ def extra_repr(self) -> str: f"ep_size={self.ep_size}, " f"reduce_results={self.reduce_results}, " f"renormalize={self.renormalize}, " - f"use_grouped_topk={self.use_grouped_topk}") + f"use_grouped_topk={self.use_grouped_topk}" + ) if self.use_grouped_topk: s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501 @@ -2202,7 +2394,7 @@ def moe_forward_fake( op_func=moe_forward, mutates_args=["hidden_states"], fake_impl=moe_forward_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(torch.Tag.needs_fixed_stride_order,), ) @@ -2232,7 +2424,7 @@ def moe_forward_shared_fake( op_func=moe_forward_shared, mutates_args=["hidden_states"], fake_impl=moe_forward_shared_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(torch.Tag.needs_fixed_stride_order,), ) # Mark the FusedMoE weight_loader as supporting MoE-specific parameters diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index a7617f8b7297..1f6209c9d08e 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,12 +10,18 @@ import vllm.envs as envs from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable - _resize_cache, count_expert_num_tokens) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, + count_expert_num_tokens, +) from vllm.utils import cdiv -from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled, - dbo_maybe_run_recv_hook, - dbo_register_recv_hook, dbo_yield) +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, + dbo_enabled, + dbo_maybe_run_recv_hook, + dbo_register_recv_hook, + dbo_yield, +) # # This file defines a set of base classes used to make MoE kernels more modular. @@ -59,31 +65,34 @@ class FusedMoEActivationFormat(Enum): """ The standard activation format (num_tokens, hidden dim). """ - Standard = "standard", + + Standard = ("standard",) """ The batched experts format (num experts, max tokens per expert, hidden dim) """ - BatchedExperts = "batched_experts", + BatchedExperts = ("batched_experts",) @dataclass class ExpertTokensMetadata: """ - Metadata regarding expert-token routing. - """ + Metadata regarding expert-token routing. + """ + expert_num_tokens: torch.Tensor expert_num_tokens_cpu: Optional[torch.Tensor] @staticmethod - def make_from_list(expert_num_tokens_list: list[int], - device: str) -> "ExpertTokensMetadata": - expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list, - device="cpu", - dtype=torch.int32) + def make_from_list( + expert_num_tokens_list: list[int], device: str + ) -> "ExpertTokensMetadata": + expert_num_tokens_cpu = torch.tensor( + expert_num_tokens_list, device="cpu", dtype=torch.int32 + ) return ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens_cpu.to(device, - non_blocking=True), - expert_num_tokens_cpu=expert_num_tokens_cpu) + expert_num_tokens=expert_num_tokens_cpu.to(device, non_blocking=True), + expert_num_tokens_cpu=expert_num_tokens_cpu, + ) class TopKWeightAndReduce(ABC): @@ -92,10 +101,14 @@ class TopKWeightAndReduce(ABC): """ @abstractmethod - def apply(self, output: Optional[torch.Tensor], - fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> torch.Tensor: + def apply( + self, + output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor: """ Apply topk_weights to the fused_experts_outputs and/or reduce. If an output tensor is not passed, it will be created in the @@ -200,16 +213,16 @@ def prepare_async( - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. - Returns a callback or a hook callback pair that when invoked waits for - results from other workers and has the same return signature as + Returns a callback or a hook callback pair that when invoked waits for + results from other workers and has the same return signature as `prepare`, if a hook is returned this is more lightweight check that - the recv is complete without doing extra work (used by DBO, will be + the recv is complete without doing extra work (used by DBO, will be refactored in the very near future) - + e.g. ret = obj.prepare_async(...) - + if isinstance(ret, tuple): hook, receiver = ret hook() @@ -270,10 +283,10 @@ def finalize_async( - weight_and_reduce_impl: An optional TopKWeightAndReduce implementation. - Returns a callback or a hook callback pair that when invoked waits for - results from other workers and has the same return signature as + Returns a callback or a hook callback pair that when invoked waits for + results from other workers and has the same return signature as `finalize`, if a hook is returned this is more lightweight check that - the recv is complete without doing extra work (used by DBO, will be + the recv is complete without doing extra work (used by DBO, will be refactored in the very near future) ret = obj.finalize_async(output, ...) @@ -344,7 +357,8 @@ def __init__( @property @abstractmethod def activation_formats( - self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: + self, + ) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]: """ A property which is a tuple of the input and output activation formats for the 'apply' method. @@ -382,8 +396,7 @@ def moe_problem_size( if a1.dim() == 2: # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.size(0) == a1.size(0), \ - f"{topk_ids.size(0)} != {a1.size(0)}" + assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}" M = a1.size(0) else: assert a1.dim() == 3 @@ -511,8 +524,9 @@ def workspace_shapes( """ raise NotImplementedError - def activation(self, activation: str, output: torch.Tensor, - input: torch.Tensor) -> None: + def activation( + self, activation: str, output: torch.Tensor, input: torch.Tensor + ) -> None: assert output.size(-1) * 2 == input.size(-1) if activation == "silu": torch.ops._C.silu_and_mul(output, input) @@ -522,8 +536,9 @@ def activation(self, activation: str, output: torch.Tensor, raise ValueError(f"Unsupported FusedMoe activation: {activation}") def enable_chunking(self): - return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \ - self.supports_chunking() + return ( + envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking() + ) def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce: raise NotImplementedError @@ -585,8 +600,9 @@ def apply( raise NotImplementedError -def _chunk_scales(scales: Optional[torch.Tensor], start: int, - end: int) -> Optional[torch.Tensor]: +def _chunk_scales( + scales: Optional[torch.Tensor], start: int, end: int +) -> Optional[torch.Tensor]: if scales is not None: if scales.numel() == 1: return scales @@ -596,17 +612,19 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int, class SharedResizableBuffer: - def __init__(self): self.buffer = None - def get(self, shape: tuple[int, ...], device: torch.device, - dtype: torch.dtype): + def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype): if shape == () or shape is None: return None shape_numel = prod(shape) - if (self.buffer is None or self.buffer.numel() < shape_numel - or self.buffer.device != device or self.buffer.dtype != dtype): + if ( + self.buffer is None + or self.buffer.numel() < shape_numel + or self.buffer.device != device + or self.buffer.dtype != dtype + ): self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) return self.buffer[:shape_numel].view(*shape) @@ -626,7 +644,6 @@ class FusedMoEModularKernel(torch.nn.Module): """ class SharedBuffers: - def __init__(self) -> None: self.fused_out = SharedResizableBuffer() self.workspace13 = SharedResizableBuffer() @@ -652,12 +669,14 @@ def __init__( self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts self.shared_experts = shared_experts - assert prepare_finalize.activation_format == \ - fused_experts.activation_formats[0], ( - f"{prepare_finalize.__class__.__name__}." - f"{prepare_finalize.activation_format} == " - f"{fused_experts.__class__.__name__}." - f"{fused_experts.activation_formats[0]}") + assert ( + prepare_finalize.activation_format == fused_experts.activation_formats[0] + ), ( + f"{prepare_finalize.__class__.__name__}." + f"{prepare_finalize.activation_format} == " + f"{fused_experts.__class__.__name__}." + f"{fused_experts.activation_formats[0]}" + ) def _do_fused_experts( self, @@ -677,14 +696,21 @@ def _do_fused_experts( expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, ) -> torch.Tensor: + _, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids) - _, M, N, K, top_k = self.fused_experts.moe_problem_size( - a1q, w1, w2, topk_ids) - - (workspace13_shape, workspace2_shape, fused_out_shape, - workspace_dtype) = self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, - expert_tokens_meta) + (workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = ( + self.fused_experts.workspace_shapes( + a1, + a1q, + M, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) + ) # select per-ubatch buffers to avoid cross-ubatch reuse under DBO ubatch_idx = dbo_current_ubatch_id() @@ -692,15 +718,16 @@ def _do_fused_experts( # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. - workspace13 = buffers.workspace13.get(workspace13_shape, - device=a1.device, - dtype=workspace_dtype) - workspace2 = buffers.workspace2.get(workspace2_shape, - device=a1.device, - dtype=workspace_dtype) + workspace13 = buffers.workspace13.get( + workspace13_shape, device=a1.device, dtype=workspace_dtype + ) + workspace2 = buffers.workspace2.get( + workspace2_shape, device=a1.device, dtype=workspace_dtype + ) assert fused_out is None or fused_out.shape == fused_out_shape, ( - f"fused_out {fused_out.shape} but expected {fused_out_shape}") + f"fused_out {fused_out.shape} but expected {fused_out_shape}" + ) if fused_out is None: # reuse workspace13 for the output fused_out = _resize_cache(workspace13, fused_out_shape) @@ -741,9 +768,7 @@ def _maybe_chunk_fused_experts( expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, ) -> torch.Tensor: - - _, M, N, K, top_k = self.fused_experts.moe_problem_size( - a1q, w1, w2, topk_ids) + _, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids) CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_chunks = cdiv(M, CHUNK_SIZE) @@ -775,18 +800,31 @@ def _maybe_chunk_fused_experts( # Construct the entire output that can then be processed in chunks. (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, - expert_tokens_meta) + a1, + a1q, + M, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) ubatch_idx = dbo_current_ubatch_id() buffers = self.shared_buffers[ubatch_idx] - fused_out = buffers.fused_out.get(fused_out_shape, - device=a1q.device, - dtype=a1.dtype) + fused_out = buffers.fused_out.get( + fused_out_shape, device=a1q.device, dtype=a1.dtype + ) def slice_input_tensors( - chunk_idx: int - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, torch.Tensor]: + chunk_idx: int, + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + ]: s = chunk_idx * CHUNK_SIZE e = min(s + CHUNK_SIZE, M) return ( @@ -799,7 +837,8 @@ def slice_input_tensors( def slice_output_tensor(chunk_idx: int) -> torch.Tensor: assert fused_out.size(0) % M == 0, ( - f"fused_out shape {fused_out.shape} vs M {M}") + f"fused_out shape {fused_out.shape} vs M {M}" + ) factor = fused_out.size(0) // M out_chunk_size = CHUNK_SIZE * factor s = chunk_idx * out_chunk_size @@ -807,38 +846,45 @@ def slice_output_tensor(chunk_idx: int) -> torch.Tensor: return fused_out[s:e] def slice_expert_tokens_metadata( - full_expert_tokens_meta: ExpertTokensMetadata, - chunk_topk_ids: torch.Tensor, local_num_experts: int, - expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata: + full_expert_tokens_meta: ExpertTokensMetadata, + chunk_topk_ids: torch.Tensor, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> ExpertTokensMetadata: # The existing expert_num_tokens is for the entire a1q # input. Chunking forces recomputation of the number # of tokens assigned to each expert. c_expert_num_tokens = count_expert_num_tokens( - chunk_topk_ids, local_num_experts, expert_map) + chunk_topk_ids, local_num_experts, expert_map + ) c_expert_num_tokens_cpu = None need_expert_num_tokens_cpu = ( - full_expert_tokens_meta.expert_num_tokens_cpu is not None) + full_expert_tokens_meta.expert_num_tokens_cpu is not None + ) if need_expert_num_tokens_cpu: # This is blocking as some implementations need the count # on the CPU to determine appropriate input/out fused-moe # buffers c_expert_num_tokens_cpu = c_expert_num_tokens.to( - "cpu", non_blocking=False) + "cpu", non_blocking=False + ) return ExpertTokensMetadata( expert_num_tokens=c_expert_num_tokens, - expert_num_tokens_cpu=c_expert_num_tokens_cpu) + expert_num_tokens_cpu=c_expert_num_tokens_cpu, + ) for chunk_idx in range(num_chunks): c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( - slice_input_tensors(chunk_idx)) + slice_input_tensors(chunk_idx) + ) c_expert_tokens_meta = None if expert_tokens_meta is not None: c_expert_tokens_meta = slice_expert_tokens_metadata( - expert_tokens_meta, c_topk_ids, local_num_experts, - expert_map) + expert_tokens_meta, c_topk_ids, local_num_experts, expert_map + ) self._do_fused_experts( fused_out=slice_output_tensor(chunk_idx), @@ -902,10 +948,7 @@ def forward( """ a1 = hidden_states - if inplace and self.shared_experts is None: - output = a1 - else: - output = torch.zeros_like(a1) + output = a1 if inplace and self.shared_experts is None else torch.zeros_like(a1) local_num_experts = w1.size(0) if global_num_experts == -1: @@ -917,16 +960,21 @@ def forward( # TODO(lucas): enable in follow-up assert not dbo_enabled() - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = self.prepare_finalize.prepare( - a1, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) + ( + a1q, + a1q_scale, + expert_tokens_meta, + _expert_topk_ids, + _expert_topk_weights, + ) = self.prepare_finalize.prepare( + a1, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) else: # Overlap shared expert compute with all2all dispatch. dbo_maybe_run_recv_hook() @@ -943,8 +991,9 @@ def forward( # TODO(lucas): refactor this in the alternative schedules followup # currently unpack if we have hook + receiver pair or just # receiver (see finalize_async docstring) - hook, receiver = prepare_ret \ - if isinstance(prepare_ret, tuple) else (None, prepare_ret) + hook, receiver = ( + prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret) + ) if hook is not None: if dbo_enabled(): @@ -956,13 +1005,19 @@ def forward( else: hook() - (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, - _expert_topk_weights) = receiver() + ( + a1q, + a1q_scale, + expert_tokens_meta, + _expert_topk_ids, + _expert_topk_weights, + ) = receiver() # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids - topk_weights = (topk_weights if _expert_topk_weights is None else - _expert_topk_weights) + topk_weights = ( + topk_weights if _expert_topk_weights is None else _expert_topk_weights + ) fused_out = None @@ -1022,8 +1077,11 @@ def forward( # TODO(lucas): refactor this in the alternative schedules followup # currently unpack if we have hook + receiver pair or just # receiver (see finalize_async docstring) - hook, receiver = finalize_ret \ - if isinstance(finalize_ret, tuple) else (None, finalize_ret) + hook, receiver = ( + finalize_ret + if isinstance(finalize_ret, tuple) + else (None, finalize_ret) + ) if hook is not None: if dbo_enabled(): diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index c7d7126bab3a..9994088ca5d9 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -14,7 +14,7 @@ def moe_align_block_size( block_size: int, num_experts: int, expert_map: Optional[torch.Tensor] = None, - pad_sorted_ids: bool = False + pad_sorted_ids: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block @@ -68,19 +68,18 @@ def moe_align_block_size( max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) if pad_sorted_ids: max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) - sorted_ids = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + ops.moe_align_block_size( + topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad + ) if expert_map is not None: expert_ids = expert_map[expert_ids] diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index 23f618b1a5fd..66c00cf89873 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -7,18 +7,20 @@ def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: """ - Compute the histogram of an int32 tensor. The bin edges are defined by the - min and max values, with step = 1. - """ + Compute the histogram of an int32 tensor. The bin edges are defined by the + min and max values, with step = 1. + """ assert input.dtype == torch.int32, "input must be of torch.int32 dtype." assert min <= max, "min must be less than or equal to max." - def searchsorted(sorted_sequence: torch.Tensor, - values_to_search: torch.Tensor) -> torch.Tensor: + def searchsorted( + sorted_sequence: torch.Tensor, values_to_search: torch.Tensor + ) -> torch.Tensor: return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1) - bin_edges = torch.linspace(min, max, max - min + 1, - dtype=input.dtype).to(input.device) + bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to( + input.device + ) return searchsorted(bin_edges, input).to(torch.int32) @@ -41,6 +43,7 @@ def fused_moe( """ assert expert_map is None, "expert_map is not supported for pallas MoE." import torch_xla.experimental.custom_kernel # noqa: F401 + orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] num_tokens = hidden_states.shape[:-1].numel() @@ -50,7 +53,8 @@ def fused_moe( dtype = hidden_states.dtype assert (num_tokens * topk) % 16 == 0, ( "The Pallas GMM kernel requires num_tokens * topk to be a multiple of " - f"16 but got {num_tokens * topk}") + f"16 but got {num_tokens * topk}" + ) hidden_states = hidden_states.view(num_tokens, hidden_size) gating_output = gating_output.view(num_tokens, num_experts) @@ -63,8 +67,7 @@ def fused_moe( topk_indices = topk_indices.flatten() topk_argsort_indices = topk_indices.argsort() topk_argsort_revert_indices = topk_argsort_indices.argsort() - token_indices = torch.arange(num_tokens, - device=device).repeat_interleave(topk) + token_indices = torch.arange(num_tokens, device=device).repeat_interleave(topk) token_indices = token_indices[topk_argsort_indices] group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 16a155e71847..698080f8aec6 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -6,7 +6,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size) + moe_align_block_size, +) from vllm.model_executor.layers.fused_moe.utils import _fp8_perm @@ -17,8 +18,9 @@ def _moe_permute( global_num_experts: int, expert_map: Optional[torch.Tensor], block_m: int, -) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - torch.Tensor]: +) -> tuple[ + torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor +]: """ Determine the sorted_token_ids, expert_ids for the given problem size. Permute the hidden states and scales according to `sorted_token_ids`. @@ -27,12 +29,9 @@ def _moe_permute( tokens_in_chunk = curr_hidden_states.size(0) - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, - block_m, - global_num_experts, - expert_map, - pad_sorted_ids=True)) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, block_m, global_num_experts, expert_map, pad_sorted_ids=True + ) inv_perm: Optional[torch.Tensor] = None @@ -43,14 +42,12 @@ def _moe_permute( # Permute according to sorted token ids. sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) - curr_hidden_states = _fp8_perm(curr_hidden_states, - sorted_token_ids // top_k_num) + curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num) if a1q_scale is not None: a1q_scale = a1q_scale[sorted_token_ids // top_k_num] - return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, - inv_perm) + return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids, inv_perm) def _moe_unpermute_and_reduce( @@ -84,8 +81,9 @@ def moe_permute( align_block_size: Optional[int] = None, fill_invalid_expert: int = -1, permuted_hidden_states: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - torch.Tensor]: +) -> tuple[ + torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor +]: """ This function expands and permutes activation to gather uncontinuous tokens for each expert. @@ -117,13 +115,21 @@ def moe_permute( """ n_token, n_hidden = hidden_states.size() topk = topk_ids.size(1) - assert (n_hidden * hidden_states.element_size() - ) % 16 == 0, "permue kernel need hidden dim align to 16B" + assert (n_hidden * hidden_states.element_size()) % 16 == 0, ( + "permue kernel need hidden dim align to 16B" + ) permuted_row_size = n_token * topk if align_block_size is not None: - permuted_row_size = (permuted_row_size + n_expert * - (align_block_size - 1) + align_block_size - - 1) // align_block_size * align_block_size + permuted_row_size = ( + ( + permuted_row_size + + n_expert * (align_block_size - 1) + + align_block_size + - 1 + ) + // align_block_size + * align_block_size + ) if n_local_expert == -1: n_local_expert = n_expert if permuted_hidden_states is None: @@ -134,40 +140,57 @@ def moe_permute( ) assert permuted_hidden_states.size() == (permuted_row_size, n_hidden), ( f"Expected permuted hidden states to be {(permuted_row_size, n_hidden)}" - f" but got {permuted_hidden_states.size()}") - - token_expert_indices = torch.arange(0, - n_token * topk, - dtype=torch.int32, - device=hidden_states.device).reshape( - (n_token, topk)) - - m_indices = torch.full((permuted_row_size, ), - fill_invalid_expert, - dtype=torch.int32, - device=hidden_states.device) - expert_first_token_offset = torch.empty(n_local_expert + 1, - dtype=torch.int64, - device=hidden_states.device) - permuted_idx = torch.full((permuted_row_size, ), - n_token * topk, - dtype=torch.int32, - device=hidden_states.device) - inv_permuted_idx = torch.empty((n_token, topk), - dtype=torch.int32, - device=hidden_states.device) + f" but got {permuted_hidden_states.size()}" + ) + + token_expert_indices = torch.arange( + 0, n_token * topk, dtype=torch.int32, device=hidden_states.device + ).reshape((n_token, topk)) + + m_indices = torch.full( + (permuted_row_size,), + fill_invalid_expert, + dtype=torch.int32, + device=hidden_states.device, + ) + expert_first_token_offset = torch.empty( + n_local_expert + 1, dtype=torch.int64, device=hidden_states.device + ) + permuted_idx = torch.full( + (permuted_row_size,), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device, + ) + inv_permuted_idx = torch.empty( + (n_token, topk), dtype=torch.int32, device=hidden_states.device + ) topk_ids = topk_ids.to(torch.int32) - torch.ops._moe_C.moe_permute(hidden_states, topk_ids, token_expert_indices, - expert_map, n_expert, n_local_expert, topk, - align_block_size, permuted_hidden_states, - expert_first_token_offset, inv_permuted_idx, - permuted_idx, m_indices) + torch.ops._moe_C.moe_permute( + hidden_states, + topk_ids, + token_expert_indices, + expert_map, + n_expert, + n_local_expert, + topk, + align_block_size, + permuted_hidden_states, + expert_first_token_offset, + inv_permuted_idx, + permuted_idx, + m_indices, + ) if a1q_scale is not None and a1q_scale.dim() > 1: - a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // - topk] - return (permuted_hidden_states, a1q_scale, expert_first_token_offset, - inv_permuted_idx.flatten(), m_indices) + a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // topk] + return ( + permuted_hidden_states, + a1q_scale, + expert_first_token_offset, + inv_permuted_idx.flatten(), + m_indices, + ) def moe_unpermute( @@ -185,7 +208,7 @@ def moe_unpermute( - permuted_hidden_states (torch.Tensor): permuted activation. - topk_weights (torch.Tensor): topk expert route weight for each token. - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute. - - expert_first_token_offset (Optional[torch.Tensor]): offset of the first + - expert_first_token_offset (Optional[torch.Tensor]): offset of the first token of each expert for grouped gemm. Returns: - hidden_states (torch.Tensor): The reduced and unpermuted activation @@ -193,12 +216,18 @@ def moe_unpermute( """ topk = topk_weights.size(1) n_hidden = permuted_hidden_states.size(-1) - assert (n_hidden * permuted_hidden_states.element_size() - ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" - - torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, - inv_permuted_idx, expert_first_token_offset, - topk, out) + assert (n_hidden * permuted_hidden_states.element_size()) % 16 == 0, ( + "unpermue kernel need hidden dim align to 16B" + ) + + torch.ops._moe_C.moe_unpermute( + permuted_hidden_states, + topk_weights, + inv_permuted_idx, + expert_first_token_offset, + topk, + out, + ) def moe_permute_unpermute_supported(): diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py index 6160da732951..f721d00d75ea 100644 --- a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +++ b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py @@ -45,7 +45,7 @@ def fused_moe( for expert_idx in range(num_experts): expert_w1 = w1[expert_idx] expert_w2 = w2[expert_idx] - expert_mask = (selected_experts == expert_idx) + expert_mask = selected_experts == expert_idx expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) x = F.linear(hidden_states, expert_w1) gate = F.silu(x[:, :intermediate_size]) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index ddddd2a3b7a2..79212c2b689d 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -9,9 +9,12 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate) + TopKWeightAndReduceDelegate, +) from vllm.model_executor.layers.fused_moe.utils import ( - _validate_scale_shape, moe_kernel_quantize_input) + _validate_scale_shape, + moe_kernel_quantize_input, +) from vllm.utils import cdiv, round_up logger = init_logger(__name__) @@ -60,7 +63,6 @@ def pplx_hidden_dim_scale_bytes( class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): - def __init__( self, a2a: pplx.AllToAll, @@ -113,8 +115,9 @@ def prepare_async( if expert_map is not None: logger.warning_once( "The PPLX backend does not support expert mapping. " - "The provided `expert_map` will be ignored.") - expert_map = None #noqa: F841 + "The provided `expert_map` will be ignored." + ) + expert_map = None # noqa: F841 # Is this always going to be a1.device? device = a1.device @@ -123,21 +126,24 @@ def prepare_async( topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1") + "apply_router_weight_on_input is only implemented for topk=1" + ) a1 = a1 * topk_weights.to(a1.dtype) repeat_cols = 4 repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) # TODO(bnell): always pass quant_config.a1_scale? a1q, a1q_scale = moe_kernel_quantize_input( - a1, (None if quant_config.per_act_token_quant else - quant_config.a1_scale), + a1, + (None if quant_config.per_act_token_quant else quant_config.a1_scale), quant_dtype=quant_config.quant_dtype, per_act_token_quant=quant_config.per_act_token_quant, - block_shape=quant_config.block_shape) + block_shape=quant_config.block_shape, + ) - _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant, - quant_config.block_shape) + _validate_scale_shape( + a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape + ) orig_a_scale_block_shape: Optional[int] = None @@ -155,8 +161,9 @@ def prepare_async( # TODO (bnell): use group_broadcast instead? a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) - assert a1q_scale is None or a1q_scale.ndim == 2, \ + assert a1q_scale is None or a1q_scale.ndim == 2, ( f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}" + ) expert_num_tokens = torch.empty( self.num_local_experts, @@ -165,8 +172,11 @@ def prepare_async( ) expert_x = torch.empty( - (self.num_local_experts, - self.max_num_tokens * self.num_dispatchers(), hidden_dim), + ( + self.num_local_experts, + self.max_num_tokens * self.num_dispatchers(), + hidden_dim, + ), dtype=a1q.dtype, device=device, ) @@ -182,14 +192,13 @@ def prepare_async( else: # (M x K_tiles) -> (E x M x K_tiles) assert quant_config.block_shape is not None - num_blocks = cdiv(expert_x.size(2), - quant_config.block_shape[1]) + num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1]) final_dim = num_blocks expert_x_scale_shape = ( self.num_local_experts, expert_x.size(1), - round_up(final_dim, 4) # round up for alignment + round_up(final_dim, 4), # round up for alignment ) expert_x_scale = torch.empty( @@ -226,12 +235,15 @@ def prepare_async( do_recv=True, ) - return (hook, lambda: self._receiver( - expert_num_tokens, - expert_x, - expert_x_scale, - orig_a_scale_block_shape, - )) + return ( + hook, + lambda: self._receiver( + expert_num_tokens, + expert_x, + expert_x_scale, + orig_a_scale_block_shape, + ), + ) def _receiver( self, @@ -240,13 +252,13 @@ def _receiver( expert_x_scale: Optional[torch.Tensor], orig_a_scale_block_shape: Optional[int], ) -> mk.PrepareResultType: - if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] assert expert_x_scale.ndim == 3 expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None + ) return expert_x, expert_x_scale, expert_tokens_meta, None, None @@ -281,22 +293,24 @@ def finalize_async( apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> Callable: - assert isinstance( - weight_and_reduce_impl, TopKWeightAndReduceDelegate - ), ("Weight application and reduction happens in the combine kernel.") + assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), ( + "Weight application and reduction happens in the combine kernel." + ) # This argument is optional # There's not much point setting this unless it is != topk_ids.size(0) bound_m: Optional[torch.Tensor] = None # TODO (bnell): fails in test_pplx_moe.py, figure out what's going on - #num_tokens = output.size(0) # M - #assert topk_ids.size(0) == num_tokens, ( + # num_tokens = output.size(0) # M + # assert topk_ids.size(0) == num_tokens, ( # f"{topk_ids.size(0)} == {num_tokens}") assert topk_ids.size() == topk_weights.size(), ( - f"{topk_ids.size()} == {topk_weights.size()}") + f"{topk_ids.size()} == {topk_weights.size()}" + ) assert output.size(0) <= self.max_num_tokens, ( - f"{output.size(0)} <= {self.max_num_tokens}") + f"{output.size(0)} <= {self.max_num_tokens}" + ) assert output.size(1) == fused_expert_output.size(-1) # Set weights to 1 if we did them in dispatch. This is hacky. @@ -305,21 +319,25 @@ def finalize_async( topk_ids_u32 = topk_ids.view(dtype=torch.uint32) - self.a2a.combine(out_tokens=output, - indices=topk_ids_u32, - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m, - do_send=True, - do_recv=False) - - return lambda: self.a2a.combine(out_tokens=output, - indices=topk_ids_u32, - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m, - do_send=False, - do_recv=True) + self.a2a.combine( + out_tokens=output, + indices=topk_ids_u32, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=True, + do_recv=False, + ) + + return lambda: self.a2a.combine( + out_tokens=output, + indices=topk_ids_u32, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + do_send=False, + do_recv=True, + ) def finalize( self, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 588e5de865dd..be6939a3f62f 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -7,13 +7,13 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) -from vllm.model_executor.layers.fused_moe.utils import ( - moe_kernel_quantize_input) + TopKWeightAndReduceContiguous, + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): - @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard @@ -37,17 +37,21 @@ def prepare( apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, \ + assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1" + ) a1.mul_(topk_weights.to(a1.dtype)) a1q, a1q_scale = moe_kernel_quantize_input( - a1, quant_config.a1_scale, quant_config.quant_dtype, - quant_config.per_act_token_quant, quant_config.block_shape) + a1, + quant_config.a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) return a1q, a1q_scale, None, None, None @@ -67,4 +71,5 @@ def finalize( fused_expert_output=fused_expert_output, topk_weights=topk_weights, topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 2764af5fc532..801785b18fb9 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -8,7 +8,9 @@ from vllm import envs from vllm.model_executor.layers.fused_moe.config import ( - FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig) + FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEQuantConfig, +) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -38,138 +40,162 @@ class ActivationMethod(IntEnum): @cache def is_rocm_aiter_moe_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_MOE \ + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER_MOE and envs.VLLM_ROCM_USE_AITER + ) def rocm_aiter_asm_moe_tkw1_impl( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - fc1_smooth_scale: Optional[torch.Tensor] = None, - fc2_smooth_scale: Optional[torch.Tensor] = None, - a16: bool = False, - per_tensor_quant_scale: Optional[torch.Tensor] = None, - expert_mask: Optional[torch.Tensor] = None, - activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor: - + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, +) -> torch.Tensor: from aiter import ActivationType from aiter.fused_moe_bf16_asm import asm_moe_tkw1 activation = ActivationType(activation_method) - return asm_moe_tkw1(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - fc1_scale=fc1_scale, - fc2_scale=fc2_scale, - fc1_smooth_scale=fc1_smooth_scale, - fc2_smooth_scale=fc2_smooth_scale, - a16=a16, - per_tensor_quant_scale=per_tensor_quant_scale, - expert_mask=expert_mask, - activation=activation) + return asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation, + ) def rocm_aiter_asm_moe_tkw1_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - fc1_scale: Optional[torch.Tensor] = None, - fc2_scale: Optional[torch.Tensor] = None, - fc1_smooth_scale: Optional[torch.Tensor] = None, - fc2_smooth_scale: Optional[torch.Tensor] = None, - a16: bool = False, - per_tensor_quant_scale: Optional[torch.Tensor] = None, - expert_mask: Optional[torch.Tensor] = None, - activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor: + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, +) -> torch.Tensor: return torch.empty_like(hidden_states) -def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool) -> None: +def rocm_aiter_topk_softmax_impl( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: from aiter import topk_softmax - topk_softmax(topk_weights, topk_indices, token_expert_indices, - gating_output, renormalize) + topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) -def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool) -> None: + +def rocm_aiter_topk_softmax_fake( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> None: pass def rocm_aiter_biased_grouped_topk_impl( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0 # mul to topk_weights + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: - from aiter import biased_grouped_topk - biased_grouped_topk(gating_output, correction_bias, topk_weights, topk_ids, - num_expert_group, topk_group, need_renorm, - routed_scaling_factor) + biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + routed_scaling_factor, + ) def rocm_aiter_biased_grouped_topk_fake( - gating_output: torch.Tensor, - correction_bias: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - routed_scaling_factor: float = 1.0 # mul to topk_weights + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: pass def rocm_aiter_grouped_topk_impl( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0 # mul to topk_weights + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: - from aiter import grouped_topk - grouped_topk(gating_output, topk_weights, topk_ids, num_expert_group, - topk_group, need_renorm, scoring_func, routed_scaling_factor) + grouped_topk( + gating_output, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + need_renorm, + scoring_func, + routed_scaling_factor, + ) def rocm_aiter_grouped_topk_fake( - gating_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_expert_group: int, - topk_group: int, - need_renorm: bool, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0 # mul to topk_weights + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_expert_group: int, + topk_group: int, + need_renorm: bool, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, # mul to topk_weights ) -> None: pass @@ -195,9 +221,21 @@ def rocm_aiter_fused_moe_impl( activation = ActivationType(activation_method) quant_type = QuantType(quant_method) - return fused_moe(hidden_states, w1, w2, topk_weight, topk_ids, expert_mask, - activation, quant_type, doweight_stage1, w1_scale, - w2_scale, a1_scale, a2_scale) + return fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation, + quant_type, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + ) def rocm_aiter_fused_moe_fake( @@ -219,7 +257,6 @@ def rocm_aiter_fused_moe_fake( if current_platform.is_rocm(): - direct_register_custom_op( op_name="rocm_aiter_asm_moe_tkw1", op_func=rocm_aiter_asm_moe_tkw1_impl, @@ -263,14 +300,12 @@ def rocm_aiter_grouped_topk( topk_group: int = 0, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: token = hidden_states.shape[0] device = hidden_states.device topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) - topk_weights = torch.empty((token, topk), - dtype=torch.float32, - device=device) + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) if e_score_correction_bias is not None: torch.ops.vllm.rocm_aiter_biased_grouped_topk( @@ -283,7 +318,7 @@ def rocm_aiter_grouped_topk( renormalize, ) else: - assert (scoring_func == "softmax" or scoring_func == "sigmoid") + assert scoring_func == "softmax" or scoring_func == "sigmoid" torch.ops.vllm.rocm_aiter_grouped_topk( gating_output, topk_weights, @@ -313,28 +348,30 @@ def rocm_aiter_fused_experts( if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - activation_method = (ActivationMethod.SILU - if activation == "silu" else ActivationMethod.GELU) + activation_method = ( + ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU + ) # All AITER Fused MoE kernels are expecting the following datatypes topk_weights = topk_weights.to(torch.float32) topk_ids = topk_ids.to(torch.int32) - if expert_map is not None: - expert_mask = (expert_map > -1).to(torch.int32) - else: - expert_mask = None + expert_mask = (expert_map > -1).to(torch.int32) if expert_map is not None else None # w8a8 per-channel quantization - if (quant_config.per_act_token_quant and apply_router_weight_on_input - and quant_config.use_fp8_w8a8): + if ( + quant_config.per_act_token_quant + and apply_router_weight_on_input + and quant_config.use_fp8_w8a8 + ): # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # This applies topk_weights on the GEMM output of the first FC layer # rather than the second FC. - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" + assert topk_weights.dim() == 2, ( + "`topk_weights` should be in shape (num_tokens, topk)" + ) assert topk_weights.shape[-1] == 1, ( - "Only support topk=1 when" - " `apply_router_weight_on_input` is True") + "Only support topk=1 when `apply_router_weight_on_input` is True" + ) return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( hidden_states, @@ -349,7 +386,8 @@ def rocm_aiter_fused_experts( a16=False, per_tensor_quant_scale=None, expert_mask=expert_mask, - activation_method=activation_method) + activation_method=activation_method, + ) else: quant_method = QuantMethod.NO.value @@ -358,7 +396,8 @@ def rocm_aiter_fused_experts( if quant_config.block_shape is not None and quant_config.use_fp8_w8a8: assert not apply_router_weight_on_input, ( "apply_router_weight_on_input is\ - not supported for block scaled moe") + not supported for block scaled moe" + ) assert quant_config.w1_scale is not None assert quant_config.w2_scale is not None quant_method = QuantMethod.BLOCK_128x128.value @@ -367,12 +406,13 @@ def rocm_aiter_fused_experts( quant_method = QuantMethod.PER_TENSOR.value if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" + assert topk_weights.dim() == 2, ( + "`topk_weights` should be in shape (num_tokens, topk)" + ) _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" + assert topk == 1, ( + "Only support topk=1 when `apply_router_weight_on_input` is True" + ) return torch.ops.vllm.rocm_aiter_fused_moe( hidden_states, @@ -387,17 +427,20 @@ def rocm_aiter_fused_experts( w2_scale=quant_config.w2_scale, a1_scale=quant_config.a1_scale, a2_scale=quant_config.a2_scale, - doweight_stage1=apply_router_weight_on_input) + doweight_stage1=apply_router_weight_on_input, + ) -def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, - topk_indices: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool) -> tuple[torch.Tensor, ...]: - torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices, - token_expert_indices, gating_output, - renormalize) +def rocm_aiter_topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, +) -> tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_softmax( + topk_weights, topk_indices, token_expert_indices, gating_output, renormalize + ) return topk_weights, topk_indices @@ -413,7 +456,7 @@ def shuffle_weights( Args: *tensors: Variable number of torch.Tensor objects. - layout: A pair of integers specifying the block sizes used to divide + layout: A pair of integers specifying the block sizes used to divide the tensors during shuffling. Default is (16, 16). Returns: diff --git a/vllm/model_executor/layers/fused_moe/routing_simulator.py b/vllm/model_executor/layers/fused_moe/routing_simulator.py index 8758a570b3c6..af20f4b7c1d2 100644 --- a/vllm/model_executor/layers/fused_moe/routing_simulator.py +++ b/vllm/model_executor/layers/fused_moe/routing_simulator.py @@ -50,9 +50,7 @@ class DistributionBasedRouting(RoutingStrategy): distributions for testing different routing patterns. """ - def __init__(self, - distribution: str = "uniform", - **distribution_params: Any): + def __init__(self, distribution: str = "uniform", **distribution_params: Any): """ Initialize distribution-based routing. @@ -76,8 +74,10 @@ def _validate_distribution_params(self): valid_distributions = ["uniform", "normal"] if self.distribution not in valid_distributions: - raise ValueError(f"Unsupported distribution: {self.distribution}. " - f"Supported distributions: {valid_distributions}") + raise ValueError( + f"Unsupported distribution: {self.distribution}. " + f"Supported distributions: {valid_distributions}" + ) # Set default parameters if not provided if self.distribution == "normal": @@ -112,12 +112,12 @@ def route_tokens( indices_type = torch.long # Generate expert IDs based on the specified distribution - topk_ids = self._sample_expert_ids(num_tokens, num_experts, top_k, - hidden_states.device, indices_type) + topk_ids = self._sample_expert_ids( + num_tokens, num_experts, top_k, hidden_states.device, indices_type + ) # Generate weights based on the distribution - topk_weights = self._generate_weights(num_tokens, top_k, - hidden_states.device) + topk_weights = self._generate_weights(num_tokens, top_k, hidden_states.device) return topk_weights, topk_ids @@ -145,7 +145,8 @@ def _sample_expert_ids( # For normal distribution, sample continuous values and map to # expert IDs continuous_samples = self._sample_continuous_distribution( - num_tokens, top_k, device) + num_tokens, top_k, device + ) # Map continuous samples to expert indices # Normalize to [0, 1] range and scale to [0, num_experts) @@ -158,8 +159,9 @@ def _sample_expert_ids( else: raise ValueError(f"Unsupported distribution: {self.distribution}") - def _sample_continuous_distribution(self, num_tokens: int, top_k: int, - device: torch.device) -> torch.Tensor: + def _sample_continuous_distribution( + self, num_tokens: int, top_k: int, device: torch.device + ) -> torch.Tensor: """Sample from continuous distributions.""" shape = (num_tokens, top_k) @@ -170,7 +172,8 @@ def _sample_continuous_distribution(self, num_tokens: int, top_k: int, else: raise ValueError( - f"Unsupported continuous distribution: {self.distribution}") + f"Unsupported continuous distribution: {self.distribution}" + ) def _normalize_samples(self, samples: torch.Tensor) -> torch.Tensor: """Normalize samples to [0, 1] range.""" @@ -179,11 +182,13 @@ def _normalize_samples(self, samples: torch.Tensor) -> torch.Tensor: return torch.sigmoid(samples) else: - raise ValueError(f"Unsupported distribution for normalization: " - f"{self.distribution}") + raise ValueError( + f"Unsupported distribution for normalization: {self.distribution}" + ) - def _generate_weights(self, num_tokens: int, top_k: int, - device: torch.device) -> torch.Tensor: + def _generate_weights( + self, num_tokens: int, top_k: int, device: torch.device + ) -> torch.Tensor: """Generate weights based on the distribution.""" if self.distribution == "uniform": # All-ones weights for uniform distribution @@ -197,7 +202,8 @@ def _generate_weights(self, num_tokens: int, top_k: int, # For normal distribution, generate weights from the same # distribution continuous_weights = self._sample_continuous_distribution( - num_tokens, top_k, device) + num_tokens, top_k, device + ) # Normalize to positive values and sum to 1 weights = torch.abs(continuous_weights) weights = weights / weights.sum(dim=-1, keepdim=True) @@ -205,14 +211,14 @@ def _generate_weights(self, num_tokens: int, top_k: int, else: raise ValueError( - f"Unsupported distribution for weight generation: " - f"{self.distribution}") + f"Unsupported distribution for weight generation: {self.distribution}" + ) def get_distribution_info(self) -> dict: """Get information about the current distribution configuration.""" return { "distribution": self.distribution, - "parameters": self.distribution_params.copy() + "parameters": self.distribution_params.copy(), } @@ -228,10 +234,12 @@ class RoutingSimulator: # Class-level registry of routing strategies _routing_strategies: dict[str, RoutingStrategy] = { # Basic routing strategies - "uniform_random": - DistributionBasedRouting(distribution="uniform", mean=0.0, std=1.0), - "normal_routing": - DistributionBasedRouting(distribution="normal", mean=0.0, std=1.0), + "uniform_random": DistributionBasedRouting( + distribution="uniform", mean=0.0, std=1.0 + ), + "normal_routing": DistributionBasedRouting( + distribution="normal", mean=0.0, std=1.0 + ), } @classmethod @@ -280,7 +288,8 @@ def simulate_routing( raise ValueError( f"Unknown routing strategy: {strategy_name}. " f"Available strategies: " - f"{list(RoutingSimulator._routing_strategies.keys())}") + f"{list(RoutingSimulator._routing_strategies.keys())}" + ) strategy = RoutingSimulator._routing_strategies[strategy_name] return strategy.route_tokens( diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py index fb398eec119f..e725a0f00363 100644 --- a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py +++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py @@ -19,7 +19,7 @@ class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize does the weight-application + reduction as part of the pplx combine kernel. But the BatchedPrepareAndFinalize needs an implementation. To facilitate - this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate + this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate so the PrepareAndFinalize implementations could choose how to weight + reduce. """ @@ -27,12 +27,18 @@ class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): def __eq__(self, other): return isinstance(other, TopKWeightAndReduceDelegate) - def apply(self, output: Optional[torch.Tensor], - fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> torch.Tensor: - raise RuntimeError("The caller is expected to choose an appropriate " - "TopKWeightAndReduce implementation.") + def apply( + self, + output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor: + raise RuntimeError( + "The caller is expected to choose an appropriate " + "TopKWeightAndReduce implementation." + ) class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): @@ -44,10 +50,14 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): def __eq__(self, other): return isinstance(other, TopKWeightAndReduceNoOP) - def apply(self, output: Optional[torch.Tensor], - fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> torch.Tensor: + def apply( + self, + output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor: # Weight application and reduction operations are already done. if output is None: return fused_expert_output @@ -57,7 +67,8 @@ def apply(self, output: Optional[torch.Tensor], assert output.size() == fused_expert_output.size(), ( "output shape is expected to match the fused_expert_output shape. " f"But got output={output.size()}, " - f"used_expert_output={fused_expert_output.size()}") + f"used_expert_output={fused_expert_output.size()}" + ) output.copy_(fused_expert_output, non_blocking=True) return output @@ -71,11 +82,14 @@ class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce): def __eq__(self, other): return isinstance(other, TopKWeightAndReduceContiguous) - def apply(self, output: Optional[torch.Tensor], - fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> torch.Tensor: - + def apply( + self, + output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor: m, num_topk = topk_ids.size() k = fused_expert_output.size(-1) if fused_expert_output.ndim == 2: @@ -83,17 +97,21 @@ def apply(self, output: Optional[torch.Tensor], assert fused_expert_output.size() == (m, num_topk, k), ( f"Expected fused_expert_output size {(m, num_topk, k)}. But got " - f"{fused_expert_output.size()}") + f"{fused_expert_output.size()}" + ) if not apply_router_weight_on_input: fused_expert_output.mul_(topk_weights.view(m, -1, 1)) if output is None: - output = torch.empty((m, k), - device=fused_expert_output.device, - dtype=fused_expert_output.dtype) + output = torch.empty( + (m, k), + device=fused_expert_output.device, + dtype=fused_expert_output.dtype, + ) assert output.size() == (m, k), ( - f"Expected output size {(m, k)}. But got {output.size()}") + f"Expected output size {(m, k)}. But got {output.size()}" + ) ops.moe_sum(fused_expert_output, output) return output @@ -109,27 +127,35 @@ def __init__(self, rank: int): self.rank = rank def __eq__(self, other): - return (isinstance(other, TopKWeightAndReduceNaiveBatched) - and (other.rank == self.rank)) - - def apply(self, output: Optional[torch.Tensor], - fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> torch.Tensor: + return isinstance(other, TopKWeightAndReduceNaiveBatched) and ( + other.rank == self.rank + ) + + def apply( + self, + output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> torch.Tensor: assert fused_expert_output.ndim == 3 num_tokens = topk_ids.size(0) num_local_experts = fused_expert_output.size(0) K = fused_expert_output.size(-1) if output is None: - output = torch.zeros((num_tokens, K), - device=fused_expert_output.device, - dtype=fused_expert_output.dtype) + output = torch.zeros( + (num_tokens, K), + device=fused_expert_output.device, + dtype=fused_expert_output.dtype, + ) else: output.fill_(0) assert output.size() == (num_tokens, K), ( - f"Expected output size {(num_tokens, K)}, but got {output.size()}") + f"Expected output size {(num_tokens, K)}, but got {output.size()}" + ) first_expert = num_local_experts * self.rank last_expert = first_expert + num_local_experts diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 3de80ff85747..bb1c70dc3895 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -7,15 +7,16 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) -from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - deep_gemm_block_shape) + DeepGemmExperts, + _valid_deep_gemm, + _valid_deep_gemm_shape, +) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import deep_gemm_block_shape from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, quant_config: FusedMoEQuantConfig, @@ -25,33 +26,40 @@ def __init__( self.triton_expert = TritonExperts(quant_config) - self.allow_deep_gemm = (allow_deep_gemm - and self.quant_config.use_fp8_w8a8 and - self.block_shape == deep_gemm_block_shape()) + self.allow_deep_gemm = ( + allow_deep_gemm + and self.quant_config.use_fp8_w8a8 + and self.block_shape == deep_gemm_block_shape() + ) - self.deep_gemm_expert = DeepGemmExperts( - self.quant_config) if self.allow_deep_gemm else None + self.deep_gemm_expert = ( + DeepGemmExperts(self.quant_config) if self.allow_deep_gemm else None + ) @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - assert (self.deep_gemm_expert is None - or self.triton_expert.activation_formats - == self.deep_gemm_expert.activation_formats) + assert ( + self.deep_gemm_expert is None + or self.triton_expert.activation_formats + == self.deep_gemm_expert.activation_formats + ) return self.triton_expert.activation_formats def supports_chunking(self) -> bool: dge = self.deep_gemm_expert te = self.triton_expert - return ((dge is None or dge.supports_chunking()) - and (te is None or te.supports_chunking())) + return (dge is None or dge.supports_chunking()) and ( + te is None or te.supports_chunking() + ) def supports_expert_map(self) -> bool: dge = self.deep_gemm_expert te = self.triton_expert - return ((dge is None or dge.supports_expert_map()) - and (te is None or te.supports_expert_map())) + return (dge is None or dge.supports_expert_map()) and ( + te is None or te.supports_expert_map() + ) def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: dge = self.deep_gemm_expert @@ -64,7 +72,8 @@ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: if is_dge_war and is_te_war: assert dge_war == te_war, ( "Both implementations should agree on WeightAndReduce impls. " - f"Got dge_war: {dge_war}, and te_war: {te_war}") + f"Got dge_war: {dge_war}, and te_war: {te_war}" + ) if dge_war is not None: return dge_war @@ -87,17 +96,33 @@ def workspace_shapes( # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and (is_deep_gemm_e8m0_used() - or _valid_deep_gemm_shape(M, N, K)): + if self.allow_deep_gemm and ( + is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K) + ): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts, - expert_tokens_meta) + a, + aq, + M, + N, + K, + topk, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) else: - return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, - global_num_experts, - local_num_experts, - expert_tokens_meta) + return self.triton_expert.workspace_shapes( + a, + aq, + M, + N, + K, + topk, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) def apply( self, @@ -117,9 +142,9 @@ def apply( expert_tokens_meta: Optional[mk.ExpertTokensMetadata], apply_router_weight_on_input: bool, ): - use_deep_gemm = (self.allow_deep_gemm - and (_valid_deep_gemm(hidden_states, w1, w2) - or is_deep_gemm_e8m0_used())) + use_deep_gemm = self.allow_deep_gemm and ( + _valid_deep_gemm(hidden_states, w1, w2) or is_deep_gemm_e8m0_used() + ) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert assert experts is not None diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 05ed93c942c8..8eb724a7435f 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -5,15 +5,17 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, - FusedMoEQuantConfig) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceNoOP) + TopKWeightAndReduceNoOP, +) from vllm.utils import next_power_of_2 class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( self, moe: FusedMoEConfig, @@ -32,10 +34,12 @@ def __init__( @property def activation_formats( - self + self, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - return (mk.FusedMoEActivationFormat.Standard, - mk.FusedMoEActivationFormat.Standard) + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) def supports_chunking(self) -> bool: return True @@ -66,8 +70,7 @@ def workspace_shapes( output = (M, K) return (workspace1, workspace2, output, a.dtype) - def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, - local_num_experts: int): + def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int): # Number of tokens in the input tensor. num_tokens = x.shape[0] # Factor to account for the imbalance of the experts. @@ -117,75 +120,49 @@ def apply( x_quant = hidden_states x_scale = a1q_scale if x_scale is not None: - x_scale = x_scale.view(torch.float8_e4m3fn).reshape( - *x_quant.shape[:-1], -1) + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x_quant.shape[:-1], -1) packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( - torch.bfloat16).view(torch.int16) + torch.bfloat16 + ).view(torch.int16) assert self.w1_scale is not None assert self.w2_scale is not None kwargs = { - "topk_ids": - packed_tensor, - "routing_bias": - None, - "hidden_states": - x_quant, - "hidden_states_scale": - x_scale, - "gemm1_weights": - w1, - "gemm1_weights_scale": - self.w1_scale, - "gemm1_bias": - self.w1_bias, - "gemm1_alpha": - self.gemm1_alpha, - "gemm1_beta": - self.gemm1_beta, - "gemm1_clamp_limit": - self.gemm1_clamp_limit, - "gemm2_weights": - w2, - "gemm2_weights_scale": - self.w2_scale, - "gemm2_bias": - self.w2_bias, - "output1_scale_scalar": - None, - "output1_scale_gate_scalar": - None, - "output2_scale_scalar": - None, - "num_experts": - global_num_experts, - "top_k": - topk, - "n_group": - None, - "topk_group": - None, - "intermediate_size": - intermediate_size, - "local_expert_offset": - local_expert_offset, - "local_num_experts": - local_num_experts, - "routed_scaling_factor": - None, - "tile_tokens_dim": - self._get_tile_tokens_dim(x_quant, topk, local_num_experts), - "routing_method_type": - 1, - "do_finalize": - True, - "output": - output, - "tune_max_num_tokens": - self.max_capture_size, + "topk_ids": packed_tensor, + "routing_bias": None, + "hidden_states": x_quant, + "hidden_states_scale": x_scale, + "gemm1_weights": w1, + "gemm1_weights_scale": self.w1_scale, + "gemm1_bias": self.w1_bias, + "gemm1_alpha": self.gemm1_alpha, + "gemm1_beta": self.gemm1_beta, + "gemm1_clamp_limit": self.gemm1_clamp_limit, + "gemm2_weights": w2, + "gemm2_weights_scale": self.w2_scale, + "gemm2_bias": self.w2_bias, + "output1_scale_scalar": None, + "output1_scale_gate_scalar": None, + "output2_scale_scalar": None, + "num_experts": global_num_experts, + "top_k": topk, + "n_group": None, + "topk_group": None, + "intermediate_size": intermediate_size, + "local_expert_offset": local_expert_offset, + "local_num_experts": local_num_experts, + "routed_scaling_factor": None, + "tile_tokens_dim": self._get_tile_tokens_dim( + x_quant, topk, local_num_experts + ), + "routing_method_type": 1, + "do_finalize": True, + "output": output, + "tune_max_num_tokens": self.max_capture_size, } from flashinfer import trtllm_fp4_block_scale_routed_moe + trtllm_fp4_block_scale_routed_moe(**kwargs) return output diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 678942e568d8..8dc57e5d0ee4 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -7,13 +7,16 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.model_executor.layers.quantization.utils.int8_utils import ( - per_token_group_quant_int8, per_token_quant_int8) + per_token_group_quant_int8, + per_token_quant_int8, +) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - quant_dequant_mxfp4) -from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( - mxfp8_quantize) + quant_dequant_mxfp4, +) +from vllm.model_executor.layers.quantization.utils.mxfp8_utils import mxfp8_quantize from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv @@ -21,26 +24,28 @@ @triton.jit -def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts, - topk_numel, expert_map, - HAS_EXPERT_MAP: tl.constexpr, - BLOCK_SIZE: tl.constexpr): - +def _count_expert_num_tokens( + topk_ids_ptr, + expert_num_tokens_ptr, + num_experts, + topk_numel, + expert_map, + HAS_EXPERT_MAP: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): curr_expert = tl.program_id(0) offsets = tl.arange(0, BLOCK_SIZE) topk_ids_ptrs = topk_ids_ptr + offsets - acc = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32) + acc = tl.zeros((BLOCK_SIZE,), dtype=tl.int32) for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)): mask = offsets < (topk_numel - x * BLOCK_SIZE) expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1) if HAS_EXPERT_MAP: expert_map_ptrs = expert_map + expert_ids expert_map_mask = expert_ids >= 0 - expert_ids = tl.load(expert_map_ptrs, - mask=expert_map_mask, - other=-1) + expert_ids = tl.load(expert_map_ptrs, mask=expert_map_mask, other=-1) has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0) acc = acc + has_curr_expert @@ -51,8 +56,8 @@ def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts, def count_expert_num_tokens( - topk_ids: torch.Tensor, num_local_experts: int, - expert_map: Optional[torch.Tensor]) -> torch.Tensor: + topk_ids: torch.Tensor, num_local_experts: int, expert_map: Optional[torch.Tensor] +) -> torch.Tensor: """ Count the number to tokens assigned to each expert. @@ -68,17 +73,16 @@ def count_expert_num_tokens( A tensor of size num_local_experts, where tensor[i] holds the number of tokens assigned to the ith expert. """ - assert topk_ids.dtype.is_signed, ( - "The kernel uses -1 to represent invalid topk_ids") - expert_num_tokens = torch.empty((num_local_experts), - device=topk_ids.device, - dtype=torch.int32) + assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids" + expert_num_tokens = torch.empty( + (num_local_experts), device=topk_ids.device, dtype=torch.int32 + ) grid = num_local_experts BLOCK_SIZE = min(topk_ids.numel(), 1024) BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE) - _count_expert_num_tokens[(grid, )]( + _count_expert_num_tokens[(grid,)]( topk_ids, expert_num_tokens, num_local_experts, @@ -96,9 +100,10 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ - assert prod(v) <= x.numel( - ), f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" # CUDAGRAPH unfriendly? - return x.flatten()[:prod(v)].view(*v) + assert prod(v) <= x.numel(), ( + f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" + ) # CUDAGRAPH unfriendly? + return x.flatten()[: prod(v)].view(*v) def _fp4_quantize( @@ -106,9 +111,7 @@ def _fp4_quantize( A_scale: Optional[torch.Tensor], is_sf_swizzled_layout: bool, ) -> tuple[torch.Tensor, torch.Tensor]: - return fp4_quantize(A, - A_scale, - is_sf_swizzled_layout=is_sf_swizzled_layout) + return fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout) def _fp8_quantize( @@ -125,7 +128,8 @@ def _fp8_quantize( # TODO(luka): use QuantFP8 custom op # https://github.com/vllm-project/vllm/issues/20711 A, A_scale = ops.scaled_fp8_quant( - A, A_scale, use_per_token_if_dynamic=per_act_token) + A, A_scale, use_per_token_if_dynamic=per_act_token + ) else: assert not per_act_token assert len(block_shape) == 2 @@ -151,8 +155,7 @@ def _int8_quantize( # activations apply per-token quantization. Otherwise, assume # activation tensor-wise fp8/int8 quantization, dynamic or static if block_shape is None: - assert per_act_token, \ - "int8 quantization only supports block or channel-wise" + assert per_act_token, "int8 quantization only supports block or channel-wise" A, A_scale = per_token_quant_int8(A) else: assert not per_act_token @@ -204,9 +207,7 @@ def moe_kernel_quantize_input( elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == "nvfp4": - return _fp4_quantize(A, - A_scale, - is_sf_swizzled_layout=is_fp4_scale_swizzled) + return _fp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled) elif quant_dtype == "mxfp4": return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == "mxfp8": @@ -225,8 +226,7 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: return m[idx, ...] -def normalize_scales_shape( - scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: +def normalize_scales_shape(scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]: if scales is not None: if scales.numel() == 1: scales = scales.view(1, 1) @@ -242,8 +242,9 @@ def normalize_batched_scales_shape( if scales is not None and scales.ndim < 3: if scales.numel() == 1: scales = scales.view(1) - scales = torch.repeat_interleave(scales, num_experts, - dim=0).view(num_experts, 1, 1) + scales = torch.repeat_interleave(scales, num_experts, dim=0).view( + num_experts, 1, 1 + ) else: scales = scales.view(num_experts, -1, scales.size(-1)) @@ -263,7 +264,8 @@ def _validate_scale_shape( assert a_scale.numel() == 1, f"{a_scale.shape}" elif per_act_token_quant: assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, ( - f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1") + f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1" + ) else: assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 363245daa89d..6a49ae42ca89 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Custom normalization layers.""" + from typing import Optional, Union import torch @@ -14,13 +15,14 @@ def is_rocm_aiter_rmsnorm_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER_RMSNORM \ - and envs.VLLM_ROCM_USE_AITER + return envs.VLLM_ROCM_USE_AITER_RMSNORM and envs.VLLM_ROCM_USE_AITER -def rms_norm(x: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> torch.Tensor: +def rms_norm( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: from vllm import _custom_ops as ops + out = torch.empty_like(x) ops.rms_norm( out, @@ -32,9 +34,13 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, def fused_add_rms_norm( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops + ops.fused_add_rms_norm( x, residual, @@ -44,9 +50,11 @@ def fused_add_rms_norm( return x, residual -def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, - variance_epsilon: float) -> torch.Tensor: +def poly_norm( + x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: from vllm import _custom_ops as ops + out = torch.empty_like(x) ops.poly_norm( out, @@ -58,9 +66,11 @@ def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, return out -def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> torch.Tensor: +def rocm_aiter_rms_norm_impl( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: import aiter as rocm_aiter + if x.dim() > 2: x_original_shape = x.shape x = x.reshape(-1, x_original_shape[-1]) @@ -71,9 +81,11 @@ def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor, def rocm_aiter_rmsnorm2d_fwd_with_add_impl( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: - + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: import aiter as rocm_aiter residual_out = torch.empty_like(residual) @@ -89,14 +101,18 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_impl( return output, residual_out -def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> torch.Tensor: +def rocm_aiter_rms_norm_fake( + x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float +) -> torch.Tensor: return torch.empty_like(x) def rocm_aiter_rmsnorm2d_fwd_with_add_fake( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + variance_epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(x), torch.empty_like(residual) @@ -116,7 +132,8 @@ def rocm_aiter_rmsnorm2d_fwd_with_add_fake( def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ - torch.float16, torch.bfloat16 + torch.float16, + torch.bfloat16, ] if use_aiter and with_fused_add: @@ -150,8 +167,9 @@ def __init__( self.hidden_size = hidden_size self.variance_epsilon = eps - self.variance_size_override = (None if var_hidden_size == hidden_size - else var_hidden_size) + self.variance_size_override = ( + None if var_hidden_size == hidden_size else var_hidden_size + ) self.has_weight = has_weight if dtype is not None: self.weight = torch.ones(hidden_size, dtype=dtype) @@ -163,9 +181,11 @@ def __init__( if current_platform.is_rocm(): self.rocm_norm_func = dispatch_rocm_rmsnorm_func( - with_fused_add=False, dtype=weight_dtype) + with_fused_add=False, dtype=weight_dtype + ) self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( - with_fused_add=True, dtype=weight_dtype) + with_fused_add=True, dtype=weight_dtype + ) def forward_native( self, @@ -181,8 +201,10 @@ def forward_native( hidden_size = x.shape[-1] if hidden_size != self.hidden_size: - raise ValueError("Expected hidden_size to be " - f"{self.hidden_size}, but found: {hidden_size}") + raise ValueError( + "Expected hidden_size to be " + f"{self.hidden_size}, but found: {hidden_size}" + ) if self.variance_size_override is None: x_var = x @@ -190,9 +212,10 @@ def forward_native( if hidden_size < self.variance_size_override: raise ValueError( "Expected hidden_size to be at least " - f"{self.variance_size_override}, but found: {hidden_size}") + f"{self.variance_size_override}, but found: {hidden_size}" + ) - x_var = x[:, :, :self.variance_size_override] + x_var = x[:, :, : self.variance_size_override] variance = x_var.pow(2).mean(dim=-1, keepdim=True) @@ -215,8 +238,9 @@ def forward_cuda( add_residual = residual is not None if add_residual: - return fused_add_rms_norm(x, residual, self.weight.data, - self.variance_epsilon) + return fused_add_rms_norm( + x, residual, self.weight.data, self.variance_epsilon + ) else: return rms_norm(x, self.weight.data, self.variance_epsilon) @@ -230,11 +254,11 @@ def forward_hip( add_residual = residual is not None if add_residual: - return self.rocm_norm_func_with_add(x, residual, self.weight.data, - self.variance_epsilon) + return self.rocm_norm_func_with_add( + x, residual, self.weight.data, self.variance_epsilon + ) else: - return self.rocm_norm_func(x, self.weight.data, - self.variance_epsilon) + return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon) def forward_xpu( self, @@ -294,10 +318,7 @@ def forward_static( """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype if residual is not None: - if orig_dtype == torch.float16: - x = x + residual.float() - else: - x = x + residual + x = x + residual.float() if orig_dtype == torch.float16 else x + residual residual = x x = x.float() @@ -315,8 +336,7 @@ def forward_native( residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" - return self.forward_static(self.weight.data, self.variance_epsilon, x, - residual) + return self.forward_static(self.weight.data, self.variance_epsilon, x, residual) def forward_cuda( self, @@ -328,7 +348,8 @@ def forward_cuda( if not getattr(self, "_is_compiled", False): self.forward_static = torch.compile( # type: ignore - self.forward_static) + self.forward_static + ) self._is_compiled = True return self.forward_native(x, residual) @@ -352,8 +373,7 @@ def __init__( self.variance_epsilon = eps def _norm(self, x): - return x / torch.sqrt( - x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) + return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) def forward_native( self, @@ -366,9 +386,12 @@ def forward_native( orig_dtype = x.dtype x_float = x.to(torch.float32) - output = (self.weight[0] * self._norm(x_float**3) + - self.weight[1] * self._norm(x_float**2) + - self.weight[2] * self._norm(x_float) + self.bias) + output = ( + self.weight[0] * self._norm(x_float**3) + + self.weight[1] * self._norm(x_float**2) + + self.weight[2] * self._norm(x_float) + + self.bias + ) return output.to(orig_dtype) def forward_cuda( @@ -391,5 +414,6 @@ def __init__(self, dim: int, eps: float = 1e-6): self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) def forward(self, x: torch.Tensor): - return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias, - self.eps).type_as(x) + return F.layer_norm( + x.float(), (self.dim,), self.weight, self.bias, self.eps + ).type_as(x) diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index 0b87acc85120..e874301b02c0 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -9,9 +9,21 @@ @triton.jit -def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, - d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr, - NUM_BLOCK, CBLOCK: tl.constexpr): +def _fwd_diag_kernel( + Q, + K, + V, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + CBLOCK: tl.constexpr, +): # This kernel computes the diagonal blocks of the attention matrix # Each diagonal block represents attention # where queries attend to keys in the same block @@ -39,18 +51,36 @@ def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, o_cblock_offset = cblock_offset * e # Calculate pointers to the query, key, value, and output tensors - Q_block_ptr = (Q + qk_offset + qk_block_offset + q_cblock_offset + - tl.arange(0, CBLOCK)[:, None] * d + - tl.arange(0, d)[None, :]) - K_trans_block_ptr = (K + qk_offset + qk_block_offset + - tl.arange(0, CBLOCK)[None, :] * d + - tl.arange(0, d)[:, None]) - V_block_ptr = (V + v_offset + v_block_offset + - tl.arange(0, CBLOCK)[:, None] * e + - tl.arange(0, e)[None, :]) - O_block_ptr = (Out + o_offset + o_block_offset + o_cblock_offset + - tl.arange(0, CBLOCK)[:, None] * e + - tl.arange(0, e)[None, :]) + Q_block_ptr = ( + Q + + qk_offset + + qk_block_offset + + q_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * d + + tl.arange(0, d)[None, :] + ) + K_trans_block_ptr = ( + K + + qk_offset + + qk_block_offset + + tl.arange(0, CBLOCK)[None, :] * d + + tl.arange(0, d)[:, None] + ) + V_block_ptr = ( + V + + v_offset + + v_block_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :] + ) + O_block_ptr = ( + Out + + o_offset + + o_block_offset + + o_cblock_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, e)[None, :] + ) # Load the decay rate for the current head S_block_ptr = S + off_h @@ -60,9 +90,9 @@ def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, q_index = tl.arange(0, CBLOCK) + i * CBLOCK # Load query values - q = tl.load(Q_block_ptr, - mask=block_offset + q_index[:, None] < n, - other=0.0).to(tl.float32) + q = tl.load(Q_block_ptr, mask=block_offset + q_index[:, None] < n, other=0.0).to( + tl.float32 + ) # Initialize output accumulator qkv = tl.zeros([CBLOCK, e], dtype=tl.float32) @@ -146,18 +176,30 @@ def _fwd_kv_parallel( kv_offset = off_bh * NUM_BLOCK * d * e # Calculate pointers to the key, value, and key-value tensors - K_trans_block_ptr = (K + k_offset + k_block_offset + - tl.arange(0, CBLOCK)[None, :] * d + - tl.arange(0, D_FBLOCK)[:, None]) - V_block_ptr = (V + v_offset + v_block_offset + - tl.arange(0, CBLOCK)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) - KV_block_ptr = (KV + kv_offset + kv_block_offset + - tl.arange(0, D_FBLOCK)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) + K_trans_block_ptr = ( + K + + k_offset + + k_block_offset + + tl.arange(0, CBLOCK)[None, :] * d + + tl.arange(0, D_FBLOCK)[:, None] + ) + V_block_ptr = ( + V + + v_offset + + v_block_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :] + ) + KV_block_ptr = ( + KV + + kv_offset + + kv_block_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :] + ) # Load the decay factors for the current head and block - k_decay_ptr = (K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :]) + k_decay_ptr = K_decay + off_h * BLOCK + tl.arange(0, CBLOCK)[None, :] kv_index = tl.arange(0, CBLOCK) @@ -165,10 +207,7 @@ def _fwd_kv_parallel( kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32) # Handle the last block which might be smaller than BLOCK - if off_block == NUM_BLOCK - 1: - split_n = n - (NUM_BLOCK - 1) * BLOCK - else: - split_n = BLOCK + split_n = n - (NUM_BLOCK - 1) * BLOCK if off_block == NUM_BLOCK - 1 else BLOCK left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK) k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK @@ -177,12 +216,16 @@ def _fwd_kv_parallel( for j in range(num_blocks): left_bound = (1 - j) * left_shift # Load key and value, handling boundary conditions - k_trans = tl.load(K_trans_block_ptr - left_shift * d, - mask=kv_index[None, :] >= left_bound, - other=0.0) - v = tl.load(V_block_ptr - left_shift * e, - mask=kv_index[:, None] >= left_bound, - other=0.0) + k_trans = tl.load( + K_trans_block_ptr - left_shift * d, + mask=kv_index[None, :] >= left_bound, + other=0.0, + ) + v = tl.load( + V_block_ptr - left_shift * e, + mask=kv_index[:, None] >= left_bound, + other=0.0, + ) # Load decay factor and compute weighted key-value outer product k_decay = tl.load(k_decay_ptr) @@ -198,9 +241,20 @@ def _fwd_kv_parallel( @triton.jit -def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n, - d: tl.constexpr, e: tl.constexpr, BLOCK: tl.constexpr, - NUM_BLOCK, D_FBLOCK: tl.constexpr, E_FBLOCK: tl.constexpr): +def _fwd_kv_reduce( + S, + KV, + KV_HISTORY, + b: tl.constexpr, + h: tl.constexpr, + n, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK, + D_FBLOCK: tl.constexpr, + E_FBLOCK: tl.constexpr, +): # This kernel reduces the key-value outer products # across blocks and updates the KV history off_bh = tl.program_id(0) # batch-head index @@ -209,8 +263,12 @@ def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n, kv_offset = off_bh * NUM_BLOCK * d * e # Calculate pointer to the key-value tensor - KV_block_ptr = (KV + kv_offset + tl.arange(0, D_FBLOCK)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) + KV_block_ptr = ( + KV + + kv_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :] + ) # Load the decay rate for the current head s_ptrs = S + off_h @@ -218,9 +276,12 @@ def _fwd_kv_reduce(S, KV, KV_HISTORY, b: tl.constexpr, h: tl.constexpr, n, # Calculate pointer to the key-value history tensor kv_history_offset = off_bh * d * e - KV_HISTORY_block_ptr = (KV_HISTORY + kv_history_offset + - tl.arange(0, D_FBLOCK)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) + KV_HISTORY_block_ptr = ( + KV_HISTORY + + kv_history_offset + + tl.arange(0, D_FBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :] + ) # Load the previous key-value history kv_pre = tl.load(KV_HISTORY_block_ptr).to(tl.float32) @@ -283,12 +344,18 @@ def _fwd_none_diag_kernel( kv_offset = off_bh * NUM_BLOCK * d * e + off_n * d * e + e_offset # Calculate pointers to the query, output, and key-value tensors - Q_block_ptr = (Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + - tl.arange(0, d)[None, :]) - O_block_ptr = (Out + o_offset + tl.arange(0, CBLOCK)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) - KV_block_ptr = (KV + kv_offset + tl.arange(0, d)[:, None] * e + - tl.arange(0, E_FBLOCK)[None, :]) + Q_block_ptr = ( + Q + q_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :] + ) + O_block_ptr = ( + Out + + o_offset + + tl.arange(0, CBLOCK)[:, None] * e + + tl.arange(0, E_FBLOCK)[None, :] + ) + KV_block_ptr = ( + KV + kv_offset + tl.arange(0, d)[:, None] * e + tl.arange(0, E_FBLOCK)[None, :] + ) # Load the decay rate for the current head S_block_ptr = S + off_h @@ -301,8 +368,7 @@ def _fwd_none_diag_kernel( q_index = block_offset + tl.arange(0, CBLOCK) # Load query values - q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, - other=0.).to(tl.float32) + q = tl.load(Q_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32) # Compute decay factors for the current sub-block q_decay = tl.exp(-s.to(tl.float32) * (off_c * CBLOCK + c_array[:, None])) @@ -311,20 +377,18 @@ def _fwd_none_diag_kernel( qkv_none_diag = tl.dot(q, kv) * q_decay # Load diagonal attention output (computed by _fwd_diag_kernel) - qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, - other=0.).to(tl.float32) + qkv_diag = tl.load(O_block_ptr, mask=q_index[:, None] < n, other=0.0).to(tl.float32) # Combine diagonal and non-diagonal attention outputs qkv = qkv_diag + qkv_none_diag # Store the result - tl.store(O_block_ptr, - qkv.to(O_block_ptr.dtype.element_ty), - mask=q_index[:, None] < n) + tl.store( + O_block_ptr, qkv.to(O_block_ptr.dtype.element_ty), mask=q_index[:, None] < n + ) class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, s, kv_history): # Forward pass of the lightning attention algorithm @@ -336,8 +400,10 @@ def forward(ctx, q, k, v, s, kv_history): # Check CUDA compute capability capability = torch.cuda.get_device_capability() if capability[0] < 8: - raise RuntimeError("Flash attention currently only supported", - "for compute capability >= 80") + raise RuntimeError( + "Flash attention currently only supported", + "for compute capability >= 80", + ) # Get input dimensions b, h, n, d = q.shape @@ -360,19 +426,21 @@ def forward(ctx, q, k, v, s, kv_history): # Step 1: Compute diagonal blocks of attention grid = (b * h * NUM_BLOCK, NUM_CBLOCK) - _fwd_diag_kernel[grid](q, - k, - v, - o, - s, - b, - h, - n, - d, - e, - BLOCK=BLOCK, - NUM_BLOCK=NUM_BLOCK, - CBLOCK=CBLOCK) + _fwd_diag_kernel[grid]( + q, + k, + v, + o, + s, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + CBLOCK=CBLOCK, + ) # Set feature block sizes NUM_FBLOCK = 1 @@ -386,9 +454,7 @@ def forward(ctx, q, k, v, s, kv_history): assert BLOCK % CBLOCK == 0, "BLOCK must be a multiple of CBLOCK" # Step 2: Compute key-value outer products for each block in parallel - kv = torch.empty((b, h, NUM_BLOCK, d, e), - dtype=torch.float32, - device=q.device) + kv = torch.empty((b, h, NUM_BLOCK, d, e), dtype=torch.float32, device=q.device) grid = (b * h, NUM_BLOCK) _fwd_kv_parallel[grid]( k, @@ -412,18 +478,20 @@ def forward(ctx, q, k, v, s, kv_history): # Step 3: Reduce key-value outer products # across blocks and update KV history grid = (b * h, NUM_FBLOCK) - _fwd_kv_reduce[grid](s, - kv, - kv_history, - b, - h, - n, - d, - e, - BLOCK=BLOCK, - NUM_BLOCK=NUM_BLOCK, - D_FBLOCK=D_FBLOCK, - E_FBLOCK=E_FBLOCK) + _fwd_kv_reduce[grid]( + s, + kv, + kv_history, + b, + h, + n, + d, + e, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + D_FBLOCK=D_FBLOCK, + E_FBLOCK=E_FBLOCK, + ) # Step 4: Compute non-diagonal blocks of attention grid = (b * h, NUM_BLOCK * NUM_CBLOCK) @@ -461,12 +529,12 @@ def lightning_attention( v: torch.Tensor, ed: torch.Tensor, block_size: int = 256, - kv_history: Optional[torch.Tensor] = None + kv_history: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Apply lightning attention algorithm + Apply lightning attention algorithm to compute attention efficiently. - + Args: q: Query tensor of shape [batch, heads, seq_len, dim] k: Key tensor of shape [batch, heads, seq_len, dim] @@ -474,7 +542,7 @@ def lightning_attention( ed: Decay rate tensor of shape [heads] block_size: Size of blocks for block-sparse attention kv_history: Optional key-value history from previous computations - + Returns: output: Attention output kv: Updated key-value history @@ -496,9 +564,9 @@ def lightning_attention( # Initialize or clone key-value history if kv_history is None: - kv_history = torch.zeros((q.shape[0], q.shape[1], d, e), - dtype=torch.float32, - device=q.device) + kv_history = torch.zeros( + (q.shape[0], q.shape[1], d, e), dtype=torch.float32, device=q.device + ) else: kv_history = kv_history.clone().contiguous() @@ -533,7 +601,7 @@ def _linear_attn_decode_kernel( ): """ Kernel for linear attention decoding with KV cache. - + This kernel computes attention for a single token using the KV cache. """ pid_b = tl.program_id(0) # batch index @@ -556,8 +624,9 @@ def _linear_attn_decode_kernel( # Calculate offsets for dimensions qk_d_offsets = tl.arange(0, D) v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE - cache_d_offsets = qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[ - None, :] * cache_d1_stride + cache_d_offsets = ( + qk_d_offsets[:, None] * cache_d0_stride + v_d_offsets[None, :] * cache_d1_stride + ) # Calculate offsets for the current batch and head q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride @@ -605,7 +674,7 @@ def linear_decode_forward_triton( ) -> torch.Tensor: """ Perform linear attention decoding using Triton kernels. - + Args: q: Query tensor of shape [B, H, 1, D] k: Key tensor of shape [B, H, 1, D] @@ -614,7 +683,7 @@ def linear_decode_forward_triton( slope_rate: Decay rate tensor slot_idx: Slot indices for batches BLOCK_SIZE: Size of blocks for processing - + Returns: output: Attention output tensor """ diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 04a5db07e95c..3881ba12faa0 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,25 +9,30 @@ import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.utils import dispatch_unquantized_gemm -# yapf: disable -from vllm.model_executor.parameter import (BasevLLMParameter, - BlockQuantScaleParameter, - ModelWeightParameter, - PackedColumnParameter, - PackedvLLMParameter, - PerTensorScaleParameter, - RowvLLMParameter) -# yapf: enable +from vllm.model_executor.parameter import ( + BasevLLMParameter, + BlockQuantScaleParameter, + ModelWeightParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import GiB_bytes @@ -62,8 +67,7 @@ def adjust_bitblas_shard(param, shard_size, shard_offset): bitblas_tile_size = getattr(param, "bitblas_tile_size", None) if bitblas_tile_size is not None: - return (shard_size // bitblas_tile_size, - shard_offset // bitblas_tile_size) + return (shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size) return shard_size, shard_offset @@ -76,9 +80,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -def adjust_bitsandbytes_4bit_shard(param: Parameter, - shard_offsets: dict[str, tuple[int, int]], - loaded_shard_id: str) -> tuple[int, int]: +def adjust_bitsandbytes_4bit_shard( + param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str +) -> tuple[int, int]: """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" total, _ = shard_offsets["total"] @@ -94,8 +98,8 @@ def adjust_bitsandbytes_4bit_shard(param: Parameter, def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): """For fused modules (QKV and MLP) we have an array of length N that holds 1 scale for each "logical" matrix. So the param - is an array of length N. The loaded_weight corresponds to - one of the shards on disk. Here, we slice the param based on + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on the shard_id for loading. """ qkv_idxs = {"q": 0, "k": 1, "v": 2} @@ -122,13 +126,13 @@ def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]): For example, given bnb weight attributes as below: { - 'bnb_shard_offsets': array([0, 4, 8, 16]), + 'bnb_shard_offsets': array([0, 4, 8, 16]), 'bnb_quant_state': {0: ..., 1: ..., 2: ...}, } The function will return: { - 'bnb_shard_offsets': array([0, 4]), + 'bnb_shard_offsets': array([0, 4]), 'bnb_quant_state': {0: ...}, } and @@ -143,8 +147,7 @@ def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]): quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]} quant_state_r = { i - 1: bnb_weight_attrs["bnb_quant_state"][i] - for i in range(1, - len(shard_offsets) - 1) + for i in range(1, len(shard_offsets) - 1) } left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l) right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r) @@ -155,18 +158,23 @@ class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - """Create weights for a linear layer. + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for a linear layer. The weights will be set as attributes of the layer. Args: layer: The layer that is using the LinearMethodBase factory. input_size_per_partition: Size of the weight input dim on rank X. - output_partition_sizes: Sizes of the output dim of each logical + output_partition_sizes: Sizes of the output dim of each logical weight on rank X. E.g., output_partition_sizes for QKVLinear is a list contains the width of Wq, Wk, Wv on rank X. input_size: Size of the input dim of the weight across all ranks. @@ -176,10 +184,12 @@ def create_weights(self, layer: torch.nn.Module, raise NotImplementedError @abstractmethod - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """Apply the weights in layer to the input tensor. Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -188,51 +198,63 @@ def apply(self, class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization.""" - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # This method creates unquantized linear weights. # The weights are not quantized, and they are not sharded. # The amount of memory allocated for the weights is # sum(output_partition_sizes) * input_size_per_partition. try: weight_loader = extra_weight_attrs.pop("weight_loader") - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) except torch.cuda.OutOfMemoryError as e: logger.error("Failed to create unquantized linear weights: %s", e) if torch.cuda.is_available(): logger.debug("CUDA device: %s", torch.cuda.current_device()) - logger.debug("Allocated: %.2f GiB", - torch.cuda.memory_allocated() / GiB_bytes) - logger.debug("Reserved: %.2f GiB", - torch.cuda.memory_reserved() / GiB_bytes) + logger.debug( + "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes + ) + logger.debug( + "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes + ) raise RuntimeError( "Failed to create unquantized linear weights. " "This may be caused by insufficient memory to allocate " - "the weight.") from e + "the weight." + ) from e layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if current_platform.is_cpu(): - from vllm.model_executor.layers.utils import ( - dispatch_cpu_unquantized_gemm) - dispatch_cpu_unquantized_gemm(layer, remove_weight=True) + from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + dispatch_cpu_unquantized_gemm(layer, remove_weight=True) + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) @@ -274,17 +296,13 @@ def __init__( self.quant_config = quant_config self.prefix = prefix if quant_config is None: - self.quant_method: Optional[ - QuantizeMethodBase] = UnquantizedLinearMethod() + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() else: - self.quant_method = quant_config.get_quant_method(self, - prefix=prefix) + self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias self.disable_tp = disable_tp - self.tp_rank = (get_tensor_model_parallel_rank() - if not disable_tp else 0) - self.tp_size = (get_tensor_model_parallel_world_size() - if not disable_tp else 1) + self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 + self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 def update_param_tp_status(self): for param in self.parameters(): @@ -329,32 +347,40 @@ def __init__( else: self.output_partition_sizes = [output_size] - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix=prefix, - return_bias=return_bias, - disable_tp=disable_tp) + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size, - self.output_partition_sizes, - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights( + self, + self.input_size, + self.output_partition_sizes, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader, + ) if bias: self.bias = Parameter( - torch.empty(self.output_size, dtype=self.params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) + torch.empty(self.output_size, dtype=self.params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) else: self.register_parameter("bias", None) @@ -377,7 +403,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param.size() == loaded_weight.size(), ( f"Tried to load weights of size {loaded_weight.size()}" - f"to a parameter of size {param.size()}") + f"to a parameter of size {param.size()}" + ) param.data.copy_(loaded_weight) def forward( @@ -423,7 +450,7 @@ class ColumnParallelLinear(LinearBase): output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) + (e.g. model.layers.0.qkv_proj) return_bias: If true, return bias together with outputs in forward pass. disable_tp: If true, weights matrix won't be sharded through tp rank. """ @@ -444,28 +471,27 @@ def __init__( disable_tp: bool = False, ): # Divide the weight matrix along the last dimension. - self.tp_rank = (get_tensor_model_parallel_rank() - if not disable_tp else 0) - self.tp_size = (get_tensor_model_parallel_world_size() - if not disable_tp else 1) + self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 + self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) self.output_partition_sizes = [self.output_size_per_partition] # If QKV or MergedColumn, use output size of each partition. if hasattr(self, "output_sizes"): self.output_partition_sizes = [ - divide(output_size, self.tp_size) - for output_size in self.output_sizes + divide(output_size, self.tp_size) for output_size in self.output_sizes ] - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias, - disable_tp=disable_tp) + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) self.gather_output = gather_output @@ -481,22 +507,27 @@ def __init__( output_size=self.output_size, params_dtype=self.params_dtype, weight_loader=( - self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + ) if bias: self.bias = Parameter( - torch.empty(self.output_size_per_partition, - dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) + torch.empty(self.output_size_per_partition, dtype=params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) else: self.register_parameter("bias", None) self.update_param_tp_status() def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - output_dim = getattr(param, "output_dim", None) is_sharded_weight = getattr(param, "is_sharded_weight", False) @@ -516,16 +547,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): final_shape = list(loaded_weight.shape) if output_dim is not None: assert final_shape[output_dim] % self.tp_size == 0 - final_shape[output_dim] = (final_shape[output_dim] // - self.tp_size) + final_shape[output_dim] = final_shape[output_dim] // self.tp_size param.materialize(final_shape, dtype=loaded_weight.dtype) param_data = param.data if output_dim is not None and not is_sharded_weight: shard_size = param_data.shape[output_dim] start_idx = self.tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -535,8 +564,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def weight_loader_v2(self, param: BasevLLMParameter, - loaded_weight: torch.Tensor): + def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: @@ -614,29 +642,29 @@ def __init__( disable_tp: bool = False, ): self.output_sizes = output_sizes - self.tp_size = (get_tensor_model_parallel_world_size() - if not disable_tp else 1) - self.tp_rank = (get_tensor_model_parallel_rank() - if not disable_tp else 0) - - assert all(output_size % self.tp_size == 0 - for output_size in output_sizes) - super().__init__(input_size=input_size, - output_size=sum(output_sizes), - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix, - return_bias=return_bias, - disable_tp=disable_tp) - - def weight_loader(self, - param: Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None): + self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 + self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 + + assert all(output_size % self.tp_size == 0 for output_size in output_sizes) + super().__init__( + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None, + ): # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -647,20 +675,17 @@ def weight_loader(self, param.shard_weight_type[loaded_shard_id] = loaded_weight.item() else: param.shard_weight_type = { - i: loaded_weight.item() - for i, _ in enumerate(self.output_sizes) + i: loaded_weight.item() for i, _ in enumerate(self.output_sizes) } return if is_gguf_weight: - output_dim = getattr(param, "output_dim", None) shard_size = loaded_weight.size(output_dim) // self.tp_size start_idx = self.tp_rank * shard_size if loaded_shard_id is not None: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) @@ -677,14 +702,14 @@ def weight_loader(self, if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, 0) + param_data, loaded_weight, 0 + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return current_shard_offset = 0 - use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", - False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) shard_offsets: list[tuple[int, int, int]] = [] for i, output_size in enumerate(self.output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) @@ -699,10 +724,12 @@ def weight_loader(self, shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) shard_size, shard_offset = adjust_bitblas_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) if use_bitsandbytes_4bit: index = list(itertools.accumulate([0] + self.output_sizes)) @@ -712,17 +739,18 @@ def weight_loader(self, } orig_offsets["total"] = (self.output_size, 0) shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( - param, orig_offsets, str(shard_id)) + param, orig_offsets, str(shard_id) + ) loaded_weight_shard = loaded_weight.narrow( - output_dim, shard_offset, shard_size) + output_dim, shard_offset, shard_size + ) self.weight_loader(param, loaded_weight_shard, shard_id) return assert loaded_shard_id < len(self.output_sizes) if output_dim is not None: - shard_offset = (sum(self.output_sizes[:loaded_shard_id]) // - self.tp_size) + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size # Special case for quantization. # If quantized, we need to adjust the offset and size to account @@ -733,12 +761,13 @@ def weight_loader(self, shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) shard_size, shard_offset = adjust_bitblas_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) - use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", - False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) is_sharded_weight = getattr(param, "is_sharded_weight", False) # bitsandbytes loads the weights of the specific portion # no need to narrow @@ -746,19 +775,17 @@ def weight_loader(self, if use_bitsandbytes_4bit: shard_size = loaded_weight.shape[output_dim] - shard_offset = loaded_weight.shape[output_dim] * \ - loaded_shard_id + shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) + param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = self.tp_rank * shard_size if not is_sharded_weight: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for per-tensor scales in fused case. elif needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, loaded_shard_id) + param_data, loaded_weight, loaded_shard_id + ) else: ignore_warning = getattr(param, "ignore_warning", False) @@ -766,13 +793,15 @@ def weight_loader(self, logger.warning( "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " - "the same for all partitions.") + "the same for all partitions." + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, - loaded_weight: torch.Tensor): + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ): """ Handle special case for models where MLP layers are already fused on disk. In this case, we have no shard id. This function @@ -793,25 +822,28 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, (PackedColumnParameter, PackedvLLMParameter - )) and param.packed_dim == param.output_dim: - shard_size, shard_offset = \ - param.adjust_shard_indexes_for_packing( - shard_size=shard_size, shard_offset=shard_offset) - - loaded_weight_shard = loaded_weight.narrow(param.output_dim, - shard_offset, - shard_size) + if ( + isinstance(param, (PackedColumnParameter, PackedvLLMParameter)) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) self.weight_loader_v2(param, loaded_weight_shard, shard_id) - def weight_loader_v2(self, - param: BasevLLMParameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None): + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None, + ): if loaded_shard_id is None: if isinstance(param, PerTensorScaleParameter): - param.load_merged_column_weight(loaded_weight=loaded_weight, - shard_id=0) + param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight) @@ -830,20 +862,24 @@ def weight_loader_v2(self, assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = ( - (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // - block_n) // self.tp_size - shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // - block_n // self.tp_size) + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n + ) // self.tp_size + shard_size = ( + (self.output_sizes[loaded_shard_id] + block_n - 1) + // block_n + // self.tp_size + ) else: - shard_offset = sum( - self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size shard_size = self.output_sizes[loaded_shard_id] // self.tp_size - param.load_merged_column_weight(loaded_weight=loaded_weight, - shard_id=loaded_shard_id, - shard_offset=shard_offset, - shard_size=shard_size, - tp_rank=self.tp_rank) + param.load_merged_column_weight( + loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + tp_rank=self.tp_rank, + ) class QKVParallelLinear(ColumnParallelLinear): @@ -896,42 +932,43 @@ def __init__( total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. - tp_size = (get_tensor_model_parallel_world_size() - if not disable_tp else 1) + tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 - self.num_kv_head_replicas = divide(tp_size, - self.total_num_kv_heads) + self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) else: self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_head_replicas = 1 input_size = self.hidden_size - output_size = (self.num_heads + - 2 * self.num_kv_heads) * tp_size * self.head_size + output_size = ( + (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + ) self.output_sizes = [ self.num_heads * self.head_size * tp_size, # q_proj self.num_kv_heads * self.head_size * tp_size, # k_proj - self.num_kv_heads * self.head_size * tp_size, # v_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj ] - super().__init__(input_size=input_size, - output_size=output_size, - bias=bias, - gather_output=False, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix, - return_bias=return_bias, - disable_tp=disable_tp) + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { "q": 0, "k": self.num_heads * self.head_size, "v": (self.num_heads + self.num_kv_heads) * self.head_size, - "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size + "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, } return shard_offset_mapping.get(loaded_shard_id) @@ -943,10 +980,11 @@ def _get_shard_size_mapping(self, loaded_shard_id: str): } return shard_size_mapping.get(loaded_shard_id) - def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, - loaded_weight: torch.Tensor): + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ): """ - Handle special case for models where QKV layers are already + Handle special case for models where QKV layers are already fused on disk. In this case, we have no shard id. This function determines the shard id by splitting these layers and then calls the weight loader using the shard id. @@ -957,41 +995,49 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, shard_offsets = [ # (shard_id, shard_offset, shard_size) ("q", 0, self.total_num_heads * self.head_size), - ("k", self.total_num_heads * self.head_size, - self.total_num_kv_heads * self.head_size), - ("v", - (self.total_num_heads + self.total_num_kv_heads) * self.head_size, - self.total_num_kv_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), ] for shard_id, shard_offset, shard_size in shard_offsets: # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, (PackedColumnParameter, PackedvLLMParameter - )) and param.packed_dim == param.output_dim: - shard_size, shard_offset = \ - param.adjust_shard_indexes_for_packing( - shard_size=shard_size, shard_offset=shard_offset) - - loaded_weight_shard = loaded_weight.narrow(param.output_dim, - shard_offset, - shard_size) + if ( + isinstance(param, (PackedColumnParameter, PackedvLLMParameter)) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) self.weight_loader_v2(param, loaded_weight_shard, shard_id) - def weight_loader_v2(self, - param: BasevLLMParameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): - param.load_qkv_weight(loaded_weight=loaded_weight, - shard_id=0, - tp_rank=self.tp_rank) + param.load_qkv_weight( + loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank + ) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): - param.load_qkv_weight(loaded_weight=loaded_weight, - tp_rank=self.tp_rank) + param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank) return # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) @@ -1013,18 +1059,21 @@ def weight_loader_v2(self, shard_offset = (shard_offset + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n - param.load_qkv_weight(loaded_weight=loaded_weight, - num_heads=self.num_kv_head_replicas, - shard_id=loaded_shard_id, - shard_offset=shard_offset, - shard_size=shard_size, - tp_rank=self.tp_rank) - - def weight_loader(self, - param: Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): + param.load_qkv_weight( + loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + tp_rank=self.tp_rank, + ) + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) @@ -1035,10 +1084,7 @@ def weight_loader(self, param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) param.shard_weight_type[loaded_shard_id] = loaded_weight.item() else: - param.shard_weight_type = { - k: loaded_weight.item() - for k in idx_map - } + param.shard_weight_type = {k: loaded_weight.item() for k in idx_map} return if is_gguf_weight: @@ -1047,8 +1093,7 @@ def weight_loader(self, start_idx = self.tp_rank * shard_size if loaded_shard_id is not None: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) @@ -1066,7 +1111,8 @@ def weight_loader(self, if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, 0) + param_data, loaded_weight, 0 + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -1074,13 +1120,18 @@ def weight_loader(self, shard_offsets = [ # (shard_id, shard_offset, shard_size) ("q", 0, self.total_num_heads * self.head_size), - ("k", self.total_num_heads * self.head_size, - self.total_num_kv_heads * self.head_size), - ("v", (self.total_num_heads + self.total_num_kv_heads) * - self.head_size, self.total_num_kv_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), ] - use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", - False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: @@ -1093,27 +1144,35 @@ def weight_loader(self, # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.total_num_heads * self.head_size), - "k": (self.total_num_heads * self.head_size, - self.total_num_kv_heads * self.head_size), - "v": - ((self.total_num_heads + self.total_num_kv_heads) * - self.head_size, - self.total_num_kv_heads * self.head_size), - "total": - ((self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_size, 0) + "k": ( + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + "v": ( + (self.total_num_heads + self.total_num_kv_heads) + * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + "total": ( + (self.total_num_heads + 2 * self.total_num_kv_heads) + * self.head_size, + 0, + ), } shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( - param, orig_qkv_offsets, shard_id) + param, orig_qkv_offsets, shard_id + ) loaded_weight_shard = loaded_weight.narrow( - output_dim, shard_offset, shard_size) + output_dim, shard_offset, shard_size + ) self.weight_loader(param, loaded_weight_shard, shard_id) return @@ -1128,8 +1187,7 @@ def weight_loader(self, shard_offset = self.num_heads * self.head_size shard_size = self.num_kv_heads * self.head_size elif loaded_shard_id == "v": - shard_offset = (self.num_heads + - self.num_kv_heads) * self.head_size + shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account @@ -1141,10 +1199,10 @@ def weight_loader(self, # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset) + param, shard_size, shard_offset + ) - use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", - False) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) is_sharded_weight = getattr(param, "is_sharded_weight", False) # bitsandbytes loads the weights of the specific portion # no need to narrow @@ -1153,20 +1211,24 @@ def weight_loader(self, if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), - "k": (self.num_heads * self.head_size, - self.num_kv_heads * self.head_size), - "v": - ((self.num_heads + self.num_kv_heads) * self.head_size, - self.num_kv_heads * self.head_size), - "total": - ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, - 0) + "k": ( + self.num_heads * self.head_size, + self.num_kv_heads * self.head_size, + ), + "v": ( + (self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size, + ), + "total": ( + (self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0, + ), } shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( - param, orig_qkv_offsets, loaded_shard_id) + param, orig_qkv_offsets, loaded_shard_id + ) - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) + param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": shard_id = self.tp_rank else: @@ -1174,20 +1236,21 @@ def weight_loader(self, start_idx = shard_id * shard_size if not is_sharded_weight: - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for per-tensor scales in fused case. elif needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, loaded_shard_id) + param_data, loaded_weight, loaded_shard_id + ) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: logger.warning( "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " - "for all partitions.") + "for all partitions." + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -1243,22 +1306,22 @@ def __init__( disable_tp: bool = False, ): # Divide the weight matrix along the first dimension. - self.tp_rank = (get_tensor_model_parallel_rank() - if not disable_tp else 0) - self.tp_size = (get_tensor_model_parallel_world_size() - if not disable_tp else 1) + self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 + self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 self.input_size_per_partition = divide(input_size, self.tp_size) self.output_size_per_partition = output_size self.output_partition_sizes = [output_size] - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias, - disable_tp=disable_tp) + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp, + ) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -1272,19 +1335,26 @@ def __init__( output_size=self.output_size, params_dtype=self.params_dtype, weight_loader=( - self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + ) if not reduce_results and (bias and not skip_bias_add): - raise ValueError("When not reduce the results, adding bias to the " - "results can lead to incorrect results") + raise ValueError( + "When not reduce the results, adding bias to the " + "results can lead to incorrect results" + ) if bias: - self.bias = Parameter( - torch.empty(self.output_size, dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) + self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) else: self.register_parameter("bias", None) self.update_param_tp_status() @@ -1307,16 +1377,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if is_gguf_weight and isinstance(param, UninitializedParameter): weight_shape = list(loaded_weight.shape) if input_dim: - weight_shape[input_dim] = (weight_shape[input_dim] // - self.tp_size) + weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data if input_dim is not None and not is_sharded_weight: shard_size = param_data.shape[input_dim] start_idx = self.tp_rank * shard_size - loaded_weight = loaded_weight.narrow(input_dim, start_idx, - shard_size) + loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -1326,9 +1394,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def weight_loader_v2(self, param: BasevLLMParameter, - loaded_weight: torch.Tensor): - + def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: @@ -1345,7 +1411,8 @@ def forward( input_parallel = input_ else: splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) + input_, num_partitions=self.tp_size + ) input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. @@ -1395,37 +1462,44 @@ class QKVCrossParallelLinear(LinearBase): (e.g. model.layers.0.qkv_proj) """ - def __init__(self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): # input_size and output_size are not used, just for alignment input_size = hidden_size output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size - super().__init__(input_size=input_size, - output_size=output_size, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix) + super().__init__( + input_size=input_size, + output_size=output_size, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + ) self.quant_config = quant_config # Empty placeholders for loading as a single module. placeholder_size = 0 assert self.quant_method is not None - self.quant_method.create_weights(self, - placeholder_size, [placeholder_size], - placeholder_size, - placeholder_size, - self.params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights( + self, + placeholder_size, + [placeholder_size], + placeholder_size, + placeholder_size, + self.params_dtype, + weight_loader=self.weight_loader, + ) # Use a dictionary to avoid submodules parameters auto-registration: # drop-in replacement for a `QKVParallelLinear` module. @@ -1437,7 +1511,8 @@ def __init__(self, quant_config=quant_config, skip_bias_add=skip_bias_add, params_dtype=params_dtype, - prefix=f"{prefix}.q_proj_decoder") + prefix=f"{prefix}.q_proj_decoder", + ) self.proj["kv_proj_encoder"] = QKVParallelLinear( hidden_size=hidden_size, @@ -1448,7 +1523,8 @@ def __init__(self, quant_config=quant_config, skip_bias_add=skip_bias_add, params_dtype=params_dtype, - prefix=f"{prefix}.kv_proj_encoder") + prefix=f"{prefix}.kv_proj_encoder", + ) # `kv_proj_encoder.num_kv_heads` accounts for sharding with tp>1. self.q_size = self.q_proj_decoder.output_size_per_partition @@ -1456,10 +1532,13 @@ def __init__(self, if bias: self.bias = torch.nn.Parameter() - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader_v1, - }) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader_v1, + }, + ) else: self.bias = None @@ -1474,9 +1553,7 @@ def q_proj_decoder(self) -> ColumnParallelLinear: for name, param in self.named_parameters(): target_param = getattr(layer, name, None) if target_param is not None: - self.sync_weight_attrs(param, - target_param, - mode="q_proj_decoder") + self.sync_weight_attrs(param, target_param, mode="q_proj_decoder") return layer @property @@ -1485,9 +1562,7 @@ def kv_proj_encoder(self) -> QKVParallelLinear: for name, param in self.named_parameters(): target_param = getattr(layer, name, None) if target_param is not None: - self.sync_weight_attrs(param, - target_param, - mode="kv_proj_encoder") + self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder") return layer def sync_weight_attrs( @@ -1498,15 +1573,14 @@ def sync_weight_attrs( ): missing_attrs_dict = { k: getattr(src_param, k) - for k in (set(vars(src_param).keys()) - - set(vars(tgt_param).keys())) + for k in (set(vars(src_param).keys()) - set(vars(tgt_param).keys())) } # TODO(Isotr0py): handle bitsandbytes 8bit - use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", - False) - if (missing_attrs_dict and use_bitsandbytes_4bit): + use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", False) + if missing_attrs_dict and use_bitsandbytes_4bit: q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard( - missing_attrs_dict) + missing_attrs_dict + ) if mode == "q_proj_decoder": set_weight_attrs(tgt_param, q_proj_attrs) elif mode == "kv_proj_encoder": @@ -1524,12 +1598,10 @@ def _is_same_param( key_to_ignore = ["weight_loader", "_weight_loader"] has_same_type_name = type(src_param) is type(map_param) src_param_attrs = { - k: v - for k, v in src_param.__dict__.items() if k not in key_to_ignore + k: v for k, v in src_param.__dict__.items() if k not in key_to_ignore } map_param_attrs = { - k: v - for k, v in map_param.__dict__.items() if k not in key_to_ignore + k: v for k, v in map_param.__dict__.items() if k not in key_to_ignore } has_same_attrs = src_param_attrs == map_param_attrs return has_same_type_name and has_same_attrs @@ -1540,12 +1612,11 @@ def select_proj_params( param: nn.Parameter, ) -> nn.Parameter: """ - Given the placeholder param, + Given the placeholder param, return the corresponding param in the proj layers. """ target_param_list = [ - v for _, v in layer.named_parameters() - if self._is_same_param(param, v) + v for _, v in layer.named_parameters() if self._is_same_param(param, v) ] assert len(target_param_list) == 1 target_param = target_param_list[0] @@ -1568,26 +1639,28 @@ def forward( # type: ignore[override] k, v = kv_enc.split(self.kv_size, dim=-1) return q, k, v - def weight_loader_v1(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): + def weight_loader_v1( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): # just like all other parameters, does not yet # support loading bias with weight_loader_v2 - layer = (self.q_proj_decoder - if loaded_shard_id == "q" else self.kv_proj_encoder) + layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder target_param = self.select_proj_params(layer, param) - shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () + shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else () layer.weight_loader(target_param, loaded_weight, *shard_id_args) - def weight_loader(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): - layer = (self.q_proj_decoder - if loaded_shard_id == "q" else self.kv_proj_encoder) + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): + layer = self.q_proj_decoder if loaded_shard_id == "q" else self.kv_proj_encoder target_param = self.select_proj_params(layer, param) - shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () + shard_id_args = (loaded_shard_id,) if loaded_shard_id != "q" else () if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED: layer.weight_loader_v2(target_param, loaded_weight, *shard_id_args) else: diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 2110aa2769b9..3db5e0b32553 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -1,15 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that compute logits from hidden_stats.""" + from typing import Optional import torch -from vllm.distributed import (tensor_model_parallel_all_gather, - tensor_model_parallel_gather) +from vllm.distributed import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_gather, +) from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.platforms import current_platform @@ -23,12 +25,14 @@ class LogitsProcessor(CustomOp): 3. Apply logits processors (if any). """ - def __init__(self, - vocab_size: int, - org_vocab_size: Optional[int] = None, - scale: float = 1.0, - logits_as_input: bool = False, - soft_cap: Optional[float] = None) -> None: + def __init__( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, + ) -> None: """ Args: scale: A scaling factor to apply to the logits. @@ -87,16 +91,14 @@ def _get_logits( embedding_bias: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. - logits = lm_head.quant_method.apply(lm_head, - hidden_states, - bias=embedding_bias) + logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias) # Gather logits for TP logits = self._gather_logits(logits) # Remove paddings in vocab (if any). if logits is not None: - logits = logits[..., :self.org_vocab_size] + logits = logits[..., : self.org_vocab_size] return logits def extra_repr(self) -> str: diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 319133777992..99f05e2eca0e 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -19,16 +19,21 @@ from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.lightning_attn import ( - lightning_attention, linear_decode_forward_triton) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) + lightning_attention, + linear_decode_forward_triton, +) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata @@ -47,8 +52,7 @@ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.tp_world = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() - self.weight = nn.Parameter(torch.ones(int(hidden_size / - self.tp_world))) + self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world))) self.weight.weight_loader = self.weight_loader self.variance_epsilon = eps @@ -75,8 +79,7 @@ def _forward( x = x.to(torch.float32) variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32) if self.tp_world > 1: - variance = tensor_model_parallel_all_reduce( - variance) / self.tp_world + variance = tensor_model_parallel_all_reduce(variance) / self.tp_world x = x * torch.rsqrt(variance + self.variance_epsilon) x = x.to(orig_dtype) * self.weight return x @@ -91,17 +94,17 @@ def forward( class MiniMaxText01LinearKernel: - @staticmethod - def jit_linear_forward_prefix(q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kv_caches: torch.Tensor, - slope_rate: torch.Tensor, - block_size: int, - layer_idx: Optional[int] = None, - **kwargs) -> torch.Tensor: - + def jit_linear_forward_prefix( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_caches: torch.Tensor, + slope_rate: torch.Tensor, + block_size: int, + layer_idx: Optional[int] = None, + **kwargs, + ) -> torch.Tensor: slope_rate = slope_rate.to(torch.float32) should_pad_dim = q.dim() == 3 if should_pad_dim: @@ -111,26 +114,22 @@ def jit_linear_forward_prefix(q: torch.Tensor, b, h, n, d = q.shape e = d kv_history = kv_caches.reshape(1, h, d, e).contiguous() - output, kv_history = lightning_attention(q, - k, - v, - slope_rate, - block_size=block_size, - kv_history=kv_history) + output, kv_history = lightning_attention( + q, k, v, slope_rate, block_size=block_size, kv_history=kv_history + ) kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e)) assert output.shape[0] == 1, "batch size must be 1" return rearrange(output.squeeze(0), "h n d -> n (h d)") class MiniMaxText01LinearAttention(nn.Module, MambaBase): - @property def mamba_type(self) -> str: return "linear_attention" def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.linear_attn import ( - LinearAttentionBackend) + from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend + return LinearAttentionBackend def get_state_dtype(self) -> tuple[torch.dtype]: @@ -143,9 +142,8 @@ def get_state_dtype(self) -> tuple[torch.dtype]: def get_state_shape(self) -> tuple[tuple[int, int, int], ...]: return MambaStateShapeCalculator.linear_attention_state_shape( - num_heads=self.num_heads, - tp_size=self.tp_size, - head_dim=self.head_dim) + num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim + ) def __init__( self, @@ -209,16 +207,16 @@ def __init__( eps=1e-5, ) - slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( - self.num_heads) + slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(self.num_heads) if num_hidden_layer <= 1: self.slope_rate = slope_rate * (1 + 1e-5) else: - self.slope_rate = slope_rate * (1 - layer_idx / - (num_hidden_layer - 1) + 1e-5) - self.tp_slope = self.slope_rate[self.tp_rank * - self.tp_heads:(self.tp_rank + 1) * - self.tp_heads].contiguous() + self.slope_rate = slope_rate * ( + 1 - layer_idx / (num_hidden_layer - 1) + 1e-5 + ) + self.tp_slope = self.slope_rate[ + self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads + ].contiguous() compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -226,36 +224,36 @@ def __init__( compilation_config.static_forward_context[prefix] = self @staticmethod - def weight_direct_load(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: + def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight) return @staticmethod def _build_slope_tensor(n_attention_heads: int): - def get_slopes(n): - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + get_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) - - slopes = torch.tensor(get_slopes(n_attention_heads), - dtype=torch.float32).reshape( - n_attention_heads, 1, 1) + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + slopes = torch.tensor( + get_slopes(n_attention_heads), dtype=torch.float32 + ).reshape(n_attention_heads, 1, 1) return slopes - def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): + def _prefill_and_mix_infer( + self, q, k, v, kv_cache, state_indices_tensor, attn_metadata + ): hidden = [] for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): if _prefill_idx >= len(attn_metadata.query_start_loc): @@ -278,12 +276,13 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, slice_layer_cache, self.tp_slope, self.BLOCK, - layer_idx=self.layer_idx) + layer_idx=self.layer_idx, + ) hidden.append(out_slice.contiguous()) if attn_metadata.num_decode_tokens > 0: - hidden_decode = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) + hidden_decode = self._decode_infer( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) hidden.insert(0, hidden_decode) if not hidden: @@ -292,18 +291,19 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor, hidden = torch.concat(hidden, dim=0).contiguous() return hidden - def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, - attn_metadata): - q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[:attn_metadata.num_decodes] - hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope, - slot_id, 32) + def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): + q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[: attn_metadata.num_decodes] + hidden = linear_decode_forward_triton( + q, k, v, kv_cache, self.tp_slope, slot_id, 32 + ) return hidden - def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor) -> None: + def forward( + self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor + ) -> None: torch.ops.vllm.linear_attention( hidden_states, output, @@ -311,16 +311,18 @@ def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, self.prefix, ) - def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor) -> None: + def _forward( + self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor + ) -> None: forward_context = get_forward_context() attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] assert isinstance(attn_metadata, LinearAttentionMetadata) - num_actual_tokens = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens + num_actual_tokens = ( + attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + ) else: num_actual_tokens = hidden_states.shape[0] @@ -335,35 +337,39 @@ def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor, num_prefills = getattr(attn_metadata, "num_prefills", 0) if num_prefills > 0: - num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", - 0) + num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0) for prefill_idx in range(num_prefills): - q_start = attn_metadata.query_start_loc[num_decode_tokens + - prefill_idx] - q_end = attn_metadata.query_start_loc[num_decode_tokens + - prefill_idx + 1] + q_start = attn_metadata.query_start_loc[ + num_decode_tokens + prefill_idx + ] + q_end = attn_metadata.query_start_loc[ + num_decode_tokens + prefill_idx + 1 + ] query_len = q_end - q_start - context_len = attn_metadata.seq_lens[ - num_decode_tokens + prefill_idx] - query_len + context_len = ( + attn_metadata.seq_lens[num_decode_tokens + prefill_idx] + - query_len + ) if context_len == 0: - block_to_clear = state_indices_tensor[num_decode_tokens - + prefill_idx] + block_to_clear = state_indices_tensor[ + num_decode_tokens + prefill_idx + ] kv_cache[block_to_clear, ...] = 0 decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 if attn_metadata is None: - hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]), - device=q.device, - dtype=q.dtype) + hidden = torch.empty( + (q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype + ) else: if not decode_only: - hidden = self._prefill_and_mix_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) + hidden = self._prefill_and_mix_infer( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) else: - hidden = self._decode_infer(q, k, v, kv_cache, - state_indices_tensor, - attn_metadata) + hidden = self._decode_infer( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) hidden = self.norm._forward(hidden) gate, _ = self.output_gate(hidden_states[:num_actual_tokens]) hidden = F.sigmoid(gate) * hidden @@ -380,9 +386,7 @@ def linear_attention( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self._forward(hidden_states=hidden_states, - output=output, - positions=positions) + self._forward(hidden_states=hidden_states, output=output, positions=positions) def linear_attention_fake( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index d64854cdb381..8ab77965ae80 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -12,20 +12,30 @@ from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) + causal_conv1d_fn, + causal_conv1d_update, +) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_scan_fn, selective_state_update) + selective_scan_fn, + selective_state_update, +) from vllm.model_executor.utils import set_weight_attrs from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -44,22 +54,24 @@ class MambaMixer(MambaBase, CustomOp): **selective** state spaces) """ - def __init__(self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - time_step_rank: int, - use_conv_bias: bool, - use_bias: bool, - use_rms_norm: bool, - rms_norm_has_weight: bool = True, - rms_norm_eps: float = 1e-5, - activation="silu", - is_lora_enabled: bool = False, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + time_step_rank: int, + use_conv_bias: bool, + use_bias: bool, + use_rms_norm: bool, + rms_norm_has_weight: bool = True, + rms_norm_eps: float = 1e-5, + activation="silu", + is_lora_enabled: bool = False, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): super().__init__() self.time_step_rank = time_step_rank self.ssm_state_size = ssm_state_size @@ -80,9 +92,9 @@ def __init__(self, # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = MergedColumnParallelLinear(hidden_size, - [intermediate_size] * 2, - bias=use_bias) + self.in_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=use_bias + ) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( @@ -93,17 +105,18 @@ def __init__(self, # time step projection (discretization) - # In the forward we need to apply dt_proj without the bias, # as the bias is added in the selective scan kernel. - self.dt_proj = ColumnParallelLinear(time_step_rank, - intermediate_size, - bias=True, - skip_bias_add=True) + self.dt_proj = ColumnParallelLinear( + time_step_rank, intermediate_size, bias=True, skip_bias_add=True + ) def weight_loader(param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() param.data.copy_( - loaded_weight.data.split(loaded_weight.shape[0] // tp_size, - dim=0)[tp_rank]) + loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[ + tp_rank + ] + ) def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): weight_loader(param, -torch.exp(loaded_weight.float())) @@ -114,7 +127,8 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): intermediate_size // tp_size, ssm_state_size, dtype=torch.float32, - )) + ) + ) self.D = nn.Parameter(torch.ones(intermediate_size // tp_size)) set_weight_attrs(self.D, {"weight_loader": weight_loader}) @@ -127,23 +141,35 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): input_is_parallel=True, ) - self.dt_layernorm = RMSNorm( - time_step_rank, - eps=rms_norm_eps, - has_weight=rms_norm_has_weight, - ) if use_rms_norm else None + self.dt_layernorm = ( + RMSNorm( + time_step_rank, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) + if use_rms_norm + else None + ) - self.b_layernorm = RMSNorm( - ssm_state_size, - eps=rms_norm_eps, - has_weight=rms_norm_has_weight, - ) if use_rms_norm else None + self.b_layernorm = ( + RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) + if use_rms_norm + else None + ) - self.c_layernorm = RMSNorm( - ssm_state_size, - eps=rms_norm_eps, - has_weight=rms_norm_has_weight, - ) if use_rms_norm else None + self.c_layernorm = ( + RMSNorm( + ssm_state_size, + eps=rms_norm_eps, + has_weight=rms_norm_has_weight, + ) + if use_rms_norm + else None + ) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -157,7 +183,7 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): self.prefix = prefix def _ssm_transform( - self, x: torch.Tensor + self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.is_lora_enabled: # Lora kernel requires contiguous tensor. @@ -167,7 +193,8 @@ def _ssm_transform( time_step, B, C = torch.split( ssm_params, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1) + dim=-1, + ) if self.use_rms_norm: assert self.dt_layernorm is not None assert self.b_layernorm is not None @@ -185,8 +212,7 @@ def forward(self, hidden_states: torch.Tensor, output: torch.Tensor): self.prefix, ) - def forward_native(self, hidden_states: torch.Tensor, - output: torch.Tensor): + def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor): pass def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): @@ -232,8 +258,9 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) hidden_states_BC, gate = projected_states.chunk(2, dim=-2) - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) if attn_metadata is None: # V1 profile run @@ -281,10 +308,12 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - query_start_loc=query_start_loc_p) + query_start_loc=query_start_loc_p, + ) # 3. State Space Model sequence transformations. discrete_time_step_p, B_p, C_p = self._ssm_transform( - conv_out_p.transpose(-2, -1)) + conv_out_p.transpose(-2, -1) + ) time_proj_bias = self._time_proj_bias() # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) @@ -301,7 +330,8 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): delta_softplus=True, cache_indices=state_indices_tensor_p, has_initial_state=has_initial_states_p, - query_start_loc=query_start_loc_p) + query_start_loc=query_start_loc_p, + ) ssm_outputs.append(scan_out_p) if has_decode: @@ -312,39 +342,42 @@ def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor): conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d).transpose(0, 1) + conv_state_indices=state_indices_tensor_d, + ).transpose(0, 1) # 3. State Space Model sequence transformation. discrete_time_step_d, B_d, C_d = self._ssm_transform( - conv_out_d.transpose(-2, -1)) + conv_out_d.transpose(-2, -1) + ) time_proj_bias = self._time_proj_bias() # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) - scan_outputs_d = torch.empty_like( - hidden_states_BC_d.transpose(0, 1)) - selective_state_update(ssm_state, - conv_out_d.transpose(0, 1), - discrete_time_step_d.transpose(0, 1), - self.A, - B_d, - C_d, - self.D, - gate_d.transpose(0, 1), - time_proj_bias, - dt_softplus=True, - state_batch_indices=state_indices_tensor_d, - out=scan_outputs_d) + scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1)) + selective_state_update( + ssm_state, + conv_out_d.transpose(0, 1), + discrete_time_step_d.transpose(0, 1), + self.A, + B_d, + C_d, + self.D, + gate_d.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=state_indices_tensor_d, + out=scan_outputs_d, + ) scan_outputs_d = scan_outputs_d.transpose(0, 1) ssm_outputs.insert(0, scan_outputs_d) - scan_outputs_combined = ssm_outputs[0] if len( - ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) + scan_outputs_combined = ( + ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) + ) # 5. Final output projection if self.is_lora_enabled: # Lora kernel requires contiguous tensor. - scan_outputs_combined = scan_outputs_combined.transpose( - -2, -1).contiguous() + scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous() out = self.out_proj(scan_outputs_combined)[0] else: out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] @@ -373,8 +406,8 @@ def mamba_type(self) -> str: return "mamba1" def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.mamba1_attn import ( - Mamba1AttentionBackend) + from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend + return Mamba1AttentionBackend def _time_proj_bias(self) -> Optional[torch.Tensor]: @@ -406,27 +439,34 @@ def split_batch_to_prefill_and_decode( num_decodes: int, num_padded_decodes: int, ) -> PrefillDecodeSplit: - num_actual_tokens = num_prefill_tokens + num_padded_decodes # In v1, decode tokens come first, then prefill tokens. hidden_states_BC_d, hidden_states_BC_p = torch.split( hidden_states_BC[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], - dim=-1) - gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], - [num_padded_decodes, num_prefill_tokens], - dim=-1) + dim=-1, + ) + gate_d, gate_p = torch.split( + gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1 + ) # num_padded_decodes accounts for CUDA graph padding when applicable state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[:num_padded_decodes + num_prefills], + state_indices_tensor[: num_padded_decodes + num_prefills], [num_padded_decodes, num_prefills], - dim=0) - query_start_loc_p = (query_start_loc[-num_prefills - 1:] - - num_padded_decodes if num_prefills > 0 else None) - has_initial_states_p = has_initial_states[-num_prefills:] if ( - has_initial_states is not None and num_prefills > 0) else None + dim=0, + ) + query_start_loc_p = ( + query_start_loc[-num_prefills - 1 :] - num_padded_decodes + if num_prefills > 0 + else None + ) + has_initial_states_p = ( + has_initial_states[-num_prefills:] + if (has_initial_states is not None and num_prefills > 0) + else None + ) return PrefillDecodeSplit( hidden_states_BC_p=hidden_states_BC_p, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index bfb0666d361f..7589905ac927 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -11,28 +11,40 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) + causal_conv1d_fn, + causal_conv1d_update, +) from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_state_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update from vllm.model_executor.layers.mamba.ops.ssd_combined import ( - mamba_chunk_scan_combined_varlen) + mamba_chunk_scan_combined_varlen, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( - LoaderFunction, composed_weight_loader, sharded_weight_loader) + LoaderFunction, + composed_weight_loader, + sharded_weight_loader, +) from vllm.model_executor.utils import set_weight_attrs from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -43,12 +55,13 @@ # Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated @CustomOp.register("mixer2_gated_rms_norm") class Mixer2RMSNormGated(CustomOp): - - def __init__(self, - full_hidden_size: int, - full_n_groups: int, - use_rms_norm: bool = True, - eps: float = 1e-6): + def __init__( + self, + full_hidden_size: int, + full_n_groups: int, + use_rms_norm: bool = True, + eps: float = 1e-6, + ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() @@ -62,13 +75,13 @@ def __init__(self, if self.use_rms_norm: # Register norm weight only if we're actually applying RMSNorm self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) - set_weight_attrs(self.weight, - {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)}) else: # Avoid checkpoint mismatch by skipping unused parameter self.register_parameter("weight", None) - assert (self.full_hidden_size % self.tp_size == 0 - ), "Tensor parallel world size must divide hidden size." + assert self.full_hidden_size % self.tp_size == 0, ( + "Tensor parallel world size must divide hidden size." + ) def forward_native( self, @@ -111,8 +124,7 @@ def forward_native( group_count = hidden_dim // self.group_size x_grouped = x.view(*prefix_dims, group_count, self.group_size) variance = x_grouped.pow(2).mean(-1, keepdim=True) - x_grouped = x_grouped * torch.rsqrt(variance + - self.variance_epsilon) + x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon) x = x_grouped.view(*prefix_dims, hidden_dim) if redundant_tp: @@ -130,18 +142,19 @@ def forward_cuda( input_dtype = x.dtype if not self.use_rms_norm: # Keep gate in float32 for numerical stability during silu - return x * nn.functional.silu(gate.to( - torch.float32)).to(input_dtype) + return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype) - if (((self.n_groups % self.tp_size) != 0) or self.n_groups != 1): + if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1: return self.forward_native(x, gate) - return rms_norm_gated(x, - self.weight.data, - bias=None, - z=gate, - eps=self.variance_epsilon, - norm_before_gate=False) + return rms_norm_gated( + x, + self.weight.data, + bias=None, + z=gate, + eps=self.variance_epsilon, + norm_before_gate=False, + ) def mamba_v2_sharded_weight_loader( @@ -156,7 +169,6 @@ def mamba_v2_sharded_weight_loader( """ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: - # - track boundary of (sharded) param, and loaded_weight, respectively boundary, loaded_boundary = 0, 0 @@ -191,11 +203,12 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # seem to handle slices well. # https://github.com/python/mypy/issues/2410 param.data[ - boundary:(boundary + take), - ... # type: ignore[misc] - ] = loaded_weight[loaded_start_idx:(loaded_start_idx + - take) # type: ignore[misc] - ] # type: ignore[misc] + boundary : (boundary + take), ... # type: ignore[misc] + ] = loaded_weight[ + loaded_start_idx : ( + loaded_start_idx + take + ) # type: ignore[misc] + ] # type: ignore[misc] # move indexing boundaries boundary += shard_size @@ -217,23 +230,25 @@ class MambaMixer2(MambaBase, CustomOp): **selective** state spaces) """ - def __init__(self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation: str = "silu", - use_rms_norm: bool = True, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() # For TP, the sharding plan is as follows: @@ -253,15 +268,18 @@ def __init__(self, self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() - assert (num_heads % self.tp_size == 0 - ), "Tensor parallel world size must divide num heads." + assert num_heads % self.tp_size == 0, ( + "Tensor parallel world size must divide num heads." + ) assert (n_groups % self.tp_size) == 0 or n_groups == 1, ( "If tensor parallel world size does not divide num_groups, " - "then num_groups must equal 1.") + "then num_groups must equal 1." + ) - assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \ - quant_config is None, ( + assert ( + (n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None + ), ( "Tensor parallel currently supported for quantized models only " "if tensor parallel world size divides num groups." ) @@ -280,7 +298,8 @@ def __init__(self, # - but if n_groups cannot divide tp_size, we need to # extend some extra groups groups = MambaStateShapeCalculator.extra_groups_for_head_shards( - n_groups, self.tp_size) + n_groups, self.tp_size + ) self.n_groups = n_groups + groups self.groups_ssm_state_size = self.n_groups * self.ssm_state_size @@ -340,8 +359,7 @@ def __init__(self, # to the head shards group_shard_settings = ( self.groups_ssm_state_size, # expected model size - (self.n_groups - n_groups) * - self.ssm_state_size, # extra dims assigned + (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned n_groups == 1, # if there was only one group ) intermediate_settings = (intermediate_size, 0, False) @@ -355,8 +373,7 @@ def __init__(self, set_weight_attrs( self.conv1d.bias, { - "weight_loader": - mamba_v2_sharded_weight_loader( + "weight_loader": mamba_v2_sharded_weight_loader( [ intermediate_settings, group_shard_settings, @@ -372,8 +389,7 @@ def __init__(self, set_weight_attrs( self.conv1d.weight, { - "weight_loader": - mamba_v2_sharded_weight_loader( + "weight_loader": mamba_v2_sharded_weight_loader( [ intermediate_settings, group_shard_settings, @@ -391,8 +407,7 @@ def __init__(self, set_weight_attrs( self.in_proj.weight, { - "weight_loader": - mamba_v2_sharded_weight_loader( + "weight_loader": mamba_v2_sharded_weight_loader( [ intermediate_settings, # for gate intermediate_settings, @@ -418,17 +433,18 @@ def __init__(self, torch.empty( divide(num_heads, self.tp_size), dtype=torch.float32, - )) + ) + ) self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.use_rms_norm = use_rms_norm set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) a_weight_loader = composed_weight_loader( - sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + sharded_weight_loader(0), lambda x: -torch.exp(x.float()) + ) set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) - set_weight_attrs(self.dt_bias, - {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) self.out_proj = RowParallelLinear( intermediate_size, @@ -439,10 +455,9 @@ def __init__(self, prefix=f"{prefix}.out_proj", ) - self.norm = Mixer2RMSNormGated(intermediate_size, - n_groups, - self.use_rms_norm, - eps=rms_norm_eps) + self.norm = Mixer2RMSNormGated( + intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps + ) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -489,6 +504,9 @@ def forward_cuda( # stay the same and reused for all mamba layers in the same iteration attn_metadata: AttentionMetadata = forward_context.attn_metadata + assert self.cache_config is not None + mamba_block_size = self.cache_config.mamba_block_size + prefix_caching_enabled = self.cache_config.enable_prefix_caching if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] @@ -522,8 +540,9 @@ def forward_cuda( dim=-1, ) - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) # - get hidden_states, B and C after depthwise convolution. split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( @@ -538,10 +557,10 @@ def forward_cuda( if attn_metadata is None: # profile run - hidden_states_B_C = (hidden_states_B_C.transpose( - 0, 1).clone().transpose(0, 1)).contiguous() - hidden_states, _B, _C = split_hidden_states_B_C_fn( - hidden_states_B_C) + hidden_states_B_C = ( + hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1) + ).contiguous() + hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C) hidden_states = self.norm(hidden_states, gate) out, _ = self.out_proj(hidden_states) return out @@ -573,12 +592,42 @@ def forward_cuda( dim=0, ) + if prefix_caching_enabled: + # If prefix caching is enabled, retrieve the relevant variables + # for prefill and decode + block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( + torch.split( + attn_metadata.block_idx_last_computed_token, + [num_decodes, num_prefills], + dim=0, + ) + ) + block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = ( + torch.split( + attn_metadata.block_idx_last_scheduled_token, + [num_decodes, num_prefills], + dim=0, + ) + ) + # Prefill-only variables: + block_idx_first_scheduled_token_p = ( + attn_metadata.block_idx_first_scheduled_token_p + ) + num_computed_tokens_p = attn_metadata.num_computed_tokens_p + else: + block_idx_last_computed_token_d = None + block_idx_last_computed_token_p = None + block_idx_last_scheduled_token_d = None + block_idx_last_scheduled_token_p = None + block_idx_first_scheduled_token_p = None + num_computed_tokens_p = None + # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs preallocated_ssm_out = torch.empty( [ num_prefill_tokens + num_decodes, - (self.num_heads // self.tp_size) * self.head_dim + (self.num_heads // self.tp_size) * self.head_dim, ], dtype=hidden_states.dtype, device=hidden_states.device, @@ -592,10 +641,21 @@ def forward_cuda( # Process prefill requests if has_prefill: # 2. Convolution sequence transformation - # - "cache_indices" updates the conv_state cache in positions - # pointed to by "state_indices_tensor" + # - It will read the initial states for every sequence, + # that has "has_initial_states_p" == True, + # from "cache_indices", using "state_indices_tensor_p". + # - It updates the "conv_state" cache in positions pointed + # to by "state_indices_tensor_p". + # In particular, it will always write the state at the + # sequence end. + # In addition, "block_idx_first_scheduled_token_p" and + # "block_idx_last_scheduled_token_p" + # are provided (which are pointers into + # "state_indices_tensor_p"), it will write additional cache + # states aligned at "block_size_to_align". x = hidden_states_B_C_p.transpose( - 0, 1) # this is the form that causal-conv see + 0, 1 + ) # this is the form that causal-conv see hidden_states_B_C_p = causal_conv1d_fn( x, conv_weights, @@ -604,31 +664,40 @@ def forward_cuda( conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, + block_idx_first_scheduled_token=block_idx_first_scheduled_token_p, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_p, + initial_state_idx=block_idx_last_computed_token_p, + num_computed_tokens=num_computed_tokens_p, + block_size_to_align=mamba_block_size, metadata=attn_metadata, - query_start_loc=query_start_loc_p).transpose( - 0, 1)[:num_prefill_tokens] + query_start_loc=query_start_loc_p, + ).transpose(0, 1)[:num_prefill_tokens] - hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( - hidden_states_B_C_p) + hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p) # 3. State Space Model sequence transformation initial_states = None - if (has_initial_states_p is not None and prep_initial_states): + if has_initial_states_p is not None and prep_initial_states: + kernel_ssm_indices = state_indices_tensor_p + if prefix_caching_enabled: + kernel_ssm_indices = state_indices_tensor_p.gather( + 1, block_idx_last_computed_token_p.unsqueeze(1) + ).squeeze(1) initial_states = torch.where( has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) + ssm_state[kernel_ssm_indices], + 0, + ) # NOTE: final output is an in-place update of out tensor varlen_states = mamba_chunk_scan_combined_varlen( - hidden_states_p.view(num_prefill_tokens, - self.num_heads // self.tp_size, - self.head_dim), + hidden_states_p.view( + num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim + ), dt_p, self.A, - B_p.view(num_prefill_tokens, self.n_groups // self.tp_size, - -1), - C_p.view(num_prefill_tokens, self.n_groups // self.tp_size, - -1), + B_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1), + C_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1), chunk_size=chunk_size, D=self.D, z=None, @@ -638,18 +707,110 @@ def forward_cuda( cu_chunk_seqlens=cu_chunk_seqlen_p, last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, + return_intermediate_states=prefix_caching_enabled, dt_softplus=True, dt_limit=(0.0, float("inf")), - out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, - self.head_dim), - state_dtype=ssm_state.dtype) + out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), + state_dtype=ssm_state.dtype, + ) + + if prefix_caching_enabled: + # The chunk_stride is the number of chunks per mamba block + # e.g., if mamba_block_size = 512 and chunk_size = 256, + # then chunk_stride = 2 + chunk_stride = mamba_block_size // chunk_size + + # Save state for sequences with more than just final state + for seq_idx in range(num_prefills): + # Block index for the first scheduled token + block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[ + seq_idx + ] + + # Block index for the last scheduled token + block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[ + seq_idx + ] + + # Number of blocks that need to be written + n_blocks_to_fill = ( + block_idx_last_scheduled_token - block_idx_first_scheduled_token + ) + + # Skip sequences that don't have any blocks to fill + if n_blocks_to_fill == 0: + continue + + # Look up the state indices + cache_blocks_to_fill = state_indices_tensor_p[ + seq_idx, + block_idx_first_scheduled_token:block_idx_last_scheduled_token, + ] + + # First chunk index for this sequence + if seq_idx == 0: + first_chunk = 0 + else: + first_chunk = 1 + last_chunk_indices_p[seq_idx - 1] + + # First chunk that is aligned on the mamba block boundary + first_aligned_chunk = first_chunk + chunk_stride - 1 + + # Calculate the number of computed tokens that were not + # already cached + num_unaligned_computed_tokens = ( + num_computed_tokens_p[seq_idx] % mamba_block_size + ) - # update ssm states - # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor - ssm_state[state_indices_tensor_p] = varlen_states + if num_unaligned_computed_tokens > 0: + # If the number of computed tokens is not block aligned, + # then we need to shift the index accordingly + first_aligned_chunk -= ( + num_unaligned_computed_tokens // chunk_size + ) + + # Get states to write + from_where = varlen_states[ + first_aligned_chunk : first_aligned_chunk + + n_blocks_to_fill * chunk_stride : chunk_stride + ] + + # Write the states + ssm_state[cache_blocks_to_fill] = from_where + + # For all seqs, store the last state (note: might be partial): + ssm_state[ + state_indices_tensor_p.gather( + 1, block_idx_last_scheduled_token_p.unsqueeze(1) + ).squeeze(1) + ] = varlen_states[last_chunk_indices_p] + + else: + # update ssm states + # - varlen state is a (num_prefills, nheads, headdim, dstate) + # tensor + ssm_state[state_indices_tensor_p] = varlen_states # Process decode requests if has_decode: + if prefix_caching_enabled: + state_indices_tensor_d_input = state_indices_tensor_d.gather( + 1, block_idx_last_computed_token_d.unsqueeze(1) + ).squeeze(1) + state_indices_tensor_d_output = state_indices_tensor_d.gather( + 1, block_idx_last_scheduled_token_d.unsqueeze(1) + ).squeeze(1) + # for decode: + # block_idx_first_scheduled_token_d == + # block_idx_last_scheduled_token_d + # at block boundaries: + # block_idx_first_scheduled_token_d > + # block_idx_last_computed_token_d + else: + # Without caching, read and write in-place to the same blocks: + state_indices_tensor_d_input = state_indices_tensor_d + state_indices_tensor_d_output = state_indices_tensor_d + # 2. Convolution sequence transformation hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d, @@ -657,22 +818,28 @@ def forward_cuda( conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d) + conv_state_indices=state_indices_tensor_d, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_d, + initial_state_idx=block_idx_last_computed_token_d, + ) - hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn( - hidden_states_B_C_d) + hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d) # 3. State Space Model sequence transformation n_groups = self.n_groups // self.tp_size - A_d = self.A[:, None, ...][:, :, None].expand( - -1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + A_d = ( + self.A[:, None, ...][:, :, None] + .expand(-1, self.head_dim, self.ssm_state_size) + .to(dtype=torch.float32) + ) dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) D_d = self.D[:, None, ...].expand(-1, self.head_dim) B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups) C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups) hidden_states_d = hidden_states_d.view( - -1, self.num_heads // self.tp_size, self.head_dim) + -1, self.num_heads // self.tp_size, self.head_dim + ) # - the hidden is reshaped into (bs, num_heads, head_dim) # - mamba_cache_params.ssm_state's slots will be selected @@ -689,17 +856,16 @@ def forward_cuda( z=None, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices_tensor_d, - out=preallocated_ssm_out_d.view(num_decodes, -1, - self.head_dim), + state_batch_indices=state_indices_tensor_d_input, + dst_state_batch_indices=state_indices_tensor_d_output, + out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), ) # 4. gated MLP # GatedRMSNorm internally applying SiLU to the gate # SiLU is applied internally before normalization, unlike standard # norm usage - hidden_states = self.norm(preallocated_ssm_out, - gate[:num_actual_tokens]) + hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens]) # 5. Final linear projection output[:num_actual_tokens], _ = self.out_proj(hidden_states) @@ -729,8 +895,8 @@ def mamba_type(self) -> str: return "mamba2" def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.mamba2_attn import ( - Mamba2AttentionBackend) + from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend + return Mamba2AttentionBackend @@ -742,9 +908,7 @@ def mamba_mixer2( ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.forward_cuda(hidden_states=hidden_states, - output=output, - mup_vector=mup_vector) + self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector) def mamba_mixer2_fake( diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 677a4b9d87fc..21c36617a872 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -10,7 +10,6 @@ class MambaStateDtypeCalculator: - @classmethod def linear_attention_state_dtype( cls, @@ -21,7 +20,7 @@ def linear_attention_state_dtype( if mamba_cache_dtype == "float32": raise ValueError("fp32 state for minimax is not yet supported") state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) - return (state_dtype, ) + return (state_dtype,) @classmethod def mamba1_state_dtype( @@ -30,8 +29,9 @@ def mamba1_state_dtype( mamba_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: - return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype, - mamba_ssm_cache_dtype) + return cls._mamba_state_dtype( + model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype + ) @classmethod def mamba2_state_dtype( @@ -40,8 +40,9 @@ def mamba2_state_dtype( mamba_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: - return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype, - mamba_ssm_cache_dtype) + return cls._mamba_state_dtype( + model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype + ) @classmethod def _mamba_state_dtype( @@ -50,13 +51,11 @@ def _mamba_state_dtype( mamba_cache_dtype: MambaDType, mamba_ssm_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: - conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, - model_dtype) + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) if mamba_ssm_cache_dtype == "auto": temporal_state_dtype = conv_state_dtype else: - temporal_state_dtype = ( - STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]) + temporal_state_dtype = STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype] return (conv_state_dtype, temporal_state_dtype) @@ -66,9 +65,8 @@ def short_conv_state_dtype( model_dtype: Union[ModelDType, torch.dtype], mamba_cache_dtype: MambaDType, ) -> tuple[torch.dtype, ...]: - conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, - model_dtype) - return (conv_state_dtype, ) + conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype) + return (conv_state_dtype,) @classmethod def gated_delta_net_state_dtype( @@ -81,7 +79,6 @@ def gated_delta_net_state_dtype( class MambaStateShapeCalculator: - @classmethod def linear_attention_state_shape( cls, @@ -89,9 +86,8 @@ def linear_attention_state_shape( tp_size: int, head_dim: int, ) -> tuple[tuple[int, int, int], ...]: - state_shape = (num_heads // tp_size, head_dim, head_dim) - return (state_shape, ) + return (state_shape,) @classmethod def mamba1_state_shape( @@ -101,11 +97,9 @@ def mamba1_state_shape( state_size: int, conv_kernel: int, ) -> tuple[tuple[int, int], tuple[int, int]]: - conv_state_shape = (divide(intermediate_size, - tp_world_size), conv_kernel - 1) + conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) - temporal_state_shape = (divide(intermediate_size, - tp_world_size), state_size) + temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) conv_state_shape = conv_state_shape[1], conv_state_shape[0] @@ -124,8 +118,7 @@ def mamba2_state_shape( ) -> tuple[tuple[int, int], tuple[int, int, int]]: # if n_groups is not divisible by world_size, need to extend the shards # to ensure all groups needed by a head is sharded along with it - n_groups = n_groups + cls.extra_groups_for_head_shards( - n_groups, tp_world_size) + n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size) # heads and n_groups are TP-ed conv_dim = intermediate_size + 2 * n_groups * state_size @@ -135,8 +128,7 @@ def mamba2_state_shape( # These are not TP-ed as they depend on A, dt_bias, D # - they are typically small # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) - temporal_state_shape = (divide(num_heads, - tp_world_size), head_dim, state_size) + temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size) return conv_state_shape, temporal_state_shape @classmethod @@ -148,7 +140,7 @@ def short_conv_state_shape( ) -> tuple[tuple[int, int]]: conv_dim = divide(intermediate_size, tp_world_size) conv_state_shape = (conv_kernel - 1, conv_dim) - return (conv_state_shape, ) + return (conv_state_shape,) @classmethod def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): @@ -173,7 +165,7 @@ def gated_delta_net_state_shape( conv_kernel_size: int, num_spec: int = 0, ): - conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads) + conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads conv_state_shape = ( divide(conv_dim, tp_world_size), conv_kernel_size - 1 + num_spec, @@ -181,6 +173,9 @@ def gated_delta_net_state_shape( conv_state_shape = conv_state_shape[1], conv_state_shape[0] - temporal_state_shape = (divide(num_v_heads, - tp_world_size), head_k_dim, head_v_dim) + temporal_state_shape = ( + divide(num_v_heads, tp_world_size), + head_k_dim, + head_v_dim, + ) return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index c4102c4753c7..ec486d3b9267 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -20,40 +20,41 @@ def _causal_conv1d_fwd_kernel( # continuous batching w_ptr, # (dim, width) bias_ptr, initial_states_ptr, # conv_states_ptr - cache_indices_ptr, # conv_state_indices_ptr + cache_indices_ptr, # (batch, n_blocks + padding) The second dimension contains + # the block indices relevant for each sequence + # plus potential 0-padding at the beginning and at the end has_initial_states_ptr, query_start_loc_ptr, batch_ptr, token_chunk_offset_ptr, + block_idx_first_scheduled_token, # (batch,) + block_idx_last_scheduled_token, # (batch,) + initial_state_idx, # (batch,) + num_computed_tokens, # (batch,) o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions - batch: tl.int32, # actually padded_batch dim: tl.constexpr, seqlen: tl.int32, # cu_seqlen num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines # Strides - stride_x_seq: tl.constexpr, # stride to get to next sequence, stride_x_dim: tl.constexpr, # stride to get to next feature-value, - stride_x_token: tl. - constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) stride_w_dim: tl.constexpr, # stride to get to next dim-axis value stride_w_width: tl.constexpr, # stride to get to next width-axis value stride_istate_seq: tl.constexpr, stride_istate_dim: tl.constexpr, stride_istate_token: tl.constexpr, stride_cache_indices: tl.constexpr, - stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, + stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M # others pad_slot_id: tl.constexpr, # Meta-parameters HAS_BIAS: tl.constexpr, KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, - HAS_INITIAL_STATES: tl.constexpr, - HAS_CACHE: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, USE_PAD_SLOT: tl.constexpr, NP2_STATELEN: tl.constexpr, BLOCK_M: tl.constexpr, @@ -64,7 +65,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching stride_conv_state_seq = stride_istate_seq stride_conv_state_dim = stride_istate_dim stride_conv_state_tok = stride_istate_token - state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value + state_len = ( + KERNEL_WIDTH - 1 + ) # can be passed via argument if it's not the same as this value # one program handles one chunk in a single sequence # rather than mixing sequences - to make updating initial_states across sequences efficiently @@ -84,27 +87,62 @@ def _causal_conv1d_fwd_kernel( # continuous batching # find the actual sequence length seqlen = sequence_end_index - sequence_start_index + B_size: tl.constexpr = stride_block_m * BLOCK_M + + if IS_APC_ENABLED: + # Handle the case if prefix caching is enabled. + # In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr" + + # Get the length of the completed sequence so far and compute the offset. + current_first_index = tl.load(block_idx_first_scheduled_token + idx_seq) + current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) + sequence_completed_index = tl.load(num_computed_tokens + idx_seq) + + # Compute the offset where the first stride_block_m-aligned first full block is + # Value in "token-space" + sequence_completed_offset_token = sequence_completed_index % B_size + seq_completed_offset = B_size - sequence_completed_offset_token + seq_end_offset = (seqlen - seq_completed_offset) % B_size + last_full_block_token_index = sequence_end_index - seq_end_offset + # If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one + if seq_end_offset == 0: + last_full_block_token_index = last_full_block_token_index - B_size + + # Get the number of blocks to be filled for the current sequence + # If n_block_to_fill = 0, then only the state at the sequence end is stored + n_block_to_fill = current_last_index - current_first_index + + # Get the index of the init block + conv_state_init_index = tl.load(initial_state_idx + idx_seq) + else: + n_block_to_fill = 0 + current_last_index = 0 + conv_state_init_index = 0 + current_first_index = 0 + last_full_block_token_index = 0 + token_offset = BLOCK_M * chunk_offset segment_len = min(BLOCK_M, seqlen - token_offset) # base of the sequence - x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] + x_base = ( + x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim + ) # [BLOCK_N,] + + # cache_idx + conv_states_input_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index + ).to(tl.int64) - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_cache_indices).to( - tl.int64) - else: - # cache_idx - conv_state_batch_coord = idx_seq if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: + if conv_states_input_coord == pad_slot_id: # not processing as this is not the actual sequence return - conv_states_base = (conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + conv_states_base = ( + conv_states_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] @@ -113,14 +151,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] if chunk_offset == 0: # read from conv_states - load_init_state = False - if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES - load_init_state = tl.load(has_initial_states_ptr + idx_seq).to( - tl.int1) + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) if load_init_state: # load from conv_states - prior_tokens = conv_states_base + (state_len - - 1) * stride_conv_state_tok + prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok mask_w = idx_feats < dim if KERNEL_WIDTH == 2: conv_states_ptrs = prior_tokens # [BLOCK_N] @@ -150,40 +184,56 @@ def _causal_conv1d_fwd_kernel( # continuous batching # prior-tokens are zeros if KERNEL_WIDTH >= 2: # STRATEGY1 # first chunk and does not have prior-token, so just set to 0 - col0 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 3: # STRATEGY1 - col1 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 4: # STRATEGY1 - col2 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) if KERNEL_WIDTH >= 5: # STRATEGY1 - col3 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) # STEP 2: # here prepare data for updating conv_state - if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + if ( + state_len <= seqlen + ): # SMALL_CACHE=True (only move part of 'x' into conv_state cache) # just read from 'x' # copy 'x' data to conv_state # load only 'x' data (and set 0 before 'x' if seqlen < state_len) idx_tokens_last = (seqlen - state_len) + tl.arange( - 0, NP2_STATELEN) # [BLOCK_M] - x_ptrs = x_ptr + ( - (sequence_start_index + idx_tokens_last) * - stride_x_token)[:, None] + ( - idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] - mask_x = ((idx_tokens_last >= 0)[:, None] & - (idx_tokens_last < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + 0, NP2_STATELEN + ) # [BLOCK_M] + x_ptrs = ( + x_ptr + + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) # [BLOCK_M,BLOCK_N,] + mask_x = ( + (idx_tokens_last >= 0)[:, None] + & (idx_tokens_last < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - conv_states_ptrs_target = conv_states_base[None, :] + ( - idx_tokens_conv * stride_conv_state_tok)[:, None] - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats - < dim)[None, :] + # Compute the offset where the last block should be written in the conv_states + conv_states_output_coord = tl.load( + conv_state_indices_ptr + + idx_seq * stride_cache_indices + + current_last_index + ).to(tl.int64) + + conv_states_ptrs_target = ( + conv_states_ptr + + (conv_states_output_coord * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[None, :] + ( # [BLOCK_N,] + idx_tokens_conv * stride_conv_state_tok + )[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] tl.debug_barrier() # NOTE: use this due to bug in Triton compiler - tl.store(conv_states_ptrs_target, new_conv_state, mask) + tl.store(conv_states_ptrs_target, loaded_x, mask) else: if load_init_state: @@ -191,39 +241,43 @@ def _causal_conv1d_fwd_kernel( # continuous batching idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] conv_states_ptrs_source = ( - conv_states_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, - None] + conv_states_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens_conv + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) + mask = ( + (conv_states_input_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) VAL = state_len - seqlen - x_ptrs = x_base[None, :] + ( - (idx_tokens_conv - VAL) * - stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + x_ptrs = ( + x_base[None, :] + + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] - mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & - (idx_tokens_conv - VAL < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier( - ) # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load + tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load new_conv_state = tl.where( mask, conv_state, loaded_x ) # BUG in 'tl.where' which requires a barrier before this - conv_states_ptrs_target = conv_states_base + ( - idx_tokens_conv * - stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv - < state_len)[:, None] & (idx_feats < dim)[None, :] + conv_states_ptrs_target = ( + conv_states_base + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ + None, : + ] tl.store(conv_states_ptrs_target, new_conv_state, mask) else: # load_init_state == False # update conv_state by shifting left, BUT @@ -232,21 +286,25 @@ def _causal_conv1d_fwd_kernel( # continuous batching VAL = state_len - seqlen - x_ptrs = x_base[None, :] + ( - (idx_tokens_conv - VAL) * - stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + x_ptrs = ( + x_base[None, :] + + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] - mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & - (idx_tokens_conv - VAL < seqlen)[:, None] & - (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index new_conv_state = tl.load(x_ptrs, mask_x, 0.0) - conv_states_ptrs_target = conv_states_base + ( - idx_tokens_conv * - stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv - < state_len)[:, None] & (idx_feats < dim)[None, :] + conv_states_ptrs_target = ( + conv_states_base + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ + None, : + ] tl.store(conv_states_ptrs_target, new_conv_state, mask) else: # chunk_offset > 0 @@ -256,37 +314,84 @@ def _causal_conv1d_fwd_kernel( # continuous batching mask_w = idx_feats < dim if KERNEL_WIDTH == 2: conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 3: conv_states_ptrs = prior_tokens # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 4: conv_states_ptrs = prior_tokens # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") if KERNEL_WIDTH == 5: # ruff: noqa: F841 conv_states_ptrs = prior_tokens # [BLOCK_N] - col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + + # Store intermediate states aligned with stride_block_m + # The additional states are cached starting from the last stride_block_m. + # For example: + # If n_block_to_fill = 0, then only the state at the sequence end is cached and the process below is not involved. + # If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last + # stride_block_m are cached. + # For example chunk_offset = n_block_to_fill stores the state at last_full_block + if (chunk_offset - 1) < n_block_to_fill: + # Store the states at the chunk boundaries from the start of the sequence + idx_tokens_last = ( + last_full_block_token_index + - (n_block_to_fill - chunk_offset) * B_size + - state_len + ) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = ( + x_ptr + + (idx_tokens_last * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) # [BLOCK_M,BLOCK_N,] + + mask_x = (idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[ + None, : + ] # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # cache_idx + conv_states_output_coord = tl.load( + conv_state_indices_ptr + + idx_seq * stride_cache_indices + + current_first_index + + (chunk_offset - 1) + ).to(tl.int64) + + conv_states_ptrs_target = ( + conv_states_ptr + + (conv_states_output_coord * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[None, :] + ( # [BLOCK_N,] + idx_tokens_conv * stride_conv_state_tok + )[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, loaded_x, mask) if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) x_base_1d = x_base + token_offset * stride_x_token # starting of chunk @@ -310,7 +415,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching matrix_w = w_col0 matrix_x = col0 for j in tl.static_range(KERNEL_WIDTH): - if KERNEL_WIDTH == 2: if j == 1: # KERNEL_WIDTH-1: matrix_w = w_col1 @@ -351,9 +455,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) mask_1d = (idx_token < segment_len) & ( - idx_feats < dim) # token-index # feature-index - o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token - ) * stride_o_token + (idx_feats * stride_o_dim) + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + + (sequence_start_index + token_offset + idx_token) * stride_o_token + + (idx_feats * stride_o_dim) + ) tl.store(o_ptrs, acc, mask=mask_1d) @@ -368,6 +476,11 @@ def causal_conv1d_fn( has_initial_state: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", pad_slot_id: int = PAD_SLOT_ID, + block_idx_first_scheduled_token: Optional[torch.Tensor] = None, + block_idx_last_scheduled_token: Optional[torch.Tensor] = None, + initial_state_idx: Optional[torch.Tensor] = None, + num_computed_tokens: Optional[torch.Tensor] = None, + block_size_to_align=0, metadata=None, validate_data=False, ): @@ -378,7 +491,7 @@ def causal_conv1d_fn( sequences are concatenated from left to right for varlen weight: (dim, width) conv_states: (...,dim,width - 1) itype - updated inplace if provided + updated inplace if cache_indices are not provided [it use `cache_indices` to get the index to the cache of conv_state for that sequence conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True @@ -410,7 +523,16 @@ def causal_conv1d_fn( for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - + block_idx_first_scheduled_token: (batch,), dtype int32 + The pointer into cache_indices, where the first cache block to be filled is located. + block_idx_last_scheduled_token: (batch,), dtype int32 + The pointer into cache_indices, where the last cache block to be filled is located. + initial_state_idx: (batch,), dtype int32 + The pointer into cache_indices, where the cache block containing the initial state is located. + num_computed_tokens: (batch,), dtype int32 + The number of tokens already completed for each sequence + block_size_to_align: int + The block size to align the cached states to out: same shape as `x` """ if isinstance(activation, bool) and activation: @@ -427,21 +549,15 @@ def causal_conv1d_fn( batch_ptr = metadata.batch_ptr token_chunk_offset_ptr = metadata.token_chunk_offset_ptr else: - seqlens = query_start_loc.diff().to('cpu') + seqlens = query_start_loc.diff().to("cpu") args = seqlens MAX_NUM_PROGRAMS = 1024 batch_ptr = torch.full( - (MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=x.device + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device ) # tracking which seq-idx the Triton program is handling token_chunk_offset_ptr = torch.full( - (MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=x.device + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device ) # tracking BLOCK_M-based index in the sequence the Triton program is handling is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) @@ -451,7 +567,6 @@ def causal_conv1d_fn( np2_statelen = triton.next_power_of_2(state_len) padded_batch = query_start_loc.size(0) - 1 - stride_x_seq = 0 stride_x_dim = x.stride(0) stride_x_token = x.stride(1) stride_w_dim = weight.stride(0) @@ -460,6 +575,7 @@ def causal_conv1d_fn( stride_istate_dim = 0 stride_istate_token = 0 num_cache_lines = 0 + BLOCK_M = 8 if conv_states is not None: # extensions to support vLLM: # 1. conv_states is used to replaced initial_states @@ -467,23 +583,22 @@ def causal_conv1d_fn( # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] # 4. computation can be skipped if cache_indices[idx] == pad_slot_id num_cache_lines = conv_states.size(0) - assert (num_cache_lines == conv_states.shape[0] - and dim == conv_states.shape[1] - and width - 1 <= conv_states.shape[2]) + assert ( + num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2] + ) stride_istate_seq = conv_states.stride(0) stride_istate_dim = conv_states.stride(1) stride_istate_token = conv_states.stride(2) assert stride_istate_dim == 1 if out.dim() == 2: - stride_o_seq = 0 stride_o_dim = out.stride(0) stride_o_token = out.stride(1) else: - stride_o_seq = out.stride(0) stride_o_dim = out.stride(1) stride_o_token = out.stride(2) - stride_cache_indices = cache_indices.stride( - 0) if cache_indices is not None else 0 + stride_cache_indices = cache_indices.stride(0) if cache_indices is not None else 0 if validate_data: assert x.dim() == 2 @@ -497,11 +612,19 @@ def causal_conv1d_fn( assert cache_indices.dim() == 1 assert padded_batch == cache_indices.size(0) if has_initial_state is not None: - assert has_initial_state.size() == (padded_batch, ) - assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`" + assert has_initial_state.size() == (padded_batch,) + assert conv_states is not None, ( + "ERROR: `has_initial_state` is used, which needs also `conv_states`" + ) assert weight.stride(1) == 1 assert (dim, width) == weight.shape assert is_channel_last, "Need to run in channel-last layout" + if block_size_to_align is not None and block_size_to_align > 0: + assert (block_size_to_align % BLOCK_M) == 0, ( + "The mamba block size needs to be divisible by the BLOCK_M" + ) + else: + block_size_to_align = BLOCK_M if metadata is None: @@ -523,44 +646,45 @@ def num_program(META, seqlens): if META["batch_ptr"].nelement() < len(mlist): newlen = len(mlist) + 1 META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - META["token_chunk_offset_ptr"].resize_(newlen).fill_( - PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) if META["batch_ptr"].nelement() >= len(mlist): - META["batch_ptr"][0:len(mlist)].copy_( - torch.from_numpy(np.array(mlist))) - META["token_chunk_offset_ptr"][0:len(mlist)].copy_( - torch.from_numpy(np.array(offsetlist))) + META["batch_ptr"][0 : len(mlist)].copy_( + torch.from_numpy(np.array(mlist)) + ) + META["token_chunk_offset_ptr"][0 : len(mlist)].copy_( + torch.from_numpy(np.array(offsetlist)) + ) META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to( - META["x_ptr"].device) + META["x_ptr"].device + ) return tot else: def num_program(META, nums_dict): - tot = nums_dict[META["BLOCK_M"]]['tot'] + tot = nums_dict[META["BLOCK_M"]]["tot"] - mlist = nums_dict[META["BLOCK_M"]]['mlist'] - mlist_len = nums_dict[META["BLOCK_M"]]['mlist_len'] + mlist = nums_dict[META["BLOCK_M"]]["mlist"] + mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"] - offsetlist = nums_dict[META["BLOCK_M"]]['offsetlist'] + offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"] if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] - META["token_chunk_offset_ptr"] = nums_dict[ - META["BLOCK_M"]]["token_chunk_offset_ptr"] + META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][ + "token_chunk_offset_ptr" + ] else: if META["batch_ptr"].nelement() < mlist_len: newlen = mlist_len + 1 META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - META["token_chunk_offset_ptr"].resize_(newlen).fill_( - PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) if META["batch_ptr"].nelement() >= mlist_len: META["batch_ptr"][0:mlist_len].copy_(mlist) - META["token_chunk_offset_ptr"][0:mlist_len].copy_( - offsetlist) + META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist) return tot def grid(META): @@ -584,14 +708,16 @@ def grid(META): query_start_loc, batch_ptr, token_chunk_offset_ptr, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, + num_computed_tokens, out, # Matrix dimensions - padded_batch, dim, cu_seqlen, num_cache_lines, # stride - stride_x_seq, stride_x_dim, stride_x_token, stride_w_dim, @@ -600,22 +726,20 @@ def grid(META): stride_istate_dim, stride_istate_token, stride_cache_indices, - stride_o_seq, stride_o_dim, stride_o_token, + block_size_to_align // BLOCK_M, # others pad_slot_id, # META HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], - HAS_INITIAL_STATES=has_initial_state is not None, - HAS_CACHE=conv_states is not None, - IS_CONTINUOUS_BATCHING=cache_indices is not None, + IS_APC_ENABLED=block_idx_last_scheduled_token is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, - #launch_cooperative_grid=True - BLOCK_M=8, + # launch_cooperative_grid=True + BLOCK_M=BLOCK_M, BLOCK_N=256, num_stages=2, ) @@ -629,10 +753,11 @@ def _causal_conv1d_update_kernel( w_ptr, # (dim, width) bias_ptr, conv_state_ptr, - cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, num_accepted_tokens_ptr, query_start_loc_ptr, # (batch + 1) + block_idx_last_scheduled_token, # (batch,) + initial_state_idx, # (batch,) o_ptr, # (batch, dim, seqlen) # Matrix dimensions batch: int, @@ -660,7 +785,7 @@ def _causal_conv1d_update_kernel( KERNEL_WIDTH: tl.constexpr, SILU_ACTIVATION: tl.constexpr, IS_VARLEN: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_APC_ENABLED: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, NP2_STATELEN: tl.constexpr, USE_PAD_SLOT: tl.constexpr, @@ -674,25 +799,29 @@ def _causal_conv1d_update_kernel( # [BLOCK_N,] elements along the feature-dimension (channel) idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices).to( - tl.int64) + if IS_APC_ENABLED: + # Get the state from the initial_state_idx + conv_state_init = tl.load(initial_state_idx + idx_seq) + current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) else: - conv_state_batch_coord = idx_seq + conv_state_init = 0 + current_last_index = 0 + + # cache_idx + conv_states_input_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init + ).to(tl.int64) + if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: + if conv_states_input_coord == pad_slot_id: # not processing as this is not the actual sequence return if IS_VARLEN: query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) - query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to( - tl.int64) + query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64) # revise state_len and seqlen - state_len = state_len - (seqlen - - (query_end_index - query_start_index)) + state_len = state_len - (seqlen - (query_end_index - query_start_index)) seqlen = query_end_index - query_start_index x_offset = query_start_index * stride_x_token o_offset = query_start_index * stride_o_token @@ -720,14 +849,17 @@ def _causal_conv1d_update_kernel( # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] # - and so on. conv_state_token_offset = ( - tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1) + tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 + ) else: conv_state_token_offset = 0 # STEP 1: READ init_state data - conv_states_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) + conv_states_base = ( + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) mask_w = idx_feats < dim prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok @@ -754,35 +886,50 @@ def _causal_conv1d_update_kernel( # window manner, at each forward pass, the tokens are shift by 1, so we # load since idx_tokens + 1. conv_state_ptrs_source = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + - conv_state_token_offset * stride_conv_state_tok + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * - stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) + conv_state_ptr + + (conv_states_input_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[ + :, None + ] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_states_input_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) VAL = state_len - seqlen x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] - x_ptrs = x_base[None, :] + ( - (idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + x_ptrs = ( + x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] - mask_x = ((idx_tokens - VAL >= 0)[:, None] & - (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index loaded_x = tl.load(x_ptrs, mask_x, 0.0) tl.debug_barrier() new_conv_state = tl.where(mask, conv_state, loaded_x) - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = conv_state_base + ( - idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + # Get the state from the initial_state_idx + # cache_idx + conv_states_offset = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index + ).to(tl.int64) + conv_state_ptrs_target = ( + conv_state_ptr + + (conv_states_offset * stride_conv_state_seq) # Offset from seq + + (idx_feats * stride_conv_state_dim) + )[None, :] + ( # [BLOCK_N,] + idx_tokens * stride_conv_state_tok + )[:, None] mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] tl.store(conv_state_ptrs_target, new_conv_state, mask) @@ -790,10 +937,11 @@ def _causal_conv1d_update_kernel( if HAS_BIAS: bias = bias_ptr + idx_feats mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) # STEP 4: # PRE-LOAD WEIGHTS @@ -909,10 +1057,12 @@ def _causal_conv1d_update_kernel( if SILU_ACTIVATION: acc = acc / (1 + tl.exp(-acc)) - mask_1d = (idx_token < seqlen) & (idx_feats < dim - ) # token-index # feature-index - o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * - stride_o_dim) + mask_1d = (idx_token < seqlen) & ( + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim) + ) tl.store(o_ptrs, acc, mask=mask_1d) @@ -923,12 +1073,13 @@ def causal_conv1d_update( weight: torch.Tensor, bias: Optional[torch.Tensor] = None, activation: Union[bool, str, None] = None, - cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, num_accepted_tokens: Optional[torch.Tensor] = None, query_start_loc: Optional[torch.Tensor] = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, + block_idx_last_scheduled_token: Optional[torch.Tensor] = None, + initial_state_idx: Optional[torch.Tensor] = None, validate_data=False, ): """ @@ -942,15 +1093,14 @@ def causal_conv1d_update( conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state - starting at the index - @cache_seqlens % state_len. conv_state_indices: (batch,), dtype int32 If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. + block_idx_last_scheduled_token: (batch,), dtype int32 + The pointer into conv_state_indices, where the last cache block to be filled is located. + initial_state_idx: (batch,), dtype int32 + The pointer into conv_state_indices, where the cache block containing the initial state is located. num_accepted_tokens: (batch,), dtype int32 If not None, it indicates the number of accepted tokens for each sequence in the batch. @@ -963,15 +1113,14 @@ def causal_conv1d_update( If query_start_loc is not None, this indicates the maximum query length in the batch. pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded + if conv_state_indices is passed, lets the kernel identify padded entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` """ if validate_data: - assert cache_seqlens is None # not implemented yet - ok for vLLM assert pad_slot_id is not None assert x.stride(1) == 1 if isinstance(activation, bool): @@ -998,20 +1147,19 @@ def causal_conv1d_update( if validate_data: assert dim == weight.size(0) - assert conv_state.stride( - -2 - ) == 1, f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert conv_state.stride(-2) == 1, ( + f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + ) assert state_len >= width - 1 # when above happens, we don't shift-left to keep any records in conv_state assert dim == conv_state.size(1) if conv_state_indices is None: assert conv_state.size(0) >= batch else: - assert (batch, ) == conv_state_indices.shape + assert (batch,) == conv_state_indices.shape assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this - assert cache_seqlens is None # not needed for vLLM - circular buffer # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' out = x @@ -1028,10 +1176,10 @@ def causal_conv1d_update( stride_o_token, stride_o_dim = out.stride() stride_o_seq = 0 - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = ( + conv_state_indices.stride(0) if conv_state_indices is not None else 0 ) - stride_state_indices = conv_state_indices.stride( - 0) if conv_state_indices is not None else 0 if num_accepted_tokens is not None: state_len = width - 1 + (seqlen - 1) # effective state_len needed else: @@ -1050,10 +1198,11 @@ def grid(META): weight, bias, conv_state, - cache_seqlens, conv_state_indices, num_accepted_tokens, query_start_loc, + block_idx_last_scheduled_token, + initial_state_idx, out, # Matrix dimensions batch, @@ -1081,7 +1230,7 @@ def grid(META): KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_VARLEN=query_start_loc is not None, - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_APC_ENABLED=block_idx_last_scheduled_token is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, diff --git a/vllm/model_executor/layers/mamba/ops/layernorm_gated.py b/vllm/model_executor/layers/mamba/ops/layernorm_gated.py index f3a45ab097c3..b592906c6f13 100644 --- a/vllm/model_executor/layers/mamba/ops/layernorm_gated.py +++ b/vllm/model_executor/layers/mamba/ops/layernorm_gated.py @@ -46,17 +46,17 @@ def _layer_norm_fwd_1pass_kernel( B += group * N # Compute mean and variance cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_Z and not NORM_BEFORE_GATE: z = tl.load(Z + cols, mask=cols < N).to(tl.float32) x *= z * tl.sigmoid(z) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.) + xbar = tl.where(cols < N, x - mean, 0.0) var = tl.sum(xbar * xbar, axis=0) / N else: - xbar = tl.where(cols < N, x, 0.) + xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) @@ -74,15 +74,17 @@ def _layer_norm_fwd_1pass_kernel( tl.store(Y + cols, y, mask=mask) -def _layer_norm_fwd(x, - weight, - bias, - eps, - z=None, - out=None, - group_size=None, - norm_before_gate=True, - is_rms_norm=False): +def _layer_norm_fwd( + x, + weight, + bias, + eps, + z=None, + out=None, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): M, N = x.shape if group_size is None: group_size = N @@ -92,57 +94,57 @@ def _layer_norm_fwd(x, if z is not None: assert z.stride(-1) == 1 assert z.shape == (M, N) - assert weight.shape == (N, ) + assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 - assert bias.shape == (N, ) + assert bias.shape == (N,) # allocate output if out is not None: assert out.shape == x.shape else: out = torch.empty_like(x) assert out.stride(-1) == 1 - mean = torch.empty((ngroups * M, ), dtype=torch.float32, - device=x.device) if not is_rms_norm else None - rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) + mean = ( + torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) if group_size > BLOCK_N: - raise RuntimeError( - "This layer norm doesn't support feature dim >= 64KB.") + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) with torch.cuda.device(x.device.index): - _layer_norm_fwd_1pass_kernel[grid](x, - out, - weight, - bias, - z, - mean, - rstd, - x.stride(0), - out.stride(0), - z.stride(0) if z is not None else 0, - M, - group_size, - eps, - BLOCK_N=BLOCK_N, - NORM_BEFORE_GATE=norm_before_gate, - IS_RMS_NORM=is_rms_norm, - num_warps=num_warps) + _layer_norm_fwd_1pass_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) return out, mean, rstd -def rms_norm_gated(x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True): +def rms_norm_gated( + x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True +): x_shape_og = x.shape # reshape input data into 2D tensor x = x.reshape(-1, x.shape[-1]) @@ -156,13 +158,15 @@ def rms_norm_gated(x, weight = weight.contiguous() if bias is not None: bias = bias.contiguous() - y, _, _ = _layer_norm_fwd(x, - weight, - bias, - eps, - z=z, - group_size=group_size, - norm_before_gate=norm_before_gate, - is_rms_norm=True) + y, _, _ = _layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=True, + ) return y.reshape(x_shape_og) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 838290a9f5fb..8722eb9a7b22 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -11,8 +11,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.triton_utils import HAS_TRITON, tl, triton -TRITON3 = HAS_TRITON and (version.parse(triton.__version__) - >= version.parse("3.0.0")) +TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0")) if TRITON3: @@ -28,16 +27,18 @@ def softplus(dt): return dt -@triton.heuristics( - {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) -@triton.heuristics({ - "HAS_STATE_BATCH_INDICES": - lambda args: args["state_batch_indices_ptr"] is not None -}) @triton.heuristics( - {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) + { + "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] + is not None + } +) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])} +) @triton.jit def _selective_scan_update_kernel( # Pointers to matrices @@ -52,6 +53,7 @@ def _selective_scan_update_kernel( z_ptr, out_ptr, state_batch_indices_ptr, + dst_state_batch_indices_ptr, pad_slot_id, # Matrix dimensions batch, @@ -107,11 +109,18 @@ def _selective_scan_update_kernel( # is taken from the state_batch_indices_ptr Otherwise, the state coordinate # is the same as the batch id. if HAS_STATE_BATCH_INDICES: + dst_state_batch_indices_ptr += pid_b + dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64) + dst_state_ptr = state_ptr + ( + dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head + ) state_batch_indices_ptr += pid_b state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) - state_ptr += (state_batch_idx * stride_state_batch + - pid_h * stride_state_head) + state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head else: + dst_state_ptr = ( + state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head + ) state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head @@ -119,26 +128,29 @@ def _selective_scan_update_kernel( if HAS_DT_BIAS: dt_bias_ptr += pid_h * stride_dt_bias_head A_ptr += pid_h * stride_A_head - B_ptr += pid_b * stride_B_batch + (pid_h // - nheads_ngroups_ratio) * stride_B_group - C_ptr += pid_b * stride_C_batch + (pid_h // - nheads_ngroups_ratio) * stride_C_group + B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group if HAS_Z: z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) - state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + - offs_n[None, :] * stride_state_dstate) + state_ptrs = state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) + dst_state_ptrs = dst_state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) x_ptrs = x_ptr + offs_m * stride_x_dim dt_ptrs = dt_ptr + offs_m * stride_dt_dim if HAS_DT_BIAS: dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim if HAS_D: D_ptr += pid_h * stride_D_head - A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + - offs_n[None, :] * stride_A_dstate) + A_ptrs = A_ptr + ( + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate + ) B_ptrs = B_ptr + offs_n * stride_B_dstate C_ptrs = C_ptr + offs_n * stride_C_dstate if HAS_D: @@ -148,20 +160,19 @@ def _selective_scan_update_kernel( out_ptrs = out_ptr + offs_m * stride_out_dim mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: - mask &= (state_batch_idx != pad_slot_id) + mask &= state_batch_idx != pad_slot_id state = tl.load(state_ptrs, mask=mask, other=0.0) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, - other=0.0).to(tl.float32) + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if DT_SOFTPLUS: dt = softplus(dt) - A = tl.load(A_ptrs, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), - other=0.0).to(tl.float32) + A = tl.load( + A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 + ).to(tl.float32) dA = tl.exp(A * dt[:, None]) else: dt = tl.load(dt_ptr).to(tl.float32) @@ -184,8 +195,8 @@ def _selective_scan_update_kernel( mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: - mask &= (state_batch_idx != pad_slot_id) - tl.store(state_ptrs, state, mask=mask) + mask &= state_batch_idx != pad_slot_id + tl.store(dst_state_ptrs, state, mask=mask) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D @@ -194,19 +205,22 @@ def _selective_scan_update_kernel( tl.store(out_ptrs, out, mask=offs_m < dim) -def selective_state_update(state, - x, - dt, - A, - B, - C, - D=None, - z=None, - dt_bias=None, - dt_softplus=False, - state_batch_indices=None, - pad_slot_id=PAD_SLOT_ID, - out=None): +def selective_state_update( + state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + dst_state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID, + out=None, +): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -219,12 +233,12 @@ def selective_state_update(state, z: (batch, dim) or (batch, nheads, dim) dt_bias: (dim,) or (nheads, dim) pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at indices 0 and 3 - out: Preallocated ssm output tensor. Assume same shape as x. + out: Preallocated ssm output tensor. Assume same shape as x. In-place updated. """ if state.dim() == 3: @@ -265,20 +279,33 @@ def selective_state_update(state, if dt_bias is not None: assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: - assert state_batch_indices.shape == (batch, ) + assert state_batch_indices.shape == (batch,) + if dst_state_batch_indices is not None: + assert dst_state_batch_indices.shape == (batch,) + else: + # revert to the default behavior of in-place state updates + dst_state_batch_indices = state_batch_indices assert out.shape == x.shape - grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) - z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else - (0, 0, 0)) + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) + z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) # We don't want autotune since it will overwrite the state # We instead tune by hand. - BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else - ((16, 4) if dstate <= 32 else - ((8, 4) if dstate <= 64 else - ((4, 4) if dstate <= 128 else ((4, 8)))))) - tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride( - -1) == 0 and dt_bias.stride(-1) == 0 + BLOCK_SIZE_M, num_warps = ( + (32, 4) + if dstate <= 16 + else ( + (16, 4) + if dstate <= 32 + else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) + ) + ) + tie_hdim = ( + A.stride(-1) == 0 + and A.stride(-2) == 0 + and dt.stride(-1) == 0 + and dt_bias.stride(-1) == 0 + ) with torch.cuda.device(x.device.index): _selective_scan_update_kernel[grid]( state, @@ -292,6 +319,7 @@ def selective_state_update(state, z, out, state_batch_indices, + dst_state_batch_indices, pad_slot_id, batch, nheads, @@ -308,8 +336,7 @@ def selective_state_update(state, dt.stride(0), dt.stride(1), dt.stride(2), - *(dt_bias.stride(0), - dt_bias.stride(1)) if dt_bias is not None else 0, + *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, A.stride(0), A.stride(1), A.stride(2), @@ -333,54 +360,56 @@ def selective_state_update(state, ) -def selective_scan_fn(u, - ssm_states, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - query_start_loc=None, - cache_indices=None, - has_initial_state=None, - pad_slot_id=PAD_SLOT_ID) -> torch.Tensor: +def selective_scan_fn( + u, + ssm_states, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + query_start_loc=None, + cache_indices=None, + has_initial_state=None, + pad_slot_id=PAD_SLOT_ID, +) -> torch.Tensor: """ - u: (dim, total_length) for varlen or (batch, dim, seqlen) + u: (dim, total_length) for varlen or (batch, dim, seqlen) applies changes in place. ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate) applies changes in place. delta: (dim, total_length) for varlen or (batch, dim, seqlen) - A: (dim, dstate) - B: (ngroups, dstate, total_length) for varlen or + A: (dim, dstate) + B: (ngroups, dstate, total_length) for varlen or (batch,ngroups,dstate,seqlen) - C: (ngroups, dstate, total_length) for varlen or + C: (ngroups, dstate, total_length) for varlen or (batch,ngroups,dstate,seqlen) - D: (dim,) - z: (dim, total_length) for varlen or (batch, dim, seqlen) + D: (dim,) + z: (dim, total_length) for varlen or (batch, dim, seqlen) dt_bias: (dim,) or (dim) query_start_loc: (batch + 1) int32 The cumulative sequence lengths of the sequences in the batch, used to index into sequence. prepended with 0. - for example: query_start_loc = torch.Tensor([0,10,16,17]), + for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) cache_indices: (batch) int32 - A tensor with each cell is a correspondent + A tensor with each cell is a correspondent input and output ssm_state index has_initial_state: (batch) bool - A tensor populated with ones and zeros, - indicate if the ssm_state at the corresponding index should be - used as initial state. Not providing argument assumes + A tensor populated with ones and zeros, + indicate if the ssm_state at the corresponding index should be + used as initial state. Not providing argument assumes there's no initial state pad_slot_id: int - if cache_indices is passed, lets the kernel identify padding entries - that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + if cache_indices is passed, lets the kernel identify padding entries + that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 returns - output: (dim, total_length) for varlen or (batch, dim, seqlen) + output: (dim, total_length) for varlen or (batch, dim, seqlen) supports inplace replacement """ if u.stride(-1) != 1: @@ -404,9 +433,22 @@ def selective_scan_fn(u, if C.dim() == 2 and query_start_loc is not None: C = C.unsqueeze(0) - ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, - query_start_loc, cache_indices, has_initial_state, - ssm_states, pad_slot_id) + ops.selective_scan_fwd( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + query_start_loc, + cache_indices, + has_initial_state, + ssm_states, + pad_slot_id, + ) if z is None: return delta # output written inplace to delta diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 15a72fc61261..ac5ffc10f295 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -14,79 +14,52 @@ @triton.autotune( configs=[ triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=2), + num_warps=2, + ), ], - key=['chunk_size', 'K', 'IS_CAUSAL'], + key=["chunk_size", "K", "IS_CAUSAL"], ) @triton.jit def _bmm_chunk_fwd_kernel( @@ -136,24 +109,26 @@ def _bmm_chunk_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + - offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + - offs_n[None, :] * stride_b_seqlen) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # compute a * b.T for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0).to(dot_dtype) - b = tl.load(b_ptrs, - mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & - (offs_n[None, :] < chunk_size_limit), - other=0.0).to(dot_dtype) + a = tl.load( + a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ).to(dot_dtype) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) + & (offs_n[None, :] < chunk_size_limit), + other=0.0, + ).to(dot_dtype) acc += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -163,20 +138,15 @@ def _bmm_chunk_fwd_kernel( out = acc.to(out_ptr.dtype.element_ty) out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head - out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + - offs_n[None, :] * stride_outn) - tl.store(out_ptrs, - out, - mask=(offs_m[:, None] < chunk_size) & - (offs_n[None, :] < chunk_size)) - - -def _bmm_chunk_fwd(a, - b, - chunk_size, - cu_chunk_seqlens, - causal=False, - output_dtype=None): + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), + ) + + +def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None): """ Argument: a: (seqlen, ngroups, k) @@ -198,16 +168,23 @@ def _bmm_chunk_fwd(a, nchunks = len(cu_chunk_seqlens) - 1 # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype - out = torch.empty((nchunks, ngroups, chunk_size, chunk_size), - device=a.device, - dtype=out_dtype) - dot_dtype = (tl.bfloat16 - if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else - (tl.float16 if a.dtype == torch.float16 - or b.dtype == torch.float16 else tl.float32)) - grid = lambda META: (triton.cdiv( - chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( - chunk_size, META['BLOCK_SIZE_N']), nchunks * ngroups) + out = torch.empty( + (nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype + ) + dot_dtype = ( + tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 + else ( + tl.float16 + if a.dtype == torch.float16 or b.dtype == torch.float16 + else tl.float32 + ) + ) + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), + nchunks * ngroups, + ) with torch.cuda.device(a.device.index): _bmm_chunk_fwd_kernel[grid]( a_ptr=a, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index e1e77e14f69d..e5a5c9dd6f71 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -10,101 +10,68 @@ from vllm.triton_utils import tl, triton -TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") @triton.autotune( configs=[ triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=2), + num_warps=2, + ), ], - key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], + key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"], ) @triton.jit def _chunk_scan_fwd_kernel( @@ -177,15 +144,16 @@ def _chunk_scan_fwd_kernel( num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += pid_c * stride_cb_chunk + (pid_h // - nheads_ngroups_ratio) * stride_cb_head + cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += chunk_seqlen_start * stride_C_seqlen + ( - pid_h // nheads_ngroups_ratio) * stride_C_head + C_ptr += ( + chunk_seqlen_start * stride_C_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_C_head + ) # M-block offsets and prev states # - logic in next block may override these if there is an active offset @@ -193,26 +161,31 @@ def _chunk_scan_fwd_kernel( seq_idx_ptr += pid_c * stride_seq_idx_chunk seq_idx = tl.load(seq_idx_ptr) - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_chunk, - mask=pid_c >= 1, - other=-1) + seq_idx_prev = tl.load( + seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1 + ) if HAS_INITSTATES and (seq_idx != seq_idx_prev): - prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_ptr = ( + initstates_ptr + + seq_idx * stride_init_states_batch + + pid_h * stride_init_states_head + ) prev_states_hdim = stride_init_states_hdim prev_states_dstate = stride_init_states_dstate else: - prev_states_ptr = states_ptr + ( - pid_c - 1) * stride_states_chunk + pid_h * stride_states_head + prev_states_ptr = ( + states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head + ) prev_states_hdim = stride_states_hdim prev_states_dstate = stride_states_dstate chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, - mask=offs_m < chunk_size, - other=0.0).to(tl.float32) + dA_cs_m = tl.load( + dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0 + ).to(tl.float32) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) @@ -221,52 +194,66 @@ def _chunk_scan_fwd_kernel( # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 offs_k_dstate = tl.arange( - 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + - offs_k_dstate[None, :] * stride_C_dstate) + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K + ) + C_ptrs = C_ptr + ( + offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate + ) scale_m = tl.exp(dA_cs_m) if BLOCK_SIZE_DSTATE <= 128: - C = tl.load(C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k_dstate[None, :] < dstate), - other=0.0) + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate), + other=0.0, + ) if not HAS_INITSTATES and (seq_idx != seq_idx_prev): # if no init states AND starting a new sequence, we need zeros - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), - dtype=C_ptr.dtype.element_ty) + prev_states = tl.zeros( + (BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty + ) else: # otherwise read the previous state - prev_states_ptrs = prev_states_ptr \ - + offs_n[None, :] * prev_states_hdim \ - + offs_k_dstate[:, None] * prev_states_dstate - prev_states = tl.load(prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & - (offs_n[None, :] < hdim), - other=0.0) + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), + other=0.0, + ) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: - prev_states_ptrs = prev_states_ptr \ - + offs_n[None, :] * prev_states_hdim \ - + offs_k_dstate[:, None] * prev_states_dstate + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load(C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k_dstate[None, :] < dstate - k), - other=0.0) + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate - k), + other=0.0, + ) if not HAS_INITSTATES and (seq_idx != seq_idx_prev): - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), - dtype=C_ptr.dtype.element_ty) + prev_states = tl.zeros( + (BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty + ) else: prev_states = tl.load( prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate - k) & - (offs_n[None, :] < hdim), - other=0.0) + mask=(offs_k_dstate[:, None] < dstate - k) + & (offs_n[None, :] < hdim), + other=0.0, + ) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K @@ -274,36 +261,42 @@ def _chunk_scan_fwd_kernel( acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) - cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + - offs_k[None, :] * stride_cb_csize_k) - x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + - offs_n[None, :] * stride_x_hdim) + cb_ptrs = cb_ptr + ( + offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k + ) + x_ptrs = x_ptr + ( + offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim + ) dt_ptrs = dt_ptr + offs_k * stride_dt_csize dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - K_MAX = chunk_size_limit if not IS_CAUSAL else min( - (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + K_MAX = ( + chunk_size_limit + if not IS_CAUSAL + else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + ) for k in range(0, K_MAX, BLOCK_SIZE_K): - cb = tl.load(cb_ptrs, - mask=(offs_m[:, None] < chunk_size) & - (offs_k[None, :] < chunk_size - k), - other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, - mask=offs_k < chunk_size - k, - other=0.0).to(tl.float32) + cb = tl.load( + cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to( + tl.float32 + ) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, - other=0.0).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: mask = offs_m[:, None] >= k + offs_k[None, :] cb = tl.where(mask, cb, 0.0) cb = cb.to(x_ptr.dtype.element_ty) - x = tl.load(x_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) & - (offs_n[None, :] < hdim), - other=0.0) + x = tl.load( + x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), + other=0.0, + ) acc += tl.dot(cb, x) cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k x_ptrs += BLOCK_SIZE_K * stride_x_seqlen @@ -315,35 +308,41 @@ def _chunk_scan_fwd_kernel( if HAS_D: if D_HAS_HDIM: - D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, - mask=offs_n < hdim, - other=0.0).to(tl.float32) + D = tl.load( + D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0 + ).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + - offs_n[None, :] * stride_x_hdim), - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_n[None, :] < hdim), - other=0.0).to(tl.float32) + x_residual = tl.load( + x_ptr + + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) acc += x_residual * D if HAS_Z: z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head - z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + - stride_z_hdim * offs_out_n[None, :]) - z = tl.load(z_ptrs, - mask=(offs_out_m[:, None] < chunk_size_limit) & - (offs_out_n[None, :] < hdim), - other=0.0).to(tl.float32) + z_ptrs = z_ptr + ( + stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :] + ) + z = tl.load( + z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) + & (offs_out_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) acc *= z * tl.sigmoid(z) out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head - out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + - offs_out_n[None, :] * stride_out_hdim) - tl.store(out_ptrs, - acc, - mask=(offs_out_m[:, None] < chunk_size_limit) & - (offs_out_n[None, :] < hdim)) + out_ptrs = out_ptr + ( + stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim + ) + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), + ) def _chunk_scan_fwd( @@ -369,24 +368,32 @@ def _chunk_scan_fwd( assert C.shape == (seqlen, ngroups, dstate) assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size) if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + assert D.shape == (nheads, headdim) or D.shape == (nheads,) if z is not None: assert z.shape == x.shape assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == (nheads, nchunks, chunk_size) assert states.shape == (nchunks, nheads, headdim, dstate) - assert seq_idx.shape == (nchunks, ) + assert seq_idx.shape == (nchunks,) - grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton - .cdiv(headdim, META['BLOCK_SIZE_N']), nchunks, nheads) + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), + nchunks, + nheads, + ) - z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else - (0, 0, 0)) - initial_states_strides = ((initial_states.stride(0), - initial_states.stride(1), - initial_states.stride(2), - initial_states.stride(3)) - if initial_states is not None else (0, 0, 0, 0)) + z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) + initial_states_strides = ( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ) _chunk_scan_fwd_kernel[grid]( cb_ptr=cb, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 3a3e0f293459..11cc125bf219 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -15,14 +15,14 @@ @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_H': 2}), - triton.Config({'BLOCK_SIZE_H': 4}), - triton.Config({'BLOCK_SIZE_H': 8}), - triton.Config({'BLOCK_SIZE_H': 16}), - triton.Config({'BLOCK_SIZE_H': 32}), - triton.Config({'BLOCK_SIZE_H': 64}), + triton.Config({"BLOCK_SIZE_H": 2}), + triton.Config({"BLOCK_SIZE_H": 4}), + triton.Config({"BLOCK_SIZE_H": 8}), + triton.Config({"BLOCK_SIZE_H": 16}), + triton.Config({"BLOCK_SIZE_H": 32}), + triton.Config({"BLOCK_SIZE_H": 64}), ], - key=['chunk_size', 'nheads'], + key=["chunk_size", "nheads"], ) @triton.jit def _chunk_cumsum_fwd_kernel( @@ -70,118 +70,99 @@ def _chunk_cumsum_fwd_kernel( offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) - dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + - offs_c[None, :] * stride_dt_seqlen) + dt_ptrs = dt_ptr + ( + offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen + ) A_ptrs = A_ptr + offs_h * stride_A_head - dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + - offs_c[None, :] * stride_dt_out_csize) - dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + - offs_c[None, :] * stride_dA_cs_csize) + dt_out_ptrs = dt_out_ptr + ( + offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize + ) + dA_cs_ptrs = dA_cumsum_ptr + ( + offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize + ) chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - dt = tl.load(dt_ptrs, - mask=(offs_h[:, None] < nheads) & - (offs_c[None, :] < chunk_size_limit), - other=0.0).to(tl.float32) + dt = tl.load( + dt_ptrs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), + other=0.0, + ).to(tl.float32) if HAS_DT_BIAS: - dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, - mask=offs_h < nheads, - other=0.0).to(tl.float32) + dt_bias = tl.load( + dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0 + ).to(tl.float32) dt += dt_bias[:, None] if DT_SOFTPLUS: dt = tl.where(dt <= 20.0, softplus(dt), dt) dt = tl.clamp(dt, dt_min, dt_max) dt = tl.where( - (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, - 0.0) - tl.store(dt_out_ptrs, - dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0 + ) + tl.store( + dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) - tl.store(dA_cs_ptrs, - dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + tl.store( + dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) @triton.autotune( configs=[ triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=2), + num_warps=2, + ), ], - key=['hdim', 'dstate', 'chunk_size'], + key=["hdim", "dstate", "chunk_size"], ) @triton.jit def _chunk_state_fwd_kernel( @@ -227,8 +208,10 @@ def _chunk_state_fwd_kernel( pid_n = tl.program_id(axis=0) % num_pid_n chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - b_ptr += chunk_seqlen_start * stride_b_seqlen + ( - pid_h // nheads_ngroups_ratio) * stride_b_head + b_ptr += ( + chunk_seqlen_start * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head @@ -236,32 +219,38 @@ def _chunk_state_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + - offs_k[None, :] * stride_x_seqlen) - b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + - offs_k[:, None] * stride_b_seqlen) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + - (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to( + tl.float32 + ) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load(x_ptrs, - mask=(offs_m[:, None] < hdim) & - (offs_k[None, :] < chunk_size_limit - k), - other=0.0) - b = tl.load(b_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) & - (offs_n[None, :] < dstate), - other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, - mask=offs_k < chunk_size_limit - k, - other=0.0).to(tl.float32) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, - other=0.0).to(tl.float32) + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) @@ -277,8 +266,9 @@ def _chunk_state_fwd_kernel( states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + - offs_n[None, :] * stride_states_dstate) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) @@ -286,79 +276,52 @@ def _chunk_state_fwd_kernel( @triton.autotune( configs=[ triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, num_stages=3, - num_warps=8), + num_warps=8, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=4), + num_warps=4, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=5, - num_warps=2), + num_warps=2, + ), triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, num_stages=4, - num_warps=2), + num_warps=2, + ), ], - key=['hdim', 'dstate', 'chunk_size'], + key=["hdim", "dstate", "chunk_size"], ) @triton.jit def _chunk_state_varlen_kernel( @@ -414,12 +377,16 @@ def _chunk_state_varlen_kernel( pid_n = tl.program_id(axis=0) % num_pid_n end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) pid_c = (end_idx - 1) // chunk_size - b_ptr += pid_c * chunk_size * stride_b_seqlen + ( - pid_h // nheads_ngroups_ratio) * stride_b_head + b_ptr += ( + pid_c * chunk_size * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + chunk_states_ptr += ( + pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + ) if HAS_INITSTATES: # if there are init states provided, we differentiate between states (which @@ -430,13 +397,16 @@ def _chunk_state_varlen_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + - offs_k[None, :] * stride_x_seqlen) - b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + - offs_k[:, None] * stride_b_seqlen) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * - stride_dA_cs_csize).to(tl.float32) + dA_cs_last = tl.load( + dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize + ).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize chunk_size_limit = end_idx - pid_c * chunk_size @@ -445,24 +415,31 @@ def _chunk_state_varlen_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load(x_ptrs, - mask=(offs_m[:, None] < hdim) & - (offs_k[None, :] < chunk_size_limit - k) & - (offs_k[None, :] >= start_idx_cur - k), - other=0.0) - b = tl.load(b_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) & - (offs_n[None, :] < dstate) & - (offs_k[:, None] >= start_idx_cur - k), - other=0.0).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, - mask=offs_k < chunk_size_limit - k, - other=0.0).to(tl.float32) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, - other=0.0).to(tl.float32) + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) + & (offs_k[None, :] < chunk_size_limit - k) + & (offs_k[None, :] >= start_idx_cur - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) + & (offs_n[None, :] < dstate) + & (offs_k[:, None] >= start_idx_cur - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) scale = tl.where( (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), - tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + tl.exp(dA_cs_last - dA_cs_k) * dt_k, + 0.0, + ) b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -475,39 +452,43 @@ def _chunk_state_varlen_kernel( # If HAS_INITSTATES==True need to consider two possibilities # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs # - if state_idx >= pid * chunk_size, then we need to insert initstates - if ((start_idx < pid_c * chunk_size) # first chunk - or (HAS_INITSTATES)): - + if ( + (start_idx < pid_c * chunk_size) # first chunk + or (HAS_INITSTATES) + ): dA_cs_boundary = 0.0 # default if not HAS_INITSTATES: past_states_ptrs = chunk_states_ptr + ( - offs_m[:, None] * stride_chunk_states_hdim + - offs_n[None, :] * stride_chunk_states_dstate) + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) else: - # - this seems repetitive, buts its to help the compiler if start_idx < pid_c * chunk_size: past_states_ptrs = chunk_states_ptr + ( - offs_m[:, None] * stride_chunk_states_hdim + - offs_n[None, :] * stride_chunk_states_dstate) + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) else: past_states_ptrs = initstates_ptr + ( - pid_b * stride_init_states_batch + - offs_m[:, None] * stride_init_states_hdim + - offs_n[None, :] * stride_init_states_dstate) + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate + ) # need to adjust the boundary if start_idx > pid_c * chunk_size: - dA_cs_boundary = tl.load(dA_cumsum_ptr + - (start_idx - pid_c * chunk_size - - 1) * stride_dA_cs_csize).to( - tl.float32) + dA_cs_boundary = tl.load( + dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize + ).to(tl.float32) - past_states = tl.load(past_states_ptrs, - mask=(offs_m[:, None] < hdim) & - (offs_n[None, :] < dstate), - other=0.0).to(tl.float32) + past_states = tl.load( + past_states_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) scale = tl.exp(dA_cs_last - dA_cs_boundary) acc += past_states * scale @@ -517,36 +498,34 @@ def _chunk_state_varlen_kernel( states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + - offs_n[None, :] * stride_states_dstate) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) tl.store(states_ptrs, states, mask=c_mask) -def _chunk_cumsum_fwd(dt, - A, - chunk_size, - cu_chunk_seqlens, - dt_bias=None, - dt_softplus=False, - dt_limit=(0.0, float("inf"))): +def _chunk_cumsum_fwd( + dt, + A, + chunk_size, + cu_chunk_seqlens, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), +): seqlen, nheads = dt.shape - assert A.shape == (nheads, ) + assert A.shape == (nheads,) if dt_bias is not None: - assert dt_bias.shape == (nheads, ) + assert dt_bias.shape == (nheads,) nchunks = cu_chunk_seqlens.shape[0] - 1 - dt_out = torch.empty(nheads, - nchunks, - chunk_size, - device=dt.device, - dtype=torch.float32) - dA_cumsum = torch.empty(nheads, - nchunks, - chunk_size, - device=dt.device, - dtype=torch.float32) - grid_chunk_cs = lambda META: (nchunks, - triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + dt_out = torch.empty( + nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + dA_cumsum = torch.empty( + nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"])) with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( dt_ptr=dt, @@ -563,8 +542,7 @@ def _chunk_cumsum_fwd(dt, stride_dt_seqlen=dt.stride(0), stride_dt_head=dt.stride(1), stride_A_head=A.stride(0), - stride_dt_bias_head=dt_bias.stride(0) - if dt_bias is not None else 0, + stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0, stride_dt_out_head=dt_out.stride(0), stride_dt_out_chunk=dt_out.stride(1), stride_dt_out_csize=dt_out.stride(2), @@ -578,13 +556,9 @@ def _chunk_cumsum_fwd(dt, return dA_cumsum, dt_out -def _chunk_state_fwd(B, - x, - dt, - dA_cumsum, - cu_chunk_seqlens, - states=None, - states_in_fp32=True): +def _chunk_state_fwd( + B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True +): seqlen, nheads, headdim = x.shape _, nchunks, chunk_size = dt.shape _, ngroups, dstate = B.shape @@ -597,12 +571,16 @@ def _chunk_state_fwd(B, assert states.shape == (nchunks, nheads, headdim, dstate) else: states_dtype = torch.float32 if states_in_fp32 else B.dtype - states = torch.empty((nchunks, nheads, headdim, dstate), - device=x.device, - dtype=states_dtype) + states = torch.empty( + (nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype + ) - grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. - cdiv(dstate, META['BLOCK_SIZE_N']), nchunks, nheads) + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + nchunks, + nheads, + ) with torch.cuda.device(x.device.index): _chunk_state_fwd_kernel[grid]( x_ptr=x, @@ -636,13 +614,9 @@ def _chunk_state_fwd(B, return states -def chunk_state_varlen(B, - x, - dt, - dA_cumsum, - cu_seqlens, - chunk_states, - initial_states=None): +def chunk_state_varlen( + B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None +): total_seqlen, nheads, headdim = x.shape _, nchunks, chunk_size = dt.shape _, ngroups, dstate = B.shape @@ -657,21 +631,32 @@ def chunk_state_varlen(B, if initial_states is not None: assert initial_states.shape == (batch, nheads, headdim, dstate) - states = torch.empty(batch, - nheads, - headdim, - dstate, - dtype=chunk_states.dtype, - device=chunk_states.device) - - initial_states_strides = ((initial_states.stride(0), - initial_states.stride(1), - initial_states.stride(2), - initial_states.stride(3)) - if initial_states is not None else (0, 0, 0, 0)) - - grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton. - cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads) + states = torch.empty( + batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device, + ) + + initial_states_strides = ( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ) + + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + batch, + nheads, + ) with torch.cuda.device(x.device.index): _chunk_state_varlen_kernel[grid]( x_ptr=x, @@ -710,5 +695,6 @@ def chunk_state_varlen(B, stride_init_states_head=initial_states_strides[1], stride_init_states_hdim=initial_states_strides[2], stride_init_states_dstate=initial_states_strides[3], - HAS_INITSTATES=initial_states is not None) + HAS_INITSTATES=initial_states is not None, + ) return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index f3eb61d5840e..ac905ada7229 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -17,62 +17,66 @@ from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd from .ssd_state_passing import _state_passing_fwd -TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") def is_int_pow_2(n): return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 -def _mamba_chunk_scan_combined_fwd(x, - dt, - A, - B, - C, - chunk_size, - out, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - cu_seqlens=None, - cu_chunk_seqlens=None, - last_chunk_indices=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - state_dtype=None): +def _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + return_intermediate_states=False, + seq_idx=None, + cu_seqlens=None, + cu_chunk_seqlens=None, + last_chunk_indices=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + state_dtype=None, +): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" seqlen, nheads, headdim = x.shape _, ngroups, dstate = B.shape assert nheads % ngroups == 0 assert B.shape == (seqlen, ngroups, dstate) assert dt.shape == (seqlen, nheads) - assert A.shape == (nheads, ) + assert A.shape == (nheads,) assert C.shape == B.shape if z is not None: assert z.shape == x.shape if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads, ) + assert D.shape == (nheads, headdim) or D.shape == (nheads,) if seq_idx is not None: - assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1, ) + assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() - if x.stride(-1) != 1 and x.stride( - 0) != 1: # Either M or K dimension should be contiguous + if ( + x.stride(-1) != 1 and x.stride(0) != 1 + ): # Either M or K dimension should be contiguous x = x.contiguous() - if z is not None and z.stride(-1) != 1 and z.stride( - 0) != 1: # Either M or K dimension should be contiguous + if ( + z is not None and z.stride(-1) != 1 and z.stride(0) != 1 + ): # Either M or K dimension should be contiguous z = z.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens" if initial_states is not None: - assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, - dstate) + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate) # This function executes 5 sub-functions for computing mamba # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ @@ -85,22 +89,21 @@ def _mamba_chunk_scan_combined_fwd(x, # 1. Compute chunked cumsum of A * dt # - here dt may go through a softplus activation - dA_cumsum, dt = _chunk_cumsum_fwd(dt, - A, - chunk_size, - cu_chunk_seqlens, - dt_bias=dt_bias, - dt_softplus=dt_softplus, - dt_limit=dt_limit) + dA_cumsum, dt = _chunk_cumsum_fwd( + dt, + A, + chunk_size, + cu_chunk_seqlens, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + ) # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) - states = _chunk_state_fwd(B, - x, - dt, - dA_cumsum, - cu_chunk_seqlens, - states_in_fp32=True) + states = _chunk_state_fwd( + B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True + ) # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) @@ -113,18 +116,15 @@ def _mamba_chunk_scan_combined_fwd(x, dA_cumsum, # (nheads, nchunks, chunk_size) cu_chunk_seqlens, initial_states=rearrange(initial_states, "... p n -> ... (p n)") - if initial_states is not None else - None, # (batch, nheads, headdim*dstate) + if initial_states is not None + else None, # (batch, nheads, headdim*dstate) seq_idx=seq_idx, - out_dtype=state_dtype if state_dtype is not None else C.dtype) + out_dtype=state_dtype if state_dtype is not None else C.dtype, + ) states = rearrange(states, "... (p n) -> ... p n", n=dstate) # 4. Compute batched matrix multiply for C_j^T B_i terms - CB = _bmm_chunk_fwd(C, - B, - chunk_size, - cu_chunk_seqlens, - output_dtype=torch.float32) + CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32) # 5. Scan and compute the diagonal blocks, taking into # account past causal states. @@ -151,28 +151,32 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, ) - return states[last_chunk_indices] + if return_intermediate_states: + return states + else: + return states[last_chunk_indices] def mamba_chunk_scan_combined_varlen( - x, - dt, - A, - B, - C, - chunk_size, - cu_seqlens, - cu_chunk_seqlens, - last_chunk_indices, - seq_idx, - out, - D=None, - z=None, - dt_bias=None, - initial_states=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - state_dtype=None, + x, + dt, + A, + B, + C, + chunk_size, + cu_seqlens, + cu_chunk_seqlens, + last_chunk_indices, + seq_idx, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_intermediate_states=False, + state_dtype=None, ): """ Argument: @@ -213,12 +217,14 @@ def mamba_chunk_scan_combined_varlen( z=z, dt_bias=dt_bias, initial_states=initial_states, + return_intermediate_states=return_intermediate_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, cu_chunk_seqlens=cu_chunk_seqlens, last_chunk_indices=last_chunk_indices, dt_softplus=dt_softplus, dt_limit=dt_limit, - state_dtype=state_dtype) + state_dtype=state_dtype, + ) return varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index f09af262cfc2..5481bab17e5a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -13,14 +13,14 @@ @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE': 64}), - triton.Config({'BLOCK_SIZE': 128}), - triton.Config({'BLOCK_SIZE': 256}), - triton.Config({'BLOCK_SIZE': 512}), - triton.Config({'BLOCK_SIZE': 1024}), - triton.Config({'BLOCK_SIZE': 2048}), + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), ], - key=['dim'], + key=["dim"], ) @triton.jit def _state_passing_fwd_kernel( @@ -58,8 +58,7 @@ def _state_passing_fwd_kernel( pid_m = tl.program_id(axis=0) states_ptr += pid_h * stride_states_head - dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - - 1) * stride_dA_cs_csize + dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_h * stride_out_head offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -67,31 +66,35 @@ def _state_passing_fwd_kernel( out_ptrs = out_ptr + offs_m * stride_out_dim if HAS_INITSTATES: - initstates_ptrs = initstates_ptr \ - + pid_h * stride_initstates_head \ + initstates_ptrs = ( + initstates_ptr + + pid_h * stride_initstates_head + offs_m * stride_initstates_dim + ) - states = tl.load(initstates_ptrs, mask=offs_m < dim, - other=0.0).to(tl.float32) + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) else: - states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) prev_seq_idx = 0 for c in range(nchunks): - new_states = tl.load(states_ptrs, mask=offs_m < dim, - other=0.0).to(tl.float32) + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk) # we have started a new sequence if prev_seq_idx != seq_idx: if HAS_INITSTATES: - initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \ - + pid_h * stride_initstates_head \ + initstates_ptrs = ( + initstates_ptr + + seq_idx * stride_initstates_batch + + pid_h * stride_initstates_head + offs_m * stride_initstates_dim - states = tl.load(initstates_ptrs, mask=offs_m < dim, - other=0.0).to(tl.float32) + ) + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to( + tl.float32 + ) else: - states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) prev_seq_idx = seq_idx states = tl.exp(dA_cs) * states + new_states @@ -115,16 +118,15 @@ def _state_passing_fwd( assert dA_cumsum.shape == (nheads, nchunks, chunk_size) seqlen = seq_idx.shape[-1] out_dtype = states.dtype if out_dtype is None else out_dtype - out = torch.empty((nchunks, nheads, dim), - device=states.device, - dtype=out_dtype) + out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype) - initial_states_strides = ((initial_states.stride(0), - initial_states.stride(1), - initial_states.stride(2)) - if initial_states is not None else (0, 0, 0)) + initial_states_strides = ( + (initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) + if initial_states is not None + else (0, 0, 0) + ) - grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), nheads) + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( states_ptr=states, diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index eb4223ade5f0..32273d137eca 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -13,29 +13,35 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) + causal_conv1d_fn, + causal_conv1d_update, +) from vllm.utils import direct_register_custom_op -from vllm.v1.attention.backends.short_conv_attn import ( - ShortConvAttentionMetadata) +from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata @CustomOp.register("short_conv") class ShortConv(MambaBase, CustomOp): - - def __init__(self, - config, - dim: int, - layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - prefix: str = ""): + def __init__( + self, + config, + dim: int, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ): super().__init__() self.config = config self.layer_idx = layer_idx @@ -72,7 +78,7 @@ def __init__(self, if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - self.kv_cache = (torch.tensor([]), ) + self.kv_cache = (torch.tensor([]),) self.model_config = model_config self.cache_config = cache_config @@ -121,8 +127,9 @@ def forward_cuda( B, C, x = BCx.chunk(3, dim=-1) - conv_weights = self.conv.weight.view(self.conv.weight.size(0), - self.conv.weight.size(2)) + conv_weights = self.conv.weight.view( + self.conv.weight.size(0), self.conv.weight.size(2) + ) if attn_metadata is None: # V1 profile run @@ -163,23 +170,26 @@ def forward_cuda( dim=0, ) query_start_loc_p = ( - attn_metadata.query_start_loc[-num_prefills - 1:] - - num_decodes if has_prefill else None) + attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes + if has_prefill + else None + ) conv_output_list = [] if has_prefill: Bx_p = (B_p * x_p).transpose(0, 1) - Bx = causal_conv1d_fn(Bx_p, - conv_weights, - self.conv.bias, - activation=None, - conv_states=conv_state, - has_initial_state=has_initial_states_p, - cache_indices=state_indices_tensor_p, - metadata=attn_metadata, - query_start_loc=query_start_loc_p).transpose( - 0, 1)[:num_prefill_tokens] + Bx = causal_conv1d_fn( + Bx_p, + conv_weights, + self.conv.bias, + activation=None, + conv_states=conv_state, + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + metadata=attn_metadata, + query_start_loc=query_start_loc_p, + ).transpose(0, 1)[:num_prefill_tokens] y = C_p * Bx conv_output_list.append(y) @@ -192,7 +202,8 @@ def forward_cuda( conv_weights, self.conv.bias, activation=None, - conv_state_indices=state_indices_tensor_d) + conv_state_indices=state_indices_tensor_d, + ) y = C_d * Bx conv_output_list.insert(0, y) @@ -222,8 +233,8 @@ def mamba_type(self) -> str: return "short_conv" def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.short_conv_attn import ( - ShortConvAttentionBackend) + from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend + return ShortConvAttentionBackend diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 66bf3823e191..b8e99226d13e 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -13,8 +13,8 @@ @dataclass class MLAModules: - """Modules used in MLA. - """ + """Modules used in MLA.""" + kv_a_layernorm: torch.nn.Module kv_b_proj: torch.nn.Module rotary_emb: torch.nn.Module @@ -36,7 +36,7 @@ class MultiHeadLatentAttention(CustomOp): because there is only one in-tree implementation in forward_native. TODO: implement this with a new PluggableLayer mechanism. - This class takes positions and hidden_states as input. + This class takes positions and hidden_states as input. The input tensors can either contain prefill tokens or decode tokens. The class does the following: @@ -125,12 +125,15 @@ def forward_native( kv_lora = None if self.q_lora_rank is not None: - assert self.fused_qkv_a_proj is not None, \ + assert self.fused_qkv_a_proj is not None, ( "fused_qkv_a_proj is required when q_lora_rank is not None" - assert self.q_a_layernorm is not None, \ + ) + assert self.q_a_layernorm is not None, ( "q_a_layernorm is required when q_lora_rank is not None" - assert self.q_b_proj is not None, \ + ) + assert self.q_b_proj is not None, ( "q_b_proj is required when q_lora_rank is not None" + ) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_lora = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], @@ -139,34 +142,35 @@ def forward_native( q_c = self.q_a_layernorm(q_c) q = self.q_b_proj(q_c)[0] else: - assert self.kv_a_proj_with_mqa is not None, \ + assert self.kv_a_proj_with_mqa is not None, ( "kv_a_proj_with_mqa is required when q_lora_rank is None" - assert self.q_proj is not None, \ + ) + assert self.q_proj is not None, ( "q_proj is required when q_lora_rank is None" + ) kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] q = self.q_proj(hidden_states)[0] - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], - dim=-1) + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c) q = q.view(-1, self.num_heads, self.qk_head_dim) # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) - q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( - positions, q[..., self.qk_nope_head_dim:], k_pe) + q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim :], k_pe + ) if self.indexer and self.is_sparse: - _topk_indices = self.indexer(hidden_states, q_c, positions, - self.rotary_emb) + _topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) attn_out = self.mla_attn( q, kv_c_normed, k_pe, - output_shape=(hidden_states.shape[0], - self.num_heads * self.v_head_dim)) + output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), + ) return self.o_proj(attn_out)[0] def forward_cuda(self, *args, **kwargs): diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 139011ce10be..979939ebc468 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -25,12 +25,14 @@ PoolingFn = Callable[ [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], - Union[torch.Tensor, list[torch.Tensor]]] + Union[torch.Tensor, list[torch.Tensor]], +] ClassifierFn = Callable[[torch.Tensor], torch.Tensor] class PoolingType(IntEnum): """Enumeration for different types of pooling methods.""" + LAST = 0 ALL = 1 CLS = 2 @@ -50,8 +52,7 @@ def from_config( pooler_config: PoolerConfig, ) -> "ResolvedPoolingConfig": assert pooler_config.pooling_type is not None - return cls(task=task, - pooling_type=PoolingType[pooler_config.pooling_type]) + return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type]) @dataclass(frozen=True) @@ -71,8 +72,9 @@ def for_encode(pooler_config: PoolerConfig): if pooler_config.pooling_type == "STEP": return StepPooler() - resolved_config = ResolvedPoolingConfig(task="encode", - pooling_type=PoolingType.ALL) + resolved_config = ResolvedPoolingConfig( + task="encode", pooling_type=PoolingType.ALL + ) return SimplePooler.from_config(resolved_config) @@ -129,10 +131,10 @@ def get_prompt_lens( return pooling_metadata.prompt_lens -def get_prompt_token_ids( - pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: +def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: assert pooling_metadata.prompt_token_ids is not None, ( - "Please set `requires_token_ids=True` in `get_pooling_updates`") + "Please set `requires_token_ids=True` in `get_pooling_updates`" + ) return [ pooling_metadata.prompt_token_ids[i, :num] @@ -140,8 +142,7 @@ def get_prompt_token_ids( ] -def get_pooling_params( - pooling_metadata: PoolingMetadata) -> list[PoolingParams]: +def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]: pooling_params = pooling_metadata.pooling_params return pooling_params @@ -150,7 +151,8 @@ def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: pooling_params = get_pooling_params(pooling_metadata) tasks: list[PoolingTask] = [ - task for pooling_param in pooling_params + task + for pooling_param in pooling_params if (task := pooling_param.task) is not None ] assert len(pooling_params) == len(tasks) @@ -173,17 +175,22 @@ def get_classification_activation_function(config: PretrainedConfig): def get_cross_encoder_activation_function(config: PretrainedConfig): function_name: Optional[str] = None - if (hasattr(config, "sentence_transformers") - and "activation_fn" in config.sentence_transformers): + if ( + hasattr(config, "sentence_transformers") + and "activation_fn" in config.sentence_transformers + ): function_name = config.sentence_transformers["activation_fn"] - elif (hasattr(config, "sbert_ce_default_activation_function") - and config.sbert_ce_default_activation_function is not None): + elif ( + hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None + ): function_name = config.sbert_ce_default_activation_function if function_name is not None: assert function_name.startswith("torch.nn.modules."), ( "Loading of activation functions is restricted to " - "torch.nn.modules for security reasons") + "torch.nn.modules for security reasons" + ) fn = resolve_obj_by_qualname(function_name)() return PoolerActivation.wraps(fn) @@ -191,7 +198,6 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): class PoolingMethod(nn.Module, ABC): - @staticmethod def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod": if pooling_type == PoolingType.LAST: @@ -230,7 +236,6 @@ def forward( class CLSPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} @@ -239,14 +244,14 @@ def forward_all( hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - assert not pooling_cursor.is_partial_prefill(), \ + assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with CLS pooling" + ) return hidden_states[pooling_cursor.first_token_indices_gpu] class LastPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} @@ -259,7 +264,6 @@ def forward_all( class AllPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode"} @@ -268,18 +272,17 @@ def forward_all( hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - - assert not pooling_cursor.is_partial_prefill(), \ + assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with ALL pooling" + ) hidden_states_lst = list( - hidden_states.split( - pooling_cursor.num_scheduled_tokens_cpu.tolist())) + hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()) + ) return [hidden_states_lst[i] for i in pooling_cursor.index] class MeanPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: return {"encode", "embed", "classify", "score"} @@ -288,12 +291,13 @@ def forward_all( hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, ) -> Union[list[torch.Tensor], torch.Tensor]: - - assert not pooling_cursor.is_partial_prefill(), \ + assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with MEAN pooling" + ) - prompt_lens = pooling_cursor.prompt_lens_cpu.to(hidden_states.device, - non_blocking=True) + prompt_lens = pooling_cursor.prompt_lens_cpu.to( + hidden_states.device, non_blocking=True + ) # Use float32 for torch.cumsum in MeanPool, # otherwise precision will be lost significantly. @@ -301,15 +305,15 @@ def forward_all( start_indices = pooling_cursor.first_token_indices_gpu end_indices = pooling_cursor.last_token_indices_gpu - return (cumsum[end_indices] - cumsum[start_indices] + - hidden_states[start_indices]) / prompt_lens.unsqueeze(1) + return ( + cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices] + ) / prompt_lens.unsqueeze(1) _T = TypeVar("_T", torch.Tensor, list[torch.Tensor]) class BasePoolerActivation(nn.Module, ABC): - @abstractmethod def forward(self, pooled_data: _T) -> _T: # shape: @@ -320,7 +324,6 @@ def forward(self, pooled_data: _T) -> _T: class PoolerActivation(BasePoolerActivation): - @staticmethod def wraps(module: nn.Module): if isinstance(module, nn.Identity): @@ -342,42 +345,42 @@ def forward(self, pooled_data: _T) -> _T: class PoolerIdentity(PoolerActivation): - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: return pooled_data class PoolerNormalize(PoolerActivation): - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: return F.normalize(pooled_data, p=2, dim=-1) class PoolerMultiLabelClassify(PoolerActivation): - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: return F.sigmoid(pooled_data) class PoolerClassify(PoolerActivation): - def __init__(self, *, static_num_labels: bool = True) -> None: super().__init__() if static_num_labels: vllm_config = get_current_vllm_config() - self.num_labels = getattr(vllm_config.model_config.hf_config, - "num_labels", 0) + self.num_labels = getattr( + vllm_config.model_config.hf_config, "num_labels", 0 + ) if self.num_labels == 0: - logger.warning("num_labels should be > 0 for classification" - "models, falling back to softmax. " - "Please check if the configuration is correct.") + logger.warning( + "num_labels should be > 0 for classification" + "models, falling back to softmax. " + "Please check if the configuration is correct." + ) else: self.num_labels = None def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - num_labels = (self.num_labels if self.num_labels is not None else - pooled_data.shape[-1]) + num_labels = ( + self.num_labels if self.num_labels is not None else pooled_data.shape[-1] + ) if num_labels < 2: return F.sigmoid(pooled_data) @@ -386,7 +389,6 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class LambdaPoolerActivation(PoolerActivation): - def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]): super().__init__() @@ -397,32 +399,35 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: class PoolerHead(nn.Module): - def __init__(self, activation: PoolerActivation) -> None: super().__init__() self.activation = activation - def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], - pooling_metadata: PoolingMetadata): - + def forward( + self, + pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata, + ): return self.activation(pooled_data) class EmbeddingPoolerHead(PoolerHead): - def __init__(self) -> None: super().__init__(activation=PoolerNormalize()) # Load ST projector if available vllm_config = get_current_vllm_config() - self.projector: Optional[nn.Module] = _load_st_projector( - vllm_config.model_config) if vllm_config else None + self.projector: Optional[nn.Module] = ( + _load_st_projector(vllm_config.model_config) if vllm_config else None + ) self.head_dtype = vllm_config.model_config.head_dtype - def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], - pooling_metadata: PoolingMetadata): - + def forward( + self, + pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata, + ): if isinstance(pooled_data, list): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_dimension] @@ -437,14 +442,11 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_params = get_pooling_params(pooling_metadata) # for matryoshka representation - dimensions_list = [ - pooling_param.dimensions for pooling_param in pooling_params - ] + dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params] if any(d is not None for d in dimensions_list): # change the output dimension assert len(pooled_data) == len(dimensions_list) - if len(set(dimensions_list)) == 1 and not isinstance( - pooled_data, list): + if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list): # if all dimensions are the same d = dimensions_list[0] pooled_data = pooled_data[..., :d] @@ -470,16 +472,17 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], class RewardPoolerHead(PoolerHead): - def __init__(self) -> None: super().__init__(activation=PoolerClassify(static_num_labels=False)) vllm_config = get_current_vllm_config() self.head_dtype = vllm_config.model_config.head_dtype - def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], - pooling_metadata: PoolingMetadata): - + def forward( + self, + pooled_data: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata, + ): if isinstance(pooled_data, list): pooled_data = [p.to(self.head_dtype) for p in pooled_data] else: @@ -547,8 +550,9 @@ def forward( class StepPooler(Pooler): - - def __init__(self, ) -> None: + def __init__( + self, + ) -> None: super().__init__() self.pooling = AllPool() @@ -566,9 +570,9 @@ def extract_states( pooling_params = get_pooling_params(pooling_metadata) - for data, token_id, pooling_param in zip(pooled_data_lst, - prompt_token_ids, - pooling_params): + for data, token_id, pooling_param in zip( + pooled_data_lst, prompt_token_ids, pooling_params + ): step_tag_id = pooling_param.step_tag_id returned_token_ids = pooling_param.returned_token_ids @@ -627,8 +631,9 @@ def __init__( self.pooling = pooling self.classifier = classifier self.act_fn = act_fn or PoolerClassify() - self.logit_bias: Optional[ - float] = vllm_config.model_config.pooler_config.logit_bias + self.logit_bias: Optional[float] = ( + vllm_config.model_config.pooler_config.logit_bias + ) self.head_dtype = vllm_config.model_config.head_dtype def get_supported_tasks(self) -> Set[PoolingTask]: @@ -660,8 +665,7 @@ def forward( scores = self.act_fn(pooled_data) if flags[0] else pooled_data else: scores = [ - self.act_fn(vecs) if f else vecs - for vecs, f in zip(pooled_data, flags) + self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags) ] # scores shape: [batchsize, num_labels] @@ -678,7 +682,8 @@ def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None: if task not in pooler.get_supported_tasks(): raise ValueError( f"{pooler=} does not support {task=}. " - f"Supported tasks: {pooler.get_supported_tasks()}") + f"Supported tasks: {pooler.get_supported_tasks()}" + ) self.poolers_by_task = poolers_by_task @@ -701,12 +706,13 @@ def forward( if not (pooler := poolers_by_task.get(task)): raise ValueError( f"Unsupported task: {task} " - f"Supported tasks: {self.get_supported_tasks()}") + f"Supported tasks: {self.get_supported_tasks()}" + ) num_items = len(list(group)) group_output: PoolerOutput = pooler( hidden_states, - pooling_metadata[offset:offset + num_items], + pooling_metadata[offset : offset + num_items], ) outputs.extend(group_output) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 8cac47b5a39a..9d1c66e56e91 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -3,8 +3,7 @@ from typing import Literal, get_args -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig QuantizationMethods = Literal[ "awq", @@ -52,9 +51,13 @@ def register_quantization_config(quantization: str): quantization (str): The quantization method name. Examples: - >>> from vllm.model_executor.layers.quantization import register_quantization_config + >>> from vllm.model_executor.layers.quantization import ( + ... register_quantization_config, + ... ) >>> from vllm.model_executor.layers.quantization import get_quantization_config - >>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + >>> from vllm.model_executor.layers.quantization.base_config import ( + ... QuantizationConfig, + ... ) >>> >>> @register_quantization_config("my_quant") ... class MyQuantConfig(QuantizationConfig): @@ -67,10 +70,12 @@ def register_quantization_config(quantization: str): def _wrapper(quant_config_cls): if quantization in QUANTIZATION_METHODS: raise ValueError( - f"The quantization method `{quantization}` is already exists.") + f"The quantization method `{quantization}` is already exists." + ) if not issubclass(quant_config_cls, QuantizationConfig): - raise ValueError("The quantization config must be a subclass of " - "`QuantizationConfig`.") + raise ValueError( + "The quantization config must be a subclass of `QuantizationConfig`." + ) _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls QUANTIZATION_METHODS.append(quantization) return quant_config_cls @@ -90,8 +95,9 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .awq_marlin import AWQMarlinConfig from .bitblas import BitBLASConfig from .bitsandbytes import BitsAndBytesConfig - from .compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsConfig) + from .compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, + ) from .deepspeedfp import DeepSpeedFPConfig from .experts_int8 import ExpertsInt8Config from .fbgemm_fp8 import FBGEMMFp8Config diff --git a/vllm/model_executor/layers/quantization/auto_round.py b/vllm/model_executor/layers/quantization/auto_round.py index eb7600af3371..b7ebc6f272db 100644 --- a/vllm/model_executor/layers/quantization/auto_round.py +++ b/vllm/model_executor/layers/quantization/auto_round.py @@ -7,10 +7,11 @@ import torch from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization import (QuantizationConfig, - QuantizationMethods) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -52,36 +53,45 @@ def __init__( ) -> None: super().__init__() if weight_bits not in self.SUPPORTED_BITS: - raise ValueError(f"Unsupported weight_bits: {weight_bits}, " - f"currently only support {self.SUPPORTED_BITS}") + raise ValueError( + f"Unsupported weight_bits: {weight_bits}, " + f"currently only support {self.SUPPORTED_BITS}" + ) if data_type not in self.SUPPORTED_DTYPES: raise ValueError( f"Unsupported data_type: {data_type}," - f" currently only support {self.SUPPORTED_DTYPES}") + f" currently only support {self.SUPPORTED_DTYPES}" + ) if packing_format not in self.SUPPORTED_FORMATS: raise ValueError( f"Unsupported packing_format: {packing_format}, " - f"currently only support {self.SUPPORTED_FORMATS}") + f"currently only support {self.SUPPORTED_FORMATS}" + ) if backend not in self.SUPPORTED_BACKENDS: raise ValueError( f"Unsupported backend: {backend}, " - f"currently only support {self.SUPPORTED_BACKENDS}") + f"currently only support {self.SUPPORTED_BACKENDS}" + ) self.weight_bits = weight_bits self.group_size = group_size self.sym = sym self.packing_format = packing_format - self.block_name_to_quantize = (block_name_to_quantize.split(",") if - isinstance(block_name_to_quantize, str) - else block_name_to_quantize) + self.block_name_to_quantize = ( + block_name_to_quantize.split(",") + if isinstance(block_name_to_quantize, str) + else block_name_to_quantize + ) self.extra_config = extra_config self.data_type = data_type self.backend = backend self.pack_factor = Fraction(32, weight_bits) def __repr__(self) -> str: - return (f"AutoRoundConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, sym={self.sym})") + return ( + f"AutoRoundConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, sym={self.sym})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -105,19 +115,18 @@ def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": weight_bits=cls.get_from_keys(config, ["bits"]), group_size=cls.get_from_keys(config, ["group_size"]), sym=cls.get_from_keys(config, ["sym"]), - packing_format=cls.get_from_keys_or(config, ["packing_format"], - "auto_round:auto_gptq"), + packing_format=cls.get_from_keys_or( + config, ["packing_format"], "auto_round:auto_gptq" + ), block_name_to_quantize=cls.get_from_keys_or( - config, ["block_name_to_quantize", "to_quant_block_names"], - None), + config, ["block_name_to_quantize", "to_quant_block_names"], None + ), extra_config=cls.get_from_keys_or(config, ["extra_config"], None), data_type=cls.get_from_keys_or(config, ["data_type"], "int"), - backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], - "auto"), + backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], "auto"), ) def get_layer_config(self, layer, layer_name: str): - def get_config(name: str, quantized: bool = True): cfg = self.extra_config.get(name, {}) if self.extra_config else {} return ( @@ -134,39 +143,38 @@ def get_config(name: str, quantized: bool = True): quantized = not isinstance(layer, ParallelLMHead) if self.block_name_to_quantize: quantized = any( - layer_name.startswith(name) - for name in self.block_name_to_quantize) + layer_name.startswith(name) for name in self.block_name_to_quantize + ) # 3. Handle fused MoE - if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower( - ): + if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower(): moe_configs = [ - get_config(name, quantized) for name in self.extra_config + get_config(name, quantized) + for name in self.extra_config if name.startswith(layer_name) ] if moe_configs: if len(set(moe_configs)) == 1: return moe_configs[0] - raise ValueError(f"Fused MoE layer '{layer_name}' requires " - f"consistent quant config for all sub-layers") + raise ValueError( + f"Fused MoE layer '{layer_name}' requires " + f"consistent quant config for all sub-layers" + ) # 4. Handle fused QKV or other patterns if self.extra_config: for fusion_key, sub_keys in self.packed_modules_mapping.items(): - if fusion_key in layer_name and layer_name.count( - fusion_key) == 1: + if fusion_key in layer_name and layer_name.count(fusion_key) == 1: sub_names = [ - layer_name.replace(fusion_key, sub_key) - for sub_key in sub_keys - ] - sub_configs = [ - get_config(name, quantized) for name in sub_names + layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys ] + sub_configs = [get_config(name, quantized) for name in sub_names] if len(set(sub_configs)) == 1: return sub_configs[0] raise ValueError( f"Fused module '{layer_name}' requires " - f"consistent quant config for {sub_names}") + f"consistent quant config for {sub_names}" + ) # 5. Fallback return get_config(layer_name, quantized) @@ -177,14 +185,17 @@ def check_quantized(self, weight_bits: int) -> bool: def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.block_name_to_quantize is not None: self.block_name_to_quantize = hf_to_vllm_mapper.apply_list( - self.block_name_to_quantize) + self.block_name_to_quantize + ) if self.extra_config is not None: self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config) def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, check_moe_marlin_supports_layer) + check_marlin_supported, + check_moe_marlin_supports_layer, + ) weight_bits, group_size, sym = self.get_layer_config(layer, prefix) if not self.check_quantized(weight_bits): @@ -206,19 +217,23 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): 4: scalar_types.uint4, 8: scalar_types.uint8, } - use_marlin = (weight_bits - in AWQ_TYPE_MAP) and check_marlin_supported( - AWQ_TYPE_MAP[weight_bits], group_size, not sym) + use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported( + AWQ_TYPE_MAP[weight_bits], group_size, not sym + ) if isinstance(layer, FusedMoE): use_marlin = use_marlin and check_moe_marlin_supports_layer( - layer, group_size) + layer, group_size + ) else: use_marlin = False if use_marlin: from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod) + AWQMarlinConfig, + AWQMarlinLinearMethod, + AWQMoEMethod, + ) quant_args_marlin = AWQMarlinConfig( weight_bits=weight_bits, @@ -230,7 +245,9 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): ) else: from vllm.model_executor.layers.quantization.awq import ( - AWQConfig, AWQLinearMethod) + AWQConfig, + AWQLinearMethod, + ) quant_args = AWQConfig( weight_bits=weight_bits, @@ -241,8 +258,7 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): if isinstance(layer, FusedMoE): if use_marlin: return AWQMoEMethod(quant_args_marlin, layer.moe_config) - from vllm.model_executor.layers.quantization.moe_wna16 import ( - MoeWNA16Config) + from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config config = { "quant_method": "awq", @@ -251,8 +267,7 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): "zero_point": not sym, "lm_head": False, } - return MoeWNA16Config.from_config(config).get_quant_method( - layer, prefix) + return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: @@ -261,13 +276,12 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): return AWQLinearMethod(quant_args) return None - def apply_gptq_quant_layer(self, - layer, - prefix: str, - backend: str = "auto"): + def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"): from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, check_moe_marlin_supports_layer) + check_marlin_supported, + check_moe_marlin_supports_layer, + ) weight_bits, group_size, sym = self.get_layer_config(layer, prefix) if not self.check_quantized(weight_bits): @@ -289,19 +303,21 @@ def apply_gptq_quant_layer(self, (4, True): scalar_types.uint4b8, (8, True): scalar_types.uint8b128, } - use_marlin = (weight_bits, - sym) in GPTQ_TYPE_MAP and check_marlin_supported( - GPTQ_TYPE_MAP[(weight_bits, sym)], - group_size, - has_zp=not sym) + use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported( + GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym + ) if isinstance(layer, FusedMoE): use_marlin = use_marlin and check_moe_marlin_supports_layer( - layer, group_size) + layer, group_size + ) else: use_marlin = False if use_marlin: from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod) + GPTQMarlinConfig, + GPTQMarlinLinearMethod, + GPTQMarlinMoEMethod, + ) quant_args_marlin = GPTQMarlinConfig( weight_bits=weight_bits, @@ -314,7 +330,9 @@ def apply_gptq_quant_layer(self, ) else: from vllm.model_executor.layers.quantization.gptq import ( - GPTQConfig, GPTQLinearMethod) + GPTQConfig, + GPTQLinearMethod, + ) quant_args = GPTQConfig( weight_bits=weight_bits, @@ -329,7 +347,8 @@ def apply_gptq_quant_layer(self, return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config) else: from vllm.model_executor.layers.quantization.moe_wna16 import ( - MoeWNA16Config) + MoeWNA16Config, + ) config = { "quant_method": "gptq", @@ -339,7 +358,8 @@ def apply_gptq_quant_layer(self, "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method( - layer, prefix) + layer, prefix + ) if isinstance(layer, (LinearBase, ParallelLMHead)): if use_marlin: @@ -357,29 +377,36 @@ def apply_ipex_quant_layer(self, layer, prefix: str): else: return None from vllm.model_executor.layers.quantization.ipex_quant import ( - IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod) + IPEXAWQLinearMethod, + IPEXConfig, + IPEXGPTQLinearMethod, + ) if isinstance(layer, (LinearBase, ParallelLMHead)): if "awq" in self.packing_format: - config = IPEXConfig(method="awq", - weight_bits=weight_bits, - group_size=group_size) + config = IPEXConfig( + method="awq", weight_bits=weight_bits, group_size=group_size + ) return IPEXAWQLinearMethod(config) elif "gptq" in self.packing_format: - config = IPEXConfig(method="gptq", - weight_bits=weight_bits, - group_size=group_size) + config = IPEXConfig( + method="gptq", weight_bits=weight_bits, group_size=group_size + ) return IPEXGPTQLinearMethod(config) else: raise ValueError( f"ipex backend only supports awq " - f"and gtpq format,but got {self.packing_format}") + f"and gtpq format,but got {self.packing_format}" + ) else: return None def get_quant_method(self, layer: torch.nn.Module, prefix: str): - if (current_platform.is_cpu() or current_platform.is_xpu() - or self.backend == "ipex"): + if ( + current_platform.is_cpu() + or current_platform.is_xpu() + or self.backend == "ipex" + ): return self.apply_ipex_quant_layer(layer, prefix) if "gptq" in self.packing_format or "gptq" in self.backend: return self.apply_gptq_quant_layer(layer, prefix) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index af602eb9aca3..d4f667564848 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -8,13 +8,17 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import FusedMoE -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.parameter import (GroupQuantScaleParameter, - PackedvLLMParameter) + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter logger = init_logger(__name__) @@ -41,14 +45,17 @@ def __init__( if self.weight_bits != 4: raise ValueError( "Currently, only 4-bit weight quantization is supported for " - f"AWQ, but got {self.weight_bits} bits.") + f"AWQ, but got {self.weight_bits} bits." + ) self.pack_factor = 32 // self.weight_bits def __repr__(self) -> str: - return (f"AWQConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"zero_point={self.zero_point}, " - f"modules_to_not_convert={self.modules_to_not_convert})") + return ( + f"AWQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) def get_name(self) -> QuantizationMethods: return "awq" @@ -75,7 +82,8 @@ def from_config(cls, config: dict[str, Any]) -> "AWQConfig": group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) modules_to_not_convert = cls.get_from_keys_or( - config, ["modules_to_not_convert"], None) + config, ["modules_to_not_convert"], None + ) return cls(weight_bits, group_size, zero_point, modules_to_not_convert) def get_quant_method( @@ -90,10 +98,12 @@ def get_quant_method( from .awq_marlin import AWQMarlinConfig, AWQMoEMethod from .moe_wna16 import MoeWNA16Config from .utils.marlin_utils import check_moe_marlin_supports_layer + if not check_moe_marlin_supports_layer(layer, self.group_size): logger.warning_once( f"Layer '{prefix}' is not supported by AWQMoeMarlin. " - "Falling back to Moe WNA16 kernels.") + "Falling back to Moe WNA16 kernels." + ) config = { "quant_method": "awq", "bits": self.weight_bits, @@ -102,7 +112,8 @@ def get_quant_method( "lm_head": False, } return MoeWNA16Config.from_config(config).get_quant_method( - layer, prefix) + layer, prefix + ) marlin_compatible_config_dict = { "quant_method": "awq", "bits": self.weight_bits, @@ -112,7 +123,8 @@ def get_quant_method( "modules_to_not_convert": self.modules_to_not_convert, } awq_marlin_config = AWQMarlinConfig.from_config( - marlin_compatible_config_dict) + marlin_compatible_config_dict + ) return AWQMoEMethod(awq_marlin_config, layer.moe_config) return None @@ -131,11 +143,16 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size @@ -146,14 +163,16 @@ def create_weights(self, layer: torch.nn.Module, raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " - "tensor parallel size.") + "tensor parallel size." + ) output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " - "tensor parallel size.") + "tensor parallel size." + ) weight_loader = extra_weight_attrs.get("weight_loader") qweight = PackedvLLMParameter( @@ -166,7 +185,8 @@ def create_weights(self, layer: torch.nn.Module, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) num_groups = input_size_per_partition // group_size @@ -180,38 +200,40 @@ def create_weights(self, layer: torch.nn.Module, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - scales = GroupQuantScaleParameter(data=torch.empty( - num_groups, - output_size_per_partition, - dtype=params_dtype, - ), - input_dim=0, - output_dim=1, - weight_loader=weight_loader) + scales = GroupQuantScaleParameter( + data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("qzeros", qzeros) layer.register_parameter("scales", scales) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.qweight = torch.nn.Parameter(layer.qweight.data, - requires_grad=False) - layer.qzeros = torch.nn.Parameter(layer.qzeros.data, - requires_grad=False) - layer.scales = torch.nn.Parameter(layer.scales.data, - requires_grad=False) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: qweight = layer.qweight scales = layer.scales qzeros = layer.qzeros pack_factor = self.quant_config.pack_factor - out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) + out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) reshaped_x = x.reshape(-1, x.shape[-1]) # num_tokens >= threshold @@ -221,8 +243,7 @@ def apply(self, out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) out = torch.matmul(reshaped_x, out) else: - out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, - pack_factor) + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: out.add_(bias) return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 6bf6ea914651..5d142387d4d9 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -9,30 +9,46 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, - FusedMoEQuantConfig) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, - UnquantizedFusedMoEMethod) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod, - set_weight_attrs) + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, + set_weight_attrs, +) from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.awq import (AWQConfig, - is_layer_skipped_awq) +from vllm.model_executor.layers.quantization.awq import AWQConfig, is_layer_skipped_awq from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, - check_marlin_supports_layer, check_moe_marlin_supports_layer, - marlin_make_empty_g_idx, marlin_make_workspace_new, - marlin_moe_permute_scales, marlin_permute_bias, marlin_permute_scales, - moe_awq_to_marlin_zero_points, verify_marlin_supported, - verify_marlin_supports_shape) + apply_awq_marlin_linear, + awq_to_marlin_zero_points, + check_marlin_supported, + check_marlin_supports_layer, + check_moe_marlin_supports_layer, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + marlin_moe_permute_scales, + marlin_permute_bias, + marlin_permute_scales, + moe_awq_to_marlin_zero_points, + verify_marlin_supported, + verify_marlin_supports_shape, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (GroupQuantScaleParameter, - PackedvLLMParameter) +from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -48,10 +64,15 @@ class AWQMarlinConfig(QuantizationConfig): 8: scalar_types.uint8, } - def __init__(self, weight_bits: int, group_size: int, zero_point: bool, - lm_head_quantized: bool, - modules_to_not_convert: Optional[list[str]], - full_config: dict[str, Any]) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[list[str]], + full_config: dict[str, Any], + ) -> None: super().__init__() self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size @@ -62,21 +83,25 @@ def __init__(self, weight_bits: int, group_size: int, zero_point: bool, self.full_config = full_config if self.weight_bits not in self.TYPE_MAP: - raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " - f"Supported num_bits = {self.TYPE_MAP.keys()}") + raise ValueError( + f"Unsupported num_bits = {self.weight_bits}. " + f"Supported num_bits = {self.TYPE_MAP.keys()}" + ) self.quant_type = self.TYPE_MAP[self.weight_bits] - verify_marlin_supported(self.quant_type, - group_size=self.group_size, - has_zp=self.zero_point) + verify_marlin_supported( + self.quant_type, group_size=self.group_size, has_zp=self.zero_point + ) def __repr__(self) -> str: - return (f"AWQMarlinConfig(quant_type={self.quant_type}, " - f"group_size={self.group_size}, " - f"zero_point={self.zero_point}, " - f"lm_head_quantized={self.lm_head_quantized}, " - f"modules_to_not_convert={self.modules_to_not_convert})") + return ( + f"AWQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"lm_head_quantized={self.lm_head_quantized}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -99,37 +124,51 @@ def from_config(cls, config: dict[str, Any]) -> "AWQMarlinConfig": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) modules_to_not_convert = cls.get_from_keys_or( - config, ["modules_to_not_convert"], None) - return cls(weight_bits, group_size, zero_point, lm_head_quantized, - modules_to_not_convert, config) + config, ["modules_to_not_convert"], None + ) + return cls( + weight_bits, + group_size, + zero_point, + lm_head_quantized, + modules_to_not_convert, + config, + ) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) - is_valid_user_quant = (user_quant is None or user_quant == "marlin" - or user_quant == "awq_marlin") + is_valid_user_quant = ( + user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin" + ) if can_convert and is_valid_user_quant: - msg = ("The model is convertible to {} during runtime." - " Using {} kernel.".format(cls.get_name(), cls.get_name())) + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) logger.info(msg) return cls.get_name() if can_convert and user_quant == "awq": - logger.info("Detected that the model can run with awq_marlin" - ", however you specified quantization=awq explicitly," - " so forcing awq. Use quantization=awq_marlin for" - " faster inference") + logger.info( + "Detected that the model can run with awq_marlin" + ", however you specified quantization=awq explicitly," + " so forcing awq. Use quantization=awq_marlin for" + " faster inference" + ) return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() # Check if the layer is supported by AWQMarlin. @@ -138,21 +177,25 @@ def get_quant_method(self, layer: torch.nn.Module, "Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501 prefix, ) - return AWQConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + return AWQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): - from vllm.model_executor.layers.quantization.moe_wna16 import ( - MoeWNA16Config) + from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config + if is_layer_skipped_awq( - prefix, getattr(self, "modules_to_not_convert", [])): + prefix, getattr(self, "modules_to_not_convert", []) + ): return UnquantizedFusedMoEMethod(layer.moe_config) if not check_moe_marlin_supports_layer(layer, self.group_size): logger.warning_once( f"Layer '{prefix}' is not supported by AWQMoeMarlin. " - "Falling back to Moe WNA16 kernels.") - return MoeWNA16Config.from_config( - self.full_config).get_quant_method(layer, prefix) + "Falling back to Moe WNA16 kernels." + ) + return MoeWNA16Config.from_config(self.full_config).get_quant_method( + layer, prefix + ) return AWQMoEMethod(self, layer.moe_config) return None @@ -171,15 +214,15 @@ def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]): return False # If we cannot find the info needed in the config, cannot convert. - if (num_bits is None or group_size is None or zero_point is None): + if num_bits is None or group_size is None or zero_point is None: return False if num_bits not in cls.TYPE_MAP: return False - return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], - group_size=group_size, - has_zp=zero_point) + return check_marlin_supported( + quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point + ) class AWQMarlinLinearMethod(LinearMethodBase): @@ -216,7 +259,8 @@ def create_weights( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, - group_size=group_size) + group_size=group_size, + ) qweight = PackedvLLMParameter( data=torch.empty( @@ -228,7 +272,8 @@ def create_weights( output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) num_groups = input_size_per_partition // group_size @@ -242,16 +287,19 @@ def create_weights( output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - scales = GroupQuantScaleParameter(data=torch.empty( - num_groups, - output_size_per_partition, - dtype=params_dtype, - ), - input_dim=0, - output_dim=1, - weight_loader=weight_loader) + scales = GroupQuantScaleParameter( + data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("qzeros", qzeros) @@ -267,12 +315,9 @@ def create_weights( # Here, we handle the repacking def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = layer.qweight.device - layer.qweight = torch.nn.Parameter(layer.qweight.data, - requires_grad=False) - layer.qzeros = torch.nn.Parameter(layer.qzeros.data, - requires_grad=False) - layer.scales = torch.nn.Parameter(layer.scales.data, - requires_grad=False) + layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) # Allocate marlin workspace layer.workspace = marlin_make_workspace_new(device) @@ -282,7 +327,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qweight, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) + num_bits=self.quant_config.quant_type.size_bits, + ) replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. @@ -290,7 +336,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scales, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. @@ -298,7 +345,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qzeros, size_k=layer.num_groups, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) + num_bits=self.quant_config.quant_type.size_bits, + ) replace_parameter(layer, "qzeros", marlin_zp) # Not-used @@ -325,11 +373,11 @@ def apply( quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, - bias=bias) + bias=bias, + ) class AWQMoEMethod(FusedMoEMethodBase): - def __init__( self, quant_config: AWQMarlinConfig, @@ -341,75 +389,93 @@ def __init__( raise ValueError("AWQMoEMethod only supports 4bit now.") self.quant_type = scalar_types.uint4 - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - extra_weight_attrs.update({ - "is_transposed": - True, - "quant_method": - FusedMoeWeightScaleSupported.GROUP.value, - }) + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + extra_weight_attrs.update( + { + "is_transposed": True, + "quant_method": FusedMoeWeightScaleSupported.GROUP.value, + } + ) w13_qweight = Parameter( - torch.empty(num_experts, - hidden_size, - 2 * intermediate_size_per_partition // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + torch.empty( + num_experts, + hidden_size, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs(w13_qweight, extra_weight_attrs) - w2_qweight = Parameter(torch.empty(num_experts, - intermediate_size_per_partition, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + w2_qweight = Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs(w2_qweight, extra_weight_attrs) num_groups_w13 = hidden_size // self.quant_config.group_size - num_groups_w2 = (intermediate_size_per_partition // - self.quant_config.group_size) + num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. - w13_scales = Parameter(torch.empty(num_experts, - num_groups_w13, - intermediate_size_per_partition * 2, - dtype=params_dtype), - requires_grad=False) + w13_scales = Parameter( + torch.empty( + num_experts, + num_groups_w13, + intermediate_size_per_partition * 2, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) - w2_scales = Parameter(torch.empty(num_experts, - num_groups_w2, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w2_scales = Parameter( + torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) # WEIGHT_ZERO_POINT # Allocate 2 zero points for w1 and w3 respectively. w13_qzeros = Parameter( - torch.empty(num_experts, - num_groups_w13, - 2 * intermediate_size_per_partition // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + torch.empty( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w13_qzeros", w13_qzeros) set_weight_attrs(w13_qzeros, extra_weight_attrs) - w2_qzeros = Parameter(torch.empty(num_experts, - num_groups_w2, - hidden_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + w2_qzeros = Parameter( + torch.empty( + num_experts, + num_groups_w2, + hidden_size // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) @@ -469,14 +535,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_qzeros, size_k=layer.w13_qzeros.shape[1], size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.weight_bits, + ) replace_parameter(layer, "w13_qzeros", marlin_w13_zp) marlin_w2_zp = moe_awq_to_marlin_zero_points( layer.w2_qzeros, size_k=layer.w2_qzeros.shape[1], size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, - num_bits=self.quant_config.weight_bits) + num_bits=self.quant_config.weight_bits, + ) replace_parameter(layer, "w2_qzeros", marlin_w2_zp) if hasattr(layer, "w13_bias") and layer.w13_bias is not None: @@ -486,7 +554,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: return None def apply( @@ -515,8 +584,7 @@ def apply( assert self.fused_experts is None if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `AWQMoEMethod` yet.") + raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.") assert activation == "silu", "Only SiLU activation is supported." @@ -532,7 +600,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return torch.ops.vllm.fused_marlin_moe( x, @@ -551,4 +620,5 @@ def apply( expert_map=expert_map, w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, - workspace=layer.workspace) + workspace=layer.workspace, + ) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index 2e8894436a98..67b4dbbfd4d8 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -10,15 +10,16 @@ @triton.jit def awq_dequantize_kernel( - qweight_ptr, # quantized matrix - scales_ptr, # scales, per group - zeros_ptr, # zeros, per group - group_size, # Should always be one of the supported group sizes - result_ptr, # Output matrix - num_cols, # input num cols in qweight - num_rows, # input num rows in qweight - BLOCK_SIZE_X: tl.constexpr, - BLOCK_SIZE_Y: tl.constexpr): + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr, +): # Set up the pids. pid_x = tl.program_id(axis=0) pid_y = tl.program_id(axis=1) @@ -35,10 +36,10 @@ def awq_dequantize_kernel( # Compute offsets and masks for result output ptr. result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) - result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange( - 0, BLOCK_SIZE_X * 8) - result_offsets = (8 * num_cols * result_offsets_y[:, None] + - result_offsets_x[None, :]) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + result_offsets = ( + 8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :] + ) result_masks_y = result_offsets_y < num_rows result_masks_x = result_offsets_x < num_cols * 8 @@ -52,8 +53,9 @@ def awq_dequantize_kernel( # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] # that will map given indices to the correct order. - reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + - tl.arange(0, 4)[:, None]).reshape(8) + reverse_awq_order_tensor = ( + (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] + ).reshape(8) # Use this to compute a set of shifts that can be used to unpack and # reorder the values in iweights and zeros. @@ -85,10 +87,8 @@ def awq_dequantize_kernel( # Compute scale offsets and masks. scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) - scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + - tl.arange(0, BLOCK_SIZE_X * 8)) - scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + - scale_offsets_x[None, :]) + scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :] scale_masks_y = scale_offsets_y < num_rows // group_size scale_masks_x = scale_offsets_x < num_cols * 8 scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] @@ -106,10 +106,21 @@ def awq_dequantize_kernel( @triton.jit -def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, - group_size, BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - SPLIT_K: tl.constexpr): +def awq_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + zeros_ptr, + scales_ptr, + M, + N, + K, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): pid = tl.program_id(axis=0) pid_z = tl.program_id(1) @@ -128,18 +139,17 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, # (BLOCK_SIZE_M, BLOCK_SIZE_N)) # accumulator = accumulator & 0x0 # accumulator = accumulator.to(accumulator_dtype) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), - dtype=accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] # that will map given indices to the correct order. - reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + - tl.arange(0, 4)[:, None]).reshape(8) + reverse_awq_order_tensor = ( + (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] + ).reshape(8) # Create the necessary shifts to use to unpack. shifts = reverse_awq_order_tensor * 4 - shifts = tl.broadcast_to(shifts[None, :], - (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) # Offsets and masks. @@ -178,8 +188,8 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, # Dequantize b. offsets_szk = ( - (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + - tl.arange(0, 1)) + BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K + ) // group_size + tl.arange(0, 1) offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] masks_zk = offsets_szk < K // group_size masks_z = masks_zk[:, None] & masks_zn[None, :] @@ -220,11 +230,13 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, # qweights - [K , M // 8], int32 # scales - [K // G, M ], float16 # zeros - [K // G, M // 8], int32 -def awq_dequantize_triton(qweight: torch.Tensor, - scales: torch.Tensor, - zeros: torch.Tensor, - block_size_x: int = 32, - block_size_y: int = 32) -> torch.Tensor: +def awq_dequantize_triton( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32, +) -> torch.Tensor: K = qweight.shape[0] M = scales.shape[1] group_size = qweight.shape[0] // scales.shape[0] @@ -238,27 +250,31 @@ def awq_dequantize_triton(qweight: torch.Tensor, # Result tensor: # number of rows = same as input tensor # number of cols = 8 x input tensor num cols - result = torch.empty(qweight.shape[0], - qweight.shape[1] * 8, - device=qweight.device, - dtype=scales.dtype) + result = torch.empty( + qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype, + ) Y = qweight.shape[0] # num rows X = qweight.shape[1] # num cols grid = lambda META: ( - triton.cdiv(X, META['BLOCK_SIZE_X']), - triton.cdiv(Y, META['BLOCK_SIZE_Y']), + triton.cdiv(X, META["BLOCK_SIZE_X"]), + triton.cdiv(Y, META["BLOCK_SIZE_Y"]), + ) + awq_dequantize_kernel[grid]( + qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y, ) - awq_dequantize_kernel[grid](qweight, - scales, - zeros, - group_size, - result, - X, - Y, - BLOCK_SIZE_X=block_size_x, - BLOCK_SIZE_Y=block_size_y) return result @@ -268,14 +284,16 @@ def awq_dequantize_triton(qweight: torch.Tensor, # qzeros - [K // G, N // 8] # scales - [K // G, N] # split_k_iters - parallelism along K-dimension, int, power of 2. -def awq_gemm_triton(input: torch.Tensor, - qweight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - split_k_iters: int, - block_size_m: int = 32, - block_size_n: int = 32, - block_size_k: int = 32) -> torch.Tensor: +def awq_gemm_triton( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, +) -> torch.Tensor: M, K = input.shape N = qweight.shape[1] * 8 group_size = qweight.shape[0] // qzeros.shape[0] @@ -290,30 +308,29 @@ def awq_gemm_triton(input: torch.Tensor, assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - N, META['BLOCK_SIZE_N']), + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), split_k_iters, ) - result = torch.zeros((split_k_iters, M, N), - dtype=scales.dtype, - device=input.device) + result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device) # A = input, B = qweight, C = result # A = M x K, B = K x N, C = M x N - awq_gemm_kernel[grid](input, - qweight, - result, - qzeros, - scales, - M, - N, - K, - group_size, - BLOCK_SIZE_M=block_size_m, - BLOCK_SIZE_N=block_size_n, - BLOCK_SIZE_K=block_size_k, - SPLIT_K=split_k_iters) + awq_gemm_kernel[grid]( + input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters, + ) result = result.sum(0) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 807a9866a18b..26f5e8bb6c7d 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -19,8 +19,9 @@ class QuantizeMethodBase(ABC): """Base class for different quantized methods.""" @abstractmethod - def create_weights(self, layer: torch.nn.Module, *weight_args, - **extra_weight_attrs): + def create_weights( + self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs + ): """Create weights for a layer. The weights will be set as attributes of the layer.""" @@ -34,8 +35,7 @@ def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: raise NotImplementedError # Not required functions - def embedding(self, layer: torch.nn.Module, *args, - **kwargs) -> torch.Tensor: + def embedding(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: """Gather embeddings in the layer based on indices in the input tensor. Expects create_weights to have been called before on the layer.""" @@ -49,19 +49,16 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: return -def method_has_implemented_embedding( - method_class: type[QuantizeMethodBase]) -> bool: +def method_has_implemented_embedding(method_class: type[QuantizeMethodBase]) -> bool: """ Not all quant methods have embedding implemented, so we need to check that it exists for our given method. We check this by making sure the function has been changed from the base implementation. """ - base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", - None) + base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None) class_embedding = inspect.getattr_static(method_class, "embedding", None) - return (class_embedding is not None - and class_embedding is not base_embedding) + return class_embedding is not None and class_embedding is not base_embedding class QuantizationConfig(ABC): @@ -107,12 +104,13 @@ def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig": @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: """ - Detects if this quantization method can support a given checkpoint - format by overriding the user specified quantization method -- - this method should only be overwritten by subclasses in exceptional - circumstances + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances """ return None @@ -122,12 +120,12 @@ def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any: for key in keys: if key in config: return config[key] - raise ValueError(f"Cannot find any of {keys} in the model's " - "quantization config.") + raise ValueError( + f"Cannot find any of {keys} in the model's quantization config." + ) @staticmethod - def get_from_keys_or(config: dict[str, Any], keys: list[str], - default: Any) -> Any: + def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any: """Get an optional value from the model's quantization config.""" try: return QuantizationConfig.get_from_keys(config, keys) @@ -135,10 +133,11 @@ def get_from_keys_or(config: dict[str, Any], keys: list[str], return default @abstractmethod - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional[QuantizeMethodBase]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: """Get the quantize method to use for the quantized layer. - + Args: layer: The layer for the quant method. prefix: The full name of the layer in the state dict @@ -152,7 +151,8 @@ def get_cache_scale(self, name: str) -> Optional[str]: return None def apply_vllm_mapper( # noqa: B027 - self, hf_to_vllm_mapper: "WeightsMapper"): + self, hf_to_vllm_mapper: "WeightsMapper" + ): """ Interface for models to update module names referenced in quantization configs in order to reflect the vllm model structure diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py index 81e51f4a4358..d2e0582be197 100644 --- a/vllm/model_executor/layers/quantization/bitblas.py +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -7,16 +7,23 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import (QuantizationConfig, - QuantizationMethods) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS, - BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION) + BITBLAS_OPTIMIZE_FEATURES, + BITBLAS_SUPPORTED_NUM_BITS, + BITBLAS_SUPPORTED_SYM, + MINIMUM_BITBLAS_VERSION, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -27,6 +34,7 @@ class BitBLASConfig(QuantizationConfig): Reference: https://github.com/Microsoft/BitBLAS """ + TORCH_DTYPE = torch.float16 STORAGE_DTYPE = "int8" # assume int8 storage TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) @@ -45,11 +53,14 @@ def __init__( ) -> None: try: import bitblas + if version.parse(bitblas.__version__) < version.parse( - MINIMUM_BITBLAS_VERSION): + MINIMUM_BITBLAS_VERSION + ): raise ImportError( "bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError as e: bitblas_import_exception = e raise ValueError( @@ -77,12 +88,14 @@ def __init__( raise ValueError( f"BitBLAS does not support weight_bits = {self.weight_bits}. " f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} " - "are supported.") + "are supported." + ) if self.is_sym not in BITBLAS_SUPPORTED_SYM: raise ValueError( f"BitBLAS does not support is_sym = {self.is_sym}. " - f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.") + f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported." + ) storage_dtype = self.STORAGE_DTYPE storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) @@ -97,11 +110,13 @@ def __init__( self.zeros_mode = self.ZEROS_MODE def __repr__(self) -> str: - return (f"BitBLASConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act}, " - f"is_sym={self.is_sym}, " - f"quant_method={self.quant_method})") + return ( + f"BitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -121,9 +136,9 @@ def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @staticmethod - def get_from_keys(config: dict[str, Any], - keys: list[str], - default: Any = None) -> Any: + def get_from_keys( + config: dict[str, Any], keys: list[str], default: Any = None + ) -> Any: """Get a value from the model's quantization config.""" for key in keys: if key in config: @@ -137,34 +152,40 @@ def from_config(cls, config: dict[str, Any]) -> "BitBLASConfig": desc_act = cls.get_from_keys(config, ["desc_act"], False) is_sym = cls.get_from_keys(config, ["sym"], False) quant_method = cls.get_from_keys(config, ["quant_method"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(weight_bits, group_size, desc_act, is_sym, quant_method, - lm_head_quantized) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + return cls( + weight_bits, group_size, desc_act, is_sym, quant_method, lm_head_quantized + ) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: # compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq <=0.7.1 is_bitblas_format: bool - is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" - or hf_quant_cfg.get("is_bitblas_format", False)) + is_bitblas_format = hf_quant_cfg.get( + "checkpoint_format" + ) == "bitblas" or hf_quant_cfg.get("is_bitblas_format", False) - is_valid_user_quant = (user_quant is None or user_quant == "gptq" - or user_quant == "bitblas") + is_valid_user_quant = ( + user_quant is None or user_quant == "gptq" or user_quant == "bitblas" + ) if is_bitblas_format and is_valid_user_quant: - msg = ("The model is serialized in {} format. Using {} kernel.". - format(cls.get_name(), cls.get_name())) + msg = "The model is serialized in {} format. Using {} kernel.".format( + cls.get_name(), cls.get_name() + ) logger.info(msg) return cls.get_name() return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["BitBLASLinearMethod"]: - if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) - and self.lm_head_quantized): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["BitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): return BitBLASLinearMethod(self) return None @@ -175,6 +196,7 @@ class BitBLASLinearMethod(LinearMethodBase): Args: quant_config: The BitBLAS quantization config. """ + # USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS # Instead of BITBLAS_OPTIMIZE_FEATURES # If you want to high contiguous batching @@ -204,7 +226,7 @@ def create_weights_gptq( ) -> None: """Creates quantized weights for use in linear operations. - The function initializes and returns a dictionary containing quantized + The function initializes and returns a dictionary containing quantized weights, scales, and zeros for performing quantized matrix multiplication operations. @@ -213,11 +235,11 @@ def create_weights_gptq( output_partition_sizes: List of output partition sizes. input_size: The total size of the input (unused). output_size: The total size of the output (unused). - params_dtype: + params_dtype: The data type of the parameters (expected to be torch.float16). Returns: - A dictionary containing the quantized weights ('qweight'), + A dictionary containing the quantized weights ('qweight'), scales ('scales'), and zeros ('zeros'). Raises: @@ -229,17 +251,19 @@ def create_weights_gptq( weight_loader = extra_weight_attrs["weight_loader"] if params_dtype not in self.quant_config.get_supported_act_dtypes(): - raise ValueError("Parameter data type must be torch.float16, " - f"but got {params_dtype}") + raise ValueError( + f"Parameter data type must be torch.float16, but got {params_dtype}" + ) group_size = self.quant_config.group_size if group_size is None: group_size = -1 # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) - if (group_size != -1 and input_size_per_partition % group_size != 0): + if group_size != -1 and input_size_per_partition % group_size != 0: raise ValueError( f"Input size per partition ({input_size_per_partition}) must " - f"be divisible by group size ({group_size}).") + f"be divisible by group size ({group_size})." + ) # Initialize or retrieve the BitBLAS matrix multiplication operator. self._configure_bitblas_matmul( @@ -265,34 +289,33 @@ def create_weights_gptq( output_dim=0, packed_dim=1, packed_factor=self.quant_config.pack_factor, - bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2] - if self.bitblas_matmul.propagate_b else None), + bitblas_tile_size=( + self.bitblas_matmul.retrieve_weight_shape()[-2] + if self.bitblas_matmul.propagate_b + else None + ), weight_loader=weight_loader, ) # Compute the number of input groups for channel-wise quantization. - input_groups = (1 if group_size == -1 else input_size_per_partition // - group_size) + input_groups = 1 if group_size == -1 else input_size_per_partition // group_size # Initialize scales and zeros for the quantized weights. weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( output_size_per_partition, input_groups, device="cuda", dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if input_groups == 1: - scales = ChannelQuantScaleParameter(output_dim=0, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) else: - scales = GroupQuantScaleParameter(output_dim=0, - input_dim=1, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=0, input_dim=1, **weight_scale_args + ) if self.quant_config.zeros_mode == "quantized": zeros = PackedvLLMParameter( @@ -312,17 +335,22 @@ def create_weights_gptq( else: zeros = BasevLLMParameter( - torch.empty(output_size_per_partition, - input_groups, - device="cuda", - dtype=params_dtype), + torch.empty( + output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype, + ), weight_loader=weight_loader, ) # Set attributes to indicate how scales and zeros are applied. - set_weight_attrs(zeros, { - "input_dim": None if input_groups == 1 else 1, - "output_dim": 0, - }) + set_weight_attrs( + zeros, + { + "input_dim": None if input_groups == 1 else 1, + "output_dim": 0, + }, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("scales", scales) @@ -339,13 +367,19 @@ def create_weights( **extra_weight_attrs, ): if self.quant_config.quant_method == "gptq": - return self.create_weights_gptq(layer, input_size_per_partition, - output_partition_sizes, input_size, - output_size, params_dtype, - **extra_weight_attrs) + return self.create_weights_gptq( + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ) else: raise ValueError( - f"Unsupported quant_method {self.quant_config.quant_method}") + f"Unsupported quant_method {self.quant_config.quant_method}" + ) def _configure_bitblas_matmul( self, @@ -359,6 +393,7 @@ def _configure_bitblas_matmul( out_dtype="float16", ): from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] with_scaling = False @@ -374,7 +409,8 @@ def _configure_bitblas_matmul( W_dtype = f"int{bits}" else: raise ValueError( - f"Unsupported quant_method {self.quant_config.quant_method}") + f"Unsupported quant_method {self.quant_config.quant_method}" + ) matmul_config = MatmulConfig( N=outfeatures, @@ -392,38 +428,40 @@ def _configure_bitblas_matmul( zeros_mode=zeros_mode, ) self.bitblas_matmul = self._get_or_create_bitblas_operator( - matmul_config, enable_tuning) + matmul_config, enable_tuning + ) def _get_or_create_bitblas_operator(self, config, enable_tuning): from bitblas import Matmul, auto_detect_nvidia_target from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() BITBLAS_TARGET = auto_detect_nvidia_target() if global_operator_cache.size() == 0: - global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, - BITBLAS_TARGET) + global_operator_cache.load_from_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: - bitblas_matmul = Matmul(config, - target=BITBLAS_TARGET, - enable_tuning=False) + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) if enable_tuning: - TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...") + TUNING_MESSAGE = f"BitBLAS Operator {config} is tuning ..." logger.info(TUNING_MESSAGE) bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) global_operator_cache.save_into_database( - BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) TUNED_MESSAGE = ( - f"BitBLAS Operator {config} tuned and saved to database.") + f"BitBLAS Operator {config} tuned and saved to database." + ) logger.info(TUNED_MESSAGE) else: _message = f"BitBLAS Operator {config} created." logger.info(_message) else: - _message = ( - f"BitBLAS Operator {config} found in global_operator_cache.") + _message = f"BitBLAS Operator {config} found in global_operator_cache." logger.info(_message) return bitblas_matmul @@ -444,7 +482,7 @@ def apply_gptq( else: output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros) - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) if bias is not None: output.add_(bias) # In-place add @@ -460,4 +498,5 @@ def apply( return self.apply_gptq(*args, **kwargs) else: raise ValueError( - f"Unsupported quant_method {self.quant_config.quant_method}") + f"Unsupported quant_method {self.quant_config.quant_method}" + ) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 7b7011cb06d3..80ed121bd85b 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -6,15 +6,21 @@ import torch from packaging import version -from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, - FusedMoEQuantConfig) -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod, - set_weight_attrs) -from vllm.model_executor.layers.quantization import (QuantizationConfig, - QuantizationMethods) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, + set_weight_attrs, +) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -51,16 +57,19 @@ def __init__( self.llm_int8_threshold = llm_int8_threshold if self.bnb_4bit_quant_storage not in ["uint8"]: - raise ValueError("Unsupported bnb_4bit_quant_storage: " - f"{self.bnb_4bit_quant_storage}") + raise ValueError( + f"Unsupported bnb_4bit_quant_storage: {self.bnb_4bit_quant_storage}" + ) def __repr__(self) -> str: - return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, " - f"load_in_4bit={self.load_in_4bit}, " - f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, " - f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, " - f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, " - f"llm_int8_skip_modules={self.llm_int8_skip_modules})") + return ( + f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, " + f"load_in_4bit={self.load_in_4bit}, " + f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, " + f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, " + f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, " + f"llm_int8_skip_modules={self.llm_int8_skip_modules})" + ) @classmethod def get_name(self) -> QuantizationMethods: @@ -80,7 +89,6 @@ def get_config_filenames() -> list[str]: @classmethod def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig": - def get_safe_value(config, keys, default_value=None): try: value = cls.get_from_keys(config, keys) @@ -88,30 +96,32 @@ def get_safe_value(config, keys, default_value=None): except ValueError: return default_value - load_in_8bit = get_safe_value(config, ["load_in_8bit"], - default_value=False) - load_in_4bit = get_safe_value(config, ["load_in_4bit"], - default_value=True) - bnb_4bit_compute_dtype = get_safe_value(config, - ["bnb_4bit_compute_dtype"], - default_value="float32") - bnb_4bit_quant_storage = get_safe_value(config, - ["bnb_4bit_quant_storage"], - default_value="uint8") - bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"], - default_value="fp4") + load_in_8bit = get_safe_value(config, ["load_in_8bit"], default_value=False) + load_in_4bit = get_safe_value(config, ["load_in_4bit"], default_value=True) + bnb_4bit_compute_dtype = get_safe_value( + config, ["bnb_4bit_compute_dtype"], default_value="float32" + ) + bnb_4bit_quant_storage = get_safe_value( + config, ["bnb_4bit_quant_storage"], default_value="uint8" + ) + bnb_4bit_quant_type = get_safe_value( + config, ["bnb_4bit_quant_type"], default_value="fp4" + ) bnb_4bit_use_double_quant = get_safe_value( - config, ["bnb_4bit_use_double_quant"], default_value=False) + config, ["bnb_4bit_use_double_quant"], default_value=False + ) llm_int8_enable_fp32_cpu_offload = get_safe_value( - config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False) - llm_int8_has_fp16_weight = get_safe_value(config, - ["llm_int8_has_fp16_weight"], - default_value=False) - llm_int8_skip_modules = get_safe_value(config, - ["llm_int8_skip_modules"], - default_value=[]) - llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"], - default_value=6.0) + config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False + ) + llm_int8_has_fp16_weight = get_safe_value( + config, ["llm_int8_has_fp16_weight"], default_value=False + ) + llm_int8_skip_modules = get_safe_value( + config, ["llm_int8_skip_modules"], default_value=[] + ) + llm_int8_threshold = get_safe_value( + config, ["llm_int8_threshold"], default_value=6.0 + ) return cls( load_in_8bit=load_in_8bit, @@ -123,7 +133,8 @@ def get_safe_value(config, keys, default_value=None): llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload, llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, llm_int8_skip_modules=llm_int8_skip_modules, - llm_int8_threshold=llm_int8_threshold) + llm_int8_threshold=llm_int8_threshold, + ) def get_quant_method( self, layer: torch.nn.Module, prefix: str @@ -139,15 +150,15 @@ def get_quant_method( def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]): # Split the prefix into its dot-separated components - components = prefix.split('.') + components = prefix.split(".") # Check if any of the skip modules exactly matches any component - substr_check = any(module_name in components - for module_name in llm_int8_skip_modules) + substr_check = any( + module_name in components for module_name in llm_int8_skip_modules + ) # Allow certain layers to not be quantized - set_components = set(".".join(components[:i + 1]) - for i in range(len(components))) + set_components = set(".".join(components[: i + 1]) for i in range(len(components))) set_llm_int8_skip_modules = set(llm_int8_skip_modules) prefix_check = len(set_llm_int8_skip_modules & set_components) != 0 @@ -171,39 +182,53 @@ class BitsAndBytesLinearMethod(LinearMethodBase): def __init__(self, quant_config: BitsAndBytesConfig): try: import bitsandbytes - if version.parse( - bitsandbytes.__version__) < version.parse("0.46.1"): - raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.46.1.") + + if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"): + raise ImportError( + "bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.46.1." + ) except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.46.1 via " - "`pip install bitsandbytes>=0.46.1` to use " - "bitsandbytes quantizer.") from err + raise ImportError( + "Please install bitsandbytes>=0.46.1 via " + "`pip install bitsandbytes>=0.46.1` to use " + "bitsandbytes quantizer." + ) from err self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): from bitsandbytes.nn import Int8Params def create_qweight_for_8bit(): qweight = Int8Params( - data=torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=torch.int8), + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8, + ), has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight, - requires_grad=False) + requires_grad=False, + ) set_weight_attrs( - qweight, { + qweight, + { "input_dim": 0, "output_dim": 0, "pack_factor": 1, "use_bitsandbytes_8bit": True, - "generation": 0 - }) + "generation": 0, + }, + ) return qweight def create_qweight_for_4bit(): @@ -212,20 +237,22 @@ def create_qweight_for_4bit(): total_size = input_size_per_partition * sum(output_partition_sizes) if total_size % quant_ratio != 0: raise ValueError( - "The input size is not aligned with the quantized " - "weight shape.") + "The input size is not aligned with the quantized weight shape." + ) - qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio, - 1, - dtype=torch.uint8), - requires_grad=False) + qweight = torch.nn.Parameter( + torch.empty(total_size // quant_ratio, 1, dtype=torch.uint8), + requires_grad=False, + ) set_weight_attrs( - qweight, { + qweight, + { "input_dim": 0, "output_dim": 0, "pack_factor": quant_ratio, - "use_bitsandbytes_4bit": True - }) + "use_bitsandbytes_4bit": True, + }, + ) return qweight if self.quant_config.load_in_8bit: @@ -237,22 +264,23 @@ def create_qweight_for_4bit(): layer.register_parameter("weight", qweight) set_weight_attrs(qweight, extra_weight_attrs) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.quant_config.load_in_8bit: return self._apply_8bit_weight(layer, x, bias) else: return self._apply_4bit_weight(layer, x, bias) def _apply_8bit_weight( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: # only load the bitsandbytes module when needed from bitsandbytes import MatmulLtState, matmul @@ -272,11 +300,9 @@ def _apply_8bit_weight( out_dim_0 = x.shape[0] out_dim_1 = sum( - [quant_state[1].shape[0] for quant_state in quant_states.items()]) - out = torch.empty(out_dim_0, - out_dim_1, - dtype=torch.float16, - device=x.device) + [quant_state[1].shape[0] for quant_state in quant_states.items()] + ) + out = torch.empty(out_dim_0, out_dim_1, dtype=torch.float16, device=x.device) current_index = 0 for i in range(len(quant_states)): @@ -286,33 +312,36 @@ def _apply_8bit_weight( # create new matmul_states if generation == 0 or generation == 1: matmul_states[i] = MatmulLtState() - matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]] + matmul_states[i].CB = qweight[offsets[i] : offsets[i + 1]] matmul_states[i].SCB = quant_states[i].to(x.device) - matmul_states[i].threshold = ( - self.quant_config.llm_int8_threshold) - matmul_states[i].has_fp16_weights = ( - self.quant_config.llm_int8_has_fp16_weight) + matmul_states[i].threshold = self.quant_config.llm_int8_threshold + matmul_states[ + i + ].has_fp16_weights = self.quant_config.llm_int8_has_fp16_weight matmul_states[i].is_training = False - if matmul_states[i].threshold > 0.0 and not matmul_states[ - i].has_fp16_weights: + if ( + matmul_states[i].threshold > 0.0 + and not matmul_states[i].has_fp16_weights + ): matmul_states[i].use_pool = True new_x = bf_x.unsqueeze(0) - out[:, current_index:current_index + output_size] = matmul( - new_x, - qweight[offsets[i]:offsets[i + 1]], - state=matmul_states[i]) + out[:, current_index : current_index + output_size] = matmul( + new_x, qweight[offsets[i] : offsets[i + 1]], state=matmul_states[i] + ) current_index += output_size # only update the matmul_states if it is not profile_run - if (generation > 0 - and not self.quant_config.llm_int8_has_fp16_weight - and matmul_states[i].CB is not None - and matmul_states[i].CxB is not None): + if ( + generation > 0 + and not self.quant_config.llm_int8_has_fp16_weight + and matmul_states[i].CB is not None + and matmul_states[i].CxB is not None + ): del matmul_states[i].CB - qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB + qweight[offsets[i] : offsets[i + 1]] = matmul_states[i].CxB out = out.to(original_type) @@ -327,11 +356,11 @@ def _apply_8bit_weight( return out def _apply_4bit_weight( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: original_type = x.dtype original_shape = x.shape reshape_after_matmul = False @@ -346,11 +375,9 @@ def _apply_4bit_weight( out_dim_0 = x.shape[0] out_dim_1 = sum( - [quant_state[1].shape[0] for quant_state in quant_states.items()]) - out = torch.empty(out_dim_0, - out_dim_1, - dtype=torch.bfloat16, - device=x.device) + [quant_state[1].shape[0] for quant_state in quant_states.items()] + ) + out = torch.empty(out_dim_0, out_dim_1, dtype=torch.bfloat16, device=x.device) apply_bnb_4bit(bf_x, qweight, offsets, out) out = out.to(original_type) @@ -371,6 +398,7 @@ def _apply_bnb_4bit( ) -> None: # only load the bitsandbytes module when needed from bitsandbytes import matmul_4bit + quant_states = weight.bnb_quant_state current_index = 0 for i in range(len(quant_states)): @@ -379,8 +407,9 @@ def _apply_bnb_4bit( # matmul_4bit(..., out = ...). Infeasible now due to the bug # https://github.com/TimDettmers/bitsandbytes/issues/1235. # Need to change after the bug is fixed. - out[:, current_index:current_index + output_size] = matmul_4bit( - x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i]) + out[:, current_index : current_index + output_size] = matmul_4bit( + x, weight[offsets[i] : offsets[i + 1]].t(), quant_states[i] + ) current_index += output_size @@ -394,11 +423,13 @@ def _apply_bnb_4bit_fake( try: - direct_register_custom_op(op_name="apply_bnb_4bit", - op_func=_apply_bnb_4bit, - mutates_args=["out"], - fake_impl=_apply_bnb_4bit_fake, - dispatch_key=current_platform.dispatch_key) + direct_register_custom_op( + op_name="apply_bnb_4bit", + op_func=_apply_bnb_4bit, + mutates_args=["out"], + fake_impl=_apply_bnb_4bit_fake, + dispatch_key=current_platform.dispatch_key, + ) apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit except AttributeError as error: @@ -420,14 +451,18 @@ def __init__( super().__init__(moe) try: import bitsandbytes - if version.parse( - bitsandbytes.__version__) < version.parse("0.46.1"): - raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.46.1.") + + if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"): + raise ImportError( + "bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.46.1." + ) except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.46.1 via " - "`pip install bitsandbytes>=0.46.1` to use " - "bitsandbytes quantizer.") from err + raise ImportError( + "Please install bitsandbytes>=0.46.1 via " + "`pip install bitsandbytes>=0.46.1` to use " + "bitsandbytes quantizer." + ) from err self.quant_config = quant_config def create_weights( @@ -453,7 +488,8 @@ def create_weights( ) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: return None def apply( @@ -480,11 +516,13 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: from vllm.model_executor.layers.fused_moe import fused_experts + assert self.fused_experts is None if enable_eplb: raise NotImplementedError( - "EPLB not supported for `BitsAndBytesMoEMethod` yet.") + "EPLB not supported for `BitsAndBytesMoEMethod` yet." + ) topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -497,7 +535,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) if self.quant_config.load_in_8bit: w13, w2 = self._apply_8bit_dequant(layer) else: @@ -527,8 +566,9 @@ def _create_weights_4bit( ): quant_ratio = calculate_quant_ratio(params_dtype) # Fused gate_up_proj (column parallel) - w13_total_size = (hidden_size * 2 * - intermediate_size_per_partition) // quant_ratio + w13_total_size = ( + hidden_size * 2 * intermediate_size_per_partition + ) // quant_ratio w13_qweight = torch.nn.Parameter( torch.empty( num_experts, @@ -543,26 +583,20 @@ def _create_weights_4bit( set_weight_attrs( w13_qweight, { - "num_experts": - num_experts, - "input_dim": - hidden_size, - "output_dim": - 2 * intermediate_size_per_partition, + "num_experts": num_experts, + "input_dim": hidden_size, + "output_dim": 2 * intermediate_size_per_partition, "experts_shape": ( num_experts, intermediate_size_per_partition * 2, hidden_size, ), - "pack_factor": - quant_ratio, - "use_bitsandbytes_4bit": - True, + "pack_factor": quant_ratio, + "use_bitsandbytes_4bit": True, }, ) # down_proj (row parallel) - w2_total_size = (hidden_size * - intermediate_size_per_partition) // quant_ratio + w2_total_size = (hidden_size * intermediate_size_per_partition) // quant_ratio w2_qweight = torch.nn.Parameter( torch.empty( num_experts, @@ -575,21 +609,16 @@ def _create_weights_4bit( set_weight_attrs( w2_qweight, { - "num_experts": - num_experts, - "input_dim": - intermediate_size_per_partition, - "output_dim": - hidden_size, + "num_experts": num_experts, + "input_dim": intermediate_size_per_partition, + "output_dim": hidden_size, "experts_shape": ( num_experts, hidden_size, intermediate_size_per_partition, ), - "pack_factor": - quant_ratio, - "use_bitsandbytes_4bit": - True, + "pack_factor": quant_ratio, + "use_bitsandbytes_4bit": True, }, ) layer.register_parameter("w2_weight", w2_qweight) @@ -607,8 +636,10 @@ def _create_weights_8bit( raise NotImplementedError def _apply_4bit_dequnt( - self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]: + self, layer: torch.nn.Module + ) -> tuple[torch.Tensor, torch.Tensor]: from bitsandbytes.functional import dequantize_4bit + w13 = dequantize_4bit( layer.w13_weight.reshape(-1, 1), layer.w13_weight.bnb_quant_state, @@ -622,5 +653,6 @@ def _apply_4bit_dequnt( return w13, w2 def _apply_8bit_dequant( - self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]: + self, layer: torch.nn.Module + ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 3f771ea2abd1..e89d002078ac 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -5,39 +5,62 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast import torch -from compressed_tensors.config import (CompressionFormat, - SparsityCompressionConfig, - SparsityStructure) -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) +from compressed_tensors.config import ( + CompressionFormat, + SparsityCompressionConfig, + SparsityStructure, +) +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) from compressed_tensors.transform import TransformConfig import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 - QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 - CompressedTensorsMoEMethod) + CompressedTensorsMoEMethod, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, - CompressedTensorsScheme, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int, - CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, - CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, - CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) + W4A16SPARSE24_SUPPORTED_BITS, + WNA16_SUPPORTED_BITS, + CompressedTensors24, + CompressedTensorsScheme, + CompressedTensorsW4A4Fp4, + CompressedTensorsW4A8Fp8, + CompressedTensorsW4A8Int, + CompressedTensorsW4A16Fp4, + CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, + CompressedTensorsWNA16, +) from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 - CompressedTensorsLinearTransformMethod, get_linear_transform_schemes) + CompressedTensorsLinearTransformMethod, + get_linear_transform_schemes, +) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - find_matched_target, is_activation_quantization_format, - should_ignore_layer) + find_matched_target, + is_activation_quantization_format, + should_ignore_layer, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported) + cutlass_fp4_supported, +) from vllm.platforms import current_platform if TYPE_CHECKING: @@ -52,7 +75,6 @@ class CompressedTensorsConfig(QuantizationConfig): - def __init__( self, target_scheme_map: dict[str, Any], @@ -75,8 +97,7 @@ def __init__( self.config = config if transform_config: - self.transform_config = TransformConfig.model_validate( - transform_config) + self.transform_config = TransformConfig.model_validate(transform_config) else: self.transform_config = None @@ -94,16 +115,16 @@ def get_name(self) -> QuantizationMethods: return "compressed-tensors" def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): - self.target_scheme_map = hf_to_vllm_mapper.apply_dict( - self.target_scheme_map) + self.target_scheme_map = hf_to_vllm_mapper.apply_dict(self.target_scheme_map) self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict( - self.sparsity_scheme_map) + self.sparsity_scheme_map + ) self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list( - self.sparsity_ignore_list) + self.sparsity_ignore_list + ) if self.kv_cache_scheme is not None: - self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict( - self.kv_cache_scheme) + self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(self.kv_cache_scheme) def get_quant_method( self, @@ -116,8 +137,8 @@ def get_quant_method( # collect schemes quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) input_tfms, output_tfms = get_linear_transform_schemes( - layer, prefix, self.transform_config, - self.packed_modules_mapping) + layer, prefix, self.transform_config, self.packed_modules_mapping + ) # choose quantization method quant_method: LinearMethodBase = UnquantizedLinearMethod() @@ -128,7 +149,8 @@ def get_quant_method( # choose transform method if any((input_tfms, output_tfms)): return CompressedTensorsLinearTransformMethod.from_schemes( - quant_method, quant_scheme, input_tfms, output_tfms) + quant_method, quant_scheme, input_tfms, output_tfms + ) else: return quant_method @@ -143,10 +165,10 @@ def get_quant_method( def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig": ignore: list[str] = cast(list[str], config.get("ignore", [])) quant_format = cast(str, config.get("format")) - target_scheme_map = cls._quantization_scheme_map_from_config( - config=config) + target_scheme_map = cls._quantization_scheme_map_from_config(config=config) sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config( - config=config) + config=config + ) transform_config = config.get("transform_config") return cls( @@ -173,18 +195,17 @@ def _parse_sparsity_config( if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)): return dict(), [] - sparsity_config = SparsityCompressionConfig.model_validate( - sparsity_config) + sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config) sparse_scheme_map: dict[str, SparsityCompressionConfig] = { - target: sparsity_config - for target in sparsity_config.targets or list() + target: sparsity_config for target in sparsity_config.targets or list() } sparsity_ignore_list = sparsity_config.ignore or list() return sparse_scheme_map, sparsity_ignore_list @classmethod def _quantization_scheme_map_from_config( - cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: + cls, config: dict[str, Any] + ) -> QUANTIZATION_SCHEME_MAP_TYPE: """ :param config: The `quantization_config` dictionary from config.json :return: A dictionary mapping target layer names to their corresponding @@ -207,19 +228,19 @@ def _quantization_scheme_map_from_config( targets = quant_config.get("targets") for target in targets: target_scheme_map[target] = {} - target_scheme_map[target][ - "weights"] = QuantizationArgs.model_validate( - quant_config.get("weights")) + target_scheme_map[target]["weights"] = QuantizationArgs.model_validate( + quant_config.get("weights") + ) target_scheme_map[target]["input_activations"] = None - target_scheme_map[target]["format"] = quant_config.get( - "format") + target_scheme_map[target]["format"] = quant_config.get("format") format = target_scheme_map[target].get("format") # If no per-config format defined, use global format in config - act_quant_format = is_activation_quantization_format( - format - ) if format is not None else is_activation_quantization_format( - quant_format) + act_quant_format = ( + is_activation_quantization_format(format) + if format is not None + else is_activation_quantization_format(quant_format) + ) # TODO(czhu): w4a8fp8 is in packed-quantized format # but needs input activation quantization input_activations = quant_config.get("input_activations") @@ -229,22 +250,25 @@ def _quantization_scheme_map_from_config( # should be w8a16fp8 w8a16fp8 can also run for cases where # there is an input_quant but it is ignored if not input_activations: - assert target_scheme_map[target][ - "weights"].type == QuantizationType.FLOAT + assert ( + target_scheme_map[target]["weights"].type + == QuantizationType.FLOAT + ) else: - target_scheme_map[target][ - "input_activations"] = QuantizationArgs.model_validate( # noqa: E501 - quant_config.get("input_activations")) + target_scheme_map[target]["input_activations"] = ( + QuantizationArgs.model_validate( + quant_config.get("input_activations") + ) + ) return target_scheme_map @classmethod def get_config_filenames(cls) -> list[str]: return [] - def _check_scheme_supported(self, - min_capability: int, - error: bool = True, - match_exact: bool = False) -> bool: + def _check_scheme_supported( + self, min_capability: int, error: bool = True, match_exact: bool = False + ) -> bool: capability_tuple = current_platform.get_device_capability() if capability_tuple is not None: @@ -255,115 +279,155 @@ def _check_scheme_supported(self, raise RuntimeError( "Quantization scheme is not supported for ", "the current GPU. Required capability: ", - f"{min_capability}. Current capability: {capability}.") + f"{min_capability}. Current capability: {capability}.", + ) else: supported = capability >= min_capability if error and not supported: raise RuntimeError( "Quantization scheme is not supported for ", f"the current GPU. Min capability: {min_capability}. ", - f"Current capability: {capability}.") + f"Current capability: {capability}.", + ) return supported else: return False - def _is_fp4a4_nvfp4(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs): - + def _is_fp4a4_nvfp4( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ): if weight_quant is None or input_quant is None: return False - is_tensor_group_quant = (weight_quant.strategy - == QuantizationStrategy.TENSOR_GROUP.value - and input_quant.strategy - == QuantizationStrategy.TENSOR_GROUP.value) + is_tensor_group_quant = ( + weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + ) is_symmetric = weight_quant.symmetric and input_quant.symmetric - is_group_size_16 = (weight_quant.group_size == 16 - and input_quant.group_size == 16) - is_float_type = (weight_quant.type == QuantizationType.FLOAT - and input_quant.type == QuantizationType.FLOAT.value) + is_group_size_16 = ( + weight_quant.group_size == 16 and input_quant.group_size == 16 + ) + is_float_type = ( + weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT.value + ) is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4 - return (is_tensor_group_quant and is_float_type and is_4_bits - and is_group_size_16 and is_symmetric) - - def _is_fp4a16_nvfp4(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs): + return ( + is_tensor_group_quant + and is_float_type + and is_4_bits + and is_group_size_16 + and is_symmetric + ) + def _is_fp4a16_nvfp4( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ): is_weight_only = weight_quant is not None and input_quant is None is_tensor_group_quant = ( - weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value) + weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + ) is_symmetric = weight_quant.symmetric is_group_size_16 = weight_quant.group_size == 16 is_float_type = weight_quant.type == QuantizationType.FLOAT is_4_bits = weight_quant.num_bits == 4 - return (is_weight_only and is_tensor_group_quant and is_float_type - and is_4_bits and is_group_size_16 and is_symmetric) + return ( + is_weight_only + and is_tensor_group_quant + and is_float_type + and is_4_bits + and is_group_size_16 + and is_symmetric + ) - def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: + def _is_static_tensor_w8a8( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) - is_tensor = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TENSOR.value) + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_tensor = ( + weight_strategy + and input_quant.strategy == QuantizationStrategy.TENSOR.value + ) is_static = not weight_quant.dynamic and not input_quant.dynamic # Both symmetric and asymmetric input quantization supported. # Only symmetric weight quantization supported. return is_8_bits and is_tensor and weight_quant.symmetric and is_static - def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: + def _is_dynamic_token_w8a8( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) - is_token = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TOKEN.value) + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_token = ( + weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) is_dynamic = not weight_quant.dynamic and input_quant.dynamic # Both symmetric and asymmetric input quantization supported. # Only symmetric weight quantization supported. return is_8_bits and is_token and weight_quant.symmetric and is_dynamic - def _is_dynamic_token_w4a8_int(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: + def _is_dynamic_token_w4a8_int( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: is_weight_4_bits = weight_quant.num_bits == 4 is_activation_8_bits = input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.GROUP.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) - is_token = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TOKEN.value) + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_token = ( + weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) is_dynamic = not weight_quant.dynamic and input_quant.dynamic # Both symmetric and asymmetric input quantization supported. # Only symmetric weight quantization supported. - return (is_weight_4_bits and is_activation_8_bits and is_token - and weight_quant.symmetric and is_dynamic) + return ( + is_weight_4_bits + and is_activation_8_bits + and is_token + and weight_quant.symmetric + and is_dynamic + ) - def _is_fp8_w8a8(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: + def _is_fp8_w8a8( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: # Confirm weights and activations quantized. if weight_quant is None or input_quant is None: return False # Confirm weight scheme is supported. - is_floating_point = (weight_quant.type == QuantizationType.FLOAT - and input_quant.type == QuantizationType.FLOAT) + is_floating_point = ( + weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT + ) is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, - QuantizationStrategy.BLOCK - ]) - if not (is_floating_point and is_symmetric_weight and is_static_weight - and is_tensor_or_channel_or_block_weight): + is_tensor_or_channel_or_block_weight = weight_quant.strategy in [ + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK, + ] + if not ( + is_floating_point + and is_symmetric_weight + and is_static_weight + and is_tensor_or_channel_or_block_weight + ): return False # Dynamic quantization is always supported if weights supported. @@ -372,45 +436,56 @@ def _is_fp8_w8a8(self, weight_quant: QuantizationArgs, # Confirm activation scheme is supported. is_symmetric_activation = input_quant.symmetric - is_per_tensor_activation = ( - input_quant.strategy == QuantizationStrategy.TENSOR) + is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR return is_symmetric_activation and is_per_tensor_activation - def _is_fp8_w4a8(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: + def _is_fp8_w4a8( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: if not weight_quant or not input_quant: return False is_weight_4_bits = weight_quant.num_bits == 4 is_activation_8_bits = input_quant.num_bits == 8 - weight_strategy = ( - weight_quant.strategy == QuantizationStrategy.GROUP.value) - is_token = (weight_strategy and input_quant.strategy - == QuantizationStrategy.TOKEN.value) + weight_strategy = weight_quant.strategy == QuantizationStrategy.GROUP.value + is_token = ( + weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) is_dynamic = not weight_quant.dynamic and input_quant.dynamic is_symmetric = weight_quant.symmetric and input_quant.symmetric # Only per-group symmetric weight (4bit) # + per-tok symmetric activation (8bit) quantization supported. - return (is_weight_4_bits and is_activation_8_bits and is_token - and is_symmetric and is_dynamic) - - def _is_fp8_w4a8_sm90(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: - return (self._check_scheme_supported(90, error=False, match_exact=True) - and self._is_fp8_w4a8(weight_quant, input_quant)) - - def _is_fp8_w8a8_sm90(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: - return (self._check_scheme_supported(90, error=False, match_exact=True) - and self._is_fp8_w8a8(weight_quant, input_quant)) - - def _is_fp8_w8a8_sm100(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: - return (self._check_scheme_supported( - 100, error=False, match_exact=True) - and self._is_fp8_w8a8(weight_quant, input_quant)) - - def _is_fp8_w8a16(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: + return ( + is_weight_4_bits + and is_activation_8_bits + and is_token + and is_symmetric + and is_dynamic + ) + + def _is_fp8_w4a8_sm90( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: + return self._check_scheme_supported( + 90, error=False, match_exact=True + ) and self._is_fp8_w4a8(weight_quant, input_quant) + + def _is_fp8_w8a8_sm90( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: + return self._check_scheme_supported( + 90, error=False, match_exact=True + ) and self._is_fp8_w8a8(weight_quant, input_quant) + + def _is_fp8_w8a8_sm100( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: + return self._check_scheme_supported( + 100, error=False, match_exact=True + ) and self._is_fp8_w8a8(weight_quant, input_quant) + + def _is_fp8_w8a16( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: # Confirm weights quantized. if weight_quant is None: return False @@ -422,33 +497,35 @@ def _is_fp8_w8a16(self, weight_quant: QuantizationArgs, # Confirm weight scheme is supported. is_symmetric_weight = weight_quant.symmetric is_static_weight = not weight_quant.dynamic - is_tensor_or_channel_or_block_weight = (weight_quant.strategy in [ - QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL, - QuantizationStrategy.BLOCK - ]) - if not (is_symmetric_weight and is_static_weight # noqa: SIM103 - and is_tensor_or_channel_or_block_weight): - return False - - # All conditions satisfied. - return True + is_tensor_or_channel_or_block_weight = weight_quant.strategy in [ + QuantizationStrategy.TENSOR, + QuantizationStrategy.CHANNEL, + QuantizationStrategy.BLOCK, + ] + return ( + is_symmetric_weight + and is_static_weight + and is_tensor_or_channel_or_block_weight + ) - def _is_wNa16_group_channel(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: + def _is_wNa16_group_channel( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ) -> bool: input_quant_none = input_quant is None is_channel_group = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value - or weight_quant.strategy == QuantizationStrategy.GROUP.value) + or weight_quant.strategy == QuantizationStrategy.GROUP.value + ) is_static = not weight_quant.dynamic - return (is_channel_group and input_quant_none and is_static) + return is_channel_group and input_quant_none and is_static def _get_scheme_from_parts( - self, - weight_quant: QuantizationArgs, - input_quant: QuantizationArgs, - format: Optional[str] = None) -> "CompressedTensorsScheme": - + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + format: Optional[str] = None, + ) -> "CompressedTensorsScheme": # use the per-layer format if defined, otherwise, use global format format = format if format is not None else self.quant_format @@ -457,94 +534,105 @@ def _get_scheme_from_parts( return CompressedTensorsW4A16Fp4() if self._is_fp8_w4a8_sm90(weight_quant, input_quant): - return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits, - strategy=weight_quant.strategy, - symmetric=weight_quant.symmetric, - group_size=weight_quant.group_size, - actorder=weight_quant.actorder) + return CompressedTensorsW4A8Fp8( + num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder, + ) if self._is_wNa16_group_channel(weight_quant, input_quant): - if (format == CompressionFormat.marlin_24.value - and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): + if ( + format == CompressionFormat.marlin_24.value + and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS + ): assert weight_quant.symmetric return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, - group_size=weight_quant.group_size) - if (format == CompressionFormat.pack_quantized.value - and weight_quant.num_bits in WNA16_SUPPORTED_BITS): + group_size=weight_quant.group_size, + ) + if ( + format == CompressionFormat.pack_quantized.value + and weight_quant.num_bits in WNA16_SUPPORTED_BITS + ): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, symmetric=weight_quant.symmetric, group_size=weight_quant.group_size, - actorder=weight_quant.actorder) + actorder=weight_quant.actorder, + ) act_quant_format = is_activation_quantization_format(format) if act_quant_format: if self._is_fp4a4_nvfp4(weight_quant, input_quant): - if cutlass_fp4_supported( - ) or envs.VLLM_USE_NVFP4_CT_EMULATIONS: + if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS: return CompressedTensorsW4A4Fp4() else: logger.warning_once( "Current platform does not support cutlass NVFP4." - " Running CompressedTensorsW4A16Fp4.") - return CompressedTensorsW4A16Fp4( - has_input_global_scale=True) + " Running CompressedTensorsW4A16Fp4." + ) + return CompressedTensorsW4A16Fp4(has_input_global_scale=True) if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + CompressedTensorsW8A8Fp8.get_min_capability(), error=False + ) if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( weight_quant=weight_quant, - is_static_input_scheme=(input_quant - and not input_quant.dynamic)) + is_static_input_scheme=( + input_quant and not input_quant.dynamic + ), + ) else: # note: input_quant will be present for converted models; # will be ignored during inference post loading return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=not input_quant.dynamic) + is_static_input_scheme=not input_quant.dynamic, + ) # note: input_quant can be None if self._is_fp8_w8a16(weight_quant, input_quant): - is_static_input_scheme = (input_quant - and not input_quant.dynamic) + is_static_input_scheme = input_quant and not input_quant.dynamic return CompressedTensorsW8A16Fp8( strategy=weight_quant.strategy, - is_static_input_scheme=is_static_input_scheme) + is_static_input_scheme=is_static_input_scheme, + ) if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( strategy=weight_quant.strategy, is_static_input_scheme=True, - input_symmetric=input_quant.symmetric) + input_symmetric=input_quant.symmetric, + ) if self._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8( strategy=weight_quant.strategy, is_static_input_scheme=False, - input_symmetric=input_quant.symmetric) + input_symmetric=input_quant.symmetric, + ) if self._is_dynamic_token_w4a8_int(weight_quant, input_quant): - is_static_input_scheme = (input_quant - and not input_quant.dynamic) + is_static_input_scheme = input_quant and not input_quant.dynamic return CompressedTensorsW4A8Int( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, group_size=weight_quant.group_size, is_static_input_scheme=is_static_input_scheme, - input_symmetric=input_quant.symmetric) + input_symmetric=input_quant.symmetric, + ) - raise NotImplementedError( - "No compressed-tensors compatible scheme was found.") + raise NotImplementedError("No compressed-tensors compatible scheme was found.") - def get_scheme(self, - layer: torch.nn.Module, - layer_name: Optional[str] = None - ) -> Optional["CompressedTensorsScheme"]: + def get_scheme( + self, layer: torch.nn.Module, layer_name: Optional[str] = None + ) -> Optional["CompressedTensorsScheme"]: """ compressed-tensors supports non uniform in the following way: @@ -561,9 +649,9 @@ def get_scheme(self, # Find the "target" in the compressed-tensors config # that our layer conforms to. # TODO (@kylesayrs): support ignore module names with ct matching utils - if should_ignore_layer(layer_name, - ignore=self.ignore, - fused_mapping=self.packed_modules_mapping): + if should_ignore_layer( + layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping + ): return None # Will be empty for models with only sparsity @@ -573,7 +661,8 @@ def get_scheme(self, layer_name=layer_name, module=layer, targets=self.target_scheme_map.keys(), - fused_mapping=self.packed_modules_mapping) + fused_mapping=self.packed_modules_mapping, + ) scheme_dict = self.target_scheme_map[matched_target] weight_quant = scheme_dict.get("weights") @@ -582,25 +671,31 @@ def get_scheme(self, # Find the sparsity scheme of the layer # assume that fused layers inherit first component's sparsity scheme - sparsity_targets = (self.sparsity_scheme_map.keys() - - set(self.sparsity_ignore_list)) + sparsity_targets = self.sparsity_scheme_map.keys() - set( + self.sparsity_ignore_list + ) sparsity_scheme: Optional[SparsityCompressionConfig] = None with suppress(ValueError): matched_target = find_matched_target( layer_name=layer_name, module=layer, targets=sparsity_targets, - fused_mapping=self.packed_modules_mapping) + fused_mapping=self.packed_modules_mapping, + ) sparsity_scheme = self.sparsity_scheme_map[matched_target] - if self.supports_cutlass_24(weight_quant=weight_quant, - input_quant=input_quant, - sparsity_scheme=sparsity_scheme): + if self.supports_cutlass_24( + weight_quant=weight_quant, + input_quant=input_quant, + sparsity_scheme=sparsity_scheme, + ): # Have a valid sparsity scheme # Validate layer is supported by Cutlass 2:4 Kernel - model_compression_config = (None if sparsity_scheme is None - or sparsity_scheme.format == "dense" - else self.config) + model_compression_config = ( + None + if sparsity_scheme is None or sparsity_scheme.format == "dense" + else self.config + ) scheme = CompressedTensors24( quantized=weight_quant is not None or input_quant is not None, @@ -609,23 +704,23 @@ def get_scheme(self, model_compression_config=model_compression_config, ) elif weight_quant is None: - logger.warning_once("Acceleration for non-quantized schemes is " - "not supported by Compressed Tensors. " - "Falling back to UnquantizedLinearMethod") + logger.warning_once( + "Acceleration for non-quantized schemes is " + "not supported by Compressed Tensors. " + "Falling back to UnquantizedLinearMethod" + ) return None else: # Find the quant_scheme scheme = self._get_scheme_from_parts( # type: ignore - weight_quant=weight_quant, - input_quant=input_quant, - format=format) + weight_quant=weight_quant, input_quant=input_quant, format=format + ) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) self._check_scheme_supported(scheme.get_min_capability()) - logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, - layer_name) + logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name) return scheme def get_cache_scale(self, name: str) -> Optional[str]: @@ -647,16 +742,18 @@ def get_cache_scale(self, name: str) -> Optional[str]: def has_blocked_weights(self) -> bool: for scheme in self.target_scheme_map.values(): weight_quant = scheme.get("weights") - if (weight_quant is not None - and weight_quant.strategy == QuantizationStrategy.BLOCK): + if ( + weight_quant is not None + and weight_quant.strategy == QuantizationStrategy.BLOCK + ): return True return False @staticmethod def supports_cutlass_24( - weight_quant: Optional[QuantizationArgs], - input_quant: Optional[QuantizationArgs], - sparsity_scheme: Optional[SparsityCompressionConfig] = None + weight_quant: Optional[QuantizationArgs], + input_quant: Optional[QuantizationArgs], + sparsity_scheme: Optional[SparsityCompressionConfig] = None, ) -> bool: """ Check if the layer is supported by the Cutlass 2:4 Kernel @@ -666,7 +763,7 @@ def supports_cutlass_24( - Weight only quantization is not-supported - Supported weight quantization strategies are TENSOR and CHANNEL - Supported input quantization strategies are TENSOR and TOKEN - - Only 8 bit quantization is supported + - Only 8 bit quantization is supported :return: True if the layer is supported by the Cutlass 2:4 Kernel False otherwise @@ -675,16 +772,17 @@ def supports_cutlass_24( return False is_valid_sparsity_structure: bool = ( - sparsity_scheme.sparsity_structure == - SparsityStructure.TWO_FOUR.value) + sparsity_scheme.sparsity_structure == SparsityStructure.TWO_FOUR.value + ) valid_compressors = { CompressionFormat.dense.value, - CompressionFormat.sparse_24_bitmask.value + CompressionFormat.sparse_24_bitmask.value, } - is_valid_sparsity = (is_valid_sparsity_structure - and sparsity_scheme.format in valid_compressors) + is_valid_sparsity = ( + is_valid_sparsity_structure and sparsity_scheme.format in valid_compressors + ) if not is_valid_sparsity: return False @@ -699,7 +797,7 @@ def supports_cutlass_24( supported_weight_quant_strategies = [ QuantizationStrategy.TENSOR.value, - QuantizationStrategy.CHANNEL.value + QuantizationStrategy.CHANNEL.value, ] assert weight_quant is not None @@ -708,7 +806,8 @@ def supports_cutlass_24( return False supported_input_quant_strategies = [ - QuantizationStrategy.TENSOR.value, QuantizationStrategy.TOKEN.value + QuantizationStrategy.TENSOR.value, + QuantizationStrategy.TOKEN.value, ] if input_quant.strategy not in supported_input_quant_strategies: @@ -718,18 +817,22 @@ def supports_cutlass_24( class CompressedTensorsLinearMethod(LinearMethodBase): - def __init__(self, quantization_config: CompressedTensorsConfig): self.quantization_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): """ Use the CompressedTensorsScheme associated with each layer to create the necessary parameters for the layer. See LinearMethodBase for param @@ -743,12 +846,15 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes=output_partition_sizes, output_size=output_size, params_dtype=params_dtype, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None): + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): """ Use the output of create_weights and the CompressedTensorsScheme associated with the layer to apply the forward pass with the @@ -788,18 +894,21 @@ def validate_kv_cache_scheme(kv_cache_scheme: Optional[dict[str, Any]]): raise NotImplementedError( "Currently supported kv cache quantization is " "num_bits=8, type=float, however " - f"received num_bits={num_bits}, type={type_}") + f"received num_bits={num_bits}, type={type_}" + ) strategy = kv_cache_scheme.get("strategy") if strategy != "tensor": raise NotImplementedError( "Only support per-tensor scaling factor " "for compressed-tensors KV cache. " - f"Expected strategy: tensor, found strategy: {strategy}") + f"Expected strategy: tensor, found strategy: {strategy}" + ) is_symmetric = kv_cache_scheme.get("symmetric") if not is_symmetric: raise NotImplementedError( "Only support symmetric scaling factor " "for compressed-tensors KV cache. " - f"However found symmetric: {is_symmetric}") + f"However found symmetric: {is_symmetric}" + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 8504ba73defb..41e7f1c7a499 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -7,8 +7,7 @@ import torch from compressed_tensors import CompressionFormat -from compressed_tensors.quantization import (ActivationOrdering, - QuantizationStrategy) +from compressed_tensors.quantization import ActivationOrdering, QuantizationStrategy import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -16,41 +15,66 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( - FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, - FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported) + FusedMoE, + FusedMoEActivationFormat, + FusedMoEConfig, + FusedMoEMethodBase, + FusedMoEPermuteExpertsUnpermute, + FusedMoeWeightScaleSupported, +) from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, - int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config, - int8_w8a16_moe_quant_config, nvfp4_moe_quant_config) + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, + int4_w4a16_moe_quant_config, + int8_w8a8_moe_quant_config, + int8_w8a16_moe_quant_config, + nvfp4_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - is_valid_flashinfer_cutlass_fused_moe) + is_valid_flashinfer_cutlass_fused_moe, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa - WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) + WNA16_SUPPORTED_BITS, + WNA16_SUPPORTED_TYPES_MAP, +) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - find_matched_target) + find_matched_target, +) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, - select_nvfp4_gemm_impl) + build_flashinfer_fp4_cutlass_moe_prepare_finalize, + reorder_w1w3_to_w3w1, + select_nvfp4_gemm_impl, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - expert_weight_is_col_major, requant_weight_ue8m0_inplace) + expert_weight_is_col_major, + requant_weight_ue8m0_inplace, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_moe_marlin_supports_layer, marlin_make_workspace_new, - marlin_moe_permute_scales) + check_moe_marlin_supports_layer, + marlin_make_workspace_new, + marlin_moe_permute_scales, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_moe_fp4_layer_for_marlin) + prepare_moe_fp4_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - prepare_moe_fp8_layer_for_marlin) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - swizzle_blockscale) + prepare_moe_fp8_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) + all_close_1d, + normalize_e4m3fn_to_e4m3fnuz, + per_tensor_dequantize, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import CpuArchEnum, current_platform from vllm.scalar_type import scalar_types -from vllm.utils.deep_gemm import (get_col_major_tma_aligned_tensor, - is_deep_gemm_e8m0_used) +from vllm.utils.deep_gemm import ( + get_col_major_tma_aligned_tensor, + is_deep_gemm_e8m0_used, +) logger = init_logger(__name__) @@ -61,22 +85,24 @@ class GPTQMarlinState(Enum): __all__ = [ - "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsMoEMethod", + "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Int8MoEMethod", - "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsW4A4MoeMethod", "CompressedTensorsW4A8Int8MoEMethod" + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsW4A4MoeMethod", + "CompressedTensorsW4A8Int8MoEMethod", ] class CompressedTensorsMoEMethod(FusedMoEMethodBase): - def __init_(self, moe: FusedMoEConfig): super().__init__(moe) @staticmethod def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 - layer: torch.nn.Module + layer: torch.nn.Module, ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -86,9 +112,7 @@ def get_moe_method( else: # May have instead defined the linear layers in the fused model - fused_layers = [ - "re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*" - ] + fused_layers = ["re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*"] current_scheme = None for fused_layer in fused_layers: # Check if one of the fused layers are defined in quant_config @@ -96,64 +120,67 @@ def get_moe_method( layer_name=fused_layer, module=layer, targets=quant_config.target_scheme_map.keys(), - fused_mapping=quant_config.packed_modules_mapping) + fused_mapping=quant_config.packed_modules_mapping, + ) # Only valid if down_proj, gate_proj, and up_proj # are mapped to the same quant scheme in the quant_config if current_scheme is None: - current_scheme = quant_config.target_scheme_map.get( - matched_target) + current_scheme = quant_config.target_scheme_map.get(matched_target) else: assert current_scheme == quant_config.target_scheme_map.get( - matched_target) + matched_target + ) - weight_quant = quant_config.target_scheme_map[matched_target].get( - "weights") + weight_quant = quant_config.target_scheme_map[matched_target].get("weights") input_quant = quant_config.target_scheme_map[matched_target].get( - "input_activations") + "input_activations" + ) if quant_config._is_wNa16_group_channel(weight_quant, input_quant): # group_size=None means channelwise group_size = weight_quant.group_size or -1 # Prefer to use the MarlinMoE kernel when it is supported. if not check_moe_marlin_supports_layer(layer, group_size): - if (weight_quant.strategy in QuantizationStrategy.GROUP and - weight_quant.actorder in (ActivationOrdering.GROUP, - ActivationOrdering.DYNAMIC)): + if ( + weight_quant.strategy in QuantizationStrategy.GROUP + and weight_quant.actorder + in (ActivationOrdering.GROUP, ActivationOrdering.DYNAMIC) + ): raise ValueError( "WNA16MoE is not supported with actorder=group/dynamic." ) logger.info_once("Using CompressedTensorsWNA16MoEMethod") - return CompressedTensorsWNA16MoEMethod(quant_config, - layer.moe_config) + return CompressedTensorsWNA16MoEMethod(quant_config, layer.moe_config) else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod( - quant_config, layer.moe_config) + quant_config, layer.moe_config + ) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A4MoeMethod(layer.moe_config) - elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) - or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) - or quant_config._is_fp8_w8a8(weight_quant, input_quant)): - return CompressedTensorsW8A8Fp8MoEMethod(quant_config, - layer.moe_config) + elif ( + quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) + or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) + or quant_config._is_fp8_w8a8(weight_quant, input_quant) + ): + return CompressedTensorsW8A8Fp8MoEMethod(quant_config, layer.moe_config) elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8Int8MoEMethod(quant_config, - layer.moe_config) - elif quant_config._is_dynamic_token_w4a8_int(weight_quant, - input_quant): - return CompressedTensorsW4A8Int8MoEMethod(quant_config, - layer.moe_config) + return CompressedTensorsW8A8Int8MoEMethod(quant_config, layer.moe_config) + elif quant_config._is_dynamic_token_w4a8_int(weight_quant, input_quant): + return CompressedTensorsW4A8Int8MoEMethod(quant_config, layer.moe_config) else: raise RuntimeError( - f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" + ) class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): - def __init__(self, moe: FusedMoEConfig): from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 - detect_nvfp4_moe_support) + detect_nvfp4_moe_support, + ) + super().__init__(moe) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported @@ -161,10 +188,15 @@ def __init__(self, moe: FusedMoEConfig): self.use_marlin = _nvfp4.use_marlin self.group_size = 16 - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): layer.num_experts = num_experts layer.params_dtype = params_dtype @@ -175,8 +207,10 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # 2 fp4 items are packed in the input dimension hidden_size // 2, requires_grad=False, - dtype=torch.uint8), - requires_grad=False) + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_packed", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) @@ -186,8 +220,10 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // 2, - dtype=torch.uint8), - requires_grad=False) + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight_packed", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -198,11 +234,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.group_size, - dtype=torch.float8_e4m3fn), - requires_grad=False) + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) w2_weight_scale = torch.nn.Parameter( @@ -211,120 +250,135 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // self.group_size, - dtype=torch.float8_e4m3fn), - requires_grad=False) + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value}) + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # Weight Global Scales - w13_weight_scale_2 = torch.nn.Parameter(torch.empty( - num_experts, 2, dtype=torch.float32), - requires_grad=False) + w13_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) - w2_weight_scale_2 = torch.nn.Parameter(torch.empty( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) # Input Global Scales - w13_input_scale = torch.nn.Parameter(torch.empty(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) + w13_input_scale = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_input_global_scale", w13_input_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = torch.nn.Parameter(torch.empty(num_experts, - dtype=torch.float32), - requires_grad=False) + w2_input_scale = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_input_global_scale", w2_input_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w2_input_scale, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # From packed to weight - layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter( + layer.w13_weight_packed.data, requires_grad=False + ) - layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data, - requires_grad=False) + layer.w2_weight = torch.nn.Parameter( + layer.w2_weight_packed.data, requires_grad=False + ) # reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel. if self.allow_flashinfer: - w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data, - layer.w13_weight_scale.data, - dim=-2) + w, s = reorder_w1w3_to_w3w1( + layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2 + ) layer.w13_weight = torch.nn.Parameter(w, requires_grad=False) layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False) - if not torch.allclose(layer.w13_weight_global_scale[:, 0], - layer.w13_weight_global_scale[:, 1]): + if not torch.allclose( + layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] + ): logger.warning_once( "w1_weight_global_scale must match w3_weight_global_scale. " - "Accuracy may be affected.") + "Accuracy may be affected." + ) # Take inverse of global scale saved to disk layer.w13_weight_scale_2 = torch.nn.Parameter( - 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False) + 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False + ) layer.w2_weight_scale_2 = torch.nn.Parameter( - 1 / layer.w2_weight_global_scale.data, requires_grad=False) + 1 / layer.w2_weight_global_scale.data, requires_grad=False + ) if self.use_marlin: prepare_moe_fp4_layer_for_marlin(layer) return # swizzle weight scales - layer.w13_weight_scale = torch.nn.Parameter(swizzle_blockscale( - layer.w13_weight_scale), - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w13_weight_scale), requires_grad=False + ) - layer.w2_weight_scale = torch.nn.Parameter(swizzle_blockscale( - layer.w2_weight_scale), - requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + ) # w13 - w13_input_global_scale = layer.w13_input_global_scale.max( - dim=1).values.to(torch.float32) + w13_input_global_scale = layer.w13_input_global_scale.max(dim=1).values.to( + torch.float32 + ) layer.g1_alphas = torch.nn.Parameter( ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), - requires_grad=False) + requires_grad=False, + ) layer.w13_input_scale_quant = torch.nn.Parameter( - (w13_input_global_scale), requires_grad=False) + (w13_input_global_scale), requires_grad=False + ) # w2 layer.g2_alphas = torch.nn.Parameter( ((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to( - torch.float32), - requires_grad=False) + torch.float32 + ), + requires_grad=False, + ) layer.w2_input_scale_quant = torch.nn.Parameter( - (layer.w2_input_global_scale), requires_grad=False) + (layer.w2_input_global_scale), requires_grad=False + ) - def maybe_make_prepare_finalize( - self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: if self.use_marlin: return None elif not self.allow_flashinfer: return super().maybe_make_prepare_finalize() - prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( - self.moe) + prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize @@ -344,7 +398,8 @@ def select_gemm_impl( return experts def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: if self.use_marlin: return None @@ -381,8 +436,9 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: - raise NotImplementedError("EPLB not supported for " - "`CompressedTensorsW4A4MoeMethod` yet.") + raise NotImplementedError( + "EPLB not supported for `CompressedTensorsW4A4MoeMethod` yet." + ) assert activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids, _ = FusedMoE.select_experts( @@ -423,12 +479,13 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, - workspace=layer.workspace) + workspace=layer.workspace, + ) elif self.fused_experts is not None: assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight), ( - "Flashinfer CUTLASS Fused MoE not applicable!") + x, layer.w13_weight, layer.w2_weight + ), "Flashinfer CUTLASS Fused MoE not applicable!" return self.fused_experts( hidden_states=x, @@ -446,11 +503,12 @@ def apply( # FlashInfer fused experts path elif self.allow_flashinfer: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - flashinfer_cutlass_moe_fp4) + flashinfer_cutlass_moe_fp4, + ) assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight), ( - "Flashinfer CUTLASS Fused MoE not applicable!") + x, layer.w13_weight, layer.w2_weight + ), "Flashinfer CUTLASS Fused MoE not applicable!" assert self.moe_quant_config is not None @@ -468,12 +526,13 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, ) else: - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4) + from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 - assert expert_map is None, ("Expert Parallelism / expert_map " - "is currently not supported for " - "CompressedTensorsW4A4MoeMethod.") + assert expert_map is None, ( + "Expert Parallelism / expert_map " + "is currently not supported for " + "CompressedTensorsW4A4MoeMethod." + ) assert self.moe_quant_config is not None # Cutlass moe takes in activations in BF16/Half precision @@ -495,7 +554,6 @@ def apply( class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): - def __init__( self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 @@ -503,17 +561,19 @@ def __init__( ): super().__init__(moe) self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( - "weights") + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations") + "input_activations" + ) - per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR - and self.input_quant.strategy - == QuantizationStrategy.TENSOR) + per_tensor = ( + self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR + ) per_channel = ( self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and self.input_quant.strategy == QuantizationStrategy.TOKEN) + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) if not (per_tensor or per_channel): assert self.weight_quant.strategy == QuantizationStrategy.BLOCK self.weight_block_size = self.weight_quant.block_structure @@ -526,33 +586,44 @@ def __init__( if self.static_input_scales and per_channel: raise ValueError( "For FP8 Fused MoE layer, we require either per tensor or " - "channelwise, dynamic per token quantization.") + "channelwise, dynamic per token quantization." + ) # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN - and not self.block_quant) + self.use_marlin = ( + not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + and not self.block_quant + ) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled, + ) self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() # cutlass path self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( - self.weight_quant, self.input_quant) + self.weight_quant, self.input_quant + ) self.use_cutlass = not self.block_quant and ( quant_config._is_fp8_w8a8_sm90(self.weight_quant, self.input_quant) - or self.is_fp8_w8a8_sm100) + or self.is_fp8_w8a8_sm100 + ) self.disable_expert_map = False - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size layer.num_experts = num_experts @@ -577,31 +648,38 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, raise ValueError( f"The output_size of gate's and up's weight = " f"{intermediate_size_per_partition} is not divisible by " - f"weight quantization block_n = {block_n}.") - if (tp_size > 1 - and intermediate_size_per_partition % block_k != 0): + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1 and intermediate_size_per_partition % block_k != 0: # Required by row parallel raise ValueError( f"The input_size of down's weight = " f"{intermediate_size_per_partition} is not divisible by " - f"weight quantization block_k = {block_k}.") + f"weight quantization block_k = {block_k}." + ) # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -609,70 +687,83 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, if self.weight_quant.strategy == QuantizationStrategy.TENSOR: # Allocate 2 scales for w1 and w3 respectively. # They are combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, 2, dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-TENSOR quantization for FusedMoE.weight_loader. extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-CHANNEL quantization for FusedMoE.weight_loader. extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, - 2 * - ((intermediate_size_per_partition + block_n - 1) // block_n), - (hidden_size + block_k - 1) // block_k, - dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, (hidden_size + block_n - 1) // block_n, - (intermediate_size_per_partition + block_k - 1) // block_k, - dtype=torch.float32), - requires_grad=False) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-CHANNEL quantization for FusedMoE.weight_loader. extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: - w13_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) else: @@ -684,46 +775,53 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # We take the max of all the scales in case they differ. if self.static_input_scales: assert self.input_quant.strategy == QuantizationStrategy.TENSOR - if (layer.w13_input_scale is None or layer.w2_input_scale is None): + if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.w13_input_scale) - or not all_close_1d(layer.w2_input_scale)): + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer.") + "for each layer." + ) layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False) + layer.w13_input_scale.max(), requires_grad=False + ) layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False) + layer.w2_input_scale.max(), requires_grad=False + ) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, - layer.w13_input_scale) - w2_weight, w2_weight_scale, w2_input_scale = \ + w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, - layer.w2_input_scale) + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, - requires_grad=False) + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, - requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) # For Per-TENSOR case, Fp8 moe kernel needs single weight scale # for w13 per expert. Use max then dequant and requant each expert. @@ -735,29 +833,31 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) start += shard_size - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - rocm_aiter_fused_experts, shuffle_weights) + rocm_aiter_fused_experts, + shuffle_weights, + ) # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + layer.w13_weight.data, layer.w2_weight.data + ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) elif self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) @@ -770,20 +870,23 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = layer.w13_weight.device # ab_strides1 and c_strides2 are the same self.ab_strides1_c_strides2 = torch.full( - (layer.local_num_experts, ), + (layer.local_num_experts,), layer.hidden_size, device=device, - dtype=torch.int64) + dtype=torch.int64, + ) self.ab_strides2 = torch.full( - (layer.local_num_experts, ), + (layer.local_num_experts,), layer.intermediate_size_per_partition, device=device, - dtype=torch.int64) + dtype=torch.int64, + ) self.c_strides1 = torch.full( - (layer.local_num_experts, ), + (layer.local_num_experts,), 2 * layer.intermediate_size_per_partition, device=device, - dtype=torch.int64) + dtype=torch.int64, + ) if is_deep_gemm_e8m0_used() and self.block_quant: assert layer.weight_block_size is not None @@ -803,13 +906,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Ensure column-major TMA alignment expected by DeepGEMM. if expert_weight_is_col_major(layer.w13_weight_scale): layer.w13_weight_scale = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale) + layer.w13_weight_scale + ) if expert_weight_is_col_major(layer.w2_weight_scale): layer.w2_weight_scale = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale) + layer.w2_weight_scale + ) - def maybe_make_prepare_finalize( - self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: if self.use_marlin or self.rocm_aiter_moe_enabled: return None else: @@ -824,16 +928,19 @@ def select_gemm_impl( assert self.moe_quant_config is not None if self.use_cutlass: from vllm.model_executor.layers.fused_moe import ( - CutlassBatchedExpertsFp8, CutlassExpertsFp8) + CutlassBatchedExpertsFp8, + CutlassExpertsFp8, + ) experts: FusedMoEPermuteExpertsUnpermute num_dispatchers = prepare_finalize.num_dispatchers() - if (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts): - logger.debug("CutlassBatchedExpertsFp8(%s)", - self.__class__.__name__) + if ( + prepare_finalize.activation_format + == FusedMoEActivationFormat.BatchedExperts + ): + logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__) experts = CutlassBatchedExpertsFp8( self.moe.num_local_experts, num_dispatchers, @@ -855,23 +962,27 @@ def select_gemm_impl( quant_config=self.moe_quant_config, ) - self.disable_expert_map = (num_dispatchers > 1 - or not experts.supports_expert_map()) + self.disable_expert_map = ( + num_dispatchers > 1 or not experts.supports_expert_map() + ) return experts # triton path from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts) + BatchedTritonOrDeepGemmExperts, + ) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) + TritonOrDeepGemmExperts, + ) assert not self.rocm_aiter_moe_enabled and not self.use_marlin - if (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts): - max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank( - ) + if ( + prepare_finalize.activation_format + == FusedMoEActivationFormat.BatchedExperts + ): + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) @@ -881,20 +992,17 @@ def select_gemm_impl( quant_config=self.moe_quant_config, ) else: - logger.debug("TritonOrDeepGemmExperts(%s)", - self.__class__.__name__) - return TritonOrDeepGemmExperts(self.moe_quant_config, - allow_deep_gemm=True) + logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) + return TritonOrDeepGemmExperts(self.moe_quant_config, allow_deep_gemm=True) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: if self.use_marlin: return None - per_act_token = ( - self.input_quant.strategy == QuantizationStrategy.TOKEN) - per_channel_quant = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN + per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, @@ -931,8 +1039,8 @@ def apply( ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( - "EPLB not supported for " - "`CompressedTensorsW8A8Fp8MoEMethod` yet.") + "EPLB not supported for `CompressedTensorsW8A8Fp8MoEMethod` yet." + ) topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, @@ -949,18 +1057,15 @@ def apply( indices_type=self.topk_indices_dtype, ) - per_act_token = ( - self.input_quant.strategy == QuantizationStrategy.TOKEN) - per_channel_quant = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN + per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL # # Note: the order here is important. self.fused_experts can override # cutlass fp8 or fused_experts but not marlin or rocm. # if self.use_marlin: - assert activation == "silu", ( - f"{activation} not supported for Marlin MoE.") + assert activation == "silu", f"{activation} not supported for Marlin MoE." assert self.fused_experts is None return torch.ops.vllm.fused_marlin_moe( x, @@ -977,11 +1082,14 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, - workspace=layer.workspace) + workspace=layer.workspace, + ) elif self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - rocm_aiter_fused_experts) + rocm_aiter_fused_experts, + ) + assert per_act_token == per_channel_quant assert self.moe_quant_config is not None assert self.fused_experts is None @@ -1016,6 +1124,7 @@ def apply( # small-batch fallback on SM100 if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: from vllm.model_executor.layers.fused_moe import fused_experts + assert per_act_token == per_channel_quant return fused_experts( hidden_states=x, @@ -1032,7 +1141,9 @@ def apply( ) else: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) + cutlass_moe_fp8, + ) + assert per_act_token == per_channel_quant assert self.moe_quant_config is not None return cutlass_moe_fp8( @@ -1053,6 +1164,7 @@ def apply( else: from vllm.model_executor.layers.fused_moe import fused_experts + assert per_act_token == per_channel_quant assert self.moe_quant_config is not None return fused_experts( @@ -1071,7 +1183,6 @@ def apply( class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): - def __init__( self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 @@ -1079,69 +1190,83 @@ def __init__( ): super().__init__(moe) self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( - "weights") + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations") + "input_activations" + ) per_channel = ( self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and self.input_quant.strategy == QuantizationStrategy.TOKEN) + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) if not per_channel: raise ValueError( "For INT8 Fused MoE layers, we require channelwise, " "dynamic per token quantization. Found " - f"{self.weight_quant}, {self.input_quant}") + f"{self.weight_quant}, {self.input_quant}" + ) self.static_input_scales = not self.input_quant.dynamic if self.static_input_scales: raise ValueError( "For INT8 Fused MoE layers, we require channelwise, " - "dynamic per token quantization. Found static input scales.") - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + "dynamic per token quantization. Found static input scales." + ) + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): params_dtype = torch.int8 # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - hidden_size, - 1, - dtype=torch.float32), - requires_grad=False) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-CHANNEL quantization for FusedMoE.weight_loader. extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) @@ -1154,7 +1279,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: return int8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, @@ -1190,8 +1316,8 @@ def apply( if enable_eplb: raise NotImplementedError( - "EPLB not supported for " - "`CompressedTensorsW8A8Int8MoEMethod` yet.") + "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet." + ) from vllm.model_executor.layers.fused_moe import fused_experts @@ -1207,7 +1333,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return fused_experts( hidden_states=x, @@ -1225,7 +1352,6 @@ def apply( class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): - def __init__( self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 @@ -1241,58 +1367,71 @@ def __init__( self.strategy = config.strategy self.group_size = config.group_size self.actorder = config.actorder - assert config.symmetric, ( - "Only symmetric quantization is supported for MoE") - - if not (self.quant_config.quant_format - == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS): - raise ValueError("For Fused MoE layers, only ", - f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}") - self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] + assert config.symmetric, "Only symmetric quantization is supported for MoE" - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + if not ( + self.quant_config.quant_format == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS + ): + raise ValueError( + "For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}", + ) + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] - intermediate_size_full = extra_weight_attrs.pop( - "intermediate_size_full") + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will # shard for TP along the transposed dims - extra_weight_attrs.update({ - "is_transposed": True, - "quant_method": self.strategy - }) - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size // self.packed_factor, - 2 * intermediate_size_per_partition, - dtype=torch.int32), - requires_grad=False) + extra_weight_attrs.update( + {"is_transposed": True, "quant_method": self.strategy} + ) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_packed", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - intermediate_size_per_partition // self.packed_factor, - hidden_size, - dtype=torch.int32), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight_packed", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # In the case where we have actorder/g_idx, # we do not partition the w2 scales load_full_w2 = self.actorder and self.group_size != -1 - w2_scales_size = (intermediate_size_full - if load_full_w2 else intermediate_size_per_partition) + w2_scales_size = ( + intermediate_size_full if load_full_w2 else intermediate_size_per_partition + ) self.is_k_full = (not self.actorder) or ( - intermediate_size_per_partition == intermediate_size_full) + intermediate_size_per_partition == intermediate_size_full + ) if self.strategy == "channel": num_groups_w2 = num_groups_w13 = 1 @@ -1301,30 +1440,34 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, num_groups_w2 = w2_scales_size // self.group_size num_groups_w13 = hidden_size // self.group_size - w13_scale = torch.nn.Parameter(torch.ones( - num_experts, - num_groups_w13, - 2 * intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w13_scale = torch.nn.Parameter( + torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_scale) set_weight_attrs(w13_scale, extra_weight_attrs) - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - num_groups_w2, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_scale) set_weight_attrs(w2_scale, extra_weight_attrs) set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) - w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), - requires_grad=False) + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) layer.register_parameter("w2_weight_shape", w2_weight_shape) set_weight_attrs(w2_weight_shape, extra_weight_attrs) - w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), - requires_grad=False) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) layer.register_parameter("w13_weight_shape", w13_weight_shape) set_weight_attrs(w13_weight_shape, extra_weight_attrs) @@ -1359,8 +1502,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, ), requires_grad=False, ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) w2_g_idx_sort_indices = torch.nn.Parameter( @@ -1371,8 +1513,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, ), requires_grad=False, ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) layer.a13_scale = None @@ -1392,41 +1533,37 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort( - layer.w13_weight_g_idx[e]).to(torch.int32) - w2_g_idx_sort_indices[e] = torch.argsort( - layer.w2_weight_g_idx[e]).to(torch.int32) + w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to( + torch.int32 + ) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to( + torch.int32 + ) w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ - w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][ - w2_g_idx_sort_indices[e]] + w13_g_idx_sort_indices[e] + ] + w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]] replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) - replace_parameter(layer, "w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_parameter(layer, "w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) else: layer.w13_weight_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_weight_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) @@ -1456,8 +1593,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_weight_scale, - size_k=layer.w2_weight_scale.shape[1] * - (self.group_size if self.group_size != -1 else self.packed_factor), + size_k=layer.w2_weight_scale.shape[1] + * (self.group_size if self.group_size != -1 else self.packed_factor), size_n=layer.w2_weight_scale.shape[2], group_size=self.group_size, ) @@ -1466,7 +1603,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.workspace = marlin_make_workspace_new(device, 4) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: return None def apply( @@ -1496,11 +1634,10 @@ def apply( if enable_eplb: raise NotImplementedError( - "EPLB not supported for " - "`CompressedTensorsWNA16MarlinMoEMethod` yet.") + "EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet." + ) - assert activation == "silu", ( - f"{activation} not supported for Marlin MoE.") + assert activation == "silu", f"{activation} not supported for Marlin MoE." topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, @@ -1514,7 +1651,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return torch.ops.vllm.fused_marlin_moe( x, @@ -1536,11 +1674,11 @@ def apply( sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, workspace=layer.workspace, - is_k_full=self.is_k_full) + is_k_full=self.is_k_full, + ) class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): - def __init__( self, quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 @@ -1559,43 +1697,55 @@ def __init__( self.group_size = config.group_size # grouped actorder isn't supported by this kernel assert config.actorder != "group" - assert config.symmetric, ( - "Only symmetric quantization is supported for MoE") - - if not (self.quant_config.quant_format - == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS): - raise ValueError("For Fused MoE layers, only ", - f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}") + assert config.symmetric, "Only symmetric quantization is supported for MoE" - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + if not ( + self.quant_config.quant_format == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS + ): + raise ValueError( + "For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}", + ) + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will # shard for TP along the transposed dims - extra_weight_attrs.update({ - "is_transposed": True, - "quant_method": self.strategy - }) - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size // self.packed_factor, - 2 * intermediate_size_per_partition, - dtype=torch.int32), - requires_grad=False) + extra_weight_attrs.update( + {"is_transposed": True, "quant_method": self.strategy} + ) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_packed", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - intermediate_size_per_partition // self.packed_factor, - hidden_size, - dtype=torch.int32), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight_packed", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -1608,30 +1758,34 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, num_groups_w2 = w2_scales_size // self.group_size num_groups_w13 = hidden_size // self.group_size - w13_scale = torch.nn.Parameter(torch.ones( - num_experts, - num_groups_w13, - 2 * intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w13_scale = torch.nn.Parameter( + torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_scale) set_weight_attrs(w13_scale, extra_weight_attrs) - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - num_groups_w2, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_scale) set_weight_attrs(w2_scale, extra_weight_attrs) set_weight_attrs(w2_scale, {"load_full_w2": False}) - w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), - requires_grad=False) + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) layer.register_parameter("w2_weight_shape", w2_weight_shape) set_weight_attrs(w2_weight_shape, extra_weight_attrs) - w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), - requires_grad=False) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) layer.register_parameter("w13_weight_shape", w13_weight_shape) set_weight_attrs(w13_weight_shape, extra_weight_attrs) @@ -1666,8 +1820,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, ), requires_grad=False, ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) w2_g_idx_sort_indices = torch.nn.Parameter( @@ -1678,8 +1831,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, ), requires_grad=False, ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) layer.a13_scale = None @@ -1688,25 +1840,29 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Reconfigure packed weights and scales to match moe_wna16 format layer.w13_weight_packed = torch.nn.Parameter( - layer.w13_weight_packed.transpose(1, 2).contiguous().view( - torch.uint8), - requires_grad=False) + layer.w13_weight_packed.transpose(1, 2).contiguous().view(torch.uint8), + requires_grad=False, + ) layer.w2_weight_packed = torch.nn.Parameter( - layer.w2_weight_packed.transpose(1, - 2).contiguous().view(torch.uint8), - requires_grad=False) + layer.w2_weight_packed.transpose(1, 2).contiguous().view(torch.uint8), + requires_grad=False, + ) layer.w13_weight_scale = torch.nn.Parameter( - layer.w13_weight_scale.transpose(1, 2).contiguous(), - requires_grad=False) + layer.w13_weight_scale.transpose(1, 2).contiguous(), requires_grad=False + ) layer.w2_weight_scale = torch.nn.Parameter( - layer.w2_weight_scale.transpose(1, 2).contiguous(), - requires_grad=False) + layer.w2_weight_scale.transpose(1, 2).contiguous(), requires_grad=False + ) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: assert self.num_bits == 4 or self.num_bits == 8 - config_builder = (int4_w4a16_moe_quant_config if self.num_bits == 4 - else int8_w8a16_moe_quant_config) + config_builder = ( + int4_w4a16_moe_quant_config + if self.num_bits == 4 + else int8_w8a16_moe_quant_config + ) return config_builder( w1_scale=layer.w13_weight_scale, @@ -1742,8 +1898,9 @@ def apply( assert self.fused_experts is None if enable_eplb: - raise NotImplementedError("EPLB not supported for " - "`CompressedTensorsWNA16MoEMethod` yet.") + raise NotImplementedError( + "EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet." + ) from vllm.model_executor.layers.fused_moe import fused_experts @@ -1759,7 +1916,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return fused_experts( x, @@ -1787,9 +1945,10 @@ class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): """ def __init__( - self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 - moe: FusedMoEConfig): + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig, + ): super().__init__(moe) self.has_bias = self.moe.has_bias self.quant_config = quant_config @@ -1797,8 +1956,7 @@ def __init__( # Validate scheme: weights=W4 (channel or group), # activations=dynamic TOKEN (A8) wq = self.quant_config.target_scheme_map["Linear"].get("weights") - aq = self.quant_config.target_scheme_map["Linear"].get( - "input_activations") + aq = self.quant_config.target_scheme_map["Linear"].get("input_activations") # Must be dynamic per-token activations if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic: @@ -1809,8 +1967,7 @@ def __init__( # Weight can be channel-wise (group_size=None) or group-wise self.group_size = wq.group_size if (wq.group_size is not None) else -1 if wq.num_bits != 4: - raise ValueError( - "This method only supports 4-bit weights (num_bits=4).") + raise ValueError("This method only supports 4-bit weights (num_bits=4).") # CPU only if not current_platform.is_cpu(): @@ -1824,14 +1981,20 @@ def __init__( except AttributeError as err: raise RuntimeError( f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops; - install a newer build.""") from err + install a newer build.""" + ) from err self.static_input_scales = False # always dynamic per token # ---- parameter creation ---- - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Shapes per local rank (TP/EP): # w13: [E, 2*I_local, H] int8 (int4 values in [-8,7]) # w2 : [E, H, I_local] int8 @@ -1850,13 +2013,15 @@ def _n_scale_cols(in_features: int) -> int: return 1 if g == -1 else (in_features // g) # Register unpacked int4-as-int8 weights the loader will fill. - w13 = torch.nn.Parameter(torch.empty(E, 2 * IN, H, dtype=torch.int8), - requires_grad=False) + w13 = torch.nn.Parameter( + torch.empty(E, 2 * IN, H, dtype=torch.int8), requires_grad=False + ) set_weight_attrs(w13, extra_weight_attrs) layer.register_parameter("w13_weight", w13) - w2 = torch.nn.Parameter(torch.empty(E, H, IN, dtype=torch.int8), - requires_grad=False) + w2 = torch.nn.Parameter( + torch.empty(E, H, IN, dtype=torch.int8), requires_grad=False + ) set_weight_attrs(w2, extra_weight_attrs) layer.register_parameter("w2_weight", w2) @@ -1865,54 +2030,48 @@ def _n_scale_cols(in_features: int) -> int: # KleidiAI groupwise kernels accepts bfloat16 scales scale_dtype = torch.float32 if g == -1 else torch.bfloat16 - w13_s = torch.nn.Parameter(torch.ones(E, - 2 * IN, - _n_scale_cols(H), - dtype=scale_dtype), - requires_grad=False) + w13_s = torch.nn.Parameter( + torch.ones(E, 2 * IN, _n_scale_cols(H), dtype=scale_dtype), + requires_grad=False, + ) set_weight_attrs( - w13_s, { - "quant_method": "channel" if g == -1 else "group", - **extra_weight_attrs - }) + w13_s, + {"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs}, + ) layer.register_parameter("w13_weight_scale", w13_s) - w2_s = torch.nn.Parameter(torch.ones(E, - H, - _n_scale_cols(IN), - dtype=scale_dtype), - requires_grad=False) + w2_s = torch.nn.Parameter( + torch.ones(E, H, _n_scale_cols(IN), dtype=scale_dtype), requires_grad=False + ) set_weight_attrs( - w2_s, { - "quant_method": "channel" if g == -1 else "group", - **extra_weight_attrs - }) + w2_s, + {"quant_method": "channel" if g == -1 else "group", **extra_weight_attrs}, + ) layer.register_parameter("w2_weight_scale", w2_s) if self.has_bias: - w13_bias = torch.nn.Parameter(torch.zeros(E, - 2 * IN, - dtype=params_dtype), - requires_grad=False) + w13_bias = torch.nn.Parameter( + torch.zeros(E, 2 * IN, dtype=params_dtype), requires_grad=False + ) layer.register_parameter("w13_bias", w13_bias) set_weight_attrs(w13_bias, extra_weight_attrs) - w2_bias = torch.nn.Parameter(torch.zeros(num_experts, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w2_bias = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=params_dtype), + requires_grad=False, + ) layer.register_parameter("w2_bias", w2_bias) set_weight_attrs(w2_bias, extra_weight_attrs) # Placeholders for packed weights (will be replaced after packing) layer.register_parameter( - "w13_weight_packed", - torch.nn.Parameter(torch.empty(0), requires_grad=False)) + "w13_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs) layer.register_parameter( - "w2_weight_packed", - torch.nn.Parameter(torch.empty(0), requires_grad=False)) + "w2_weight_packed", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs) # dims for 4 bit fused matmuls @@ -1930,15 +2089,19 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: IN = layer.w2_in_features g = layer.group_size - def _pack_matrix(int4_as_int8_2d: torch.Tensor, - scales_2d: torch.Tensor, - bias_1d: Optional[torch.Tensor], in_features: int, - out_features: int) -> torch.Tensor: + def _pack_matrix( + int4_as_int8_2d: torch.Tensor, + scales_2d: torch.Tensor, + bias_1d: Optional[torch.Tensor], + in_features: int, + out_features: int, + ) -> torch.Tensor: # int4 values are stored as int8 in [-8,7]. # Shift to unsigned nibble and pack pairs along input-dim. tmp = int4_as_int8_2d.add(8) # [out, in] uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to( - torch.uint8) # [out, in//2] + torch.uint8 + ) # [out, in//2] # KleidiAI groupwise kernels accepts float32 scales # KleidiAI groupwise kernels accepts bfloat16 scales @@ -1946,15 +2109,19 @@ def _pack_matrix(int4_as_int8_2d: torch.Tensor, scales = scales_2d.to(scale_dtype) bias = None if bias_1d is None else bias_1d.to(torch.float32) return torch.ops.aten._dyn_quant_pack_4bit_weight( - uint8_nibbles, scales, bias, g if g != -1 else in_features, - in_features, out_features) + uint8_nibbles, + scales, + bias, + g if g != -1 else in_features, + in_features, + out_features, + ) # Pack per expert w13_packed_list = [] w2_packed_list = [] - has_w13_bias = hasattr(layer, - "w13_bias") and layer.w13_bias is not None + has_w13_bias = hasattr(layer, "w13_bias") and layer.w13_bias is not None has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None for e in range(E): @@ -1964,7 +2131,9 @@ def _pack_matrix(int4_as_int8_2d: torch.Tensor, layer.w13_weight_scale[e], # [2I, H/g or 1] layer.w13_bias[e] if has_w13_bias else None, # [2I] H, - I2)) + I2, + ) + ) w2_packed_list.append( _pack_matrix( # w2 shape is [H, IN]; we need [out, in] == [H, IN]. @@ -1972,42 +2141,58 @@ def _pack_matrix(int4_as_int8_2d: torch.Tensor, layer.w2_weight_scale[e], # [H, IN/g or 1] layer.w2_bias[e] if has_w2_bias else None, # [H] IN, - layer.w2_out_features # in_features=IN, out_features=H - )) + layer.w2_out_features, # in_features=IN, out_features=H + ) + ) # each packed tensor has identical shape per expert; stack on dim 0 w13_packed = torch.stack(w13_packed_list, dim=0) w2_packed = torch.stack(w2_packed_list, dim=0) - replace_parameter(layer, "w13_weight_packed", - torch.nn.Parameter(w13_packed, requires_grad=False)) - replace_parameter(layer, "w2_weight_packed", - torch.nn.Parameter(w2_packed, requires_grad=False)) + replace_parameter( + layer, + "w13_weight_packed", + torch.nn.Parameter(w13_packed, requires_grad=False), + ) + replace_parameter( + layer, + "w2_weight_packed", + torch.nn.Parameter(w2_packed, requires_grad=False), + ) # free raw tensors/scales/bias now that they're packed into the payload. replace_parameter( - layer, "w13_weight", - torch.nn.Parameter(torch.empty(0), requires_grad=False)) + layer, "w13_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) replace_parameter( - layer, "w2_weight", - torch.nn.Parameter(torch.empty(0), requires_grad=False)) + layer, "w2_weight", torch.nn.Parameter(torch.empty(0), requires_grad=False) + ) replace_parameter( - layer, "w13_weight_scale", - torch.nn.Parameter(torch.empty(0), requires_grad=False)) + layer, + "w13_weight_scale", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) replace_parameter( - layer, "w2_weight_scale", - torch.nn.Parameter(torch.empty(0), requires_grad=False)) + layer, + "w2_weight_scale", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) if has_w13_bias: replace_parameter( - layer, "w13_bias", - torch.nn.Parameter(torch.empty(0), requires_grad=False)) + layer, + "w13_bias", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) if has_w2_bias: replace_parameter( - layer, "w2_bias", - torch.nn.Parameter(torch.empty(0), requires_grad=False)) + layer, + "w2_bias", + torch.nn.Parameter(torch.empty(0), requires_grad=False), + ) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: # CPU dynamic 4-bit MoE path does not use modular kernels or # fused_experts; quant config is not needed. return None @@ -2036,9 +2221,9 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet." - assert activation in ( - "silu", "swigluoai", - "swiglu"), "Only SiLU/SwiGLUGU/SwiGLUUG are supported." + assert activation in ("silu", "swigluoai", "swiglu"), ( + "Only SiLU/SwiGLUGU/SwiGLUUG are supported." + ) assert expert_map is None, """expert_map/EP not implemented for CPU dyn-4bit MoE.""" @@ -2068,7 +2253,15 @@ def _act_kind(s: str) -> int: ) return torch.ops._C.dynamic_4bit_int_moe( - x, topk_ids.to(torch.long), topk_weights, layer.w13_weight_packed, - layer.w2_weight_packed, layer.w2_out_features, - layer.w2_in_features, layer.w13_out_features, layer.group_size, - apply_router_weight_on_input, int(_act_kind(activation))) \ No newline at end of file + x, + topk_ids.to(torch.long), + topk_weights, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w2_out_features, + layer.w2_in_features, + layer.w13_out_features, + layer.group_size, + apply_router_weight_on_input, + int(_act_kind(activation)), + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index cac65cca5093..fc0634394ece 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -5,23 +5,30 @@ from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int -from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, - CompressedTensorsW4A16Sparse24) +from .compressed_tensors_w4a16_24 import ( + W4A16SPARSE24_SUPPORTED_BITS, + CompressedTensorsW4A16Sparse24, +) from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 -from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, - CompressedTensorsWNA16) +from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 from .compressed_tensors_24 import CompressedTensors24 # isort: skip __all__ = [ - "CompressedTensorsScheme", "CompressedTensorsWNA16", - "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", - "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", - "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", - "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", - "CompressedTensorsW4A8Fp8" + "CompressedTensorsScheme", + "CompressedTensorsWNA16", + "CompressedTensorsW8A16Fp8", + "CompressedTensorsW4A16Sparse24", + "CompressedTensorsW8A8Int8", + "CompressedTensorsW8A8Fp8", + "WNA16_SUPPORTED_BITS", + "W4A16SPARSE24_SUPPORTED_BITS", + "CompressedTensors24", + "CompressedTensorsW4A16Fp4", + "CompressedTensorsW4A4Fp4", + "CompressedTensorsW4A8Int", + "CompressedTensorsW4A8Fp8", ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 168b221a9cfe..068eecf5e026 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -5,25 +5,33 @@ import torch from compressed_tensors import CompressionFormat, ModelCompressor -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy, - QuantizationType) +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) from compressed_tensors.utils import combine_shards from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise, sparse_cutlass_supported) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + convert_to_channelwise, + sparse_cutlass_supported, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) __all__ = ["CompressedTensors24"] @@ -31,7 +39,6 @@ class CompressedTensors24(CompressedTensorsScheme): - def __init__( self, quantized: bool = False, @@ -44,14 +51,20 @@ def __init__( self.input_quant = input_quant self.model_compressor = ( ModelCompressor.from_compression_config(model_compression_config) - if model_compression_config is not None else None) + if model_compression_config is not None + else None + ) self.do_sparse_decompress = ( self.model_compressor is not None and self.model_compressor.sparsity_config.format - == CompressionFormat.sparse_24_bitmask.value) + == CompressionFormat.sparse_24_bitmask.value + ) - if quantized and input_quant is not None and \ - self._get_quant_dtype() == current_platform.fp8_dtype(): + if ( + quantized + and input_quant is not None + and self._get_quant_dtype() == current_platform.fp8_dtype() + ): static = not input_quant.dynamic g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN self.quant_fp8 = QuantFP8(static, g_shape) @@ -74,7 +87,8 @@ def create_weights( if not sparse_cutlass_supported(): raise ValueError( "Sparse CUTLASS not supported. vLLM must be built with " - "CUDA 12.2 or later to use this feature") + "CUDA 12.2 or later to use this feature" + ) layer.logical_widths = output_partition_sizes layer.input_size = input_size @@ -93,9 +107,9 @@ def create_weights( weight_loader=weight_loader, ) if self.do_sparse_decompress: - assert all(partition_size % 8 == 0 - for partition_size in output_partition_sizes - ), "All partitions must be divisible by 8 for " + assert all( + partition_size % 8 == 0 for partition_size in output_partition_sizes + ), "All partitions must be divisible by 8 for " "2:4 sparse compressed models" shape = BasevLLMParameter( @@ -130,20 +144,24 @@ def create_weights( # Check if quantized, not just 2:4 Sparse if self.quantized: - if (self.weight_quant and self.weight_quant.strategy - == QuantizationStrategy.CHANNEL.value): + if ( + self.weight_quant + and self.weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ): weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), + data=torch.empty( + (sum(output_partition_sizes), 1), dtype=torch.float32 + ), output_dim=0, weight_loader=weight_loader, ) else: - assert (self.weight_quant and self.weight_quant.strategy - == QuantizationStrategy.TENSOR.value) + assert ( + self.weight_quant + and self.weight_quant.strategy == QuantizationStrategy.TENSOR.value + ) weight_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), - dtype=torch.float32), + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) @@ -152,8 +170,7 @@ def create_weights( # input quant will be non-none if self.input_quant and not self.input_quant.dynamic: # register input quant scale - assert (self.input_quant.strategy == - QuantizationStrategy.TENSOR.value) + assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value input_scale = BasevLLMParameter( data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader, @@ -163,12 +180,12 @@ def create_weights( else: # for sparse-only, pass in 1 for weight/input scales - weight_scale = torch.nn.Parameter(data=torch.ones( - 1, dtype=torch.float32), - requires_grad=False) - input_scale = torch.nn.Parameter(data=torch.ones( - 1, dtype=torch.float32), - requires_grad=False) + weight_scale = torch.nn.Parameter( + data=torch.ones(1, dtype=torch.float32), requires_grad=False + ) + input_scale = torch.nn.Parameter( + data=torch.ones(1, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("input_scale", input_scale) layer.register_parameter("weight_scale", weight_scale) @@ -199,8 +216,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # torch.compile workaround if hasattr(layer, "input_scale"): - layer.input_scale = torch.nn.Parameter(layer.input_scale.data, - requires_grad=False) + layer.input_scale = torch.nn.Parameter( + layer.input_scale.data, requires_grad=False + ) if self.weight_quant: if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value: @@ -214,11 +232,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: # torch.compile workaround layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.data, requires_grad=False) + layer.weight_scale.data, requires_grad=False + ) # Set all negative zero values to 0 prior to compression - if (layer.weight.dtype.is_floating_point - and layer.weight.dtype.itemsize >= 2): + if layer.weight.dtype.is_floating_point and layer.weight.dtype.itemsize >= 2: layer.weight.data[layer.weight.data == -0.0] = 0.0 w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data) @@ -243,7 +261,7 @@ def apply_weights( :return: The output tensor of the layer """ if self.quantized: - scale = getattr(layer, 'input_scale', None) + scale = getattr(layer, "input_scale", None) if self.weights_dtype == torch.int8: ops_output = ops.scaled_int8_quant(x, scale=scale) @@ -286,12 +304,16 @@ def _get_quant_dtype(self) -> torch.dtype: if not is_8_bits: raise ValueError("Cutlass only supports 8-bit quantization") - if (self.weight_quant.type == QuantizationType.FLOAT - and self.input_quant.type == QuantizationType.FLOAT): + if ( + self.weight_quant.type == QuantizationType.FLOAT + and self.input_quant.type == QuantizationType.FLOAT + ): return torch.float8_e4m3fn - if (self.weight_quant.type == QuantizationType.INT - and self.input_quant.type == QuantizationType.INT): + if ( + self.weight_quant.type == QuantizationType.INT + and self.input_quant.type == QuantizationType.INT + ): return torch.int8 raise ValueError("Quantization type not supported by Cutlass") @@ -317,7 +339,7 @@ def _decompress_bitmask_compressed_weight( :param bitmask: The 2:4 bitmask associated with the compressed weights, representing the positions of non-zero elements in the compressed tensor. - :param layer: The layer whose weights need to be processed after + :param layer: The layer whose weights need to be processed after loading. :return: The decompressed 2:4 sparse weight tensor. """ @@ -343,14 +365,16 @@ def _process_split( if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)): split_weights = torch.split(compressed, layer.logical_widths) split_bitmask = torch.split(bitmask, layer.logical_widths) - split_shape = [(out, layer.input_size_per_partition) - for out in layer.logical_widths] + split_shape = [ + (out, layer.input_size_per_partition) for out in layer.logical_widths + ] if split_weights: decompressed_shards = [ _process_split(compressed_weight, shape, bitmask) for compressed_weight, shape, bitmask in zip( - split_weights, split_shape, split_bitmask) + split_weights, split_shape, split_bitmask + ) ] decompressed = combine_shards(decompressed_shards) else: @@ -362,5 +386,6 @@ def _process_split( layer.input_size_per_partition, ), bitmask=bitmask, - )) + ) + ) return decompressed diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index a5d48f235674..688621cbf79a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -11,7 +11,7 @@ class CompressedTensorsScheme(ABC): """ - Abstract class used to describe the weight creation and forward pass + Abstract class used to describe the weight creation and forward pass of different quantization schemes supported by CompressedTensors. """ @@ -26,20 +26,21 @@ def get_min_capability(cls) -> int: @abstractmethod def create_weights(self, *args, **kwargs): """ - Weight creation for the particular scheme. Inputs to this function + Weight creation for the particular scheme. Inputs to this function """ raise NotImplementedError @abstractmethod - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]): + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ): """ - Run the forward pass for the particular scheme. This is where + Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied. - :param layer: torch.nn.Module with the registered weights and - other parameters relevant to the particular scheme. + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. :param x: input to the layer :param bias: bias parameter diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 3f3e7668fcf7..af06418c959d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -8,13 +8,18 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) + GPTQ_MARLIN_24_MAX_PARALLEL, + GPTQ_MARLIN_24_MIN_THREAD_N, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsW4A16Sparse24"] @@ -25,11 +30,7 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): - - def __init__(self, - strategy: str, - num_bits: int, - group_size: Optional[int] = None): + def __init__(self, strategy: str, num_bits: int, group_size: Optional[int] = None): self.strategy = strategy self.group_size = group_size self.tile_size = 16 @@ -37,13 +38,13 @@ def __init__(self, if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP: raise ValueError( f"Unsupported num_bits = {num_bits}. " - f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}") + f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}" + ) self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits] if self.strategy == "group" and self.group_size is None: - raise ValueError( - "group_size must be given when using strategy group") + raise ValueError("group_size must be given when using strategy group") @classmethod def get_min_capability(cls) -> int: @@ -52,18 +53,20 @@ def get_min_capability(cls) -> int: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # required by torch.compile to be torch.nn.Parameter - layer.weight_packed = Parameter(layer.weight_packed.data, - requires_grad=False) - layer.scale_packed = Parameter(layer.scale_packed.data, - requires_grad=False) + layer.weight_packed = Parameter(layer.weight_packed.data, requires_grad=False) + layer.scale_packed = Parameter(layer.scale_packed.data, requires_grad=False) layer.meta = Parameter(layer.meta.data, requires_grad=False) - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - + def create_weights( + self, + layer: torch.nn.Module, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): assert params_dtype == torch.float16, ( "float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501 ) @@ -71,55 +74,59 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, pack_factor = 32 // self.quant_type.size_bits output_size_per_partition = sum(output_partition_sizes) - qweight = PackedvLLMParameter(data=torch.empty( - input_size_per_partition // self.tile_size // 2, - output_size_per_partition * self.tile_size // pack_factor, - dtype=torch.int32, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=pack_factor, - marlin_tile_size=self.tile_size, - weight_loader=weight_loader) - - input_groups = (1 if self.group_size is None else - input_size_per_partition // self.group_size) + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.tile_size // 2, + output_size_per_partition * self.tile_size // pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=pack_factor, + marlin_tile_size=self.tile_size, + weight_loader=weight_loader, + ) + + input_groups = ( + 1 + if self.group_size is None + else input_size_per_partition // self.group_size + ) weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( input_groups, output_size_per_partition, dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if self.group_size is not None: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) else: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) - - weight_shape = BasevLLMParameter(data=torch.empty(2, - dtype=torch.int64), - weight_loader=weight_loader) - - meta = PackedvLLMParameter(data=torch.empty( - input_size_per_partition // 8 // 2 // 2, - output_size_per_partition * 2, - dtype=torch.int16, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=1, - marlin_tile_size=2, - weight_loader=weight_loader) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) + + weight_shape = BasevLLMParameter( + data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader + ) + + meta = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader, + ) layer.register_parameter("weight_packed", qweight) layer.register_parameter("weight_shape", weight_shape) @@ -127,16 +134,17 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.register_parameter("meta", meta) max_workspace_size = ( - output_size_per_partition // - GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL + output_size_per_partition // GPTQ_MARLIN_24_MIN_THREAD_N + ) * GPTQ_MARLIN_24_MAX_PARALLEL - workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int), - requires_grad=False) + workspace = Parameter( + torch.zeros(max_workspace_size, dtype=torch.int), requires_grad=False + ) layer.workspace = workspace - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: qweight = layer.weight_packed meta = layer.meta scales = layer.scale_packed @@ -148,11 +156,19 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, size_k = x_2d.shape[1] size_n = scales.shape[1] - output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, - workspace, self.quant_type, size_m, - size_n, size_k) + output_2d = ops.gptq_marlin_24_gemm( + x_2d, + qweight, + meta, + scales, + workspace, + self.quant_type, + size_m, + size_n, + size_k, + ) - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py index 96dccf04d490..a96f51538b38 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -6,18 +6,22 @@ from torch.nn.parameter import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin) -from vllm.model_executor.parameter import (GroupQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, +) +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) __all__ = ["CompressedTensorsW4A16Fp4"] class CompressedTensorsW4A16Fp4(CompressedTensorsScheme): - def __init__(self, has_input_global_scale: bool = False): self.has_input_global_scale = has_input_global_scale self.group_size = 16 @@ -27,49 +31,59 @@ def get_min_capability(cls) -> int: # dont restrict as emulations return 80 - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition # Weight - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // 2, - dtype=torch.uint8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_packed", weight) # Global Weight Scale weight_global_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("weight_global_scale", weight_global_scale) # Per Group Weight Scale - weight_scale = GroupQuantScaleParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // self.group_size, - dtype=torch.float8_e4m3fn, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) if self.has_input_global_scale: input_global_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), - dtype=torch.float32), - weight_loader=weight_loader) + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("input_global_scale", input_global_scale) def process_weights_after_loading(self, layer) -> None: @@ -81,25 +95,30 @@ def process_weights_after_loading(self, layer) -> None: # Rename weight_global_scale to weight_scale_2 that marlin expects # Note: ct stores the inverse of what is expected by the marlin kernel layer.weight_scale_2 = Parameter( - 1 / layer.weight_global_scale.max().to(torch.float32), - requires_grad=False) + 1 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False + ) del layer.weight_global_scale if self.has_input_global_scale: layer.input_global_scale = torch.nn.Parameter( - layer.input_global_scale.data, requires_grad=False) + layer.input_global_scale.data, requires_grad=False + ) prepare_fp4_layer_for_marlin(layer) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return apply_fp4_marlin_linear(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - weight_scale_2=layer.weight_scale_2, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias) + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index d472427756d4..676f4de6ee7b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -9,14 +9,17 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 - run_nvfp4_emulations) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - swizzle_blockscale) -from vllm.model_executor.parameter import (GroupQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + run_nvfp4_emulations, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer logger = init_logger(__name__) @@ -25,7 +28,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): - def __init__(self): if envs.VLLM_USE_TRTLLM_FP4_GEMM: assert has_flashinfer(), "TRTLLM FP4 GEMM requires FlashInfer" @@ -54,58 +56,67 @@ def get_min_capability(cls) -> int: return 80 return 100 - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition # Weight - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // 2, - dtype=torch.uint8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_packed", weight) # Global Weight Scale weight_global_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("weight_global_scale", weight_global_scale) # Per Group Weight Scale - weight_scale = GroupQuantScaleParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition // self.group_size, - dtype=torch.float8_e4m3fn, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) input_global_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("input_global_scale", input_global_scale) def process_weights_after_loading(self, layer) -> None: - global_input_scale = layer.input_global_scale.max().to(torch.float32) - layer.input_global_scale = Parameter(global_input_scale, - requires_grad=False) + layer.input_global_scale = Parameter(global_input_scale, requires_grad=False) layer.weight_global_scale = Parameter( - layer.weight_global_scale.max().to(torch.float32), - requires_grad=False) + layer.weight_global_scale.max().to(torch.float32), requires_grad=False + ) if self.backend == "flashinfer-trtllm": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. @@ -118,40 +129,43 @@ def process_weights_after_loading(self, layer) -> None: weight_scale = layer.weight_scale.data epilogue_tile_m = 128 - weight = shuffle_matrix_a(weight.view(torch.uint8), - epilogue_tile_m) - weight_scale = (shuffle_matrix_sf_a(weight_scale.view( - torch.uint8), epilogue_tile_m).reshape( - weight_scale.shape).view(torch.float8_e4m3fn)) + weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) + weight_scale = ( + shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) + .reshape(weight_scale.shape) + .view(torch.float8_e4m3fn) + ) layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_packed = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) if self.backend == "fbgemm": - swizzled_weight_scale = swizzled_weight_scale.view(-1).view( - torch.uint8) - layer.weight_scale = Parameter(swizzled_weight_scale, - requires_grad=False) - layer.weight_packed = Parameter(layer.weight_packed.data, - requires_grad=False) + swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8) + layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) + layer.weight_packed = Parameter( + layer.weight_packed.data, requires_grad=False + ) layer.alpha = Parameter( 1 / (layer.input_global_scale * layer.weight_global_scale), - requires_grad=False) - - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + requires_grad=False, + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if envs.VLLM_USE_NVFP4_CT_EMULATIONS: out = run_nvfp4_emulations( x=x, input_global_scale=layer.input_global_scale, weight=layer.weight_packed, weight_scale_swizzled=layer.weight_scale, - weight_global_scale=layer.weight_global_scale) + weight_global_scale=layer.weight_global_scale, + ) if bias is not None: out = out + bias return out @@ -162,8 +176,14 @@ def apply_weights(self, # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale) - mm_args = (x_fp4, layer.weight_packed, x_blockscale, - layer.weight_scale, layer.alpha, output_dtype) + mm_args = ( + x_fp4, + layer.weight_packed, + x_blockscale, + layer.weight_scale, + layer.alpha, + output_dtype, + ) if self.backend == "flashinfer-trtllm": out = flashinfer_scaled_fp4_mm(*mm_args, backend="trtllm") elif self.backend == "flashinfer-cutlass": diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py index 3d9827058803..59d99e1e1c90 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -8,18 +8,21 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( - MPLinearLayerConfig, choose_mp_linear_kernel) + MPLinearLayerConfig, + choose_mp_linear_kernel, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_repeat_scales_on_all_ranks) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) -# yapf: enable + marlin_repeat_scales_on_all_ranks, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -34,13 +37,14 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): _kernel_backends_being_used: set[str] = set() - def __init__(self, - strategy: str, - num_bits: int, - group_size: Optional[int] = None, - symmetric: Optional[bool] = True, - actorder: Optional[ActivationOrdering] = None): - + def __init__( + self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None, + ): self.pack_factor = 32 // num_bits self.strategy = strategy self.symmetric = symmetric @@ -48,13 +52,15 @@ def __init__(self, self.has_g_idx = actorder == ActivationOrdering.GROUP if self.group_size != 128 or self.strategy != "group": - raise ValueError("W4A8 kernels require group quantization " \ - "with group size 128") + raise ValueError( + "W4A8 kernels require group quantization with group size 128" + ) if num_bits not in W4A8_SUPPORTED_TYPES_MAP: raise ValueError( f"Unsupported num_bits = {num_bits}. " - f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}") + f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}" + ) self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] @@ -63,38 +69,45 @@ def get_min_capability(cls) -> int: # hopper return 90 - def create_weights(self, layer: torch.nn.Module, output_size: int, - input_size: int, output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - + def create_weights( + self, + layer: torch.nn.Module, + output_size: int, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), - partition_weight_shape=\ - (input_size_per_partition, output_size_per_partition), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), weight_type=self.quant_type, act_type=torch.float8_e4m3fn, # always use fp8(e4m3) group_size=self.group_size, zero_points=not self.symmetric, has_g_idx=self.has_g_idx, - out_type=params_dtype + out_type=params_dtype, ) kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW4A8Fp8", - kernel_type.__name__) + logger.info("Using %s for CompressedTensorsW4A8Fp8", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # If group_size is -1, we are in channelwise case. group_size = self.group_size if self.group_size != -1 else input_size - row_parallel = (input_size != input_size_per_partition) + row_parallel = input_size != input_size_per_partition partition_scales = not marlin_repeat_scales_on_all_ranks( - self.has_g_idx, self.group_size, row_parallel) + self.has_g_idx, self.group_size, row_parallel + ) scales_and_zp_size = input_size // group_size @@ -102,68 +115,69 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, assert input_size_per_partition % group_size == 0 scales_and_zp_size = input_size_per_partition // group_size - weight = PackedvLLMParameter(input_dim=1, - output_dim=0, - weight_loader=weight_loader, - packed_factor=self.pack_factor, - packed_dim=1, - data=torch.empty( - output_size_per_partition, - input_size_per_partition // - self.pack_factor, - dtype=torch.int32, - )) + weight = PackedvLLMParameter( + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + ) # TODO(czhu): allocate the packed fp8 scales memory here? # the scales will be expanded by 8x via `cutlass_pack_scale_fp8` weight_scale_args = { - "weight_loader": - weight_loader, - "data": - torch.empty( + "weight_loader": weight_loader, + "data": torch.empty( output_size_per_partition, scales_and_zp_size, dtype=torch.float8_e4m3fn, - ) + ), } if not partition_scales: - weight_scale = ChannelQuantScaleParameter(output_dim=0, - **weight_scale_args) + weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) else: - weight_scale = GroupQuantScaleParameter(output_dim=0, - input_dim=1, - **weight_scale_args) + weight_scale = GroupQuantScaleParameter( + output_dim=0, input_dim=1, **weight_scale_args + ) # A 2D array defining the original shape of the weights # before packing - weight_shape = BasevLLMParameter(data=torch.empty(2, - dtype=torch.int64), - weight_loader=weight_loader) + weight_shape = BasevLLMParameter( + data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader + ) # per-channel scales weight_chan_scale = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), - dtype=torch.float32), + data=torch.empty((output_size_per_partition, 1), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) layer.register_parameter("weight_chan_scale", weight_chan_scale) - self.kernel = kernel_type(mp_linear_kernel_config, - w_q_param_name="weight_packed", - w_s_param_name="weight_scale", - w_zp_param_name="weight_zero_point", - w_gidx_param_name="weight_g_idx") + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name="weight_zero_point", + w_gidx_param_name="weight_g_idx", + ) # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py index f1fca85508a6..61a9f6b75cb1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py @@ -7,12 +7,17 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( - MPLinearLayerConfig, choose_mp_linear_kernel) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - GroupQuantScaleParameter, - ModelWeightParameter) + MPLinearLayerConfig, + choose_mp_linear_kernel, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + ModelWeightParameter, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -27,12 +32,14 @@ class CompressedTensorsW4A8Int(CompressedTensorsScheme): _kernel_backends_being_used: set[str] = set() - def __init__(self, - strategy: str, - num_bits: int, - group_size: Optional[int] = None, - is_static_input_scheme: bool = False, - input_symmetric: bool = True): + def __init__( + self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + is_static_input_scheme: bool = False, + input_symmetric: bool = True, + ): self.strategy = strategy self.group_size = -1 if group_size is None else group_size self.is_static_input_scheme = is_static_input_scheme @@ -41,42 +48,53 @@ def __init__(self, if num_bits not in W4A8_SUPPORTED_TYPES_MAP: raise ValueError( f"Unsupported num_bits = {num_bits}." - f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}") + f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}" + ) self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] @classmethod def get_min_capability(cls) -> int: return 1 - def create_weights(self, layer: torch.nn.Module, output_size: int, - input_size: int, output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_size: int, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) - row_parallel = (input_size != input_size_per_partition) + row_parallel = input_size != input_size_per_partition # Compute effective group_size if self.group_size == -1: - effective_group_size = (input_size_per_partition - if row_parallel else input_size) + effective_group_size = ( + input_size_per_partition if row_parallel else input_size + ) else: effective_group_size = self.group_size # Ensure group_size divides input_size_per_partition assert input_size_per_partition % effective_group_size == 0, ( f"input_size_per_partition {input_size_per_partition}" - f" not divisible by group_size {effective_group_size}") + f" not divisible by group_size {effective_group_size}" + ) # Determine scale partitioning - is_channelwise = (self.group_size == -1) - repeat_scales = (is_channelwise and row_parallel) + is_channelwise = self.group_size == -1 + repeat_scales = is_channelwise and row_parallel partition_scales = not repeat_scales mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), - partition_weight_shape=(input_size_per_partition, - output_size_per_partition), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), weight_type=self.quant_type, act_type=params_dtype, group_size=effective_group_size, @@ -86,50 +104,50 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW4A8Int", - kernel_type.__name__) + logger.info("Using %s for CompressedTensorsW4A8Int", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) scales_and_zp_size = input_size_per_partition // effective_group_size - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.int8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) weight_scale_args = { - "weight_loader": - weight_loader, - "data": - torch.empty(output_size_per_partition, - scales_and_zp_size, - dtype=params_dtype) + "weight_loader": weight_loader, + "data": torch.empty( + output_size_per_partition, scales_and_zp_size, dtype=params_dtype + ), } if partition_scales: - weight_scale = GroupQuantScaleParameter(output_dim=0, - input_dim=1, - **weight_scale_args) + weight_scale = GroupQuantScaleParameter( + output_dim=0, input_dim=1, **weight_scale_args + ) else: - weight_scale = ChannelQuantScaleParameter(output_dim=0, - **weight_scale_args) + weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_scale", weight_scale) - self.kernel = kernel_type(mp_linear_kernel_config, - w_q_param_name="weight_packed", - w_s_param_name="weight_scale", - w_zp_param_name=None, - w_gidx_param_name=None) + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name=None, + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py index 01a87a088899..709d2538e6ad 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -7,24 +7,27 @@ from compressed_tensors.quantization import QuantizationStrategy from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + convert_to_channelwise, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) __all__ = ["CompressedTensorsW8A16Fp8"] -SUPPORTED_STRATEGIES = [ - QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR -] +SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR] class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): - def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme @@ -39,31 +42,36 @@ def get_min_capability(cls) -> int: # we expand each scale to its shard's channels. def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.TENSOR: - ws_channelwise = convert_to_channelwise(layer.weight_scale, - layer.logical_widths) - layer.weight_scale = torch.nn.Parameter(ws_channelwise, - requires_grad=False) + ws_channelwise = convert_to_channelwise( + layer.weight_scale, layer.logical_widths + ) + layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False) else: # required by torch.compile to be torch.nn.Parameter - layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, - requires_grad=False) + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False + ) # Weights must be transposed for marlin - layer.weight = torch.nn.Parameter(layer.weight.t(), - requires_grad=False) + layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False) if self.is_static_input_scheme: # required by torch.compile to be torch.nn.Parameter - layer.input_scale = torch.nn.Parameter(layer.input_scale.data, - requires_grad=False) + layer.input_scale = torch.nn.Parameter( + layer.input_scale.data, requires_grad=False + ) prepare_fp8_layer_for_marlin(layer) - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - + def create_weights( + self, + layer: torch.nn.Module, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition @@ -72,50 +80,59 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.weight_block_size = None # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # WEIGHT SCALE if self.strategy == QuantizationStrategy.CHANNEL: weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) elif self.strategy == QuantizationStrategy.TENSOR: - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) else: raise ValueError( f"Unsupported weight strategy={self.strategy}, " - f"supported strategies are {SUPPORTED_STRATEGIES}") + f"supported strategies are {SUPPORTED_STRATEGIES}" + ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE (to deal with converted checkpoints) if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("input_scale", input_scale) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - return apply_fp8_marlin_linear(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias) + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 4755c17c5967..902c9c7bde97 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -4,25 +4,35 @@ from typing import Callable, Optional import torch -from compressed_tensors.quantization import (QuantizationArgs, - QuantizationStrategy) +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from torch.nn import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, - create_fp8_input_scale, create_fp8_scale_parameter, - create_fp8_weight_parameter, maybe_post_process_fp8_weight_block, - process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy, - process_fp8_weight_tensor_strategy, validate_fp8_block_shape) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + W8A8BlockFp8LinearOp, + check_aiter_fp8_linear_support, + create_fp8_input_scale, + create_fp8_scale_parameter, + create_fp8_weight_parameter, + maybe_post_process_fp8_weight_block, + process_fp8_weight_block_strategy, + process_fp8_weight_channel_strategy, + process_fp8_weight_tensor_strategy, + validate_fp8_block_shape, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, cutlass_block_fp8_supported, maybe_create_device_identity) -from vllm.model_executor.parameter import (BlockQuantScaleParameter, - ChannelQuantScaleParameter, - PerTensorScaleParameter) + Fp8LinearOp, + cutlass_block_fp8_supported, + maybe_create_device_identity, +) +from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ChannelQuantScaleParameter, + PerTensorScaleParameter, +) __all__ = ["CompressedTensorsW8A8Fp8"] @@ -34,9 +44,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): - - def __init__(self, weight_quant: QuantizationArgs, - is_static_input_scheme: bool): + def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): self.weight_quant = weight_quant self.strategy = weight_quant.strategy self.out_dtype = torch.get_default_dtype() @@ -46,8 +54,11 @@ def __init__(self, weight_quant: QuantizationArgs, if self.weight_block_size is not None: self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) else: - self.act_q_group_shape = GroupShape.PER_TENSOR \ - if is_static_input_scheme else GroupShape.PER_TOKEN + self.act_q_group_shape = ( + GroupShape.PER_TENSOR + if is_static_input_scheme + else GroupShape.PER_TOKEN + ) self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() @@ -63,18 +74,25 @@ def __init__(self, weight_quant: QuantizationArgs, else: self.fp8_linear = Fp8LinearOp( act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_q_group_shape) + act_quant_group_shape=self.act_q_group_shape, + ) @classmethod def get_min_capability(cls) -> int: # lovelace and up return 89 - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - weight_loader: Callable, **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): maybe_create_device_identity() output_size_per_partition = sum(output_partition_sizes) @@ -86,48 +104,57 @@ def create_weights(self, layer: torch.nn.Module, assert self.weight_block_size is not None layer.weight_block_size = self.weight_block_size # Validate block quantization shapes - validate_fp8_block_shape(layer, input_size, output_size, - input_size_per_partition, - output_partition_sizes, - self.weight_block_size) + validate_fp8_block_shape( + layer, + input_size, + output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size, + ) # WEIGHT - weight = create_fp8_weight_parameter(output_size_per_partition, - input_size_per_partition, - weight_loader) + weight = create_fp8_weight_parameter( + output_size_per_partition, input_size_per_partition, weight_loader + ) layer.register_parameter("weight", weight) # WEIGHT SCALE weight_scale = create_fp8_scale_parameter( - strategy_to_parameter_type[self.strategy], output_partition_sizes, - input_size_per_partition, layer.weight_block_size, weight_loader) + strategy_to_parameter_type[self.strategy], + output_partition_sizes, + input_size_per_partition, + layer.weight_block_size, + weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: - input_scale = create_fp8_input_scale(output_partition_sizes, - weight_loader) + input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader) layer.register_parameter("input_scale", input_scale) def process_weights_after_loading(self, layer) -> None: if self.strategy == QuantizationStrategy.TENSOR: - weight, weight_scale, input_scale = ( - process_fp8_weight_tensor_strategy( - layer.weight, layer.weight_scale, layer.logical_widths, - getattr(layer, 'input_scale', None))) + weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( + layer.weight, + layer.weight_scale, + layer.logical_widths, + getattr(layer, "input_scale", None), + ) weight = weight.t() elif self.strategy == QuantizationStrategy.CHANNEL: - weight, weight_scale, input_scale = ( - process_fp8_weight_channel_strategy( - layer.weight, layer.weight_scale, - getattr(layer, 'input_scale', None))) + weight, weight_scale, input_scale = process_fp8_weight_channel_strategy( + layer.weight, layer.weight_scale, getattr(layer, "input_scale", None) + ) weight = weight.t() elif self.strategy == QuantizationStrategy.BLOCK: assert self.is_static_input_scheme is False weight, weight_scale = process_fp8_weight_block_strategy( - layer.weight, layer.weight_scale) + layer.weight, layer.weight_scale + ) input_scale = None else: @@ -137,25 +164,23 @@ def process_weights_after_loading(self, layer) -> None: layer.weight = Parameter(weight.data, requires_grad=False) layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) if input_scale is not None: - layer.input_scale = Parameter(input_scale.data, - requires_grad=False) + layer.input_scale = Parameter(input_scale.data, requires_grad=False) # INPUT SCALE - if self.is_static_input_scheme and hasattr(layer, 'input_scale'): - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) + if self.is_static_input_scheme and hasattr(layer, "input_scale"): + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: layer.input_scale = None if self.strategy == QuantizationStrategy.BLOCK: - maybe_post_process_fp8_weight_block( - layer, self.cutlass_block_fp8_supported) - - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.weight_block_size is not None: return self.w8a8_block_fp8_linear.apply( input=x, @@ -165,9 +190,11 @@ def apply_weights(self, bias=bias, ) - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias) + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 6189f0609d85..70316a7553ca 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -8,13 +8,18 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + ScaledMMLinearLayerConfig, + choose_scaled_mm_linear_kernel, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) logger = init_logger(__name__) @@ -22,8 +27,9 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): _kernel_backends_being_used: set[str] = set() - def __init__(self, strategy: str, is_static_input_scheme: bool, - input_symmetric: bool): + def __init__( + self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool + ): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme self.input_symmetric = input_symmetric @@ -33,56 +39,61 @@ def get_min_capability(cls) -> int: # turing and up return 75 - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): layer.logical_widths = output_partition_sizes scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_static_input_scheme=self.is_static_input_scheme, - input_symmetric=self.input_symmetric) + input_symmetric=self.input_symmetric, + ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config) + kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsW8A8Int8", - kernel_type.__name__) + logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=torch.int8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # WEIGHT SCALE if self.strategy == QuantizationStrategy.CHANNEL: weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) else: assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE if self.is_static_input_scheme: - input_scale = BasevLLMParameter(data=torch.empty( - 1, dtype=torch.float32), - weight_loader=weight_loader) + input_scale = BasevLLMParameter( + data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader + ) layer.register_parameter("input_scale", input_scale) if not self.input_symmetric: @@ -90,22 +101,25 @@ def create_weights(self, layer: torch.nn.Module, # as the weights # AZP loaded as int8 but used as int32 input_zero_point = BasevLLMParameter( - data=torch.empty(1, dtype=torch.int8), - weight_loader=weight_loader) + data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader + ) layer.register_parameter("input_zero_point", input_zero_point) - self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj") + self.kernel = kernel_type( + c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj", + ) # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 74787603e002..188fc15fd948 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -8,29 +8,29 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( - MPLinearLayerConfig, choose_mp_linear_kernel) + MPLinearLayerConfig, + choose_mp_linear_kernel, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - marlin_repeat_scales_on_all_ranks) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - RowvLLMParameter) -# yapf: enable + marlin_repeat_scales_on_all_ranks, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) __all__ = ["CompressedTensorsWNA16"] -WNA16_SUPPORTED_TYPES_MAP = { - 4: scalar_types.uint4b8, - 8: scalar_types.uint8b128 -} +WNA16_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4b8, 8: scalar_types.uint8b128} WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8} WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) @@ -38,13 +38,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): _kernel_backends_being_used: set[str] = set() - def __init__(self, - strategy: str, - num_bits: int, - group_size: Optional[int] = None, - symmetric: Optional[bool] = True, - actorder: Optional[ActivationOrdering] = None): - + def __init__( + self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None, + ): self.pack_factor = 32 // num_bits self.strategy = strategy self.symmetric = symmetric @@ -52,55 +53,67 @@ def __init__(self, self.has_g_idx = actorder == ActivationOrdering.GROUP if self.group_size == -1 and self.strategy != "channel": - raise ValueError("Marlin kernels require group quantization or " - "channelwise quantization, but found no group " - "size and strategy is not channelwise.") + raise ValueError( + "Marlin kernels require group quantization or " + "channelwise quantization, but found no group " + "size and strategy is not channelwise." + ) if num_bits not in WNA16_SUPPORTED_TYPES_MAP: raise ValueError( f"Unsupported num_bits = {num_bits}. " - f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") + f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}" + ) - self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits] - if not self.symmetric else - WNA16_SUPPORTED_TYPES_MAP[num_bits]) + self.quant_type = ( + WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits] + if not self.symmetric + else WNA16_SUPPORTED_TYPES_MAP[num_bits] + ) @classmethod def get_min_capability(cls) -> int: # ampere and up return 80 - def create_weights(self, layer: torch.nn.Module, output_size: int, - input_size: int, output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - + def create_weights( + self, + layer: torch.nn.Module, + output_size: int, + input_size: int, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), - partition_weight_shape=\ - (input_size_per_partition, output_size_per_partition), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), weight_type=self.quant_type, act_type=params_dtype, group_size=self.group_size, zero_points=not self.symmetric, - has_g_idx=self.has_g_idx + has_g_idx=self.has_g_idx, ) kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for CompressedTensorsWNA16", - kernel_type.__name__) + logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # If group_size is -1, we are in channelwise case. group_size = self.group_size if self.group_size != -1 else input_size - row_parallel = (input_size != input_size_per_partition) + row_parallel = input_size != input_size_per_partition partition_scales = not marlin_repeat_scales_on_all_ranks( - self.has_g_idx, self.group_size, row_parallel) + self.has_g_idx, self.group_size, row_parallel + ) scales_and_zp_size = input_size // group_size @@ -108,65 +121,65 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, assert input_size_per_partition % group_size == 0 scales_and_zp_size = input_size_per_partition // group_size - weight = PackedvLLMParameter(input_dim=1, - output_dim=0, - weight_loader=weight_loader, - packed_factor=self.pack_factor, - packed_dim=1, - data=torch.empty( - output_size_per_partition, - input_size_per_partition // - self.pack_factor, - dtype=torch.int32, - )) + weight = PackedvLLMParameter( + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + ) weight_scale_args = { - "weight_loader": - weight_loader, - "data": - torch.empty( + "weight_loader": weight_loader, + "data": torch.empty( output_size_per_partition, scales_and_zp_size, dtype=params_dtype, - ) + ), } zeros_args = { - "weight_loader": - weight_loader, - "data": - torch.zeros( + "weight_loader": weight_loader, + "data": torch.zeros( output_size_per_partition // self.pack_factor, scales_and_zp_size, dtype=torch.int32, - ) + ), } if not partition_scales: - weight_scale = ChannelQuantScaleParameter(output_dim=0, - **weight_scale_args) + weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) if not self.symmetric: - qzeros = PackedColumnParameter(output_dim=0, - packed_dim=0, - packed_factor=self.pack_factor, - **zeros_args) + qzeros = PackedColumnParameter( + output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args, + ) else: - weight_scale = GroupQuantScaleParameter(output_dim=0, - input_dim=1, - **weight_scale_args) + weight_scale = GroupQuantScaleParameter( + output_dim=0, input_dim=1, **weight_scale_args + ) if not self.symmetric: - qzeros = PackedvLLMParameter(input_dim=1, - output_dim=0, - packed_dim=0, - packed_factor=self.pack_factor, - **zeros_args) + qzeros = PackedvLLMParameter( + input_dim=1, + output_dim=0, + packed_dim=0, + packed_factor=self.pack_factor, + **zeros_args, + ) # A 2D array defining the original shape of the weights # before packing - weight_shape = BasevLLMParameter(data=torch.empty(2, - dtype=torch.int64), - weight_loader=weight_loader) + weight_shape = BasevLLMParameter( + data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader + ) layer.register_parameter("weight_packed", weight) layer.register_parameter("weight_scale", weight_scale) @@ -177,25 +190,30 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, # group index (for activation reordering) if self.has_g_idx: - weight_g_idx = RowvLLMParameter(data=torch.empty( - input_size_per_partition, - dtype=torch.int32, - ), - input_dim=0, - weight_loader=weight_loader) + weight_g_idx = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_g_idx", weight_g_idx) - self.kernel = kernel_type(mp_linear_kernel_config, - w_q_param_name="weight_packed", - w_s_param_name="weight_scale", - w_zp_param_name="weight_zero_point", - w_gidx_param_name="weight_g_idx") + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name="weight_zero_point", + w_gidx_param_name="weight_g_idx", + ) # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py index d098185146e4..a51fe28b975e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py @@ -5,19 +5,28 @@ from typing import Callable, Optional import torch -from compressed_tensors.transform import (TransformArgs, TransformConfig, - TransformLocation, TransformScheme) +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformLocation, + TransformScheme, +) from compressed_tensors.utils import is_match -from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, - LinearMethodBase, - QKVCrossParallelLinear) +from vllm.model_executor.layers.linear import ( + WEIGHT_LOADER_V2_SUPPORTED, + LinearMethodBase, + QKVCrossParallelLinear, +) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsScheme) + CompressedTensorsScheme, +) from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501 - HadamardTransform) + HadamardTransform, +) from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 - TransformTuple) + TransformTuple, +) class CompressedTensorsLinearTransformMethod(LinearMethodBase): @@ -35,21 +44,25 @@ def from_schemes( output_tfms: dict[int, TransformTuple], ) -> "CompressedTensorsLinearTransformMethod": from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501 - QutlassNvFP4LinearMethod, is_qutlass_fp4_scheme) + QutlassNvFP4LinearMethod, + is_qutlass_fp4_scheme, + ) assert input_tfms or output_tfms if is_qutlass_fp4_scheme(quant_scheme, input_tfms): - return QutlassNvFP4LinearMethod(quant_method, input_tfms, - output_tfms) + return QutlassNvFP4LinearMethod(quant_method, input_tfms, output_tfms) # hadacore or dense gemm is selected by Transform module return cls(quant_method, input_tfms, output_tfms) - def __init__(self, quant_method: LinearMethodBase, - input_tfms: dict[int, TransformTuple], - output_tfms: dict[int, TransformTuple]): + def __init__( + self, + quant_method: LinearMethodBase, + input_tfms: dict[int, TransformTuple], + output_tfms: dict[int, TransformTuple], + ): self.quant_method = quant_method self.input_tfms = input_tfms self.output_tfms = output_tfms @@ -57,15 +70,18 @@ def __init__(self, quant_method: LinearMethodBase, self.input_transform: Optional[HadamardTransform] = None self.output_transform: Optional[HadamardTransform] = None - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # get weight loader for transforms - weight_loader: Callable = extra_weight_attrs.get( - "weight_loader") # type: ignore[assignment] + weight_loader: Callable = extra_weight_attrs.get("weight_loader") # type: ignore[assignment] # HACK: UnquantizedLinearMethod does not support weight loader v2, but # transforms (specifically SharedWeightParameter) requires @@ -86,7 +102,8 @@ def create_weights(self, layer: torch.nn.Module, input_size=input_size, output_size=output_size, params_dtype=params_dtype, - **extra_weight_attrs) + **extra_weight_attrs, + ) # validate schemes num_partitions = len(output_partition_sizes) @@ -98,10 +115,13 @@ def create_weights(self, layer: torch.nn.Module, location = list(self.input_tfms.values())[0].args.location transform_name = f"{scheme_name}_{location}" - transform = HadamardTransform(self.input_tfms, layer, - weight_loader, - input_size_per_partition, - output_partition_sizes) + transform = HadamardTransform( + self.input_tfms, + layer, + weight_loader, + input_size_per_partition, + output_partition_sizes, + ) layer.register_module(transform_name, transform) self.input_transform = transform @@ -110,10 +130,13 @@ def create_weights(self, layer: torch.nn.Module, location = list(self.output_tfms.values())[0].args.location transform_name = f"{scheme_name}_{location}" - transform = HadamardTransform(self.output_tfms, layer, - weight_loader, - input_size_per_partition, - output_partition_sizes) + transform = HadamardTransform( + self.output_tfms, + layer, + weight_loader, + input_size_per_partition, + output_partition_sizes, + ) layer.register_module(transform_name, transform) self.output_transform = transform @@ -128,11 +151,12 @@ def process_weights_after_loading(self, layer): if isinstance(submodule, HadamardTransform): submodule.process_weights_after_loading() - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.input_transform is not None: x = self.input_transform(x) @@ -143,8 +167,9 @@ def apply(self, # (@ksayers): confirm that this is done concurrently if self.output_transform is not None: for part_id, (start, length) in enumerate(self.partition_ranges): - x[:, start:start + length] = self.output_transform( - x[:, start:start + length].contiguous(), part_id=part_id) + x[:, start : start + length] = self.output_transform( + x[:, start : start + length].contiguous(), part_id=part_id + ) return x @@ -171,39 +196,41 @@ def _validate_tfm_schemes(self, num_partitions: int): def get_linear_transform_schemes( - layer: torch.nn.Module, layer_name: str, + layer: torch.nn.Module, + layer_name: str, transform_config: Optional[TransformConfig], - packed_modules_mapping: dict[str, list[str]] -) -> tuple[dict[int, TransformTuple], dict[ - int, TransformTuple]]: # [input_transform, [output_transform, ...]] + packed_modules_mapping: dict[str, list[str]], +) -> tuple[ + dict[int, TransformTuple], dict[int, TransformTuple] +]: # [input_transform, [output_transform, ...]] # there can only be one transform input scheme per (fused) module input_tfms = {} output_tfms = {} - partition_names = get_layer_partition_names(layer_name, - packed_modules_mapping) + partition_names = get_layer_partition_names(layer_name, packed_modules_mapping) for scheme_name, scheme, args in get_schemes_args(transform_config): for part_index, part_name in enumerate(partition_names): - if is_match(part_name, layer, args.targets, - args.ignore) and args.is_online(): + if ( + is_match(part_name, layer, args.targets, args.ignore) + and args.is_online() + ): if args.location == TransformLocation.INPUT: - input_tfms[part_index] = TransformTuple( - scheme_name, scheme, args) + input_tfms[part_index] = TransformTuple(scheme_name, scheme, args) elif args.location == TransformLocation.OUTPUT: - output_tfms[part_index] = TransformTuple( - scheme_name, scheme, args) + output_tfms[part_index] = TransformTuple(scheme_name, scheme, args) else: - raise ValueError(f"Cannot apply `{args.location}` " - f"transform to `{layer_name}`") + raise ValueError( + f"Cannot apply `{args.location}` transform to `{layer_name}`" + ) return (input_tfms, output_tfms) def get_schemes_args( - transform_config: Optional[TransformConfig] + transform_config: Optional[TransformConfig], ) -> Generator[tuple[str, TransformScheme, TransformArgs]]: if transform_config is None: return @@ -214,20 +241,20 @@ def get_schemes_args( def get_layer_partition_names( - layer_name: str, packed_modules_mapping: dict[str, - list[str]]) -> list[str]: + layer_name: str, packed_modules_mapping: dict[str, list[str]] +) -> list[str]: """ Get all partition names associated with this layer. Names are returned in order of their partition indices. - + ```python mapping = {"gate_up_proj", "gate_proj", "up_proj"} - assert get_layer_partition_names( - "mlp.gate_up_proj", mapping) == ["gate_proj", "up_proj"] - assert get_layer_partition_names( - "mlp.down_proj", mapping) == ["down_proj"] - """ + assert get_layer_partition_names("mlp.gate_up_proj", mapping) == [ + "gate_proj", + "up_proj", + ] + assert get_layer_partition_names("mlp.down_proj", mapping) == ["down_proj"]""" for fused_suffix, part_suffixes in packed_modules_mapping.items(): if layer_name.endswith(fused_suffix): return [ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py index 5e863354715e..ecd798257fce 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py @@ -5,19 +5,21 @@ from typing import Callable import torch -from compressed_tensors.transform import (TransformArgs, TransformLocation, - TransformScheme) +from compressed_tensors.transform import ( + TransformArgs, + TransformLocation, + TransformScheme, +) from torch import Tensor import vllm._custom_ops as ops -from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_world_size) +from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501 - TransformTuple) + TransformTuple, +) from vllm.model_executor.layers.utils import dispatch_unquantized_gemm -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.parameter import SharedWeightParameter @@ -27,22 +29,28 @@ class HadamardTransform(torch.nn.Module): transforms. Meant to be used with `CompressedTensorsLinearTransformMethod` and attention transforms method (not implemented yet) """ + transforms: dict[int, TransformTuple] # info parsed from transforms config weight: SharedWeightParameter # container for shared tensors scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0)) - def __init__(self, transforms: dict[int, TransformTuple], - layer: torch.nn.Module, weight_loader: Callable, - input_size_per_partition: int, - output_partition_sizes: list[int]): + def __init__( + self, + transforms: dict[int, TransformTuple], + layer: torch.nn.Module, + weight_loader: Callable, + input_size_per_partition: int, + output_partition_sizes: list[int], + ): super().__init__() self.transforms = transforms self.scales = {} if get_tensor_model_parallel_world_size() > 1: - raise NotImplementedError("Online transforms with tensor " - "parallelism is not supported") + raise NotImplementedError( + "Online transforms with tensor parallelism is not supported" + ) # Similar to row/col parallel params, but tensors are separate # to allow for loading with shared memory @@ -50,11 +58,11 @@ def __init__(self, transforms: dict[int, TransformTuple], # create shared partition data for each partition of the original weight input_size = input_size_per_partition - for part_index, (_scheme_name, scheme, - args) in self.transforms.items(): + for part_index, (_scheme_name, scheme, args) in self.transforms.items(): output_size = output_partition_sizes[part_index] - weight_size = self._get_weight_size(layer, scheme, args, - input_size, output_size) + weight_size = self._get_weight_size( + layer, scheme, args, input_size, output_size + ) data_key = self._get_data_key(scheme, weight_size) self.weight.add_partition( @@ -101,28 +109,41 @@ def forward(self, value: Tensor, part_id: int = 0) -> Tensor: # fall back to dense else: weight = self.weight.partitions[part_id] - weight = weight if self.transforms[ - part_id].args.inverse else weight.T # linear := x(W.T) + weight = ( + weight if self.transforms[part_id].args.inverse else weight.T + ) # linear := x(W.T) scale = self.scales[part_id] if self.transforms[part_id].scheme.head_dim is not None: value = value.unflatten(-1, (-1, weight.size(0))) - value = dispatch_unquantized_gemm()(self, value.to( - weight.dtype), weight, None).to(value.dtype) * scale + value = ( + dispatch_unquantized_gemm()( + self, value.to(weight.dtype), weight, None + ).to(value.dtype) + * scale + ) value = value.flatten(-2, -1) return value - return dispatch_unquantized_gemm()(self, value.to( - weight.dtype), weight, None).to(value.dtype) * scale + return ( + dispatch_unquantized_gemm()( + self, value.to(weight.dtype), weight, None + ).to(value.dtype) + * scale + ) - def _get_data_key(self, scheme: TransformScheme, - weight_size: int) -> Hashable: + def _get_data_key(self, scheme: TransformScheme, weight_size: int) -> Hashable: return (id(scheme), weight_size) - def _get_weight_size(self, layer: torch.nn.Module, scheme: TransformScheme, - args: TransformArgs, input_size: int, - output_size: int) -> int: + def _get_weight_size( + self, + layer: torch.nn.Module, + scheme: TransformScheme, + args: TransformArgs, + input_size: int, + output_size: int, + ) -> int: if scheme.head_dim is not None: return scheme.head_dim diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py index 69b39f31eec1..b800c5f5d436 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py @@ -5,42 +5,61 @@ import torch from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsScheme, CompressedTensorsW4A4Fp4) + CompressedTensorsScheme, + CompressedTensorsW4A4Fp4, +) from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 - CompressedTensorsLinearTransformMethod, TransformTuple) + CompressedTensorsLinearTransformMethod, + TransformTuple, +) __all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"] -def is_qutlass_fp4_scheme(quant_scheme: Optional[CompressedTensorsScheme], - input_tfms: dict[int, TransformTuple]) -> bool: - return isinstance( - quant_scheme, - (CompressedTensorsW4A4Fp4, )) and len(input_tfms) == 1 and input_tfms[ - 0].scheme.head_dim == quant_scheme.group_size +def is_qutlass_fp4_scheme( + quant_scheme: Optional[CompressedTensorsScheme], + input_tfms: dict[int, TransformTuple], +) -> bool: + return ( + isinstance(quant_scheme, (CompressedTensorsW4A4Fp4,)) + and len(input_tfms) == 1 + and input_tfms[0].scheme.head_dim == quant_scheme.group_size + ) class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod): - - def create_weights(self, layer, input_size_per_partition, - output_partition_sizes, input_size, output_size, - params_dtype, **extra_weight_attrs): + def create_weights( + self, + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ): # initializes fp4 qparams - assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4, )) - ret = super().create_weights(layer, input_size_per_partition, - output_partition_sizes, input_size, - output_size, params_dtype, - **extra_weight_attrs) + assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4,)) + ret = super().create_weights( + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ) assert self.input_transform is not None assert len(self.input_transform.weight) == 1 - assert self.input_transform.weight[0].size( - 0) == layer.scheme.group_size + assert self.input_transform.weight[0].size(0) == layer.scheme.group_size return ret - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index d926b4c12db1..ed326197295d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -17,13 +17,29 @@ def is_weak_contiguous(x: torch.Tensor): @triton.jit -def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, - M, N, K, stride_am, stride_ak, stride_bk, stride_bn, - stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_SCALE_A: tl.constexpr, - BLOCK_SIZE_SCALE_B: tl.constexpr): +def scaled_mm_kernel( + a_ptr, + b_ptr, + scale_a_ptr, + scale_b_ptr, + c_ptr, + bias_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + ACCUMULATOR_DTYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_SCALE_A: tl.constexpr, + BLOCK_SIZE_SCALE_B: tl.constexpr, +): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -32,8 +48,7 @@ def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, pid_n = pid % num_pid_n accumulator_dtype = ACCUMULATOR_DTYPE - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), - dtype=accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) # NOTE: Some tensor inputs are so large, they will cause int32 overflow # so it is necessary to use tl.int64 for all the offsets, else SEGV will @@ -47,20 +62,22 @@ def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, masks_bn = offsets_bn < N offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) - offsets_a = (stride_am * offsets_am[:, None] + - stride_ak * offsets_k[None, :]) - offsets_b = (stride_bk * offsets_k[:, None] + - stride_bn * offsets_bn[None, :]) + offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :] + offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :] # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create # appropriate offsets and masks for each case. Same goes for # BLOCK_SIZE_SCALE_B. - offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) + - (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M) + offsets_scale_am = ( + tl.arange(0, BLOCK_SIZE_SCALE_A) + + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M + ) masks_scale_am = offsets_scale_am < M - offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) + - (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N) + offsets_scale_bn = ( + tl.arange(0, BLOCK_SIZE_SCALE_B) + + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N + ) masks_scale_bn = offsets_scale_bn < N a_ptrs = a_ptr + offsets_a @@ -114,8 +131,7 @@ def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) offs_cm = offs_cm.to(tl.int64) offs_cn = offs_cn.to(tl.int64) - c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] + - stride_cn * offs_cn[None, :]) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) @@ -123,16 +139,18 @@ def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, # input - [M, K] # weight - [K, N] -def triton_scaled_mm(input: torch.Tensor, - weight: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: type[torch.dtype], - bias: Optional[torch.Tensor] = None, - block_size_m: int = 32, - block_size_n: int = 32, - block_size_k: int = 32, - use_heuristic=True) -> torch.Tensor: +def triton_scaled_mm( + input: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, + use_heuristic=True, +) -> torch.Tensor: M, K = input.shape N = weight.shape[1] @@ -144,17 +162,16 @@ def triton_scaled_mm(input: torch.Tensor, scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() - assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 - or scale_a.shape[0] == M) - assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 - or scale_b.shape[0] == N) + assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M) + assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N) assert out_dtype.is_floating_point assert bias is None or bias.is_floating_point() assert is_weak_contiguous(input) assert is_weak_contiguous(weight) - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( - N, META['BLOCK_SIZE_N']), ) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) result = torch.empty((M, N), dtype=out_dtype, device=input.device) @@ -181,26 +198,28 @@ def triton_scaled_mm(input: torch.Tensor, # A = input, B = weight, C = result # A = M x K, B = K x N, C = M x N - scaled_mm_kernel[grid](input, - weight, - scale_a, - scale_b, - result, - bias, - M, - N, - K, - input.stride(0), - input.stride(1), - weight.stride(0), - weight.stride(1), - result.stride(0), - result.stride(1), - accumulator_dtype, - BLOCK_SIZE_M=block_size_m, - BLOCK_SIZE_N=block_size_n, - BLOCK_SIZE_K=block_size_k, - BLOCK_SIZE_SCALE_A=block_size_sa, - BLOCK_SIZE_SCALE_B=block_size_sb) + scaled_mm_kernel[grid]( + input, + weight, + scale_a, + scale_b, + result, + bias, + M, + N, + K, + input.stride(0), + input.stride(1), + weight.stride(0), + weight.stride(1), + result.stride(0), + result.stride(1), + accumulator_dtype, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + BLOCK_SIZE_SCALE_A=block_size_sa, + BLOCK_SIZE_SCALE_B=block_size_sb, + ) return result.to(out_dtype) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index b2dd2501095f..d8beaafff2ef 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -15,7 +15,7 @@ def is_activation_quantization_format(format: str) -> bool: CompressionFormat.naive_quantized.value, CompressionFormat.int_quantized.value, CompressionFormat.float_quantized.value, - CompressionFormat.nvfp4_pack_quantized.value + CompressionFormat.nvfp4_pack_quantized.value, ] return format in _ACTIVATION_QUANTIZATION_FORMATS @@ -23,7 +23,7 @@ def is_activation_quantization_format(format: str) -> bool: def should_ignore_layer( layer_name: Optional[str], ignore: Iterable[str] = tuple(), - fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> bool: if layer_name is None: return False @@ -49,7 +49,8 @@ def should_ignore_layer( should_ignore_layer = None for shard_name in shard_names: should_ignore_shard = check_equal_or_regex_match( - layer_name=shard_name, targets=ignore) + layer_name=shard_name, targets=ignore + ) # If shard_idx=0, set layer ignore to match shard. if should_ignore_layer is None: @@ -57,37 +58,36 @@ def should_ignore_layer( # If shard_idx=1+ confirm scheme matches prior shards. elif should_ignore_shard != should_ignore_layer: - raise ValueError(f"Found a different quantization schemes for " - f"{shard_proj_names} in {layer_name}. vLLM " - "requires all to use the same scheme.") + raise ValueError( + f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme." + ) # Unfused layers like down_proj and o_proj will match # the safetensors checkpoint already. else: - should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, - targets=ignore) + should_ignore_layer = check_equal_or_regex_match( + layer_name=layer_name, targets=ignore + ) assert should_ignore_layer is not None return should_ignore_layer -def check_equal_or_regex_match(layer_name: str, - targets: Iterable[str]) -> bool: +def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool: """ Checks whether a layer_name is exactly equal or a regex match for if target starts with 're:' to any target in list. """ - for target in targets: - if _is_equal_or_regex_match(layer_name, target): - return True - return False + return any(_is_equal_or_regex_match(layer_name, target) for target in targets) def find_matched_target( layer_name: Optional[str], module: Module, targets: Iterable[str], - fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> str: """ Helper function to look up which "target" in the compressed-tensors @@ -120,19 +120,21 @@ def find_matched_target( matched_target = ( _find_first_match(layer_name, targets) or _find_first_match(module.__class__.__name__, targets, True) - or _match_fused_layer(layer_name, targets, fused_mapping)) + or _match_fused_layer(layer_name, targets, fused_mapping) + ) if matched_target is None: raise ValueError( f"Unable to find matching target for {layer_name} in the " - "compressed-tensors config.") + "compressed-tensors config." + ) return matched_target -def _find_first_match(value: str, - targets: Iterable[str], - check_contains: bool = False) -> Optional[str]: +def _find_first_match( + value: str, targets: Iterable[str], check_contains: bool = False +) -> Optional[str]: """ Returns first element of target that matches value either exactly or as a regex after 're:'. If check_contains is set to True, @@ -144,16 +146,14 @@ def _find_first_match(value: str, """ for target in targets: - if _is_equal_or_regex_match(value, - target, - check_contains=check_contains): + if _is_equal_or_regex_match(value, target, check_contains=check_contains): return target return None -def _is_equal_or_regex_match(value: str, - target: str, - check_contains: bool = False) -> bool: +def _is_equal_or_regex_match( + value: str, target: str, check_contains: bool = False +) -> bool: """ Checks whether a value is exactly equal or a regex match for target if target starts with 're:'. If check_contains is set to True, @@ -173,10 +173,12 @@ def _is_equal_or_regex_match(value: str, def _match_fused_layer( - layer_name: str, target_layers: Iterable[str], - fused_mapping: Mapping[str, list[str]]) -> Optional[str]: + layer_name: str, + target_layers: Iterable[str], + fused_mapping: Mapping[str, list[str]], +) -> Optional[str]: """ - Match a fused layer name to its corresponding individual layer in + Match a fused layer name to its corresponding individual layer in target_layers. Returns first value in fused_mapping which matches targets Implements an "all" matching strategy where a fused layer matches iff @@ -193,8 +195,7 @@ def _match_fused_layer( "model.layers.0.self_attn.v_proj"] """ # find layer_name in mapping - fused = next((key for key in fused_mapping if layer_name.endswith(key)), - None) + fused = next((key for key in fused_mapping if layer_name.endswith(key)), None) if fused is None: return None diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 4a189ab4a171..82a2103a19f3 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -9,15 +9,17 @@ from packaging import version from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import (QuantizationConfig, - QuantizationMethods) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.utils import set_weight_attrs class DeepSpeedFPConfig(QuantizationConfig): """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. - - Args: + + Args: weight_bits: the target quantization bits, 6 or 8. group_size: group size for quantizaiton, default to 128. """ @@ -36,11 +38,14 @@ def __init__( raise ValueError( "Currently, only 6-bit or 8-bit weight quantization are " f"supported for DeepSpeed FP quantizaiton, but got " - f"{self.weight_bits} bits.") + f"{self.weight_bits} bits." + ) def __repr__(self) -> str: - return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " - f"group_size={self.group_size}") + return ( + f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " + f"group_size={self.group_size}" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -71,8 +76,9 @@ def get_config_filenames() -> list[str]: "quantize_config.json", ] - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["DeepSpeedFPLinearMethod"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["DeepSpeedFPLinearMethod"]: if isinstance(layer, LinearBase): return DeepSpeedFPLinearMethod(self) return None @@ -89,15 +95,17 @@ def __init__(self, quant_config: DeepSpeedFPConfig): self.quant_config = quant_config self.weight = None - def create_weights(self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - weight_loader=None, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs, + ): del output_size del input_size output_size_per_partition = sum(output_partition_sizes) @@ -106,10 +114,13 @@ def create_weights(self, params_dtype=params_dtype, quant_config=self.quant_config, ) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - }) + set_weight_attrs( + weight, + { + "input_dim": 1, + "output_dim": 0, + }, + ) layer.register_parameter("weight", weight) def quant_weight_loader(param, loaded_weight, *args, **kwargs): @@ -125,10 +136,12 @@ def quant_weight_loader(param, loaded_weight, *args, **kwargs): extra_weight_attrs["weight_loader"] = quant_weight_loader set_weight_attrs(weight, extra_weight_attrs) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: weight = layer.weight y = weight.ds_dequantize() return F.linear(x, y, bias) @@ -141,23 +154,33 @@ class DeepSpeedFPParameter(nn.Parameter): GPUs, and can be dequantized on-the-fly when needed by the model. """ - def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, - quant_config: DeepSpeedFPConfig): + def __new__( + cls, + orig_shape: torch.Size, + params_dtype: torch.dtype, + quant_config: DeepSpeedFPConfig, + ): try: import deepspeed + if version.parse(deepspeed.__version__) < version.parse("0.14.2"): - raise ImportError("deepspeed version is wrong. Please " - "install deepspeed>=0.14.2.") + raise ImportError( + "deepspeed version is wrong. Please install deepspeed>=0.14.2." + ) from deepspeed.ops.fp_quantizer import FP_Quantize except ImportError as err: - raise ImportError("Please install deepspeed>=0.14.2 via " - "`pip install deepspeed>=0.14.2` to use " - "deepspeedfp quantizer.") from err - data = torch.empty(( - orig_shape.numel() // quant_config.group_size, - quant_config.group_size * quant_config.weight_bits // 8 + 4, - ), - dtype=torch.int8) + raise ImportError( + "Please install deepspeed>=0.14.2 via " + "`pip install deepspeed>=0.14.2` to use " + "deepspeedfp quantizer." + ) from err + data = torch.empty( + ( + orig_shape.numel() // quant_config.group_size, + quant_config.group_size * quant_config.weight_bits // 8 + 4, + ), + dtype=torch.int8, + ) self = torch.Tensor._make_subclass(cls, data, data.requires_grad) self.orig_shape = orig_shape self.quant_config = quant_config @@ -172,7 +195,8 @@ def ds_quantize_(self, tensor: torch.Tensor): self.fp_quantizer.quantize( tensor.data, q_bits=self.quant_config.weight_bits, - )) + ) + ) def ds_dequantize(self, fp_out=None) -> torch.Tensor: """ @@ -180,7 +204,8 @@ def ds_dequantize(self, fp_out=None) -> torch.Tensor: """ assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 return self.fp_quantizer.dequantize( - self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits + ) def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: """ @@ -189,7 +214,5 @@ def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: """ assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 return self.fp_quantizer.selective_dequantize( - self.data, - indices, - fp_out=fp_out, - q_bits=self.quant_config.weight_bits) + self.data, indices, fp_out=fp_out, q_bits=self.quant_config.weight_bits + ) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 38d7e200b303..909b04c79f23 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -6,15 +6,21 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEConfig, + FusedMoEMethodBase, +) from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, int8_w8a16_moe_quant_config) -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) + FusedMoEQuantConfig, + int8_w8a16_moe_quant_config, +) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.utils import set_weight_attrs @@ -44,8 +50,9 @@ def get_config_filenames(cls) -> list[str]: def from_config(cls, config: dict[str, Any]) -> "ExpertsInt8Config": return cls() - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): @@ -54,7 +61,6 @@ def get_quant_method(self, layer: torch.nn.Module, class ExpertsInt8MoEMethod(FusedMoEMethodBase): - def __init__( self, quant_config: ExpertsInt8Config, @@ -63,57 +69,70 @@ def __init__( super().__init__(moe) self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): int8_dtype = torch.int8 - assert 'weight_loader' in extra_weight_attrs - weight_loader = extra_weight_attrs['weight_loader'] + assert "weight_loader" in extra_weight_attrs + weight_loader = extra_weight_attrs["weight_loader"] wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader( - layer, weight_loader) - extra_weight_attrs['weight_loader'] = wrapped_weight_loader + layer, weight_loader + ) + extra_weight_attrs["weight_loader"] = wrapped_weight_loader # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=int8_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=int8_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=int8_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=int8_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - w13_scale = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - dtype=torch.float32), - requires_grad=False) + w13_scale = torch.nn.Parameter( + torch.zeros( + num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32 + ), + requires_grad=False, + ) layer.register_parameter("w13_scale", w13_scale) - w2_scale = torch.nn.Parameter(torch.zeros(num_experts, - hidden_size, - dtype=torch.float32), - requires_grad=False) + w2_scale = torch.nn.Parameter( + torch.zeros(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) layer.register_parameter("w2_scale", w2_scale) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: - return int8_w8a16_moe_quant_config(w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - w1_zp=None, - w2_zp=None) + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + return int8_w8a16_moe_quant_config( + w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, w1_zp=None, w2_zp=None + ) def apply( self, @@ -142,7 +161,8 @@ def apply( if enable_eplb: raise NotImplementedError( - "EPLB not supported for `ExpertsInt8MoEMethod` yet.") + "EPLB not supported for `ExpertsInt8MoEMethod` yet." + ) from vllm.model_executor.layers.fused_moe import fused_experts @@ -158,7 +178,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return fused_experts( x, @@ -176,11 +197,13 @@ def apply( @staticmethod def quantizing_weight_loader(layer, weight_loader): - - def quantize_and_call_weight_loader(param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, shard_id: int, - expert_id: int): + def quantize_and_call_weight_loader( + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: int, + expert_id: int, + ): tp_rank = get_tensor_model_parallel_rank() shard_size = layer.intermediate_size_per_partition shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) @@ -188,33 +211,28 @@ def quantize_and_call_weight_loader(param: torch.nn.Parameter, loaded_weight = loaded_weight.to(device) # w1, gate_proj case: Load into first shard of w13. if shard_id == "w1": - scales = quantize_in_place_and_get_scales( - loaded_weight[shard, :]) - layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, - 0]) + scales = quantize_in_place_and_get_scales(loaded_weight[shard, :]) + layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, 0]) # w3, up_proj case: Load into second shard of w13. elif shard_id == "w3": - scales = quantize_in_place_and_get_scales( - loaded_weight[shard, :]) - layer.w13_scale.data[expert_id, shard_size:2 * - shard_size].copy_(scales[:, 0]) + scales = quantize_in_place_and_get_scales(loaded_weight[shard, :]) + layer.w13_scale.data[expert_id, shard_size : 2 * shard_size].copy_( + scales[:, 0] + ) # w2, down_proj case: Load into only shard of w2. elif shard_id == "w2": - scales = quantize_in_place_and_get_scales(loaded_weight[:, - shard]) + scales = quantize_in_place_and_get_scales(loaded_weight[:, shard]) layer.w2_scale.data[expert_id, :].copy_(scales[:, 0]) else: - raise ValueError( - f"Shard id must be in [0,1,2] but got {shard_id}") - weight_loader(param, loaded_weight, weight_name, shard_id, - expert_id) + raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}") + weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) return quantize_and_call_weight_loader def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor: vmax = torch.iinfo(torch.int8).max - scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax) + scales = torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax weight.div_(scales) weight.round_() diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index b2cab7d4614a..5d390cbd7b1e 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -8,19 +8,33 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped) + GroupShape, + is_layer_skipped, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - ModelWeightParameter) + Fp8LinearOp, + maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, +) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -60,23 +74,26 @@ def from_config(cls, config: dict[str, Any]) -> "FBGEMMFp8Config": input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): - if is_layer_skipped(prefix=prefix, - ignored_layers=self.ignore_list, - fused_mapping=self.packed_modules_mapping): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignore_list, + fused_mapping=self.packed_modules_mapping, + ): return UnquantizedLinearMethod() return FBGEMMFp8LinearMethod(self) return None class FBGEMMFp8LinearMethod(LinearMethodBase): - def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN) + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + ) self.out_dtype = torch.get_default_dtype() def create_weights( @@ -101,43 +118,45 @@ def create_weights( layer.orig_dtype = params_dtype # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # WEIGHT SCALE - weight_scale = ChannelQuantScaleParameter(data=torch.empty( - (sum(output_partition_sizes), 1), dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader) + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE UPPER BOUND - input_scale_ub = torch.nn.Parameter(torch.tensor( - (self.quant_config.input_scale_ub), dtype=torch.float32), - requires_grad=False) + input_scale_ub = torch.nn.Parameter( + torch.tensor((self.quant_config.input_scale_ub), dtype=torch.float32), + requires_grad=False, + ) layer.input_scale_ub = input_scale_ub def process_weights_after_loading(self, layer: Module) -> None: # required by torch.compile - layer.weight_scale = Parameter(layer.weight_scale.data, - requires_grad=False) + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False) weight = layer.weight if current_platform.is_fp8_fnuz(): - weight, weight_scale, input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=layer.weight_scale, - input_scale=None) + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=layer.weight_scale, input_scale=None + ) if input_scale is not None: layer.input_scale = Parameter(input_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) @@ -148,11 +167,12 @@ def process_weights_after_loading(self, layer: Module) -> None: # Activations not quantized for marlin. del layer.input_scale_ub - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.quant_config.use_marlin: return apply_fp8_marlin_linear( input=x, @@ -161,12 +181,15 @@ def apply(self, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, - bias=bias) - - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=None, - input_scale_ub=layer.input_scale_ub, - bias=bias) + bias=bias, + ) + + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=None, + input_scale_ub=layer.input_scale_ub, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index dbcf4b2fbee5..2123fd9eba15 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -14,51 +14,85 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( - FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase, - FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, - FusedMoeWeightScaleSupported) + FusedMoE, + FusedMoEActivationFormat, + FusedMoEMethodBase, + FusedMoEPermuteExpertsUnpermute, + FusedMoEPrepareAndFinalize, + FusedMoeWeightScaleSupported, +) from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, fp8_w8a8_moe_quant_config) -from vllm.model_executor.layers.fused_moe.layer import ( - UnquantizedFusedMoEMethod) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + FlashinferMoeBackend, + apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, - register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, - select_cutlass_fp8_gemm_impl, swap_w13_to_w31) + flashinfer_cutlass_moe_fp8, + get_flashinfer_moe_backend, + register_moe_scaling_factors, + rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, + swap_w13_to_w31, +) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support, - create_fp8_input_scale, create_fp8_scale_parameter, - create_fp8_weight_parameter, expert_weight_is_col_major, - maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, - process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace, - validate_fp8_block_shape) + W8A8BlockFp8LinearOp, + check_aiter_fp8_linear_support, + create_fp8_input_scale, + create_fp8_scale_parameter, + create_fp8_weight_parameter, + expert_weight_is_col_major, + maybe_post_process_fp8_weight_block, + process_fp8_weight_block_strategy, + process_fp8_weight_tensor_strategy, + requant_weight_ue8m0_inplace, + validate_fp8_block_shape, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, - prepare_moe_fp8_layer_for_marlin) + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, + prepare_moe_fp8_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped) + GroupShape, + is_layer_skipped, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, - cutlass_fp8_supported, maybe_create_device_identity, - normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) -from vllm.model_executor.parameter import (BlockQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + Fp8LinearOp, + all_close_1d, + cutlass_block_fp8_supported, + cutlass_fp8_supported, + maybe_create_device_identity, + normalize_e4m3fn_to_e4m3fnuz, + per_tensor_dequantize, +) +from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import (get_col_major_tma_aligned_tensor, - is_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import ( + get_col_major_tma_aligned_tensor, + is_deep_gemm_e8m0_used, + is_deep_gemm_supported, +) from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -85,22 +119,25 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: Note: Shape-specific fallbacks may still occur at runtime. """ # prefer FlashInfer backends when available and enabled on supported GPUs - if (current_platform.is_cuda() - and current_platform.is_device_capability(100) - and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()): + if ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and envs.VLLM_USE_FLASHINFER_MOE_FP8 + and has_flashinfer_moe() + ): backend = get_flashinfer_moe_backend() if backend == FlashinferMoeBackend.TENSORRT_LLM: - logger.info_once( - "Using FlashInfer FP8 MoE TRTLLM backend for SM100") + logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") return Fp8MoeBackend.FLASHINFER_TRTLLM else: - logger.info_once( - "Using FlashInfer FP8 MoE CUTLASS backend for SM100") + logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100") return Fp8MoeBackend.FLASHINFER_CUTLASS # weight-only path for older GPUs without native FP8 - use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) + use_marlin = ( + not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + ) if current_platform.is_rocm(): use_marlin = False if use_marlin: @@ -110,17 +147,18 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: # deepGEMM on supported platforms with block-quantized weights if envs.VLLM_USE_DEEP_GEMM and block_quant: if not has_deep_gemm(): - logger.warning_once( - "DeepGEMM backend requested but not available.") + logger.warning_once("DeepGEMM backend requested but not available.") elif is_deep_gemm_supported(): logger.info_once("Using DeepGEMM backend for FP8 MoE") return Fp8MoeBackend.DEEPGEMM # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights - if (current_platform.is_cuda() - and current_platform.is_device_capability(100) and block_quant): - logger.info_once( - "Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") + if ( + current_platform.is_cuda() + and current_platform.is_device_capability(100) + and block_quant + ): + logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM # default to Triton @@ -143,23 +181,26 @@ def __init__( self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if activation_scheme not in ACTIVATION_SCHEMES: - raise ValueError( - f"Unsupported activation scheme {activation_scheme}") + raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] if weight_block_size is not None: if not is_checkpoint_fp8_serialized: raise ValueError( "The block-wise quantization only supports fp8-serialized " - "checkpoint for now.") + "checkpoint for now." + ) if len(weight_block_size) != 2: raise ValueError( "The quantization block size of weight must have 2 " - f"dimensions, but got {len(weight_block_size)} dimensions") + f"dimensions, but got {len(weight_block_size)} dimensions" + ) if activation_scheme != "dynamic": - raise ValueError("The block-wise quantization only supports " - "dynamic activation scheme for now, but got " - f"{activation_scheme} activation scheme.") + raise ValueError( + "The block-wise quantization only supports " + "dynamic activation scheme for now, but got " + f"{activation_scheme} activation scheme." + ) self.weight_block_size = weight_block_size @classmethod @@ -180,41 +221,48 @@ def get_config_filenames(cls) -> list[str]: def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.ignored_layers is not None: - self.ignored_layers = hf_to_vllm_mapper.apply_list( - self.ignored_layers) + self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers) @classmethod def from_config(cls, config: dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) - is_checkpoint_fp8_serialized = ("fp8" in quant_method) + is_checkpoint_fp8_serialized = "fp8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) - weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], - None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) if not ignored_layers: - ignored_layers = cls.get_from_keys_or(config, - ["modules_to_not_convert"], - None) - return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, - activation_scheme=activation_scheme, - ignored_layers=ignored_layers, - weight_block_size=weight_block_size) - - def get_xpu_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + ignored_layers = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None + ) + return cls( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=weight_block_size, + ) + + def get_xpu_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention from vllm.model_executor.layers.quantization.ipex_quant import ( - XPUFp8LinearMethod, XPUFp8MoEMethod) + XPUFp8LinearMethod, + XPUFp8MoEMethod, + ) + fp8_config = Fp8Config( is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized, activation_scheme=self.activation_scheme, ignored_layers=self.ignored_layers, - weight_block_size=self.weight_block_size) + weight_block_size=self.weight_block_size, + ) if isinstance(layer, LinearBase): - if is_layer_skipped(prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): return UnquantizedLinearMethod() return XPUFp8LinearMethod(fp8_config) elif isinstance(layer, FusedMoE): @@ -223,22 +271,27 @@ def get_xpu_quant_method(self, layer: torch.nn.Module, return Fp8KVCacheMethod(self) return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if current_platform.is_xpu(): return self.get_xpu_quant_method(layer, prefix) if isinstance(layer, LinearBase): - if is_layer_skipped(prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): - if is_layer_skipped(prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): return UnquantizedFusedMoEMethod(layer.moe_config) return Fp8MoEMethod(self, layer) elif isinstance(layer, Attention): @@ -291,8 +344,10 @@ def __init__(self, quant_config: Fp8Config): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) + self.use_marlin = ( + not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + ) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False @@ -323,7 +378,8 @@ def __init__(self, quant_config: Fp8Config): else: self.fp8_linear = Fp8LinearOp( act_quant_static=self.act_q_static, - act_quant_group_shape=self.act_q_group_shape) + act_quant_group_shape=self.act_q_group_shape, + ) def create_weights( self, @@ -348,25 +404,32 @@ def create_weights( if self.block_quant: assert self.weight_block_size is not None layer.weight_block_size = self.weight_block_size - validate_fp8_block_shape(layer, input_size, output_size, - input_size_per_partition, - output_partition_sizes, - self.weight_block_size) + validate_fp8_block_shape( + layer, + input_size, + output_size, + input_size_per_partition, + output_partition_sizes, + self.weight_block_size, + ) # WEIGHT if self.quant_config.is_checkpoint_fp8_serialized: - weight = create_fp8_weight_parameter(output_size_per_partition, - input_size_per_partition, - weight_loader) + weight = create_fp8_weight_parameter( + output_size_per_partition, input_size_per_partition, weight_loader + ) else: # For non-serialized checkpoints, use original dtype - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=params_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # If checkpoint is serialized fp8, load them. @@ -374,28 +437,32 @@ def create_weights( if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE if not self.block_quant: - scale = create_fp8_scale_parameter(PerTensorScaleParameter, - output_partition_sizes, - input_size_per_partition, - None, weight_loader) + scale = create_fp8_scale_parameter( + PerTensorScaleParameter, + output_partition_sizes, + input_size_per_partition, + None, + weight_loader, + ) set_weight_attrs(scale, {"scale_type": "weight_scale"}) layer.register_parameter("weight_scale", scale) else: assert not self.act_q_static assert self.weight_block_size is not None - scale = create_fp8_scale_parameter(BlockQuantScaleParameter, - output_partition_sizes, - input_size_per_partition, - self.weight_block_size, - weight_loader) + scale = create_fp8_scale_parameter( + BlockQuantScaleParameter, + output_partition_sizes, + input_size_per_partition, + self.weight_block_size, + weight_loader, + ) set_weight_attrs(scale, {"scale_type": "weight_scale"}) # The weight_scale_inv name is intentional for deepseekv3 layer.register_parameter("weight_scale_inv", scale) # INPUT ACTIVATION SCALE if self.act_q_static: - scale = create_fp8_input_scale(output_partition_sizes, - weight_loader) + scale = create_fp8_input_scale(output_partition_sizes, weight_loader) set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) else: @@ -410,15 +477,15 @@ def process_weights_after_loading(self, layer: Module) -> None: size_k_first = False weight, weight_scale = process_fp8_weight_block_strategy( - layer.weight, layer.weight_scale_inv) + layer.weight, layer.weight_scale_inv + ) # Delete the weight_scale_inv parameter to avoid confusion # with the weight_scale parameter del layer.weight_scale_inv # If checkpoint not serialized fp8, quantize the weights. elif not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, - scale=None) + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) weight = qweight.t() # If checkpoint is fp8 per-tensor, handle that there are N scales for N @@ -430,10 +497,12 @@ def process_weights_after_loading(self, layer: Module) -> None: # If using w8a8, torch._scaled_mm needs per tensor, so # requantize the logical shards as a single weight. if not self.use_marlin: - weight, weight_scale, input_scale = ( - process_fp8_weight_tensor_strategy( - weight, weight_scale, layer.logical_widths, - getattr(layer, 'input_scale', None))) + weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( + weight, + weight_scale, + layer.logical_widths, + getattr(layer, "input_scale", None), + ) if self.act_q_static: assert input_scale is not None input_scale = input_scale.max() @@ -442,9 +511,11 @@ def process_weights_after_loading(self, layer: Module) -> None: # Update layer with new values. layer.weight = Parameter(weight.data, requires_grad=False) layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) - layer.input_scale = Parameter( - input_scale, - requires_grad=False) if input_scale is not None else None + layer.input_scale = ( + Parameter(input_scale, requires_grad=False) + if input_scale is not None + else None + ) if self.use_marlin: prepare_fp8_layer_for_marlin(layer, size_k_first) @@ -453,14 +524,14 @@ def process_weights_after_loading(self, layer: Module) -> None: return if self.block_quant: - maybe_post_process_fp8_weight_block( - layer, self.cutlass_block_fp8_supported) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported) + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.use_marlin: return apply_fp8_marlin_linear( input=x, @@ -469,7 +540,8 @@ def apply(self, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, - bias=bias) + bias=bias, + ) if self.block_quant: assert self.weight_block_size is not None @@ -482,12 +554,14 @@ def apply(self, bias=bias, ) - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias) + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias, + ) class Fp8MoEMethod(FusedMoEMethodBase): @@ -508,29 +582,33 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): self.layer = layer self.quant_config = quant_config self.weight_block_size = self.quant_config.weight_block_size - self.block_quant = self.weight_block_size is not None + self.block_quant: bool = self.weight_block_size is not None - self.fused_experts: Optional[ - mk.FusedMoEModularKernel] = None # type: ignore + self.fused_experts: Optional[mk.FusedMoEModularKernel] = None # type: ignore self.fp8_backend = get_fp8_moe_backend(self.block_quant) - self.use_marlin = (self.fp8_backend == Fp8MoeBackend.MARLIN) + self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS - self.allow_deep_gemm = (self.fp8_backend == Fp8MoeBackend.DEEPGEMM) + self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM self.allow_cutlass_block_scaled_grouped_gemm = ( self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM ) - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size layer.num_experts = num_experts @@ -555,31 +633,38 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, raise ValueError( f"The output_size of gate's and up's weight = " f"{intermediate_size_per_partition} is not divisible by " - f"weight quantization block_n = {block_n}.") - if (tp_size > 1 - and intermediate_size_per_partition % block_k != 0): + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1 and intermediate_size_per_partition % block_k != 0: # Required by row parallel raise ValueError( f"The input_size of down's weight = " f"{intermediate_size_per_partition} is not divisible by " - f"weight quantization block_k = {block_k}.") + f"weight quantization block_k = {block_k}." + ) # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -587,20 +672,19 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, if not self.block_quant: # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, 2, dtype=torch.float32), - requires_grad=False) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) else: w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, - 2 * ((intermediate_size_per_partition + block_n - 1) // - block_n), + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), (hidden_size + block_k - 1) // block_k, dtype=torch.float32, ), @@ -622,9 +706,10 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK. - value} if self.block_quant else - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + if self.block_quant + else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() @@ -637,17 +722,18 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, if not self.quant_config.is_checkpoint_fp8_serialized: raise ValueError( "Found static activation scheme for checkpoint that " - "was not serialized fp8.") + "was not serialized fp8." + ) - w13_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) @@ -658,7 +744,9 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: # Lazy import to avoid importing triton too early. from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, shuffle_weights) + is_rocm_aiter_moe_enabled, + shuffle_weights, + ) self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() @@ -666,20 +754,23 @@ def process_weights_after_loading(self, layer: Module) -> None: if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" if current_platform.is_fp8_fnuz(): - w13_weight, w13_weight_scale_inv, w13_input_scale = \ + w13_weight, w13_weight_scale_inv, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale_inv, - layer.w13_input_scale) - w2_weight, w2_weight_scale_inv, w2_input_scale = \ + layer.w13_weight, + layer.w13_weight_scale_inv, + layer.w13_input_scale, + ) + ) + w2_weight, w2_weight_scale_inv, w2_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale_inv, - layer.w2_input_scale) + layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale + ) + ) elif self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is # applied on different half for flashinfer vs vllm w13_weight = swap_w13_to_w31(layer.w13_weight.data) - w13_weight_scale_inv = swap_w13_to_w31( - layer.w13_weight_scale_inv.data) + w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data) w2_weight = layer.w2_weight.data w2_weight_scale_inv = layer.w2_weight_scale_inv.data else: @@ -690,65 +781,67 @@ def process_weights_after_loading(self, layer: Module) -> None: # torch.compile() cannot use Parameter subclasses. layer.w13_weight = Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale_inv = Parameter(w13_weight_scale_inv, - requires_grad=False) + layer.w13_weight_scale_inv = Parameter( + w13_weight_scale_inv, requires_grad=False + ) layer.w2_weight = Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, - requires_grad=False) + layer.w2_weight_scale_inv = Parameter( + w2_weight_scale_inv, requires_grad=False + ) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + layer.w13_weight.data, layer.w2_weight.data + ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): if expert_weight_is_col_major(layer.w13_weight_scale_inv): - layer.w13_weight_scale_inv = \ - get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv) + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w13_weight_scale_inv + ) if expert_weight_is_col_major(layer.w2_weight_scale_inv): - layer.w2_weight_scale_inv = \ - get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv) + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w2_weight_scale_inv + ) # If checkpoint is fp16, quantize in place. elif not self.quant_config.is_checkpoint_fp8_serialized: fp8_dtype = current_platform.fp8_dtype() - w13_weight = torch.empty_like(layer.w13_weight.data, - dtype=fp8_dtype) + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. - layer.w13_weight_scale = torch.nn.Parameter(torch.ones( - layer.local_num_experts, - dtype=torch.float32, - device=w13_weight.device), - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.local_num_experts, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) for expert in range(layer.local_num_experts): - w13_weight[expert, :, :], layer.w13_weight_scale[ - expert] = ops.scaled_fp8_quant( - layer.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], layer.w2_weight_scale[ - expert] = ops.scaled_fp8_quant( - layer.w2_weight.data[expert, :, :]) - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight, layer.w2_weight) + layer.w13_weight, layer.w2_weight + ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) # If checkpoint is fp8, we need to handle that the # MoE kernels require single activation scale and single weight # scale for w13 per expert. @@ -756,46 +849,54 @@ def process_weights_after_loading(self, layer: Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.quant_config.activation_scheme == "static": - if (layer.w13_input_scale is None - or layer.w2_input_scale is None): + if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.w13_input_scale) - or not all_close_1d(layer.w2_input_scale)): + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer.") + "for each layer." + ) layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False) + layer.w13_input_scale.max(), requires_grad=False + ) layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False) + layer.w2_input_scale.max(), requires_grad=False + ) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = \ + w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, - layer.w13_input_scale) - w2_weight, w2_weight_scale, w2_input_scale = \ + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, - layer.w2_input_scale) + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + ) # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight_scale = torch.nn.Parameter( - w13_weight_scale, requires_grad=False) + w13_weight_scale, requires_grad=False + ) if w13_input_scale is not None: layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, - requires_grad=False) + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) if w2_input_scale is not None: layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False) + w2_input_scale, requires_grad=False + ) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. @@ -806,25 +907,25 @@ def process_weights_after_loading(self, layer: Module) -> None: start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) start += shard_size if self.rocm_aiter_moe_enabled: shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight, layer.w2_weight) + layer.w13_weight, layer.w2_weight + ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) if self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is @@ -832,8 +933,7 @@ def process_weights_after_loading(self, layer: Module) -> None: assert not self.block_quant register_moe_scaling_factors(layer) w13_weight = swap_w13_to_w31(layer.w13_weight.data) - if self.flashinfer_moe_backend == \ - FlashinferMoeBackend.TENSORRT_LLM: + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) layer.w13_weight.data = w13_weight.data @@ -861,20 +961,24 @@ def process_weights_after_loading(self, layer: Module) -> None: # Ensure column-major TMA alignment expected by DeepGEMM. if expert_weight_is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale_inv) + layer.w13_weight_scale_inv + ) if expert_weight_is_col_major(layer.w2_weight_scale_inv): layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale_inv) + layer.w2_weight_scale_inv + ) - def maybe_make_prepare_finalize( - self) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if (self.rocm_aiter_moe_enabled or self.use_marlin - or self.flashinfer_moe_backend - == FlashinferMoeBackend.TENSORRT_LLM): + def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if ( + self.rocm_aiter_moe_enabled + or self.use_marlin + or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): return None elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - prepare_finalize = ( - build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe)) + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + self.moe + ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize else: @@ -886,23 +990,30 @@ def select_gemm_impl( layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( - BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts) + BatchedTritonOrDeepGemmExperts, + TritonOrDeepGemmExperts, + ) assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( - "Marlin and ROCm AITER are not supported with all2all yet.") + "Marlin and ROCm AITER are not supported with all2all yet." + ) assert self.moe_quant_config is not None - if (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts): - max_num_tokens_per_rank = ( - prepare_finalize.max_num_tokens_per_rank()) + if ( + prepare_finalize.activation_format + == FusedMoEActivationFormat.BatchedExperts + ): + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None logger.debug( "BatchedTritonOrDeepGemmExperts(%s): " "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", - self.__class__.__name__, max_num_tokens_per_rank, - self.weight_block_size, False) + self.__class__.__name__, + max_num_tokens_per_rank, + self.weight_block_size, + False, + ) return BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), @@ -919,22 +1030,30 @@ def select_gemm_impl( else: logger.debug( "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", - self.__class__.__name__, self.weight_block_size, False) + self.__class__.__name__, + self.weight_block_size, + False, + ) return TritonOrDeepGemmExperts( quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, ) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: if self.use_marlin: return None return fp8_w8a8_moe_quant_config( - w1_scale=(layer.w13_weight_scale_inv - if self.block_quant else layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale_inv - if self.block_quant else layer.w2_weight_scale), + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale + ), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.weight_block_size, @@ -963,25 +1082,33 @@ def apply( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - and self.fused_experts is None): - assert activation == 'silu', ( - f"Expected 'silu' activation but got {activation}") - assert scoring_func == 'sigmoid', ( - f"Expected 'sigmoid' scoring func but got {scoring_func}") + if ( + self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + and self.fused_experts is None + ): + assert activation == "silu", ( + f"Expected 'silu' activation but got {activation}" + ) + assert scoring_func == "sigmoid", ( + f"Expected 'sigmoid' scoring func but got {scoring_func}" + ) if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - assert (renormalize and use_grouped_topk - and custom_routing_function is None) - e_score_correction_bias = (e_score_correction_bias.to( - x.dtype) if e_score_correction_bias is not None else None) + + assert ( + renormalize and use_grouped_topk and custom_routing_function is None + ) + e_score_correction_bias = ( + e_score_correction_bias.to(x.dtype) + if e_score_correction_bias is not None + else None + ) return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( routing_logits=router_logits.to(torch.float32), routing_bias=e_score_correction_bias, @@ -1001,8 +1128,7 @@ def apply( routed_scaling=routed_scaling_factor, ) else: - assert (not renormalize - and custom_routing_function is not None) + assert not renormalize and custom_routing_function is not None result = apply_flashinfer_per_tensor_scale_fp8( layer=layer, hidden_states=x, @@ -1012,10 +1138,11 @@ def apply( top_k=top_k, num_expert_group=num_expert_group, topk_group=topk_group, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + ) - zero_expert_num = getattr(layer, 'zero_expert_num', 0) - zero_expert_type = getattr(layer, 'zero_expert_type', None) + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) select_result = FusedMoE.select_experts( hidden_states=x, @@ -1048,7 +1175,9 @@ def apply( if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_fused_experts) + rocm_aiter_fused_experts, + ) + assert self.fused_experts is None result = rocm_aiter_fused_experts( x, @@ -1059,10 +1188,10 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, - quant_config=self.moe_quant_config) + quant_config=self.moe_quant_config, + ) elif self.use_marlin: - assert activation == "silu", ( - f"{activation} not supported for Marlin MoE.") + assert activation == "silu", f"{activation} not supported for Marlin MoE." assert self.fused_experts is None result = torch.ops.vllm.fused_marlin_moe( x, @@ -1079,7 +1208,8 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, - workspace=layer.workspace) + workspace=layer.workspace, + ) elif self.fused_experts: result = self.fused_experts( hidden_states=x, @@ -1094,12 +1224,14 @@ def apply( expert_map=expert_map, ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - assert self.block_quant is None - assert (not renormalize and custom_routing_function is not None) - assert activation == 'silu', ( - f"Expected 'silu' activation but got {activation}") - assert scoring_func == 'sigmoid', ( - f"Expected 'sigmoid' scoring func but got {scoring_func}") + assert not self.block_quant + assert not renormalize and custom_routing_function is not None + assert activation == "silu", ( + f"Expected 'silu' activation but got {activation}" + ) + assert scoring_func == "sigmoid", ( + f"Expected 'sigmoid' scoring func but got {scoring_func}" + ) result = flashinfer_cutlass_moe_fp8( x, @@ -1114,6 +1246,7 @@ def apply( ) else: from vllm.model_executor.layers.fused_moe import fused_experts + result = fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -1128,10 +1261,13 @@ def apply( quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, allow_cutlass_block_scaled_grouped_gemm=( - self.allow_cutlass_block_scaled_grouped_gemm)) + self.allow_cutlass_block_scaled_grouped_gemm + ), + ) if zero_expert_num != 0 and zero_expert_type is not None: - assert not isinstance(result, tuple), \ + assert not isinstance(result, tuple), ( "Shared + zero experts are mutually exclusive not yet supported" + ) return result, zero_expert_result else: return result diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index da1688808bb5..8296bc2ea3b4 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -10,17 +10,22 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, - FusedMoEQuantConfig) -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.utils import set_weight_attrs from vllm.utils import direct_register_custom_op @@ -30,13 +35,12 @@ class GGUFConfig(QuantizationConfig): """Config class for GGUF.""" - def __init__(self, - unquantized_modules: Optional[list[str]] = None) -> None: + def __init__(self, unquantized_modules: Optional[list[str]] = None) -> None: super().__init__() self.unquantized_modules = unquantized_modules or [] def __repr__(self) -> str: - return ("GGUFConfig()") + return "GGUFConfig()" def get_name(self) -> QuantizationMethods: return "gguf" @@ -56,8 +60,9 @@ def get_config_filenames(cls) -> list[str]: def from_config(cls, config: dict[str, Any]) -> "GGUFConfig": return cls() - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): if is_layer_skipped_gguf(prefix, self.unquantized_modules): return UnquantizedLinearMethod() @@ -108,8 +113,9 @@ def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]): MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES -def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, - qweight_type: int) -> torch.Tensor: +def _fused_mul_mat_gguf( + x: torch.Tensor, qweight: torch.Tensor, qweight_type: int +) -> torch.Tensor: if qweight_type in IMATRIX_QUANT_TYPES: mmvq_safe = 8 if qweight.shape[0] > 5120 else 16 else: @@ -117,10 +123,7 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, # HACK: when doing chunked prefill we don't generate output tokens # so input to logits generator is empty which causes invalid parameter if x.shape[0] == 0: - return torch.empty(x.shape[0], - qweight.shape[0], - dtype=x.dtype, - device=x.device) + return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device) # there is no need to call any kernel for fp16/bf16 if qweight_type in UNQUANTIZED_TYPES: return x @ qweight.T @@ -141,8 +144,7 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, # Might be useful if llama.cpp adds a new quantization type. # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. qweight_type = WeightType(qweight_type) - raise NotImplementedError( - f"Unsupported GGUF quantization type: {qweight_type}") + raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") return y @@ -151,10 +153,7 @@ def _fused_mul_mat_gguf_fake( qweight: torch.Tensor, qweight_type: int, ) -> torch.Tensor: - return torch.empty(x.shape[0], - qweight.shape[0], - dtype=x.dtype, - device=x.device) + return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device) try: @@ -179,10 +178,9 @@ def _fused_moe_gguf( qweight_type2: int, activation: str, ) -> torch.Tensor: - def act(x: torch.Tensor): d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) + output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) if activation == "silu": torch.ops._C.silu_and_mul(out, x) @@ -193,50 +191,73 @@ def act(x: torch.Tensor): return out # lazy import to avoid triggering triton import in CPU backend - from vllm.model_executor.layers.fused_moe.fused_moe import ( - moe_align_block_size) + from vllm.model_executor.layers.fused_moe.fused_moe import moe_align_block_size out_hidden_states = torch.empty_like(x) # unless we decent expert reuse we are better off running moe_vec kernel - if (qweight_type2 in MMQ_QUANT_TYPES and qweight_type in MMQ_QUANT_TYPES - and x.shape[0] > 64): + if ( + qweight_type2 in MMQ_QUANT_TYPES + and qweight_type in MMQ_QUANT_TYPES + and x.shape[0] > 64 + ): num_tokens, _ = x.shape E, N, _ = w1.shape top_k = topk_ids.shape[1] BLOCK_SIZE = ops.ggml_moe_get_block_size(qweight_type) - sorted_token_ids, expert_ids, num_tokens_post_padded = \ - moe_align_block_size(topk_ids, BLOCK_SIZE, E) - out = ops.ggml_moe_a8(x, w1, sorted_token_ids, expert_ids, - num_tokens_post_padded, qweight_type, N, top_k, - num_tokens) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, BLOCK_SIZE, E + ) + out = ops.ggml_moe_a8( + x, + w1, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + qweight_type, + N, + top_k, + num_tokens, + ) out = act(out) - out = ops.ggml_moe_a8(out, w2, sorted_token_ids, expert_ids, - num_tokens_post_padded, qweight_type2, - w2.shape[1], 1, num_tokens * top_k) + out = ops.ggml_moe_a8( + out, + w2, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + qweight_type2, + w2.shape[1], + 1, + num_tokens * top_k, + ) out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( - topk_weights.view(num_tokens, top_k, 1)) + topk_weights.view(num_tokens, top_k, 1) + ) ops.moe_sum(out, out_hidden_states) elif qweight_type2 in MMVQ_QUANT_TYPES and qweight_type in MMVQ_QUANT_TYPES: num_tokens, _ = x.shape E, N, _ = w1.shape top_k = topk_ids.shape[1] - out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, - num_tokens) + out = ops.ggml_moe_a8_vec(x, w1, topk_ids, top_k, qweight_type, N, num_tokens) out = act(out) - out = ops.ggml_moe_a8_vec(out, w2, topk_ids, 1, qweight_type2, - w2.shape[1], num_tokens * top_k) + out = ops.ggml_moe_a8_vec( + out, w2, topk_ids, 1, qweight_type2, w2.shape[1], num_tokens * top_k + ) out = out.reshape(num_tokens, top_k, w2.shape[1]).mul_( - topk_weights.view(num_tokens, top_k, 1)) + topk_weights.view(num_tokens, top_k, 1) + ) ops.moe_sum(out, out_hidden_states) else: - logger.warning_once("There is no support for fast MoE kernel " - "for current quantization method. " - "Falling back to slow implementation. ") + logger.warning_once( + "There is no support for fast MoE kernel " + "for current quantization method. " + "Falling back to slow implementation. " + ) for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)): - inp = x[tok].reshape((1, ) + x.shape[1:]) + inp = x[tok].reshape((1,) + x.shape[1:]) current_hidden_state = None for ww, ii in zip(w, idx): expert_up = w1[ii] @@ -245,8 +266,9 @@ def act(x: torch.Tensor): out = act(out) expert_down = w2[ii] - current_state = fused_mul_mat_gguf(out, expert_down, - qweight_type2).mul_(ww) + current_state = fused_mul_mat_gguf( + out, expert_down, qweight_type2 + ).mul_(ww) if current_hidden_state is None: current_hidden_state = current_state else: @@ -292,15 +314,15 @@ def _apply_gguf_embedding( elif qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] x_flat = x.flatten() - assert (hidden_size == qweight.shape[1] // type_size * block_size) + assert hidden_size == qweight.shape[1] // type_size * block_size quant = torch.index_select(qweight, dim=0, index=x_flat) - dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, - x_flat.shape[0], dtype) + dequant = ops.ggml_dequantize( + quant, qweight_type, hidden_size, x_flat.shape[0], dtype + ) return dequant.view(*x.shape, hidden_size) else: qweight_type = WeightType(qweight_type) - raise NotImplementedError( - f"Unsupported GGUF quantization type: {qweight_type}") + raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") def _apply_gguf_embedding_fake( @@ -335,18 +357,24 @@ class GGUFLinearMethod(LinearMethodBase): def __init__(self, quant_config: GGUFConfig): self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): self.params_dtype = params_dtype output_size_per_partition = sum(output_partition_sizes) tensor_shape = (output_size_per_partition, input_size_per_partition) qweight = GGUFUninitializedParameter(requires_grad=False) set_weight_attrs( - qweight, { + qweight, + { "input_dim": 1, "output_dim": 0, "tensor_shape": tensor_shape, @@ -354,31 +382,34 @@ def create_weights(self, layer: torch.nn.Module, "data_container": [], "shard_id": [], "shard_id_map": {}, - }) + }, + ) set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("qweight", qweight) - qweight_type = Parameter(torch.empty(len(output_partition_sizes), - dtype=torch.uint8), - requires_grad=False) + qweight_type = Parameter( + torch.empty(len(output_partition_sizes), dtype=torch.uint8), + requires_grad=False, + ) set_weight_attrs( - qweight_type, { + qweight_type, + { "is_gguf_weight_type": True, "weight_type": 0, "shard_weight_type": {}, - "ignore_warning": True - }) + "ignore_warning": True, + }, + ) set_weight_attrs(qweight_type, extra_weight_attrs) layer.register_parameter("qweight_type", qweight_type) def process_weights_after_loading(self, layer: torch.nn.Module): qweight_type = layer.qweight_type.weight_type - if not (qweight_type in UNQUANTIZED_TYPES - or qweight_type in DEQUANT_TYPES): + if not (qweight_type in UNQUANTIZED_TYPES or qweight_type in DEQUANT_TYPES): qweight_type = WeightType(qweight_type) raise ValueError( - f"Unsupported GGUF quantization type {qweight_type} in " - f"layer {layer}.") + f"Unsupported GGUF quantization type {qweight_type} in layer {layer}." + ) # For MergedColumnParallelLinear and QKVParallelLinear, we need to # materialize the padded weight parameter for CUDA Graph compatibility. self._create_padded_weight_param(layer) @@ -391,22 +422,22 @@ def _create_padded_weight_param(self, layer: torch.nn.Module): if len(data_container := qweight.data_container) > 1: dtype = {data.dtype for data in data_container} assert len(dtype) == 1, ValueError( - f"Data container has mixed dtypes: {dtype}") + f"Data container has mixed dtypes: {dtype}" + ) dtype = next(iter(dtype)) # concat dim0 and pad dim1 padded_side = max(x.size(1) for x in data_container) concat_side = sum(x.size(0) for x in data_container) # Pad the quantized weights to dense tensor, and create a map # with the location of each shard in the padded tensor. - padded_data = torch.zeros((concat_side, padded_side), - dtype=dtype, - device=qweight.device) + padded_data = torch.zeros( + (concat_side, padded_side), dtype=dtype, device=qweight.device + ) # (dim0_start, dim0_end, dim1_size) shard_offset_map = dict[str, tuple[int, int, int]]() for idx in shard_id: id_in_container = shard_id_map[idx] - start = sum( - x.size(0) for x in data_container[:id_in_container]) + start = sum(x.size(0) for x in data_container[:id_in_container]) end = start + data_container[id_in_container].size(0) size = data_container[id_in_container].size(1) padded_data[start:end, :size] = data_container[id_in_container] @@ -414,14 +445,15 @@ def _create_padded_weight_param(self, layer: torch.nn.Module): qweight.data_container.clear() padded_param = Parameter(padded_data, requires_grad=False) set_weight_attrs(padded_param, vars(qweight)) - set_weight_attrs(padded_param, - {"shard_offset_map": shard_offset_map}) + set_weight_attrs(padded_param, {"shard_offset_map": shard_offset_map}) layer.register_parameter("qweight", padded_param) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: shard_id = layer.qweight.shard_id if shard_id: @@ -434,8 +466,9 @@ def apply(self, qweight_type = layer.qweight_type.shard_weight_type[idx] result.append( fused_mul_mat_gguf( - x, qweight[start:end, :offset].contiguous(), - qweight_type)) + x, qweight[start:end, :offset].contiguous(), qweight_type + ) + ) out = torch.cat(result, axis=1) else: qweight = layer.qweight @@ -461,63 +494,71 @@ def __init__( super().__init__(moe) self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - - tensor_shape = (num_experts, 2 * intermediate_size_per_partition, - hidden_size) - #gate up proj + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + tensor_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size) + # gate up proj w13_qweight = GGUFUninitializedParameter(requires_grad=False) set_weight_attrs( - w13_qweight, { + w13_qweight, + { "input_dim": 1, "output_dim": 0, "tensor_shape": tensor_shape, "is_gguf_weight": True, "data_container": [], - }) + }, + ) set_weight_attrs(w13_qweight, extra_weight_attrs) layer.register_parameter("w13_qweight", w13_qweight) - w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), - requires_grad=False) - set_weight_attrs(w13_qweight_type, { - "is_gguf_weight_type": True, - "weight_type": 0, - "ignore_warning": True - }) + w13_qweight_type = Parameter( + torch.empty(1, dtype=torch.uint8), requires_grad=False + ) + set_weight_attrs( + w13_qweight_type, + {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True}, + ) set_weight_attrs(w13_qweight_type, extra_weight_attrs) layer.register_parameter("w13_qweight_type", w13_qweight_type) - tensor_shape = (num_experts, intermediate_size_per_partition, - hidden_size) - #gate down proj + tensor_shape = (num_experts, intermediate_size_per_partition, hidden_size) + # gate down proj w2_qweight = GGUFUninitializedParameter(requires_grad=False) set_weight_attrs( - w2_qweight, { + w2_qweight, + { "input_dim": 1, "output_dim": 0, "tensor_shape": tensor_shape, "is_gguf_weight": True, "data_container": [], - }) + }, + ) set_weight_attrs(w2_qweight, extra_weight_attrs) layer.register_parameter("w2_qweight", w2_qweight) - w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8), - requires_grad=False) - set_weight_attrs(w2_qweight_type, { - "is_gguf_weight_type": True, - "weight_type": 0, - "ignore_warning": True - }) + w2_qweight_type = Parameter( + torch.empty(1, dtype=torch.uint8), requires_grad=False + ) + set_weight_attrs( + w2_qweight_type, + {"is_gguf_weight_type": True, "weight_type": 0, "ignore_warning": True}, + ) set_weight_attrs(w2_qweight_type, extra_weight_attrs) layer.register_parameter("w2_qweight_type", w2_qweight_type) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: return None def apply( @@ -546,14 +587,14 @@ def apply( assert self.fused_experts is None if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `GGUFMoEMethod` yet.") + raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.") assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: raise NotImplementedError( "Apply router weight on input is not supported for" - "fused GGUF MoE method.") + "fused GGUF MoE method." + ) topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, @@ -567,11 +608,18 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) - return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, - topk_weights, topk_ids, - layer.w13_qweight_type.weight_type, - layer.w2_qweight_type.weight_type, activation) + indices_type=self.topk_indices_dtype, + ) + return fused_moe_gguf( + x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights, + topk_ids, + layer.w13_qweight_type.weight_type, + layer.w2_qweight_type.weight_type, + activation, + ) class GGUFEmbeddingMethod(GGUFLinearMethod): @@ -581,17 +629,14 @@ class GGUFEmbeddingMethod(GGUFLinearMethod): quant_config: The GGUF quantization config. """ - def embedding(self, layer: torch.nn.Module, - x: torch.Tensor) -> torch.Tensor: + def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor: qweight = layer.qweight qweight_type = layer.qweight_type.weight_type hidden_size = qweight.tensor_shape[1] - return apply_gguf_embedding(x, - qweight, - qweight_type, - hidden_size, - dtype=self.params_dtype) + return apply_gguf_embedding( + x, qweight, qweight_type, hidden_size, dtype=self.params_dtype + ) class GGUFUninitializedParameter(UninitializedParameter): diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 842ce92333c9..8f36fc70c444 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -14,14 +14,19 @@ from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.gptq_utils import ( - get_linear_quant_method) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - RowvLLMParameter) + get_linear_quant_method, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, +) from vllm.transformers_utils.config import get_safetensors_params_metadata from vllm.utils import is_list_of @@ -81,7 +86,8 @@ def __init__( if self.weight_bits not in [2, 3, 4, 8]: raise ValueError( "Currently, only 2/3/4/8-bit weight quantization is " - f"supported for GPTQ, but got {self.weight_bits} bits.") + f"supported for GPTQ, but got {self.weight_bits} bits." + ) self.modules_in_block_to_quantize = modules_in_block_to_quantize or [] @@ -123,14 +129,22 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQConfig": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - autoround_version = cls.get_from_keys_or(config, ["autoround_version"], - default="") + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + autoround_version = cls.get_from_keys_or( + config, ["autoround_version"], default="" + ) modules_in_block_to_quantize = cls.get_from_keys_or( - config, ["modules_in_block_to_quantize"], default=None) - return cls(weight_bits, group_size, desc_act, lm_head_quantized, - dynamic, autoround_version, modules_in_block_to_quantize) + config, ["modules_in_block_to_quantize"], default=None + ) + return cls( + weight_bits, + group_size, + desc_act, + lm_head_quantized, + dynamic, + autoround_version, + modules_in_block_to_quantize, + ) def get_quant_method( self, layer: torch.nn.Module, prefix: str @@ -146,43 +160,40 @@ def get_quant_method( "sym": True, # GPTQ typically uses symmetric quantization "lm_head": False, } - return MoeWNA16Config.from_config(config).get_quant_method( - layer, prefix) + return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix) return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) def apply_vllm_mapper(self, hf_to_vllm_mapper): if self.modules_in_block_to_quantize is not None: self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list( - self.modules_in_block_to_quantize) + self.modules_in_block_to_quantize + ) - def maybe_update_config(self, - model_name: str, - revision: Optional[str] = None): + def maybe_update_config(self, model_name: str, revision: Optional[str] = None): if self.modules_in_block_to_quantize: if is_list_of(self.modules_in_block_to_quantize, list): # original modules_in_block_to_quantize: list[list[str]] # flatten original modules_in_block_to_quantize self.modules_in_block_to_quantize = [ - item for sublist in self.modules_in_block_to_quantize + item + for sublist in self.modules_in_block_to_quantize for item in sublist ] return unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] - metadata = get_safetensors_params_metadata(model_name, - revision=revision) + metadata = get_safetensors_params_metadata(model_name, revision=revision) quant_layers: set[str] = { param_name.rsplit(".", 1)[0] for param_name, info in metadata.items() - if (dtype := info.get('dtype', None)) + if (dtype := info.get("dtype", None)) and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes } self.modules_in_block_to_quantize = list(quant_layers) class ExllamaState(Enum): - UNUSED = enum.auto() UNINITIALIZED = enum.auto() READY = enum.auto() @@ -214,14 +225,15 @@ def create_weights( raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " - "tensor parallel size.") + "tensor parallel size." + ) output_size_per_partition = sum(output_partition_sizes) - if (output_size_per_partition % self.quant_config.pack_factor.numerator - != 0): + if output_size_per_partition % self.quant_config.pack_factor.numerator != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " - "tensor parallel size.") + "tensor parallel size." + ) if self.quant_config.group_size != -1: group_size = self.quant_config.group_size @@ -230,8 +242,10 @@ def create_weights( exllama_state = ExllamaState.UNINITIALIZED scale_and_zero_size = input_size // group_size scale_and_zero_input_dim = None - if (input_size != input_size_per_partition - and self.quant_config.group_size != -1): + if ( + input_size != input_size_per_partition + and self.quant_config.group_size != -1 + ): # For act-order models, we cannot use Exllama for row parallel layer if self.quant_config.desc_act: exllama_state = ExllamaState.UNUSED @@ -250,56 +264,56 @@ def create_weights( output_dim=1, packed_dim=0, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) - - g_idx = RowvLLMParameter(data=torch.tensor( - [ - i // self.quant_config.group_size - for i in range(input_size_per_partition) - ], - dtype=torch.int32, - ), - input_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) + + g_idx = RowvLLMParameter( + data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) qzeros_args = { - "data": - torch.empty( + "data": torch.empty( scale_and_zero_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( scale_and_zero_size, output_size_per_partition, dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if scale_and_zero_input_dim is None: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) qzeros = PackedColumnParameter( output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) qzeros = PackedvLLMParameter( input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) @@ -321,24 +335,30 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.quant_config.desc_act: layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - layer.g_idx.data = torch.empty((0, ), - dtype=torch.int, - device=layer.g_idx.device) + layer.g_idx.data = torch.empty( + (0,), dtype=torch.int, device=layer.g_idx.device + ) layer.exllama_state = ExllamaState.READY - ops.gptq_shuffle(layer.qweight, layer.g_idx, - self.quant_config.weight_bits) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - out_shape = x.shape[:-1] + (layer.qweight.shape[-1], ) + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.qweight.shape[-1],) reshaped_x = x.reshape(-1, x.shape[-1]) - output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, - layer.scales, layer.g_idx, - layer.exllama_state == ExllamaState.READY, - self.quant_config.weight_bits) + output = ops.gptq_gemm( + reshaped_x, + layer.qweight, + layer.qzeros, + layer.scales, + layer.g_idx, + layer.exllama_state == ExllamaState.READY, + self.quant_config.weight_bits, + ) if bias is not None: output.add_(bias) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py index c193dd85e32f..85cf4ed4ac58 100644 --- a/vllm/model_executor/layers/quantization/gptq_bitblas.py +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -7,25 +7,39 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - set_weight_attrs) -from vllm.model_executor.layers.quantization import (QuantizationConfig, - QuantizationMethods) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + set_weight_attrs, +) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( - BitBLASLinearKernel, MPLinearLayerConfig) + BitBLASLinearKernel, + MPLinearLayerConfig, +) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS) + BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS, +) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM) + BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM, +) from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks, - check_bitblas_supported, verify_bitblas_supported) + MINIMUM_BITBLAS_VERSION, + bitblas_repeat_scales_on_all_ranks, + check_bitblas_supported, + verify_bitblas_supported, +) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - RowvLLMParameter) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, +) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -60,14 +74,16 @@ def __init__( quant_method: Optional[str], lm_head_quantized: bool, ) -> None: - try: import bitblas + if version.parse(bitblas.__version__) < version.parse( - MINIMUM_BITBLAS_VERSION): + MINIMUM_BITBLAS_VERSION + ): raise ImportError( "bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError as e: bitblas_import_exception = e raise ValueError( @@ -95,17 +111,20 @@ def __init__( raise ValueError( f"BitBLAS does not support weight_bits = {self.weight_bits}. " f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} " - "are supported.") + "are supported." + ) if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM: raise ValueError( f"BitBLAS does not support is_sym = {self.is_sym}. " - f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.") + f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported." + ) self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE - storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE - if c.isdigit())) + storage_nbit = int( + "".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE if c.isdigit()) + ) # 4 Bits packed into 32 bit datatype. self.pack_factor = storage_nbit // weight_bits @@ -115,17 +134,20 @@ def __init__( self.zeros_mode = self.ZEROS_MODE if (weight_bits, is_sym) not in self.TYPE_MAP: - raise ValueError("Unsupported quantization config: " - f"bits={weight_bits}, sym={is_sym}") + raise ValueError( + f"Unsupported quantization config: bits={weight_bits}, sym={is_sym}" + ) self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] def __repr__(self) -> str: - return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act})" - f"is_sym={self.is_sym}, " - f"quant_method={self.quant_method})") + return ( + f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})" + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -150,36 +172,46 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQBitBLASConfig": desc_act = cls.get_from_keys(config, ["desc_act"]) is_sym = cls.get_from_keys(config, ["sym"]) quant_method = cls.get_from_keys(config, ["quant_method"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(weight_bits, group_size, desc_act, is_sym, quant_method, - lm_head_quantized) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + return cls( + weight_bits, group_size, desc_act, is_sym, quant_method, lm_head_quantized + ) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg) - is_valid_user_quant = (user_quant is None or user_quant == "bitblas" - or user_quant == "gptq_bitblas") + is_valid_user_quant = ( + user_quant is None + or user_quant == "bitblas" + or user_quant == "gptq_bitblas" + ) if can_convert and is_valid_user_quant: - msg = ("The model is convertible to {} during runtime." - " Using {} kernel.".format(cls.get_name(), cls.get_name())) + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) logger.info(msg) return cls.get_name() if can_convert and user_quant == "gptq": - logger.info("Detected that the model can run with gptq_bitblas" - ", however you specified quantization=gptq explicitly," - " so forcing gptq. Use quantization=gptq_bitblas for" - " faster inference") + logger.info( + "Detected that the model can run with gptq_bitblas" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_bitblas for" + " faster inference" + ) return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQBitBLASLinearMethod"]: - if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) - and self.lm_head_quantized): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["GPTQBitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or ( + isinstance(layer, ParallelLMHead) and self.lm_head_quantized + ): return GPTQBitBLASLinearMethod(self) return None @@ -200,8 +232,7 @@ def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]): return False # If we cannot find the info needed in the config, cannot convert. - if (num_bits is None or group_size is None or sym is None - or desc_act is None): + if num_bits is None or group_size is None or sym is None or desc_act is None: return False if (num_bits, sym) not in cls.TYPE_MAP: @@ -214,9 +245,9 @@ def is_gptq_bitblas_compatible(cls, quant_config: dict[str, Any]): return False # Otherwise, can convert if model satisfies bitblas constraints. - return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits, - sym)], - group_size=group_size) + return check_bitblas_supported( + quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size + ) class GPTQBitBLASLinearMethod(LinearMethodBase): @@ -232,8 +263,10 @@ class GPTQBitBLASLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQBitBLASConfig) -> None: self.quant_config = quant_config # Verify supported on platform. - verify_bitblas_supported(quant_type=self.quant_config.quant_type, - group_size=self.quant_config.group_size) + verify_bitblas_supported( + quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size, + ) def create_weights( self, @@ -247,7 +280,7 @@ def create_weights( ) -> None: """Creates quantized weights for use in linear operations. - The function initializes and returns a dictionary containing + The function initializes and returns a dictionary containing quantized weights, scales, and zeros for performing quantized matrix multiplication operations. @@ -256,11 +289,11 @@ def create_weights( output_partition_sizes: The size of the output partition. input_size: The total size of the input (unused). output_size: The total size of the output (unused). - params_dtype: + params_dtype: The data type of the parameters (expected to be torch.float16). Returns: - A dictionary containing the quantized weights ('qweight'), + A dictionary containing the quantized weights ('qweight'), scales ('scales'), and zeros ('zeros'). Raises: @@ -269,8 +302,9 @@ def create_weights( in `quant_config`. """ if params_dtype != torch.float16: - raise ValueError("Parameter data type must be torch.float16, " - f"but got {params_dtype}") + raise ValueError( + f"Parameter data type must be torch.float16, but got {params_dtype}" + ) # Normalize group_size if self.quant_config.group_size != -1: @@ -293,18 +327,19 @@ def create_weights( mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), - partition_weight_shape=\ - (input_size_per_partition, output_size_per_partition), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), weight_type=self.quant_config.quant_type, act_type=params_dtype, group_size=self.quant_config.group_size, zero_points=False, - has_g_idx=self.quant_config.desc_act + has_g_idx=self.quant_config.desc_act, ) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for GPTQBitBLASLinearMethod", - kernel_type.__name__) + logger.info("Using %s for GPTQBitBLASLinearMethod", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # Normalize group_size @@ -314,9 +349,9 @@ def create_weights( group_size = input_size # Determine sharding - if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act, - self.quant_config.group_size, - is_row_parallel): + if bitblas_repeat_scales_on_all_ranks( + self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel + ): # By setting scale_dim == None, weight_loader will # repeat the scales on each GPU in TP>1 case. scales_and_zp_input_dim = None @@ -339,16 +374,19 @@ def create_weights( output_dim=1, packed_dim=0, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) # Activation order # Ignore warning from fused linear layers such as QKVParallelLinear. - g_idx = RowvLLMParameter(data=torch.empty( - input_size_per_partition, - dtype=torch.int32, - ), - input_dim=0, - weight_loader=weight_loader) + g_idx = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) # Scales scales = Parameter( @@ -370,45 +408,42 @@ def create_weights( # Quantized zero-points qzeros_args = { - "data": - torch.empty( + "data": torch.empty( scales_and_zp_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( scales_and_zp_size, output_size_per_partition, dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if scales_and_zp_input_dim is None: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) qzeros = PackedColumnParameter( output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) qzeros = PackedvLLMParameter( input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 253675e25f34..8fa70a240f9f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,30 +10,48 @@ import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, - FusedMoEQuantConfig) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, - UnquantizedFusedMoEMethod) -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, + UnquantizedFusedMoEMethod, +) +from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( - MPLinearLayerConfig, choose_mp_linear_kernel) + MPLinearLayerConfig, + choose_mp_linear_kernel, +) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.gptq_utils import ( - get_dynamic_override, get_linear_quant_method, override_config) + get_dynamic_override, + get_linear_quant_method, + override_config, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, check_moe_marlin_supports_layer, - marlin_make_workspace_new, marlin_moe_permute_scales, marlin_permute_bias, - marlin_repeat_scales_on_all_ranks, verify_marlin_supported) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - RowvLLMParameter) + check_marlin_supported, + check_moe_marlin_supports_layer, + marlin_make_workspace_new, + marlin_moe_permute_scales, + marlin_permute_bias, + marlin_repeat_scales_on_all_ranks, + verify_marlin_supported, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, +) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.transformers_utils.config import get_safetensors_params_metadata @@ -52,9 +70,13 @@ def get_moe_quant_method( if isinstance(layer, FusedMoE): # False = skip module, None = no override, else = Positive match - if get_dynamic_override( # noqa: E712 + if ( + get_dynamic_override( # noqa: E712 cloned_config, # noqa: E712 - layer_name=prefix) == False: # noqa: E712 + layer_name=prefix, + ) + == False + ): # noqa: E712 return UnquantizedFusedMoEMethod(layer.moe_config) if prefix: @@ -75,15 +97,16 @@ class GPTQMarlinConfig(QuantizationConfig): } def __init__( - self, - weight_bits: int, - group_size: int, - desc_act: bool, - is_sym: bool, - lm_head_quantized: bool, - dynamic: dict[str, dict[str, Union[int, bool]]], - full_config: dict[str, Any], - modules_in_block_to_quantize: Optional[list[str]] = None) -> None: + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + dynamic: dict[str, dict[str, Union[int, bool]]], + full_config: dict[str, Any], + modules_in_block_to_quantize: Optional[list[str]] = None, + ) -> None: super().__init__() if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False @@ -125,8 +148,9 @@ def __init__( self.full_config = full_config if (weight_bits, is_sym) not in self.TYPE_MAP: - raise ValueError("Unsupported quantization config: " - f"bits={weight_bits}, sym={is_sym}") + raise ValueError( + f"Unsupported quantization config: bits={weight_bits}, sym={is_sym}" + ) self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] @@ -169,50 +193,64 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQMarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) is_sym = cls.get_from_keys(config, ["sym"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) modules_in_block_to_quantize = cls.get_from_keys_or( - config, ["modules_in_block_to_quantize"], default=None) - return cls(weight_bits, group_size, desc_act, is_sym, - lm_head_quantized, dynamic, config, - modules_in_block_to_quantize) + config, ["modules_in_block_to_quantize"], default=None + ) + return cls( + weight_bits, + group_size, + desc_act, + is_sym, + lm_head_quantized, + dynamic, + config, + modules_in_block_to_quantize, + ) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) - is_valid_user_quant = (user_quant is None or user_quant == "marlin" - or user_quant == "gptq_marlin") + is_valid_user_quant = ( + user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin" + ) if can_convert and is_valid_user_quant: - msg = ("The model is convertible to {} during runtime." - " Using {} kernel.".format(cls.get_name(), cls.get_name())) + msg = ( + "The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name()) + ) logger.info(msg) return cls.get_name() if can_convert and user_quant == "gptq": - logger.info("Detected that the model can run with gptq_marlin" - ", however you specified quantization=gptq explicitly," - " so forcing gptq. Use quantization=gptq_marlin for" - " faster inference") + logger.info( + "Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference" + ) return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, FusedMoE): - from vllm.model_executor.layers.quantization.moe_wna16 import ( - MoeWNA16Config) + from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config + if not check_moe_marlin_supports_layer(layer, self.group_size): logger.warning_once( f"Layer '{prefix}' is not supported by GPTQMoeMarlin. " - "Falling back to Moe WNA16 kernels.") - return MoeWNA16Config.from_config( - self.full_config).get_quant_method(layer, prefix) - return get_moe_quant_method(self, layer, prefix, - GPTQMarlinMoEMethod) - return get_linear_quant_method(self, layer, prefix, - GPTQMarlinLinearMethod) + "Falling back to Moe WNA16 kernels." + ) + return MoeWNA16Config.from_config(self.full_config).get_quant_method( + layer, prefix + ) + return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod) + return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) @classmethod def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]): @@ -229,41 +267,40 @@ def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]): return False # Marlin conversion is only valid if required properties are found - if (num_bits is None or group_size is None or sym is None - or desc_act is None): + if num_bits is None or group_size is None or sym is None or desc_act is None: return False if (num_bits, sym) not in cls.TYPE_MAP: return False - return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], - group_size=group_size) + return check_marlin_supported( + quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size + ) def apply_vllm_mapper(self, hf_to_vllm_mapper): if self.modules_in_block_to_quantize is not None: self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list( - self.modules_in_block_to_quantize) + self.modules_in_block_to_quantize + ) - def maybe_update_config(self, - model_name: str, - revision: Optional[str] = None): + def maybe_update_config(self, model_name: str, revision: Optional[str] = None): if self.modules_in_block_to_quantize: if is_list_of(self.modules_in_block_to_quantize, list): # original modules_in_block_to_quantize: list[list[str]] # flatten original modules_in_block_to_quantize self.modules_in_block_to_quantize = [ - item for sublist in self.modules_in_block_to_quantize + item + for sublist in self.modules_in_block_to_quantize for item in sublist ] return unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] - metadata = get_safetensors_params_metadata(model_name, - revision=revision) + metadata = get_safetensors_params_metadata(model_name, revision=revision) quant_layers: set[str] = { param_name.rsplit(".", 1)[0] for param_name, info in metadata.items() - if (dtype := info.get('dtype', None)) + if (dtype := info.get("dtype", None)) and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes } self.modules_in_block_to_quantize = list(quant_layers) @@ -282,8 +319,10 @@ def __init__(self, quant_config: GPTQMarlinConfig) -> None: self.quant_config = quant_config # Verify supported on platform. - verify_marlin_supported(quant_type=self.quant_config.quant_type, - group_size=self.quant_config.group_size) + verify_marlin_supported( + quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size, + ) def create_weights( self, @@ -301,20 +340,21 @@ def create_weights( mp_linear_kernel_config = MPLinearLayerConfig( full_weight_shape=(input_size, output_size), - partition_weight_shape=\ - (input_size_per_partition, output_size_per_partition), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), weight_type=self.quant_config.quant_type, act_type=params_dtype, group_size=self.quant_config.group_size, zero_points=False, - has_g_idx=self.quant_config.desc_act + has_g_idx=self.quant_config.desc_act, ) kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: - logger.info("Using %s for GPTQMarlinLinearMethod", - kernel_type.__name__) + logger.info("Using %s for GPTQMarlinLinearMethod", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # Normalize group_size @@ -324,9 +364,9 @@ def create_weights( group_size = input_size # Determine sharding - if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, - self.quant_config.group_size, - is_row_parallel): + if marlin_repeat_scales_on_all_ranks( + self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel + ): # By setting scale_dim == None, weight_loader will # repeat the scales on each GPU in TP>1 case. scales_and_zp_input_dim = None @@ -348,67 +388,69 @@ def create_weights( output_dim=1, packed_dim=0, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) # Activation order - g_idx = RowvLLMParameter(data=torch.empty( - input_size_per_partition, - dtype=torch.int32, - ), - input_dim=0, - weight_loader=weight_loader) + g_idx = RowvLLMParameter( + data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) qzeros_args = { - "data": - torch.empty( + "data": torch.empty( scales_and_zp_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( scales_and_zp_size, output_size_per_partition, dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if scales_and_zp_input_dim is None: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) qzeros = PackedColumnParameter( output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) qzeros = PackedvLLMParameter( input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - self.kernel = kernel_type(mp_linear_kernel_config, - w_q_param_name="qweight", - w_s_param_name="scales", - w_zp_param_name="qzeros", - w_gidx_param_name="g_idx") + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx", + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.kernel.process_weights_after_loading(layer) @@ -437,8 +479,7 @@ def __init__( elif self.quant_config.quant_type.size_bits == 8: self.quant_type = scalar_types.uint8b128 else: - raise ValueError( - "GPTQMarlinMoEMethod only supports int4 and int8 now.") + raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") def create_weights( self, @@ -449,28 +490,27 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - intermediate_size_full = extra_weight_attrs.pop( - "intermediate_size_full") + intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full") self.is_k_full = (not self.quant_config.desc_act) or ( - intermediate_size_per_partition == intermediate_size_full) + intermediate_size_per_partition == intermediate_size_full + ) if self.quant_config.group_size != -1: scales_size13 = hidden_size // self.quant_config.group_size - w2_scales_size = (intermediate_size_full - if self.quant_config.desc_act else - intermediate_size_per_partition) - scales_size2 = (w2_scales_size // self.quant_config.group_size) + w2_scales_size = ( + intermediate_size_full + if self.quant_config.desc_act + else intermediate_size_per_partition + ) + scales_size2 = w2_scales_size // self.quant_config.group_size strategy = FusedMoeWeightScaleSupported.GROUP.value else: scales_size13 = 1 scales_size2 = 1 strategy = FusedMoeWeightScaleSupported.CHANNEL.value - extra_weight_attrs.update({ - "quant_method": strategy, - "is_transposed": True - }) + extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True}) # Fused gate_up_proj (column parallel) w13_qweight = torch.nn.Parameter( torch.empty( @@ -487,8 +527,7 @@ def create_weights( w2_qweight = torch.nn.Parameter( torch.empty( num_experts, - intermediate_size_per_partition // - self.quant_config.pack_factor, + intermediate_size_per_partition // self.quant_config.pack_factor, hidden_size, dtype=torch.int32, ), @@ -498,51 +537,51 @@ def create_weights( set_weight_attrs(w2_qweight, extra_weight_attrs) # up_proj scales w13_scales = torch.nn.Parameter( - torch.empty(num_experts, - scales_size13, - 2 * intermediate_size_per_partition, - dtype=params_dtype), + torch.empty( + num_experts, + scales_size13, + 2 * intermediate_size_per_partition, + dtype=params_dtype, + ), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) # down_proj scales w2_scales = torch.nn.Parameter( - torch.empty(num_experts, - scales_size2, - hidden_size, - dtype=params_dtype), + torch.empty(num_experts, scales_size2, hidden_size, dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) # don't shard the w2 scales when running act order - set_weight_attrs(w2_scales, - {"load_full_w2": self.quant_config.desc_act}) + set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act}) # up_proj scales w13_qzeros = torch.nn.Parameter( - torch.empty(num_experts, - scales_size13, - 2 * intermediate_size_per_partition // - self.quant_config.pack_factor, - dtype=params_dtype), + torch.empty( + num_experts, + scales_size13, + 2 * intermediate_size_per_partition // self.quant_config.pack_factor, + dtype=params_dtype, + ), requires_grad=False, ) layer.register_parameter("w13_qzeros", w13_qzeros) set_weight_attrs(w13_qzeros, extra_weight_attrs) # down_proj scales w2_qzeros = torch.nn.Parameter( - torch.empty(num_experts, - scales_size2, - hidden_size // self.quant_config.pack_factor, - dtype=params_dtype), + torch.empty( + num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype, + ), requires_grad=False, ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) # don't shard the w2 scales when running act order - set_weight_attrs(w2_qzeros, - {"load_full_w2": self.quant_config.desc_act}) + set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act}) w13_g_idx = torch.nn.Parameter( torch.empty( num_experts, @@ -571,8 +610,7 @@ def create_weights( ), requires_grad=False, ) - layer.register_parameter("w13_g_idx_sort_indices", - w13_g_idx_sort_indices) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) w2_g_idx_sort_indices = torch.nn.Parameter( torch.empty( @@ -582,15 +620,13 @@ def create_weights( ), requires_grad=False, ) - layer.register_parameter("w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) device = layer.w13_qweight.device layer.workspace = marlin_make_workspace_new(device, 4) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Process act_order if self.quant_config.desc_act: # Get sorting based on g_idx @@ -600,42 +636,36 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort( - layer.w13_g_idx[e]).to(torch.int32) + w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to( + torch.int32 + ) w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( - torch.int32) - w13_sorted_g_idx[e] = layer.w13_g_idx[e][ - w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_g_idx[e][ - w2_g_idx_sort_indices[e]] + torch.int32 + ) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]] replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) - replace_parameter(layer, "w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_parameter(layer, "w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) else: # Reset g_idx related tensors num_experts = layer.w13_g_idx.shape[0] device = layer.w13_g_idx.device layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, - device=device), + torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) # Repack weights @@ -665,9 +695,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: replace_parameter(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, - size_k=layer.w2_scales.shape[1] * - (self.quant_config.group_size if self.quant_config.group_size != -1 - else self.quant_config.pack_factor), + size_k=layer.w2_scales.shape[1] + * ( + self.quant_config.group_size + if self.quant_config.group_size != -1 + else self.quant_config.pack_factor + ), size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) @@ -680,7 +713,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: return None def apply( @@ -710,7 +744,8 @@ def apply( if enable_eplb: raise NotImplementedError( - "EPLB not supported for `GPTQMarlinMoEMethod` yet.") + "EPLB not supported for `GPTQMarlinMoEMethod` yet." + ) assert activation == "silu", "Only SiLU activation is supported." @@ -726,7 +761,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return torch.ops.vllm.fused_marlin_moe( x, @@ -748,4 +784,5 @@ def apply( sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, workspace=layer.workspace, - is_k_full=self.is_k_full) + is_k_full=self.is_k_full, + ) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 6b9e3effc29d..8f0df55b0a5c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -9,12 +9,16 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import (QuantizationConfig, - QuantizationMethods) -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -24,15 +28,12 @@ GPTQ_MARLIN_24_MIN_THREAD_K = 128 GPTQ_MARLIN_24_MAX_PARALLEL = 64 -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [ - scalar_types.uint4b8, scalar_types.uint8b128 -] +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] class GPTQMarlin24Config(QuantizationConfig): - """Config class for Marlin24. - """ + """Config class for Marlin24.""" def __init__( self, @@ -48,17 +49,18 @@ def __init__( self.group_size = group_size # Verify - if quant_type is None or \ - quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: + if quant_type is None or quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: raise ValueError( f"Marlin_24 does not support quant_type = {quant_type}. " f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} " - "are supported.") + "are supported." + ) if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: raise ValueError( f"Marlin_24 does not support group_size = {self.group_size}. " f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " - "are supported.") + "are supported." + ) self.quant_type = quant_type @@ -83,7 +85,8 @@ def __init__( def __repr__(self) -> str: return "Marlin24Config(quant_type={}, group_size={})".format( - self.quant_type, self.group_size) + self.quant_type, self.group_size + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -110,23 +113,26 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config": @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: - is_marlin_24_format = ( - hf_quant_cfg.get("checkpoint_format") == "marlin_24") + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: + is_marlin_24_format = hf_quant_cfg.get("checkpoint_format") == "marlin_24" - is_valid_user_quant = (user_quant is None or user_quant == "gptq" - or user_quant == "gptq_marlin_24") + is_valid_user_quant = ( + user_quant is None or user_quant == "gptq" or user_quant == "gptq_marlin_24" + ) if is_marlin_24_format and is_valid_user_quant: - msg = ("The model is serialized in {} format. " - "Using {} kernel.".format(cls.get_name(), cls.get_name())) + msg = "The model is serialized in {} format. Using {} kernel.".format( + cls.get_name(), cls.get_name() + ) logger.info(msg) return cls.get_name() return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQMarlin24LinearMethod"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["GPTQMarlin24LinearMethod"]: if isinstance(layer, LinearBase): return GPTQMarlin24LinearMethod(self) return None @@ -156,7 +162,8 @@ def create_weights( weight_loader = extra_weight_attrs["weight_loader"] if params_dtype != torch.float16: raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") + f"The params dtype must be float16, but got {params_dtype}" + ) # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) @@ -164,38 +171,46 @@ def create_weights( raise ValueError( f"Weight output_size_per_partition = " f"{output_size_per_partition} is not divisible by " - f"min_n_threads = {self.quant_config.min_n_threads}.") + f"min_n_threads = {self.quant_config.min_n_threads}." + ) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( f"Weight output_size_per_partition = " f"{output_size_per_partition} is not divisible by " - f"pack_factor = {self.quant_config.pack_factor}.") + f"pack_factor = {self.quant_config.pack_factor}." + ) # Validate input_size_per_partition if input_size_per_partition % self.quant_config.min_k_threads != 0: raise ValueError( f"Weight input_size_per_partition = " f"{input_size_per_partition} is not divisible by " - f"min_k_threads = {self.quant_config.min_k_threads}.") - if (self.quant_config.group_size != -1 and - input_size_per_partition % self.quant_config.group_size != 0): - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible by " - f"group_size = {self.quant_config.group_size}.") + f"min_k_threads = {self.quant_config.min_k_threads}." + ) + if ( + self.quant_config.group_size != -1 + and input_size_per_partition % self.quant_config.group_size != 0 + ): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}." + ) # Check that we have at least 4 tiles horizontally in the shard num_tiles_per_perm = self.quant_config.perm_len // ( - self.quant_config.tile_size**2) + self.quant_config.tile_size**2 + ) if output_size_per_partition % num_tiles_per_perm != 0: - raise ValueError( - "Each permutation group must reside on the same gpu") + raise ValueError("Each permutation group must reside on the same gpu") # Quantized 4Bit weights packed into Int32. qweight = PackedvLLMParameter( data=torch.empty( input_size_per_partition // self.quant_config.tile_size // 2, - output_size_per_partition * self.quant_config.tile_size // - self.quant_config.pack_factor, + output_size_per_partition + * self.quant_config.tile_size + // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), @@ -204,55 +219,57 @@ def create_weights( packed_dim=1, packed_factor=self.quant_config.pack_factor, marlin_tile_size=self.quant_config.tile_size, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) # Meta - meta = PackedvLLMParameter(data=torch.empty( - input_size_per_partition // 8 // 2 // 2, - output_size_per_partition * 2, - device="cuda", - dtype=torch.int16, - ), - input_dim=0, - output_dim=1, - packed_dim=1, - packed_factor=1, - marlin_tile_size=2, - weight_loader=weight_loader) + meta = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + device="cuda", + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader, + ) # Determine if channelwise or not - input_groups = (1 if self.quant_config.group_size == -1 else - input_size_per_partition // - self.quant_config.group_size) + input_groups = ( + 1 + if self.quant_config.group_size == -1 + else input_size_per_partition // self.quant_config.group_size + ) weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( input_groups, output_size_per_partition, device="cuda", dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if input_groups == 1: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) # Allocate workspace (Used for internal locking mechanism) max_workspace_size = ( - output_size_per_partition // - self.quant_config.min_n_threads) * self.quant_config.max_parallel + output_size_per_partition // self.quant_config.min_n_threads + ) * self.quant_config.max_parallel - workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - weight_loader=weight_loader) + workspace = BasevLLMParameter( + data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int), + weight_loader=weight_loader, + ) layer.register_parameter("B_24", qweight) layer.register_parameter("B_meta", meta) @@ -283,12 +300,19 @@ def apply( size_k = x_2d.shape[1] size_n = scales.shape[1] - output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, - workspace, - self.quant_config.quant_type, - size_m, size_n, size_k) - - output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + output_2d = ops.gptq_marlin_24_gemm( + x_2d, + qweight, + meta, + scales, + workspace, + self.quant_config.quant_type, + size_m, + size_n, + size_k, + ) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],)) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/hqq_marlin.py b/vllm/model_executor/layers/quantization/hqq_marlin.py index 8385ccac32a2..e61caf6b459b 100644 --- a/vllm/model_executor/layers/quantization/hqq_marlin.py +++ b/vllm/model_executor/layers/quantization/hqq_marlin.py @@ -7,20 +7,32 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - marlin_make_empty_g_idx, marlin_permute_bias, marlin_permute_scales) + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + marlin_make_empty_g_idx, + marlin_permute_bias, + marlin_permute_scales, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace) + MarlinWorkspace, +) from vllm.model_executor.layers.quantization.utils.quant_utils import gptq_pack -from vllm.model_executor.parameter import (BasevLLMParameter, - GroupQuantScaleParameter, - PackedvLLMParameter) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, +) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -36,10 +48,10 @@ def __init__( skip_modules: Optional[list[str]] = None, ) -> None: super().__init__() - assert group_size == 64, ("The only supported HQQ group size is " - "currently 64.") - assert weight_bits == 4, ("The only supported HQQ quantization " - "bitsize is currently 4.") + assert group_size == 64, "The only supported HQQ group size is currently 64." + assert weight_bits == 4, ( + "The only supported HQQ quantization bitsize is currently 4." + ) self.weight_bits = weight_bits self.group_size = group_size @@ -48,8 +60,10 @@ def __init__( self.skip_modules = skip_modules def __repr__(self) -> str: - return (f"HQQMarlinConfig(quant_type={self.quant_type}, " - f"group_size={self.group_size})") + return ( + f"HQQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -69,7 +83,7 @@ def get_config_filenames(cls) -> list[str]: @classmethod def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig": - wq_params = (config["quant_config"]["weight_quant_params"]) + wq_params = config["quant_config"]["weight_quant_params"] weight_bits = cls.get_from_keys(wq_params, ["nbits"]) group_size = cls.get_from_keys(wq_params, ["group_size"]) skip_modules = config["skip_modules"] @@ -77,14 +91,16 @@ def from_config(cls, config: dict[str, Any]) -> "HQQMarlinConfig": def is_layer_skipped(self, prefix: str) -> bool: # Split the prefix into its dot-separated components - components = prefix.split('.') + components = prefix.split(".") # Check if any of the skip modules exactly matches any component return self.skip_modules is not None and any( - module_name in components for module_name in self.skip_modules) + module_name in components for module_name in self.skip_modules + ) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): if self.is_layer_skipped(prefix): return UnquantizedLinearMethod() @@ -94,7 +110,6 @@ def get_quant_method(self, layer: torch.nn.Module, # Empty HQQ parameter, will be ignored during loading class HQQEmptyParameter(BasevLLMParameter): - def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): pass @@ -112,23 +127,18 @@ def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # HQQ packing creates issues with sharding - therefore, prior to loading, we # repack to GPTQ. We also reshape the weights to their proper GPTQ shape. class HQQweightParameter(PackedvLLMParameter): - # unpack function from https://github.com/mobiusml/hqq - def unpack_4bit_u8(self, - W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8 + def unpack_4bit_u8(self, W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8 assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)" dtype = torch.uint8 step = W_q.shape[0] - tmp = torch.empty([2 * step, W_q.shape[1]], - dtype=dtype, - device=W_q.device) + tmp = torch.empty([2 * step, W_q.shape[1]], dtype=dtype, device=W_q.device) tmp[:step] = (W_q & 0b11110000) >> 4 tmp[step:] = W_q & 0b00001111 return tmp - def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int, - **kwargs): + def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int, **kwargs): super().__init__(packed_factor, packed_dim, None, **kwargs) self.weight_bits = weight_bits self.input_shape = self.shape[self.input_dim] * self.packed_factor @@ -136,36 +146,41 @@ def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int, def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): loaded_weight = self.unpack_4bit_u8(loaded_weight) - loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose( - 1, 0) - loaded_weight = gptq_pack(loaded_weight, self.weight_bits, - loaded_weight.shape[0], - loaded_weight.shape[1]) + loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(1, 0) + loaded_weight = gptq_pack( + loaded_weight, + self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1], + ) super().load_merged_column_weight(loaded_weight, **kwargs) def load_row_parallel_weight(self, loaded_weight: torch.Tensor): loaded_weight = self.unpack_4bit_u8(loaded_weight) - loaded_weight = loaded_weight.reshape(self.output_shape, - -1).transpose(1, 0) - loaded_weight = gptq_pack(loaded_weight, self.weight_bits, - loaded_weight.shape[0], - loaded_weight.shape[1]) + loaded_weight = loaded_weight.reshape(self.output_shape, -1).transpose(1, 0) + loaded_weight = gptq_pack( + loaded_weight, + self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1], + ) super().load_row_parallel_weight(loaded_weight) def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): loaded_weight = self.unpack_4bit_u8(loaded_weight) - loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose( - 1, 0) - loaded_weight = gptq_pack(loaded_weight, self.weight_bits, - loaded_weight.shape[0], - loaded_weight.shape[1]) + loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(1, 0) + loaded_weight = gptq_pack( + loaded_weight, + self.weight_bits, + loaded_weight.shape[0], + loaded_weight.shape[1], + ) super().load_qkv_weight(loaded_weight, **kwargs) # Zero points and scales in HQQ must also be reshaped to correspond to W_q's # GPTQ shape (transposed - we transpose them too when processing weights). class HQQZeroScaleParameter(GroupQuantScaleParameter): - def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): loaded_weight = loaded_weight.reshape(-1, self.shape[1]) super().load_merged_column_weight(loaded_weight, **kwargs) @@ -180,8 +195,7 @@ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): class HQQMarlinMethod(LinearMethodBase): - """Linear method for HQQ Marlin. - """ + """Linear method for HQQ Marlin.""" def __init__( self, @@ -204,8 +218,9 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader", error_loader) - self.scales_and_zp_size = (input_size_per_partition // - self.quant_config.group_size) + self.scales_and_zp_size = ( + input_size_per_partition // self.quant_config.group_size + ) qweight = HQQweightParameter( data=torch.empty( @@ -218,25 +233,30 @@ def create_weights( packed_dim=0, packed_factor=self.quant_config.pack_factor, weight_bits=self.quant_config.weight_bits, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - zeros = HQQZeroScaleParameter(data=torch.empty( - self.output_size_per_partition, - self.scales_and_zp_size, - dtype=params_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) - - scales = HQQZeroScaleParameter(data=torch.empty( - self.output_size_per_partition, - self.scales_and_zp_size, - dtype=params_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + zeros = HQQZeroScaleParameter( + data=torch.empty( + self.output_size_per_partition, + self.scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + scales = HQQZeroScaleParameter( + data=torch.empty( + self.output_size_per_partition, + self.scales_and_zp_size, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("W_q", qweight) layer.register_parameter("zero", zeros) @@ -244,17 +264,29 @@ def create_weights( # Ignore extra parameters in the HQQ model. # To be added as needed. - ignore_parameters = ("axis", "channel_wise", "compute_dtype", - "encoded_state_dict", "group_size", "nbits", - "offload_meta", "optimize", "packing", - "quant_scale", "quant_zero", "round_zero", - "shape", "stores_quant_config", - "unpack_view_dtype", "view_as_float") + ignore_parameters = ( + "axis", + "channel_wise", + "compute_dtype", + "encoded_state_dict", + "group_size", + "nbits", + "offload_meta", + "optimize", + "packing", + "quant_scale", + "quant_zero", + "round_zero", + "shape", + "stores_quant_config", + "unpack_view_dtype", + "view_as_float", + ) for name in ignore_parameters: layer.register_parameter( name, - HQQEmptyParameter(data=torch.empty(0), - weight_loader=weight_loader)) + HQQEmptyParameter(data=torch.empty(0), weight_loader=weight_loader), + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: dev = layer.W_q.device @@ -268,14 +300,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.output_size_per_partition, self.quant_config.weight_bits, ).to(dev) - marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0), - self.input_size_per_partition, - self.output_size_per_partition, - self.quant_config.group_size).to(dev) - marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0), - self.input_size_per_partition, - self.output_size_per_partition, - self.quant_config.group_size).to(dev) + marlin_s = marlin_permute_scales( + layer.scale.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size, + ).to(dev) + marlin_zp = marlin_permute_scales( + layer.zero.transpose(1, 0), + self.input_size_per_partition, + self.output_size_per_partition, + self.quant_config.group_size, + ).to(dev) layer.g_idx = marlin_make_empty_g_idx(dev) layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev) @@ -293,9 +329,11 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - workspace = MarlinWorkspace(self.output_size_per_partition, - GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = MarlinWorkspace( + self.output_size_per_partition, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL, + ) scales = layer.marlin_scales zeros = layer.marlin_zeros diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py index 8aa1f1a14bfc..4e736378e9da 100644 --- a/vllm/model_executor/layers/quantization/inc.py +++ b/vllm/model_executor/layers/quantization/inc.py @@ -21,12 +21,15 @@ import torch from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, UnquantizedFusedMoEMethod) -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) + FusedMoE, + UnquantizedFusedMoEMethod, +) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) class INCConfig(QuantizationConfig): @@ -44,8 +47,9 @@ def get_supported_act_dtypes(cls) -> list[torch.dtype]: def from_config(cls, config: dict[str, Any]) -> "INCConfig": raise AssertionError - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index ece3e5817116..8786638869a4 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -7,8 +7,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform # Using the default value (240.0) from pytorch will cause accuracy @@ -28,12 +27,12 @@ class QuantFP8(CustomOp): """ def __init__( - self, - static: bool, - group_shape: GroupShape, - num_token_padding: Optional[int] = None, - column_major_scales: bool = False, - use_ue8m0: Optional[bool] = None, # for Torch compile + self, + static: bool, + group_shape: GroupShape, + num_token_padding: Optional[int] = None, + column_major_scales: bool = False, + use_ue8m0: Optional[bool] = None, # for Torch compile ): """ :param static: static or dynamic quantization @@ -57,8 +56,9 @@ def __init__( self.group_size = group_shape.col else: assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} - assert not static or group_shape == GroupShape.PER_TENSOR, \ + assert not static or group_shape == GroupShape.PER_TENSOR, ( "Only per-tensor scales supported for static quantization." + ) self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN def forward_cuda( @@ -70,23 +70,28 @@ def forward_cuda( if self.is_group_quant: assert scale is None, "Group quantization is always dynamic" from vllm.model_executor.layers.quantization.utils import fp8_utils + return fp8_utils.per_token_group_quant_fp8( x, group_size=self.group_size, column_major_scales=self.column_major_scales, dtype=_FP8_DTYPE, - use_ue8m0=self.use_ue8m0) + use_ue8m0=self.use_ue8m0, + ) assert (scale is not None) == self.static - assert scale_ub is None or (not self.static and self.group_shape - == GroupShape.PER_TOKEN - and scale_ub.numel() == 1) + assert scale_ub is None or ( + not self.static + and self.group_shape == GroupShape.PER_TOKEN + and scale_ub.numel() == 1 + ) return ops.scaled_fp8_quant( x, scale, num_token_padding=self.num_token_padding, scale_ub=scale_ub, - use_per_token_if_dynamic=self.use_per_token_if_dynamic) + use_per_token_if_dynamic=self.use_per_token_if_dynamic, + ) def forward_native( self, @@ -99,9 +104,11 @@ def forward_native( return self._quantize_group_native(x) assert (scale is not None) == self.static - assert scale_ub is None or (not self.static and self.group_shape - == GroupShape.PER_TOKEN - and scale_ub.numel() == 1) + assert scale_ub is None or ( + not self.static + and self.group_shape == GroupShape.PER_TOKEN + and scale_ub.numel() == 1 + ) if scale is None: if self.group_shape == GroupShape.PER_TOKEN: @@ -130,7 +137,8 @@ def forward_native( return out, scale def _quantize_group_native( - self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: orig_shape = x.shape hidden_dim = x.shape[-1] num_groups = (hidden_dim + self.group_size - 1) // self.group_size @@ -138,7 +146,7 @@ def _quantize_group_native( if padded_dim != hidden_dim: padding = padded_dim - hidden_dim - x = F.pad(x, (0, padding), mode='constant', value=0.0) + x = F.pad(x, (0, padding), mode="constant", value=0.0) x_grouped = x.view(-1, num_groups, self.group_size) absmax = x_grouped.abs().max(dim=-1, keepdim=True)[0].float() @@ -156,7 +164,7 @@ def _quantize_group_native( x_quant = x_quant.view(orig_shape) scales = scales.squeeze(-1) - scales = scales.reshape(orig_shape[:-1] + (num_groups, )) + scales = scales.reshape(orig_shape[:-1] + (num_groups,)) if self.column_major_scales: scales = scales.transpose(-2, -1).contiguous().transpose(-1, -2) diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 353942cdd591..4aa0e464e0f5 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -9,17 +9,25 @@ from torch.nn.parameter import Parameter from vllm._ipex_ops import ipex_ops as ops -from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, - FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe import ( + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization import (QuantizationConfig, - QuantizationMethods) -from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, - is_layer_skipped_awq) -from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, - Fp8LinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) +from vllm.model_executor.layers.quantization.awq import ( + AWQLinearMethod, + is_layer_skipped_awq, +) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -56,17 +64,22 @@ def __init__( self.pack_factor = 32 // self.weight_bits if self.weight_bits not in [4]: - raise ValueError(f"IPEX quantization supports weight bits [4], " - f"but got {self.weight_bits}.") + raise ValueError( + f"IPEX quantization supports weight bits [4], " + f"but got {self.weight_bits}." + ) if self.method not in ["awq", "gptq"]: - raise ValueError(f"IPEX quantization supports [awq, gptq], " - f"but got {self.method}.") + raise ValueError( + f"IPEX quantization supports [awq, gptq], but got {self.method}." + ) def __repr__(self) -> str: - return (f"IPEXConfig(method={self.method}," - f"weight_bits={self.weight_bits}, " - f"group_size={self.group_size})") + return ( + f"IPEXConfig(method={self.method}," + f"weight_bits={self.weight_bits}, " + f"group_size={self.group_size})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -92,24 +105,24 @@ def from_config(cls, config: dict[str, Any]) -> "IPEXConfig": method = cls.get_from_keys(config, ["quant_method"]).lower() if method == "awq": weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) - group_size = cls.get_from_keys(config, - ["q_group_size", "group_size"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) modules_to_not_convert = cls.get_from_keys_or( - config, ["modules_to_not_convert"], None) - return cls(method, weight_bits, group_size, modules_to_not_convert, - False, False) + config, ["modules_to_not_convert"], None + ) + return cls( + method, weight_bits, group_size, modules_to_not_convert, False, False + ) # otherwise for gptq weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False) - return cls(method, weight_bits, group_size, [], desc_act, - lm_head_quantized) + return cls(method, weight_bits, group_size, [], desc_act, lm_head_quantized) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: if not current_platform.is_cpu() and not current_platform.is_xpu(): return None @@ -120,8 +133,9 @@ def override_quantization_method( return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["LinearMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["LinearMethodBase"]: if isinstance(layer, LinearBase): if self.method == "awq": if is_layer_skipped_awq(prefix, self.modules_to_not_convert): @@ -133,8 +147,7 @@ def get_quant_method(self, layer: torch.nn.Module, class IPEXGPTQLinearMethod(GPTQLinearMethod): - """GPTQ linear method using IPEX for the CPU/XPU backend. - """ + """GPTQ linear method using IPEX for the CPU/XPU backend.""" def __init__(self, quant_config: IPEXConfig): self.quant_config = quant_config # type: ignore @@ -144,18 +157,20 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: try: import intel_extension_for_pytorch as ipex - if version.parse( - ipex.__version__) < version.parse(MIN_IPEX_VERSION): + + if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION): raise ImportError( "intel_extension_for_pytorch version is " "wrong. Please install " - f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}." + ) except ImportError as err: raise ImportError( "Please install " f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" - " to use IPEX-AWQ linear method.") from err + " to use IPEX-AWQ linear method." + ) from err # Using the compute dtype (lowp_mode) as INT8 to leverage instructions # with better performance. lowp_mode = ipex.quantization.WoqLowpMode.INT8 @@ -172,32 +187,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) layer.ipex_output_size = layer.qweight.shape[-1] g_idx = layer.g_idx if self.quant_config.desc_act else None - layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ - IPEXWeightOnlyQuantizedLinear.from_weight( - layer.qweight, - layer.scales, - layer.qzeros, - layer.qweight.size(0), - layer.ipex_output_size, - qconfig=qconfig, - g_idx=g_idx, - bias=bias, - group_size=self.quant_config.group_size, - quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"] + layer.ipex_qlinear = ( + ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + g_idx=g_idx, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"], + ) ) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) - return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size,)) class IPEXAWQLinearMethod(AWQLinearMethod): - """AWQ linear method using IPEX for the CPU/XPU backend. - """ + """AWQ linear method using IPEX for the CPU/XPU backend.""" def __init__(self, quant_config: IPEXConfig): self.quant_config = quant_config # type: ignore @@ -209,18 +226,20 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: try: import intel_extension_for_pytorch as ipex - if version.parse( - ipex.__version__) < version.parse(MIN_IPEX_VERSION): + + if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION): raise ImportError( "intel_extension_for_pytorch version is " "wrong. Please install " - f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.") + f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}." + ) except ImportError as err: raise ImportError( "Please install " f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" - " to use IPEX-AWQ linear method.") from err + " to use IPEX-AWQ linear method." + ) from err # Using the compute dtype (lowp_mode) as INT8 to leverage instructions # with better performance. @@ -237,104 +256,117 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: group_size=self.quant_config.group_size, ) - layer.ipex_output_size = layer.qweight.size( - 1) * self.quant_config.pack_factor - layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ - IPEXWeightOnlyQuantizedLinear.from_weight( - layer.qweight, - layer.scales, - layer.qzeros, - layer.qweight.size(0), - layer.ipex_output_size, - qconfig=qconfig, - bias=bias, - group_size=self.quant_config.group_size, - quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore + layer.ipex_output_size = layer.qweight.size(1) * self.quant_config.pack_factor + layer.ipex_qlinear = ( + ipex.llm.quantization.woq_linear.IPEXWeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + bias=bias, + group_size=self.quant_config.group_size, + quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"], # type: ignore + ) ) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: reshaped_x = x.reshape(-1, x.shape[-1]) out = layer.ipex_qlinear(reshaped_x) - return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + return out.reshape(x.shape[:-1] + (layer.ipex_output_size,)) class XPUFp8LinearMethod(Fp8LinearMethod): - def __init__(self, quant_config: Fp8Config): super().__init__(quant_config) def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, - scale=None) + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) # Update the layer with the new values. layer.weight = Parameter(qweight, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.input_scale = None - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: weight = layer.weight.data weight_scale = layer.weight_scale.data - output = torch.ops.torch_ipex.fp8_gemm_w8a16(x, weight, True, - weight_scale, bias) + output = torch.ops.torch_ipex.fp8_gemm_w8a16( + x, weight, True, weight_scale, bias + ) return output class XPUFp8MoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): super().__init__(layer.moe_config) self.quant_config = quant_config - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) # INPUT_SCALES layer.w13_input_scale = None layer.w2_input_scale = None @@ -342,29 +374,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: if not self.quant_config.is_checkpoint_fp8_serialized: fp8_dtype = current_platform.fp8_dtype() - w13_weight = torch.empty_like(layer.w13_weight.data, - dtype=fp8_dtype) + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. - layer.w13_weight_scale = torch.nn.Parameter(torch.ones( - layer.local_num_experts, - dtype=torch.float32, - device=w13_weight.device), - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + torch.ones( + layer.local_num_experts, + dtype=torch.float32, + device=w13_weight.device, + ), + requires_grad=False, + ) for expert in range(layer.local_num_experts): - w13_weight[expert, :, :], layer.w13_weight_scale[ - expert] = ops.scaled_fp8_quant( - layer.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], layer.w2_weight_scale[ - expert] = ops.scaled_fp8_quant( - layer.w2_weight.data[expert, :, :]) - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) import intel_extension_for_pytorch as ipex + layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( layer.w13_weight, layer.w2_weight, @@ -376,7 +409,8 @@ def process_weights_after_loading(self, layer: Module) -> None: ) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: return None def apply( diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py index 1280f5f1eadf..055a3ebbced6 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -24,7 +24,6 @@ class MPLinearLayerConfig: class MPLinearKernel(ABC): - @classmethod @abstractmethod def get_min_capability(cls) -> int: @@ -32,16 +31,17 @@ def get_min_capability(cls) -> int: @classmethod @abstractmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: raise NotImplementedError - def __init__(self, - c: MPLinearLayerConfig, - w_q_param_name: str, - w_s_param_name: str, - w_zp_param_name: Optional[str] = None, - w_gidx_param_name: Optional[str] = None) -> None: + def __init__( + self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None, + ) -> None: assert self.can_implement(c) self.config = c self.w_q_name = w_q_param_name @@ -58,31 +58,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: raise NotImplementedError @abstractmethod - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: raise NotImplementedError - def _transform_param(self, layer: torch.nn.Module, name: Optional[str], - fn: Callable) -> None: + def _transform_param( + self, layer: torch.nn.Module, name: Optional[str], fn: Callable + ) -> None: if name is not None and getattr(layer, name, None) is not None: - old_param = getattr(layer, name) new_param = fn(old_param) # replace the parameter with torch.nn.Parameter for TorchDynamo # compatibility replace_parameter( - layer, name, - torch.nn.Parameter(new_param.data, requires_grad=False)) + layer, name, torch.nn.Parameter(new_param.data, requires_grad=False) + ) def _get_weight_params( - self, layer: torch.nn.Module) -> tuple[ - torch.Tensor, # w_q - torch.Tensor, # w_s - Optional[torch.Tensor], # w_zp, - Optional[torch.Tensor] # w_gidx - ]: + self, layer: torch.nn.Module + ) -> tuple[ + torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor], # w_gidx + ]: return ( getattr(layer, self.w_q_name), getattr(layer, self.w_s_name), diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index f10d20999bee..1759d142e6cc 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -5,23 +5,33 @@ import vllm.envs as envs from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 - AllSparkLinearKernel) + AllSparkLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501 - BitBLASLinearKernel) + BitBLASLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501 - ConchLinearKernel) + ConchLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501 - CutlassW4A8LinearKernel) + CutlassW4A8LinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501 - Dynamic4bitLinearKernel) + Dynamic4bitLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 - ExllamaLinearKernel) + ExllamaLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 - MacheteLinearKernel) + MacheteLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 - MarlinLinearKernel) + MarlinLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 - MPLinearKernel, MPLinearLayerConfig) + MPLinearKernel, + MPLinearLayerConfig, +) from vllm.platforms import current_platform # in priority/performance order (when available) @@ -38,11 +48,11 @@ def choose_mp_linear_kernel( - config: MPLinearLayerConfig, - compute_capability: Optional[int] = None) -> type[MPLinearKernel]: + config: MPLinearLayerConfig, compute_capability: Optional[int] = None +) -> type[MPLinearKernel]: """ Choose an MPLinearKernel that can implement the given config for the given - compute capability. Attempts to choose the best kernel in terms of + compute capability. Attempts to choose the best kernel in terms of performance. Args: @@ -69,14 +79,18 @@ def choose_mp_linear_kernel( for kernel in _POSSIBLE_KERNELS: if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: failure_reasons.append( - f' {kernel.__name__} disabled by environment variable') + f" {kernel.__name__} disabled by environment variable" + ) continue - if (compute_capability is not None - and kernel.get_min_capability() > compute_capability): + if ( + compute_capability is not None + and kernel.get_min_capability() > compute_capability + ): failure_reasons.append( f"{kernel.__name__} requires capability " f"{kernel.get_min_capability()}, current compute " - f" capability is {compute_capability}") + f" capability is {compute_capability}" + ) continue can_implement, failure_reason = kernel.can_implement(config) @@ -84,10 +98,10 @@ def choose_mp_linear_kernel( return kernel else: failure_reasons.append( - f' {kernel.__name__} cannot implement due to: {failure_reason}' + f" {kernel.__name__} cannot implement due to: {failure_reason}" ) raise ValueError( - "Failed to find a kernel that can implement the "\ - "WNA16 linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) + "Failed to find a kernel that can implement the " + "WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons) + ) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py index 785e559df8f7..c353372b05ec 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py @@ -8,22 +8,21 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.allspark_utils import ( - ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, check_allspark_supported_dtype_shape) -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + check_allspark_supported_dtype_shape, +) +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig class AllSparkLinearKernel(MPLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 80 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: if c.has_g_idx: return False, "Act reordering currently not supported by AllSpark" @@ -35,7 +34,8 @@ def can_implement(cls, c.partition_weight_shape[1], # out_features c.group_size, c.weight_type, - c.act_type) + c.act_type, + ) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} @@ -49,8 +49,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: sm_count = properties.multi_processor_count sm_version = properties.major * 10 + properties.minor gemm_args = {} - gemm_args['sm_count'] = sm_count - gemm_args['sm_version'] = sm_version + gemm_args["sm_count"] = sm_count + gemm_args["sm_version"] = sm_version self.gemm_args = gemm_args @@ -59,43 +59,42 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: old_scale_param = getattr(layer, self.w_s_name) assert isinstance(old_weight_param, BasevLLMParameter) - permute_param_layout_(old_weight_param, - input_dim=0, - output_dim=1, - packed_dim=0) + permute_param_layout_(old_weight_param, input_dim=0, output_dim=1, packed_dim=0) assert isinstance(old_scale_param, BasevLLMParameter) permute_param_layout_(old_scale_param, input_dim=0, output_dim=1) # unpack weight from K / 4 x N int32 to K x N uint8 - new_weight_param = torch.nn.Parameter(old_weight_param.data, - requires_grad=False) - new_weight_param.data = new_weight_param.data.t().contiguous().view( - dtype=torch.uint8) + new_weight_param = torch.nn.Parameter( + old_weight_param.data, requires_grad=False + ) + new_weight_param.data = ( + new_weight_param.data.t().contiguous().view(dtype=torch.uint8) + ) new_weight_param.data = new_weight_param.data.t().contiguous() - new_scale_param = torch.nn.Parameter(old_scale_param.data, - requires_grad=False) + new_scale_param = torch.nn.Parameter(old_scale_param.data, requires_grad=False) # reorder K x N weight as N32K16 format for Ampere W8A16 - new_weight_param.data, new_scale_param.data, _ = \ - ops.allspark_repack_weight( - new_weight_param.data, new_scale_param.data, None, - c.zero_points) + new_weight_param.data, new_scale_param.data, _ = ops.allspark_repack_weight( + new_weight_param.data, new_scale_param.data, None, c.zero_points + ) replace_parameter(layer, self.w_q_name, new_weight_param.data) replace_parameter(layer, self.w_s_name, new_scale_param.data) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config gemm_args = self.gemm_args w_q, w_s, _, _ = self._get_weight_params(layer) reshaped_x = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) output = ops.allspark_w8a16_gemm( a=reshaped_x, @@ -104,11 +103,12 @@ def apply_weights(self, b_qzeros=None, n=c.partition_weight_shape[1], group_size=c.group_size, - sm_count=gemm_args['sm_count'], - sm_version=gemm_args['sm_version'], + sm_count=gemm_args["sm_count"], + sm_version=gemm_args["sm_version"], CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp=c.zero_points, - n32k16_reorder=True) + n32k16_reorder=True, + ) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py index fe72910659e2..d1ff582c4e21 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py @@ -10,10 +10,16 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES, - MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx, - check_bitblas_supports_shape, query_bitblas_supported_quant_types, - unpack_gptq_qweight, unpack_gptq_qzeros) + BITBLAS_OPTIMIZE_FEATURES, + BITBLAS_SUPPORTED_GROUP_SIZES, + MINIMUM_BITBLAS_VERSION, + bitblas_make_empty_g_idx, + bitblas_sort_g_idx, + check_bitblas_supports_shape, + query_bitblas_supported_quant_types, + unpack_gptq_qweight, + unpack_gptq_qzeros, +) from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig @@ -21,7 +27,6 @@ class BitBLASLinearKernel(MPLinearKernel): - OPT_FEATURES: list[int] = BITBLAS_OPTIMIZE_FEATURES ENABLE_TUNING: bool = True MATMUL_LAYOUT: str = "nt" @@ -44,8 +49,9 @@ def __init__( bitblas_quant_config: Optional[QuantizationConfig] = None, ): self.quant_config = bitblas_quant_config - super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name, - w_gidx_param_name) + super().__init__( + c, w_q_param_name, w_s_param_name, w_zp_param_name, w_gidx_param_name + ) def repack_bitblas_from_gptq( self, @@ -54,19 +60,18 @@ def repack_bitblas_from_gptq( qzeros: Optional[torch.Tensor] = None, ): from bitblas.quantization.utils import general_compress + assert self.bitblas_matmul is not None, "bitblas_matmul is None" quant_config = self.quant_config # qweight in gptq old quant linear stored with # (outfeatures, infeatures), should be transposed. - qweight = b_q_weight.T.contiguous().view( - quant_config.torch_storage_dtype) # type: ignore[union-attr] - intweight = unpack_gptq_qweight( - qweight, - quant_config.weight_bits).contiguous() # type: ignore[union-attr] + qweight = b_q_weight.T.contiguous().view(quant_config.torch_storage_dtype) # type: ignore[union-attr] + intweight = unpack_gptq_qweight(qweight, quant_config.weight_bits).contiguous() # type: ignore[union-attr] if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined] qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined] - intweight.cpu()).cuda() + intweight.cpu() + ).cuda() # scales in gptq old quant linear stored with # (infeatures // group_size, outfeatures), should be transposed. scales = scales.T.contiguous() @@ -90,9 +95,14 @@ def repack_bitblas_from_gptq( general_compress( intzeros.T.contiguous().cpu().numpy(), weight_bits, - )).to(qweight.device). - to(quant_config.torch_storage_dtype # type: ignore[union-attr] - ).contiguous()) + ) + ) + .to(qweight.device) + .to( + quant_config.torch_storage_dtype # type: ignore[union-attr] + ) + .contiguous() + ) else: raise ValueError("Unsupported zeros type: {}".format(zeros_mode)) @@ -103,41 +113,50 @@ def get_min_capability(cls) -> int: return 70 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: - + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: is_bitblas_installed = True try: import bitblas + if version.parse(bitblas.__version__) < version.parse( - MINIMUM_BITBLAS_VERSION): + MINIMUM_BITBLAS_VERSION + ): raise ImportError( "bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError: is_bitblas_installed = False if not is_bitblas_installed: - return False, "bitblas is not installed. Please install bitblas "\ - "by running `pip install bitblas>="\ - f"{MINIMUM_BITBLAS_VERSION}`" + return ( + False, + "bitblas is not installed. Please install bitblas " + "by running `pip install bitblas>=" + f"{MINIMUM_BITBLAS_VERSION}`", + ) quant_types = query_bitblas_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: - return False, (f"Quant type ({c.weight_type}) not supported by" - f" BitBLAS, supported types are: {quant_types}") + return False, ( + f"Quant type ({c.weight_type}) not supported by" + f" BitBLAS, supported types are: {quant_types}" + ) if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES: - return False, (f"Group size ({c.group_size}) not supported by " - "BitBLAS, supported group sizes are: " - f"{BITBLAS_SUPPORTED_GROUP_SIZES}") + return False, ( + f"Group size ({c.group_size}) not supported by " + "BitBLAS, supported group sizes are: " + f"{BITBLAS_SUPPORTED_GROUP_SIZES}" + ) return check_bitblas_supports_shape( c.partition_weight_shape[1], # out_features c.partition_weight_shape[0], # in_features c.full_weight_shape[0], # in_features - c.group_size) + c.group_size, + ) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} @@ -149,14 +168,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Default names since bitblas requires empty parameters for these, # TODO: remove this requirement from bitblas (allow optional tensors) - if self.w_gidx_name is None: - self.w_gidx_name = "g_idx" - if self.w_zp_name is None: - self.w_zp_name = "qzeros" + if getattr(self, "w_gidx_name", None) is None: + self.w_gidx_name: str = "g_idx" + if getattr(self, "w_zp_name", None) is None: + self.w_zp_name: str = "qzeros" if c.has_g_idx: g_idx, g_idx_sort_indices = bitblas_sort_g_idx( - getattr(layer, self.w_gidx_name)) + getattr(layer, self.w_gidx_name) + ) self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) layer.g_idx_sort_indices = g_idx_sort_indices else: @@ -169,13 +189,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device)) # Repack weights - bitblas_qweight, bitblas_scales, bitblas_qzeros = ( - self.repack_bitblas_from_gptq( - layer.qweight, - layer.scales, - None if quant_config.is_sym else # type: ignore[union-attr] - layer.qzeros, # type: ignore[union-attr] - )) + bitblas_qweight, bitblas_scales, bitblas_qzeros = self.repack_bitblas_from_gptq( + layer.qweight, + layer.scales, + None if quant_config.is_sym else layer.qzeros, # type: ignore[union-attr] + ) replace_parameter(layer, self.w_q_name, bitblas_qweight) replace_parameter(layer, self.w_s_name, bitblas_scales) if bitblas_qzeros is not None: @@ -212,6 +230,7 @@ def _configure_bitblas_matmul( bits, ): from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] quant_config = self.quant_config with_scaling = False @@ -248,30 +267,33 @@ def _configure_bitblas_matmul( zeros_mode=zeros_mode, ) self.bitblas_matmul = self._get_or_create_bitblas_operator( - matmul_config, enable_tuning) + matmul_config, enable_tuning + ) def _get_or_create_bitblas_operator(self, config, enable_tuning): from bitblas import Matmul, auto_detect_nvidia_target from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() BITBLAS_TARGET = auto_detect_nvidia_target() if global_operator_cache.size() == 0: - global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, - BITBLAS_TARGET) + global_operator_cache.load_from_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: - bitblas_matmul = Matmul(config, - target=BITBLAS_TARGET, - enable_tuning=False) + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) if enable_tuning: bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) global_operator_cache.save_into_database( - BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + BITBLAS_DATABASE_PATH, BITBLAS_TARGET + ) TUNING_MESSAGE = ( - f"BitBLAS Operator {config} tuned and saved to database.") + f"BitBLAS Operator {config} tuned and saved to database." + ) logger.info(TUNING_MESSAGE) else: _message = f"BitBLAS Operator {config} created without tuning. " @@ -287,7 +309,7 @@ def apply_gptq_bitblas_linear( x: torch.Tensor, ) -> torch.Tensor: output_size_per_partition = self.config.partition_weight_shape[1] - out_shape = x.shape[:-1] + (output_size_per_partition, ) + out_shape = x.shape[:-1] + (output_size_per_partition,) args = [x, layer.qweight, layer.scales] if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined] args.append(layer.qzeros) @@ -297,5 +319,6 @@ def apply_gptq_bitblas_linear( def apply_weights(self, layer, x, bias=None): NOT_IMPLEMENT_MESSAGE = ( f"{self.__class__.__name__}.apply_weights is not implemented. " - "Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead") + "Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead" + ) raise NotImplementedError(NOT_IMPLEMENT_MESSAGE) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py index f80af548f019..281fca7888ab 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py @@ -6,44 +6,49 @@ import torch -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.scalar_type import scalar_types from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig _CONCH_SUPPORTED_WEIGHT_TYPES: Final = [ - scalar_types.uint4, scalar_types.uint8, scalar_types.uint4b8, - scalar_types.uint8b128 + scalar_types.uint4, + scalar_types.uint8, + scalar_types.uint4b8, + scalar_types.uint8b128, ] _CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128] class ConchLinearKernel(MPLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 80 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES: - error_msg = f"Weight type ({c.weight_type}) not supported by "\ - "ConchLinearKernel, supported types are: " \ - f"{_CONCH_SUPPORTED_WEIGHT_TYPES}" + error_msg = ( + f"Weight type ({c.weight_type}) not supported by " + "ConchLinearKernel, supported types are: " + f"{_CONCH_SUPPORTED_WEIGHT_TYPES}" + ) return False, error_msg if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES: - error_msg = f"Group size ({c.group_size}) not supported by "\ - "ConchLinearKernel, supported group sizes are: " \ - f"{_CONCH_SUPPORTED_GROUP_SIZES}" + error_msg = ( + f"Group size ({c.group_size}) not supported by " + "ConchLinearKernel, supported group sizes are: " + f"{_CONCH_SUPPORTED_GROUP_SIZES}" + ) return False, error_msg if find_spec("conch") is None: - error_msg = "conch-triton-kernels is not installed, please "\ - "install it via `pip install conch-triton-kernels` "\ - "and try again!" + error_msg = ( + "conch-triton-kernels is not installed, please " + "install it via `pip install conch-triton-kernels` " + "and try again!" + ) return False, error_msg return True, None @@ -52,7 +57,6 @@ def can_implement(cls, # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) @@ -68,10 +72,12 @@ def transform_w_s(x): self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: from conch.ops.quantization.gemm import mixed_precision_gemm w_q, w_s, w_zp, _ = self._get_weight_params(layer) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py index 9e23c0dd3595..f5df7a244b42 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -7,10 +7,8 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -18,26 +16,22 @@ class CutlassW4A8LinearKernel(MPLinearKernel): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # dynamic per-tok fp8 activation quantization - self.quant_fp8 = QuantFP8(static=False, - group_shape=GroupShape.PER_TOKEN) + self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN) @classmethod def get_min_capability(cls) -> int: return 90 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_cuda(): return False, "CUTLASS only supported on CUDA" if not current_platform.is_device_capability(90): - return False, "CUTLASS W4A8 requires compute capability of 90 "\ - "(Hopper)" + return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)" if c.act_type != torch.float8_e4m3fn: return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations" @@ -49,8 +43,11 @@ def can_implement(cls, return False, "Zero points not supported by CUTLASS W4A8" if c.weight_type != scalar_types.int4: - return False, f"Quant type ({c.weight_type}) not supported by "\ - "CUTLASS W4A8, only supported int4" + return ( + False, + f"Quant type ({c.weight_type}) not supported by " + "CUTLASS W4A8, only supported int4", + ) # TODO(czhu): support -1 (column-wise) if c.group_size != 128: @@ -58,12 +55,16 @@ def can_implement(cls, in_features, out_features = c.partition_weight_shape if in_features % 128 or out_features % 128: - return False, "K and N must be divisible by 128, got "\ - f"{c.partition_weight_shape}" + return ( + False, + f"K and N must be divisible by 128, got {c.partition_weight_shape}", + ) if c.out_type != torch.bfloat16: - return False, "Only bfloat16 output type currently supported"\ - f"got {c.out_type=}" + return ( + False, + f"Only bfloat16 output type currently supportedgot {c.out_type=}", + ) return True, None @@ -71,13 +72,11 @@ def can_implement(cls, # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): - # TODO(czhu): optimize speed/mem usage def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) - x.data = ops.cutlass_encode_and_reorder_int4b( - x.data.t().contiguous().t()) + x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t()) return x def transform_w_s(x): @@ -92,24 +91,28 @@ def transform_w_s(x): self._transform_param(layer, self.w_s_name, transform_w_s) self._transform_param(layer, "weight_chan_scale", lambda x: x) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config w_q, w_s, _, _ = self._get_weight_params(layer) w_ch_s = layer.weight_chan_scale x_2d = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) x_2d, act_scales = self.quant_fp8(x_2d) - output = ops.cutlass_w4a8_mm(a=x_2d, - b_q=w_q, - b_group_scales=w_s, - b_group_size=c.group_size, - a_token_scales=act_scales, - b_channel_scales=w_ch_s) + output = ops.cutlass_w4a8_mm( + a=x_2d, + b_q=w_q, + b_group_scales=w_s, + b_group_size=c.group_size, + a_token_scales=act_scales, + b_channel_scales=w_ch_s, + ) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py index 7bd326f47f9e..7631236e6f64 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py @@ -20,37 +20,45 @@ def get_min_capability(cls) -> int: return 1 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_cpu(): return False, "Only CPU is supported" if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: return False, f"Unsupported quant type {c.weight_type}" - if current_platform.get_cpu_architecture( - ) == CpuArchEnum.ARM and c.act_type not in [ + if ( + current_platform.get_cpu_architecture() == CpuArchEnum.ARM + and c.act_type + not in [ torch.float32, - ]: - return False, "Dynamic4bitLinearKernel on Arm requires"\ - " Float32 activations" + ] + ): + return False, "Dynamic4bitLinearKernel on Arm requires Float32 activations" if c.full_weight_shape[0] % c.group_size != 0: - return False, f"Group size ({c.group_size}) does not evenly divide"\ - " the number of input features "\ - f"({c.full_weight_shape[0]})" + return ( + False, + f"Group size ({c.group_size}) does not evenly divide" + " the number of input features " + f"({c.full_weight_shape[0]})", + ) if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: try: # Attempt to retrieve the operation _ = torch.ops.aten._dyn_quant_matmul_4bit except AttributeError: - return False, f"PyTorch {torch.__version__} does not support"\ - " _dyn_quant_matmul_4bit. Install a newer version" + return ( + False, + f"PyTorch {torch.__version__} does not support" + " _dyn_quant_matmul_4bit. Install a newer version", + ) return True, None def process_weights_after_loading(self, layer: torch.nn.Module): c = self.config packed_weight = getattr(layer, self.w_q_name) packed_weight = packed_weight.add(8) - uint8_packed = (packed_weight[::, 1::2] << 4 - | packed_weight[::, ::2]).to(torch.uint8) + uint8_packed = (packed_weight[::, 1::2] << 4 | packed_weight[::, ::2]).to( + torch.uint8 + ) scales = getattr(layer, self.w_s_name) block_size = c.group_size @@ -71,22 +79,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module): # Repack weights as per kernel requirement w = torch.ops.aten._dyn_quant_pack_4bit_weight( - uint8_packed, scales, layer.bias, block_size, - c.partition_weight_shape[0], c.partition_weight_shape[1]) - replace_parameter(layer, self.w_q_name, - torch.nn.Parameter(w, requires_grad=False)) + uint8_packed, + scales, + layer.bias, + block_size, + c.partition_weight_shape[0], + c.partition_weight_shape[1], + ) + replace_parameter( + layer, self.w_q_name, torch.nn.Parameter(w, requires_grad=False) + ) setattr(layer, self.w_s_name, None) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config x_2d = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) w_q = getattr(layer, self.w_q_name) output = torch.ops.aten._dyn_quant_matmul_4bit( - x_2d, w_q, c.group_size, c.partition_weight_shape[0], - c.partition_weight_shape[1]) + x_2d, + w_q, + c.group_size, + c.partition_weight_shape[0], + c.partition_weight_shape[1], + ) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py index fef333e862d5..a57d3f65267e 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py @@ -7,9 +7,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_quantized_values_into_int32) -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) + pack_quantized_values_into_int32, +) +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.scalar_type import scalar_types from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig @@ -25,31 +25,41 @@ def get_min_capability(cls) -> int: return 60 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: - if c.has_g_idx and\ - c.partition_weight_shape[0] != c.full_weight_shape[0]: - return False, "Act reordering currently not supported by Exllama, "\ - "when the input features are partitioned across "\ - "devices" + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]: + return ( + False, + "Act reordering currently not supported by Exllama, " + "when the input features are partitioned across " + "devices", + ) if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0: - return False, "Output features must be a multiple of the pack " \ - "factor (32 / num_bits) so that we can correctly " \ - "pack the zero points" + return ( + False, + "Output features must be a multiple of the pack " + "factor (32 / num_bits) so that we can correctly " + "pack the zero points", + ) if c.act_type != torch.float16: return False, "Exllama only supports float16 activations" if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: - return False, f"Quant type ({c.weight_type}) not supported by "\ - "Exllama, supported types are: "\ - f"{cls.SUPPORTED_QUANT_TYPES}" + return ( + False, + f"Quant type ({c.weight_type}) not supported by " + "Exllama, supported types are: " + f"{cls.SUPPORTED_QUANT_TYPES}", + ) if c.full_weight_shape[0] % c.group_size != 0: - return False, f"Group size ({c.group_size}) does not evenly divide"\ - " the number of input features "\ - f"({c.full_weight_shape[0]})" + return ( + False, + f"Group size ({c.group_size}) does not evenly divide" + " the number of input features " + f"({c.full_weight_shape[0]})", + ) return True, None @@ -70,21 +80,23 @@ def process_weights_after_loading(self, layer: torch.nn.Module): # exllama kernel adding 1 to the zero points during inference) # Documentation of the bug can be found here: # https://garden.danieldk.eu/GPTQ-Checkpoint-Format - zeros = torch.full((groups, out_features), - c.weight_type.bias - 1, - dtype=torch.int32, - device=device) + zeros = torch.full( + (groups, out_features), + c.weight_type.bias - 1, + dtype=torch.int32, + device=device, + ) else: raise NotImplementedError( "A 0 zero-point is not supported by Exllama due to " "a bug in the original GPTQ checkpoint format leading to " "exllama kernel adding 1 to the zero points during " - "inference") - zeros = pack_quantized_values_into_int32(zeros, - c.weight_type, - packed_dim=1) - setattr(layer, self.w_zp_name, - torch.nn.Parameter(zeros, requires_grad=False)) + "inference" + ) + zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1) + setattr( + layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False) + ) if c.has_g_idx: @@ -96,10 +108,9 @@ def transform_w_g_idx(x): self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) else: self.w_gidx_name = "g_idx" - empty_g_idx = torch.nn.Parameter(torch.empty((0, ), - dtype=torch.int, - device=device), - requires_grad=False) + empty_g_idx = torch.nn.Parameter( + torch.empty((0,), dtype=torch.int, device=device), requires_grad=False + ) setattr(layer, self.w_gidx_name, empty_g_idx) def transform_w_q(x): @@ -122,21 +133,24 @@ def transform_w_s(x): self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config x_2d = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) assert w_zp is not None, "Zero points are required by Exllama" assert w_g_idx is not None, "Group index is required by Exllama" - output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True, - c.weight_type.size_bits) + output = ops.gptq_gemm( + x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits + ) if bias is not None: output.add_(bias) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index da951ddab2e4..df2f8fedce7e 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -8,26 +8,27 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.machete_utils import ( - check_machete_supports_shape, query_machete_supported_group_sizes, - query_machete_supported_quant_types) + check_machete_supports_shape, + query_machete_supported_group_sizes, + query_machete_supported_quant_types, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_quantized_values_into_int32, unpack_quantized_values_into_int32) -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) + pack_quantized_values_into_int32, + unpack_quantized_values_into_int32, +) +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.platforms import current_platform from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig class MacheteLinearKernel(MPLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 90 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: # Machete uses CUTLASS, so it can only be compatible with Nvidia if not current_platform.is_cuda(): return False, "Machete only supported on CUDA" @@ -35,25 +36,33 @@ def can_implement(cls, if not current_platform.is_device_capability(90): return False, "Machete requires compute capability of 90 (Hopper)" - if c.has_g_idx and\ - c.partition_weight_shape[0] != c.full_weight_shape[0]: - return False, "Act reordering currently not supported by Machete, "\ - "when the input features are partitioned across "\ - "devices" - - if c.weight_type not in query_machete_supported_quant_types( - c.zero_points): - return False, f"Quant type ({c.weight_type}) not supported by "\ - "Machete, supported types are: "\ - f"{query_machete_supported_quant_types(c.zero_points)}" + if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]: + return ( + False, + "Act reordering currently not supported by Machete, " + "when the input features are partitioned across " + "devices", + ) + + if c.weight_type not in query_machete_supported_quant_types(c.zero_points): + return ( + False, + f"Quant type ({c.weight_type}) not supported by " + "Machete, supported types are: " + f"{query_machete_supported_quant_types(c.zero_points)}", + ) if c.group_size not in query_machete_supported_group_sizes(c.act_type): - return False, f"Group size ({c.group_size}) not supported by "\ - "Machete, supported group sizes are: "\ - f"{query_machete_supported_group_sizes(c.act_type)}" + return ( + False, + f"Group size ({c.group_size}) not supported by " + "Machete, supported group sizes are: " + f"{query_machete_supported_group_sizes(c.act_type)}", + ) - return check_machete_supports_shape(c.partition_weight_shape[0], - c.partition_weight_shape[1]) + return check_machete_supports_shape( + c.partition_weight_shape[0], c.partition_weight_shape[1] + ) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} @@ -64,30 +73,33 @@ def process_weights_after_loading(self, layer: torch.nn.Module): if c.has_g_idx: assert self.w_gidx_name is not None - perm = torch.argsort(getattr(layer, self.w_gidx_name))\ - .to(torch.int) + perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int) self.act_perm = lambda x: x[:, perm] # use `ops.permute_cols` if possible - if c.act_type in [torch.float16, torch.bfloat16] \ - and c.partition_weight_shape[0] % 8 == 0: + if ( + c.act_type in [torch.float16, torch.bfloat16] + and c.partition_weight_shape[0] % 8 == 0 + ): self.act_perm = partial(ops.permute_cols, perm=perm) def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) if c.has_g_idx: - x_unpacked = unpack_quantized_values_into_int32(x.data, - c.weight_type, - packed_dim=0) + x_unpacked = unpack_quantized_values_into_int32( + x.data, c.weight_type, packed_dim=0 + ) x_perm = x_unpacked[perm, :] - x.data = pack_quantized_values_into_int32(x_perm, - c.weight_type, - packed_dim=0) - x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), - a_type=c.act_type, - b_type=c.weight_type, - group_scales_type=c.act_type) + x.data = pack_quantized_values_into_int32( + x_perm, c.weight_type, packed_dim=0 + ) + x.data = ops.machete_prepack_B( + x.data.t().contiguous().t(), + a_type=c.act_type, + b_type=c.weight_type, + group_scales_type=c.act_type, + ) return x def transform_w_s(x): @@ -99,9 +111,9 @@ def transform_w_s(x): def transform_w_zp(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1) - x_unpacked = unpack_quantized_values_into_int32(x.data, - c.weight_type, - packed_dim=1) + x_unpacked = unpack_quantized_values_into_int32( + x.data, c.weight_type, packed_dim=1 + ) w_s = getattr(layer, self.w_s_name).data # pre-apply scales to zero-points x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous() @@ -113,15 +125,17 @@ def transform_w_zp(x): if c.zero_points: self._transform_param(layer, self.w_zp_name, transform_w_zp) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config w_q, w_s, w_zp, _ = self._get_weight_params(layer) x_2d = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1],) if c.has_g_idx: x_2d = self.act_perm(x_2d) @@ -131,12 +145,14 @@ def apply_weights(self, else: w_zp = None - output = ops.machete_mm(a=x_2d, - b_q=w_q, - b_type=c.weight_type, - b_group_zeros=w_zp, - b_group_scales=w_s, - b_group_size=c.group_size) + output = ops.machete_mm( + a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_group_zeros=w_zp, + b_group_scales=w_s, + b_group_size=c.group_size, + ) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index 5eb99383097b..0be448e4e3d8 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -7,46 +7,58 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, - check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, - marlin_sort_g_idx, marlin_zero_points, query_marlin_supported_quant_types, - unpack_cols) -from vllm.model_executor.parameter import (BasevLLMParameter, - permute_param_layout_) + MARLIN_SUPPORTED_GROUP_SIZES, + apply_gptq_marlin_linear, + check_marlin_supports_shape, + marlin_is_k_full, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + marlin_sort_g_idx, + marlin_zero_points, + query_marlin_supported_quant_types, + unpack_cols, +) +from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.platforms import current_platform from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig class MarlinLinearKernel(MPLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 80 @classmethod - def can_implement(cls, - c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: # Marlin uses inline PTX, so it can only be compatible with Nvidia if not current_platform.is_cuda(): return False, "Marlin only supported on CUDA" quant_types = query_marlin_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: - return False, f"Quant type ({c.weight_type}) not supported by"\ - f" Marlin, supported types are: {quant_types}" + return ( + False, + f"Quant type ({c.weight_type}) not supported by" + f" Marlin, supported types are: {quant_types}", + ) if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return False, f"Group size ({c.group_size}) not supported by "\ - "Marlin, supported group sizes are: "\ - f"{MARLIN_SUPPORTED_GROUP_SIZES}" + return ( + False, + f"Group size ({c.group_size}) not supported by " + "Marlin, supported group sizes are: " + f"{MARLIN_SUPPORTED_GROUP_SIZES}", + ) return check_marlin_supports_shape( c.partition_weight_shape[1], # out_features c.partition_weight_shape[0], # in_features c.full_weight_shape[0], # in_features - c.group_size) + c.group_size, + ) # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} @@ -55,7 +67,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = getattr(layer, self.w_q_name).device c = self.config - row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) # Allocate marlin workspace. @@ -71,25 +83,30 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def transform_w_q(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) - x.data = ops.gptq_marlin_repack(x.data.contiguous(), - perm=layer.g_idx_sort_indices, - size_k=c.partition_weight_shape[0], - size_n=c.partition_weight_shape[1], - num_bits=c.weight_type.size_bits) + x.data = ops.gptq_marlin_repack( + x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits, + ) return x def transform_w_s(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1) - x.data = marlin_permute_scales(x.data.contiguous(), - size_k=c.partition_weight_shape[0], - size_n=c.partition_weight_shape[1], - group_size=c.group_size) + x.data = marlin_permute_scales( + x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size, + ) return x if c.has_g_idx: g_idx, g_idx_sort_indices = marlin_sort_g_idx( - getattr(layer, self.w_gidx_name)) + getattr(layer, self.w_gidx_name) + ) self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) layer.g_idx_sort_indices = g_idx_sort_indices else: @@ -97,16 +114,24 @@ def transform_w_s(x): layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) if c.zero_points: - grouped_k = (c.partition_weight_shape[0] // - c.group_size if c.group_size != -1 else 1) - self._transform_param(layer, self.w_zp_name, lambda x: \ - marlin_zero_points( - unpack_cols(x.t(), c.weight_type.size_bits, - grouped_k, - c.partition_weight_shape[1]), + grouped_k = ( + c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1 + ) + self._transform_param( + layer, + self.w_zp_name, + lambda x: marlin_zero_points( + unpack_cols( + x.t(), + c.weight_type.size_bits, + grouped_k, + c.partition_weight_shape[1], + ), size_k=grouped_k, size_n=c.partition_weight_shape[1], - num_bits=c.weight_type.size_bits)) + num_bits=c.weight_type.size_bits, + ), + ) else: setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) self._transform_param(layer, self.w_q_name, transform_w_q) @@ -115,10 +140,12 @@ def transform_w_s(x): if hasattr(layer, "bias") and layer.bias is not None: layer.bias.data = marlin_permute_bias(layer.bias) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: c = self.config w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) @@ -136,4 +163,5 @@ def apply_weights(self, input_size_per_partition=c.partition_weight_shape[0], output_size_per_partition=c.partition_weight_shape[1], is_k_full=self.is_k_full, - bias=bias) + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py index 9ebf5f303792..d9b999e3d5dd 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -16,7 +16,6 @@ class ScaledMMLinearLayerConfig: class ScaledMMLinearKernel(ABC): - @classmethod @abstractmethod def get_min_capability(cls) -> int: @@ -24,13 +23,18 @@ def get_min_capability(cls) -> int: @classmethod @abstractmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: raise NotImplementedError - def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str, - w_s_param_name: str, i_s_param_name: str, - i_zp_param_name: str, azp_adj_param_name: str) -> None: + def __init__( + self, + c: ScaledMMLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + i_s_param_name: str, + i_zp_param_name: str, + azp_adj_param_name: str, + ) -> None: assert self.can_implement(c) self.config = c self.w_q_name = w_q_param_name @@ -44,20 +48,23 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: raise NotImplementedError @abstractmethod - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: raise NotImplementedError def _get_weight_params( - self, layer: torch.nn.Module) -> tuple[ - torch.Tensor, # weight - torch.Tensor, # weight_scale - Optional[torch.Tensor], # input_scale, - Optional[torch.Tensor], # input_zp - Optional[torch.Tensor], # azp_adj - ]: + self, layer: torch.nn.Module + ) -> tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + Optional[torch.Tensor], # input_scale, + Optional[torch.Tensor], # input_zp + Optional[torch.Tensor], # azp_adj + ]: return ( getattr(layer, self.w_q_name), getattr(layer, self.w_s_name), diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 2bc68ab3ebd1..ee5416bae01c 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -5,17 +5,24 @@ from typing import Optional from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( - AiterScaledMMLinearKernel) + AiterScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( - CPUScaledMMLinearKernel) + CPUScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( - CutlassScaledMMLinearKernel) + CutlassScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 - ScaledMMLinearKernel, ScaledMMLinearLayerConfig) + ScaledMMLinearKernel, + ScaledMMLinearLayerConfig, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( - TritonScaledMMLinearKernel) + TritonScaledMMLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( - XLAScaledMMLinearKernel) + XLAScaledMMLinearKernel, +) from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) @@ -28,19 +35,18 @@ def choose_scaled_mm_linear_kernel( - config: ScaledMMLinearLayerConfig, - compute_capability: Optional[int] = None + config: ScaledMMLinearLayerConfig, compute_capability: Optional[int] = None ) -> type[ScaledMMLinearKernel]: """ - Choose an ScaledMMLinearKernel that can implement the given config for the - given compute capability. Attempts to choose the best kernel in terms of + Choose an ScaledMMLinearKernel that can implement the given config for the + given compute capability. Attempts to choose the best kernel in terms of performance. Args: - config (ScaledMMLinearLayerConfig): Description of the linear layer + config (ScaledMMLinearLayerConfig): Description of the linear layer to be implemented. compute_capability (Optional[int], optional): The compute capability of - the target device, if None uses `current_platform` to get the + the target device, if None uses `current_platform` to get the compute capability. Defaults to None. Raises: @@ -57,22 +63,25 @@ def choose_scaled_mm_linear_kernel( failure_reasons = [] for kernel in _POSSIBLE_KERNELS[current_platform._enum]: - if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ - .split(","): + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","): failure_reasons.append( - f' {kernel.__name__} disabled by environment variable') + f" {kernel.__name__} disabled by environment variable" + ) continue # If the current platform uses compute_capability, # make sure the kernel supports the compute cability. if compute_capability is not None: kernel_min_capability = kernel.get_min_capability() - if (kernel_min_capability is not None - and kernel_min_capability > compute_capability): + if ( + kernel_min_capability is not None + and kernel_min_capability > compute_capability + ): failure_reasons.append( f"{kernel.__name__} requires capability " f"{kernel_min_capability}, current compute capability " - f"is {compute_capability}") + f"is {compute_capability}" + ) continue can_implement, failure_reason = kernel.can_implement(config) @@ -80,10 +89,10 @@ def choose_scaled_mm_linear_kernel( return kernel else: failure_reasons.append( - f' {kernel.__name__} cannot implement due to: {failure_reason}' + f" {kernel.__name__} cannot implement due to: {failure_reason}" ) raise ValueError( - "Failed to find a kernel that can implement the "\ - "ScaledMM linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) + "Failed to find a kernel that can implement the " + "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons) + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index e8e950a4bb7b..e97beefdd9c2 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -22,7 +22,6 @@ def rocm_aiter_gemm_w8a8_impl( bias: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - from aiter import gemm_a8w8_CK # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects @@ -40,7 +39,6 @@ def rocm_aiter_gemm_w8a8_fake( bias: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - m = A.shape[0] n = B.shape[0] Y = torch.empty(m, n, dtype=output_dtype, device=A.device) @@ -56,50 +54,53 @@ def rocm_aiter_gemm_w8a8_fake( class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 90 @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_rocm(): return ( False, - "AiterScaledMMLinearKernel requires `aiter` which is not " + - "currently supported on non-ROCm platform.") + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "currently supported on non-ROCm platform.", + ) try: import aiter # noqa: F401 # deliberately attempt to import aiter except Exception: return ( False, - "AiterScaledMMLinearKernel requires `aiter` which is not " + - "installed on ROCm.") + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "installed on ROCm.", + ) # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled - if not ( - envs.VLLM_ROCM_USE_AITER_LINEAR \ - and envs.VLLM_ROCM_USE_AITER - ): - return (False, "AiterScaledMMLinearKernel is disabled. " + - "Enable by setting `VLLM_ROCM_USE_AITER=1` " + - "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + - "`VLLM_ROCM_USE_AITER_LINEAR` default is True.") + if not (envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER): + return ( + False, + "AiterScaledMMLinearKernel is disabled. " + + "Enable by setting `VLLM_ROCM_USE_AITER=1` " + + "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.", + ) if not c.input_symmetric: - return (False, - "AiterScaledMMLinearKernel only supports symmetric " + - "quantization.") + return ( + False, + "AiterScaledMMLinearKernel only supports symmetric " + "quantization.", + ) return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ `AiterScaledMMLinearKernel` implements a fused version of `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` @@ -116,29 +117,27 @@ def apply_weights(self, # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. symmetric = azp_adj is None - assert symmetric, ("AiterScaledMMLinearKernel only supports" - " symmetric quantization.") - x_q, x_s, x_zp = ops.scaled_int8_quant(x, - i_s, - i_zp, - symmetric=symmetric) - - assert x_zp is None, ("AiterScaledMMLinearKernel only supports" - " symmetric quantization.") + assert symmetric, ( + "AiterScaledMMLinearKernel only supports symmetric quantization." + ) + x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric) + + assert x_zp is None, ( + "AiterScaledMMLinearKernel only supports symmetric quantization." + ) out_dtype = x.dtype - assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == w_q.shape[ - 1] and bias.dtype == out_dtype + assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype m = x_q.shape[0] # a n = w_q.shape[1] # b - per_tensor_scale_a = (x_s.numel() == 1) - per_tensor_scale_b = (w_s.numel() == 1) - per_token_scale_a = (x_s.numel() == m) - per_channel_scale_b = (w_s.numel() == n) + per_tensor_scale_a = x_s.numel() == 1 + per_tensor_scale_b = w_s.numel() == 1 + per_token_scale_a = x_s.numel() == m + per_channel_scale_b = w_s.numel() == n # @TODO: # Maybe broadcast the per-tensor-scale into per-channel-scale @@ -146,16 +145,19 @@ def apply_weights(self, # For now, it only supports: # - per-tensor-per-tensor a8w8 scaled GEMM, and # - per-token-per-channel a8w8 scaled GEMM - assert ((per_tensor_scale_a and per_tensor_scale_b) - or (per_token_scale_a and per_channel_scale_b)), ( - "Currently only support per-tensor-per-tensor GEMM " + - " and per-token-per-channel GEMM through AITER" - " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + - "does not support AITER block scaled GEMM.") + assert (per_tensor_scale_a and per_tensor_scale_b) or ( + per_token_scale_a and per_channel_scale_b + ), ( + "Currently only support per-tensor-per-tensor GEMM " + + " and per-token-per-channel GEMM through AITER" + " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + + "does not support AITER block scaled GEMM." + ) # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # a to be [M, K] # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s, - bias, out_dtype) + return torch.ops.vllm.rocm_aiter_gemm_w8a8( + x_q, w_q.t(), x_s, w_s, bias, out_dtype + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py index 59d2b5bce962..cb00b0c8af21 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py @@ -9,24 +9,22 @@ from vllm import envs from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise) + convert_to_channelwise, +) from vllm.model_executor.layers.utils import check_cpu_sgl_kernel from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from .ScaledMMLinearKernel import (ScaledMMLinearKernel, - ScaledMMLinearLayerConfig) +from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig class CPUScaledMMLinearKernel(ScaledMMLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_cpu(): return False, "CPUScaledMM requires running on CPU." @@ -36,9 +34,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight = getattr(layer, self.w_q_name) dtype = weight.dtype N, K = weight.size() - if (current_platform.get_cpu_architecture() == CpuArchEnum.X86 - and envs.VLLM_CPU_SGL_KERNEL and self.config.input_symmetric - and check_cpu_sgl_kernel(N, K, dtype)): + if ( + current_platform.get_cpu_architecture() == CpuArchEnum.X86 + and envs.VLLM_CPU_SGL_KERNEL + and self.config.input_symmetric + and check_cpu_sgl_kernel(N, K, dtype) + ): self.linear_method = self._apply_weights_sgl self.process_weights_for_sgl(layer) else: @@ -50,8 +51,10 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # Transpose to [K, N] for convenience weight = getattr(layer, self.w_q_name) replace_parameter( - layer, self.w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False)) + layer, + self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False), + ) # WEIGHT SCALE # oneDNN kernels support only per-tensor and per-channel. @@ -60,11 +63,12 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: is_fused_module = len(layer.logical_widths) > 1 weight_scale = getattr(layer, self.w_s_name) if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, - layer.logical_widths) + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( - layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False)) + layer, + self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) # INPUT SCALE if self.config.is_static_input_scheme: @@ -72,8 +76,10 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: if self.config.input_symmetric: replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False)) + layer, + self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False), + ) setattr(layer, self.i_zp_name, None) else: input_zero_point = getattr(layer, self.i_zp_name) @@ -84,16 +90,17 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: range_max = (input_scale * (int8_traits.max - azps)).max() range_min = (input_scale * (int8_traits.min - azps)).min() - scale = (range_max - range_min) / (int8_traits.max - - int8_traits.min) + scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(scale, requires_grad=False)) + layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) + ) - azp = (int8_traits.min - - range_min / scale).round().to(dtype=torch.int32) - replace_parameter(layer, self.i_zp_name, - torch.nn.Parameter(azp, requires_grad=False)) + azp = ( + (int8_traits.min - range_min / scale).round().to(dtype=torch.int32) + ) + replace_parameter( + layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + ) else: setattr(layer, self.i_s_name, None) @@ -105,14 +112,16 @@ def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: # s_a * s_b * [(A - zp_a)B] + bias = # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = # s_a * GEMM_output - s_a * zp_a * adj + bias - if not (self.config.input_symmetric - and self.config.is_static_input_scheme): + if not (self.config.input_symmetric and self.config.is_static_input_scheme): weight = getattr(layer, self.w_q_name) weight_scale = getattr(layer, self.w_s_name) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) azp_adj = azp_adj * weight_scale.squeeze() - setattr(layer, self.azp_adj_name, - torch.nn.Parameter(azp_adj, requires_grad=False)) + setattr( + layer, + self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False), + ) else: setattr(layer, self.azp_adj_name, None) @@ -135,34 +144,37 @@ def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: weight = getattr(layer, self.w_q_name) packed_weight = torch.ops._C.convert_weight_packed(weight) replace_parameter( - layer, self.w_q_name, - torch.nn.Parameter(packed_weight, requires_grad=False)) + layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) + ) if layer.bias is not None: bias = layer.bias layer.register_parameter( - "bias_fp32", - torch.nn.Parameter(bias.float().data, requires_grad=False)) + "bias_fp32", torch.nn.Parameter(bias.float().data, requires_grad=False) + ) # WEIGHT SCALE # CPU SGL kernels only support per-channel. # For per-tensor quant, convert to the per-channel case. weight_scale = getattr(layer, self.w_s_name) if not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, - layer.logical_widths) + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( - layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False)) + layer, + self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) setattr(layer, self.i_s_name, None) setattr(layer, self.i_zp_name, None) setattr(layer, self.azp_adj_name, None) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return self.linear_method( layer, x, @@ -170,31 +182,33 @@ def apply_weights(self, ) def _apply_weights_onednn( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. x_q, x_s, x_zp = ops.onednn_scaled_int8_quant( - x, i_s, i_zp, self.config.input_symmetric) + x, i_s, i_zp, self.config.input_symmetric + ) m = x.size(0) n = self.dnnl_handler.n out = torch.empty((m, n), dtype=x.dtype) - ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, - bias) + ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias) return out def _apply_weights_sgl( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: w_q, w_s, _, _, _ = self._get_weight_params(layer) return torch.ops._C.int8_scaled_mm_with_quant( x, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 2f982f96b0d0..f1dafdf14c7a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -8,23 +8,20 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise) + convert_to_channelwise, +) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import (ScaledMMLinearKernel, - ScaledMMLinearLayerConfig) +from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: - + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_cuda(): return False, "CutlassScaledMM requires running on CUDA." @@ -35,8 +32,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Cutlass kernels need transposed weight. weight = getattr(layer, self.w_q_name) replace_parameter( - layer, self.w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False)) + layer, + self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False), + ) # WEIGHT SCALE # Cutlass kernels support only per-tensor and per-channel. @@ -45,11 +44,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: is_fused_module = len(layer.logical_widths) > 1 weight_scale = getattr(layer, self.w_s_name) if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, - layer.logical_widths) + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) replace_parameter( - layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False)) + layer, + self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) # INPUT SCALE if self.config.is_static_input_scheme: @@ -57,8 +57,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.config.input_symmetric: replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False)) + layer, + self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False), + ) setattr(layer, self.i_zp_name, None) else: input_zero_point = getattr(layer, self.i_zp_name) @@ -69,17 +71,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: range_max = (input_scale * (int8_traits.max - azps)).max() range_min = (input_scale * (int8_traits.min - azps)).min() - scale = (range_max - range_min) / (int8_traits.max - - int8_traits.min) + scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(scale, requires_grad=False)) + layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) + ) # AZP loaded as int8 but used as int32 - azp = (int8_traits.min - - range_min / scale).to(dtype=torch.int32) - replace_parameter(layer, self.i_zp_name, - torch.nn.Parameter(azp, requires_grad=False)) + azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) + replace_parameter( + layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) + ) else: setattr(layer, self.i_s_name, None) @@ -97,41 +98,44 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # cutlass_w8a8 requires azp to be folded into azp_adj # in the per-tensor case azp_adj = getattr(layer, self.i_zp_name) * azp_adj - setattr(layer, self.azp_adj_name, - torch.nn.Parameter(azp_adj, requires_grad=False)) + setattr( + layer, + self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False), + ) else: setattr(layer, self.azp_adj_name, None) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) # ops.scaled_int8_quant supports both dynamic and static quant: # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. symmetric = azp_adj is None - x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(), - i_s, - i_zp, - symmetric=symmetric) + x_q, x_s, x_zp = ops.scaled_int8_quant( + x.contiguous(), i_s, i_zp, symmetric=symmetric + ) if x_zp is not None: # Currently, static is always per-tensor and dynamic is per-token static = i_zp is not None azp = None if static else x_zp - return ops.cutlass_scaled_mm_azp(x_q, - w_q, - scale_a=x_s, - scale_b=w_s, - out_dtype=x.dtype, - azp_adj=azp_adj, - azp=azp, - bias=bias) - return ops.cutlass_scaled_mm(x_q, - w_q, - scale_a=x_s, - scale_b=w_s, - out_dtype=x.dtype, - bias=bias) + return ops.cutlass_scaled_mm_azp( + x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + azp_adj=azp_adj, + azp=azp, + bias=bias, + ) + return ops.cutlass_scaled_mm( + x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias + ) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py index 817565cf2827..7e21afca5750 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -12,30 +12,32 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): - @classmethod def get_min_capability(cls) -> int: return 75 @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if current_platform.is_cpu(): return ( False, - "TritonScaledMMLinearKernel requires Triton which is not " + - "currently supported on CPU.") + "TritonScaledMMLinearKernel requires Triton which is not " + + "currently supported on CPU.", + ) if not c.input_symmetric: - return (False, - "TritonScaledMMLinearKernel only supports symmetric " + - "quantization.") + return ( + False, + "TritonScaledMMLinearKernel only supports symmetric " + "quantization.", + ) return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return super().apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 0b931b2d8b81..63eee1e28861 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -9,25 +9,23 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise) + convert_to_channelwise, +) from vllm.platforms import current_platform -from .ScaledMMLinearKernel import (ScaledMMLinearKernel, - ScaledMMLinearLayerConfig) +from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig class XLAScaledMMLinearKernel(ScaledMMLinearKernel): - @classmethod def get_min_capability(cls) -> int: raise NotImplementedError( "TPU platform does have a concept of compute capability, " - "this method should not be called.") + "this method should not be called." + ) @classmethod - def can_implement( - cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: - + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: if not current_platform.is_tpu(): return False, "ScaledMMXLA requires running on TPU." @@ -46,8 +44,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # WEIGHT # [out, in] (different than cutlass_scaled_mm) weight = getattr(layer, self.w_q_name) - replace_parameter(layer, self.w_q_name, - torch.nn.Parameter(weight.data, requires_grad=False)) + replace_parameter( + layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) + ) # WEIGHT SCALE # XLA kernels support only per-tensor and per-channel. @@ -56,14 +55,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: is_fused_module = len(layer.logical_widths) > 1 weight_scale = getattr(layer, self.w_s_name) if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, - layer.logical_widths) + weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) # [out_channel,] (different than cutlass_scaled_mm) weight_scale = weight_scale.squeeze(-1) replace_parameter( - layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False)) + layer, + self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False), + ) # Only support symmetric dynamic activation quantization. setattr(layer, self.i_s_name, None) @@ -74,8 +74,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # to specialize the graph since bias is not dynamic. warnings.filterwarnings( "ignore", - message= - "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501 + message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501 ) def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): @@ -84,14 +83,17 @@ def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): return x + bias - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: w_q, w_s, _, _, _ = self._get_weight_params(layer) # Required to register custom ops. import torch_xla.experimental.custom_kernel # noqa: F401 + out = torch.ops.xla.quantized_matmul_int8( x, w_q, diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 27e2b7846d38..78456dcf1ca5 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -5,7 +5,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -14,12 +16,12 @@ class BaseKVCacheMethod(QuantizeMethodBase): """ Quant method that adds `_k_scale` and `_v_scale` attributes to the - Attention layer to support loading those scaling factors from checkpoints. + Attention layer to support loading those scaling factors from checkpoints. The k/v_scale will be used to: - quantize k/v_cache entries before saving them to the cache - dequantize k/v_cache entries before fetching them from the cache - :param quant_config: the appropriate QuantizationConfig + :param quant_config: the appropriate QuantizationConfig """ def __init__(self, quant_config: QuantizationConfig): @@ -33,19 +35,14 @@ def create_weights(self, layer: torch.nn.Module): # Initialize the Q and KV cache scales to -1.0, an invalid value. # If the q and k/v_scales appear in the checkpoint, it will be # overwritten when loading weights. - layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), - requires_grad=False) - layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), - requires_grad=False) - layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), - requires_grad=False) + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) # Initialize P = softmax(QK^T) scales - layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), - requires_grad=False) + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: - raise RuntimeError( - f"{self.__class__.__name__}.apply should not be called.") + raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 @@ -77,16 +74,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: k_scale *= 2 v_scale *= 2 - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError( + "Only support per-tensor scaling factor for fp8 KV cache" + ) if layer.q_scale < 0.0: logger.warning_once( "Checkpoint does not provide a q scaling factor. " "Setting it to k_scale. This only matters for " - "FP8 Attention backends (flash-attn or flashinfer).") + "FP8 Attention backends (flash-attn or flashinfer)." + ) layer._q_scale.copy_(k_scale) layer._q_scale_float = k_scale @@ -95,12 +93,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer._v_scale.copy_(v_scale) layer._k_scale_float = k_scale layer._v_scale_float = v_scale - if (k_scale == 1.0 and v_scale == 1.0 - and "e5m2" not in layer.kv_cache_dtype): + if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: logger.warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. " "If this is unintended, verify that k/v_scale " - "scaling factors are properly set in the checkpoint.") + "scaling factors are properly set in the checkpoint." + ) if layer.q_scale > 0.0: q_scale = layer.q_scale @@ -116,26 +114,31 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: prob_scale = 1.0 - is_singleton_float = lambda x: isinstance(x, float) or isinstance( - x, torch.Tensor) and x.numel() == 1 and x.is_floating_point() - if not is_singleton_float(q_scale) or not is_singleton_float( - prob_scale): - raise ValueError("Only support per-tensor scaling factor" - "for fp8-quantized Q/prob") + is_singleton_float = ( + lambda x: isinstance(x, float) + or isinstance(x, torch.Tensor) + and x.numel() == 1 + and x.is_floating_point() + ) + if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale): + raise ValueError( + "Only support per-tensor scaling factorfor fp8-quantized Q/prob" + ) # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) - layer._q_scale_float = q_scale.item() if isinstance( - q_scale, torch.Tensor) else q_scale + layer._q_scale_float = ( + q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale + ) layer._prob_scale.copy_(prob_scale) - if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 - or prob_scale == 1.0): + if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 or prob_scale == 1.0): logger.warning_once( f"Using uncalibrated q_scale {q_scale} and/or prob_scale " f"{prob_scale} with fp8 attention. This may cause accuracy " "issues. Please make sure q/prob scaling factors are " - "available in the fp8 checkpoint.") + "available in the fp8 checkpoint." + ) del layer.k_scale del layer.v_scale diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 1ca82cdcbc78..8c074ebdc8db 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -12,40 +12,70 @@ from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, - nvfp4_moe_quant_config) + FusedMoEConfig, + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, + nvfp4_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - is_valid_flashinfer_cutlass_fused_moe) + is_valid_flashinfer_cutlass_fused_moe, +) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) + FusedMoE, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, - select_nvfp4_gemm_impl) + build_flashinfer_fp4_cutlass_moe_prepare_finalize, + reorder_w1w3_to_w3w1, + select_nvfp4_gemm_impl, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, + FlashinferMoeBackend, + apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, - register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, - select_cutlass_fp8_gemm_impl, swap_w13_to_w31) + flashinfer_cutlass_moe_fp8, + get_flashinfer_moe_backend, + register_moe_scaling_factors, + rotate_flashinfer_fp8_moe_weights, + select_cutlass_fp8_gemm_impl, + swap_w13_to_w31, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, is_fp4_marlin_supported, - prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) + apply_fp4_marlin_linear, + is_fp4_marlin_supported, + prepare_fp4_layer_for_marlin, + prepare_moe_fp4_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale) + GroupShape, + cutlass_fp4_supported, + is_layer_skipped, + swizzle_blockscale, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, requantize_with_max_scale) -from vllm.model_executor.parameter import (ModelWeightParameter, - PerTensorScaleParameter) + Fp8LinearOp, + requantize_with_max_scale, +) +from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from vllm.scalar_type import scalar_types from vllm.utils import next_power_of_2 -from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer, - has_flashinfer_moe) +from vllm.utils.flashinfer import ( + flashinfer_scaled_fp4_mm, + has_flashinfer, + has_flashinfer_moe, +) if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -70,8 +100,10 @@ def __init__( self.kv_cache_quant_method = kv_cache_quant_method self.exclude_modules = exclude_modules or [] if is_checkpoint_fp8_serialized: - logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" - " the format is experimental and could change.") + logger.warning( + "Detected ModelOpt fp8 checkpoint. Please note that" + " the format is experimental and could change." + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -91,12 +123,12 @@ def get_config_filenames(cls) -> list[str]: def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.exclude_modules is not None: - self.exclude_modules = hf_to_vllm_mapper.apply_list( - self.exclude_modules) + self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: """Detect if this ModelOpt config should be used based on quantization config.""" @@ -132,8 +164,7 @@ def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": # ModelOpt format: {"quantization": {"quant_algo": "..."}} quant_config = cls.get_from_keys(config, ["quantization"]) if not isinstance(quant_config, dict): - raise ValueError( - "Expected 'quantization' to be a dictionary in config") + raise ValueError("Expected 'quantization' to be a dictionary in config") quant_method = quant_config.get("quant_algo", "") if not quant_method: raise ValueError("Missing 'quant_algo' in quantization config") @@ -153,11 +184,11 @@ def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": f"ModelOpt currently only supports: {QUANT_ALGOS} " "quantizations in vLLM. Please check the " "`hf_quant_config.json` file for your model's " - "quant configuration.") - is_checkpoint_fp8_serialized = ("FP8" in quant_method) + "quant configuration." + ) + is_checkpoint_fp8_serialized = "FP8" in quant_method - return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, - exclude_modules) + return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules) def is_layer_excluded(self, prefix: str) -> bool: """ @@ -172,28 +203,32 @@ def is_layer_excluded(self, prefix: str) -> bool: return False # First check exact matching with fused layer support - if is_layer_skipped(prefix, self.exclude_modules, - self.packed_modules_mapping): + if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): return True # Then check substring matching for patterns not caught by exact match for module in self.exclude_modules: # Skip exact matches already handled above - if (module != prefix and - (module in prefix or - (prefix.startswith("language_model.") - and module in prefix.removeprefix("language_model.")))): + if module != prefix and ( + module in prefix + or ( + prefix.startswith("language_model.") + and module in prefix.removeprefix("language_model.") + ) + ): return True return False - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): if self.is_layer_excluded(prefix): return UnquantizedLinearMethod() # Check if this is a vision model layer that should not be quantized - if ("vision_tower" in prefix or "vision_model" in prefix): + if "vision_tower" in prefix or "vision_model" in prefix: return UnquantizedLinearMethod() return ModelOptFp8LinearMethod(self) elif isinstance(layer, Attention): @@ -218,7 +253,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = Fp8LinearOp( - act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) + act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR + ) def create_weights( self, @@ -236,29 +272,34 @@ def create_weights( layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_fp8_serialized else - params_dtype) - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=weight_dtype + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE - scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) @@ -268,11 +309,11 @@ def process_weights_after_loading(self, layer: Module) -> None: max_w_scale = layer.weight_scale.max() if not (layer.weight_scale == layer.weight_scale[0]).all(): max_w_scale, weight = requantize_with_max_scale( - layer.weight, layer.weight_scale, layer.logical_widths) + layer.weight, layer.weight_scale, layer.logical_widths + ) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) def apply( self, @@ -280,11 +321,13 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=layer.input_scale, - bias=bias) + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + ) class ModelOptFp8MoEMethod(FusedMoEMethodBase): @@ -304,7 +347,9 @@ def __init__( self.layer = layer self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_fp8_supported) + cutlass_fp8_supported, + ) + self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): @@ -314,13 +359,15 @@ def __init__( ) def maybe_make_prepare_finalize( - self, ) -> Optional[mk.FusedMoEPrepareAndFinalize]: + self, + ) -> Optional[mk.FusedMoEPrepareAndFinalize]: # TRT LLM not supported with all2all yet. if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: return None elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - prepare_finalize = ( - build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe)) + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + self.moe + ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize else: @@ -348,18 +395,21 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - # Use FP8 dtype if checkpoint is serialized - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_fp8_serialized else - params_dtype) + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) weight_loader = extra_weight_attrs.get("weight_loader") w13_weight = ModelWeightParameter( - data=torch.empty(num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=weight_dtype), + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=weight_dtype, + ), input_dim=2, output_dim=1, weight_loader=weight_loader, @@ -367,10 +417,12 @@ def create_weights( layer.register_parameter("w13_weight", w13_weight) w2_weight = ModelWeightParameter( - data=torch.empty(num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=weight_dtype), + data=torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=weight_dtype, + ), input_dim=2, output_dim=1, weight_loader=weight_loader, @@ -390,7 +442,7 @@ def create_weights( weight_loader=weight_loader, ) w2_weight_scale = PerTensorScaleParameter( - data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) @@ -398,15 +450,16 @@ def create_weights( # Set weight loader attributes for scales extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) # INPUT SCALES - Per-tensor scaling for ModelOpt w13_input_scale = PerTensorScaleParameter( - data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) w2_input_scale = PerTensorScaleParameter( - data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) @@ -417,22 +470,20 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: Only supports pre-quantized checkpoints with FP8 weights and scales. """ - layer.w13_weight = Parameter(layer.w13_weight.data, - requires_grad=False) + layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) from vllm._custom_ops import scaled_fp8_quant from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - per_tensor_dequantize) + per_tensor_dequantize, + ) # Handle scale parameters - if hasattr(layer, - "w13_weight_scale") and layer.w13_weight_scale is not None: + if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None: # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max of the w1 and w3 scales # then dequant and requant each expert. if layer.w13_weight_scale.dim() == 2: - # Get the maximum scale across w1 and w3 for each expert max_w13_scales = layer.w13_weight_scale.max(dim=1).values @@ -445,51 +496,52 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: for shard_id in range(2): # w1 and w3 # Dequantize using the original scale for this shard dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - intermediate_size, :], + layer.w13_weight[expert_id][ + start : start + intermediate_size, : + ], layer.w13_weight_scale[expert_id][shard_id], ) # Requantize using the combined max scale ( - layer.w13_weight[expert_id][start:start + - intermediate_size, :], + layer.w13_weight[expert_id][ + start : start + intermediate_size, : + ], _, - ) = scaled_fp8_quant(dq_weight, - max_w13_scales[expert_id]) + ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) start += intermediate_size # Update the scale parameter to be per-expert - layer.w13_weight_scale = Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False) else: - layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data, - requires_grad=False) + layer.w13_weight_scale = Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) - if hasattr(layer, - "w2_weight_scale") and layer.w2_weight_scale is not None: - layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data, - requires_grad=False) + if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None: + layer.w2_weight_scale = Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) # Input scales must be equal for each expert in fp8 MoE layers. - if hasattr(layer, - "w13_input_scale") and layer.w13_input_scale is not None: - layer.w13_input_scale = Parameter(layer.w13_input_scale.max(), - requires_grad=False) - if hasattr(layer, - "w2_input_scale") and layer.w2_input_scale is not None: - layer.w2_input_scale = Parameter(layer.w2_input_scale.max(), - requires_grad=False) + if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None: + layer.w13_input_scale = Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None: + layer.w2_input_scale = Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) if self.flashinfer_moe_backend is not None: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) register_moe_scaling_factors(layer) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, - layer.w2_weight) + rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: return None @@ -526,12 +578,14 @@ def apply( ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( - "EPLB not supported for `ModelOptFp8MoEMethod` yet.") + "EPLB not supported for `ModelOptFp8MoEMethod` yet." + ) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: assert self.fused_experts is None - assert activation == 'silu', ( - f"Expected 'silu' activation but got {activation}") + assert activation == "silu", ( + f"Expected 'silu' activation but got {activation}" + ) assert not renormalize return apply_flashinfer_per_tensor_scale_fp8( layer=layer, @@ -542,7 +596,8 @@ def apply( top_k=top_k, num_expert_group=num_expert_group, topk_group=topk_group, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + ) # Expert selection topk_weights, topk_ids, _ = FusedMoE.select_experts( @@ -579,8 +634,9 @@ def apply( ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: assert not renormalize - assert activation == 'silu', ( - f"Expected 'silu' activation but got {activation}") + assert activation == "silu", ( + f"Expected 'silu' activation but got {activation}" + ) return flashinfer_cutlass_moe_fp8( x, layer, @@ -593,8 +649,8 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, ) else: - from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts) + from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts + assert self.moe_quant_config is not None return fused_experts( @@ -627,7 +683,8 @@ def __init__( if is_checkpoint_nvfp4_serialized: logger.warning( "Detected ModelOpt NVFP4 checkpoint. Please note that" - " the format is experimental and could change in future.") + " the format is experimental and could change in future." + ) self.group_size = group_size self.kv_cache_quant_algo = kv_cache_quant_algo @@ -651,12 +708,12 @@ def get_config_filenames(cls) -> list[str]: def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.exclude_modules is not None: - self.exclude_modules = hf_to_vllm_mapper.apply_list( - self.exclude_modules) + self.exclude_modules = hf_to_vllm_mapper.apply_list(self.exclude_modules) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: """Detect if this ModelOpt FP4 config should be used based on quantization config.""" if hf_quant_cfg is None: @@ -694,8 +751,7 @@ def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": # {"quantization": {"quant_algo": "..."}} quant_config = cls.get_from_keys(config, ["quantization"]) if not isinstance(quant_config, dict): - raise ValueError( - "Expected 'quantization' to be a dictionary in config") + raise ValueError("Expected 'quantization' to be a dictionary in config") quant_method = quant_config.get("quant_algo", "") if not quant_method: @@ -709,8 +765,10 @@ def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": elif isinstance(kv_cache_quant_algo_raw, str): kv_cache_quant_algo = kv_cache_quant_algo_raw else: - raise ValueError(f"kv_cache_quant_algo must be a string, got " - f"{type(kv_cache_quant_algo_raw)}") + raise ValueError( + f"kv_cache_quant_algo must be a string, got " + f"{type(kv_cache_quant_algo_raw)}" + ) # Handle group_size with proper type validation group_size_raw = quant_config.get("group_size") @@ -722,14 +780,16 @@ def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": try: group_size = int(group_size_raw) except (ValueError, TypeError): - raise ValueError(f"group_size must be an integer, got " - f"{type(group_size_raw)}") from None + raise ValueError( + f"group_size must be an integer, got {type(group_size_raw)}" + ) from None # "exclude_modules" is the key in the legacy hf_quant_config.json exclude_modules = quant_config.get("exclude_modules", []) if not isinstance(exclude_modules, list): - raise ValueError(f"exclude_modules must be a list, got " - f"{type(exclude_modules)}") + raise ValueError( + f"exclude_modules must be a list, got {type(exclude_modules)}" + ) else: # Compressed-tensors style format: # {"quant_algo": "...", "quant_method": "modelopt"} @@ -743,8 +803,10 @@ def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": elif isinstance(kv_cache_quant_algo_raw, str): kv_cache_quant_algo = kv_cache_quant_algo_raw else: - raise ValueError(f"kv_cache_quant_algo must be a string, got " - f"{type(kv_cache_quant_algo_raw)}") + raise ValueError( + f"kv_cache_quant_algo must be a string, got " + f"{type(kv_cache_quant_algo_raw)}" + ) # Handle group_size with proper type validation group_size_raw = config.get("group_size") @@ -756,40 +818,46 @@ def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": try: group_size = int(group_size_raw) except (ValueError, TypeError): - raise ValueError(f"group_size must be an integer, got " - f"{type(group_size_raw)}") from None + raise ValueError( + f"group_size must be an integer, got {type(group_size_raw)}" + ) from None # "ignore" is the key in config.json exclude_modules = config.get("ignore", []) if not isinstance(exclude_modules, list): - raise ValueError(f"exclude_modules must be a list, got " - f"{type(exclude_modules)}") + raise ValueError( + f"exclude_modules must be a list, got {type(exclude_modules)}" + ) if quant_method not in QUANT_ALGOS: raise ValueError( f"ModelOpt currently only supports: {QUANT_ALGOS} " "quantizations in vLLM. Please check the " "`hf_quant_config.json` file for your model's " - "quant configuration.") - is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method) + "quant configuration." + ) + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method # For FP4, these fields are required if is_checkpoint_nvfp4_serialized and "quantization" in config: # Check if required fields are present in the quantization config quant_config = config["quantization"] - required_fields = [ - "group_size", "kv_cache_quant_algo", "exclude_modules" - ] + required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"] missing_fields = [ field for field in required_fields if field not in quant_config ] if missing_fields: raise ValueError( f"NVFP4 quantization requires the following fields in " - f"hf_quant_config.json: {missing_fields}") - - return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, - exclude_modules, group_size) + f"hf_quant_config.json: {missing_fields}" + ) + + return cls( + is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo, + exclude_modules, + group_size, + ) def is_layer_excluded(self, prefix: str) -> bool: """ @@ -797,28 +865,30 @@ def is_layer_excluded(self, prefix: str) -> bool: Handles both exact matching (for fused layers) and pattern matching. """ # First check exact matching with fused layer support - if is_layer_skipped(prefix, self.exclude_modules, - self.packed_modules_mapping): + if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): return True # Check regex pattern matching for patterns not caught by exact match import regex as re + for pattern in self.exclude_modules: # Skip patterns that would be caught by exact matching - if '*' in pattern or '.' in pattern: - regex_str = pattern.replace('.', r'\.').replace('*', r'.*') + if "*" in pattern or "." in pattern: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") if re.fullmatch(regex_str, prefix): return True return False - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): if self.is_layer_excluded(prefix): return UnquantizedLinearMethod() # Check if this is a vision model layer that should not be quantized - if ("vision_tower" in prefix or "vision_model" in prefix): + if "vision_tower" in prefix or "vision_model" in prefix: return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) elif isinstance(layer, Attention): @@ -833,8 +903,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): Supports loading kv-cache scaling factors from FP8 checkpoints. """ - def __init__(self, quant_config: Union[ModelOptFp8Config, - ModelOptNvFp4Config]): + def __init__(self, quant_config: Union[ModelOptFp8Config, ModelOptNvFp4Config]): super().__init__(quant_config) @@ -862,9 +931,11 @@ def __init__(self, quant_config: ModelOptNvFp4Config) -> None: elif is_fp4_marlin_supported(): self.backend = "marlin" else: - raise ValueError("Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above.") + raise ValueError( + "Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above." + ) def create_weights( self, @@ -878,59 +949,69 @@ def create_weights( ): del input_size, output_size if not self.quant_config.is_checkpoint_nvfp4_serialized: - raise ValueError("NVFP4 quantization was selected, " - " dynamic quantization is not supported.") + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition - if (input_size_per_partition % 16 != 0): - raise ValueError("Unsupported model when in features size is " - "not multiple of 16") + if input_size_per_partition % 16 != 0: + raise ValueError( + "Unsupported model when in features size is not multiple of 16" + ) # The nvfp4 weight is still represented as - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_nvfp4_serialized - else params_dtype) + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype + ) # Weight weight = ModelWeightParameter( data=torch.empty( # 2 fp4 items are packed in the input dimension layer.output_size_per_partition, layer.input_size_per_partition // 2, - dtype=torch.uint8), + dtype=torch.uint8, + ), input_dim=1, output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # Input Weight Scale - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("input_scale", input_scale) # Global Weight Scale - weight_scale_2 = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale_2", weight_scale_2) # Per Block Weight Scale - weight_scale = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition // self.quant_config.group_size, - dtype=weight_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: Module) -> None: - # global scales: input_scale_2 = layer.input_scale.max().to(torch.float32) layer.input_scale = Parameter(input_scale_2, requires_grad=False) @@ -938,18 +1019,21 @@ def process_weights_after_loading(self, layer: Module) -> None: weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) - layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, - requires_grad=False) + layer.alpha = Parameter( + layer.input_scale * layer.weight_scale_2, requires_grad=False + ) # Calculate `1 / input_scale` so that we don't need to do so at runtime layer.input_scale_inv = Parameter( - (1 / layer.input_scale).to(torch.float32), requires_grad=False) + (1 / layer.input_scale).to(torch.float32), requires_grad=False + ) # Swizzle the weight blockscale. # contracting dimension is input dimension # block_size = 16; - assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Block scale must be represented as FP8-E4M3") + assert layer.weight_scale.dtype == torch.float8_e4m3fn, ( + "Weight Block scale must be represented as FP8-E4M3" + ) if self.backend == "marlin": prepare_fp4_layer_for_marlin(layer) @@ -966,18 +1050,18 @@ def process_weights_after_loading(self, layer: Module) -> None: weight_scale = layer.weight_scale.data epilogue_tile_m = 128 - weight = shuffle_matrix_a(weight.view(torch.uint8), - epilogue_tile_m) - weight_scale = (shuffle_matrix_sf_a(weight_scale.view( - torch.uint8), epilogue_tile_m).reshape( - weight_scale.shape).view(torch.float8_e4m3fn)) + weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) + weight_scale = ( + shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) + .reshape(weight_scale.shape) + .view(torch.float8_e4m3fn) + ) layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) - layer.weight_scale = Parameter(swizzled_weight_scale, - requires_grad=False) + layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False) def apply( @@ -995,7 +1079,8 @@ def apply( workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, - bias=bias) + bias=bias, + ) output_dtype = x.dtype output_shape = [x.shape[0], layer.weight.shape[0]] @@ -1005,11 +1090,11 @@ def apply( # validate dtypes of quantized input, input block scale, # weight and weight_blockscale - assert (x_fp4.dtype == torch.uint8) - assert (layer.weight.dtype == torch.uint8) - assert (x_blockscale.dtype == torch.float8_e4m3fn) - assert (layer.weight_scale.dtype == torch.float8_e4m3fn) - assert (layer.alpha.dtype == torch.float32) + assert x_fp4.dtype == torch.uint8 + assert layer.weight.dtype == torch.uint8 + assert x_blockscale.dtype == torch.float8_e4m3fn + assert layer.weight_scale.dtype == torch.float8_e4m3fn + assert layer.alpha.dtype == torch.float32 mm_args = ( x_fp4, @@ -1055,7 +1140,9 @@ def __init__( layer: torch.nn.Module, ) -> None: from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 - detect_nvfp4_moe_support) + detect_nvfp4_moe_support, + ) + super().__init__(moe) self.quant_config = quant_config self.layer = layer @@ -1069,19 +1156,23 @@ def __init__( self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" - " for ModelOptNvFp4FusedMoE.") + " for ModelOptNvFp4FusedMoE." + ) - def maybe_make_prepare_finalize( - self) -> Optional[mk.FusedMoEPrepareAndFinalize]: - if (self.use_marlin - or (self.allow_flashinfer and self.flashinfer_moe_backend - == FlashinferMoeBackend.TENSORRT_LLM)): + def maybe_make_prepare_finalize(self) -> Optional[mk.FusedMoEPrepareAndFinalize]: + if self.use_marlin or ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): return None - elif (self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): + elif ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ): # For now, fp4 moe only works with the flashinfer dispatcher. - prepare_finalize = ( - build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)) + prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( + self.moe + ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize else: @@ -1107,12 +1198,20 @@ def uses_weight_scale_2_pattern(self) -> bool: """ return True - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): if not self.quant_config.is_checkpoint_nvfp4_serialized: - raise ValueError("NVFP4 quantization was selected, " - " dynamic quantization is not supported.") + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) layer.num_experts = num_experts layer.params_dtype = params_dtype @@ -1127,10 +1226,12 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // 2, - dtype=weight_dtype), + dtype=weight_dtype, + ), input_dim=1, output_dim=2, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w13_weight", w13_weight) # GEMM 2 @@ -1140,10 +1241,12 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // 2, - dtype=weight_dtype), + dtype=weight_dtype, + ), input_dim=1, output_dim=2, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w2_weight", w2_weight) w13_weight_scale = ModelWeightParameter( @@ -1152,10 +1255,12 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, 2 * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.quant_config.group_size, - dtype=weight_scale_dtype), + dtype=weight_scale_dtype, + ), input_dim=1, output_dim=2, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = ModelWeightParameter( @@ -1163,38 +1268,45 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, num_experts, hidden_size, # 2 fp4 items are packed in the input dimension - intermediate_size_per_partition // - self.quant_config.group_size, - dtype=weight_scale_dtype), + intermediate_size_per_partition // self.quant_config.group_size, + dtype=weight_scale_dtype, + ), input_dim=1, output_dim=2, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) w13_weight_scale_2 = PerTensorScaleParameter( data=torch.empty(num_experts, 2, dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) w2_weight_scale_2 = PerTensorScaleParameter( data=torch.empty(num_experts, dtype=torch.float32), - weight_loader=weight_loader) + weight_loader=weight_loader, + ) layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) - w13_input_scale = PerTensorScaleParameter(data=torch.empty( - num_experts, 2, dtype=torch.float32), - weight_loader=weight_loader) + w13_input_scale = PerTensorScaleParameter( + data=torch.empty(num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("w13_input_scale", w13_input_scale) - w2_input_scale = PerTensorScaleParameter(data=torch.empty( - num_experts, dtype=torch.float32), - weight_loader=weight_loader) + w2_input_scale = PerTensorScaleParameter( + data=torch.empty(num_experts, dtype=torch.float32), + weight_loader=weight_loader, + ) layer.register_parameter("w2_input_scale", w2_input_scale) def prepare_static_weights_for_trtllm_fp4_moe( @@ -1212,24 +1324,30 @@ def prepare_static_weights_for_trtllm_fp4_moe( from flashinfer import nvfp4_block_scale_interleave from flashinfer.fused_moe.core import ( _maybe_get_cached_w2_permute_indices, - _maybe_get_cached_w3_w1_permute_indices) + _maybe_get_cached_w3_w1_permute_indices, + ) + """Prepare quantized weights for kernel (done offline with weights).""" epilogue_tile_m = 128 # FIXME: this depends on the kernel internals # Convert quantized weights to proper formats gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape( - num_experts, 2 * intermediate_size, hidden_size // 2) # packed fp4 + num_experts, 2 * intermediate_size, hidden_size // 2 + ) # packed fp4 gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( - torch.float8_e4m3fn).reshape(num_experts, 2 * intermediate_size, - hidden_size // - 16) # fp8 scaling factors + torch.float8_e4m3fn + ).reshape( + num_experts, 2 * intermediate_size, hidden_size // 16 + ) # fp8 scaling factors gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape( - num_experts, hidden_size, intermediate_size // 2) # packed fp4 + num_experts, hidden_size, intermediate_size // 2 + ) # packed fp4 gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view( - torch.float8_e4m3fn).reshape(num_experts, hidden_size, - intermediate_size // - 16) # fp8 scaling factors + torch.float8_e4m3fn + ).reshape( + num_experts, hidden_size, intermediate_size // 16 + ) # fp8 scaling factors gemm1_weights_fp4_shuffled = [] gemm1_scales_fp4_shuffled = [] @@ -1245,9 +1363,11 @@ def prepare_static_weights_for_trtllm_fp4_moe( gemm1_weights_fp4[i].view(torch.uint8), epilogue_tile_m, ) - gemm1_weights_fp4_shuffled.append(gemm1_weights_fp4[i].view( - torch.uint8)[permute_indices.to( - gemm1_weights_fp4.device)].contiguous()) + gemm1_weights_fp4_shuffled.append( + gemm1_weights_fp4[i] + .view(torch.uint8)[permute_indices.to(gemm1_weights_fp4.device)] + .contiguous() + ) permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( self._cache_permute_indices, @@ -1256,18 +1376,25 @@ def prepare_static_weights_for_trtllm_fp4_moe( num_elts_per_sf=16, ) gemm1_scales_fp4_shuffled.append( - nvfp4_block_scale_interleave(gemm1_scales_linear_fp4[i].view( - torch.uint8)[permute_sf_indices.to( - gemm1_scales_linear_fp4.device)].contiguous())) + nvfp4_block_scale_interleave( + gemm1_scales_linear_fp4[i] + .view(torch.uint8)[ + permute_sf_indices.to(gemm1_scales_linear_fp4.device) + ] + .contiguous() + ) + ) permute_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m, ) - gemm2_weights_fp4_shuffled.append(gemm2_weights_fp4[i].view( - torch.uint8)[permute_indices.to( - gemm2_weights_fp4.device)].contiguous()) + gemm2_weights_fp4_shuffled.append( + gemm2_weights_fp4[i] + .view(torch.uint8)[permute_indices.to(gemm2_weights_fp4.device)] + .contiguous() + ) permute_sf_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, @@ -1276,23 +1403,29 @@ def prepare_static_weights_for_trtllm_fp4_moe( num_elts_per_sf=16, ) gemm2_scales_fp4_shuffled.append( - nvfp4_block_scale_interleave(gemm2_scales_linear_fp4[i].view( - torch.uint8)[permute_sf_indices.to( - gemm2_scales_linear_fp4.device)].contiguous())) + nvfp4_block_scale_interleave( + gemm2_scales_linear_fp4[i] + .view(torch.uint8)[ + permute_sf_indices.to(gemm2_scales_linear_fp4.device) + ] + .contiguous() + ) + ) # Stack weights for all experts gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled) gemm1_scales_fp4_shuffled = ( - torch.stack(gemm1_scales_fp4_shuffled).view( - torch.float8_e4m3fn).reshape(num_experts, - 2 * intermediate_size, - hidden_size // 16)) + torch.stack(gemm1_scales_fp4_shuffled) + .view(torch.float8_e4m3fn) + .reshape(num_experts, 2 * intermediate_size, hidden_size // 16) + ) gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled) gemm2_scales_fp4_shuffled = ( - torch.stack(gemm2_scales_fp4_shuffled).view( - torch.float8_e4m3fn).reshape(num_experts, hidden_size, - intermediate_size // 16)) + torch.stack(gemm2_scales_fp4_shuffled) + .view(torch.float8_e4m3fn) + .reshape(num_experts, hidden_size, intermediate_size // 16) + ) return ( gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, @@ -1307,74 +1440,86 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.allow_flashinfer: gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( - gemm1_weight, gemm1_weight_scale, dim=-2) + gemm1_weight, gemm1_weight_scale, dim=-2 + ) layer.w13_weight = Parameter(gemm1_weight, requires_grad=False) - layer.w13_weight_scale = Parameter(gemm1_weight_scale, - requires_grad=False) + layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False) # Common processing for w13_weight_scale_2 - if not torch.allclose(layer.w13_weight_scale_2[:, 0], - layer.w13_weight_scale_2[:, 1]): + if not torch.allclose( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] + ): logger.warning_once( "w1_weight_scale_2 must match w3_weight_scale_2. " - "Accuracy may be affected.") + "Accuracy may be affected." + ) w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] - layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, - requires_grad=False) + layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) # Common processing for input scales and alphas - w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( - torch.float32) + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), - requires_grad=False) + requires_grad=False, + ) # This is for quantization, so we need to invert it. layer.w13_input_scale_quant = Parameter( - (1 / w13_input_scale).to(torch.float32), requires_grad=False) + (1 / w13_input_scale).to(torch.float32), requires_grad=False + ) # GEMM 2 processing layer.g2_alphas = Parameter( (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), - requires_grad=False) + requires_grad=False, + ) # This is for quantization, so we need to invert it. layer.w2_input_scale_quant = Parameter( - (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) + (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False + ) # TensorRT-LLM specific processing - if self.allow_flashinfer and \ - self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): # Prepare static weights for TRT-LLM kernel # alternate: prepare_static_weight_layouts_for_trtllm_moe - (gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, - gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled - ) = self.prepare_static_weights_for_trtllm_fp4_moe( - layer.w13_weight, - layer.w2_weight, - layer.w13_weight_scale, - layer.w2_weight_scale, - layer.w2_weight.size(-2), # hidden_size - layer.w13_weight.size(-2) // 2, # intermediate_size - layer.w13_weight.size(0), # num_experts - ) + ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) = self.prepare_static_weights_for_trtllm_fp4_moe( + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + layer.w2_weight.size(-2), # hidden_size + layer.w13_weight.size(-2) // 2, # intermediate_size + layer.w13_weight.size(0), # num_experts + ) logger.debug_once("Finished shuffling weights for TRT-LLM MOE") layer.gemm1_weights_fp4_shuffled = Parameter( - gemm1_weights_fp4_shuffled, requires_grad=False) + gemm1_weights_fp4_shuffled, requires_grad=False + ) layer.gemm2_weights_fp4_shuffled = Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False) + gemm2_weights_fp4_shuffled, requires_grad=False + ) layer.gemm1_scales_fp4_shuffled = Parameter( - gemm1_scales_fp4_shuffled, requires_grad=False) + gemm1_scales_fp4_shuffled, requires_grad=False + ) layer.gemm2_scales_fp4_shuffled = Parameter( - gemm2_scales_fp4_shuffled, requires_grad=False) + gemm2_scales_fp4_shuffled, requires_grad=False + ) # Additional parameter needed for TRT-LLM layer.g1_scale_c = Parameter( - (layer.w2_input_scale_quant * layer.g1_alphas).to( - torch.float32), + (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) @@ -1392,29 +1537,36 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: del layer.w2_input_scale_quant else: # Non-TRT-LLM processing (Cutlass or non-flashinfer) - assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Blockscale must be represented as FP8-E4M3") - w13_blockscale_swizzled = swizzle_blockscale( - layer.w13_weight_scale) - layer.w13_weight_scale = Parameter(w13_blockscale_swizzled, - requires_grad=False) - - assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( - "Expected weight_scale.dim(1) to be divisible by 16") - assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( - "Weight Blockscale must be represented as FP8-E4M3") + assert layer.w13_weight_scale.shape[2] % 16 == 0, ( + "Expected weight_scale.dim(1) to be divisible by 16" + ) + assert layer.w13_weight_scale.dtype == torch.float8_e4m3fn, ( + "Weight Blockscale must be represented as FP8-E4M3" + ) + w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) + layer.w13_weight_scale = Parameter( + w13_blockscale_swizzled, requires_grad=False + ) + + assert layer.w2_weight_scale.shape[2] % 16 == 0, ( + "Expected weight_scale.dim(1) to be divisible by 16" + ) + assert layer.w2_weight_scale.dtype == torch.float8_e4m3fn, ( + "Weight Blockscale must be represented as FP8-E4M3" + ) w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) - layer.w2_weight_scale = Parameter(w2_blockscale_swizzled, - requires_grad=False) - layer.w2_weight = Parameter(layer.w2_weight.data, - requires_grad=False) + layer.w2_weight_scale = Parameter( + w2_blockscale_swizzled, requires_grad=False + ) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: - if (self.use_marlin or self.flashinfer_moe_backend - == FlashinferMoeBackend.TENSORRT_LLM): + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: + if ( + self.use_marlin + or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): return None return nvfp4_moe_quant_config( @@ -1451,11 +1603,14 @@ def apply( ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: if enable_eplb: raise NotImplementedError( - "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") + "EPLB not supported for `ModelOptNvFp4FusedMoE` yet." + ) assert activation == "silu", "Only SiLU activation is supported." - if (self.allow_flashinfer and self.flashinfer_moe_backend - == FlashinferMoeBackend.TENSORRT_LLM): + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ): import flashinfer from vllm.model_executor.models.llama4 import Llama4MoE @@ -1463,14 +1618,16 @@ def apply( assert self.fused_experts is None a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, - hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) - use_llama4_routing = \ + (hidden_states_fp4, hidden_states_scale_linear_fp4) = ( + flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) + ) + use_llama4_routing = ( custom_routing_function is Llama4MoE.custom_routing_function + ) routing_method_type = flashinfer.RoutingMethodType.DeepSeekV3 if use_llama4_routing: routing_method_type = flashinfer.RoutingMethodType.Llama4 @@ -1479,36 +1636,40 @@ def apply( routing_bias = routing_bias.to(torch.bfloat16) out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( routing_logits=router_logits - if use_llama4_routing else router_logits.to(torch.float32), + if use_llama4_routing + else router_logits.to(torch.float32), routing_bias=routing_bias, hidden_states=hidden_states_fp4, hidden_states_scale=hidden_states_scale_linear_fp4.view( - torch.float8_e4m3fn).flatten(), + torch.float8_e4m3fn + ).flatten(), gemm1_weights=layer.gemm1_weights_fp4_shuffled.data, gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn), + torch.float8_e4m3fn + ), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, gemm2_weights=layer.gemm2_weights_fp4_shuffled.data, gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn), + torch.float8_e4m3fn + ), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c.data, output1_scale_gate_scalar=layer.g1_alphas.data, output2_scale_scalar=layer.g2_alphas.data, num_experts=global_num_experts, top_k=top_k, - n_group=num_expert_group - if num_expert_group is not None else 0, + n_group=num_expert_group if num_expert_group is not None else 0, topk_group=topk_group if topk_group is not None else 0, intermediate_size=layer.intermediate_size_per_partition, local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, routed_scaling_factor=None, - tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, - layer.local_num_experts), + tile_tokens_dim=_get_tile_tokens_dim( + x.shape[0], top_k, layer.local_num_experts + ), routing_method_type=routing_method_type, do_finalize=True, )[0] @@ -1526,7 +1687,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) # # Note: the order here is important. self.fused_experts can override @@ -1552,15 +1714,18 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, - workspace=layer.workspace) + workspace=layer.workspace, + ) elif self.fused_experts is not None: - assert self.allow_flashinfer and \ - self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + assert ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ) assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight), ( - "Flashinfer CUTLASS Fused MoE not applicable!") + x, layer.w13_weight, layer.w2_weight + ), "Flashinfer CUTLASS Fused MoE not applicable!" return self.fused_experts( hidden_states=x, @@ -1574,10 +1739,14 @@ def apply( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif (self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS): + elif ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ): from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - flashinfer_cutlass_moe_fp4) + flashinfer_cutlass_moe_fp4, + ) + assert self.moe_quant_config is not None return flashinfer_cutlass_moe_fp4( @@ -1596,8 +1765,8 @@ def apply( else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4) + from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 + assert self.moe_quant_config is not None return cutlass_moe_fp4( a=x, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index ee8d33e636f9..3719672f6e52 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -7,17 +7,25 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, int4_w4a16_moe_quant_config, - int8_w8a16_moe_quant_config) + FusedMoEQuantConfig, + int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) + FusedMoE, + FusedMoEConfig, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supports_layer) + check_marlin_supports_layer, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -25,10 +33,16 @@ class MoeWNA16Config(QuantizationConfig): """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" - def __init__(self, linear_quant_method: str, weight_bits: int, - group_size: int, has_zp: bool, lm_head_quantized: bool, - modules_to_not_convert: Optional[list[str]], - full_config: dict[str, Any]) -> None: + def __init__( + self, + linear_quant_method: str, + weight_bits: int, + group_size: int, + has_zp: bool, + lm_head_quantized: bool, + modules_to_not_convert: Optional[list[str]], + full_config: dict[str, Any], + ) -> None: super().__init__() self.weight_bits = weight_bits self.group_size = group_size @@ -40,26 +54,25 @@ def __init__(self, linear_quant_method: str, weight_bits: int, self.use_marlin = False # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig - from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig) - from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) + from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig + from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig + if self.linear_quant_method == "gptq": - self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible( - full_config) + self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config) elif self.linear_quant_method == "awq": capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) awq_min_capability = AWQConfig.get_min_capability() if device_capability < awq_min_capability: raise ValueError( "The quantization method moe_wna16 + awq is not supported " "for the current GPU. " f"Minimum capability: {awq_min_capability}. " - f"Current capability: {device_capability}.") - self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible( - full_config) + f"Current capability: {device_capability}." + ) + self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(full_config) else: raise ValueError("moe_wna16 only support gptq and awq.") @@ -89,24 +102,32 @@ def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config": linear_quant_method = cls.get_from_keys(config, ["quant_method"]) weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) if linear_quant_method == "gptq": has_zp = not cls.get_from_keys(config, ["sym"]) modules_to_not_convert = [] elif linear_quant_method == "awq": has_zp = cls.get_from_keys(config, ["zero_point"]) modules_to_not_convert = cls.get_from_keys_or( - config, ["modules_to_not_convert"], None) + config, ["modules_to_not_convert"], None + ) else: raise ValueError("moe_wna16 only support gptq and awq.") - return cls(linear_quant_method, weight_bits, group_size, has_zp, - lm_head_quantized, modules_to_not_convert, config) + return cls( + linear_quant_method, + weight_bits, + group_size, + has_zp, + lm_head_quantized, + modules_to_not_convert, + config, + ) @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) if can_convert and user_quant == "moe_wna16": return cls.get_name() @@ -120,46 +141,59 @@ def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]): desc_act = quant_config.get("desc_act") capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig + awq_min_capability = AWQConfig.get_min_capability() - gptq_compatible = quant_method == "gptq" and \ - not desc_act and num_bits in [4, 8] - awq_compatible = quant_method == "awq" and num_bits == 4 and \ - device_capability >= awq_min_capability + gptq_compatible = quant_method == "gptq" and not desc_act and num_bits in [4, 8] + awq_compatible = ( + quant_method == "awq" + and num_bits == 4 + and device_capability >= awq_min_capability + ) return gptq_compatible or awq_compatible - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if is_layer_skipped_quant(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig) + AWQMarlinConfig, + ) from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) + GPTQMarlinConfig, + ) + if self.linear_quant_method == "gptq": if self.use_marlin: return GPTQMarlinConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + self.full_config + ).get_quant_method(layer, prefix) else: - return GPTQConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + return GPTQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) elif self.linear_quant_method == "awq": if self.use_marlin and check_marlin_supports_layer( - layer, self.group_size): + layer, self.group_size + ): return AWQMarlinConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + self.full_config + ).get_quant_method(layer, prefix) else: - return AWQConfig.from_config( - self.full_config).get_quant_method(layer, prefix) + return AWQConfig.from_config(self.full_config).get_quant_method( + layer, prefix + ) else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): @@ -178,15 +212,19 @@ class MoeWNA16Method(FusedMoEMethodBase): quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ - def __init__(self, quant_config: MoeWNA16Config, - moe: "FusedMoEConfig") -> None: + def __init__(self, quant_config: MoeWNA16Config, moe: "FusedMoEConfig") -> None: super().__init__(moe) self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): self.moe = layer layer.quant_config = self.quant_config bit8_pack_factor = self.quant_config.bit8_pack_factor @@ -196,8 +234,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # make intermediate_size and hidden_size divisible by group_size # we reduce the group size to ensure that # and we would repeat the loaded_weight later - while intermediate_size_per_partition % group_size or \ - hidden_size % group_size: + while intermediate_size_per_partition % group_size or hidden_size % group_size: group_size = group_size // 2 group_size_div_factor *= 2 assert group_size >= 32 @@ -205,71 +242,85 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.group_size_div_factor = group_size_div_factor strategy = FusedMoeWeightScaleSupported.GROUP.value - extra_weight_attrs.update({ - "quant_method": strategy, - "is_transposed": False - }) + extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False}) - assert 'weight_loader' in extra_weight_attrs - weight_loader = extra_weight_attrs['weight_loader'] - wrapped_weight_loader = MoeWNA16Method.get_weight_loader( - layer, weight_loader) - extra_weight_attrs['weight_loader'] = wrapped_weight_loader + assert "weight_loader" in extra_weight_attrs + weight_loader = extra_weight_attrs["weight_loader"] + wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader) + extra_weight_attrs["weight_loader"] = wrapped_weight_loader # Fused gate_up_proj (column parallel) - w13_qweight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // bit8_pack_factor, - dtype=torch.uint8), - requires_grad=False) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // bit8_pack_factor, + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs(w13_qweight, extra_weight_attrs) # down_proj (row parallel) - w2_qweight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // bit8_pack_factor, - dtype=torch.uint8), - requires_grad=False) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // bit8_pack_factor, + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs(w2_qweight, extra_weight_attrs) - w13_scales = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // group_size, - dtype=params_dtype), - requires_grad=False) + w13_scales = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) - w2_scales = torch.nn.Parameter(torch.zeros( - num_experts, - hidden_size, - intermediate_size_per_partition // group_size, - dtype=params_dtype), - requires_grad=False) + w2_scales = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) if self.quant_config.has_zp: - w13_qzeros = torch.nn.Parameter(torch.zeros( - num_experts, - 2 * intermediate_size_per_partition // bit8_pack_factor, - hidden_size // group_size, - dtype=torch.uint8), - requires_grad=False) + w13_qzeros = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition // bit8_pack_factor, + hidden_size // group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w13_qzeros", w13_qzeros) set_weight_attrs(w13_qzeros, extra_weight_attrs) - w2_qzeros = torch.nn.Parameter(torch.zeros( - num_experts, - hidden_size // bit8_pack_factor, - intermediate_size_per_partition // group_size, - dtype=torch.uint8), - requires_grad=False) + w2_qzeros = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size // bit8_pack_factor, + intermediate_size_per_partition // group_size, + dtype=torch.uint8, + ), + requires_grad=False, + ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) @@ -280,19 +331,23 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, if not self.quant_config.has_zp: invalid_param_keys += ["w13_qzeros", "w2_qzeros"] for key in invalid_param_keys: - param = torch.nn.Parameter(torch.empty((0, ), - dtype=torch.int32), - requires_grad=False) + param = torch.nn.Parameter( + torch.empty((0,), dtype=torch.int32), requires_grad=False + ) layer.register_parameter(key, param) set_weight_attrs(param, extra_weight_attrs) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp assert weight_bits == 4 or weight_bits == 8 - config_builder = (int4_w4a16_moe_quant_config - if weight_bits == 4 else int8_w8a16_moe_quant_config) + config_builder = ( + int4_w4a16_moe_quant_config + if weight_bits == 4 + else int8_w8a16_moe_quant_config + ) return config_builder( w1_scale=layer.w13_scales, @@ -327,10 +382,10 @@ def apply( ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `MoeWNA16Method` yet.") + raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.") from vllm.model_executor.layers.fused_moe import fused_experts + assert activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, @@ -344,7 +399,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return fused_experts( x, @@ -361,7 +417,6 @@ def apply( @staticmethod def get_weight_loader(layer, weight_loader): - def convert_awq_tensor(tensor, tensor_type): # convert awq qweight/qzeros to a standard format (assume int4) # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8) @@ -377,9 +432,7 @@ def convert_awq_tensor(tensor, tensor_type): # 2. unpack to uint4 (only when weight_bits == 4) # shape (a, 4 * b) -> (a, 4 * b, 2) - shifter = torch.tensor([0, 4], - dtype=torch.uint8, - device=tensor.device) + shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) tensor = (tensor[:, :, None] >> shifter) & 0xF # 3. change order, see @@ -404,20 +457,20 @@ def convert_awq_tensor(tensor, tensor_type): def convert_gptq_int4_qzeros(tensor): tensor = tensor.view(torch.uint8) - shifter = torch.tensor([0, 4], - dtype=torch.uint8, - device=tensor.device) + shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) tensor = (tensor[:, :, None] >> shifter) & 0xF tensor = tensor + 1 tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16 return tensor - def moe_wna16_weight_loader(param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - shard_id: str, - expert_id: int, - return_success: bool = False): + def moe_wna16_weight_loader( + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False, + ): if "g_idx" in weight_name: return False if return_success else None if not layer.quant_config.has_zp and "qzeros" in weight_name: @@ -432,8 +485,7 @@ def moe_wna16_weight_loader(param: torch.nn.Parameter, if layer.quant_config.linear_quant_method == "awq": assert layer.quant_config.weight_bits == 4 if "weight" in weight_name: - loaded_weight = convert_awq_tensor(loaded_weight, - "qweight") + loaded_weight = convert_awq_tensor(loaded_weight, "qweight") elif "zeros" in weight_name: loaded_weight = convert_awq_tensor(loaded_weight, "qzeros") else: @@ -441,44 +493,50 @@ def moe_wna16_weight_loader(param: torch.nn.Parameter, elif layer.quant_config.linear_quant_method == "gptq": assert layer.quant_config.weight_bits in [4, 8] if "weight" in weight_name: - loaded_weight = loaded_weight.T.contiguous().view( - torch.uint8) + loaded_weight = loaded_weight.T.contiguous().view(torch.uint8) elif "zeros" in weight_name: # add 1 to gptq qzeros to align with awq loaded_weight = loaded_weight.view(torch.uint8) if layer.quant_config.weight_bits == 4: - loaded_weight = convert_gptq_int4_qzeros( - loaded_weight).T + loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T else: loaded_weight = loaded_weight.T + 1 else: loaded_weight = loaded_weight.T # repeat the qzeros/scales to fit new group size - if layer.group_size_div_factor > 1 and \ - "qzeros" in weight_name or "scales" in weight_name: + if ( + layer.group_size_div_factor > 1 + and "qzeros" in weight_name + or "scales" in weight_name + ): loaded_weight = loaded_weight.repeat_interleave( - layer.group_size_div_factor, 1) + layer.group_size_div_factor, 1 + ) if "w13_qzeros" in weight_name: - tensor = loaded_weight.view(layer.tp_size, -1, - loaded_weight.size(1))[tp_rank] + tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[ + tp_rank + ] if shard_id == "w1": - param.data[expert_id, :shard_size // 2] = tensor + param.data[expert_id, : shard_size // 2] = tensor else: - param.data[expert_id, shard_size // 2:] = tensor + param.data[expert_id, shard_size // 2 :] = tensor return True if return_success else None elif "w2_qzeros" in weight_name: param.data[expert_id] = loaded_weight.view( - loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] + loaded_weight.size(0), layer.tp_size, -1 + )[:, tp_rank] return True if return_success else None else: # Delegate to the original loader, passing return_success - return weight_loader(param, - loaded_weight, - weight_name, - shard_id, - expert_id, - return_success=return_success) + return weight_loader( + param, + loaded_weight, + weight_name, + shard_id, + expert_id, + return_success=return_success, + ) return moe_wna16_weight_loader diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 950bf33dbf01..b379d4bf3ae1 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -9,32 +9,45 @@ from vllm import envs from vllm.config import get_current_vllm_config from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEConfig, + FusedMoEMethodBase, +) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config, - mxfp4_w4a16_moe_quant_config) + FusedMoEQuantConfig, + mxfp4_w4a4_moe_quant_config, + mxfp4_w4a16_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( - OAITritonExperts) + OAITritonExperts, +) from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_moe_fp4_layer_for_marlin) + prepare_moe_fp4_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - _can_support_mxfp4, _swizzle_mxfp4) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + _can_support_mxfp4, + _swizzle_mxfp4, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types -from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, - next_power_of_2, round_up) +from vllm.utils import ( + has_triton_kernels, + is_torch_equal_or_newer, + next_power_of_2, + round_up, +) from vllm.utils.flashinfer import has_flashinfer logger = init_logger(__name__) @@ -60,42 +73,57 @@ class Mxfp4Backend(Enum): def get_mxfp4_backend(): # Backend Selection if current_platform.is_cuda(): - if (current_platform.is_device_capability(90) and has_flashinfer() - and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): + if ( + current_platform.is_device_capability(90) + and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 + ): logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") return Mxfp4Backend.SM90_FI_MXFP4_BF16 - elif (current_platform.is_device_capability(100) and has_flashinfer() - and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS): - logger.info_once( - "Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") + elif ( + current_platform.is_device_capability(100) + and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS + ): + logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - elif (current_platform.is_device_capability(100) and has_flashinfer() - and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8): + elif ( + current_platform.is_device_capability(100) + and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + ): logger.info_once( "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, " "for high concurrency throughput workloads consider setting " "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better " - "performance") + "performance" + ) return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM elif current_platform.is_device_capability(100) and has_flashinfer(): logger.info_once( "Using FlashInfer MXFP4 BF16 backend for SM100, " "For faster performance on SM100, consider setting " "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact " - "accuracy.") + "accuracy." + ) return Mxfp4Backend.SM100_FI_MXFP4_BF16 - elif ((current_platform.is_device_capability(100) - or current_platform.is_device_capability(90)) - and not has_flashinfer()): + elif ( + current_platform.is_device_capability(100) + or current_platform.is_device_capability(90) + ) and not has_flashinfer(): logger.warning_once( "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer " "is not available. This may result in degraded performance. " - "Please `pip install vllm[flashinfer]` for best results.") + "Please `pip install vllm[flashinfer]` for best results." + ) # If FlashInfer is not available, try either Marlin or Triton - if envs.VLLM_MXFP4_USE_MARLIN or current_platform.get_device_capability( - )[0] < 9 or not has_triton_kernels() or not is_torch_equal_or_newer( - "2.8.0"): + if ( + envs.VLLM_MXFP4_USE_MARLIN + or current_platform.get_device_capability()[0] < 9 + or not has_triton_kernels() + or not is_torch_equal_or_newer("2.8.0") + ): logger.info_once("Using Marlin backend") return Mxfp4Backend.MARLIN else: @@ -109,7 +137,6 @@ def get_mxfp4_backend(): class Mxfp4Config(QuantizationConfig): - def __init__(self, ignored_layers: Optional[list[str]] = None): super().__init__() self.ignored_layers = ignored_layers @@ -134,43 +161,51 @@ def get_supported_act_dtypes(cls) -> list[torch.dtype]: def get_config_filenames(cls) -> list[str]: return [] - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): if self.ignored_layers and is_layer_skipped( - prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping): + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): return UnquantizedLinearMethod() raise NotImplementedError("Mxfp4 linear layer is not implemented") elif isinstance(layer, FusedMoE): return Mxfp4MoEMethod(layer.moe_config) elif isinstance(layer, Attention): - raise NotImplementedError( - "Mxfp4 attention layer is not implemented") + raise NotImplementedError("Mxfp4 attention layer is not implemented") return None class Mxfp4MoEMethod(FusedMoEMethodBase): - def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.topk_indices_dtype = None self.moe = moe self.mxfp4_backend = get_mxfp4_backend() - self.max_capture_size = get_current_vllm_config( - ).compilation_config.max_capture_size + self.max_capture_size = ( + get_current_vllm_config().compilation_config.max_capture_size + ) assert self.mxfp4_backend != Mxfp4Backend.NONE, ( "No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available." - "Please check your environment and try again.") + "Please check your environment and try again." + ) self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): self.num_experts = num_experts weight_dtype = torch.uint8 scale_dtype = torch.uint8 @@ -185,8 +220,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, mxfp4_block = 32 - intermediate_size_per_partition_after_pad = \ - intermediate_size_per_partition + intermediate_size_per_partition_after_pad = intermediate_size_per_partition if self.mxfp4_backend == Mxfp4Backend.MARLIN: # The moe marlin kernel requires that for each linear # n % 256 == 0 and k % 128 == 0. @@ -197,34 +231,44 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # n = hidden_size # k = intermediate_size_per_partition_after_pad intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 128) + intermediate_size_per_partition, 128 + ) hidden_size = round_up(hidden_size, 256) layer.params_dtype = params_dtype layer.num_experts = num_experts layer.hidden_size = hidden_size - layer.intermediate_size_per_partition = \ + layer.intermediate_size_per_partition = ( intermediate_size_per_partition_after_pad - elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + ) + elif ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + ): # pad the intermediate size to be a multiple of 2 * mxfp4_block # for to hold non-uniform sharded tensor as well as swizzling # other padding to increase performance intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 256) + intermediate_size_per_partition, 256 + ) hidden_size = round_up(hidden_size, 256) - elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): + elif ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + ): intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 128) + intermediate_size_per_partition, 128 + ) hidden_size = round_up(hidden_size, 128) elif current_platform.is_rocm(): intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 256) + intermediate_size_per_partition, 256 + ) hidden_size = round_up(hidden_size, 256) else: intermediate_size_per_partition_after_pad = round_up( - intermediate_size_per_partition, 64) + intermediate_size_per_partition, 64 + ) self.intermediate_size = intermediate_size_per_partition_after_pad self.hidden_size = hidden_size @@ -303,47 +347,61 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, def process_weights_after_loading(self, layer): if self.mxfp4_backend == Mxfp4Backend.MARLIN: prepare_moe_fp4_layer_for_marlin(layer) - elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): - from flashinfer.fp4_quantization import ( - nvfp4_block_scale_interleave) - from flashinfer.fused_moe.core import ( - _maybe_get_cached_w2_permute_indices) - layer.gemm1_alpha = Parameter(torch.tensor( - [1.702] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False) - layer.gemm1_beta = Parameter(torch.tensor( - [1.0] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False) - layer.gemm1_clamp_limit = Parameter(torch.tensor( - [7.0] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False) + elif ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + ): + from flashinfer.fp4_quantization import nvfp4_block_scale_interleave + from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices + + layer.gemm1_alpha = Parameter( + torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_beta = Parameter( + torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_clamp_limit = Parameter( + torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) sf_block_size = 32 # mxfp4 block size - assert (layer.w13_weight.dim() == 3 - and layer.w13_weight.shape[0] == self.num_experts - and layer.w13_weight.shape[1] == self.intermediate_size * 2 - and layer.w13_weight.shape[2] == self.hidden_size // 2) - assert (layer.w13_weight_scale.dim() == 3 - and layer.w13_weight_scale.shape[0] == self.num_experts - and layer.w13_weight_scale.shape[1] - == self.intermediate_size * 2 - and layer.w13_weight_scale.shape[2] - == self.hidden_size // sf_block_size) - assert (layer.w2_weight.dim() == 3 - and layer.w2_weight.shape[0] == self.num_experts - and layer.w2_weight.shape[1] == self.hidden_size and - layer.w2_weight.shape[2] == self.intermediate_size // 2) - assert (layer.w2_weight_scale.dim() == 3 - and layer.w2_weight_scale.shape[1] == self.hidden_size - and layer.w2_weight_scale.shape[2] - == self.intermediate_size // sf_block_size) - assert (layer.w13_bias.dim() == 2 - and layer.w13_bias.shape[0] == self.num_experts - and layer.w13_bias.shape[1] == self.intermediate_size * 2) - assert (layer.w2_bias.dim() == 2 - and layer.w2_bias.shape[0] == self.num_experts - and layer.w2_bias.shape[1] == self.hidden_size) + assert ( + layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2 + ) + assert ( + layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size + ) + assert ( + layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size + and layer.w2_weight.shape[2] == self.intermediate_size // 2 + ) + assert ( + layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size + ) + assert ( + layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2 + ) + assert ( + layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size + ) w13_weight_scale = layer.w13_weight_scale.data w2_weight_scale = layer.w2_weight_scale.data @@ -391,9 +449,11 @@ def swap_every_two_rows(x, axis=-1): w13_weight[i].view(torch.uint8), epilogue_tile_m, ) - gemm1_weights_mxfp4_shuffled.append(w13_weight[i].view( - torch.uint8)[permute_indices.to( - w13_weight.device)].contiguous()) + gemm1_weights_mxfp4_shuffled.append( + w13_weight[i] + .view(torch.uint8)[permute_indices.to(w13_weight.device)] + .contiguous() + ) # w13 scale shuffling permute_sf_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, @@ -402,27 +462,37 @@ def swap_every_two_rows(x, axis=-1): num_elts_per_sf=16, ) gemm1_scales_mxfp4_shuffled.append( - nvfp4_block_scale_interleave(w13_weight_scale[i].view( - torch.uint8)[permute_sf_indices.to( - w13_weight_scale.device)].contiguous())) + nvfp4_block_scale_interleave( + w13_weight_scale[i] + .view(torch.uint8)[ + permute_sf_indices.to(w13_weight_scale.device) + ] + .contiguous() + ) + ) # w13 bias shuffling permute_bias_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, w13_bias[i].clone().reshape(-1, 1), epilogue_tile_m, ) - gemm1_bias_shuffled.append(w13_bias[i].clone().reshape( - -1, - 1)[permute_bias_indices.to(w13_bias.device)].contiguous()) + gemm1_bias_shuffled.append( + w13_bias[i] + .clone() + .reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)] + .contiguous() + ) # w2 weight shuffling permute_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, w2_weight[i].view(torch.uint8), epilogue_tile_m, ) - gemm2_weights_mxfp4_shuffled.append(w2_weight[i].view( - torch.uint8)[permute_indices.to( - w2_weight.device)].contiguous()) + gemm2_weights_mxfp4_shuffled.append( + w2_weight[i] + .view(torch.uint8)[permute_indices.to(w2_weight.device)] + .contiguous() + ) # w2 scale shuffling permute_sf_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, @@ -431,81 +501,115 @@ def swap_every_two_rows(x, axis=-1): num_elts_per_sf=16, ) gemm2_scales_mxfp4_shuffled.append( - nvfp4_block_scale_interleave(w2_weight_scale[i].view( - torch.uint8)[permute_sf_indices.to( - w2_weight_scale.device)].contiguous())) + nvfp4_block_scale_interleave( + w2_weight_scale[i] + .view(torch.uint8)[ + permute_sf_indices.to(w2_weight_scale.device) + ] + .contiguous() + ) + ) # w2 bias shuffling permute_indices = _maybe_get_cached_w2_permute_indices( self._cache_permute_indices, w2_bias[i].clone().reshape(-1, 1), epilogue_tile_m, ) - gemm2_bias_shuffled.append(w2_bias[i].clone().reshape( - -1, 1)[permute_indices.to(w2_bias.device)].contiguous()) + gemm2_bias_shuffled.append( + w2_bias[i] + .clone() + .reshape(-1, 1)[permute_indices.to(w2_bias.device)] + .contiguous() + ) w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled) - w13_weight_scale = torch.stack( - gemm1_scales_mxfp4_shuffled).reshape( - self.num_experts, 2 * self.intermediate_size, - self.hidden_size // sf_block_size).view( - torch.float8_e4m3fn) + w13_weight_scale = ( + torch.stack(gemm1_scales_mxfp4_shuffled) + .reshape( + self.num_experts, + 2 * self.intermediate_size, + self.hidden_size // sf_block_size, + ) + .view(torch.float8_e4m3fn) + ) w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled) - w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape( - self.num_experts, self.hidden_size, self.intermediate_size // - sf_block_size).view(torch.float8_e4m3fn) + w2_weight_scale = ( + torch.stack(gemm2_scales_mxfp4_shuffled) + .reshape( + self.num_experts, + self.hidden_size, + self.intermediate_size // sf_block_size, + ) + .view(torch.float8_e4m3fn) + ) layer.w13_weight = Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = Parameter(w13_weight_scale, - requires_grad=False) + layer.w13_weight_scale = Parameter(w13_weight_scale, requires_grad=False) layer.w2_weight = Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = Parameter(w2_weight_scale, - requires_grad=False) + layer.w2_weight_scale = Parameter(w2_weight_scale, requires_grad=False) layer.w13_bias = Parameter( torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1), - requires_grad=False) - layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape( - self.num_experts, -1), - requires_grad=False) - elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): - layer.gemm1_alpha = Parameter(torch.tensor( - [1.702] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False) - layer.gemm1_beta = Parameter(torch.tensor( - [1.0] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False) - layer.gemm1_clamp_limit = Parameter(torch.tensor( - [7.0] * self.num_experts, dtype=torch.float32).cuda(), - requires_grad=False) + requires_grad=False, + ) + layer.w2_bias = Parameter( + torch.stack(gemm2_bias_shuffled).reshape(self.num_experts, -1), + requires_grad=False, + ) + elif ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + ): + layer.gemm1_alpha = Parameter( + torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_beta = Parameter( + torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) + layer.gemm1_clamp_limit = Parameter( + torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(), + requires_grad=False, + ) sf_block_size = 32 # mxfp4 block size # Common shape assertions - assert (layer.w13_weight.dim() == 3 - and layer.w13_weight.shape[0] == self.num_experts - and layer.w13_weight.shape[1] == self.intermediate_size * 2 - and layer.w13_weight.shape[2] == self.hidden_size // 2) - assert (layer.w13_weight_scale.dim() == 3 - and layer.w13_weight_scale.shape[0] == self.num_experts - and layer.w13_weight_scale.shape[1] - == self.intermediate_size * 2 - and layer.w13_weight_scale.shape[2] - == self.hidden_size // sf_block_size) - assert (layer.w2_weight.dim() == 3 - and layer.w2_weight.shape[0] == self.num_experts - and layer.w2_weight.shape[1] == self.hidden_size and - layer.w2_weight.shape[2] == self.intermediate_size // 2) - assert (layer.w2_weight_scale.dim() == 3 - and layer.w2_weight_scale.shape[1] == self.hidden_size - and layer.w2_weight_scale.shape[2] - == self.intermediate_size // sf_block_size) - assert (layer.w13_bias.dim() == 2 - and layer.w13_bias.shape[0] == self.num_experts - and layer.w13_bias.shape[1] == self.intermediate_size * 2) - assert (layer.w2_bias.dim() == 2 - and layer.w2_bias.shape[0] == self.num_experts - and layer.w2_bias.shape[1] == self.hidden_size) + assert ( + layer.w13_weight.dim() == 3 + and layer.w13_weight.shape[0] == self.num_experts + and layer.w13_weight.shape[1] == self.intermediate_size * 2 + and layer.w13_weight.shape[2] == self.hidden_size // 2 + ) + assert ( + layer.w13_weight_scale.dim() == 3 + and layer.w13_weight_scale.shape[0] == self.num_experts + and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2 + and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size + ) + assert ( + layer.w2_weight.dim() == 3 + and layer.w2_weight.shape[0] == self.num_experts + and layer.w2_weight.shape[1] == self.hidden_size + and layer.w2_weight.shape[2] == self.intermediate_size // 2 + ) + assert ( + layer.w2_weight_scale.dim() == 3 + and layer.w2_weight_scale.shape[1] == self.hidden_size + and layer.w2_weight_scale.shape[2] + == self.intermediate_size // sf_block_size + ) + assert ( + layer.w13_bias.dim() == 2 + and layer.w13_bias.shape[0] == self.num_experts + and layer.w13_bias.shape[1] == self.intermediate_size * 2 + ) + assert ( + layer.w2_bias.dim() == 2 + and layer.w2_bias.shape[0] == self.num_experts + and layer.w2_bias.shape[1] == self.hidden_size + ) # De-interleave and swap for w13 weight, bias, and scales w13_w = layer.w13_weight.data @@ -531,51 +635,55 @@ def swap_every_two_rows(x, axis=-1): orig_shape = w13_scale_swapped.shape w13_scale_interleaved = block_scale_interleave( - w13_scale_swapped.view(torch.uint8)).reshape(orig_shape) + w13_scale_swapped.view(torch.uint8) + ).reshape(orig_shape) w2_s = layer.w2_weight_scale.data orig_shape = w2_s.shape w2_scale_interleaved = block_scale_interleave( - w2_s.view(torch.uint8)).reshape(orig_shape) - - layer.w13_weight = Parameter(w13_weight_swapped, - requires_grad=False) - layer.w13_weight_scale = Parameter(w13_scale_interleaved, - requires_grad=False) - layer.w13_bias = Parameter(w13_bias_swapped, - requires_grad=False) - layer.w2_weight_scale = Parameter(w2_scale_interleaved, - requires_grad=False) + w2_s.view(torch.uint8) + ).reshape(orig_shape) + + layer.w13_weight = Parameter(w13_weight_swapped, requires_grad=False) + layer.w13_weight_scale = Parameter( + w13_scale_interleaved, requires_grad=False + ) + layer.w13_bias = Parameter(w13_bias_swapped, requires_grad=False) + layer.w2_weight_scale = Parameter( + w2_scale_interleaved, requires_grad=False + ) elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: def _interleave_mxfp4_cutlass_sm90(w): w_shape = w.shape - w_interleaved = w.reshape(w_shape[0], w_shape[1], - (w_shape[2] // 4), 4) + w_interleaved = w.reshape( + w_shape[0], w_shape[1], (w_shape[2] // 4), 4 + ) w_interleaved = w_interleaved.permute(0, 2, 1, 3) w_interleaved = w_interleaved.reshape( - w_shape[0], w_shape[2] // 4, w_shape[1] * 4) + w_shape[0], w_shape[2] // 4, w_shape[1] * 4 + ) return w_interleaved - w31_scales = w13_scale_swapped.to(torch.uint8).view( - torch.uint8) - w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90( - w31_scales) + w31_scales = w13_scale_swapped.to(torch.uint8).view(torch.uint8) + w31_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w31_scales) w2_weight_scale = layer.w2_weight_scale.data w2_scales = w2_weight_scale.to(torch.uint8).view(torch.uint8) - w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90( - w2_scales) - - layer.w13_weight = torch.nn.Parameter(torch.cat([w3_w, w1_w], - dim=1), - requires_grad=False) - layer.w13_bias = torch.nn.Parameter(w13_bias_swapped, - requires_grad=False) + w2_scales_interleaved = _interleave_mxfp4_cutlass_sm90(w2_scales) + + layer.w13_weight = torch.nn.Parameter( + torch.cat([w3_w, w1_w], dim=1), requires_grad=False + ) + layer.w13_bias = torch.nn.Parameter( + w13_bias_swapped, requires_grad=False + ) layer.w13_weight_scale = torch.nn.Parameter( - w31_scales_interleaved, requires_grad=False) + w31_scales_interleaved, requires_grad=False + ) layer.w2_weight_scale = torch.nn.Parameter( - w2_scales_interleaved, requires_grad=False) + w2_scales_interleaved, requires_grad=False + ) elif self.mxfp4_backend == Mxfp4Backend.TRITON: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -590,22 +698,25 @@ def _interleave_mxfp4_cutlass_sm90(w): # batched activation format. As self.fused_experts is not # initialized at this point, we resort to checking the MoE config # directly. - is_batched_moe = (self.moe.use_pplx_kernels - or self.moe.use_deepep_ll_kernels) + is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels if is_batched_moe: num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 else: num_warps = 8 w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( - layer.w13_weight, layer.w13_weight_scale, num_warps) + layer.w13_weight, layer.w13_weight_scale, num_warps + ) w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( - layer.w2_weight, layer.w2_weight_scale, num_warps) + layer.w2_weight, layer.w2_weight_scale, num_warps + ) self.w13_precision_config = PrecisionConfig( - weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)) + weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) + ) self.w2_precision_config = PrecisionConfig( - weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)) + weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) + ) self.w13_weight_triton_tensor = w13_weight self.w2_weight_triton_tensor = w2_weight @@ -644,8 +755,8 @@ def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): return tile_tokens_dim def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: - + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: if self.mxfp4_backend == Mxfp4Backend.MARLIN: return mxfp4_w4a16_moe_quant_config( w1_bias=layer.w13_bias, @@ -677,14 +788,19 @@ def select_gemm_impl( prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: - if (prepare_finalize.activation_format == - mk.FusedMoEActivationFormat.BatchedExperts): + if ( + prepare_finalize.activation_format + == mk.FusedMoEActivationFormat.BatchedExperts + ): raise NotImplementedError( - "Mxfp4 does not support batched experts format for EP") + "Mxfp4 does not support batched experts format for EP" + ) else: assert self.moe_quant_config is not None - if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + if ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + ): # B200 code-path kwargs = { "gemm1_alpha": layer.gemm1_alpha, @@ -693,36 +809,34 @@ def select_gemm_impl( # TODO(bnell): part of quant_config "max_capture_size": self.max_capture_size, } - return TrtLlmGenExperts(self.moe, self.moe_quant_config, - **kwargs) - elif (self.mxfp4_backend == Mxfp4Backend.MARLIN): + return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs) + elif self.mxfp4_backend == Mxfp4Backend.MARLIN: return MarlinExperts(self.moe_quant_config) else: return OAITritonExperts(self.moe_quant_config) def _route_and_experts( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert isinstance(self.fused_experts, mk.FusedMoEModularKernel) topk_weights, topk_ids, _ = FusedMoE.select_experts( @@ -741,12 +855,17 @@ def _route_and_experts( expert_map=expert_map, expert_load_view=expert_load_view, logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count) + logical_replica_count=logical_replica_count, + ) - w13_weight = (self.w13_weight_triton_tensor - if layer.w13_weight is None else layer.w13_weight) - w2_weight = (self.w2_weight_triton_tensor - if layer.w2_weight is None else layer.w2_weight) + w13_weight = ( + self.w13_weight_triton_tensor + if layer.w13_weight is None + else layer.w13_weight + ) + w2_weight = ( + self.w2_weight_triton_tensor if layer.w2_weight is None else layer.w2_weight + ) assert all([w is not None for w in [w13_weight, w2_weight]]) return self.fused_experts( @@ -785,7 +904,6 @@ def apply( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") @@ -824,7 +942,8 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + ) return torch.ops.vllm.fused_marlin_moe( x, @@ -843,28 +962,39 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, activation=activation, - expert_map=expert_map) + expert_map=expert_map, + ) assert _can_support_mxfp4( - use_grouped_topk, topk_group, num_expert_group, expert_map, - custom_routing_function, e_score_correction_bias, - apply_router_weight_on_input, scoring_func, activation, - expert_load_view, logical_to_physical_map, - logical_replica_count), ( - "MXFP4 are not supported with this configuration.") - - if (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16): + use_grouped_topk, + topk_group, + num_expert_group, + expert_map, + custom_routing_function, + e_score_correction_bias, + apply_router_weight_on_input, + scoring_func, + activation, + expert_load_view, + logical_to_physical_map, + logical_replica_count, + ), "MXFP4 are not supported with this configuration." + + if ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 + ): from flashinfer import trtllm_fp4_block_scale_moe + if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16: assert x.dtype == torch.bfloat16 x_quant = x x_scale = None elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM: from flashinfer import mxfp8_quantize + x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8 - x_scale = x_scale.view(torch.float8_e4m3fn).reshape( - *x.shape[:-1], -1) + x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1) trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), @@ -897,8 +1027,10 @@ def apply( tune_max_num_tokens=self.max_capture_size, )[0] return trtllm_gen_output - elif (self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS - or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16): + elif ( + self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 + ): from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe topk_weights, topk_ids, _ = FusedMoE.select_experts( @@ -916,13 +1048,11 @@ def apply( # Backend-specific preparation if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS: - from flashinfer import mxfp8_quantize x_quant, x_scale = mxfp8_quantize(x, True, 32) - fake_input_scale = torch.ones(self.num_experts, - device=x.device) + fake_input_scale = torch.ones(self.num_experts, device=x.device) quant_scales = [ layer.w13_weight_scale.contiguous().view(torch.int32), fake_input_scale, @@ -934,10 +1064,8 @@ def apply( extra_kwargs = dict( use_mxfp8_act_scaling=True, input_sf=x_scale, - fc1_expert_weights=layer.w13_weight.contiguous().view( - torch.long), - fc2_expert_weights=layer.w2_weight.contiguous().view( - torch.long), + fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long), + fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long), ) elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16: assert x.dtype == torch.bfloat16 @@ -978,7 +1106,9 @@ def apply( return output elif self.mxfp4_backend == Mxfp4Backend.TRITON: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 - triton_kernel_moe_forward) + triton_kernel_moe_forward, + ) + return triton_kernel_moe_forward( hidden_states=x, w1=self.w13_weight_triton_tensor, diff --git a/vllm/model_executor/layers/quantization/petit.py b/vllm/model_executor/layers/quantization/petit.py index 5b9fee69bb02..60519bdaea02 100644 --- a/vllm/model_executor/layers/quantization/petit.py +++ b/vllm/model_executor/layers/quantization/petit.py @@ -9,19 +9,24 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.petit_utils import ( - apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit, - verify_petit_nvfp4_supported) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) -from vllm.model_executor.parameter import (ModelWeightParameter, - PerTensorScaleParameter) + apply_petit_nvfp4_linear, + prepare_nvfp4_layer_for_petit, + verify_petit_nvfp4_supported, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped +from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from vllm.platforms import current_platform # Initialize logger for the module @@ -43,8 +48,10 @@ def __init__( self._check_hardware_support() self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: - logger.warning("Detected nvfp4 checkpoint. Please note that the " - "format is experimental and subject to change.") + logger.warning( + "Detected nvfp4 checkpoint. Please note that the " + "format is experimental and subject to change." + ) self.group_size = group_size self.kv_cache_quant_algo = kv_cache_quant_algo self.exclude_modules = exclude_modules @@ -61,7 +68,8 @@ def _check_hardware_support(self) -> None: "The 'petit' quantization backend is designed for AMD GPUs " "and is not supported on the CUDA platform. For NVIDIA GPUs, " "please use a different quantization method such as FP8, AWQ, " - "or GPTQ.") + "or GPTQ." + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -86,8 +94,7 @@ def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config": quant_method_raw = qc.get("quant_algo") if not isinstance(quant_method_raw, str) or not quant_method_raw: - raise ValueError( - "Missing or invalid 'quant_algo' in quantization config.") + raise ValueError("Missing or invalid 'quant_algo' in quantization config.") quant_method = quant_method_raw.upper() group_size_raw = qc.get("group_size") @@ -101,19 +108,18 @@ def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config": kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto" if not isinstance(kv_cache_quant_algo_raw, str): - raise ValueError( - "'kv_cache_quant_algo' must be a string if provided.") + raise ValueError("'kv_cache_quant_algo' must be a string if provided.") kv_cache_quant_algo = kv_cache_quant_algo_raw exclude_raw = qc.get("exclude_modules", []) if exclude_raw is None: exclude_modules: list[str] = [] elif isinstance(exclude_raw, list) and all( - isinstance(x, str) for x in exclude_raw): + isinstance(x, str) for x in exclude_raw + ): exclude_modules = exclude_raw else: - raise ValueError( - "'exclude_modules' must be a list[str] (or omitted).") + raise ValueError("'exclude_modules' must be a list[str] (or omitted).") is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method @@ -126,7 +132,8 @@ def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config": @classmethod def override_quantization_method( - cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + cls, hf_quant_cfg, user_quant + ) -> Optional[QuantizationMethods]: if not current_platform.is_rocm(): return None @@ -142,23 +149,24 @@ def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool: algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() return algo == "NVFP4" - def is_layer_excluded(self, prefix: str, - exclude_modules: list[str]) -> bool: + def is_layer_excluded(self, prefix: str, exclude_modules: list[str]) -> bool: for pattern in exclude_modules: regex_str = pattern.replace(".", r"\.").replace("*", r".*") if re.fullmatch(regex_str, prefix): return True return False - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import exclude = self.require_exclude_modules() if isinstance(layer, LinearBase): if is_layer_skipped(prefix, exclude) or self.is_layer_excluded( - prefix, exclude): + prefix, exclude + ): return UnquantizedLinearMethod() return PetitNvFp4LinearMethod(self) elif isinstance(layer, Attention): @@ -220,8 +228,10 @@ def create_weights( ): del input_size, output_size if not self.quant_config.is_checkpoint_nvfp4_serialized: - raise ValueError("NVFP4 quantization was selected, " - " dynamic quantization is not supported.") + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") @@ -231,12 +241,15 @@ def create_weights( layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition if input_size_per_partition % 16 != 0: - raise ValueError("Unsupported model when in features size is " - "not multiple of 16") + raise ValueError( + "Unsupported model when in features size is not multiple of 16" + ) - weight_dtype = (torch.float8_e4m3fn - if self.quant_config.is_checkpoint_nvfp4_serialized - else params_dtype) + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype + ) weight = ModelWeightParameter( data=torch.empty( @@ -283,8 +296,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) layer.input_scale = Parameter(input_scale_2, requires_grad=False) layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) - layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, - requires_grad=False) + layer.alpha = Parameter( + layer.input_scale * layer.weight_scale_2, requires_grad=False + ) prepare_nvfp4_layer_for_petit(layer) del layer.input_scale diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 45ea8e3520f1..c0156321f65d 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -8,18 +8,19 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizeMethodBase) -from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, - Fp8KVCacheMethod, - Fp8LinearMethod) +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase +from vllm.model_executor.layers.quantization.fp8 import ( + Fp8Config, + Fp8KVCacheMethod, + Fp8LinearMethod, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, is_layer_skipped) -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp) + GroupShape, + is_layer_skipped, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.platforms import current_platform ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -36,20 +37,20 @@ def __init__( ignored_layers: Optional[list[str]] = None, ) -> None: if not current_platform.is_rocm(): - raise ValueError( - "ptpc_fp8 quantization is supported only on ROCm.") + raise ValueError("ptpc_fp8 quantization is supported only on ROCm.") if not current_platform.has_device_capability(94): raise ValueError( "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501 ) if activation_scheme == "static": - raise ValueError( - "ptpc_fp8 as of now only support dynamic quantization.") + raise ValueError("ptpc_fp8 as of now only support dynamic quantization.") - super().__init__(is_checkpoint_fp8_serialized=False, - activation_scheme=activation_scheme, - ignored_layers=ignored_layers) + super().__init__( + is_checkpoint_fp8_serialized=False, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -59,11 +60,11 @@ def get_name(cls) -> QuantizationMethods: def from_config(cls, config: dict[str, Any]) -> "PTPCFp8Config": activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) - return cls(activation_scheme=activation_scheme, - ignored_layers=ignored_layers) + return cls(activation_scheme=activation_scheme, ignored_layers=ignored_layers) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): @@ -79,7 +80,7 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): """Linear method for Per-Token and Per-Channel FP8 Quantization. Only supports loading quantized BF16 model checkpoints with dynamic activation scaling. To load FP16 model checkpoints, user must specify - to convert the FP16 model weight loading into BF16. + to convert the FP16 model weight loading into BF16. The weight scaling factor will be initialized after the model weights are loaded. @@ -92,38 +93,45 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): """ def __init__(self, quant_config: PTPCFp8Config): - assert current_platform.is_rocm(), \ + assert current_platform.is_rocm(), ( "PTPCFp8LinearMethod is only supported on ROCm." + ) super().__init__(quant_config=quant_config) # Force weight quantization self.quant_config.is_checkpoint_fp8_serialized = False self.fp8_linear = Fp8LinearOp( - act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN) + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.weight = torch.nn.Parameter(layer.weight.data, - requires_grad=False) + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) - assert layer.weight.data.dtype == torch.bfloat16, \ - f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 + assert layer.weight.data.dtype == torch.bfloat16, ( + f"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. {str(layer.weight.data.dtype)} is specified." # noqa: E501 + ) # Quantize the weights. qweight, weight_scale = ops.scaled_fp8_quant( - layer.weight, scale=None, use_per_token_if_dynamic=True) + layer.weight, scale=None, use_per_token_if_dynamic=True + ) # Update the layer with the new values. layer.weight = Parameter( - qweight.t(), requires_grad=False) # Pretranspose the weight + qweight.t(), requires_grad=False + ) # Pretranspose the weight layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.input_scale = None - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - input_scale=None, - input_scale_ub=None, - bias=bias) + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=None, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index c65212c01819..37911a549645 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -8,18 +8,30 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 - QuarkMoEMethod) + QuarkMoEMethod, +) from vllm.model_executor.layers.quantization.quark.schemes import ( - QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8) + QuarkScheme, + QuarkW4A4MXFP4, + QuarkW8A8Fp8, + QuarkW8A8Int8, +) from vllm.model_executor.layers.quantization.quark.utils import ( - deep_compare, should_ignore_layer) + deep_compare, + should_ignore_layer, +) from vllm.platforms import current_platform __all__ = ["QuarkLinearMethod"] @@ -28,12 +40,13 @@ class QuarkConfig(QuantizationConfig): - - def __init__(self, - quant_config: dict[str, Any], - kv_cache_group: Optional[list[str]] = None, - kv_cache_config: Optional[dict[str, Any]] = None, - pack_method: str = "reorder"): + def __init__( + self, + quant_config: dict[str, Any], + kv_cache_group: Optional[list[str]] = None, + kv_cache_config: Optional[dict[str, Any]] = None, + pack_method: str = "reorder", + ): super().__init__() if kv_cache_group is None: kv_cache_group = [] @@ -55,15 +68,16 @@ def get_min_capability(cls) -> int: def get_name(self) -> QuantizationMethods: return "quark" - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import # Check if the layer is skipped for quantization. exclude_layers = cast(list[str], self.quant_config.get("exclude")) - if should_ignore_layer(prefix, - ignore=exclude_layers, - fused_mapping=self.packed_modules_mapping): + if should_ignore_layer( + prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping + ): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) @@ -73,17 +87,17 @@ def get_quant_method(self, layer: torch.nn.Module, return QuarkKVCacheMethod(self) if isinstance(layer, FusedMoE): - return QuarkMoEMethod.get_moe_method(self, - module=layer, - layer_name=prefix) + return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix) return None @classmethod def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": export_config = config.get("export") if export_config is None: - raise ValueError("The export key should be included in " - "the configurations of Quark quantized model") + raise ValueError( + "The export key should be included in " + "the configurations of Quark quantized model" + ) kv_cache_group = cast(list[str], export_config.get("kv_cache_group")) pack_method = cast(str, export_config.get("pack_method")) @@ -96,33 +110,32 @@ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": kv_cache_config = None else: kv_cache_set = set(kv_cache_group) - layer_quant_config = cast(dict[str, Any], - config.get("layer_quant_config")) + layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config")) layer_quant_names = list(layer_quant_config.keys()) layer_quant_set = set(layer_quant_names) if not kv_cache_set.issubset(layer_quant_set): - raise ValueError("The Quark quantized model has the " - "kv_cache_group parameter setting, " - "but no kv_cache quantization settings " - "were found in the quantization " - "configuration.") + raise ValueError( + "The Quark quantized model has the " + "kv_cache_group parameter setting, " + "but no kv_cache quantization settings " + "were found in the quantization " + "configuration." + ) q_configs = [ cast(dict[str, Any], layer_quant_config.get(name)) for name in kv_cache_group ] - if not all( - deep_compare(q_config, q_configs[0]) - for q_config in q_configs): + if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs): raise ValueError( "The quantization method used for kv_cache should " "be the same, but the quantization method for the " - "kv_cache layer in the config is different.") + "kv_cache layer in the config is different." + ) kv_cache_config = q_configs[0].get("output_tensors") if kv_cache_config is None: - raise ValueError( - "The kv_cache quantization configuration is empty.") + raise ValueError("The kv_cache quantization configuration is empty.") # Since we have already set kv_cache quantization configurations, # we will remove the quantization configuration for the @@ -132,23 +145,22 @@ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig": # In case q_proj output is also quantized, remove the configuration # to keep qkv consistency. - q_proj_q_config = cast(dict[str, Any], - layer_quant_config.get("*q_proj")) + q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj")) if q_proj_q_config is not None: q_proj_q_config["output_tensors"] = None - return cls(quant_config=config, - kv_cache_group=kv_cache_group, - kv_cache_config=kv_cache_config, - pack_method=pack_method) + return cls( + quant_config=config, + kv_cache_group=kv_cache_group, + kv_cache_config=kv_cache_config, + pack_method=pack_method, + ) @classmethod def get_config_filenames(cls) -> list[str]: return [] - def _check_scheme_supported(self, - min_capability: int, - error: bool = True) -> bool: + def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool: capability_tuple = current_platform.get_device_capability() if capability_tuple is not None: @@ -158,26 +170,33 @@ def _check_scheme_supported(self, raise RuntimeError( "Quantization scheme is not supported for ", f"the current GPU. Min capability: {min_capability}. ", - f"Current capability: {capability}.") + f"Current capability: {capability}.", + ) return supported else: return False - def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]], - input_quant: Optional[dict[str, Any]]) -> bool: + def _is_fp8_w8a8( + self, + weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]], + ) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: return False # Confirm weight scheme is supported - is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3" - and input_quant.get("dtype") == "fp8_e4m3") + is_fp8_dtype = ( + weight_quant.get("dtype") == "fp8_e4m3" + and input_quant.get("dtype") == "fp8_e4m3" + ) is_static_weight = not weight_quant.get("is_dynamic") - is_per_tensor_or_channel_weight = (weight_quant.get("qscheme") - in ["per_tensor", "per_channel"]) + is_per_tensor_or_channel_weight = weight_quant.get("qscheme") in [ + "per_tensor", + "per_channel", + ] - if not (is_fp8_dtype and is_static_weight - and is_per_tensor_or_channel_weight): + if not (is_fp8_dtype and is_static_weight and is_per_tensor_or_channel_weight): return False # Dynamic quantization is always supported if weights supported. @@ -185,76 +204,86 @@ def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]], return True # Confirm activation scheme is supported. - is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor") + is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor" return is_per_tensor_activation - def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]], - input_quant: Optional[dict[str, Any]]) -> bool: + def _is_static_tensor_w8a8( + self, + weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]], + ) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: return False - is_int8_dtype = (weight_quant.get("dtype") == "int8" - and input_quant.get("dtype") == "int8") + is_int8_dtype = ( + weight_quant.get("dtype") == "int8" and input_quant.get("dtype") == "int8" + ) - is_tensor = (weight_quant.get("qscheme") - in ["per_tensor", "per_channel"] - and input_quant.get("qscheme") == "per_tensor") + is_tensor = ( + weight_quant.get("qscheme") in ["per_tensor", "per_channel"] + and input_quant.get("qscheme") == "per_tensor" + ) - is_static = (not weight_quant.get("is_dynamic") - and not input_quant.get("is_dynamic")) + is_static = not weight_quant.get("is_dynamic") and not input_quant.get( + "is_dynamic" + ) - is_weight_symmetric = (weight_quant.get("symmetric") is True) + is_weight_symmetric = weight_quant.get("symmetric") is True # Both symmetric and asymmetric input quantization supported. # Only symmetric weight quantization supported. return is_int8_dtype and is_tensor and is_weight_symmetric and is_static - def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]], - input_quant: Optional[dict[str, Any]]) -> bool: + def _is_mx_fp4( + self, + weight_quant: Optional[dict[str, Any]], + input_quant: Optional[dict[str, Any]], + ) -> bool: # Confirm weights and input quantized. if weight_quant is None or input_quant is None: - logger.debug("Quark model is not in MX-FP4 format: " - "weight_quant or input_quant not set") + logger.debug( + "Quark model is not in MX-FP4 format: " + "weight_quant or input_quant not set" + ) return False # Input and weight dtype needs to be fp4. - if weight_quant.get("dtype") != "fp4" or input_quant.get( - "dtype") != "fp4": + if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4": logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") return False # Input and weight qscheme needs to be per group. - if weight_quant.get("qscheme") != "per_group" or input_quant.get( - "qscheme") != "per_group": + if ( + weight_quant.get("qscheme") != "per_group" + or input_quant.get("qscheme") != "per_group" + ): logger.debug("Quark model is not in MX-FP4 format: not per_group") return False # Input and weight group size needs to be 32. - if weight_quant.get("group_size") != 32 or input_quant.get( - "group_size") != 32: - logger.debug( - "Quark model is not in MX-FP4 format: not group_size=32") + if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32: + logger.debug("Quark model is not in MX-FP4 format: not group_size=32") return False # Activations need to use dynamic quantization. if input_quant.get("is_dynamic") is False: - logger.debug( - "Quark model is not in MX-FP4 format: not activation dynamic") + logger.debug("Quark model is not in MX-FP4 format: not activation dynamic") return False # Activations and weight scales need to be in e8m0 format. - if weight_quant.get("scale_format") != "e8m0" or input_quant.get( - "scale_format") != "e8m0": - logger.debug( - "Quark model is not in MX-FP4 format: not scale_format e8m0") + if ( + weight_quant.get("scale_format") != "e8m0" + or input_quant.get("scale_format") != "e8m0" + ): + logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0") return False return True - def _find_matched_config(self, layer_name: str, - module: torch.nn.Module) -> dict[str, Any]: - + def _find_matched_config( + self, layer_name: str, module: torch.nn.Module + ) -> dict[str, Any]: proj_name = layer_name.split(".")[-1] if proj_name in self.packed_modules_mapping: shard_proj_names = self.packed_modules_mapping[proj_name] @@ -269,59 +298,66 @@ def _find_matched_config(self, layer_name: str, for shard_name in shard_names ] if not all( - deep_compare(q_config, shard_configs[0]) - for q_config in shard_configs): + deep_compare(q_config, shard_configs[0]) for q_config in shard_configs + ): raise ValueError( f"Found a different quantization configuration for " f"{shard_proj_names} in {layer_name}. vLLM " - "requires all to use the same scheme.") + "requires all to use the same scheme." + ) return shard_configs[0] else: layer_quant_config = cast( - dict[str, Any], self.quant_config.get("layer_quant_config")) + dict[str, Any], self.quant_config.get("layer_quant_config") + ) for name_pattern in layer_quant_config: if fnmatch.fnmatch(layer_name, name_pattern): return layer_quant_config[name_pattern] layer_type = cast(str, type(module)) layer_type_quant_config = cast( - dict[str, Any], - self.quant_config.get("layer_type_quant_config")) + dict[str, Any], self.quant_config.get("layer_type_quant_config") + ) if layer_type in layer_type_quant_config: return layer_type_quant_config[layer_type] global_quant_config = cast( - dict[str, Any], self.quant_config.get("global_quant_config")) + dict[str, Any], self.quant_config.get("global_quant_config") + ) return global_quant_config def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": if config.get("output_tensors") or config.get("bias"): raise NotImplementedError( "Currently, Quark models with output_tensors " - "and bias quantized are not supported") + "and bias quantized are not supported" + ) weight_config = cast(dict[str, Any], config.get("weight")) input_config = cast(dict[str, Any], config.get("input_tensors")) if self._is_fp8_w8a8(weight_config, input_config): is_fp8_w8a8_supported = self._check_scheme_supported( - QuarkW8A8Fp8.get_min_capability(), error=False) + QuarkW8A8Fp8.get_min_capability(), error=False + ) if is_fp8_w8a8_supported: return QuarkW8A8Fp8(weight_config, input_config) elif self._is_static_tensor_w8a8(weight_config, input_config): weight_qscheme = cast(str, weight_config.get("qscheme")) - return QuarkW8A8Int8(qscheme=weight_qscheme, - is_static_input_scheme=True, - input_symmetric=input_config.get("symmetric")) + return QuarkW8A8Int8( + qscheme=weight_qscheme, + is_static_input_scheme=True, + input_symmetric=input_config.get("symmetric"), + ) elif self._is_mx_fp4(weight_config, input_config): return QuarkW4A4MXFP4(weight_config, input_config) - raise NotImplementedError("No quark compatible scheme was found. " - f"Weight config: {weight_config}, " - f"Input config: {input_config}") - - def get_scheme(self, layer: torch.nn.Module, - layer_name: str) -> "QuarkScheme": + raise NotImplementedError( + "No quark compatible scheme was found. " + f"Weight config: {weight_config}, " + f"Input config: {input_config}" + ) + def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": layer_quant_config = self._find_matched_config(layer_name, layer) # Find the quant_scheme @@ -335,7 +371,7 @@ def get_scheme(self, layer: torch.nn.Module, def get_cache_scale(self, name: str) -> Optional[str]: """ Check whether the param name matches the format for k/v cache scales - in quark. If this is the case, return its equivalent param name + in quark. If this is the case, return its equivalent param name expected by vLLM :param name: param name @@ -355,18 +391,22 @@ def get_cache_scale(self, name: str) -> Optional[str]: class QuarkLinearMethod(LinearMethodBase): - def __init__(self, quantization_config: QuarkConfig): self.quantization_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.scheme.process_weights_after_loading(layer) - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): """ Use the CompressedTensorsScheme associated with each layer to create the necessary parameters for the layer. See LinearMethodBase for param @@ -380,12 +420,15 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes=output_partition_sizes, output_size=output_size, params_dtype=params_dtype, - weight_loader=weight_loader) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None): + weight_loader=weight_loader, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): """ Use the output of create_weights and the CompressedTensorsScheme associated with the layer to apply the forward pass with the @@ -422,11 +465,13 @@ def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]): if dtype != "fp8_e4m3": raise NotImplementedError( "Currently supported kv cache quantization is " - f"dtype=fp8_e4m3, however received {dtype}") + f"dtype=fp8_e4m3, however received {dtype}" + ) qscheme = kv_cache_config.get("qscheme") if qscheme != "per_tensor": raise NotImplementedError( "Only support per-tensor scaling factor " "for quark KV cache. " - f"Expected qscheme: per_tensor, found qscheme: {qscheme}") + f"Expected qscheme: per_tensor, found qscheme: {qscheme}" + ) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 24497cc756c1..810057757a83 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -8,66 +8,71 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, - FusedMoEMethodBase, - FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEConfig, + FusedMoEMethodBase, + FusedMoeWeightScaleSupported, +) from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, - mxfp4_w4a4_moe_quant_config) + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, + mxfp4_w4a4_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - prepare_moe_fp8_layer_for_marlin) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - OCP_MX_BLOCK_SIZE) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + prepare_moe_fp8_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) + all_close_1d, + normalize_e4m3fn_to_e4m3fnuz, + per_tensor_dequantize, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types logger = init_logger(__name__) -__all__ = [ - "QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod" -] +__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod"] class QuarkMoEMethod(FusedMoEMethodBase): - def __init__(self, moe: FusedMoEConfig): super().__init__(moe) @staticmethod def get_moe_method( - quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 - module: torch.nn.Module, - layer_name: str) -> "QuarkMoEMethod": - layer_quant_config = quant_config._find_matched_config( - layer_name, module) - - if (layer_quant_config.get("output_tensors") - or layer_quant_config.get("bias")): - raise NotImplementedError("Currently, Quark models with " - "output_tensors and bias " - "quantized are not supported") + quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 + module: torch.nn.Module, + layer_name: str, + ) -> "QuarkMoEMethod": + layer_quant_config = quant_config._find_matched_config(layer_name, module) + + if layer_quant_config.get("output_tensors") or layer_quant_config.get("bias"): + raise NotImplementedError( + "Currently, Quark models with " + "output_tensors and bias " + "quantized are not supported" + ) weight_config = layer_quant_config.get("weight") input_config = layer_quant_config.get("input_tensors") if quant_config._is_fp8_w8a8(weight_config, input_config): - return QuarkW8A8Fp8MoEMethod(weight_config, input_config, - module.moe_config) + return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_mx_fp4(weight_config, input_config): - return QuarkW4A4MXFp4MoEMethod(weight_config, input_config, - module.moe_config) + return QuarkW4A4MXFp4MoEMethod( + weight_config, input_config, module.moe_config + ) else: raise RuntimeError("Unsupported FusedMoe scheme") class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): - def __init__( self, weight_config: dict[str, Any], @@ -80,38 +85,50 @@ def __init__( self.weight_qscheme = self.weight_quant.get("qscheme") self.input_qscheme = self.input_quant.get("qscheme") - per_tensor = (self.weight_qscheme == "per_tensor" - and self.input_qscheme == "per_tensor") - per_channel = (self.weight_qscheme == "per_channel" - and self.input_qscheme == "per_channel") - self.act_quant_group_shape = GroupShape.PER_TOKEN \ - if per_channel else GroupShape.PER_TENSOR + per_tensor = ( + self.weight_qscheme == "per_tensor" and self.input_qscheme == "per_tensor" + ) + per_channel = ( + self.weight_qscheme == "per_channel" and self.input_qscheme == "per_channel" + ) + self.act_quant_group_shape = ( + GroupShape.PER_TOKEN if per_channel else GroupShape.PER_TENSOR + ) if not (per_tensor or per_channel): raise ValueError( "For FP8 Fused MoE layers, only per-tensor and per-channel " "scales for weights and activations are supported. Found " - f"{self.weight_qscheme}, {self.input_qscheme}") # noqa E501 + f"{self.weight_qscheme}, {self.input_qscheme}" + ) # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") if self.static_input_scales and per_channel: raise ValueError( "For FP8 Fused MoE layer, we require either per tensor or " - "channelwise, dynamic per token quantization.") + "channelwise, dynamic per token quantization." + ) # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) + self.use_marlin = ( + not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN + ) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size layer.num_experts = num_experts @@ -120,21 +137,27 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype = torch.float8_e4m3fn # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -142,48 +165,54 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, if self.weight_qscheme == "per_tensor": # Allocate 2 scales for w1 and w3 respectively. # They are combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, 2, dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-TENSOR quantization for FusedMoE.weight_loader. extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) elif self.weight_qscheme == "per_channel": # quark's scale is 1 dim. - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - dtype=torch.float32), - requires_grad=False) + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, hidden_size, dtype=torch.float32), - requires_grad=False) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float32), + requires_grad=False, + ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-CHANNEL quantization for FusedMoE.weight_loader. extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: - w13_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) - w2_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) else: @@ -194,46 +223,53 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: - if (layer.w13_input_scale is None or layer.w2_input_scale is None): + if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.w13_input_scale) - or not all_close_1d(layer.w2_input_scale)): + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") + "for each layer. " + ) layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False) + layer.w13_input_scale.max(), requires_grad=False + ) layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False) + layer.w2_input_scale.max(), requires_grad=False + ) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, - layer.w13_input_scale) - w2_weight, w2_weight_scale, w2_input_scale = \ + w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, - layer.w2_input_scale) + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, - requires_grad=False) + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, - requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) # For per-tensor case, Fp8 moe kernel needs single weight scale # for w13 per expert. Use max then dequant and requant each expert. @@ -245,42 +281,45 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) start += shard_size - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) # quark's scale is 1 dim. elif self.weight_qscheme == "per_channel": if self.act_quant_group_shape == GroupShape.PER_TOKEN: w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1) layer.w13_weight_scale = torch.nn.Parameter( - w13_weight_scale, requires_grad=False) + w13_weight_scale, requires_grad=False + ) w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1) - layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, - requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - rocm_aiter_fused_experts, shuffle_weights) + rocm_aiter_fused_experts, + shuffle_weights, + ) # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data) + layer.w13_weight.data, layer.w2_weight.data + ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, - requires_grad=False) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts elif self.use_marlin: - prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale @@ -288,10 +327,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: self.fused_experts_func = None else: from vllm.model_executor.layers.fused_moe import fused_experts + self.fused_experts_func = fused_experts def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, @@ -327,7 +368,8 @@ def apply( if enable_eplb: raise NotImplementedError( - "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") + "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet." + ) topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, @@ -341,7 +383,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) if self.rocm_aiter_moe_enabled: return self.rocm_aiter_fused_experts_func( @@ -353,10 +396,10 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, quant_config=self.moe_quant_config, - expert_map=expert_map) + expert_map=expert_map, + ) if self.use_marlin: - assert activation == "silu", ( - f"{activation} not supported for Marlin MoE.") + assert activation == "silu", f"{activation} not supported for Marlin MoE." return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -371,7 +414,8 @@ def apply( quant_type_id=scalar_types.float8_e4m3fn.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, - expert_map=expert_map) + expert_map=expert_map, + ) assert self.fused_experts_func is not None @@ -386,11 +430,11 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, - quant_config=self.moe_quant_config) + quant_config=self.moe_quant_config, + ) class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): - def __init__( self, weight_config: dict[str, Any], @@ -403,19 +447,20 @@ def __init__( weight_qscheme = self.weight_quant.get("qscheme") input_qscheme = self.input_quant.get("qscheme") - if not (weight_qscheme == "per_group" - and input_qscheme == "per_group"): + if not (weight_qscheme == "per_group" and input_qscheme == "per_group"): raise ValueError( "For MX(FP4) Fused MoE layers, only per-group scales " "for weights and activations are supported. Found " - f"{weight_qscheme}, {input_qscheme}") # noqa E501 + f"{weight_qscheme}, {input_qscheme}" + ) # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") if self.static_input_scales: raise NotImplementedError( "QuarkW4A4MXFp4MoEMethod with static input scales is currently " - "not implemented. Please open an issue.") + "not implemented. Please open an issue." + ) if not current_platform.supports_mx(): self.emulate = True @@ -423,7 +468,8 @@ def __init__( "The current platform does not support native MXFP4 " "computation. Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision.") + "layers computed in high precision." + ) else: self.emulate = True logger.warning_once( @@ -431,36 +477,49 @@ def __init__( "computation, but kernels are not yet integrated in vLLM. " "Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " - "layers computed in high precision.") - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): + "layers computed in high precision." + ) + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) params_dtype = torch.uint8 # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // 2, - dtype=params_dtype), - requires_grad=False) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // 2, - dtype=params_dtype), - requires_grad=False) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=params_dtype, + ), + requires_grad=False, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -491,7 +550,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight_scale", w2_weight_scale) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: return mxfp4_w4a4_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, @@ -527,7 +587,8 @@ def apply( if enable_eplb: raise NotImplementedError( - "EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.") + "EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet." + ) from vllm.model_executor.layers.fused_moe import fused_experts @@ -543,7 +604,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) out = fused_experts( x, diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py index c167e949ac26..ddec0f6ea8eb 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py @@ -11,7 +11,7 @@ class QuarkScheme(ABC): """ - Abstract class used to describe the weight creation and forward pass + Abstract class used to describe the weight creation and forward pass of different quantization schemes supported by Quark. """ @@ -26,20 +26,21 @@ def get_min_capability(cls) -> int: @abstractmethod def create_weights(self, *args, **kwargs): """ - Weight creation for the particular scheme. Inputs to this function + Weight creation for the particular scheme. Inputs to this function """ raise NotImplementedError @abstractmethod - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]): + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ): """ - Run the forward pass for the particular scheme. This is where + Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied. - :param layer: torch.nn.Module with the registered weights and - other parameters relevant to the particular scheme. + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. :param x: input to the layer :param bias: bias parameter diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index bcf3911095ac..9bedd7fa2563 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -10,17 +10,21 @@ from vllm import envs from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4) -from vllm.model_executor.parameter import (GroupQuantScaleParameter, - PackedvLLMParameter) + OCP_MX_BLOCK_SIZE, + dequant_mxfp4, + quant_dequant_mxfp4, +) +from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter from vllm.platforms import current_platform @cache def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM \ + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM and envs.VLLM_ROCM_USE_AITER + ) try: @@ -29,6 +33,7 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool: from aiter.ops.triton.quant import dynamic_mxfp4_quant from vllm.utils import direct_register_custom_op + if is_rocm_aiter_fp4_asm_gemm_enabled(): from aiter import gemm_a4w4, per_1x32_f4_quant_hip @@ -51,17 +56,13 @@ def gemm_with_dynamic_quant( # 32 alignment is enough for dim0 padding of output for # gemm_a4w4 kernel - y = torch.empty((M + 31) // 32 * 32, - weight.shape[0], - device=x_q.device, - dtype=out_dtype) - - gemm_a4w4(x_q, - weight, - x_s, - weight_scale.view(x_s.dtype), - y, - bpreshuffle=True) + y = torch.empty( + (M + 31) // 32 * 32, weight.shape[0], device=x_q.device, dtype=out_dtype + ) + + gemm_a4w4( + x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True + ) return y[:M] else: if x_scales is None: @@ -69,10 +70,9 @@ def gemm_with_dynamic_quant( else: x_q = x x_s = x_scales - y = torch.empty(x_q.shape[0], - weight.shape[0], - device=x_q.device, - dtype=out_dtype) + y = torch.empty( + x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype + ) gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) return y @@ -85,9 +85,9 @@ def gemm_with_dynamic_quant_fake( rocm_use_aiter_fp4_asm_gemm: bool = False, out_dtype: Optional[torch.dtype] = torch.bfloat16, ) -> torch.Tensor: - return torch.empty((*x.shape[:-1], weight.shape[0]), - dtype=out_dtype, - device=x.device) + return torch.empty( + (*x.shape[:-1], weight.shape[0]), dtype=out_dtype, device=x.device + ) direct_register_custom_op( op_name="gemm_with_dynamic_quant", @@ -104,46 +104,45 @@ def gemm_with_dynamic_quant_fake( class QuarkW4A4MXFP4(QuarkScheme): - - def __init__(self, weight_quant_spec: dict[str, Any], - input_quant_spec: dict[str, Any]): + def __init__( + self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any] + ): self.out_dtype = torch.get_default_dtype() self.qscheme = "per_group" self.weight_quant_spec = weight_quant_spec self.input_quant_spec = input_quant_spec self.emulate = not current_platform.supports_mx() self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled() - if not self.emulate and (dynamic_mxfp4_quant is None - or gemm_afp4wfp4 is None): + if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None): # Currently need these kernels if not emulating raise NotImplementedError( f"{self.__class__.__name__} requires AITER to be installed " "for non-emulation mode! Please refer to " - "https://github.com/ROCm/aiter for installation details.") + "https://github.com/ROCm/aiter for installation details." + ) @classmethod def get_min_capability(cls) -> int: return 70 def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.weight = torch.nn.Parameter(layer.weight.data, - requires_grad=False) + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) if self.emulate: - layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, - requires_grad=False) + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False + ) try: from quark.torch.export.nn.modules import realquantizer - from quark.torch.quantization.config.config import ( - QuantizationSpec) + from quark.torch.quantization.config.config import QuantizationSpec except ImportError as err: raise ImportError( "The package `amd-quark` is required to use AMD Quark " "MX-FP4 models. Please install it with `pip install " - "amd-quark`.") from err + "amd-quark`." + ) from err - weight_quant_spec = QuantizationSpec.from_dict( - self.weight_quant_spec) + weight_quant_spec = QuantizationSpec.from_dict(self.weight_quant_spec) weight_quantizer = realquantizer.get_real_quantizer( qspec=weight_quant_spec, @@ -170,29 +169,34 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight_scale_shuffle = layer.weight_scale.data sm, sn = weight_scale_shuffle.shape weight_scale_shuffle = weight_scale_shuffle.view( - sm // 32, 2, 16, sn // 8, 2, 4, 1) + sm // 32, 2, 16, sn // 8, 2, 4, 1 + ) weight_scale_shuffle = weight_scale_shuffle.permute( - 0, 3, 5, 2, 4, 1, 6).contiguous() + 0, 3, 5, 2, 4, 1, 6 + ).contiguous() weight_scale_shuffle = weight_scale_shuffle.view(sm, sn) - layer.weight_scale = torch.nn.Parameter(weight_scale_shuffle, - requires_grad=False) + layer.weight_scale = torch.nn.Parameter( + weight_scale_shuffle, requires_grad=False + ) # shuffle weight weight_shuffle = layer.weight.data - weight_shuffle = shuffle_weight(weight_shuffle, - layout=(16, 16)) - layer.weight = torch.nn.Parameter(weight_shuffle, - requires_grad=False) + weight_shuffle = shuffle_weight(weight_shuffle, layout=(16, 16)) + layer.weight = torch.nn.Parameter(weight_shuffle, requires_grad=False) else: layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.data.T.contiguous(), - requires_grad=False) - - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + layer.weight_scale.data.T.contiguous(), requires_grad=False + ) + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes @@ -224,16 +228,21 @@ def create_weights(self, layer: torch.nn.Module, ) layer.register_parameter("weight_scale", weight_scale) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: if self.emulate: dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype) x = quant_dequant_mxfp4(x) return F.linear(x, dq_w, bias) else: return torch.ops.vllm.gemm_with_dynamic_quant( - x, layer.weight, layer.weight_scale, - self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype) + x, + layer.weight, + layer.weight_scale, + self.rocm_use_aiter_fp4_asm_gemm, + self.out_dtype, + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 2cb35249f49e..553698a7dc94 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,37 +7,43 @@ from torch.nn import Parameter from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) + Fp8LinearOp, + normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from vllm.platforms import current_platform __all__ = ["QuarkW8A8Fp8"] class QuarkW8A8Fp8(QuarkScheme): - - def __init__(self, weight_config: dict[str, Any], - input_config: Optional[dict[str, Any]]): + def __init__( + self, weight_config: dict[str, Any], input_config: Optional[dict[str, Any]] + ): self.weight_qscheme = cast(str, weight_config.get("qscheme")) self.is_static_input_scheme: bool = False self.input_qscheme: Optional[str] = None if input_config is not None: - self.is_static_input_scheme = not cast( - bool, input_config.get("is_dynamic")) + self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic")) self.input_qscheme = cast(str, input_config.get("qscheme")) - per_token = (not self.is_static_input_scheme - and self.input_qscheme == "per_channel") - self.act_quant_group_shape = GroupShape.PER_TOKEN \ - if per_token else GroupShape.PER_TENSOR + per_token = ( + not self.is_static_input_scheme and self.input_qscheme == "per_channel" + ) + self.act_quant_group_shape = ( + GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR + ) self.fp8_linear = Fp8LinearOp( act_quant_static=self.is_static_input_scheme, - act_quant_group_shape=self.act_quant_group_shape) + act_quant_group_shape=self.act_quant_group_shape, + ) self.out_dtype = torch.get_default_dtype() @classmethod @@ -51,14 +57,14 @@ def process_weights_after_loading(self, layer) -> None: # requantize so we can always run per tensor if self.weight_qscheme == "per_tensor": if current_platform.is_fp8_fnuz(): - input_scale = getattr(layer, 'input_scale', None) + input_scale = getattr(layer, "input_scale", None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, weight_scale=layer.weight_scale, - input_scale=input_scale) + input_scale=input_scale, + ) if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) + layer.input_scale = Parameter(input_scale, requires_grad=False) else: max_w_scale = layer.weight_scale weight = layer.weight @@ -77,15 +83,14 @@ def process_weights_after_loading(self, layer) -> None: weight = layer.weight if current_platform.is_fp8_fnuz(): - input_scale = getattr(layer, 'input_scale', None) - weight, weight_scale, input_scale = \ - normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=layer.weight_scale, - input_scale=input_scale) + input_scale = getattr(layer, "input_scale", None) + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=input_scale, + ) if input_scale is not None: - layer.input_scale = Parameter(input_scale, - requires_grad=False) + layer.input_scale = Parameter(input_scale, requires_grad=False) else: weight_scale = layer.weight_scale.data if self.act_quant_group_shape == GroupShape.PER_TOKEN: @@ -95,32 +100,37 @@ def process_weights_after_loading(self, layer) -> None: layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: - raise ValueError( - f"Unknown quantization scheme {self.weight_qscheme}") + raise ValueError(f"Unknown quantization scheme {self.weight_qscheme}") # INPUT SCALE if self.is_static_input_scheme: - layer.input_scale = Parameter(layer.input_scale.max(), - requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) else: layer.input_scale = None - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): output_size_per_partition = sum(output_partition_sizes) layer.logical_widths = output_partition_sizes # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # WEIGHT SCALE @@ -128,15 +138,16 @@ def create_weights(self, layer: torch.nn.Module, # the newly added parameters if self.weight_qscheme == "per_channel": weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes)), - dtype=torch.float32), + data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) else: assert self.weight_qscheme == "per_tensor" - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) # min requirement for fp8 kernels weight_scale[:] = torch.finfo(torch.float32).min @@ -144,20 +155,24 @@ def create_weights(self, layer: torch.nn.Module, # INPUT SCALE if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) input_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", input_scale) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - return self.fp8_linear.apply(input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - out_dtype=self.out_dtype, - input_scale=layer.input_scale, - bias=bias) + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.fp8_linear.apply( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, + input_scale=layer.input_scale, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py index ae68d5bbc268..c41dd05d1062 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -7,12 +7,16 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( - ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) + ScaledMMLinearLayerConfig, + choose_scaled_mm_linear_kernel, +) from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter) +from vllm.model_executor.parameter import ( + BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) logger = init_logger(__name__) @@ -20,8 +24,12 @@ class QuarkW8A8Int8(QuarkScheme): _kernel_backends_being_used: set[str] = set() - def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool], - input_symmetric: Optional[bool]): + def __init__( + self, + qscheme: str, + is_static_input_scheme: Optional[bool], + input_symmetric: Optional[bool], + ): self.qscheme = qscheme self.is_static_input_scheme = is_static_input_scheme self.input_symmetric = input_symmetric @@ -31,92 +39,101 @@ def get_min_capability(cls) -> int: # turing and up return 75 - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): layer.logical_widths = output_partition_sizes scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( is_channelwise=(self.qscheme == "per_channel"), is_static_input_scheme=(self.is_static_input_scheme is True), - input_symmetric=(self.input_symmetric is True)) + input_symmetric=(self.input_symmetric is True), + ) - kernel_type = choose_scaled_mm_linear_kernel( - scaled_mm_linear_kernel_config) + kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config) if kernel_type.__name__ not in self._kernel_backends_being_used: logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) self._kernel_backends_being_used.add(kernel_type.__name__) # WEIGHT - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=torch.int8), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) # WEIGHT SCALE if self.qscheme == "per_channel": weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes)), - dtype=torch.float32), + data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) ChannelQuantZPParameter = ChannelQuantScaleParameter weight_zero_point = ChannelQuantZPParameter( - data=torch.empty((sum(output_partition_sizes)), - dtype=torch.int8), + data=torch.empty((sum(output_partition_sizes)), dtype=torch.int8), output_dim=0, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) else: assert self.qscheme == "per_tensor" - weight_scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) PerTensorZPParameter = PerTensorScaleParameter weight_zero_point = PerTensorZPParameter( - data=torch.empty(len(output_partition_sizes), - dtype=torch.int8), - weight_loader=weight_loader) + data=torch.empty(len(output_partition_sizes), dtype=torch.int8), + weight_loader=weight_loader, + ) layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_zero_point", weight_zero_point) # INPUT SCALE if self.is_static_input_scheme: - input_scale = BasevLLMParameter(data=torch.empty( - 1, dtype=torch.float32), - weight_loader=weight_loader) + input_scale = BasevLLMParameter( + data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader + ) layer.register_parameter("input_scale", input_scale) - input_zero_point = BasevLLMParameter(data=torch.empty( - 1, dtype=torch.int8), - weight_loader=weight_loader) + input_zero_point = BasevLLMParameter( + data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader + ) layer.register_parameter("input_zero_point", input_zero_point) - self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, - w_q_param_name="weight", - w_s_param_name="weight_scale", - i_s_param_name="input_scale", - i_zp_param_name="input_zero_point", - azp_adj_param_name="azp_adj") + self.kernel = kernel_type( + c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj", + ) # Checkpoints are serialized in quark format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.register_parameter("weight_zero_point", None) - delattr(layer, 'weight_zero_point') + delattr(layer, "weight_zero_point") if self.input_symmetric: layer.register_parameter("input_zero_point", None) - delattr(layer, 'input_zero_point') + delattr(layer, "input_zero_point") self.kernel.process_weights_after_loading(layer) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/quark/utils.py b/vllm/model_executor/layers/quantization/quark/utils.py index 99f5ec15933a..0eb4b20a6e52 100644 --- a/vllm/model_executor/layers/quantization/quark/utils.py +++ b/vllm/model_executor/layers/quantization/quark/utils.py @@ -24,7 +24,7 @@ def deep_compare(dict1: Any, dict2: Any) -> bool: def should_ignore_layer( layer_name: Optional[str], ignore: Iterable[str], - fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> bool: if layer_name is None: return False @@ -50,7 +50,8 @@ def should_ignore_layer( should_ignore_layer = None for shard_name in shard_names: should_ignore_shard = check_equal_or_regex_match( - layer_name=shard_name, targets=ignore) + layer_name=shard_name, targets=ignore + ) # If shard_idx=0, set layer ignore to match shard. if should_ignore_layer is None: @@ -58,35 +59,34 @@ def should_ignore_layer( # If shard_idx=1+ confirm scheme matches prior shards. elif should_ignore_shard != should_ignore_layer: - raise ValueError(f"Found a different quantization schemes for " - f"{shard_proj_names} in {layer_name}. vLLM " - "requires all to use the same scheme.") + raise ValueError( + f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme." + ) # Unfused layers like down_proj and o_proj will match # the safetensors checkpoint already. else: - should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, - targets=ignore) + should_ignore_layer = check_equal_or_regex_match( + layer_name=layer_name, targets=ignore + ) assert should_ignore_layer is not None return should_ignore_layer -def check_equal_or_regex_match(layer_name: str, - targets: Iterable[str]) -> bool: +def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool: """ - Checks whether a layer_name is exactly equal or a regex match for + Checks whether a layer_name is exactly equal or a regex match for if target starts with 're:' to any target in list. """ - for target in targets: - if _is_equal_or_regex_match(layer_name, target): - return True - return False + return any(_is_equal_or_regex_match(layer_name, target) for target in targets) -def _is_equal_or_regex_match(value: str, - target: str, - check_contains: bool = False) -> bool: +def _is_equal_or_regex_match( + value: str, target: str, check_contains: bool = False +) -> bool: """ Checks whether a value is exactly equal or a regex match for target if target starts with 're:'. If check_contains is set to True, diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 015dc136bb82..e0070e207048 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -10,36 +10,45 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + FusedMoEConfig, + FusedMoEMethodBase, +) from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, int4_w4a16_moe_quant_config, - int8_w8a16_moe_quant_config) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - set_weight_attrs) + FusedMoEQuantConfig, + int4_w4a16_moe_quant_config, + int8_w8a16_moe_quant_config, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + set_weight_attrs, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) logger = init_logger(__name__) """By default, use 8 bit as target precision, but it can be overridden by setting the RTN_NUM_BITS envvar """ -NUM_BITS = os.getenv('RTN_NUM_BITS', "8") +NUM_BITS = os.getenv("RTN_NUM_BITS", "8") """By default, use group size of 128 parameters, but it can be overridden by setting the RTN_GROUP_SIZE envvar """ -GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128") +GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128") class RTNConfig(QuantizationConfig): - """Config class for RTN. - """ + """Config class for RTN.""" def __init__( - self, - weight_bits: int = int(NUM_BITS), - group_size: int = int(GROUP_SIZE), + self, + weight_bits: int = int(NUM_BITS), + group_size: int = int(GROUP_SIZE), ) -> None: self.weight_bits = weight_bits self.group_size = group_size @@ -47,11 +56,13 @@ def __init__( if self.weight_bits != 4 and self.weight_bits != 8: raise ValueError( "Currently, only 4-bit or 8-bit weight quantization is " - f"supported for RTN, but got {self.weight_bits} bits.") + f"supported for RTN, but got {self.weight_bits} bits." + ) def __repr__(self) -> str: - return (f"RTNConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size})") + return ( + f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})" + ) @classmethod def get_name(cls) -> QuantizationMethods: @@ -75,8 +86,9 @@ def from_config(cls, config: dict[str, Any]) -> "RTNConfig": group_size = cls.get_from_keys(config, ["group_size"]) return cls(weight_bits, group_size) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return RTNLinearMethod(self) elif isinstance(layer, FusedMoE): @@ -89,8 +101,9 @@ class RTNTensor: overloading the copy_ method. """ - def __init__(self, data: torch.Tensor, scale: torch.Tensor, - quant_config: RTNConfig) -> None: + def __init__( + self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig + ) -> None: self.data = data self.scale = scale self.quant_config = quant_config @@ -99,7 +112,9 @@ def narrow(self, dim, start, length): factor = 1 if self.quant_config.weight_bits == 8 else 2 return RTNTensor( self.data.narrow(dim, start // factor, length // factor), - self.scale.narrow(dim, start, length), self.quant_config) + self.scale.narrow(dim, start, length), + self.quant_config, + ) def __getitem__(self, key): return RTNTensor(self.data[key], self.scale[key], self.quant_config) @@ -115,9 +130,11 @@ def shape(self): return torch.Size((shape[0] * factor, shape[1])) def copy_(self, loaded_weight: torch.Tensor) -> None: - qweight, weight_scale = rtn_quantize(loaded_weight.cuda(), - self.quant_config.weight_bits, - self.quant_config.group_size) + qweight, weight_scale = rtn_quantize( + loaded_weight.cuda(), + self.quant_config.weight_bits, + self.quant_config.group_size, + ) self.data.copy_(qweight) self.scale.data.copy_(weight_scale) @@ -133,8 +150,9 @@ class RTNParameter(Parameter): def __new__(cls, data: torch.Tensor, **kwargs): return super().__new__(cls, data=data, requires_grad=False) - def __init__(self, data: torch.Tensor, scale: torch.Tensor, - quant_config: RTNConfig) -> None: + def __init__( + self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig + ) -> None: self.scale = scale self.quant_config = quant_config @@ -164,31 +182,39 @@ def create_weights( **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) - num_groups_per_col = (input_size_per_partition // - self.quant_config.group_size - if self.quant_config.group_size != -1 else 1) + num_groups_per_col = ( + input_size_per_partition // self.quant_config.group_size + if self.quant_config.group_size != -1 + else 1 + ) scale = Parameter( - torch.empty(output_size_per_partition, - num_groups_per_col, - dtype=params_dtype), + torch.empty( + output_size_per_partition, num_groups_per_col, dtype=params_dtype + ), requires_grad=False, ) factor = 1 if self.quant_config.weight_bits == 8 else 2 - weight = RTNParameter(data=torch.empty(output_size_per_partition // - factor, - input_size_per_partition, - dtype=torch.uint8), - scale=scale, - quant_config=self.quant_config) + weight = RTNParameter( + data=torch.empty( + output_size_per_partition // factor, + input_size_per_partition, + dtype=torch.uint8, + ), + scale=scale, + quant_config=self.quant_config, + ) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }) + set_weight_attrs( + weight, + { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }, + ) layer.register_parameter("scale", scale) layer.output_size_per_partition = output_size_per_partition @@ -196,10 +222,12 @@ def create_weights( def process_weights_after_loading(self, layer: torch.nn.Module) -> None: fix_weights(layer, "weight") - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: qweight = layer.weight scale = layer.scale @@ -213,57 +241,75 @@ def apply(self, class RTNMoEMethod(FusedMoEMethodBase): - def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig): super().__init__(moe) self.quant_config = quant_config - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): factor = 1 if self.quant_config.weight_bits == 8 else 2 # Fused gate_up_proj (column parallel) - num_groups_per_col = (hidden_size // self.quant_config.group_size - if self.quant_config.group_size != -1 else 1) + num_groups_per_col = ( + hidden_size // self.quant_config.group_size + if self.quant_config.group_size != -1 + else 1 + ) w13_scale = Parameter( - torch.empty(num_experts, - 2 * intermediate_size_per_partition, - num_groups_per_col, - dtype=params_dtype), + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + num_groups_per_col, + dtype=params_dtype, + ), requires_grad=False, ) layer.register_parameter("w13_scale", w13_scale) - w13_weight = RTNParameter(data=torch.empty( - num_experts, - 2 * intermediate_size_per_partition // factor, - hidden_size, - dtype=torch.uint8), - scale=w13_scale, - quant_config=self.quant_config) + w13_weight = RTNParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition // factor, + hidden_size, + dtype=torch.uint8, + ), + scale=w13_scale, + quant_config=self.quant_config, + ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - num_groups_per_col = (intermediate_size_per_partition // - self.quant_config.group_size - if self.quant_config.group_size != -1 else 1) - w2_scale = Parameter(torch.zeros(num_experts, - hidden_size, - num_groups_per_col, - dtype=params_dtype), - requires_grad=False) + num_groups_per_col = ( + intermediate_size_per_partition // self.quant_config.group_size + if self.quant_config.group_size != -1 + else 1 + ) + w2_scale = Parameter( + torch.zeros( + num_experts, hidden_size, num_groups_per_col, dtype=params_dtype + ), + requires_grad=False, + ) layer.register_parameter("w2_scale", w2_scale) - w2_weight = RTNParameter(data=torch.empty( - num_experts, - hidden_size // factor, - intermediate_size_per_partition, - dtype=torch.uint8), - scale=w2_scale, - quant_config=self.quant_config) + w2_weight = RTNParameter( + data=torch.empty( + num_experts, + hidden_size // factor, + intermediate_size_per_partition, + dtype=torch.uint8, + ), + scale=w2_scale, + quant_config=self.quant_config, + ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -273,12 +319,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: fix_weights(layer, "w2_weight", weight_bits == 4) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: weight_bits = self.quant_config.weight_bits group_size = self.quant_config.group_size assert weight_bits == 4 or weight_bits == 8 - config_builder = (int4_w4a16_moe_quant_config - if weight_bits == 4 else int8_w8a16_moe_quant_config) + config_builder = ( + int4_w4a16_moe_quant_config + if weight_bits == 4 + else int8_w8a16_moe_quant_config + ) return config_builder( w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, @@ -313,8 +363,7 @@ def apply( assert self.fused_experts is None if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `RTNMoEMethod` yet.") + raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.") from vllm.model_executor.layers.fused_moe import fused_experts @@ -330,7 +379,8 @@ def apply( scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + ) return fused_experts( x, @@ -347,15 +397,16 @@ def apply( ) -def rtn_quantize(tensor: torch.Tensor, num_bits: int, - group_size: int) -> tuple[torch.Tensor, torch.Tensor]: +def rtn_quantize( + tensor: torch.Tensor, num_bits: int, group_size: int +) -> tuple[torch.Tensor, torch.Tensor]: """Quantize a tensor using per-group static scaling factor. Args: tensor: The input tensor. num_bits: Target precision for the result (supported values are 8 or 4). - group_size: Quantization granularity. + group_size: Quantization granularity. If equal to -1, each row in the input tensor is treated as one group. """ @@ -364,15 +415,18 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int, tensor = tensor.unsqueeze(0) q_range = 2**num_bits - num_groups = (tensor.shape[1] * tensor.shape[2] // - group_size if group_size != -1 else tensor.shape[1]) + num_groups = ( + tensor.shape[1] * tensor.shape[2] // group_size + if group_size != -1 + else tensor.shape[1] + ) """Calculate a scaling factor per input group. """ input_flat = tensor.reshape(tensor.shape[0], num_groups, -1) input_min = torch.min(input_flat, dim=2, keepdim=True)[0] input_max = torch.max(input_flat, dim=2, keepdim=True)[0] input_max_abs = torch.max(input_min.abs(), input_max.abs()) - scale = (input_max_abs * 2.0 / (q_range - 1)) + scale = input_max_abs * 2.0 / (q_range - 1) """Scale each input group, round to the nearest integer, shift the range and truncate. """ @@ -388,9 +442,10 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int, if num_bits == 4: """Pack two 4-bit values into each byte. """ - inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xf) - inputs_q = inputs_q.reshape(tensor.shape[0], tensor.shape[1] // 2, - tensor.shape[2]) + inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xF) + inputs_q = inputs_q.reshape( + tensor.shape[0], tensor.shape[1] // 2, tensor.shape[2] + ) inputs_q = inputs_q.contiguous() if not batch_present: @@ -420,9 +475,9 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: if num_bits == 4: input_dim *= 2 - data = torch.empty((batch, input_dim, output_dim), - dtype=scale.dtype, - device=tensor.device) + data = torch.empty( + (batch, input_dim, output_dim), dtype=scale.dtype, device=tensor.device + ) if num_bits == 8: data.copy_(tensor) @@ -432,8 +487,9 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """ tensor = tensor.reshape(batch, input_dim, output_dim // 2) for i in range(2): - data[:, :, i::2] = ((tensor << 4 * - (1 - i)) >> 4).to(torch.int8) - q_range // 2 + data[:, :, i::2] = ((tensor << 4 * (1 - i)) >> 4).to( + torch.int8 + ) - q_range // 2 """Scale each input group with its scaling factor. """ scale = scale.reshape(batch, num_groups, -1) @@ -447,9 +503,7 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return input_deq -def fix_weights(layer: torch.nn.Module, - param_name: str, - reshape: bool = False): +def fix_weights(layer: torch.nn.Module, param_name: str, reshape: bool = False): """torch.compile does not know how to deal with a Parameter subclass (aka RTNParameter). As we don't really need RTNParameters for the forward pass, we replace them with equivalent instances of Parameters. diff --git a/vllm/model_executor/layers/quantization/schema.py b/vllm/model_executor/layers/quantization/schema.py index a108152929d9..9396da0ecd1a 100644 --- a/vllm/model_executor/layers/quantization/schema.py +++ b/vllm/model_executor/layers/quantization/schema.py @@ -30,7 +30,8 @@ class KVCacheQuantSchema(BaseModel): def check_is_fp8(self) -> "KVCacheQuantSchema": assert self.dtype == "float8_e4m3fn", ( "Loaded scaling factors intended for KV cache dtype = " - f"{self.dtype} rather than float8_e4m3fn!") + f"{self.dtype} rather than float8_e4m3fn!" + ) return self @model_validator(mode="after") @@ -41,15 +42,18 @@ def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": num_hidden_layers = context["num_hidden_layers"] assert len(self.scaling_factor) == tp_size, ( f"Loaded dictionary has TP size {len(self.scaling_factor)} " - f"but LLM engine is currently running with TP size {tp_size}.") + f"but LLM engine is currently running with TP size {tp_size}." + ) for tp_rank, layer_maps in self.scaling_factor.items(): assert len(layer_maps) == num_hidden_layers, ( f"KV cache scales map for TP rank {tp_rank} is malformed. " f"Expected {num_hidden_layers} layers, got " - f"{len(layer_maps)}.") + f"{len(layer_maps)}." + ) for i in range(tp_size): assert i in self.scaling_factor, ( - f"KV cache scales map for TP rank {i} not found.") + f"KV cache scales map for TP rank {i} not found." + ) return self @model_validator(mode="after") @@ -62,7 +66,8 @@ def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": for i in range(num_hidden_layers): assert i in layer_scales_map, ( f"Could not find KV cache scales for layer {i} in " - f"TP rank {tp_rank}.") + f"TP rank {tp_rank}." + ) return self @@ -82,5 +87,6 @@ def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": assert model_type == self.model_type, ( f"Model type is {model_type} but loaded " f"scaling factors belonging to different " - f"model type {self.model_type}!") + f"model type {self.model_type}!" + ) return self diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 7e38304ad6d9..629d0b863041 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -8,11 +8,16 @@ from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -39,10 +44,12 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool: class TorchAOConfig(QuantizationConfig): """Config class for torchao.""" - def __init__(self, - torchao_config, - skip_modules: Optional[list[str]] = None, - is_checkpoint_torchao_serialized: bool = False) -> None: + def __init__( + self, + torchao_config, + skip_modules: Optional[list[str]] = None, + is_checkpoint_torchao_serialized: bool = False, + ) -> None: """ # TorchAO quantization relies on tensor subclasses. In order, # to enable proper caching this needs standalone compile @@ -63,8 +70,10 @@ def __init__(self, self.is_checkpoint_torchao_serialized = is_checkpoint_torchao_serialized def __repr__(self) -> str: - return f"TorchAOConfig({self.torchao_config=}, {self.skip_modules=}, " \ + return ( + f"TorchAOConfig({self.torchao_config=}, {self.skip_modules=}, " f"{self.is_checkpoint_torchao_serialized=})" + ) def get_name(self) -> QuantizationMethods: return "torchao" @@ -95,13 +104,15 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig": ) from err quant_method = cls.get_from_keys_or(config, ["quant_method"], None) - is_checkpoint_torchao_serialized = (quant_method is not None - and "torchao" in quant_method) + is_checkpoint_torchao_serialized = ( + quant_method is not None and "torchao" in quant_method + ) hf_config = cls.get_from_keys_or(config, ["quant_type"], None) assert hf_config is not None, "quant_type must be specified" assert len(hf_config) == 1 and "default" in hf_config, ( - "Expected only one key 'default' in quant_type dictionary") + "Expected only one key 'default' in quant_type dictionary" + ) quant_type = hf_config["default"] ao_config = config_from_dict(quant_type) @@ -127,9 +138,7 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig": def from_config_file(cls, config_file: str) -> "TorchAOConfig": """Initialize class from a config file. Example: ``` - config = ( - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) - ) + config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) fn = "torchao_config.json" with open(fn, "w") as f: @@ -154,8 +163,9 @@ def from_config_dict_json(cls, config_dict_json: str) -> "TorchAOConfig": hf_config = {"quant_type": {"default": config_dict}} return cls.from_config(hf_config) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: if not isinstance(layer, LinearBase): return None @@ -167,12 +177,13 @@ def get_quant_method(self, layer: torch.nn.Module, module_fqn = prefix if isinstance(self.torchao_config, ModuleFqnToConfig): module_fqn_to_config = self.torchao_config.module_fqn_to_config - c = module_fqn_to_config.get( - module_fqn) or module_fqn_to_config.get("_default", None) + c = module_fqn_to_config.get(module_fqn) or module_fqn_to_config.get( + "_default", None + ) if c is not None: current_torchao_config = TorchAOConfig( - c, self.skip_modules, - self.is_checkpoint_torchao_serialized) + c, self.skip_modules, self.is_checkpoint_torchao_serialized + ) return TorchAOLinearMethod(current_torchao_config) else: return UnquantizedLinearMethod() @@ -183,8 +194,9 @@ def get_scaled_act_names(self) -> list[str]: return [] -def torchao_quantize_param_data(param: torch.Tensor, - torchao_config: Any) -> torch.nn.Parameter: +def torchao_quantize_param_data( + param: torch.Tensor, torchao_config: Any +) -> torch.nn.Parameter: """Quantize a Tensor with torchao quantization specified by torchao_config Args: @@ -205,7 +217,8 @@ def torchao_quantize_param_data(param: torch.Tensor, # while some of our configs need to do module swap, and only non-top # level modules support module swap dummy_linear = torch.nn.Sequential( - torch.nn.Linear(param.shape[1], param.shape[0], bias=False)) + torch.nn.Linear(param.shape[1], param.shape[0], bias=False) + ) dummy_linear[0].weight = param quantize_(dummy_linear, torchao_config) @@ -243,7 +256,8 @@ def create_weights( ) if self.quant_config.is_checkpoint_torchao_serialized: weight = torchao_quantize_param_data( - weight, self.quant_config.torchao_config) + weight, self.quant_config.torchao_config + ) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) @@ -264,7 +278,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # quantize the weight on the fly if the checkpoint is not already # quantized by torchao - weight = torchao_quantize_param_data(layer.weight, - self.quant_config.torchao_config) + weight = torchao_quantize_param_data( + layer.weight, self.quant_config.torchao_config + ) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index 7f738d170db4..a24cd41659a0 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -8,8 +8,10 @@ from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import (QuantizationConfig, - QuantizationMethods) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + QuantizationMethods, +) from vllm.model_executor.parameter import ModelWeightParameter ACTIVATION_SCHEMES = ["none", "dynamic"] @@ -24,8 +26,7 @@ def __init__( ) -> None: super().__init__() if activation_scheme not in ACTIVATION_SCHEMES: - raise ValueError( - f"Unsupported activation scheme {activation_scheme}") + raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme def get_name(self) -> QuantizationMethods: @@ -36,8 +37,7 @@ def get_supported_act_dtypes(self) -> list[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - raise NotImplementedError( - "This function should not be called with TPU Backend") + raise NotImplementedError("This function should not be called with TPU Backend") @staticmethod def get_config_filenames() -> list[str]: @@ -48,50 +48,61 @@ def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig": activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) return cls(activation_scheme=activation_scheme) - def get_quant_method(self, layer: Module, - prefix: str) -> Optional["TPUInt8LinearMethod"]: + def get_quant_method( + self, layer: Module, prefix: str + ) -> Optional["TPUInt8LinearMethod"]: if isinstance(layer, LinearBase): return TPUInt8LinearMethod(self) return None class TPUInt8LinearMethod(LinearMethodBase): - """Int8 Linear method for TPU Quant. """ + """Int8 Linear method for TPU Quant.""" def __init__(self, quant_config: Int8TpuConfig): self.quant_config = quant_config self.quantize_activation = False - if self.quant_config.activation_scheme == 'dynamic': + if self.quant_config.activation_scheme == "dynamic": self.quantize_activation = True - def create_weights(self, layer: Module, input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - + def create_weights( + self, + layer: Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): weight_loader = extra_weight_attrs.get("weight_loader") - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) def _quantize_weight( - self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: weight_dtype = weight.dtype weight = weight.cpu().to(torch.float32) n_bit = 8 eps = 1e-5 - max_int = 2**(n_bit - 1) - 1 - min_int = -(2**(n_bit - 1)) + max_int = 2 ** (n_bit - 1) - 1 + min_int = -(2 ** (n_bit - 1)) max_val = weight.abs().amax(dim=-1, keepdim=True) max_val = max_val.clamp(min=eps) qscale = max_val / max_int - qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int, - max_int).to(torch.int8) + qweight = torch.clamp( + torch.round(weight * (1.0 / qscale)), min_int, max_int + ).to(torch.int8) qscale = qscale.squeeze().to(weight_dtype) return qweight, qscale @@ -104,21 +115,25 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight = Parameter(qweight, requires_grad=False) layer.scale = Parameter(qscale, requires_grad=False) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: try: import torch_xla.experimental.custom_kernel # noqa: F401 except ImportError as err: raise ImportError( "Please install torch_xla by following the instructions at " "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501 - "to run vLLM on TPU.") from err + "to run vLLM on TPU." + ) from err weight = layer.weight scale = layer.scale out = torch.ops.xla.quantized_matmul_int8( - x, weight, scale, quantize_activation=self.quantize_activation) + x, weight, scale, quantize_activation=self.quantize_activation + ) if bias is not None: out = out + bias return out diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py index 6ad56bae3dca..07c18029fb4d 100644 --- a/vllm/model_executor/layers/quantization/utils/__init__.py +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .layer_utils import replace_parameter, update_tensor_inplace - -__all__ = ['update_tensor_inplace', 'replace_parameter'] +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ["update_tensor_inplace", "replace_parameter"] diff --git a/vllm/model_executor/layers/quantization/utils/allspark_utils.py b/vllm/model_executor/layers/quantization/utils/allspark_utils.py index 1992b4d20147..4c324682e5e6 100644 --- a/vllm/model_executor/layers/quantization/utils/allspark_utils.py +++ b/vllm/model_executor/layers/quantization/utils/allspark_utils.py @@ -12,41 +12,56 @@ ALLSPARK_AMPERE_K_ALIGN = 16 -def check_allspark_supported_dtype_shape(input_size_per_partition: int, - output_size_per_partition: int, - group_size: int, - weight_dtype: ScalarType, - act_dtype: torch.dtype): +def check_allspark_supported_dtype_shape( + input_size_per_partition: int, + output_size_per_partition: int, + group_size: int, + weight_dtype: ScalarType, + act_dtype: torch.dtype, +): capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = -1 if capability_tuple is None else capability_tuple.to_int() # For Ampere GPU if device_capability >= 80 and device_capability < 90: if group_size != -1: - return False, \ - "For Ampere GPU, AllSpark does not support group_size "\ - f"= {group_size}. Only group_size = -1 are supported." + return ( + False, + "For Ampere GPU, AllSpark does not support group_size " + f"= {group_size}. Only group_size = -1 are supported.", + ) if weight_dtype not in ALLSPARK_SUPPORTED_QUANT_TYPES: - return False, "For Ampere GPU, AllSpark does not support "\ - f"quant type ({weight_dtype}). Only quant type "\ - f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported." - - if input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 \ - or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0: - return False, \ - "AllSpark needs input_size_per_partition % "\ - f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and "\ - f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 "\ - "for Ampere GPU optimized kernels." + return ( + False, + "For Ampere GPU, AllSpark does not support " + f"quant type ({weight_dtype}). Only quant type " + f"({ALLSPARK_SUPPORTED_QUANT_TYPES}) are supported.", + ) + + if ( + input_size_per_partition % ALLSPARK_AMPERE_K_ALIGN != 0 + or output_size_per_partition % ALLSPARK_AMPERE_N_ALIGN != 0 + ): + return ( + False, + "AllSpark needs input_size_per_partition % " + f"{ALLSPARK_AMPERE_K_ALIGN} = 0 and " + f"output_size_per_partition % {ALLSPARK_AMPERE_N_ALIGN} = 0 " + "for Ampere GPU optimized kernels.", + ) if act_dtype != torch.float16 and act_dtype != torch.bfloat16: - return False, \ - "AllSpark only supports act_dtype = float16 or bfloat16,"\ - f"for Ampere GPU, but got act_dtype = {act_dtype}." + return ( + False, + "AllSpark only supports act_dtype = float16 or bfloat16," + f"for Ampere GPU, but got act_dtype = {act_dtype}.", + ) else: - return False, "AllSpark currently does not support "\ - f"device_capability = {device_capability}." + return ( + False, + "AllSpark currently does not support " + f"device_capability = {device_capability}.", + ) return True, None diff --git a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py index 4c2e54873586..4b7a22a26653 100644 --- a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py +++ b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py @@ -28,13 +28,14 @@ # Determines the supported quantization types for BitBLAS based on the # device's capability and whether zero-point (zp) is used. -def query_bitblas_supported_quant_types(has_zp: bool, - device_capability: Optional[int] = None - ): +def query_bitblas_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): if device_capability is None: capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) if device_capability < 70: return [] @@ -50,97 +51,116 @@ def query_bitblas_supported_quant_types(has_zp: bool, def _check_bitblas_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: - + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) - supported_types = query_bitblas_supported_quant_types( - has_zp, device_capability) + supported_types = query_bitblas_supported_quant_types(has_zp, device_capability) if quant_type not in supported_types: - return (False, f"BitBLAS does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).") - if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES): - return (False, f"BitBLAS does not support group_size = {group_size}. " - f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} " - "are supported.") + return ( + False, + f"BitBLAS does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES: + return ( + False, + f"BitBLAS does not support group_size = {group_size}. " + f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) # Finally, check if bitblas is installed try: import bitblas - if version.parse( - bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION): - raise ImportError("bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + + if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION): + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError: return False, "BitBLAS is not installed." return True, None -def check_bitblas_supported(quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None) -> bool: - cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp, - device_capability) +def check_bitblas_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_bitblas_supported( + quant_type, group_size, has_zp, device_capability + ) return cond -def verify_bitblas_supported(quant_type: ScalarType, - group_size: int, - has_zp: bool = False) -> None: +def verify_bitblas_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_bitblas_supports_shape(output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, group_size: int) -> None: - +def verify_bitblas_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: # Validate output_size_per_partition if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0: - raise ValueError(f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) # Validate input_size_per_partition if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0: - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") - - if (group_size < input_size - and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" f" is not divisible by group_size = {group_size}." "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") + "with --quantization gptq." + ) -def check_bitblas_supports_shape(output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, group_size: int) \ - -> tuple[bool, Optional[str]]: +def check_bitblas_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> tuple[bool, Optional[str]]: try: - verify_bitblas_supports_shape(output_size_per_partition, - input_size_per_partition, input_size, - group_size) + verify_bitblas_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) except ValueError as e: return False, e.__str__() return True, None @@ -150,8 +170,9 @@ def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int, - is_row_parallel: bool) -> bool: +def bitblas_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -159,17 +180,18 @@ def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int, def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), - requires_grad=False) + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), - requires_grad=False) + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) -def bitblas_sort_g_idx( - g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def bitblas_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices @@ -186,8 +208,7 @@ def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor: for col in range(unpacked_zeros.shape[1]): i = col % elems_per_int32 - unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> - (bits * i)) & 0xF + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) & 0xF if not is_gptq_v2: return unpacked_zeros + 1 return unpacked_zeros @@ -204,7 +225,6 @@ def unpack_gptq_qweight(qweight, bits): ) for col in range(unpacked_weight.shape[1]): i = col % elems_per_int8 - unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> - (bits * i)) + unpacked_weight[:, col] = qweight[:, col // elems_per_int8] >> (bits * i) return torch.bitwise_and(unpacked_weight, 2**bits - 1) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index a520302c62d9..7059a029ba67 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -1,18 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility helpers for NVFP4 + FlashInfer fused-MoE path""" + from __future__ import annotations import torch import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, - FusedMoEQuantConfig) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts) + FlashInferExperts, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize) + create_flashinfer_prepare_finalize, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe @@ -25,15 +30,17 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool: """Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" - return (envs.VLLM_USE_FLASHINFER_MOE_FP4 - and has_flashinfer_cutlass_fused_moe() - and current_platform.is_cuda() - and current_platform.is_device_capability(100)) + return ( + envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and current_platform.is_cuda() + and current_platform.is_device_capability(100) + ) -def reorder_w1w3_to_w3w1(weight: torch.Tensor, - scale: torch.Tensor, - dim: int = -2) -> tuple[torch.Tensor, torch.Tensor]: +def reorder_w1w3_to_w3w1( + weight: torch.Tensor, scale: torch.Tensor, dim: int = -2 +) -> tuple[torch.Tensor, torch.Tensor]: """Re-order the concatenated `[w1, w3]` tensors to `[w3, w1]`""" size = weight.size(dim) assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}" @@ -42,18 +49,21 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor, w1, w3 = weight.split(half, dim=dim) s1, s3 = scale.split(half, dim=dim) - return (torch.cat([w3, w1], - dim=dim).contiguous(), torch.cat([s3, s1], - dim=dim).contiguous()) + return ( + torch.cat([w3, w1], dim=dim).contiguous(), + torch.cat([s3, s1], dim=dim).contiguous(), + ) def build_flashinfer_fp4_cutlass_moe_prepare_finalize( - moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize: + moe: FusedMoEConfig, +) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 enable_alltoallv = envs.VLLM_ALL2ALL_BACKEND == "flashinfer_all2allv" return create_flashinfer_prepare_finalize( - use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv) + use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv + ) def select_nvfp4_gemm_impl( @@ -76,4 +86,5 @@ def select_nvfp4_gemm_impl( # native cutlass experts currently don't support DP; TP case won't call this raise ValueError( "CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS " - "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)") + "Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)" + ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index b779a5355b67..7f32ef00647c 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -8,12 +8,16 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, - FusedMoEQuantConfig) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts) + FlashInferExperts, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - create_flashinfer_prepare_finalize) + create_flashinfer_prepare_finalize, +) logger = init_logger(__name__) @@ -24,7 +28,6 @@ class FlashinferMoeBackend(Enum): def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): - # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. # TODO: Revert this to dynamic calculation once a new version of FlashInfer # with the necessary kernels is released. @@ -44,13 +47,16 @@ def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: - return x.reshape(-1, 2, x.shape[-2] // 2, - x.shape[-1]).flip(dims=[1]).reshape(x.shape) + return ( + x.reshape(-1, 2, x.shape[-2] // 2, x.shape[-1]).flip(dims=[1]).reshape(x.shape) + ) -def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor): +def rotate_flashinfer_fp8_moe_weights( + gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor +): from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a + epilogue_tile_m = 128 num_experts = gemm1_weights.shape[0] hidden_size = gemm1_weights.shape[-1] @@ -60,13 +66,13 @@ def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor, gemm1_weights_fp8_interleaved = [] for i in range(num_experts): gemm1_weights_fp8_interleaved.append( - reorder_rows_for_gated_act_gemm(gemm1_weights[i])) + reorder_rows_for_gated_act_gemm(gemm1_weights[i]) + ) # Stack weights and scales for all experts - gemm1_weights_fp8_interleaved = torch.stack( - gemm1_weights_fp8_interleaved).reshape(num_experts, - 2 * intermediate_size, - hidden_size) + gemm1_weights_fp8_interleaved = torch.stack(gemm1_weights_fp8_interleaved).reshape( + num_experts, 2 * intermediate_size, hidden_size + ) # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp8_shuffled = [] @@ -74,18 +80,21 @@ def rotate_flashinfer_fp8_moe_weights(gemm1_weights: torch.Tensor, for i in range(num_experts): gemm1_weights_fp8_shuffled.append( shuffle_matrix_a( - gemm1_weights_fp8_interleaved[i].view(torch.uint8), - epilogue_tile_m)) + gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m + ) + ) gemm2_weights_fp8_shuffled.append( - shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), - epilogue_tile_m)) + shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), epilogue_tile_m) + ) # Stack weights for all experts gemm1_weights.data = torch.stack(gemm1_weights_fp8_shuffled).view( - torch.float8_e4m3fn) + torch.float8_e4m3fn + ) gemm2_weights.data = torch.stack(gemm2_weights_fp8_shuffled).view( - torch.float8_e4m3fn) + torch.float8_e4m3fn + ) def apply_flashinfer_per_tensor_scale_fp8( @@ -102,16 +111,22 @@ def apply_flashinfer_per_tensor_scale_fp8( from flashinfer.fused_moe import RoutingMethodType import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 + assert layer.output1_scales_scalar is not None, ( - "Expected output1_scales_scalar to be initialized") + "Expected output1_scales_scalar to be initialized" + ) assert layer.output1_scales_scalar is not None, ( - "Expected output1_scales_gate_scalar to be initialized") + "Expected output1_scales_gate_scalar to be initialized" + ) assert layer.output1_scales_scalar is not None, ( - "Expected output2_scales_scalar to be initialized") + "Expected output2_scales_scalar to be initialized" + ) from vllm.model_executor.models.llama4 import Llama4MoE - assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \ + + assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( "FusedMoE flashinfer kernels are only supported for Llama4" + ) return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( routing_logits=router_logits, routing_bias=routing_bias, @@ -140,37 +155,39 @@ def get_moe_scaling_factors( activation_scale: torch.Tensor, gemm2_weights_scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - output1_scales_scalar = gemm1_weights_scale * input_scale * ( - 1.0 / activation_scale) + output1_scales_scalar = gemm1_weights_scale * input_scale * (1.0 / activation_scale) output1_scales_gate_scalar = gemm1_weights_scale * input_scale output2_scales_scalar = activation_scale * gemm2_weights_scale - return output1_scales_scalar, output1_scales_gate_scalar, \ - output2_scales_scalar + return output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar def register_moe_scaling_factors(layer: torch.nn.Module) -> None: - output1_scales, output1_gate_scales, output2_scales = \ - get_moe_scaling_factors( - layer.w13_input_scale, layer.w13_weight_scale, - layer.w2_input_scale, layer.w2_weight_scale - ) + output1_scales, output1_gate_scales, output2_scales = get_moe_scaling_factors( + layer.w13_input_scale, + layer.w13_weight_scale, + layer.w2_input_scale, + layer.w2_weight_scale, + ) layer.register_parameter( - 'output1_scales_scalar', - torch.nn.Parameter(output1_scales, requires_grad=False)) + "output1_scales_scalar", torch.nn.Parameter(output1_scales, requires_grad=False) + ) layer.register_parameter( - 'output1_scales_gate_scalar', - torch.nn.Parameter(output1_gate_scales, requires_grad=False)) + "output1_scales_gate_scalar", + torch.nn.Parameter(output1_gate_scales, requires_grad=False), + ) layer.register_parameter( - 'output2_scales_scalar', - torch.nn.Parameter(output2_scales, requires_grad=False)) + "output2_scales_scalar", torch.nn.Parameter(output2_scales, requires_grad=False) + ) layer.register_parameter( - 'w2_input_scale_inv', - torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False)) + "w2_input_scale_inv", + torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False), + ) def build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize: + moe: Optional[FusedMoEConfig], +) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False return create_flashinfer_prepare_finalize(use_dp) @@ -193,8 +210,7 @@ def select_cutlass_fp8_gemm_impl( tp_size=moe.moe_parallel_config.tp_size, ) - assert out_dtype is not None, ( - "If moe config is None, out_dtype must be passed") + assert out_dtype is not None, "If moe config is None, out_dtype must be passed" return FlashInferExperts( out_dtype=out_dtype, quant_config=quant_config, @@ -217,9 +233,10 @@ def flashinfer_cutlass_moe_fp8( fused_experts = mk.FusedMoEModularKernel( build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None), - select_cutlass_fp8_gemm_impl(moe=None, - quant_config=quant_config, - out_dtype=hidden_states.dtype)) + select_cutlass_fp8_gemm_impl( + moe=None, quant_config=quant_config, out_dtype=hidden_states.dtype + ), + ) return fused_experts( hidden_states, @@ -245,4 +262,5 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend: allowed_backends = ["throughput", "latency"] raise ValueError( f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" - f" expected one of {allowed_backends}") + f" expected one of {allowed_backends}" + ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 13bb69190eae..16ede6113a94 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -15,18 +15,26 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, group_broadcast) + GroupShape, + group_broadcast, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_BLOCK_FP8_SUPPORTED) -from vllm.model_executor.parameter import (BlockQuantScaleParameter, - ChannelQuantScaleParameter, - PerTensorScaleParameter) + CUTLASS_BLOCK_FP8_SUPPORTED, +) +from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ChannelQuantScaleParameter, + PerTensorScaleParameter, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op -from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used, - is_deep_gemm_supported, - should_use_deepgemm_for_fp8_linear) +from vllm.utils.deep_gemm import ( + fp8_gemm_nt, + is_deep_gemm_e8m0_used, + is_deep_gemm_supported, + should_use_deepgemm_for_fp8_linear, +) logger = init_logger(__name__) @@ -56,7 +64,8 @@ def cutlass_scaled_mm( out_dtype=output_dtype, scale_a=As, # SM90 block FP8 requires row-major scale_b, which we do ahead of time - scale_b=Bs if block_size is not None and is_hopper else Bs.T) + scale_b=Bs if block_size is not None and is_hopper else Bs.T, + ) def rocm_aiter_gemm_w8a8_blockscale_impl( @@ -80,7 +89,6 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - m = A.shape[0] n = B.shape[0] Y = torch.empty(m, n, dtype=output_dtype, device=A.device) @@ -93,9 +101,11 @@ def rocm_aiter_gemm_w8a8_blockscale_fake( op_func=rocm_aiter_gemm_w8a8_blockscale_impl, fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, ) - if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz()): - + if ( + envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz() + ): import aiter as rocm_aiter from aiter import get_hip_quant @@ -113,8 +123,9 @@ def _w8a8_triton_block_scaled_mm_func( block_size: list[int], output_dtype: torch.dtype, ) -> torch.Tensor: - return w8a8_triton_block_scaled_mm(qx, weight, x_scale, weight_scale, - block_size, output_dtype) + return w8a8_triton_block_scaled_mm( + qx, weight, x_scale, weight_scale, block_size, output_dtype + ) def _w8a8_triton_block_scaled_mm_fake( @@ -125,9 +136,9 @@ def _w8a8_triton_block_scaled_mm_fake( block_size: list[int], output_dtype: torch.dtype, ) -> torch.Tensor: - return torch.empty((qx.size(0), weight.size(0)), - dtype=output_dtype, - device=qx.device) + return torch.empty( + (qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device + ) direct_register_custom_op( @@ -147,22 +158,24 @@ def _padded_cutlass( ) -> torch.Tensor: pad_multiple = 4 dim = qx.shape[0] - padded = dim if dim % pad_multiple == 0 else dim + pad_multiple - ( - dim % pad_multiple) + padded = ( + dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple) + ) padded_shape = [padded, *qx.shape[1:]] padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype) - padded_qx[0:qx.shape[0], ...].copy_(qx) + padded_qx[0 : qx.shape[0], ...].copy_(qx) padded_x_scale_shape = [*x_scale.shape[1:], padded] - padded_x_scale = torch.ones(padded_x_scale_shape, - device=x_scale.device, - dtype=x_scale.dtype).permute(-1, -2) - padded_x_scale[0:x_scale.shape[0], ...].copy_(x_scale) + padded_x_scale = torch.ones( + padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype + ).permute(-1, -2) + padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale) - output = cutlass_scaled_mm(padded_qx, weight, padded_x_scale, weight_scale, - block_size, output_dtype, True) - return output[0:qx.shape[0], ...] + output = cutlass_scaled_mm( + padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True + ) + return output[0 : qx.shape[0], ...] def _padded_cutlass_fake( @@ -173,9 +186,9 @@ def _padded_cutlass_fake( block_size: list[int], output_dtype: torch.dtype, ) -> torch.Tensor: - return torch.empty((qx.size(0), weight.size(0)), - dtype=output_dtype, - device=qx.device) + return torch.empty( + (qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device + ) direct_register_custom_op( @@ -185,18 +198,30 @@ def _padded_cutlass_fake( ) -def _fp8_gemm_nt_op(q_input: torch.Tensor, input_scale: torch.Tensor, - weight: torch.Tensor, weight_scale: torch.Tensor, - output: torch.Tensor, use_deep_gemm_e8m0: bool) -> None: - fp8_gemm_nt((q_input, input_scale), (weight, weight_scale), - output, - is_deep_gemm_e8m0_used=use_deep_gemm_e8m0) +def _fp8_gemm_nt_op( + q_input: torch.Tensor, + input_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + output: torch.Tensor, + use_deep_gemm_e8m0: bool, +) -> None: + fp8_gemm_nt( + (q_input, input_scale), + (weight, weight_scale), + output, + is_deep_gemm_e8m0_used=use_deep_gemm_e8m0, + ) -def _fp8_gemm_nt_op_fake(q_input: torch.Tensor, input_scale: torch.Tensor, - weight: torch.Tensor, weight_scale: torch.Tensor, - output: torch.Tensor, - use_deep_gemm_e8m0: bool) -> None: +def _fp8_gemm_nt_op_fake( + q_input: torch.Tensor, + input_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + output: torch.Tensor, + use_deep_gemm_e8m0: bool, +) -> None: return None @@ -233,15 +258,21 @@ def __init__( # We can't use _dispatch_w8a8_blockscale_op to figure out if we want # to use deepgemm because we don't know the shape of weights (and # whether deepgemm supports it) at the init time. - self.w8a8_blockscale_op, self.input_quant_op = \ - self._dispatch_w8a8_blockscale_op( - cutlass_block_fp8_supported, use_aiter_and_is_supported) - self.deepgemm_input_quant_op = (QuantFP8( - False, - self.act_quant_group_shape, - column_major_scales=True, - use_ue8m0=self.use_deep_gemm_e8m0) if self.is_deep_gemm_supported - else None) + self.w8a8_blockscale_op, self.input_quant_op = ( + self._dispatch_w8a8_blockscale_op( + cutlass_block_fp8_supported, use_aiter_and_is_supported + ) + ) + self.deepgemm_input_quant_op = ( + QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=self.use_deep_gemm_e8m0, + ) + if self.is_deep_gemm_supported + else None + ) def apply( self, @@ -257,8 +288,9 @@ def apply( output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_deepgemm_for_fp8_linear(output_dtype, weight, - self.is_deep_gemm_supported): + if should_use_deepgemm_for_fp8_linear( + output_dtype, weight, self.is_deep_gemm_supported + ): output = self._run_deepgemm(input_2d, weight, weight_scale) else: output = self.w8a8_blockscale_op(input_2d, weight, weight_scale) @@ -275,12 +307,14 @@ def _run_deepgemm( ) -> torch.Tensor: assert self.deepgemm_input_quant_op is not None q_input, input_scale = self.deepgemm_input_quant_op(input_2d) - output = torch.empty((q_input.shape[0], weight.shape[0]), - dtype=torch.bfloat16, - device=q_input.device) - torch.ops.vllm.fp8_gemm_nt_op(q_input, input_scale, weight, - weight_scale, output, - self.use_deep_gemm_e8m0) + output = torch.empty( + (q_input.shape[0], weight.shape[0]), + dtype=torch.bfloat16, + device=q_input.device, + ) + torch.ops.vllm.fp8_gemm_nt_op( + q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0 + ) return output def _run_cutlass( @@ -292,15 +326,24 @@ def _run_cutlass( assert self.input_quant_op is not None q_input, input_scale = self.input_quant_op(input_2d) if self.is_hopper: - return torch.ops.vllm.padded_cutlass(q_input, weight, input_scale, - weight_scale, - list(self.weight_group_shape), - input_2d.dtype) + return torch.ops.vllm.padded_cutlass( + q_input, + weight, + input_scale, + weight_scale, + list(self.weight_group_shape), + input_2d.dtype, + ) else: - return cutlass_scaled_mm(q_input, weight, - input_scale, weight_scale, - list(self.weight_group_shape), - input_2d.dtype, False) + return cutlass_scaled_mm( + q_input, + weight, + input_scale, + weight_scale, + list(self.weight_group_shape), + input_2d.dtype, + False, + ) def _run_aiter( self, @@ -310,10 +353,16 @@ def _run_aiter( ) -> torch.Tensor: assert self.act_quant_group_shape == GroupShape(1, 128) q_input, input_scale = aiter_per1x128_quant( - input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8) + input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8 + ) return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( - q_input, weight, input_scale, weight_scale, - self.weight_group_shape, input_2d.dtype) + q_input, + weight, + input_scale, + weight_scale, + self.weight_group_shape, + input_2d.dtype, + ) def _run_triton( self, @@ -324,34 +373,52 @@ def _run_triton( assert self.input_quant_op is not None q_input, input_scale = self.input_quant_op(input_2d) return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( - q_input, weight, input_scale, weight_scale, - self.weight_group_shape, input_2d.dtype) + q_input, + weight, + input_scale, + weight_scale, + self.weight_group_shape, + input_2d.dtype, + ) def _dispatch_w8a8_blockscale_op( self, use_cutlass: bool, use_aiter_and_is_supported: bool, - ) -> tuple[Callable[[ - torch.Tensor, - torch.Tensor, + ) -> tuple[ + Callable[ + [ + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], torch.Tensor, - ], torch.Tensor], Optional[QuantFP8]]: + ], + Optional[QuantFP8], + ]: if use_cutlass: - return self._run_cutlass, (QuantFP8(False, - self.act_quant_group_shape, - column_major_scales=True, - use_ue8m0=False)) + return self._run_cutlass, ( + QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=True, + use_ue8m0=False, + ) + ) if use_aiter_and_is_supported: return self._run_aiter, None - return self._run_triton, (QuantFP8(False, - self.act_quant_group_shape, - column_major_scales=False, - use_ue8m0=False)) + return self._run_triton, ( + QuantFP8( + False, + self.act_quant_group_shape, + column_major_scales=False, + use_ue8m0=False, + ) + ) def input_to_float8( - x: torch.Tensor, - dtype: Optional[torch.dtype] = None + x: torch.Tensor, dtype: Optional[torch.dtype] = None ) -> tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to float8 values " "with tensor-wise quantization.""" @@ -410,8 +477,9 @@ def _per_token_group_quant_fp8( row_g_id = g_id % groups_per_row # Ensure offset calculations use int64 to prevent overflow - y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * - group_size) + y_ptr_offset = (row.to(tl.int64) * y_row_stride) + ( + row_g_id.to(tl.int64) * group_size + ) y_ptr += y_ptr_offset y_q_ptr_offset = g_id.to(tl.int64) * group_size @@ -465,8 +533,9 @@ def _per_token_group_quant_fp8_colmajor( row_g_id = g_id % groups_per_row # Ensure offset calculations use int64 to prevent overflow - y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * - group_size) + y_ptr_offset = (row.to(tl.int64) * y_row_stride) + ( + row_g_id.to(tl.int64) * group_size + ) y_ptr += y_ptr_offset y_q_ptr_offset = g_id.to(tl.int64) * group_size @@ -478,8 +547,7 @@ def _per_token_group_quant_fp8_colmajor( scale_col = g_id % blocks_per_row scale_row = g_id // blocks_per_row # Ensure offset calculation uses int64 for y_s_ptr - y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to( - tl.int64) + y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(tl.int64) y_s_ptr += y_s_ptr_offset cols = tl.arange(0, BLOCK) # group_size <= BLOCK @@ -523,9 +591,10 @@ def per_token_group_quant_fp8( if use_ue8m0 is None: use_ue8m0 = is_deep_gemm_e8m0_used() dtype = current_platform.fp8_dtype() if dtype is None else dtype - assert (x.shape[-1] % group_size == 0), ( + assert x.shape[-1] % group_size == 0, ( f"the last dimension of `x` {x.shape[-1]} must be divisible " - f"by `group_size` {group_size}") + f"by `group_size` {group_size}" + ) assert x.stride(-1) == 1, "`x` groups must be contiguous" finfo = torch.finfo(dtype) @@ -539,18 +608,18 @@ def per_token_group_quant_fp8( # Allocate the scale tensor in either row- or column-major format. if column_major_scales: - shape = (x.shape[-1] // group_size, ) + x.shape[:-1] - x_s = torch.empty(shape, device=x.device, - dtype=torch.float32).permute(-1, -2) + shape = (x.shape[-1] // group_size,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) else: - shape = x.shape[:-1] + (x.shape[-1] // group_size, ) + shape = x.shape[:-1] + (x.shape[-1] // group_size,) x_s = torch.empty(shape, device=x.device, dtype=torch.float32) # prefer CUDA kernel if available # TODO(bnell): this causes some fp8 moe test to fail. if current_platform.is_cuda() and x.is_contiguous(): - torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps, - fp8_min, fp8_max, use_ue8m0) + torch.ops._C.per_token_group_fp8_quant( + x, x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0 + ) return x_q, x_s # TRITON FALLBACK @@ -561,7 +630,7 @@ def per_token_group_quant_fp8( num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 if column_major_scales: - _per_token_group_quant_fp8_colmajor[(M, )]( + _per_token_group_quant_fp8_colmajor[(M,)]( x, x_q, x_s, @@ -578,7 +647,7 @@ def per_token_group_quant_fp8( num_stages=num_stages, ) else: - _per_token_group_quant_fp8[(M, )]( + _per_token_group_quant_fp8[(M,)]( x, x_q, x_s, @@ -656,12 +725,8 @@ def _w8a8_triton_block_scaled_mm( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k @@ -687,8 +752,9 @@ def _w8a8_triton_block_scaled_mm( @functools.lru_cache -def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, - block_k: int) -> Optional[dict[int, Any]]: +def get_w8a8_block_fp8_configs( + N: int, K: int, block_n: int, block_k: int +) -> Optional[dict[int, Any]]: """ Return optimized configurations for the w8a8 block fp8 kernel. The return value will be a dictionary that maps an irregular grid of @@ -703,7 +769,8 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501 config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) if os.path.exists(config_file_path): with open(config_file_path) as f: logger.info( @@ -759,7 +826,7 @@ def w8a8_triton_block_scaled_mm( assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] - C_shape = A.shape[:-1] + (N, ) + C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) @@ -780,8 +847,9 @@ def w8a8_triton_block_scaled_mm( } def grid(META): - return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * - triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) _w8a8_triton_block_scaled_mm[grid]( A, @@ -811,9 +879,9 @@ def grid(META): def requant_weight_ue8m0_inplace( - weight: torch.Tensor, - weight_scale: torch.Tensor, - block_size: Sequence[int] = (128, 128), + weight: torch.Tensor, + weight_scale: torch.Tensor, + block_size: Sequence[int] = (128, 128), ) -> None: """Re-quantise *weight* so that its per-block scaling factors are in the UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace. @@ -830,8 +898,9 @@ def requant_weight_ue8m0_inplace( return if weight.dtype != torch.float8_e4m3fn: - raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got " - f"{weight.dtype} instead.") + raise ValueError( + f"Expected *weight* to be torch.float8_e4m3fn, got {weight.dtype} instead." + ) from vllm.utils.deep_gemm import per_block_cast_to_fp8 @@ -860,8 +929,9 @@ def requant_weight_ue8m0_inplace( s_exp = s_exp[:m_cur, :k_cur] w_dq = w_q.to(torch.float32) * s_exp # Re-quantise using power-of-two scaling (UE8M0). - w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k], - use_ue8m0=True) + w_requant, s_requant = per_block_cast_to_fp8( + w_dq, [block_m, block_k], use_ue8m0=True + ) # Write back the results in-place. w_q.copy_(w_requant) @@ -871,28 +941,39 @@ def requant_weight_ue8m0_inplace( def check_aiter_fp8_linear_support() -> bool: """AITER is only supported on ROCm and only for FP8_FNUZ and at the moment are MI300 series""" - return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_LINEAR - and current_platform.is_fp8_fnuz()) + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and current_platform.is_fp8_fnuz() + ) def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: """Pad the weight tensor. This is an optimization on ROCm platform, which can benefit from tensors located far enough from one another in memory""" - if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() - and weight.stride(-1) == 1 - and (weight.stride(-2) * weight.element_size()) % 512 == 0): + if ( + envs.VLLM_ROCM_FP8_PADDING + and current_platform.is_rocm() + and weight.stride(-1) == 1 + and (weight.stride(-2) * weight.element_size()) % 512 == 0 + ): num_pad = 256 // weight.element_size() import torch.nn.functional as F + weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] torch.cuda.empty_cache() return weight -def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int, - output_size: int, input_size_per_partition: int, - output_partition_sizes: list[int], - block_size: list[int]) -> None: +def validate_fp8_block_shape( + layer: torch.nn.Module, + input_size: int, + output_size: int, + input_size_per_partition: int, + output_partition_sizes: list[int], + block_size: list[int], +) -> None: """Validate block quantization shapes for tensor parallelism.""" from vllm.distributed import get_tensor_model_parallel_world_size @@ -900,15 +981,18 @@ def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int, block_n, block_k = block_size[0], block_size[1] # Required by row parallel - if (tp_size > 1 and input_size // input_size_per_partition == tp_size - and input_size_per_partition % block_k != 0): + if ( + tp_size > 1 + and input_size // input_size_per_partition == tp_size + and input_size_per_partition % block_k != 0 + ): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition} " - f"is not divisible by weight quantization block_k = {block_k}.") + f"is not divisible by weight quantization block_k = {block_k}." + ) # Required by column parallel or enabling merged weights - is_tp_split = (tp_size > 1 - and output_size // sum(output_partition_sizes) == tp_size) + is_tp_split = tp_size > 1 and output_size // sum(output_partition_sizes) == tp_size is_merged_gemm = len(output_partition_sizes) > 1 if is_tp_split or is_merged_gemm: sizes_to_check = output_partition_sizes @@ -921,33 +1005,44 @@ def validate_fp8_block_shape(layer: torch.nn.Module, input_size: int, raise ValueError( f"Weight output_partition_size = " f"{output_partition_size} is not divisible by " - f"weight quantization block_n = {block_n}.") + f"weight quantization block_n = {block_n}." + ) def create_fp8_weight_parameter( - output_size_per_partition: int, input_size_per_partition: int, - weight_loader: Optional[Callable]) -> torch.nn.Parameter: + output_size_per_partition: int, + input_size_per_partition: int, + weight_loader: Optional[Callable], +) -> torch.nn.Parameter: """Create FP8 weight parameter.""" from vllm.model_executor.parameter import ModelWeightParameter - return ModelWeightParameter(data=torch.empty(output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) + return ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) def create_fp8_scale_parameter( - parameter_type: torch.nn.Parameter, output_partition_sizes: list[int], - input_size_per_partition: int, block_size: Optional[list[int]], - weight_loader: Optional[Callable]) -> torch.nn.Parameter: + parameter_type: torch.nn.Parameter, + output_partition_sizes: list[int], + input_size_per_partition: int, + block_size: Optional[list[int]], + weight_loader: Optional[Callable], +) -> torch.nn.Parameter: """Create scale parameter based on quantization strategy.""" if parameter_type == ChannelQuantScaleParameter: - scale = parameter_type(data=torch.empty( - (sum(output_partition_sizes), 1), dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader) + scale = parameter_type( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) elif parameter_type == BlockQuantScaleParameter: assert block_size is not None block_n, block_k = block_size[0], block_size[1] @@ -963,9 +1058,10 @@ def create_fp8_scale_parameter( weight_loader=weight_loader, ) elif parameter_type == PerTensorScaleParameter: - scale = parameter_type(data=torch.empty(len(output_partition_sizes), - dtype=torch.float32), - weight_loader=weight_loader) + scale = parameter_type( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) else: raise ValueError(f"Unknown parameter type: {parameter_type}") @@ -974,14 +1070,15 @@ def create_fp8_scale_parameter( def create_fp8_input_scale( - output_partition_sizes: list[int], - weight_loader: Optional[Callable]) -> torch.nn.Parameter: + output_partition_sizes: list[int], weight_loader: Optional[Callable] +) -> torch.nn.Parameter: """Create input scale parameter for static activation quantization.""" from vllm.model_executor.parameter import PerTensorScaleParameter - scale = PerTensorScaleParameter(data=torch.empty( - len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader) + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) scale[:] = torch.finfo(torch.float32).min return scale @@ -990,15 +1087,18 @@ def process_fp8_weight_tensor_strategy( weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: list[int], - input_scale: Optional[torch.Tensor] = None + input_scale: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Process weights for tensor-wise quantization strategy.""" from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) + normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale, + ) if current_platform.is_fp8_fnuz(): weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, weight_scale=weight_scale, input_scale=input_scale) + weight=weight, weight_scale=weight_scale, input_scale=input_scale + ) # Requantize with max scale weight_scale, weight = requantize_with_max_scale( @@ -1014,15 +1114,17 @@ def process_fp8_weight_tensor_strategy( def process_fp8_weight_channel_strategy( weight: torch.Tensor, weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None + input_scale: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Process weights for channel-wise quantization strategy.""" from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - normalize_e4m3fn_to_e4m3fnuz) + normalize_e4m3fn_to_e4m3fnuz, + ) if current_platform.is_fp8_fnuz(): weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, weight_scale=weight_scale, input_scale=input_scale) + weight=weight, weight_scale=weight_scale, input_scale=input_scale + ) return weight, weight_scale, input_scale @@ -1033,37 +1135,48 @@ def process_fp8_weight_block_strategy( ) -> tuple[torch.Tensor, torch.Tensor]: """Process weights for block-wise quantization strategy.""" from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - normalize_e4m3fn_to_e4m3fnuz) + normalize_e4m3fn_to_e4m3fnuz, + ) if current_platform.is_fp8_fnuz(): weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, weight_scale=weight_scale) + weight=weight, weight_scale=weight_scale + ) weight = _maybe_pad_fp8_weight(weight) return weight, weight_scale -def maybe_post_process_fp8_weight_block(layer: torch.nn.Module, - cutlass_block_fp8_supported: bool): +def maybe_post_process_fp8_weight_block( + layer: torch.nn.Module, cutlass_block_fp8_supported: bool +): assert layer.weight_block_size is not None - from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, - should_use_deepgemm_for_fp8_linear) + from vllm.utils.deep_gemm import ( + is_deep_gemm_e8m0_used, + should_use_deepgemm_for_fp8_linear, + ) # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. should_use_deepgemm = should_use_deepgemm_for_fp8_linear( - layer.orig_dtype, layer.weight) + layer.orig_dtype, layer.weight + ) if is_deep_gemm_e8m0_used() and should_use_deepgemm: block_sz = tuple(layer.weight_block_size) - requant_weight_ue8m0_inplace(layer.weight.data, - layer.weight_scale.data, block_sz) + requant_weight_ue8m0_inplace( + layer.weight.data, layer.weight_scale.data, block_sz + ) # SM90 Block FP8 CUTLASS requires row-major weight scales - elif (current_platform.is_device_capability(90) - and cutlass_block_fp8_supported and not should_use_deepgemm): + elif ( + current_platform.is_device_capability(90) + and cutlass_block_fp8_supported + and not should_use_deepgemm + ): layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.data.T.contiguous(), requires_grad=False) + layer.weight_scale.data.T.contiguous(), requires_grad=False + ) def expert_weight_is_col_major(x: torch.Tensor) -> bool: diff --git a/vllm/model_executor/layers/quantization/utils/gptq_utils.py b/vllm/model_executor/layers/quantization/utils/gptq_utils.py index fd76af230620..6209dda955ce 100644 --- a/vllm/model_executor/layers/quantization/utils/gptq_utils.py +++ b/vllm/model_executor/layers/quantization/utils/gptq_utils.py @@ -9,10 +9,11 @@ import regex as re import torch -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, UnquantizedEmbeddingMethod) + ParallelLMHead, + UnquantizedEmbeddingMethod, +) if TYPE_CHECKING: from ..gptq import GPTQConfig @@ -25,16 +26,13 @@ # Match dynamic rules with module name (prefix) and override quantize # config if module (prefix) matches a rule def override_config(config: Union[GPTQConfig, GPTQMarlinConfig], prefix: str): - weight_bits = get_dynamic_override(config, prefix, "bits", - config.weight_bits) + weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits) if isinstance(weight_bits, int): config.weight_bits = weight_bits - group_size = get_dynamic_override(config, prefix, "group_size", - config.group_size) + group_size = get_dynamic_override(config, prefix, "group_size", config.group_size) if isinstance(group_size, int): config.group_size = group_size - desc_act = get_dynamic_override(config, prefix, "desc_act", - config.desc_act) + desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act) if isinstance(desc_act, bool): config.desc_act = desc_act @@ -46,25 +44,27 @@ def override_config(config: Union[GPTQConfig, GPTQMarlinConfig], prefix: str): config.is_sym = is_sym if (config.weight_bits, config.is_sym) not in config.TYPE_MAP: - raise ValueError("Unsupported quantization config: " - f"bits={config.weight_bits}, sym={config.is_sym}") + raise ValueError( + "Unsupported quantization config: " + f"bits={config.weight_bits}, sym={config.is_sym}" + ) - config.quant_type = config.TYPE_MAP[(config.weight_bits, - config.is_sym)] + config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)] elif config.get_name() == "gptq": assert isinstance(config, GPTQConfig) if config.weight_bits not in [2, 3, 4, 8]: raise ValueError( "Currently, only 2/3/4/8-bit weight quantization is " - f"supported for GPTQ, but got {config.weight_bits} bits.") + f"supported for GPTQ, but got {config.weight_bits} bits." + ) def get_dynamic_override( config: Union[GPTQConfig, GPTQMarlinConfig], layer_name: str, key: Optional[str] = None, - default_value: Union[int, bool, - None] = None) -> Union[dict, int, bool, None]: + default_value: Union[int, bool, None] = None, +) -> Union[dict, int, bool, None]: for pattern, pattern_dict in config.dynamic.items(): # Negative match: matched modules are excluded from quantized init if pattern.startswith("-:"): @@ -83,7 +83,7 @@ def get_dynamic_override( def is_layer_gptq_quantized( prefix: str, quantized_layers: list[str], - fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj @@ -106,8 +106,9 @@ def is_layer_gptq_quantized( is_quantized = None for shard_prefix in shard_prefixes: - is_shard_quantized = any(layer in shard_prefix - for layer in quantized_layers) + is_shard_quantized = any( + layer in shard_prefix for layer in quantized_layers + ) if is_quantized is None: is_quantized = is_shard_quantized @@ -115,7 +116,8 @@ def is_layer_gptq_quantized( raise ValueError( f"Detected some but not all shards of {prefix} " "are quantized. All shards of fused layers " - "to have the same precision.") + "to have the same precision." + ) else: is_quantized = any(layer in prefix for layer in quantized_layers) @@ -130,18 +132,20 @@ def get_linear_quant_method( linear_method_cls: type, ): cloned_config = deepcopy(config) - parallel_lm_head_quantized = isinstance( - layer, ParallelLMHead) and cloned_config.lm_head_quantized + parallel_lm_head_quantized = ( + isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized + ) if isinstance(layer, LinearBase) or parallel_lm_head_quantized: is_layer_quantized = is_layer_gptq_quantized( prefix=prefix, quantized_layers=cloned_config.modules_in_block_to_quantize, - fused_mapping=cloned_config.packed_modules_mapping) + fused_mapping=cloned_config.packed_modules_mapping, + ) # False = skip module, None = no override, else = Positive match if get_dynamic_override( # noqa: E712 - cloned_config, # noqa: E712 - layer_name=prefix) == False or ( - not is_layer_quantized): # noqa: E712 + cloned_config, # noqa: E712 + layer_name=prefix, + ) == False or (not is_layer_quantized): # noqa: E712 if parallel_lm_head_quantized: return UnquantizedEmbeddingMethod() return UnquantizedLinearMethod() diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index 62e458ec3c93..1b8efe4332c5 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -30,12 +30,9 @@ def apply_w8a8_block_int8_linear( output_shape = [*input.shape[:-1], weight.shape[0]] q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1]) - output = w8a8_block_int8_matmul(q_input, - weight, - x_scale, - weight_scale, - block_size, - output_dtype=input.dtype) + output = w8a8_block_int8_matmul( + q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype + ) if bias is not None: output = output + bias @@ -43,8 +40,8 @@ def apply_w8a8_block_int8_linear( def input_to_int8( - x: torch.Tensor, - dtype: torch.dtype = torch.int8) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, dtype: torch.dtype = torch.int8 +) -> tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to int8 values with tensor-wise quantization.""" iinfo = torch.iinfo(dtype) @@ -78,8 +75,8 @@ def block_dequant( for i in range(k_tiles): for j in range(n_tiles): x_dq_block[ - j * block_n:min((j + 1) * block_n, n), - i * block_k:min((i + 1) * block_k, k), + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), ] *= x_s[j][i] return x_dq_block @@ -91,15 +88,17 @@ def block_dequant( # NOTE: This can be removed when hip.libdevice.round() is available. @core.extern def round_f32(arg0, _builder=None): - return core.extern_elementwise("", - "", [arg0], { - (core.dtype("fp32"), ): - ("llvm.round", core.dtype("fp32")), - (core.dtype("fp64"), ): - ("llvm.round", core.dtype("fp64")), - }, - is_pure=True, - _builder=_builder) + return core.extern_elementwise( + "", + "", + [arg0], + { + (core.dtype("fp32"),): ("llvm.round", core.dtype("fp32")), + (core.dtype("fp64"),): ("llvm.round", core.dtype("fp64")), + }, + is_pure=True, + _builder=_builder, + ) @triton.jit def round_int8(x): @@ -127,8 +126,7 @@ def _per_token_quant_int8( cols = tl.arange(0, BLOCK) mask = cols < N - x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, - other=0.0).to(tl.float32) + x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32) absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) scale_x = absmax / 127 x_q = x * (127 / absmax) @@ -142,15 +140,13 @@ def per_token_quant_int8(x): M = x.numel() // x.shape[-1] N = x.shape[-1] x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) - scales = torch.empty(x.shape[:-1] + (1, ), - device=x.device, - dtype=torch.float32) + scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32) BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) assert x.is_contiguous() - _per_token_quant_int8[(M, )]( + _per_token_quant_int8[(M,)]( x, x_q, scales, @@ -229,8 +225,9 @@ def per_token_group_quant_int8( tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ - assert (x.shape[-1] % group_size == 0 - ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.shape[-1] % group_size == 0, ( + "the last dimension of `x` cannot be divisible by `group_size`" + ) assert x.is_contiguous(), "`x` is not contiguous" iinfo = torch.iinfo(dtype) @@ -239,15 +236,15 @@ def per_token_group_quant_int8( x_q = torch.empty_like(x, device=x.device, dtype=dtype) x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size, ), + x.shape[:-1] + (x.shape[-1] // group_size,), device=x.device, dtype=torch.float32, ) # prefer CUDA kernel if available if current_platform.is_cuda(): - torch.ops._C.per_token_group_quant_int8(x, x_q, x_s, group_size, eps, - float(int8_min), - float(int8_max)) + torch.ops._C.per_token_group_quant_int8( + x, x_q, x_s, group_size, eps, float(int8_min), float(int8_max) + ) return x_q, x_s M = x.numel() // group_size @@ -257,7 +254,7 @@ def per_token_group_quant_int8( # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 - _per_token_group_quant_int8[(M, )]( + _per_token_group_quant_int8[(M,)]( x, x_q, x_s, @@ -333,20 +330,15 @@ def _w8a8_block_int8_matmul( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, - mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k a_s = tl.load(As_ptrs + offs_ks * stride_As_k) b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) - accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, - None] * b_s[None, :] + accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -365,8 +357,9 @@ def _w8a8_block_int8_matmul( @functools.lru_cache -def get_w8a8_block_int8_configs(N: int, K: int, block_n: int, - block_k: int) -> Optional[dict[int, Any]]: +def get_w8a8_block_int8_configs( + N: int, K: int, block_n: int, block_k: int +) -> Optional[dict[int, Any]]: """ Return optimized configurations for the w8a8 block fp8 kernel. @@ -382,7 +375,8 @@ def get_w8a8_block_int8_configs(N: int, K: int, block_n: int, json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" # noqa: E501 config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) if os.path.exists(config_file_path): with open(config_file_path) as f: logger.info( @@ -395,8 +389,10 @@ def get_w8a8_block_int8_configs(N: int, K: int, block_n: int, # If no optimized configuration is available, we will use the default # configuration logger.warning( - ("Using default W8A8 Block INT8 kernel config. Performance might " - "be sub-optimal! Config file not found at %s"), + ( + "Using default W8A8 Block INT8 kernel config. Performance might " + "be sub-optimal! Config file not found at %s" + ), config_file_path, ) return None @@ -441,7 +437,7 @@ def w8a8_block_int8_matmul( assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] - C_shape = A.shape[:-1] + (N, ) + C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1]) @@ -462,8 +458,9 @@ def w8a8_block_int8_matmul( } def grid(META): - return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * - triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) _w8a8_block_int8_matmul[grid]( A, diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py index fbc0f23acb59..4bf31340a2f6 100644 --- a/vllm/model_executor/layers/quantization/utils/layer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -20,12 +20,15 @@ def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): # Newly generated tensors need to replace existing tensors that are # already registered as parameters by vLLM (and won't be freed) -def replace_parameter(mod: torch.nn.Module, name: str, - new: Union[torch.Tensor, torch.nn.Parameter]) -> None: - +def replace_parameter( + mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter] +) -> None: old = getattr(mod, name) - if type(old) is type(new) and old.dtype == new.dtype and \ - old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + if ( + type(old) is type(new) + and old.dtype == new.dtype + and old.untyped_storage().nbytes() == new.untyped_storage().nbytes() + ): # If we can just update in-place to avoid re-registering # can be faster if the underlying storage is the same update_tensor_inplace(old, new) @@ -36,5 +39,4 @@ def replace_parameter(mod: torch.nn.Module, name: str, # parameters for `torch.compile` compatibility if not isinstance(new, torch.nn.Parameter): new = torch.nn.Parameter(new, requires_grad=False) - mod.register_parameter(name, - torch.nn.Parameter(new, requires_grad=False)) + mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py index fbb850d22776..69466bdcb64c 100644 --- a/vllm/model_executor/layers/quantization/utils/machete_utils.py +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -39,12 +39,19 @@ def query_machete_supported_group_sizes(act_type: torch.dtype) -> list[int]: return [-1, 128] -def check_machete_supports_shape(in_features: int, out_featrues: int) \ - -> tuple[bool, Optional[str]]: +def check_machete_supports_shape( + in_features: int, out_featrues: int +) -> tuple[bool, Optional[str]]: if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: - return False, "Input features size must be divisible by "\ - f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + return ( + False, + "Input features size must be divisible by " + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}", + ) if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: - return False, "Output features size must be divisible by "\ - f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return ( + False, + "Output features size must be divisible by " + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}", + ) return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 6c7604cc9d04..d2fa5af1b854 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -40,8 +40,9 @@ def query_marlin_supported_quant_types( ): if device_capability is None: capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) if device_capability < 80: return [] @@ -50,10 +51,12 @@ def query_marlin_supported_quant_types( # - has_zp is False: return quant_types that has not zero points # - has_zp is None: both if has_zp is None: - types0 = query_marlin_supported_quant_types(False, include_fp_type, - device_capability) - types1 = query_marlin_supported_quant_types(True, include_fp_type, - device_capability) + types0 = query_marlin_supported_quant_types( + False, include_fp_type, device_capability + ) + types1 = query_marlin_supported_quant_types( + True, include_fp_type, device_capability + ) return types0 + types1 if has_zp: @@ -68,108 +71,126 @@ def query_marlin_supported_quant_types( def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: - + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = current_platform.get_device_capability() - device_capability = (-1 if capability_tuple is None else - capability_tuple.to_int()) + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) supported_types = query_marlin_supported_quant_types( - has_zp, True, device_capability) + has_zp, True, device_capability + ) if quant_type not in supported_types: - return (False, f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).") - if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): - return (False, f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.") + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) return True, None -def check_marlin_supported(quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, - device_capability) +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) return cond -def verify_marlin_supported(quant_type: ScalarType, - group_size: int, - has_zp: bool = False) -> None: +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape(output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, group_size: int) -> None: - +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError(f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError(f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") - - if (group_size < input_size - and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq.") + "with --quantization gptq." + ) -def check_marlin_supports_shape(output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, group_size: int) \ - -> tuple[bool, Optional[str]]: +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape(output_size_per_partition, - input_size_per_partition, input_size, - group_size) + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) except ValueError as e: return False, e.__str__() return True, None -def check_marlin_supports_layer(layer: LinearBase, group_size: int) \ - -> bool: - output_size_per_partition = getattr(layer, "output_size_per_partition", - None) or layer.output_size - input_size_per_partition = getattr(layer, "input_size_per_partition", - None) or layer.input_size +def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: + output_size_per_partition = ( + getattr(layer, "output_size_per_partition", None) or layer.output_size + ) + input_size_per_partition = ( + getattr(layer, "input_size_per_partition", None) or layer.input_size + ) return check_marlin_supports_shape( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=layer.input_size, - group_size=group_size)[0] + group_size=group_size, + )[0] -def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ - -> bool: +def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: hidden_size = layer.hidden_size intermediate_size_per_partition = layer.intermediate_size_per_partition # apply_router_weight_on_input is not supported for moe marlin @@ -180,51 +201,58 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \ # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) # down: (n, k) = (hidden_size, intermediate_size_per_partition) # moe marlin requires n % 128 == 0 and k % 64 == 0 - supports_shape = hidden_size % 128 == 0 and \ - intermediate_size_per_partition % max(64, group_size) == 0 + supports_shape = ( + hidden_size % 128 == 0 + and intermediate_size_per_partition % max(64, group_size) == 0 + ) supports_group_size = group_size in [-1, 32, 64, 128] - return supports_shape and supports_group_size and \ - supports_router_weight and supports_activation + return ( + supports_shape + and supports_group_size + and supports_router_weight + and supports_activation + ) -def marlin_moe_intermediate_size(w1_packed: torch.Tensor, - w2_packed: torch.Tensor): +def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tensor): """ Given Marlin packed weight matrices w1_packed, and w2_packed, - return the MoE intermediate size N + return the MoE intermediate size N """ marlin_tile_size = 16 return w2_packed.size(1) * marlin_tile_size -def marlin_make_workspace(output_size_per_partition: int, - device: torch.device) -> torch.Tensor: - max_workspace_size = (output_size_per_partition // - GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL - return torch.zeros(max_workspace_size, - dtype=torch.int, - device=device, - requires_grad=False) + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) -def marlin_make_workspace_new(device: torch.device, - max_blocks_per_sm: int = 1) -> torch.Tensor: +def marlin_make_workspace_new( + device: torch.device, max_blocks_per_sm: int = 1 +) -> torch.Tensor: # In the new marlin kernel, we use the num of threadblocks as workspace # size. The num of threadblocks is sms_count * max_blocks_per_sm. sms = torch.cuda.get_device_properties(device).multi_processor_count - return torch.zeros(sms * max_blocks_per_sm, - dtype=torch.int, - device=device, - requires_grad=False) + return torch.zeros( + sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False + ) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, - is_row_parallel: bool) -> bool: +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -232,17 +260,18 @@ def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), - requires_grad=False) + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), - requires_grad=False) + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) -def marlin_sort_g_idx( - g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices @@ -253,14 +282,13 @@ def get_scale_perms(): scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int) -> torch.Tensor: - +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] @@ -296,8 +324,9 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -318,8 +347,9 @@ def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, return zp -def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, - size_n: int, num_bits: int) -> torch.Tensor: +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -341,8 +371,9 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp -def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, - size_n: int, num_bits: int): +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -350,8 +381,7 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, - num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) return output @@ -363,7 +393,8 @@ def maybe_warn_marlin_atomic_add(device, dtype): logger.info_once( "You are running Marlin kernel with bf16 on GPUs before SM90. " "You can consider change to fp16 to achieve better performance " - "if possible.") + "if possible." + ) def maybe_warn_marlin_atomic_add_env(): @@ -375,12 +406,13 @@ def maybe_warn_marlin_atomic_add_env(): "Marlin kernel can achieve better performance for small size_n " "with experimental use_atomic_add feature. " "You can consider set environment variable " - "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") - + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible." + ) -def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, - dtype: torch.dtype) -> bool: +def should_use_atomic_add_reduce( + m: int, n: int, k: int, device: torch.device, dtype: torch.dtype +) -> bool: # the performance of atomicAdd is better than global reduce # only when m*n is small and k is large if n >= 2048 or k < 2048 or device.type != "cuda": @@ -402,88 +434,98 @@ def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition, ) - - use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), - n=output_size_per_partition, - k=reshaped_x.size(1), - device=input.device, - dtype=input.dtype) - - output = ops.gptq_marlin_gemm(reshaped_x, - None, - weight, - bias, - weight_scale, - None, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - use_atomic_add=use_atomic_add, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype, + ) + + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + bias, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) return output.reshape(out_shape) def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition, ) - - use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), - n=output_size_per_partition, - k=reshaped_x.size(1), - device=input.device, - dtype=input.dtype) - - output = ops.gptq_marlin_gemm(reshaped_x, - None, - weight, - bias, - weight_scale, - None, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - use_atomic_add=use_atomic_add, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype, + ) + + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + bias, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 94ffdcd26ecd..c5e34f392fb2 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -8,8 +8,12 @@ import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, - marlin_permute_scales, should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + should_use_atomic_add_reduce, +) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -28,7 +32,8 @@ def nvfp4_marlin_process_scales(marlin_scales): "NVFP4 Marlin assumes the scales to be >=0, but has encountered " "negative scales. Accuracy will likely be degraded. This is " "because it changes the scales from FP8-S1E4M3 to a special " - "FP8-S0E5M3 format to speedup the dequantization.") + "FP8-S0E5M3 format to speedup the dequantization." + ) # convert to half first, we would convert to fp8 later marlin_scales = marlin_scales.to(torch.half) @@ -36,11 +41,13 @@ def nvfp4_marlin_process_scales(marlin_scales): # 8 is the number of scale number using by one thread marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( - marlin_scales.size(0) * 2, -1) + marlin_scales.size(0) * 2, -1 + ) # fit the layout of fp8 dequantization marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( - marlin_scales.size(0), -1) + marlin_scales.size(0), -1 + ) # We assume that weight_scale (FP8-S1E4M3) is always greater # than or equal to 0. So we can convert @@ -60,11 +67,13 @@ def mxfp4_marlin_process_scales(marlin_scales): # 8 is the number of scale number using by one thread marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( - marlin_scales.size(0) * 2, -1) + marlin_scales.size(0) * 2, -1 + ) # fit the layout of fp8 dequantization marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( - marlin_scales.size(0), -1) + marlin_scales.size(0), -1 + ) marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) return marlin_scales @@ -78,48 +87,49 @@ def nvfp4_marlin_process_global_scale(global_scale): target_exponent = 8 # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 - exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) - return global_scale * (2.0**(exponent_bias - 7)) + exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp4_exponent - 1) + return global_scale * (2.0 ** (exponent_bias - 7)) def apply_fp4_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_scale_2: Optional[torch.Tensor], - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: Optional[torch.Tensor], + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: # For GPUs that lack FP4 hardware support, we can leverage the # Marlin kernel for fast weight-only FP4 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n, ) - - use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), - n=size_n, - k=size_k, - device=input.device, - dtype=input.dtype) - - output = ops.gptq_marlin_gemm(a=reshaped_x, - c=None, - b_q_weight=weight, - b_bias=bias, - b_scales=weight_scale, - global_scale=weight_scale_2, - b_zeros=None, - g_idx=None, - perm=None, - workspace=workspace, - b_q_type=scalar_types.float4_e2m1f, - size_m=reshaped_x.size(0), - size_n=size_n, - size_k=size_k, - use_atomic_add=use_atomic_add, - use_fp32_reduce=use_fp32_reduce) + out_shape = input.shape[:-1] + (size_n,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype + ) + + output = ops.gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=weight, + b_bias=bias, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + ) return output.reshape(out_shape) @@ -129,7 +139,8 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "Your GPU does not have native support for FP4 computation but " "FP4 quantization is being used. Weight-only FP4 compression will " "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") + "performance for compute-heavy workloads." + ) is_nvfp4 = hasattr(layer, "weight_scale_2") group_size = 16 if is_nvfp4 else 32 @@ -150,11 +161,13 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: perm = torch.empty(0, dtype=torch.int, device=device) qweight = layer.weight.view(torch.int32).T.contiguous() - marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, - perm=perm, - size_k=part_size_k, - size_n=part_size_n, - num_bits=4) + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4, + ) layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) # WEIGHT SCALES @@ -165,27 +178,23 @@ def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: weight_scale = weight_scale.view(torch.float8_e8m0fnu) weight_scale = weight_scale.to(param_dtype) - weight_scale = marlin_permute_scales(s=weight_scale, - size_k=part_size_k, - size_n=part_size_n, - group_size=group_size) + weight_scale = marlin_permute_scales( + s=weight_scale, size_k=part_size_k, size_n=part_size_n, group_size=group_size + ) if is_nvfp4: weight_scale = nvfp4_marlin_process_scales(weight_scale) - layer.weight_scale = torch.nn.Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) weight_scale_2 = layer.weight_scale_2.to(param_dtype) weight_scale_2 = nvfp4_marlin_process_global_scale(weight_scale_2) - layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, - requires_grad=False) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, requires_grad=False) else: weight_scale = mxfp4_marlin_process_scales(weight_scale) - layer.weight_scale = torch.nn.Parameter(weight_scale, - requires_grad=False) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) if hasattr(layer, "bias") and layer.bias is not None: - assert layer.bias.shape == (part_size_n, ) + assert layer.bias.shape == (part_size_n,) bias = marlin_permute_bias(layer.bias) layer.bias = torch.nn.Parameter(bias, requires_grad=False) @@ -197,7 +206,8 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: "Your GPU does not have native support for FP4 computation but " "FP4 quantization is being used. Weight-only FP4 compression will " "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") + "performance for compute-heavy workloads." + ) is_nvfp4 = hasattr(layer, "w13_weight_scale_2") group_size = 16 if is_nvfp4 else 32 @@ -227,11 +237,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: for i in range(e): qweight = weight[i].view(torch.int32).T.contiguous() - marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, - perm=perm, - size_k=size_k, - size_n=size_n, - num_bits=4) + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=4 + ) tensor_list.append(marlin_qweight) weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) @@ -247,8 +255,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: scales = scales.view(torch.float8_e8m0fnu) scales = scales.to(param_dtype) if is_nvfp4: - global_scale = getattr(layer, - name + "_weight_scale_2").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) tensor_list = [] if "w13" in name: @@ -259,10 +266,9 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: for i in range(e): scale = scales[i].T - marlin_scales = marlin_permute_scales(s=scale, - size_k=size_k, - size_n=size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scale, size_k=size_k, size_n=size_n, group_size=group_size + ) if is_nvfp4: marlin_scales = nvfp4_marlin_process_scales(marlin_scales) else: @@ -275,8 +281,7 @@ def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: if is_nvfp4: global_scale = nvfp4_marlin_process_global_scale(global_scale) - global_scale = torch.nn.Parameter(global_scale, - requires_grad=False) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) setattr(layer, name + "_weight_scale_2", global_scale) # BIAS @@ -306,26 +311,26 @@ def rand_marlin_weight_nvfp4_like(weight, group_size): global_scale = scales.max() / 448 scales = (scales / global_scale).to(torch.float8_e4m3fn) - fp4_weight = torch.randint(0, - 256, (size_n, size_k // 2), - dtype=torch.uint8, - device=weight.device) - fp4_weight_part_1 = ((fp4_weight & 0b10000000) | - ((fp4_weight & 0b01110000) >> 2)) + fp4_weight = torch.randint( + 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=weight.device + ) + fp4_weight_part_1 = (fp4_weight & 0b10000000) | ((fp4_weight & 0b01110000) >> 2) fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) fp4_weight2 = fp4_weight << 4 - fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | - ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = (fp4_weight2 & 0b10000000) | ((fp4_weight2 & 0b01110000) >> 2) fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) weight_ref = torch.cat( - [fp4_weight_part_2.unsqueeze(2), - fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) - weight_ref = weight_ref * global_scale.to(weight.dtype) * \ - scales.repeat_interleave(group_size, 1).to(weight.dtype) + [fp4_weight_part_2.unsqueeze(2), fp4_weight_part_1.unsqueeze(2)], 2 + ).view(size_n, size_k) + weight_ref = ( + weight_ref + * global_scale.to(weight.dtype) + * scales.repeat_interleave(group_size, 1).to(weight.dtype) + ) marlin_qweight = ops.gptq_marlin_repack( b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), @@ -335,10 +340,9 @@ def rand_marlin_weight_nvfp4_like(weight, group_size): num_bits=4, ) - marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), - size_k=size_k, - size_n=size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size + ) marlin_scales = nvfp4_marlin_process_scales(marlin_scales) global_scale = nvfp4_marlin_process_global_scale(global_scale) @@ -351,32 +355,31 @@ def rand_marlin_weight_mxfp4_like(weight, group_size): size_n, size_k = weight.shape device = weight.device - scales = torch.randint(100, - 125, (size_n, size_k // group_size), - dtype=torch.uint8, - device=weight.device) + scales = torch.randint( + 100, + 125, + (size_n, size_k // group_size), + dtype=torch.uint8, + device=weight.device, + ) scales = scales.view(torch.float8_e8m0fnu) - fp4_weight = torch.randint(0, - 256, (size_n, size_k // 2), - dtype=torch.uint8, - device=weight.device) - fp4_weight_part_1 = ((fp4_weight & 0b10000000) | - ((fp4_weight & 0b01110000) >> 2)) + fp4_weight = torch.randint( + 0, 256, (size_n, size_k // 2), dtype=torch.uint8, device=weight.device + ) + fp4_weight_part_1 = (fp4_weight & 0b10000000) | ((fp4_weight & 0b01110000) >> 2) fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) fp4_weight2 = fp4_weight << 4 - fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | - ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = (fp4_weight2 & 0b10000000) | ((fp4_weight2 & 0b01110000) >> 2) fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) weight_ref = torch.cat( - [fp4_weight_part_2.unsqueeze(2), - fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) - weight_ref = weight_ref * \ - scales.repeat_interleave(group_size, 1).to(weight.dtype) + [fp4_weight_part_2.unsqueeze(2), fp4_weight_part_1.unsqueeze(2)], 2 + ).view(size_n, size_k) + weight_ref = weight_ref * scales.repeat_interleave(group_size, 1).to(weight.dtype) marlin_qweight = ops.gptq_marlin_repack( b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), @@ -386,10 +389,9 @@ def rand_marlin_weight_mxfp4_like(weight, group_size): num_bits=4, ) - marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), - size_k=size_k, - size_n=size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scales.T.to(weight.dtype), size_k=size_k, size_n=size_n, group_size=group_size + ) marlin_scales = mxfp4_marlin_process_scales(marlin_scales) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 511e19545d5a..9348ac158daa 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -8,8 +8,12 @@ import vllm._custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_bias, - marlin_permute_scales, should_use_atomic_add_reduce) + USE_FP32_REDUCE_DEFAULT, + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + should_use_atomic_add_reduce, +) from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -28,60 +32,63 @@ def fp8_fused_exponent_bias_into_scales(scales): target_exponent = 8 # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 - exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp8_exponent - 1) s = torch.ones_like(scales) * 2 s = s**exponent_bias return scales * s def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n, ) - - use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), - n=size_n, - k=size_k, - device=input.device, - dtype=input.dtype) - - output = ops.gptq_marlin_gemm(a=reshaped_x, - c=None, - b_q_weight=weight, - b_bias=bias, - b_scales=weight_scale, - global_scale=None, - b_zeros=None, - g_idx=None, - perm=None, - workspace=workspace, - b_q_type=scalar_types.float8_e4m3fn, - size_m=reshaped_x.size(0), - size_n=size_n, - size_k=size_k, - use_atomic_add=use_atomic_add, - use_fp32_reduce=use_fp32_reduce) + out_shape = input.shape[:-1] + (size_n,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype + ) + + output = ops.gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=weight, + b_bias=bias, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + ) return output.reshape(out_shape) -def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, - size_k_first: bool = True) -> None: +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, size_k_first: bool = True +) -> None: logger.warning_once( "Your GPU does not have native support for FP8 computation but " "FP8 quantization is being used. Weight-only FP8 compression will " "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") + "performance for compute-heavy workloads." + ) part_size_n = layer.output_size_per_partition part_size_k = layer.input_size_per_partition @@ -104,11 +111,13 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, if not size_k_first: qweight = qweight.T.contiguous() - marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, - perm=perm, - size_k=part_size_k, - size_n=part_size_n, - num_bits=8) + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) # WEIGHT SCALES @@ -151,26 +160,27 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, # size_n may not divisible by block_size[0] scales = scales[:, :part_size_n] - marlin_scales = marlin_permute_scales(s=scales, - size_k=part_size_k, - size_n=part_size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size + ) marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) if hasattr(layer, "bias") and layer.bias is not None: - assert layer.bias.shape == (part_size_n, ) + assert layer.bias.shape == (part_size_n,) bias = marlin_permute_bias(layer.bias) layer.bias = torch.nn.Parameter(bias, requires_grad=False) -def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, - size_k_first: bool = True) -> None: +def prepare_moe_fp8_layer_for_marlin( + layer: torch.nn.Module, size_k_first: bool = True +) -> None: logger.warning_once( "Your GPU does not have native support for FP8 computation but " "FP8 quantization is being used. Weight-only FP8 compression will " "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads.") + "performance for compute-heavy workloads." + ) e = layer.num_experts k = layer.hidden_size @@ -202,11 +212,9 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, if not size_k_first: qweight = qweight.T.contiguous() - marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, - perm=perm, - size_k=size_k, - size_n=size_n, - num_bits=8) + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8 + ) tensor_list.append(marlin_qweight) weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) @@ -265,10 +273,9 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, scales = scales[..., :size_n].contiguous() for i in range(e): - marlin_scales = marlin_permute_scales(s=scales[i], - size_k=size_k, - size_n=size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size + ) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) @@ -295,8 +302,9 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, setattr(layer, name, bias) -def pack_fp8_to_int32(fp8_tensor: torch.Tensor, - size_k_first: bool = True) -> torch.Tensor: +def pack_fp8_to_int32( + fp8_tensor: torch.Tensor, size_k_first: bool = True +) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ @@ -335,10 +343,9 @@ def marlin_quant_fp8_torch(weight, group_size): num_bits=8, ) - marlin_scales = marlin_permute_scales(s=scales.T, - size_k=size_k, - size_n=size_n, - group_size=group_size) + marlin_scales = marlin_permute_scales( + s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size + ) marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index f5acd03cc662..1bbd88d5ca71 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -9,24 +9,26 @@ from vllm.scalar_type import ScalarType -from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, - marlin_zero_points) -from .quant_utils import (get_pack_factor, gptq_quantize_weights, - quantize_weights, sort_weights) +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) class MarlinWorkspace: - def __init__(self, out_features, min_thread_n, max_parallel): - assert (out_features % min_thread_n == 0), ( + assert out_features % min_thread_n == 0, ( "out_features = {} is indivisible by min_thread_n = {}".format( - out_features, min_thread_n)) + out_features, min_thread_n + ) + ) - max_workspace_size = ((out_features // min_thread_n) * max_parallel) + max_workspace_size = (out_features // min_thread_n) * max_parallel - self.scratch = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda") + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): @@ -54,8 +56,7 @@ def marlin_weights(q_w, size_k, size_n, num_bits, perm): q_w = q_w.cpu().numpy().astype(np.uint32) - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=np.uint32) + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i @@ -71,10 +72,10 @@ def get_weight_perm(num_bits: int): col = i // 4 for block in [0, 1]: for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, ]: perm1.append(16 * row + col + 8 * block) for j in range(4): @@ -94,11 +95,13 @@ def get_weight_perm(num_bits: int): return perm -def marlin_quantize(w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None): +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): size_k, size_n = w.shape num_bits = quant_type.size_bits @@ -109,7 +112,8 @@ def marlin_quantize(w: torch.Tensor, # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order, test_perm) + w, quant_type, group_size, act_order, test_perm + ) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing @@ -130,8 +134,7 @@ def marlin_quantize(w: torch.Tensor, return res_list -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, - group_size: int): +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): size_k, size_n = w.shape # Normalize group_size @@ -144,18 +147,13 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, num_groups = size_k // group_size # Quantize with zp - w_ref, q_w, s, zp = quantize_weights(w, - quant_type, - group_size, - zero_points=True) + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) # Reformat to marlin weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, - weight_perm) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, - quant_type.size_bits) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) # Create result res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py index 1c93c364679d..90011f116bb0 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -26,8 +26,7 @@ # matrix elements into reordered metadata matrix elements (or, # equivalently, for gathering reordered metadata matrix element back # into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, - device): +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) @@ -35,9 +34,13 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, group_x = 64 group_y = 32 if meta_dtype.itemsize == 2 else 16 - dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + - (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + - ((dst_rows % group_x) // 8) * 4) + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) @@ -50,8 +53,7 @@ def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, interleave = 2 cols_maj = dst_cols // interleave cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + - cols_min).view(-1) + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) # This function converts dense matrix into sparse semi-structured @@ -75,17 +77,18 @@ def sparse_semi_structured_from_dense_cutlass(dense): raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError( - "Invalid number of elements per meta element calculated") + raise RuntimeError("Invalid number of elements per meta element calculated") if meta_dtype == torch.int32: if m % 16 != 0: raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16") + f"Number of rows of dense matrix {m} must be divisible by 16" + ) else: if m % 32 != 0: raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32") + f"Number of rows of dense matrix {m} must be divisible by 32" + ) if k % (4 * quadbits_per_meta_elem) != 0: raise RuntimeError( f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 @@ -146,40 +149,39 @@ def sparse_semi_structured_from_dense_cutlass(dense): idxs1 = bit2 | (bit3.to(torch.int64) << 1) if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) else: - sparse = dense_2.gather(-1, - idxs0.unsqueeze(-1) // 2).view( - m, - k // 2) # type: ignore[possibly-undefined] + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined] meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view( - (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) if quadbits_per_meta_elem == 4: - meta = (meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12)) + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) elif quadbits_per_meta_elem == 8: - meta = (meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28)) + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined] meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device) + m, meta_ncols, meta_dtype, device + ) meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) return (sparse, meta_reordered.view(m, meta_ncols)) @@ -222,13 +224,14 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: raise RuntimeError( f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix") + "expected according to the number of columns of meta matrix" + ) # Undo meta tensor elements reordering. meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device) - meta = torch.gather(meta_reordered.view(-1), 0, - meta_offsets).view(m, meta_ncols) + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) # Unpack sparse tensor back to original dense tensor, using # information provided by meta tensor. Note that torch.float @@ -270,16 +273,17 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): meta_2[:, :, 15] = (meta >> 30) & 0b11 dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( - -1, 1).repeat(1, 2).view(-1) + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) - dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) if sparse.dtype != torch.float: # dense.scatter_(0, dense_offsets, sparse.view(-1)) dense.scatter_(0, dense_offsets, sparse.reshape(-1)) else: - dense.view(torch.half).scatter_(0, dense_offsets, - sparse.view(torch.half).view(-1)) + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) return dense.view(m, 2 * k) @@ -287,8 +291,8 @@ def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): def mask_creator(tensor): """ Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask will correspond to the given tensor. :param N: The number of weights in a group to keep @@ -301,14 +305,14 @@ def mask_creator(tensor): # for i, tensor in enumerate(tensors): if tensor.numel() % M != 0: raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " - f"{M} groups") + f"Tensor of size {tensor.shape} can't be evenly divided into {M} groups" + ) num_groups = tensor.numel() // M # N:M sparsity for linear layers tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) @@ -342,7 +346,7 @@ def check_24(w, num_rows_to_sample=50, _verbose=False): for i in sampled_row_idxs: for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): total_segments += 1 - block = w[i, j:j + BLOCK_SIZE] + block = w[i, j : j + BLOCK_SIZE] num_nonzero = torch.count_nonzero(block) if num_nonzero > MAX_NON_ZEROS: print("i = {} j = {} block = {}".format(i, j, block)) @@ -359,8 +363,7 @@ def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): # Compress q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( - q_24_no_zp) + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() # Restore bias @@ -390,13 +393,12 @@ def get_weight_perm_24(num_bits: int): col_o = col // 2 for block in [0, 1]: for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + - 4 * block) + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) for j in range(4): perm_list.extend([p + 1 * j for p in perm1]) perm = numpy.array(perm_list) @@ -413,9 +415,9 @@ def get_weight_perm_24(num_bits: int): return perm -def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, - group_size: int) -> torch.Tensor: - +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms_24() if group_size < size_k and group_size != -1: s = s.reshape((-1, len(scale_perm)))[:, scale_perm] @@ -443,17 +445,18 @@ def marlin_24_quantize( # Quantize w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False) + w_24, quant_type, group_size, act_order=False + ) # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, - quant_type) + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) size_k_comp = size_k // 2 # Reformat to marlin weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, - quant_type.size_bits, weight_perm) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) # Create result diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index fb1d041f3449..e1286b243f3b 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -14,8 +14,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): - """ weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel - """ + """weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel""" import triton_kernels.matmul_ogs_details.opt_flags as opt_flags from triton_kernels.numerics import InFlexData from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor @@ -25,30 +24,38 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): value_layout_opts: dict[str, Any] = {} scale_layout_opts: dict[str, Any] = {} - if (current_platform.is_cuda() - and current_platform.is_device_capability(90) - and not is_torch_equal_or_newer("2.8.1")): + if ( + current_platform.is_cuda() + and current_platform.is_device_capability(90) + and not is_torch_equal_or_newer("2.8.1") + ): logger.warning_once( "Mxfp4 on hopper is running on torch < 2.8.1, " "this cause swizling to be disabled, which may " - "cause performance degradation. Please upgrade to torch nightly") + "cause performance degradation. Please upgrade to torch nightly" + ) value_layout = StridedLayout scale_layout = StridedLayout elif current_platform.is_rocm(): - from triton_kernels.tensor_details.layout import (GFX950MXScaleLayout, - StridedLayout) + from triton_kernels.tensor_details.layout import ( + GFX950MXScaleLayout, + StridedLayout, + ) from vllm.platforms.rocm import on_gfx950 + value_layout = StridedLayout scale_layout = GFX950MXScaleLayout if on_gfx950() else StridedLayout else: - value_layout, value_layout_opts = \ - layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout( + mx_axis=1 + ) scale_layout, scale_layout_opts = ( layout.make_default_matmul_mxfp4_w_scale_layout( - mx_axis=1, num_warps=num_warps)) - if current_platform.is_cuda() and \ - current_platform.is_device_capability(100): + mx_axis=1, num_warps=num_warps + ) + ) + if current_platform.is_cuda() and current_platform.is_device_capability(100): constraints = { "is_persistent": True, "epilogue_subtile": 1, @@ -57,66 +64,83 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): # transpose the tensor so that the quantization axis is on dim1 quant_tensor = quant_tensor.transpose(-2, -1) scale = scale.transpose(-2, -1) - quant_tensor = convert_layout(wrap_torch_tensor(quant_tensor, dtype=FP4), - value_layout, **value_layout_opts) - scale = convert_layout(wrap_torch_tensor(scale), scale_layout, - **scale_layout_opts) + quant_tensor = convert_layout( + wrap_torch_tensor(quant_tensor, dtype=FP4), value_layout, **value_layout_opts + ) + scale = convert_layout(wrap_torch_tensor(scale), scale_layout, **scale_layout_opts) return quant_tensor, InFlexData(), scale -def _can_support_mxfp4(use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - e_score_correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - scoring_func: str = "softmax", - activation: str = "swigluoai", - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None): - return not (use_grouped_topk or topk_group or num_expert_group - or custom_routing_function or e_score_correction_bias - or apply_router_weight_on_input or scoring_func != "softmax" - or activation != "swigluoai" or expert_load_view - or logical_to_physical_map or logical_replica_count) - - -def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor, - float_dtype: torch.dtype) -> torch.Tensor: +def _can_support_mxfp4( + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + scoring_func: str = "softmax", + activation: str = "swigluoai", + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, +): + return not ( + use_grouped_topk + or topk_group + or num_expert_group + or custom_routing_function + or e_score_correction_bias + or apply_router_weight_on_input + or scoring_func != "softmax" + or activation != "swigluoai" + or expert_load_view + or logical_to_physical_map + or logical_replica_count + ) + + +def _dequant_mxfp4( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype +) -> torch.Tensor: try: from quark.torch.kernel import mx except ImportError as err: - raise ImportError("The package `amd-quark` is required to use " - "MX-FP4 models. Please install it with `pip install " - "amd-quark`.") from err + raise ImportError( + "The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`." + ) from err return mx.dq_mxfp4(x, scale, float_dtype) -def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor, - float_dtype: torch.dtype) -> torch.Tensor: - return torch.empty((*x.shape[:-1], x.shape[-1] * 2), - dtype=float_dtype, - device=x.device) +def _dequant_mxfp4_fake( + x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype +) -> torch.Tensor: + return torch.empty( + (*x.shape[:-1], x.shape[-1] * 2), dtype=float_dtype, device=x.device + ) -def _quant_dequant_mxfp4(x: torch.Tensor, - scale_calculation_mode: str = "even") -> torch.Tensor: +def _quant_dequant_mxfp4( + x: torch.Tensor, scale_calculation_mode: str = "even" +) -> torch.Tensor: try: from quark.torch.kernel import mx except ImportError as err: - raise ImportError("The package `amd-quark` is required to use " - "MX-FP4 models. Please install it with `pip install " - "amd-quark`.") from err + raise ImportError( + "The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`." + ) from err return mx.qdq_mxfp4(x, scale_calculation_mode) -def _quant_dequant_mxfp4_fake(x: torch.Tensor, - scale_calculation_mode: str = "even" - ) -> torch.Tensor: +def _quant_dequant_mxfp4_fake( + x: torch.Tensor, scale_calculation_mode: str = "even" +) -> torch.Tensor: return torch.empty_like(x) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py index 2a6b21c918f4..2d211565c19e 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp8_utils.py @@ -9,12 +9,13 @@ def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - try: from flashinfer import mxfp8_quantize except ImportError as err: - raise ImportError("The package `flashinfer` is required to do " - "MX-FP8 quantization. Please install it with" \ - "`pip install flashinfer`") from err + raise ImportError( + "The package `flashinfer` is required to do " + "MX-FP8 quantization. Please install it with" + "`pip install flashinfer`" + ) from err return mxfp8_quantize(x, is_sf_swizzled_layout=False) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py index 8648771cb017..62b480210fc0 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py @@ -12,8 +12,9 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() -kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], - dtype=torch.float32) +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) def break_fp4_bytes(a, dtype): @@ -45,12 +46,9 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): return out[0:m, 0:k] -def dequantize_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): +def dequantize_to_dtype( + tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 +): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. assert tensor_fp4.dtype == torch.uint8 @@ -95,8 +93,7 @@ def ref_nvfp4_quant(x, global_scale, block_size): assert x.ndim == 2 m, n = x.shape x = torch.reshape(x, (m, n // block_size, block_size)) - vec_max = torch.max(torch.abs(x), dim=-1, - keepdim=True)[0].to(torch.float32) + vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32) scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) scale = torch.clamp(scale, max=448, min=-448) scale = scale.to(torch.float8_e4m3fn).to(torch.float32) @@ -108,10 +105,13 @@ def ref_nvfp4_quant(x, global_scale, block_size): return cast_to_fp4(clipped_x), scale.squeeze(-1) -def run_nvfp4_emulations(x: torch.Tensor, input_global_scale: torch.Tensor, - weight: torch.Tensor, - weight_scale_swizzled: torch.Tensor, - weight_global_scale: torch.Tensor): +def run_nvfp4_emulations( + x: torch.Tensor, + input_global_scale: torch.Tensor, + weight: torch.Tensor, + weight_scale_swizzled: torch.Tensor, + weight_global_scale: torch.Tensor, +): group_size = 16 x_m, x_k = x.shape output_dtype = x.dtype @@ -127,9 +127,14 @@ def run_nvfp4_emulations(x: torch.Tensor, input_global_scale: torch.Tensor, # dequantize weight w_fp4 = weight.data.view(torch.uint8) - w_dq = dequantize_to_dtype(w_fp4, weight_scale_swizzled.data, - weight_global_scale, output_dtype, x.device, - group_size) + w_dq = dequantize_to_dtype( + w_fp4, + weight_scale_swizzled.data, + weight_global_scale, + output_dtype, + x.device, + group_size, + ) # matmul out = torch.matmul(x_dq, w_dq.t()) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py index 21af74c6b72b..c3f26cc77411 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py @@ -5,11 +5,14 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - is_flashinfer_fp4_cutlass_moe_available) + is_flashinfer_fp4_cutlass_moe_available, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - is_fp4_marlin_supported) + is_fp4_marlin_supported, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - cutlass_fp4_supported) + cutlass_fp4_supported, +) __all__ = ["detect_nvfp4_moe_support", "NvFp4Support"] @@ -29,12 +32,12 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support: """Detect platform support for NV-FP4 fused-MoE path""" cutlass_supported = cutlass_fp4_supported() - allow_flashinfer = (cutlass_supported - and is_flashinfer_fp4_cutlass_moe_available()) + allow_flashinfer = cutlass_supported and is_flashinfer_fp4_cutlass_moe_available() if allow_flashinfer: - _logger.info_once("Using FlashInfer kernels for %s.", class_name - or "NVFP4 path") + _logger.info_once( + "Using FlashInfer kernels for %s.", class_name or "NVFP4 path" + ) else: if envs.VLLM_USE_FLASHINFER_MOE_FP4: _logger.warning_once( @@ -50,7 +53,8 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support: else: raise ValueError( "Current platform does not support NVFP4 quantization. " - "Please use Blackwell GPUs or enable FlashInfer.") + "Please use Blackwell GPUs or enable FlashInfer." + ) return NvFp4Support( cutlass_supported=cutlass_supported, diff --git a/vllm/model_executor/layers/quantization/utils/petit_utils.py b/vllm/model_executor/layers/quantization/utils/petit_utils.py index 00d3def1db81..1f053103fc3c 100644 --- a/vllm/model_executor/layers/quantization/utils/petit_utils.py +++ b/vllm/model_executor/layers/quantization/utils/petit_utils.py @@ -11,14 +11,15 @@ # 1. Create a global variable as a placeholder for the module _petit_kernel: Optional["ModuleType"] = None -_PETIT_INSTALL_MSG = ("Petit is not installed. Please install it with " - "`pip install petit-kernel`.") +_PETIT_INSTALL_MSG = ( + "Petit is not installed. Please install it with `pip install petit-kernel`." +) def _import_petit_kernel() -> "ModuleType": """ A helper function to handle the lazy import. - The first time this function is called, it will import the petit_kernel + The first time this function is called, it will import the petit_kernel library and store it in the global _petit_kernel variable. Subsequent calls will return the already-loaded module directly. """ @@ -28,6 +29,7 @@ def _import_petit_kernel() -> "ModuleType": try: import petit_kernel + _petit_kernel = petit_kernel return _petit_kernel except ImportError: @@ -41,14 +43,16 @@ def _import_petit_kernel() -> "ModuleType": def _check_petit_nvfp4_supported( - quant_method: str, - group_size: Optional[int]) -> tuple[bool, Optional[str]]: + quant_method: str, group_size: Optional[int] +) -> tuple[bool, Optional[str]]: if quant_method != "NVFP4": return ( False, - ("Petit currently only supports: NVFP4 quantizations in sglang. " - "Please check the `hf_quant_config.json` file for your model's " - "quant configuration."), + ( + "Petit currently only supports: NVFP4 quantizations in sglang. " + "Please check the `hf_quant_config.json` file for your model's " + "quant configuration." + ), ) if group_size is not None and group_size != 16: return ( @@ -58,10 +62,8 @@ def _check_petit_nvfp4_supported( return (True, None) -def verify_petit_nvfp4_supported(quant_method: str, - group_size: Optional[int]) -> None: - supported, error_msg = _check_petit_nvfp4_supported( - quant_method, group_size) +def verify_petit_nvfp4_supported(quant_method: str, group_size: Optional[int]) -> None: + supported, error_msg = _check_petit_nvfp4_supported(quant_method, group_size) if not supported: assert error_msg is not None raise ValueError(error_msg) @@ -77,15 +79,15 @@ def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: qweight = layer.weight.view(torch.int32).contiguous() # 3. Call functions through the imported module variable. - petit_qweight = petit_kernel.repack_nvfp4(qweight, - size_n=part_size_n, - size_k=part_size_k) + petit_qweight = petit_kernel.repack_nvfp4( + qweight, size_n=part_size_n, size_k=part_size_k + ) layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False) # Permute scales - weight_scale = petit_kernel.process_nvfp4_scales(scales=layer.weight_scale, - size_k=part_size_k, - size_n=part_size_n) + weight_scale = petit_kernel.process_nvfp4_scales( + scales=layer.weight_scale, size_k=part_size_k, size_n=part_size_n + ) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) @@ -102,7 +104,7 @@ def apply_petit_nvfp4_linear( petit_kernel = _import_petit_kernel() reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n, ) + out_shape = input.shape[:-1] + (size_n,) # TODO: Use auto-tuning to find the performant solution_id # Call the function via the module variable. diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index acd9058fe694..2e9b279465f9 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """This file is used for /tests and /benchmarks""" + from collections.abc import Mapping from dataclasses import dataclass from types import MappingProxyType @@ -31,8 +32,8 @@ class GroupShape(_GroupShape): """ # Aliases for common quantization group shapes - PER_TENSOR: ClassVar['GroupShape'] - PER_TOKEN: ClassVar['GroupShape'] + PER_TENSOR: ClassVar["GroupShape"] + PER_TOKEN: ClassVar["GroupShape"] def is_per_tensor(self) -> bool: return self.row == -1 and self.col == -1 @@ -56,18 +57,26 @@ class ScaleDesc: static: static scale if True, dynamic if False group_shape: group shape of the scale """ + dtype: torch.dtype static: bool group_shape: GroupShape def __str__(self): - group_shape = ('per_tensor' - if self.group_shape == GroupShape.PER_TENSOR else - ('per_token' if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape))) - - return (f"{fx.graph.dtype_abbrs[self.dtype]}," - f"{'static' if self.static else 'dynamic'},{group_shape}") + group_shape = ( + "per_tensor" + if self.group_shape == GroupShape.PER_TENSOR + else ( + "per_token" + if self.group_shape == GroupShape.PER_TOKEN + else str(self.group_shape) + ) + ) + + return ( + f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'static' if self.static else 'dynamic'},{group_shape}" + ) @dataclass(frozen=True) @@ -79,6 +88,7 @@ class QuantKey: scale2: second-level scale descriptor symmetric: symmetric if True, asymmetric if False """ + dtype: torch.dtype scale: ScaleDesc scale2: Optional[ScaleDesc] = None @@ -86,9 +96,11 @@ class QuantKey: def __str__(self): scale2_str = f"scale2({self.scale2})," if self.scale2 else "" - return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]}," - f"scale({self.scale}),{scale2_str}" - f"{'a' if not self.symmetric else ''}symmetric)") + return ( + f"QuantKey({fx.graph.dtype_abbrs[self.dtype]}," + f"scale({self.scale}),{scale2_str}" + f"{'a' if not self.symmetric else ''}symmetric)" + ) kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR) @@ -101,16 +113,16 @@ def __str__(self): kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) -kNvfp4Quant = QuantKey(FP4_DTYPE, - scale=kNvfp4GroupScale, - scale2=kStaticTensorScale) +kNvfp4Quant = QuantKey(FP4_DTYPE, scale=kNvfp4GroupScale, scale2=kStaticTensorScale) # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent - return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], - group_shape[1] if group_shape[1] > 0 else x.shape[-1]) + return ( + group_shape[0] if group_shape[0] > 0 else x.shape[-2], + group_shape[1] if group_shape[1] > 0 else x.shape[-1], + ) # Useful when treating N-dimensional group scaling as extended numpy-style @@ -131,9 +143,11 @@ def group_broadcast(t, shape): for i, s in enumerate(shape): if t.shape[i] != s and t.shape[i] != 1: assert s % t.shape[i] == 0 - t = t.unsqueeze(i + 1)\ - .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\ + t = ( + t.unsqueeze(i + 1) + .expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :]) .flatten(i, i + 1) + ) return t @@ -151,9 +165,10 @@ def scaled_quantize( quant_dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: group_shape = _normalize_quant_group_shape(x, group_shape) - assert quant_dtype.is_floating_point, \ - "currently `scaled_quantize` only supports floating point dtypes " \ + assert quant_dtype.is_floating_point, ( + "currently `scaled_quantize` only supports floating point dtypes " "but could be extended to support other dtypes" + ) finfo = torch.finfo(quant_dtype) @@ -175,11 +190,13 @@ def scaled_quantize( # Apply scale and convert form: # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) - x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\ - .clamp(min=finfo.min, max=finfo.max)\ - .reshape(blk_m, blk_n, group_shape[0], group_shape[1])\ - .permute(0, 2, 1, 3)\ + x_scl_sat = ( + (x_blkd_permd * scale.unsqueeze(-1)) + .clamp(min=finfo.min, max=finfo.max) + .reshape(blk_m, blk_n, group_shape[0], group_shape[1]) + .permute(0, 2, 1, 3) .reshape(x.shape) + ) return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal() @@ -200,7 +217,8 @@ def scaled_dequantize( if group_shape is None: raise AssertionError( "if x_s is 1D tensor, group_shape must be provided otherwise " - "its ambiguous which dimension to broadcast x_s to") + "its ambiguous which dimension to broadcast x_s to" + ) # unsqueeze the scales for the dimension where we want to broadcast # across the full extent if group_shape[0] == x_q.shape[-2]: @@ -210,7 +228,8 @@ def scaled_dequantize( else: raise AssertionError( "if x_s is a vector we should be broadcasting it to the full " - "extent of one of the dimensions") + "extent of one of the dimensions" + ) if group_shape is not None: assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1] @@ -219,9 +238,9 @@ def scaled_dequantize( return (x_q.to(torch.float32) * x_s).to(out_dtype) -def pack_quantized_values_into_int32(w_q: torch.Tensor, - wtype: ScalarType, - packed_dim: int = 0): +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): # move dim to pack to the end perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) inv_perm = tuple(perm.index(i) for i in range(len(perm))) @@ -241,9 +260,9 @@ def pack_quantized_values_into_int32(w_q: torch.Tensor, return res.permute(inv_perm) -def unpack_quantized_values_into_int32(w_q: torch.Tensor, - wtype: ScalarType, - packed_dim: int = 0): +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): # move dim to pack to the end perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) inv_perm = tuple(perm.index(i) for i in range(len(perm))) @@ -265,7 +284,7 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor, def is_layer_skipped( prefix: str, ignored_layers: list[str], - fused_mapping: Mapping[str, list[str]] = MappingProxyType({}) + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), ) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj @@ -291,12 +310,16 @@ def is_layer_skipped( raise ValueError( f"Detected some but not all shards of {prefix} " "are quantized. All shards of fused layers " - "to have the same precision.") + "to have the same precision." + ) elif "experts" in prefix: - return any([ - prefix in layer_name for layer_name in ignored_layers - if "experts" in layer_name - ]) + return any( + [ + prefix in layer_name + for layer_name in ignored_layers + if "experts" in layer_name + ] + ) else: is_skipped = prefix in ignored_layers @@ -309,16 +332,18 @@ def get_pack_factor(num_bits): return 32 // num_bits -def permute_rows(q_w: torch.Tensor, - w_ref: torch.Tensor, - group_size: int, - test_perm: Optional[torch.Tensor] = None): +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): assert q_w.shape == w_ref.shape orig_device = q_w.device k_size, _ = q_w.shape - g_idx = torch.zeros((k_size, ), dtype=torch.int32) + g_idx = torch.zeros((k_size,), dtype=torch.int32) for i in range(k_size): g_idx[i] = i // group_size @@ -337,16 +362,20 @@ def permute_rows(q_w: torch.Tensor, ) -def quantize_weights(w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False): - assert quant_type.is_integer(), \ +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert quant_type.is_integer(), ( "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, \ - "to have group zero points, group_size must be provided "\ + ) + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " "(-1 group_size is channelwise)" + ) orig_device = w.device orig_type = w.dtype @@ -376,14 +405,16 @@ def quantize_weights(w: torch.Tensor, if zero_points: assert not quant_type.is_signed() and quant_type.max() > 0 w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ - .clamp(min_q_val, max_q_val).int() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) else: # If the bias is such that there are no possible negative/positive # values, set the max value to inf to avoid divide by 0 w_s = torch.max( abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) # Quantize w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) @@ -430,19 +461,22 @@ def reshape_w(w): SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] -def gptq_quantize_weights(w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None): +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): size_k, _ = w.shape assert w.is_floating_point(), "w must be float" - assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \ + assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, ( f"Unsupported gptq type = {quant_type}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" + ) + assert group_size in SUPPORTED_GROUP_SIZES + [size_k], ( + f"Unsupported groupsize = {group_size}" + ) w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) @@ -450,13 +484,13 @@ def gptq_quantize_weights(w: torch.Tensor, g_idx = torch.empty(0, dtype=torch.int, device=w.device) rand_perm = torch.empty(0, dtype=torch.int, device=w.device) if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k) + assert group_size < size_k, ( + "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + ) - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, - test_perm) + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) return w_ref, w_q, w_s, g_idx, rand_perm @@ -464,8 +498,7 @@ def gptq_quantize_weights(w: torch.Tensor, def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): orig_device = q_w.device - sort_indices = torch.argsort(g_idx).to( - dtype=torch.int32) # Sort based on g_idx + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx g_idx = g_idx[sort_indices].contiguous() q_w = q_w[sort_indices, :].contiguous() @@ -535,10 +568,11 @@ def unpack_cols( ): pack_factor = get_pack_factor(num_bits) assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, size_n // pack_factor - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor) + assert packed_q_w.shape == (size_k, size_n // pack_factor), ( + "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + ) orig_device = packed_q_w.device @@ -604,7 +638,8 @@ def swizzle_blockscale(scale: torch.Tensor) -> torch.Tensor: """ assert scale.dtype == torch.float8_e4m3fn, ( "swizzle_blockscale expects the input tensor to be in " - "torch.float8_e4m3fn format.") + "torch.float8_e4m3fn format." + ) scale_ndim = scale.ndim if scale_ndim == 2: @@ -619,9 +654,9 @@ def _round_up(x: int, m: int) -> int: M_padded = _round_up(M, 128) K_padded = _round_up(K, 4) - padded = torch.zeros((B, M_padded, K_padded), - dtype=scale.dtype, - device=scale.device) + padded = torch.zeros( + (B, M_padded, K_padded), dtype=scale.dtype, device=scale.device + ) padded[:B, :M, :K] = scale # Reshape / permute to the layout required by the kernel. diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index b434b7acfea8..c26cd4f28cb6 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -10,8 +10,7 @@ from vllm import envs from vllm.config import CompilationLevel, get_current_vllm_config from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer @@ -24,9 +23,11 @@ # torch._scaled_mm rowwise feature. # The condition is determined once as the operations # are time-consuming. -USE_ROWWISE_TORCH_SCALED_MM = (current_platform.is_rocm() and version.parse( - torch.__version__) >= version.parse("2.7") - and current_platform.has_device_capability(94)) +USE_ROWWISE_TORCH_SCALED_MM = ( + current_platform.is_rocm() + and version.parse(torch.__version__) >= version.parse("2.7") + and current_platform.has_device_capability(94) +) def sparse_cutlass_supported() -> bool: @@ -74,8 +75,8 @@ def cutlass_group_gemm_supported() -> bool: def per_tensor_dequantize( - tensor: torch.Tensor, inv_scale: Union[float, - torch.Tensor]) -> torch.Tensor: + tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] +) -> torch.Tensor: fake_qweight = tensor.to(torch.float16) dq_weight = fake_qweight * inv_scale return dq_weight @@ -87,12 +88,12 @@ def all_close_1d(x: torch.Tensor) -> bool: def convert_to_channelwise( - weight_scale: torch.Tensor, - logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]: + weight_scale: torch.Tensor, logical_widths: list[int] +) -> tuple[torch.Tensor, torch.Tensor]: # Create channelwise buffer - weight_scale_channel = torch.empty((sum(logical_widths), 1), - dtype=torch.float32, - device=weight_scale.device) + weight_scale_channel = torch.empty( + (sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device + ) # Expand each scale to match the size of each logical matrix. start = 0 @@ -105,8 +106,8 @@ def convert_to_channelwise( def requantize_with_max_scale( - weight: torch.Tensor, weight_scale: torch.Tensor, - logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]: + weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: list[int] +) -> tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() @@ -116,8 +117,9 @@ def requantize_with_max_scale( # from disk in this case. Skip requantization in this case (since) # we already are quantized with the single scale. # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 - unfused_module_in_checkpoint = (weight_scale[-1] - > torch.finfo(torch.float8_e4m3fn).min) + unfused_module_in_checkpoint = ( + weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min + ) # If unfused checkpoint, need requanize with the single scale. if unfused_module_in_checkpoint: @@ -127,10 +129,8 @@ def requantize_with_max_scale( if logical_width == 0: continue end = start + logical_width - weight_dq = per_tensor_dequantize(weight[start:end, :], - weight_scale[idx]) - weight[start:end, :], _ = ops.scaled_fp8_quant( - weight_dq, max_w_scale) + weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) + weight[start:end, :], _ = ops.scaled_fp8_quant(weight_dq, max_w_scale) start = end return max_w_scale, weight @@ -143,75 +143,102 @@ def maybe_create_device_identity(): TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) -def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, - out_dtype: torch.dtype, scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - output_shape: list, **kwargs) -> torch.Tensor: - +def cutlass_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, + **kwargs, +) -> torch.Tensor: # Fused GEMM_DQ - output = ops.cutlass_scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias) + output = ops.cutlass_scaled_mm( + qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias + ) return output.view(*output_shape) -def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, - out_dtype: torch.dtype, scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - output_shape: list, **kwargs) -> torch.Tensor: - - return flashinfer_scaled_fp8_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias) - - -def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor) -> torch.Tensor: +def flashinfer_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, + **kwargs, +) -> torch.Tensor: + return flashinfer_scaled_fp8_mm( + qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias + ) + + +def rocm_per_tensor_w8a8_scaled_mm_impl( + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx - if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx() and \ - qinput.shape[0] == 1 and \ - qinput.shape[1] % 16 == 0 and \ - ((bias is None) or (bias.dtype == out_dtype)) : - output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, - current_platform.get_cu_count(), bias) + + if ( + envs.VLLM_ROCM_USE_SKINNY_GEMM + and on_mi3xx() + and qinput.shape[0] == 1 + and qinput.shape[1] % 16 == 0 + and ((bias is None) or (bias.dtype == out_dtype)) + ): + output = ops.wvSplitKQ( + weight.t(), + qinput, + out_dtype, + scale_a, + scale_b, + current_platform.get_cu_count(), + bias, + ) else: - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias) + output = torch._scaled_mm( + qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + bias=bias, + ) return output -def rocm_per_tensor_w8a8_scaled_mm_fake(qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - bias: torch.Tensor) -> torch.Tensor: - return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), - dtype=out_dtype) +def rocm_per_tensor_w8a8_scaled_mm_fake( + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype) -def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - output_shape: list) -> torch.Tensor: +def rocm_per_tensor_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( - qinput, weight, out_dtype, scale_a, scale_b, bias) + qinput, weight, out_dtype, scale_a, scale_b, bias + ) return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) @@ -222,18 +249,19 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, ) -def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - output_shape: list) -> torch.Tensor: - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b, - bias=bias) +def torch_per_tensor_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + output = torch._scaled_mm( + qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias + ) # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: @@ -242,13 +270,17 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape) -def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - output_shape: list, - **kwargs) -> torch.Tensor: +def torch_per_token_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, + **kwargs, +) -> torch.Tensor: # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # when using it. # For now it has only been validated on ROCm platform. @@ -260,25 +292,31 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, # rowwise scaled GEMM before using it # Fused GEMM_DQ Rowwise GEMM - output = torch._scaled_mm(qinput, - weight, - out_dtype=out_dtype, - scale_a=scale_a, - scale_b=scale_b.t(), - bias=bias) + output = torch._scaled_mm( + qinput, + weight, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b.t(), + bias=bias, + ) output = torch.narrow(output, 0, 0, qinput.shape[0]) output = output.view(*output_shape) return output -def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - output_shape: list, - **kwargs) -> torch.Tensor: +def torch_channelwise_w8a8_scaled_mm( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, + **kwargs, +) -> torch.Tensor: # Use unfused DQ due to limitations with scaled_mm # Symmetric quantized GEMM by definition computes the following: @@ -296,11 +334,13 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, # GEMM # This computes C = (X * W). # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm(qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32) + output = torch._scaled_mm( + qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32, + ) # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5 if type(output) is tuple and len(output) == 2: @@ -318,9 +358,8 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def dispatch_w8a8_scaled_mm( - preferred_backend: str, per_tensor_weights: bool, - per_tensor_activations: bool) -> Callable[..., torch.Tensor]: - + preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool +) -> Callable[..., torch.Tensor]: if per_tensor_weights and per_tensor_activations: if preferred_backend == "rocm": return rocm_per_tensor_w8a8_scaled_mm @@ -335,8 +374,11 @@ def dispatch_w8a8_scaled_mm( return cutlass_w8a8_scaled_mm # If torch.scaled_mm supports per-channel (weights) per-token (inputs) - if not per_tensor_weights and not per_tensor_activations \ - and USE_ROWWISE_TORCH_SCALED_MM: + if ( + not per_tensor_weights + and not per_tensor_activations + and USE_ROWWISE_TORCH_SCALED_MM + ): return torch_per_token_w8a8_scaled_mm # Normally, torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token @@ -353,15 +395,16 @@ class Fp8LinearOp: in the __init__ method, as reading config is not allowed inside forward. """ - def __init__(self, - act_quant_static: bool, - act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, - pad_output: Optional[bool] = None): + def __init__( + self, + act_quant_static: bool, + act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, + pad_output: Optional[bool] = None, + ): if current_platform.is_rocm(): self.preferred_backend = "rocm" elif current_platform.is_cuda() and cutlass_fp8_supported(): - if has_flashinfer() and current_platform.has_device_capability( - 100): + if has_flashinfer() and current_platform.has_device_capability(100): self.preferred_backend = "flashinfer" else: self.preferred_backend = "cutlass" @@ -375,15 +418,19 @@ def __init__(self, # as it breaks with dynamic shapes. if pad_output is None: config = get_current_vllm_config().compilation_config - pad_output = config.level < CompilationLevel.PIECEWISE and \ - self.preferred_backend == "torch" + pad_output = ( + config.level < CompilationLevel.PIECEWISE + and self.preferred_backend == "torch" + ) self.output_padding = 17 if pad_output else None self.act_quant_static = act_quant_static self.act_quant_group_shape = act_quant_group_shape - self.quant_fp8 = QuantFP8(static=act_quant_static, - group_shape=act_quant_group_shape, - num_token_padding=self.output_padding) + self.quant_fp8 = QuantFP8( + static=act_quant_static, + group_shape=act_quant_group_shape, + num_token_padding=self.output_padding, + ) def apply( self, @@ -417,27 +464,29 @@ def apply( else: qinput, x_scale = input_2d, input_scale - per_tensor_weights = (weight_scale.numel() == 1) - per_tensor_activations = (x_scale.numel() == 1) + per_tensor_weights = weight_scale.numel() == 1 + per_tensor_activations = x_scale.numel() == 1 # TODO(luka) do this dispatch during init (after ScaledMM refactor) - w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(self.preferred_backend, - per_tensor_weights, - per_tensor_activations) + w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( + self.preferred_backend, per_tensor_weights, per_tensor_activations + ) - return w8a8_scaled_mm_func(qinput=qinput, - weight=weight, - out_dtype=out_dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - output_shape=output_shape) + return w8a8_scaled_mm_func( + qinput=qinput, + weight=weight, + out_dtype=out_dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + output_shape=output_shape, + ) def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None + input_scale: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: assert weight.dtype == torch.float8_e4m3fn # The bits pattern 10000000(-128) represents zero in e4m3fn diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index 3f2d571777c0..6ae2db0f428c 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -32,6 +32,7 @@ Example models: Qwen (Qwen-VL), MiniCPM-V 2.0 """ + import math from functools import partial from typing import Callable, Optional, Union @@ -47,8 +48,9 @@ DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) -def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, - int]) -> torch.Tensor: +def get_abs_pos( + abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, int] +) -> torch.Tensor: # abs_pos: L, C # tgt_size: (H, W) # return: M, C @@ -56,21 +58,26 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, dtype = abs_pos.dtype if isinstance(tgt_size, int): tgt_size = (tgt_size, tgt_size) - if (src_size == tgt_size[0] and src_size == tgt_size[1]): + if src_size == tgt_size[0] and src_size == tgt_size[1]: return abs_pos - return (F.interpolate( - abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), - size=(tgt_size[0], tgt_size[1]), - mode="bicubic", - align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) + return ( + F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size[0], tgt_size[1]), + mode="bicubic", + align_corners=False, + ) + .permute(0, 2, 3, 1) + .flatten(0, 2) + .to(dtype=dtype) + ) # sin/cos positional embedding helpers are adapted from: # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_1d_sincos_pos_embed_from_grid( - embed_dim: int, pos: np.ndarray, - version: tuple[int, int] = (2, 0)) -> torch.Tensor: + embed_dim: int, pos: np.ndarray, version: tuple[int, int] = (2, 0) +) -> torch.Tensor: """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) / (H, W) @@ -96,15 +103,17 @@ def get_1d_sincos_pos_embed_from_grid( def get_2d_sincos_pos_embed_from_grid( - embed_dim: int, grid: np.ndarray, - version: tuple[int, int] = (2, 0)) -> torch.Tensor: + embed_dim: int, grid: np.ndarray, version: tuple[int, int] = (2, 0) +) -> torch.Tensor: assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid( - embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) + embed_dim // 2, grid[0], version + ) # (H*W, D/2) or (H, W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid( - embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) + embed_dim // 2, grid[1], version + ) # (H*W, D/2) or (H, W, D/2) if version == (2, 0): emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) @@ -114,10 +123,10 @@ def get_2d_sincos_pos_embed_from_grid( def get_2d_sincos_pos_embed( - embed_dim: int, - grid_size: Union[int, tuple[int, int]], - cls_token: bool = False, - version: tuple[int, int] = (2, 0), + embed_dim: int, + grid_size: Union[int, tuple[int, int]], + cls_token: bool = False, + version: tuple[int, int] = (2, 0), ) -> torch.Tensor: """ grid_size: int of the grid height and width @@ -134,15 +143,13 @@ def get_2d_sincos_pos_embed( grid_w = np.arange(grid_w_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) - assert isinstance(grid, np.ndarray) and \ - grid.shape == (2, grid_h_size, grid_w_size) + assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size) if version == (2, 0): grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], - axis=0) + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) else: pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) return pos_embed @@ -156,15 +163,17 @@ class BaseResampler(nn.Module): A tensor with the shape of (grid_size**2, embed_dim) """ - def __init__(self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - do_post_projection: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.num_queries = num_queries @@ -174,14 +183,16 @@ def __init__(self, self.query = nn.Parameter(torch.empty(self.num_queries, embed_dim)) if kv_dim is not None and kv_dim != embed_dim: - self.kv_proj = ReplicatedLinear(kv_dim, - embed_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_proj") + self.kv_proj = ReplicatedLinear( + kv_dim, + embed_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_proj", + ) else: # Maintain the same return value with ReplicatedLinear.forward - self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa + self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa nn.Identity()(*args, **kwargs), None, ) @@ -190,9 +201,11 @@ def __init__(self, self.ln_kv = norm_layer(embed_dim) self.do_post_projection = do_post_projection self.ln_post = norm_layer(embed_dim) if do_post_projection else None - self.proj = nn.Parameter( - (embed_dim**-0.5) * - torch.empty(embed_dim, embed_dim)) if do_post_projection else None + self.proj = ( + nn.Parameter((embed_dim**-0.5) * torch.empty(embed_dim, embed_dim)) + if do_post_projection + else None + ) def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) @@ -206,32 +219,35 @@ class Resampler2(BaseResampler): present in minicpmv2.0, but not qwen-vl. """ - def __init__(self, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - adaptive: bool = False, - do_post_projection: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: - super().__init__(grid_size**2, - embed_dim, - num_heads, - kv_dim, - norm_layer, - do_post_projection=do_post_projection, - quant_config=quant_config, - prefix=prefix) + def __init__( + self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + grid_size**2, + embed_dim, + num_heads, + kv_dim, + norm_layer, + do_post_projection=do_post_projection, + quant_config=quant_config, + prefix=prefix, + ) self.adaptive = adaptive - pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, - grid_size, - version=(2, 0)) + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, grid_size, version=(2, 0)) self.pos_embed = nn.Parameter( - torch.from_numpy(pos_embed_arr).requires_grad_(False)) + torch.from_numpy(pos_embed_arr).requires_grad_(False) + ) def forward( self, @@ -242,15 +258,16 @@ def forward( if tgt_sizes is None: tgt_sizes = int(math.sqrt(x.size(1))) if self.adaptive: - pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, - tgt_sizes, - version=(2, 0)) - pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, - dtype=x.dtype) + pos_embed_arr = get_2d_sincos_pos_embed( + self.embed_dim, tgt_sizes, version=(2, 0) + ) + pos_embed = torch.from_numpy(pos_embed_arr).to( + device=x.device, dtype=x.dtype + ) else: - pos_embed = get_abs_pos(self.pos_embed, - tgt_sizes).to(device=x.device, - dtype=x.dtype) + pos_embed = get_abs_pos(self.pos_embed, tgt_sizes).to( + device=x.device, dtype=x.dtype + ) x, _ = self.kv_proj(x) x = self.ln_kv(x).permute(1, 0, 2) diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 3576368981c7..e6956de4bfaa 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Rotary Positional Embeddings.""" + from typing import Any, Optional import torch @@ -37,8 +38,7 @@ def get_rope( if rope_scaling is not None: # Transforms every value that is a list into a tuple for caching calls rope_scaling_tuple = { - k: tuple(v) if isinstance(v, list) else v - for k, v in rope_scaling.items() + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() } rope_scaling_args = tuple(rope_scaling_tuple.items()) else: @@ -56,8 +56,16 @@ def get_rope( if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) - key = (head_size, rotary_dim, max_position, base, is_neox_style, - rope_scaling_args, dual_chunk_attention_args, dtype) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dual_chunk_attention_args, + dtype, + ) if key in _ROPE_DICT: return _ROPE_DICT[key] @@ -67,13 +75,19 @@ def get_rope( for k, v in dual_chunk_attention_config.items() if k in ("chunk_size", "local_size") } - rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype, - **extra_kwargs) + rotary_emb = DualChunkRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + **extra_kwargs, + ) elif not rope_scaling: - rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style, dtype) + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ) else: scaling_type = rope_scaling["rope_type"] @@ -81,18 +95,23 @@ def get_rope( scaling_factor = rope_scaling["factor"] low_freq_factor = rope_scaling["low_freq_factor"] high_freq_factor = rope_scaling["high_freq_factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] - rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype, - scaling_factor, low_freq_factor, - high_freq_factor, - original_max_position) + original_max_position = rope_scaling["original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + scaling_factor, + low_freq_factor, + high_freq_factor, + original_max_position, + ) elif scaling_type == "mllama4": - rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype) + rotary_emb = Llama4VisionRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ) elif scaling_type == "default": if "mrope_section" in rope_scaling: rotary_emb = MRotaryEmbedding( @@ -103,8 +122,7 @@ def get_rope( is_neox_style, dtype, mrope_section=rope_scaling["mrope_section"], - mrope_interleaved=rope_scaling.get("mrope_interleaved", - False), + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), ) else: rotary_emb = RotaryEmbedding( @@ -117,41 +135,63 @@ def get_rope( ) elif scaling_type == "linear": scaling_factor = rope_scaling["factor"] - rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, - scaling_factor, dtype) + rotary_emb = LinearScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) elif scaling_type == "ntk": scaling_factor = rope_scaling["factor"] - mixed_b = rope_scaling.get('mixed_b', None) - rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, - scaling_factor, dtype, - mixed_b) + mixed_b = rope_scaling.get("mixed_b", None) + rotary_emb = NTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + mixed_b, + ) elif scaling_type == "dynamic": if "alpha" in rope_scaling: scaling_alpha = rope_scaling["alpha"] rotary_emb = DynamicNTKAlphaRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - scaling_alpha, dtype) + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_alpha, + dtype, + ) elif "factor" in rope_scaling: scaling_factor = rope_scaling["factor"] rotary_emb = DynamicNTKScalingRotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor, dtype) + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) else: - raise ValueError("Dynamic rope scaling must contain either " - "'alpha' or 'factor' field") + raise ValueError( + "Dynamic rope scaling must contain either 'alpha' or 'factor' field" + ) elif scaling_type == "yarn": scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] + original_max_position = rope_scaling["original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() - if k in ("extrapolation_factor", "attn_factor", "beta_fast", - "beta_slow") + if k + in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } if "mrope_section" in rope_scaling: rotary_emb = MRotaryEmbedding( @@ -162,42 +202,69 @@ def get_rope( is_neox_style, dtype, mrope_section=rope_scaling["mrope_section"], - mrope_interleaved=rope_scaling.get("mrope_interleaved", - False), + mrope_interleaved=rope_scaling.get("mrope_interleaved", False), scaling_factor=scaling_factor, - **extra_kwargs) + **extra_kwargs, + ) else: rotary_emb = YaRNScalingRotaryEmbedding( - head_size, rotary_dim, original_max_position, base, - is_neox_style, scaling_factor, dtype, **extra_kwargs) + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) elif scaling_type == "deepseek_yarn": scaling_factor = rope_scaling["factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] + original_max_position = rope_scaling["original_max_position_embeddings"] # assert max_position == original_max_position * scaling_factor extra_kwargs = { k: v for k, v in rope_scaling.items() - if k in ("extrapolation_factor", "attn_factor", "beta_fast", - "beta_slow", "mscale", "mscale_all_dim") + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) } rotary_emb = DeepseekScalingRotaryEmbedding( - head_size, rotary_dim, original_max_position, base, - is_neox_style, scaling_factor, dtype, **extra_kwargs) + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) elif scaling_type == "longrope": short_factor = rope_scaling["short_factor"] long_factor = rope_scaling["long_factor"] - original_max_position = rope_scaling[ - "original_max_position_embeddings"] + original_max_position = rope_scaling["original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("short_mscale", "long_mscale") } rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( - head_size, rotary_dim, max_position, original_max_position, - base, is_neox_style, dtype, short_factor, long_factor, - **extra_kwargs) + head_size, + rotary_dim, + max_position, + original_max_position, + base, + is_neox_style, + dtype, + short_factor, + long_factor, + **extra_kwargs, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 0cf634f82a8a..cf50b60118b9 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Rotary Positional Embeddings Base Class.""" + from typing import Optional import torch @@ -8,8 +9,10 @@ from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_torch -from .rocm_aiter_rope_ops import (is_rocm_triton_rotary_embedding_enabled, - rocm_aiter_rotary_emb) +from .rocm_aiter_rope_ops import ( + is_rocm_triton_rotary_embedding_enabled, + rocm_aiter_rotary_emb, +) @CustomOp.register("rotary_embedding") @@ -47,8 +50,9 @@ def __init__( cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.is_rocm_triton_rotary_embedding_enabled = \ + self.is_rocm_triton_rotary_embedding_enabled = ( is_rocm_triton_rotary_embedding_enabled() + ) def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" @@ -56,8 +60,12 @@ def _compute_inv_freq(self, base: float) -> torch.Tensor: # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: @@ -74,10 +82,11 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) # is expensive, so avoid calling it if possible - if self.cos_sin_cache.device != query.device or \ - self.cos_sin_cache.dtype != query.dtype: - self.cos_sin_cache = self.cos_sin_cache.to(query.device, - dtype=query.dtype) + if ( + self.cos_sin_cache.device != query.device + or self.cos_sin_cache.dtype != query.dtype + ): + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) def forward_native( self, @@ -93,20 +102,18 @@ def forward_native( query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = apply_rotary_emb_torch(query_rot, cos, sin, - self.is_neox_style) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) # key may be None in some cases, e.g. cross-layer KV sharing if key is not None: key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_emb_torch(key_rot, cos, sin, - self.is_neox_style) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -117,19 +124,30 @@ def forward_cuda( key: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if self.use_flashinfer: - torch.ops.vllm.flashinfer_rotary_embedding(positions, query, key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style) + torch.ops.vllm.flashinfer_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key from vllm import _custom_ops as ops + self._match_cos_sin_cache_dtype(query) # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key def forward_hip( @@ -140,9 +158,15 @@ def forward_hip( ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if self.is_rocm_triton_rotary_embedding_enabled: self._match_cos_sin_cache_dtype(query) - rocm_aiter_rotary_emb(positions, query, key, self.cos_sin_cache, - self.head_size, self.rotary_dim, - self.is_neox_style) + rocm_aiter_rotary_emb( + positions, + query, + key, + self.cos_sin_cache, + self.head_size, + self.rotary_dim, + self.is_neox_style, + ) else: # ops.rotary_embedding() is an in-place operation # that updates the query and key tensors. @@ -166,8 +190,14 @@ def forward_xpu( # ipex.llm.functional.rotary_embedding_batched return self.forward_native(positions, query, key) else: - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key def extra_repr(self) -> str: diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 0d11d1ffea9f..124ea0236cbf 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -20,8 +20,8 @@ # common functions def rotate_neox(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -53,9 +53,9 @@ def apply_rotary_emb_torch( return torch.stack((o1, o2), dim=-1).flatten(-2) -def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool) -> torch.Tensor: +def apply_rotary_emb_dispatch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool +) -> torch.Tensor: """ Args: x: [num_tokens, num_heads, head_size] @@ -65,15 +65,14 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, positional embeddings. """ if current_platform.is_cuda(): - return apply_rotary_emb(x.unsqueeze(0), cos, sin, - not is_neox_style).squeeze(0) + return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0) else: return apply_rotary_emb_torch(x, cos, sin, is_neox_style) @cache def dispatch_rotary_emb_function( - default: Optional[Callable[..., torch.Tensor]] = None + default: Optional[Callable[..., torch.Tensor]] = None, ) -> Callable[..., torch.Tensor]: if current_platform.is_cuda(): return apply_rotary_emb @@ -81,11 +80,13 @@ def dispatch_rotary_emb_function( if current_platform.is_rocm(): if find_spec("flash_attn") is not None: from flash_attn.ops.triton.rotary import apply_rotary + return apply_rotary else: logger.warning( "flash_attn is not installed. Falling back to PyTorch " - "implementation for rotary embeddings.") + "implementation for rotary embeddings." + ) if default is not None: return default @@ -95,31 +96,37 @@ def dispatch_rotary_emb_function( # yarn functions # Inverse dim formula to find dim based on number of rotations -def yarn_find_correction_dim(num_rotations: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> float: - return (dim * math.log(max_position_embeddings / - (num_rotations * 2 * math.pi))) / (2 * - math.log(base)) +def yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) # Find dim range bounds based on rotations def yarn_find_correction_range( - low_rot: int, - high_rot: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> tuple[int, int]: + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> tuple[int, int]: low = math.floor( - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) high = math.ceil( - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) return max(low, 0), min(high, dim - 1) # Clamp values just in case -def yarn_linear_ramp_mask(low: float, high: float, dim: int, - dtype: torch.dtype) -> torch.Tensor: +def yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype +) -> torch.Tensor: if low == high: high += 0.001 # Prevent singularity @@ -143,7 +150,7 @@ def _flashinfer_rotary_embedding( is_neox: bool, ) -> None: """Custom op wrapper for flashinfer's rotary embedding. - + This is an in-place operation that modifies query and key tensors directly. """ from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index 736ec2c1dd3a..eaedca9b5219 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -9,8 +9,12 @@ from vllm.platforms import current_platform from .base import RotaryEmbedding -from .common import (rotate_gptj, rotate_neox, yarn_find_correction_range, - yarn_linear_ramp_mask) +from .common import ( + rotate_gptj, + rotate_neox, + yarn_find_correction_range, + yarn_linear_ramp_mask, +) def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: @@ -49,42 +53,56 @@ def __init__( self.beta_slow = beta_slow # Get n-d magnitude scaling corrected for interpolation. self.mscale = float( - yarn_get_mscale(self.scaling_factor, float(mscale)) / - yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * - attn_factor) - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + yarn_get_mscale(self.scaling_factor, float(mscale)) + / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) + * attn_factor + ) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base**( - torch.arange(0, - self.rotary_dim, - 2, - dtype=torch.float, - device=current_platform.device_type) / - self.rotary_dim) + pos_freqs = self.base ** ( + torch.arange( + 0, + self.rotary_dim, + 2, + dtype=torch.float, + device=current_platform.device_type, + ) + / self.rotary_dim + ) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, - self.rotary_dim, self.base, - self.max_position_embeddings) + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) # Get n-d rotational scaling corrected for extrapolation - inv_freq_mask = (1 - yarn_linear_ramp_mask( - low, high, self.rotary_dim // 2, - dtype=torch.float)) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * ( - 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + inv_freq_mask = ( + 1 + - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) - t = torch.arange(self.max_position_embeddings * self.scaling_factor, - device=current_platform.device_type, - dtype=torch.float32) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=current_platform.device_type, + dtype=torch.float32, + ) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = (freqs.cos() * self.mscale) - sin = (freqs.sin() * self.mscale) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale cache = torch.cat((cos, sin), dim=-1) return cache @@ -98,14 +116,15 @@ def forward_native( """PyTorch-native implementation equivalent to forward().""" assert key is not None self._match_cos_sin_cache_dtype(query) - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] - cos_sin = self.cos_sin_cache[torch.add(positions, offsets) - if offsets is not None else positions] + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: # NOTE(woosuk): Here we assume that the positions tensor has the diff --git a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py index 27e41dd0fa97..0e6eddda772f 100644 --- a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py @@ -35,18 +35,17 @@ def __init__( self.local_size = local_size self.dtype = dtype self.device = torch.device(f"cuda:{torch.cuda.current_device()}") - (q_cache, qc_cache, k_cache, qc_no_clamp_cache, - q_inter_cache) = self._compute_cos_sin_cache() + (q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = ( + self._compute_cos_sin_cache() + ) self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) - self.register_buffer("cos_sin_qc_no_clamp_cache", - qc_no_clamp_cache, - persistent=False) - self.register_buffer("cos_sin_q_inter_cache", - q_inter_cache, - persistent=False) + self.register_buffer( + "cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False + ) + self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False) def _compute_inv_freq(self, base: float) -> torch.Tensor: """Compute the inverse frequency.""" @@ -59,8 +58,12 @@ def _compute_inv_freq(self, base: float) -> torch.Tensor: # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: @@ -68,16 +71,15 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) chunk_len = self.chunk_size - self.local_size q_t = torch.arange(chunk_len, dtype=torch.float) - qc_t = (torch.arange(chunk_len, dtype=torch.float) + - chunk_len).clamp(max=self.chunk_size) - k_t = torch.arange(self.max_position_embeddings, - dtype=torch.float) % chunk_len + qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp( + max=self.chunk_size + ) + k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len # count from chunk_len, no clamp(self.chunk_size) restriction qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len # count from self.chunk_size for q_inter's rope - q_inter_t = torch.arange(chunk_len, - dtype=torch.float) + self.chunk_size + q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size q_freqs = torch.outer(q_t, inv_freq) qc_freqs = torch.outer(qc_t, inv_freq) @@ -97,18 +99,21 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: q_inter_cos = q_inter_freqs.cos() q_inter_sin = q_inter_freqs.sin() - q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype, - device=self.device) - qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype, - device=self.device) - k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype, - device=self.device) - qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), - dim=-1).to(dtype=self.dtype, - device=self.device) - q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), - dim=-1).to(dtype=self.dtype, - device=self.device) + q_cache = torch.cat((q_cos, q_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + k_cache = torch.cat((k_cos, k_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) + q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to( + dtype=self.dtype, device=self.device + ) return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache def forward_native( @@ -120,45 +125,59 @@ def forward_native( ) -> tuple[torch.Tensor, torch.Tensor]: query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] else: query_pass = None key_pass = None - positions_with_offsets = (torch.add(positions, offsets) - if offsets is not None else positions) + positions_with_offsets = ( + torch.add(positions, offsets) if offsets is not None else positions + ) key = self._apply_rotary_embedding( - self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass) + self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass + ) chunk_len = self.chunk_size - self.local_size query = self._apply_rotary_embedding( self.cos_sin_q_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) + query_rot, + query_pass, + ) query_succ = self._apply_rotary_embedding( self.cos_sin_qc_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) + query_rot, + query_pass, + ) query_inter = self._apply_rotary_embedding( self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), - query_rot, query_pass) + query_rot, + query_pass, + ) query_succ_critical = self._apply_rotary_embedding( self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) + query_rot, + query_pass, + ) query_inter_critical = self._apply_rotary_embedding( self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], - query_rot, query_pass) + query_rot, + query_pass, + ) # merge query into one tensor to simplify the interfaces - query = torch.cat(( - query, - query_succ, - query_inter, - query_succ_critical, - query_inter_critical, - ), - dim=-1) + query = torch.cat( + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ), + dim=-1, + ) return query, key def forward_cuda( diff --git a/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py index 1da39bbd303b..dd9d06d4b288 100644 --- a/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py @@ -23,14 +23,16 @@ def __init__( dtype: torch.dtype, ) -> None: self.scaling_alpha = scaling_alpha - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_cos_sin_cache(self) -> torch.Tensor: # For Hunyuan DynamicNTKAlphaRotaryEmbedding max_len = self.max_position_embeddings - base = self.base * self.scaling_alpha**(self.rotary_dim / - (self.rotary_dim - 2)) + base = self.base * self.scaling_alpha ** ( + self.rotary_dim / (self.rotary_dim - 2) + ) inv_freq = self._compute_inv_freq(base) t = torch.arange(max_len, dtype=torch.float) diff --git a/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py index ec2008b90cfb..28fd87ecc21f 100644 --- a/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py @@ -44,8 +44,9 @@ def __init__( dtype: torch.dtype, ) -> None: self.scaling_factor = scaling_factor - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_cos_sin_cache(self) -> torch.Tensor: # NOTE(woosuk): self.max_position_embeddings is the original @@ -54,9 +55,9 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: # self.max_position_embeddings * self.scaling_factor. max_len = self.max_position_embeddings * self.scaling_factor base = self.base * ( - (self.scaling_factor * max_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**(self.rotary_dim / - (self.rotary_dim - 2)) + (self.scaling_factor * max_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.rotary_dim / (self.rotary_dim - 2)) inv_freq = self._compute_inv_freq(base) t = torch.arange(max_len, dtype=torch.float) diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 4960c20f4060..2bc0477c5af2 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -33,41 +33,37 @@ def forward_native( # type: ignore[override] assert section_h == section_w # Split according to [h w h w h w h w... t t t...] section_cos_t = cos[..., -section_t:] - section_cos_h = cos[..., :section_h + section_w:2] - section_cos_w = cos[..., 1:section_h + section_w:2] + section_cos_h = cos[..., : section_h + section_w : 2] + section_cos_w = cos[..., 1 : section_h + section_w : 2] - cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[ - 1], section_cos_w[2] - cos_hw = torch.stack([cos_h, cos_w], - dim=-1).reshape(cos_h.shape[:-1] + - (cos_h.shape[-1] * 2, )) + cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[1], section_cos_w[2] + cos_hw = torch.stack([cos_h, cos_w], dim=-1).reshape( + cos_h.shape[:-1] + (cos_h.shape[-1] * 2,) + ) cos = torch.cat([cos_hw, cos_t], dim=-1) section_sin_t = sin[..., -section_t:] - section_sin_h = sin[..., :section_h + section_w:2] - section_sin_w = sin[..., 1:section_h + section_w:2] + section_sin_h = sin[..., : section_h + section_w : 2] + section_sin_w = sin[..., 1 : section_h + section_w : 2] - sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[ - 1], section_sin_w[2] - sin_hw = torch.stack([sin_h, sin_w], - dim=-1).reshape(sin_h.shape[:-1] + - (sin_h.shape[-1] * 2, )) + sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[1], section_sin_w[2] + sin_hw = torch.stack([sin_h, sin_w], dim=-1).reshape( + sin_h.shape[:-1] + (sin_h.shape[-1] * 2,) + ) sin = torch.cat([sin_hw, sin_t], dim=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, - self.is_neox_style) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, - self.is_neox_style) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -77,4 +73,4 @@ def forward_cuda( # type: ignore[override] query: torch.Tensor, key: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - return self.forward_native(positions, query, key) \ No newline at end of file + return self.forward_native(positions, query, key) diff --git a/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py index 6e920991882d..cbb3ee4e9974 100644 --- a/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py @@ -71,8 +71,9 @@ def __init__( if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] self.scaling_factors: list[float] = scaling_factors # noqa - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) # Lazy initialized. self._scaling_factor_to_offset: dict[float, int] diff --git a/vllm/model_executor/layers/rotary_embedding/llama3_rope.py b/vllm/model_executor/layers/rotary_embedding/llama3_rope.py index adcef549bc4c..ed9a6031eb6f 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama3_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama3_rope.py @@ -9,7 +9,6 @@ class Llama3RotaryEmbedding(RotaryEmbedding): - def __init__( self, head_size: int, @@ -27,8 +26,9 @@ def __init__( self.low_freq_factor = low_freq_factor self.high_freq_factor = high_freq_factor self.orig_max_position = orig_max_position - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_inv_freq(self, base: float) -> torch.Tensor: inv_freqs = super()._compute_inv_freq(base) @@ -37,8 +37,9 @@ def _compute_inv_freq(self, base: float) -> torch.Tensor: wave_len = 2 * math.pi / inv_freqs if self.low_freq_factor != self.high_freq_factor: - smooth = (self.orig_max_position / wave_len - self.low_freq_factor - ) / (self.high_freq_factor - self.low_freq_factor) + smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) else: smooth = 0 new_freqs = torch.where( @@ -47,8 +48,7 @@ def _compute_inv_freq(self, base: float) -> torch.Tensor: torch.where( wave_len > low_freq_wavelen, inv_freqs / self.scaling_factor, - (1 - smooth) * inv_freqs / self.scaling_factor + - smooth * inv_freqs, + (1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs, ), ) return new_freqs diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index c98a426a2a1e..0b808e31c903 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -10,7 +10,6 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): - def __init__( self, head_size: int, @@ -20,12 +19,13 @@ def __init__( is_neox_style: bool, dtype: torch.dtype, ): - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_inv_freq(self, base: float) -> torch.Tensor: inv_freqs = super()._compute_inv_freq(base) - inv_freqs = inv_freqs[:(self.rotary_dim // 2)] + inv_freqs = inv_freqs[: (self.rotary_dim // 2)] return inv_freqs def _compute_cos_sin_cache(self) -> torch.Tensor: @@ -34,23 +34,23 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: # self.max_position_embeddings here is number of image patches # i.e. (image_size // patch_size) ** 2 num_patches = self.max_position_embeddings - img_idx = torch.arange(num_patches, - dtype=torch.int32) \ - .reshape(num_patches, 1) + img_idx = torch.arange(num_patches, dtype=torch.int32).reshape(num_patches, 1) img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN num_patches_single_dim = int(math.sqrt(num_patches)) frequencies_x = img_idx % num_patches_single_dim frequencies_y = img_idx // num_patches_single_dim - freqs_x = ((frequencies_x + 1)[..., None] * - inv_freq[None, None, :]).repeat_interleave(2, dim=-1) - freqs_y = ((frequencies_y + 1)[..., None] * - inv_freq[None, None, :]).repeat_interleave(2, dim=-1) - freqs = torch.cat([freqs_x, freqs_y], - dim=-1).float().contiguous()[..., ::2] + freqs_x = ( + (frequencies_x + 1)[..., None] * inv_freq[None, None, :] + ).repeat_interleave(2, dim=-1) + freqs_y = ( + (frequencies_y + 1)[..., None] * inv_freq[None, None, :] + ).repeat_interleave(2, dim=-1) + freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) cache = torch.view_as_complex( - torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) + torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) + ) return cache def forward_native( # type: ignore[override] @@ -62,10 +62,8 @@ def forward_native( # type: ignore[override] # self.cos_sin_cache here is complex tensor so we cannot cast into # query's dtype directly with self._match_cos_sin_cache_dtype self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) - query_ = torch.view_as_complex(query.float().reshape( - *query.shape[:-1], -1, 2)) - key_ = torch.view_as_complex(key.float().reshape( - *key.shape[:-1], -1, 2)) + query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2)) + key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2)) broadcast_shape = [ d if i == 1 or i == (query_.ndim - 1) else 1 for i, d in enumerate(query_.shape) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 9bf0d6bd15e7..120979970679 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -62,10 +62,8 @@ def _triton_mrope_forward( # Updated offsets for half head_dim cos_offsets = tl.arange(0, pad_hd // 2) if is_interleaved: - h_mask = (((cos_offsets % 3) == 1) & - (cos_offsets <= 3 * mrope_section_h)) - w_mask = (((cos_offsets % 3) == 2) & - (cos_offsets <= 3 * mrope_section_w)) + h_mask = ((cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section_h) + w_mask = ((cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section_w) t_mask = ~(h_mask | w_mask) else: t_end = mrope_section_t @@ -89,21 +87,25 @@ def _triton_mrope_forward( # program instance (i.e. for the current token) separately # #################################################################### # left half of the head - first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange( - 0, pad_hd // 2)[None, :] - first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange( - 0, pad_hd // 2)[None, :] - first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange( - 0, pad_hd // 2)[None, :] < rd // 2) - first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange( - 0, pad_hd // 2)[None, :] < rd // 2) - - q_tile_1 = tl.load(q_ptr + first_half_q_offsets, - mask=first_q_mask, - other=0).to(sin_row.dtype) - k_tile_1 = tl.load(k_ptr + first_half_k_offsets, - mask=first_k_mask, - other=0).to(sin_row.dtype) + first_half_q_offsets = ( + tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + ) + first_half_k_offsets = ( + tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] + ) + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_hd // 2)[None, :] < rd // 2 + ) + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_hd // 2)[None, :] < rd // 2 + ) + + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( + sin_row.dtype + ) # right half of the head second_half_q_offsets = first_half_q_offsets + (rd // 2) @@ -111,12 +113,12 @@ def _triton_mrope_forward( second_q_mask = first_q_mask second_k_mask = first_k_mask - q_tile_2 = tl.load(q_ptr + second_half_q_offsets, - mask=second_q_mask, - other=0).to(sin_row.dtype) - k_tile_2 = tl.load(k_ptr + second_half_k_offsets, - mask=second_k_mask, - other=0).to(sin_row.dtype) + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( + sin_row.dtype + ) # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] # Since cos and sin are now half-size, @@ -168,7 +170,7 @@ def triton_mrope( cos = cos.contiguous() sin = sin.contiguous() - _triton_mrope_forward[(n_row, )]( + _triton_mrope_forward[(n_row,)]( q, k, cos, @@ -189,15 +191,14 @@ def triton_mrope( return q, k -def apply_interleaved_rope(x: torch.Tensor, - mrope_section: list[int]) -> torch.Tensor: +def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: """Apply interleaved MRoPE to 3D rotary embeddings. Reorganizes frequency layout from chunked [TTT...HHH...WWW] to interleaved [THTHWHTHW...TT], preserving frequency continuity. """ x_t = x[0].clone() - x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3] - x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3] + x_t[..., 1 : mrope_section[1] * 3 : 3] = x[1, ..., 1 : mrope_section[1] * 3 : 3] + x_t[..., 2 : mrope_section[2] * 3 : 3] = x[2, ..., 2 : mrope_section[2] * 3 : 3] return x_t @@ -222,7 +223,6 @@ def __init__( beta_fast: int = 32, beta_slow: int = 1, ) -> None: - self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor @@ -230,8 +230,7 @@ def __init__( self.beta_slow = beta_slow if self.scaling_factor is not None: # Get n-d magnitude scaling corrected for interpolation - self.mscale = float( - yarn_get_mscale(self.scaling_factor) * attn_factor) + self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor) else: self.mscale = 1.0 @@ -239,8 +238,14 @@ def __init__( # the input video. We enlarge max_position_embeddings to 4 times to get # a larger the cos and sin cache. self.cache_max_position_num = max_position_embeddings * 4 - super().__init__(head_size, rotary_dim, self.cache_max_position_num, - base, is_neox_style, dtype) + super().__init__( + head_size, + rotary_dim, + self.cache_max_position_num, + base, + is_neox_style, + dtype, + ) self.mrope_section = mrope_section self.mrope_interleaved = mrope_interleaved @@ -286,31 +291,27 @@ def forward_native( cos = apply_interleaved_rope(cos, self.mrope_section) sin = apply_interleaved_rope(sin, self.mrope_section) else: - cos = torch.cat([ - m[i] for i, m in enumerate( - cos.split(self.mrope_section, dim=-1)) - ], - dim=-1) - sin = torch.cat([ - m[i] for i, m in enumerate( - sin.split(self.mrope_section, dim=-1)) - ], - dim=-1) + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, - self.is_neox_style) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, - self.is_neox_style) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -321,7 +322,6 @@ def forward_cuda( key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - assert positions.ndim == 1 or positions.ndim == 2 assert key is not None @@ -348,17 +348,15 @@ def forward_cuda( return q.reshape(query_shape), k.reshape(key_shape) query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, - self.is_neox_style) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, - self.is_neox_style) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -397,21 +395,19 @@ def get_input_positions( image_grid_thw = [] if image_grid_thw is None else image_grid_thw video_grid_thw = [] if video_grid_thw is None else video_grid_thw - second_per_grid_ts = [] if second_per_grid_ts is None else \ - second_per_grid_ts - - llm_positions, mrope_position_delta = \ - cls.get_input_positions_tensor( - input_tokens=input_tokens, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=context_len, - seq_len=seq_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) + second_per_grid_ts = [] if second_per_grid_ts is None else second_per_grid_ts + + llm_positions, mrope_position_delta = cls.get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) return llm_positions.tolist(), mrope_position_delta @@ -429,6 +425,7 @@ def get_input_positions_tensor( use_audio_in_video: bool = False, ) -> tuple[torch.Tensor, int]: from vllm.transformers_utils.config import thinker_uses_mrope + if thinker_uses_mrope(hf_config): return cls._omni_get_input_positions_tensor( input_tokens=input_tokens, @@ -527,7 +524,8 @@ def _glm4v_get_input_positions_tensor( input_type_group: list[tuple[str, int, int]] = [] for key, group_iter in itertools.groupby( - enumerate(input_token_type), lambda x: x[1]): + enumerate(input_token_type), lambda x: x[1] + ): group_list = list(group_iter) start_index = group_list[0][0] end_index = group_list[-1][0] + 1 @@ -536,25 +534,42 @@ def _glm4v_get_input_positions_tensor( video_frame_num = 1 mm_data_idx = 0 for modality_type, start_idx, end_idx in input_type_group: - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) if modality_type == "image": t, h, w = ( image_grid_thw[mm_data_idx][0], image_grid_thw[mm_data_idx][1], image_grid_thw[mm_data_idx][2], ) - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size - - t_index = torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx) + torch.stack([t_index, h_index, w_index]) + st_idx + ) mm_data_idx += 1 elif modality_type == "video": @@ -563,18 +578,34 @@ def _glm4v_get_input_positions_tensor( image_grid_thw[mm_data_idx][1], image_grid_thw[mm_data_idx][2], ) - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) for t_idx in range(llm_grid_t): - t_index = torch.tensor(t_idx).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view( - 1, -1, 1).expand(1, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view( - 1, 1, -1).expand(1, llm_grid_h, -1).flatten() + t_index = ( + torch.tensor(t_idx) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(1, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(1, llm_grid_h, -1) + .flatten() + ) llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx) + torch.stack([t_index, h_index, w_index]) + st_idx + ) mm_data_idx += 1 video_frame_num += 1 @@ -582,19 +613,17 @@ def _glm4v_get_input_positions_tensor( else: text_len = end_idx - start_idx llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + - st_idx) + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) video_frame_num = 1 else: text_len = len(input_tokens) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1)) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = llm_positions[:, context_len:seq_len] - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta @classmethod @@ -609,8 +638,7 @@ def _qwen3vl_get_input_positions_tensor( ) -> tuple[torch.Tensor, int]: """Get mrope input positions and delta value.""" - video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw - for _ in range(t)] + video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw for _ in range(t)] image_token_id = hf_config.image_token_id video_token_id = hf_config.video_token_id @@ -619,7 +647,8 @@ def _qwen3vl_get_input_positions_tensor( input_tokens_tensor = torch.tensor(input_tokens) vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id).squeeze(1) + input_tokens_tensor == vision_start_token_id + ).squeeze(1) vision_tokens = input_tokens_tensor[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() @@ -657,35 +686,50 @@ def _qwen3vl_get_input_positions_tensor( remain_videos -= 1 ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) text_len = ed - st - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - t_index = torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta @@ -729,7 +773,8 @@ def _ernie_get_input_positions_tensor( input_type_group: list[tuple[str, int, int]] = [] for key, group_iter in itertools.groupby( - enumerate(input_token_type), lambda x: x[1]): + enumerate(input_token_type), lambda x: x[1] + ): group_list = list(group_iter) start_index = group_list[0][0] end_index = group_list[-1][0] + 1 @@ -738,25 +783,42 @@ def _ernie_get_input_positions_tensor( video_frame_num = 1 mm_data_idx = 0 for modality_type, start_idx, end_idx in input_type_group: - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) if modality_type == "image": t, h, w = ( image_grid_thw[mm_data_idx][0], image_grid_thw[mm_data_idx][1], image_grid_thw[mm_data_idx][2], ) - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_conv_size, w // spatial_conv_size - - t_index = torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_conv_size, + w // spatial_conv_size, + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx) + torch.stack([t_index, h_index, w_index]) + st_idx + ) mm_data_idx += 1 elif modality_type == "video": @@ -765,22 +827,34 @@ def _ernie_get_input_positions_tensor( video_grid_thw[mm_data_idx][1], video_grid_thw[mm_data_idx][2], ) - llm_grid_t, llm_grid_h, llm_grid_w = (t // - temporal_conv_size, - h // - spatial_conv_size, - w // - spatial_conv_size) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t // temporal_conv_size, + h // spatial_conv_size, + w // spatial_conv_size, + ) for t_idx in range(llm_grid_t): - t_index = torch.tensor(t_idx).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view( - 1, -1, 1).expand(1, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view( - 1, 1, -1).expand(1, llm_grid_h, -1).flatten() + t_index = ( + torch.tensor(t_idx) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(1, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(1, llm_grid_h, -1) + .flatten() + ) llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + st_idx) + torch.stack([t_index, h_index, w_index]) + st_idx + ) mm_data_idx += 1 video_frame_num += 1 @@ -788,19 +862,17 @@ def _ernie_get_input_positions_tensor( else: text_len = end_idx - start_idx llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + - st_idx) + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) video_frame_num = 1 else: text_len = len(input_tokens) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1)) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = llm_positions[:, context_len:seq_len] - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() return llm_positions, mrope_position_delta @classmethod @@ -817,8 +889,7 @@ def _keye_get_input_positions_tensor( video_grid_thw = video_grid_thw[0] """Get mrope input positions and delta value (Keye series).""" - def split_thw( - grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]: + def split_thw(grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]: """ Split grid_thw along the t dimension. @@ -889,36 +960,54 @@ def split_thw( remain_frames -= 1 ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) text_len = ed - st - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) - t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w)).long().flatten() + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + ) + .long() + .flatten() + ) - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta @@ -940,12 +1029,12 @@ def _vl_get_input_positions_tensor( video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size - tokens_per_second = getattr(hf_config.vision_config, - "tokens_per_second", 1.0) + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) input_tokens_tensor = torch.tensor(input_tokens) vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id).squeeze(1) + input_tokens_tensor == vision_start_token_id + ).squeeze(1) vision_tokens = input_tokens_tensor[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() @@ -993,37 +1082,56 @@ def _vl_get_input_positions_tensor( remain_videos -= 1 ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) text_len = ed - st - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) - t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * - tokens_per_second).long().flatten() + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + * video_second_per_grid_t + * tokens_per_second + ) + .long() + .flatten() + ) - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta @@ -1070,8 +1178,9 @@ def _omni_get_input_positions_tensor( vision_end_token_id = thinker_config.vision_end_token_id seconds_per_chunk = thinker_config.seconds_per_chunk spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr(thinker_config.vision_config, - "tokens_per_second", 25) + tokens_per_second = getattr( + thinker_config.vision_config, "tokens_per_second", 25 + ) if isinstance(image_grid_thw, list): image_grid_thw = torch.tensor(image_grid_thw) @@ -1091,28 +1200,30 @@ def _omni_get_input_positions_tensor( idx = 0 while idx < len(src_item): new_src_item_len = len(new_src_item) - start_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 - if src_item[idx] not in [ - audio_token_id, video_token_id, image_token_id - ]: + start_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]: if use_audio_in_video and idx > 0: - if src_item[idx] == vision_end_token_id and \ - src_item[idx - 1] == audio_end_token_id: + if ( + src_item[idx] == vision_end_token_id + and src_item[idx - 1] == audio_end_token_id + ): # processing the <|audio_eos|> before <|vision_eos|> start_idx -= 1 - elif src_item[idx] == audio_start_token_id and \ - src_item[idx - 1] == vision_start_token_id: + elif ( + src_item[idx] == audio_start_token_id + and src_item[idx - 1] == vision_start_token_id + ): # processing the <|audio_bos|> after <|vision_eos|> start_idx -= 1 new_src_item.append(src_item[idx]) - llm_pos_ids = torch.tensor([start_idx], - dtype=torch.long).expand(3, -1) + llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) llm_pos_ids_list.append(llm_pos_ids) elif src_item[idx] == audio_token_id: assert audio_seqlens is not None audio_seqlen = audio_seqlens[audio_idx] - place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 new_src_item.extend([audio_token_id] * place_num) llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx llm_pos_ids_list.append(llm_pos_ids) @@ -1123,26 +1234,30 @@ def _omni_get_input_positions_tensor( grid_ws = image_grid_thw[:, 2] t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long() llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, image_idx, spatial_merge_size, t_index, grid_hs, - grid_ws) + start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) llm_pos_ids_list.append(llm_pos_ids) vision_seqlen = image_grid_thw[image_idx].prod() // ( - spatial_merge_size**2) + spatial_merge_size**2 + ) new_src_item.extend([image_token_id] * vision_seqlen) image_idx += 1 elif src_item[idx] == video_token_id and not use_audio_in_video: grid_t = video_grid_thw[video_idx][0] grid_hs = video_grid_thw[:, 1] grid_ws = video_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * - second_per_grid_ts[video_idx] * - tokens_per_second).long() + t_index = ( + torch.arange(grid_t) + * second_per_grid_ts[video_idx] + * tokens_per_second + ).long() llm_pos_ids = cls._get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_index, grid_hs, - grid_ws) + start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) llm_pos_ids_list.append(llm_pos_ids) vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2) + spatial_merge_size**2 + ) new_src_item.extend([video_token_id] * vision_seqlen) video_idx += 1 else: @@ -1150,56 +1265,73 @@ def _omni_get_input_positions_tensor( assert audio_seqlens is not None audio_seqlen = audio_seqlens[audio_idx] vision_seqlen = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2) + spatial_merge_size**2 + ) grid_t = video_grid_thw[video_idx][0] grid_h = video_grid_thw[video_idx][1] grid_w = video_grid_thw[video_idx][2] grid_hs = video_grid_thw[:, 1] grid_ws = video_grid_thw[:, 2] t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = (torch.arange(grid_t) * - second_per_grid_ts[video_idx] * - tokens_per_second).long() + t_index = ( + torch.arange(grid_t) + * second_per_grid_ts[video_idx] + * tokens_per_second + ).long() t_index_split_chunk = cls._split_list_into_ranges( - t_index, t_ntoken_per_chunk) + t_index, t_ntoken_per_chunk + ) place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 pure_audio_len = place_num - 2 added_audio_len = 0 audio_llm_pos_ids_list: list[torch.Tensor] = [] for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = len( - t_chunk) * grid_h * grid_w // (spatial_merge_size**2) - new_src_item.extend([video_token_id] * - vision_ntoken_per_chunk) + vision_ntoken_per_chunk = ( + len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + ) + new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( - start_idx, video_idx, spatial_merge_size, t_chunk, - grid_hs, grid_ws).split(1, dim=1) + start_idx, + video_idx, + spatial_merge_size, + t_chunk, + grid_hs, + grid_ws, + ).split(1, dim=1) llm_pos_ids_list.extend(vision_llm_pos_ids_list) new_src_item.extend( - min(t_ntoken_per_chunk, pure_audio_len - - added_audio_len) * [audio_token_id]) - audio_start_idx = start_idx if len( - audio_llm_pos_ids_list - ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 - if min(t_ntoken_per_chunk, - pure_audio_len - added_audio_len) > 0: - audio_llm_pos_ids_list = (torch.arange( - min(t_ntoken_per_chunk, pure_audio_len - - added_audio_len)).expand(3, -1) + - audio_start_idx).split(1, - dim=1) + min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) + * [audio_token_id] + ) + audio_start_idx = ( + start_idx + if len(audio_llm_pos_ids_list) == 0 + else audio_llm_pos_ids_list[-1][0].item() + 1 + ) + if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: + audio_llm_pos_ids_list = ( + torch.arange( + min( + t_ntoken_per_chunk, pure_audio_len - added_audio_len + ) + ).expand(3, -1) + + audio_start_idx + ).split(1, dim=1) else: audio_llm_pos_ids_list = [] - added_audio_len += min(t_ntoken_per_chunk, - pure_audio_len - added_audio_len) + added_audio_len += min( + t_ntoken_per_chunk, pure_audio_len - added_audio_len + ) llm_pos_ids_list.extend(audio_llm_pos_ids_list) if added_audio_len < pure_audio_len: new_src_item.extend( - (pure_audio_len - added_audio_len) * [audio_token_id]) + (pure_audio_len - added_audio_len) * [audio_token_id] + ) audio_llm_pos_ids_list = ( - torch.arange(pure_audio_len - added_audio_len).expand( - 3, -1) + llm_pos_ids_list[-1].max() + 1).split( - 1, dim=1) + torch.arange(pure_audio_len - added_audio_len).expand(3, -1) + + llm_pos_ids_list[-1].max() + + 1 + ).split(1, dim=1) llm_pos_ids_list.extend(audio_llm_pos_ids_list) audio_idx += 1 video_idx += 1 @@ -1207,8 +1339,9 @@ def _omni_get_input_positions_tensor( idx += len(new_src_item) - new_src_item_len llm_positions = torch.cat(llm_pos_ids_list, dim=1) - mrope_position_delta = torch.cat(llm_pos_ids_list, - dim=1).max() + 1 - len(src_item) + mrope_position_delta = ( + torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item) + ) llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta @@ -1225,22 +1358,34 @@ def _get_llm_pos_ids_for_vision( llm_pos_ids_list = [] llm_grid_h = grid_hs[vision_idx] // spatial_merge_size llm_grid_w = grid_ws[vision_idx] // spatial_merge_size - h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand( - len(t_index), -1, llm_grid_w).flatten()) - w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( - len(t_index), llm_grid_h, -1).flatten()) - t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view( - -1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten() + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(len(t_index), -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(len(t_index), llm_grid_h, -1) + .flatten() + ) + t_index_tensor = ( + torch.Tensor(t_index) + .to(llm_grid_h.device) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .long() + .flatten() + ) _llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index]) llm_pos_ids_list.append(_llm_pos_ids + start_idx) llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) return llm_pos_ids @staticmethod - def _split_list_into_ranges(lst: torch.Tensor, - interval: int) -> list[list[int]]: - ranges: list[list[int]] = [[] - for _ in range((max(lst) // interval) + 1)] + def _split_list_into_ranges(lst: torch.Tensor, interval: int) -> list[list[int]]: + ranges: list[list[int]] = [[] for _ in range((max(lst) // interval) + 1)] for num in lst: index = num // interval ranges[index].append(num) @@ -1254,19 +1399,27 @@ def get_next_input_positions( ) -> list[list[int]]: return [ list( - range(context_len + mrope_position_delta, - seq_len + mrope_position_delta)) for _ in range(3) + range( + context_len + mrope_position_delta, seq_len + mrope_position_delta + ) + ) + for _ in range(3) ] @staticmethod - def get_next_input_positions_tensor(out: np.ndarray, out_offset: int, - mrope_position_delta: int, - context_len: int, num_new_tokens: int): - - values = np.arange(mrope_position_delta + context_len, - mrope_position_delta + context_len + num_new_tokens, - dtype=out.dtype) - out[:, out_offset:out_offset + num_new_tokens] = values + def get_next_input_positions_tensor( + out: np.ndarray, + out_offset: int, + mrope_position_delta: int, + context_len: int, + num_new_tokens: int, + ): + values = np.arange( + mrope_position_delta + context_len, + mrope_position_delta + context_len + num_new_tokens, + dtype=out.dtype, + ) + out[:, out_offset : out_offset + num_new_tokens] = values @classmethod def omni_get_updates_use_audio_in_video( @@ -1291,27 +1444,28 @@ def omni_get_updates_use_audio_in_video( audio_end_token_id = thinker_config.audio_end_token_id seconds_per_chunk = thinker_config.seconds_per_chunk spatial_merge_size = thinker_config.vision_config.spatial_merge_size - tokens_per_second = getattr(thinker_config.vision_config, - "tokens_per_second", 25) + tokens_per_second = getattr( + thinker_config.vision_config, "tokens_per_second", 25 + ) grid_t = video_grid_thw[0] grid_h = video_grid_thw[1] grid_w = video_grid_thw[2] t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk) - t_index = (torch.arange(grid_t) * video_second_per_grid_t * - tokens_per_second).long() - t_index_split_chunk = cls._split_list_into_ranges( - t_index, t_ntoken_per_chunk) + t_index = ( + torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second + ).long() + t_index_split_chunk = cls._split_list_into_ranges(t_index, t_ntoken_per_chunk) updates = [audio_start_token_id] added_audio_len = 0 for t_chunk in t_index_split_chunk: - vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // ( - spatial_merge_size**2) + vision_ntoken_per_chunk = ( + len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2) + ) updates.extend([video_token_id] * vision_ntoken_per_chunk) - audio_chunk_size = min(t_ntoken_per_chunk, - audio_len - added_audio_len) + audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len) updates.extend(audio_chunk_size * [audio_token_id]) added_audio_len += audio_chunk_size if added_audio_len < audio_len: diff --git a/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py index 42926bad22ef..560fb100413d 100644 --- a/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py @@ -10,33 +10,39 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with fixed and mixed NTK scaling. - https://kexue.fm/archives/9706 """ - - def __init__(self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: float, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - mixed_b: Optional[float] = None) -> None: + https://kexue.fm/archives/9706""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + mixed_b: Optional[float] = None, + ) -> None: self.scaling_factor = scaling_factor self.mixed_b = mixed_b - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_inv_freq(self, base: float) -> torch.Tensor: base = self.base * (self.scaling_factor if self.mixed_b is None else 1) inv_freq = super()._compute_inv_freq(base) if self.mixed_b is None: - inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim) + inv_freq = inv_freq / self.scaling_factor ** (2 / self.rotary_dim) else: - a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim / - 2)**self.mixed_b - lambda_1_m = (a * torch.arange( - 1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp() + a = ( + torch.tensor(self.scaling_factor).log() + / (self.rotary_dim / 2) ** self.mixed_b + ) + lambda_1_m = ( + a * torch.arange(1, self.rotary_dim // 2 + 1).float() ** self.mixed_b + ).exp() inv_freq = inv_freq / lambda_1_m return inv_freq diff --git a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py index 9c36d633e2a9..02ad142d676b 100644 --- a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py @@ -44,14 +44,13 @@ def __init__( self.short_factor = short_factor self.long_factor = long_factor - scale = self.max_position_embeddings / \ - self.original_max_position_embeddings + scale = self.max_position_embeddings / self.original_max_position_embeddings if scale <= 1.0: scaling_factor = 1.0 else: scaling_factor = math.sqrt( - 1 + math.log(scale) / - math.log(self.original_max_position_embeddings)) + 1 + math.log(scale) / math.log(self.original_max_position_embeddings) + ) if short_mscale is None: short_mscale = scaling_factor if long_mscale is None: @@ -61,22 +60,32 @@ def __init__( self.long_mscale = long_mscale short_cache = self._compute_cos_sin_cache( - original_max_position_embeddings, short_factor, short_mscale) + original_max_position_embeddings, short_factor, short_mscale + ) short_cache = short_cache.to(dtype) - long_cache = self._compute_cos_sin_cache(max_position_embeddings, - long_factor, long_mscale) + long_cache = self._compute_cos_sin_cache( + max_position_embeddings, long_factor, long_mscale + ) long_cache = long_cache.to(dtype) long_short_cache = torch.cat([short_cache, long_cache], dim=0) - self.register_buffer("long_short_cos_sin_cache", - long_short_cache, - persistent=False) + self.register_buffer( + "long_short_cos_sin_cache", long_short_cache, persistent=False + ) def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor: rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) - inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))) + inv_freq = 1.0 / ( + rescale_factors + * ( + self.base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) + / self.rotary_dim + ) + ) + ) return inv_freq def _compute_cos_sin_cache( @@ -105,10 +114,14 @@ def forward( key = key.view(*key.shape[:-1], -1, self.head_size) k = self.original_max_position_embeddings - long_prompt_offset = (torch.any(positions > k).float() * - torch.full_like(positions, k)).long() - idx = (torch.add(positions, long_prompt_offset) - if long_prompt_offset is not None else positions) + long_prompt_offset = ( + torch.any(positions > k).float() * torch.full_like(positions, k) + ).long() + idx = ( + torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None + else positions + ) idx = torch.add(idx, offsets) if offsets is not None else idx cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) @@ -116,13 +129,13 @@ def forward( cos = cos.repeat(1, 2).unsqueeze(-2) sin = sin.repeat(1, 2).unsqueeze(-2) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] query_rot = query_rot * cos + rotate_neox(query_rot) * sin query = torch.cat((query_rot, query_pass), dim=-1) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] key_rot = key_rot * cos + rotate_neox(key_rot) * sin key = torch.cat((key_rot, key_pass), dim=-1) diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py index da7c84cb442d..223350d43267 100644 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py @@ -9,8 +9,11 @@ def is_rocm_triton_rotary_embedding_enabled() -> bool: - return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_TRITON_ROPE) + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_TRITON_ROPE + ) def rocm_aiter_rotary_emb_with_key_forward_triton_impl( @@ -23,6 +26,7 @@ def rocm_aiter_rotary_emb_with_key_forward_triton_impl( is_nope_first: bool = False, ) -> None: import aiter.ops.triton.rope as ops + ops.rope_cached_thd_positions_2c_fwd_inplace( query, key, @@ -48,7 +52,6 @@ def rocm_aiter_rotary_emb_with_key_forward_triton_fake( if is_rocm_triton_rotary_embedding_enabled(): - direct_register_custom_op( op_name="rocm_aiter_rotary_emb_with_key_forward_triton", op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl, @@ -58,10 +61,15 @@ def rocm_aiter_rotary_emb_with_key_forward_triton_fake( ) -def rocm_aiter_rotary_emb(positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, cos_sin_cache: torch.Tensor, - head_size: int, rotary_dim: int, - is_neox_style: bool): +def rocm_aiter_rotary_emb( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + cos_sin_cache: torch.Tensor, + head_size: int, + rotary_dim: int, + is_neox_style: bool, +): num_tokens = positions.numel() cos, sin = cos_sin_cache.chunk(2, dim=-1) query_shape = query.shape diff --git a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py index 851565c5667a..93c92e7801e1 100644 --- a/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py @@ -4,8 +4,7 @@ import torch from .base import RotaryEmbedding -from .common import (yarn_find_correction_range, yarn_get_mscale, - yarn_linear_ramp_mask) +from .common import yarn_find_correction_range, yarn_get_mscale, yarn_linear_ramp_mask class YaRNScalingRotaryEmbedding(RotaryEmbedding): @@ -36,33 +35,42 @@ def __init__( self.beta_slow = beta_slow # Get n-d magnitude scaling corrected for interpolation self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor) - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base**( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / - self.rotary_dim) + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - low, high = yarn_find_correction_range(self.beta_fast, self.beta_slow, - self.rotary_dim, self.base, - self.max_position_embeddings) + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) # Get n-d rotational scaling corrected for extrapolation - inv_freq_mask = (1 - yarn_linear_ramp_mask( - low, high, self.rotary_dim // 2, - dtype=torch.float)) * self.extrapolation_factor - inv_freq = inv_freq_interpolation * ( - 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + inv_freq_mask = ( + 1 + - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) - t = torch.arange(self.max_position_embeddings * self.scaling_factor, - dtype=torch.float32) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, dtype=torch.float32 + ) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = (freqs.cos() * self.mscale) - sin = (freqs.sin() * self.mscale) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale cache = torch.cat((cos, sin), dim=-1) return cache diff --git a/vllm/model_executor/layers/shared_fused_moe/__init__.py b/vllm/model_executor/layers/shared_fused_moe/__init__.py index b87c69d3edd0..b047e9cad04a 100644 --- a/vllm/model_executor/layers/shared_fused_moe/__init__.py +++ b/vllm/model_executor/layers/shared_fused_moe/__init__.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import ( - SharedFusedMoE) +from vllm.model_executor.layers.shared_fused_moe.shared_fused_moe import SharedFusedMoE __all__ = ["SharedFusedMoE"] diff --git a/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py index e1e3d188d985..a8b09a5c3cdb 100644 --- a/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py @@ -40,8 +40,11 @@ def forward( # Reduce outputs if necessary, since the MLP should # have been created with reduce_results=False. - if (self.reduce_results and self.tp_size > 1 - and self.must_reduce_shared_expert_outputs()): + if ( + self.reduce_results + and self.tp_size > 1 + and self.must_reduce_shared_expert_outputs() + ): shared_out = tensor_model_parallel_all_reduce(shared_out) fused_out = super().forward( diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 96dd58c0e4d2..e522cc450d6b 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility methods for model layers.""" + from typing import Callable, Optional import torch @@ -24,8 +25,8 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor: # This will be used together with triton swiglu kernel shape = w.shape N = shape[-1] - first = w[..., :N // 2] - second = w[..., N // 2:] + first = w[..., : N // 2] + second = w[..., N // 2 :] stacked = torch.stack((first, second), dim=-1) w_shuffled = stacked.reshape(shape) @@ -39,9 +40,9 @@ def get_token_bin_counts_and_mask( ) -> tuple[torch.Tensor, torch.Tensor]: # Compute the bin counts for the tokens. # vocab_size + 1 for padding. - bin_counts = torch.zeros((num_seqs, vocab_size + 1), - dtype=torch.long, - device=tokens.device) + bin_counts = torch.zeros( + (num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device + ) bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) bin_counts = bin_counts[:, :vocab_size] mask = bin_counts > 0 @@ -49,18 +50,21 @@ def get_token_bin_counts_and_mask( return bin_counts, mask -def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, - output_tokens_tensor: torch.Tensor, - presence_penalties: torch.Tensor, - frequency_penalties: torch.Tensor, - repetition_penalties: torch.Tensor) -> torch.Tensor: +def apply_penalties( + logits: torch.Tensor, + prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, +) -> torch.Tensor: """ Applies penalties in place to the logits tensor logits : The input logits tensor of shape [num_seqs, vocab_size] - prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts - are padded to the maximum prompt length within the batch using - `vocab_size` as the padding value. The value `vocab_size` is used - for padding because it does not correspond to any valid token ID + prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts + are padded to the maximum prompt length within the batch using + `vocab_size` as the padding value. The value `vocab_size` is used + for padding because it does not correspond to any valid token ID in the vocabulary. output_tokens_tensor: The output tokens tensor. presence_penalties: The presence penalties of shape (num_seqs, ) @@ -68,15 +72,17 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, repetition_penalties: The repetition penalties of shape (num_seqs, ) """ num_seqs, vocab_size = logits.shape - _, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor, - vocab_size, num_seqs) + _, prompt_mask = get_token_bin_counts_and_mask( + prompt_tokens_tensor, vocab_size, num_seqs + ) output_bin_counts, output_mask = get_token_bin_counts_and_mask( - output_tokens_tensor, vocab_size, num_seqs) + output_tokens_tensor, vocab_size, num_seqs + ) # Apply repetition penalties as a custom op from vllm._custom_ops import apply_repetition_penalties - apply_repetition_penalties(logits, prompt_mask, output_mask, - repetition_penalties) + + apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties) # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details @@ -85,22 +91,27 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, return logits -def default_unquantized_gemm(layer: torch.nn.Module, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None): +def default_unquantized_gemm( + layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, +): return torch.nn.functional.linear(x, weight, bias) def rocm_unquantized_gemm_impl( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 + k = weight.shape[1] - use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ - x.dtype in [torch.float16, torch.bfloat16] \ - and k % 8 == 0) + use_skinny = ( + envs.VLLM_ROCM_USE_SKINNY_GEMM + and on_gfx9() + and x.dtype in [torch.float16, torch.bfloat16] + and k % 8 == 0 + ) if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) @@ -120,16 +131,17 @@ def rocm_unquantized_gemm_impl( def rocm_unquantized_gemm_impl_fake( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: return x.new_empty((*x.shape[:-1], weight.shape[0])) -def rocm_unquantized_gemm(layer: torch.nn.Module, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def rocm_unquantized_gemm( + layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) @@ -141,9 +153,12 @@ def rocm_unquantized_gemm(layer: torch.nn.Module, def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool: - return (torch._C._cpu._is_amx_tile_supported() - and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0 - and n % 16 == 0) + return ( + torch._C._cpu._is_amx_tile_supported() + and (dtype in (torch.bfloat16, torch.int8)) + and k % 32 == 0 + and n % 16 == 0 + ) def dispatch_cpu_unquantized_gemm( @@ -158,31 +173,32 @@ def dispatch_cpu_unquantized_gemm( bias_f32 = layer.bias.to(torch.float32) else: bias_f32 = None - layer.cpu_linear = ( - lambda x, weight, bias: torch.ops._C.weight_packed_linear( - x, packed_weight, bias_f32 - if bias is not None else None, True)) + layer.cpu_linear = lambda x, weight, bias: torch.ops._C.weight_packed_linear( + x, packed_weight, bias_f32 if bias is not None else None, True + ) if remove_weight: - layer.weight = torch.nn.Parameter(torch.empty(0), - requires_grad=False) - elif (ops._supports_onednn - and current_platform.get_cpu_architecture() == CpuArchEnum.X86): + layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) + elif ops._supports_onednn and ( + current_platform.get_cpu_architecture() == CpuArchEnum.X86 + or ops.is_onednn_acl_supported() + ): origin_weight = layer.weight if remove_weight: - layer.weight = torch.nn.Parameter(torch.empty(0), - requires_grad=False) + layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) handler = ops.create_onednn_mm(origin_weight.t(), 32) - layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm( - handler, x, bias) + layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias) else: layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear( - x, weight, bias) + x, weight, bias + ) -def cpu_unquantized_gemm(layer: torch.nn.Module, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None): +def cpu_unquantized_gemm( + layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, +): return layer.cpu_linear(x, weight, bias) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index aa64d4e09ae1..b7253c7f0e52 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -9,12 +9,18 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) + QuantizationConfig, + QuantizeMethodBase, + method_has_implemented_embedding, +) from vllm.model_executor.layers.utils import dispatch_unquantized_gemm from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs @@ -26,65 +32,73 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): """Unquantized method for embeddings.""" - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): """Create weights for embedding layer.""" - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if current_platform.is_cpu(): - from vllm.model_executor.layers.utils import ( - dispatch_cpu_unquantized_gemm) + from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm + dispatch_cpu_unquantized_gemm(layer, remove_weight=False) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) - def embedding(self, layer: torch.nn.Module, - input_: torch.Tensor) -> torch.Tensor: + def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: return F.embedding(input_, layer.weight) -def pad_vocab_size(vocab_size: int, - pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: +def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, - rank: int, - offset: int = 0) -> Sequence[int]: + per_partition_vocab_size: int, rank: int, offset: int = 0 +) -> Sequence[int]: index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size return index_f + offset, index_l + offset -def vocab_range_from_global_vocab_size(global_vocab_size: int, - rank: int, - world_size: int, - offset: int = 0) -> Sequence[int]: +def vocab_range_from_global_vocab_size( + global_vocab_size: int, rank: int, world_size: int, offset: int = 0 +) -> Sequence[int]: per_partition_vocab_size = divide(global_vocab_size, world_size) - return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, - rank, - offset=offset) + return vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, offset=offset + ) @dataclass class VocabParallelEmbeddingShardIndices: """Indices for a shard of a vocab parallel embedding.""" + padded_org_vocab_start_index: int padded_org_vocab_end_index: int padded_added_vocab_start_index: int @@ -105,13 +119,11 @@ def num_added_elements(self) -> int: @property def num_org_elements_padded(self) -> int: - return (self.padded_org_vocab_end_index - - self.padded_org_vocab_start_index) + return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index @property def num_added_elements_padded(self) -> int: - return (self.padded_added_vocab_end_index - - self.padded_added_vocab_start_index) + return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index @property def num_org_vocab_padding(self) -> int: @@ -127,17 +139,14 @@ def num_elements_padded(self) -> int: def __post_init__(self): # sanity checks - assert (self.padded_org_vocab_start_index - <= self.padded_org_vocab_end_index) - assert (self.padded_added_vocab_start_index - <= self.padded_added_vocab_end_index) + assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index + assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index assert self.org_vocab_start_index <= self.org_vocab_end_index assert self.added_vocab_start_index <= self.added_vocab_end_index assert self.org_vocab_start_index <= self.padded_org_vocab_start_index - assert (self.added_vocab_start_index - <= self.padded_added_vocab_start_index) + assert self.added_vocab_start_index <= self.padded_added_vocab_start_index assert self.org_vocab_end_index <= self.padded_org_vocab_end_index assert self.added_vocab_end_index <= self.padded_added_vocab_end_index @@ -147,20 +156,27 @@ def __post_init__(self): @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def get_masked_input_and_mask( - input_: torch.Tensor, org_vocab_start_index: int, - org_vocab_end_index: int, num_org_vocab_padding: int, - added_vocab_start_index: int, - added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]: + input_: torch.Tensor, + org_vocab_start_index: int, + org_vocab_end_index: int, + num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int, +) -> tuple[torch.Tensor, torch.Tensor]: # torch.compile will fuse all of the pointwise ops below # into a single kernel, making it very fast - org_vocab_mask = (input_ >= org_vocab_start_index) & ( - input_ < org_vocab_end_index) + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) added_vocab_mask = (input_ >= added_vocab_start_index) & ( - input_ < added_vocab_end_index) - added_offset = added_vocab_start_index - ( - org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding - valid_offset = (org_vocab_start_index * - org_vocab_mask) + (added_offset * added_vocab_mask) + input_ < added_vocab_end_index + ) + added_offset = ( + added_vocab_start_index + - (org_vocab_end_index - org_vocab_start_index) + - num_org_vocab_padding + ) + valid_offset = (org_vocab_start_index * org_vocab_mask) + ( + added_offset * added_vocab_mask + ) vocab_mask = org_vocab_mask | added_vocab_mask input_ = vocab_mask * (input_ - valid_offset) return input_, ~vocab_mask @@ -206,14 +222,16 @@ class VocabParallelEmbedding(CustomOp): prefix: full name of the layer in the state dict """ # noqa: E501 - def __init__(self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() # Keep the input dimensions. @@ -223,18 +241,22 @@ def __init__(self, self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings num_added_embeddings = num_embeddings - self.org_vocab_size - self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, - self.padding_size) + self.org_vocab_size_padded = pad_vocab_size( + self.org_vocab_size, self.padding_size + ) self.num_embeddings_padded = pad_vocab_size( - self.org_vocab_size_padded + num_added_embeddings, - self.padding_size) + self.org_vocab_size_padded + num_added_embeddings, self.padding_size + ) assert self.org_vocab_size_padded <= self.num_embeddings_padded - self.shard_indices = self._get_indices(self.num_embeddings_padded, - self.org_vocab_size_padded, - self.num_embeddings, - self.org_vocab_size, tp_rank, - self.tp_size) + self.shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + tp_rank, + self.tp_size, + ) self.embedding_dim = embedding_dim quant_method = None @@ -248,11 +270,13 @@ def __init__(self, # layer type like ParallelLMHead, this is not important. is_embedding_layer = type(self) is VocabParallelEmbedding quant_method_implements_embedding = method_has_implemented_embedding( - type(quant_method)) + type(quant_method) + ) if is_embedding_layer and not quant_method_implements_embedding: raise NotImplementedError( f"The class {type(quant_method).__name__} must implement " - "the 'embedding' method, see UnquantizedEmbeddingMethod.") + "the 'embedding' method, see UnquantizedEmbeddingMethod." + ) self.quant_method: QuantizeMethodBase = quant_method @@ -260,58 +284,73 @@ def __init__(self, params_dtype = torch.get_default_dtype() # Divide the weight matrix along the vocabulary dimension. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size - self.num_embeddings_per_partition = divide(self.num_embeddings_padded, - self.tp_size) - assert (self.shard_indices.num_elements_padded == - self.num_embeddings_per_partition) + self.num_embeddings_per_partition = divide( + self.num_embeddings_padded, self.tp_size + ) + assert ( + self.shard_indices.num_elements_padded == self.num_embeddings_per_partition + ) self.num_org_embeddings_per_partition = ( - self.shard_indices.org_vocab_end_index - - self.shard_indices.org_vocab_start_index) + self.shard_indices.org_vocab_end_index + - self.shard_indices.org_vocab_start_index + ) self.num_added_embeddings_per_partition = ( - self.shard_indices.added_vocab_end_index - - self.shard_indices.added_vocab_start_index) - - self.quant_method.create_weights(self, - self.embedding_dim, - [self.num_embeddings_per_partition], - self.embedding_dim, - self.num_embeddings_padded, - params_dtype=params_dtype, - weight_loader=self.weight_loader) + self.shard_indices.added_vocab_end_index + - self.shard_indices.added_vocab_start_index + ) + + self.quant_method.create_weights( + self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) @classmethod - def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, - vocab_size: int, org_vocab_size: int, tp_rank: int, - tp_size: int) -> VocabParallelEmbeddingShardIndices: + def _get_indices( + cls, + vocab_size_padded: int, + org_vocab_size_padded: int, + vocab_size: int, + org_vocab_size: int, + tp_rank: int, + tp_size: int, + ) -> VocabParallelEmbeddingShardIndices: """Get start and end indices for vocab parallel embedding, following the layout outlined in the class docstring, based on the given tp_rank and tp_size.""" num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded padded_org_vocab_start_index, padded_org_vocab_end_index = ( - vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, - tp_size)) + vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size) + ) padded_added_vocab_start_index, padded_added_vocab_end_index = ( - vocab_range_from_global_vocab_size(num_added_embeddings_padded, - tp_rank, - tp_size, - offset=org_vocab_size)) + vocab_range_from_global_vocab_size( + num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size + ) + ) # remove padding - org_vocab_start_index = min(padded_org_vocab_start_index, - org_vocab_size) + org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size) org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) - added_vocab_start_index = min(padded_added_vocab_start_index, - vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size) added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) return VocabParallelEmbeddingShardIndices( - padded_org_vocab_start_index, padded_org_vocab_end_index, - padded_added_vocab_start_index, padded_added_vocab_end_index, - org_vocab_start_index, org_vocab_end_index, - added_vocab_start_index, added_vocab_end_index) + padded_org_vocab_start_index, + padded_org_vocab_end_index, + padded_added_vocab_start_index, + padded_added_vocab_end_index, + org_vocab_start_index, + org_vocab_end_index, + added_vocab_start_index, + added_vocab_end_index, + ) def get_sharded_to_full_mapping(self) -> Optional[list[int]]: """Get a mapping that can be used to reindex the gathered logits for sampling. - + During sampling, we gather logits from all ranks. The relationship of index->token_id will follow the same format as outlined in the class docstring. However, after the gather, we want to reindex the final @@ -326,32 +365,49 @@ def get_sharded_to_full_mapping(self) -> Optional[list[int]]: added_embeddings: list[int] = [] padding: list[int] = [] for tp_rank in range(self.tp_size): - shard_indices = self._get_indices(self.num_embeddings_padded, - self.org_vocab_size_padded, - self.num_embeddings, - self.org_vocab_size, tp_rank, - self.tp_size) + shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + tp_rank, + self.tp_size, + ) range_start = self.num_embeddings_per_partition * tp_rank range_end = self.num_embeddings_per_partition * (tp_rank + 1) base_embeddings.extend( - range(range_start, - range_start + shard_indices.num_org_elements)) + range(range_start, range_start + shard_indices.num_org_elements) + ) padding.extend( - range(range_start + shard_indices.num_org_elements, - range_start + shard_indices.num_org_elements_padded)) + range( + range_start + shard_indices.num_org_elements, + range_start + shard_indices.num_org_elements_padded, + ) + ) added_embeddings.extend( range( range_start + shard_indices.num_org_elements_padded, - range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements)) + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + ) + ) padding.extend( range( - range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements, - range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements_padded)) - assert (range_start + shard_indices.num_org_elements_padded + - shard_indices.num_added_elements_padded == range_end) + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded, + ) + ) + assert ( + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded + == range_end + ) ret = base_embeddings + added_embeddings + padding assert len(ret) == self.num_embeddings_padded return ret @@ -385,10 +441,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # If param packed on the same dim we are sharding on, then # need to adjust offsets of loaded weight by pack_factor. if packed_dim is not None and packed_dim == output_dim: - packed_factor = param.packed_factor if isinstance( - param, BasevLLMParameter) else param.pack_factor - assert loaded_weight.shape[output_dim] == (self.org_vocab_size // - param.packed_factor) + packed_factor = ( + param.packed_factor + if isinstance(param, BasevLLMParameter) + else param.pack_factor + ) + assert loaded_weight.shape[output_dim] == ( + self.org_vocab_size // param.packed_factor + ) start_idx = start_idx // packed_factor shard_size = shard_size // packed_factor else: @@ -396,23 +456,24 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # Copy the data. Select chunk corresponding to current shard. loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - param[:loaded_weight.shape[0]].data.copy_(loaded_weight) - param[loaded_weight.shape[0]:].data.fill_(0) + param[: loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0] :].data.fill_(0) def forward_native(self, input_): if self.tp_size > 1: # Build the mask. masked_input, input_mask = get_masked_input_and_mask( - input_, self.shard_indices.org_vocab_start_index, + input_, + self.shard_indices.org_vocab_start_index, self.shard_indices.org_vocab_end_index, self.shard_indices.num_org_vocab_padding, self.shard_indices.added_vocab_start_index, - self.shard_indices.added_vocab_end_index) + self.shard_indices.added_vocab_end_index, + ) else: masked_input = input_ # Get the embeddings. - output_parallel = self.quant_method.embedding(self, - masked_input.long()) + output_parallel = self.quant_method.embedding(self, masked_input.long()) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) @@ -427,8 +488,8 @@ def extra_repr(self) -> str: s = f"num_embeddings={self.num_embeddings_per_partition}" s += f", embedding_dim={self.embedding_dim}" s += f", org_vocab_size={self.org_vocab_size}" - s += f', num_embeddings_padded={self.num_embeddings_padded}' - s += f', tp_size={self.tp_size}' + s += f", num_embeddings_padded={self.num_embeddings_padded}" + s += f", tp_size={self.tp_size}" return s @@ -449,27 +510,38 @@ class ParallelLMHead(VocabParallelEmbedding): padding_size: padding size for the vocabulary. """ - def __init__(self, - num_embeddings: int, - embedding_dim: int, - bias: bool = False, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings, padding_size, quant_config, - prefix) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__( + num_embeddings, + embedding_dim, + params_dtype, + org_num_embeddings, + padding_size, + quant_config, + prefix, + ) self.quant_config = quant_config if bias: self.bias = Parameter( - torch.empty(self.num_embeddings_per_partition, - dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) + torch.empty(self.num_embeddings_per_partition, dtype=params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) else: self.register_parameter("bias", None) diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 138a2ff30b62..df0d059594a7 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -9,18 +9,20 @@ from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.bitsandbytes_loader import ( - BitsAndBytesModelLoader) +from vllm.model_executor.model_loader.bitsandbytes_loader import BitsAndBytesModelLoader from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader from vllm.model_executor.model_loader.runai_streamer_loader import ( - RunaiModelStreamerLoader) -from vllm.model_executor.model_loader.sharded_state_loader import ( - ShardedStateLoader) + RunaiModelStreamerLoader, +) +from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader from vllm.model_executor.model_loader.utils import ( - get_architecture_class_name, get_model_architecture, get_model_cls) + get_architecture_class_name, + get_model_architecture, + get_model_cls, +) logger = init_logger(__name__) @@ -69,7 +71,10 @@ def register_model_loader(load_format: str): Examples: >>> from vllm.config.load import LoadConfig - >>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader + >>> from vllm.model_executor.model_loader import ( + ... get_model_loader, + ... register_model_loader, + ... ) >>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader >>> >>> @register_model_loader("my_loader") @@ -89,14 +94,20 @@ def _wrapper(model_loader_cls): if load_format in _LOAD_FORMAT_TO_MODEL_LOADER: logger.warning( "Load format `%s` is already registered, and will be " - "overwritten by the new loader class `%s`.", load_format, - model_loader_cls) + "overwritten by the new loader class `%s`.", + load_format, + model_loader_cls, + ) if not issubclass(model_loader_cls, BaseModelLoader): - raise ValueError("The model loader must be a subclass of " - "`BaseModelLoader`.") + raise ValueError( + "The model loader must be a subclass of `BaseModelLoader`." + ) _LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls - logger.info("Registered model loader `%s` with load format `%s`", - model_loader_cls, load_format) + logger.info( + "Registered model loader `%s` with load format `%s`", + model_loader_cls, + load_format, + ) return model_loader_cls return _wrapper @@ -110,14 +121,13 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config) -def get_model(*, - vllm_config: VllmConfig, - model_config: Optional[ModelConfig] = None) -> nn.Module: +def get_model( + *, vllm_config: VllmConfig, model_config: Optional[ModelConfig] = None +) -> nn.Module: loader = get_model_loader(vllm_config.load_config) if model_config is None: model_config = vllm_config.model_config - return loader.load_model(vllm_config=vllm_config, - model_config=model_config) + return loader.load_model(vllm_config=vllm_config, model_config=model_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index ab538a3c9562..6106a1ab8a85 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -9,7 +9,10 @@ from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) + initialize_model, + process_weights_after_loading, + set_default_torch_dtype, +) logger = init_logger(__name__) @@ -26,24 +29,26 @@ def download_model(self, model_config: ModelConfig) -> None: raise NotImplementedError @abstractmethod - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: - """Load weights into a model. This standalone API allows + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: + """Load weights into a model. This standalone API allows inplace weights loading for an already-initialized model""" raise NotImplementedError - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model( + self, vllm_config: VllmConfig, model_config: ModelConfig + ) -> nn.Module: """Load a model with the given configurations.""" device_config = vllm_config.device_config load_config = vllm_config.load_config - load_device = device_config.device if load_config.device is None else \ - load_config.device + load_device = ( + device_config.device if load_config.device is None else load_config.device + ) target_device = torch.device(load_device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config, - model_config=model_config) + model = initialize_model( + vllm_config=vllm_config, model_config=model_config + ) logger.debug("Loading weights on %s ...", load_device) # Quantization does not happen in `load_weights` but after it diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 4edf193b54ac..8c1ff0300b24 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -18,38 +18,43 @@ from vllm.config import ModelConfig from vllm.config.load import LoadConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -# yapf: enable +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (LinearBase, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.utils import (ParamMapping, - set_default_torch_dtype) +from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import ( - download_safetensors_index_file_from_hf, download_weights_from_hf, - filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, - pt_weights_iterator, safetensors_weights_iterator) + download_safetensors_index_file_from_hf, + download_weights_from_hf, + filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, + pt_weights_iterator, + safetensors_weights_iterator, +) from vllm.model_executor.models import is_pooling_model -from vllm.model_executor.utils import (get_moe_expert_mapping, - get_packed_modules_mapping, - set_weight_attrs) +from vllm.model_executor.utils import ( + get_moe_expert_mapping, + get_packed_modules_mapping, + set_weight_attrs, +) from vllm.platforms import current_platform -# yapf conflicts with isort for this block - logger = init_logger(__name__) def is_moe_model(model: torch.nn.Module) -> bool: """Checks if the model contains FusedMoE layers.""" - return bool(any( - isinstance(module, FusedMoE) for module in model.modules())) + return bool(any(isinstance(module, FusedMoE) for module in model.modules())) class BitsAndBytesModelLoader(BaseModelLoader): @@ -92,8 +97,7 @@ def _get_weight_files( if is_local: for pattern in allowed_patterns: - weight_files = glob.glob( - os.path.join(model_name_or_path, pattern)) + weight_files = glob.glob(os.path.join(model_name_or_path, pattern)) if weight_files: return model_name_or_path, weight_files, pattern else: @@ -109,20 +113,24 @@ def _get_weight_files( revision, ignore_patterns=self.load_config.ignore_patterns, ) - return hf_folder, glob.glob( - os.path.join(hf_folder, pattern)), pattern + return ( + hf_folder, + glob.glob(os.path.join(hf_folder, pattern)), + pattern, + ) - raise RuntimeError( - f"No model weights found in: `{model_name_or_path}`") + raise RuntimeError(f"No model weights found in: `{model_name_or_path}`") - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]) -> tuple[list[str], bool]: + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str] + ) -> tuple[list[str], bool]: """Prepare weight files for the model.""" allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] hf_folder, hf_weights_files, matched_pattern = self._get_weight_files( - model_name_or_path, allowed_patterns, revision) + model_name_or_path, allowed_patterns, revision + ) use_safetensors = matched_pattern == "*.safetensors" is_local = os.path.isdir(model_name_or_path) @@ -141,25 +149,27 @@ def _prepare_weights(self, model_name_or_path: str, revision, ) hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder, index_file) + hf_weights_files, hf_folder, index_file + ) else: - hf_weights_files = filter_files_not_needed_for_inference( - hf_weights_files) + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) if len(hf_weights_files) == 0: raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") + f"Cannot find any model weights with `{model_name_or_path}`" + ) return hf_weights_files, use_safetensors def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): - def _maybe_pool_model(module_name: str): # For pool model, we need to add the prefix `model.` # for the weight name if possible. - if self.is_pool_model and self.target_modules[0]. \ - startswith("model.") and not module_name.startswith( - "model."): + if ( + self.is_pool_model + and self.target_modules[0].startswith("model.") + and not module_name.startswith("model.") + ): return "model." + module_name return module_name @@ -187,8 +197,7 @@ def _get_quantized_weights_iterator( self, model_name_or_path: str, revision: Optional[str], - ) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, - Any]]: + ) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, as well as the quantization state dictionary.""" @@ -196,37 +205,41 @@ def _get_quantized_weights_iterator( try: import bitsandbytes - if version.parse( - bitsandbytes.__version__) < version.parse("0.46.1"): - raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.46.1.") + if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"): + raise ImportError( + "bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.46.1." + ) except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.46.1 via " - "`pip install bitsandbytes>=0.46.1` to use " - "bitsandbytes quantizer.") from err + raise ImportError( + "Please install bitsandbytes>=0.46.1 via " + "`pip install bitsandbytes>=0.46.1` to use " + "bitsandbytes quantizer." + ) from err hf_weights_files, use_safetensors = self._prepare_weights( - model_name_or_path, revision) + model_name_or_path, revision + ) quant_state_dict: dict[str, Any] = {} if self.pre_quant: if self.load_8bit: return self._quantized_8bit_generator( - hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict + hf_weights_files, use_safetensors, quant_state_dict + ), quant_state_dict else: return self._quantized_4bit_generator( - hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict + hf_weights_files, use_safetensors, quant_state_dict + ), quant_state_dict - return self._unquantized_generator(hf_weights_files, use_safetensors, - quant_state_dict), quant_state_dict + return self._unquantized_generator( + hf_weights_files, use_safetensors, quant_state_dict + ), quant_state_dict def _is_8bit_weight_name(self, weight_name: str): quantized_suffix = {".scb", ".weight_format"} - return any(weight_name.lower().endswith(suffix) - for suffix in quantized_suffix) + return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix) def _is_4bit_weight_name(self, weight_name: str): quantized_suffix = { @@ -239,12 +252,13 @@ def _is_4bit_weight_name(self, weight_name: str): suffix = weight_name.split(".")[-1] return any(q_suffix in suffix for q_suffix in quantized_suffix) - def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: + def _quantized_8bit_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: for ( - org_weight_name, - mapped_weight_name, - weight_tensor, + org_weight_name, + mapped_weight_name, + weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): if not mapped_weight_name.lower().endswith(".scb"): continue @@ -253,9 +267,9 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, quant_state_dict[weight_key] = weight_tensor for ( - org_weight_name, - mapped_weight_name, - weight_tensor, + org_weight_name, + mapped_weight_name, + weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): if self._is_8bit_weight_name(mapped_weight_name): continue @@ -266,18 +280,18 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, else: yield org_weight_name, weight_tensor - def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: + def _quantized_4bit_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: from bitsandbytes.functional import QuantState # First iterate over all quant state weights - weight_iterator = self._hf_weight_iter(hf_weights_files, - use_safetensors) + weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors) temp_state_dict = {} for ( - org_weight_name, - mapped_weight_name, - weight_tensor, + org_weight_name, + mapped_weight_name, + weight_tensor, ) in weight_iterator: if not self._is_4bit_weight_name(mapped_weight_name): continue @@ -289,98 +303,111 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, temp_state_dict[mapped_weight_name] = weight_tensor # Closure to parse quant_state for each prequant weight - def _parse_quant_state(param_name: str, - temp_state_dict: dict) -> QuantState: + def _parse_quant_state(param_name: str, temp_state_dict: dict) -> QuantState: quant_state = {} for k in temp_state_dict: if param_name + "." in k: quant_state[k] = temp_state_dict[k] - return QuantState.from_dict(quant_state, - device=current_platform.device_type) + return QuantState.from_dict( + quant_state, device=current_platform.device_type + ) # Second iterate over all prequant and normal weights # pre quantized weights would have a quant_state for ( - org_weight_name, - mapped_weight_name, - weight_tensor, + org_weight_name, + mapped_weight_name, + weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): if self._is_4bit_weight_name(mapped_weight_name): continue - if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" - in temp_state_dict) or ( - f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" - in temp_state_dict): - quant_state = _parse_quant_state(mapped_weight_name, - temp_state_dict) + if ( + f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict + ) or ( + f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict + ): + quant_state = _parse_quant_state(mapped_weight_name, temp_state_dict) quant_state_dict[mapped_weight_name] = quant_state yield org_weight_name, weight_tensor else: yield org_weight_name, weight_tensor - def _unquantized_generator(self, hf_weights_files, use_safetensors, - quant_state_dict) -> Generator: + def _unquantized_generator( + self, hf_weights_files, use_safetensors, quant_state_dict + ) -> Generator: from bitsandbytes.functional import quantize_4bit global_tp_size = get_tensor_model_parallel_world_size() global_tp_rank = get_tensor_model_parallel_rank() - check_match = (lambda weight_name, module_name: weight_name. - removesuffix(".weight") == module_name) + check_match = ( + lambda weight_name, module_name: weight_name.removesuffix(".weight") + == module_name + ) for ( - org_weight_name, - mapped_weight_name, - weight_tensor, + org_weight_name, + mapped_weight_name, + weight_tensor, ) in self._hf_weight_iter(hf_weights_files, use_safetensors): - # override tp_size and tp_rank if the module has disabled TP - if any(tp_disabled_module in mapped_weight_name - for tp_disabled_module in self.tp_disabled_modules): + if any( + tp_disabled_module in mapped_weight_name + for tp_disabled_module in self.tp_disabled_modules + ): tp_size = 1 tp_rank = 0 else: tp_size = global_tp_size tp_rank = global_tp_rank - if any(target_module in mapped_weight_name - for target_module in self.target_modules - ) and mapped_weight_name.endswith(".weight"): + if any( + target_module in mapped_weight_name + for target_module in self.target_modules + ) and mapped_weight_name.endswith(".weight"): # Without sharding if any( - check_match(mapped_weight_name, module) - for module in self.unsharded_weights_modules): + check_match(mapped_weight_name, module) + for module in self.unsharded_weights_modules + ): weight_sub_tensor = weight_tensor # Shard by column elif any( - check_match(mapped_weight_name, module) - for module in self.column_sharded_weights_modules): + check_match(mapped_weight_name, module) + for module in self.column_sharded_weights_modules + ): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) - weight_sub_tensor = weight_tensor[..., - start_index:end_index] + weight_sub_tensor = weight_tensor[..., start_index:end_index] # Weights have fused on disk. In this case, we assume that the # weight and module use same name. elif any( - check_match(mapped_weight_name, module) - for module in self.maybe_fused_weights_modules): + check_match(mapped_weight_name, module) + for module in self.maybe_fused_weights_modules + ): # special case for fused weights # get the size of each shard weight tensor total_shard_sizes = next( - (sizes for module, sizes in - self.maybe_fused_weights_modules.items() - if check_match(mapped_weight_name, module))) + ( + sizes + for module, sizes in self.maybe_fused_weights_modules.items() # noqa: E501 + if check_match(mapped_weight_name, module) + ) + ) total_size = weight_tensor.size(0) assert total_size == sum(total_shard_sizes) # get the start/end index of each shard weight tensor total_start_index = list( - itertools.accumulate([0] + total_shard_sizes))[:-1] - shard_weights_index = [( - idx + size // tp_size * tp_rank, - idx + size // tp_size * (tp_rank + 1), - ) for idx, size in zip(total_start_index, - total_shard_sizes)] + itertools.accumulate([0] + total_shard_sizes) + )[:-1] + shard_weights_index = [ + ( + idx + size // tp_size * tp_rank, + idx + size // tp_size * (tp_rank + 1), + ) + for idx, size in zip(total_start_index, total_shard_sizes) + ] # slice and reorder the weight tensor weight_tensor = [ weight_tensor[start_index:end_index, ...] @@ -392,15 +419,15 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, total_size = weight_tensor.size(0) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) - weight_sub_tensor = weight_tensor[start_index:end_index, - ...] + weight_sub_tensor = weight_tensor[start_index:end_index, ...] # bitsandbytes requires data in GPU if weight_sub_tensor.is_cuda: loaded_weight = weight_sub_tensor else: loaded_weight = weight_sub_tensor.to( - device=current_platform.device_type) + device=current_platform.device_type + ) # remove the following after the issue is fixed: # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 @@ -421,12 +448,13 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, def _get_bnb_target_modules(self, model: nn.Module) -> None: """ - Identify and collect all modules that support BitsAndBytes + Identify and collect all modules that support BitsAndBytes quantization. """ for name, module in model.named_modules(): - if (isinstance(module, LinearBase) - and hasattr(module.quant_method, "quant_config")): + if isinstance(module, LinearBase) and hasattr( + module.quant_method, "quant_config" + ): if modules_info := self.modules_mapping.get_sub_modules(name): # Map vllm's names to transformers's names. rep_name, sub_modules = modules_info @@ -442,45 +470,48 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None: if module.disable_tp: self.tp_disabled_modules.append(name) elif isinstance(module, FusedMoE) and hasattr( - module.quant_method, "quant_config"): + module.quant_method, "quant_config" + ): # TODO: support FusedMoE with prequant and 8bit. if self.pre_quant and self.load_8bit: raise ValueError( "Prequant BitsAndBytes 8bit models with FusedMoE " - "is not supported yet.") + "is not supported yet." + ) # Get the corresponding weight name using module name and # expert_params_mapping. for exp in self.expert_params_mapping: weight_name = exp[1] - rep_name = name.replace("experts", - "") + weight_name.removesuffix(".") + rep_name = name.replace("experts", "") + weight_name.removesuffix( + "." + ) self.target_modules.append(rep_name) - assert (self.target_modules - ), "vLLM currently does not support BNB quantization for" + assert self.target_modules, ( + "vLLM currently does not support BNB quantization for" + ) f" {type(model).__name__}" def _classify_module_sharding(self, model: nn.Module): """ - Categorize modules based on their weight sharding requirements + Categorize modules based on their weight sharding requirements for tensor parallelism. """ for name, module in model.named_modules(): # Some modules like `ReplicatedLinear` should not have their weights # sharded. The reason for implementing it this way is to avoid new # static variable in the model implementation. - if isinstance(module, (ReplicatedLinear, )): + if isinstance(module, (ReplicatedLinear,)): self.unsharded_weights_modules.append(name) # `QKVParallelLinear` and `MergedColumnParallelLinear` might have # fused weights on disk. We need to use the output sizes of these # modules to shard the weights correctly. - elif isinstance(module, - (QKVParallelLinear, MergedColumnParallelLinear)): + elif isinstance(module, (QKVParallelLinear, MergedColumnParallelLinear)): self.maybe_fused_weights_modules[name] = module.output_sizes # In TP, these weights are partitioned along the column # dimension (dim=-1) - elif isinstance(module, (RowParallelLinear, )): + elif isinstance(module, (RowParallelLinear,)): self.column_sharded_weights_modules.append(name) elif isinstance(module, FusedMoE): expert_mapping = self.expert_params_mapping @@ -488,48 +519,53 @@ def _classify_module_sharding(self, model: nn.Module): if exp[-1] == "w2": weight_name = exp[1] rep_name = name.replace( - "experts", "") + weight_name.removesuffix(".") + "experts", "" + ) + weight_name.removesuffix(".") self.column_sharded_weights_modules.append(rep_name) - def _verify_model_compatibility(self, model: nn.Module, - model_config: ModelConfig) -> None: + def _verify_model_compatibility( + self, model: nn.Module, model_config: ModelConfig + ) -> None: """ Verify that the model is compatible with BitsAndBytes quantization. """ if not hasattr(model, "load_weights"): raise AttributeError( "The required method 'load_weights' is not defined in class" - f" {type(model).__name__}.") + f" {type(model).__name__}." + ) if not hasattr(model, "packed_modules_mapping"): raise AttributeError( f"Model {type(model).__name__} does not support BitsAndBytes " - "quantization yet. No 'packed_modules_mapping' found.") + "quantization yet. No 'packed_modules_mapping' found." + ) - quant_config = getattr(model_config.hf_config, "quantization_config", - None) + quant_config = getattr(model_config.hf_config, "quantization_config", None) if quant_config is not None: quant_method = quant_config.get("quant_method") if quant_method == "bitsandbytes": self.pre_quant = True else: raise ValueError( - f"BitsAndBytes loader does not support {quant_method} " - "quantization") + f"BitsAndBytes loader does not support {quant_method} quantization" + ) # The quant_states in pre_quantized models cannot work with a split # weight tensor. So TP does not work with pre_quantized bnb models. if self.pre_quant and get_tensor_model_parallel_world_size() > 1: raise ValueError( "Prequant BitsAndBytes models with tensor parallelism is not " - "supported. Please try with pipeline parallelism.") + "supported. Please try with pipeline parallelism." + ) if self.pre_quant: self.load_8bit = quant_config.get("load_in_8bit", False) - def _initialize_loader_state(self, model: nn.Module, - model_config: ModelConfig) -> None: + def _initialize_loader_state( + self, model: nn.Module, model_config: ModelConfig + ) -> None: """ - Initialize the loader's internal state based on the model and + Initialize the loader's internal state based on the model and configuration. """ self.is_pool_model = is_pooling_model(model) @@ -541,7 +577,8 @@ def _initialize_loader_state(self, model: nn.Module, raise AttributeError( f"MoE Model {type(model).__name__} does not support " "BitsAndBytes quantization yet. Ensure this model has " - "'get_expert_mapping' method.") + "'get_expert_mapping' method." + ) # For some models like Molmo, we need to use hf_to_vllm_mapper # to ensure correct loading of weights. if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None): @@ -552,22 +589,20 @@ def _initialize_loader_state(self, model: nn.Module, def _dequantize_dq(self, quant_states: Any): """ - When BNB employs Double Quantization, we perform the dequantization of - these constants during weight loading rather than at inference time, - thereby avoiding this computational overhead during inference. This + When BNB employs Double Quantization, we perform the dequantization of + these constants during weight loading rather than at inference time, + thereby avoiding this computational overhead during inference. This comes at the cost of increased memory usage. """ from bitsandbytes.functional import QuantState, dequantize_blockwise def _dequantize_single_state(quant_state): """Helper function to dequantize a single QuantState object.""" - if not (isinstance(quant_state, QuantState) - and quant_state.nested): + if not (isinstance(quant_state, QuantState) and quant_state.nested): return # Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356 - absmax = dequantize_blockwise(quant_state.absmax, - quant_state.state2) + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset # Ensure float32 dtype @@ -586,10 +621,9 @@ def _dequantize_single_state(quant_state): _dequantize_single_state(quant_states) return quant_states - def _fuse_moe_quant_states(self, model: nn.Module, - quant_states_dict: dict) -> dict: + def _fuse_moe_quant_states(self, model: nn.Module, quant_states_dict: dict) -> dict: """ - + This function consolidates individual expert quantization states into fused representations for w13 and w2. """ @@ -609,12 +643,12 @@ def _fuse_moe_quant_states(self, model: nn.Module, for exp in expert_mapping: shard_id = exp[-1] if shard_id not in ("w1", "w2", "w3"): - raise ValueError(f"shard_id must be ['w1','w2','w3'] but " - f"got {shard_id}.") + raise ValueError( + f"shard_id must be ['w1','w2','w3'] but got {shard_id}." + ) layer_prefix = name.split("experts")[0] weight_qual_name = layer_prefix + exp[1] + "weight" - quant_state = self._dequantize_dq( - quant_states_dict[weight_qual_name]) + quant_state = self._dequantize_dq(quant_states_dict[weight_qual_name]) if shard_id == "w1": w1_states_lst.append(quant_state) elif shard_id == "w2": @@ -622,14 +656,12 @@ def _fuse_moe_quant_states(self, model: nn.Module, else: w3_states_lst.append(quant_state) del quant_states_dict[weight_qual_name] - assert (len(w1_states_lst) == len(w2_states_lst) == - len(w3_states_lst)) + assert len(w1_states_lst) == len(w2_states_lst) == len(w3_states_lst) w13_absmax_lst = [] w2_absmax_lst = [] w13_total_dim0 = 0 w2_total_dim0 = 0 - for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, - w3_states_lst): + for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, w3_states_lst): assert w1_qs.shape == w3_qs.shape assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype @@ -669,12 +701,13 @@ def _fuse_moe_quant_states(self, model: nn.Module, return expert_qs_dict def _stack_quantization_states( - self, model: nn.Module, - quant_state_dict: dict) -> dict[str, dict[int, Any]]: + self, model: nn.Module, quant_state_dict: dict + ) -> dict[str, dict[int, Any]]: stacked_quant_state_dict: dict[str, dict[int, Any]] = {} # TODO: Change this lazy import to normal import # after the checks are updated to run on a new version from vllm.model_executor.models.utils import is_pp_missing_parameter + param_dict = dict(model.named_parameters()) for quant_param_name in quant_state_dict: if is_pp_missing_parameter(quant_param_name, model): @@ -684,23 +717,23 @@ def _stack_quantization_states( shard_index = 0 for shard_name, ( - weight_name, - index, + weight_name, + index, ) in self.modules_mapping.inverse_packed_mapping.items(): # Some models, such as MiniCPM V2.5/2.6, contain both # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' # from being incorrectly identified as being present in # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight shard_pos = quant_param_name.find(shard_name) - can_correct_rename = (shard_pos - > 0) and (quant_param_name[shard_pos - 1] - == ".") + can_correct_rename = (shard_pos > 0) and ( + quant_param_name[shard_pos - 1] == "." + ) # If the quant_param_name is packed, it won't occur in the # param_dict before renaming. - new_quant_param_name = quant_param_name.replace( - shard_name, weight_name) - need_rename = (quant_param_name not in param_dict) \ - and (new_quant_param_name in param_dict) + new_quant_param_name = quant_param_name.replace(shard_name, weight_name) + need_rename = (quant_param_name not in param_dict) and ( + new_quant_param_name in param_dict + ) if can_correct_rename and need_rename: shard_index = index quant_param_name = new_quant_param_name @@ -714,12 +747,14 @@ def _stack_quantization_states( if quant_param_name not in stacked_quant_state_dict: stacked_quant_state_dict[quant_param_name] = {} - stacked_quant_state_dict[quant_param_name][shard_index] = ( - quant_state_dict[non_stacked_param_name]) + stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[ + non_stacked_param_name + ] return stacked_quant_state_dict - def _bind_quant_states_to_params(self, model: nn.Module, - stacked_quant_state_dict: dict) -> None: + def _bind_quant_states_to_params( + self, model: nn.Module, stacked_quant_state_dict: dict + ) -> None: # save quant_states and offsets as the attributes of the parameters param_dict = dict(model.named_parameters()) for param_name, param in param_dict.items(): @@ -733,13 +768,11 @@ def _bind_quant_states_to_params(self, model: nn.Module, pack_ratio = getattr(param, "pack_factor", -1) if pack_ratio == -1: - raise ValueError( - f"pack_factor not set for parameter {param_name}.") + raise ValueError(f"pack_factor not set for parameter {param_name}.") num_elements = [0] * len(quant_states) for seq, quant_state in quant_states.items(): - num_elements[seq] = (math.prod(quant_state.shape) // - pack_ratio) + num_elements[seq] = math.prod(quant_state.shape) // pack_ratio offsets = np.concatenate(([0], np.cumsum(num_elements))) # Make torch infer_schema happy @@ -748,38 +781,39 @@ def _bind_quant_states_to_params(self, model: nn.Module, if self.load_8bit: set_weight_attrs( - param, {"matmul_state": [None] * len(quant_states)}) - - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + param, {"matmul_state": [None] * len(quant_states)} + ) + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: self._verify_model_compatibility(model, model_config) self._initialize_loader_state(model, model_config) - logger.info("Loading weights with BitsAndBytes quantization. " - "May take a while ...") - qweight_iterator, quant_state_dict = ( - self._get_quantized_weights_iterator( - model_config.model, - model_config.revision, - )) + logger.info( + "Loading weights with BitsAndBytes quantization. May take a while ..." + ) + qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator( + model_config.model, + model_config.revision, + ) weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights(qweight_iterator) # Some models may have weights loading tracker unimplemented. if loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: - raise ValueError("Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") - expert_quant_state_dict = self._fuse_moe_quant_states( - model, quant_state_dict) + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}" + ) + expert_quant_state_dict = self._fuse_moe_quant_states(model, quant_state_dict) stacked_quant_state_dict = self._stack_quantization_states( - model, quant_state_dict) + model, quant_state_dict + ) stacked_quant_state_dict = { **expert_quant_state_dict, - **stacked_quant_state_dict + **stacked_quant_state_dict, } self._bind_quant_states_to_params(model, stacked_quant_state_dict) torch.cuda.empty_cache() diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 8e2db9292ff8..206b8244569f 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -16,12 +16,18 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( - download_safetensors_index_file_from_hf, download_weights_from_hf, - fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, - filter_files_not_needed_for_inference, maybe_download_from_modelscope, + download_safetensors_index_file_from_hf, + download_weights_from_hf, + fastsafetensors_weights_iterator, + filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, + maybe_download_from_modelscope, multi_thread_pt_weights_iterator, - multi_thread_safetensors_weights_iterator, np_cache_weights_iterator, - pt_weights_iterator, safetensors_weights_iterator) + multi_thread_safetensors_weights_iterator, + np_cache_weights_iterator, + pt_weights_iterator, + safetensors_weights_iterator, +) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -63,9 +69,11 @@ def __init__(self, load_config: LoadConfig): unexpected_keys = set(extra_config.keys()) - allowed_keys if unexpected_keys: - raise ValueError(f"Unexpected extra config keys for load format " - f"{load_config.load_format}: " - f"{unexpected_keys}") + raise ValueError( + f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{unexpected_keys}" + ) def _prepare_weights( self, @@ -77,8 +85,10 @@ def _prepare_weights( """Prepare weights for the model. If the model is not local, it will be downloaded.""" - model_name_or_path = (maybe_download_from_modelscope( - model_name_or_path, revision) or model_name_or_path) + model_name_or_path = ( + maybe_download_from_modelscope(model_name_or_path, revision) + or model_name_or_path + ) is_local = os.path.isdir(model_name_or_path) load_format = self.load_config.load_format @@ -87,8 +97,7 @@ def _prepare_weights( # Some quantized models use .pt files for storing the weights. if load_format == "auto": allow_patterns = ["*.safetensors", "*.bin"] - elif (load_format == "safetensors" - or load_format == "fastsafetensors"): + elif load_format == "safetensors" or load_format == "fastsafetensors": use_safetensors = True allow_patterns = ["*.safetensors"] elif load_format == "mistral": @@ -141,25 +150,29 @@ def _prepare_weights( revision, ) hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder, index_file) + hf_weights_files, hf_folder, index_file + ) else: - hf_weights_files = filter_files_not_needed_for_inference( - hf_weights_files) + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) if len(hf_weights_files) == 0: raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") + f"Cannot find any model weights with `{model_name_or_path}`" + ) return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, source: "Source" + self, source: "Source" ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - source.model_or_path, source.revision, source.fall_back_to_pt, - source.allow_patterns_overrides) + source.model_or_path, + source.revision, + source.fall_back_to_pt, + source.allow_patterns_overrides, + ) if self.load_config.load_format == "npcache": # Currently np_cache only support *.bin checkpoints assert use_safetensors is False @@ -178,13 +191,13 @@ def _get_weights_iterator( ) else: if extra_config.get("enable_multithread_load"): - weights_iterator = ( - multi_thread_safetensors_weights_iterator( - hf_weights_files, - self.load_config.use_tqdm_on_load, - max_workers=extra_config.get( - "num_threads", self.DEFAULT_NUM_THREADS), - )) + weights_iterator = multi_thread_safetensors_weights_iterator( + hf_weights_files, + self.load_config.use_tqdm_on_load, + max_workers=extra_config.get( + "num_threads", self.DEFAULT_NUM_THREADS + ), + ) else: weights_iterator = safetensors_weights_iterator( hf_weights_files, @@ -197,8 +210,9 @@ def _get_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, self.load_config.pt_load_map_location, - max_workers=extra_config.get("num_threads", - self.DEFAULT_NUM_THREADS), + max_workers=extra_config.get( + "num_threads", self.DEFAULT_NUM_THREADS + ), ) else: weights_iterator = pt_weights_iterator( @@ -226,8 +240,7 @@ def _xla_weights_iterator(iterator: Generator): if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. - return ((source.prefix + name, tensor) - for (name, tensor) in weights_iterator) + return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) def get_all_weights( self, @@ -238,10 +251,8 @@ def get_all_weights( model_config.model, model_config.revision, prefix="", - fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", - True), - allow_patterns_overrides=getattr(model, "allow_patterns_overrides", - None), + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), ) yield from self._get_weights_iterator(primary_weights) @@ -253,13 +264,14 @@ def get_all_weights( yield from self._get_weights_iterator(source) def download_model(self, model_config: ModelConfig) -> None: - self._prepare_weights(model_config.model, - model_config.revision, - fall_back_to_pt=True, - allow_patterns_overrides=None) + self._prepare_weights( + model_config.model, + model_config.revision, + fall_back_to_pt=True, + allow_patterns_overrides=None, + ) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: weights_to_load = {name for name, _ in model.named_parameters()} # if we don't have `model.weight_metadata_and_attr_saved` defined and @@ -267,38 +279,43 @@ def load_weights(self, model: nn.Module, # or the first run of online quantization # see online_quantization.py for detailed notes offline_quantization_or_first_run_of_online_quantization = not getattr( - model, "weight_metadata_and_attr_saved", False) + model, "weight_metadata_and_attr_saved", False + ) if model_config.quantization is None: # model is not quantized loaded_weights = model.load_weights( - self.get_all_weights(model_config, model)) + self.get_all_weights(model_config, model) + ) elif offline_quantization_or_first_run_of_online_quantization: # case 1: offline quantized checkpoint # case 2: Step I1 first run of weight loading with # online quantization # see online_quantization.py for detailed notes loaded_weights = model.load_weights( - self.get_all_weights(model_config, model)) + self.get_all_weights(model_config, model) + ) else: # to avoid circular dependency from vllm.model_executor.model_loader.online_quantization import ( - load_weights_and_online_quantize) + load_weights_and_online_quantize, + ) # subsequent runs of weight loading with online # quantization - loaded_weights = load_weights_and_online_quantize( - self, model, model_config) + loaded_weights = load_weights_and_online_quantize(self, model, model_config) self.counter_after_loading_weights = time.perf_counter() logger.info( "Loading weights took %.2f seconds", - self.counter_after_loading_weights - - self.counter_before_loading_weights) + self.counter_after_loading_weights - self.counter_before_loading_weights, + ) # We only enable strict check for non-quantized models # that have loaded weights tracking currently. if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: - raise ValueError("Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}" + ) diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index 5b8c6268f64e..b2a934ce5949 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -5,8 +5,7 @@ from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.weight_utils import ( - initialize_dummy_weights) +from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights class DummyModelLoader(BaseModelLoader): @@ -15,14 +14,15 @@ class DummyModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index aaee8f3f7635..93dc754a571c 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -13,10 +13,15 @@ from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) + initialize_model, + process_weights_after_loading, + set_default_torch_dtype, +) from vllm.model_executor.model_loader.weight_utils import ( - get_gguf_extra_tensor_names, get_gguf_weight_type_map, - gguf_quant_weights_iterator) + get_gguf_extra_tensor_names, + get_gguf_weight_type_map, + gguf_quant_weights_iterator, +) class GGUFModelLoader(BaseModelLoader): @@ -29,15 +34,18 @@ class GGUFModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) def _prepare_weights(self, model_name_or_path: str): if os.path.isfile(model_name_or_path): return model_name_or_path # for raw HTTPS link if model_name_or_path.startswith( - ("http://", "https://")) and model_name_or_path.endswith(".gguf"): + ("http://", "https://") + ) and model_name_or_path.endswith(".gguf"): return hf_hub_download(url=model_name_or_path) # repo id/filename.gguf if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"): @@ -46,7 +54,8 @@ def _prepare_weights(self, model_name_or_path: str): else: raise ValueError( f"Unrecognised GGUF reference: {model_name_or_path} " - "(expected local file, raw URL, or <repo_id>/<filename>.gguf)") + "(expected local file, raw URL, or <repo_id>/<filename>.gguf)" + ) def _get_gguf_weights_map(self, model_config: ModelConfig): """ @@ -68,25 +77,32 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): # GGUF layer map assumes that we will have a merged expert weights # so we need to map them manually for idx in range(config.num_hidden_layers): - gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ - f"model.layers.{idx}.mlp.gate.e_score_correction_bias" - gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.down_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = ( + f"model.layers.{idx}.mlp.gate.e_score_correction_bias" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + ) if model_type in ("qwen2_moe", "qwen3_moe"): model_type = model_type.replace("_", "") # GGUF layer map assumes that we will have a merged expert weights # so we need to map them manually for idx in range(config.num_hidden_layers): - gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.down_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + ) arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): @@ -99,7 +115,8 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): name_map = gguf.get_tensor_name_map(arch, num_layers) with torch.device("meta"): dummy_model = AutoModelForCausalLM.from_config( - config, trust_remote_code=model_config.trust_remote_code) + config, trust_remote_code=model_config.trust_remote_code + ) state_dict = dummy_model.state_dict() for hf_name in state_dict: @@ -111,31 +128,31 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): def _get_weights_iterator( self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str] ) -> Generator[tuple[str, torch.Tensor], None, None]: - return gguf_quant_weights_iterator(model_name_or_path, - gguf_to_hf_name_map) + return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) model.load_weights( - self._get_weights_iterator(local_model_path, gguf_weights_map)) + self._get_weights_iterator(local_model_path, gguf_weights_map) + ) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model( + self, vllm_config: VllmConfig, model_config: ModelConfig + ) -> nn.Module: device_config = vllm_config.device_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights if "lm_head.weight" in get_gguf_extra_tensor_names( - local_model_path, gguf_weights_map): + local_model_path, gguf_weights_map + ): model_config.hf_config.update({"tie_word_embeddings": True}) - weight_type_map = get_gguf_weight_type_map(model_config.model, - gguf_weights_map) + weight_type_map = get_gguf_weight_type_map(model_config.model, gguf_weights_map) # filter out unquantized modules to skip unquant_names = [ diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py index beec2d20ad69..890dd7231a0e 100644 --- a/vllm/model_executor/model_loader/online_quantization.py +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -9,8 +9,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.default_loader import DefaultModelLoader -from vllm.model_executor.model_loader.utils import ( - process_weights_after_loading) +from vllm.model_executor.model_loader.utils import process_weights_after_loading logger = init_logger(__name__) @@ -63,7 +62,8 @@ def maybe_save_metadata_and_attributes_for_weight_reloading( - model: nn.Module, model_config: ModelConfig): + model: nn.Module, model_config: ModelConfig +): # following is to support on the fly quantization, currently only supported # for torchao if model_config.quantization != "torchao": @@ -73,10 +73,12 @@ def maybe_save_metadata_and_attributes_for_weight_reloading( # In case `process_weights_after_loading` is called multiple times # we'll skip it at later times logger.warning( - "process_weights_after_loading already called for model %s", model) + "process_weights_after_loading already called for model %s", model + ) return from vllm.model_executor.model_loader.weight_utils import get_quant_config + quant_config = get_quant_config(model_config, None) # If checkpoint is already torchao serialized, this means it's @@ -86,8 +88,10 @@ def maybe_save_metadata_and_attributes_for_weight_reloading( # This step record the weights metadata and weight attributes so we can # restore the bfloat16 model weights during the relad step (R1 and R2) # see Notes in online_quantization.py for more details - if not (hasattr(quant_config, "is_checkpoint_torchao_serialized") and \ - not quant_config.is_checkpoint_torchao_serialized): + if not ( + hasattr(quant_config, "is_checkpoint_torchao_serialized") + and not quant_config.is_checkpoint_torchao_serialized + ): return # This is the I2 step of online quantiztion that saves @@ -144,23 +148,23 @@ def _bond_method_to_cls(func, obj): return types.MethodType(func, obj) -def load_weights_and_online_quantize(model_loader: DefaultModelLoader, - model: nn.Module, - model_config: ModelConfig) -> set[str]: +def load_weights_and_online_quantize( + model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig +) -> set[str]: # online quantization, right now only enabled for # torchao # R1, R2, R3, R4 in the Notes # TODO: Add fp8 support - assert model_config.quantization == "torchao", "online " \ - "quantization is only enabled for torchao currently" + assert model_config.quantization == "torchao", ( + "online quantization is only enabled for torchao currently" + ) # TODO: use create_weights to restore the weights to original state # Step R1: First restore the quantized weights to original bfloat16 # weights, with original metadata (shape, dtype, device) # and attributes, so that bfloat16 weights can be loaded properly - existing_param_names = dict( - model.named_parameters(remove_duplicate=False)).keys() + existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() named_modules = dict(model.named_modules(remove_duplicate=False)) model_device = None @@ -170,9 +174,11 @@ def load_weights_and_online_quantize(model_loader: DefaultModelLoader, _dtype = d["dtype"] _device = d["device"] if model_device is not None: - assert model_device == _device, "Expecting all weights " \ - "to be in the same device for now, got both: " \ + assert model_device == _device, ( + "Expecting all weights " + "to be in the same device for now, got both: " f"{model_device} and {_device}" + ) else: model_device = _device @@ -180,9 +186,10 @@ def load_weights_and_online_quantize(model_loader: DefaultModelLoader, module_name, weight_name = name.rsplit(".", 1) module = named_modules[module_name] setattr( - module, weight_name, - torch.nn.Parameter( - torch.empty(_shape, dtype=_dtype, device=_device))) + module, + weight_name, + torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)), + ) # recorded_weight_attr is # {"weight_name": {"weight_attr_key": attr}} @@ -196,8 +203,7 @@ def load_weights_and_online_quantize(model_loader: DefaultModelLoader, # "layer.1.weight": ..., # } # } - for full_weight_name, weight_attr_dict in \ - model.recorded_weight_attr.items(): + for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items(): for attr_name, attr in weight_attr_dict.items(): module_name, weight_name = full_weight_name.rsplit(".", 1) module = named_modules[module_name] @@ -207,7 +213,8 @@ def load_weights_and_online_quantize(model_loader: DefaultModelLoader, # Step I1: reload bfloat16 / high precision weights loaded_weights = model.load_weights( - model_loader.get_all_weights(model_config, model)) + model_loader.get_all_weights(model_config, model) + ) # Step I2: online quantize the weights # manually process weights after loading diff --git a/vllm/model_executor/model_loader/runai_streamer_loader.py b/vllm/model_executor/model_loader/runai_streamer_loader.py index dc941401a04e..50a92edd1162 100644 --- a/vllm/model_executor/model_loader/runai_streamer_loader.py +++ b/vllm/model_executor/model_loader/runai_streamer_loader.py @@ -13,16 +13,17 @@ from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( - download_safetensors_index_file_from_hf, download_weights_from_hf, - runai_safetensors_weights_iterator) -from vllm.transformers_utils.runai_utils import (is_runai_obj_uri, - list_safetensors) + download_safetensors_index_file_from_hf, + download_weights_from_hf, + runai_safetensors_weights_iterator, +) +from vllm.transformers_utils.runai_utils import is_runai_obj_uri, list_safetensors class RunaiModelStreamerLoader(BaseModelLoader): """ - Model loader that can load safetensors - files from local FS or S3 bucket. + Model loader that can load safetensors + files from local FS or S3 bucket. """ def __init__(self, load_config: LoadConfig): @@ -30,25 +31,28 @@ def __init__(self, load_config: LoadConfig): if load_config.model_loader_extra_config: extra_config = load_config.model_loader_extra_config - if ("concurrency" in extra_config - and isinstance(extra_config.get("concurrency"), int)): + if "concurrency" in extra_config and isinstance( + extra_config.get("concurrency"), int + ): os.environ["RUNAI_STREAMER_CONCURRENCY"] = str( - extra_config.get("concurrency")) + extra_config.get("concurrency") + ) - if ("memory_limit" in extra_config - and isinstance(extra_config.get("memory_limit"), int)): + if "memory_limit" in extra_config and isinstance( + extra_config.get("memory_limit"), int + ): os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str( - extra_config.get("memory_limit")) + extra_config.get("memory_limit") + ) - runai_streamer_s3_endpoint = os.getenv( - 'RUNAI_STREAMER_S3_ENDPOINT') - aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL') - if (runai_streamer_s3_endpoint is None - and aws_endpoint_url is not None): + runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT") + aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL") + if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None: os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]) -> list[str]: + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str] + ) -> list[str]: """Prepare weights for the model. If the model is not local, it will be downloaded.""" @@ -58,31 +62,34 @@ def _prepare_weights(self, model_name_or_path: str, safetensors_pattern = "*.safetensors" index_file = SAFE_WEIGHTS_INDEX_NAME - hf_folder = (model_name_or_path if (is_local or is_object_storage_path) - else download_weights_from_hf( - model_name_or_path, - self.load_config.download_dir, - [safetensors_pattern], - revision, - ignore_patterns=self.load_config.ignore_patterns, - )) + hf_folder = ( + model_name_or_path + if (is_local or is_object_storage_path) + else download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [safetensors_pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + ) hf_weights_files = list_safetensors(path=hf_folder) if not is_local and not is_object_storage_path: download_safetensors_index_file_from_hf( - model_name_or_path, index_file, self.load_config.download_dir, - revision) + model_name_or_path, index_file, self.load_config.download_dir, revision + ) if not hf_weights_files: raise RuntimeError( - f"Cannot find any safetensors model weights with " - f"`{model_name_or_path}`") + f"Cannot find any safetensors model weights with `{model_name_or_path}`" + ) return hf_weights_files def _get_weights_iterator( - self, model_or_path: str, - revision: str) -> Generator[tuple[str, torch.Tensor], None, None]: + self, model_or_path: str, revision: str + ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_weights_files = self._prepare_weights(model_or_path, revision) return runai_safetensors_weights_iterator( @@ -94,11 +101,11 @@ def download_model(self, model_config: ModelConfig) -> None: """Download model if necessary""" self._prepare_weights(model_config.model, model_config.revision) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: """Load weights into a model.""" model_weights = model_config.model if hasattr(model_config, "model_weights"): model_weights = model_config.model_weights model.load_weights( - self._get_weights_iterator(model_weights, model_config.revision)) + self._get_weights_iterator(model_weights, model_config.revision) + ) diff --git a/vllm/model_executor/model_loader/sharded_state_loader.py b/vllm/model_executor/model_loader/sharded_state_loader.py index a85ca065d1d2..d50a1a8f9dbf 100644 --- a/vllm/model_executor/model_loader/sharded_state_loader.py +++ b/vllm/model_executor/model_loader/sharded_state_loader.py @@ -15,7 +15,9 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, runai_safetensors_weights_iterator) + download_weights_from_hf, + runai_safetensors_weights_iterator, +) from vllm.transformers_utils.s3_utils import glob as s3_glob from vllm.transformers_utils.utils import is_s3 @@ -36,23 +38,30 @@ class ShardedStateLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) - extra_config = ({} if load_config.model_loader_extra_config is None - else load_config.model_loader_extra_config.copy()) + extra_config = ( + {} + if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy() + ) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) if extra_config: - raise ValueError(f"Unexpected extra config keys for load format " - f"{load_config.load_format}: " - f"{load_config.model_loader_extra_config.keys()}") + raise ValueError( + f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}" + ) @staticmethod def _filter_subtensors( - tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: + tensors: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: """ Filter out all tensors that share the same memory or a subset of the memory of another tensor. """ same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = ( - collections.defaultdict(list)) + collections.defaultdict(list) + ) for key, tensor in tensors.items(): if tensor.numel(): ptr = tensor.untyped_storage().data_ptr() @@ -80,8 +89,7 @@ def get_end_ptr(tensor: torch.Tensor) -> int: result[k] = t return result - def _prepare_weights(self, model_name_or_path: str, - revision: Optional[str]): + def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]): if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path): return model_name_or_path else: @@ -97,8 +105,7 @@ def _prepare_weights(self, model_name_or_path: str, def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: from vllm.distributed import get_tensor_model_parallel_rank model_weights = model_config.model @@ -115,15 +122,15 @@ def load_weights(self, model: nn.Module, filepaths = [] if is_s3(local_model_path): file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}" - filepaths = s3_glob(path=local_model_path, - allow_pattern=[file_pattern]) + filepaths = s3_glob(path=local_model_path, allow_pattern=[file_pattern]) else: filepaths = glob.glob(pattern) if not filepaths: # TODO: support un-sharded checkpoints too raise ValueError( f"Could not find checkpoint files '{pattern}', only " - f"pre-sharded checkpoints are currently supported!") + f"pre-sharded checkpoints are currently supported!" + ) state_dict = self._filter_subtensors(model.state_dict()) for key, tensor in self.iterate_over_files(filepaths): # If loading with LoRA enabled, additional padding may @@ -136,8 +143,7 @@ def load_weights(self, model: nn.Module, param_data = param_data.narrow(dim, 0, size) if tensor.shape != param_shape: logger.warning( - "loading tensor of shape %s into " - "parameter '%s' of shape %s", + "loading tensor of shape %s into parameter '%s' of shape %s", tensor.shape, key, param_shape, @@ -145,15 +151,16 @@ def load_weights(self, model: nn.Module, param_data.copy_(tensor) state_dict.pop(key) if state_dict: - raise ValueError( - f"Missing keys {tuple(state_dict)} in loaded state!") + raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") def iterate_over_files( - self, paths) -> Generator[tuple[str, torch.Tensor], None, None]: + self, paths + ) -> Generator[tuple[str, torch.Tensor], None, None]: if self.load_config.load_format == "runai_streamer_sharded": yield from runai_safetensors_weights_iterator(paths, True) else: from safetensors.torch import safe_open + for path in paths: with safe_open(path, framework="pt") as f: for key in f.keys(): # noqa: SIM118 diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 13f4eebf1038..9d58278f996b 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -22,11 +22,9 @@ from transformers import PretrainedConfig import vllm.envs as envs -from vllm.config import (ModelConfig, ParallelConfig, VllmConfig, - set_current_vllm_config) +from vllm.config import ModelConfig, ParallelConfig, VllmConfig, set_current_vllm_config from vllm.logger import init_logger -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser, PlaceholderModule @@ -34,11 +32,14 @@ from vllm.engine.arg_utils import EngineArgs try: - from tensorizer import (DecryptionParams, EncryptionParams, - TensorDeserializer, TensorSerializer) + from tensorizer import ( + DecryptionParams, + EncryptionParams, + TensorDeserializer, + TensorSerializer, + ) from tensorizer.stream_io import open_stream - from tensorizer.utils import (convert_bytes, get_mem_usage, - no_init_or_tensor) + from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor except ImportError: tensorizer = PlaceholderModule("tensorizer") @@ -52,9 +53,15 @@ no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor") __all__ = [ - 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', - 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', - 'no_init_or_tensor', 'TensorizerConfig' + "EncryptionParams", + "DecryptionParams", + "TensorDeserializer", + "TensorSerializer", + "open_stream", + "convert_bytes", + "get_mem_usage", + "no_init_or_tensor", + "TensorizerConfig", ] logger = init_logger(__name__) @@ -73,12 +80,12 @@ def tensorizer_kwargs_arg(value): raise argparse.ArgumentTypeError( f"Not deserializable to dict: {value}. serialization_kwargs and " f"deserialization_kwargs must be " - f"deserializable from a JSON string to a dictionary. ") + f"deserializable from a JSON string to a dictionary. " + ) return loaded class MetaTensorMode(TorchDispatchMode): - def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} @@ -88,8 +95,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) -def meta_tensor_mode(loading_code=None, ): - +def meta_tensor_mode( + loading_code=None, +): if loading_code is None: return _NoInitOrTensorImpl.context_manager() elif callable(loading_code): @@ -99,15 +107,15 @@ def meta_tensor_mode(loading_code=None, ): raise TypeError( "expected a callable to evaluate," " or None if being used as a context manager;" - f' got an object of type "{type(loading_code).__name__}" instead.') + f' got an object of type "{type(loading_code).__name__}" instead.' + ) class _NoInitOrTensorImpl: _MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm) _MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES) - is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", - default=False) + is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False) _count_active: int = 0 _count_active_lock = threading.Lock() @@ -139,7 +147,6 @@ def context_manager(cls): @staticmethod def _disable(func): - def wrapper(*args, **kwargs): if not _NoInitOrTensorImpl.is_active.get(): return func(*args, **kwargs) @@ -162,10 +169,10 @@ class TensorizerConfig(MutableMapping): stream_kwargs: Optional[dict[str, Any]] = None serialization_kwargs: Optional[dict[str, Any]] = None deserialization_kwargs: Optional[dict[str, Any]] = None - _extra_serialization_attrs: Optional[dict[str, Any]] = field(init=False, - default=None) - model_class: Optional[type[torch.nn.Module]] = field(init=False, - default=None) + _extra_serialization_attrs: Optional[dict[str, Any]] = field( + init=False, default=None + ) + model_class: Optional[type[torch.nn.Module]] = field(init=False, default=None) hf_config: Optional[PretrainedConfig] = field(init=False, default=None) dtype: Optional[Union[str, torch.dtype]] = field(init=False, default=None) _is_sharded: bool = field(init=False, default=False) @@ -220,19 +227,23 @@ class TensorizerConfig(MutableMapping): def __post_init__(self): # check if the configuration is for a sharded vLLM model - self._is_sharded = isinstance(self.tensorizer_uri, str) \ - and re.search(r'%0\dd', self.tensorizer_uri) is not None + self._is_sharded = ( + isinstance(self.tensorizer_uri, str) + and re.search(r"%0\dd", self.tensorizer_uri) is not None + ) if self.tensorizer_dir and self.lora_dir: raise ValueError( "Only one of tensorizer_dir or lora_dir may be specified. " "Use lora_dir exclusively when serializing LoRA adapters, " - "and tensorizer_dir or tensorizer_uri otherwise.") + "and tensorizer_dir or tensorizer_uri otherwise." + ) if self.tensorizer_dir and self.tensorizer_uri: logger.warning_once( "Provided both tensorizer_dir and tensorizer_uri. " "Inferring tensorizer_dir from tensorizer_uri as the " - "latter takes precedence.") + "latter takes precedence." + ) self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) if not self.tensorizer_uri: if self.lora_dir: @@ -240,11 +251,13 @@ def __post_init__(self): elif self.tensorizer_dir: self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors" else: - raise ValueError("Unable to resolve tensorizer_uri. " - "A valid tensorizer_uri or tensorizer_dir " - "must be provided for deserialization, and a " - "valid tensorizer_uri, tensorizer_uri, or " - "lora_dir for serialization.") + raise ValueError( + "Unable to resolve tensorizer_uri. " + "A valid tensorizer_uri or tensorizer_dir " + "must be provided for deserialization, and a " + "valid tensorizer_uri, tensorizer_uri, or " + "lora_dir for serialization." + ) else: self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) @@ -280,8 +293,12 @@ def to_serializable(self) -> dict[str, Any]: tc_dict = {} for k, v in raw_tc_dict.items(): - if (k not in blacklisted and k not in tc_dict - and not k.startswith("_") and v is not None): + if ( + k not in blacklisted + and k not in tc_dict + and not k.startswith("_") + and v is not None + ): tc_dict[k] = v return tc_dict @@ -293,26 +310,25 @@ def verify_with_parallel_config( self, parallel_config: "ParallelConfig", ) -> None: - if parallel_config.tensor_parallel_size > 1 \ - and not self._is_sharded: + if parallel_config.tensor_parallel_size > 1 and not self._is_sharded: raise ValueError( "For a sharded model, tensorizer_uri should include a" " string format template like '%04d' to be formatted" - " with the rank of the shard") + " with the rank of the shard" + ) def verify_with_model_config(self, model_config: "ModelConfig") -> None: - if (model_config.quantization is not None - and self.tensorizer_uri is not None): + if model_config.quantization is not None and self.tensorizer_uri is not None: logger.warning( "Loading a model using Tensorizer with quantization on vLLM" - " is unstable and may lead to errors.") + " is unstable and may lead to errors." + ) def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): if tensorizer_args is None: tensorizer_args = self._construct_tensorizer_args() - return open_stream(self.tensorizer_uri, - **tensorizer_args.stream_kwargs) + return open_stream(self.tensorizer_uri, **tensorizer_args.stream_kwargs) def keys(self): return self._keys @@ -354,34 +370,36 @@ def __init__(self, tensorizer_config: TensorizerConfig): for k, v in tensorizer_config.items(): setattr(self, k, v) self.file_obj = tensorizer_config.tensorizer_uri - self.s3_access_key_id = (tensorizer_config.s3_access_key_id - or envs.S3_ACCESS_KEY_ID) - self.s3_secret_access_key = (tensorizer_config.s3_secret_access_key - or envs.S3_SECRET_ACCESS_KEY) + self.s3_access_key_id = ( + tensorizer_config.s3_access_key_id or envs.S3_ACCESS_KEY_ID + ) + self.s3_secret_access_key = ( + tensorizer_config.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY + ) self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL self.stream_kwargs = { "s3_access_key_id": tensorizer_config.s3_access_key_id, "s3_secret_access_key": tensorizer_config.s3_secret_access_key, "s3_endpoint": tensorizer_config.s3_endpoint, - **(tensorizer_config.stream_kwargs or {}) + **(tensorizer_config.stream_kwargs or {}), } self.deserialization_kwargs = { "verify_hash": tensorizer_config.verify_hash, "encryption": tensorizer_config.encryption_keyfile, "num_readers": tensorizer_config.num_readers, - **(tensorizer_config.deserialization_kwargs or {}) + **(tensorizer_config.deserialization_kwargs or {}), } if self.encryption_keyfile: with open_stream( - tensorizer_config.encryption_keyfile, - **self.stream_kwargs, + tensorizer_config.encryption_keyfile, + **self.stream_kwargs, ) as stream: key = stream.read() decryption_params = DecryptionParams.from_key(key) - self.deserialization_kwargs['encryption'] = decryption_params + self.deserialization_kwargs["encryption"] = decryption_params @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -389,17 +407,20 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: # Tensorizer options arg group group = parser.add_argument_group( - 'tensorizer options', - description=('Options for configuring the behavior of the' - ' tensorizer deserializer when ' - 'load_format=tensorizer is specified when ' - 'initializing an LLMEngine, either via the CLI ' - 'when running the vLLM OpenAI inference server ' - 'with a JSON string passed to ' - '--model-loader-extra-config or as arguments given ' - 'to TensorizerConfig when passed to ' - 'model_loader_extra_config in the constructor ' - 'for LLMEngine.')) + "tensorizer options", + description=( + "Options for configuring the behavior of the" + " tensorizer deserializer when " + "load_format=tensorizer is specified when " + "initializing an LLMEngine, either via the CLI " + "when running the vLLM OpenAI inference server " + "with a JSON string passed to " + "--model-loader-extra-config or as arguments given " + "to TensorizerConfig when passed to " + "model_loader_extra_config in the constructor " + "for LLMEngine." + ), + ) group.add_argument( "--tensorizer-uri", @@ -419,7 +440,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default=None, help="The file path to a binary file containing a binary key to " - "use for decryption. Can be a file path or S3 network URI.") + "use for decryption. Can be a file path or S3 network URI.", + ) group.add_argument( "--num-readers", default=None, @@ -427,7 +449,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Controls how many threads are allowed to read concurrently " "from the source file. Default is `None`, which will dynamically " "set the number of readers based on the available resources " - "and model size. This greatly increases performance.") + "and model size. This greatly increases performance.", + ) group.add_argument( "--s3-access-key-id", type=str, @@ -455,72 +478,81 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": attrs = [attr.name for attr in dataclasses.fields(cls)] - tensorizer_args = cls(**{ - attr: getattr(args, attr) - for attr in attrs if hasattr(args, attr) - }) + tensorizer_args = cls( + **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)} + ) return tensorizer_args def _check_tensors_on_meta_device(model: nn.Module) -> None: for tensor in model.state_dict().values(): - if tensor.device.type == 'meta': + if tensor.device.type == "meta": raise ValueError( "The serialized model contains tensors on the meta device," " indicating that some tensors were not loaded properly." " Please check that the parameters of the model being" " specified match that of the serialized model, such as" - " its quantization.") + " its quantization." + ) def _resize_lora_embeddings(model: nn.Module): """Modify LoRA embedding layers to use bigger tensors to allow for adapter added tokens.""" for child in model.modules(): - if (isinstance(child, VocabParallelEmbedding) and child.weight.shape[0] - < child.num_embeddings_per_partition): - new_weight = torch.empty(child.num_embeddings_per_partition, - child.embedding_dim, - dtype=child.weight.dtype, - device=child.weight.device) - new_weight[:child.weight.shape[0]].copy_(child.weight.data) - new_weight[child.weight.shape[0]:].fill_(0) + if ( + isinstance(child, VocabParallelEmbedding) + and child.weight.shape[0] < child.num_embeddings_per_partition + ): + new_weight = torch.empty( + child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device, + ) + new_weight[: child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0] :].fill_(0) child.weight.data = new_weight -def init_tensorizer_model(tensorizer_config: TensorizerConfig, - vllm_config: VllmConfig) -> nn.Module: +def init_tensorizer_model( + tensorizer_config: TensorizerConfig, vllm_config: VllmConfig +) -> nn.Module: assert tensorizer_config.hf_config is not None model_args = tensorizer_config.hf_config model_args.torch_dtype = tensorizer_config.dtype assert tensorizer_config.model_class is not None # TODO: Do we need to consider old-style model class? - with meta_tensor_mode(), set_current_vllm_config(vllm_config, - check_compile=True): + with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True): return tensorizer_config.model_class(vllm_config=vllm_config) -def deserialize_tensorizer_model(model: nn.Module, - tensorizer_config: TensorizerConfig) -> None: +def deserialize_tensorizer_model( + model: nn.Module, tensorizer_config: TensorizerConfig +) -> None: tensorizer_args = tensorizer_config._construct_tensorizer_args() if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri): raise ValueError( f"{tensorizer_config.tensorizer_uri} is not a valid " f"tensorizer URI. Please check that the URI is correct. " f"It must either point to a local existing file, or have a " - f"S3, HTTP or HTTPS scheme.") + f"S3, HTTP or HTTPS scheme." + ) before_mem = get_mem_usage() start = time.perf_counter() - with open_stream( - tensorizer_config.tensorizer_uri, - mode="rb", - **tensorizer_args.stream_kwargs) as stream, TensorDeserializer( - stream, - dtype=tensorizer_config.dtype, - device=f'xpu:{torch.xpu.current_device()}' - if current_platform.is_xpu() else - f'cuda:{torch.cuda.current_device()}', - **tensorizer_args.deserialization_kwargs) as deserializer: + with ( + open_stream( + tensorizer_config.tensorizer_uri, mode="rb", **tensorizer_args.stream_kwargs + ) as stream, + TensorDeserializer( + stream, + dtype=tensorizer_config.dtype, + device=f"xpu:{torch.xpu.current_device()}" + if current_platform.is_xpu() + else f"cuda:{torch.cuda.current_device()}", + **tensorizer_args.deserialization_kwargs, + ) as deserializer, + ): deserializer.load_into_module(model) end = time.perf_counter() @@ -529,8 +561,9 @@ def deserialize_tensorizer_model(model: nn.Module, per_second = convert_bytes(deserializer.total_tensor_bytes / duration) after_mem = get_mem_usage() deserializer.close() - logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, - end - start, per_second) + logger.info( + "Deserialized %s in %0.2fs, %s/s", total_bytes_str, end - start, per_second + ) logger.info("Memory usage before: %s", before_mem) logger.info("Memory usage after: %s", after_mem) @@ -540,20 +573,21 @@ def deserialize_tensorizer_model(model: nn.Module, def tensorizer_weights_iterator( - tensorizer_args: "TensorizerArgs" + tensorizer_args: "TensorizerArgs", ) -> Generator[tuple[str, torch.Tensor], None, None]: - logger.warning("Deserializing HuggingFace models is not optimized for " - "loading on vLLM, as tensorizer is forced to load to CPU. " - "Consider deserializing a vLLM model instead for faster " - "load times. See the " - "examples/others/tensorize_vllm_model.py example script " - "for serializing vLLM models.") + logger.warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the " + "examples/others/tensorize_vllm_model.py example script " + "for serializing vLLM models." + ) deserializer_args = tensorizer_args.deserialization_kwargs stream_kwargs = tensorizer_args.stream_kwargs stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs) - with TensorDeserializer(stream, **deserializer_args, - device="cpu") as state: + with TensorDeserializer(stream, **deserializer_args, device="cpu") as state: yield from state.items() del state @@ -571,41 +605,54 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: bool: True if the model is a vLLM model, False otherwise. """ tensorizer_args = tensorizer_config._construct_tensorizer_args() - deserializer = TensorDeserializer(open_stream( - tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs), - **tensorizer_args.deserialization_kwargs, - lazy_load=True) + deserializer = TensorDeserializer( + open_stream(tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs), + **tensorizer_args.deserialization_kwargs, + lazy_load=True, + ) if tensorizer_config.vllm_tensorized: logger.warning( "Please note that newly serialized vLLM models are automatically " "inferred as vLLM models, so setting vllm_tensorized=True is " - "only necessary for models serialized prior to this change.") + "only necessary for models serialized prior to this change." + ) return True return ".vllm_tensorized_marker" in deserializer def serialize_extra_artifacts( - tensorizer_args: TensorizerArgs, - served_model_name: Union[str, list[str], None]) -> None: + tensorizer_args: TensorizerArgs, served_model_name: Union[str, list[str], None] +) -> None: if not isinstance(served_model_name, str): raise ValueError( f"served_model_name must be a str for serialize_extra_artifacts, " - f"not {type(served_model_name)}.") + f"not {type(served_model_name)}." + ) with tempfile.TemporaryDirectory() as tmpdir: - snapshot_download(served_model_name, - local_dir=tmpdir, - ignore_patterns=[ - "*.pt", "*.safetensors", "*.bin", "*.cache", - "*.gitattributes", "*.md" - ]) + snapshot_download( + served_model_name, + local_dir=tmpdir, + ignore_patterns=[ + "*.pt", + "*.safetensors", + "*.bin", + "*.cache", + "*.gitattributes", + "*.md", + ], + ) for artifact in os.scandir(tmpdir): if not artifact.is_file(): continue - with open(artifact.path, "rb") as f, open_stream( + with ( + open(artifact.path, "rb") as f, + open_stream( f"{tensorizer_args.tensorizer_dir}/{artifact.name}", mode="wb+", - **tensorizer_args.stream_kwargs) as stream: + **tensorizer_args.stream_kwargs, + ) as stream, + ): logger.info("Writing artifact %s", artifact.name) stream.write(f.read()) @@ -617,7 +664,8 @@ def serialize_vllm_model( ) -> nn.Module: model.register_parameter( "vllm_tensorized_marker", - nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + nn.Parameter(torch.tensor((1,), device="meta"), requires_grad=False), + ) tensorizer_args = tensorizer_config._construct_tensorizer_args() @@ -630,13 +678,17 @@ def serialize_vllm_model( output_file = tensorizer_args.tensorizer_uri if tensorizer_config._is_sharded: from vllm.distributed import get_tensor_model_parallel_rank + output_file = output_file % get_tensor_model_parallel_rank() - with open_stream(output_file, mode="wb+", - **tensorizer_args.stream_kwargs) as stream: - serializer = TensorSerializer(stream, - encryption=encryption_params, - **tensorizer_config.serialization_kwargs) + with open_stream( + output_file, mode="wb+", **tensorizer_args.stream_kwargs + ) as stream: + serializer = TensorSerializer( + stream, + encryption=encryption_params, + **tensorizer_config.serialization_kwargs, + ) serializer.write_module(model) serializer.close() @@ -646,29 +698,32 @@ def serialize_vllm_model( return model -def tensorize_vllm_model(engine_args: "EngineArgs", - tensorizer_config: TensorizerConfig, - generate_keyfile: bool = True): +def tensorize_vllm_model( + engine_args: "EngineArgs", + tensorizer_config: TensorizerConfig, + generate_keyfile: bool = True, +): """Utility to load a model and then serialize it with Tensorizer - Intended to be used separately from running a vLLM server since it - creates its own Engine instance. + Intended to be used separately from running a vLLM server since it + creates its own Engine instance. """ engine_config = engine_args.create_engine_config() tensorizer_config.verify_with_model_config(engine_config.model_config) - tensorizer_config.verify_with_parallel_config( - engine_config.parallel_config) + tensorizer_config.verify_with_parallel_config(engine_config.parallel_config) # generate the encryption key before creating the engine to support sharding - if generate_keyfile and (keyfile := - tensorizer_config.encryption_keyfile) is not None: + if ( + generate_keyfile + and (keyfile := tensorizer_config.encryption_keyfile) is not None + ): encryption_params = EncryptionParams.random() with open_stream( - keyfile, - mode="wb+", - s3_access_key_id=tensorizer_config.s3_access_key_id, - s3_secret_access_key=tensorizer_config.s3_secret_access_key, - s3_endpoint=tensorizer_config.s3_endpoint, + keyfile, + mode="wb+", + s3_access_key_id=tensorizer_config.s3_access_key_id, + s3_secret_access_key=tensorizer_config.s3_secret_access_key, + s3_endpoint=tensorizer_config.s3_endpoint, ) as stream: stream.write(encryption_params.key) @@ -683,8 +738,7 @@ def tensorize_vllm_model(engine_args: "EngineArgs", ) -def tensorize_lora_adapter(lora_path: str, - tensorizer_config: TensorizerConfig): +def tensorize_lora_adapter(lora_path: str, tensorizer_config: TensorizerConfig): """ Uses tensorizer to serialize a LoRA adapter. Assumes that the files needed to load a LoRA adapter are a safetensors-format file called @@ -720,19 +774,20 @@ def tensorize_lora_adapter(lora_path: str, tensorizer_args = tensorizer_config._construct_tensorizer_args() - with open_stream(f"{tensorizer_config.tensorizer_dir}/adapter_config.json", - mode="wb+", - **tensorizer_args.stream_kwargs) as f: - + with open_stream( + f"{tensorizer_config.tensorizer_dir}/adapter_config.json", + mode="wb+", + **tensorizer_args.stream_kwargs, + ) as f: f.write(json.dumps(config).encode("utf-8")) - lora_uri = (f"{tensorizer_config.tensorizer_dir}" - f"/adapter_model.tensors") - with open_stream(lora_uri, mode="wb+", - **tensorizer_args.stream_kwargs) as f: + lora_uri = f"{tensorizer_config.tensorizer_dir}/adapter_model.tensors" + with open_stream(lora_uri, mode="wb+", **tensorizer_args.stream_kwargs) as f: serializer = TensorSerializer(f) serializer.write_state_dict(tensors) serializer.close() - logger.info("Successfully serialized LoRA files to %s", - str(tensorizer_config.tensorizer_dir)) + logger.info( + "Successfully serialized LoRA files to %s", + str(tensorizer_config.tensorizer_dir), + ) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 65ea49c64294..5585a74f8926 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -13,11 +13,18 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model, - is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator) -from vllm.model_executor.model_loader.utils import (get_model_architecture, - initialize_model, - set_default_torch_dtype) + TensorizerConfig, + deserialize_tensorizer_model, + init_tensorizer_model, + is_vllm_tensorized, + serialize_vllm_model, + tensorizer_weights_iterator, +) +from vllm.model_executor.model_loader.utils import ( + get_model_architecture, + initialize_model, + set_default_torch_dtype, +) logger = init_logger(__name__) @@ -44,15 +51,18 @@ def __init__(self, load_config: LoadConfig): else: validate_config(load_config.model_loader_extra_config) self.tensorizer_config = TensorizerConfig( - **load_config.model_loader_extra_config["tensorizer_config"]) + **load_config.model_loader_extra_config["tensorizer_config"] + ) - def _verify_config(self, model_config: ModelConfig, - parallel_config: ParallelConfig): + def _verify_config( + self, model_config: ModelConfig, parallel_config: ParallelConfig + ): self.tensorizer_config.verify_with_model_config(model_config) self.tensorizer_config.verify_with_parallel_config(parallel_config) def _get_weights_iterator( - self, ) -> Generator[tuple[str, torch.Tensor], None, None]: + self, + ) -> Generator[tuple[str, torch.Tensor], None, None]: tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) @@ -82,8 +92,7 @@ def download_model(self, model_config: ModelConfig) -> None: with self.tensorizer_config.open_stream(): pass - def _patch_tensorizer_config( - self, model_config: ModelConfig) -> TensorizerConfig: + def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig: model_class = get_model_architecture(model_config)[0] tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class @@ -91,8 +100,7 @@ def _patch_tensorizer_config( tensorizer_config.dtype = model_config.dtype return tensorizer_config - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: """Load serialized model weights with tensorizer. Expects a vLLM-tensorized model. See the @@ -104,8 +112,9 @@ def load_weights(self, model: nn.Module, else: model.load_weights(self._get_weights_iterator()) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model( + self, vllm_config: VllmConfig, model_config: ModelConfig + ) -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) @@ -113,8 +122,8 @@ def load_model(self, vllm_config: VllmConfig, from vllm.distributed import get_tensor_model_parallel_rank self.tensorizer_config.tensorizer_uri = ( - self.tensorizer_config.tensorizer_uri % - get_tensor_model_parallel_rank()) + self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank() + ) if is_vllm_tensorized(self.tensorizer_config): tensorizer_config = self._patch_tensorizer_config(model_config) @@ -122,8 +131,8 @@ def load_model(self, vllm_config: VllmConfig, with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = init_tensorizer_model( - tensorizer_config=tensorizer_config, - vllm_config=vllm_config) + tensorizer_config=tensorizer_config, vllm_config=vllm_config + ) self.load_weights(model, model_config) return model return self._load_model_serialized_cpu(vllm_config=vllm_config) diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py index a70cdeb483e6..fc97003de8e3 100644 --- a/vllm/model_executor/model_loader/tpu.py +++ b/vllm/model_executor/model_loader/tpu.py @@ -13,7 +13,10 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) + initialize_model, + process_weights_after_loading, + set_default_torch_dtype, +) logger = init_logger(__name__) @@ -34,33 +37,31 @@ def load_model( self.counter_before_loading_weights = time.perf_counter() model_config = vllm_config.model_config assert model_config.quantization is None, "Quantization not supported" - target_device = torch.device('cpu') + target_device = torch.device("cpu") with set_default_torch_dtype(model_config.dtype): with target_device: model = initialize_model(vllm_config=vllm_config) load_format = vllm_config.load_config.load_format if load_format != "dummy": - weights_to_load = { - name - for name, _ in model.named_parameters() - } + weights_to_load = {name for name, _ in model.named_parameters()} all_weights = self.get_all_weights(model_config, model) loaded_weights = model.load_weights(all_weights) self.counter_after_loading_weights = time.perf_counter() logger.info( "Loading weights took %.2f seconds", - self.counter_after_loading_weights - - self.counter_before_loading_weights) + self.counter_after_loading_weights + - self.counter_before_loading_weights, + ) # We only enable strict check for non-quantized models # that have loaded weights tracking currently. - if model_config.quantization is None and \ - loaded_weights is not None: + if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: raise ValueError( "Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") + f"checkpoint: {weights_not_loaded}" + ) else: logger.info("Use dummy weight during weight loading.") @@ -68,11 +69,13 @@ def load_model( counter_before_partition = time.perf_counter() model = model.eval() - model = model.to('xla') + model = model.to("xla") shard_model(model, mesh) counter_after_partition = time.perf_counter() - logger.info("Partition model took %.2f seconds", - counter_after_partition - counter_before_partition) + logger.info( + "Partition model took %.2f seconds", + counter_after_partition - counter_before_partition, + ) # Ensure the model is properly loaded. self._check_model_is_loaded(mesh, model) @@ -82,12 +85,12 @@ def load_model( if not model_config.is_multimodal_model: model.model = torch.compile(model.model, backend="openxla") else: - model.language_model.model = \ - torch.compile(model.language_model.model, backend="openxla") + model.language_model.model = torch.compile( + model.language_model.model, backend="openxla" + ) return model - def _check_model_is_loaded(self, mesh: Optional[xs.Mesh], - model: nn.Module) -> None: + def _check_model_is_loaded(self, mesh: Optional[xs.Mesh], model: nn.Module) -> None: """ Ensure the model is properly loaded. 1. All model parameters and buffers are on XLA device. @@ -99,16 +102,18 @@ def _check_model_is_loaded(self, mesh: Optional[xs.Mesh], # Check parameters for name, param in model.named_parameters(): assert param.device.type == device_type, ( - f"Parameter {name} is on {param.device.type} " - f"instead of {device_type}") + f"Parameter {name} is on {param.device.type} instead of {device_type}" + ) # Check buffers for name, buffer in model.named_buffers(): assert buffer.device.type == device_type, ( - f"Buffer {name} is on {buffer.device.type} " - f"instead of {device_type}") + f"Buffer {name} is on {buffer.device.type} instead of {device_type}" + ) for module in model.modules(): - if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'): - raise AssertionError("QKVParallelLinear should be replaced by \ - XlaQKVParallelLinear under SPMD mode.") + if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"): + raise AssertionError( + "QKVParallelLinear should be replaced by \ + XlaQKVParallelLinear under SPMD mode." + ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 293edadcc240..ba8d53c0ba14 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for selecting and loading models.""" + import contextlib import inspect import warnings @@ -17,12 +18,16 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.models.adapters import ( - as_embedding_model, as_reward_model, as_seq_cls_model, - try_create_mm_pooling_model_cls) -from vllm.model_executor.models.interfaces import (SupportsQuant, - supports_multimodal) + as_embedding_model, + as_reward_model, + as_seq_cls_model, + try_create_mm_pooling_model_cls, +) +from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -57,16 +62,16 @@ def initialize_model( all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class - with set_current_vllm_config(vllm_config, - check_compile=True, - prefix=prefix): + with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix): return model_class(vllm_config=vllm_config, prefix=prefix) - msg = ("vLLM model class should accept `vllm_config` and `prefix` as " - "input arguments. Possibly you have an old-style model class" - " registered from out of tree and it is used for new vLLM version. " - "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " - "for the design and update the model class accordingly.") + msg = ( + "vLLM model class should accept `vllm_config` and `prefix` as " + "input arguments. Possibly you have an old-style model class" + " registered from out of tree and it is used for new vLLM version. " + "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " + "for the design and update the model class accordingly." + ) warnings.warn(msg, DeprecationWarning, stacklevel=2) logger.warning( @@ -87,20 +92,19 @@ def initialize_model( kwargs["lora_config"] = vllm_config.lora_config if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config - with set_current_vllm_config(vllm_config, - check_compile=True, - prefix=prefix): + with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix): return model_class(**kwargs) -def process_weights_after_loading(model: nn.Module, model_config: ModelConfig, - target_device: torch.device) -> None: - +def process_weights_after_loading( + model: nn.Module, model_config: ModelConfig, target_device: torch.device +) -> None: # to avoid circular dependency from vllm.model_executor.model_loader.online_quantization import ( - maybe_save_metadata_and_attributes_for_weight_reloading) - maybe_save_metadata_and_attributes_for_weight_reloading( - model, model_config) + maybe_save_metadata_and_attributes_for_weight_reloading, + ) + + maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config) for _, module in model.named_modules(): if isinstance(module, QKVCrossParallelLinear): @@ -122,16 +126,16 @@ def process_weights_after_loading(model: nn.Module, model_config: ModelConfig, # NOTE: This intentionally happens after other modules so we can easily # decompress the weights for MLA. for _, module in model.named_modules(): - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): + if isinstance(module, Attention) and hasattr( + module, "process_weights_after_loading" + ): # TODO(lucas): see if there is a way to unify the signatures # of process_weights_after_loading module.process_weights_after_loading(model_config.dtype) @contextmanager -def device_loading_context(module: torch.nn.Module, - target_device: torch.device): +def device_loading_context(module: torch.nn.Module, target_device: torch.device): if target_device.type == "cpu": # If target is CPU, no need to move anything yield module @@ -176,8 +180,7 @@ def device_loading_context(module: torch.nn.Module, """Caches the outputs of `_get_model_architecture`.""" -def _get_model_architecture( - model_config: ModelConfig) -> tuple[type[nn.Module], str]: +def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) model_cls, arch = model_config.registry.resolve_model_cls( @@ -191,7 +194,9 @@ def _get_model_architecture( logger.warning_once( "%s has no vLLM implementation, falling back to Transformers " "implementation. Some features may not be supported and " - "performance may not be optimal.", arch) + "performance may not be optimal.", + arch, + ) convert_type = model_config.convert_type if convert_type != "none" and supports_multimodal(model_cls): @@ -220,16 +225,17 @@ def _get_model_architecture( return model_cls, arch -def get_model_architecture( - model_config: ModelConfig) -> tuple[type[nn.Module], str]: - key = hash(( - model_config.model, - model_config.convert_type, - model_config.runner_type, - model_config.trust_remote_code, - model_config.model_impl, - tuple(getattr(model_config.hf_config, "architectures", [])), - )) +def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: + key = hash( + ( + model_config.model, + model_config.convert_type, + model_config.runner_type, + model_config.trust_remote_code, + model_config.model_impl, + tuple(getattr(model_config.hf_config, "architectures", [])), + ) + ) if key in _MODEL_ARCH_BY_HASH: return _MODEL_ARCH_BY_HASH[key] @@ -253,9 +259,9 @@ class ParamMapping: It creates a bidirectional mapping between packed parameters and their constituent parts. """ + packed_mapping: dict[str, list[str]] - inverse_packed_mapping: dict[str, tuple[str, - int]] = field(default_factory=dict) + inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict) def __post_init__(self): for packed_name, sub_params in self.packed_mapping.items(): @@ -268,16 +274,16 @@ def __post_init__(self): index, ) - def get_sub_modules(self, - module_name: str) -> Optional[tuple[str, list[str]]]: + def get_sub_modules(self, module_name: str) -> Optional[tuple[str, list[str]]]: for key, value in self.packed_mapping.items(): if module_name.endswith(key): return key, value return None -def configure_quant_config(quant_config: QuantizationConfig, - model_class: type[nn.Module]): +def configure_quant_config( + quant_config: QuantizationConfig, model_class: type[nn.Module] +): """ Pass packed_modules_mapping by reference to quant_config so that quant_config can properly match fused modules diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 6c5f7bbcc8aa..c40185c1c084 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for downloading and initializing model weights.""" + import concurrent.futures import fnmatch import glob @@ -28,18 +29,18 @@ from vllm.config.load import LoadConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import (QuantizationConfig, - get_quantization_config) +from vllm.model_executor.layers.quantization import ( + QuantizationConfig, + get_quantization_config, +) from vllm.platforms import current_platform from vllm.utils import PlaceholderModule try: from runai_model_streamer import SafetensorsStreamer except ImportError: - runai_model_streamer = PlaceholderModule( - "runai_model_streamer") # type: ignore[assignment] - SafetensorsStreamer = runai_model_streamer.placeholder_attr( - "SafetensorsStreamer") + runai_model_streamer = PlaceholderModule("runai_model_streamer") # type: ignore[assignment] + SafetensorsStreamer = runai_model_streamer.placeholder_attr("SafetensorsStreamer") try: import gguf @@ -50,8 +51,7 @@ from fastsafetensors import SafeTensorsFileLoader, SingleGroup except ImportError: fastsafetensors = PlaceholderModule("fastsafetensors") - SafeTensorsFileLoader = fastsafetensors.placeholder_attr( - "SafeTensorsFileLoader") + SafeTensorsFileLoader = fastsafetensors.placeholder_attr("SafeTensorsFileLoader") SingleGroup = fastsafetensors.placeholder_attr("SingleGroup") logger = init_logger(__name__) @@ -64,12 +64,12 @@ def enable_hf_transfer(): - """automatically activates hf_transfer - """ + """automatically activates hf_transfer""" if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: try: # enable hf hub transfer if available import hf_transfer # type: ignore # noqa + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True except ImportError: pass @@ -79,13 +79,11 @@ def enable_hf_transfer(): class DisabledTqdm(tqdm): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, disable=True) -def get_lock(model_name_or_path: Union[str, Path], - cache_dir: Optional[str] = None): +def get_lock(model_name_or_path: Union[str, Path], cache_dir: Optional[str] = None): lock_dir = cache_dir or temp_dir model_name_or_path = str(model_name_or_path) os.makedirs(os.path.dirname(lock_dir), exist_ok=True) @@ -94,15 +92,14 @@ def get_lock(model_name_or_path: Union[str, Path], # add hash to avoid conflict with old users' lock files lock_file_name = hash_name + model_name + ".lock" # mode 0o666 is required for the filelock to be shared across users - lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), - mode=0o666) + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) return lock @contextmanager -def atomic_writer(filepath: Union[str, Path], - mode: str = 'w', - encoding: Optional[str] = None) -> Generator[IO]: +def atomic_writer( + filepath: Union[str, Path], mode: str = "w", encoding: Optional[str] = None +) -> Generator[IO]: """ Context manager that provides an atomic file writing routine. @@ -133,8 +130,8 @@ def atomic_writer(filepath: Union[str, Path], except Exception: logger.exception( - "Error during atomic write. Original file '%s' not modified", - filepath) + "Error during atomic write. Original file '%s' not modified", filepath + ) raise finally: # Clean up the temporary file if it still exists. @@ -143,16 +140,16 @@ def atomic_writer(filepath: Union[str, Path], def maybe_download_from_modelscope( - model: str, - revision: Optional[str] = None, - download_dir: Optional[str] = None, - ignore_patterns: Optional[Union[str, list[str]]] = None, - allow_patterns: Optional[Union[list[str], - str]] = None) -> Optional[str]: + model: str, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + ignore_patterns: Optional[Union[str, list[str]]] = None, + allow_patterns: Optional[Union[list[str], str]] = None, +) -> Optional[str]: """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. - Returns the path to the downloaded model, or None if the model is not - downloaded from ModelScope.""" + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" if envs.VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. @@ -226,9 +223,9 @@ def convert_bin_to_safetensor_file( # TODO(woosuk): Move this to other place. -def get_quant_config(model_config: ModelConfig, - load_config: LoadConfig) -> QuantizationConfig: - +def get_quant_config( + model_config: ModelConfig, load_config: LoadConfig +) -> QuantizationConfig: quant_cls = get_quantization_config(model_config.quantization) # GGUF doesn't have config file @@ -236,16 +233,14 @@ def get_quant_config(model_config: ModelConfig, return quant_cls() # Read the quantization config from the HF model config, if available. - hf_quant_config = getattr(model_config.hf_config, "quantization_config", - None) + hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) # some vision model may keep quantization_config in their text_config hf_text_config = getattr(model_config.hf_config, "text_config", None) if hf_quant_config is None and hf_text_config is not None: hf_quant_config = getattr(hf_text_config, "quantization_config", None) if hf_quant_config is None: # compressed-tensors uses a compressions_config - hf_quant_config = getattr(model_config.hf_config, "compression_config", - None) + hf_quant_config = getattr(model_config.hf_config, "compression_config", None) if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) @@ -253,8 +248,7 @@ def get_quant_config(model_config: ModelConfig, # if hf_quant_config is None, we will try to get config from # hf_overrides hf_overrides = model_config.hf_overrides - quantization_config_file = hf_overrides.get("quantization_config_file", - None) + quantization_config_file = hf_overrides.get("quantization_config_file", None) if quantization_config_file is not None: if hasattr(quant_cls, "from_config_file"): return quant_cls.from_config_file(quantization_config_file) @@ -262,9 +256,9 @@ def get_quant_config(model_config: ModelConfig, raise NotImplementedError( "from_config_file is specified in hf_override config, " "but quant_cls.from_config_file is not implemented in " - f"{quant_cls}") - quantization_config_json = hf_overrides.get( - "quantization_config_dict_json", None) + f"{quant_cls}" + ) + quantization_config_json = hf_overrides.get("quantization_config_dict_json", None) if quantization_config_json is not None: if hasattr(quant_cls, "from_config_dict_json"): return quant_cls.from_config_dict_json(quantization_config_json) @@ -272,17 +266,21 @@ def get_quant_config(model_config: ModelConfig, raise NotImplementedError( "from_config_dict_json is specified in hf_override config, " "but quant_cls.from_config_dict_json is not implemented in " - f"{quant_cls}") + f"{quant_cls}" + ) # Inflight BNB quantization if model_config.quantization == "bitsandbytes": return quant_cls.from_config({}) - model_name_or_path = maybe_download_from_modelscope( - model_config.model, - revision=model_config.revision, - download_dir=load_config.download_dir, - allow_patterns=["*.json"], - ) or model_config.model + model_name_or_path = ( + maybe_download_from_modelscope( + model_config.model, + revision=model_config.revision, + download_dir=load_config.download_dir, + allow_patterns=["*.json"], + ) + or model_config.model + ) is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. @@ -307,16 +305,15 @@ def get_quant_config(model_config: ModelConfig, config_files = glob.glob(os.path.join(hf_folder, "*.json")) quant_config_files = [ - f for f in config_files if any( - f.endswith(x) for x in possible_config_filenames) + f for f in config_files if any(f.endswith(x) for x in possible_config_filenames) ] if len(quant_config_files) == 0: - raise ValueError( - f"Cannot find the config file for {model_config.quantization}") + raise ValueError(f"Cannot find the config file for {model_config.quantization}") if len(quant_config_files) > 1: raise ValueError( f"Found multiple config files for {model_config.quantization}: " - f"{quant_config_files}") + f"{quant_config_files}" + ) quant_config_file = quant_config_files[0] with open(quant_config_file) as f: @@ -330,7 +327,8 @@ def get_quant_config(model_config: ModelConfig, else: raise ValueError( f"Unsupported quantization config" - f" found for {model_config.quantization} in {f}.") + f" found for {model_config.quantization} in {f}." + ) return quant_cls.from_config(config) @@ -399,9 +397,7 @@ def download_weights_from_hf( # so we only have to call snapshot_download once. try: fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, - detail=False, - revision=revision) + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) # Use the first pattern found in the HF repo's files. for pattern in allow_patterns: @@ -413,7 +409,10 @@ def download_weights_from_hf( logger.warning( "Failed to get file list for '%s'. Trying each pattern in " "allow_patterns individually until weights have been " - "downloaded. Error: %s", model_name_or_path, e) + "downloaded. Error: %s", + model_name_or_path, + e, + ) logger.info("Using model weights format %s", allow_patterns) # Use file lock to prevent multiple processes from @@ -436,8 +435,11 @@ def download_weights_from_hf( break time_taken = time.perf_counter() - start_time if time_taken > 0.5: - logger.info("Time spent downloading weights for %s: %.6f seconds", - model_name_or_path, time_taken) + logger.info( + "Time spent downloading weights for %s: %.6f seconds", + model_name_or_path, + time_taken, + ) return hf_folder @@ -481,9 +483,9 @@ def download_safetensors_index_file_from_hf( # Passing both of these to the weight loader functionality breaks. # So, we use the index_file to # look up which safetensors files should be used. -def filter_duplicate_safetensors_files(hf_weights_files: list[str], - hf_folder: str, - index_file: str) -> list[str]: +def filter_duplicate_safetensors_files( + hf_weights_files: list[str], hf_folder: str, index_file: str +) -> list[str]: # model.safetensors.index.json is a mapping from keys in the # torch state_dict to safetensors file holding that weight. index_file_name = os.path.join(hf_folder, index_file) @@ -496,17 +498,13 @@ def filter_duplicate_safetensors_files(hf_weights_files: list[str], weight_map = json.load(f)["weight_map"] weight_files_in_index = set() for weight_name in weight_map: - weight_files_in_index.add( - os.path.join(hf_folder, weight_map[weight_name])) + weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name])) # Filter out any fields that are not found in the index file. - hf_weights_files = [ - f for f in hf_weights_files if f in weight_files_in_index - ] + hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index] return hf_weights_files -def filter_files_not_needed_for_inference( - hf_weights_files: list[str]) -> list[str]: +def filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]: """ Exclude files that are not needed for inference. @@ -520,8 +518,7 @@ def filter_files_not_needed_for_inference( "scaler.pt", ] hf_weights_files = [ - f for f in hf_weights_files - if not any(f.endswith(x) for x in blacklist) + f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist) ] return hf_weights_files @@ -534,8 +531,9 @@ def filter_files_not_needed_for_inference( def enable_tqdm(use_tqdm_on_load: bool): - return use_tqdm_on_load and (not torch.distributed.is_initialized() - or torch.distributed.get_rank() == 0) + return use_tqdm_on_load and ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) def np_cache_weights_iterator( @@ -560,14 +558,12 @@ def np_cache_weights_iterator( if not os.path.exists(weight_names_file): weight_names: list[str] = [] for bin_file in tqdm( - hf_weights_files, - desc="Loading np_cache checkpoint shards", - disable=not enable_tqdm(use_tqdm_on_load), - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading np_cache checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, ): - state = torch.load(bin_file, - map_location="cpu", - weights_only=True) + state = torch.load(bin_file, map_location="cpu", weights_only=True) for name, param in state.items(): param_path = os.path.join(np_folder, name) with open(param_path, "wb") as f: @@ -597,10 +593,10 @@ def safetensors_weights_iterator( loading_desc += " (eager)" for st_file in tqdm( - hf_weights_files, - desc=loading_desc, - disable=not enable_tqdm(use_tqdm_on_load), - bar_format=_BAR_FORMAT, + hf_weights_files, + desc=loading_desc, + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, ): if safetensors_load_strategy == "eager": with open(st_file, "rb") as f: @@ -624,12 +620,8 @@ def _load_file(st_file: str): result = load_file(st_file, device="cpu") return result - with concurrent.futures.ThreadPoolExecutor( - max_workers=max_workers) as executor: - futures = [ - executor.submit(_load_file, st_file) - for st_file in hf_weights_files - ] + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files] futures_iter = tqdm( concurrent.futures.as_completed(futures), total=len(hf_weights_files), @@ -652,7 +644,8 @@ def runai_safetensors_weights_iterator( streamer.stream_files(hf_weights_files) total_tensors = sum( len(tensors_meta) - for tensors_meta in streamer.files_to_tensors_metadata.values()) + for tensors_meta in streamer.files_to_tensors_metadata.values() + ) tensor_iter = tqdm( streamer.get_tensors(), @@ -689,19 +682,19 @@ def fastsafetensors_weights_iterator( else: pg = SingleGroup() - device = torch.device(f'cuda:{pg.rank()}') + device = torch.device(f"cuda:{pg.rank()}") weight_files_sub_lists = [ - hf_weights_files[i:i + pg.size()] + hf_weights_files[i : i + pg.size()] for i in range(0, len(hf_weights_files), pg.size()) ] nogds = False for f_list in tqdm( - weight_files_sub_lists, - desc="Loading safetensors using Fastsafetensor loader", - disable=not enable_tqdm(use_tqdm_on_load), - bar_format=_BAR_FORMAT, + weight_files_sub_lists, + desc="Loading safetensors using Fastsafetensor loader", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, ): loader = _init_loader(pg, device, f_list, nogds=nogds) try: @@ -738,14 +731,14 @@ def pt_weights_iterator( ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" for bin_file in tqdm( - hf_weights_files, - desc="Loading pt checkpoint shards", - disable=not enable_tqdm(use_tqdm_on_load), - bar_format=_BAR_FORMAT, + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm(use_tqdm_on_load), + bar_format=_BAR_FORMAT, ): - state = torch.load(bin_file, - map_location=pt_load_map_location, - weights_only=True) + state = torch.load( + bin_file, map_location=pt_load_map_location, weights_only=True + ) yield from state.items() del state @@ -759,15 +752,13 @@ def multi_thread_pt_weights_iterator( """Multi-Thread iterate over the weights in the model bin/pt files.""" def _load_file(bin_file: str): - return torch.load(bin_file, - map_location=pt_load_map_location, - weights_only=True) + return torch.load( + bin_file, map_location=pt_load_map_location, weights_only=True + ) - with concurrent.futures.ThreadPoolExecutor( - max_workers=max_workers) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ - executor.submit(_load_file, bin_file) - for bin_file in hf_weights_files + executor.submit(_load_file, bin_file) for bin_file in hf_weights_files ] futures_iter = tqdm( concurrent.futures.as_completed(futures), @@ -784,7 +775,8 @@ def _load_file(bin_file: str): def get_gguf_extra_tensor_names( - gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]: + gguf_file: str, gguf_to_hf_name_map: dict[str, str] +) -> list[str]: reader = gguf.GGUFReader(gguf_file) expected_gguf_keys = set(gguf_to_hf_name_map.keys()) exact_gguf_keys = set([tensor.name for tensor in reader.tensors]) @@ -793,14 +785,16 @@ def get_gguf_extra_tensor_names( def get_gguf_weight_type_map( - gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]: + gguf_file: str, gguf_to_hf_name_map: dict[str, str] +) -> dict[str, str]: """ Return GGUF mapped weight's name and its quant type """ reader = gguf.GGUFReader(gguf_file) return { gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name - for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map + for tensor in reader.tensors + if tensor.name in gguf_to_hf_name_map } @@ -850,8 +844,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: return x -def default_weight_loader(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" try: if param.numel() == 1 and loaded_weight.numel() == 1: @@ -862,7 +855,8 @@ def default_weight_loader(param: torch.Tensor, else: assert param.size() == loaded_weight.size(), ( f"Attempted to load weight ({loaded_weight.size()}) " - f"into parameter ({param.size()})") + f"into parameter ({param.size()})" + ) param.data.copy_(loaded_weight) except Exception: @@ -871,8 +865,9 @@ def default_weight_loader(param: torch.Tensor, raise -def row_parallel_weight_loader(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: +def row_parallel_weight_loader( + param: torch.Tensor, loaded_weight: torch.Tensor +) -> None: """Load weights that are row-parallelized.""" tp_rank = get_tensor_model_parallel_rank() shard_dim = 0 if param.dim() != 1 else None @@ -904,12 +899,11 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: def composed_weight_loader( - loader: LoaderFunction, fn: Callable[[torch.Tensor], - torch.Tensor]) -> LoaderFunction: + loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor] +) -> LoaderFunction: """Create a weight loader that post-processes the weights after loading""" - def composed_loader(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: + def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: loader(param, loaded_weight) param.data.copy_(fn(param)) return @@ -945,13 +939,18 @@ def initialize_dummy_weights( # from a CPU tensor. # Note: We avoid using torch.rank_like as it doesn't currently # support the generator argument. - param.copy_((high - low) * - torch.rand(param.shape, - generator=generator, - dtype=param.dtype, - layout=param.layout, - requires_grad=param.requires_grad, - device="cpu") + low) + param.copy_( + (high - low) + * torch.rand( + param.shape, + generator=generator, + dtype=param.dtype, + layout=param.layout, + requires_grad=param.requires_grad, + device="cpu", + ) + + low + ) torch._sync(param) continue @@ -961,8 +960,7 @@ def initialize_dummy_weights( # uniform_ doesn't support < 16-bit datatypes (FP8) dtype = param.data.dtype tmp_param = param.data.to(torch.float16) - tmp_param = tmp_param.uniform_(low, high, - generator=generator).to(dtype) + tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype) param.data.copy_(tmp_param) else: param.uniform_(low, high, generator=generator) @@ -991,7 +989,8 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: "This format is deprecated in favor of separate k_scale and " "v_scale tensors and will be removed in a future release. " "Functionally, we will remap kv_scale to k_scale and duplicate " - "k_scale to v_scale") + "k_scale to v_scale" + ) # NOTE: we remap the deprecated kv_scale to k_scale remapped_name = name.replace(".kv_scale", ".attn.k_scale") if remapped_name not in params_dict: @@ -1005,23 +1004,26 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: if any("mla_attn" in key for key in params_dict): attn_str = "mla_attn.mla_attn" - logger.debug_once(f"Found mla_attn with k_scale and v_scale in " - f"the checkpoint, using {attn_str} as attn_str") + logger.debug_once( + f"Found mla_attn with k_scale and v_scale in " + f"the checkpoint, using {attn_str} as attn_str" + ) else: attn_str = "attn" # Define scale name mapping patterns in order of precedence scale_mapping_patterns = [ # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale -> # .self_attn.attn.{k,v}_scale - (r"\.self_attn\.([kv])_proj\.([kv])_scale$", - rf".self_attn.{attn_str}.\2_scale"), + ( + r"\.self_attn\.([kv])_proj\.([kv])_scale$", + rf".self_attn.{attn_str}.\2_scale", + ), # QKV proj format: .self_attn.qkv_proj.{k,v}_scale -> # .self_attn.attn.{k,v}_scale (r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), # Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale -> # .self_attn.attn.{k,v}_scale - (r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale" - ), + (r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), # Default format: .{k,v}_scale -> .attn.{k,v}_scale (r"\.([kv])_scale$", r".attn.\1_scale"), ] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 4ccba64f2c11..b56cb3340048 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,13 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .interfaces import (HasInnerState, SupportsLoRA, SupportsMRoPE, - SupportsMultiModal, SupportsPP, SupportsTranscription, - SupportsV0Only, has_inner_state, supports_lora, - supports_mrope, supports_multimodal, supports_pp, - supports_transcription, supports_v0_only) -from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, - is_pooling_model, is_text_generation_model) +from .interfaces import ( + HasInnerState, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, + SupportsTranscription, + SupportsV0Only, + has_inner_state, + supports_lora, + supports_mrope, + supports_multimodal, + supports_pp, + supports_transcription, + supports_v0_only, +) +from .interfaces_base import ( + VllmModelForPooling, + VllmModelForTextGeneration, + is_pooling_model, + is_text_generation_model, +) from .registry import ModelRegistry __all__ = [ diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index c4328a176a5d..fd8a0b87e43e 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -13,8 +13,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig -from vllm.transformers_utils.config import (get_hf_file_bytes, - get_hf_file_to_dict) +from vllm.transformers_utils.config import get_hf_file_bytes, get_hf_file_to_dict from .interfaces_base import VllmModelForPooling, is_pooling_model @@ -37,8 +36,9 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: """Load Sentence-Transformers Dense projection layers.""" try: - modules = get_hf_file_to_dict("modules.json", model_config.model, - model_config.revision) + modules = get_hf_file_to_dict( + "modules.json", model_config.model, model_config.revision + ) if not modules: return None @@ -46,8 +46,7 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: modules = modules.get("modules", []) dense_modules = [ - m for m in modules - if m.get("type") == "sentence_transformers.models.Dense" + m for m in modules if m.get("type") == "sentence_transformers.models.Dense" ] if not dense_modules: return None @@ -57,15 +56,18 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: folder = module.get("path", "") config_path = f"{folder}/config.json" if folder else "config.json" - layer_config = get_hf_file_to_dict(config_path, model_config.model, - model_config.revision) + layer_config = get_hf_file_to_dict( + config_path, model_config.model, model_config.revision + ) if not layer_config: continue - linear = nn.Linear(layer_config.get("in_features", 768), - layer_config.get("out_features", 768), - bias=layer_config.get("bias", True), - dtype=model_config.head_dtype) + linear = nn.Linear( + layer_config.get("in_features", 768), + layer_config.get("out_features", 768), + bias=layer_config.get("bias", True), + dtype=model_config.head_dtype, + ) if not _load_dense_weights(linear, folder, model_config): continue @@ -80,40 +82,45 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: return None -def _load_dense_weights(linear: nn.Linear, folder: str, - model_config: "ModelConfig") -> bool: +def _load_dense_weights( + linear: nn.Linear, folder: str, model_config: "ModelConfig" +) -> bool: """Load weights using vLLM's weight_loader pattern.""" - from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader) + from vllm.model_executor.model_loader.weight_utils import default_weight_loader for filename in ["model.safetensors", "pytorch_model.bin"]: file_path = f"{folder}/{filename}" if folder else filename try: - file_bytes = get_hf_file_bytes(file_path, model_config.model, - model_config.revision) + file_bytes = get_hf_file_bytes( + file_path, model_config.model, model_config.revision + ) if not file_bytes: continue if filename.endswith(".safetensors"): from safetensors.torch import load as load_safetensors + state_dict = load_safetensors(file_bytes) else: import io - state_dict = torch.load(io.BytesIO(file_bytes), - map_location="cpu", - weights_only=True) + + state_dict = torch.load( + io.BytesIO(file_bytes), map_location="cpu", weights_only=True + ) for weight_key in ["weight", "linear.weight", "dense.weight"]: if weight_key in state_dict: - weight_loader = getattr(linear.weight, "weight_loader", - default_weight_loader) + weight_loader = getattr( + linear.weight, "weight_loader", default_weight_loader + ) weight_loader(linear.weight, state_dict[weight_key]) bias_key = weight_key.replace("weight", "bias") if linear.bias is not None and bias_key in state_dict: - bias_loader = getattr(linear.bias, "weight_loader", - default_weight_loader) + bias_loader = getattr( + linear.bias, "weight_loader", default_weight_loader + ) bias_loader(linear.bias, state_dict[bias_key]) return True except Exception: @@ -133,9 +140,7 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T: - class CallVisitor(ast.NodeVisitor): - def __init__(self): self.calls = [] @@ -150,7 +155,6 @@ def visit_Call(self, node): return None class ModelForPooling(orig_cls, VllmModelForPooling): - is_pooling_model = True def __init__( @@ -172,7 +176,6 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T: from .utils import AutoWeightsLoader, WeightsMapper class ModelForPooling(orig_cls, VllmModelForPooling): - is_pooling_model = True def __init__( @@ -202,8 +205,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # TODO: Support uninitialized params tracking # We have deleted this attribute, so don't load it - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) + weights = ( + (name, data) + for name, data in weights + if not name.startswith("lm_head.") + ) # If `*ForCausalLM` defines `load_weights` on the inner model # and there are no other inner modules with parameters, @@ -212,7 +218,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # Whether only `self.model` contains parameters model_is_only_param = all( name == "model" or next(child.parameters(), None) is None - for name, child in self.named_children()) + for name, child in self.named_children() + ) if model_is_only_param: mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) @@ -252,7 +259,6 @@ def as_embedding_model(cls: _T) -> _T: from vllm.model_executor.layers.pooler import DispatchPooler, Pooler class ModelForEmbedding(_create_pooling_model_cls(cls)): - def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -261,10 +267,10 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): { "encode": Pooler.for_encode(pooler_config), "embed": Pooler.for_embed(pooler_config), - }, ) + }, + ) - ModelForEmbedding.__name__ = \ - _get_pooling_model_name(cls.__name__, "ForEmbedding") + ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding") return ModelForEmbedding # type: ignore @@ -287,17 +293,21 @@ def as_seq_cls_model(cls: _T) -> _T: # Lazy import from vllm.model_executor.layers.linear import ReplicatedLinear - from vllm.model_executor.layers.pooler import (ClassifierPooler, - DispatchPooler, Pooler, - PoolingMethod, PoolingType) + from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + DispatchPooler, + Pooler, + PoolingMethod, + PoolingType, + ) from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.sequence import IntermediateTensors from .utils import get_model_hidden_size, maybe_prefix - class ModelForSequenceClassification(_create_pooling_model_cls(cls), - SupportsCrossEncoding): - + class ModelForSequenceClassification( + _create_pooling_model_cls(cls), SupportsCrossEncoding + ): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -319,24 +329,25 @@ def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): assert pooling_type_str is not None pooling_type = PoolingType[pooling_type_str] - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - ClassifierPooler( - pooling=PoolingMethod.from_pooling_type(pooling_type), - classifier=self._classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - ), - "score": - ClassifierPooler( - pooling=PoolingMethod.from_pooling_type(pooling_type), - classifier=self._classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=PoolingMethod.from_pooling_type(pooling_type), + classifier=self._classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=PoolingMethod.from_pooling_type(pooling_type), + classifier=self._classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) def _classifier(self, x: torch.Tensor): x, _ = self.score(x.float()) @@ -349,8 +360,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return super().forward(input_ids, positions, intermediate_tensors, - inputs_embeds) + return super().forward( + input_ids, positions, intermediate_tensors, inputs_embeds + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): tokens = getattr(self.config, "classifier_from_token", None) @@ -363,9 +375,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # ForSequenceClassification model. return seq_cls_model_loader(self, weights) - - ModelForSequenceClassification.__name__ = \ - _get_pooling_model_name(cls.__name__, "ForSequenceClassification") + ModelForSequenceClassification.__name__ = _get_pooling_model_name( + cls.__name__, "ForSequenceClassification" + ) return ModelForSequenceClassification # type: ignore @@ -388,22 +400,20 @@ def as_reward_model(cls: _T) -> _T: from vllm.model_executor.layers.pooler import DispatchPooler, Pooler class ModelForReward(_create_pooling_model_cls(cls)): - def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, ) + {"encode": Pooler.for_encode(pooler_config)}, + ) - ModelForReward.__name__ = \ - _get_pooling_model_name(cls.__name__, "ForReward") + ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward") return ModelForReward # type: ignore class SequenceClassificationConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config @@ -428,12 +438,11 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: def load_weights_using_from_2_way_softmax( - model, weights: Iterable[tuple[str, torch.Tensor]]): + model, weights: Iterable[tuple[str, torch.Tensor]] +): # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 - from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead) - from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader) + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config @@ -446,24 +455,27 @@ def load_weights_using_from_2_way_softmax( model.lm_head = model.model.embed_tokens else: quant_config = model.vllm_config.quant_config - model.lm_head = ParallelLMHead(model.config.vocab_size, - model.config.hidden_size, - quant_config=quant_config) + model.lm_head = ParallelLMHead( + model.config.vocab_size, model.config.hidden_size, quant_config=quant_config + ) loader = AutoWeightsLoader(model) loaded_weights = loader.load_weights(weights) from vllm.transformers_utils.tokenizer import get_tokenizer - tokenizer = get_tokenizer(model_config.tokenizer, - revision=model_config.tokenizer_revision, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code) + + tokenizer = get_tokenizer( + model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + ) false_id = tokenizer.convert_tokens_to_ids(tokens[0]) true_id = tokenizer.convert_tokens_to_ids(tokens[1]) score_weight = model.lm_head.weight.data[[true_id]].to( - torch.float32) - model.lm_head.weight.data[[false_id]].to( - torch.float32) + torch.float32 + ) - model.lm_head.weight.data[[false_id]].to(torch.float32) param = model.score.weight weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -475,13 +487,9 @@ def load_weights_using_from_2_way_softmax( return loaded_weights -def load_weights_no_post_processing(model, - weights: Iterable[tuple[str, - torch.Tensor]]): - from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead) - from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader) +def load_weights_no_post_processing(model, weights: Iterable[tuple[str, torch.Tensor]]): + from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config @@ -493,18 +501,21 @@ def load_weights_no_post_processing(model, model.lm_head = model.model.embed_tokens else: quant_config = model.vllm_config.quant_config - model.lm_head = ParallelLMHead(model.config.vocab_size, - model.config.hidden_size, - quant_config=quant_config) + model.lm_head = ParallelLMHead( + model.config.vocab_size, model.config.hidden_size, quant_config=quant_config + ) loader = AutoWeightsLoader(model) loaded_weights = loader.load_weights(weights) from vllm.transformers_utils.tokenizer import get_tokenizer - tokenizer = get_tokenizer(model_config.tokenizer, - revision=model_config.tokenizer_revision, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code) + + tokenizer = get_tokenizer( + model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + ) token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] score_weight = model.lm_head.weight.data[token_ids] diff --git a/vllm/model_executor/models/aimv2.py b/vllm/model_executor/models/aimv2.py index 419f8a5ae2c7..2423ad5b0c3a 100644 --- a/vllm/model_executor/models/aimv2.py +++ b/vllm/model_executor/models/aimv2.py @@ -14,18 +14,20 @@ from vllm.distributed.utils import divide from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.transformers_utils.configs.ovis import AIMv2Config class AIMv2SwiGLUFFN(nn.Module): - - def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, - prefix: str): + def __init__( + self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str + ): super().__init__() hidden_features = config.intermediate_size in_features = config.hidden_size @@ -55,7 +57,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AIMv2PatchEmbed(nn.Module): - def __init__(self, config: AIMv2Config): super().__init__() self.proj = nn.Conv2d( @@ -73,14 +74,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AIMv2ViTPreprocessor(nn.Module): - def __init__(self, config: AIMv2Config): super().__init__() - num_patches = (config.image_size // config.patch_size)**2 + num_patches = (config.image_size // config.patch_size) ** 2 self.patchifier = AIMv2PatchEmbed(config) - self.pos_embed = nn.Parameter( - torch.zeros((1, num_patches, config.hidden_size))) + self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size))) def forward(self, x: torch.Tensor) -> torch.Tensor: tokens = self.patchifier(x) @@ -91,9 +90,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AIMv2Attention(nn.Module): - - def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, - prefix: str): + def __init__( + self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -103,7 +102,8 @@ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, raise ValueError( "embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.qkv = QKVParallelLinear( @@ -126,8 +126,9 @@ def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) def forward(self, x: torch.Tensor) -> torch.Tensor: qkv, _ = self.qkv(x) @@ -139,17 +140,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AIMv2Block(nn.Module): - - def __init__(self, config: AIMv2Config, quant_config: QuantizationConfig, - prefix: str): + def __init__( + self, config: AIMv2Config, quant_config: QuantizationConfig, prefix: str + ): super().__init__() - self.attn = AIMv2Attention(config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = AIMv2Attention( + config, quant_config=quant_config, prefix=f"{prefix}.attn" + ) self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = AIMv2SwiGLUFFN(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = AIMv2SwiGLUFFN( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -159,7 +160,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class AIMv2Transformer(nn.Module): - def __init__( self, config: AIMv2Config, @@ -170,13 +170,14 @@ def __init__( ): super().__init__() - self.blocks = nn.ModuleList([ - AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") - for i in range(config.num_hidden_layers) - ]) + self.blocks = nn.ModuleList( + [ + AIMv2Block(config, quant_config, prefix=f"{prefix}.blocks.{i}") + for i in range(config.num_hidden_layers) + ] + ) if require_post_norm: - self.post_trunk_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.post_trunk_norm = None @@ -190,29 +191,30 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor: class AIMv2Model(torch.nn.Module): - - def __init__(self, - config: AIMv2Config, - quant_config: QuantizationConfig, - *, - require_post_norm: Optional[bool] = None, - prefix: str = ""): + def __init__( + self, + config: AIMv2Config, + quant_config: QuantizationConfig, + *, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ): super().__init__() self.preprocessor = AIMv2ViTPreprocessor(config) - self.trunk = AIMv2Transformer(config, - quant_config=quant_config, - require_post_norm=require_post_norm, - prefix=f"{prefix}.trunk") + self.trunk = AIMv2Transformer( + config, + quant_config=quant_config, + require_post_norm=require_post_norm, + prefix=f"{prefix}.trunk", + ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: - x = self.preprocessor(pixel_values) x = self.trunk(x) return x - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".fc13", ".fc1", 0), @@ -223,11 +225,13 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: # post_layernorm is optional in SiglipVisionModel - if (name.startswith("trunk.post_trunk_norm") - and self.trunk.post_trunk_norm is None): + if ( + name.startswith("trunk.post_trunk_norm") + and self.trunk.post_trunk_norm is None + ): continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -238,8 +242,7 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index 6dab4ed14345..743207082721 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -24,6 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Apertus model compatible with HuggingFace weights.""" + from collections.abc import Iterable from typing import Any, Optional, Union @@ -38,27 +39,38 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import XIELU from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class ApertusMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -86,8 +98,10 @@ def __init__( prefix=f"{prefix}.down_proj", ) if hidden_act != "xielu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only xIELU is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only xIELU is supported for now." + ) self.act_fn = XIELU() def forward(self, x): @@ -98,7 +112,6 @@ def forward(self, x): class ApertusAttention(nn.Module): - def __init__( self, config: ApertusConfig, @@ -138,8 +151,7 @@ def __init__( head_dim = self.hidden_size // self.total_num_heads self.head_dim = head_dim # Phi models introduced a partial_rotary_factor parameter in the config - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", - 1) + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -164,9 +176,9 @@ def __init__( prefix=f"{prefix}.o_proj", ) - self._init_rotary_emb(config, - rope_scaling=rope_scaling, - quant_config=quant_config) + self._init_rotary_emb( + config, rope_scaling=rope_scaling, quant_config=quant_config + ) sliding_window = None if layer_types := getattr(config, "layer_types", None): @@ -174,8 +186,11 @@ def __init__( if is_sliding: sliding_window = config.sliding_window - attn_cls = (EncoderOnlyAttention - if attn_type == AttentionType.ENCODER_ONLY else Attention) + attn_cls = ( + EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY + else Attention + ) self.attn = attn_cls( self.num_heads, @@ -206,9 +221,12 @@ def forward( output, _ = self.o_proj(attn_output) return output - def _init_rotary_emb(self, config: ApertusConfig, - rope_scaling: Optional[dict[str, Any]], - quant_config: Optional[QuantizationConfig]) -> None: + def _init_rotary_emb( + self, + config: ApertusConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig], + ) -> None: is_neox_style = True is_gguf = quant_config and quant_config.get_name() == "gguf" if is_gguf and config.model_type == "apertus": @@ -226,7 +244,6 @@ def _init_rotary_emb(self, config: ApertusConfig, class ApertusDecoderLayer(nn.Module): - def __init__( self, config: ApertusConfig, @@ -239,18 +256,20 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) bias_o_proj = attention_bias # support internlm/internlm3-8b with qkv_bias - if hasattr(config, 'qkv_bias'): + if hasattr(config, "qkv_bias"): attention_bias = config.qkv_bias # Apertus defaults to causal attention as it is a decoder-only model. @@ -266,8 +285,9 @@ def __init__( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -286,10 +306,10 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.feedforward_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -302,26 +322,24 @@ def forward( residual = hidden_states hidden_states = self.attention_layernorm(hidden_states) else: - hidden_states, residual = self.attention_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) + hidden_states, residual = self.attention_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) # Fully Connected - hidden_states, residual = self.feedforward_layernorm( - hidden_states, residual) + hidden_states, residual = self.feedforward_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class ApertusModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = ApertusDecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config @@ -331,12 +349,16 @@ def __init__(self, self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -347,10 +369,12 @@ def __init__(self, self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: @@ -360,9 +384,9 @@ def __init__(self, self.aux_hidden_state_layers = tuple[int, ...]() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -373,8 +397,9 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, - list[torch.Tensor]]]: + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -387,17 +412,15 @@ def forward( residual = intermediate_tensors["residual"] aux_hidden_states = [] - for idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) @@ -405,8 +428,7 @@ def forward( return hidden_states, aux_hidden_states return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -424,19 +446,19 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -469,8 +491,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -482,15 +503,17 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # LoRA specific attributes embedding_modules = { "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings" + "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = ApertusDecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -498,9 +521,11 @@ def __init__(self, self.config = config self.lora_config = lora_config - self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model"), - layer_type=layer_type) + self.model = self._init_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type, + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size @@ -514,24 +539,25 @@ def __init__(self, DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers @@ -540,13 +566,15 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) - def _init_model(self, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = ApertusDecoderLayer): - return ApertusModel(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = ApertusDecoderLayer, + ): + return ApertusModel( + vllm_config=vllm_config, prefix=prefix, layer_type=layer_type + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -558,8 +586,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( @@ -569,11 +598,9 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index 1ee378af76c9..634e94b16814 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -20,32 +20,43 @@ from vllm.distributed import get_pp_group from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, +) class ArceeMLP(nn.Module): """Feed-forward layer for Arcee using ReLU^2 activation (no gating as in LLaMA).""" - def __init__(self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[Any] = None, - bias: bool = False, - prefix: str = "", - reduce_results: bool = True) -> None: + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[Any] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + ) -> None: super().__init__() # Single linear projection up to intermediate size # (no separate gate projection) @@ -66,8 +77,10 @@ def __init__(self, prefix=f"{prefix}.down_proj", ) if hidden_act != "relu2": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only 'relu2' is supported for AFM.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only 'relu2' is supported for AFM." + ) # Define ReLU^2 activation: (ReLU(x))^2 elementwise self.act_fn = ReLUSquaredActivation() @@ -82,38 +95,45 @@ class ArceeDecoderLayer(nn.Module): """Transformer decoder block for Arcee, with self-attention and ReLU^2 MLP.""" - def __init__(self, - config: LlamaConfig, - cache_config: Optional[Any] = None, - quant_config: Optional[Any] = None, - prefix: str = "") -> None: + def __init__( + self, + config: LlamaConfig, + cache_config: Optional[Any] = None, + quant_config: Optional[Any] = None, + prefix: str = "", + ) -> None: super().__init__() self.hidden_size = config.hidden_size # Rotary embedding parameters (reuse LLaMA defaults) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Determine if attention bias is needed (some variants use bias terms) attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) bias_o_proj = attention_bias if hasattr(config, "qkv_bias"): attention_bias = config.qkv_bias # Self-Attention (using LLaMA's attention structure) from vllm.model_executor.models.llama import ( - LlamaAttention) # import here to avoid circular import + LlamaAttention, # import here to avoid circular import + ) + self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -123,8 +143,8 @@ def __init__(self, cache_config=cache_config, prefix=f"{prefix}.self_attn", attn_type=getattr( - config, "attn_type", - "decoder"), # assume decoder (causal) unless specified + config, "attn_type", "decoder" + ), # assume decoder (causal) unless specified ) # MLP with ReLU^2 activation self.mlp = ArceeMLP( @@ -136,14 +156,16 @@ def __init__(self, prefix=f"{prefix}.mlp", ) # Layer normalization layers (RMSNorm as in LLaMA) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( - self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor] + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: # Self-Attention block if residual is None: @@ -151,13 +173,10 @@ def forward( hidden_states = self.input_layernorm(hidden_states) else: # Fused residual add + layernorm if supported - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) # Feed-forward block - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -167,11 +186,13 @@ class ArceeModel(nn.Module): """The transformer model backbone for Arcee (embedding layer + stacked decoder blocks + final norm).""" - def __init__(self, - *, - vllm_config, - prefix: str = "", - layer_type: type[nn.Module] = ArceeDecoderLayer) -> None: + def __init__( + self, + *, + vllm_config, + prefix: str = "", + layer_type: type[nn.Module] = ArceeDecoderLayer, + ) -> None: super().__init__() config: LlamaConfig = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -182,8 +203,9 @@ def __init__(self, self.org_vocab_size = config.vocab_size # Word embeddings (parallelized if using pipeline parallel) - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -191,16 +213,17 @@ def __init__(self, quant_config=quant_config, ) else: - self.embed_tokens = PPMissingLayer( - ) # placeholder on non-embedding ranks + self.embed_tokens = PPMissingLayer() # placeholder on non-embedding ranks # Build decoder layers across pipeline ranks self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) # Final RMSNorm on the last pipeline stage @@ -215,9 +238,9 @@ def __init__(self, # Prepare factory for empty intermediate tensors # (for pipeline scheduling) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -227,44 +250,47 @@ def forward( input_ids: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None - ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, - list[torch.Tensor]]]: + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: # Embedding lookup (on first pipeline rank) if get_pp_group().is_first_rank: - hidden_states = (inputs_embeds if inputs_embeds is not None else - self.get_input_embeddings(input_ids)) + hidden_states = ( + inputs_embeds + if inputs_embeds is not None + else self.get_input_embeddings(input_ids) + ) residual = None else: assert intermediate_tensors is not None, ( - "IntermediateTensors must be provided for non-first " - "pipeline ranks") + "IntermediateTensors must be provided for non-first pipeline ranks" + ) hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] aux_hidden_states: list[torch.Tensor] = [] for idx, layer in enumerate( - islice(self.layers, self.start_layer, self.end_layer)): + islice(self.layers, self.start_layer, self.end_layer) + ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append( - hidden_states + - residual) # capture pre-layer hidden state if needed + hidden_states + residual + ) # capture pre-layer hidden state if needed hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: # Send intermediate results to the next pipeline stage - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) # On last rank: apply final layer norm hidden_states, _ = self.norm(hidden_states, residual) if len(aux_hidden_states) > 0: return hidden_states, aux_hidden_states return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights, mapping q/k/v projections to fused qkv_proj.""" stacked_params_mapping = [ (".qkv_proj", ".q_proj", "q"), @@ -278,17 +304,17 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -331,8 +357,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -342,6 +367,7 @@ def load_weights(self, weights: Iterable[tuple[str, class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): """Arcee Model for causal language modeling, integrated with vLLM runtime.""" + # Map fused module names to their submodule components # (for quantization and LoRA) packed_modules_mapping = { @@ -354,8 +380,7 @@ def __init__(self, *, vllm_config, prefix: str = "") -> None: self.config = config # Initialize the inner Transformer model (ArceeModel) - self.model = ArceeModel(vllm_config=vllm_config, - prefix=f"{prefix}.model") + self.model = ArceeModel(vllm_config=vllm_config, prefix=f"{prefix}.model") # On the last pipeline stage, set up the LM head and logits processor if get_pp_group().is_last_rank: # Determine vocabulary size (including any LoRA extra tokens @@ -373,34 +398,35 @@ def __init__(self, *, vllm_config, prefix: str = "") -> None: ) if config.tie_word_embeddings: # Tie output weights with input embedding matrix - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: # Placeholder for lm_head on non-last ranks self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) return model_output - def compute_logits(self, - hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # Compute final logits from hidden states (last pipeline rank only) logits = self.logits_processor(self.lm_head, hidden_states) return logits @@ -408,15 +434,14 @@ def compute_logits(self, def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights into the model (delegates to inner model and handles tied embeddings).""" loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), - skip_substrs=["gate_proj"]) + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + skip_substrs=["gate_proj"], + ) # AutoWeightLoader handles weight name remapping, including fusing # separate q_proj, k_proj, v_proj into qkv_proj return loader.load_weights(weights) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 55d16fd75ceb..760df1cef82b 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Snowflake Arctic model.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -11,24 +12,33 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.deepspeedfp import ( - DeepSpeedFPConfig, DeepSpeedFPParameter) + DeepSpeedFPConfig, + DeepSpeedFPParameter, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -36,41 +46,50 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig from .interfaces import SupportsPP, SupportsQuant -from .utils import (extract_layer_index, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class ArcticMLP(nn.Module): - - def __init__(self, - config: ArcticConfig, - expert_id: int = -1, - is_residual_mlp: bool = False, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = ""): + def __init__( + self, + config: ArcticConfig, + expert_id: int = -1, + is_residual_mlp: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ): super().__init__() self.hidden_size = config.hidden_size self.expert_id = expert_id - self.ffn_dim = config.intermediate_size if not is_residual_mlp \ - else self.hidden_size - - self.w13 = MergedColumnParallelLinear(self.hidden_size, - [self.ffn_dim] * 2, - bias=False, - quant_config=quant_config) - self.w2 = RowParallelLinear(self.ffn_dim, - self.hidden_size, - bias=False, - reduce_results=reduce_results, - quant_config=quant_config) + self.ffn_dim = ( + config.intermediate_size if not is_residual_mlp else self.hidden_size + ) + + self.w13 = MergedColumnParallelLinear( + self.hidden_size, [self.ffn_dim] * 2, bias=False, quant_config=quant_config + ) + self.w2 = RowParallelLinear( + self.ffn_dim, + self.hidden_size, + bias=False, + reduce_results=reduce_results, + quant_config=quant_config, + ) if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, hidden_states): @@ -85,13 +104,15 @@ class ArcticMoE(nn.Module): Model-parallel implementation of Arctic MoE Layer. """ - def __init__(self, - config: ArcticConfig, - tp_size: Optional[int] = None, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = ""): + def __init__( + self, + config: ArcticConfig, + tp_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ): super().__init__() layer_id = extract_layer_index(prefix) @@ -111,52 +132,75 @@ def __init__(self, self.params_dtype = params_dtype if not self.is_moe_layer: - self.mlp = ArcticMLP(config, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.mlp") + self.mlp = ArcticMLP( + config, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.mlp", + ) else: - self.gate = ReplicatedLinear(self.hidden_size, - self.num_experts, - bias=False, - params_dtype=self.params_dtype, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + self.hidden_size, + self.num_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) if self.is_quant: self.ws = DeepSpeedFPParameter( - torch.Size((self.num_experts, 2 * self.intermediate_size, - self.hidden_size)), + torch.Size( + (self.num_experts, 2 * self.intermediate_size, self.hidden_size) + ), params_dtype=params_dtype, quant_config=quant_config, ) self.w2s = DeepSpeedFPParameter( - torch.Size((self.num_experts, self.hidden_size, - self.intermediate_size)), + torch.Size( + (self.num_experts, self.hidden_size, self.intermediate_size) + ), params_dtype=params_dtype, quant_config=quant_config, ) else: self.ws = nn.Parameter( - torch.empty(self.num_experts, - 2 * self.intermediate_size, - self.hidden_size, - device=current_platform.device_type, - dtype=self.params_dtype)) + torch.empty( + self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) self.w2s = nn.Parameter( - torch.empty(self.num_experts, - self.hidden_size, - self.intermediate_size, - device=current_platform.device_type, - dtype=self.params_dtype)) - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): + torch.empty( + self.num_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) + set_weight_attrs( + self.ws, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2s, + { + "weight_loader": self.weight_loader, + }, + ) + + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + expert_id: int, + ): tp_rank = get_tensor_model_parallel_rank() param_data = param.ds_dequantize() if self.is_quant else param.data shard_size = self.intermediate_size @@ -164,8 +208,9 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, if weight_name.endswith("w1.weight"): param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] + param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ + shard, : + ] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] if self.is_quant: @@ -178,15 +223,14 @@ def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, _ = self.gate(hidden_states) do_normalize = self.top_k > 1 topk_weights, topk_ids, token_expert_indices = fused_topk( - hidden_states, router_logits, self.top_k, renormalize=do_normalize) + hidden_states, router_logits, self.top_k, renormalize=do_normalize + ) # topk_ids: (num_tokens, k) if self.is_quant: if 2 * num_tokens <= self.num_experts: # If much fewer tokens than experts, use selective dequantize. - ws_dequantized = self.ws.ds_selective_dequantize( - topk_ids.flatten()) - w2s_dequantized = self.w2s.ds_selective_dequantize( - topk_ids.flatten()) + ws_dequantized = self.ws.ds_selective_dequantize(topk_ids.flatten()) + w2s_dequantized = self.w2s.ds_selective_dequantize(topk_ids.flatten()) # We gathered the experts to the tokens so update the mapping. topk_ids = torch.arange( 0, @@ -203,10 +247,10 @@ def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: w2s_dequantized if self.is_quant else self.w2s, topk_weights, topk_ids, - inplace=True) + inplace=True, + ) if self.reduce_results and self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) def forward(self, hidden_states: torch.Tensor): @@ -218,7 +262,6 @@ def forward(self, hidden_states: torch.Tensor): class ArcticAttention(nn.Module): - def __init__( self, config: ArcticConfig, @@ -248,12 +291,14 @@ def __init__( self.rope_theta = config.rope_theta self.scaling = self.head_dim**-0.5 - self.qkv_proj = QKVParallelLinear(self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=False, - quant_config=quant_config) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, @@ -270,13 +315,15 @@ def __init__( is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -292,7 +339,6 @@ def forward( class ArcticDecoderLayer(nn.Module): - def __init__( self, config: ArcticConfig, @@ -305,10 +351,12 @@ def __init__( layer_idx = extract_layer_index(prefix) is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 self.use_residual = config.use_residual and is_moe_layer - self.self_attn = ArcticAttention(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = ArcticAttention( + config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.block_sparse_moe = ArcticMoE( config, quant_config=quant_config, @@ -316,18 +364,21 @@ def __init__( prefix=f"{prefix}.block_sparse_moe", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) if self.use_residual: - self.residual_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.residual_mlp = ArcticMLP(config, - is_residual_mlp=True, - reduce_results=False, - prefix=f"{prefix}.residual_mlp") + self.residual_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.residual_mlp = ArcticMLP( + config, + is_residual_mlp=True, + reduce_results=False, + prefix=f"{prefix}.residual_mlp", + ) def forward( self, @@ -361,7 +412,6 @@ def forward( @support_torch_compile class ArcticModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -371,19 +421,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=self.vocab_size) + self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: ArcticDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self._attn_implementation = config._attn_implementation self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -419,8 +470,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - self.model = ArcticModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = ArcticModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.vocab_size, @@ -433,10 +485,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok self.unpadded_vocab_size = config.vocab_size - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -448,8 +502,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -459,8 +514,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -474,28 +528,47 @@ def load_weights(self, weights: Iterable[tuple[str, for layer in range(num_layers): mlp_params_mapping.append( - (f"layers.{layer}.residual_mlp.w13.weight", - f"layers.{layer}.residual_mlp.w1.weight", 0)) + ( + f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", + 0, + ) + ) mlp_params_mapping.append( - (f"layers.{layer}.residual_mlp.w13.weight", - f"layers.{layer}.residual_mlp.w3.weight", 1)) + ( + f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", + 1, + ) + ) if layer % 2 == 0: # MLP layers mlp_params_mapping.append( - (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", - f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0)) + ( + f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", + 0, + ) + ) mlp_params_mapping.append( - (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", - f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1)) + ( + f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", + 1, + ) + ) else: # MoE layers for expert_id in range(self.config.num_local_experts): expert_params_mapping.append( - ("ws", f"experts.{expert_id}.w1.weight", expert_id)) + ("ws", f"experts.{expert_id}.w1.weight", expert_id) + ) expert_params_mapping.append( - ("w2s", f"experts.{expert_id}.w2.weight", expert_id)) + ("w2s", f"experts.{expert_id}.w2.weight", expert_id) + ) expert_params_mapping.append( - ("ws", f"experts.{expert_id}.w3.weight", expert_id)) + ("ws", f"experts.{expert_id}.w3.weight", expert_id) + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -503,9 +576,10 @@ def load_weights(self, weights: Iterable[tuple[str, logger.info( "It will take ~10 minutes loading from the 16-bit weights. " "Alternatively, use the prequantized 8-bit weights of arctic " - "and set load-format to `sharded_state` will accelerate loading.") + "and set load-format to `sharded_state` will accelerate loading." + ) for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -530,8 +604,7 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, shard_id \ - in expert_params_mapping: + for param_name, weight_name, shard_id in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -539,10 +612,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=shard_id) + weight_loader( + param, loaded_weight, weight_name, expert_id=shard_id + ) break else: if name.endswith(".bias") and name not in params_dict: @@ -551,8 +623,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 188624e606ff..7db118ca0745 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -14,33 +14,43 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -# yapf: disable from .idefics2_vision_model import Idefics2VisionConfig from .idefics2_vision_model import ( - Idefics2VisionTransformer as Idefics3VisionTransformer) -# yapf: enable + Idefics2VisionTransformer as Idefics3VisionTransformer, +) from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + is_pp_missing_parameter, + maybe_prefix, +) class AriaImagePixelInputs(TensorSchema): @@ -81,8 +91,7 @@ def __init__( # Identity layer self.post_layernorm = nn.Identity() - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -92,7 +101,6 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - # NOTE: post_layernorm is not used in Aria if "post_layernorm" in name: continue @@ -107,15 +115,13 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class AriaProjectorMLP(nn.Module): - def __init__( self, in_features: int, @@ -124,12 +130,8 @@ def __init__( ) -> None: super().__init__() - self.linear_in = ColumnParallelLinear(in_features, - hidden_features, - bias=False) - self.linear_out = RowParallelLinear(hidden_features, - output_dim, - bias=False) + self.linear_in = ColumnParallelLinear(in_features, hidden_features, bias=False) + self.linear_out = RowParallelLinear(hidden_features, output_dim, bias=False) self.act = get_act_fn("gelu_new") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -163,15 +165,17 @@ def __init__(self, config: AriaConfig) -> None: self.output_dim = config.text_config.hidden_size self.query = nn.Parameter( - torch.empty(config.max_value_projector_patch_to_query_dict, - self.in_features)) + torch.empty( + config.max_value_projector_patch_to_query_dict, self.in_features + ) + ) self.cross_attn = AriaCrossAttention(config) self.layer_norm = nn.LayerNorm(self.in_features) - self.feed_forward = AriaProjectorMLP(self.in_features, - self.hidden_features, - self.output_dim) + self.feed_forward = AriaProjectorMLP( + self.in_features, self.hidden_features, self.output_dim + ) def forward( self, @@ -181,9 +185,11 @@ def forward( batch_size, num_patches = x.shape[0], x.shape[1] if num_patches not in self.patch_to_query_dict: - raise KeyError(f"Number of patches {num_patches} not found in " - "patch_to_query_dict amongst possible values " - f"{self.patch_to_query_dict.keys()}.") + raise KeyError( + f"Number of patches {num_patches} not found in " + "patch_to_query_dict amongst possible values " + f"{self.patch_to_query_dict.keys()}." + ) query_num = self.patch_to_query_dict[num_patches] @@ -201,32 +207,32 @@ def forward( class AriaFusedMoE(FusedMoE): - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - shard_id: str) -> None: + def weight_loader( + self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str + ) -> None: # Override the weight_loader to handle the expert weights in the Aria # model, which are already packed with experts, and merge the gate and # up weights for each expert. # Note: Loading expert weights with quantization is not supported tp_rank = get_tensor_model_parallel_rank() - if shard_id == 'w13': + if shard_id == "w13": # the shape of loaded_weight is # (num_experts, hidden_size, 2 * moe_intermediate_size) if self.tp_size > 1: up, gate = loaded_weight.chunk(2, dim=-1) up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank] gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank] - up_and_gate = torch.cat([up_current_rank, gate_current_rank], - dim=-1).transpose(1, 2) + up_and_gate = torch.cat( + [up_current_rank, gate_current_rank], dim=-1 + ).transpose(1, 2) param.data.copy_(up_and_gate) else: param.data.copy_(loaded_weight.transpose(1, 2)) - elif shard_id == 'w2': + elif shard_id == "w2": # the shape of loaded_weight is # (num_experts, moe_intermediate_size, hidden_size) if self.tp_size > 1: - down_current_rank = loaded_weight.chunk(self.tp_size, - dim=1)[tp_rank] + down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[tp_rank] param.data.copy_(down_current_rank.transpose(1, 2)) else: param.data.copy_(loaded_weight.transpose(1, 2)) @@ -251,8 +257,8 @@ def __init__( self.config = config self.router_weight = nn.Parameter( - torch.empty( - (self.config.moe_num_experts, self.config.hidden_size))) + torch.empty((self.config.moe_num_experts, self.config.hidden_size)) + ) self.experts = AriaFusedMoE( num_experts=config.moe_num_experts, @@ -283,8 +289,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: torch.Tensor: Output tensor after passing through the MoE layer. """ - router_output = torch.nn.functional.linear(hidden_states, - self.router_weight) + router_output = torch.nn.functional.linear(hidden_states, self.router_weight) hidden_states_copy = hidden_states.clone() # NOTE: hidden_states will be modified inplace by `FusedMoE` @@ -307,9 +312,9 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.mlp = AriaTextMoELayer(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = AriaTextMoELayer( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) class AriaTextModel(LlamaModel, SupportsQuant): @@ -317,6 +322,7 @@ class AriaTextModel(LlamaModel, SupportsQuant): Custom LlamaModel for the AriaMoE model which modifies the standard LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. """ + packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -325,14 +331,13 @@ class AriaTextModel(LlamaModel, SupportsQuant): } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=AriaTextDecoderLayer) + super().__init__( + vllm_config=vllm_config, prefix=prefix, layer_type=AriaTextDecoderLayer + ) # Adapted from LlamaModel.load_weights with the modification of adding # the expert weights mapping to `stacked_params_mapping` - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -340,27 +345,27 @@ def load_weights(self, weights: Iterable[tuple[str, (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), - ("experts.w13_weight", "experts.fc1.weight", 'w13'), - ("experts.w2_weight", "experts.fc2.weight", 'w2'), + ("experts.w13_weight", "experts.fc1.weight", "w13"), + ("experts.w2_weight", "experts.fc2.weight", "w2"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -392,15 +397,13 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class AriaProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(AriaConfig) @@ -419,7 +422,6 @@ def get_num_image_tokens(self) -> int: class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -442,16 +444,16 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=max_image_size, - height=max_image_size, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=max_image_size, + height=max_image_size, + num_images=num_images, + overrides=image_overrides, + ) } class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -482,9 +484,11 @@ def _get_prompt_updates( ] -@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor, - info=AriaProcessingInfo, - dummy_inputs=AriaDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + AriaMultiModalProcessor, + info=AriaProcessingInfo, + dummy_inputs=AriaDummyInputsBuilder, +) class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): """ Aria model for conditional generation tasks. @@ -492,6 +496,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): This model combines a vision tower, a multi-modal projector, and a language model to perform tasks that involve both image and text inputs. """ + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( @@ -537,8 +542,9 @@ def __init__( vllm_config=vllm_config.with_hf_config(config.text_config), prefix=maybe_prefix(prefix, "language_model.model"), ) - self.pad_token_id = (self.config.pad_token_id - if self.config.pad_token_id is not None else -1) + self.pad_token_id = ( + self.config.pad_token_id if self.config.pad_token_id is not None else -1 + ) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, @@ -548,11 +554,13 @@ def __init__( prefix=maybe_prefix(prefix, "lm_head"), ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.vocab_size, logit_scale + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[AriaImagePixelInputs]: + self, **kwargs: object + ) -> Optional[AriaImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) pixel_mask = kwargs.pop("pixel_mask", None) @@ -588,8 +596,8 @@ def _process_image_input( ) -> tuple[torch.Tensor, torch.Tensor]: assert self.vision_tower is not None - pixel_values = image_input['pixel_values'] - pixel_mask = image_input['pixel_mask'] + pixel_values = image_input["pixel_values"] + pixel_mask = image_input["pixel_mask"] patch_attention_mask = self._create_patch_attention_mask(pixel_mask) @@ -607,8 +615,7 @@ def _process_image_input( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index a682252f4a2b..6e93de524e48 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -10,30 +10,36 @@ from transformers.activations import ACT2FN from transformers.image_processing_utils import get_size_dict from transformers.models.aya_vision import AyaVisionConfig -from transformers.models.aya_vision.processing_aya_vision import ( - AyaVisionProcessor) +from transformers.models.aya_vision.processing_aya_vision import AyaVisionProcessor from transformers.models.got_ocr2.image_processing_got_ocr2 import ( - get_optimal_tiled_canvas) + get_optimal_tiled_canvas, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalFieldConfig, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalFieldConfig, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) class AyaVisionImagePixelInputs(TensorSchema): @@ -61,17 +67,17 @@ class AyaVisionImagePixelInputs(TensorSchema): class AyaVisionMultiModalProjector(nn.Module): - def __init__(self, config: AyaVisionConfig): super().__init__() self.config = config self.downsample_factor = config.downsample_factor self.alignment_intermediate_size = getattr( - config, "alignment_intermediate_size", - config.text_config.hidden_size) - self.layernorm = nn.LayerNorm(config.vision_config.hidden_size * - (config.downsample_factor**2), - eps=config.adapter_layer_norm_eps) + config, "alignment_intermediate_size", config.text_config.hidden_size + ) + self.layernorm = nn.LayerNorm( + config.vision_config.hidden_size * (config.downsample_factor**2), + eps=config.adapter_layer_norm_eps, + ) self.linear_1 = nn.Linear( config.vision_config.hidden_size * (config.downsample_factor**2), @@ -81,9 +87,11 @@ def __init__(self, config: AyaVisionConfig): self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation # For SwiGLU, project down to half size since we split intermediate dim - self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2, - config.text_config.hidden_size, - bias=True) + self.linear_2 = nn.Linear( + self.alignment_intermediate_size // 2, + config.text_config.hidden_size, + bias=True, + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: image_features = self.pixel_shuffle(image_features) @@ -97,26 +105,31 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_2(hidden_states) return hidden_states - def pixel_shuffle(self, - image_features: torch.Tensor) -> torch.Tensor: # B, S, D + def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor: # B, S, D batch_size, seq_length, _ = image_features.shape height = width = int(seq_length**0.5) - image_features = image_features.reshape(image_features.shape[0], width, - height, -1) + image_features = image_features.reshape( + image_features.shape[0], width, height, -1 + ) channels = image_features.shape[-1] image_features = image_features.reshape( - batch_size, width, int(height / self.downsample_factor), - int(channels * self.downsample_factor)) + batch_size, + width, + int(height / self.downsample_factor), + int(channels * self.downsample_factor), + ) image_features = image_features.permute(0, 2, 1, 3) image_features = image_features.reshape( - batch_size, int(height / self.downsample_factor), - int(width / self.downsample_factor), -1) + batch_size, + int(height / self.downsample_factor), + int(width / self.downsample_factor), + -1, + ) image_features = image_features.permute(0, 2, 1, 3) return image_features class AyaVisionProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> AyaVisionConfig: return self.ctx.get_hf_config(AyaVisionConfig) @@ -131,14 +144,20 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_image_processor() - height = image_processor.size['height'] - width = image_processor.size['width'] + height = image_processor.size["height"] + width = image_processor.size["width"] max_patches = image_processor.max_patches - return ImageSize(height=height * max_patches, - width=width * max_patches) + return ImageSize(height=height * max_patches, width=width * max_patches) - def get_num_patches(self, *, image_width: int, image_height: int, - size: dict, min_patches: int, max_patches: int) -> int: + def get_num_patches( + self, + *, + image_width: int, + image_height: int, + size: dict, + min_patches: int, + max_patches: int, + ) -> int: """ Calculate the number of patches needed for a given image based on size constraints. This method replicates and adjusts the logic from: @@ -146,15 +165,16 @@ def get_num_patches(self, *, image_width: int, image_height: int, """ size = get_size_dict(size, default_to_square=False) num_columns, num_rows = get_optimal_tiled_canvas( - (image_height, image_width), (size["height"], size["width"]), - min_patches, max_patches) + (image_height, image_width), + (size["height"], size["width"]), + min_patches, + max_patches, + ) num_blocks = num_columns * num_rows return num_blocks if num_blocks == 1 else num_blocks + 1 -class AyaVisionDummyInputsBuilder( - BaseDummyInputsBuilder[AyaVisionProcessingInfo]): - +class AyaVisionDummyInputsBuilder(BaseDummyInputsBuilder[AyaVisionProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -170,23 +190,21 @@ def get_dummy_mm_data( mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - image_size = \ - self.info.get_image_size_with_most_features() + image_size = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=image_size.width, - height=image_size.height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=image_size.width, + height=image_size.height, + num_images=num_images, + overrides=image_overrides, + ) } -class AyaVisionMultiModalProcessor( - BaseMultiModalProcessor[AyaVisionProcessingInfo]): - +class AyaVisionMultiModalProcessor(BaseMultiModalProcessor[AyaVisionProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -205,13 +223,13 @@ def _call_hf_processor( # HF processor pops the `num_patches` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": - images - }).get_items("image", ImageProcessorItems)) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) image_sizes = [ - parsed_images.get_image_size(i) - for i in range(len(parsed_images)) + parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] num_patches = [ @@ -220,7 +238,8 @@ def _call_hf_processor( image_height=image_size.height, size=image_processor.size, min_patches=image_processor.min_patches, - max_patches=image_processor.max_patches) + max_patches=image_processor.max_patches, + ) for image_size in image_sizes ] processed_outputs["num_patches"] = torch.tensor(num_patches) @@ -234,8 +253,7 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -282,10 +300,10 @@ def _get_num_hidden_layers(hf_config: AyaVisionConfig) -> int: return _get_layer_index(feature_layers, num_hidden_layers) # If we have multiple feature layers, initialize up to the deepest m elif isinstance(feature_layers, (list, tuple)): - return max( - _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) - raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" - " is not supported") + return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError( + f"vision_layer_feature type: {type(feature_layers)} is not supported" + ) def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: @@ -297,9 +315,9 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: @MULTIMODAL_REGISTRY.register_processor( AyaVisionMultiModalProcessor, info=AyaVisionProcessingInfo, - dummy_inputs=AyaVisionDummyInputsBuilder) -class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + dummy_inputs=AyaVisionDummyInputsBuilder, +) +class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( @@ -309,7 +327,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -332,7 +351,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vision_config, quant_config, num_hidden_layers_override=num_hidden_layers, - prefix=maybe_prefix(prefix, "vision_model")) + prefix=maybe_prefix(prefix, "vision_model"), + ) self.vocab_size = config.text_config.vocab_size self.multi_modal_projector = AyaVisionMultiModalProjector(config) self.language_model = init_vllm_registered_model( @@ -340,14 +360,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): hf_config=config.text_config, prefix=maybe_prefix(prefix, "model"), # Cohere2ForCausalLM and CohereForCausalLM are the same on vllm - architectures=["Cohere2ForCausalLM"]) + architectures=["Cohere2ForCausalLM"], + ) @property def dtype(self): return next(self.parameters()).dtype - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -361,20 +381,21 @@ def _image_pixels_to_features( feature_select_strategy=self.config.vision_feature_select_strategy, ) - def _process_image_input(self, image_input: AyaVisionImagePixelInputs, - **kwargs) -> list[torch.Tensor]: + def _process_image_input( + self, image_input: AyaVisionImagePixelInputs, **kwargs + ) -> list[torch.Tensor]: assert self.vision_tower is not None pixel_values = image_input["pixel_values"] num_patches = image_input["num_patches"] image_features = self._image_pixels_to_features( - self.vision_tower, pixel_values=pixel_values) + self.vision_tower, pixel_values=pixel_values + ) image_embeds = self.multi_modal_projector(image_features) - return [ - e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist()) - ] + return [e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())] def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]: + self, **kwargs: object + ) -> Optional[AyaVisionImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -390,13 +411,13 @@ def _parse_and_validate_image_input( resolve_bindings={ "h": self.config.vision_config.image_size, "w": self.config.vision_config.image_size, - }) + }, + ) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index db8d0a871047..a8f0e5993e2b 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -20,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable from itertools import islice @@ -32,32 +33,45 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, row_parallel_weight_loader) + default_weight_loader, + row_parallel_weight_loader, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) @@ -65,22 +79,20 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes class BaiChuanMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -90,16 +102,15 @@ def __init__( ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -125,12 +136,10 @@ def __init__( ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = hidden_size // self.total_num_heads self.position_embedding = position_embedding self.rope_theta = rope_theta @@ -160,12 +169,14 @@ def __init__( alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) else: self.rotary_emb = get_rope( self.head_dim, @@ -174,12 +185,14 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -196,18 +209,18 @@ def forward( class BaiChuanDecoderLayer(nn.Module): - - def __init__(self, - config: PretrainedConfig, - position_embedding: str, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PretrainedConfig, + position_embedding: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = BaiChuanAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -224,10 +237,10 @@ def __init__(self, hidden_act=config.hidden_act, quant_config=quant_config, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -240,23 +253,20 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class BaiChuanModel(nn.Module): - def __init__( self, vllm_config: VllmConfig, @@ -278,17 +288,15 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: BaiChuanDecoderLayer(config, - position_embedding, - cache_config, - quant_config, - prefix=prefix), + lambda prefix: BaiChuanDecoderLayer( + config, position_embedding, cache_config, quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -317,15 +325,16 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -337,7 +346,7 @@ def load_weights(self, weights: Iterable[tuple[str, if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -357,15 +366,13 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, - SupportsQuant): +class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): packed_modules_mapping = { "W_pack": ["W_pack"], "gate_up_proj": [ @@ -389,19 +396,24 @@ def __init__( self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config - self.model = BaiChuanModel(vllm_config=vllm_config, - prefix=prefix, - position_embedding=position_embedding) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + self.model = BaiChuanModel( + vllm_config=vllm_config, + prefix=prefix, + position_embedding=position_embedding, + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.lm_head.weight.weight_loader = self.lm_head_weight_loader if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -413,8 +425,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -424,13 +437,11 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) - def lm_head_weight_loader(self, param: nn.Parameter, - loaded_weight: torch.Tensor): + def lm_head_weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): # Unlike Baichuan, Baichuan2 normalizes the head weights. # Refer to: # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 @@ -454,13 +465,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config if config.hidden_size == 4096: # baichuan2 7b - super().__init__(vllm_config=vllm_config, - prefix=prefix, - position_embedding="ROPE") + super().__init__( + vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE" + ) else: # baichuan 13b, baichuan2 13b - super().__init__(vllm_config=vllm_config, - prefix=prefix, - position_embedding="ALIBI") + super().__init__( + vllm_config=vllm_config, prefix=prefix, position_embedding="ALIBI" + ) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -469,6 +480,6 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - position_embedding="ROPE") + super().__init__( + vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE" + ) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 6e470378cb60..3911ba599069 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BailingMoE model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -35,31 +36,42 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class BailingAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -79,8 +91,7 @@ def __init__( assert self.total_num_heads >= self.total_kv_heads self.num_heads = self.total_num_heads // tp_size - self.head_dim = config.head_dim or (self.hidden_size // - self.total_num_heads) + self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) self.q_size_per_rank = self.head_dim * self.num_heads self.num_kv_heads = self.total_kv_heads // tp_size self.kv_size_per_rank = self.num_kv_heads * self.head_dim @@ -99,12 +110,16 @@ def __init__( ) if self.use_qk_norm: - self.query_layernorm = (RMSNorm( - self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm - else nn.LayerNorm(self.head_dim, eps=1e-6)) - self.key_layernorm = (RMSNorm( - self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm - else nn.LayerNorm(self.head_dim, eps=1e-6)) + self.query_layernorm = ( + RMSNorm(self.head_dim, eps=config.rms_norm_eps) + if self.use_rmsnorm + else nn.LayerNorm(self.head_dim, eps=1e-6) + ) + self.key_layernorm = ( + RMSNorm(self.head_dim, eps=config.rms_norm_eps) + if self.use_rmsnorm + else nn.LayerNorm(self.head_dim, eps=1e-6) + ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -115,8 +130,7 @@ def __init__( prefix=f"{prefix}.dense", ) - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", - 1.0) + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) self.rotary_dim = getattr(config, "rotary_dim", self.head_dim) @@ -144,12 +158,10 @@ def forward( hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: - qkv, _ = self.query_key_value(hidden_states) - q, k, v = qkv.split([ - self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank - ], - dim=-1) + q, k, v = qkv.split( + [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], dim=-1 + ) if self.use_qk_norm: q = q.view(-1, self.num_heads, self.head_dim) @@ -168,7 +180,6 @@ def forward( class BailingMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -203,7 +214,6 @@ def forward(self, x): class BailingMoE(nn.Module): - def __init__( self, intermediate_size: int, @@ -225,10 +235,8 @@ def __init__( self.score_function = getattr(config, "score_function", None) self.n_group = getattr(config, "n_group", None) self.topk_group = getattr(config, "topk_group", None) - self.use_grouped_topk = (self.n_group is not None - and self.topk_group is not None) - self.routed_scaling_factor = getattr(config, "routed_scaling_factor", - 1.0) + self.use_grouped_topk = self.n_group is not None and self.topk_group is not None + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) router_dtype = getattr(config, "router_dtype", None) if router_dtype is None: @@ -247,21 +255,23 @@ def __init__( if getattr(config, "moe_router_enable_expert_bias", False): self.gate.expert_bias = nn.Parameter( - torch.empty((config.num_experts, ), dtype=torch.float32)) + torch.empty((config.num_experts,), dtype=torch.float32) + ) else: self.gate.expert_bias = None - self.correction_bias = (self.gate.expert_bias.data - if self.gate.expert_bias is not None else None) + self.correction_bias = ( + self.gate.expert_bias.data if self.gate.expert_bias is not None else None + ) if self.score_function is not None: assert ( - self.score_function == "softmax" - and self.correction_bias is None + self.score_function == "softmax" and self.correction_bias is None ) or ( - self.score_function == "sigmoid" - and self.correction_bias is not None - ), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" # noqa: E501 + self.score_function == "sigmoid" and self.correction_bias is not None + ), ( + "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" # noqa: E501 + ) else: # default value for scoring_func self.score_function = "softmax" @@ -293,7 +303,8 @@ def __init__( config=config, quant_config=quant_config, reduce_results=False, - prefix=f"{prefix}.shared_experts") + prefix=f"{prefix}.shared_experts", + ) else: self.shared_experts = None @@ -306,8 +317,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits = self.gate(hidden_states.to(self.router_dtype)) router_logits = router_logits.to(hidden_states.dtype) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) final_hidden_states *= self.routed_scaling_factor @@ -315,13 +327,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) class BailingMoeBlock(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -330,30 +340,26 @@ def __init__( prefix: str = "", ): super().__init__() - layer_idx = int(prefix.split('.')[-1]) + layer_idx = int(prefix.split(".")[-1]) self.config = config hidden_size = config.hidden_size intermediate_size = config.intermediate_size self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) - self.attention = BailingAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attention") + self.attention = BailingAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attention" + ) - self.post_attention_layernorm = RMSNorm(hidden_size, - eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) # Choose MLP class based on the number of experts and layer index if layer_idx < config.first_k_dense_replace: mlp_class = BailingMLP else: mlp_class = BailingMoE - self.mlp = mlp_class(intermediate_size, - config, - quant_config, - True, - prefix=f"{prefix}.mlp") + self.mlp = mlp_class( + intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp" + ) def forward( self, @@ -365,23 +371,20 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.attention( hidden_states=hidden_states, position_ids=position_ids, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class BailingMoeModel(nn.Module): - def __init__( self, *, @@ -396,11 +399,11 @@ def __init__( self.config = config self.vocab_size = config.vocab_size self.embed_dim = config.hidden_size - self.tie_word_embeddings = getattr(config, "tie_word_embeddings", - False) + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False) - if get_pp_group().is_first_rank or (self.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + self.tie_word_embeddings and get_pp_group().is_last_rank + ): self.word_embeddings = VocabParallelEmbedding( self.vocab_size, self.embed_dim, @@ -420,11 +423,12 @@ def __init__( quant_config=quant_config, prefix=prefix, ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) @@ -460,10 +464,9 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) else: if residual is None: hidden_states = self.norm(hidden_states) @@ -479,8 +482,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: num_experts=self.config.num_experts, ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -491,14 +493,14 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if (hasattr(self.config, "norm_head") and self.config.norm_head - and "lm_head.weight" in name): - loaded_weight = F.normalize(loaded_weight, - dim=0, - p=2, - eps=1e-7) - - for (param_name, weight_name, shard_id) in stacked_params_mapping: + if ( + hasattr(self.config, "norm_head") + and self.config.norm_head + and "lm_head.weight" in name + ): + loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7) + + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue if "mlp.experts" in name: @@ -548,15 +550,15 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - packed_modules_mapping = { "query_key_value": ["query_key_value"], "gate_up_proj": [ @@ -582,10 +584,10 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config self.max_position_embeddings = config.max_position_embeddings - self.model = BailingMoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.tie_word_embeddings = getattr(config, "tie_word_embeddings", - False) + self.model = BailingMoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False) if get_pp_group().is_last_rank: if self.tie_word_embeddings: @@ -602,7 +604,8 @@ def __init__( self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -614,8 +617,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( @@ -625,8 +629,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None), diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 4a6154dc548a..42c1c7be1a75 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Bamba model.""" + # Added by the IBM Team, 2024 from collections.abc import Iterable from typing import Optional @@ -16,29 +17,38 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant) -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class BambaMLP(nn.Module): - def __init__( self, config: BambaConfig, @@ -59,8 +69,10 @@ def __init__( quant_config=quant_config, ) if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -71,38 +83,38 @@ def forward(self, x): class BambaMixerDecoderLayer(nn.Module): - - def __init__(self, - config: BambaConfig, - layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: BambaConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config - self.mamba = MambaMixer2(hidden_size= config.hidden_size, - ssm_state_size = config.mamba_d_state, - conv_kernel_size = config.mamba_d_conv, - intermediate_size = config.mamba_expand *\ - config.hidden_size, - use_conv_bias = config.mamba_conv_bias, - use_bias = config.mamba_proj_bias, - n_groups=config.mamba_n_groups, - num_heads=config.mamba_n_heads, - head_dim=config.mamba_d_head, - rms_norm_eps=config.rms_norm_eps, - activation=config.hidden_act, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.mixer") + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=config.mamba_expand * config.hidden_size, + use_conv_bias=config.mamba_conv_bias, + use_bias=config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) self.feed_forward = BambaMLP(config, quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -114,8 +126,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) output = torch.empty_like(hidden_states) self.mamba(hidden_states, output) @@ -126,7 +137,6 @@ def forward( class BambaAttentionDecoderLayer(nn.Module): - def __init__( self, config: BambaConfig, @@ -139,8 +149,7 @@ def __init__( super().__init__() rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads @@ -188,10 +197,12 @@ def __init__( bias=False, quant_config=quant_config, ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) self.attn = Attention( self.num_heads, @@ -203,10 +214,8 @@ def __init__( ) self.feed_forward = BambaMLP(config, quant_config=quant_config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def self_attention( self, @@ -233,29 +242,26 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attention( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual ALL_DECODER_LAYER_TYPES = { "attention": BambaAttentionDecoderLayer, - "mamba": BambaMixerDecoderLayer + "mamba": BambaMixerDecoderLayer, } @support_torch_compile class BambaModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -266,8 +272,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -279,8 +288,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) - layer_class = ALL_DECODER_LAYER_TYPES[ - config.layers_block_type[layer_idx]] + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[layer_idx]] return layer_class( config, layer_idx, @@ -291,13 +299,13 @@ def get_layer(prefix: str): ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -309,7 +317,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -330,15 +337,13 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -383,22 +388,22 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsQuant): +class BambaForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant +): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], - "gate_up_proj": ["up_proj", "down_proj"] + "gate_up_proj": ["up_proj", "down_proj"], } # LoRA specific attributes @@ -413,7 +418,6 @@ def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -453,19 +457,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Bamba currently does not support prefix caching" - self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config - self.model = BambaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = BambaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -476,28 +477,33 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states @@ -508,7 +514,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 2ec3edc5a0a7..d9d4c62639d5 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -13,17 +13,21 @@ from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.pooler import (ClassifierPooler, - DispatchPooler, Pooler, - PoolingMethod, - PoolingParamsUpdate, - PoolingType) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + DispatchPooler, + Pooler, + PoolingMethod, + PoolingParamsUpdate, + PoolingType, +) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask from vllm.v1.pool.metadata import PoolingMetadata @@ -34,19 +38,19 @@ class BertEmbedding(nn.Module): - def __init__(self, config: BertConfig): - super().__init__() self.size = config.hidden_size - self.word_embeddings = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.position_embeddings = VocabParallelEmbedding( - config.max_position_embeddings, config.hidden_size) + config.max_position_embeddings, config.hidden_size + ) self.token_type_embeddings = VocabParallelEmbedding( - config.type_vocab_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + config.type_vocab_size, config.hidden_size + ) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.register_buffer( "position_ids", @@ -54,8 +58,9 @@ def __init__(self, config: BertConfig): ) self.position_embedding_type = config.position_embedding_type if self.position_embedding_type != "absolute": - raise ValueError("Only 'absolute' position_embedding_type" + - " is supported") + raise ValueError( + "Only 'absolute' position_embedding_type" + " is supported" + ) def forward( self, @@ -78,7 +83,6 @@ def forward( class BertPooler(Pooler): - def __init__(self, config: BertConfig): super().__init__() @@ -113,19 +117,22 @@ def forward( class BertEncoder(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - self.layer = nn.ModuleList([ - BertLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.layer.{layer_idx}") - for layer_idx in range(config.num_hidden_layers) - ]) + self.layer = nn.ModuleList( + [ + BertLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.layer.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -137,12 +144,13 @@ def forward( class BertLayer(nn.Module): - - def __init__(self, - config: BertConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.attention = BertAttention( @@ -151,20 +159,24 @@ def __init__(self, layer_norm_eps=config.layer_norm_eps, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attention") + prefix=f"{prefix}.attention", + ) self.intermediate = BertIntermediate( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.intermediate") + prefix=f"{prefix}.intermediate", + ) - self.output = BertOutput(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - layer_norm_eps=config.layer_norm_eps, - quant_config=quant_config, - prefix=f"{prefix}.output") + self.output = BertOutput( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_norm_eps=config.layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output", + ) def forward(self, hidden_states: torch.Tensor): attn_output = self.attention(hidden_states) @@ -174,7 +186,6 @@ def forward(self, hidden_states: torch.Tensor): class BertAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -186,16 +197,20 @@ def __init__( ): super().__init__() - self.self = BertSelfAttention(hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.output") + self.self = BertSelfAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.output", + ) - self.output = BertSelfOutput(hidden_size=hidden_size, - layer_norm_eps=layer_norm_eps, - quant_config=quant_config, - prefix=f"{prefix}.output") + self.output = BertSelfOutput( + hidden_size=hidden_size, + layer_norm_eps=layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output", + ) def forward( self, @@ -206,7 +221,6 @@ def forward( class BertSelfAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -239,15 +253,18 @@ def __init__( total_num_kv_heads=self.total_num_kv_heads, bias=True, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + prefix=f"{prefix}.qkv_proj", + ) - self.attn = EncoderOnlyAttention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = EncoderOnlyAttention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -260,41 +277,48 @@ def forward( class BertSelfOutput(nn.Module): - - def __init__(self, - hidden_size: int, - layer_norm_eps: float, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.dense = RowParallelLinear(input_size=hidden_size, - output_size=hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.dense") + self.dense = RowParallelLinear( + input_size=hidden_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) - def forward(self, hidden_states: torch.Tensor, - input_tensor: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: hidden_states, _ = self.dense(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertIntermediate(nn.Module): - - def __init__(self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.dense = ColumnParallelLinear(input_size=hidden_size, - output_size=intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.dense") + self.dense = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) self.intermediate_act_fn = get_act_fn(hidden_act) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -304,25 +328,29 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertOutput(nn.Module): - - def __init__(self, - hidden_size: int, - intermediate_size: int, - layer_norm_eps: float, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.dense = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.dense") + self.dense = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) - def forward(self, hidden_states: torch.Tensor, - input_tensor: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: hidden_states, _ = self.dense(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states @@ -331,7 +359,6 @@ def forward(self, hidden_states: torch.Tensor, @support_torch_compile @default_pooling_type("CLS") class BertModel(nn.Module, SupportsQuant): - is_pooling_model = True packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} @@ -347,11 +374,10 @@ def __init__( self.config = vllm_config.model_config.hf_config self.embeddings = embedding_class(self.config) - self.encoder = BertEncoder(vllm_config=vllm_config, - prefix=f"{prefix}.encoder") + self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embeddings(input_ids) + return self.embeddings.word_embeddings(input_ids) def forward( self, @@ -380,7 +406,7 @@ def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): other_weights = [] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -398,8 +424,7 @@ def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): return other_weights, loaded_stacked_params - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: other_weights, loaded_stacked_params = self._load_weights(weights) loader = AutoWeightsLoader(self, skip_prefixes=["pooler."]) @@ -410,7 +435,6 @@ def load_weights(self, weights: Iterable[tuple[str, @default_pooling_type("ALL") class BertPoolingModel(BertModel): - is_pooling_model = True def __init__( @@ -429,8 +453,7 @@ def __init__( config = vllm_config.model_config.hf_config self.pooler = BertPooler(config) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: other_weights, loaded_stacked_params = self._load_weights(weights) loader = AutoWeightsLoader(self) @@ -459,8 +482,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.model = self._build_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = self._build_model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.pooler = self._build_pooler(pooler_config) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -473,34 +497,35 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.model(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + return self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights_list = list(weights) - has_model_prefix = any( - name.startswith("model.") for name, _ in weights_list) + has_model_prefix = any(name.startswith("model.") for name, _ in weights_list) if not has_model_prefix: mapper = WeightsMapper(orig_to_new_prefix={"": "model."}) loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."]) return loader.load_weights(weights_list, mapper=mapper) - def _build_model(self, - vllm_config: VllmConfig, - prefix: str = "") -> BertModel: - return BertModel(vllm_config=vllm_config, - prefix=prefix, - embedding_class=BertEmbedding) + def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel: + return BertModel( + vllm_config=vllm_config, prefix=prefix, embedding_class=BertEmbedding + ) def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: - return DispatchPooler({ - "encode": Pooler.for_encode(pooler_config), - "embed": Pooler.for_embed(pooler_config), - }) + return DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) # Here we encode the token type ids together with the input ids. @@ -527,18 +552,18 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: TOKEN_TYPE_SHIFT = 30 -def _encode_token_type_ids(input_ids: torch.Tensor, - token_type_ids: torch.Tensor) -> None: +def _encode_token_type_ids( + input_ids: torch.Tensor, token_type_ids: torch.Tensor +) -> None: # input_ids can be padded to the right - input_ids[:token_type_ids.shape[0]].bitwise_or_( - token_type_ids << TOKEN_TYPE_SHIFT) + input_ids[: token_type_ids.shape[0]].bitwise_or_(token_type_ids << TOKEN_TYPE_SHIFT) def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: - - ids_mask = torch.ones_like(input_ids, - dtype=torch.int32, - device=input_ids.device) << TOKEN_TYPE_SHIFT + ids_mask = ( + torch.ones_like(input_ids, dtype=torch.int32, device=input_ids.device) + << TOKEN_TYPE_SHIFT + ) tokens_mask = ids_mask.bitwise_not() token_type_ids = input_ids.bitwise_and(ids_mask) >> TOKEN_TYPE_SHIFT @@ -549,17 +574,16 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: @default_pooling_type("CLS") -class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsQuant): +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant): """A model that uses Bert to provide embedding functionalities. - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. - Attributes: - model: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ is_pooling_model = True @@ -568,34 +592,39 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.num_labels = config.num_labels - self.bert = BertPoolingModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "bert"), - embedding_class=BertEmbedding) - self.classifier = nn.Linear(config.hidden_size, - config.num_labels, - dtype=vllm_config.model_config.head_dtype) + self.bert = BertPoolingModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding, + ) + self.classifier = nn.Linear( + config.hidden_size, + config.num_labels, + dtype=vllm_config.model_config.head_dtype, + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - ClassifierPooler( - pooling=self.bert.pooler, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - ), - "score": - ClassifierPooler( - pooling=self.bert.pooler, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=self.bert.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=self.bert.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.bert.get_input_embeddings(input_ids) @@ -613,16 +642,17 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if token_type_ids is not None: assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) assert input_ids is not None _encode_token_type_ids(input_ids, token_type_ids) - return self.bert(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + return self.bert( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) @default_pooling_type("ALL") @@ -634,20 +664,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.head_dtype = vllm_config.model_config.head_dtype self.num_labels = config.num_labels - self.bert = BertModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "bert"), - embedding_class=BertEmbedding) - self.classifier = nn.Linear(config.hidden_size, - config.num_labels, - dtype=self.head_dtype) + self.bert = BertModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding, + ) + self.classifier = nn.Linear( + config.hidden_size, config.num_labels, dtype=self.head_dtype + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + } + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.bert.get_input_embeddings(input_ids) @@ -665,16 +698,17 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if token_type_ids is not None: assert self.bert.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) assert input_ids is not None _encode_token_type_ids(input_ids, token_type_ids) - hidden_states = self.bert(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + hidden_states = self.bert( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) hidden_states = hidden_states.to(self.head_dtype) return self.classifier(hidden_states) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 4e1eba32d259..05cb0e22a0aa 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -10,25 +10,30 @@ from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.activation import (get_act_and_mul_fn, - get_act_fn) -from vllm.model_executor.layers.fused_moe import (activation_without_mul, - fused_topk) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.activation import get_act_and_mul_fn, get_act_fn +from vllm.model_executor.layers.fused_moe import activation_without_mul, fused_topk +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, - maybe_prefix) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + maybe_prefix, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -40,24 +45,24 @@ class BertWithRopeEmbedding(nn.Module): - def __init__(self, config: PretrainedConfig): - super().__init__() if config.position_embedding_type not in ["rope", "rotary"]: - raise ValueError("Only 'rotary'('rope') position_embedding_type" + - " is supported") + raise ValueError( + "Only 'rotary'('rope') position_embedding_type" + " is supported" + ) - self.word_embeddings = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) if config.type_vocab_size > 0: self.token_type_embeddings = VocabParallelEmbedding( - config.type_vocab_size, config.hidden_size) + config.type_vocab_size, config.hidden_size + ) else: self.token_type_embeddings = None - self.LayerNorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, @@ -70,9 +75,9 @@ def forward( embeddings = inputs_embeds if self.token_type_embeddings is not None: if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device) + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=inputs_embeds.device + ) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings += token_type_embeddings @@ -82,7 +87,6 @@ def forward( class BertWithRopeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -119,23 +123,28 @@ def __init__( total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + prefix=f"{prefix}.qkv_proj", + ) self.rotary_emb = get_rope(**rotary_kwargs) - self.attn = EncoderOnlyAttention(num_heads=self.num_heads, - head_size=self.head_dim, - scale=self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = EncoderOnlyAttention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) - self.out_proj = RowParallelLinear(input_size=hidden_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.dense") + self.out_proj = RowParallelLinear( + input_size=hidden_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) def forward( self, @@ -151,14 +160,15 @@ def forward( class BertWithRopeGatedMLP(nn.Module): - - def __init__(self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.act_fn = get_act_and_mul_fn(hidden_act) self.gate_up_proj = MergedColumnParallelLinear( @@ -168,11 +178,13 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(hidden_states) @@ -182,26 +194,31 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertWithRopeMLP(nn.Module): - - def __init__(self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - bias: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.act_fn = get_act_fn(hidden_act) - self.up_proj = ColumnParallelLinear(input_size=hidden_size, - output_size=intermediate_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.up_proj") - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.up_proj(hidden_states) @@ -211,7 +228,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class NomicMoE(nn.Module): - def __init__( self, num_experts: int, @@ -236,28 +252,40 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - self.router = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False) + self.router = ReplicatedLinear( + self.hidden_size, self.num_total_experts, bias=False + ) self.w1 = nn.Parameter( - torch.empty(self.num_total_experts, - self.intermediate_size, - self.hidden_size, - device=current_platform.device_type, - dtype=self.params_dtype)) + torch.empty( + self.num_total_experts, + self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) self.w2 = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device=current_platform.device_type, - dtype=self.params_dtype)) + torch.empty( + self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) self.bias = nn.Parameter(torch.zeros(self.hidden_size)) - set_weight_attrs(self.w1, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2, { - "weight_loader": self.weight_loader, - }) + set_weight_attrs( + self.w1, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2, + { + "weight_loader": self.weight_loader, + }, + ) def weight_loader( self, @@ -293,10 +321,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # FIXME(Isotr0py): This implementation is too tricky, # we should use FusedMoE instead in the future # after supporting ungated activation for it. - topk_weights, topk_ids, _ = fused_topk(hidden_states, - router_logits, - self.top_k, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk( + hidden_states, router_logits, self.top_k, renormalize=False + ) final_hidden_states = torch.ops.vllm.outplace_fused_experts( hidden_states=hidden_states, @@ -308,22 +335,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ) if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) + self.bias class BertWithRopeBlock(nn.Module): - - def __init__(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - moe: bool = False, - bias: bool = True, - rotary_kwargs: Optional[dict] = None, - prefix: str = ""): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + moe: bool = False, + bias: bool = True, + rotary_kwargs: Optional[dict] = None, + prefix: str = "", + ): super().__init__() self.attn = BertWithRopeAttention( hidden_size=config.hidden_size, @@ -332,14 +359,17 @@ def __init__(self, quant_config=quant_config, bias=bias, rotary_kwargs=rotary_kwargs, - prefix=f"{prefix}.attention") + prefix=f"{prefix}.attention", + ) if moe: - self.mlp = NomicMoE(num_experts=config.num_experts, - top_k=config.moe_top_k, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act) + self.mlp = NomicMoE( + num_experts=config.num_experts, + top_k=config.moe_top_k, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) else: if config.hidden_act in ["silu", "geglu"]: self.mlp = BertWithRopeGatedMLP( @@ -348,7 +378,8 @@ def __init__(self, hidden_act=config.hidden_act, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + ) else: self.mlp = BertWithRopeMLP( hidden_size=config.hidden_size, @@ -356,12 +387,11 @@ def __init__(self, hidden_act=config.hidden_act, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + ) - self.attn_ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.mlp_ln = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor): attn_output = self.attn(positions, hidden_states) @@ -372,27 +402,32 @@ def forward(self, positions: torch.Tensor, hidden_states: torch.Tensor): class BertWithRopeEncoder(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - bias: bool = True, - rotary_kwargs: Optional[dict] = None, - prefix: str = ""): + def __init__( + self, + vllm_config: VllmConfig, + bias: bool = True, + rotary_kwargs: Optional[dict] = None, + prefix: str = "", + ): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config every_n = getattr(config, "moe_every_n_layers", 0) - self.layers = nn.ModuleList([ - BertWithRopeBlock(config=config, - cache_config=cache_config, - quant_config=quant_config, - bias=bias, - moe=every_n > 0 and (layer_idx % every_n == 1), - rotary_kwargs=rotary_kwargs, - prefix=f"{prefix}.layer.{layer_idx}") - for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + BertWithRopeBlock( + config=config, + cache_config=cache_config, + quant_config=quant_config, + bias=bias, + moe=every_n > 0 and (layer_idx % every_n == 1), + rotary_kwargs=rotary_kwargs, + prefix=f"{prefix}.layer.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -409,11 +444,13 @@ def forward( class BertWithRope(nn.Module, SupportsQuant): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - add_pooling_layer: bool = False): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + add_pooling_layer: bool = False, + ): super().__init__() self.vllm_config = vllm_config self.add_pooling_layer = add_pooling_layer @@ -423,7 +460,8 @@ def __init__(self, vllm_config=vllm_config, bias=getattr(self.config, "bias", True), rotary_kwargs=self.config.rotary_kwargs, - prefix=f"{prefix}.encoder") + prefix=f"{prefix}.encoder", + ) self.pooler = BertPooler(self.config) if add_pooling_layer else None def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -440,12 +478,12 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embeddings(input_ids=input_ids, - token_type_ids=token_type_ids) + hidden_states = self.embeddings( + input_ids=input_ids, token_type_ids=token_type_ids + ) return self.encoder(positions, hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.hf_to_vllm_mapper.apply(weights) if self.config.hidden_act in ["silu", "geglu"]: @@ -462,7 +500,7 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if not self.add_pooling_layer and "pooler" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -478,8 +516,7 @@ def load_weights(self, weights: Iterable[tuple[str, if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if name.endswith((".w1", ".w2")): # Nomic-MoE has fused experts weights weight_loader(param, loaded_weight, name) @@ -506,7 +543,8 @@ class NomicBertModel(BertWithRope): "experts.mlp.": "", "experts.": "", "router.layer": "router", - }) + } + ) class GteNewModel(BertWithRope): @@ -518,7 +556,8 @@ class GteNewModel(BertWithRope): "layer": "layers", "attention.qkv_proj": "attn.qkv_proj", "attention.o_proj": "attn.out_proj", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) @@ -539,15 +578,13 @@ def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]): else: yield name, weight - def ignore_unnecessary_layers(self, - weights: Iterable[tuple[str, torch.Tensor]]): + def ignore_unnecessary_layers(self, weights: Iterable[tuple[str, torch.Tensor]]): for name, weight in weights: if name.startswith("classifier"): continue yield name, weight - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.ignore_unnecessary_layers(weights) weights = self.split_up_gate_proj(weights) return super().load_weights(weights) @@ -561,7 +598,8 @@ class SnowflakeGteNewModel(GteNewModel): "layer": "layers", "attention.qkv_proj": "attn.qkv_proj", "attention.o_proj": "attn.out_proj", - }) + } + ) class JinaRobertaModel(BertWithRope): @@ -576,11 +614,11 @@ class JinaRobertaModel(BertWithRope): "mlp.fc1.": "mlp.up_proj.", "mlp.fc2": "mlp.down_proj", "norm2": "mlp_ln", - }) + } + ) @torch.inference_mode() - def jina_merge_lora_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]): + def jina_merge_lora_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # use for jina-embeddings-v3 # Merge Lora weights into a single weight tensor. # This is a temporary solution until we have a better way to handle @@ -601,7 +639,7 @@ def jina_merge_lora_weights(self, weights: Iterable[tuple[str, if o in name: dtype = weights[name].dtype shape = weights[name].shape - weight_name = name[:-len(o)] + weight_name = name[: -len(o)] if "embeddings" in weight_name: B = weights[weight_name + a][i].to(device).float() @@ -610,20 +648,23 @@ def jina_merge_lora_weights(self, weights: Iterable[tuple[str, B = weights[weight_name + b][i].to(device).float() A = weights[weight_name + a][i].to(device).float() - weight = (weights[weight_name + o].to(device) + - torch.matmul(B, A).view(shape) * scaling) + weight = ( + weights[weight_name + o].to(device) + + torch.matmul(B, A).view(shape) * scaling + ) weight = weight.cpu().to(dtype) weights[weight_name.replace(".parametrizations", "")] = weight - del weights[weight_name + o], weights[weight_name + - a], weights[weight_name + - b] + del ( + weights[weight_name + o], + weights[weight_name + a], + weights[weight_name + b], + ) return [(name, weight) for name, weight in weights.items()] - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.jina_merge_lora_weights(weights) return super().load_weights(weights) @@ -637,9 +678,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.new = GteNewModel(vllm_config=vllm_config, - prefix=prefix, - add_pooling_layer=True) + self.new = GteNewModel( + vllm_config=vllm_config, prefix=prefix, add_pooling_layer=True + ) self.classifier = ReplicatedLinear( config.hidden_size, config.num_labels, @@ -647,29 +688,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, params_dtype=vllm_config.model_config.head_dtype, prefix=maybe_prefix(prefix, "classifier"), - return_bias=False) + return_bias=False, + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - ClassifierPooler( - pooling=self.new.pooler, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - ), - "score": - ClassifierPooler( - pooling=self.new.pooler, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=self.new.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) @@ -686,8 +729,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - return self.new(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + return self.new( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 2b457fd8a5b2..aa361e0a2a39 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Minimal implementation of BlipVisionModel intended to be only used +"""Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" + from collections.abc import Iterable from typing import Optional, Union @@ -12,9 +13,11 @@ from vllm.attention.layer import MultiHeadAttention from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -27,14 +30,14 @@ def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_blip_num_patches(*, image_size: int, patch_size: int) -> int: - grid_length = get_blip_patch_grid_length(image_size=image_size, - patch_size=patch_size) + grid_length = get_blip_patch_grid_length( + image_size=image_size, patch_size=patch_size + ) return grid_length * grid_length # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa class BlipVisionEmbeddings(nn.Module): - def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]): super().__init__() @@ -52,25 +55,28 @@ def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]): stride=self.patch_size, ) - self.num_patches = get_blip_num_patches(image_size=self.image_size, - patch_size=self.patch_size) + self.num_patches = get_blip_num_patches( + image_size=self.image_size, patch_size=self.patch_size + ) self.num_positions = self.num_patches + 1 self.position_embedding = nn.Parameter( - torch.randn(1, self.num_positions, self.embed_dim)) + torch.randn(1, self.num_positions, self.embed_dim) + ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size = pixel_values.shape[0] target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to( - dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) position_embeds = self.position_embedding.to(target_dtype) - embeddings = embeddings + position_embeds[:, :embeddings.size(1), :] + embeddings = embeddings + position_embeds[:, : embeddings.size(1), :] return embeddings @@ -93,7 +99,8 @@ def __init__( raise ValueError( "embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout @@ -115,12 +122,16 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) def forward( self, @@ -137,7 +148,6 @@ def forward( class BlipMLP(nn.Module): - def __init__( self, config: BlipVisionConfig, @@ -149,16 +159,20 @@ def __init__( self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -169,7 +183,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BlipEncoderLayer(nn.Module): - def __init__( self, config: BlipVisionConfig, @@ -184,13 +197,9 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - self.layer_norm1 = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.mlp = BlipMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.layer_norm2 = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = BlipMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states @@ -209,7 +218,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BlipEncoder(nn.Module): """ - Transformer encoder consisting of `config.num_hidden_layers` self + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`BlipEncoderLayer`]. Args: @@ -232,12 +241,16 @@ def __init__( else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - BlipEncoderLayer(config=config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + BlipEncoderLayer( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward(self, inputs_embeds: torch.Tensor): hidden_states = inputs_embeds @@ -284,8 +297,9 @@ def __init__( require_post_norm = len(self.encoder.layers) == num_hidden_layers if require_post_norm: - self.post_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) else: self.post_layernorm = None @@ -298,8 +312,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return self.post_layernorm(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -312,8 +325,7 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: # post_layernorm is not needed in BlipVisionModel - if (name.startswith("post_layernorm") - and self.post_layernorm is None): + if name.startswith("post_layernorm") and self.post_layernorm is None: continue # omit layers when num_hidden_layers_override is set @@ -322,7 +334,7 @@ def load_weights(self, weights: Iterable[tuple[str, if layer_idx >= layer_count: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -332,8 +344,7 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 3d057654cca7..8e94d5935026 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -6,27 +6,42 @@ import torch import torch.nn as nn -from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, - apply_chunking_to_forward) +from transformers import ( + BatchFeature, + Blip2Config, + Blip2QFormerConfig, + apply_chunking_to_forward, +) from vllm.config import CacheConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptIndexTargets, - PromptInsertion, PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptInsertion, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .blip import BlipVisionModel -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, - SupportsQuant) +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix @@ -38,6 +53,7 @@ class Blip2ImagePixelInputs(TensorSchema): - h: Height of each image - w: Width of each image """ + type: Literal["pixel_values"] data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] @@ -49,6 +65,7 @@ class Blip2ImageEmbeddingInputs(TensorSchema): - f: Image feature size - h: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["image_embeds"] data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] @@ -57,7 +74,6 @@ class Blip2ImageEmbeddingInputs(TensorSchema): class Blip2QFormerMultiHeadAttention(nn.Module): - def __init__( self, config: Blip2QFormerConfig, @@ -78,8 +94,7 @@ def __init__( ) self.num_attention_heads = config.num_attention_heads - self.attention_head_size = (config.hidden_size // - config.num_attention_heads) + self.attention_head_size = config.hidden_size // config.num_attention_heads self.all_head_size = self.num_attention_heads * self.attention_head_size self.scaling = self.attention_head_size**-0.5 @@ -91,18 +106,18 @@ def __init__( self.key = nn.Linear(kv_hidden_size, self.all_head_size) self.value = nn.Linear(kv_hidden_size, self.all_head_size) - self.position_embedding_type = getattr(config, - "position_embedding_type", - "absolute") + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) if self.position_embedding_type != "absolute": - raise NotImplementedError("Unsupported position_embedding_type: " - f"{self.position_embedding_type}") + raise NotImplementedError( + f"Unsupported position_embedding_type: {self.position_embedding_type}" + ) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): - x = x.view(*x.size()[:-1], self.num_attention_heads, - self.attention_head_size) + x = x.view(*x.size()[:-1], self.num_attention_heads, self.attention_head_size) return x.permute(0, 2, 1, 3) def forward( @@ -113,10 +128,8 @@ def forward( is_cross_attention = encoder_hidden_states is not None if is_cross_attention: - key_layer = self.transpose_for_scores( - self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores( - self.value(encoder_hidden_states)) + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) @@ -125,10 +138,8 @@ def forward( query_layer = self.transpose_for_scores(mixed_query_layer) - attention_scores = torch.matmul(query_layer, - key_layer.transpose(-1, -2)) - attention_probs = torch.softmax(attention_scores * self.scaling, - dim=-1) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_probs = torch.softmax(attention_scores * self.scaling, dim=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -137,20 +148,19 @@ def forward( context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - context_layer = context_layer.view(*context_layer.size()[:-2], - self.all_head_size) + context_layer = context_layer.view( + *context_layer.size()[:-2], self.all_head_size + ) return context_layer class Blip2QFormerSelfOutput(nn.Module): - def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( @@ -165,7 +175,6 @@ def forward( class Blip2QFormerAttention(nn.Module): - def __init__( self, config: Blip2QFormerConfig, @@ -202,7 +211,6 @@ def forward( class Blip2QFormerIntermediate(nn.Module): - def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() @@ -216,13 +224,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Blip2QFormerOutput(nn.Module): - def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None: super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( @@ -237,7 +243,6 @@ def forward( class Blip2QFormerLayer(nn.Module): - def __init__( self, config: Blip2QFormerConfig, @@ -251,10 +256,12 @@ def __init__( self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = Blip2QFormerAttention(config, - quant_config=quant_config, - cache_config=cache_config, - prefix=f"{prefix}.attention") + self.attention = Blip2QFormerAttention( + config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.attention", + ) self.layer_idx = layer_idx @@ -264,15 +271,16 @@ def __init__( quant_config=quant_config, cache_config=cache_config, is_cross_attention=True, - prefix=f"{prefix}.crossattention") + prefix=f"{prefix}.crossattention", + ) self.has_cross_attention = True else: self.has_cross_attention = False self.intermediate_query = Blip2QFormerIntermediate( - config, prefix=f"{prefix}.intermediate_query") - self.output_query = Blip2QFormerOutput(config, - prefix=f"{prefix}.output_query") + config, prefix=f"{prefix}.intermediate_query" + ) + self.output_query = Blip2QFormerOutput(config, prefix=f"{prefix}.output_query") def forward( self, @@ -305,8 +313,7 @@ def forward( self.seq_len_dim, attention_output[:, query_length:, :], ) - layer_output = torch.cat([layer_output, layer_output_text], - dim=1) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) else: layer_output = apply_chunking_to_forward( self.feed_forward_chunk, @@ -317,21 +324,18 @@ def forward( return layer_output - def feed_forward_chunk(self, - attention_output: torch.Tensor) -> torch.Tensor: + def feed_forward_chunk(self, attention_output: torch.Tensor) -> torch.Tensor: intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output - def feed_forward_chunk_query( - self, attention_output: torch.Tensor) -> torch.Tensor: + def feed_forward_chunk_query(self, attention_output: torch.Tensor) -> torch.Tensor: intermediate_output = self.intermediate_query(attention_output) layer_output = self.output_query(intermediate_output, attention_output) return layer_output class Blip2QFormerEncoder(nn.Module): - def __init__( self, config: Blip2QFormerConfig, @@ -344,14 +348,18 @@ def __init__( self.config = config - self.layer = nn.ModuleList([ - Blip2QFormerLayer(config, - quant_config=quant_config, - cache_config=cache_config, - layer_idx=layer_idx, - prefix=f"{prefix}.layer.{layer_idx}") - for layer_idx in range(config.num_hidden_layers) - ]) + self.layer = nn.ModuleList( + [ + Blip2QFormerLayer( + config, + quant_config=quant_config, + cache_config=cache_config, + layer_idx=layer_idx, + prefix=f"{prefix}.layer.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -373,7 +381,6 @@ def forward( # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025 class Blip2QFormerModel(nn.Module): - def __init__( self, config: Blip2QFormerConfig, @@ -386,14 +393,15 @@ def __init__( self.config = config - self.layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.encoder = Blip2QFormerEncoder(config, - quant_config=quant_config, - cache_config=cache_config, - prefix=f"{prefix}.encoder") + self.encoder = Blip2QFormerEncoder( + config, + quant_config=quant_config, + cache_config=cache_config, + prefix=f"{prefix}.encoder", + ) def forward( self, @@ -415,7 +423,6 @@ def forward( class Blip2ProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Blip2Config) @@ -428,7 +435,6 @@ def get_num_image_tokens(self) -> int: class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -447,16 +453,16 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=max_image_size, - height=max_image_size, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=max_image_size, + height=max_image_size, + num_images=num_images, + overrides=image_overrides, + ) } class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): - def _call_hf_processor( self, prompt: str, @@ -509,11 +515,14 @@ def _get_prompt_updates( ] -@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor, - info=Blip2ProcessingInfo, - dummy_inputs=Blip2DummyInputsBuilder) -class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, - SupportsQuant): +@MULTIMODAL_REGISTRY.register_processor( + Blip2MultiModalProcessor, + info=Blip2ProcessingInfo, + dummy_inputs=Blip2DummyInputsBuilder, +) +class Blip2ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant +): merge_by_field_config = True @classmethod @@ -524,7 +533,6 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: raise ValueError("Only image modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -537,13 +545,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vision_model = BlipVisionModel(config.vision_config, quant_config) self.query_tokens = nn.Parameter( - torch.zeros(1, config.num_query_tokens, - config.qformer_config.hidden_size)) + torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size) + ) - self.qformer = Blip2QFormerModel(config.qformer_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.qformer") + self.qformer = Blip2QFormerModel( + config.qformer_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.qformer", + ) self.language_projection = nn.Linear( config.qformer_config.hidden_size, @@ -558,10 +568,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Blip2ImageInputs]: + self, **kwargs: object + ) -> Optional[Blip2ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -570,12 +582,11 @@ def _parse_and_validate_image_input( if pixel_values is not None: expected_h = expected_w = self.config.vision_config.image_size - return Blip2ImagePixelInputs(type="pixel_values", - data=pixel_values, - resolve_bindings={ - "h": expected_h, - "w": expected_w - }) + return Blip2ImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, + ) if image_embeds is not None: return Blip2ImageEmbeddingInputs( @@ -585,34 +596,30 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") - def _image_pixels_to_features(self, vision_model: BlipVisionModel, - pixel_values: torch.Tensor) -> torch.Tensor: - + def _image_pixels_to_features( + self, vision_model: BlipVisionModel, pixel_values: torch.Tensor + ) -> torch.Tensor: # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower image_features = vision_model(pixel_values) return image_features - def _process_image_pixels(self, - inputs: Blip2ImagePixelInputs) -> torch.Tensor: + def _process_image_pixels(self, inputs: Blip2ImagePixelInputs) -> torch.Tensor: assert self.vision_model is not None pixel_values = inputs["data"] return self._image_pixels_to_features(self.vision_model, pixel_values) - def _process_image_input(self, - image_input: Blip2ImageInputs) -> torch.Tensor: - + def _process_image_input(self, image_input: Blip2ImageInputs) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["data"] assert self.vision_model is not None image_features = self._process_image_pixels(image_input) - query_tokens = self.query_tokens.expand(image_features.shape[0], -1, - -1) + query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1) query_output = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_features, @@ -623,8 +630,7 @@ def _process_image_input(self, def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -651,7 +657,7 @@ def forward( `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`. To reserve space in KV cache, we have to insert placeholder tokens - before they are inputted to the model, so the input processor prepends + before they are inputted to the model, so the input processor prepends dummy tokens (denoted as `50265`), resulting in: `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`. @@ -664,7 +670,7 @@ def forward( Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. - + Info: [`Blip2ImageInputs`][vllm.model_executor.models.blip2.Blip2ImageInputs] """ @@ -672,10 +678,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -685,7 +690,6 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 30816f72a267..4a814fc4020d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable from itertools import islice @@ -30,29 +31,40 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP, SupportsQuant -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) base = torch.tensor( - 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) @@ -60,22 +72,20 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(start=1, - end=1 + 2 * num_remaining_heads, - step=2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes class BloomAttention(nn.Module): - def __init__( self, config: BloomConfig, @@ -115,13 +125,15 @@ def __init__( alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -137,7 +149,6 @@ def forward( class BloomMLP(nn.Module): - def __init__( self, config: BloomConfig, @@ -165,7 +176,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class BloomBlock(nn.Module): - def __init__( self, config: BloomConfig, @@ -176,17 +186,17 @@ def __init__( super().__init__() hidden_size = config.hidden_size - self.input_layernorm = nn.LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attention") + self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.self_attention = BloomAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attention" + ) self.post_attention_layernorm = nn.LayerNorm( - hidden_size, eps=config.layer_norm_epsilon) + hidden_size, eps=config.layer_norm_epsilon + ) self.mlp = BloomMLP(config, quant_config) self.apply_residual_connection_post_layernorm = ( - config.apply_residual_connection_post_layernorm) + config.apply_residual_connection_post_layernorm + ) def forward( self, @@ -223,7 +233,6 @@ def forward( @support_torch_compile class BloomModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -240,20 +249,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_dim, ) self.word_embeddings_layernorm = nn.LayerNorm( - self.embed_dim, eps=config.layer_norm_epsilon) + self.embed_dim, eps=config.layer_norm_epsilon + ) # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, lambda prefix: BloomBlock( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.h", + ) # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.word_embeddings(input_ids) @@ -281,8 +293,7 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -300,14 +311,14 @@ def load_weights(self, weights: Iterable[tuple[str, if output_dim is not None: loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) + loaded_weight_shape[:output_dim] + + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1 :] + ) + loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1) loaded_weight = loaded_weight.reshape(loaded_weight_shape) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -315,27 +326,28 @@ def load_weights(self, weights: Iterable[tuple[str, class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = BloomModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = BloomModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) if self.config.tie_word_embeddings: self.lm_head = self.transformer.word_embeddings else: - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -347,8 +359,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -358,17 +371,16 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"]) weights = _add_transformer_prefix(weights) return loader.load_weights(weights) def _add_transformer_prefix( - weights: Iterable[tuple[str, torch.Tensor]] + weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, torch.Tensor]]: for name, tensor in weights: - if not name.startswith('transformer.'): - name = 'transformer.' + name + if not name.startswith("transformer."): + name = "transformer." + name yield name, tensor diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index b1432dcb9d6d..d8756e236f4c 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -9,8 +9,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, - ChameleonVQVAEConfig) +from transformers import ( + BatchFeature, + ChameleonConfig, + ChameleonProcessor, + ChameleonVQVAEConfig, +) from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig @@ -19,33 +23,53 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, row_parallel_weight_loader) + default_weight_loader, + row_parallel_weight_loader, +) from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, - SupportsQuant) -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from .utils import ( + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -58,12 +82,12 @@ class ChameleonImagePixelInputs(TensorSchema): - h: Height of each image - w: Width of each image """ + type: Literal["pixel_values"] data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] class ChameleonProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(ChameleonConfig) @@ -78,9 +102,7 @@ def get_num_image_tokens(self) -> int: return processor.image_seq_length -class ChameleonDummyInputsBuilder( - BaseDummyInputsBuilder[ChameleonProcessingInfo]): - +class ChameleonDummyInputsBuilder(BaseDummyInputsBuilder[ChameleonProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -103,17 +125,16 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=width, - height=height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=width, + height=height, + num_images=num_images, + overrides=image_overrides, + ) } -class ChameleonMultiModalProcessor( - BaseMultiModalProcessor[ChameleonProcessingInfo]): - +class ChameleonMultiModalProcessor(BaseMultiModalProcessor[ChameleonProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -182,29 +203,23 @@ def _get_prompt_updates( class ChameleonLayerNorm(nn.LayerNorm): - def __init__(self, hidden_size, *args, **kwargs): super().__init__(hidden_size, *args, **kwargs) - self.normalized_shape = (hidden_size[-1], ) + self.normalized_shape = (hidden_size[-1],) - set_weight_attrs(self.weight, - {"weight_loader": row_parallel_weight_loader}) - set_weight_attrs(self.bias, - {"weight_loader": row_parallel_weight_loader}) + set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader}) + set_weight_attrs(self.bias, {"weight_loader": row_parallel_weight_loader}) def forward(self, hidden_states): - hidden_states = F.layer_norm(hidden_states, - self.normalized_shape, - None, - None, - eps=1e-5) + hidden_states = F.layer_norm( + hidden_states, self.normalized_shape, None, None, eps=1e-5 + ) hidden_states = hidden_states * self.weight + self.bias return hidden_states # Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP class ChameleonMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -218,14 +233,18 @@ def __init__( input_size=hidden_size, output_sizes=[intermediate_size] * 2, bias=bias, - quant_config=quant_config) - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config) + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -237,7 +256,6 @@ def forward(self, x): # Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa class ChameleonAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -298,16 +316,19 @@ def __init__( rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) - def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: # reshape for layernorm q = q.reshape(-1, self.num_heads, self.head_dim) k = k.reshape(-1, self.num_kv_heads, self.head_dim) @@ -333,7 +354,6 @@ def forward( class ChameleonDecoderLayer(nn.Module): - def __init__( self, config: ChameleonConfig, @@ -346,17 +366,19 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 4096) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 4096) self.self_attn = ChameleonAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -372,10 +394,10 @@ def __init__( quant_config=quant_config, bias=getattr(config, "mlp_bias", False), ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -383,28 +405,24 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class ChameleonSwinDecoderLayer(nn.Module): - def __init__( self, config: ChameleonConfig, @@ -417,17 +435,19 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 4096) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 4096) self.self_attn = ChameleonAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -443,10 +463,10 @@ def __init__( quant_config=quant_config, bias=getattr(config, "mlp_bias", False), ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -454,7 +474,6 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - residual = hidden_states hidden_states = self.self_attn( positions=positions, @@ -475,7 +494,6 @@ def forward( # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa class ChameleonVQVAEVectorQuantizer(nn.Module): - def __init__(self, config: ChameleonVQVAEConfig): super().__init__() self.num_embeddings = config.num_embeddings @@ -491,55 +509,52 @@ def forward(self, hidden_state: torch.Tensor): # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z distances = ( - torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + - torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, - self.embedding.weight.transpose(0, 1))) + torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", + hidden_state_flattened, + self.embedding.weight.transpose(0, 1), + ) + ) min_encoding_indices = torch.argmin(distances, dim=1) hidden_state_quant = self.embedding(min_encoding_indices).view( - hidden_state.shape) + hidden_state.shape + ) # compute loss for embedding - loss = torch.mean((hidden_state_quant.detach() - hidden_state)** - 2) + self.beta * torch.mean( - (hidden_state_quant - hidden_state.detach())**2) + loss = torch.mean( + (hidden_state_quant.detach() - hidden_state) ** 2 + ) + self.beta * torch.mean((hidden_state_quant - hidden_state.detach()) ** 2) # preserve gradients - hidden_state_quant = hidden_state + (hidden_state_quant - - hidden_state).detach() + hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach() # reshape back to match original input shape - hidden_state_quant = hidden_state_quant.permute(0, 3, 1, - 2).contiguous() + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() return hidden_state_quant, loss, min_encoding_indices # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa class ChameleonVQVAEEncoderConvDownsample(nn.Module): - def __init__(self, in_channels: int): super().__init__() - self.conv = nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) def forward(self, hidden_states: torch.Tensor): # no asymmetric padding in torch conv, must do it ourselves - hidden_states = F.pad(hidden_states, - pad=(0, 1, 0, 1), - mode="constant", - value=0) + hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0) hidden_states = self.conv(hidden_states) return hidden_states # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa class ChameleonVQVAEEncoderResnetBlock(nn.Module): - def __init__( self, config: ChameleonVQVAEConfig, @@ -549,42 +564,31 @@ def __init__( ): super().__init__() self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None \ - else out_channels + self.out_channels = in_channels if out_channels is None else out_channels self.use_conv_shortcut = conv_shortcut - self.norm1 = torch.nn.GroupNorm(num_groups=32, - num_channels=in_channels, - eps=1e-6, - affine=True) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - self.norm2 = torch.nn.GroupNorm(num_groups=32, - num_channels=out_channels, - eps=1e-6, - affine=True) + self.norm1 = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = torch.nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) self.dropout = torch.nn.Dropout(config.dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, hidden_states: torch.Tensor): residual = hidden_states @@ -608,35 +612,25 @@ def forward(self, hidden_states: torch.Tensor): # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa class ChameleonVQVAEEncoderAttnBlock(nn.Module): - def __init__(self, in_channels: int): super().__init__() self.in_channels = in_channels - self.norm = torch.nn.GroupNorm(num_groups=32, - num_channels=in_channels, - eps=1e-6, - affine=True) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) def forward(self, hidden_states: torch.Tensor): residual = hidden_states @@ -647,20 +641,20 @@ def forward(self, hidden_states: torch.Tensor): # compute attention batch_size, channels, height, width = query_states.shape - query_states = query_states.reshape(batch_size, channels, - height * width).permute(0, 2, 1) + query_states = query_states.reshape( + batch_size, channels, height * width + ).permute(0, 2, 1) key_states = key_states.reshape(batch_size, channels, height * width) attn_weights = torch.bmm(query_states, key_states) - attn_weights = attn_weights * (int(channels)**(-0.5)) + attn_weights = attn_weights * (int(channels) ** (-0.5)) attn_weights = F.softmax(attn_weights, dim=2) # attend to values - value_states = value_states.reshape(batch_size, channels, - height * width) + value_states = value_states.reshape(batch_size, channels, height * width) attn_weights = attn_weights.permute(0, 2, 1) - attn_output = torch.bmm(value_states, - attn_weights).reshape(batch_size, channels, - height, width) + attn_output = torch.bmm(value_states, attn_weights).reshape( + batch_size, channels, height, width + ) attn_output = self.proj_out(attn_output) return residual + attn_output @@ -668,7 +662,6 @@ def forward(self, hidden_states: torch.Tensor): # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa class ChameleonVQVAEEncoder(nn.Module): - def __init__(self, config: ChameleonVQVAEConfig): super().__init__() @@ -681,14 +674,12 @@ def __init__(self, config: ChameleonVQVAEConfig): latent_channels = config.latent_channels channel_multiplier = config.channel_multiplier - self.conv_in = torch.nn.Conv2d(in_channels, - base_channels, - kernel_size=3, - stride=1, - padding=1) + self.conv_in = torch.nn.Conv2d( + in_channels, base_channels, kernel_size=3, stride=1, padding=1 + ) curr_res = resolution - in_channel_multiplier = (1, ) + tuple(channel_multiplier) + in_channel_multiplier = (1,) + tuple(channel_multiplier) self.in_channel_multiplier = in_channel_multiplier self.down = nn.ModuleList() for i_level in range(self.num_resolutions): @@ -702,11 +693,14 @@ def __init__(self, config: ChameleonVQVAEConfig): config=config, in_channels=block_in, out_channels=block_out, - )) + ) + ) block_in = block_out - if (config.attn_resolutions is not None - and curr_res in config.attn_resolutions - and config.attn_type == "vanilla"): + if ( + config.attn_resolutions is not None + and curr_res in config.attn_resolutions + and config.attn_type == "vanilla" + ): attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) down = nn.Module() @@ -723,18 +717,20 @@ def __init__(self, config: ChameleonVQVAEConfig): in_channels=block_in, out_channels=block_in, ) - self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock( - block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.attn_1 = ( + ChameleonVQVAEEncoderAttnBlock(block_in) + if config.attn_type == "vanilla" + else nn.Identity() + ) self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( config=config, in_channels=block_in, out_channels=block_in, ) - self.norm_out = torch.nn.GroupNorm(num_groups=32, - num_channels=block_in, - eps=1e-6, - affine=True) + self.norm_out = torch.nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) self.conv_out = torch.nn.Conv2d( block_in, 2 * latent_channels if double_latent else latent_channels, @@ -750,15 +746,12 @@ def forward(self, pixel_values: torch.Tensor): hidden_states = [self.conv_in(pixel_values)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): - hidden_state = self.down[i_level].block[i_block]( - hidden_states[-1]) + hidden_state = self.down[i_level].block[i_block](hidden_states[-1]) if len(self.down[i_level].attn) > 0: - hidden_state = self.down[i_level].attn[i_block]( - hidden_state) + hidden_state = self.down[i_level].attn[i_block](hidden_state) hidden_states.append(hidden_state) if i_level != self.num_resolutions - 1: - hidden_states.append(self.down[i_level].downsample( - hidden_states[-1])) + hidden_states.append(self.down[i_level].downsample(hidden_states[-1])) # middle last_hidden_state = hidden_states[-1] @@ -775,15 +768,14 @@ def forward(self, pixel_values: torch.Tensor): # Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa class ChameleonVQVAE(nn.Module): - def __init__(self, config: ChameleonVQVAEConfig): super().__init__() self.encoder = ChameleonVQVAEEncoder(config) self.quantize = ChameleonVQVAEVectorQuantizer(config) - self.quant_conv = torch.nn.Conv2d(config.latent_channels, - config.embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, - config.latent_channels, 1) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d( + config.embed_dim, config.latent_channels, 1 + ) self.eval() # Chameleon's VQ model is frozen def encode( @@ -811,10 +803,9 @@ def val2name(self): @cached_property def image_tokens(self): - return sorted([ - val for name, val in self.vocab_map.items() - if name.startswith("IMGIMG") - ]) + return sorted( + [val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")] + ) @cached_property def bpe2img(self): @@ -822,13 +813,10 @@ def bpe2img(self): def remap(old_name: str) -> str: return "".join( - img_tkn_chr_mapping.get(c, c) - for c in old_name[len("IMGIMG"):-1]) + img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1] + ) - return { - tok: int(remap(self.val2name[tok])) - for tok in self.image_tokens - } + return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens} @cached_property def img2bpe(self): @@ -837,7 +825,8 @@ def img2bpe(self): @cached_property def bpe2img_search_tensors(self): return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor( - sorted(self.bpe2img.values())) + sorted(self.bpe2img.values()) + ) @cached_property def img2bpe_mapping_tensor(self): @@ -853,7 +842,6 @@ def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: class ChameleonModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -867,25 +855,29 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vocab_size, config.hidden_size, ) - self.vocabulary_mapping = ChameleonImageVocabularyMapping( - config.vocabulary_map) - decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \ + self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map) + decoder_layer = ( + ChameleonDecoderLayer + if not self.config.swin_norm else ChameleonSwinDecoderLayer + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: decoder_layer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: decoder_layer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.vqmodel = ChameleonVQVAE(config.vq_config) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -926,10 +918,9 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -937,14 +928,16 @@ def forward( @MULTIMODAL_REGISTRY.register_processor( ChameleonMultiModalProcessor, info=ChameleonProcessingInfo, - dummy_inputs=ChameleonDummyInputsBuilder) -class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP, SupportsQuant): + dummy_inputs=ChameleonDummyInputsBuilder, +) +class ChameleonForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant +): merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } @classmethod @@ -960,8 +953,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config - self.model = ChameleonModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = ChameleonModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, @@ -972,13 +966,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]: + self, **kwargs: object + ) -> Optional[ChameleonImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) if pixel_values is None: @@ -987,24 +984,23 @@ def _parse_and_validate_image_input( vq_config: ChameleonVQVAEConfig = self.config.vq_config expected_h = expected_w = vq_config.resolution - return ChameleonImagePixelInputs(type="pixel_values", - data=pixel_values, - resolve_bindings={ - "h": expected_h, - "w": expected_w - }) + return ChameleonImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, + ) def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] assert self.model.vqmodel is not None - image_tokens = self.model.get_image_tokens(image_input["data"].to( - self.config.torch_dtype)) + image_tokens = self.model.get_image_tokens( + image_input["data"].to(self.config.torch_dtype) + ) vision_embeddings = self.model.get_input_embeddings(image_tokens) return vision_embeddings @@ -1016,14 +1012,12 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( @@ -1040,8 +1034,7 @@ def compute_logits( return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1056,8 +1049,7 @@ def load_weights(self, weights: Iterable[tuple[str, if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue @@ -1075,8 +1067,7 @@ def load_weights(self, weights: Iterable[tuple[str, # not vqvae for now. use_default_weight_loading = True else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -1096,7 +1087,8 @@ def load_weights(self, weights: Iterable[tuple[str, # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") + ".kv_scale", ".attn.kv_scale" + ) if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 @@ -1109,15 +1101,15 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) if use_default_weight_loading and name in params_dict: if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index c182201fe256..ece719df61f7 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -3,6 +3,7 @@ # Adapted from # https://github.com/zai-org/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" + import json from collections.abc import Iterable from itertools import islice @@ -18,26 +19,34 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GLMAttention(nn.Module): - def __init__( self, config: ChatGLMConfig, @@ -52,9 +61,11 @@ def __init__( assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.multi_query_attention = config.multi_query_attention - self.total_num_kv_heads = (config.multi_query_group_num - if config.multi_query_attention else - config.num_attention_heads) + self.total_num_kv_heads = ( + config.multi_query_group_num + if config.multi_query_attention + else config.num_attention_heads + ) if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. @@ -99,13 +110,15 @@ def __init__( base=10000 * rope_ratio, is_neox_style=is_neox_style, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -183,25 +196,27 @@ def __init__( ): super().__init__() self.apply_residual_connection_post_layernorm = ( - config.apply_residual_connection_post_layernorm) + config.apply_residual_connection_post_layernorm + ) self.fp32_residual_connection = config.fp32_residual_connection layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Layernorm on the input data. - self.input_layernorm = layer_norm_func(config.hidden_size, - eps=config.layernorm_epsilon) + self.input_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon + ) # Self attention. - self.self_attention = GLMAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attention") + self.self_attention = GLMAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attention" + ) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output self.post_attention_layernorm = layer_norm_func( - config.hidden_size, eps=config.layernorm_epsilon) + config.hidden_size, eps=config.layernorm_epsilon + ) # MLP self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp") @@ -261,8 +276,7 @@ def __init__( # Transformer layers. self.start_layer, self.end_layer, self.layers = make_layers( self.num_layers, - lambda prefix: GLMBlock( - config, cache_config, quant_config, prefix=prefix), + lambda prefix: GLMBlock(config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) @@ -270,11 +284,12 @@ def __init__( layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm # Final layer norm before output. self.final_layernorm = layer_norm_func( - config.hidden_size, eps=config.layernorm_epsilon) + config.hidden_size, eps=config.layernorm_epsilon + ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def forward( self, @@ -282,8 +297,9 @@ def forward( position_ids: torch.Tensor, ) -> Union[torch.Tensor, IntermediateTensors]: for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states = layer(hidden_states=hidden_states, - position_ids=position_ids) + hidden_states = layer( + hidden_states=hidden_states, position_ids=position_ids + ) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -298,8 +314,10 @@ def forward( @support_torch_compile class ChatGLMModel(nn.Module, SupportsQuant): packed_modules_mapping = { - "linear_proj.merged_proj": - ["linear_proj.gate_proj", "linear_proj.dense_h_to_4h"] + "linear_proj.merged_proj": [ + "linear_proj.gate_proj", + "linear_proj.dense_h_to_4h", + ] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -311,26 +329,30 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config - self.embedding = VocabParallelEmbedding(config.padded_vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embedding") + self.embedding = VocabParallelEmbedding( + config.padded_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embedding", + ) self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, - cache_config, - quant_config, - prefix=f"{prefix}.encoder") + self.encoder = GLMTransformer( + config, cache_config, quant_config, prefix=f"{prefix}.encoder" + ) - self.output_layer = ParallelLMHead(config.padded_vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.output_layer") + self.output_layer = ParallelLMHead( + config.padded_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.output_layer", + ) self.make_empty_intermediate_tensors = ( - self.encoder.make_empty_intermediate_tensors) + self.encoder.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embedding(input_ids) @@ -360,8 +382,7 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("linear_proj.merged_proj", "linear_proj.gate_proj", 0), @@ -371,7 +392,7 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -392,8 +413,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -401,7 +421,8 @@ def load_weights(self, weights: Iterable[tuple[str, class ChatGLMBaseModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( - orig_to_new_substr={".word_embeddings": ""}, ) + orig_to_new_substr={".word_embeddings": ""}, + ) def __init__( self, @@ -420,18 +441,17 @@ def __init__( self.multimodal_config = multimodal_config self.quant_config = quant_config - self.max_position_embeddings = getattr(config, "max_sequence_length", - 8192) - self.transformer = transformer_type(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) + self.transformer = transformer_type( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) if self.config.tie_word_embeddings: - self.transformer.output_layer.weight = ( - self.transformer.embedding.weight) + self.transformer.output_layer.weight = self.transformer.embedding.weight self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -448,11 +468,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) -class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, - SupportsQuant): +class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsQuant): packed_modules_mapping = { "query_key_value": ["query_key_value"], - "dense_h_to_4h": ["dense_h_to_4h"] + "dense_h_to_4h": ["dense_h_to_4h"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -463,7 +482,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): "The configuration of this model indicates that it supports " "vision inputs, but you instantiated the text-only version " "of this model. Please use the vision model by setting " - f"`--hf-overrides '{json.dumps(hf_overrides)}'`") + f"`--hf-overrides '{json.dumps(hf_overrides)}'`" + ) super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -474,6 +494,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 451da2120048..f05d5c4cc1d8 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,37 +1,88 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Minimal implementation of CLIPVisionModel intended to be only used -within a vision language model.""" -from collections.abc import Iterable -from typing import Optional, Union +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn -from transformers import CLIPVisionConfig - +from transformers import ( + BatchFeature, + CLIPConfig, + CLIPProcessor, + CLIPTextConfig, + CLIPVisionConfig, +) + +from vllm.attention import Attention from vllm.attention.layer import MultiHeadAttention +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptReplacement, + PromptUpdate, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .interfaces_base import default_pooling_type +from .utils import AutoWeightsLoader, maybe_prefix +from .vision import ( + VisionEncoderInfo, + VisionFeatureSelectStrategy, + VisionFeatureSelectStrategyStr, + get_num_selected_vision_tokens, + resolve_visual_encoder_outputs, +) + + +class CLIPImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + """ -from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy, - resolve_visual_encoder_outputs) + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): - def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: - return self.get_patch_grid_length()**2 + 1 + return self.get_patch_grid_length() ** 2 + 1 def get_image_size(self) -> int: return self.vision_config.image_size @@ -45,9 +96,215 @@ def get_patch_grid_length(self) -> int: return image_size // patch_size -# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa -class CLIPVisionEmbeddings(nn.Module): +_POOLING_TYPE_TO_STRATEGY: dict[str, VisionFeatureSelectStrategyStr] = { + "MEAN": "full", + "ALL": "full", + "CLS": "class", + # This lets us use the same pooling type for both text and image + "LAST": "class", +} + + +def _get_vision_feature_select_strategy(pooling_type: str): + try: + return _POOLING_TYPE_TO_STRATEGY[pooling_type] + except KeyError: + raise ValueError( + f"No feature selection strategy is defined for " + f"pooling_type: {pooling_type!r}" + ) from None + + +class CLIPProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(CLIPConfig) + + def get_vision_encoder_info(self): + return CLIPEncoderInfo(self.get_hf_config()) + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(CLIPProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + vision_encoder_info = self.get_vision_encoder_info() + + pooler_config = self.ctx.model_config.pooler_config + assert pooler_config is not None + + return get_num_selected_vision_tokens( + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), + _get_vision_feature_select_strategy(pooler_config.pooling_type), + ) + + def get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self.get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + +class CLIPDummyInputsBuilder(BaseDummyInputsBuilder[CLIPProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) + } + + +class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): + @cached_property + def image_token_id(self) -> int: + tokenizer = self.info.get_tokenizer() + dummy_token_id = 0 + + assert dummy_token_id not in tokenizer.all_special_ids + + return dummy_token_id + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, + *, + mm_uuids: Optional[MultiModalUUIDDict] = None, + ) -> MultiModalInputs: + if prompt and mm_data: + raise ValueError( + "CLIP accepts text-only or image-only inputs, not both! " + "Image-only inputs means passing an image with an empty text " + "prompt." + ) + + if mm_data: + # For multi-modal data, the prompt after processing should + # only contain the dummy image tokens + tokenization_kwargs = { + **(tokenization_kwargs or {}), + "add_special_tokens": False, + } + + return super().apply( + prompt=prompt, + mm_data=mm_data, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, + ) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + image_token_id = self.image_token_id + + def get_replacement(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + return [image_token_id] * num_image_tokens + + return [ + PromptReplacement( + modality="image", + target=PromptIndexTargets.start(), + replacement=get_replacement, + ), + ] + + +# Adapted from: https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/models/clip/modeling_clip.py +class CLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + + embed_dim = config.hidden_size + + self.token_embedding = VocabParallelEmbedding(config.vocab_size, embed_dim) + self.position_embedding = VocabParallelEmbedding( + config.max_position_embeddings, embed_dim + ) + + def forward( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is None: + if input_ids is None: + raise ValueError( + "Either `input_ids` or `input_embeds` must be provided" + ) + + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class CLIPVisionEmbeddings(nn.Module): def __init__(self, config: CLIPVisionConfig): super().__init__() self.config = config @@ -66,19 +323,21 @@ def __init__(self, config: CLIPVisionConfig): bias=False, ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 - self.position_embedding = nn.Embedding(self.num_positions, - self.embed_dim) - self.register_buffer("position_ids", - torch.arange(self.num_positions).expand((1, -1)), - persistent=False) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size = pixel_values.shape[0] target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to( - dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) @@ -89,15 +348,16 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: class CLIPAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, + *, prefix: str = "", - ): + attn_cls: Union[type[Attention], type[MultiHeadAttention]], + ) -> None: super().__init__() + self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads @@ -106,7 +366,8 @@ def __init__( raise ValueError( "embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.qkv_proj = QKVParallelLinear( @@ -127,8 +388,12 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = attn_cls( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -145,26 +410,29 @@ def forward( class CLIPMLP(nn.Module): - def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -175,29 +443,26 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class CLIPEncoderLayer(nn.Module): - def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, + *, prefix: str = "", + attn_cls: Union[type[Attention], type[MultiHeadAttention]], ) -> None: super().__init__() self.self_attn = CLIPAttention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + attn_cls=attn_cls, ) - self.layer_norm1 = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.mlp = CLIPMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.layer_norm2 = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - residual = hidden_states hidden_states = self.layer_norm1(hidden_states) @@ -223,10 +488,12 @@ class CLIPEncoder(nn.Module): def __init__( self, - config: CLIPVisionConfig, + config: Union[CLIPTextConfig, CLIPVisionConfig], quant_config: Optional[QuantizationConfig] = None, num_hidden_layers_override: Optional[int] = None, + *, prefix: str = "", + attn_cls: Union[type[Attention], type[MultiHeadAttention]], ) -> None: super().__init__() @@ -236,15 +503,22 @@ def __init__( num_hidden_layers = config.num_hidden_layers else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - CLIPEncoderLayer(config=config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + CLIPEncoderLayer( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + attn_cls=attn_cls, + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward( - self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool + self, + inputs_embeds: torch.Tensor, + return_all_hidden_states: bool, ) -> Union[torch.Tensor, list[torch.Tensor]]: hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds @@ -260,8 +534,85 @@ def forward( return hidden_states -class CLIPVisionTransformer(nn.Module): +class CLIPTextTransformer(nn.Module): + def __init__( + self, + config: CLIPTextConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPTextEmbeddings(config) + + self.encoder = CLIPEncoder( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + attn_cls=Attention, + ) + + self.final_layer_norm = nn.LayerNorm( + embed_dim, + eps=config.layer_norm_eps, + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embeddings.token_embedding(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=False, + ) + last_hidden_state = self.final_layer_norm(last_hidden_state) + + return last_hidden_state + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class CLIPVisionTransformer(nn.Module): def __init__( self, config: CLIPVisionConfig, @@ -287,6 +638,7 @@ def __init__( quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", + attn_cls=MultiHeadAttention, ) num_hidden_layers = config.num_hidden_layers @@ -301,11 +653,18 @@ def __init__( require_post_norm = len(self.encoder.layers) == num_hidden_layers if require_post_norm: - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) else: self.post_layernorm = None + @property + def dtype(self): + return next(self.parameters()).dtype + + @property + def device(self): + return next(self.parameters()).device + def forward( self, pixel_values: torch.Tensor, @@ -313,7 +672,6 @@ def forward( select_layers: Optional[list[int]] = None, feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, ) -> torch.Tensor: - hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layrnorm(hidden_states) @@ -335,12 +693,46 @@ def forward( return encoder_outputs + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + layer_count = len(self.encoder.layers) -class CLIPVisionModel(nn.Module, SupportsQuant): - config_class = CLIPVisionConfig - main_input_name = "pixel_values" - packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + for name, loaded_weight in weights: + # post_layernorm is not needed in CLIPVisionModel + if name.startswith("post_layernorm") and self.post_layernorm is None: + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("encoder.layers"): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + +class CLIPVisionModel(nn.Module): def __init__( self, config: CLIPVisionConfig, @@ -351,12 +743,14 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + self.vision_model = CLIPVisionTransformer( config=config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, require_post_norm=require_post_norm, - prefix=f"{prefix}.vision_model") + prefix=f"{prefix}.vision_model", + ) def forward( self, @@ -370,49 +764,201 @@ def forward( feature_select_strategy=feature_select_strategy, ) + @property + def dtype(self): + return self.vision_model.dtype + @property def device(self): - return next(self.parameters()).device + return self.vision_model.device - # (TODO) Add prefix argument for filtering out weights to be loaded - # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - layer_count = len(self.vision_model.encoder.layers) - for name, loaded_weight in weights: - # post_layernorm is not needed in CLIPVisionModel - if (name.startswith("vision_model.post_layernorm") - and self.vision_model.post_layernorm is None): - continue +# Assume EOS token corresponds to LAST token in text model +@default_pooling_type("LAST") +@MULTIMODAL_REGISTRY.register_processor( + CLIPMultiModalProcessor, + info=CLIPProcessingInfo, + dummy_inputs=CLIPDummyInputsBuilder, +) +class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): + is_pooling_model = True - # omit layers when num_hidden_layers_override is set - if name.startswith("vision_model.encoder.layers"): - layer_idx = int(name.split(".")[3]) - if layer_idx >= layer_count: - continue + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + merge_by_field_config = True - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return None - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: CLIPConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = CLIPTextTransformer( + text_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "text_model"), + ) + self.vision_model = CLIPVisionTransformer( + vision_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.visual_projection = nn.Linear( + self.vision_embed_dim, + self.projection_dim, + bias=False, + ) + self.text_projection = nn.Linear( + self.text_embed_dim, + self.projection_dim, + bias=False, + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.pooler_config = pooler_config + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) + + # Assumes that self.forward is called after self.get_input_embeddings + self._is_text_input = True + + def get_text_features( + self, + input_ids: Optional[torch.Tensor], + position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pooled_output = self.text_model( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + text_features = self.text_projection(pooled_output) + + return text_features + + def get_image_features( + self, + pixel_values: torch.Tensor, + feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None, + ) -> torch.Tensor: + if feature_select_strategy is None: + feature_select_strategy = _get_vision_feature_select_strategy( + self.pooler_config.pooling_type + ) + + pooled_output = self.vision_model( + pixel_values=pixel_values, + select_layers=None, + feature_select_strategy=feature_select_strategy, + ) + + image_features = self.visual_projection(pooled_output) + + return image_features + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Optional[CLIPImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is None: + return None + + expected_h = expected_w = self.config.vision_config.image_size + return CLIPImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, + ) + + def _process_image_inputs(self, inputs: CLIPImagePixelInputs) -> torch.Tensor: + pixel_values = inputs["data"] + + return self.get_image_features(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.text_model + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + *, + is_multimodal: Optional[torch.Tensor] = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + self._is_text_input = ( + multimodal_embeddings is None or len(multimodal_embeddings) == 0 + ) + + # This is to satisfy the type checker for each overload + if multimodal_embeddings is None or is_multimodal is None: + return super().get_input_embeddings(input_ids) + + return super().get_input_embeddings( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + + vision_embeddings = self._process_image_inputs(image_input) + return vision_embeddings + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> torch.Tensor: + if intermediate_tensors is not None: + raise RuntimeError("PP is not supported for this model") + + # Multimodal inputs + if not self._is_text_input: + return inputs_embeds + + # Text inputs + return self.get_text_features( + input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_substrs=[".position_ids"], + ignore_unexpected_prefixes=["logit_scale."], + ) + + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 70f2a3fd339a..73aafbd01144 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -11,34 +11,44 @@ from transformers import BatchFeature, PretrainedConfig from transformers.models.cohere2_vision import Cohere2VisionConfig from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import ( # noqa: E501 - get_optimal_tiled_canvas) + get_optimal_tiled_canvas, +) from transformers.models.cohere2_vision.processing_cohere2_vision import ( - Cohere2VisionProcessor) + Cohere2VisionProcessor, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import MulAndSilu -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalFieldConfig, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalFieldConfig, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) class Cohere2VisionImagePixelInputs(TensorSchema): @@ -67,7 +77,7 @@ class Cohere2VisionImagePixelInputs(TensorSchema): class Cohere2VisionMultiModalProjector(nn.Module): """Multimodal projector that maps vision features to text embedding space. - + Uses pixel shuffle downsampling followed by SwiGLU activation. """ @@ -76,8 +86,7 @@ def __init__(self, config: Cohere2VisionConfig, prefix: str = ""): self.downsample_factor = config.downsample_factor # Input dimension after pixel shuffle downsampling - input_dim = config.vision_config.hidden_size * ( - config.downsample_factor**2) + input_dim = config.vision_config.hidden_size * (config.downsample_factor**2) # MergedColumnParallelLinear expects the intermediate size to be a list # of sizes, so that it will load the weights as two separate linear # layers before applying any parallelism. @@ -110,28 +119,26 @@ def forward(self, image_features): def pixel_shuffle(self, image_features: torch.Tensor) -> torch.Tensor: """Apply pixel shuffle downsampling to reduce spatial dimensions. - + Args: image_features: Input tensor of shape [B, S, D] where S = H*W - + Returns: Downsampled tensor with increased channel dimension """ - height = width = int(image_features.shape[1]**0.5) + height = width = int(image_features.shape[1] ** 0.5) x = image_features.reshape(image_features.shape[0], width, height, -1) n, h, w, c = x.size() - scale_factor = 1. / self.downsample_factor + scale_factor = 1.0 / self.downsample_factor nh = int(h * scale_factor) nw = int(w * scale_factor) - x = x.reshape(n, nh, self.downsample_factor, nw, - self.downsample_factor, c) + x = x.reshape(n, nh, self.downsample_factor, nw, self.downsample_factor, c) x = x.permute(0, 1, 3, 2, 4, 5).contiguous() x = x.reshape(n, nh, nw, -1) return x class Cohere2VisionProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> Cohere2VisionConfig: return self.ctx.get_hf_config(Cohere2VisionConfig) @@ -146,8 +153,8 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_image_processor() - height = image_processor.size['height'] - width = image_processor.size['width'] + height = image_processor.size["height"] + width = image_processor.size["width"] max_patches = image_processor.max_patches return ImageSize(height=height * max_patches, width=width) @@ -196,8 +203,8 @@ def get_num_patches( class Cohere2VisionDummyInputsBuilder( - BaseDummyInputsBuilder[Cohere2VisionProcessingInfo]): - + BaseDummyInputsBuilder[Cohere2VisionProcessingInfo] +): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -213,23 +220,23 @@ def get_dummy_mm_data( mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - image_size = \ - self.info.get_image_size_with_most_features() + image_size = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=image_size.width, - height=image_size.height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=image_size.width, + height=image_size.height, + num_images=num_images, + overrides=image_overrides, + ) } class Cohere2VisionMultiModalProcessor( - BaseMultiModalProcessor[Cohere2VisionProcessingInfo]): - + BaseMultiModalProcessor[Cohere2VisionProcessingInfo] +): def _call_hf_processor( self, prompt: str, @@ -245,22 +252,26 @@ def _call_hf_processor( ) # Ensure num_patches is available for proper tensor splitting - if "num_patches" not in processed_outputs and ( - images := mm_data.get("images")) is not None: + if ( + "num_patches" not in processed_outputs + and (images := mm_data.get("images")) is not None + ): hf_processor = self.info.get_hf_processor(**mm_kwargs) # Fallback calculation if HF processor didn't provide num_patches - parsed_images = self._get_data_parser().parse_mm_data({ - "image": - images - }).get_items("image", ImageProcessorItems) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) num_patches = [ self.info.get_num_patches( image_width=parsed_images.get_image_size(i).width, image_height=parsed_images.get_image_size(i).height, processor=hf_processor, - ) for i in range(len(parsed_images)) + ) + for i in range(len(parsed_images)) ] processed_outputs["num_patches"] = torch.tensor(num_patches) @@ -273,8 +284,7 @@ def _get_mm_fields_config( ) -> Mapping[str, MultiModalFieldConfig]: num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -301,8 +311,7 @@ def get_replacement(item_idx: int): image_height=image_size.height, processor=hf_processor, ) - patch_tokens = (image_token * img_tokens_per_tile + - img_line_break_token) + patch_tokens = image_token * img_tokens_per_tile + img_line_break_token repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}" return PromptUpdateDetails.select_text(repl, image_token) @@ -319,9 +328,9 @@ def get_replacement(item_idx: int): @MULTIMODAL_REGISTRY.register_processor( Cohere2VisionMultiModalProcessor, info=Cohere2VisionProcessingInfo, - dummy_inputs=Cohere2VisionDummyInputsBuilder) -class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + dummy_inputs=Cohere2VisionDummyInputsBuilder, +) +class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( @@ -330,7 +339,8 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, "model.multi_modal_projector.": "multi_modal_projector.", "model.language_model.": "language_model.model.", "lm_head.": "language_model.lm_head.", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -342,37 +352,39 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config self._patch_quant_config(config, quant_config) - self.vision_tower = SiglipVisionModel(config.vision_config, - quant_config, - prefix=maybe_prefix( - prefix, "vision_tower")) + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.vocab_size = config.text_config.vocab_size - self.multi_modal_projector = \ - Cohere2VisionMultiModalProjector( - config, prefix=maybe_prefix(prefix, "multi_modal_projector")) + self.multi_modal_projector = Cohere2VisionMultiModalProjector( + config, prefix=maybe_prefix(prefix, "multi_modal_projector") + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), - architectures=config.text_config.architectures) + architectures=config.text_config.architectures, + ) @property def dtype(self): return next(self.parameters()).dtype - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - def _process_image_input(self, image_input: Cohere2VisionImagePixelInputs, - **kwargs) -> list[torch.Tensor]: + def _process_image_input( + self, image_input: Cohere2VisionImagePixelInputs, **kwargs + ) -> list[torch.Tensor]: """Process image pixels through vision tower and projector. - + Args: - image_input: Validated image input containing pixel values and + image_input: Validated image input containing pixel values and patch counts - + Returns: List of flattened image embeddings, one per image """ @@ -388,17 +400,15 @@ def _process_image_input(self, image_input: Cohere2VisionImagePixelInputs, image_embeds = self.multi_modal_projector(image_features) # Split and flatten embeddings per image - return [ - e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist()) - ] + return [e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())] def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Cohere2VisionImagePixelInputs]: + self, **kwargs: object + ) -> Optional[Cohere2VisionImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) - assert image_embeds is None, \ - "Cohere2Vision does not support image_embeds." + assert image_embeds is None, "Cohere2Vision does not support image_embeds." if pixel_values is None: return None @@ -410,25 +420,26 @@ def _parse_and_validate_image_input( resolve_bindings={ "h": self.config.vision_config.image_size, "w": self.config.vision_config.image_size, - }) + }, + ) - def _patch_quant_config(self, config: PretrainedConfig, - quant_config: QuantizationConfig): + def _patch_quant_config( + self, config: PretrainedConfig, quant_config: QuantizationConfig + ): # the awq models from OpenGVLab missing `modules_to_not_convert` # patch the quant_config to add `modules_to_not_convert` back if isinstance(quant_config, AWQConfig): text_config = config.text_config - llm_quant_config = getattr(text_config, "quantization_config", - None) - if (not quant_config.modules_to_not_convert) and (llm_quant_config - is not None): + llm_quant_config = getattr(text_config, "quantization_config", None) + if (not quant_config.modules_to_not_convert) and ( + llm_quant_config is not None + ): quant_config.modules_to_not_convert.append("vision_tower") def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index f3929ef3b593..e38c3c0492fb 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -22,6 +22,7 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -35,26 +36,33 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name, - row_parallel_weight_loader) + default_weight_loader, + maybe_remap_kv_scale_name, + row_parallel_weight_loader, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) @torch.compile(backend=current_platform.simple_compile_backend) @@ -63,30 +71,27 @@ def layer_norm_func(hidden_states, weight, variance_epsilon): hidden_states = hidden_states.to(torch.float32) mean = hidden_states.mean(-1, keepdim=True) variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) - hidden_states = (hidden_states - mean) * torch.rsqrt(variance + - variance_epsilon) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + variance_epsilon) hidden_states = weight.to(torch.float32) * hidden_states return hidden_states.to(input_dtype) class LayerNorm(nn.Module): - def __init__(self, param_shape=None, eps=1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(param_shape)) self.variance_epsilon = eps - set_weight_attrs(self.weight, - {"weight_loader": row_parallel_weight_loader}) + set_weight_attrs(self.weight, {"weight_loader": row_parallel_weight_loader}) def forward(self, hidden_states, residuals=None): - hidden_states = layer_norm_func(hidden_states, self.weight, - self.variance_epsilon) + hidden_states = layer_norm_func( + hidden_states, self.weight, self.variance_epsilon + ) return hidden_states, residuals # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere class CohereMLP(nn.Module): - def __init__( self, config: Union[CohereConfig, Cohere2Config], @@ -121,7 +126,6 @@ def forward(self, x): class CohereAttention(nn.Module): - def __init__( self, config: Union[CohereConfig, Cohere2Config], @@ -151,8 +155,8 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.max_position_embeddings = getattr( - config, "model_max_length", None) or getattr( - config, "max_position_embeddings", 8192) + config, "model_max_length", None + ) or getattr(config, "max_position_embeddings", 8192) self.rope_theta = config.rope_theta self.rope_scaling = getattr(config, "rope_scaling", None) self.use_qk_norm = getattr(config, "use_qk_norm", False) @@ -190,21 +194,24 @@ def __init__( if config.layer_types[layer_idx] == "sliding_attention": self.sliding_window = config.sliding_window - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - per_layer_sliding_window=self.sliding_window, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn", + ) if self.use_qk_norm: - self.q_norm = LayerNorm(param_shape=(self.num_heads, - self.head_dim), - eps=config.layer_norm_eps) - self.k_norm = LayerNorm(param_shape=(self.num_kv_heads, - self.head_dim), - eps=config.layer_norm_eps) + self.q_norm = LayerNorm( + param_shape=(self.num_heads, self.head_dim), eps=config.layer_norm_eps + ) + self.k_norm = LayerNorm( + param_shape=(self.num_kv_heads, self.head_dim), + eps=config.layer_norm_eps, + ) def _apply_qk_norm(self, q, k): q = q.view(*q.shape[:-1], -1, self.head_dim) @@ -232,25 +239,27 @@ def forward( class CohereDecoderLayer(nn.Module): - - def __init__(self, - config: Union[CohereConfig, Cohere2Config], - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Union[CohereConfig, Cohere2Config], + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CohereAttention(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = CohereAttention( + config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) - self.mlp = CohereMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), - eps=config.layer_norm_eps) + self.mlp = CohereMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.input_layernorm = LayerNorm( + param_shape=(config.hidden_size), eps=config.layer_norm_eps + ) def forward( self, @@ -274,7 +283,6 @@ def forward( @support_torch_compile class CohereModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -285,22 +293,29 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config self.config = config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: CohereDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.norm = LayerNorm(param_shape=(config.hidden_size), - eps=config.layer_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + self.norm = LayerNorm( + param_shape=(config.hidden_size), eps=config.layer_norm_eps + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -329,15 +344,13 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -349,14 +362,15 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -386,8 +400,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -421,13 +434,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.quant_config = quant_config - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=config.logit_scale) - self.model = CohereModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, scale=config.logit_scale + ) + self.model = CohereModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -440,26 +455,27 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> Optional[torch.Tensor]: - is_not_lora = hasattr(self.model.embed_tokens, 'weight') + is_not_lora = hasattr(self.model.embed_tokens, "weight") if is_not_lora: - logits = self.logits_processor(self.model.embed_tokens, - hidden_states) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) else: - logits = self.logits_processor(self.model.embed_tokens.base_layer, - hidden_states) + logits = self.logits_processor( + self.model.embed_tokens.base_layer, hidden_states + ) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( - self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"]) + self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"] + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 5711b5ebe85e..caf481f5aec6 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -4,28 +4,24 @@ from typing import TYPE_CHECKING import vllm.envs as envs -from vllm.config.compilation import CUDAGraphMode from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: - from vllm.config import VllmConfig logger = init_logger(__name__) class VerifyAndUpdateConfig: - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: raise NotImplementedError class Gemma3TextModelConfig: - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: hf_config = vllm_config.model_config.hf_config @@ -33,7 +29,6 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: class GteNewModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config @@ -49,12 +44,11 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), "max_position": config.max_position_embeddings, "base": config.rope_theta, - "rope_scaling": getattr(config, "rope_scaling", None) + "rope_scaling": getattr(config, "rope_scaling", None), } class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: pooler_config = vllm_config.model_config.pooler_config @@ -63,7 +57,6 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: class JinaRobertaModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config @@ -77,29 +70,27 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), "max_position": config.max_position_embeddings, "base": getattr(config, "rope_theta", config.rotary_emb_base), - "rope_scaling": getattr(config, "rope_scaling", None) + "rope_scaling": getattr(config, "rope_scaling", None), } class NomicBertModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config assert config.__class__.__name__ == "NomicBertConfig" assert config.activation_function in ["swiglu", "gelu"] - config.position_embedding_type = getattr(config, - "position_embedding_type", - "rope") + config.position_embedding_type = getattr( + config, "position_embedding_type", "rope" + ) if config.activation_function == "swiglu": config.hidden_act = "silu" else: config.hidden_act = config.activation_function - assert (config.mlp_fc1_bias == config.mlp_fc2_bias == - config.qkv_proj_bias) + assert config.mlp_fc1_bias == config.mlp_fc2_bias == config.qkv_proj_bias config.bias = config.qkv_proj_bias assert config.rotary_emb_scale_base is None @@ -118,7 +109,7 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: "rotary_dim": rotary_emb_dim, "max_position": max_trained_positions, "base": getattr(config, "rope_theta", config.rotary_emb_base), - "rope_scaling": getattr(config, "rope_scaling", None) + "rope_scaling": getattr(config, "rope_scaling", None), } # we ignore config.rotary_scaling_factor so that for datasets shorter @@ -126,15 +117,18 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: # with SentenceTransformer. # The context extension uses vllm style rope_theta and rope_scaling. # See #17785 #18755 - if (not vllm_config.model_config.hf_overrides - and vllm_config.model_config.original_max_model_len is None): + if ( + not vllm_config.model_config.hf_overrides + and vllm_config.model_config.original_max_model_len is None + ): # Default # Reset max_model_len to max_trained_positions. # nomic-embed-text-v2-moe the length is set to 512 # by sentence_bert_config.json. max_model_len_before = vllm_config.model_config.max_model_len - max_model_len = min(vllm_config.model_config.max_model_len, - max_trained_positions) + max_model_len = min( + vllm_config.model_config.max_model_len, max_trained_positions + ) vllm_config.recalculate_max_model_len(max_model_len) logger.warning( @@ -142,7 +136,9 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: "Changing max_model_len from %s to %s. " "To enable context extension, see: " "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html", - max_model_len_before, vllm_config.model_config.max_model_len) + max_model_len_before, + vllm_config.model_config.max_model_len, + ) else: # We need to re-verify max_model_len to avoid lengths # greater than position_embedding. @@ -152,7 +148,8 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: if isinstance(model_config.hf_overrides, dict): # hf_overrides_kw max_model_len = model_config.hf_overrides.get( - "max_model_len", vllm_config.model_config.max_model_len) + "max_model_len", vllm_config.model_config.max_model_len + ) else: # hf_overrides_fn # This might be overridden by sentence_bert_config.json. @@ -174,7 +171,6 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: class Qwen2ForProcessRewardModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: pooler_config = vllm_config.model_config.pooler_config @@ -184,7 +180,6 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: class Qwen2ForRewardModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: pooler_config = vllm_config.model_config.pooler_config @@ -194,27 +189,26 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config - is_original_qwen3_reranker = getattr(config, - "is_original_qwen3_reranker", - False) + is_original_qwen3_reranker = getattr( + config, "is_original_qwen3_reranker", False + ) if not is_original_qwen3_reranker: return tokens = getattr(config, "classifier_from_token", None) - assert tokens is not None and len(tokens) == 2, \ - ("Try loading the original Qwen3 Reranker?, see: " - "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") + assert tokens is not None and len(tokens) == 2, ( + "Try loading the original Qwen3 Reranker?, see: " + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py" + ) vllm_config.model_config.hf_config.method = "from_2_way_softmax" class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config @@ -225,7 +219,6 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: config = vllm_config.model_config.hf_config @@ -241,12 +234,11 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), "max_position": config.max_position_embeddings, "base": config.rope_theta, - "rope_scaling": getattr(config, "rope_scaling", None) + "rope_scaling": getattr(config, "rope_scaling", None), } class GptOssForCausalLMConfig(VerifyAndUpdateConfig): - @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: structured_outputs_config = vllm_config.structured_outputs_config @@ -269,12 +261,11 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: cuda_graph_sizes += [i for i in range(256, 993, 16)] scheduler_config.cuda_graph_sizes = cuda_graph_sizes logger.info( - "Overriding max cuda graph capture size to " - "%d for performance.", 992) + "Overriding max cuda graph capture size to %d for performance.", 992 + ) class MambaModelConfig(VerifyAndUpdateConfig): - @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ @@ -290,29 +281,42 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: model_config = vllm_config.model_config cache_config = vllm_config.cache_config - compilation_config = vllm_config.compilation_config - # TODO(tdoublep): remove once prefix caching is enabled - cache_config.enable_prefix_caching = False - logger.info("Hybrid or mamba-based model detected: disabling prefix " - "caching since it is not yet supported.") - - # TODO(tdoublep): remove as full cuda graph support is added - FCG_NOT_SUPPORTED_MODELS = [ - "Lfm2ForCausalLM", - "MiniMaxText01ForCausalLM", + # Set mamba block size to max_model_len (this may get + # override by prefix caching logic later) + cache_config.mamba_block_size = model_config.max_model_len + + # TODO(@tdoublep) find a better way to do this than whitelist + MAMBA2_MODELS = [ + "BambaForCausalLM", + "FalconH1ForCausalLM", + "GraniteMoeHybridForCausalLM", + "Mamba2ForCausalLM", + "NemotronHForCausalLM", + "Zamba2ForCausalLM", ] - - if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS - and compilation_config.cudagraph_mode is None): - logger.info( - "Hybrid or mamba-based model detected: setting cudagraph mode " - "to FULL_AND_PIECEWISE in order to optimize performance.") - compilation_config.cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE + if cache_config.enable_prefix_caching: + if model_config.architecture in MAMBA2_MODELS: + logger.info( + "Warning: Prefix caching is currently enabled. " + "Its support for Mamba2 layers is experimental. " + "Please report any issues you may observe." + ) + else: + logger.info( + "Hybrid or mamba-based model detected without " + "support for prefix caching: disabling." + ) + cache_config.enable_prefix_caching = False + + # TODO(tdoublep): remove once cascade attention is supported + logger.info( + "Disabling cascade attention since it is not supported for hybrid models." + ) + model_config.disable_cascade_attn = True class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): - @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ @@ -346,7 +350,8 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=1, num_kv_heads=model_config.get_num_kv_heads(parallel_config), head_size=model_config.get_head_size(), - dtype=kv_cache_dtype).page_size_bytes + dtype=kv_cache_dtype, + ).page_size_bytes model_cls, _ = ModelRegistry.resolve_model_cls( model_config.architecture, @@ -360,27 +365,49 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: block_size=model_config.max_model_len, ).page_size_bytes - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - attn_block_size = 16 * cdiv(mamba_page_size, - 16 * attn_page_size_1_token) + if cache_config.enable_prefix_caching: + # With prefix caching, select attention block size to + # optimize for mamba kernel performance + + # mamba SSD kernel uses a chunk_size, e.g. 256 + # Align the block to the kernel: use lowest multiple of chunk_size + # of attention tokens that would fit mamba_page_size: + # e.g. for mamba page size = 788kB + # attn_1_token = 2kB -> fits ~394 tokens + # then round up to a mulitple of 256 -> 512 tokens + # End result: + # attn_block_size = 512 + # mamba_block_size = 512 (aligned to a multiple of chunk_size) + # TODO(tdoublep): this constraint can be relaxed fairly + # easily by changing the way we layout chunks in the + # mamba2 kernels. + chunk_size = model_config.get_mamba_chunk_size() + attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) + attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) + cache_config.mamba_block_size = attn_block_size + else: + # Without prefix caching, select minimum valid attention block size + # to minimize mamba state padding + + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1_token) # override attention block size if either (a) the # user has not set it or (b) the user has set it # too small. - if (cache_config.block_size is None - or cache_config.block_size < attn_block_size): + if cache_config.block_size is None or cache_config.block_size < attn_block_size: cache_config.block_size = attn_block_size logger.info( "Setting attention block size to %d tokens " "to ensure that attention page size is >= mamba page size.", - attn_block_size) + attn_block_size, + ) # compute new attention page size - attn_page_size = \ - cache_config.block_size * attn_page_size_1_token + attn_page_size = cache_config.block_size * attn_page_size_1_token assert attn_page_size >= mamba_page_size @@ -389,19 +416,23 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: return # pad mamba page size to exactly match attention - if (cache_config.mamba_page_size_padded is None - or cache_config.mamba_page_size_padded != attn_page_size): - cache_config.mamba_page_size_padded = (attn_page_size) - mamba_padding_pct = 100 * (attn_page_size - - mamba_page_size) / mamba_page_size + if ( + cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size + ): + cache_config.mamba_page_size_padded = attn_page_size + mamba_padding_pct = ( + 100 * (attn_page_size - mamba_page_size) / mamba_page_size + ) logger.info( "Padding mamba page size by %.2f%% to ensure " "that mamba page size and attention page size are " - "exactly equal.", mamba_padding_pct) + "exactly equal.", + mamba_padding_pct, + ) class DeepseekV32ForCausalLM(VerifyAndUpdateConfig): - @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ @@ -416,8 +447,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # For DeepSeekV3.2, we use a custom fp8 format as default (i.e. # "auto") cache_config = vllm_config.cache_config - if cache_config.cache_dtype == "auto" or \ - cache_config.cache_dtype.startswith("fp8"): + if cache_config.cache_dtype == "auto" or cache_config.cache_dtype.startswith( + "fp8" + ): cache_config.cache_dtype = "fp8_ds_mla" logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") if cache_config.cache_dtype == "bfloat16": diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index f863b1da5505..8ec7a82a7b2a 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -11,25 +11,39 @@ from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class DbrxRouter(nn.Module): @@ -60,7 +74,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class DbrxExperts(FusedMoE): - def __init__( self, config: DbrxConfig, @@ -82,12 +95,16 @@ def __init__( ) self.config = config self.d_model = config.d_model - self.intermediate_size = (self.config.ffn_config.ffn_hidden_size // - self.tp_size) + self.intermediate_size = self.config.ffn_config.ffn_hidden_size // self.tp_size # Define custom weight loader for dbrx model - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, param_name: str): + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + param_name: str, + ): tp_rank = get_tensor_model_parallel_rank() param_data = param.data shard_size = self.intermediate_size @@ -111,8 +128,9 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_weight, [-1, self.intermediate_size * self.tp_size, self.d_model], ) - param_data[:, shard_size:2 * - shard_size, :] = loaded_weight[:, shard, :] + param_data[:, shard_size : 2 * shard_size, :] = loaded_weight[ + :, shard, : + ] elif param_name.endswith("weight_scale"): param_data[:, 1] = loaded_weight else: @@ -151,10 +169,12 @@ def __init__( self.router = DbrxRouter(config, self.params_dtype) - self.experts = DbrxExperts(config=config, - quant_config=quant_config, - params_dtype=self.params_dtype, - prefix=f"{prefix}.experts") + self.experts = DbrxExperts( + config=config, + quant_config=quant_config, + params_dtype=self.params_dtype, + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape @@ -166,7 +186,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class DbrxAttention(nn.Module): - def __init__( self, config: DbrxConfig, @@ -222,13 +241,15 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -246,7 +267,6 @@ def forward( class DbrxFusedNormAttention(nn.Module): - def __init__( self, config: DbrxConfig, @@ -256,10 +276,9 @@ def __init__( ): super().__init__() self.d_model = config.d_model - self.attn = DbrxAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = DbrxAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) @@ -281,7 +300,6 @@ def forward( class DbrxBlock(nn.Module): - def __init__( self, config: DbrxConfig, @@ -291,10 +309,8 @@ def __init__( ): super().__init__() self.norm_attn_norm = DbrxFusedNormAttention( - config, - cache_config, - quant_config, - prefix=f"{prefix}.norm_attn_norm") + config, cache_config, quant_config, prefix=f"{prefix}.norm_attn_norm" + ) self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn") def forward( @@ -312,7 +328,6 @@ def forward( class DbrxModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -327,19 +342,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.blocks = make_layers( config.n_layers, - lambda prefix: DbrxBlock( - config, cache_config, quant_config, prefix=prefix), + lambda prefix: DbrxBlock(config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.blocks", ) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): - if hasattr(module, "bias") and isinstance(module.bias, - nn.Parameter): + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): # Remove the bias term in Linear and LayerNorm. module.register_parameter("bias", None) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.d_model)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.d_model + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -366,24 +379,27 @@ def forward( hidden_states = self.norm_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - expert_params_mapping = [( - "w13" if weight_name in ["w1", "v1"] else "w2", - f"mlp.{weight_name}", - ) for weight_name in ["w1", "v1", "w2"]] + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + expert_params_mapping = [ + ( + "w13" if weight_name in ["w1", "v1"] else "w2", + f"mlp.{weight_name}", + ) + for weight_name in ["w1", "v1", "w2"] + ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -409,28 +425,25 @@ def load_weights(self, weights: Iterable[tuple[str, if name is None: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class DbrxForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config if config.tie_word_embeddings: - raise ValueError( - "tie_word_embeddings is not supported for Dbrx models.") + raise ValueError("tie_word_embeddings is not supported for Dbrx models.") self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = DbrxModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, @@ -439,10 +452,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -454,8 +469,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -465,7 +481,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index ffc843fe033c..67258c2f77b8 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -33,33 +34,43 @@ from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class DeepseekMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -71,17 +82,19 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results) + quant_config=quant_config, + reduce_results=reduce_results, + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -92,7 +105,6 @@ def forward(self, x): class DeepseekMoE(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -108,26 +120,29 @@ def __init__( if self.tp_size > self.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.n_routed_experts}.") - - self.experts = nn.ModuleList([ - DeepseekMLP(hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False) - for idx in range(self.n_routed_experts) - ]) + f"the number of experts {self.n_routed_experts}." + ) + + self.experts = nn.ModuleList( + [ + DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + for idx in range(self.n_routed_experts) + ] + ) self.pack_params() - self.gate = ReplicatedLinear(config.hidden_size, - self.n_routed_experts, - bias=False, - quant_config=None) + self.gate = ReplicatedLinear( + config.hidden_size, self.n_routed_experts, bias=False, quant_config=None + ) if config.n_shared_experts is not None: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) + intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, @@ -167,25 +182,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, router_logits, self.top_k, - renormalize=self.config.norm_topk_prob) + renormalize=self.config.norm_topk_prob, + ) - final_hidden_states = fused_experts(hidden_states, - self.w1, - self.w2, - topk_weights, - topk_ids, - inplace=True) + final_hidden_states = fused_experts( + hidden_states, self.w1, self.w2, topk_weights, topk_ids, inplace=True + ) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) class DeepseekAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -244,13 +255,15 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -266,7 +279,6 @@ def forward( class DeepseekDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -279,8 +291,7 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = DeepseekAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -292,12 +303,14 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): + self.mlp = DeepseekMoE( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) else: self.mlp = DeepseekMLP( hidden_size=config.hidden_size, @@ -306,10 +319,10 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -322,22 +335,19 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class DeepseekModel(nn.Module): - fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -358,11 +368,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: DeepseekDecoderLayer( config, cache_config, quant_config=quant_config, prefix=prefix ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -386,15 +397,13 @@ def forward( for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -409,7 +418,7 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -417,8 +426,9 @@ def load_weights(self, weights: Iterable[tuple[str, if name.endswith(".bias") and name not in params_dict: continue # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): + if ( + "mlp.experts." in name or "mlp.shared_experts." in name + ) and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue @@ -431,14 +441,14 @@ def load_weights(self, weights: Iterable[tuple[str, if name.endswith(".bias") and name not in params_dict: continue # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): + if ( + "mlp.experts." in name or "mlp.shared_experts." in name + ) and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -456,8 +466,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = DeepseekModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = DeepseekModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, @@ -468,7 +479,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -480,8 +492,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -491,7 +504,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index c42a66d86912..467468dcc01e 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -14,18 +14,23 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer, - DeepseekV3ForCausalLM) + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2DecoderLayer, + DeepseekV3ForCausalLM, +) from .utils import AutoWeightsLoader, maybe_prefix @support_torch_compile class DeepseekV2Model(nn.Module): - def __init__( self, *, @@ -34,8 +39,7 @@ def __init__( start_layer_id: int = 0, ) -> None: super().__init__() - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config quant_config = vllm_config.quant_config self.vocab_size = self.config.vocab_size @@ -46,12 +50,15 @@ def __init__( prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList([ - DeepseekV2DecoderLayer( - vllm_config, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - ) for i in range(self.config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer( + vllm_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) + for i in range(self.config.num_hidden_layers) + ] + ) self.fc = nn.Linear( self.config.model.hidden_size * 2, @@ -59,12 +66,9 @@ def __init__( bias=False, ) - self.enorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) - self.hnorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) - self.norm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.enorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.hnorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -78,8 +82,8 @@ def forward( input_embeds = self.embed_tokens(input_ids) inputs = torch.cat( - [self.enorm(input_embeds), - self.hnorm(hidden_states)], dim=-1) + [self.enorm(input_embeds), self.hnorm(hidden_states)], dim=-1 + ) hidden_states = self.fc(inputs) residual = None for layer in self.layers: @@ -91,8 +95,7 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -107,7 +110,8 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -132,8 +136,9 @@ def load_weights(self, weights: Iterable[tuple[str, # QKV fusion is optional, fall back to normal # weight loading if it's not enabled # if go with fusion option, then update name - if ((param_name == "fused_qkv_a_proj") - and name_mapped not in params_dict): + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: continue else: name = name_mapped @@ -165,8 +170,7 @@ def load_weights(self, weights: Iterable[tuple[str, break else: # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and \ - "embed_tokens." in name: + if get_pp_group().world_size == 1 and "embed_tokens." in name: continue # Skip loading extra bias for GPTQ models. @@ -179,34 +183,37 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config quant_config = vllm_config.quant_config target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) - self.model = DeepseekV2Model(vllm_config=vllm_config, - prefix="model", - start_layer_id=target_layer_num) + vllm_config.parallel_config + ) + self.model = DeepseekV2Model( + vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num + ) - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.config.vocab_size, - scale=logit_scale) + self.logits_processor = LogitsProcessor( + self.config.vocab_size, scale=logit_scale + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -232,7 +239,6 @@ def compute_logits( return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - def transform(inputs): name, loaded_weight = inputs if "lm_head" not in name: diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 02a25ab762e5..36c1e0cbe69b 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -13,18 +13,18 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .deepseek_v2 import (DeepseekV2DecoderLayer, - get_spec_layer_idx_from_weight_name) +from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name from .interfaces import SupportsPP from .utils import maybe_prefix class SharedHead(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -33,17 +33,18 @@ def __init__( ) -> None: super().__init__() self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "head")) + self.head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head"), + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(hidden_states) class DeepSeekMultiTokenPredictorLayer(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: super().__init__() @@ -52,9 +53,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.eh_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) + self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.is_v32 = hasattr(config, "index_topk") if self.is_v32: @@ -63,14 +62,16 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda") + device="cuda", + ) else: topk_indices_buffer = None - self.shared_head = SharedHead(config=config, - prefix=prefix, - quant_config=quant_config) - self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, - topk_indices_buffer) + self.shared_head = SharedHead( + config=config, prefix=prefix, quant_config=quant_config + ) + self.mtp_block = DeepseekV2DecoderLayer( + vllm_config, prefix, topk_indices_buffer + ) def forward( self, @@ -87,30 +88,34 @@ def forward( previous_hidden_states = self.hnorm(previous_hidden_states) hidden_states = self.eh_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - residual=None) + hidden_states, residual = self.mtp_block( + positions=positions, hidden_states=hidden_states, residual=None + ) hidden_states = residual + hidden_states return hidden_states class DeepSeekMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict({ - str(idx): - DeepSeekMultiTokenPredictorLayer(vllm_config, - f"{prefix}.layers.{idx}") - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.layers = torch.nn.ModuleDict( + { + str(idx): DeepSeekMultiTokenPredictorLayer( + vllm_config, f"{prefix}.layers.{idx}" + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -130,7 +135,7 @@ def forward( ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - current_step_idx = (spec_step_idx % self.num_mtp_layers) + current_step_idx = spec_step_idx % self.num_mtp_layers return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, @@ -144,22 +149,21 @@ def compute_logits( hidden_states: torch.Tensor, spec_step_idx: int = 0, ) -> torch.Tensor: - current_step_idx = (spec_step_idx % self.num_mtp_layers) - mtp_layer = self.layers[str(self.mtp_start_layer_idx + - current_step_idx)] - logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states)) + current_step_idx = spec_step_idx % self.num_mtp_layers + mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] + logits = self.logits_processor( + mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) + ) return logits class DeepSeekMTP(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = DeepSeekMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -173,8 +177,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( @@ -184,8 +189,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.model.compute_logits(hidden_states, spec_step_idx) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), @@ -197,7 +201,8 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -208,7 +213,7 @@ def load_weights(self, weights: Iterable[tuple[str, if spec_layer is None: continue name = self._rewrite_spec_layer_name(spec_layer, name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -218,14 +223,15 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled - if ((param_name == "fused_qkv_a_proj") - and name_mapped not in params_dict): + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: continue else: name = name_mapped @@ -247,11 +253,13 @@ def load_weights(self, weights: Iterable[tuple[str, param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -260,13 +268,16 @@ def load_weights(self, weights: Iterable[tuple[str, # According to DeepSeek-V3 Technical Report, MTP modules # shares embedding layer. We only load the first weights. - if (spec_layer != self.model.mtp_start_layer_idx - and ".layers" not in name): + if ( + spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name + ): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -278,7 +289,11 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: and rename shared layer weights to be top level. """ spec_layer_weight_names = [ - "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", ] shared_weight_names = ["embed_tokens"] spec_layer_weight = False @@ -291,8 +306,9 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: break if not spec_layer_weight: # treat rest weights as weights for transformer layer block - name = name.replace(f"model.layers.{spec_layer}.", - f"model.layers.{spec_layer}.mtp_block.") + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) elif shared_weight: # treat shared weights as top level weights name = name.replace(f"model.layers.{spec_layer}.", "model.") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b7f96d0d1552..f149b02e5522 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" + import typing from collections.abc import Callable, Iterable from itertools import islice @@ -36,46 +37,61 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ParallelConfig, VllmConfig, - get_current_vllm_config) -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) + per_token_group_quant_fp8, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import cdiv, direct_register_custom_op from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits -from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend, - DeepseekV32IndexerMetadata) +from vllm.v1.attention.backends.mla.indexer import ( + DeepseekV32IndexerBackend, + DeepseekV32IndexerMetadata, +) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops @@ -86,7 +102,6 @@ class DeepseekV2MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -104,21 +119,26 @@ def __init__( # replicated and no collective ops are needed. # Otherwise we use standard TP with an allreduce at the end. self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, disable_tp=is_sequence_parallel, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - disable_tp=is_sequence_parallel, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=is_sequence_parallel, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -129,7 +149,6 @@ def forward(self, x): class DeepseekV2MoE(nn.Module): - def __init__( self, config: Union[DeepseekV2Config, DeepseekV3Config], @@ -152,17 +171,22 @@ def __init__( self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") - - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts, dtype=torch.float32)) + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) else: self.gate.e_score_correction_bias = None @@ -172,14 +196,13 @@ def __init__( self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) if config.n_shared_experts is None: self.experts = FusedMoE( @@ -204,8 +227,7 @@ def __init__( ) self.shared_experts = None else: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) + intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, @@ -253,8 +275,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - fused_moe_out = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if self.shared_experts is not None: shared_output, final_hidden_states = fused_moe_out @@ -268,7 +291,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states *= self.routed_scaling_factor elif self.shared_experts is not None: assert shared_output is not None - shared_output *= (1. / self.routed_scaling_factor) + shared_output *= 1.0 / self.routed_scaling_factor if self.shared_experts is not None: assert shared_output is not None @@ -276,25 +299,26 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( - final_hidden_states, 0) + final_hidden_states, 0 + ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: - final_hidden_states = ( - self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) return final_hidden_states.view(num_tokens, hidden_dim) def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: import math + if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 class DeepseekV2Attention(nn.Module): - def __init__( self, vllm_config: VllmConfig, @@ -329,60 +353,70 @@ def __init__( self.scaling = self.qk_head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - assert topk_indices_buffer is None, "topk_indices_buffer is not \ + assert topk_indices_buffer is None, ( + "topk_indices_buffer is not \ supported for DeepseekV2Attention" + ) if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") + prefix=f"{prefix}.kv_b_proj", + ) # O projection. - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' + rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) @@ -390,13 +424,15 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - self.attn = Attention(self.num_local_heads, - self.qk_head_dim, - self.scaling, - num_kv_heads=self.num_local_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -406,47 +442,43 @@ def forward( if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] q = self.q_a_layernorm(q) - q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, - self.qk_head_dim) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: - q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, - self.qk_head_dim) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_a, _ = latent_cache.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] - kv = kv.view(-1, self.num_local_heads, - self.qk_nope_head_dim + self.v_head_dim) + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) - k[..., :self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe # padding value to qk_head_dim for alignment v = torch.nn.functional.pad( - v, [0, self.qk_head_dim - self.v_head_dim], - value=0).view(-1, self.num_local_heads * self.qk_head_dim) + v, [0, self.qk_head_dim - self.v_head_dim], value=0 + ).view(-1, self.num_local_heads * self.qk_head_dim) attn_output = self.attn(q, k, v) - attn_output = attn_output.view( - -1, self.num_local_heads, - self.qk_head_dim)[..., :self.v_head_dim].reshape( - -1, self.num_local_heads * self.v_head_dim) + attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): - - def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str, - cache_config: CacheConfig): + def __init__( + self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig + ): super().__init__() self.kv_cache = [torch.tensor([])] self.head_dim = head_dim @@ -466,8 +498,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: dtype=self.dtype, ) - def forward(self): - ... + def forward(self): ... def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend @@ -497,27 +528,33 @@ def cp_gather_indexer_k_quant_cache( value = [] scale = [] - full_block = torch.arange(tot - 1, - device=kv_cache.device, - dtype=torch.int32) - non_remaining_value = kv_cache[blocks[full_block], :block_size * - head_dim].view(-1, head_dim) - non_remaining_scale = kv_cache[blocks[full_block], - block_size * head_dim:].view(-1, 4) + full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) + non_remaining_value = kv_cache[ + blocks[full_block], : block_size * head_dim + ].view(-1, head_dim) + non_remaining_scale = kv_cache[ + blocks[full_block], block_size * head_dim : + ].view(-1, 4) remaining = s - (tot - 1) * block_size - value = torch.cat([ - non_remaining_value, - kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim) - ], - dim=0) - scale = torch.cat([ - non_remaining_scale, - kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim + - remaining * 4].view(-1, 4) - ], - dim=0) + value = torch.cat( + [ + non_remaining_value, + kv_cache[blocks[-1], : remaining * head_dim].view(-1, head_dim), + ], + dim=0, + ) + scale = torch.cat( + [ + non_remaining_scale, + kv_cache[ + blocks[-1], + block_size * head_dim : block_size * head_dim + remaining * 4, + ].view(-1, 4), + ], + dim=0, + ) expected_value.append(value) expected_scale.append(scale) @@ -545,7 +582,6 @@ def sparse_attn_indexer( total_seq_lens: int, topk_indices_buffer: Optional[torch.Tensor], ) -> torch.Tensor: - # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata # assert isinstance(attn_metadata, dict) @@ -580,16 +616,18 @@ def sparse_attn_indexer( scale_fmt, ) - topk_indices_buffer[:hidden_states.shape[0]] = -1 + topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: prefill_metadata = attn_metadata.prefill for chunk in prefill_metadata.chunks: - k_fp8 = torch.empty([chunk.total_seq_lens, head_dim], - device=k.device, - dtype=torch.float8_e4m3fn) - k_scale = torch.empty([chunk.total_seq_lens, 1], - device=k.device, - dtype=torch.float32) + k_fp8 = torch.empty( + [chunk.total_seq_lens, head_dim], + device=k.device, + dtype=torch.float8_e4m3fn, + ) + k_scale = torch.empty( + [chunk.total_seq_lens, 1], device=k.device, dtype=torch.float32 + ) cp_gather_indexer_k_quant_cache( kv_cache, k_fp8, @@ -599,27 +637,26 @@ def sparse_attn_indexer( chunk.num_reqs, ) logits = fp8_mqa_logits( - q_fp8[chunk.token_start:chunk.token_end], + q_fp8[chunk.token_start : chunk.token_end], (k_fp8, k_scale), - weights[chunk.token_start:chunk.token_end], + weights[chunk.token_start : chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, ) - topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), - dim=-1)[1] + topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), dim=-1)[1] topk_indices -= chunk.cu_seqlen_ks[:, None] mask_lo = topk_indices >= 0 - mask_hi = topk_indices - (chunk.cu_seqlen_ke - - chunk.cu_seqlen_ks)[:, None] < 0 - mask = torch.full_like(topk_indices, - False, - dtype=torch.bool, - device=topk_indices.device) + mask_hi = ( + topk_indices - (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks)[:, None] < 0 + ) + mask = torch.full_like( + topk_indices, False, dtype=torch.bool, device=topk_indices.device + ) mask = mask_lo & mask_hi topk_indices = topk_indices.masked_fill(~mask, -1) topk_indices_buffer[ - chunk.token_start:chunk.token_end, :topk_indices. - shape[-1]] = topk_indices.to(dtype=torch.int32) + chunk.token_start : chunk.token_end, : topk_indices.shape[-1] + ] = topk_indices.to(dtype=torch.int32) if has_decode: decode_metadata = attn_metadata.decode @@ -633,10 +670,12 @@ def sparse_attn_indexer( # prefill and decode by decode_threshold # (currently set to 1 + speculative tokens) padded_q_fp8_decode_tokens = pack_seq_triton( - q_fp8[:num_decode_tokens], decode_lens) + q_fp8[:num_decode_tokens], decode_lens + ) else: padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( - decode_lens.shape[0], -1, *q_fp8.shape[1:]) + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) # TODO: move and optimize below logic with triton kernels batch_size = padded_q_fp8_decode_tokens.shape[0] next_n = padded_q_fp8_decode_tokens.shape[1] @@ -654,22 +693,24 @@ def sparse_attn_indexer( # padded query len current_device = padded_q_fp8_decode_tokens.device padded_num_tokens = batch_size * next_n - positions = torch.arange(max_model_len, - device=current_device).unsqueeze(0).expand( - batch_size * next_n, -1) - row_indices = torch.arange(padded_num_tokens, - device=current_device) // next_n - next_n_offset = torch.arange( - padded_num_tokens, - device=padded_q_fp8_decode_tokens.device) % next_n - index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + - next_n_offset).unsqueeze(1) + positions = ( + torch.arange(max_model_len, device=current_device) + .unsqueeze(0) + .expand(batch_size * next_n, -1) + ) + row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n + next_n_offset = ( + torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device) + % next_n + ) + index_end_pos = ( + decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + ).unsqueeze(1) # index_end_pos: [B * N, 1] mask = positions <= index_end_pos # mask: [B * N, L] - logits = logits.masked_fill(~mask, float('-inf')) - topk_indices = logits.topk(topk_tokens, - dim=-1)[1].to(torch.int32) # [B * N, K] + logits = logits.masked_fill(~mask, float("-inf")) + topk_indices = logits.topk(topk_tokens, dim=-1)[1].to(torch.int32) # [B * N, K] # ensure we don't set indices for the top k # that is out of range(masked already) # this will happen if context length is shorter than K @@ -679,9 +720,11 @@ def sparse_attn_indexer( # the topk indices removing padded tokens topk_indices = unpack_seq_triton( topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), - decode_lens) - topk_indices_buffer[:num_decode_tokens, :topk_indices. - shape[-1]] = topk_indices.to(dtype=torch.int32) + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices.to(dtype=torch.int32) + ) return topk_indices_buffer @@ -704,11 +747,10 @@ def sparse_attn_indexer_fake( # profile run # NOTE(Chen): create the max possible flattened_kv. So that # profile_run can get correct memory usage. - _flattened_kv = torch.empty([total_seq_lens, head_dim + 4], - device=k.device, - dtype=torch.uint8) - _k_fp8 = _flattened_kv[..., :head_dim].view( - torch.float8_e4m3fn).contiguous() + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 + ) + _k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous() _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() return topk_indices_buffer @@ -723,16 +765,17 @@ def sparse_attn_indexer_fake( class Indexer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - config: Union[DeepseekV2Config, DeepseekV3Config], - hidden_size: int, - q_lora_rank: int, - quant_config: Optional[QuantizationConfig], - cache_config: Optional[CacheConfig], - topk_indices_buffer: Optional[torch.Tensor], - prefix: str = ""): + def __init__( + self, + vllm_config: VllmConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], + hidden_size: int, + q_lora_rank: int, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + topk_indices_buffer: Optional[torch.Tensor], + prefix: str = "", + ): super().__init__() self.vllm_config = vllm_config self.config = config @@ -743,21 +786,24 @@ def __init__(self, self.rope_dim = config.qk_rope_head_dim # 64 self.q_lora_rank = q_lora_rank # 1536 # no tensor parallel, just replicated - self.wq_b = ReplicatedLinear(self.q_lora_rank, - self.head_dim * self.n_head, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wq_b") - self.wk = ReplicatedLinear(hidden_size, - self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wk") + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.head_dim * self.n_head, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + ) + self.wk = ReplicatedLinear( + hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk", + ) self.k_norm = LayerNorm(self.head_dim, eps=1e-6) - self.weights_proj = ReplicatedLinear(hidden_size, - self.n_head, - quant_config=None, - prefix=f"{prefix}.weights_proj") + self.weights_proj = ReplicatedLinear( + hidden_size, self.n_head, quant_config=None, prefix=f"{prefix}.weights_proj" + ) self.softmax_scale = self.head_dim**-0.5 self.scale_fmt = "ue8m0" @@ -768,28 +814,31 @@ def __init__(self, # where we store value in fp8 and scale in fp32 # per self.quant_block_size element self.k_cache = DeepseekV32IndexerCache( - head_dim=self.head_dim + - self.head_dim // self.quant_block_size * 4, + head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4, dtype=torch.uint8, prefix=f"{prefix}.k_cache", - cache_config=cache_config) + cache_config=cache_config, + ) self.max_model_len = vllm_config.model_config.max_model_len self.prefix = prefix - from vllm.v1.attention.backends.mla.indexer import ( - get_max_prefill_buffer_size) + from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size + self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) - def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, - rotary_emb) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb + ) -> torch.Tensor: q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) q_pe, q_nope = torch.split( - q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) k, _ = self.wk(hidden_states) k = self.k_norm(k) k_pe, k_nope = torch.split( - k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 + ) q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) q = torch.cat([q_pe, q_nope], dim=-1) @@ -797,17 +846,19 @@ def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim) - q_fp8, q_scale = per_token_group_quant_fp8(q, - self.quant_block_size, - column_major_scales=False, - use_ue8m0=self.scale_fmt - is not None) + q_fp8, q_scale = per_token_group_quant_fp8( + q, + self.quant_block_size, + column_major_scales=False, + use_ue8m0=self.scale_fmt is not None, + ) q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) q_scale = q_scale.view(-1, self.n_head, 1) weights, _ = self.weights_proj(hidden_states) - weights = weights.unsqueeze( - -1) * q_scale * self.softmax_scale * self.n_head**-0.5 + weights = ( + weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 + ) weights = weights.squeeze(-1) return torch.ops.vllm.sparse_attn_indexer( @@ -831,7 +882,7 @@ class DeepseekV2MLAAttention(nn.Module): """ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - + For more info see MLACommonImpl in: vllm/v1/attention/backends/mla/utils.py """ @@ -881,53 +932,60 @@ def __init__( bias=False, quant_config=quant_config, prefix=f"{prefix}.fused_qkv_a_proj", - disable_tp=True) + disable_tp=True, + ) else: self.kv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) if self.q_lora_rank is not None: - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(self.q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + self.q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + prefix=f"{prefix}.kv_b_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) + rope_scaling["rope_type"] = "deepseek_yarn" + self.rotary_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, + ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) scaling_factor = rope_scaling["factor"] @@ -937,9 +995,16 @@ def __init__( self.is_v32 = hasattr(config, "index_topk") if self.is_v32: - self.indexer = Indexer(vllm_config, config, hidden_size, - q_lora_rank, quant_config, cache_config, - topk_indices_buffer, f"{prefix}.indexer") + self.indexer = Indexer( + vllm_config, + config, + hidden_size, + q_lora_rank, + quant_config, + cache_config, + topk_indices_buffer, + f"{prefix}.indexer", + ) else: self.indexer = None @@ -949,11 +1014,12 @@ def __init__( rotary_emb=self.rotary_emb, o_proj=self.o_proj, fused_qkv_a_proj=self.fused_qkv_a_proj - if self.q_lora_rank is not None else None, + if self.q_lora_rank is not None + else None, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa - if self.q_lora_rank is None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, + if self.q_lora_rank is None + else None, + q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, q_proj=self.q_proj if self.q_lora_rank is None else None, indexer=self.indexer, @@ -985,11 +1051,12 @@ def forward( class DeepseekV2DecoderLayer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str, - topk_indices_buffer: Optional[torch.Tensor] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str, + topk_indices_buffer: Optional[torch.Tensor] = None, + ) -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -1001,11 +1068,10 @@ def __init__(self, self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) + layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx if model_config.use_mla: attn_cls = DeepseekV2MLAAttention @@ -1019,8 +1085,7 @@ def __init__(self, qk_nope_head_dim=config.qk_nope_head_dim, qk_rope_head_dim=config.qk_rope_head_dim, v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, + q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, kv_lora_rank=config.kv_lora_rank, rope_theta=rope_theta, rope_scaling=rope_scaling, @@ -1031,9 +1096,11 @@ def __init__(self, topk_indices_buffer=topk_indices_buffer, ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): self.mlp = DeepseekV2MoE( config=config, parallel_config=parallel_config, @@ -1048,10 +1115,10 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.routed_scaling_factor = config.routed_scaling_factor def forward( @@ -1065,8 +1132,7 @@ def forward( residual = hidden_states.clone() hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -1076,32 +1142,29 @@ def forward( # Fix FP16 overflow # We scale both hidden_states and residual before # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor + hidden_states *= 1.0 / self.routed_scaling_factor if self.layer_idx == 0: # The residual is shared by all layers, we only scale it on # first layer. - residual *= 1. / self.routed_scaling_factor + residual *= 1.0 / self.routed_scaling_factor # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) - if isinstance(self.mlp, - DeepseekV2MLP) and hidden_states.dtype == torch.float16: + if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16: # Fix FP16 overflow # Scaling the DeepseekV2MLP output, it is the input of # input_layernorm of next decoder layer. # The scaling of DeepseekV2MOE output would be done in the forward # of DeepseekV2MOE - hidden_states *= 1. / self.routed_scaling_factor + hidden_states *= 1.0 / self.routed_scaling_factor return hidden_states, residual @support_torch_compile class DeepseekV2Model(nn.Module): - fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -1119,7 +1182,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda") + device="cuda", + ) else: topk_indices_buffer = None @@ -1128,23 +1192,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, - topk_indices_buffer), - prefix=f"{prefix}.layers") + lambda prefix: DeepseekV2DecoderLayer( + vllm_config, prefix, topk_indices_buffer + ), + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -1171,17 +1238,15 @@ def forward( hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, - SupportsLoRA): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } @@ -1197,16 +1262,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the # quant_method for relevant layers during initialization. - self.fuse_qkv_a_proj = hasattr( - config, "q_lora_rank") and config.q_lora_rank is not None + self.fuse_qkv_a_proj = ( + hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + ) if self.fuse_qkv_a_proj: self.packed_modules_mapping["fused_qkv_a_proj"] = [ "q_a_proj", "kv_a_proj_with_mqa", ] - self.model = DeepseekV2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = DeepseekV2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( config.vocab_size, @@ -1218,12 +1285,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) self.expert_weights = [] # Set MoE hyperparameters - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) + self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_expert_groups = config.n_group self.moe_layers: list[FusedMoE] = [] @@ -1272,8 +1339,7 @@ def update_physical_experts_metadata( assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) + self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, DeepseekV2MoE): moe = layer.mlp @@ -1292,8 +1358,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -1303,8 +1370,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -1320,7 +1386,8 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=self.num_redundant_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -1332,7 +1399,7 @@ def load_weights(self, weights: Iterable[tuple[str, if spec_layer is not None: continue # skip spec decode layers for main model - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -1342,15 +1409,16 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled # if go with fusion option, then update name - if ((param_name == "fused_qkv_a_proj") - and name_mapped not in params_dict): + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: continue else: name = name_mapped @@ -1387,14 +1455,17 @@ def load_weights(self, weights: Iterable[tuple[str, # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) if success: name = name_mapped break @@ -1418,8 +1489,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -1432,13 +1504,15 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): # Compatibility with # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py -def get_spec_layer_idx_from_weight_name(config: Union[DeepseekV2Config, - DeepseekV3Config], - weight_name: str) -> Optional[int]: - if (hasattr(config, "num_nextn_predict_layers") - and config.num_nextn_predict_layers > 0): +def get_spec_layer_idx_from_weight_name( + config: Union[DeepseekV2Config, DeepseekV3Config], weight_name: str +) -> Optional[int]: + if ( + hasattr(config, "num_nextn_predict_layers") + and config.num_nextn_predict_layers > 0 + ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx+i}."): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): return layer_idx + i return None diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 107949df2270..8226e88c47a2 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -3,6 +3,7 @@ # adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py """Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Literal, Optional, Union @@ -20,28 +21,44 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, MultiModalUUIDDict) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalProcessingInfo, - PromptReplacement, PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, - MlpProjectorConfig, - VisionEncoderConfig) -from vllm.transformers_utils.processors.deepseek_vl2 import ( - DeepseekVLV2Processor) +from vllm.transformers_utils.configs.deepseek_vl2 import ( + DeepseekVLV2Config, + MlpProjectorConfig, + VisionEncoderConfig, +) +from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) # The image token id may be various _IMAGE_TOKEN = "<image>" @@ -56,9 +73,9 @@ class DeepseekVL2ImagePixelInputs(TensorSchema): - h: Height of each image - w: Width of each image """ + type: Literal["pixel_values"] - data: Annotated[torch.Tensor, - TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})] + data: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})] images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)] @@ -69,51 +86,53 @@ class DeepseekVL2VImageEmbeddingInputs(TensorSchema): - f: Image feature size - h: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "f", "h")] + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], TensorShape("bn", "f", "h") + ] -DeepseekVL2ImageInputs = Union[DeepseekVL2ImagePixelInputs, - DeepseekVL2VImageEmbeddingInputs] +DeepseekVL2ImageInputs = Union[ + DeepseekVL2ImagePixelInputs, DeepseekVL2VImageEmbeddingInputs +] class MlpProjector(nn.Module): - def __init__(self, cfg: MlpProjectorConfig): - super().__init__() self.cfg = cfg - assert not cfg.token_pooling, ( - "Token pooling is not supported currently.") + assert not cfg.token_pooling, "Token pooling is not supported currently." if cfg.projector_type == "downsample_mlp_gelu": mlp_depth = cfg.depth mlp_ratio = cfg.mlp_ratio modules = [ nn.Linear( - cfg.input_dim * cfg.downsample_ratio * - cfg.downsample_ratio, cfg.n_embed * mlp_ratio) + cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, + cfg.n_embed * mlp_ratio, + ) ] for _ in range(1, mlp_depth - 1): modules.append(nn.GELU()) modules.append( - nn.Linear(cfg.n_embed * mlp_ratio, - cfg.n_embed * mlp_ratio)) + nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio) + ) modules.append(nn.GELU()) modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) modules = nn.Sequential(*modules) else: raise NotImplementedError( - f"Unsupported projector type: {cfg.projector_type}") + f"Unsupported projector type: {cfg.projector_type}" + ) self.layers = modules def forward(self, x): bs, hw, input_dim = x.shape - h = w = int((hw)**0.5) + h = w = int((hw) ** 0.5) """compute padding""" if h % self.cfg.downsample_ratio: pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio @@ -124,17 +143,18 @@ def forward(self, x): x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) """4 to 1 concat""" x = x.permute(0, 3, 1, 2) # B, C, H, W - x = F.unfold(x, - kernel_size=self.cfg.downsample_ratio, - stride=self.cfg.downsample_ratio, - padding=0) # B, C*4, HW // 4 + x = F.unfold( + x, + kernel_size=self.cfg.downsample_ratio, + stride=self.cfg.downsample_ratio, + padding=0, + ) # B, C*4, HW // 4 x = x.permute(0, 2, 1) return self.layers(x) class DeepseekVL2ProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(DeepseekVLV2Config) @@ -144,11 +164,9 @@ def get_hf_processor(self, **kwargs: object): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_num_image_tokens(self, - *, - image_width: int, - image_height: int, - cropping: bool = True) -> int: + def get_num_image_tokens( + self, *, image_width: int, image_height: int, cropping: bool = True + ) -> int: hf_processor = self.get_hf_processor() image_size = hf_processor.image_size patch_size = hf_processor.patch_size @@ -156,9 +174,12 @@ def get_num_image_tokens(self, if cropping: best_width, best_height = hf_processor.select_best_resolution( - (image_width, image_height)) - num_width_tiles, num_height_tiles = (best_width // image_size, - best_height // image_size) + (image_width, image_height) + ) + num_width_tiles, num_height_tiles = ( + best_width // image_size, + best_height // image_size, + ) else: num_width_tiles = num_height_tiles = 1 @@ -171,15 +192,16 @@ def get_num_image_tokens(self, def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() candidate_resolutions = hf_config.candidate_resolutions - height, width = max(candidate_resolutions, - key=lambda x: self.get_num_image_tokens( - image_width=x[1], image_height=x[0])) + height, width = max( + candidate_resolutions, + key=lambda x: self.get_num_image_tokens( + image_width=x[1], image_height=x[0] + ), + ) return ImageSize(width=width, height=height) -class DeepseekVL2DummyInputsBuilder( - BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): - +class DeepseekVL2DummyInputsBuilder(BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -201,17 +223,18 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=max_image_size.width, - height=max_image_size.height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=max_image_size.width, + height=max_image_size.height, + num_images=num_images, + overrides=image_overrides, + ) } class DeepseekVL2MultiModalProcessor( - BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): - + BaseMultiModalProcessor[DeepseekVL2ProcessingInfo] +): def _call_hf_processor( self, prompt: str, @@ -221,9 +244,7 @@ def _call_hf_processor( ) -> BatchFeature: if not mm_data: tokenizer = self.info.get_tokenizer() - return tokenizer(prompt, - add_special_tokens=True, - return_tensors="pt") + return tokenizer(prompt, add_special_tokens=True, return_tensors="pt") processed_outputs = super()._call_hf_processor( prompt=prompt, @@ -233,7 +254,8 @@ def _call_hf_processor( ) processed_outputs["num_patches"] = ( - processed_outputs["images_spatial_crop"].prod(-1) + 1) + processed_outputs["images_spatial_crop"].prod(-1) + 1 + ) return processed_outputs @@ -245,8 +267,7 @@ def _get_mm_fields_config( num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), images_spatial_crop=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), ) @@ -264,7 +285,8 @@ def _get_prompt_updates( def get_replacement_deepseek_vl2(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -319,13 +341,16 @@ def _cached_apply_hf_processor( @MULTIMODAL_REGISTRY.register_processor( DeepseekVL2MultiModalProcessor, info=DeepseekVL2ProcessingInfo, - dummy_inputs=DeepseekVL2DummyInputsBuilder) + dummy_inputs=DeepseekVL2DummyInputsBuilder, +) class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): merge_by_field_config = True - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "language.": "language_model.", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language.": "language_model.", + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -351,9 +376,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): tokenizer = cached_tokenizer_from_config(model_config) self.image_token_id: int = tokenizer.vocab[_IMAGE_TOKEN] - self.vision = self._init_vision_module(self.vision_config, - quant_config, - maybe_prefix(prefix, "vision")) + self.vision = self._init_vision_module( + self.vision_config, quant_config, maybe_prefix(prefix, "vision") + ) self.projector = MlpProjector(self.projector_config) self.tile_tag = config.tile_tag @@ -361,14 +386,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # special token for image token sequence format embed_std = 1 / torch.sqrt( - torch.tensor(self.projector_config.n_embed, dtype=torch.float32)) + torch.tensor(self.projector_config.n_embed, dtype=torch.float32) + ) if self.tile_tag == "2D": # <|view_seperator|>, <|\n|> self.image_newline = nn.Parameter( - torch.randn(self.projector_config.n_embed) * embed_std) + torch.randn(self.projector_config.n_embed) * embed_std + ) # This is a typo in original implementation self.view_seperator = nn.Parameter( - torch.randn(self.projector_config.n_embed) * embed_std) + torch.randn(self.projector_config.n_embed) * embed_std + ) else: raise ValueError( f"Only 2D tile_tag is supported currently, got: {self.tile_tag}" @@ -389,19 +417,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _get_parent_and_attr(self, root: torch.nn.Module, dotted_name: str): """Return (parent_module, final_attr_name) for a dotted module path.""" - names = dotted_name.split('.') + names = dotted_name.split(".") parent = root for n in names[:-1]: parent = getattr(parent, n) return parent, names[-1] - #patch for timm ViT instance to support tensor parallel - def patch_vit_for_tp(self, vit: torch.nn.Module, - quant_config: QuantizationConfig): + # patch for timm ViT instance to support tensor parallel + def patch_vit_for_tp(self, vit: torch.nn.Module, quant_config: QuantizationConfig): try: import timm except ImportError as e: @@ -411,17 +439,14 @@ def patch_vit_for_tp(self, vit: torch.nn.Module, if isinstance(module, nn.Linear): parent, attr_name = self._get_parent_and_attr(vit, name) if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1": - new_linear = replace_linear_class(module, - "colwise", - quant_config, - prefix=name) + new_linear = replace_linear_class( + module, "colwise", quant_config, prefix=name + ) setattr(parent, attr_name, new_linear) - elif isinstance(parent, - timm.layers.Mlp) and attr_name == "fc2": - new_linear = replace_linear_class(module, - "rowwise", - quant_config, - prefix=name) + elif isinstance(parent, timm.layers.Mlp) and attr_name == "fc2": + new_linear = replace_linear_class( + module, "rowwise", quant_config, prefix=name + ) setattr(parent, attr_name, new_linear) return vit @@ -454,7 +479,8 @@ def _init_vision_module( return model def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[DeepseekVL2ImageInputs]: + self, **kwargs: object + ) -> Optional[DeepseekVL2ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) images_spatial_crop = kwargs.pop("images_spatial_crop", None) image_embeds = kwargs.pop("image_embeds", None) @@ -471,7 +497,8 @@ def _parse_and_validate_image_input( resolve_bindings={ "h": expected_h, "w": expected_w, - }) + }, + ) if image_embeds is not None: return DeepseekVL2VImageEmbeddingInputs( @@ -509,8 +536,9 @@ def _pixel_values_to_embedding( global_features = images_embeds[tile_index] # [num_height_tiles * num_width_tiles, hw, D] - local_features = images_embeds[tile_index + 1:tile_index + 1 + - num_tiles_in_image] + local_features = images_embeds[ + tile_index + 1 : tile_index + 1 + num_tiles_in_image + ] tile_index += num_tiles_in_image + 1 # format global and local features @@ -522,8 +550,7 @@ def _pixel_values_to_embedding( new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h) # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D] - global_features = torch.cat([global_features, new_lines_in_global], - dim=1) + global_features = torch.cat([global_features, new_lines_in_global], dim=1) # [h, w + 1, D] -> [h * (w + 1), D] global_features = global_features.view(-1, n_dim) @@ -531,22 +558,22 @@ def _pixel_values_to_embedding( # ----------------- local view add newline ----------------- # [num_height_tiles * num_width_tiles, h * w, D] -> # [num_height_tiles * h, num_width_tiles * w, D] - local_features = rearrange(local_features, - "(th tw) (h w) d -> (th h) (tw w) d", - th=num_height_tiles, - tw=num_width_tiles, - h=h, - w=w) + local_features = rearrange( + local_features, + "(th tw) (h w) d -> (th h) (tw w) d", + th=num_height_tiles, + tw=num_width_tiles, + h=h, + w=w, + ) # [D] -> [num_height_tiles * h, 1, D] - new_lines_in_local = repeat(self.image_newline, - "d -> (th h) 1 d", - th=num_height_tiles, - h=h) + new_lines_in_local = repeat( + self.image_newline, "d -> (th h) 1 d", th=num_height_tiles, h=h + ) # [num_height_tiles * h, num_width_tiles * w + 1, D] - local_features = torch.cat([local_features, new_lines_in_local], - dim=1) + local_features = torch.cat([local_features, new_lines_in_local], dim=1) # [num_height_tiles * h, num_width_tiles * w + 1, D] # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D] @@ -554,23 +581,28 @@ def _pixel_values_to_embedding( # merge global and local tiles if self.global_view_pos == "head": - global_local_features = torch.cat([ - global_features, - self.view_seperator[None, :], - local_features, - ]) + global_local_features = torch.cat( + [ + global_features, + self.view_seperator[None, :], + local_features, + ] + ) else: - global_local_features = torch.cat([ - local_features, - self.view_seperator[None, :], - global_features, - ]) + global_local_features = torch.cat( + [ + local_features, + self.view_seperator[None, :], + global_features, + ] + ) vision_embeddings.append(global_local_features) return vision_embeddings def _process_image_input( - self, image_input: DeepseekVL2ImageInputs) -> list[torch.Tensor]: + self, image_input: DeepseekVL2ImageInputs + ) -> list[torch.Tensor]: if image_input["type"] == "image_embeds": image_data = image_input["data"] if is_list_of(image_data, torch.Tensor): @@ -588,33 +620,33 @@ def _process_image_input( images_spatial_crop = image_input["images_spatial_crop"] return self._pixel_values_to_embedding( - pixel_values=pixel_values, images_spatial_crop=images_spatial_crop) + pixel_values=pixel_values, images_spatial_crop=images_spatial_crop + ) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] vision_embeddings = self._process_image_input(image_input) return vision_embeddings - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object): - + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -624,10 +656,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - autoloaded_weights = loader.load_weights(weights, - mapper=self.hf_to_vllm_mapper) + autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return autoloaded_weights diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 2a09234b59ed..1ae7457fb215 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -24,6 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only dots1 model.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -35,33 +36,45 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class Dots1MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -73,19 +86,24 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -96,7 +114,6 @@ def forward(self, x): class Dots1MoE(nn.Module): - def __init__( self, config: Dots1Config, @@ -109,17 +126,22 @@ def __init__( self.n_shared_experts = config.n_shared_experts if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") - - self.gate = ReplicatedLinear(config.hidden_size, - config.n_routed_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) if config.topk_method == "noaux_tc": - self.gate.e_score_correction_bias = (nn.Parameter( - torch.empty(config.n_routed_experts))) + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts) + ) else: self.gate.e_score_correction_bias = None @@ -138,11 +160,11 @@ def __init__( scoring_func=config.scoring_func, # we do scaling outside, set factor to 1.0 to avoid double mul routed_scaling_factor=1.0, - e_score_correction_bias=self.gate.e_score_correction_bias) + e_score_correction_bias=self.gate.e_score_correction_bias, + ) if config.n_shared_experts is not None: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) + intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = Dots1MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, @@ -158,19 +180,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor + final_hidden_states = ( + self.experts(hidden_states=hidden_states, router_logits=router_logits) + * self.routed_scaling_factor + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) class Dots1Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -200,8 +221,7 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = getattr(config, "head_dim", - hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -244,14 +264,15 @@ def __init__( self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) - def forward(self, positions: torch.Tensor, - hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, positions: torch.Tensor, hidden_states: torch.Tensor + ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = self.q_norm(q.reshape(-1, self.num_heads, - self.head_dim)).reshape(q.shape) - k = self.k_norm(k.reshape(-1, self.num_kv_heads, - self.head_dim)).reshape(k.shape) + q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape( + k.shape + ) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) @@ -259,7 +280,6 @@ def forward(self, positions: torch.Tensor, class Dots1DecoderLayer(nn.Module): - def __init__( self, config: Dots1Config, @@ -272,9 +292,8 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - layer_idx = int(prefix.split(sep='.')[-1]) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx self.self_attn = Dots1Attention( @@ -289,12 +308,14 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = Dots1MoE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): + self.mlp = Dots1MoE( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) else: self.mlp = Dots1MLP( hidden_size=config.hidden_size, @@ -303,10 +324,10 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.routed_scaling_factor = config.routed_scaling_factor def forward( @@ -319,19 +340,15 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class Dots1Model(nn.Module): - fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -350,7 +367,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) else: self.embed_tokens = PPMissingLayer() @@ -363,15 +381,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config=cache_config, quant_config=quant_config, ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -400,10 +419,9 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -412,10 +430,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -430,10 +448,10 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) if name.endswith(".bias") and name not in params_dict: @@ -456,11 +474,13 @@ def load_weights(self, weights: Iterable[tuple[str, param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: if name.endswith(".bias") and name not in params_dict: @@ -471,15 +491,15 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -498,19 +518,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Dots1Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Dots1Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -537,8 +560,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index cda46d123901..1bc50f27269e 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -10,40 +10,52 @@ from transformers.models.qwen2_vl import Qwen2VLProcessor from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import (check_upstream_fa_availability, - maybe_get_vit_flash_attn_backend) +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import utils as dist_utils from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, - SupportsLoRA, - SupportsMultiModal, - SupportsPP) +from vllm.model_executor.models.interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention -from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder, - Qwen2VLMultiModalProcessor, - Qwen2VLProcessingInfo) -from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, - maybe_prefix) +from vllm.model_executor.models.qwen2_vl import ( + Qwen2VLDummyInputsBuilder, + Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, - DotsVisionConfig) +from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionConfig from vllm.utils.tensor_schema import TensorSchema, TensorShape from .vision import run_dp_sharded_mrope_vision_model @@ -59,6 +71,7 @@ class DotsOCRImagePixelInputs(TensorSchema): - ni: Number of images - cps: Number of channels * patch_size * patch_size """ + type: Literal["pixel_values"] pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] @@ -72,18 +85,17 @@ class DotsOCRImageEmbeddingInputs(TensorSchema): - hs: Hidden size - ni: Number of images """ + type: Literal["image_embeds"] image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] -DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, - DotsOCRImageEmbeddingInputs] +DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, DotsOCRImageEmbeddingInputs] class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) return IMAGE_TOKEN * num_images @@ -102,23 +114,22 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), } class DotsOCRProcessingInfo(Qwen2VLProcessingInfo): - def get_hf_config(self) -> DotsOCRConfig: config = self.ctx.get_hf_config() - if not config.__class__.__name__ == 'DotsOCRConfig': + if not config.__class__.__name__ == "DotsOCRConfig": raise TypeError(f"Expected DotsOCRConfig, got {type(config)}") - if hasattr(config, "vision_config") and isinstance( - config.vision_config, dict): + if hasattr(config, "vision_config") and isinstance(config.vision_config, dict): config.vision_config = DotsVisionConfig(**config.vision_config) return config @@ -138,8 +149,7 @@ def get_hf_processor( self, **kwargs: object, ) -> Qwen2VLProcessor: - self.get_tokenizer( - ).image_token = IMAGE_TOKEN # Ensure image token is set + self.get_tokenizer().image_token = IMAGE_TOKEN # Ensure image token is set processor = self.ctx.get_hf_processor( Qwen2VLProcessor, **kwargs, @@ -151,13 +161,14 @@ def get_hf_processor( def rotate_half(x): """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb_vision(tensor: torch.Tensor, - freqs: torch.Tensor) -> torch.Tensor: +def apply_rotary_pos_emb_vision( + tensor: torch.Tensor, freqs: torch.Tensor +) -> torch.Tensor: orig_dtype = tensor.dtype tensor = tensor.float() @@ -175,23 +186,20 @@ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) return freqs class PatchMerger(nn.Module): - def __init__( self, dim: int, @@ -210,19 +218,23 @@ def __init__( self.ln_q = RMSNorm(context_dim, eps=1e-6) self.mlp = nn.Sequential( - ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - return_bias=False, - prefix=f"{prefix}.0", - disable_tp=use_data_parallel), + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + return_bias=False, + prefix=f"{prefix}.0", + disable_tp=use_data_parallel, + ), nn.GELU(), - RowParallelLinear(self.hidden_size, - dim, - bias=True, - return_bias=False, - prefix=f"{prefix}.2", - disable_tp=use_data_parallel), + RowParallelLinear( + self.hidden_size, + dim, + bias=True, + return_bias=False, + prefix=f"{prefix}.2", + disable_tp=use_data_parallel, + ), ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -234,26 +246,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DotsVisionAttention(nn.Module): - - def __init__(self, - config, - dim: int, - num_heads: int = 16, - bias: bool = True, - *, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False) -> None: + def __init__( + self, + config, + dim: int, + num_heads: int = 16, + bias: bool = True, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ) -> None: super().__init__() self.embed_dim = dim - self.tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) - self.tp_rank = (0 if use_data_parallel else - get_tensor_model_parallel_rank()) + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, self.tp_size) + num_heads, self.tp_size + ) # qkv/proj follow Qwen2-VL style; bias controlled by arg self.qkv = QKVParallelLinear( hidden_size=dim, @@ -262,31 +276,40 @@ def __init__(self, bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv", - disable_tp=use_data_parallel) - self.proj = RowParallelLinear(input_size=dim, - output_size=dim, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.proj", - disable_tp=use_data_parallel) + disable_tp=use_data_parallel, + ) + self.proj = RowParallelLinear( + input_size=dim, + output_size=dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) # Select attention backend self.attn_backend = get_vit_attn_backend( - self.hidden_size_per_attention_head, torch.get_default_dtype()) + self.hidden_size_per_attention_head, torch.get_default_dtype() + ) self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func \ - = maybe_get_vit_flash_attn_backend( + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, ) + ) if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( - f"Unsupported vision attention backend: {self.attn_backend}") + f"Unsupported vision attention backend: {self.attn_backend}" + ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } def forward( @@ -317,18 +340,23 @@ def forward( q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) - output = self.flash_attn_varlen_func(q_, - k_, - v_, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) - context_layer = output.view(bs, -1, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + output = self.flash_attn_varlen_func( + q_, + k_, + v_, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + ) + context_layer = output.view( + bs, + -1, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) elif self.attn_backend == _Backend.TORCH_SDPA: outputs = [] for i in range(1, len(cu_seqlens)): @@ -337,21 +365,20 @@ def forward( q_i = q[:, s:e].permute(0, 2, 1, 3) k_i = k[:, s:e].permute(0, 2, 1, 3) v_i = v[:, s:e].permute(0, 2, 1, 3) - out_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) out_i = out_i.permute(0, 2, 1, 3) outputs.append(out_i) context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) else: raise RuntimeError("Unsupported attention backend") @@ -363,31 +390,36 @@ def forward( class DotsSwiGLUFFN(nn.Module): - - def __init__(self, - config, - *, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() hidden_features = config.intermediate_size in_features = config.embed_dim bias = config.use_bias # Referenced aimv2.py AIMv2SwiGLUFFN - self.fc13 = MergedColumnParallelLinear(in_features, - [hidden_features] * 2, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.fc13", - disable_tp=use_data_parallel) - self.fc2 = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.fc2", - disable_tp=use_data_parallel) + self.fc13 = MergedColumnParallelLinear( + in_features, + [hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc13", + disable_tp=use_data_parallel, + ) + self.fc2 = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, + ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -396,8 +428,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.fc2(x) return x - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("fc13", "fc1", 0), ("fc13", "fc3", 1), @@ -405,7 +436,6 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -423,15 +453,13 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class DotsPatchEmbed(nn.Module): - def __init__(self, config): super().__init__() self.num_channels = config.num_channels @@ -448,15 +476,19 @@ def __init__(self, config): self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: - x = x.view(-1, self.num_channels, self.temporal_patch_size, - self.patch_size, self.patch_size)[:, :, 0] + x = x.view( + -1, + self.num_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + )[:, :, 0] x = self.proj(x).view(-1, self.embed_dim) x = self.norm(x) return x class DotsViTPreprocessor(nn.Module): - def __init__(self, config): super().__init__() self.patch_h = config.patch_size @@ -471,7 +503,6 @@ def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: class DotsVisionBlock(nn.Module): - def __init__( self, config, @@ -482,27 +513,33 @@ def __init__( ): super().__init__() - self.attn = DotsVisionAttention(config, - config.embed_dim, - num_heads=config.num_attention_heads, - bias=config.use_bias, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel) + self.attn = DotsVisionAttention( + config, + config.embed_dim, + num_heads=config.num_attention_heads, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, + ) self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) - self.mlp = DotsSwiGLUFFN(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) + self.mlp = DotsSwiGLUFFN( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) - def forward(self, - hidden_states: torch.Tensor, - *, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, - seqlens: Optional[list[int]] = None) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + *, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, + seqlens: Optional[list[int]] = None, + ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, @@ -515,7 +552,6 @@ def forward(self, class DotsVisionTransformer(nn.Module): - def __init__( self, config: DotsVisionConfig, @@ -535,26 +571,34 @@ def __init__( head_dim = config.embed_dim // config.num_attention_heads self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability(torch.get_default_dtype()): + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): self.attn_backend = _Backend.FLASH_ATTN self.out_hidden_size = config.hidden_size # Keep blocks for compatibility with other vision towers - num_layers = (config.num_hidden_layers if num_hidden_layers_override - is None else num_hidden_layers_override) - self.blocks = nn.ModuleList([ - DotsVisionBlock(config, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{i}", - use_data_parallel=use_data_parallel) - for i in range(num_layers) - ]) + num_layers = ( + config.num_hidden_layers + if num_hidden_layers_override is None + else num_hidden_layers_override + ) + self.blocks = nn.ModuleList( + [ + DotsVisionBlock( + config, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}", + use_data_parallel=use_data_parallel, + ) + for i in range(num_layers) + ] + ) if require_post_norm is None: - require_post_norm = (len(self.blocks) == config.num_hidden_layers) + require_post_norm = len(self.blocks) == config.num_hidden_layers if require_post_norm and self.config.post_norm: - self.post_trunk_norm = RMSNorm(config.embed_dim, - eps=config.rms_norm_eps) + self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) else: self.post_trunk_norm = None @@ -573,7 +617,7 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.patch_embed.patchifier.proj.weight.device - def get_pos_ids_by_grid(self, grid_thw): + def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]: pos_ids = [] for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) @@ -595,56 +639,58 @@ def get_pos_ids_by_grid(self, grid_thw): ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) return pos_ids - def rot_pos_emb(self, grid_thw): + def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: pos_ids = self.get_pos_ids_by_grid(grid_thw) pos_ids = torch.cat(pos_ids, dim=0) - max_grid_size = grid_thw[:, 1:].max() + max_grid_size = max(max(h, w) for _, h, w in grid_thw) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb def compute_attn_mask_seqlen( - self, cu_seqlens: torch.Tensor + self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens - def forward(self, hidden_states: torch.Tensor, - grid_thw: list[list[int]]) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, grid_thw: list[list[int]] + ) -> torch.Tensor: + rotary_pos_emb = self.rot_pos_emb(grid_thw) + # Convert grid_thw to tensor (always expecting list format now) - grid_thw = torch.tensor(grid_thw, - device=hidden_states.device, - dtype=torch.long) + grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long) hidden_states = hidden_states.to(self.dtype) hidden_states = self.patch_embed(hidden_states, grid_thw) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, - dtype=grid_thw.dtype - if torch.jit.is_tracing() else torch.int32, - ) + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) for blk in self.blocks: - hidden_states = blk(hidden_states, - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - max_seqlen=max_seqlen, - seqlens=seqlens) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) if self.post_trunk_norm is not None: hidden_states = self.post_trunk_norm(hidden_states) @@ -658,8 +704,7 @@ def forward(self, hidden_states: torch.Tensor, info=DotsOCRProcessingInfo, dummy_inputs=DotsOCRDummyInputsBuilder, ) -class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, - SupportsLoRA): +class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( @@ -709,7 +754,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vision_config, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "vision_tower"), - use_data_parallel=self.use_data_parallel) + use_data_parallel=self.use_data_parallel, + ) self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( vllm_config=vllm_config, hf_config=self.config, @@ -718,7 +764,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[DotsOCRImageInputs]: + self, **kwargs: object + ) -> Optional[DotsOCRImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -727,27 +774,30 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - return DotsOCRImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return DotsOCRImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: - return DotsOCRImageEmbeddingInputs(type="image_embeds", - image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + return DotsOCRImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) def _process_image_input( - self, image_input: DotsOCRImageInputs) -> tuple[torch.Tensor, ...]: + self, image_input: DotsOCRImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type( - self.vision_tower.dtype) + image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype) else: - pixel_values = image_input["pixel_values"].type( - self.vision_tower.dtype) + pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype) if self.use_data_parallel: return run_dp_sharded_mrope_vision_model( @@ -757,21 +807,23 @@ def _process_image_input( rope_type="rope_3d", ) else: - image_embeds = self.vision_tower( - pixel_values, grid_thw)[:, :self.config.hidden_size] + image_embeds = self.vision_tower(pixel_values, grid_thw_list)[ + :, : self.config.hidden_size + ] # Split concatenated embeddings for each image item. merge_size = self.vision_tower.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return image_embeds.split(sizes) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -812,8 +864,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ernie45.py b/vllm/model_executor/models/ernie45.py index e7302dc5ecdd..b1d26cddcc5e 100644 --- a/vllm/model_executor/models/ernie45.py +++ b/vllm/model_executor/models/ernie45.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Erine model compatible with HuggingFace weights.""" + from vllm.config import VllmConfig from vllm.model_executor.models.llama import LlamaForCausalLM @@ -29,7 +30,6 @@ class Ernie4_5ForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) # Hack Llama model to fit HF format Ernie4.5 dense implementation diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index 38c5249380c3..3cb93177a383 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only ErineMoE model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -38,30 +39,40 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Ernie4_5_MoeMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -74,19 +85,24 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=use_bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=use_bias, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -97,7 +113,6 @@ def forward(self, x): class Ernie4_5_MoeMoE(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -109,23 +124,26 @@ def __init__( layer_idx = extract_layer_index(prefix) self.layer_idx = layer_idx self.tp_size = get_tensor_model_parallel_world_size() - self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0) - > 0) + self.has_shared_experts = getattr(config, "moe_num_shared_experts", 0) > 0 if self.tp_size > config.moe_num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.moe_num_experts}.") - - self.gate = ReplicatedLinear(config.hidden_size, - config.moe_num_experts, - bias=False, - params_dtype=torch.float32, - quant_config=None, - prefix=f"{prefix}.gate") + f"the number of experts {config.moe_num_experts}." + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts, + bias=False, + params_dtype=torch.float32, + quant_config=None, + prefix=f"{prefix}.gate", + ) self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.moe_num_experts, dtype=torch.float32)) + torch.empty(config.moe_num_experts, dtype=torch.float32) + ) self.experts = FusedMoE( num_experts=config.moe_num_experts, @@ -136,19 +154,21 @@ def __init__( renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", - e_score_correction_bias=self.gate.e_score_correction_bias) + e_score_correction_bias=self.gate.e_score_correction_bias, + ) if self.has_shared_experts: - intermediate_size = (config.moe_intermediate_size * - config.moe_num_shared_experts) + intermediate_size = ( + config.moe_intermediate_size * config.moe_num_shared_experts + ) self.shared_experts = Ernie4_5_MoeMLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.shared_experts", - reduce_results=self.experts.must_reduce_shared_expert_outputs( - )) + reduce_results=self.experts.must_reduce_shared_expert_outputs(), + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape @@ -160,23 +180,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) - if self.has_shared_experts and \ - shared_output is not None: + if self.has_shared_experts and shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = ( - self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) return final_hidden_states.view(orig_shape) class Ernie4_5_MoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -219,19 +238,23 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = QKVParallelLinear(hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) self.rotary_emb = get_rope( self.head_dim, @@ -241,20 +264,21 @@ def __init__( is_neox_style=False, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -268,7 +292,6 @@ def forward( class Ernie4_5_MoeDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -280,18 +303,17 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 500000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 131072) + max_position_embeddings = getattr(config, "max_position_embeddings", 131072) self.self_attn = Ernie4_5_MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, - head_dim=getattr(config, 'head_dim', None), + head_dim=getattr(config, "head_dim", None), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, rms_norm_eps=config.rms_norm_eps, - qkv_bias=getattr(config, 'use_bias', False), + qkv_bias=getattr(config, "use_bias", False), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", @@ -303,30 +325,35 @@ def __init__( # MoE moe_num_experts = getattr(config, "moe_num_experts", 0) moe_layer_start_index = getattr(config, "moe_layer_start_index", 0) - moe_layer_end_index = getattr(config, "moe_layer_end_index", - config.num_hidden_layers - 1) + moe_layer_end_index = getattr( + config, "moe_layer_end_index", config.num_hidden_layers - 1 + ) moe_layer_interval = getattr(config, "moe_layer_interval", 1) use_moe = getattr(config, "use_moe", moe_num_experts > 0) - if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0) - and layer_idx >= moe_layer_start_index - and layer_idx <= moe_layer_end_index): - self.mlp = Ernie4_5_MoeMoE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if ( + use_moe + and ((layer_idx + 1) % moe_layer_interval == 0) + and layer_idx >= moe_layer_start_index + and layer_idx <= moe_layer_end_index + ): + self.mlp = Ernie4_5_MoeMoE( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) else: self.mlp = Ernie4_5_MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - use_bias=getattr(config, 'use_bias', False), + use_bias=getattr(config, "use_bias", False), quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -334,14 +361,12 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> torch.Tensor: - # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, @@ -349,8 +374,7 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) @@ -359,7 +383,6 @@ def forward( @support_torch_compile class Ernie4_5_MoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -376,16 +399,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Ernie4_5_MoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: Ernie4_5_MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) @@ -394,9 +420,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -408,7 +434,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -424,27 +449,25 @@ def forward( hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.moe_num_experts) + num_experts=self.config.moe_num_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -458,8 +481,7 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if self.config.tie_word_embeddings and name.endswith( - "lm_head.weight"): + if self.config.tie_word_embeddings and name.endswith("lm_head.weight"): continue # MTP will be supported soon. if "mtp" in name: @@ -469,17 +491,18 @@ def load_weights(self, weights: Iterable[tuple[str, name = name.replace("moe_statics", "gate") loaded_weight = loaded_weight.squeeze(0) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -502,22 +525,26 @@ def load_weights(self, weights: Iterable[tuple[str, continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -528,8 +555,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -556,15 +584,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Ernie4_5_MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Ernie4_5_MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() @@ -572,7 +602,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -584,8 +615,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -595,12 +627,10 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 903ddf7953ea..493260cf73ef 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Erine VL model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable, Mapping, Sequence from functools import partial @@ -35,8 +36,10 @@ from transformers import BatchFeature from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import (check_upstream_fa_availability, - maybe_get_vit_flash_attn_backend) +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state @@ -44,26 +47,38 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageSize, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix from .vision import get_vit_attn_backend @@ -78,15 +93,14 @@ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), - "... d two -> ... (d two)", - two=2) + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) -def apply_rotary_emb_torch(x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False) -> torch.Tensor: +def apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) @@ -94,22 +108,21 @@ def apply_rotary_emb_torch(x: torch.Tensor, ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat( - cos, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) sin = repeat( - sin, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) return torch.cat( [ - x[..., :ro_dim] * cos + - rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], ], dim=-1, ) -def apply_rotary_pos_emb_vision(t: torch.Tensor, - freqs: torch.Tensor) -> torch.Tensor: +def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: t_ = t.float() cos = freqs.cos() sin = freqs.sin() @@ -123,14 +136,14 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): """All-gather the input tensor interleavely across model parallel group.""" import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] - dist.all_gather(gathered_tensors, - local_tensor, - group=parallel_state.get_tp_group().device_group) + dist.all_gather( + gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group + ) gathered_tensors_split = [ - torch.split(tensor, hidden_size // tp_size, -1) - for tensor in gathered_tensors + torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors ] ordered_tensors = [ tensor for pair in zip(*gathered_tensors_split) for tensor in pair @@ -155,9 +168,11 @@ def __init__( self.tp_size = parallel_state.get_tensor_model_parallel_world_size() self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) + projection_size, num_heads + ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, self.tp_size) + num_heads, self.tp_size + ) self.qkv = QKVParallelLinear( hidden_size=embed_dim, @@ -166,69 +181,79 @@ def __init__( total_num_kv_heads=num_heads, bias=True, quant_config=quant_config, - prefix=f"{prefix}.qkv") - self.proj = RowParallelLinear(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj") + prefix=f"{prefix}.qkv", + ) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + ) # Detect attention implementation. self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype()) + dtype=torch.get_default_dtype(), + ) self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func \ - = maybe_get_vit_flash_attn_backend( + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, ) + ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( f"Ernie45-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape if self.tp_size > 1: - qkv = all_gather_interleave(qkv, self.qkv.hidden_size, - self.tp_size) + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) # 3 * [s, b, head * head_dim] if self.tp_size > 1: - splitter = partial(dist_utils.split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] v = splitter(v)[self.tp_rank] # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -237,30 +262,30 @@ def forward( q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: qk_concat = torch.cat([q, k], dim=0) qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = self.flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) - - context_layer = rearrange(output, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() + output = self.flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + ) + + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -270,36 +295,36 @@ def forward( q_i = q[:, start_idx:end_idx] k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + q_i, k_i, v_i = ( + rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() output, _ = self.proj(context_layer) return output class Ernie4_5_VisionMLP(nn.Module): - def __init__( self, in_features: int, @@ -309,15 +334,19 @@ def __init__( prefix: str = "", ): super().__init__() - self.fc1 = ColumnParallelLinear(in_features, - hidden_features, - quant_config=quant_config, - prefix=f"{prefix}.fc1") + self.fc1 = ColumnParallelLinear( + in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) self.act = act_layer() - self.fc2 = RowParallelLinear(hidden_features, - in_features, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc2 = RowParallelLinear( + hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -327,7 +356,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Ernie4_5_VisionBlock(nn.Module): - def __init__( self, dim: int, @@ -346,27 +374,30 @@ def __init__( self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.attn = Ernie4_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Ernie4_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) - self.mlp = Ernie4_5_VisionMLP(dim, - mlp_hidden_dim, - act_layer=act_layer, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = Ernie4_5_VisionMLP( + dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) def forward( - self, - hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, @@ -379,7 +410,6 @@ def forward( class Ernie4_5_VisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -387,18 +417,16 @@ def __init__( embed_dim: int = 1280, prefix="", ) -> None: - super().__init__() self.patch_size = patch_size self.in_channels = in_channels self.embed_dim = embed_dim - self.proj = nn.Linear(in_channels * patch_size * patch_size, - embed_dim, - bias=False) + self.proj = nn.Linear( + in_channels * patch_size * patch_size, embed_dim, bias=False + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - target_dtype = self.proj.weight.dtype hidden_states = hidden_states.to(target_dtype) hidden_states = self.proj(hidden_states) @@ -407,22 +435,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Ernie4_5_VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - self.inv_freq = 1.0 / theta**( - torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim) + self.inv_freq = 1.0 / theta ** ( + torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim + ) def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(input=seq, vec2=self.inv_freq) return freqs class Ernie4_5_VisionTransformer(nn.Module): - def __init__( self, vision_config, @@ -430,7 +457,6 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__() patch_size = vision_config.patch_size spatial_merge_size = vision_config.spatial_merge_size @@ -456,24 +482,31 @@ def __init__( head_dim = embed_dim // num_heads self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Ernie4_5_VisionBlock(dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(depth) - ]) - - assert (hidden_size == embed_dim - ), "vit's config.hidden must be equal to config.embed_dim" + self.blocks = nn.ModuleList( + [ + Ernie4_5_VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + ) + for layer_idx in range(depth) + ] + ) + + assert hidden_size == embed_dim, ( + "vit's config.hidden must be equal to config.embed_dim" + ) self.ln = nn.LayerNorm(hidden_size, eps=1e-6) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability(torch.get_default_dtype()): + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): self.attn_backend = _Backend.FLASH_ATTN @property @@ -489,20 +522,27 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) @@ -510,29 +550,29 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: return rotary_pos_emb def compute_attn_mask_seqlen( - self, cu_seqlens: torch.Tensor + self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens - def forward(self, - hidden_states: torch.Tensor, - grid_thw: torch.Tensor, - num_pad=0) -> torch.Tensor: - + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0 + ) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) if num_pad > 0: cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) @@ -569,8 +609,7 @@ def load_weights(self, weights) -> set[str]: for name, loaded_weight in weights: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -587,6 +626,7 @@ class Ernie4_5_VLImagePixelInputs(TensorSchema): - ni: Number of images - cps: Number of channels * patch_size * patch_size """ + type: Literal["pixel_values"] pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")] @@ -605,6 +645,7 @@ class Ernie4_5_VLVideoPixelInputs(TensorSchema): - cps: Number of channels * temporal_patch_size * patch_size * patch_size """ + type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "cps")] video_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -664,14 +705,15 @@ def smart_resize( class VariableResolutionResamplerModel(nn.Module): - - def __init__(self, - in_dim, - out_dim, - spatial_conv_size, - temporal_conv_size, - config, - prefix: str = "") -> None: + def __init__( + self, + in_dim, + out_dim, + spatial_conv_size, + temporal_conv_size, + config, + prefix: str = "", + ) -> None: super().__init__() self.in_dim = in_dim self.out_dim = out_dim @@ -681,18 +723,21 @@ def __init__(self, self.use_temporal_conv = config.use_temporal_conv # compress 2d conv(picture) to 1d - self.spatial_dim = (self.in_dim * self.spatial_conv_size * - self.spatial_conv_size) + self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size # compress 3d conv(video) to 1d - self.temporal_dim = (self.in_dim * self.spatial_conv_size * - self.spatial_conv_size * self.temporal_conv_size) + self.temporal_dim = ( + self.in_dim + * self.spatial_conv_size + * self.spatial_conv_size + * self.temporal_conv_size + ) self.spatial_linear1 = ColumnParallelLinear( self.spatial_dim, self.spatial_dim, bias=True, gather_output=True, - quant_config=getattr(config, 'quant_config', None), + quant_config=getattr(config, "quant_config", None), prefix=f"{prefix}.spatial_linear1", ) @@ -703,7 +748,7 @@ def __init__(self, self.spatial_dim, bias=True, gather_output=True, - quant_config=getattr(config, 'quant_config', None), + quant_config=getattr(config, "quant_config", None), prefix=f"{prefix}.spatial_linear2", ) @@ -715,7 +760,7 @@ def __init__(self, self.spatial_dim, bias=True, gather_output=True, - quant_config=getattr(config, 'quant_config', None), + quant_config=getattr(config, "quant_config", None), prefix=f"{prefix}.temporal_linear1", ) @@ -726,7 +771,7 @@ def __init__(self, self.spatial_dim, bias=True, gather_output=True, - quant_config=getattr(config, 'quant_config', None), + quant_config=getattr(config, "quant_config", None), prefix=f"{prefix}.temporal_linear2", ) @@ -737,12 +782,13 @@ def __init__(self, self.out_dim, bias=True, gather_output=True, - quant_config=getattr(config, 'quant_config', None), + quant_config=getattr(config, "quant_config", None), prefix=f"{prefix}.mlp", ) - self.after_norm = RMSNorm(hidden_size=out_dim, - eps=getattr(config, 'rms_norm_eps', 1e-6)) + self.after_norm = RMSNorm( + hidden_size=out_dim, eps=getattr(config, "rms_norm_eps", 1e-6) + ) def spatial_conv_reshape(self, x, spatial_conv_size): S, C = x.shape @@ -750,7 +796,6 @@ def spatial_conv_reshape(self, x, spatial_conv_size): return x def forward(self, x, grid_thw): - def fwd_spatial(x): x = self.spatial_conv_reshape(x, self.spatial_conv_size) @@ -762,43 +807,48 @@ def fwd_spatial(x): return x def fwd_placeholder(x, grid_thw, to_tensor=False): - grid_thw_cpu = grid_thw.cpu().numpy() grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] - grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size** - 2) + grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size**2) - tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // ( - self.spatial_conv_size**2) - batch_offset = np.empty(tokens_per_img_or_vid.size, - dtype=tokens_per_img_or_vid.dtype) + tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (self.spatial_conv_size**2) + batch_offset = np.empty( + tokens_per_img_or_vid.size, dtype=tokens_per_img_or_vid.dtype + ) batch_offset[0] = 0 batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] slice_offsets = [] for temporoal_size, spatial_size, b_offset in zip( - grid_t, grid_hw_after_conv, batch_offset): + grid_t, grid_hw_after_conv, batch_offset + ): for temp_offset in range(0, temporoal_size, 2): slice_offsets.append( np.arange( b_offset + (temp_offset) * spatial_size, b_offset + (temp_offset + 1) * spatial_size, - )) - slice_offsets = torch.tensor(np.concatenate(slice_offsets, - axis=-1)).to(x.device) + ) + ) + slice_offsets = torch.tensor(np.concatenate(slice_offsets, axis=-1)).to( + x.device + ) slice_offsets2 = [] for temporoal_size, spatial_size, b_offset in zip( - grid_t, grid_hw_after_conv, batch_offset): - for temp_offset in range(1 if temporoal_size > 1 else 0, - temporoal_size, 2): + grid_t, grid_hw_after_conv, batch_offset + ): + for temp_offset in range( + 1 if temporoal_size > 1 else 0, temporoal_size, 2 + ): slice_offsets2.append( np.arange( b_offset + (temp_offset) * spatial_size, b_offset + (temp_offset + 1) * spatial_size, - )) - slice_offsets2 = torch.tensor( - np.concatenate(slice_offsets2, axis=-1)).to(x.device) + ) + ) + slice_offsets2 = torch.tensor(np.concatenate(slice_offsets2, axis=-1)).to( + x.device + ) x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) @@ -824,9 +874,7 @@ def fwd_mlp(x): x = fwd_mlp(x) return x - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() @@ -834,15 +882,13 @@ def load_weights(self, weights: Iterable[tuple[str, if name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.model_config.hf_config @@ -890,11 +936,9 @@ def _get_vision_info( min_pixels=image_processor.min_pixels, max_pixels=image_processor.max_pixels, ) - preprocessed_size = ImageSize(width=resized_width, - height=resized_height) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: - preprocessed_size = ImageSize(width=image_width, - height=image_height) + preprocessed_size = ImageSize(width=image_width, height=image_height) grid_t = max(num_frames // temporal_conv_size, 1) grid_h = preprocessed_size.height // patch_size @@ -987,8 +1031,7 @@ def get_num_frames_with_most_features( max_videos = mm_counts.get("video", 0) max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 2) @@ -1003,15 +1046,12 @@ def get_max_video_tokens( return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), image_processor=None, ) -class Ernie4_5VLMultiModalProcessor( - BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): - +class Ernie4_5VLMultiModalProcessor(BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): def _pixel_values_norm( self, pixel_values: torch.Tensor, @@ -1020,28 +1060,32 @@ def _pixel_values_norm( hf_config = self.info.get_hf_config() vision_config = hf_config.vision_config image_processor = self.info.get_image_processor(**mm_kwargs) - image_mean_tensor = torch.tensor(image_processor.image_mean, - dtype=torch.float32).reshape( - [1, 3, 1, 1]) - image_std_tensor = torch.tensor(image_processor.image_std, - dtype=torch.float32).reshape( - [1, 3, 1, 1]) - rescale_factor = torch.tensor(image_processor.rescale_factor, - dtype=torch.float32) + image_mean_tensor = torch.tensor( + image_processor.image_mean, dtype=torch.float32 + ).reshape([1, 3, 1, 1]) + image_std_tensor = torch.tensor( + image_processor.image_std, dtype=torch.float32 + ).reshape([1, 3, 1, 1]) + rescale_factor = torch.tensor( + image_processor.rescale_factor, dtype=torch.float32 + ) patch_size_squared = vision_config.patch_size**2 - image_mean_tensor = (image_mean_tensor.squeeze( - [-2, -1]).repeat_interleave(patch_size_squared, -1)) - image_std_tensor = (image_std_tensor.squeeze( - [-2, -1]).repeat_interleave(patch_size_squared, -1)) + image_mean_tensor = image_mean_tensor.squeeze([-2, -1]).repeat_interleave( + patch_size_squared, -1 + ) + image_std_tensor = image_std_tensor.squeeze([-2, -1]).repeat_interleave( + patch_size_squared, -1 + ) if not image_mean_tensor.is_contiguous(): image_mean_tensor = image_mean_tensor.contiguous() if not image_std_tensor.is_contiguous(): image_std_tensor = image_std_tensor.contiguous() - pixel_values = (rescale_factor * pixel_values.to(torch.float32) - - image_mean_tensor) / image_std_tensor + pixel_values = ( + rescale_factor * pixel_values.to(torch.float32) - image_mean_tensor + ) / image_std_tensor pixel_values = pixel_values.to(hf_config.torch_dtype) return pixel_values @@ -1057,8 +1101,9 @@ def _call_hf_processor( if "images" not in mm_data and "videos" not in mm_data and prompt != "": tokenizer = self.info.get_tokenizer() prompt_ids = tokenizer.encode(prompt) - tokenizer_output = BatchFeature(dict(input_ids=[prompt_ids]), - tensor_type="pt") + tokenizer_output = BatchFeature( + dict(input_ids=[prompt_ids]), tensor_type="pt" + ) return tokenizer_output if "images" not in mm_data: @@ -1067,38 +1112,40 @@ def _call_hf_processor( mm_data["videos"] = [] processor_output = self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), - dict(text=[prompt], - images=mm_data["images"], - videos=mm_data["videos"]), + dict(text=[prompt], images=mm_data["images"], videos=mm_data["videos"]), dict(**mm_kwargs, **tok_kwargs), ) # Divide the processor_output into two modalities: image and video. if processor_output is not None: - pixel_values = processor_output['images'] + pixel_values = processor_output["images"] if pixel_values is not None: - processor_output['images'] = self._pixel_values_norm( - pixel_values, mm_kwargs) + processor_output["images"] = self._pixel_values_norm( + pixel_values, mm_kwargs + ) for key in list(processor_output.keys()): if processor_output[key] is None: del processor_output[key] continue if key == "grid_thw": - grid_thw = processor_output['grid_thw'] - pixel_values_all = processor_output['images'] + grid_thw = processor_output["grid_thw"] + pixel_values_all = processor_output["images"] # Identify elements where the first # dimension is greater than 1 and # treat them as the video modality mask = grid_thw[:, 0] > 1 processor_output["video_grid_thw"] = grid_thw[mask] processor_output["image_grid_thw"] = grid_thw[~mask] - image_patch_num = processor_output["image_grid_thw"].prod( - dim=1).sum() - processor_output[ - 'pixel_values'] = pixel_values_all[:image_patch_num] - processor_output['pixel_values_videos'] = pixel_values_all[ - image_patch_num:] - del processor_output['images'] + image_patch_num = ( + processor_output["image_grid_thw"].prod(dim=1).sum() + ) + processor_output["pixel_values"] = pixel_values_all[ + :image_patch_num + ] + processor_output["pixel_values_videos"] = pixel_values_all[ + image_patch_num: + ] + del processor_output["images"] return processor_output @@ -1112,13 +1159,13 @@ def _get_prompt_updates( before_placeholder = { "image": "<|image@placeholder|>", - "video": "<|video@placeholder|>" + "video": "<|video@placeholder|>", } after_placeholder = { # image and video have same placeholder "image": "<|IMAGE_PLACEHOLDER|>", - "video": "<|IMAGE_PLACEHOLDER|>" + "video": "<|IMAGE_PLACEHOLDER|>", } merge_length = hf_processor.spatial_conv_size**2 @@ -1128,8 +1175,11 @@ def get_replacement_ernie45vl(item_idx: int, modality: str): grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) if modality == "video": - num_tokens = int(grid_thw.prod( - )) // hf_processor.temporal_conv_size // merge_length + num_tokens = ( + int(grid_thw.prod()) + // hf_processor.temporal_conv_size + // merge_length + ) else: num_tokens = int(grid_thw.prod()) // merge_length return after_placeholder[modality] * num_tokens @@ -1138,9 +1188,9 @@ def get_replacement_ernie45vl(item_idx: int, modality: str): PromptReplacement( modality=modality, target=before_placeholder[modality], - replacement=partial(get_replacement_ernie45vl, - modality=modality), - ) for modality in ("image", "video") + replacement=partial(get_replacement_ernie45vl, modality=modality), + ) + for modality in ("image", "video") ] def _get_mm_fields_config( @@ -1148,7 +1198,6 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_grid_sizes = image_grid_thw.prod(-1) @@ -1157,28 +1206,28 @@ def _get_mm_fields_config( return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), + "image", image_grid_sizes + ), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), video_grid_thw=MultiModalFieldConfig.batched("video"), ) -class Ernie4_5_VLDummyInputsBuilder( - BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]): - +class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) prompt = "" for i in range(num_images): - prompt += (f"Picture {i+1}:" - "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>") + prompt += ( + f"Picture {i + 1}:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + ) for i in range(num_videos): - prompt += (f"Video {i+1}:" - "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>") + prompt += f"Video {i + 1}:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" return prompt def get_dummy_mm_data( @@ -1190,35 +1239,39 @@ def get_dummy_mm_data( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), - "video": - self._get_dummy_videos(width=target_width, - height=target_height, - num_frames=target_num_frames, - num_videos=num_videos, - overrides=video_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, + ), } @MULTIMODAL_REGISTRY.register_processor( Ernie4_5VLMultiModalProcessor, info=Ernie4_5_VLProcessingInfo, - dummy_inputs=Ernie4_5_VLDummyInputsBuilder) -class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): + dummy_inputs=Ernie4_5_VLDummyInputsBuilder, +) +class Ernie4_5_VLMoeForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP +): merge_by_field_config = True packed_modules_mapping = { @@ -1250,7 +1303,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, "temporal_linear.0.": "temporal_linear1.", "temporal_linear.2.": "temporal_linear2.", "temporal_linear.3.": "temporal_norm.", - }) + }, + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -1288,11 +1342,13 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.config.spatial_conv_size, self.config.temporal_conv_size, config=self.config, - prefix=maybe_prefix(prefix, "resampler_model")) + prefix=maybe_prefix(prefix, "resampler_model"), + ) self.visual_token_mask = None self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def compute_logits( self, @@ -1311,7 +1367,8 @@ def _vision_forward( if grid_thw.numel() % 3 != 0: raise ValueError( f"grid_thw has {grid_thw.numel()} elements after filtering," - "which is not divisible by 3.") + "which is not divisible by 3." + ) grid_thw = grid_thw.reshape(-1, 3) # example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]] grid_thw = F.pad( @@ -1324,8 +1381,9 @@ def _vision_forward( def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: if getattr(self.config, "im_patch_id", None) is not None: - self.visual_token_mask = ( - input_ids == self.config.im_patch_id).reshape(-1, 1) + self.visual_token_mask = (input_ids == self.config.im_patch_id).reshape( + -1, 1 + ) else: self.visual_token_mask = None @@ -1333,7 +1391,8 @@ def get_language_model(self) -> torch.nn.Module: return self.language_model def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]: + self, **kwargs: object + ) -> Optional[Ernie4_5_VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1341,12 +1400,15 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - return Ernie4_5_VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Ernie4_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Ernie4_5_VLVideoInputs]: + self, **kwargs: object + ) -> Optional[Ernie4_5_VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1361,16 +1423,15 @@ def _parse_and_validate_video_input( ) def _process_image_input( - self, - image_input: Ernie4_5_VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Ernie4_5_VLImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 - pixel_values = image_input["pixel_values"].type( - self.vision_model.dtype) - image_features = self._vision_forward(pixel_values=pixel_values, - grid_thw=grid_thw) + pixel_values = image_input["pixel_values"].type(self.vision_model.dtype) + image_features = self._vision_forward( + pixel_values=pixel_values, grid_thw=grid_thw + ) image_embeds = self.resampler_model(image_features, grid_thw) merge_size = self.vision_model.spatial_merge_size @@ -1379,21 +1440,25 @@ def _process_image_input( return image_embeds.split(sizes.tolist()) def _process_video_input( - self, - video_input: Ernie4_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: - + self, video_input: Ernie4_5_VLVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 pixel_values_videos = video_input["pixel_values_videos"].type( - self.vision_model.dtype) - video_features = self._vision_forward(pixel_values=pixel_values_videos, - grid_thw=grid_thw) + self.vision_model.dtype + ) + video_features = self._vision_forward( + pixel_values=pixel_values_videos, grid_thw=grid_thw + ) video_embeds = self.resampler_model(video_features, grid_thw) merge_size = self.vision_model.spatial_merge_size - sizes = (grid_thw.prod(-1) // - self.config.temporal_conv_size) // merge_size // merge_size + sizes = ( + (grid_thw.prod(-1) // self.config.temporal_conv_size) + // merge_size + // merge_size + ) return video_embeds.split(sizes.tolist()) @@ -1403,20 +1468,22 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_videos", - "video_embeds") and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "videos" not in modalities + ): + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - + self, **kwargs: object + ) -> Optional[MultiModalEmbeddings]: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None @@ -1447,8 +1514,7 @@ def get_input_embeddings( is_multimodal: Optional[torch.Tensor] = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: - if multimodal_embeddings is not None and len( - multimodal_embeddings) > 0: + if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) # This is to satisfy the type checker for each overload @@ -1470,7 +1536,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ): - forward_kwargs = { "input_ids": input_ids, "positions": positions, @@ -1479,20 +1544,17 @@ def forward( } if self.visual_token_mask is not None: - if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]: - padding_len = inputs_embeds.shape[ - 0] - self.visual_token_mask.shape[0] + padding_len = inputs_embeds.shape[0] - self.visual_token_mask.shape[0] # right pad False pad = torch.zeros( (padding_len, self.visual_token_mask.shape[1]), dtype=self.visual_token_mask.dtype, - device=self.visual_token_mask.device) - self.visual_token_mask = torch.cat( - [self.visual_token_mask, pad], dim=0) + device=self.visual_token_mask.device, + ) + self.visual_token_mask = torch.cat([self.visual_token_mask, pad], dim=0) - forward_kwargs.update( - {"visual_token_mask": self.visual_token_mask}) + forward_kwargs.update({"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None hidden_states = self.language_model.model( @@ -1502,8 +1564,6 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 21772f766b40..51f49b8587e6 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Erine VL model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -31,31 +32,43 @@ from transformers import PretrainedConfig from vllm.attention import Attention + # from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( - Ernie4_5_VLRotaryEmbedding) + Ernie4_5_VLRotaryEmbedding, +) from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .ernie45_moe import Ernie4_5_MoeMLP from .interfaces import SupportsPP -from .utils import (PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -65,7 +78,6 @@ class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP): class Ernie4_5_VLMoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -109,19 +121,23 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = QKVParallelLinear(hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) t_rope = freq_allocation h_rope = (self.head_dim // 2 - freq_allocation) // 2 @@ -134,22 +150,24 @@ def __init__( base=rope_theta, is_neox_style=False, dtype=torch.get_default_dtype(), - mrope_section=[h_rope, w_rope, t_rope]) + mrope_section=[h_rope, w_rope, t_rope], + ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -163,7 +181,6 @@ def forward( class Ernie4_5_VLMoeMoE(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -175,8 +192,7 @@ def __init__( layer_idx = extract_layer_index(prefix) self.layer_idx = layer_idx self.tp_size = get_tensor_model_parallel_world_size() - self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0) - > 0) + self.has_shared_experts = getattr(config, "moe_num_shared_experts", 0) > 0 self.hidden_size = config.hidden_size moe_num_experts = config.moe_num_experts @@ -185,33 +201,40 @@ def __init__( if self.tp_size > max_moe_num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {moe_num_experts}.") + f"the number of experts {moe_num_experts}." + ) moe_layer_start_index = config.moe_layer_start_index text_moe_layer_start_index = moe_layer_start_index[0] vision_moe_layer_start_index = moe_layer_start_index[1] moe_layer_end_index = config.moe_layer_end_index moe_layer_end_index = getattr( - config, "moe_layer_end_index", - [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) + config, + "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1], + ) text_moe_layer_end_index = moe_layer_end_index[0] vision_moe_layer_end_index = moe_layer_end_index[1] assert config.moe_num_experts[0] == config.moe_num_experts[1] self.e_score_correction_bias = nn.Parameter( - torch.empty(2, config.moe_num_experts[0], dtype=torch.float32)) + torch.empty(2, config.moe_num_experts[0], dtype=torch.float32) + ) assert text_moe_layer_start_index <= text_moe_layer_end_index - if layer_idx >= text_moe_layer_start_index and \ - layer_idx <= text_moe_layer_end_index: + if ( + layer_idx >= text_moe_layer_start_index + and layer_idx <= text_moe_layer_end_index + ): self.text_experts_gate = ReplicatedLinear( config.hidden_size, config.moe_num_experts[0], bias=False, params_dtype=torch.float32, quant_config=quant_config, - prefix=f"{prefix}.text_experts_gate") + prefix=f"{prefix}.text_experts_gate", + ) self.text_experts = FusedMoE( num_experts=config.moe_num_experts[0], @@ -222,26 +245,31 @@ def __init__( renormalize=True, quant_config=quant_config, e_score_correction_bias=self.e_score_correction_bias[0], - prefix=f"{prefix}.text_experts") + prefix=f"{prefix}.text_experts", + ) else: self.text_experts = Ernie4_5_VLMoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - use_bias=getattr(config, 'use_bias', False), + use_bias=getattr(config, "use_bias", False), quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + ) assert vision_moe_layer_start_index <= vision_moe_layer_end_index - if layer_idx >= vision_moe_layer_start_index and \ - layer_idx <= vision_moe_layer_end_index: + if ( + layer_idx >= vision_moe_layer_start_index + and layer_idx <= vision_moe_layer_end_index + ): self.vision_experts_gate = ReplicatedLinear( config.hidden_size, config.moe_num_experts[1], bias=False, params_dtype=torch.float32, quant_config=quant_config, - prefix=f"{prefix}.vision_experts_gate") + prefix=f"{prefix}.vision_experts_gate", + ) self.vision_experts = FusedMoE( num_experts=config.moe_num_experts[1], @@ -252,27 +280,30 @@ def __init__( renormalize=True, quant_config=quant_config, e_score_correction_bias=self.e_score_correction_bias[1], - prefix=f"{prefix}.vision_experts") + prefix=f"{prefix}.vision_experts", + ) else: self.vision_experts = Ernie4_5_VLMoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - use_bias=getattr(config, 'use_bias', False), + use_bias=getattr(config, "use_bias", False), quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + ) if self.has_shared_experts: - intermediate_size = (config.moe_intermediate_size[0] * - config.moe_num_shared_experts) + intermediate_size = ( + config.moe_intermediate_size[0] * config.moe_num_shared_experts + ) self.shared_experts = Ernie4_5_VLMoeMLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.shared_experts", - reduce_results=self.text_experts. - must_reduce_shared_expert_outputs()) + reduce_results=self.text_experts.must_reduce_shared_expert_outputs(), + ) def forward( self, @@ -280,7 +311,6 @@ def forward( visual_token_mask: torch.Tensor, **kwargs: object, ) -> torch.Tensor: - orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) @@ -291,54 +321,61 @@ def forward( if visual_token_mask is not None and visual_token_mask.all(): # only vision modal input router_logits, _ = self.vision_experts_gate( - hidden_states.to(dtype=torch.float32)) + hidden_states.to(dtype=torch.float32) + ) final_hidden_states = self.vision_experts( - hidden_states=hidden_states, router_logits=router_logits) + hidden_states=hidden_states, router_logits=router_logits + ) elif visual_token_mask is not None and visual_token_mask.any(): # text and vision modals input - visual_token_mask = visual_token_mask.repeat( - 1, self.hidden_size).bool() + visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool() text_token_mask = ~visual_token_mask final_hidden_states = torch.zeros_like(hidden_states) text_hidden_states = hidden_states[text_token_mask].reshape( - -1, self.hidden_size) + -1, self.hidden_size + ) vision_hidden_states = hidden_states[visual_token_mask].reshape( - -1, self.hidden_size) + -1, self.hidden_size + ) text_router_logits, _ = self.text_experts_gate( - text_hidden_states.to(dtype=torch.float32)) + text_hidden_states.to(dtype=torch.float32) + ) final_hidden_states[text_token_mask] = self.text_experts( - hidden_states=text_hidden_states, - router_logits=text_router_logits).flatten() + hidden_states=text_hidden_states, router_logits=text_router_logits + ).flatten() vision_router_logits, _ = self.vision_experts_gate( - vision_hidden_states.to(dtype=torch.float32)) + vision_hidden_states.to(dtype=torch.float32) + ) final_hidden_states[visual_token_mask] = self.vision_experts( - hidden_states=vision_hidden_states, - router_logits=vision_router_logits).flatten() + hidden_states=vision_hidden_states, router_logits=vision_router_logits + ).flatten() else: # only text modal input text_router_logits, _ = self.text_experts_gate( - hidden_states.to(dtype=torch.float32)) + hidden_states.to(dtype=torch.float32) + ) final_hidden_states = self.text_experts( - hidden_states=hidden_states, router_logits=text_router_logits) + hidden_states=hidden_states, router_logits=text_router_logits + ) - if self.has_shared_experts and \ - shared_output is not None: + if self.has_shared_experts and shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: final_hidden_states = ( self.text_experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + final_hidden_states + ) + ) return final_hidden_states.view(orig_shape) class Ernie4_5_VLMoeDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -351,20 +388,19 @@ def __init__( rope_theta = getattr(config, "rope_theta", 500000) rope_scaling = getattr(config, "rope_scaling", None) freq_allocation = getattr(config, "freq_allocation", 20) - max_position_embeddings = getattr(config, "max_position_embeddings", - 131072) + max_position_embeddings = getattr(config, "max_position_embeddings", 131072) self.self_attn = Ernie4_5_VLMoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, - head_dim=getattr(config, 'head_dim', None), + head_dim=getattr(config, "head_dim", None), rope_theta=rope_theta, rope_scaling=rope_scaling, freq_allocation=freq_allocation, max_position_embeddings=max_position_embeddings, rms_norm_eps=config.rms_norm_eps, - qkv_bias=getattr(config, 'use_bias', False), + qkv_bias=getattr(config, "use_bias", False), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", @@ -377,8 +413,10 @@ def __init__( moe_layer_start_index = config.moe_layer_start_index min_moe_layer_start_index = min(moe_layer_start_index) moe_layer_end_index = getattr( - config, "moe_layer_end_index", - [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) + config, + "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1], + ) max_moe_layer_end_index = max(moe_layer_end_index) assert min_moe_layer_start_index <= max_moe_layer_end_index moe_num_experts = config.moe_num_experts @@ -386,25 +424,29 @@ def __init__( moe_layer_interval = getattr(config, "moe_layer_interval", 1) use_moe = getattr(config, "use_moe", max_moe_num_experts > 0) - if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0) - and layer_idx >= min_moe_layer_start_index - and layer_idx <= max_moe_layer_end_index): - self.mlp = Ernie4_5_VLMoeMoE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if ( + use_moe + and ((layer_idx + 1) % moe_layer_interval == 0) + and layer_idx >= min_moe_layer_start_index + and layer_idx <= max_moe_layer_end_index + ): + self.mlp = Ernie4_5_VLMoeMoE( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) else: self.mlp = Ernie4_5_VLMoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - use_bias=getattr(config, 'use_bias', False), + use_bias=getattr(config, "use_bias", False), quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -414,14 +456,12 @@ def forward( visual_token_mask: Optional[torch.Tensor], **kwargs: object, ) -> torch.Tensor: - # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, @@ -429,12 +469,10 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) if isinstance(self.mlp, Ernie4_5_VLMoeMoE): - hidden_states = self.mlp(hidden_states, visual_token_mask, - **kwargs) + hidden_states = self.mlp(hidden_states, visual_token_mask, **kwargs) else: hidden_states = self.mlp(hidden_states) @@ -452,7 +490,6 @@ def forward( # "visual_token_mask": 0, # }) class Ernie4_5_VLMoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -471,7 +508,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) else: self.embed_tokens = PPMissingLayer() @@ -481,7 +519,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config=config, cache_config=cache_config, quant_config=quant_config, - prefix=prefix), + prefix=prefix, + ), prefix=f"{prefix}.layers", ) @@ -490,9 +529,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -506,7 +545,6 @@ def forward( visual_token_mask: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -519,14 +557,14 @@ def forward( residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual, - visual_token_mask, **kwargs) + hidden_states, residual = layer( + positions, hidden_states, residual, visual_token_mask, **kwargs + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) @@ -555,15 +593,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Ernie4_5_VLMoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Ernie4_5_VLMoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() @@ -571,7 +611,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -584,8 +625,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, **kwargs) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) return hidden_states def compute_logits( @@ -595,8 +637,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -612,32 +653,31 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=max(self.config.moe_num_experts)) + num_experts=max(self.config.moe_num_experts), + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.config.tie_word_embeddings and name.endswith( - "lm_head.weight"): + if self.config.tie_word_embeddings and name.endswith("lm_head.weight"): loaded_params.add("lm_head.weight") continue # MTP will be supported soon. - if "mtp" in name or \ - "vision_model" in name or \ - "resampler_model" in name: + if "mtp" in name or "vision_model" in name or "resampler_model" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -652,14 +692,13 @@ def load_weights(self, weights: Iterable[tuple[str, if "mlp.experts" in name: moe_offset = int(name.split(".")[-3]) vision_expert_start_idx = self.config.moe_num_experts[0] - is_text_expert = \ - moe_offset <= vision_expert_start_idx - 1 + is_text_expert = moe_offset <= vision_expert_start_idx - 1 if is_text_expert: name = name.replace(".experts.", ".text_experts.") else: name = name.replace( f".experts.{moe_offset}", - f".vision_experts.{moe_offset-vision_expert_start_idx}" + f".vision_experts.{moe_offset - vision_expert_start_idx}", ) for mapping in expert_params_mapping: @@ -670,8 +709,7 @@ def load_weights(self, weights: Iterable[tuple[str, # Distinguish between vision experts and text experts moe_offset = int(name.split(".")[-3]) - is_text_expert = \ - moe_offset <= self.config.moe_num_experts[0] - 1 + is_text_expert = moe_offset <= self.config.moe_num_experts[0] - 1 name = name.replace(weight_name, param_name) if is_text_expert: @@ -684,36 +722,40 @@ def load_weights(self, weights: Iterable[tuple[str, continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Distinguish between vision expert gate # and text expert gate if name.endswith("mlp.gate.weight"): - name = name.replace("gate.weight", - "text_experts_gate.weight") + name = name.replace("gate.weight", "text_experts_gate.weight") loaded_weight = loaded_weight.T elif name.endswith("mlp.gate.weight_1"): - name = name.replace("gate.weight_1", - "vision_experts_gate.weight") + name = name.replace( + "gate.weight_1", "vision_experts_gate.weight" + ) loaded_weight = loaded_weight.T if "e_score_correction_bias" in name: name = name.replace(".moe_statics.", ".") # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -725,8 +767,9 @@ def load_weights(self, weights: Iterable[tuple[str, param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index 2e6ef2d476a6..46a7131f2499 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Ernie-MTP model.""" + from collections.abc import Iterable from typing import Optional @@ -33,7 +34,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors @@ -43,7 +46,6 @@ class ErnieMultiTokenPredictorLayer(nn.Module): - def __init__( self, vllm_config: VllmConfig, @@ -52,13 +54,11 @@ def __init__( super().__init__() config = vllm_config.model_config.hf_config - self.mtp_emb_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.mtp_hidden_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) + self.mtp_emb_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mtp_hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mtp_linear_proj = nn.Linear( + config.hidden_size * 2, config.hidden_size, bias=False + ) self.mtp_block = LlamaDecoderLayer(vllm_config, prefix) def forward( @@ -76,18 +76,18 @@ def forward( previous_hidden_states = self.mtp_hidden_norm(previous_hidden_states) hidden_states = self.mtp_linear_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - residual=None) + hidden_states, residual = self.mtp_block( + positions=positions, hidden_states=hidden_states, residual=None + ) hidden_states = residual + hidden_states return hidden_states class ErnieMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -95,15 +95,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict({ - str(idx): - ErnieMultiTokenPredictorLayer( - vllm_config, - f"{prefix}.layers.{idx}", - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.layers = torch.nn.ModuleDict( + { + str(idx): ErnieMultiTokenPredictorLayer( + vllm_config, + f"{prefix}.layers.{idx}", + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -142,17 +145,18 @@ def compute_logits( class ErnieMTP(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - self.model = ErnieMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head")) + self.model = ErnieMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight @@ -170,8 +174,9 @@ def forward( spec_step_idx: int = 0, ) -> torch.Tensor: assert spec_step_idx == 0, "ernie_mtp only support predict one token" - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( @@ -179,11 +184,9 @@ def compute_logits( hidden_states: torch.Tensor, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.model.compute_logits(hidden_states, self.lm_head, - spec_step_idx) + return self.model.compute_logits(hidden_states, self.lm_head, spec_step_idx) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -195,16 +198,14 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - - if self.config.tie_word_embeddings and name.endswith( - "lm_head.weight"): + if self.config.tie_word_embeddings and name.endswith("lm_head.weight"): continue if "rotary_emb.inv_freq" in name: continue if "mtp" in name: name = self._rewrite_spec_layer_name(self.config, name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -216,12 +217,13 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -233,8 +235,9 @@ def load_weights(self, weights: Iterable[tuple[str, break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -242,33 +245,36 @@ def load_weights(self, weights: Iterable[tuple[str, # According to DeepSeek-V3 Technical Report, MTP modules # shares embedding layer. We only load the first weights. - if "mtp_" not in name and ("embed_tokens" not in name - and "lm_head" not in name): + if "mtp_" not in name and ( + "embed_tokens" not in name and "lm_head" not in name + ): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - def _rewrite_spec_layer_name(self, config: PretrainedConfig, - name: str) -> str: + def _rewrite_spec_layer_name(self, config: PretrainedConfig, name: str) -> str: """ Rewrite the weight name to match the format of the original model. """ spec_layer_weight_names = [ - "embed_tokens", "mtp_emb_norm", "mtp_hidden_norm", - "mtp_linear_proj" + "embed_tokens", + "mtp_emb_norm", + "mtp_hidden_norm", + "mtp_linear_proj", ] layer_idx = config.num_hidden_layers for weight_name in spec_layer_weight_names: if weight_name in name: name = name.replace( f"model.{weight_name}.0.", - f"model.layers.{layer_idx}.{weight_name}.") + f"model.layers.{layer_idx}.{weight_name}.", + ) return name - name = name.replace("model.mtp_block.0.", - f"model.layers.{layer_idx}.mtp_block.") + name = name.replace( + "model.mtp_block.0.", f"model.layers.{layer_idx}.mtp_block." + ) return name diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 5dafcd595e4a..1f0b5723721c 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -39,26 +39,37 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class ExaoneGatedMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -84,8 +95,9 @@ def __init__( prefix=f"{prefix}.c_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -96,7 +108,6 @@ def forward(self, x): class ExaoneAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -191,7 +202,6 @@ def forward( class ExaoneBlockAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -233,7 +243,6 @@ def forward( class ExaoneDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -246,21 +255,24 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.attn = ExaoneBlockAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -305,7 +317,6 @@ def forward( @support_torch_compile class ExaoneModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -316,12 +327,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.wte = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.wte = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -341,14 +356,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.h", ) if get_pp_group().is_last_rank: - self.ln_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) + self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) else: self.ln_f = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -379,16 +393,14 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -402,19 +414,19 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -447,8 +459,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -499,7 +510,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -507,14 +519,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.transformer.wte.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -526,8 +539,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + model_output = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( @@ -537,14 +551,12 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index c78eedff6670..230a2c80104b 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -35,27 +35,38 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class Exaone4GatedMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -81,8 +92,9 @@ def __init__( prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -93,7 +105,6 @@ def forward(self, x): class Exaone4Attention(nn.Module): - def __init__( self, config: Exaone4Config, @@ -208,7 +219,6 @@ def forward( class Exaone4DecoderLayer(nn.Module): - def __init__( self, config: Exaone4Config, @@ -221,22 +231,25 @@ def __init__( rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.self_attn = Exaone4Attention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -253,10 +266,12 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_feedforward_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -290,7 +305,6 @@ def forward( @support_torch_compile class Exaone4Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -301,11 +315,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -329,9 +347,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -362,16 +380,14 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -385,19 +401,19 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -430,8 +446,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -482,7 +497,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -490,14 +506,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -509,8 +526,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( @@ -520,14 +538,12 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/fairseq2_llama.py b/vllm/model_executor/models/fairseq2_llama.py index d78ee100b26d..ca0e7e64df53 100644 --- a/vllm/model_executor/models/fairseq2_llama.py +++ b/vllm/model_executor/models/fairseq2_llama.py @@ -23,8 +23,10 @@ from torch.nn import Parameter from vllm.config import VllmConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.linear import set_weight_attrs from vllm.model_executor.models.llama import LlamaForCausalLM @@ -32,7 +34,6 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) self.tp_rank = get_tensor_model_parallel_rank() @@ -45,14 +46,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): f"model.{self.tp_rank}.pt", ] - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # fairseq2's serialization adds a wrapper to usual .pt state_dict's: # { "model_key": my_model_name, "my_model_name": state_dict } # which we first need to unpack weights_wrapped = dict(weights) - weights = weights_wrapped[ - weights_wrapped["model_key"]].items() # type: ignore + weights = weights_wrapped[weights_wrapped["model_key"]].items() # type: ignore # remap keys fs2_to_vllm_mapper = WeightsMapper( @@ -77,12 +76,14 @@ def load_weights(self, weights: Iterable[tuple[str, loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights( - (self.reshape_fairseq2_weights(name, loaded_weight, params) - for name, loaded_weight in weights)) + ( + self.reshape_fairseq2_weights(name, loaded_weight, params) + for name, loaded_weight in weights + ) + ) def flag_sharded_weights(self, params: dict[str, Parameter]): """Sets the `is_sharded_weight` flag to True for all sharded weights""" @@ -113,35 +114,34 @@ def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor: attn_in //= self.tp_size n_heads //= self.tp_size attn_out = self.config.hidden_size - return (w.view(n_heads, attn_in // n_heads // 2, 2, - attn_out).transpose(1, - 2).reshape(attn_in, attn_out)) + return ( + w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + .transpose(1, 2) + .reshape(attn_in, attn_out) + ) modules = name.split(".") # rotary embeds should be sliced if "k_proj" in modules: - loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads) + loaded_weight = permute(loaded_weight, self.config.num_key_value_heads) elif "q_proj" in modules: - loaded_weight = permute(loaded_weight, - self.config.num_attention_heads) + loaded_weight = permute(loaded_weight, self.config.num_attention_heads) # We make the loaded weights compatible with both # full checkpoints and tp sharded checkpoints. # Embeddings are repeated to fit the vocab size. - # Other weights are flagged for the weight_loader calls. + # Other weights are flagged for the weight_loader calls. if any(emb in modules for emb in ["embed_tokens", "lm_head"]): # Embeddings are sharded on dim 0 dim = 0 # In fairseq2, vocab size has to be divisible by tp_size # so we don't worry about padding - if self.tp_size > 1 and loaded_weight.shape[ - dim] < self.config.vocab_size: - assert loaded_weight.shape[ - dim] * self.tp_size == self.config.vocab_size, \ - "vocab_size should be divisible by tp_size." + if self.tp_size > 1 and loaded_weight.shape[dim] < self.config.vocab_size: + assert ( + loaded_weight.shape[dim] * self.tp_size == self.config.vocab_size + ), "vocab_size should be divisible by tp_size." repeats = [1] * len(loaded_weight.size()) repeats[dim] = self.tp_size # repeat to match vocab size and to be easily 'narrow'able diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 0c50056d1c52..211a9120789e 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -33,55 +33,65 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import RWConfig from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) FalconConfig = Union[HF_FalconConfig, RWConfig] def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: - closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) - base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), - dtype=torch.float32) + closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32 + ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( - 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), - dtype=torch.float32) - num_remaining_heads = min(closest_power_of_2, - total_num_heads - closest_power_of_2) - extra_powers = torch.arange(1, - 1 + 2 * num_remaining_heads, - 2, - dtype=torch.int32) - slopes = torch.cat( - [slopes, torch.pow(extra_base, extra_powers)], dim=0) + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32 + ) + num_remaining_heads = min( + closest_power_of_2, total_num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + 1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32 + ) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes class FalconAttention(nn.Module): - def __init__( self, config: FalconConfig, @@ -133,59 +143,68 @@ def __init__( # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) - self.reduce_row_parallel_results = not (config.new_decoder_architecture - or config.parallel_attn) + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=config.bias, skip_bias_add=True, quant_config=quant_config, - reduce_results=self.reduce_row_parallel_results) + reduce_results=self.reduce_row_parallel_results, + ) self.use_rotary = config.rotary self.use_alibi = config.alibi assert not (self.use_rotary and self.use_alibi), ( - "Rotary and alibi are mutually exclusive.") + "Rotary and alibi are mutually exclusive." + ) if self.use_rotary: rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, - "max_position_embeddings", 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads - alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * - self.inv_norm_factor) + alibi_slopes = ( + _get_alibi_slopes(self.total_num_heads) * self.inv_norm_factor + ) alibi_slopes = alibi_slopes[head_start:head_end].tolist() - self.attn = Attention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - alibi_slopes=alibi_slopes, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) else: - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -204,7 +223,6 @@ def forward( class FalconMLP(nn.Module): - def __init__( self, config: FalconConfig, @@ -213,21 +231,25 @@ def __init__( super().__init__() hidden_size = config.hidden_size - self.dense_h_to_4h = ColumnParallelLinear(hidden_size, - 4 * hidden_size, - bias=config.bias, - skip_bias_add=True, - quant_config=quant_config) + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + 4 * hidden_size, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config, + ) self.act = get_act_fn("gelu") - self.reduce_row_parallel_results = not (config.new_decoder_architecture - or config.parallel_attn) + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, bias=config.bias, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results, - quant_config=quant_config) + quant_config=quant_config, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. @@ -240,7 +262,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FalconDecoderLayer(nn.Module): - def __init__( self, config: FalconConfig, @@ -252,39 +273,36 @@ def __init__( hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.self_attention = FalconAttention( - config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attention") + config, cache_config, quant_config, prefix=f"{prefix}.self_attention" + ) self.mlp = FalconMLP(config, quant_config) self.config = config - if (not hasattr(config, "num_ln_in_parallel_attn")): + if not hasattr(config, "num_ln_in_parallel_attn"): config.num_ln_in_parallel_attn = None - if (config.num_ln_in_parallel_attn is None - and config.new_decoder_architecture): + if config.num_ln_in_parallel_attn is None and config.new_decoder_architecture: config.num_ln_in_parallel_attn = 2 if not config.parallel_attn: self.post_attention_layernorm = LayerNorm( - hidden_size, eps=config.layer_norm_epsilon) - self.input_layernorm = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) + hidden_size, eps=config.layer_norm_epsilon + ) + self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) else: if config.num_ln_in_parallel_attn == 2: # The layer norm before self-attention - self.ln_attn = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) + self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) # The layer norm before the MLP - self.ln_mlp = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) else: - self.input_layernorm = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) + self.input_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon + ) - self.reduce_row_parallel_results = not (config.new_decoder_architecture - or config.parallel_attn) + self.reduce_row_parallel_results = not ( + config.new_decoder_architecture or config.parallel_attn + ) def forward( self, @@ -314,8 +332,11 @@ def forward( residual += attention_output mlp_layernorm_out = self.post_attention_layernorm(residual) - if (self.config.new_decoder_architecture and self.config.parallel_attn - and self.config.num_ln_in_parallel_attn == 1): + if ( + self.config.new_decoder_architecture + and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1 + ): mlp_layernorm_out = attention_layernorm_out # MLP. @@ -340,7 +361,6 @@ def forward( @support_torch_compile class FalconModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -363,14 +383,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, lambda prefix: FalconDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.h", + ) # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.word_embeddings(input_ids) @@ -396,8 +418,7 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: total_num_heads = self.config.num_attention_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads @@ -420,26 +441,34 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_weight_shape = loaded_weight.shape if output_dim is not None: loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + - (total_num_kv_heads, num_query_heads_per_kv_head + 2, - -1) + loaded_weight_shape[output_dim + 1:]) + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) + + loaded_weight_shape[output_dim + 1 :] + ) wq = loaded_weight.narrow( - output_dim + 1, 0, - num_query_heads_per_kv_head).reshape( - *loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) + output_dim + 1, 0, num_query_heads_per_kv_head + ).reshape( + *loaded_weight_shape[:output_dim], + -1, + *loaded_weight_shape[output_dim + 1 :], + ) wk = loaded_weight.narrow( - output_dim + 1, num_query_heads_per_kv_head, - 1).reshape(*loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) + output_dim + 1, num_query_heads_per_kv_head, 1 + ).reshape( + *loaded_weight_shape[:output_dim], + -1, + *loaded_weight_shape[output_dim + 1 :], + ) wv = loaded_weight.narrow( - output_dim + 1, num_query_heads_per_kv_head + 1, - 1).reshape(*loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) + output_dim + 1, num_query_heads_per_kv_head + 1, 1 + ).reshape( + *loaded_weight_shape[:output_dim], + -1, + *loaded_weight_shape[output_dim + 1 :], + ) loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -456,15 +485,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = FalconModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = FalconModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) # only Falcon-11B doesn't share lm_head weight with word embeddings # and previous Falcon model doesn't have tie_word_embeddings config # so we set tie_word_embeddings to True by default - self.tie_word_embeddings = (config.tie_word_embeddings - if config.tie_word_embeddings is not None - else True) + self.tie_word_embeddings = ( + config.tie_word_embeddings + if config.tie_word_embeddings is not None + else True + ) if self.tie_word_embeddings: self.lm_head = self.transformer.word_embeddings else: @@ -476,7 +507,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -488,8 +520,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -499,11 +532,9 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index f382018e2222..8af08711038d 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only FalconH1 model.""" + from collections.abc import Iterable from typing import Optional @@ -15,28 +16,38 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class FalconH1MLP(nn.Module): - def __init__( self, config: FalconH1Config, @@ -60,13 +71,15 @@ def __init__( self.intermediate_size = config.intermediate_size self.gate_multiplier, self.down_multiplier = config.mlp_multipliers if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): x, _ = self.gate_up_proj(x) - x[:, :self.intermediate_size // self.tp_size] *= self.gate_multiplier + x[:, : self.intermediate_size // self.tp_size] *= self.gate_multiplier x = self.act_fn(x) x, _ = self.down_proj(x) x = x * self.down_multiplier @@ -74,7 +87,6 @@ def forward(self, x): class FalconH1SSMDecoderLayer(nn.Module): - def __init__( self, config: FalconH1Config, @@ -87,8 +99,11 @@ def __init__( self.config = config self.tp_size = get_tensor_model_parallel_world_size() - self.d_ssm = (int(config.mamba_expand * config.hidden_size) - if config.mamba_d_ssm is None else config.mamba_d_ssm) + self.d_ssm = ( + int(config.mamba_expand * config.hidden_size) + if config.mamba_d_ssm is None + else config.mamba_d_ssm + ) self.mamba = MambaMixer2( hidden_size=config.hidden_size, @@ -115,15 +130,15 @@ def __init__( def _init_mup_vector(self): """ - Non learnable per-block scaling vector composed of element-wise - multipliersapplied to each separate contiguous block of the output + Non learnable per-block scaling vector composed of element-wise + multipliersapplied to each separate contiguous block of the output of the linear projection (in_proj) before further processing (gating, convolution, SSM): - Z block: [0 : d_ssm] → zxbcdt_multipliers[0] - X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1] - B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2] - - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S] + - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S] → zxbcdt_multipliers[3] - dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4] @@ -133,38 +148,38 @@ def _init_mup_vector(self): - S: SSM state size per group - All indices are divided by tp_size to support tensor parallelism """ - vector_shape = (2 * self.d_ssm + 2 * self.groups_time_state_size + - self.config.mamba_n_heads) // self.tp_size + vector_shape = ( + 2 * self.d_ssm + 2 * self.groups_time_state_size + self.config.mamba_n_heads + ) // self.tp_size mup_vector = torch.ones(1, vector_shape) # Z vector 0 -> d_ssm - mup_vector[:, :self.d_ssm // - self.tp_size] *= self.zxbcdt_multipliers[0] + mup_vector[:, : self.d_ssm // self.tp_size] *= self.zxbcdt_multipliers[0] # X vector d_ssm -> 2 * d_ssm - mup_vector[:, - (self.d_ssm // - self.tp_size):(2 * self.d_ssm // - self.tp_size)] *= self.zxbcdt_multipliers[1] + mup_vector[ + :, (self.d_ssm // self.tp_size) : (2 * self.d_ssm // self.tp_size) + ] *= self.zxbcdt_multipliers[1] # B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state) mup_vector[ :, - (2 * self.d_ssm) // - self.tp_size:(2 * self.d_ssm + self.groups_time_state_size) // - self.tp_size, + (2 * self.d_ssm) // self.tp_size : ( + 2 * self.d_ssm + self.groups_time_state_size + ) + // self.tp_size, ] *= self.zxbcdt_multipliers[2] # C vector 2 * d_ssm + (n_group * d_state) # -> 2 * d_ssm + 2 * (n_group * d_state) mup_vector[ :, - (2 * self.d_ssm + self.groups_time_state_size) // - self.tp_size:(2 * self.d_ssm + 2 * self.groups_time_state_size) // - self.tp_size, + (2 * self.d_ssm + self.groups_time_state_size) // self.tp_size : ( + 2 * self.d_ssm + 2 * self.groups_time_state_size + ) + // self.tp_size, ] *= self.zxbcdt_multipliers[3] # dt vector 2 * d_ssm + 2 * (n_group * d_state) # -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads mup_vector[ :, - (2 * self.d_ssm + 2 * self.groups_time_state_size) // - self.tp_size:, + (2 * self.d_ssm + 2 * self.groups_time_state_size) // self.tp_size :, ] *= self.zxbcdt_multipliers[4] self.register_buffer("mup_vector", mup_vector, persistent=False) @@ -185,7 +200,6 @@ def forward( class FalconH1AttentionDecoderLayer(nn.Module): - def __init__( self, config: FalconH1Config, @@ -196,8 +210,7 @@ def __init__( super().__init__() rope_theta = getattr(config, "rope_theta", 1e11) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads @@ -213,8 +226,11 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = (config.hidden_size // self.total_num_heads if getattr( - config, "head_dim", None) is None else config.head_dim) + self.head_dim = ( + config.hidden_size // self.total_num_heads + if getattr(config, "head_dim", None) is None + else config.head_dim + ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -345,10 +361,8 @@ def __init__( self.feed_forward = FalconH1MLP(config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -380,7 +394,8 @@ def forward( # We assume both branches produce outputs of the same # dimensionality (config.hidden_size). hidden_states = (attn_hidden * self.attn_out_multiplier) + ( - ssm_hidden * self.ssm_out_multiplier) + ssm_hidden * self.ssm_out_multiplier + ) hidden_states = hidden_states + residual # feed-forward @@ -394,7 +409,6 @@ def forward( @support_torch_compile class FalconH1Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: FalconH1Config = vllm_config.model_config.hf_config @@ -404,12 +418,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size if get_pp_group().is_first_rank: - self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -433,13 +449,13 @@ def get_layer(prefix: str): ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.final_layernorm = PPMissingLayer() @@ -453,13 +469,13 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds * self.embedding_multiplier else: - hidden_states = (self.get_input_embeddings(input_ids) * - self.embedding_multiplier) + hidden_states = ( + self.get_input_embeddings(input_ids) * self.embedding_multiplier + ) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -471,15 +487,16 @@ def forward( hidden_states=hidden_states, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) hidden_states = self.final_layernorm(hidden_states) return hidden_states -class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid): +class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -496,7 +513,6 @@ def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -521,10 +537,11 @@ def get_mamba_state_shape_from_config( parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config - intermediate_size = (int(hf_config.mamba_expand * - hf_config.hidden_size) - if hf_config.mamba_d_ssm is None else - hf_config.mamba_d_ssm) + intermediate_size = ( + int(hf_config.mamba_expand * hf_config.hidden_size) + if hf_config.mamba_d_ssm is None + else hf_config.mamba_d_ssm + ) return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, @@ -540,19 +557,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert (not cache_config.enable_prefix_caching - ), "FalconH1 currently does not support prefix caching" self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config - self.model = FalconH1Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = FalconH1Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.tie_word_embeddings = config.tie_word_embeddings self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -566,14 +581,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), prefix=maybe_prefix(prefix, "lm_head"), ) self.lm_head_multiplier = config.lm_head_multiplier if self.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) # Used to track and store by the Mamba cache between steps. self.logits_processor = LogitsProcessor( @@ -585,7 +600,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -598,7 +614,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ): - hidden_states = self.model( input_ids, positions, @@ -616,8 +631,7 @@ def compute_logits( return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -664,8 +678,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index a0f8d0659c59..83572563c15e 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -16,28 +16,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Fuyu model.""" +"""PyTorch Fuyu model.""" + import math from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Literal, Optional import torch import torch.nn as nn -from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, - FuyuProcessor) +from transformers import BatchFeature, FuyuConfig, FuyuImageProcessor, FuyuProcessor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -72,7 +78,6 @@ class FuyuImagePatchInputs(TensorSchema): class FuyuProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(FuyuConfig) @@ -124,12 +129,12 @@ def get_num_image_tokens( def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_image_processor() - return ImageSize(width=image_processor.size["width"], - height=image_processor.size["height"]) + return ImageSize( + width=image_processor.size["width"], height=image_processor.size["height"] + ) class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -139,23 +144,22 @@ def get_dummy_mm_data( mm_counts: Mapping[str, int], mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): - def _call_hf_processor( self, prompt: str, @@ -179,7 +183,8 @@ def _call_hf_processor( image_patches = processed_outputs["image_patches"] processed_outputs["image_patches"] = flatten_bn(image_patches) processed_outputs["patches_per_image"] = torch.tensor( - [len(p) for p in image_patches]) + [len(p) for p in image_patches] + ) return processed_outputs @@ -206,7 +211,8 @@ def _get_mm_fields_config( return dict( image_patches=MultiModalFieldConfig.flat_from_sizes( - "image", patches_per_image), + "image", patches_per_image + ), patches_per_image=MultiModalFieldConfig.batched("image"), ) @@ -232,8 +238,7 @@ def get_replacement_fuyu(item_idx: int): image_width=image_size.width, image_height=image_size.height, ) - image_tokens = ([_IMAGE_TOKEN_ID] * ncols + - [_NEWLINE_TOKEN_ID]) * nrows + image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows return PromptUpdateDetails.select_token_id( image_tokens + [bos_token_id], @@ -249,9 +254,11 @@ def get_replacement_fuyu(item_idx: int): ] -@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor, - info=FuyuProcessingInfo, - dummy_inputs=FuyuDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + FuyuMultiModalProcessor, + info=FuyuProcessingInfo, + dummy_inputs=FuyuDummyInputsBuilder, +) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): merge_by_field_config = True @@ -260,7 +267,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): "model.vision_embed_tokens.": "vision_embed_tokens.", "model.language_model.": "language_model.model.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -292,10 +300,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "language_model"), ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[FuyuImagePatchInputs]: + self, **kwargs: object + ) -> Optional[FuyuImagePatchInputs]: image_patches = kwargs.pop("image_patches", None) patches_per_image = kwargs.pop("patches_per_image", None) @@ -310,21 +320,20 @@ def _parse_and_validate_image_input( ) def _process_image_input( - self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings: + self, image_input: FuyuImagePatchInputs + ) -> MultiModalEmbeddings: image_patches_flat = image_input["image_patches_flat"] patches_per_image = image_input["patches_per_image"] assert self.vision_embed_tokens is not None - vision_embeddings_flat, _ = self.vision_embed_tokens( - image_patches_flat) + vision_embeddings_flat, _ = self.vision_embed_tokens(image_patches_flat) return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -355,10 +364,10 @@ def compute_logits( hidden_states: torch.Tensor, ) -> Optional[torch.Tensor]: logits = self.language_model.logits_processor( - self.language_model.lm_head, hidden_states) + self.language_model.lm_head, hidden_states + ) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index c19425b6cb6d..b152f52223cf 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" + from collections.abc import Iterable from functools import cache from itertools import islice @@ -32,21 +33,26 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -66,19 +72,22 @@ def _get_gemma_act_fn( "`%s`, edit the config JSON to set " "`hidden_activation=%s` instead of `hidden_act`. " "See https://github.com/huggingface/transformers/pull/29402 " - "for more details.", hidden_act, hidden_act) + "for more details.", + hidden_act, + hidden_act, + ) return GeluAndMul(approximate="tanh") elif hidden_activation == "gelu_pytorch_tanh": return GeluAndMul(approximate="tanh") elif hidden_activation == "gelu": return GeluAndMul(approximate="none") else: - raise ValueError(f"Activation function {hidden_act} is not " - "supported for Gemma models.") + raise ValueError( + f"Activation function {hidden_act} is not supported for Gemma models." + ) class GemmaMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -113,7 +122,6 @@ def forward(self, x): class GemmaAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -172,13 +180,15 @@ def __init__( base=self.rope_theta, is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -194,7 +204,6 @@ def forward( class GemmaDecoderLayer(nn.Module): - def __init__( self, config: GemmaConfig, @@ -223,10 +232,10 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -239,23 +248,20 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class GemmaModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -272,8 +278,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: GemmaDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -281,12 +289,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # data type such as bfloat16, not float32. # See https://github.com/huggingface/transformers/pull/29402 normalizer = self.config.hidden_size**0.5 - self.register_buffer("normalizer", - torch.tensor(normalizer), - persistent=False) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.register_buffer("normalizer", torch.tensor(normalizer), persistent=False) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -315,15 +321,13 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -335,7 +339,7 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue name = name.replace(shard_name, param_name) @@ -355,8 +359,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -388,11 +391,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.model = GemmaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = GemmaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -404,8 +409,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -415,11 +421,9 @@ def compute_logits( logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 3f76e1e7d42a..2d26edcf6609 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -31,29 +31,35 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Gemma2MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -64,18 +70,17 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): raise ValueError( "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_act` and `hidden_activation` to " - "`gelu_pytorch_tanh`.") + "`gelu_pytorch_tanh`." + ) self.act_fn = GeluAndMul(approximate="tanh") def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -86,19 +91,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Gemma2Attention(nn.Module): - - def __init__(self, - config: Gemma2Config, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - max_position_embeddings: int, - rope_theta: float, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - attn_logits_soft_cap: Optional[float] = None, - prefix: str = "") -> None: + def __init__( + self, + config: Gemma2Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + rope_theta: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = hidden_size @@ -148,15 +154,17 @@ def __init__(self, is_sliding = config.layer_types[layer_idx] == "sliding_attention" sliding_window = config.sliding_window if is_sliding else None - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=attn_logits_soft_cap, - per_layer_sliding_window=sliding_window, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -172,7 +180,6 @@ def forward( class Gemma2DecoderLayer(nn.Module): - def __init__( self, config: Gemma2Config, @@ -203,14 +210,16 @@ def __init__( hidden_activation=config.hidden_activation, quant_config=quant_config, ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -222,8 +231,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -231,7 +239,8 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, residual = self.pre_feedforward_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) return hidden_states, residual @@ -239,7 +248,6 @@ def forward( @support_torch_compile class Gemma2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -255,8 +263,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Gemma2DecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -264,12 +274,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # data type such as bfloat16, not float32. # See https://github.com/huggingface/transformers/pull/29402 normalizer = self.config.hidden_size**0.5 - self.register_buffer("normalizer", - torch.tensor(normalizer), - persistent=False) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.register_buffer("normalizer", torch.tensor(normalizer), persistent=False) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -299,15 +307,13 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -319,17 +325,17 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue name = name.replace(shard_name, param_name) @@ -353,8 +359,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -384,12 +389,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # currently all existing Gemma models have `tie_word_embeddings` enabled assert config.tie_word_embeddings self.quant_config = quant_config - self.model = Gemma2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Gemma2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.logits_processor = LogitsProcessor( - config.vocab_size, soft_cap=config.final_logit_softcapping) + config.vocab_size, soft_cap=config.final_logit_softcapping + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -401,8 +409,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -412,11 +421,9 @@ def compute_logits( logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 77c0ef8cb91d..9fa8e1c78b12 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -31,30 +31,36 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from ...attention.layers.encoder_only_attention import EncoderOnlyAttention from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Gemma3MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -82,7 +88,8 @@ def __init__( raise ValueError( "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_act` and `hidden_activation` to " - "`gelu_pytorch_tanh`.") + "`gelu_pytorch_tanh`." + ) self.act_fn = GeluAndMul(approximate="tanh") def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -93,18 +100,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Gemma3Attention(nn.Module): - - def __init__(self, - config: Gemma3TextConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - max_position_embeddings: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - attn_logits_soft_cap: Optional[float] = None, - prefix: str = "") -> None: + def __init__( + self, + config: Gemma3TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = hidden_size @@ -174,19 +182,24 @@ def __init__(self, else: attn_type = AttentionType.ENCODER_ONLY - attn_cls = (EncoderOnlyAttention - if attn_type == AttentionType.ENCODER_ONLY else Attention) + attn_cls = ( + EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY + else Attention + ) - self.attn = attn_cls(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - attn_type=attn_type, - logits_soft_cap=attn_logits_soft_cap, - per_layer_sliding_window=sliding_window, - prefix=f"{prefix}.attn") + self.attn = attn_cls( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -225,11 +238,7 @@ def forward( # output is discarded and overwritten below. While this duplicates # computation, it maintains compatibility. # TODO(woosuk): Optimize by implementing custom attention kernels. - attn_output = self.naive_attn_with_masks(q, - k, - v, - out=attn_output, - **kwargs) + attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs) output, _ = self.o_proj(attn_output) return output @@ -283,7 +292,6 @@ def naive_attn_with_masks( class Gemma3DecoderLayer(nn.Module): - def __init__( self, config: Gemma3TextConfig, @@ -313,14 +321,16 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -333,8 +343,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -343,7 +352,8 @@ def forward( hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, residual = self.pre_feedforward_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states) return hidden_states, residual @@ -351,7 +361,6 @@ def forward( @support_torch_compile class Gemma3Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -368,8 +377,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Gemma3DecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -377,12 +388,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # data type such as bfloat16, not float32. # See https://github.com/huggingface/transformers/pull/29402 normalizer = self.config.hidden_size**0.5 - self.register_buffer("normalizer", - torch.tensor(normalizer), - persistent=False) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.register_buffer("normalizer", torch.tensor(normalizer), persistent=False) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: # NOTE(woosuk): Only apply the normalizer to the output of @@ -415,15 +424,13 @@ def forward( **kwargs, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -435,33 +442,33 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue # Check if this is a scale parameter that needs remapping first - if name.endswith( - (".k_scale", ".v_scale", ".q_scale", ".prob_scale")): + if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")): # Try to remap the scale name first remapped_name = maybe_remap_kv_scale_name(name, params_dict) if remapped_name is not None and remapped_name in params_dict: # Successfully remapped, use the remapped name param = params_dict[remapped_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(remapped_name) continue # If remapping failed, continue with normal processing - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue name = name.replace(shard_name, param_name) @@ -485,8 +492,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -516,12 +522,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # currently all existing Gemma models have `tie_word_embeddings` enabled assert config.tie_word_embeddings self.quant_config = quant_config - self.model = Gemma3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Gemma3Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.logits_processor = LogitsProcessor( - config.vocab_size, soft_cap=config.final_logit_softcapping) + config.vocab_size, soft_cap=config.final_logit_softcapping + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -534,8 +543,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, **kwargs) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) return hidden_states def compute_logits( @@ -545,11 +555,9 @@ def compute_logits( logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index e1956b94cdc8..95b0b0dab5a1 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -16,29 +16,40 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -# yapf: disable -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalPromptUpdates, - MultiModalPromptUpdatesApplyResult, - PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate, - PromptUpdateDetails, - replace_token_matches) -# yapf: enable +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, + replace_token_matches, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) logger = init_logger(__name__) @@ -53,6 +64,7 @@ class Gemma3ImagePixelInputs(TensorSchema): - w: Width of each patch - bn: Batch size * number of images """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[torch.Tensor, TensorShape("p", 3, "h", "w")] @@ -64,7 +76,6 @@ class Gemma3ImagePixelInputs(TensorSchema): class Gemma3ProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Gemma3Config) @@ -107,19 +118,21 @@ def get_num_crops( processor = self.get_hf_processor() images_kwargs = self._resolve_image_kwargs( - processor, { - "do_pan_and_scan", "pan_and_scan_min_crop_size", + processor, + { + "do_pan_and_scan", + "pan_and_scan_min_crop_size", "pan_and_scan_max_num_crops", - "pan_and_scan_min_ratio_to_activate" - }) + "pan_and_scan_min_ratio_to_activate", + }, + ) do_pan_and_scan = images_kwargs["do_pan_and_scan"] - pan_and_scan_min_crop_size = images_kwargs[ - "pan_and_scan_min_crop_size"] - pan_and_scan_max_num_crops = images_kwargs[ - "pan_and_scan_max_num_crops"] + pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"] + pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] pan_and_scan_min_ratio_to_activate = images_kwargs[ - "pan_and_scan_min_ratio_to_activate"] + "pan_and_scan_min_ratio_to_activate" + ] if not do_pan_and_scan: return 0 @@ -127,7 +140,8 @@ def get_num_crops( if envs.VLLM_USE_V1: logger.warning_once( "`do_pan_and_scan=True` has suboptimal results on V1 " - "because of the simplified attention pattern being used.") + "because of the simplified attention pattern being used." + ) # Based on Gemma3ImageProcessor.pan_and_scan if image_width >= image_height: @@ -187,10 +201,10 @@ def get_image_repl( crops_image_tokens = " ".join(boi_token for _ in range(num_crops)) image_text = ( f"Here is the original image {boi_token} and here are some " - f"crops to help you see better {crops_image_tokens}") + f"crops to help you see better {crops_image_tokens}" + ) - repl_full = image_text.replace(boi_token, - processor.full_image_sequence) + repl_full = image_text.replace(boi_token, processor.full_image_sequence) tokenizer = processor.tokenizer vocab = tokenizer.get_vocab() @@ -221,7 +235,8 @@ def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() images_kwargs = self._resolve_image_kwargs( - processor, {"pan_and_scan_max_num_crops"}) + processor, {"pan_and_scan_max_num_crops"} + ) max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] # Result in the max possible feature size (h:w = max_num_crops:1) @@ -229,7 +244,6 @@ def get_image_size_with_most_features(self) -> ImageSize: class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -246,22 +260,21 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): - def _call_hf_processor( self, prompt: str, @@ -278,20 +291,22 @@ def _call_hf_processor( # HF processor pops the `num_crops` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": - images - }).get_items("image", ImageProcessorItems)) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) image_sizes = [ - parsed_images.get_image_size(i) - for i in range(len(parsed_images)) + parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] hf_processor = self.info.get_hf_processor(**mm_kwargs) num_crops = [ - self.info.get_num_crops(image_width=size.width, - image_height=size.height, - processor=hf_processor) + self.info.get_num_crops( + image_width=size.width, + image_height=size.height, + processor=hf_processor, + ) for size in image_sizes ] processed_outputs["num_patches"] = torch.tensor(num_crops) + 1 @@ -306,8 +321,7 @@ def _get_mm_fields_config( num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), num_patches=MultiModalFieldConfig.batched("image"), ) @@ -343,8 +357,7 @@ def _apply_token_matches( prompt: list[int], mm_prompt_updates: MultiModalPromptUpdates, ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: - token_ids, res = super()._apply_token_matches(prompt, - mm_prompt_updates) + token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -403,8 +416,7 @@ def get_repl_toks(tok: int) -> list[int]: repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = super()._find_mm_placeholders(repl_token_ids, - mm_prompt_updates) + repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates) return { modality: [ @@ -414,39 +426,43 @@ def get_repl_toks(tok: int) -> list[int]: start_idx=repl_orig_idxs[p.start_idx], tokens=p.tokens, is_embed=p.is_embed, - ) for p in placeholders + ) + for p in placeholders ] for modality, placeholders in repls.items() } class Gemma3MultiModalProjector(nn.Module): - def __init__(self, config: Gemma3Config): super().__init__() self.mm_input_projection_weight = nn.Parameter( - torch.zeros(config.vision_config.hidden_size, - config.text_config.hidden_size)) + torch.zeros( + config.vision_config.hidden_size, config.text_config.hidden_size + ) + ) self.mm_soft_emb_norm = GemmaRMSNorm( - config.vision_config.hidden_size, - eps=config.vision_config.layer_norm_eps) + config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps + ) - self.patches_per_image = int(config.vision_config.image_size // - config.vision_config.patch_size) + self.patches_per_image = int( + config.vision_config.image_size // config.vision_config.patch_size + ) self.tokens_per_side = int(config.mm_tokens_per_image**0.5) self.kernel_size = self.patches_per_image // self.tokens_per_side - self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, - stride=self.kernel_size) + self.avg_pool = nn.AvgPool2d( + kernel_size=self.kernel_size, stride=self.kernel_size + ) def forward(self, vision_outputs: torch.Tensor): batch_size, _, seq_length = vision_outputs.shape reshaped_vision_outputs = vision_outputs.transpose(1, 2) reshaped_vision_outputs = reshaped_vision_outputs.reshape( - batch_size, seq_length, self.patches_per_image, - self.patches_per_image) + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) reshaped_vision_outputs = reshaped_vision_outputs.contiguous() pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) @@ -456,15 +472,19 @@ def forward(self, vision_outputs: torch.Tensor): normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) projected_vision_outputs = torch.matmul( - normed_vision_outputs, self.mm_input_projection_weight) + normed_vision_outputs, self.mm_input_projection_weight + ) return projected_vision_outputs.type_as(vision_outputs) -@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, - info=Gemma3ProcessingInfo, - dummy_inputs=Gemma3DummyInputsBuilder) -class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, - SupportsLoRA): +@MULTIMODAL_REGISTRY.register_processor( + Gemma3MultiModalProcessor, + info=Gemma3ProcessingInfo, + dummy_inputs=Gemma3DummyInputsBuilder, +) +class Gemma3ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA +): merge_by_field_config = True packed_modules_mapping = { @@ -486,7 +506,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -504,10 +525,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config self.multimodal_config = multimodal_config - self.vision_tower = SiglipVisionModel(config.vision_config, - quant_config, - prefix=maybe_prefix( - prefix, "vision_tower")) + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = Gemma3MultiModalProjector(config) self.language_model = init_vllm_registered_model( @@ -524,14 +546,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) @property def dtype(self): return next(self.parameters()).dtype def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Gemma3ImageInputs]: + self, **kwargs: object + ) -> Optional[Gemma3ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -541,12 +565,11 @@ def _parse_and_validate_image_input( image_size = self.config.vision_config.image_size - return Gemma3ImagePixelInputs(pixel_values=pixel_values, - num_patches=num_patches, - resolve_bindings={ - "h": image_size, - "w": image_size - }) + return Gemma3ImagePixelInputs( + pixel_values=pixel_values, + num_patches=num_patches, + resolve_bindings={"h": image_size, "w": image_size}, + ) def _image_pixels_to_features( self, @@ -570,35 +593,36 @@ def _process_image_input( ) image_embeds = self.multi_modal_projector(image_features) - return [ - e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) - ] + return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())] def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object) -> IntermediateTensors: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds, - **kwargs) + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) return hidden_states @@ -646,7 +670,7 @@ def prepare_attn_masks( # Consider the bidirectional attention between image tokens. img_mask = torch.zeros_like(global_attn_mask) - img_pos = (input_token_ids == self.config.image_token_index) + img_pos = input_token_ids == self.config.image_token_index img_mask[:, :, :, img_pos] += 1 img_mask[:, :, img_pos, :] += 1 global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) @@ -656,10 +680,10 @@ def prepare_attn_masks( if sliding_window is not None: # Create a local causal mask with sliding window (1024). local_attn_mask = torch.ones_like(global_attn_mask) - local_attn_mask = torch.tril(local_attn_mask, - diagonal=-sliding_window) - local_attn_mask = torch.where(local_attn_mask == 0, - global_attn_mask, float("-inf")) + local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window) + local_attn_mask = torch.where( + local_attn_mask == 0, global_attn_mask, float("-inf") + ) local_attn_masks.append(local_attn_mask) kwargs["global_attn_masks"] = global_attn_masks kwargs["local_attn_masks"] = local_attn_masks @@ -671,8 +695,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -683,4 +706,5 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", - tower_model="vision_tower") + tower_model="vision_tower", + ) diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 0b6bccb33498..e4ea4256ebc2 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -28,28 +28,38 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY, - GeluAndMul, - GeluAndMulSparse) +from vllm.model_executor.layers.activation import ( + _ACTIVATION_REGISTRY, + GeluAndMul, + GeluAndMulSparse, +) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from .interfaces import SupportsQuant -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, make_layers, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -110,9 +120,11 @@ def __init__( eps=rms_norm_eps, ) self.router_input_scale = torch.tensor( - hidden_size**-1.0, dtype=self.modality_router.weight.dtype) + hidden_size**-1.0, dtype=self.modality_router.weight.dtype + ) self.correct_output_scale = nn.Parameter( - torch.zeros(hidden_size, dtype=torch.float32)) + torch.zeros(hidden_size, dtype=torch.float32) + ) def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: router_inputs = self.router_norm(x) * self.router_input_scale @@ -120,15 +132,17 @@ def _compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor: return torch.tanh(routed.float()).type_as(x) def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor: - return (corrected.type_as(self.correct_output_scale) * - self.correct_output_scale).type_as(corrected) + return ( + corrected.type_as(self.correct_output_scale) * self.correct_output_scale + ).type_as(corrected) def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: # hidden: [altup_num_inputs, num_tokens, hidden_size] # modalities: [num_tokens, num_altup_inputs] # all_coefs: [num_tokens, num_altup_inputs ** 2] modalities = self._compute_router_modalities( - hidden_states[self.altup_active_idx]) + hidden_states[self.altup_active_idx] + ) all_coefs = self.prediction_coefs(modalities) # Reshape and transpose the 2D matrix for the matmul. @@ -146,8 +160,9 @@ def predict(self, hidden_states: torch.Tensor) -> torch.Tensor: predictions += hidden_states return predictions.contiguous() - def correct(self, predictions: torch.Tensor, - activated: torch.Tensor) -> torch.Tensor: + def correct( + self, predictions: torch.Tensor, activated: torch.Tensor + ) -> torch.Tensor: # predictions: [altup_num_inputs, num_tokens, hidden_size] # activated: [num_tokens, hidden_size] # modalities: [num_tokens, altup_num_inputs] @@ -215,7 +230,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Gemma3nMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -244,12 +258,16 @@ def __init__( raise ValueError( "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation " "function. Please set `hidden_act` and `hidden_activation` to " - "`gelu_pytorch_tanh`.") + "`gelu_pytorch_tanh`." + ) - self.act_fn = GeluAndMulSparse( - activation_sparsity=activation_sparsity, - approximate="tanh") if activation_sparsity > 0.0 else GeluAndMul( - approximate="tanh") + self.act_fn = ( + GeluAndMulSparse( + activation_sparsity=activation_sparsity, approximate="tanh" + ) + if activation_sparsity > 0.0 + else GeluAndMul(approximate="tanh") + ) def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) @@ -259,17 +277,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Gemma3nAttention(nn.Module): - - def __init__(self, - config: Gemma3nTextConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - max_position_embeddings: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: Gemma3nTextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = hidden_size @@ -307,13 +326,11 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) - self.q_norm = RMSNorm(hidden_size=self.head_dim, - eps=config.rms_norm_eps) - self.k_norm = RMSNorm(hidden_size=self.head_dim, - eps=config.rms_norm_eps) - self.v_norm = RMSNorm(hidden_size=self.head_dim, - eps=config.rms_norm_eps, - has_weight=False) + self.q_norm = RMSNorm(hidden_size=self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(hidden_size=self.head_dim, eps=config.rms_norm_eps) + self.v_norm = RMSNorm( + hidden_size=self.head_dim, eps=config.rms_norm_eps, has_weight=False + ) layer_idx = extract_layer_index(prefix) is_sliding = config.layer_types[layer_idx] == "sliding_attention" @@ -329,8 +346,9 @@ def __init__(self, rope_theta = config.rope_theta rope_scaling = config.rope_scaling - first_kv_shared_layer_idx = (config.num_hidden_layers - - config.num_kv_shared_layers) + first_kv_shared_layer_idx = ( + config.num_hidden_layers - config.num_kv_shared_layers + ) self.is_kv_shared = layer_idx >= first_kv_shared_layer_idx kv_sharing_target_layer_name = None @@ -361,7 +379,8 @@ def __init__(self, quant_config=quant_config, per_layer_sliding_window=self.sliding_window, kv_sharing_target_layer_name=kv_sharing_target_layer_name, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + ) def forward( self, @@ -390,7 +409,6 @@ def forward( class Gemma3nDecoderLayer(nn.Module): - def __init__( self, config: Gemma3nTextConfig, @@ -426,12 +444,12 @@ def __init__( self.mlp = Gemma3nMLP( hidden_size=config.hidden_size, # NOTE: Matformer https://github.com/huggingface/transformers/blob/a52478253bbe522a420e88ea3940d4d98a935300/src/transformers/models/gemma3n/modular_gemma3n.py#L258 # noqa: E501 - intermediate_size=config.intermediate_size[extract_layer_index( - prefix)], + intermediate_size=config.intermediate_size[extract_layer_index(prefix)], hidden_activation=config.hidden_activation, quant_config=quant_config, activation_sparsity=config.activation_sparsity_pattern[ - extract_layer_index(prefix)], + extract_layer_index(prefix) + ], prefix=f"{prefix}.mlp", ) self.laurel = Gemma3nLaurelBlock( @@ -493,7 +511,6 @@ def forward( per_layer_input: torch.Tensor, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: - # ActUp (predict). predictions = self.altup.predict(hidden_states) active_prediction = predictions[self.altup_active_idx] @@ -508,8 +525,7 @@ def forward( ) attn = self.post_attention_layernorm(attn) attn_gated = attn + active_prediction - attn_laurel = (attn_gated + laurel_output) / torch.sqrt( - torch.tensor(2.0)) + attn_laurel = (attn_gated + laurel_output) / torch.sqrt(torch.tensor(2.0)) # MLP. attn_norm = self.pre_feedforward_layernorm(attn_laurel) @@ -518,8 +534,7 @@ def forward( attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm # ActUp (connect). - corrected_predictions = self.altup.correct(predictions, - attn_ffw_laurel_gated) + corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated) first_prediction = corrected_predictions[self.altup_active_idx] first_prediction = self.altup.scale_corrected_output(first_prediction) @@ -537,8 +552,9 @@ def forward( # This enables torch.compile if --kv-sharing-fast-prefill passed -@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. - kv_sharing_fast_prefill) +@support_torch_compile( + enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill +) class Gemma3nSelfDecoder(nn.Module): """ Includes altup embedding and self decoder layers @@ -595,34 +611,41 @@ def __init__( eps=config.rms_norm_eps, ) self.per_layer_input_scale = torch.rsqrt(torch.tensor(2.0)).to( - self.embed_tokens.weight.dtype) + self.embed_tokens.weight.dtype + ) self.per_layer_projection_scale = torch.tensor( config.hidden_size**0.5, dtype=self.embed_tokens.weight.dtype, ) - self.altup_projections = nn.ModuleList([ - ColumnParallelLinear( - config.hidden_size, - config.hidden_size, - bias=False, - gather_output=True, - return_bias=False, - quant_config=quant_config, - prefix=f"{prefix}.altup_projections.{idx-1}", - ) for idx in range(1, self.config.altup_num_inputs) - ]) - - def get_per_layer_input_embeddings( - self, input_ids: torch.Tensor) -> torch.Tensor: + self.altup_projections = nn.ModuleList( + [ + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + gather_output=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.altup_projections.{idx - 1}", + ) + for idx in range(1, self.config.altup_num_inputs) + ] + ) + + def get_per_layer_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: # Deal with the fact that vocab_size_per_layer_input < vocab_size # which causes us to have some out of vocab tokens by setting # those token ids to 0. This matches the HF implementation. per_layer_inputs_mask = torch.logical_and( - input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input) - per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, - torch.zeros_like(input_ids)) - return self.embed_tokens_per_layer( - per_layer_inputs_tokens) * self.embed_scale_per_layer + input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + return ( + self.embed_tokens_per_layer(per_layer_inputs_tokens) + * self.embed_scale_per_layer + ) def get_per_layer_inputs( self, @@ -635,8 +658,7 @@ def get_per_layer_inputs( self.config.num_hidden_layers, self.config.hidden_size_per_layer_input, ) - per_layer_projection = self.per_layer_projection_norm( - per_layer_projection) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) if per_layer_inputs is not None: # Profiling run does not compute per_layer_inputs per_layer_inputs = per_layer_projection + per_layer_inputs @@ -651,15 +673,13 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor: # Altup embed. hidden_states = [hidden_states_0] * self.config.altup_num_inputs - target_magnitude = torch.mean(hidden_states_0**2, dim=-1, - keepdim=True)**0.5 + target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5 for i in range(1, self.config.altup_num_inputs): hidden_states[i] = self.altup_projections[i - 1](hidden_states[i]) - new_magnitude = torch.mean(hidden_states[i]**2, - dim=-1, - keepdim=True)**0.5 - hidden_states[i] *= target_magnitude / torch.maximum( - new_magnitude, EPS) + new_magnitude = ( + torch.mean(hidden_states[i] ** 2, dim=-1, keepdim=True) ** 0.5 + ) + hidden_states[i] *= target_magnitude / torch.maximum(new_magnitude, EPS) hidden_states = torch.stack(hidden_states, dim=-1) return hidden_states @@ -677,7 +697,8 @@ def forward( hidden_states_0 = self.get_input_embeddings(input_ids) adjusted_per_layer_inputs = self.get_per_layer_inputs( - hidden_states_0, per_layer_inputs) + hidden_states_0, per_layer_inputs + ) hidden_states = self.altup_embed(hidden_states_0) # [altnum_inputs, num_tokens, hidden_size] @@ -700,8 +721,9 @@ def forward( # This enables torch.compile if --kv-sharing-fast-prefill passed -@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config. - kv_sharing_fast_prefill) +@support_torch_compile( + enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill +) class Gemma3nCrossDecoder(nn.Module): """ Cross-decoder layers @@ -743,10 +765,10 @@ def forward( # This disables torch.compile if --kv-sharing-fast-prefill passed -@support_torch_compile(enable_if=lambda vllm_config: not vllm_config. - cache_config.kv_sharing_fast_prefill) +@support_torch_compile( + enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill +) class Gemma3nTextModel(nn.Module, SupportsQuant): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -755,27 +777,33 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - self.altup_unembed_projections = nn.ModuleList([ - ColumnParallelLinear( - config.hidden_size, - config.hidden_size, - bias=False, - gather_output=True, - return_bias=False, - quant_config=quant_config, - prefix=f"{prefix}.altup_unembed_projections.{idx-1}", - ) for idx in range(1, self.config.altup_num_inputs) - ]) + self.altup_unembed_projections = nn.ModuleList( + [ + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + gather_output=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.altup_unembed_projections.{idx - 1}", + ) + for idx in range(1, self.config.altup_num_inputs) + ] + ) # Allocate config.num_kv_shared_layers layers for self-decoder self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Gemma3nDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) - first_kv_shared_layer_idx = (config.num_hidden_layers - - config.num_kv_shared_layers) + first_kv_shared_layer_idx = ( + config.num_hidden_layers - config.num_kv_shared_layers + ) # NOTE(sarckk): importing this top level seems to cause issues # during running of tests. @@ -810,18 +838,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # TODO(sarckk): Extract this functionality to interface max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens device = next(self.parameters()).device - self.positions = torch.zeros(max_num_tokens, - dtype=torch.int64, - device=device) + self.positions = torch.zeros( + max_num_tokens, dtype=torch.int64, device=device + ) self.hidden_states = torch.zeros( - (max_num_tokens, config.hidden_size, - self.config.altup_num_inputs), + (max_num_tokens, config.hidden_size, self.config.altup_num_inputs), dtype=self.embed_tokens.weight.dtype, device=device, ) self.per_layer_inputs = torch.zeros( - (max_num_tokens, self.config.num_hidden_layers, - self.config.hidden_size_per_layer_input), + ( + max_num_tokens, + self.config.num_hidden_layers, + self.config.hidden_size_per_layer_input, + ), dtype=self.embed_tokens.weight.dtype, device=device, ) @@ -830,8 +860,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def embed_tokens(self): return self.self_decoder.embed_tokens - def get_per_layer_input_embeddings( - self, input_ids: torch.Tensor) -> torch.Tensor: + def get_per_layer_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.self_decoder.get_per_layer_input_embeddings(input_ids) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -849,27 +878,26 @@ def fast_prefill_forward( attn_metadata = get_forward_context().attn_metadata # attn_metadata is None during dummy runs - if (self.fast_prefill_enabled and attn_metadata is not None): + if self.fast_prefill_enabled and attn_metadata is not None: assert isinstance(attn_metadata, dict) # Last layer is a KV sharing layer layer_attn_metadata = attn_metadata[ - self.layers[-1].self_attn.attn.layer_name] - if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)): - logits_indices_padded = ( - layer_attn_metadata.logits_indices_padded) + self.layers[-1].self_attn.attn.layer_name + ] + if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata): + logits_indices_padded = layer_attn_metadata.logits_indices_padded num_logits_indices = layer_attn_metadata.num_logits_indices # Copy inputs for cudagraph batch_size = positions.size(0) self.positions[:batch_size].copy_(positions) - self_decoder_hidden_states, per_layer_inputs_adjusted = \ - self.self_decoder( - input_ids=input_ids, - positions=self.positions[:batch_size], - inputs_embeds=inputs_embeds, - per_layer_inputs=per_layer_inputs, - **kwargs, - ) + self_decoder_hidden_states, per_layer_inputs_adjusted = self.self_decoder( + input_ids=input_ids, + positions=self.positions[:batch_size], + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) if logits_indices_padded is None: logits_indices_padded = torch.arange( @@ -889,11 +917,14 @@ def fast_prefill_forward( # Copy inputs for cudagraph num_padded_logits_indices = logits_indices_padded.size(0) self.positions[:num_padded_logits_indices].copy_( - positions[logits_indices_padded]) + positions[logits_indices_padded] + ) self.hidden_states[:num_padded_logits_indices].copy_( - self_decoder_hidden_states[logits_indices_padded]) + self_decoder_hidden_states[logits_indices_padded] + ) self.per_layer_inputs[:num_padded_logits_indices].copy_( - per_layer_inputs_adjusted[logits_indices_padded]) + per_layer_inputs_adjusted[logits_indices_padded] + ) cross_decoder_hidden_states = self.cross_decoder( positions=self.positions[:num_padded_logits_indices], hidden_states=self.hidden_states[:num_padded_logits_indices], @@ -905,7 +936,8 @@ def fast_prefill_forward( assert num_logits_indices > 0 # Merge cross-decoder and self-decoder hidden states hidden_states[logits_indices_padded[:num_logits_indices]] = ( - cross_decoder_hidden_states[:num_logits_indices]) + cross_decoder_hidden_states[:num_logits_indices] + ) else: hidden_states = cross_decoder_hidden_states @@ -939,17 +971,19 @@ def altup_unembed( hidden_states: torch.Tensor, ) -> torch.Tensor: # Altup unembed. - target_magnitude = torch.mean(hidden_states[..., 0]**2, - dim=-1, - keepdim=True)**0.5 + target_magnitude = ( + torch.mean(hidden_states[..., 0] ** 2, dim=-1, keepdim=True) ** 0.5 + ) for i in range(1, self.config.altup_num_inputs): hidden_states[..., i] = self.altup_unembed_projections[i - 1]( - hidden_states[..., i]) - new_magnitude = torch.mean(hidden_states[..., i]**2, - dim=-1, - keepdim=True)**0.5 + hidden_states[..., i] + ) + new_magnitude = ( + torch.mean(hidden_states[..., i] ** 2, dim=-1, keepdim=True) ** 0.5 + ) hidden_states[..., i] *= target_magnitude / torch.maximum( - new_magnitude, EPS) + new_magnitude, EPS + ) # [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size] hidden_states = torch.mean(hidden_states, dim=-1) return hidden_states @@ -982,8 +1016,7 @@ def forward( hidden_states = self.altup_unembed(hidden_states) return self.norm(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -997,22 +1030,24 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: # decoder layer weights, altup_unembed_projections and rmsnorm # are initialized in text model, others are in self decoder - if (not name.startswith('layers') - and not name.startswith('altup_unembed_projections') - and not name.startswith('norm')): + if ( + not name.startswith("layers") + and not name.startswith("altup_unembed_projections") + and not name.startswith("norm") + ): name = f"self_decoder.{name}" - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, shard_name, shard_id) in stacked_params_mapping: + for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue # Avoid spurious match with ".up_proj". @@ -1039,8 +1074,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -1067,10 +1101,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = config self.cache_config = vllm_config.cache_config - self.model = Gemma3nTextModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Gemma3nTextModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.logits_processor = LogitsProcessor( - config.vocab_size, soft_cap=config.final_logit_softcapping) + config.vocab_size, soft_cap=config.final_logit_softcapping + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -1085,7 +1121,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model( input_ids, positions, @@ -1103,11 +1138,11 @@ def compute_logits( logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, - skip_substrs=([ - "embed_audio.", "embed_vision.", - "audio_tower.", "vision_tower." - ])) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_substrs=( + ["embed_audio.", "embed_vision.", "audio_tower.", "vision_tower."] + ), + ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 22f9967ebdcf..0e69fcfd8feb 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -5,14 +5,17 @@ import numpy as np import torch -# yapf: disable + from torch import nn from transformers import AutoModel, BatchFeature -from transformers.models.gemma3n import (Gemma3nAudioConfig, - Gemma3nAudioFeatureExtractor, - Gemma3nConfig, Gemma3nProcessor, - Gemma3nTextConfig, - Gemma3nVisionConfig) +from transformers.models.gemma3n import ( + Gemma3nAudioConfig, + Gemma3nAudioFeatureExtractor, + Gemma3nConfig, + Gemma3nProcessor, + Gemma3nTextConfig, + Gemma3nVisionConfig, +) from transformers.models.siglip import SiglipImageProcessorFast from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig @@ -21,33 +24,44 @@ from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import RowParallelLinear -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalPromptUpdates, - MultiModalPromptUpdatesApplyResult, - PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate, - PromptUpdateDetails, - replace_token_matches) -# yapf: enable +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, + replace_token_matches, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) logger = init_logger(__name__) @@ -64,6 +78,7 @@ class Gemma3nImagePixelInputs(TensorSchema): - h: Height of each patch - w: Width of each patch """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] @@ -75,6 +90,7 @@ class Gemma3nAudioInputs(TensorSchema): - s: seq_length - f: num_features """ + type: Literal["audio"] = "audio" input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")] input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")] @@ -84,7 +100,6 @@ class Gemma3nAudioInputs(TensorSchema): class Gemma3nProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Gemma3nConfig) @@ -95,9 +110,8 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "audio": None} def get_max_tokens_per_item( - self, seq_len: int, - mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]: - + self, seq_len: int, mm_counts: Mapping[str, int] + ) -> Optional[Mapping[str, int]]: return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO} def get_image_repl( @@ -109,7 +123,7 @@ def get_image_repl( ) -> str: """ Get the replacement text for image tokens. - + For Gemma3n, this should return the full_image_sequence which includes BOI token, repeated image tokens, and EOI token. """ @@ -117,7 +131,8 @@ def get_image_repl( processor = self.get_hf_processor() return PromptUpdateDetails.select_token_id( - processor.full_image_sequence, processor.image_token_id) + processor.full_image_sequence, processor.image_token_id + ) def get_audio_repl( self, @@ -126,7 +141,7 @@ def get_audio_repl( ) -> str: """ Get the replacement text for audio tokens. - + For Gemma3n, this should return the full_audio_sequence which includes BOA token, repeated audio tokens, and EOA token. """ @@ -135,11 +150,11 @@ def get_audio_repl( # Return the full audio sequence as defined by the processor return PromptUpdateDetails.select_token_id( - processor.full_audio_sequence, processor.audio_token_id) + processor.full_audio_sequence, processor.audio_token_id + ) class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_audios = mm_counts.get("audio", 0) @@ -159,7 +174,9 @@ def get_dummy_mm_data( num_images = mm_counts.get("image", 0) num_audios = mm_counts.get("audio", 0) processor = self.info.get_hf_processor() - audio_feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501 + audio_feature_extractor: Gemma3nAudioFeatureExtractor = ( + processor.feature_extractor + ) audio_len = audio_feature_extractor.fft_length image_processor: SiglipImageProcessorFast = processor.image_processor img_width = image_processor.size.get("width", 224) @@ -169,21 +186,19 @@ def get_dummy_mm_data( audio_overrides = mm_options.get("audio") if mm_options else None return { - "image": - self._get_dummy_images(width=img_width, - height=img_height, - num_images=num_images, - overrides=image_overrides), - "audio": - self._get_dummy_audios(length=audio_len, - num_audios=num_audios, - overrides=audio_overrides) + "image": self._get_dummy_images( + width=img_width, + height=img_height, + num_images=num_images, + overrides=image_overrides, + ), + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ), } -class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] - ): - +class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_hf_processor().feature_extractor return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -195,12 +210,11 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: - # HF Transformers audio processor no longer accepts `audios` key. # We pop `audios` and replace it with `audio` key to suppress # the warning. - if 'audios' in mm_data: - mm_data['audio'] = mm_data.pop('audios') + if "audios" in mm_data: + mm_data["audio"] = mm_data.pop("audios") processed_outputs = super()._call_hf_processor( prompt, mm_data, @@ -208,15 +222,17 @@ def _call_hf_processor( tok_kwargs, ) - if 'input_features' in processed_outputs: + if "input_features" in processed_outputs: # Padding enables audio_tower to run in batched mode - processed_outputs["input_features_padded"] = \ - processed_outputs["input_features"] + processed_outputs["input_features_padded"] = processed_outputs[ + "input_features" + ] # Unpad features here since we need the output of each item to be # independent of other items for the cache to work correctly unpadded_features = [ - f[mask] for f, mask in zip( + f[mask] + for f, mask in zip( processed_outputs["input_features"], processed_outputs["input_features_mask"], ) @@ -229,7 +245,6 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( pixel_values=MultiModalFieldConfig.batched("image"), input_features_padded=MultiModalFieldConfig.batched("audio"), @@ -264,21 +279,25 @@ def get_replacement_image(item_idx: int): modality="image", target=image_token, replacement=get_replacement_image, - )) + ) + ) # Handle audio tokens if "audio" in mm_items: audio_token = hf_processor.audio_token def get_replacement_audio(item_idx: int): - return self.info.get_audio_repl(processor=hf_processor, ) + return self.info.get_audio_repl( + processor=hf_processor, + ) prompt_updates.append( PromptReplacement( modality="audio", target=audio_token, replacement=get_replacement_audio, - )) + ) + ) return prompt_updates @@ -287,8 +306,7 @@ def _apply_token_matches( prompt: list[int], mm_prompt_updates: MultiModalPromptUpdates, ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: - token_ids, res = super()._apply_token_matches(prompt, - mm_prompt_updates) + token_ids, res = super()._apply_token_matches(prompt, mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -347,8 +365,7 @@ def get_repl_toks(tok: int) -> list[int]: repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = super()._find_mm_placeholders(repl_token_ids, - mm_prompt_updates) + repls = super()._find_mm_placeholders(repl_token_ids, mm_prompt_updates) return { modality: [ @@ -358,14 +375,15 @@ def get_repl_toks(tok: int) -> list[int]: start_idx=repl_orig_idxs[p.start_idx], tokens=p.tokens, is_embed=p.is_embed, - ) for p in placeholders + ) + for p in placeholders ] for modality, placeholders in repls.items() } class Gemma3nMultimodalEmbedder(nn.Module): - """Embeds token ids or soft tokens for multimodal content into language + """Embeds token ids or soft tokens for multimodal content into language model space.""" def __init__( @@ -425,7 +443,8 @@ def forward( """ # noqa: E501 if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds") + "You must specify exactly one of input_ids or inputs_embeds" + ) if inputs_embeds is not None: emb_norm = self.soft_embedding_norm(inputs_embeds) @@ -437,11 +456,14 @@ def forward( return self.embedding_post_projection_norm(emb_norm_proj) -@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor, - info=Gemma3nProcessingInfo, - dummy_inputs=Gemma3nDummyInputsBuilder) -class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsTranscription): +@MULTIMODAL_REGISTRY.register_processor( + Gemma3nMultiModalProcessor, + info=Gemma3nProcessingInfo, + dummy_inputs=Gemma3nDummyInputsBuilder, +) +class Gemma3nForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsTranscription +): merge_by_field_config = True supported_languages = ISO639_1_SUPPORTED_LANGS @@ -468,7 +490,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", "model": "language_model.model", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -482,10 +505,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vision_tower = AutoModel.from_config(config=config.vision_config) self.audio_tower = AutoModel.from_config(config=config.audio_config) - self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, - config.text_config) - self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, - config.text_config) + self.embed_vision = Gemma3nMultimodalEmbedder( + config.vision_config, config.text_config + ) + self.embed_audio = Gemma3nMultimodalEmbedder( + config.audio_config, config.text_config + ) self.language_model: nn.Module = init_vllm_registered_model( vllm_config=vllm_config, @@ -501,10 +526,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config.text_config.num_hidden_layers, self.config.text_config.hidden_size_per_layer_input, device=self.language_model.model.embed_tokens.weight.device, - dtype=self.language_model.model.embed_tokens.weight.dtype) + dtype=self.language_model.model.embed_tokens.weight.dtype, + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Gemma3nImageInputs]: + self, **kwargs: object + ) -> Optional[Gemma3nImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) # TODO is this the case? @@ -515,8 +542,8 @@ def _parse_and_validate_image_input( return Gemma3nImagePixelInputs(pixel_values=pixel_values) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Gemma3nAudioInputs]: - + self, **kwargs: object + ) -> Optional[Gemma3nAudioInputs]: input_features_padded = kwargs.pop("input_features_padded", None) if input_features_padded is None: return None @@ -536,14 +563,20 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key == "input_features_padded" \ - and "audio" not in mm_input_by_modality: - mm_input_by_modality[ - "audio"] = self._parse_and_validate_audio_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key == "input_features_padded" + and "audio" not in mm_input_by_modality + ): + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input( + **kwargs + ) return mm_input_by_modality def _process_image_input( @@ -553,16 +586,20 @@ def _process_image_input( assert self.vision_tower is not None pixel_values = image_input["pixel_values"] - vision_outputs = self.vision_tower(pixel_values=pixel_values, - do_pooling=False, - return_dict=True).last_hidden_state + vision_outputs = self.vision_tower( + pixel_values=pixel_values, do_pooling=False, return_dict=True + ).last_hidden_state # TODO try to avoid copy here # (batch, channels, height, width) to (batch, height * width, channels) - vision_outputs = vision_outputs.reshape( - vision_outputs.shape[0], - self.config.vision_config.hidden_size, - self.config.vision_soft_tokens_per_image, - ).permute(0, 2, 1).contiguous() + vision_outputs = ( + vision_outputs.reshape( + vision_outputs.shape[0], + self.config.vision_config.hidden_size, + self.config.vision_soft_tokens_per_image, + ) + .permute(0, 2, 1) + .contiguous() + ) # Normalize and embed the soft tokens into language model space. vision_outputs *= self.config.vision_config.hidden_size**0.5 # Return a list of embeddings instead of a batched tensor @@ -576,8 +613,9 @@ def _process_audio_input( # Run on padded features to enable batching input_features = audio_input["input_features_padded"].squeeze(1) input_features_mask = audio_input["input_features_mask"].squeeze(1) - audio_outputs, audio_mask = self.audio_tower(input_features, - ~input_features_mask) + audio_outputs, audio_mask = self.audio_tower( + input_features, ~input_features_mask + ) audio_features = self.embed_audio(inputs_embeds=audio_outputs) # ruff: noqa @@ -587,30 +625,29 @@ def _process_audio_input( # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad # the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab. # TODO precompute and cache padding - audio_padding_toks = torch.tensor([[self.vocab_size - 1]], - dtype=torch.long, - device=audio_features.device) + audio_padding_toks = torch.tensor( + [[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device + ) audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks) - audio_features = torch.where(audio_mask.unsqueeze(-1), - audio_padding_embs, audio_features) + audio_features = torch.where( + audio_mask.unsqueeze(-1), audio_padding_embs, audio_features + ) audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501 extra_padding_features = audio_padding_embs.expand( - audio_batch_size, extra_padding_tokens, audio_embed_dim) + audio_batch_size, extra_padding_tokens, audio_embed_dim + ) - audio_features = torch.cat((audio_features, extra_padding_features), - dim=1) + audio_features = torch.cat((audio_features, extra_padding_features), dim=1) # Return a list of embeddings instead of a batched tensor return audio_features.unbind(0) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if mm_input_by_modality is None: return [] @@ -640,12 +677,16 @@ def get_input_embeddings( # them here, as the model forward has only access to the input_embeds. if input_ids is not None: per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings( - input_ids) + input_ids + ) per_layer_inputs = per_layer_inputs.reshape( - -1, self.config.text_config.num_hidden_layers, - self.config.text_config.hidden_size_per_layer_input) - self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_( - per_layer_inputs) + -1, + self.config.text_config.num_hidden_layers, + self.config.text_config.hidden_size_per_layer_input, + ) + self.per_layer_embeddings[: per_layer_inputs.shape[0]].copy_( + per_layer_inputs + ) # This is to satisfy the type checker for each overload if multimodal_embeddings is None or is_multimodal is None: @@ -658,12 +699,14 @@ def get_input_embeddings( handle_oov_mm_token=handle_oov_mm_token, ) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object) -> IntermediateTensors: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None @@ -672,7 +715,7 @@ def forward(self, # select a chunk of pre-allocated PLEs. During normal execution, # `get_input_embeddings` is called before forward, hence this slice # will contain PLEs computed from the actual input_ids. - per_layer_inputs = self.per_layer_embeddings[:inputs_embeds.shape[0]] + per_layer_inputs = self.per_layer_embeddings[: inputs_embeds.shape[0]] hidden_states = self.language_model.model( input_ids, @@ -680,7 +723,8 @@ def forward(self, per_layer_inputs=per_layer_inputs, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - **kwargs) + **kwargs, + ) return hidden_states @@ -690,8 +734,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -702,7 +745,8 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", - tower_model="vision_tower") + tower_model="vision_tower", + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -714,16 +758,19 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: raise ValueError(f"Unsupported modality: {modality}") @classmethod - def get_generation_prompt(cls, audio: np.ndarray, - stt_config: SpeechToTextConfig, - model_config: ModelConfig, - language: Optional[str], - task_type: Literal["transcribe", "translate"], - request_prompt: str, - to_language: Optional[str]) -> PromptType: + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: """ Gemma3n supports "free-form" transcription. - We fix its prompt here to standardize transcriptions/translations + We fix its prompt here to standardize transcriptions/translations requests. """ # Transcribe this audio [into <>] | for transcription @@ -752,8 +799,9 @@ def get_generation_prompt(cls, audio: np.ndarray, return cast(PromptType, prompts_dict) @classmethod - def get_speech_to_text_config(cls, model_config: ModelConfig, - task_type: str) -> SpeechToTextConfig: + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: return SpeechToTextConfig( # Let's set this to 30 as suggested in the docs for now, although # the model is only limited by its context length. diff --git a/vllm/model_executor/models/glm.py b/vllm/model_executor/models/glm.py index defa77b84e44..a6991f8e43fe 100644 --- a/vllm/model_executor/models/glm.py +++ b/vllm/model_executor/models/glm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only HF format GLM-4 model compatible with THUDM weights.""" + from vllm.config import VllmConfig from vllm.model_executor.models.llama import LlamaForCausalLM @@ -8,7 +9,6 @@ class GlmForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.partial_rotary_factor = 0.5 super().__init__(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index f49f21a40f82..f25f50602e6c 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" + from collections.abc import Iterable from typing import Optional, Union @@ -34,8 +35,7 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope @@ -49,21 +49,22 @@ class Glm4Attention(nn.Module): - - def __init__(self, - config: Glm4Config, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - head_dim: Optional[int] = None, - qkv_bias: bool = False, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[tuple] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER) -> None: + def __init__( + self, + config: Glm4Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -112,14 +113,16 @@ def __init__(self, partial_rotary_factor=partial_rotary_factor, is_neox_style=False, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=attn_type) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=attn_type, + ) def forward( self, @@ -135,11 +138,12 @@ def forward( class Glm4DecoderLayer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str = "", - config: Optional[Glm4Config] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[Glm4Config] = None, + ) -> None: super().__init__() config = config or vllm_config.model_config.hf_config @@ -157,8 +161,8 @@ def __init__(self, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - qkv_bias=getattr(config, 'attention_bias', False), - head_dim=getattr(config, 'head_dim', None), + qkv_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), cache_config=cache_config, quant_config=quant_config, rope_scaling=rope_scaling, @@ -172,14 +176,14 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_self_attn_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_mlp_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_self_attn_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -192,8 +196,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -202,8 +205,7 @@ def forward( hidden_states = self.post_self_attn_layernorm(hidden_states) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) hidden_states = self.post_mlp_layernorm(hidden_states) @@ -221,13 +223,13 @@ def forward( "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class Glm4Model(LlamaModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=Glm4DecoderLayer) + super().__init__( + vllm_config=vllm_config, prefix=prefix, layer_type=Glm4DecoderLayer + ) class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): @@ -253,25 +255,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.model = Glm4Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Glm4Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -283,8 +288,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -294,11 +300,9 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index c253631eb8b4..304e721fade5 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -29,7 +29,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Callable, Literal, Optional, Union, override +from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch @@ -41,47 +41,65 @@ from transformers import __version__ as TRANSFORMERS_VERSION from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.image_processing_glm4v import ( - Glm4vImageProcessor, smart_resize) -from transformers.models.glm4v.video_processing_glm4v import ( - Glm4vVideoProcessor) + Glm4vImageProcessor, + smart_resize, +) +from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.video_utils import VideoMetadata from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import (check_upstream_fa_availability, - maybe_get_vit_flash_attn_backend) +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions -from vllm.distributed import (get_tensor_model_parallel_world_size, - parallel_state) +from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, VideoItem) -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from ..layers.activation import SiluAndMul -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .qwen2_vl import (_create_qwen2vl_field_factory, - apply_rotary_pos_emb_vision) -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -100,6 +118,7 @@ class Glm4vImagePixelInputs(TensorSchema): - ni: Number of images - g: Grid dimensions (3 for grid_t, grid_h, grid_w) """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[torch.Tensor, TensorShape("np", "cpp")] @@ -114,6 +133,7 @@ class Glm4vImageEmbeddingInputs(TensorSchema): - n: Number of images - g: Grid dimensions (3 for grid_t, grid_h, grid_w) """ + type: Literal["image_embeds"] = "image_embeds" image_embeds: Annotated[torch.Tensor, TensorShape("f", "h")] @@ -133,6 +153,7 @@ class Glm4vVideoPixelInputs(TensorSchema): - g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w) """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctpp")] @@ -148,6 +169,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema): - g: Grid dimensions (3 for grid_t which is usually 1 for processed video, grid_h, grid_w) """ + type: Literal["video_embeds"] = "video_embeds" video_embeds: Annotated[torch.Tensor, TensorShape("p", "h")] @@ -160,7 +182,6 @@ class Glm4vVideoEmbeddingInputs(TensorSchema): class Glm4vVisionMLP(nn.Module): - def __init__( self, in_features: int, @@ -208,8 +229,7 @@ def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): ) gathered_tensors_split = [ - torch.split(tensor, hidden_size // tp_size, -1) - for tensor in gathered_tensors + torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors ] ordered_tensors = [ tensor for pair in zip(*gathered_tensors_split) for tensor in pair @@ -219,7 +239,6 @@ def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): class Glm4vVisionAttention(nn.Module): - def __init__( self, embed_dim: int, @@ -231,14 +250,18 @@ def __init__( ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) - self.tp_rank = (0 if use_data_parallel else - parallel_state.get_tensor_model_parallel_rank()) + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.tp_rank = ( + 0 if use_data_parallel else parallel_state.get_tensor_model_parallel_rank() + ) self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) + projection_size, num_heads + ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, self.tp_size) + num_heads, self.tp_size + ) self.qkv = QKVParallelLinear( hidden_size=embed_dim, @@ -263,26 +286,30 @@ def __init__( # Detect attention implementation. self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype()) + dtype=torch.get_default_dtype(), + ) self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func \ - = maybe_get_vit_flash_attn_backend( + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, ) + ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( - f"GLM-4V does not support {self.attn_backend} backend now.") + f"GLM-4V does not support {self.attn_backend} backend now." + ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -303,12 +330,12 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -317,8 +344,7 @@ def forward( q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) @@ -326,7 +352,6 @@ def forward( q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) output = self.flash_attn_varlen_func( @@ -341,9 +366,9 @@ def forward( causal=False, ) - context_layer = rearrange(output, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -353,36 +378,36 @@ def forward( q_i = q[:, start_idx:end_idx] k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + q_i, k_i, v_i = ( + rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() output, _ = self.proj(context_layer) return output class Glm4vVisionBlock(nn.Module): - def __init__( self, dim: int, @@ -416,12 +441,12 @@ def __init__( ) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: x_attn = self.attn( self.norm1(x), @@ -437,7 +462,6 @@ def forward( class Glm4vVisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -461,14 +485,12 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x).view(L, self.hidden_size) return x class Glm4vPatchMerger(nn.Module): - def __init__( self, d_model: int, @@ -519,7 +541,6 @@ def forward(self, x: torch.Tensor): class Glm4vVisionEmbeddings(nn.Module): - def __init__(self, config: Glm4vVisionConfig): super().__init__() self.config = config @@ -527,18 +548,18 @@ def __init__(self, config: Glm4vVisionConfig): self.image_size = config.image_size self.patch_size = config.patch_size - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, - self.embed_dim) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer( "position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False, ) - def forward(self, embeddings, lengths, image_shapes, h_coords, - w_coords) -> torch.Tensor: + def forward( + self, embeddings, lengths, image_shapes, h_coords, w_coords + ) -> torch.Tensor: pos_embed_weight = self.position_embedding.weight hidden_size = pos_embed_weight.shape[1] total_seq = h_coords.shape[0] @@ -549,29 +570,27 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, # Handle empty sequence case if total_seq == 0: - adapted_pos_embed = torch.empty(0, - hidden_size, - device=device, - dtype=pos_embed_weight.dtype) + adapted_pos_embed = torch.empty( + 0, hidden_size, device=device, dtype=pos_embed_weight.dtype + ) else: # Convert inputs to tensors if needed if isinstance(lengths, list): - lengths = torch.tensor(lengths, - device=device, - dtype=torch.long) + lengths = torch.tensor(lengths, device=device, dtype=torch.long) if not isinstance(image_shapes, torch.Tensor): - image_shapes = torch.tensor(image_shapes, - device=device, - dtype=torch.long) + image_shapes = torch.tensor( + image_shapes, device=device, dtype=torch.long + ) # Prepare 2D position embedding orig_size_sq = pos_embed_weight.shape[0] orig_size = int(orig_size_sq**0.5) - pos_embed_2d = (pos_embed_weight.view( - orig_size, orig_size, - hidden_size).permute(2, 0, - 1).unsqueeze(0).to(device=device, - dtype=torch.float32)) + pos_embed_2d = ( + pos_embed_weight.view(orig_size, orig_size, hidden_size) + .permute(2, 0, 1) + .unsqueeze(0) + .to(device=device, dtype=torch.float32) + ) # Calculate target dimensions for each patch # Add bounds checking for data parallel mode @@ -584,23 +603,21 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, for i in range(len(lengths)): # Cycle through available shapes shape_idx = i % image_shapes.shape[0] - target_h_list.append(image_shapes[shape_idx, - 1].repeat(lengths[i])) - target_w_list.append(image_shapes[shape_idx, - 2].repeat(lengths[i])) - target_h = torch.cat(target_h_list).to(device=device, - dtype=torch.float32) - target_w = torch.cat(target_w_list).to(device=device, - dtype=torch.float32) + target_h_list.append(image_shapes[shape_idx, 1].repeat(lengths[i])) + target_w_list.append(image_shapes[shape_idx, 2].repeat(lengths[i])) + target_h = torch.cat(target_h_list).to( + device=device, dtype=torch.float32 + ) + target_w = torch.cat(target_w_list).to( + device=device, dtype=torch.float32 + ) else: - target_h = torch.cat([ - image_shapes[i, 1].repeat(lengths[i]) - for i in range(len(lengths)) - ]).to(device=device, dtype=torch.float32) - target_w = torch.cat([ - image_shapes[i, 2].repeat(lengths[i]) - for i in range(len(lengths)) - ]).to(device=device, dtype=torch.float32) + target_h = torch.cat( + [image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))] + ).to(device=device, dtype=torch.float32) + target_w = torch.cat( + [image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))] + ).to(device=device, dtype=torch.float32) # Normalize coordinates to [-1, 1] range for grid_sample h_coords = h_coords.to(device=device, dtype=torch.float32) @@ -609,8 +626,7 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, norm_h = ((h_coords + 0.5) / target_h) * 2 - 1 # Create sampling grid - grid = (torch.stack((norm_w, norm_h), - dim=-1).unsqueeze(0).unsqueeze(2)) + grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2) # Perform bicubic interpolation interpolated_embed_fp32 = F.grid_sample( @@ -623,9 +639,11 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, # Reshape and convert back to original dtype adapted_pos_embed_fp32 = ( - interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)) - adapted_pos_embed = adapted_pos_embed_fp32.to( - pos_embed_weight.dtype).to(embeddings.device) + interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0) + ) + adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to( + embeddings.device + ) # Add adapted position encoding to embeddings embeddings = embeddings + adapted_pos_embed @@ -633,13 +651,11 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, class Glm4vVisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -648,16 +664,22 @@ def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta**(torch.arange( - 0, - self.dim, - 2, - dtype=torch.float, - device=self.inv_freq.device, - ) / self.dim)) - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + self.inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, + self.dim, + 2, + dtype=torch.float, + device=self.inv_freq.device, + ) + / self.dim + ) + ) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs @@ -667,7 +689,6 @@ def forward(self, seqlen: int) -> torch.Tensor: class Glm4vVisionTransformer(nn.Module): - def __init__( self, vision_config: Glm4vVisionConfig, @@ -700,17 +721,20 @@ def __init__( norm_layer = partial(RMSNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Glm4vVisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.out_hidden_size, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=self.use_data_parallel, - ) for layer_idx in range(depth) - ]) + self.blocks = nn.ModuleList( + [ + Glm4vVisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.out_hidden_size, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=self.use_data_parallel, + ) + for layer_idx in range(depth) + ] + ) self.merger = Glm4vPatchMerger( d_model=vision_config.out_hidden_size, context_dim=vision_config.intermediate_size, @@ -721,21 +745,25 @@ def __init__( ) self.embeddings = Glm4vVisionEmbeddings(vision_config) - self.post_conv_layernorm = RMSNorm(vision_config.hidden_size, - eps=vision_config.rms_norm_eps) + self.post_conv_layernorm = RMSNorm( + vision_config.hidden_size, eps=vision_config.rms_norm_eps + ) self.downsample = nn.Conv2d( in_channels=vision_config.hidden_size, out_channels=vision_config.out_hidden_size, kernel_size=vision_config.spatial_merge_size, stride=vision_config.spatial_merge_size, ) - self.post_layernorm = RMSNorm(vision_config.hidden_size, - eps=vision_config.rms_norm_eps) + self.post_layernorm = RMSNorm( + vision_config.hidden_size, eps=vision_config.rms_norm_eps + ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability(torch.get_default_dtype()): + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): self.attn_backend = _Backend.FLASH_ATTN @property @@ -751,20 +779,27 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = (hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten()) - wpos_ids = (wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten()) - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) @@ -777,8 +812,10 @@ def compute_attn_mask_seqlen( ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return max_seqlen, seqlens @@ -798,15 +835,16 @@ def forward( # compute position embedding rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw) # compute cu_seqlens - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32) + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # pre-compute seqlens for attn mask to reduce cuMemcpy operations max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) - x = self.embeddings(x, seqlens, grid_thw, image_type_ids[:, 0], - image_type_ids[:, 1]) + x = self.embeddings( + x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1] + ) # transformers x = x.unsqueeze(1) @@ -822,16 +860,14 @@ def forward( # adapter x = self.post_layernorm(x) - x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, - x.shape[-1]) + x = x.view(-1, self.spatial_merge_size, self.spatial_merge_size, x.shape[-1]) x = x.permute(0, 3, 1, 2) x = self.downsample(x).view(-1, self.out_hidden_size) x = self.merger(x) return x - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), @@ -855,15 +891,13 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Glm4vProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config() @@ -896,17 +930,16 @@ def _get_vision_info( if do_resize: resized_height, resized_width = smart_resize( num_frames=num_frames - if num_frames > temporal_patch_size else temporal_patch_size, + if num_frames > temporal_patch_size + else temporal_patch_size, height=image_height, width=image_width, factor=patch_size * merge_size, max_pixels=max_image_pixels, ) - preprocessed_size = ImageSize(width=resized_width, - height=resized_height) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: - preprocessed_size = ImageSize(width=image_width, - height=image_height) + preprocessed_size = ImageSize(width=image_width, height=image_height) # NOTE: Frames are padded to be divisible by `temporal_patch_size` # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 @@ -922,8 +955,9 @@ def _get_vision_info( return preprocessed_size, num_vision_tokens def get_image_size_with_most_features(self) -> ImageSize: - max_image_size, _ = self._get_vision_info(image_width=9999999, - image_height=9999999) + max_image_size, _ = self._get_vision_info( + image_width=9999999, image_height=9999999 + ) return max_image_size def get_num_image_tokens( @@ -990,22 +1024,22 @@ def get_num_frames_with_most_features( max_videos = mm_counts.get("video", 0) max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO + ) return max(max_frames_per_video, 1) - def _get_video_second_idx(self, metadata: dict[str, Any], - total_frames: int) -> list[int]: + def _get_video_second_idx( + self, metadata: dict[str, Any], total_frames: int + ) -> list[int]: video_processor = self.get_video_processor() video_fps = metadata.get("fps", video_processor.fps) meta_frames = metadata.get("total_num_frames", total_frames) max_frame_idx = meta_frames - 1 - duration = metadata.get("duration", - round(max_frame_idx / video_fps) + 1) + duration = metadata.get("duration", round(max_frame_idx / video_fps) + 1) do_sample_frames = metadata["do_sample_frames"] if not do_sample_frames: frame_indices = metadata["frames_indices"] @@ -1016,18 +1050,17 @@ def _get_video_second_idx(self, metadata: dict[str, Any], min( max_frame_idx, int(math.ceil(i * video_fps / video_processor.fps)), - ) for i in range(n) + ) + for i in range(n) ] else: - num_samples = int(video_processor.max_duration * - video_processor.fps) + num_samples = int(video_processor.max_duration * video_processor.fps) if num_samples >= meta_frames: frame_indices = list(range(meta_frames)) else: - target_seconds = np.linspace(0, - duration, - num_samples, - endpoint=True) + target_seconds = np.linspace( + 0, duration, num_samples, endpoint=True + ) frame_indices = [ min(max_frame_idx, int(math.ceil(t * video_fps))) for t in target_seconds @@ -1069,8 +1102,7 @@ def _construct_video_placeholder( assert isinstance(grid_thw, torch.Tensor) timestamps = self._get_video_second_idx(metadata, len(video_array)) frames_idx_token = [ - tokenizer.encode(str(i), add_special_tokens=False) - for i in timestamps + tokenizer.encode(str(i), add_special_tokens=False) for i in timestamps ] T, H, W = grid_thw num_tokens_per_frame = int(H * W) // merge_length @@ -1078,8 +1110,7 @@ def _construct_video_placeholder( placeholder.append(bov_token_id) for frame_idx in frames_idx_token: placeholder.append(boi_token_id) - placeholder.extend([hf_processor.video_token_id] * - num_tokens_per_frame) + placeholder.extend([hf_processor.video_token_id] * num_tokens_per_frame) placeholder.append(eoi_token_id) placeholder.extend(frame_idx) placeholder.append(eov_token_id) @@ -1088,7 +1119,6 @@ def _construct_video_placeholder( class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -1116,22 +1146,22 @@ def get_dummy_mm_data( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = ( - self.info.get_image_size_with_most_features()) + target_width, target_height = self.info.get_image_size_with_most_features() target_num_frames = self.info.get_num_frames_with_most_features( - seq_len, mm_counts) + seq_len, mm_counts + ) image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), - "video": - self._get_dummy_videos( + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, @@ -1155,22 +1185,28 @@ def _get_dummy_videos( logger.warning( "video.num_frames override (%d) exceeds model's " "maximum number of frames (%d), will be ignored", - overrides.num_frames, num_frames) + overrides.num_frames, + num_frames, + ) num_frames = min(num_frames, overrides.num_frames) if overrides.width: if overrides.width > width: logger.warning( "video.width override (%d) exceeds model's " - "maximum width (%d), will be ignored", overrides.width, - width) + "maximum width (%d), will be ignored", + overrides.width, + width, + ) width = min(width, overrides.width) if overrides.height: if overrides.height > height: logger.warning( "video.height override (%d) exceeds model's " "maximum height (%d), will be ignored", - overrides.height, height) - height = min(height, override.height) + overrides.height, + height, + ) + height = min(height, overrides.height) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video_items = [] @@ -1190,7 +1226,6 @@ def _get_dummy_videos( class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: return MultiModalDataParser(video_needs_metadata=True) @@ -1207,8 +1242,11 @@ def _call_hf_processor( # GLM-4.1V use `image_token_id` as video placeholder, we need to # replace it with `video_token_id` for video processing. So we # separate video processing from image processing. - if ("videos" in mm_data and isinstance(mm_data["videos"], list) - and len(mm_data["videos"]) > 0): + if ( + "videos" in mm_data + and isinstance(mm_data["videos"], list) + and len(mm_data["videos"]) > 0 + ): video_grid_thw_lst = [] pixel_values_videos_lst = [] for item in mm_data.pop("videos", []): @@ -1217,25 +1255,31 @@ def _call_hf_processor( # don't update mm_kwargs inplace video_mm_kwargs = dict(**mm_kwargs) video_mm_kwargs["do_sample_frames"] = metadata.get( - "do_sample_frames", True) + "do_sample_frames", True + ) video_mm_data = dict() video_mm_data["videos"] = [[video_array]] # backward compatibility for Transformers 4.55 unuse_metadata = ["do_sample_frames"] - if not hasattr( - VideoMetadata, - "frames_indices") and "frames_indices" in metadata: + if ( + not hasattr(VideoMetadata, "frames_indices") + and "frames_indices" in metadata + ): unuse_metadata.append("frames_indices") - video_mm_data["video_metadata"] = [[ - VideoMetadata( - **{ - k: metadata[k] - for k in metadata if k not in unuse_metadata - }) - ]] + video_mm_data["video_metadata"] = [ + [ + VideoMetadata( + **{ + k: metadata[k] + for k in metadata + if k not in unuse_metadata + } + ) + ] + ] video_outputs = super()._call_hf_processor( prompt="<|begin_of_video|><|video|><|end_of_video|>", @@ -1244,7 +1288,8 @@ def _call_hf_processor( tok_kwargs=tok_kwargs, ) if not video_mm_kwargs["do_sample_frames"] and Version( - TRANSFORMERS_VERSION) < Version("4.56.0"): + TRANSFORMERS_VERSION + ) < Version("4.56.0"): # Transformers v4.55 has incorrect timestamps issue for # skip sampling. We construct the placeholder manually to # get placeholders with correct timestamps. @@ -1257,9 +1302,9 @@ def _call_hf_processor( else: input_ids = video_outputs.pop("input_ids") input_ids[input_ids == processor.image_token_id] = ( - processor.video_token_id) - video_placeholder = processor.tokenizer.batch_decode( - input_ids)[0] + processor.video_token_id + ) + video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] prompt = prompt.replace( "<|begin_of_video|><|video|><|end_of_video|>", video_placeholder, @@ -1267,8 +1312,7 @@ def _call_hf_processor( ) video_grid_thw_lst.append(video_outputs["video_grid_thw"]) - pixel_values_videos_lst.append( - video_outputs["pixel_values_videos"]) + pixel_values_videos_lst.append(video_outputs["pixel_values_videos"]) video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_grid_thw=torch.cat(video_grid_thw_lst), @@ -1294,8 +1338,8 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _create_qwen2vl_field_factory( - self.info.get_hf_config().vision_config.spatial_merge_size)( - hf_inputs) + self.info.get_hf_config().vision_config.spatial_merge_size + )(hf_inputs) def _get_prompt_updates( self, @@ -1304,8 +1348,7 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) merge_length = image_processor.merge_size**2 @@ -1324,7 +1367,8 @@ def get_video_replacement_glm4v(item_idx: int): video, metadata = mm_items["video"][item_idx] placeholder = self.info._construct_video_placeholder( - video, metadata, grid_thw) + video, metadata, grid_thw + ) return PromptUpdateDetails.select_token_id( placeholder, embed_token_id=hf_processor.video_token_id, @@ -1349,8 +1393,9 @@ def get_video_replacement_glm4v(item_idx: int): info=Glm4vProcessingInfo, dummy_inputs=Glm4vDummyInputsBuilder, ) -class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): +class Glm4vForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP +): merge_by_field_config = True packed_modules_mapping = { @@ -1359,7 +1404,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, "k_proj", "v_proj", ], - "gate_up_proj": ["gate_up_proj"] + "gate_up_proj": ["gate_up_proj"], } # To ensure correct weight loading and mapping. @@ -1368,7 +1413,8 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal, "lm_head.": "language_model.lm_head.", "model.language_model.": "language_model.model.", "model.visual.": "visual.", - }) + } + ) supports_encoder_tp_data = True @@ -1410,13 +1456,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), - architectures=architectures) + architectures=architectures, + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Glm4vImageInputs]: + self, **kwargs: object + ) -> Optional[Glm4vImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1439,7 +1488,8 @@ def _parse_and_validate_image_input( ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Glm4vVideoInputs]: + self, **kwargs: object + ) -> Optional[Glm4vVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1462,7 +1512,8 @@ def _parse_and_validate_video_input( ) def _process_image_input( - self, image_input: Glm4vImageInputs) -> tuple[torch.Tensor, ...]: + self, image_input: Glm4vImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1472,20 +1523,21 @@ def _process_image_input( else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values, - grid_thw.tolist(), - rope_type="rope_3d") + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d" + ) else: - image_embeds = self.visual(pixel_values, - grid_thw=grid_thw.tolist()) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw.tolist()) merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return image_embeds.split(sizes) def _process_video_input( - self, video_input: Glm4vVideoInputs) -> tuple[torch.Tensor, ...]: + self, video_input: Glm4vVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1494,19 +1546,25 @@ def _process_video_input( video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) + self.visual.dtype + ) if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values_videos, - grid_thw.tolist(), - rope_type="rope_3d") + return run_dp_sharded_mrope_vision_model( + self.visual, + pixel_values_videos, + grid_thw.tolist(), + rope_type="rope_3d", + ) else: - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw.tolist()) + video_embeds = self.visual( + pixel_values_videos, grid_thw=grid_thw.tolist() + ) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -1515,23 +1573,29 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if (input_key in ("pixel_values", "image_embeds") - and "image" not in mm_input_by_modality): - mm_input_by_modality["image"] = ( - self._parse_and_validate_image_input(**kwargs)) - if (input_key in ("pixel_values_videos", "video_embeds") - and "video" not in mm_input_by_modality): - mm_input_by_modality["video"] = ( - self._parse_and_validate_video_input(**kwargs)) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + self, **kwargs: object + ) -> Optional[MultiModalEmbeddings]: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None @@ -1591,8 +1655,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 2557748b7faa..5db6f297dbf2 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights.""" + import typing from collections.abc import Callable, Iterable from itertools import islice @@ -34,35 +35,48 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Glm4MoeMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -74,19 +88,24 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -97,7 +116,6 @@ def forward(self, x): class Glm4MoE(nn.Module): - def __init__( self, config: Glm4MoeConfig, @@ -116,8 +134,10 @@ def __init__( self.n_shared_experts: int = config.n_shared_experts if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) # NOTE In the transformers implementation, the gate isn't an nn.Linear, # so we cannot use ReplicatedLinear here. # See: https://github.com/huggingface/transformers/blob/v4.55.1/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L260 @@ -128,7 +148,8 @@ def __init__( dtype=torch.float32, ) self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts, dtype=torch.float32)) + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) # Load balancing settings. vllm_config = get_current_vllm_config() @@ -137,18 +158,16 @@ def __init__( self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) if config.n_shared_experts is not None: - intermediate_size = (config.moe_intermediate_size * - config.n_shared_experts) + intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = Glm4MoeMLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, @@ -195,7 +214,8 @@ def __init__( routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + num_redundant_experts=self.n_redundant_experts, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -204,27 +224,27 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states.to(dtype=torch.float32)) - fused_moe_out = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if self.shared_experts is not None: shared_output, final_hidden_states = fused_moe_out assert shared_output is not None - final_hidden_states = \ - final_hidden_states * self.routed_scaling_factor\ - + shared_output + final_hidden_states = ( + final_hidden_states * self.routed_scaling_factor + shared_output + ) else: final_hidden_states = fused_moe_out * self.routed_scaling_factor if self.tp_size > 1: - final_hidden_states = ( - self.experts.maybe_all_reduce_tensor_model_parallel( - final_hidden_states)) + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) return final_hidden_states.view(num_tokens, hidden_dim) class Glm4MoeAttention(nn.Module): - def __init__( self, config: Glm4MoeConfig, @@ -266,19 +286,23 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.use_qk_norm = use_qk_norm - self.qkv_proj = QKVParallelLinear(hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) self.rotary_emb = get_rope( @@ -311,10 +335,12 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.use_qk_norm: - q = self.q_norm(q.reshape(-1, self.num_heads, - self.head_dim)).reshape(q.shape) - k = self.k_norm(k.reshape(-1, self.num_kv_heads, - self.head_dim)).reshape(k.shape) + q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape( + q.shape + ) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, self.head_dim)).reshape( + k.shape + ) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) @@ -323,7 +349,6 @@ def forward( class Glm4MoeDecoderLayer(nn.Module): - def __init__( self, config: Glm4MoeConfig, @@ -336,11 +361,10 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 131072) + max_position_embeddings = getattr(config, "max_position_embeddings", 131072) # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) + layer_idx = int(prefix.split(sep=".")[-1]) self.layer_idx = layer_idx self.self_attn = Glm4MoeAttention( @@ -360,8 +384,10 @@ def __init__( use_qk_norm=config.use_qk_norm, ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace): + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + ): self.mlp = Glm4MoE( config=config, quant_config=quant_config, @@ -369,16 +395,18 @@ def __init__( enable_eplb=enable_eplb, ) else: - self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.mlp = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.routed_scaling_factor = config.routed_scaling_factor def forward( @@ -391,12 +419,9 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -407,9 +432,9 @@ def forward( "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class Glm4MoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -423,9 +448,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - prefix=f"{prefix}.embed_tokens") + config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens" + ) else: self.embed_tokens = PPMissingLayer() @@ -438,15 +462,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=prefix, enable_eplb=enable_eplb, ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -473,27 +498,26 @@ def forward( hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales @@ -502,10 +526,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -522,7 +546,7 @@ def load_weights(self, weights: Iterable[tuple[str, spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -532,7 +556,7 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -567,14 +591,17 @@ def load_weights(self, weights: Iterable[tuple[str, # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) if success: name = name_mapped break @@ -598,8 +625,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -627,24 +655,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Glm4MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Glm4MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) self.expert_weights = [] # Set MoE hyperparameters - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) + self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_expert_groups = config.n_group self.moe_layers: list[FusedMoE] = [] @@ -695,8 +725,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -706,8 +737,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) @@ -715,13 +745,14 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() -def get_spec_layer_idx_from_weight_name(config: Glm4MoeConfig, - weight_name: str) -> Optional[int]: - if hasattr(config, - "num_nextn_predict_layers") and (config.num_nextn_predict_layers - > 0): +def get_spec_layer_idx_from_weight_name( + config: Glm4MoeConfig, weight_name: str +) -> Optional[int]: + if hasattr(config, "num_nextn_predict_layers") and ( + config.num_nextn_predict_layers > 0 + ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if f"layers.{layer_idx+i}." in weight_name: + if f"layers.{layer_idx + i}." in weight_name: return layer_idx + i return None diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index 57b698e239ec..beb40632246c 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -36,7 +36,9 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors @@ -46,7 +48,6 @@ class SharedHead(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -55,17 +56,18 @@ def __init__( ) -> None: super().__init__() self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "head")) + self.head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head"), + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.norm(hidden_states) class Glm4MoeMultiTokenPredictorLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -76,16 +78,16 @@ def __init__( super().__init__() self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.eh_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) - self.shared_head = SharedHead(config=config, - prefix=prefix, - quant_config=quant_config) - self.mtp_block = Glm4MoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) + self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + self.shared_head = SharedHead( + config=config, prefix=prefix, quant_config=quant_config + ) + self.mtp_block = Glm4MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) def forward( self, @@ -102,34 +104,37 @@ def forward( previous_hidden_states = self.hnorm(previous_hidden_states) hidden_states = self.eh_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - residual=None) + hidden_states, residual = self.mtp_block( + positions=positions, hidden_states=hidden_states, residual=None + ) hidden_states = residual + hidden_states return hidden_states class Glm4MoeMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict({ - str(idx): - Glm4MoeMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.layers = torch.nn.ModuleDict( + { + str(idx): Glm4MoeMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -149,7 +154,7 @@ def forward( ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - current_step_idx = (spec_step_idx % self.num_mtp_layers) + current_step_idx = spec_step_idx % self.num_mtp_layers return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, @@ -163,22 +168,21 @@ def compute_logits( hidden_states: torch.Tensor, spec_step_idx: int = 0, ) -> torch.Tensor: - current_step_idx = (spec_step_idx % self.num_mtp_layers) - mtp_layer = self.layers[str(self.mtp_start_layer_idx + - current_step_idx)] - logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states)) + current_step_idx = spec_step_idx % self.num_mtp_layers + mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] + logits = self.logits_processor( + mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) + ) return logits class Glm4MoeMTP(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - self.model = Glm4MoeMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = Glm4MoeMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -192,8 +196,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( @@ -203,8 +208,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.model.compute_logits(hidden_states, spec_step_idx) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -220,7 +224,8 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -229,7 +234,7 @@ def load_weights(self, weights: Iterable[tuple[str, if spec_layer is None: continue name = self._rewrite_spec_layer_name(spec_layer, name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -239,7 +244,7 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -259,11 +264,13 @@ def load_weights(self, weights: Iterable[tuple[str, param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -272,13 +279,16 @@ def load_weights(self, weights: Iterable[tuple[str, # According to DeepSeek-V3 Technical Report, MTP modules # shares embedding layer. We only load the first weights. - if (spec_layer != self.model.mtp_start_layer_idx - and ".layers" not in name): + if ( + spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name + ): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -290,7 +300,11 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: and rename shared layer weights to be top level. """ spec_layer_weight_names = [ - "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", ] shared_weight_names = ["embed_tokens"] spec_layer_weight = False @@ -303,8 +317,9 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: break if not spec_layer_weight: # treat rest weights as weights for transformer layer block - name = name.replace(f"model.layers.{spec_layer}.", - f"model.layers.{spec_layer}.mtp_block.") + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) elif shared_weight: # treat shared weights as top level weights name = name.replace(f"model.layers.{spec_layer}.", "model.") diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index bc898105cbcb..a5c3ce0e6bf7 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -4,6 +4,7 @@ # Adapted from # https://github.com/zai-org/CogAgent """Inference-only CogAgent model compatible with THUDM weights.""" + from argparse import Namespace from collections.abc import Mapping, Sequence from typing import Annotated, Literal, Optional, Union @@ -22,28 +23,40 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from vllm.utils.tensor_schema import TensorSchema, TensorShape from .chatglm import ChatGLMBaseModel, ChatGLMModel -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) class GLMVImagePixelInputs(TensorSchema): @@ -54,21 +67,22 @@ class GLMVImagePixelInputs(TensorSchema): - h: Height of image - w: Width of image """ + type: Literal["pixel_values"] = "pixel_values" data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")] class EVA2CLIPPatchEmbedding(nn.Module): - def __init__(self, config): super().__init__() - self.proj = nn.Conv2d(config.in_channels, - config.hidden_size, - kernel_size=config.patch_size, - stride=config.patch_size) + self.proj = nn.Conv2d( + config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size, + ) self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) - self.position_embedding = nn.Embedding(config.num_positions, - config.hidden_size) + self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size) def forward(self, images: torch.Tensor) -> torch.Tensor: """ @@ -80,8 +94,7 @@ def forward(self, images: torch.Tensor) -> torch.Tensor: torch.Tensor Transformed tensor with shape (B, L, D) """ - images = images.to(device=self.proj.weight.device, - dtype=self.proj.weight.dtype) + images = images.to(device=self.proj.weight.device, dtype=self.proj.weight.dtype) x = self.proj(images) x = x.flatten(2).transpose(1, 2) cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) @@ -91,12 +104,11 @@ def forward(self, images: torch.Tensor) -> torch.Tensor: class EVA2CLIPAttention(nn.Module): - def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -119,8 +131,9 @@ def __init__( prefix=f"{prefix}.dense", ) - self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, - self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_rank, self.head_dim, self.scale + ) self.output_dropout = torch.nn.Dropout(config.dropout_prob) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -134,12 +147,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class EVA2CLIPMLP(nn.Module): - def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): super().__init__() self.config = config @@ -165,29 +177,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class EVA2CLIPTransformerLayer(nn.Module): - def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): super().__init__() - self.input_layernorm = LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.attention = EVA2CLIPAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.attention") - self.mlp = EVA2CLIPMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.post_attention_layernorm = LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = EVA2CLIPAttention( + config, quant_config=quant_config, prefix=f"{prefix}.attention" + ) + self.mlp = EVA2CLIPMLP( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) + self.post_attention_layernorm = LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) def forward(self, hidden_states): attention_input = hidden_states - attention_output = self.input_layernorm( - self.attention(attention_input)) + attention_output = self.input_layernorm(self.attention(attention_input)) hidden_states = attention_input + attention_output mlp_input = hidden_states mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) @@ -196,20 +206,23 @@ def forward(self, hidden_states): class EVA2CLIPTransformer(nn.Module): - def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): super().__init__() - self.layers = nn.ModuleList([ - EVA2CLIPTransformerLayer(config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + EVA2CLIPTransformerLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) def forward(self, hidden_states): for layer_module in self.layers: @@ -218,13 +231,12 @@ def forward(self, hidden_states): class EVA2CLIPGLU(nn.Module): - def __init__( self, config, in_features, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): """ The original implementation is the same as: @@ -233,14 +245,14 @@ def __init__( config.hidden_size, config.ffn_hidden_size, bias=False, - quant_config=quant_config + quant_config=quant_config, ) self.gate_proj = ColumnParallelLinear( config.hidden_size, config.ffn_hidden_size, bias=False, - quant_config=quant_config + quant_config=quant_config, ) ``` ``` @@ -255,7 +267,7 @@ def __init__( config.hidden_size, [config.ffn_hidden_size] * 2, bias=False, - quant_config=quant_config + quant_config=quant_config, ) ``` ``` @@ -263,27 +275,32 @@ def __init__( ``` """ super().__init__() - self.linear_proj = ReplicatedLinear(in_features, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.linear_proj") + self.linear_proj = ReplicatedLinear( + in_features, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj", + ) self.norm1 = nn.LayerNorm(config.hidden_size) self.act1 = nn.GELU() self.act2 = SiluAndMul() self.merged_proj = MergedColumnParallelLinear( - config.hidden_size, [config.ffn_hidden_size] * 2, + config.hidden_size, + [config.ffn_hidden_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.merged_proj") + prefix=f"{prefix}.merged_proj", + ) self.dense_4h_to_h = RowParallelLinear( config.ffn_hidden_size, config.hidden_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.dense_4h_to_h") + prefix=f"{prefix}.dense_4h_to_h", + ) def forward(self, x): x, _ = self.linear_proj(x) @@ -295,27 +312,30 @@ def forward(self, x): class EVA2CLIPModel(nn.Module): - def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '', + prefix: str = "", ): super().__init__() vision_config = Namespace(**config.vision_config) self.patch_embedding = EVA2CLIPPatchEmbedding(vision_config) - self.transformer = EVA2CLIPTransformer(vision_config, - quant_config=quant_config, - prefix=f"{prefix}.transformer") - self.linear_proj = EVA2CLIPGLU(config, - in_features=config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.linear_proj") - self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, - out_channels=config.hidden_size, - kernel_size=2, - stride=2) + self.transformer = EVA2CLIPTransformer( + vision_config, quant_config=quant_config, prefix=f"{prefix}.transformer" + ) + self.linear_proj = EVA2CLIPGLU( + config, + in_features=config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.linear_proj", + ) + self.conv = nn.Conv2d( + in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2, + ) self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.scaling_factor = vision_config.scaling_factor @@ -349,15 +369,14 @@ def forward(self, images: torch.Tensor) -> torch.Tensor: class GLM4VModel(ChatGLMModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) quant_config = vllm_config.quant_config - self.vision = EVA2CLIPModel(self.config, - quant_config, - prefix=f"{prefix}.vision") + self.vision = EVA2CLIPModel( + self.config, quant_config, prefix=f"{prefix}.vision" + ) class GLM4VProcessor: @@ -379,17 +398,19 @@ def __init__( vision_config = config.vision_config image_size = vision_config["image_size"] - self.image_transform = transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ]) + self.image_transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) def __call__( self, @@ -424,7 +445,6 @@ def __call__( class GLM4VProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(ChatGLMConfig) @@ -454,7 +474,6 @@ def get_num_image_feature_tokens(self) -> int: class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -477,16 +496,16 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): - def _hf_processor_applies_updates( self, prompt_text: str, @@ -530,17 +549,18 @@ def get_replacement(item_idx: int): ] -@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor, - info=GLM4VProcessingInfo, - dummy_inputs=GLM4VDummyInputsBuilder) -class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, - SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + GLM4VMultiModalProcessor, + info=GLM4VProcessingInfo, + dummy_inputs=GLM4VDummyInputsBuilder, +) +class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP): merge_by_field_config = True packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"], - "merged_proj": ["gate_proj", "dense_h_to_4h"] + "merged_proj": ["gate_proj", "dense_h_to_4h"], } def get_mm_mapping(self) -> MultiModelKeys: @@ -550,7 +570,8 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="transformer.encoder", connector="transformer.vision.linear_proj", - tower_model="transformer.vision.transformer") + tower_model="transformer.vision.transformer", + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -575,22 +596,21 @@ def __init__( self.transformer: GLM4VModel def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[GLMVImagePixelInputs]: + self, **kwargs: object + ) -> Optional[GLMVImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) if pixel_values is not None: expected_h = expected_w = self.config.vision_config["image_size"] - return GLMVImagePixelInputs(type="pixel_values", - data=pixel_values, - resolve_bindings={ - "h": expected_h, - "w": expected_w - }) + return GLMVImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, + ) return None - def _process_image_input( - self, image_input: GLMVImagePixelInputs) -> torch.Tensor: + def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tensor: pixel_values = image_input["data"].to(dtype=self.config.torch_dtype) return self.transformer.vision(pixel_values) @@ -600,8 +620,7 @@ def get_language_model(self) -> torch.nn.Module: get_input_embeddings = SupportsMultiModal.get_input_embeddings - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -620,7 +639,8 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 24274db148bd..53d6026c5938 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -31,27 +32,36 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed.parallel_state import ( - get_pp_group, get_tensor_model_parallel_world_size) + get_pp_group, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from ..layers.pooler import DispatchPooler, Pooler from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GPT2Attention(nn.Module): - def __init__( self, config: GPT2Config, @@ -62,8 +72,7 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads @@ -84,12 +93,14 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.c_proj", ) - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scale, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scale, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -103,7 +114,6 @@ def forward( class GPT2MLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -137,7 +147,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GPT2Block(nn.Module): - def __init__( self, config: GPT2Config, @@ -147,19 +156,14 @@ def __init__( ): super().__init__() hidden_size = config.hidden_size - inner_dim = (config.n_inner if config.n_inner is not None else 4 * - hidden_size) + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = GPT2Attention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, - config, - quant_config, - prefix=f"{prefix}.mlp") + self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp") def forward( self, @@ -181,7 +185,6 @@ def forward( @support_torch_compile class GPT2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -194,20 +197,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert not config.scale_attn_by_inverse_layer_idx assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size - self.wte = VocabParallelEmbedding(config.vocab_size, - self.embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.wte") + self.wte = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.wte", + ) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: GPT2Block( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + lambda prefix: GPT2Block(config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h", + ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.n_embd)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.n_embd + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -237,8 +242,7 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -260,34 +264,35 @@ def load_weights(self, weights: Iterable[tuple[str, if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class GPT2LMHeadModel(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = GPT2Model(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.lm_head") + self.transformer = GPT2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) if self.config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.transformer.wte) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -299,8 +304,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -310,8 +316,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) weights = _add_transformer_prefix(weights) return loader.load_weights(weights) @@ -334,22 +339,25 @@ class GPT2ForSequenceClassification(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - self.transformer = GPT2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "gpt2")) - self.score = nn.Linear(config.n_embd, - config.num_labels, - bias=False, - dtype=vllm_config.model_config.head_dtype) + self.transformer = GPT2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt2") + ) + self.score = nn.Linear( + config.n_embd, + config.num_labels, + bias=False, + dtype=vllm_config.model_config.head_dtype, + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - Pooler.for_classify(pooler_config, classifier=self.score), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": Pooler.for_classify(pooler_config, classifier=self.score), + } + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) @@ -366,15 +374,15 @@ def forward( input_ids=input_ids, position_ids=positions, inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + intermediate_tensors=intermediate_tensors, + ) return hidden_states def _add_transformer_prefix( - weights: Iterable[tuple[str, torch.Tensor]] + weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, torch.Tensor]]: for name, tensor in weights: - if not name.startswith('transformer.') and not name.startswith( - "lm_head"): - name = 'transformer.' + name + if not name.startswith("transformer.") and not name.startswith("lm_head"): + name = "transformer." + name yield name, tensor diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 162018450e7c..b6d3d8f3f2e6 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -20,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -33,24 +34,31 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GPTBigCodeAttention(nn.Module): - def __init__( self, config: GPTBigCodeConfig, @@ -61,11 +69,9 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads - self.tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + self.tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert total_num_heads % self.tensor_model_parallel_world_size == 0 - self.num_heads = (total_num_heads // - self.tensor_model_parallel_world_size) + self.num_heads = total_num_heads // self.tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads self.scale = self.head_dim**-0.5 @@ -94,13 +100,15 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.c_proj", ) - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scale, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -110,7 +118,8 @@ def forward( q, k, v = qkv.split( [ self.hidden_size // self.tensor_model_parallel_world_size, - self.kv_dim, self.kv_dim + self.kv_dim, + self.kv_dim, ], dim=-1, ) @@ -120,7 +129,6 @@ def forward( class GPTBigMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -154,7 +162,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GPTBigCodeBlock(nn.Module): - def __init__( self, config: GPTBigCodeConfig, @@ -164,19 +171,14 @@ def __init__( ): super().__init__() hidden_size = config.hidden_size - inner_dim = (config.n_inner if config.n_inner is not None else 4 * - hidden_size) + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = GPTBigCodeAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigMLP(inner_dim, - config, - quant_config, - prefix=f"{prefix}.mlp") + self.mlp = GPTBigMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp") def forward( self, @@ -184,7 +186,9 @@ def forward( ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn(hidden_states=hidden_states, ) + attn_output = self.attn( + hidden_states=hidden_states, + ) # residual connection hidden_states = attn_output + residual @@ -198,7 +202,6 @@ def forward( @support_torch_compile class GPTBigCodeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -211,23 +214,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert not config.add_cross_attention self.embed_dim = config.hidden_size - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab - self.wte = VocabParallelEmbedding(self.vocab_size, - self.embed_dim, - org_num_embeddings=config.vocab_size) + self.wte = VocabParallelEmbedding( + self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size + ) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, lambda prefix: GPTBigCodeBlock( - config, cache_config, quant_config, prefix=prefix), + config, cache_config, quant_config, prefix=prefix + ), prefix=f"{prefix}.h", ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.n_embd)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.n_embd + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -254,8 +261,7 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -266,13 +272,12 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method if "c_attn.input_scale" in name: - weight_loader(param, loaded_weight, 'q') - weight_loader(param, loaded_weight, 'k') - weight_loader(param, loaded_weight, 'v') + weight_loader(param, loaded_weight, "q") + weight_loader(param, loaded_weight, "k") + weight_loader(param, loaded_weight, "v") else: weight_loader(param, loaded_weight) loaded_params.add(name) @@ -292,9 +297,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = GPTBigCodeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) if self.config.tie_word_embeddings: self.lm_head = self.transformer.wte else: @@ -302,14 +307,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.transformer.vocab_size, self.transformer.embed_dim, org_num_embeddings=self.config.vocab_size, - prefix=maybe_prefix(prefix, "lm_head")) + prefix=maybe_prefix(prefix, "lm_head"), + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -321,8 +329,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -332,8 +341,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = None if self.config.tie_word_embeddings: skip_prefixes = ["lm_head."] diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 698387fab946..5428512dec19 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -31,26 +32,35 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GPTJAttention(nn.Module): - def __init__( self, config: GPTJConfig, @@ -85,8 +95,7 @@ def __init__( assert getattr(config, "rotary", True) assert config.rotary_dim % 2 == 0 rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_size, rotary_dim=config.rotary_dim, @@ -94,12 +103,14 @@ def __init__( base=rope_theta, is_neox_style=False, ) - self.attn = Attention(self.num_heads, - self.head_size, - scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -115,7 +126,6 @@ def forward( class GPTJMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -144,7 +154,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GPTJBlock(nn.Module): - def __init__( self, config: GPTJConfig, @@ -153,13 +162,11 @@ def __init__( prefix: str = "", ): super().__init__() - inner_dim = (4 * config.n_embd - if config.n_inner is None else config.n_inner) + inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = GPTJAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( @@ -180,7 +187,6 @@ def forward( @support_torch_compile class GPTJModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -197,14 +203,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.h = make_layers( config.n_layer, - lambda prefix: GPTJBlock( - config, cache_config, quant_config, prefix=prefix), + lambda prefix: GPTJBlock(config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h", ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.n_embd)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.n_embd + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -230,8 +235,7 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -246,19 +250,20 @@ def load_weights(self, weights: Iterable[tuple[str, if "attn.bias" in name or "attn.masked_bias" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -281,15 +286,13 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class GPTJForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -297,9 +300,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config assert not config.tie_word_embeddings - self.transformer = GPTJModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = GPTJModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, @@ -309,7 +312,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -321,19 +325,18 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - self.lm_head.bias) + logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 45519a94d854..8278ae03d88a 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -31,25 +32,32 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GPTNeoXAttention(nn.Module): - def __init__( self, config: GPTNeoXConfig, @@ -63,11 +71,9 @@ def __init__( self.head_size = self.hidden_size // self.total_num_heads self.bias = getattr(config, "attention_bias", True) - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.query_key_value = QKVParallelLinear( config.hidden_size, @@ -86,20 +92,21 @@ def __init__( rotary_dim = int(self.head_size * config.rotary_pct) assert rotary_dim % 2 == 0 rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.rotary_emb = get_rope( self.head_size, rotary_dim=rotary_dim, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, - self.head_size, - scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -115,7 +122,6 @@ def forward( class GPTNeoXMLP(nn.Module): - def __init__( self, config: GPTNeoXConfig, @@ -142,7 +148,6 @@ def forward(self, hidden_states): class GPTNeoXLayer(nn.Module): - def __init__( self, config: GPTNeoXConfig, @@ -152,14 +157,15 @@ def __init__( ): super().__init__() self.use_parallel_residual = config.use_parallel_residual - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attention") + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.attention = GPTNeoXAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attention" + ) self.mlp = GPTNeoXMLP(config, quant_config) def forward( @@ -192,7 +198,6 @@ def forward( @support_torch_compile class GPTNeoXModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -209,14 +214,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: GPTNeoXLayer( - config, cache_config, quant_config, prefix=prefix), + config, cache_config, quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) - self.final_layer_norm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_in(input_ids) @@ -242,16 +249,17 @@ def forward( hidden_states = self.final_layer_norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if ("attention.bias" in name or "attention.masked_bias" in name - or "rotary_emb.inv_freq" in name): + if ( + "attention.bias" in name + or "attention.masked_bias" in name + or "rotary_emb.inv_freq" in name + ): continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using OpenRLHF may include # these tensors in the checkpoint. Skip them. continue @@ -269,29 +277,29 @@ def load_weights(self, weights: Iterable[tuple[str, if output_dim is not None: loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) + loaded_weight_shape[:output_dim] + + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1 :] + ) + loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1) loaded_weight = loaded_weight.reshape(loaded_weight_shape) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class GPTNeoXForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "gpt_neox")) + self.gpt_neox = GPTNeoXModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt_neox") + ) self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, @@ -302,7 +310,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.gpt_neox.make_empty_intermediate_tensors) + self.gpt_neox.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.gpt_neox.get_input_embeddings(input_ids) @@ -314,8 +323,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.gpt_neox(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.gpt_neox( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -325,7 +335,6 @@ def compute_logits( logits = self.logits_processor(self.embed_out, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 47ba5084d608..17f911435079 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -11,33 +11,41 @@ from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from vllm.utils import cdiv from .interfaces import SupportsEagle3, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class OAIAttention(nn.Module): - def __init__( self, config: GptOssConfig, @@ -59,16 +67,13 @@ def __init__( base=config.rope_theta, dtype=torch.float32, rope_scaling={ - "rope_type": - "yarn", - "factor": - config.rope_scaling["factor"], - "original_max_position_embeddings": - config.rope_scaling["original_max_position_embeddings"], - "beta_fast": - config.rope_scaling["beta_fast"], - "beta_slow": - config.rope_scaling["beta_slow"], + "rope_type": "yarn", + "factor": config.rope_scaling["factor"], + "original_max_position_embeddings": config.rope_scaling[ + "original_max_position_embeddings" + ], + "beta_fast": config.rope_scaling["beta_fast"], + "beta_slow": config.rope_scaling["beta_slow"], }, is_neox_style=True, ) @@ -76,8 +81,8 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() self.sinks = torch.nn.Parameter( - torch.empty(config.num_attention_heads // tp_size, - requires_grad=False)) + torch.empty(config.num_attention_heads // tp_size, requires_grad=False) + ) self.q_size = self.num_attention_heads * self.head_dim // tp_size self.kv_size = self.num_key_value_heads * self.head_dim // tp_size @@ -104,8 +109,7 @@ def __init__( self.num_local_key_value_heads = config.num_key_value_heads // tp_size # Only apply sliding window to every other layer - sliding_window = (config.sliding_window if self.layer_idx % - 2 == 0 else None) + sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None self.attn = Attention( self.num_local_attention_heads, self.head_dim, @@ -119,8 +123,9 @@ def __init__( sinks=self.sinks, ) - def forward(self, hidden_states: torch.Tensor, - positions: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, positions: torch.Tensor + ) -> torch.Tensor: qkv, _ = self.qkv(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) @@ -131,7 +136,6 @@ def forward(self, hidden_states: torch.Tensor, class MLPBlock(torch.nn.Module): - def __init__( self, vllm_config: VllmConfig, @@ -150,21 +154,22 @@ def __init__( self.num_experts = config.num_local_experts self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.router = torch.nn.Linear(config.hidden_size, - config.num_local_experts) + self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts) assert config.intermediate_size % self.world_size == 0 - self.experts = FusedMoE(num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - prefix=f"{prefix}.experts", - apply_router_weight_on_input=False, - has_bias=True, - activation="swigluoai", - is_sequence_parallel=self.is_sequence_parallel) + self.experts = FusedMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + apply_router_weight_on_input=False, + has_bias=True, + activation="swigluoai", + is_sequence_parallel=self.is_sequence_parallel, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens = x.shape[0] @@ -181,7 +186,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(torch.nn.Module): - def __init__( self, vllm_config: VllmConfig, @@ -193,12 +197,10 @@ def __init__( cache_config = vllm_config.cache_config self.layer_idx = extract_layer_index(prefix) - self.attn = OAIAttention(config, - prefix=f"{prefix}.attn", - cache_config=cache_config) - self.mlp = MLPBlock(vllm_config, - self.layer_idx, - prefix=f"{prefix}.mlp") + self.attn = OAIAttention( + config, prefix=f"{prefix}.attn", cache_config=cache_config + ) + self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp") self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) @@ -213,19 +215,16 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.attn(hidden_states, positions) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) output = self.mlp(hidden_states) return output, residual @support_torch_compile class GptOssModel(nn.Module): - def __init__( self, *, @@ -249,9 +248,9 @@ def __init__( prefix=f"{prefix}.layers", ) self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], self.config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size + ) self.aux_hidden_state_layers = tuple[int, ...]() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -280,14 +279,10 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] if i in self.aux_hidden_state_layers: - aux_hidden_states.append(x if residual is None else x + - residual) + aux_hidden_states.append(x if residual is None else x + residual) x, residual = layer(x, positions, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": x, - "residual": residual - }) + return IntermediateTensors({"hidden_states": x, "residual": residual}) x, _ = self.norm(x, residual) if len(aux_hidden_states) > 0: @@ -315,15 +310,12 @@ def _load_weights_mxfp4( intermediate_size = self.config.intermediate_size intermediate_size_block = intermediate_size // mxfp4_block - per_rank_intermediate_size_block = cdiv(intermediate_size_block, - tp_size) - per_rank_intermediate_size = (per_rank_intermediate_size_block * - mxfp4_block) + per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) + per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, - intermediate_size) + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) for name, weight in weights: # Skip layers on other devices. @@ -338,18 +330,17 @@ def _load_weights_mxfp4( if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None, + ) loaded_params.add(name) continue elif ".w2_weight_scale" in name: @@ -357,66 +348,68 @@ def _load_weights_mxfp4( if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[..., tp_rank_start // - mxfp4_block:tp_rank_end // - mxfp4_block] + narrow_weight = weight[ + ..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block + ] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None, + ) loaded_params.add(name) continue elif ".w13_weight" in name: # Handle MLP gate and up projection weights # flat weight from (E, 2 * N, block_size, entry_per_block) # to (E, 2 * N, -1), shouldn't trigger copy for contiguous - weight = weight.view(num_experts, 2 * intermediate_size, - -1).contiguous() + weight = weight.view( + num_experts, 2 * intermediate_size, -1 + ).contiguous() # Extract gate and up projection parts # since the weight is shuffled, we can slice directly if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end, - ...] + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None, + ) loaded_params.add(name) continue elif ".w2_weight" in name: # Handle MLP down projection weights # same flatten here, but since 2 mx4 value are packed in 1 # uint8, divide by 2 - weight = weight.view(num_experts, -1, - intermediate_size // 2).contiguous() + weight = weight.view( + num_experts, -1, intermediate_size // 2 + ).contiguous() if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[..., - tp_rank_start // 2:tp_rank_end // 2] + narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None, + ) loaded_params.add(name) continue elif ".w13_bias" in name: @@ -425,35 +418,32 @@ def _load_weights_mxfp4( if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end] + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - narrow_weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + narrow_weight, + weight_name=name, + shard_id=None, + expert_id=None, + ) loaded_params.add(name) continue elif ".w2_bias" in name: # Handle MLP down projection bias param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if use_ep: weight = weight[ep_rank_start:ep_rank_end, ...] else: # (only load on rank 0 to avoid duplication) if tp_rank != 0: weight.zero_() - weight_loader(param, - weight, - weight_name=name, - shard_id=None, - expert_id=None) + weight_loader( + param, weight, weight_name=name, shard_id=None, expert_id=None + ) loaded_params.add(name) continue elif "sinks" in name: @@ -468,8 +458,7 @@ def _load_weights_mxfp4( continue name = name.replace(weight_name, param_name) param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, weight) else: @@ -480,8 +469,7 @@ def _load_weights_mxfp4( if name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) loaded_params.add(name) return loaded_params @@ -507,8 +495,7 @@ def _load_weights_other( per_rank_intermediate_size = cdiv(intermediate_size, tp_size) # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size - tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, - intermediate_size) + tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) for name, weight in weights: # Skip layers on other devices. @@ -521,8 +508,7 @@ def _load_weights_other( if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[:, :, - 2 * tp_rank_start:2 * tp_rank_end] + narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end] narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() param = params_dict[name] @@ -548,8 +534,7 @@ def _load_weights_other( if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: - narrow_weight = weight[:, - 2 * tp_rank_start:2 * tp_rank_end] + narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] param = params_dict[name] param.copy_(narrow_weight) @@ -579,8 +564,7 @@ def _load_weights_other( continue name = name.replace(weight_name, param_name) param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, weight) else: @@ -591,14 +575,12 @@ def _load_weights_other( if name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) loaded_params.add(name) return loaded_params - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv", ".q_proj", "q"), @@ -620,16 +602,29 @@ def load_weights(self, weights: Iterable[tuple[str, ep_rank_start = ep_rank * experts_per_rank ep_rank_end = (ep_rank + 1) * experts_per_rank - quant_method = (self.config.quantization_config['quant_method'] if - hasattr(self.config, "quantization_config") else None) + quant_method = ( + self.config.quantization_config["quant_method"] + if hasattr(self.config, "quantization_config") + else None + ) if quant_method == "mxfp4": - return self._load_weights_mxfp4(ep_rank_end, ep_rank_start, - heads_per_rank, head_start, - weights, stacked_params_mapping) + return self._load_weights_mxfp4( + ep_rank_end, + ep_rank_start, + heads_per_rank, + head_start, + weights, + stacked_params_mapping, + ) else: - return self._load_weights_other(ep_rank_end, ep_rank_start, - heads_per_rank, head_start, - weights, stacked_params_mapping) + return self._load_weights_other( + ep_rank_end, + ep_rank_start, + heads_per_rank, + head_start, + weights, + stacked_params_mapping, + ) class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): @@ -641,17 +636,14 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3): }, orig_to_new_suffix={ ".embed_tokens.weight": ".embedding.weight", - # MoE MXFP4 weights ".gate_up_proj_blocks": ".w13_weight", ".down_proj_blocks": ".w2_weight", ".gate_up_proj_scales": ".w13_weight_scale", ".down_proj_scales": ".w2_weight_scale", - # MoE other weights ".gate_up_proj": ".w13_weight", ".down_proj": ".w2_weight", - # MoE Bias ".gate_up_proj_bias": ".w13_bias", ".down_proj_bias": ".w2_bias", @@ -678,7 +670,8 @@ def __init__( ) self.logits_processor = LogitsProcessor(self.config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers @@ -690,23 +683,22 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: - return self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 893cc8a41455..e9bc592c0797 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM Granite model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -37,25 +38,36 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_layers, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) class GraniteMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -71,15 +83,19 @@ def __init__( output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -90,7 +106,6 @@ def forward(self, x): class GraniteAttention(nn.Module): - def __init__( self, config: GraniteConfig, @@ -155,13 +170,15 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -177,7 +194,6 @@ def forward( class GraniteDecoderLayer(nn.Module): - def __init__( self, config: GraniteConfig, @@ -191,21 +207,24 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.self_attn = GraniteAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -223,10 +242,10 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -251,7 +270,6 @@ def forward( @support_torch_compile class GraniteModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -262,12 +280,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -275,18 +297,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) else: self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: GraniteDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers") + lambda prefix: GraniteDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -317,15 +343,16 @@ def forward( hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -337,18 +364,19 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -377,8 +405,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -414,8 +441,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.model = GraniteModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = GraniteModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -427,7 +455,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -438,9 +467,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if hasattr(config, "logits_scaling"): logit_scale /= config.logits_scaling - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, scale=logit_scale + ) else: self.lm_head = PPMissingLayer() @@ -454,32 +483,31 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output - def compute_logits(self, - hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # With tie_word_embeddings, we can skip lm_head.weight # The weight might appear unnecessarily in the files if the model is # processed with quantization, LoRA, fine-tuning, etc. - skip_prefixes = (["lm_head."] - if self.config.tie_word_embeddings else None) + skip_prefixes = ["lm_head."] if self.config.tie_word_embeddings else None loader = AutoWeightsLoader( self, diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index e543c6040fc0..82bceaf3ed01 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM Granite speech model.""" + import math from collections.abc import Iterable, Mapping from typing import Annotated, Optional, Union @@ -34,25 +35,37 @@ from vllm.config import CacheConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .blip2 import Blip2QFormerModel -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix @@ -60,7 +73,7 @@ class GraniteSpeechAudioInputs(TensorSchema): """ Audio input features for Granite Speech model. - + Dimensions: - b: Batch size - fi: Number of input features from the Mel spectrogram. @@ -79,7 +92,6 @@ class GraniteSpeechAudioInputs(TensorSchema): class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": 1} @@ -96,8 +108,8 @@ def get_max_audio_len(self): ### Input Processing & Multimodal utils class GraniteSpeechMultiModalProcessor( - BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]): - + BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo] +): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_hf_processor().audio_processor sampling_rate = feature_extractor.melspec_kwargs["sample_rate"] @@ -133,7 +145,8 @@ def get_replacement(item_idx: int): audio = audios.get(item_idx) audio_length = audio.shape[-1] num_projector_features = feature_extractor._get_num_audio_features( - [audio_length])[0] + [audio_length] + )[0] return [audio_token_id] * num_projector_features return [ @@ -170,14 +183,15 @@ def _call_hf_processor( # This is used to split the batch back out after padding. audio_token_index = self.info.get_hf_config().audio_token_index processed_outputs["audio_embed_sizes"] = ( - processed_outputs["input_ids"] == audio_token_index).sum(-1) + processed_outputs["input_ids"] == audio_token_index + ).sum(-1) return processed_outputs class GraniteSpeechDummyInputsBuilder( - BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]): - + BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo] +): def get_dummy_mm_data( self, seq_len: int, @@ -188,8 +202,7 @@ def get_dummy_mm_data( audio_overrides = mm_options.get("audio") if mm_options else None return { - "audio": - self._get_dummy_audios( + "audio": self._get_dummy_audios( length=self.info.get_max_audio_len(), num_audios=num_audios, overrides=audio_overrides, @@ -205,7 +218,6 @@ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: ### QFormer Projector class GraniteSpeechEncoderProjector(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -220,8 +232,8 @@ def __init__( self.num_queries = config.window_size // config.downsample_rate self.query = nn.Parameter( - torch.zeros(1, self.num_queries, - config.projector_config.hidden_size)) + torch.zeros(1, self.num_queries, config.projector_config.hidden_size) + ) # NOTE - this is implemented generically in transformers, # but for now we create the QFormer model directly since @@ -232,17 +244,16 @@ def __init__( cache_config=cache_config, prefix=f"{prefix}.qformer", ) - self.linear = nn.Linear(config.projector_config.hidden_size, - config.text_config.hidden_size) + self.linear = nn.Linear( + config.projector_config.hidden_size, config.text_config.hidden_size + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, seq_len, dim = hidden_states.size() nblocks = math.ceil(seq_len / self.window_size) pad = nblocks * self.window_size - seq_len - hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), - "constant", 0) - hidden_states = hidden_states.view(batch_size * nblocks, - self.window_size, dim) + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), "constant", 0) + hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim) last_hidden_state = self.qformer( query_embeds=self.query.data, @@ -254,7 +265,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, nblocks * self.window_size // self.downsample_rate, -1, - )) + ) + ) return query_proj @@ -264,10 +276,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GraniteSpeechConformerFeedForward(nn.Module): """Feedforward module for conformer encoder blocks.""" - def __init__(self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.pre_norm = nn.LayerNorm(config.hidden_dim) @@ -313,16 +327,16 @@ def __init__(self, config: PretrainedConfig, prefix: str = ""): self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, config.hidden_dim) - self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, - self.dim_head) + self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head) if self.context_size <= 0 or self.context_size > self.max_pos_emb: raise ValueError( "Context size is either less than 0 or exceeds the max_pos_emb" ) - def forward(self, hidden_states: torch.Tensor, - attention_dists: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, attention_dists: torch.Tensor + ) -> torch.Tensor: hidden_states = self.pre_norm(hidden_states) bsz, num_features, _ = hidden_states.shape @@ -331,47 +345,53 @@ def forward(self, hidden_states: torch.Tensor, if remainder > 0: # right padding to reach block size hidden_states = torch.nn.functional.pad( - hidden_states, (0, 0, 0, self.context_size - remainder)) + hidden_states, (0, 0, 0, self.context_size - remainder) + ) # NOTE: would be nice to try to use qkvparallellinear # here for this block attention implementation if possible query_states = self.to_q(hidden_states) key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) - query_states = query_states.reshape(bsz, num_blocks, self.context_size, - self.num_heads, - -1).transpose(2, 3) - key_states = key_states.reshape(bsz, num_blocks, self.context_size, - self.num_heads, -1).transpose(2, 3) - value_states = value_states.reshape(bsz, num_blocks, self.context_size, - self.num_heads, - -1).transpose(2, 3) + query_states = query_states.reshape( + bsz, num_blocks, self.context_size, self.num_heads, -1 + ).transpose(2, 3) + key_states = key_states.reshape( + bsz, num_blocks, self.context_size, self.num_heads, -1 + ).transpose(2, 3) + value_states = value_states.reshape( + bsz, num_blocks, self.context_size, self.num_heads, -1 + ).transpose(2, 3) # shaw's relative positional embedding dist = attention_dists.to(hidden_states.device) rel_pos_emb = self.rel_pos_emb(dist) - rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + - list(rel_pos_emb.shape)) - pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, - dim=-1) * self.scale + rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) + pos_attn = ( + torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) + * self.scale + ) if remainder > 0: # masked attention in the extended block - mask = torch.ones(self.context_size, - self.context_size, - dtype=bool, - device=hidden_states.device) + mask = torch.ones( + self.context_size, + self.context_size, + dtype=bool, + device=hidden_states.device, + ) mask[:remainder, :remainder] = 0 mask_value = -torch.finfo(pos_attn.dtype).max pos_attn[:, -1, :].masked_fill_(mask, mask_value) - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.MATH): - out = F.scaled_dot_product_attention(query_states, - key_states, - value_states, - attn_mask=pos_attn, - scale=self.scale) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + out = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=pos_attn, + scale=self.scale, + ) out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1) return self.to_out(out[:, :num_features, :]) @@ -379,22 +399,16 @@ def forward(self, hidden_states: torch.Tensor, class GraniteSpeechConformerDepthWiseConv1d(nn.Module): """Wrapper for padded 1D pointwise convolution.""" - def __init__(self, - chan_in: int, - chan_out: int, - kernel_size: int, - prefix: str = ""): + def __init__(self, chan_in: int, chan_out: int, kernel_size: int, prefix: str = ""): super().__init__() # Padding for the 1D conv is symmetric or close (i.e., offset by one). pad = kernel_size // 2 pad_offset = (kernel_size + 1) % 2 self.padding = (pad, pad - pad_offset) - self.conv = nn.Conv1d(chan_in, - chan_out, - kernel_size, - groups=chan_in, - bias=False) + self.conv = nn.Conv1d( + chan_in, chan_out, kernel_size, groups=chan_in, bias=False + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = F.pad(hidden_states, self.padding) @@ -439,21 +453,19 @@ class GraniteSpeechConformerBlock(nn.Module): def __init__(self, config: PretrainedConfig, prefix: str = ""): super().__init__() - self.ff1 = GraniteSpeechConformerFeedForward(config, - prefix=f"{prefix}.ff1") - self.attn = GraniteSpeechConformerAttention(config, - prefix=f"{prefix}.attn") - self.conv = GraniteSpeechConformerConvModule(config, - prefix=f"{prefix}.conv") - self.ff2 = GraniteSpeechConformerFeedForward(config, - prefix=f"{prefix}.ff2") + self.ff1 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff1") + self.attn = GraniteSpeechConformerAttention(config, prefix=f"{prefix}.attn") + self.conv = GraniteSpeechConformerConvModule(config, prefix=f"{prefix}.conv") + self.ff2 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff2") self.post_norm = nn.LayerNorm(config.hidden_dim) - def forward(self, hidden_states: torch.Tensor, - attention_dists: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, attention_dists: torch.Tensor + ) -> torch.Tensor: hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states - hidden_states = self.attn( - hidden_states, attention_dists=attention_dists) + hidden_states + hidden_states = ( + self.attn(hidden_states, attention_dists=attention_dists) + hidden_states + ) hidden_states = self.conv(hidden_states) + hidden_states hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states hidden_states = self.post_norm(hidden_states) @@ -463,29 +475,33 @@ def forward(self, hidden_states: torch.Tensor, class GraniteSpeechCTCEncoder(nn.Module): """CTC Encoder comprising conformer blocks and additional linear layers.""" - def __init__(self, - config: PretrainedConfig, - prefix: str, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, + config: PretrainedConfig, + prefix: str, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config # Precompute clamped relative positional encoding distances seq = torch.arange(config.context_size) relpos_dist = seq.view(-1, 1) - seq.view(1, -1) - self.attention_dists = torch.clamp( - relpos_dist, -config.context_size, - config.context_size) + config.max_pos_emb - - self.input_linear = nn.Linear(config.input_dim, - config.hidden_dim, - bias=True) - self.layers = nn.ModuleList([ - GraniteSpeechConformerBlock( - config, - prefix=f"{prefix}.layers.{idx}", - ) for idx in range(config.num_layers) - ]) + self.attention_dists = ( + torch.clamp(relpos_dist, -config.context_size, config.context_size) + + config.max_pos_emb + ) + + self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True) + self.layers = nn.ModuleList( + [ + GraniteSpeechConformerBlock( + config, + prefix=f"{prefix}.layers.{idx}", + ) + for idx in range(config.num_layers) + ] + ) self.out = ColumnParallelLinear( input_size=config.hidden_dim, @@ -508,8 +524,7 @@ def __init__(self, def forward(self, hidden_states: torch.Tensor): hidden_states = self.input_linear(hidden_states) for idx, layer in enumerate(self.layers, start=1): - hidden_states = layer(hidden_states, - attention_dists=self.attention_dists) + hidden_states = layer(hidden_states, attention_dists=self.attention_dists) if idx == self.num_layers // 2: hidden_states_mid = hidden_states.clone() @@ -523,12 +538,13 @@ def forward(self, hidden_states: torch.Tensor): @MULTIMODAL_REGISTRY.register_processor( GraniteSpeechMultiModalProcessor, info=GraniteSpeechMultiModalProcessingInfo, - dummy_inputs=GraniteSpeechDummyInputsBuilder) + dummy_inputs=GraniteSpeechDummyInputsBuilder, +) class GraniteSpeechForConditionalGeneration( - nn.Module, - SupportsMultiModal, - SupportsPP, - SupportsLoRA, + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsLoRA, ): merge_by_field_config = True @@ -584,7 +600,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_audio_input( self, @@ -602,17 +619,21 @@ def _parse_and_validate_audio_input( # from the processor, but we handle rebuilding it here since # vLLM generally processes everything independently + batches. if input_features_mask is None: - input_features_mask = self._build_input_features_mask( - audio_embed_sizes) + input_features_mask = self._build_input_features_mask(audio_embed_sizes) if not isinstance(input_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_features)}") + raise ValueError( + "Incorrect type of audio input features. " + f"Got type: {type(input_features)}" + ) if input_features_mask is not None and not isinstance( - input_features_mask, torch.Tensor): - raise ValueError("Incorrect type of audio input features mask. " - f"Got type: {type(input_features_mask)}") + input_features_mask, torch.Tensor + ): + raise ValueError( + "Incorrect type of audio input features mask. " + f"Got type: {type(input_features_mask)}" + ) if isinstance(input_features, torch.Tensor): # Granite speech currently only allows one audio token per instance @@ -625,16 +646,17 @@ def _parse_and_validate_audio_input( if len(input_features.shape) != 3: raise ValueError( "Squeezed input features should be 3D but are of shape " - f"{input_features.shape}") - input_features = input_features.to( - self.encoder.input_linear.weight.dtype) + f"{input_features.shape}" + ) + input_features = input_features.to(self.encoder.input_linear.weight.dtype) else: # Otherwise we have a list of tensors, which are almost certainly # differing in their respective numbers of audio features; # stack them into a 3D tensor of size [bsz, most_num_features, 160]. input_features = self._pad_and_stack_input_features( - input_features, ).to(self.encoder.input_linear.weight.dtype) + input_features, + ).to(self.encoder.input_linear.weight.dtype) return GraniteSpeechAudioInputs( input_features=input_features, @@ -706,7 +728,7 @@ def _process_audio_input( audio_input: GraniteSpeechAudioInputs, ) -> tuple[torch.Tensor]: """Compute the audio features to be merged into the LLM embeddings. - + Args: audio_input: GraniteSpeechAudioInputs Audio inputs object containing Mel features, an input features @@ -769,8 +791,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - model_output = self.language_model(input_ids, positions, - intermediate_tensors, inputs_embeds) + model_output = self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 489c0bb3d3af..4711ed05c587 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GraniteMoe model.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional @@ -33,27 +34,35 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_layers, - maybe_prefix) +from .utils import AutoWeightsLoader, is_pp_missing_parameter, make_layers, maybe_prefix class GraniteMoeMoE(nn.Module): @@ -64,39 +73,45 @@ class GraniteMoeMoE(nn.Module): across ranks. """ - def __init__(self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - is_sequence_parallel=False, - prefix: str = ""): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + is_sequence_parallel=False, + prefix: str = "", + ): super().__init__() self.hidden_size = hidden_size self.is_sequence_parallel = is_sequence_parallel # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, - num_experts, - bias=False, - params_dtype=params_dtype, - quant_config=None, - prefix=f"{prefix}.gate") - - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=tp_size, - prefix=f"{prefix}.experts", - is_sequence_parallel=self.is_sequence_parallel) + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts", + is_sequence_parallel=self.is_sequence_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -112,7 +127,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( - final_hidden_states, 0) + final_hidden_states, 0 + ) num_tokens = orig_shape[0] final_hidden_states = final_hidden_states[:num_tokens] @@ -120,7 +136,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GraniteMoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -153,8 +168,11 @@ def __init__( self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = (attention_multiplier if attention_multiplier - is not None else self.head_dim**-1) + self.scaling = ( + attention_multiplier + if attention_multiplier is not None + else self.head_dim**-1 + ) self.rope_theta = rope_theta self.qkv_proj = QKVParallelLinear( @@ -181,13 +199,15 @@ def __init__( is_neox_style=True, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -203,7 +223,6 @@ def forward( class GraniteMoeDecoderLayer(nn.Module): - def __init__( self, vllm_config: VllmConfig, @@ -230,7 +249,8 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - attention_multiplier=config.attention_multiplier) + attention_multiplier=config.attention_multiplier, + ) self.block_sparse_moe = GraniteMoeMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, @@ -238,12 +258,13 @@ def __init__( intermediate_size=config.intermediate_size, quant_config=quant_config, is_sequence_parallel=parallel_config.use_sequence_parallel_moe, - prefix=f"{prefix}.block_sparse_moe") + prefix=f"{prefix}.block_sparse_moe", + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.residual_multiplier = config.residual_multiplier @@ -270,7 +291,6 @@ def forward( @support_torch_compile class GraniteMoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -280,8 +300,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config # Required by MixtralModel - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -295,7 +318,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -321,17 +345,18 @@ def forward( for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) hidden_states = self.norm(hidden_states) return hidden_states - def _load_weights(self, - weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """ - This function is copied from `MixtralModel.load_weights`, mainly to - decouple from mixtral, avoiding impact on support like BNB + This function is copied from `MixtralModel.load_weights`, mainly to + decouple from mixtral, avoiding impact on support like BNB quantization. """ stacked_params_mapping = [ @@ -347,30 +372,33 @@ def _load_weights(self, ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) + num_experts=self.config.num_local_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -393,21 +421,25 @@ def _load_weights(self, # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -418,40 +450,45 @@ def _load_weights(self, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: new_weights = {} for n, p in weights: - if n.endswith('.block_sparse_moe.input_linear.weight'): + if n.endswith(".block_sparse_moe.input_linear.weight"): for e in range(p.size(0)): w1_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w1.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w1.weight", + ) w3_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w3.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w3.weight", + ) w1_param, w3_param = p[e].chunk(2, dim=0) assert w1_name not in new_weights assert w3_name not in new_weights new_weights[w1_name] = w1_param new_weights[w3_name] = w3_param - elif n.endswith('.block_sparse_moe.output_linear.weight'): + elif n.endswith(".block_sparse_moe.output_linear.weight"): for e in range(p.size(0)): w2_name = n.replace( - '.block_sparse_moe.output_linear.weight', - f".block_sparse_moe.experts.{e}.w2.weight") + ".block_sparse_moe.output_linear.weight", + f".block_sparse_moe.experts.{e}.w2.weight", + ) w2_param = p[e] assert w2_name not in new_weights new_weights[w2_name] = w2_param - elif n.endswith('.block_sparse_moe.router.layer.weight'): - gate_name = n.replace('.block_sparse_moe.router.layer.weight', - ".block_sparse_moe.gate.weight") + elif n.endswith(".block_sparse_moe.router.layer.weight"): + gate_name = n.replace( + ".block_sparse_moe.router.layer.weight", + ".block_sparse_moe.gate.weight", + ) assert gate_name not in new_weights new_weights[gate_name] = p else: @@ -486,8 +523,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.lora_config = lora_config - self.model = GraniteMoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = GraniteMoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -498,17 +536,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=1 / - self.config.logits_scaling) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, + config.vocab_size, + scale=1 / self.config.logits_scaling, + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -520,30 +560,29 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def compute_logits(self, - hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index f5751fe47bb8..f877dc576427 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only GraniteMoeHybrid model.""" + # Added by the IBM Team, 2025 from collections.abc import Iterable from typing import Optional @@ -15,58 +16,67 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .granitemoe import GraniteMoeMoE from .granitemoeshared import GraniteMoeSharedMLP -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant) -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class GraniteMoeHybridMambaDecoderLayer(nn.Module): - - def __init__(self, - config: GraniteMoeHybridConfig, - layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: GraniteMoeHybridConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size self.residual_multiplier = config.residual_multiplier - self.mamba = MambaMixer2(hidden_size= config.hidden_size, - ssm_state_size = config.mamba_d_state, - conv_kernel_size = config.mamba_d_conv, - intermediate_size = config.mamba_expand *\ - config.hidden_size, - use_conv_bias = config.mamba_conv_bias, - use_bias = config.mamba_proj_bias, - n_groups=config.mamba_n_groups, - num_heads=config.mamba_n_heads, - head_dim=config.mamba_d_head, - rms_norm_eps=config.rms_norm_eps, - activation=config.hidden_act, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.mixer") + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=config.mamba_expand * config.hidden_size, + use_conv_bias=config.mamba_conv_bias, + use_bias=config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) self.block_sparse_moe = None if getattr(config, "num_local_experts", 0) > 0: @@ -76,20 +86,21 @@ def __init__(self, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") + prefix=f"{prefix}.block_sparse_moe", + ) - self.shared_mlp = None if \ - getattr(config, 'shared_intermediate_size', 0) == 0 \ + self.shared_mlp = ( + None + if getattr(config, "shared_intermediate_size", 0) == 0 else GraniteMoeSharedMLP( - config, - quant_config=quant_config, - prefix=f"{prefix}.shared_mlp" + config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp" ) + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -114,8 +125,7 @@ def forward( if self.block_sparse_moe is not None: moe_hidden_states = hidden_states.clone() moe_hidden_states = self.block_sparse_moe(moe_hidden_states) - hidden_states = moe_hidden_states + self.shared_mlp( - hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) del moe_hidden_states else: hidden_states = self.shared_mlp(hidden_states) @@ -125,7 +135,6 @@ def forward( class GraniteMoeHybridAttentionDecoderLayer(nn.Module): - def __init__( self, config: GraniteMoeHybridConfig, @@ -143,7 +152,8 @@ def __init__( config, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.self_attn") + prefix=f"{prefix}.self_attn", + ) self.block_sparse_moe = None if getattr(config, "num_local_experts", 0) > 0: @@ -153,20 +163,21 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") + prefix=f"{prefix}.block_sparse_moe", + ) - self.shared_mlp = None if \ - getattr(config, 'shared_intermediate_size', 0) == 0 \ + self.shared_mlp = ( + None + if getattr(config, "shared_intermediate_size", 0) == 0 else GraniteMoeSharedMLP( - config, - quant_config=quant_config, - prefix=f"{prefix}.shared_mlp" + config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp" ) + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -194,8 +205,7 @@ def forward( if self.block_sparse_moe is not None: moe_hidden_states = hidden_states.clone() moe_hidden_states = self.block_sparse_moe(moe_hidden_states) - hidden_states = moe_hidden_states + self.shared_mlp( - hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) del moe_hidden_states else: hidden_states = self.shared_mlp(hidden_states) @@ -205,7 +215,6 @@ def forward( class GraniteMoeHybridAttention(nn.Module): - def __init__( self, config: GraniteMoeHybridConfig, @@ -237,19 +246,23 @@ def __init__( assert tp_size % self.total_num_kv_heads == 0 self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size) - self.qkv_proj = QKVParallelLinear(self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) - self.o_proj = RowParallelLinear(self.hidden_size, - self.hidden_size, - bias=self.attention_bias, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=self.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) if config.position_embedding_type == "rope": self.rotary_emb = get_rope( @@ -257,34 +270,38 @@ def __init__( rotary_dim=self.head_dim, max_position=config.max_position_embeddings, base=int(config.rope_theta), - rope_scaling=config.rope_scaling \ - if hasattr(config, "rope_scaling") \ - and config.rope_scaling is not None else None, + rope_scaling=config.rope_scaling + if hasattr(config, "rope_scaling") and config.rope_scaling is not None + else None, is_neox_style=True, ) else: self.rotary_emb = None - self.attn = Attention(self.num_heads, - self.head_dim, - self.attention_multiplier, - num_kv_heads=self.num_key_value_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.attention_multiplier, + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - query, key, value = qkv.split([ - self.num_heads * self.head_dim, self.num_key_value_heads * - self.head_dim, self.num_key_value_heads * self.head_dim - ], - dim=-1) + query, key, value = qkv.split( + [ + self.num_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ], + dim=-1, + ) if self.rotary_emb is not None: query, key = self.rotary_emb(positions, query, key) @@ -304,7 +321,6 @@ def forward( @support_torch_compile class GraniteMoeHybridModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -315,8 +331,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -329,8 +348,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) - layer_class = ALL_DECODER_LAYER_TYPES[ - config.layer_types[layer_idx]] + layer_class = ALL_DECODER_LAYER_TYPES[config.layer_types[layer_idx]] return layer_class( config, layer_idx, @@ -341,10 +359,11 @@ def get_layer(prefix: str): ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -358,7 +377,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -368,7 +386,7 @@ def forward( residual = None else: if intermediate_tensors is None: - raise RuntimeError('Intermediate tensors may not be None!') + raise RuntimeError("Intermediate tensors may not be None!") hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] @@ -376,21 +394,19 @@ def forward( for i, layer in enumerate(self.layers): if isinstance(layer, GraniteMoeHybridAttentionDecoderLayer): num_attn += 1 - hidden_states, residual = layer(positions=positions, - hidden_states=hidden_states, - residual=residual) + hidden_states, residual = layer( + positions=positions, hidden_states=hidden_states, residual=residual + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -402,8 +418,7 @@ def load_weights(self, weights: Iterable[tuple[str, def _load(n, p): param = params_dict[n] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, p) loaded_params.add(n) @@ -411,20 +426,14 @@ def _load_shard(n, p, shard_id): # Skip layers on other devices. if not is_pp_missing_parameter(n, self): param = params_dict[n] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, p, shard_id) loaded_params.add(n) def _load_expert(n, p, name, shard_id, expert_id): param = params_dict[n] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, - p, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, p, name, shard_id=shard_id, expert_id=expert_id) loaded_params.add(n) for n, p in weights: @@ -437,49 +446,62 @@ def _load_expert(n, p, name, shard_id, expert_id): # to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate) # The renaming and parameter loading logic is the same for weight # and weight_scale tensors so we can reuse them without issues. - if (n.endswith('.block_sparse_moe.input_linear.weight') or - n.endswith('.block_sparse_moe.input_linear.weight_scale')): + if n.endswith(".block_sparse_moe.input_linear.weight") or n.endswith( + ".block_sparse_moe.input_linear.weight_scale" + ): for e in range(p.size(0)): w1_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w1.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w1.weight", + ) w3_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w3.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w3.weight", + ) w1_param, w3_param = p[e].chunk(2, dim=0) - _load_expert(n.replace('.input_linear.', '.experts.w13_'), - w1_param, - w1_name, - shard_id='w1', - expert_id=e) - _load_expert(n.replace('.input_linear.', '.experts.w13_'), - w3_param, - w3_name, - shard_id='w3', - expert_id=e) - elif (n.endswith('.block_sparse_moe.output_linear.weight') or - n.endswith('.block_sparse_moe.output_linear.weight_scale')): + _load_expert( + n.replace(".input_linear.", ".experts.w13_"), + w1_param, + w1_name, + shard_id="w1", + expert_id=e, + ) + _load_expert( + n.replace(".input_linear.", ".experts.w13_"), + w3_param, + w3_name, + shard_id="w3", + expert_id=e, + ) + elif n.endswith(".block_sparse_moe.output_linear.weight") or n.endswith( + ".block_sparse_moe.output_linear.weight_scale" + ): for e in range(p.size(0)): w2_name = n.replace( - '.block_sparse_moe.output_linear.weight', - f".block_sparse_moe.experts.{e}.w2.weight") + ".block_sparse_moe.output_linear.weight", + f".block_sparse_moe.experts.{e}.w2.weight", + ) w2_param = p[e] - _load_expert(n.replace('.output_linear.', '.experts.w2_'), - w2_param, - w2_name, - shard_id='w2', - expert_id=e) - elif n.endswith('.block_sparse_moe.router.layer.weight'): - gate_name = n.replace('.block_sparse_moe.router.layer.weight', - ".block_sparse_moe.gate.weight") + _load_expert( + n.replace(".output_linear.", ".experts.w2_"), + w2_param, + w2_name, + shard_id="w2", + expert_id=e, + ) + elif n.endswith(".block_sparse_moe.router.layer.weight"): + gate_name = n.replace( + ".block_sparse_moe.router.layer.weight", + ".block_sparse_moe.gate.weight", + ) _load(gate_name, p) else: loaded = False for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name in n: - _load_shard(n.replace(weight_name, param_name), - p, - shard_id=shard_id) + _load_shard( + n.replace(weight_name, param_name), p, shard_id=shard_id + ) loaded = True if not loaded: _load(n, p) @@ -487,8 +509,9 @@ def _load_expert(n, p, name, shard_id, expert_id): return loaded_params -class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, - SupportsPP, IsHybrid, SupportsQuant): +class GraniteMoeHybridForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant +): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -507,7 +530,6 @@ def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -549,19 +571,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - if cache_config.enable_prefix_caching: - raise RuntimeError( - "GraniteMoeHybrid currently does not support prefix caching") - self.quant_config = vllm_config.quant_config self.config = config self.scheduler_config = scheduler_config - self.model = GraniteMoeHybridModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = GraniteMoeHybridModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -573,31 +590,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=self.quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + prefix=maybe_prefix(prefix, "lm_head"), + ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=1 / - self.config.logits_scaling) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, + config.vocab_size, + scale=1 / self.config.logits_scaling, + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states @@ -608,7 +631,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index c864856db654..93302821ca68 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -5,6 +5,7 @@ The architecture is the same as granitemoe but with the addition of shared experts. """ + from collections.abc import Iterable from itertools import islice from typing import Optional @@ -18,12 +19,17 @@ from vllm.distributed import get_pp_group from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.sequence import IntermediateTensors from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE @@ -32,7 +38,6 @@ class GraniteMoeSharedMLP(nn.Module): - def __init__( self, config: GraniteMoeSharedConfig, @@ -48,16 +53,20 @@ def __init__( output_sizes=[self.hidden_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.input_linear") + prefix=f"{prefix}.input_linear", + ) self.output_linear = RowParallelLinear( self.hidden_size, self.input_size, bias=False, quant_config=quant_config, - prefix=f"{prefix}.output_linear") + prefix=f"{prefix}.output_linear", + ) if config.hidden_act != "silu": - raise ValueError(f"Unsupported activation: {config.hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -68,7 +77,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GraniteMoeSharedDecoderLayer(nn.Module): - def __init__( self, config: GraniteMoeSharedConfig, @@ -91,26 +99,28 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - attention_multiplier=config.attention_multiplier) + attention_multiplier=config.attention_multiplier, + ) self.block_sparse_moe = GraniteMoeMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.block_sparse_moe") - self.shared_mlp = None if \ - getattr(config, 'shared_intermediate_size', 0) == 0 \ + prefix=f"{prefix}.block_sparse_moe", + ) + self.shared_mlp = ( + None + if getattr(config, "shared_intermediate_size", 0) == 0 else GraniteMoeSharedMLP( - config, - quant_config=quant_config, - prefix=f"{prefix}.shared_mlp" + config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp" ) + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.residual_multiplier = config.residual_multiplier @@ -144,7 +154,6 @@ def forward( @support_torch_compile class GraniteMoeSharedModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -156,8 +165,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config # Required by MixtralModel self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -174,7 +186,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: GraniteMoeSharedDecoderLayer( config, cache_config, quant_config=quant_config, prefix=prefix ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -200,40 +213,46 @@ def forward( for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: new_weights = {} for n, p in weights: - if n.endswith('.block_sparse_moe.input_linear.weight'): + if n.endswith(".block_sparse_moe.input_linear.weight"): for e in range(p.size(0)): w1_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w1.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w1.weight", + ) w3_name = n.replace( - '.block_sparse_moe.input_linear.weight', - f".block_sparse_moe.experts.{e}.w3.weight") + ".block_sparse_moe.input_linear.weight", + f".block_sparse_moe.experts.{e}.w3.weight", + ) w1_param, w3_param = p[e].chunk(2, dim=0) assert w1_name not in new_weights assert w3_name not in new_weights new_weights[w1_name] = w1_param new_weights[w3_name] = w3_param - elif n.endswith('.block_sparse_moe.output_linear.weight'): + elif n.endswith(".block_sparse_moe.output_linear.weight"): for e in range(p.size(0)): w2_name = n.replace( - '.block_sparse_moe.output_linear.weight', - f".block_sparse_moe.experts.{e}.w2.weight") + ".block_sparse_moe.output_linear.weight", + f".block_sparse_moe.experts.{e}.w2.weight", + ) w2_param = p[e] assert w2_name not in new_weights new_weights[w2_name] = w2_param - elif n.endswith('.block_sparse_moe.router.layer.weight'): - gate_name = n.replace('.block_sparse_moe.router.layer.weight', - ".block_sparse_moe.gate.weight") + elif n.endswith(".block_sparse_moe.router.layer.weight"): + gate_name = n.replace( + ".block_sparse_moe.router.layer.weight", + ".block_sparse_moe.gate.weight", + ) assert gate_name not in new_weights new_weights[gate_name] = p else: @@ -268,9 +287,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.lora_config = lora_config - self.model = GraniteMoeSharedModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = GraniteMoeSharedModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -281,16 +300,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + prefix=maybe_prefix(prefix, "lm_head"), + ) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - scale=1 / - self.config.logits_scaling) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, + config.vocab_size, + scale=1 / self.config.logits_scaling, + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -302,30 +324,29 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states - def compute_logits(self, - hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 639d8f620f94..ac78dd9e753a 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -9,11 +9,15 @@ from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, - PoolerHead, PoolerNormalize, - PoolingParamsUpdate, - get_prompt_lens, - get_prompt_token_ids) +from vllm.model_executor.layers.pooler import ( + DispatchPooler, + Pooler, + PoolerHead, + PoolerNormalize, + PoolingParamsUpdate, + get_prompt_lens, + get_prompt_token_ids, +) from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config @@ -47,12 +51,11 @@ def __init__(self, model_config: ModelConfig): def tokens_to_ids(tokens: list[str]) -> np.ndarray: return np.array([self.token_ids[token] for token in tokens]) - self.user_pattern_ids = tokens_to_ids( - ["▁<", "|", "user", "|", ">", "<0x0A>"]) + self.user_pattern_ids = tokens_to_ids(["▁<", "|", "user", "|", ">", "<0x0A>"]) self.embed_newline_pattern_ids = tokens_to_ids( - ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"]) - self.embed_pattern_ids = tokens_to_ids( - ["▁<", "|", "embed", "|", ">", "<0x0A>"]) + ["<0x0A>", "<", "|", "embed", "|", ">", "<0x0A>"] + ) + self.embed_pattern_ids = tokens_to_ids(["▁<", "|", "embed", "|", ">", "<0x0A>"]) def _find_array( self, @@ -86,7 +89,7 @@ def _find_array( end_idx = arr_len for i in range(start_idx, min(end_idx, arr_len - target_len + 1)): - if (arr[i:i + target_len] == target).all(): + if (arr[i : i + target_len] == target).all(): return i return -1 @@ -105,31 +108,37 @@ def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int: # Return no instruction in case of missing BOS token. if prompt_token_ids[0] != self.token_ids["<s>"]: - logger.warning("BOS token not found in prompt, " - "thus using empty string for instruction. " - "GritLM requires BOS token in prompt.") + logger.warning( + "BOS token not found in prompt, " + "thus using empty string for instruction. " + "GritLM requires BOS token in prompt." + ) return instruction_len # If user pattern is found in the prompt, that means there should be # a newline token before the embed pattern. embed_pattern_ids = self.embed_pattern_ids - if self._find_array(prompt_token_ids, - self.user_pattern_ids, - start_idx=1, - end_idx=2) == 1: + if ( + self._find_array( + prompt_token_ids, self.user_pattern_ids, start_idx=1, end_idx=2 + ) + == 1 + ): embed_pattern_ids = self.embed_newline_pattern_ids # Find the embed pattern in the prompt. - found_embed_pattern_idx = self._find_array(prompt_token_ids, - embed_pattern_ids, - start_idx=1) + found_embed_pattern_idx = self._find_array( + prompt_token_ids, embed_pattern_ids, start_idx=1 + ) if found_embed_pattern_idx != -1: instruction_len = found_embed_pattern_idx + len(embed_pattern_ids) else: - logger.warning("Query instruction not found in prompt, " - "thus using BOS token as instruction instead. " - "GritLM requires query instruction in prompt.") + logger.warning( + "Query instruction not found in prompt, " + "thus using BOS token as instruction instead. " + "GritLM requires query instruction in prompt." + ) instruction_len = 1 return instruction_len @@ -146,8 +155,9 @@ def forward_one( prompt_len: Optional[torch.Tensor] = None, instr_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + assert prompt_len is None or prompt_len == hidden_states.shape[0], ( "partial prefill not supported with MEAN pooling" + ) return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32) @@ -161,9 +171,11 @@ def forward_all( pooled_data = list[torch.Tensor]() for prompt_len, instr_len in zip(prompt_lens, instr_lens): - pooled_data.append(hidden_states[offset + instr_len:offset + - prompt_len].mean( - dim=0, dtype=torch.float32)) + pooled_data.append( + hidden_states[offset + instr_len : offset + prompt_len].mean( + dim=0, dtype=torch.float32 + ) + ) offset += prompt_len return pooled_data @@ -184,15 +196,16 @@ def forward( if isinstance(hidden_states, list): return [ - self.forward_one(h, prompt_len, instr_len) for h, prompt_len, - instr_len in zip(hidden_states, prompt_lens, instr_lens) + self.forward_one(h, prompt_len, instr_len) + for h, prompt_len, instr_len in zip( + hidden_states, prompt_lens, instr_lens + ) ] return self.forward_all(hidden_states, prompt_lens, instr_lens) class GritLMPooler(Pooler): - def __init__(self, model_config: ModelConfig): super().__init__() @@ -254,9 +267,9 @@ def __init__( pooler_config = vllm_config.model_config.pooler_config if pooler_config is not None: - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "embed": - GritLMPooler(vllm_config.model_config), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": GritLMPooler(vllm_config.model_config), + } + ) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 996e41fe84ff..f4139685b79f 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Grok1 model.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -36,22 +37,33 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) # Default Grok1-specific constants, overridden by config values if present DEFAULT_ATTN_OUTPUT_MULTIPLIER = 0.08838834764831845 @@ -68,37 +80,43 @@ class Grok1MoE(nn.Module): across ranks. """ - def __init__(self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = ""): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = hidden_size # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, - num_experts, - bias=False, - params_dtype=params_dtype, - quant_config=None, - prefix=f"{prefix}.gate") - - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=tp_size, - activation="gelu", - prefix=f"{prefix}.experts") + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + activation="gelu", + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -112,18 +130,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Grok1Attention(nn.Module): - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - config=None, # Added config parameter + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + config=None, # Added config parameter ) -> None: super().__init__() self.hidden_size = hidden_size @@ -172,19 +189,21 @@ def __init__( is_neox_style=True, ) - attn_logits_soft_cap = max( - getattr(config, "attn_logit_softcapping", 30.0), 0.0) + attn_logits_soft_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=attn_logits_soft_cap, - prefix=f"{prefix}.attn") - self.attn_multiplier = getattr(self.config, "attn_output_multiplier", - 1.0) if self.config else 1.0 + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + prefix=f"{prefix}.attn", + ) + self.attn_multiplier = ( + getattr(self.config, "attn_output_multiplier", 1.0) if self.config else 1.0 + ) def forward( self, @@ -201,7 +220,6 @@ def forward( class Grok1DecoderLayer(nn.Module): - def __init__( self, config, @@ -214,8 +232,7 @@ def __init__( # Check for fp8 quantization self.use_fp8 = False if quant_config is not None: - self.use_fp8 = getattr(quant_config, "is_fp8_w8a8", - lambda: False)() + self.use_fp8 = getattr(quant_config, "is_fp8_w8a8", lambda: False)() if not self.use_fp8 and hasattr(quant_config, "is_fp8"): self.use_fp8 = quant_config.is_fp8 @@ -231,27 +248,26 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", - config=config) # Pass config to Grok1Attention + config=config, + ) # Pass config to Grok1Attention # Grok1 uses "num_experts" in its config num_experts = getattr(config, "num_experts", 8) num_experts_per_tok = getattr(config, "num_experts_per_tok", 2) - self.moe_block = Grok1MoE(num_experts=num_experts, - top_k=num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - prefix=f"{prefix}.moe_block") - - self.pre_attn_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attn_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_moe_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_moe_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.moe_block = Grok1MoE( + num_experts=num_experts, + top_k=num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.moe_block", + ) + + self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -264,8 +280,7 @@ def forward( residual = hidden_states hidden_states = self.pre_attn_norm(hidden_states) else: - hidden_states, residual = self.pre_attn_norm( - hidden_states, residual) + hidden_states, residual = self.pre_attn_norm(hidden_states, residual) hidden_states = self.attn( positions=positions, @@ -285,7 +300,6 @@ def forward( @support_torch_compile class Grok1Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -297,13 +311,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embedding_multiplier_scale = getattr( - config, "embedding_multiplier_scale", - DEFAULT_EMBEDDING_MULTIPLIER_SCALE) + config, "embedding_multiplier_scale", DEFAULT_EMBEDDING_MULTIPLIER_SCALE + ) self.embed_tokens = VocabParallelEmbedding( self.vocab_size, @@ -317,12 +334,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lambda prefix: Grok1DecoderLayer( config, cache_config, quant_config=quant_config, prefix=prefix ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -351,10 +369,9 @@ def forward( hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -367,10 +384,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_gate_proj_name="linear", # Grok1 specific ckpt_down_proj_name="linear_1", # Grok1 specific ckpt_up_proj_name="linear_v", # Grok1 specific - num_experts=num_experts) + num_experts=num_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -382,25 +399,27 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -423,21 +442,25 @@ def load_weights(self, weights: Iterable[tuple[str, # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -453,8 +476,9 @@ def load_weights(self, weights: Iterable[tuple[str, name = name.replace("scale", "weight") param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -482,8 +506,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.model = Grok1Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Grok1Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -502,13 +527,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight self.output_multiplier_scale = getattr( - config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - self.output_multiplier_scale) + config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, self.output_multiplier_scale + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -520,8 +547,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -531,11 +559,9 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Skip lm_head when tie_word_embeddings is True - skip_prefixes = (["lm_head"] - if self.config.tie_word_embeddings else None) + skip_prefixes = ["lm_head"] if self.config.tie_word_embeddings else None loader = AutoWeightsLoader( self, diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index b42df3ad8650..d7ee0fd8fd37 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -18,20 +18,33 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargsItems, MultiModalUUIDDict -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - MultiModalDataItems) -from vllm.multimodal.processing import (MultiModalProcessingInfo, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + MultiModalProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.transformers_utils.tokenizer import AnyTokenizer from .intern_vit import InternVisionModel -from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, - BaseInternVLDummyInputsBuilder, - BaseInternVLMultiModalProcessor, - BaseInternVLProcessingInfo, BaseInternVLProcessor, - InternVLChatModel, build_transform, - find_closest_aspect_ratio, get_internvl_target_ratios) +from .internvl import ( + IMG_CONTEXT, + IMG_END, + IMG_START, + BaseInternVLDummyInputsBuilder, + BaseInternVLMultiModalProcessor, + BaseInternVLProcessingInfo, + BaseInternVLProcessor, + InternVLChatModel, + build_transform, + find_closest_aspect_ratio, + get_internvl_target_ratios, +) def resolve_h2ovl_min_max_num( @@ -61,8 +74,10 @@ def get_h2ovl_target_ratios( # if prior_aspect_ratio is provided, filter the target ratios if prior_aspect_ratio is not None: target_ratios = [ - ratio for ratio in target_ratios if prior_aspect_ratio[0] % - ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0 + ratio + for ratio in target_ratios + if prior_aspect_ratio[0] % ratio[0] != 0 + and prior_aspect_ratio[1] % ratio[1] != 0 ] return target_ratios @@ -207,7 +222,8 @@ def image_to_pixel_values_h2ovl( ) # combine pixel values pixel_values = torch.cat( - [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0) + [pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0 + ) else: pixel_values, _ = _preprocess_image( @@ -223,7 +239,6 @@ def image_to_pixel_values_h2ovl( class H2OVLProcessor(BaseInternVLProcessor): - def __init__( self, config: PretrainedConfig, @@ -270,14 +285,18 @@ def resolve_min_max_num( dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> tuple[int, int]: - min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch - is None else min_dynamic_patch) - max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch - is None else max_dynamic_patch) - dynamic_image_size = (self.dynamic_image_size if dynamic_image_size - is None else dynamic_image_size) - use_thumbnail = (self.use_thumbnail - if use_thumbnail is None else use_thumbnail) + min_dynamic_patch = ( + self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch + ) + max_dynamic_patch = ( + self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch + ) + dynamic_image_size = ( + self.dynamic_image_size + if dynamic_image_size is None + else dynamic_image_size + ) + use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail return resolve_h2ovl_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -318,7 +337,7 @@ def get_num_image_tokens( image_height: int, use_msac: Optional[bool] = None, ) -> int: - use_msac = (self.use_msac if use_msac is None else use_msac) + use_msac = self.use_msac if use_msac is None else use_msac use_thumbnail = self.use_thumbnail @@ -387,12 +406,12 @@ def _images_to_pixel_values_lst( max_num=max_num, use_thumbnail=self.use_thumbnail, use_msac=use_msac, - ) for image in images + ) + for image in images ] class H2OVLProcessingInfo(BaseInternVLProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> H2OVLProcessor: return self.ctx.init_processor( H2OVLProcessor, @@ -419,9 +438,7 @@ def get_num_image_tokens( ) -class H2OVLMultiModalProcessor( - BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]): - +class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]): def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -446,7 +463,8 @@ def _get_prompt_updates( def get_replacement_internvl(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -506,9 +524,9 @@ def _cached_apply_hf_processor( @MULTIMODAL_REGISTRY.register_processor( H2OVLMultiModalProcessor, info=H2OVLProcessingInfo, - dummy_inputs=BaseInternVLDummyInputsBuilder) + dummy_inputs=BaseInternVLDummyInputsBuilder, +) class H2OVLChatModel(InternVLChatModel): - def _init_vision_model( self, config: PretrainedConfig, @@ -520,8 +538,9 @@ def _init_vision_model( if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: - num_hidden_layers = (config.vision_config.num_hidden_layers + - vision_feature_layer + 1) + num_hidden_layers = ( + config.vision_config.num_hidden_layers + vision_feature_layer + 1 + ) else: num_hidden_layers = vision_feature_layer + 1 diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 09f124426fa1..d33406b7be2b 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only HunYuan model compatible with HuggingFace weights.""" + import typing from collections.abc import Callable, Iterable from typing import Any, Optional, Union @@ -35,29 +36,44 @@ from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_layers, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) def _is_moe(config: PretrainedConfig) -> bool: @@ -80,7 +96,6 @@ def _get_cla_factor(config: PretrainedConfig) -> int: class HunYuanMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -108,8 +123,9 @@ def __init__( reduce_results=reduce_results, ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -120,7 +136,6 @@ def forward(self, x): class HunYuanAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -204,10 +219,8 @@ def __init__( ) if self.use_qk_norm: - self.query_layernorm = RMSNorm(self.head_dim, - eps=config.rms_norm_eps) - self.key_layernorm = RMSNorm(self.head_dim, - eps=config.rms_norm_eps) + self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, @@ -221,9 +234,11 @@ def forward( ori_k = k if self.use_qk_norm: q = self.query_layernorm( - q.view(-1, self.num_heads, self.head_dim).contiguous()) + q.view(-1, self.num_heads, self.head_dim).contiguous() + ) k = self.key_layernorm( - k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + k.view(-1, self.num_kv_heads, self.head_dim).contiguous() + ) attn_output = self.attn(q, k, v) # For o_proj @@ -233,7 +248,6 @@ def forward( class HunYuanCrossAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -316,10 +330,8 @@ def __init__( ) if self.use_qk_norm: - self.query_layernorm = RMSNorm(self.head_dim, - eps=config.rms_norm_eps) - self.key_layernorm = RMSNorm(self.head_dim, - eps=config.rms_norm_eps) + self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) def forward( self, @@ -335,9 +347,11 @@ def forward( q, _ = self.rotary_emb(positions, q, k_tmp) if self.use_qk_norm: q = self.query_layernorm( - q.view(-1, self.num_heads, self.head_dim).contiguous()) + q.view(-1, self.num_heads, self.head_dim).contiguous() + ) k = self.key_layernorm( - k.view(-1, self.num_kv_heads, self.head_dim).contiguous()) + k.view(-1, self.num_kv_heads, self.head_dim).contiguous() + ) attn_output = self.attn(q, k, v) # For o_proj @@ -347,7 +361,6 @@ def forward( class HunYuanSparseMoeBlock(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -367,7 +380,8 @@ def __init__( if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") + f"the number of experts {config.num_experts}." + ) # Get layer_id topk if config.moe_topk is a list if isinstance(config.moe_topk, list): @@ -380,9 +394,11 @@ def __init__( # If it is moe, moe_intermediate_size is preferred intermediate_size = config.intermediate_size if config.moe_intermediate_size is not None: - intermediate_size = (config.moe_intermediate_size if isinstance( - config.moe_intermediate_size, int) else - config.moe_intermediate_size[layer_id]) + intermediate_size = ( + config.moe_intermediate_size + if isinstance(config.moe_intermediate_size, int) + else config.moe_intermediate_size[layer_id] + ) # Load balancing settings. vllm_config = get_current_vllm_config() @@ -391,13 +407,12 @@ def __init__( self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) self.experts = FusedMoE( num_experts=self.n_routed_experts, @@ -412,11 +427,13 @@ def __init__( num_redundant_experts=self.n_redundant_experts, ) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) if config.use_mixed_mlp_moe > 0: # Get layer_id num_shared_expert if config.num_shared_expert is # a list. @@ -448,19 +465,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) class HunYuanDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -474,30 +490,37 @@ def __init__( assert layer_id >= 0 self.layer_id = layer_id self.hidden_size = config.hidden_size - self.intermediate_size = (config.intermediate_size if isinstance( - config.intermediate_size, int) else - config.intermediate_size[layer_id]) + self.intermediate_size = ( + config.intermediate_size + if isinstance(config.intermediate_size, int) + else config.intermediate_size[layer_id] + ) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) cla_factor = _get_cla_factor(config) - attention_type = (AttentionType.ENCODER_DECODER - if layer_id >= 0 and layer_id % cla_factor != 0 else - AttentionType.DECODER) + attention_type = ( + AttentionType.ENCODER_DECODER + if layer_id >= 0 and layer_id % cla_factor != 0 + else AttentionType.DECODER + ) if attention_type == AttentionType.DECODER: self.self_attn = HunYuanAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -512,8 +535,9 @@ def __init__( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -544,10 +568,10 @@ def __init__( prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -561,8 +585,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, ori_kv_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -570,15 +593,13 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual, ori_kv_states @support_torch_compile class HunYuanModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -593,12 +614,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -656,25 +681,27 @@ def forward( prev_kv_states, ) - if (getattr(self.config, "use_cla", False) - and (i - self.start_layer) % cla_factor == 0): + if ( + getattr(self.config, "use_cla", False) + and (i - self.start_layer) % cla_factor == 0 + ): prev_kv_states = kv_states else: prev_kv_states = None if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def _split_qkv_weight(self, qkv: torch.Tensor): num_attention_heads = self.config.num_attention_heads - num_kv_heads = getattr(self.config, "num_key_value_heads", - self.config.num_attention_heads) + num_kv_heads = getattr( + self.config, "num_key_value_heads", self.config.num_attention_heads + ) num_key_value_groups = num_attention_heads // num_kv_heads hidden_size = self.config.hidden_size @@ -685,8 +712,9 @@ def _split_qkv_weight(self, qkv: torch.Tensor): else: attention_head_dim = self.config.hidden_size // num_attention_heads - qkv = qkv.reshape(num_kv_heads, num_key_value_groups + 2, - attention_head_dim, hidden_size) + qkv = qkv.reshape( + num_kv_heads, num_key_value_groups + 2, attention_head_dim, hidden_size + ) q, k, v = torch.split(qkv, (num_key_value_groups, 1, 1), dim=1) q = q.reshape(-1, hidden_size) k = k.reshape(-1, hidden_size) @@ -719,16 +747,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ] num_attention_heads = self.config.num_attention_heads - num_kv_heads = getattr(self.config, "num_key_value_heads", - self.config.num_attention_heads) + num_kv_heads = getattr( + self.config, "num_key_value_heads", self.config.num_attention_heads + ) split_params_mapping = [ (".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None), ( ".qkv_proj", ".qkv_proj", num_attention_heads + num_kv_heads * 2, - [("q", num_attention_heads), ("k", num_kv_heads), - ("v", num_kv_heads)], + [("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)], self._split_qkv_weight, ), ] @@ -743,8 +771,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): name = name.replace("gate_proj_bias", "gate_proj.bias") if "up_proj_bias" in name: name = name.replace("up_proj_bias", "up_proj.bias") - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue @@ -754,11 +781,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): if self.config.tie_word_embeddings and "lm_head.weight" in name: continue if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name)): + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache scales for compressed-tensors quantization param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) continue @@ -794,11 +821,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): continue for ( - param_name, - weight_name, - den, - split_param, - func, + param_name, + weight_name, + den, + split_param, + func, ) in split_params_mapping: if weight_name not in name: continue @@ -819,12 +846,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): for shard_id, num in split_param: new_offset = offset + num * units if func: - weight_loader(param, - func(loaded_weight)[offset:new_offset], - shard_id) + weight_loader( + param, func(loaded_weight)[offset:new_offset], shard_id + ) else: - weight_loader(param, loaded_weight[offset:new_offset], - shard_id) + weight_loader(param, loaded_weight[offset:new_offset], shard_id) offset = new_offset break @@ -850,8 +876,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) success = weight_loader( param, loaded_weight, @@ -881,8 +908,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): name = name.replace("wg.", "") param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -924,9 +952,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() @@ -937,8 +965,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( @@ -949,25 +978,23 @@ def compute_logits( return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) @@ -976,7 +1003,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: class HunYuanMoEV1Base(HunyuanV1ModelBase, MixtureOfExperts): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -1028,8 +1054,7 @@ def update_physical_experts_metadata( assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) + self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, HunYuanSparseMoeBlock): moe = layer.mlp @@ -1043,7 +1068,6 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: class HunYuanDenseV1Base(HunyuanV1ModelBase): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -1053,4 +1077,4 @@ class HunYuanDenseV1ForCausalLM(HunYuanDenseV1Base): class HunYuanMoEV1ForCausalLM(HunYuanMoEV1Base): - pass \ No newline at end of file + pass diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 10d3bc8464ba..611c14733c71 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -2,27 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # copied from : https://github.com/huggingface/transformers import ast -import sys from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial -from itertools import chain -from typing import Any, Literal, Optional, TypedDict, Union +from itertools import accumulate +from typing import Annotated, Any, Literal, Optional, Union import numpy as np -import PIL -from einops import rearrange -from PIL import Image - -if sys.version_info >= (3, 11): - import typing - Unpack = typing.Unpack -else: - import typing_extensions - Unpack = typing_extensions.Unpack - import torch import torch.nn as nn +from einops import rearrange from timm.layers import LayerNorm, LayerNorm2d from timm.models.regnet import RegStage from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig @@ -33,20 +22,32 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageSize, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - InputProcessingContext, - PromptReplacement, PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix +from .utils import ( + AutoWeightsLoader, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vision_encoder_info EOT = "<|endofturn|>" @@ -57,8 +58,8 @@ # Based on combine_frames_into_images in # https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py def get_num_combined_frames( - num_frames: int, - max_grid_shape: tuple[int, int] = (3, 3), + num_frames: int, + max_grid_shape: tuple[int, int] = (3, 3), ) -> int: max_num_grids = max_grid_shape[0] * max_grid_shape[1] @@ -69,32 +70,48 @@ def get_num_combined_frames( return num_canvases + (leftover_frames > 0) -class HCXVisionMultimodalPixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values_images: list[torch.Tensor] +class HCXVisionImagePixelInputs(TensorSchema): """ - Shape: `[(num_grids, num_channels, height, width), ...]` if anyres - - Note that `height` or `width` may be different per batch and image, - in which case the data is passed as a list instead of a batched tensor. + Dimensions: + - n: Number of images + - g: Number of grids + - c: Number of channels (3) + - h: Height + - w: Width """ - image_sizes_images: list[tuple[Union[int, float]]] - """ - Shape: `[(height, width), ...]` - """ - vision_query_lengths_images: list[Union[int, float]] - pixel_values_videos: list[tuple[Union[int, float]]] + + type: Literal["pixel_values"] = "pixel_values" + pixel_values_images: Annotated[ + list[torch.Tensor], TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"}) + ] + image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)] + + +HCXVisionImageInputs = HCXVisionImagePixelInputs + + +class HCXVisionVideoPixelInputs(TensorSchema): """ - Shape: `[(num_grids, num_channels, height, width), ...]` if anyres + Dimensions: + - n: Number of videos + - f: Number of frames + - g: Number of grids + - c: Number of channels (3) + - h: Height + - w: Width """ - vision_query_lengths_videos: list[Union[int, float]] + type: Literal["pixel_values_videos"] = "pixel_values_videos" + pixel_values_videos: Annotated[ + list[list[torch.Tensor]], + TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"}), + ] -HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs] +HCXVisionVideoInputs = HCXVisionVideoPixelInputs -class HCXVisionProcessingInfo(BaseProcessingInfo): +class HCXVisionProcessingInfo(BaseProcessingInfo): def get_vision_encoder_info(self): return get_vision_encoder_info(self.get_hf_config()) @@ -135,15 +152,14 @@ def get_max_image_tokens(self) -> int: ) -class HCXVisionDummyInputsBuilder( - BaseDummyInputsBuilder[HCXVisionProcessingInfo]): - +class HCXVisionDummyInputsBuilder(BaseDummyInputsBuilder[HCXVisionProcessingInfo]): def get_dummy_text( self, mm_counts: Mapping[str, int], ) -> str: dummy_text = IMAGE_TOKEN * mm_counts.get( - "image", 0) + VIDEO_TOKEN * mm_counts.get("video", 0) + "image", 0 + ) + VIDEO_TOKEN * mm_counts.get("video", 0) return dummy_text def get_dummy_mm_data( @@ -155,35 +171,30 @@ def get_dummy_mm_data( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() target_num_frames = 32 image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images( + "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, overrides=image_overrides, ), - "video": - self._get_dummy_videos( + "video": self._get_dummy_videos( width=target_width - 1, height=target_height - 1, num_frames=target_num_frames, num_videos=num_videos, overrides=video_overrides, - ) + ), } -class HCXVisionMultiModalProcessor( - BaseMultiModalProcessor[HCXVisionProcessingInfo]): - +class HCXVisionMultiModalProcessor(BaseMultiModalProcessor[HCXVisionProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -191,27 +202,9 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: - - def replace_multimodal_token( - token_ids: torch.Tensor, - target_token: int, - repeats: list[int], - ): - output = list[int]() - _repeats_idx = 0 - for token_id in token_ids: - if token_id == target_token: - output += [token_id.item()] * repeats[_repeats_idx] - _repeats_idx += 1 - else: - output += [token_id.item()] - - return torch.tensor(output, device=token_ids.device) - for video_idx, video_arr in enumerate(mm_data.get("videos", [])): - if video_arr.dtype == np.uint8: - continue - mm_data["videos"][video_idx] = video_arr.astype(np.uint8) + if video_arr.dtype != np.uint8: + mm_data["videos"][video_idx] = video_arr.astype(np.uint8) processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), @@ -223,20 +216,16 @@ def replace_multimodal_token( ) # text-only if len(mm_data) > 0: - # batchify input as a single item - images = mm_data.get("images", None) - batched_images = None if images is None else [images] - - # list of video in single conversation - videos = mm_data.get("videos", None) - batched_videos = None if videos is None else [videos] + images = mm_data.get("images") + videos = mm_data.get("videos") + # batchify input as a single item _processed_outputs = self.info.ctx.call_hf_processor( hf_processor=self.info.get_hf_processor(**mm_kwargs), data=dict( text=None, - images=batched_images, - videos=batched_videos, + images=None if images is None else [images], + videos=None if videos is None else [videos], ), ) # mm-only @@ -246,51 +235,48 @@ def replace_multimodal_token( _processed_outputs[k] = v[0] if images: - tokenizer = self.info.get_tokenizer() - image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) - processed_outputs["input_ids"] = torch.stack([ - replace_multimodal_token( - token_ids=_input_ids, - target_token=image_token_id, - repeats=_processed_outputs[ - "vision_query_lengths_images"], - ) for _input_ids in processed_outputs["input_ids"] - ], - dim=0) + _processed_outputs["image_sizes_images"] = torch.tensor( + _processed_outputs["image_sizes_images"] + ) + _processed_outputs["vision_query_lengths_images"] = torch.tensor( + _processed_outputs["vision_query_lengths_images"] + ) if videos: - _num_per_videos = [ - get_num_combined_frames(len(video)) for video in videos + _idx_per_video = [ + 0, + *accumulate( + get_num_combined_frames(len(video)) for video in videos + ), ] _processed_outputs["pixel_values_videos"] = [ - _processed_outputs["pixel_values_videos"] - [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(len(videos)) + _processed_outputs["pixel_values_videos"][ + _idx_per_video[i] : _idx_per_video[i + 1] + ] + for i in range(len(videos)) ] _processed_outputs["vision_query_lengths_videos"] = [ - _processed_outputs["vision_query_lengths_videos"] - [sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])] - for _i in range(len(videos)) + torch.tensor( + _processed_outputs["vision_query_lengths_videos"][ + _idx_per_video[i] : _idx_per_video[i + 1] + ] + ) + for i in range(len(videos)) ] - tokenizer = self.info.get_tokenizer() - video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN) - processed_outputs["input_ids"] = torch.stack([ - replace_multimodal_token( - token_ids=_input_ids, - target_token=video_token_id, - repeats=[ - sum(lens) for lens in - _processed_outputs["vision_query_lengths_videos"] - ], - ) for _input_ids in processed_outputs["input_ids"] - ], - dim=0) - processed_outputs.update(_processed_outputs) return processed_outputs + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -311,13 +297,11 @@ def get_replacement_hyperclovax( out_item = out_mm_kwargs[modality][item_idx] if modality == "image": - lens = out_item["vision_query_lengths_images"].data - num_tokens = self.info.get_num_image_tokens( - vision_query_length=lens) + lens = out_item["vision_query_lengths_images"].data.tolist() + num_tokens = self.info.get_num_image_tokens(vision_query_length=lens) elif modality == "video": - lens = out_item["vision_query_lengths_videos"].data - num_tokens = self.info.get_num_video_tokens( - vision_query_length=lens) + lens = out_item["vision_query_lengths_videos"].data.tolist() + num_tokens = self.info.get_num_video_tokens(vision_query_length=lens) else: raise NotImplementedError(modality) @@ -334,7 +318,8 @@ def get_replacement_hyperclovax( modality=modality, out_mm_kwargs=out_mm_kwargs, ), - ) for modality in ("image", "video") + ) + for modality in ("image", "video") ] def _get_mm_fields_config( @@ -343,31 +328,17 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - # image pixel_values_images=MultiModalFieldConfig.batched("image"), image_sizes_images=MultiModalFieldConfig.batched("image"), vision_query_lengths_images=MultiModalFieldConfig.batched("image"), - num_queries_vis_abstractors_images=MultiModalFieldConfig.batched( - "image"), - num_queries_vis_abstractors_slow_images=MultiModalFieldConfig. - batched("image"), - first_last_frames_slows_images=MultiModalFieldConfig.batched( - "image"), - # video pixel_values_videos=MultiModalFieldConfig.batched("video"), - image_sizes_videos=MultiModalFieldConfig.batched("video"), vision_query_lengths_videos=MultiModalFieldConfig.batched("video"), - num_queries_vis_abstractors_videos=MultiModalFieldConfig.batched( - "video"), - num_queries_vis_abstractors_slow_videos=MultiModalFieldConfig. - batched("video"), - first_last_frames_slows_videos=MultiModalFieldConfig.batched( - "video"), ) def _build_hcxvision_hf_info( - ctx: InputProcessingContext, ) -> HCXVisionProcessingInfo: + ctx: InputProcessingContext, +) -> HCXVisionProcessingInfo: return HCXVisionProcessingInfo(ctx) @@ -425,7 +396,6 @@ def init_vision_tower_for_hcxvision( class HCXVisionMlp(nn.Module): - def __init__( self, mm_projector_type, @@ -447,8 +417,9 @@ def __init__( self.act = act_layer() self.fc2 = nn.Linear(2 * hidden_features, out_features) else: - raise NotImplementedError("{} is not implemented".format( - self.mm_projector_type)) + raise NotImplementedError( + "{} is not implemented".format(self.mm_projector_type) + ) def forward(self, x): x = self.fc1(x) @@ -460,7 +431,7 @@ def forward(self, x): class HCXVisionCAbstractor(nn.Module): """ This module is based on C-Abstractor, whose license is under apache-2.0. - You can check the original code at + You can check the original code at https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py and we made necessary modifications. """ @@ -482,7 +453,8 @@ def __init__( # Positional embedding if pos_emb: self.pos_emb = torch.nn.Parameter( - torch.zeros(1, num_input_tokens, encoder_hidden_size)) + torch.zeros(1, num_input_tokens, encoder_hidden_size) + ) self.pos_emb.data.normal_(mean=0.0, std=0.02) else: self.pos_emb = None @@ -493,8 +465,9 @@ def __init__( else: self.prenorm = None - self.build_net(num_queries, encoder_hidden_size, hidden_size, - output_hidden_size) + self.build_net( + num_queries, encoder_hidden_size, hidden_size, output_hidden_size + ) self.dtype = next(self.parameters()).dtype def forward( @@ -531,7 +504,8 @@ def _forward( if num_queries_vis_abstractors is not None: assert num_grids is not None return self._forward_adaptive_num_query( - x, num_queries_vis_abstractors, num_grids) + x, num_queries_vis_abstractors, num_grids + ) x = self.net(x) x = rearrange(x, "b d h w -> b (h w) d") @@ -552,7 +526,7 @@ def _forward_adaptive_num_query( for i, num_queries in enumerate(num_queries_vis_abstractors): hw = int(num_queries**0.5) sampler = nn.AdaptiveAvgPool2d((hw, hw)) - out = sampler(x[num_grids[i]:num_grids[i + 1], :]) + out = sampler(x[num_grids[i] : num_grids[i + 1], :]) out = self.net[2](out) # s2 out = rearrange(out, "b d h w -> b (h w) d") @@ -570,8 +544,9 @@ def build_net( depth: int = 3, mlp_depth: int = 2, ): - assert (n_queries**0.5).is_integer( - ), f"n_queries must be square number. n_queries: {n_queries}" + assert (n_queries**0.5).is_integer(), ( + f"n_queries must be square number. n_queries: {n_queries}" + ) hw = int(n_queries**0.5) # RegBlock = ResBlock + SE @@ -596,8 +571,7 @@ def build_net( ) self.net = nn.Sequential(s1, sampler, s2) - self.readout = self.build_mlp(mlp_depth, hidden_size, - output_hidden_size) + self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size) def build_mlp( self, @@ -615,12 +589,14 @@ def build_mlp( @MULTIMODAL_REGISTRY.register_processor( _build_hcxvision_hf_processor, info=_build_hcxvision_hf_info, - dummy_inputs=HCXVisionDummyInputsBuilder) + dummy_inputs=HCXVisionDummyInputsBuilder, +) class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__( @@ -650,7 +626,8 @@ def __init__( ## possible_resolution should be matched with preprocessor_config.json config.possible_resolutions = self._init_possible_resolutions( - config, vision_config) + config, vision_config + ) # init models & parameters with no_init_weights(): # weight will be loaded in from_pretrained @@ -661,11 +638,11 @@ def __init__( require_post_norm=False, prefix=maybe_prefix(prefix, "vision_model"), ) - self.mm_projector = self._init_mm_projector(config, text_config, - vision_config) + self.mm_projector = self._init_mm_projector(config, text_config, vision_config) - self.lm_head_vocab_size = getattr(text_config, "padded_vocab_size", - text_config.vocab_size) + self.lm_head_vocab_size = getattr( + text_config, "padded_vocab_size", text_config.vocab_size + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=text_config, @@ -674,7 +651,8 @@ def __init__( if config.anyres: self.image_newline = nn.Parameter( - torch.empty(text_config.hidden_size, dtype=self.dtype)) + torch.empty(text_config.hidden_size, dtype=self.dtype) + ) self.config = config self.vision_config = vision_config @@ -692,55 +670,92 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: raise ValueError("Only image or video modality is supported") + def _parse_and_validate_image_input( + self, + **kwargs: object, + ) -> Optional[HCXVisionImageInputs]: + pixel_values_images = kwargs.pop("pixel_values_images", None) + + if pixel_values_images is None: + return None + + image_sizes_images = kwargs.pop("image_sizes_images") + + return HCXVisionImagePixelInputs( + pixel_values_images=pixel_values_images, + image_sizes_images=image_sizes_images, + ) + + def _parse_and_validate_video_input( + self, + **kwargs: object, + ) -> Optional[HCXVisionVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + + if pixel_values_videos is None: + return None + + return HCXVisionVideoPixelInputs( + pixel_values_videos=pixel_values_videos, + ) + + def _process_image_input( + self, + image_input: HCXVisionImageInputs, + ) -> tuple[torch.Tensor, ...]: + return self.forward_images( + pixel_values_images=image_input["pixel_values_images"], + image_sizes_images=image_input["image_sizes_images"], + ) + + def _process_video_input( + self, + video_input: HCXVisionVideoInputs, + ) -> tuple[torch.Tensor, ...]: + return self.forward_videos( + pixel_values_videos=video_input["pixel_values_videos"], + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key == "pixel_values_images" and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if input_key == "pixel_values_videos" and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) + + return modalities + def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( self, - **kwargs: Unpack[HCXVisionMultimodalInputs], + **kwargs: object, ) -> MultiModalEmbeddings: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings - multimodal_embeddings = list() - if kwargs.get("pixel_values_images") is not None: - for _pixel_values_images, _image_sizes_images in zip( - kwargs["pixel_values_images"], - kwargs["image_sizes_images"]): - _pixel_values_images = _pixel_values_images.unsqueeze(dim=0) - _image_sizes_images = _image_sizes_images.unsqueeze(dim=0) - _len_pixel_values_images = [ - len(pixel_value) for pixel_value in _pixel_values_images - ] - if isinstance(_image_sizes_images, torch.Tensor): - _image_sizes_images = _image_sizes_images.detach().cpu( - ).tolist() - _multimodal_embeddings_images = self.forward_images( - pixel_values_images=_pixel_values_images, - image_sizes_images=_image_sizes_images, - len_pixel_values_images=_len_pixel_values_images, - ) - _multimodal_embeddings_images = torch.cat( - _multimodal_embeddings_images, dim=0) - multimodal_embeddings.append(_multimodal_embeddings_images) - - if kwargs.get("pixel_values_videos") is not None: - for _pixel_values_videos, _vision_query_lengths_videos in zip( - kwargs["pixel_values_videos"], - kwargs["vision_query_lengths_videos"]): - _len_pixel_values_videos = [ - len(_vision_query_lengths) - for _vision_query_lengths in _vision_query_lengths_videos - ] - _c, _w, _h = _pixel_values_videos.shape[-3:] - _pixel_values_videos = _pixel_values_videos.reshape( - sum(_len_pixel_values_videos), -1, _c, _w, - _h).unsqueeze(dim=0) - _multimodal_embeddings_videos = self.forward_videos( - pixel_values_videos=_pixel_values_videos, - len_pixel_values_videos=_len_pixel_values_videos, - ) - _multimodal_embeddings_videos = torch.cat( - _multimodal_embeddings_videos, dim=0) - multimodal_embeddings.append(_multimodal_embeddings_videos) return multimodal_embeddings def forward( @@ -754,87 +769,66 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def forward_images( self, - pixel_values_images: list[list[torch.FloatTensor]], - image_sizes_images: list[list[tuple[int, int]]], - len_pixel_values_images: list[int], - ) -> list[list[torch.Tensor]]: - if sum(len_pixel_values_images) == 0: - return None - - concat_pixel_values_images = torch.cat(list( - chain(*pixel_values_images)), - dim=0) + pixel_values_images: list[torch.Tensor], + image_sizes_images: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True) visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 - image_forward_outs = self.vision_model( - concat_pixel_values_images)[:, visual_token_idx:] + image_forward_outs = self.vision_model(pixel_values_image_flat)[ + :, visual_token_idx: + ] - image_forward_outs = image_forward_outs.to( - dtype=self.mm_projector.dtype) + image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype) image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d - split_sizes = [ - pixel_value.shape[0] for pixel_value in chain(*pixel_values_images) - ] - image_forward_outs = torch.split(image_forward_outs, - split_sizes, - dim=0) + split_sizes = [len(item) for item in pixel_values_images] + image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0) # newline for anyres postprocessing image_features = anyres_postprocessing( image_forward_outs=image_forward_outs, - image_sizes=[ - image_size for image_sizes in image_sizes_images - for image_size in image_sizes - ], - num_queries_vis_abstractor=self.config. - num_queries_vis_abstractor_image, + image_sizes=image_sizes_images.tolist(), + num_queries_vis_abstractor=self.config.num_queries_vis_abstractor_image, unpad=self.config.unpad, patch_size=self.vision_config.patch_size, grid_size=self.vision_config.image_size, image_newline=self.image_newline, possible_resolutions=self.config.possible_resolutions, ) - return image_features + + return tuple(image_features) def forward_videos( self, - pixel_values_videos: list[list[torch.FloatTensor]], - len_pixel_values_videos: list[int], - ) -> list[torch.Tensor]: - - len_video_grids = sum(len_pixel_values_videos) - if len_video_grids == 0: - return None - - # Run Vision Model - concat_pixel_values_videos = torch.cat(list( - chain(*pixel_values_videos)), - dim=0) + pixel_values_videos: list[list[torch.Tensor]], + ) -> tuple[torch.Tensor, ...]: + pixel_values_videos_flat = flatten_bn( + [frame for frames in pixel_values_videos for frame in frames], + concat=True, + ) visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 - video_forward_outs = self.vision_model( - concat_pixel_values_videos)[:, visual_token_idx:] + video_forward_outs = self.vision_model(pixel_values_videos_flat)[ + :, visual_token_idx: + ] - video_forward_outs = video_forward_outs.to( - dtype=self.mm_projector.dtype) + video_forward_outs = video_forward_outs.to(dtype=self.mm_projector.dtype) # Run MM-Projector # len(num_grids) == len(num_queries_vis_abstractors) + 1 grid_idx = 0 - num_grids = [ - grid_idx - ] # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56] - num_queries_vis_abstractors = [ - ] # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9] + # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56] + num_grids = [grid_idx] + # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9] + num_queries_vis_abstractors = [] len_total_frames = video_forward_outs.shape[0] if self.config.first_last_frames_slow: @@ -842,22 +836,26 @@ def forward_videos( assert len_total_frames != 0 if len_total_frames <= 2: num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow) + self.config.num_queries_vis_abstractor_video_slow + ) grid_idx += len_total_frames num_grids.append(grid_idx) else: num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow) + self.config.num_queries_vis_abstractor_video_slow + ) grid_idx += 1 num_grids.append(grid_idx) num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_fast) + self.config.num_queries_vis_abstractor_video_fast + ) grid_idx += len_total_frames - 2 num_grids.append(grid_idx) num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow) + self.config.num_queries_vis_abstractor_video_slow + ) grid_idx += 1 num_grids.append(grid_idx) else: @@ -866,17 +864,19 @@ def forward_videos( for pixel_values_frame in pixel_values_frames: if len(pixel_values_frame) > 0: num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_slow) + self.config.num_queries_vis_abstractor_video_slow + ) grid_idx += 1 num_grids.append(grid_idx) num_queries_vis_abstractors.append( - self.config.num_queries_vis_abstractor_video_fast) + self.config.num_queries_vis_abstractor_video_fast + ) grid_idx = grid_idx + len(pixel_values_frame) - 1 num_grids.append(grid_idx) - video_forward_outs = self.mm_projector(video_forward_outs, - num_queries_vis_abstractors, - num_grids) + video_forward_outs = self.mm_projector( + video_forward_outs, num_queries_vis_abstractors, num_grids + ) video_features = [] # what we want to return target_features = [] @@ -898,14 +898,19 @@ def forward_videos( target_group_size = 0 elif video_group_size < target_group_size: - raise RuntimeError( - f"{video_group_size=} < {target_group_size=}") + raise RuntimeError(f"{video_group_size=} < {target_group_size=}") - assert len(target_features - ) == 0, f"target_features is not empty!! {target_features}" + assert len(target_features) == 0, ( + f"target_features is not empty!! {target_features}" + ) assert len(video_groups) == len(video_features) - return video_features + feats_per_video = [len(video) for video in pixel_values_videos] + idxs_per_video = [0, *accumulate(feats_per_video)] + return tuple( + torch.cat(video_features[idxs_per_video[i] : idxs_per_video[i + 1]]) + for i in range(len(feats_per_video)) + ) def _prepare_multimodal_kwargs(self, **kwargs: object): output = defaultdict(list) @@ -914,7 +919,7 @@ def _prepare_multimodal_kwargs(self, **kwargs: object): continue # if empty batch of empty sample new_k, is_video = k, False - if (not k.endswith("_images") and not k.endswith("_videos")): + if not k.endswith("_images") and not k.endswith("_videos"): pass else: new_k, is_video = k.split("_")[:-1], k.split("_")[-1] @@ -967,10 +972,10 @@ def _init_possible_resolutions( if i * j <= config.max_num_grids: possible_resolutions.append([i, j]) - possible_resolutions = [[ - ys * vision_config.image_size, - xs * vision_config.image_size - ] for ys, xs in possible_resolutions] + possible_resolutions = [ + [ys * vision_config.image_size, xs * vision_config.image_size] + for ys, xs in possible_resolutions + ] return possible_resolutions else: return config.possible_resolutions @@ -983,14 +988,13 @@ def _init_mm_projector( ): input_hidden_size = vision_config.hidden_size if config.mm_projector_type == "linear": - mm_projector = nn.Linear(input_hidden_size, - text_config.hidden_size) + mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size) mm_projector.dtype = next(mm_projector.parameters()).dtype elif config.mm_projector_type == "cabstractor": mm_projector = HCXVisionCAbstractor( num_queries=config.num_queries_vis_abstractor_image, - num_input_tokens=(vision_config.image_size // - vision_config.patch_size)**2, + num_input_tokens=(vision_config.image_size // vision_config.patch_size) + ** 2, encoder_hidden_size=input_hidden_size, hidden_size=input_hidden_size, output_hidden_size=text_config.hidden_size, @@ -1007,8 +1011,7 @@ def _init_mm_projector( return mm_projector -def unpad_image(tensor: torch.Tensor, - original_size: tuple[int, int]) -> torch.Tensor: +def unpad_image(tensor: torch.Tensor, original_size: tuple[int, int]) -> torch.Tensor: original_width, original_height = original_size current_height, current_width = tensor.shape[1:] @@ -1019,18 +1022,17 @@ def unpad_image(tensor: torch.Tensor, scale_factor = current_width / original_width new_height = int(original_height * scale_factor) padding = (current_height - new_height) // 2 - unpadded_tensor = tensor[:, padding:current_height - padding, :] + unpadded_tensor = tensor[:, padding : current_height - padding, :] else: scale_factor = current_height / original_height new_width = int(original_width * scale_factor) padding = (current_width - new_width) // 2 - unpadded_tensor = tensor[:, :, padding:current_width - padding] + unpadded_tensor = tensor[:, :, padding : current_width - padding] return unpadded_tensor -def select_best_resolution(original_size: tuple, - possible_resolutions: list) -> tuple: +def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: original_height, original_width = original_size best_fit = None max_effective_resolution = 0 @@ -1038,15 +1040,19 @@ def select_best_resolution(original_size: tuple, for height, width in possible_resolutions: scale = min(width / original_width, height / original_height) - downscaled_width, downscaled_height = int(original_width * scale), int( - original_height * scale) - effective_resolution = min(downscaled_width * downscaled_height, - original_width * original_height) + downscaled_width, downscaled_height = ( + int(original_width * scale), + int(original_height * scale), + ) + effective_resolution = min( + downscaled_width * downscaled_height, original_width * original_height + ) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or ( - effective_resolution == max_effective_resolution - and wasted_resolution < min_wasted_resolution): + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution + ): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (height, width) @@ -1059,12 +1065,16 @@ def get_anyres_image_grid_shape( grid_pinpoints: Union[str, list[tuple[int, int]]], patch_size: int, ) -> tuple[int, int]: - possible_resolutions = grid_pinpoints if isinstance( - grid_pinpoints, list) else ast.literal_eval(grid_pinpoints) + possible_resolutions = ( + grid_pinpoints + if isinstance(grid_pinpoints, list) + else ast.literal_eval(grid_pinpoints) + ) original_width, original_height = image_size - height, width = select_best_resolution((original_height, original_width), - possible_resolutions) + height, width = select_best_resolution( + (original_height, original_width), possible_resolutions + ) return width // patch_size, height // patch_size @@ -1082,12 +1092,15 @@ def reshape_and_unpad_image_features( image_feature = image_feature[1:] assert height * width == base_image_feature.shape[0], ( - f"{height=} * {width=} != {base_image_feature.shape[0]=}") + f"{height=} * {width=} != {base_image_feature.shape[0]=}" + ) num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_size, possible_resolutions, grid_size) - image_feature = image_feature.view(num_patch_height, num_patch_width, - height, width, -1) + image_size, possible_resolutions, grid_size + ) + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) if unpad: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() @@ -1096,8 +1109,9 @@ def reshape_and_unpad_image_features( image_feature = torch.cat( ( image_feature, - image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1).to(image_feature.device), + image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), ), dim=-1, ) @@ -1111,20 +1125,21 @@ def reshape_and_unpad_image_features( def anyres_postprocessing( - image_forward_outs: list[torch.FloatTensor], + image_forward_outs: list[torch.Tensor], image_sizes: list[list[int]], possible_resolutions: list[tuple[int, int]], patch_size: int, grid_size: int, - image_newline: torch.FloatTensor, + image_newline: torch.Tensor, num_queries_vis_abstractor: int = -1, unpad: bool = False, -) -> list[torch.FloatTensor]: +) -> list[torch.Tensor]: height = width = grid_size // patch_size if num_queries_vis_abstractor > 0: - assert (num_queries_vis_abstractor**0.5 - ).is_integer(), "n_queries must be square number" + assert (num_queries_vis_abstractor**0.5).is_integer(), ( + "n_queries must be square number" + ) height = width = int(num_queries_vis_abstractor**0.5) # post-processing (unpad, add newline) @@ -1144,29 +1159,8 @@ def anyres_postprocessing( else: image_feature = image_feature[0] image_feature = torch.cat( - (image_feature, image_newline[None].to(image_feature.device)), - dim=0) + (image_feature, image_newline[None].to(image_feature.device)), dim=0 + ) new_image_features.append(image_feature) - image_features = new_image_features - return image_features - - -def resize_image( - image: Union[np.ndarray, PIL.Image.Image], - max_side: int = 378, -) -> np.ndarray: - image_arr = image - if isinstance(image, np.ndarray): - image = Image.fromarray(image) - - width, height = image.size - cur_max_size = max(width, height) - if cur_max_size <= max_side: - return image_arr - - scale = max_side / cur_max_size - width = int(width * scale) - height = int(height * scale) - image = image.resize((width, height), Image.LANCZOS) - image_arr = np.array(image) - return image_arr + + return new_image_features diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 2f0c4240413b..02c46a11a179 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -24,14 +24,18 @@ import torch from torch import nn from transformers.models.idefics2.configuration_idefics2 import ( - Idefics2Config, Idefics2VisionConfig) + Idefics2Config, + Idefics2VisionConfig, +) from vllm.attention.layer import MultiHeadAttention from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -67,13 +71,14 @@ def __init__(self, config: Idefics2VisionConfig): self.num_patches_per_side = self.image_size // self.patch_size self.num_patches = self.num_patches_per_side**2 self.num_positions = self.num_patches - self.position_embedding = nn.Embedding(self.num_positions, - self.embed_dim) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - def forward(self, - pixel_values: torch.FloatTensor, - patch_attention_mask: torch.BoolTensor, - tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: + def forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: batch_size, _, max_im_h, max_im_w = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(target_dtype)) @@ -82,14 +87,14 @@ def forward(self, max_im_h // self.patch_size, max_im_w // self.patch_size, ) - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, - 1 / self.num_patches_per_side) - position_ids = torch.full(size=(batch_size, - max_nb_patches_h * max_nb_patches_w), - fill_value=0) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): - if tgt_sizes is not None: nb_patches_h = tgt_sizes[batch_idx][0] nb_patches_w = tgt_sizes[batch_idx][1] @@ -98,14 +103,15 @@ def forward(self, nb_patches_w = p_attn_mask[0].sum() fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) - bucket_coords_h = torch.bucketize(fractional_coords_h, - boundaries, - right=True) - bucket_coords_w = torch.bucketize(fractional_coords_w, - boundaries, - right=True) - pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + - bucket_coords_w).flatten() + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) embeddings += self.position_embedding(position_ids) @@ -130,12 +136,12 @@ def __init__( if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501 - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() assert self.num_heads % tp_size == 0 self.num_heads_per_partition = self.num_heads // tp_size @@ -156,8 +162,9 @@ def __init__( disable_tp=use_data_parallel, ) # Use unified MultiHeadAttention with Flash Attention support - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) def forward( self, @@ -175,7 +182,6 @@ def forward( class Idefics2VisionMLP(nn.Module): - def __init__( self, config: Idefics2VisionConfig, @@ -211,7 +217,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Idefics2EncoderLayer(nn.Module): - def __init__( self, config: Idefics2Config, @@ -225,15 +230,16 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", - use_data_parallel=use_data_parallel) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - self.mlp = Idefics2VisionMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + use_data_parallel=use_data_parallel, + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Idefics2VisionMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, @@ -284,13 +290,17 @@ def __init__( else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - Idefics2EncoderLayer(config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Idefics2EncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward( self, @@ -313,7 +323,6 @@ def forward( class Idefics2VisionTransformer(nn.Module): - def __init__( self, config: Idefics2VisionConfig, @@ -335,7 +344,8 @@ def __init__( quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + ) num_hidden_layers = config.num_hidden_layers if len(self.encoder.layers) > config.num_hidden_layers: @@ -345,10 +355,14 @@ def __init__( ) self.require_post_norm = require_post_norm - self.post_layernorm = nn.LayerNorm( - embed_dim, - eps=config.layer_norm_eps, - ) if require_post_norm else nn.Identity() + self.post_layernorm = ( + nn.LayerNorm( + embed_dim, + eps=config.layer_norm_eps, + ) + if require_post_norm + else nn.Identity() + ) def get_input_embeddings(self): return self.embeddings @@ -365,15 +379,13 @@ def forward( tgt_sizes=tgt_sizes, ) if self.use_data_parallel: - encoder_outputs = run_dp_sharded_vision_model( - hidden_states, self.encoder) + encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder) else: encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -390,8 +402,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue # post_layernorm is optional - if (name.startswith("post_layernorm.") - and not self.require_post_norm): + if name.startswith("post_layernorm.") and not self.require_post_norm: continue # omit layers when num_hidden_layers_override is set @@ -410,8 +421,7 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 567793e9b7ee..effdbdc1ac38 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -22,8 +22,12 @@ import torch from torch import nn -from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, - Idefics3Processor) +from transformers import ( + BatchFeature, + Idefics3Config, + Idefics3ImageProcessor, + Idefics3Processor, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -33,27 +37,30 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageProcessorItems, ImageSize -# yapf conflicts with isort for this block -# yapf: disable -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalDataItems, PromptReplacement, - PromptUpdate, PromptUpdateDetails) -# yapf: enable +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalDataItems, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -# yapf: disable from .idefics2_vision_model import ( - Idefics2VisionTransformer as Idefics3VisionTransformer) -# yapf: enable + Idefics2VisionTransformer as Idefics3VisionTransformer, +) from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal from .llama import LlamaModel -from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix class Idefics3ImagePixelInputs(TensorSchema): @@ -65,9 +72,10 @@ class Idefics3ImagePixelInputs(TensorSchema): - h: Height - w: Width """ + type: Literal["pixel_values"] pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] - pixel_attention_mask: torch.Tensor + pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -78,6 +86,7 @@ class Idefics3ImageEmbeddingInputs(TensorSchema): - f: Image feature size - h: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["image_embeds"] data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] @@ -86,20 +95,21 @@ class Idefics3ImageEmbeddingInputs(TensorSchema): class Idefics3ProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> Idefics3Processor: return self.ctx.get_hf_processor(Idefics3Processor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def _resize_output_size(self, - *, - height: int, - width: int, - max_len: Optional[int] = None, - min_len: int = 1, - max_size: Optional[int] = None) -> tuple[int, int]: + def _resize_output_size( + self, + *, + height: int, + width: int, + max_len: Optional[int] = None, + min_len: int = 1, + max_size: Optional[int] = None, + ) -> tuple[int, int]: # Set default value for max_len if not provided max_len = max(height, width) if max_len is None else max_len aspect_ratio = width / height @@ -135,18 +145,19 @@ def _get_resize_output_image_size( ) -> tuple[int, int]: hf_processor = self.get_hf_processor() image_processor: Idefics3ImageProcessor = hf_processor.image_processor - max_image_size = image_processor.size['longest_edge'] + max_image_size = image_processor.size["longest_edge"] if resolution_max_side > max_image_size: raise ValueError( - "`resolution_max_side` cannot be larger than `max_image_size`") + "`resolution_max_side` cannot be larger than `max_image_size`" + ) height, width = image_height, image_width # Find the output size, when rescaling the longest edge to max_len and # preserving the aspect ratio - height, width = self._resize_output_size(height=height, - width=width, - max_len=resolution_max_side) + height, width = self._resize_output_size( + height=height, width=width, max_len=resolution_max_side + ) return height, width def _get_image_feature_grid_size( @@ -161,12 +172,13 @@ def _get_image_feature_grid_size( image_processor: Idefics3ImageProcessor = processor.image_processor - max_image_size = image_processor.max_image_size['longest_edge'] - size = image_processor.size['longest_edge'] + max_image_size = image_processor.max_image_size["longest_edge"] + size = image_processor.size["longest_edge"] assert size % max_image_size == 0, ( "`longest_edge` in image_processor's `size` must be divisible by " "`longest_edge` in `max_image_size`, this may be caused by " - "incorrect mm_kwargs override.") + "incorrect mm_kwargs override." + ) resized_height, resized_width = self._get_resize_output_image_size( image_width=image_width, @@ -196,8 +208,8 @@ def get_num_patches( return grid_w * grid_h + 1 def _get_image_token( - self, - processor: Optional[Idefics3Processor]) -> tuple[str, str, str]: + self, processor: Optional[Idefics3Processor] + ) -> tuple[str, str, str]: if processor is None: processor = self.get_hf_processor() @@ -217,7 +229,8 @@ def get_image_repl( processor = self.get_hf_processor() image_token, fake_image_token, global_img_token = self._get_image_token( - processor) + processor + ) image_seq_len = processor.image_seq_len grid_placeholder = "<row_{n_h}_col_{n_w}>" @@ -236,19 +249,20 @@ def get_image_repl( tiles_placeholder = list[str]() for i in range(grid_h): for j in range(grid_w): - placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, - n_w=j + 1) + placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, n_w=j + 1) tiles_placeholder.append(placeholder_per_tile) # Add line break if it is the last tile in the row if j == grid_w - 1: tiles_placeholder.append("\n") - return "".join([ - *tiles_placeholder, - "\n", - global_img_placeholder, - fake_image_token, - ]) + return "".join( + [ + *tiles_placeholder, + "\n", + global_img_placeholder, + fake_image_token, + ] + ) def get_num_image_tokens( self, @@ -278,9 +292,7 @@ def get_image_size_with_most_features(self) -> ImageSize: ) -class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] - ): - +class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -298,22 +310,21 @@ def get_dummy_mm_data( num_images = mm_counts.get("image", 0) hf_processor = self.info.get_hf_processor() image_processor: Idefics3ImageProcessor = hf_processor.image_processor - longest_edge = image_processor.max_image_size['longest_edge'] + longest_edge = image_processor.max_image_size["longest_edge"] image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=longest_edge, - height=longest_edge, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=longest_edge, + height=longest_edge, + num_images=num_images, + overrides=image_overrides, + ) } -class Idefics3MultiModalProcessor( - BaseMultiModalProcessor[Idefics3ProcessingInfo]): - +class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -334,9 +345,11 @@ def _call_hf_processor( tok_kwargs, ) - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": images - }).get_items("image", ImageProcessorItems)) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) image_sizes = [ parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] @@ -347,7 +360,8 @@ def _call_hf_processor( image_width=size.width, image_height=size.height, processor=hf_processor, - ) for size in image_sizes + ) + for size in image_sizes ] processed_outputs["num_patches"] = torch.tensor(num_patches) @@ -365,10 +379,10 @@ def _get_mm_fields_config( num_patches = hf_inputs.get("num_patches", torch.empty(0)) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + "image", num_patches + ), image_embeds=MultiModalFieldConfig.batched("image"), num_patches=MultiModalFieldConfig.batched("image"), ) @@ -408,7 +422,6 @@ def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails: class Idefics3SimpleMLP(nn.Module): - def __init__( self, config: Idefics3Config, @@ -416,8 +429,7 @@ def __init__( prefix: str = "", ): super().__init__() - input_size = config.vision_config.hidden_size * (config.scale_factor** - 2) + input_size = config.vision_config.hidden_size * (config.scale_factor**2) output_size = config.text_config.hidden_size self.proj = ReplicatedLinear( input_size, @@ -433,7 +445,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Idefics3Connector(nn.Module): - def __init__( self, config: Idefics3Config, @@ -448,14 +459,11 @@ def __init__( prefix=maybe_prefix(prefix, "modality_projection"), ) - def pixel_shuffle(self, - x: torch.Tensor, - scale_factor: int = 2) -> torch.Tensor: + def pixel_shuffle(self, x: torch.Tensor, scale_factor: int = 2) -> torch.Tensor: bsz, seq, embed_dim = x.size() height = width = int(seq**0.5) x = x.view(bsz, height, width, embed_dim) - x = x.view(bsz, height, int(width / scale_factor), - embed_dim * scale_factor) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) x = x.permute(0, 2, 1, 3) x = x.reshape( bsz, @@ -464,19 +472,16 @@ def pixel_shuffle(self, embed_dim * (scale_factor**2), ) x = x.permute(0, 2, 1, 3) - x = x.reshape(bsz, int(seq / (scale_factor**2)), - embed_dim * (scale_factor**2)) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) return x def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor: - image_hidden_states = self.pixel_shuffle(image_hidden_states, - self.scale_factor) + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) image_hidden_states = self.modality_projection(image_hidden_states) return image_hidden_states class Idefics3Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -488,7 +493,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vision_model = Idefics3VisionTransformer( config.vision_config, quant_config=quant_config, - prefix=maybe_prefix(prefix, "vision_model")) + prefix=maybe_prefix(prefix, "vision_model"), + ) self.connector = Idefics3Connector( config, quant_config, @@ -500,8 +506,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.image_seq_len = int( - ((config.vision_config.image_size // - config.vision_config.patch_size)**2) / (config.scale_factor**2)) + ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) + / (config.scale_factor**2) + ) self.image_token_id = self.config.image_token_id def image_pixels_to_features( @@ -518,21 +525,21 @@ def image_pixels_to_features( # Remove padding images - padding images are full 0. nb_values_per_image = pixel_values.shape[1:].numel() real_images_inds = (pixel_values == 0.0).sum( - dim=(-1, -2, -3)) != nb_values_per_image + dim=(-1, -2, -3) + ) != nb_values_per_image pixel_values = pixel_values[real_images_inds].contiguous() # Handle the vision attention mask # Remove padding images from the mask - pixel_attention_mask = pixel_attention_mask[ - real_images_inds].contiguous() + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() patch_size = self.config.vision_config.patch_size - patches_subgrid = pixel_attention_mask.unfold(dimension=1, - size=patch_size, - step=patch_size) - patches_subgrid = patches_subgrid.unfold(dimension=2, - size=patch_size, - step=patch_size) + patches_subgrid = pixel_attention_mask.unfold( + dimension=1, size=patch_size, step=patch_size + ) + patches_subgrid = patches_subgrid.unfold( + dimension=2, size=patch_size, step=patch_size + ) patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() # Get sequence from the vision encoder @@ -553,7 +560,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.text_model( input_ids, positions, @@ -566,9 +572,11 @@ def forward( @MULTIMODAL_REGISTRY.register_processor( Idefics3MultiModalProcessor, info=Idefics3ProcessingInfo, - dummy_inputs=Idefics3DummyInputsBuilder) -class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA): + dummy_inputs=Idefics3DummyInputsBuilder, +) +class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -598,8 +606,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.multimodal_config = multimodal_config - self.model = Idefics3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Idefics3Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.image_token_id = self.config.image_token_id self.lm_head = ParallelLMHead( @@ -613,7 +622,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.text_config.vocab_size) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[ImageInputs]: + self, **kwargs: object + ) -> Optional[ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -621,47 +631,27 @@ def _parse_and_validate_image_input( return None if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return Idefics3ImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - pixel_attention_mask = kwargs.pop("pixel_attention_mask") - if not isinstance(pixel_attention_mask, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel_attention_mask. " - f"Got type: {type(pixel_attention_mask)}") - num_patches = kwargs.pop("num_patches") - if not isinstance(num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_patches. " - f"Got type: {type(num_patches)}") - expected_h = expected_w = self.config.vision_config.image_size + return Idefics3ImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values, concat=True), - pixel_attention_mask=flatten_bn(pixel_attention_mask, - concat=True), - num_patches=flatten_bn(num_patches, concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w - }, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + num_patches=num_patches, + resolve_bindings={"h": expected_h, "w": expected_w}, ) raise AssertionError("This line should be unreachable.") - def _process_image_pixels( - self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor: + def _process_image_pixels(self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor: pixel_values = inputs["pixel_values"] pixel_attention_mask = inputs["pixel_attention_mask"] @@ -681,15 +671,12 @@ def _process_image_input( image_features = self.model.connector(image_features) num_patches = image_input["num_patches"] - return [ - e.flatten(0, 1) for e in image_features.split(num_patches.tolist()) - ] + return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())] def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -707,10 +694,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.model.text_model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.model.text_model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -718,8 +704,7 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) @@ -730,4 +715,5 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="model.text_model", connector="model.connector", - tower_model="model.vision_model") + tower_model="model.vision_model", + ) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index c95c63cd8534..38c9d5abb587 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -2,8 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, MutableSequence -from typing import (TYPE_CHECKING, Callable, ClassVar, Literal, Optional, - Protocol, Union, overload, runtime_checkable) +from typing import ( + TYPE_CHECKING, + Callable, + ClassVar, + Literal, + Optional, + Protocol, + Union, + overload, + runtime_checkable, +) import numpy as np import torch @@ -76,10 +85,9 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: """ ... - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: """ - Returns multimodal embeddings generated from multimodal kwargs + Returns multimodal embeddings generated from multimodal kwargs to be merged with text embeddings. Note: @@ -93,7 +101,7 @@ def get_language_model(self) -> VllmModel: """ Returns the underlying language model used for text generation. - This is typically the `torch.nn.Module` instance responsible for + This is typically the `torch.nn.Module` instance responsible for processing the merged multimodal embeddings and producing hidden states Returns: @@ -102,8 +110,7 @@ def get_language_model(self) -> VllmModel: ... @overload - def get_input_embeddings(self, input_ids: Tensor) -> Tensor: - ... + def get_input_embeddings(self, input_ids: Tensor) -> Tensor: ... @overload def get_input_embeddings( @@ -113,8 +120,7 @@ def get_input_embeddings( *, is_multimodal: torch.Tensor, handle_oov_mm_token: bool = False, - ) -> Tensor: - ... + ) -> Tensor: ... def _get_text_embeddings( self, @@ -172,7 +178,8 @@ def get_input_embeddings( raise ValueError( "`get_input_embeddings` now requires `is_multimodal` arg, " "please update your model runner according to " - "https://github.com/vllm-project/vllm/pull/16229.") + "https://github.com/vllm-project/vllm/pull/16229." + ) return _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, @@ -187,12 +194,15 @@ class SupportsMultiModalPruning(Protocol): embeddings and positions. Model may require custom positions for dynamic pruning of multimodal embeddings. """ + supports_multimodal_pruning: ClassVar[Literal[True]] = True def recompute_mrope_positions( - self, input_ids: list[int], - multimodal_embeddings: MultiModalEmbeddings, - mrope_positions: torch.LongTensor, num_computed_tokens: int + self, + input_ids: list[int], + multimodal_embeddings: MultiModalEmbeddings, + mrope_positions: torch.LongTensor, + num_computed_tokens: int, ) -> tuple[MultiModalEmbeddings, Tensor, int]: """ Update part of input mrope positions (starting with @@ -218,14 +228,11 @@ def recompute_mrope_positions( @overload -def supports_multimodal( - model: type[object]) -> TypeIs[type[SupportsMultiModal]]: - ... +def supports_multimodal(model: type[object]) -> TypeIs[type[SupportsMultiModal]]: ... @overload -def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: - ... +def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: ... def supports_multimodal( @@ -234,32 +241,27 @@ def supports_multimodal( return getattr(model, "supports_multimodal", False) -def supports_multimodal_raw_input_only( - model: Union[type[object], object]) -> bool: +def supports_multimodal_raw_input_only(model: Union[type[object], object]) -> bool: return getattr(model, "supports_multimodal_raw_input_only", False) -def supports_multimodal_encoder_tp_data( - model: Union[type[object], object]) -> bool: +def supports_multimodal_encoder_tp_data(model: Union[type[object], object]) -> bool: return getattr(model, "supports_encoder_tp_data", False) @overload def supports_multimodal_pruning( - model: type[object]) -> TypeIs[type[SupportsMultiModalPruning]]: - ... + model: type[object], +) -> TypeIs[type[SupportsMultiModalPruning]]: ... @overload -def supports_multimodal_pruning( - model: object) -> TypeIs[SupportsMultiModalPruning]: - ... +def supports_multimodal_pruning(model: object) -> TypeIs[SupportsMultiModalPruning]: ... def supports_multimodal_pruning( model: Union[type[object], object], -) -> Union[TypeIs[type[SupportsMultiModalPruning]], - TypeIs[SupportsMultiModalPruning]]: +) -> Union[TypeIs[type[SupportsMultiModalPruning]], TypeIs[SupportsMultiModalPruning]]: return getattr(model, "supports_multimodal_pruning", False) @@ -280,7 +282,7 @@ class SupportsScoreTemplate(Protocol): def get_score_template(cls, query: str, document: str) -> Optional[str]: """ Generate a full prompt by populating the score template with query and document content. - """ # noqa: E501 + """ # noqa: E501 ... @classmethod @@ -293,13 +295,12 @@ def post_process_tokens(cls, prompt: TokensPrompt) -> None: @overload def supports_score_template( - model: type[object]) -> TypeIs[type[SupportsScoreTemplate]]: - ... + model: type[object], +) -> TypeIs[type[SupportsScoreTemplate]]: ... @overload -def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]: - ... +def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]: ... def supports_score_template( @@ -339,13 +340,11 @@ class _SupportsLoRAType(Protocol): @overload -def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]: - ... +def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]: ... @overload -def supports_lora(model: object) -> TypeIs[SupportsLoRA]: - ... +def supports_lora(model: object) -> TypeIs[SupportsLoRA]: ... def supports_lora( @@ -359,8 +358,7 @@ def supports_lora( "embedding_modules", "embedding_padding_modules", ) - missing_attrs = tuple(attr for attr in lora_attrs - if not hasattr(model, attr)) + missing_attrs = tuple(attr for attr in lora_attrs if not hasattr(model, attr)) if getattr(model, "supports_lora", False): if missing_attrs: @@ -374,7 +372,9 @@ def supports_lora( if not missing_attrs: logger.warning( "The model (%s) contains all LoRA-specific attributes, " - "but does not set `supports_lora=True`.", model) + "but does not set `supports_lora=True`.", + model, + ) return result @@ -434,25 +434,21 @@ def make_empty_intermediate_tensors( batch_size: int, dtype: torch.dtype, device: torch.device, - ) -> "IntermediateTensors": - ... + ) -> "IntermediateTensors": ... def forward( self, *, intermediate_tensors: Optional["IntermediateTensors"], - ) -> Union[Tensor, "IntermediateTensors"]: - ... + ) -> Union[Tensor, "IntermediateTensors"]: ... @overload -def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]: - ... +def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]: ... @overload -def supports_pp(model: object) -> TypeIs[SupportsPP]: - ... +def supports_pp(model: object) -> TypeIs[SupportsPP]: ... def supports_pp( @@ -464,12 +460,13 @@ def supports_pp( if supports_attributes and not supports_inspect: logger.warning( "The model (%s) sets `supports_pp=True`, but does not accept " - "`intermediate_tensors` in its `forward` method", model) + "`intermediate_tensors` in its `forward` method", + model, + ) if not supports_attributes: - pp_attrs = ("make_empty_intermediate_tensors", ) - missing_attrs = tuple(attr for attr in pp_attrs - if not hasattr(model, attr)) + pp_attrs = ("make_empty_intermediate_tensors",) + missing_attrs = tuple(attr for attr in pp_attrs if not hasattr(model, attr)) if getattr(model, "supports_pp", False): if missing_attrs: @@ -483,7 +480,9 @@ def supports_pp( if not missing_attrs: logger.warning( "The model (%s) contains all PP-specific attributes, " - "but does not set `supports_pp=True`.", model) + "but does not set `supports_pp=True`.", + model, + ) return supports_attributes and supports_inspect @@ -516,17 +515,15 @@ class HasInnerState(Protocol): @overload -def has_inner_state(model: object) -> TypeIs[HasInnerState]: - ... +def has_inner_state(model: object) -> TypeIs[HasInnerState]: ... @overload -def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]: - ... +def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]: ... def has_inner_state( - model: Union[type[object], object] + model: Union[type[object], object], ) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]: return getattr(model, "has_inner_state", False) @@ -545,17 +542,15 @@ class IsAttentionFree(Protocol): @overload -def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: - ... +def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: ... @overload -def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]: - ... +def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]: ... def is_attention_free( - model: Union[type[object], object] + model: Union[type[object], object], ) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]: return getattr(model, "is_attention_free", False) @@ -563,7 +558,7 @@ def is_attention_free( @runtime_checkable class IsHybrid(Protocol): """The interface required for all models like Jamba that have both - attention and mamba blocks, indicates that + attention and mamba blocks, indicates that hf_config has 'layers_block_type'""" is_hybrid: ClassVar[Literal[True]] = True @@ -593,17 +588,15 @@ def get_mamba_state_shape_from_config( @overload -def is_hybrid(model: object) -> TypeIs[IsHybrid]: - ... +def is_hybrid(model: object) -> TypeIs[IsHybrid]: ... @overload -def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]: - ... +def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]: ... def is_hybrid( - model: Union[type[object], object] + model: Union[type[object], object], ) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]: return getattr(model, "is_hybrid", False) @@ -654,7 +647,7 @@ def set_eplb_state( ) -> None: """ Register the EPLB state in the MoE model. - + Since these are views of the actual EPLB state, any changes made by the EPLB algorithm are automatically reflected in the model's behavior without requiring additional method calls to set new states. @@ -674,8 +667,7 @@ def update_physical_experts_metadata( self, num_physical_experts: int, num_local_physical_experts: int, - ) -> None: - ... + ) -> None: ... def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: @@ -688,17 +680,15 @@ class HasNoOps(Protocol): @overload -def has_noops(model: object) -> TypeIs[HasNoOps]: - ... +def has_noops(model: object) -> TypeIs[HasNoOps]: ... @overload -def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]: - ... +def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]: ... def has_noops( - model: Union[type[object], object] + model: Union[type[object], object], ) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]: return getattr(model, "has_noops", False) @@ -712,13 +702,12 @@ class SupportsCrossEncoding(Protocol): @overload def supports_cross_encoding( - model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]: - ... + model: type[object], +) -> TypeIs[type[SupportsCrossEncoding]]: ... @overload -def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: - ... +def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: ... def _supports_cross_encoding( @@ -746,7 +735,6 @@ def __new__(cls, *args, **kwargs) -> Self: # find config passed in arguments quant_config = cls._find_quant_config(*args, **kwargs) if quant_config is not None: - # attach config to model for general use instance.quant_config = quant_config @@ -755,7 +743,8 @@ def __new__(cls, *args, **kwargs) -> Self: instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper) if instance.packed_modules_mapping is not None: instance.quant_config.packed_modules_mapping.update( - instance.packed_modules_mapping) + instance.packed_modules_mapping + ) return instance @@ -778,6 +767,7 @@ def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]: @runtime_checkable class SupportsTranscription(Protocol): """The interface required for all models that support transcription.""" + # Mapping from ISO639_1 language codes: language names supported_languages: ClassVar[Mapping[str, str]] @@ -798,16 +788,20 @@ def __init_subclass__(cls, **kwargs): raise ValueError( f"{cls.__name__}.supported_languages contains invalid " f"language codes: {sorted(invalid)}\n. " - f"Valid choices are: {sorted(LANGUAGES.keys())}") + f"Valid choices are: {sorted(LANGUAGES.keys())}" + ) @classmethod - def get_generation_prompt(cls, audio: np.ndarray, - stt_config: SpeechToTextConfig, - model_config: ModelConfig, - language: Optional[str], - task_type: Literal["transcribe", "translate"], - request_prompt: str, - to_language: Optional[str]) -> PromptType: + def get_generation_prompt( + cls, + audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: """Get the prompt for the ASR model. The model has control over the construction, as long as it returns a valid PromptType.""" @@ -816,17 +810,14 @@ def get_generation_prompt(cls, audio: np.ndarray, @classmethod def get_other_languages(cls) -> Mapping[str, str]: # other possible language codes from the whisper map - return { - k: v - for k, v in LANGUAGES.items() if k not in cls.supported_languages - } + return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages} @classmethod def validate_language(cls, language: Optional[str]) -> Optional[str]: """ - Ensure the language specified in the transcription request - is a valid ISO 639-1 language code. If the request language is - valid, but not natively supported by the model, trigger a + Ensure the language specified in the transcription request + is a valid ISO 639-1 language code. If the request language is + valid, but not natively supported by the model, trigger a warning (but not an exception). """ if language is None or language in cls.supported_languages: @@ -843,22 +834,25 @@ def validate_language(cls, language: Optional[str]) -> Optional[str]: else: raise ValueError( f"Unsupported language: {language!r}. Must be one of " - f"{list(cls.supported_languages.keys())}.") + f"{list(cls.supported_languages.keys())}." + ) @classmethod def get_speech_to_text_config( - cls, model_config: ModelConfig, - task_type: Literal["transcribe", - "translate"]) -> SpeechToTextConfig: + cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"] + ) -> SpeechToTextConfig: """Get the speech to text config for the ASR model.""" ... @classmethod - def get_num_audio_tokens(cls, audio_duration_s: float, - stt_config: SpeechToTextConfig, - model_config: ModelConfig) -> Optional[int]: + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> Optional[int]: """ - Map from audio duration to number of audio tokens produced by the ASR + Map from audio duration to number of audio tokens produced by the ASR model, without running a forward pass. This is used for estimating the amount of processing for this audio. """ @@ -867,13 +861,12 @@ def get_num_audio_tokens(cls, audio_duration_s: float, @overload def supports_transcription( - model: type[object]) -> TypeIs[type[SupportsTranscription]]: - ... + model: type[object], +) -> TypeIs[type[SupportsTranscription]]: ... @overload -def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: - ... +def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: ... def supports_transcription( @@ -890,13 +883,11 @@ class SupportsV0Only(Protocol): @overload -def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]: - ... +def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]: ... @overload -def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: - ... +def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: ... def supports_v0_only( @@ -907,7 +898,7 @@ def supports_v0_only( @runtime_checkable class SupportsEagle3(Protocol): - """The interface required for models that support + """The interface required for models that support EAGLE3 speculative decoding.""" supports_eagle3: ClassVar[Literal[True]] = True @@ -924,7 +915,7 @@ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: """ Set which layers should output auxiliary hidden states for EAGLE3. - + Args: layers: Tuple of layer indices that should output auxiliary hidden states. @@ -935,7 +926,7 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: """ Get the layer indices that should output auxiliary hidden states for EAGLE3. - + Returns: Tuple of layer indices for auxiliary hidden state outputs. """ @@ -943,13 +934,11 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: @overload -def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]: - ... +def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]: ... @overload -def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]: - ... +def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]: ... def supports_eagle3( @@ -985,10 +974,10 @@ def get_mrope_input_positions( ) -> tuple[torch.Tensor, int]: """ Get M-RoPE input positions and delta value for this specific model. - + This method should be implemented by each model that supports M-RoPE to provide model-specific logic for computing input positions. - + Args: input_tokens: List of input token IDs hf_config: HuggingFace model configuration @@ -999,7 +988,7 @@ def get_mrope_input_positions( seq_len: Sequence length audio_feature_lengths: Audio feature lengths for multimodal models use_audio_in_video: Whether to use audio in video for interleaving - + Returns: Tuple of (llm_positions, mrope_position_delta) - llm_positions: Tensor of shape [3, num_tokens] @@ -1010,13 +999,11 @@ def get_mrope_input_positions( @overload -def supports_mrope(model: type[object]) -> TypeIs[type[SupportsMRoPE]]: - ... +def supports_mrope(model: type[object]) -> TypeIs[type[SupportsMRoPE]]: ... @overload -def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]: - ... +def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]: ... def supports_mrope( diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 84146db0943c..b697eb25b5cc 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -1,7 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Protocol, - Union, overload, runtime_checkable) +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + Optional, + Protocol, + Union, + overload, + runtime_checkable, +) import torch import torch.nn as nn @@ -38,8 +47,7 @@ def __init__( self, vllm_config: VllmConfig, prefix: str = "", - ) -> None: - ... + ) -> None: ... def get_input_embeddings( self, @@ -52,8 +60,7 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - ) -> T_co: - ... + ) -> T_co: ... def _check_vllm_model_init(model: Union[type[object], object]) -> bool: @@ -61,8 +68,7 @@ def _check_vllm_model_init(model: Union[type[object], object]) -> bool: return supports_kw(model_init, "vllm_config") -def _check_vllm_model_get_input_embeddings( - model: Union[type[object], object]) -> bool: +def _check_vllm_model_get_input_embeddings(model: Union[type[object], object]) -> bool: model_get_input_embeddings = getattr(model, "get_input_embeddings", None) if not callable(model_get_input_embeddings): logger.warning( @@ -80,11 +86,9 @@ def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: return False vllm_kws = ("input_ids", "positions") - missing_kws = tuple(kw for kw in vllm_kws - if not supports_kw(model_forward, kw)) + missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw)) - if missing_kws and (isinstance(model, type) - and issubclass(model, nn.Module)): + if missing_kws and (isinstance(model, type) and issubclass(model, nn.Module)): logger.warning( "The model (%s) is missing " "vLLM-specific keywords from its `forward` method: %s", @@ -96,21 +100,21 @@ def _check_vllm_model_forward(model: Union[type[object], object]) -> bool: @overload -def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: - ... +def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: ... @overload -def is_vllm_model(model: object) -> TypeIs[VllmModel]: - ... +def is_vllm_model(model: object) -> TypeIs[VllmModel]: ... def is_vllm_model( model: Union[type[object], object], ) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]: - return (_check_vllm_model_init(model) - and _check_vllm_model_get_input_embeddings(model) - and _check_vllm_model_forward(model)) + return ( + _check_vllm_model_init(model) + and _check_vllm_model_get_input_embeddings(model) + and _check_vllm_model_forward(model) + ) @runtime_checkable @@ -127,20 +131,19 @@ def compute_logits( @overload def is_text_generation_model( - model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]: - ... + model: type[object], +) -> TypeIs[type[VllmModelForTextGeneration]]: ... @overload -def is_text_generation_model( - model: object) -> TypeIs[VllmModelForTextGeneration]: - ... +def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration]: ... def is_text_generation_model( model: Union[type[object], object], -) -> Union[TypeIs[type[VllmModelForTextGeneration]], - TypeIs[VllmModelForTextGeneration]]: +) -> Union[ + TypeIs[type[VllmModelForTextGeneration]], TypeIs[VllmModelForTextGeneration] +]: if not is_vllm_model(model): return False @@ -179,13 +182,11 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): @overload -def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: - ... +def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: ... @overload -def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: - ... +def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ... def is_pooling_model( diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 2c341d283971..9435ff0d26cf 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -17,28 +17,32 @@ from transformers import PretrainedConfig from vllm.attention.layer import MultiHeadAttention -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from .vision import run_dp_sharded_vision_model NORM2FN = { - 'rms_norm': RMSNorm, - 'layer_norm': nn.LayerNorm, + "rms_norm": RMSNorm, + "layer_norm": nn.LayerNorm, } class InternVisionEmbeddings(nn.Module): - def __init__(self, config: PretrainedConfig): super().__init__() self.config = config @@ -48,28 +52,36 @@ def __init__(self, config: PretrainedConfig): self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) - self.patch_embedding = nn.Conv2d(in_channels=3, - out_channels=self.embed_dim, - kernel_size=self.patch_size, - stride=self.patch_size) + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Parameter( - torch.randn(1, self.num_positions, self.embed_dim)) + torch.randn(1, self.num_positions, self.embed_dim) + ) def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int): target_dtype = pos_embed.dtype - pos_embed = pos_embed.float().reshape( - 1, self.image_size // self.patch_size, - self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) - pos_embed = F.interpolate(pos_embed, - size=(H, W), - mode='bicubic', - align_corners=False) - return pos_embed.reshape(1, -1, H * W).permute(0, 2, - 1).to(target_dtype) + pos_embed = ( + pos_embed.float() + .reshape( + 1, + self.image_size // self.patch_size, + self.image_size // self.patch_size, + -1, + ) + .permute(0, 3, 1, 2) + ) + pos_embed = F.interpolate( + pos_embed, size=(H, W), mode="bicubic", align_corners=False + ) + return pos_embed.reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype) def _get_position_embedding(self, H: int, W: int) -> torch.Tensor: position_embedding = self.position_embedding @@ -86,12 +98,12 @@ def _get_position_embedding(self, H: int, W: int) -> torch.Tensor: def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to( - target_dtype)) # shape = [*, channel, width, height] + patch_embeds = self.patch_embedding( + pixel_values.to(target_dtype) + ) # shape = [*, channel, width, height] batch_size, _, height, width = patch_embeds.shape patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - class_embeds = self.class_embedding.expand(batch_size, 1, - -1).to(target_dtype) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) position_embedding = self._get_position_embedding(height, width) embeddings = embeddings + position_embedding.to(target_dtype) @@ -99,7 +111,6 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: class InternVisionPatchModel(nn.Module): - def __init__(self, config: PretrainedConfig): super().__init__() self.config = config @@ -114,8 +125,7 @@ def forward( pixel_embeds: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if pixel_values is None and pixel_embeds is None: - raise ValueError( - 'You have to specify pixel_values or pixel_embeds') + raise ValueError("You have to specify pixel_values or pixel_embeds") if pixel_embeds is not None: hidden_states = pixel_embeds @@ -123,8 +133,7 @@ def forward( if pixel_values.ndim == 4: hidden_states = self.embeddings(pixel_values) else: - raise ValueError( - f'wrong pixel_values size: {pixel_values.shape}') + raise ValueError(f"wrong pixel_values size: {pixel_values.shape}") return hidden_states @@ -149,19 +158,21 @@ def __init__( self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( - f'embed_dim must be divisible by num_heads ' - f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' - f' {self.num_heads}).') + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) - self.tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) - self.tp_rank = (0 if use_data_parallel else - get_tensor_model_parallel_rank()) + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() # Additional dummy heads are used to enable TP for common GPU counts. self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim - self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads, - self.tp_size) + self.num_heads_per_partition = divide( + num_dummy_heads + self.num_heads, self.tp_size + ) self.scale = self.head_dim**-0.5 self.qkv = QKVParallelLinear( @@ -177,12 +188,16 @@ def __init__( self.qk_normalization = config.qk_normalization if self.qk_normalization: - self.q_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) - self.k_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) + self.q_norm = RMSNorm( + self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim, + ) + self.k_norm = RMSNorm( + self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim, + ) self.proj = RowParallelLinear( self.dummy_dim, @@ -192,8 +207,9 @@ def __init__( disable_tp=use_data_parallel, ) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): if self.tp_size > 1: @@ -202,8 +218,7 @@ def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): q = self.q_norm(q) k = self.k_norm(k) if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k @@ -222,7 +237,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class InternMLP(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -234,18 +248,22 @@ def __init__( self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1", - disable_tp=use_data_parallel) - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2", - disable_tp=use_data_parallel) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -256,7 +274,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class InternVisionEncoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -272,25 +289,25 @@ def __init__( self.intermediate_size = config.intermediate_size self.norm_type = config.norm_type - self.attn = self._init_attn(config, - quant_config, - num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel) - - self.mlp = InternMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) - self.norm1 = NORM2FN[self.norm_type](self.embed_dim, - eps=config.layer_norm_eps) - self.norm2 = NORM2FN[self.norm_type](self.embed_dim, - eps=config.layer_norm_eps) - - self.ls1 = nn.Parameter(config.initializer_factor * - torch.ones(self.embed_dim)) - self.ls2 = nn.Parameter(config.initializer_factor * - torch.ones(self.embed_dim)) + self.attn = self._init_attn( + config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, + ) + + self.mlp = InternMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + + self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) def _init_attn( self, @@ -302,35 +319,34 @@ def _init_attn( use_data_parallel: bool = False, ): # fallback to sdpa attention if tp unavailable - tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() num_heads = config.num_attention_heads # if the number of heads is not divisible by tp_size, # we also disable Attention's TP - use_data_parallel = (use_data_parallel - or (num_heads + num_dummy_heads) % tp_size != 0) - return InternParallelAttention(config, - quant_config=quant_config, - num_dummy_heads=num_dummy_heads, - prefix=prefix, - use_data_parallel=use_data_parallel) + use_data_parallel = ( + use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0 + ) + return InternParallelAttention( + config, + quant_config=quant_config, + num_dummy_heads=num_dummy_heads, + prefix=prefix, + use_data_parallel=use_data_parallel, + ) def forward( self, hidden_states: torch.Tensor, ): - hidden_states = hidden_states + self.attn( - self.norm1(hidden_states)) * self.ls1 + hidden_states = hidden_states + self.attn(self.norm1(hidden_states)) * self.ls1 - hidden_states = hidden_states + self.mlp( - self.norm2(hidden_states)) * self.ls2 + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2 return hidden_states class InternVisionEncoder(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -350,17 +366,20 @@ def __init__( else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - InternVisionEncoderLayer(config, - quant_config, - num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.layers.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + InternVisionEncoderLayer( + config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward(self, inputs_embeds: torch.Tensor): - hidden_states = inputs_embeds for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states) @@ -369,7 +388,6 @@ def forward(self, inputs_embeds: torch.Tensor): class InternVisionModel(nn.Module): - packed_modules_mapping = { "qkv": ["qkv"], } @@ -408,8 +426,7 @@ def forward( pixel_embeds: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if pixel_values is None and pixel_embeds is None: - raise ValueError( - 'You have to specify pixel_values or pixel_embeds') + raise ValueError("You have to specify pixel_values or pixel_embeds") if pixel_embeds is not None: hidden_states = pixel_embeds @@ -417,25 +434,21 @@ def forward( if pixel_values.ndim == 4: hidden_states = self.embeddings(pixel_values) else: - raise ValueError( - f'wrong pixel_values size: {pixel_values.shape}') + raise ValueError(f"wrong pixel_values size: {pixel_values.shape}") if self.use_data_parallel: - encoder_outputs = run_dp_sharded_vision_model( - hidden_states, self.encoder) + encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder) else: encoder_outputs = self.encoder(inputs_embeds=hidden_states) return encoder_outputs - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 221ff08b4384..128791541b3d 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -13,33 +13,42 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .interfaces_base import default_pooling_type -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class InternLM2MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -64,8 +73,9 @@ def __init__( prefix=f"{prefix}.w2", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -76,7 +86,6 @@ def forward(self, x): class InternLM2Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -157,16 +166,16 @@ def split_qkv(self, qkv: torch.Tensor): qkv = qkv[::3] + qkv[1::3] + qkv[2::3] qkv = torch.cat(qkv, dim=-1) - qkv = qkv.view(seq_len, self.total_num_kv_heads, - self.key_value_groups + 2, self.head_dim) + qkv = qkv.view( + seq_len, self.total_num_kv_heads, self.key_value_groups + 2, self.head_dim + ) q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2) q = q.reshape(seq_len, self.q_size * self.tp_size) k = k.reshape(seq_len, self.kv_size * self.tp_size) v = v.reshape(seq_len, self.kv_size * self.tp_size) if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] v = splitter(v)[self.tp_rank] @@ -186,7 +195,6 @@ def forward( class InternLMDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -198,8 +206,7 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.attention = InternLM2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -218,8 +225,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) - self.attention_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -233,8 +239,7 @@ def forward( residual = hidden_states hidden_states = self.attention_norm(hidden_states) else: - hidden_states, residual = self.attention_norm( - hidden_states, residual) + hidden_states, residual = self.attention_norm(hidden_states, residual) hidden_states = self.attention( positions=positions, hidden_states=hidden_states, @@ -248,13 +253,13 @@ def forward( @support_torch_compile class InternLM2Model(nn.Module): - def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer): + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config @@ -270,12 +275,14 @@ def __init__( self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: layer_type( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) @@ -300,10 +307,9 @@ def forward( for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -314,11 +320,13 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): "gate_up_proj": ["w1", "w3"], } - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - model_type: type[InternLM2Model] = InternLM2Model): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + model_type: type[InternLM2Model] = InternLM2Model, + ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -328,17 +336,21 @@ def __init__(self, self.quant_config = quant_config self.lora_config = lora_config - self.model = model_type(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.output = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "output")) + self.model = model_type( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.output = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "output"), + ) if self.config.tie_word_embeddings: self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -350,8 +362,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -361,8 +374,7 @@ def compute_logits( logits = self.logits_processor(self.output, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w1", 0), @@ -373,7 +385,7 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -393,8 +405,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -402,7 +413,6 @@ def load_weights(self, weights: Iterable[tuple[str, @default_pooling_type("ALL") class InternLM2ForRewardModel(InternLM2ForCausalLM): - is_pooling_model = True def __init__( @@ -412,9 +422,7 @@ def __init__( prefix: str = "", model_type: type[InternLM2Model] = InternLM2Model, ): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - model_type=model_type) + super().__init__(vllm_config=vllm_config, prefix=prefix, model_type=model_type) for attr in ("output", "logits_processor"): delattr(self, attr) @@ -422,19 +430,22 @@ def __init__( config = vllm_config.model_config.hf_config self.head_dtype = vllm_config.model_config.head_dtype - self.v_head = RowParallelLinear(config.hidden_size, - 1, - bias=False, - input_is_parallel=False, - params_dtype=self.head_dtype, - prefix=maybe_prefix(prefix, "v_head"), - return_bias=False) + self.v_head = RowParallelLinear( + config.hidden_size, + 1, + bias=False, + input_is_parallel=False, + params_dtype=self.head_dtype, + prefix=maybe_prefix(prefix, "v_head"), + return_bias=False, + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, ) + {"encode": Pooler.for_encode(pooler_config)}, + ) def forward( self, @@ -443,8 +454,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) hidden_states = hidden_states.to(self.head_dtype) logits = self.v_head(hidden_states) return logits diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index d41ac2b70bc6..5344ded280b2 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -12,14 +12,16 @@ from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.models.internlm2 import (InternLM2Attention, - InternLM2ForCausalLM, - InternLM2MLP, InternLM2Model) +from vllm.model_executor.models.internlm2 import ( + InternLM2Attention, + InternLM2ForCausalLM, + InternLM2MLP, + InternLM2Model, +) from vllm.sequence import IntermediateTensors class InternLM2VEDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -31,8 +33,7 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.attention = InternLM2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -58,8 +59,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.feed_forward_ve", ) - self.attention_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -74,8 +74,7 @@ def forward( residual = hidden_states hidden_states = self.attention_norm(hidden_states) else: - hidden_states, residual = self.attention_norm( - hidden_states, residual) + hidden_states, residual = self.attention_norm(hidden_states, residual) hidden_states = self.attention( positions=positions, hidden_states=hidden_states, @@ -84,27 +83,25 @@ def forward( # Fully Connected hidden_states, residual = self.ffn_norm(hidden_states, residual) if visual_token_mask is not None and visual_token_mask.any(): - visual_token_mask = visual_token_mask.repeat( - 1, self.hidden_size).bool() + visual_token_mask = visual_token_mask.repeat(1, self.hidden_size).bool() text_token_mask = ~visual_token_mask hidden_states[visual_token_mask] = self.feed_forward_ve( - hidden_states[visual_token_mask].reshape( - -1, self.hidden_size)).flatten() + hidden_states[visual_token_mask].reshape(-1, self.hidden_size) + ).flatten() if text_token_mask.any(): hidden_states[text_token_mask] = self.feed_forward( - hidden_states[text_token_mask].reshape( - -1, self.hidden_size)).flatten() + hidden_states[text_token_mask].reshape(-1, self.hidden_size) + ).flatten() else: hidden_states = self.feed_forward(hidden_states) return hidden_states, residual class InternLM2VEModel(InternLM2Model): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=InternLM2VEDecoderLayer) + super().__init__( + vllm_config=vllm_config, prefix=prefix, layer_type=InternLM2VEDecoderLayer + ) def forward( self, @@ -132,17 +129,15 @@ def forward( visual_token_mask=visual_token_mask, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class InternLM2VEForCausalLM(InternLM2ForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - model_type=InternLM2VEModel) + super().__init__( + vllm_config=vllm_config, prefix=prefix, model_type=InternLM2VEModel + ) diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 7d82dad34a7a..06c7c8ccd0b5 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -15,9 +15,11 @@ from transformers import BatchFeature, InternVLProcessor, PretrainedConfig from transformers.activations import ACT2FN from transformers.models.got_ocr2.image_processing_got_ocr2_fast import ( - GotOcr2ImageProcessorFast) + GotOcr2ImageProcessorFast, +) from transformers.models.internvl.video_processing_internvl import ( - InternVLVideoProcessor) + InternVLVideoProcessor, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -25,38 +27,57 @@ from vllm.model_executor.models.interns1_vit import InternS1VisionModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.processor import ( - cached_video_processor_from_config) +from vllm.transformers_utils.processor import cached_video_processor_from_config from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) class InternS1MultiModalProjector(nn.Module): - def __init__(self, config): super().__init__() - self.layer_norm = nn.LayerNorm(config.vision_config.hidden_size * - int(1 / config.downsample_ratio)**2) + self.layer_norm = nn.LayerNorm( + config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2 + ) self.linear_1 = nn.Linear( - config.vision_config.hidden_size * - int(1 / config.downsample_ratio)**2, - config.text_config.hidden_size) + config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2, + config.text_config.hidden_size, + ) self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = nn.Linear(config.text_config.hidden_size, - config.text_config.hidden_size) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size + ) def forward(self, image_features): hidden_states = self.layer_norm(image_features) @@ -75,6 +96,7 @@ class InternS1ImagePixelInputs(TensorSchema): - w: Width - bn: Batch size * number of images """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -87,13 +109,14 @@ class InternS1ImageEmbeddingInputs(TensorSchema): - tifs: Total image feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("ni", "tifs", "hs")] + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], TensorShape("ni", "tifs", "hs") + ] -InternS1ImageInputs = Union[InternS1ImagePixelInputs, - InternS1ImageEmbeddingInputs] +InternS1ImageInputs = Union[InternS1ImagePixelInputs, InternS1ImageEmbeddingInputs] class InternS1VideoPixelInputs(TensorSchema): @@ -105,6 +128,7 @@ class InternS1VideoPixelInputs(TensorSchema): - h: Height - w: Width """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" pixel_values: Annotated[torch.Tensor, TensorShape("bnv", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -117,13 +141,14 @@ class InternS1VideoEmbeddingInputs(TensorSchema): - tvfs: Total video feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["video_embeds"] = "video_embeds" - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("nv", "tvfs", "hs")] + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], TensorShape("nv", "tvfs", "hs") + ] -InternS1VideoInputs = Union[InternS1VideoPixelInputs, - InternS1VideoEmbeddingInputs] +InternS1VideoInputs = Union[InternS1VideoPixelInputs, InternS1VideoEmbeddingInputs] def resolve_interns1_min_max_num( @@ -145,10 +170,13 @@ def get_interns1_target_ratios( min_num: int, max_num: int, ) -> list[tuple[int, int]]: - target_ratios = {(i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) if min_num <= i * j <= max_num} + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } return sorted(target_ratios, key=lambda x: x[0] * x[1]) @@ -158,9 +186,8 @@ class InternS1ProcessingInfo(BaseProcessingInfo): def get_hf_processor(self, **kwargs: object) -> InternVLProcessor: hf_processor = self.ctx.get_hf_processor(InternVLProcessor, **kwargs) hf_processor.video_processor = cached_video_processor_from_config( - self.ctx.model_config, - processor_cls=InternVLVideoProcessor, - **kwargs) + self.ctx.model_config, processor_cls=InternVLVideoProcessor, **kwargs + ) return hf_processor def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: @@ -171,18 +198,19 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - processor: Optional['GotOcr2ImageProcessorFast'] = None, + processor: Optional["GotOcr2ImageProcessorFast"] = None, ) -> int: if processor is None: processor = self.get_hf_processor().image_processor if not isinstance(processor, GotOcr2ImageProcessorFast): - raise ValueError(f'GotOcr2ImageProcessorFast is expected but got ' - f'{type(processor)}') + raise ValueError( + f"GotOcr2ImageProcessorFast is expected but got {type(processor)}" + ) num_image_patches = processor.get_number_of_image_patches( - image_height, image_width, images_kwargs=dict()) - num_image_tokens = self.get_hf_processor( - ).image_seq_length * num_image_patches + image_height, image_width, images_kwargs=dict() + ) + num_image_tokens = self.get_hf_processor().image_seq_length * num_image_patches return num_image_tokens def resolve_target_ratios(self, use_thumbnail: Optional[bool] = None): @@ -197,7 +225,8 @@ def resolve_target_ratios(self, use_thumbnail: Optional[bool] = None): min_dynamic_patch, max_dynamic_patch, dynamic_image_size, - use_thumbnail=use_thumbnail) + use_thumbnail=use_thumbnail, + ) return get_interns1_target_ratios(min_num, max_num) @@ -219,11 +248,11 @@ def get_image_size_with_most_features(self) -> ImageSize: ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) - assert not (largest_feature_size == 0 or largest_feature_pinpoint - is None), ("Cannot have a largest feature size of 0!") + assert not (largest_feature_size == 0 or largest_feature_pinpoint is None), ( + "Cannot have a largest feature size of 0!" + ) return largest_feature_pinpoint @@ -248,15 +277,13 @@ def get_num_frames_with_most_features( processor = self.get_hf_processor() max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = (seq_len - - max_image_tokens) // processor.image_seq_length + max_total_frames = (seq_len - max_image_tokens) // processor.image_seq_length max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) -class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo] - ): +class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]): """DummyInputsBuilder for InternS1-style models.""" def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: @@ -273,10 +300,10 @@ def get_dummy_mm_data( mm_counts: Mapping[str, int], mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -287,23 +314,24 @@ def get_dummy_mm_data( video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), - "video": - self._get_dummy_videos(width=image_size_w, - height=image_size_h, - num_frames=target_num_frames, - num_videos=num_videos, - overrides=video_overrides), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( + width=image_size_w, + height=image_size_h, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, + ), } -class InternS1MultiModalProcessor( - BaseMultiModalProcessor[InternS1ProcessingInfo]): - """ Basic image-only MultiModalProcessor for InternS1-style models.""" +class InternS1MultiModalProcessor(BaseMultiModalProcessor[InternS1ProcessingInfo]): + """Basic image-only MultiModalProcessor for InternS1-style models.""" def _call_hf_processor( self, @@ -320,15 +348,14 @@ def _call_hf_processor( hf_processor = self.info.get_hf_processor(**mm_kwargs) tokenizer = hf_processor.tokenizer - video_token_id = tokenizer.encode(hf_processor.video_token, - add_special_tokens=False) + video_token_id = tokenizer.encode( + hf_processor.video_token, add_special_tokens=False + ) assert len(video_token_id) == 1 video_token_id = video_token_id[0] - prompt = re.sub(hf_processor.image_token, "<image_placeholder>", - prompt) - prompt = re.sub(hf_processor.video_token, "<video_placeholder>", - prompt) + prompt = re.sub(hf_processor.image_token, "<image_placeholder>", prompt) + prompt = re.sub(hf_processor.video_token, "<video_placeholder>", prompt) image_outputs = {} if images: @@ -340,13 +367,11 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) - image_pixel_values.append( - processed_outputs.pop("pixel_values")) + image_pixel_values.append(processed_outputs.pop("pixel_values")) input_ids = processed_outputs.pop("input_ids") image_placeholder = tokenizer.batch_decode(input_ids)[0] - prompt = prompt.replace("<image_placeholder>", - image_placeholder, 1) + prompt = prompt.replace("<image_placeholder>", image_placeholder, 1) num_patches = [len(item) for item in image_pixel_values] image_outputs = { @@ -365,16 +390,13 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) - video_pixel_values.append( - processed_outputs.pop("pixel_values")) + video_pixel_values.append(processed_outputs.pop("pixel_values")) input_ids = processed_outputs.pop("input_ids") - input_ids[input_ids == - hf_processor.image_token_id] = video_token_id + input_ids[input_ids == hf_processor.image_token_id] = video_token_id video_placeholder = tokenizer.batch_decode(input_ids)[0] - prompt = prompt.replace("<video_placeholder>", - video_placeholder, 1) + prompt = prompt.replace("<video_placeholder>", video_placeholder, 1) num_frames = [len(item) for item in video_pixel_values] video_outputs = { @@ -383,10 +405,8 @@ def _call_hf_processor( "video_token_id": torch.tensor(video_token_id), } - prompt = re.sub("<image_placeholder>", hf_processor.image_token, - prompt) - prompt = re.sub("<video_placeholder>", hf_processor.video_token, - prompt) + prompt = re.sub("<image_placeholder>", hf_processor.image_token, prompt) + prompt = re.sub("<video_placeholder>", hf_processor.video_token, prompt) text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt") return BatchFeature({**text_outputs, **image_outputs, **video_outputs}) @@ -396,7 +416,6 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) num_images = len(image_num_patches) @@ -404,12 +423,14 @@ def _get_mm_fields_config( return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_patches), + "image", image_num_patches + ), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_patches), + "video", video_num_patches + ), video_num_patches=MultiModalFieldConfig.batched("video"), video_token_id=MultiModalFieldConfig.shared("video", num_videos), ) @@ -443,7 +464,8 @@ def _get_prompt_updates( def get_replacement_interns1_image(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -453,19 +475,16 @@ def get_replacement_interns1_image(item_idx: int): repl_features = img_context_token * feature_size repl_full = start_image_token + repl_features + end_image_token - return PromptUpdateDetails.select_text(repl_full, - img_context_token) + return PromptUpdateDetails.select_text(repl_full, img_context_token) def get_replacement_interns1_video(item_idx: int): num_patches = video_num_patches[item_idx] repl_features = video_token * hf_processor.image_seq_length - repl_features_with_sep = (start_image_token + repl_features + - end_image_token) + repl_features_with_sep = start_image_token + repl_features + end_image_token # num_patches is equal to num_frames - repl_full = '\n'.join([ - f'Frame{i+1}: {repl_features_with_sep}' - for i in range(num_patches) - ]) + repl_full = "\n".join( + [f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)] + ) return PromptUpdateDetails.select_text(repl_full, video_token) @@ -486,9 +505,11 @@ def get_replacement_interns1_video(item_idx: int): @MULTIMODAL_REGISTRY.register_processor( InternS1MultiModalProcessor, info=InternS1ProcessingInfo, - dummy_inputs=InternS1DummyInputsBuilder) -class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP, SupportsLoRA): + dummy_inputs=InternS1DummyInputsBuilder, +) +class InternS1ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA +): merge_by_field_config = True # To ensure correct weight loading and mapping. @@ -498,14 +519,15 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, "model.language_model.": "language_model.model.", "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: # transformers InternVLProcessor uses <IMG_CONTEXT> as the separator # refer to https://github.com/huggingface/transformers/blob/f90de364c2484c7c325bbe05befdcf487bd75b63/src/transformers/models/internvl/processing_internvl.py#L116 if modality.startswith("image"): - return '<IMG_CONTEXT>' + return "<IMG_CONTEXT>" if modality.startswith("video"): return "<video>" @@ -524,7 +546,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: patch_size = config.vision_config.patch_size[0] self.patch_size = patch_size self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.downsample_ratio = config.downsample_ratio self.llm_arch_name = config.text_config.architectures[0] @@ -547,7 +570,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.visual_token_mask = None self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _init_vision_model( self, @@ -573,8 +597,12 @@ def pixel_shuffle(self, x, scale_factor=0.5): x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() - x = x.view(n, int(h * scale_factor), int(w * scale_factor), - int(c / (scale_factor * scale_factor))) + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) x = x.permute(0, 2, 1, 3).contiguous() return x @@ -582,18 +610,17 @@ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = self.vision_tower(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :] - h = w = int(vit_embeds.shape[1]**0.5) + h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle(vit_embeds, - scale_factor=self.downsample_ratio) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, - vit_embeds.shape[-1]) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.multi_modal_projector(vit_embeds) return vit_embeds def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[InternS1ImageInputs]: + self, **kwargs: object + ) -> Optional[InternS1ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -626,7 +653,8 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[InternS1VideoInputs]: + self, **kwargs: object + ) -> Optional[InternS1VideoInputs]: pixel_values_flat_video = kwargs.pop("pixel_values_videos", None) video_num_patches = kwargs.pop("video_num_patches", None) video_embeds = kwargs.pop("video_embeds", None) @@ -662,8 +690,10 @@ def _process_vision_input( self, image_input: Union[InternS1ImageInputs, InternS1VideoInputs], ) -> tuple[torch.Tensor, ...]: - if (image_input["type"] == "image_embeds" - or image_input["type"] == "video_embeds"): + if ( + image_input["type"] == "image_embeds" + or image_input["type"] == "video_embeds" + ): return image_input["data"] assert self.vision_tower is not None @@ -674,14 +704,12 @@ def _process_vision_input( # Only one image in the current batch if len(num_patches) == 1: - return (image_embeds.view(-1, - self.config.text_config.hidden_size), ) + return (image_embeds.view(-1, self.config.text_config.hidden_size),) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. feature_size = image_embeds.shape[1] - image_embeds = image_embeds.view(-1, - self.config.text_config.hidden_size) + image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ num_patches * feature_size for num_patches in num_patches ] @@ -693,14 +721,13 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ( - "pixel_values_videos", ) and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos",) and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities @@ -710,9 +737,7 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -743,8 +768,7 @@ def get_input_embeddings( is_multimodal: Optional[torch.Tensor] = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: - if multimodal_embeddings is not None and len( - multimodal_embeddings) > 0: + if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) # This is to satisfy the type checker for each overload @@ -766,7 +790,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> IntermediateTensors: - if intermediate_tensors is not None: input_ids = None inputs_embeds = None @@ -787,8 +810,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -799,4 +821,5 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", - tower_model="vision_tower") + tower_model="vision_tower", + ) diff --git a/vllm/model_executor/models/interns1_vit.py b/vllm/model_executor/models/interns1_vit.py index eb6b685d03dc..f5965bdf7c9c 100644 --- a/vllm/model_executor/models/interns1_vit.py +++ b/vllm/model_executor/models/interns1_vit.py @@ -18,48 +18,45 @@ from vllm.attention.layer import MultiHeadAttention from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader NORM2FN = { - 'rms_norm': RMSNorm, - 'layer_norm': nn.LayerNorm, + "rms_norm": RMSNorm, + "layer_norm": nn.LayerNorm, } class InternS1VisionPatchEmbeddings(nn.Module): - def __init__(self, config): super().__init__() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // - patch_size[0]) - patch_shape = (image_size[0] // patch_size[0], - image_size[1] // patch_size[1]) + num_patches = (image_size[1] // patch_size[1]) * ( + image_size[0] // patch_size[0] + ) + patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches self.patch_shape = patch_shape - self.projection = nn.Conv2d(num_channels, - hidden_size, - kernel_size=patch_size, - stride=patch_size) + self.projection = nn.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size + ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( "Make sure that the channel dimension of the pixel values " - "match with the one set in the configuration.") + "match with the one set in the configuration." + ) - embeddings = self.projection( - pixel_values.to(self.projection.weight.dtype)) + embeddings = self.projection(pixel_values.to(self.projection.weight.dtype)) patch_height, patch_width = embeddings.shape[2], embeddings.shape[3] embeddings = embeddings.flatten(2).transpose(1, 2) @@ -67,30 +64,32 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: class InternS1VisionEmbeddings(nn.Module): - def __init__(self, config: PretrainedConfig): super().__init__() self.config = config self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if config.use_mask_token: - self.mask_token = nn.Parameter( - torch.zeros(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) else: self.mask_token = None self.patch_embeddings = InternS1VisionPatchEmbeddings(config) self.patch_size = config.patch_size - self.image_size = (config.image_size if isinstance( - config.image_size, Iterable) else - (config.image_size, config.image_size)) + self.image_size = ( + config.image_size + if isinstance(config.image_size, Iterable) + else (config.image_size, config.image_size) + ) num_patches = self.patch_embeddings.num_patches if config.use_absolute_position_embeddings: self.position_embeddings = nn.Parameter( - torch.zeros(1, num_patches + 1, config.hidden_size)) + torch.zeros(1, num_patches + 1, config.hidden_size) + ) else: self.position_embeddings = None - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, - width: int) -> torch.Tensor: + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing. @@ -105,8 +104,11 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, # always interpolate when tracing to ensure the exported model # works for dynamic input shapes - if not torch.jit.is_tracing( - ) and num_patches == num_positions and height == width: + if ( + not torch.jit.is_tracing() + and num_patches == num_positions + and height == width + ): return self.position_embeddings class_pos_embed = self.position_embeddings[:, :1] @@ -118,8 +120,9 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, new_width = width // self.patch_size[1] sqrt_num_positions = torch_int(num_positions**0.5) - patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, - sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.reshape( + 1, sqrt_num_positions, sqrt_num_positions, dim + ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( @@ -139,8 +142,7 @@ def forward( bool_masked_pos: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: _, _, height, width = pixel_values.shape - embeddings, (patch_height, - patch_width) = self.patch_embeddings(pixel_values) + embeddings, (patch_height, patch_width) = self.patch_embeddings(pixel_values) batch_size, seq_len, _ = embeddings.size() if bool_masked_pos is not None: @@ -154,7 +156,8 @@ def forward( if self.position_embeddings is not None: embeddings = embeddings + self.interpolate_pos_encoding( - embeddings, height, width) + embeddings, height, width + ) return embeddings, (patch_height, patch_width) @@ -176,39 +179,43 @@ def __init__( self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( - f'embed_dim must be divisible by num_heads ' - f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' - f' {self.num_heads}).') + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) # Additional dummy heads are used to enable TP for common GPU counts. self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim self.scale = self.head_dim**-0.5 - self.q_proj = nn.Linear(self.embed_dim, - self.num_heads * self.head_dim, - bias=config.attention_bias) - self.k_proj = nn.Linear(self.embed_dim, - self.num_heads * self.head_dim, - bias=config.attention_bias) - self.v_proj = nn.Linear(self.embed_dim, - self.num_heads * self.head_dim, - bias=config.attention_bias) + self.q_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias + ) self.qk_normalization = config.use_qk_norm if self.qk_normalization: - self.q_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) - self.k_norm = RMSNorm(self.dummy_dim, - eps=config.layer_norm_eps, - var_hidden_size=self.embed_dim) + self.q_norm = RMSNorm( + self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim, + ) + self.k_norm = RMSNorm( + self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim, + ) self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim) # Use unified MultiHeadAttention with automatic backend selection - self.attn = MultiHeadAttention(self.num_heads, self.head_dim, - self.scale) + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape @@ -230,7 +237,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class InternS1VisionMLP(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -241,16 +247,20 @@ def __init__( self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1") - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2") + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -261,7 +271,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class InternS1VisionLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -272,26 +281,30 @@ def __init__( ) -> None: super().__init__() - self.attention = self._init_attn(config, - quant_config, - num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.attention") + self.attention = self._init_attn( + config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.attention", + ) - self.mlp = InternS1VisionMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = InternS1VisionMLP( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) self.layernorm_before = NORM2FN[config.norm_type]( - config.hidden_size, eps=config.layer_norm_eps) + config.hidden_size, eps=config.layer_norm_eps + ) self.layernorm_after = NORM2FN[config.norm_type]( - config.hidden_size, eps=config.layer_norm_eps) + config.hidden_size, eps=config.layer_norm_eps + ) init_values = config.layer_scale_init_value - self.lambda_1 = nn.Parameter(init_values * - torch.ones(config.hidden_size), - requires_grad=True) - self.lambda_2 = nn.Parameter(init_values * - torch.ones(config.hidden_size), - requires_grad=True) + self.lambda_1 = nn.Parameter( + init_values * torch.ones(config.hidden_size), requires_grad=True + ) + self.lambda_2 = nn.Parameter( + init_values * torch.ones(config.hidden_size), requires_grad=True + ) def _init_attn( self, @@ -307,17 +320,20 @@ def forward( self, hidden_states: torch.Tensor, ): - hidden_states = hidden_states + self.attention( - self.layernorm_before(hidden_states)) * self.lambda_1 + hidden_states = ( + hidden_states + + self.attention(self.layernorm_before(hidden_states)) * self.lambda_1 + ) - hidden_states = hidden_states + self.mlp( - self.layernorm_after(hidden_states)) * self.lambda_2 + hidden_states = ( + hidden_states + + self.mlp(self.layernorm_after(hidden_states)) * self.lambda_2 + ) return hidden_states class InternS1VisionEncoder(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -336,16 +352,19 @@ def __init__( else: num_hidden_layers = num_hidden_layers_override - self.layer = nn.ModuleList([ - InternS1VisionLayer(config, - quant_config, - num_dummy_heads=num_dummy_heads, - prefix=f"{prefix}.layer.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layer = nn.ModuleList( + [ + InternS1VisionLayer( + config, + quant_config, + num_dummy_heads=num_dummy_heads, + prefix=f"{prefix}.layer.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward(self, inputs_embeds: torch.Tensor): - hidden_states = inputs_embeds for encoder_layer in self.layer: hidden_states = encoder_layer(hidden_states) @@ -354,7 +373,6 @@ def forward(self, inputs_embeds: torch.Tensor): class InternS1VisionModel(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -375,9 +393,11 @@ def __init__( num_dummy_heads=num_dummy_heads, prefix=f"{prefix}.encoder", ) - self.layernorm = (nn.Identity() if config.use_mean_pooling else - nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps)) + self.layernorm = ( + nn.Identity() + if config.use_mean_pooling + else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) def get_input_embeddings(self): return self.embeddings.patch_embeddings @@ -388,8 +408,7 @@ def forward( pixel_embeds: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: if pixel_values is None and pixel_embeds is None: - raise ValueError( - 'You have to specify pixel_values or pixel_embeds') + raise ValueError("You have to specify pixel_values or pixel_embeds") if pixel_embeds is not None: hidden_states = pixel_embeds @@ -397,22 +416,19 @@ def forward( if pixel_values.ndim == 4: hidden_states, _ = self.embeddings(pixel_values) else: - raise ValueError( - f'wrong pixel_values size: {pixel_values.shape}') + raise ValueError(f"wrong pixel_values size: {pixel_values.shape}") encoder_outputs = self.encoder(inputs_embeds=hidden_states) encoder_outputs = self.layernorm(encoder_outputs) return encoder_outputs - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 48ea5a18a22d..3cd3807dd888 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -23,31 +23,48 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.models.intern_vit import (InternVisionModel, - InternVisionPatchModel) +from vllm.model_executor.models.intern_vit import ( + InternVisionModel, + InternVisionPatchModel, +) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import set_default_torch_num_threads from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix -IMG_START = '<img>' -IMG_END = '</img>' -IMG_CONTEXT = '<IMG_CONTEXT>' +IMG_START = "<img>" +IMG_END = "</img>" +IMG_CONTEXT = "<IMG_CONTEXT>" IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) @@ -62,6 +79,7 @@ class InternVLImagePixelInputs(TensorSchema): - h: Height of each image patch - w: Width of each image patch """ + type: Literal["pixel_values"] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -74,13 +92,12 @@ class InternVLImageEmbeddingInputs(TensorSchema): - f: Total image feature size - h: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["image_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("n", "f", "h")] + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")] -InternVLImageInputs = Union[InternVLImagePixelInputs, - InternVLImageEmbeddingInputs] +InternVLImageInputs = Union[InternVLImagePixelInputs, InternVLImageEmbeddingInputs] class InternVLVideoPixelInputs(TensorSchema): @@ -92,6 +109,7 @@ class InternVLVideoPixelInputs(TensorSchema): - h: Height of each video frame - w: Width of each video frame """ + type: Literal["pixel_values_videos"] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -104,25 +122,27 @@ class InternVLVideoEmbeddingInputs(TensorSchema): - f: Total video feature size - h: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["video_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("n", "f", "h")] + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")] -InternVLVideoInputs = Union[InternVLVideoPixelInputs, - InternVLVideoEmbeddingInputs] +InternVLVideoInputs = Union[InternVLVideoPixelInputs, InternVLVideoEmbeddingInputs] # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD - transform = T.Compose([ - T.Lambda(lambda img: convert_image_mode(img, 'RGB')), - T.Resize((input_size, input_size), - interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=MEAN, std=STD) - ]) + transform = T.Compose( + [ + T.Lambda(lambda img: convert_image_mode(img, "RGB")), + T.Resize( + (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC + ), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) # Image transformation operations (which include tensor computations # on the CPU) can occupy a substantial number of CPU cores, introducing # overhead due to CPU contention. This issue becomes particularly @@ -147,7 +167,7 @@ def find_closest_aspect_ratio( height: int, image_size: int, ) -> tuple[int, int]: - best_ratio_diff = float('inf') + best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: @@ -182,10 +202,13 @@ def get_internvl_target_ratios( min_num: int, max_num: int, ) -> list[tuple[int, int]]: - target_ratios = {(i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) if min_num <= i * j <= max_num} + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } return sorted(target_ratios, key=lambda x: x[0] * x[1]) @@ -243,10 +266,12 @@ def dynamic_preprocess_internvl( resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): - box = ((i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size) + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) @@ -349,7 +374,8 @@ def __init__( assert isinstance(dynamic_image_size, bool) self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.image_size = image_size self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch @@ -377,14 +403,18 @@ def resolve_min_max_num( dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> tuple[int, int]: - min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch - is None else min_dynamic_patch) - max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch - is None else max_dynamic_patch) - dynamic_image_size = (self.dynamic_image_size if dynamic_image_size - is None else dynamic_image_size) - use_thumbnail = (self.use_thumbnail - if use_thumbnail is None else use_thumbnail) + min_dynamic_patch = ( + self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch + ) + max_dynamic_patch = ( + self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch + ) + dynamic_image_size = ( + self.dynamic_image_size + if dynamic_image_size is None + else dynamic_image_size + ) + use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail return resolve_internvl_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -451,7 +481,8 @@ def _images_to_pixel_values_lst( min_num=min_num, max_num=max_num, use_thumbnail=self.use_thumbnail, - ) for image in images + ) + for image in images ] def _preprocess_image( @@ -472,10 +503,10 @@ def _preprocess_image( dynamic_image_size=dynamic_image_size, ) image_inputs = { - "pixel_values_flat": - torch.cat(pixel_values_lst), - "image_num_patches": - torch.tensor([len(item) for item in pixel_values_lst]), + "pixel_values_flat": torch.cat(pixel_values_lst), + "image_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst] + ), } for pixel_values in pixel_values_lst: @@ -483,11 +514,10 @@ def _preprocess_image( feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) - text = [t.replace('<image>', image_repl.full, 1) for t in text] + text = [t.replace("<image>", image_repl.full, 1) for t in text] return text, image_inputs - def _make_batch_input(self, - input_item: Optional[Union[Any, list[Any]]] = None): + def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None): if input_item is None: input_item = [] if not isinstance(input_item, list): @@ -581,7 +611,8 @@ def _videos_to_pixel_values_lst( min_num=min_num, max_num=max_num, use_thumbnail=False, - ) for video in videos + ) + for video in videos ] def _preprocess_video( @@ -598,18 +629,19 @@ def _preprocess_video( dynamic_image_size=dynamic_image_size, ) video_inputs = { - "pixel_values_flat_video": - torch.cat(pixel_values_lst_video), - "video_num_patches": - torch.tensor([len(item) for item in pixel_values_lst_video]), + "pixel_values_flat_video": torch.cat(pixel_values_lst_video), + "video_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst_video] + ), } for pixel_values in pixel_values_lst_video: num_patches = pixel_values.shape[0] - video_repl = self.get_video_repl(self.num_image_token, - num_patches, self.video_token) - text = [t.replace('<video>', video_repl.full, 1) for t in text] + video_repl = self.get_video_repl( + self.num_image_token, num_patches, self.video_token + ) + text = [t.replace("<video>", video_repl.full, 1) for t in text] return text, video_inputs def __call__( @@ -665,9 +697,9 @@ def get_video_repl( repl_features = video_context_token * self.num_image_token repl_features_with_sep = IMG_START + repl_features + IMG_END # num_patches is equal to num_frames - repl_full = ''.join([ - f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches) - ]) + repl_full = "".join( + [f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)] + ) return PromptUpdateDetails.select_text(repl_full, video_context_token) @@ -714,8 +746,7 @@ def get_image_size_with_most_features(self) -> ImageSize: ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") @@ -750,23 +781,23 @@ def get_dummy_mm_data( mm_counts: Mapping[str, int], mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): - """ Basic image-only MultiModalProcessor for InternVL-style models.""" + """Basic image-only MultiModalProcessor for InternVL-style models.""" def _call_hf_processor( self, @@ -802,7 +833,8 @@ def _get_mm_fields_config( return dict( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_patches), + "image", image_num_patches + ), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), @@ -830,7 +862,8 @@ def _get_prompt_updates( def get_replacement_internvl(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -889,8 +922,7 @@ def get_num_frames_with_most_features( processor = self.get_hf_processor() max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = (seq_len - - max_image_tokens) // processor.num_image_token + max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) @@ -906,7 +938,8 @@ def get_hf_processor(self, **kwargs: object) -> InternVLProcessor: class InternVLDummyInputsBuilder( - BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]): + BaseInternVLDummyInputsBuilder[InternVLProcessingInfo] +): """InternVL DummyInputsBuilder extended for video support""" def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: @@ -920,23 +953,25 @@ def get_dummy_mm_data( mm_counts: Mapping[str, int], mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - dummy_image = super().get_dummy_mm_data(seq_len=seq_len, - mm_counts=mm_counts, - mm_options=mm_options) + dummy_image = super().get_dummy_mm_data( + seq_len=seq_len, mm_counts=mm_counts, mm_options=mm_options + ) if self.info.supports_video: config = self.info.get_hf_config() image_size: int = config.vision_config.image_size - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) num_videos = mm_counts.get("video", 0) video_overrides = mm_options.get("video") if mm_options else None dummy_video = { - "video": - self._get_dummy_videos(width=image_size, - height=image_size, - num_frames=target_num_frames, - num_videos=num_videos, - overrides=video_overrides) + "video": self._get_dummy_videos( + width=image_size, + height=image_size, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, + ) } else: dummy_video = {} @@ -944,7 +979,8 @@ def get_dummy_mm_data( class InternVLMultiModalProcessor( - BaseInternVLMultiModalProcessor[InternVLProcessingInfo]): + BaseInternVLMultiModalProcessor[InternVLProcessingInfo] +): """InternVL MultiModalProcessor extended for video support""" def _call_hf_processor( @@ -954,12 +990,15 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: - processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs, tok_kwargs) + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs + ) hf_processor = self.info.get_hf_processor(**mm_kwargs) - if self.info.supports_video and ( - video_token_id := hf_processor.video_token_id) is not None: + if ( + self.info.supports_video + and (video_token_id := hf_processor.video_token_id) is not None + ): processed_outputs["video_token_id"] = torch.tensor(video_token_id) return processed_outputs @@ -968,18 +1007,16 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - image_fields = super()._get_mm_fields_config(hf_inputs, - hf_processor_mm_kwargs) + image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs) if self.info.supports_video: - video_num_patches = hf_inputs.get("video_num_patches", - torch.empty(0)) + video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) num_videos = len(video_num_patches) video_fields = dict( pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_patches), + "video", video_num_patches + ), video_num_patches=MultiModalFieldConfig.batched("video"), - video_token_id=MultiModalFieldConfig.shared( - "video", num_videos), + video_token_id=MultiModalFieldConfig.shared("video", num_videos), ) else: video_fields = {} @@ -1015,9 +1052,8 @@ def get_video_replacement_internvl(item_idx: int): assert isinstance(num_patches, int) return hf_processor.get_video_repl( - feature_size, - num_patches, - video_context_token=hf_processor.video_token) + feature_size, num_patches, video_context_token=hf_processor.video_token + ) if self.info.supports_video: prompt_repl = [ @@ -1026,7 +1062,7 @@ def get_video_replacement_internvl(item_idx: int): modality="video", target="<video>", replacement=get_video_replacement_internvl, - ) + ), ] return prompt_repl @@ -1035,9 +1071,9 @@ def get_video_replacement_internvl(item_idx: int): @MULTIMODAL_REGISTRY.register_processor( InternVLMultiModalProcessor, info=InternVLProcessingInfo, - dummy_inputs=InternVLDummyInputsBuilder) -class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, - SupportsLoRA): + dummy_inputs=InternVLDummyInputsBuilder, +) +class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): merge_by_field_config = True supports_encoder_tp_data = True @@ -1067,12 +1103,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: patch_size = config.vision_config.patch_size self.patch_size = patch_size self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.llm_arch_name = config.text_config.architectures[0] - self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM' + self.is_mono = self.llm_arch_name == "InternLM2VEForCausalLM" self.vision_model = self._init_vision_model( config, quant_config=quant_config, @@ -1093,18 +1130,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.visual_token_mask = None self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _patch_quant_config(self, config: PretrainedConfig, - quant_config: QuantizationConfig): + def _patch_quant_config( + self, config: PretrainedConfig, quant_config: QuantizationConfig + ): # the awq models from OpenGVLab missing `modules_to_not_convert` # patch the quant_config to add `modules_to_not_convert` back if isinstance(quant_config, AWQConfig): text_config = config.text_config - llm_quant_config = getattr(text_config, "quantization_config", - None) - if (not quant_config.modules_to_not_convert) and \ - (llm_quant_config is not None): + llm_quant_config = getattr(text_config, "quantization_config", None) + if (not quant_config.modules_to_not_convert) and ( + llm_quant_config is not None + ): quant_config.modules_to_not_convert.append("vision_model") def _init_vision_model( @@ -1118,8 +1157,9 @@ def _init_vision_model( if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 + num_hidden_layers = ( + config.vision_config.num_hidden_layers + vision_feature_layer + 1 + ) else: num_hidden_layers = vision_feature_layer + 1 @@ -1128,7 +1168,8 @@ def _init_vision_model( quant_config=quant_config, num_hidden_layers_override=num_hidden_layers, prefix=prefix, - use_data_parallel=self.use_data_parallel) + use_data_parallel=self.use_data_parallel, + ) else: return InternVisionPatchModel(config.vision_config) @@ -1137,9 +1178,10 @@ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: llm_hidden_size = config.text_config.hidden_size return nn.Sequential( - nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), - nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, - llm_hidden_size), + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size + ), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size), ) @@ -1150,9 +1192,13 @@ def pixel_shuffle(self, x, scale_factor=0.5): x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() - x = x.view(n, int(h * scale_factor), int(w * scale_factor), - int(c / (scale_factor * scale_factor))) - if self.ps_version == 'v1': + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": pass else: x = x.permute(0, 2, 1, 3).contiguous() @@ -1162,17 +1208,16 @@ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = self.vision_model(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :] - h = w = int(vit_embeds.shape[1]**0.5) + h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle(vit_embeds, - scale_factor=self.downsample_ratio) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, - vit_embeds.shape[-1]) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[InternVLImageInputs]: + self, **kwargs: object + ) -> Optional[InternVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -1204,7 +1249,8 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[InternVLVideoPixelInputs]: + self, **kwargs: object + ) -> Optional[InternVLVideoPixelInputs]: pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None) video_num_patches = kwargs.pop("video_num_patches", None) video_embeds = kwargs.pop("image_embeds", None) @@ -1239,8 +1285,10 @@ def _process_vision_input( self, image_input: Union[InternVLImageInputs, InternVLVideoInputs], ) -> tuple[torch.Tensor, ...]: - if (image_input["type"] == "image_embeds" - or image_input["type"] == "video_embeds"): + if ( + image_input["type"] == "image_embeds" + or image_input["type"] == "video_embeds" + ): return image_input["data"] assert self.vision_model is not None @@ -1251,14 +1299,12 @@ def _process_vision_input( # Only one image in the current batch if len(num_patches) == 1: - return (image_embeds.view(-1, - self.config.text_config.hidden_size), ) + return (image_embeds.view(-1, self.config.text_config.hidden_size),) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. feature_size = image_embeds.shape[1] - image_embeds = image_embeds.view(-1, - self.config.text_config.hidden_size) + image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ num_patches * feature_size for num_patches in num_patches ] @@ -1270,31 +1316,29 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values_flat", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_flat_video", - ) and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values_flat", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_flat_video",) and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: if self.is_mono: assert self.img_context_token_id is not None - self.visual_token_mask = ( - input_ids == self.img_context_token_id).reshape(-1, 1) + self.visual_token_mask = (input_ids == self.img_context_token_id).reshape( + -1, 1 + ) else: self.visual_token_mask = None def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -1325,8 +1369,7 @@ def get_input_embeddings( is_multimodal: Optional[torch.Tensor] = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: - if multimodal_embeddings is not None and len( - multimodal_embeddings) > 0: + if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) # This is to satisfy the type checker for each overload @@ -1348,7 +1391,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> IntermediateTensors: - if intermediate_tensors is not None: input_ids = None inputs_embeds = None @@ -1362,8 +1404,7 @@ def forward( # Only required if the model is mono-architecture if self.visual_token_mask is not None: - forward_kwargs.update( - {"visual_token_mask": self.visual_token_mask}) + forward_kwargs.update({"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None hidden_states = self.language_model.model(**forward_kwargs) @@ -1375,14 +1416,21 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B skip_prefixes = [ - "action_embed", "temporal_embed", "track_embed", - "track_embed_decoder", "box_token", "cg_criterion", "cg_model", - "loc_encoder", "loc_decoder", "sam", "temporal_token", - "track_token" + "action_embed", + "temporal_embed", + "track_embed", + "track_embed_decoder", + "box_token", + "cg_criterion", + "cg_model", + "loc_encoder", + "loc_decoder", + "sam", + "temporal_token", + "track_token", ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights) @@ -1394,4 +1442,5 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model", connector="mlp1", - tower_model="vision_model") + tower_model="vision_model", + ) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 0eb1578b4361..d788ed7ec2af 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -32,48 +32,57 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class SwiGLUActivation(nn.Module): - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: return x1 * nn.functional.silu(x2) def _get_alibi_slopes(n): - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( - 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + _get_alibi_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) class JAISAttention(nn.Module): - def __init__( self, config: JAISConfig, @@ -84,8 +93,7 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // total_num_heads @@ -113,13 +121,15 @@ def __init__( head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end] - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scale, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -133,7 +143,6 @@ def forward( class JAISMLP(nn.Module): - def __init__( self, intermediate_size: int, @@ -149,12 +158,16 @@ def __init__( bias=True, quant_config=quant_config, ) - self.c_fc2 = (ColumnParallelLinear( - hidden_size, - intermediate_size, - bias=True, - quant_config=quant_config, - ) if self.swiglu else None) + self.c_fc2 = ( + ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) + if self.swiglu + else None + ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, @@ -168,14 +181,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.swiglu: hidden_states2, _ = self.c_fc2(hidden_states) hidden_states, _ = self.c_fc(hidden_states) - hidden_states = (self.act(hidden_states, hidden_states2) - if self.swiglu else self.act(hidden_states)) + hidden_states = ( + self.act(hidden_states, hidden_states2) + if self.swiglu + else self.act(hidden_states) + ) hidden_states, _ = self.c_proj(hidden_states) return hidden_states class JAISBlock(nn.Module): - def __init__( self, config: JAISConfig, @@ -185,14 +200,12 @@ def __init__( ): super().__init__() hidden_size = config.hidden_size - inner_dim = (config.n_inner if config.n_inner is not None else 4 * - hidden_size) + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = JAISAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = JAISMLP(inner_dim, config, quant_config) @@ -202,7 +215,9 @@ def forward( ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_output = self.attn(hidden_states=hidden_states, ) + attn_output = self.attn( + hidden_states=hidden_states, + ) # residual connection hidden_states = attn_output + residual @@ -216,7 +231,6 @@ def forward( @support_torch_compile class JAISModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -230,9 +244,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) - self.wpe = (nn.Embedding(config.max_position_embeddings, - self.embed_dim) - if config.position_embedding_type != "alibi" else None) + self.wpe = ( + nn.Embedding(config.max_position_embeddings, self.embed_dim) + if config.position_embedding_type != "alibi" + else None + ) if hasattr(config, "embeddings_scale"): self.embeddings_scale = config.embeddings_scale else: @@ -240,17 +256,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: JAISBlock(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: JAISBlock( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.h", ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.n_embd)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.n_embd + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -270,8 +288,9 @@ def forward( hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds - hidden_states *= torch.tensor(float(self.embeddings_scale), - dtype=hidden_states.dtype) + hidden_states *= torch.tensor( + float(self.embeddings_scale), dtype=hidden_states.dtype + ) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -287,32 +306,33 @@ def forward( class JAISLMHeadModel(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.transformer = JAISModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) + self.transformer = JAISModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) if self.config.tie_word_embeddings: self.lm_head = self.transformer.wte else: - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: - self.output_logits_scale = (config.mup_output_alpha * - config.mup_width_scale) - self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, - scale=self.output_logits_scale) + self.output_logits_scale = config.mup_output_alpha * config.mup_width_scale + self.logits_processor = LogitsProcessor( + vocab_size=config.vocab_size, scale=self.output_logits_scale + ) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -324,8 +344,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[IntermediateTensors, torch.Tensor]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -335,8 +356,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -366,8 +386,7 @@ def load_weights(self, weights: Iterable[tuple[str, if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index e8277e259bc5..0371458f5578 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jamba model.""" + from collections.abc import Iterable from itertools import islice from typing import Optional @@ -16,37 +17,50 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaMLP as JambaMLP from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class JambaMoE(nn.Module): - - def __init__(self, - config: JambaConfig, - num_experts: Optional[int] = None, - top_k: Optional[int] = None, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: JambaConfig, + num_experts: Optional[int] = None, + top_k: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.num_total_experts = num_experts or config.num_experts self.top_k = top_k or config.num_experts_per_tok @@ -54,23 +68,27 @@ def __init__(self, self.intermediate_size = config.intermediate_size if self.num_total_experts > 1: - self.router = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - quant_config=None, - params_dtype=params_dtype) - - self.experts = FusedMoE(self.num_total_experts, - self.top_k, - self.hidden_size, - self.intermediate_size, - tp_size=tp_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=False, - use_grouped_topk=False, - quant_config=quant_config, - prefix=f"{prefix}.experts") + self.router = ReplicatedLinear( + self.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None, + params_dtype=params_dtype, + ) + + self.experts = FusedMoE( + self.num_total_experts, + self.top_k, + self.hidden_size, + self.intermediate_size, + tp_size=tp_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + use_grouped_topk=False, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape @@ -79,43 +97,46 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.num_total_experts > 1: router_logits, _ = self.router(hidden_states) else: - router_logits = torch.ones((hidden_states.shape[0], 1), - device=hidden_states.device, - dtype=hidden_states.dtype) + router_logits = torch.ones( + (hidden_states.shape[0], 1), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) hidden_states = self.experts(hidden_states, router_logits) return hidden_states.view(orig_shape) class JambaMambaDecoderLayer(nn.Module): - - def __init__(self, - config: JambaConfig, - layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - is_lora_enabled: Optional[bool] = False, - prefix: str = "", - **kwargs) -> None: + def __init__( + self, + config: JambaConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + is_lora_enabled: Optional[bool] = False, + prefix: str = "", + **kwargs, + ) -> None: super().__init__() self.config = config self.is_lora_enabled = is_lora_enabled - self.mamba = MambaMixer(hidden_size= config.hidden_size, - ssm_state_size = config.mamba_d_state, - conv_kernel_size = config.mamba_d_conv, - intermediate_size = config.mamba_expand *\ - config.hidden_size, - time_step_rank = config.mamba_dt_rank, - use_conv_bias = config.mamba_conv_bias, - use_bias = config.mamba_proj_bias, - use_rms_norm=True, - rms_norm_eps=config.rms_norm_eps, - activation=config.hidden_act, - is_lora_enabled = self.is_lora_enabled, - model_config=model_config, - cache_config=cache_config, - prefix=f"{prefix}.mixer", - ) + self.mamba = MambaMixer( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=config.mamba_expand * config.hidden_size, + time_step_rank=config.mamba_dt_rank, + use_conv_bias=config.mamba_conv_bias, + use_bias=config.mamba_proj_bias, + use_rms_norm=True, + rms_norm_eps=config.rms_norm_eps, + activation=config.hidden_act, + is_lora_enabled=self.is_lora_enabled, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.mixer", + ) num_experts = config.layers_num_experts[layer_idx] if num_experts > 1: @@ -132,10 +153,8 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -147,8 +166,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) output = torch.empty_like(hidden_states) self.mamba(hidden_states, output) @@ -159,15 +177,16 @@ def forward( class JambaAttentionDecoderLayer(nn.Module): - - def __init__(self, - config: JambaConfig, - layer_idx: int, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - **kwargs) -> None: + def __init__( + self, + config: JambaConfig, + layer_idx: int, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + **kwargs, + ) -> None: super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -197,10 +216,12 @@ def __init__(self, bias=False, quant_config=quant_config, ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) self.attn = Attention( self.num_heads, @@ -226,10 +247,8 @@ def __init__(self, quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def self_attention( self, @@ -254,29 +273,26 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attention( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(hidden_states, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual ALL_DECODER_LAYER_TYPES = { "attention": JambaAttentionDecoderLayer, - "mamba": JambaMambaDecoderLayer + "mamba": JambaMambaDecoderLayer, } @support_torch_compile class JambaModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -287,8 +303,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -302,24 +321,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) - layer_class = ALL_DECODER_LAYER_TYPES[ - config.layers_block_type[layer_idx]] - return layer_class(config, - layer_idx, - model_config, - cache_config, - quant_config=quant_config, - prefix=prefix, - **extra_kwargs) + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + model_config, + cache_config, + quant_config=quant_config, + prefix=prefix, + **extra_kwargs, + ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -343,15 +363,14 @@ def forward( residual = intermediate_tensors["residual"] for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions=positions, - hidden_states=hidden_states, - residual=residual) + hidden_states, residual = layer( + positions=positions, hidden_states=hidden_states, residual=residual + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -362,10 +381,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) + num_experts=self.config.num_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -384,7 +403,7 @@ def load_weights(self, weights: Iterable[tuple[str, for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if 'experts' in name: + if "experts" in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -399,10 +418,10 @@ def load_weights(self, weights: Iterable[tuple[str, break else: for ( - param_name, - weight_name, - expert_id, - shard_id, + param_name, + weight_name, + expert_id, + shard_id, ) in expert_params_mapping: if weight_name not in name: continue @@ -412,11 +431,13 @@ def load_weights(self, weights: Iterable[tuple[str, name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -426,19 +447,18 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ - ".self_attn.": ".", - ".A_log": ".A" - }, ) +class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".self_attn.": ".", ".A_log": ".A"}, + ) packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -461,16 +481,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ + assert not cache_config.enable_prefix_caching, ( "Jamba currently does not support prefix caching" + ) super().__init__() self.config = config self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.scheduler_config = scheduler_config - self.model = JambaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = JambaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -481,33 +503,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) + return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) @@ -517,7 +543,6 @@ def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba1_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -547,8 +572,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -557,7 +581,6 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: class JambaForSequenceClassification(JambaForCausalLM): - is_pooling_model = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -565,7 +588,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config num_labels: int = config.num_labels - score_bias: bool = getattr(config, 'score_bias', False) + score_bias: bool = getattr(config, "score_bias", False) # TODO: The original reward weights have float32 accuracy data, we # would like to load them in fp32 to get that extra precision. @@ -580,12 +603,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - Pooler.for_classify( - pooler_config, - classifier=self.score, - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": Pooler.for_classify( + pooler_config, + classifier=self.score, + ), + } + ) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index f8c2a1e507a7..9711eeeeec33 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -10,36 +10,34 @@ from vllm.config import ModelConfig, VllmConfig from vllm.inputs import TokensPrompt from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors -from .interfaces import (SupportsCrossEncoding, SupportsMultiModal, - SupportsScoreTemplate) -from .qwen2_vl import (Qwen2VLDummyInputsBuilder, - Qwen2VLForConditionalGeneration, - Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo) +from .interfaces import SupportsCrossEncoding, SupportsMultiModal, SupportsScoreTemplate +from .qwen2_vl import ( + Qwen2VLDummyInputsBuilder, + Qwen2VLForConditionalGeneration, + Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo, +) from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix logger = init_logger(__name__) class JinaVLScorer(nn.Module): - def __init__(self, model_config: "ModelConfig"): super().__init__() config = model_config.hf_config head_dtype = model_config.head_dtype - self.dense = ColumnParallelLinear(config.hidden_size, - config.hidden_size, - params_dtype=head_dtype, - bias=True) - self.out_proj = RowParallelLinear(config.hidden_size, - config.num_labels, - params_dtype=head_dtype, - bias=True) + self.dense = ColumnParallelLinear( + config.hidden_size, config.hidden_size, params_dtype=head_dtype, bias=True + ) + self.out_proj = RowParallelLinear( + config.hidden_size, config.num_labels, params_dtype=head_dtype, bias=True + ) def forward(self, x, **kwargs): x, _ = self.dense(x) @@ -49,7 +47,6 @@ def forward(self, x, **kwargs): class JinaVLMultiModalProcessor(Qwen2VLMultiModalProcessor): - def _call_hf_processor( self, prompt: str, @@ -57,25 +54,26 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: - # NOTE: We should reverse the order of the mm_data because the # query prompt is placed after the document prompt in the score # template for JinaVLForRanking model, but in mm_data they are # stored in the opposite order (query first, then document). for _, value in mm_data.items(): value.reverse() - return super()._call_hf_processor(prompt, mm_data, mm_kwargs, - tok_kwargs) - - -@MULTIMODAL_REGISTRY.register_processor(JinaVLMultiModalProcessor, - info=Qwen2VLProcessingInfo, - dummy_inputs=Qwen2VLDummyInputsBuilder) -class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, - SupportsCrossEncoding, - SupportsMultiModal, - SupportsScoreTemplate): - + return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs) + + +@MULTIMODAL_REGISTRY.register_processor( + JinaVLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder, +) +class JinaVLForSequenceClassification( + Qwen2VLForConditionalGeneration, + SupportsCrossEncoding, + SupportsMultiModal, + SupportsScoreTemplate, +): is_pooling_model = True weight_mapper = WeightsMapper( orig_to_new_prefix={ @@ -87,23 +85,24 @@ class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, # mapping for original checkpoint "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "qwen2_vl")) + super().__init__( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "qwen2_vl") + ) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None self.score = JinaVLScorer(vllm_config.model_config) - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - Pooler.for_classify(pooler_config, classifier=self.score), - "score": - Pooler.for_classify(pooler_config, classifier=self.score), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": Pooler.for_classify(pooler_config, classifier=self.score), + "score": Pooler.for_classify(pooler_config, classifier=self.score), + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -118,9 +117,8 @@ def get_score_template(cls, query: str, document: str) -> Optional[str]: @classmethod def post_process_tokens(cls, prompt: TokensPrompt) -> None: - # add score target token at the end of prompt tokens - prompt['prompt_token_ids'].append(100) + prompt["prompt_token_ids"].append(100) def forward( self, diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 20f705cca8e6..7ccbc81431f6 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -13,8 +13,7 @@ from transformers import PretrainedConfig from transformers.activations import GELUActivation from transformers.feature_extraction_utils import BatchFeature -from transformers.modeling_outputs import (BaseModelOutput, - BaseModelOutputWithPooling) +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import torch_int from vllm.attention.backends.registry import _Backend @@ -23,34 +22,57 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, VideoItem) -from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ImageSize, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .siglip import SiglipMLP -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, is_pp_missing_parameter, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + is_pp_missing_parameter, + maybe_prefix, +) from .vision import get_vit_attn_backend logger = init_logger(__name__) @@ -82,8 +104,10 @@ def smart_resize( width = factor if max(height, width) / min(height, width) > 200: - raise ValueError("absolute aspect ratio must be smaller than 200, got " - "{max(height, width) / min(height, width)}") + raise ValueError( + "absolute aspect ratio must be smaller than 200, got " + "{max(height, width) / min(height, width)}" + ) h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: @@ -100,17 +124,17 @@ def smart_resize( class KeyeImagePixelInputs(TensorSchema): """ Dimensions: - - b: Batch size - - np: Number of patches + - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ + type: Literal["pixel_values"] pixel_values: Annotated[ - torch.Tensor, - TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] + torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"}) + ] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -118,11 +142,12 @@ class KeyeImageEmbeddingInputs(TensorSchema): """ Dimensions: - nf: Number of image features - - hs: Hidden size (must match the hidden size of language model + - hs: Hidden size (must match the hidden size of language model backbone) - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ + type: Literal["image_embeds"] image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -134,17 +159,17 @@ class KeyeImageEmbeddingInputs(TensorSchema): class KeyeVideoPixelInputs(TensorSchema): """ Dimensions: - - b: Batch size - - np: Number of patches + - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ + type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ - torch.Tensor, - TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})] + torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"}) + ] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] @@ -152,11 +177,12 @@ class KeyeVideoEmbeddingInputs(TensorSchema): """ Dimensions: - nf: Number of video features - - hs: Hidden size (must match the hidden size of language model + - hs: Hidden size (must match the hidden size of language model backbone) - nv: Number of videos - g: Grid dimensions (3 for t, h, w) """ + type: Literal["video_embeds"] video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] @@ -166,7 +192,6 @@ class KeyeVideoEmbeddingInputs(TensorSchema): class KeyeVisionEmbeddings(nn.Module): - def __init__(self, config: PretrainedConfig): super().__init__() self.config = config @@ -182,12 +207,11 @@ def __init__(self, config: PretrainedConfig): padding="valid", ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.cache_position_embedding = dict() self.cache_position_count = dict() - self.position_embedding = nn.Embedding(self.num_positions, - self.embed_dim) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.packing_position_embedding = nn.Embedding(32768, self.embed_dim) self.register_buffer( @@ -203,7 +227,6 @@ def interpolate_pos_encoding( width: int, is_after_patchify: bool = False, ) -> torch.Tensor: - num_positions = self.position_embedding.weight.shape[0] patch_pos_embed = self.position_embedding.weight.unsqueeze(0) @@ -218,8 +241,9 @@ def interpolate_pos_encoding( new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) - patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, - sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.reshape( + 1, sqrt_num_positions, sqrt_num_positions, dim + ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( @@ -232,11 +256,7 @@ def interpolate_pos_encoding( patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed - def fetch_position_embedding_lfu_cache(self, - embeddings, - h, - w, - max_cache: int = 20): + def fetch_position_embedding_lfu_cache(self, embeddings, h, w, max_cache: int = 20): grid = (h, w) if grid in self.cache_position_embedding: self.cache_position_count[grid] += 1 @@ -250,8 +270,7 @@ def fetch_position_embedding_lfu_cache(self, self.cache_position_count.pop(min_hit_grid) self.cache_position_embedding.pop(min_hit_grid) - position_embedding = self.interpolate_pos_encoding( - embeddings, h, w, True) + position_embedding = self.interpolate_pos_encoding(embeddings, h, w, True) self.cache_position_count[grid] = 1 self.cache_position_embedding[grid] = position_embedding return position_embedding @@ -260,10 +279,14 @@ def forward( self, pixel_values: torch.FloatTensor, position_ids: Optional[torch.Tensor] = None, - image_grid_thw: Optional[list[Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ]]] = None, + image_grid_thw: Optional[ + list[ + Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ] + ] + ] = None, interpolate_pos_encoding=False, ) -> torch.Tensor: if pixel_values.dim() == 4: @@ -282,8 +305,7 @@ def forward( ) = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype pixel_values = rearrange(pixel_values, "b l c h w -> (b l) c h w") - patch_embeds = self.patch_embedding( - pixel_values.to(dtype=target_dtype)) + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) embeddings = patch_embeds.flatten(-2).squeeze(-1) if interpolate_pos_encoding and image_grid_thw is not None: @@ -293,19 +315,23 @@ def forward( t, h, w = image_grid end = start + t * h * w image_embeddings = embeddings[start:end, :] - position_embedding = (self.interpolate_pos_encoding( - image_embeddings, h, w, True).squeeze(0).repeat(t, 1)) + position_embedding = ( + self.interpolate_pos_encoding(image_embeddings, h, w, True) + .squeeze(0) + .repeat(t, 1) + ) image_embeddings = image_embeddings + position_embedding tmp_embeddings.append(image_embeddings) start = end embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) else: - embeddings = embeddings + self.packing_position_embedding( - position_ids) + embeddings = embeddings + self.packing_position_embedding(position_ids) return embeddings else: - raise ValueError("Unsupported pixel_values dimension:" - f" {pixel_values.dim()}. Expected 4 or 5.") + raise ValueError( + "Unsupported pixel_values dimension:" + f" {pixel_values.dim()}. Expected 4 or 5." + ) def apply_rotary_pos_emb_flashatt( @@ -372,18 +398,20 @@ def __init__( # Detect attention implementation. self.attn_backend = get_vit_attn_backend( - head_size=self.head_dim, dtype=torch.get_default_dtype()) + head_size=self.head_dim, dtype=torch.get_default_dtype() + ) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): self.attn_backend = _Backend.FLASH_ATTN self.use_upstream_fa = True if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: raise RuntimeError( - f"Keye-VL does not support {self.attn_backend} backend now.") + f"Keye-VL does not support {self.attn_backend} backend now." + ) def forward( self, @@ -417,8 +445,7 @@ def forward( ) else: if cu_seqlens is None: - raise ValueError( - "cu_seqlens cannot be None when rope_emb is not None.") + raise ValueError("cu_seqlens cannot be None when rope_emb is not None.") cos, sin = rope_emb q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) k = k.view( @@ -452,29 +479,26 @@ def forward( causal=False, softmax_scale=self.scale, ) - context_layer = rearrange(output, - "(b s) ... -> b s ...", - b=batch_size) + context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) - context_layer = rearrange(context_layer, - "b s h d -> b s (h d)").contiguous() + context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous() output, _ = self.out_proj(context_layer) return output class SigLIPRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim @@ -482,8 +506,9 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None: self.rope_init() def rope_init(self): - inv_freq = 1.0 / (self.theta**( - torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim)) + inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: @@ -497,7 +522,6 @@ def forward(self, seqlen: int) -> torch.Tensor: class KeyeSiglipEncoderLayer(nn.Module): - def __init__( self, config: Union[PretrainedConfig], @@ -506,15 +530,13 @@ def __init__( ): super().__init__() self.embed_dim = config.hidden_size - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = KeyeSiglipAttention( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( config, quant_config=quant_config, @@ -529,7 +551,6 @@ def forward( cu_seqlens: Optional[list[torch.Tensor]] = None, rope_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.FloatTensor]: - residual = hidden_states hidden_states = self.layer_norm1(hidden_states) @@ -553,7 +574,6 @@ def forward( class KeyeSiglipEncoder(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -565,13 +585,16 @@ def __init__( embed_dim = config.hidden_size num_heads = config.num_attention_heads head_dim = embed_dim // num_heads - self.layers = nn.ModuleList([ - KeyeSiglipEncoderLayer( - config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - ) for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + KeyeSiglipEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) self.rotary_pos_emb = SigLIPRotaryEmbedding(head_dim // 2) @staticmethod @@ -591,10 +614,14 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cu_seqlens: Optional[list[torch.Tensor]] = None, - image_grid_thw: Optional[list[Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ]]] = None, + image_grid_thw: Optional[ + list[ + Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ] + ] + ] = None, height_position_ids: Optional[torch.Tensor] = None, width_position_ids: Optional[torch.Tensor] = None, use_rope: Optional[bool] = False, @@ -610,8 +637,7 @@ def forward( split_hids = list() split_wids = list() for t, h, w in flatten_image_grid_thw: - image_pids = torch.arange(t * h * w, - device=device) % (h * w) + image_pids = torch.arange(t * h * w, device=device) % (h * w) sample_hids = image_pids // w sample_wids = image_pids % w split_hids.append(sample_hids) @@ -647,7 +673,6 @@ def forward( class KeyeSiglipVisionTransformer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -664,8 +689,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.encoder", ) - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, @@ -682,15 +706,18 @@ def forward( cu_seqlens: Optional[list[torch.Tensor]] = None, padding_mask: Optional[torch.Tensor] = None, vision_return_embed_list: Optional[bool] = False, - image_grid_thw: Optional[list[Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ]]] = None, + image_grid_thw: Optional[ + list[ + Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ] + ] + ] = None, return_pooler_output: Optional[bool] = True, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, ) -> BaseModelOutputWithPooling: - hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, @@ -716,8 +743,10 @@ def forward( sample_hidden_state = list() if cu_seqlens is None: - raise ValueError("cu_seqlens cannot be None for " - "SiglipVisionTransformer output processing.") + raise ValueError( + "cu_seqlens cannot be None for " + "SiglipVisionTransformer output processing." + ) for i in range(cu_seqlens.shape[0] - 1): start = cu_seqlens[i] end = cu_seqlens[i + 1] @@ -766,16 +795,19 @@ def forward( interpolate_pos_encoding: bool = False, position_ids: Optional[torch.Tensor] = None, vision_return_embed_list: Optional[bool] = False, - image_grid_thw: Optional[list[Union[ - tuple[int, int, int], - list[tuple[int, int, int]], - ]]] = None, + image_grid_thw: Optional[ + list[ + Union[ + tuple[int, int, int], + list[tuple[int, int, int]], + ] + ] + ] = None, cu_seqlens: Optional[list[torch.Tensor]] = None, return_pooler_output: Optional[bool] = True, use_rope: Optional[bool] = False, window_size: Optional[bool] = -1, ) -> BaseModelOutputWithPooling: - return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, @@ -791,8 +823,7 @@ def forward( window_size=window_size, ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -808,22 +839,24 @@ def load_weights(self, weights: Iterable[tuple[str, if "head.mlp" in name or "head.probe" in name: continue if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name)): + scale_name := self.quant_config.get_cache_scale(name) + ): param = params_dict[scale_name] weight_loader = getattr( param, "weight_loader", default_weight_loader, ) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue for ( - param_name, - weight_name, - shard_id, + param_name, + weight_name, + shard_id, ) in stacked_params_mapping: if weight_name not in name: continue @@ -856,7 +889,6 @@ def load_weights(self, weights: Iterable[tuple[str, class Projector(nn.Module): - def __init__( self, text_config: PretrainedConfig, @@ -869,12 +901,13 @@ def __init__( self.vision_config = vision_config self.merge_kernel_size = (2, 2) - self.hidden_size = (self.vision_config.hidden_size * - self.merge_kernel_size[0] * - self.merge_kernel_size[1]) + self.hidden_size = ( + self.vision_config.hidden_size + * self.merge_kernel_size[0] + * self.merge_kernel_size[1] + ) - self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, - eps=1e-05) + self.pre_norm = torch.nn.LayerNorm(self.vision_config.hidden_size, eps=1e-05) self.act = GELUActivation() self.linear_1 = ColumnParallelLinear( @@ -900,8 +933,7 @@ def forward( m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() - for image_feature, image_grid in zip(image_features, - image_grid_thw): + for image_feature, image_grid in zip(image_features, image_grid_thw): image_feature = self.pre_norm(image_feature) t, h, w = image_grid @@ -924,8 +956,7 @@ def forward( dims = image_features.shape[:-1] dim = image_features.shape[-1] image_features = image_features.view(np.prod(dims), dim) - hidden_states = self.pre_norm(image_features).view( - -1, self.hidden_size) + hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size) hidden_states = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) @@ -933,7 +964,9 @@ def forward( return hidden_states.view(*dims, -1) -def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ): +def _keye_field_config( + hf_inputs: Mapping[str, torch.Tensor], +): image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_grid_sizes = image_grid_thw.prod(-1) @@ -941,21 +974,18 @@ def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ): video_grid_sizes = video_grid_thw.prod(-1) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), + video_embeds=MultiModalFieldConfig.flat_from_sizes("video", video_grid_sizes), video_grid_thw=MultiModalFieldConfig.batched("video"), ) class KeyeMultiModalDataParser(MultiModalDataParser): - def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -992,17 +1022,18 @@ def _parse_video_data( class KeyeProcessingInfo(BaseProcessingInfo): - def get_max_image_size(self) -> int: - return 9999999 #_MAX_IMAGE_SIZE + return 9999999 # _MAX_IMAGE_SIZE def get_max_frame_per_video(self) -> int: - return 16 #_MAX_FRAMES_PER_VIDEO + return 16 # _MAX_FRAMES_PER_VIDEO def get_image_processor(self, **kwargs: object): return self.get_hf_processor(**kwargs).image_processor - def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits( + self, + ) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} def get_mm_max_tokens_per_item( @@ -1041,11 +1072,9 @@ def _get_vision_info( min_pixels=image_processor.min_pixels, max_pixels=image_processor.max_pixels, ) - preprocessed_size = ImageSize(width=resized_width, - height=resized_height) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: - preprocessed_size = ImageSize(width=image_width, - height=image_height) + preprocessed_size = ImageSize(width=image_width, height=image_height) padded_num_frames = num_frames + num_frames % temporal_patch_size @@ -1088,7 +1117,9 @@ def get_num_video_tokens( ) return num_video_tokens - def get_image_size_with_most_features(self, ) -> ImageSize: + def get_image_size_with_most_features( + self, + ) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=self.get_max_image_size(), image_height=self.get_max_image_size(), @@ -1132,8 +1163,7 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int: max_videos = mm_config.get_limit_per_prompt("video") max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) max_frames_per_video = min( max_total_frames // max(max_videos, 1), self.get_max_frame_per_video(), @@ -1156,7 +1186,6 @@ def get_max_video_tokens(self, seq_len: int) -> int: class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -1176,24 +1205,20 @@ def get_dummy_mm_data( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = ( - self.info.get_image_size_with_most_features()) - target_num_frames = self.info.get_num_frames_with_most_features( - seq_len) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features(seq_len) image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None mm_data = { - "image": - self._get_dummy_images( + "image": self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, overrides=image_overrides, ), - "video": - self._get_dummy_videos( + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, @@ -1205,12 +1230,10 @@ def get_dummy_mm_data( return mm_data -class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): - ... +class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): ... class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: return KeyeMultiModalDataParser() @@ -1221,8 +1244,7 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -1246,7 +1268,8 @@ def get_replacement_keye(item_idx: int, modality: str): modality=modality, target=[placeholder[modality]], replacement=partial(get_replacement_keye, modality=modality), - ) for modality in ("image", "video") + ) + for modality in ("image", "video") ] def _get_mm_fields_config( @@ -1258,6 +1281,8 @@ def _get_mm_fields_config( class BaseKeyeModule(nn.Module): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1270,10 +1295,12 @@ class BaseKeyeModule(nn.Module): ], } - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "lm_head.": "language_model.lm_head.", - "model.": "language_model.model.", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -1313,18 +1340,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) @abstractmethod - def _build_projector(self, - text_config: PretrainedConfig, - vision_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: + def _build_projector( + self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: raise ValueError("Need projector") - def _process_image_input(self, - image_input: Any) -> tuple[torch.Tensor, ...]: + def _process_image_input(self, image_input: Any) -> tuple[torch.Tensor, ...]: siglip_position_ids = list() image_grid_hws = list() sample_indices = list() @@ -1339,21 +1368,22 @@ def _process_image_input(self, image_grid_hws.append(thw_tuple) image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(image_position_ids) - sample_indices.append(torch.full((numel, ), idx, - dtype=torch.int64)) + sample_indices.append(torch.full((numel,), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) if image_input["type"] == "image_embeds": raise ValueError( - "Image embeddings are not supported for this processing path.") + "Image embeddings are not supported for this processing path." + ) else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) - siglip_position_ids = torch.concat(siglip_position_ids, - dim=0).to(pixel_values.device) + siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( + pixel_values.device + ) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( - pixel_values.device) - sample_indices = torch.concat(sample_indices, - dim=0).to(pixel_values.device) + pixel_values.device + ) + sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values.device) image_embeds = self.visual( pixel_values=pixel_values, @@ -1373,7 +1403,7 @@ def _process_video_embeds( self, video_type: Literal["video_embeds", "pixel_values_videos"], video_grid_thw: list[torch.Tensor], - pixel_values_videos: Optional[torch.Tensor] = None + pixel_values_videos: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, list[torch.Tensor]]: siglip_position_ids = list() video_grid_hws = list() @@ -1388,21 +1418,24 @@ def _process_video_embeds( video_grid_hws.append(thw_tuple) video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) siglip_position_ids.append(video_position_ids) - sample_indices.append(torch.full((numel, ), idx, - dtype=torch.int64)) + sample_indices.append(torch.full((numel,), idx, dtype=torch.int64)) cu_seqlens.append(cu_seqlens[-1] + numel) if video_type == "video_embeds": raise ValueError( - "Video embeddings are not supported for this processing path.") + "Video embeddings are not supported for this processing path." + ) else: pixel_values_videos = pixel_values_videos.type(self.visual.dtype) siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( - pixel_values_videos.device) + pixel_values_videos.device + ) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( - pixel_values_videos.device) - sample_indices = torch.concat(sample_indices, - dim=0).to(pixel_values_videos.device) + pixel_values_videos.device + ) + sample_indices = torch.concat(sample_indices, dim=0).to( + pixel_values_videos.device + ) video_embeds = self.visual( pixel_values=pixel_values_videos, @@ -1422,14 +1455,16 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} for input_key in kwargs: - if (input_key in ("pixel_values", "image_embeds") - and "images" not in modalities): - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if (input_key in ("pixel_values_videos", "video_embeds") - and "videos" not in modalities): - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "videos" not in modalities + ): + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities @@ -1437,8 +1472,8 @@ def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - + self, **kwargs: object + ) -> Optional[MultiModalEmbeddings]: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return None @@ -1495,8 +1530,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -1514,40 +1548,21 @@ def get_mm_mapping(self) -> MultiModelKeys: info=KeyeProcessingInfo, dummy_inputs=KeyeDummyInputsBuilder, ) -class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, - SupportsLoRA, SupportsPP): - - def _build_projector(self, - text_config: PretrainedConfig, - vision_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: +class KeyeForConditionalGeneration( + BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP +): + def _build_projector( + self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: return Projector(text_config, vision_config, quant_config, prefix) - def _validate_and_reshape_mm_tensor( - self, mm_input: NestedTensors, - name: str) -> Union[torch.Tensor, list[torch.Tensor]]: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim == 5: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return mm_input.reshape(-1, mm_input.shape[-1]) - elif is_list_of(mm_input, torch.Tensor): - if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2 - for p in mm_input): - return mm_input - return torch.concat(mm_input) - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[KeyeImageInputs]: + self, **kwargs: object + ) -> Optional[KeyeImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1556,11 +1571,6 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - return KeyeImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -1568,11 +1578,6 @@ def _parse_and_validate_image_input( ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") - return KeyeImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -1580,7 +1585,8 @@ def _parse_and_validate_image_input( ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[KeyeVideoInputs]: + self, **kwargs: object + ) -> Optional[KeyeVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1589,13 +1595,6 @@ def _parse_and_validate_video_input( return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, - "video pixel values", - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return KeyeVideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, @@ -1603,11 +1602,6 @@ def _parse_and_validate_video_input( ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") - return KeyeVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, @@ -1615,11 +1609,12 @@ def _parse_and_validate_video_input( ) def _process_video_input( - self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]: + self, video_input: KeyeVideoInputs + ) -> tuple[torch.Tensor, ...]: video_type = video_input["type"] video_grid_thw = video_input["video_grid_thw"] pixel_values_videos = video_input.get("pixel_values_videos", None) return tuple( - self._process_video_embeds(video_type, video_grid_thw, - pixel_values_videos)) + self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos) + ) diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 93a3bf5f98f7..578436fcad21 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -15,22 +15,36 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalFieldConfig, - MultiModalKwargsItems, VideoItem) -from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, - MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .keye import (BaseKeyeModule, BaseMultiModalProcessor, - KeyeBaseDummyInputsBuilder, KeyeProcessingInfo) +from .keye import ( + BaseKeyeModule, + BaseMultiModalProcessor, + KeyeBaseDummyInputsBuilder, + KeyeProcessingInfo, +) logger = init_logger(__name__) @@ -58,8 +72,9 @@ def split_thw(grid_thw: torch.Tensor) -> torch.Tensor: return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0) -def get_num_patches(grid_thw: torch.Tensor, - num_frames: Union[list[int], torch.Tensor]) -> list[int]: +def get_num_patches( + grid_thw: torch.Tensor, num_frames: Union[list[int], torch.Tensor] +) -> list[int]: """ Return num_patches per video. @@ -73,9 +88,13 @@ def get_num_patches(grid_thw: torch.Tensor, Examples: >>> # Suppose there are 2 videos with a total of 3 grids - >>> grid_thw = torch.tensor([[2, 2, 2], # grid 0: 2*2*2=8 patches - ... [2, 2, 2], # grid 1: 2*2*2=8 patches - ... [1, 1, 1]]) # grid 2: 1*1*1=1 patches + >>> grid_thw = torch.tensor( + ... [ + ... [2, 2, 2], # grid 0: 2*2*2=8 patches + ... [2, 2, 2], # grid 1: 2*2*2=8 patches + ... [1, 1, 1], + ... ] + ... ) # grid 2: 1*1*1=1 patches >>> num_frames = [2, 1] # The first video contains 2 grids, the second contains 1 grid. >>> get_num_patches(grid_thw, num_frames) @@ -90,28 +109,31 @@ def get_num_patches(grid_thw: torch.Tensor, num_grids_per_frame = grid_thw.prod(dim=1) start_idx_per_video = [0, *itertools.accumulate(num_frames)] num_patches = [ - num_grids_per_frame[start_idx_per_video[i]:start_idx_per_video[i + 1]]. - sum() for i in range(len(num_frames)) + num_grids_per_frame[start_idx_per_video[i] : start_idx_per_video[i + 1]].sum() + for i in range(len(num_frames)) ] - return torch.stack(num_patches) if num_patches else torch.zeros( - 0, dtype=grid_thw.dtype, device=grid_thw.device) + return ( + torch.stack(num_patches) + if num_patches + else torch.zeros(0, dtype=grid_thw.dtype, device=grid_thw.device) + ) class KeyeVL1_5ImagePixelInputs(TensorSchema): """ Dimensions: - - b: Batch size - - np: Number of patches + - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ + type: Literal["pixel_values"] pixel_values: Annotated[ - torch.Tensor, - TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})] + torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"}) + ] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] @@ -125,29 +147,29 @@ class KeyeVL1_5ImageEmbeddingInputs(TensorSchema): - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ + type: Literal["image_embeds"] image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] -KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs, - KeyeVL1_5ImageEmbeddingInputs] +KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs, KeyeVL1_5ImageEmbeddingInputs] class KeyeVL1_5VideoPixelInputs(TensorSchema): """ Dimensions: - - b: Batch size - - np: Number of patches + - bnp: Batch size * Number of patches - c: Number of channels - ps: Patch size - ni: Number of images - g: Grid dimensions (3 for t, h, w) """ + type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ - torch.Tensor, - TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})] + torch.Tensor, TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"}) + ] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] num_frames: torch.Tensor @@ -162,18 +184,17 @@ class KeyeVL1_5VideoEmbeddingInputs(TensorSchema): - nv: Number of videos - g: Grid dimensions (3 for t, h, w) """ + type: Literal["video_embeds"] video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] num_frames: torch.Tensor -KeyeVL1_5VideoInputs = Union[KeyeVL1_5VideoPixelInputs, - KeyeVL1_5VideoEmbeddingInputs] +KeyeVL1_5VideoInputs = Union[KeyeVL1_5VideoPixelInputs, KeyeVL1_5VideoEmbeddingInputs] class KeyeVL1_5Projector(nn.Module): - def __init__( self, text_config: PretrainedConfig, @@ -186,9 +207,11 @@ def __init__( self.vision_config = vision_config self.merge_kernel_size = (2, 2) - self.hidden_size = (self.vision_config.hidden_size * - self.merge_kernel_size[0] * - self.merge_kernel_size[1]) + self.hidden_size = ( + self.vision_config.hidden_size + * self.merge_kernel_size[0] + * self.merge_kernel_size[1] + ) self.pre_norm = torch.nn.LayerNorm(self.hidden_size, eps=1e-05) self.act = GELUActivation() @@ -210,15 +233,13 @@ def __init__( def forward( self, - image_features: Union[torch.Tensor, tuple[torch.Tensor], - list[torch.Tensor]], + image_features: Union[torch.Tensor, tuple[torch.Tensor], list[torch.Tensor]], image_grid_thw: list[tuple[int, int, int]], ) -> Union[torch.Tensor, list[torch.Tensor]]: m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() - for image_feature, image_grid in zip(image_features, - image_grid_thw): + for image_feature, image_grid in zip(image_features, image_grid_thw): t, h, w = image_grid image_feature = rearrange( image_feature, @@ -240,8 +261,7 @@ def forward( dims = image_features.shape[:-1] dim = image_features.shape[-1] image_features = image_features.view(np.prod(dims), dim) - hidden_states = self.pre_norm(image_features.view( - -1, self.hidden_size)) + hidden_states = self.pre_norm(image_features.view(-1, self.hidden_size)) hidden_states = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) @@ -250,24 +270,28 @@ def forward( class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo): - def get_max_frame_per_video(self) -> int: return 2048 - def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits( + self, + ) -> Mapping[str, Optional[int]]: return {"image": None, "video": 1} -def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ): - image_grid_thw = hf_inputs.get("image_grid_thw", - torch.empty((0, 3), dtype=torch.int64)) +def _keye_field_config( + hf_inputs: Mapping[str, torch.Tensor], +): + image_grid_thw = hf_inputs.get( + "image_grid_thw", torch.empty((0, 3), dtype=torch.int64) + ) image_grid_sizes = image_grid_thw.prod(-1) - video_grid_thw = hf_inputs.get("video_grid_thw", - torch.empty((0, 3), dtype=torch.int64)) + video_grid_thw = hf_inputs.get( + "video_grid_thw", torch.empty((0, 3), dtype=torch.int64) + ) video_grid_thw = split_thw(video_grid_thw) - num_frames = hf_inputs.get("num_frames", - video_grid_thw[:, 0]).clone().tolist() + num_frames = hf_inputs.get("num_frames", video_grid_thw[:, 0]).clone().tolist() video_num_patches = get_num_patches(video_grid_thw, num_frames) @@ -287,22 +311,20 @@ def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ): else: j += 1 video_num_grids = torch.tensor(video_num_grids) - return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_patches), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_patches), - video_grid_thw=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_grids), - num_frames=MultiModalFieldConfig.batched("video")) + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes("image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches + ), + video_embeds=MultiModalFieldConfig.flat_from_sizes("video", video_num_patches), + video_grid_thw=MultiModalFieldConfig.flat_from_sizes("video", video_num_grids), + num_frames=MultiModalFieldConfig.batched("video"), + ) class KeyeVL1_5MultiModalDataParser(MultiModalDataParser): - def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -338,9 +360,7 @@ def _parse_video_data( return super()._parse_video_data(data) -class KeyeVL1_5MultiModalProcessor( - BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]): - +class KeyeVL1_5MultiModalProcessor(BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return KeyeVL1_5MultiModalDataParser() @@ -351,8 +371,7 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() image_token_id = vocab[hf_processor.image_token] @@ -361,44 +380,49 @@ def _get_prompt_updates( merge_length = image_processor.merge_size**2 out_mm_kwargs_data = out_mm_kwargs.get_data() - frame_types: list[torch.Tensor] = \ - hf_processor_mm_kwargs.get("frame_types", None) - timestamps: list[torch.Tensor] = \ - hf_processor_mm_kwargs.get("timestamps", None) + frame_types: list[torch.Tensor] = hf_processor_mm_kwargs.get( + "frame_types", None + ) + timestamps: list[torch.Tensor] = hf_processor_mm_kwargs.get("timestamps", None) num_videos = mm_items.get_count("video", strict=False) if frame_types is None: frame_types = [None] * num_videos - assert len(frame_types) == num_videos, \ - f"Number of frame_types={len(frame_types)} " \ + assert len(frame_types) == num_videos, ( + f"Number of frame_types={len(frame_types)} " f"doesn't equal to number of videos={num_videos}" + ) if timestamps is None: timestamps = [None] * num_videos - assert len(timestamps) == num_videos, \ - f"Number of timestamps={len(timestamps)} " \ + assert len(timestamps) == num_videos, ( + f"Number of timestamps={len(timestamps)} " f"doesn't equal to number of videos={num_videos}" + ) video_grid_thw = out_mm_kwargs_data.get( - 'video_grid_thw', torch.empty((0, 3), dtype=torch.int64)) + "video_grid_thw", torch.empty((0, 3), dtype=torch.int64) + ) num_frames = out_mm_kwargs_data.get( - 'num_frames', torch.tensor([], dtype=torch.int64)) + "num_frames", torch.tensor([], dtype=torch.int64) + ) - assert len(num_frames) == num_videos, \ - f"Size of num_frames={len(num_frames)} " \ + assert len(num_frames) == num_videos, ( + f"Size of num_frames={len(num_frames)} " f"doesn't equal to number of videos={num_videos}" + ) video_grid_hws = split_thw(video_grid_thw) assert int(num_frames.sum().tolist()) == video_grid_hws.shape[0], ( f"The first dimension of `video_grid_hws`={video_grid_hws.shape[0]}" - f"doesn't equal to num of frames.") + f"doesn't equal to num of frames." + ) - cu_seqlens = torch.cumsum(torch.tensor([0] + num_frames.tolist()), - dim=-1) + cu_seqlens = torch.cumsum(torch.tensor([0] + num_frames.tolist()), dim=-1) def get_replacement_keye(item_idx: int, modality: str): """ Args: - item_idx(int): The item index of modality to replace + item_idx(int): The item index of modality to replace modality(str): The modality """ if modality == "image": @@ -413,16 +437,15 @@ def get_replacement_keye(item_idx: int, modality: str): video_timestamps = timestamps[item_idx] video_frame_types = frame_types[item_idx] grid_thw = video_grid_hws[ - cu_seqlens[item_idx]:cu_seqlens[item_idx + 1]] + cu_seqlens[item_idx] : cu_seqlens[item_idx + 1] + ] nframes = grid_thw.shape[0] if video_timestamps is None: video_timestamps = [""] * nframes else: - video_timestamps = [ - format(ts, ".1f") for ts in video_timestamps - ] + video_timestamps = [format(ts, ".1f") for ts in video_timestamps] if video_frame_types is None: video_frame_types = [0] * nframes @@ -437,7 +460,8 @@ def get_replacement_keye(item_idx: int, modality: str): placeholders.append(vocab[hf_processor.fast_end]) return PromptUpdateDetails.select_token_id( - placeholders, embed_token_id=video_token_id) + placeholders, embed_token_id=video_token_id + ) else: raise ValueError(f"Unsupported modality {modality}") @@ -446,7 +470,8 @@ def get_replacement_keye(item_idx: int, modality: str): modality=modality, target=[placeholder[modality]], replacement=partial(get_replacement_keye, modality=modality), - ) for modality in ("image", "video") + ) + for modality in ("image", "video") ] def _get_mm_fields_config( @@ -458,8 +483,8 @@ def _get_mm_fields_config( class KeyeVL1_5DummyInputsBuilder( - KeyeBaseDummyInputsBuilder[KeyeVL1_5ProcessingInfo]): - ... + KeyeBaseDummyInputsBuilder[KeyeVL1_5ProcessingInfo] +): ... @MULTIMODAL_REGISTRY.register_processor( @@ -467,42 +492,26 @@ class KeyeVL1_5DummyInputsBuilder( info=KeyeVL1_5ProcessingInfo, dummy_inputs=KeyeVL1_5DummyInputsBuilder, ) -class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, - SupportsLoRA, SupportsPP): - - def _build_projector(self, - text_config: PretrainedConfig, - vision_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: - return KeyeVL1_5Projector(text_config, vision_config, quant_config, - prefix) +class KeyeVL1_5ForConditionalGeneration( + BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP +): + def _build_projector( + self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + return KeyeVL1_5Projector(text_config, vision_config, quant_config, prefix) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config: PretrainedConfig = vllm_config.model_config.hf_config self.merge_size = config.vision_config.spatial_merge_size super().__init__(vllm_config=vllm_config, prefix=prefix) - def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors, - expected_dim: int, name: str): - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == expected_dim: - return mm_input - elif mm_input.ndim == expected_dim + 1: - return mm_input.reshape(-1, *mm_input.shape[2:]) - else: - raise ValueError( - f"{name} should be {expected_dim}D or " - f"batched {expected_dim}D tensor." - f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})") - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]: + self, **kwargs: object + ) -> Optional[KeyeVL1_5ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -511,11 +520,6 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, expected_dim=4, name="image pixel values") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, expected_dim=2, name="image grid_thw") - return KeyeVL1_5ImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -523,11 +527,6 @@ def _parse_and_validate_image_input( ) if image_embeds is not None: - image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, expected_dim=2, name="image embeds") - image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, expected_dim=2, name="image grid_thw") - return KeyeVL1_5ImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, @@ -535,7 +534,8 @@ def _parse_and_validate_image_input( ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[KeyeVL1_5VideoInputs]: + self, **kwargs: object + ) -> Optional[KeyeVL1_5VideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -545,43 +545,31 @@ def _parse_and_validate_video_input( return None if pixel_values_videos is not None: - pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, - expected_dim=4, - name="video pixel values", - ) - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, expected_dim=2, name="video grid_thw") - - num_frames = self._validate_and_reshape_mm_tensor( - num_frames, expected_dim=1, name="video num frames") - return KeyeVL1_5VideoPixelInputs( type="pixel_values_videos", pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, - num_frames=num_frames) + num_frames=num_frames, + ) if video_embeds is not None: - video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, expected_dim=2, name="video embeds") - video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, expected_dim=2, name="video grid_thw") - - return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds", - video_embeds=video_embeds, - video_grid_thw=video_grid_thw, - num_frames=num_frames) + return KeyeVL1_5VideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + num_frames=num_frames, + ) def _process_video_input( - self, - video_input: KeyeVL1_5VideoInputs) -> tuple[torch.Tensor, ...]: + self, video_input: KeyeVL1_5VideoInputs + ) -> tuple[torch.Tensor, ...]: video_type = video_input["type"] video_grid_thw = split_thw(video_input["video_grid_thw"]) pixel_values_videos = video_input.get("pixel_values_videos", None) - video_embeds = self._process_video_embeds(video_type, video_grid_thw, - pixel_values_videos) + video_embeds = self._process_video_embeds( + video_type, video_grid_thw, pixel_values_videos + ) video_embeds = torch.concat(video_embeds, dim=0) num_frames = video_input["num_frames"].clone().tolist() @@ -589,10 +577,11 @@ def _process_video_input( num_patches = get_num_patches(video_grid_thw, num_frames).tolist() patch_cu_seqlens = torch.cumsum( - torch.tensor([0] + num_patches).detach().clone(), dim=-1) - patch_cu_seqlens = torch.div(patch_cu_seqlens, - self.merge_size**2, - rounding_mode="floor") + torch.tensor([0] + num_patches).detach().clone(), dim=-1 + ) + patch_cu_seqlens = torch.div( + patch_cu_seqlens, self.merge_size**2, rounding_mode="floor" + ) new_video_embeds = [] for idx in range(patch_cu_seqlens.shape[0] - 1): diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index a47bdd2f5ab5..f7381e6b6b93 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -60,21 +60,34 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model -from vllm.model_executor.models.interfaces import (SupportsMultiModal, - SupportsPP) +from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP from vllm.model_executor.models.moonvit import MoonVitPretrainedModel from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig @@ -93,33 +106,35 @@ class MaxImageTokenMeta: class KimiVLMultiModalProjector(nn.Module): - - def __init__(self, config: KimiVLConfig, \ - use_data_parallel: bool = False, prefix: str = ""): + def __init__( + self, config: KimiVLConfig, use_data_parallel: bool = False, prefix: str = "" + ): super().__init__() self.use_data_parallel = use_data_parallel - self.hidden_size = (config.vision_config.hidden_size * - config.vision_config.merge_kernel_size[0] * - config.vision_config.merge_kernel_size[1]) - - self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, - eps=1e-5) - self.linear_1 = ReplicatedLinear(self.hidden_size, - self.hidden_size, - bias=True, - prefix=maybe_prefix( - prefix, "linear_1")) - self.linear_2 = ReplicatedLinear(self.hidden_size, - config.text_config.hidden_size, - bias=True, - prefix=maybe_prefix( - prefix, "linear_2")) + self.hidden_size = ( + config.vision_config.hidden_size + * config.vision_config.merge_kernel_size[0] + * config.vision_config.merge_kernel_size[1] + ) + + self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5) + self.linear_1 = ReplicatedLinear( + self.hidden_size, + self.hidden_size, + bias=True, + prefix=maybe_prefix(prefix, "linear_1"), + ) + self.linear_2 = ReplicatedLinear( + self.hidden_size, + config.text_config.hidden_size, + bias=True, + prefix=maybe_prefix(prefix, "linear_2"), + ) self.act = GELUActivation() def forward(self, image_features: torch.Tensor) -> torch.Tensor: - hidden_states = self.pre_norm(image_features).view( - -1, self.hidden_size) + hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size) hidden_states, _ = self.linear_1(hidden_states) hidden_states = self.act(hidden_states) hidden_states, _ = self.linear_2(hidden_states) @@ -134,6 +149,7 @@ class KimiVLImagePixelInputs(TensorSchema): - ps: Patch size - ni: Number of images """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ @@ -150,7 +166,6 @@ class KimiVLImagePixelInputs(TensorSchema): class KimiVLProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(KimiVLConfig) @@ -169,25 +184,25 @@ def get_num_image_tokens( in_token_limit = hf_processor.image_processor.in_token_limit height = image_height width = image_width - assert isinstance(height, - int), f"height must be int, current height {height}" - assert isinstance(width, - int), f"width must be int, current width {width}" + assert isinstance(height, int), f"height must be int, current height {height}" + assert isinstance(width, int), f"width must be int, current width {width}" assert kernel_size is not None, "kernel_size must be specified" if (width // patch_size) * (height // patch_size) > in_token_limit: - scale = math.sqrt(in_token_limit / ((width // patch_size) * - (height // patch_size))) + scale = math.sqrt( + in_token_limit / ((width // patch_size) * (height // patch_size)) + ) new_w, new_h = int(width * scale), int(height * scale) width, height = new_w, new_h kernel_height, kernel_width = kernel_size - pad_height = (kernel_height * patch_size - height % - (kernel_height * patch_size)) % (kernel_height * - patch_size) - pad_width = (kernel_width * patch_size - width % - (kernel_width * patch_size)) % (kernel_width * patch_size) + pad_height = ( + kernel_height * patch_size - height % (kernel_height * patch_size) + ) % (kernel_height * patch_size) + pad_width = ( + kernel_width * patch_size - width % (kernel_width * patch_size) + ) % (kernel_width * patch_size) # Calculate new dimensions after padding and patching token_height = (height + pad_height) // (kernel_size[0] * patch_size) @@ -200,7 +215,6 @@ def image_token_id(self) -> int: class KimiVLDummyInputsBuilder(BaseDummyInputsBuilder[KimiVLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -220,16 +234,16 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=MaxImageTokenMeta.width, - height=MaxImageTokenMeta.height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=MaxImageTokenMeta.width, + height=MaxImageTokenMeta.height, + num_images=num_images, + overrides=image_overrides, + ) } class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]): - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -242,7 +256,8 @@ def _get_mm_fields_config( # image_grid_hws is shapes for each subtensor in pixel_values return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), + "image", image_grid_sizes + ), image_grid_hws=MultiModalFieldConfig.batched("image"), ) @@ -256,7 +271,8 @@ def _get_prompt_updates( def get_replacement(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -278,11 +294,13 @@ def get_replacement(item_idx: int): ] -@MULTIMODAL_REGISTRY.register_processor(KimiVLMultiModalProcessor, - info=KimiVLProcessingInfo, - dummy_inputs=KimiVLDummyInputsBuilder) -class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + KimiVLMultiModalProcessor, + info=KimiVLProcessingInfo, + dummy_inputs=KimiVLDummyInputsBuilder, +) +class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True supports_encoder_tp_data = True @@ -305,21 +323,27 @@ def __init__( quant_config = vllm_config.quant_config assert isinstance(config.vision_config, MoonViTConfig) - self.use_data_parallel = model_config.multimodal_config.mm_encoder_tp_mode == "data" + self.use_data_parallel = ( + model_config.multimodal_config.mm_encoder_tp_mode == "data" + ) self.hidden_size = config.text_config.hidden_size - self.vision_tower = MoonVitPretrainedModel(config.vision_config, - self.use_data_parallel, - prefix=maybe_prefix( - prefix, "vision_tower")) + self.vision_tower = MoonVitPretrainedModel( + config.vision_config, + self.use_data_parallel, + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = KimiVLMultiModalProjector( config=config, use_data_parallel=self.use_data_parallel, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) self.quant_config = quant_config sub_vllm_config = copy.deepcopy(vllm_config) - sub_vllm_config.model_config.hf_config = sub_vllm_config.model_config.hf_config.text_config + sub_vllm_config.model_config.hf_config = ( + sub_vllm_config.model_config.hf_config.text_config + ) self.language_model = DeepseekV2Model( vllm_config=sub_vllm_config, prefix=maybe_prefix(prefix, "language_model"), @@ -336,31 +360,17 @@ def __init__( else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) self.media_placeholder: int = self.config.media_placeholder_token_id - # ref: qwen2_vl.py - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: - if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") - if isinstance(mm_input, torch.Tensor): - if mm_input.ndim == 2: - return mm_input - if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") - return mm_input.reshape(-1, mm_input.shape[-1]) - else: - return torch.concat(mm_input) - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[KimiVLImageInputs]: + self, **kwargs: object + ) -> Optional[KimiVLImageInputs]: # image input type must be pixel values now pixel_values = kwargs.pop("pixel_values", None) image_grid_hws = kwargs.pop("image_grid_hws", None) @@ -368,21 +378,6 @@ def _parse_and_validate_image_input( if pixel_values is None: return None - image_grid_hws = self._validate_and_reshape_mm_tensor( - image_grid_hws, "image grid hws") - # pixel_values may have complex shapes - num_channels = 3 - patch_size = self.config.vision_config.patch_size - if isinstance(pixel_values, list): - pixel_values = torch.cat([ - x.reshape(-1, num_channels, patch_size, patch_size) - for x in pixel_values - ]) - else: - pixel_values = pixel_values.reshape(-1, num_channels, patch_size, - patch_size) - pixel_values = pixel_values.to(self.vision_tower.dtype) - return KimiVLImagePixelInputs( type="pixel_values", pixel_values=pixel_values, @@ -391,34 +386,32 @@ def _parse_and_validate_image_input( # perform vt on processored pixel_values @torch.inference_mode() - def _process_image_pixels(self, - inputs: KimiVLImagePixelInputs) -> torch.Tensor: + def _process_image_pixels(self, inputs: KimiVLImagePixelInputs) -> torch.Tensor: assert self.vision_tower is not None pixel_values = inputs["pixel_values"] image_grid_hws = inputs["image_grid_hws"] if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.vision_tower, - pixel_values, - image_grid_hws.tolist(), - rope_type="rope_2d") + return run_dp_sharded_mrope_vision_model( + self.vision_tower, + pixel_values, + image_grid_hws.tolist(), + rope_type="rope_2d", + ) else: return self.vision_tower(pixel_values, image_grid_hws) - def _process_image_input(self, - image_input: KimiVLImageInputs) -> torch.Tensor: + def _process_image_input(self, image_input: KimiVLImageInputs) -> torch.Tensor: assert image_input["type"] == "pixel_values" image_features = self._process_image_pixels(image_input) assert isinstance(image_features, (list, tuple)) lengths = [x.shape[0] for x in image_features] - return self.multi_modal_projector( - torch.cat(image_features)).split(lengths) + return self.multi_modal_projector(torch.cat(image_features)).split(lengths) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> Optional[NestedTensors]: + def get_multimodal_embeddings(self, **kwargs: object) -> Optional[NestedTensors]: # Validate the multimodal input keyword arguments image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: @@ -448,8 +441,7 @@ def forward( return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - **kwargs) -> torch.Tensor: + def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, **kwargs) return logits @@ -478,7 +470,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=config.n_routed_experts) + num_experts=config.n_routed_experts, + ) else: expert_params_mapping = [] @@ -494,8 +487,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): if spec_layer is not None: continue # skip spec decode layers for main model - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue @@ -509,8 +501,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # not vision model for now. use_default_weight_loading = True else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. @@ -519,7 +510,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -534,8 +525,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id, **kwargs) break else: - for idx, (param_name, weight_name, expert_id, - shard_id) in enumerate(expert_params_mapping): + for idx, ( + param_name, + weight_name, + expert_id, + shard_id, + ) in enumerate(expert_params_mapping): if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -545,12 +540,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - expert_id=expert_id, - shard_id=shard_id, - **kwargs) + weight_loader( + param, + loaded_weight, + name, + expert_id=expert_id, + shard_id=shard_id, + **kwargs, + ) break else: use_default_weight_loading = True @@ -567,18 +564,18 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight, **kwargs) -def get_spec_layer_idx_from_weight_name(config: DeepseekV2Config, - weight_name: str) -> Optional[int]: - if hasattr(config, - "num_nextn_predict_layers") and (config.num_nextn_predict_layers - > 0): +def get_spec_layer_idx_from_weight_name( + config: DeepseekV2Config, weight_name: str +) -> Optional[int]: + if hasattr(config, "num_nextn_predict_layers") and ( + config.num_nextn_predict_layers > 0 + ): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): - if weight_name.startswith(f"model.layers.{layer_idx+i}."): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): return layer_idx + i return None diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index f9def222a1ec..ae5c97426ee7 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -14,30 +14,40 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.short_conv import ShortConv from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, - SupportsQuant) -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class Lfm2MLP(nn.Module): - def __init__( self, dim: int, @@ -80,7 +90,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Lfm2Attention(nn.Module): - def __init__( self, config: Lfm2Config, @@ -177,7 +186,6 @@ def forward( class Lfm2AttentionDecoderLayer(nn.Module): - def __init__( self, config: Lfm2Config, @@ -195,11 +203,12 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Lfm2Attention( config=config, @@ -238,16 +247,13 @@ def forward( residual = hidden_states hidden_states = self.operator_norm(hidden_states) else: - hidden_states, residual = self.operator_norm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) + hidden_states, residual = self.operator_norm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) hidden_states, residual = self.ffn_norm(hidden_states, residual) return self.feed_forward(hidden_states), residual class Lfm2ShortConvDecoderLayer(nn.Module): - def __init__( self, config: Lfm2Config, @@ -290,8 +296,7 @@ def forward( residual = hidden_states hidden_states = self.operator_norm(hidden_states) else: - hidden_states, residual = self.operator_norm( - hidden_states, residual) + hidden_states, residual = self.operator_norm(hidden_states, residual) output = torch.empty_like(hidden_states) self.conv( hidden_states, @@ -304,7 +309,6 @@ def forward( @support_torch_compile class Lfm2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -315,21 +319,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size) + self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size + ) def get_layer(prefix: str): layer_idx = extract_layer_index(prefix) is_attn = self.config.layer_types[layer_idx] == "full_attention" - layer_class = (Lfm2AttentionDecoderLayer - if is_attn else Lfm2ShortConvDecoderLayer) + layer_class = ( + Lfm2AttentionDecoderLayer if is_attn else Lfm2ShortConvDecoderLayer + ) return layer_class( config, layer_idx, @@ -340,14 +347,14 @@ def get_layer(prefix: str): ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: - self.embedding_norm = RMSNorm(config.hidden_size, - eps=config.norm_eps) + self.embedding_norm = RMSNorm(config.hidden_size, eps=config.norm_eps) else: self.embedding_norm = PPMissingLayer() @@ -379,15 +386,13 @@ def forward( residual=residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.embedding_norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), @@ -398,7 +403,6 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -414,15 +418,15 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class Lfm2ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsQuant): +class Lfm2ForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant +): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -447,7 +451,6 @@ def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, ...]: - return MambaStateDtypeCalculator.short_conv_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -458,7 +461,7 @@ def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[tuple[int, int]]: - """ Calculate shapes for LFM2's convolutional cache. + """Calculate shapes for LFM2's convolutional cache. Args: vllm_config: vLLM config @@ -482,8 +485,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert (not cache_config.enable_prefix_caching - ), "Lfm2 currently does not support prefix caching" + assert not cache_config.enable_prefix_caching, ( + "Lfm2 currently does not support prefix caching" + ) super().__init__() self.config = config @@ -491,8 +495,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.scheduler_config = scheduler_config self.model_config = vllm_config.model_config - self.model = Lfm2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Lfm2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = self.config.vocab_size @@ -507,8 +512,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -516,11 +522,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: else: self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -533,19 +541,18 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a6081d331511..948c9280f953 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -38,27 +39,38 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class LlamaMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -89,8 +101,9 @@ def __init__( prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -101,7 +114,6 @@ def forward(self, x): class LlamaAttention(nn.Module): - def __init__( self, config: LlamaConfig, @@ -141,8 +153,7 @@ def __init__( head_dim = self.hidden_size // self.total_num_heads self.head_dim = head_dim # Phi models introduced a partial_rotary_factor parameter in the config - self.partial_rotary_factor = getattr(config, "partial_rotary_factor", - 1) + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -167,33 +178,36 @@ def __init__( prefix=f"{prefix}.o_proj", ) - self._init_rotary_emb(config, - rope_scaling=rope_scaling, - quant_config=quant_config) + self._init_rotary_emb( + config, rope_scaling=rope_scaling, quant_config=quant_config + ) sliding_window = None if layer_types := getattr(config, "layer_types", None): # Fix for Eagle3 compatibility: # for draft models, subtract target layer count # to get draft-relative layer index starting from 0 - if hasattr(config, 'target_layer_count'): + if hasattr(config, "target_layer_count"): # This is a draft model, # adjust layer_idx to be relative to draft layers effective_layer_idx = layer_idx - config.target_layer_count else: # This is a target model, use layer_idx directly effective_layer_idx = layer_idx - assert effective_layer_idx < len(layer_types), \ + assert effective_layer_idx < len(layer_types), ( f"effective_layer_idx: {effective_layer_idx} \ is out of bounds for layer_types: {layer_types}" + ) - is_sliding = layer_types[ - effective_layer_idx] == "sliding_attention" + is_sliding = layer_types[effective_layer_idx] == "sliding_attention" if is_sliding: sliding_window = config.sliding_window - attn_cls = (EncoderOnlyAttention - if attn_type == AttentionType.ENCODER_ONLY else Attention) + attn_cls = ( + EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY + else Attention + ) self.attn = attn_cls( self.num_heads, @@ -219,9 +233,12 @@ def forward( output, _ = self.o_proj(attn_output) return output - def _init_rotary_emb(self, config: LlamaConfig, - rope_scaling: Optional[dict[str, Any]], - quant_config: Optional[QuantizationConfig]) -> None: + def _init_rotary_emb( + self, + config: LlamaConfig, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig], + ) -> None: is_neox_style = True is_gguf = quant_config and quant_config.get_name() == "gguf" if is_gguf and config.model_type == "llama": @@ -239,11 +256,12 @@ def _init_rotary_emb(self, config: LlamaConfig, class LlamaDecoderLayer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str = "", - config: Optional[LlamaConfig] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[LlamaConfig] = None, + ) -> None: super().__init__() config = config or vllm_config.model_config.hf_config @@ -254,18 +272,20 @@ def __init__(self, rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) bias_o_proj = attention_bias # support internlm/internlm3-8b with qkv_bias - if hasattr(config, 'qkv_bias'): + if hasattr(config, "qkv_bias"): attention_bias = config.qkv_bias # By default, Llama uses causal attention as it is a decoder-only model. @@ -281,8 +301,9 @@ def __init__(self, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -301,10 +322,10 @@ def __init__(self, bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -317,31 +338,28 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual - def get_quant_config( - self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]: + def get_quant_config(self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]: """Get quantization config for this layer. Override in subclasses.""" return vllm_config.quant_config @support_torch_compile class LlamaModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = LlamaDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config @@ -350,12 +368,16 @@ def __init__(self, self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -376,9 +398,9 @@ def __init__(self, self.aux_hidden_state_layers = tuple[int, ...]() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -389,8 +411,9 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, - list[torch.Tensor]]]: + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -404,16 +427,16 @@ def forward( aux_hidden_states = [] for idx, layer in enumerate( - islice(self.layers, self.start_layer, self.end_layer)): + islice(self.layers, self.start_layer, self.end_layer) + ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) @@ -421,8 +444,7 @@ def forward( return hidden_states, aux_hidden_states return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -436,19 +458,19 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -481,8 +503,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -491,13 +512,13 @@ def load_weights(self, weights: Iterable[tuple[str, class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } # LoRA specific attributes embedding_modules = { "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings" + "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] @@ -527,11 +548,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): "norm": "model.norm", } - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = LlamaDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -539,9 +562,11 @@ def __init__(self, self.config = config self.lora_config = lora_config - self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model"), - layer_type=layer_type) + self.model = self._init_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + layer_type=layer_type, + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size @@ -555,39 +580,45 @@ def __init__(self, DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Override to return default layers for Llama + + Note: The GPU model runner will override this with layers from + the speculative config if available, providing dynamic configuration. + """ num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) - def _init_model(self, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): - return LlamaModel(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = LlamaDecoderLayer, + ): + return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -599,8 +630,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( @@ -610,16 +642,15 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights( self.maybe_remap_mistral(name, loaded_weight) - for name, loaded_weight in weights) + for name, loaded_weight in weights + ) # This function is used to remap the mistral format as # used by Mistral and Llama <=2 @@ -628,12 +659,14 @@ def maybe_remap_mistral( name: str, loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: - def permute(w: torch.Tensor, n_heads: int, attn_out: int): attn_in = self.config.head_dim * n_heads - return w.view(n_heads, attn_in // n_heads // 2, 2, - attn_out).transpose(1, 2).reshape(attn_in, attn_out) + return ( + w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + .transpose(1, 2) + .reshape(attn_in, attn_out) + ) mapping = self.mistral_mapping modules = name.split(".") @@ -642,29 +675,32 @@ def permute(w: torch.Tensor, n_heads: int, attn_out: int): # If using quantized model in mistral format, # quantization scales (qscale_weight) also need to be sliced if "wk" in modules and modules[-1] == "weight": - loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads, - self.config.hidden_size) - elif "wk" in modules and modules[ - -1] == "qscale_weight" and loaded_weight.numel() > 1: - loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads, 1) + loaded_weight = permute( + loaded_weight, self.config.num_key_value_heads, self.config.hidden_size + ) + elif ( + "wk" in modules + and modules[-1] == "qscale_weight" + and loaded_weight.numel() > 1 + ): + loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1) elif "wq" in modules and modules[-1] == "weight": - loaded_weight = permute(loaded_weight, - self.config.num_attention_heads, - self.config.hidden_size) - elif "wq" in modules and modules[ - -1] == "qscale_weight" and loaded_weight.numel() > 1: - loaded_weight = permute(loaded_weight, - self.config.num_attention_heads, 1) + loaded_weight = permute( + loaded_weight, self.config.num_attention_heads, self.config.hidden_size + ) + elif ( + "wq" in modules + and modules[-1] == "qscale_weight" + and loaded_weight.numel() > 1 + ): + loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1) num_modules = len(modules) for i in range(num_modules): item = modules[i] next_item = modules[i + 1] if i < num_modules - 1 else None - combined_item = (f"{item}.{next_item}" - if next_item is not None else None) + combined_item = f"{item}.{next_item}" if next_item is not None else None if combined_item in mapping: name = name.replace(combined_item, mapping[combined_item]) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 32d4f69c6bf1..075f35a098a4 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -17,6 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" + from collections.abc import Iterable from typing import Any, Optional @@ -28,27 +29,36 @@ from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.utils import sequence_parallel_chunk from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel -from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, - is_pp_missing_parameter) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + fast_topk, + is_pp_missing_parameter, +) class Llama4MoE(nn.Module): - @staticmethod def custom_routing_function( hidden_states: torch.Tensor, @@ -73,11 +83,13 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe intermediate_size_moe = config.intermediate_size - self.router = ReplicatedLinear(config.hidden_size, - config.num_local_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.router") + self.router = ReplicatedLinear( + config.hidden_size, + config.num_local_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.router", + ) self.shared_expert = LlamaMLP( hidden_size=config.hidden_size, @@ -123,26 +135,28 @@ def forward(self, hidden_states): experts_out = experts_out[:num_tokens] elif self.tp_size > 1: experts_out = self.experts.maybe_all_reduce_tensor_model_parallel( - experts_out) + experts_out + ) return experts_out class Llama4Attention(nn.Module): - - def __init__(self, - config: Llama4TextConfig, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - bias_o_proj: bool = False, - cache_config: Optional[CacheConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: Llama4TextConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size @@ -167,20 +181,23 @@ def __init__(self, self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.attn_temperature_tuning = self.nope and \ - config.attn_temperature_tuning + self.attn_temperature_tuning = self.nope and config.attn_temperature_tuning self.floor_scale = getattr(config, "floor_scale", 8192.0) self.attn_scale = getattr(config, "attn_scale", 0.1) self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.n_rep = self.num_heads // self.num_kv_heads - self.qk_norm = RMSNorm( - hidden_size=self.head_dim, - eps=config.rms_norm_eps, - has_weight=False, - dtype=torch.float32, - ) if self.use_qk_norm else None + self.qk_norm = ( + RMSNorm( + hidden_size=self.head_dim, + eps=config.rms_norm_eps, + has_weight=False, + dtype=torch.float32, + ) + if self.use_qk_norm + else None + ) self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, @@ -203,18 +220,21 @@ def __init__(self, if is_gguf and config.model_type == "llama": is_neox_style = False - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=int(rope_theta), - rope_scaling=rope_scaling if rope_scaling != "default" else None, - is_neox_style=is_neox_style, - ) if not self.nope else None + self.rotary_emb = ( + get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=int(rope_theta), + rope_scaling=rope_scaling if rope_scaling != "default" else None, + is_neox_style=is_neox_style, + ) + if not self.nope + else None + ) use_chunked_local_attn = not self.nope and config.attention_chunk_size - attn_cls = (ChunkedLocalAttention - if use_chunked_local_attn else Attention) + attn_cls = ChunkedLocalAttention if use_chunked_local_attn else Attention self.attn = attn_cls( self.num_heads, self.head_dim, @@ -223,9 +243,12 @@ def __init__(self, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", - **({ - "attention_chunk_size": config.attention_chunk_size - } if use_chunked_local_attn else {})) + **( + {"attention_chunk_size": config.attention_chunk_size} + if use_chunked_local_attn + else {} + ), + ) def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: floor = torch.floor((positions + 1.0) / self.floor_scale) @@ -270,11 +293,12 @@ def forward( class Llama4DecoderLayer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str = "", - config: Optional[Llama4TextConfig] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[Llama4TextConfig] = None, + ) -> None: super().__init__() config = config or vllm_config.model_config.hf_config @@ -302,8 +326,10 @@ def __init__(self, cache_config=cache_config, prefix=f"{prefix}.self_attn", ) - is_moe_layer = config.interleave_moe_layer_step > 0 and ( - self.layer_idx + 1) % config.interleave_moe_layer_step == 0 + is_moe_layer = ( + config.interleave_moe_layer_step > 0 + and (self.layer_idx + 1) % config.interleave_moe_layer_step == 0 + ) if is_moe_layer: self.feed_forward = Llama4MoE( vllm_config=vllm_config, @@ -318,10 +344,10 @@ def __init__(self, bias=False, prefix=f"{prefix}.feed_forward", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -334,30 +360,26 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @support_torch_compile class Llama4Model(LlamaModel): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer, + ): self.num_experts = vllm_config.model_config.hf_config.num_local_experts - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) + super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) def load_moe_expert_weights( self, @@ -408,9 +430,7 @@ def load_moe_expert_weights( # Iterate over all the expert parameters and load the weights if we find # a match in weight name. - for (param_name, weight_name, expert_id, - shard_id) in expert_params_mapping: - + for param_name, weight_name, expert_id, shard_id in expert_params_mapping: # Get a view of the loaded_weight to avoid modifying the original # one across iterations. new_loaded_weight = loaded_weight @@ -419,7 +439,7 @@ def load_moe_expert_weights( # the expert index from the expected weight name. if fused: # The string between e_str and proj_str is the expert index. - e_str, _, proj_str, _ = weight_name.split('.') + e_str, _, proj_str, _ = weight_name.split(".") weight_name = f"{e_str}.{proj_str}" param_name = f"{param_name}weight" @@ -436,8 +456,9 @@ def load_moe_expert_weights( continue # Skip if the current weight is for the bias. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[full_param_name] @@ -456,13 +477,14 @@ def load_moe_expert_weights( # starting expert index for the current EP rank and extract the # corresponding expert weights. layer_idx = extract_layer_index(name) - expert_map = self.layers[ - layer_idx].feed_forward.experts.expert_map + expert_map = self.layers[layer_idx].feed_forward.experts.expert_map if expert_map is not None: - local_expert_indices = (expert_map != -1) \ - .nonzero() \ - .flatten() \ - .to(new_loaded_weight.device) + local_expert_indices = ( + (expert_map != -1) + .nonzero() + .flatten() + .to(new_loaded_weight.device) + ) new_loaded_weight = new_loaded_weight[local_expert_indices] expert_id = local_expert_indices[0].item() else: @@ -471,19 +493,20 @@ def load_moe_expert_weights( # Load the weight into the module parameter with corresponding # shard id and expert id. - weight_loader(param, - new_loaded_weight, - full_param_name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + new_loaded_weight, + full_param_name, + shard_id=shard_id, + expert_id=expert_id, + ) loaded_params.add(full_param_name) expert_param_loaded = True return expert_param_loaded - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Name mapping from the parameter name to the shard name and # corresponding shard id. stacked_params_mapping = [ @@ -503,14 +526,16 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.num_experts) + num_experts=self.num_experts, + ) # Expert parameter mapping for the case where the expert weights are # fused into a single weight tensor. expert_params_mapping_fused = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_up_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="gate_up_proj", - num_experts=1) + num_experts=1, + ) # All the module parameters. params_dict = dict(self.named_parameters()) # The module parameters that have been loaded. @@ -518,7 +543,6 @@ def load_weights(self, weights: Iterable[tuple[str, # Iterate over all the weights and load them into module parameters. for name, loaded_weight in weights: - # If the name contains "experts.gate_up_proj" or "experts.down_proj" # without the expert indices, it means the expert weights are fused # into a single weight tensor across all experts. @@ -529,13 +553,14 @@ def load_weights(self, weights: Iterable[tuple[str, # If kv cache quantization scales exist and the weight name # corresponds to one of the kv cache quantization scales, load # them. - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -552,8 +577,9 @@ def load_weights(self, weights: Iterable[tuple[str, # For ModelOpt checkpoints, we need to rename the self_attn # weight/weight_scale names except for kv cache scales. - if not (name.endswith( - (".k_scale", ".v_scale")) and "self_attn" in name): + if not ( + name.endswith((".k_scale", ".v_scale")) and "self_attn" in name + ): name = name.replace(weight_name, param_name) # Skip if the current weight corresponds to a parameter that @@ -572,8 +598,7 @@ def load_weights(self, weights: Iterable[tuple[str, # Load the weight into the module parameter with corresponding # shard id and exit the for loop and the else block. param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) @@ -587,12 +612,14 @@ def load_weights(self, weights: Iterable[tuple[str, else: # First, try to load MoE weights using load_moe_expert_weights. # If successful, move on to next loaded weight. - if self.load_moe_expert_weights(name, - loaded_weight, - params_dict, - loaded_params, - expert_params_mapping, - fused=fused_experts_params): + if self.load_moe_expert_weights( + name, + loaded_weight, + params_dict, + loaded_params, + expert_params_mapping, + fused=fused_experts_params, + ): continue # Skip if the current weight corresponds to a parameter that @@ -604,37 +631,40 @@ def load_weights(self, weights: Iterable[tuple[str, # per-expert patterns, i.e. one weight scale tensor for all # experts. scale_names = [ - "w13_input_scale", "w13_weight_scale", "w2_input_scale", - "w2_weight_scale" + "w13_input_scale", + "w13_weight_scale", + "w2_input_scale", + "w2_weight_scale", ] - if ("experts." in name and any(scale_name in name - for scale_name in scale_names)): - + if "experts." in name and any( + scale_name in name for scale_name in scale_names + ): param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) # If weight loader supports special moe loading, use it to # avoid expensive runtime reflection - if getattr(weight_loader, 'supports_moe_loading', False): + if getattr(weight_loader, "supports_moe_loading", False): # Map the weight name to the corresponding shard id. shard_id = "w2" if "w2_" in name else "w1" # Transpose if weight scales are FP8 block scales with # three dimensions: # [num_experts, hidden_in, hidden_out]. - if name.endswith("weight_scale") \ - and loaded_weight.dtype == torch.float8_e4m3fn \ - and loaded_weight.ndim == 3: + if ( + name.endswith("weight_scale") + and loaded_weight.dtype == torch.float8_e4m3fn + and loaded_weight.ndim == 3 + ): loaded_weight = loaded_weight.transpose(-1, -2) # Load the weight into the module parameter with # corresponding shard id and expert id. - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=0) + weight_loader( + param, loaded_weight, name, shard_id=shard_id, expert_id=0 + ) else: # Regular weight loader (handles both @@ -646,8 +676,7 @@ def load_weights(self, weights: Iterable[tuple[str, # Handle normal (non-stacked, non-MoE) weights. param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -656,7 +685,6 @@ def load_weights(self, weights: Iterable[tuple[str, class Llama4ForCausalLM(LlamaForCausalLM): - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -667,30 +695,29 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): gen_config = vllm_config.model_config.try_get_generation_config() gen_config.update(vllm_config.model_config.override_generation_config) # enable temperature tuning by default when max_model_len > 32K - default_attn_temperature_tuning = \ - vllm_config.model_config.max_model_len > 32768 - vllm_config.model_config.hf_config.attn_temperature_tuning \ - = gen_config.get( - "attn_temperature_tuning", default_attn_temperature_tuning) - - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=Llama4DecoderLayer) - - def _init_model(self, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer): - return Llama4Model(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + default_attn_temperature_tuning = vllm_config.model_config.max_model_len > 32768 + vllm_config.model_config.hf_config.attn_temperature_tuning = gen_config.get( + "attn_temperature_tuning", default_attn_temperature_tuning + ) + + super().__init__( + vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer + ) + + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer, + ): + return Llama4Model( + vllm_config=vllm_config, prefix=prefix, layer_type=layer_type + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) weights = [ self.permute_qk_weight_for_rotary(name, loaded_weight) @@ -703,10 +730,8 @@ def permute_qk_weight_for_rotary( name: str, loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: - # Helper function to permute the weight's channels def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): - # Calculate the expected shape of the weight. # Do not rely on w's shape, as it may be in another layout. attn_in = self.config.head_dim * n_heads @@ -719,28 +744,39 @@ def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): # If the weight is a weight scale, we need to divide attn_out by # block size, which is currently 16. - elif w.dtype == torch.float8_e4m3fn and is_weight_scale \ - and w.shape[1] * 16 == attn_out: + elif ( + w.dtype == torch.float8_e4m3fn + and is_weight_scale + and w.shape[1] * 16 == attn_out + ): attn_out = attn_out // 16 - return w.view(n_heads, attn_in // n_heads // 2, 2, - attn_out).transpose(1, 2).reshape(attn_in, attn_out) + return ( + w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + .transpose(1, 2) + .reshape(attn_in, attn_out) + ) modules = name.split(".") # Permute Q/K weights and weight block scales for rotary embedding is_weight = modules[-1] == "weight" - is_nvfp4_weight_scale = (modules[-1] == "weight_scale" and - loaded_weight.dtype == torch.float8_e4m3fn) + is_nvfp4_weight_scale = ( + modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn + ) if is_weight or is_nvfp4_weight_scale: - if ("wk" in modules or "k_proj" in modules): - loaded_weight = permute(loaded_weight, - self.config.num_key_value_heads, - is_nvfp4_weight_scale) - elif ("wq" in modules or "q_proj" in modules): - loaded_weight = permute(loaded_weight, - self.config.num_attention_heads, - is_nvfp4_weight_scale) + if "wk" in modules or "k_proj" in modules: + loaded_weight = permute( + loaded_weight, + self.config.num_key_value_heads, + is_nvfp4_weight_scale, + ) + elif "wq" in modules or "q_proj" in modules: + loaded_weight = permute( + loaded_weight, + self.config.num_attention_heads, + is_nvfp4_weight_scale, + ) return name, loaded_weight diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 572eca344e0a..039022ef4527 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -30,11 +30,9 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.torchao import TorchAOConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, - Llama4ForCausalLM) +from vllm.model_executor.models.llama4 import Llama4DecoderLayer, Llama4ForCausalLM from vllm.model_executor.models.utils import extract_layer_index from .interfaces import SupportsMultiModal @@ -45,7 +43,6 @@ @support_torch_compile class LlamaModel(nn.Module): - def __init__( self, *, @@ -55,8 +52,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.config = ( - vllm_config.speculative_config.draft_model_config.hf_config) + self.config = vllm_config.speculative_config.draft_model_config.hf_config self.validate_and_update_config(start_layer_id, quant_config) self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -65,18 +61,20 @@ def __init__( prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList([ - Llama4DecoderLayer( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - config=self.config, - ) for i in range(self.config.num_hidden_layers) - ]) - self.fc = torch.nn.Linear(self.config.hidden_size * 2, - self.config.hidden_size, - bias=False) - self.norm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.layers = nn.ModuleList( + [ + Llama4DecoderLayer( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, + ) + for i in range(self.config.num_hidden_layers) + ] + ) + self.fc = torch.nn.Linear( + self.config.hidden_size * 2, self.config.hidden_size, bias=False + ) + self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -90,8 +88,7 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: if inputs_embeds is None: inputs_embeds = self.get_input_embeddings(input_ids) - hidden_states = self.fc( - torch.cat((inputs_embeds, hidden_states), dim=-1)) + hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) residual = None for layer in self.layers: hidden_states, residual = layer( @@ -102,8 +99,7 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -126,69 +122,66 @@ def load_weights(self, weights: Iterable[tuple[str, break else: # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and \ - "embed_tokens." in name: + if get_pp_group().world_size == 1 and "embed_tokens." in name: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) for name in params_dict: # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and \ - "embed_tokens." in name: + if get_pp_group().world_size == 1 and "embed_tokens." in name: continue assert name in loaded_params, f"{name} is not loaded!" return loaded_params def validate_and_update_config( - self, - start_layer_id: int, - quant_config: Optional[QuantizationConfig] = None) -> None: + self, start_layer_id: int, quant_config: Optional[QuantizationConfig] = None + ) -> None: # yoco and moe is not supported by draft model yet assert self.config.yoco_global_kv_layer is None assert self.config.yoco_local_kv_layer is None assert len(self.config.moe_layers) == 0 # draft model layer index is increased by start_layer_id, # so we need to pad relevant configs accordingly - self.config.no_rope_layers = [ - 0 - ] * start_layer_id + self.config.no_rope_layers + self.config.no_rope_layers = [0] * start_layer_id + self.config.no_rope_layers # currently only TorchAO quantization is supported if isinstance(quant_config, TorchAOConfig): def pad_layer_name(layer: str) -> str: layer_index = extract_layer_index(layer) - return layer.replace(str(layer_index), - str(layer_index + start_layer_id)) + return layer.replace( + str(layer_index), str(layer_index + start_layer_id) + ) - quant_config.torchao_config.module_fqn_to_config = { + torchao_config = quant_config.torchao_config + torchao_config.module_fqn_to_config = { pad_layer_name(layer): quantization - for layer, quantization in - quant_config.torchao_config.module_fqn_to_config.items() + for layer, quantization in torchao_config.module_fqn_to_config.items() } class EagleLlama4ForCausalLM(Llama4ForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) - self.config = ( - vllm_config.speculative_config.draft_model_config.hf_config) + self.config = vllm_config.speculative_config.draft_model_config.hf_config target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) + vllm_config.parallel_config + ) # draft model quantization config may differ from target model quant_config = VllmConfig.get_quantization_config( - vllm_config.speculative_config.draft_model_config, - vllm_config.load_config) - self.model = LlamaModel(vllm_config=vllm_config, - prefix="model", - start_layer_id=target_layer_num, - quant_config=quant_config) + vllm_config.speculative_config.draft_model_config, vllm_config.load_config + ) + self.model = LlamaModel( + vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num, + quant_config=quant_config, + ) logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.config.vocab_size, - scale=logit_scale) + self.logits_processor = LogitsProcessor( + self.config.vocab_size, scale=logit_scale + ) def get_language_model(self) -> torch.nn.Module: return self.model @@ -204,13 +197,10 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states, inputs_embeds) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> None: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: def transform(inputs): name, loaded_weight = inputs - name, weight = self.permute_qk_weight_for_rotary( - name, loaded_weight) + name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight) if "lm_head" not in name: name = "model." + name return name, weight diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index d7d6b1745fc8..5df158818c9f 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -13,11 +13,9 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama import (LlamaDecoderLayer, - LlamaForCausalLM) +from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM from .utils import AutoWeightsLoader, maybe_prefix @@ -25,7 +23,6 @@ class LlamaDecoderLayer(LlamaDecoderLayer): - def __init__( self, vllm_config: VllmConfig, @@ -44,7 +41,6 @@ def __init__( @support_torch_compile class LlamaModel(nn.Module): - def __init__( self, *, @@ -53,8 +49,7 @@ def __init__( start_layer_id: int = 0, ) -> None: super().__init__() - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -63,17 +58,20 @@ def __init__( prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList([ - LlamaDecoderLayer( - vllm_config, - i == 0, - prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - config=self.config, - ) for i in range(self.config.num_hidden_layers) - ]) - self.fc = torch.nn.Linear(self.config.hidden_size * 2, - self.config.hidden_size, - bias=False) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + vllm_config, + i == 0, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, + ) + for i in range(self.config.num_hidden_layers) + ] + ) + self.fc = torch.nn.Linear( + self.config.hidden_size * 2, self.config.hidden_size, bias=False + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -85,8 +83,7 @@ def forward( hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) - hidden_states = self.fc( - torch.cat((input_embeds, hidden_states), dim=-1)) + hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1)) residual = None for layer in self.layers: hidden_states, residual = layer( @@ -97,8 +94,7 @@ def forward( hidden_states = hidden_states + residual return hidden_states, hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -119,40 +115,37 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight, shard_id) break else: - # if PP disabled then draft will share embed with target - if get_pp_group().world_size == 1 and \ - "embed_tokens." in name: + if get_pp_group().world_size == 1 and "embed_tokens." in name: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class EagleLlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config # Ensure draft_vocab_size is set # default to the base vocab size when absent if getattr(self.config, "draft_vocab_size", None) is None: base_vocab_size = getattr(self.config, "vocab_size", None) self.config.draft_vocab_size = base_vocab_size target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) - self.model = LlamaModel(vllm_config=vllm_config, - prefix="model", - start_layer_id=target_layer_num) + vllm_config.parallel_config + ) + self.model = LlamaModel( + vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num + ) logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.config.vocab_size, - scale=logit_scale) + self.logits_processor = LogitsProcessor( + self.config.vocab_size, scale=logit_scale + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -171,7 +164,6 @@ def forward( return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - def transform(inputs): name, loaded_weight = inputs if "lm_head" not in name: diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 3fb6f2f8d5ec..155a4ecea28f 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -13,13 +13,15 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama import (LlamaDecoderLayer, - LlamaForCausalLM) +from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM +from vllm.multimodal.inputs import NestedTensors from .utils import AutoWeightsLoader, maybe_prefix @@ -27,11 +29,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str = "", - config: Optional[LlamaConfig] = None) -> None: + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + config: Optional[LlamaConfig] = None, + ) -> None: super().__init__(vllm_config, prefix=prefix, config=config) config = config or vllm_config.model_config.hf_config @@ -55,26 +58,27 @@ def __init__(self, else: self._residual_norm = self._norm_after_residual - def get_quant_config( - self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]: + def get_quant_config(self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]: """Use drafter's quantization config instead of verifier's.""" draft_model_config = vllm_config.speculative_config.draft_model_config draft_load_config = vllm_config.load_config - return VllmConfig.get_quantization_config( - draft_model_config, - draft_load_config) if draft_model_config else None + return ( + VllmConfig.get_quantization_config(draft_model_config, draft_load_config) + if draft_model_config + else None + ) def _norm_before_residual( - self, - hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: hidden_states = self.hidden_norm(hidden_states) residual = hidden_states return hidden_states, residual def _norm_after_residual( - self, - hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: residual = hidden_states hidden_states = self.hidden_norm(hidden_states) return hidden_states, residual @@ -86,11 +90,9 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - embeds = self.input_layernorm(embeds) - hidden_states, residual = self._residual_norm( - hidden_states=hidden_states) + hidden_states, residual = self._residual_norm(hidden_states=hidden_states) hidden_states = torch.cat([embeds, hidden_states], dim=-1) # Self Attention @@ -99,8 +101,7 @@ def forward( hidden_states=hidden_states, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) # Fully Connected hidden_states = self.mlp(hidden_states) @@ -109,7 +110,6 @@ def forward( class LlamaModel(nn.Module): - def __init__( self, *, @@ -118,8 +118,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size current_vllm_config = get_current_vllm_config() @@ -130,21 +129,23 @@ def __init__( prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList([ - LlamaDecoderLayer( - current_vllm_config, - prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), - config=self.config, - ) - ]) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + current_vllm_config, + prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), + config=self.config, + ) + ] + ) if hasattr(self.config, "target_hidden_size"): - self.fc = torch.nn.Linear(self.config.target_hidden_size * 3, - self.config.hidden_size, - bias=False) + self.fc = torch.nn.Linear( + self.config.target_hidden_size * 3, self.config.hidden_size, bias=False + ) else: - self.fc = torch.nn.Linear(self.config.hidden_size * 3, - self.config.hidden_size, - bias=False) + self.fc = torch.nn.Linear( + self.config.hidden_size * 3, self.config.hidden_size, bias=False + ) self.norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, @@ -175,8 +176,7 @@ def forward( hidden_states, hidden_prenorm = self.norm(hidden_states, residual) return hidden_states, hidden_prenorm - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -188,8 +188,8 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if 'midlayer.' in name: - name = name.replace('midlayer.', 'layers.0.') + if "midlayer." in name: + name = name.replace("midlayer.", "layers.0.") for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -200,33 +200,31 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Eagle3LlamaForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = vllm_config.speculative_config.draft_model_config.hf_config # Ensure draft_vocab_size is set # default to the base vocab size when absent if getattr(self.config, "draft_vocab_size", None) is None: base_vocab_size = getattr(self.config, "vocab_size", None) self.config.draft_vocab_size = base_vocab_size target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) + vllm_config.parallel_config + ) # Store target layer count in draft config for # proper layer_types indexing in draft models self.config.target_layer_count = target_layer_num - self.model = LlamaModel(vllm_config=vllm_config, - prefix="model", - start_layer_id=target_layer_num) + self.model = LlamaModel( + vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num + ) logit_scale = getattr(self.config, "logit_scale", 1.0) self.lm_head = ParallelLMHead( @@ -234,15 +232,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config.hidden_size, org_num_embeddings=self.config.draft_vocab_size, padding_size=(DEFAULT_VOCAB_PADDING_SIZE), - prefix=maybe_prefix(prefix, "lm_head")) - self.logits_processor = LogitsProcessor(self.config.draft_vocab_size, - scale=logit_scale) + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.config.draft_vocab_size, scale=logit_scale + ) self.draft_id_to_target_id = nn.Parameter( torch.zeros(self.config.draft_vocab_size, dtype=torch.long), requires_grad=False, ) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + is_multimodal: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( @@ -260,17 +265,21 @@ def compute_logits( ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states) if self.draft_id_to_target_id is None: - assert logits.shape[1] == self.config.vocab_size, \ - "Expected logits to have shape " \ + assert logits.shape[1] == self.config.vocab_size, ( + "Expected logits to have shape " f"(*, {self.config.vocab_size}), but got {logits.shape}" + ) return logits base = torch.arange(self.config.draft_vocab_size, device=logits.device) targets = base + self.draft_id_to_target_id - logits_new = logits.new_full(( - logits.shape[0], - self.config.vocab_size, - ), float('-inf')) + logits_new = logits.new_full( + ( + logits.shape[0], + self.config.vocab_size, + ), + float("-inf"), + ) logits_new[:, targets] = logits return logits_new diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 725468ddef86..3d46e22a0d21 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -3,35 +3,49 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union) +from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union import torch import torch.nn as nn -from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, - PixtralVisionConfig, PretrainedConfig, - SiglipVisionConfig) +from transformers import ( + BatchFeature, + CLIPVisionConfig, + LlavaConfig, + PixtralVisionConfig, + PretrainedConfig, + SiglipVisionConfig, +) from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, - MultiModalUUIDDict) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - InputProcessingContext, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -40,8 +54,12 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_num_selected_vision_tokens, get_vision_encoder_info @@ -52,10 +70,11 @@ class LlavaImagePixelInputs(TensorSchema): - c: Number of channels (3) - h: Height - w: Width - + Note that `height` or `width` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] @@ -67,14 +86,16 @@ class PixtralHFImagePixelInputs(TensorSchema): - c: Number of channels - h: Height - w: Width - + Note that `height` or `width` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral" pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"})] + TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}), + ] class LlavaImageEmbeddingInputs(TensorSchema): @@ -84,36 +105,43 @@ class LlavaImageEmbeddingInputs(TensorSchema): - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs, - LlavaImageEmbeddingInputs] +LlavaImageInputs = Union[ + LlavaImagePixelInputs, PixtralHFImagePixelInputs, LlavaImageEmbeddingInputs +] class LlavaMultiModalProjector(nn.Module): - - def __init__(self, - vision_hidden_size: int, - text_hidden_size: int, - projector_hidden_act: str, - multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.linear_1 = ColumnParallelLinear(vision_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_1") + self.linear_1 = ColumnParallelLinear( + vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) self.act = get_act_fn(projector_hidden_act) - self.linear_2 = RowParallelLinear(text_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_2") + self.linear_2 = RowParallelLinear( + text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.linear_1(image_features) @@ -134,7 +162,6 @@ class LlavaLikeProcessor(Protocol): class BaseLlavaProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> LlavaLikeConfig: return self.ctx.get_hf_config(LlavaConfig) @@ -183,7 +210,6 @@ def get_max_image_tokens(self) -> int: class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -200,22 +226,21 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class LlavaProcessingInfo(BaseLlavaProcessingInfo): - def get_hf_processor(self, **kwargs: object): hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs) # In case patch_size is omitted from `processor_config.json` @@ -227,7 +252,6 @@ def get_hf_processor(self, **kwargs: object): class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): - # Copied from BaseMultiModalProcessor @abstractmethod def _get_mm_fields_config( @@ -248,7 +272,8 @@ def _get_prompt_updates( def get_replacement(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -270,9 +295,7 @@ def get_replacement(item_idx: int): ] -class LlavaMultiModalProcessor( - BaseLlavaMultiModalProcessor[LlavaProcessingInfo]): - +class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]): def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -285,14 +308,11 @@ def _get_mm_fields_config( class PixtralHFProcessingInfo(BaseLlavaProcessingInfo): - def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(PixtralProcessor, **kwargs) -class PixtralHFMultiModalProcessor( - BaseMultiModalProcessor[PixtralHFProcessingInfo]): - +class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -372,7 +392,8 @@ def get_replacement(item_idx: int): def _build_llava_or_pixtral_hf_info( - ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo: + ctx: InputProcessingContext, +) -> BaseLlavaProcessingInfo: hf_config = ctx.get_hf_config(LlavaConfig) if isinstance(hf_config.vision_config, PixtralVisionConfig): @@ -407,7 +428,7 @@ def _build_llava_or_pixtral_hf_processor( def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: """Determine the number of hidden layers to initialize up to in the visual encoder. - + Args: hf_config: Model config with vision feature layer(s). """ @@ -418,10 +439,10 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: return _get_layer_index(feature_layers, num_hidden_layers) # If we have multiple feature layers, initialize up to the deepest one elif isinstance(feature_layers, (list, tuple)): - return max( - _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) - raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" - " is not supported") + return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError( + f"vision_layer_feature type: {type(feature_layers)} is not supported" + ) def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: @@ -479,14 +500,17 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor, - info=_build_llava_or_pixtral_hf_info, - dummy_inputs=LlavaDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + _build_llava_or_pixtral_hf_processor, + info=_build_llava_or_pixtral_hf_info, + dummy_inputs=LlavaDummyInputsBuilder, +) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } hf_to_vllm_mapper = WeightsMapper( @@ -496,7 +520,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -517,11 +542,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # NOTE: These are special cases for Pixtral-12B in the HF-format # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa - if (config.text_config.architectures is None - and config.text_config.model_type == "mistral"): + if ( + config.text_config.architectures is None + and config.text_config.model_type == "mistral" + ): config.text_config.architectures = ["MistralForCausalLM"] - if (config.projector_hidden_act is None - and config.vision_config.hidden_act == "gelu"): + if ( + config.projector_hidden_act is None + and config.vision_config.hidden_act == "gelu" + ): config.projector_hidden_act = "gelu" # TODO: Optionally initializes this for supporting embeddings. @@ -530,14 +559,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act, multimodal_projector_bias=config.multimodal_projector_bias, quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) else: self.vision_tower = None self.multi_modal_projector = None @@ -549,10 +580,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[LlavaImageInputs]: + self, **kwargs: object + ) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -560,45 +593,33 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - if self.config.vision_config.model_type == "pixtral": return PixtralHFImagePixelInputs( type="pixel_values_pixtral", - pixel_values=flatten_bn(pixel_values), + pixel_values=pixel_values, ) expected_h = expected_w = self.config.vision_config.image_size return LlavaImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values, concat=True), - resolve_bindings={ - "h": expected_h, - "w": expected_w - }, + pixel_values=pixel_values, + resolve_bindings={"h": expected_h, "w": expected_w}, ) if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - if self.config.vision_config.model_type == "pixtral": raise ValueError("Pixtral-HF does not support image_embeds.") return LlavaImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel, - PixtralHFVisionModel], + vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel], pixel_values: Union[torch.Tensor, list[torch.Tensor]], ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since @@ -631,9 +652,7 @@ def _process_image_input( if isinstance(image_features, torch.Tensor): return self.multi_modal_projector(image_features) - feature_sizes = [ - image_feature.shape[0] for image_feature in image_features - ] + feature_sizes = [image_feature.shape[0] for image_feature in image_features] image_embeds = self.multi_modal_projector(torch.cat(image_features)) image_embeds = torch.split(image_embeds, feature_sizes) @@ -642,8 +661,7 @@ def _process_image_input( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -698,10 +716,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -711,8 +728,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.vision_tower is None and self.multi_modal_projector is None: skip_prefixes.extend(["vision_tower.", "multi_modal_projector."]) @@ -722,7 +738,6 @@ def load_weights(self, weights: Iterable[tuple[str, class MantisProcessingInfo(LlavaProcessingInfo): - def get_hf_processor(self, **kwargs: object): hf_config = self.get_hf_config() vision_info = self.get_vision_encoder_info() @@ -737,7 +752,6 @@ def get_hf_processor(self, **kwargs: object): class MantisMultiModalProcessor(LlavaMultiModalProcessor): - def apply( self, prompt: Union[str, list[int]], @@ -755,11 +769,13 @@ def apply( image_height=-1, ) - result = super().apply(prompt, - mm_data, - hf_processor_mm_kwargs, - tokenization_kwargs, - mm_uuids=mm_uuids) + result = super().apply( + prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() @@ -769,19 +785,24 @@ def apply( # We reimplement the functionality of MLlavaProcessor from # https://github.com/TIGER-AI-Lab/Mantis.git def get_replacement_mantis(item_idx: int): - return "".join([ - f"(image {item_idx+1}: <Image>", # 7 tokens - "<image>" * num_image_tokens, - "</Image>)", # 3 tokens - ]) - - mantis_mm_repls = self._bind_and_group_updates([ - PromptReplacement( - modality="image", - target=[image_token_id] * num_image_tokens, - replacement=get_replacement_mantis, + return "".join( + [ + f"(image {item_idx + 1}: <Image>", # 7 tokens + "<image>" * num_image_tokens, + "</Image>)", # 3 tokens + ] ) - ], mm_item_counts) + + mantis_mm_repls = self._bind_and_group_updates( + [ + PromptReplacement( + modality="image", + target=[image_token_id] * num_image_tokens, + replacement=get_replacement_mantis, + ) + ], + mm_item_counts, + ) prompt_ids, _ = self._apply_prompt_updates( result["prompt_token_ids"], @@ -812,8 +833,10 @@ def get_replacement_mantis(item_idx: int): # To use this model, please use # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` -@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor, - info=MantisProcessingInfo, - dummy_inputs=LlavaDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + MantisMultiModalProcessor, + info=MantisProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder, +) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 70fd0b2e5efb..caedace7cab1 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -3,14 +3,15 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping -from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union) +from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union import torch import torch.nn as nn from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor from transformers.models.llava_next.modeling_llava_next import ( - get_anyres_image_grid_shape, unpad_image) + get_anyres_image_grid_shape, + unpad_image, +) from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY @@ -21,12 +22,21 @@ from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, - LlavaDummyInputsBuilder, LlavaLikeConfig, - LlavaMultiModalProjector, init_vision_tower_for_llava) +from .llava import ( + BaseLlavaMultiModalProcessor, + BaseLlavaProcessingInfo, + LlavaDummyInputsBuilder, + LlavaLikeConfig, + LlavaMultiModalProjector, + init_vision_tower_for_llava, +) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_num_selected_vision_tokens @@ -38,14 +48,16 @@ class LlavaNextImagePixelInputs(TensorSchema): - c: Number of channels (3) - h: Height - w: Width - + Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})] + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), + ] image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] # This should be in `(height, width)` format. @@ -58,12 +70,12 @@ class LlavaNextImageEmbeddingInputs(TensorSchema): - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, - LlavaNextImageEmbeddingInputs] +LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, LlavaNextImageEmbeddingInputs] class LlavaNextLikeConfig(LlavaLikeConfig, Protocol): @@ -71,7 +83,6 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol): class LlavaNextProcessingInfo(BaseLlavaProcessingInfo): - def get_hf_config(self) -> LlavaNextLikeConfig: return self.ctx.get_hf_config(LlavaNextConfig) @@ -141,12 +152,14 @@ def _get_num_unpadded_features( if aspect_ratio > current_aspect_ratio: new_height = int( - round(original_height * (current_width / original_width), 7)) + round(original_height * (current_width / original_width), 7) + ) padding = (current_height - new_height) // 2 current_height = current_height - (2 * padding) else: new_width = int( - round(original_width * (current_height / original_height), 7)) + round(original_width * (current_height / original_height), 7) + ) padding = (current_width - new_width) // 2 current_width = current_width - (2 * padding) @@ -159,13 +172,13 @@ def get_image_size_with_most_features(self) -> ImageSize: hf_config = self.get_hf_config() largest_feature_size, largest_feature_pinpoint = 0, None - for (height, width) in hf_config.image_grid_pinpoints: - feat_size = self.get_num_image_tokens(image_width=width, - image_height=height) + for height, width in hf_config.image_grid_pinpoints: + feat_size = self.get_num_image_tokens( + image_width=width, image_height=height + ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") @@ -177,7 +190,6 @@ def get_image_size_with_most_features(self) -> ImageSize: class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]): - # Copied from BaseMultiModalProcessor @abstractmethod def _get_mm_fields_config( @@ -189,8 +201,8 @@ def _get_mm_fields_config( class LlavaNextMultiModalProcessor( - BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]): - + BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo] +): def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -203,11 +215,13 @@ def _get_mm_fields_config( ) -@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor, - info=LlavaNextProcessingInfo, - dummy_inputs=LlavaDummyInputsBuilder) -class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextMultiModalProcessor, + info=LlavaNextProcessingInfo, + dummy_inputs=LlavaDummyInputsBuilder, +) +class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -217,7 +231,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -240,12 +255,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # Used for multimodal granite models to control encoder outputs elif isinstance(vision_feature_layer, (list, tuple)): vision_hidden_size = config.vision_config.hidden_size * len( - vision_feature_layer) + vision_feature_layer + ) self.select_layers = vision_feature_layer else: raise TypeError( f"vision_layer_feature type: {type(vision_feature_layer)}" - " is not supported") + " is not supported" + ) self.config = config self.multimodal_config = multimodal_config @@ -255,14 +272,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) - self.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size)) + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=vision_hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act, - multimodal_projector_bias=config.multimodal_projector_bias) + multimodal_projector_bias=config.multimodal_projector_bias, + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -271,10 +289,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[LlavaNextImageInputs]: + self, **kwargs: object + ) -> Optional[LlavaNextImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -283,32 +303,21 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(image_sizes, (torch.Tensor, list)): - raise ValueError("Incorrect type of image sizes. " - f"Got type: {type(image_sizes)}") - expected_h = expected_w = self.config.vision_config.image_size return LlavaNextImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values), - image_sizes=flatten_bn(image_sizes, concat=True), + pixel_values=pixel_values, + image_sizes=image_sizes, resolve_bindings={ "h": expected_h, "w": expected_w, - }) + }, + ) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeds. " - f"Got type: {type(image_embeds)}") - return LlavaNextImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") @@ -327,21 +336,23 @@ def _image_pixels_to_features( ) # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py - def _merge_image_patch_embeddings(self, image_size: torch.Tensor, - patch_embeddings: torch.Tensor, *, - strategy: str) -> torch.Tensor: + def _merge_image_patch_embeddings( + self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str + ) -> torch.Tensor: if strategy == "flat": return patch_embeddings.flatten(0, 1) if strategy.startswith("spatial"): - height = width = self.config.vision_config.image_size \ + height = width = ( + self.config.vision_config.image_size // self.config.vision_config.patch_size + ) base_patch_embeds = patch_embeddings[0] if height * width != base_patch_embeds.shape[0]: raise ValueError( - "The number of patches is not consistent with the " - "image size.") + "The number of patches is not consistent with the image size." + ) if patch_embeddings.shape[0] > 1: other_patch_embeds = patch_embeddings[1:] @@ -358,37 +369,51 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, num_patches = num_patch_height * num_patch_width # Image patches might be padded for batch processing - other_patch_embeds = other_patch_embeds[:num_patches] \ - .view(num_patch_height, num_patch_width, height, width, -1) + other_patch_embeds = other_patch_embeds[:num_patches].view( + num_patch_height, num_patch_width, height, width, -1 + ) if "unpad" in strategy: - other_patch_embeds = other_patch_embeds \ - .permute(4, 0, 2, 1, 3).contiguous() \ - .flatten(1, 2).flatten(2, 3) - other_patch_embeds = unpad_image(other_patch_embeds, - (orig_height, orig_width)) - other_patch_embeds = torch.cat(( - other_patch_embeds, - self.image_newline[:, None, None] \ - .expand(*other_patch_embeds.shape[:-1], 1) \ + other_patch_embeds = ( + other_patch_embeds.permute(4, 0, 2, 1, 3) + .contiguous() + .flatten(1, 2) + .flatten(2, 3) + ) + other_patch_embeds = unpad_image( + other_patch_embeds, (orig_height, orig_width) + ) + other_patch_embeds = torch.cat( + ( + other_patch_embeds, + self.image_newline[:, None, None] + .expand(*other_patch_embeds.shape[:-1], 1) .to(other_patch_embeds.device), - ), dim=-1) - other_patch_embeds = other_patch_embeds \ - .flatten(1, 2).transpose(0, 1) + ), + dim=-1, + ) + other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose( + 0, 1 + ) else: - other_patch_embeds = other_patch_embeds \ - .permute(0, 2, 1, 3, 4).contiguous() \ + other_patch_embeds = ( + other_patch_embeds.permute(0, 2, 1, 3, 4) + .contiguous() .flatten(0, 3) + ) merged_patch_embeddings = torch.cat( - (base_patch_embeds, other_patch_embeds), dim=0) + (base_patch_embeds, other_patch_embeds), dim=0 + ) else: if "unpad" in strategy: merged_patch_embeddings = torch.cat( - (base_patch_embeds, - self.image_newline[None] \ - .to(base_patch_embeds.device) - ), dim=0) + ( + base_patch_embeds, + self.image_newline[None].to(base_patch_embeds.device), + ), + dim=0, + ) else: merged_patch_embeddings = base_patch_embeds @@ -408,20 +433,25 @@ def _process_image_pixels( b, num_patches, c, h, w = pixel_values.shape stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) stacked_image_features = self._image_pixels_to_features( - self.vision_tower, stacked_pixel_values) + self.vision_tower, stacked_pixel_values + ) stacked_patch_embeddings = self.multi_modal_projector( - stacked_image_features) + stacked_image_features + ) return stacked_patch_embeddings.view( - b, num_patches, *stacked_patch_embeddings.shape[1:]) + b, num_patches, *stacked_patch_embeddings.shape[1:] + ) num_patches_per_batch = [v.shape[0] for v in pixel_values] stacked_pixel_values = torch.cat(pixel_values) stacked_image_features = self._image_pixels_to_features( - self.vision_tower, stacked_pixel_values) + self.vision_tower, stacked_pixel_values + ) - return torch.split(self.multi_modal_projector(stacked_image_features), - num_patches_per_batch) + return torch.split( + self.multi_modal_projector(stacked_image_features), num_patches_per_batch + ) def _process_image_input( self, @@ -437,21 +467,21 @@ def _process_image_input( batch_size = len(image_input["data"]) vision_config = self.config.vision_config default_height = default_width = vision_config.image_size - image_sizes = torch.as_tensor([[default_height, default_width] - for _ in range(batch_size)]) + image_sizes = torch.as_tensor( + [[default_height, default_width] for _ in range(batch_size)] + ) return [ - self._merge_image_patch_embeddings(image_sizes[i], - patch_features_batch, - strategy="spatial_unpad") + self._merge_image_patch_embeddings( + image_sizes[i], patch_features_batch, strategy="spatial_unpad" + ) for i, patch_features_batch in enumerate(patch_embeddings) ] def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -535,10 +565,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( @@ -547,7 +576,6 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 7aabef32b4a9..074acc7943a4 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -7,21 +7,30 @@ import torch import torch.nn as nn -from transformers import (BatchFeature, LlavaNextVideoConfig, - LlavaNextVideoProcessor) +from transformers import BatchFeature, LlavaNextVideoConfig, LlavaNextVideoProcessor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.clip import CLIPVisionModel from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - VideoEmbeddingItems, VideoProcessorItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageSize, + MultiModalDataItems, + VideoEmbeddingItems, + VideoProcessorItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -30,34 +39,39 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import init_vision_tower_for_llava from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vision_encoder_info class LlavaNextVideoPixelInputs(TensorSchema): - """ + """ Dimensions: - - bs: Batch size - - nv: Number of videos - - nf: Number of frames - - nc: Number of channels (3) + - bn: Batch size * number of videos + - f: Number of frames + - c: Number of channels (3) - h: Height of each frame - w: Width of each frame - Note that `num_frames` may be different for each batch, in which case + Note that `f` may be different for each batch, in which case the data is passed as a list instead of a batched tensor. Note that it only supports one video input for one batch. """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bs", "nv", "nf", 3, "h", "w")] + pixel_values_videos: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}), + ] class LlavaNextVideoProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(LlavaNextVideoConfig) @@ -137,8 +151,8 @@ def get_num_frames_with_most_features( class LlavaNextVideoDummyInputsBuilder( - BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]): - + BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo] +): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_videos = mm_counts.get("video", 0) @@ -155,16 +169,15 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) video_overrides = mm_options.get("video") if mm_options else None return { - "video": - self._get_dummy_videos( + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, @@ -175,8 +188,8 @@ def get_dummy_mm_data( class LlavaNextVideoMultiModalProcessor( - BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]): - + BaseMultiModalProcessor[LlavaNextVideoProcessingInfo] +): def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -195,7 +208,8 @@ def _get_prompt_updates( def get_replacement(item_idx: int): videos = mm_items.get_items( - "video", (VideoEmbeddingItems, VideoProcessorItems)) + "video", (VideoEmbeddingItems, VideoProcessorItems) + ) if isinstance(videos, VideoEmbeddingItems): num_video_tokens = videos.get_feature_size(item_idx) @@ -220,7 +234,6 @@ def get_replacement(item_idx: int): # adopted from transformers modeling_llava_next_video.py class LlavaNextVideoPooler(nn.Module): - def __init__(self, config: LlavaNextVideoConfig): super().__init__() @@ -237,36 +250,41 @@ def __init__(self, config: LlavaNextVideoConfig): else: # TODO: Support Conv2d pooling layer, need to load weights raise ValueError( - f"Unknown pooling mode: {mode}. Expected [`average`, `max`]") + f"Unknown pooling mode: {mode}. Expected [`average`, `max`]" + ) def forward(self, image_features: torch.Tensor): ori_width = int( - math.sqrt(image_features.shape[1] * self.image_size // - self.image_size)) + math.sqrt(image_features.shape[1] * self.image_size // self.image_size) + ) ori_height = int(ori_width * self.image_size // self.image_size) batch_size, _, dim = image_features.shape - image_features_spatial = image_features \ - .view(batch_size, ori_height, ori_height, dim) \ - .permute(0, 3, 1, 2) + image_features_spatial = image_features.view( + batch_size, ori_height, ori_height, dim + ).permute(0, 3, 1, 2) image_features_spatial = self.pool(image_features_spatial) return image_features_spatial.flatten(2).transpose(1, 2).contiguous() class LlavaNextMultiModalProjector(nn.Module): - - def __init__(self, vision_hidden_size: int, text_hidden_size: int, - projector_hidden_act: str, multimodal_projector_bias: bool): + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + ): super().__init__() - self.linear_1 = nn.Linear(vision_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias) + self.linear_1 = nn.Linear( + vision_hidden_size, text_hidden_size, bias=multimodal_projector_bias + ) self.act = get_act_fn(projector_hidden_act) - self.linear_2 = nn.Linear(text_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias) + self.linear_2 = nn.Linear( + text_hidden_size, text_hidden_size, bias=multimodal_projector_bias + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(image_features) @@ -280,8 +298,8 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: info=LlavaNextVideoProcessingInfo, dummy_inputs=LlavaNextVideoDummyInputsBuilder, ) -class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -291,7 +309,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -316,13 +335,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.vision_resampler = LlavaNextVideoPooler(config) self.multi_modal_projector = LlavaNextMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act, - multimodal_projector_bias=config.multimodal_projector_bias) + multimodal_projector_bias=config.multimodal_projector_bias, + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, @@ -330,14 +351,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) self.make_empty_intermediate_tensors = ( - self.language_model.model.make_empty_intermediate_tensors) + self.language_model.model.make_empty_intermediate_tensors + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]: + self, **kwargs: object + ) -> Optional[LlavaNextVideoPixelInputs]: """ A legal video input should have the following dimensions: { - "pixel_values_videos" : + "pixel_values_videos" : list[b, Tensor(nb_frames, nb_channels, height, width)] } """ @@ -347,12 +370,14 @@ def _parse_and_validate_video_input( return None expected_h = expected_w = self.config.vision_config.image_size - return LlavaNextVideoPixelInputs(type="pixel_values_videos", - data=pixel_values_videos, - resolve_bindings={ - "h": expected_h, - "w": expected_w, - }) + return LlavaNextVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + resolve_bindings={ + "h": expected_h, + "w": expected_w, + }, + ) def _video_pixels_to_features( self, @@ -372,36 +397,32 @@ def _video_pixels_to_features( def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs): assert self.vision_tower is not None - video_pixels = inputs["data"] + video_pixels = inputs["pixel_values_videos"] if isinstance(video_pixels, torch.Tensor): - # TODO: support multiple videos per input - b, num_videos, num_frames, c, h, w = video_pixels.shape - assert (num_videos == 1) - stacked_pixels = video_pixels.view(b * num_videos * num_frames, c, - h, w) + bn, f, c, h, w = video_pixels.shape + stacked_pixels = video_pixels.view(bn * f, c, h, w) stacked_embeddings = self._video_pixels_to_features( - self.vision_tower, stacked_pixels) - embeds = stacked_embeddings.view(b, num_frames, - *stacked_embeddings.shape[1:]) + self.vision_tower, stacked_pixels + ) + embeds = stacked_embeddings.view(bn, f, *stacked_embeddings.shape[1:]) elif is_list_of(video_pixels, torch.Tensor): frames_per_videos = [v.shape[0] for v in video_pixels] stacked_pixels = torch.cat(video_pixels, dim=0) stacked_embeddings = self._video_pixels_to_features( - self.vision_tower, stacked_pixels) + self.vision_tower, stacked_pixels + ) embeds = torch.split(stacked_embeddings, frames_per_videos, dim=0) else: - raise ValueError( - f"Unsupported type of video input {type(video_pixels)}") + raise ValueError(f"Unsupported type of video input {type(video_pixels)}") return [e.flatten(0, 1) for e in embeds] def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: video_input = self._parse_and_validate_video_input(**kwargs) if video_input is None: return [] @@ -425,10 +446,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -438,8 +458,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # This model doesn't support images for now diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 4379f24da1bf..05f1621694c3 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -7,19 +7,27 @@ import torch import torch.nn as nn -from transformers import (BatchFeature, LlavaOnevisionConfig, - LlavaOnevisionProcessor) +from transformers import BatchFeature, LlavaOnevisionConfig, LlavaOnevisionProcessor from transformers.models.llava_onevision.modeling_llava_onevision import ( - get_anyres_image_grid_shape, unpad_image) + get_anyres_image_grid_shape, + unpad_image, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - VideoEmbeddingItems, VideoProcessorItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageSize, + MultiModalDataItems, + VideoEmbeddingItems, + VideoProcessorItems, +) from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -27,11 +35,18 @@ from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava -from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig, - LlavaNextProcessingInfo) +from .llava_next import ( + BaseLlavaNextMultiModalProcessor, + LlavaNextLikeConfig, + LlavaNextProcessingInfo, +) from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) # For profile run _MAX_FRAMES_PER_VIDEO = 16 @@ -46,10 +61,11 @@ class LlavaOnevisionVideoPixelInputs(TensorSchema): - h: Height - w: Width - Note that `num_videos` may be different for each batch, and 'num_frames' + Note that `f` may be different for each batch, and 'num_frames' may be different for each video, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values_videos"] = "pixel_values_videos" pixel_values_videos: Annotated[ @@ -70,6 +86,7 @@ class LlavaOnevisionImagePixelInputs(TensorSchema): Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ @@ -87,6 +104,7 @@ class LlavaOnevisionImageEmbeddingInputs(TensorSchema): - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[ @@ -95,11 +113,13 @@ class LlavaOnevisionImageEmbeddingInputs(TensorSchema): ] -LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs, - LlavaOnevisionImageEmbeddingInputs] +LlavaOnevisionImageInputs = Union[ + LlavaOnevisionImagePixelInputs, LlavaOnevisionImageEmbeddingInputs +] -LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs, - LlavaOnevisionVideoPixelInputs] +LlavaOnevisionMultiInputs = Union[ + LlavaOnevisionImageInputs, LlavaOnevisionVideoPixelInputs +] class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol): @@ -107,7 +127,6 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol): class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): - def get_hf_config(self) -> LlavaOnevisionLikeConfig: return self.ctx.get_hf_config(LlavaOnevisionConfig) @@ -136,12 +155,14 @@ def _get_num_unpadded_features( if aspect_ratio > current_aspect_ratio: new_height = int( - round(original_height * (current_width / original_width), 7)) + round(original_height * (current_width / original_width), 7) + ) padding = (current_height - new_height) // 2 current_height = current_height - (2 * padding) else: new_width = int( - round(original_width * (current_height / original_height), 7)) + round(original_width * (current_height / original_height), 7) + ) padding = (current_width - new_width) // 2 current_width = current_width - (2 * padding) @@ -218,8 +239,9 @@ def get_num_frames_with_most_features( max_videos = mm_counts.get("video", 0) max_total_frames = self._get_max_video_frames(seq_len) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO + ) return max(max_frames_per_video, 1) @@ -233,14 +255,13 @@ def get_max_video_tokens( return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), ) class LlavaOnevisionDummyInputsBuilder( - LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]): - + LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo] +): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -260,35 +281,34 @@ def get_dummy_mm_data( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, - mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), - "video": - self._get_dummy_videos( + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, overrides=video_overrides, - ) + ), } class LlavaOnevisionMultiModalProcessor( - BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]): - + BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo] +): def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -405,7 +425,8 @@ def _get_prompt_updates( def get_video_replacement(item_idx: int): videos = mm_items.get_items( - "video", (VideoEmbeddingItems, VideoProcessorItems)) + "video", (VideoEmbeddingItems, VideoProcessorItems) + ) if isinstance(videos, VideoEmbeddingItems): num_video_tokens = videos.get_feature_size(item_idx) @@ -430,17 +451,20 @@ def get_video_replacement(item_idx: int): class LlavaOnevisionMultiModalProjector(nn.Module): - def __init__(self, config: LlavaOnevisionConfig): super().__init__() - self.linear_1 = nn.Linear(config.vision_config.hidden_size, - config.text_config.hidden_size, - bias=config.multimodal_projector_bias) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) self.act = get_act_fn(config.projector_hidden_act) - self.linear_2 = nn.Linear(config.text_config.hidden_size, - config.text_config.hidden_size, - bias=config.multimodal_projector_bias) + self.linear_2 = nn.Linear( + config.text_config.hidden_size, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(image_features) @@ -452,9 +476,10 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: @MULTIMODAL_REGISTRY.register_processor( LlavaOnevisionMultiModalProcessor, info=LlavaOnevisionProcessingInfo, - dummy_inputs=LlavaOnevisionDummyInputsBuilder) -class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + dummy_inputs=LlavaOnevisionDummyInputsBuilder, +) +class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -464,7 +489,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -489,21 +515,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, prefix=maybe_prefix(prefix, "language_model"), ) - self.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size)) + self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) self.make_empty_intermediate_tensors = ( - self.language_model.model.make_empty_intermediate_tensors) + self.language_model.model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]: + self, **kwargs: object + ) -> Optional[LlavaOnevisionImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -512,42 +540,31 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(image_sizes, (torch.Tensor, list)): - raise ValueError("Incorrect type of image sizes. " - f"Got type: {type(image_sizes)}") - return LlavaOnevisionImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values), - image_sizes=flatten_bn(image_sizes, concat=True), + pixel_values=pixel_values, + image_sizes=image_sizes, resolve_bindings={ "h": self.config.vision_config.image_size, - "w": self.config.vision_config.image_size - }) + "w": self.config.vision_config.image_size, + }, + ) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeds. " - f"Got type: {type(image_embeds)}") - return LlavaOnevisionImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") def _parse_and_validate_video_input( - self, - **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]: + self, **kwargs: object + ) -> Optional[LlavaOnevisionVideoPixelInputs]: """ A legal video input should have the following dimensions: { - "pixel_values_videos" : + "pixel_values_videos" : list[b, Tensor(nb_frames, nb_channels, height, width)] } """ @@ -555,17 +572,14 @@ def _parse_and_validate_video_input( if pixel_values_videos is None: return None - if not isinstance(pixel_values_videos, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel_values_videos. " - f"Got type: {type(pixel_values_videos)}") - return LlavaOnevisionVideoPixelInputs( type="pixel_values_videos", - pixel_values_videos=flatten_bn(pixel_values_videos), + pixel_values_videos=pixel_values_videos, resolve_bindings={ "h": self.config.vision_config.image_size, - "w": self.config.vision_config.image_size - }) + "w": self.config.vision_config.image_size, + }, + ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -573,14 +587,20 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key in ("pixel_values_videos", "video_embeds" - ) and "video" not in mm_input_by_modality: - mm_input_by_modality[ - "video"] = self._parse_and_validate_video_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) return mm_input_by_modality @@ -597,25 +617,29 @@ def _image_pixels_to_features( ) # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py - def _merge_image_patch_embeddings(self, - image_size: torch.Tensor, - patch_embeddings: torch.Tensor, - *, - image_newline=None, - vision_aspect_ratio="anyres_max_9", - strategy: str) -> torch.Tensor: + def _merge_image_patch_embeddings( + self, + image_size: torch.Tensor, + patch_embeddings: torch.Tensor, + *, + image_newline=None, + vision_aspect_ratio="anyres_max_9", + strategy: str, + ) -> torch.Tensor: if strategy == "flat": return patch_embeddings.flatten(0, 1) if strategy.startswith("spatial"): - height = width = self.config.vision_config.image_size \ + height = width = ( + self.config.vision_config.image_size // self.config.vision_config.patch_size + ) base_patch_embeds = patch_embeddings[0] if height * width != base_patch_embeds.shape[0]: raise ValueError( - "The number of patches is not consistent with the " - "image size.") + "The number of patches is not consistent with the image size." + ) if patch_embeddings.shape[0] > 1: other_patch_embeds = patch_embeddings[1:] @@ -632,53 +656,66 @@ def _merge_image_patch_embeddings(self, num_patches = num_patch_height * num_patch_width # Image patches might be padded for batch processing - other_patch_embeds = other_patch_embeds[:num_patches] \ - .view(num_patch_height, num_patch_width, height, width, -1) + other_patch_embeds = other_patch_embeds[:num_patches].view( + num_patch_height, num_patch_width, height, width, -1 + ) if "unpad" in strategy: - other_patch_embeds = other_patch_embeds \ - .permute(4, 0, 2, 1, 3).contiguous() \ - .flatten(1, 2).flatten(2, 3) - other_patch_embeds = unpad_image(other_patch_embeds, - (orig_height, orig_width)) + other_patch_embeds = ( + other_patch_embeds.permute(4, 0, 2, 1, 3) + .contiguous() + .flatten(1, 2) + .flatten(2, 3) + ) + other_patch_embeds = unpad_image( + other_patch_embeds, (orig_height, orig_width) + ) max_num_patches = int( - vision_aspect_ratio.removeprefix("anyres_max_")) + vision_aspect_ratio.removeprefix("anyres_max_") + ) channels, curr_height, curr_width = other_patch_embeds.shape - ratio = math.sqrt(curr_height * curr_width / - (max_num_patches * height**2)) + ratio = math.sqrt( + curr_height * curr_width / (max_num_patches * height**2) + ) if ratio > 1.1: other_patch_embeds = other_patch_embeds[None] other_patch_embeds = nn.functional.interpolate( - other_patch_embeds, [ - int(curr_height // ratio), - int(curr_width // ratio) - ], - mode="bilinear")[0] + other_patch_embeds, + [int(curr_height // ratio), int(curr_width // ratio)], + mode="bilinear", + )[0] if image_newline is not None: other_patch_embeds = torch.cat( ( other_patch_embeds, - image_newline[:, None, None] \ - .expand(*other_patch_embeds.shape[:-1], 1) \ + image_newline[:, None, None] + .expand(*other_patch_embeds.shape[:-1], 1) .to(other_patch_embeds.device), ), - dim=-1) - other_patch_embeds = other_patch_embeds \ - .flatten(1, 2).transpose(0, 1) + dim=-1, + ) + other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose( + 0, 1 + ) else: - other_patch_embeds = other_patch_embeds \ - .permute(0, 2, 1, 3, 4).contiguous() \ + other_patch_embeds = ( + other_patch_embeds.permute(0, 2, 1, 3, 4) + .contiguous() .flatten(0, 3) + ) merged_patch_embeddings = torch.cat( - (base_patch_embeds, other_patch_embeds), dim=0) + (base_patch_embeds, other_patch_embeds), dim=0 + ) else: if "unpad" in strategy: merged_patch_embeddings = torch.cat( - (base_patch_embeds, - self.image_newline[None] \ - .to(base_patch_embeds.device) - ), dim=0) + ( + base_patch_embeds, + self.image_newline[None].to(base_patch_embeds.device), + ), + dim=0, + ) else: merged_patch_embeddings = base_patch_embeds @@ -698,21 +735,27 @@ def _process_image_pixels( b, num_patches, c, h, w = pixel_values.shape stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) stacked_image_features = self._image_pixels_to_features( - self.vision_tower, stacked_pixel_values) + self.vision_tower, stacked_pixel_values + ) stacked_patch_embeddings = self.multi_modal_projector( - stacked_image_features) + stacked_image_features + ) return stacked_patch_embeddings.view( - b, num_patches, *stacked_patch_embeddings.shape[1:]) + b, num_patches, *stacked_patch_embeddings.shape[1:] + ) num_patches_per_batch = [v.shape[0] for v in pixel_values] stacked_pixel_values = torch.cat(pixel_values) stacked_image_features = self._image_pixels_to_features( - self.vision_tower, stacked_pixel_values) + self.vision_tower, stacked_pixel_values + ) return [ - self.multi_modal_projector(image_features) for image_features in - torch.split(stacked_image_features, num_patches_per_batch) + self.multi_modal_projector(image_features) + for image_features in torch.split( + stacked_image_features, num_patches_per_batch + ) ] def _process_image_input( @@ -729,15 +772,17 @@ def _process_image_input( batch_size = len(image_input["pixel_values"]) vision_config = self.config.vision_config default_height = default_width = vision_config.image_size - image_sizes = torch.as_tensor([[default_height, default_width] - for _ in range(batch_size)]) + image_sizes = torch.as_tensor( + [[default_height, default_width] for _ in range(batch_size)] + ) return [ self._merge_image_patch_embeddings( image_sizes[i], patch_features_batch, image_newline=self.image_newline, - strategy="spatial_unpad") + strategy="spatial_unpad", + ) for i, patch_features_batch in enumerate(patch_embeddings) ] @@ -763,36 +808,39 @@ def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs): if isinstance(video_pixels, torch.Tensor): total_videos, frames, c, h, w = video_pixels.shape - video_pixels_flat = video_pixels.view(total_videos * frames, c, h, - w) + video_pixels_flat = video_pixels.view(total_videos * frames, c, h, w) embeddings_flat = self._video_pixels_to_features( - self.vision_tower, video_pixels_flat) + self.vision_tower, video_pixels_flat + ) embeddings_flat = embeddings_flat.reshape( - total_videos, frames * embeddings_flat.shape[1], -1) + total_videos, frames * embeddings_flat.shape[1], -1 + ) image_newline = self.image_newline[None, None, :].expand( - total_videos, -1, -1) + total_videos, -1, -1 + ) return torch.cat((embeddings_flat, image_newline), dim=1) frames_per_video = [len(video) for video in video_pixels] video_pixels_flat = torch.cat(video_pixels) embeddings_flat = self._video_pixels_to_features( - self.vision_tower, video_pixels_flat) + self.vision_tower, video_pixels_flat + ) image_newline = self.image_newline[None, None, :] return [ torch.cat( ( - embeds.reshape(1, num_frame * embeddings_flat.shape[1], - -1), + embeds.reshape(1, num_frame * embeddings_flat.shape[1], -1), image_newline, ), dim=1, - ) for num_frame, embeds in zip( + ) + for num_frame, embeds in zip( frames_per_video, torch.split(embeddings_flat, frames_per_video), ) @@ -808,9 +856,9 @@ def apply_pooling(self, image_features: torch.Tensor, stride: int = 2): # TODO support other pooling types config height, width = image_features.shape[2:] scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)] - image_feature = nn.functional.interpolate(image_features, - size=scaled_shape, - mode='bilinear') + image_feature = nn.functional.interpolate( + image_features, size=scaled_shape, mode="bilinear" + ) image_feature = image_feature.permute(0, 2, 3, 1) image_feature = image_feature.view(batch_frames, -1, dim) return image_feature @@ -818,10 +866,8 @@ def apply_pooling(self, image_features: torch.Tensor, stride: int = 2): def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] return None @@ -860,10 +906,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -873,7 +918,6 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/longcat_flash.py b/vllm/model_executor/models/longcat_flash.py index 78e6e3d4b535..5020da37df89 100644 --- a/vllm/model_executor/models/longcat_flash.py +++ b/vllm/model_executor/models/longcat_flash.py @@ -32,6 +32,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. """Inference-only Flash model compatible with HuggingFace weights.""" + import typing from collections.abc import Callable, Iterable from typing import Optional, Union @@ -47,29 +48,37 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.utils.int8_utils import ( - block_dequant) +from vllm.model_executor.layers.quantization.utils.int8_utils import block_dequant from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class FlashConfig(PretrainedConfig): """Flash model configuration.""" + model_type = "longcat_flash" keys_to_ignore_at_inference = ["past_key_values"] @@ -132,8 +141,9 @@ def __init__( self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size - self.num_hidden_layers = (num_hidden_layers if num_hidden_layers - is not None else num_layers) + self.num_hidden_layers = ( + num_hidden_layers if num_hidden_layers is not None else num_layers + ) self.num_attention_heads = num_attention_heads self.ep_size = ep_size self.kv_lora_rank = kv_lora_rank @@ -162,8 +172,11 @@ def __init__( self.zero_expert_type = zero_expert_type self.routed_scaling_factor = routed_scaling_factor self.hidden_act = "silu" - self.intermediate_size = self.ffn_hidden_size if hasattr( - self, "ffn_hidden_size") else self.intermediate_size + self.intermediate_size = ( + self.ffn_hidden_size + if hasattr(self, "ffn_hidden_size") + else self.intermediate_size + ) if hasattr(self, "moe_intermediate_size"): self.moe_intermediate_size = self.moe_intermediate_size elif hasattr(self, "expert_ffn_hidden_size"): @@ -201,8 +214,9 @@ def __init__( prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -216,15 +230,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LongcatRouter(nn.Module): - - def __init__(self, - config, - zero_expert_num=0, - rounter_params_dtype=torch.bfloat16, - prefix: str = ""): + def __init__( + self, + config, + zero_expert_num=0, + rounter_params_dtype=torch.bfloat16, + prefix: str = "", + ): super().__init__() - self.n_routed_experts = config.n_routed_experts if hasattr( - config, "n_routed_experts") else config.num_experts[0] + self.n_routed_experts = ( + config.n_routed_experts + if hasattr(config, "n_routed_experts") + else config.num_experts[0] + ) self.n_routed_experts = self.n_routed_experts + zero_expert_num self.classifier = ReplicatedLinear( config.hidden_size, @@ -235,7 +253,8 @@ def __init__(self, prefix=f"{prefix}.classifier", ) self.e_score_correction_bias = nn.Parameter( - torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype)) + torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype) + ) def forward(self, hidden_states): logits, _ = self.classifier(hidden_states) @@ -243,7 +262,6 @@ def forward(self, hidden_states): class LongcatMoe(nn.Module): - def __init__( self, config: FlashConfig, @@ -271,7 +289,8 @@ def __init__( config=config, zero_expert_num=self.zero_expert_num, rounter_params_dtype=self.rounter_params_dtype, - prefix=f"{prefix}.gate") + prefix=f"{prefix}.gate", + ) self.experts = FusedMoE( num_experts=num_experts, @@ -291,14 +310,13 @@ def __init__( ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - router_logits = self.router(hidden_states.to( - self.rounter_params_dtype)) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + router_logits = self.router(hidden_states.to(self.rounter_params_dtype)) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) return final_hidden_states.view(num_tokens, hidden_dim) @@ -316,67 +334,76 @@ def __init__( enable_eplb: bool = False, ) -> None: super().__init__() - self.layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = int(prefix.split(sep=".")[-1]) self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) + config.original_max_position_embeddings + ) # Dual attention structure - self.self_attn = nn.ModuleList([ - DeepseekV2MLAAttention( - vllm_config=vllm_config, - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=(config.q_lora_rank if hasattr( - config, "q_lora_rank") else None), - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=None if "self_attn" in getattr( - config, "disable_quant_module", []) else quant_config, - prefix=f"{prefix}.self_attn.{i}", - ) for i in range(2) - ]) - self.input_layernorm = nn.ModuleList([ - RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - for i in range(2) - ]) - self.post_attention_layernorm = nn.ModuleList([ - RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - for i in range(2) - ]) + self.self_attn = nn.ModuleList( + [ + DeepseekV2MLAAttention( + vllm_config=vllm_config, + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=( + config.q_lora_rank if hasattr(config, "q_lora_rank") else None + ), + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=None + if "self_attn" in getattr(config, "disable_quant_module", []) + else quant_config, + prefix=f"{prefix}.self_attn.{i}", + ) + for i in range(2) + ] + ) + self.input_layernorm = nn.ModuleList( + [RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for i in range(2)] + ) + self.post_attention_layernorm = nn.ModuleList( + [RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for i in range(2)] + ) # Dual MLP structure - self.mlps = nn.ModuleList([ - FlashMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=None if "mlps" in getattr( - config, "disable_quant_module", []) else quant_config, - prefix=f"{prefix}.mlps.{i}", - ) for i in range(2) - ]) + self.mlps = nn.ModuleList( + [ + FlashMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=None + if "mlps" in getattr(config, "disable_quant_module", []) + else quant_config, + prefix=f"{prefix}.mlps.{i}", + ) + for i in range(2) + ] + ) self.mlp = LongcatMoe( config=config, - num_experts=config.n_routed_experts if hasattr( - config, "n_routed_experts") else - config.num_experts[self.layer_idx], + num_experts=config.n_routed_experts + if hasattr(config, "n_routed_experts") + else config.num_experts[self.layer_idx], top_k=config.moe_topk - if hasattr(config, "moe_topk") else config.num_experts_per_tok, + if hasattr(config, "moe_topk") + else config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, quant_config=quant_config, @@ -389,13 +416,11 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - if residual is None: residual = hidden_states hidden_states = self.input_layernorm[0](hidden_states) else: - hidden_states, residual = self.input_layernorm[0](hidden_states, - residual) + hidden_states, residual = self.input_layernorm[0](hidden_states, residual) hidden_states = self.self_attn[0]( positions=positions, @@ -403,7 +428,8 @@ def forward( ) hidden_states, residual = self.post_attention_layernorm[0]( - hidden_states, residual) + hidden_states, residual + ) # moe hidden_states_copy = hidden_states.clone() @@ -412,8 +438,7 @@ def forward( # first mlp hidden_states = self.mlps[0](hidden_states) - hidden_states, residual = self.input_layernorm[1](hidden_states, - residual) + hidden_states, residual = self.input_layernorm[1](hidden_states, residual) # second_attn hidden_states = self.self_attn[1]( @@ -421,7 +446,8 @@ def forward( hidden_states=hidden_states, ) hidden_states, residual = self.post_attention_layernorm[1]( - hidden_states, residual) + hidden_states, residual + ) # second_mlp hidden_states = self.mlps[1](hidden_states) @@ -462,14 +488,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, prefix=prefix, ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -501,10 +528,9 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -532,26 +558,32 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - config.intermediate_size = config.ffn_hidden_size if hasattr( - config, "ffn_hidden_size") else config.intermediate_size + config.intermediate_size = ( + config.ffn_hidden_size + if hasattr(config, "ffn_hidden_size") + else config.intermediate_size + ) self.lora_config = lora_config self.quant_config = quant_config - self.model = FlashModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = FlashModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -563,8 +595,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -581,14 +614,12 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts if hasattr( - self.config, "n_routed_experts") else - self.config.num_experts[0], + num_experts=self.config.n_routed_experts + if hasattr(self.config, "n_routed_experts") + else self.config.num_experts[0], ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), @@ -610,8 +641,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if (name.endswith(".bias") - or name.endswith("_bias")) and name not in params_dict: + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip mtp if ".mtp." in name: @@ -633,22 +665,25 @@ def load_weights(self, weights: Iterable[tuple[str, # Skip mtp if ".mtp." in name_mapped: continue - if (name_mapped.endswith(".bias") - or name_mapped.endswith("_bias") - ) and name not in params_dict: + if ( + name_mapped.endswith(".bias") or name_mapped.endswith("_bias") + ) and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name_mapped] weight_loader = param.weight_loader - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) if success: name = name_mapped break @@ -672,8 +707,9 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) for layer_id in range(self.config.num_hidden_layers): @@ -681,35 +717,35 @@ def load_weights(self, weights: Iterable[tuple[str, if isinstance(self.model.layers[layer_id], PPMissingLayer): continue self_attn = self.model.layers[layer_id].self_attn[i] - if hasattr(self.quant_config, "weight_block_size" - ) and self_attn.kv_b_proj.weight.dtype in ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ): + if hasattr( + self.quant_config, "weight_block_size" + ) and self_attn.kv_b_proj.weight.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): weight_block_size = self.quant_config.weight_block_size if weight_block_size is not None: assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") dtype = torch.get_default_dtype() - w = block_dequant(self_attn.kv_b_proj.weight, - self_attn.kv_b_proj.weight_scale_inv, - weight_block_size).to(dtype) + w = block_dequant( + self_attn.kv_b_proj.weight, + self_attn.kv_b_proj.weight_scale_inv, + weight_block_size, + ).to(dtype) else: w = self_attn.kv_b_proj.weight w_kc, w_vc = w.unflatten( - 0, - (-1, - self_attn.qk_nope_head_dim + self_attn.v_head_dim)).split( - [self_attn.qk_nope_head_dim, self_attn.v_head_dim], - dim=1) - self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose( - 1, 2) + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_vc = w_vc.contiguous().transpose(1, 2) if self.config.mla_scale_q_lora: self_attn.q_a_layernorm.weight.data *= ( - self.config.hidden_size / self.config.q_lora_rank)**0.5 + self.config.hidden_size / self.config.q_lora_rank + ) ** 0.5 if self.config.mla_scale_kv_lora: self_attn.kv_a_layernorm.weight.data *= ( - self.config.hidden_size / - self.config.kv_lora_rank)**0.5 + self.config.hidden_size / self.config.kv_lora_rank + ) ** 0.5 return loaded_params diff --git a/vllm/model_executor/models/longcat_flash_mtp.py b/vllm/model_executor/models/longcat_flash_mtp.py index e288658a7ebf..55468f354c3a 100644 --- a/vllm/model_executor/models/longcat_flash_mtp.py +++ b/vllm/model_executor/models/longcat_flash_mtp.py @@ -15,10 +15,11 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.utils.int8_utils import ( - block_dequant) +from vllm.model_executor.layers.quantization.utils.int8_utils import block_dequant from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.longcat_flash import FlashConfig from vllm.sequence import IntermediateTensors @@ -29,7 +30,6 @@ class LongCatMultiTokenPredictorLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -40,14 +40,15 @@ def __init__( super().__init__() self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.eh_proj = ReplicatedLinear(2 * config.hidden_size, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix="eh_proj") + self.eh_proj = ReplicatedLinear( + 2 * config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix="eh_proj", + ) self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix) - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -62,39 +63,43 @@ def forward( previous_hidden_states = self.hnorm(previous_hidden_states) hidden_states, _ = self.eh_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + torch.cat([inputs_embeds, previous_hidden_states], dim=-1) + ) - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - residual=None) + hidden_states, residual = self.mtp_block( + positions=positions, hidden_states=hidden_states, residual=None + ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states class LongCatMultiTokenPredictor(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() config = FlashConfig(**vllm_config.model_config.hf_config.__dict__) - vllm_config.model_config.hf_config.intermediate_size \ - = config.intermediate_size + vllm_config.model_config.hf_config.intermediate_size = config.intermediate_size self.mtp_start_layer_idx = config.num_hidden_layers * 2 self.num_mtp_layers = 1 - self.layers = torch.nn.ModuleDict({ - str(idx): - LongCatMultiTokenPredictorLayer( - config, - prefix=f"{prefix}.layers.{idx}", - vllm_config=vllm_config, - quant_config=quant_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.layers = torch.nn.ModuleDict( + { + str(idx): LongCatMultiTokenPredictorLayer( + config, + prefix=f"{prefix}.layers.{idx}", + vllm_config=vllm_config, + quant_config=quant_config, + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -110,7 +115,7 @@ def forward( ) -> torch.Tensor: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - current_step_idx = (spec_step_idx % self.num_mtp_layers) + current_step_idx = spec_step_idx % self.num_mtp_layers return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, positions, @@ -121,21 +126,22 @@ def forward( class LongCatFlashMTP(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() # LongCat MTP without MoE layers vllm_config.model_config.hf_config.n_routed_experts = None - self.config = FlashConfig( - **vllm_config.model_config.hf_config.__dict__) - self.quant_config = None if "mtp" in getattr( - self.config, "disable_quant_module", - []) else vllm_config.quant_config + self.config = FlashConfig(**vllm_config.model_config.hf_config.__dict__) + self.quant_config = ( + None + if "mtp" in getattr(self.config, "disable_quant_module", []) + else vllm_config.quant_config + ) - self.model = LongCatMultiTokenPredictor(vllm_config=vllm_config, - quant_config=self.quant_config, - prefix=maybe_prefix( - prefix, "model")) + self.model = LongCatMultiTokenPredictor( + vllm_config=vllm_config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "model"), + ) self.lm_head = ParallelLMHead( self.config.vocab_size, self.config.hidden_size, @@ -153,8 +159,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( @@ -162,12 +169,10 @@ def compute_logits( hidden_states: torch.Tensor, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), @@ -176,53 +181,31 @@ def load_weights(self, weights: Iterable[tuple[str, ] new_to_old_names_mapping = { - "model.mtp.embed_tokens.weight": - "model.layers.0.embed_tokens.weight", + "model.mtp.embed_tokens.weight": "model.layers.0.embed_tokens.weight", "model.mtp.layers.0.eh_proj.weight": "eh_proj.weight", - "model.mtp.layers.0.eh_proj.weight_scale_inv": - "eh_proj.weight_scale_inv", + "model.mtp.layers.0.eh_proj.weight_scale_inv": "eh_proj.weight_scale_inv", "model.mtp.layers.0.enorm.m.weight": "enorm.weight", "model.mtp.layers.0.hnorm.m.weight": "hnorm.weight", - "model.mtp.layers.0.input_layernorm.weight": - "model.layers.0.input_layernorm.weight", - "model.mtp.layers.0.post_attention_layernorm.weight": - "model.layers.0.post_attention_layernorm.weight", - "model.mtp.layers.0.self_attn.kv_a_layernorm.weight": - "model.layers.0.self_attn.kv_a_layernorm.weight", - "model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight": - "model.layers.0.self_attn.kv_a_proj_with_mqa.weight", - "model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv": - "model.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv", - "model.mtp.layers.0.self_attn.kv_b_proj.weight": - "model.layers.0.self_attn.kv_b_proj.weight", - "model.mtp.layers.0.self_attn.kv_b_proj.weight_scale_inv": - "model.layers.0.self_attn.kv_b_proj.weight_scale_inv", - "model.mtp.layers.0.self_attn.o_proj.weight": - "model.layers.0.self_attn.o_proj.weight", - "model.mtp.layers.0.self_attn.o_proj.weight_scale_inv": - "model.layers.0.self_attn.o_proj.weight_scale_inv", - "model.mtp.layers.0.self_attn.q_a_layernorm.weight": - "model.layers.0.self_attn.q_a_layernorm.weight", - "model.mtp.layers.0.self_attn.q_a_proj.weight": - "model.layers.0.self_attn.q_a_proj.weight", - "model.mtp.layers.0.self_attn.q_a_proj.weight_scale_inv": - "model.layers.0.self_attn.q_a_proj.weight_scale_inv", - "model.mtp.layers.0.self_attn.q_b_proj.weight": - "model.layers.0.self_attn.q_b_proj.weight", - "model.mtp.layers.0.self_attn.q_b_proj.weight_scale_inv": - "model.layers.0.self_attn.q_b_proj.weight_scale_inv", - "model.mtp.layers.0.transformer_layer.mlp.down_proj.weight": - "model.layers.0.mlp.down_proj.weight", - "model.mtp.layers.0.transformer_layer.mlp.down_proj.weight_scale_inv": - "model.layers.0.mlp.down_proj.weight_scale_inv", - "model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight": - "model.layers.0.mlp.gate_proj.weight", - "model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight_scale_inv": - "model.layers.0.mlp.gate_proj.weight_scale_inv", - "model.mtp.layers.0.transformer_layer.mlp.up_proj.weight": - "model.layers.0.mlp.up_proj.weight", - "model.mtp.layers.0.transformer_layer.mlp.up_proj.weight_scale_inv": - "model.layers.0.mlp.up_proj.weight_scale_inv", + "model.mtp.layers.0.input_layernorm.weight": "model.layers.0.input_layernorm.weight", # noqa: E501 + "model.mtp.layers.0.post_attention_layernorm.weight": "model.layers.0.post_attention_layernorm.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.kv_a_layernorm.weight": "model.layers.0.self_attn.kv_a_layernorm.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight": "model.layers.0.self_attn.kv_a_proj_with_mqa.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv": "model.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.self_attn.kv_b_proj.weight": "model.layers.0.self_attn.kv_b_proj.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.kv_b_proj.weight_scale_inv": "model.layers.0.self_attn.kv_b_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.self_attn.o_proj.weight": "model.layers.0.self_attn.o_proj.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.o_proj.weight_scale_inv": "model.layers.0.self_attn.o_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.self_attn.q_a_layernorm.weight": "model.layers.0.self_attn.q_a_layernorm.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.q_a_proj.weight": "model.layers.0.self_attn.q_a_proj.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.q_a_proj.weight_scale_inv": "model.layers.0.self_attn.q_a_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.self_attn.q_b_proj.weight": "model.layers.0.self_attn.q_b_proj.weight", # noqa: E501 + "model.mtp.layers.0.self_attn.q_b_proj.weight_scale_inv": "model.layers.0.self_attn.q_b_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.down_proj.weight": "model.layers.0.mlp.down_proj.weight", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.down_proj.weight_scale_inv": "model.layers.0.mlp.down_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight": "model.layers.0.mlp.gate_proj.weight", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight_scale_inv": "model.layers.0.mlp.gate_proj.weight_scale_inv", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.up_proj.weight": "model.layers.0.mlp.up_proj.weight", # noqa: E501 + "model.mtp.layers.0.transformer_layer.mlp.up_proj.weight_scale_inv": "model.layers.0.mlp.up_proj.weight_scale_inv", # noqa: E501 "model.mtp.norm.weight": "final_layernorm.weight", } @@ -231,13 +214,13 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - spec_layer = self.get_spec_layer_idx_from_weight_name( - self.config, name) + spec_layer = self.get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is None: continue - name = self._rewrite_spec_layer_name(spec_layer, name, - new_to_old_names_mapping) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + name = self._rewrite_spec_layer_name( + spec_layer, name, new_to_old_names_mapping + ) + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -247,14 +230,13 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled - if ((param_name == "fused_qkv_a_proj") - and name not in params_dict): + if (param_name == "fused_qkv_a_proj") and name not in params_dict: continue # Skip loading extra bias for GPTQ models. @@ -272,48 +254,54 @@ def load_weights(self, weights: Iterable[tuple[str, # According to DeepSeek-V3 Technical Report, MTP modules # shares embedding layer. We only load the first weights. - if (spec_layer != self.model.mtp_start_layer_idx - and ".layers" not in name): + if ( + spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name + ): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) spec_layer_id = self.config.num_hidden_layers * 2 self_attn = self.model.layers[str(spec_layer_id)].mtp_block.self_attn if hasattr( - self.quant_config, - "weight_block_size") and self_attn.kv_b_proj.weight.dtype in ( - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ): + self.quant_config, "weight_block_size" + ) and self_attn.kv_b_proj.weight.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): weight_block_size = self.quant_config.weight_block_size if weight_block_size is not None: dtype = torch.get_default_dtype() - w = block_dequant(self_attn.kv_b_proj.weight, - self_attn.kv_b_proj.weight_scale_inv, - weight_block_size).to(dtype) + w = block_dequant( + self_attn.kv_b_proj.weight, + self_attn.kv_b_proj.weight_scale_inv, + weight_block_size, + ).to(dtype) else: w = self_attn.kv_b_proj.weight else: w = self_attn.kv_b_proj.weight w_kc, w_vc = w.unflatten( - 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)).split( - [self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_vc = w_vc.contiguous().transpose(1, 2) if self.config.mla_scale_q_lora: self_attn.q_a_layernorm.weight.data *= ( - self.config.hidden_size / self.config.q_lora_rank)**0.5 + self.config.hidden_size / self.config.q_lora_rank + ) ** 0.5 if self.config.mla_scale_kv_lora: self_attn.kv_a_layernorm.weight.data *= ( - self.config.hidden_size / self.config.kv_lora_rank)**0.5 + self.config.hidden_size / self.config.kv_lora_rank + ) ** 0.5 return loaded_params - def _rewrite_spec_layer_name(self, spec_layer: int, name: str, - new_to_old_names_mapping: dict) -> str: + def _rewrite_spec_layer_name( + self, spec_layer: int, name: str, new_to_old_names_mapping: dict + ) -> str: """ Rewrite the weight name to match the format of the original model. Add .mtp_block for modules in transformer layer block for spec layer @@ -322,11 +310,18 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str, if name in new_to_old_names_mapping: name = new_to_old_names_mapping[name] spec_layer_weight_names = [ - "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", ] - if name.startswith("enorm") or name.startswith( - "hnorm") or name.startswith("eh_proj") or name.startswith( - "final_layernorm"): + if ( + name.startswith("enorm") + or name.startswith("hnorm") + or name.startswith("eh_proj") + or name.startswith("final_layernorm") + ): name = "model.layers." + str(spec_layer) + "." + name shared_weight_names = ["embed_tokens"] spec_layer_weight = False @@ -339,15 +334,17 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str, break if not spec_layer_weight: # treat rest weights as weights for transformer layer block - name = name.replace("model.layers.0.", - f"model.layers.{spec_layer}.mtp_block.") + name = name.replace( + "model.layers.0.", f"model.layers.{spec_layer}.mtp_block." + ) elif shared_weight: # treat shared weights as top level weights name = name.replace("model.layers.0.", "model.") return name - def get_spec_layer_idx_from_weight_name(self, config: PretrainedConfig, - weight_name: str) -> Optional[int]: + def get_spec_layer_idx_from_weight_name( + self, config: PretrainedConfig, weight_name: str + ) -> Optional[int]: if "model.mtp" in weight_name: return config.num_hidden_layers * 2 return None diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index d810701c50b4..fa11f92cce33 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """PyTorch MAMBA model.""" + from collections.abc import Iterable from typing import Optional @@ -15,51 +16,66 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree, SupportsPP) +from vllm.model_executor.models.interfaces import ( + HasInnerState, + IsAttentionFree, + SupportsPP, +) from vllm.sequence import IntermediateTensors -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) KVCache = tuple[torch.Tensor, torch.Tensor] class MambaDecoderLayer(nn.Module): - - def __init__(self, - config: MambaConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - is_lora_enabled: Optional[bool] = False, - prefix: str = "") -> None: + def __init__( + self, + config: MambaConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + is_lora_enabled: Optional[bool] = False, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.is_falcon_mamba = config.model_type == "falcon_mamba" self.is_lora_enabled = is_lora_enabled mixer_rms_eps = config.mixer_rms_eps if self.is_falcon_mamba else None - self.mixer = MambaMixer(hidden_size=config.hidden_size, - ssm_state_size=config.state_size, - conv_kernel_size=config.conv_kernel, - intermediate_size=config.intermediate_size, - time_step_rank=config.time_step_rank, - use_conv_bias=config.use_conv_bias, - use_bias=config.use_bias, - use_rms_norm=self.is_falcon_mamba, - rms_norm_has_weight=not self.is_falcon_mamba, - rms_norm_eps=mixer_rms_eps, - activation=config.hidden_act, - is_lora_enabled=self.is_lora_enabled, - model_config=model_config, - cache_config=cache_config, - prefix=f"{prefix}.mixer") + self.mixer = MambaMixer( + hidden_size=config.hidden_size, + ssm_state_size=config.state_size, + conv_kernel_size=config.conv_kernel, + intermediate_size=config.intermediate_size, + time_step_rank=config.time_step_rank, + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + use_rms_norm=self.is_falcon_mamba, + rms_norm_has_weight=not self.is_falcon_mamba, + rms_norm_eps=mixer_rms_eps, + activation=config.hidden_act, + is_lora_enabled=self.is_lora_enabled, + model_config=model_config, + cache_config=cache_config, + prefix=f"{prefix}.mixer", + ) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -82,7 +98,6 @@ def forward( @support_torch_compile class MambaModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -94,8 +109,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): is_lora_enabled = bool(lora_config) self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -107,19 +125,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MambaDecoderLayer(config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - is_lora_enabled=is_lora_enabled, - prefix=prefix), - prefix=f"{prefix}.layers") - - self.norm_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + lambda prefix: MambaDecoderLayer( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + is_lora_enabled=is_lora_enabled, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) @@ -144,20 +164,18 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions=positions, - hidden_states=hidden_states, - residual=residual) + hidden_states, residual = layer( + positions=positions, hidden_states=hidden_states, residual=residual + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -170,29 +188,29 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config self.scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ + assert not cache_config.enable_prefix_caching, ( "Mamba does not support prefix caching" + ) super().__init__() self.config = config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - self.backbone = MambaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "backbone")) + self.backbone = MambaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -206,28 +224,33 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.backbone.make_empty_intermediate_tensors) + self.backbone.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - hidden_states = self.backbone(input_ids, positions, - intermediate_tensors, inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.backbone( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states @@ -236,7 +259,6 @@ def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba1_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -255,11 +277,11 @@ def get_mamba_state_shape_from_config( tp_world_size=parallel_config.tensor_parallel_size, intermediate_size=hf_config.intermediate_size, state_size=hf_config.state_size, - conv_kernel=hf_config.conv_kernel) + conv_kernel=hf_config.conv_kernel, + ) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) + return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) @@ -268,7 +290,6 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index f8a5a8f6081b..4491648f3a0a 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """PyTorch MAMBA2 model.""" + from collections.abc import Iterable from typing import Optional @@ -15,49 +16,60 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree) +from vllm.model_executor.models.interfaces import HasInnerState, IsAttentionFree from vllm.sequence import IntermediateTensors -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) KVCache = tuple[torch.Tensor, torch.Tensor] class Mamba2DecoderLayer(nn.Module): - - def __init__(self, - config: MambaConfig, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: MambaConfig, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config - self.mixer = MambaMixer2(hidden_size=config.hidden_size, - ssm_state_size=config.state_size, - conv_kernel_size=config.conv_kernel, - intermediate_size=getattr( - config, "intermediate_size", - config.expand * config.hidden_size), - use_conv_bias=config.use_conv_bias, - use_bias=config.use_bias, - n_groups=config.n_groups, - num_heads=config.num_heads, - head_dim=config.head_dim, - rms_norm_eps=config.layer_norm_epsilon, - activation=config.hidden_act, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.mixer") + self.mixer = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.state_size, + conv_kernel_size=config.conv_kernel, + intermediate_size=getattr( + config, "intermediate_size", config.expand * config.hidden_size + ), + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + n_groups=config.n_groups, + num_heads=config.num_heads, + head_dim=config.head_dim, + rms_norm_eps=config.layer_norm_epsilon, + activation=config.hidden_act, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -80,7 +92,6 @@ def forward( @support_torch_compile class Mamba2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -93,8 +104,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert not is_lora_enabled self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -106,18 +120,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Mamba2DecoderLayer(config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers") - - self.norm_f = RMSNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + lambda prefix: Mamba2DecoderLayer( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) @@ -141,22 +157,20 @@ def forward( residual = intermediate_tensors["residual"] for i, layer in enumerate(self.layers): - hidden_states, residual = layer(positions=positions, - hidden_states=hidden_states, - residual=residual) + hidden_states, residual = layer( + positions=positions, hidden_states=hidden_states, residual=residual + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -170,21 +184,18 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): - @classmethod def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -222,19 +233,17 @@ def get_mamba_state_shape_from_config( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Mamba does not support prefix caching" super().__init__() self.config = config self.vllm_config = vllm_config self.scheduler_config = scheduler_config self.model_config = vllm_config.model_config - self.backbone = Mamba2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "backbone")) + self.backbone = Mamba2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -246,36 +255,40 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.backbone.embeddings) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.backbone.make_empty_intermediate_tensors) + self.backbone.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - hidden_states = self.backbone(input_ids, positions, - intermediate_tensors, inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.backbone( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) + return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) @@ -284,7 +297,6 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index f083c2cb0380..7e1d2bf14bb5 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -9,24 +9,28 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from .utils import maybe_prefix class ResidualBlock(nn.Module): - - def __init__(self, config: VllmConfig, hidden_size: int, - num_layers: int) -> None: + def __init__(self, config: VllmConfig, hidden_size: int, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList([ - nn.Linear(hidden_size, - hidden_size, - bias=getattr(config, "medusa_fc_bias", False)) - for _ in range(num_layers) - ]) + self.layers = nn.ModuleList( + [ + nn.Linear( + hidden_size, + hidden_size, + bias=getattr(config, "medusa_fc_bias", False), + ) + for _ in range(num_layers) + ] + ) self.act = nn.SiLU() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -38,13 +42,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Medusa(nn.Module): """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774 Reference implementation: https://github.com/FasterDecoding/Medusa - + Differences from reference implementation: 1. Currently this only supports generating proposals from top-1 tokens. - 2. We have an optional token_map which reduces draft vocab to most - frequently used tokens to give some additional speed-up by reducing - sampling overhead. This is disabled unless the checkpoint file has - explicit token_map tensor and config has an optional attribute + 2. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute truncated_vocab_size < vocab_size. To use this technique, one has to find the top-k most frequent tokens in target dataset and add that as a tensor in the draft checkpoint (using key token_map). Also, the draft config @@ -54,12 +58,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.speculative_config.draft_model_config.hf_config super().__init__() self.config = config - self.blocks = nn.ModuleList([ - ResidualBlock(config=config, - hidden_size=self.config.hidden_size, - num_layers=self.config.num_hidden_layers) - for _ in range(self.config.num_heads) - ]) + self.blocks = nn.ModuleList( + [ + ResidualBlock( + config=config, + hidden_size=self.config.hidden_size, + num_layers=self.config.num_hidden_layers, + ) + for _ in range(self.config.num_heads) + ] + ) self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size @@ -72,24 +80,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: padding_size=DEFAULT_VOCAB_PADDING_SIZE, prefix=maybe_prefix(prefix, "lm_head"), ) - self.lm_heads = [ - self.lm_head for _ in range(self.config.num_heads) - ] + self.lm_heads = [self.lm_head for _ in range(self.config.num_heads)] else: - self.lm_heads = nn.ModuleList([ - ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=self.truncated_vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - prefix=maybe_prefix(prefix, f"lm_heads.{i}"), - ) for i in range(self.config.num_heads) - ]) + self.lm_heads = nn.ModuleList( + [ + ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, f"lm_heads.{i}"), + ) + for i in range(self.config.num_heads) + ] + ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.truncated_vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.truncated_vocab_size, logit_scale + ) # Token map is a idx to token mapping to reduce the vocab size for # the draft model. Using smaller vocab size for draft, containing @@ -120,17 +129,20 @@ def compute_logits( if self.token_map is None: logits_lst.append(_logits) else: - logits_lst.append(-torch.inf * torch.ones( - size=(*_logits.shape[:-1], self.orig_vocab_size), - device=_logits.device, - dtype=_logits.dtype)) + logits_lst.append( + -torch.inf + * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype, + ) + ) logits_lst[-1][..., self.token_map] = _logits return logits_lst - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -141,30 +153,33 @@ def load_weights(self, weights: Iterable[tuple[str, if name == "token_map": if self.truncated_vocab_size < self.orig_vocab_size: - self.token_map = nn.Parameter(loaded_weight, - requires_grad=False) + self.token_map = nn.Parameter(loaded_weight, requires_grad=False) elif name in params_dict: weights_map[name] = loaded_weight - elif (getattr(self.config, "original_lm_head", False) - and name == "lm_heads.0.weight"): + elif ( + getattr(self.config, "original_lm_head", False) + and name == "lm_heads.0.weight" + ): weights_map["lm_head.weight"] = loaded_weight for name, loaded_weight in weights_map.items(): - if "lm_head" in name and self.token_map is not None and\ - loaded_weight.shape[0] > self.token_map.shape[0]: - + if ( + "lm_head" in name + and self.token_map is not None + and loaded_weight.shape[0] > self.token_map.shape[0] + ): loaded_weight = loaded_weight[self.token_map] param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) if self.token_map is not None: self.token_map.to(device=self.lm_heads[0].weight.device) - assert (self.truncated_vocab_size - == self.orig_vocab_size) or (self.token_map is not None) + assert (self.truncated_vocab_size == self.orig_vocab_size) or ( + self.token_map is not None + ) return loaded_params diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 65b3ee1c0e18..47839a2c6b03 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -39,17 +39,26 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.midashenglm import DashengConfig @@ -63,7 +72,8 @@ def _resolve_tuple2(x: _Tuple2) -> tuple[int, int]: if isinstance(x, collections.abc.Sequence): assert len(x) == 2, ( - f"Expected a sequence of length 2, got {x} with length {len(x)}") + f"Expected a sequence of length 2, got {x} with length {len(x)}" + ) return cast(tuple[int, int], tuple(x)) return (x, x) @@ -80,12 +90,14 @@ def calculate_mel_frames_dasheng( if center: audio_length_samples = audio_length_samples + n_fft - return (int(1 + ((audio_length_samples - n_fft) / hop_size)) // - dasheng_subsampling // model_subsampling) + return ( + int(1 + ((audio_length_samples - n_fft) / hop_size)) + // dasheng_subsampling + // model_subsampling + ) class AudioPatchEmbed(nn.Module): - def __init__( self, input_size: _Tuple2 = 64, @@ -118,14 +130,14 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) if self.flatten: - x = torch.permute(torch.flatten( - x, 2, 3), (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c") + x = torch.permute( + torch.flatten(x, 2, 3), (0, 2, 1) + ) # rearrange(x, "b c f t -> b (f t) c") x = self.norm(x) return x class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace @@ -136,7 +148,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DashengMlp(nn.Module): - def __init__( self, in_features: int, @@ -170,7 +181,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DashengAttention(nn.Module): - def __init__( self, dim: int, @@ -237,7 +247,6 @@ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): class DashengBlock(nn.Module): - def __init__( self, dim: int, @@ -257,8 +266,9 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", ) - self.ls1 = (LayerScale(dim, init_values=init_values) - if init_values else nn.Identity()) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) self.norm2 = nn.LayerNorm(dim, eps=1e-6) self.mlp = DashengMlp( @@ -267,8 +277,9 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.ls2 = (LayerScale(dim, init_values=init_values) - if init_values else nn.Identity()) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) # Kwargs usually has a mask parameter that is passed to Attention def forward( @@ -282,7 +293,6 @@ def forward( class DashengFrontend(nn.Module): - def __init__(self, config: DashengConfig): super().__init__() self.config = config @@ -302,9 +312,7 @@ def __init__(self, config: DashengConfig): n_mels=self.config.n_mels, sample_rate=self.config.sample_rate, ) - self.register_buffer("melscale_fbanks", - melscale_fbanks, - persistent=False) + self.register_buffer("melscale_fbanks", melscale_fbanks, persistent=False) self.melscale_fbanks: torch.Tensor def forward(self, waveform: torch.Tensor) -> torch.Tensor: @@ -319,8 +327,7 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: normalized=False, center=self.config.center, ) - mel_spectrogram = ( - spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT + mel_spectrogram = (spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT # x has shape [batch, freq, time]. # F.amplitude_to_DB accepts inputs shaped as: # - [freq, time] @@ -339,7 +346,6 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: class DashengAudioTransformer(nn.Module): - def __init__( self, config: DashengConfig, @@ -365,9 +371,11 @@ def __init__( ) self.time_pos_embed = nn.Parameter( - torch.empty(1, config.embed_dim, 1, self.patch_embed.grid_size[1])) + torch.empty(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) + ) self.freq_pos_embed = nn.Parameter( - torch.empty(1, config.embed_dim, self.patch_embed.grid_size[0], 1)) + torch.empty(1, config.embed_dim, self.patch_embed.grid_size[0], 1) + ) self.blocks = nn.ModuleList( DashengBlock( dim=config.embed_dim, @@ -377,7 +385,9 @@ def __init__( init_values=config.init_values, quant_config=quant_config, prefix=f"{prefix}.blocks.{i}", - ) for i in range(config.depth)) + ) + for i in range(config.depth) + ) self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6) def forward_features( @@ -387,10 +397,12 @@ def forward_features( ) -> torch.Tensor: t = x.shape[-1] x = x + self.time_pos_embed[:, :, :, :t] - x = (x + self.freq_pos_embed[:, :, :, :] - ) # Just to support __getitem__ in posembed - x = torch.permute(torch.flatten(x, 2, 3), - (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c") + x = ( + x + self.freq_pos_embed[:, :, :, :] + ) # Just to support __getitem__ in posembed + x = torch.permute( + torch.flatten(x, 2, 3), (0, 2, 1) + ) # rearrange(x, "b c f t -> b (f t) c") for block in self.blocks: x = block(x, mask) x = self.norm(x) @@ -423,7 +435,8 @@ def forward( if x_length is not None: assert len(x_length) == len(x), ( - "batchsizes of input x and x_length need to be same") + "batchsizes of input x and x_length need to be same" + ) assert x_length.ndim == 1, "Lengths are of size (B,)" scaled_lengths = (x_length / (self.hop_length * 4)).long() mask = self._to_mask(max_length=t, lengths=scaled_lengths) @@ -444,7 +457,6 @@ def forward( class AudioProjectorSubsample(nn.Module): - def __init__( self, in_dim: int, @@ -483,13 +495,14 @@ def forward(self, x, mask=None): mask = mask[:, :-num_frames_to_discard] if mask is None: mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device) - x = x.reshape(batch_size, -1, self.k * - dim) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k) + x = x.reshape( + batch_size, -1, self.k * dim + ) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k) for layer in self.net: x = layer(x) mask = mask.reshape( - batch_size, -1, - self.k) # rearrange(mask, "b (s k) -> b s k", k=self.k) + batch_size, -1, self.k + ) # rearrange(mask, "b (s k) -> b s k", k=self.k) mask = mask.any(dim=-1).long() return x, mask @@ -503,7 +516,6 @@ class MiDashengLMAudioInputs(TypedDict): class MiDashengLMProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config() @@ -522,9 +534,7 @@ def get_max_audio_len(self): return 160000 -class MiDashengLMDummyInputsBuilder( - BaseDummyInputsBuilder[MiDashengLMProcessingInfo]): - +class MiDashengLMDummyInputsBuilder(BaseDummyInputsBuilder[MiDashengLMProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) @@ -547,16 +557,17 @@ def get_dummy_mm_data( audio_overrides = mm_options.get("audio") if mm_options else None return { - "audio": - self._get_dummy_audios(length=self.info.get_max_audio_len(), - num_audios=num_audios, - overrides=audio_overrides) + "audio": self._get_dummy_audios( + length=self.info.get_max_audio_len(), + num_audios=num_audios, + overrides=audio_overrides, + ) } class MiDashengLMMultiModalProcessor( - BaseMultiModalProcessor[MiDashengLMProcessingInfo]): - + BaseMultiModalProcessor[MiDashengLMProcessingInfo] +): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -578,8 +589,10 @@ def _call_hf_processor( (0, min_audio_len - audio.shape[-1]), mode="constant", constant_values=0, - ) if isinstance(audio, np.ndarray) - and audio.shape[-1] < min_audio_len else audio for audio in audios + ) + if isinstance(audio, np.ndarray) and audio.shape[-1] < min_audio_len + else audio + for audio in audios ] if processed_audios: @@ -590,7 +603,9 @@ def _call_hf_processor( prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - mm_kwargs = dict(**mm_kwargs, ) + mm_kwargs = dict( + **mm_kwargs, + ) return super()._call_hf_processor( prompt=prompt, @@ -627,11 +642,13 @@ def _get_prompt_updates( if audio_length is None: audio_output_lengths = [] else: - audio_length_np = (audio_length.cpu().numpy() if isinstance( - audio_length, torch.Tensor) else audio_length) + audio_length_np = ( + audio_length.cpu().numpy() + if isinstance(audio_length, torch.Tensor) + else audio_length + ) audio_output_lengths = [ - max(1, calculate_mel_frames_dasheng( - int(length))) # at least one frame + max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame for length in audio_length_np ] @@ -708,22 +725,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config self.make_empty_intermediate_tensors = ( - self.decoder.make_empty_intermediate_tensors) + self.decoder.make_empty_intermediate_tensors + ) - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of {name}. Got type: {type(mm_input)}") + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): return mm_input.reshape(-1, *mm_input.shape[2:]) if name == "input_values": max_length = max(tensor.shape[1] for tensor in mm_input) padded_mm_input = [ - torch.nn.functional.pad(tensor, - (0, max_length - tensor.shape[1])) - if tensor.shape[1] < max_length else tensor + torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1])) + if tensor.shape[1] < max_length + else tensor for tensor in mm_input ] return torch.concat(padded_mm_input) @@ -731,65 +749,67 @@ def _validate_and_reshape_mm_tensor(self, mm_input: object, return torch.concat(mm_input) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[MiDashengLMAudioInputs]: + self, **kwargs: object + ) -> Optional[MiDashengLMAudioInputs]: input_values = kwargs.pop("input_values", None) audio_length = kwargs.pop("audio_length", None) if input_values is None: return None input_values = self._validate_and_reshape_mm_tensor( - input_values, "input_values") + input_values, "input_values" + ) audio_length = self._validate_and_reshape_mm_tensor( - audio_length, "audio_length") + audio_length, "audio_length" + ) if not isinstance(input_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_values)}") + raise ValueError( + "Incorrect type of audio input features. " + f"Got type: {type(input_values)}" + ) return MiDashengLMAudioInputs( input_values=input_values, audio_length=audio_length, ) - def _process_audio_input( - self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor: + def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor: # Process audio through encoder and projector input_values = audio_input["input_values"] audio_length = audio_input["audio_length"] - encoder_out, encoder_atts = self.audio_encoder(input_values, - audio_length) + encoder_out, encoder_atts = self.audio_encoder(input_values, audio_length) audio_embeddings, _ = self.audio_projector(encoder_out, encoder_atts) - audio_embeddings = audio_embeddings.to( - audio_input["input_values"].dtype) + audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype) batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape - audio_length_np = (audio_length.cpu().numpy() if isinstance( - audio_length, torch.Tensor) else audio_length) + audio_length_np = ( + audio_length.cpu().numpy() + if isinstance(audio_length, torch.Tensor) + else audio_length + ) audio_output_lengths = [ - max(1, calculate_mel_frames_dasheng( - int(length))) # at least one frame + max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame for length in audio_length_np ] audio_output_lengths = torch.tensor(audio_output_lengths).to( - audio_embeddings.device) + audio_embeddings.device + ) audio_feature_mask = torch.arange( - max_audio_tokens, - device=audio_embeddings.device).unsqueeze(0).expand( - batch_size, - max_audio_tokens) < audio_output_lengths.unsqueeze(1) + max_audio_tokens, device=audio_embeddings.device + ).unsqueeze(0).expand( + batch_size, max_audio_tokens + ) < audio_output_lengths.unsqueeze(1) - masked_audio_features = audio_embeddings[audio_feature_mask].view( - -1, embed_dim) + masked_audio_features = audio_embeddings[audio_feature_mask].view(-1, embed_dim) - return torch.split(masked_audio_features, - audio_output_lengths.tolist()) + return torch.split(masked_audio_features, audio_output_lengths.tolist()) def get_language_model(self) -> torch.nn.Module: return self.decoder - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: @@ -828,7 +848,6 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.decoder.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index d256c1f3eed7..e01e06421842 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -25,6 +25,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiMo model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -39,7 +40,9 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2Model from vllm.sequence import IntermediateTensors @@ -54,9 +57,9 @@ "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class MiMoModel(Qwen2Model): - def forward( self, input_ids: torch.Tensor, @@ -81,15 +84,13 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states = hidden_states + residual return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -104,18 +105,19 @@ def load_weights(self, weights: Iterable[tuple[str, continue if "rotary_emb.inv_freq" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -139,15 +141,13 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) config = vllm_config.model_config.hf_config @@ -159,25 +159,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config - self.model = MiMoModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MiMoModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def compute_logits( self, diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index 9c1e36094c4a..b678a06b7f20 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiMo-MTP model.""" + from collections.abc import Iterable from typing import Optional @@ -31,7 +32,9 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer from vllm.sequence import IntermediateTensors @@ -40,7 +43,6 @@ class MiMoMultiTokenPredictorLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -51,19 +53,18 @@ def __init__( ) -> None: super().__init__() - self.token_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.hidden_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.input_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) - self.mtp_block = Qwen2DecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.token_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hidden_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_proj = nn.Linear( + config.hidden_size * 2, config.hidden_size, bias=False + ) + self.mtp_block = Qwen2DecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -79,17 +80,17 @@ def forward( previous_hidden_states = self.hidden_layernorm(previous_hidden_states) hidden_states = self.input_proj( - torch.cat([previous_hidden_states, inputs_embeds], dim=-1)) + torch.cat([previous_hidden_states, inputs_embeds], dim=-1) + ) - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - residual=None) + hidden_states, residual = self.mtp_block( + positions=positions, hidden_states=hidden_states, residual=None + ) hidden_states = residual + hidden_states return self.final_layernorm(hidden_states) class MiMoMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -102,18 +103,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size, ) - self.mtp_layers = torch.nn.ModuleDict({ - str(idx): - MiMoMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) + self.mtp_layers = torch.nn.ModuleDict( + { + str(idx): MiMoMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range( + self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers, + ) + } + ) self.logits_processor = LogitsProcessor(config.vocab_size) @@ -128,7 +132,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) return self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)]( @@ -150,16 +153,17 @@ def compute_logits( class MiMoMTP(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config - self.model = MiMoMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head")) + self.model = MiMoMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -174,8 +178,9 @@ def forward( spec_step_idx: int = 0, ) -> torch.Tensor: assert spec_step_idx == 0, "mimo_mtp only support predict one token now" - hidden_states = self.model(input_ids, positions, hidden_states, - inputs_embeds, spec_step_idx) + hidden_states = self.model( + input_ids, positions, hidden_states, inputs_embeds, spec_step_idx + ) return hidden_states def compute_logits( @@ -183,11 +188,9 @@ def compute_logits( hidden_states: torch.Tensor, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.model.compute_logits(hidden_states, self.lm_head, - spec_step_idx) + return self.model.compute_logits(hidden_states, self.lm_head, spec_step_idx) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -199,12 +202,11 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: continue name = self.map_model_name_to_mtp_param_name(name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -216,7 +218,7 @@ def load_weights(self, weights: Iterable[tuple[str, # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): + if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. @@ -231,12 +233,12 @@ def load_weights(self, weights: Iterable[tuple[str, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if "mtp_layers" not in name and ("embed_tokens" not in name - and "lm_head" not in name): + if "mtp_layers" not in name and ( + "embed_tokens" not in name and "lm_head" not in name + ): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -253,8 +255,10 @@ def map_model_name_to_mtp_param_name(self, name: str) -> str: name = name.replace(match.group(), f"{match.group(1)}{new_num}.") # check for early turn name_without_prefix = [ - "token_layernorm", "hidden_layernorm", "input_proj", - "final_layernorm" + "token_layernorm", + "hidden_layernorm", + "input_proj", + "final_layernorm", ] for sub_name in name_without_prefix: if sub_name in name: @@ -272,7 +276,11 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: Add .mtp_block for modules in transformer layer block for spec layer """ spec_layer_weight_names = [ - "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + "embed_tokens", + "enorm", + "hnorm", + "eh_proj", + "shared_head", ] spec_layer_weight = False for weight_name in spec_layer_weight_names: @@ -281,6 +289,7 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: break if not spec_layer_weight: # treat rest weights as weights for transformer layer block - name = name.replace(f"model.layers.{spec_layer}.", - f"model.layers.{spec_layer}.mtp_block.") + name = name.replace( + f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block." + ) return name diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 55fe3e2ae3ae..06cb6bc61576 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable from itertools import islice @@ -35,30 +36,42 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class MiniCPMMoE(nn.Module): @@ -90,34 +103,53 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - self.gate = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype, - quant_config=None) + self.gate = ReplicatedLinear( + self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None, + ) self.ws = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device=current_platform.device_type, - dtype=self.params_dtype)) + torch.empty( + self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) self.w2s = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device=current_platform.device_type, - dtype=self.params_dtype)) - - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): + torch.empty( + self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device=current_platform.device_type, + dtype=self.params_dtype, + ) + ) + + set_weight_attrs( + self.ws, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2s, + { + "weight_loader": self.weight_loader, + }, + ) + + def weight_loader( + self, + param: nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + expert_id: int, + ): tp_rank = get_tensor_model_parallel_rank() param_data = param.data shard_size = self.intermediate_size @@ -125,8 +157,9 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, if weight_name.endswith("w1.weight"): param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] + param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ + shard, : + ] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] @@ -136,27 +169,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - topk_weights, topk_ids, _ = fused_topk(hidden_states, - router_logits, - self.top_k, - renormalize=True) + topk_weights, topk_ids, _ = fused_topk( + hidden_states, router_logits, self.top_k, renormalize=True + ) - final_hidden_states = fused_experts(hidden_states, - self.ws, - self.w2s, - topk_weights, - topk_ids, - inplace=True) + final_hidden_states = fused_experts( + hidden_states, self.ws, self.w2s, topk_weights, topk_ids, inplace=True + ) if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) class MiniCPMMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -167,20 +194,20 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if hidden_act == "silu": self.act_fn = SiluAndMul() elif hidden_act == "fatrelu": self.act_fn = FatreluAndMul(threshold=hidden_act_param) else: - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu and fatrelu are supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu and fatrelu are supported for now." + ) def forward(self, x): gate_up, _ = self.gate_up_proj(x) @@ -190,7 +217,6 @@ def forward(self, x): class MiniCPMAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -249,13 +275,15 @@ def __init__( rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -274,7 +302,6 @@ def forward( class MiniCPMDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -289,15 +316,15 @@ def __init__( self.hidden_size = config.hidden_size self.rope_theta = getattr(config, "rope_theta", 10000) self.rope_scaling = getattr(config, "rope_scaling", None) - self.max_position_embeddings = getattr(config, - "max_position_embeddings", 8192) + self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.prefix = prefix self._init_attn_block() self._init_ffn_block() def _init_attn_block(self): - self.input_layernorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.input_layernorm = RMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) self.self_attn = MiniCPMAttention( hidden_size=self.hidden_size, num_heads=self.config.num_attention_heads, @@ -311,15 +338,16 @@ def _init_attn_block(self): ) def _init_ffn_block(self): - self.post_attention_layernorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: self.mlp = MiniCPMMLP( hidden_size=self.hidden_size, intermediate_size=self.config.intermediate_size, hidden_act=self.config.hidden_act, - hidden_act_param=getattr(self.config, "hidden_act_param", 0.), + hidden_act_param=getattr(self.config, "hidden_act_param", 0.0), quant_config=self.quant_config, ) else: @@ -327,7 +355,8 @@ def _init_ffn_block(self): num_experts=self.config.num_experts, top_k=self.config.num_experts_per_tok, hidden_size=self.config.hidden_size, - intermediate_size=self.config.intermediate_size) + intermediate_size=self.config.intermediate_size, + ) def forward( self, @@ -342,22 +371,23 @@ def forward( positions=positions, hidden_states=hidden_states, ) - hidden_states = residual + hidden_states * \ - (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + hidden_states = residual + hidden_states * ( + self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) + ) # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states * \ - (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + hidden_states = residual + hidden_states * ( + self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) + ) return hidden_states, None @support_torch_compile class MiniCPMModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -369,8 +399,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.cache_config = cache_config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -384,9 +417,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.aux_hidden_state_layers = tuple[int, ...]() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], self.config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size + ) def _init_layers( self, @@ -398,8 +431,10 @@ def _init_layers( self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MiniCPMDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) @@ -411,8 +446,9 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, - list[torch.Tensor]]]: + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -425,11 +461,12 @@ def forward( aux_hidden_states = [] for idx, layer in enumerate( - islice(self.layers, self.start_layer, self.end_layer)): + islice(self.layers, self.start_layer, self.end_layer) + ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append( - hidden_states + - residual if residual is not None else hidden_states) + hidden_states + residual if residual is not None else hidden_states + ) hidden_states, residual = layer( positions, hidden_states, @@ -437,10 +474,9 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states = self.norm(hidden_states) @@ -448,8 +484,7 @@ def forward( return hidden_states, aux_hidden_states return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -460,8 +495,11 @@ def load_weights(self, weights: Iterable[tuple[str, ] expert_params_mapping = [ # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) + ( + "ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) for expert_id in range(self.num_experts) for weight_name in ["w1", "w2", "w3"] ] @@ -471,12 +509,11 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -498,10 +535,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + weight_loader( + param, loaded_weight, weight_name, expert_id=expert_id + ) break else: # Skip loading extra bias for GPTQ models. @@ -510,8 +546,9 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -551,8 +588,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.cache_config = cache_config self.quant_config = quant_config - self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = self._init_model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) unpadded_vocab_size = config.vocab_size if lora_config: @@ -564,7 +602,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -572,10 +611,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.scale_width = self.config.hidden_size / self.config.dim_model_base - self.logits_processor = LogitsProcessor(unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) @@ -596,10 +635,12 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, - list[torch.Tensor]]]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) if isinstance(model_output, tuple) and len(model_output) == 2: # Aux hidden states are present. @@ -621,11 +662,9 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 92c13e81bf3e..35f02a1538e8 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -24,6 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM3 model compatible with HuggingFace weights.""" + from typing import Any, Optional import torch @@ -34,20 +35,23 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, - MiniCPMForCausalLM, - MiniCPMModel) +from vllm.model_executor.models.minicpm import ( + MiniCPMDecoderLayer, + MiniCPMForCausalLM, + MiniCPMModel, +) from .utils import make_layers class MiniCPM3Attention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -83,33 +87,37 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config) + self.q_a_proj = ReplicatedLinear( + self.hidden_size, self.q_lora_rank, bias=False, quant_config=quant_config + ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config) - - self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size, - self.kv_lora_rank + - self.qk_rope_head_dim, - bias=False, - quant_config=quant_config) - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, - quant_config=quant_config) + quant_config=quant_config, + ) # O projection. - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) self.rotary_emb = get_rope( self.qk_rope_head_dim, @@ -118,13 +126,15 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_local_heads, - self.qk_head_dim, - self.scaling, - num_kv_heads=self.num_local_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -135,55 +145,52 @@ def forward( q = self.q_a_layernorm(q) q, _ = self.q_b_proj(q) q = q.view(-1, self.num_local_heads, self.qk_head_dim) - _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states) - kv_a, _ = latent_cache.split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a.contiguous()) kv, _ = self.kv_b_proj(kv_a) - kv = kv.view(-1, self.num_local_heads, - self.qk_nope_head_dim + self.v_head_dim) + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank:] + k_pe = latent_cache[:, :, self.kv_lora_rank :] q_pe, k_pe = self.rotary_emb( positions, q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim), - k_pe.reshape(-1, self.qk_rope_head_dim)) + k_pe.reshape(-1, self.qk_rope_head_dim), + ) q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim) k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) - q[..., self.qk_nope_head_dim:] = q_pe + q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) - k[..., :self.qk_nope_head_dim] = k_nope - k[..., self.qk_nope_head_dim:] = k_pe + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe q = q.reshape(-1, self.num_local_heads * self.qk_head_dim) k = k.view(-1, self.num_local_heads * self.qk_head_dim) v = torch.nn.functional.pad( - v, [0, self.qk_head_dim - self.v_head_dim], - value=0).view(-1, self.num_local_heads * self.qk_head_dim) + v, [0, self.qk_head_dim - self.v_head_dim], value=0 + ).view(-1, self.num_local_heads * self.qk_head_dim) attn_output = self.attn(q, k, v) - attn_output = attn_output.view( - -1, self.num_local_heads, - self.qk_head_dim)[..., :self.v_head_dim].reshape( - -1, self.num_local_heads * self.v_head_dim) + attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): - def _init_attn_block(self): - self.input_layernorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.input_layernorm = RMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) self.self_attn = MiniCPM3Attention( config=self.config, hidden_size=self.hidden_size, @@ -203,7 +210,6 @@ def _init_attn_block(self): class MiniCPM3Model(MiniCPMModel): - def _init_layers( self, prefix: str, @@ -214,8 +220,10 @@ def _init_layers( self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MiniCPM3DecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) class MiniCPM3ForCausalLM(MiniCPMForCausalLM): diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 2af0d546ce63..6c635b248109 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only EagleMiniCPM model compatible with HuggingFace weights.""" + import math from collections.abc import Iterable from typing import Optional, Union @@ -37,7 +38,10 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors @@ -45,12 +49,15 @@ from .minicpm import MiniCPMAttention as EagleMiniCPMAttention from .minicpm import MiniCPMMLP as EagleMiniCPMMLP from .minicpm import MiniCPMMoE as EagleMiniCPMMoE -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) class EagleMiniCPMDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -65,15 +72,15 @@ def __init__( self.hidden_size = config.hidden_size self.rope_theta = getattr(config, "rope_theta", 10000) self.rope_scaling = getattr(config, "rope_scaling", None) - self.max_position_embeddings = getattr(config, - "max_position_embeddings", 8192) + self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.prefix = prefix self._init_attn_block() self._init_ffn_block() def _init_attn_block(self): - self.input_layernorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.input_layernorm = RMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) self.self_attn = EagleMiniCPMAttention( hidden_size=self.hidden_size, num_heads=self.config.num_attention_heads, @@ -87,15 +94,16 @@ def _init_attn_block(self): ) def _init_ffn_block(self): - self.post_attention_layernorm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + self.config.hidden_size, eps=self.config.rms_norm_eps + ) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: self.mlp = EagleMiniCPMMLP( hidden_size=self.hidden_size, intermediate_size=self.config.intermediate_size, hidden_act=self.config.hidden_act, - hidden_act_param=getattr(self.config, "hidden_act_param", 0.), + hidden_act_param=getattr(self.config, "hidden_act_param", 0.0), quant_config=self.quant_config, ) else: @@ -103,7 +111,8 @@ def _init_ffn_block(self): num_experts=self.config.num_experts, top_k=self.config.num_experts_per_tok, hidden_size=self.config.hidden_size, - intermediate_size=self.config.intermediate_size) + intermediate_size=self.config.intermediate_size, + ) def forward( self, @@ -118,27 +127,26 @@ def forward( positions=positions, hidden_states=hidden_states, ) - hidden_states = residual + hidden_states * \ - (self.config.scale_depth / math.sqrt(self.config.mup_denominator)) + hidden_states = residual + hidden_states * ( + self.config.scale_depth / math.sqrt(self.config.mup_denominator) + ) # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states * \ - (self.config.scale_depth / math.sqrt(self.config.mup_denominator)) + hidden_states = residual + hidden_states * ( + self.config.scale_depth / math.sqrt(self.config.mup_denominator) + ) return hidden_states, None @support_torch_compile class EagleMiniCPMModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - start_layer: int = 0): + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", start_layer: int = 0 + ): super().__init__() config = vllm_config.speculative_config.draft_model_config.hf_config @@ -149,13 +157,16 @@ def __init__(self, self.config = config self.cache_config = cache_config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - self.fc = torch.nn.Linear(self.config.hidden_size * 2, - self.config.hidden_size, - bias=False) + self.fc = torch.nn.Linear( + self.config.hidden_size * 2, self.config.hidden_size, bias=False + ) self.input_norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.embed_tokens = VocabParallelEmbedding( @@ -164,12 +175,11 @@ def __init__(self, org_num_embeddings=config.vocab_size, ) self.num_experts = getattr(self.config, "num_experts", 0) - self._init_layers(prefix, config, cache_config, quant_config, - start_layer) + self._init_layers(prefix, config, cache_config, quant_config, start_layer) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], self.config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size + ) def _init_layers( self, @@ -179,14 +189,17 @@ def _init_layers( quant_config: Optional[QuantizationConfig], start_layer: int, ): - self.eagle_layers = nn.ModuleList([ - EagleMiniCPMDecoderLayer( - config, - cache_config, - quant_config, - f"{prefix}.eagle_layers.{i + start_layer}", - ) for i in range(self.config.num_hidden_layers) - ]) + self.eagle_layers = nn.ModuleList( + [ + EagleMiniCPMDecoderLayer( + config, + cache_config, + quant_config, + f"{prefix}.eagle_layers.{i + start_layer}", + ) + for i in range(self.config.num_hidden_layers) + ] + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) @@ -202,8 +215,7 @@ def forward( input_embeds = self.input_norm1(input_embeds) hidden_states = self.input_norm2(hidden_states) - hidden_states = self.fc( - torch.cat((input_embeds, hidden_states), dim=-1)) + hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1)) residual = None for layer in self.eagle_layers: hidden_states, residual = layer( @@ -214,8 +226,7 @@ def forward( return hidden_states, hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -226,8 +237,11 @@ def load_weights(self, weights: Iterable[tuple[str, ] expert_params_mapping = [ # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) + ( + "ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) for expert_id in range(self.num_experts) for weight_name in ["w1", "w2", "w3"] ] @@ -237,12 +251,11 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -264,10 +277,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) + weight_loader( + param, loaded_weight, weight_name, expert_id=expert_id + ) break else: # Skip loading extra bias for GPTQ models. @@ -276,8 +288,9 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -319,11 +332,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config target_layer_num = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config) + vllm_config.parallel_config + ) - self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model"), - start_layer=target_layer_num) + self.model = self._init_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + start_layer=target_layer_num, + ) unpadded_vocab_size = config.vocab_size if lora_config: @@ -335,7 +351,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -343,19 +360,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) self.scale_width = self.config.hidden_size / self.config.dim_model_base - self.logits_processor = LogitsProcessor(unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) - def _init_model(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - start_layer: int = 0): - return EagleMiniCPMModel(vllm_config=vllm_config, - prefix=prefix, - start_layer=start_layer) + def _init_model( + self, *, vllm_config: VllmConfig, prefix: str = "", start_layer: int = 0 + ): + return EagleMiniCPMModel( + vllm_config=vllm_config, prefix=prefix, start_layer=start_layer + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -366,8 +381,7 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - hidden_states, hidden_states2 = self.model(input_ids, positions, - hidden_states) + hidden_states, hidden_states2 = self.model(input_ids, positions, hidden_states) hidden_states = hidden_states / self.scale_width hidden_states2 = hidden_states2 / self.scale_width return hidden_states, hidden_states2 @@ -379,11 +393,9 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 74b2a2e62cd5..34f05122abe3 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" + from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Callable, Literal, Optional, Union @@ -30,31 +31,47 @@ from torch import nn from transformers import BatchFeature from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.whisper.modeling_whisper import (ACT2FN, - WhisperAttention, - WhisperConfig, - WhisperEncoder) +from transformers.models.whisper.modeling_whisper import ( + ACT2FN, + WhisperAttention, + WhisperConfig, + WhisperEncoder, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) -from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, - DictEmbeddingItems, ModalityData, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + NestedTensors, +) +from vllm.multimodal.parse import ( + AudioItem, + AudioProcessorItems, + DictEmbeddingItems, + ModalityData, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, - MiniCPMVDummyInputsBuilder, - MiniCPMVMultiModalDataParser, - MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, - _minicpmv_field_config) -from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, - maybe_prefix) +from .minicpmv import ( + _MAX_FRAMES_PER_VIDEO, + MiniCPMV2_6, + MiniCPMVDummyInputsBuilder, + MiniCPMVMultiModalDataParser, + MiniCPMVMultiModalProcessor, + MiniCPMVProcessingInfo, + _minicpmv_field_config, +) +from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix CPU_DEVICE = torch.device("cpu") @@ -68,6 +85,7 @@ class MiniCPMOAudioFeatureInputs(TensorSchema): - l: Length - s: Number of slices """ + type: Literal["audio_features"] = "audio_features" audio_features: Annotated[ @@ -96,9 +114,10 @@ class MiniCPMOAudioEmbeddingInputs(TensorSchema): - bn: Batch size * number of audios - s: Number of slices - h: Hidden size (must match language model backbone) - + Length of each slice may vary, so pass it as a list. """ + type: Literal["audio_embeds"] = "audio_embeds" audio_embeds: Annotated[ @@ -107,8 +126,7 @@ class MiniCPMOAudioEmbeddingInputs(TensorSchema): ] -MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, - MiniCPMOAudioEmbeddingInputs] +MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, MiniCPMOAudioEmbeddingInputs] def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): @@ -125,7 +143,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems): - def __init__( self, data: Mapping[str, torch.Tensor], @@ -143,7 +160,6 @@ def __init__( class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): - def _parse_audio_data( self, data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], @@ -215,18 +231,17 @@ def get_num_frames_with_most_features( max_image_tokens = self.get_max_image_tokens() * max_images max_audio_tokens = self.get_max_audio_tokens() * max_audios - max_total_frames = self.get_max_video_frames(seq_len - - max_image_tokens - - max_audio_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_total_frames = self.get_max_video_frames( + seq_len - max_image_tokens - max_audio_tokens + ) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO + ) return max(max_frames_per_video, 1) -class MiniCPMODummyInputsBuilder( - MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): - +class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) @@ -241,16 +256,17 @@ def get_dummy_mm_data( mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) - audio_len = self.info.get_max_audio_chunks_with_most_features() * \ - self.info.get_default_audio_sampling_rate() + audio_len = ( + self.info.get_max_audio_chunks_with_most_features() + * self.info.get_default_audio_sampling_rate() + ) audio_overrides = mm_options.get("audio") if mm_options else None audio_mm_data = { - "audio": - self._get_dummy_audios(length=audio_len, - num_audios=num_audios, - overrides=audio_overrides) + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) } return { @@ -259,12 +275,11 @@ def get_dummy_mm_data( } -class MiniCPMOMultiModalProcessor( - MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]): - +class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return MiniCPMOMultiModalDataParser( - target_sr=self.info.get_default_audio_sampling_rate()) + target_sr=self.info.get_default_audio_sampling_rate() + ) def get_audio_prompt_texts( self, @@ -287,10 +302,11 @@ def process_audios( if (audios := mm_data.get("audios")) is None: return {} - parsed_audios = (self._get_data_parser().parse_mm_data({ - "audio": audios - }).get_items("audio", - (MiniCPMOAudioEmbeddingItems, AudioProcessorItems))) + parsed_audios = ( + self._get_data_parser() + .parse_mm_data({"audio": audios}) + .get_items("audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)) + ) if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): audio_inputs = {} @@ -298,9 +314,7 @@ def process_audios( audio_inputs = self._base_call_hf_processor( prompts=[self.info.audio_pattern] * len(parsed_audios), mm_data={"audios": [[audio] for audio in parsed_audios]}, - mm_kwargs={ - **mm_kwargs, "chunk_input": True - }, + mm_kwargs={**mm_kwargs, "chunk_input": True}, tok_kwargs=tok_kwargs, out_keys={"audio_features", "audio_feature_lens"}, ) @@ -308,7 +322,8 @@ def process_audios( # Avoid padding since we need the output for each audio to be # independent of other audios for the cache to work correctly unpadded_audio_features = [ - feat[:, :feature_len] for feat, feature_len in zip( + feat[:, :feature_len] + for feat, feature_len in zip( audio_inputs["audio_features"], audio_inputs["audio_feature_lens"], ) @@ -348,12 +363,14 @@ def _get_prompt_updates( def get_audio_replacement(item_idx: int): audios = mm_items.get_items( - "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)) + "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems) + ) if isinstance(audios, MiniCPMOAudioEmbeddingItems): single_audio_embeds = audios.get(item_idx)["audio_embeds"] audio_len = self.info.get_audio_len_by_num_chunks( - sum(map(len, single_audio_embeds))) + sum(map(len, single_audio_embeds)) + ) else: audio_len = audios.get_audio_length(item_idx) @@ -364,9 +381,11 @@ def get_audio_replacement(item_idx: int): return [ *base_updates, - PromptReplacement(modality="audio", - target=audio_placeholder, - replacement=get_audio_replacement), + PromptReplacement( + modality="audio", + target=audio_placeholder, + replacement=get_audio_replacement, + ), ] def _get_mm_fields_config( @@ -378,16 +397,11 @@ def _get_mm_fields_config( class MultiModalProjector(nn.Module): - def __init__(self, in_dim: int, out_dim: int): super().__init__() - self.linear1 = nn.Linear(in_features=in_dim, - out_features=out_dim, - bias=True) + self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) self.relu = nn.ReLU() - self.linear2 = nn.Linear(in_features=out_dim, - out_features=out_dim, - bias=True) + self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) def forward(self, audio_features: torch.Tensor) -> torch.Tensor: hidden_states = self.relu(self.linear1(audio_features)) @@ -396,7 +410,6 @@ def forward(self, audio_features: torch.Tensor) -> torch.Tensor: class MiniCPMWhisperEncoderLayer(nn.Module): - def __init__(self, config: WhisperConfig, layer_idx: int): super().__init__() self.embed_dim = config.d_model @@ -428,39 +441,40 @@ def forward( attention_mask=attention_mask, past_key_value=past_key_values, ) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, - p=self.activation_dropout, - training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16: hidden_states = cast_overflow_tensors(hidden_states) - outputs = (hidden_states, ) + outputs = (hidden_states,) return outputs class MiniCPMWhisperEncoder(WhisperEncoder): - def __init__(self, config: WhisperConfig): super().__init__(config) - self.layers = nn.ModuleList([ - MiniCPMWhisperEncoderLayer(config, layer_idx=i) - for i in range(config.encoder_layers) - ]) + self.layers = nn.ModuleList( + [ + MiniCPMWhisperEncoderLayer(config, layer_idx=i) + for i in range(config.encoder_layers) + ] + ) def forward( self, @@ -468,8 +482,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, ) -> BaseModelOutputWithPast: # Ignore copy - input_features = input_features.to(dtype=self.conv1.weight.dtype, - device=self.conv1.weight.device) + input_features = input_features.to( + dtype=self.conv1.weight.dtype, device=self.conv1.weight.device + ) inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) @@ -478,17 +493,17 @@ def forward( embed_pos = self.embed_positions.weight - embed_pos = embed_pos[:inputs_embeds.shape[1], :] + embed_pos = embed_pos[: inputs_embeds.shape[1], :] hidden_states = inputs_embeds + embed_pos - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) encoder_states = () for idx, encoder_layer in enumerate(self.layers): - encoder_states = encoder_states + (hidden_states, ) + encoder_states = encoder_states + (hidden_states,) to_drop = False if self.training: dropout_probability = torch.rand([]) @@ -507,7 +522,7 @@ def forward( hidden_states = layer_outputs[0] hidden_states = self.layer_norm(hidden_states) - encoder_states = encoder_states + (hidden_states, ) + encoder_states = encoder_states + (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -518,7 +533,8 @@ def forward( @MULTIMODAL_REGISTRY.register_processor( MiniCPMOMultiModalProcessor, info=MiniCPMOProcessingInfo, - dummy_inputs=MiniCPMODummyInputsBuilder) + dummy_inputs=MiniCPMODummyInputsBuilder, +) class MiniCPMO(MiniCPMV2_6): packed_modules_mapping = { "qkv_proj": [ @@ -545,8 +561,9 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - self.apm = self.init_audio_module(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "apm")) + self.apm = self.init_audio_module( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") + ) self.audio_token_id = None @@ -555,16 +572,16 @@ def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): audio_config = self.config.audio_config model = MiniCPMWhisperEncoder(audio_config) audio_output_dim = int(audio_config.encoder_ffn_dim // 4) - self.audio_avg_pooler = \ - nn.AvgPool1d(self.config.audio_pool_step, - stride=self.config.audio_pool_step) - self.audio_projection_layer = \ - MultiModalProjector(in_dim=audio_output_dim,out_dim=self.embed_dim) + self.audio_avg_pooler = nn.AvgPool1d( + self.config.audio_pool_step, stride=self.config.audio_pool_step + ) + self.audio_projection_layer = MultiModalProjector( + in_dim=audio_output_dim, out_dim=self.embed_dim + ) self.audio_encoder_layer = -1 return model - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["tts"]) return loader.load_weights(weights) @@ -585,14 +602,13 @@ def subsequent_chunk_mask( start_indices = torch.zeros_like(row_indices) else: # Compute start indices vectorially - start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, - min=0) + start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, min=0) start_indices = start_chunk_indices * chunk_size # Compute ending indices vectorially end_chunk_indices = chunk_indices + 1 - end_indices = torch.clamp(end_chunk_indices * chunk_size + - num_lookhead, - max=size) + end_indices = torch.clamp( + end_chunk_indices * chunk_size + num_lookhead, max=size + ) # Create column indices for broadcasting col_indices = torch.arange(size, device=device).unsqueeze(0) start_indices = start_indices.unsqueeze(1) @@ -601,19 +617,18 @@ def subsequent_chunk_mask( ret = (col_indices >= start_indices) & (col_indices < end_indices) return ret - def _get_feat_extract_output_lengths(self, - input_lengths: torch.LongTensor): + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 input_lengths_after_pooling = ( - input_lengths_after_cnn - - self.config.audio_pool_step) // self.config.audio_pool_step + 1 - input_lengths_after_pooling = input_lengths_after_pooling.to( - dtype=torch.int32) + input_lengths_after_cnn - self.config.audio_pool_step + ) // self.config.audio_pool_step + 1 + input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) return input_lengths_after_cnn, input_lengths_after_pooling def get_audio_hidden_states( - self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]: + self, data: MiniCPMOAudioFeatureInputs + ) -> list[torch.Tensor]: chunk_length = self.config.audio_chunk_length # (bs, 80, frames) or [], multi audios need filled in advance @@ -642,23 +657,26 @@ def get_audio_hidden_states( max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # Create a sequence tensor of shape (batch_size, max_seq_len) - seq_range = (torch.arange( - 0, - max_seq_len, - dtype=audio_feature_lens.dtype, - device=audio_feature_lens.device).unsqueeze(0).expand( - batch_size, max_seq_len)) - lengths_expand = audio_feature_lens.unsqueeze(1).expand( - batch_size, max_seq_len) + seq_range = ( + torch.arange( + 0, + max_seq_len, + dtype=audio_feature_lens.dtype, + device=audio_feature_lens.device, + ) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) + lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) # Create mask padding_mask = seq_range >= lengths_expand # 1 for padded values - audio_attention_mask_ = padding_mask.view( - batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len, - max_seq_len) + audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( + batch_size, 1, max_seq_len, max_seq_len + ) audio_attention_mask = audio_attention_mask_.to( - dtype=self.apm.conv1.weight.dtype, - device=self.apm.conv1.weight.device) + dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device + ) if chunk_length > 0: chunk_num_frame = int(chunk_length * 50) @@ -669,20 +687,22 @@ def get_audio_hidden_states( device=audio_attention_mask_.device, ) audio_attention_mask_ = torch.logical_or( - audio_attention_mask_, torch.logical_not(chunk_mask)) + audio_attention_mask_, torch.logical_not(chunk_mask) + ) audio_attention_mask[audio_attention_mask_] = float("-inf") audio_states = self.apm( - wavforms, attention_mask=audio_attention_mask).hidden_states[ - self.audio_encoder_layer] + wavforms, attention_mask=audio_attention_mask + ).hidden_states[self.audio_encoder_layer] audio_embeds = self.audio_projection_layer(audio_states) audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) - _, feature_lens_after_pooling = \ - self._get_feat_extract_output_lengths(audio_feature_lens) + _, feature_lens_after_pooling = self._get_feat_extract_output_lengths( + audio_feature_lens + ) num_audio_tokens = feature_lens_after_pooling @@ -692,7 +712,8 @@ def get_audio_hidden_states( target_audio_embeds_lst = list[torch.Tensor]() for _ in range(len(audio_feature_lens_raw[i])): target_audio_embeds_lst.append( - audio_embeds[idx, :num_audio_tokens[idx], :]) + audio_embeds[idx, : num_audio_tokens[idx], :] + ) idx += 1 final_audio_embeds.append(torch.cat(target_audio_embeds_lst)) @@ -700,7 +721,8 @@ def get_audio_hidden_states( return final_audio_embeds def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]: + self, **kwargs: object + ) -> Optional[MiniCPMOAudioInputs]: audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) @@ -714,8 +736,9 @@ def _parse_and_validate_audio_input( if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_embeds. " - f"Got type: {type(audio_embeds)}") + raise ValueError( + f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}" + ) audio_embeds_flat = flatten_bn(audio_embeds) @@ -725,13 +748,16 @@ def _parse_and_validate_audio_input( ) if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_features. " - f"Got type: {type(audio_features)}") + raise ValueError( + f"Incorrect type of audio_features. Got type: {type(audio_features)}" + ) audio_feature_lens = kwargs.pop("audio_feature_lens") if not isinstance(audio_feature_lens, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_feature_lens. " - f"Got type: {type(audio_feature_lens)}") + raise ValueError( + "Incorrect type of audio_feature_lens. " + f"Got type: {type(audio_feature_lens)}" + ) audio_features_flat = flatten_bn(audio_features) audio_feature_lens_flat = flatten_bn(audio_feature_lens) @@ -748,10 +774,11 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("audio_features", - "audio_embeds") and "audios" not in modalities: - modalities["audios"] = self._parse_and_validate_audio_input( - **kwargs) + if ( + input_key in ("audio_features", "audio_embeds") + and "audios" not in modalities + ): + modalities["audios"] = self._parse_and_validate_audio_input(**kwargs) return modalities diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 8bef1ec514ab..09f973e98db9 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" + import math from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence @@ -43,8 +44,11 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig -from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, - get_2d_sincos_pos_embed) +from vllm.model_executor.layers.resampler import ( + BaseResampler, + Resampler2, + get_2d_sincos_pos_embed, +) from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM @@ -52,17 +56,33 @@ from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, - ImageProcessorItems, ImageSize, - ModalityData, ModalityDataItems, - MultiModalDataItems, MultiModalDataParser, - VideoItem, VideoProcessorItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails, - ResolvedPromptUpdate, _seq2text) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ImageItem, + ImageProcessorItems, + ImageSize, + ModalityData, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, + VideoItem, + VideoProcessorItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, + ResolvedPromptUpdate, + _seq2text, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -70,8 +90,12 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix # For profile run @@ -121,45 +145,48 @@ class MiniCPMVImageEmbeddingInputs(TensorSchema): ] -MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, - MiniCPMVImageEmbeddingInputs] +MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs] DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) class Resampler2_5(BaseResampler): - - def __init__(self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - max_size: tuple[int, int] = (70, 70), - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: - super().__init__(num_queries, - embed_dim, - num_heads, - kv_dim, - norm_layer, - quant_config=quant_config, - prefix=prefix) + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: tuple[int, int] = (70, 70), + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + quant_config=quant_config, + prefix=prefix, + ) self.max_size = max_size self._set_2d_pos_cache(self.max_size) - def _set_2d_pos_cache(self, - max_size: tuple[int, int], - device: torch.types.Device = "cpu") -> None: - pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, - max_size, - version=(2, 5)) + def _set_2d_pos_cache( + self, max_size: tuple[int, int], device: torch.types.Device = "cpu" + ) -> None: + pos_embed_arr = get_2d_sincos_pos_embed( + self.embed_dim, max_size, version=(2, 5) + ) pos_embed = torch.from_numpy(pos_embed_arr).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) - def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, - device: torch.types.Device) -> None: + def _adjust_pos_cache( + self, tgt_sizes: torch.Tensor, device: torch.types.Device + ) -> None: max_h = tgt_sizes[:, 0].max().item() max_w = tgt_sizes[:, 1].max().item() assert isinstance(max_h, int) and isinstance(max_w, int) @@ -171,8 +198,7 @@ def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, ) self._set_2d_pos_cache(self.max_size, device) - def forward(self, x: torch.Tensor, - tgt_sizes: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor: assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] @@ -186,21 +212,20 @@ def forward(self, x: torch.Tensor, max_patch_len = patch_len.max().item() assert isinstance(max_patch_len, int) - key_padding_mask = torch.zeros((bs, max_patch_len), - dtype=torch.bool, - device=device) + key_padding_mask = torch.zeros( + (bs, max_patch_len), dtype=torch.bool, device=device + ) pos_embed = [] for i in range(bs): tgt_h, tgt_w = tgt_sizes[i].tolist() - pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( - (tgt_h * tgt_w, -1)).to(dtype)) # patches * D - key_padding_mask[i, patch_len[i]:] = True - pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, - batch_first=True, - padding_value=0.0).permute( - 1, 0, - 2) # BLD => L * B * D + pos_embed.append( + self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype) + ) # patches * D + key_padding_mask[i, patch_len[i] :] = True + pos_embed = torch.nn.utils.rnn.pad_sequence( + pos_embed, batch_first=True, padding_value=0.0 + ).permute(1, 0, 2) # BLD => L * B * D x, _ = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D @@ -221,33 +246,37 @@ def forward(self, x: torch.Tensor, class Resampler4_5(Resampler2_5): + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: tuple[int, int] = (70, 70), + max_temporal_size: int = 36000, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + max_size, + quant_config=quant_config, + prefix=prefix, + ) - def __init__(self, - num_queries: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - max_size: tuple[int, int] = (70, 70), - max_temporal_size: int = 36000, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: - super().__init__(num_queries, - embed_dim, - num_heads, - kv_dim, - norm_layer, - max_size, - quant_config=quant_config, - prefix=prefix) - - trunc_normal_(self.query, std=.02) + trunc_normal_(self.query, std=0.02) self.max_temporal_size = max_temporal_size self._set_temporal_pos_cache(self.max_temporal_size) self.apply(self._init_weights) - def get_1d_sincos_pos_embed_from_temporal_size(self, embed_dim: int, - pos: np.ndarray): + def get_1d_sincos_pos_embed_from_temporal_size( + self, embed_dim: int, pos: np.ndarray + ): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) @@ -255,11 +284,11 @@ def get_1d_sincos_pos_embed_from_temporal_size(self, embed_dim: int, """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) @@ -267,25 +296,31 @@ def get_1d_sincos_pos_embed_from_temporal_size(self, embed_dim: int, emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb - def _set_temporal_pos_cache(self, - max_temporal_size: int, - device: torch.types.Device = "cpu") -> None: + def _set_temporal_pos_cache( + self, max_temporal_size: int, device: torch.types.Device = "cpu" + ) -> None: temporal_size = np.arange(max_temporal_size, dtype=np.float32) - pos_embed = torch.from_numpy( - self.get_1d_sincos_pos_embed_from_temporal_size( - self.embed_dim, temporal_size)).float().to(device) + pos_embed = ( + torch.from_numpy( + self.get_1d_sincos_pos_embed_from_temporal_size( + self.embed_dim, temporal_size + ) + ) + .float() + .to(device) + ) self.register_buffer("temporal_pos_embed", pos_embed, persistent=False) - def _adjust_temporal_pos_cache(self, - max_temporal_size: int, - device: torch.types.Device = "cpu"): + def _adjust_temporal_pos_cache( + self, max_temporal_size: int, device: torch.types.Device = "cpu" + ): if max_temporal_size > self.max_temporal_size: self.max_temporal_size = max_temporal_size self._set_temporal_pos_cache(self.max_temporal_size, device) def _init_weights(self, m: Union[nn.Linear, nn.LayerNorm]): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -297,7 +332,7 @@ def forward( x: torch.Tensor, tgt_sizes: torch.Tensor, # temporal_ids for high refresh rate videos - temporal_ids=None + temporal_ids=None, ) -> torch.Tensor: assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] @@ -323,9 +358,9 @@ def forward( max_patch_len = patch_len.max().item() assert isinstance(max_patch_len, int) - key_padding_mask = torch.zeros((bs, max_patch_len), - dtype=torch.bool, - device=device) + key_padding_mask = torch.zeros( + (bs, max_patch_len), dtype=torch.bool, device=device + ) x, _ = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D @@ -338,19 +373,21 @@ def forward( if temporal_pos_emb: if temporal_ids_flatten[i] == -1: pos_embed_temporal.append( - torch.zeros(self.embed_dim, dtype=dtype, - device=device)) + torch.zeros(self.embed_dim, dtype=dtype, device=device) + ) else: - pos_embed_temporal.append(self.temporal_pos_embed[ - temporal_ids_flatten[i]].to(dtype)) # D + pos_embed_temporal.append( + self.temporal_pos_embed[temporal_ids_flatten[i]].to(dtype) + ) # D - pos_embed_2d.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( - (tgt_h * tgt_w, -1)).to(dtype)) # patches * D - key_padding_mask[i, patch_len[i]:] = True + pos_embed_2d.append( + self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype) + ) # patches * D + key_padding_mask[i, patch_len[i] :] = True pos_embed_2d = torch.nn.utils.rnn.pad_sequence( - pos_embed_2d, batch_first=True, - padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D + pos_embed_2d, batch_first=True, padding_value=0.0 + ).permute(1, 0, 2) # BLD => L * B * D k = x v = x + pos_embed_2d @@ -366,26 +403,27 @@ def forward( end = start + len(tp) # L * (end-start) * D -> (end-start) * L * D # -> 1 * L*(end-start) * D - merge_k.append(k[:, start:end, :].permute(1, 0, 2).reshape( - -1, self.embed_dim)) - merge_v.append(v[:, start:end, :].permute(1, 0, 2).reshape( - -1, self.embed_dim)) + merge_k.append( + k[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim) + ) + merge_v.append( + v[:, start:end, :].permute(1, 0, 2).reshape(-1, self.embed_dim) + ) merge_key_padding_mask.append( - key_padding_mask[start:end, :].reshape(-1, 1)) + key_padding_mask[start:end, :].reshape(-1, 1) + ) start = end - k = torch.nn.utils.rnn.pad_sequence(merge_k, - batch_first=True, - padding_value=0.0).permute( - 1, 0, 2) # L*(end-start) - v = torch.nn.utils.rnn.pad_sequence(merge_v, - batch_first=True, - padding_value=0.0).permute( - 1, 0, 2) # L*(end-start) + k = torch.nn.utils.rnn.pad_sequence( + merge_k, batch_first=True, padding_value=0.0 + ).permute(1, 0, 2) # L*(end-start) + v = torch.nn.utils.rnn.pad_sequence( + merge_v, batch_first=True, padding_value=0.0 + ).permute(1, 0, 2) # L*(end-start) key_padding_mask = torch.nn.utils.rnn.pad_sequence( - merge_key_padding_mask, batch_first=True, - padding_value=True).squeeze(-1) + merge_key_padding_mask, batch_first=True, padding_value=True + ).squeeze(-1) out = self.attn( self._repeat(q, bs), # Q * B * D @@ -436,7 +474,6 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): class MiniCPMVImageEmbeddingItems(DictEmbeddingItems): - def __init__( self, data: Mapping[str, torch.Tensor], @@ -458,7 +495,6 @@ def get_image_size(self, index: int) -> ImageSize: class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems): - def __init__( self, data: Mapping[str, torch.Tensor], @@ -483,7 +519,6 @@ def get_num_frames(self, index: int) -> int: class MiniCPMVMultiModalDataParser(MultiModalDataParser): - def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], @@ -652,21 +687,18 @@ def get_num_frames_with_most_features( max_videos = mm_counts.get("video", 0) max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self.get_max_video_frames(seq_len - - max_image_tokens) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO) + max_total_frames = self.get_max_video_frames(seq_len - max_image_tokens) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO + ) return max(max_frames_per_video, 1) -_I = TypeVar("_I", - bound=MiniCPMVProcessingInfo, - default=MiniCPMVProcessingInfo) +_I = TypeVar("_I", bound=MiniCPMVProcessingInfo, default=MiniCPMVProcessingInfo) class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -685,52 +717,54 @@ def get_dummy_mm_data( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - image_width, image_height = \ - self.info.get_image_size_with_most_features() - video_width, video_height = \ - self.info.get_video_frame_size_with_most_features() - num_video_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + image_width, image_height = self.info.get_image_size_with_most_features() + video_width, video_height = self.info.get_video_frame_size_with_most_features() + num_video_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images(width=image_width, - height=image_height, - num_images=num_images, - overrides=image_overrides), + "image": self._get_dummy_images( + width=image_width, + height=image_height, + num_images=num_images, + overrides=image_overrides, + ), "video": [ - self._get_dummy_images(width=video_width, - height=video_height, - num_images=num_video_frames, - overrides=video_overrides) - ] * num_videos, + self._get_dummy_images( + width=video_width, + height=video_height, + num_images=num_video_frames, + overrides=video_overrides, + ) + ] + * num_videos, } class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): - def _get_data_parser(self) -> MultiModalDataParser: return MiniCPMVMultiModalDataParser() - def get_image_prompt_texts(self, - image_size: ImageSize, - image_idx: int = 0) -> str: + def get_image_prompt_texts(self, image_size: ImageSize, image_idx: int = 0) -> str: return self.info.get_slice_image_placeholder( image_size, image_idx=image_idx, ) - def get_video_prompt_texts(self, image_size: ImageSize, - num_frames: int) -> str: - return self.info.get_slice_image_placeholder( - image_size=image_size, - image_idx=0, - max_slice_nums=self.info.get_video_max_slice_num(), - use_image_id=False, - ) * num_frames + def get_video_prompt_texts(self, image_size: ImageSize, num_frames: int) -> str: + return ( + self.info.get_slice_image_placeholder( + image_size=image_size, + image_idx=0, + max_slice_nums=self.info.get_video_max_slice_num(), + use_image_id=False, + ) + * num_frames + ) def process_images( self, @@ -741,10 +775,11 @@ def process_images( if (images := mm_data.get("images")) is None: return {} - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": images - }).get_items("image", - (MiniCPMVImageEmbeddingItems, ImageProcessorItems))) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems)) + ) if isinstance(parsed_images, MiniCPMVImageEmbeddingItems): image_inputs = {} @@ -772,24 +807,23 @@ def process_videos( if (videos := mm_data.get("videos")) is None: return {} - parsed_videos = (self._get_data_parser().parse_mm_data({ - "video": videos - }).get_items("video", - (MiniCPMVVideoEmbeddingItems, VideoProcessorItems))) + parsed_videos = ( + self._get_data_parser() + .parse_mm_data({"video": videos}) + .get_items("video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems)) + ) if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems): video_inputs = {} else: video_inputs = self._base_call_hf_processor( prompts=[ - self.info.image_pattern * len(video) - for video in parsed_videos + self.info.image_pattern * len(video) for video in parsed_videos ], mm_data={"images": list(parsed_videos)}, mm_kwargs={ **mm_kwargs, - "max_slice_nums": - self.info.get_video_max_slice_num(), + "max_slice_nums": self.info.get_video_max_slice_num(), }, tok_kwargs=tok_kwargs, out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, @@ -837,10 +871,7 @@ def _base_call_hf_processor( for i, prompt in enumerate(prompts): inputs_one = super()._call_hf_processor( prompt=prompt, - mm_data={ - k: v[i] - for k, v in mm_data.items() - }, + mm_data={k: v[i] for k, v in mm_data.items()}, mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) @@ -863,10 +894,12 @@ def _call_hf_processor( input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)]) mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs, tok_kwargs) - return BatchFeature({ - "input_ids": input_ids, - **mm_inputs, - }) + return BatchFeature( + { + "input_ids": input_ids, + **mm_inputs, + } + ) def _hf_processor_applies_updates( self, @@ -883,22 +916,26 @@ def _get_prompt_updates( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: - placeholders = [("image", self.info.image_pattern), - ("video", self.info.video_pattern)] + placeholders = [ + ("image", self.info.image_pattern), + ("video", self.info.video_pattern), + ] # hard code for inconsistency of encode-decode image_pattern additional_placeholders = [] tokenizer = self.info.get_tokenizer() for modality, pattern in placeholders: sub_pattern = tokenizer.decode( - tokenizer.encode(pattern, add_special_tokens=False)) + tokenizer.encode(pattern, add_special_tokens=False) + ) if sub_pattern != pattern: additional_placeholders.append((modality, sub_pattern)) placeholders += additional_placeholders def get_image_replacement(item_idx: int): images = mm_items.get_items( - "image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems)) + "image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems) + ) image_size = images.get_image_size(item_idx) @@ -909,7 +946,8 @@ def get_image_replacement(item_idx: int): def get_video_replacement(item_idx: int): videos = mm_items.get_items( - "video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems)) + "video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems) + ) frame_size = videos.get_frame_size(item_idx) num_frames = videos.get_num_frames(item_idx) @@ -925,9 +963,9 @@ def get_video_replacement(item_idx: int): } return [ - PromptReplacement(modality=modality, - target=pattern, - replacement=get_replacement[modality]) + PromptReplacement( + modality=modality, target=pattern, replacement=get_replacement[modality] + ) for modality, pattern in placeholders ] @@ -964,7 +1002,8 @@ def _recompute_cached_prompt_update( 1, ), "<unk>", - )) + ) + ) return new_update @@ -1007,24 +1046,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config self.version = get_version_by_config(self.config) - self.llm = self.init_llm(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "llm")) - self.vpm = self.init_vision_module(config, - quant_config, - prefix=maybe_prefix(prefix, "vpm")) - self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else - self.vpm.embeddings.embed_dim) + self.llm = self.init_llm( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "llm") + ) + self.vpm = self.init_vision_module( + config, quant_config, prefix=maybe_prefix(prefix, "vpm") + ) + self.vision_dim = ( + self.vpm.embed_dim + if self.version == (2, 0) + else self.vpm.embeddings.embed_dim + ) self.embed_dim = self.config.hidden_size - self.resampler = self.init_resampler(self.embed_dim, - self.vision_dim, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "resampler")) + self.resampler = self.init_resampler( + self.embed_dim, + self.vision_dim, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "resampler"), + ) self.mm_token_ids = set[int]() - self.make_empty_intermediate_tensors = ( - self.llm.make_empty_intermediate_tensors) + self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors def _parse_and_validate_vision_input( self, @@ -1046,7 +1089,8 @@ def _parse_and_validate_vision_input( if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError( f"Incorrect type of image_embeds for {modality=}. " - f"Got type: {type(image_embeds)}") + f"Got type: {type(image_embeds)}" + ) image_embeds_flat = flatten_bn(image_embeds) @@ -1058,12 +1102,15 @@ def _parse_and_validate_vision_input( if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError( f"Incorrect type of pixel_values for {modality=}. " - f"Got type: {type(pixel_values)}") + f"Got type: {type(pixel_values)}" + ) tgt_sizes = kwargs.pop("tgt_sizes") if not isinstance(tgt_sizes, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. " - f"Got type: {type(tgt_sizes)}") + raise ValueError( + f"Incorrect type of tgt_sizes for {modality=}. " + f"Got type: {type(tgt_sizes)}" + ) num_slices = [[len(p) for p in ps] for ps in pixel_values] num_slices_flat = flatten_bn(torch.tensor(num_slices)) @@ -1084,12 +1131,17 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): modalities["images"] = self._parse_and_validate_vision_input( - "images", **kwargs) - if input_key in ("video_pixel_values", - "video_embeds") and "videos" not in modalities: + "images", **kwargs + ) + if ( + input_key in ("video_pixel_values", "video_embeds") + and "videos" not in modalities + ): def _image_key(video_key: str): if video_key == "video_token_id": @@ -1098,10 +1150,8 @@ def _image_key(video_key: str): return video_key.removeprefix("video_") modalities["videos"] = self._parse_and_validate_vision_input( - "videos", **{ - _image_key(k): v - for k, v in kwargs.items() - }) + "videos", **{_image_key(k): v for k, v in kwargs.items()} + ) return modalities @@ -1115,10 +1165,7 @@ def _process_vision_input( image_features_flat = self.get_vision_hidden_states(image_input) num_slices = image_input["num_slices"] - return [ - e.flatten(0, 1) - for e in image_features_flat.split(num_slices.tolist()) - ] + return [e.flatten(0, 1) for e in image_features_flat.split(num_slices.tolist())] def _process_multimodal_inputs(self, modalities: dict): # The result multimodal_embeddings is tuple of tensors, with each @@ -1142,8 +1189,7 @@ def _process_multimodal_inputs(self, modalities: dict): def get_language_model(self) -> torch.nn.Module: return self.llm - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -1175,8 +1221,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.llm.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) @@ -1184,9 +1229,9 @@ def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ - return MultiModelKeys.from_string_field(language_model="llm", - connector="resampler", - tower_model="vpm") + return MultiModelKeys.from_string_field( + language_model="llm", connector="resampler", tower_model="vpm" + ) def init_llm( self, @@ -1203,20 +1248,20 @@ def init_vision_module( ) -> nn.Module: raise NotImplementedError - def init_resampler(self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: raise NotImplementedError - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: raise NotImplementedError class MiniCPMV2_0(MiniCPMVBaseModel): - supports_encoder_tp_data = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -1253,8 +1298,10 @@ def init_vision_module( model = model.to(dtype=torch.get_default_dtype()) - if (isinstance(model, timm.models.VisionTransformer) - and model.attn_pool is not None): + if ( + isinstance(model, timm.models.VisionTransformer) + and model.attn_pool is not None + ): model.attn_pool = torch.nn.Identity() if self.config.drop_vision_last_layer: @@ -1262,27 +1309,30 @@ def init_vision_module( return model - def init_resampler(self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: with set_default_torch_dtype(torch.float16): - resampler = Resampler2(embed_dim=embed_dim, - num_heads=embed_dim // 128, - grid_size=int( - math.sqrt(self.config.query_num)), - kv_dim=vision_dim, - adaptive=False, - do_post_projection=True, - quant_config=quant_config, - prefix=prefix) - - return resampler.to(device=current_platform.device_type, - dtype=torch.get_default_dtype()) - - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + resampler = Resampler2( + embed_dim=embed_dim, + num_heads=embed_dim // 128, + grid_size=int(math.sqrt(self.config.query_num)), + kv_dim=vision_dim, + adaptive=False, + do_post_projection=True, + quant_config=quant_config, + prefix=prefix, + ) + + return resampler.to( + device=current_platform.device_type, dtype=torch.get_default_dtype() + ) + + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] P_h, P_w = self.vpm.patch_embed.patch_size @@ -1294,7 +1344,8 @@ def get_vision_hidden_states( H, W = pixel_value[0].shape[-2:] tgt_size = (math.ceil(H / P_h), math.ceil(W / P_w)) vision_embedding = self.vpm.forward_features( - pixel_value.unsqueeze(0).type(dtype)) + pixel_value.unsqueeze(0).type(dtype) + ) if num_prefix_tokens > 0: vision_embedding = vision_embedding[:, num_prefix_tokens:] @@ -1343,24 +1394,28 @@ def init_vision_module( model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: with set_default_torch_dtype(torch.float16): - resampler = Resampler2_5(num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - quant_config=quant_config, - prefix=prefix) - - return resampler.to(device=current_platform.device_type, - dtype=torch.get_default_dtype()) - - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix, + ) + + return resampler.to( + device=current_platform.device_type, dtype=torch.get_default_dtype() + ) + + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -1370,9 +1425,7 @@ def get_vision_hidden_states( device = pixel_values[0].device dtype = pixel_values[0].dtype - all_pixel_values = torch.zeros((B, 3, P, L), - dtype=dtype, - device=device) + all_pixel_values = torch.zeros((B, 3, P, L), dtype=dtype, device=device) for i, pixel_values_item in enumerate(pixel_values): L_item = pixel_values_item.shape[-1] all_pixel_values[i, ..., :L_item] = pixel_values_item @@ -1381,9 +1434,7 @@ def get_vision_hidden_states( max_patches = num_patches.max().item() assert isinstance(max_patches, int) - patch_attn_mask = torch.zeros((B, max_patches), - dtype=torch.bool, - device=device) + patch_attn_mask = torch.zeros((B, max_patches), dtype=torch.bool, device=device) for i, num_patches_item in enumerate(num_patches): patch_attn_mask[i, :num_patches_item] = True @@ -1436,25 +1487,29 @@ def init_vision_module( model.encoder.layers = model.encoder.layers[:-1] return model - def init_resampler(self, - embed_dim: int, - vision_dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> nn.Module: + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: with set_default_torch_dtype(torch.float16): # The resampler in 2.6 remains consistent with the one in 2.5. - resampler = Resampler2_5(num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - quant_config=quant_config, - prefix=prefix) - - return resampler.to(device=current_platform.device_type, - dtype=torch.get_default_dtype()) - - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix, + ) + + return resampler.to( + device=current_platform.device_type, dtype=torch.get_default_dtype() + ) + + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -1464,9 +1519,7 @@ def get_vision_hidden_states( device = pixel_values[0].device dtype = pixel_values[0].dtype - all_pixel_values = torch.zeros((B, 3, P, L), - dtype=dtype, - device=device) + all_pixel_values = torch.zeros((B, 3, P, L), dtype=dtype, device=device) for i, pixel_values_item in enumerate(pixel_values): L_item = pixel_values_item.shape[-1] all_pixel_values[i, ..., :L_item] = pixel_values_item @@ -1475,9 +1528,7 @@ def get_vision_hidden_states( max_patches = num_patches.max().item() assert isinstance(max_patches, int) - patch_attn_mask = torch.zeros((B, max_patches), - dtype=torch.bool, - device=device) + patch_attn_mask = torch.zeros((B, max_patches), dtype=torch.bool, device=device) for i, num_patches_item in enumerate(num_patches): patch_attn_mask[i, :num_patches_item] = True @@ -1489,10 +1540,8 @@ def get_vision_hidden_states( return self.resampler(vision_embedding, tgt_sizes) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, - skip_prefixes=["apm.", "audio", "tts"]) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"]) return loader.load_weights(weights) @@ -1552,18 +1601,20 @@ def init_resampler( quant_config = self._maybe_ignore_quant_config(quant_config) with set_default_torch_dtype(torch.float16): # The resampler in 4.0 remains consistent with the one in 2.5/2.6. - resampler = Resampler2_5(num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - quant_config=quant_config, - prefix=prefix) - - return resampler.to(device=current_platform.device_type, - dtype=torch.get_default_dtype()) - - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix, + ) + + return resampler.to( + device=current_platform.device_type, dtype=torch.get_default_dtype() + ) + + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -1573,9 +1624,7 @@ def get_vision_hidden_states( device = pixel_values[0].device dtype = pixel_values[0].dtype - all_pixel_values = torch.zeros((B, 3, P, L), - dtype=dtype, - device=device) + all_pixel_values = torch.zeros((B, 3, P, L), dtype=dtype, device=device) for i, pixel_values_item in enumerate(pixel_values): L_item = pixel_values_item.shape[-1] all_pixel_values[i, ..., :L_item] = pixel_values_item @@ -1584,9 +1633,7 @@ def get_vision_hidden_states( max_patches = num_patches.max().item() assert isinstance(max_patches, int) - patch_attn_mask = torch.zeros((B, max_patches), - dtype=torch.bool, - device=device) + patch_attn_mask = torch.zeros((B, max_patches), dtype=torch.bool, device=device) for i, num_patches_item in enumerate(num_patches): patch_attn_mask[i, :num_patches_item] = True @@ -1598,10 +1645,8 @@ def get_vision_hidden_states( return self.resampler(vision_embedding, tgt_sizes) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, - skip_prefixes=["apm.", "audio", "tts"]) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"]) return loader.load_weights(weights) @@ -1661,21 +1706,23 @@ def init_resampler( quant_config = self._maybe_ignore_quant_config(quant_config) with set_default_torch_dtype(torch.float16): # The resampler in 4.0 remains consistent with the one in 2.5/2.6. - resampler = Resampler4_5(num_queries=self.config.query_num, - embed_dim=embed_dim, - num_heads=embed_dim // 128, - kv_dim=vision_dim, - quant_config=quant_config, - prefix=prefix) - - return resampler.to(device=current_platform.device_type, - dtype=torch.get_default_dtype()) - - def get_vision_hidden_states( - self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + resampler = Resampler4_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix, + ) + + return resampler.to( + device=current_platform.device_type, dtype=torch.get_default_dtype() + ) + + def get_vision_hidden_states(self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] - temporal_ids = data.get('temporal_ids', None) + temporal_ids = data.get("temporal_ids", None) B = len(pixel_values) P = pixel_values[0].shape[-2] @@ -1683,11 +1730,10 @@ def get_vision_hidden_states( device = pixel_values[0].device dtype = pixel_values[0].dtype - all_pixel_values = torch.zeros((B, 3, P, L), - dtype=dtype, - device=device) - all_temporal_ids = None if temporal_ids is None else flatten_2d_lists( - temporal_ids) + all_pixel_values = torch.zeros((B, 3, P, L), dtype=dtype, device=device) + all_temporal_ids = ( + None if temporal_ids is None else flatten_2d_lists(temporal_ids) + ) for i, pixel_values_item in enumerate(pixel_values): L_item = pixel_values_item.shape[-1] all_pixel_values[i, ..., :L_item] = pixel_values_item @@ -1696,9 +1742,7 @@ def get_vision_hidden_states( max_patches = num_patches.max().item() assert isinstance(max_patches, int) - patch_attn_mask = torch.zeros((B, max_patches), - dtype=torch.bool, - device=device) + patch_attn_mask = torch.zeros((B, max_patches), dtype=torch.bool, device=device) for i, num_patches_item in enumerate(num_patches): patch_attn_mask[i, :num_patches_item] = True @@ -1710,10 +1754,8 @@ def get_vision_hidden_states( return self.resampler(vision_embedding, tgt_sizes, all_temporal_ids) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, - skip_prefixes=["apm.", "audio", "tts"]) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["apm.", "audio", "tts"]) return loader.load_weights(weights) @@ -1729,7 +1771,8 @@ def load_weights(self, weights: Iterable[tuple[str, @MULTIMODAL_REGISTRY.register_processor( MiniCPMVMultiModalProcessor, info=MiniCPMVProcessingInfo, - dummy_inputs=MiniCPMVDummyInputsBuilder) + dummy_inputs=MiniCPMVDummyInputsBuilder, +) class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): """ Different versions of MiniCPMV use different visual encoders and LLMs, @@ -1751,9 +1794,12 @@ def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): instance_cls = _SUPPORT_VERSION.get(version) if instance_cls is None: supported_versions = ", ".join( - [f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())]) - raise ValueError(f"Currently, MiniCPMV only supports versions " - f"{supported_versions}. Got version: {version}") + [f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())] + ) + raise ValueError( + f"Currently, MiniCPMV only supports versions " + f"{supported_versions}. Got version: {version}" + ) # quant_config references base class members, # so update values before init is called diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 45228aa0bb93..e6e0952f71dd 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only MiniMaxText01 model.""" + from collections.abc import Iterable from itertools import islice from typing import TYPE_CHECKING, Optional, Union @@ -18,25 +19,33 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import ( - get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mamba.linear_attn import ( - MiniMaxText01LinearAttention) +from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01LinearAttention from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import maybe_prefix from vllm.sequence import IntermediateTensors @@ -45,25 +54,22 @@ from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers -def replace_weight_name(name: str, - key: str = None, - to: str = None, - count: int = None, - prefix: str = None) -> str: - name = name.replace(key, to) if count is None else \ - name.replace(key, to, count) +def replace_weight_name( + name: str, key: str = None, to: str = None, count: int = None, prefix: str = None +) -> str: + name = name.replace(key, to) if count is None else name.replace(key, to, count) return name def weight_loader_with_alias(alias: str): - def wrapper(func: callable): - - def inner_func(param: torch.Tensor, - loaded_weight: torch.Tensor, - *args, - prefix: str = None, - **kwargs): + def inner_func( + param: torch.Tensor, + loaded_weight: torch.Tensor, + *args, + prefix: str = None, + **kwargs, + ): value = func(param, loaded_weight, *args, **kwargs) return value @@ -73,7 +79,6 @@ def inner_func(param: torch.Tensor, class MiniMaxText01MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -103,7 +108,6 @@ def __init__( return def forward(self, x: torch.Tensor) -> torch.Tensor: - gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -111,7 +115,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MiniMaxText01MoE(nn.Module): - def __init__( self, num_experts: int, @@ -162,8 +165,7 @@ def __init__( return @staticmethod - def gate_weight_loader(param: nn.Parameter, - loaded_weight: torch.Tensor) -> None: + def gate_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None: assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight.to(torch.float32)) return @@ -173,13 +175,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, self.hidden_size) router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32)) final_hidden_states = self.experts( - hidden_states, router_logits_fp32.to(hidden_states.dtype)) + hidden_states, router_logits_fp32.to(hidden_states.dtype) + ) final_hidden = final_hidden_states.view(num_tokens, hidden_size) return final_hidden class MiniMaxText01Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -253,8 +255,13 @@ def __init__( ) return - def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, - positions: torch.Tensor, **kwargs) -> None: + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + **kwargs, + ) -> None: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) @@ -263,7 +270,6 @@ def forward(self, hidden_states: torch.Tensor, output: torch.Tensor, class MiniMaxText01DecoderLayer(nn.Module): - def __init__( self, config: MiniMaxConfig, @@ -288,14 +294,17 @@ def __init__( head_dim = getattr(config, "head_dim", None) if head_dim is None: head_dim = config.hidden_size // config.num_attention_heads - if hasattr(config, "max_model_len") and isinstance( - config.max_model_len, int): - max_position_embeddings = min(config.max_position_embeddings, - config.max_model_len) + if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int): + max_position_embeddings = min( + config.max_position_embeddings, config.max_model_len + ) if config.attention_type == 0: use_headxdim = True - hidden_inner = (head_dim * config.num_attention_heads - if use_headxdim else config.hidden_size) + hidden_inner = ( + head_dim * config.num_attention_heads + if use_headxdim + else config.hidden_size + ) self.self_attn = MiniMaxText01LinearAttention( hidden_size=self.hidden_size, hidden_inner_size=hidden_inner, @@ -309,14 +318,16 @@ def __init__( quant_config=quant_config, layer_idx=self._ilayer, linear_layer_idx=linear_layer_id, - prefix=prefix) + prefix=prefix, + ) elif config.attention_type == 1: self.self_attn = MiniMaxText01Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, head_dim=head_dim, rotary_dim=config.rotary_dim - if hasattr(config, "rotary_dim") else head_dim, + if hasattr(config, "rotary_dim") + else head_dim, num_kv_heads=config.num_key_value_heads, max_position=max_position_embeddings, rope_theta=rope_theta, @@ -324,10 +335,12 @@ def __init__( quant_config=quant_config, layer_idx=self._ilayer, cache_config=cache_config, - prefix=prefix) + prefix=prefix, + ) else: raise ValueError( - f"Unsupported attention type: {self.config.attention_type}") + f"Unsupported attention type: {self.config.attention_type}" + ) if expert_num == 1: self.mlp = MiniMaxText01MLP( @@ -335,7 +348,8 @@ def __init__( intermediate_size=config.intermediate_size, quant_config=quant_config, layer_idx=self._ilayer, - prefix=prefix) + prefix=prefix, + ) else: self.block_sparse_moe = MiniMaxText01MoE( num_experts=expert_num, @@ -344,39 +358,51 @@ def __init__( intermediate_size=config.intermediate_size, layer_idx=self._ilayer, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) if config.attention_type == 0: self.layernorm_attention_alpha = getattr( - config, 'layernorm_linear_attention_alpha', - getattr(config, 'linear_attn_alpha_factor', 1)) + config, + "layernorm_linear_attention_alpha", + getattr(config, "linear_attn_alpha_factor", 1), + ) self.layernorm_attention_beta = getattr( - config, 'layernorm_linear_attention_beta', - getattr(config, 'linear_attn_beta_factor', 1)) + config, + "layernorm_linear_attention_beta", + getattr(config, "linear_attn_beta_factor", 1), + ) else: self.layernorm_attention_alpha = getattr( - config, 'layernorm_full_attention_alpha', - getattr(config, 'full_attn_alpha_factor', 1)) + config, + "layernorm_full_attention_alpha", + getattr(config, "full_attn_alpha_factor", 1), + ) self.layernorm_attention_beta = getattr( - config, 'layernorm_full_attention_beta', - getattr(config, 'full_attn_beta_factor', 1)) + config, + "layernorm_full_attention_beta", + getattr(config, "full_attn_beta_factor", 1), + ) self.layernorm_mlp_alpha = getattr( - config, 'layernorm_mlp_alpha', - getattr(config, 'mlp_alpha_factor', 1)) + config, "layernorm_mlp_alpha", getattr(config, "mlp_alpha_factor", 1) + ) self.layernorm_mlp_beta = getattr( - config, 'layernorm_mlp_beta', getattr(config, 'mlp_beta_factor', - 1)) - self.postnorm = getattr(config, 'postnorm', False) + config, "layernorm_mlp_beta", getattr(config, "mlp_beta_factor", 1) + ) + self.postnorm = getattr(config, "postnorm", False) self.shared_moe = False - shared_intermediate = getattr(config, 'shared_intermediate_size', 0) + shared_intermediate = getattr(config, "shared_intermediate_size", 0) if isinstance(shared_intermediate, list): - shared_intermediate = shared_intermediate[ - layer_id] if layer_id < len(shared_intermediate) else 0 + shared_intermediate = ( + shared_intermediate[layer_id] + if layer_id < len(shared_intermediate) + else 0 + ) if shared_intermediate > 0: self.shared_moe = True self.shared_mlp = MiniMaxText01MLP( @@ -384,7 +410,8 @@ def __init__( intermediate_size=shared_intermediate, quant_config=quant_config, layer_idx=self._ilayer, - prefix=prefix) + prefix=prefix, + ) self.coefficient = ReplicatedLinear( self.hidden_size, 1, @@ -392,20 +419,19 @@ def __init__( quant_config=quant_config, params_dtype=torch.float32, ) - self.coefficient.weight.weight_loader = ( - self.shared_moe_coefficient_loader) - self.shared_moe_mode = getattr(config, 'shared_moe_mode', - 'softmax') + self.coefficient.weight.weight_loader = self.shared_moe_coefficient_loader + self.shared_moe_mode = getattr(config, "shared_moe_mode", "softmax") return - def forward(self, - hidden_states: torch.Tensor, - positions: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - is_warmup: bool = False, - **kwargs) -> tuple[torch.Tensor, torch.Tensor]: - + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + is_warmup: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: layernorm_input = hidden_states layernorm_output = self.input_layernorm(layernorm_input) residual = layernorm_output if self.postnorm else layernorm_input @@ -417,8 +443,7 @@ def forward(self, ) residual = residual * self.layernorm_attention_alpha - self_attention_output = (self_attention_output * - self.layernorm_attention_beta) + self_attention_output = self_attention_output * self.layernorm_attention_beta layernorm_input = residual + self_attention_output layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -432,19 +457,16 @@ def forward(self, if self.shared_moe: before_moe_dtype = layernorm_output.dtype moe_hidden_fp32 = moe_hidden_states.to(torch.float32) - output_mlp = self.shared_mlp(layernorm_output).to( - torch.float32) + output_mlp = self.shared_mlp(layernorm_output).to(torch.float32) coef, _ = self.coefficient(layernorm_output.to(torch.float32)) - if self.shared_moe_mode == 'softmax': + if self.shared_moe_mode == "softmax": coef = torch.nn.functional.softmax(coef, dim=-1) - hidden_states = moe_hidden_fp32 * ( - 1 - coef) + output_mlp * coef - elif self.shared_moe_mode == 'sigmoid': + hidden_states = moe_hidden_fp32 * (1 - coef) + output_mlp * coef + elif self.shared_moe_mode == "sigmoid": coef = torch.nn.functional.sigmoid(coef) - hidden_states = moe_hidden_fp32 * ( - 1 - coef) + output_mlp * coef + hidden_states = moe_hidden_fp32 * (1 - coef) + output_mlp * coef hidden_states = hidden_states.to(before_moe_dtype) else: @@ -458,8 +480,9 @@ def forward(self, return hidden_states, None @staticmethod - def shared_moe_coefficient_loader(param: torch.Tensor, - loaded_weight: torch.Tensor) -> None: + def shared_moe_coefficient_loader( + param: torch.Tensor, loaded_weight: torch.Tensor + ) -> None: assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight.to(torch.float32)) @@ -468,7 +491,6 @@ def shared_moe_coefficient_loader(param: torch.Tensor, @support_torch_compile class MiniMaxText01Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: MiniMaxConfig = vllm_config.model_config.hf_config @@ -481,8 +503,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vocab_size = config.vocab_size self.decoder_attention_types = getattr( - config, "attn_type_list", False) or getattr( - config, "decoder_attention_types", False) + config, "attn_type_list", False + ) or getattr(config, "decoder_attention_types", False) # The HF format uses "layer_types" instead of "attn_type_list" # where "linear_attention" is 0 and "full_attention" is 1 if not self.decoder_attention_types and hasattr(config, "layer_types"): @@ -510,50 +532,57 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = PPMissingLayer() def layer_fn(prefix): - layer_idx = int(prefix.split('.')[-1]) + layer_idx = int(prefix.split(".")[-1]) layer_config = config - layer_config.attention_type = self.decoder_attention_types[ - layer_idx] + layer_config.attention_type = self.decoder_attention_types[layer_idx] layer_config.layer_idx = layer_idx decoder_kwargs = { "quant_config": quant_config, "layer_id": layer_idx, "model_config": model_config, - "cache_config": cache_config + "cache_config": cache_config, } if layer_config.attention_type == 0: decoder_kwargs["linear_layer_id"] = sum( - 1 for i in range(layer_idx) - if self.decoder_attention_types[i] == 0) + 1 for i in range(layer_idx) if self.decoder_attention_types[i] == 0 + ) else: decoder_kwargs["linear_layer_id"] = None if hasattr(config, "num_local_experts") and isinstance( - config.num_local_experts, list): - decoder_kwargs["expert_num"] = config.num_local_experts[ - layer_idx] + config.num_local_experts, list + ): + decoder_kwargs["expert_num"] = config.num_local_experts[layer_idx] elif hasattr(config, "num_local_experts") and isinstance( - config.num_local_experts, int): + config.num_local_experts, int + ): decoder_kwargs["expert_num"] = config.num_local_experts else: decoder_kwargs["expert_num"] = 1 - return MiniMaxText01DecoderLayer(layer_config, - **decoder_kwargs, - prefix=prefix) + return MiniMaxText01DecoderLayer( + layer_config, **decoder_kwargs, prefix=prefix + ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers") + config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers" + ) - linear_layer_nums = sum(1 for i in range(config.num_hidden_layers) - if self.decoder_attention_types[i] == 0) + linear_layer_nums = sum( + 1 + for i in range(config.num_hidden_layers) + if self.decoder_attention_types[i] == 0 + ) max_slots_number = scheduler_config.max_num_seqs - self.cache_shape = (linear_layer_nums, max_slots_number, - config.num_attention_heads // - get_tensor_model_parallel_world_size(), - config.head_dim, config.head_dim) + self.cache_shape = ( + linear_layer_nums, + max_slots_number, + config.num_attention_heads // get_tensor_model_parallel_world_size(), + config.head_dim, + config.head_dim, + ) _dummy = torch.zeros(1) self._dtype = _dummy.dtype del _dummy @@ -568,12 +597,12 @@ def layer_fn(prefix): self.embed_scale = 1.0 return - def _clear_prefill_cache(self, attn_metadata, - minimax_cache_tensors: torch.Tensor, **kwargs): + def _clear_prefill_cache( + self, attn_metadata, minimax_cache_tensors: torch.Tensor, **kwargs + ): seq_to_slot_maps = {} seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), []) - for _, seq_to_slot_map in ( - self.minimax_cache.cache_indices_mapping.items()): + for _, seq_to_slot_map in self.minimax_cache.cache_indices_mapping.items(): seq_to_slot_maps.update(seq_to_slot_map) slots_to_clear = [] @@ -581,25 +610,29 @@ def _clear_prefill_cache(self, attn_metadata, if _prefill_id >= len(seq_id_map): break seq_id = seq_id_map[_prefill_id] - if attn_metadata.context_lens_tensor[ - _prefill_id] == 0 and seq_id in seq_to_slot_maps: + if ( + attn_metadata.context_lens_tensor[_prefill_id] == 0 + and seq_id in seq_to_slot_maps + ): slots_to_clear.append(seq_to_slot_maps[seq_id]) if slots_to_clear: - slots_tensor = torch.tensor(slots_to_clear, - device=minimax_cache_tensors.device, - dtype=torch.long) + slots_tensor = torch.tensor( + slots_to_clear, device=minimax_cache_tensors.device, dtype=torch.long + ) minimax_cache_tensors[:, slots_tensor, ...] = 0 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) - def forward(self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> Union[torch.Tensor, IntermediateTensors]: + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -622,10 +655,9 @@ def forward(self, residual=residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) if residual is not None: hidden_states, _ = self.norm(hidden_states, residual) else: @@ -635,9 +667,7 @@ def forward(self, class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - super().__init__() config = vllm_config.model_config.hf_config lora_config = vllm_config.lora_config @@ -652,8 +682,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.unpadded_vocab_size = self.config.vocab_size if hasattr(vllm_config.model_config, "max_model_len"): self.config.max_model_len = vllm_config.model_config.max_model_len - self.model = MiniMaxText01Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MiniMaxText01Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( self.unpadded_vocab_size, @@ -663,37 +694,41 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.config.vocab_size + ) else: self.lm_head = PPMissingLayer() self.lm_head.float() flash_layer_count = sum( - 1 for attn_type in self.model.decoder_attention_types - if attn_type == 1) + 1 for attn_type in self.model.decoder_attention_types if attn_type == 1 + ) self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] return def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.model.minimax_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) + input_buffers, **kwargs + ) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs( - batch_size) + return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, **kwargs) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) return hidden_states @@ -703,21 +738,20 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: return logits def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -729,7 +763,8 @@ def which_layer(name: str) -> int: def is_linear_attn_layer(layer_idx: int) -> bool: if layer_idx is None or layer_idx >= len( - self.model.decoder_attention_types): + self.model.decoder_attention_types + ): return False return self.model.decoder_attention_types[layer_idx] == 0 @@ -737,39 +772,48 @@ def is_moe_weight(name: str) -> bool: return "block_sparse_moe" in name and not name.endswith(".bias") def get_expert_id(param_name): - pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.' + pattern = r"model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\." match = re.search(pattern, param_name) if match: return match.group(1) return None - def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: + def load_sparse_moe_weight( + name: str, loaded_weight: torch.Tensor, self + ) -> None: if isinstance(self.config.num_local_experts, list): expert_params_mapping = [ - ("w13_weight" - if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id) + ( + "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + ) for expert_id in range(max(self.config.num_local_experts)) for weight_name in ["w1", "w2", "w3"] ] else: expert_params_mapping = [ - ("w13_scale" if weight_name in ["w1", "w3"] else - "w2_scale", f"{expert_id}.{weight_name}.weight_scale", - expert_id, weight_name) + ( + "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"{expert_id}.{weight_name}.weight_scale", + expert_id, + weight_name, + ) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] - ] + [("w13_weight" if weight_name in ["w1", "w3"] else - "w2_weight", f"{expert_id}.{weight_name}.weight", - expert_id, weight_name) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"]] - for (param_name, weight_name, expert_id, - shard_id) in expert_params_mapping: + ] + [ + ( + "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", + f"{expert_id}.{weight_name}.weight", + expert_id, + weight_name, + ) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + for param_name, weight_name, expert_id, shard_id in expert_params_mapping: name_expert_id = get_expert_id(name) - if name_expert_id is not None and int(name_expert_id) != int( - expert_id): + if name_expert_id is not None and int(name_expert_id) != int(expert_id): continue if weight_name not in name: continue @@ -779,19 +823,20 @@ def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor, param = params_dict[name] weight_loader = param.weight_loader weight_loader = weight_loader_with_alias(name)(weight_loader) - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id, - shard_id=shard_id) + weight_loader( + param, + loaded_weight, + weight_name, + expert_id=expert_id, + shard_id=shard_id, + ) loaded_params.add(name) break else: if is_pp_missing_parameter(name, self): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -800,8 +845,9 @@ def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor, def is_shared_mlp_weight(name: str) -> bool: return "shared_mlp" in name and not name.endswith(".bias") - def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: + def load_shared_mlp_weight( + name: str, loaded_weight: torch.Tensor, self + ) -> None: if not self.CONCAT_FFN: if "gate_proj" in name: name = name.replace("gate_proj", "w1", 1) @@ -819,8 +865,7 @@ def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor, if is_pp_missing_parameter(name, self): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) if not self.CONCAT_FFN: weight_loader(param, loaded_weight) @@ -830,31 +875,31 @@ def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor, elif "down_proj" in name: weight_loader(param, loaded_weight) else: - raise AssertionError( - "MLP weight not in [gate_up_proj, down_proj]") + raise AssertionError("MLP weight not in [gate_up_proj, down_proj]") loaded_params.add(name) return def is_mha_weight(name: str) -> bool: return "self_attn" in name and not name.endswith(".bias") - def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: + def load_linear_attn_weight( + name: str, loaded_weight: torch.Tensor, self + ) -> None: if is_pp_missing_parameter(name, self): return param = params_dict[name] weight_loader = getattr( - param, "weight_loader", - MiniMaxText01LinearAttention.weight_direct_load) + param, "weight_loader", MiniMaxText01LinearAttention.weight_direct_load + ) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return - def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: - + def load_flash_attn_weight( + name: str, loaded_weight: torch.Tensor, self + ) -> None: flash_mha_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -862,16 +907,14 @@ def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor, ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - for (param_name, weight_name, - shard_id) in flash_mha_params_mapping: + for param_name, weight_name, shard_id in flash_mha_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) @@ -881,36 +924,32 @@ def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor, return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return def is_layer_norm_weight(name: str) -> bool: - return "norm" in name and not name.endswith( - ".bias") and name in params_dict + return "norm" in name and not name.endswith(".bias") and name in params_dict - def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: + def load_layer_norm_weight( + name: str, loaded_weight: torch.Tensor, self + ) -> None: if is_pp_missing_parameter(name, self): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return - def load_basic_weight(name: str, loaded_weight: torch.Tensor, - self) -> None: + def load_basic_weight(name: str, loaded_weight: torch.Tensor, self) -> None: if is_pp_missing_parameter(name, self): return param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = weight_loader_with_alias(name)(weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) @@ -919,7 +958,8 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor, for name, loaded_weight in weights: weight_at_layer = which_layer(name) if weight_at_layer and weight_at_layer >= len( - self.model.decoder_attention_types): + self.model.decoder_attention_types + ): continue if is_layer_norm_weight(name): @@ -949,7 +989,6 @@ def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.linear_attention_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index 938c9a689fcf..a25a7097a6ec 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -7,12 +7,13 @@ import torch.nn as nn from transformers import BatchFeature, PretrainedConfig from transformers.models.llava_next.modeling_llava_next import ( - get_anyres_image_grid_shape, unpad_image) + get_anyres_image_grid_shape, + unpad_image, +) from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig @@ -21,13 +22,19 @@ from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .llava import (BaseLlavaMultiModalProcessor, LlavaDummyInputsBuilder, - init_vision_tower_for_llava) +from .llava import ( + BaseLlavaMultiModalProcessor, + LlavaDummyInputsBuilder, + init_vision_tower_for_llava, +) from .llava_next import LlavaNextProcessingInfo from .pixtral import PixtralHFVisionModel from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + init_vllm_registered_model, + maybe_prefix, +) class MiniMaxVL01ImagePixelInputs(TensorSchema): @@ -42,10 +49,12 @@ class MiniMaxVL01ImagePixelInputs(TensorSchema): Note that `num_patches` may be different per batch and image, in which case the data is passed as a list instead of a batched tensor. """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"})] + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np", "h", "w"}), + ] image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] # This should be in `(height, width)` format. @@ -58,36 +67,43 @@ class MiniMaxVL01ImageEmbeddingInputs(TensorSchema): - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -MiniMaxVL01ImageInputs = Union[MiniMaxVL01ImagePixelInputs, - MiniMaxVL01ImageEmbeddingInputs] +MiniMaxVL01ImageInputs = Union[ + MiniMaxVL01ImagePixelInputs, MiniMaxVL01ImageEmbeddingInputs +] class MiniMaxVL01MultiModalProjector(nn.Module): - - def __init__(self, - vision_hidden_size: int, - text_hidden_size: int, - projector_hidden_act: str, - multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.linear_1 = ColumnParallelLinear(vision_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_1") + self.linear_1 = ColumnParallelLinear( + vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) self.act = get_act_fn(projector_hidden_act) - self.linear_2 = RowParallelLinear(text_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_2") + self.linear_2 = RowParallelLinear( + text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.linear_1(image_features) @@ -101,15 +117,13 @@ class MiniMaxVL01DummyInputsBuilder(LlavaDummyInputsBuilder): class MiniMaxVL01ProcessingInfo(LlavaNextProcessingInfo): - def get_hf_config(self): # Need to override the config type return self.ctx.get_hf_config(PretrainedConfig) def get_hf_processor(self, **kwargs: object): hf_processor = self.ctx.get_hf_processor(**kwargs) image_processor = hf_processor.image_processor - image_processor.anyres_preprocess = ( - image_processor.anyres_for_vllm_preprocess) + image_processor.anyres_preprocess = image_processor.anyres_for_vllm_preprocess return hf_processor @@ -118,8 +132,8 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: class MiniMaxVL01MultiModalProcessor( - BaseLlavaMultiModalProcessor[MiniMaxVL01ProcessingInfo]): - + BaseLlavaMultiModalProcessor[MiniMaxVL01ProcessingInfo] +): def _call_hf_processor( self, prompt: str, @@ -162,13 +176,14 @@ def _get_mm_fields_config( @MULTIMODAL_REGISTRY.register_processor( MiniMaxVL01MultiModalProcessor, info=MiniMaxVL01ProcessingInfo, - dummy_inputs=MiniMaxVL01DummyInputsBuilder) -class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + dummy_inputs=MiniMaxVL01DummyInputsBuilder, +) +class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } @classmethod @@ -193,16 +208,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = MiniMaxVL01MultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act, multimodal_projector_bias=True, quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) - self.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size)) + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) + self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, @@ -215,15 +231,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.pad_token_id = self.config.pad_token_id self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def get_language_model(self) -> torch.nn.Module: return self.language_model def _image_pixels_to_features( self, - vision_tower: Union[CLIPVisionModel, SiglipVisionModel, - PixtralHFVisionModel], + vision_tower: Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel], pixel_values: Union[torch.Tensor, list[torch.Tensor]], ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: # NOTE: we skip the step to select the vision feature layer since @@ -231,55 +247,55 @@ def _image_pixels_to_features( feature_select_strategy = self.config.vision_feature_select_strategy return tuple( vision_tower(p, feature_select_strategy=feature_select_strategy) - for p in pixel_values) + for p in pixel_values + ) # adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631 - def pack_image_features(self, image_features: list[torch.Tensor], - image_sizes: torch.Tensor): + def pack_image_features( + self, image_features: list[torch.Tensor], image_sizes: torch.Tensor + ): new_image_features = [] for image_idx, image_feature in enumerate(image_features): if image_feature.shape[0] > 1: base_image_feature = image_feature[0] image_feature = image_feature[1:] - height = width = (self.config.vision_config.image_size // - self.config.vision_config.patch_size) + height = width = ( + self.config.vision_config.image_size + // self.config.vision_config.patch_size + ) if height * width != base_image_feature.shape[0]: raise ValueError( - "The number of patches is not consistent with " - "the image size.") + "The number of patches is not consistent with the image size." + ) num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_sizes[image_idx], self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) - image_feature = image_feature.view(num_patch_height, - num_patch_width, height, - width, -1) - image_feature = image_feature.permute(4, 0, 2, 1, - 3).contiguous() + image_feature = image_feature.view( + num_patch_height, num_patch_width, height, width, -1 + ) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, - image_sizes[image_idx]) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = torch.cat( ( image_feature, - self.image_newline[:, None, None].expand( - *image_feature.shape[:-1], 1).to( - image_feature.dtype), + self.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.dtype), ), dim=-1, ) image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), - dim=0) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) else: image_feature = image_feature[0] image_feature = torch.cat( - (image_feature, - self.image_newline[None].to(image_feature)), - dim=0) + (image_feature, self.image_newline[None].to(image_feature)), dim=0 + ) new_image_features.append(image_feature) return new_image_features @@ -305,9 +321,7 @@ def _process_image_input( if isinstance(image_features, torch.Tensor): return self.multi_modal_projector(image_features) - feature_sizes = [ - image_feature.shape[0] for image_feature in image_features - ] + feature_sizes = [image_feature.shape[0] for image_feature in image_features] image_embeds = self.multi_modal_projector(torch.cat(image_features)) image_embeds = torch.split(image_embeds, feature_sizes) @@ -315,7 +329,8 @@ def _process_image_input( return self.pack_image_features(image_embeds, image_sizes) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[MiniMaxVL01ImageInputs]: + self, **kwargs: object + ) -> Optional[MiniMaxVL01ImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -324,34 +339,21 @@ def _parse_and_validate_image_input( return None if pixel_values is not None and image_sizes is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - - if not isinstance(image_sizes, (torch.Tensor, list)): - raise ValueError("Incorrect type of image sizes. " - f"Got type: {type(image_sizes)}") - return MiniMaxVL01ImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values), - image_sizes=flatten_bn(image_sizes, concat=True), + pixel_values=pixel_values, + image_sizes=image_sizes, ) if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - return MiniMaxVL01ImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -366,7 +368,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - if intermediate_tensors is not None: inputs_embeds = None elif inputs_embeds is None: @@ -378,10 +379,9 @@ def forward( ) input_ids = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -391,7 +391,6 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index d7c48758cca7..8e74425c5dbd 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -3,43 +3,58 @@ from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence -from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union) +from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union import torch import torch.nn as nn -from transformers import (BatchFeature, Mistral3Config, PixtralVisionConfig, - PretrainedConfig) +from transformers import ( + BatchFeature, + Mistral3Config, + PixtralVisionConfig, + PretrainedConfig, +) from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - InputProcessingContext, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vision_encoder_info @@ -67,38 +82,43 @@ class Mistral3PatchMerger(nn.Module): Learned merging of spatial_merge_size ** 2 patches """ - def __init__(self, vision_hidden_size: int, spatial_merge_size: int, - patch_size: int): + def __init__( + self, vision_hidden_size: int, spatial_merge_size: int, patch_size: int + ): super().__init__() self.vision_hidden_size = vision_hidden_size self.spatial_merge_size = spatial_merge_size self.patch_size = patch_size - self.merging_layer = nn.Linear(vision_hidden_size * - self.spatial_merge_size**2, - vision_hidden_size, - bias=False) + self.merging_layer = nn.Linear( + vision_hidden_size * self.spatial_merge_size**2, + vision_hidden_size, + bias=False, + ) - def forward(self, image_features: torch.Tensor, - image_sizes: torch.Tensor) -> torch.Tensor: - image_sizes = [(image_size[0] // self.patch_size, - image_size[1] // self.patch_size) - for image_size in image_sizes] + def forward( + self, image_features: torch.Tensor, image_sizes: torch.Tensor + ) -> torch.Tensor: + image_sizes = [ + (image_size[0] // self.patch_size, image_size[1] // self.patch_size) + for image_size in image_sizes + ] tokens_per_image = [h * w for h, w in image_sizes] d = image_features.shape[-1] permuted_tensor = [] for image_index, image_tokens in enumerate( - image_features.split(tokens_per_image)): + image_features.split(tokens_per_image) + ): # Reshape image_tokens into a 2D grid h, w = image_sizes[image_index] - image_grid = image_tokens.view(h, w, d).permute(2, 0, - 1).unsqueeze(0) + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0) grid = torch.nn.functional.unfold( image_grid, kernel_size=self.spatial_merge_size, - stride=self.spatial_merge_size) + stride=self.spatial_merge_size, + ) grid = grid.view(d * self.spatial_merge_size**2, -1).t() permuted_tensor.append(grid) @@ -108,38 +128,45 @@ def forward(self, image_features: torch.Tensor, class Mistral3MultiModalProjector(nn.Module): - - def __init__(self, - vision_hidden_size: int, - text_hidden_size: int, - spatial_merge_size: int, - patch_size: int, - projector_hidden_act: str, - multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + spatial_merge_size: int, + patch_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.norm = RMSNorm(vision_hidden_size, eps=1e-5) self.patch_merger = Mistral3PatchMerger( vision_hidden_size=vision_hidden_size, spatial_merge_size=spatial_merge_size, - patch_size=patch_size) + patch_size=patch_size, + ) - self.linear_1 = ColumnParallelLinear(vision_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_1") + self.linear_1 = ColumnParallelLinear( + vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) self.act = get_act_fn(projector_hidden_act) - self.linear_2 = RowParallelLinear(text_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_2") - - def forward(self, image_features: torch.Tensor, - image_sizes: torch.Tensor) -> torch.Tensor: + self.linear_2 = RowParallelLinear( + text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) + + def forward( + self, image_features: torch.Tensor, image_sizes: torch.Tensor + ) -> torch.Tensor: image_features = self.norm(image_features) image_features = self.patch_merger(image_features, image_sizes) hidden_states, _ = self.linear_1(image_features) @@ -160,7 +187,6 @@ class LlavaLikeProcessor(Protocol): class BaseLlavaProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> LlavaLikeConfig: return self.ctx.get_hf_config(Mistral3Config) @@ -196,7 +222,6 @@ def get_image_size_with_most_features(self) -> ImageSize: class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -213,29 +238,26 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class Mistral3ProcessingInfo(BaseLlavaProcessingInfo): - def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(PixtralProcessor, **kwargs) -class Mistral3MultiModalProcessor( - BaseMultiModalProcessor[Mistral3ProcessingInfo]): - +class Mistral3MultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -252,7 +274,6 @@ def _call_hf_processor( pixel_values = processed_outputs.get("pixel_values") if pixel_values is not None: - # Avoid padding since we need the output for each image to be # independent of other images for the cache to work correctly image_sizes = processed_outputs["image_sizes"] @@ -316,7 +337,8 @@ def get_replacement(item_idx: int): def _build_mistral3_info( - ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo: + ctx: InputProcessingContext, +) -> BaseLlavaProcessingInfo: hf_config = ctx.get_hf_config(Mistral3Config) assert isinstance(hf_config.vision_config, PixtralVisionConfig) return Mistral3ProcessingInfo(ctx) @@ -339,7 +361,7 @@ def _build_mistral3_processor( def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: """Determine the number of hidden layers to initialize up to in the visual encoder. - + Args: hf_config: Model config with vision feature layer(s). """ @@ -350,10 +372,10 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: return _get_layer_index(feature_layers, num_hidden_layers) # If we have multiple feature layers, initialize up to the deepest one elif isinstance(feature_layers, (list, tuple)): - return max( - _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) - raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" - " is not supported") + return max(_get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError( + f"vision_layer_feature type: {type(feature_layers)} is not supported" + ) def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: @@ -396,13 +418,16 @@ def init_vision_tower_for_llava( @MULTIMODAL_REGISTRY.register_processor( _build_mistral3_processor, info=_build_mistral3_info, - dummy_inputs=Mistral3DummyInputsBuilder) -class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, - SupportsMultiModal, SupportsPP): + dummy_inputs=Mistral3DummyInputsBuilder, +) +class Mistral3ForConditionalGeneration( + nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP +): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } hf_to_vllm_mapper = WeightsMapper( @@ -412,7 +437,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -433,11 +459,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # NOTE: These are special cases for Pixtral-12B in the HF-format # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa - if (config.text_config.architectures is None - and config.text_config.model_type == "mistral"): + if ( + config.text_config.architectures is None + and config.text_config.model_type == "mistral" + ): config.text_config.architectures = ["MistralForCausalLM"] - if (config.projector_hidden_act is None - and config.vision_config.hidden_act == "gelu"): + if ( + config.projector_hidden_act is None + and config.vision_config.hidden_act == "gelu" + ): config.projector_hidden_act = "gelu" # TODO: Optionally initializes this for supporting embeddings. @@ -446,7 +476,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = Mistral3MultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, @@ -455,7 +486,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: patch_size=config.vision_config.patch_size, multimodal_projector_bias=config.multimodal_projector_bias, quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) else: self.vision_tower = None self.multi_modal_projector = None @@ -467,24 +499,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Mistral3ImagePixelInputs]: + self, **kwargs: object + ) -> Optional[Mistral3ImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: return None - assert pixel_values is not None - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - return Mistral3ImagePixelInputs( type="pixel_values_pixtral", - pixel_values=flatten_bn(pixel_values), + pixel_values=pixel_values, ) def _process_image_input( @@ -494,8 +523,9 @@ def _process_image_input( if image_input["type"] == "image_embeds": return image_input["data"] - image_sizes = [(img.shape[-2], img.shape[-1]) - for img in image_input["pixel_values"]] + image_sizes = [ + (img.shape[-2], img.shape[-1]) for img in image_input["pixel_values"] + ] image_features = self.vision_tower(image_input["pixel_values"]) @@ -507,19 +537,19 @@ def _process_image_input( for image_feature in image_features ] - image_embeds = self.multi_modal_projector(torch.cat(image_features), - image_sizes) + image_embeds = self.multi_modal_projector( + torch.cat(image_features), image_sizes + ) if len(feature_sizes) > 1: image_embeds = torch.split(image_embeds, feature_sizes) else: - image_embeds = (image_embeds, ) + image_embeds = (image_embeds,) return image_embeds def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -576,10 +606,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -589,8 +618,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.vision_tower is None and self.multi_modal_projector is None: skip_prefixes = ["vision_tower.", "multi_modal_projector."] @@ -605,4 +633,5 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", - tower_model="vision_tower") + tower_model="vision_tower", + ) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index bebf0b5adac5..37b49349ec12 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" + import typing from collections.abc import Callable, Iterable from itertools import islice @@ -35,26 +36,41 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class MixtralMoE(nn.Module): @@ -66,17 +82,19 @@ class MixtralMoE(nn.Module): across ranks. """ - def __init__(self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - dp_size: Optional[int] = None, - prefix: str = "", - enable_eplb: bool = False): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + enable_eplb: bool = False, + ): super().__init__() self.hidden_size = hidden_size @@ -91,38 +109,40 @@ def __init__(self, self.n_routed_experts = num_experts self.n_logical_experts = num_experts - self.n_redundant_experts = ( - parallel_config.eplb_config.num_redundant_experts) - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_redundant_experts = parallel_config.eplb_config.num_redundant_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, - num_experts, - bias=False, - params_dtype=params_dtype, - quant_config=None, - prefix=f"{prefix}.gate") - - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=tp_size, - dp_size=dp_size, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + dp_size=dp_size, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -135,7 +155,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__( self, config: MixtralConfig, @@ -196,13 +215,15 @@ def __init__( base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -218,7 +239,6 @@ def forward( class MixtralDecoderLayer(nn.Module): - def __init__( self, config: MixtralConfig, @@ -240,7 +260,8 @@ def __init__( rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.self_attn") + prefix=f"{prefix}.self_attn", + ) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, @@ -248,11 +269,12 @@ def __init__( intermediate_size=config.intermediate_size, quant_config=quant_config, prefix=f"{prefix}.block_sparse_moe", - enable_eplb=enable_eplb) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + enable_eplb=enable_eplb, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -265,23 +287,20 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.block_sparse_moe(hidden_states) return hidden_states, residual @support_torch_compile class MixtralModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -293,8 +312,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -305,8 +327,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.enable_eplb = parallel_config.enable_eplb - self.num_redundant_experts = ( - parallel_config.eplb_config.num_redundant_experts) + self.num_redundant_experts = parallel_config.eplb_config.num_redundant_experts self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -317,12 +338,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=prefix, enable_eplb=self.enable_eplb, ), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -347,10 +369,9 @@ def forward( for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -362,10 +383,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", num_experts=self.config.num_local_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=self.num_redundant_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -377,25 +398,27 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -424,20 +447,23 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name_mapped, self): continue - if ((name_mapped.endswith(".bias") - or name_mapped.endswith("_bias")) - and name_mapped not in params_dict): + if ( + name_mapped.endswith(".bias") or name_mapped.endswith("_bias") + ) and name_mapped not in params_dict: continue param = params_dict[name_mapped] - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) if success: name = name_mapped break @@ -445,8 +471,9 @@ def load_weights(self, weights: Iterable[tuple[str, if is_expert_weight: continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -457,15 +484,15 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, - MixtureOfExperts): +class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -492,8 +519,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.model = MixtralModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = MixtralModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -504,16 +532,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) self.expert_weights = [] self.moe_layers: list[FusedMoE] = [] @@ -524,7 +555,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): continue assert isinstance(layer, MixtralDecoderLayer) if hasattr(layer, "block_sparse_moe") and isinstance( - layer.block_sparse_moe, MixtralMoE): + layer.block_sparse_moe, MixtralMoE + ): example_moe = layer.block_sparse_moe self.moe_layers.append(layer.block_sparse_moe.experts) @@ -565,11 +597,11 @@ def update_physical_experts_metadata( assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) + self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if hasattr(layer, "block_sparse_moe") and isinstance( - layer.block_sparse_moe, MixtralMoE): + layer.block_sparse_moe, MixtralMoE + ): moe = layer.block_sparse_moe moe.n_local_physical_experts = num_local_physical_experts moe.n_physical_experts = num_physical_experts @@ -586,8 +618,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -597,8 +630,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 9864ca2dc474..b624a6200ab3 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -27,35 +27,49 @@ from transformers.image_utils import SizeDict from transformers.models.llama4 import Llama4Processor from transformers.models.llama4.image_processing_llama4_fast import ( - find_supported_resolutions, get_best_fit) + find_supported_resolutions, + get_best_fit, +) from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - InputProcessingContext, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import ( + MultiModalEmbeddings, + SupportsEagle3, + SupportsMultiModal, + SupportsPP, +) from .llama4 import Llama4ForCausalLM from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix from .vision import run_dp_sharded_vision_model @@ -72,9 +86,10 @@ class Llama4ImagePatchInputs(TensorSchema): type: Literal["pixel_values"] = "pixel_values" - flat_data: Annotated[torch.Tensor, - TensorShape("total_num_chunks", "num_channels", - "image_size", "image_size")] + flat_data: Annotated[ + torch.Tensor, + TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"), + ] patches_per_image: Annotated[torch.Tensor, TensorShape("batch_size")] """ @@ -93,7 +108,6 @@ class Llama4ImagePatchInputs(TensorSchema): class Llama4VisionMLP(nn.Module): - def __init__( self, input_size: int, @@ -135,7 +149,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Llama4MultiModalProjector(nn.Module): - def __init__( self, config, @@ -165,9 +178,9 @@ def pixel_shuffle(input_tensor, shuffle_ratio): input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) batch_size, height, width, channels = input_tensor.size() - reshaped_tensor = input_tensor.view(batch_size, height, - int(width * shuffle_ratio), - int(channels / shuffle_ratio)) + reshaped_tensor = input_tensor.view( + batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio) + ) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() reshaped_tensor = reshaped_tensor.view( @@ -178,13 +191,11 @@ def pixel_shuffle(input_tensor, shuffle_ratio): ) reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() - output_tensor = reshaped_tensor.view(batch_size, -1, - reshaped_tensor.shape[-1]) + output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1]) return output_tensor class Llama4VisionPixelShuffleMLP(nn.Module): - def __init__( self, config, @@ -194,8 +205,9 @@ def __init__( ): super().__init__() self.pixel_shuffle_ratio = config.pixel_shuffle_ratio - self.inner_dim = int(config.projector_input_dim // - (self.pixel_shuffle_ratio**2)) + self.inner_dim = int( + config.projector_input_dim // (self.pixel_shuffle_ratio**2) + ) self.output_dim = config.projector_output_dim self.mlp = Llama4VisionMLP( input_size=config.intermediate_size, @@ -209,13 +221,11 @@ def __init__( ) def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: - encoded_patches = pixel_shuffle(encoded_patches, - self.pixel_shuffle_ratio) + encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) return self.mlp(encoded_patches) class Llama4VisionAttention(nn.Module): - def __init__( self, config: Llama4VisionConfig, @@ -225,8 +235,9 @@ def __init__( ): super().__init__() self.config = config - self.tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads @@ -237,8 +248,9 @@ def __init__( self.attention_dropout = config.attention_dropout self.scaling = self.head_dim**-0.5 - self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim, - self.scaling) + self.attn = MultiHeadAttention( + self.num_local_heads, self.head_dim, self.scaling + ) if use_data_parallel: self.qkv_proj = ReplicatedLinear( @@ -277,7 +289,7 @@ def __init__( head_size=self.head_dim, rotary_dim=config.hidden_size // config.num_attention_heads // 2, # number of image patches - max_position=(config.image_size // config.patch_size)**2, + max_position=(config.image_size // config.patch_size) ** 2, base=config.rope_theta, rope_scaling={"rope_type": "mllama4"}, is_neox_style=False, @@ -308,7 +320,6 @@ def forward( class Llama4VisionEncoderLayer(nn.Module): - def __init__( self, config: Llama4VisionConfig, @@ -357,12 +368,11 @@ def forward( hidden_state = self.mlp(hidden_state) hidden_state = residual + hidden_state - outputs = (hidden_state, ) + outputs = (hidden_state,) return outputs class Llama4VisionEncoder(nn.Module): - def __init__( self, config: Llama4VisionConfig, @@ -372,14 +382,17 @@ def __init__( ): super().__init__() self.config = config - self.layers = nn.ModuleList([ - Llama4VisionEncoderLayer( - config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}", - use_data_parallel=use_data_parallel, - ) for layer_idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Llama4VisionEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -387,9 +400,9 @@ def forward( ) -> torch.Tensor: r""" Args: - hidden_states: Input tensor of shape + hidden_states: Input tensor of shape (batch_size, sequence_length, hidden_size). - Hidden states from the model embeddings, representing + Hidden states from the model embeddings, representing the input tokens. associated vectors than the model's internal embedding lookup matrix. @@ -403,7 +416,6 @@ def forward( class Llama4UnfoldConvolution(nn.Module): - def __init__( self, config: Llama4VisionConfig, @@ -415,8 +427,7 @@ def __init__( kernel_size = config.patch_size if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) - self.unfold = torch.nn.Unfold(kernel_size=kernel_size, - stride=config.patch_size) + self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) self.linear = ColumnParallelLinear( input_size=config.num_channels * kernel_size[0] * kernel_size[1], output_size=config.hidden_size, @@ -435,7 +446,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Llama4VisionModel(nn.Module): - def __init__( self, config: Llama4VisionConfig, @@ -450,7 +460,7 @@ def __init__( self.hidden_size = config.hidden_size self.num_channels = config.num_channels - self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 self.scale = config.hidden_size**-0.5 self.patch_embedding = Llama4UnfoldConvolution( @@ -460,10 +470,10 @@ def __init__( use_data_parallel=use_data_parallel, ) - self.class_embedding = nn.Parameter(self.scale * - torch.randn(self.hidden_size)) + self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) self.positional_embedding_vlm = nn.Parameter( - self.scale * torch.randn(self.num_patches, self.hidden_size)) + self.scale * torch.randn(self.num_patches, self.hidden_size) + ) # layer norms self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5) @@ -492,8 +502,9 @@ def forward( num_tiles, num_patches, hidden_dim = hidden_state.shape # Add cls token - class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, - hidden_state.shape[-1]) + class_embedding = self.class_embedding.expand( + hidden_state.shape[0], 1, hidden_state.shape[-1] + ) hidden_state = torch.cat([hidden_state, class_embedding], dim=1) num_patches += 1 @@ -505,7 +516,8 @@ def forward( hidden_dim, ) positional_embedding = self.positional_embedding_vlm.to( - dtype=hidden_state.dtype, device=hidden_state.device) + dtype=hidden_state.dtype, device=hidden_state.device + ) hidden_state = hidden_state + positional_embedding hidden_state = self.layernorm_pre(hidden_state) hidden_state = hidden_state.view(num_tiles, -1, hidden_dim) @@ -524,7 +536,6 @@ def forward( class Mllama4ProcessingInfo(BaseProcessingInfo): - def __init__(self, ctx: InputProcessingContext) -> None: super().__init__(ctx) @@ -532,9 +543,9 @@ def get_hf_config(self) -> Llama4Config: return self.ctx.get_hf_config(Llama4Config) def get_hf_processor(self, **kwargs: object) -> Llama4Processor: - return self.ctx.get_hf_processor(Llama4Processor, - use_fast=kwargs.pop("use_fast", True), - **kwargs) + return self.ctx.get_hf_processor( + Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs + ) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: # Although vLLM can support more images from an infra capability @@ -546,13 +557,13 @@ def get_patch_per_chunk(vision_config: Llama4VisionConfig) -> int: image_size = vision_config.image_size patch_size = vision_config.patch_size - assert ( - image_size % - patch_size == 0), f"chunk size {image_size} should be multiple of " + assert image_size % patch_size == 0, ( + f"chunk size {image_size} should be multiple of " + ) f"patch_size {patch_size}" ds_ratio = int(round(1.0 / (vision_config.pixel_shuffle_ratio**2))) - return (image_size // patch_size)**2 // ds_ratio + return (image_size // patch_size) ** 2 // ds_ratio def get_max_num_tiles(self) -> int: image_processor = self.get_hf_processor().image_processor @@ -562,13 +573,10 @@ def get_image_size_with_most_features(self) -> ImageSize: vision_config = self.get_hf_config().vision_config image_size = vision_config.image_size # Result in the max possible feature size (h:w = 16:1) - return ImageSize(height=self.get_max_num_tiles() * image_size, - width=image_size) - + return ImageSize(height=self.get_max_num_tiles() * image_size, width=image_size) -class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] - ): +class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -592,15 +600,16 @@ def _call_hf_processor( vision_config = self.info.get_hf_config().vision_config if processed_outputs.get("pixel_values") is not None: - assert ( - "images" in mm_data - ), "images expected to be in mm_data when pixel_values is present" + assert "images" in mm_data, ( + "images expected to be in mm_data when pixel_values is present" + ) images = mm_data["images"] - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": - images - }).get_items("image", ImageProcessorItems)) + parsed_images = ( + self._get_data_parser() + .parse_mm_data({"image": images}) + .get_items("image", ImageProcessorItems) + ) tile_size = vision_config.image_size possible_resolutions = find_supported_resolutions( @@ -612,20 +621,20 @@ def _call_hf_processor( (image.size[1], image.size[0]), torch.tensor(possible_resolutions), resize_to_max_canvas=image_processor.resize_to_max_canvas, - ) for image in parsed_images + ) + for image in parsed_images ] # TODO tile height/width do not necessarily need to match - aspect_ratios = [(image_size[0] // tile_size, - image_size[1] // tile_size) - for image_size in best_fit_sizes] + aspect_ratios = [ + (image_size[0] // tile_size, image_size[1] // tile_size) + for image_size in best_fit_sizes + ] patches_per_image = [ - 1 if r_h * r_w == 1 else 1 + r_h * r_w - for (r_h, r_w) in aspect_ratios + 1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios ] processed_outputs["aspect_ratios"] = torch.tensor(aspect_ratios) - processed_outputs["patches_per_image"] = torch.tensor( - patches_per_image) + processed_outputs["patches_per_image"] = torch.tensor(patches_per_image) return processed_outputs @@ -637,7 +646,8 @@ def _get_mm_fields_config( patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0)) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", patches_per_image), + "image", patches_per_image + ), patches_per_image=MultiModalFieldConfig.batched("image"), aspect_ratios=MultiModalFieldConfig.batched("image"), ) @@ -677,7 +687,6 @@ def get_replacement(item_idx: int): class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -694,17 +703,17 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - (target_width, - target_height) = self.info.get_image_size_with_most_features() + (target_width, target_height) = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } @@ -713,8 +722,9 @@ def get_dummy_mm_data( info=Mllama4ProcessingInfo, dummy_inputs=Mllama4DummyInputsBuilder, ) -class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +class Llama4ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 +): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -747,24 +757,42 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): use_data_parallel=self.use_data_parallel, ) self.multi_modal_projector = Llama4MultiModalProjector( - self.config, - None, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + self.config, None, prefix=maybe_prefix(prefix, "multi_modal_projector") + ) else: self.vision_model = None self.multi_modal_projector = None self.language_model = initialize_model( - vllm_config=vllm_config.with_hf_config(config.text_config, - ["LlamaForCausalLM"]), + vllm_config=vllm_config.with_hf_config( + config.text_config, ["LlamaForCausalLM"] + ), prefix=maybe_prefix(prefix, "language_model"), model_class=Llama4ForCausalLM, ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) + + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + """Set which layers should output auxiliary hidden states for EAGLE3.""" + # Delegate to underlying language model (Llama4ForCausalLM) + assert hasattr(self.language_model, "set_aux_hidden_state_layers") + self.language_model.set_aux_hidden_state_layers(layers) + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + """Get the layer indices for auxiliary hidden state outputs. + + Note: The GPU model runner will override this with layers from + the speculative config if available, providing dynamic configuration. + """ + # Delegate to underlying language model (Llama4ForCausalLM) + assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers") + return self.language_model.get_eagle3_aux_hidden_state_layers() def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Llama4ImagePatchInputs]: + self, **kwargs: object + ) -> Optional[Llama4ImagePatchInputs]: # num_images, 1, num_chunks, channel, image_size, image_size pixel_values = kwargs.pop("pixel_values", None) if pixel_values is None: @@ -786,8 +814,8 @@ def _parse_and_validate_image_input( ) def _process_image_input( - self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings: - + self, image_input: Llama4ImagePatchInputs + ) -> MultiModalEmbeddings: assert self.vision_model and self.multi_modal_projector flat_data = image_input["flat_data"] patches_per_image = image_input["patches_per_image"].tolist() @@ -795,12 +823,12 @@ def _process_image_input( # shard image input if self.use_data_parallel: vision_embeddings_flat = run_dp_sharded_vision_model( - flat_data, self.vision_model) + flat_data, self.vision_model + ) else: vision_embeddings_flat = self.vision_model(flat_data) - vision_embeddings_flat = self.multi_modal_projector( - vision_embeddings_flat) + vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat) return [ img.flatten(0, 1) @@ -828,8 +856,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - return self.language_model(input_ids, positions, intermediate_tensors, - inputs_embeds) + return self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) def compute_logits( self, @@ -841,8 +870,7 @@ def separate_weights( self, weights: Iterable[tuple[str, torch.Tensor]], prefix: str, - ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[ - str, torch.Tensor]]]: + ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]: weights1, weights2 = tee(weights, 2) def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]: @@ -884,31 +912,33 @@ def _consolidate_qkv_weights( def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str: """Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM format.""" - if name.startswith("model.") or name.startswith( - "language_model.model."): - renamed = name.replace("model.", "language_model.model.", - 1) if name.startswith("model.") else name + if name.startswith("model.") or name.startswith("language_model.model."): + renamed = ( + name.replace("model.", "language_model.model.", 1) + if name.startswith("model.") + else name + ) # Handle expert scale parameters with flat naming - if "feed_forward.experts." in name and ("_input_scale" in name or - "_weight_scale" in name): + if "feed_forward.experts." in name and ( + "_input_scale" in name or "_weight_scale" in name + ): # Map checkpoint naming to vLLM's expected naming if "down_proj_input_scale" in renamed: - return renamed.replace("down_proj_input_scale", - "w2_input_scale") + return renamed.replace("down_proj_input_scale", "w2_input_scale") elif "down_proj_weight_scale" in renamed: - return renamed.replace("down_proj_weight_scale", - "w2_weight_scale") + return renamed.replace("down_proj_weight_scale", "w2_weight_scale") elif "gate_up_proj_input_scale" in renamed: - return renamed.replace("gate_up_proj_input_scale", - "w13_input_scale") + return renamed.replace( + "gate_up_proj_input_scale", "w13_input_scale" + ) elif "gate_up_proj_weight_scale" in renamed: - return renamed.replace("gate_up_proj_weight_scale", - "w13_weight_scale") + return renamed.replace( + "gate_up_proj_weight_scale", "w13_weight_scale" + ) return renamed # Handle attention scale parameters - elif "self_attn." in name and (".k_scale" in name - or ".v_scale" in name): + elif "self_attn." in name and (".k_scale" in name or ".v_scale" in name): if ".k_proj.k_scale" in renamed: return renamed.replace(".k_proj.k_scale", ".attn.k_scale") elif ".v_proj.v_scale" in renamed: @@ -919,8 +949,7 @@ def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str: return renamed elif name.startswith("lm_head.weight"): - return name.replace("lm_head.weight", - "language_model.lm_head.weight") + return name.replace("lm_head.weight", "language_model.lm_head.weight") return name @@ -943,7 +972,7 @@ def _separate_and_rename_weights( return language_model_weights, other_weights def _handle_expert_scale_broadcasting( - self, weights: list[tuple[str, torch.Tensor]], params_dict: dict + self, weights: list[tuple[str, torch.Tensor]], params_dict: dict ) -> tuple[list[tuple[str, torch.Tensor]], set[str]]: """Handle expert scale parameters that need broadcasting. @@ -956,12 +985,18 @@ def _handle_expert_scale_broadcasting( for name, weight in weights: # Check if this is an expert scale parameter that needs broadcasting - if ("feed_forward.experts." in name and "scale" in name - and ".shared_expert" not in name): + if ( + "feed_forward.experts." in name + and "scale" in name + and ".shared_expert" not in name + ): if name in params_dict: param = params_dict[name] - if (hasattr(param, 'data') and param.data.numel() > 1 - and weight.numel() == 1): + if ( + hasattr(param, "data") + and param.data.numel() > 1 + and weight.numel() == 1 + ): # Broadcast single value to all experts param.data.fill_(weight.item()) updated_params.add(name) @@ -973,10 +1008,12 @@ def _handle_expert_scale_broadcasting( return regular_weights, expert_scale_weights, updated_params - def _load_other_weights(self, other_weights: Iterable[tuple[str, - torch.Tensor]], - params_dict: dict, - stacked_params_mapping: list) -> set[str]: + def _load_other_weights( + self, + other_weights: Iterable[tuple[str, torch.Tensor]], + params_dict: dict, + stacked_params_mapping: list, + ) -> set[str]: """Load non-language-model weights with stacking support.""" updated_params = set() @@ -997,16 +1034,13 @@ def _load_other_weights(self, other_weights: Iterable[tuple[str, else: # Use regular weight loading param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) updated_params.add(name) return updated_params - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), @@ -1023,8 +1057,9 @@ def load_weights(self, weights: Iterable[tuple[str, updated_params: set[str] = set() # Separate and rename weights - language_model_weights, other_weights = ( - self._separate_and_rename_weights(weights)) + language_model_weights, other_weights = self._separate_and_rename_weights( + weights + ) # Skip loading vision model and projector if they're not initialized. if self.vision_model is None and self.multi_modal_projector is None: @@ -1032,8 +1067,8 @@ def load_weights(self, weights: Iterable[tuple[str, # Handle expert scale parameters regular_weights, expert_scale_weights, updated_params_from_experts = ( - self._handle_expert_scale_broadcasting(language_model_weights, - params_dict)) + self._handle_expert_scale_broadcasting(language_model_weights, params_dict) + ) updated_params.update(updated_params_from_experts) loader = AutoWeightsLoader(self) @@ -1042,13 +1077,12 @@ def load_weights(self, weights: Iterable[tuple[str, updated_params.update(loaded_language_model_params) if expert_scale_weights: - loaded_expert_scale_params = loader.load_weights( - expert_scale_weights) + loaded_expert_scale_params = loader.load_weights(expert_scale_weights) if loaded_expert_scale_params: updated_params.update(loaded_expert_scale_params) updated_params.update( - self._load_other_weights(other_weights, params_dict, - stacked_params_mapping)) + self._load_other_weights(other_weights, params_dict, stacked_params_mapping) + ) return updated_params diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 0f375134ef00..4901ac74fb28 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -10,7 +10,9 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from .utils import maybe_prefix @@ -74,8 +76,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.emb_dim = config.emb_dim - self.inner_dim = config.inner_dim if config.inner_dim != 0 \ - else config.emb_dim + self.inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim self.max_speculative_tokens = config.num_lookahead_tokens @@ -83,72 +84,93 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.scale_input = config.scale_input if self.tie_weights: - assert ( - self.n_predict > 1 - ), "You cannot tie weights between stages when only 1 exists" + assert self.n_predict > 1, ( + "You cannot tie weights between stages when only 1 exists" + ) embedding = VocabParallelEmbedding( - config.vocab_size, - self.inner_dim, - org_num_embeddings=config.vocab_size) + config.vocab_size, self.inner_dim, org_num_embeddings=config.vocab_size + ) self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens) # the initial projection from the base model may # have a different size, so that stays separate. proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False) proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False) - self.proj = nn.ModuleList([proj_first] + [proj_tied] * - (self.max_speculative_tokens - 1)) - - self.head = nn.ModuleList([ - ParallelLMHead(self.vocab_size, - self.inner_dim, - bias=False, - prefix=maybe_prefix(prefix, f"head.{i}")) - for i in range(self.max_speculative_tokens) - ]) - - ln = MLPSpeculatorLayerNorm(self.inner_dim, - elementwise_scale_and_shift=True) + self.proj = nn.ModuleList( + [proj_first] + [proj_tied] * (self.max_speculative_tokens - 1) + ) + + self.head = nn.ModuleList( + [ + ParallelLMHead( + self.vocab_size, + self.inner_dim, + bias=False, + prefix=maybe_prefix(prefix, f"head.{i}"), + ) + for i in range(self.max_speculative_tokens) + ] + ) + + ln = MLPSpeculatorLayerNorm( + self.inner_dim, elementwise_scale_and_shift=True + ) self.ln = nn.ModuleList([ln] * self.max_speculative_tokens) else: - self.emb = nn.ModuleList([ - VocabParallelEmbedding(config.vocab_size, - self.inner_dim, - org_num_embeddings=config.vocab_size) - for _ in range(self.max_speculative_tokens) - ]) - - self.proj = nn.ModuleList([ - nn.Linear((self.emb_dim if i == 0 else self.inner_dim), - self.inner_dim, - bias=False) - for i in range(self.max_speculative_tokens) - ]) - - self.head = nn.ModuleList([ - ParallelLMHead(self.vocab_size, - self.inner_dim, - bias=False, - prefix=maybe_prefix(prefix, f"head.{i}")) - for i in range(self.max_speculative_tokens) - ]) - self.ln = nn.ModuleList([ - MLPSpeculatorLayerNorm(self.inner_dim, - elementwise_scale_and_shift=True) - for _ in range(self.max_speculative_tokens) - ]) + self.emb = nn.ModuleList( + [ + VocabParallelEmbedding( + config.vocab_size, + self.inner_dim, + org_num_embeddings=config.vocab_size, + ) + for _ in range(self.max_speculative_tokens) + ] + ) + + self.proj = nn.ModuleList( + [ + nn.Linear( + (self.emb_dim if i == 0 else self.inner_dim), + self.inner_dim, + bias=False, + ) + for i in range(self.max_speculative_tokens) + ] + ) + + self.head = nn.ModuleList( + [ + ParallelLMHead( + self.vocab_size, + self.inner_dim, + bias=False, + prefix=maybe_prefix(prefix, f"head.{i}"), + ) + for i in range(self.max_speculative_tokens) + ] + ) + self.ln = nn.ModuleList( + [ + MLPSpeculatorLayerNorm( + self.inner_dim, elementwise_scale_and_shift=True + ) + for _ in range(self.max_speculative_tokens) + ] + ) if self.scale_input: self.ln0 = MLPSpeculatorLayerNorm( - self.emb_dim, elementwise_scale_and_shift=False) + self.emb_dim, elementwise_scale_and_shift=False + ) - self.state_weight = 0.5**(0.5 / config.n_predict) - self.emb_weight = math.sqrt( - (1 - self.state_weight**2) * (self.inner_dim / 2)) + self.state_weight = 0.5 ** (0.5 / config.n_predict) + self.emb_weight = math.sqrt((1 - self.state_weight**2) * (self.inner_dim / 2)) self.activation = nn.GELU() self.config = config - self.logits_processor = LogitsProcessor(config.vocab_size, - config.vocab_size, 1.0) + self.logits_processor = LogitsProcessor( + config.vocab_size, config.vocab_size, 1.0 + ) # NOTE(woosuk): This method is commented out because it is old code # using V0. We should either port it to V1 or remove it. @@ -201,16 +223,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # return next_tokens - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: name = name.replace("speculator.", "") param = params_dict.get(name) if param is not None: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index e4a51b369737..2e3b76aaaabc 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -11,16 +11,17 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.pooler import (ClassifierPooler, - DispatchPooler, Pooler, - PoolingMethod, - PoolingParamsUpdate, - PoolingType) +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + DispatchPooler, + Pooler, + PoolingMethod, + PoolingParamsUpdate, + PoolingType, +) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask @@ -32,16 +33,15 @@ class ModernBertEmbeddings(nn.Module): - def __init__(self, config: ModernBertConfig): - super().__init__() self.config = config - self.tok_embeddings = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) - self.norm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps, - bias=config.norm_bias) + self.tok_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.norm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps, bias=config.norm_bias + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) @@ -60,24 +60,20 @@ def forward( class ModernBertRotaryEmbedding(RotaryEmbedding): - - def __init__(self, config: ModernBertConfig, head_size: int, dim: int, - base: float): + def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: float): super().__init__( head_size=head_size, rotary_dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, is_neox_style=True, - dtype=torch.float16) + dtype=torch.float16, + ) self.config = config class ModernBertAttention(nn.Module): - - def __init__(self, - config: ModernBertConfig, - layer_id: Optional[int] = None): + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -99,24 +95,27 @@ def __init__(self, sliding_window = None if layer_id % config.global_attn_every_n_layers != 0: sliding_window = config.local_attention // 2 - rope_theta = config.local_rope_theta if config.local_rope_theta \ - is not None else config.global_rope_theta + rope_theta = ( + config.local_rope_theta + if config.local_rope_theta is not None + else config.global_rope_theta + ) else: rope_theta = config.global_rope_theta - self.rotary_emb = ModernBertRotaryEmbedding(config=config, - head_size=self.head_dim, - dim=self.head_dim, - base=rope_theta) + self.rotary_emb = ModernBertRotaryEmbedding( + config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta + ) self.attn = EncoderOnlyAttention( self.num_heads, self.head_dim, self.scaling, prefix=f"{layer_id}.attn", - per_layer_sliding_window=sliding_window) - self.Wo = RowParallelLinear(config.hidden_size, - config.hidden_size, - bias=config.attention_bias) + per_layer_sliding_window=sliding_window, + ) + self.Wo = RowParallelLinear( + config.hidden_size, config.hidden_size, bias=config.attention_bias + ) def forward( self, @@ -133,17 +132,16 @@ def forward( class ModernBertMLP(nn.Module): - def __init__(self, config: ModernBertConfig): super().__init__() self.config = config - self.Wi = nn.Linear(config.hidden_size, - int(config.intermediate_size) * 2, - bias=config.mlp_bias) + self.Wi = nn.Linear( + config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias + ) self.act = nn.GELU() - self.Wo = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=config.mlp_bias) + self.Wo = RowParallelLinear( + config.intermediate_size, config.hidden_size, bias=config.mlp_bias + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input, gate = self.Wi(hidden_states).chunk(2, dim=-1) @@ -151,23 +149,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class ModernBertLayer(nn.Module): - - def __init__(self, - config: ModernBertConfig, - prefix: str = "", - layer_id: Optional[int] = None): + def __init__( + self, config: ModernBertConfig, prefix: str = "", layer_id: Optional[int] = None + ): super().__init__() self.config = config if layer_id == 0: self.attn_norm = nn.Identity() else: - self.attn_norm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps, - bias=config.norm_bias) + self.attn_norm = nn.LayerNorm( + config.hidden_size, eps=config.norm_eps, bias=config.norm_bias + ) self.attn = ModernBertAttention(config=config, layer_id=layer_id) - self.mlp_norm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps, - bias=config.norm_bias) + self.mlp_norm = nn.LayerNorm( + config.hidden_size, eps=config.norm_eps, bias=config.norm_bias + ) self.mlp = ModernBertMLP(config) def forward( @@ -175,8 +171,9 @@ def forward( hidden_states: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: - attn_outputs = self.attn(hidden_states=self.attn_norm(hidden_states), - position_ids=position_ids) + attn_outputs = self.attn( + hidden_states=self.attn_norm(hidden_states), position_ids=position_ids + ) hidden_states = hidden_states + attn_outputs mlp_output = self.mlp(self.mlp_norm(hidden_states)) hidden_states = hidden_states + mlp_output @@ -184,14 +181,15 @@ def forward( class ModernBertEncoderLayer(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - self.layers = nn.ModuleList([ - ModernBertLayer(config=config, layer_id=layer_id) - for layer_id in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + ModernBertLayer(config=config, layer_id=layer_id) + for layer_id in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -207,7 +205,8 @@ def forward( @default_pooling_type("CLS") class ModernBertModel(nn.Module): hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={"layers.": "encoder_layer.layers."}) + orig_to_new_prefix={"layers.": "encoder_layer.layers."} + ) def __init__( self, @@ -219,15 +218,14 @@ def __init__( self.config = config self.embeddings = ModernBertEmbeddings(config) self.encoder_layer = ModernBertEncoderLayer(vllm_config) - self.final_norm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps, - bias=config.norm_bias) + self.final_norm = nn.LayerNorm( + config.hidden_size, eps=config.norm_eps, bias=config.norm_bias + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings.get_input_embeddings(input_ids) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weights = self.hf_to_vllm_mapper.apply(weights) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -235,8 +233,7 @@ def load_weights(self, weights: Iterable[tuple[str, if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -251,8 +248,9 @@ def forward( if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embeddings(input_ids=input_ids, - inputs_embeds=inputs_embeds) + hidden_states = self.embeddings( + input_ids=input_ids, inputs_embeds=inputs_embeds + ) outputs = self.encoder_layer( hidden_states=hidden_states, @@ -263,18 +261,18 @@ def forward( class ModernBertPooler(Pooler): - def __init__(self, config: ModernBertConfig): super().__init__() pooling_type = PoolingType[config.classifier_pooling.upper()] self.pooling = PoolingMethod.from_pooling_type(pooling_type) - self.dense = nn.Linear(config.hidden_size, config.hidden_size, - config.classifier_bias) + self.dense = nn.Linear( + config.hidden_size, config.hidden_size, config.classifier_bias + ) self.act = nn.GELU() - self.norm = nn.LayerNorm(config.hidden_size, - eps=config.norm_eps, - bias=config.norm_bias) + self.norm = nn.LayerNorm( + config.hidden_size, eps=config.norm_eps, bias=config.norm_bias + ) def get_supported_tasks(self) -> Set[PoolingTask]: return self.pooling.get_supported_tasks() @@ -303,53 +301,55 @@ def forward( @default_pooling_type("CLS") class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): - is_pooling_model = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config - self.model = ModernBertModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "modernbert")) - self.classifier = nn.Linear(config.hidden_size, - config.num_labels, - dtype=vllm_config.model_config.head_dtype) + self.model = ModernBertModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert") + ) + self.classifier = nn.Linear( + config.hidden_size, + config.num_labels, + dtype=vllm_config.model_config.head_dtype, + ) self.pooling = ModernBertPooler(config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - ClassifierPooler( - pooling=self.pooling, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - ), - "score": - ClassifierPooler( - pooling=self.pooling, - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=self.pooling, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=self.pooling, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - self_weights = [] def weight_filter(): for name, weight in weights: if name.startswith("model."): - yield name[len("model."):], weight + yield name[len("model.") :], weight else: self_weights.append((name, weight)) @@ -360,13 +360,11 @@ def weight_filter(): for name, loaded_weight in self_weights: if name.startswith("classifier"): param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) if name.startswith("head"): - param = params_dict["pooling." + name[len("head") + 1:]] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + param = params_dict["pooling." + name[len("head") + 1 :]] + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) def forward( diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py index 11a2a384c165..666796d835a3 100644 --- a/vllm/model_executor/models/module_mapping.py +++ b/vllm/model_executor/models/module_mapping.py @@ -54,19 +54,22 @@ class MultiModelKeys(ModelKeys): generator: list[str] = field(default_factory=list) @staticmethod - def from_string_field(language_model: Union[str, list[str]] = None, - connector: Union[str, list[str]] = None, - tower_model: Union[str, list[str]] = None, - generator: Union[str, list[str]] = None, - **kwargs) -> 'MultiModelKeys': - + def from_string_field( + language_model: Union[str, list[str]] = None, + connector: Union[str, list[str]] = None, + tower_model: Union[str, list[str]] = None, + generator: Union[str, list[str]] = None, + **kwargs, + ) -> "MultiModelKeys": def to_list(value): if value is None: return [] return [value] if isinstance(value, str) else list(value) - return MultiModelKeys(language_model=to_list(language_model), - connector=to_list(connector), - tower_model=to_list(tower_model), - generator=to_list(generator), - **kwargs) + return MultiModelKeys( + language_model=to_list(language_model), + connector=to_list(connector), + tower_model=to_list(tower_model), + generator=to_list(generator), + **kwargs, + ) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index a77a2eb0f5a8..734841d0dc98 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -13,8 +13,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, - TensorType) +from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput @@ -23,43 +22,65 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather) -from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, - SiluAndMul) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) +from vllm.model_executor.layers.activation import MulAndSilu, QuickGELU, SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptIndexTargets, - PromptInsertion, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptInsertion, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP, SupportsQuant) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] @@ -81,16 +102,22 @@ class MolmoImageInputs(TensorSchema): - tp: Token sequence positions - pd: Patch dimension """ - images: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"})] + + images: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}), + ] # Number of crops may vary per batch and image, so pass it as a list. - image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]], - TensorShape("bn", "nc", "np", dynamic_dims={"nc"})] + image_masks: Annotated[ + Optional[Union[torch.Tensor, list[torch.Tensor]]], + TensorShape("bn", "nc", "np", dynamic_dims={"nc"}), + ] feat_is_patch: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "nc", "tp", dynamic_dims={"nc"})] + TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}), + ] # A boolean mask indicating which image features correspond to patch tokens. num_crops: Annotated[torch.Tensor, TensorShape("bn")] @@ -110,8 +137,7 @@ class VisionBackboneConfig: image_norm_eps: float = 1e-5 def __post_init__(self): - self.image_default_input_size = tuple( - self.image_default_input_size) # type: ignore[assignment] + self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment] @property def image_num_patch(self): @@ -207,15 +233,13 @@ def __init__( ) self.scale = self.head_dim**-0.5 - self.attn = MultiHeadAttention(self.num_heads, - self.head_dim, - self.scale, - num_kv_heads=self.num_kv_heads) - - def forward(self, - inputs_q: torch.Tensor, - inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: + self.attn = MultiHeadAttention( + self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads + ) + def forward( + self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None + ) -> torch.Tensor: if inputs_kv is not None: inputs_k = inputs_kv inputs_v = inputs_kv @@ -242,8 +266,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.attention = MultiHeadDotProductAttention( - config, quant_config=quant_config) + self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config) self.feed_forward = ViTMLP(config, quant_config) self.attention_norm = nn.LayerNorm( config.image_emb_dim, @@ -269,10 +292,12 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.resblocks = nn.ModuleList([ - ResidualAttentionBlock(config, quant_config) - for _ in range(config.image_num_layers) - ]) + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(config, quant_config) + for _ in range(config.image_num_layers) + ] + ) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: hidden_states = [] @@ -297,19 +322,18 @@ def __init__( super().__init__() scale = config.image_emb_dim**-0.5 self.patch_num = config.image_num_patch - self.class_embedding = nn.Parameter( - torch.randn(config.image_emb_dim) * scale) + self.class_embedding = nn.Parameter(torch.randn(config.image_emb_dim) * scale) self.num_prefix_tokens: int = NUM_PREFIX_TOKENS self.positional_embedding = nn.Parameter( - torch.randn(config.image_num_pos, config.image_emb_dim) * scale) + torch.randn(config.image_num_pos, config.image_emb_dim) * scale + ) image_patch_size = config.image_patch_size self.patch_embedding = nn.Linear( image_patch_size * image_patch_size * 3, config.image_emb_dim, bias=False, ) - self.pre_ln = nn.LayerNorm(config.image_emb_dim, - eps=config.image_norm_eps) + self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps) self.transformer = BlockCollection(config, quant_config) def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: @@ -317,8 +341,12 @@ def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: pos_emb = self.positional_embedding[1:] pos_emb = pos_emb.reshape( - (int(math.sqrt(pos_emb.shape[0])), - int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])) + ( + int(math.sqrt(pos_emb.shape[0])), + int(math.sqrt(pos_emb.shape[0])), + pos_emb.shape[1], + ) + ) (patch_num_0, patch_num_1) = patch_num @@ -335,13 +363,12 @@ def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) - x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], - dim=1).to(x.dtype) + x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype) return x - def forward(self, - x: torch.Tensor, - patch_num: Optional[int] = None) -> list[torch.Tensor]: + def forward( + self, x: torch.Tensor, patch_num: Optional[int] = None + ) -> list[torch.Tensor]: """ : param x: (batch_size, num_patch, n_pixels) """ @@ -353,8 +380,8 @@ def forward(self, # class embeddings and positional embeddings x = torch.cat( - [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], - dim=1) + [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1 + ) x = self.add_pos_emb(x, patch_num) x = self.pre_ln(x) @@ -382,8 +409,7 @@ def __init__( assert self.total_num_heads % self.tp_size == 0 self.num_heads = self.total_num_heads // self.tp_size - self.total_num_kv_heads = config.num_key_value_heads \ - or self.total_num_heads + self.total_num_kv_heads = config.num_key_value_heads or self.total_num_heads if self.total_num_kv_heads >= self.tp_size: assert self.total_num_kv_heads % self.tp_size == 0 else: @@ -411,10 +437,10 @@ def __init__( self.q_norm: Optional[nn.Module] = None if config.attention_layer_norm: self.tp_rank = get_tensor_model_parallel_rank() - self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, - eps=config.layer_norm_eps) - self.q_norm = RMSNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.k_norm = RMSNorm( + self.total_num_kv_heads * self.head_dim, eps=config.layer_norm_eps + ) + self.q_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) # Rotary embeddings. self.rotary_emb = get_rope( @@ -424,13 +450,15 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) # Attention output projection. self.o_proj = RowParallelLinear( @@ -440,16 +468,16 @@ def __init__( quant_config=quant_config, ) - def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) q = self.q_norm(q) k = self.k_norm(k) if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k @@ -472,10 +500,12 @@ def forward( class LanguageModelMLP(nn.Module): """Molmo's LLM mlp.""" - def __init__(self, - config: PretrainedConfig, - input_dim: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + def __init__( + self, + config: PretrainedConfig, + input_dim: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size // 2 @@ -547,7 +577,6 @@ def forward( class MolmoDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -557,20 +586,19 @@ def __init__( ) -> None: super().__init__() # Attention block. - self.self_attn = MolmoAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = MolmoAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attn" + ) # MLP block. self.mlp = LanguageModelMLP(config, quant_config=quant_config) # LayerNorm assert config.layer_norm_type == "rms" - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.layer_norm_eps + ) def forward( self, @@ -583,21 +611,18 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class MolmoDecoderNormAfterLayer(MolmoDecoderLayer): - def forward( self, positions: torch.Tensor, @@ -638,16 +663,14 @@ def __init__( (self.image_num_patch[0] + 1) // POOLING_SIZE, (self.image_num_patch[1] + 1) // POOLING_SIZE, ) - self.image_vit = VisionTransformer(vision_config, - quant_config=quant_config) + self.image_vit = VisionTransformer(vision_config, quant_config=quant_config) self.num_prefix_tokens = self.image_vit.num_prefix_tokens - assert self.num_prefix_tokens in { - 0, 1 - }, "Only 0 or 1 prefix tokens are supported" + assert self.num_prefix_tokens in {0, 1}, ( + "Only 0 or 1 prefix tokens are supported" + ) self.image_pooling_2d = MultiHeadDotProductAttention( - vision_config, - nlayers=len(self.vit_layers), - quant_config=quant_config) + vision_config, nlayers=len(self.vit_layers), quant_config=quant_config + ) self.image_projector = ImageProjectorMLP( config, input_dim=vision_config.image_emb_dim, @@ -671,8 +694,7 @@ def encode_image(self, images: torch.Tensor) -> torch.Tensor: """ B, T, N, D = images.shape - mask = ~torch.all( - images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) + mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) images = images.view(B * T, N, D) image_features = self.image_vit(images) @@ -707,21 +729,22 @@ def forward( assert image_masks is not None pad_embed = self.pad_embed[:, None, None, None, :] all_pad = image_masks == 0 - partial_pad = torch.logical_and( - image_masks < 1, - torch.logical_not(all_pad)).to(dtype=torch.float32) + partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to( + dtype=torch.float32 + ) all_pad = all_pad.to(dtype=torch.float32) - image_features = image_features + pad_embed[0] * torch.unsqueeze( - all_pad, -1) + image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1) image_features = image_features + pad_embed[1] * torch.unsqueeze( - partial_pad, -1) + partial_pad, -1 + ) image_features = image_features.to(og_dtype) image_features = image_features.reshape( - (batch_size, num_image) + self.image_num_patch + (-1, ), ) + (batch_size, num_image) + self.image_num_patch + (-1,), + ) - if (missing_w := self.image_num_patch[0] % POOLING_SIZE): + if missing_w := self.image_num_patch[0] % POOLING_SIZE: # Padding for image pooling (see below) image_features = F.pad( image_features, @@ -731,7 +754,7 @@ def forward( # image pooling image_features = rearrange( image_features, - 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', + "b n (h dh) (w dw) c -> (b n h w) (dh dw) c", dh=POOLING_SIZE, dw=POOLING_SIZE, ) @@ -747,8 +770,7 @@ def forward( # image_features: (batch_size, num_image, num_patch, d_model) return image_features - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("merged_linear", "gate_proj", 0), @@ -758,7 +780,7 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -777,8 +799,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -786,7 +807,6 @@ def load_weights(self, weights: Iterable[tuple[str, @support_torch_compile class MolmoModel(nn.Module, SupportsQuant): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -804,21 +824,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, ) - decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \ - else MolmoDecoderLayer + decoder_layer = ( + MolmoDecoderNormAfterLayer if config.norm_after else MolmoDecoderLayer + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: decoder_layer( - config, cache_config, quant_config, prefix=prefix), + config, cache_config, quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) assert config.layer_norm_type == "rms" self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -849,18 +871,16 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) if residual is not None: hidden_states, _ = self.norm(hidden_states, residual) else: hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -871,8 +891,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -939,8 +958,12 @@ def get_patches_grid_size( def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]: - tilings = [(i, j) for i in range(1, max_num + 1) - for j in range(1, max_num + 1) if i * j <= max_num] + tilings = [ + (i, j) + for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num + ] return sorted(tilings, key=lambda x: x[0] * x[1]) @@ -1128,7 +1151,8 @@ def __call__( **kwargs, ) -> BatchFeature: outputs = self.processor.process( # type: ignore - text, images, **kwargs) + text, images, **kwargs + ) if images is None: images = [] @@ -1146,7 +1170,8 @@ def __call__( self.select_tiling( image_width=image.size[0], image_height=image.size[1], - ) for image in images + ) + for image in images ] # For each image: tiling_h * tiling_w + extra num_crops = torch.tensor(tilings).prod(-1) + 1 @@ -1160,7 +1185,6 @@ def __call__( class MolmoProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper: processor = self.ctx.get_hf_processor(**kwargs) return MolmoProcessorWrapper(processor) @@ -1209,8 +1233,7 @@ def get_image_size_with_most_features(self) -> ImageSize: ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") @@ -1219,7 +1242,6 @@ def get_image_size_with_most_features(self) -> ImageSize: class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -1229,23 +1251,22 @@ def get_dummy_mm_data( mm_counts: Mapping[str, int], mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): - def _apply_hf_processor_tokens_only( self, prompt_tokens: list[int], @@ -1263,7 +1284,7 @@ def _apply_hf_processor_tokens_only( processor, # type: ignore dict(tokens=tokens), ) - prompt_ids, = processed_data.pop("input_ids").tolist() + (prompt_ids,) = processed_data.pop("input_ids").tolist() return prompt_ids @@ -1277,10 +1298,8 @@ def _get_mm_fields_config( return dict( images=MultiModalFieldConfig.flat_from_sizes("image", num_crops), - image_masks=MultiModalFieldConfig.flat_from_sizes( - "image", num_crops), - feat_is_patch=MultiModalFieldConfig.flat_from_sizes( - "image", num_crops), + image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops), + feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops), num_crops=MultiModalFieldConfig.batched("image"), img_patch_id=MultiModalFieldConfig.shared("image", num_images), ) @@ -1303,8 +1322,7 @@ def _get_prompt_updates( img_end_id = processor.im_end_id extra_row = [img_patch_id] * image_token_length_w + [img_col_id] - extra_joint = ([img_start_id] + extra_row * image_token_length_h + - [img_end_id]) + extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id] def get_insertion_molmo(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) @@ -1315,10 +1333,12 @@ def get_insertion_molmo(item_idx: int): image_height=image_size.height, ) - joint_row = ([img_patch_id] * ((ncols + 1) // pooling_size) + - [img_col_id]) - joint = ([img_start_id] + joint_row * - ((nrows + 1) // pooling_size) + [img_end_id]) + joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id] + joint = ( + [img_start_id] + + joint_row * ((nrows + 1) // pooling_size) + + [img_end_id] + ) return PromptUpdateDetails.select_token_id( extra_joint + joint, @@ -1334,11 +1354,14 @@ def get_insertion_molmo(item_idx: int): ] -@MULTIMODAL_REGISTRY.register_processor(MolmoMultiModalProcessor, - info=MolmoProcessingInfo, - dummy_inputs=MolmoDummyInputsBuilder) -class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, - SupportsQuant): +@MULTIMODAL_REGISTRY.register_processor( + MolmoMultiModalProcessor, + info=MolmoProcessingInfo, + dummy_inputs=MolmoDummyInputsBuilder, +) +class MolmoForCausalLM( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant +): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ # vision backbone mapping @@ -1370,7 +1393,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, packed_modules_mapping = { "qkv_proj": ["qkv_proj"], "gate_up_proj": ["gate_up_proj"], # language model - "merged_linear": ["gate_proj", "up_proj"] # image_projector + "merged_linear": ["gate_proj", "up_proj"], # image_projector } @classmethod @@ -1391,10 +1414,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config vision_config = VisionBackboneConfig() - self.vision_backbone = MolmoVisionBackbone(config, vision_config, - quant_config) - self.model = MolmoModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config) + self.model = MolmoModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.img_patch_id = None if self.config.weight_tying: @@ -1407,11 +1430,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(config.embedding_size - or config.vocab_size) + self.logits_processor = LogitsProcessor( + config.embedding_size or config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( self, @@ -1426,14 +1451,16 @@ def _parse_and_validate_image_input( return None if not isinstance(num_crops, (torch.Tensor, list)): - raise ValueError("Incorrect type of num_crops. " - f"Got type: {type(num_crops)}") + raise ValueError( + f"Incorrect type of num_crops. Got type: {type(num_crops)}" + ) num_crops = flatten_bn(num_crops, concat=True) img_patch_id = kwargs.pop("img_patch_id", None) if not isinstance(img_patch_id, torch.Tensor): - raise ValueError("Incorrect type of img_patch_id. " - f"Got type: {type(img_patch_id)}") + raise ValueError( + f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}" + ) self.img_patch_id = img_patch_id.flatten().unique().item() return MolmoImageInputs( @@ -1454,19 +1481,22 @@ def _process_image_input( # Call the vision backbone on the whole batch at once images_flat = flatten_bn(images, concat=True) - image_masks_flat = (None if image_masks is None else flatten_bn( - image_masks, concat=True)) + image_masks_flat = ( + None if image_masks is None else flatten_bn(image_masks, concat=True) + ) feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True) image_features_flat = self.vision_backbone( images=images_flat.unsqueeze(0), - image_masks=(None if image_masks_flat is None else - image_masks_flat.unsqueeze(0)), + image_masks=( + None if image_masks_flat is None else image_masks_flat.unsqueeze(0) + ), ).squeeze(0) # Only the features corresponding to patch tokens are relevant return [ - feats[f_is_patch] for feats, f_is_patch in zip( + feats[f_is_patch] + for feats, f_is_patch in zip( image_features_flat.split(num_crops.tolist()), feat_is_patch_flat.split(num_crops.tolist()), ) @@ -1475,8 +1505,7 @@ def _process_image_input( def get_language_model(self) -> torch.nn.Module: return self.model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -1491,14 +1520,12 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> torch.Tensor: - if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -1507,7 +1534,6 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader(self) weights = _get_weights_with_merged_embedding(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -1524,7 +1550,7 @@ def get_mm_mapping(self) -> MultiModelKeys: def _get_weights_with_merged_embedding( - weights: Iterable[tuple[str, torch.Tensor]] + weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, torch.Tensor]]: embedding_weights = {} for name, weight in weights: diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index caa00763fc3d..3bf8fce0de0d 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -91,10 +91,10 @@ def multihead_attention( """ # Unified format legal check assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims" - assert q_cu_seqlens[-1] == q.shape[ - 0], "q_cu_seqlens must sum to q.shape[0]" - assert (k_cu_seqlens[-1] == k.shape[0] == - v.shape[0]), "k_cu_seqlens must sum to k.shape[0]" + assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]" + assert k_cu_seqlens[-1] == k.shape[0] == v.shape[0], ( + "k_cu_seqlens must sum to k.shape[0]" + ) assert q.dtype in [ torch.bfloat16, torch.float16, @@ -137,23 +137,19 @@ def sdpa_attention( k_cu_seqlens: Optional cumulative sequence lengths of k. """ seq_length = q.shape[0] - attention_mask = torch.zeros([1, seq_length, seq_length], - device=q.device, - dtype=torch.bool) + attention_mask = torch.zeros( + [1, seq_length, seq_length], device=q.device, dtype=torch.bool + ) for i in range(1, len(q_cu_seqlens)): attention_mask[ ..., - q_cu_seqlens[i - 1]:q_cu_seqlens[i], - q_cu_seqlens[i - 1]:q_cu_seqlens[i], + q_cu_seqlens[i - 1] : q_cu_seqlens[i], + q_cu_seqlens[i - 1] : q_cu_seqlens[i], ] = True q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - attn_output = F.scaled_dot_product_attention(q, - k, - v, - attention_mask, - dropout_p=0.0) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) return attn_output @@ -172,8 +168,9 @@ def _apply_rope_input_validation(x, freqs_cis): assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype -def apply_rope(xq: torch.Tensor, xk: torch.Tensor, - freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def apply_rope( + xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: """ Args: (The leading dimensions of all inputs should be the same) xq: query, tensor of shape (..., num_heads, head_dim) @@ -189,20 +186,15 @@ def apply_rope(xq: torch.Tensor, xk: torch.Tensor, # ..., num_heads, head_dim/2 xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten( - -2) # ..., num_heads, head_dim - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten( - -2) # ..., num_heads, head_dim + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim return xq_out.type_as(xq), xk_out.type_as(xk) class Learnable2DInterpPosEmb(nn.Module): - - def __init__(self, - height: int, - width: int, - dim: int, - interpolation_mode: str = "bicubic") -> None: + def __init__( + self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic" + ) -> None: super().__init__() self.height = height self.width = width @@ -224,13 +216,16 @@ def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor: self.weight.permute((2, 0, 1)).unsqueeze(0), size=shape, mode=self.interpolation_mode, - ).squeeze(0).permute((1, 2, 0)).flatten(end_dim=1)) + ) + .squeeze(0) + .permute((1, 2, 0)) + .flatten(end_dim=1) + ) out = x + torch.cat(pos_embs) return out class MoonVisionPatchEmbed(nn.Module): - def __init__( self, out_dim: int, @@ -240,23 +235,23 @@ def __init__( pos_emb_width: int = 14, ): super().__init__() - assert isinstance( - patch_size, - (int, Sequence)), f"Invalid patch_size type: {type(patch_size)}" + assert isinstance(patch_size, (int, Sequence)), ( + f"Invalid patch_size type: {type(patch_size)}" + ) if isinstance(patch_size, int): patch_size = (patch_size, patch_size) - assert (len(patch_size) == 2 - ), f"Expected patch_size to be a tuple of 2, got {patch_size}" + assert len(patch_size) == 2, ( + f"Expected patch_size to be a tuple of 2, got {patch_size}" + ) self.patch_size = patch_size - self.proj = nn.Conv2d(in_dim, - out_dim, - kernel_size=patch_size, - stride=patch_size) + self.proj = nn.Conv2d( + in_dim, out_dim, kernel_size=patch_size, stride=patch_size + ) - self.pos_emb = Learnable2DInterpPosEmb(height=pos_emb_height, - width=pos_emb_width, - dim=out_dim) + self.pos_emb = Learnable2DInterpPosEmb( + height=pos_emb_height, width=pos_emb_width, dim=out_dim + ) def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: """ @@ -295,12 +290,9 @@ class Rope2DPosEmb(nn.Module): device (str): the device to store the precomputed cis """ - def __init__(self, - dim: int, - max_height: int, - max_width: int, - theta_base=10000, - device="cuda"): + def __init__( + self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda" + ): super().__init__() self.dim = dim assert self.dim % 4 == 0, "dim must be divisible by 4" @@ -325,18 +317,18 @@ def precomputed_freqs_cis(self) -> torch.Tensor: flat_pos = torch.arange(0, N).float().to(self.device) x_pos = flat_pos % self.max_width y_pos = flat_pos // self.max_width - dim_range = (torch.arange(0, self.dim, - 4)[:(self.dim // 4)].float().to(self.device) - ) # C/4 - freqs = 1.0 / (self.theta_base**(dim_range / self.dim)) + dim_range = ( + torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(self.device) + ) # C/4 + freqs = 1.0 / (self.theta_base ** (dim_range / self.dim)) x_freqs = torch.outer(x_pos, freqs).float() # N, C/4 y_freqs = torch.outer(y_pos, freqs).float() # N, C/4 x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4 y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4 # N, C/4, 2 freqs_cis = torch.cat( - [x_cis.unsqueeze(dim=-1), - y_cis.unsqueeze(dim=-1)], dim=-1) + [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1 + ) # max_height, max_width, C/2 freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1) return freqs_cis @@ -349,12 +341,13 @@ def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor: freqs_cis: tensor of shape (sum(t * height * width), dim//2) """ shapes = grid_hws.tolist() - assert all(1 <= h <= self.max_height and 1 <= w <= self.max_width - for h, w in shapes), ( - shapes, - self.max_height, - self.max_width, - ) + assert all( + 1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes + ), ( + shapes, + self.max_height, + self.max_width, + ) freqs_cis = torch.cat( [ self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2) @@ -364,8 +357,9 @@ def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor: ) return freqs_cis - def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor, - pos_idx_mask: torch.Tensor) -> torch.Tensor: + def get_freqs_cis_by_idx( + self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor + ) -> torch.Tensor: """ Args: pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token. @@ -374,16 +368,20 @@ def get_freqs_cis_by_idx(self, pos_idx: torch.Tensor, Return: freqs_cis: tensor of shape (..., dim//2) """ - assert (pos_idx.shape[:-1] == pos_idx_mask.shape - and pos_idx.shape[-1] == 2 and pos_idx.ndim - == pos_idx_mask.ndim + 1), (pos_idx.shape, pos_idx_mask.shape) + assert ( + pos_idx.shape[:-1] == pos_idx_mask.shape + and pos_idx.shape[-1] == 2 + and pos_idx.ndim == pos_idx_mask.ndim + 1 + ), (pos_idx.shape, pos_idx_mask.shape) assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype - shp = pos_idx_mask.shape + (self.dim // 2, ) # ..., head_dim/2 - freqs_cis = torch.ones(shp, dtype=torch.complex64, - device=self.device) # ..., head_dim/2 - freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[pos_idx[ - ..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]] + shp = pos_idx_mask.shape + (self.dim // 2,) # ..., head_dim/2 + freqs_cis = torch.ones( + shp, dtype=torch.complex64, device=self.device + ) # ..., head_dim/2 + freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[ + pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask] + ] return freqs_cis @@ -394,23 +392,23 @@ class MLP2(nn.Module): bias: whether to use bias in linear layer. """ - def __init__(self, - dims: list[int], - activation, - bias: bool = True, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + dims: list[int], + activation, + bias: bool = True, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() assert len(dims) == 3 self.use_data_parallel = use_data_parallel - self.fc0 = ReplicatedLinear(dims[0], - dims[1], - bias=bias, - prefix=maybe_prefix(prefix, "fc0")) - self.fc1 = ReplicatedLinear(dims[1], - dims[2], - bias=bias, - prefix=maybe_prefix(prefix, "fc1")) + self.fc0 = ReplicatedLinear( + dims[0], dims[1], bias=bias, prefix=maybe_prefix(prefix, "fc0") + ) + self.fc1 = ReplicatedLinear( + dims[1], dims[2], bias=bias, prefix=maybe_prefix(prefix, "fc1") + ) self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -421,7 +419,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MoonVitEncoderLayer(nn.Module): - def __init__( self, num_heads: int, @@ -446,18 +443,18 @@ def __init__( self.norm0 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim) self.use_data_parallel = use_data_parallel - self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], - activation, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) - self.wqkv = ReplicatedLinear(hidden_dim, - hidden_dim * 3, - bias=attn_bias, - prefix=f"{prefix}.wqkv") - self.wo = ReplicatedLinear(hidden_dim, - hidden_dim, - bias=attn_bias, - prefix=f"{prefix}.wo") + self.mlp = MLP2( + [hidden_dim, mlp_dim, hidden_dim], + activation, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + self.wqkv = ReplicatedLinear( + hidden_dim, hidden_dim * 3, bias=attn_bias, prefix=f"{prefix}.wqkv" + ) + self.wo = ReplicatedLinear( + hidden_dim, hidden_dim, bias=attn_bias, prefix=f"{prefix}.wo" + ) def attention_qkvpacked( self, @@ -484,11 +481,9 @@ def attention_qkvpacked( xq, xk = apply_rope(xq, xk, rope_freqs_cis) attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] - attn_out = attn_func(xq, - xk, - xv, - q_cu_seqlens=cu_seqlens, - k_cu_seqlens=cu_seqlens) + attn_out = attn_func( + xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens + ) attn_out, _ = self.wo(attn_out) return attn_out @@ -507,9 +502,9 @@ def forward( """ residual = hidden_states hidden_states = self.norm0(hidden_states) - attn_out = self.attention_qkvpacked(hidden_states, - cu_seqlens, - rope_freqs_cis=rope_freqs_cis) + attn_out = self.attention_qkvpacked( + hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis + ) hidden_states = residual + attn_out residual = hidden_states @@ -519,7 +514,6 @@ def forward( class MoonVitEncoder(nn.Module): - def __init__( self, hidden_dim: int, @@ -531,27 +525,37 @@ def __init__( super().__init__() self.rope_2d = Rope2DPosEmb( - block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512) + block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512 + ) self.blocks = nn.ModuleList( - [MoonVitEncoderLayer(use_data_parallel=use_data_parallel, \ - prefix=f"{prefix}.blocks.{layer_idx}", \ - **block_cfg) for layer_idx in range(num_layers)]) + [ + MoonVitEncoderLayer( + use_data_parallel=use_data_parallel, + prefix=f"{prefix}.blocks.{layer_idx}", + **block_cfg, + ) + for layer_idx in range(num_layers) + ] + ) self.final_layernorm = nn.LayerNorm(hidden_dim) - def forward(self, hidden_states: torch.Tensor, - grid_hw: torch.Tensor) -> torch.Tensor: - rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens( - grid_hws=grid_hw) + def forward( + self, hidden_states: torch.Tensor, grid_hw: torch.Tensor + ) -> torch.Tensor: + rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(grid_hws=grid_hw) lengths = torch.cat( - (torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), - (grid_hw[:, 0] * grid_hw[:, 1]).to(hidden_states.device))) + ( + torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype), + (grid_hw[:, 0] * grid_hw[:, 1]).to(hidden_states.device), + ) + ) cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32) for _, block in enumerate(self.blocks): - hidden_states = block(hidden_states, - cu_seqlens, - rope_freqs_cis=rope_freqs_cis) + hidden_states = block( + hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis + ) hidden_states = self.final_layernorm(hidden_states) @@ -559,9 +563,9 @@ def forward(self, hidden_states: torch.Tensor, def patch_merger( - x: torch.Tensor, - grid_hw: torch.Tensor, - merge_kernel_size: list[int, int] = (2, 2), + x: torch.Tensor, + grid_hw: torch.Tensor, + merge_kernel_size: list[int, int] = (2, 2), ) -> list[torch.Tensor]: d_model = x.size(-1) @@ -570,15 +574,17 @@ def patch_merger( for x_shape in grid_hw.tolist(): height, width = x_shape[0], x_shape[1] # Get the current sequence - seq = x[pre_sum:pre_sum + height * width] + seq = x[pre_sum : pre_sum + height * width] # Reshape along self.merge_kernel_size and concat to the last dimension kernel_height, kernel_width = merge_kernel_size new_height, new_width = height // kernel_height, width // kernel_width - reshaped_seq = seq.view(new_height, kernel_height, new_width, - kernel_width, d_model) + reshaped_seq = seq.view( + new_height, kernel_height, new_width, kernel_width, d_model + ) reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous() - padded_seq = reshaped_seq.view(new_height * new_width, - kernel_height * kernel_width, -1) + padded_seq = reshaped_seq.view( + new_height * new_width, kernel_height * kernel_width, -1 + ) outputs.append(padded_seq) pre_sum += height * width @@ -586,7 +592,6 @@ def patch_merger( class MoonVitVLProjector(nn.Module): - def __init__( self, in_channels: int, @@ -596,13 +601,10 @@ def __init__( out_dim: int = 4096, ): super().__init__() - self.hidden_size = in_channels * merge_kernel_size[ - 0] * merge_kernel_size[1] + self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1] self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps) - self.linear_1 = nn.Linear(self.hidden_size, - self.hidden_size, - bias=True) + self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.act = ACT2FN[hidden_act] self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True) @@ -621,12 +623,14 @@ class MoonVitPretrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True - def __init__(self, - config: MoonViTConfig, - use_data_parallel: bool = False, - prefix: str = "", - *inputs, - **kwargs): + def __init__( + self, + config: MoonViTConfig, + use_data_parallel: bool = False, + prefix: str = "", + *inputs, + **kwargs, + ): super().__init__(config, *inputs, **kwargs) config = deepcopy(config) self.use_data_parallel = use_data_parallel @@ -655,8 +659,9 @@ def __init__(self, prefix=f"{prefix}.encoder", ) - def forward(self, pixel_values: torch.Tensor, - grid_hw: torch.Tensor) -> torch.Tensor: + def forward( + self, pixel_values: torch.Tensor, grid_hw: torch.Tensor + ) -> torch.Tensor: """ Args: pixel_values (torch.Tensor): The input pixel values. @@ -667,7 +672,7 @@ def forward(self, pixel_values: torch.Tensor, """ hidden_states = self.patch_embed(pixel_values, grid_hw) hidden_states = self.encoder(hidden_states, grid_hw) - hidden_states = patch_merger(hidden_states, - grid_hw, - merge_kernel_size=self.merge_kernel_size) + hidden_states = patch_merger( + hidden_states, grid_hw, merge_kernel_size=self.merge_kernel_size + ) return hidden_states diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 64d669e8ac3e..3f1f2bbcb026 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -14,30 +14,38 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) def _get_alibi_slopes( total_num_heads: int, alibi_bias_max: int, ) -> torch.Tensor: - next_power_of_2 = 2**math.ceil(math.log2(total_num_heads)) + next_power_of_2 = 2 ** math.ceil(math.log2(total_num_heads)) m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32) m = m.mul(alibi_bias_max / next_power_of_2) slopes = 1.0 / torch.pow(2, m) @@ -47,7 +55,6 @@ def _get_alibi_slopes( class MPTAttention(nn.Module): - def __init__( self, config: MptConfig, @@ -107,20 +114,21 @@ def __init__( tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads - alibi_slopes = _get_alibi_slopes(self.total_num_heads, - self.alibi_bias_max) + alibi_slopes = _get_alibi_slopes(self.total_num_heads, self.alibi_bias_max) alibi_slopes = alibi_slopes[head_start:head_end].tolist() self.head_dim = self.d_model // self.total_num_heads scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -141,7 +149,6 @@ def forward( class MPTMLP(nn.Module): - def __init__( self, config: MptConfig, @@ -173,7 +180,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MPTBlock(nn.Module): - def __init__( self, config: MptConfig, @@ -184,10 +190,9 @@ def __init__( super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attn") + self.attn = MPTAttention( + config, cache_config, quant_config, prefix=f"{prefix}.attn" + ) self.norm_2 = nn.LayerNorm(hidden_size) self.ffn = MPTMLP(config, quant_config) @@ -210,7 +215,6 @@ def forward( @support_torch_compile class MPTModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -227,19 +231,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.blocks = make_layers( config.n_layers, - lambda prefix: MPTBlock( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.blocks") + lambda prefix: MPTBlock(config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.blocks", + ) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): - if hasattr(module, "bias") and isinstance( - module.bias, nn.Parameter): + if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): # Remove the bias term in Linear and LayerNorm. module.register_parameter("bias", None) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.d_model)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.d_model + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -267,8 +270,7 @@ def forward( hidden_states = self.norm_f(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -278,15 +280,13 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class MPTForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -295,12 +295,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert config.tie_word_embeddings self.quant_config = quant_config - self.transformer = MPTModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "transformer")) + self.transformer = MPTModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) @@ -312,8 +314,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -323,7 +326,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 11b152fe79da..7c64d14ca9d7 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -25,25 +25,44 @@ from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - MultiModalEmbeddings, - SupportsMultiModal) -from vllm.model_executor.models.internvl import (calculate_internvl_targets, - get_internvl_target_ratios) +from vllm.model_executor.models.interfaces import ( + HasInnerState, + IsHybrid, + MultiModalEmbeddings, + SupportsMultiModal, +) +from vllm.model_executor.models.internvl import ( + calculate_internvl_targets, + get_internvl_target_ratios, +) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM from vllm.model_executor.models.radio import RadioModel -from vllm.model_executor.models.utils import (flatten_bn, - init_vllm_registered_model, - maybe_prefix) +from vllm.model_executor.models.utils import ( + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs, MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargs, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.radio import RadioConfig @@ -87,8 +106,9 @@ class NanoNemotronVLImageEmbeddinInputs(TypedDict): """ -NanoNemotronVLImageInputs = Union[NanoNemotronVLImagePixelInputs, - NanoNemotronVLImageEmbeddinInputs] +NanoNemotronVLImageInputs = Union[ + NanoNemotronVLImagePixelInputs, NanoNemotronVLImageEmbeddinInputs +] class NanoNemotronVLVideoPixelInputs(TensorSchema): @@ -100,6 +120,7 @@ class NanoNemotronVLVideoPixelInputs(TensorSchema): - h: Height of each video frame - w: Width of each video frame """ + type: Literal["pixel_values_videos"] pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")] num_patches: Annotated[torch.Tensor, TensorShape("bn")] @@ -112,21 +133,19 @@ class NanoNemotronVLVideoEmbeddingInputs(TensorSchema): - f: Total video feature size - h: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["video_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("n", "f", "h")] + data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")] -NanoNemotronVLVideoInputs = Union[NanoNemotronVLVideoPixelInputs, - NanoNemotronVLVideoEmbeddingInputs] +NanoNemotronVLVideoInputs = Union[ + NanoNemotronVLVideoPixelInputs, NanoNemotronVLVideoEmbeddingInputs +] -def dynamic_preprocess(image, - *, - image_size=512, - max_num_tiles=12, - use_thumbnail=True, - idx=0): +def dynamic_preprocess( + image, *, image_size=512, max_num_tiles=12, use_thumbnail=True, idx=0 +): orig_width, orig_height = image.size target_ratios = get_internvl_target_ratios(1, max_num_tiles) @@ -136,7 +155,8 @@ def dynamic_preprocess(image, orig_height=orig_height, target_ratios=target_ratios, image_size=image_size, - use_thumbnail=False) + use_thumbnail=False, + ) # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] @@ -156,12 +176,12 @@ def dynamic_preprocess(image, processed_images.append(thumbnail_img) processed_images = [ - img.convert("RGB") if img.mode != "RGB" else img - for img in processed_images + img.convert("RGB") if img.mode != "RGB" else img for img in processed_images ] processed_images = [ - T.Resize((image_size, image_size), - interpolation=T.InterpolationMode.BICUBIC)(img) + T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)( + img + ) for img in processed_images ] processed_images = [T.ToTensor()(img) for img in processed_images] @@ -222,8 +242,9 @@ class BaseNanoNemotronVLProcessor(ABC): https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252 """ - def __init__(self, config: PretrainedConfig, tokenizer: AnyTokenizer, - *args, **kwargs) -> None: + def __init__( + self, config: PretrainedConfig, tokenizer: AnyTokenizer, *args, **kwargs + ) -> None: super().__init__() self.config = config @@ -233,7 +254,8 @@ def __init__(self, config: PretrainedConfig, tokenizer: AnyTokenizer, patch_size: int = config.patch_size self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.image_size = image_size self.use_thumbnail: bool = config.use_thumbnail self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1) @@ -283,7 +305,8 @@ def _images_to_pixel_values_lst( max_num=max_num_tiles, use_thumbnail=self.use_thumbnail, idx=idx, - ) for idx, image in enumerate(images) + ) + for idx, image in enumerate(images) ] def _preprocess_image( @@ -295,24 +318,22 @@ def _preprocess_image( if len(images) == 0: image_inputs = {} else: - pixel_values_lst = self._images_to_pixel_values_lst( - images, max_num_tiles) + pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles) image_inputs = { - "pixel_values_flat": - torch.cat(pixel_values_lst), - "image_num_patches": - torch.tensor([len(item) for item in pixel_values_lst]), + "pixel_values_flat": torch.cat(pixel_values_lst), + "image_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst] + ), } for pixel_values in pixel_values_lst: num_patches = pixel_values.shape[0] feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) - text = [t.replace('<image>', image_repl.full, 1) for t in text] + text = [t.replace("<image>", image_repl.full, 1) for t in text] return text, image_inputs - def _make_batch_input(self, - input_item: Optional[Union[Any, list[Any]]] = None): + def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None): if input_item is None: input_item = [] if not isinstance(input_item, list): @@ -392,14 +413,14 @@ def _videos_to_pixel_values_lst( max_num_tiles: int, dynamic_image_size: Optional[bool] = None, ) -> list[torch.Tensor]: - return [ video_to_pixel_values( video, input_size=self.image_size, max_num_tiles=max_num_tiles, use_thumbnail=self.use_thumbnail, - ) for video in videos + ) + for video in videos ] def _preprocess_video( @@ -419,18 +440,19 @@ def _preprocess_video( ) video_inputs = { - "pixel_values_flat_video": - torch.cat(pixel_values_lst_video), - "video_num_patches": - torch.tensor([len(item) for item in pixel_values_lst_video]), + "pixel_values_flat_video": torch.cat(pixel_values_lst_video), + "video_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst_video] + ), } for pixel_values in pixel_values_lst_video: num_patches = pixel_values.shape[0] - video_repl = self.get_video_repl(self.num_image_token, - num_patches, self.video_token) - text = [t.replace('<video>', video_repl.full, 1) for t in text] + video_repl = self.get_video_repl( + self.num_image_token, num_patches, self.video_token + ) + text = [t.replace("<video>", video_repl.full, 1) for t in text] return text, video_inputs def __call__( @@ -488,9 +510,9 @@ def get_video_repl( repl_features = video_context_token * self.num_image_token repl_features_with_sep = IMG_START + repl_features + IMG_END # num_patches is equal to num_frames - repl_full = ''.join([ - f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches) - ]) + repl_full = "".join( + [f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)] + ) return PromptUpdateDetails.select_text(repl_full, video_context_token) @@ -525,8 +547,7 @@ def get_num_image_tokens( max_num_tiles=max_num_tiles, ) - def get_image_size_with_most_features(self, - max_num_tiles: int) -> ImageSize: + def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize: processor = self.get_hf_processor() base_size = processor.image_size @@ -544,8 +565,7 @@ def get_image_size_with_most_features(self, ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") @@ -557,7 +577,8 @@ def get_max_image_tokens(self) -> int: # Use default max_num_tiles for max tokens calculation max_num_tiles = 12 target_width, target_height = self.get_image_size_with_most_features( - max_num_tiles) + max_num_tiles + ) return self.get_num_image_tokens( image_width=target_width, @@ -571,7 +592,7 @@ def get_max_image_tokens(self) -> int: class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo): - """ ProcessingInfo extended for video processing""" + """ProcessingInfo extended for video processing""" @property def supports_video(self): @@ -595,8 +616,7 @@ def get_num_frames_with_most_features( processor = self.get_hf_processor() # we get the CustomProcessor here max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = (seq_len - - max_image_tokens) // processor.num_image_token + max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token max_frames_per_video = max_total_frames // max(max_videos, 1) max_frames_per_video = min(max_frames_per_video, MAX_FRAMES) @@ -649,7 +669,8 @@ def _get_mm_fields_config( return dict( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_patches), + "image", image_num_patches + ), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), @@ -675,7 +696,8 @@ def _get_prompt_updates( def get_replacement_custom(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -694,9 +716,9 @@ def get_replacement_custom(item_idx: int): local_image_num_patches = image_num_patches if isinstance(local_image_num_patches, torch.Tensor): local_image_num_patches = local_image_num_patches.tolist() - if isinstance( - local_image_num_patches, - (list, tuple)) and item_idx < len(local_image_num_patches): + if isinstance(local_image_num_patches, (list, tuple)) and item_idx < len( + local_image_num_patches + ): num_patches = int(local_image_num_patches[item_idx]) return hf_processor.get_image_repl(feature_size, num_patches) @@ -711,7 +733,8 @@ def get_replacement_custom(item_idx: int): class NanoNemotronVLMultiModalProcessor( - NanoNemotronBaseVLMultiModalProcessor[NanoNemotronVLProcessingInfo]): + NanoNemotronBaseVLMultiModalProcessor[NanoNemotronVLProcessingInfo] +): """MultiModalProcessor extended for video support""" def _call_hf_processor( @@ -721,12 +744,15 @@ def _call_hf_processor( mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: - processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs, tok_kwargs) + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs + ) hf_processor = self.info.get_hf_processor(**mm_kwargs) - if self.info.supports_video and ( - video_token_id := hf_processor.video_token_id) is not None: + if ( + self.info.supports_video + and (video_token_id := hf_processor.video_token_id) is not None + ): processed_outputs["video_token_id"] = torch.tensor(video_token_id) return processed_outputs @@ -735,18 +761,17 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - image_fields = super()._get_mm_fields_config(hf_inputs, - hf_processor_mm_kwargs) + image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs) if self.info.supports_video: - video_num_patches = hf_inputs.get("video_num_patches", - torch.empty(0)) + video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) num_videos = len(video_num_patches) video_fields = dict( pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_patches), + "video", video_num_patches + ), video_num_patches=MultiModalFieldConfig.batched("video"), - video_token_id=MultiModalFieldConfig.shared( - "video", num_videos)) + video_token_id=MultiModalFieldConfig.shared("video", num_videos), + ) else: video_fields = {} @@ -781,9 +806,8 @@ def get_video_replacement_internvl(item_idx: int): assert isinstance(num_patches, int) return hf_processor.get_video_repl( - feature_size, - num_patches, - video_context_token=hf_processor.video_token) + feature_size, num_patches, video_context_token=hf_processor.video_token + ) if self.info.supports_video: prompt_repl = [ @@ -792,7 +816,7 @@ def get_video_replacement_internvl(item_idx: int): modality="video", target="<video>", replacement=get_video_replacement_internvl, - ) + ), ] return prompt_repl @@ -814,23 +838,26 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: # Use default max_num_tiles for dummy data generation max_num_tiles = 12 - target_width, target_height = ( - self.info.get_image_size_with_most_features(max_num_tiles)) + target_width, target_height = self.info.get_image_size_with_most_features( + max_num_tiles + ) num_images = mm_counts.get("image", 0) image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class NanoNemotronVLDummyInputsBuilder( - NanoNemotronVLDummyInputsBuilder[NanoNemotronVLProcessingInfo]): + NanoNemotronVLDummyInputsBuilder[NanoNemotronVLProcessingInfo] +): """DummyInputsBuilder extended for video support""" def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: @@ -844,23 +871,25 @@ def get_dummy_mm_data( mm_counts: Mapping[str, int], mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - dummy_image = super().get_dummy_mm_data(seq_len=seq_len, - mm_counts=mm_counts, - mm_options=mm_options) + dummy_image = super().get_dummy_mm_data( + seq_len=seq_len, mm_counts=mm_counts, mm_options=mm_options + ) if self.info.supports_video: config = self.info.get_hf_config() image_size: int = config.force_image_size - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) num_videos = mm_counts.get("video", 0) video_overrides = mm_options.get("video") if mm_options else None dummy_video = { - "video": - self._get_dummy_videos(width=image_size, - height=image_size, - num_frames=target_num_frames, - num_videos=num_videos, - overrides=video_overrides) + "video": self._get_dummy_videos( + width=image_size, + height=image_size, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, + ) } else: dummy_video = {} @@ -872,9 +901,7 @@ def get_dummy_mm_data( info=NanoNemotronVLProcessingInfo, dummy_inputs=NanoNemotronVLDummyInputsBuilder, ) -class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, - SupportsMultiModal): - +class NemotronH_Nano_VL_V2(nn.Module, HasInnerState, IsHybrid, SupportsMultiModal): @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -892,7 +919,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.patch_size = patch_size self.template = config.template self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.image_tag_type = config.image_tag_type @@ -903,7 +931,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "language_model"), ) self.vision_model = self.get_vit_model_from_radio_config(config).to( - self.language_model.config.torch_dtype) + self.language_model.config.torch_dtype + ) # Construct the vision projection. vit_hidden_size = config.vit_hidden_size @@ -911,18 +940,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): llm_hidden_size = config.text_config.hidden_size self.mlp1 = nn.Sequential( - RMSNorm(hidden_size=vit_hidden_size * - int(1 / self.downsample_ratio)**2, - eps=1e-5), + RMSNorm( + hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + eps=1e-5, + ), nn.Linear( - vit_hidden_size * int(1 / self.downsample_ratio)**2, + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, vision_projection_hidden_size, bias=False, ), ReLUSquaredActivation(), - nn.Linear(vision_projection_hidden_size, - llm_hidden_size, - bias=False), + nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False), ) self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype) @@ -962,17 +990,16 @@ def pixel_shuffle(self, x, scale_factor=0.5): def extract_feature(self, pixel_values): vit_embeds = self.vision_model(pixel_values) vit_embeds = vit_embeds.to(dtype=torch.bfloat16) - h = w = int(vit_embeds.shape[1]**0.5) + h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle(vit_embeds, - scale_factor=self.downsample_ratio) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, - vit_embeds.shape[-1]) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[NanoNemotronVLImageInputs]: + self, **kwargs: object + ) -> Optional[NanoNemotronVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -982,8 +1009,10 @@ def _parse_and_validate_image_input( if image_embeds is not None: if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") + raise ValueError( + "Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}" + ) return NanoNemotronVLImageEmbeddinInputs( type="image_embeds", @@ -996,12 +1025,16 @@ def _parse_and_validate_image_input( if pixel_values_flat is not None: if not isinstance(pixel_values_flat, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat)}") + raise ValueError( + "Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat)}" + ) if not isinstance(image_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(image_num_patches)}") + raise ValueError( + "Incorrect type of image_num_patches. " + f"Got type: {type(image_num_patches)}" + ) pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True) @@ -1015,7 +1048,8 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") def _process_image_input( - self, image_input: NanoNemotronVLImageInputs) -> torch.Tensor: + self, image_input: NanoNemotronVLImageInputs + ) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["data"] @@ -1026,22 +1060,20 @@ def _process_image_input( # Only one image in the current batch if len(num_patches) == 1: - return (image_embeds.view(-1, - self.config.text_config.hidden_size), ) + return (image_embeds.view(-1, self.config.text_config.hidden_size),) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. feature_size = image_embeds.shape[1] - image_embeds = image_embeds.view(-1, - self.config.text_config.hidden_size) + image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ num_patches * feature_size for num_patches in num_patches ] return image_embeds.split(image_feature_sizes) def _parse_and_validate_video_input( - self, - **kwargs: object) -> Optional[NanoNemotronVLVideoPixelInputs]: + self, **kwargs: object + ) -> Optional[NanoNemotronVLVideoPixelInputs]: pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None) video_num_patches = kwargs.pop("video_num_patches", None) video_embeds = kwargs.pop("video_embeds", None) @@ -1061,15 +1093,18 @@ def _parse_and_validate_video_input( if pixel_values_flat_video is not None: if not isinstance(pixel_values_flat_video, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values_flat_video)}") + raise ValueError( + "Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat_video)}" + ) if not isinstance(video_num_patches, (torch.Tensor, list)): - raise ValueError("Incorrect type of image_num_patches. " - f"Got type: {type(video_num_patches)}") + raise ValueError( + "Incorrect type of image_num_patches. " + f"Got type: {type(video_num_patches)}" + ) - pixel_values_flat_video = flatten_bn(pixel_values_flat_video, - concat=True) + pixel_values_flat_video = flatten_bn(pixel_values_flat_video, concat=True) video_num_patches = flatten_bn(video_num_patches, concat=True) expected_h = expected_w = self.config.force_image_size resolve_bindings = {"h": expected_h, "w": expected_w} @@ -1088,19 +1123,17 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values_flat", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_flat_video", - ) and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values_flat", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_flat_video",) and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: # Validate the multimodal input keyword arguments modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if modalities is None: @@ -1193,16 +1226,13 @@ def is_vision_weights(name: str) -> bool: default_weight_loader(param, w) elif is_vision_weights(name): # Convert: vision_model.radio_model.* → radio_model.* - hf_key = name[len( - "vision_model."):] # Remove "vision_model." prefix + hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix vision_weights.append((hf_key, w)) self.language_model.load_weights(llm_weights) self.vision_model.load_weights(vision_weights) - def print_architecture(self, - detailed: bool = True, - save_to_file: str = None): + def print_architecture(self, detailed: bool = True, save_to_file: str = None): """ Print model architecture with parameter names, shapes, and sizes. @@ -1238,20 +1268,26 @@ def print_architecture(self, # Group parameters by main component if name.startswith("language_model"): param_groups["language_model"].append( - (name, param.shape, param_size, param.dtype)) + (name, param.shape, param_size, param.dtype) + ) elif name.startswith("vision_model"): param_groups["vision_model"].append( - (name, param.shape, param_size, param.dtype)) + (name, param.shape, param_size, param.dtype) + ) elif name.startswith("mlp1"): param_groups["mlp1"].append( - (name, param.shape, param_size, param.dtype)) + (name, param.shape, param_size, param.dtype) + ) else: param_groups["other"].append( - (name, param.shape, param_size, param.dtype)) + (name, param.shape, param_size, param.dtype) + ) if detailed: - print(f"{name:<70} | Shape: {str(param.shape):<25} | " - f"Size: {param_size:>12,} | Dtype: {param.dtype}") + print( + f"{name:<70} | Shape: {str(param.shape):<25} | " + f"Size: {param_size:>12,} | Dtype: {param.dtype}" + ) print("=" * 100) print("Summary by Component:") @@ -1260,11 +1296,16 @@ def print_architecture(self, for component, params in param_groups.items(): if params: # Only show components that have parameters component_total = sum(size for _, _, size, _ in params) - percentage = ((component_total / total_params) * - 100 if total_params > 0 else 0) - print(f"{component:<20} | Parameters: {len(params):>4} | " - f"Total Size: {component_total:>15,} | " - f"{percentage:>6.2f}%") + percentage = ( + (component_total / total_params) * 100 + if total_params > 0 + else 0 + ) + print( + f"{component:<20} | Parameters: {len(params):>4} | " + f"Total Size: {component_total:>15,} | " + f"{percentage:>6.2f}%" + ) print("-" * 60) print(f"{'Total Parameters':<20} | {total_params:>15,}") @@ -1320,10 +1361,9 @@ def get_vit_model_from_radio_config(self, hf_config): hf_config_vision = hf_config.vision_config model_name = hf_config_vision.args.get("model") if model_name is None: - raise ValueError(f'Unsupported vit model type: {model_name}') + raise ValueError(f"Unsupported vit model type: {model_name}") - preferred_resolution = getattr(hf_config_vision, - "preferred_resolution", None) + preferred_resolution = getattr(hf_config_vision, "preferred_resolution", None) image_size = preferred_resolution[0] if preferred_resolution else 224 patch_size = getattr(hf_config_vision, "patch_size", 16) @@ -1333,33 +1373,36 @@ def get_vit_model_from_radio_config(self, hf_config): patch_size=patch_size, norm_mean=hf_config.norm_mean, norm_std=hf_config.norm_std, - reg_tokens=(hf_config_vision.args.get("register_multiple") - if hasattr(hf_config_vision, "args") - and isinstance(hf_config_vision.args, dict) else None), + reg_tokens=( + hf_config_vision.args.get("register_multiple") + if hasattr(hf_config_vision, "args") + and isinstance(hf_config_vision.args, dict) + else None + ), ) return RadioModel(config=radio_config) def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): return self.language_model.mamba_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs) + input_buffers, **kwargs + ) def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return (self.language_model.mamba_cache. - get_seqlen_agnostic_capture_inputs(batch_size)) + return self.language_model.mamba_cache.get_seqlen_agnostic_capture_inputs( + batch_size + ) @classmethod def get_mamba_state_shape_from_config(cls, vllm_config: "VllmConfig"): text_config = vllm_config.model_config.hf_config.text_config temp_vllm_config = copy.deepcopy(vllm_config) temp_vllm_config.model_config.hf_config = text_config - return NemotronHForCausalLM.get_mamba_state_shape_from_config( - temp_vllm_config) + return NemotronHForCausalLM.get_mamba_state_shape_from_config(temp_vllm_config) @classmethod def get_mamba_state_dtype_from_config(cls, vllm_config: "VllmConfig"): text_config = vllm_config.model_config.hf_config.text_config temp_vllm_config = copy.deepcopy(vllm_config) temp_vllm_config.model_config.hf_config = text_config - return NemotronHForCausalLM.get_mamba_state_dtype_from_config( - temp_vllm_config) + return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 6bb2f7392cb4..8f07a2cf12f7 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Nemotron model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -35,23 +36,35 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) # The architecture is pretty similar to Llama, with these changes: # - There is no gate_proj, just up_proj @@ -65,20 +78,21 @@ def _cast_if_autocast_enabled(*args): return args else: return torch.amp.autocast_mode._cast( - args, device_type="cuda", dtype=torch.get_autocast_gpu_dtype()) + args, device_type="cuda", dtype=torch.get_autocast_gpu_dtype() + ) class NemotronLayerNorm1P(nn.LayerNorm): - - def __init__(self, - normalized_shape: Union[int, list[int], torch.Size], - eps: float = 1e-5, - elementwise_affine: bool = True, - bias: bool = True, - device=None, - dtype=None): - super().__init__(normalized_shape, eps, elementwise_affine, bias, - device, dtype) + def __init__( + self, + normalized_shape: Union[int, list[int], torch.Size], + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None, + ): + super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype) def forward( self, @@ -88,15 +102,15 @@ def forward( if residual is not None: x = x + residual residual = x - args = _cast_if_autocast_enabled(x, self.normalized_shape, - self.weight + 1, self.bias, self.eps) + args = _cast_if_autocast_enabled( + x, self.normalized_shape, self.weight + 1, self.bias, self.eps + ) with torch.amp.autocast("cuda", enabled=False): x = torch.nn.functional.layer_norm(*args) return x if residual is None else (x, residual) class NemotronMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -107,16 +121,20 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - self.up_proj = ColumnParallelLinear(input_size=hidden_size, - output_size=intermediate_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.up_proj") - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) self.act_fn = get_act_fn(hidden_act) def forward(self, x): @@ -127,7 +145,6 @@ def forward(self, x): class NemotronAttention(nn.Module): - def __init__( self, config: NemotronConfig, @@ -194,13 +211,15 @@ def __init__( rope_scaling=rope_scaling, partial_rotary_factor=self.partial_rotary_factor, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -216,7 +235,6 @@ def forward( class NemotronDecoderLayer(nn.Module): - def __init__( self, config: NemotronConfig, @@ -229,21 +247,24 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.self_attn = NemotronAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -260,10 +281,12 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = NemotronLayerNorm1P(config.hidden_size, - eps=config.norm_eps) + self.input_layernorm = NemotronLayerNorm1P( + config.hidden_size, eps=config.norm_eps + ) self.post_attention_layernorm = NemotronLayerNorm1P( - config.hidden_size, eps=config.norm_eps) + config.hidden_size, eps=config.norm_eps + ) def forward( self, @@ -276,23 +299,20 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class NemotronModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -303,12 +323,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -318,19 +342,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: NemotronDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), - prefix=f"{prefix}.layers") + lambda prefix: NemotronDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: - self.norm = NemotronLayerNorm1P(config.hidden_size, - eps=config.norm_eps) + self.norm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps) else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -357,16 +383,14 @@ def forward( hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -376,18 +400,19 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -416,8 +441,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -450,8 +474,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.model = NemotronModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = NemotronModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -463,7 +488,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -471,14 +497,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -490,8 +517,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( @@ -501,7 +529,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 987920ecc331..0a05c63a31ea 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -17,6 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only NemotronH model.""" + from collections.abc import Iterable from typing import Optional @@ -30,30 +31,46 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsLoRA, SupportsPP, - SupportsQuant) + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.interfaces import ( + HasInnerState, + IsHybrid, + SupportsLoRA, + SupportsPP, + SupportsQuant, +) from vllm.model_executor.models.utils import ( - AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, - make_layers, maybe_prefix) + AutoWeightsLoader, + WeightsMapper, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig class NemotronHMLP(nn.Module): - def __init__( self, config: NemotronHConfig, @@ -65,7 +82,7 @@ def __init__( super().__init__() hybrid_override_pattern = config.hybrid_override_pattern - mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1 + mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1 if isinstance(config.intermediate_size, list): if len(config.intermediate_size) == 1: intermediate_size = config.intermediate_size[0] @@ -98,7 +115,6 @@ def forward(self, x: torch.Tensor): class NemotronHMLPDecoderLayer(nn.Module): - def __init__( self, config: NemotronHConfig, @@ -138,7 +154,6 @@ def forward( class NemotronHMambaDecoderLayer(nn.Module): - def __init__( self, config: NemotronHConfig, @@ -188,7 +203,6 @@ def forward( class NemotronHAttention(nn.Module): - def __init__( self, config: NemotronHConfig, @@ -261,7 +275,6 @@ def forward( class NemotronHAttentionDecoderLayer(nn.Module): - def __init__( self, config: NemotronHConfig, @@ -310,7 +323,6 @@ def forward( @support_torch_compile class NemotronHModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -321,8 +333,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): lora_config = vllm_config.lora_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -335,7 +350,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) layer_class = ALL_DECODER_LAYER_TYPES[ - config.hybrid_override_pattern[layer_idx]] + config.hybrid_override_pattern[layer_idx] + ] return layer_class( config, layer_idx, @@ -346,11 +362,11 @@ def get_layer(prefix: str): ) self.start_layer, self.end_layer, self.layers = make_layers( - len(config.hybrid_override_pattern), - get_layer, - prefix=f"{prefix}.layers") + len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers" + ) self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size) + ["hidden_states", "residual"], config.hidden_size + ) self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -364,7 +380,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -385,15 +400,13 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -427,22 +440,19 @@ def load_weights(self, weights: Iterable[tuple[str, # load other params else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - IsHybrid, SupportsQuant): +class NemotronHForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant +): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"backbone": "model"}, - orig_to_new_substr={ - "A_log": "A", - "embeddings": "embed_tokens" - }, + orig_to_new_substr={"A_log": "A", "embeddings": "embed_tokens"}, ) packed_modules_mapping = { @@ -465,7 +475,6 @@ def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -505,19 +514,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "NemotronH currently does not support prefix caching" self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config - self.model = NemotronHModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = NemotronHModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -528,27 +535,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) - self.make_empty_intmd_tensors = (self.model.make_empty_intmd_tensors) + self.make_empty_intmd_tensors = self.model.make_empty_intmd_tensors def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states @@ -559,7 +570,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index d474c8db41b2..ddd623b5de23 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only deci model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -40,16 +41,26 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.llama import LlamaAttention, LlamaMLP from vllm.sequence import IntermediateTensors from .interfaces import HasNoOps, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int: @@ -66,7 +77,6 @@ def _find_multiple(n: int, k: int) -> int: class DeciLMAttention(LlamaAttention): - def __init__( self, config: LlamaConfig, @@ -83,18 +93,34 @@ def __init__( prefix: str = "", attn_type: str = AttentionType.DECODER, ) -> None: - super().__init__(config, hidden_size, num_heads, num_kv_heads, - rope_theta, rope_scaling, max_position_embeddings, - quant_config, bias, bias_o_proj, cache_config, prefix, - attn_type) + super().__init__( + config, + hidden_size, + num_heads, + num_kv_heads, + rope_theta, + rope_scaling, + max_position_embeddings, + quant_config, + bias, + bias_o_proj, + cache_config, + prefix, + attn_type, + ) - def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]], - quant_config: Optional[QuantizationConfig]) -> None: + def _init_rotary_emb( + self, + config, + rope_scaling: Optional[dict[str, Any]], + quant_config: Optional[QuantizationConfig], + ) -> None: # Enables YARN for Mistral and LLaMA4 derivatives. is_neox_style = True if hasattr(config, "position_embedding_type"): is_neox_style = config.position_embedding_type not in [ - "mistral_yarn", "rope_llama4" + "mistral_yarn", + "rope_llama4", ] self.rotary_emb = get_rope( @@ -104,11 +130,11 @@ def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]], base=self.rope_theta, rope_scaling=rope_scaling, is_neox_style=is_neox_style, - partial_rotary_factor=self.partial_rotary_factor) + partial_rotary_factor=self.partial_rotary_factor, + ) class DeciLMDecoderLayer(nn.Module): - def __init__( self, config: LlamaConfig, @@ -126,23 +152,26 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): + config, "original_max_position_embeddings", None + ): rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) bias_o_proj = attention_bias # support internlm/internlm3-8b with qkv_bias if hasattr(config, "qkv_bias"): attention_bias = config.qkv_bias if not self._is_no_op_attention: - num_kv_heads = (config.num_attention_heads // - block_config.attention.n_heads_in_group) + num_kv_heads = ( + config.num_attention_heads // block_config.attention.n_heads_in_group + ) self.self_attn = DeciLMAttention( config=config, hidden_size=self.hidden_size, @@ -157,13 +186,13 @@ def __init__( cache_config=cache_config, prefix=f"{prefix}.self_attn", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if not self._is_no_op_ffn: ffn_mult = block_config.ffn.ffn_mult intermediate_size = _ffn_mult_to_intermediate_size( - ffn_mult, config.hidden_size) + ffn_mult, config.hidden_size + ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -173,8 +202,9 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -187,12 +217,11 @@ def forward( if self._is_no_op_attention: pass else: - if (residual is None): + if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -201,14 +230,14 @@ def forward( # Fully Connected if not self._is_no_op_ffn: hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class DeciModel(nn.Module): - def __init__( self, *, @@ -226,12 +255,16 @@ def __init__( self.config = config self.quant_config = quant_config self.padding_idx = config.pad_token_id - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -261,9 +294,9 @@ def get_layer(prefix: str): else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -289,24 +322,20 @@ def forward( kv_cache_index = 0 for layer in islice(self.layers, self.start_layer, self.end_layer): if not layer._is_no_op_attention: - hidden_states, residual = layer(positions, hidden_states, - residual) + hidden_states, residual = layer(positions, hidden_states, residual) kv_cache_index += 1 else: - hidden_states, residual = layer(positions, hidden_states, - residual) + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -320,19 +349,19 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue if self.quant_config is not None and ( - scale_name := self.quant_config.get_cache_scale(name)): + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -365,8 +394,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -413,8 +441,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.lora_config = lora_config - self.model = self._init_model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = self._init_model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size @@ -428,24 +457,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else - lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if config.tie_word_embeddings: - self.lm_head = self.lm_head.tie_weights( - self.model.embed_tokens) + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): return DeciModel(vllm_config=vllm_config, prefix=prefix) @@ -460,8 +490,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits( @@ -471,11 +502,9 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index e6c4c5b022dc..268644bc9249 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -22,34 +22,45 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.models.internvl import ( - BaseInternVLDummyInputsBuilder, BaseInternVLMultiModalProcessor, - BaseInternVLProcessingInfo, InternVLImageEmbeddingInputs, - InternVLImageInputs, InternVLImagePixelInputs, InternVLProcessor) + BaseInternVLDummyInputsBuilder, + BaseInternVLMultiModalProcessor, + BaseInternVLProcessingInfo, + InternVLImageEmbeddingInputs, + InternVLImageInputs, + InternVLImagePixelInputs, + InternVLProcessor, +) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.processing import PromptUpdateDetails from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.processor import ( - cached_image_processor_from_config) +from vllm.transformers_utils.processor import cached_image_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix -IMG_START = '<img>' -IMG_END = '</img>' -IMG_CONTEXT = '<image>' +IMG_START = "<img>" +IMG_END = "</img>" +IMG_CONTEXT = "<image>" def build_transform(input_size: int): - return T.Compose([ - T.Lambda(lambda img: convert_image_mode(img, 'RGB')), - T.Resize((input_size, input_size), - interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - ]) + return T.Compose( + [ + T.Lambda(lambda img: convert_image_mode(img, "RGB")), + T.Resize( + (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC + ), + T.ToTensor(), + ] + ) # adapted from https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1 @@ -61,15 +72,16 @@ def find_closest_aspect_ratio( height: int, image_size: int, ) -> tuple[int, int]: - best_factor = float('-inf') + best_factor = float("-inf") best_ratio = (1, 1) area = width * height for rw, rh in target_ratios: target_aspect_ratio = rw / rh size_factor = min((rw * rh * image_size * image_size) / area, 0.6) - ratio_closeness = min(target_aspect_ratio / aspect_ratio, - aspect_ratio / target_aspect_ratio) + ratio_closeness = min( + target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio + ) factor = size_factor * ratio_closeness if factor > best_factor: @@ -132,10 +144,12 @@ def dynamic_preprocess_nemotron_vl( resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): - box = ((i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size) + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) @@ -153,10 +167,13 @@ def get_nemotron_vl_target_ratios( min_num: int, max_num: int, ) -> list[tuple[int, int]]: - target_ratios = {(i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) if min_num <= i * j <= max_num} + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } return sorted(target_ratios, key=lambda x: x[0] * x[1]) @@ -184,7 +201,6 @@ def image_to_pixel_values_nemotron_vl( class NemotronVLProcessor(InternVLProcessor): - def __init__( self, config: PretrainedConfig, @@ -215,7 +231,8 @@ def __init__( assert isinstance(dynamic_image_size, bool) self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.image_size = image_size self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch @@ -267,7 +284,8 @@ def _images_to_pixel_values_lst( min_num=min_num, max_num=max_num, use_thumbnail=self.use_thumbnail, - ) for image in images + ) + for image in images ] def _preprocess_image( @@ -288,10 +306,10 @@ def _preprocess_image( dynamic_image_size=dynamic_image_size, ) image_inputs = { - "pixel_values_flat": - torch.cat(pixel_values_lst), - "image_num_patches": - torch.tensor([len(item) for item in pixel_values_lst]), + "pixel_values_flat": torch.cat(pixel_values_lst), + "image_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst] + ), } for pixel_values in pixel_values_lst: @@ -299,10 +317,9 @@ def _preprocess_image( feature_size = num_patches * self.num_image_token image_repl = self.get_image_repl(feature_size, num_patches) NVL_IMAGE_CONTEXT = image_repl.full.replace( - "<image>", "<NVL_IMG_CONTEXT>") - text = [ - t.replace('<image>', NVL_IMAGE_CONTEXT, 1) for t in text - ] + "<image>", "<NVL_IMG_CONTEXT>" + ) + text = [t.replace("<image>", NVL_IMAGE_CONTEXT, 1) for t in text] text = [t.replace("<NVL_IMG_CONTEXT>", IMG_CONTEXT) for t in text] return text, image_inputs @@ -339,9 +356,9 @@ def get_image_processor(self, **kwargs: object): @MULTIMODAL_REGISTRY.register_processor( BaseInternVLMultiModalProcessor[NemotronVLProcessingInfo], info=NemotronVLProcessingInfo, - dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo]) -class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, - SupportsLoRA): + dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo], +) +class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): merge_by_field_config = True @classmethod @@ -366,7 +383,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: patch_size = config.vision_config.patch_size self.patch_size = patch_size self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version @@ -389,18 +407,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.visual_token_mask = None self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _patch_quant_config(self, config: PretrainedConfig, - quant_config: QuantizationConfig): + def _patch_quant_config( + self, config: PretrainedConfig, quant_config: QuantizationConfig + ): # the awq models from OpenGVLab missing `modules_to_not_convert` # patch the quant_config to add `modules_to_not_convert` back if isinstance(quant_config, AWQConfig): text_config = config.text_config - llm_quant_config = getattr(text_config, "quantization_config", - None) - if (not quant_config.modules_to_not_convert) and \ - (llm_quant_config is not None): + llm_quant_config = getattr(text_config, "quantization_config", None) + if (not quant_config.modules_to_not_convert) and ( + llm_quant_config is not None + ): quant_config.modules_to_not_convert.append("vision_model") def _init_vision_model( @@ -410,8 +430,7 @@ def _init_vision_model( *, prefix: str, ): - return AutoModel.from_config(config.vision_config, - trust_remote_code=True) + return AutoModel.from_config(config.vision_config, trust_remote_code=True) def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vit_hidden_size @@ -419,11 +438,14 @@ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: llm_hidden_size = config.text_config.hidden_size return nn.Sequential( - nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2, - bias=True), - nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, - vision_projection_hidden_size, - bias=True), + nn.LayerNorm( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, bias=True + ), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + vision_projection_hidden_size, + bias=True, + ), nn.GELU(), nn.Linear(vision_projection_hidden_size, llm_hidden_size), ) @@ -434,9 +456,13 @@ def pixel_shuffle(self, x, scale_factor=0.5): x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() - x = x.view(n, int(h * scale_factor), int(w * scale_factor), - int(c / (scale_factor * scale_factor))) - if self.ps_version == 'v1': + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": pass else: x = x.permute(0, 2, 1, 3).contiguous() @@ -447,17 +473,16 @@ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = self.vision_model(x=pixel_values).features vit_embeds = vit_embeds.to(dtype=torch.bfloat16) - h = w = int(vit_embeds.shape[1]**0.5) + h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle(vit_embeds, - scale_factor=self.downsample_ratio) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, - vit_embeds.shape[-1]) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[InternVLImageInputs]: + self, **kwargs: object + ) -> Optional[InternVLImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -482,7 +507,7 @@ def _parse_and_validate_image_input( num_patches=image_num_patches, resolve_bindings={ "h": self.config.force_image_size, - "w": self.config.force_image_size + "w": self.config.force_image_size, }, ) @@ -503,14 +528,12 @@ def _process_image_input( # Only one image in the current batch if len(num_patches) == 1: - return (image_embeds.view(-1, - self.config.text_config.hidden_size), ) + return (image_embeds.view(-1, self.config.text_config.hidden_size),) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. feature_size = image_embeds.shape[1] - image_embeds = image_embeds.view(-1, - self.config.text_config.hidden_size) + image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ num_patches * feature_size for num_patches in num_patches ] @@ -522,10 +545,11 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values_flat", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) + if ( + input_key in ("pixel_values_flat", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) return modalities @@ -535,9 +559,7 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -564,8 +586,7 @@ def get_input_embeddings( is_multimodal: Optional[torch.Tensor] = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: - if multimodal_embeddings is not None and len( - multimodal_embeddings) > 0: + if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) # This is to satisfy the type checker for each overload @@ -587,7 +608,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> IntermediateTensors: - if intermediate_tensors is not None: input_ids = None inputs_embeds = None @@ -601,8 +621,7 @@ def forward( # Only required if the model is mono-architecture if self.visual_token_mask is not None: - forward_kwargs.update( - {"visual_token_mask": self.visual_token_mask}) + forward_kwargs.update({"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None hidden_states = self.language_model.model(**forward_kwargs) @@ -614,8 +633,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ## Ignore registered_buffers ## see https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/input_conditioner.py#L28 # noqa: E501 skip_substrs = ["norm_mean", "norm_std"] @@ -629,4 +647,5 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model", connector="mlp1", - tower_model="vision_model") + tower_model="vision_model", + ) diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index 03b77823e969..f17bf3b09d5b 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -18,22 +18,30 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - MultiModalDataItems) -from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from .intern_vit import InternVisionModel -from .internvl import (BaseInternVLDummyInputsBuilder, - BaseInternVLMultiModalProcessor, - BaseInternVLProcessingInfo, BaseInternVLProcessor, - InternVLChatModel) +from .internvl import ( + BaseInternVLDummyInputsBuilder, + BaseInternVLMultiModalProcessor, + BaseInternVLProcessingInfo, + BaseInternVLProcessor, + InternVLChatModel, +) IMG_PAD = "<|vision_pad|>" class NVLMProcessor(BaseInternVLProcessor): - @property def image_token_id(self) -> int: return self.tokenizer.get_vocab()[IMG_PAD] @@ -51,8 +59,9 @@ def get_image_repl( tile_pos_identifiers += ["<tile_global_thumbnail>"] context_size = feature_size // num_patches - features = "".join(identifier + IMG_PAD * context_size - for identifier in tile_pos_identifiers) + features = "".join( + identifier + IMG_PAD * context_size for identifier in tile_pos_identifiers + ) # We include the start and end as well because "<Image><tile" is # tokenized as ["<Image", "><", "tile"], resulting in assertion error @@ -63,7 +72,6 @@ def get_image_repl( class NVLMProcessingInfo(BaseInternVLProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> NVLMProcessor: return self.ctx.init_processor( NVLMProcessor, @@ -73,9 +81,7 @@ def get_hf_processor(self, **kwargs: object) -> NVLMProcessor: ) -class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo] - ): - +class NVLMDummyInputsBuilder(BaseInternVLDummyInputsBuilder[NVLMProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -89,24 +95,22 @@ def get_dummy_mm_data( mm_counts: Mapping[str, int], mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } -class NVLMMultiModalProcessor( - BaseInternVLMultiModalProcessor[NVLMProcessingInfo]): - +class NVLMMultiModalProcessor(BaseInternVLMultiModalProcessor[NVLMProcessingInfo]): def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -129,7 +133,8 @@ def _get_prompt_updates( def get_replacement_nvlm(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -159,21 +164,24 @@ def get_replacement_nvlm(item_idx: int): ] -@MULTIMODAL_REGISTRY.register_processor(NVLMMultiModalProcessor, - info=NVLMProcessingInfo, - dummy_inputs=NVLMDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + NVLMMultiModalProcessor, + info=NVLMProcessingInfo, + dummy_inputs=NVLMDummyInputsBuilder, +) class NVLM_D_Model(InternVLChatModel): - def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: vit_hidden_size = config.vision_config.hidden_size llm_intermediate_size = config.text_config.intermediate_size llm_hidden_size = config.text_config.hidden_size return nn.Sequential( - nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), - nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, - llm_intermediate_size, - bias=False), + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + llm_intermediate_size, + bias=False, + ), nn.GELU(), nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False), ) @@ -189,8 +197,9 @@ def _init_vision_model( if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 + num_hidden_layers = ( + config.vision_config.num_hidden_layers + vision_feature_layer + 1 + ) else: num_hidden_layers = vision_feature_layer + 1 diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 9fa8760073c1..f334bbf9feeb 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -36,21 +37,29 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class OlmoAttention(nn.Module): @@ -70,15 +79,13 @@ def __init__( super().__init__() self.config = config self.hidden_size = config.hidden_size - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.hidden_size % self.total_num_heads == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta @@ -102,12 +109,14 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) # Attention output projection. self.o_proj = RowParallelLinear( @@ -189,28 +198,29 @@ class OlmoDecoderLayer(nn.Module): (plus another skip connection). """ - def __init__(self, - config: OlmoConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = OlmoAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attn" + ) # MLP block. self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp") # LayerNorm - self.input_layernorm = nn.LayerNorm(config.hidden_size, - elementwise_affine=False, - bias=False) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - elementwise_affine=False, - bias=False) + self.input_layernorm = nn.LayerNorm( + config.hidden_size, elementwise_affine=False, bias=False + ) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, elementwise_affine=False, bias=False + ) def forward( self, @@ -233,7 +243,6 @@ def forward( @support_torch_compile class OlmoModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -243,19 +252,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: OlmoDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.norm = nn.LayerNorm(config.hidden_size, - elementwise_affine=False, - bias=False) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + self.norm = nn.LayerNorm( + config.hidden_size, elementwise_affine=False, bias=False + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -291,8 +303,7 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -304,7 +315,7 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -324,8 +335,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -335,6 +345,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -352,8 +363,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - self.model = OlmoModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = OlmoModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: @@ -367,7 +379,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -394,11 +407,11 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None), + skip_prefixes=( + ["lm_head.weight"] if self.config.tie_word_embeddings else None + ), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index e7e30ee8df0f..79234cc4dd8d 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -42,18 +42,27 @@ from vllm.distributed.utils import split_tensor_along_last_dim from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.utils import ( - AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Olmo3Config @@ -78,8 +87,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert self.total_num_heads % self.tp_size == 0 self.num_heads = self.total_num_heads // self.tp_size - self.total_num_kv_heads = (self.config.num_key_value_heads - or self.total_num_heads) + self.total_num_kv_heads = ( + self.config.num_key_value_heads or self.total_num_heads + ) if self.total_num_kv_heads >= self.tp_size: assert self.total_num_kv_heads % self.tp_size == 0 else: @@ -108,15 +118,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.total_num_kv_heads * self.head_dim, eps=self.config.rms_norm_eps, ) - self.q_norm = RMSNorm(self.config.hidden_size, - eps=self.config.rms_norm_eps) + self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.scaling = self.head_dim**-0.5 layer_idx = extract_layer_index(prefix) sliding_window = None - if ((layer_types := getattr(self.config, "layer_types", None)) - is not None and layer_types[layer_idx] == "sliding_attention"): + if ( + layer_types := getattr(self.config, "layer_types", None) + ) is not None and layer_types[layer_idx] == "sliding_attention": sliding_window = self.config.sliding_window self.attn = Attention( @@ -132,8 +142,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Rotary embeddings. Rope scaling is only applied on full attention # layers. - self.rope_scaling = (self.config.rope_scaling - if sliding_window is None else None) + self.rope_scaling = self.config.rope_scaling if sliding_window is None else None self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -151,16 +160,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.o_proj", ) - def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) q = self.q_norm(q) k = self.k_norm(k) if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k @@ -236,18 +245,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config assert isinstance(config, (Olmo2Config, Olmo3Config)) # Attention block. - self.self_attn = Olmo2Attention(vllm_config=vllm_config, - prefix=f"{prefix}.self_attn") + self.self_attn = Olmo2Attention( + vllm_config=vllm_config, prefix=f"{prefix}.self_attn" + ) # MLP block. self.mlp = Olmo2MLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp") # LayerNorm - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) - self.post_feedforward_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -270,7 +282,6 @@ def forward( @support_torch_compile class Olmo2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config @@ -283,17 +294,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( self.config.num_hidden_layers, - lambda prefix: Olmo2DecoderLayer(vllm_config=vllm_config, - prefix=prefix), + lambda prefix: Olmo2DecoderLayer(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - self.config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -334,8 +344,7 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -366,8 +375,7 @@ def load_weights(self, weights: Iterable[tuple[str, if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -377,6 +385,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """ Extremely barebones HF model wrapper. """ + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -394,8 +403,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config assert isinstance(config, (Olmo2Config, Olmo3Config)) self.config = config - self.model = Olmo2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Olmo2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: @@ -409,7 +419,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -439,7 +450,8 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None), + skip_prefixes=( + ["lm_head.weight"] if self.config.tie_word_embeddings else None + ), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 77ece544d490..90ec1a890417 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMoE model compatible with HuggingFace weights.""" + from collections.abc import Iterable from functools import partial from itertools import islice @@ -25,28 +26,39 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.distributed.utils import split_tensor_along_last_dim from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -60,33 +72,36 @@ class OlmoeMoE(nn.Module): across ranks. """ - def __init__(self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - prefix: str = ""): + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = hidden_size # Gate always runs at half / full precision for now. - self.gate = ReplicatedLinear(hidden_size, - num_experts, - bias=False, - quant_config=None) - - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - reduce_results=True, - renormalize=False, - quant_config=quant_config, - tp_size=tp_size, - prefix=f"{prefix}.experts") + self.gate = ReplicatedLinear( + hidden_size, num_experts, bias=False, quant_config=None + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -95,13 +110,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) return final_hidden_states.view(orig_shape) class OlmoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -148,8 +163,7 @@ def __init__( self.tp_size = tp_size self.tp_rank = get_tensor_model_parallel_rank() self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5) - self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, - eps=1e-5) + self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, @@ -165,24 +179,26 @@ def __init__( rope_scaling=rope_scaling, is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def _apply_qk_norm(self, q: torch.Tensor, - k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) q = self.q_norm(q) k = self.k_norm(k) if self.tp_size > 1: - splitter = partial(split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k @@ -202,7 +218,6 @@ def forward( class OlmoeDecoderLayer(nn.Module): - def __init__( self, config: OlmoeConfig, @@ -214,8 +229,7 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 4096) + max_position_embeddings = getattr(config, "max_position_embeddings", 4096) self.self_attn = OlmoeAttention( hidden_size=self.hidden_size, @@ -251,8 +265,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, @@ -260,15 +273,13 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class OlmoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -285,13 +296,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: OlmoeDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=1e-5) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -322,10 +335,9 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -337,10 +349,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) + num_experts=self.config.num_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -354,7 +366,7 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -391,11 +403,13 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -407,7 +421,8 @@ def load_weights(self, weights: Iterable[tuple[str, # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") + ".kv_scale", ".attn.kv_scale" + ) if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 @@ -419,8 +434,9 @@ def load_weights(self, weights: Iterable[tuple[str, name = remapped_kv_scale_name param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -445,16 +461,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = OlmoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + self.model = OlmoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -466,16 +486,16 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index c4746166471c..eadfea6084e5 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -32,25 +33,33 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class OPTLearnedPositionalEmbedding(nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int): # OPT is set up so that if padding_idx is specified then offset the # embedding ids by 2 and adjust num_embeddings appropriately. Other @@ -63,7 +72,6 @@ def forward(self, positions: torch.Tensor): class OPTAttention(nn.Module): - def __init__( self, embed_dim: int, @@ -75,8 +83,7 @@ def __init__( ) -> None: super().__init__() self.embed_dim = embed_dim - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() total_num_heads = num_heads assert num_heads % tensor_model_parallel_world_size == 0 self.num_heads = total_num_heads // tensor_model_parallel_world_size @@ -98,12 +105,14 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.out_proj", ) - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -117,7 +126,6 @@ def forward( class OPTDecoderLayer(nn.Module): - def __init__( self, config: OPTConfig, @@ -139,8 +147,8 @@ def __init__( self.do_layer_norm_before = config.do_layer_norm_before self.self_attn_layer_norm = nn.LayerNorm( - self.embed_dim, - elementwise_affine=config.layer_norm_elementwise_affine) + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) self.fc1 = ColumnParallelLinear( self.embed_dim, config.ffn_dim, @@ -157,8 +165,8 @@ def __init__( prefix=f"{prefix}.fc2", ) self.final_layer_norm = nn.LayerNorm( - self.embed_dim, - elementwise_affine=config.layer_norm_elementwise_affine) + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) def forward( self, @@ -191,7 +199,6 @@ def forward( class OPTDecoder(nn.Module): - def __init__( self, config: OPTConfig, @@ -210,24 +217,29 @@ def __init__( ) # Positional embeddings are replicated (not sharded). self.embed_positions = OPTLearnedPositionalEmbedding( - config.max_position_embeddings, config.hidden_size) + config.max_position_embeddings, config.hidden_size + ) # Project out & in will be replicated if they exist. if config.word_embed_proj_dim != config.hidden_size: - self.project_out = ReplicatedLinear(config.hidden_size, - config.word_embed_proj_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.project_out") + self.project_out = ReplicatedLinear( + config.hidden_size, + config.word_embed_proj_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.project_out", + ) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: - self.project_in = ReplicatedLinear(config.word_embed_proj_dim, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.project_in") + self.project_in = ReplicatedLinear( + config.word_embed_proj_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.project_in", + ) else: self.project_in = None @@ -238,15 +250,18 @@ def __init__( if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm( config.hidden_size, - elementwise_affine=config.layer_norm_elementwise_affine) + elementwise_affine=config.layer_norm_elementwise_affine, + ) else: self.final_layer_norm = None self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: OPTDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -283,7 +298,6 @@ def forward( @support_torch_compile class OPTModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -291,13 +305,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - self.decoder = OPTDecoder(config, - cache_config, - quant_config, - prefix=f"{prefix}.decoder") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.decoder = OPTDecoder( + config, cache_config, quant_config, prefix=f"{prefix}.decoder" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.decoder.get_input_embeddings(input_ids) @@ -309,13 +322,11 @@ def forward( intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - return self.decoder(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + return self.decoder( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -325,7 +336,7 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -345,8 +356,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -357,9 +367,11 @@ class OPTForCausalLM(nn.Module, SupportsPP, SupportsLoRA): "qkv_proj": ["q_proj", "k_proj", "v_proj"], } - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "decoder.": "model.decoder.", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "decoder.": "model.decoder.", + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -367,18 +379,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = OPTModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = OPTModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if self.config.tie_word_embeddings: self.lm_head = self.model.decoder.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.word_embed_proj_dim, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.word_embed_proj_dim, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -390,8 +405,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -401,11 +417,11 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None), + skip_prefixes=( + ["lm_head.weight"] if self.config.tie_word_embeddings else None + ), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 586fea343d6f..0ce172938955 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -6,6 +6,7 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -19,25 +20,32 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class OrionMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -47,16 +55,15 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -67,7 +74,6 @@ def forward(self, x): class OrionAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -125,13 +131,15 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -147,7 +155,6 @@ def forward( class OrionDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -159,8 +166,7 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = OrionAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -179,10 +185,10 @@ def __init__( quant_config=quant_config, ) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -209,7 +215,6 @@ def forward( @support_torch_compile class OrionModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -226,13 +231,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: OrionDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory([ + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + [ "hidden_states", - ], config.hidden_size)) + ], + config.hidden_size, + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -255,14 +264,15 @@ def forward( for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer(positions, hidden_states) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -274,7 +284,7 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -294,32 +304,34 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class OrionForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = OrionModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + self.model = OrionModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -331,8 +343,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -342,7 +355,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index f8674b4f0e3f..08ce8c5d83a6 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -16,7 +16,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Ovis model.""" +"""PyTorch Ovis model.""" + import math from collections.abc import Iterable, Mapping from typing import Annotated, Literal, Optional, Union @@ -33,15 +34,24 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.aimv2 import AIMv2Model from vllm.model_executor.models.siglip import SiglipVisionModel -from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, - init_vllm_registered_model, - maybe_prefix) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageSize, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processors.ovis import OvisProcessor @@ -74,7 +84,6 @@ def st_argmax(y_soft: torch.Tensor, dim: int): # straight-through softmax class VisualTokenizer(torch.nn.Module): - def __init__( self, config: PretrainedConfig, @@ -92,12 +101,15 @@ def __init__( head_dim = config.vocab_size - len(IMAGE_INDICATOR_IDS) self.head = torch.nn.Sequential( ReplicatedLinear( - config.backbone_config.hidden_size * config.hidden_stride * - config.hidden_stride, + config.backbone_config.hidden_size + * config.hidden_stride + * config.hidden_stride, head_dim, bias=False, return_bias=False, - ), torch.nn.LayerNorm(head_dim)) + ), + torch.nn.LayerNorm(head_dim), + ) def _init_backbone( self, @@ -120,8 +132,7 @@ def _init_backbone( quant_config=quant_config, prefix=prefix, ) - raise ValueError( - f"Unsupported visual tokenizer model_type: {model_type}") + raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}") @property def dtype(self) -> torch.dtype: @@ -132,16 +143,17 @@ def device(self) -> torch.device: return next(self.head.parameters()).device def tokenize(self, logits: torch.Tensor) -> torch.Tensor: - if self.config.tokenize_function == 'softmax': + if self.config.tokenize_function == "softmax": tokens = softmax(logits, dim=-1) - elif self.config.tokenize_function == 'gumbel_argmax': + elif self.config.tokenize_function == "gumbel_argmax": tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) - elif self.config.tokenize_function == 'st_argmax': + elif self.config.tokenize_function == "st_argmax": tokens = st_argmax(logits, dim=-1) else: raise ValueError( - 'Invalid `max_type`, expected softmax or gumbel_argmax ' - f'or st_argmax, but got {self.config.tokenize_function}') + "Invalid `max_type`, expected softmax or gumbel_argmax " + f"or st_argmax, but got {self.config.tokenize_function}" + ) return tokens def encode(self, pixel_values: torch.Tensor) -> torch.Tensor: @@ -158,25 +170,30 @@ def encode(self, pixel_values: torch.Tensor) -> torch.Tensor: n, L, d = features.shape sqrt_l = int(L**0.5) assert sqrt_l**2 == L, ( - "The token sequence length should be a perfect square.") + "The token sequence length should be a perfect square." + ) features = features.reshape(n, sqrt_l, sqrt_l, d) - pl = (self.config.hidden_stride - - (sqrt_l % - self.config.hidden_stride)) % self.config.hidden_stride + pl = ( + self.config.hidden_stride - (sqrt_l % self.config.hidden_stride) + ) % self.config.hidden_stride features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) sqrt_l += pl - features = features.reshape(n, sqrt_l // self.config.hidden_stride, - self.config.hidden_stride, - sqrt_l // self.config.hidden_stride, - self.config.hidden_stride, d) + features = features.reshape( + n, + sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, + sqrt_l // self.config.hidden_stride, + self.config.hidden_stride, + d, + ) # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d] features = features.permute(0, 1, 3, 2, 4, 5) # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d] features = features.flatten(3) # [n, sqrt_l/hs*sqrt_l/hs, hs*hs*d] features = features.reshape( - n, -1, - self.config.hidden_stride * self.config.hidden_stride * d) + n, -1, self.config.hidden_stride * self.config.hidden_stride * d + ) return features @@ -206,23 +223,25 @@ class OvisImagePatchInputs(TensorSchema): - patches_per_image: List of number of total patches for each image in the batch. """ + type: Literal["image_patches"] - flat_data: Annotated[torch.Tensor, - TensorShape("batch_patches", "patch_size")] + flat_data: Annotated[torch.Tensor, TensorShape("batch_patches", "patch_size")] indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")] - patches_per_image: Annotated[list[int], - TensorShape("num_patches_per_image")] + patches_per_image: Annotated[list[int], TensorShape("num_patches_per_image")] # This is used to restore the first two dimensions of `flat_data`. class VisualEmbedding(torch.nn.Embedding): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, visual_tokens: Tensor) -> Tensor: if visual_tokens.dtype in [ - torch.int8, torch.int16, torch.int32, torch.int64, torch.long + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.long, ]: return super().forward(visual_tokens) return torch.matmul(visual_tokens, self.weight) @@ -237,7 +256,6 @@ def dtype(self): class OvisProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor( OvisProcessor, @@ -254,9 +272,10 @@ def get_image_segment_len(self) -> int: patch_grid_length = math.ceil(image_size / patch_size) assert patch_grid_length % hidden_stride == 0, ( f"patch_grid_length {patch_grid_length} is not divisible by " - f"hidden_stride {hidden_stride}") + f"hidden_stride {hidden_stride}" + ) # minus 1 for presented image token - return (patch_grid_length // hidden_stride)**2 - 1 + return (patch_grid_length // hidden_stride) ** 2 - 1 def get_image_pad_token(self) -> str: hf_text_config = self.get_hf_config().get_text_config() @@ -275,7 +294,6 @@ def get_image_size_with_most_features(self) -> ImageSize: class OvisDummyInputsBuilder(BaseDummyInputsBuilder[OvisProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) return IMAGE_TOKEN * num_images @@ -288,29 +306,28 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), } return mm_data class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): - def image_indicators_to_visual_tokens( self, image_indicators: list[int], ) -> list[int]: """ - Filter image indicators placeholders and convert them to corresponding + Filter image indicators placeholders and convert them to corresponding tokens in visual tokenizer. For example, [-301, -300, -302, -300, -303, -300, -304, -300, -305] should return [vocab_size-1, vocab_size-2, ..., vocab_size-5] @@ -356,7 +373,6 @@ def _apply_hf_processor_tokens_only( self, prompt_tokens: list[int], ) -> list[int]: - return prompt_tokens def _get_mm_fields_config( @@ -364,9 +380,11 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict(pixel_values=MultiModalFieldConfig.batched("image"), - grids=MultiModalFieldConfig.batched("image"), - indicator_tokens=MultiModalFieldConfig.batched("image")) + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + grids=MultiModalFieldConfig.batched("image"), + indicator_tokens=MultiModalFieldConfig.batched("image"), + ) def _get_prompt_updates( self, @@ -374,7 +392,6 @@ def _get_prompt_updates( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptReplacement]: - def get_replacement_ovis(item_idx: int): out_item = out_mm_kwargs["image"][item_idx] grid = out_item["grids"].data @@ -391,11 +408,12 @@ def get_replacement_ovis(item_idx: int): ] -@MULTIMODAL_REGISTRY.register_processor(OvisMultiModalProcessor, - info=OvisProcessingInfo, - dummy_inputs=OvisDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + OvisMultiModalProcessor, + info=OvisProcessingInfo, + dummy_inputs=OvisDummyInputsBuilder, +) class Ovis(nn.Module, SupportsMultiModal, SupportsPP): - @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -421,17 +439,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.vte = VisualEmbedding( - self.config.visual_tokenizer_config.vocab_size, - self.config.hidden_size) + self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size + ) text_model_type = self.config.get_text_config().model_type self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] self.make_empty_intermediate_tensors = ( - self.get_language_model().make_empty_intermediate_tensors) + self.get_language_model().make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[OvisImagePatchInputs]: + self, **kwargs: object + ) -> Optional[OvisImagePatchInputs]: pixel_values = kwargs.pop("pixel_values", None) indicator_tokens = kwargs.pop("indicator_tokens", None) @@ -440,12 +460,15 @@ def _parse_and_validate_image_input( if pixel_values is not None and indicator_tokens is not None: if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + f"Incorrect type of pixel values. Got type: {type(pixel_values)}" + ) if not isinstance(indicator_tokens, (torch.Tensor, list)): - raise ValueError("Incorrect type of indicator_tokens. " - f"Got type: {type(pixel_values)}") + raise ValueError( + "Incorrect type of indicator_tokens. " + f"Got type: {type(pixel_values)}" + ) flat_data = flatten_bn(pixel_values, concat=True) if flat_data.ndim >= 3: @@ -453,49 +476,46 @@ def _parse_and_validate_image_input( return OvisImagePatchInputs( type="image_patches", flat_data=flat_data, - patches_per_image=[ - x.shape[0] for x in flatten_bn(pixel_values) - ], - indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), - concat=True), + patches_per_image=[x.shape[0] for x in flatten_bn(pixel_values)], + indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True), ) raise AssertionError("This line should be unreachable.") def _process_image_input( - self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings: + self, image_input: OvisImagePatchInputs + ) -> MultiModalEmbeddings: image_patches_flat = image_input["flat_data"] patches_per_image = image_input["patches_per_image"] indicator_tokens = image_input["indicator_tokens"] indicator_per_image = list( - map(lambda x: x + 1 if x > 1 else x + 2, patches_per_image)) + map(lambda x: x + 1 if x > 1 else x + 2, patches_per_image) + ) target_dtype = self.visual_tokenizer.dtype - visual_tokens = self.visual_tokenizer( - image_patches_flat.to(target_dtype)) + visual_tokens = self.visual_tokenizer(image_patches_flat.to(target_dtype)) visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. indicator_embeds = self.vte(indicator_tokens) - indicator_embeds_per_image = indicator_embeds.split( - indicator_per_image) + indicator_embeds_per_image = indicator_embeds.split(indicator_per_image) visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0) vision_embeddings = [] - for indicator, visual in zip(indicator_embeds_per_image, - visual_embeds_per_image): + for indicator, visual in zip( + indicator_embeds_per_image, visual_embeds_per_image + ): vision_embeddings_per_image = [] for i in range(visual.shape[0]): vision_embeddings_per_image.append( - torch.cat([indicator[i:i + 1], visual[i]], dim=0)) - vision_embeddings_per_image.append(indicator[i + 1:]) - vision_embeddings.append( - torch.cat(vision_embeddings_per_image, dim=0)) + torch.cat([indicator[i : i + 1], visual[i]], dim=0) + ) + vision_embeddings_per_image.append(indicator[i + 1 :]) + vision_embeddings.append(torch.cat(vision_embeddings_per_image, dim=0)) return tuple(vision_embeddings) - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -532,8 +552,7 @@ def compute_logits( logits = self.llm.compute_logits(hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index 18dea14379a6..8f73f2ff8263 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" PyTorch Ovis model.""" +"""PyTorch Ovis model.""" + from collections.abc import Iterable, Mapping from functools import partial from typing import Literal, Optional, TypedDict, Union @@ -13,18 +14,26 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.models.ovis import (OvisImagePatchInputs, - VisualEmbedding) +from vllm.model_executor.models.ovis import OvisImagePatchInputs, VisualEmbedding from vllm.model_executor.models.siglip2navit import Siglip2NavitModel -from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, - init_vllm_registered_model, - maybe_prefix) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageSize, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor @@ -71,12 +80,14 @@ class OvisVideoPatchInputs(TypedDict): def _ovis2_5_field_config(): - return dict(pixel_values=MultiModalFieldConfig.batched("image"), - grids=MultiModalFieldConfig.batched("image"), - indicator_tokens=MultiModalFieldConfig.batched("image"), - video_pixel_values=MultiModalFieldConfig.batched("video"), - video_indicator_tokens=MultiModalFieldConfig.batched("video"), - video_grids=MultiModalFieldConfig.batched("video")) + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + grids=MultiModalFieldConfig.batched("image"), + indicator_tokens=MultiModalFieldConfig.batched("image"), + video_pixel_values=MultiModalFieldConfig.batched("video"), + video_indicator_tokens=MultiModalFieldConfig.batched("video"), + video_grids=MultiModalFieldConfig.batched("video"), + ) class VisualTokenizer(torch.nn.Module): @@ -108,7 +119,9 @@ def __init__( head_dim, bias=False, return_bias=False, - ), torch.nn.LayerNorm(head_dim)) + ), + torch.nn.LayerNorm(head_dim), + ) def _init_backbone( self, @@ -119,12 +132,13 @@ def _init_backbone( ): model_type = config.model_type if model_type == "siglip2_navit": - return Siglip2NavitModel(config=config, - quant_config=quant_config, - prefix=prefix, - use_data_parallel=use_data_parallel) - raise ValueError( - f"Unsupported visual tokenizer model_type: {model_type}") + return Siglip2NavitModel( + config=config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=use_data_parallel, + ) + raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}") @property def dtype(self) -> torch.dtype: @@ -135,22 +149,22 @@ def device(self) -> torch.device: return next(self.head.parameters()).device def tokenize(self, logits: torch.Tensor) -> torch.Tensor: - tokens = torch.softmax(logits, dim=-1, - dtype=torch.float32).to(logits.dtype) + tokens = torch.softmax(logits, dim=-1, dtype=torch.float32).to(logits.dtype) return tokens - def encode(self, pixel_values: torch.Tensor, - grid_thws: torch.Tensor) -> torch.Tensor: + def encode( + self, pixel_values: torch.Tensor, grid_thws: torch.Tensor + ) -> torch.Tensor: features = self.vit(pixel_values, grid_thws) # refer to qwen2.5-vl patchmerger seq_len, _ = features.shape - features = features.reshape(seq_len // (self.config.hidden_stride**2), - -1) + features = features.reshape(seq_len // (self.config.hidden_stride**2), -1) return features - def forward(self, pixel_values: torch.Tensor, - grid_thws: torch.Tensor) -> torch.Tensor: + def forward( + self, pixel_values: torch.Tensor, grid_thws: torch.Tensor + ) -> torch.Tensor: features = self.encode(pixel_values, grid_thws) logits = self.head(features) tokens = self.tokenize(logits) @@ -167,7 +181,6 @@ def forward(self, pixel_values: torch.Tensor, class Ovis2_5ProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config() @@ -220,8 +233,9 @@ def get_num_image_tokens( def get_max_image_tokens(self) -> int: target_width, target_height = self.get_image_size_with_most_features() - return self.get_num_image_tokens(image_width=target_width, - image_height=target_height) + return self.get_num_image_tokens( + image_width=target_width, image_height=target_height + ) def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() @@ -247,8 +261,7 @@ def get_num_frames_with_most_features( max_images = mm_counts.get("image", 0) max_videos = mm_counts.get("video", 0) max_image_tokens = self.get_max_image_tokens() * max_images - max_total_frames = self._get_max_video_frames(seq_len - - max_image_tokens) + max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) max_frames_per_video = max_total_frames // max(max_videos, 1) return max(max_frames_per_video, 1) @@ -260,9 +273,9 @@ def get_num_video_tokens( num_frames: int, image_processor: Optional[BaseImageProcessor], ) -> int: - num_video_tokens = self.get_num_image_tokens(image_width=image_width, - image_height=image_height, - num_frames=num_frames) + num_video_tokens = self.get_num_image_tokens( + image_width=image_width, image_height=image_height, num_frames=num_frames + ) return num_video_tokens def get_max_video_tokens( @@ -274,14 +287,12 @@ def get_max_video_tokens( return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), image_processor=None, ) class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -296,48 +307,47 @@ def get_dummy_mm_data( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), - "video": - self._get_dummy_videos( + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, overrides=video_overrides, - ) + ), } return mm_data -class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] - ): - +class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]): def visual_indicators_to_visual_tokens( self, visual_indicators: list[int], ) -> list[int]: """ - Filter image indicators placeholders and convert them to corresponding + Filter image indicators placeholders and convert them to corresponding tokens in visual tokenizer. """ hf_config = self.info.get_hf_config() vte_vocab_size = hf_config.visual_vocab_size return [ vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1 - for x in visual_indicators if x < -300 + for x in visual_indicators + if x < -300 ] def _call_hf_processor( @@ -388,7 +398,6 @@ def _apply_hf_processor_tokens_only( self, prompt_tokens: list[int], ) -> list[int]: - return prompt_tokens def _get_mm_fields_config( @@ -404,7 +413,6 @@ def _get_prompt_updates( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptReplacement]: - def get_replacement_ovis(item_idx, modality: str): if modality == "image": out_item = out_mm_kwargs["image"][item_idx] @@ -413,22 +421,26 @@ def get_replacement_ovis(item_idx, modality: str): out_item = out_mm_kwargs["video"][item_idx] grid = out_item["video_grids"].data hf_processor = self.info.get_hf_processor() - return hf_processor.construct_visual_placeholders(grid[0], ) + return hf_processor.construct_visual_placeholders( + grid[0], + ) return [ PromptReplacement( modality=modality, target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN, replacement=partial(get_replacement_ovis, modality=modality), - ) for modality in ("image", "video") + ) + for modality in ("image", "video") ] -@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor, - info=Ovis2_5ProcessingInfo, - dummy_inputs=Ovis2_5DummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + Ovis2_5MultiModalProcessor, + info=Ovis2_5ProcessingInfo, + dummy_inputs=Ovis2_5DummyInputsBuilder, +) class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -447,17 +459,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.visual_tokenizer", ) - self.vte = VisualEmbedding(config.visual_vocab_size, - config.hidden_size) + self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size) text_model_type = self.config.get_text_config().model_type self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] self.make_empty_intermediate_tensors = ( - self.get_language_model().make_empty_intermediate_tensors) + self.get_language_model().make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[OvisImagePatchInputs]: + self, **kwargs: object + ) -> Optional[OvisImagePatchInputs]: pixel_values = kwargs.pop("pixel_values", None) indicator_tokens = kwargs.pop("indicator_tokens", None) grids = kwargs.pop("grids", None) @@ -466,12 +479,15 @@ def _parse_and_validate_image_input( if pixel_values is not None and indicator_tokens is not None: if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + f"Incorrect type of pixel values. Got type: {type(pixel_values)}" + ) if not isinstance(indicator_tokens, (torch.Tensor, list)): - raise ValueError("Incorrect type of indicator_tokens. " - f"Got type: {type(indicator_tokens)}") + raise ValueError( + "Incorrect type of indicator_tokens. " + f"Got type: {type(indicator_tokens)}" + ) return OvisImagePatchInputs( type="image_patches", @@ -480,15 +496,15 @@ def _parse_and_validate_image_input( x.shape[0] // (self.config.vit_config.hidden_stride**2) for x in flatten_bn(pixel_values) ], - indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), - concat=True), + indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True), grids=flatten_bn(flatten_bn(grids), concat=True), ) raise AssertionError("This line should be unreachable.") def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[OvisImagePatchInputs]: + self, **kwargs: object + ) -> Optional[OvisImagePatchInputs]: pixel_values = kwargs.pop("video_pixel_values", None) indicator_tokens = kwargs.pop("video_indicator_tokens", None) grids = kwargs.pop("video_grids", None) @@ -497,12 +513,15 @@ def _parse_and_validate_video_input( if pixel_values is not None and indicator_tokens is not None: if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + f"Incorrect type of pixel values. Got type: {type(pixel_values)}" + ) if not isinstance(indicator_tokens, (torch.Tensor, list)): - raise ValueError("Incorrect type of indicator_tokens. " - f"Got type: {type(indicator_tokens)}") + raise ValueError( + "Incorrect type of indicator_tokens. " + f"Got type: {type(indicator_tokens)}" + ) return OvisVideoPatchInputs( type="video_patches", @@ -511,8 +530,7 @@ def _parse_and_validate_video_input( x.shape[0] // (self.config.vit_config.hidden_stride**2) for x in flatten_bn(pixel_values) ], - indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), - concat=True), + indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True), grids=flatten_bn(flatten_bn(grids), concat=True), ) @@ -527,30 +545,32 @@ def _process_image_input( grid_thws = image_input["grids"] indicator_per_image = list( - map(lambda x: 2 if x > 1 else x + 2, patches_per_image)) + map(lambda x: 2 if x > 1 else x + 2, patches_per_image) + ) target_dtype = self.visual_tokenizer.dtype visual_tokens = self.visual_tokenizer( - image_patches_flat.to(target_dtype), grid_thws) + image_patches_flat.to(target_dtype), grid_thws + ) visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. indicator_embeds = self.vte(indicator_tokens) visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0) - indicator_embeds_per_image = indicator_embeds.split( - indicator_per_image) + indicator_embeds_per_image = indicator_embeds.split(indicator_per_image) vision_embeddings = [] - for indicator, visual in zip(indicator_embeds_per_image, - visual_embeds_per_image): + for indicator, visual in zip( + indicator_embeds_per_image, visual_embeds_per_image + ): vision_embeddings_per_image = [] visual = visual.unsqueeze(0) for i in range(visual.shape[0]): vision_embeddings_per_image.append( - torch.cat([indicator[i:i + 1], visual[i]], dim=0)) - vision_embeddings_per_image.append(indicator[i + 1:]) - vision_embeddings.append( - torch.cat(vision_embeddings_per_image, dim=0)) + torch.cat([indicator[i : i + 1], visual[i]], dim=0) + ) + vision_embeddings_per_image.append(indicator[i + 1 :]) + vision_embeddings.append(torch.cat(vision_embeddings_per_image, dim=0)) return tuple(vision_embeddings) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -559,20 +579,21 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", "indicator_tokens", - "grids") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("video_pixel_values", "video_indicator_tokens", - "video_grids") and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values", "indicator_tokens", "grids") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key + in ("video_pixel_values", "video_indicator_tokens", "video_grids") + and "videos" not in modalities + ): + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -620,8 +641,7 @@ def compute_logits( logits = self.llm.compute_logits(hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index ff6b8e4b9b4f..7bddfc5ee855 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -11,23 +11,39 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, - MultiModalUUIDDict) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptIndexTargets, - PromptInsertion, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptIndexTargets, + PromptInsertion, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vision_encoder_info logger = init_logger(__name__) @@ -41,6 +57,7 @@ class PaliGemmaImagePixelInputs(TensorSchema): - h: Height - w: Width """ + type: Literal["pixel_values"] = "pixel_values" data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] @@ -52,16 +69,15 @@ class PaliGemmaImageEmbeddingInputs(TensorSchema): - ifs: Image feature size - hs: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, - PaliGemmaImageEmbeddingInputs] +PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, PaliGemmaImageEmbeddingInputs] class PaliGemmaMultiModalProjector(nn.Module): - def __init__(self, vision_hidden_size: int, projection_dim: int): super().__init__() @@ -73,7 +89,6 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: class PaliGemmaProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(PaliGemmaConfig) @@ -97,9 +112,7 @@ def get_num_image_tokens( ) -class PaliGemmaDummyInputsBuilder( - BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): - +class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -118,17 +131,16 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=max_image_size, - height=max_image_size, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=max_image_size, + height=max_image_size, + num_images=num_images, + overrides=image_overrides, + ) } -class PaliGemmaMultiModalProcessor( - BaseMultiModalProcessor[PaliGemmaProcessingInfo]): - +class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -171,7 +183,8 @@ def _get_prompt_updates( def get_insertion(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -196,7 +209,8 @@ def get_insertion(item_idx: int): PromptInsertion( modality="image", target=PromptIndexTargets.prefix( - [bos_token_id] if tokenizer.add_bos_token else []), + [bos_token_id] if tokenizer.add_bos_token else [] + ), insertion=get_insertion, ) ] @@ -209,11 +223,13 @@ def apply( tokenization_kwargs: Optional[Mapping[str, object]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: - mm_inputs = super().apply(prompt, - mm_data, - hf_processor_mm_kwargs, - tokenization_kwargs, - mm_uuids=mm_uuids) + mm_inputs = super().apply( + prompt, + mm_data, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() @@ -231,9 +247,9 @@ def apply( @MULTIMODAL_REGISTRY.register_processor( PaliGemmaMultiModalProcessor, info=PaliGemmaProcessingInfo, - dummy_inputs=PaliGemmaDummyInputsBuilder) -class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): + dummy_inputs=PaliGemmaDummyInputsBuilder, +) +class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -253,7 +269,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, "model.vision_tower.": "vision_tower.", "model.multi_modal_projector.": "multi_modal_projector.", "lm_head.": "language_model.lm_head.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -270,13 +287,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.multimodal_config = multimodal_config - self.vision_tower = SiglipVisionModel(config.vision_config, - quant_config, - prefix=maybe_prefix( - prefix, "vision_tower")) + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) self.multi_modal_projector = PaliGemmaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, - projection_dim=config.vision_config.projection_dim) + projection_dim=config.vision_config.projection_dim, + ) self.quant_config = quant_config @@ -293,10 +312,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: + self, **kwargs: object + ) -> Optional[PaliGemmaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -307,12 +328,11 @@ def _parse_and_validate_image_input( pixel_values = flatten_bn(pixel_values, concat=True) h = w = self.config.vision_config.image_size - return PaliGemmaImagePixelInputs(type="pixel_values", - data=pixel_values, - resolve_bindings={ - "h": h, - "w": w - }) + return PaliGemmaImagePixelInputs( + type="pixel_values", + data=pixel_values, + resolve_bindings={"h": h, "w": w}, + ) if image_embeds is not None: image_embeds = flatten_bn(image_embeds, concat=True) @@ -329,7 +349,6 @@ def _image_pixels_to_features( vision_tower: SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: - target_dtype = vision_tower.get_input_embeddings().weight.dtype image_features = vision_tower(pixel_values.to(dtype=target_dtype)) @@ -339,7 +358,6 @@ def _process_image_input( self, image_input: PaliGemmaImageInputs, ) -> torch.Tensor: - if image_input["type"] == "image_embeds": return image_input["data"] @@ -355,8 +373,7 @@ def _process_image_input( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -365,19 +382,20 @@ def get_multimodal_embeddings(self, vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5) return vision_embeddings - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object) -> IntermediateTensors: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -387,7 +405,6 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 23fb7bb85215..d3df5f9a59b5 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -35,35 +36,42 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class PersimmonMLP(nn.Module): - - def __init__(self, - config: PersimmonConfig, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, config: PersimmonConfig, quant_config: Optional[QuantizationConfig] = None + ): super().__init__() - self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - quant_config=quant_config) - self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, - config.hidden_size, - quant_config=quant_config) + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, config.intermediate_size, quant_config=quant_config + ) + self.dense_4h_to_h = RowParallelLinear( + config.intermediate_size, config.hidden_size, quant_config=quant_config + ) self.act = get_act_fn(config.hidden_act) def forward(self, hidden_states) -> torch.Tensor: @@ -74,12 +82,13 @@ def forward(self, hidden_states) -> torch.Tensor: class PersimmonAttention(nn.Module): - - def __init__(self, - config: PersimmonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.config = config tensor_parallel_world_size = get_tensor_model_parallel_world_size() @@ -123,12 +132,14 @@ def __init__(self, partial_rotary_factor=self.partial_rotary_factor, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scale=self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def _split_heads(self, x: torch.Tensor) -> torch.Tensor: # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim] @@ -167,23 +178,28 @@ def forward( class PersimmonDecoderLayer(nn.Module): - - def __init__(self, - config: PersimmonConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = PersimmonAttention(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = PersimmonAttention( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.mlp = PersimmonMLP(config, quant_config=quant_config) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) def forward( self, @@ -214,7 +230,6 @@ def forward( @support_torch_compile class PersimmonModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -224,18 +239,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vocab_size = config.vocab_size self.config = config - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: PersimmonDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.final_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + self.final_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -262,8 +281,7 @@ def forward( hidden_states = self.final_layernorm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: @@ -282,35 +300,38 @@ def load_weights(self, weights: Iterable[tuple[str, if output_dim is not None: loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) + loaded_weight_shape[:output_dim] + + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1 :] + ) + loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1) loaded_weight = loaded_weight.reshape(loaded_weight_shape) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class PersimmonForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config self.vocab_size = config.vocab_size - self.model = PersimmonModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - bias=False, - prefix=maybe_prefix(prefix, "lm_head")) + self.model = PersimmonModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + bias=False, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -337,7 +358,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 9cf288e85005..779b391008bb 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -37,6 +37,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -50,40 +51,47 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class PhiAttention(nn.Module): - - def __init__(self, - config: PhiConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PhiConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.total_num_heads - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size # pylint: disable=C0103 self.qkv_proj = QKVParallelLinear( @@ -100,28 +108,31 @@ def __init__(self, ) scaling = self.head_size**-0.5 - rotary_dim = int(config.partial_rotary_factor * - (config.hidden_size // config.num_attention_heads)) + rotary_dim = int( + config.partial_rotary_factor + * (config.hidden_size // config.num_attention_heads) + ) assert rotary_dim % 2 == 0 # pylint: disable=C0301 # Refer to: # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 rope_theta = getattr(config, "rope_theta", 10000.0) - max_position_embeddings = getattr(config, "max_position_embeddings", - 2048) + max_position_embeddings = getattr(config, "max_position_embeddings", 2048) self.rotary_emb = get_rope( self.head_size, rotary_dim=rotary_dim, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, - self.head_size, - scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -137,10 +148,9 @@ def forward( class PhiMLP(nn.Module): - - def __init__(self, - config: PhiConfig, - quant_config: Optional[QuantizationConfig] = None): + def __init__( + self, config: PhiConfig, quant_config: Optional[QuantizationConfig] = None + ): super().__init__() n_inner = getattr(config, "n_inner", None) @@ -166,19 +176,20 @@ def forward(self, hidden_states): class PhiLayer(nn.Module): - - def __init__(self, - config: PhiConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: PhiConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attn") + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.self_attn = PhiAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attn" + ) self.mlp = PhiMLP(config, quant_config) def forward( @@ -199,7 +210,6 @@ def forward( @support_torch_compile class PhiModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -209,18 +219,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: PhiLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.final_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + lambda prefix: PhiLayer(config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.final_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -250,13 +262,12 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v") + ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -265,7 +276,7 @@ def load_weights(self, weights: Iterable[tuple[str, if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -287,8 +298,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -315,17 +325,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config - self.model = PhiModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = PhiModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -337,8 +351,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states @@ -346,11 +361,9 @@ def compute_logits( self, hidden_states: torch.Tensor, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - self.lm_head.bias) + logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py index f4e870c53030..56c8755123d3 100644 --- a/vllm/model_executor/models/phi3.py +++ b/vllm/model_executor/models/phi3.py @@ -8,7 +8,6 @@ class Phi3ForCausalLM(LlamaForCausalLM): - packed_modules_mapping = { "qkv_proj": [ "qkv_proj", diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index f5720e726c48..d972604db9cd 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -21,68 +21,90 @@ import regex as re import torch import torch.nn as nn -from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, - ProcessorMixin) +from transformers import ( + BatchFeature, + CLIPVisionConfig, + PretrainedConfig, + ProcessorMixin, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalPromptUpdates, - PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate, - ResolvedPromptUpdate) -# yapf: enable +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalPromptUpdates, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, + ResolvedPromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP, - SupportsQuant) -from .utils import (AutoWeightsLoader, WeightsMapper, - _merge_multimodal_embeddings, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) logger = init_logger(__name__) # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 32044 -CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, - hidden_act="quick_gelu", - hidden_size=1024, - image_size=336, - intermediate_size=4096, - num_attention_heads=16, - num_channels=3, - num_hidden_layers=24, - patch_size=14, - projection_dim=768) - - -def _init_img_processor(hf_config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], - prefix: str = "") -> CLIPVisionModel: +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig( + dropout=0.0, + hidden_act="quick_gelu", + hidden_size=1024, + image_size=336, + intermediate_size=4096, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, +) + + +def _init_img_processor( + hf_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", +) -> CLIPVisionModel: clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG - layer_idx = hf_config.img_processor.get('layer_idx', -2) + layer_idx = hf_config.img_processor.get("layer_idx", -2) # Initialize the CLIP only up to the required feature layer if layer_idx < 0: - num_hidden_layers = clip_config.num_hidden_layers + \ - layer_idx + 1 + num_hidden_layers = clip_config.num_hidden_layers + layer_idx + 1 else: num_hidden_layers = layer_idx + 1 @@ -109,10 +131,11 @@ class Phi3VImagePixelInputs(TensorSchema): type: Literal["pixel_values", "image_embeds"] = "pixel_values" # Supports either a stacked tensor or a list of (p, 3, h, w) tensors - data: Annotated[ + pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} - ), # 'p' may vary across items + TensorShape( + "bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # 'p' may vary across items ] # Stacked tensor with height and width for each image @@ -127,6 +150,7 @@ class Phi3VImageEmbeddingInputs(TensorSchema): - f: Image feature size (e.g., number of tokens per image) - h: Hidden size (must match language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[ Union[torch.Tensor, list[torch.Tensor]], @@ -138,15 +162,13 @@ class Phi3VImageEmbeddingInputs(TensorSchema): class Phi3ImageEmbeddingBase(nn.Module): - def __init__(self) -> None: super().__init__() self.layer_idx: int self.type_feature: str self.img_processor: CLIPVisionModel - def get_img_features(self, - img_embeds: torch.FloatTensor) -> torch.FloatTensor: + def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: TYPE_FEATURE = self.type_feature # NOTE: we skip the step to select the vision feature layer since @@ -167,52 +189,51 @@ def get_img_features(self, class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): """Phi3 Image embedding with HD transform.""" - def __init__(self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], - prefix: str = "") -> None: + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + ) -> None: super().__init__() # n_embed or hidden_size - hidden_size = config.n_embd if hasattr( - config, 'n_embd') else config.hidden_size + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size self.img_processor = _init_img_processor( - config, quant_config, prefix=f"{prefix}.img_processor") + config, quant_config, prefix=f"{prefix}.img_processor" + ) - image_dim_out = config.img_processor['image_dim_out'] - self.num_img_tokens = config.img_processor['num_img_tokens'] + image_dim_out = config.img_processor["image_dim_out"] + self.num_img_tokens = config.img_processor["num_img_tokens"] self.image_dim_out = image_dim_out # global_gn and sub_gn for hd transform, serves as line separator - self.use_hd_transform = config.embd_layer.get('use_hd_transform', - False) + self.use_hd_transform = config.embd_layer.get("use_hd_transform", False) self.with_learnable_separator = config.embd_layer.get( - 'with_learnable_separator', False) - self.hd_transform_order = config.embd_layer.get( - 'hd_transform_order', 'glb_sub') + "with_learnable_separator", False + ) + self.hd_transform_order = config.embd_layer.get("hd_transform_order", "glb_sub") # with_hd_transform and with_learnable_separator should have same value assert self.use_hd_transform and self.with_learnable_separator # 1024 * 4, merge spatial to channel dimension self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4])) - self.sub_GN = nn.Parameter( - torch.empty([1, 1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4])) dim_projection = hidden_size depth = 2 layers = [nn.Linear(image_dim_out * 4, dim_projection)] for _ in range(1, depth): - layers.extend( - [nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.img_projection = nn.Sequential(*layers) - self.type_feature = config.img_processor.get('type_feature', 'patch') + self.type_feature = config.img_processor.get("type_feature", "patch") - def forward(self, pixel_values: torch.FloatTensor, - image_sizes: torch.Tensor) -> torch.FloatTensor: + def forward( + self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor + ) -> torch.FloatTensor: """ process image and return vision embeddings. @@ -222,19 +243,19 @@ def forward(self, pixel_values: torch.FloatTensor, num_images, num_crops, c, h, w = pixel_values.shape pixel_values = pixel_values.flatten(0, 1) img_features = self.get_img_features(pixel_values) - img_features = img_features.reshape(num_images, num_crops, -1, - self.image_dim_out) - image_features_proj = self.hd_feature_transform( - img_features, image_sizes) + img_features = img_features.reshape( + num_images, num_crops, -1, self.image_dim_out + ) + image_features_proj = self.hd_feature_transform(img_features, image_sizes) return image_features_proj def hd_feature_transform(self, image_features, image_sizes): """ image_features: (num_images, num_crops+1, 24*24, 1024) """ - assert ( - self.hd_transform_order == 'sub_glb' - ), f'hd_transform_order `{self.hd_transform_order}` not implemented' + assert self.hd_transform_order == "sub_glb", ( + f"hd_transform_order `{self.hd_transform_order}` not implemented" + ) if isinstance(self.img_projection, nn.Sequential): target_device = self.img_projection[0].bias.device target_dtype = self.img_projection[0].bias.dtype @@ -242,13 +263,14 @@ def hd_feature_transform(self, image_features, image_sizes): target_device = self.img_projection.bias.device target_dtype = self.img_projection.bias.dtype - global_image_features = image_features[:, - 0] # (num_images, 24*24, 1024) + global_image_features = image_features[:, 0] # (num_images, 24*24, 1024) # global feature can be viewed as a special HD case with num_crops 1x1 global_image_features_hd = self.reshape_hd_patches_2x2merge( - global_image_features, 1, 1) + global_image_features, 1, 1 + ) global_image_features_hd_newline = self.add_image_newline( - global_image_features_hd) + global_image_features_hd + ) batch_image_features_proj = [] # need a for loop to process each image because of different image sizes @@ -261,21 +283,27 @@ def hd_feature_transform(self, image_features, image_sizes): # NOTE: real num_crops is padded # (num_crops, 24*24, 1024) - sub_image_features = image_features[i, 1:1 + num_crops] + sub_image_features = image_features[i, 1 : 1 + num_crops] sub_image_features_hd = self.reshape_hd_patches_2x2merge( - sub_image_features, h_crop, w_crop) + sub_image_features, h_crop, w_crop + ) sub_image_features_hd_newline = self.add_image_newline( - sub_image_features_hd) + sub_image_features_hd + ) # [sub features, separator, global features] - image_embeddings = torch.cat([ - sub_image_features_hd_newline.squeeze( - 0), # (h_crop*12*(w_crop*12+1), 4096) - self.glb_GN.squeeze(0), - global_image_features_hd_newline[i], - ]) + image_embeddings = torch.cat( + [ + sub_image_features_hd_newline.squeeze( + 0 + ), # (h_crop*12*(w_crop*12+1), 4096) + self.glb_GN.squeeze(0), + global_image_features_hd_newline[i], + ] + ) img_proj = self.img_projection( - image_embeddings.to(target_device, target_dtype)) + image_embeddings.to(target_device, target_dtype) + ) batch_image_features_proj.append(img_proj) return batch_image_features_proj @@ -295,11 +323,13 @@ def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop): .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 .reshape(N, -1, 4 * C) # N, 144, 4096 - .reshape(num_images, h_crop, w_crop, H // 2, H // 2, - -1) # n_img, h_crop, w_crop, 12, 12, 4096 + .reshape( + num_images, h_crop, w_crop, H // 2, H // 2, -1 + ) # n_img, h_crop, w_crop, 12, 12, 4096 .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 - .reshape(num_images, h_crop * H // 2, w_crop * H // 2, - 4 * C) # n_img, h_crop*12, w_crop*12, 4096 + .reshape( + num_images, h_crop * H // 2, w_crop * H // 2, 4 * C + ) # n_img, h_crop*12, w_crop*12, 4096 ) return image_features_hd @@ -310,16 +340,16 @@ def add_image_newline(self, image_features_hd): """ num_images, h, w, hid_dim = image_features_hd.shape # add the newline token to the HD image feature patches - newline_embeddings = self.sub_GN.expand(num_images, h, -1, - -1) # (n_img, h, 1, hid_dim) + newline_embeddings = self.sub_GN.expand( + num_images, h, -1, -1 + ) # (n_img, h, 1, hid_dim) image_features_hd_newline = torch.cat( - [image_features_hd, newline_embeddings], - dim=2).reshape(num_images, -1, hid_dim) + [image_features_hd, newline_embeddings], dim=2 + ).reshape(num_images, -1, hid_dim) return image_features_hd_newline class Phi3VProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} @@ -344,7 +374,6 @@ def get_image_size_with_most_features(self) -> ImageSize: class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -361,22 +390,21 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): - def _call_hf_processor( self, prompt: str, @@ -423,7 +451,8 @@ def _get_prompt_updates( def get_replacement_phi3v(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -487,8 +516,7 @@ def _apply_prompt_updates( # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407 pattern = r"<\|image_\d+\|>" prompt_chunks = [ - tokenizer(chunk).input_ids - for chunk in re.split(pattern, text) + tokenizer(chunk).input_ids for chunk in re.split(pattern, text) ] image_tags = [ tokenizer(chunk, add_special_tokens=False).input_ids @@ -497,8 +525,10 @@ def _apply_prompt_updates( if len(prompt_chunks) > len(image_tags): image_tags.append([]) token_ids = [ - e for sublist in zip(prompt_chunks, image_tags) - for ele in sublist for e in ele + e + for sublist in zip(prompt_chunks, image_tags) + for ele in sublist + for e in ele ] token_ids, placeholders = super()._apply_prompt_updates( @@ -507,8 +537,9 @@ def _apply_prompt_updates( ) # Keep the behavior in line with HF processor - if token_ids[:2] == tokenizer.encode("<s> <|image|>", - add_special_tokens=False): + if len(mm_prompt_updates) and ( + token_ids[:2] == tokenizer.encode("<s> <|image|>", add_special_tokens=False) + ): token_ids = [token_ids[0], *token_ids[2:]] placeholders = { modality: [ @@ -518,7 +549,8 @@ def _apply_prompt_updates( start_idx=p.start_idx - 1, tokens=p.tokens, is_embed=p.is_embed, - ) for p in ps + ) + for p in ps ] for modality, ps in placeholders.items() } @@ -526,18 +558,20 @@ def _apply_prompt_updates( return token_ids, placeholders -@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor, - info=Phi3VProcessingInfo, - dummy_inputs=Phi3VDummyInputsBuilder) -class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, - SupportsQuant): +@MULTIMODAL_REGISTRY.register_processor( + Phi3VMultiModalProcessor, + info=Phi3VProcessingInfo, + dummy_inputs=Phi3VDummyInputsBuilder, +) +class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.vision_embed_tokens.wte": "embed_tokens", "model.vision_embed_tokens.": "vision_embed_tokens.", "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -566,7 +600,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vision_embed_tokens = Phi3HDImageEmbedding( config, self.quant_config, - prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) + prefix=maybe_prefix(prefix, "model.vision_embed_tokens"), + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -580,10 +615,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Phi3VImageInputs]: + self, **kwargs: object + ) -> Optional[Phi3VImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_sizes = kwargs.pop("image_sizes", None) image_embeds = kwargs.pop("image_embeds", None) @@ -594,12 +631,13 @@ def _parse_and_validate_image_input( if pixel_values is not None: return Phi3VImagePixelInputs( type="pixel_values", - data=flatten_bn(pixel_values), + pixel_values=flatten_bn(pixel_values), image_sizes=flatten_bn(image_sizes, concat=True), resolve_bindings={ "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, - "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size - }) + "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, + }, + ) if image_embeds is not None: return Phi3VImageEmbeddingInputs( @@ -613,7 +651,6 @@ def _process_image_input( self, image_input: Phi3VImageInputs, ) -> torch.Tensor: - if image_input["type"] == "image_embeds": image_data = image_input["data"] if is_list_of(image_data, torch.Tensor): @@ -628,16 +665,16 @@ def _process_image_input( ) assert self.vision_embed_tokens is not None - image_embeds = self.vision_embed_tokens(image_input["data"], - image_input["image_sizes"]) + image_embeds = self.vision_embed_tokens( + image_input["pixel_values"], image_input["image_sizes"] + ) return image_embeds def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -666,7 +703,8 @@ def get_input_embeddings( raise ValueError( "`get_input_embeddings` now requires `is_multimodal` arg, " "please update your model runner according to " - "https://github.com/vllm-project/vllm/pull/16229.") + "https://github.com/vllm-project/vllm/pull/16229." + ) return _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, @@ -674,20 +712,20 @@ def get_input_embeddings( is_multimodal=is_multimodal, ) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object): - + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ): if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -697,12 +735,9 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - autoloaded_weights = loader.load_weights(weights, - mapper=self.hf_to_vllm_mapper) + autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) # The HF config doesn't specify whether these are tied, # so we detect it this way diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index a5cc87d327b5..002233d0677b 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -8,35 +8,60 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import (BatchFeature, Phi4MultimodalAudioConfig, - Phi4MultimodalConfig, Phi4MultimodalFeatureExtractor, - Phi4MultimodalImageProcessorFast) +from transformers import ( + BatchFeature, + Phi4MultimodalAudioConfig, + Phi4MultimodalConfig, + Phi4MultimodalFeatureExtractor, + Phi4MultimodalImageProcessorFast, +) from transformers import Phi4MultimodalProcessor as Phi4MMProcessor from transformers.models.phi4_multimodal.modeling_phi4_multimodal import ( - Phi4MultimodalAudioConvModule, Phi4MultimodalAudioNemoConvSubsampling, - Phi4MultimodalAudioRelativeAttentionBias, adaptive_enc_mask, unfold_tensor) + Phi4MultimodalAudioConvModule, + Phi4MultimodalAudioNemoConvSubsampling, + Phi4MultimodalAudioRelativeAttentionBias, + adaptive_enc_mask, + unfold_tensor, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.distributed import (divide, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, - ImageProcessorItems, ImageSize, - MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -44,14 +69,20 @@ from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) _AUDIO_MAX_SOUNDFILE_SIZE = 241_000 -def _get_padding_size(orig_width: int, orig_height: int, target_height: int, - target_width: int): +def _get_padding_size( + orig_width: int, orig_height: int, target_height: int, target_width: int +): ratio_width = target_width / orig_width ratio_height = target_height / orig_height @@ -65,7 +96,6 @@ def _get_padding_size(orig_width: int, orig_height: int, target_height: int, class Phi4MMProjector(nn.Module): - def __init__(self, input_size: int, hidden_size: int): super().__init__() self.up = ColumnParallelLinear(input_size, hidden_size) @@ -89,41 +119,44 @@ def __init__(self, config: Phi4MultimodalConfig): self.crop_size = config.vision_config.crop_size self.image_dim_out = config.vision_config.hidden_size - n_patches = (config.vision_config.image_size // - config.vision_config.patch_size) + n_patches = config.vision_config.image_size // config.vision_config.patch_size if n_patches % 2 != 0: self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) n_patches += 1 - self.num_img_tokens = (n_patches // 2)**2 + self.num_img_tokens = (n_patches // 2) ** 2 - num_hidden_layers = (config.vision_config.num_hidden_layers + - self.layer_idx + - 1 if self.layer_idx < 0 else self.layer_idx + 1) + num_hidden_layers = ( + config.vision_config.num_hidden_layers + self.layer_idx + 1 + if self.layer_idx < 0 + else self.layer_idx + 1 + ) self.img_processor = Idefics2VisionTransformer( config.vision_config, require_post_norm=False, - num_hidden_layers_override=num_hidden_layers) + num_hidden_layers_override=num_hidden_layers, + ) self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) - self.img_projection = Phi4MMProjector(self.image_dim_out, - config.hidden_size) + self.img_projection = Phi4MMProjector(self.image_dim_out, config.hidden_size) self.global_img_feature_extensor = nn.Parameter( - torch.zeros([1, 1, self.image_dim_out])) + torch.zeros([1, 1, self.image_dim_out]) + ) self.sub_img_feature_extensor = nn.Parameter( - torch.zeros([1, 1, 1, self.image_dim_out])) + torch.zeros([1, 1, 1, self.image_dim_out]) + ) def get_img_features( self, img_embeds: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: - img_feature = self.img_processor(img_embeds, - patch_attention_mask=attention_mask) + img_feature = self.img_processor( + img_embeds, patch_attention_mask=attention_mask + ) patch_feature = img_feature # reshape to 2D tensor width = int(math.sqrt(patch_feature.size(1))) - patch_feature = patch_feature.view(-1, width, width, - patch_feature.size(-1)) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) if getattr(self, "img_processor_padding", None) is not None: @@ -132,9 +165,8 @@ def get_img_features( # convert to NHWC patch_feature = patch_feature.permute(0, 2, 3, 1) patch_feature = patch_feature.view( - -1, - patch_feature.size(1) * patch_feature.size(2), - patch_feature.size(-1)) + -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1) + ) return patch_feature def forward( @@ -144,7 +176,8 @@ def forward( image_attention_mask: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: image_pixel_values = image_pixel_values.to( - self.img_processor.embeddings.patch_embedding.weight.dtype) + self.img_processor.embeddings.patch_embedding.weight.dtype + ) target_device = self.img_projection.up.bias.device target_dtype = self.img_projection.up.bias.dtype @@ -154,11 +187,13 @@ def forward( img_features = self.get_img_features( image_pixel_values.flatten(0, 1), attention_mask=image_attention_mask.flatten(0, 1).to( - dtype=bool, device=target_device), + dtype=bool, device=target_device + ), ) base_feat_size = int(np.sqrt(img_features.shape[1])) - img_features = img_features.view(batch_size, -1, base_feat_size**2, - self.image_dim_out) + img_features = img_features.view( + batch_size, -1, base_feat_size**2, self.image_dim_out + ) image_sizes = image_sizes.view(-1, 2) output_imgs = [] @@ -169,58 +204,70 @@ def forward( area_ratio = height_ratio * width_ratio global_img = img_features[idx, :1] - global_img = global_img.reshape(1, base_feat_size, base_feat_size, - self.image_dim_out).contiguous() + global_img = global_img.reshape( + 1, base_feat_size, base_feat_size, self.image_dim_out + ).contiguous() temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, base_feat_size, 1, 1) - global_img = torch.cat([global_img, temporary_extensor], - dim=2).reshape(1, -1, self.image_dim_out) + 1, base_feat_size, 1, 1 + ) + global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape( + 1, -1, self.image_dim_out + ) sub_img = img_features[idx, 1:] sub_img = sub_img[:area_ratio] - sub_img = (sub_img.reshape( - height_ratio, width_ratio, base_feat_size, base_feat_size, - self.image_dim_out).transpose(1, 2).reshape( - 1, height_ratio * base_feat_size, + sub_img = ( + sub_img.reshape( + height_ratio, + width_ratio, + base_feat_size, + base_feat_size, + self.image_dim_out, + ) + .transpose(1, 2) + .reshape( + 1, + height_ratio * base_feat_size, width_ratio * base_feat_size, - self.image_dim_out).contiguous()) + self.image_dim_out, + ) + .contiguous() + ) if image_attention_mask is not None: reshaped_image_attention_mask = ( - image_attention_mask[idx, 1:area_ratio + 1, - 0::2, 0::2].reshape( - height_ratio, width_ratio, - base_feat_size, - base_feat_size).transpose( - 1, 2).reshape( - 1, height_ratio * - base_feat_size, - width_ratio * - base_feat_size)) - useful_height = int( - reshaped_image_attention_mask[0, :, 0].sum().item()) - useful_width = int( - reshaped_image_attention_mask[0, 0, :].sum().item()) + image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] + .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) + .transpose(1, 2) + .reshape( + 1, height_ratio * base_feat_size, width_ratio * base_feat_size + ) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) sub_img = sub_img[:, :useful_height, :useful_width] temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, useful_height, 1, 1) + 1, useful_height, 1, 1 + ) else: temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, height_ratio * base_feat_size, 1, 1) + 1, height_ratio * base_feat_size, 1, 1 + ) - sub_img = torch.cat([sub_img, temporary_extensor], - dim=2).reshape(1, -1, self.image_dim_out) + sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape( + 1, -1, self.image_dim_out + ) # Merge global and sub output_imgs.append( torch.cat( - [sub_img, self.global_img_feature_extensor, global_img], - dim=1)) + [sub_img, self.global_img_feature_extensor, global_img], dim=1 + ) + ) img_set_tensor = [] for output_img in output_imgs: - output_img = output_img.to(device=target_device, - dtype=target_dtype) + output_img = output_img.to(device=target_device, dtype=target_dtype) img_feature_proj = self.img_projection(output_img) img_set_tensor.append(img_feature_proj.flatten(0, 1)) @@ -228,7 +275,6 @@ def forward( class Phi4MultimodalAudioMLP(nn.Module): - def __init__( self, config: Phi4MultimodalAudioConfig, @@ -239,15 +285,19 @@ def __init__( self.layer_norm = nn.LayerNorm(config.hidden_size) self.act_fn = MulAndSilu() self.gate_up_proj = MergedColumnParallelLinear( - config.hidden_size, [config.intermediate_size] * 2, + config.hidden_size, + [config.intermediate_size] * 2, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, bias=True, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.down_proj", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.layer_norm(hidden_states) @@ -258,7 +308,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Phi4MultimodalAudioAttention(nn.Module): - def __init__( self, config: Phi4MultimodalAudioConfig, @@ -274,7 +323,8 @@ def __init__( raise ValueError( "embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.qkv_proj = QKVParallelLinear( @@ -331,7 +381,6 @@ def forward( class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): - def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() @@ -350,11 +399,9 @@ def forward( residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) hidden_states = self.layer_norm_att(residual) - hidden_states = residual + self.self_attn(hidden_states, - attention_mask) + hidden_states = residual + self.self_attn(hidden_states, attention_mask) hidden_states = hidden_states + self.conv(hidden_states) - hidden_states = hidden_states + 0.5 * self.feed_forward_out( - hidden_states) + hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) out = self.layer_norm(hidden_states) @@ -368,7 +415,7 @@ class Phi4MMAudioMeanVarianceNormLayer(nn.Module): Typically used as a very first layer in a model. Args: - config: [Phi4MultimodalAudioConfig](https://huggingface.co/docs/transformers/model_doc/phi4_multimodal#transformers.Phi4MultimodalAudioConfig) + config: [Phi4MultimodalAudioConfig](https://huggingface.co/docs/transformers/model_doc/phi4_multimodal#transformers.Phi4MultimodalAudioConfig) object containing model parameters. """ @@ -388,19 +435,21 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: class Phi4MultimodalAudioModel(nn.Module): - def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config self.encoder_embedding = Phi4MMAudioMeanVarianceNormLayer(config) self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) - self.relative_attention_bias_layer = ( - Phi4MultimodalAudioRelativeAttentionBias(config)) - self.encoders = nn.ModuleList([ - Phi4MultimodalAudioConformerEncoderLayer(config) - for _ in range(config.num_blocks) - ]) + self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias( + config + ) + self.encoders = nn.ModuleList( + [ + Phi4MultimodalAudioConformerEncoderLayer(config) + for _ in range(config.num_blocks) + ] + ) def _streaming_mask( self, @@ -413,9 +462,11 @@ def _streaming_mask( # S stores start index. if chunksize is 18, s is [0,18,36,....] chunk_start_idx = np.arange(0, seq_len, chunk_size) - enc_streaming_mask = (adaptive_enc_mask( - seq_len, chunk_start_idx, - left_window=left_chunk).unsqueeze(0).expand([batch_size, -1, -1])) + enc_streaming_mask = ( + adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) return enc_streaming_mask def forward_embeddings( @@ -424,18 +475,18 @@ def forward_embeddings( masks: torch.Tensor, ): """Forwarding the inputs through the top embedding layers""" - seq_len = math.ceil(hidden_states.shape[1] / - self.config.time_reduction) + seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) if seq_len <= 0: raise ValueError( f"Sequence length after time reduction is invalid: {seq_len}." - "Your input feature is too short.") + "Your input feature is too short." + ) batch_size = hidden_states.shape[0] - enc_streaming_mask = self._streaming_mask(seq_len, batch_size, - self.config.chunk_size, - self.config.left_chunk) + enc_streaming_mask = self._streaming_mask( + seq_len, batch_size, self.config.chunk_size, self.config.left_chunk + ) enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) hidden_states, masks = self.embed(hidden_states, masks) @@ -450,13 +501,14 @@ def forward_embeddings( return hidden_states, hs_mask, masks - def calculate_hs_mask(self, hidden_states: torch.Tensor, - device: torch.device, mask: torch.Tensor): + def calculate_hs_mask( + self, hidden_states: torch.Tensor, device: torch.device, mask: torch.Tensor + ): max_audio_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] - enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, - self.config.chunk_size, - self.config.left_chunk) + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk + ) enc_streaming_mask = enc_streaming_mask.to(device) if mask is None: return enc_streaming_mask @@ -464,17 +516,15 @@ def calculate_hs_mask(self, hidden_states: torch.Tensor, feature_lens = mask.sum(1) padding_length = feature_lens pad_mask = torch.arange(0, max_audio_length, device=device).expand( - padding_length.size(0), -1) < padding_length.unsqueeze(1) + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) pad_mask = pad_mask.unsqueeze(1) pad_mask = pad_mask & enc_streaming_mask return pad_mask - def forward(self, - hidden_states: torch.Tensor, - mask: Optional[torch.Tensor] = None): + def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor] = None): hidden_states = self.encoder_embedding(hidden_states) - hidden_states, hs_mask, mask = self.forward_embeddings( - hidden_states, mask) + hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) unfolded = False bs, seq_len, _ = hidden_states.shape @@ -490,9 +540,9 @@ def forward(self, else: chunk_pad_size = 0 if chunk_pad_size > 0: - hidden_states_pad = F.pad(hidden_states, - (0, 0, 0, chunk_pad_size), - "constant", 0) + hidden_states_pad = F.pad( + hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0 + ) hidden_states = hidden_states_pad.to(hidden_states.device) hidden_states = unfold_tensor(hidden_states, max_seq_len) @@ -500,24 +550,24 @@ def forward(self, if mask is not None: # revise hs_mask here because the previous calculated hs_mask # did not consider extra pad - subsampled_pad_mask = mask.squeeze( - 1) # [bz, subsampled_unmask_seq_len] + subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] extra_padded_subsamlped_pad_mask = F.pad( - subsampled_pad_mask, (0, chunk_pad_size), "constant", - False) # extra padding to the pad mask + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask extra_padded_subsamlped_pad_mask = ( - extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()) + extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + ) masks_unfold = unfold_tensor( extra_padded_subsamlped_pad_mask, max_seq_len ) # unfold the pad mask like we did to the input tensor masks_unfold = masks_unfold.squeeze( - -1).bool() # unfold op does not support bool tensor + -1 + ).bool() # unfold op does not support bool tensor hs_mask = self.calculate_hs_mask( hidden_states, hidden_states.device, masks_unfold ) # calculate hs_mask based on the unfolded pad mask - relative_attention_bias = self.relative_attention_bias_layer( - hidden_states) + relative_attention_bias = self.relative_attention_bias_layer(hidden_states) attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias for layer in self.encoders: @@ -534,7 +584,6 @@ def forward(self, class Phi4MMAudioEmbedding(nn.Module): - def __init__(self, config: Phi4MultimodalConfig): super().__init__() self.config = config @@ -543,12 +592,11 @@ def __init__(self, config: Phi4MultimodalConfig): self.encoder = Phi4MultimodalAudioModel(config.audio_config) audio_config = config.audio_config - proj_input_size = (audio_config.hidden_size * - audio_config.downsample_rate) + proj_input_size = audio_config.hidden_size * audio_config.downsample_rate self.vision_speech_projection = Phi4MMProjector( - proj_input_size, config.hidden_size) - self.speech_projection = Phi4MMProjector(proj_input_size, - config.hidden_size) + proj_input_size, config.hidden_size + ) + self.speech_projection = Phi4MMProjector(proj_input_size, config.hidden_size) def get_projection( self, @@ -566,23 +614,23 @@ def forward( audio_attention_mask=None, audio_projection_mode="speech", ) -> torch.FloatTensor: - audio_projection = self.get_projection(audio_projection_mode) target_device = audio_projection.up.bias.device target_dtype = audio_projection.up.bias.dtype - audio_input_features = audio_input_features.to(device=target_device, - dtype=target_dtype) + audio_input_features = audio_input_features.to( + device=target_device, dtype=target_dtype + ) - audio_encoder_hidden_states = self.encoder(audio_input_features, - audio_attention_mask) + audio_encoder_hidden_states = self.encoder( + audio_input_features, audio_attention_mask + ) audio_embeds = audio_projection(audio_encoder_hidden_states) return audio_embeds.flatten(0, 1) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -603,8 +651,7 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -627,8 +674,9 @@ class Phi4MMImagePixelInputs(TensorSchema): data: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} - ), # may be different per batch and image + TensorShape( + "bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # may be different per batch and image ] image_sizes: Annotated[ @@ -705,9 +753,9 @@ def cat_with_pad(tensors, dim, padding_value=0): cat along dim, while pad to max for all other dims """ ndim = tensors[0].dim() - assert all( - t.dim() == ndim for t in - tensors[1:]), "All tensors must have the same number of dimensions" + assert all(t.dim() == ndim for t in tensors[1:]), ( + "All tensors must have the same number of dimensions" + ) out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] out_size[dim] = sum(t.shape[dim] for t in tensors) @@ -727,15 +775,13 @@ def cat_with_pad(tensors, dim, padding_value=0): class Phi4MMProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> Phi4MultimodalConfig: return self.ctx.get_hf_config(Phi4MultimodalConfig) def get_hf_processor(self, **kwargs: object) -> Phi4MMProcessor: return self.ctx.get_hf_processor(Phi4MMProcessor, **kwargs) - def get_feature_extractor( - self, **kwargs: object) -> Phi4MultimodalFeatureExtractor: + def get_feature_extractor(self, **kwargs: object) -> Phi4MultimodalFeatureExtractor: return self.get_hf_processor(**kwargs).audio_processor def get_image_processor( @@ -769,9 +815,12 @@ def _find_target_aspect_ratio( aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio - target_ratios = set((i, j) for i in range(1, max_num + 1) - for j in range(1, max_num + 1) - if i * j <= max_num and i * j >= min_num) + target_ratios = set( + (i, j) + for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num + ) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target @@ -804,49 +853,56 @@ def _compute_num_image_tokens( ): """ compute the number of tokens an image is expected to take up considering - the image encoder architecture and exclude output features containing + the image encoder architecture and exclude output features containing only padding pixels - for siglip, vit_image_size=448, vit_patch_size=14, so output will be + for siglip, vit_image_size=448, vit_patch_size=14, so output will be 32x32 feature map NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 """ assert vit_image_size % vit_patch_size == 0, ( - "vit_image_size must be divisible by vit_patch_size") - assert (vit_image_size // vit_patch_size % - token_compression_factor == 0), ( - "vit_image_size // vit_patch_size must be divisible by " - "token_compression_factor") + "vit_image_size must be divisible by vit_patch_size" + ) + assert vit_image_size // vit_patch_size % token_compression_factor == 0, ( + "vit_image_size // vit_patch_size must be divisible by " + "token_compression_factor" + ) target_aspect_ratio, target_height, target_width = ( - self._find_target_aspect_ratio(orig_width, - orig_height, - vit_image_size, - dynamic_hd_size, - min_num=1)) + self._find_target_aspect_ratio( + orig_width, orig_height, vit_image_size, dynamic_hd_size, min_num=1 + ) + ) assert target_aspect_ratio[0] * vit_image_size == target_width, ( - f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}") + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" + ) assert target_aspect_ratio[1] * vit_image_size == target_height, ( - f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}") - assert (target_height % vit_image_size == 0 - and target_width % vit_image_size == 0) + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" + ) + assert ( + target_height % vit_image_size == 0 and target_width % vit_image_size == 0 + ) padding_height, padding_width = _get_padding_size( - orig_width, orig_height, target_height, target_width) - assert padding_width == 0 or padding_height == 0, \ + orig_width, orig_height, target_height, target_width + ) + assert padding_width == 0 or padding_height == 0, ( "padding_width or padding_height must be 0" + ) target_feat_width = target_width // vit_patch_size target_feat_height = target_height // vit_patch_size if padding_width >= vit_patch_size: assert padding_height == 0, "padding_height not 0" non_pad_feat_width = target_feat_width - math.floor( - padding_width / vit_patch_size) + padding_width / vit_patch_size + ) non_pad_feat_height = target_feat_height elif padding_height >= vit_patch_size: assert padding_width == 0, "padding_width not 0" non_pad_feat_height = target_feat_height - math.floor( - padding_height / vit_patch_size) + padding_height / vit_patch_size + ) non_pad_feat_width = target_feat_width else: # small padding shorter than a vit patch @@ -863,15 +919,17 @@ def _compute_num_image_tokens( num_hd_patch_tokens = feat_width * feat_height num_hd_newline_tokens = feat_height vit_feature_size = vit_image_size // vit_patch_size - num_global_image_tokens = (vit_feature_size // - token_compression_factor)**2 + num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2 num_sep_tokens = 1 - num_global_image_newline_tokens = \ - vit_feature_size // token_compression_factor - - return (num_global_image_tokens + num_sep_tokens + - num_hd_patch_tokens + num_hd_newline_tokens + - num_global_image_newline_tokens) + num_global_image_newline_tokens = vit_feature_size // token_compression_factor + + return ( + num_global_image_tokens + + num_sep_tokens + + num_hd_patch_tokens + + num_hd_newline_tokens + + num_global_image_newline_tokens + ) def get_num_image_tokens( self, @@ -966,7 +1024,6 @@ def _compute_audio_embed_size(self, audio_frames: int) -> int: class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) @@ -986,29 +1043,29 @@ def get_dummy_mm_data( num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None audio_overrides = mm_options.get("audio") if mm_options else None mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), - "audio": - self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE, - num_audios=num_audios, - overrides=audio_overrides), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "audio": self._get_dummy_audios( + length=_AUDIO_MAX_SOUNDFILE_SIZE, + num_audios=num_audios, + overrides=audio_overrides, + ), } return mm_data class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -1027,29 +1084,29 @@ def _call_hf_processor( audio_data = mm_data.pop("audios", []) if audio_data: - mm_data['audio'] = audio_data + mm_data["audio"] = audio_data - processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs, tok_kwargs) + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs + ) if "image_pixel_values" in processed_outputs: num_img_tokens = [ - self.info.get_num_image_tokens(image_width=img_size[0], - image_height=img_size[1]) + self.info.get_num_image_tokens( + image_width=img_size[0], image_height=img_size[1] + ) for img_size in processed_outputs["image_sizes"] ] processed_outputs["num_img_tokens"] = num_img_tokens if audio_data: - audio_features = processed_outputs['audio_input_features'] + audio_features = processed_outputs["audio_input_features"] sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate feature_sizes = [ - self.info.get_audio_num_frames(len(audio), sr) - for audio in audio_data + self.info.get_audio_num_frames(len(audio), sr) for audio in audio_data ] - processed_outputs['audio_input_features'] = [ - audio_features[idx, :size] - for idx, size in enumerate(feature_sizes) + processed_outputs["audio_input_features"] = [ + audio_features[idx, :size] for idx, size in enumerate(feature_sizes) ] return processed_outputs @@ -1078,12 +1135,12 @@ def _get_prompt_updates( audio_token_id: int = tokenizer.vocab[tokenizer.audio_token] hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - audio_processor = self.info.get_feature_extractor( - **hf_processor_mm_kwargs) + audio_processor = self.info.get_feature_extractor(**hf_processor_mm_kwargs) def get_image_replacement_phi4mm(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -1102,9 +1159,9 @@ def get_audio_replacement_phi4mm(item_idx: int): # TODO(Isotr0py): support embedding inputs audio_len = audios.get_audio_length(item_idx) audio_frames = self.info.get_audio_num_frames( - audio_len, audio_processor.sampling_rate) - audio_embed_size = self.info._compute_audio_embed_size( - audio_frames) + audio_len, audio_processor.sampling_rate + ) + audio_embed_size = self.info._compute_audio_embed_size(audio_frames) return [audio_token_id] * audio_embed_size @@ -1131,6 +1188,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): """ Implements the Phi-4-multimodal-instruct model in vLLM. """ + packed_modules_mapping = { "qkv_proj": [ "qkv_proj", @@ -1190,12 +1248,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: + self, **kwargs: object + ) -> Optional[Phi4MMAudioInputs]: """ - Parse and validate the audio input to the model. This handles both + Parse and validate the audio input to the model. This handles both audio features and audio embeddings, but only the former is used for now. @@ -1212,17 +1272,18 @@ def _parse_and_validate_audio_input( return None if audio_features is not None: - return Phi4MMAudioFeatureInputs(type="audio_features", - data=flatten_bn(audio_features)) + return Phi4MMAudioFeatureInputs( + type="audio_features", data=flatten_bn(audio_features) + ) if audio_embeds is not None: - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", - data=audio_embeds) + return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) raise AssertionError("This line should be unreachable.") - def _process_audio_input(self, audio_input: Phi4MMAudioInputs, - audio_projection_mode: str) -> NestedTensors: + def _process_audio_input( + self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str + ) -> NestedTensors: """ Create the audio embeddings from the audio input, where the audio input is pairs of audio features and audio embed lengths. The audio input is @@ -1246,12 +1307,14 @@ def _process_audio_input(self, audio_input: Phi4MMAudioInputs, self.audio_embed( features.unsqueeze(0).to(dtype), audio_projection_mode=audio_projection_mode, - ) for features in audio_features + ) + for features in audio_features ] return audio_embeds def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]: + self, **kwargs: object + ) -> Optional[Phi4MMImagePixelInputs]: image_pixel_values: NestedTensors = kwargs.get("image_pixel_values") if image_pixel_values is None: return None @@ -1259,12 +1322,16 @@ def _parse_and_validate_image_input( image_sizes = kwargs.get("image_sizes") image_attention_mask = kwargs.get("image_attention_mask") num_img_tokens = kwargs.get("num_img_tokens") - assert image_sizes is not None and image_attention_mask is not None\ - and num_img_tokens is not None, "Missing image inputs" + assert ( + image_sizes is not None + and image_attention_mask is not None + and num_img_tokens is not None + ), "Missing image inputs" if is_list_of(image_pixel_values, torch.Tensor): - assert all(p.dim() == 5 - for p in image_pixel_values), "Incorrect image inputs" + assert all(p.dim() == 5 for p in image_pixel_values), ( + "Incorrect image inputs" + ) # list len is batch_size. # each tensor has dimension: num_img_per_example, num_hd_patches, # channels, height, width. @@ -1297,8 +1364,7 @@ def _parse_and_validate_image_input( if isinstance(num_img_tokens, list): num_img_tokens = [ - n for num_tensor in num_img_tokens - for n in num_tensor.tolist() + n for num_tensor in num_img_tokens for n in num_tensor.tolist() ] elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() @@ -1319,33 +1385,35 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("image_pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("audio_input_features", - "audio_embeds") and "audios" not in modalities: - modalities["audios"] = self._parse_and_validate_audio_input( - **kwargs) + if ( + input_key in ("image_pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("audio_input_features", "audio_embeds") + and "audios" not in modalities + ): + modalities["audios"] = self._parse_and_validate_audio_input(**kwargs) return modalities def _process_image_input( - self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]: + self, image_input: Phi4MMImagePixelInputs + ) -> list[torch.Tensor]: if image_input["type"] == "image_embeds": image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: dtype = next(self.image_embed.parameters()).dtype - pixel_values = image_input['data'].to(dtype) - image_sizes = image_input['image_sizes'] - image_attention_mask = image_input['image_attention_mask'] - image_embeds = self.image_embed(pixel_values, image_sizes, - image_attention_mask) + pixel_values = image_input["data"].to(dtype) + image_sizes = image_input["image_sizes"] + image_attention_mask = image_input["image_attention_mask"] + image_embeds = self.image_embed( + pixel_values, image_sizes, image_attention_mask + ) return image_embeds - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -1356,7 +1424,7 @@ def get_multimodal_embeddings(self, # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. - audio_projection_mode = 'speech' + audio_projection_mode = "speech" for modality in modalities: # make sure process images first if modality == "images": @@ -1367,7 +1435,8 @@ def get_multimodal_embeddings(self, if modality == "audios": audio_input = modalities["audios"] audio_embeddings = self._process_audio_input( - audio_input, audio_projection_mode=audio_projection_mode) + audio_input, audio_projection_mode=audio_projection_mode + ) multimodal_embeddings += tuple(audio_embeddings) return multimodal_embeddings @@ -1398,8 +1467,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -1410,8 +1478,9 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model.", connector=[ - "img_projection", "vision_speech_projection", - "speech_projection" + "img_projection", + "vision_speech_projection", + "speech_projection", ], tower_model=["image_embed", "audio_embed"], ) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index e3529dc393cf..981f9b37846f 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -7,8 +7,13 @@ import numpy as np import torch import torch.nn as nn -from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, - SequenceFeatureExtractor, SiglipVisionConfig) +from transformers import ( + BatchFeature, + PretrainedConfig, + ProcessorMixin, + SequenceFeatureExtractor, + SiglipVisionConfig, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -16,18 +21,33 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, +) from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, - ImageProcessorItems, ImageSize, - MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, ResolvedPromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + ResolvedPromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -47,16 +67,17 @@ SIGLIP_NAME = "siglip-so400m-patch14-448" VISION_ENCODER_TO_PROCESSING_CONFIG = { - 'siglip-so400m-patch14-448': { - 'vit_image_size': 448, - 'vit_patch_size': 14, - 'token_compression_factor': 2, + "siglip-so400m-patch14-448": { + "vit_image_size": 448, + "vit_patch_size": 14, + "token_compression_factor": 2, }, } -def _get_padding_size(orig_width: int, orig_height: int, target_height: int, - target_width: int): +def _get_padding_size( + orig_width: int, orig_height: int, target_height: int, target_width: int +): ratio_width = target_width / orig_width ratio_height = target_height / orig_height @@ -82,8 +103,7 @@ def get_navit_vision_model(layer_idx: int = -1, **kwargs): model_config = SiglipVisionConfig(**vision_config, **kwargs) if layer_idx < 0: - num_hidden_layers = model_config.num_hidden_layers \ - + layer_idx + 1 + num_hidden_layers = model_config.num_hidden_layers + layer_idx + 1 else: num_hidden_layers = layer_idx + 1 @@ -99,38 +119,38 @@ def get_navit_vision_model(layer_idx: int = -1, **kwargs): class Phi4MMImageEncoder(nn.Module): """Image embedding.""" - def __init__(self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig], - prefix: str = "", - model_dir: str = "") -> None: + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + model_dir: str = "", + ) -> None: super().__init__() # n_embed or hidden_size - hidden_size = config.n_embd if hasattr( - config, 'n_embd') else config.hidden_size + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size # layer_idx to output the img features if isinstance(config.img_processor, dict): - self.layer_idx = config.img_processor.get('layer_idx', -2) - self.type_feature = config.img_processor.get( - 'type_feature', 'patch') + self.layer_idx = config.img_processor.get("layer_idx", -2) + self.type_feature = config.img_processor.get("type_feature", "patch") else: self.layer_idx = -2 - self.type_feature = 'patch' + self.type_feature = "patch" self.img_processor = get_navit_vision_model(layer_idx=self.layer_idx) pe_weight = self.img_processor.embeddings.position_embedding.weight L, D = pe_weight.size() H = int(math.sqrt(L)) - assert H**2 == L, f'position embedding size {L} is not square' + assert H**2 == L, f"position embedding size {L} is not square" if H % 2 != 0: self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) H += 1 image_dim_out = D # ((448/14)//2)**2 - self.num_img_tokens = (H // 2)**2 + self.num_img_tokens = (H // 2) ** 2 self.base_feat_height_target = H self.image_dim_out = image_dim_out @@ -145,37 +165,35 @@ def __init__(self, self.crop_size = 448 # image token compression - self.image_token_compression_cls = 'avg_pool_2d' + self.image_token_compression_cls = "avg_pool_2d" self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) self.base_feat_height_reduction = 1 self.base_feat_height_target = self.base_feat_height_target // 2 # with_hd_transform and with_learnable_separator should have same value - assert self.use_hd_transform == self.with_learnable_separator, \ - 'use_hd_transform and with_learnable_separator should have same value' - assert self.use_hd_transform, \ - 'learnable separator is only for hd transform' + assert self.use_hd_transform == self.with_learnable_separator, ( + "use_hd_transform and with_learnable_separator should have same value" + ) + assert self.use_hd_transform, "learnable separator is only for hd transform" # 1024 * 4, merge spatial to channel dimension self.glb_GN = nn.Parameter( - torch.zeros([ - 1, 1, self.image_dim_out * self.base_feat_height_reduction**2 - ])) + torch.zeros([1, 1, self.image_dim_out * self.base_feat_height_reduction**2]) + ) self.sub_GN = nn.Parameter( - torch.zeros([ - 1, 1, 1, - self.image_dim_out * self.base_feat_height_reduction**2 - ])) + torch.zeros( + [1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2] + ) + ) dim_projection = hidden_size depth = 2 layers = [ - nn.Linear(image_dim_out * self.base_feat_height_reduction**2, - dim_projection) + nn.Linear( + image_dim_out * self.base_feat_height_reduction**2, dim_projection + ) ] for _ in range(1, depth): - layers.extend( - [nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.img_projection = nn.Sequential(*layers) self.vocab_size = config.vocab_size @@ -183,24 +201,24 @@ def __init__(self, self.use_out_place_operations = False - def get_img_features(self, - img_embeds: torch.FloatTensor, - attention_mask=None) -> torch.FloatTensor: - - img_feature = self.img_processor(img_embeds, - patch_attention_mask=attention_mask) + def get_img_features( + self, img_embeds: torch.FloatTensor, attention_mask=None + ) -> torch.FloatTensor: + img_feature = self.img_processor( + img_embeds, patch_attention_mask=attention_mask + ) if self.type_feature == "patch": patch_feature = img_feature use_token_compression = self.image_token_compression is not None - use_padding = getattr(self, 'img_processor_padding', - None) is not None + use_padding = getattr(self, "img_processor_padding", None) is not None if use_token_compression or use_padding: # reshape to 2D tensor width = int(math.sqrt(patch_feature.size(1))) - patch_feature = patch_feature.view(-1, width, width, - patch_feature.size(-1)) + patch_feature = patch_feature.view( + -1, width, width, patch_feature.size(-1) + ) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) @@ -214,15 +232,19 @@ def get_img_features(self, patch_feature = patch_feature.view( -1, patch_feature.size(1) * patch_feature.size(2), - patch_feature.size(-1)) + patch_feature.size(-1), + ) return patch_feature raise NotImplementedError - def forward(self, pixel_values: torch.FloatTensor, - image_sizes: torch.Tensor, - image_attention_mask: torch.Tensor) -> list[torch.FloatTensor]: + def forward( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + image_attention_mask: torch.Tensor, + ) -> list[torch.FloatTensor]: """ process image and return vision embeddings. @@ -251,25 +273,27 @@ def forward(self, pixel_values: torch.FloatTensor, img_features = self.get_img_features( pixel_values, - image_attention_mask.type(torch.BoolTensor).flatten( - 0, 1).to(target_device)) + image_attention_mask.type(torch.BoolTensor).flatten(0, 1).to(target_device), + ) base_feat_height_target = self.base_feat_height_target base_resolution = self.crop_size base_feat_height_reduction = self.base_feat_height_reduction - base_feat_height = base_feat_width = int(np.sqrt( - img_features.shape[1])) - assert base_feat_height == base_feat_height_target \ - and base_feat_width == base_feat_height_target, \ - (f"base_feat_height: {base_feat_height}, " - f"base_feat_width: {base_feat_width}, " - f"expect {base_feat_height_target} features for hd transform") + base_feat_height = base_feat_width = int(np.sqrt(img_features.shape[1])) + assert ( + base_feat_height == base_feat_height_target + and base_feat_width == base_feat_height_target + ), ( + f"base_feat_height: {base_feat_height}, " + f"base_feat_width: {base_feat_width}, " + f"expect {base_feat_height_target} features for hd transform" + ) # bs x max_num_crops x (24x24) x C - img_features = img_features.view(bs, -1, - base_feat_height * base_feat_width, - self.image_dim_out) + img_features = img_features.view( + bs, -1, base_feat_height * base_feat_width, self.image_dim_out + ) C = self.image_dim_out H = base_feat_height @@ -288,22 +312,32 @@ def forward(self, pixel_values: torch.FloatTensor, global_img_feature = img_features[_bs, :1] # 1 x 12 x 12 x 4096 - glb_img = global_img_feature.reshape(1, H, H, C).reshape( - 1, H // base_feat_height_reduction, base_feat_height_reduction, - H // base_feat_height_reduction, base_feat_height_reduction, - C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( - 1, H // base_feat_height_reduction, + glb_img = ( + global_img_feature.reshape(1, H, H, C) + .reshape( + 1, + H // base_feat_height_reduction, + base_feat_height_reduction, H // base_feat_height_reduction, - base_feat_height_reduction * base_feat_height_reduction * - C).contiguous() - temp_glb_GN = self.sub_GN.repeat(1, - H // base_feat_height_reduction, - 1, 1) + base_feat_height_reduction, + C, + ) + .contiguous() + .permute(0, 1, 3, 2, 4, 5) + .reshape( + 1, + H // base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * C, + ) + .contiguous() + ) + temp_glb_GN = self.sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1) # 1 x 156 x 4096 glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape( - 1, -1, - base_feat_height_reduction * base_feat_height_reduction * C) + 1, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) # (max_num_crops-1) x (12x12) x C sub_img = img_features[_bs, 1:] @@ -313,79 +347,106 @@ def forward(self, pixel_values: torch.FloatTensor, # (num_crops, 12, 2, 12, 2, 1024) -> # (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) - sub_img = sub_img.reshape(B_, H, H, C).reshape( - B_, H // base_feat_height_reduction, - base_feat_height_reduction, H // base_feat_height_reduction, - base_feat_height_reduction, - C).contiguous().permute(0, 1, 3, 2, 4, 5).reshape( - B_, -1, base_feat_height_reduction * - base_feat_height_reduction * C).contiguous() - sub_img = sub_img.reshape( - 1, h, w, base_feat_height // base_feat_height_reduction, - base_feat_width // base_feat_height_reduction, - -1).permute(0, 1, 3, 2, 4, 5).reshape( - 1, h * base_feat_height // base_feat_height_reduction, + sub_img = ( + sub_img.reshape(B_, H, H, C) + .reshape( + B_, + H // base_feat_height_reduction, + base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction, + C, + ) + .contiguous() + .permute(0, 1, 3, 2, 4, 5) + .reshape( + B_, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) + .contiguous() + ) + sub_img = ( + sub_img.reshape( + 1, + h, + w, + base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction, + -1, + ) + .permute(0, 1, 3, 2, 4, 5) + .reshape( + 1, + h * base_feat_height // base_feat_height_reduction, w * base_feat_width // base_feat_height_reduction, - base_feat_height_reduction * base_feat_height_reduction * - C) - - if image_attention_mask is not None and len( - image_attention_mask) > 0: - reshaped_image_attention_mask = image_attention_mask[ - _bs, 1:B_ + 1, 0::2, 0::2].reshape( - 1, h, w, + base_feat_height_reduction * base_feat_height_reduction * C, + ) + ) + + if image_attention_mask is not None and len(image_attention_mask) > 0: + reshaped_image_attention_mask = ( + image_attention_mask[_bs, 1 : B_ + 1, 0::2, 0::2] + .reshape( + 1, + h, + w, base_feat_height // base_feat_height_reduction, - base_feat_width // base_feat_height_reduction).permute( - 0, 1, 3, 2, 4).reshape( - 1, h * base_feat_height // - base_feat_height_reduction, w * - base_feat_width // base_feat_height_reduction) - useful_height = int( - reshaped_image_attention_mask[0, :, 0].sum().item()) - useful_width = int( - reshaped_image_attention_mask[0, 0, :].sum().item()) + base_feat_width // base_feat_height_reduction, + ) + .permute(0, 1, 3, 2, 4) + .reshape( + 1, + h * base_feat_height // base_feat_height_reduction, + w * base_feat_width // base_feat_height_reduction, + ) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) sub_img = sub_img[:, :useful_height, :useful_width] temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) - temp_len = int( - image_attention_mask[_bs, :B_ + 1, 0::2, 0::2].sum().item( - )) + (useful_height + - 1) + base_feat_height // base_feat_height_reduction + temp_len = ( + int(image_attention_mask[_bs, : B_ + 1, 0::2, 0::2].sum().item()) + + (useful_height + 1) + + base_feat_height // base_feat_height_reduction + ) else: temp_sub_GN = self.sub_GN.repeat( - 1, h * base_feat_height // base_feat_height_reduction, 1, - 1) - temp_len = int((h * w + 1) * self.num_img_tokens + 1 + - (h + 1) * base_feat_height // - base_feat_height_reduction) + 1, h * base_feat_height // base_feat_height_reduction, 1, 1 + ) + temp_len = int( + (h * w + 1) * self.num_img_tokens + + 1 + + (h + 1) * base_feat_height // base_feat_height_reduction + ) sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape( - 1, -1, - base_feat_height_reduction * base_feat_height_reduction * C) + 1, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) # (1, num_img_tokens, 1024*4) # glb + sub - if self.hd_transform_order == 'glb_sub': - output_imgs.append( - torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) - elif self.hd_transform_order == 'sub_glb': - output_imgs.append( - torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + if self.hd_transform_order == "glb_sub": + output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == "sub_glb": + output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) else: raise NotImplementedError( f'hd_transform_order = {self.hd_transform_order}, "\ - "not implemented') + "not implemented' + ) - #temp_len = int((h*w+1)*144 + 1 + (h+1)*12) - assert temp_len == output_imgs[-1].shape[ - 1], f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\ + # temp_len = int((h*w+1)*144 + 1 + (h+1)*12) + assert temp_len == output_imgs[-1].shape[1], ( + f'temp_len: {temp_len}, output_imgs[-1].shape[1]: "\ "{output_imgs[-1].shape[1]}' + ) output_len.append(temp_len) img_set_tensor = [] for _output_img in output_imgs: img_feature_proj = self.img_projection( - _output_img.to(target_device).to(target_dtype)) + _output_img.to(target_device).to(target_dtype) + ) img_set_tensor.append(img_feature_proj.squeeze(0)) return img_set_tensor @@ -408,8 +469,9 @@ class Phi4MMImagePixelInputs(TensorSchema): data: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} - ), # may be different per batch and image + TensorShape( + "bn", "p", 3, "h", "w", dynamic_dims={"p"} + ), # may be different per batch and image ] image_sizes: Annotated[ @@ -451,6 +513,7 @@ class Phi4MMAudioEmbeddingInputs(TensorSchema): - f: Audio feature size - h: Hidden size (must match language model backbone) """ + type: Literal["audio_embeds"] data: Annotated[ NestedTensors, @@ -466,9 +529,9 @@ def cat_with_pad(tensors, dim, padding_value=0): cat along dim, while pad to max for all other dims """ ndim = tensors[0].dim() - assert all( - t.dim() == ndim for t in - tensors[1:]), "All tensors must have the same number of dimensions" + assert all(t.dim() == ndim for t in tensors[1:]), ( + "All tensors must have the same number of dimensions" + ) out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] out_size[dim] = sum(t.shape[dim] for t in tensors) @@ -488,14 +551,13 @@ def cat_with_pad(tensors, dim, padding_value=0): class Phi4MMProcessingInfo(BaseProcessingInfo): - @property def image_tokens(self) -> list[str]: - return [f"<|image_{i+1}|>" for i in range(100)] + return [f"<|image_{i + 1}|>" for i in range(100)] @property def audio_tokens(self) -> list[str]: - return [f"<|audio_{i+1}|>" for i in range(100)] + return [f"<|audio_{i + 1}|>" for i in range(100)] def get_dynamic_hd( self, @@ -506,8 +568,7 @@ def get_dynamic_hd( image_processor = processor.image_processor return image_processor.dynamic_hd - def get_feature_extractor(self, - **kwargs: object) -> SequenceFeatureExtractor: + def get_feature_extractor(self, **kwargs: object) -> SequenceFeatureExtractor: return self.get_hf_processor(**kwargs).audio_processor def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: @@ -527,9 +588,12 @@ def _find_target_aspect_ratio( aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio - target_ratios = set((i, j) for i in range(1, max_num + 1) - for j in range(1, max_num + 1) - if i * j <= max_num and i * j >= min_num) + target_ratios = set( + (i, j) + for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num and i * j >= min_num + ) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target @@ -562,49 +626,56 @@ def _compute_num_image_tokens( ): """ compute the number of tokens an image is expected to take up considering - the image encoder architecture and exclude output features containing + the image encoder architecture and exclude output features containing only padding pixels - for siglip, vit_image_size=448, vit_patch_size=14, so output will be + for siglip, vit_image_size=448, vit_patch_size=14, so output will be 32x32 feature map NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 """ assert vit_image_size % vit_patch_size == 0, ( - "vit_image_size must be divisible by vit_patch_size") - assert (vit_image_size // vit_patch_size % - token_compression_factor == 0), ( - "vit_image_size // vit_patch_size must be divisible by " - "token_compression_factor") + "vit_image_size must be divisible by vit_patch_size" + ) + assert vit_image_size // vit_patch_size % token_compression_factor == 0, ( + "vit_image_size // vit_patch_size must be divisible by " + "token_compression_factor" + ) target_aspect_ratio, target_height, target_width = ( - self._find_target_aspect_ratio(orig_width, - orig_height, - vit_image_size, - dynamic_hd_size, - min_num=1)) + self._find_target_aspect_ratio( + orig_width, orig_height, vit_image_size, dynamic_hd_size, min_num=1 + ) + ) assert target_aspect_ratio[0] * vit_image_size == target_width, ( - f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}") + f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" + ) assert target_aspect_ratio[1] * vit_image_size == target_height, ( - f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}") - assert (target_height % vit_image_size == 0 - and target_width % vit_image_size == 0) + f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" + ) + assert ( + target_height % vit_image_size == 0 and target_width % vit_image_size == 0 + ) padding_height, padding_width = _get_padding_size( - orig_width, orig_height, target_height, target_width) - assert padding_width == 0 or padding_height == 0, \ + orig_width, orig_height, target_height, target_width + ) + assert padding_width == 0 or padding_height == 0, ( "padding_width or padding_height must be 0" + ) target_feat_width = target_width // vit_patch_size target_feat_height = target_height // vit_patch_size if padding_width >= vit_patch_size: assert padding_height == 0, "padding_height not 0" non_pad_feat_width = target_feat_width - math.floor( - padding_width / vit_patch_size) + padding_width / vit_patch_size + ) non_pad_feat_height = target_feat_height elif padding_height >= vit_patch_size: assert padding_width == 0, "padding_width not 0" non_pad_feat_height = target_feat_height - math.floor( - padding_height / vit_patch_size) + padding_height / vit_patch_size + ) non_pad_feat_width = target_feat_width else: # small padding shorter than a vit patch @@ -621,15 +692,17 @@ def _compute_num_image_tokens( num_hd_patch_tokens = feat_width * feat_height num_hd_newline_tokens = feat_height vit_feature_size = vit_image_size // vit_patch_size - num_global_image_tokens = (vit_feature_size // - token_compression_factor)**2 + num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2 num_sep_tokens = 1 - num_global_image_newline_tokens = \ - vit_feature_size // token_compression_factor - - return (num_global_image_tokens + num_sep_tokens + - num_hd_patch_tokens + num_hd_newline_tokens + - num_global_image_newline_tokens) + num_global_image_newline_tokens = vit_feature_size // token_compression_factor + + return ( + num_global_image_tokens + + num_sep_tokens + + num_hd_patch_tokens + + num_hd_newline_tokens + + num_global_image_newline_tokens + ) def get_num_image_tokens( self, @@ -642,11 +715,10 @@ def get_num_image_tokens( vision_encoder_name = hf_config.img_processor if vision_encoder_name is None: vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[ - vision_encoder_name] - vit_image_size = prepro_config['vit_image_size'] - vit_patch_size = prepro_config['vit_patch_size'] - token_compression_factor = prepro_config['token_compression_factor'] + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + vit_image_size = prepro_config["vit_image_size"] + vit_patch_size = prepro_config["vit_patch_size"] + token_compression_factor = prepro_config["token_compression_factor"] dynamic_hd_size = self.get_dynamic_hd(processor=processor) @@ -669,9 +741,8 @@ def get_image_size_with_most_features( vision_encoder_name = hf_config.img_processor if vision_encoder_name is None: vision_encoder_name = SIGLIP_NAME - prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[ - vision_encoder_name] - vit_image_size = prepro_config['vit_image_size'] + prepro_config = VISION_ENCODER_TO_PROCESSING_CONFIG[vision_encoder_name] + vit_image_size = prepro_config["vit_image_size"] max_side = vit_image_size * self.get_dynamic_hd(processor=processor) return ImageSize(height=max_side, width=vit_image_size) @@ -717,8 +788,7 @@ def _compute_audio_embed_size(self, audio_frames: int) -> int: compression rate. """ hf_config = self.get_hf_config() - compression_rate = hf_config.embd_layer['audio_embd_layer'][ - 'compression_rate'] + compression_rate = hf_config.embd_layer["audio_embd_layer"]["compression_rate"] # NOTE: this is a hard-coded value but might be configurable # in the future qformer_compression_rate = 1 @@ -736,7 +806,6 @@ def _compute_audio_embed_size(self, audio_frames: int) -> int: class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) @@ -755,33 +824,34 @@ def get_dummy_mm_data( num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None audio_overrides = mm_options.get("audio") if mm_options else None mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), - "audio": - self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE, - num_audios=num_audios, - overrides=audio_overrides), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "audio": self._get_dummy_audios( + length=_AUDIO_MAX_SOUNDFILE_SIZE, + num_audios=num_audios, + overrides=audio_overrides, + ), } return mm_data class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() - return MultiModalDataParser(target_sr=feature_extractor.sampling_rate, - audio_resample_method="scipy") + return MultiModalDataParser( + target_sr=feature_extractor.sampling_rate, audio_resample_method="scipy" + ) def _call_hf_processor( self, @@ -796,27 +866,27 @@ def _call_hf_processor( return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate - if (audio_data := mm_data.get("audios", [])): - mm_data['audios'] = [(data, sr) for data in audio_data] + if audio_data := mm_data.get("audios", []): + mm_data["audios"] = [(data, sr) for data in audio_data] - processed_outputs = super()._call_hf_processor(prompt, mm_data, - mm_kwargs, tok_kwargs) + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs + ) num_img_tokens = [ - self.info.get_num_image_tokens(image_width=img_size[0], - image_height=img_size[1]) + self.info.get_num_image_tokens( + image_width=img_size[0], image_height=img_size[1] + ) for img_size in processed_outputs["image_sizes"] ] processed_outputs["num_img_tokens"] = num_img_tokens - audio_features = processed_outputs['input_audio_embeds'] + audio_features = processed_outputs["input_audio_embeds"] feature_sizes = [ - self.info.get_audio_num_frames(len(audio), sr) - for audio in audio_data + self.info.get_audio_num_frames(len(audio), sr) for audio in audio_data ] - processed_outputs['input_audio_embeds'] = [ - audio_features[idx, :size] - for idx, size in enumerate(feature_sizes) + processed_outputs["input_audio_embeds"] = [ + audio_features[idx, :size] for idx, size in enumerate(feature_sizes) ] return processed_outputs @@ -842,13 +912,13 @@ def _get_prompt_updates( ) -> Sequence[PromptUpdate]: image_tokens: list[str] = self.info.image_tokens # type: ignore audio_tokens: list[str] = self.info.audio_tokens # type: ignore - feature_extractor = self.info.get_feature_extractor( - **hf_processor_mm_kwargs) + feature_extractor = self.info.get_feature_extractor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) def get_image_replacement_phi4mm(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_image_tokens = images.get_feature_size(item_idx) @@ -867,9 +937,9 @@ def get_audio_replacement_phi4mm(item_idx: int): # TODO(Isotr0py): support embedding inputs audio_len = audios.get_audio_length(item_idx) audio_frames = self.info.get_audio_num_frames( - audio_len, feature_extractor.sampling_rate) - audio_embed_size = self.info._compute_audio_embed_size( - audio_frames) + audio_len, feature_extractor.sampling_rate + ) + audio_embed_size = self.info._compute_audio_embed_size(audio_frames) return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size @@ -915,6 +985,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): """ Implements the Phi-4-multimodal-instruct model in vLLM. """ + packed_modules_mapping = { "qkv_proj": [ "qkv_proj", @@ -929,10 +1000,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): "base_layer.": "", }, orig_to_new_prefix={ - "model.embed_tokens_extend.audio_embed.audio_projection.vision.": - "embed_tokens_extend.audio_projection_for_vision.", - "model.embed_tokens_extend.audio_embed.audio_projection.speech.": - "embed_tokens_extend.audio_projection.", + "model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.", # noqa: E501 + "model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.", # noqa: E501 "model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.", "model.embed_tokens_extend.image_embed.": "vision_encoder.", }, @@ -961,19 +1030,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config # Tensor/Pipeline parallel not supported for now. - assert get_pp_group( - ).world_size == 1, "pipeline parallel is not supported" + assert get_pp_group().world_size == 1, "pipeline parallel is not supported" self.vision_encoder = Phi4MMImageEncoder( config, quant_config, prefix="model.vision_embed_tokens", - model_dir=config._name_or_path) + model_dir=config._name_or_path, + ) if isinstance(config.embd_layer["audio_embd_layer"], dict): embedding_config = { - "embedding_cls": - config.embd_layer["audio_embd_layer"]["embedding_cls"], + "embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"], **config.embd_layer["audio_embd_layer"], } else: @@ -982,8 +1050,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): } self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) - self.model = LlamaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = LlamaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -999,13 +1068,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Phi4MMAudioInputs]: + self, **kwargs: object + ) -> Optional[Phi4MMAudioInputs]: """ - Parse and validate the audio input to the model. This handles both + Parse and validate the audio input to the model. This handles both audio features and audio embeddings, but only the former is used for now. @@ -1022,17 +1093,18 @@ def _parse_and_validate_audio_input( return None if audio_features is not None: - return Phi4MMAudioFeatureInputs(type="audio_features", - data=flatten_bn(audio_features)) + return Phi4MMAudioFeatureInputs( + type="audio_features", data=flatten_bn(audio_features) + ) if audio_embeds is not None: - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", - data=audio_embeds) + return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) raise AssertionError("This line should be unreachable.") - def _process_audio_input(self, audio_input: Phi4MMAudioInputs, - audio_projection_mode: str) -> NestedTensors: + def _process_audio_input( + self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str + ) -> NestedTensors: """ Create the audio embeddings from the audio input, where the audio input is pairs of audio features and audio embed lengths. The audio input is @@ -1056,12 +1128,14 @@ def _process_audio_input(self, audio_input: Phi4MMAudioInputs, self.embed_tokens_extend( features.to(dtype), audio_projection_mode=audio_projection_mode, - ) for features in audio_features + ) + for features in audio_features ] return audio_embeds def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]: + self, **kwargs: object + ) -> Optional[Phi4MMImagePixelInputs]: input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") if input_image_embeds is None: return None @@ -1069,12 +1143,16 @@ def _parse_and_validate_image_input( image_sizes = kwargs.get("image_sizes") image_attention_mask = kwargs.get("image_attention_mask") num_img_tokens = kwargs.get("num_img_tokens") - assert image_sizes is not None and image_attention_mask is not None\ - and num_img_tokens is not None, "Missing image inputs" + assert ( + image_sizes is not None + and image_attention_mask is not None + and num_img_tokens is not None + ), "Missing image inputs" if is_list_of(input_image_embeds, torch.Tensor): - assert all(p.dim() == 5 - for p in input_image_embeds), "Incorrect image inputs" + assert all(p.dim() == 5 for p in input_image_embeds), ( + "Incorrect image inputs" + ) # list len is batch_size. # each tensor has dimension: num_img_per_example, num_hd_patches, # channels, height, width. @@ -1107,8 +1185,7 @@ def _parse_and_validate_image_input( if isinstance(num_img_tokens, list): num_img_tokens = [ - n for num_tensor in num_img_tokens - for n in num_tensor.tolist() + n for num_tensor in num_img_tokens for n in num_tensor.tolist() ] elif isinstance(num_img_tokens, torch.Tensor): num_img_tokens = num_img_tokens.flatten(0, 1).tolist() @@ -1129,31 +1206,32 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("input_image_embeds", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("input_audio_embeds", - "audio_embeds") and "audios" not in modalities: - modalities["audios"] = self._parse_and_validate_audio_input( - **kwargs) + if ( + input_key in ("input_image_embeds", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("input_audio_embeds", "audio_embeds") + and "audios" not in modalities + ): + modalities["audios"] = self._parse_and_validate_audio_input(**kwargs) return modalities def _process_image_input( - self, image_input: Phi4MMImagePixelInputs) -> list[torch.Tensor]: - + self, image_input: Phi4MMImagePixelInputs + ) -> list[torch.Tensor]: dtype = next(self.vision_encoder.parameters()).dtype - pixel_values = image_input['data'].to(dtype) - image_sizes = image_input['image_sizes'] - image_attention_mask = image_input['image_attention_mask'] - image_embeds = self.vision_encoder(pixel_values, image_sizes, - image_attention_mask) + pixel_values = image_input["data"].to(dtype) + image_sizes = image_input["image_sizes"] + image_attention_mask = image_input["image_attention_mask"] + image_embeds = self.vision_encoder( + pixel_values, image_sizes, image_attention_mask + ) return image_embeds - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -1164,7 +1242,7 @@ def get_multimodal_embeddings(self, # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. - audio_projection_mode = 'speech' + audio_projection_mode = "speech" for modality in modalities: # make sure process images first if modality == "images": @@ -1175,7 +1253,8 @@ def get_multimodal_embeddings(self, if modality == "audios": audio_input = modalities["audios"] audio_embeddings = self._process_audio_input( - audio_input, audio_projection_mode=audio_projection_mode) + audio_input, audio_projection_mode=audio_projection_mode + ) multimodal_embeddings += tuple(audio_embeddings) return multimodal_embeddings @@ -1207,8 +1286,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> None: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: loader = AutoWeightsLoader(self, skip_substrs=["lora"]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py index a1c452053ddd..d289e26efa10 100644 --- a/vllm/model_executor/models/phi4mm_audio.py +++ b/vllm/model_executor/models/phi4mm_audio.py @@ -14,15 +14,24 @@ import torch.nn.functional as F from torch import Tensor, nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointWrapper) -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - FullyShardedDataParallel) + CheckpointWrapper, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from transformers import PretrainedConfig from vllm.model_executor.models.phi4mm_utils import ( - AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer, - MultiHeadedAttention, MultiSequential, NemoConvSubsampling, - T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor) + AbsolutePositionalEncoding, + ConvModule, + FeedForward, + MeanVarianceNormLayer, + MultiHeadedAttention, + MultiSequential, + NemoConvSubsampling, + T5RelativeAttentionLogitBias, + adaptive_enc_mask, + get_offset, + unfold_tensor, +) _AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|> @@ -40,9 +49,9 @@ class ConformerEncoderLayer(nn.Module): if > 0, ext_pw_out_channel is a dim channel size for the last pointwise conv after swish activation. depthwise_seperable_out_channel: int - if set different to 0, the number of + if set different to 0, the number of depthwise_seperable_out_channel will be used as a - channel_out of the second conv1d layer. + channel_out of the second conv1d layer. otherwise, it equals to 0, the second conv1d layer is skipped. depthwise_multiplier: int number of input_dim channels duplication. this value @@ -119,10 +128,10 @@ class ConformerEncoderLayer(nn.Module): and allow the onnx conversion for inference. default False. use_pt_scaled_dot_product_attention: bool, optional - if set to True, use pytorch's scaled dot product attention + if set to True, use pytorch's scaled dot product attention implementation in training. attn_group_sizes: int, optional - the number of groups to use for attention, default 1 + the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attn_group_sizes < attention_heads = Grouped-Query Attention @@ -173,8 +182,7 @@ def __init__( attention_inner_dim, attention_glu_type, bias_in_glu, - use_pt_scaled_dot_product_attention= - use_pt_scaled_dot_product_attention, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, group_size=attn_group_sizes, ) self.conv = ConvModule( @@ -296,7 +304,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module): (Q*K^T + B) implemented in cmb.basics.embedding. [T5/ALiBi]RelativeAttentionLogitBias usage: relative_attention_bias_args={"type": t5/alibi} - additional method-specific arguments can be provided (see + additional method-specific arguments can be provided (see transformer_base.py) positional_dropout_rate: float, optional dropout rate after positional encoding. default 0.0 @@ -310,10 +318,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module): supraframe utts in batch. Default: none attention_group_size: int, optional - the number of groups to use for attention, default 1 + the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, - 1 < attention_group_size < attention_heads = Grouped-Query + 1 < attention_group_size < attention_heads = Grouped-Query Attention attention_group_size = attention_heads = Multi-Query Attention """ @@ -334,8 +342,7 @@ def __init__( relative_attention_bias_args: Optional[dict[str, Any]] = None, positional_dropout_rate: float = 0.0, nemo_conv_settings: Optional[dict[str, Any]] = None, - conv2d_extra_padding: Literal["feat", "feat_time", "none", - True] = "none", + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", attention_group_size: int = 1, encoder_embedding_config: Optional[dict[str, Any]] = None, ) -> None: @@ -366,70 +373,77 @@ def __init__( if nemo_conv_settings: default_nemo_conv_settings.update(nemo_conv_settings) for i in ["subsampling_factor", "feat_in", "feat_out"]: - assert ( - i not in nemo_conv_settings - ), "{i} should be specified outside of the NeMo dictionary" + assert i not in nemo_conv_settings, ( + "{i} should be specified outside of the NeMo dictionary" + ) - self.embed = NemoConvSubsampling(**default_nemo_conv_settings, ) + self.embed = NemoConvSubsampling( + **default_nemo_conv_settings, + ) else: raise ValueError("unknown input_layer: " + input_layer) - self.pos_emb = AbsolutePositionalEncoding(attention_dim, - positional_dropout_rate) + self.pos_emb = AbsolutePositionalEncoding( + attention_dim, positional_dropout_rate + ) self.relative_attention_bias_type = ( relative_attention_bias_args.get("type") - if relative_attention_bias_args else None) + if relative_attention_bias_args + else None + ) if self.relative_attention_bias_type == "t5": - assert (self.num_heads % self.attention_group_size == 0 - ), "attention_group_size must divide n_head" + assert self.num_heads % self.attention_group_size == 0, ( + "attention_group_size must divide n_head" + ) self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( self.num_heads // self.attention_group_size, max_distance=relative_attention_bias_args.get( - "t5_bias_max_distance", 1000), - symmetric=relative_attention_bias_args.get( - "t5_bias_symmetric", False), + "t5_bias_max_distance", 1000 + ), + symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False), ) else: raise NotImplementedError self.encoder_embedding = MeanVarianceNormLayer( - self.encoder_embedding_config["input_size"]) + self.encoder_embedding_config["input_size"] + ) def compute_lens_change( - self, - feature_lens: Union[int, - torch.Tensor]) -> Union[int, torch.Tensor]: + self, feature_lens: Union[int, torch.Tensor] + ) -> Union[int, torch.Tensor]: """feature_lens: int return updated feature lens. - This used to return a different lambda function for each case that - computed the right thing. That does not work within Torchscript. + This used to return a different lambda function for each case that + computed the right thing. That does not work within Torchscript. If you really need this to be faster, create nn.Module()-s for all the cases and return one of them. Torchscript does support that. """ if self.input_layer == "nemo_conv": # Handle the special causal case subsampling_causal_cond = self.nemo_conv_settings.get( - "subsampling", "dw_striding") in [ - "dw_striding", - "striding", - "striding_conv1d", - ] + "subsampling", "dw_striding" + ) in [ + "dw_striding", + "striding", + "striding_conv1d", + ] is_causal = self.nemo_conv_settings.get("is_causal", False) if is_causal and subsampling_causal_cond: - lens_change = (torch.ceil(feature_lens / - self.time_reduction).long() - if isinstance(feature_lens, Tensor) else - math.ceil(feature_lens / self.time_reduction)) + lens_change = ( + torch.ceil(feature_lens / self.time_reduction).long() + if isinstance(feature_lens, Tensor) + else math.ceil(feature_lens / self.time_reduction) + ) feature_lens_remainder = feature_lens % self.time_reduction if isinstance(feature_lens, Tensor): lens_change[feature_lens_remainder != 1] += 1 elif feature_lens_remainder != 1: lens_change += 1 return lens_change - ceil_func = (math.ceil - if isinstance(feature_lens, int) else torch.ceil) + ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil return ceil_func(feature_lens / self.time_reduction) @abc.abstractmethod @@ -437,10 +451,10 @@ def forward(self) -> Any: """Abstract forward method implementation.""" def _chunk_size_selection( - self, - chunk_size: Optional[Union[int, list[int]]] = None, - left_chunk: Optional[Union[int, - list[int]]] = None) -> tuple[int, int]: + self, + chunk_size: Optional[Union[int, list[int]]] = None, + left_chunk: Optional[Union[int, list[int]]] = None, + ) -> tuple[int, int]: """If chunk size is a list, we will randomly select a chunk size.""" if chunk_size is None: @@ -450,15 +464,16 @@ def _chunk_size_selection( if isinstance(chunk_size, list): # Variable chunk size during training chunk_size_index = int( - torch.randint(low=0, high=len(chunk_size), size=(1, ))) + torch.randint(low=0, high=len(chunk_size), size=(1,)) + ) chunk_size_train_eff = chunk_size[chunk_size_index] if not isinstance(left_chunk, list): raise ValueError( - "Since chunk_size is a list, left_chunk must be a list") + "Since chunk_size is a list, left_chunk must be a list" + ) if len(left_chunk) != len(chunk_size): raise ValueError( - "The length of left_chunk must be the same as length of "\ - "chunk_size." + "The length of left_chunk must be the same as length of chunk_size." ) left_chunk_train_eff = left_chunk[chunk_size_index] else: @@ -479,8 +494,8 @@ def _get_embed_class(self, embed: nn.Module) -> nn.Module: return embed_class def _forward_embeddings_core( - self, input_tensor: torch.Tensor, - masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, input_tensor: torch.Tensor, masks: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: embed_class = self._get_embed_class(self.embed) assert isinstance(embed_class, NemoConvSubsampling) input_tensor, masks = self.embed(input_tensor, masks) @@ -493,23 +508,32 @@ def _position_embedding( pos_v = None if self.relative_attention_bias_layer is None: input_tensor = self.pos_emb( - input_tensor) # default to add abs sinusoid embedding + input_tensor + ) # default to add abs sinusoid embedding return pos_k, pos_v - def _streaming_mask(self, seq_len: int, batch_size: int, - chunk_size: Union[int, list[int]], - left_chunk: Union[int, list[int]]) -> torch.Tensor: - chunk_size_train_eff, left_chunk_train_eff = \ - self._chunk_size_selection(chunk_size, left_chunk) + def _streaming_mask( + self, + seq_len: int, + batch_size: int, + chunk_size: Union[int, list[int]], + left_chunk: Union[int, list[int]], + ) -> torch.Tensor: + chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection( + chunk_size, left_chunk + ) # Create mask matrix for streaming # S stores start index. if chunksize is 18, s is [0,18,36,....] chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) - enc_streaming_mask = (adaptive_enc_mask( - seq_len, chunk_start_idx, - left_window=left_chunk_train_eff).unsqueeze(0).expand( - [batch_size, -1, -1])) + enc_streaming_mask = ( + adaptive_enc_mask( + seq_len, chunk_start_idx, left_window=left_chunk_train_eff + ) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) return enc_streaming_mask def forward_embeddings( @@ -517,12 +541,24 @@ def forward_embeddings( xs_pad: torch.Tensor, masks: torch.Tensor, chunk_size_nc: Optional[Union[int, list[int]]] = None, - left_chunk_nc: Optional[Union[int, list[int]]] = None - ) -> Union[tuple[torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, torch.Tensor], - tuple[torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, torch.Tensor, - torch.Tensor]]: + left_chunk_nc: Optional[Union[int, list[int]]] = None, + ) -> Union[ + tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + ], + tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + ]: """Forwarding the inputs through the top embedding layers Args: @@ -530,7 +566,7 @@ def forward_embeddings( input tensor masks: torch.Tensor input mask - chunk_size_nc: (optional, default is None) chunk size for + chunk_size_nc: (optional, default is None) chunk size for non-causal layers left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers @@ -543,21 +579,21 @@ def forward_embeddings( f"""The sequence length after time reduction is invalid: {seq_len}. Your input feature is too short. Consider filtering out the very short sentence from data - loader""", ) + loader""", + ) batch_size = xs_pad.shape[0] - enc_streaming_mask = self._streaming_mask(seq_len, batch_size, - self.chunk_size, - self.left_chunk) + enc_streaming_mask = self._streaming_mask( + seq_len, batch_size, self.chunk_size, self.left_chunk + ) if xs_pad.is_cuda: enc_streaming_mask = enc_streaming_mask.cuda() xs_pad = xs_pad.cuda() input_tensor = xs_pad - input_tensor, masks = self._forward_embeddings_core( - input_tensor, masks) + input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) streaming_mask = enc_streaming_mask if streaming_mask is not None and masks is not None: @@ -569,7 +605,8 @@ def forward_embeddings( if chunk_size_nc is not None: enc_streaming_mask_nc = self._streaming_mask( - seq_len, batch_size, chunk_size_nc, left_chunk_nc) + seq_len, batch_size, chunk_size_nc, left_chunk_nc + ) if xs_pad.is_cuda: enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() if masks is not None: @@ -622,8 +659,8 @@ class ConformerEncoder(TransformerEncoderBase): left_chunk = 6 left_chunk = [12, 9, 6, 3] num_lang: int - This parameter is used to store the number of languages in the - lang_dict, only used for multiseed/multilingual models. + This parameter is used to store the number of languages in the + lang_dict, only used for multiseed/multilingual models. default None. attention_dim: int, optional attention dimension. default 256. @@ -721,16 +758,16 @@ class ConformerEncoder(TransformerEncoderBase): extra_layer_output_idx: int the layer index to be exposed. relative_attention_bias_args: dict, optional - use more efficient scalar bias-based relative multihead attention + use more efficient scalar bias-based relative multihead attention (Q*K^T + B) implemented in cmb.basics.embedding. [T5/ALiBi]RelativeAttentionLogitBias usage: relative_attention_bias_args={"type": t5/alibi} - additional method-specific arguments can be provided (see + additional method-specific arguments can be provided (see transformer_base.py) time_reduction: int optional time reduction factor default 4 - use_pt_scaled_dot_product_attention: whether to use pytorch scaled + use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention in training. Default: False nemo_conv_settings: dict, optional @@ -748,12 +785,12 @@ class ConformerEncoder(TransformerEncoderBase): Add extra padding in conv2d subsampling layers. Choices are (feat, feat_time, none, True) Default: none - replication_pad_for_subsample_embedding: For batched-streaming + replication_pad_for_subsample_embedding: For batched-streaming decoding, use "replication" padding for the cache at start of utterance. Default: False attention_group_size: int, optional - the number of groups to use for attention, default 1 + the number of groups to use for attention, default 1 (Multi-Head Attention), 1 = typical Multi-Head Attention, 1 < attention_group_size < attention_heads = Grouped-Query @@ -799,8 +836,7 @@ def __init__( # pylint: disable-all time_reduction: int = 4, use_pt_scaled_dot_product_attention: bool = False, nemo_conv_settings: Optional[dict[str, Any]] = None, - conv2d_extra_padding: Literal["feat", "feat_time", "none", - True] = "none", + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", replication_pad_for_subsample_embedding: bool = False, attention_group_size: int = 1, encoder_embedding_config: Optional[dict[str, Any]] = None, @@ -827,39 +863,43 @@ def __init__( # pylint: disable-all self.num_lang = num_lang self.kernel_size = kernel_size self.replication_pad_for_subsample_embedding: bool = ( - replication_pad_for_subsample_embedding) - assert (self.num_heads % attention_group_size == 0 - ), "attention_group_size must divide n_head" + replication_pad_for_subsample_embedding + ) + assert self.num_heads % attention_group_size == 0, ( + "attention_group_size must divide n_head" + ) self.num_heads_k = self.num_heads // attention_group_size - self.encoders = MultiSequential(*[ - ConformerEncoderLayer( - d_model=attention_dim, - ext_pw_out_channel=ext_pw_out_channel, - depthwise_seperable_out_channel=depthwise_seperable_out_channel, - depthwise_multiplier=depthwise_multiplier, - n_head=attention_heads, - d_ffn=linear_units, - ext_pw_kernel_size=ext_pw_kernel_size, - kernel_size=kernel_size, - dropout_rate=dropout_rate, - causal=causal, - batch_norm=batch_norm, - activation=activation, - chunk_se=chunk_se, - chunk_size=chunk_size, - conv_activation=conv_activation, - conv_glu_type=conv_glu_type, - bias_in_glu=bias_in_glu, - linear_glu_in_convm=linear_glu_in_convm, - attention_glu_type=attention_glu_type, - activation_checkpointing=activation_checkpointing, - export=export, - use_pt_scaled_dot_product_attention= - use_pt_scaled_dot_product_attention, - attn_group_sizes=attention_group_size, - ) for _ in range(num_blocks) - ]) + self.encoders = MultiSequential( + *[ + ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel=depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=activation_checkpointing, + export=export, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + ) + for _ in range(num_blocks) + ] + ) self.extra_layer_output_idx = extra_layer_output_idx self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs # Make a zeros scalar we can use in get_initial_state to determine @@ -867,34 +907,36 @@ def __init__( # pylint: disable-all self.register_buffer("dev_type", torch.zeros(()), persistent=False) def init_relative_attention_bias( - self, input_tensor: torch.Tensor) -> Optional[torch.Tensor]: + self, input_tensor: torch.Tensor + ) -> Optional[torch.Tensor]: if self.relative_attention_bias_layer: return self.relative_attention_bias_layer(input_tensor) - def calculate_hs_mask(self, xs_pad: torch.Tensor, device: torch.device, - mask: Optional[torch.Tensor]) -> torch.Tensor: + def calculate_hs_mask( + self, xs_pad: torch.Tensor, device: torch.device, mask: Optional[torch.Tensor] + ) -> torch.Tensor: max_audio_length = xs_pad.shape[1] batch_size = xs_pad.shape[0] - enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, - self.chunk_size, - self.left_chunk) + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.chunk_size, self.left_chunk + ) enc_streaming_mask = enc_streaming_mask.to(device) if mask is None: return enc_streaming_mask feature_lens = mask.sum(1) padding_length = feature_lens - pad_mask = (torch.arange(0, max_audio_length, - device=device).expand(padding_length.size(0), - -1) - < padding_length.unsqueeze(1)) + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) pad_mask = pad_mask.unsqueeze(1) pad_mask = pad_mask & enc_streaming_mask return pad_mask @torch.jit.ignore - def forward(self, xs_pad: torch.Tensor, - masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + self, xs_pad: torch.Tensor, masks: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: """Conformer Forward function Args: @@ -905,11 +947,12 @@ def forward(self, xs_pad: torch.Tensor, """ xs_pad = self.encoder_embedding(xs_pad) input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings( - xs_pad, masks) + xs_pad, masks + ) unfolded = False ori_bz, seq_len, D = input_tensor.shape - max_seq_len = 500 #maximum position for absolute positional encoding + max_seq_len = 500 # maximum position for absolute positional encoding if seq_len > max_seq_len: # audio sequence is longer than max_seq_len, unfold it into chunks # of max_seq_len @@ -921,26 +964,29 @@ def forward(self, xs_pad: torch.Tensor, else: chunk_pad_size = 0 if chunk_pad_size > 0: - input_tensor_pad = F.pad(input_tensor, - (0, 0, 0, chunk_pad_size), "constant", - 0) + input_tensor_pad = F.pad( + input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0 + ) input_tensor = input_tensor_pad.to(input_tensor.device) input_tensor = unfold_tensor(input_tensor, max_seq_len) if masks is not None: # revise hs_mask here because the previous calculated hs_mask # did not consider extra pad subsampled_pad_mask = masks.squeeze( - 1) # [bz, subsampled_unmask_seq_len] + 1 + ) # [bz, subsampled_unmask_seq_len] extra_padded_subsamlped_pad_mask = F.pad( - subsampled_pad_mask, (0, chunk_pad_size), "constant", - False) # extra padding to the pad mask - extra_padded_subsamlped_pad_mask = \ + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = ( extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + ) masks_unfold = unfold_tensor( extra_padded_subsamlped_pad_mask, max_seq_len ) # unfold the pad mask like we did to the input tensor masks_unfold = masks_unfold.squeeze( - -1).bool() # unfold op does not support bool tensor + -1 + ).bool() # unfold op does not support bool tensor else: masks_unfold = None hs_mask = self.calculate_hs_mask( @@ -949,15 +995,14 @@ def forward(self, xs_pad: torch.Tensor, # layer_emb = None - relative_attention_bias = self.init_relative_attention_bias( - input_tensor) + relative_attention_bias = self.init_relative_attention_bias(input_tensor) - _simplified_path = (self.extra_layer_output_idx == -1 - and relative_attention_bias is None) + _simplified_path = ( + self.extra_layer_output_idx == -1 and relative_attention_bias is None + ) if _simplified_path: - input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, - hs_mask) + input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask) else: for i, layer in enumerate(self.encoders): input_tensor, _, _, _ = layer( @@ -997,28 +1042,32 @@ def __init__( ): super().__init__() - self.decoders = nn.ModuleList([ - nn.TransformerDecoderLayer( - d_model=attention_dim, - nhead=attention_heads, - dim_feedforward=linear_units, - dropout=dropout_rate, - activation="relu", - batch_first=True, - norm_first=normalize_before, # TODO need to verify - ) for _ in range(num_blocks) - ]) + self.decoders = nn.ModuleList( + [ + nn.TransformerDecoderLayer( + d_model=attention_dim, + nhead=attention_heads, + dim_feedforward=linear_units, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=normalize_before, # TODO need to verify + ) + for _ in range(num_blocks) + ] + ) self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim)) - self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12) - if normalize_before else None) + self.after_norm = ( + nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None + ) self.window_size = window_size def forward( - self, - audio_embed: torch.Tensor, - mask: Optional[torch.Tensor], - embed_len: Optional[int] = None + self, + audio_embed: torch.Tensor, + mask: Optional[torch.Tensor], + embed_len: Optional[int] = None, ) -> tuple[torch.Tensor, Optional[int]]: """forward decoder""" # audio_embed: N x T x D => N x D x T @@ -1027,8 +1076,9 @@ def forward( # audio_embed: N x D x 1 x T => N x DK x T' padding = audio_embed.shape[-1] % self.window_size if padding > 0: - audio_embed = F.pad(audio_embed, (0, self.window_size - padding), - "constant", 0) + audio_embed = F.pad( + audio_embed, (0, self.window_size - padding), "constant", 0 + ) embed_chunk = F.unfold( audio_embed[..., None, :], @@ -1045,10 +1095,7 @@ def forward( # NT' x 1 x D q = self.queries.expand(bsz * slen, -1, -1) for layer in self.decoders: - q = layer(tgt=q, - memory=embed_chunk, - tgt_mask=None, - memory_mask=mask) + q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask) if self.after_norm is not None: q = self.after_norm(q) @@ -1068,8 +1115,7 @@ def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: super().__init__() self.config = config # n_embed or hidden_size for text LM - hidden_size = (config.n_embd - if hasattr(config, "n_embd") else config.hidden_size) + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size # self.wte = nn.Embedding(config.vocab_size, hidden_size) @@ -1078,8 +1124,10 @@ def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: ) self.layer_idx = -2 - if (isinstance(config.audio_processor, dict) - and config.audio_processor.get("name", None) == "cascades"): + if ( + isinstance(config.audio_processor, dict) + and config.audio_processor.get("name", None) == "cascades" + ): encoder_config = config.audio_processor.get("config", None) assert encoder_config is not None self.encoder = ConformerEncoder(**encoder_config) @@ -1089,13 +1137,11 @@ def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: else: raise NotImplementedError("") - assert (audio_dim_out - is not None), "Remember to set values for audio_dim_out" + assert audio_dim_out is not None, "Remember to set values for audio_dim_out" self.audio_dim_out = audio_dim_out self.audio_dim_in = n_mels - self.freeze_audio_processor = kwargs.get("freeze_audio_processor", - False) + self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False) self.downsample_rate = kwargs.get("downsample_rate", 1) @@ -1107,8 +1153,9 @@ def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: self.qformer = None if kwargs.get("use_conv_downsample", False): - assert (self.qformer is None - ), "don't support use qformer and conv downsample together" + assert self.qformer is None, ( + "don't support use qformer and conv downsample together" + ) nemo_conv_settings = kwargs.get("nemo_conv_settings", {}) default_nemo_conv_settings = { "subsampling": "dw_striding", @@ -1124,11 +1171,13 @@ def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: if nemo_conv_settings: default_nemo_conv_settings.update(nemo_conv_settings) for i in ["subsampling_factor", "feat_in", "feat_out"]: - assert ( - i not in nemo_conv_settings - ), "{i} should be specified outside of the NeMo dictionary" + assert i not in nemo_conv_settings, ( + "{i} should be specified outside of the NeMo dictionary" + ) - self.conv_ds = NemoConvSubsampling(**default_nemo_conv_settings, ) + self.conv_ds = NemoConvSubsampling( + **default_nemo_conv_settings, + ) else: self.conv_ds = None @@ -1140,30 +1189,26 @@ def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None: # (do not use image_projection and image_proj_norm) dim_projection = hidden_size depth = 2 - self.linear_downsample_rate = (1 if (self.qformer or self.conv_ds) - else self.downsample_rate) + self.linear_downsample_rate = ( + 1 if (self.qformer or self.conv_ds) else self.downsample_rate + ) layers = [ - nn.Linear(audio_dim_out * self.linear_downsample_rate, - dim_projection) + nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection) ] for _ in range(1, depth): - layers.extend( - [nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.audio_projection = nn.Sequential(*layers) # NOTE vision-speech tasks use a separate projection layer layers = [ - nn.Linear(audio_dim_out * self.linear_downsample_rate, - dim_projection) + nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection) ] for _ in range(1, depth): - layers.extend( - [nn.GELU(), - nn.Linear(dim_projection, dim_projection)]) + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.audio_projection_for_vision = nn.Sequential(*layers) else: raise NotImplementedError( - f"projection_cls = {projection_cls}, not implemented") + f"projection_cls = {projection_cls}, not implemented" + ) # TODO: audio sequence compression - Qformer self.vocab_size = config.vocab_size @@ -1188,11 +1233,9 @@ def get_audio_features( """ if self.freeze_audio_processor: with torch.no_grad(): - audio_features, masks = self.encoder(input_embeds, - audio_attention_mask) + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) else: - audio_features, masks = self.encoder(input_embeds, - audio_attention_mask) + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) if self.qformer is not None: audio_features, _ = self.qformer(audio_features, mask=None) @@ -1221,14 +1264,13 @@ def get_audio_features( feat_dim * self.linear_downsample_rate, ) - if audio_projection_mode == 'speech': + if audio_projection_mode == "speech": audio_set_tensor = self.audio_projection(audio_features) - elif audio_projection_mode == 'vision': + elif audio_projection_mode == "vision": audio_set_tensor = self.audio_projection_for_vision(audio_features) else: raise ValueError( - f"audio_projection_mode = {audio_projection_mode} not "\ - "implemented" + f"audio_projection_mode = {audio_projection_mode} not implemented" ) return audio_set_tensor @@ -1242,7 +1284,7 @@ def forward( """ arguments: audio_features: audio features (T, D) - + returns: audio_embeds: audio embeddings (num_audio_tokens, hidden_dim) """ diff --git a/vllm/model_executor/models/phi4mm_utils.py b/vllm/model_executor/models/phi4mm_utils.py index 6fbfca619a42..d50547c199ac 100644 --- a/vllm/model_executor/models/phi4mm_utils.py +++ b/vllm/model_executor/models/phi4mm_utils.py @@ -43,18 +43,17 @@ def get_activation(name: str = "relu") -> torch.nn.Module: return nn.Identity() -def adaptive_enc_mask(x_len: int, - chunk_start_idx: list[int], - left_window: int = 0, - right_window: int = 0) -> torch.Tensor: +def adaptive_enc_mask( + x_len: int, chunk_start_idx: list[int], left_window: int = 0, right_window: int = 0 +) -> torch.Tensor: """ The function is very important for Transformer Transducer Streaming mode Args: x_len: sequence length - chunk_start_idx: first idx of each chunk, such as [0,18,36,48]. + chunk_start_idx: first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] left_window: how many left chunks can be seen - right_window: how many right chunks can be seen. It is used for + right_window: how many right chunks can be seen. It is used for chunk overlap model. Returns: mask (torch.Tensor): a mask tensor for streaming model @@ -67,21 +66,23 @@ def adaptive_enc_mask(x_len: int, [False., True., True., False.], [False., False., True., True.]]) """ - chunk_start_idx = torch.Tensor(chunk_start_idx).long( - ) # first idx of each chunk, such as [0,18,36,48]. + chunk_start_idx = torch.Tensor( + chunk_start_idx + ).long() # first idx of each chunk, such as [0,18,36,48]. start_pad = torch.nn.functional.pad( - chunk_start_idx, - (1, 0)) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + chunk_start_idx, (1, 0) + ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] end_pad = torch.nn.functional.pad( chunk_start_idx, (0, 1), value=x_len ) # append x_len to the end, so it becomes [0,18,36,48, x_len] - seq_range = torch.arange(0, - x_len).unsqueeze(-1) # seq_range size: [x_len, 1] - idx = ((seq_range < end_pad) & - (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len] + seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1] + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[ + :, 1 + ] # idx size: [x_len] # boundary = end_pad[idx] # boundary size: [x_len] - seq_range_expand = (torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) - ) # seq_range_expand size [x_len, x_len] + seq_range_expand = ( + torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + ) # seq_range_expand size [x_len, x_len] idx_left = idx - left_window idx_left[idx_left < 0] = 0 boundary_left = start_pad[idx_left] @@ -231,18 +232,23 @@ def forward(self, x: Tensor) -> Tensor: x = self.ext_pw_conv_1d(x) if self.glu_type == "bilinear": if self.bias_in_glu: - x = (x[:, 0:self.output_dim, :] + self.b1) * ( - x[:, self.output_dim:self.output_dim * 2, :] + self.b2) + x = (x[:, 0 : self.output_dim, :] + self.b1) * ( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) else: - x = (x[:, 0:self.output_dim, :]) * ( - x[:, self.output_dim:self.output_dim * 2, :]) + x = ( + (x[:, 0 : self.output_dim, :]) + * (x[:, self.output_dim : self.output_dim * 2, :]) + ) else: if self.bias_in_glu: - x = (x[:, 0:self.output_dim, :] + self.b1) * self.glu_act( - x[:, self.output_dim:self.output_dim * 2, :] + self.b2) + x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) else: - x = (x[:, 0:self.output_dim, :]) * self.glu_act( - x[:, self.output_dim:self.output_dim * 2, :]) + x = (x[:, 0 : self.output_dim, :]) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + ) x = x.permute([0, 2, 1]) return x @@ -257,7 +263,7 @@ class DepthWiseSeperableConv1d(nn.Module): input_dim: int input channel size. depthwise_seperable_out_channel: int - if set different to 0, the number of + if set different to 0, the number of depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. otherwise, it equals to 0, the second conv1d layer is skipped. @@ -327,7 +333,7 @@ class ConvModule(nn.Module): if > 0, ext_pw_out_channel is a dim channel size for the last pointwise conv after swish activation. depthwise_seperable_out_channel: int - if set different to 0, the number of + if set different to 0, the number of depthwise_seperable_out_channel will be used as a channel_out of the second conv1d layer. otherwise, it equal to 0, the second conv1d layer is skipped. @@ -431,12 +437,10 @@ def __init__( if depthwise_seperable_out_channel != 0: if input_dim != depthwise_seperable_out_channel: - self.ln2 = nn.Linear(depthwise_seperable_out_channel, - input_dim) + self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) else: if depthwise_multiplier != 1: - self.ln2 = nn.Linear(input_dim * depthwise_multiplier, - input_dim) + self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) def _add_ext_pw_layer(self) -> None: """ @@ -445,7 +449,8 @@ def _add_ext_pw_layer(self) -> None: of the conformer. """ self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = ( - nn.Identity()) # jit hacks. + nn.Identity() + ) # jit hacks. self.squeeze_excitation = nn.Identity() # jit. self.apply_ln1 = self.fix_len1 = False # jit. @@ -509,7 +514,7 @@ def forward(self, x: Tensor) -> Tensor: if self.ext_pw_out_channel != 0: x = self.glu(x) if self.causal and self.ext_pw_kernel_size > 1: - x = x[:, :-(self.ext_pw_kernel_size - 1), :] + x = x[:, : -(self.ext_pw_kernel_size - 1), :] if self.apply_ln1: x = self.ln1(x) else: @@ -521,7 +526,7 @@ def forward(self, x: Tensor) -> Tensor: x = self.dw_sep_conv_1d(x) if self.causal and self.kernel_size > 1: - x = x[:, :, :-(self.kernel_size - 1)] + x = x[:, :, : -(self.kernel_size - 1)] if hasattr(self, "ln2"): x = x.permute([0, 2, 1]) x = self.ln2(x) @@ -533,7 +538,7 @@ def forward(self, x: Tensor) -> Tensor: if self.ext_pw_out_channel != 0: x = self.ext_pw_conv_1d(x) if self.fix_len1: - x = x[:, :, :-(self.ext_pw_kernel_size - 1)] + x = x[:, :, : -(self.ext_pw_kernel_size - 1)] if self.apply_ln1: x = x.permute([0, 2, 1]) @@ -652,7 +657,7 @@ def _pre_hook( Note: We saved self.pe until v.0.5.2 but we have omitted it later. - Therefore, we remove the item "pe" from `state_dict` for backward + Therefore, we remove the item "pe" from `state_dict` for backward compatibility. """ @@ -663,7 +668,7 @@ def _pre_hook( class T5RelativeAttentionLogitBias(nn.Module): """ - This module implements the relative position bias described in Section + This module implements the relative position bias described in Section 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf The Huggingface implementation is used as a reference @@ -671,18 +676,18 @@ class T5RelativeAttentionLogitBias(nn.Module): transformers/models/t5/modeling_t5.py#L435 Modifies attention as Q*K^T + B, where B is a learned scalar bias based - on relative position of the query and key. It is HxNxN, where H is the + on relative position of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length. I've made these modifications to the original T5 bias: - - Skipping of the bucketing step. Original T5 bias converted rel - position distances into logarithmically increasing buckets. This is + - Skipping of the bucketing step. Original T5 bias converted rel + position distances into logarithmically increasing buckets. This is supposed to help with length generalization. - - I just directly use rel position index as bias values, as we don't - need length generalization (40s max is good enough for ASR encoder), + - I just directly use rel position index as bias values, as we don't + need length generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple. - - I've also extended it so that biases can be asymmetric, the default - implementation treats L->R and R->L the same. Asymmetric was found to + - I've also extended it so that biases can be asymmetric, the default + implementation treats L->R and R->L the same. Asymmetric was found to yield better results in my experiments. Args: @@ -690,26 +695,28 @@ class T5RelativeAttentionLogitBias(nn.Module): Number of attention heads num_buckets: int Number of buckets to use for relative attention bias. This is the - size of the learnable bias parameter. Bucketing is not yet + size of the learnable bias parameter. Bucketing is not yet supported, so this defaults to -1 which means no bucketing is used (max_distance determines size of bias param). max_distance: int - Maximum distance to use for relative attention bias. With - num_buckets=-1, this directly controls the max size of the bias - parameter. When num_buckets > 0 is supported, this will control - the maximum distance for logarithmic bucketing after which all + Maximum distance to use for relative attention bias. With + num_buckets=-1, this directly controls the max size of the bias + parameter. When num_buckets > 0 is supported, this will control + the maximum distance for logarithmic bucketing after which all positions are in the same bucket. symmetric: bool Whether to use symmetric or asymmetric biases. symmetric=False uses - 2x number of bias params to distinguish L->R from R->L. This was + 2x number of bias params to distinguish L->R from R->L. This was found to be better for the encoder. """ - def __init__(self, - num_heads: int, - num_buckets: int = -1, - max_distance: int = 1000, - symmetric: bool = False) -> None: + def __init__( + self, + num_heads: int, + num_buckets: int = -1, + max_distance: int = 1000, + symmetric: bool = False, + ) -> None: super().__init__() self.num_heads = num_heads self.num_buckets = num_buckets @@ -720,7 +727,8 @@ def __init__(self, self.num_buckets = max_distance else: raise NotImplementedError( - "T5 attention bias with bucketed positions is not yet tested") + "T5 attention bias with bucketed positions is not yet tested" + ) if not self.symmetric: self.num_buckets *= 2 self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) @@ -728,19 +736,21 @@ def __init__(self, def forward(self, x: Tensor) -> Tensor: # instantiate bias compatible with shape of x maxpos = x.size(1) - context_position = torch.arange(maxpos, - device=x.device, - dtype=torch.long)[:, None] - memory_position = torch.arange(maxpos, - device=x.device, - dtype=torch.long)[None, :] + context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[ + :, None + ] + memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[ + None, : + ] relative_position = memory_position - context_position # clipping to a maximum distance using ops that play well with ONNX # export relative_position = relative_position.masked_fill( - relative_position < -self.max_distance, -self.max_distance) + relative_position < -self.max_distance, -self.max_distance + ) relative_position = relative_position.masked_fill( - relative_position > self.max_distance - 1, self.max_distance - 1) + relative_position > self.max_distance - 1, self.max_distance - 1 + ) # mapping from relative position to index in the bias parameter if self._skip_bucketing: @@ -753,8 +763,7 @@ def forward(self, x: Tensor) -> Tensor: bias_idx += self.num_buckets // 2 t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] - t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze( - 0) # [1, H, L, L] + t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L] return t5_rel_att_bias @@ -766,11 +775,13 @@ def _bucket_relative_position(self, relative_position: Tensor) -> Tensor: if not self.causal: self.num_buckets //= 2 relative_buckets += (relative_position > 0).to( - torch.long) * self.num_buckets + torch.long + ) * self.num_buckets relative_position = torch.abs(relative_position) else: - relative_position = -torch.min(relative_position, - torch.zeros_like(relative_position)) + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -780,16 +791,18 @@ def _bucket_relative_position(self, relative_position: Tensor) -> Tensor: # The other half of the buckets are for logarithmically bigger bins in # positions up to max_distance relative_position_if_large = max_exact + ( - torch.log(relative_position.float() / max_exact) / - math.log(self.max_distance / max_exact) * - (self.num_buckets - max_exact)).to(torch.long) + torch.log(relative_position.float() / max_exact) + / math.log(self.max_distance / max_exact) + * (self.num_buckets - max_exact) + ).to(torch.long) relative_position_if_large = torch.min( relative_position_if_large, torch.full_like(relative_position_if_large, self.num_buckets - 1), ) - relative_buckets += torch.where(is_small, relative_position, - relative_position_if_large) + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) return relative_buckets @@ -808,10 +821,7 @@ class AbsolutePositionalEncoding(nn.Module): """ - def __init__(self, - d_model: int, - dropout_rate: float, - max_len: int = 5000) -> None: + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: """Construct an PositionalEncoding object.""" super().__init__() self.d_model = d_model @@ -834,8 +844,9 @@ def extend_pe(self, x: torch.Tensor) -> None: pe = torch.zeros(x.size(1), self.d_model) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) * - -(math.log(10000.0) / self.d_model)) + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) @@ -852,7 +863,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ self.extend_pe(x) - x = x * self.xscale + self.pe[:, :x.size(1)] + x = x * self.xscale + self.pe[:, : x.size(1)] return self.dropout(x) @@ -889,14 +900,14 @@ class CausalConv1D(nn.Conv1d): locations on its right or left All arguments are the same as nn.Conv1d except padding. - If padding is set None, then paddings are set automatically to make it a + If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right. - If padding is set as a list (size of 2), then padding[0] would be used as + If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding. It would make it possible to control the number of steps to be accessible on the right and left. - This mode is not supported when stride > 1. padding[0]+padding[1] should + This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1). """ @@ -920,13 +931,15 @@ def __init__( self._right_padding = stride - 1 else: if stride != 1 and padding != kernel_size - 1: - raise ValueError( - "No striding allowed for non-symmetric convolutions!") + raise ValueError("No striding allowed for non-symmetric convolutions!") if isinstance(padding, int): self._left_padding = padding self._right_padding = padding - elif (isinstance(padding, list) and len(padding) == 2 - and padding[0] + padding[1] == kernel_size - 1): + elif ( + isinstance(padding, list) + and len(padding) == 2 + and padding[0] + padding[1] == kernel_size - 1 + ): self._left_padding = padding[0] self._right_padding = padding[1] else: @@ -949,9 +962,8 @@ def __init__( ) def update_cache( - self, - x: Tensor, - cache: Optional[Tensor] = None) -> tuple[Tensor, Optional[Tensor]]: + self, x: Tensor, cache: Optional[Tensor] = None + ) -> tuple[Tensor, Optional[Tensor]]: if cache is None: new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) next_cache = cache @@ -959,16 +971,14 @@ def update_cache( new_x = F.pad(x, pad=(0, self._right_padding)) new_x = torch.cat([cache, new_x], dim=-1) if self.cache_drop_size > 0: - next_cache = new_x[:, :, :-self.cache_drop_size] + next_cache = new_x[:, :, : -self.cache_drop_size] else: next_cache = new_x - next_cache = next_cache[:, :, -cache.size(-1):] + next_cache = next_cache[:, :, -cache.size(-1) :] return new_x, next_cache def forward( - self, - x: Tensor, - cache: Optional[Tensor] = None + self, x: Tensor, cache: Optional[Tensor] = None ) -> Union[Tensor, tuple[Tensor, Optional[Tensor]]]: x, cache = self.update_cache(x, cache=cache) x = super().forward(x) @@ -982,7 +992,7 @@ class CausalConv2D(nn.Conv2d): """ A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down - All arguments are the same as nn.Conv2d except padding which should be + All arguments are the same as nn.Conv2d except padding which should be set as None """ @@ -1001,8 +1011,7 @@ def __init__( dtype=None, ) -> None: if padding is not None: - raise ValueError( - "Argument padding should be set to None for CausalConv2D.") + raise ValueError("Argument padding should be set to None for CausalConv2D.") self._left_padding = kernel_size - 1 self._right_padding = stride - 1 @@ -1038,17 +1047,17 @@ class NemoConvSubsampling(torch.nn.Module): (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) - Striding Subsampling: "Speech-Transformer: A No-Recurrence - Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong + Striding Subsampling: "Speech-Transformer: A No-Recurrence + Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506) - Compared with the EncoderConv2D (`input_layer: custom`), this is a + Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach, and uses no LayerNorm and far fewer Conv2Ds. Moreover, depthwise convolutions are used to reduce FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy. - `Striding` and `dw_striding` are the same except that the latter uses + `Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions after the first layer, whereas the former does not. Args: @@ -1056,11 +1065,11 @@ class NemoConvSubsampling(torch.nn.Module): feat_in (int): size of the input features feat_out (int): size of the output features subsampling (str): The subsampling technique, choose from - {"striding", "dw-striding", "striding_conv1d", + {"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"} - conv_channels (int): Number of channels for the convolution layers, + conv_channels (int): Number of channels for the convolution layers, default is 256. - subsampling_conv_chunking_factor (int): Input chunking factor which + subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1 activation (Module): activation function, default is nn.ReLU() is_causal (bool): whether to use causal Conv1/2D, where each step will @@ -1095,15 +1104,15 @@ def __init__( "striding_conv1d", ) - if (subsampling_conv_chunking_factor != -1 - and subsampling_conv_chunking_factor != 1 - and subsampling_conv_chunking_factor % 2 != 0): + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): raise ValueError( - "subsampling_conv_chunking_factor should be -1, 1, or a "\ - "power of 2" + "subsampling_conv_chunking_factor should be -1, 1, or a power of 2" ) - self.subsampling_conv_chunking_factor = \ - subsampling_conv_chunking_factor + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor in_channels = 1 layers = [] @@ -1131,7 +1140,8 @@ def __init__( kernel_size=self._kernel_size, stride=self._stride, padding=None, - )) + ) + ) else: layers.append( torch.nn.Conv2d( @@ -1140,7 +1150,8 @@ def __init__( kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, - )) + ) + ) in_channels = conv_channels layers.append(activation) @@ -1154,7 +1165,8 @@ def __init__( stride=self._stride, padding=None, groups=in_channels, - )) + ) + ) else: layers.append( torch.nn.Conv2d( @@ -1164,7 +1176,8 @@ def __init__( stride=self._stride, padding=self._left_padding, groups=in_channels, - )) + ) + ) layers.append( torch.nn.Conv2d( @@ -1174,7 +1187,8 @@ def __init__( stride=1, padding=0, groups=1, - )) + ) + ) layers.append(activation) in_channels = conv_channels @@ -1201,7 +1215,8 @@ def __init__( kernel_size=self._kernel_size, stride=self._stride, padding=None, - )) + ) + ) else: layers.append( torch.nn.Conv2d( @@ -1210,7 +1225,8 @@ def __init__( kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, - )) + ) + ) layers.append(activation) in_channels = conv_channels @@ -1235,22 +1251,30 @@ def __init__( layers.append( CausalConv1D( in_channels=in_channels, - out_channels=(feat_out if self._sampling_num == i + - 1 else conv_channels), + out_channels=( + feat_out + if self._sampling_num == i + 1 + else conv_channels + ), kernel_size=self._kernel_size, stride=self._stride, padding=None, - )) + ) + ) else: layers.append( torch.nn.Conv1d( in_channels=in_channels, - out_channels=(feat_out if self._sampling_num == i + - 1 else conv_channels), + out_channels=( + feat_out + if self._sampling_num == i + 1 + else conv_channels + ), kernel_size=self._kernel_size, stride=self._stride, padding=self._left_padding, - )) + ) + ) layers.append(activation) in_channels = conv_channels @@ -1265,30 +1289,8 @@ def __init__( self._right_padding = (self._kernel_size - 1) // 2 # Layer 1 - layers.extend([ - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=self._kernel_size, - stride=self._stride, - padding=self._left_padding, - groups=in_channels, - ), - torch.nn.Conv1d( - in_channels=in_channels, - out_channels=(feat_out if self._sampling_num == 1 else - conv_channels), - kernel_size=1, - stride=1, - padding=0, - groups=1, - ), - ]) - in_channels = conv_channels - layers.append(activation) - - for i in range(self._sampling_num - 1): - layers.extend([ + layers.extend( + [ torch.nn.Conv1d( in_channels=in_channels, out_channels=in_channels, @@ -1299,14 +1301,44 @@ def __init__( ), torch.nn.Conv1d( in_channels=in_channels, - out_channels=(feat_out if self._sampling_num == i + - 2 else conv_channels), + out_channels=( + feat_out if self._sampling_num == 1 else conv_channels + ), kernel_size=1, stride=1, padding=0, groups=1, ), - ]) + ] + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=( + feat_out + if self._sampling_num == i + 2 + else conv_channels + ), + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) layers.append(activation) in_channels = conv_channels @@ -1323,8 +1355,7 @@ def __init__( ceil_mode=self._ceil_mode, repeat_num=self._sampling_num, ) - self.out = torch.nn.Linear(conv_channels * int(out_length), - feat_out) + self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) self.conv2d_subsampling = True elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: self.out = None @@ -1340,8 +1371,9 @@ def get_sampling_frames(self) -> list[int]: def get_streaming_cache_size(self) -> list[int]: return [0, self.subsampling_factor + 1] - def forward(self, x: Tensor, - mask: Optional[Tensor]) -> tuple[Tensor, Optional[Tensor]]: + def forward( + self, x: Tensor, mask: Optional[Tensor] + ) -> tuple[Tensor, Optional[Tensor]]: """ Forward method for NeMo subsampling. @@ -1350,24 +1382,22 @@ def forward(self, x: Tensor, mask: input mask Returns: - x: Resulting tensor from subsampling (B, T // + x: Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) - pad_mask: tensor of padded hidden state sequences (B, 1, T // + pad_mask: tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) """ x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2) # split inputs if chunking_factor is set - if (self.subsampling_conv_chunking_factor != -1 - and self.conv2d_subsampling): + if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling: if self.subsampling_conv_chunking_factor == 1: # if subsampling_conv_chunking_factor is 1, we split only # if needed. # avoiding a bug / feature limiting indexing of tensors # to 2**31. # see https://github.com/pytorch/pytorch/issues/80020 - x_ceil = (2**31 / self._conv_channels * self._stride * - self._stride) + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride need_to_split = torch.numel(x) > x_ceil else: # if subsampling_conv_chunking_factor > 1 we always split @@ -1403,7 +1433,8 @@ def forward(self, x: Tensor, feature_lens_remainder = feature_lens % self.subsampling_factor padding_length[feature_lens_remainder != 1] += 1 pad_mask = torch.arange(0, max_audio_length, device=x.device).expand( - padding_length.size(0), -1) < padding_length.unsqueeze(1) + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) return x, pad_mask.unsqueeze(1) def reset_parameters(self) -> None: @@ -1412,27 +1443,22 @@ def reset_parameters(self) -> None: with torch.no_grad(): # init conv scale = 1.0 / self._kernel_size - dw_max = (self._kernel_size**2)**-0.5 + dw_max = (self._kernel_size**2) ** -0.5 pw_max = self._conv_channels**-0.5 torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) for idx in range(2, len(self.conv), 3): - torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, - dw_max) - torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, - dw_max) - torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, - pw_max) - torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, - pw_max) + torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max) # init fc (80 * 64 = 5120 from https://github.com/kssteven418/ # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/ # src/models/conformer_encoder.py#L487 - fc_scale = (self._feat_out * self._feat_in / - self._sampling_num)**-0.5 + fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5 torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) @@ -1456,15 +1482,14 @@ def conv_split_by_batch(self, x: Tensor) -> tuple[Tensor, bool]: return x, False return ( - torch.cat([ - self.conv(chunk) - for chunk in torch.split(x, new_batch_size, 0) - ]), + torch.cat( + [self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)] + ), True, ) def conv_split_by_channel(self, x: Tensor) -> Tensor: - """For dw convs, tries to split input by time, run conv and concat + """For dw convs, tries to split input by time, run conv and concat results""" x = self.conv[0](x) # full conv2D x = self.conv[1](x) # activation @@ -1489,22 +1514,21 @@ def conv_split_by_channel(self, x: Tensor) -> Tensor: if new_t == 0: new_t = 1 - x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, - x) # conv2D, depthwise + x = self.channel_chunked_conv( + self.conv[i * 3 + 2], new_c, x + ) # conv2D, depthwise # splitting pointwise convs by time x = torch.cat( - [ - self.conv[i * 3 + 3](chunk) - for chunk in torch.split(x, new_t, 2) - ], + [self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2, ) # conv2D, pointwise x = self.conv[i * 3 + 4](x) # activation return x - def channel_chunked_conv(self, conv: torch.nn.Module, chunk_size: int, - x: Tensor) -> Tensor: + def channel_chunked_conv( + self, conv: torch.nn.Module, chunk_size: int, x: Tensor + ) -> Tensor: """Performs channel chunked convolution""" ind = 0 @@ -1524,8 +1548,8 @@ def channel_chunked_conv(self, conv: torch.nn.Module, chunk_size: int, ) ch_out = nn.functional.conv2d( chunk, - conv.weight[ind:ind + step, :, :, :], - bias=conv.bias[ind:ind + step], + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], stride=self._stride, padding=0, groups=step, @@ -1533,8 +1557,8 @@ def channel_chunked_conv(self, conv: torch.nn.Module, chunk_size: int, else: ch_out = nn.functional.conv2d( chunk, - conv.weight[ind:ind + step, :, :, :], - bias=conv.bias[ind:ind + step], + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], stride=self._stride, padding=self._left_padding, groups=step, @@ -1545,30 +1569,33 @@ def channel_chunked_conv(self, conv: torch.nn.Module, chunk_size: int, return torch.cat(out_chunks, 1) def change_subsampling_conv_chunking_factor( - self, subsampling_conv_chunking_factor: int) -> None: - if (subsampling_conv_chunking_factor != -1 - and subsampling_conv_chunking_factor != 1 - and subsampling_conv_chunking_factor % 2 != 0): + self, subsampling_conv_chunking_factor: int + ) -> None: + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): raise ValueError( - "subsampling_conv_chunking_factor should be -1, 1, or a "\ - "power of 2" + "subsampling_conv_chunking_factor should be -1, 1, or a power of 2" ) self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor -def calc_length(lengths: Tensor, - all_paddings: int, - kernel_size: int, - stride: int, - ceil_mode: bool, - repeat_num: int = 1) -> Tensor: +def calc_length( + lengths: Tensor, + all_paddings: int, + kernel_size: int, + stride: int, + ceil_mode: bool, + repeat_num: int = 1, +) -> Tensor: """Calculates the output length of a Tensor passed through a convolution or - max pooling layer""" + max pooling layer""" add_pad: float = all_paddings - kernel_size one: float = 1.0 for i in range(repeat_num): - lengths = (torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + - one) + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths) return lengths.to(dtype=torch.int) @@ -1619,14 +1646,15 @@ def masked_softmax( mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) scores = scores.masked_fill(mask, -torch.inf) attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0) # (batch, head, time1, time2) + mask, 0.0 + ) # (batch, head, time1, time2) else: attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) return attn class MultiHeadedAttention(nn.Module): - """Multi-Head Attention layer with optional relative position embedding + """Multi-Head Attention layer with optional relative position embedding and GLU. Args: @@ -1642,12 +1670,12 @@ class MultiHeadedAttention(nn.Module): default: -1 (equal to n_feat). use_pt_scaled_dot_product_attention: bool, optional if set True, use pytorch scaled dot product attention in training. - NOTE: this will NOT be used in ONNX decoding due to a lack of - support. In that case, we use the original attention + NOTE: this will NOT be used in ONNX decoding due to a lack of + support. In that case, we use the original attention implementation, which shows no regression. default: False. n_value: int, optional - if set to values other than -1, use a different dimension for + if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible. group_size: int, optional. must divide `n_head` if group_size > 1: GQA @@ -1695,8 +1723,7 @@ def __init__( self.attn = torch.jit.Attribute(None, Optional[Tensor]) self.dropout = nn.Dropout(p=dropout_rate) self.dropout_rate = dropout_rate - self.use_pt_scaled_dot_product_attention = ( - use_pt_scaled_dot_product_attention) + self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention if use_pt_scaled_dot_product_attention and group_size > 1: raise ValueError("Cannot use PT Scaled Attention with GQA") @@ -1728,25 +1755,24 @@ def forward( pos_k: key tensor used for relative positional embedding. pos_v: value tensor used for relative positional embedding. mask: mask tensor (batch, time1, time2) - relative_attention_bias: bias added to attention logits w.r.t. + relative_attention_bias: bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) """ n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, - self.d_k) # (b, t, d) - k = self.linear_k(key).view(n_batch, -1, self.h_k, - self.d_k) # (b, t, d) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) # (b, t, d) + k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) # (b, t, d) v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) - q = (q.transpose(1, 2) if self.use_pt_scaled_dot_product_attention - and not torch.jit.is_scripting() else q.transpose(1, 2) * - self.inv_sqrt_d_k) + q = ( + q.transpose(1, 2) + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting() + else q.transpose(1, 2) * self.inv_sqrt_d_k + ) k = k.transpose(1, 2) # (batch, head_k, time2, d_k) v = v.transpose(1, 2) # (batch, head_k, time2, d_k) - if (self.use_pt_scaled_dot_product_attention - and not torch.jit.is_scripting()): + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting(): attn_mask = None if mask is not None: mask = mask.unsqueeze(1) @@ -1757,12 +1783,14 @@ def forward( if mask.dtype != q.dtype: attn_mask = attn_mask.to(q.dtype) - with torch.nn.attention.sdpa_kernel([ + with torch.nn.attention.sdpa_kernel( + [ torch.nn.attention.SDPBackend.FLASH_ATTENTION, torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, torch.nn.attention.SDPBackend.MATH, torch.nn.attention.SDPBackend.CUDNN_ATTENTION, - ]): + ] + ): x = torch.nn.functional.scaled_dot_product_attention( q, k, @@ -1780,14 +1808,17 @@ def forward( if self.h != self.h_k: B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) else: - reshape_q = (q.contiguous().view(n_batch * self.h, -1, - self.d_k).transpose(0, 1) - ) # (t1,nh,dk) - B = torch.matmul(reshape_q, - pos_k.transpose(-2, - -1)) # pos_k: (t1,dk,t2) - B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), - pos_k.size(1)) + reshape_q = ( + q.contiguous() + .view(n_batch * self.h, -1, self.d_k) + .transpose(0, 1) + ) # (t1,nh,dk) + B = torch.matmul( + reshape_q, pos_k.transpose(-2, -1) + ) # pos_k: (t1,dk,t2) + B = B.transpose(0, 1).view( + n_batch, self.h, pos_k.size(0), pos_k.size(1) + ) scores = A + B else: scores = A @@ -1800,20 +1831,24 @@ def forward( self.attn = attn p_attn = self.dropout(attn) - x = torch.matmul(p_attn.to(v.dtype), - v) # (batch, head, time1, d_k) + x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k) if pos_v is not None: - reshape_attn = (p_attn.contiguous().view( - n_batch * self.h, pos_v.size(0), - pos_v.size(1)).transpose(0, 1)) # (t1, bh, t2) - - attn_v = (torch.matmul(reshape_attn, pos_v).transpose( - 0, 1).contiguous().view(n_batch, self.h, pos_v.size(0), - self.d_k)) + reshape_attn = ( + p_attn.contiguous() + .view(n_batch * self.h, pos_v.size(0), pos_v.size(1)) + .transpose(0, 1) + ) # (t1, bh, t2) + + attn_v = ( + torch.matmul(reshape_attn, pos_v) + .transpose(0, 1) + .contiguous() + .view(n_batch, self.h, pos_v.size(0), self.d_k) + ) x = x + attn_v - x = (x.transpose(1, 2).contiguous().view(n_batch, -1, - self.h_k * self.d_k) - ) # (batch, time1, d_model) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k) + ) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) @@ -1830,7 +1865,7 @@ def forward(self, *args) -> tuple: def get_offset(input_layer: str, time_reduction: int) -> int: - """Get an offset. We will use the offset for determining #frames of a + """Get an offset. We will use the offset for determining #frames of a subsampled feature. Args: @@ -1841,7 +1876,7 @@ def get_offset(input_layer: str, time_reduction: int) -> int: """ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: return 3 - if input_layer in ("conv2d", ) and time_reduction == 6: + if input_layer in ("conv2d",) and time_reduction == 6: return 1 if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: return 7 @@ -1850,8 +1885,8 @@ def get_offset(input_layer: str, time_reduction: int) -> int: def unfold_tensor(xs_pad: Tensor, max_seq_len: int) -> Tensor: """ - For a given tensor with shape of (N, T, D), if sequence length T is - longer than max_seq_len, this function unfold it to a + For a given tensor with shape of (N, T, D), if sequence length T is + longer than max_seq_len, this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. Args: xs_pad: input tensor with shape (N, T, D) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 7308fef092b5..fee52edfe26c 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only PhiMoE model.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -36,26 +37,36 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import (QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class PhiMoEConfig(PretrainedConfig): - model_type = "phimoe" keys_to_ignore_at_inference = ["past_key_values"] @@ -128,7 +139,6 @@ def __init__( class mp(torch.autograd.Function): - @staticmethod def forward( ctx, @@ -173,8 +183,9 @@ def sparsemixer(scores, jitter_eps=0.01): # compute mask for sparsity mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) factor = scores.abs().clamp(min=mask_logits_threshold) - mask_logits_threshold = ((mask_logits_threshold - scores) / - factor) > (2 * jitter_eps) + mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > ( + 2 * jitter_eps + ) # apply mask masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) @@ -195,24 +206,21 @@ def sparsemixer(scores, jitter_eps=0.01): ) with torch.no_grad(): # compute mask for sparsity - mask_logits_threshold, max_ind = masked_scores.max(dim=-1, - keepdim=True) + mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True) factor = scores.abs().clamp(min=mask_logits_threshold) - mask_logits_threshold = ((mask_logits_threshold - scores) / - factor) > (2 * jitter_eps) + mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > ( + 2 * jitter_eps + ) # apply mask - masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, - float("-inf")) + masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float("-inf")) selected_experts_top2 = max_ind # compute scores for gradients masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) - multiplier_top2 = masked_gates_top2.gather(dim=-1, - index=selected_experts_top2) + multiplier_top2 = masked_gates_top2.gather(dim=-1, index=selected_experts_top2) multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) - selected_experts = torch.concat((selected_experts, selected_experts_top2), - dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1) return ( multiplier, @@ -226,8 +234,7 @@ def phimoe_routing_function( topk: int, renormalize: bool, ): - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" assert topk == 2, "Only top-2 routing is supported" assert renormalize is False, "Renormalization is not supported" @@ -278,7 +285,8 @@ def __init__( quant_config=quant_config, tp_size=tp_size, custom_routing_function=phimoe_routing_function, - prefix=f"{prefix}.experts") + prefix=f"{prefix}.experts", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -291,7 +299,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class PhiMoEAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -376,7 +383,6 @@ def forward( class PhiMoEDecoderLayer(nn.Module): - def __init__( self, config: PhiMoEConfig, @@ -393,8 +399,9 @@ def __init__( num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, - head_dim=getattr(config, "head_dim", - self.hidden_size // config.num_attention_heads), + head_dim=getattr( + config, "head_dim", self.hidden_size // config.num_attention_heads + ), rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, @@ -409,12 +416,12 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.block_sparse_moe", ) - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps, - elementwise_affine=True) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps, - elementwise_affine=True) + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True + ) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True + ) def forward( self, @@ -444,7 +451,6 @@ def forward( @support_torch_compile class PhiMoEModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -453,8 +459,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.config = config @@ -468,15 +477,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: PhiMoEDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") - self.norm = nn.LayerNorm(config.hidden_size, - eps=config.rms_norm_eps, - elementwise_affine=True) + config, cache_config, quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers", + ) + self.norm = nn.LayerNorm( + config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True + ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -507,10 +518,9 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states = self.norm(hidden_states) return hidden_states @@ -523,8 +533,7 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: num_experts=self.config.num_local_experts, ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -536,14 +545,15 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -594,8 +604,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -627,8 +638,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = vllm_config.quant_config - self.model = PhiMoEModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = PhiMoEModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -640,16 +652,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size), + if not lora_config + else lora_config.lora_vocab_padding_size + ), quant_config=None, bias=True, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -661,16 +677,16 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 1c6e3a31d985..65abebcf37de 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -10,17 +10,20 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk, - UserMessage) +from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.multimodal import ImageEncoder from PIL import Image from transformers import BatchFeature, PixtralVisionConfig, TensorType from transformers.image_utils import ImageInput from transformers.models.pixtral.image_processing_pixtral import ( - _num_image_tokens as _get_pixtral_hf_num_image_tokens) + _num_image_tokens as _get_pixtral_hf_num_image_tokens, +) from transformers.models.pixtral.modeling_pixtral import ( - PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) + PixtralRotaryEmbedding, + apply_rotary_pos_emb, + position_ids_in_meshgrid, +) from transformers.tokenization_utils_base import TextInput from vllm.config import VllmConfig @@ -28,37 +31,50 @@ from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalUUIDDict, NestedTensors) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalProcessingInfo, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalUUIDDict, + NestedTensors, +) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.tokenizer import (MistralTokenizer, - cached_tokenizer_from_config) +from vllm.transformers_utils.tokenizer import ( + MistralTokenizer, + cached_tokenizer_from_config, +) from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix -from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy, - resolve_visual_encoder_outputs) +from .utils import init_vllm_registered_model, maybe_prefix +from .vision import ( + VisionEncoderInfo, + VisionFeatureSelectStrategy, + resolve_visual_encoder_outputs, +) try: from xformers import ops as xops - if (current_platform.is_cuda() - and current_platform.has_device_capability(100)): + + if current_platform.is_cuda() and current_platform.has_device_capability(100): # Xformers FA is not compatible with B200 USE_XFORMERS_OPS = False else: @@ -76,13 +92,16 @@ class PixtralImagePixelInputs(TensorSchema): - c: Number of channels (3) - h: Height of each image - w: Width of each image - + The result of stacking `ImageEncoding.tokens` from each prompt. """ + type: Literal["pixel_values"] = "pixel_values" - images: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})] + images: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}), + ] class PixtralProcessorAdapter: @@ -150,7 +169,8 @@ def __call__( "Make sure to process your input via `mistral_common`'s " "tokenizer or pass a chat completion request. " "For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411.") + "https://github.com/vllm-project/vllm/issues/8411." + ) images_processed = list[torch.Tensor]() images_tokens = list[torch.Tensor]() @@ -163,16 +183,15 @@ def __call__( images_processed.append(image_processed) images_tokens.append(image_tokens) - return BatchFeature({ - "input_ids": - torch.cat(images_tokens)[None].expand(len(text), -1), - "images": - images_processed, - }) + return BatchFeature( + { + "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), + "images": images_processed, + } + ) class PixtralProcessingInfo(BaseProcessingInfo): - def get_tokenizer(self) -> MistralTokenizer: tokenizer = cached_tokenizer_from_config(self.ctx.model_config) if not isinstance(tokenizer, MistralTokenizer): @@ -209,7 +228,8 @@ def get_num_image_tokens( processor = self.get_hf_processor() ncols, nrows = processor.image_processor._image_to_num_tokens( - Image.new("RGB", (image_width, image_height))) + Image.new("RGB", (image_width, image_height)) + ) return ncols * nrows @@ -221,7 +241,6 @@ def get_image_size_with_most_features(self) -> ImageSize: class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -233,17 +252,17 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } def get_dummy_processor_inputs( @@ -259,23 +278,27 @@ def get_dummy_processor_inputs( dummy_images = dummy_mm_data.get("image", []) tokenization_kwargs = {"truncation": False} - request = ChatCompletionRequest(messages=[ - UserMessage(content=[ - TextChunk(text=dummy_text), - *(ImageChunk(image=image) for image in dummy_images), - ]), - ]) + request = ChatCompletionRequest( + messages=[ + UserMessage( + content=[ + TextChunk(text=dummy_text), + *(ImageChunk(image=image) for image in dummy_images), + ] + ), + ] + ) res = tokenizer.mistral.encode_chat_completion(request) dummy_tokens = res.tokens - return ProcessorInputs(prompt=dummy_tokens, - mm_data=dummy_mm_data, - tokenization_kwargs=tokenization_kwargs) - + return ProcessorInputs( + prompt=dummy_tokens, + mm_data=dummy_mm_data, + tokenization_kwargs=tokenization_kwargs, + ) -class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] - ): +class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]): def _get_mm_fields_config( self, hf_inputs: Mapping[str, NestedTensors], @@ -300,7 +323,8 @@ def get_replacement(item_idx: int): image_size = images.get_image_size(item_idx) ncols, nrows = processor.image_processor._image_to_num_tokens( - Image.new("RGB", (image_size.width, image_size.height))) + Image.new("RGB", (image_size.width, image_size.height)) + ) tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens[-1] = image_end_id @@ -335,11 +359,13 @@ def _cached_apply_hf_processor( return prompt_ids, mm_info, True -@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, - info=PixtralProcessingInfo, - dummy_inputs=PixtralDummyInputsBuilder) -class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + PixtralMultiModalProcessor, + info=PixtralProcessingInfo, + dummy_inputs=PixtralDummyInputsBuilder, +) +class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -374,8 +400,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vision_encoder = VisionTransformer(self.vision_args) if self.vision_args.add_pre_mm_projector_layer_norm: - self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, - eps=1e-5) + self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, eps=1e-5) if self.vision_args.mm_projector_id == PATCH_MERGE: self.patch_merger = PatchMerger( @@ -385,20 +410,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.vision_language_adapter = VisionLanguageAdapter( - self.vision_args, dim=config.text_config.hidden_size) + self.vision_args, dim=config.text_config.hidden_size + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[PixtralImagePixelInputs]: + self, **kwargs: object + ) -> Optional[PixtralImagePixelInputs]: images = kwargs.pop("images", None) if images is None: return None return PixtralImagePixelInputs( type="pixel_values", - images=flatten_bn(images), + images=images, ) def _process_image_input( @@ -407,23 +435,24 @@ def _process_image_input( ) -> tuple[torch.Tensor, ...]: images = image_input["images"] image_features = self.vision_encoder(images) - feature_sizes = [ - image_feature.shape[0] for image_feature in image_features - ] + feature_sizes = [image_feature.shape[0] for image_feature in image_features] image_features = torch.cat(image_features) if self.vision_args.add_pre_mm_projector_layer_norm: image_features = self.pre_mm_projector_norm(image_features) if self.vision_args.mm_projector_id == PATCH_MERGE: patch_size = self.vision_args.patch_size spatial_merge_size_square = self.vision_args.spatial_merge_size**2 - img_patch_dims = [(img.shape[1] // patch_size, - img.shape[2] // patch_size) for img in images] + img_patch_dims = [ + (img.shape[1] // patch_size, img.shape[2] // patch_size) + for img in images + ] feature_sizes = [ feature_size // spatial_merge_size_square for feature_size in feature_sizes ] - image_features = self.patch_merger(image_features, - image_sizes=img_patch_dims) + image_features = self.patch_merger( + image_features, image_sizes=img_patch_dims + ) image_embeds = self.vision_language_adapter(image_features) image_embeds = torch.split(image_embeds, feature_sizes) return image_embeds @@ -431,8 +460,7 @@ def _process_image_input( def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -451,10 +479,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -465,7 +492,6 @@ def compute_logits( return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith("vision_encoder") @@ -480,38 +506,42 @@ def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): # Get references to parameters for direct loading vision_encoder_dict = dict(self.vision_encoder.named_parameters()) - patch_merger_dict = dict(self.patch_merger.named_parameters( - )) if self.vision_args.mm_projector_id == PATCH_MERGE else dict() - pre_mm_projector_norm_dict = dict( - self.pre_mm_projector_norm.named_parameters( - )) if self.vision_args.add_pre_mm_projector_layer_norm else dict() - vision_lang_adapter_dict = dict( - self.vision_language_adapter.named_parameters()) + patch_merger_dict = ( + dict(self.patch_merger.named_parameters()) + if self.vision_args.mm_projector_id == PATCH_MERGE + else dict() + ) + pre_mm_projector_norm_dict = ( + dict(self.pre_mm_projector_norm.named_parameters()) + if self.vision_args.add_pre_mm_projector_layer_norm + else dict() + ) + vision_lang_adapter_dict = dict(self.vision_language_adapter.named_parameters()) def llm_weights_generator(): # Single pass over weights for name, w in weights: if is_vision_encoder_weights((name, w)): # Load vision encoder weights directly - trimmed_name = '.'.join(name.split(".")[1:]) + trimmed_name = ".".join(name.split(".")[1:]) param = vision_encoder_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_patch_merger((name, w)): # Load vision patch merger weights directly - trimmed_name = '.'.join(name.split(".")[1:]) + trimmed_name = ".".join(name.split(".")[1:]) param = patch_merger_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_pre_mm_projector_norm((name, w)): # Load vision pre_mm_projector_norm weights directly - trimmed_name = '.'.join(name.split(".")[1:]) + trimmed_name = ".".join(name.split(".")[1:]) param = pre_mm_projector_norm_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) elif is_vision_lang_adapter_weights((name, w)): # Load vision-language adapter weights directly - trimmed_name = '.'.join(name.split(".")[1:]) + trimmed_name = ".".join(name.split(".")[1:]) param = vision_lang_adapter_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) @@ -542,8 +572,7 @@ class VisionEncoderArgs: mm_projector_id: str = "" -def _reshape_for_broadcast(freqs_cis: torch.Tensor, - x: torch.Tensor) -> torch.Tensor: +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ freqs_cis: complex - (seq_len, head_dim / 2) x: complex - (bsz, seq_len, head_dim / 2) @@ -554,9 +583,7 @@ def _reshape_for_broadcast(freqs_cis: torch.Tensor, freqs_cis.shape, (x.shape[1], x.shape[-1]), ) - shape = [ - d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) - ] + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) @@ -571,7 +598,7 @@ def precompute_freqs_cis_2d( to be indexed by (height, width) position tuples """ # (dim / 2) frequency bases - freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim)) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) h = torch.arange(height, device=freqs.device) w = torch.arange(width, device=freqs.device) @@ -603,26 +630,18 @@ def apply_rotary_emb_vit( class FeedForward(nn.Module): - def __init__(self, args: VisionEncoderArgs): super().__init__() assert args.intermediate_size is not None - self.w1 = nn.Linear(args.hidden_size, - args.intermediate_size, - bias=False) - self.w2 = nn.Linear(args.intermediate_size, - args.hidden_size, - bias=False) - self.w3 = nn.Linear(args.hidden_size, - args.intermediate_size, - bias=False) + self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) + self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class Attention(nn.Module): - def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args @@ -656,10 +675,7 @@ def forward( q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) - out = nn.functional.scaled_dot_product_attention(q, - k, - v, - attn_mask=mask) + out = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) out = out.transpose(1, 2) out = out.reshape(batch, patches, self.n_heads * self.head_dim) @@ -667,7 +683,6 @@ def forward( class TransformerBlock(nn.Module): - def __init__(self, args: VisionEncoderArgs): super().__init__() self.attention = Attention(args) @@ -681,9 +696,9 @@ def forward( mask: torch.Tensor, freqs_cis: torch.Tensor, ) -> torch.Tensor: - r = self.attention.forward(self.attention_norm(x), - mask=mask, - freqs_cis=freqs_cis) + r = self.attention.forward( + self.attention_norm(x), mask=mask, freqs_cis=freqs_cis + ) h = x + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r @@ -691,7 +706,6 @@ def forward( class Transformer(nn.Module): - def __init__(self, args: VisionEncoderArgs): super().__init__() self.layers = torch.nn.ModuleList() @@ -709,22 +723,26 @@ def forward( return x -def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor: - positions = torch.cat([ - torch.stack( - torch.meshgrid( - torch.arange(p.shape[-2]), - torch.arange(p.shape[-1]), - indexing="ij", - ), - dim=-1, - ).reshape(-1, 2) for p in patch_embeds_list - ]) +def position_meshgrid( + patch_embeds_list: list[torch.Tensor], +) -> torch.Tensor: + positions = torch.cat( + [ + torch.stack( + torch.meshgrid( + torch.arange(p.shape[-2]), + torch.arange(p.shape[-1]), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + for p in patch_embeds_list + ] + ) return positions class VisionTransformer(nn.Module): - def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args @@ -786,9 +804,7 @@ def forward( self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images ] - patch_embeds = [ - p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list - ] + patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list] embed_sizes = [p.shape[1] for p in patch_embeds] # flatten to a single sequence @@ -802,13 +818,16 @@ def forward( # pass through Transformer with a block diagonal mask delimiting images if USE_XFORMERS_OPS: mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + ) else: from transformers.models.pixtral.modeling_pixtral import ( - generate_block_attention_mask) + generate_block_attention_mask, + ) + mask = generate_block_attention_mask( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], - patch_embeds) + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) # squeeze dim 0 and split into separate tensors for each image @@ -816,7 +835,6 @@ def forward( class VisionLanguageAdapter(nn.Module): - def __init__(self, args: VisionEncoderArgs, dim: int): super().__init__() assert isinstance(args, VisionEncoderArgs) @@ -856,8 +874,9 @@ def __init__( bias=use_mlp_bias, ) - def forward(self, x: torch.Tensor, - image_sizes: list[tuple[int, int]]) -> torch.Tensor: + def forward( + self, x: torch.Tensor, image_sizes: list[tuple[int, int]] + ) -> torch.Tensor: # image_sizes specified in tokens assert sum([h * w for h, w in image_sizes]) == len(x) @@ -889,15 +908,14 @@ def permute( """ sub_grids = get_sub_grids( - x=x, - image_sizes=image_sizes, - spatial_merge_size=self.spatial_merge_size + x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size ) # list of [d x sub_grid_size x sub_grid_size x n_patches] permuted_tensor: list[torch.Tensor] = [] for grid in sub_grids: n_patches = grid.shape[-1] - permuted_tensor.append(grid.view(-1, n_patches).t( - )) # n_patches x d * sub_grid_size * sub_grid_size + permuted_tensor.append( + grid.view(-1, n_patches).t() + ) # n_patches x d * sub_grid_size * sub_grid_size return torch.cat( permuted_tensor, dim=0 ) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2) @@ -917,14 +935,15 @@ def get_sub_grids( for image_index, image_tokens in enumerate(x.split(tokens_per_image)): # Reshape image_tokens into a 2D grid h, w = image_sizes[image_index] - image_grid = image_tokens.view(h, w, d).permute( - 2, 0, 1)[None, :, :, :] # 1 x d x h x w - sub_grids = torch.nn.functional.unfold(image_grid, - kernel_size=sub_grid_size, - stride=sub_grid_size) + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1)[ + None, :, :, : + ] # 1 x d x h x w + sub_grids = torch.nn.functional.unfold( + image_grid, kernel_size=sub_grid_size, stride=sub_grid_size + ) sub_grids = sub_grids.view( - 1, d, sub_grid_size, sub_grid_size, - -1) # 1 x d x sub_grid_size x sub_grid_size x n_patches + 1, d, sub_grid_size, sub_grid_size, -1 + ) # 1 x d x sub_grid_size x sub_grid_size x n_patches all_img_sub_grids.append(sub_grids[0]) @@ -940,7 +959,6 @@ def get_sub_grids( class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): - def get_num_image_tokens( self, *, @@ -993,7 +1011,6 @@ def get_patch_grid_size( class PixtralHFMLP(nn.Module): - def __init__( self, config: PixtralVisionConfig, @@ -1009,12 +1026,15 @@ def __init__( output_sizes=[config.intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(input_size=config.intermediate_size, - output_size=config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) self.act_and_mul = get_act_and_mul_fn(config.hidden_act) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -1025,7 +1045,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PixtralHFAttention(nn.Module): - def __init__( self, config: PixtralVisionConfig, @@ -1081,14 +1100,12 @@ def forward( # Transpose q and k back for attention q = q.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous() - out = xops.memory_efficient_attention(q, - k, - v, - attn_bias=attention_mask) + out = xops.memory_efficient_attention(q, k, v, attn_bias=attention_mask) else: v = v.transpose(1, 2) out = nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask) + q, k, v, attn_mask=attention_mask + ) out = out.transpose(1, 2) out = out.view(batch, patches, self.n_heads * self.head_dim) @@ -1098,7 +1115,6 @@ def forward( class PixtralHFTransformerBlock(nn.Module): - def __init__( self, config: PixtralVisionConfig, @@ -1109,12 +1125,12 @@ def __init__( super().__init__() self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) - self.attention = PixtralHFAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.attention") - self.feed_forward = PixtralHFMLP(config, - quant_config=quant_config, - prefix=f"{prefix}.feed_forward") + self.attention = PixtralHFAttention( + config, quant_config=quant_config, prefix=f"{prefix}.attention" + ) + self.feed_forward = PixtralHFMLP( + config, quant_config=quant_config, prefix=f"{prefix}.feed_forward" + ) self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) def forward( @@ -1123,9 +1139,11 @@ def forward( attention_mask: torch.Tensor, position_embeddings: torch.Tensor, ) -> torch.Tensor: - r, _ = self.attention.forward(self.attention_norm(hidden_states), - attention_mask=attention_mask, - position_embeddings=position_embeddings) + r, _ = self.attention.forward( + self.attention_norm(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) h = hidden_states + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r @@ -1133,7 +1151,6 @@ def forward( class PixtralHFTransformer(nn.Module): - def __init__( self, config: PixtralVisionConfig, @@ -1149,12 +1166,16 @@ def __init__( else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - PixtralHFTransformerBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + PixtralHFTransformerBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward( self, @@ -1177,7 +1198,6 @@ def forward( class PixtralHFVisionModel(nn.Module): - def __init__( self, config: PixtralVisionConfig, @@ -1211,7 +1231,8 @@ def __init__( raise ValueError( f"The original encoder only has {num_hidden_layers} " f"layers, but you requested {len(self.transformer.layers)} " - "layers.") + "layers." + ) if require_post_norm is True: msg = "PixtralHFVisionModel does not have post-layernorm" @@ -1219,8 +1240,7 @@ def __init__( self.dtype = next(self.parameters()).dtype self.device = next(self.parameters()).device - self.patch_positional_embedding = PixtralRotaryEmbedding( - config, self.device) + self.patch_positional_embedding = PixtralRotaryEmbedding(config, self.device) def forward( self, @@ -1245,13 +1265,10 @@ def forward( """ # pass images through initial convolution independently patch_embeds_list = [ - self.patch_conv(img.unsqueeze(0).to(self.dtype)) - for img in pixel_values + self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values ] - patch_embeds = [ - p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list - ] + patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list] embed_sizes = [p.shape[1] for p in patch_embeds] # flatten to a single sequence @@ -1261,20 +1278,22 @@ def forward( # positional embeddings position_ids = position_ids_in_meshgrid( patch_embeds_list, - max_width=self.config.image_size // self.config.patch_size).to( - self.device) - position_embedding = self.patch_positional_embedding( - patch_embeds, position_ids) + max_width=self.config.image_size // self.config.patch_size, + ).to(self.device) + position_embedding = self.patch_positional_embedding(patch_embeds, position_ids) if USE_XFORMERS_OPS: attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + ) else: from transformers.models.pixtral.modeling_pixtral import ( - generate_block_attention_mask) + generate_block_attention_mask, + ) + attention_mask = generate_block_attention_mask( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], - patch_embeds) + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds + ) out = self.transformer( patch_embeds, @@ -1296,8 +1315,7 @@ def forward( # (TODO) Add prefix argument for filtering out weights to be loaded # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -1317,7 +1335,7 @@ def load_weights(self, weights: Iterable[tuple[str, if layer_idx >= layer_count: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -1327,8 +1345,7 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 8234d40e94ab..278957e7cf6c 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only PLaMo2 model.""" + from collections.abc import Iterable from itertools import islice from typing import TYPE_CHECKING, Optional @@ -22,31 +23,45 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) -from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( - selective_state_update) + causal_conv1d_fn, + causal_conv1d_update, +) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update from vllm.model_executor.layers.mamba.ops.ssd_combined import ( - mamba_chunk_scan_combined_varlen) + mamba_chunk_scan_combined_varlen, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - composed_weight_loader, default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, - SupportsPP) + composed_weight_loader, + default_weight_loader, + sharded_weight_loader, +) +from vllm.model_executor.models.interfaces import HasInnerState, IsHybrid, SupportsPP from vllm.model_executor.models.utils import ( - is_pp_missing_parameter, make_empty_intermediate_tensors_factory, - make_layers, maybe_prefix) + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.utils import direct_register_custom_op @@ -89,12 +104,7 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: # transformers.models.mamba.modeling_mamba.MambaMixer @CustomOp.register(name="plamo2_mamba_mixer") class Plamo2MambaMixer(MambaBase, CustomOp): - - def __init__(self, - vllm_config: VllmConfig, - *, - prefix: str = "", - **kwargs) -> None: + def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None: super().__init__() self.config = vllm_config.model_config.hf_config self.cache_config = vllm_config.cache_config @@ -103,8 +113,9 @@ def __init__(self, self.hidden_size = self.config.hidden_size self.ssm_state_size = self.config.mamba_d_state self.conv_kernel_size = self.config.mamba_d_conv - self.intermediate_size = (self.config.mamba_num_heads * - self.config.hidden_size_per_head) + self.intermediate_size = ( + self.config.mamba_num_heads * self.config.hidden_size_per_head + ) self.tp_size = get_tensor_model_parallel_world_size() self.head_dim = self.config.hidden_size_per_head self.num_heads = self.config.mamba_num_heads @@ -155,17 +166,17 @@ def __init__(self, torch.empty( divide(self.num_heads, self.tp_size), dtype=torch.float32, - )) + ) + ) self.D = nn.Parameter(torch.ones(divide(self.num_heads, self.tp_size))) - self.dt_bias = nn.Parameter( - torch.ones(divide(self.num_heads, self.tp_size))) + self.dt_bias = nn.Parameter(torch.ones(divide(self.num_heads, self.tp_size))) set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) a_weight_loader = composed_weight_loader( - sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + sharded_weight_loader(0), lambda x: -torch.exp(x.float()) + ) set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) - set_weight_attrs(self.dt_bias, - {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) self.out_proj = RowParallelLinear( self.intermediate_size, @@ -179,12 +190,9 @@ def __init__(self, # The activation function is fixed to SiLU. self.activation = "silu" - self.dt_norm = RMSNorm(self.time_step_rank, - eps=self.config.rms_norm_eps) - self.B_norm = RMSNorm(self.ssm_state_size, - eps=self.config.rms_norm_eps) - self.C_norm = RMSNorm(self.ssm_state_size, - eps=self.config.rms_norm_eps) + self.dt_norm = RMSNorm(self.time_step_rank, eps=self.config.rms_norm_eps) + self.B_norm = RMSNorm(self.ssm_state_size, eps=self.config.rms_norm_eps) + self.C_norm = RMSNorm(self.ssm_state_size, eps=self.config.rms_norm_eps) self.chunk_size = self.config.mamba_chunk_size @@ -239,7 +247,6 @@ def forward_cuda( output: torch.Tensor, **kwargs, ): - forward_context = get_forward_context() # attn_metadata contains metadata necessary for the mamba2 triton # kernels to operate in continuous batching and in chunked prefill @@ -269,13 +276,15 @@ def forward_cuda( gate, hidden_states = projected_states.chunk(2, dim=-1) # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) if attn_metadata is None: # profile run - hidden_states = (hidden_states.transpose(0, 1).clone().transpose( - 0, 1)).contiguous() + hidden_states = ( + hidden_states.transpose(0, 1).clone().transpose(0, 1) + ).contiguous() output[:] = self.out_proj(hidden_states) return @@ -294,9 +303,9 @@ def forward_cuda( [num_decodes, num_prefill_tokens], dim=0, ) - gate_d, gate_p = torch.split(gate[:num_actual_tokens], - [num_decodes, num_prefill_tokens], - dim=0) + gate_d, gate_p = torch.split( + gate[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0 + ) # Split along batch dimension state_indices_tensor_d, state_indices_tensor_p = torch.split( state_indices_tensor, @@ -309,7 +318,7 @@ def forward_cuda( preallocated_ssm_out = torch.empty( [ num_prefill_tokens + num_decodes, - (self.num_heads // self.tp_size) * self.head_dim + (self.num_heads // self.tp_size) * self.head_dim, ], dtype=hidden_states.dtype, device=hidden_states.device, @@ -325,8 +334,7 @@ def forward_cuda( # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions # pointed to by "state_indices_tensor" - x = hidden_states_p.transpose( - 0, 1) # this is the form that causal-conv see + x = hidden_states_p.transpose(0, 1) # this is the form that causal-conv see hidden_states_p = causal_conv1d_fn( x, conv_weights, @@ -336,7 +344,8 @@ def forward_cuda( has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, metadata=attn_metadata, - query_start_loc=query_start_loc_p) + query_start_loc=query_start_loc_p, + ) hidden_states_p = hidden_states_p.transpose(0, 1) hidden_states_p = hidden_states_p[:num_prefill_tokens] # In some instances, the following `bcdt_proj` op @@ -352,20 +361,23 @@ def forward_cuda( # making a copy of the states initial_states = torch.where( has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) + ssm_state[state_indices_tensor_p], + 0, + ) varlen_state = mamba_chunk_scan_combined_varlen( - hidden_states_p.view(num_prefill_tokens, - self.num_heads // self.tp_size, - self.head_dim), + hidden_states_p.view( + num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim + ), dt, self.A, B.view(num_prefill_tokens, 1, -1), C.view(num_prefill_tokens, 1, -1), chunk_size=chunk_size, D=self.D, - z=gate_p.view(num_prefill_tokens, - self.num_heads // self.tp_size, self.head_dim), + z=gate_p.view( + num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim + ), dt_bias=self.dt_bias, seq_idx=seq_idx_p, cu_seqlens=query_start_loc_p, @@ -374,8 +386,7 @@ def forward_cuda( initial_states=initial_states, dt_softplus=True, dt_limit=(0.0, float("inf")), - out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, - self.head_dim), + out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim), state_dtype=ssm_state.dtype, ) @@ -392,21 +403,23 @@ def forward_cuda( conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor_d) + conv_state_indices=state_indices_tensor_d, + ) B, C, dt = self._project_ssm_parameters(hidden_states_d) # 3. State Space Model sequence transformation - A = self.A[:, None, ...][:, :, - None].expand(-1, self.head_dim, - self.config.mamba_d_state) + A = self.A[:, None, ...][:, :, None].expand( + -1, self.head_dim, self.config.mamba_d_state + ) dt = dt[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) D = self.D[:, None, ...].expand(-1, self.head_dim) B = B.unsqueeze(1) C = C.unsqueeze(1) hidden_states_d = hidden_states_d.view( - -1, self.num_heads // self.tp_size, self.head_dim) + -1, self.num_heads // self.tp_size, self.head_dim + ) # - the hidden is reshaped into (bs, num_heads, head_dim) # - ssm_state's slots will be selected @@ -425,8 +438,7 @@ def forward_cuda( dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices_tensor_d, - out=preallocated_ssm_out_d.view(num_decodes, -1, - self.head_dim), + out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), ) # 4. Final linear projection @@ -457,8 +469,8 @@ def mamba_type(self) -> str: return "mamba2" def get_attn_backend(self) -> type["AttentionBackend"]: - from vllm.v1.attention.backends.mamba2_attn import ( - Mamba2AttentionBackend) + from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend + return Mamba2AttentionBackend @@ -489,7 +501,6 @@ def plamo2_mamba_mixer_fake( class DenseMLP(nn.Module): - def __init__( self, config: Plamo2Config, @@ -508,12 +519,14 @@ def __init__( return_bias=False, ) self.act = SiluAndMul() - self.down_proj = RowParallelLinear(self.intermediate_size, - self.hidden_size, - bias=False, - prefix=f"{prefix}.down_proj", - quant_config=quant_config, - return_bias=False) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + prefix=f"{prefix}.down_proj", + quant_config=quant_config, + return_bias=False, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: h = self.gate_up_proj(hidden_states) @@ -522,12 +535,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Plamo2AttentionMixer(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - **kwargs) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> None: super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -560,20 +568,22 @@ def __init__(self, bias=False, quant_config=quant_config, ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - config.hidden_size, - bias=False, - quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) - self.rope_theta = config.rope_theta if hasattr(config, - "rope_theta") else 10000 - self.rope_scaling = config.rope_scaling if hasattr( - config, "rope_scaling") else None + self.rope_theta = config.rope_theta if hasattr(config, "rope_theta") else 10000 + self.rope_scaling = ( + config.rope_scaling if hasattr(config, "rope_scaling") else None + ) max_position = config.max_position_embeddings if hasattr(vllm_config.model_config, "max_model_len") and isinstance( - vllm_config.model_config.max_model_len, int): - max_position = min(max_position, - vllm_config.model_config.max_model_len) + vllm_config.model_config.max_model_len, int + ): + max_position = min(max_position, vllm_config.model_config.max_model_len) self.rotary_emb = get_rope( self.head_dim, @@ -582,22 +592,24 @@ def __init__(self, base=self.rope_theta, rope_scaling=self.rope_scaling, ) - self.q_norm = RMSNorm(config.hidden_size_per_head, - eps=config.rms_norm_eps) + self.q_norm = RMSNorm(config.hidden_size_per_head, eps=config.rms_norm_eps) self.q_norm.weight = torch.nn.Parameter( - torch.ones((self.num_heads, config.hidden_size_per_head))) - set_weight_attrs(self.q_norm.weight, - {"weight_loader": sharded_weight_loader(0)}) - self.k_norm = RMSNorm(config.hidden_size_per_head, - eps=config.rms_norm_eps) + torch.ones((self.num_heads, config.hidden_size_per_head)) + ) + set_weight_attrs( + self.q_norm.weight, {"weight_loader": sharded_weight_loader(0)} + ) + self.k_norm = RMSNorm(config.hidden_size_per_head, eps=config.rms_norm_eps) self.k_norm.weight = torch.nn.Parameter( - torch.ones((self.num_kv_heads, config.hidden_size_per_head))) + torch.ones((self.num_kv_heads, config.hidden_size_per_head)) + ) # Tensor-parallelism shards the K norm weights to the tp ranks # in a head-wise manner. This approach does not work if there is only # a single KV head, as is the case for PLaMo 2-1B. if self.total_num_kv_heads != 1: - set_weight_attrs(self.k_norm.weight, - {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs( + self.k_norm.weight, {"weight_loader": sharded_weight_loader(0)} + ) self.attn = Attention( self.num_heads, @@ -631,35 +643,30 @@ def forward( class Plamo2DecoderLayer(nn.Module): - - def __init__(self, - vllm_config: VllmConfig, - layer_idx: int, - prefix: str = "", - **kwargs) -> None: + def __init__( + self, vllm_config: VllmConfig, layer_idx: int, prefix: str = "", **kwargs + ) -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.is_mamba = is_mamba(config, layer_idx) if self.is_mamba: - self.mixer = Plamo2MambaMixer(vllm_config=vllm_config, - prefix=f"{prefix}.mixer") + self.mixer = Plamo2MambaMixer( + vllm_config=vllm_config, prefix=f"{prefix}.mixer" + ) else: - self.mixer = Plamo2AttentionMixer(vllm_config=vllm_config, - prefix=f"{prefix}.mixer") - - self.mlp = DenseMLP(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.pre_mixer_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_mixer_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_mlp_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_mlp_norm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.mixer = Plamo2AttentionMixer( + vllm_config=vllm_config, prefix=f"{prefix}.mixer" + ) + + self.mlp = DenseMLP( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) + self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -672,8 +679,7 @@ def forward( residual = hidden_states hidden_states = self.pre_mixer_norm(hidden_states) else: - hidden_states, residual = self.pre_mixer_norm( - hidden_states, residual) + hidden_states, residual = self.pre_mixer_norm(hidden_states, residual) if self.is_mamba: # Plamo2MambaMixer writes output to this tensor @@ -700,7 +706,6 @@ def forward( class Plamo2Decoder(torch.nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -708,13 +713,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) - return Plamo2DecoderLayer(vllm_config=vllm_config, - layer_idx=layer_idx, - prefix=prefix, - **extra_kwargs) + return Plamo2DecoderLayer( + vllm_config=vllm_config, + layer_idx=layer_idx, + prefix=prefix, + **extra_kwargs, + ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) def forward( self, @@ -733,7 +741,6 @@ def forward( @support_torch_compile class Plamo2Model(torch.nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -750,11 +757,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, prefix=f"{prefix}.embed_tokens", ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - self.layers = Plamo2Decoder(vllm_config=vllm_config, - prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + self.layers = Plamo2Decoder(vllm_config=vllm_config, prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -784,10 +790,9 @@ def forward( residual=residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -816,8 +821,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # the case for PLaMo2, as indicated by the FIXME comment. self.config.head_dim = self.config.hidden_size_per_head - self.model = Plamo2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Plamo2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.vocab_size = self.config.vocab_size self.unpadded_vocab_size = self.config.vocab_size num_embeddings = ((self.vocab_size + 15) // 16) * 16 @@ -831,23 +837,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: if self.config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs): - - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states @classmethod @@ -855,7 +865,6 @@ def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -877,8 +886,7 @@ def get_mamba_state_shape_from_config( """ parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config - intermediate_size =\ - hf_config.mamba_num_heads * hf_config.hidden_size_per_head + intermediate_size = hf_config.mamba_num_heads * hf_config.hidden_size_per_head return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, @@ -900,7 +908,6 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: - # Both tie_word_embeddings=True and lm_head.weight in the safetensor # at the same time causes dict key access error. if name == "lm_head.weight" and self.config.tie_word_embeddings: @@ -932,10 +939,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # Also, in addition to the quantized weights, # the zero points and scales have to be reshaped as well. # Packing should not be affected by this. - if ".mixer.in_proj.weight" in name \ - or "mixer.in_proj.qweight" in name \ - or "mixer.in_proj.scales" in name \ - or "mixer.in_proj.qzeros" in name: + if ( + ".mixer.in_proj.weight" in name + or "mixer.in_proj.qweight" in name + or "mixer.in_proj.scales" in name + or "mixer.in_proj.qzeros" in name + ): if "mixer.in_proj.weight" in name: loaded_weight = loaded_weight.transpose(0, 1) # for weight: @@ -945,14 +954,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # for scales and qzeros: # loaded_weight.shape[0] == self.config.hidden_size // self.vllm_config.quant_config.group_size # noqa loaded_weight = loaded_weight.reshape( - loaded_weight.shape[0], self.config.mamba_num_heads, -1) - gate_weight, hidden_states_weight = loaded_weight.chunk(2, - dim=-1) + loaded_weight.shape[0], self.config.mamba_num_heads, -1 + ) + gate_weight, hidden_states_weight = loaded_weight.chunk(2, dim=-1) gate_weight = gate_weight.reshape(loaded_weight.shape[0], -1) hidden_states_weight = hidden_states_weight.reshape( - loaded_weight.shape[0], -1) - loaded_weight = torch.cat([gate_weight, hidden_states_weight], - dim=-1) + loaded_weight.shape[0], -1 + ) + loaded_weight = torch.cat([gate_weight, hidden_states_weight], dim=-1) if "mixer.in_proj.weight" in name: loaded_weight = loaded_weight.transpose(0, 1) @@ -973,6 +982,5 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e0c08a6a8827..6a12776b7f94 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -6,6 +6,7 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" + import json from collections.abc import Iterable from itertools import islice @@ -21,21 +22,28 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class QWenMLP(nn.Module): @@ -51,16 +59,15 @@ def __init__( ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.c_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) + hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config + ) + self.c_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, quant_config=quant_config + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -71,7 +78,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class QWenAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -85,12 +91,10 @@ def __init__( ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = hidden_size // self.total_num_heads self.c_attn = QKVParallelLinear( hidden_size, @@ -114,12 +118,14 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -135,7 +141,6 @@ def forward( class QWenBlock(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -148,20 +153,22 @@ def __init__( rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - self.attn = QWenAttention(config.hidden_size, - config.num_attention_heads, - config.max_position_embeddings, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = QWenAttention( + config.hidden_size, + config.num_attention_heads, + config.max_position_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mlp = QWenMLP(config.hidden_size, - config.intermediate_size // 2, - quant_config=quant_config) + self.mlp = QWenMLP( + config.hidden_size, config.intermediate_size // 2, quant_config=quant_config + ) def forward( self, @@ -188,7 +195,6 @@ def forward( @support_torch_compile class QWenModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -205,13 +211,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: QWenBlock( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.h") + lambda prefix: QWenBlock(config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h", + ) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -241,16 +247,14 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states class QWenBaseModel(nn.Module): - def __init__( self, *, @@ -265,18 +269,21 @@ def __init__( self.config = config self.multimodal_config = multimodal_config self.quant_config = quant_config - self.transformer = transformer_type(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "transformer")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + self.transformer = transformer_type( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.transformer.make_empty_intermediate_tensors) + self.transformer.make_empty_intermediate_tensors + ) def compute_logits( self, @@ -285,8 +292,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w2", 0), @@ -297,7 +303,7 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -319,8 +325,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -338,14 +343,13 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config if hasattr(config, "visual"): - hf_overrides = { - "architectures": ["QwenVLForConditionalGeneration"] - } + hf_overrides = {"architectures": ["QwenVLForConditionalGeneration"]} raise RuntimeError( "The configuration of this model indicates that it supports " "vision inputs, but you instantiated the text-only version " "of this model. Please use the vision model by setting " - f"`--hf-overrides '{json.dumps(hf_overrides)}'`") + f"`--hf-overrides '{json.dumps(hf_overrides)}'`" + ) super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -356,6 +360,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index c536b0f60c30..c8bc17dbfa0a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -24,6 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -39,28 +40,38 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class Qwen2MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -85,8 +96,9 @@ def __init__( prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -97,7 +109,6 @@ def forward(self, x): class Qwen2Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -160,8 +171,11 @@ def __init__( rope_scaling=rope_scaling, dual_chunk_attention_config=dual_chunk_attention_config, ) - attn_cls = (EncoderOnlyAttention - if attn_type == AttentionType.ENCODER_ONLY else Attention) + attn_cls = ( + EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY + else Attention + ) self.attn = attn_cls( self.num_heads, self.head_dim, @@ -174,7 +188,10 @@ def __init__( **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}) + } + if dual_chunk_attention_config + else {}, + ) def forward( self, @@ -190,7 +207,6 @@ def forward( class Qwen2DecoderLayer(nn.Module): - def __init__( self, config: Qwen2Config, @@ -203,9 +219,9 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) - dual_chunk_attention_config = getattr(config, - "dual_chunk_attention_config", - None) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) # By default, Qwen2 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable @@ -236,10 +252,10 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -252,16 +268,14 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -274,14 +288,16 @@ def forward( "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class Qwen2Model(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config.get_text_config() @@ -297,14 +313,16 @@ def __init__(self, "to discuss this feature.".format( config.max_window_layers, config.num_hidden_layers, - )) + ) + ) self.config = config self.quant_config = quant_config self.vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -318,16 +336,18 @@ def __init__(self, decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: decoder_layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: decoder_layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -358,16 +378,16 @@ def forward( aux_hidden_states = [] for idx, layer in enumerate( - islice(self.layers, self.start_layer, self.end_layer)): + islice(self.layers, self.start_layer, self.end_layer) + ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) @@ -376,8 +396,7 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -391,18 +410,19 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -417,8 +437,7 @@ def load_weights(self, weights: Iterable[tuple[str, if name is None: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) else: @@ -435,8 +454,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -465,25 +483,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Qwen2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -502,8 +523,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -513,11 +535,9 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index b5c2aee7f231..1ab2f43c9d73 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -31,11 +31,15 @@ import torch.nn as nn from transformers.feature_extraction_utils import BatchFeature from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( - Qwen2_5OmniConfig, Qwen2_5OmniThinkerConfig) + Qwen2_5OmniConfig, + Qwen2_5OmniThinkerConfig, +) from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( - Qwen2_5OmniAudioEncoder) + Qwen2_5OmniAudioEncoder, +) from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import ( - Qwen2_5OmniProcessor) + Qwen2_5OmniProcessor, +) from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig @@ -44,33 +48,60 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( - Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs, - Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs, - Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs, - Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs) + Qwen2_5_VisionTransformer, + Qwen2_5_VLImageEmbeddingInputs, + Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLProcessingInfo, + Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, + Qwen2_5_VLVideoPixelInputs, +) from vllm.model_executor.models.qwen2_audio import ( - Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths) + Qwen2AudioProcessingInfo, + _get_feat_extract_output_lengths, +) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) -from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalPromptUpdates, - PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate) +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + DictEmbeddingItems, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + MultiModalPromptUpdates, + PlaceholderFeaturesInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import encode_tokens from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) try: import flash_attn @@ -88,6 +119,7 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema): - msl: Maximum sequence length - tsl: Total sequence length """ + type: Literal["audio_features"] input_features: Annotated[ Union[torch.Tensor, list[torch.Tensor]], @@ -101,52 +133,55 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema): def create_qwen2_5_omni_thinker_field_factory( - spatial_merge_size: int -) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, - MultiModalFieldConfig]]: - - def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, - torch.Tensor]): - audio_feature_lengths = hf_inputs.get("audio_feature_lengths", - torch.empty((0, ))) + spatial_merge_size: int, +) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, MultiModalFieldConfig]]: + def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): + audio_feature_lengths = hf_inputs.get( + "audio_feature_lengths", torch.empty((0,)) + ) image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_pixel_grid_sizes = image_grid_thw.prod(-1) - image_embed_grid_sizes = (image_pixel_grid_sizes // - spatial_merge_size // spatial_merge_size) + image_embed_grid_sizes = ( + image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size + ) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_sizes = video_grid_thw.prod(-1) - video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // - spatial_merge_size) + video_embed_grid_sizes = ( + video_grid_sizes // spatial_merge_size // spatial_merge_size + ) num_videos = len(video_grid_sizes) return dict( input_audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_feature_lengths, dim=1), + "audio", audio_feature_lengths, dim=1 + ), feature_attention_mask=MultiModalFieldConfig.batched("audio"), audio_feature_lengths=MultiModalFieldConfig.batched("audio"), pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_pixel_grid_sizes), + "image", image_pixel_grid_sizes + ), image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_embed_grid_sizes), + "image", image_embed_grid_sizes + ), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_embed_grid_sizes), + "video", video_embed_grid_sizes + ), video_grid_thw=MultiModalFieldConfig.batched("video"), second_per_grid_ts=MultiModalFieldConfig.batched("video"), - use_audio_in_video=MultiModalFieldConfig.shared( - "video", num_videos), + use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos), ) return _qwen2_5_omni_thinker_field_config class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser): - def __init__(self, spatial_merge_size: int, *args, **kwargs): self._spatial_merge_size = spatial_merge_size super().__init__(self._spatial_merge_size, *args, **kwargs) @@ -159,19 +194,18 @@ def _parse_audio_data( return DictEmbeddingItems( data, modality="audio", - required_fields={ - "input_audio_features", "audio_feature_lengths" - }, + required_fields={"input_audio_features", "audio_feature_lengths"}, fields_factory=create_qwen2_5_omni_thinker_field_factory( - self._spatial_merge_size), + self._spatial_merge_size + ), ) return super()._parse_audio_data(data) -class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo, - Qwen2_5_VLProcessingInfo): - +class Qwen2_5OmniThinkerProcessingInfo( + Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo +): def get_hf_config(self): return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config @@ -193,8 +227,8 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: class Qwen2_5OmniThinkerDummyInputsBuilder( - BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]): - + BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo] +): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) num_images = mm_counts.get("image", 0) @@ -206,8 +240,11 @@ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: image_token: str = hf_processor.image_token video_token: str = hf_processor.video_token - return (audio_token * num_audios + image_token * num_images + - video_token * num_videos) + return ( + audio_token * num_audios + + image_token * num_images + + video_token * num_videos + ) def get_dummy_mm_data( self, @@ -221,49 +258,55 @@ def get_dummy_mm_data( feature_extractor = self.info.get_feature_extractor() - target_audio_length = min( - feature_extractor.chunk_length, - 30, - ) * feature_extractor.sampling_rate - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_audio_length = ( + min( + feature_extractor.chunk_length, + 30, + ) + * feature_extractor.sampling_rate + ) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None audio_overrides = mm_options.get("audio") if mm_options else None mm_data = { - "audio": - self._get_dummy_audios(length=target_audio_length, - num_audios=num_audios, - overrides=audio_overrides), - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), - "video": - self._get_dummy_videos(width=target_width, - height=target_height, - num_frames=target_num_frames, - num_videos=num_videos, - overrides=video_overrides), + "audio": self._get_dummy_audios( + length=target_audio_length, + num_audios=num_audios, + overrides=audio_overrides, + ), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + overrides=video_overrides, + ), } return mm_data class Qwen2_5OmniThinkerMultiModalProcessor( - BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]): - + BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo] +): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return Qwen2_5OmniThinkerMultiModalDataParser( - spatial_merge_size=self.info.get_hf_config( - ).vision_config.spatial_merge_size, - target_sr=feature_extractor.sampling_rate) + spatial_merge_size=self.info.get_hf_config().vision_config.spatial_merge_size, + target_sr=feature_extractor.sampling_rate, + ) def _call_hf_processor( self, @@ -279,7 +322,9 @@ def _call_hf_processor( if audios: # NOTE: Qwen2.5-Omni processor accept "audio" mm_data["audio"] = audios - mm_kwargs = dict(**mm_kwargs, ) + mm_kwargs = dict( + **mm_kwargs, + ) hf_inputs = super()._call_hf_processor( prompt=prompt, @@ -288,17 +333,19 @@ def _call_hf_processor( tok_kwargs=tok_kwargs, ) - input_features = hf_inputs.pop('input_features', None) - feature_attention_mask = hf_inputs.get('feature_attention_mask', None) - if ('input_audio_features' not in hf_inputs - and input_features is not None): + input_features = hf_inputs.pop("input_features", None) + feature_attention_mask = hf_inputs.get("feature_attention_mask", None) + if "input_audio_features" not in hf_inputs and input_features is not None: if feature_attention_mask is not None: - input_features = input_features.permute( - 0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) - hf_inputs['input_audio_features'] = input_features - if ('audio_feature_lengths' not in hf_inputs - and feature_attention_mask is not None): - hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) + input_features = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) + hf_inputs["input_audio_features"] = input_features + if ( + "audio_feature_lengths" not in hf_inputs + and feature_attention_mask is not None + ): + hf_inputs["audio_feature_lengths"] = feature_attention_mask.sum(-1) video_second_per_grid = hf_inputs.get("video_second_per_grid", None) if video_second_per_grid is not None: @@ -315,8 +362,8 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return create_qwen2_5_omni_thinker_field_factory( - self.info.get_hf_config().vision_config.spatial_merge_size)( - hf_inputs) + self.info.get_hf_config().vision_config.spatial_merge_size + )(hf_inputs) def _maybe_apply_prompt_updates( self, @@ -331,16 +378,16 @@ def _maybe_apply_prompt_updates( """ mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + self._validate_mm_updates(mm_prompt_updates, mm_item_counts) use_audio_in_video = False if "video" in mm_kwargs: - video_items = [ - item for item in mm_kwargs["video"] if item is not None - ] + video_items = [item for item in mm_kwargs["video"] if item is not None] # only check video items (if there are any) if video_items: - use_audio_in_video = all(item["use_audio_in_video"].data - for item in video_items) + use_audio_in_video = all( + item["use_audio_in_video"].data for item in video_items + ) if is_update_applied: mm_placeholders = self._find_mm_placeholders( @@ -373,8 +420,7 @@ def _get_prompt_updates( ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) vocab = tokenizer.get_vocab() audio_token = processor.audio_token @@ -391,12 +437,14 @@ def _get_prompt_updates( audio_output_lengths = [] elif audio_feature_lengths is not None: _, audio_output_lens = _get_feat_extract_output_lengths( - audio_feature_lengths) + audio_feature_lengths + ) audio_output_lengths = audio_output_lens.tolist() elif feature_attention_mask is not None: assert isinstance(feature_attention_mask, torch.Tensor) _, audio_output_lens = _get_feat_extract_output_lengths( - feature_attention_mask.sum(-1)) + feature_attention_mask.sum(-1) + ) audio_output_lengths = audio_output_lens.tolist() # number of audios read from video. @@ -411,7 +459,8 @@ def get_replacement_qwen2_audio(item_idx: int): audio = audios.get(item_idx) raise ValueError( f"The audio {audio} (len={len(audio)}) is too short " - "to be represented inside the model") + "to be represented inside the model" + ) return [audio_token_id] * num_features @@ -423,21 +472,20 @@ def get_replacement_qwen2_vision(item_idx: int, modality: str): token_id = image_token_id if modality == "image" else video_token_id return [token_id] * (int(grid_thw.prod()) // merge_length) - use_audio_in_video = hf_processor_mm_kwargs.get( - "use_audio_in_video", False) + use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False) thinker_config = self.info.get_hf_config() def get_replacement_qwen2_use_audio_in_video(item_idx: int): nonlocal audio_in_video_item_idx - audio_num_features = audio_output_lengths[audio_in_video_item_idx + - item_idx] + audio_num_features = audio_output_lengths[ + audio_in_video_item_idx + item_idx + ] video_grid_thw = out_mm_data["video_grid_thw"][item_idx] audio_in_video_item_idx += 1 - second_per_grid_ts = hf_processor_mm_kwargs.get( - "second_per_grid_ts", None) + second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None) if second_per_grid_ts: video_second_per_grid_t = second_per_grid_ts[item_idx] else: @@ -451,8 +499,10 @@ def get_replacement_qwen2_use_audio_in_video(item_idx: int): ) video_replacement_fn = ( - get_replacement_qwen2_use_audio_in_video if use_audio_in_video else - partial(get_replacement_qwen2_vision, modality="video")) + get_replacement_qwen2_use_audio_in_video + if use_audio_in_video + else partial(get_replacement_qwen2_vision, modality="video") + ) return [ PromptReplacement( @@ -463,8 +513,7 @@ def get_replacement_qwen2_use_audio_in_video(item_idx: int): PromptReplacement( modality="image", target=image_token, - replacement=partial(get_replacement_qwen2_vision, - modality="image"), + replacement=partial(get_replacement_qwen2_vision, modality="image"), ), PromptReplacement( modality="video", @@ -517,8 +566,7 @@ def _apply_hf_processor_mm_only( """ mm_counts = mm_items.get_all_counts() - use_audio_in_video = hf_processor_mm_kwargs.get( - "use_audio_in_video", False) + use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False) if use_audio_in_video and "video" in mm_counts: assert "audio" in mm_counts mm_counts["audio"] -= mm_counts["video"] @@ -547,14 +595,11 @@ def _validate_mm_placeholders( class Qwen2_5OmniConditionalGenerationMixin: - - def _validate_and_reshape_mm_tensor(self, - mm_input: object, - name: str, - dim: int = 0) -> torch.Tensor: + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str, dim: int = 0 + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if dim == 0: return mm_input.reshape(-1, *mm_input.shape[2:]) @@ -563,25 +608,31 @@ def _validate_and_reshape_mm_tensor(self, return torch.concat(mm_input, dim=dim) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]: - input_audio_features = kwargs.pop('input_audio_features', None) - audio_feature_lengths = kwargs.pop('audio_feature_lengths', None) - feature_attention_mask = kwargs.pop('feature_attention_mask', None) + self, **kwargs: object + ) -> Optional[Qwen2_5OmniAudioFeatureInputs]: + input_audio_features = kwargs.pop("input_audio_features", None) + audio_feature_lengths = kwargs.pop("audio_feature_lengths", None) + feature_attention_mask = kwargs.pop("feature_attention_mask", None) if input_audio_features is None: return None input_audio_features = self._validate_and_reshape_mm_tensor( - input_audio_features, 'input_audio_features', dim=1) + input_audio_features, "input_audio_features", dim=1 + ) if feature_attention_mask is not None: feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, 'feature_attention_mask') + feature_attention_mask, "feature_attention_mask" + ) if not isinstance(input_audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio input features. " - f"Got type: {type(input_audio_features)}") + raise ValueError( + "Incorrect type of audio input features. " + f"Got type: {type(input_audio_features)}" + ) return Qwen2_5OmniAudioFeatureInputs( type="audio_features", input_features=input_audio_features, audio_feature_lengths=audio_feature_lengths, - feature_attention_mask=feature_attention_mask) + feature_attention_mask=feature_attention_mask, + ) def _parse_and_validate_image_input( self, @@ -596,31 +647,42 @@ def _parse_and_validate_image_input( if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") + pixel_values, "image pixel values" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + "Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}" + ) - return Qwen2_5_VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") + image_embeds, "image embeds" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") + raise ValueError( + "Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}" + ) return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( self, @@ -635,9 +697,11 @@ def _parse_and_validate_video_input( if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") + pixel_values_videos, "video pixel values" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", @@ -647,17 +711,22 @@ def _parse_and_validate_video_input( if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") + video_embeds, "video embeds" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") + raise ValueError( + "Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}" + ) return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, - video_grid_thw=video_grid_thw) + video_grid_thw=video_grid_thw, + ) def _process_audio_input( self, @@ -665,35 +734,35 @@ def _process_audio_input( audio_hashes: list[str] = None, cached_audio_features: torch.Tensor = None, ) -> torch.Tensor: - input_features = audio_input["input_features"] audio_feature_lengths = audio_input["audio_feature_lengths"] if input_features.ndim == 3: assert input_features.shape[0] == 1 input_features = input_features.squeeze(0) if audio_feature_lengths.ndim == 2: - assert audio_feature_lengths.shape[ - 0] == 1 or audio_feature_lengths.shape[1] == 1 + assert ( + audio_feature_lengths.shape[0] == 1 + or audio_feature_lengths.shape[1] == 1 + ) if audio_feature_lengths.shape[0] == 1: audio_feature_lengths = audio_feature_lengths.squeeze(0) else: audio_feature_lengths = audio_feature_lengths.squeeze(1) audio_feat_lengths, audio_output_lengths = ( - self.audio_tower._get_feat_extract_output_lengths( - audio_feature_lengths)) + self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths) + ) audio_outputs = self.audio_tower( input_features.to(self.audio_tower.dtype), feature_lens=audio_feature_lengths, aftercnn_lens=audio_feat_lengths, ) - return audio_outputs.last_hidden_state.split( - audio_output_lengths.tolist()) + return audio_outputs.last_hidden_state.split(audio_output_lengths.tolist()) def _process_image_input( - self, - image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + self, image_input: Qwen2_5_VLImageInputs + ) -> tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": return image_input["image_embeds"].type(self.visual.dtype) @@ -709,18 +778,18 @@ def _process_image_input( return image_embeds.split(sizes.tolist()) def _process_video_input( - self, - video_input: Qwen2_5_VLVideoInputs, - video_hashes: list[str] = None, - cached_video_embeds: torch.Tensor = None) -> torch.Tensor: + self, + video_input: Qwen2_5_VLVideoInputs, + video_hashes: list[str] = None, + cached_video_embeds: torch.Tensor = None, + ) -> torch.Tensor: if video_input["type"] == "video_embeds": return video_input["video_embeds"].type(self.visual.dtype) grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) + pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size @@ -735,14 +804,19 @@ def _process_video_input( dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, ) class Qwen2_5OmniThinkerForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, - Qwen2_5OmniConditionalGenerationMixin): + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsLoRA, + Qwen2_5OmniConditionalGenerationMixin, +): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "thinker.lm_head.": "language_model.lm_head.", "thinker.model.": "language_model.model.", "thinker.": "", - }) + } + ) packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -774,7 +848,8 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() thinker_config: Qwen2_5OmniThinkerConfig = ( - vllm_config.model_config.hf_config.thinker_config) + vllm_config.model_config.hf_config.thinker_config + ) quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = thinker_config @@ -790,20 +865,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): logger.warning( "flash_attn is not available, the model may not yield the " "exactly same result as the transformers implementation " - "in the audio tower part.") + "in the audio tower part." + ) if multimodal_config.get_limit_per_prompt("audio"): - self.audio_tower = Qwen2_5OmniAudioEncoder( - thinker_config.audio_config) + self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config) else: self.audio_tower = None if multimodal_config.get_limit_per_prompt( - "image") or multimodal_config.get_limit_per_prompt("video"): + "image" + ) or multimodal_config.get_limit_per_prompt("video"): self.visual = Qwen2_5_VisionTransformer( vision_config=thinker_config.vision_config, - norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", - 1e-6), + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), ) @@ -819,7 +894,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} @@ -827,28 +903,34 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key in ("pixel_values_videos", "video_embeds" - ) and "video" not in mm_input_by_modality: - mm_input_by_modality[ - "video"] = self._parse_and_validate_video_input(**kwargs) - if input_key in ("input_audio_features" - ) and "audio" not in mm_input_by_modality: - mm_input_by_modality[ - "audio"] = self._parse_and_validate_audio_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) + if ( + input_key in ("input_audio_features") + and "audio" not in mm_input_by_modality + ): + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input( + **kwargs + ) return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] @@ -892,8 +974,7 @@ def get_input_embeddings( handle_oov_mm_token=handle_oov_mm_token, ) - def get_multimodal_embeddings_v0( - self, **kwargs: object) -> Optional[NestedTensors]: + def get_multimodal_embeddings_v0(self, **kwargs: object) -> Optional[NestedTensors]: audio_input = self._parse_and_validate_audio_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) @@ -925,10 +1006,9 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( @@ -937,8 +1017,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = ["talker.", "token2wav."] if self.audio_tower is None: skip_prefixes.extend(["audio_tower."]) @@ -949,8 +1028,7 @@ def load_weights(self, weights: Iterable[tuple[str, self, skip_prefixes=skip_prefixes, ) - loaded_weights = loader.load_weights(weights, - mapper=self.hf_to_vllm_mapper) + loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights @@ -961,4 +1039,5 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model", connector="merger.", - tower_model=["visual.", "audio_tower."]) + tower_model=["visual.", "audio_tower."], + ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3c46516c7905..f7d2f6c584ca 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -25,6 +25,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" + from collections.abc import Iterable, Mapping, Sequence from functools import lru_cache, partial from typing import Annotated, Any, Callable, Literal, Optional, Union @@ -36,31 +37,37 @@ from transformers import BatchFeature from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( - Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) + Qwen2_5_VLConfig, + Qwen2_5_VLVisionConfig, +) from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import (check_upstream_fa_availability, - maybe_get_vit_flash_attn_backend) +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm -# yapf: disable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -# yapf: enable +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.evs import (compute_mrope_for_media, - compute_retained_tokens_count, - compute_retention_mask, - recompute_mrope_positions) +from vllm.multimodal.evs import ( + compute_mrope_for_media, + compute_retained_tokens_count, + compute_retention_mask, + recompute_mrope_positions, +) from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate @@ -68,14 +75,28 @@ from vllm.utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsEagle3, SupportsLoRA, - SupportsMultiModal, SupportsMultiModalPruning, - SupportsPP, SupportsQuant) +from .interfaces import ( + MultiModalEmbeddings, + SupportsEagle3, + SupportsLoRA, + SupportsMultiModal, + SupportsMultiModalPruning, + SupportsPP, + SupportsQuant, +) from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder -from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, - apply_rotary_pos_emb_vision) -from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, - init_vllm_registered_model, maybe_prefix) +from .qwen2_vl import ( + Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo, + apply_rotary_pos_emb_vision, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + cast_overflow_tensors, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -96,6 +117,7 @@ class Qwen2_5_VLImagePixelInputs(TensorSchema): - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) formatnum_channels * patch_size * patch_size """ + type: Literal["pixel_values"] pixel_values: Annotated[ @@ -124,6 +146,7 @@ class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["image_embeds"] image_embeds: Annotated[ @@ -137,8 +160,9 @@ class Qwen2_5_VLImageEmbeddingInputs(TensorSchema): ] -Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs, - Qwen2_5_VLImageEmbeddingInputs] +Qwen2_5_VLImageInputs = Union[ + Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageEmbeddingInputs +] class Qwen2_5_VLVideoPixelInputs(TensorSchema): @@ -158,6 +182,7 @@ class Qwen2_5_VLVideoPixelInputs(TensorSchema): grid along the temporal dimension in the 3D position IDs. Returned when `videos` is not `None`. """ + type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ @@ -191,6 +216,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["video_embeds"] video_embeds: Annotated[ @@ -204,22 +230,24 @@ class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema): ] -Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs, - Qwen2_5_VLVideoEmbeddingInputs] +Qwen2_5_VLVideoInputs = Union[ + Qwen2_5_VLVideoPixelInputs, Qwen2_5_VLVideoEmbeddingInputs +] # === Vision Encoder === # class Qwen2_5_VisionMLP(nn.Module): - - def __init__(self, - in_features: int, - hidden_features: int, - bias: bool = False, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( input_size=in_features, @@ -227,14 +255,17 @@ def __init__(self, bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", - disable_tp=use_data_parallel) - - self.down_proj = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj", - disable_tp=use_data_parallel) + disable_tp=use_data_parallel, + ) + + self.down_proj = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, + ) self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -247,14 +278,14 @@ def forward(self, x: torch.Tensor): def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): """All-gather the input tensor interleavely across model parallel group.""" import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] - dist.all_gather(gathered_tensors, - local_tensor, - group=parallel_state.get_tp_group().device_group) + dist.all_gather( + gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group + ) gathered_tensors_split = [ - torch.split(tensor, hidden_size // tp_size, -1) - for tensor in gathered_tensors + torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors ] ordered_tensors = [ tensor for pair in zip(*gathered_tensors_split) for tensor in pair @@ -264,7 +295,6 @@ def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): class Qwen2_5_VisionAttention(nn.Module): - def __init__( self, embed_dim: int, @@ -278,13 +308,18 @@ def __init__( ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = (1 if use_data_parallel else - parallel_state.get_tensor_model_parallel_world_size()) + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) + projection_size, num_heads + ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, self.tp_size) + num_heads, self.tp_size + ) self.qkv = QKVParallelLinear( hidden_size=embed_dim, @@ -294,55 +329,64 @@ def __init__( bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv", - disable_tp=use_data_parallel) + disable_tp=use_data_parallel, + ) - self.proj = RowParallelLinear(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - disable_tp=use_data_parallel) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) self.attn_backend = attn_backend self.use_upstream_fa = use_upstream_fa - self.attn_backend, self.flash_attn_varlen_func \ - = maybe_get_vit_flash_attn_backend( + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, ) + ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape if self.tp_size > 1: - qkv = all_gather_interleave(qkv, self.qkv.hidden_size, - self.tp_size) + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) # 3 * [s, b, head * head_dim] if self.tp_size > 1: - splitter = partial(dist_utils.split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] v = splitter(v)[self.tp_rank] # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -351,8 +395,7 @@ def forward( q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) @@ -360,22 +403,23 @@ def forward( q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = self.flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) - - context_layer = rearrange(output, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() + output = self.flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + ) + + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -385,36 +429,36 @@ def forward( q_i = q[:, start_idx:end_idx] k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + q_i, k_i, v_i = ( + rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() output, _ = self.proj(context_layer) return output class Qwen2_5_VisionBlock(nn.Module): - def __init__( self, dim: int, @@ -441,35 +485,39 @@ def __init__( prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, attn_backend=attn_backend, - use_upstream_fa=use_upstream_fa) - self.mlp = Qwen2_5_VisionMLP(dim, - mlp_hidden_dim, - act_fn=act_fn, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) + use_upstream_fa=use_upstream_fa, + ) + self.mlp = Qwen2_5_VisionMLP( + dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - x_attn = self.attn(self.norm1(x), - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - max_seqlen=max_seqlen, - seqlens=seqlens) + x_attn = self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) x_fused_norm, residual = self.norm2(x, residual=x_attn) x = residual + self.mlp(x_fused_norm) return x class Qwen2_5_VisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -483,22 +531,22 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d(in_channels, - hidden_size, - kernel_size=kernel_size, - stride=kernel_size, - bias=False) + self.proj = nn.Conv3d( + in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x).view(L, self.hidden_size) return x class Qwen2_5_VisionPatchMerger(nn.Module): - def __init__( self, d_model: int, @@ -545,13 +593,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2_5_VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta**( - torch.arange(0, dim, 2, dtype=torch.float, device='cpu') / dim)) + inv_freq = 1.0 / ( + theta ** (torch.arange(0, dim, 2, dtype=torch.float, device="cpu") / dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -560,12 +608,18 @@ def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta**(torch.arange( - 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) - / self.dim)) - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + self.inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device + ) + / self.dim + ) + ) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs @@ -575,7 +629,6 @@ def forward(self, seqlen: int) -> torch.Tensor: class Qwen2_5_VisionTransformer(nn.Module): - def __init__( self, vision_config: Qwen2_5_VLVisionConfig, @@ -615,35 +668,43 @@ def __init__( use_upstream_fa = False self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) - if self.attn_backend != _Backend.FLASH_ATTN and \ - self.attn_backend != _Backend.ROCM_AITER_FA and \ - check_upstream_fa_availability( - torch.get_default_dtype()): + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if ( + self.attn_backend != _Backend.FLASH_ATTN + and self.attn_backend != _Backend.ROCM_AITER_FA + and check_upstream_fa_availability(torch.get_default_dtype()) + ): self.attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." ) - self.blocks = nn.ModuleList([ - Qwen2_5_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=get_act_and_mul_fn(vision_config.hidden_act), - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend=self.attn_backend, - use_upstream_fa=use_upstream_fa) for layer_idx in range(depth) - ]) + self.blocks = nn.ModuleList( + [ + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=get_act_and_mul_fn(vision_config.hidden_act), + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa, + ) + for layer_idx in range(depth) + ] + ) self.merger = Qwen2_5_VisionPatchMerger( d_model=vision_config.out_hidden_size, context_dim=self.hidden_size, @@ -665,48 +726,66 @@ def device(self) -> torch.device: def rotary_pos_emb_thw(self, t, h, w): hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1) max_size = max(h, w) rotary_pos_emb_full = self.rotary_pos_emb(max_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) rotary_pos_emb = rotary_pos_emb.reshape( rotary_pos_emb.shape[0] // self.spatial_merge_unit, - self.spatial_merge_unit, -1) + self.spatial_merge_unit, + -1, + ) return rotary_pos_emb def get_window_index_thw(self, grid_t, grid_h, grid_w): - vit_merger_window_size = (self.window_size // - self.spatial_merge_size // self.patch_size) + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) llm_grid_h = grid_h // self.spatial_merge_size llm_grid_w = grid_w // self.spatial_merge_size index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) + grid_t, llm_grid_h, llm_grid_w + ) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) - index_padded = index_padded.reshape(grid_t, num_windows_h, - vit_merger_window_size, - num_windows_w, - vit_merger_window_size) + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( - grid_t, num_windows_h * num_windows_w, vit_merger_window_size, - vit_merger_window_size) + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] @@ -718,23 +797,29 @@ def get_window_index_thw(self, grid_t, grid_h, grid_w): @lru_cache(maxsize=1024) # noqa: B019 def get_rope_by_thw(self, t, h, w): - window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw( - t, h, w) + window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w) rotary_pos_emb_thw = self.rotary_pos_emb_thw(t, h, w) rotary_pos_emb_thw = rotary_pos_emb_thw[window_index_thw, :, :] rotary_pos_emb_thw = rotary_pos_emb_thw.flatten(start_dim=0, end_dim=1) cu_seqlens_thw = torch.repeat_interleave( - torch.tensor([h * w], dtype=torch.int32), t) - return (rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, - cu_seqlens_thw) + torch.tensor([h * w], dtype=torch.int32), t + ) + return ( + rotary_pos_emb_thw, + window_index_thw, + cu_seqlens_window_thw, + cu_seqlens_thw, + ) def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -744,9 +829,7 @@ def compute_attn_mask_seqlen( def invert_permutation(perm: torch.Tensor) -> torch.Tensor: # building the inverse permutation in O(n) time inv = torch.empty_like(perm, pin_memory=is_pin_memory_available()) - inv[perm] = torch.arange(perm.numel(), - device=perm.device, - dtype=perm.dtype) + inv[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype) return inv def forward( @@ -779,10 +862,9 @@ def forward( ) = self.get_rope_by_thw(t, h, w) window_index.append(window_index_thw + window_index_id) - window_index_id += (t * llm_h * llm_w) + window_index_id += t * llm_h * llm_w - cu_seqlens_window_thw = (cu_seqlens_window_thw + - cu_window_seqlens_last) + cu_seqlens_window_thw = cu_seqlens_window_thw + cu_window_seqlens_last cu_window_seqlens_last = cu_seqlens_window_thw[-1] cu_window_seqlens.append(cu_seqlens_window_thw) @@ -802,23 +884,22 @@ def forward( # transformers # pre-compute seqlens for window/full attn to reduce cuMemcpy operations - max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen( - cu_seqlens) + max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens) max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( - cu_window_seqlens) + cu_window_seqlens + ) cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) - cu_window_seqlens = cu_window_seqlens.to(device=self.device, - non_blocking=True) - rotary_pos_emb = rotary_pos_emb.to(device=self.device, - non_blocking=True) - window_index = window_index.to(device=hidden_states.device, - non_blocking=True) - reverse_indices = reverse_indices.to(device=hidden_states.device, - non_blocking=True) + cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True) + rotary_pos_emb = rotary_pos_emb.to(device=self.device, non_blocking=True) + window_index = window_index.to(device=hidden_states.device, non_blocking=True) + reverse_indices = reverse_indices.to( + device=hidden_states.device, non_blocking=True + ) hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) @@ -852,8 +933,7 @@ def forward( hidden_states = hidden_states[reverse_indices, :] return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), @@ -866,7 +946,7 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -876,15 +956,13 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen2_5_VLConfig) @@ -897,7 +975,6 @@ def get_hf_processor(self, **kwargs: object) -> Qwen2_5_VLProcessor: class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -915,8 +992,7 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -935,10 +1011,12 @@ def get_replacement_qwen2vl(item_idx: int, modality: str): num_tokens = int(grid_thw.prod()) // merge_length # EVS-specific code - video_pruning_rate = self.info.ctx.get_mm_config( - ).video_pruning_rate - if (modality == "video" and video_pruning_rate is not None - and video_pruning_rate > 0.0): + video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate + if ( + modality == "video" + and video_pruning_rate is not None + and video_pruning_rate > 0.0 + ): num_tokens = compute_retained_tokens_count( grid_thw, image_processor.merge_size, @@ -952,21 +1030,26 @@ def get_replacement_qwen2vl(item_idx: int, modality: str): PromptReplacement( modality=modality, target=[placeholder[modality]], - replacement=partial(get_replacement_qwen2vl, - modality=modality), - ) for modality in ("image", "video") + replacement=partial(get_replacement_qwen2vl, modality=modality), + ) + for modality in ("image", "video") ] @MULTIMODAL_REGISTRY.register_processor( Qwen2_5_VLMultiModalProcessor, info=Qwen2_5_VLProcessingInfo, - dummy_inputs=Qwen2_5_VLDummyInputsBuilder) -class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP, - SupportsQuant, SupportsEagle3, - SupportsMultiModalPruning): - + dummy_inputs=Qwen2_5_VLDummyInputsBuilder, +) +class Qwen2_5_VLForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsLoRA, + SupportsPP, + SupportsQuant, + SupportsEagle3, + SupportsMultiModalPruning, +): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -981,7 +1064,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, # mapping for original checkpoint "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - }) + } + ) supports_encoder_tp_data = True @@ -1004,10 +1088,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config self.video_pruning_rate = multimodal_config.video_pruning_rate self.is_multimodal_pruning_enabled = ( - multimodal_config.is_multimodal_pruning_enabled()) + multimodal_config.is_multimodal_pruning_enabled() + ) - if multimodal_config.get_limit_per_prompt("image") or \ - multimodal_config.get_limit_per_prompt("video"): + if multimodal_config.get_limit_per_prompt( + "image" + ) or multimodal_config.get_limit_per_prompt("video"): self.visual = Qwen2_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), @@ -1025,7 +1111,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.language_model.model.aux_hidden_state_layers = layers @@ -1034,24 +1121,27 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.language_model.model.layers) return (2, num_layers // 2, num_layers - 3) - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") + raise ValueError( + f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})" + ) return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + self, **kwargs: object + ) -> Optional[Qwen2_5_VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1061,27 +1151,35 @@ def _parse_and_validate_image_input( if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") + pixel_values, "image pixel values" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) - return Qwen2_5_VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") + image_embeds, "image embeds" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + self, **kwargs: object + ) -> Optional[Qwen2_5_VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1092,9 +1190,11 @@ def _parse_and_validate_video_input( if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") + pixel_values_videos, "video pixel values" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2: second_per_grid_ts = second_per_grid_ts.squeeze(-1) return Qwen2_5_VLVideoPixelInputs( @@ -1106,19 +1206,21 @@ def _parse_and_validate_video_input( if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") + video_embeds, "video embeds" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, - video_grid_thw=video_grid_thw) + video_grid_thw=video_grid_thw, + ) def _process_image_input( - self, - image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Qwen2_5_VLImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1129,25 +1231,27 @@ def _process_image_input( pixel_values = image_input["pixel_values"] if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values, - grid_thw_list, - rope_type="rope_3d") + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + ) else: - image_embeds = self.visual(pixel_values, - grid_thw=grid_thw_list) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return image_embeds.split(sizes) def _postprocess_image_embeds_evs( - self, image_embeds_split: tuple[torch.Tensor, ...], - image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + self, + image_embeds_split: tuple[torch.Tensor, ...], + image_input: Qwen2_5_VLImageInputs, + ) -> tuple[torch.Tensor, ...]: """ Append mrope positions for each for images. This is necessary to recover correct mrope @@ -1168,17 +1272,15 @@ def _postprocess_image_embeds_evs( grid_thw_list = grid_thw.tolist() image_embeds_out = [] for emb, size in zip(image_embeds_split, grid_thw_list): - positions = compute_mrope_for_media(size, - merge_size).to(emb.device) + positions = compute_mrope_for_media(size, merge_size).to(emb.device) emb = torch.cat([emb, positions], dim=1) image_embeds_out.append(emb) image_embeds_split = image_embeds_out return tuple(image_embeds_split) def _process_video_input( - self, - video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: - + self, video_input: Qwen2_5_VLVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1188,25 +1290,27 @@ def _process_video_input( else: pixel_values_videos = video_input["pixel_values_videos"] if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values_videos, - grid_thw_list, - rope_type="rope_3d") + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" + ) else: - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return video_embeds.split(sizes) def _postprocess_video_embeds_evs( - self, video_embeds_split: tuple[torch.Tensor, ...], - video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + self, + video_embeds_split: tuple[torch.Tensor, ...], + video_input: Qwen2_5_VLVideoInputs, + ) -> tuple[torch.Tensor, ...]: """ Prunes video embeddings via Efficient Video Sampling (EVS) and then appends mrope positions for each retained embeddings @@ -1231,9 +1335,9 @@ def _postprocess_video_embeds_evs( tokens_per_second = self.config.vision_config.tokens_per_second video_embeds_out = [] - for emb, size, video_second_per_grid_t in zip(video_embeds_split, - grid_thw_list, - second_per_grid_ts): + for emb, size, video_second_per_grid_t in zip( + video_embeds_split, grid_thw_list, second_per_grid_ts + ): # For each video, we compute retention mask using EVS retention_mask = compute_retention_mask( emb, @@ -1285,20 +1389,19 @@ def recompute_mrope_positions( vision_start_token_id = self.config.vision_start_token_id # Device - device = (multimodal_embeddings[0].device - if len(multimodal_embeddings) else mrope_positions.device) + device = ( + multimodal_embeddings[0].device + if len(multimodal_embeddings) + else mrope_positions.device + ) # Tensors - input_ids_t = torch.as_tensor(input_ids, - device=device, - dtype=torch.long) + input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) - # fmt: off - mm_embeddings_out = [mm[:, :-4] for mm in - multimodal_embeddings] - mm_embeddings_pos = [mm[:, -4:].permute(1, 0).long() for mm in - multimodal_embeddings] - # fmt: in + mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] + mm_embeddings_pos = [ + mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings + ] positions, mrope_positions_delta = recompute_mrope_positions( input_ids_t, @@ -1318,24 +1421,27 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key in ("pixel_values_videos", "video_embeds" - ) and "video" not in mm_input_by_modality: - mm_input_by_modality[ - "video"] = self._parse_and_validate_video_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] @@ -1399,9 +1505,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.visual is None: skip_prefixes.extend(["visual."]) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index f407692e1151..e61a730f97bb 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -22,29 +22,44 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" + from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Literal, Optional, Union import torch import torch.nn as nn from transformers import BatchFeature -from transformers.models.qwen2_audio import (Qwen2AudioConfig, - Qwen2AudioEncoder, - Qwen2AudioProcessor) +from transformers.models.qwen2_audio import ( + Qwen2AudioConfig, + Qwen2AudioEncoder, + Qwen2AudioProcessor, +) from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (AudioItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + AudioItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + DictEmbeddingItems, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -60,6 +75,7 @@ class Qwen2AudioFeatureInputs(TensorSchema): - na: Number of audios - nmb: Number of mel bins """ + type: Literal["audio_features"] input_features: Annotated[ Union[torch.Tensor, list[torch.Tensor]], @@ -80,6 +96,7 @@ class Qwen2AudioEmbeddingInputs(TensorSchema): - hs: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["audio_embeds"] = "audio_embeds" audio_embeds: Annotated[ @@ -94,7 +111,6 @@ class Qwen2AudioEmbeddingInputs(TensorSchema): class Qwen2AudioMultiModalProjector(nn.Module): - def __init__(self, audio_hidden_size: int, text_hidden_size: int): super().__init__() self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True) @@ -112,15 +128,13 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): class Qwen2AudioProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen2AudioConfig) def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor: return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs) - def get_feature_extractor(self, - **kwargs: object) -> WhisperFeatureExtractor: + def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: hf_processor = self.get_hf_processor(**kwargs) feature_extractor = hf_processor.feature_extractor # type: ignore assert isinstance(feature_extractor, WhisperFeatureExtractor) @@ -130,9 +144,7 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} -class Qwen2AudioDummyInputsBuilder( - BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): - +class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) @@ -156,10 +168,9 @@ def get_dummy_mm_data( audio_overrides = mm_options.get("audio") if mm_options else None return { - "audio": - self._get_dummy_audios(length=audio_len, - num_audios=num_audios, - overrides=audio_overrides) + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) } @@ -172,7 +183,6 @@ def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]): class Qwen2AudioMultiModalDataParser(MultiModalDataParser): - def _parse_audio_data( self, data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], @@ -188,13 +198,10 @@ def _parse_audio_data( return super()._parse_audio_data(data) -class Qwen2AudioMultiModalProcessor( - BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): - +class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() - return Qwen2AudioMultiModalDataParser( - target_sr=feature_extractor.sampling_rate) + return Qwen2AudioMultiModalDataParser(target_sr=feature_extractor.sampling_rate) def _call_hf_processor( self, @@ -242,17 +249,14 @@ def _get_prompt_updates( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: - processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() # Use getattr with default to be compatible with transformers<4.48 audio_token = getattr(processor, "audio_token", "<|AUDIO|>") - audio_bos_token = getattr(processor, "audio_bos_token", - "<|audio_bos|>") - audio_eos_token = getattr(processor, "audio_eos_token", - "<|audio_eos|>") + audio_bos_token = getattr(processor, "audio_bos_token", "<|audio_bos|>") + audio_eos_token = getattr(processor, "audio_eos_token", "<|audio_eos|>") audio_token_id = vocab[audio_token] audio_bos_id = vocab[audio_bos_token] @@ -265,26 +269,27 @@ def _get_prompt_updates( else: assert isinstance(feature_attention_mask, torch.Tensor) _, audio_output_lens = _get_feat_extract_output_lengths( - feature_attention_mask.sum(-1)) + feature_attention_mask.sum(-1) + ) audio_output_lengths = audio_output_lens.tolist() def get_replacement_qwen2_audio(item_idx: int): - if audio_output_lengths: num_features = audio_output_lengths[item_idx] else: audio_embeds = out_mm_data["audio_embeds"][item_idx] - assert len(audio_embeds.shape - ) == 2, "audio_embeds must be a 2D tensor" + assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor" num_features = audio_embeds.shape[0] if num_features == 0: audios = mm_items.get_items("audio", AudioProcessorItems) audio_len = audios.get_audio_length(item_idx) - raise ValueError(f"The audio (len={audio_len}) is too short " - "to be represented inside the model") + raise ValueError( + f"The audio (len={audio_len}) is too short " + "to be represented inside the model" + ) audio_tokens = [audio_token_id] * num_features @@ -305,10 +310,9 @@ def get_replacement_qwen2_audio(item_idx: int): @MULTIMODAL_REGISTRY.register_processor( Qwen2AudioMultiModalProcessor, info=Qwen2AudioProcessingInfo, - dummy_inputs=Qwen2AudioDummyInputsBuilder) -class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): - + dummy_inputs=Qwen2AudioDummyInputsBuilder, +) +class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("audio"): @@ -326,7 +330,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.audio_tower = Qwen2AudioEncoder(config.audio_config) self.multi_modal_projector = Qwen2AudioMultiModalProjector( - config.audio_config.d_model, config.text_config.hidden_size) + config.audio_config.d_model, config.text_config.hidden_size + ) self.quant_config = quant_config @@ -338,45 +343,53 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): return mm_input.reshape(-1, *mm_input.shape[2:]) else: return torch.concat(mm_input) def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[Qwen2AudioInputs]: - input_features = kwargs.pop('input_features', None) - audio_embeds = kwargs.pop('audio_embeds', None) - feature_attention_mask = kwargs.pop('feature_attention_mask', None) + self, **kwargs: object + ) -> Optional[Qwen2AudioInputs]: + input_features = kwargs.pop("input_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + feature_attention_mask = kwargs.pop("feature_attention_mask", None) if input_features is None and audio_embeds is None: return None if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio embeds. " - f"Got type: {type(audio_embeds)}") + raise ValueError( + f"Incorrect type of audio embeds. Got type: {type(audio_embeds)}" + ) audio_embeds = self._validate_and_reshape_mm_tensor( - audio_embeds, "audio_embeds") - return Qwen2AudioEmbeddingInputs(type="audio_embeds", - audio_embeds=audio_embeds) + audio_embeds, "audio_embeds" + ) + return Qwen2AudioEmbeddingInputs( + type="audio_embeds", audio_embeds=audio_embeds + ) if input_features is not None: input_features = self._validate_and_reshape_mm_tensor( - input_features, 'input_features') + input_features, "input_features" + ) feature_attention_mask = self._validate_and_reshape_mm_tensor( - feature_attention_mask, 'feature_attention_mask') + feature_attention_mask, "feature_attention_mask" + ) return Qwen2AudioFeatureInputs( type="audio_features", input_features=input_features, - feature_attention_mask=feature_attention_mask) + feature_attention_mask=feature_attention_mask, + ) raise AssertionError("This line should be unreachable.") @@ -392,51 +405,62 @@ def _process_audio_input( audio_feat_lengths, audio_output_lengths = ( self.audio_tower._get_feat_extract_output_lengths( - feature_attention_mask.sum(-1))) + feature_attention_mask.sum(-1) + ) + ) batch_size, _, max_mel_seq_len = input_features.shape max_seq_len = (max_mel_seq_len - 2) // 2 + 1 # Create a sequence tensor of shape (batch_size, max_seq_len) - seq_range = (torch.arange( - 0, - max_seq_len, - dtype=audio_feat_lengths.dtype, - device=audio_feat_lengths.device).unsqueeze(0).expand( - batch_size, max_seq_len)) + seq_range = ( + torch.arange( + 0, + max_seq_len, + dtype=audio_feat_lengths.dtype, + device=audio_feat_lengths.device, + ) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) lengths_expand = audio_feat_lengths.unsqueeze(-1).expand( - batch_size, max_seq_len) + batch_size, max_seq_len + ) # Create mask padding_mask = seq_range >= lengths_expand - audio_attention_mask_ = padding_mask.view( - batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len, - max_seq_len) + audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( + batch_size, 1, max_seq_len, max_seq_len + ) audio_attention_mask = audio_attention_mask_.to( dtype=self.audio_tower.conv1.weight.dtype, - device=self.audio_tower.conv1.weight.device) + device=self.audio_tower.conv1.weight.device, + ) audio_attention_mask[audio_attention_mask_] = float("-inf") - audio_outputs = self.audio_tower(input_features, - attention_mask=audio_attention_mask) + audio_outputs = self.audio_tower( + input_features, attention_mask=audio_attention_mask + ) selected_audio_feature = audio_outputs.last_hidden_state audio_features = self.multi_modal_projector(selected_audio_feature) num_audios, max_audio_tokens, embed_dim = audio_features.shape audio_output_lengths = audio_output_lengths.unsqueeze(1) - audio_features_mask = torch.arange(max_audio_tokens).expand( - num_audios, max_audio_tokens).to( - audio_output_lengths.device) < audio_output_lengths - masked_audio_features = audio_features[audio_features_mask].view( - -1, embed_dim) + audio_features_mask = ( + torch.arange(max_audio_tokens) + .expand(num_audios, max_audio_tokens) + .to(audio_output_lengths.device) + < audio_output_lengths + ) + masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim) # Split to tuple of embeddings for individual audio input. - return torch.split(masked_audio_features, - audio_output_lengths.flatten().tolist()) + return torch.split( + masked_audio_features, audio_output_lengths.flatten().tolist() + ) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] @@ -451,14 +475,12 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits( @@ -467,7 +489,6 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6a9acaf2c3fe..61b203a08349 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -24,6 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional, Union @@ -41,29 +42,36 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Qwen2MoeMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -75,19 +83,24 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -98,7 +111,6 @@ def forward(self, x): class Qwen2MoeSparseMoeBlock(nn.Module): - def __init__( self, config: Qwen2MoeConfig, @@ -111,37 +123,39 @@ def __init__( if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - self.experts = FusedMoE(num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts") - - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + f"the number of experts {config.num_experts}." + ) + + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), + reduce_results=self.experts.must_reduce_shared_expert_outputs(), prefix=f"{prefix}.shared_expert", ) else: self.shared_expert = None - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, - 1, - bias=False) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -152,24 +166,26 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_expert is not None: shared_output = self.shared_expert(hidden_states) if self.shared_expert_gate is not None: - shared_output = F.sigmoid( - self.shared_expert_gate(hidden_states)) * shared_output + shared_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output + ) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states) + final_hidden_states + ) return final_hidden_states.view(orig_shape) class Qwen2MoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -207,19 +223,23 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.dual_chunk_attention_config = dual_chunk_attention_config - self.qkv_proj = QKVParallelLinear(hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) self.rotary_emb = get_rope( self.head_dim, @@ -240,7 +260,10 @@ def __init__( **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}) + } + if dual_chunk_attention_config + else {}, + ) def forward( self, @@ -256,7 +279,6 @@ def forward( class Qwen2MoeDecoderLayer(nn.Module): - def __init__( self, config: Qwen2MoeConfig, @@ -268,11 +290,10 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - dual_chunk_attention_config = getattr(config, - "dual_chunk_attention_config", - None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Qwen2MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -289,24 +310,27 @@ def __init__( # Note: Qwen/Qwen2-57B-A14B-Instruct does not have # `mlp_only_layers` in the config. layer_idx = extract_layer_index(prefix) - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) if (layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 and - (layer_idx + 1) % config.decoder_sparse_step == 0): - self.mlp = Qwen2MoeSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen2MoeSparseMoeBlock( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) else: - self.mlp = Qwen2MoeMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.mlp = Qwen2MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -319,23 +343,20 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class Qwen2MoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -352,16 +373,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Qwen2MoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: Qwen2MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -386,10 +409,9 @@ def forward( for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -400,10 +422,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) + num_experts=self.config.num_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -417,7 +439,7 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -431,8 +453,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -455,21 +478,25 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -477,7 +504,8 @@ def load_weights(self, weights: Iterable[tuple[str, # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") + ".kv_scale", ".attn.kv_scale" + ) if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501 @@ -488,15 +516,15 @@ def load_weights(self, weights: Iterable[tuple[str, else: name = remapped_kv_scale_name param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - fall_back_to_pt_during_load = False packed_modules_mapping = { "qkv_proj": [ @@ -516,17 +544,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Qwen2MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + self.model = Qwen2MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -538,8 +570,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -549,8 +582,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 2bd9d2b52628..75ed95477f78 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -6,6 +6,7 @@ # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. """Inference-only Qwen2-RM model compatible with HuggingFace weights.""" + from collections.abc import Iterable from typing import Optional, Union @@ -13,8 +14,7 @@ from torch import nn from vllm.config import VllmConfig -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.sequence import IntermediateTensors @@ -25,7 +25,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): - is_pooling_model = True pooler: Pooler @@ -51,25 +50,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Qwen2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.head_dtype = vllm_config.model_config.head_dtype self.score = nn.Sequential( - ColumnParallelLinear(config.hidden_size, - config.hidden_size, - quant_config=quant_config, - params_dtype=self.head_dtype, - return_bias=False), + ColumnParallelLinear( + config.hidden_size, + config.hidden_size, + quant_config=quant_config, + params_dtype=self.head_dtype, + return_bias=False, + ), nn.ReLU(), - RowParallelLinear(config.hidden_size, - config.num_labels, - params_dtype=self.head_dtype, - quant_config=quant_config, - return_bias=False), + RowParallelLinear( + config.hidden_size, + config.num_labels, + params_dtype=self.head_dtype, + quant_config=quant_config, + return_bias=False, + ), ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -81,22 +86,20 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) hidden_states = hidden_states.to(self.head_dtype) logits = self.score(hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, - ignore_unexpected_prefixes=["lm_head."]) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["lm_head."]) return loader.load_weights(weights) @default_pooling_type("ALL") class Qwen2ForRewardModel(Qwen2RewardBaseModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 1 super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -105,12 +108,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, ) + {"encode": Pooler.for_encode(pooler_config)}, + ) @default_pooling_type("STEP") class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 2 super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -118,5 +121,4 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}) + self.pooler = DispatchPooler({"encode": Pooler.for_encode(pooler_config)}) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index ab9bfe4d0f19..cb1bf3825c74 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -24,6 +24,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" + from collections.abc import Iterable, Mapping, Sequence from functools import partial from typing import Annotated, Any, Callable, Literal, Optional, Union @@ -33,49 +34,72 @@ import torch.nn.functional as F from einops import rearrange, repeat from transformers import AutoConfig, BatchFeature, PretrainedConfig -from transformers.models.qwen2_vl import (Qwen2VLImageProcessor, - Qwen2VLProcessor) +from transformers.models.qwen2_vl import Qwen2VLImageProcessor, Qwen2VLProcessor from transformers.models.qwen2_vl.configuration_qwen2_vl import ( - Qwen2VLConfig, Qwen2VLVisionConfig) + Qwen2VLConfig, + Qwen2VLVisionConfig, +) from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize -from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( - Qwen2VLVideoProcessor) +from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import (check_upstream_fa_availability, - maybe_get_vit_flash_attn_backend) +from vllm.attention.layer import ( + check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend, +) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding.common import ( - dispatch_rotary_emb_function) + dispatch_rotary_emb_function, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, VideoItem) -from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize, - ModalityDataItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ImageSize, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMRoPE, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -93,13 +117,14 @@ class Qwen2VLImagePixelInputs(TensorSchema): the batch - ni: Number of images - cps: Number of channels * patch_size * patch_size - + Historical context: - - pixel_values shape: (num_patches, num_channels * patch_size * + - pixel_values shape: (num_patches, num_channels * patch_size * patch_size) - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["pixel_values"] pixel_values: Annotated[ @@ -119,7 +144,7 @@ class Qwen2VLImageEmbeddingInputs(TensorSchema): - nf: Number of image features - hs: Hidden size - ni: Number of images - + Historical context: - image_embeds shape: (num_image_features, hidden_size) - num_image_features varies based on the number and resolution of the @@ -128,6 +153,7 @@ class Qwen2VLImageEmbeddingInputs(TensorSchema): - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["image_embeds"] image_embeds: Annotated[ @@ -141,8 +167,7 @@ class Qwen2VLImageEmbeddingInputs(TensorSchema): ] -Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, - Qwen2VLImageEmbeddingInputs] +Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, Qwen2VLImageEmbeddingInputs] class Qwen2VLVideoPixelInputs(TensorSchema): @@ -150,16 +175,17 @@ class Qwen2VLVideoPixelInputs(TensorSchema): Dimensions: - np: The total number of patches over each video over each prompt in the batch - - ctps: Number of channels * temporal_patch_size * patch_size * + - ctps: Number of channels * temporal_patch_size * patch_size * patch_size - nv: Number of videos - + Historical context: - - pixel_values_videos shape: (num_patches, num_channels * + - pixel_values_videos shape: (num_patches, num_channels * temporal_patch_size * patch_size * patch_size) - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["pixel_values_videos"] pixel_values_videos: Annotated[ @@ -179,7 +205,7 @@ class Qwen2VLVideoEmbeddingInputs(TensorSchema): - nf: Number of video features - hs: Hidden size - nv: Number of videos - + Historical context: - video_embeds shape: (num_video_features, hidden_size) - num_video_features varies based on the number and resolution of the @@ -188,6 +214,7 @@ class Qwen2VLVideoEmbeddingInputs(TensorSchema): - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format """ + type: Literal["video_embeds"] video_embeds: Annotated[ @@ -201,14 +228,12 @@ class Qwen2VLVideoEmbeddingInputs(TensorSchema): ] -Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, - Qwen2VLVideoEmbeddingInputs] +Qwen2VLVideoInputs = Union[Qwen2VLVideoPixelInputs, Qwen2VLVideoEmbeddingInputs] # === Vision Encoder === # class Qwen2VisionMLP(nn.Module): - def __init__( self, in_features: int, @@ -219,17 +244,21 @@ def __init__( use_data_parallel: bool = False, ): super().__init__() - self.fc1 = ColumnParallelLinear(in_features, - hidden_features, - quant_config=quant_config, - prefix=f"{prefix}.fc1", - disable_tp=use_data_parallel) + self.fc1 = ColumnParallelLinear( + in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, + ) self.act = act_layer() - self.fc2 = RowParallelLinear(hidden_features, - in_features, - quant_config=quant_config, - prefix=f"{prefix}.fc2", - disable_tp=use_data_parallel) + self.fc2 = RowParallelLinear( + hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x_parallel, _ = self.fc1(x) @@ -244,15 +273,14 @@ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), - "... d two -> ... (d two)", - two=2) + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) -def apply_rotary_emb_torch(x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False) -> torch.Tensor: +def apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) @@ -260,24 +288,22 @@ def apply_rotary_emb_torch(x: torch.Tensor, ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat( - cos, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) sin = repeat( - sin, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) return torch.cat( [ - x[..., :ro_dim] * cos + - rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], ], dim=-1, ) -def apply_rotary_pos_emb_vision(t: torch.Tensor, - freqs: torch.Tensor) -> torch.Tensor: - rotary_emb_function = dispatch_rotary_emb_function( - default=apply_rotary_emb_torch) +def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch) t_ = t.float() cos = freqs.cos() sin = freqs.sin() @@ -286,7 +312,6 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, class Qwen2VisionAttention(nn.Module): - def __init__( self, embed_dim: int, @@ -298,46 +323,61 @@ def __init__( ) -> None: super().__init__() # Per attention head and per partition values. - self.tp_size = (1 if use_data_parallel else - parallel_state.get_tensor_model_parallel_world_size()) + self.tp_size = ( + 1 + if use_data_parallel + else parallel_state.get_tensor_model_parallel_world_size() + ) self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) + projection_size, num_heads + ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, self.tp_size) - - self.qkv = ColumnParallelLinear(input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - prefix=f"{prefix}.qkv", - disable_tp=use_data_parallel) - self.proj = RowParallelLinear(input_size=projection_size, - output_size=embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.proj", - disable_tp=use_data_parallel) + num_heads, self.tp_size + ) + + self.qkv = ColumnParallelLinear( + input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + disable_tp=use_data_parallel, + ) + self.proj = RowParallelLinear( + input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + disable_tp=use_data_parallel, + ) # Detect attention implementation. self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, - dtype=torch.get_default_dtype()) + dtype=torch.get_default_dtype(), + ) self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func \ - = maybe_get_vit_flash_attn_backend( + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, ) + ) if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( - f"Qwen2-VL does not support {self.attn_backend} backend now.") + f"Qwen2-VL does not support {self.attn_backend} backend now." + ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -351,27 +391,31 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # 3 * [s, b, head * head_dim] if self.tp_size > 1: - splitter = partial(dist_utils.split_tensor_along_last_dim, - num_partitions=self.tp_size) + splitter = partial( + dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size + ) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] v = splitter(v)[self.tp_rank] # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) + new_shape = ( + seq_len, + bs, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - # [s, b, c] --> [s, b, 3 * head * head_dim] x, _ = self.qkv(x) @@ -379,8 +423,7 @@ def forward( q, k, v = self.split_qkv(x) batch_size = q.shape[1] - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) if rotary_pos_emb is not None: # [2 * b, s, heads, head_dim] qk_concat = torch.cat([q, k], dim=0) @@ -388,22 +431,23 @@ def forward( q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = self.flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) - - context_layer = rearrange(output, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() + output = self.flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False, + ) + + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=batch_size + ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] @@ -413,36 +457,36 @@ def forward( q_i = q[:, start_idx:end_idx] k_i = k[:, start_idx:end_idx] v_i = v[:, start_idx:end_idx] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") - for x in [q_i, k_i, v_i]) - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + q_i, k_i, v_i = ( + rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i] + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, - kv_seqlen=None, - device=q.device) + attn_bias = BlockDiagonalMask.from_seqlens( + q_seqlen=seqlens, kv_seqlen=None, device=q.device + ) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None) - context_layer = rearrange(context_layer, - "b s h d -> s b (h d)").contiguous() + q, k, v, attn_bias=attn_bias, p=0, scale=None + ) + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() output, _ = self.proj(context_layer) return output class Qwen2VisionBlock(nn.Module): - def __init__( self, dim: int, @@ -461,26 +505,30 @@ def __init__( self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.attn = Qwen2VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_data_parallel=use_data_parallel) - self.mlp = Qwen2VisionMLP(dim, - mlp_hidden_dim, - act_layer=act_layer, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) + self.attn = Qwen2VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_data_parallel=use_data_parallel, + ) + self.mlp = Qwen2VisionMLP( + dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -495,7 +543,6 @@ def forward( class Qwen2VisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -509,22 +556,22 @@ def __init__( self.embed_dim = embed_dim kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d(in_channels, - embed_dim, - kernel_size=kernel_size, - stride=kernel_size, - bias=False) + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x).view(L, self.embed_dim) return x class Qwen2VisionPatchMerger(nn.Module): - def __init__( self, d_model: int, @@ -540,21 +587,27 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_q = norm_layer(context_dim) - self.mlp = nn.ModuleList([ - ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.0", - disable_tp=use_data_parallel), - nn.GELU(), - RowParallelLinear(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp.2", - disable_tp=use_data_parallel), - ]) + self.mlp = nn.ModuleList( + [ + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + disable_tp=use_data_parallel, + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + disable_tp=use_data_parallel, + ), + ] + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.ln_q(x) @@ -568,13 +621,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen2VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._seq_len_cached = 0 self._freqs_cached = None @@ -583,12 +634,18 @@ def update_freqs_cache(self, seqlen: int) -> None: if seqlen > self._seq_len_cached: seqlen *= 2 self._seq_len_cached = seqlen - self.inv_freq = 1.0 / (self.theta**(torch.arange( - 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) - / self.dim)) - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + self.inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device + ) + / self.dim + ) + ) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) self._freqs_cached = freqs @@ -598,7 +655,6 @@ def forward(self, seqlen: int) -> torch.Tensor: class Qwen2VisionTransformer(nn.Module): - def __init__( self, vision_config: Qwen2VLVisionConfig, @@ -637,16 +693,20 @@ def __init__( head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList([ - Qwen2VisionBlock(dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(depth) - ]) + self.blocks = nn.ModuleList( + [ + Qwen2VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(depth) + ] + ) self.merger = Qwen2VisionPatchMerger( d_model=hidden_size, context_dim=embed_dim, @@ -656,10 +716,11 @@ def __init__( use_data_parallel=use_data_parallel, ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): + head_size=head_dim, dtype=torch.get_default_dtype() + ) + if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() + ): self.attn_backend = _Backend.FLASH_ATTN @property @@ -676,20 +737,27 @@ def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) - hpos_ids = hpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - wpos_ids = wpos_ids.reshape( - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ).permute(0, 2, 1, 3).flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) max_grid_size = max(max_grid_size, h, w) pos_ids = torch.cat(pos_ids, dim=0) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) @@ -697,11 +765,13 @@ def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: return rotary_pos_emb def compute_attn_mask_seqlen( - self, cu_seqlens: torch.Tensor + self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -720,10 +790,10 @@ def forward( rotary_pos_emb = self.rot_pos_emb(grid_thw) # compute cu_seqlens - grid_thw_ = torch.tensor(grid_thw) - cu_seqlens = torch.repeat_interleave(grid_thw_[:, 1] * grid_thw_[:, 2], - grid_thw_[:, 0]).cumsum( - dim=0, dtype=torch.int32) + grid_thw_ = torch.tensor(grid_thw, device=x.device, dtype=torch.long) + cu_seqlens = torch.repeat_interleave( + grid_thw_[:, 1] * grid_thw_[:, 2], grid_thw_[:, 0] + ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) # transformers @@ -745,8 +815,7 @@ def forward( return x - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -757,7 +826,7 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -768,41 +837,45 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params def _create_qwen2vl_field_factory( - spatial_merge_size: int + spatial_merge_size: int, ) -> Callable[ [Mapping[str, torch.Tensor]], - Mapping[str, MultiModalFieldConfig], + Mapping[str, MultiModalFieldConfig], ]: - def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) image_pixel_grid_sizes = image_grid_thw.prod(-1) - image_embed_grid_sizes = (image_pixel_grid_sizes // - spatial_merge_size // spatial_merge_size) + image_embed_grid_sizes = ( + image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size + ) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_sizes = video_grid_thw.prod(-1) - video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size // - spatial_merge_size) + video_embed_grid_sizes = ( + video_grid_sizes // spatial_merge_size // spatial_merge_size + ) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_pixel_grid_sizes), + "image", image_pixel_grid_sizes + ), image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_embed_grid_sizes), + "image", image_embed_grid_sizes + ), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_embed_grid_sizes), + "video", video_embed_grid_sizes + ), video_grid_thw=MultiModalFieldConfig.batched("video"), ) @@ -810,7 +883,6 @@ def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]): class Qwen2VLMultiModalDataParser(MultiModalDataParser): - def __init__(self, spatial_merge_size: int, *args, **kwargs): self._spatial_merge_size = spatial_merge_size super().__init__(*args, **kwargs) @@ -824,8 +896,7 @@ def _parse_image_data( data, modality="image", required_fields={"image_embeds", "image_grid_thw"}, - fields_factory=_create_qwen2vl_field_factory( - self._spatial_merge_size), + fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size), ) return super()._parse_image_data(data) @@ -839,15 +910,13 @@ def _parse_video_data( data, modality="video", required_fields={"video_embeds", "video_grid_thw"}, - fields_factory=_create_qwen2vl_field_factory( - self._spatial_merge_size), + fields_factory=_create_qwen2vl_field_factory(self._spatial_merge_size), ) return super()._parse_video_data(data) class Qwen2VLProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen2VLConfig) @@ -899,11 +968,9 @@ def _get_vision_info( min_pixels=image_processor.min_pixels, max_pixels=image_processor.max_pixels, ) - preprocessed_size = ImageSize(width=resized_width, - height=resized_height) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: - preprocessed_size = ImageSize(width=image_width, - height=image_height) + preprocessed_size = ImageSize(width=image_width, height=image_height) # NOTE: Frames are padded to be divisible by `temporal_patch_size` # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 @@ -967,9 +1034,7 @@ def get_max_image_tokens(self) -> int: image_processor=None, ) - def _get_max_video_frames(self, - max_tokens: int, - start_num_frames: int = 1) -> int: + def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int: target_width, target_height = self.get_image_size_with_most_features() num_frames = start_num_frames @@ -999,8 +1064,9 @@ def get_num_frames_with_most_features( max_videos = mm_counts.get("video", 0) max_total_frames = self._get_max_video_frames(seq_len) - max_frames_per_video = min(max_total_frames // max(max_videos, 1), - max_frames_per_video) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), max_frames_per_video + ) return max(max_frames_per_video, 1) @@ -1014,14 +1080,12 @@ def get_max_video_tokens( return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), image_processor=None, ) class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -1041,37 +1105,36 @@ def get_dummy_mm_data( num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) - target_width, target_height = \ - self.info.get_image_size_with_most_features() - target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len, mm_counts) + target_width, target_height = self.info.get_image_size_with_most_features() + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), - "video": - self._get_dummy_videos( + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( width=target_width, height=target_height, num_frames=target_num_frames, num_videos=num_videos, overrides=video_overrides, - ) + ), } -class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] - ): - +class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return Qwen2VLMultiModalDataParser( - self.info.get_hf_config().vision_config.spatial_merge_size) + self.info.get_hf_config().vision_config.spatial_merge_size + ) def _get_prompt_updates( self, @@ -1080,8 +1143,7 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() @@ -1104,9 +1166,9 @@ def get_replacement_qwen2vl(item_idx: int, modality: str): PromptReplacement( modality=modality, target=[placeholder[modality]], - replacement=partial(get_replacement_qwen2vl, - modality=modality), - ) for modality in ("image", "video") + replacement=partial(get_replacement_qwen2vl, modality=modality), + ) + for modality in ("image", "video") ] def _get_mm_fields_config( @@ -1115,16 +1177,18 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _create_qwen2vl_field_factory( - self.info.get_hf_config().vision_config.spatial_merge_size)( - hf_inputs) - - -@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, - info=Qwen2VLProcessingInfo, - dummy_inputs=Qwen2VLDummyInputsBuilder) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP, SupportsMRoPE): - + self.info.get_hf_config().vision_config.spatial_merge_size + )(hf_inputs) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2VLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder, +) +class Qwen2VLForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE +): # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ @@ -1134,7 +1198,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, # mapping for original checkpoint "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - }) + } + ) supports_encoder_tp_data = True @@ -1162,12 +1227,12 @@ def get_mrope_input_positions( video_token_id = hf_config.video_token_id vision_start_token_id = hf_config.vision_start_token_id spatial_merge_size = hf_config.vision_config.spatial_merge_size - tokens_per_second = getattr(hf_config.vision_config, - "tokens_per_second", 1.0) + tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0) input_tokens_tensor = torch.tensor(input_tokens) vision_start_indices = torch.argwhere( - input_tokens_tensor == vision_start_token_id).squeeze(1) + input_tokens_tensor == vision_start_token_id + ).squeeze(1) vision_tokens = input_tokens_tensor[vision_start_indices + 1] image_nums = (vision_tokens == image_token_id).sum() video_nums = (vision_tokens == video_token_id).sum() @@ -1215,37 +1280,56 @@ def get_mrope_input_positions( remain_videos -= 1 ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = \ - t, h // spatial_merge_size, w // spatial_merge_size + llm_grid_t, llm_grid_h, llm_grid_w = ( + t, + h // spatial_merge_size, + w // spatial_merge_size, + ) text_len = ed - st - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) - t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * - tokens_per_second).long().flatten() + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + * video_second_per_grid_t + * tokens_per_second + ) + .long() + .flatten() + ) - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( - llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten() + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) llm_pos_ids_list.append( - torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len( - llm_pos_ids_list) > 0 else 0 + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - mrope_position_delta = (llm_positions.max() + 1 - - len(input_tokens)).item() + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() llm_positions = llm_positions[:, context_len:seq_len] return llm_positions, mrope_position_delta @@ -1269,8 +1353,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.multimodal_config = multimodal_config - if multimodal_config.get_limit_per_prompt("image") or \ - multimodal_config.get_limit_per_prompt("video"): + if multimodal_config.get_limit_per_prompt( + "image" + ) or multimodal_config.get_limit_per_prompt("video"): self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), @@ -1288,26 +1373,30 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") + raise ValueError( + f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})" + ) return mm_input.reshape(-1, mm_input.shape[-1]) else: return torch.concat(mm_input) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: + self, **kwargs: object + ) -> Optional[Qwen2VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1317,26 +1406,35 @@ def _parse_and_validate_image_input( if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") + pixel_values, "image pixel values" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) - return Qwen2VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") + image_embeds, "image embeds" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) - return Qwen2VLImageEmbeddingInputs(type="image_embeds", - image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + return Qwen2VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: + self, **kwargs: object + ) -> Optional[Qwen2VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1346,9 +1444,11 @@ def _parse_and_validate_video_input( if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") + pixel_values_videos, "video pixel values" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) return Qwen2VLVideoPixelInputs( type="pixel_values_videos", @@ -1358,17 +1458,21 @@ def _parse_and_validate_video_input( if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") + video_embeds, "video embeds" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) - return Qwen2VLVideoEmbeddingInputs(type="video_embeds", - video_embeds=video_embeds, - video_grid_thw=video_grid_thw) + return Qwen2VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + ) def _process_image_input( - self, image_input: Qwen2VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Qwen2VLImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1379,24 +1483,24 @@ def _process_image_input( pixel_values = image_input["pixel_values"] if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values, - grid_thw_list, - rope_type="rope_3d") + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + ) else: - image_embeds = self.visual(pixel_values, - grid_thw=grid_thw_list) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return image_embeds.split(sizes) def _process_video_input( - self, video_input: Qwen2VLVideoInputs) -> tuple[torch.Tensor, ...]: - + self, video_input: Qwen2VLVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1406,18 +1510,18 @@ def _process_video_input( else: pixel_values_videos = video_input["pixel_values_videos"] if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values_videos, - grid_thw_list, - rope_type="rope_3d") + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" + ) else: - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return video_embeds.split(sizes) @@ -1427,23 +1531,23 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: - if input_key in ("pixel_values", - "image_embeds") and "images" not in modalities: - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if input_key in ("pixel_values_videos", - "video_embeds") and "videos" not in modalities: - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "images" not in modalities + ): + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "videos" not in modalities + ): + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: - + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] @@ -1505,9 +1609,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.visual is None: skip_prefixes.extend(["visual."]) @@ -1530,7 +1632,6 @@ class Tarsier2MultiModalProcessor(Qwen2VLMultiModalProcessor): class Tarsier2ImageProcessor(Qwen2VLImageProcessor): - def __init__( self, size: Optional[dict[str, int]] = None, @@ -1540,7 +1641,7 @@ def __init__( # Remap if Tarsier2-specific format is provided remapped_size = { "shortest_edge": size["min_pixels"], - "longest_edge": size["max_pixels"] + "longest_edge": size["max_pixels"], } super().__init__(size=remapped_size, **kwargs) else: @@ -1548,7 +1649,6 @@ def __init__( class Tarsier2Processor(Qwen2VLProcessor): - def __init__( self, vision_config: dict, @@ -1561,11 +1661,11 @@ def __init__( tokenizer=tokenizer, video_processor=Qwen2VLVideoProcessor(**vision_config), chat_template=None, - **kwargs) + **kwargs, + ) class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo): - def get_hf_config(self) -> Qwen2VLConfig: model_path = self.ctx.model_config.model original_config = AutoConfig.from_pretrained(model_path) @@ -1582,17 +1682,20 @@ def get_hf_processor(self, **kwargs: object) -> Tarsier2Processor: ) def get_image_processor(self) -> Tarsier2ImageProcessor: - return Tarsier2ImageProcessor( - **self.ctx.get_hf_image_processor_config()) + return Tarsier2ImageProcessor(**self.ctx.get_hf_image_processor_config()) -@MULTIMODAL_REGISTRY.register_processor(Tarsier2MultiModalProcessor, - info=Tarsier2ProcessingInfo, - dummy_inputs=Qwen2VLDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + Tarsier2MultiModalProcessor, + info=Tarsier2ProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder, +) class Tarsier2ForConditionalGeneration(Qwen2VLForConditionalGeneration): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "vision_tower.": "visual.", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "vision_tower.": "visual.", + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Tarsier2 uses llava as model_type, which will create a Qwen2VLConfig @@ -1603,9 +1706,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config = qwen2vl_config super().__init__(vllm_config=vllm_config, prefix=prefix) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.visual is None: skip_prefixes.extend(["visual."]) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index ae72fd30c399..bcd4968ba5c4 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3 model compatible with HuggingFace weights.""" + from collections.abc import Iterable from typing import Any, Optional, Union @@ -35,8 +36,7 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope @@ -46,14 +46,12 @@ from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - maybe_prefix) +from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix logger = init_logger(__name__) class Qwen3Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -131,7 +129,9 @@ def __init__( **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}, + } + if dual_chunk_attention_config + else {}, ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -144,12 +144,10 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # Add qk-norm - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, - self.head_dim) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, - self.head_dim) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) @@ -159,7 +157,6 @@ def forward( class Qwen3DecoderLayer(nn.Module): - def __init__( self, config: Qwen3Config, @@ -172,9 +169,9 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) - dual_chunk_attention_config = getattr(config, - "dual_chunk_attention_config", - None) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) # By default, Qwen3 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable @@ -192,8 +189,8 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, rms_norm_eps=config.rms_norm_eps, - qkv_bias=getattr(config, 'attention_bias', False), - head_dim=getattr(config, 'head_dim', None), + qkv_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), cache_config=cache_config, quant_config=quant_config, rope_scaling=rope_scaling, @@ -208,10 +205,10 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -224,16 +221,14 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -251,13 +246,13 @@ def forward( "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class Qwen3Model(Qwen2Model): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - decoder_layer_type=Qwen3DecoderLayer) + super().__init__( + vllm_config=vllm_config, prefix=prefix, decoder_layer_type=Qwen3DecoderLayer + ) class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): @@ -283,25 +278,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.model = Qwen3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Qwen3Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers @@ -320,8 +318,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -331,11 +330,9 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 61f1abad72b6..34b5af846493 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" + import typing from collections.abc import Callable, Iterable from itertools import islice @@ -33,38 +34,51 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class Qwen3MoeMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -76,19 +90,24 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -99,7 +118,6 @@ def forward(self, x): class Qwen3MoeSparseMoeBlock(nn.Module): - def __init__( self, vllm_config: VllmConfig, @@ -123,7 +141,8 @@ def __init__( if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") + f"the number of experts {config.num_experts}." + ) # Load balancing settings. vllm_config = get_current_vllm_config() @@ -132,36 +151,40 @@ def __init__( self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) - - self.experts = FusedMoE(num_experts=self.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=True, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel) - - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=True, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - assert hidden_states.dim( - ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" + assert hidden_states.dim() <= 2, ( + "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" + ) is_input_1d = hidden_states.dim() == 1 num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -171,21 +194,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( - final_hidden_states, 0) + final_hidden_states, 0 + ) final_hidden_states = final_hidden_states[:num_tokens] # return to 1d if input is 1d - return final_hidden_states.squeeze(0) if is_input_1d else \ - final_hidden_states + return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states class Qwen3MoeAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -226,19 +249,23 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.dual_chunk_attention_config = dual_chunk_attention_config - self.qkv_proj = QKVParallelLinear(hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) self.rotary_emb = get_rope( self.head_dim, @@ -259,7 +286,9 @@ def __init__( **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}, + } + if dual_chunk_attention_config + else {}, ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -273,13 +302,11 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # Add qk-norm - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, - self.head_dim) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, - self.head_dim) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) @@ -289,7 +316,6 @@ def forward( class Qwen3MoeDecoderLayer(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() @@ -300,11 +326,10 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - dual_chunk_attention_config = getattr(config, - "dual_chunk_attention_config", - None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + dual_chunk_attention_config = getattr( + config, "dual_chunk_attention_config", None + ) self.self_attn = Qwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -313,8 +338,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, rms_norm_eps=config.rms_norm_eps, - qkv_bias=getattr(config, 'attention_bias', False), - head_dim=getattr(config, 'head_dim', None), + qkv_bias=getattr(config, "attention_bias", False), + head_dim=getattr(config, "head_dim", None), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", @@ -323,23 +348,27 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: # `mlp_only_layers` in the config. layer_idx = extract_layer_index(prefix) - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) if (layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 and - (layer_idx + 1) % config.decoder_sparse_step == 0): - self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config, - prefix=f"{prefix}.mlp") + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoeSparseMoeBlock( + vllm_config=vllm_config, prefix=f"{prefix}.mlp" + ) else: - self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.mlp = Qwen3MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -352,23 +381,20 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class Qwen3MoeModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -385,17 +411,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, - prefix=prefix), + lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -420,10 +446,9 @@ def forward( for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -435,10 +460,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=self.num_redundant_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -449,15 +474,24 @@ def load_weights(self, weights: Iterable[tuple[str, ] # Skip loading extra parameters for GPTQ/modelopt models. - ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", - ".v_scale", "_v_scale", ".weight_scale", - "_weight_scale", ".input_scale", "_input_scale") + ignore_suffixes = ( + ".bias", + "_bias", + ".k_scale", + "_k_scale", + ".v_scale", + "_v_scale", + ".weight_scale", + "_weight_scale", + ".input_scale", + "_input_scale", + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue @@ -487,8 +521,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) else: @@ -513,23 +546,27 @@ def load_weights(self, weights: Iterable[tuple[str, continue # Skip loading extra parameters for GPTQ/modelopt models. - if name_mapped.endswith( - ignore_suffixes - ) and name_mapped not in params_dict: + if ( + name_mapped.endswith(ignore_suffixes) + and name_mapped not in params_dict + ): continue param = params_dict[name_mapped] # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) if success: name = name_mapped break @@ -541,8 +578,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue # Skip loading extra parameters for GPTQ/modelopt models. - if name.endswith( - ignore_suffixes) and name not in params_dict: + if name.endswith(ignore_suffixes) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -550,7 +586,8 @@ def load_weights(self, weights: Iterable[tuple[str, # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") + ".kv_scale", ".attn.kv_scale" + ) if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 @@ -561,15 +598,15 @@ def load_weights(self, weights: Iterable[tuple[str, else: name = remapped_kv_scale_name param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, - MixtureOfExperts): +class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -590,17 +627,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Qwen3MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + self.model = Qwen3MoeModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) # Set MoE hyperparameters self.expert_weights = [] @@ -652,8 +693,7 @@ def update_physical_experts_metadata( assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) + self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): moe = layer.mlp @@ -672,8 +712,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -683,8 +724,7 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 14d19874a51e..cea3faf45a14 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Qwen3Next model.""" + from collections.abc import Iterable from itertools import islice from typing import Optional @@ -13,40 +14,58 @@ from vllm.attention import Attention, AttentionBackend, AttentionMetadata from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, - VllmConfig, get_current_vllm_config) -from vllm.distributed import (divide, get_ep_group, get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.config import ( + CacheConfig, + ModelConfig, + SpeculativeConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import ( + divide, + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.fla.ops import ( - RMSNormGated, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) + RMSNormGated, + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, +) from vllm.model_executor.layers.fused_moe import FusedMoE -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.layers.layernorm import ( - GemmaRMSNorm as Qwen3NextRMSNorm) -# yapf: enable -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.layernorm import GemmaRMSNorm as Qwen3NextRMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - mamba_v2_sharded_weight_loader) +from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( - causal_conv1d_fn, causal_conv1d_update) + causal_conv1d_fn, + causal_conv1d_update, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, sharded_weight_loader) + default_weight_loader, + sharded_weight_loader, +) from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.utils import set_weight_attrs @@ -57,12 +76,22 @@ from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from .interfaces import (HasInnerState, IsHybrid, MixtureOfExperts, - SupportsLoRA, SupportsPP) -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .interfaces import ( + HasInnerState, + IsHybrid, + MixtureOfExperts, + SupportsLoRA, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -70,7 +99,6 @@ class Qwen3NextSparseMoeBlock(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -90,7 +118,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") + f"the number of experts {config.num_experts}." + ) # Load balancing settings. vllm_config = get_current_vllm_config() @@ -99,32 +128,35 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts - self.n_physical_experts = (self.n_logical_experts + - self.n_redundant_experts) + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) - - self.experts = FusedMoE(num_experts=self.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts, - is_sequence_parallel=self.is_sequence_parallel) - - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.experts = FusedMoE( + num_experts=self.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen3NextMLP( @@ -132,15 +164,12 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), + reduce_results=self.experts.must_reduce_shared_expert_outputs(), prefix=f"{prefix}.shared_expert", ) else: self.shared_expert = None - self.shared_expert_gate = torch.nn.Linear(config.hidden_size, - 1, - bias=False) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -155,46 +184,57 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_expert is not None: shared_output = self.shared_expert(hidden_states) if self.shared_expert_gate is not None: - shared_output = F.sigmoid( - self.shared_expert_gate(hidden_states)) * shared_output + shared_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output + ) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( - final_hidden_states, 0) + final_hidden_states, 0 + ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 - final_hidden_states) + final_hidden_states + ) return final_hidden_states.view(orig_shape) class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): - @property def mamba_type(self) -> str: return "linear_attention" def get_attn_backend(self) -> type["AttentionBackend"]: from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend + return GDNAttentionBackend def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - self.model_config.dtype, self.cache_config.mamba_cache_dtype) + self.model_config.dtype, self.cache_config.mamba_cache_dtype + ) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.gated_delta_net_state_shape( - self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim, - self.head_v_dim, self.conv_kernel_size, self.num_spec) + self.tp_size, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + self.conv_kernel_size, + self.num_spec, + ) def __init__( self, @@ -228,8 +268,11 @@ def __init__( self.cache_config = cache_config self.quant_config = quant_config self.speculative_config = speculative_config - self.num_spec = (self.speculative_config.num_speculative_tokens - if self.speculative_config else 0) + self.num_spec = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config + else 0 + ) # QKV self.conv_dim = self.key_dim * 2 + self.value_dim @@ -265,31 +308,36 @@ def __init__( delattr(self.conv1d.weight, "weight_loader") set_weight_attrs( - self.conv1d.weight, { - "weight_loader": - mamba_v2_sharded_weight_loader([ - query_key_settings, - query_key_settings, - value_settings, - ], self.tp_size, self.tp_rank) - }) + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + query_key_settings, + query_key_settings, + value_settings, + ], + self.tp_size, + self.tp_rank, + ) + }, + ) # selective projection used to make dt, B and C input dependant # time step projection (discretization) # instantiate once and copy inv_dt in init_weights of PretrainedModel self.dt_bias = nn.Parameter( - torch.ones(self.num_v_heads // self.tp_size), ) + torch.ones(self.num_v_heads // self.tp_size), + ) self.A_log = nn.Parameter( torch.empty( divide(self.num_v_heads, self.tp_size), dtype=torch.float32, - )) + ) + ) - set_weight_attrs(self.A_log, - {"weight_loader": sharded_weight_loader(0)}) - set_weight_attrs(self.dt_bias, - {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) self.norm = RMSNormGated( self.head_v_dim, @@ -300,12 +348,14 @@ def __init__( dtype=config.torch_dtype, ) - self.out_proj = RowParallelLinear(self.value_dim, - self.hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj") + self.out_proj = RowParallelLinear( + self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -322,9 +372,13 @@ def fix_query_key_value_ordering( """ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( self.num_k_heads // self.tp_size, - (self.head_k_dim + self.head_k_dim + - (self.head_v_dim + self.head_v_dim) * self.num_v_heads // - self.num_k_heads), + ( + self.head_k_dim + + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) + * self.num_v_heads + // self.num_k_heads + ), ) new_tensor_shape_ba = mixed_qkvz.size()[:-1] + ( self.num_k_heads // self.tp_size, @@ -342,15 +396,13 @@ def fix_query_key_value_ordering( ] split_arg_list_ba = [ self.num_v_heads // self.num_k_heads, - self.num_v_heads // self.num_k_heads + self.num_v_heads // self.num_k_heads, ] # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], # [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] - (query, key, value, z) = torch.split(mixed_qkvz, - split_arg_list_qkvz, - dim=2) + (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] @@ -374,9 +426,10 @@ def rearrange_mixed_qkv(self, mixed_qkv): dim=-1, ) query, key = map( - lambda x: rearrange(x, 'l (h d) -> 1 l h d', d=self.head_k_dim), - (query, key)) - value = rearrange(value, 'l (h d) -> 1 l h d', d=self.head_v_dim) + lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), + (query, key), + ) + value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) return query, key, value def forward( @@ -421,23 +474,23 @@ def _forward( spec_token_masks = spec_token_masks[:num_actual_tokens] # 1. Set up dimensions for reshapes later - projected_states_qkvz, _ = self.in_proj_qkvz( - hidden_states[:num_actual_tokens]) - projected_states_ba, _ = self.in_proj_ba( - hidden_states[:num_actual_tokens]) + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens]) + projected_states_ba, _ = self.in_proj_ba(hidden_states[:num_actual_tokens]) query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba) - query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'), - (query, key, value)) + projected_states_qkvz, projected_states_ba + ) + query, key, value = map( + lambda x: rearrange(x, "l p d -> l (p d)"), (query, key, value) + ) mixed_qkv = torch.cat((query, key, value), dim=-1) # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), - self.conv1d.weight.size(2)) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) if spec_sequence_masks is not None: - if (attn_metadata.num_prefills == 0 - and attn_metadata.num_decodes == 0): + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: mixed_qkv_spec = mixed_qkv mixed_qkv_non_spec = None else: @@ -455,8 +508,9 @@ def _forward( conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=spec_state_indices_tensor[:, 0] - [:attn_metadata.num_spec_decodes], + conv_state_indices=spec_state_indices_tensor[:, 0][ + : attn_metadata.num_spec_decodes + ], num_accepted_tokens=num_accepted_tokens, query_start_loc=spec_query_start_loc, max_query_len=spec_state_indices_tensor.size(-1), @@ -486,26 +540,26 @@ def _forward( conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=non_spec_state_indices_tensor[:attn_metadata - .num_decodes], + conv_state_indices=non_spec_state_indices_tensor[ + : attn_metadata.num_decodes + ], validate_data=True, ) else: mixed_qkv_non_spec = None - query_spec, key_spec, value_spec = self.rearrange_mixed_qkv( - mixed_qkv_spec) + query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec) query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv( - mixed_qkv_non_spec) + mixed_qkv_non_spec + ) beta = b.sigmoid() # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) g = fused_gdn_gating(self.A_log, a, self.dt_bias) - g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta)) + g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta)) if spec_sequence_masks is not None: - if (attn_metadata.num_prefills == 0 - and attn_metadata.num_decodes == 0): + if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: g_spec = g beta_spec = beta g_non_spec = None @@ -525,28 +579,25 @@ def _forward( # 3.1: process the mutlti-query part if spec_sequence_masks is not None: - core_attn_out_spec, last_recurrent_state = ( - fused_recurrent_gated_delta_rule( - q=query_spec, - k=key_spec, - v=value_spec, - g=g_spec, - beta=beta_spec, - initial_state=ssm_state, - inplace_final_state=True, - cu_seqlens=spec_query_start_loc[:attn_metadata. - num_spec_decodes + 1], - ssm_state_indices=spec_state_indices_tensor, - num_accepted_tokens=num_accepted_tokens, - use_qk_l2norm_in_kernel=True, - )) + core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( + q=query_spec, + k=key_spec, + v=value_spec, + g=g_spec, + beta=beta_spec, + initial_state=ssm_state, + inplace_final_state=True, + cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], + ssm_state_indices=spec_state_indices_tensor, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=True, + ) else: core_attn_out_spec, last_recurrent_state = None, None # 3.2: process the remaining part if attn_metadata.num_prefills > 0: - initial_state = ssm_state[ - non_spec_state_indices_tensor].contiguous() + initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 ( core_attn_out_non_spec, @@ -565,7 +616,8 @@ def _forward( ) # Init cache ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( - ssm_state.dtype) + ssm_state.dtype + ) elif attn_metadata.num_decodes > 0: core_attn_out_non_spec, last_recurrent_state = ( fused_recurrent_gated_delta_rule( @@ -576,17 +628,18 @@ def _forward( beta=beta_non_spec, initial_state=ssm_state, inplace_final_state=True, - cu_seqlens=non_spec_query_start_loc[:attn_metadata. - num_decodes + 1], + cu_seqlens=non_spec_query_start_loc[ + : attn_metadata.num_decodes + 1 + ], ssm_state_indices=non_spec_state_indices_tensor, use_qk_l2norm_in_kernel=True, - )) + ) + ) else: core_attn_out_non_spec, last_recurrent_state = None, None # Merge core attention output - if (spec_sequence_masks is not None - and core_attn_out_non_spec is not None): + if spec_sequence_masks is not None and core_attn_out_non_spec is not None: core_attn_out = torch.empty( (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), dtype=core_attn_out_non_spec.dtype, @@ -605,13 +658,12 @@ def _forward( z = z.reshape(-1, z.shape[-1]) core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)') + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_actual_tokens], _ = self.out_proj(core_attn_out) class Qwen3NextAttention(nn.Module): - def __init__( self, config: Qwen3NextConfig, @@ -642,7 +694,8 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.dual_chunk_attention_config = getattr( - config, "dual_chunk_attention_config", None) + config, "dual_chunk_attention_config", None + ) self.attn_output_gate = getattr(config, "attn_output_gate", True) self.qkv_proj = QKVParallelLinear( @@ -683,9 +736,10 @@ def __init__( prefix=f"{prefix}.attn", **{ "layer_idx": extract_layer_index(prefix), - "dual_chunk_attention_config": - self.dual_chunk_attention_config, - } if self.dual_chunk_attention_config else {}, + "dual_chunk_attention_config": self.dual_chunk_attention_config, + } + if self.dual_chunk_attention_config + else {}, ) self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) @@ -701,20 +755,22 @@ def forward( if self.attn_output_gate: q_gate, k, v = qkv.split( - [self.q_size * 2, self.kv_size, self.kv_size], dim=-1) + [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 + ) orig_shape = q_gate.shape[:-1] q_gate = q_gate.view(*orig_shape, self.num_heads, -1) q, gate = torch.chunk(q_gate, 2, dim=-1) q = q.reshape(*orig_shape, -1) gate = gate.reshape(*orig_shape, -1) else: - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view( - -1, self.num_heads * self.head_dim) + -1, self.num_heads * self.head_dim + ) k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view( - -1, self.num_kv_heads * self.head_dim) + -1, self.num_kv_heads * self.head_dim + ) q, k = self.rotary_emb(positions, q, k) @@ -728,7 +784,6 @@ def forward( class Qwen3NextDecoderLayer(nn.Module): - def __init__( self, vllm_config: VllmConfig, @@ -753,23 +808,26 @@ def __init__( cache_config=cache_config, quant_config=quant_config, speculative_config=speculative_config, - prefix=f'{prefix}.linear_attn') + prefix=f"{prefix}.linear_attn", + ) elif self.layer_type == "full_attention": self.self_attn = Qwen3NextAttention( config, model_config=model_config, cache_config=cache_config, quant_config=quant_config, - prefix=f'{prefix}.self_attn', + prefix=f"{prefix}.self_attn", ) else: raise ValueError(f"Invalid layer_type {self.layer_type}") - mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else - config.mlp_only_layers) + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) if (self.layer_idx not in mlp_only_layers) and ( - config.num_experts > 0 and - (self.layer_idx + 1) % config.decoder_sparse_step == 0): + config.num_experts > 0 + and (self.layer_idx + 1) % config.decoder_sparse_step == 0 + ): self.mlp = Qwen3NextSparseMoeBlock( vllm_config=vllm_config, prefix=f"{prefix}.mlp", @@ -782,10 +840,12 @@ def __init__( quant_config=quant_config, ) - self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.post_attention_layernorm = Qwen3NextRMSNorm( - config.hidden_size, eps=config.rms_norm_eps) + config.hidden_size, eps=config.rms_norm_eps + ) self.layer_scale = getattr(config, "layer_scale", False) if self.layer_scale: @@ -795,14 +855,16 @@ def __init__( 1, config.hidden_size, dtype=config.torch_dtype, - ), ) + ), + ) self.ffn_layer_scale = torch.nn.Parameter( torch.zeros( 1, 1, config.hidden_size, dtype=config.torch_dtype, - ), ) + ), + ) def forward( self, @@ -815,8 +877,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) self_attention_output = torch.empty_like(hidden_states) if self.layer_type == "linear_attention": @@ -837,33 +898,36 @@ def forward( if self.layer_scale: if len(hidden_states.shape) == 2: hidden_states = hidden_states * ( - self.attn_layer_scale.to(hidden_states.dtype)[0] + 1) + self.attn_layer_scale.to(hidden_states.dtype)[0] + 1 + ) else: hidden_states = hidden_states * ( - self.attn_layer_scale.to(hidden_states.dtype) + 1) + self.attn_layer_scale.to(hidden_states.dtype) + 1 + ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) if self.layer_scale: if len(hidden_states.shape) == 2: hidden_states = hidden_states * ( - self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1) + self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1 + ) else: - assert len(hidden_states.shape) == len( - self.ffn_layer_scale.shape - ), f'shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}' # noqa: E501 + assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), ( + f"shape must be the same {len(hidden_states.shape)}, " + f"{len(self.ffn_layer_scale.shape)}" + ) hidden_states = hidden_states * ( - self.ffn_layer_scale.to(hidden_states.dtype) + 1) + self.ffn_layer_scale.to(hidden_states.dtype) + 1 + ) return hidden_states, residual @support_torch_compile class Qwen3NextModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -874,8 +938,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.embed_tokens = VocabParallelEmbedding( @@ -892,14 +959,14 @@ def get_layer(prefix: str): ) self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: - self.norm = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() @@ -932,10 +999,9 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -947,10 +1013,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=self.num_redundant_experts, + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -1001,16 +1067,19 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -1019,15 +1088,17 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params -class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, - MixtureOfExperts, IsHybrid): +class Qwen3NextForCausalLM( + nn.Module, HasInnerState, SupportsLoRA, SupportsPP, MixtureOfExperts, IsHybrid +): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1044,15 +1115,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ + assert not cache_config.enable_prefix_caching, ( "Qwen3Next currently does not support prefix caching" + ) self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config - self.model = Qwen3NextModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Qwen3NextModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -1063,12 +1136,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, - prefix=maybe_prefix(prefix, "lm_head")) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + if not lora_config + else lora_config.lora_vocab_padding_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) # Set MoE hyperparameters self.expert_weights = [] @@ -1120,8 +1197,7 @@ def update_physical_experts_metadata( assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = (num_physical_experts - - self.num_logical_experts) + self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): moe = layer.mlp @@ -1141,8 +1217,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states @@ -1152,23 +1229,30 @@ def get_mamba_state_dtype_from_config( vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.gated_delta_net_state_dtype( - vllm_config.model_config.dtype, - vllm_config.cache_config.mamba_cache_dtype) + vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype + ) @classmethod def get_mamba_state_shape_from_config( - cls, vllm_config: "VllmConfig" + cls, vllm_config: "VllmConfig" ) -> tuple[tuple[int, int], tuple[int, int]]: parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config tp_size = parallel_config.tensor_parallel_size - num_spec = (vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config else 0) + num_spec = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ) return MambaStateShapeCalculator.gated_delta_net_state_shape( - tp_size, hf_config.linear_num_key_heads, - hf_config.linear_num_value_heads, hf_config.linear_key_head_dim, - hf_config.linear_value_head_dim, hf_config.linear_conv_kernel_dim, - num_spec) + tp_size, + hf_config.linear_num_key_heads, + hf_config.linear_num_value_heads, + hf_config.linear_key_head_dim, + hf_config.linear_value_head_dim, + hf_config.linear_conv_kernel_dim, + num_spec, + ) def compute_logits( self, @@ -1176,8 +1260,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.logits_processor(self.lm_head, hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=["mtp."], @@ -1236,8 +1319,9 @@ def fused_gdn_gating_kernel( blk_bias = tl.load(dt_bias + head_off, mask=mask) # If the model is loaded in fp16, without the .float() here, A might be -inf x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) - softplus_x = tl.where(beta * x <= threshold, - (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) @@ -1253,14 +1337,7 @@ def fused_gdn_gating( seq_len = 1 grid = (batch, seq_len, triton.cdiv(num_heads, 8)) g = torch.empty_like(a, dtype=torch.float32) - fused_gdn_gating_kernel[grid](g, - A_log, - a, - dt_bias, - seq_len, - num_heads, - beta, - threshold, - 8, - num_warps=1) + fused_gdn_gating_kernel[grid]( + g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1 + ) return g diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index e950699a0c49..828931716c8f 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Qwen3Next MTP model.""" + from collections.abc import Iterable from typing import Optional @@ -15,16 +16,25 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer, - Qwen3NextRMSNorm) +from vllm.model_executor.models.qwen3_next import ( + Qwen3NextDecoderLayer, + Qwen3NextRMSNorm, +) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) logger = init_logger(__name__) @@ -33,7 +43,6 @@ @support_torch_compile class Qwen3NextMultiTokenPredictor(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -43,8 +52,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config: Qwen3NextConfig = model_config.hf_config self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -57,31 +69,36 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) - self.fc = ColumnParallelLinear(self.config.hidden_size * 2, - self.config.hidden_size, - gather_output=True, - bias=False, - return_bias=False, - quant_config=quant_config, - prefix=f'{prefix}.fc') + self.fc = ColumnParallelLinear( + self.config.hidden_size * 2, + self.config.hidden_size, + gather_output=True, + bias=False, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fc", + ) self.layers = torch.nn.ModuleList( Qwen3NextDecoderLayer( vllm_config, layer_type="full_attention", - prefix=f'{prefix}.layers.{idx}', - ) for idx in range(self.num_mtp_layers)) + prefix=f"{prefix}.layers.{idx}", + ) + for idx in range(self.num_mtp_layers) + ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) - self.norm = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_fc_norm_hidden = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_fc_norm_embedding = Qwen3NextRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -109,7 +126,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - current_step_idx = (spec_step_idx % self.num_mtp_layers) + current_step_idx = spec_step_idx % self.num_mtp_layers hidden_states, residual = self.layers[current_step_idx]( positions=positions, hidden_states=hidden_states, @@ -117,16 +134,14 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -142,7 +157,8 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) + num_experts=self.config.num_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -180,16 +196,19 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: # Skip loading extra bias for GPTQ models. @@ -199,8 +218,9 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -214,33 +234,38 @@ class Qwen3NextMTP(nn.Module, SupportsPP): "k_proj", "v_proj", ], - "gate_up_proj": ["up_proj", "down_proj"] + "gate_up_proj": ["up_proj", "down_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config cache_config = vllm_config.cache_config - assert not cache_config.enable_prefix_caching, \ + assert not cache_config.enable_prefix_caching, ( "Qwen3NextMTP currently does not support prefix caching" + ) self.quant_config = vllm_config.quant_config super().__init__() self.config = config - self.model = Qwen3NextMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "mtp")) + self.model = Qwen3NextMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp") + ) self.unpadded_vocab_size = config.vocab_size - self.lm_head = ParallelLMHead(self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - prefix=maybe_prefix(prefix, "lm_head")) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -254,8 +279,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): - hidden_states = self.model(input_ids, positions, hidden_states, - intermediate_tensors, inputs_embeds) + hidden_states = self.model( + input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -265,8 +291,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.logits_processor(self.lm_head, hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: shared_weight_names = ["embed_tokens", "lm_head"] def remap_weight_names(weights): diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 822c8d6d5f30..1c532376256d 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -23,6 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3VL model compatible with HuggingFace weights.""" + from collections.abc import Iterable, Mapping, Sequence from functools import partial from typing import Any, Callable, Optional, Union @@ -34,13 +35,16 @@ from transformers import BatchFeature from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( - smart_resize as image_smart_resize) -from transformers.models.qwen3_vl import (Qwen3VLProcessor, - Qwen3VLVideoProcessor) + smart_resize as image_smart_resize, +) +from transformers.models.qwen3_vl import Qwen3VLProcessor, Qwen3VLVideoProcessor from transformers.models.qwen3_vl.configuration_qwen3_vl import ( - Qwen3VLConfig, Qwen3VLVisionConfig) + Qwen3VLConfig, + Qwen3VLVisionConfig, +) from transformers.models.qwen3_vl.video_processing_qwen3_vl import ( - smart_resize as video_smart_resize) + smart_resize as video_smart_resize, +) from transformers.video_utils import VideoMetadata from vllm.attention.backends.registry import _Backend @@ -51,38 +55,56 @@ from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItem, - MultiModalKwargsItems, VideoItem) -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItem, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .qwen2_5_vl import (Qwen2_5_VisionAttention, - Qwen2_5_VisionRotaryEmbedding, - Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, - Qwen2_5_VLImagePixelInputs, - Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, - Qwen2_5_VLVideoPixelInputs) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .qwen2_5_vl import ( + Qwen2_5_VisionAttention, + Qwen2_5_VisionRotaryEmbedding, + Qwen2_5_VLImageEmbeddingInputs, + Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLVideoEmbeddingInputs, + Qwen2_5_VLVideoInputs, + Qwen2_5_VLVideoPixelInputs, +) from .qwen2_vl import Qwen2VLProcessingInfo from .qwen3 import Qwen3ForCausalLM, Qwen3Model -from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - _merge_multimodal_embeddings, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model logger = init_logger(__name__) @@ -92,7 +114,6 @@ class Qwen3_VisionPatchEmbed(nn.Module): - def __init__( self, patch_size: int = 14, @@ -106,45 +127,51 @@ def __init__( self.hidden_size = hidden_size kernel_size = (temporal_patch_size, patch_size, patch_size) - self.proj = nn.Conv3d(in_channels, - hidden_size, - kernel_size=kernel_size, - stride=kernel_size, - bias=True) + self.proj = nn.Conv3d( + in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: L, C = x.shape - x = x.view(L, -1, self.temporal_patch_size, self.patch_size, - self.patch_size) + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size) x = self.proj(x).view(L, self.hidden_size) return x class Qwen3_VisionMLP(nn.Module): - - def __init__(self, - in_features: int, - hidden_features: int, - bias: bool = False, - act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() - self.linear_fc1 = ColumnParallelLinear(in_features, - hidden_features, - bias=bias, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.linear_fc1", - disable_tp=use_data_parallel) - self.linear_fc2 = RowParallelLinear(hidden_features, - in_features, - bias=bias, - quant_config=quant_config, - return_bias=False, - prefix=f"{prefix}.linear_fc2", - disable_tp=use_data_parallel) + self.linear_fc1 = ColumnParallelLinear( + in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc1", + disable_tp=use_data_parallel, + ) + self.linear_fc2 = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc2", + disable_tp=use_data_parallel, + ) self.act_fn = act_fn def forward(self, x: torch.Tensor): @@ -153,7 +180,6 @@ def forward(self, x: torch.Tensor): class Qwen3_VisionBlock(nn.Module): - def __init__( self, dim: int, @@ -180,35 +206,39 @@ def __init__( prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, attn_backend=attn_backend, - use_upstream_fa=use_upstream_fa) - self.mlp = Qwen3_VisionMLP(dim, - mlp_hidden_dim, - act_fn=act_fn, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) + use_upstream_fa=use_upstream_fa, + ) + self.mlp = Qwen3_VisionMLP( + dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor, - max_seqlen: Optional[int] = None, # Only used for Flash Attention - seqlens: Optional[list[int]] = None, # Only used for xFormers + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - x = x + self.attn(self.norm1(x), - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - max_seqlen=max_seqlen, - seqlens=seqlens) + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) x = x + self.mlp(self.norm2(x)) return x class Qwen3_VisionPatchMerger(nn.Module): - def __init__( self, d_model: int, @@ -230,19 +260,23 @@ def __init__( if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm = norm_layer(context_dim) - self.linear_fc1 = ColumnParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.linear_fc1", - disable_tp=use_data_parallel) + self.linear_fc1 = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc1", + disable_tp=use_data_parallel, + ) self.act_fn = nn.GELU() - self.linear_fc2 = RowParallelLinear(self.hidden_size, - d_model, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.linear_fc2", - disable_tp=use_data_parallel) + self.linear_fc2 = RowParallelLinear( + self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc2", + disable_tp=use_data_parallel, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_postshuffle_norm: @@ -257,7 +291,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Qwen3_VisionTransformer(nn.Module): - def __init__( self, vision_config: Qwen3VLVisionConfig, @@ -280,8 +313,9 @@ def __init__( # NOTE: This is used for creating empty tensor for all_gather for # DP ViT. Here out_hidden_size is enlarged due to deepstack - self.out_hidden_size = (vision_config.out_hidden_size * - (1 + len(self.deepstack_visual_indexes))) + self.out_hidden_size = vision_config.out_hidden_size * ( + 1 + len(self.deepstack_visual_indexes) + ) self.patch_embed = Qwen3_VisionPatchEmbed( patch_size=self.patch_size, @@ -290,8 +324,7 @@ def __init__( hidden_size=self.hidden_size, ) - self.pos_embed = nn.Embedding(self.num_position_embeddings, - self.hidden_size) + self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads @@ -307,50 +340,61 @@ def __init__( use_data_parallel=use_data_parallel, ) - self.deepstack_merger_list = nn.ModuleList([ - Qwen3_VisionPatchMerger( - d_model=vision_config.out_hidden_size, - context_dim=self.hidden_size, - spatial_merge_size=self.spatial_merge_size, - use_postshuffle_norm=True, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", - use_data_parallel=use_data_parallel) - for layer_idx in range(len(self.deepstack_visual_indexes)) - ]) + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(len(self.deepstack_visual_indexes)) + ] + ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype()) + head_size=head_dim, dtype=torch.get_default_dtype() + ) use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - self.attn_backend != _Backend.ROCM_AITER_FA and \ - check_upstream_fa_availability( - torch.get_default_dtype()): + if ( + self.attn_backend != _Backend.FLASH_ATTN + and self.attn_backend != _Backend.ROCM_AITER_FA + and check_upstream_fa_availability(torch.get_default_dtype()) + ): self.attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( - f"Qwen3-VL does not support {self.attn_backend} backend now.") - - self.blocks = nn.ModuleList([ - Qwen3_VisionBlock( - dim=self.hidden_size, - num_heads=self.num_heads, - mlp_hidden_dim=vision_config.intermediate_size, - act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}", - use_data_parallel=use_data_parallel, - attn_backend=self.attn_backend, - use_upstream_fa=use_upstream_fa) - for layer_idx in range(vision_config.depth) - ]) + f"Qwen3-VL does not support {self.attn_backend} backend now." + ) + + self.blocks = nn.ModuleList( + [ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend=self.attn_backend, + use_upstream_fa=use_upstream_fa, + ) + for layer_idx in range(vision_config.depth) + ] + ) @property def dtype(self) -> torch.dtype: @@ -389,32 +433,25 @@ def rot_pos_emb(self, grid_thw): ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def fast_pos_embed_interpolate(self, - grid_thw: list[list[int]]) -> torch.Tensor: - + def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: num_grid_per_side = self.num_grid_per_side m_size = self.spatial_merge_size hidden_dim = self.pos_embed.embedding_dim outputs = [] for t, h, w in grid_thw: - h_idxs = torch.linspace(0, - num_grid_per_side - 1, - h, - dtype=torch.float32, - device=self.device) - w_idxs = torch.linspace(0, - num_grid_per_side - 1, - w, - dtype=torch.float32, - device=self.device) + h_idxs = torch.linspace( + 0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device + ) + w_idxs = torch.linspace( + 0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device + ) h_floor = h_idxs.to(torch.long) w_floor = w_idxs.to(torch.long) @@ -425,13 +462,9 @@ def fast_pos_embed_interpolate(self, dw = w_idxs - w_floor # Create meshgrid view for all h, w vars - dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij') - h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, - w_floor, - indexing='ij') - h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, - w_ceil, - indexing='ij') + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij") + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij") + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij") h_floor_grid_idx = h_floor_grid * num_grid_per_side h_ceil_grid_idx = h_ceil_grid * num_grid_per_side @@ -452,10 +485,8 @@ def fast_pos_embed_interpolate(self, idx10 = h_ceil_grid_idx + w_floor_grid idx11 = h_ceil_grid_idx + w_ceil_grid - indices = torch.stack([idx00, idx01, idx10, idx11], - dim=0).reshape(4, -1) - weights = torch.stack([w00, w01, w10, w11], - dim=0).reshape(4, -1, 1) + indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1) + weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1) weights = weights.to(dtype=self.dtype, device=self.device) embeds = self.pos_embed(indices) @@ -465,10 +496,10 @@ def fast_pos_embed_interpolate(self, combined = combined.view(h * w, hidden_dim) repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() - repeated = repeated.view(t, h // m_size, m_size, w // m_size, - m_size, hidden_dim) - repeated = repeated.permute(0, 1, 3, 2, 4, - 5).reshape(-1, hidden_dim) + repeated = repeated.view( + t, h // m_size, m_size, w // m_size, m_size, hidden_dim + ) + repeated = repeated.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim) outputs.append(repeated) return torch.cat(outputs, dim=0) @@ -478,8 +509,10 @@ def compute_attn_mask_seqlen( cu_seqlens: torch.Tensor, ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if (self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA): + if ( + self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA + ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -497,17 +530,14 @@ def forward( hidden_states = hidden_states + pos_embeds rotary_pos_emb = self.rot_pos_emb(grid_thw) - grid_thw_tensor = torch.tensor(grid_thw, - device=self.device, - dtype=torch.int32) + grid_thw_tensor = torch.tensor(grid_thw, device=self.device, dtype=torch.int32) cu_seqlens = torch.repeat_interleave( - grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], - grid_thw_tensor[:, 0]).cumsum( - dim=0, - dtype=grid_thw_tensor.dtype - if torch.jit.is_tracing() else torch.int32, - ) + grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], grid_thw_tensor[:, 0] + ).cumsum( + dim=0, + dtype=grid_thw_tensor.dtype if torch.jit.is_tracing() else torch.int32, + ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) hidden_states = hidden_states.unsqueeze(1) @@ -516,25 +546,26 @@ def forward( deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): - hidden_states = blk(hidden_states, - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - max_seqlen=max_seqlen, - seqlens=seqlens) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) if layer_num in self.deepstack_visual_indexes: - deepstack_merger_idx = self.deepstack_visual_indexes.index( - layer_num) - deepstack_feature = self.deepstack_merger_list[ - deepstack_merger_idx](hidden_states) + deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_feature = self.deepstack_merger_list[deepstack_merger_idx]( + hidden_states + ) deepstack_feature_lists.append(deepstack_feature) hidden_states = self.merger(hidden_states) hidden_states = torch.cat( - [hidden_states] + deepstack_feature_lists, - dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + [hidden_states] + deepstack_feature_lists, dim=1 + ) # [seq_len, hidden_size * (1 + depth_of_deepstack)] return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("attn.qkv.", "attn.q.", "q"), @@ -545,7 +576,7 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -556,15 +587,13 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen3VLConfig) @@ -578,8 +607,7 @@ def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor: def get_tokenizer(self): return self.ctx.tokenizer - def get_image_processor(self, - **kwargs: object) -> Qwen2VLImageProcessorFast: + def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessorFast: return self.get_hf_processor(**kwargs).image_processor def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor: @@ -592,8 +620,9 @@ def _get_vision_info( image_height: int, num_frames: int = 2, do_resize: bool = True, - image_processor: Optional[Union[Qwen2VLImageProcessorFast, - Qwen3VLVideoProcessor]], + image_processor: Optional[ + Union[Qwen2VLImageProcessorFast, Qwen3VLVideoProcessor] + ], ) -> tuple[ImageSize, int]: if image_processor is None and num_frames > 1: image_processor = self.get_video_processor() @@ -613,7 +642,7 @@ def _get_vision_info( smart_resize = video_smart_resize extra_kwargs = { "num_frames": num_frames, - "temporal_factor": temporal_patch_size + "temporal_factor": temporal_patch_size, } else: smart_resize = image_smart_resize @@ -626,11 +655,9 @@ def _get_vision_info( max_pixels=image_processor.size["longest_edge"], **extra_kwargs, ) - preprocessed_size = ImageSize(width=resized_width, - height=resized_height) + preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: - preprocessed_size = ImageSize(width=image_width, - height=image_height) + preprocessed_size = ImageSize(width=image_width, height=image_height) padded_num_frames = num_frames + num_frames % temporal_patch_size @@ -643,11 +670,10 @@ def _get_vision_info( return preprocessed_size, num_vision_tokens - def _get_max_video_frames(self, - max_tokens: int, - start_num_frames: int = 2) -> int: - return super()._get_max_video_frames(max_tokens, - start_num_frames=start_num_frames) + def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 2) -> int: + return super()._get_max_video_frames( + max_tokens, start_num_frames=start_num_frames + ) def get_num_frames_with_most_features( self, @@ -655,7 +681,8 @@ def get_num_frames_with_most_features( mm_counts: Mapping[str, int], ) -> int: return super().get_num_frames_with_most_features( - seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO) + seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO + ) def get_max_video_tokens( self, @@ -666,8 +693,7 @@ def get_max_video_tokens( video_soft_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), + num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), image_processor=None, ) @@ -676,25 +702,28 @@ def get_max_video_tokens( formatted_video_soft_tokens = video_soft_tokens * 12.5 return int(formatted_video_soft_tokens) - def _calculate_timestamps(self, indices: list[int] | torch.Tensor, - video_fps: float, merge_size: int): + def _calculate_timestamps( + self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int + ): if not isinstance(indices, list): indices = indices.tolist() if len(indices) % merge_size != 0: # don't update metadata's frames_indices directly - indices = indices + [indices[-1] - ] * (merge_size - len(indices) % merge_size) + indices = indices + [indices[-1]] * (merge_size - len(indices) % merge_size) timestamps = [idx / video_fps for idx in indices] - timestamps = [(timestamps[i] + timestamps[i + merge_size - 1]) / 2 - for i in range(0, len(timestamps), merge_size)] + timestamps = [ + (timestamps[i] + timestamps[i + merge_size - 1]) / 2 + for i in range(0, len(timestamps), merge_size) + ] return timestamps def _get_video_second_idx( - self, - metadata: dict[str, Any], - out_item: MultiModalKwargsItem, - do_sample_frames: Optional[bool] = None, - sampled_fps: Optional[float] = None) -> list[int]: + self, + metadata: dict[str, Any], + out_item: MultiModalKwargsItem, + do_sample_frames: Optional[bool] = None, + sampled_fps: Optional[float] = None, + ) -> list[int]: video_processor = self.get_video_processor() merge_size = video_processor.merge_size indices = metadata["frames_indices"] @@ -714,16 +743,23 @@ def _get_video_second_idx( total_num_frames = metadata["total_num_frames"] num_frames = int(total_num_frames / metadata["fps"] * video_fps) num_frames = min( - min(max(num_frames, video_processor.min_frames), - video_processor.max_frames), total_num_frames) - indices = np.linspace(0, total_num_frames - 1, - num_frames).round().astype(int).tolist() + min( + max(num_frames, video_processor.min_frames), + video_processor.max_frames, + ), + total_num_frames, + ) + indices = ( + np.linspace(0, total_num_frames - 1, num_frames) + .round() + .astype(int) + .tolist() + ) timestamps = self._calculate_timestamps(indices, video_fps, merge_size) return timestamps class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) @@ -744,10 +780,10 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None video_overrides = mm_options.get("video") if mm_options else None - target_width, target_height = ( - self.info.get_image_size_with_most_features()) + target_width, target_height = self.info.get_image_size_with_most_features() target_num_frames = self.info.get_num_frames_with_most_features( - seq_len, mm_counts) + seq_len, mm_counts + ) if video_overrides: assert isinstance(video_overrides, VideoDummyOptions) @@ -757,11 +793,15 @@ def get_dummy_mm_data( logger.warning( "video.num_frames override (%d) exceeds model's " "maximum number of frames (%d), will be ignored", - num_frames_override, target_num_frames) + num_frames_override, + target_num_frames, + ) if num_frames_override < 2: logger.warning( "video.num_frames override (%d) cannot be less " - "than 2, will be ignored", num_frames_override) + "than 2, will be ignored", + num_frames_override, + ) target_num_frames = min(target_num_frames, num_frames_override) target_num_frames = max(target_num_frames, 2) @@ -781,8 +821,10 @@ def get_dummy_mm_data( if width_override > width: logger.warning( "video.width override (%d) exceeds model's " - "maximum width (%d), will be ignored", width_override, - width) + "maximum width (%d), will be ignored", + width_override, + width, + ) width = min(width, width_override) height_override = video_overrides.height if height_override: @@ -790,17 +832,19 @@ def get_dummy_mm_data( logger.warning( "video.height override (%d) exceeds model's " "maximum height (%d), will be ignored", - height_override, height) + height_override, + height, + ) height = min(height, height_override) return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), - "video": - self._get_dummy_videos( + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), + "video": self._get_dummy_videos( width=width, height=height, num_frames=target_num_frames, @@ -832,9 +876,7 @@ def _get_dummy_videos( return video_items -class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo] - ): - +class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: return MultiModalDataParser(video_needs_metadata=True) @@ -850,8 +892,11 @@ def _call_hf_processor( # Separate video processing from image processing. Because the videos # are processed into serval image patches - if ("videos" in mm_data and isinstance(mm_data["videos"], list) - and len(mm_data["videos"]) > 0): + if ( + "videos" in mm_data + and isinstance(mm_data["videos"], list) + and len(mm_data["videos"]) > 0 + ): video_grid_thw_lst = [] pixel_values_videos_lst = [] @@ -870,12 +915,12 @@ def _call_hf_processor( # qwen_vl_utils already has "do_sample_frames" in # mm_kwargs, don't overwrite it. video_mm_kwargs["do_sample_frames"] = metadata.get( - "do_sample_frames", False) + "do_sample_frames", False + ) - metadata = VideoMetadata(**{ - k: metadata[k] - for k in metadata if k != "do_sample_frames" - }) + metadata = VideoMetadata( + **{k: metadata[k] for k in metadata if k != "do_sample_frames"} + ) video_mm_data = dict() video_mm_data["videos"] = [[video_array]] @@ -888,8 +933,7 @@ def _call_hf_processor( tok_kwargs=tok_kwargs, ) input_ids = video_outputs.pop("input_ids") - video_placeholder = processor.tokenizer.batch_decode( - input_ids)[0] + video_placeholder = processor.tokenizer.batch_decode(input_ids)[0] prompt = prompt.replace( "<|vision_start|><|video_pad|><|vision_end|>", video_placeholder, @@ -897,8 +941,7 @@ def _call_hf_processor( ) video_grid_thw_lst.append(video_outputs["video_grid_thw"]) - pixel_values_videos_lst.append( - video_outputs["pixel_values_videos"]) + pixel_values_videos_lst.append(video_outputs["pixel_values_videos"]) video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_grid_thw=torch.cat(video_grid_thw_lst), @@ -931,14 +974,18 @@ def _get_mm_fields_config( return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), + "image", image_grid_sizes + ), image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), + "image", image_grid_sizes + ), image_grid_thw=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), + "video", video_grid_sizes + ), video_grid_thw=MultiModalFieldConfig.batched("video"), ) @@ -949,8 +996,7 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - image_processor = self.info.get_image_processor( - **hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() hf_config = self.info.get_hf_config() @@ -979,26 +1025,28 @@ def get_video_replacement_qwen3vl(item_idx: int): if is_list_of(sampled_fps, float): sampled_fps = sampled_fps[item_idx] timestamps = self.info._get_video_second_idx( - metadata, out_item, do_sample_frames, sampled_fps) + metadata, out_item, do_sample_frames, sampled_fps + ) assert len(timestamps) == grid_thw[0], ( f"The timestamps length({len(timestamps)}) should be equal " - f"video length ({grid_thw[0]}).") + f"video length ({grid_thw[0]})." + ) frames_idx_token = [ - tokenizer.encode(f"<{curr_time:.1f} seconds>", - add_special_tokens=False) + tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) for curr_time in timestamps ] num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length placeholder = [] for frame_idx in frames_idx_token: placeholder.extend(frame_idx) - placeholder.extend([vision_start_token_id] + - [video_token_id] * num_tokens_per_frame + - [vision_end_token_id]) - return PromptUpdateDetails.select_token_id(placeholder, - video_token_id) + placeholder.extend( + [vision_start_token_id] + + [video_token_id] * num_tokens_per_frame + + [vision_end_token_id] + ) + return PromptUpdateDetails.select_token_id(placeholder, video_token_id) return [ PromptReplacement( @@ -1006,7 +1054,6 @@ def get_video_replacement_qwen3vl(item_idx: int): target=hf_processor.image_token, replacement=get_image_replacement_qwen3vl, ), - # NOTE: We match string on purpose since searching sequence of # token ids takes more time. PromptReplacement( @@ -1026,18 +1073,19 @@ def get_video_replacement_qwen3vl(item_idx: int): "intermediate_tensors": 0, "inputs_embeds": 0, # the same shape as input_embeds - "deepstack_input_embeds": 0 - }) + "deepstack_input_embeds": 0, + } +) class Qwen3LLMModel(Qwen3Model): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) if not get_pp_group().is_first_rank: assert self.start_layer >= len( - vllm_config.model_config.hf_config.vision_config. - deepstack_visual_indexes), ( - "start_layer should be greater than or equal to " - "len(deepstack_visual_indexes)") + vllm_config.model_config.hf_config.vision_config.deepstack_visual_indexes + ), ( + "start_layer should be greater than or equal to " + "len(deepstack_visual_indexes)" + ) def forward( self, @@ -1059,7 +1107,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer_idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + self.layers[self.start_layer : self.end_layer] + ): layer_idx = layer_idx + self.start_layer hidden_states, residual = layer( @@ -1068,22 +1117,23 @@ def forward( residual, ) - if deepstack_input_embeds is not None and \ - layer_idx in range(0, len(deepstack_input_embeds)): - hidden_states = hidden_states + deepstack_input_embeds[ - f"deepstack_input_embeds_{layer_idx}"] + if deepstack_input_embeds is not None and layer_idx in range( + 0, len(deepstack_input_embeds) + ): + hidden_states = ( + hidden_states + + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class Qwen3LLMForCausalLM(Qwen3ForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super(Qwen3ForCausalLM, self).__init__() config = vllm_config.model_config.hf_config.text_config @@ -1100,24 +1150,30 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix="lm_head") + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix="lm_head", + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) -class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsLoRA, SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + Qwen3VLMultiModalProcessor, + info=Qwen3VLProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder, +) +class Qwen3VLForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP +): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1138,7 +1194,8 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, "model.visual.": "visual.", "lm_head.": "language_model.lm_head.", "model.language_model.": "language_model.model.", - }) + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -1158,8 +1215,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - if not multimodal_config.get_limit_per_prompt("image") and \ - not multimodal_config.get_limit_per_prompt("video"): + if not multimodal_config.get_limit_per_prompt( + "image" + ) and not multimodal_config.get_limit_per_prompt("video"): self.visual = None else: self.visual = Qwen3_VisionTransformer( @@ -1170,25 +1228,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): use_data_parallel=self.use_data_parallel, ) - self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, - "language_model")) + self.language_model = Qwen3LLMForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - self.use_deepstack = hasattr(config.vision_config, - 'deepstack_visual_indexes') - self.deepstack_num_level = len( - config.vision_config.deepstack_visual_indexes - ) if self.use_deepstack else 0 + self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes") + self.deepstack_num_level = ( + len(config.vision_config.deepstack_visual_indexes) + if self.use_deepstack + else 0 + ) # register buffer for deepstack if self.use_deepstack and self.visual is not None: self.deepstack_input_embeds = [ torch.zeros( vllm_config.scheduler_config.max_num_batched_tokens, - config.text_config.hidden_size) + config.text_config.hidden_size, + ) for _ in range(self.deepstack_num_level) ] else: @@ -1196,30 +1256,34 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.visual_dim = config.vision_config.out_hidden_size self.multiscale_dim = self.visual_dim * self.deepstack_num_level - def _get_deepstack_input_embeds(self, - num_tokens: int) -> IntermediateTensors: + def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors: # get deepstack_input_embeds from buffer, and clear the buffer - return IntermediateTensors({ - f"deepstack_input_embeds_{idx}": - self.deepstack_input_embeds[idx][:num_tokens] - for idx in range(self.deepstack_num_level) - }) - - def _set_deepstack_input_embeds( - self, deepstack_input_embeds: torch.Tensor) -> None: + return IntermediateTensors( + { + f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][ + :num_tokens + ] + for idx in range(self.deepstack_num_level) + } + ) + + def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None: # set deepstack_input_embeds to buffer num_tokens = deepstack_input_embeds.size(1) if num_tokens > self.deepstack_input_embeds[0].size(0): self.deepstack_input_embeds = [ - torch.zeros(num_tokens, - self.config.text_config.hidden_size, - device=self.deepstack_input_embeds[0].device, - dtype=self.deepstack_input_embeds[0].dtype) + torch.zeros( + num_tokens, + self.config.text_config.hidden_size, + device=self.deepstack_input_embeds[0].device, + dtype=self.deepstack_input_embeds[0].dtype, + ) for _ in range(self.deepstack_num_level) ] for idx in range(self.deepstack_num_level): self.deepstack_input_embeds[idx][:num_tokens].copy_( - deepstack_input_embeds[idx]) + deepstack_input_embeds[idx] + ) def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: # clear deepstack_input_embeds in buffer @@ -1227,24 +1291,27 @@ def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: for idx in range(self.deepstack_num_level): self.deepstack_input_embeds[idx][:num_tokens].zero_() - def _validate_and_reshape_mm_tensor(self, mm_input: object, - name: str) -> torch.Tensor: + def _validate_and_reshape_mm_tensor( + self, mm_input: object, name: str + ) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {name}. " - f"Got type: {type(mm_input)}") + raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}") if isinstance(mm_input, torch.Tensor): if mm_input.ndim == 2: return mm_input if mm_input.ndim != 3: - raise ValueError(f"{name} should be 2D or batched 3D tensor. " - f"Got ndim: {mm_input.ndim} " - f"(shape={mm_input.shape})") + raise ValueError( + f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})" + ) return torch.concat(list(mm_input)) else: return torch.concat(mm_input) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + self, **kwargs: object + ) -> Optional[Qwen2_5_VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) @@ -1254,34 +1321,46 @@ def _parse_and_validate_image_input( if pixel_values is not None: pixel_values = self._validate_and_reshape_mm_tensor( - pixel_values, "image pixel values") + pixel_values, "image pixel values" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of image pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + "Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}" + ) - return Qwen2_5_VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) + return Qwen2_5_VLImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) if image_embeds is not None: image_embeds = self._validate_and_reshape_mm_tensor( - image_embeds, "image embeds") + image_embeds, "image embeds" + ) image_grid_thw = self._validate_and_reshape_mm_tensor( - image_grid_thw, "image grid_thw") + image_grid_thw, "image grid_thw" + ) if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") + raise ValueError( + "Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}" + ) return Qwen2_5_VLImageEmbeddingInputs( type="image_embeds", image_embeds=image_embeds, - image_grid_thw=image_grid_thw) + image_grid_thw=image_grid_thw, + ) def _parse_and_validate_video_input( - self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + self, **kwargs: object + ) -> Optional[Qwen2_5_VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) @@ -1292,9 +1371,11 @@ def _parse_and_validate_video_input( if pixel_values_videos is not None: pixel_values_videos = self._validate_and_reshape_mm_tensor( - pixel_values_videos, "video pixel values") + pixel_values_videos, "video pixel values" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) return Qwen2_5_VLVideoPixelInputs( type="pixel_values_videos", @@ -1305,22 +1386,26 @@ def _parse_and_validate_video_input( if video_embeds is not None: video_embeds = self._validate_and_reshape_mm_tensor( - video_embeds, "video embeds") + video_embeds, "video embeds" + ) video_grid_thw = self._validate_and_reshape_mm_tensor( - video_grid_thw, "video grid_thw") + video_grid_thw, "video grid_thw" + ) if not isinstance(video_embeds, torch.Tensor): - raise ValueError("Incorrect type of video embeddings. " - f"Got type: {type(video_embeds)}") + raise ValueError( + "Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}" + ) return Qwen2_5_VLVideoEmbeddingInputs( type="video_embeds", video_embeds=video_embeds, - video_grid_thw=video_grid_thw) + video_grid_thw=video_grid_thw, + ) def _process_image_input( - self, - image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Qwen2_5_VLImageInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1330,25 +1415,24 @@ def _process_image_input( else: pixel_values = image_input["pixel_values"].type(self.visual.dtype) if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values, - grid_thw_list, - rope_type="rope_3d") + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, grid_thw_list, rope_type="rope_3d" + ) else: - image_embeds = self.visual(pixel_values, - grid_thw=grid_thw_list) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return image_embeds.split(sizes) def _process_video_input( - self, - video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: - + self, video_input: Qwen2_5_VLVideoInputs + ) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 grid_thw_list = grid_thw.tolist() @@ -1357,44 +1441,50 @@ def _process_video_input( video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) + self.visual.dtype + ) if self.use_data_parallel: - return run_dp_sharded_mrope_vision_model(self.visual, - pixel_values_videos, - grid_thw_list, - rope_type="rope_3d") + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d" + ) else: - video_embeds = self.visual(pixel_values_videos, - grid_thw=grid_thw_list) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list) # Split concatenated embeddings for each video item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync merge_size = self.visual.spatial_merge_size - sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // - (merge_size * merge_size)).tolist() + sizes = ( + torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) + // (merge_size * merge_size) + ).tolist() return video_embeds.split(sizes) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} for input_key in kwargs: - if input_key in ("pixel_values", "image_embeds" - ) and "image" not in mm_input_by_modality: - mm_input_by_modality[ - "image"] = self._parse_and_validate_image_input(**kwargs) - if input_key in ("pixel_values_videos", "video_embeds" - ) and "video" not in mm_input_by_modality: - mm_input_by_modality[ - "video"] = self._parse_and_validate_video_input(**kwargs) + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) return mm_input_by_modality def get_language_model(self) -> torch.nn.Module: return self.language_model def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - - mm_input_by_modality = self._parse_and_validate_multimodal_inputs( - **kwargs) + self, **kwargs: object + ) -> Optional[MultiModalEmbeddings]: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return None @@ -1432,15 +1522,16 @@ def _compute_deepstack_embeds( dim=-1, ) - multimodal_embeddings = torch.split(multimodal_embeddings_main, - visual_lens, - dim=0) + multimodal_embeddings = torch.split( + multimodal_embeddings_main, visual_lens, dim=0 + ) multimodal_embeddings_multiscale = torch.split( - multimodal_embeddings_multiscale, visual_lens, dim=0) + multimodal_embeddings_multiscale, visual_lens, dim=0 + ) deepstack_input_embeds = inputs_embeds.new_zeros( - inputs_embeds.size(0), - self.deepstack_num_level * inputs_embeds.size(1)) + inputs_embeds.size(0), self.deepstack_num_level * inputs_embeds.size(1) + ) deepstack_input_embeds = _merge_multimodal_embeddings( inputs_embeds=deepstack_input_embeds, @@ -1448,7 +1539,8 @@ def _compute_deepstack_embeds( is_multimodal=is_multimodal, ) deepstack_input_embeds = deepstack_input_embeds.view( - inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim) + inputs_embeds.shape[0], self.deepstack_num_level, self.visual_dim + ) deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2) return deepstack_input_embeds, multimodal_embeddings @@ -1475,7 +1567,8 @@ def get_input_embeddings( raise ValueError( "`get_input_embeddings` now requires `is_multimodal` arg, " "please update your model runner according to " - "https://github.com/vllm-project/vllm/pull/16229.") + "https://github.com/vllm-project/vllm/pull/16229." + ) if self.use_deepstack: ( @@ -1496,8 +1589,12 @@ def get_input_embeddings( ) if deepstack_input_embeds is not None: - deepstack_input_embeds = torch.zeros_like(inputs_embeds).unsqueeze( - 0).repeat(self.deepstack_num_level, 1, 1).contiguous() + deepstack_input_embeds = ( + torch.zeros_like(inputs_embeds) + .unsqueeze(0) + .repeat(self.deepstack_num_level, 1, 1) + .contiguous() + ) self._set_deepstack_input_embeds(deepstack_input_embeds) return inputs_embeds @@ -1537,10 +1634,14 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - if self.use_deepstack and inputs_embeds is not None and get_pp_group( - ).is_first_rank: + if ( + self.use_deepstack + and inputs_embeds is not None + and get_pp_group().is_first_rank + ): deepstack_input_embeds = self._get_deepstack_input_embeds( - inputs_embeds.size(0)) + inputs_embeds.size(0) + ) else: deepstack_input_embeds = None @@ -1564,9 +1665,7 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [] if self.visual is None: skip_prefixes.extend(["visual."]) diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index bd4aae7404c6..cd8046d04248 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -23,13 +23,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights.""" + import typing from collections.abc import Iterable from typing import Callable, Optional, Union import torch -from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import ( - Qwen3VLMoeConfig) +from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -38,21 +38,26 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel -from .qwen3_vl import (Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder, - Qwen3VLForConditionalGeneration, - Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) +from .qwen3_vl import ( + Qwen3_VisionTransformer, + Qwen3VLDummyInputsBuilder, + Qwen3VLForConditionalGeneration, + Qwen3VLMultiModalProcessor, + Qwen3VLProcessingInfo, +) from .utils import is_pp_missing_parameter, maybe_prefix logger = init_logger(__name__) class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config(Qwen3VLMoeConfig) @@ -66,18 +71,19 @@ def get_hf_config(self): "intermediate_tensors": 0, "inputs_embeds": 0, # the same shape as input_embeds - "deepstack_input_embeds": 0 - }) + "deepstack_input_embeds": 0, + } +) class Qwen3MoeLLMModel(Qwen3MoeModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) if not get_pp_group().is_first_rank: assert self.start_layer >= len( - vllm_config.model_config.hf_config.vision_config. - deepstack_visual_indexes), ( - "start_layer should be greater than or equal to " - "len(deepstack_visual_indexes)") + vllm_config.model_config.hf_config.vision_config.deepstack_visual_indexes + ), ( + "start_layer should be greater than or equal to " + "len(deepstack_visual_indexes)" + ) def forward( self, @@ -98,7 +104,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer_idx, layer in enumerate( - self.layers[self.start_layer:self.end_layer]): + self.layers[self.start_layer : self.end_layer] + ): layer_idx = layer_idx + self.start_layer hidden_states, residual = layer( @@ -107,40 +114,48 @@ def forward( residual, ) - if deepstack_input_embeds is not None and \ - layer_idx in range(0, len(deepstack_input_embeds)): - hidden_states = hidden_states + deepstack_input_embeds[ - f"deepstack_input_embeds_{layer_idx}"] + if deepstack_input_embeds is not None and layer_idx in range( + 0, len(deepstack_input_embeds) + ): + hidden_states = ( + hidden_states + + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"] + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_fused_expert_weights(self, name: str, params_dict: dict, - loaded_weight: torch.Tensor, shard_id: str, - num_experts: int) -> bool: + def load_fused_expert_weights( + self, + name: str, + params_dict: dict, + loaded_weight: torch.Tensor, + shard_id: str, + num_experts: int, + ) -> bool: param = params_dict[name] weight_loader = typing.cast(Callable[..., bool], param.weight_loader) loaded_local_expert = False for expert_id in range(num_experts): curr_expert_weight = loaded_weight[expert_id] - success = weight_loader(param, - curr_expert_weight, - name, - shard_id, - expert_id, - return_success=True) + success = weight_loader( + param, + curr_expert_weight, + name, + shard_id, + expert_id, + return_success=True, + ) if success: loaded_local_expert = True return loaded_local_expert - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -150,9 +165,18 @@ def load_weights(self, weights: Iterable[tuple[str, ("gate_up_proj", "up_proj", 1), ] # Skip loading extra parameters for GPTQ/modelopt models. - ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", - ".v_scale", "_v_scale", ".weight_scale", - "_weight_scale", ".input_scale", "_input_scale") + ignore_suffixes = ( + ".bias", + "_bias", + ".k_scale", + "_k_scale", + ".v_scale", + "_v_scale", + ".weight_scale", + "_weight_scale", + ".input_scale", + "_input_scale", + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() @@ -163,9 +187,8 @@ def load_weights(self, weights: Iterable[tuple[str, ] num_experts = self.config.num_experts for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if ("experts.gate_up_proj" in name - or "experts.down_proj" in name): + for param_name, weight_name, shard_id in stacked_params_mapping: + if "experts.gate_up_proj" in name or "experts.down_proj" in name: is_fused_expert = True expert_params_mapping = fused_expert_params_mapping @@ -195,8 +218,7 @@ def load_weights(self, weights: Iterable[tuple[str, if name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) else: @@ -215,40 +237,55 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name_mapped, self): continue if is_fused_expert: - loaded_weight = loaded_weight.transpose(-1, - -2) # no bias + loaded_weight = loaded_weight.transpose(-1, -2) # no bias if "experts.gate_up_proj" in name: loaded_weight = loaded_weight.chunk(2, dim=-2) success_w1 = self.load_fused_expert_weights( - name_mapped, params_dict, loaded_weight[0], - "w1", num_experts) + name_mapped, + params_dict, + loaded_weight[0], + "w1", + num_experts, + ) success_w3 = self.load_fused_expert_weights( - name_mapped, params_dict, loaded_weight[1], - "w3", num_experts) + name_mapped, + params_dict, + loaded_weight[1], + "w3", + num_experts, + ) success = success_w1 and success_w3 else: # down_proj success = self.load_fused_expert_weights( - name_mapped, params_dict, loaded_weight, - shard_id, num_experts) + name_mapped, + params_dict, + loaded_weight, + shard_id, + num_experts, + ) else: # Skip loading extra parameters for GPTQ/modelopt models - if name_mapped.endswith( - ignore_suffixes - ) and name_mapped not in params_dict: + if ( + name_mapped.endswith(ignore_suffixes) + and name_mapped not in params_dict + ): continue param = params_dict[name_mapped] # We should ask the weight loader to return success or # not here since otherwise we may skip experts with # other available replicas. - weight_loader = typing.cast(Callable[..., bool], - param.weight_loader) - success = weight_loader(param, - loaded_weight, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True) + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) if success: name = name_mapped break @@ -259,8 +296,7 @@ def load_weights(self, weights: Iterable[tuple[str, # So we simply skip it continue # Skip loading extra parameters for GPTQ/modelopt models. - if name.endswith( - ignore_suffixes) and name not in params_dict: + if name.endswith(ignore_suffixes) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -268,7 +304,8 @@ def load_weights(self, weights: Iterable[tuple[str, # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") + ".kv_scale", ".attn.kv_scale" + ) if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 @@ -279,37 +316,42 @@ def load_weights(self, weights: Iterable[tuple[str, else: name = remapped_kv_scale_name param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super(Qwen3MoeForCausalLM, self).__init__() self.config = vllm_config.model_config.hf_config.text_config self.quant_config = vllm_config.quant_config - self.model = Qwen3MoeLLMModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size, - quant_config=self.quant_config, - prefix=maybe_prefix(prefix, "lm_head")) + self.model = Qwen3MoeLLMModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(self.config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) -@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, - info=Qwen3VLMoeProcessingInfo, - dummy_inputs=Qwen3VLDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + Qwen3VLMultiModalProcessor, + info=Qwen3VLMoeProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder, +) class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super(Qwen3VLForConditionalGeneration, self).__init__() config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config @@ -320,8 +362,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - if not multimodal_config.get_limit_per_prompt("image") and \ - not multimodal_config.get_limit_per_prompt("video"): + if not multimodal_config.get_limit_per_prompt( + "image" + ) and not multimodal_config.get_limit_per_prompt("video"): self.visual = None else: self.visual = Qwen3_VisionTransformer( @@ -332,25 +375,27 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): use_data_parallel=self.use_data_parallel, ) - self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, - "language_model")) + self.language_model = Qwen3MoeLLMForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - self.use_deepstack = hasattr(config.vision_config, - 'deepstack_visual_indexes') - self.deepstack_num_level = len( - config.vision_config.deepstack_visual_indexes - ) if self.use_deepstack else 0 + self.use_deepstack = hasattr(config.vision_config, "deepstack_visual_indexes") + self.deepstack_num_level = ( + len(config.vision_config.deepstack_visual_indexes) + if self.use_deepstack + else 0 + ) # register buffer for deepstack if self.use_deepstack and self.visual is not None: self.deepstack_input_embeds = [ torch.zeros( vllm_config.scheduler_config.max_num_batched_tokens, - config.text_config.hidden_size) + config.text_config.hidden_size, + ) for _ in range(self.deepstack_num_level) ] else: diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index a94e1e700c67..1786ea6a6878 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -18,33 +18,45 @@ from torch import nn from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer, - TensorType) +from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType from transformers.image_utils import ImageInput from transformers.tokenization_utils_base import TextInput from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) from .qwen import QWenBaseModel, QWenModel from .utils import flatten_bn @@ -56,11 +68,12 @@ class QwenImagePixelInputs(TensorSchema): - c: Number of channels (3) - h: Height - w: Width - + Note that image_size is the value in the vision config to which we resize the image to in the normalization transform. Currently multi-image support can only be leveraged by passing image embeddings directly. """ + type: Literal["pixel_values"] = "pixel_values" data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] @@ -71,10 +84,11 @@ class QwenImageEmbeddingInputs(TensorSchema): - bn: Batch size * number of images - ifs: Image feature size (256) - hs: Hidden size - + `hidden_size` must match the hidden size of the language model backbone and is stored in the visual config of the model if we have one. """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")] @@ -100,8 +114,7 @@ def __init__( self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim \ - and self.vdim == embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads @@ -112,8 +125,9 @@ def __init__( self.hidden_size_per_partition = embed_dim # Strided linear layer. - assert self._qkv_same_embed_dim, \ - 'Visual Attention implementation only supports self-attention' + assert self._qkv_same_embed_dim, ( + "Visual Attention implementation only supports self-attention" + ) self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim) self.out_proj = ReplicatedLinear(embed_dim, embed_dim) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -128,50 +142,63 @@ def forward( mixed_x_layer, _ = self.in_proj(x) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] query_layer, key_layer, value_layer = mixed_x_layer.split( - self.hidden_size_per_attention_head, dim=-1) + self.hidden_size_per_attention_head, dim=-1 + ) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view( - sq, b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + sq, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ).transpose(0, 1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view( - sq, b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + sq, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ).transpose(0, 1) q_scaled = query_layer / self.norm_factor if attn_mask is not None: - attention_probs = torch.baddbmm(attn_mask, q_scaled, - key_layer.transpose(-2, -1)) + attention_probs = torch.baddbmm( + attn_mask, q_scaled, key_layer.transpose(-2, -1) + ) else: attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) attention_probs = attention_probs.softmax(dim=-1) value_layer = value_layer.view( - sq, b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + sq, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ).transpose(0, 1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer) # change view [b, np, sq, hn] context_layer = context_layer.view( - b, self.num_attention_heads_per_partition, sq, - self.hidden_size_per_attention_head) + b, + self.num_attention_heads_per_partition, + sq, + self.hidden_size_per_attention_head, + ) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size_per_partition, + ) context_layer = context_layer.view(*new_context_layer_shape) output, _ = self.out_proj(context_layer) @@ -189,10 +216,9 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.c_fc = ColumnParallelLinear(hidden_size, - intermediate_size, - bias=True, - quant_config=quant_config) + self.c_fc = ColumnParallelLinear( + hidden_size, intermediate_size, bias=True, quant_config=quant_config + ) self.act_fn = get_act_fn("gelu") self.c_proj = RowParallelLinear( intermediate_size, @@ -209,7 +235,6 @@ def forward(self, x): class VisualAttentionBlock(nn.Module): - def __init__( self, d_model: int, @@ -249,7 +274,6 @@ def forward( class TransformerBlock(nn.Module): - def __init__( self, width: int, @@ -263,14 +287,18 @@ def __init__( self.width = width self.layers = layers - self.resblocks = nn.ModuleList([ - VisualAttentionBlock(width, - heads, - mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config) - for _ in range(layers) - ]) + self.resblocks = nn.ModuleList( + [ + VisualAttentionBlock( + width, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + ) + for _ in range(layers) + ] + ) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype @@ -278,54 +306,57 @@ def get_cast_dtype(self) -> torch.dtype: def get_cast_device(self) -> torch.device: return self.resblocks[0].mlp.c_fc.weight.device - def forward(self, - x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: for r in self.resblocks: x = r(x, attn_mask=attn_mask) return x class VisionTransformer(nn.Module): - - def __init__(self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - n_queries: int = 256, - output_dim: int = 512, - image_start_id: int = 151857, - quant_config: Optional[QuantizationConfig] = None, - **kwargs): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + image_start_id: int = 151857, + quant_config: Optional[QuantizationConfig] = None, + **kwargs, + ): super().__init__() image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) - self.grid_size = (image_height // patch_height, - image_width // patch_width) + self.grid_size = (image_height // patch_height, image_width // patch_width) self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False) + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) # class embeddings and positional embeddings scale = width**-0.5 - self.positional_embedding = nn.Parameter(scale * - torch.randn(256, width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(256, width)) norm_layer = partial(nn.LayerNorm, eps=1e-6) self.ln_pre = norm_layer(width) - self.transformer = TransformerBlock(width, - layers, - heads, - mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config) + self.transformer = TransformerBlock( + width, + layers, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + ) self.attn_pool = Resampler2( grid_size=int(math.sqrt(n_queries)), @@ -342,7 +373,8 @@ def __init__(self, self.ln_post = norm_layer(output_dim) self.proj = nn.Parameter( - (output_dim**-0.5) * torch.randn(output_dim, output_dim)) + (output_dim**-0.5) * torch.randn(output_dim, output_dim) + ) self.image_start_id = image_start_id self.image_end_id = image_start_id + 1 @@ -356,12 +388,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # to patches x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], - -1) # shape = [*, width, grid ** 2] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = x + get_abs_pos(self.positional_embedding, int(math.sqrt( - x.size(1)))) + x = x + get_abs_pos(self.positional_embedding, int(math.sqrt(x.size(1)))) x = self.ln_pre(x) @@ -377,20 +407,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class QwenVLModel(QWenModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.visual = VisionTransformer(**config.visual, - quant_config=quant_config) + self.visual = VisionTransformer(**config.visual, quant_config=quant_config) @lru_cache(maxsize=1) def _get_tokenizer_without_image_pad( - tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: + tokenizer: PreTrainedTokenizer, +) -> PreTrainedTokenizer: """ The logic of adding image pad tokens should only be applied in [`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor], @@ -402,7 +431,6 @@ def _get_tokenizer_without_image_pad( new_tokenizer = copy.deepcopy(tokenizer) class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore - def tokenize( self, text: str, @@ -413,7 +441,8 @@ def tokenize( text = unicodedata.normalize("NFC", text) return [ - self.decoder[t] for t in self.tokenizer.encode( + self.decoder[t] + for t in self.tokenizer.encode( text, allowed_special=allowed_special, disallowed_special=disallowed_special, @@ -435,8 +464,7 @@ def _decode( errors=errors or self.errors, ) - TokenizerWithoutImagePad.__name__ = \ - f"{tokenizer.__class__.__name__}WithoutImagePad" + TokenizerWithoutImagePad.__name__ = f"{tokenizer.__class__.__name__}WithoutImagePad" new_tokenizer.__class__ = TokenizerWithoutImagePad return new_tokenizer @@ -467,17 +495,19 @@ def __init__( vision_config = config.visual image_size = vision_config["image_size"] - self.image_transform = transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ]) + self.image_transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), + ), + ] + ) @property def image_start_tag(self) -> str: @@ -524,7 +554,6 @@ def __call__( class QwenVLProcessingInfo(BaseProcessingInfo): - def get_tokenizer(self) -> PreTrainedTokenizer: tokenizer = self.ctx.tokenizer assert isinstance(tokenizer, PreTrainedTokenizer) @@ -553,7 +582,6 @@ def get_num_image_tokens(self) -> int: class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -561,8 +589,9 @@ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: img_start = hf_processor.image_start_tag img_end = hf_processor.image_end_tag - return "".join(f"Picture {i}: {img_start}{img_end}\n" - for i in range(1, num_images + 1)) + return "".join( + f"Picture {i}: {img_start}{img_end}\n" for i in range(1, num_images + 1) + ) def get_dummy_mm_data( self, @@ -579,16 +608,16 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): - def _call_hf_processor( self, prompt: str, @@ -644,8 +673,7 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() - special_tokens: dict[str, - int] = tokenizer.special_tokens # type: ignore + special_tokens: dict[str, int] = tokenizer.special_tokens # type: ignore processor = self.info.get_hf_processor() img_start_id = special_tokens[processor.image_start_tag] @@ -667,11 +695,14 @@ def _get_prompt_updates( ] -@MULTIMODAL_REGISTRY.register_processor(QwenVLMultiModalProcessor, - info=QwenVLProcessingInfo, - dummy_inputs=QwenVLDummyInputsBuilder) -class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, - SupportsMultiModal): +@MULTIMODAL_REGISTRY.register_processor( + QwenVLMultiModalProcessor, + info=QwenVLProcessingInfo, + dummy_inputs=QwenVLDummyInputsBuilder, +) +class QwenVLForConditionalGeneration( + QWenBaseModel, SupportsPP, SupportsLoRA, SupportsMultiModal +): packed_modules_mapping = { "c_attn": ["c_attn"], "gate_up_proj": [ @@ -687,7 +718,8 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="transformer.h", connector="transformer.visual.attn_pool", - tower_model="transformer.visual.transformer") + tower_model="transformer.visual.transformer", + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -712,14 +744,16 @@ def __init__( self.transformer: QwenVLModel def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[QwenImageInputs]: + self, **kwargs: object + ) -> Optional[QwenImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is not None: if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") + raise ValueError( + f"Incorrect type of pixel values. Got type: {type(pixel_values)}" + ) expected_h = expected_w = self.config.visual["image_size"] resolve_bindings = {"h": expected_h, "w": expected_w} @@ -732,8 +766,10 @@ def _parse_and_validate_image_input( if image_embeds is not None: if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") + raise ValueError( + "Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}" + ) return QwenImageEmbeddingInputs( type="image_embeds", @@ -742,8 +778,7 @@ def _parse_and_validate_image_input( return None - def _process_image_input(self, - image_input: QwenImageInputs) -> torch.Tensor: + def _process_image_input(self, image_input: QwenImageInputs) -> torch.Tensor: if image_input["type"] == "image_embeds": return image_input["data"] @@ -752,8 +787,7 @@ def _process_image_input(self, def get_language_model(self) -> torch.nn.Module: return self.transformer - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -772,6 +806,7 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.transformer(input_ids, positions, - intermediate_tensors, inputs_embeds) + hidden_states = self.transformer( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states diff --git a/vllm/model_executor/models/radio.py b/vllm/model_executor/models/radio.py index 9cbf844ae9f8..2313b98348b7 100644 --- a/vllm/model_executor/models/radio.py +++ b/vllm/model_executor/models/radio.py @@ -28,7 +28,6 @@ def _ntuple(n): - def parse(x): if isinstance(x, Iterable) and not isinstance(x, str): return tuple(x) @@ -45,7 +44,6 @@ def parse(x): class InputConditioner(nn.Module): - def __init__( self, input_scale: float, @@ -72,7 +70,6 @@ def _to_tensor(v: norm_t): class ClsToken(nn.Module): - def __init__( self, ndim: int, @@ -91,12 +88,14 @@ def __init__( if num_registers: self.num_registers = num_registers elif register_multiple: - self.num_registers = register_multiple - (num_tokens % - register_multiple) + self.num_registers = register_multiple - ( + num_tokens % register_multiple + ) scale = ndim**-0.5 self.token = nn.Parameter( - torch.randn(num_tokens + self.num_registers, ndim) * scale) + torch.randn(num_tokens + self.num_registers, ndim) * scale + ) else: self.token = None @@ -108,16 +107,18 @@ def forward(self, x: torch.Tensor): return x token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1) - x = torch.cat([ - token, - x, - ], dim=1) + x = torch.cat( + [ + token, + x, + ], + dim=1, + ) return x class ViTPatchGenerator(nn.Module): - def __init__( self, # config: PretrainedConfig, @@ -147,8 +148,8 @@ def __init__( max_input_dims = (max_input_dims, max_input_dims) max_input_dims = tuple( - int(math.ceil(d / patch_size) * patch_size) - for d in max_input_dims) + int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims + ) self.cpe_mode = max_input_dims != input_dims self.pos_dropout = pos_dropout @@ -167,15 +168,15 @@ def __init__( self.max_input_dims = max_input_dims self.im_to_patches = Im2Patches(patch_size) - self.embedder = ViTPatchLinear(patch_size, - embed_dim, - bias=patch_bias, - **factory) + self.embedder = ViTPatchLinear( + patch_size, embed_dim, bias=patch_bias, **factory + ) if abs_pos: scale = embed_dim**-0.5 self.pos_embed = nn.Parameter( - torch.randn(1, self.num_patches, embed_dim, **factory) * scale) + torch.randn(1, self.num_patches, embed_dim, **factory) * scale + ) self.cls_token = ClsToken( embed_dim, @@ -185,8 +186,9 @@ def __init__( num_registers=num_registers, ) - self.patch_normalizer = nn.LayerNorm( - embed_dim) if normalize_patches else nn.Identity() + self.patch_normalizer = ( + nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity() + ) def forward(self, x: torch.Tensor) -> torch.Tensor: patches = self.embed_patches(x) @@ -221,42 +223,48 @@ def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter): if src_embed.shape != targ_embed.shape: src_size = int(math.sqrt(src_embed.shape[1])) - assert src_size**2 == src_embed.shape[ - 1], 'Unable to interpolate non-square embedding' - - src_embed = rearrange(src_embed, - 'b (h w) c -> b c h w', - h=src_size, - w=src_size) - src_embed = F.interpolate(src_embed, - size=(self.num_rows, self.num_cols), - mode='bicubic', - align_corners=True, - antialias=False) - src_embed = rearrange(src_embed, 'b c h w -> b (h w) c') + assert src_size**2 == src_embed.shape[1], ( + "Unable to interpolate non-square embedding" + ) + + src_embed = rearrange( + src_embed, "b (h w) c -> b c h w", h=src_size, w=src_size + ) + src_embed = F.interpolate( + src_embed, + size=(self.num_rows, self.num_cols), + mode="bicubic", + align_corners=True, + antialias=False, + ) + src_embed = rearrange(src_embed, "b c h w -> b (h w) c") targ_embed.data.copy_(src_embed) - def _load_projection(self, src_proj_weight: torch.Tensor, - targ_proj_weight: torch.Tensor): + def _load_projection( + self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor + ): if src_proj_weight.shape != targ_proj_weight.shape: src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3)) - assert (src_patch_size**2) * 3 == src_proj_weight.shape[ - 1], 'Unable to interpolate non-square patch size' - - src_proj_weight = rearrange(src_proj_weight, - 'b (c h w) -> b c h w', - c=3, - h=src_patch_size, - w=src_patch_size) - src_proj_weight = F.interpolate(src_proj_weight, - size=(self.patch_size, - self.patch_size), - mode='bicubic', - align_corners=True, - antialias=False) - src_proj_weight = rearrange(src_proj_weight, - 'b c h w -> b (c h w)') + assert (src_patch_size**2) * 3 == src_proj_weight.shape[1], ( + "Unable to interpolate non-square patch size" + ) + + src_proj_weight = rearrange( + src_proj_weight, + "b (c h w) -> b c h w", + c=3, + h=src_patch_size, + w=src_patch_size, + ) + src_proj_weight = F.interpolate( + src_proj_weight, + size=(self.patch_size, self.patch_size), + mode="bicubic", + align_corners=True, + antialias=False, + ) + src_proj_weight = rearrange(src_proj_weight, "b c h w -> b (c h w)") targ_proj_weight.data.copy_(src_proj_weight) def embed_patches(self, x: torch.Tensor) -> torch.Tensor: @@ -276,11 +284,12 @@ def apply_pos_enc( pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size) if self.training and self.pos_dropout > 0: - keeps = torch.rand(patches.shape[0], - 1, - 1, - dtype=pos_enc.dtype, - device=pos_enc.device) > self.pos_dropout + keeps = ( + torch.rand( + patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device + ) + > self.pos_dropout + ) pos_enc_drop = torch.where(keeps, pos_enc, 0) else: pos_enc_drop = pos_enc @@ -303,56 +312,58 @@ def get_pos_enc( if patch_idxs is None: return pos_embed - exp_patch_idxs = patch_idxs.unsqueeze(-1).expand( - -1, -1, pos_embed.shape[-1]) + exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1]) - pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1), - dim=1, - index=exp_patch_idxs) + pos_embed = torch.gather( + pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs + ) return pos_embed - def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, - int]): + def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, int]): if (self.num_rows, self.num_cols) == input_dims: return self.pos_embed - pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, - -1).permute(0, 3, 1, 2) + pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute( + 0, 3, 1, 2 + ) def window_select(pos_embed): if input_dims[0] < pos_embed.shape[-2]: - pos_embed = pos_embed[..., :input_dims[0], :] + pos_embed = pos_embed[..., : input_dims[0], :] if input_dims[1] < pos_embed.shape[-1]: - pos_embed = pos_embed[..., :, :input_dims[1]] + pos_embed = pos_embed[..., :, : input_dims[1]] return pos_embed if self.cpe_mode: if self.training: min_scale = math.sqrt(0.1) - scale = torch.rand(batch_size, 1, 1, device=pos_embed.device - ) * (1 - min_scale) + min_scale + scale = ( + torch.rand(batch_size, 1, 1, device=pos_embed.device) + * (1 - min_scale) + + min_scale + ) aspect_min = math.log(3 / 4) aspect_max = -aspect_min aspect = torch.exp( - torch.rand(batch_size, 1, 1, device=pos_embed.device) * - (aspect_max - aspect_min) + aspect_min) + torch.rand(batch_size, 1, 1, device=pos_embed.device) + * (aspect_max - aspect_min) + + aspect_min + ) scale_x = scale * aspect scale_y = scale * (1 / aspect) scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1) - pos_xy = torch.rand( - batch_size, 1, 1, 2, - device=pos_embed.device) * (1 - scale_xy) + pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * ( + 1 - scale_xy + ) lin_x = torch.linspace( - 0, 1, steps=input_dims[1], - device=pos_embed.device)[None, None].expand( - batch_size, input_dims[0], -1) + 0, 1, steps=input_dims[1], device=pos_embed.device + )[None, None].expand(batch_size, input_dims[0], -1) lin_y = torch.linspace( - 0, 1, steps=input_dims[0], - device=pos_embed.device)[None, :, None].expand( - batch_size, -1, input_dims[1]) + 0, 1, steps=input_dims[0], device=pos_embed.device + )[None, :, None].expand(batch_size, -1, input_dims[1]) lin_xy = torch.stack([lin_x, lin_y], dim=-1) @@ -364,26 +375,27 @@ def window_select(pos_embed): pos_embed = F.grid_sample( pos_embed.float().expand(batch_size, -1, -1, -1), grid=grid_xy, - mode='bilinear', - padding_mode='zeros', + mode="bilinear", + padding_mode="zeros", align_corners=True, ).to(pos_embed.dtype) else: max_dim = max(input_dims) - pos_embed = F.interpolate(pos_embed.float(), - size=(max_dim, max_dim), - align_corners=True, - mode='bilinear').to(pos_embed.dtype) + pos_embed = F.interpolate( + pos_embed.float(), + size=(max_dim, max_dim), + align_corners=True, + mode="bilinear", + ).to(pos_embed.dtype) pos_embed = window_select(pos_embed) else: pos_embed = window_select(pos_embed) if pos_embed.shape[-2:] != input_dims: - pos_embed = F.interpolate(pos_embed.float(), - size=input_dims, - align_corners=True, - mode='bilinear').to(pos_embed.dtype) + pos_embed = F.interpolate( + pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear" + ).to(pos_embed.dtype) pos_embed = pos_embed.flatten(2).permute(0, 2, 1) @@ -391,7 +403,6 @@ def window_select(pos_embed): class Im2Patches(nn.Module): - def __init__(self, patch_size: int): super().__init__() self.patch_size = patch_size @@ -406,7 +417,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: px = x.shape[-1] // self.patch_size patches = rearrange( x, - 'b c (py yy) (px xx) -> b (py px) (c yy xx)', + "b c (py yy) (px xx) -> b (py px) (c yy xx)", py=py, yy=self.patch_size, px=px, @@ -416,12 +427,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ViTPatchLinear(nn.Linear): - - def __init__(self, - patch_size: int, - embed_dim: int, - bias: bool = False, - **factory): + def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory): super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory) self.patch_size = patch_size @@ -444,16 +450,19 @@ def __init__( self.config = config self.img_size, self.grid_size, self.num_patches = self._init_img_size( - to_2tuple(config.patch_size), config.image_size) + to_2tuple(config.patch_size), config.image_size + ) max_img_size = int( - round(config.max_img_size / config.patch_size) * config.patch_size) + round(config.max_img_size / config.patch_size) * config.patch_size + ) self.patch_generator = ViTPatchGenerator( config.patch_size, config.hidden_size, input_dims=self.img_size, max_input_dims=max_img_size, cls_token=True, - register_multiple=config.reg_tokens) + register_multiple=config.reg_tokens, + ) self.encoder = InternVisionEncoder( config=config, @@ -463,8 +472,7 @@ def __init__( prefix=f"{prefix}.encoder", ) - def _init_img_size(self, patch_size, img_size: Union[int, tuple[int, - int]]): + def _init_img_size(self, patch_size, img_size: Union[int, tuple[int, int]]): if img_size is None: return None, None, None img_size = to_2tuple(img_size) @@ -509,7 +517,8 @@ def __init__( quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, num_dummy_heads=num_dummy_heads, - prefix=prefix) + prefix=prefix, + ) def forward( self, @@ -534,7 +543,7 @@ def load_weights(self, weights) -> set[str]: # Skip non-radio weights continue - sub = name[len("radio_model."):] # drop "radio_model." prefix + sub = name[len("radio_model.") :] # drop "radio_model." prefix # Skip buffers not used in vLLM if sub in {"summary_idxs"}: @@ -553,15 +562,13 @@ def load_weights(self, weights) -> set[str]: layer_idx = parts[2] suffix = ".".join(parts[3:]) # Skip layer-scale entries that vLLM doesn't use - if suffix in {"ls1", "ls2"} or suffix.startswith( - ("ls1.", "ls2.")): + if suffix in {"ls1", "ls2"} or suffix.startswith(("ls1.", "ls2.")): continue vllm_key = f"model.encoder.layers.{layer_idx}.{suffix}" if vllm_key and vllm_key in params_dict: param = params_dict[vllm_key] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) loaded_params.add(vllm_key) @@ -571,6 +578,6 @@ def _extract_final(self, y: torch.Tensor): # Remove CLS + REGISTERS tokens patch_gen = getattr(self.model, "patch_generator", None) if patch_gen is not None: - all_feat = y[:, patch_gen.num_skip:] + all_feat = y[:, patch_gen.num_skip :] return all_feat diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 94744fe558bd..7c324b7e7872 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -4,6 +4,7 @@ Whenever you add an architecture to this page, please also update `tests/models/registry.py` with example HuggingFace models for it. """ + import hashlib import importlib import json @@ -23,25 +24,36 @@ import transformers from vllm import envs -from vllm.config import (ModelConfig, iter_architecture_defaults, - try_match_architecture_defaults) +from vllm.config import ( + ModelConfig, + iter_architecture_defaults, + try_match_architecture_defaults, +) from vllm.logger import init_logger from vllm.logging_utils import logtime -from vllm.transformers_utils.dynamic_module import ( - try_get_class_from_dynamic_module) - -from .interfaces import (has_inner_state, has_noops, is_attention_free, - is_hybrid, supports_cross_encoding, - supports_multimodal, - supports_multimodal_encoder_tp_data, - supports_multimodal_raw_input_only, supports_pp, - supports_transcription, supports_v0_only) -from .interfaces_base import (get_default_pooling_type, is_pooling_model, - is_text_generation_model) +from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module + +from .interfaces import ( + has_inner_state, + has_noops, + is_attention_free, + is_hybrid, + supports_cross_encoding, + supports_multimodal, + supports_multimodal_encoder_tp_data, + supports_multimodal_raw_input_only, + supports_pp, + supports_transcription, + supports_v0_only, +) +from .interfaces_base import ( + get_default_pooling_type, + is_pooling_model, + is_text_generation_model, +) logger = init_logger(__name__) -# yapf: disable _TEXT_GENERATION_MODELS = { # [Decoder-only] "ApertusForCausalLM": ("apertus", "ApertusForCausalLM"), @@ -93,8 +105,8 @@ "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), - "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), # noqa: E501 - "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 + "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), # noqa: E501 + "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 "GritLM": ("gritlm", "GritLM"), "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"), @@ -108,13 +120,13 @@ "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), - "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # noqa: E501 + "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LongcatFlashForCausalLM": ("longcat_flash", "LongcatFlashForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"), - "FalconH1ForCausalLM":("falcon_h1", "FalconH1ForCausalLM"), + "FalconH1ForCausalLM": ("falcon_h1", "FalconH1ForCausalLM"), "Mamba2ForCausalLM": ("mamba2", "Mamba2ForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), @@ -171,7 +183,8 @@ "LlamaModel": ("llama", "LlamaForCausalLM"), **{ # Multiple models share the same architecture, so we include them all - k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() + k: (mod, arch) + for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() if arch == "LlamaForCausalLM" }, "MistralModel": ("llama", "LlamaForCausalLM"), @@ -187,7 +200,11 @@ "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), # [Multimodal] - "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 + "CLIPModel": ("clip", "CLIPEmbeddingModel"), + "LlavaNextForConditionalGeneration": ( + "llava_next", + "LlavaNextForConditionalGeneration", + ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 # Technically Terratorch models work on images, both in @@ -200,79 +217,150 @@ _CROSS_ENCODER_MODELS = { "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), "BertForTokenClassification": ("bert", "BertForTokenClassification"), - "GteNewForSequenceClassification": ("bert_with_rope", - "GteNewForSequenceClassification"), - "ModernBertForSequenceClassification": ("modernbert", - "ModernBertForSequenceClassification"), - "RobertaForSequenceClassification": ("roberta", - "RobertaForSequenceClassification"), - "XLMRobertaForSequenceClassification": ("roberta", - "RobertaForSequenceClassification"), + "GteNewForSequenceClassification": ( + "bert_with_rope", + "GteNewForSequenceClassification", + ), + "ModernBertForSequenceClassification": ( + "modernbert", + "ModernBertForSequenceClassification", + ), + "RobertaForSequenceClassification": ("roberta", "RobertaForSequenceClassification"), + "XLMRobertaForSequenceClassification": ( + "roberta", + "RobertaForSequenceClassification", + ), # [Auto-converted (see adapters.py)] - "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, + "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, } _MULTIMODAL_MODELS = { # [Decoder-only] "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), - "AyaVisionForConditionalGeneration": ("aya_vision", "AyaVisionForConditionalGeneration"), # noqa: E501 + "AyaVisionForConditionalGeneration": ( + "aya_vision", + "AyaVisionForConditionalGeneration", + ), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), - "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 - "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501 + "ChameleonForConditionalGeneration": ( + "chameleon", + "ChameleonForConditionalGeneration", + ), + "Cohere2VisionForConditionalGeneration": ( + "cohere2_vision", + "Cohere2VisionForConditionalGeneration", + ), "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"), - "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501 + "Ernie4_5_VLMoeForConditionalGeneration": ( + "ernie45_vl", + "Ernie4_5_VLMoeForConditionalGeneration", + ), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 - "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501 + "Gemma3nForConditionalGeneration": ( + "gemma3n_mm", + "Gemma3nForConditionalGeneration", + ), "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501 "Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501 - "GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501 + "GraniteSpeechForConditionalGeneration": ( + "granite_speech", + "GraniteSpeechForConditionalGeneration", + ), "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"), - "InternS1ForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 - "InternVLForConditionalGeneration": ("interns1", "InternS1ForConditionalGeneration"), # noqa: E501 - "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), - "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 + "InternS1ForConditionalGeneration": ( + "interns1", + "InternS1ForConditionalGeneration", + ), + "InternVLForConditionalGeneration": ( + "interns1", + "InternS1ForConditionalGeneration", + ), + "Idefics3ForConditionalGeneration": ( + "idefics3", + "Idefics3ForConditionalGeneration", + ), + "SmolVLMForConditionalGeneration": ("smolvlm", "SmolVLMForConditionalGeneration"), # noqa: E501 "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), - "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501 + "KeyeVL1_5ForConditionalGeneration": ( + "keye_vl1_5", + "KeyeVL1_5ForConditionalGeneration", + ), "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), - "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 - "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 - "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 + "LlavaNextForConditionalGeneration": ( + "llava_next", + "LlavaNextForConditionalGeneration", + ), + "LlavaNextVideoForConditionalGeneration": ( + "llava_next_video", + "LlavaNextVideoForConditionalGeneration", + ), + "LlavaOnevisionForConditionalGeneration": ( + "llava_onevision", + "LlavaOnevisionForConditionalGeneration", + ), "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 "MiDashengLMModel": ("midashenglm", "MiDashengLMModel"), - "MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"), # noqa: E501 + "MiniMaxVL01ForConditionalGeneration": ( + "minimax_vl_01", + "MiniMaxVL01ForConditionalGeneration", + ), "MiniCPMO": ("minicpmo", "MiniCPMO"), "MiniCPMV": ("minicpmv", "MiniCPMV"), - "Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501 + "Mistral3ForConditionalGeneration": ( + "mistral3", + "Mistral3ForConditionalGeneration", + ), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), "Ovis2_5": ("ovis2_5", "Ovis2_5"), - "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 + "PaliGemmaForConditionalGeneration": ( + "paligemma", + "PaliGemmaForConditionalGeneration", + ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501 "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 - "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 - "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 - "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 - "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 + "Qwen2_5_VLForConditionalGeneration": ( + "qwen2_5_vl", + "Qwen2_5_VLForConditionalGeneration", + ), + "Qwen2AudioForConditionalGeneration": ( + "qwen2_audio", + "Qwen2AudioForConditionalGeneration", + ), + "Qwen2_5OmniModel": ( + "qwen2_5_omni_thinker", + "Qwen2_5OmniThinkerForConditionalGeneration", + ), + "Qwen2_5OmniForConditionalGeneration": ( + "qwen2_5_omni_thinker", + "Qwen2_5OmniThinkerForConditionalGeneration", + ), "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501 - "Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), # noqa: E501 + "Qwen3VLMoeForConditionalGeneration": ( + "qwen3_vl_moe", + "Qwen3VLMoeForConditionalGeneration", + ), "SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"), "Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501 "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 - "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 + "Tarsier2ForConditionalGeneration": ( + "qwen2_vl", + "Tarsier2ForConditionalGeneration", + ), "UltravoxModel": ("ultravox", "UltravoxModel"), "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] @@ -310,13 +398,27 @@ "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 "TransformersMoEForCausalLM": ("transformers_moe", "TransformersMoEForCausalLM"), # noqa: E501 - "TransformersMoEForMultimodalLM": ("transformers_moe", "TransformersMoEForMultimodalLM"), # noqa: E501 - "TransformersEmbeddingModel": ("transformers_pooling", "TransformersEmbeddingModel"), # noqa: E501 - "TransformersForSequenceClassification": ("transformers_pooling", "TransformersForSequenceClassification"), # noqa: E501 - "TransformersMoEForSequenceClassification": ("transformers_pooling", "TransformersMoEForSequenceClassification"), # noqa: E501 - "TransformersMoEEmbeddingModel": ("transformers_pooling", "TransformersMoEEmbeddingModel"), # noqa: E501 + "TransformersMoEForMultimodalLM": ( + "transformers_moe", + "TransformersMoEForMultimodalLM", + ), + "TransformersEmbeddingModel": ( + "transformers_pooling", + "TransformersEmbeddingModel", + ), + "TransformersForSequenceClassification": ( + "transformers_pooling", + "TransformersForSequenceClassification", + ), + "TransformersMoEForSequenceClassification": ( + "transformers_pooling", + "TransformersMoEForSequenceClassification", + ), + "TransformersMoEEmbeddingModel": ( + "transformers_pooling", + "TransformersMoEEmbeddingModel", + ), } -# yapf: enable _VLLM_MODELS = { **_TEXT_GENERATION_MODELS, @@ -332,9 +434,7 @@ # can modify this variable to alter the args if needed. e.g. # when we use par format to pack things together, sys.executable # might not be the target we want to run. -_SUBPROCESS_COMMAND = [ - sys.executable, "-m", "vllm.model_executor.models.registry" -] +_SUBPROCESS_COMMAND = [sys.executable, "-m", "vllm.model_executor.models.registry"] _PREVIOUSLY_SUPPORTED_MODELS = { "MotifForCausalLM": "0.10.2", @@ -379,24 +479,26 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": default_pooling_type=get_default_pooling_type(model), supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), - supports_multimodal_raw_input_only= - supports_multimodal_raw_input_only(model), - supports_multimodal_encoder_tp_data= - supports_multimodal_encoder_tp_data(model), + supports_multimodal_raw_input_only=supports_multimodal_raw_input_only( + model + ), + supports_multimodal_encoder_tp_data=supports_multimodal_encoder_tp_data( + model + ), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), supports_transcription=supports_transcription(model), - supports_transcription_only=(supports_transcription(model) and - model.supports_transcription_only), + supports_transcription_only=( + supports_transcription(model) and model.supports_transcription_only + ), supports_v0_only=supports_v0_only(model), has_noops=has_noops(model), ) class _BaseRegisteredModel(ABC): - @abstractmethod def inspect_model_cls(self) -> _ModelInfo: raise NotImplementedError @@ -434,6 +536,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel): """ Represents a model that has not been imported in the main process. """ + module_name: str class_name: str @@ -445,38 +548,42 @@ def _get_cache_filename(self) -> str: cls_name = f"{self.module_name}-{self.class_name}".replace(".", "-") return f"{cls_name}.json" - def _load_modelinfo_from_cache(self, - module_hash: str) -> _ModelInfo | None: + def _load_modelinfo_from_cache(self, module_hash: str) -> _ModelInfo | None: try: try: - modelinfo_path = self._get_cache_dir( - ) / self._get_cache_filename() + modelinfo_path = self._get_cache_dir() / self._get_cache_filename() with open(modelinfo_path, encoding="utf-8") as file: mi_dict = json.load(file) except FileNotFoundError: - logger.debug(("Cached model info file " - "for class %s.%s not found"), self.module_name, - self.class_name) + logger.debug( + ("Cached model info file for class %s.%s not found"), + self.module_name, + self.class_name, + ) return None if mi_dict["hash"] != module_hash: - logger.debug(("Cached model info file " - "for class %s.%s is stale"), self.module_name, - self.class_name) + logger.debug( + ("Cached model info file for class %s.%s is stale"), + self.module_name, + self.class_name, + ) return None # file not changed, use cached _ModelInfo properties return _ModelInfo(**mi_dict["modelinfo"]) except Exception: - logger.exception(("Cached model info " - "for class %s.%s error. "), self.module_name, - self.class_name) + logger.exception( + ("Cached model info for class %s.%s error. "), + self.module_name, + self.class_name, + ) return None - def _save_modelinfo_to_cache(self, mi: _ModelInfo, - module_hash: str) -> None: + def _save_modelinfo_to_cache(self, mi: _ModelInfo, module_hash: str) -> None: """save dictionary json file to cache""" from vllm.model_executor.model_loader.weight_utils import atomic_writer + try: modelinfo_dict = { "hash": module_hash, @@ -485,15 +592,14 @@ def _save_modelinfo_to_cache(self, mi: _ModelInfo, cache_dir = self._get_cache_dir() cache_dir.mkdir(parents=True, exist_ok=True) modelinfo_path = cache_dir / self._get_cache_filename() - with atomic_writer(modelinfo_path, encoding='utf-8') as f: + with atomic_writer(modelinfo_path, encoding="utf-8") as f: json.dump(modelinfo_dict, f, indent=2) except Exception: logger.exception("Error saving model info cache.") @logtime(logger=logger, msg="Registry inspect model class") def inspect_model_cls(self) -> _ModelInfo: - model_path = Path( - __file__).parent / f"{self.module_name.split('.')[-1]}.py" + model_path = Path(__file__).parent / f"{self.module_name.split('.')[-1]}.py" module_hash = None if model_path.exists(): @@ -502,21 +608,26 @@ def inspect_model_cls(self) -> _ModelInfo: mi = self._load_modelinfo_from_cache(module_hash) if mi is not None: - logger.debug(("Loaded model info " - "for class %s.%s from cache"), self.module_name, - self.class_name) + logger.debug( + ("Loaded model info for class %s.%s from cache"), + self.module_name, + self.class_name, + ) return mi else: - logger.debug(("Cache model info " - "for class %s.%s miss. " - "Loading model instead."), self.module_name, - self.class_name) + logger.debug( + ("Cache model info for class %s.%s miss. Loading model instead."), + self.module_name, + self.class_name, + ) # Performed in another process to avoid initializing CUDA mi = _run_in_subprocess( - lambda: _ModelInfo.from_model_cls(self.load_model_cls())) - logger.debug("Loaded model info for class %s.%s", self.module_name, - self.class_name) + lambda: _ModelInfo.from_model_cls(self.load_model_cls()) + ) + logger.debug( + "Loaded model info for class %s.%s", self.module_name, self.class_name + ) # save cache file if module_hash is not None: @@ -535,12 +646,12 @@ def _try_load_model_cls( model: _BaseRegisteredModel, ) -> Optional[type[nn.Module]]: from vllm.platforms import current_platform + current_platform.verify_model_arch(model_arch) try: return model.load_model_cls() except Exception: - logger.exception("Error in loading model architecture '%s'", - model_arch) + logger.exception("Error in loading model architecture '%s'", model_arch) return None @@ -552,8 +663,7 @@ def _try_inspect_model_cls( try: return model.inspect_model_cls() except Exception: - logger.exception("Error in inspecting model architecture '%s'", - model_arch) + logger.exception("Error in inspecting model architecture '%s'", model_arch) return None @@ -588,8 +698,10 @@ def register_model( if model_arch in self.models: logger.warning( "Model architecture %s is already registered, and will be " - "overwritten by the new model class %s.", model_arch, - model_cls) + "overwritten by the new model class %s.", + model_arch, + model_cls, + ) if isinstance(model_cls, str): split_str = model_cls.split(":") @@ -601,8 +713,10 @@ def register_model( elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module): model = _RegisteredModel.from_model_cls(model_cls) else: - msg = ("`model_cls` should be a string or PyTorch model class, " - f"not a {type(model_arch)}") + msg = ( + "`model_cls` should be a string or PyTorch model class, " + f"not a {type(model_arch)}" + ) raise TypeError(msg) self.models[model_arch] = model @@ -613,7 +727,8 @@ def _raise_for_unsupported(self, architectures: list[str]): if any(arch in all_supported_archs for arch in architectures): raise ValueError( f"Model architectures {architectures} failed " - "to be inspected. Please check the logs for more details.") + "to be inspected. Please check the logs for more details." + ) for arch in architectures: if arch in _PREVIOUSLY_SUPPORTED_MODELS: @@ -623,14 +738,15 @@ def _raise_for_unsupported(self, architectures: list[str]): f"Model architecture {arch} was supported in vLLM until " f"v{previous_version}, and is not supported anymore. " "Please use an older version of vLLM if you want to " - "use this model architecture.") + "use this model architecture." + ) raise ValueError( f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {all_supported_archs}") + f"Supported architectures: {all_supported_archs}" + ) - def _try_load_model_cls(self, - model_arch: str) -> Optional[type[nn.Module]]: + def _try_load_model_cls(self, model_arch: str) -> Optional[type[nn.Module]]: if model_arch not in self.models: return None @@ -650,8 +766,9 @@ def _try_resolve_transformers( if architecture in _TRANSFORMERS_BACKEND_MODELS: return architecture - auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", - None) or dict() + auto_map: dict[str, str] = ( + getattr(model_config.hf_config, "auto_map", None) or dict() + ) # Make sure that config class is always initialized before model class, # otherwise the model class won't be able to access the config class, @@ -693,7 +810,8 @@ def _try_resolve_transformers( "registered model in the Transformers library (only " "relevant if the model is meant to be in Transformers) " "and 'AutoModel' is not present in the model config's " - "'auto_map' (relevant if the model is custom).") + "'auto_map' (relevant if the model is custom)." + ) if not model_module.is_backend_compatible(): if model_config.model_impl != "transformers": @@ -701,7 +819,8 @@ def _try_resolve_transformers( raise ValueError( f"The Transformers implementation of {architecture!r} " - "is not compatible with vLLM.") + "is not compatible with vLLM." + ) return model_config._get_transformers_backend_cls() @@ -743,8 +862,7 @@ def inspect_model_cls( # Require transformers impl if model_config.model_impl == "transformers": - arch = self._try_resolve_transformers(architectures[0], - model_config) + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_info = self._try_inspect_model_cls(arch) if model_info is not None: @@ -754,11 +872,12 @@ def inspect_model_cls( return (model_info, "Terratorch") # Fallback to transformers impl (after resolving convert_type) - if (all(arch not in self.models for arch in architectures) - and model_config.model_impl == "auto" - and getattr(model_config, "convert_type", "none") == "none"): - arch = self._try_resolve_transformers(architectures[0], - model_config) + if ( + all(arch not in self.models for arch in architectures) + and model_config.model_impl == "auto" + and getattr(model_config, "convert_type", "none") == "none" + ): + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_info = self._try_inspect_model_cls(arch) if model_info is not None: @@ -771,10 +890,11 @@ def inspect_model_cls( return (model_info, arch) # Fallback to transformers impl (before resolving runner_type) - if (all(arch not in self.models for arch in architectures) - and model_config.model_impl == "auto"): - arch = self._try_resolve_transformers(architectures[0], - model_config) + if ( + all(arch not in self.models for arch in architectures) + and model_config.model_impl == "auto" + ): + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_info = self._try_inspect_model_cls(arch) if model_info is not None: @@ -794,8 +914,7 @@ def resolve_model_cls( # Require transformers impl if model_config.model_impl == "transformers": - arch = self._try_resolve_transformers(architectures[0], - model_config) + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_cls = self._try_load_model_cls(arch) if model_cls is not None: @@ -807,11 +926,12 @@ def resolve_model_cls( return (model_cls, arch) # Fallback to transformers impl (after resolving convert_type) - if (all(arch not in self.models for arch in architectures) - and model_config.model_impl == "auto" - and getattr(model_config, "convert_type", "none") == "none"): - arch = self._try_resolve_transformers(architectures[0], - model_config) + if ( + all(arch not in self.models for arch in architectures) + and model_config.model_impl == "auto" + and getattr(model_config, "convert_type", "none") == "none" + ): + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_cls = self._try_load_model_cls(arch) if model_cls is not None: @@ -824,10 +944,11 @@ def resolve_model_cls( return (model_cls, arch) # Fallback to transformers impl (before resolving runner_type) - if (all(arch not in self.models for arch in architectures) - and model_config.model_impl == "auto"): - arch = self._try_resolve_transformers(architectures[0], - model_config) + if ( + all(arch not in self.models for arch in architectures) + and model_config.model_impl == "auto" + ): + arch = self._try_resolve_transformers(architectures[0], model_config) if arch is not None: model_cls = self._try_load_model_cls(arch) if model_cls is not None: @@ -940,14 +1061,15 @@ def is_v1_compatible( return not model_cls.supports_v0_only -ModelRegistry = _ModelRegistry({ - model_arch: - _LazyRegisteredModel( - module_name=f"vllm.model_executor.models.{mod_relname}", - class_name=cls_name, - ) - for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() -}) +ModelRegistry = _ModelRegistry( + { + model_arch: _LazyRegisteredModel( + module_name=f"vllm.model_executor.models.{mod_relname}", + class_name=cls_name, + ) + for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() + } +) _T = TypeVar("_T") @@ -960,21 +1082,23 @@ def _run_in_subprocess(fn: Callable[[], _T]) -> _T: # `cloudpickle` allows pickling lambda functions directly import cloudpickle + input_bytes = cloudpickle.dumps((fn, output_filepath)) # cannot use `sys.executable __file__` here because the script # contains relative imports - returned = subprocess.run(_SUBPROCESS_COMMAND, - input=input_bytes, - capture_output=True) + returned = subprocess.run( + _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True + ) # check if the subprocess is successful try: returned.check_returncode() except Exception as e: # wrap raised exception to provide more information - raise RuntimeError(f"Error raised in subprocess:\n" - f"{returned.stderr.decode()}") from e + raise RuntimeError( + f"Error raised in subprocess:\n{returned.stderr.decode()}" + ) from e with open(output_filepath, "rb") as f: return pickle.load(f) @@ -983,6 +1107,7 @@ def _run_in_subprocess(fn: Callable[[], _T]) -> _T: def _run() -> None: # Setup plugins from vllm.plugins import load_general_plugins + load_general_plugins() fn, output_file = pickle.loads(sys.stdin.buffer.read()) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index a13042a6367c..6408cf7937b2 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -9,16 +9,25 @@ from transformers import RobertaConfig from vllm.config import ModelConfig, VllmConfig -from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, - DispatchPooler, Pooler) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.models.bert import (TOKEN_TYPE_SHIFT, - BertEmbeddingModel, BertModel, - _decode_token_type_ids, - _encode_token_type_ids) -from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, - maybe_prefix) +from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + CLSPool, + DispatchPooler, + Pooler, +) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models.bert import ( + TOKEN_TYPE_SHIFT, + BertEmbeddingModel, + BertModel, + _decode_token_type_ids, + _encode_token_type_ids, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + maybe_prefix, +) from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel @@ -27,21 +36,23 @@ class RobertaEmbedding(nn.Module): - def __init__(self, config: RobertaConfig): super().__init__() self.size = config.hidden_size - self.word_embeddings = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) self.padding_idx = config.pad_token_id - self.position_embeddings = nn.Embedding(config.max_position_embeddings, - config.hidden_size, - padding_idx=self.padding_idx) - - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, - config.hidden_size) - self.LayerNorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + padding_idx=self.padding_idx, + ) + + self.token_type_embeddings = nn.Embedding( + config.type_vocab_size, config.hidden_size + ) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).unsqueeze(0), @@ -49,8 +60,9 @@ def __init__(self, config: RobertaConfig): self.position_embedding_type = config.position_embedding_type if self.position_embedding_type != "absolute": - raise ValueError("Only 'absolute' position_embedding_type" + - " is supported") + raise ValueError( + "Only 'absolute' position_embedding_type" + " is supported" + ) def forward( self, @@ -79,12 +91,10 @@ def __init__(self, model_config: "ModelConfig"): super().__init__() config = model_config.hf_config head_dtype = model_config.head_dtype - self.dense = nn.Linear(config.hidden_size, - config.hidden_size, - dtype=head_dtype) - self.out_proj = nn.Linear(config.hidden_size, - config.num_labels, - dtype=head_dtype) + self.dense = nn.Linear(config.hidden_size, config.hidden_size, dtype=head_dtype) + self.out_proj = nn.Linear( + config.hidden_size, config.num_labels, dtype=head_dtype + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # CLSPool has already been applied in `pooling` @@ -98,13 +108,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities. - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. - Attributes: - model: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -117,34 +127,35 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # Fix Roberta positions here outside of the CUDA graph. # Because we need the to extract the sequences from # input_ids the control flow is data dependent. - replace_roberta_positions(input_ids=input_ids, - position_ids=positions, - padding_idx=self.padding_idx) - - return self.model(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) - - def _build_model(self, - vllm_config: VllmConfig, - prefix: str = "") -> Union[BertModel, BertWithRope]: - if (vllm_config.model_config.hf_config.position_embedding_type == - "rotary"): + replace_roberta_positions( + input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx + ) + + return self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) + + def _build_model( + self, vllm_config: VllmConfig, prefix: str = "" + ) -> Union[BertModel, BertWithRope]: + if vllm_config.model_config.hf_config.position_embedding_type == "rotary": return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix) else: - return BertModel(vllm_config=vllm_config, - prefix=prefix, - embedding_class=RobertaEmbedding) + return BertModel( + vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights_list = list(weights) has_roberta_prefix = any( - name.startswith("roberta.") for name, _ in weights_list) + name.startswith("roberta.") for name, _ in weights_list + ) if has_roberta_prefix: # For models with the `roberta.` prefix e.g. # `FacebookAI/roberta-base` @@ -162,26 +173,27 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities. - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. - Attributes: - roberta: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ + Attributes: + roberta: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ is_pooling_model = True jina_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ - 'emb_ln': "embeddings.LayerNorm", - 'layers': "layer", - 'mixer.Wqkv': "attention.self.qkv_proj", - 'mixer.out_proj': "attention.output.dense", - 'norm1': "attention.output.LayerNorm", - 'mlp.fc1': "intermediate.dense", - 'mlp.fc2': "output.dense", - 'norm2': "output.LayerNorm", - }) + "emb_ln": "embeddings.LayerNorm", + "layers": "layer", + "mixer.Wqkv": "attention.self.qkv_proj", + "mixer.out_proj": "attention.output.dense", + "norm1": "attention.output.LayerNorm", + "mlp.fc1": "intermediate.dense", + "mlp.fc2": "output.dense", + "norm2": "output.LayerNorm", + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -189,32 +201,35 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx: int = vllm_config.model_config.hf_config.pad_token_id self.num_labels = config.num_labels - self.roberta = BertModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "bert"), - embedding_class=RobertaEmbedding) + self.roberta = BertModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=RobertaEmbedding, + ) self.classifier = RobertaClassificationHead(vllm_config.model_config) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - ), - "score": - ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) @@ -231,22 +246,24 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - replace_roberta_positions(input_ids=input_ids, - position_ids=positions, - padding_idx=self.padding_idx) + replace_roberta_positions( + input_ids=input_ids, position_ids=positions, padding_idx=self.padding_idx + ) if token_type_ids is not None: assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) assert input_ids is not None _encode_token_type_ids(input_ids, token_type_ids) - return self.roberta(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + return self.roberta( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + ) -def replace_roberta_positions(input_ids: torch.Tensor, - position_ids: torch.Tensor, - padding_idx: int) -> None: +def replace_roberta_positions( + input_ids: torch.Tensor, position_ids: torch.Tensor, padding_idx: int +) -> None: # Replace position ids because in RoBERTa models # they have to start at padding_idx + 1 and ignore # existing padding tokens diff --git a/vllm/model_executor/models/rvl.py b/vllm/model_executor/models/rvl.py index 594d018f6bb6..89150677f3ce 100644 --- a/vllm/model_executor/models/rvl.py +++ b/vllm/model_executor/models/rvl.py @@ -13,14 +13,16 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict -from .llava_next import (LlavaDummyInputsBuilder, LlavaNextMultiModalProcessor, - LlavaNextProcessingInfo) +from .llava_next import ( + LlavaDummyInputsBuilder, + LlavaNextMultiModalProcessor, + LlavaNextProcessingInfo, +) from .llava_onevision import LlavaOnevisionForConditionalGeneration from .utils import WeightsMapper class RVLProcessingInfo(LlavaNextProcessingInfo): - def get_hf_config(self): return self.ctx.get_hf_config() @@ -29,7 +31,6 @@ def get_hf_processor(self, **kwargs: object): class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) image_token = "<image>" @@ -44,26 +45,24 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - target_width, target_height = ( - self.info.get_image_size_with_most_features()) + target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), } class RVLMultiModalProjector(nn.Module): - def __init__(self, config): super().__init__() - self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, - eps=1e-06) + self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, eps=1e-06) self.linear_1 = nn.Linear( config.vision_config.hidden_size, config.text_config.hidden_size, @@ -91,7 +90,6 @@ def forward(self, image_feature: torch.Tensor) -> torch.Tensor: dummy_inputs=RVLDummyInputsBuilder, ) class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration): - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # mapping for new names in checkpoint saved after transformers @@ -101,7 +99,8 @@ class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration): "model.multi_modal_projector.": "multi_modal_projector.", "model.image_newline": "image_newline", "lm_head.": "language_model.lm_head.", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index a217c820fedf..ca33a694a3b6 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only SeedOss model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -37,28 +38,38 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class SeedOssMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -83,8 +94,9 @@ def __init__( prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -95,7 +107,6 @@ def forward(self, x): class SeedOssAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -181,7 +192,6 @@ def forward( class SeedOssDecoderLayer(nn.Module): - def __init__( self, config: SeedOssConfig, @@ -224,10 +234,10 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -240,16 +250,14 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -260,14 +268,16 @@ def forward( "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, - }) + } +) class SeedOssModel(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer, + ): super().__init__() config = vllm_config.model_config.hf_config @@ -275,8 +285,9 @@ def __init__(self, quant_config = vllm_config.quant_config # TODO (@robertgshaw2): see if this can be moved out - if (cache_config.sliding_window is not None - and hasattr(config, "max_window_layers")): + if cache_config.sliding_window is not None and hasattr( + config, "max_window_layers" + ): assert config.max_window_layers == config.num_hidden_layers, ( "Sliding window for some but all layers is not supported. " "This model uses sliding window but `max_window_layers` = {} " @@ -284,14 +295,16 @@ def __init__(self, "to discuss this feature.".format( config.max_window_layers, config.num_hidden_layers, - )) + ) + ) self.config = config self.quant_config = quant_config self.vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -305,16 +318,18 @@ def __init__(self, decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: decoder_layer_type(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: decoder_layer_type( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -347,15 +362,13 @@ def forward( residual, ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -369,18 +382,19 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -404,8 +418,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -434,25 +447,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lora_config = lora_config self.quant_config = quant_config - self.model = SeedOssModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = SeedOssModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -464,8 +480,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -475,11 +492,9 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 4c60d96c77d7..ee21a03c8525 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -14,28 +14,33 @@ from vllm.attention.layer import MultiHeadAttention from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) -from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy, - resolve_visual_encoder_outputs) +from .vision import ( + VisionEncoderInfo, + VisionFeatureSelectStrategy, + resolve_visual_encoder_outputs, +) class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): - def get_num_image_tokens( self, *, image_width: int, image_height: int, ) -> int: - return self.get_patch_grid_length()**2 + return self.get_patch_grid_length() ** 2 def get_image_size(self) -> int: return self.vision_config.image_size @@ -50,7 +55,6 @@ def get_patch_grid_length(self) -> int: # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa class SiglipVisionEmbeddings(nn.Module): - def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config @@ -66,19 +70,20 @@ def __init__(self, config: SiglipVisionConfig): padding="valid", ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = VocabParallelEmbedding( - self.num_positions, self.embed_dim) + self.num_positions, self.embed_dim + ) self.register_buffer( "position_ids", - torch.arange(self.num_positions, dtype=torch.int64).expand( - (1, -1)), + torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)), persistent=False, ) - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, - width: int) -> torch.Tensor: + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: """ This method is an adapted method for SigLIP (due to SigLIP not having class embedding unlike other ViTs) that allows the model to interpolate @@ -103,8 +108,8 @@ class embedding unlike other ViTs) that allows the model to interpolate height, width = height + 0.1, width + 0.1 patch_pos_embed = position_embeddings.reshape( - 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), - dim) + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, @@ -115,33 +120,36 @@ class embedding unlike other ViTs) that allows the model to interpolate mode="bicubic", align_corners=False, ) - if (int(height) != patch_pos_embed.shape[-2] - or int(width) != patch_pos_embed.shape[-1]): - raise ValueError("Width or height does not match with " - "the interpolated position embeddings") + if ( + int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1] + ): + raise ValueError( + "Width or height does not match with " + "the interpolated position embeddings" + ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed - def forward(self, - pixel_values: torch.Tensor, - interpolate_pos_encoding: bool = False) -> torch.Tensor: + def forward( + self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False + ) -> torch.Tensor: _, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to( - dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] embeddings = patch_embeds.flatten(2).transpose(1, 2) if interpolate_pos_encoding: - embeddings += self.interpolate_pos_encoding( - embeddings, height, width) + embeddings += self.interpolate_pos_encoding(embeddings, height, width) else: embeddings += self.position_embedding(self.position_ids) return embeddings class SiglipAttention(nn.Module): - def __init__( self, config: SiglipVisionConfig, @@ -155,9 +163,11 @@ def __init__( self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError(f"embed_dim must be divisible by num_heads (got " - "`embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + raise ValueError( + f"embed_dim must be divisible by num_heads (got " + "`embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout @@ -179,8 +189,9 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = MultiHeadAttention(self.num_heads_per_partition, - self.head_dim, self.scale) + self.attn = MultiHeadAttention( + self.num_heads_per_partition, self.head_dim, self.scale + ) def forward( self, @@ -197,7 +208,6 @@ def forward( class SiglipMLP(nn.Module): - def __init__( self, config: SiglipVisionConfig, @@ -209,15 +219,14 @@ def __init__( self.config = config self.activation_fn = get_act_fn(config.hidden_act) # Special handling for BNB and torchao quantization - if quant_config and quant_config.get_name() in [ - "bitsandbytes", "torchao" - ]: + if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]: quantizable = True else: # For other quantization, we require the hidden size to be a # multiple of 64 - quantizable = (config.hidden_size % 64 == 0 - and config.intermediate_size % 64 == 0) + quantizable = ( + config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0 + ) self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, @@ -239,7 +248,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class SiglipEncoderLayer(nn.Module): - def __init__( self, config: SiglipVisionConfig, @@ -255,15 +263,13 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", ) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, @@ -284,7 +290,6 @@ def forward( class SiglipEncoder(nn.Module): - def __init__( self, config: SiglipVisionConfig, @@ -301,12 +306,16 @@ def __init__( else: num_hidden_layers = num_hidden_layers_override - self.layers = nn.ModuleList([ - SiglipEncoderLayer(config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + SiglipEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) def forward( self, @@ -341,12 +350,12 @@ def __init__( self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention self.attention = torch.nn.MultiheadAttention( - config.hidden_size, config.num_attention_heads, batch_first=True) - self.layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) - self.mlp = SiglipMLP(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + config.hidden_size, config.num_attention_heads, batch_first=True + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP( + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: batch_size = hidden_state.shape[0] @@ -363,7 +372,6 @@ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: class SiglipVisionTransformer(nn.Module): - def __init__( self, config: SiglipVisionConfig, @@ -399,13 +407,13 @@ def __init__( require_post_norm = len(self.encoder.layers) == num_hidden_layers if require_post_norm: - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) else: self.post_layernorm = None - self.use_head = (True if not hasattr(config, "vision_use_head") else - config.vision_use_head) + self.use_head = ( + True if not hasattr(config, "vision_use_head") else config.vision_use_head + ) if self.use_head: self.head = SiglipMultiheadAttentionPoolingHead( config=config, @@ -493,8 +501,7 @@ def forward( feature_select_strategy=feature_select_strategy, ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -507,8 +514,10 @@ def load_weights(self, weights: Iterable[tuple[str, for name, loaded_weight in weights: # post_layernorm is optional in SiglipVisionModel - if (name.startswith("vision_model.post_layernorm") - and self.vision_model.post_layernorm is None): + if ( + name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None + ): continue # omit layers when num_hidden_layers_override is set @@ -518,21 +527,21 @@ def load_weights(self, weights: Iterable[tuple[str, continue # Check if this is a scale parameter that needs remapping first - if name.endswith( - (".k_scale", ".v_scale", ".q_scale", ".prob_scale")): + if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")): # Try to remap the scale name first remapped_name = maybe_remap_kv_scale_name(name, params_dict) if remapped_name is not None and remapped_name in params_dict: # Successfully remapped, use the remapped name param = params_dict[remapped_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(remapped_name) continue # If remapping failed, continue with normal processing - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -543,8 +552,7 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index 5bea5b1daf4d..7cd133d9da1d 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -17,10 +17,13 @@ from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -28,23 +31,20 @@ class VisionRotaryEmbedding(nn.Module): - def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - inv_freq = 1.0 / (theta - **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) return freqs class Siglip2VisionEmbeddings(nn.Module): - def __init__(self, config: PretrainedConfig): super().__init__() self.config = config @@ -58,15 +58,13 @@ def __init__(self, config: PretrainedConfig): # siglip2 naflex if self.num_patches > 0: self.patch_embedding = ReplicatedLinear( - input_size=config.num_channels * self.patch_size * - self.patch_size, + input_size=config.num_channels * self.patch_size * self.patch_size, output_size=self.embed_dim, return_bias=False, ) if self.preserve_original_pe: self.position_embedding_size = int(self.num_patches**0.5) - self.position_embedding = nn.Embedding(self.num_patches, - self.embed_dim) + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) else: self.patch_embedding = nn.Conv2d( @@ -77,15 +75,15 @@ def __init__(self, config: PretrainedConfig): padding="valid", ) if self.preserve_original_pe: - self.num_patches = (self.image_size // self.patch_size)**2 - self.position_embedding_size = (self.image_size // - self.patch_size) - self.position_embedding = nn.Embedding(self.num_patches, - self.embed_dim) - - def forward(self, - pixel_values: torch.FloatTensor, - grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor: + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.position_embedding_size = self.image_size // self.patch_size + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: """ Args: pixel_values (`torch.FloatTensor`): @@ -100,36 +98,48 @@ def forward(self, # Apply patch embeddings to already patchified pixel values target_dtype = self.patch_embedding.weight.dtype if isinstance(self.patch_embedding, LinearBase): - patch_embeds = self.patch_embedding( - pixel_values.to(dtype=target_dtype)) + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) elif isinstance(self.patch_embedding, nn.Conv2d): pixel_values = pixel_values.view( - -1, self.config.num_channels * self.config.temporal_patch_size, - self.patch_size, self.patch_size) - patch_embeds = self.patch_embedding( - pixel_values.to(dtype=target_dtype)) + -1, + self.config.num_channels * self.config.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) patch_embeds = patch_embeds.reshape(-1, self.embed_dim) if self.preserve_original_pe: assert grid_thws is not None pos_embed_new = torch.zeros_like(patch_embeds) - positional_embeddings = self.position_embedding.weight.reshape( - self.position_embedding_size, self.position_embedding_size, - -1).unsqueeze(0).permute(0, 3, 1, 2) + positional_embeddings = ( + self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, -1 + ) + .unsqueeze(0) + .permute(0, 3, 1, 2) + ) cnt = 0 for t, h, w in grid_thws: volume = t * h * w - pe = F.interpolate(positional_embeddings, - size=(h, w), - mode='bicubic', - align_corners=False) + pe = F.interpolate( + positional_embeddings, + size=(h, w), + mode="bicubic", + align_corners=False, + ) pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1) pe = pe[0].repeat(t, 1) - pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, - w // self.hidden_stride, self.hidden_stride, - -1) + pe = pe.reshape( + t, + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + -1, + ) pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(volume, -1) - pos_embed_new[cnt:cnt + volume] = pe + pos_embed_new[cnt : cnt + volume] = pe cnt += volume patch_embeds = patch_embeds + pos_embed_new @@ -143,9 +153,9 @@ def rotate_half(x, interleaved=False): return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange(torch.stack((-x2, x1), dim=-1), - "... d two -> ... (d two)", - two=2) + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) def apply_rotary_emb_torch(x, cos, sin, interleaved=False): @@ -156,15 +166,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False): ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat( - cos, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) sin = repeat( - sin, - "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) return torch.cat( [ - x[..., :ro_dim] * cos + - rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], ], dim=-1, ) @@ -181,13 +191,12 @@ def apply_rotary_pos_emb( sin = sin.chunk(2, dim=-1)[0].contiguous() if is_flash_attn_backend: from flash_attn.layers.rotary import apply_rotary_emb + apply_rotary_emb_func = apply_rotary_emb else: apply_rotary_emb_func = apply_rotary_emb_torch - q_embed = apply_rotary_emb_func(q.float(), cos.float(), - sin.float()).type_as(q) - k_embed = apply_rotary_emb_func(k.float(), cos.float(), - sin.float()).type_as(k) + q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q) + k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k) return q_embed, k_embed @@ -210,7 +219,8 @@ def __init__( raise ValueError( f"embed_dim must be divisible by num_heads " f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads}).") + f" {self.num_heads})." + ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False @@ -231,37 +241,41 @@ def __init__( prefix=f"{prefix}.out_proj", ) - self.tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) self.num_heads_per_partition = divide(self.num_heads, self.tp_size) self.use_rope = config.use_rope # Detect attention implementation. self.attn_backend = get_vit_attn_backend( - head_size=self.head_dim, dtype=torch.get_default_dtype()) + head_size=self.head_dim, dtype=torch.get_default_dtype() + ) self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func \ - = maybe_get_vit_flash_attn_backend( + self.attn_backend, self.flash_attn_varlen_func = ( + maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, ) + ) if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.ROCM_AITER_FA, }: self.attn_backend = _Backend.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.ROCM_AITER_FA, } def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - position_embeddings: Optional[tuple[torch.Tensor, - torch.Tensor]] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" @@ -270,26 +284,27 @@ def forward( qkv_states, _ = self.qkv_proj(hidden_states) queries, keys, values = qkv_states.chunk(3, dim=-1) - queries = queries.view(seq_length, self.num_heads_per_partition, - self.head_dim) - keys = keys.view(seq_length, self.num_heads_per_partition, - self.head_dim) - values = values.view(seq_length, self.num_heads_per_partition, - self.head_dim) + queries = queries.view(seq_length, self.num_heads_per_partition, self.head_dim) + keys = keys.view(seq_length, self.num_heads_per_partition, self.head_dim) + values = values.view(seq_length, self.num_heads_per_partition, self.head_dim) if self.use_rope: cos, sin = position_embeddings - queries, keys = apply_rotary_pos_emb(queries.unsqueeze(0), - keys.unsqueeze(0), cos, sin, - self.is_flash_attn_backend) + queries, keys = apply_rotary_pos_emb( + queries.unsqueeze(0), + keys.unsqueeze(0), + cos, + sin, + self.is_flash_attn_backend, + ) queries = queries.squeeze(0) keys = keys.squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if self.is_flash_attn_backend: attn_output = self.flash_attn_varlen_func( - queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, - max_seqlen).reshape(seq_length, -1) + queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen + ).reshape(seq_length, -1) elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. batch_size = cu_seqlens.shape[0] - 1 @@ -308,13 +323,9 @@ def forward( # (1, num_heads, seq_len, head_dim) q_i, k_i, v_i = [x.transpose(1, 2) for x in (q_i, k_i, v_i)] - output_i = F.scaled_dot_product_attention(q_i, - k_i, - v_i, - dropout_p=0.0) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) # (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim) - output_i = output_i.transpose(1, 2).reshape( - end_idx - start_idx, -1) + output_i = output_i.transpose(1, 2).reshape(end_idx - start_idx, -1) outputs.append(output_i) attn_output = torch.cat(outputs, dim=0) @@ -323,7 +334,6 @@ def forward( class Siglip2MLP(nn.Module): - def __init__( self, config: Siglip2VisionConfig, @@ -357,7 +367,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Siglip2EncoderLayer(nn.Module): - def __init__( self, config: Siglip2VisionConfig, @@ -367,21 +376,27 @@ def __init__( ): super().__init__() self.embed_dim = config.hidden_size - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - self.self_attn = Siglip2Attention(config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - use_data_parallel=use_data_parallel) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - self.mlp = Siglip2MLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=use_data_parallel) - - def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]: + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = Siglip2Attention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Siglip2MLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor, + ) -> tuple[torch.FloatTensor]: """ Args: hidden_states: Input tensor of shape (batch, seq_len, embed_dim). @@ -391,9 +406,11 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn(hidden_states=hidden_states, - cu_seqlens=cu_seqlens, - position_embeddings=position_embeddings) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + ) hidden_states = residual + hidden_states residual = hidden_states @@ -405,7 +422,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, class Siglip2Encoder(nn.Module): """ - Transformer encoder consisting of `config.num_hidden_layers` + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Siglip2EncoderLayer`]. Args: @@ -421,16 +438,21 @@ def __init__( ): super().__init__() self.config = config - self.layers = nn.ModuleList([ - Siglip2EncoderLayer(config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{idx}", - use_data_parallel=use_data_parallel) - for idx in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Siglip2EncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}", + use_data_parallel=use_data_parallel, + ) + for idx in range(config.num_hidden_layers) + ] + ) self.rotary_pos_emb = VisionRotaryEmbedding( - config.hidden_size // config.num_attention_heads // 2) + config.hidden_size // config.num_attention_heads // 2 + ) self.patch_size = config.patch_size self.hidden_stride = config.hidden_stride self.window_size = config.window_size @@ -439,7 +461,7 @@ def __init__( self.fullatt_block_indexes = None else: self.fullatt_block_indexes = [ - int(i) for i in config.fullatt_block_indexes.split('|') + int(i) for i in config.fullatt_block_indexes.split("|") ] # copied from qwen2.5_vl @@ -465,8 +487,7 @@ def rot_pos_emb(self, grid_thw): ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() - pos_ids.append( - torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) @@ -478,8 +499,9 @@ def get_window_index(self, grid_thw): cu_window_seqlens: list = [0] window_index_id = 0 # patch (after merge) number in each window - vit_merger_window_size = (self.window_size // self.hidden_stride // - self.patch_size) + vit_merger_window_size = ( + self.window_size // self.hidden_stride // self.patch_size + ) for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( @@ -487,7 +509,8 @@ def get_window_index(self, grid_thw): grid_w // self.hidden_stride, ) index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w) + grid_t, llm_grid_h, llm_grid_w + ) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size @@ -510,8 +533,9 @@ def get_window_index(self, grid_thw): index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum( - 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) @@ -525,10 +549,10 @@ def forward( ) -> torch.Tensor: r""" Args: - inputs_embeds: Input tensor of shape + inputs_embeds: Input tensor of shape (batch_size, sequence_length, hidden_size). Embedded representation of the input tokens. - grid_thws: Grid tensor of shape (num_patches, 3) + grid_thws: Grid tensor of shape (num_patches, 3) containing grid dimensions. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. @@ -544,11 +568,13 @@ def forward( seq_len, _ = inputs_embeds.size() inputs_embeds = inputs_embeds.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) inputs_embeds = inputs_embeds[window_index, :, :] inputs_embeds = inputs_embeds.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) @@ -572,23 +598,21 @@ def forward( hidden_states = inputs_embeds for index, block in enumerate(self.layers): - if (not self.fullatt_block_indexes - or index in self.fullatt_block_indexes): + if not self.fullatt_block_indexes or index in self.fullatt_block_indexes: cu_seqlens_tmp = cu_seqlens else: cu_seqlens_tmp = cu_window_seqlens - hidden_states = block(hidden_states, cu_seqlens_tmp, - position_embeddings) + hidden_states = block(hidden_states, cu_seqlens_tmp, position_embeddings) hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1) return hidden_states class Siglip2VisionTransformer(nn.Module): - def __init__( self, config: Siglip2VisionConfig, @@ -601,12 +625,13 @@ def __init__( embed_dim = config.hidden_size self.embeddings = Siglip2VisionEmbeddings(config) - self.encoder = Siglip2Encoder(config, - quant_config=quant_config, - prefix=f"{prefix}.encoder", - use_data_parallel=use_data_parallel) - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) + self.encoder = Siglip2Encoder( + config, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel, + ) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, @@ -627,7 +652,6 @@ def forward( class Siglip2NavitModel(torch.nn.Module): - def __init__( self, config: Siglip2VisionConfig, @@ -641,7 +665,8 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.vision_model", - use_data_parallel=use_data_parallel) + use_data_parallel=use_data_parallel, + ) def forward( self, @@ -653,8 +678,7 @@ def forward( grid_thws=grid_thws, ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -665,7 +689,7 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -676,8 +700,7 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index af99e4953b1a..f0f6917ddf91 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -21,17 +21,30 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.models.intern_vit import (InternVisionModel, - InternVisionPatchModel) +from vllm.model_executor.models.intern_vit import ( + InternVisionModel, + InternVisionPatchModel, +) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -40,9 +53,9 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix -IMG_START = '<img>' -IMG_END = '</img>' -IMG_CONTEXT = '<IMG_CONTEXT>' +IMG_START = "<img>" +IMG_END = "</img>" +IMG_CONTEXT = "<IMG_CONTEXT>" IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) @@ -57,6 +70,7 @@ class SkyworkR1VImagePixelInputs(TensorSchema): - w: Width - bn: Batch size * number of images """ + type: Literal["pixel_values"] = "pixel_values" pixel_values_flat: Annotated[ @@ -75,9 +89,10 @@ class SkyworkR1VImageEmbeddingInputs(TensorSchema): Dimensions: - ni: Number of images - ifs: Image feature size - - hs: Hidden size (must match the hidden size of language model + - hs: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[ @@ -86,20 +101,24 @@ class SkyworkR1VImageEmbeddingInputs(TensorSchema): ] -SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs, - SkyworkR1VImageEmbeddingInputs] +SkyworkR1VImageInputs = Union[ + SkyworkR1VImagePixelInputs, SkyworkR1VImageEmbeddingInputs +] # adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ def build_transform(input_size: int): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD - return T.Compose([ - T.Lambda(lambda img: convert_image_mode(img, 'RGB')), - T.Resize((input_size, input_size), - interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=MEAN, std=STD) - ]) + return T.Compose( + [ + T.Lambda(lambda img: convert_image_mode(img, "RGB")), + T.Resize( + (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC + ), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) # adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/ @@ -111,7 +130,7 @@ def find_closest_aspect_ratio( height: int, image_size: int, ) -> tuple[int, int]: - best_ratio_diff = float('inf') + best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: @@ -146,10 +165,13 @@ def get_skyworkr1v_target_ratios( min_num: int, max_num: int, ) -> list[tuple[int, int]]: - target_ratios = {(i, j) - for n in range(min_num, max_num + 1) - for i in range(1, n + 1) - for j in range(1, n + 1) if min_num <= i * j <= max_num} + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } return sorted(target_ratios, key=lambda x: x[0] * x[1]) @@ -206,10 +228,12 @@ def dynamic_preprocess_skyworkr1v( resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): - box = ((i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size) + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) @@ -285,7 +309,8 @@ def __init__( assert isinstance(dynamic_image_size, bool) self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.image_size = image_size self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch @@ -314,14 +339,18 @@ def resolve_min_max_num( dynamic_image_size: Optional[bool] = None, use_thumbnail: Optional[bool] = None, ) -> tuple[int, int]: - min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch - is None else min_dynamic_patch) - max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch - is None else max_dynamic_patch) - dynamic_image_size = (self.dynamic_image_size if dynamic_image_size - is None else dynamic_image_size) - use_thumbnail = (self.use_thumbnail - if use_thumbnail is None else use_thumbnail) + min_dynamic_patch = ( + self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch + ) + max_dynamic_patch = ( + self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch + ) + dynamic_image_size = ( + self.dynamic_image_size + if dynamic_image_size is None + else dynamic_image_size + ) + use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail return resolve_skyworkr1v_min_max_num( min_dynamic_patch=min_dynamic_patch, @@ -388,7 +417,8 @@ def _images_to_pixel_values_lst( min_num=min_num, max_num=max_num, use_thumbnail=self.use_thumbnail, - ) for image in images + ) + for image in images ] def __call__( @@ -419,10 +449,10 @@ def __call__( dynamic_image_size=dynamic_image_size, ) image_inputs = { - "pixel_values_flat": - torch.cat(pixel_values_lst), - "image_num_patches": - torch.tensor([len(item) for item in pixel_values_lst]), + "pixel_values_flat": torch.cat(pixel_values_lst), + "image_num_patches": torch.tensor( + [len(item) for item in pixel_values_lst] + ), } for pixel_values in pixel_values_lst: @@ -431,7 +461,7 @@ def __call__( image_repl = self.get_image_repl(feature_size, num_patches) - text = [t.replace('<image>', image_repl.full, 1) for t in text] + text = [t.replace("<image>", image_repl.full, 1) for t in text] text_inputs = self.tokenizer(text) @@ -441,7 +471,6 @@ def __call__( class SkyworkR1VProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> SkyworkR1VProcessor: return self.ctx.init_processor( SkyworkR1VProcessor, @@ -485,8 +514,7 @@ def get_image_size_with_most_features(self) -> ImageSize: ) if feat_size > largest_feature_size: largest_feature_size = feat_size - largest_feature_pinpoint = ImageSize(width=width, - height=height) + largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") @@ -494,9 +522,7 @@ def get_image_size_with_most_features(self) -> ImageSize: return largest_feature_pinpoint -class SkyworkR1VDummyInputsBuilder( - BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]): - +class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -508,24 +534,22 @@ def get_dummy_mm_data( mm_counts: Mapping[str, int], mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } -class SkyworkR1VMultiModalProcessor( - BaseMultiModalProcessor[SkyworkR1VProcessingInfo]): - +class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[SkyworkR1VProcessingInfo]): def _call_hf_processor( self, prompt: str, @@ -560,7 +584,8 @@ def _get_mm_fields_config( return dict( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_patches), + "image", image_num_patches + ), image_num_patches=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images), @@ -588,7 +613,8 @@ def _get_prompt_updates( def get_replacement_skyworkr1v(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): feature_size = images.get_feature_size(item_idx) @@ -618,7 +644,8 @@ def get_replacement_skyworkr1v(item_idx: int): @MULTIMODAL_REGISTRY.register_processor( SkyworkR1VMultiModalProcessor, info=SkyworkR1VProcessingInfo, - dummy_inputs=SkyworkR1VDummyInputsBuilder) + dummy_inputs=SkyworkR1VDummyInputsBuilder, +) class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): merge_by_field_config = True @@ -644,12 +671,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: patch_size = config.vision_config.patch_size self.patch_size = patch_size self.num_image_token = int( - (image_size // patch_size)**2 * (config.downsample_ratio**2)) + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version self.llm_arch_name = config.text_config.architectures[0] - self.is_mono = self.llm_arch_name == 'SkyworkLM2VEForCausalLM' + self.is_mono = self.llm_arch_name == "SkyworkLM2VEForCausalLM" self.vision_model = self._init_vision_model( config, quant_config=quant_config, @@ -668,18 +696,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.img_context_token_id = None self.visual_token_mask = None self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) - def _patch_quant_config(self, config: PretrainedConfig, - quant_config: QuantizationConfig): + def _patch_quant_config( + self, config: PretrainedConfig, quant_config: QuantizationConfig + ): # the awq models from OpenGVLab missing `modules_to_not_convert` # patch the quant_config to add `modules_to_not_convert` back if isinstance(quant_config, AWQConfig): text_config = config.text_config - llm_quant_config = getattr(text_config, "quantization_config", - None) - if (not quant_config.modules_to_not_convert) and \ - (llm_quant_config is not None): + llm_quant_config = getattr(text_config, "quantization_config", None) + if (not quant_config.modules_to_not_convert) and ( + llm_quant_config is not None + ): quant_config.modules_to_not_convert.append("vision_model") def _init_vision_model( @@ -693,8 +723,9 @@ def _init_vision_model( if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 + num_hidden_layers = ( + config.vision_config.num_hidden_layers + vision_feature_layer + 1 + ) else: num_hidden_layers = vision_feature_layer + 1 @@ -712,15 +743,14 @@ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: llm_hidden_size = config.text_config.hidden_size return nn.Sequential( - nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), - ReplicatedLinear(vit_hidden_size * - int(1 / self.downsample_ratio)**2, - llm_hidden_size, - return_bias=False), + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + ReplicatedLinear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + llm_hidden_size, + return_bias=False, + ), nn.GELU(), - ReplicatedLinear(llm_hidden_size, - llm_hidden_size, - return_bias=False), + ReplicatedLinear(llm_hidden_size, llm_hidden_size, return_bias=False), ) def pixel_shuffle(self, x, scale_factor=0.5): @@ -729,9 +759,13 @@ def pixel_shuffle(self, x, scale_factor=0.5): x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() - x = x.view(n, int(h * scale_factor), int(w * scale_factor), - int(c / (scale_factor * scale_factor))) - if self.ps_version == 'v1': + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": pass else: x = x.permute(0, 2, 1, 3).contiguous() @@ -741,17 +775,16 @@ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = self.vision_model(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :] - h = w = int(vit_embeds.shape[1]**0.5) + h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) - vit_embeds = self.pixel_shuffle(vit_embeds, - scale_factor=self.downsample_ratio) - vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, - vit_embeds.shape[-1]) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: + self, **kwargs: object + ) -> Optional[SkyworkR1VImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) image_num_patches = kwargs.pop("image_num_patches", None) image_embeds = kwargs.pop("image_embeds", None) @@ -777,7 +810,8 @@ def _parse_and_validate_image_input( resolve_bindings={ "h": self.config.vision_config.image_size, "w": self.config.vision_config.image_size, - }) + }, + ) raise AssertionError("This line should be unreachable.") @@ -796,14 +830,14 @@ def _process_image_input( # Only one image in the current batch if len(num_patches) == 1: - return image_embeds.view( - -1, self.config.text_config.hidden_size).unsqueeze(0) + return image_embeds.view(-1, self.config.text_config.hidden_size).unsqueeze( + 0 + ) # NOTE: Image embeddings are split into separate tensors for each image # by the size of each embedding. feature_size = image_embeds.shape[1] - image_embeds = image_embeds.view(-1, - self.config.text_config.hidden_size) + image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size) image_feature_sizes = [ num_patches * feature_size for num_patches in num_patches ] @@ -811,16 +845,16 @@ def _process_image_input( def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: if self.is_mono: - self.visual_token_mask = ( - input_ids == self.img_context_token_id).reshape(-1, 1) + self.visual_token_mask = (input_ids == self.img_context_token_id).reshape( + -1, 1 + ) else: self.visual_token_mask = None def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -835,8 +869,7 @@ def get_input_embeddings( is_multimodal: Optional[torch.Tensor] = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: - if multimodal_embeddings is not None and len( - multimodal_embeddings) > 0: + if multimodal_embeddings is not None and len(multimodal_embeddings) > 0: self._set_visual_token_mask(input_ids) # This is to satisfy the type checker for each overload @@ -858,7 +891,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> IntermediateTensors: - if intermediate_tensors is not None: input_ids = None inputs_embeds = None @@ -872,8 +904,7 @@ def forward( # Only required if the model is mono-architecture if self.visual_token_mask is not None: - forward_kwargs.update( - {"visual_token_mask": self.visual_token_mask}) + forward_kwargs.update({"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None hidden_states = self.language_model.model(**forward_kwargs) @@ -885,13 +916,20 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: skip_prefixes = [ - "action_embed", "temporal_embed", "track_embed", - "track_embed_decoder", "box_token", "cg_criterion", "cg_model", - "loc_encoder", "loc_decoder", "sam", "temporal_token", - "track_token" + "action_embed", + "temporal_embed", + "track_embed", + "track_embed_decoder", + "box_token", + "cg_criterion", + "cg_model", + "loc_encoder", + "loc_decoder", + "sam", + "temporal_token", + "track_token", ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/smolvlm.py b/vllm/model_executor/models/smolvlm.py index 2adfad67152b..1800330c8235 100644 --- a/vllm/model_executor/models/smolvlm.py +++ b/vllm/model_executor/models/smolvlm.py @@ -8,22 +8,18 @@ from vllm.config import VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY -# yapf: disable from .idefics3 import Idefics3DummyInputsBuilder as SmolVLMDummyInputsBuilder -from .idefics3 import Idefics3ForConditionalGeneration +from .idefics3 import Idefics3ForConditionalGeneration, Idefics3ProcessingInfo from .idefics3 import Idefics3MultiModalProcessor as SmolVLMMultiModalProcessor -from .idefics3 import Idefics3ProcessingInfo - -# yapf: enable class SmolVLMProcessingInfo(Idefics3ProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> SmolVLMProcessor: return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs) def _get_image_token( - self, processor: Optional[SmolVLMProcessor]) -> tuple[str, str]: + self, processor: Optional[SmolVLMProcessor] + ) -> tuple[str, str]: if processor is None: processor = self.get_hf_processor() image_token = processor.image_token @@ -32,11 +28,12 @@ def _get_image_token( return image_token, fake_image_token, global_image_token -@MULTIMODAL_REGISTRY.register_processor(SmolVLMMultiModalProcessor, - info=SmolVLMProcessingInfo, - dummy_inputs=SmolVLMDummyInputsBuilder) +@MULTIMODAL_REGISTRY.register_processor( + SmolVLMMultiModalProcessor, + info=SmolVLMProcessingInfo, + dummy_inputs=SmolVLMDummyInputsBuilder, +) class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__( vllm_config=vllm_config, diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index c5b82b0ca4a0..5abcb47c6e25 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -37,26 +37,37 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class SolarMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -82,8 +93,9 @@ def __init__( prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -94,7 +106,6 @@ def forward(self, x): class SolarAttention(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -183,7 +194,6 @@ def forward( class SolarDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -197,21 +207,24 @@ def __init__( rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): - rope_scaling["original_max_position_embeddings"] \ - = config.original_max_position_embeddings - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.self_attn = SolarAttention( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -228,10 +241,10 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -244,23 +257,20 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class SolarModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -271,12 +281,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -299,9 +313,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -328,8 +342,7 @@ def forward( bskcn_h_2 = None bskcn_r_1 = None bskcn_r_2 = None - bskcn_tv = (self.config.bskcn_tv[0] - if self.training else self.config.bskcn_tv[1]) + bskcn_tv = self.config.bskcn_tv[0] if self.training else self.config.bskcn_tv[1] for i in range(self.start_layer, self.end_layer): if i in self.config.bskcn_1: @@ -339,12 +352,10 @@ def forward( bskcn_h_2 = hidden_states.clone() bskcn_r_2 = residual.clone() if i in self.config.bskcn_3: - hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * ( - 1 - bskcn_tv) + hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * (1 - bskcn_tv) residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv) if i in self.config.bskcn_4: - hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * ( - 1 - bskcn_tv) + hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * (1 - bskcn_tv) residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv) layer = self.layers[i] hidden_states, residual = layer( @@ -354,16 +365,14 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -375,14 +384,15 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): # Loading kv cache quantization scales param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue @@ -415,8 +425,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -466,7 +475,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) @@ -474,14 +484,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -493,15 +504,15 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index e4dfe8d5a9a3..79ed00183344 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -21,6 +21,7 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -33,43 +34,56 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class StablelmMLP(nn.Module): - - def __init__(self, - config: StableLmConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: StableLmConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_up_proj = MergedColumnParallelLinear( - config.hidden_size, [config.intermediate_size] * 2, + config.hidden_size, + [config.intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -80,12 +94,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class StablelmAttention(nn.Module): - - def __init__(self, - config: StableLmConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: StableLmConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -102,33 +117,39 @@ def __init__(self, # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_key_value_heads == 0 - self.num_key_value_heads = max( - 1, self.total_num_key_value_heads // tp_size) + self.num_key_value_heads = max(1, self.total_num_key_value_heads // tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings self.partial_rotary_factor = getattr( - config, "rope_pct", getattr(config, "partial_rotary_factor", 1)) + config, "rope_pct", getattr(config, "partial_rotary_factor", 1) + ) self.scaling = self.head_dim**-0.5 self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim self.qkv_bias = getattr(config, "use_qkv_bias", False) if (self.head_dim * self.num_heads * tp_size) != self.hidden_size: - raise ValueError(f"hidden_size must be divisible by num_heads " - f"(got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads}).") - - self.qkv_proj = QKVParallelLinear(self.hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_key_value_heads, - self.qkv_bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj") - self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + raise ValueError( + f"hidden_size must be divisible by num_heads " + f"(got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_key_value_heads, + self.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -136,13 +157,15 @@ def __init__(self, base=self.config.rope_theta, partial_rotary_factor=self.partial_rotary_factor, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_key_value_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -158,7 +181,6 @@ def forward( class StablelmDecoderLayer(nn.Module): - def __init__( self, config: StableLmConfig, @@ -167,16 +189,13 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - self.self_attn = StablelmAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = StablelmAttention( + config, cache_config, quant_config, prefix=f"{prefix}.self_attn" + ) self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp") - norm_eps = getattr(config, "norm_eps", - getattr(config, "layer_norm_eps", 1e-05)) + norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) def forward( self, @@ -202,7 +221,6 @@ def forward( class StableLMEpochModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -219,15 +237,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: StablelmDecoderLayer( - config, cache_config, quant_config, prefix=prefix), + config, cache_config, quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) - norm_eps = getattr(config, "norm_eps", - getattr(config, "layer_norm_eps", 1e-05)) + norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -254,8 +272,7 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -267,7 +284,7 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -287,32 +304,34 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class StablelmForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = StableLMEpochModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.lm_head") + self.model = StableLMEpochModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.lm_head", + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -324,8 +343,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -335,7 +355,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 7f379ab95a03..ec894140c3bf 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -19,7 +19,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Starcoder2 model.""" +"""PyTorch Starcoder2 model.""" + from collections.abc import Iterable from itertools import islice from typing import Optional, Union @@ -33,31 +34,43 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) + default_weight_loader, + maybe_remap_kv_scale_name, +) from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + AutoWeightsLoader, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) class Starcoder2Attention(nn.Module): - - def __init__(self, - config: Starcoder2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.config = config @@ -107,13 +120,15 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -129,11 +144,12 @@ def forward( class Starcoder2MLP(nn.Module): - - def __init__(self, - config: Starcoder2Config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Starcoder2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.c_fc = ColumnParallelLinear( config.hidden_size, @@ -159,25 +175,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Starcoder2DecoderLayer(nn.Module): - - def __init__(self, - config: Starcoder2Config, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config, - cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") - self.mlp = Starcoder2MLP(config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.norm_epsilon) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.norm_epsilon) + self.self_attn = Starcoder2Attention( + config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Starcoder2MLP( + config, quant_config=quant_config, prefix=f"{prefix}.mlp" + ) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.norm_epsilon + ) def forward( self, @@ -204,7 +223,6 @@ def forward( @support_torch_compile class Starcoder2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -219,7 +237,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size, config.hidden_size, quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") + prefix=f"{prefix}.embed_tokens", + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Starcoder2DecoderLayer( @@ -228,9 +247,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.layers", ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -257,8 +276,7 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -269,7 +287,7 @@ def load_weights(self, weights: Iterable[tuple[str, params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -286,22 +304,21 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Starcoder2ForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - self.model = Starcoder2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Starcoder2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: @@ -316,10 +333,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, prefix=f"{prefix}.lm_head", ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -331,8 +350,9 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits( @@ -342,13 +362,13 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None), + skip_prefixes=( + ["lm_head.weight"] if self.config.tie_word_embeddings else None + ), ) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 960813822139..2099055e641c 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Inference-only Jurassic model.""" + from collections.abc import Iterable from itertools import islice from typing import Any, Optional @@ -11,60 +12,77 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) +from .utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) class FusedMoEBlock(nn.Module): - - def __init__(self, - config: ModelConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + config: ModelConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() if self.tp_size > config.moe_num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.moe_num_experts}.") - - self.experts = FusedMoE(num_experts=config.moe_num_experts, - top_k=config.moe_top_k, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_expert_weight, - quant_config=quant_config, - prefix=f"{prefix}.experts") - self.gate = ReplicatedLinear(config.hidden_size, - config.moe_num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + f"the number of experts {config.moe_num_experts}." + ) + + self.experts = FusedMoE( + num_experts=config.moe_num_experts, + top_k=config.moe_top_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_expert_weight, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + self.gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape @@ -73,17 +91,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) class Step3TextMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -94,18 +111,23 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) self.act_fn = SiluAndMul() self.hidden_size = hidden_size @@ -117,7 +139,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Step3TextAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -142,8 +163,9 @@ def __init__( self.num_heads = self.total_num_heads // tp_size if num_kv_heads != 1: - raise ValueError(f"Step3TextAttention num_kv_heads must be 1, " - f"but got {num_kv_heads}.") + raise ValueError( + f"Step3TextAttention num_kv_heads must be 1, but got {num_kv_heads}." + ) self.num_kv_heads = num_kv_heads self.head_dim = head_dim @@ -173,21 +195,26 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.wq", ) - self.rotary_emb = get_rope(self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embedding, - base=rope_theta, - rope_scaling=rope_scaling) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embedding, + base=rope_theta, + rope_scaling=rope_scaling, + ) scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, - self.head_dim, - scaling, - self.num_kv_heads, - cache_config=cache_config, - prefix=f"{prefix}.attn") - - def forward(self, positions: torch.Tensor, - hidden_states: torch.Tensor) -> torch.Tensor: + self.attn = Attention( + self.num_heads, + self.head_dim, + scaling, + self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, positions: torch.Tensor, hidden_states: torch.Tensor + ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q = self.inter_norm(q) @@ -199,12 +226,13 @@ def forward(self, positions: torch.Tensor, class Step3TextDecoderLayer(nn.Module): - - def __init__(self, - config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() config = config.hf_config self.hidden_size = config.hidden_size @@ -222,59 +250,61 @@ def __init__(self, share_q_dim=config.share_q_dim, rope_theta=config.rope_theta, rope_scaling=rope_scaling, - prefix=f"{prefix}.self_attn") + prefix=f"{prefix}.self_attn", + ) layer_idx = int(prefix.split("layers.")[1].split(".")[0]) moe_layers_enum = getattr(config, "moe_layers_enum", None) if moe_layers_enum is not None: - moe_layers_idx = [ - int(i) for i in moe_layers_enum.strip().split(',') - ] + moe_layers_idx = [int(i) for i in moe_layers_enum.strip().split(",")] else: # Default to 1dense. moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] if layer_idx in moe_layers_idx: - self.moe = FusedMoEBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.moe") + self.moe = FusedMoEBlock( + config=config, quant_config=quant_config, prefix=f"{prefix}.moe" + ) self.share_expert = Step3TextMLP( hidden_size=self.hidden_size, intermediate_size=config.share_expert_dim, hidden_act="silu", quant_config=quant_config, - prefix=f"{prefix}.share_expert") + prefix=f"{prefix}.share_expert", + ) self.use_moe = True else: - self.mlp = Step3TextMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act="silu", - quant_config=quant_config, - prefix=f"{prefix}.mlp") + self.mlp = Step3TextMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act="silu", + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) self.use_moe = False - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( - self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor] + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) if self.use_moe: share_output = self.share_expert(hidden_states) @@ -288,7 +318,6 @@ def forward( @support_torch_compile class Step3TextModel(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -297,8 +326,9 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.vocab_size = config.vocab_size self.config = config - if get_pp_group().is_first_rank or (config.tie_word_embeddings - and get_pp_group().is_last_rank): + if get_pp_group().is_first_rank or ( + config.tie_word_embeddings and get_pp_group().is_last_rank + ): self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, @@ -308,11 +338,12 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Step3TextDecoderLayer(config=vllm_config. - model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: Step3TextDecoderLayer( + config=vllm_config.model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: @@ -320,9 +351,9 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: else: self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], config.hidden_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -349,17 +380,18 @@ def forward( hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual, - }) + return IntermediateTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class Step3TextForCausalLM(nn.Module, SupportsPP): - def __init__( self, *, @@ -383,48 +415,65 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None): - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ): + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: qkv_params_mapping = [ # (param_name, shard_name, relative_start_idx, relative_end_idx) - (".qkv_proj", ".q_proj", 0, self.config.share_q_dim / - (self.config.share_q_dim + self.config.head_dim * 2)), - (".qkv_proj", ".k_proj", self.config.share_q_dim / - (self.config.share_q_dim + self.config.head_dim * 2), - (self.config.share_q_dim + self.config.head_dim) / - (self.config.share_q_dim + self.config.head_dim * 2)), - (".qkv_proj", ".v_proj", - (self.config.share_q_dim + self.config.head_dim) / - (self.config.share_q_dim + self.config.head_dim * 2), - (self.config.share_q_dim + self.config.head_dim * 2) / - (self.config.share_q_dim + self.config.head_dim * 2)), + ( + ".qkv_proj", + ".q_proj", + 0, + self.config.share_q_dim + / (self.config.share_q_dim + self.config.head_dim * 2), + ), + ( + ".qkv_proj", + ".k_proj", + self.config.share_q_dim + / (self.config.share_q_dim + self.config.head_dim * 2), + (self.config.share_q_dim + self.config.head_dim) + / (self.config.share_q_dim + self.config.head_dim * 2), + ), + ( + ".qkv_proj", + ".v_proj", + (self.config.share_q_dim + self.config.head_dim) + / (self.config.share_q_dim + self.config.head_dim * 2), + (self.config.share_q_dim + self.config.head_dim * 2) + / (self.config.share_q_dim + self.config.head_dim * 2), + ), ] stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -437,20 +486,19 @@ def load_weights(self, weights: Iterable[tuple[str, expert_params_mapping = [ (".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"), (".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"), - (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2") + (".moe.experts.w2_weight", ".moe.down_proj.weight", "w2"), ] - disable_moe_stacked_params = [ - data[1] for data in expert_params_mapping - ] + disable_moe_stacked_params = [data[1] for data in expert_params_mapping] for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if any(disable_moe_stacked_param in name - for disable_moe_stacked_param in - disable_moe_stacked_params): + if any( + disable_moe_stacked_param in name + for disable_moe_stacked_param in disable_moe_stacked_params + ): continue name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): @@ -470,23 +518,30 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue # Skip loading extra bias for GPTQ models. - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader for expert_id in range(loaded_weight.shape[0]): loaded_weight_expert = loaded_weight[expert_id] - weight_loader(param, - loaded_weight_expert, - name, - shard_id=shard_id, - expert_id=expert_id) + weight_loader( + param, + loaded_weight_expert, + name, + shard_id=shard_id, + expert_id=expert_id, + ) loaded_params.add(name) break else: - for (param_name, weight_name, start_idx, - end_idx) in qkv_params_mapping: + for ( + param_name, + weight_name, + start_idx, + end_idx, + ) in qkv_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -496,8 +551,9 @@ def load_weights(self, weights: Iterable[tuple[str, dim = param.shape[param.output_dim] begin_idx = int(start_idx * dim) end_idx = int(end_idx * dim) - param_slice = param.narrow(param.output_dim, begin_idx, - end_idx - begin_idx) + param_slice = param.narrow( + param.output_dim, begin_idx, end_idx - begin_idx + ) param_slice.copy_(loaded_weight) loaded_params.add(name) break @@ -505,8 +561,9 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index da507e0d9732..c4033dd12558 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -20,25 +20,39 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import ImageSize, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.tokenizer import AnyTokenizer from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) from .vision import run_dp_sharded_vision_model @@ -54,8 +68,7 @@ class Step3VLImageEmbeddingInputs(TypedDict): image_embeds: torch.Tensor -Step3VLImageInputs = Union[Step3VLImagePixelInputs, - Step3VLImageEmbeddingInputs] +Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs] ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None] @@ -63,31 +76,42 @@ class Step3VLImageEmbeddingInputs(TypedDict): class Step3VisionProcessor: - def __init__(self, size, interpolation_mode="bicubic", patch_size=None): mean = [0.48145466, 0.4578275, 0.40821073] std = [0.26862954, 0.26130258, 0.27577711] patch_size = patch_size if patch_size is not None else size - self.transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean, std), - transforms.Resize( - (size, size), - interpolation=InterpolationMode.BICUBIC if interpolation_mode - == "bicubic" else InterpolationMode.BILINEAR, - antialias=True), - ]) - - self.patch_transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(mean, std), - transforms.Resize( - (patch_size, patch_size), - interpolation=InterpolationMode.BICUBIC if interpolation_mode - == "bicubic" else InterpolationMode.BILINEAR, - antialias=True), - ]) if patch_size is not None else None + self.transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean, std), + transforms.Resize( + (size, size), + interpolation=InterpolationMode.BICUBIC + if interpolation_mode == "bicubic" + else InterpolationMode.BILINEAR, + antialias=True, + ), + ] + ) + + self.patch_transform = ( + transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean, std), + transforms.Resize( + (patch_size, patch_size), + interpolation=InterpolationMode.BICUBIC + if interpolation_mode == "bicubic" + else InterpolationMode.BILINEAR, + antialias=True, + ), + ] + ) + if patch_size is not None + else None + ) def __call__(self, image, is_patch=False): if is_patch: @@ -97,7 +121,6 @@ def __call__(self, image, is_patch=False): class ImagePatcher: - def determine_window_size(self, long: int, short: int) -> int: if long <= 728: return short if long / short > 1.5 else 0 @@ -118,14 +141,12 @@ def slide_window( size_w, size_h = size step_w, step_h = step - x_num = 1 if width <= size_w else ceil((width - size_w) / step_w + - 1) + x_num = 1 if width <= size_w else ceil((width - size_w) / step_w + 1) x_start = [step_w * i for i in range(x_num)] if len(x_start) > 1 and x_start[-1] + size_w > width: x_start[-1] = width - size_w - y_num = 1 if height <= size_h else ceil((height - size_h) / - step_h + 1) + y_num = 1 if height <= size_h else ceil((height - size_h) / step_h + 1) y_start = [step_h * i for i in range(y_num)] if len(y_start) > 1 and y_start[-1] + size_h > height: y_start[-1] = height - size_h @@ -135,8 +156,10 @@ def slide_window( windows.append(np.concatenate([start, start + size], axis=1)) windows = np.concatenate(windows, axis=0) - return [(int(box[0]), int(box[1]), int(box[2] - box[0]), - int(box[3] - box[1])) for box in windows], (x_num, y_num) + return [ + (int(box[0]), int(box[1]), int(box[2] - box[0]), int(box[3] - box[1])) + for box in windows + ], (x_num, y_num) def square_pad(self, img: Image.Image) -> Image.Image: w, h = img.size @@ -147,25 +170,27 @@ def square_pad(self, img: Image.Image) -> Image.Image: padded.paste(img, (0, 0)) return padded - def get_image_size_for_padding(self, img_width: int, - img_height: int) -> tuple[int, int]: + def get_image_size_for_padding( + self, img_width: int, img_height: int + ) -> tuple[int, int]: ratio = img_width / img_height if min(img_height, img_width) < 32 and (ratio > 4 or ratio < 1 / 4): new_size = max(img_height, img_width) return new_size, new_size return img_width, img_height - def get_image_size_for_preprocess(self, img_width: int, - img_height: int) -> tuple[int, int]: - + def get_image_size_for_preprocess( + self, img_width: int, img_height: int + ) -> tuple[int, int]: if max(img_height, img_width) > MAX_IMAGE_SIZE: scale_factor = MAX_IMAGE_SIZE / max(img_height, img_width) img_width = int(img_width * scale_factor) img_height = int(img_height * scale_factor) return img_width, img_height - def get_image_size_for_crop(self, img_width: int, img_height: int, - window_size: int): + def get_image_size_for_crop( + self, img_width: int, img_height: int, window_size: int + ): w_ratio = img_width / window_size h_ratio = img_height / window_size @@ -187,22 +212,26 @@ def patch_crop(self, img: Image.Image, i: int, j: int, th: int, tw: int): target = img.crop((j, i, j + tw, i + th)) return target - def get_num_patches(self, img_width: int, - img_height: int) -> tuple[int, int]: - img_width, img_height = self.get_image_size_for_padding( - img_width, img_height) + def get_num_patches(self, img_width: int, img_height: int) -> tuple[int, int]: + img_width, img_height = self.get_image_size_for_padding(img_width, img_height) img_width, img_height = self.get_image_size_for_preprocess( - img_width, img_height) - window_size = self.determine_window_size(max(img_height, img_width), - min(img_height, img_width)) + img_width, img_height + ) + window_size = self.determine_window_size( + max(img_height, img_width), min(img_height, img_width) + ) if window_size == 0: return 0, 0 else: img_width, img_height = self.get_image_size_for_crop( - img_width, img_height, window_size) + img_width, img_height, window_size + ) center_list, (x_num, y_num) = self.slide_window( - img_width, img_height, [(window_size, window_size)], - [(window_size, window_size)]) + img_width, + img_height, + [(window_size, window_size)], + [(window_size, window_size)], + ) full_rows = (len(center_list) - 1) // x_num + 1 if len(center_list) > 0 and len(center_list) % x_num == 0: full_rows -= 1 @@ -213,39 +242,44 @@ def __call__( ) -> tuple[Image.Image, list[Image.Image], list[bool] | None]: img_width, img_height = img.size new_img_width, new_img_height = self.get_image_size_for_padding( - img_width, img_height) + img_width, img_height + ) if new_img_width != img_width or new_img_height != img_height: img = self.square_pad(img) img_width, img_height = img.size new_img_width, new_img_height = self.get_image_size_for_preprocess( - img_width, img_height) - img = img.resize((new_img_width, new_img_height), - Image.Resampling.BILINEAR) + img_width, img_height + ) + img = img.resize((new_img_width, new_img_height), Image.Resampling.BILINEAR) window_size = self.determine_window_size( - max(new_img_height, new_img_width), - min(new_img_height, new_img_width)) + max(new_img_height, new_img_width), min(new_img_height, new_img_width) + ) if window_size == 0: return img, [], None else: new_img_width, new_img_height = self.get_image_size_for_crop( - new_img_width, new_img_height, window_size) + new_img_width, new_img_height, window_size + ) if (new_img_width, new_img_height) != (img_width, img_height): - img_for_crop = img.resize((new_img_width, new_img_height), - Image.Resampling.BILINEAR) + img_for_crop = img.resize( + (new_img_width, new_img_height), Image.Resampling.BILINEAR + ) else: img_for_crop = img patches = [] newlines = [] center_list, (x_num, y_num) = self.slide_window( - new_img_width, new_img_height, [(window_size, window_size)], - [(window_size, window_size)]) + new_img_width, + new_img_height, + [(window_size, window_size)], + [(window_size, window_size)], + ) for patch_id, center_lf_point in enumerate(center_list): x, y, patch_w, patch_h = center_lf_point - big_patch = self.patch_crop(img_for_crop, y, x, patch_h, - patch_w) + big_patch = self.patch_crop(img_for_crop, y, x, patch_h, patch_w) patches.append(big_patch) if (patch_id + 1) % x_num == 0: newlines.append(patch_id) @@ -253,12 +287,16 @@ def __call__( if newlines and newlines[-1] == len(patches) - 1: newlines.pop() - return img, patches, [i in newlines for i in range(len(patches)) - ] if len(patches) > 0 else None + return ( + img, + patches, + [i in newlines for i in range(len(patches))] + if len(patches) > 0 + else None, + ) class Step3VLProcessor: - def __init__( self, config: PretrainedConfig, @@ -271,17 +309,15 @@ def __init__( self.image_size = 728 self.patch_size = 504 - self.image_preprocessor = Step3VisionProcessor(self.image_size, - "bilinear", - self.patch_size) + self.image_preprocessor = Step3VisionProcessor( + self.image_size, "bilinear", self.patch_size + ) self.num_image_feature_size = 169 self.num_patch_feature_size = 81 self.image_token = "<im_patch>" - self.image_feature_placeholder = (self.image_token * - self.num_image_feature_size) - self.patch_feature_placeholder = (self.image_token * - self.num_patch_feature_size) + self.image_feature_placeholder = self.image_token * self.num_image_feature_size + self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size self.patcher = ImagePatcher() @@ -290,15 +326,16 @@ def image_token_id(self) -> int: return self.tokenizer.get_vocab()[self.image_token] def get_num_image_tokens(self, img_width: int, img_height: int) -> int: - num_patches, num_newlines = self.patcher.get_num_patches( - img_width, img_height) + num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height) - return num_patches * ( - self.num_patch_feature_size + - 2) + self.num_image_feature_size + 2 + num_newlines + return ( + num_patches * (self.num_patch_feature_size + 2) + + self.num_image_feature_size + + 2 + + num_newlines + ) - def _split_images(self, - images: list[Image.Image]) -> list[ImageWithPatches]: + def _split_images(self, images: list[Image.Image]) -> list[ImageWithPatches]: result = [] for img in images: result.append(self.patcher(img)) @@ -325,13 +362,15 @@ def _get_patch_repl( assert len(patch_newline_mask) == num_patches text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>" token_ids.extend( - [self.tokenizer.convert_tokens_to_ids("<patch_start>")] + - [self.image_token_id] * self.num_patch_feature_size + - [self.tokenizer.convert_tokens_to_ids("<patch_end>")]) + [self.tokenizer.convert_tokens_to_ids("<patch_start>")] + + [self.image_token_id] * self.num_patch_feature_size + + [self.tokenizer.convert_tokens_to_ids("<patch_end>")] + ) if patch_newline_mask and patch_newline_mask[i]: text += "<patch_newline>" token_ids.append( - self.tokenizer.convert_tokens_to_ids("<patch_newline>")) + self.tokenizer.convert_tokens_to_ids("<patch_newline>") + ) return text, token_ids def _get_image_repl( @@ -339,11 +378,11 @@ def _get_image_repl( num_images: int, ) -> tuple[str, list[int]]: text = f"<im_start>{self.image_feature_placeholder}<im_end>" - token_ids = [ - self.tokenizer.convert_tokens_to_ids("<im_start>") - ] + [self.image_token_id] * self.num_image_feature_size + [ - self.tokenizer.convert_tokens_to_ids("<im_end>") - ] + token_ids = ( + [self.tokenizer.convert_tokens_to_ids("<im_start>")] + + [self.image_token_id] * self.num_image_feature_size + + [self.tokenizer.convert_tokens_to_ids("<im_end>")] + ) return text * num_images, token_ids * num_images def _get_image_repl_features( @@ -354,15 +393,15 @@ def _get_image_repl_features( ) -> tuple[str, list[int]]: if num_patches > 0: patch_repl, patch_repl_ids = self._get_patch_repl( - num_patches, patch_new_line_idx) + num_patches, patch_new_line_idx + ) else: patch_repl = "" patch_repl_ids = [] image_repl, image_repl_ids = self._get_image_repl(num_images) return patch_repl + image_repl, patch_repl_ids + image_repl_ids - def replace_placeholder(self, text: str, placeholder: str, - repls: list[str]) -> str: + def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str: parts = text.split(placeholder) if len(parts) - 1 != len(repls): @@ -404,17 +443,17 @@ def __call__( image_repl_ids_lst = [] num_patches = [] for raw_img, img_patches, patch_newline_mask in splitted_images_data: # noqa: E501 - pixel_values_lst.extend( - self._convert_images_to_pixel_values([raw_img])) + pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img])) if len(img_patches) > 0: patch_pixel_values_lst.extend( - self._convert_images_to_pixel_values(img_patches, - is_patch=True)) + self._convert_images_to_pixel_values(img_patches, is_patch=True) + ) num_patches.append(len(img_patches)) image_repl_str, image_repl_ids = self._get_image_repl_features( - 1, len(img_patches), patch_newline_mask) + 1, len(img_patches), patch_newline_mask + ) image_repl_str_lst.append(image_repl_str) image_repl_ids_lst.extend(image_repl_ids) @@ -426,15 +465,15 @@ def __call__( "num_patches": num_patches, } if patch_pixel_values_lst: - image_inputs["patch_pixel_values"] = torch.cat( - patch_pixel_values_lst) + image_inputs["patch_pixel_values"] = torch.cat(patch_pixel_values_lst) if patch_newline_mask_lst: image_inputs["patch_newline_mask"] = torch.tensor( - patch_newline_mask_lst, dtype=torch.bool) + patch_newline_mask_lst, dtype=torch.bool + ) text = [ - self.replace_placeholder(t, self.image_token, - image_repl_str_lst) for t in text + self.replace_placeholder(t, self.image_token, image_repl_str_lst) + for t in text ] text_inputs = self.tokenizer(text) @@ -448,7 +487,6 @@ def __call__( class Step3VLProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self) -> Step3VLProcessor: return Step3VLProcessor( self.get_hf_config(), @@ -462,7 +500,8 @@ def get_max_image_tokens(self) -> int: hf_processor = self.get_hf_processor() return hf_processor.get_num_image_tokens( self.get_image_size_with_most_features().width, - self.get_image_size_with_most_features().height) + self.get_image_size_with_most_features().height, + ) def get_mm_max_tokens_per_item( self, @@ -476,19 +515,19 @@ def get_image_size_with_most_features(self) -> ImageSize: def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int: if len(mm_data) != 1 or "image" not in mm_data: - raise ValueError( - "mm_data could only contain one key 'image' for steo1o") + raise ValueError("mm_data could only contain one key 'image' for steo1o") image_data = mm_data["image"] if not isinstance(image_data, (list, tuple)): image_data = [image_data] - return sum(self.get_hf_processor().get_num_image_tokens( - img.width, img.height) for img in image_data) + return sum( + self.get_hf_processor().get_num_image_tokens(img.width, img.height) + for img in image_data + ) class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) return "<im_patch>" * num_images @@ -499,24 +538,22 @@ def get_dummy_mm_data( mm_counts: Mapping[str, int], mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, ) -> MultiModalDataDict: - target_width, target_height = \ - self.info.get_image_size_with_most_features() + target_width, target_height = self.info.get_image_size_with_most_features() num_images = mm_counts.get("image", 0) image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides) + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) } -class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo] - ): - +class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo]): def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -532,10 +569,10 @@ def get_replacement_step1o(item_idx: int): if num_patches > 0: patch_newline_mask = out_item["patch_newline_mask"].data image_repl_ids = hf_processor._get_image_repl_features( - 1, num_patches, patch_newline_mask.tolist())[1] + 1, num_patches, patch_newline_mask.tolist() + )[1] else: - image_repl_ids = hf_processor._get_image_repl_features( - 1, 0, None)[1] + image_repl_ids = hf_processor._get_image_repl_features(1, 0, None)[1] return PromptUpdateDetails.select_token_id( seq=image_repl_ids, embed_token_id=image_placeholder_token_id, @@ -559,10 +596,12 @@ def _get_mm_fields_config( return dict( pixel_values=MultiModalFieldConfig.batched("image"), patch_pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + "image", num_patches + ), num_patches=MultiModalFieldConfig.batched("image"), patch_newline_mask=MultiModalFieldConfig.flat_from_sizes( - "image", num_patches), + "image", num_patches + ), ) @@ -576,29 +615,29 @@ def get_abs_pos(abs_pos, tgt_size): dtype = abs_pos.dtype if src_size != tgt_size: - old_pos_embed = old_pos_embed.view(1, src_size, src_size, - dim).permute(0, 3, 1, - 2).contiguous() + old_pos_embed = ( + old_pos_embed.view(1, src_size, src_size, dim) + .permute(0, 3, 1, 2) + .contiguous() + ) old_pos_embed = old_pos_embed.to(torch.float32) new_pos_embed = F.interpolate( old_pos_embed, size=(tgt_size, tgt_size), - mode='bicubic', + mode="bicubic", antialias=True, align_corners=False, ).to(dtype) new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim) vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0) - vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, - dim) + vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim) return vision_pos_embed else: return abs_pos class Step3VisionEmbeddings(nn.Module): - def __init__(self, config: Step3VisionEncoderConfig): super().__init__() self.config = config @@ -616,43 +655,51 @@ def __init__(self, config: Step3VisionEncoderConfig): bias=True, ) - self.num_patches = (self.image_size // self.patch_size)**2 + self.num_patches = (self.image_size // self.patch_size) ** 2 self.pad_tp_size = 4 # hard code for padding # To load the pretrained weights, we still use P+1 as the seqlen - self.position_embedding = torch.nn.Embedding(self.num_patches + 1, - self.embed_dim) - self.register_buffer("position_ids", - torch.arange(self.num_patches + 1).expand( - (1, -1)), - persistent=False) + self.position_embedding = torch.nn.Embedding( + self.num_patches + 1, self.embed_dim + ) + self.register_buffer( + "position_ids", + torch.arange(self.num_patches + 1).expand((1, -1)), + persistent=False, + ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: batch_size = pixel_values.shape[0] patch_embeds = self.patch_embedding( - pixel_values) # shape = [*, width, grid, grid] + pixel_values + ) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) # pad class_embeds = self.class_embedding.expand(batch_size, 1, -1) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) embeddings = embeddings + get_abs_pos( - self.position_embedding(self.position_ids), patch_embeds.size(1)) - embeddings = torch.cat([ - embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, - 1), embeddings - ], - dim=1) + self.position_embedding(self.position_ids), patch_embeds.size(1) + ) + embeddings = torch.cat( + [ + embeddings[:, 0, :].unsqueeze(1).repeat(1, self.pad_tp_size - 1, 1), + embeddings, + ], + dim=1, + ) return embeddings class Step3VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -661,8 +708,7 @@ def __init__(self, self.scale = self.head_dim**-0.5 - tp_size = (1 if use_data_parallel else - get_tensor_model_parallel_world_size()) + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size @@ -677,16 +723,17 @@ def __init__(self, prefix=f"{prefix}.qkv_proj", disable_tp=use_data_parallel, ) - self.out_proj = RowParallelLinear(self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - disable_tp=use_data_parallel) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, + ) # Use unified MultiHeadAttention with automatic backend selection - self.attn = MultiHeadAttention(self.num_heads, self.head_dim, - self.scale) + self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) def forward( self, @@ -708,27 +755,32 @@ def forward( class Step3VisionMLP(nn.Module): - - def __init__(self, - config, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear(config.hidden_size, - config.intermediate_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc1", - disable_tp=use_data_parallel) - self.fc2 = RowParallelLinear(config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.fc2", - disable_tp=use_data_parallel) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.fc1(hidden_states) @@ -738,12 +790,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Step3VisionEncoderLayer(nn.Module): - - def __init__(self, - config: Step3VisionEncoderConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.use_data_parallel = use_data_parallel self.embed_dim = config.hidden_size @@ -751,44 +804,48 @@ def __init__(self, config, quant_config, prefix=f"{prefix}.self_attn", - use_data_parallel=self.use_data_parallel) - self.layer_norm1 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) - self.mlp = Step3VisionMLP(config, - quant_config, - prefix=f"{prefix}.mlp", - use_data_parallel=self.use_data_parallel) - self.layer_norm2 = nn.LayerNorm(self.embed_dim, - eps=config.layer_norm_eps) + use_data_parallel=self.use_data_parallel, + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Step3VisionMLP( + config, + quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=self.use_data_parallel, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, ) -> torch.FloatTensor: - hidden_states = hidden_states + self.layer_norm1( - self.self_attn(hidden_states)) - hidden_states = hidden_states + self.layer_norm2( - self.mlp(hidden_states)) + hidden_states = hidden_states + self.layer_norm1(self.self_attn(hidden_states)) + hidden_states = hidden_states + self.layer_norm2(self.mlp(hidden_states)) return hidden_states class Step3VisionEncoder(nn.Module): - - def __init__(self, - config: Step3VisionEncoderConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.config = config self.use_data_parallel = use_data_parallel - self.layers = nn.ModuleList([ - Step3VisionEncoderLayer(config, - quant_config, - prefix=f"{prefix}.layers.{i}", - use_data_parallel=self.use_data_parallel) - for i in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + Step3VisionEncoderLayer( + config, + quant_config, + prefix=f"{prefix}.layers.{i}", + use_data_parallel=self.use_data_parallel, + ) + for i in range(config.num_hidden_layers) + ] + ) def forward( self, @@ -801,12 +858,13 @@ def forward( class Step3VisionTransformer(nn.Module): - - def __init__(self, - config: Step3VisionEncoderConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - use_data_parallel: bool = False): + def __init__( + self, + config: Step3VisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): super().__init__() self.config = config self.use_data_parallel = use_data_parallel @@ -816,7 +874,8 @@ def __init__(self, config, quant_config, prefix=f"{prefix}.transformer", - use_data_parallel=self.use_data_parallel) + use_data_parallel=self.use_data_parallel, + ) def forward( self, @@ -824,23 +883,24 @@ def forward( ): hidden_states = self.embeddings(pixel_values) if self.use_data_parallel: - hidden_states = run_dp_sharded_vision_model( - hidden_states, self.transformer) + hidden_states = run_dp_sharded_vision_model(hidden_states, self.transformer) else: hidden_states = self.transformer(inputs_embeds=hidden_states) return hidden_states -@MULTIMODAL_REGISTRY.register_processor(Step3VLMultiModalProcessor, - info=Step3VLProcessingInfo, - dummy_inputs=Step3VLDummyInputsBuilder) -class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): - - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "model.": "language_model.model.", - "lm_head.": "language_model.lm_head.", - }) +@MULTIMODAL_REGISTRY.register_processor( + Step3VLMultiModalProcessor, + info=Step3VLProcessingInfo, + dummy_inputs=Step3VLDummyInputsBuilder, +) +class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + } + ) supports_encoder_tp_data = True @@ -866,12 +926,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config.vision_config, None, prefix=maybe_prefix(prefix, "vision_model"), - use_data_parallel=self.use_data_parallel) + use_data_parallel=self.use_data_parallel, + ) self.vit_downsampler = nn.Conv2d( config.vision_config.hidden_size, config.vision_config.output_hidden_size, kernel_size=2, - stride=config.understand_projector_stride) + stride=config.understand_projector_stride, + ) self.vit_downsampler2 = nn.Conv2d( config.vision_config.output_hidden_size, config.vision_config.output_hidden_size * 2, @@ -893,10 +955,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config, - prefix=maybe_prefix(prefix, "language_model")) + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) @property def device(self): @@ -907,7 +971,8 @@ def dtype(self): return next(self.parameters()).dtype def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[Step3VLImageInputs]: + self, **kwargs: object + ) -> Optional[Step3VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) patch_pixel_values = kwargs.pop("patch_pixel_values", None) num_patches = kwargs.pop("num_patches", None) @@ -921,10 +986,10 @@ def _parse_and_validate_image_input( if pixel_values.dim() >= 3: pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]) if patch_pixel_values is not None: - patch_pixel_values = flatten_bn(patch_pixel_values, - concat=True) + patch_pixel_values = flatten_bn(patch_pixel_values, concat=True) patch_pixel_values = patch_pixel_values.view( - -1, *patch_pixel_values.shape[-3:]) + -1, *patch_pixel_values.shape[-3:] + ) # Handle empty patch_pixel_values by setting to None if patch_pixel_values.shape[0] == 0: patch_pixel_values = None @@ -933,8 +998,9 @@ def _parse_and_validate_image_input( return Step3VLImagePixelInputs( type="pixel_values", pixel_values=pixel_values.to(self.dtype).to(self.device), - patch_pixel_values=patch_pixel_values.to(self.dtype).to( - self.device) if patch_pixel_values is not None else None, + patch_pixel_values=patch_pixel_values.to(self.dtype).to(self.device) + if patch_pixel_values is not None + else None, num_patches=num_patches, ) @@ -943,7 +1009,8 @@ def _parse_and_validate_image_input( image_embeds = image_embeds.view(-1, image_embeds.shape[-1]) else: raise ValueError( - f"Unexpected shape for image_embeds: {image_embeds.shape}") + f"Unexpected shape for image_embeds: {image_embeds.shape}" + ) return Step3VLImageEmbeddingInputs( type="image_embeds", @@ -951,8 +1018,7 @@ def _parse_and_validate_image_input( ) return None - def _process_image_features(self, - image_features: torch.Tensor) -> torch.Tensor: + def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor: B, P = image_features.shape[:2] HW = int(sqrt(P)) image_features = image_features.permute(0, 2, 1).view(B, -1, HW, HW) @@ -963,26 +1029,29 @@ def _process_image_features(self, image_features = self.vit_large_projector(image_features) return image_features - def _get_vision_model_output(self, - input_tensor: torch.Tensor) -> torch.Tensor: + def _get_vision_model_output(self, input_tensor: torch.Tensor) -> torch.Tensor: return self.vision_model(input_tensor)[:, 4:] def _process_image_input( - self, image_input: Step3VLImageInputs) -> tuple[torch.Tensor, ...]: - + self, image_input: Step3VLImageInputs + ) -> tuple[torch.Tensor, ...]: if image_input["type"] == "image_embeds": image_features = image_input["image_embeds"] else: - image_features = self._get_vision_model_output( - image_input["pixel_values"]) - patch_image_features = self._get_vision_model_output( - image_input["patch_pixel_values"] - ) if image_input["patch_pixel_values"] is not None else None + image_features = self._get_vision_model_output(image_input["pixel_values"]) + patch_image_features = ( + self._get_vision_model_output(image_input["patch_pixel_values"]) + if image_input["patch_pixel_values"] is not None + else None + ) num_patches = image_input["num_patches"] image_features = self._process_image_features(image_features) - patch_image_features = self._process_image_features( - patch_image_features) if patch_image_features is not None else None + patch_image_features = ( + self._process_image_features(patch_image_features) + if patch_image_features is not None + else None + ) merged_image_features = [] cur_patch_idx = 0 @@ -990,14 +1059,14 @@ def _process_image_input( cur_feature = [] if num_patch > 0: patch_slice = patch_image_features[ - cur_patch_idx:cur_patch_idx + num_patch] + cur_patch_idx : cur_patch_idx + num_patch + ] cur_feature.append(patch_slice.view(-1, patch_slice.shape[-1])) - cur_feature.append(image_features[i].view( - -1, image_features.shape[-1])) + cur_feature.append(image_features[i].view(-1, image_features.shape[-1])) cur_patch_idx += num_patch merged_image_features.append( - torch.cat(cur_feature) if len(cur_feature) > - 1 else cur_feature[0]) + torch.cat(cur_feature) if len(cur_feature) > 1 else cur_feature[0] + ) return merged_image_features def get_language_model(self) -> torch.nn.Module: @@ -1049,10 +1118,9 @@ def forward( ) input_ids = None - hidden_states = self.language_model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states @@ -1063,15 +1131,15 @@ def compute_logits( return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - skip_prefixes = [] if self.vision_model is None and self.vit_large_projector is None: skip_prefixes = [ - "vision_model.", "vit_downsampler.", "vit_downsampler2.", - "vit_large_projector." + "vision_model.", + "vit_downsampler.", + "vit_downsampler2.", + "vit_large_projector.", ] loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) - loaded_weights = loader.load_weights(weights, - mapper=self.hf_to_vllm_mapper) + loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py index 30b441f5b4df..485c008e830a 100644 --- a/vllm/model_executor/models/swin.py +++ b/vllm/model_executor/models/swin.py @@ -7,21 +7,21 @@ import torch import torch.nn as nn from transformers import SwinConfig -from transformers.models.swin.modeling_swin import SwinEmbeddings +from transformers.models.swin.modeling_swin import SwinEmbeddings, SwinPatchMerging from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer -from transformers.models.swin.modeling_swin import SwinPatchMerging from transformers.pytorch_utils import meshgrid from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader class SwinSelfAttention(nn.Module): - def __init__( self, config: SwinConfig, @@ -35,35 +35,40 @@ def __init__( if dim % num_heads != 0: raise ValueError( f"The hidden size ({dim}) is not a multiple of the number of " - f"attention heads ({num_heads})") + f"attention heads ({num_heads})" + ) self.num_attention_heads = num_heads self.attention_head_size = int(dim / num_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.window_size = (window_size if isinstance(window_size, Iterable) - else (window_size, window_size)) + self.window_size = ( + window_size + if isinstance(window_size, Iterable) + else (window_size, window_size) + ) self.scale = self.attention_head_size**-0.5 self.relative_position_bias_table = nn.Parameter( torch.zeros( - (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), - num_heads)) + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads + ) + ) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, - None, :] + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += self.window_size[0] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) - self.relative_position_index = nn.Parameter(relative_position_index, - requires_grad=False) + self.relative_position_index = nn.Parameter( + relative_position_index, requires_grad=False + ) self.qkv = QKVParallelLinear( hidden_size=dim, @@ -75,19 +80,23 @@ def __init__( ) def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, - self.attention_head_size) + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def _get_rel_pos_bias(self) -> torch.Tensor: relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1)] + self.relative_position_index.view(-1) + ] relative_position_bias = relative_position_bias.view( self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1) - relative_position_bias = relative_position_bias.permute( - 2, 0, 1).contiguous() + self.window_size[0] * self.window_size[1], + -1, + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() return relative_position_bias.unsqueeze(0) def forward( @@ -110,38 +119,38 @@ def forward( if attention_mask is not None: mask_shape = attention_mask.shape[0] attention_mask_expanded = attention_mask.view( - 1, mask_shape, 1, dim, - dim).expand(batch_size // mask_shape, mask_shape, - self.num_attention_heads, dim, dim) - attention_scores = attention_scores + \ - attention_mask_expanded.unsqueeze( - 1).unsqueeze(0) - attention_scores = attention_scores.view(-1, - self.num_attention_heads, - dim, dim) + 1, mask_shape, 1, dim, dim + ).expand( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask_expanded.unsqueeze( + 1 + ).unsqueeze(0) + attention_scores = attention_scores.view( + -1, self.num_attention_heads, dim, dim + ) context_layer = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=attention_scores, - dropout_p=0., + dropout_p=0.0, ) attention_probs = None context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + ( - self.all_head_size, ) + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, - attention_probs) if output_attentions else (context_layer, ) + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) return outputs class SwinSelfOutput(nn.Module): - def __init__( self, config: SwinConfig, @@ -157,33 +166,36 @@ def __init__( prefix=f"{prefix}.dense", ) - def forward(self, hidden_states: torch.Tensor, - input_tensor: torch.Tensor) -> torch.Tensor: + def forward( + self, hidden_states: torch.Tensor, input_tensor: torch.Tensor + ) -> torch.Tensor: hidden_states, _ = self.dense(hidden_states) return hidden_states class SwinAttention(nn.Module): - - def __init__(self, - config: SwinConfig, - dim: int, - num_heads: int, - window_size: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() - self.self = SwinSelfAttention(config, - dim, - num_heads, - window_size, - quant_config=quant_config, - prefix=f"{prefix}.self") - self.output = SwinSelfOutput(config, - dim, - quant_config=quant_config, - prefix=f"{prefix}.output") + self.self = SwinSelfAttention( + config, + dim, + num_heads, + window_size, + quant_config=quant_config, + prefix=f"{prefix}.self", + ) + self.output = SwinSelfOutput( + config, dim, quant_config=quant_config, prefix=f"{prefix}.output" + ) self.pruned_heads = set() def forward( @@ -193,25 +205,29 @@ def forward( head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: - self_outputs = self.self(hidden_states, attention_mask, head_mask, - output_attentions) + self_outputs = self.self( + hidden_states, attention_mask, head_mask, output_attentions + ) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output, ) + self_outputs[1:] + outputs = (attention_output,) + self_outputs[1:] return outputs class SwinIntermediate(nn.Module): - - def __init__(self, - config: SwinConfig, - dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() - self.dense = ColumnParallelLinear(dim, - int(config.mlp_ratio * dim), - quant_config=quant_config, - prefix=f"{prefix}.dense") + self.dense = ColumnParallelLinear( + dim, + int(config.mlp_ratio * dim), + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) self.intermediate_act_fn = get_act_fn(config.hidden_act) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -221,17 +237,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class SwinOutput(nn.Module): - - def __init__(self, - config: SwinConfig, - dim: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: super().__init__() - self.dense = RowParallelLinear(int(config.mlp_ratio * dim), - dim, - quant_config=quant_config, - prefix=f"{prefix}.dense") + self.dense = RowParallelLinear( + int(config.mlp_ratio * dim), + dim, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.dense(hidden_states) @@ -239,7 +258,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class SwinLayer(HFSwinLayer): - def __init__( self, config: SwinConfig, @@ -260,24 +278,23 @@ def __init__( shift_size=shift_size, ) - self.attention = SwinAttention(config, - dim, - num_heads, - window_size=self.window_size, - quant_config=quant_config, - prefix=f"{prefix}.attention") - self.intermediate = SwinIntermediate(config, - dim, - quant_config=quant_config, - prefix=f"{prefix}.intermediate") - self.output = SwinOutput(config, - dim, - quant_config=quant_config, - prefix=f"{prefix}.output") + self.attention = SwinAttention( + config, + dim, + num_heads, + window_size=self.window_size, + quant_config=quant_config, + prefix=f"{prefix}.attention", + ) + self.intermediate = SwinIntermediate( + config, dim, quant_config=quant_config, prefix=f"{prefix}.intermediate" + ) + self.output = SwinOutput( + config, dim, quant_config=quant_config, prefix=f"{prefix}.output" + ) class SwinStage(nn.Module): - def __init__( self, config: SwinConfig, @@ -293,24 +310,27 @@ def __init__( super().__init__() self.config = config self.dim = dim - self.blocks = nn.ModuleList([ - SwinLayer(config=config, - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - drop_path_rate=drop_path[layer_idx], - shift_size=0 if - (layer_idx % 2 == 0) else config.window_size // 2, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(depth) - ]) + self.blocks = nn.ModuleList( + [ + SwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[layer_idx], + shift_size=0 if (layer_idx % 2 == 0) else config.window_size // 2, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + ) + for layer_idx in range(depth) + ] + ) # patch merging layer if downsample is not None: - self.downsample = downsample(input_resolution, - dim=dim, - norm_layer=nn.LayerNorm) + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=nn.LayerNorm + ) else: self.downsample = None @@ -328,25 +348,31 @@ def forward( for i, layer_module in enumerate(self.blocks): layer_head_mask = head_mask[i] if head_mask is not None else None - layer_outputs = layer_module(hidden_states, input_dimensions, - layer_head_mask, output_attentions, - always_partition) + layer_outputs = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, + ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = hidden_states if self.downsample is not None: - height_downsampled, width_downsampled = (height + 1) // 2, (width + - 1) // 2 - output_dimensions = (height, width, height_downsampled, - width_downsampled) - hidden_states = self.downsample(hidden_states_before_downsampling, - input_dimensions) + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample( + hidden_states_before_downsampling, input_dimensions + ) else: output_dimensions = (height, width, height, width) - stage_outputs = (hidden_states, hidden_states_before_downsampling, - output_dimensions) + stage_outputs = ( + hidden_states, + hidden_states_before_downsampling, + output_dimensions, + ) if output_attentions: stage_outputs += layer_outputs[1:] @@ -354,7 +380,6 @@ def forward( class SwinEncoder(nn.Module): - def __init__( self, config: SwinConfig, @@ -366,24 +391,36 @@ def __init__( self.num_layers = len(config.depths) self.config = config dpr = [ - x.item() for x in torch.linspace( - 0, config.drop_path_rate, sum(config.depths), device="cpu") + x.item() + for x in torch.linspace( + 0, config.drop_path_rate, sum(config.depths), device="cpu" + ) ] - self.layers = nn.ModuleList([ - SwinStage(config=config, - dim=int(config.embed_dim * 2**layer_idx), - input_resolution=(grid_size[0] // (2**layer_idx), - grid_size[1] // (2**layer_idx)), - depth=config.depths[layer_idx], - num_heads=config.num_heads[layer_idx], - drop_path=dpr[sum(config.depths[:layer_idx] - ):sum(config.depths[:layer_idx + 1])], - downsample=SwinPatchMerging if - (layer_idx < self.num_layers - 1) else None, - quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") - for layer_idx in range(self.num_layers) - ]) + self.layers = nn.ModuleList( + [ + SwinStage( + config=config, + dim=int(config.embed_dim * 2**layer_idx), + input_resolution=( + grid_size[0] // (2**layer_idx), + grid_size[1] // (2**layer_idx), + ), + depth=config.depths[layer_idx], + num_heads=config.num_heads[layer_idx], + drop_path=dpr[ + sum(config.depths[:layer_idx]) : sum( + config.depths[: layer_idx + 1] + ) + ], + downsample=SwinPatchMerging + if (layer_idx < self.num_layers - 1) + else None, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(self.num_layers) + ] + ) def forward( self, @@ -396,9 +433,13 @@ def forward( for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None - layer_outputs = layer_module(hidden_states, input_dimensions, - layer_head_mask, output_attentions, - always_partition) + layer_outputs = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, + ) hidden_states = layer_outputs[0] output_dimensions = layer_outputs[2] @@ -420,13 +461,15 @@ def __init__( super().__init__() self.config = config self.num_layers = len(config.depths) - self.num_features = int(config.embed_dim * 2**(self.num_layers - 1)) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) self.embeddings = SwinEmbeddings(config) - self.encoder = SwinEncoder(config, - self.embeddings.patch_grid, - quant_config=quant_config, - prefix=f"{prefix}.encoder") + self.encoder = SwinEncoder( + config, + self.embeddings.patch_grid, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + ) def forward( self, @@ -445,8 +488,7 @@ def forward( return encoder_outputs - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ ("qkv", "query", "q"), ("qkv", "key", "k"), @@ -456,8 +498,7 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() for name, loaded_weight in weights: - - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -468,8 +509,7 @@ def load_weights(self, weights: Iterable[tuple[str, break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 8759c4ea4a64..482ad9cb7748 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -3,14 +3,17 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, - Union) +from typing import Annotated, Final, Literal, Optional, Protocol, TypeVar, Union import torch import torch.nn as nn -from transformers import BatchFeature, CLIPVisionConfig +from transformers import ( + BatchFeature, + CLIPVisionConfig, + PretrainedConfig, + SiglipVisionConfig, +) from transformers import LlavaConfig as HfLlavaConfig -from transformers import PretrainedConfig, SiglipVisionConfig from transformers.image_utils import ImageInput, get_image_size, to_numpy_array from transformers.models.llava import LlavaProcessor from transformers.processing_utils import ProcessingKwargs, Unpack @@ -18,19 +21,25 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.llava import LlavaDummyInputsBuilder from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems -from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, - ImageSize, MultiModalDataItems) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - InputProcessingContext, - PromptReplacement, PromptUpdate) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -38,10 +47,16 @@ from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, - maybe_prefix) -from .vision import (VisionEncoderInfo, get_num_selected_vision_tokens, - get_vision_encoder_info) +from .utils import ( + AutoWeightsLoader, + init_vllm_registered_model, + maybe_prefix, +) +from .vision import ( + VisionEncoderInfo, + get_num_selected_vision_tokens, + get_vision_encoder_info, +) class TarsierImagePixelInputs(TensorSchema): @@ -52,6 +67,7 @@ class TarsierImagePixelInputs(TensorSchema): - h: Height - w: Width """ + type: Literal["pixel_values"] = "pixel_values" pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] @@ -64,12 +80,12 @@ class TarsierImageEmbeddingInputs(TensorSchema): - hs: Hidden size (must match the hidden size of language model backbone) """ + type: Literal["image_embeds"] = "image_embeds" data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] -TarsierImageInputs = Union[TarsierImagePixelInputs, - TarsierImageEmbeddingInputs] +TarsierImageInputs = Union[TarsierImagePixelInputs, TarsierImageEmbeddingInputs] class TarsierHfConfig(Protocol): # Based on the Tarsier's LlavaConfig @@ -94,19 +110,18 @@ class TarsierProcessorKwargs(ProcessingKwargs, total=False): class TarsierProcessor(LlavaProcessor): - def __call__( self, images: ImageInput = None, - text: Union[TextInput, PreTokenizedInput, list[TextInput], - list[PreTokenizedInput]] = None, + text: Union[ + TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] + ] = None, audio=None, videos=None, **kwargs: Unpack[TarsierProcessorKwargs], ) -> BatchFeature: if images is None and text is None: - raise ValueError( - "You have to specify at least one of `images` or `text`.") + raise ValueError("You have to specify at least one of `images` or `text`.") output_kwargs = self._merge_kwargs( TarsierProcessorKwargs, @@ -115,15 +130,17 @@ def __call__( ) if images is not None: image_inputs = self.image_processor( - images, **output_kwargs["images_kwargs"]) + images, **output_kwargs["images_kwargs"] + ) else: image_inputs = {} if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError("Invalid input text. Please provide a string," - " or a list of strings") + raise ValueError( + "Invalid input text. Please provide a string, or a list of strings" + ) # try to expand inputs in processing if we have the necessary parts prompt_strings = text @@ -131,51 +148,55 @@ def __call__( # Replace the image token with the expanded image token sequence pixel_values = image_inputs["pixel_values"] height, width = get_image_size(to_numpy_array(pixel_values[0])) - num_image_tokens = (height // self.patch_size) * ( - width // self.patch_size + - 1) + self.num_additional_image_tokens + 1 + num_image_tokens = ( + (height // self.patch_size) * (width // self.patch_size + 1) + + self.num_additional_image_tokens + + 1 + ) if self.vision_feature_select_strategy == "default": num_image_tokens -= 1 prompt_strings = [] for sample in text: - sample = sample.replace(self.image_token, - self.image_token * num_image_tokens) + sample = sample.replace( + self.image_token, self.image_token * num_image_tokens + ) prompt_strings.append(sample) - return_tensors = output_kwargs["text_kwargs"].pop( - "return_tensors", None) - text_inputs = self.tokenizer(prompt_strings, - **output_kwargs["text_kwargs"]) - return BatchFeature(data={ - **text_inputs, - **image_inputs - }, - tensor_type=return_tensors) + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) + return BatchFeature( + data={**text_inputs, **image_inputs}, tensor_type=return_tensors + ) class TarsierMultiModalProjector(nn.Module): - - def __init__(self, - vision_hidden_size: int, - text_hidden_size: int, - projector_hidden_act: str, - multimodal_projector_bias: bool, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + multimodal_projector_bias: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): super().__init__() - self.linear_1 = ColumnParallelLinear(vision_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_1") + self.linear_1 = ColumnParallelLinear( + vision_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) self.act = get_act_fn(projector_hidden_act) - self.linear_2 = RowParallelLinear(text_hidden_size, - text_hidden_size, - bias=multimodal_projector_bias, - quant_config=quant_config, - prefix=f"{prefix}.linear_2") + self.linear_2 = RowParallelLinear( + text_hidden_size, + text_hidden_size, + bias=multimodal_projector_bias, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.linear_1(image_features) @@ -185,7 +206,6 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: class TarsierProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> TarsierHfConfig: return self.ctx.get_hf_config(HfLlavaConfig) @@ -227,12 +247,10 @@ def get_num_image_tokens( hf_config.vision_feature_select_strategy, ) if num_projected_patches_default <= 0: - raise ValueError( - "Could not determine a valid number of image patches.") + raise ValueError("Could not determine a valid number of image patches.") num_projected_patches = num_projected_patches_default num_height_patches = int(math.sqrt(num_projected_patches)) - total_image_tokens_for_llm = num_projected_patches \ - + num_height_patches + 1 + total_image_tokens_for_llm = num_projected_patches + num_height_patches + 1 return total_image_tokens_for_llm def get_image_size_with_most_features(self) -> ImageSize: @@ -258,12 +276,10 @@ def get_image_new_idx(self) -> int: class TarsierDummyInputsBuilder(LlavaDummyInputsBuilder[_I_Tarsier]): - pass class TarsierMultiModalProcessor(BaseMultiModalProcessor[_I_Tarsier]): - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -285,14 +301,14 @@ def _get_prompt_updates( def get_replacement(item_idx: int): images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems)) + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) if isinstance(images, ImageEmbeddingItems): num_projected_patches = images.get_feature_size(item_idx) # This assumes num_projected_patches is a perfect square num_height_patches = int(math.sqrt(num_projected_patches)) - num_final_image_tokens = num_projected_patches \ - + num_height_patches + 1 + num_final_image_tokens = num_projected_patches + num_height_patches + 1 else: image_size = images.get_image_size(item_idx) num_final_image_tokens = self.info.get_num_image_tokens( @@ -311,8 +327,7 @@ def get_replacement(item_idx: int): ] -def _build_tarsier_hf_info( - ctx: InputProcessingContext) -> TarsierProcessingInfo: +def _build_tarsier_hf_info(ctx: InputProcessingContext) -> TarsierProcessingInfo: return TarsierProcessingInfo(ctx) @@ -343,22 +358,23 @@ def init_vision_tower_for_tarsier( feature_layers = hf_config.vision_feature_layer base_num_hidden_layers = vision_config.num_hidden_layers - def _get_layer_index(feature_layer_index: int, - num_hidden_layers_total: int) -> int: + def _get_layer_index(feature_layer_index: int, num_hidden_layers_total: int) -> int: if feature_layer_index < 0: return num_hidden_layers_total + feature_layer_index + 1 return feature_layer_index if isinstance(feature_layers, int): - num_hidden_layers_to_init = _get_layer_index(feature_layers, - base_num_hidden_layers) + num_hidden_layers_to_init = _get_layer_index( + feature_layers, base_num_hidden_layers + ) elif isinstance(feature_layers, (list, tuple)): num_hidden_layers_to_init = max( - _get_layer_index(idx, base_num_hidden_layers) - for idx in feature_layers) + _get_layer_index(idx, base_num_hidden_layers) for idx in feature_layers + ) else: - raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" - " is not supported") + raise TypeError( + f"vision_layer_feature type: {type(feature_layers)} is not supported" + ) if isinstance(vision_config, CLIPVisionConfig): return CLIPVisionModel( @@ -381,14 +397,17 @@ def _get_layer_index(feature_layer_index: int, raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_processor(_build_tarsier_hf_processor, - info=_build_tarsier_hf_info, - dummy_inputs=TarsierDummyInputsBuilder) -class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +@MULTIMODAL_REGISTRY.register_processor( + _build_tarsier_hf_processor, + info=_build_tarsier_hf_info, + dummy_inputs=TarsierDummyInputsBuilder, +) +class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True + packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } @classmethod @@ -407,7 +426,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config, quant_config, require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower")) + prefix=maybe_prefix(prefix, "vision_tower"), + ) projector_bias = getattr(config, "multimodal_projector_bias", True) self.multi_modal_projector = TarsierMultiModalProjector( @@ -416,27 +436,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: projector_hidden_act=config.projector_hidden_act, multimodal_projector_bias=projector_bias, quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector")) + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, - hf_config=config. - text_config, # Use text_config from Tarsier's main config + hf_config=config.text_config, # Use text_config from Tarsier's main config prefix=maybe_prefix(prefix, "language_model"), ) - self.register_buffer('image_newline_idx_tensor', - torch.tensor([config.image_newline_idx], - dtype=torch.long), - persistent=False) - self.register_buffer('image_new_idx_tensor', - torch.tensor([config.image_new_idx], - dtype=torch.long), - persistent=False) + self.register_buffer( + "image_newline_idx_tensor", + torch.tensor([config.image_newline_idx], dtype=torch.long), + persistent=False, + ) + self.register_buffer( + "image_new_idx_tensor", + torch.tensor([config.image_new_idx], dtype=torch.long), + persistent=False, + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[TarsierImageInputs]: + self, **kwargs: object + ) -> Optional[TarsierImageInputs]: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) @@ -444,22 +468,15 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - return TarsierImagePixelInputs( type="pixel_values", - pixel_values=flatten_bn(pixel_values, concat=True), + pixel_values=pixel_values, ) if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") return TarsierImageEmbeddingInputs( type="image_embeds", - data=flatten_bn(image_embeds, concat=True), + data=image_embeds, ) raise AssertionError("This line should be unreachable.") @@ -476,23 +493,24 @@ def _image_pixels_to_features( ) def _add_tarsier_split_tokens( - self, projected_image_features: torch.Tensor) -> torch.Tensor: + self, projected_image_features: torch.Tensor + ) -> torch.Tensor: """ Implements Tarsier's `add_split_tokens` logic. """ - num_images, num_projected_patches, embed_dim = \ - projected_image_features.shape + num_images, num_projected_patches, embed_dim = projected_image_features.shape num_height_patches = int(math.sqrt(num_projected_patches)) num_width_patches = num_projected_patches // num_height_patches device = projected_image_features.device embedding_layer = self.language_model.model.embed_tokens image_newline_emb = embedding_layer( - self.image_newline_idx_tensor.to(device)).squeeze(0) - image_new_emb = embedding_layer( - self.image_new_idx_tensor.to(device)).squeeze(0) + self.image_newline_idx_tensor.to(device) + ).squeeze(0) + image_new_emb = embedding_layer(self.image_new_idx_tensor.to(device)).squeeze(0) try: current_image_features_grid = projected_image_features.view( - num_images, num_height_patches, num_width_patches, embed_dim) + num_images, num_height_patches, num_width_patches, embed_dim + ) except RuntimeError as e: raise RuntimeError( "Cannot reshape projected_image_features" @@ -502,22 +520,24 @@ def _add_tarsier_split_tokens( "Ensure num_projected_patches is compatible" " with a grid structure. " f"num_projected_patches={num_projected_patches}, " - f"derived num_height_patches={num_height_patches}. ") from e + f"derived num_height_patches={num_height_patches}. " + ) from e image_newline_expanded = image_newline_emb.expand( - (num_images, num_height_patches, 1, embed_dim)) + (num_images, num_height_patches, 1, embed_dim) + ) features_with_newlines = torch.cat( [current_image_features_grid, image_newline_expanded], - dim=2 # Concatenate along width dim + dim=2, # Concatenate along width dim ) - new_num_patches_after_newline = num_projected_patches \ - + num_height_patches + new_num_patches_after_newline = num_projected_patches + num_height_patches features_with_newlines_flat = features_with_newlines.view( - num_images, new_num_patches_after_newline, embed_dim) + num_images, new_num_patches_after_newline, embed_dim + ) image_new_expanded = image_new_emb.expand((num_images, 1, embed_dim)) final_image_features = torch.cat( [features_with_newlines_flat, image_new_expanded], - dim=1 # Concatenate along patch sequence dim + dim=1, # Concatenate along patch sequence dim ) return final_image_features @@ -528,16 +548,17 @@ def _process_image_pixels( assert self.vision_tower is not None pixel_values = inputs["pixel_values"] image_features_selected = self._image_pixels_to_features( - self.vision_tower, pixel_values) # type: ignore + self.vision_tower, pixel_values + ) # type: ignore if isinstance(image_features_selected, torch.Tensor): - projected_features = self.multi_modal_projector( - image_features_selected) + projected_features = self.multi_modal_projector(image_features_selected) final_features = self._add_tarsier_split_tokens(projected_features) return final_features else: raise TypeError( f"_image_pixels_to_features type:" - f" {type(image_features_selected)} is not supported") + f" {type(image_features_selected)} is not supported" + ) def _process_image_input( self, @@ -548,16 +569,17 @@ def _process_image_input( if isinstance(projected_features, torch.Tensor): return self._add_tarsier_split_tokens(projected_features) else: - raise ValueError("Incorrect type of image_embeds. " - f"Got type: {type(projected_features)}. ") + raise ValueError( + "Incorrect type of image_embeds. " + f"Got type: {type(projected_features)}. " + ) assert self.vision_tower is not None return self._process_image_pixels(image_input) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] @@ -585,7 +607,8 @@ def forward( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + inputs_embeds=inputs_embeds, + ) return hidden_states def compute_logits( @@ -594,7 +617,6 @@ def compute_logits( ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py index 49a7677151a9..113581d55ff5 100644 --- a/vllm/model_executor/models/telechat2.py +++ b/vllm/model_executor/models/telechat2.py @@ -30,12 +30,15 @@ from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel from .llama import LlamaDecoderLayer -from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - is_pp_missing_parameter) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + is_pp_missing_parameter, +) class TeleChat2Model(LlamaModel): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): hf_config = vllm_config.model_config.hf_config @@ -43,7 +46,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): "num_hidden_layers": "n_layer", "num_attention_heads": "n_head", "intermediate_size": "ffn_hidden_size", - "rms_norm_eps": "layer_norm_epsilon" + "rms_norm_eps": "layer_norm_epsilon", } vllm_config.model_config.hf_config.hidden_act = "silu" @@ -62,11 +65,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): layer.mlp.gate_up_proj.bias = None layer.mlp.gate_up_proj.skip_bias_add = True - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ - ('gate_up_proj', 'gate_proj', 0), - ('gate_up_proj', 'up_proj', 1), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -78,9 +80,10 @@ def load_weights(self, weights: Iterable[tuple[str, v_weight = [] for i in range(total_num_heads): start = i * head_dim * 2 - k_weight.append(loaded_weight[start:start + head_dim, :]) - v_weight.append(loaded_weight[start + head_dim:start + - 2 * head_dim:]) + k_weight.append(loaded_weight[start : start + head_dim, :]) + v_weight.append( + loaded_weight[start + head_dim : start + 2 * head_dim :] + ) k_weight = torch.cat(k_weight, dim=0) v_weight = torch.cat(v_weight, dim=0) name = name.replace("key_value", "qkv_proj") @@ -112,15 +115,15 @@ def load_weights(self, weights: Iterable[tuple[str, if is_pp_missing_parameter(name, self): continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class TeleChat2ForCausalLM(LlamaForCausalLM): - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "transformer.": "model.", @@ -134,18 +137,17 @@ class TeleChat2ForCausalLM(LlamaForCausalLM): }, ) - def _init_model(self, - vllm_config: VllmConfig, - prefix: str = "", - layer_type: type[nn.Module] = LlamaDecoderLayer): + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = LlamaDecoderLayer, + ): return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/teleflm.py b/vllm/model_executor/models/teleflm.py index 3666f7011a99..4dfeddb0b28e 100644 --- a/vllm/model_executor/models/teleflm.py +++ b/vllm/model_executor/models/teleflm.py @@ -28,12 +28,14 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.models.llama import (LlamaDecoderLayer, - LlamaForCausalLM, LlamaModel) +from vllm.model_executor.models.llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, +) class TeleFLMModel(LlamaModel): - def __init__( self, *, @@ -41,9 +43,7 @@ def __init__( prefix: str = "", layer_type: type[nn.Module] = LlamaDecoderLayer, ): - super().__init__(vllm_config=vllm_config, - prefix=prefix, - layer_type=layer_type) + super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) """ This implementation is based on the µScaling paper presented at the ICLR 2025 Workshop: @@ -65,7 +65,6 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: class TeleFLMForCausalLM(LlamaForCausalLM): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) # mup @@ -74,6 +73,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.mup_scale_factor = self.config.mup_scale_factor self.output_mult = self.config.output_mult / self.mup_scale_factor logit_scale = self.output_mult - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.config.vocab_size, - logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, self.config.vocab_size, logit_scale + ) diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index cc71adbebd33..c7c82e9e10d1 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -23,8 +23,12 @@ import torch import torch.nn as nn -from terratorch.vllm import (DummyDataGenerator, InferenceRunner, - InputDefinition, InputTypeEnum) +from terratorch.vllm import ( + DummyDataGenerator, + InferenceRunner, + InputDefinition, + InputTypeEnum, +) from transformers import BatchFeature from vllm.config import VllmConfig @@ -35,19 +39,31 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import MultiModalProcessorOnlyCache -from vllm.multimodal.inputs import (ImageItem, ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, - MultiModalUUIDDict, PlaceholderRange) -from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, - MultiModalDataItems, MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptUpdate) +from vllm.multimodal.inputs import ( + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalUUIDDict, + PlaceholderRange, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from .interfaces import (IsAttentionFree, MultiModalEmbeddings, - SupportsMultiModal) +from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal from .interfaces_base import default_pooling_type logger = init_logger(__name__) @@ -59,12 +75,11 @@ def _terratorch_field_names(pretrained_cfg: dict): def _terratorch_field_factory( - pretrained_cfg: dict + pretrained_cfg: dict, ) -> Callable[ [Mapping[str, torch.Tensor]], - Mapping[str, MultiModalFieldConfig], + Mapping[str, MultiModalFieldConfig], ]: - def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]): input_definition = InputDefinition(**pretrained_cfg["input"]) fields = {} @@ -75,24 +90,24 @@ def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]): mm_fields_config = {} for field_name, field_modality in fields.items(): mm_fields_config[field_name] = MultiModalFieldConfig.shared( - batch_size=1, modality=field_modality) + batch_size=1, modality=field_modality + ) return mm_fields_config return _terratorch_field_config class TerratorchProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]): - def __init__(self, info: TerratorchProcessingInfo): super().__init__(info) self.dummy_data_generator = DummyDataGenerator( - self.info.get_hf_config().to_dict()["pretrained_cfg"]) + self.info.get_hf_config().to_dict()["pretrained_cfg"] + ) def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -107,15 +122,16 @@ def get_dummy_mm_data( # defined in the HF configuration file if mm_options: - logger.warning("Configurable multimodal profiling " - "options are not supported for Terratorch. " - "They are ignored for now.") + logger.warning( + "Configurable multimodal profiling " + "options are not supported for Terratorch. " + "They are ignored for now." + ) return self.dummy_data_generator.get_dummy_mm_data() class TerratorchMultiModalDataParser(MultiModalDataParser): - def __init__(self, pretrained_cfg: dict, *args, **kwargs): self._pretrained_cfg = pretrained_cfg super().__init__(*args, **kwargs) @@ -125,7 +141,6 @@ def _parse_image_data( data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], ) -> Optional[ModalityDataItems[Any, Any]]: if isinstance(data, dict): - terratorch_fields = _terratorch_field_names(self._pretrained_cfg) return DictEmbeddingItems( @@ -139,20 +154,18 @@ def _parse_image_data( class TerratorchMultiModalProcessor(BaseMultiModalProcessor): - def __init__( - self, - info: TerratorchProcessingInfo, - dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]", - *, - cache: Optional[MultiModalProcessorOnlyCache] = None) -> None: - + self, + info: TerratorchProcessingInfo, + dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]", + *, + cache: Optional[MultiModalProcessorOnlyCache] = None, + ) -> None: self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"] super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache) def _get_data_parser(self) -> MultiModalDataParser: - return TerratorchMultiModalDataParser( - pretrained_cfg=self.pretrained_cfg) + return TerratorchMultiModalDataParser(pretrained_cfg=self.pretrained_cfg) def _get_mm_fields_config( self, @@ -185,18 +198,16 @@ def apply( mm_items = self._to_mm_items(mm_data) tokenization_kwargs = tokenization_kwargs or {} - mm_hashes = self._hash_mm_items(mm_items, - hf_processor_mm_kwargs, - tokenization_kwargs, - mm_uuids=mm_uuids) + mm_hashes = self._hash_mm_items( + mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids + ) mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} mm_processed_data = BatchFeature(image_data) mm_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_processed_data, - self._get_mm_fields_config(mm_processed_data, - hf_processor_mm_kwargs), + self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs), ) return MultiModalInputs( @@ -237,7 +248,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): assert pooler_config is not None self.pooler = DispatchPooler( - {"encode": Pooler.for_encode(pooler_config)}, ) + {"encode": Pooler.for_encode(pooler_config)}, + ) def get_input_embeddings( self, @@ -265,8 +277,7 @@ def forward( return model_output.output - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_list = [] model_buffers = dict(self.named_buffers()) loaded_buffers = [] @@ -289,8 +300,9 @@ def load_weights(self, weights: Iterable[tuple[str, if "_timm_module." in name: name = name.replace("_timm_module.", "") buffer = model_buffers[name] - weight_loader = getattr(buffer, "weight_loader", - default_weight_loader) + weight_loader = getattr( + buffer, "weight_loader", default_weight_loader + ) weight_loader(buffer, weight) loaded_buffers.append(name) else: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 99114a39295a..5b92aa97eaf0 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` models""" + from collections.abc import Iterable, Mapping from contextlib import contextmanager from pathlib import Path @@ -25,42 +26,62 @@ import transformers from packaging.version import Version from torch import nn -from transformers import (AutoModel, BatchFeature, PretrainedConfig, - PreTrainedModel) +from transformers import AutoModel, BatchFeature, PretrainedConfig, PreTrainedModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, VllmConfig) +from vllm.config import ( + CacheConfig, + DeviceConfig, + ModelConfig, + ParallelConfig, + VllmConfig, +) from vllm.config.multimodal import BaseDummyOptions from vllm.config.utils import getattr_iter -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalUUIDDict, - PlaceholderRange) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalUUIDDict, + PlaceholderRange, +) from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo) +from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP, SupportsQuant) -from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - flatten_bn, make_empty_intermediate_tensors_factory, - maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + make_empty_intermediate_tensors_factory, + maybe_prefix, +) logger = init_logger(__name__) @@ -81,17 +102,18 @@ def get_feature_request_tip( def vllm_flash_attention_forward( - # Transformers args - module: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor, - # Transformers kwargs - scaling: Optional[float] = None, - # vLLM kwargs - attention_instances: Optional[dict[Attention]] = None, - **kwargs): + # Transformers args + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + # Transformers kwargs + scaling: Optional[float] = None, + # vLLM kwargs + attention_instances: Optional[dict[Attention]] = None, + **kwargs, +): self_attn = attention_instances[module.layer_idx] if scaling is not None: self_attn.impl.scale = float(scaling) @@ -125,8 +147,7 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: return enable -Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", - "replicate"] +Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] def replace_linear_class( @@ -148,18 +169,13 @@ def replace_linear_class( """ if not isinstance(style, str): - raise ValueError( - f"Unsupported parallel style type {type(style)}, expected str") + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") vllm_linear_cls, vllm_linear_kwargs = { "colwise": (ColumnParallelLinear, {}), - "colwise_rep": (ColumnParallelLinear, { - "gather_output": True - }), + "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), "rowwise": (RowParallelLinear, {}), - "rowwise_rep": (RowParallelLinear, { - "input_is_parallel": False - }), + "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), "replicate": (ReplicatedLinear, {}), }.get(style, (ReplicatedLinear, {})) @@ -187,7 +203,7 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: kwargs = { "hidden_size": hidden_size, "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6), - "has_weight": getattr(rms_norm, "with_scale", True) + "has_weight": getattr(rms_norm, "with_scale", True), } if (weight := getattr(rms_norm, "weight", None)) is not None: # If weight is a Parameter, get its data tensor @@ -221,12 +237,12 @@ def register_empty_parameter(module, name, param): kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad module._parameters[name] = param_cls( - module._parameters[name].to(device), **kwargs) + module._parameters[name].to(device), **kwargs + ) tensor_constructors_to_patch = {} def patch_tensor_constructor(fn): - def wrapper(*args, **kwargs): kwargs["device"] = device return fn(*args, **kwargs) @@ -237,18 +253,21 @@ def wrapper(*args, **kwargs): nn.Module.register_parameter = register_empty_parameter for torch_function_name in tensor_constructors_to_patch: setattr( - torch, torch_function_name, - patch_tensor_constructor(getattr(torch, torch_function_name))) + torch, + torch_function_name, + patch_tensor_constructor(getattr(torch, torch_function_name)), + ) yield finally: nn.Module.register_parameter = old_register_parameter - for torch_function_name, old_torch_function in ( - tensor_constructors_to_patch.items()): + for ( + torch_function_name, + old_torch_function, + ) in tensor_constructors_to_patch.items(): setattr(torch, torch_function_name, old_torch_function) class MultiModalProcessingInfo(BaseProcessingInfo): - def get_supported_mm_limits(self): return {"image": None} @@ -261,7 +280,8 @@ def get_max_image_tokens(self) -> int: multimodal_config = self.ctx.model_config.multimodal_config mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {} mm_tokens = processor._get_num_multimodal_tokens( - image_sizes=([height, width], ), **mm_processor_kwargs) + image_sizes=([height, width],), **mm_processor_kwargs + ) image_tokens = mm_tokens["num_image_tokens"][0] return image_tokens @@ -269,9 +289,7 @@ def get_max_image_size(self): return 10_000, 10_000 # hardcode for arbitrary very large size -class MultiModalDummyInputsBuilder( - BaseDummyInputsBuilder[MultiModalProcessingInfo]): - +class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -295,16 +313,16 @@ def get_dummy_mm_data( image_overrides = mm_options.get("image") if mm_options else None return { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides), + "image": self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ), } class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): - def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -328,53 +346,37 @@ def _get_prompt_updates( def _get_mm_fields_config( self, - hf_inputs, - hf_processor_mm_kwargs, - num_image_patches: torch.Tensor = None, - ): + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: # HF Processors always return a mask but vLLM doesn't need it hf_inputs.pop("attention_mask", None) + num_image_patches = hf_inputs.get("num_image_patches") mm_fields = { - key: MultiModalFieldConfig.flat_from_sizes("image", - num_image_patches) + key: MultiModalFieldConfig.flat_from_sizes("image", num_image_patches) for key in hf_inputs } mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( - "image", num_image_patches) + "image", num_image_patches + ) + + # Keep these as batched, as they always have batch size as first dim + mm_fields["image_grid_thw"] = MultiModalFieldConfig.batched("image") + mm_fields["video_grid_thw"] = MultiModalFieldConfig.batched("image") mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") return mm_fields - def _apply_hf_processor_text_mm( + def _get_hf_mm_data( self, - prompt_text: str, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], - ) -> tuple[list[int], BatchFeature, bool]: + ) -> tuple[Mapping[str, object], Mapping[str, object]]: """ - Apply the HF processor on the prompt text and multi-modal data - together. - - In addition, return whether prompt replacements have been applied. + In contrast to the base class, this method always adds + `return_mm_token_type_ids` to the processor data """ - processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + processor_data, passthrough_data = super()._get_hf_mm_data(mm_items) processor_data["return_mm_token_type_ids"] = True - - processed_data = self._call_hf_processor( - prompt=prompt_text, - mm_data=processor_data, - mm_kwargs=hf_processor_mm_kwargs, - tok_kwargs=tokenization_kwargs, - ) - processed_data.update(passthrough_data) - - prompt_ids, = processed_data.pop("input_ids").tolist() - mm_token_type_ids = processed_data.pop( - "mm_token_type_ids" - ) if "mm_token_type_ids" in processed_data else processed_data.pop( - "token_type_ids") # for gemma3 only - - return prompt_ids, processed_data, mm_token_type_ids + return processor_data, passthrough_data def apply( self, @@ -401,17 +403,28 @@ def apply( # into string prompt = hf_processor.decode(prompt) - (prompt_ids, processed_data, - mm_token_type_ids) = self._apply_hf_processor_text_mm( - prompt_text=prompt, - mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - tokenization_kwargs=tokenization_kwargs, - ) - - # HF processor will return `mm_token_type_ids` from which - # we can infer mm_placeholders. Until then hardcode to make code run - # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1 + # Bypass cached processor and always apply to the full set of mm inputs + # NOTE: we can't just set caching=False because base class method + # transforms outputs to `MultiModalKwargs` which is not going to + # work for Transformers. We have a lot of logic tied to + # `mm_tokens_per_modality` below + prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + + # For gemma3 we check `token_type_ids` as the key + token_type_key = ( + "mm_token_type_ids" + if "mm_token_type_ids" in processed_data + else "token_type_ids" + ) + mm_token_type_ids = processed_data.pop(token_type_key) + + # We can infer vLLM style placeholder from token type ids, if we split + # it for each input `mm_data`. mm_positions = torch.where(mm_token_type_ids == 1)[1] images = mm_items.get_items("image", ImageProcessorItems) multimodal_config = self.info.ctx.model_config.multimodal_config @@ -422,7 +435,8 @@ def apply( image_sizes.append((image_size.height, image_size.width)) mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( - image_sizes=image_sizes, **mm_processor_kwargs) + image_sizes=image_sizes, **mm_processor_kwargs + ) mm_placeholders = {} split_sizes = mm_tokens_per_modality["num_image_tokens"] @@ -434,27 +448,24 @@ def apply( PlaceholderRange( offset=positions[0].item(), length=positions.shape[0], - is_embed=(mm_tokens == hf_processor.image_token_id).bool()) - for positions, mm_tokens in zip(chunked_mm_positions, - chunked_mm_tokens) + is_embed=(mm_tokens == hf_processor.image_token_id).bool(), + ) + for positions, mm_tokens in zip(chunked_mm_positions, chunked_mm_tokens) ] mm_placeholders = {"image": ranges} - num_image_patches = torch.tensor( + processed_data["num_image_patches"] = torch.tensor( mm_tokens_per_modality["num_image_patches"] - ) if "num_image_patches" in mm_tokens_per_modality else None - processed_data['num_image_patches'] = num_image_patches + ) mm_kwargs = MultiModalKwargsItems.from_hf_inputs( processed_data, - self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, - num_image_patches), + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), ) # Use overrides if provided; fallback to data-dependent hashing. - mm_hashes = self._hash_mm_items(mm_items, - hf_processor_mm_kwargs, - tokenization_kwargs, - mm_uuids=mm_uuids) + mm_hashes = self._hash_mm_items( + mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids + ) return MultiModalInputs( type="multimodal", @@ -467,8 +478,7 @@ def apply( class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] - embedding_modules = ["embed_tokens" - ] # TODO transformers will have a util to get it + embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -480,13 +490,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.device_config: DeviceConfig = vllm_config.device_config self.model_config: ModelConfig = vllm_config.model_config self.parallel_config: ParallelConfig = vllm_config.parallel_config - self.quant_config: Optional[ - QuantizationConfig] = vllm_config.quant_config + self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config self.pp_group = get_pp_group() - self.pp_size = self.pp_group.world_size - self.pp_rank = self.pp_group.rank_in_group - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() # Weights to skip in `self.load_weights` self.skip_prefixes: list[str] = [] @@ -503,15 +510,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_method_name = self.quant_config.get_name() # Check for unsupported quantization methods. if quant_method_name == "mxfp4": - raise NotImplementedError("Transformers backend does not " - "support MXFP4 quantization yet.") + raise NotImplementedError( + "Transformers backend does not support MXFP4 quantization yet." + ) # Skip loading extra bias for GPTQ models. if "gptq" in quant_method_name: self.ignore_unexpected_suffixes.append(".bias") # Set correct attn and init on "meta" to delay allocating GPU tensors - # TODO: @raushan, use the public `model.set_attn_implementation()` - # method once its checks are fixed in Transformers. self.text_config._attn_implementation = "vllm" with init_on_device_without_buffers("meta"): self.model: PreTrainedModel = AutoModel.from_config( @@ -538,26 +544,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): embedding_dim=embedding_dim, org_num_embeddings=self.text_config.vocab_size, quant_config=self.quant_config, - )) + ) + ) # Initialize any parameters that have not had their modules replaced self.init_parameters(self.model) # Pipeline parallel intermediate tensors - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states"], self.text_config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.text_config.hidden_size + ) def pipeline_parallel(self): """ Apply the model's pipeline parallelization plan. """ - if self.pp_size <= 1: + if self.pp_group.world_size <= 1: return if not self.model.supports_pp_plan: - tip = get_feature_request_tip(self.model_config.model, - self.model_config.trust_remote_code) + tip = get_feature_request_tip( + self.model_config.model, self.model_config.trust_remote_code + ) raise ValueError( f"{type(self.model)} does not support pipeline parallel. {tip}" ) @@ -573,22 +581,25 @@ def pipeline_parallel(self): if len(module_lists) > 1: raise ValueError( "Pipeline parallel of models with multiple `ModuleList`s " - "in the base model are not supported yet!") + "in the base model are not supported yet!" + ) if module_list_idx is None: - raise ValueError( - f"Could not find `ModuleList` in {type(self.model)}") + raise ValueError(f"Could not find `ModuleList` in {type(self.model)}") # Layers before module list for name in pp_plan[:module_list_idx]: if self.pp_group.is_first_rank or ( - self.text_config.tie_word_embeddings - and self.pp_group.is_last_rank): + self.text_config.tie_word_embeddings and self.pp_group.is_last_rank + ): continue setattr(self.model, name, PPMissingLayer()) # Module list start_layer, end_layer = get_pp_indices( - self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) + self.text_config.num_hidden_layers, + self.pp_group.rank_in_group, + self.pp_group.world_size, + ) layers_name = pp_plan[module_list_idx] layers = getattr(self.model, layers_name) for i in range(len(layers)): @@ -597,7 +608,7 @@ def pipeline_parallel(self): layers[i] = PPMissingLayer() # Layers after module list - for name in pp_plan[module_list_idx + 1:]: + for name in pp_plan[module_list_idx + 1 :]: # Modules that should be on last rank if not self.pp_group.is_last_rank: setattr(self.model, name, PPMissingLayer()) @@ -612,11 +623,13 @@ def recursive_replace(self): """ tp_plan = self.model.tp_plan - if not tp_plan and self.tp_size > 1: - tip = get_feature_request_tip(self.model_config.model, - self.model_config.trust_remote_code) + if not tp_plan and self.tp_group.world_size > 1: + tip = get_feature_request_tip( + self.model_config.model, self.model_config.trust_remote_code + ) raise ValueError( - f"{type(self.model)} does not support tensor parallel. {tip}") + f"{type(self.model)} does not support tensor parallel. {tip}" + ) # Prefix the patterns because we always start from `self.model` tp_plan = {maybe_prefix("model", k): v for k, v in tp_plan.items()} @@ -632,10 +645,9 @@ def _recursive_replace(module: nn.Module, prefix: str): # LinearBase, so we set a default style which causes any # unspecified layers to be replaced with ReplicatedLinear style = tp_plan.get(pattern, "replicate") - new_module = replace_linear_class(child_module, - style, - self.quant_config, - prefix=qual_name) + new_module = replace_linear_class( + child_module, style, self.quant_config, prefix=qual_name + ) # TODO(hmellor): Enable RMSNorm replacement once we have a way # to choose RMSNorm vs GemmaRMSNorm # elif child_module.__class__.__name__.endswith("RMSNorm"): @@ -651,25 +663,28 @@ def _recursive_replace(module: nn.Module, prefix: str): _recursive_replace(self.model, prefix="model") def create_attention_instances( - self, - attn_type: AttentionType = AttentionType.DECODER + self, attn_type: AttentionType = AttentionType.DECODER ) -> dict[int, Attention]: """ Create `Attention` instances to inform KV cache allocation. """ - num_heads = self.model_config.get_num_attention_heads( - self.parallel_config) + num_heads = self.model_config.get_num_attention_heads(self.parallel_config) head_size = self.model_config.get_head_size() num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - start, end = get_pp_indices(self.text_config.num_hidden_layers, - self.pp_rank, self.pp_size) + start, end = get_pp_indices( + self.text_config.num_hidden_layers, + self.pp_group.rank_in_group, + self.pp_group.world_size, + ) attention_instances = {} for i in range(start, end): # Handle interleaved sliding window attention per_layer_sliding_window = None - if (hasattr(self.config, "layer_types") - and self.config.layer_types[i] == "sliding_attention"): + if ( + hasattr(self.config, "layer_types") + and self.config.layer_types[i] == "sliding_attention" + ): per_layer_sliding_window = self.config.sliding_window attention_instances[i] = Attention( @@ -683,12 +698,11 @@ def create_attention_instances( quant_config=self.quant_config, per_layer_sliding_window=per_layer_sliding_window, prefix=f"{i}.attn", - attn_type=attn_type) + attn_type=attn_type, + ) return attention_instances - def init_parameters(self, - module: nn.Module, - dtype: Optional[torch.dtype] = None): + def init_parameters(self, module: nn.Module, dtype: Optional[torch.dtype] = None): """ If a `parameter` is on the `meta` device, then its parent `module` is the original module created by: @@ -707,7 +721,8 @@ def _init_parameters(module: nn.Module, dtype: Optional[torch.dtype]): param.data, dtype=dtype or self.model_config.dtype, device=self.device_config.device, - )) + ) + ) setattr(module, name, new_param) for child in module.children(): _init_parameters(child, dtype) @@ -721,7 +736,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - if not get_pp_group().is_first_rank: + if not self.pp_group.is_first_rank: assert intermediate_tensors is not None input_ids = None inputs_embeds = intermediate_tensors["hidden_states"] @@ -742,9 +757,10 @@ def forward( use_cache=False, position_ids=position_ids, attention_instances=self.attention_instances, - return_dict=False)[0][0, ...] # we remove batch dimension for now + return_dict=False, + )[0][0, ...] # we remove batch dimension for now - if not get_pp_group().is_last_rank: + if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) return hidden_states @@ -768,12 +784,12 @@ def check_version(self, min_version: str, feature: str): if installed < required: raise ImportError( f"Transformers backend requires transformers>={required} " - f"for {feature}, but got {installed}") + f"for {feature}, but got {installed}" + ) @support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(TransformersBase): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -782,7 +798,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.text_config.tie_word_embeddings: self.skip_prefixes.append("lm_head.") - if get_pp_group().is_last_rank: + if self.pp_group.is_last_rank: self.unpadded_vocab_size = self.text_config.vocab_size self.lm_head = ParallelLMHead( self.text_config.vocab_size, @@ -792,12 +808,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) if self.text_config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights( - self.model.get_input_embeddings()) + self.model.get_input_embeddings() + ) logit_scale = getattr(self.text_config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( - self.unpadded_vocab_size, self.text_config.vocab_size, - logit_scale) + self.unpadded_vocab_size, self.text_config.vocab_size, logit_scale + ) else: self.lm_head = PPMissingLayer() @@ -812,21 +829,11 @@ def compute_logits( return logits -def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor: - """Flatten until a list of tensors can be concatenated then do concat""" - - def _can_concat(x: list[torch.Tensor]): - return len(set(map(lambda _x: _x.shape[1:], x))) == 1 - - if _can_concat(x): - return torch.concat(x) - return flatten_and_concat(flatten_bn(x)) - - @MULTIMODAL_REGISTRY.register_processor( MultiModalProcessor, info=MultiModalProcessingInfo, - dummy_inputs=MultiModalDummyInputsBuilder) + dummy_inputs=MultiModalDummyInputsBuilder, +) @support_torch_compile( # set `positions` to last dim to support Qwen-mrope dynamic_arg_dims={ @@ -835,7 +842,8 @@ def _can_concat(x: list[torch.Tensor]): "intermediate_tensors": 0, "inputs_embeds": 0, }, - enable_if=can_enable_torch_compile) + enable_if=can_enable_torch_compile, +) class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): merge_by_field_config = True # Backwards compatibility for prev released models. State dicts back then @@ -859,7 +867,8 @@ class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal): "model.embed_tokens": "model.language_model.embed_tokens", "model.layers": "model.language_model.layers", "model.norm": "model.language_model.norm", - }) + } + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -874,8 +883,9 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = super().forward(input_ids, positions, - intermediate_tensors, inputs_embeds) + model_output = super().forward( + input_ids, positions, intermediate_tensors, inputs_embeds + ) return model_output def get_language_model(self) -> torch.nn.Module: @@ -896,13 +906,9 @@ def get_multimodal_embeddings(self, **kwargs): num_image_patches = kwargs.pop("num_image_patches") if pixel_values is not None: - vision_embeddings = self.model.get_image_features( - pixel_values, **kwargs) + vision_embeddings = self.model.get_image_features(pixel_values, **kwargs) if isinstance(vision_embeddings, torch.Tensor): - if isinstance(num_image_patches, list): - num_image_patches = torch.cat(num_image_patches) - if vision_embeddings.ndim == 2: vision_embeddings = vision_embeddings.unsqueeze(0) @@ -910,8 +916,8 @@ def get_multimodal_embeddings(self, **kwargs): # but transformers returns concat tensors if each patch # is of different size. We split it back to make vLLM happy vision_embeddings = torch.split( - vision_embeddings, - num_image_patches.flatten().tolist()) + vision_embeddings, num_image_patches.flatten().tolist() + ) vision_embeddings = [ embed.flatten(start_dim=0, end_dim=-2) for embed in vision_embeddings @@ -954,7 +960,8 @@ def get_input_embeddings( raise ValueError( "`get_input_embeddings` now requires `is_multimodal` arg, " "please update your model runner according to " - "https://github.com/vllm-project/vllm/pull/16229.") + "https://github.com/vllm-project/vllm/pull/16229." + ) return _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index cb966256b350..eb135f050e8c 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` MoE models.""" + from typing import Any import torch @@ -22,15 +23,21 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config.utils import getattr_iter +from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op -from .transformers import (TransformersBase, TransformersForCausalLM, - TransformersForMultimodalLM, - can_enable_torch_compile, log_replacement) +from .interfaces import MixtureOfExperts +from .transformers import ( + TransformersBase, + TransformersForCausalLM, + TransformersForMultimodalLM, + can_enable_torch_compile, + log_replacement, +) from .utils import maybe_prefix @@ -40,43 +47,63 @@ class TransformersFusedMoE(FusedMoE): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._top_k_index: torch.Tensor = None - - def custom_routing_function(hidden_states, gating_output, topk, - renormalize): - """Return `top_k_weights` from `gating_output` and the - `top_k_index` we stored in the layer earlier.""" - return gating_output, self._top_k_index + self._topk_ids: torch.Tensor = None + + def custom_routing_function(hidden_states, gating_output, topk, renormalize): + """Return `topk_weights` from `gating_output` and the + `topk_ids` we stored in the layer earlier.""" + topk_weights = gating_output + topk_ids = self._topk_ids + # Handle all gather in expert parallel + if topk_ids.size(0) != hidden_states.size(0): + dp_metadata = get_forward_context().dp_metadata + sizes = dp_metadata.get_chunk_sizes_across_dp_rank() + is_sp = self.is_sequence_parallel + dist_group = get_ep_group() if is_sp else get_dp_group() + assert sizes[dist_group.rank_in_group] == topk_ids.shape[0] + (topk_ids,) = dist_group.all_gatherv([topk_ids], 0, sizes) + return topk_weights, topk_ids self.custom_routing_function = custom_routing_function - def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: """In Transformers `experts.forward` will have this signature. We discard any extra kwargs because we cannot use them here.""" - return torch.ops.vllm.transformers_moe_forward(hidden_states, - top_k_index, - top_k_weights, - self.layer_name) - - -def transformers_moe_forward(hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - layer_name: str) -> torch.Tensor: - """Store the `top_k_index` in the layer and call the actual forward.""" + return torch.ops.vllm.transformers_moe_forward( + hidden_states, + topk_ids.to(torch.int32), + topk_weights.to(torch.float32), + self.layer_name, + ) + + +def transformers_moe_forward( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + """Store the `topk_ids` in the layer and call the actual forward.""" forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self._top_k_index = top_k_index + self._topk_ids = topk_ids # Clone hidden_states because it will be mutated in-place in FusedMoE - return self.forward_impl(hidden_states.clone(), top_k_weights) + return self.forward_impl(hidden_states.clone(), topk_weights) -def transformers_moe_forward_fake(hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - layer_name: str) -> torch.Tensor: +def transformers_moe_forward_fake( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + layer_name: str, +) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -86,23 +113,44 @@ def transformers_moe_forward_fake(hidden_states: torch.Tensor, mutates_args=["hidden_states"], fake_impl=transformers_moe_forward_fake, dispatch_key=current_platform.dispatch_key, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(torch.Tag.needs_fixed_stride_order,), ) -class TransformersMoEBase(TransformersBase): - +class TransformersMoEBase(TransformersBase, MixtureOfExperts): def __init__(self, *, vllm_config, prefix=""): self.check_version("4.57.0.dev0", "MoE models support") + self.ep_group = get_ep_group() super().__init__(vllm_config=vllm_config, prefix=prefix) - if self.parallel_config.enable_expert_parallel: - raise NotImplementedError( - "Transformers backend does not support expert parallel yet.") - if self.parallel_config.enable_eplb: - raise NotImplementedError( - "Transformers backend does not support expert parallel load " - "balancing yet.") + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ): + for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers): + mlp_layer.experts.set_eplb_state( + moe_layer_idx=moe_layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ): + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for mlp in self.mlp_layers: + mlp.n_local_physical_experts = num_local_physical_experts + mlp.n_physical_experts = num_physical_experts + mlp.n_redundant_experts = self.num_redundant_experts + mlp.experts.update_expert_map() def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: """ @@ -115,6 +163,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style ("linear", "linear_1", "linear_v"), # Grok1 style ] + num_experts = self.model_config.get_num_experts() + num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts expert_mapping = [] for gate_proj, down_proj, up_proj in ckpt_names: expert_mapping.extend( @@ -122,9 +172,10 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_gate_proj_name=gate_proj, ckpt_down_proj_name=down_proj, ckpt_up_proj_name=up_proj, - num_experts=self.model_config.get_num_experts(), - num_redundant_experts=0, # TODO: enable EPLB - )) + num_experts=num_experts, + num_redundant_experts=num_redundant_experts, + ) + ) return expert_mapping def recursive_replace(self): @@ -133,30 +184,33 @@ def recursive_replace(self): # Positional arguments num_experts = self.model_config.get_num_experts() - top_k = getattr_iter(text_config, ["num_experts_per_tok", "top_k"], - None) + top_k = getattr_iter(text_config, ["num_experts_per_tok", "top_k"], None) assert top_k is not None hidden_size = text_config.hidden_size intermediate_size = getattr_iter( - text_config, ["moe_intermediate_size", "intermediate_size"], None) + text_config, ["moe_intermediate_size", "intermediate_size"], None + ) assert intermediate_size is not None # If there are shared experts, the results are # reduced after mlp.forward() not inside FusedMoE - num_experts_shared = getattr_iter(text_config, [ - "num_experts_shared", "n_shared_experts", "moe_num_shared_experts" - ], 0) - reduce_results = num_experts_shared == 0 + num_shared_experts = getattr_iter( + text_config, + [ + "n_shared_experts", # DeepSeek, Docs, GLM + "moe_num_shared_experts", # Aria, Ernie + ], + 0, + ) + reduce_results = num_shared_experts == 0 def add_all_reduce(mlp: nn.Module): """Adds an all-reduce to the output of `mlp.forward()`.""" class MLPWithAllReduce(mlp.__class__): - def forward(self, *args, **kwargs): output = super().forward(*args, **kwargs) - return self.experts.maybe_all_reduce_tensor_model_parallel( - output) + return self.experts.maybe_all_reduce_tensor_model_parallel(output) mlp.__class__ = MLPWithAllReduce @@ -183,20 +237,29 @@ def forward(self, *args, **kwargs): # Expert mapping for `AutoWeightsLoader` expert_mapping = self.get_expert_mapping() - # Configs - parallel_config = self.parallel_config - eplb_config = parallel_config.eplb_config - # Expert parallel load balancing kwargs - enable_eplb = parallel_config.enable_eplb - num_redundant_experts = eplb_config.num_redundant_experts + enable_eplb = self.parallel_config.enable_eplb + num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts + + # MixtureOfExperts mixin settings + ep_size = self.ep_group.world_size + + self.mlp_layers = [] # Used for MixtureOfExperts methods + self.expert_weights = [] + self.num_moe_layers = 0 + self.num_expert_groups = 1 if num_expert_group is None else num_expert_group + self.num_logical_experts = num_experts + self.num_physical_experts = num_experts + num_redundant_experts + self.num_local_physical_experts = self.num_physical_experts // ep_size + self.num_routed_experts = num_experts + self.num_shared_experts = num_shared_experts + self.num_redundant_experts = num_redundant_experts # Recursively fuse MoE layers def _recursive_replace(module: nn.Module, prefix: str): for child_name, child_module in module.named_children(): qual_name = maybe_prefix(prefix, child_name) - if (child_name == "experts" - and isinstance(child_module, nn.ModuleList)): + if child_name == "experts" and isinstance(child_module, nn.ModuleList): # Alias for readability mlp = module experts = child_module @@ -212,6 +275,9 @@ def _recursive_replace(module: nn.Module, prefix: str): for mlp_param_name, _ in mlp.named_parameters(): if "shared_expert" in mlp_param_name: reduce_results = False + # If the config does not specify num_shared_experts, but + # the model has shared experts, we assume there is one. + self.num_shared_experts = 1 break # Replace experts module with FusedMoE fused_experts = TransformersFusedMoE( @@ -235,11 +301,16 @@ def _recursive_replace(module: nn.Module, prefix: str): ) mlp.experts = fused_experts log_replacement(qual_name, experts, fused_experts) + # Update MixtureOfExperts mixin state + self.mlp_layers.append(mlp) + self.expert_weights.append(fused_experts.get_expert_weights()) + self.num_moe_layers += 1 # If results are not all-reduced in FusedMoE, ensure they # are all-reduced at the end of mlp.forward() if tensor # parallel or expert parallel is enabled - if not reduce_results and (fused_experts.tp_size > 1 - or fused_experts.ep_size > 1): + if not reduce_results and ( + fused_experts.tp_size > 1 or fused_experts.ep_size > 1 + ): add_all_reduce(mlp) else: _recursive_replace(child_module, prefix=qual_name) @@ -262,7 +333,9 @@ class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM): "intermediate_tensors": 0, "inputs_embeds": 0, }, - enable_if=can_enable_torch_compile) -class TransformersMoEForMultimodalLM(TransformersMoEForCausalLM, - TransformersForMultimodalLM): + enable_if=can_enable_torch_compile, +) +class TransformersMoEForMultimodalLM( + TransformersMoEForCausalLM, TransformersForMultimodalLM +): pass diff --git a/vllm/model_executor/models/transformers_pooling.py b/vllm/model_executor/models/transformers_pooling.py index 27fd40999fe2..98d2611351c0 100644 --- a/vllm/model_executor/models/transformers_pooling.py +++ b/vllm/model_executor/models/transformers_pooling.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` models for pooling tasks.""" + from typing import Optional, Union import torch @@ -23,8 +24,12 @@ from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, - DispatchPooler, Pooler) +from vllm.model_executor.layers.pooler import ( + ClassifierPooler, + CLSPool, + DispatchPooler, + Pooler, +) from vllm.sequence import IntermediateTensors from .interfaces_base import VllmModelForPooling @@ -52,16 +57,22 @@ class TransformersPoolingBase(TransformersBase, VllmModelForPooling): # Replace legacy suffixes used for norms ".gamma": ".weight", ".beta": ".bias", - }) + }, + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) # Skip unsupported/unwanted output embeddings layers - self.skip_prefixes.extend([ - "model.lm_head.", "model.predictions.", "model.qa_outputs.", - "model.embeddings_project.", "model.discriminator_predictions." - ]) + self.skip_prefixes.extend( + [ + "model.lm_head.", + "model.predictions.", + "model.qa_outputs.", + "model.embeddings_project.", + "model.discriminator_predictions.", + ] + ) # Some encoder models have the position_ids buffer in the checkpoint. # vLLM will always pass position_ids as an argument, so we skip loading @@ -80,8 +91,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = self.text_config.pad_token_id def create_attention_instances( - self, - attn_type: AttentionType = AttentionType.DECODER + self, attn_type: AttentionType = AttentionType.DECODER ) -> dict[int, Attention]: # TODO(hmellor): Better way to detect encoder models # In encoder models, the attention layers will have `is_causal=False` @@ -107,10 +117,12 @@ def forward( if self.is_roberta: # RoBERTa-specific positions padding positions += self.padding_idx + 1 - return super().forward(input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + return super().forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) @support_torch_compile(enable_if=can_enable_torch_compile) @@ -123,10 +135,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({ - "encode": Pooler.for_encode(pooler_config), - "embed": Pooler.for_embed(pooler_config), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + } + ) @support_torch_compile(enable_if=can_enable_torch_compile) @@ -158,12 +172,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.model.pooler is not None: raise ValueError( "Sequence classification models with pooling layers are not " - "supported yet in the Transformers backend.") + "supported yet in the Transformers backend." + ) # Unlike `lm_head`, `classifier` is not always `nn.Linear`. self.classifier = seq_cls_model.classifier - self.init_parameters(self.classifier, - dtype=self.model_config.head_dtype) + self.init_parameters(self.classifier, dtype=self.model_config.head_dtype) class ClassifierWithReshape(self.classifier.__class__): """CLSPool has already been applied in `pooling`. @@ -176,33 +190,34 @@ def forward(self, *args, **kwargs): self.classifier.__class__ = ClassifierWithReshape - self.pooler = DispatchPooler({ - "encode": - Pooler.for_encode(pooler_config), - "classify": - ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_seq_cls( - vllm_config.model_config), - ), - "score": - ClassifierPooler( - pooling=CLSPool(), - classifier=self.classifier, - act_fn=ClassifierPooler.act_fn_for_cross_encoder( - vllm_config.model_config), - ), - }) + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "classify": ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config + ), + ), + "score": ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config + ), + ), + } + ) @support_torch_compile(enable_if=can_enable_torch_compile) -class TransformersMoEEmbeddingModel(TransformersMoEBase, - TransformersEmbeddingModel): +class TransformersMoEEmbeddingModel(TransformersMoEBase, TransformersEmbeddingModel): pass @support_torch_compile(enable_if=can_enable_torch_compile) class TransformersMoEForSequenceClassification( - TransformersMoEBase, TransformersForSequenceClassification): + TransformersMoEBase, TransformersForSequenceClassification +): pass diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 7744a19946a2..8f071eac2201 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,6 +3,7 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" + from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Any, Literal, Optional, Union @@ -20,21 +21,37 @@ from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, NestedTensors) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + NestedTensors, +) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) _AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>" _MAX_ENCODER_BATCH_SIZE = 16 @@ -48,15 +65,21 @@ class UltravoxAudioFeatureInputs(TensorSchema): - t: Time frames (M) - nmb: Number of mel bins """ + type: Literal["audio_features"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor], - list[list[torch.Tensor]]], - TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"})] - lens: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("b", "n", dynamic_dims={"n"})] + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]], + TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"}), + ] + lens: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "n", dynamic_dims={"n"}), + ] """Length of the audio frames. Used for attention mask in WhisperEncoder.""" - token_len: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("b", "n", dynamic_dims={"n"})] + token_len: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("b", "n", dynamic_dims={"n"}), + ] """Length of the audio tokens. Used for flattening the audio features.""" @@ -68,17 +91,17 @@ class UltravoxAudioEmbeddingInputs(TensorSchema): - afs: audio feature size - hs: hidden size """ + type: Literal["audio_embeds"] - data: Annotated[Union[torch.Tensor, list[torch.Tensor]], - TensorShape("b", "na", "afs", "hs")] + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], TensorShape("b", "na", "afs", "hs") + ] -UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, - UltravoxAudioEmbeddingInputs] +UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, UltravoxAudioEmbeddingInputs] class UltravoxProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: config = self.ctx.model_config.hf_config hf_processor = self.ctx.get_hf_processor(**kwargs) @@ -91,8 +114,7 @@ def get_hf_processor(self, **kwargs: object) -> ProcessorMixin: return hf_processor - def get_feature_extractor(self, - **kwargs: object) -> WhisperFeatureExtractor: + def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: hf_processor = self.get_hf_processor(**kwargs) audio_processor = hf_processor.audio_processor # type: ignore feature_extractor = audio_processor.feature_extractor # type: ignore @@ -103,9 +125,7 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} -class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] - ): - +class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) @@ -120,23 +140,21 @@ def get_dummy_mm_data( feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate - audio_len = (feature_extractor.chunk_length * sampling_rate * - _MAX_ENCODER_BATCH_SIZE) + audio_len = ( + feature_extractor.chunk_length * sampling_rate * _MAX_ENCODER_BATCH_SIZE + ) num_audios = mm_counts.get("audio", 0) audio_overrides = mm_options.get("audio") if mm_options else None return { - "audio": - self._get_dummy_audios(length=audio_len, - num_audios=num_audios, - overrides=audio_overrides) + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) } -class UltravoxMultiModalProcessor( - BaseMultiModalProcessor[UltravoxProcessingInfo]): - +class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -151,7 +169,8 @@ def _call_hf_processor( # Text-only input not supported in composite processor if not mm_data.get("audios", []): prompt_ids = self.info.get_tokenizer().encode( - prompt, add_special_tokens=False) + prompt, add_special_tokens=False + ) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") @@ -178,7 +197,7 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, tok_kwargs=tok_kwargs, ) - output['audio_features'] = output.pop('audio_values') + output["audio_features"] = output.pop("audio_values") return output @@ -187,17 +206,14 @@ def _get_mm_fields_config( hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0)) + num_chunks = hf_inputs.get("audio_num_chunks", torch.zeros(0)) return dict( # to handle longer than 30s audio, each audio might be split # into multiple chunks as such, their batch dimension can be # higher than the number of audio samples - audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", num_chunks), - audio_token_len=MultiModalFieldConfig.flat_from_sizes( - "audio", num_chunks), - audio_lens=MultiModalFieldConfig.flat_from_sizes( - "audio", num_chunks), + audio_features=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks), + audio_token_len=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks), + audio_lens=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks), # num_chunks can convert audio_chunked to audio batch dimension audio_num_chunks=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"), @@ -218,11 +234,12 @@ def _get_prompt_updates( # belonging to the i-th audio. out_mm_data = out_mm_kwargs.get_data() num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0)) - chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks, - dim=0, - dtype=torch.int32) + chunks_start_idx: torch.Tensor = torch.cumsum( + num_chunks, dim=0, dtype=torch.int32 + ) chunks_start_idx = torch.cat( - [torch.tensor([0], dtype=torch.int32), chunks_start_idx]) + [torch.tensor([0], dtype=torch.int32), chunks_start_idx] + ) def get_replacement_ultravox(item_idx: int): start = chunks_start_idx[item_idx] @@ -251,17 +268,16 @@ def __init__(self, stack_factor: int = 8): def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: B, T, C = audio_embeds.shape - T_pad = (T + self.stack_factor - - 1) // self.stack_factor * self.stack_factor + T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T)) B, T, C = audio_embeds.shape - audio_embeds = audio_embeds.view(B, T // self.stack_factor, - C * self.stack_factor) + audio_embeds = audio_embeds.view( + B, T // self.stack_factor, C * self.stack_factor + ) return audio_embeds class UltravoxProjector(nn.Module): - def __init__(self, config: UltravoxConfig): super().__init__() self.hidden_dim = config.hidden_size @@ -325,12 +341,15 @@ def __init__(self, *args, **kwargs): @property def max_context_length(self): - return (self.config.max_source_positions * self.conv1.stride[0] * - self.conv2.stride[0]) + return ( + self.config.max_source_positions + * self.conv1.stride[0] + * self.conv2.stride[0] + ) - def get_attention_mask_by_audio_len(self, - audio_lens: Optional[torch.Tensor], - hidden_states: torch.Tensor): + def get_attention_mask_by_audio_len( + self, audio_lens: Optional[torch.Tensor], hidden_states: torch.Tensor + ): """ Create attention mask based on audio lengths to mask out padding tokens For each sample in batch: @@ -346,9 +365,9 @@ def get_attention_mask_by_audio_len(self, audio_feature_len = self._get_feat_extract_output_lengths(audio_lens) max_seq_len = hidden_states.shape[1] - attention_mask = torch.arange(max_seq_len, - device=hidden_states.device)[None, :].lt( - audio_feature_len.view(-1, 1)) + attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[ + None, : + ].lt(audio_feature_len.view(-1, 1)) attention_mask = self.get_extended_attention_mask( attention_mask, None, @@ -367,21 +386,21 @@ def forward( f"Whisper expects the mel input features to be of length " f"{expected_seq_length} or less, but found " f"{input_features.shape[-1]}. Make sure to pad the input mel " - f"features to {expected_seq_length}.") + f"features to {expected_seq_length}." + ) inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) - embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)] + embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)] hidden_states = inputs_embeds + embed_pos - hidden_states = nn.functional.dropout(hidden_states, - p=self.dropout, - training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) - attention_mask = self.get_attention_mask_by_audio_len( - audio_lens, hidden_states) + attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states) for encoder_layer in self.layers: layer_outputs = encoder_layer( @@ -399,16 +418,17 @@ def forward( @MULTIMODAL_REGISTRY.register_processor( UltravoxMultiModalProcessor, info=UltravoxProcessingInfo, - dummy_inputs=UltravoxDummyInputsBuilder) + dummy_inputs=UltravoxDummyInputsBuilder, +) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) + orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."} + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -435,7 +455,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_or_path=config.audio_model_id, revision=None, prefix="audio_tower.", - )) + ) + ) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -446,12 +467,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # this prefix is not for initialization, but for loading weights # note the trailing dot self.secondary_weights.append( - DefaultModelLoader.Source(model_or_path=config.text_model_id, - revision=None, - prefix="language_model.")) + DefaultModelLoader.Source( + model_or_path=config.text_model_id, + revision=None, + prefix="language_model.", + ) + ) self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors) + self.language_model.make_empty_intermediate_tensors + ) def get_mm_mapping(self) -> MultiModelKeys: """ @@ -464,8 +489,8 @@ def get_mm_mapping(self) -> MultiModelKeys: ) def _audio_features_to_embeddings( - self, input_features: torch.Tensor, - audio_lens: torch.Tensor) -> torch.Tensor: + self, input_features: torch.Tensor, audio_lens: torch.Tensor + ) -> torch.Tensor: audio_features = input_features.to(self.audio_tower.dtype) batch_size = audio_features.size(0) audio_embeddings = [] @@ -474,8 +499,9 @@ def _audio_features_to_embeddings( for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE): end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size) # Process through audio tower - batch_features = self.audio_tower(audio_features[start:end], - audio_lens[start:end]) + batch_features = self.audio_tower( + audio_features[start:end], audio_lens[start:end] + ) batch_features = batch_features.to(self.audio_tower.dtype) # Process through projector @@ -487,7 +513,8 @@ def _audio_features_to_embeddings( return audio_embeddings def _parse_and_validate_audio_input( - self, **kwargs: object) -> Optional[UltravoxAudioInputs]: + self, **kwargs: object + ) -> Optional[UltravoxAudioInputs]: audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) audio_lens = kwargs.pop("audio_lens", None) @@ -497,14 +524,15 @@ def _parse_and_validate_audio_input( return None if audio_features is not None: - return UltravoxAudioFeatureInputs(type="audio_features", - data=audio_features, - lens=audio_lens, - token_len=audio_token_len) + return UltravoxAudioFeatureInputs( + type="audio_features", + data=audio_features, + lens=audio_lens, + token_len=audio_token_len, + ) if audio_embeds is not None: - return UltravoxAudioEmbeddingInputs(type="audio_embeds", - data=audio_embeds) + return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) raise AssertionError("This line should be unreachable.") @@ -520,11 +548,10 @@ def _process_audio_input( audio_features = pad_and_concat_to_dim3(audio_input["data"]) # [B1, B2] -> [B1+B2] - audio_lens = flatten_bn(audio_input['lens'], concat=True) - audio_token_len = flatten_bn(audio_input['token_len'], concat=True) + audio_lens = flatten_bn(audio_input["lens"], concat=True) + audio_token_len = flatten_bn(audio_input["token_len"], concat=True) - embeddings = self._audio_features_to_embeddings( - audio_features, audio_lens) + embeddings = self._audio_features_to_embeddings(audio_features, audio_lens) # We should flatten and concatenate embeddings based on token lengths # For example, with token_len = [4, 2, 3], flattened_embeddings will be @@ -533,23 +560,22 @@ def _process_audio_input( # Create a mask of valid indices based on token lengths max_len = embeddings.shape[1] indices = torch.arange(max_len, device=embeddings.device).expand( - embeddings.shape[0], -1) + embeddings.shape[0], -1 + ) mask = indices < audio_token_len[:, None] # Apply mask and flatten flattened_embeddings = embeddings[mask] # Return one tensor per input audio embed_lens = [ - token_len_item.sum().item() - for token_len_item in audio_input['token_len'] + token_len_item.sum().item() for token_len_item in audio_input["token_len"] ] return flattened_embeddings.split(embed_lens) def get_language_model(self) -> torch.nn.Module: return self.language_model - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is None: return [] @@ -576,12 +602,14 @@ def get_input_embeddings( handle_oov_mm_token=handle_oov_mm_token, ) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs) -> Union[torch.Tensor, IntermediateTensors]: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Ultravox One key thing to understand is the `input_ids` already accounts for the @@ -607,25 +635,21 @@ def forward(self, if hasattr(language_model, "language_model"): language_model = language_model.language_model - hidden_states = language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.language_model.compute_logits(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - - loader = AutoWeightsLoader(self, - ignore_unexpected_prefixes=["audio_tower."]) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."]) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def pad_and_concat_to_dim3( - features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]] + features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]], ) -> torch.Tensor: """ Pad and concatenate a list of tensors. diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index d6fa88f06e56..2a64f6865f12 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -14,15 +14,21 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import NestedTensors from vllm.sequence import IntermediateTensors -from vllm.utils import (cdiv, direct_register_custom_op, - get_cuda_view_from_cpu_tensor, is_pin_memory_available, - is_uva_available) +from vllm.utils import ( + cdiv, + direct_register_custom_op, + get_cuda_view_from_cpu_tensor, + is_pin_memory_available, + is_uva_available, +) logger = init_logger(__name__) @@ -65,12 +71,16 @@ def _map_name(self, key: str) -> Optional[str]: def apply( self, weights: Iterable[tuple[str, torch.Tensor]] ) -> Iterable[tuple[str, torch.Tensor]]: - return ((out_name, data) for name, data in weights - if (out_name := self._map_name(name)) is not None) + return ( + (out_name, data) + for name, data in weights + if (out_name := self._map_name(name)) is not None + ) def apply_list(self, values: list[str]) -> list[str]: return [ - out_name for name in values + out_name + for name in values if (out_name := self._map_name(name)) is not None ] @@ -129,17 +139,20 @@ def _groupby_prefix( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]: - weights_by_parts = ((weight_name.split(".", 1), weight_data) - for weight_name, weight_data in weights) + weights_by_parts = ( + (weight_name.split(".", 1), weight_data) + for weight_name, weight_data in weights + ) - for prefix, group in itertools.groupby(weights_by_parts, - key=lambda x: x[0][0]): + for prefix, group in itertools.groupby(weights_by_parts, key=lambda x: x[0][0]): yield ( prefix, # Because maxsplit=1 in weight_name.split(...), # the length of `parts` must either be 1 or 2 - (("" if len(parts) == 1 else parts[1], weights_data) - for parts, weights_data in group), + ( + ("" if len(parts) == 1 else parts[1], weights_data) + for parts, weights_data in group + ), ) def _get_qualname(self, prefix: str, rest: str) -> str: @@ -151,8 +164,9 @@ def _get_qualname(self, prefix: str, rest: str) -> str: return ".".join((prefix, rest)) def _can_skip(self, qualname: str) -> bool: - return (any(qualname.startswith(p) for p in self.skip_prefixes) - or any(substr in qualname for substr in self.skip_substrs)) + return any(qualname.startswith(p) for p in self.skip_prefixes) or any( + substr in qualname for substr in self.skip_substrs + ) def _can_ignore_unexpected(self, qualname: str) -> bool: iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes) @@ -181,24 +195,26 @@ def _load_param( raise ValueError( f"Attempted to load nested weight '{weight_qualname}' " - f"into a single parameter '{base_prefix}'") + f"into a single parameter '{base_prefix}'" + ) - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight_data) - logger.debug("Loaded weight %s with shape %s", weight_qualname, - param.shape) + logger.debug("Loaded weight %s with shape %s", weight_qualname, param.shape) yield weight_qualname - def _add_loadable_non_param_tensors(self, module: nn.Module, - child_params: dict[str, torch.Tensor]): + def _add_loadable_non_param_tensors( + self, module: nn.Module, child_params: dict[str, torch.Tensor] + ): """ Add tensor names that are not in the model params that may be in the safetensors, e.g., batch normalization stats. """ - if isinstance(module, ( + if isinstance( + module, + ( nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, @@ -206,10 +222,10 @@ def _add_loadable_non_param_tensors(self, module: nn.Module, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, nn.SyncBatchNorm, - )): + ), + ): module_state_dict = module.state_dict() - for stat_name in ("running_mean", "running_var", - "num_batches_tracked"): + for stat_name in ("running_mean", "running_var", "num_batches_tracked"): child_params[stat_name] = module_state_dict[stat_name] def _load_module( @@ -229,8 +245,8 @@ def _load_module( loaded_params = module_load_weights(weights) if loaded_params is None: logger.warning( - "Unable to collect loaded parameters " - "for module %s", module) + "Unable to collect loaded parameters for module %s", module + ) else: yield from map( lambda x: self._get_qualname(base_prefix, x), @@ -253,17 +269,18 @@ def _load_module( continue - yield from self._load_module(prefix, - child_modules[child_prefix], - child_weights) + yield from self._load_module( + prefix, child_modules[child_prefix], child_weights + ) elif child_prefix in child_params: if self._can_skip(prefix): logger.debug("Skipping param %s", prefix) continue - yield from self._load_param(prefix, child_params[child_prefix], - child_weights) + yield from self._load_param( + prefix, child_params[child_prefix], child_weights + ) else: can_skip_module = self._can_skip(prefix + ".") can_skip_param = self._can_skip(prefix) @@ -279,8 +296,10 @@ def _load_module( continue - msg = (f"There is no module or parameter named '{prefix}' " - f"in {type(self.module).__name__}") + msg = ( + f"There is no module or parameter named '{prefix}' " + f"in {type(self.module).__name__}" + ) raise ValueError(msg) def load_weights( @@ -292,8 +311,9 @@ def load_weights( if mapper is not None: weights = mapper.apply(weights) # filter out weights with first-prefix/substr to skip in name - weights = ((name, weight) for name, weight in weights - if not self._can_skip(name)) + weights = ( + (name, weight) for name, weight in weights if not self._can_skip(name) + ) autoloaded_weights = set(self._load_module("", self.module, weights)) return autoloaded_weights @@ -317,20 +337,17 @@ def init_vllm_registered_model( hf_config = vllm_config.model_config.hf_config if hf_config is not None: - vllm_config = vllm_config.with_hf_config(hf_config, - architectures=architectures) + vllm_config = vllm_config.with_hf_config(hf_config, architectures=architectures) return initialize_model(vllm_config=vllm_config, prefix=prefix) @overload -def flatten_bn(x: torch.Tensor) -> torch.Tensor: - ... +def flatten_bn(x: torch.Tensor) -> torch.Tensor: ... @overload -def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: - ... +def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]: ... @overload @@ -338,8 +355,7 @@ def flatten_bn( x: Union[list[torch.Tensor], torch.Tensor], *, concat: Literal[True], -) -> torch.Tensor: - ... +) -> torch.Tensor: ... @overload @@ -347,8 +363,7 @@ def flatten_bn( x: Union[list[torch.Tensor], torch.Tensor], *, concat: bool = False, -) -> Union[list[torch.Tensor], torch.Tensor]: - ... +) -> Union[list[torch.Tensor], torch.Tensor]: ... def flatten_bn( @@ -392,8 +407,7 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str: if isinstance(embeddings, torch.Tensor): return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) - return " + ".join( - _embedding_count_expression(inner) for inner in embeddings) + return " + ".join(_embedding_count_expression(inner) for inner in embeddings) def _merge_multimodal_embeddings( @@ -421,8 +435,9 @@ def _merge_multimodal_embeddings( # NOTE: This can avoid D2H sync (#22105), but fails to # raise an error if is_multimodal.sum() < len(mm_embeds_flat) - inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), - mm_embeds_flat.to(dtype=input_dtype)) + inputs_embeds.masked_scatter_( + is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype) + ) except RuntimeError as e: num_actual_tokens = len(mm_embeds_flat) num_expected_tokens = is_multimodal.sum().item() @@ -440,9 +455,11 @@ def _merge_multimodal_embeddings( return inputs_embeds -@deprecated("`merge_multimodal_embeddings` has been replaced with " - "`SupportsMultiModal.get_input_embeddings` and will be " - "removed in v0.12.") +@deprecated( + "`merge_multimodal_embeddings` has been replaced with " + "`SupportsMultiModal.get_input_embeddings` and will be " + "removed in v0.12." +) def merge_multimodal_embeddings( input_ids: torch.Tensor, inputs_embeds: torch.Tensor, @@ -477,7 +494,7 @@ def merge_multimodal_embeddings( if isinstance(placeholder_token_id, list): is_multimodal = isin_list(input_ids, placeholder_token_id) else: - is_multimodal = (input_ids == placeholder_token_id) + is_multimodal = input_ids == placeholder_token_id return _merge_multimodal_embeddings( inputs_embeds, @@ -499,9 +516,7 @@ def isin_list( class LayerFn(Protocol): - - def __call__(self, prefix: str) -> torch.nn.Module: - ... + def __call__(self, prefix: str) -> torch.nn.Module: ... class PPMissingLayer(torch.nn.Identity): @@ -544,8 +559,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: uva_available = is_uva_available() if envs.VLLM_USE_V1: - assert uva_available, ("V1 CPU offloading requires" - " uva (pin memory) support") + assert uva_available, "V1 CPU offloading requires uva (pin memory) support" uva_offloading = True else: uva_offloading = False @@ -560,12 +574,14 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: break # `torch.empty_like` does not support `pin_memory` argument - cpu_data = torch.empty_strided(size=p.data.size(), - stride=p.data.stride(), - dtype=p.data.dtype, - layout=p.data.layout, - device='cpu', - pin_memory=pin_memory) + cpu_data = torch.empty_strided( + size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory, + ) cpu_data.copy_(p.data) if not uva_offloading: p.data = cpu_data @@ -587,10 +603,7 @@ def forward(*args, **kwargs): k: v.to(device, non_blocking=True) for k, v in module.state_dict().items() } - output = functional_call(module, - device_state, - args=args, - kwargs=kwargs) + output = functional_call(module, device_state, args=args, kwargs=kwargs) module.forward = forward return output @@ -609,14 +622,18 @@ def make_layers( """ from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.utils import get_pp_indices - start_layer, end_layer = get_pp_indices(num_hidden_layers, - get_pp_group().rank_in_group, - get_pp_group().world_size) + + start_layer, end_layer = get_pp_indices( + num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size + ) modules = torch.nn.ModuleList( - [PPMissingLayer() for _ in range(start_layer)] + [ + [PPMissingLayer() for _ in range(start_layer)] + + [ maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) for idx in range(start_layer, end_layer) - ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) + ] + + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)] + ) return start_layer, end_layer, modules @@ -636,7 +653,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]: # NOTE: the trailing dot is used to match the prefix of the layer. # without the dot, we could match a layer that is not missing, # e.g., 'encoder.layer.1' would match 'encoder.layer.11' - missing_layer_names.append(name + '.') + missing_layer_names.append(name + ".") _model_to_pp_missing_layer_names[model_id] = missing_layer_names return missing_layer_names @@ -649,21 +666,22 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: return any( name.startswith(missing_layer_name) - for missing_layer_name in get_pp_missing_layer_names(model)) + for missing_layer_name in get_pp_missing_layer_names(model) + ) def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int): - def make_empty_intermediate_tensors( batch_size: int, dtype: torch.dtype, device: torch.device, ) -> IntermediateTensors: - return IntermediateTensors({ - key: - torch.zeros((batch_size, hidden_size), dtype=dtype, device=device) - for key in keys - }) + return IntermediateTensors( + { + key: torch.zeros((batch_size, hidden_size), dtype=dtype, device=device) + for key in keys + } + ) return make_empty_intermediate_tensors @@ -698,15 +716,20 @@ def extract_layer_index(layer_name: str, num_attn_module: int = 1) -> int: except ValueError: continue if num_attn_module == 1 or "attn" not in layer_name: - assert len(int_vals) == 1, (f"layer name {layer_name} should" - " only contain one integer") + assert len(int_vals) == 1, ( + f"layer name {layer_name} should only contain one integer" + ) return int_vals[0] else: - assert len(int_vals) <= 2, (f"layer name {layer_name} should" - " contain most two integers") - layer_index = int_vals[0] * num_attn_module + int_vals[1] if len( - int_vals) == 2 else int_vals[0] + assert len(int_vals) <= 2, ( + f"layer name {layer_name} should contain most two integers" + ) + layer_index = ( + int_vals[0] * num_attn_module + int_vals[1] + if len(int_vals) == 2 + else int_vals[0] + ) return layer_index @@ -720,19 +743,20 @@ def cast_overflow_tensors( return tensors -def fast_topk(values: torch.Tensor, topk: int, - dim: int) -> tuple[torch.Tensor, torch.Tensor]: +def fast_topk( + values: torch.Tensor, topk: int, dim: int +) -> tuple[torch.Tensor, torch.Tensor]: """ Optimized topk implementation that uses torch.max for k=1 case. - + This function provides better performance for the common case of k=1 by using torch.max instead of the more general torch.topk. - + Args: values: Input tensor to find top-k values from topk: Number of top values to return (k). Must be > 0. dim: Dimension along which to compute topk - + Returns: Tuple of (values, indices) where values are the top-k values and indices are their corresponding indices in the input tensor @@ -791,5 +815,5 @@ def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor: op_name="sequence_parallel_chunk_impl", op_func=sequence_parallel_chunk_impl, fake_impl=sequence_parallel_chunk_impl_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(torch.Tag.needs_fixed_stride_order,), ) diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 2636942580fa..74262f8b94a6 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -4,16 +4,17 @@ import itertools import math from abc import ABC, abstractmethod -from typing import (Callable, Final, Generic, Literal, Optional, Protocol, - TypeVar, Union) +from typing import Callable, Final, Generic, Literal, Optional, Protocol, TypeVar, Union import torch from transformers import PretrainedConfig from vllm.attention.backends.registry import _Backend -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) from vllm.logger import init_logger from vllm.platforms import current_platform @@ -27,7 +28,6 @@ class _RootConfig(Protocol[_C]): class VisionEncoderInfo(ABC, Generic[_C]): - def __init__(self, hf_config: _RootConfig[_C]) -> None: super().__init__() @@ -60,8 +60,7 @@ class VisionLanguageConfig(Protocol): vision_config: Final[PretrainedConfig] -def get_vision_encoder_info( - hf_config: VisionLanguageConfig) -> VisionEncoderInfo: +def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInfo: # Avoid circular imports from .clip import CLIPEncoderInfo, CLIPVisionConfig from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig @@ -92,8 +91,10 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: return current_platform.get_vit_attn_backend(head_size, dtype) +VisionFeatureSelectStrategyStr = Literal["class", "default", "full"] + VisionFeatureSelectStrategy = Union[ - Literal["class", "default", "full"], + VisionFeatureSelectStrategyStr, Callable[[torch.Tensor], torch.Tensor], ] @@ -106,7 +107,7 @@ def _get_vision_feature_selector( # https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762 if strategy == "class": - return lambda feats: feats[:, 0, :] + return lambda feats: feats[:, :1, :] # https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196 if strategy == "default": @@ -162,12 +163,13 @@ def resolve_visual_encoder_outputs( """ if select_layers is None: if not isinstance(encoder_outputs, torch.Tensor): - raise ValueError("Expected only a single encoder output when " - "`select_layers` is not provided") + raise ValueError( + "Expected only a single encoder output when " + "`select_layers` is not provided" + ) if feature_select_strategy is not None: - select_features = _get_vision_feature_selector( - feature_select_strategy) + select_features = _get_vision_feature_selector(feature_select_strategy) encoder_outputs = select_features(encoder_outputs) if post_layer_norm is not None: @@ -176,8 +178,9 @@ def resolve_visual_encoder_outputs( return encoder_outputs if max_possible_layers is None: - raise ValueError("`max_possible_layers` must be provided " - "alongside `select_layers`") + raise ValueError( + "`max_possible_layers` must be provided alongside `select_layers`" + ) # Get the hidden states corresponding to the layer indices. # Negative values are relative to the full visual encoder, @@ -189,7 +192,8 @@ def resolve_visual_encoder_outputs( offset = max_possible_layers - num_loaded_layers hs_pool = [ encoder_outputs[layer_idx] - if layer_idx >= 0 else encoder_outputs[layer_idx + offset] + if layer_idx >= 0 + else encoder_outputs[layer_idx + offset] for layer_idx in select_layers ] @@ -205,9 +209,10 @@ def resolve_visual_encoder_outputs( return torch.cat(hs_pool, dim=-1) -def run_dp_sharded_vision_model(image_input: torch.Tensor, - vision_model: torch.nn.Module) -> torch.Tensor: - """Run a vision model with data parallelism (DP) sharding. The function +def run_dp_sharded_vision_model( + image_input: torch.Tensor, vision_model: torch.nn.Module +) -> torch.Tensor: + """Run a vision model with data parallelism (DP) sharding. The function will shard the input image tensor on the first dimension and run the vision model @@ -222,18 +227,17 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, mp_world_size = get_tensor_model_parallel_world_size() num_chunks_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size num_padded_chunks = num_chunks_per_rank * mp_world_size - num_chunks - pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) + pad = (0,) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) image_input_padded = torch.nn.functional.pad(image_input, pad) rank = get_tensor_model_parallel_rank() - image_input_per_rank = image_input_padded[rank * - num_chunks_per_rank:(rank + 1) * - num_chunks_per_rank, ...] + image_input_per_rank = image_input_padded[ + rank * num_chunks_per_rank : (rank + 1) * num_chunks_per_rank, ... + ] vision_embeddings = vision_model(image_input_per_rank) # Ensure tensor is contiguous before all_gather vision_embeddings = vision_embeddings.contiguous() - vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, - dim=0) + vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0) vision_embeddings = vision_embeddings[:num_chunks, ...] return vision_embeddings @@ -243,27 +247,27 @@ def get_load_balance_assignment( num_gpus: int = 2, ) -> tuple[list[int], list[int], list[int]]: """ - Generate load balancing assignment and metadata + Generate load balancing assignment and metadata for distributing data across GPUs. The load is determined by the total image sizes, not the number of images. - + Args: sizes: The size of each image num_gpus: Number of GPUs to balance across - + Returns: - shuffle_indices: + shuffle_indices: Indices to reorder data for balanced loading - gpu_sample_counts: + gpu_sample_counts: Number of samples assigned to each GPU - grouped_sizes_per_gpu: + grouped_sizes_per_gpu: Total size assigned to each GPU - + Example: ``` sizes = [1000, 100, 200, 50] - num_gpus=2 + num_gpus = 2 ``` """ @@ -281,9 +285,9 @@ def get_load_balance_assignment( # Sort indices by size (largest first for better load balancing) # sizes = [1000, 100, 200, 50] # large_to_small_indices = [0, 2, 1, 3] - large_to_small_indices = sorted(range(n_samples), - key=lambda i: sizes[i], - reverse=True) + large_to_small_indices = sorted( + range(n_samples), key=lambda i: sizes[i], reverse=True + ) for idx in large_to_small_indices: # Find GPU with minimum current load (by total size) @@ -314,11 +318,11 @@ def run_dp_sharded_mrope_vision_model( *, rope_type: Literal["rope_3d", "rope_2d"], ) -> tuple[torch.Tensor, ...]: - """Run a vision model with data parallelism (DP) sharding. - The function will shard the input image tensor on the + """Run a vision model with data parallelism (DP) sharding. + The function will shard the input image tensor on the first dimension and run the vision model. This function is used to run the vision model with mrope. - + Args: vision_model (torch.nn.Module): Vision model. pixel_values (torch.Tensor): Image/Video input tensor. @@ -336,7 +340,7 @@ def run_dp_sharded_mrope_vision_model( vision_model.spatial_merge_size = 2 pixel_values.shape = (1350, channel) grid_thw_list = [[1, 10, 100], [1, 10, 10], [1, 10, 20], [1, 50]] - tp_size=2 + tp_size = 2 ``` """ @@ -355,51 +359,57 @@ def run_dp_sharded_mrope_vision_model( # image_to_tp_rank = [0, 2, 1, 3] # gpu_sample_counts = [1, 3] # grouped_pixel_values_len = [1000, 350] - (image_to_tp_rank, gpu_sample_counts, - grouped_pixel_values_len) = get_load_balance_assignment( - patches_per_image, tp_size) + (image_to_tp_rank, gpu_sample_counts, grouped_pixel_values_len) = ( + get_load_balance_assignment(patches_per_image, tp_size) + ) # cu_gpu_sample_counts = [0, 1, 4] cum_gpu_sample_counts = [0, *itertools.accumulate(gpu_sample_counts)] # GPU_0 image_idxs_local = [0] # GPU_1 image_idxs_local = [2, 1, 3] - image_idxs_local = image_to_tp_rank[cum_gpu_sample_counts[tp_rank_local]: - cum_gpu_sample_counts[tp_rank_local + - 1]] + image_idxs_local = image_to_tp_rank[ + cum_gpu_sample_counts[tp_rank_local] : cum_gpu_sample_counts[tp_rank_local + 1] + ] # Get the pixel values for the local images based on the image_idxs_local if len(image_idxs_local) > 0: - pixel_values_local = torch.cat([ - pixel_values[cum_patches_per_image[i]:cum_patches_per_image[i + 1]] - for i in image_idxs_local - ]) + pixel_values_local = torch.cat( + [ + pixel_values[cum_patches_per_image[i] : cum_patches_per_image[i + 1]] + for i in image_idxs_local + ] + ) else: # Handle case where this rank has no images - pixel_values_local = torch.empty((0, pixel_values.shape[1]), - device=pixel_values.device, - dtype=pixel_values.dtype) + pixel_values_local = torch.empty( + (0, pixel_values.shape[1]), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) # embed_dim_reduction_factor = 2 * 2 if rope_type == "rope_2d": - embed_dim_reduction_factor = (vision_model.merge_kernel_size[0] * - vision_model.merge_kernel_size[1]) + embed_dim_reduction_factor = ( + vision_model.merge_kernel_size[0] * vision_model.merge_kernel_size[1] + ) else: - embed_dim_reduction_factor = (vision_model.spatial_merge_size * - vision_model.spatial_merge_size) + embed_dim_reduction_factor = ( + vision_model.spatial_merge_size * vision_model.spatial_merge_size + ) # Find the max length across all ranks # The output embedding of every DP rank has to be # padded to this length for tensor_model_parallel_all_gather # to work - max_len_per_rank = max( - grouped_pixel_values_len) // embed_dim_reduction_factor + max_len_per_rank = max(grouped_pixel_values_len) // embed_dim_reduction_factor local_grid_thw_list = [grid_thw_list[i] for i in image_idxs_local] # Run the vision model on the local pixel_values_local if rope_type == "rope_2d": if pixel_values_local.shape[0] > 0: image_embeds_local = vision_model( - pixel_values_local, torch.tensor(local_grid_thw_list)) + pixel_values_local, torch.tensor(local_grid_thw_list) + ) if isinstance(image_embeds_local, list): image_embeds_local = torch.cat(image_embeds_local, dim=0) else: @@ -407,16 +417,18 @@ def run_dp_sharded_mrope_vision_model( image_embeds_local = torch.empty( (0, embed_dim_reduction_factor, out_dim), device=pixel_values.device, - dtype=pixel_values.dtype) + dtype=pixel_values.dtype, + ) else: if pixel_values_local.shape[0] > 0: - image_embeds_local = vision_model(pixel_values_local, - local_grid_thw_list) + image_embeds_local = vision_model(pixel_values_local, local_grid_thw_list) else: # Handle empty case - image_embeds_local = torch.empty((0, vision_model.out_hidden_size), - device=pixel_values.device, - dtype=pixel_values.dtype) + image_embeds_local = torch.empty( + (0, vision_model.out_hidden_size), + device=pixel_values.device, + dtype=pixel_values.dtype, + ) # Pad the output based on max_len_per_rank # for tensor_model_parallel_all_gather to work @@ -424,33 +436,40 @@ def run_dp_sharded_mrope_vision_model( if current_len < max_len_per_rank: padding_size = max_len_per_rank - current_len if rope_type == "rope_2d": - padding = torch.empty((padding_size, image_embeds_local.shape[1], - image_embeds_local.shape[2]), - dtype=image_embeds_local.dtype, - device=image_embeds_local.device) + padding = torch.empty( + ( + padding_size, + image_embeds_local.shape[1], + image_embeds_local.shape[2], + ), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device, + ) else: - padding = torch.empty((padding_size, image_embeds_local.shape[1]), - dtype=image_embeds_local.dtype, - device=image_embeds_local.device) - image_embeds_local_padded = torch.cat([image_embeds_local, padding], - dim=0) + padding = torch.empty( + (padding_size, image_embeds_local.shape[1]), + dtype=image_embeds_local.dtype, + device=image_embeds_local.device, + ) + image_embeds_local_padded = torch.cat([image_embeds_local, padding], dim=0) else: image_embeds_local_padded = image_embeds_local # Do all_gather to collect embeddings from all ranks - gathered_embeds = tensor_model_parallel_all_gather( - image_embeds_local_padded, dim=0) + gathered_embeds = tensor_model_parallel_all_gather(image_embeds_local_padded, dim=0) # Remove padding and reconstruct per-rank embeddings rank_embeddings = list[torch.Tensor]() for rank in range(tp_size): start_idx = rank * max_len_per_rank - end_idx = start_idx + (grouped_pixel_values_len[rank] // - embed_dim_reduction_factor) + end_idx = start_idx + ( + grouped_pixel_values_len[rank] // embed_dim_reduction_factor + ) rank_embeddings.append(gathered_embeds[start_idx:end_idx]) - patches_per_output_image = [(patch_size // embed_dim_reduction_factor) - for patch_size in patches_per_image] + patches_per_output_image = [ + (patch_size // embed_dim_reduction_factor) for patch_size in patches_per_image + ] # Reconstruct embeddings in the original order original_order_embeddings = [None] * len(grid_thw_list) @@ -461,7 +480,7 @@ def run_dp_sharded_mrope_vision_model( # Get images assigned to this rank in shuffled order # GPU_0 = image_idxs_local [0] # GPU_1 = image_idxs_local [2, 1, 3] - rank_images = image_to_tp_rank[current_idx:current_idx + count] + rank_images = image_to_tp_rank[current_idx : current_idx + count] rank_embed = rank_embeddings[rank] # Split rank embeddings back to individual images @@ -469,11 +488,14 @@ def run_dp_sharded_mrope_vision_model( for img_idx in rank_images: img_patches = patches_per_output_image[img_idx] original_order_embeddings[img_idx] = rank_embed[ - embed_start:embed_start + img_patches] + embed_start : embed_start + img_patches + ] embed_start += img_patches current_idx += count - out_embeddings = tuple(embed for embed in original_order_embeddings - if embed is not None) - assert len(out_embeddings) == len( - original_order_embeddings), "Found unassigned embeddings" + out_embeddings = tuple( + embed for embed in original_order_embeddings if embed is not None + ) + assert len(out_embeddings) == len(original_order_embeddings), ( + "Found unassigned embeddings" + ) return out_embeddings diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index ad494a7a7ec9..0d77b72675e2 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -12,8 +12,12 @@ import torch import torch.nn as nn from mistral_common.audio import mel_filter_bank -from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio, - TextChunk, UserMessage) +from mistral_common.protocol.instruct.messages import ( + AudioChunk, + RawAudio, + TextChunk, + UserMessage, +) from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder @@ -28,23 +32,33 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import SupportsPP from vllm.model_executor.models.module_mapping import MultiModelKeys -# yapf: disable from vllm.model_executor.models.whisper import WhisperEncoder -# yapf: enable from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems, MultiModalUUIDDict, - NestedTensors) -from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, - MultiModalDataParser) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, - MultiModalProcessingInfo, - PromptReplacement, PromptUpdate) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + MultiModalUUIDDict, + NestedTensors, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalProcessingInfo, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.tokenizer import (MistralTokenizer, - cached_tokenizer_from_config) +from vllm.transformers_utils.tokenizer import ( + MistralTokenizer, + cached_tokenizer_from_config, +) from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix @@ -109,7 +123,8 @@ def get_num_audio_tokens( audio_length: int, ) -> int: pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames( - audio_length, self.sampling_rate) + audio_length, self.sampling_rate + ) return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate)) def __call__( @@ -139,7 +154,8 @@ def __call__( "Make sure to process your input via `mistral_common`'s " "tokenizer or pass a chat completion request. " "For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411.") + "https://github.com/vllm-project/vllm/issues/8411." + ) audios_tokens = list[torch.Tensor]() audios_processed = list[torch.Tensor]() @@ -150,23 +166,22 @@ def __call__( # pad if necessary audio = self._audio_processor.pad(audio, self.sampling_rate) - audio_tokens = [ - self.begin_audio_token_id - ] + [self.audio_token_id] * self.get_num_audio_tokens(len(audio)) + audio_tokens = [self.begin_audio_token_id] + [ + self.audio_token_id + ] * self.get_num_audio_tokens(len(audio)) audios_tokens.append(torch.tensor(audio_tokens)) audios_processed.append(torch.tensor(audio)) - return BatchFeature({ - "input_ids": - torch.cat(audios_tokens)[None].expand(len(text), -1), - "audio_arrays": - audios_processed, - }) + return BatchFeature( + { + "input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1), + "audio_arrays": audios_processed, + } + ) class VoxtralProcessingInfo(BaseProcessingInfo): - def get_tokenizer(self) -> MistralTokenizer: tokenizer = cached_tokenizer_from_config(self.ctx.model_config) if not isinstance(tokenizer, MistralTokenizer): @@ -193,11 +208,11 @@ def get_max_audio_tokens(self) -> int: def get_max_audio_array_len(self) -> int: processor = self.get_hf_processor() return self.get_max_audio_tokens() * int( - processor.sampling_rate // processor.frame_rate) + processor.sampling_rate // processor.frame_rate + ) class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -214,10 +229,9 @@ def get_dummy_mm_data( audio_overrides = mm_options.get("audio") if mm_options else None return { - "audio": - self._get_dummy_audios(length=target_length, - num_audios=num_audios, - overrides=audio_overrides) + "audio": self._get_dummy_audios( + length=target_length, num_audios=num_audios, overrides=audio_overrides + ) } def get_dummy_processor_inputs( @@ -243,9 +257,11 @@ def get_dummy_processor_inputs( chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item)) audio_chunks.append(chunk) - request = ChatCompletionRequest(messages=[ - UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]), - ]) + request = ChatCompletionRequest( + messages=[ + UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]), + ] + ) res = tokenizer.mistral.encode_chat_completion(request) dummy_tokens = res.tokens # whixtral tokenizer adds padding to the audio @@ -255,9 +271,7 @@ def get_dummy_processor_inputs( return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data) -class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] - ): - +class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]): def _get_mm_fields_config( self, hf_inputs: Mapping[str, NestedTensors], @@ -315,17 +329,19 @@ def _get_data_parser(self) -> MultiModalDataParser: return MultiModalDataParser(target_sr=sampling_rate) -@MULTIMODAL_REGISTRY.register_processor(VoxtralMultiModalProcessor, - info=VoxtralProcessingInfo, - dummy_inputs=VoxtralDummyInputsBuilder) -class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP, SupportsLoRA, - SupportsTranscription): +@MULTIMODAL_REGISTRY.register_processor( + VoxtralMultiModalProcessor, + info=VoxtralProcessingInfo, + dummy_inputs=VoxtralDummyInputsBuilder, +) +class VoxtralForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription +): supported_languages = ISO639_1_SUPPORTED_LANGS packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -336,7 +352,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # match the vLLM model names if hasattr(vllm_config, "quant_config"): vllm_config.quant_config = self.maybe_update_quant_config( - vllm_config.quant_config) + vllm_config.quant_config + ) config = vllm_config.model_config.hf_config self.config = config @@ -378,17 +395,15 @@ def forward( if intermediate_tensors is not None: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) return hidden_states def get_multimodal_embeddings( self, **kwargs - ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...], - None]: + ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...], None]: audio_inputs = self._parse_and_validate_audio_arrays(**kwargs) if audio_inputs is None: return None @@ -399,34 +414,36 @@ def get_multimodal_embeddings( seq_len, dim = audio_embedding.shape # Pad such that seq_len is divisible by downsample_factor target_seq_len = self.downsample_factor * math.ceil( - seq_len / self.downsample_factor) + seq_len / self.downsample_factor + ) audio_embedding = torch.nn.functional.pad( audio_embedding, (0, 0, 0, target_seq_len - seq_len), ) audio_embeddings[i] = audio_embedding.reshape( - target_seq_len // self.downsample_factor, - dim * self.downsample_factor) + target_seq_len // self.downsample_factor, dim * self.downsample_factor + ) # Concat, project and resplit audio_embeddings_packed = torch.cat(audio_embeddings, dim=0) - audio_embeddings_packed = self.audio_language_adapter( - audio_embeddings_packed) - audio_embeddings = torch.split(audio_embeddings_packed, - [a.shape[0] for a in audio_embeddings], - dim=0) + audio_embeddings_packed = self.audio_language_adapter(audio_embeddings_packed) + audio_embeddings = torch.split( + audio_embeddings_packed, [a.shape[0] for a in audio_embeddings], dim=0 + ) return audio_embeddings def _parse_and_validate_audio_arrays( - self, **kwargs: object) -> Union[list[torch.Tensor], None]: + self, **kwargs: object + ) -> Union[list[torch.Tensor], None]: audio_arrays = kwargs.pop("audio_arrays", None) if audio_arrays is None: return None if not isinstance(audio_arrays, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_arrays. " - f"Got type: {type(audio_arrays)}") + raise ValueError( + f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}" + ) audio_arrays = flatten_bn(audio_arrays) if isinstance(audio_arrays, torch.Tensor): @@ -440,8 +457,9 @@ def compute_logits( return self.language_model.compute_logits(hidden_states) @classmethod - def get_speech_to_text_config(cls, model_config: ModelConfig, - task_type: str) -> SpeechToTextConfig: + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: tokenizer = cached_tokenizer_from_config(model_config) audio_config = tokenizer.instruct.audio_encoder.audio_config max_audio_clip_s = audio_config.chunk_length_s @@ -455,19 +473,23 @@ def get_speech_to_text_config(cls, model_config: ModelConfig, @classmethod # for speech-to-text transcription - def get_generation_prompt(cls, audio: np.ndarray, - model_config: ModelConfig, - stt_config: SpeechToTextConfig, - language: Optional[str], - task_type: Literal["transcribe", "translate"], - request_prompt: str, - to_language: Optional[str]) -> PromptType: + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, + stt_config: SpeechToTextConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: tokenizer = cached_tokenizer_from_config(model_config) - audio = Audio(audio, int(stt_config.sample_rate), - format="wav") # lossless - req = TranscriptionRequest(model=model_config.model, - audio=RawAudio.from_audio(audio), - language=language) + audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless + req = TranscriptionRequest( + model=model_config.model, + audio=RawAudio.from_audio(audio), + language=language, + ) tokenized = tokenizer.instruct.encode_transcription(req) audio = (tokenized.audios[0].audio_array, stt_config.sample_rate) @@ -476,35 +498,44 @@ def get_generation_prompt(cls, audio: np.ndarray, return cast(PromptType, prompts_dict) @classmethod - def get_num_audio_tokens(cls, audio_duration_s: float, - stt_config: SpeechToTextConfig, - model_config: ModelConfig) -> Optional[int]: + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> Optional[int]: """ - Map from audio duration to number of audio tokens produced by the ASR + Map from audio duration to number of audio tokens produced by the ASR model, without running a forward pass. This is used for estimating the amount of processing for this audio. """ tokenizer = cached_tokenizer_from_config(model_config) adapter = VoxtralProcessorAdapter(tokenizer) return adapter.get_num_audio_tokens( - int(audio_duration_s * stt_config.sample_rate)) + int(audio_duration_s * stt_config.sample_rate) + ) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - # fmt: off + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: remapping_rules = [ (r"mm_whisper_embeddings\.(.*)", r"\1"), (r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"), - (r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"), # noqa: E501 - (r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"), # noqa: E501 + ( + r"audio_language_adapter\.0\.weight", + r"audio_language_adapter.w_in.weight", + ), + ( + r"audio_language_adapter\.2\.weight", + r"audio_language_adapter.w_out.weight", + ), ] - # fmt: on audio_params = dict( - nn.ModuleDict({ - "audio_language_adapter": - self.audio_language_adapter, - }).named_parameters()) + nn.ModuleDict( + { + "audio_language_adapter": self.audio_language_adapter, + } + ).named_parameters() + ) loaded_weights = set() @@ -512,10 +543,12 @@ def llm_weights_generator(): nonlocal loaded_weights for name, w in weights: is_encoder = ( - name.startswith("mm_whisper_embeddings") and - not name.startswith("mm_whisper_embeddings.tok_embeddings") + name.startswith("mm_whisper_embeddings") + and not name.startswith("mm_whisper_embeddings.tok_embeddings") and not name.startswith( - "mm_whisper_embeddings.audio_language_projection")) + "mm_whisper_embeddings.audio_language_projection" + ) + ) for pattern, repl in remapping_rules: if re.fullmatch(pattern, name): @@ -546,7 +579,8 @@ def llm_weights_generator(): return loaded_weights def maybe_update_quant_config( - self, quant_config: QuantizationConfig) -> QuantizationConfig: + self, quant_config: QuantizationConfig + ) -> QuantizationConfig: """ Update quant config to so that ignored module and target module names match the vLLM model names. @@ -555,32 +589,54 @@ def maybe_update_quant_config( """ remapping_rules = [ (r"output", r"language_model.lm_head"), - (r"layers\.(\d+)\.attention\.wo", - r"language_model.model.layers.\1.self_attn.out_proj"), - (r"layers\.(\d+)\.attention\.w(.*)", - r"language_model.model.layers.\1.self_attn.\2_proj"), - (r"layers\.(\d+)\.feed_forward\.w1", - r"language_model.model.layers.\1.mlp.gate_proj"), - (r"layers\.(\d+)\.feed_forward\.w2", - r"language_model.model.layers.\1.mlp.down_proj"), - (r"layers\.(\d+)\.feed_forward\.w3", - r"language_model.model.layers.\1.mlp.up_proj"), - (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)", - r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj" - ), - (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo", - r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj" - ), - (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)", - r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"), - (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0", - r"whisper_encoder.whisper_encoder.conv1"), - (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1", - r"whisper_encoder.whisper_encoder.conv2"), - (r"mm_whisper_embeddings\.audio_language_projection\.0", - r"audio_language_adapter.w_in"), - (r"mm_whisper_embeddings\.audio_language_projection\.2", - r"audio_language_adapter.w_out"), + ( + r"layers\.(\d+)\.attention\.wo", + r"language_model.model.layers.\1.self_attn.out_proj", + ), + ( + r"layers\.(\d+)\.attention\.w(.*)", + r"language_model.model.layers.\1.self_attn.\2_proj", + ), + ( + r"layers\.(\d+)\.feed_forward\.w1", + r"language_model.model.layers.\1.mlp.gate_proj", + ), + ( + r"layers\.(\d+)\.feed_forward\.w2", + r"language_model.model.layers.\1.mlp.down_proj", + ), + ( + r"layers\.(\d+)\.feed_forward\.w3", + r"language_model.model.layers.\1.mlp.up_proj", + ), + ( + r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)", + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj", + ), + ( + r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo", + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj", + ), + ( + r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)", + r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2", + ), + ( + r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0", + r"whisper_encoder.whisper_encoder.conv1", + ), + ( + r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1", + r"whisper_encoder.whisper_encoder.conv2", + ), + ( + r"mm_whisper_embeddings\.audio_language_projection\.0", + r"audio_language_adapter.w_in", + ), + ( + r"mm_whisper_embeddings\.audio_language_projection\.2", + r"audio_language_adapter.w_out", + ), ] # Update ignore list @@ -613,7 +669,6 @@ def maybe_update_quant_config( class AudioLanguageAdapter(nn.Module): - def __init__(self, hidden_size: int, dim: int) -> None: super().__init__() self.w_in = nn.Linear(hidden_size, dim, bias=False) @@ -627,19 +682,44 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VoxtralEncoderModel(nn.Module): packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} - # fmt: off mistral_remapping = [ - (r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501 - (r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501 - (r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501 + ( + r"whisper_encoder\.conv_layers\.0\.(weight|bias)", + r"whisper_encoder.conv1.\1", + ), + ( + r"whisper_encoder\.conv_layers\.1\.(weight|bias)", + r"whisper_encoder.conv2.\1", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.self_attn.\2_proj.\3", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.self_attn.out_proj.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.self_attn_layer_norm.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.mlp.fc1.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501 + r"whisper_encoder.layers.\1.mlp.fc2.\2", + ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", + r"whisper_encoder.layers.\1.final_layer_norm.\2", + ), + ( + r"whisper_encoder\.transformer\.norm\.(weight|bias)", + r"whisper_encoder.layer_norm.\1", + ), ] - # fmt: on def __init__( self, @@ -650,10 +730,11 @@ def __init__( super().__init__() self.config = cast(WhisperConfig, vllm_config.model_config.hf_config) self.dtype: torch.dtype = vllm_config.model_config.dtype - self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "whisper_encoder"), - init_in_fp32=True) + self.whisper_encoder = WhisperEncoder( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "whisper_encoder"), + init_in_fp32=True, + ) mel_filters = mel_filter_bank( num_frequency_bins=1 + self.config.window_size // 2, num_mel_bins=self.config.num_mel_bins, @@ -668,8 +749,7 @@ def compute_whisper_melspec( audio_waveforms: torch.Tensor, ) -> torch.Tensor: input_dtype = audio_waveforms.dtype - window = torch.hann_window(self.config.window_size).to( - audio_waveforms.device) + window = torch.hann_window(self.config.window_size).to(audio_waveforms.device) stft = torch.stft( audio_waveforms, self.config.window_size, @@ -677,7 +757,7 @@ def compute_whisper_melspec( window=window, return_complex=True, ) - magnitudes = stft[..., :-1].abs()**2 + magnitudes = stft[..., :-1].abs() ** 2 mel_spec = self.mel_filters.T @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) @@ -686,8 +766,9 @@ def compute_whisper_melspec( @property def downsample_factor(self) -> int: - return self.whisper_encoder.conv1.stride[ - 0] * self.whisper_encoder.conv2.stride[0] + return ( + self.whisper_encoder.conv1.stride[0] * self.whisper_encoder.conv2.stride[0] + ) @property def chunk_size(self) -> int: @@ -721,8 +802,7 @@ def forward( input_features = [input_features] # Split long inputs into chunks - input_embeds, chunks_per_example = ( - self.prepare_inputs_for_conv(input_features)) + input_embeds, chunks_per_example = self.prepare_inputs_for_conv(input_features) # [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size] out = self.whisper_encoder([input_embeds]) @@ -731,7 +811,7 @@ def forward( chunk_idx = 0 results = [] for n_chunks in chunks_per_example: - result = out[chunk_idx:chunk_idx + n_chunks].flatten(0, 1) + result = out[chunk_idx : chunk_idx + n_chunks].flatten(0, 1) results.append(result) chunk_idx += n_chunks @@ -751,7 +831,7 @@ def load_weight(self, weight: tuple[str, torch.Tensor]) -> str: if re.fullmatch(pattern, name): name = re.sub(pattern, repl, name) - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -762,8 +842,7 @@ def load_weight(self, weight: tuple[str, torch.Tensor]) -> str: break else: param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) return name diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 84686b8b1941..ce9634935d24 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -9,43 +9,58 @@ import numpy as np import torch from torch import nn -from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, - WhisperProcessor) +from transformers import ( + BatchFeature, + WhisperConfig, + WhisperFeatureExtractor, + WhisperProcessor, +) from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention import Attention, AttentionType from vllm.attention.layer import MultiHeadAttention from vllm.attention.layers.cross_attention import CrossAttention -from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, - VllmConfig) +from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargsItems) +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser -from vllm.multimodal.processing import (BaseProcessingInfo, - EncDecMultiModalProcessor, - PromptReplacement, PromptUpdate) +from vllm.multimodal.processing import ( + BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement, + PromptUpdate, +) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.transformers_utils.processor import cached_get_processor from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, - SupportsTranscription) -from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, - make_layers, maybe_prefix) +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + cast_overflow_tensors, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) @@ -108,7 +123,7 @@ "uk": "Ukrainian", "ur": "Urdu", "vi": "Vietnamese", - "cy": "Welsh" + "cy": "Welsh", } @@ -120,8 +135,7 @@ class WhisperAudioInputs(TensorSchema): - t: Time frames (M) """ - input_features: Annotated[Optional[NestedTensors], - TensorShape("b", "nmb", "t")] + input_features: Annotated[Optional[NestedTensors], TensorShape("b", "nmb", "t")] class WhisperEncoderAttention(MultiHeadAttention): @@ -153,7 +167,6 @@ def forward( class WhisperPositionalEmbedding(nn.Embedding): - def __init__(self, num_positions: int, embedding_dim: int): super().__init__(num_positions, embedding_dim) @@ -162,7 +175,6 @@ def forward(self, position_ids): class WhisperAttention(nn.Module): - def __init__( self, embed_dim: int, @@ -196,7 +208,8 @@ def __init__( if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: " - f"{self.embed_dim} and `num_heads`: {num_heads}).") + f"{self.embed_dim} and `num_heads`: {num_heads})." + ) self.scaling = self.head_dim**-0.5 self._init_qkv(embed_dim, bias, quant_config, prefix=prefix) @@ -269,7 +282,6 @@ def forward( class WhisperCrossAttention(WhisperAttention): - def __init__( self, embed_dim: int, @@ -336,7 +348,6 @@ def forward( class WhisperMLP(nn.Module): - def __init__( self, embed_dim: int, @@ -369,7 +380,6 @@ def forward(self, hidden_states: torch.Tensor): class WhisperEncoderLayer(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -414,7 +424,6 @@ def forward( class WhisperDecoderLayer(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -474,48 +483,39 @@ def forward( class WhisperEncoder(nn.Module): - - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - init_in_fp32: bool = False): + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False + ): super().__init__() config = vllm_config.model_config.hf_config embed_dim = config.d_model self.num_mel_bins = config.num_mel_bins self.max_source_positions = config.max_source_positions - self.embed_scale = (math.sqrt(embed_dim) - if config.scale_embedding else 1.0) - - self.conv1 = nn.Conv1d(self.num_mel_bins, - embed_dim, - kernel_size=3, - padding=1) - self.conv2 = nn.Conv1d(embed_dim, - embed_dim, - kernel_size=3, - stride=2, - padding=1) + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, - lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, - prefix=f"{prefix}.layers"), + lambda prefix: WhisperEncoderLayer( + vllm_config=vllm_config, prefix=f"{prefix}.layers" + ), prefix=f"{prefix}.layers", ) self.layer_norm = nn.LayerNorm(config.d_model) - maybe_fp32_init_ctx = set_default_torch_dtype( - torch.float32) if init_in_fp32 else nullcontext() + maybe_fp32_init_ctx = ( + set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext() + ) with ( - torch.no_grad(), - maybe_fp32_init_ctx, + torch.no_grad(), + maybe_fp32_init_ctx, ): - self.embed_positions = nn.Embedding(self.max_source_positions, - embed_dim) + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) self.embed_positions.weight.copy_( - sinusoids(*self.embed_positions.weight.shape)) + sinusoids(*self.embed_positions.weight.shape) + ) def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]): hidden_states = [] @@ -523,9 +523,9 @@ def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]): embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv2(embeds)) embeds = embeds.transpose(-1, -2) - embeds = (embeds + - self.embed_positions.weight[:embeds.size(-2), :]).to( - embeds.dtype) + embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to( + embeds.dtype + ) hidden_states.append(embeds) hidden_states = torch.cat(hidden_states) @@ -537,7 +537,6 @@ def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]): class WhisperDecoder(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -545,17 +544,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = config.pad_token_id self.max_target_positions = config.max_target_positions self.max_source_positions = config.max_source_positions - self.embed_scale = (math.sqrt(config.d_model) - if config.scale_embedding else 1.0) + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, - self.padding_idx) + self.embed_tokens = nn.Embedding( + config.vocab_size, config.d_model, self.padding_idx + ) self.embed_positions = WhisperPositionalEmbedding( - self.max_target_positions, config.d_model) + self.max_target_positions, config.d_model + ) self.start_layer, self.end_layer, self.layers = make_layers( config.decoder_layers, - lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config, - prefix=f"{prefix}.layers"), + lambda prefix: WhisperDecoderLayer( + vllm_config=vllm_config, prefix=f"{prefix}.layers" + ), prefix=f"{prefix}.layers", ) self.layer_norm = nn.LayerNorm(config.d_model) @@ -584,13 +585,14 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: class WhisperModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - self.encoder = WhisperEncoder(vllm_config=vllm_config, - prefix=f"{prefix}.encoder") - self.decoder = WhisperDecoder(vllm_config=vllm_config, - prefix=f"{prefix}.decoder") + self.encoder = WhisperEncoder( + vllm_config=vllm_config, prefix=f"{prefix}.encoder" + ) + self.decoder = WhisperDecoder( + vllm_config=vllm_config, prefix=f"{prefix}.decoder" + ) def forward( self, @@ -614,8 +616,7 @@ def get_encoder_outputs( return None return self.encoder(input_features) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), @@ -645,15 +646,13 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class WhisperProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> WhisperConfig: return self.ctx.get_hf_config(WhisperConfig) @@ -670,8 +669,7 @@ def get_hf_processor(self, **kwargs: object) -> WhisperProcessor: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": 1} - def get_feature_extractor(self, - **kwargs: object) -> WhisperFeatureExtractor: + def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: hf_processor = self.get_hf_processor(**kwargs) feature_extractor = hf_processor.feature_extractor # type: ignore assert isinstance(feature_extractor, WhisperFeatureExtractor) @@ -682,7 +680,6 @@ def get_num_audio_tokens(self) -> int: class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) @@ -703,16 +700,13 @@ def get_dummy_mm_data( audio_overrides = mm_options.get("audio") if mm_options else None return { - "audio": - self._get_dummy_audios(length=audio_len, - num_audios=num_audios, - overrides=audio_overrides) + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) } -class WhisperMultiModalProcessor( - EncDecMultiModalProcessor[WhisperProcessingInfo]): - +class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.info.get_feature_extractor() return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) @@ -779,11 +773,14 @@ def _get_prompt_updates( ] -@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor, - info=WhisperProcessingInfo, - dummy_inputs=WhisperDummyInputsBuilder) -class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, - SupportsMultiModal): +@MULTIMODAL_REGISTRY.register_processor( + WhisperMultiModalProcessor, + info=WhisperProcessingInfo, + dummy_inputs=WhisperDummyInputsBuilder, +) +class WhisperForConditionalGeneration( + nn.Module, SupportsTranscription, SupportsMultiModal +): packed_modules_mapping = { "self_attn.qkv_proj": [ "self_attn.q_proj", @@ -793,10 +790,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"], } - hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ - ".fc1.": ".mlp.fc1.", - ".fc2.": ".mlp.fc2." - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."} + ) # Whisper only supports audio-conditioned generation. supports_transcription_only = True @@ -811,23 +807,26 @@ def validate_language(cls, language: Optional[str]) -> Optional[str]: logger.warning( "Defaulting to language='en'. If you wish to transcribe " "audio in a different language, pass the `language` field " - "in the TranscriptionRequest.") + "in the TranscriptionRequest." + ) language = "en" return super().validate_language(language) @classmethod def get_generation_prompt( - cls, - audio: np.ndarray, - model_config: ModelConfig, # not needed here - stt_config: SpeechToTextConfig, - language: Optional[str], - task_type: Literal["transcribe", "translate"], - request_prompt: str, - to_language: Optional[str]) -> PromptType: + cls, + audio: np.ndarray, + model_config: ModelConfig, # not needed here + stt_config: SpeechToTextConfig, + language: Optional[str], + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: Optional[str], + ) -> PromptType: if language is None: raise ValueError( - "Language must be specified when creating the Whisper prompt") + "Language must be specified when creating the Whisper prompt" + ) prompt = { "encoder_prompt": { # Whisper does not support encoder prompt. @@ -836,10 +835,11 @@ def get_generation_prompt( "audio": (audio, stt_config.sample_rate), }, }, - "decoder_prompt": - ((f"<|prev|>{request_prompt}" if request_prompt else "") + - f"<|startoftranscript|><|{language}|>" + - f"<|{task_type}|><|notimestamps|>") + "decoder_prompt": ( + (f"<|prev|>{request_prompt}" if request_prompt else "") + + f"<|startoftranscript|><|{language}|>" + + f"<|{task_type}|><|notimestamps|>" + ), } return cast(PromptType, prompt) @@ -851,8 +851,9 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: raise ValueError("Only audio modality is supported") @classmethod - def get_speech_to_text_config(cls, model_config: ModelConfig, - task_type: str) -> SpeechToTextConfig: + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: processor = cached_get_processor(model_config.model) return SpeechToTextConfig( @@ -861,9 +862,12 @@ def get_speech_to_text_config(cls, model_config: ModelConfig, ) @classmethod - def get_num_audio_tokens(cls, audio_duration_s: float, - stt_config: SpeechToTextConfig, - model_config: ModelConfig) -> Optional[int]: + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> Optional[int]: processor = cached_get_processor(model_config.model) hop_length = processor.feature_extractor.hop_length assert hop_length is not None @@ -871,8 +875,7 @@ def get_num_audio_tokens(cls, audio_duration_s: float, # prompts directly at least not to Whisper. # One indicator of the encoder amount of processing # is the log-mel spectogram length. - return math.ceil(audio_duration_s * stt_config.sample_rate / - hop_length) + return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -883,15 +886,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix) self.unpadded_vocab_size = config.vocab_size - self.proj_out = ParallelLMHead(config.vocab_size, - config.d_model, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "proj_out")) - self.proj_out = self.proj_out.tie_weights( - self.model.decoder.embed_tokens) + self.proj_out = ParallelLMHead( + config.vocab_size, + config.d_model, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "proj_out"), + ) + self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) def forward( self, @@ -910,8 +915,7 @@ def forward( def get_language_model(self) -> torch.nn.Module: return self.model.decoder - def get_multimodal_embeddings(self, - **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: # Required as part of SupportsMultiModal interface. audio_input = self._parse_and_validate_audio_input(**kwargs) return [self.model.get_encoder_outputs(audio_input["input_features"])] @@ -928,16 +932,16 @@ def get_input_embeddings( # Whisper does not have encoder text tokens. return self.model.decoder.get_input_embeddings(input_ids) - def _parse_and_validate_audio_input( - self, **kwargs: object) -> WhisperAudioInputs: + def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs: input_features = kwargs.pop("input_features", None) if input_features is not None: if not isinstance(input_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio features. " - f"Got type: {type(input_features)}") - input_features = torch.cat( - [feat.to(self.dtype) for feat in input_features]) + raise ValueError( + "Incorrect type of audio features. " + f"Got type: {type(input_features)}" + ) + input_features = torch.cat([feat.to(self.dtype) for feat in input_features]) return WhisperAudioInputs(input_features=input_features) @@ -945,8 +949,7 @@ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.proj_out, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) # add fake zeros bias for k_proj to state_dict @@ -955,7 +958,7 @@ def load_weights(self, weights: Iterable[tuple[str, def _create_fake_bias_for_k_proj( - weights: Iterable[tuple[str, torch.Tensor]] + weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, torch.Tensor]]: """ Create full zeros bias for k_proj weight in self-attn and x-attn layers. diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 1d68320bd9b2..b69204d02096 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """PyTorch Zamba2 model implementation for vLLM. -This module implements the Zamba2 architecture from -https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer -architectures in a hybrid model optimized for efficient sequence modeling. The +This module implements the Zamba2 architecture from +https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer +architectures in a hybrid model optimized for efficient sequence modeling. The model alternates between state space model layers and attention-based layers. """ + from collections.abc import Iterable from itertools import cycle from typing import Any, Optional, Union @@ -21,19 +22,26 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( - MambaStateDtypeCalculator, MambaStateShapeCalculator) + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors @@ -43,7 +51,7 @@ class Zamba2LoRA(nn.Module): """LoRA layer for the Zamba2 model. - + Implements a LoRA layer that is used in shared attention and gated MLP blocks. """ @@ -57,7 +65,7 @@ def __init__( prefix: str = "", ): """Initialize the attention layer. - + Args: input_dim: input dimension rank: LoRA rank @@ -66,20 +74,15 @@ def __init__( """ super().__init__() - self.A = ColumnParallelLinear(input_dim, - rank, - bias=False, - quant_config=quant_config, - gather_output=True) + self.A = ColumnParallelLinear( + input_dim, rank, bias=False, quant_config=quant_config, gather_output=True + ) if isinstance(output_dim, list): B_class = MergedColumnParallelLinear else: B_class = ColumnParallelLinear - self.B = B_class(rank, - output_dim, - bias=False, - quant_config=quant_config) + self.B = B_class(rank, output_dim, bias=False, quant_config=quant_config) def forward( self, @@ -92,8 +95,8 @@ def forward( class Zamba2Attention(nn.Module): """Multi-head attention mechanism for the Zamba2 model. - - Implements attention with parallel computation, QKV projections, optional + + Implements attention with parallel computation, QKV projections, optional adapters and rotary position embeddings. The attention is computed across distributed blocks for efficient processing. """ @@ -108,7 +111,7 @@ def __init__( prefix: str = "", ) -> None: """Initialize the attention layer. - + Args: config: The Zamba2 model configuration bare_block_idx: Index of the bare attention block @@ -129,15 +132,17 @@ def __init__( self.num_attention_heads = config.num_attention_heads // tp_size self.attention_head_dim = config.attention_head_dim self.qkv_size = self.attention_hidden_size // tp_size - self.scale = (self.attention_head_dim / 2)**-0.5 + self.scale = (self.attention_head_dim / 2) ** -0.5 - if (self.attention_head_dim * - self.total_num_attention_heads) != self.attention_hidden_size: + if ( + self.attention_head_dim * self.total_num_attention_heads + ) != self.attention_hidden_size: raise ValueError( f"attention_hidden_size must be divisible by" f" num_attention_heads" f" (got `attention_hidden_size`: {self.attention_hidden_size}" - f" and `num_heads`: {self.num_attention_heads}).") + f" and `num_heads`: {self.num_attention_heads})." + ) self.qkv_proj = QKVParallelLinear( self.attention_hidden_size, @@ -146,10 +151,12 @@ def __init__( bias=False, quant_config=quant_config, ) - self.o_proj = RowParallelLinear(self.attention_hidden_size, - config.hidden_size, - bias=False, - quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.attention_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) # Even though in Zamba2 weights are shared between attention layers, KV # cache is unique for every attention layer. Hence, we need to define @@ -158,8 +165,11 @@ def __init__( # Initialize attention blocks with proper indexing self.dpa_list = nn.ModuleList([]) - j = bare_block_idx * (self.num_hybrid_layers + config.num_mem_blocks - - 1) // config.num_mem_blocks + j = ( + bare_block_idx + * (self.num_hybrid_layers + config.num_mem_blocks - 1) + // config.num_mem_blocks + ) for block_idx in range(self.num_hybrid_layers): if block_idx % config.num_mem_blocks == bare_block_idx: dpa = Attention( @@ -226,18 +236,17 @@ def forward( position_ids: torch.Tensor, ) -> torch.Tensor: """Forward pass through the attention layer. - + Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] position_ids: Position IDs for positional embeddings block_idx: Current shared transformer block index - + Returns: Output tensor [batch_size, seq_len, hidden_size] """ qkv, _ = self.qkv_proj(hidden_states) - query_states, key_states, value_states = qkv.split([self.qkv_size] * 3, - dim=-1) + query_states, key_states, value_states = qkv.split([self.qkv_size] * 3, dim=-1) if self.config.use_shared_attention_adapter: # Apply adapter transformations to Q, K, V if enabled @@ -257,9 +266,9 @@ def forward( value_states = value_states + v_lora_output if self.config.use_mem_rope: - query_states, key_states = self.rotary_emb(position_ids, - query_states, - key_states) + query_states, key_states = self.rotary_emb( + position_ids, query_states, key_states + ) y = self.dpa_list[block_idx](query_states, key_states, value_states) y, _ = self.o_proj(y) @@ -268,9 +277,9 @@ def forward( class Zamba2MLP(nn.Module): """Feed-forward MLP layer for the Zamba2 model. - - Implements a gated feed-forward network that projects inputs to a larger - intermediate size, applies GELU activation with gating, then projects back + + Implements a gated feed-forward network that projects inputs to a larger + intermediate size, applies GELU activation with gating, then projects back to the original size. Includes optional adapter layers for model adaptation. """ @@ -283,7 +292,7 @@ def __init__( prefix: str = "", ) -> None: """Initialize the MLP layer. - + Args: config: The Zamba2 model configuration bare_block_idx: Index of the bare block in the model @@ -302,17 +311,22 @@ def __init__( self.hidden_size, 2 * [self.intermediate_size], # 2x for gate and input projections bias=self.config.add_bias_linear, - quant_config=quant_config) + quant_config=quant_config, + ) - self.down_proj = RowParallelLinear(self.intermediate_size, - self.hidden_size, - bias=self.config.add_bias_linear, - quant_config=quant_config) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.config.add_bias_linear, + quant_config=quant_config, + ) # Only allow GELU activations if config.hidden_act != "gelu": - raise ValueError(f"Only GELU activation is supported " - f"(got `hidden_act`: {config.hidden_act})") + raise ValueError( + f"Only GELU activation is supported " + f"(got `hidden_act`: {config.hidden_act})" + ) self.act_fn = GeluAndMul() # Initialize adapter layers @@ -329,14 +343,13 @@ def __init__( gate_up_proj_adapter = nn.Identity() self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) - def forward(self, hidden_states: torch.Tensor, - block_idx: int) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, block_idx: int) -> torch.Tensor: """Forward pass through the MLP layer. - + Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] block_idx: Current shared transformer block index - + Returns: Output tensor [batch_size, seq_len, hidden_size] after applying gated feed-forward transformation @@ -360,7 +373,7 @@ def forward(self, hidden_states: torch.Tensor, class Zamba2AttentionDecoderLayer(nn.Module): """Single decoder layer combining attention and feed-forward networks. - + This layer implements a standard transformer block with: - Input layer normalization - Multi-head self-attention @@ -378,7 +391,7 @@ def __init__( prefix: str = "", ) -> None: """Initialize the decoder layer. - + Args: config: The Zamba2 model configuration bare_block_idx: Index of the bare block @@ -409,11 +422,9 @@ def __init__( # Initialize layer normalizations # Input normalization operates on concatenated states - self.input_layernorm = RMSNorm(2 * config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(2 * config.hidden_size, eps=config.rms_norm_eps) # Pre-FF normalization operates on attention output - self.pre_ff_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -423,14 +434,14 @@ def forward( positions: torch.Tensor, ) -> torch.Tensor: """Forward pass through the decoder layer. - + Args: hidden_states: Input tensor from previous layer - original_hidden_states: Original input tensor for residual + original_hidden_states: Original input tensor for residual connection block_idx: Current shared transformer block index positions: IDs for positional embeddings - + Returns: Transformed hidden states after attention and feed-forward """ @@ -440,7 +451,8 @@ def forward( # The concatenated tensor is then used as input of the pre-attention # RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712). hidden_states = torch.concatenate( - [hidden_states, original_hidden_states], dim=-1) + [hidden_states, original_hidden_states], dim=-1 + ) # Layer norm before attention hidden_states = self.input_layernorm(hidden_states) @@ -463,20 +475,22 @@ def forward( class Zamba2MambaDecoderLayer(nn.Module): """Single Mamba decoder layer with normalization. - - This implements a Mamba block. It includes input normalization - and can process sequences using either chunked or full + + This implements a Mamba block. It includes input normalization + and can process sequences using either chunked or full computation depending on configuration. """ - def __init__(self, - config: Zamba2Config, - model_config: Optional[ModelConfig] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: + def __init__( + self, + config: Zamba2Config, + model_config: Optional[ModelConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: """Initialize the Mamba decoder layer. - + Args: config: The Zamba2 model configuration quant_config: Configuration for model quantization @@ -485,26 +499,26 @@ def __init__(self, # Initialize Mamba mixer with expanded intermediate size intermediate_size = config.mamba_expand * config.hidden_size - self.mamba = MambaMixer2(hidden_size=config.hidden_size, - ssm_state_size=config.mamba_d_state, - conv_kernel_size=config.mamba_d_conv, - intermediate_size=intermediate_size, - use_conv_bias=config.use_conv_bias, - use_bias=config.add_bias_linear, - n_groups=config.mamba_ngroups, - num_heads=config.n_mamba_heads, - head_dim=intermediate_size // - config.n_mamba_heads, - rms_norm_eps=config.rms_norm_eps, - activation="silu", - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.mixer") + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=intermediate_size, + use_conv_bias=config.use_conv_bias, + use_bias=config.add_bias_linear, + n_groups=config.mamba_ngroups, + num_heads=config.n_mamba_heads, + head_dim=intermediate_size // config.n_mamba_heads, + rms_norm_eps=config.rms_norm_eps, + activation="silu", + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.mixer", + ) # Input normalization - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -514,14 +528,14 @@ def forward( original_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass through the Mamba decoder layer. - + Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] transformer_hidden_states: Optional output from transformer path Added to input if provided (used in hybrid architecture) positions: Optional position IDs (unused in Mamba) original_hidden_states: Optional original inputs (unused in Mamba) - + Returns: Transformed hidden states with residual connection applied """ @@ -555,7 +569,7 @@ def forward( class Zamba2HybridLayer(nn.Module): """Hybrid layer combining Transformer and Mamba architectures. - + This layer implements the hybrid architecture described in the Zamba paper, where a shared transformer pathway processes input in parallel with a Mamba pathway. The transformer output is projected and added to the Mamba input @@ -573,22 +587,26 @@ def __init__( prefix: str = "", ) -> None: """Initialize the hybrid layer. - + Args: shared_transformer: Transformer decoder layer for attention pathway """ super().__init__() self.block_idx = block_idx self.shared_transformer = shared_transformer - self.linear = ReplicatedLinear(config.hidden_size, - config.hidden_size, - bias=False, - quant_config=quant_config) - self.mamba_decoder = Zamba2MambaDecoderLayer(config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix) + self.linear = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) + self.mamba_decoder = Zamba2MambaDecoderLayer( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) def forward( self, @@ -597,19 +615,19 @@ def forward( positions: torch.Tensor, ) -> torch.Tensor: """Forward pass through the hybrid layer. - + Processes input through parallel transformer and Mamba paths: 1. Transformer path processes input with attention 2. Transformer output is projected to match hidden size 3. Projected output is added to Mamba path input 4. Final output combines both paths' representations - + Args: hidden_states: Input tensor [batch_size, seq_len, hidden_size] - original_hidden_states: Original input for transformer residual + original_hidden_states: Original input for transformer residual connection positions: Position IDs for positional embeddings - + Returns: Output tensor combining transformer and Mamba representations """ @@ -636,16 +654,16 @@ def forward( @support_torch_compile class Zamba2Model(nn.Module): """Core Zamba2 model combining transformer and Mamba architectures. - - The model processes input through a sequence of hybrid and Mamba-only + + The model processes input through a sequence of hybrid and Mamba-only layers, using token embeddings and final layer normalization. """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: """Initialize the Zamba2 model. - + Args: - vllm_config: Configuration object containing model, cache, + vllm_config: Configuration object containing model, cache, quantization and LoRA settings prefix: Optional prefix for parameter names in state dict """ @@ -660,8 +678,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: assert not is_lora_enabled self.config = config - lora_vocab = ((lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0) + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size @@ -679,15 +700,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: } # Create cyclic iterator of transformer blocks - blocks = cycle([ - Zamba2AttentionDecoderLayer(config, - bare_block_idx=idx, - num_hybrid_layers=len(layer2block_map), - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}") - for idx in range(config.num_mem_blocks) - ]) + blocks = cycle( + [ + Zamba2AttentionDecoderLayer( + config, + bare_block_idx=idx, + num_hybrid_layers=len(layer2block_map), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}", + ) + for idx in range(config.num_mem_blocks) + ] + ) # Initialize layers according to block type configuration layers = [] @@ -699,32 +724,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: block = next(blocks) block_idx = layer2block_map[layer_idx] layers.append( - Zamba2HybridLayer(block, - config, - block_idx, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix)) + Zamba2HybridLayer( + block, + config, + block_idx, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + ) else: layers.append( - Zamba2MambaDecoderLayer(config, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix)) + Zamba2MambaDecoderLayer( + config, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ) + ) self.layers = nn.ModuleList(layers) # Final layer normalization - self.final_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. - + Args: input_ids: Tensor of input token IDs - + Returns: Embedded representation of the input tokens """ @@ -737,14 +767,14 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: """Forward pass through the model. - + Args: input_ids: Input token IDs positions: Position IDs for embeddings inputs_embeds: Optional pre-computed input embeddings - + Returns: - Either final hidden states or intermediate tensors for pipeline + Either final hidden states or intermediate tensors for pipeline parallelism """ # Handle pipeline parallelism for first rank @@ -765,8 +795,7 @@ def forward( hidden_states = self.final_layernorm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -780,8 +809,7 @@ def load_weights(self, weights: Iterable[tuple[str, for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in chkpt_weight_name: continue - chkpt_weight_name = chkpt_weight_name.replace( - weight_name, param_name) + chkpt_weight_name = chkpt_weight_name.replace(weight_name, param_name) param = params_dict[chkpt_weight_name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -790,8 +818,7 @@ def load_weights(self, weights: Iterable[tuple[str, if chkpt_weight_name not in params_dict: continue param = params_dict[chkpt_weight_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(chkpt_weight_name) return loaded_params @@ -799,26 +826,28 @@ def load_weights(self, weights: Iterable[tuple[str, class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): """Zamba2 model with causal language modeling head. - + This class wraps the core Zamba2 model and adds: - A language modeling head for next token prediction - Mamba state caching functionality - Support for model parallelism and quantization - Sampling capabilities for text generation """ + # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={ - "A_log": "A", - "0.weight": "A.weight", - "1.weight": "B.weight", - }) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "A_log": "A", + "0.weight": "A.weight", + "1.weight": "B.weight", + } + ) @classmethod def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, @@ -857,22 +886,19 @@ def get_mamba_state_shape_from_config( def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: """Initialize the Zamba2 model for causal language modeling. - + Args: vllm_config: Configuration containing model, cache, quantization, LoRA and scheduler settings prefix: Optional prefix for parameter names - + Raises: - AssertionError: If prefix caching is enabled + AssertionError: If prefix caching is enabled (not supported by Mamba) """ config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config - assert not cache_config.enable_prefix_caching, \ - "Mamba does not support prefix caching" super().__init__() self.config = config @@ -884,8 +910,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size # Initialize core model - self.model = Zamba2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = Zamba2Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) # Initialize language modeling head self.lm_head = ParallelLMHead( @@ -895,15 +922,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) # Tie weights with input embeddings if using same dimensions self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) # Initialize logits processing and sampling - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size + ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Convert input token IDs to embeddings. @@ -914,19 +943,21 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """ return self.model.get_input_embeddings(input_ids) - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: Any) -> torch.Tensor: + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> torch.Tensor: """Forward pass through the model. - + Args: input_ids: Input token IDs positions: Position IDs for embeddings inputs_embeds: Optional pre-computed input embeddings **kwargs: Additional arguments passed to cache manager - + Returns: Output hidden states """ @@ -954,7 +985,6 @@ def compute_logits( logits = self.logits_processor(self.lm_head, hidden_states) return logits - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 66add98dab44..9341665f1bca 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -9,14 +9,21 @@ import torch from torch.nn import Parameter -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import init_logger __all__ = [ - "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", - "ModelWeightParameter", "ChannelQuantScaleParameter", - "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" + "BasevLLMParameter", + "PackedvLLMParameter", + "PerTensorScaleParameter", + "ModelWeightParameter", + "ChannelQuantScaleParameter", + "GroupQuantScaleParameter", + "PackedColumnParameter", + "RowvLLMParameter", ] logger = init_logger(__name__) @@ -30,7 +37,6 @@ class BasevLLMParameter(Parameter): """ def __new__(cls, data: Optional[torch.Tensor], **kwargs): - return super().__new__(cls, data=data, requires_grad=False) def __init__(self, data: torch.Tensor, weight_loader: Callable): @@ -52,9 +58,9 @@ def __init__(self, data: torch.Tensor, weight_loader: Callable): # This sometimes causes OOM errors during model loading. To avoid this, # we sync the param tensor after its weight loader is called. from vllm.platforms import current_platform + if current_platform.use_sync_weight_loader(): - weight_loader = current_platform.make_synced_weight_loader( - weight_loader) + weight_loader = current_platform.make_synced_weight_loader(weight_loader) self._weight_loader = weight_loader self.tp_rank = get_tensor_model_parallel_rank() @@ -67,8 +73,9 @@ def weight_loader(self) -> Callable: # weight loading should be implemented via Model.load_weights. In the # meantime, support deleting and overriding `weight_loader`` attribute if self._weight_loader is None: - raise AttributeError(f"{self.__class__.__name__} weight_loader " - "attribute has been deleted") + raise AttributeError( + f"{self.__class__.__name__} weight_loader attribute has been deleted" + ) return self._weight_loader @weight_loader.setter @@ -82,11 +89,12 @@ def weight_loader(self): def _is_1d_and_scalar(self, loaded_weight: torch.Tensor): cond1 = self.data.ndim == 1 and self.data.numel() == 1 cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1 - return (cond1 and cond2) + return cond1 and cond2 def _assert_and_load(self, loaded_weight: torch.Tensor): - assert (self.data.shape == loaded_weight.shape - or self._is_1d_and_scalar(loaded_weight)) + assert self.data.shape == loaded_weight.shape or self._is_1d_and_scalar( + loaded_weight + ) self.data.copy_(loaded_weight) def load_column_parallel_weight(self, loaded_weight: torch.Tensor): @@ -121,11 +129,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): class _ColumnvLLMParameter(BasevLLMParameter): """ - Private class defining weight loading functionality + Private class defining weight loading functionality (load_merged_column_weight, load_qkv_weight) for parameters being loaded into linear layers with column parallelism. This includes QKV and MLP layers which are - not already fused on disk. Requires an output dimension + not already fused on disk. Requires an output dimension to be defined. Called within the weight loader of each of the column parallel linear layers. """ @@ -140,57 +148,55 @@ def output_dim(self): def load_column_parallel_weight(self, loaded_weight: torch.Tensor): shard_size = self.data.shape[self.output_dim] - loaded_weight = loaded_weight.narrow(self.output_dim, - self.tp_rank * shard_size, - shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, self.tp_rank * shard_size, shard_size + ) assert self.data.shape == loaded_weight.shape self.data.copy_(loaded_weight) def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): - shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") # TODO: move these to PackedColumnParameter and PackedvLLMParameter - if isinstance( - self, - (PackedColumnParameter, - PackedvLLMParameter)) and self.packed_dim == self.output_dim: + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.packed_dim == self.output_dim + ): shard_size, shard_offset = self.adjust_shard_indexes_for_packing( - shard_offset=shard_offset, shard_size=shard_size) + shard_offset=shard_offset, shard_size=shard_size + ) param_data = self.data - param_data = param_data.narrow(self.output_dim, shard_offset, - shard_size) - loaded_weight = loaded_weight.narrow(self.output_dim, - self.tp_rank * shard_size, - shard_size) + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, self.tp_rank * shard_size, shard_size + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): - shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") shard_id = kwargs.get("shard_id") num_heads = kwargs.get("num_heads") # TODO: move these to PackedColumnParameter and PackedvLLMParameter - if isinstance( - self, - (PackedColumnParameter, - PackedvLLMParameter)) and self.output_dim == self.packed_dim: + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.output_dim == self.packed_dim + ): shard_size, shard_offset = self.adjust_shard_indexes_for_packing( - shard_offset=shard_offset, shard_size=shard_size) + shard_offset=shard_offset, shard_size=shard_size + ) param_data = self.data - shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank // - num_heads) - param_data = param_data.narrow(self.output_dim, shard_offset, - shard_size) - loaded_weight = loaded_weight.narrow(self.output_dim, - shard_id * shard_size, shard_size) + shard_id = self.tp_rank if shard_id == "q" else self.tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, shard_id * shard_size, shard_size + ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -214,9 +220,9 @@ def input_dim(self): def load_row_parallel_weight(self, loaded_weight: torch.Tensor): shard_size = self.data.shape[self.input_dim] - loaded_weight = loaded_weight.narrow(self.input_dim, - self.tp_rank * shard_size, - shard_size) + loaded_weight = loaded_weight.narrow( + self.input_dim, self.tp_rank * shard_size, shard_size + ) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) @@ -230,6 +236,7 @@ class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): Parameter class for linear layer weights. Uses both column and row parallelism. """ + pass @@ -238,6 +245,7 @@ class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): Parameter class for weight scales loaded for weights with grouped quantization. Uses both column and row parallelism. """ + pass @@ -246,6 +254,7 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter): Parameter class for weight scales loaded for weights with channel-wise quantization. Equivalent to _ColumnvLLMParameter. """ + pass @@ -256,11 +265,11 @@ class PerTensorScaleParameter(BasevLLMParameter): layers (e.g. for QKV, there are 3 scales loaded from disk). This is relevant to weights with per-tensor quantization. Adds functionality to map the scalers to a shard during - weight loading. + weight loading. - Note: additional parameter manipulation may be handled - for each quantization config specifically, within - process_weights_after_loading + Note: additional parameter manipulation may be handled + for each quantization config specifically, within + process_weights_after_loading """ def __init__(self, **kwargs): @@ -280,10 +289,11 @@ def load_qkv_weight(self, *args, **kwargs): def load_column_parallel_weight(self, *args, **kwargs): super().load_row_parallel_weight(*args, **kwargs) - def _load_into_shard_id(self, loaded_weight: torch.Tensor, - shard_id: Union[str, int], **kwargs): + def _load_into_shard_id( + self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs + ): """ - Slice the parameter data based on the shard id for + Slice the parameter data based on the shard id for loading. """ @@ -308,12 +318,14 @@ class PackedColumnParameter(_ColumnvLLMParameter): for more details on the packed properties. """ - def __init__(self, - packed_factor: Union[int, Fraction], - packed_dim: int, - marlin_tile_size: Optional[int] = None, - bitblas_tile_size: Optional[int] = None, - **kwargs): + def __init__( + self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + bitblas_tile_size: Optional[int] = None, + **kwargs, + ): self._packed_factor = packed_factor self._packed_dim = packed_dim self._marlin_tile_size = marlin_tile_size @@ -342,7 +354,8 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): shard_offset=shard_offset, packed_factor=self.packed_factor, marlin_tile_size=self.marlin_tile_size, - bitblas_tile_size=self.bitblas_tile_size) + bitblas_tile_size=self.bitblas_tile_size, + ) class PackedvLLMParameter(ModelWeightParameter): @@ -351,17 +364,19 @@ class PackedvLLMParameter(ModelWeightParameter): Example: GPTQ Marlin weights are int4 or int8, packed into int32. Extends the ModelWeightParameter to take in the packed factor, the packed dimension, and optionally, marlin - tile size for marlin kernels. Adjusts the shard_size and + tile size for marlin kernels. Adjusts the shard_size and shard_offset for fused linear layers model weight loading by accounting for packing and optionally, marlin tile size. """ - def __init__(self, - packed_factor: Union[int, Fraction], - packed_dim: int, - marlin_tile_size: Optional[int] = None, - bitblas_tile_size: Optional[int] = None, - **kwargs): + def __init__( + self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + bitblas_tile_size: Optional[int] = None, + **kwargs, + ): self._packed_factor = packed_factor self._packed_dim = packed_dim self._marlin_tile_size = marlin_tile_size @@ -390,7 +405,8 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): shard_offset=shard_offset, packed_factor=self.packed_factor, marlin_tile_size=self.marlin_tile_size, - bitblas_tile_size=self.bitblas_tile_size) + bitblas_tile_size=self.bitblas_tile_size, + ) class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): @@ -410,6 +426,7 @@ class SharedWeightParameter(BasevLLMParameter): `MergedColumnParallelLinear`, the transform weights must stay separate tensors in order to allow for tensor memory sharing between layers. """ + # global registry for sharing tensors based on passed `data_key` # this dict holds weaksrefs to avoid memory leak after model cleanup tensors_registry: WeakValueDictionary = WeakValueDictionary() @@ -426,8 +443,7 @@ def __new__(cls, **kwargs): return super().__new__(cls, data=None, **kwargs) def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs): - weight_loader: Callable = kwargs.get( - "weight_loader") # type: ignore[assignment] + weight_loader: Callable = kwargs.get("weight_loader") # type: ignore[assignment] super().__init__(data=None, weight_loader=weight_loader) self.local_tensors = set() @@ -435,12 +451,14 @@ def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs): self.kwargs = { "input_dim": input_dim, "output_dim": output_dim, - "weight_loader": self._fake_weight_loader + "weight_loader": self._fake_weight_loader, } if self.tp_size > 1: - raise NotImplementedError(f"{self.__class__.__name__} does not " - "currently support tensor parallelism") + raise NotImplementedError( + f"{self.__class__.__name__} does not " + "currently support tensor parallelism" + ) def add_partition(self, index: int, data_key: Hashable, *args, **kwargs): """ @@ -460,8 +478,7 @@ def add_partition(self, index: int, data_key: Hashable, *args, **kwargs): data = self.tensors_registry[data_key] # create associated model parameter - self.partitions[index] = ModelWeightParameter( - data=data, **self.kwargs) # type: ignore[arg-type] + self.partitions[index] = ModelWeightParameter(data=data, **self.kwargs) # type: ignore[arg-type] # hold local reference, since ModelWeightParameter does not # see https://github.com/pytorch/pytorch/issues/75932 @@ -471,8 +488,7 @@ def load_column_parallel_weight(self, loaded_weight: torch.Tensor): assert len(self.partitions) == 1 and 0 in self.partitions partition = self.partitions[0] - ModelWeightParameter.load_column_parallel_weight( - partition, loaded_weight) + ModelWeightParameter.load_column_parallel_weight(partition, loaded_weight) def load_row_parallel_weight(self, loaded_weight: torch.Tensor): assert len(self.partitions) == 1 and 0 in self.partitions @@ -490,10 +506,8 @@ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): shard_offset = self.tp_rank * shard_size ModelWeightParameter.load_merged_column_weight( - partition, - loaded_weight, - shard_offset=shard_offset, - shard_size=shard_size) + partition, loaded_weight, shard_offset=shard_offset, shard_size=shard_size + ) def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): partition_id = self._shard_id_as_int(kwargs.pop("shard_id")) @@ -517,33 +531,42 @@ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): def process_weights_after_loading(self): for key in self.partitions: self.partitions[key] = torch.nn.Parameter( - data=self.partitions[key].data, requires_grad=False) + data=self.partitions[key].data, requires_grad=False + ) @property def data(self): - raise ValueError("Accessing `data` of a " - "`PartitionedModelWeightParameter` is not allowed. " - "Instead, use `get_partition` to get the weight of " - "the particular partition you want to access") + raise ValueError( + "Accessing `data` of a " + "`PartitionedModelWeightParameter` is not allowed. " + "Instead, use `get_partition` to get the weight of " + "the particular partition you want to access" + ) - def _fake_weight_loader(self, param: BasevLLMParameter, - loaded_weight: torch.Tensor, - loaded_weight_shard_id: Optional[Union[str, int]]): - raise ValueError("When loading partition weights of " - f"{self.__class__.__name__}, use methods provided by " - f"{self.__class__.__name__}, not partition loader") + def _fake_weight_loader( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_weight_shard_id: Optional[Union[str, int]], + ): + raise ValueError( + "When loading partition weights of " + f"{self.__class__.__name__}, use methods provided by " + f"{self.__class__.__name__}, not partition loader" + ) -def permute_param_layout_(param: BasevLLMParameter, input_dim: int, - output_dim: int, **kwargs) -> BasevLLMParameter: +def permute_param_layout_( + param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs +) -> BasevLLMParameter: """ - Permute a parameter's layout to the specified input and output dimensions, + Permute a parameter's layout to the specified input and output dimensions, useful for forcing the parameter into a known layout, for example, if I need - a packed (quantized) weight matrix to be in the layout + a packed (quantized) weight matrix to be in the layout {input_dim = 0, output_dim = 1, packed_dim = 0} then I can call: permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) - to ensure x is in the correct layout (permuting it to the correct layout if + to ensure x is in the correct layout (permuting it to the correct layout if required, asserting if it cannot get it to the correct layout) """ @@ -551,35 +574,34 @@ def permute_param_layout_(param: BasevLLMParameter, input_dim: int, curr_output_dim = getattr(param, "output_dim", None) if curr_input_dim is None or curr_output_dim is None: - assert param.data.dim() == 2,\ - "permute_param_layout_ only supports 2D parameters when either "\ + assert param.data.dim() == 2, ( + "permute_param_layout_ only supports 2D parameters when either " "input_dim or output_dim is not set" + ) # if one of the dimensions is not set, set it to the opposite of the other # we can only do this since we asserted the parameter is 2D above if curr_input_dim is None: - assert curr_output_dim is not None,\ - "either input or output dim must be set" + assert curr_output_dim is not None, "either input or output dim must be set" curr_input_dim = (curr_output_dim + 1) % 2 if curr_output_dim is None: - assert curr_input_dim is not None,\ - "either input or output dim must be set" + assert curr_input_dim is not None, "either input or output dim must be set" curr_output_dim = (curr_input_dim + 1) % 2 # create permutation from the current layout to the layout with # self.input_dim at input_dim and self.output_dim at output_dim preserving # other dimensions perm = [ - i for i in range(param.data.dim()) - if i not in [curr_input_dim, curr_output_dim] + i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim] ] perm.insert(input_dim, curr_input_dim) perm.insert(output_dim, curr_output_dim) if "packed_dim" in kwargs: - assert hasattr(param, "packed_dim") and\ - param.packed_dim == perm[kwargs["packed_dim"]],\ - "permute_param_layout_ currently doesn't support repacking" + assert ( + hasattr(param, "packed_dim") + and param.packed_dim == perm[kwargs["packed_dim"]] + ), "permute_param_layout_ currently doesn't support repacking" param.data = param.data.permute(*perm) if hasattr(param, "_input_dim"): @@ -592,29 +614,30 @@ def permute_param_layout_(param: BasevLLMParameter, input_dim: int, return param -def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, - marlin_tile_size): +def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset, - bitblas_tile_size): +def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset, bitblas_tile_size): return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size -def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, - marlin_tile_size, bitblas_tile_size): +def _adjust_shard_indexes_for_packing( + shard_size, shard_offset, packed_factor, marlin_tile_size, bitblas_tile_size +): shard_size = shard_size // packed_factor shard_offset = shard_offset // packed_factor if marlin_tile_size is not None: return _adjust_shard_indexes_for_marlin( shard_size=shard_size, shard_offset=shard_offset, - marlin_tile_size=marlin_tile_size) + marlin_tile_size=marlin_tile_size, + ) elif bitblas_tile_size is not None: return _adjust_shard_indexes_for_bitblas( shard_size=shard_size, shard_offset=shard_offset, - bitblas_tile_size=bitblas_tile_size) + bitblas_tile_size=bitblas_tile_size, + ) return shard_size, shard_offset diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 543918418953..4abd2625f806 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -30,8 +30,7 @@ def set_weight_attrs( if weight_attrs is None: return for key, value in weight_attrs.items(): - assert not hasattr( - weight, key), f"Overwriting existing tensor attribute: {key}" + assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}" # NOTE(woosuk): During weight loading, we often do something like: # narrowed_tensor = param.data.narrow(0, offset, len) @@ -44,8 +43,7 @@ def set_weight_attrs( # TODO(woosuk): Remove this hack once we have a better solution. from vllm.platforms import current_platform - if current_platform.use_sync_weight_loader( - ) and key == "weight_loader": + if current_platform.use_sync_weight_loader() and key == "weight_loader": value = current_platform.make_synced_weight_loader(value) setattr(weight, key, value) @@ -63,18 +61,19 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: child_map = getattr(child, "packed_modules_mapping", None) child_map = copy.deepcopy(child_map) if child_map is not None else {} - if any((k in parent_map and parent_map[k] != v) - for k, v in child_map.items()): + if any((k in parent_map and parent_map[k] != v) for k, v in child_map.items()): raise ValueError( f"Can't update {type(model).__name__}'s packed_modules_mapping " - f"safely because of conflicts from {type(child).__name__}.") + f"safely because of conflicts from {type(child).__name__}." + ) else: parent_map.update(child_map) return parent_map def get_moe_expert_mapping( - model: torch.nn.Module, ) -> list[tuple[str, str, int, str]]: + model: torch.nn.Module, +) -> list[tuple[str, str, int, str]]: if parent_map := getattr(model, "get_expert_mapping", None): return parent_map() else: diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index e495f9ee4472..1747caf26cef 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -13,19 +13,22 @@ from vllm.distributed.parallel_state import get_dp_group from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( - compute_aligned_M, deep_gemm_block_shape) + compute_aligned_M, + deep_gemm_block_shape, +) from vllm.model_executor.layers.fused_moe.layer import FusedMoE -from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( - TritonOrDeepGemmExperts) + TritonOrDeepGemmExperts, +) from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.utils.deep_gemm import fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous def _extract_data_from_linear_base_module( - m: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, list[int]]: + m: torch.nn.Module, +) -> tuple[torch.Tensor, torch.Tensor, list[int]]: """ Extract weights, weight scales and quantization block sizes from the given LinearBase module. @@ -46,18 +49,24 @@ def _extract_data_from_linear_base_module( def _extract_data_from_fused_moe_module( - m: torch.nn.Module + m: torch.nn.Module, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: """ Extract weights, weight scales and num_topk from FusedMoE module. """ assert isinstance(m, FusedMoE) w13 = m.w13_weight - w13_s = m.w13_weight_scale_inv if hasattr( - m, "w13_weight_scale_inv") else m.w13_weight_scale + w13_s = ( + m.w13_weight_scale_inv + if hasattr(m, "w13_weight_scale_inv") + else m.w13_weight_scale + ) w2 = m.w2_weight - w2_s = m.w2_weight_scale_inv if hasattr( - m, "w2_weight_scale_inv") else m.w2_weight_scale + w2_s = ( + m.w2_weight_scale_inv + if hasattr(m, "w2_weight_scale_inv") + else m.w2_weight_scale + ) num_topk = m.top_k assert isinstance(w13, torch.Tensor) @@ -72,14 +81,20 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: Return True if the input module/layer could be processed with DeepGEMM. """ block_size = deep_gemm_block_shape()[0] - if not (isinstance(module, LinearBase) - and isinstance(module.quant_method, Fp8LinearMethod) - and module.quant_method.block_quant): + if not ( + isinstance(module, LinearBase) + and isinstance(module.quant_method, Fp8LinearMethod) + and module.quant_method.block_quant + ): return False w, _, block_sizes = _extract_data_from_linear_base_module(module) - return (block_sizes == deep_gemm_block_shape() and w.ndim == 2 - and w.shape[0] % block_size == 0 and w.shape[1] % block_size == 0) + return ( + block_sizes == deep_gemm_block_shape() + and w.ndim == 2 + and w.shape[0] % block_size == 0 + and w.shape[1] % block_size == 0 + ) def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: @@ -88,27 +103,26 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: moe_quant_config = module.quant_method.get_fused_moe_quant_config(module) - if (moe_quant_config is None - or moe_quant_config.quant_dtype != torch.float8_e4m3fn - or moe_quant_config.block_shape != deep_gemm_block_shape()): + if ( + moe_quant_config is None + or moe_quant_config.quant_dtype != torch.float8_e4m3fn + or moe_quant_config.block_shape != deep_gemm_block_shape() + ): return False - if not isinstance(module.quant_method.fused_experts, - FusedMoEModularKernel): + if not isinstance(module.quant_method.fused_experts, FusedMoEModularKernel): # fused_experts could invoke deep_gemm_moe_fp8 return True mk: FusedMoEModularKernel = module.quant_method.fused_experts # Further check if the ModularKernel implementation uses the DeepGemmExperts - return isinstance(mk.fused_experts, - (DeepGemmExperts, TritonOrDeepGemmExperts)) + return isinstance(mk.fused_experts, (DeepGemmExperts, TritonOrDeepGemmExperts)) FP8_GEMM_NT_WARMUP_CACHE: set[torch.Size] = set() -def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, - max_tokens: int): +def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, max_tokens: int): if w.size() in FP8_GEMM_NT_WARMUP_CACHE: return @@ -116,20 +130,18 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, block_m = deep_gemm_block_shape()[0] device = w.device - a1q = torch.empty((max_tokens, k), - device=device, - dtype=torch.float8_e4m3fn) - a1q_scales = torch.empty((max_tokens, k // block_m), - device=device, - dtype=torch.float32) + a1q = torch.empty((max_tokens, k), device=device, dtype=torch.float8_e4m3fn) + a1q_scales = torch.empty( + (max_tokens, k // block_m), device=device, dtype=torch.float32 + ) out = torch.empty((max_tokens, n), device=device, dtype=torch.bfloat16) - pbar = tqdm(total=max_tokens, - desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") + pbar = tqdm(total=max_tokens, desc=f"DeepGemm(fp8_gemm_nt) warmup (W={w.size()})") num_tokens = max_tokens while num_tokens > 0: - fp8_gemm_nt((a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), - out[:num_tokens]) + fp8_gemm_nt( + (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, ws), out[:num_tokens] + ) pbar.update(1) num_tokens -= 1 @@ -140,14 +152,20 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor, def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( - w1: torch.Tensor, w2: torch.Tensor, w1_scale: torch.Tensor, - w2_scale: torch.Tensor, num_topk: int, max_tokens: int): - if (w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE - and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE): + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + num_topk: int, + max_tokens: int, +): + if ( + w1.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE + and w2.size() in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE + ): return - assert w1.size(0) == w2.size(0), ( - "w1 and w2 must have the same number of experts") + assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" block_m = deep_gemm_block_shape()[0] num_experts = w1.size(0) @@ -159,39 +177,36 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( # This is the maximum GroupedGemm M size that we expect to run # the grouped_gemm with. - MAX_M = compute_aligned_M(max_tokens, - num_topk, - num_experts, - block_m, - expert_tokens_meta=None) + MAX_M = compute_aligned_M( + max_tokens, num_topk, num_experts, block_m, expert_tokens_meta=None + ) # Distribute expert-ids evenly. MAX_BLOCKS = MAX_M // block_m - expert_ids_block = torch.randint(low=0, - high=num_experts, - size=(MAX_BLOCKS, ), - device=device, - dtype=torch.int32) + expert_ids_block = torch.randint( + low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32 + ) expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) def _warmup(w: torch.Tensor, w_scale: torch.Tensor): - _, n, k = w.size() a1q = torch.empty((MAX_M, k), device=device, dtype=torch.float8_e4m3fn) - a1q_scales = torch.empty((MAX_M, k // block_m), - device=device, - dtype=torch.float32) + a1q_scales = torch.empty( + (MAX_M, k // block_m), device=device, dtype=torch.float32 + ) out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) pbar = tqdm( total=MAX_BLOCKS, - desc= - f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})" + desc=f"DeepGemm(m_grouped_fp8_gemm_nt_contiguous) warmup (W={w.size()})", ) num_tokens = MAX_M while num_tokens > 0: m_grouped_fp8_gemm_nt_contiguous( - (a1q[:num_tokens], a1q_scales[:num_tokens]), (w, w_scale), - out[:num_tokens], expert_ids[:num_tokens]) + (a1q[:num_tokens], a1q_scales[:num_tokens]), + (w, w_scale), + out[:num_tokens], + expert_ids[:num_tokens], + ) pbar.update(1) num_tokens = num_tokens - block_m @@ -202,27 +217,27 @@ def _warmup(w: torch.Tensor, w_scale: torch.Tensor): def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int): - dg_modules = [ - m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m) - ] + dg_modules = [m for m in model.modules() if _fp8_linear_may_use_deep_gemm(m)] for dgm in dg_modules: w, ws, _ = _extract_data_from_linear_base_module(dgm) _deepgemm_fp8_gemm_nt_warmup(w=w, ws=ws, max_tokens=max_tokens) -def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module, - max_tokens: int): +def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( + model: torch.nn.Module, max_tokens: int +): dg_modules = [ - m for m in model.modules() - if _fused_moe_grouped_gemm_may_use_deep_gemm(m) + m for m in model.modules() if _fused_moe_grouped_gemm_may_use_deep_gemm(m) ] for dgm in dg_modules: - w13, w13_scale, w2, w2_scale, num_topk = ( - _extract_data_from_fused_moe_module(dgm)) + w13, w13_scale, w2, w2_scale, num_topk = _extract_data_from_fused_moe_module( + dgm + ) _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup( - w13, w2, w13_scale, w2_scale, num_topk, max_tokens) + w13, w2, w13_scale, w2_scale, num_topk, max_tokens + ) def deep_gemm_warmup(model: torch.nn.Module, max_tokens: int): diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 3f99340c2906..23227065ee95 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -5,6 +5,7 @@ This is useful specifically for JIT'ed kernels as we don't want JIT'ing to happen during model execution. """ + from typing import TYPE_CHECKING import torch @@ -25,9 +26,11 @@ def kernel_warmup(worker: "Worker"): # Deep GEMM warmup - do_deep_gemm_warmup = (envs.VLLM_USE_DEEP_GEMM - and is_deep_gemm_supported() - and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP) + do_deep_gemm_warmup = ( + envs.VLLM_USE_DEEP_GEMM + and is_deep_gemm_supported() + and not envs.VLLM_SKIP_DEEP_GEMM_WARMUP + ) if do_deep_gemm_warmup: model = worker.get_model() max_tokens = worker.scheduler_config.max_num_batched_tokens @@ -47,8 +50,10 @@ def _is_flashinfer_backend(backend): return False if not worker.model_runner.is_pooling_model and all( - _is_flashinfer_backend(group.backend) - for groups in worker.model_runner.attn_groups for group in groups): + _is_flashinfer_backend(group.backend) + for groups in worker.model_runner.attn_groups + for group in groups + ): logger.info("Warming up FlashInfer attention.") # Warmup with mixed batch containing both prefill and decode tokens # This is to warm up both prefill and decode attention kernels @@ -78,6 +83,8 @@ def flashinfer_autotune(runner: "GPUModelRunner") -> None: # When autotuning with number of tokens m, flashinfer will autotune # operations for all number of tokens up to m. # So we only need to run with the max number of tokens. - runner._dummy_run(runner.scheduler_config.max_num_batched_tokens, - skip_eplb=True, - is_profile=True) + runner._dummy_run( + runner.scheduler_config.max_num_batched_tokens, + skip_eplb=True, + is_profile=True, + ) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 8ea79078465e..b7cbb3bbc67e 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,10 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .hasher import MultiModalHasher -from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, - MultiModalDataDict, MultiModalKwargs, - MultiModalKwargsItems, MultiModalPlaceholderDict, - MultiModalUUIDDict, NestedTensors) +from .inputs import ( + BatchedTensorInputs, + ModalityData, + MultiModalDataBuiltins, + MultiModalDataDict, + MultiModalKwargs, + MultiModalKwargsItems, + MultiModalPlaceholderDict, + MultiModalUUIDDict, + NestedTensors, +) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index d7e9d402a1f9..d81354d9a399 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -66,23 +66,25 @@ def resample( orig_sr: float, ) -> npt.NDArray[np.floating]: if self.target_sr is None: - raise RuntimeError("Audio resampling is not supported when " - "`target_sr` is not provided") + raise RuntimeError( + "Audio resampling is not supported when `target_sr` is not provided" + ) if self.method == "librosa": - return resample_audio_librosa(audio, - orig_sr=orig_sr, - target_sr=self.target_sr) + return resample_audio_librosa( + audio, orig_sr=orig_sr, target_sr=self.target_sr + ) elif self.method == "scipy": - return resample_audio_scipy(audio, - orig_sr=orig_sr, - target_sr=self.target_sr) + return resample_audio_scipy( + audio, orig_sr=orig_sr, target_sr=self.target_sr + ) else: - raise ValueError(f"Invalid resampling method: {self.method}. " - "Supported methods are 'librosa' and 'scipy'.") + raise ValueError( + f"Invalid resampling method: {self.method}. " + "Supported methods are 'librosa' and 'scipy'." + ) class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): - def __init__(self, **kwargs) -> None: super().__init__() @@ -113,4 +115,4 @@ def encode_base64(self, media: tuple[npt.NDArray, int]) -> str: soundfile.write(buffer, audio, sr, format="WAV") data = buffer.getvalue() - return base64.b64encode(data).decode('utf-8') + return base64.b64encode(data).decode("utf-8") diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index faffddd57199..fef118a93c6c 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -9,7 +9,6 @@ class MediaIO(ABC, Generic[_T]): - @abstractmethod def load_bytes(self, data: bytes) -> _T: raise NotImplementedError diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 642ec3fd7e3f..15aa91a04092 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -11,17 +11,24 @@ from typing_extensions import TypeAlias, override from vllm.distributed.device_communicators.shm_object_storage import ( - MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer) + MsgpackSerde, + SingleWriterShmObjectStorage, + SingleWriterShmRingBuffer, +) from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME from vllm.logger import init_logger from vllm.utils import GiB_bytes, LRUCache, MiB_bytes -from vllm.utils.jsontree import (json_count_leaves, json_map_leaves, - json_reduce_leaves) +from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves -from .inputs import (MultiModalBatchedField, MultiModalFeatureSpec, - MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, MultiModalKwargsItems, - NestedTensors) +from .inputs import ( + MultiModalBatchedField, + MultiModalFeatureSpec, + MultiModalFieldElem, + MultiModalKwargs, + MultiModalKwargsItem, + MultiModalKwargsItems, + NestedTensors, +) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -90,7 +97,6 @@ def __init__( class MultiModalCache: - @classmethod def get_leaf_size(cls, leaf: object) -> int: if isinstance(leaf, MultiModalProcessorCacheItem): @@ -99,8 +105,15 @@ def get_leaf_size(cls, leaf: object) -> int: return leaf.item_size # These are not subclasses of dict - if isinstance(leaf, (MultiModalKwargs, MultiModalKwargsItems, - MultiModalKwargsItem, MultiModalFieldElem)): + if isinstance( + leaf, + ( + MultiModalKwargs, + MultiModalKwargsItems, + MultiModalKwargsItem, + MultiModalFieldElem, + ), + ): return cls.get_item_size(leaf.data) # type: ignore # sys.getsizeof doesn't work for tensors @@ -116,8 +129,9 @@ def get_item_size( *, debug: bool = False, ) -> int: - size = json_reduce_leaves(operator.add, - json_map_leaves(cls.get_leaf_size, value)) + size = json_reduce_leaves( + operator.add, json_map_leaves(cls.get_leaf_size, value) + ) if debug: leaf_count = json_count_leaves(value) @@ -241,17 +255,19 @@ def clear_cache(self) -> None: raise NotImplementedError -MultiModalProcessorCacheInItem: TypeAlias = \ - Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]] +MultiModalProcessorCacheInItem: TypeAlias = Optional[ + tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] +] -MultiModalProcessorCacheOutItem: TypeAlias = \ - tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]] +MultiModalProcessorCacheOutItem: TypeAlias = tuple[ + Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"] +] class BaseMultiModalProcessorCache( - BaseMultiModalCache[MultiModalProcessorCacheInItem, - MultiModalProcessorCacheOutItem]): + BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem] +): """The required interface for caches on P0.""" @abstractmethod @@ -405,15 +421,13 @@ def __init__(self, vllm_config: "VllmConfig") -> None: create=True, # sender is the writer ) self._shm_cache = SingleWriterShmObjectStorage( - max_object_size=mm_config.mm_shm_cache_max_object_size_mb * - MiB_bytes, + max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes, n_readers=self.world_size, ring_buffer=ring_buffer, serde_class=MsgpackSerde, ) # cache (prompt_updates, modality) for P0 only - self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], - str]] = {} + self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {} @override def is_cached_item(self, mm_hash: str) -> bool: @@ -425,12 +439,10 @@ def get_and_update_item( mm_item: MultiModalProcessorCacheInItem, mm_hash: str, ) -> MultiModalProcessorCacheOutItem: - if self._shm_cache.is_cached(mm_hash): address, monotonic_id = self._shm_cache.get_cached(mm_hash) prompt_updates, modality = self._p0_cache[mm_hash] - return self.address_as_item(address, monotonic_id, - modality), prompt_updates + return self.address_as_item(address, monotonic_id, modality), prompt_updates assert mm_item is not None, f"Expected a cached item for {mm_hash=}" @@ -440,15 +452,15 @@ def get_and_update_item( if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index): self.remove_dangling_items() self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality - address_item = self.address_as_item(address, monotonic_id, - mm_item[0].modality) + address_item = self.address_as_item( + address, monotonic_id, mm_item[0].modality + ) return address_item, mm_item[1] except (ValueError, MemoryError) as e: # put may fail if the object is too large or # the cache is full. # In this case we log the error and keep the original mm_input. - logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, - e) + logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e) return mm_item @override @@ -463,8 +475,9 @@ def remove_dangling_items(self) -> None: for mm_hash in dangling_hashes: del self._p0_cache[mm_hash] - def address_as_item(self, address: int, monotonic_id: int, - modality: str) -> MultiModalKwargsItem: + def address_as_item( + self, address: int, monotonic_id: int, modality: str + ) -> MultiModalKwargsItem: addr_elem = MultiModalFieldElem( modality=modality, key="address", @@ -494,9 +507,10 @@ def _enable_processor_cache( def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: parallel_config = vllm_config.parallel_config - supports_ipc_cache = ((parallel_config._api_process_count == 1 - and parallel_config.data_parallel_size == 1) - or parallel_config.data_parallel_external_lb) + supports_ipc_cache = ( + parallel_config._api_process_count == 1 + and parallel_config.data_parallel_size == 1 + ) or parallel_config.data_parallel_external_lb return supports_ipc_cache @@ -542,8 +556,8 @@ def processor_only_cache_from_config( class BaseMultiModalReceiverCache( - BaseMultiModalCache[Optional[MultiModalKwargsItem], - MultiModalKwargsItem]): + BaseMultiModalCache[Optional[MultiModalKwargsItem], MultiModalKwargsItem] +): """The required interface for caches on P1.""" def get_and_update_features( @@ -552,8 +566,7 @@ def get_and_update_features( ) -> list["MultiModalFeatureSpec"]: """Update multimodal features with cached encoder outputs.""" for feature in mm_features: - feature.data = self.get_and_update_item(feature.data, - feature.identifier) + feature.data = self.get_and_update_item(feature.data, feature.identifier) return mm_features @@ -623,8 +636,7 @@ def __init__( create=False, # Server is a reader ) self._shm_cache = SingleWriterShmObjectStorage( - max_object_size=mm_config.mm_shm_cache_max_object_size_mb * - MiB_bytes, + max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes, n_readers=self.world_size, ring_buffer=ring_buffer, serde_class=MsgpackSerde, diff --git a/vllm/multimodal/evs.py b/vllm/multimodal/evs.py index 056f3d905968..3a706700da6a 100644 --- a/vllm/multimodal/evs.py +++ b/vllm/multimodal/evs.py @@ -13,8 +13,9 @@ import torch -def compute_retained_tokens_count(video_size_thw: torch.LongTensor, - spatial_merge_size: int, q: float) -> int: +def compute_retained_tokens_count( + video_size_thw: torch.LongTensor, spatial_merge_size: int, q: float +) -> int: """ Compute the number of retained tokens for a given video. Method ensures that we retain all the tokens from the first frame @@ -66,23 +67,21 @@ def compute_retention_mask( ) # Core EVS - similarity = torch.nn.functional.cosine_similarity(video_embeds[1:, ...], - video_embeds[:-1, ...], - dim=-1) + similarity = torch.nn.functional.cosine_similarity( + video_embeds[1:, ...], video_embeds[:-1, ...], dim=-1 + ) dissimilarity = 1 - similarity # Always ensure we include all tokens from the first frame dissimilarity = torch.cat( - [255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], - dim=0) + [255 * torch.ones_like(video_embeds[:1, :, :, 0]), dissimilarity], dim=0 + ) dissimilarity_flat = dissimilarity.view(-1) - order = torch.argsort(dissimilarity_flat, - dim=-1, - descending=True, - stable=True) - retain_num_tokens = compute_retained_tokens_count(video_size_thw, - spatial_merge_size, q) + order = torch.argsort(dissimilarity_flat, dim=-1, descending=True, stable=True) + retain_num_tokens = compute_retained_tokens_count( + video_size_thw, spatial_merge_size, q + ) topk_indices = order[:retain_num_tokens] retention_mask = torch.zeros_like(dissimilarity_flat, dtype=torch.bool) @@ -119,18 +118,34 @@ def compute_mrope_for_media( llm_grid_h = video_size_thw[1] // spatial_merge_size llm_grid_w = video_size_thw[2] // spatial_merge_size - t_index = ((torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).mul( - tokens_per_second * video_second_per_grid)).long().flatten()) - h_index = (torch.arange(llm_grid_h).view(1, -1, - 1).expand(llm_grid_t, -1, - llm_grid_w).flatten()) - w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand( - llm_grid_t, llm_grid_h, -1).flatten()) - llm_grid_w = (torch.tensor([llm_grid_w - ]).view(1, 1, - 1).expand(llm_grid_t, llm_grid_h, - llm_grid_w).flatten()) + t_index = ( + ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .mul(tokens_per_second * video_second_per_grid) + ) + .long() + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_grid_w = ( + torch.tensor([llm_grid_w]) + .view(1, 1, 1) + .expand(llm_grid_t, llm_grid_h, llm_grid_w) + .flatten() + ) positions = torch.stack([t_index, h_index, w_index, llm_grid_w], dim=1) return positions @@ -183,7 +198,8 @@ def recompute_mrope_positions( # Tensors positions: torch.LongTensor = typing.cast( - torch.LongTensor, mrope_positions.clone()) # (3, N) + torch.LongTensor, mrope_positions.clone() + ) # (3, N) N = input_ids.numel() image_mask = input_ids.eq(image_token_id) @@ -193,8 +209,7 @@ def recompute_mrope_positions( # Early exit: no media in this chunk if len(multimodal_positions) == 0: - delta = (int((positions.max().item() + 1) - - N) if positions.numel() else -N) + delta = int((positions.max().item() + 1) - N) if positions.numel() else -N return positions, delta total_mm_tokens = torch.count_nonzero(media_mask) @@ -203,12 +218,12 @@ def recompute_mrope_positions( # Early exit: we've updated positions for all media tokens # (and consequently - for all remaining text tokens) if seen_mm_tokens == total_mm_tokens: - delta = (int((positions.max().item() + 1) - - N) if positions.numel() else -N) + delta = int((positions.max().item() + 1) - N) if positions.numel() else -N return positions, delta - vision_start_indices = (input_ids == vision_start_token_id).nonzero( - as_tuple=True)[0] + vision_start_indices = (input_ids == vision_start_token_id).nonzero(as_tuple=True)[ + 0 + ] for mm_pos in multimodal_positions: # Each mm_pos can be a complete embedding for single media @@ -218,8 +233,9 @@ def recompute_mrope_positions( # - Current prefill chunk has no vision start indexes at all # - Vision start token appeared in previous prefill round # - Regular case - seen_vision_start_indices = vision_start_indices[vision_start_indices < - num_computed_tokens] + seen_vision_start_indices = vision_start_indices[ + vision_start_indices < num_computed_tokens + ] if len(seen_vision_start_indices): # If we have encountered some vision start indexes, @@ -228,19 +244,23 @@ def recompute_mrope_positions( # | TTTTTTTTTSVVVVVVVVVV|VVVVVVTTTTTTTTTTTTTTTT| last_vision_start_token = seen_vision_start_indices[-1] seem_mm_tokens_before_last_vision_start = torch.count_nonzero( - media_mask[:last_vision_start_token]) + media_mask[:last_vision_start_token] + ) in_the_middle_of_media = ( - seen_mm_tokens > seem_mm_tokens_before_last_vision_start) + seen_mm_tokens > seem_mm_tokens_before_last_vision_start + ) if in_the_middle_of_media: - mm_embeddings_seen = (seen_mm_tokens - - seem_mm_tokens_before_last_vision_start) + mm_embeddings_seen = ( + seen_mm_tokens - seem_mm_tokens_before_last_vision_start + ) global_mm_start = last_vision_start_token else: # We have completed previous mm_embedding part and # ready to start a new one next_vision_start_token = vision_start_indices[ - vision_start_indices >= num_computed_tokens][0] + vision_start_indices >= num_computed_tokens + ][0] mm_embeddings_seen = 0 global_mm_start = next_vision_start_token @@ -248,7 +268,8 @@ def recompute_mrope_positions( # If there were no vision start indexes so far, # let's find first vision start index next_vision_start_token = vision_start_indices[ - vision_start_indices >= num_computed_tokens][0] + vision_start_indices >= num_computed_tokens + ][0] mm_embeddings_seen = 0 global_mm_start = next_vision_start_token diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index df6c531d876a..91d86cd9a189 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -17,23 +17,23 @@ class MultiModalHasher: - @classmethod def serialize_item(cls, obj: object) -> Iterable[Union[bytes, memoryview]]: # Simple cases if isinstance(obj, (bytes, memoryview)): - return (obj, ) + return (obj,) if isinstance(obj, str): - return (obj.encode("utf-8"), ) + return (obj.encode("utf-8"),) if isinstance(obj, (int, float)): - return (np.array(obj).tobytes(), ) + return (np.array(obj).tobytes(),) if isinstance(obj, Image.Image): exif = obj.getexif() if Image.ExifTags.Base.ImageID in exif and isinstance( - exif[Image.ExifTags.Base.ImageID], uuid.UUID): + exif[Image.ExifTags.Base.ImageID], uuid.UUID + ): # If the image has exif ImageID tag, use that - return (exif[Image.ExifTags.Base.ImageID].bytes, ) + return (exif[Image.ExifTags.Base.ImageID].bytes,) data = {"mode": obj.mode, "data": np.asarray(obj)} if obj.palette is not None: data["palette"] = obj.palette.palette @@ -49,30 +49,35 @@ def serialize_item(cls, obj: object) -> Iterable[Union[bytes, memoryview]]: # Workaround: View the tensor as a contiguous 1D array of bytes if tensor_dtype == torch.bfloat16: tensor_obj = tensor_obj.contiguous() - tensor_obj = tensor_obj.view( - (tensor_obj.numel(), )).view(torch.uint8) + tensor_obj = tensor_obj.view((tensor_obj.numel(),)).view(torch.uint8) return cls.iter_item_to_bytes( - "tensor", { + "tensor", + { "original_dtype": str(tensor_dtype), "original_shape": tuple(tensor_shape), "data": tensor_obj.numpy(), - }) + }, + ) return cls.iter_item_to_bytes("tensor", tensor_obj.numpy()) if isinstance(obj, np.ndarray): # If the array is non-contiguous, we need to copy it first - arr_data = obj.view( - np.uint8).data if obj.flags.c_contiguous else obj.tobytes() - return cls.iter_item_to_bytes("ndarray", { - "dtype": obj.dtype.str, - "shape": obj.shape, - "data": arr_data, - }) + arr_data = ( + obj.view(np.uint8).data if obj.flags.c_contiguous else obj.tobytes() + ) + return cls.iter_item_to_bytes( + "ndarray", + { + "dtype": obj.dtype.str, + "shape": obj.shape, + "data": arr_data, + }, + ) logger.warning( - "No serialization method found for %s. " - "Falling back to pickle.", type(obj)) + "No serialization method found for %s. Falling back to pickle.", type(obj) + ) - return (pickle.dumps(obj), ) + return (pickle.dumps(obj),) @classmethod def iter_item_to_bytes( diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 1006c1ce4b24..f50ab1faebba 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -12,9 +12,9 @@ from .base import MediaIO -def rescale_image_size(image: Image.Image, - size_factor: float, - transpose: int = -1) -> Image.Image: +def rescale_image_size( + image: Image.Image, size_factor: float, transpose: int = -1 +) -> Image.Image: """Rescale the dimensions of an image by a constant factor.""" new_width = int(image.width * size_factor) new_height = int(image.height * size_factor) @@ -26,7 +26,7 @@ def rescale_image_size(image: Image.Image, def rgba_to_rgb( image: Image.Image, - background_color: Union[tuple[int, int, int], list[int]] = (255, 255, 255) + background_color: Union[tuple[int, int, int], list[int]] = (255, 255, 255), ) -> Image.Image: """Convert an RGBA image to RGB with filled background color.""" assert image.mode == "RGBA" @@ -45,7 +45,6 @@ def convert_image_mode(image: Image.Image, to_mode: str): class ImageMediaIO(MediaIO[Image.Image]): - def __init__(self, image_mode: str = "RGB", **kwargs) -> None: super().__init__() @@ -59,18 +58,21 @@ def __init__(self, image_mode: str = "RGB", **kwargs) -> None: # Extract RGBA background color from kwargs if provided # Default to white background for backward compatibility - rgba_bg = kwargs.get('rgba_background_color', (255, 255, 255)) + rgba_bg = kwargs.get("rgba_background_color", (255, 255, 255)) # Convert list to tuple for consistency if isinstance(rgba_bg, list): rgba_bg = tuple(rgba_bg) # Validate rgba_background_color format - if not (isinstance(rgba_bg, tuple) and len(rgba_bg) == 3 - and all(isinstance(c, int) and 0 <= c <= 255 - for c in rgba_bg)): + if not ( + isinstance(rgba_bg, tuple) + and len(rgba_bg) == 3 + and all(isinstance(c, int) and 0 <= c <= 255 for c in rgba_bg) + ): raise ValueError( "rgba_background_color must be a list or tuple of 3 integers " - "in the range [0, 255].") + "in the range [0, 255]." + ) self.rgba_background_color = rgba_bg def _convert_image_mode(self, image: Image.Image) -> Image.Image: @@ -108,11 +110,10 @@ def encode_base64( image.save(buffer, image_format) data = buffer.getvalue() - return pybase64.b64encode(data).decode('utf-8') + return pybase64.b64encode(data).decode("utf-8") class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): - def __init__(self) -> None: super().__init__() @@ -127,4 +128,4 @@ def load_file(self, filepath: Path) -> torch.Tensor: return torch.load(filepath, weights_only=True) def encode_base64(self, media: torch.Tensor) -> str: - return pybase64.b64encode(media.numpy()).decode('utf-8') + return pybase64.b64encode(media.numpy()).decode("utf-8") diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 14d0c8dda78e..45e6ac2adaca 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -7,8 +7,7 @@ from dataclasses import dataclass from functools import partial from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, - cast, final) +from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast, final import numpy as np from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated @@ -35,8 +34,9 @@ item, which can be passed to a HuggingFace `ImageProcessor`. """ -HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor", - list[np.ndarray], list["torch.Tensor"]] +HfVideoItem: TypeAlias = Union[ + list["Image"], np.ndarray, "torch.Tensor", list[np.ndarray], list["torch.Tensor"] +] """ A `transformers.image_utils.VideoInput` representing a single video item, which can be passed to a HuggingFace `VideoProcessor`. @@ -58,8 +58,9 @@ these are directly passed to the model without HF processing. """ -VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor", - tuple[HfVideoItem, dict[str, Any]]] +VideoItem: TypeAlias = Union[ + HfVideoItem, "torch.Tensor", tuple[HfVideoItem, dict[str, Any]] +] """ A `transformers.video_utils.VideoInput` representing a single video item. This can be passed to a HuggingFace `VideoProcessor` @@ -70,8 +71,7 @@ these are directly passed to the model without HF processing. """ -AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], - "torch.Tensor"] +AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float], "torch.Tensor"] """ Represents a single audio item, which can be passed to a HuggingFace `AudioProcessor`. @@ -177,8 +177,12 @@ def __eq__(self, other: object) -> bool: return nested_tensors_equal(self.is_embed, other.is_embed) -NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"], - "torch.Tensor", tuple["torch.Tensor", ...]] +NestedTensors: TypeAlias = Union[ + list["NestedTensors"], + list["torch.Tensor"], + "torch.Tensor", + tuple["torch.Tensor", ...], +] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ @@ -193,11 +197,13 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: return isinstance(a, torch.Tensor) and torch.equal(b, a) if isinstance(a, list): - return (isinstance(b, list) - and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))) + return isinstance(b, list) and all( + nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b) + ) if isinstance(b, list): - return (isinstance(a, list) - and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))) + return isinstance(a, list) and all( + nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a) + ) # Both a and b are scalars return a == b @@ -214,7 +220,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: class MultiModalFeatureSpec: """ Represents a single multimodal input with its processed data and metadata. - + Used by the V1 engine to track multimodal data through processing and caching. A request containing multiple multimodal items will have one MultiModalFeatureSpec per item. @@ -280,9 +286,11 @@ def __eq__(self, other: object) -> bool: else: data_equal = nested_tensors_equal(self.data, other.data) - return ((self.modality, self.key) == (other.modality, other.key) - and data_equal - and type(self.field) == type(other.field)) # noqa: E721 + return ( + (self.modality, self.key) == (other.modality, other.key) + and data_equal + and type(self.field) is type(other.field) + ) # noqa: E721 @dataclass(frozen=True) @@ -385,10 +393,12 @@ def _reduce_data( return batch[0].unsqueeze(0).contiguous() first_shape = batch[0].shape if all(elem.shape == first_shape for elem in batch): - out = torch.empty((len(batch), *batch[0].shape), - dtype=batch[0].dtype, - device=batch[0].device, - pin_memory=pin_memory) + out = torch.empty( + (len(batch), *batch[0].shape), + dtype=batch[0].dtype, + device=batch[0].device, + pin_memory=pin_memory, + ) return torch.stack(batch, out=out) return batch @@ -401,6 +411,7 @@ class MultiModalFlatField(BaseMultiModalField): [`MultiModalFieldConfig.flat`][vllm.multimodal.inputs.MultiModalFieldConfig.flat] [`MultiModalFieldConfig.flat_from_sizes`][vllm.multimodal.inputs.MultiModalFieldConfig.flat_from_sizes] """ + slices: Union[Sequence[slice], Sequence[Sequence[slice]]] dim: int = 0 @@ -412,8 +423,9 @@ def build_elems( ) -> Sequence[MultiModalFieldElem]: field_factory = self._field_factory(modality=modality, key=key) if not is_list_of(self.slices, slice, check="all"): - assert isinstance(data, torch.Tensor), \ + assert isinstance(data, torch.Tensor), ( "torch.Tensor is required for multiple slices" + ) return [field_factory(data[cast(slice, s)]) for s in self.slices] def _reduce_data( @@ -433,17 +445,19 @@ def _reduce_data( dim = self.dim + (self.dim < 0) * len(batch[0].shape) def _shape_before_after(tensor: torch.Tensor): - return tensor.shape[:dim], tensor.shape[dim + 1:] + return tensor.shape[:dim], tensor.shape[dim + 1 :] first_shape = _shape_before_after(batch[0]) if all(_shape_before_after(elem) == first_shape for elem in batch): shape_before, shape_after = first_shape shape_concat = sum(item.shape[dim] for item in batch) - out = torch.empty((*shape_before, shape_concat, *shape_after), - dtype=batch[0].dtype, - device=batch[0].device, - pin_memory=pin_memory) + out = torch.empty( + (*shape_before, shape_concat, *shape_after), + dtype=batch[0].dtype, + device=batch[0].device, + pin_memory=pin_memory, + ) return torch.concat(batch, dim=self.dim, out=out) assert self.dim == 0, "dim == 0 is required for nested list" @@ -456,6 +470,7 @@ class MultiModalSharedField(BaseMultiModalField): Info: [`MultiModalFieldConfig.shared`][vllm.multimodal.inputs.MultiModalFieldConfig.shared] """ + batch_size: int def build_elems( @@ -477,7 +492,6 @@ def _reduce_data( class MultiModalFieldConfig: - @staticmethod def batched(modality: str): """ @@ -508,9 +522,11 @@ def batched(modality: str): ) @staticmethod - def flat(modality: str, - slices: Union[Sequence[slice], Sequence[Sequence[slice]]], - dim: int = 0): + def flat( + modality: str, + slices: Union[Sequence[slice], Sequence[Sequence[slice]]], + dim: int = 0, + ): """ Defines a field where an element in the batch is obtained by slicing along the first dimension of the underlying data. @@ -561,9 +577,7 @@ def flat(modality: str, ) @staticmethod - def flat_from_sizes(modality: str, - size_per_item: "torch.Tensor", - dim: int = 0): + def flat_from_sizes(modality: str, size_per_item: "torch.Tensor", dim: int = 0): """ Defines a field where an element in the batch is obtained by slicing along the first dimension of the underlying data. @@ -609,13 +623,17 @@ def flat_from_sizes(modality: str, """ if size_per_item.ndim != 1: - raise ValueError("size_per_item should be a 1-D tensor, " - f"but found shape: {size_per_item.shape}") + raise ValueError( + "size_per_item should be a 1-D tensor, " + f"but found shape: {size_per_item.shape}" + ) slice_idxs = [0, *accumulate(size_per_item)] - slices = [(slice(None, None, None), ) * dim + - (slice(slice_idxs[i], slice_idxs[i + 1]), ) - for i in range(len(size_per_item))] + slices = [ + (slice(None, None, None),) * dim + + (slice(slice_idxs[i], slice_idxs[i + 1]),) + for i in range(len(size_per_item)) + ] return MultiModalFieldConfig.flat(modality, slices, dim=dim) @@ -745,7 +763,8 @@ def from_hf_inputs( if len(set(batch_sizes.values())) > 1: raise ValueError( f"Cannot merge different batch sizes for {modality=}! " - f"Found: {batch_sizes=}") + f"Found: {batch_sizes=}" + ) batch_size = next(iter(batch_sizes.values())) for item_idx in range(batch_size): @@ -761,8 +780,10 @@ def from_seq(items: Sequence[MultiModalKwargsItem]): def __getitem__(self, modality: str) -> Sequence[_I]: if modality not in self: - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {set(self.keys())}") + raise KeyError( + f"Modality {modality!r} not found. " + f"Available modalities: {set(self.keys())}" + ) return super().__getitem__(modality) # type: ignore[return-value] @@ -770,8 +791,7 @@ def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]": for modality, items in self.items(): for i, item in enumerate(items): if item is None: - raise RuntimeError( - f"Found empty mm_items[{modality}][{i}]") + raise RuntimeError(f"Found empty mm_items[{modality}][{i}]") return self # type: ignore[return-value] @@ -780,17 +800,19 @@ def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": for modality, items in self.items(): for i, item in enumerate(items): if item is None: - raise RuntimeError("Cannot build data from empty " - f"mm_items[{modality}][{i}]") + raise RuntimeError( + f"Cannot build data from empty mm_items[{modality}][{i}]" + ) for key, elem in item.items(): elems_by_key[key].append(elem) - return MultiModalKwargs({ - key: - elems[0].field.reduce_data(elems, pin_memory=pin_memory) - for key, elems in elems_by_key.items() - }) + return MultiModalKwargs( + { + key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) + for key, elems in elems_by_key.items() + } + ) MultiModalKwargsOptionalItems: TypeAlias = Union[ @@ -806,33 +828,36 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): """ @staticmethod - @deprecated("`MultiModalKwargs.from_hf_inputs` is deprecated and " - "will be removed in v0.13. " - "Please use `MultiModalKwargsItems.from_hf_inputs` and " - "access the tensor data using `.get_data()`.") + @deprecated( + "`MultiModalKwargs.from_hf_inputs` is deprecated and " + "will be removed in v0.13. " + "Please use `MultiModalKwargsItems.from_hf_inputs` and " + "access the tensor data using `.get_data()`." + ) def from_hf_inputs( hf_inputs: "BatchFeature", config_by_key: Mapping[str, MultiModalFieldConfig], ): - return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key) \ - .get_data() + return MultiModalKwargsItems.from_hf_inputs(hf_inputs, config_by_key).get_data() @staticmethod - @deprecated("`MultiModalKwargs.from_items` is deprecated and " - "will be removed in v0.13. " - "Please use `MultiModalKwargsItems.from_seq` and " - "access the tensor data using `.get_data()`.") + @deprecated( + "`MultiModalKwargs.from_items` is deprecated and " + "will be removed in v0.13. " + "Please use `MultiModalKwargsItems.from_seq` and " + "access the tensor data using `.get_data()`." + ) def from_items( items: Sequence[MultiModalKwargsItem], *, pin_memory: bool = False, ): - return MultiModalKwargsItems.from_seq(items) \ - .get_data(pin_memory=pin_memory) + return MultiModalKwargsItems.from_seq(items).get_data(pin_memory=pin_memory) @staticmethod - def _try_stack(nested_tensors: NestedTensors, - pin_memory: bool = False) -> NestedTensors: + def _try_stack( + nested_tensors: NestedTensors, pin_memory: bool = False + ) -> NestedTensors: """ Stack the inner dimensions that have the same shape in a nested list of tensors. @@ -849,9 +874,7 @@ def _try_stack(nested_tensors: NestedTensors, if isinstance(nested_tensors, (int, float)): return torch.tensor(nested_tensors) - stacked = [ - MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors - ] + stacked = [MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors] if not is_list_of(stacked, torch.Tensor, check="all"): # Only tensors (not lists) can be stacked. return stacked @@ -867,16 +890,19 @@ def _try_stack(nested_tensors: NestedTensors, # The tensors have incompatible shapes and can't be stacked. return tensors_ - outputs = torch.empty(len(tensors_), - *tensors_[0].shape, - dtype=tensors_[0].dtype, - device=tensors_[0].device, - pin_memory=pin_memory) + outputs = torch.empty( + len(tensors_), + *tensors_[0].shape, + dtype=tensors_[0].dtype, + device=tensors_[0].device, + pin_memory=pin_memory, + ) return torch.stack(tensors_, out=outputs) @staticmethod - def batch(inputs_list: list["MultiModalKwargs"], - pin_memory: bool = False) -> BatchedTensorInputs: + def batch( + inputs_list: list["MultiModalKwargs"], pin_memory: bool = False + ) -> BatchedTensorInputs: """ Batch multiple inputs together into a dictionary. @@ -915,8 +941,10 @@ def as_kwargs( def __getitem__(self, key: str): if key not in self: - raise KeyError(f"Keyword argument {key!r} not found. " - f"Available keys: {set(self.keys())}") + raise KeyError( + f"Keyword argument {key!r} not found. " + f"Available keys: {set(self.keys())}" + ) return super().__getitem__(key) diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 493dd3560a51..8fdc5cf721d0 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -4,8 +4,16 @@ from abc import ABC, abstractmethod from collections import UserDict from collections.abc import Callable, Iterator, Mapping, Sequence -from typing import (TYPE_CHECKING, Any, Generic, Literal, NamedTuple, Optional, - TypeVar, Union) +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + NamedTuple, + Optional, + TypeVar, + Union, +) import numpy as np import torch @@ -14,9 +22,18 @@ from vllm.utils import LazyLoader, is_list_of from .audio import AudioResampler -from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem, - ImageItem, ModalityData, MultiModalDataDict, - MultiModalFieldConfig, MultiModalKwargsItems, VideoItem) +from .inputs import ( + AudioItem, + HfAudioItem, + HfImageItem, + HfVideoItem, + ImageItem, + ModalityData, + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) _T = TypeVar("_T") _I = TypeVar("_I") @@ -40,8 +57,7 @@ def __init__(self, data: _T, modality: str) -> None: self.modality = modality def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r}, " - f"len={len(self)})") + return f"{type(self).__name__}(modality={self.modality!r}, len={len(self)})" def __len__(self) -> int: return self.get_count() @@ -51,8 +67,7 @@ def __getitem__(self, index: int) -> _I: if TYPE_CHECKING: # Auto-generated - def __iter__(self) -> Iterator[_I]: - ... + def __iter__(self) -> Iterator[_I]: ... @abstractmethod def get_count(self) -> int: @@ -95,8 +110,9 @@ def get_passthrough_data(self) -> Mapping[str, object]: return {} -class EmbeddingItems(ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], - torch.Tensor]): +class EmbeddingItems( + ModalityDataItems[Union[torch.Tensor, list[torch.Tensor]], torch.Tensor] +): """ Base class for data items that are expressed as a batched embedding tensor, or a list of embedding tensors (one per item). @@ -118,8 +134,9 @@ def get_feature_size(self, item_idx: int) -> int: return len(self.get(item_idx)) -class DictEmbeddingItems(ModalityDataItems[Mapping[str, torch.Tensor], - Mapping[str, torch.Tensor]]): +class DictEmbeddingItems( + ModalityDataItems[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor]] +): """ Base class for data items that are expressed as a dictionary of tensors. @@ -143,8 +160,10 @@ def __init__( missing_required_data_keys = required_fields - data.keys() if missing_required_data_keys: data_keys = set(data.keys()) - msg = (f"The data should contain the fields: {required_fields}, " - f"but only found the following keys: {data_keys}") + msg = ( + f"The data should contain the fields: {required_fields}, " + f"but only found the following keys: {data_keys}" + ) raise ValueError(msg) fields_config = fields_factory(data) @@ -176,7 +195,6 @@ def get_passthrough_data(self) -> Mapping[str, object]: class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): - def __init__(self, data: Optional[Sequence[HfAudioItem]]) -> None: if data is None: data = [None] @@ -188,7 +206,6 @@ def get_audio_length(self, item_idx: int) -> int: class AudioEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "audio") @@ -199,7 +216,6 @@ class ImageSize(NamedTuple): class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): - def __init__(self, data: Optional[Sequence[HfImageItem]]) -> None: if data is None: data = [None] @@ -218,18 +234,17 @@ def get_image_size(self, item_idx: int) -> ImageSize: class ImageEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "image") class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): - def __init__( self, data: Optional[Sequence[HfVideoItem]], - metadata: Optional[Union[dict[str, Any], - list[Optional[dict[str, Any]]]]] = None, + metadata: Optional[ + Union[dict[str, Any], list[Optional[dict[str, Any]]]] + ] = None, ) -> None: if data is None: data = [None] @@ -252,7 +267,6 @@ def get_frame_size(self, item_idx: int) -> ImageSize: class VideoEmbeddingItems(EmbeddingItems): - def __init__(self, data: Union[torch.Tensor, list[torch.Tensor]]) -> None: super().__init__(data, "video") @@ -276,8 +290,10 @@ def get_count(self, modality: str, *, strict: bool = True) -> int: if modality not in self: if strict: available_modalities = set(self.keys()) - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {available_modalities}") + raise KeyError( + f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}" + ) return 0 @@ -298,20 +314,25 @@ def get_items( """ if modality not in self: available_modalities = set(self.keys()) - raise KeyError(f"Modality {modality!r} not found. " - f"Available modalities: {available_modalities}") + raise KeyError( + f"Modality {modality!r} not found. " + f"Available modalities: {available_modalities}" + ) items = self[modality] if not isinstance(items, typ): - raise TypeError(f"Invalid type of data items for {modality=}. " - f"Expected type: {typ}, but " - f"found type: {type(items)}") + raise TypeError( + f"Invalid type of data items for {modality=}. " + f"Expected type: {typ}, but " + f"found type: {type(items)}" + ) return items # type: ignore[return-value] -ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]], - Optional[ModalityDataItems[Any, Any]]] +ModalityDataParser: TypeAlias = Callable[ + [ModalityData[Any]], Optional[ModalityDataItems[Any, Any]] +] class MultiModalDataParser: @@ -340,7 +361,7 @@ def __init__( self.video_needs_metadata = video_needs_metadata def _is_embeddings( - self, data: object + self, data: object ) -> TypeGuard[Union[torch.Tensor, list[torch.Tensor]]]: if isinstance(data, torch.Tensor): return data.ndim == 3 @@ -395,17 +416,20 @@ def _parse_audio_data( return AudioProcessorItems(None) # also check single audio item with sampling rate - if self._is_empty(data) or (isinstance(data, tuple) - and self._is_empty(data[0])): + if self._is_empty(data) or ( + isinstance(data, tuple) and self._is_empty(data[0]) + ): return None if self._is_embeddings(data): return AudioEmbeddingItems(data) - if (is_list_of(data, float) - or isinstance(data, - (np.ndarray, torch.Tensor)) and data.ndim == 1 - or isinstance(data, tuple)): + if ( + is_list_of(data, float) + or isinstance(data, (np.ndarray, torch.Tensor)) + and data.ndim == 1 + or isinstance(data, tuple) + ): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] @@ -418,8 +442,7 @@ def _parse_audio_data( if orig_sr is None: new_audio = audio else: - new_audio = self.audio_resampler.resample(audio, - orig_sr=orig_sr) + new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr) new_audios.append(new_audio) @@ -438,9 +461,11 @@ def _parse_image_data( if self._is_embeddings(data): return ImageEmbeddingItems(data) - if (isinstance(data, PILImage.Image) - or isinstance(data, - (np.ndarray, torch.Tensor)) and data.ndim == 3): + if ( + isinstance(data, PILImage.Image) + or isinstance(data, (np.ndarray, torch.Tensor)) + and data.ndim == 3 + ): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] @@ -462,9 +487,11 @@ def _parse_video_data( if self._is_embeddings(data): return VideoEmbeddingItems(data) - if (is_list_of(data, PILImage.Image) - or isinstance(data, - (np.ndarray, torch.Tensor)) and data.ndim == 4): + if ( + is_list_of(data, PILImage.Image) + or isinstance(data, (np.ndarray, torch.Tensor)) + and data.ndim == 4 + ): data_items = [data] elif isinstance(data, (np.ndarray, torch.Tensor)): data_items = [elem for elem in data] @@ -495,8 +522,7 @@ def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: "video": self._parse_video_data, } - def parse_mm_data(self, - mm_data: MultiModalDataDict) -> MultiModalDataItems: + def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems: subparsers = self._get_subparsers() mm_items = MultiModalDataItems() diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index bc998dc2785f..5c3739e29d10 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -3,13 +3,21 @@ import time from abc import ABC, abstractmethod from collections import defaultdict -from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, - Sequence) +from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache -from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, - Protocol, Union, cast, overload) +from typing import ( + TYPE_CHECKING, + Any, + Generic, + NamedTuple, + Optional, + Protocol, + Union, + cast, + overload, +) import regex as re import torch @@ -17,20 +25,28 @@ from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config -from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, - encode_tokens) -from vllm.utils import (flatten_2d_lists, full_groupby, - get_allowed_kwarg_only_overrides) +from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens +from vllm.utils import flatten_2d_lists, full_groupby, get_allowed_kwarg_only_overrides from vllm.utils.jsontree import JSONTree, json_map_leaves from .hasher import MultiModalHasher -from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalFieldConfig, MultiModalInputs, - MultiModalKwargsItem, MultiModalKwargsItems, - MultiModalKwargsOptionalItems, MultiModalUUIDDict, - PlaceholderRange) -from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, - MultiModalDataParser) +from .inputs import ( + MultiModalDataDict, + MultiModalEncDecInputs, + MultiModalFieldConfig, + MultiModalInputs, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalKwargsOptionalItems, + MultiModalUUIDDict, + PlaceholderRange, +) +from .parse import ( + DictEmbeddingItems, + EmbeddingItems, + MultiModalDataItems, + MultiModalDataParser, +) if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig @@ -57,9 +73,7 @@ def _cached_encode( *, add_special_tokens: Optional[bool] = None, ) -> list[int]: - return encode_tokens(tokenizer, - text, - add_special_tokens=add_special_tokens) + return encode_tokens(tokenizer, text, add_special_tokens=add_special_tokens) @lru_cache(maxsize=2048) @@ -69,9 +83,9 @@ def _cached_decode( *, skip_special_tokens: Optional[bool] = None, ) -> str: - return decode_tokens(tokenizer, - list(token_ids), - skip_special_tokens=skip_special_tokens) + return decode_tokens( + tokenizer, list(token_ids), skip_special_tokens=skip_special_tokens + ) def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str: @@ -89,24 +103,22 @@ def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]: class _GetMatchIndex(Protocol): - def __call__( self, tokenizer: AnyTokenizer, prompt: PromptSeq, start_idx: int = 0, - ) -> Optional[int]: - ... + ) -> Optional[int]: ... @dataclass class PromptIndex: """Resolves to an index in the prompt.""" + get_match_index: _GetMatchIndex class PromptIndexTargets: - @staticmethod def start() -> PromptIndex: """ @@ -139,9 +151,7 @@ def get_match_index( else: if isinstance(prefix, str): # Make both `list[int]` - prefix = encode_tokens(tokenizer, - prefix, - add_special_tokens=False) + prefix = encode_tokens(tokenizer, prefix, add_special_tokens=False) match_idx = len(prefix) return match_idx if prompt[:match_idx] == prefix else None @@ -181,8 +191,7 @@ class PromptUpdateDetails(Generic[_S]): full: _S """The full content.""" - is_embed: Optional[Callable[[AnyTokenizer, PromptSeq], - torch.Tensor]] = None + is_embed: Optional[Callable[[AnyTokenizer, PromptSeq], torch.Tensor]] = None """ Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], return a boolean mask of shape `(len(full),)` indicating which positions @@ -203,7 +212,6 @@ def select_text( seq: _S, embed_text: str, ) -> "PromptUpdateDetails[_S]": - def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: embed_token_ids = encode_tokens(tokenizer, embed_text) token_ids = _seq2tokens(tokenizer, full) @@ -220,7 +228,6 @@ def select_token_id( seq: _S, embed_token_id: int, ) -> "PromptUpdateDetails[_S]": - def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: token_ids = _seq2tokens(tokenizer, full) @@ -238,8 +245,7 @@ def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: specify which part. """ -PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], - PromptUpdateInfo] +PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo], PromptUpdateInfo] """ Given the index of the processed item within [`modality`][vllm.multimodal.processing.PromptUpdate.modality], @@ -408,11 +414,13 @@ class PromptReplacement(PromptUpdate): modality="image", target="<image>", replacement=PromptUpdateDetails( - full="".join([ - "<image_bos>", - "<image>" * image_feature_size, - "<image_eos>", - ]), + full="".join( + [ + "<image_bos>", + "<image>" * image_feature_size, + "<image_eos>", + ] + ), features="<image>" * image_feature_size, ), ) @@ -426,8 +434,9 @@ class PromptReplacement(PromptUpdate): modality="image", target=[image_token_id], replacement=PromptUpdateDetails( - full=([image_bos_id] + [image_token_id] * image_feature_size - + [image_eos_id]), + full=( + [image_bos_id] + [image_token_id] * image_feature_size + [image_eos_id] + ), features=[image_token_id] * image_feature_size, ), ) @@ -459,10 +468,8 @@ class _HasModalityAttr(Protocol): class _HasModalityProp(Protocol): - @property - def modality(self) -> str: - ... + def modality(self) -> str: ... _M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp]) @@ -520,9 +527,7 @@ def iter_token_matches( target_token_ids = _seq2tokens(tokenizer, target) - for match in iter_token_matches(prompt, - target_token_ids, - start_idx=start_idx): + for match in iter_token_matches(prompt, target_token_ids, start_idx=start_idx): yield PromptTargetMatch(match.start_idx, match.end_idx) def iter_text_matches( @@ -544,8 +549,7 @@ def iter_text_matches( target_text = _seq2text(tokenizer, target) - for match in re.finditer(re.escape(target_text), prompt, - pos=start_idx): + for match in re.finditer(re.escape(target_text), prompt, pos=start_idx): yield PromptTargetMatch(match.start(), match.end()) def iter_matches( @@ -557,9 +561,7 @@ def iter_matches( ) -> Generator[PromptTargetMatch]: """Yield each instance of `self.target` found in `prompt`.""" if isinstance(prompt, str): - return self.iter_text_matches(prompt, - tokenizer, - start_idx=start_idx) + return self.iter_text_matches(prompt, tokenizer, start_idx=start_idx) return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx) @@ -680,9 +682,9 @@ def _find_matches( break # Already found a match for this item for match in update.iter_matches( - prompt, - tokenizer, - start_idx=prev_end_idx, + prompt, + tokenizer, + start_idx=prev_end_idx, ): # All matches should share the same mode if mode is None: @@ -723,8 +725,7 @@ def _apply_matches( out_seqs = list[Union[str, list[int]]]() out_result: MultiModalPromptUpdatesApplyResult = { - m: [None] * len(items) - for m, items in mm_prompt_updates.items() + m: [None] * len(items) for m, items in mm_prompt_updates.items() } start_idx = prev_end_idx = 0 @@ -743,8 +744,7 @@ def _apply_matches( for (modality, item_idx), (match, update_idx) in matches_to_apply: found = True - matched_update = mm_prompt_updates[modality][item_idx][ - update_idx] + matched_update = mm_prompt_updates[modality][item_idx][update_idx] matched_content = matched_update.content.full if mode == UpdateMode.INSERT: @@ -756,9 +756,10 @@ def _apply_matches( out_seqs.append(prompt[prev_end_idx:end_idx_to_insert]) out_seqs.append( - _seq2text(tokenizer, matched_content - ) if isinstance(prompt, str) else _seq2tokens( - tokenizer, matched_content)) + _seq2text(tokenizer, matched_content) + if isinstance(prompt, str) + else _seq2tokens(tokenizer, matched_content) + ) out_result[modality][item_idx] = update_idx # Exclude overlapping matches @@ -784,8 +785,7 @@ def apply_token_matches( the same placeholder tokens. In that case, the modality that appears earlier in `mm_prompt_updates` takes priority. """ - token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates, - tokenizer) + token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates, tokenizer) return flatten_2d_lists(token_id_seqs), result @@ -847,8 +847,7 @@ def _iter_placeholders( if prompt[start_idx:end_idx_full] == content_tokens_full: content_is_embed = content.is_embed if content_is_embed is not None: - content_is_embed = content_is_embed( - tokenizer, content.full) + content_is_embed = content_is_embed(tokenizer, content.full) yield PlaceholderFeaturesInfo( modality=modality, @@ -899,16 +898,14 @@ class InputProcessingContext: """The tokenizer used to tokenize the inputs.""" @overload - def get_hf_config(self, /) -> "PretrainedConfig": - ... + def get_hf_config(self, /) -> "PretrainedConfig": ... @overload def get_hf_config( self, typ: Union[type[_C], tuple[type[_C], ...]], /, - ) -> _C: - ... + ) -> _C: ... def get_hf_config( self, @@ -930,9 +927,11 @@ def get_hf_config( hf_config = self.model_config.hf_config if not isinstance(hf_config, typ): - raise TypeError("Invalid type of HuggingFace config. " - f"Expected type: {typ}, but " - f"found type: {type(hf_config)}") + raise TypeError( + "Invalid type of HuggingFace config. " + f"Expected type: {typ}, but " + f"found type: {type(hf_config)}" + ) return hf_config @@ -956,8 +955,7 @@ def get_mm_config(self): return mm_config @overload - def get_hf_processor(self, /, **kwargs: object) -> "ProcessorMixin": - ... + def get_hf_processor(self, /, **kwargs: object) -> "ProcessorMixin": ... @overload def get_hf_processor( @@ -965,8 +963,7 @@ def get_hf_processor( typ: Union[type[_P], tuple[type[_P], ...]], /, **kwargs: object, - ) -> _P: - ... + ) -> _P: ... def get_hf_processor( self, @@ -1017,7 +1014,6 @@ def _postprocess_output( self, output: JSONTree, ) -> JSONTree: - def _postprocess_one(x: object): if isinstance(x, torch.Tensor): # noqa: SIM102 # This mimics the behavior of transformers.BatchFeature @@ -1054,17 +1050,21 @@ def call_hf_processor( ) try: - output = hf_processor(**data, - **allowed_kwargs, - return_tensors="pt") + output = hf_processor(**data, **allowed_kwargs, return_tensors="pt") except Exception as exc: # See https://github.com/huggingface/tokenizers/issues/537 - if (isinstance(exc, RuntimeError) and exc - and exc.args[0] == "Already borrowed" - and num_tries < max_tries): + if ( + isinstance(exc, RuntimeError) + and exc + and exc.args[0] == "Already borrowed" + and num_tries < max_tries + ): logger.warning( "Failed to acquire tokenizer in current thread. " - "Retrying (%d/%d)...", num_tries, max_tries) + "Retrying (%d/%d)...", + num_tries, + max_tries, + ) time.sleep(0.5) return self.call_hf_processor( hf_processor, @@ -1074,8 +1074,10 @@ def call_hf_processor( max_tries=max_tries, ) - msg = (f"Failed to apply {type(hf_processor).__name__} " - f"on data={data} with kwargs={allowed_kwargs}") + msg = ( + f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={allowed_kwargs}" + ) raise ValueError(msg) from exc @@ -1142,8 +1144,11 @@ def get_allowed_mm_limits(self) -> Mapping[str, int]: for modality, supported_limit in supported_mm_limits.items(): user_limit = mm_config.get_limit_per_prompt(modality) - allowed_limits[modality] = (user_limit if supported_limit is None - else min(user_limit, supported_limit)) + allowed_limits[modality] = ( + user_limit + if supported_limit is None + else min(user_limit, supported_limit) + ) return allowed_limits @@ -1154,7 +1159,7 @@ def get_mm_max_tokens_per_item( ) -> Optional[Mapping[str, int]]: """ Return the maximum number of tokens per item of for each modality. - + When `None` (the default) is returned, vLLM will generate dummy inputs (images/videos) at maximum possible sizes and process them to determine the maximum token count per modality. @@ -1165,7 +1170,7 @@ def get_mm_max_tokens_per_item( counts, avoiding the need for dummy input generation and processing. Note: - The maximum number of tokens per item of each modality returned + The maximum number of tokens per item of each modality returned from this function should respect the model's maximum sequence length and the maximum number of items of each modality allowed, and agree with dummy inputs (images/videos) at maximum possible @@ -1245,10 +1250,7 @@ def __call__( *, mm_uuids: Optional[MultiModalUUIDDict] = None, ) -> MultiModalInputs: - return self.apply(prompt, - mm_data, - hf_processor_mm_kwargs, - mm_uuids=mm_uuids) + return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids) def _get_data_parser(self) -> MultiModalDataParser: """ @@ -1276,8 +1278,7 @@ def validate_num_items( limit = min(supported_limit, allowed_limit) if num_items > limit: - msg = (f"At most {limit} {modality}(s) may be provided in " - "one prompt.") + msg = f"At most {limit} {modality}(s) may be provided in one prompt." if num_items <= supported_limit: msg += " Set `--limit-mm-per-prompt` to increase this limit." @@ -1339,8 +1340,10 @@ def _bind_and_group_updates( mm_item_counts: Mapping[str, int], ) -> MultiModalPromptUpdates: return { - modality: [[update.resolve(item_idx) for update in updates] - for item_idx in range(mm_item_counts.get(modality, 0))] + modality: [ + [update.resolve(item_idx) for update in updates] + for item_idx in range(mm_item_counts.get(modality, 0)) + ] for modality, updates in full_groupby_modality(prompt_updates) } @@ -1385,8 +1388,7 @@ def _find_mm_placeholders( ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: tokenizer = self.info.get_tokenizer() - return find_mm_placeholders(new_token_ids, mm_prompt_updates, - tokenizer) + return find_mm_placeholders(new_token_ids, mm_prompt_updates, tokenizer) def _get_hf_mm_data( self, @@ -1436,7 +1438,8 @@ def _hf_processor_applies_updates( """ return not any( isinstance(items, (EmbeddingItems, DictEmbeddingItems)) - for items in mm_items.values()) + for items in mm_items.values() + ) def _apply_hf_processor_text_mm( self, @@ -1461,7 +1464,7 @@ def _apply_hf_processor_text_mm( ) processed_data.update(passthrough_data) - prompt_ids, = processed_data.pop("input_ids").tolist() + (prompt_ids,) = processed_data.pop("input_ids").tolist() is_update_applied = self._hf_processor_applies_updates( prompt_text=prompt_text, @@ -1564,8 +1567,7 @@ def _apply_hf_processor_main( tokenization_kwargs=tokenization_kwargs, ) - prompt_ids = self._apply_hf_processor_text_only( - prompt, tokenization_kwargs) + prompt_ids = self._apply_hf_processor_text_only(prompt, tokenization_kwargs) else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) @@ -1611,10 +1613,11 @@ def _hash_mm_items( # hash if `hf_processor_mm_kwargs` or `tokenization_kwargs` # are provided. This is because the processed multimodal # inputs can be different depending on the processor kwargs. - if item_uuid is None or \ - hf_processor_mm_kwargs or \ - tokenization_kwargs: - + if ( + item_uuid is None + or hf_processor_mm_kwargs + or tokenization_kwargs + ): # NOTE: use provided hash string to hash with kwargs # if available for better performance. item = item_uuid if item_uuid is not None else item @@ -1623,16 +1626,20 @@ def _hash_mm_items( model_id=model_id, **{modality: item}, **hf_processor_mm_kwargs, - **tokenization_kwargs)) + **tokenization_kwargs, + ) + ) else: computed.append(item_uuid) hashes[modality] = computed else: hashes[modality] = [ - MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: item}, - **hf_processor_mm_kwargs, - **tokenization_kwargs) + MultiModalHasher.hash_kwargs( + model_id=model_id, + **{modality: item}, + **hf_processor_mm_kwargs, + **tokenization_kwargs, + ) for item in items ] @@ -1645,13 +1652,13 @@ def _get_cache_missing_items( mm_hashes: MultiModalHashes, ) -> MultiModalDataItems: mm_is_cached = { - modality: cache.is_cached(hashes) - for modality, hashes in mm_hashes.items() + modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items() } mm_missing_idxs = { modality: [ - idx for idx, item_is_cached in enumerate(items_is_cached) + idx + for idx, item_is_cached in enumerate(items_is_cached) if not item_is_cached ] for modality, items_is_cached in mm_is_cached.items() @@ -1664,7 +1671,8 @@ def _get_cache_missing_items( if data is None: raise ValueError( f"Cache miss for {modality} at index {idx} " - f"but data is not provided.") + f"but data is not provided." + ) else: missing_modality_data.append(data) mm_missing_data[modality] = missing_modality_data @@ -1692,20 +1700,18 @@ def _merge_mm_kwargs( # Need to calculate this at the beginning to avoid skipping cache logic # for subsequently repeated items in the same modality mm_is_cached = { - modality: cache.is_cached(hashes) - for modality, hashes in mm_hashes.items() + modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items() } mm_missing_next_idx = defaultdict[str, int](lambda: 0) - merged_kwargs = defaultdict[str, - list[Optional[MultiModalKwargsItem]]](list) - merged_prompt_updates = defaultdict[ - str, list[Sequence[ResolvedPromptUpdate]]](list) + merged_kwargs = defaultdict[str, list[Optional[MultiModalKwargsItem]]](list) + merged_prompt_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]]( + list + ) for modality, hashes in mm_hashes.items(): missing_kwargs = mm_missing_kwargs.get(modality, []) - missing_prompt_updates = mm_missing_prompt_updates.get( - modality, []) + missing_prompt_updates = mm_missing_prompt_updates.get(modality, []) for item_idx, item_hash in enumerate(hashes): kwargs: Optional[MultiModalKwargsItem] @@ -1723,10 +1729,12 @@ def _merge_mm_kwargs( kwargs, updates = cache.get_and_update_item(item, item_hash) merged_kwargs[modality].append(kwargs) - merged_prompt_updates[modality].append([ - self._recompute_cached_prompt_update(update, item_idx) - for update in updates - ]) + merged_prompt_updates[modality].append( + [ + self._recompute_cached_prompt_update(update, item_idx) + for update in updates + ] + ) mm_kwargs = MultiModalKwargsItems(merged_kwargs) mm_prompt_updates = dict(merged_prompt_updates) @@ -1756,15 +1764,16 @@ def _apply_hf_processor( mm_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_processed_data, - self._get_mm_fields_config(mm_processed_data, - hf_processor_mm_kwargs), + self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs), ) # Use overrides if provided; fallback to data-dependent hashing. - mm_hashes = self._hash_mm_items(mm_data_items, - hf_processor_mm_kwargs, - tokenization_kwargs, - mm_uuids=mm_uuids) + mm_hashes = self._hash_mm_items( + mm_data_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) mm_prompt_updates = self._get_mm_prompt_updates( mm_data_items, @@ -1805,10 +1814,12 @@ def _cached_apply_hf_processor( mm_uuids=mm_uuids, ) - mm_hashes = self._hash_mm_items(mm_data_items, - hf_processor_mm_kwargs, - tokenization_kwargs, - mm_uuids=mm_uuids) + mm_hashes = self._hash_mm_items( + mm_data_items, + hf_processor_mm_kwargs, + tokenization_kwargs, + mm_uuids=mm_uuids, + ) mm_missing_data_items = self._get_cache_missing_items( cache=cache, @@ -1833,8 +1844,9 @@ def _cached_apply_hf_processor( mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_missing_processed_data, - self._get_mm_fields_config(mm_missing_processed_data, - hf_processor_mm_kwargs), + self._get_mm_fields_config( + mm_missing_processed_data, hf_processor_mm_kwargs + ), ) mm_missing_prompt_updates = self._get_mm_prompt_updates( @@ -1897,8 +1909,9 @@ def _apply_prompt_updates( # of the search text in the prompt, we instead perform string-based # updates on the decoded token IDs, then encode them back. if not all( - all(update_idx is not None for update_idx in update_idxs) - for update_idxs in match_result.values()): + all(update_idx is not None for update_idx in update_idxs) + for update_idxs in match_result.values() + ): new_text, match_result = self._apply_text_matches( decode_tokens(tokenizer, token_ids), mm_prompt_updates, @@ -1910,16 +1923,17 @@ def _apply_prompt_updates( add_special_tokens=False, ) - matched_updates = defaultdict[ - str, list[Sequence[ResolvedPromptUpdate]]](list) + matched_updates = defaultdict[str, list[Sequence[ResolvedPromptUpdate]]](list) for modality, update_idxs in match_result.items(): for item_idx, update_idx in enumerate(update_idxs): assert update_idx is not None, ( "Failed to apply prompt replacement for " - f"mm_items[{modality!r}][{item_idx}]") + f"mm_items[{modality!r}][{item_idx}]" + ) matched_updates[modality].append( - [mm_prompt_updates[modality][item_idx][update_idx]]) + [mm_prompt_updates[modality][item_idx][update_idx]] + ) placeholders = self._find_mm_placeholders( new_token_ids, @@ -1944,20 +1958,18 @@ def _validate_mm_kwargs( "There is likely a problem with your " "implementation of merged multi-modal processor for this " "model (usually arising from an inconsistency between " - "`_call_hf_processor` and `_get_mm_fields_config`).") + "`_call_hf_processor` and `_get_mm_fields_config`)." + ) - def _validate_mm_placeholders( + def _validate_mm_updates( self, - mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_updates: MultiModalPromptUpdates, mm_item_counts: Mapping[str, int], ) -> None: for modality, item_count in mm_item_counts.items(): - placeholders = mm_placeholders.get(modality, []) + placeholders = mm_updates.get(modality, []) if len(placeholders) != item_count: - # NOTE: If you are a model developer, this can also arise from - # an inconsistency between `_call_hf_processor` and - # `_get_mm_fields_config` implementations raise RuntimeError( f"Expected there to be {item_count} prompt updates " f"corresponding to {item_count} {modality} items, but " @@ -1965,7 +1977,25 @@ def _validate_mm_placeholders( "This is likely because you forgot to include input " "placeholder tokens (e.g., `<image>`, `<|image_pad|>`) " "in the prompt. If the model has a chat template, make " - "sure you have applied it before calling `LLM.generate`.") + "sure you have applied it before calling `LLM.generate`." + ) + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_item_counts: Mapping[str, int], + ) -> None: + for modality, item_count in mm_item_counts.items(): + placeholders = mm_placeholders.get(modality, []) + + if len(placeholders) != item_count: + raise RuntimeError( + f"Expected there to be {item_count} prompt placeholders " + f"corresponding to {item_count} {modality} items, but " + f"instead found {len(placeholders)} prompt placeholders! " + "Make sure the implementation of `_call_hf_processor` and " + "`_get_mm_fields_config` are consistent with each other." + ) def _maybe_apply_prompt_updates( self, @@ -1977,6 +2007,7 @@ def _maybe_apply_prompt_updates( ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]: mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + self._validate_mm_updates(mm_prompt_updates, mm_item_counts) if is_update_applied: mm_placeholders = self._find_mm_placeholders( @@ -2056,7 +2087,6 @@ def apply( class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): - @abstractmethod def create_encoder_prompt( self, @@ -2090,15 +2120,16 @@ def _get_enc_dec_inputs( tokenizer = self.info.get_tokenizer() decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data) if isinstance(decoder_prompt_raw, str): - decoder_prompt_ids = encode_tokens(tokenizer, - decoder_prompt_raw, - add_special_tokens=False) + decoder_prompt_ids = encode_tokens( + tokenizer, decoder_prompt_raw, add_special_tokens=False + ) else: decoder_prompt_ids = decoder_prompt_raw mm_inputs = MultiModalEncDecInputs( encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], - **encoder_inputs) + **encoder_inputs, + ) mm_inputs["prompt_token_ids"] = decoder_prompt_ids return mm_inputs diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 74dc2314d2eb..05ba5a2abdd4 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -10,15 +10,26 @@ from PIL import Image import vllm.envs as envs -from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions, - ImageDummyOptions, VideoDummyOptions) +from vllm.config.multimodal import ( + AudioDummyOptions, + BaseDummyOptions, + ImageDummyOptions, + VideoDummyOptions, +) from vllm.logger import init_logger -from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs, MultiModalKwargsItems, - MultiModalPlaceholderDict) -from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, - EncDecMultiModalProcessor) +from .inputs import ( + MultiModalDataDict, + MultiModalEncDecInputs, + MultiModalInputs, + MultiModalKwargsItems, + MultiModalPlaceholderDict, +) +from .processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + EncDecMultiModalProcessor, +) logger = init_logger(__name__) @@ -29,6 +40,7 @@ class ProcessorInputs: Represents the keyword arguments to [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][]. """ + prompt: Union[str, list[int]] mm_data: MultiModalDataDict hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) @@ -86,7 +98,7 @@ def get_dummy_mm_data( mm_counts: Count of items per modality mm_options: Configurable options per modality (optional). If None, use model defaults for backward compatibility. - If provided, models can use these to customize dummy + If provided, models can use these to customize dummy data generation. """ raise NotImplementedError @@ -113,9 +125,11 @@ def get_dummy_processor_inputs( tokenization_kwargs = {"truncation": False} - return ProcessorInputs(prompt=dummy_text, - mm_data=dummy_mm_data, - tokenization_kwargs=tokenization_kwargs) + return ProcessorInputs( + prompt=dummy_text, + mm_data=dummy_mm_data, + tokenization_kwargs=tokenization_kwargs, + ) def _get_dummy_audios( self, @@ -130,10 +144,12 @@ def _get_dummy_audios( if overrides.length > length: logger.warning( "audio.length override (%d) exceeds model's " - "maximum length (%d), will be ignored", overrides.length, - length) + "maximum length (%d), will be ignored", + overrides.length, + length, + ) length = min(length, overrides.length) - audio = np.zeros((length, )) + audio = np.zeros((length,)) return [audio] * num_audios def _get_dummy_images( @@ -151,15 +167,19 @@ def _get_dummy_images( if overrides.width > width: logger.warning( "image.width override (%d) exceeds model's " - "maximum width (%d), will be ignored", overrides.width, - width) + "maximum width (%d), will be ignored", + overrides.width, + width, + ) width = min(width, overrides.width) if overrides.height: if overrides.height > height: logger.warning( "image.height override (%d) exceeds model's " "maximum height (%d), will be ignored", - overrides.height, height) + overrides.height, + height, + ) height = min(height, overrides.height) image = Image.new("RGB", (width, height), color=255) return [image] * num_images @@ -181,21 +201,27 @@ def _get_dummy_videos( logger.warning( "video.num_frames override (%d) exceeds model's " "maximum number of frames (%d), will be ignored", - overrides.num_frames, num_frames) + overrides.num_frames, + num_frames, + ) num_frames = min(num_frames, overrides.num_frames) if overrides.width: if overrides.width > width: logger.warning( "video.width override (%d) exceeds model's " - "maximum width (%d), will be ignored", overrides.width, - width) + "maximum width (%d), will be ignored", + overrides.width, + width, + ) width = min(width, overrides.width) if overrides.height: if overrides.height > height: logger.warning( "video.height override (%d) exceeds model's " "maximum height (%d), will be ignored", - overrides.height, height) + overrides.height, + height, + ) height = min(height, overrides.height) video = np.full((num_frames, width, height, 3), 255) return [video] * num_videos @@ -236,7 +262,8 @@ def _get_dummy_mm_inputs( factory = self.dummy_inputs processor_inputs = factory.get_dummy_processor_inputs( - seq_len, mm_counts, mm_options) + seq_len, mm_counts, mm_options + ) return self.processor.apply( prompt=processor_inputs.prompt, @@ -253,9 +280,10 @@ def _get_mm_num_tokens( placeholders_by_modality = mm_inputs["mm_placeholders"] return { - modality: - sum(item.get_num_embeds() if mm_embeddings_only else item.length - for item in placeholders) + modality: sum( + item.get_num_embeds() if mm_embeddings_only else item.length + for item in placeholders + ) for modality, placeholders in placeholders_by_modality.items() } @@ -330,8 +358,7 @@ def _get_mm_max_tokens( return max_tokens_per_item mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - return self._get_mm_num_tokens(mm_inputs, - mm_embeddings_only=mm_embeddings_only) + return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only) def get_mm_max_contiguous_tokens( self, @@ -349,6 +376,4 @@ def get_mm_max_contiguous_tokens( initializing the encoder cache size. """ - return self._get_mm_max_tokens(seq_len, - mm_counts, - mm_embeddings_only=False) + return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 24d3baa9b4e7..a526eaff715a 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -8,15 +8,21 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - cached_tokenizer_from_config) +from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config from vllm.utils import ClassRegistry from .cache import BaseMultiModalProcessorCache -from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, - InputProcessingContext) -from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, - DummyEncoderData, MultiModalProfiler) +from .processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + InputProcessingContext, +) +from .profiling import ( + BaseDummyInputsBuilder, + DummyDecoderData, + DummyEncoderData, + MultiModalProfiler, +) if TYPE_CHECKING: from vllm.config import ModelConfig @@ -38,8 +44,7 @@ class ProcessingInfoFactory(Protocol[_I_co]): def __call__( self, ctx: InputProcessingContext, - ) -> _I_co: - ... + ) -> _I_co: ... class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc] @@ -49,8 +54,7 @@ class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc] instance from the context. """ - def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: - ... + def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ... class MultiModalProcessorFactory(Protocol[_I]): # type: ignore[misc] @@ -66,8 +70,7 @@ def __call__( dummy_inputs: BaseDummyInputsBuilder[_I], *, cache: Optional[BaseMultiModalProcessorCache] = None, - ) -> BaseMultiModalProcessor[_I]: - ... + ) -> BaseMultiModalProcessor[_I]: ... @dataclass(frozen=True) @@ -93,8 +96,7 @@ class MultiModalRegistry: """ def __init__(self) -> None: - self._processor_factories = ClassRegistry[nn.Module, - _ProcessorFactories]() + self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() def _extract_mm_options( self, @@ -112,8 +114,7 @@ def _extract_mm_options( mm_options = { m: opt for m in model_config.multimodal_config.limit_per_prompt - if (opt := model_config.multimodal_config.get_dummy_options(m) - ) is not None + if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None } return mm_options if len(mm_options) > 0 else None @@ -121,8 +122,8 @@ def _extract_mm_options( def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: """ Checks if the model supports multimodal inputs. - Returns True if the model is multimodal with any non-zero supported - modalities, otherwise returns False, effectively running in + Returns True if the model is multimodal with any non-zero supported + modalities, otherwise returns False, effectively running in text-only mode. """ if not model_config.is_multimodal_model: @@ -135,11 +136,13 @@ def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: # Check if all supported modalities have limit == 0 if all( - mm_config.get_limit_per_prompt(modality) == 0 - for modality in supported_modalities): + mm_config.get_limit_per_prompt(modality) == 0 + for modality in supported_modalities + ): logger.info_once( "All limits of multimodal modalities supported by the model " - "are set to 0, running in text-only mode.") + "are set to 0, running in text-only mode." + ) return False return True @@ -165,10 +168,7 @@ def get_max_tokens_per_item_by_modality( return profiler.get_mm_max_contiguous_tokens( seq_len, - { - modality: 1 - for modality, limit in mm_limits.items() if limit > 0 - }, + {modality: 1 for modality, limit in mm_limits.items() if limit > 0}, ) def get_max_tokens_per_item_by_nonzero_modality( @@ -235,7 +235,9 @@ def wrapper(model_cls: N) -> N: logger.warning( "Model class %s already has a multi-modal processor " "registered to %s. It is overwritten by the new one.", - model_cls, self) + model_cls, + self, + ) self._processor_factories[model_cls] = _ProcessorFactories( info=info, @@ -315,15 +317,15 @@ def get_decoder_dummy_data( # count-only behavior remains unchanged. mm_options = self._extract_mm_options(model_config) - dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, - mm_options) + dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options) # Having more tokens is over-conservative but otherwise fine token_ids = dummy_data.prompt_token_ids if len(token_ids) < seq_len: raise AssertionError( f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(token_ids)} tokens instead.") + f"but found {len(token_ids)} tokens instead." + ) return dummy_data @@ -348,8 +350,7 @@ def get_encoder_dummy_data( # count-only behavior remains unchanged. mm_options = self._extract_mm_options(model_config) - dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, - mm_options) + dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts, mm_options) # Having more tokens is over-conservative but otherwise fine token_ids = dummy_data.prompt_token_ids @@ -368,15 +369,16 @@ def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: """ if not model_config.is_encoder_decoder: return 0 - max_tokens = self.\ - get_max_tokens_per_item_by_nonzero_modality(model_config) + max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config) if not max_tokens: # TODO - this function assumes encoder-decoder models are # multimodal. This will need to change when adding support for more # than whisper. return 0 - assert len(max_tokens) == 1, "Encoder-decoder models are expected \ + assert len(max_tokens) == 1, ( + "Encoder-decoder models are expected \ to implement the multimodal interface with at most one modality." + ) first_modality = next(iter(max_tokens)) return max_tokens[first_modality] diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index bab12fd1681a..c9dc077d0385 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -15,7 +15,6 @@ import numpy.typing as npt import torch from PIL import Image, UnidentifiedImageError -from typing_extensions import deprecated import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection @@ -29,8 +28,12 @@ _M = TypeVar("_M") if TYPE_CHECKING: - from .inputs import (BatchedTensorInputs, MultiModalKwargsItem, - MultiModalKwargsItems, MultiModalPlaceholderDict) + from .inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalPlaceholderDict, + ) else: BatchedTensorInputs = Any MultiModalKwargsItem = Any @@ -38,12 +41,12 @@ MultiModalPlaceholderDict = Any global_thread_pool = ThreadPoolExecutor( - max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT) + max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT +) atexit.register(global_thread_pool.shutdown) class MediaConnector: - def __init__( self, media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None, @@ -54,9 +57,9 @@ def __init__( ) -> None: """ Args: - media_io_kwargs: Additional args passed to process media - inputs, keyed by modalities. For example, - to set num_frames for video, set + media_io_kwargs: Additional args passed to process media + inputs, keyed by modalities. For example, + to set num_frames for video, set `--media-io-kwargs '{"video":{"num_frames":40}}'` connection: HTTP connection client to download media contents. allowed_local_media_path: A local directory to load media files @@ -64,8 +67,9 @@ def __init__( """ super().__init__() - self.media_io_kwargs: dict[str, dict[ - str, Any]] = media_io_kwargs if media_io_kwargs else {} + self.media_io_kwargs: dict[str, dict[str, Any]] = ( + media_io_kwargs if media_io_kwargs else {} + ) self.connection = connection if allowed_local_media_path: @@ -74,11 +78,13 @@ def __init__( if not allowed_local_media_path_.exists(): raise ValueError( "Invalid `--allowed-local-media-path`: The path " - f"{allowed_local_media_path_} does not exist.") + f"{allowed_local_media_path_} does not exist." + ) if not allowed_local_media_path_.is_dir(): raise ValueError( "Invalid `--allowed-local-media-path`: The path " - f"{allowed_local_media_path_} must be a directory.") + f"{allowed_local_media_path_} must be a directory." + ) else: allowed_local_media_path_ = None @@ -108,24 +114,29 @@ def _load_file_url( ) -> _M: # type: ignore[type-var] allowed_local_media_path = self.allowed_local_media_path if allowed_local_media_path is None: - raise RuntimeError("Cannot load local files without " - "`--allowed-local-media-path`.") + raise RuntimeError( + "Cannot load local files without `--allowed-local-media-path`." + ) filepath = Path(url2pathname(url_spec.path)) if allowed_local_media_path not in filepath.resolve().parents: raise ValueError( f"The file path {filepath} must be a subpath " - f"of `--allowed-local-media-path` {allowed_local_media_path}.") + f"of `--allowed-local-media-path` {allowed_local_media_path}." + ) return media_io.load_file(filepath) def _assert_url_in_allowed_media_domains(self, url_spec) -> None: - if self.allowed_media_domains and url_spec.hostname not in \ - self.allowed_media_domains: + if ( + self.allowed_media_domains + and url_spec.hostname not in self.allowed_media_domains + ): raise ValueError( f"The URL must be from one of the allowed domains: " f"{self.allowed_media_domains}. Input URL domain: " - f"{url_spec.hostname}") + f"{url_spec.hostname}" + ) def load_from_url( self, @@ -176,20 +187,19 @@ async def load_from_url_async( timeout=fetch_timeout, allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS, ) - future = loop.run_in_executor(global_thread_pool, - media_io.load_bytes, data) + future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data) return await future if url_spec.scheme == "data": - future = loop.run_in_executor(global_thread_pool, - self._load_data_url, url_spec, - media_io) + future = loop.run_in_executor( + global_thread_pool, self._load_data_url, url_spec, media_io + ) return await future if url_spec.scheme == "file": - future = loop.run_in_executor(global_thread_pool, - self._load_file_url, url_spec, - media_io) + future = loop.run_in_executor( + global_thread_pool, self._load_file_url, url_spec, media_io + ) return await future msg = "The URL must be either a HTTP, data or file URL." raise ValueError(msg) @@ -235,8 +245,9 @@ def fetch_image( By default, the image is converted into RGB format. """ - image_io = ImageMediaIO(image_mode=image_mode, - **self.media_io_kwargs.get("image", {})) + image_io = ImageMediaIO( + image_mode=image_mode, **self.media_io_kwargs.get("image", {}) + ) try: return self.load_from_url( @@ -259,8 +270,9 @@ async def fetch_image_async( By default, the image is converted into RGB format. """ - image_io = ImageMediaIO(image_mode=image_mode, - **self.media_io_kwargs.get("image", {})) + image_io = ImageMediaIO( + image_mode=image_mode, **self.media_io_kwargs.get("image", {}) + ) try: return await self.load_from_url_async( @@ -281,10 +293,10 @@ def fetch_video( """ Load video from an HTTP or base64 data URL. """ - image_io = ImageMediaIO(image_mode=image_mode, - **self.media_io_kwargs.get("image", {})) - video_io = VideoMediaIO(image_io, - **self.media_io_kwargs.get("video", {})) + image_io = ImageMediaIO( + image_mode=image_mode, **self.media_io_kwargs.get("image", {}) + ) + video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {})) return self.load_from_url( video_url, @@ -303,10 +315,10 @@ async def fetch_video_async( By default, the image is converted into RGB format. """ - image_io = ImageMediaIO(image_mode=image_mode, - **self.media_io_kwargs.get("image", {})) - video_io = VideoMediaIO(image_io, - **self.media_io_kwargs.get("video", {})) + image_io = ImageMediaIO( + image_mode=image_mode, **self.media_io_kwargs.get("image", {}) + ) + video_io = VideoMediaIO(image_io, **self.media_io_kwargs.get("video", {})) return await self.load_from_url_async( video_url, @@ -357,7 +369,8 @@ def encode_video_base64(frames: npt.NDArray) -> str: def argsort_mm_positions( - mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]: + mm_positions: MultiModalPlaceholderDict, +) -> list[tuple[str, int]]: """ Given a `MultiModalPlaceholderDict`, output a sequence of keys to sort the dictionary by `offset` (starting index in the input sequence) @@ -367,48 +380,23 @@ def argsort_mm_positions( A list of `(modality, idx)`, which can be used to access an item by `mm_positions[modality][idx]`. """ - flat_items = ((modality, idx, item) - for modality, items in mm_positions.items() - for idx, item in enumerate(items)) + flat_items = ( + (modality, idx, item) + for modality, items in mm_positions.items() + for idx, item in enumerate(items) + ) sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset) return [(modality, idx) for modality, idx, _ in sorted_flat_items] -# Temporary back-compatibility for plugins that define model runner -@deprecated("`group_mm_inputs_by_modality` is superseded by " - "`group_mm_kwargs_by_modality` and will be removed in v0.13. " - "Please use `group_mm_kwargs_by_modality` instead.") -def group_mm_inputs_by_modality( - mm_inputs: list[MultiModalKwargsItems] -) -> list[list[MultiModalKwargsItems]]: - if not mm_inputs: - return [] - - def modality_group_func( - mm_input: MultiModalKwargsItems) -> Union[str, int]: - # If the input has multiple modalities, return an id as the unique key - # for the mm_input input. - if len(mm_input) > 1: - return id(mm_input) - - elif len(mm_input) == 1: - return next(iter(mm_input.keys())) - - raise AssertionError("This line should be unreachable.") - - return [ - list(group) for _, group in groupby(mm_inputs, key=modality_group_func) - ] - - def group_mm_kwargs_by_modality( mm_kwargs: list[MultiModalKwargsItem], *, device: torch.types.Device = None, pin_memory: bool = False, - merge_by_field_config: bool = False, + merge_by_field_config: Optional[bool] = None, ) -> Iterable[tuple[str, int, BatchedTensorInputs]]: """Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same modality together into the same `MultiModalKwargs` instance. @@ -421,19 +409,26 @@ def group_mm_kwargs_by_modality( Yields: A tuple `(modality, num_items, grouped_kwargs)`. """ + if merge_by_field_config is None: + raise RuntimeError( + "`group_mm_kwargs_by_modality` now requires " + "`merge_by_field_config` arg, please update your model runner " + "according to https://github.com/vllm-project/vllm/pull/25676." + ) + from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems for modality, items in groupby(mm_kwargs, key=lambda item: item.modality): items_lst = list(items) - # TODO: Enable `merge_by_field_config` for all models - # to avoid creating an extra batch dimension (except for fields - # that are meant to be stacked anyway). - # We will also need to update each model to remove `flatten_bn`. + # TODO: Deprecate `merge_by_field_config` once + # we have migrated all in-tree models if merge_by_field_config: mm_kwargs_group: BatchedTensorInputs = dict( MultiModalKwargsItems.from_seq(items_lst).get_data( - pin_memory=pin_memory)) + pin_memory=pin_memory + ) + ) if device is not None: mm_kwargs_group = json_map_leaves( @@ -464,9 +459,7 @@ def fetch_audio( audio_url: URL of the audio file to fetch. audio_io_kwargs: Additional kwargs passed to handle audio IO. """ - media_io_kwargs = None if not audio_io_kwargs else { - "audio": audio_io_kwargs - } + media_io_kwargs = None if not audio_io_kwargs else {"audio": audio_io_kwargs} media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) return media_connector.fetch_audio(audio_url) @@ -480,9 +473,7 @@ def fetch_image( image_url: URL of the image file to fetch. image_io_kwargs: Additional kwargs passed to handle image IO. """ - media_io_kwargs = None if not image_io_kwargs else { - "image": image_io_kwargs - } + media_io_kwargs = None if not image_io_kwargs else {"image": image_io_kwargs} media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) return media_connector.fetch_image(image_url) @@ -496,8 +487,6 @@ def fetch_video( video_url: URL of the video file to fetch. video_io_kwargs: Additional kwargs passed to handle video IO. """ - media_io_kwargs = None if not video_io_kwargs else { - "video": video_io_kwargs - } + media_io_kwargs = None if not video_io_kwargs else {"video": video_io_kwargs} media_connector = MediaConnector(media_io_kwargs=media_io_kwargs) return media_connector.fetch_video(video_url) diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 6981f2ce5623..400d6a6be9be 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -21,8 +21,9 @@ def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray: num_frames, _, _, channels = frames.shape new_height, new_width = size - resized_frames = np.empty((num_frames, new_height, new_width, channels), - dtype=frames.dtype) + resized_frames = np.empty( + (num_frames, new_height, new_width, channels), dtype=frames.dtype + ) # lazy import cv2 to avoid bothering users who only use text models import cv2 @@ -40,8 +41,7 @@ def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray: return resize_video(frames, (new_height, new_width)) -def sample_frames_from_video(frames: npt.NDArray, - num_frames: int) -> npt.NDArray: +def sample_frames_from_video(frames: npt.NDArray, num_frames: int) -> npt.NDArray: total_frames = frames.shape[0] if num_frames == -1: return frames @@ -52,23 +52,19 @@ def sample_frames_from_video(frames: npt.NDArray, class VideoLoader: - @classmethod @abstractmethod - def load_bytes(cls, - data: bytes, - num_frames: int = -1, - **kwargs) -> tuple[npt.NDArray, dict[str, Any]]: + def load_bytes( + cls, data: bytes, num_frames: int = -1, **kwargs + ) -> tuple[npt.NDArray, dict[str, Any]]: raise NotImplementedError class VideoLoaderRegistry: - def __init__(self) -> None: self.name2class: dict[str, type] = {} def register(self, name: str): - def wrap(cls_to_register): self.name2class[name] = cls_to_register return cls_to_register @@ -87,7 +83,6 @@ def load(cls_name: str) -> VideoLoader: @VIDEO_LOADER_REGISTRY.register("opencv") class OpenCVVideoBackend(VideoLoader): - def get_cv2_video_api(self): import cv2.videoio_registry as vr @@ -127,10 +122,9 @@ def load_bytes( num_frames = total_frames_num frame_idx = list(range(0, num_frames)) else: - uniform_sampled_frames = np.linspace(0, - total_frames_num - 1, - num_frames, - dtype=int) + uniform_sampled_frames = np.linspace( + 0, total_frames_num - 1, num_frames, dtype=int + ) frame_idx = uniform_sampled_frames.tolist() width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) @@ -148,8 +142,10 @@ def load_bytes( frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) i += 1 - assert i == num_frames, (f"Expected reading {num_frames} frames, " - f"but only loaded {i} frames from video.") + assert i == num_frames, ( + f"Expected reading {num_frames} frames, " + f"but only loaded {i} frames from video." + ) # Use transformers transformers.video_utils.VideoMetadata format # NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata @@ -170,7 +166,6 @@ def load_bytes( @VIDEO_LOADER_REGISTRY.register("opencv_dynamic") class OpenCVDynamicVideoBackend(OpenCVVideoBackend): - @classmethod def load_bytes( cls, @@ -200,28 +195,28 @@ def load_bytes( frame_indices: Union[range, list[int]] if duration <= max_duration: n = int(math.floor(duration * fps)) - frame_indices = sorted({ - min(max_frame_idx, int(math.ceil(i * original_fps / fps))) - for i in range(n) - }) + frame_indices = sorted( + { + min(max_frame_idx, int(math.ceil(i * original_fps / fps))) + for i in range(n) + } + ) else: num_samples = int(max_duration * fps) if num_samples >= total_frames_num: frame_indices = range(total_frames_num) else: - target_seconds = np.linspace(0, - duration, - num_samples, - endpoint=True) - frame_indices = sorted({ - min(max_frame_idx, int(math.ceil(t * original_fps))) - for t in target_seconds - }) + target_seconds = np.linspace(0, duration, num_samples, endpoint=True) + frame_indices = sorted( + { + min(max_frame_idx, int(math.ceil(t * original_fps))) + for t in target_seconds + } + ) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - frames = np.empty((len(frame_indices), height, width, 3), - dtype=np.uint8) + frames = np.empty((len(frame_indices), height, width, 3), dtype=np.uint8) i = 0 for idx in range(total_frames_num): @@ -236,7 +231,8 @@ def load_bytes( assert i == len(frame_indices), ( f"Expected reading {len(frame_indices)} frames, " - f"but only loaded {i} frames from video.") + f"but only loaded {i} frames from video." + ) # Use transformers transformers.video_utils.VideoMetadata format metadata = { @@ -252,7 +248,6 @@ def load_bytes( class VideoMediaIO(MediaIO[npt.NDArray]): - def __init__( self, image_io: ImageMediaIO, @@ -273,22 +268,22 @@ def __init__( self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend) def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]: - return self.video_loader.load_bytes(data, - num_frames=self.num_frames, - **self.kwargs) + return self.video_loader.load_bytes( + data, num_frames=self.num_frames, **self.kwargs + ) - def load_base64(self, media_type: str, - data: str) -> tuple[npt.NDArray, dict[str, Any]]: + def load_base64( + self, media_type: str, data: str + ) -> tuple[npt.NDArray, dict[str, Any]]: if media_type.lower() == "video/jpeg": load_frame = partial( self.image_io.load_base64, "image/jpeg", ) - return np.stack([ - np.asarray(load_frame(frame_data)) - for frame_data in data.split(",") - ]), {} + return np.stack( + [np.asarray(load_frame(frame_data)) for frame_data in data.split(",")] + ), {} return self.load_bytes(base64.b64decode(data)) @@ -312,8 +307,7 @@ def encode_base64( image_format=video_format, ) - return ",".join( - encode_frame(Image.fromarray(frame)) for frame in video) + return ",".join(encode_frame(Image.fromarray(frame)) for frame in video) msg = "Only JPEG format is supported for now." raise NotImplementedError(msg) diff --git a/vllm/outputs.py b/vllm/outputs.py index 1ed20461def1..dc183bd8dbe9 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -51,13 +51,15 @@ def finished(self) -> bool: return self.finish_reason is not None def __repr__(self) -> str: - return (f"CompletionOutput(index={self.index}, " - f"text={self.text!r}, " - f"token_ids={self.token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"logprobs={self.logprobs}, " - f"finish_reason={self.finish_reason}, " - f"stop_reason={self.stop_reason})") + return ( + f"CompletionOutput(index={self.index}, " + f"text={self.text!r}, " + f"token_ids={self.token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"logprobs={self.logprobs}, " + f"finish_reason={self.finish_reason}, " + f"stop_reason={self.stop_reason})" + ) @dataclass @@ -67,14 +69,16 @@ class PoolingOutput: Args: data: The extracted hidden states. """ + data: torch.Tensor def __repr__(self) -> str: - return (f"PoolingOutput(data={self.data})") + return f"PoolingOutput(data={self.data})" def __eq__(self, other: object) -> bool: - return (isinstance(other, self.__class__) and bool( - (self.data == other.data).all())) + return isinstance(other, self.__class__) and bool( + (self.data == other.data).all() + ) class RequestOutput: @@ -122,8 +126,9 @@ def __init__( **kwargs: Any, ) -> None: if kwargs: - logger.warning_once("RequestOutput: Ignoring extra arguments: %s", - str(kwargs)) + logger.warning_once( + "RequestOutput: Ignoring extra arguments: %s", str(kwargs) + ) self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids @@ -150,16 +155,15 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None: if aggregate: # Merge outputs with same index completion.text += next_completion.text - if not isinstance(completion.token_ids, - MutableSequence): + if not isinstance(completion.token_ids, MutableSequence): completion.token_ids = list(completion.token_ids) completion.token_ids.extend(next_completion.token_ids) if next_completion.logprobs: assert completion.logprobs is not None - completion.logprobs.extend( - next_completion.logprobs) + completion.logprobs.extend(next_completion.logprobs) completion.cumulative_logprob = ( - next_completion.cumulative_logprob) + next_completion.cumulative_logprob + ) completion.finish_reason = next_completion.finish_reason completion.stop_reason = next_completion.stop_reason else: @@ -170,18 +174,20 @@ def add(self, next_output: "RequestOutput", aggregate: bool) -> None: self.outputs.append(next_completion) def __repr__(self) -> str: - return (f"RequestOutput(request_id={self.request_id}, " - f"prompt={self.prompt!r}, " - f"prompt_token_ids={self.prompt_token_ids}, " - f"encoder_prompt={self.encoder_prompt!r}, " - f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " - f"prompt_logprobs={self.prompt_logprobs}, " - f"outputs={self.outputs}, " - f"finished={self.finished}, " - f"metrics={self.metrics}, " - f"lora_request={self.lora_request}, " - f"num_cached_tokens={self.num_cached_tokens}, " - f"multi_modal_placeholders={self.multi_modal_placeholders})") + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"encoder_prompt={self.encoder_prompt!r}, " + f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"outputs={self.outputs}, " + f"finished={self.finished}, " + f"metrics={self.metrics}, " + f"lora_request={self.lora_request}, " + f"num_cached_tokens={self.num_cached_tokens}, " + f"multi_modal_placeholders={self.multi_modal_placeholders})" + ) _O = TypeVar("_O", default=PoolingOutput) @@ -198,18 +204,21 @@ class PoolingRequestOutput(Generic[_O]): finished (bool): A flag indicating whether the pooling is completed. """ - def __init__(self, request_id: str, outputs: _O, - prompt_token_ids: list[int], finished: bool): + def __init__( + self, request_id: str, outputs: _O, prompt_token_ids: list[int], finished: bool + ): self.request_id = request_id self.prompt_token_ids = prompt_token_ids self.finished = finished self.outputs = outputs def __repr__(self): - return (f"{type(self).__name__}(request_id={self.request_id!r}, " - f"outputs={self.outputs!r}, " - f"prompt_token_ids={self.prompt_token_ids}, " - f"finished={self.finished})") + return ( + f"{type(self).__name__}(request_id={self.request_id!r}, " + f"outputs={self.outputs!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})" + ) @dataclass @@ -220,6 +229,7 @@ class EmbeddingOutput: embedding: The embedding vector, which is a list of floats. Its length depends on the hidden dimension of the model. """ + embedding: list[float] @staticmethod @@ -239,7 +249,6 @@ def __repr__(self) -> str: class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]): - @staticmethod def from_base(request_output: PoolingRequestOutput): return EmbeddingRequestOutput( @@ -258,6 +267,7 @@ class ClassificationOutput: probs: The probability vector, which is a list of floats. Its length depends on the number of classes. """ + probs: list[float] @staticmethod @@ -278,7 +288,6 @@ def __repr__(self) -> str: class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]): - @staticmethod def from_base(request_output: PoolingRequestOutput): return ClassificationRequestOutput( @@ -296,6 +305,7 @@ class ScoringOutput: Args: score: The similarity score, which is a scalar value. """ + score: float @staticmethod @@ -314,7 +324,6 @@ def __repr__(self) -> str: class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]): - @staticmethod def from_base(request_output: PoolingRequestOutput): return ScoringRequestOutput( diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 7549de480ee6..5154b1cea782 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -19,12 +19,14 @@ def vllm_version_matches_substr(substr: str) -> bool: Check to see if the vLLM version matches a substring. """ from importlib.metadata import PackageNotFoundError, version + try: vllm_version = version("vllm") except PackageNotFoundError as e: logger.warning( "The vLLM package was not found, so its version could not be " - "inspected. This may cause platform detection to fail.") + "inspected. This may cause platform detection to fail." + ) raise e return substr in vllm_version @@ -45,6 +47,7 @@ def tpu_platform_plugin() -> Optional[str]: # has TPUs. import libtpu # noqa: F401 + logger.debug("Confirmed TPU platform is available.") return "vllm.platforms.tpu.TpuPlatform" except Exception as e: @@ -57,6 +60,7 @@ def cuda_platform_plugin() -> Optional[str]: logger.debug("Checking if CUDA platform is available.") try: from vllm.utils import import_pynvml + pynvml = import_pynvml() pynvml.nvmlInit() try: @@ -65,21 +69,22 @@ def cuda_platform_plugin() -> Optional[str]: # we need to check if vllm is built with cpu too. # Otherwise, vllm will always activate cuda plugin # on a GPU machine, even if in a cpu build. - is_cuda = (pynvml.nvmlDeviceGetCount() > 0 - and not vllm_version_matches_substr("cpu")) + is_cuda = ( + pynvml.nvmlDeviceGetCount() > 0 + and not vllm_version_matches_substr("cpu") + ) if pynvml.nvmlDeviceGetCount() <= 0: - logger.debug( - "CUDA platform is not available because no GPU is found.") + logger.debug("CUDA platform is not available because no GPU is found.") if vllm_version_matches_substr("cpu"): - logger.debug("CUDA platform is not available because" - " vLLM is built with CPU.") + logger.debug( + "CUDA platform is not available because vLLM is built with CPU." + ) if is_cuda: logger.debug("Confirmed CUDA platform is available.") finally: pynvml.nvmlShutdown() except Exception as e: - logger.debug("Exception happens when checking CUDA platform: %s", - str(e)) + logger.debug("Exception happens when checking CUDA platform: %s", str(e)) if "nvml" not in e.__class__.__name__.lower(): # If the error is not related to NVML, re-raise it. raise e @@ -88,8 +93,9 @@ def cuda_platform_plugin() -> Optional[str]: import os def cuda_is_jetson() -> bool: - return os.path.isfile("/etc/nv_tegra_release") \ - or os.path.exists("/sys/class/tegra-firmware") + return os.path.isfile("/etc/nv_tegra_release") or os.path.exists( + "/sys/class/tegra-firmware" + ) if cuda_is_jetson(): logger.debug("Confirmed CUDA platform is available on Jetson.") @@ -105,14 +111,14 @@ def rocm_platform_plugin() -> Optional[str]: logger.debug("Checking if ROCm platform is available.") try: import amdsmi + amdsmi.amdsmi_init() try: if len(amdsmi.amdsmi_get_processor_handles()) > 0: is_rocm = True logger.debug("Confirmed ROCm platform is available.") else: - logger.debug("ROCm platform is not available because" - " no GPU is found.") + logger.debug("ROCm platform is not available because no GPU is found.") finally: amdsmi.amdsmi_shut_down() except Exception as e: @@ -128,18 +134,19 @@ def xpu_platform_plugin() -> Optional[str]: # installed IPEX if the machine has XPUs. import intel_extension_for_pytorch # noqa: F401 import torch + if supports_xccl(): dist_backend = "xccl" else: dist_backend = "ccl" import oneccl_bindings_for_pytorch # noqa: F401 - if hasattr(torch, 'xpu') and torch.xpu.is_available(): + if hasattr(torch, "xpu") and torch.xpu.is_available(): is_xpu = True from vllm.platforms.xpu import XPUPlatform + XPUPlatform.dist_backend = dist_backend - logger.debug("Confirmed %s backend is available.", - XPUPlatform.dist_backend) + logger.debug("Confirmed %s backend is available.", XPUPlatform.dist_backend) logger.debug("Confirmed XPU platform is available.") except Exception as e: logger.debug("XPU platform is not available because: %s", str(e)) @@ -153,14 +160,17 @@ def cpu_platform_plugin() -> Optional[str]: try: is_cpu = vllm_version_matches_substr("cpu") if is_cpu: - logger.debug("Confirmed CPU platform is available because" - " vLLM is built with CPU.") + logger.debug( + "Confirmed CPU platform is available because vLLM is built with CPU." + ) if not is_cpu: import sys + is_cpu = sys.platform.startswith("darwin") if is_cpu: - logger.debug("Confirmed CPU platform is available" - " because the machine is MacOS.") + logger.debug( + "Confirmed CPU platform is available because the machine is MacOS." + ) except Exception as e: logger.debug("CPU platform is not available because: %s", str(e)) @@ -169,21 +179,20 @@ def cpu_platform_plugin() -> Optional[str]: builtin_platform_plugins = { - 'tpu': tpu_platform_plugin, - 'cuda': cuda_platform_plugin, - 'rocm': rocm_platform_plugin, - 'xpu': xpu_platform_plugin, - 'cpu': cpu_platform_plugin, + "tpu": tpu_platform_plugin, + "cuda": cuda_platform_plugin, + "rocm": rocm_platform_plugin, + "xpu": xpu_platform_plugin, + "cpu": cpu_platform_plugin, } def resolve_current_platform_cls_qualname() -> str: - platform_plugins = load_plugins_by_group('vllm.platform_plugins') + platform_plugins = load_plugins_by_group("vllm.platform_plugins") activated_plugins = [] - for name, func in chain(builtin_platform_plugins.items(), - platform_plugins.items()): + for name, func in chain(builtin_platform_plugins.items(), platform_plugins.items()): try: assert callable(func) platform_cls_qualname = func() @@ -193,43 +202,41 @@ def resolve_current_platform_cls_qualname() -> str: pass activated_builtin_plugins = list( - set(activated_plugins) & set(builtin_platform_plugins.keys())) - activated_oot_plugins = list( - set(activated_plugins) & set(platform_plugins.keys())) + set(activated_plugins) & set(builtin_platform_plugins.keys()) + ) + activated_oot_plugins = list(set(activated_plugins) & set(platform_plugins.keys())) if len(activated_oot_plugins) >= 2: raise RuntimeError( "Only one platform plugin can be activated, but got: " - f"{activated_oot_plugins}") + f"{activated_oot_plugins}" + ) elif len(activated_oot_plugins) == 1: platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]() - logger.info("Platform plugin %s is activated", - activated_oot_plugins[0]) + logger.info("Platform plugin %s is activated", activated_oot_plugins[0]) elif len(activated_builtin_plugins) >= 2: raise RuntimeError( "Only one platform plugin can be activated, but got: " - f"{activated_builtin_plugins}") + f"{activated_builtin_plugins}" + ) elif len(activated_builtin_plugins) == 1: - platform_cls_qualname = builtin_platform_plugins[ - activated_builtin_plugins[0]]() - logger.info("Automatically detected platform %s.", - activated_builtin_plugins[0]) + platform_cls_qualname = builtin_platform_plugins[activated_builtin_plugins[0]]() + logger.info("Automatically detected platform %s.", activated_builtin_plugins[0]) else: platform_cls_qualname = "vllm.platforms.interface.UnspecifiedPlatform" - logger.info( - "No platform detected, vLLM is running on UnspecifiedPlatform") + logger.info("No platform detected, vLLM is running on UnspecifiedPlatform") return platform_cls_qualname _current_platform = None -_init_trace: str = '' +_init_trace: str = "" if TYPE_CHECKING: current_platform: Platform def __getattr__(name: str): - if name == 'current_platform': + if name == "current_platform": # lazy init current_platform. # 1. out-of-tree platform plugins need `from vllm.platforms import # Platform` so that they can inherit `Platform` class. Therefore, @@ -244,19 +251,14 @@ def __getattr__(name: str): global _current_platform if _current_platform is None: platform_cls_qualname = resolve_current_platform_cls_qualname() - _current_platform = resolve_obj_by_qualname( - platform_cls_qualname)() + _current_platform = resolve_obj_by_qualname(platform_cls_qualname)() global _init_trace _init_trace = "".join(traceback.format_stack()) return _current_platform elif name in globals(): return globals()[name] else: - raise AttributeError( - f"No attribute named '{name}' exists in {__name__}.") + raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") -__all__ = [ - 'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum', - "_init_trace" -] +__all__ = ["Platform", "PlatformEnum", "current_platform", "CpuArchEnum", "_init_trace"] diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 436e295e58e6..2f87664003dc 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -28,9 +28,9 @@ def get_max_threads(pid=0): - if hasattr(os, 'sched_getaffinity'): + if hasattr(os, "sched_getaffinity"): return len(os.sched_getaffinity(pid)) - elif platform.system() == 'Darwin': + elif platform.system() == "Darwin": return os.cpu_count() else: raise NotImplementedError("Unsupported OS") @@ -60,7 +60,8 @@ def json_decoder(obj_dict: dict): return LogicalCPUInfo( id=LogicalCPUInfo._int(id), physical_core=LogicalCPUInfo._int(physical_core), - numa_node=LogicalCPUInfo._int(numa_node)) + numa_node=LogicalCPUInfo._int(numa_node), + ) else: return obj_dict @@ -77,13 +78,42 @@ class CpuPlatform(Platform): def supported_dtypes(self) -> list[torch.dtype]: if self.get_cpu_architecture() == CpuArchEnum.POWERPC: return [torch.bfloat16, torch.float32] - elif (self.get_cpu_architecture() == CpuArchEnum.ARM - and sys.platform.startswith("darwin")): - if (subprocess.check_output( - ["sysctl -n hw.optional.arm.FEAT_BF16"], - shell=True).strip() == b"1"): + elif self.get_cpu_architecture() == CpuArchEnum.ARM and sys.platform.startswith( + "darwin" + ): + if ( + subprocess.check_output( + ["sysctl -n hw.optional.arm.FEAT_BF16"], shell=True + ).strip() + == b"1" + ): return [torch.bfloat16, torch.float16, torch.float32] return [torch.float16, torch.float32] + elif self.get_cpu_architecture() == CpuArchEnum.RISCV: + # Workaround for Issue #25655: RISC-V scheduler bug with float16 + # + # Background: + # - RISC-V currently uses scalar code path + # - There is a latent bug in the vLLM scheduler that provides + # invalid + # physical_block_idx values under certain conditions + # - This bug causes segmentation faults when using float16 + # dtype on RISC-V + # - Testing shows that forcing float32 successfully bypasses + # this issue + # + # Technical details: + # - The bug manifests as out-of-bounds physical_block_idx in + # block_tables + # - Only occurs on RISC-V hardware + # tested on Sophgo SG2044 + # - Does not reproduce on x86 or other architectures + # - Root cause is in Python-level scheduling logic, + # not C++ kernels + # + # This is a temporary workaround until the scheduler bug is fixed. + # See: https://github.com/vllm-project/vllm/issues/25655 + return [torch.float32] # x86/aarch64 CPU has supported both bf16 and fp16 natively. return [torch.bfloat16, torch.float16, torch.float32] @@ -92,18 +122,26 @@ def get_device_name(cls, device_id: int = 0) -> str: return "cpu" @classmethod - def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, - dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool, use_sparse: bool) -> str: + def get_attn_backend_cls( + cls, + selected_backend: "_Backend", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + ) -> str: from vllm.attention.backends.registry import _Backend + if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: raise NotImplementedError("MLA is not supported on CPU.") if use_sparse: - raise NotImplementedError( - "Sparse Attention is not supported on CPU.") + raise NotImplementedError("Sparse Attention is not supported on CPU.") logger.info("Using Torch SDPA backend.") if not use_v1: raise ValueError("CPU backend only supports V1.") @@ -119,7 +157,8 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: kv_cache_space = 4 * GiB_bytes # type: ignore logger.warning_once( "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) " - "for CPU backend is not set, using 4 by default.") + "for CPU backend is not set, using 4 by default." + ) else: kv_cache_space *= GiB_bytes @@ -153,53 +192,66 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if not ipex_available and cache_config.block_size != 16: raise RuntimeError( f"--block-size={cache_config.block_size} requires" - " intel_extension_for_pytorch") + " intel_extension_for_pytorch" + ) scheduler_config = vllm_config.scheduler_config - if ((scheduler_config.chunked_prefill_enabled - or cache_config.enable_prefix_caching) - and cache_config.cache_dtype != "auto"): - raise RuntimeError("Chunked-prefill and prefix-cache on the CPU " - "backend is not compatible with FP8 KV cache.") + if ( + scheduler_config.chunked_prefill_enabled + or cache_config.enable_prefix_caching + ) and cache_config.cache_dtype != "auto": + raise RuntimeError( + "Chunked-prefill and prefix-cache on the CPU " + "backend is not compatible with FP8 KV cache." + ) if cache_config.cache_dtype == "fp8_e4m3": cache_config.cache_dtype = "fp8_e5m2" logger.warning( - "CPU backend doesn't support fp8_e4m3 KV cache type, " - "cast to fp8_e5m2.") - - if (cache_config.cache_dtype != "auto" and model_config is not None - and model_config.dtype == torch.half): - logger.warning("FP8 KV cache on the CPU backend only does not" - " support fp16 for now, cast to bf16.") + "CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2." + ) + + if ( + cache_config.cache_dtype != "auto" + and model_config is not None + and model_config.dtype == torch.half + ): + logger.warning( + "FP8 KV cache on the CPU backend only does not" + " support fp16 for now, cast to bf16." + ) model_config.dtype = torch.bfloat16 - cache_config.cpu_kvcache_space_bytes = \ - CpuPlatform.get_device_total_memory() + cache_config.cpu_kvcache_space_bytes = CpuPlatform.get_device_total_memory() parallel_config = vllm_config.parallel_config - if (parallel_config.world_size > 1 - and parallel_config.distributed_executor_backend is not None - and parallel_config.distributed_executor_backend != "mp"): - logger.warning(("%s is not supported on CPU, fallback to mp " - "distributed executor backend."), - parallel_config.distributed_executor_backend) + if ( + parallel_config.world_size > 1 + and parallel_config.distributed_executor_backend is not None + and parallel_config.distributed_executor_backend != "mp" + ): + logger.warning( + ( + "%s is not supported on CPU, fallback to mp " + "distributed executor backend." + ), + parallel_config.distributed_executor_backend, + ) parallel_config.distributed_executor_backend = "mp" if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker" # Disable DBO if parallel_config.enable_dbo: - logger.warning( - "Dual-Batch Overlap is not supported on CPU, disabled.") + logger.warning("Dual-Batch Overlap is not supported on CPU, disabled.") parallel_config.enable_dbo = False # Note: workaround for v1 gpu_model_runner from vllm.config import CompilationLevel + vllm_config.compilation_config.cudagraph_capture_sizes = [] compilation_config = vllm_config.compilation_config if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: - # Note: vLLM V1 is using PIECEWISE level compilation, which will # take time to compile kernels just-in-time with the inductor # backend. For CPU CI tests, most of them are executed fast and @@ -214,16 +266,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: compilation_config.level = CompilationLevel.DYNAMO_ONCE compilation_config.backend = backend - compilation_config.inductor_compile_config.update({ - "dce": - True, - "size_asserts": - False, - "nan_asserts": - False, - "epilogue_fusion": - True, - }) + compilation_config.inductor_compile_config.update( + { + "dce": True, + "size_asserts": False, + "nan_asserts": False, + "epilogue_fusion": True, + } + ) if compilation_config.use_inductor: compilation_config.custom_ops = ["none"] @@ -253,51 +303,56 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if "libiomp5.so" in ld_prealod_str: # The time(milliseconds) that a thread should wait after # completing the execution of a parallel region, before sleeping. - os.environ['KMP_BLOCKTIME'] = "1" + os.environ["KMP_BLOCKTIME"] = "1" # Prevents the CPU to run into low performance state - os.environ['KMP_TPAUSE'] = "0" + os.environ["KMP_TPAUSE"] = "0" # Provides fine granularity parallelism - os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist" - os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist" - os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist" + os.environ["KMP_FORKJOIN_BARRIER_PATTERN"] = "dist,dist" + os.environ["KMP_PLAIN_BARRIER_PATTERN"] = "dist,dist" + os.environ["KMP_REDUCTION_BARRIER_PATTERN"] = "dist,dist" # To hint IPEX uses shared memory based AllReduce os.environ["LOCAL_WORLD_SIZE"] = str( - vllm_config.parallel_config.tensor_parallel_size) + vllm_config.parallel_config.tensor_parallel_size + ) if model_config is not None and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") + "prefill and prefix caching to be disabled." + ) vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) + DEFAULT_MAX_NUM_BATCHED_TOKENS, + ) @classmethod - def get_allowed_cpu_core_node_list( - cls) -> tuple[list[int], list[LogicalCPUInfo]]: + def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]: assert platform.system() == "Linux" # Init LogicalCPUInfo from lscpu - lscpu_output = subprocess.check_output("lscpu -J -e=CPU,CORE,NODE", - shell=True, - text=True) + lscpu_output = subprocess.check_output( + "lscpu -J -e=CPU,CORE,NODE", shell=True, text=True + ) logical_cpu_list: list[LogicalCPUInfo] = json.loads( - lscpu_output, object_hook=LogicalCPUInfo.json_decoder)['cpus'] + lscpu_output, object_hook=LogicalCPUInfo.json_decoder + )["cpus"] # Filter CPUs with invalid attributes logical_cpu_list = [ - x for x in logical_cpu_list + x + for x in logical_cpu_list if -1 not in (x.id, x.physical_core, x.numa_node) ] # Filter allowed CPUs - allowed_cpu_id_list = os.sched_getaffinity(0) - logical_cpu_list = [ - x for x in logical_cpu_list if x.id in allowed_cpu_id_list - ] + if hasattr(os, "sched_getaffinity"): + allowed_cpu_id_list = os.sched_getaffinity(0) + else: + raise NotImplementedError("Unsupported OS") + logical_cpu_list = [x for x in logical_cpu_list if x.id in allowed_cpu_id_list] # Get allowed NUMA nodes allowed_numa_nodes = set() @@ -306,8 +361,8 @@ def get_allowed_cpu_core_node_list( allowed_numa_nodes_list = sorted(allowed_numa_nodes) env_key = CpuPlatform.device_control_env_var - if (env_key in os.environ and os.environ[env_key] != ""): - visible_nodes = [int(s) for s in os.environ[env_key].split(',')] + if env_key in os.environ and os.environ[env_key] != "": + visible_nodes = [int(s) for s in os.environ[env_key].split(",")] allowed_numa_nodes_list = [ x for x in visible_nodes if x in allowed_cpu_id_list ] diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b7baa614957e..20568e0d6c51 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -41,7 +41,6 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: - @wraps(fn) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: pynvml.nvmlInit() @@ -86,9 +85,7 @@ def set_device(cls, device: torch.device) -> None: _ = torch.zeros(1, device=device) @classmethod - def get_device_capability(cls, - device_id: int = 0 - ) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]: raise NotImplementedError @classmethod @@ -122,8 +119,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing if model_config is not None and model_config.use_mla: - use_sparse = hasattr(vllm_config.model_config.hf_config, - "index_topk") + use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the @@ -146,43 +142,47 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: use_flashmla = True else: # Forced case - use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") - use_cutlass_mla = ( - envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") - use_flashinfer_mla = ( - envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA") + use_flashmla = envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" + use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" + use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" from vllm.attention.ops.flashmla import is_flashmla_supported - if use_flashmla and is_flashmla_supported()[0] \ - and cache_config.block_size != 64: + + if ( + use_flashmla + and is_flashmla_supported()[0] + and cache_config.block_size != 64 + ): cache_config.block_size = 64 - logger.info( - "Forcing kv cache block size to 64 for FlashMLA backend.") + logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") if use_cutlass_mla and cache_config.block_size != 128: cache_config.block_size = 128 - logger.info("Forcing kv cache block size to 128 for " - "CUTLASS_MLA backend.") + logger.info( + "Forcing kv cache block size to 128 for CUTLASS_MLA backend." + ) if use_flashinfer_mla and cache_config.block_size not in [32, 64]: cache_config.block_size = 64 logger.info( - "Forcing kv cache block size to 64 for FlashInferMLA " - "backend.") + "Forcing kv cache block size to 64 for FlashInferMLA backend." + ) # TODO(Chen): remove this hacky code if use_sparse and cache_config.block_size != 64: cache_config.block_size = 64 logger.info( - "Forcing kv cache block size to 64 for FlashMLASparse " - "backend.") + "Forcing kv cache block size to 64 for FlashMLASparse backend." + ) # lazy import to avoid circular import from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config - if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" - and parallel_config.data_parallel_size > 1 - and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): + if ( + envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + and parallel_config.data_parallel_size > 1 + and compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): # TODO: Piecewise Cuda graph might be enabled # if torch compile cache key issue fixed # See https://github.com/vllm-project/vllm/pull/25093 @@ -192,20 +192,20 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "CUDA Graphs. " "In order to use CUDA Graphs for decode-optimized workloads, " "set VLLM_ALL2ALL_BACKEND to another option, such as " - "deepep_low_latency, pplx, or allgather_reducescatter.") + "deepep_low_latency, pplx, or allgather_reducescatter." + ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE @classmethod - def get_current_memory_usage(cls, - device: Optional[torch.types.Device] = None - ) -> float: + def get_current_memory_usage( + cls, device: Optional[torch.types.Device] = None + ) -> float: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> "_Backend": + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": from vllm.attention.backends.registry import _Backend # For Blackwell GPUs, force TORCH_SDPA for now. @@ -217,10 +217,14 @@ def get_vit_attn_backend(cls, head_size: int, return _Backend.XFORMERS if cls.has_device_capability(80): - FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + FLASH_ATTN_V1 = ( + "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + ) from vllm.attention.selector import is_attn_backend_supported + is_default_fa_supported = is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, allow_import_error=False) + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False + ) if is_default_fa_supported: return _Backend.FLASH_ATTN else: @@ -231,83 +235,109 @@ def get_vit_attn_backend(cls, head_size: int, return _Backend.XFORMERS @classmethod - def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, use_mla, - has_sink, use_sparse) -> str: + def get_attn_backend_cls( + cls, + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink, + use_sparse, + ) -> str: from vllm.attention.backends.registry import _Backend + if use_mla: if not use_v1: raise RuntimeError( "MLA attention backends require the V1 engine. " - "Set VLLM_USE_V1=1 to enable them.") + "Set VLLM_USE_V1=1 to enable them." + ) from vllm.attention.ops.flashmla import is_flashmla_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla if use_sparse: logger.info_once("Using Sparse MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla.flashmla_sparse." - "FlashMLASparseBackend") + return ( + "vllm.v1.attention.backends.mla.flashmla_sparse." + "FlashMLASparseBackend" + ) use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( - selected_backend is None and cls.is_device_capability(100) - and block_size == 128) + selected_backend is None + and cls.is_device_capability(100) + and block_size == 128 + ) use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or ( - selected_backend is None and cls.is_device_capability(100) - and block_size in [32, 64]) + selected_backend is None + and cls.is_device_capability(100) + and block_size in [32, 64] + ) use_flashmla = selected_backend == _Backend.FLASHMLA or ( - selected_backend is None and is_flashmla_supported()[0]) + selected_backend is None and is_flashmla_supported()[0] + ) use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( - selected_backend is None and flash_attn_supports_mla()) + selected_backend is None and flash_attn_supports_mla() + ) use_triton = selected_backend == _Backend.TRITON_MLA or ( - selected_backend is None) + selected_backend is None + ) if use_cutlassmla: logger.info_once("Using Cutlass MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "cutlass_mla.CutlassMLABackend") + return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" if use_flashinfermla: - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) + from vllm.v1.attention.backends.utils import set_kv_cache_layout + set_kv_cache_layout("HND") logger.info_once("Using FlashInfer MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashinfer_mla.FlashInferMLABackend") + return ( + "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" + ) if use_flashmla: if block_size != 64: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", - block_size) + block_size, + ) else: logger.info_once("Using FlashMLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashmla.FlashMLABackend") + return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" if use_flashattn: - logger.info_once( - "Using FlashAttention MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashattn_mla.FlashAttnMLABackend") + logger.info_once("Using FlashAttention MLA backend on V1 engine.") + return ( + "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" + ) if use_triton: logger.info_once("Using Triton MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") + return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" if use_v1: FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 - FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 - TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 - FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + FLEX_ATTENTION_V1 = ( + "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 + ) + TRITON_ATTN = ( + "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501 + ) + FLASH_ATTN_V1 = ( + "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + ) TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501 XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501 - use_fp8_kv_cache = (kv_cache_dtype is not None - and kv_cache_dtype.startswith("fp8")) + use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith( + "fp8" + ) if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") if cls.has_device_capability(100): - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) + from vllm.v1.attention.backends.utils import set_kv_cache_layout + set_kv_cache_layout("HND") return FLASHINFER_V1 elif selected_backend == _Backend.FLEX_ATTENTION: @@ -332,13 +362,14 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, # Prefer FlashInfer for Blackwell GPUs if installed if cls.is_device_capability(100): if is_default_backend_supported := is_attn_backend_supported( - FLASHINFER_V1, head_size, dtype): - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) + FLASHINFER_V1, head_size, dtype + ): + from vllm.v1.attention.backends.utils import set_kv_cache_layout logger.info_once( "Using FlashInfer backend with HND KV cache layout on " - "V1 engine by default for Blackwell (SM 10.0) GPUs.") + "V1 engine by default for Blackwell (SM 10.0) GPUs." + ) set_kv_cache_layout("HND") return FLASHINFER_V1 @@ -347,19 +378,18 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.warning_once( "FlashInfer failed to import for V1 engine on " "Blackwell (SM 10.0) GPUs; it is recommended to " - "install FlashInfer for better performance.") + "install FlashInfer for better performance." + ) # FlashAttention is the default for SM 8.0+ GPUs if cls.has_device_capability(80): - if (has_sink or - use_fp8_kv_cache) and not cls.is_device_capability(90): + if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90): logger.info_once("Using Triton backend on V1 engine.") return TRITON_ATTN elif is_default_backend_supported := is_attn_backend_supported( - FLASH_ATTN_V1, head_size, dtype, - allow_import_error=False): - logger.info_once("Using Flash Attention backend on " - "V1 engine.") + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False + ): + logger.info_once("Using Flash Attention backend on V1 engine.") return FLASH_ATTN_V1 # FlexAttention is the default for older GPUs @@ -377,14 +407,14 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info_once( "Using FlexAttention backend for %s on V1 engine.", - ", ".join(f"{k}={v}" - for k, v in use_flex_attention_reason.items()), + ", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()), ) return FLEX_ATTENTION_V1 raise RuntimeError( "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend.") + "to select a supported backend." + ) @classmethod def get_punica_wrapper(cls) -> str: @@ -392,7 +422,9 @@ def get_punica_wrapper(cls) -> str: @classmethod def get_device_communicator_cls(cls) -> str: - return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + return ( + "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + ) @classmethod def supports_fp8(cls) -> bool: @@ -430,8 +462,9 @@ def stateless_init_device_torch_dist_pg( backend_options = ProcessGroupNCCL.Options() backend_options._timeout = timeout - backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, - backend_options) + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, group_size, backend_options + ) backend_type = ProcessGroup.BackendType.NCCL device = torch.device("cuda") pg._set_default_backend(backend_type) @@ -445,8 +478,9 @@ def device_count(cls) -> int: return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, - model_config: "ModelConfig") -> bool: + def is_kv_cache_dtype_supported( + cls, kv_cache_dtype: str, model_config: "ModelConfig" + ) -> bool: fp8_attention = kv_cache_dtype.startswith("fp8") attention_backend = envs.VLLM_ATTENTION_BACKEND @@ -461,12 +495,10 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, attention_backend = "FLASHMLA" # Only FlashMLA and CUTLASS_MLA support fp8 - if attention_backend in [ - "FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA" - ]: + if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]: supported = True else: - supported = (not fp8_attention) + supported = not fp8_attention else: # Default to FlashAttention if attention_backend is None: @@ -477,8 +509,8 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, supported = True elif attention_backend == "FLASH_ATTN": if fp8_attention: - from vllm.attention.utils.fa_utils import ( - flash_attn_supports_fp8) + from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 + supported = flash_attn_supports_fp8() else: supported = True @@ -506,7 +538,8 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype): "with compute capability of at least 8.0. " f"Your {gpu_name} GPU {compute_str}. " "You can use float16 instead by explicitly setting the " - "`dtype` flag in CLI, for example: --dtype=half.") + "`dtype` flag in CLI, for example: --dtype=half." + ) @classmethod def insert_blocks_to_device( @@ -546,13 +579,10 @@ def support_static_graph_mode(cls) -> bool: # all the related functions work on real physical device ids. # the major benefit of using NVML is that it will not initialize CUDA class NvmlCudaPlatform(CudaPlatformBase): - @classmethod @cache @with_nvml_context - def get_device_capability(cls, - device_id: int = 0 - ) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]: try: physical_device_id = cls.device_id_to_physical_device_id(device_id) handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) @@ -599,9 +629,7 @@ def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ - handles = [ - pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids - ] + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] for i, handle in enumerate(handles): for j, peer_handle in enumerate(handles): if i < j: @@ -616,7 +644,8 @@ def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: except pynvml.NVMLError: logger.exception( "NVLink detection failed. This is normal if" - " your machine has no NVLink equipped.") + " your machine has no NVLink equipped." + ) return False return True @@ -630,11 +659,11 @@ def _get_physical_device_name(cls, device_id: int = 0) -> str: def log_warnings(cls): device_ids: int = pynvml.nvmlDeviceGetCount() if device_ids > 1: - device_names = [ - cls._get_physical_device_name(i) for i in range(device_ids) - ] - if (len(set(device_names)) > 1 - and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"): + device_names = [cls._get_physical_device_name(i) for i in range(device_ids)] + if ( + len(set(device_names)) > 1 + and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID" + ): logger.warning( "Detected different devices in the system: %s. Please" " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " @@ -644,7 +673,6 @@ def log_warnings(cls): class NonNvmlCudaPlatform(CudaPlatformBase): - @classmethod @cache def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: @@ -664,7 +692,8 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: logger.exception( "NVLink detection not possible, as context support was" - " not found. Assuming no NVLink available.") + " not found. Assuming no NVLink available." + ) return False diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index df1395fa842a..59bc9173958c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import enum import os import platform @@ -155,8 +156,10 @@ def device_id_to_physical_device_id(cls, device_id: int): # Treat empty device control env var as unset. This is a valid # configuration in Ray setups where the engine is launched in # a CPU-only placement group located on a GPU node. - if cls.device_control_env_var in os.environ and os.environ[ - cls.device_control_env_var] != "": + if ( + cls.device_control_env_var in os.environ + and os.environ[cls.device_control_env_var] != "" + ): device_ids = os.environ[cls.device_control_env_var].split(",") physical_device_id = device_ids[device_id] return int(physical_device_id) @@ -164,16 +167,41 @@ def device_id_to_physical_device_id(cls, device_id: int): return device_id @classmethod - def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> "_Backend": + def import_core_kernels(cls) -> None: + """Import any platform-specific C kernels.""" + try: + import vllm._C # noqa: F401 + except ImportError as e: + logger.warning("Failed to import from vllm._C: %r", e) + + @classmethod + def try_import_moe_kernels(cls) -> bool: + """Import any platform-specific MoE kernels.""" + with contextlib.suppress(ImportError): + import vllm._moe_C # noqa: F401 + + return True + return False + + @classmethod + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": from vllm.attention.backends.registry import _Backend + return _Backend.TORCH_SDPA @classmethod - def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, - dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool, use_sparse: bool) -> str: + def get_attn_backend_cls( + cls, + selected_backend: "_Backend", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + ) -> str: """Get the attention backend class of a device.""" return "" @@ -279,9 +307,9 @@ def set_device(cls, device: torch.device) -> None: raise NotImplementedError @classmethod - def pre_register_and_update(cls, - parser: Optional[FlexibleArgumentParser] = None - ) -> None: + def pre_register_and_update( + cls, parser: Optional[FlexibleArgumentParser] = None + ) -> None: """ Do some pre-registration or update action for the current platform. @@ -324,11 +352,10 @@ def verify_quantization(cls, quant: str) -> None: """ Verify whether the quantization is supported by the current platform. """ - if cls.supported_quantization and \ - quant not in cls.supported_quantization: + if cls.supported_quantization and quant not in cls.supported_quantization: raise ValueError( - f"{quant} quantization is currently not supported in " - f"{cls.device_name}.") + f"{quant} quantization is currently not supported in {cls.device_name}." + ) @classmethod def get_cpu_architecture(cls) -> CpuArchEnum: @@ -357,15 +384,17 @@ def is_pin_memory_available(cls) -> bool: if in_wsl(): # Pinning memory in WSL is not supported. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications - logger.warning("Using 'pin_memory=False' as WSL is detected. " - "This may slow down the performance.") + logger.warning( + "Using 'pin_memory=False' as WSL is detected. " + "This may slow down the performance." + ) return False return True @classmethod - def get_current_memory_usage(cls, - device: Optional[torch.types.Device] = None - ) -> float: + def get_current_memory_usage( + cls, device: Optional[torch.types.Device] = None + ) -> float: """ Return the memory usage in bytes. """ @@ -452,9 +481,10 @@ def use_all_gather(cls) -> bool: from vllm.config import get_current_vllm_config parallel_config = get_current_vllm_config().parallel_config - return (envs.VLLM_USE_V1 - or parallel_config.distributed_executor_backend - == "external_launcher") + return ( + envs.VLLM_USE_V1 + or parallel_config.distributed_executor_backend == "external_launcher" + ) @classmethod def use_custom_allreduce(cls) -> bool: @@ -485,8 +515,11 @@ def __getattr__(self, key: str): if device is not None and hasattr(device, key): return getattr(device, key) else: - logger.warning("Current platform %s does not have '%s'" \ - " attribute.", self.device_type, key) + logger.warning( + "Current platform %s does not have '%s' attribute.", + self.device_type, + key, + ) return None def get_global_graph_pool(self) -> Any: @@ -527,8 +560,9 @@ def stateless_init_device_torch_dist_pg( raise RuntimeError(f"Unsupported torch distributed backend: {backend}") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, - model_config: "ModelConfig") -> bool: + def is_kv_cache_dtype_supported( + cls, kv_cache_dtype: str, model_config: "ModelConfig" + ) -> bool: """ Returns if the kv_cache_dtype is supported by the current platform. """ @@ -581,7 +615,7 @@ def _synced_weight_loader(param, *args, **kwargs): @classmethod def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]: """ - Returns a mapping from device_type to a tuple of supported + Returns a mapping from device_type to a tuple of supported kv_buffer_device for nixl. """ return {} diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index de3df03d1fa0..80e7b849c0ed 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -25,9 +25,14 @@ logger = init_logger(__name__) try: - from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info, - amdsmi_get_processor_handles, amdsmi_init, - amdsmi_shut_down, amdsmi_topo_get_link_type) + from amdsmi import ( + AmdSmiException, + amdsmi_get_gpu_asic_info, + amdsmi_get_processor_handles, + amdsmi_init, + amdsmi_shut_down, + amdsmi_topo_get_link_type, + ) except ImportError as e: logger.warning("Failed to import from amdsmi with %r", e) @@ -47,24 +52,24 @@ # Models partially supported by ROCm. # Architecture -> Reason. -_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " - "Triton flash attention. For half-precision SWA support, " - "please use CK flash attention by setting " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") +_ROCM_SWA_REASON = ( + "Sliding window attention (SWA) is not yet supported in " + "Triton flash attention. For half-precision SWA support, " + "please use CK flash attention by setting " + "`VLLM_USE_TRITON_FLASH_ATTN=0`" +) _ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = { - "Qwen2ForCausalLM": - _ROCM_SWA_REASON, - "MistralForCausalLM": - _ROCM_SWA_REASON, - "MixtralForCausalLM": - _ROCM_SWA_REASON, - "PaliGemmaForConditionalGeneration": - ("ROCm flash attention does not yet " - "fully support 32-bit precision on PaliGemma"), - "Phi3VForCausalLM": - ("ROCm Triton flash attention may run into compilation errors due to " - "excessive use of shared memory. If this happens, disable Triton FA " - "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") + "Qwen2ForCausalLM": _ROCM_SWA_REASON, + "MistralForCausalLM": _ROCM_SWA_REASON, + "MixtralForCausalLM": _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": ( + "ROCm flash attention does not yet fully support 32-bit precision on PaliGemma" + ), + "Phi3VForCausalLM": ( + "ROCm Triton flash attention may run into compilation errors due to " + "excessive use of shared memory. If this happens, disable Triton FA " + "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`" + ), } _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = { "0x74a0": "AMD_Instinct_MI300A", @@ -91,7 +96,6 @@ def with_amdsmi_context(fn): - @wraps(fn) def wrapper(*args, **kwargs): amdsmi_init() @@ -129,16 +133,16 @@ def on_gfx950() -> bool: @cache def use_rocm_custom_paged_attention( - qtype: torch.dtype, - head_size: int, - block_size: int, - gqa_ratio: int, - max_seq_len: int, - sliding_window: int, - kv_cache_dtype: str, - alibi_slopes: Optional[torch.Tensor] = None, - sinks: Optional[torch.Tensor] = None) -> bool: - + qtype: torch.dtype, + head_size: int, + block_size: int, + gqa_ratio: int, + max_seq_len: int, + sliding_window: int, + kv_cache_dtype: str, + alibi_slopes: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, +) -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) @@ -146,26 +150,36 @@ def use_rocm_custom_paged_attention( # custom paged attn always supported on V0. On V1, requires sliding window # disabled due to observed numerical discrepancy. if ON_GFX9: - return ((not envs.VLLM_USE_V1 or sliding_window == 0 - or sliding_window == (-1, -1)) - and (qtype == torch.half or qtype == torch.bfloat16) - and (head_size == 64 or head_size == 128) - and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) - and max_seq_len <= 128 * 1024 - and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) - and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN - and envs.VLLM_ROCM_USE_AITER) and sinks is None) + return ( + (not envs.VLLM_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1)) + and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_seq_len <= 128 * 1024 + and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER) + and sinks is None + ) else: - return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 - or sliding_window == (-1, -1)) - and (qtype == torch.half or qtype == torch.bfloat16) - and head_size == 128 and block_size == 16 - and (gqa_ratio >= 3 and gqa_ratio <= 16) - and max_seq_len <= 128 * 1024 and alibi_slopes is None - and kv_cache_dtype == "auto" - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None) + return ( + ON_GFX11_GFX12 + and ( + not envs.VLLM_USE_V1 + or sliding_window == 0 + or sliding_window == (-1, -1) + ) + and (qtype == torch.half or qtype == torch.bfloat16) + and head_size == 128 + and block_size == 16 + and (gqa_ratio >= 3 and gqa_ratio <= 16) + and max_seq_len <= 128 * 1024 + and alibi_slopes is None + and kv_cache_dtype == "auto" + and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN + and sinks is None + ) class RocmPlatform(Platform): @@ -179,86 +193,112 @@ class RocmPlatform(Platform): device_control_env_var: str = "CUDA_VISIBLE_DEVICES" supported_quantization: list[str] = [ - "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4", "torchao" + "awq", + "gptq", + "fp8", + "compressed-tensors", + "fbgemm_fp8", + "gguf", + "quark", + "ptpc_fp8", + "mxfp4", + "petit_nvfp4", + "torchao", ] @classmethod - def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> "_Backend": + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": from vllm.attention.backends.registry import _Backend - if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA - and on_gfx9()): + + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): return _Backend.ROCM_AITER_FA if on_gfx9(): return _Backend.FLASH_ATTN return _Backend.TORCH_SDPA @classmethod - def get_attn_backend_cls(cls, selected_backend, head_size, dtype, - kv_cache_dtype, block_size, use_v1, use_mla, - has_sink, use_sparse) -> str: + def get_attn_backend_cls( + cls, + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink, + use_sparse, + ) -> str: from vllm.attention.backends.registry import _Backend + if use_sparse: - raise NotImplementedError( - "Sparse Attention is not supported on ROCm.") + raise NotImplementedError("Sparse Attention is not supported on ROCm.") if use_mla: if not use_v1: raise RuntimeError( "MLA attention backends require the V1 engine. " - "Set VLLM_USE_V1=1 to enable them.") + "Set VLLM_USE_V1=1 to enable them." + ) from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( - is_aiter_mla_enabled) + is_aiter_mla_enabled, + ) if selected_backend is None: - selected_backend = (_Backend.ROCM_AITER_MLA if - is_aiter_mla_enabled() or block_size == 1 - else _Backend.TRITON_MLA) + selected_backend = ( + _Backend.ROCM_AITER_MLA + if is_aiter_mla_enabled() or block_size == 1 + else _Backend.TRITON_MLA + ) if selected_backend == _Backend.TRITON_MLA: if block_size != 1: logger.info_once("Using Triton MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") + return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" raise ValueError( f" The selected backend, {selected_backend.name}," - f"does not support block size {block_size}.") + f"does not support block size {block_size}." + ) if selected_backend == _Backend.ROCM_AITER_MLA: if block_size == 1: logger.info("Using AITER MLA backend on V1 engine.") - return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + return ( + "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + ) raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}." - "(currently only supports block size 1)") + "(currently only supports block size 1)" + ) raise ValueError( f" The selected backend, {selected_backend.name}," - f"is not MLA type while requested for MLA backend.") + f"is not MLA type while requested for MLA backend." + ) if envs.VLLM_USE_V1: - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \ - and on_gfx9(): + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): logger.info("Using Flash Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "rocm_aiter_fa.AiterFlashAttentionBackend") - elif (envs.VLLM_ROCM_USE_AITER and - envs.VLLM_USE_AITER_UNIFIED_ATTENTION) or \ - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or \ - selected_backend == _Backend.ROCM_ATTN: + return ( + "vllm.v1.attention.backends." + "rocm_aiter_fa.AiterFlashAttentionBackend" + ) + elif ( + (envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION) + or envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + or selected_backend == _Backend.ROCM_ATTN + ): # rocm specific backend, with aiter and/or # triton prefix-prefill logger.info("Using Rocm/Aiter Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "rocm_attn.RocmAttentionBackend") + return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" else: # default case, using triton unified attention logger.info("Using Triton Attention backend on V1 engine.") - return ("vllm.v1.attention.backends." - "triton_attn.TritonAttentionBackend") + return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" raise RuntimeError( "V0 attention backends have been removed. Set VLLM_USE_V1=1 " - "to select a supported backend.") + "to select a supported backend." + ) @classmethod def set_device(cls, device: torch.device) -> None: @@ -269,9 +309,7 @@ def set_device(cls, device: torch.device) -> None: @classmethod @lru_cache(maxsize=8) - def get_device_capability(cls, - device_id: int = 0 - ) -> Optional[DeviceCapability]: + def get_device_capability(cls, device_id: int = 0) -> Optional[DeviceCapability]: major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) @@ -281,21 +319,17 @@ def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: """ Query if the set of gpus are fully connected by xgmi (1 hop) """ - handles = [ - amdsmi_get_processor_handles()[i] for i in physical_device_ids - ] + handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids] for i, handle in enumerate(handles): for j, peer_handle in enumerate(handles): if i < j: try: - link_type = amdsmi_topo_get_link_type( - handle, peer_handle) + link_type = amdsmi_topo_get_link_type(handle, peer_handle) # type is 2 for XGMI if link_type["hops"] != 1 or link_type["type"] != 2: return False except AmdSmiException as error: - logger.error("AMD 1 hop XGMI detection failed.", - exc_info=error) + logger.error("AMD 1 hop XGMI detection failed.", exc_info=error) return False return True @@ -326,8 +360,9 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: is_eager_execution = compilation_config == CUDAGraphMode.NONE use_v1 = envs.VLLM_USE_V1 - use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \ - envs.VLLM_ROCM_USE_AITER_RMSNORM + use_aiter_rms_norm = ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM + ) if cache_config and cache_config.block_size is None: cache_config.block_size = 16 @@ -335,21 +370,28 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" # Aiter rms norm perform best when CUDA Graph capture is enabled. - if (use_v1 and use_aiter_rms_norm and not is_eager_execution - and "-rms_norm" not in compilation_config.custom_ops): + if ( + use_v1 + and use_aiter_rms_norm + and not is_eager_execution + and "-rms_norm" not in compilation_config.custom_ops + ): compilation_config.custom_ops.append("+rms_norm") @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: - raise ValueError(f"Model architecture '{model_arch}' is not " - "supported by ROCm for now.") + raise ValueError( + f"Model architecture '{model_arch}' is not supported by ROCm for now." + ) if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] logger.warning( - "Model architecture '%s' is partially " - "supported by ROCm: %s", model_arch, msg) + "Model architecture '%s' is partially supported by ROCm: %s", + model_arch, + msg, + ) @classmethod def verify_quantization(cls, quant: str) -> None: @@ -357,7 +399,8 @@ def verify_quantization(cls, quant: str) -> None: if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ: logger.warning( "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" - " is not set, enabling VLLM_USE_TRITON_AWQ.") + " is not set, enabling VLLM_USE_TRITON_AWQ." + ) envs.VLLM_USE_TRITON_AWQ = True @classmethod @@ -365,16 +408,17 @@ def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" @classmethod - def get_current_memory_usage(cls, - device: Optional[torch.types.Device] = None - ) -> float: + def get_current_memory_usage( + cls, device: Optional[torch.types.Device] = None + ) -> float: torch.cuda.reset_peak_memory_stats(device) - return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info( - device)[0] + return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0] @classmethod def get_device_communicator_cls(cls) -> str: - return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + return ( + "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + ) @classmethod def supports_mx(cls) -> bool: @@ -384,12 +428,12 @@ def supports_mx(cls) -> bool: @classmethod def supports_fp8(cls) -> bool: gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12']) + return any(gfx in gcn_arch for gfx in ["gfx94", "gfx95", "gfx12"]) @classmethod def is_fp8_fnuz(cls) -> bool: # only device 0 is checked, this assumes MI300 platforms are homogeneous - return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName + return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName @classmethod def fp8_dtype(cls) -> torch.dtype: @@ -402,7 +446,7 @@ def fp8_dtype(cls) -> torch.dtype: def use_custom_allreduce(cls) -> bool: # We only enable custom allreduce for MI300 series gcn_arch = torch.cuda.get_device_properties(0).gcnArchName - supported_archs = ['gfx94', 'gfx95'] + supported_archs = ["gfx94", "gfx95"] return any(gfx in gcn_arch for gfx in supported_archs) @classmethod @@ -411,12 +455,11 @@ def opaque_attention_op(cls) -> bool: @classmethod def get_cu_count(cls, device_id: int = 0) -> int: - return torch.cuda.get_device_properties( - device_id).multi_processor_count + return torch.cuda.get_device_properties(device_id).multi_processor_count @classmethod def is_navi(cls) -> bool: - return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName + return "gfx1" in torch.cuda.get_device_properties(0).gcnArchName @classmethod def get_static_graph_wrapper_cls(cls) -> str: @@ -442,8 +485,9 @@ def stateless_init_device_torch_dist_pg( backend_options = ProcessGroupNCCL.Options() backend_options._timeout = timeout - backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, - backend_options) + backend_class = ProcessGroupNCCL( + prefix_store, group_rank, group_size, backend_options + ) backend_type = ProcessGroup.BackendType.NCCL device = torch.device("cuda") pg._set_default_backend(backend_type) @@ -457,8 +501,9 @@ def device_count(cls) -> int: return cuda_device_count_stateless() @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, - model_config: "ModelConfig") -> bool: + def is_kv_cache_dtype_supported( + cls, kv_cache_dtype: str, model_config: "ModelConfig" + ) -> bool: return True @classmethod @@ -479,7 +524,8 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype): "with compute capability of at least 8.0. " f"Your {gpu_name} GPU {compute_str}. " "You can use float16 instead by explicitly setting the " - "`dtype` flag in CLI, for example: --dtype=half.") + "`dtype` flag in CLI, for example: --dtype=half." + ) @classmethod def support_hybrid_kv_cache(cls) -> bool: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 91a01a4f4ee9..6be9ca1298a9 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -39,23 +39,31 @@ class TpuPlatform(Platform): device_control_env_var: str = "TPU_VISIBLE_CHIPS" simple_compile_backend: str = "openxla" - supported_quantization: list[str] = [ - "fp8", "tpu_int8", "compressed-tensors" - ] + supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"] - additional_env_vars: list[str] = [ - "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" - ] + additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"] @classmethod - def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, - dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, use_mla: bool, - has_sink, use_sparse) -> str: + def import_core_kernels(cls) -> None: + pass + + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "_Backend", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink, + use_sparse, + ) -> str: from vllm.attention.backends.registry import _Backend + if use_sparse: - raise NotImplementedError( - "Sparse Attention is not supported on TPU.") + raise NotImplementedError("Sparse Attention is not supported on TPU.") if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) @@ -112,34 +120,43 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # TPU only supports DYNAMO_ONCE compilation level if compilation_config.level != CompilationLevel.DYNAMO_ONCE: - logger.info("[TPU] Forcing DYNAMO_ONCE compilation level, and " - "disabling cudagraph.") + logger.info( + "[TPU] Forcing DYNAMO_ONCE compilation level, and disabling cudagraph." + ) compilation_config.level = CompilationLevel.DYNAMO_ONCE - if compilation_config.cudagraph_mode is None or \ - compilation_config.cudagraph_mode.max_cudagraph_mode() \ - != CUDAGraphMode.NONE: - logger.info("[TPU] CUDA graph is not supported on TPU, " - "disabling cudagraphs.") + if ( + compilation_config.cudagraph_mode is None + or compilation_config.cudagraph_mode.max_cudagraph_mode() + != CUDAGraphMode.NONE + ): + logger.info( + "[TPU] CUDA graph is not supported on TPU, disabling cudagraphs." + ) compilation_config.cudagraph_mode = CUDAGraphMode.NONE if compilation_config.backend == "": compilation_config.backend = "openxla" - assert vllm_config.speculative_config is None, \ + assert vllm_config.speculative_config is None, ( "TPU does not support speculative decoding" + ) model_config = vllm_config.model_config - if model_config is not None and model_config.dtype in (torch.float16, - torch.float32): + if model_config is not None and model_config.dtype in ( + torch.float16, + torch.float32, + ): logger.warning( "The TPU backend currently does not support %s. " - "Using bfloat16 instead.", model_config.dtype) + "Using bfloat16 instead.", + model_config.dtype, + ) model_config.dtype = torch.bfloat16 from vllm.v1.attention.backends.pallas import PallasAttentionBackend - cache_config.block_size = PallasAttentionBackend.get_page_size( - vllm_config) # type: ignore[assignment] + + cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment] parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config @@ -147,24 +164,31 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker" assert not vllm_config.speculative_config, ( - "Speculative decoding is not yet supported for TPU backend") + "Speculative decoding is not yet supported for TPU backend" + ) - if scheduler_config.is_multimodal_model and not \ - scheduler_config.disable_chunked_mm_input: - logger.warning("TPU does not support running Multimodal models"\ - " without setting `--disable_chunked_mm_input`. " \ - "Forcing --disable_chunked_mm_input.") + if ( + scheduler_config.is_multimodal_model + and not scheduler_config.disable_chunked_mm_input + ): + logger.warning( + "TPU does not support running Multimodal models" + " without setting `--disable_chunked_mm_input`. " + "Forcing --disable_chunked_mm_input." + ) scheduler_config.disable_chunked_mm_input = True if model_config and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") + "prefill and prefix caching to be disabled." + ) vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) + DEFAULT_MAX_NUM_BATCHED_TOKENS, + ) @classmethod def is_pin_memory_available(cls): @@ -187,13 +211,16 @@ def validate_request( processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" - if (isinstance(params, SamplingParams) - and params.sampling_type == SamplingType.RANDOM_SEED): + if ( + isinstance(params, SamplingParams) + and params.sampling_type == SamplingType.RANDOM_SEED + ): raise ValueError("Torch XLA does not support per-request seed.") @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, - model_config: "ModelConfig") -> bool: + def is_kv_cache_dtype_supported( + cls, kv_cache_dtype: str, model_config: "ModelConfig" + ) -> bool: return True @classmethod @@ -206,8 +233,7 @@ def insert_blocks_to_device( dst_block_indices: torch.Tensor, ) -> None: torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True) - dst_cache[dst_block_indices] = src_cache[src_block_indices].to( - dst_cache.device) + dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device) @classmethod @torch.compile(backend="openxla") @@ -218,7 +244,7 @@ def swap_out_blocks_to_host( src_block_indices: torch.Tensor, dst_block_indices: torch.Tensor, ) -> None: - """ tpu blocks to cpu blocks""" + """tpu blocks to cpu blocks""" torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True) dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu() @@ -229,6 +255,7 @@ def use_sync_weight_loader(cls) -> bool: try: from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform + TpuPlatform = TpuCommonsPlatform # type: ignore USE_TPU_COMMONS = True except ImportError: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 3ccbae58726f..2f2f3ab8b9d9 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -35,14 +35,26 @@ class XPUPlatform(Platform): device_control_env_var: str = "ZE_AFFINITY_MASK" @classmethod - def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, - dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool, use_sparse) -> str: + def import_core_kernels(cls) -> None: + pass + + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "_Backend", + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + use_v1: bool, + use_mla: bool, + has_sink: bool, + use_sparse, + ) -> str: from vllm.attention.backends.registry import _Backend + if use_sparse: - raise NotImplementedError( - "Sparse Attention is not supported on XPU.") + raise NotImplementedError("Sparse Attention is not supported on XPU.") use_v1 = envs.VLLM_USE_V1 if not use_v1: raise ValueError("XPU backend only supports V1.") @@ -57,20 +69,24 @@ def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, elif selected_backend: raise ValueError( f"Invalid attention backend for {cls.device_name}, " - f"with use_v1: {use_v1} use_mla: {use_mla}") + f"with use_v1: {use_v1} use_mla: {use_mla}" + ) logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" @classmethod - def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, - model_config: "ModelConfig") -> bool: + def is_kv_cache_dtype_supported( + cls, kv_cache_dtype: str, model_config: "ModelConfig" + ) -> bool: """ Check if the kv_cache_dtype is supported. XPU only support fp8 kv cache with triton backend. """ - if envs.is_set("VLLM_ATTENTION_BACKEND") and \ - envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN": + if ( + envs.is_set("VLLM_ATTENTION_BACKEND") + and envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN" + ): return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"] return False @@ -118,12 +134,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # lazy import to avoid circular import from vllm.config import CompilationLevel, CUDAGraphMode + compilation_config = vllm_config.compilation_config if compilation_config.compile_sizes is None: compilation_config.compile_sizes = [] - assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, \ + assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, ( "CUDA graph mode should be NONE on XPU" + ) if vllm_config.lora_config is not None: compilation_config.level = CompilationLevel.NO_COMPILATION @@ -144,31 +162,38 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn": os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" logger.warning( - "Please use spawn as start method if you want to use mp.") - elif (parallel_config.distributed_executor_backend != "ray" - and parallel_config.distributed_executor_backend != "uni" - and parallel_config.distributed_executor_backend - != "external_launcher"): + "Please use spawn as start method if you want to use mp." + ) + elif ( + parallel_config.distributed_executor_backend != "ray" + and parallel_config.distributed_executor_backend != "uni" + and parallel_config.distributed_executor_backend != "external_launcher" + ): logger.warning( "%s is not supported on XPU, fallback to ray distributed" " executor backend.", - parallel_config.distributed_executor_backend) + parallel_config.distributed_executor_backend, + ) parallel_config.distributed_executor_backend = "ray" if model_config and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") + "prefill and prefix caching to be disabled." + ) vllm_config.scheduler_config.enable_chunked_prefill = False vllm_config.scheduler_config.chunked_prefill_enabled = False vllm_config.scheduler_config.max_num_batched_tokens = max( vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) + DEFAULT_MAX_NUM_BATCHED_TOKENS, + ) from vllm.v1.attention.backends.utils import set_kv_cache_layout set_kv_cache_layout("NHD") - logger.info("Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " - "only NHD layout is supported by XPU attention kernels.") + logger.info( + "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " + "only NHD layout is supported by XPU attention kernels." + ) @classmethod def support_hybrid_kv_cache(cls) -> bool: @@ -183,9 +208,9 @@ def is_pin_memory_available(cls): return True @classmethod - def get_current_memory_usage(cls, - device: Optional[torch.types.Device] = None - ) -> float: + def get_current_memory_usage( + cls, device: Optional[torch.types.Device] = None + ) -> float: torch.xpu.reset_peak_memory_stats(device) return torch.xpu.max_memory_allocated(device) @@ -215,7 +240,8 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype): raise ValueError( "Intel Arc A770 have bfloat16 accuracy known issue. " "You can use float16 instead by explicitly setting the " - "`dtype` flag in CLI, for example: --dtype=half.") + "`dtype` flag in CLI, for example: --dtype=half." + ) @classmethod def opaque_attention_op(cls) -> bool: diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 1a1760df82c0..0c83d49c4593 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) -DEFAULT_PLUGINS_GROUP = 'vllm.general_plugins' +DEFAULT_PLUGINS_GROUP = "vllm.general_plugins" # make sure one process only loads plugins once plugins_loaded = False @@ -16,6 +16,7 @@ def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: import sys + if sys.version_info < (3, 10): from importlib_metadata import entry_points else: @@ -29,7 +30,7 @@ def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: return {} # Check if the only discovered plugin is the default one - is_default_group = (group == DEFAULT_PLUGINS_GROUP) + is_default_group = group == DEFAULT_PLUGINS_GROUP # Use INFO for non-default groups and DEBUG for the default group log_level = logger.debug if is_default_group else logger.info @@ -38,8 +39,10 @@ def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: log_level("- %s -> %s", plugin.name, plugin.value) if allowed_plugins is None: - log_level("All plugins in this group will be loaded. " - "Set `VLLM_PLUGINS` to control which plugins to load.") + log_level( + "All plugins in this group will be loaded. " + "Set `VLLM_PLUGINS` to control which plugins to load." + ) plugins = dict[str, Callable[[], Any]]() for plugin in discovered_plugins: diff --git a/vllm/plugins/io_processors/__init__.py b/vllm/plugins/io_processors/__init__.py index 3b17211b1b83..7a914442c4ab 100644 --- a/vllm/plugins/io_processors/__init__.py +++ b/vllm/plugins/io_processors/__init__.py @@ -4,7 +4,6 @@ from __future__ import annotations import logging -from typing import Optional from vllm.config import VllmConfig from vllm.plugins import load_plugins_by_group @@ -15,8 +14,8 @@ def get_io_processor( - vllm_config: VllmConfig, - plugin_from_init: Optional[str] = None) -> IOProcessor | None: + vllm_config: VllmConfig, plugin_from_init: str | None = None +) -> IOProcessor | None: # Input.Output processors are loaded as plugins under the # 'vllm.io_processor_plugins' group. Similar to platform # plugins, these plugins register a function that returns the class @@ -39,8 +38,9 @@ def get_io_processor( logger.debug("IOProcessor plugin to be loaded %s", model_plugin) # Load all installed plugin in the group - multimodal_data_processor_plugins = \ - load_plugins_by_group('vllm.io_processor_plugins') + multimodal_data_processor_plugins = load_plugins_by_group( + "vllm.io_processor_plugins" + ) loadable_plugins = {} for name, func in multimodal_data_processor_plugins.items(): @@ -54,14 +54,16 @@ def get_io_processor( num_available_plugins = len(loadable_plugins.keys()) if num_available_plugins == 0: - raise ValueError("No IOProcessor plugins installed" - f" but one is required ({model_plugin}).") + raise ValueError( + f"No IOProcessor plugins installed but one is required ({model_plugin})." + ) if model_plugin not in loadable_plugins: raise ValueError( f"The model requires the '{model_plugin}' IO Processor plugin " "but it is not installed. " - f"Available plugins: {list(loadable_plugins.keys())}") + f"Available plugins: {list(loadable_plugins.keys())}" + ) activated_plugin_cls = loadable_plugins[model_plugin] diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py index 62b224cac5e5..84af40d01c43 100644 --- a/vllm/plugins/io_processors/interface.py +++ b/vllm/plugins/io_processors/interface.py @@ -10,12 +10,11 @@ from vllm.inputs.data import PromptType from vllm.outputs import PoolingRequestOutput -IOProcessorInput = TypeVar('IOProcessorInput') -IOProcessorOutput = TypeVar('IOProcessorOutput') +IOProcessorInput = TypeVar("IOProcessorInput") +IOProcessorOutput = TypeVar("IOProcessorOutput") class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): - def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config @@ -37,10 +36,12 @@ async def pre_process_async( return self.pre_process(prompt, request_id, **kwargs) @abstractmethod - def post_process(self, - model_output: Sequence[PoolingRequestOutput], - request_id: Optional[str] = None, - **kwargs) -> IOProcessorOutput: + def post_process( + self, + model_output: Sequence[PoolingRequestOutput], + request_id: Optional[str] = None, + **kwargs, + ) -> IOProcessorOutput: raise NotImplementedError async def post_process_async( @@ -52,8 +53,9 @@ async def post_process_async( # We cannot guarantee outputs are returned in the same order they were # fed to vLLM. # Let's sort them by id before post_processing - sorted_output = sorted([(i, item) async for i, item in model_output], - key=lambda output: output[0]) + sorted_output = sorted( + [(i, item) async for i, item in model_output], key=lambda output: output[0] + ) collected_output = [output[1] for output in sorted_output] return self.post_process(collected_output, request_id, **kwargs) @@ -63,5 +65,6 @@ def parse_request(self, request: Any) -> IOProcessorInput: @abstractmethod def output_to_response( - self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: + self, plugin_output: IOProcessorOutput + ) -> IOProcessorResponse: raise NotImplementedError diff --git a/vllm/plugins/lora_resolvers/filesystem_resolver.py b/vllm/plugins/lora_resolvers/filesystem_resolver.py index b999d07a6eb7..c3255af45702 100644 --- a/vllm/plugins/lora_resolvers/filesystem_resolver.py +++ b/vllm/plugins/lora_resolvers/filesystem_resolver.py @@ -10,25 +10,29 @@ class FilesystemResolver(LoRAResolver): - def __init__(self, lora_cache_dir: str): self.lora_cache_dir = lora_cache_dir - async def resolve_lora(self, base_model_name: str, - lora_name: str) -> Optional[LoRARequest]: + async def resolve_lora( + self, base_model_name: str, lora_name: str + ) -> Optional[LoRARequest]: lora_path = os.path.join(self.lora_cache_dir, lora_name) if os.path.exists(lora_path): - adapter_config_path = os.path.join(self.lora_cache_dir, lora_name, - "adapter_config.json") + adapter_config_path = os.path.join( + self.lora_cache_dir, lora_name, "adapter_config.json" + ) if os.path.exists(adapter_config_path): with open(adapter_config_path) as file: adapter_config = json.load(file) - if adapter_config["peft_type"] == "LORA" and adapter_config[ - "base_model_name_or_path"] == base_model_name: - lora_request = LoRARequest(lora_name=lora_name, - lora_int_id=abs( - hash(lora_name)), - lora_path=lora_path) + if ( + adapter_config["peft_type"] == "LORA" + and adapter_config["base_model_name_or_path"] == base_model_name + ): + lora_request = LoRARequest( + lora_name=lora_name, + lora_int_id=abs(hash(lora_name)), + lora_path=lora_path, + ) return lora_request return None @@ -38,13 +42,12 @@ def register_filesystem_resolver(): lora_cache_dir = envs.VLLM_LORA_RESOLVER_CACHE_DIR if lora_cache_dir: - if not os.path.exists(lora_cache_dir) or not os.path.isdir( - lora_cache_dir): + if not os.path.exists(lora_cache_dir) or not os.path.isdir(lora_cache_dir): raise ValueError( "VLLM_LORA_RESOLVER_CACHE_DIR must be set to a valid directory \ - for Filesystem Resolver plugin to function") + for Filesystem Resolver plugin to function" + ) fs_resolver = FilesystemResolver(lora_cache_dir) - LoRAResolverRegistry.register_resolver("Filesystem Resolver", - fs_resolver) + LoRAResolverRegistry.register_resolver("Filesystem Resolver", fs_resolver) return diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index a6313367457a..f7a53503e584 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -14,16 +14,17 @@ class PoolingParams( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True, +): # type: ignore[call-arg] """API parameters for pooling models. Attributes: truncate_prompt_tokens: Controls prompt truncation. Set to -1 to use the model's default truncation size. Set to k to keep only the last k tokens (left truncation). - Set to None to disable truncation. + Set to None to disable truncation. normalize: Whether to normalize the embeddings outputs. dimensions: Reduce the dimensions of embeddings if model support matryoshka representation. @@ -33,8 +34,7 @@ class PoolingParams( """ # --8<-- [start:common-pooling-params] - truncate_prompt_tokens: Optional[Annotated[int, - msgspec.Meta(ge=-1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None # --8<-- [end:common-pooling-params] ## for embeddings models @@ -67,8 +67,12 @@ class PoolingParams( @property def all_parameters(self) -> list[str]: return [ - "dimensions", "normalize", "activation", "softmax", "step_tag_id", - "returned_token_ids" + "dimensions", + "normalize", + "activation", + "softmax", + "step_tag_id", + "returned_token_ids", ] @property @@ -84,10 +88,9 @@ def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" return deepcopy(self) - def verify(self, - task: PoolingTask, - model_config: Optional["ModelConfig"] = None) -> None: - + def verify( + self, task: PoolingTask, model_config: Optional["ModelConfig"] = None + ) -> None: if self.task is None: self.task = task elif self.task != task: @@ -102,10 +105,9 @@ def verify(self, self._set_default_parameters(model_config) self._verify_valid_parameters() - def _merge_default_parameters(self, - model_config: Optional["ModelConfig"] = None - ) -> None: - + def _merge_default_parameters( + self, model_config: Optional["ModelConfig"] = None + ) -> None: if model_config is None: return @@ -132,8 +134,8 @@ def _set_default_parameters(self, model_config: Optional["ModelConfig"]): if not model_config.is_matryoshka: raise ValueError( f'Model "{model_config.served_model_name}" does not ' - f'support matryoshka representation, ' - f'changing output dimensions will lead to poor results.' + f"support matryoshka representation, " + f"changing output dimensions will lead to poor results." ) mds = model_config.matryoshka_dimensions @@ -141,9 +143,10 @@ def _set_default_parameters(self, model_config: Optional["ModelConfig"]): if self.dimensions not in mds: raise ValueError( f'Model "{model_config.served_model_name}" ' - f'only supports {str(mds)} matryoshka dimensions, ' - f'use other output dimensions will ' - f'lead to poor results.') + f"only supports {str(mds)} matryoshka dimensions, " + f"use other output dimensions will " + f"lead to poor results." + ) elif self.dimensions < 1: raise ValueError("Dimensions must be greater than 0") @@ -172,20 +175,24 @@ def _verify_valid_parameters(self): raise ValueError( f"Task {self.task} only supports {valid_parameters} " f"parameters, does not support " - f"{invalid_parameters} parameters") + f"{invalid_parameters} parameters" + ) def __repr__(self) -> str: - return (f"PoolingParams(" - f"task={self.task}, " - f"normalize={self.normalize}, " - f"dimensions={self.dimensions}, " - f"activation={self.activation}, " - f"softmax={self.softmax}, " - f"step_tag_id={self.step_tag_id}, " - f"returned_token_ids={self.returned_token_ids}, " - f"requires_token_ids={self.requires_token_ids}, " - f"extra_kwargs={self.extra_kwargs})") + return ( + f"PoolingParams(" + f"task={self.task}, " + f"normalize={self.normalize}, " + f"dimensions={self.dimensions}, " + f"activation={self.activation}, " + f"softmax={self.softmax}, " + f"step_tag_id={self.step_tag_id}, " + f"returned_token_ids={self.returned_token_ids}, " + f"requires_token_ids={self.requires_token_ids}, " + f"extra_kwargs={self.extra_kwargs})" + ) def __post_init__(self) -> None: - assert self.output_kind == RequestOutputKind.FINAL_ONLY,\ + assert self.output_kind == RequestOutputKind.FINAL_ONLY, ( "For pooling output_kind has to be FINAL_ONLY" + ) diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py index 41136f738c28..fea299b287f9 100644 --- a/vllm/profiler/layerwise_profile.py +++ b/vllm/profiler/layerwise_profile.py @@ -12,21 +12,26 @@ from torch.autograd.profiler import FunctionEvent from torch.profiler import ProfilerActivity, profile -from vllm.profiler.utils import (TablePrinter, event_has_module, - event_is_torch_op, event_module_repr, - event_torch_op_stack_trace, indent_string) +from vllm.profiler.utils import ( + TablePrinter, + event_has_module, + event_is_torch_op, + event_module_repr, + event_torch_op_stack_trace, + indent_string, +) @dataclass class _ModuleTreeNode: event: _ProfilerEvent - parent: Optional['_ModuleTreeNode'] = None - children: list['_ModuleTreeNode'] = field(default_factory=list) + parent: Optional["_ModuleTreeNode"] = None + children: list["_ModuleTreeNode"] = field(default_factory=list) trace: str = "" @property def is_leaf(self): - return (self.event.children is None or len(self.event.children) == 0) + return self.event.children is None or len(self.event.children) == 0 @property def is_torch_op(self): @@ -34,8 +39,10 @@ def is_torch_op(self): @property def is_cuda(self): - return (self.event.tag == _EventType.Kineto - and self.event.typed[1].device_type == DeviceType.CUDA) + return ( + self.event.tag == _EventType.Kineto + and self.event.typed[1].device_type == DeviceType.CUDA + ) @dataclass @@ -68,8 +75,7 @@ class _StatsTreeNode: @dataclass class LayerwiseProfileResults(profile): _kineto_results: _ProfilerResult - _kineto_event_correlation_map: dict[int, - list[_KinetoEvent]] = field(init=False) + _kineto_event_correlation_map: dict[int, list[_KinetoEvent]] = field(init=False) _event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False) _module_tree: list[_ModuleTreeNode] = field(init=False) _model_stats_tree: list[_StatsTreeNode] = field(init=False) @@ -84,11 +90,9 @@ def __post_init__(self): self._build_stats_trees() def print_model_table(self, column_widths: dict[str, int] = None): - _column_widths = dict(name=60, - cpu_time_us=12, - cuda_time_us=12, - pct_cuda_time=12, - trace=60) + _column_widths = dict( + name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60 + ) if column_widths: _column_widths.update(**column_widths) filtered_model_table = [ @@ -99,78 +103,76 @@ def print_model_table(self, column_widths: dict[str, int] = None): TablePrinter(ModelStatsEntry, _column_widths).print_table( self._indent_row_names_based_on_depth( filtered_model_table, - indent_style=lambda indent: "|" + "-" * indent + " ")) + indent_style=lambda indent: "|" + "-" * indent + " ", + ) + ) def print_summary_table(self, column_widths: dict[str, int] = None): - _column_widths = dict(name=80, - cuda_time_us=12, - pct_cuda_time=12, - invocations=15) + _column_widths = dict( + name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15 + ) if column_widths: _column_widths.update(**column_widths) - filtered_summary_table = [(depth, row) - for depth, row in self._flatten_stats_tree( - self._summary_stats_tree) - if row.cuda_time_us > 0] + filtered_summary_table = [ + (depth, row) + for depth, row in self._flatten_stats_tree(self._summary_stats_tree) + if row.cuda_time_us > 0 + ] TablePrinter(SummaryStatsEntry, _column_widths).print_table( self._indent_row_names_based_on_depth( filtered_summary_table, - indent_style=lambda indent: "|" + "-" * indent + " ")) + indent_style=lambda indent: "|" + "-" * indent + " ", + ) + ) def export_model_stats_table_csv(self, filename: str): - df = pd.DataFrame([ - asdict(row) - for _, row in self._flatten_stats_tree(self._model_stats_tree) - ]) + df = pd.DataFrame( + [asdict(row) for _, row in self._flatten_stats_tree(self._model_stats_tree)] + ) df.to_csv(filename) def export_summary_stats_table_csv(self, filename: str): - df = pd.DataFrame([ - asdict(row) - for _, row in self._flatten_stats_tree(self._summary_stats_tree) - ]) + df = pd.DataFrame( + [ + asdict(row) + for _, row in self._flatten_stats_tree(self._summary_stats_tree) + ] + ) df.to_csv(filename) def convert_stats_to_dict(self) -> dict[str, Any]: return { - "metadata": { - "num_running_seqs": self.num_running_seqs - }, - "summary_stats": - self._convert_stats_tree_to_dict(self._summary_stats_tree), - "model_stats": - self._convert_stats_tree_to_dict(self._model_stats_tree) + "metadata": {"num_running_seqs": self.num_running_seqs}, + "summary_stats": self._convert_stats_tree_to_dict(self._summary_stats_tree), + "model_stats": self._convert_stats_tree_to_dict(self._model_stats_tree), } @staticmethod - def _indent_row_names_based_on_depth(depths_rows: list[tuple[int, - StatsEntry]], - indent_style: Union[Callable[[int], - str], - str] = " "): + def _indent_row_names_based_on_depth( + depths_rows: list[tuple[int, StatsEntry]], + indent_style: Union[Callable[[int], str], str] = " ", + ): indented_rows = [] for depth, row in depths_rows: if row.cuda_time_us == 0: continue indented_row = copy.deepcopy(row) - indented_row.name = indent_string(indented_row.name, depth, - indent_style) + indented_row.name = indent_string(indented_row.name, depth, indent_style) indented_rows.append(indented_row) return indented_rows def _build_correlation_map(self): self._kineto_event_correlation_map = defaultdict(list) for event in self._kineto_results.events(): - self._kineto_event_correlation_map[event.correlation_id()].append( - event) + self._kineto_event_correlation_map[event.correlation_id()].append(event) def _build_module_tree(self): self._module_tree = [] event_tree = self._kineto_results.experimental_event_tree() - def _df_traversal(event: _ProfilerEvent, - curr_node: Optional[_ModuleTreeNode] = None): - + def _df_traversal( + event: _ProfilerEvent, curr_node: Optional[_ModuleTreeNode] = None + ): # For the tensor parallel case for now only look at task 1 if event.start_tid != 1: return @@ -183,13 +185,15 @@ def _df_traversal(event: _ProfilerEvent, self._module_tree.append(node) curr_node = node - is_leaf = (event.children is None or len(event.children) == 0) + is_leaf = event.children is None or len(event.children) == 0 if is_leaf and curr_node: node = _ModuleTreeNode( event=event, parent=curr_node, trace=event_torch_op_stack_trace( - event, until=lambda x: event_has_module(x))) + event, until=lambda x: event_has_module(x) + ), + ) curr_node.children.append(node) curr_node = node @@ -203,31 +207,31 @@ def _get_kineto_gpu_event(self, node: _ModuleTreeNode): if node.event.tag != _EventType.Kineto: return None correlated_kineto_events = self._kineto_event_correlation_map.get( - node.event.correlation_id, []) - iterator = (x for x in correlated_kineto_events - if x.device_type() == DeviceType.CUDA - and x.name() == node.event.name) + node.event.correlation_id, [] + ) + iterator = ( + x + for x in correlated_kineto_events + if x.device_type() == DeviceType.CUDA and x.name() == node.event.name + ) return next(iterator, None) def _cumulative_cuda_time(self, node: _ModuleTreeNode): - 'Return cuda time in microseconds' + "Return cuda time in microseconds" def _cumulative_cuda_time_recursive(node: _ModuleTreeNode): - if node.is_leaf and (gpu_kineto_event := - self._get_kineto_gpu_event(node)): + if node.is_leaf and (gpu_kineto_event := self._get_kineto_gpu_event(node)): return gpu_kineto_event.duration_ns() / 1000.0 else: cumulative_cuda_time = 0 for child in node.children: - cumulative_cuda_time += _cumulative_cuda_time_recursive( - child) + cumulative_cuda_time += _cumulative_cuda_time_recursive(child) return cumulative_cuda_time return _cumulative_cuda_time_recursive(node) def _total_cuda_time(self): - return sum( - [self._cumulative_cuda_time(root) for root in self._module_tree]) + return sum([self._cumulative_cuda_time(root) for root in self._module_tree]) def _build_stats_trees(self): summary_dict: dict[str, _StatsTreeNode] = {} @@ -239,38 +243,42 @@ def pct_cuda_time(cuda_time_us): def build_summary_stats_tree_df( node: _ModuleTreeNode, parent: Optional[_StatsTreeNode] = None, - summary_trace: tuple[str] = ()): - + summary_trace: tuple[str] = (), + ): if event_has_module(node.event): name = event_module_repr(node.event) cuda_time_us = self._cumulative_cuda_time(node) - elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + elif gpu_kineto_event := self._get_kineto_gpu_event(node): name = gpu_kineto_event.name() cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0 else: return None - summary_trace = summary_trace + (name, ) + summary_trace = summary_trace + (name,) if summary_trace in summary_dict: entry = summary_dict[summary_trace].entry entry.cuda_time_us += cuda_time_us entry.invocations += 1 entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us) else: - new_node = _StatsTreeNode(entry=SummaryStatsEntry( - name=name, - cuda_time_us=cuda_time_us, - pct_cuda_time=pct_cuda_time(cuda_time_us), - invocations=1), - children=[], - parent=parent) + new_node = _StatsTreeNode( + entry=SummaryStatsEntry( + name=name, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + invocations=1, + ), + children=[], + parent=parent, + ) if parent: parent.children.append(new_node) summary_dict[summary_trace] = new_node for child in node.children: - build_summary_stats_tree_df(child, summary_dict[summary_trace], - summary_trace) + build_summary_stats_tree_df( + child, summary_dict[summary_trace], summary_trace + ) return summary_dict[summary_trace] @@ -278,14 +286,17 @@ def build_summary_stats_tree_df( for root in self._module_tree: self._summary_stats_tree.append(build_summary_stats_tree_df(root)) - def build_model_stats_tree_df(node: _ModuleTreeNode, - parent: Optional[_StatsTreeNode] = None): - if event_has_module(node.event, ): + def build_model_stats_tree_df( + node: _ModuleTreeNode, parent: Optional[_StatsTreeNode] = None + ): + if event_has_module( + node.event, + ): name = event_module_repr(node.event) cuda_time_us = self._cumulative_cuda_time(node) cpu_time_us = node.event.duration_time_ns / 1000 trace = "" - elif (gpu_kineto_event := self._get_kineto_gpu_event(node)): + elif gpu_kineto_event := self._get_kineto_gpu_event(node): name = gpu_kineto_event.name() cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0 cpu_time_us = 0 @@ -293,14 +304,17 @@ def build_model_stats_tree_df(node: _ModuleTreeNode, else: return None - new_node = _StatsTreeNode(entry=ModelStatsEntry( - name=name, - cpu_time_us=cpu_time_us, - cuda_time_us=cuda_time_us, - pct_cuda_time=pct_cuda_time(cuda_time_us), - trace=trace), - parent=parent, - children=[]) + new_node = _StatsTreeNode( + entry=ModelStatsEntry( + name=name, + cpu_time_us=cpu_time_us, + cuda_time_us=cuda_time_us, + pct_cuda_time=pct_cuda_time(cuda_time_us), + trace=trace, + ), + parent=parent, + children=[], + ) if parent: parent.children.append(new_node) @@ -314,7 +328,8 @@ def build_model_stats_tree_df(node: _ModuleTreeNode, self._model_stats_tree.append(build_model_stats_tree_df(root)) def _flatten_stats_tree( - self, tree: list[_StatsTreeNode]) -> list[tuple[int, StatsEntry]]: + self, tree: list[_StatsTreeNode] + ) -> list[tuple[int, StatsEntry]]: entries: list[tuple[int, StatsEntry]] = [] def df_traversal(node: _StatsTreeNode, depth=0): @@ -327,15 +342,11 @@ def df_traversal(node: _StatsTreeNode, depth=0): return entries - def _convert_stats_tree_to_dict(self, - tree: list[_StatsTreeNode]) -> list[dict]: + def _convert_stats_tree_to_dict(self, tree: list[_StatsTreeNode]) -> list[dict]: root_dicts: list[dict] = [] def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]): - curr_json_list.append({ - "entry": asdict(node.entry), - "children": [] - }) + curr_json_list.append({"entry": asdict(node.entry), "children": []}) for child in node.children: df_traversal(child, curr_json_list[-1]["children"]) @@ -346,7 +357,6 @@ def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]): class layerwise_profile(profile): - def __init__(self, num_running_seqs: Optional[int] = None): """ layerwise profile constructor. @@ -361,7 +371,8 @@ def __init__(self, num_running_seqs: Optional[int] = None): record_shapes=True, with_stack=True, with_modules=True, - experimental_config=_ExperimentalConfig(verbose=True)) + experimental_config=_ExperimentalConfig(verbose=True), + ) self.num_running_seqs = num_running_seqs @@ -371,5 +382,5 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) self.results = LayerwiseProfileResults( - self.profiler.kineto_results, - num_running_seqs=self.num_running_seqs) + self.profiler.kineto_results, num_running_seqs=self.num_running_seqs + ) diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py index 9f0f56a15fd5..b3607fbecde7 100644 --- a/vllm/profiler/utils.py +++ b/vllm/profiler/utils.py @@ -30,9 +30,9 @@ def trim_string_back(string, width): class TablePrinter: - - def __init__(self, row_cls: type[dataclasses.dataclass], - column_widths: dict[str, int]): + def __init__( + self, row_cls: type[dataclasses.dataclass], column_widths: dict[str, int] + ): self.row_cls = row_cls self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] self.column_widths = column_widths @@ -46,16 +46,18 @@ def print_table(self, rows: list[dataclasses.dataclass]): def _print_header(self): for i, f in enumerate(self.fieldnames): - last = (i == len(self.fieldnames) - 1) + last = i == len(self.fieldnames) - 1 col_width = self.column_widths[f] - print(trim_string_back(f, col_width).ljust(col_width), - end=" | " if not last else "\n") + print( + trim_string_back(f, col_width).ljust(col_width), + end=" | " if not last else "\n", + ) def _print_row(self, row): assert isinstance(row, self.row_cls) for i, f in enumerate(self.fieldnames): - last = (i == len(self.fieldnames) - 1) + last = i == len(self.fieldnames) - 1 col_width = self.column_widths[f] val = getattr(row, f) @@ -75,9 +77,9 @@ def _print_line(self): print("=" * (total_col_width + 3 * (len(self.column_widths) - 1))) -def indent_string(string: str, - indent: int, - indent_style: Union[Callable[[int], str], str] = " ") -> str: +def indent_string( + string: str, indent: int, indent_style: Union[Callable[[int], str], str] = " " +) -> str: if indent: if isinstance(indent_style, str): return indent_style * indent + string @@ -111,15 +113,14 @@ def event_arg_repr(arg) -> str: elif isinstance(arg, tuple): return f"({', '.join([event_arg_repr(x) for x in arg])})" else: - assert isinstance(arg, - _TensorMetadata), f"Unsupported type: {type(arg)}" - sizes_str = ', '.join([str(x) for x in arg.sizes]) + assert isinstance(arg, _TensorMetadata), f"Unsupported type: {type(arg)}" + sizes_str = ", ".join([str(x) for x in arg.sizes]) return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]" def event_torch_op_repr(event: _ProfilerEvent) -> str: assert event.tag == _EventType.TorchOp - args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs]) + args_str = ", ".join([event_arg_repr(x) for x in event.typed[1].inputs]) return f"{event.name}({args_str})".replace("aten::", "") @@ -127,15 +128,17 @@ def event_module_repr(event: _ProfilerEvent) -> str: assert event_has_module(event) module = event.typed[1].module if module.parameters and len(module.parameters) > 0: - args_str = ', '.join( - [f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters]) + args_str = ", ".join( + [f"{x[0]}={event_arg_repr(x[1])}" for x in module.parameters] + ) return f"{module.cls_name}({args_str})" else: return module.cls_name -def event_torch_op_stack_trace(curr_event: _ProfilerEvent, - until: Callable[[_ProfilerEvent], bool]) -> str: +def event_torch_op_stack_trace( + curr_event: _ProfilerEvent, until: Callable[[_ProfilerEvent], bool] +) -> str: trace = "" curr_event = curr_event.parent while curr_event and not until(curr_event): diff --git a/vllm/ray/lazy_utils.py b/vllm/ray/lazy_utils.py index bb3535579cfd..64b5f51571a3 100644 --- a/vllm/ray/lazy_utils.py +++ b/vllm/ray/lazy_utils.py @@ -6,6 +6,7 @@ def is_ray_initialized(): """Check if Ray is initialized.""" try: import ray + return ray.is_initialized() except ImportError: return False @@ -16,7 +17,10 @@ def is_in_ray_actor(): try: import ray - return (ray.is_initialized() - and ray.get_runtime_context().get_actor_id() is not None) + + return ( + ray.is_initialized() + and ray.get_runtime_context().get_actor_id() is not None + ) except ImportError: return False diff --git a/vllm/ray/ray_env.py b/vllm/ray/ray_env.py index f6a994bb3c22..a89e55bd7e4b 100644 --- a/vllm/ray/ray_env.py +++ b/vllm/ray/ray_env.py @@ -14,7 +14,8 @@ # This file contains a list of env vars that should not be copied # from the driver to the Ray workers. RAY_NON_CARRY_OVER_ENV_VARS_FILE = os.path.join( - CONFIG_HOME, "ray_non_carry_over_env_vars.json") + CONFIG_HOME, "ray_non_carry_over_env_vars.json" +) try: if os.path.exists(RAY_NON_CARRY_OVER_ENV_VARS_FILE): @@ -25,13 +26,16 @@ except json.JSONDecodeError: logger.warning( "Failed to parse %s. Using an empty set for non-carry-over env vars.", - RAY_NON_CARRY_OVER_ENV_VARS_FILE) + RAY_NON_CARRY_OVER_ENV_VARS_FILE, + ) RAY_NON_CARRY_OVER_ENV_VARS = set() -def get_env_vars_to_copy(exclude_vars: Optional[set[str]] = None, - additional_vars: Optional[set[str]] = None, - destination: Optional[str] = None) -> set[str]: +def get_env_vars_to_copy( + exclude_vars: Optional[set[str]] = None, + additional_vars: Optional[set[str]] = None, + destination: Optional[str] = None, +) -> set[str]: """ Get the environment variables to copy to downstream Ray actors. @@ -60,13 +64,17 @@ def get_env_vars_to_copy(exclude_vars: Optional[set[str]] = None, to_destination = " to " + destination if destination is not None else "" - logger.info("RAY_NON_CARRY_OVER_ENV_VARS from config: %s", - RAY_NON_CARRY_OVER_ENV_VARS) - logger.info("Copying the following environment variables%s: %s", - to_destination, - [v for v in env_vars_to_copy if v in os.environ]) logger.info( - "If certain env vars should NOT be copied, add them to " - "%s file", RAY_NON_CARRY_OVER_ENV_VARS_FILE) + "RAY_NON_CARRY_OVER_ENV_VARS from config: %s", RAY_NON_CARRY_OVER_ENV_VARS + ) + logger.info( + "Copying the following environment variables%s: %s", + to_destination, + [v for v in env_vars_to_copy if v in os.environ], + ) + logger.info( + "If certain env vars should NOT be copied, add them to %s file", + RAY_NON_CARRY_OVER_ENV_VARS_FILE, + ) return env_vars_to_copy diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 3c8a9c6ae0d3..78d3bf35f2a3 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -9,6 +9,7 @@ from .granite_reasoning_parser import GraniteReasoningParser from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser from .mistral_reasoning_parser import MistralReasoningParser +from .olmo3_reasoning_parser import Olmo3ReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser from .seedoss_reasoning_parser import SeedOSSReasoningParser from .step3_reasoning_parser import Step3ReasoningParser @@ -23,6 +24,7 @@ "Qwen3ReasoningParser", "Glm4MoeModelReasoningParser", "MistralReasoningParser", + "Olmo3ReasoningParser", "Step3ReasoningParser", "GptOssReasoningParser", "SeedOSSReasoningParser", diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 320009c2611e..2d93f0702f72 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -13,9 +13,11 @@ from vllm.utils import import_from_path, is_list_of if TYPE_CHECKING: - from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, - ResponsesRequest) + from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ResponsesRequest, + ) from vllm.transformers_utils.tokenizer import AnyTokenizer else: ChatCompletionRequest = Any @@ -128,8 +130,7 @@ def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]: if name in cls.reasoning_parsers: return cls.reasoning_parsers[name] - raise KeyError( - f"reasoning helper: '{name}' not found in reasoning_parsers") + raise KeyError(f"reasoning helper: '{name}' not found in reasoning_parsers") @classmethod def _register_module( @@ -139,8 +140,9 @@ def _register_module( force: bool = True, ) -> None: if not issubclass(module, ReasoningParser): - raise TypeError("module must be subclass of ReasoningParser, " - f"but got {type(module)}") + raise TypeError( + f"module must be subclass of ReasoningParser, but got {type(module)}" + ) if module_name is None: module_name = module.__name__ if isinstance(module_name, str): @@ -148,8 +150,9 @@ def _register_module( for name in module_name: if not force and name in cls.reasoning_parsers: existed_module = cls.reasoning_parsers[name] - raise KeyError(f"{name} is already registered " - f"at {existed_module.__module__}") + raise KeyError( + f"{name} is already registered at {existed_module.__module__}" + ) cls.reasoning_parsers[name] = module @classmethod @@ -168,11 +171,11 @@ def register_module( raise TypeError(f"force must be a boolean, but got {type(force)}") # raise the error ahead of time - if not (name is None or isinstance(name, str) - or is_list_of(name, str)): + if not (name is None or isinstance(name, str) or is_list_of(name, str)): raise TypeError( "name must be None, an instance of str, or a sequence of str, " - f"but got {type(name)}") + f"but got {type(name)}" + ) # use it as a normal method: x.register_module(module=SomeClass) if module is not None: @@ -197,6 +200,7 @@ def import_reasoning_parser(cls, plugin_path: str) -> None: try: import_from_path(module_name, plugin_path) except Exception: - logger.exception("Failed to load module '%s' from %s.", - module_name, plugin_path) + logger.exception( + "Failed to load module '%s' from %s.", module_name, plugin_path + ) return diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py index cea4b8601ae7..b4106a4f5794 100644 --- a/vllm/reasoning/basic_parsers.py +++ b/vllm/reasoning/basic_parsers.py @@ -5,8 +5,11 @@ from collections.abc import Sequence from typing import Optional, Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage, ResponsesRequest) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ResponsesRequest, +) from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -14,11 +17,11 @@ class BaseThinkingReasoningParser(ReasoningParser): """ Base class for reasoning parsers that use thinking tokens. - + This class provides common functionality for parsers that use start and end tokens to delimit reasoning content ( e.g., <think>...</think>, <seed:think>...</seed:think>). - + Subclasses must implement the start and end tokens via abstract properties. """ @@ -41,18 +44,19 @@ def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") + "constructor during construction." + ) if not self.start_token or not self.end_token: - raise ValueError( - "start_token and end_token must be defined in subclasses") + raise ValueError("start_token and end_token must be defined in subclasses") self.start_token_id = self.vocab.get(self.start_token) self.end_token_id = self.vocab.get(self.end_token) if self.start_token_id is None or self.end_token_id is None: raise RuntimeError( f"{self.__class__.__name__} reasoning parser could not locate " - "think start/end tokens in the tokenizer!") + "think start/end tokens in the tokenizer!" + ) def is_reasoning_end(self, input_ids: list[int]) -> bool: return self.end_token_id in input_ids @@ -64,7 +68,7 @@ def extract_content_ids(self, input_ids: list[int]) -> list[int]: if self.end_token_id not in input_ids[:-1]: return [] else: - return input_ids[input_ids.index(self.end_token_id) + 1:] + return input_ids[input_ids.index(self.end_token_id) + 1 :] def extract_reasoning_content_streaming( self, @@ -81,9 +85,9 @@ def extract_reasoning_content_streaming( Uses token IDs for faster processing. """ # Skip single special tokens - if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.start_token_id, self.end_token_id - ]): + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] in [self.start_token_id, self.end_token_id] + ): return None # Check if start token is present in previous or delta. @@ -94,7 +98,7 @@ def extract_reasoning_content_streaming( # extract reasoning content end_index = delta_text.find(self.end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.end_token):] + content = delta_text[end_index + len(self.end_token) :] return DeltaMessage( reasoning_content=reasoning_content, content=content if content else None, @@ -113,9 +117,10 @@ def extract_reasoning_content_streaming( # extract reasoning content start_index = delta_text.find(self.start_token) end_index = delta_text.find(self.end_token) - reasoning_content = delta_text[start_index + - len(self.start_token):end_index] - content = delta_text[end_index + len(self.end_token):] + reasoning_content = delta_text[ + start_index + len(self.start_token) : end_index + ] + content = delta_text[end_index + len(self.end_token) :] return DeltaMessage( reasoning_content=reasoning_content, content=content if content else None, @@ -129,28 +134,27 @@ def extract_reasoning_content_streaming( return DeltaMessage(content=delta_text) def extract_reasoning_content( - self, model_output: str, request: Union[ChatCompletionRequest, - ResponsesRequest] + self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest] ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from the model output. - + This is the base implementation that works for most models. Subclasses can override this method for specific behavior. """ # Check if the start token is present in the model output, remove it # if it is present. model_output_parts = model_output.partition(self.start_token) - model_output = model_output_parts[2] if model_output_parts[ - 1] else model_output_parts[0] + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) # For models that may not generate start token, # assume the reasoning content is always at the start. if self.end_token not in model_output: return model_output, None else: - reasoning_content, _, content = model_output.partition( - self.end_token) + reasoning_content, _, content = model_output.partition(self.end_token) # If generation stops right after end-of-think, return null content final_content = content or None return reasoning_content, final_content diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py index 76d2959e1c9a..264da54b4879 100644 --- a/vllm/reasoning/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -45,14 +45,17 @@ def extract_reasoning_content_streaming( current_token_ids, delta_token_ids, ) - if (ret is not None and self.start_token_id not in previous_token_ids - and self.start_token_id not in delta_token_ids): + if ( + ret is not None + and self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids + ): if self.end_token_id in delta_token_ids: # end token in delta with more tokens, # extract reasoning content and content end_index = delta_text.find(self.end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.end_token):] + content = delta_text[end_index + len(self.end_token) :] return DeltaMessage( reasoning_content=reasoning_content, content=content if content else None, diff --git a/vllm/reasoning/glm4_moe_reasoning_parser.py b/vllm/reasoning/glm4_moe_reasoning_parser.py index 8d7488afce68..da98515c7e62 100644 --- a/vllm/reasoning/glm4_moe_reasoning_parser.py +++ b/vllm/reasoning/glm4_moe_reasoning_parser.py @@ -6,8 +6,7 @@ from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -35,17 +34,21 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") + "constructor during construction." + ) self.think_start_token_id = self.vocab.get(self.think_start_token) self.think_end_token_id = self.vocab.get(self.think_end_token) self.assistant_token_id = self.vocab.get(self.assistant_token) - if (self.think_start_token_id is None - or self.think_end_token_id is None - or self.assistant_token_id is None): + if ( + self.think_start_token_id is None + or self.think_end_token_id is None + or self.assistant_token_id is None + ): raise RuntimeError( "Glm4MoeModel reasoning parser could not locate " - "think start/end or assistant tokens in the tokenizer!") + "think start/end or assistant tokens in the tokenizer!" + ) def is_reasoning_end(self, input_ids: list[int]) -> bool: """ @@ -67,7 +70,7 @@ def extract_content_ids(self, input_ids: list[int]) -> list[int]: if self.think_end_token_id not in input_ids[:-1]: return [] else: - return input_ids[input_ids.index(self.think_end_token_id) + 1:] + return input_ids[input_ids.index(self.think_end_token_id) + 1 :] def extract_reasoning_content_streaming( self, @@ -87,9 +90,9 @@ def extract_reasoning_content_streaming( - 'xyz' goes to content """ # Skip single special tokens - if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.think_start_token_id, self.think_end_token_id - ]): + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id] + ): return None if self.think_start_token_id in previous_token_ids: @@ -98,9 +101,11 @@ def extract_reasoning_content_streaming( # extract reasoning content end_index = delta_text.find(self.think_end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) elif self.think_end_token_id in previous_token_ids: # <think> in previous, </think> in previous, # reasoning content continues @@ -114,12 +119,14 @@ def extract_reasoning_content_streaming( # <think> in delta, </think> in delta, extract reasoning content start_index = delta_text.find(self.think_start_token) end_index = delta_text.find(self.think_end_token) - reasoning_content = delta_text[start_index + - len(self.think_start_token - ):end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) + reasoning_content = delta_text[ + start_index + len(self.think_start_token) : end_index + ] + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) else: # <think> in delta, no </think> in delta, # reasoning content continues @@ -129,7 +136,7 @@ def extract_reasoning_content_streaming( return DeltaMessage(content=delta_text) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from the model output. @@ -143,22 +150,24 @@ def extract_reasoning_content( """ # Check if the model output contains the <think> and </think> tokens. - if (self.think_start_token not in model_output - or self.think_end_token not in model_output): + if ( + self.think_start_token not in model_output + or self.think_end_token not in model_output + ): return None, model_output # Check if the <think> is present in the model output, remove it # if it is present. model_output_parts = model_output.partition(self.think_start_token) - model_output = model_output_parts[2] if model_output_parts[ - 1] else model_output_parts[0] + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) # Check if the model output contains the </think> tokens. # If the end token is not found, return the model output as is. if self.think_end_token not in model_output: return None, model_output # Extract reasoning content from the model output. - reasoning_content, _, content = model_output.partition( - self.think_end_token) + reasoning_content, _, content = model_output.partition(self.think_end_token) final_content = content or None return reasoning_content, final_content diff --git a/vllm/reasoning/gptoss_reasoning_parser.py b/vllm/reasoning/gptoss_reasoning_parser.py index b0988d5d2618..738c7b51694a 100644 --- a/vllm/reasoning/gptoss_reasoning_parser.py +++ b/vllm/reasoning/gptoss_reasoning_parser.py @@ -7,8 +7,7 @@ from transformers import PreTrainedTokenizerBase from vllm.entrypoints.harmony_utils import parse_chat_output -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -27,7 +26,8 @@ class GptOssReasoningParser(ReasoningParser): def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): super().__init__(tokenizer, *args, **kwargs) self.reasoning_end_token_ids = self.model_tokenizer.encode( - "<|start|>assistant<|channel|>final<|message|>") + "<|start|>assistant<|channel|>final<|message|>" + ) def is_reasoning_end(self, input_ids: list[int]) -> bool: end_token_ids = self.reasoning_end_token_ids @@ -35,7 +35,7 @@ def is_reasoning_end(self, input_ids: list[int]) -> bool: # Check if the end sequence is present in the input_ids. # We search from the end of input_ids to find the last match. for i in range(len(input_ids) - len(end_token_ids), -1, -1): - if input_ids[i:i + len(end_token_ids)] == end_token_ids: + if input_ids[i : i + len(end_token_ids)] == end_token_ids: return True return False @@ -54,28 +54,25 @@ def extract_reasoning_content_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: - prev_reasoning, prev_content, _ = parse_chat_output( - list(previous_token_ids)) - cur_reasoning, cur_content, _ = parse_chat_output( - list(current_token_ids)) + prev_reasoning, prev_content, _ = parse_chat_output(list(previous_token_ids)) + cur_reasoning, cur_content, _ = parse_chat_output(list(current_token_ids)) reasoning_delta = None content_delta = None if cur_reasoning is not None: prev_r = prev_reasoning or "" if cur_reasoning.startswith(prev_r): - reasoning_delta = cur_reasoning[len(prev_r):] or None + reasoning_delta = cur_reasoning[len(prev_r) :] or None else: reasoning_delta = cur_reasoning if cur_content is not None: prev_c = prev_content or "" if cur_content.startswith(prev_c): - content_delta = cur_content[len(prev_c):] or None + content_delta = cur_content[len(prev_c) :] or None else: content_delta = cur_content if reasoning_delta is None and content_delta is None: return None - return DeltaMessage(reasoning_content=reasoning_delta, - content=content_delta) + return DeltaMessage(reasoning_content=reasoning_delta, content=content_delta) def extract_reasoning_content( self, diff --git a/vllm/reasoning/granite_reasoning_parser.py b/vllm/reasoning/granite_reasoning_parser.py index b76170f39f10..543b202989ee 100644 --- a/vllm/reasoning/granite_reasoning_parser.py +++ b/vllm/reasoning/granite_reasoning_parser.py @@ -7,8 +7,7 @@ import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -34,15 +33,14 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): self.response_start_expr = r"(?:Here's|Here is) my response:" self.reasoning_regex = re.compile( - rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", - re.DOTALL) + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL + ) self.valid_think_starts = [ - "Here's my thought process:", "Here is my thought process:" - ] - self.valid_response_starts = [ - "Here's my response:", "Here is my response:" + "Here's my thought process:", + "Here is my thought process:", ] + self.valid_response_starts = ["Here's my response:", "Here is my response:"] # Substrings to match for sequence boundaries on raw text self.seq_boundary_end = ":" @@ -50,10 +48,11 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): # The longest any thinking / start of response message can be self.longest_think_start = max( - len(think_start) for think_start in self.valid_think_starts) + len(think_start) for think_start in self.valid_think_starts + ) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: """Extract the reasoning content & content sections, respectively. If the sequence doesn't match what we expect, i.e., the model generates @@ -111,24 +110,27 @@ def extract_reasoning_content_streaming( DeltaMessage with either reasoning content or content, or None. """ reasoning_content, resp_seq_len, content = self._get_content_sections( - current_text) + current_text + ) # Either we haven't finished the start of the reasoning sequence, # or the model is generating something unexpected. if not reasoning_content: delta_message = self._get_delta_message_with_no_reasoning_bounds( - current_text, delta_text) + current_text, delta_text + ) # We have a start of reasoning message, but have not yet finished # the start of response sequence. elif not content: delta_message = self._get_delta_message_with_no_response_bounds( - current_text, reasoning_content, delta_text) + current_text, reasoning_content, delta_text + ) # We've finished both the start of reasoning and start of response seq. else: # This should never happen since we matched on the response assert resp_seq_len is not None delta_message = self._get_delta_message_with_both_bounds( - delta_text, reasoning_content, content, current_text, - resp_seq_len) + delta_text, reasoning_content, content, current_text, resp_seq_len + ) if not delta_message.content and not delta_message.reasoning_content: return None return delta_message @@ -139,26 +141,27 @@ def _is_reasoning_start_substr(self, text: str) -> bool: Args: text (str): Text to check for leading substr. - + Returns: bool: True if any of the possible reasoning start seqs match. """ return any( - think_start.startswith(text) - for think_start in self.valid_think_starts) + think_start.startswith(text) for think_start in self.valid_think_starts + ) def _is_response_start_substr(self, text: str) -> bool: """Check if a text matches one of the possible start response seqs. Args: text (str): Text to check for leading substr. - + Returns: bool: True if any of the possible response start seqs match. """ return any( response_start.startswith(text) - for response_start in self.valid_response_starts) + for response_start in self.valid_response_starts + ) def _get_delta_message_with_no_reasoning_bounds( self, @@ -177,8 +180,7 @@ def _get_delta_message_with_no_reasoning_bounds( """ prev_longest_length = len(current_text) - len(delta_text) is_substr = self._is_reasoning_start_substr(current_text) - was_substr = self._is_reasoning_start_substr( - current_text[:prev_longest_length]) + was_substr = self._is_reasoning_start_substr(current_text[:prev_longest_length]) # Check if we just generated something NOT in the special token seq; # if so, add everything that we previously skipped with this delta @@ -220,12 +222,13 @@ def _get_delta_message_with_no_response_bounds( # content and fully parse it out; we should not pass the : back. ends_with_start_response_seq = any( current_text.endswith(response_start) - for response_start in self.valid_response_starts) + for response_start in self.valid_response_starts + ) if reasoning_content is None or ends_with_start_response_seq: return DeltaMessage(reasoning_content=None, content=None) # Consider previous / current text only within context of the reasoning - previous_text = reasoning_content[:-len(delta_text)] + previous_text = reasoning_content[: -len(delta_text)] current_text = reasoning_content # We need to be careful about adding unfinished response sequences; @@ -234,12 +237,21 @@ def _get_delta_message_with_no_response_bounds( delta_idx = delta_text.rfind(self.seq_boundary_start) # Check the state of potential start of response substring matches. - prev_was_substr = self._is_response_start_substr( - previous_text[prev_idx:]) if prev_idx >= 0 else False - delta_continues_substr = self._is_response_start_substr( - current_text[prev_idx:]) if prev_idx >= 0 else False - delta_new_substr = self._is_response_start_substr( - delta_text[delta_idx:]) if delta_idx >= 0 else False + prev_was_substr = ( + self._is_response_start_substr(previous_text[prev_idx:]) + if prev_idx >= 0 + else False + ) + delta_continues_substr = ( + self._is_response_start_substr(current_text[prev_idx:]) + if prev_idx >= 0 + else False + ) + delta_new_substr = ( + self._is_response_start_substr(delta_text[delta_idx:]) + if delta_idx >= 0 + else False + ) # Delta only contains potential continued response sequence text. if delta_continues_substr: @@ -248,18 +260,17 @@ def _get_delta_message_with_no_response_bounds( if not prev_was_substr: # Delta may be starting a new response seq but has other text too. if delta_new_substr: - return DeltaMessage(reasoning_content=delta_text[:delta_idx], - content=None) + return DeltaMessage( + reasoning_content=delta_text[:delta_idx], content=None + ) # Normal case for most reasoning text (no potential special seqs). return DeltaMessage(reasoning_content=delta_text, content=None) # The substring that previously seemed to be a potential response # seq wasn't one; we need to add the content to the delta message, # and also slice off the potential response sequence elif delta_new_substr: - reasoning_content = previous_text[ - prev_idx:] + delta_text[:delta_idx] - return DeltaMessage(reasoning_content=reasoning_content, - content=None) + reasoning_content = previous_text[prev_idx:] + delta_text[:delta_idx] + return DeltaMessage(reasoning_content=reasoning_content, content=None) # No new substring yet, and we broke our old one; take the whole delta return DeltaMessage( reasoning_content=previous_text[prev_idx:] + delta_text, @@ -288,23 +299,21 @@ def _get_delta_message_with_both_bounds( DeltaMessage: Message containing the parsed content. """ # Always have content; take length to the end - delta_content = delta_text[-len(response_content):] - reasoning_end_idx = len(delta_text) - (len(response_content) + - response_seq_len) + delta_content = delta_text[-len(response_content) :] + reasoning_end_idx = len(delta_text) - (len(response_content) + response_seq_len) if reasoning_end_idx < 0: delta_reasoning_content = None else: # Get the starting offset - start_reasoning_content_idx = len( - reasoning_content) + response_seq_len + len( - response_content) - 1 + start_reasoning_content_idx = ( + len(reasoning_content) + response_seq_len + len(response_content) - 1 + ) delta_offset = len(current_text) - len(delta_text) start_offset = start_reasoning_content_idx - delta_offset if start_offset < 0: start_offset = 0 - delta_reasoning_content = delta_text[ - start_offset:reasoning_end_idx] + delta_reasoning_content = delta_text[start_offset:reasoning_end_idx] return DeltaMessage( reasoning_content=delta_reasoning_content, @@ -329,7 +338,8 @@ def _get_content_sections( start_reasoning_content = None parsed_content = False delimiter_idxs = [ - idx for idx, char in enumerate(current_text) + idx + for idx, char in enumerate(current_text) if char == self.seq_boundary_end ] @@ -346,17 +356,15 @@ def _get_content_sections( # Check to see if the start of response seq if complete elif not parsed_content: for response_start in self.valid_response_starts: - if current_chunk[-len(response_start) + - 1:] == response_start[:-1]: + if current_chunk[-len(response_start) + 1 :] == response_start[:-1]: # Mark end of reasoning and start response content # after the start of response sequence. - end_reasoning_content = current_chunk_end - len( - response_start) + end_reasoning_content = current_chunk_end - len(response_start) reasoning_content = current_text[ - start_reasoning_content:end_reasoning_content] - response_content = current_text[current_chunk_end + 1:] - return reasoning_content, len( - response_start), response_content + start_reasoning_content:end_reasoning_content + ] + response_content = current_text[current_chunk_end + 1 :] + return reasoning_content, len(response_start), response_content if start_reasoning_content and not parsed_content: return current_text[start_reasoning_content:], None, None diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py index 6e3b056d6b62..381f1b5f3466 100644 --- a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -7,8 +7,7 @@ import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -22,16 +21,16 @@ class HunyuanA13BReasoningParser(ReasoningParser): HunyuanReasoningParser - This class implements a reasoning parser specifically designed - for the Hunyuan A13B Model. It is responsible for parsing and - extracting structured reasoning and answer segments from model + This class implements a reasoning parser specifically designed + for the Hunyuan A13B Model. It is responsible for parsing and + extracting structured reasoning and answer segments from model outputs that follow a specific pattern. Key Features: - For non-stream output , Recognizes and extracts reasoning ("think") and answer ("answer") sections from text using regular expressions. - For stream process, it requires a token id sequences to change the - reasoning state and other state so it maintains internal state to + reasoning state and other state so it maintains internal state to manage parsing across multiple token. @@ -50,20 +49,19 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): self.full_match_reasoning_regex = re.compile( rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}", - re.DOTALL) + re.DOTALL, + ) self.half_match_reasoning_regex = re.compile( - rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", - re.DOTALL) + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL + ) self.think_start_ids = [14023, 771, 397] self.think_start_ids_fast = [14023, 771, 1363] self.response_start_ids = [198, 524, 27963, 397, 27, 9399, 397] self.response_start_ids_fast = [524, 27963, 397, 27, 9399, 397] self.response_end_ids = [198, 524, 9399, 29] - self.fast_think_ids = [ - 14023, 771, 1363, 524, 27963, 397, 27, 9399, 397 - ] + self.fast_think_ids = [14023, 771, 1363, 524, 27963, 397, 27, 9399, 397] # when state change, send out all the buffered text in last state self.buffered_text = [] @@ -91,7 +89,7 @@ def extract_content_ids(self, input_ids: list[int]) -> list[int]: return [] def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: """Extract the reasoning content & content sections, respectively. If the sequence doesn't match what we expect, i.e., the model generates @@ -121,8 +119,7 @@ def extract_reasoning_content( reasoning_content, response_content = fallback_match[0] if response_content.endswith(self.response_end_expr): - response_content = response_content[:-len(self. - response_end_expr)] + response_content = response_content[: -len(self.response_end_expr)] if len(reasoning_content) == 0: reasoning_content = None @@ -133,8 +130,9 @@ def extract_reasoning_content( return None, model_output - def _is_strict_increasing_subsequence(self, subsequence: Sequence[int], - sequence: Sequence[int]) -> bool: + def _is_strict_increasing_subsequence( + self, subsequence: Sequence[int], sequence: Sequence[int] + ) -> bool: if not subsequence: return False @@ -159,27 +157,27 @@ def extract_reasoning_content_streaming( response_start_sequence = self.response_start_ids response_end_sequence = self.response_end_ids - assert (len(delta_token_ids) == 1) + assert len(delta_token_ids) == 1 # Process each token in the delta token = delta_token_ids[0] def check_token_with_sequence(token): if self.current_state == "idle" or self.current_state == "think": - return (token == self.expected_sequence[self.sequence_index] - or token == \ - self.expected_sequence_side[self.sequence_index]) + return ( + token == self.expected_sequence[self.sequence_index] + or token == self.expected_sequence_side[self.sequence_index] + ) else: return token == self.expected_sequence[self.sequence_index] def check_last_token(token): if self.current_state == "idle" or self.current_state == "think": # only return true if it's judge using a side sequence. - if (self.sequence_index - 1 < len(self.expected_sequence_side) - and token - == self.expected_sequence_side[self.sequence_index - - 1]): - return self.sequence_index == len( - self.expected_sequence_side) + if ( + self.sequence_index - 1 < len(self.expected_sequence_side) + and token == self.expected_sequence_side[self.sequence_index - 1] + ): + return self.sequence_index == len(self.expected_sequence_side) else: return self.sequence_index == len(self.expected_sequence) else: @@ -227,19 +225,19 @@ def check_last_token(token): # Return content based on current state if self.current_state == "think": - return DeltaMessage(reasoning_content=buffered_content, - content=None) + return DeltaMessage( + reasoning_content=buffered_content, content=None + ) else: - return DeltaMessage(reasoning_content=None, - content=buffered_content) + return DeltaMessage( + reasoning_content=None, content=buffered_content + ) else: # No buffered content, send normally if self.current_state == "think": - return DeltaMessage(reasoning_content=delta_text, - content=None) + return DeltaMessage(reasoning_content=delta_text, content=None) else: - return DeltaMessage(reasoning_content=None, - content=delta_text) + return DeltaMessage(reasoning_content=None, content=delta_text) # If no content to send in this delta return None diff --git a/vllm/reasoning/mistral_reasoning_parser.py b/vllm/reasoning/mistral_reasoning_parser.py index ceda96ca6a6d..5658c372a264 100644 --- a/vllm/reasoning/mistral_reasoning_parser.py +++ b/vllm/reasoning/mistral_reasoning_parser.py @@ -5,8 +5,7 @@ from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager -from vllm.reasoning.deepseek_r1_reasoning_parser import ( - DeepSeekR1ReasoningParser) +from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) @@ -23,34 +22,35 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser): def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs): if not isinstance(tokenizer, MistralTokenizer): - raise ValueError( - "The tokenizer must be an instance of MistralTokenizer.") + raise ValueError("The tokenizer must be an instance of MistralTokenizer.") ReasoningParser.__init__(self, tokenizer, *args, **kwargs) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") + "constructor during construction." + ) - self.start_token_id = tokenizer.tokenizer.get_control_token( - self.start_token) - self.end_token_id = tokenizer.tokenizer.get_control_token( - self.end_token) + self.start_token_id = tokenizer.tokenizer.get_control_token(self.start_token) + self.end_token_id = tokenizer.tokenizer.get_control_token(self.end_token) if self.start_token_id is None or self.end_token_id is None: raise RuntimeError( "Mistral reasoning parser could not locate think start/end " - "tokens in the tokenizer!") + "tokens in the tokenizer!" + ) @cached_property def start_token(self) -> str: """The token that starts reasoning content.""" from mistral_common.tokens.tokenizers.base import SpecialTokens + return SpecialTokens.begin_think @cached_property def end_token(self) -> str: """The token that ends reasoning content.""" from mistral_common.tokens.tokenizers.base import SpecialTokens + return SpecialTokens.end_think diff --git a/vllm/reasoning/olmo3_reasoning_parser.py b/vllm/reasoning/olmo3_reasoning_parser.py new file mode 100644 index 000000000000..b330e8b1fdd5 --- /dev/null +++ b/vllm/reasoning/olmo3_reasoning_parser.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses as dt +import enum +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional, Union + +import regex as re + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer import AnyTokenizer + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ResponsesRequest, +) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +class Olmo3ReasoningState(enum.Enum): + REASONING = 1 + CONTENT = 2 + + +@dt.dataclass(frozen=True) +class Indices: + start: int + end: int + + def __len__(self): + return self.end - self.start + + +def string_overlap(a: str, b: str) -> tuple[Optional[Indices], Optional[Indices]]: + """ + Find the longest overlap where the end of string a matches the start + of string b. + + Args: + a: First string + b: Second string + + Returns: + Tuple of IndicesTuples representing the overlapping portions in each + string, or a tuple of None if no overlap exists + """ + + # swap so a is always the shorter string + a, b, swap = (a, b, False) if len(a) < len(b) else (b, a, True) + + # first check: is a fully contained in b? + if a in b: + ind_a = Indices(0, len(a)) + ind_b = Indices(b.index(a), b.index(a) + len(a)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + # second check: does the end of a overlap with the + # beginning of b? + for i in range(len(a) - 1, 0, -1): + if a[-i:] == b[:i]: + ind_a = Indices(len(a) - i, len(a)) + ind_b = Indices(0, i) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + # third check: does the beginning of a overlap with + # the end of b? + for i in range(len(a) - 1, 0, -1): + if b[-i:] == a[:i]: + ind_a = Indices(0, i) + ind_b = Indices(len(b) - i, len(b)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + return None, None + + +@dt.dataclass +class Olmo3ReasoningBuffer: + think_start: str = "<think>" + think_end: str = "</think>" + buffer: str = "" + + # we start in reasoning state to support cases where we hardcode + # <think> as the start of the reasoning block. + # In those cases, the only token we will see is </think>, which + # is when we switch to content state. + state: Olmo3ReasoningState = Olmo3ReasoningState.REASONING + + def process_buffer(self) -> Optional[DeltaMessage]: + start_think_idx = self.buffer.find(self.think_start) + + if start_think_idx >= 0: + self.state = Olmo3ReasoningState.REASONING + pretext, self.buffer = ( + self.buffer[:start_think_idx], + self.buffer[start_think_idx + len(self.think_start) :], + ) + if start_think_idx > 0: + # this covers the case there's content before + # the start of the reasoning block + return DeltaMessage(content=pretext) + + end_think_idx = self.buffer.rfind(self.think_end) + + if end_think_idx >= 0: + self.state = Olmo3ReasoningState.CONTENT + pretext, self.buffer = ( + self.buffer[:end_think_idx], + self.buffer[end_think_idx + len(self.think_end) :], + ) + if end_think_idx > 0: + # this covers the case there's content before + # the end of the reasoning block + return DeltaMessage(reasoning_content=pretext) + + if self.state == Olmo3ReasoningState.REASONING: + # we are inside reasoning block, return and empty + # the text buffer + ( + text_buffer, + self.buffer, + ) = self.buffer, "" + return DeltaMessage(reasoning_content=text_buffer) + + if self.state == Olmo3ReasoningState.CONTENT: + # we are outside reasoning block, return and empty + # the text buffer + ( + text_buffer, + self.buffer, + ) = self.buffer, "" + return DeltaMessage(content=text_buffer) + + # nothing to return unless we are in reasoning or content state + return None + + def __len__(self): + # is the length of the text buffer + return len(self.buffer) + + def add_text(self, delta_text: str) -> Optional[DeltaMessage]: + # we start by adding the delta text to the buffer + self.buffer += delta_text + + # setting this to empty before starting + delta_message: Optional[DeltaMessage] = None + + # we start by computing the overlap between the delta_text + # and start/end of think tokens. + _, overlap_think_start = string_overlap(delta_text, self.think_start) + _, overlap_think_end = string_overlap(delta_text, self.think_end) + + partial_overlap_start = overlap_think_start is not None and len( + overlap_think_start + ) < len(self.think_start) + partial_overlap_end = overlap_think_end is not None and len( + overlap_think_end + ) < len(self.think_end) + + if ( + partial_overlap_start + and self.think_start in self.buffer + and not partial_overlap_end + ): + # we can only process the buffer if partial overlap + # is the last part of think token (thus causing + # text_buffer to contain the start of think token) + # and there are no partial overlaps with end think + delta_message = self.process_buffer() + + elif partial_overlap_end and self.think_end in self.buffer: + # same as before (partial overlap only allowed) + # if the buffer contains the end think token, + # but we don't have to check for partial overlap + # with start think token because they are handled + # by the previous condition + delta_message = self.process_buffer() + + elif partial_overlap_start or partial_overlap_end: + # in general, if there are overlaps, we don't + # process the buffer because we want to wait until + # the think token is fully completed. + return None + else: + # we process the buffer as normal + delta_message = self.process_buffer() + + return delta_message + + +@ReasoningParserManager.register_module("olmo3") +class Olmo3ReasoningParser(ReasoningParser): + """ + Reasoning parser for Olmo 3 model + + Olmo3ReasoningParser + + This class implements a reasoning parser specifically designed for the + Olmo 3 family of models. Olmo 3 models do not use special tokens to + indicate reasoning; rather, reasoning trace is wrapped in `<think>` and + `</think>`, which are tokenized using standard vocabulary entries. + Because of this, the parser operates in string space, accumulating the + characters in a buffer until it sees `<think>` or `</think>`. tokens + to switch modes. + + Key Features: + - For non-stream output, Recognizes and extracts reasoning (text + bracketed by `<think>` and `</think>`) and content (everything + after the first `</think>`). + - For stream process, it uses a buffer to accumulate delta text, + and output progressive delta messages as soon as thinking starts + or ends. + - For reliability, some Olmo 3 models may hardcode the first + `<think>` token is the input text (similar to Deepseek R1, + or reasoning-only Qwen models). To support such variants, the + parser can optionally work in cases where the first `<think>` + token is missing from generation. + """ + + def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + self.think_start = r"<think>" + self.think_end = r"</think>" + + # notice that the first think is optional; this allows template to + # work in cases when we hardcode a <think> at the beginning of the + # reasoning template. + reasoning_expr = ( + rf"^(?:{self.think_start})?(?P<reasoning>.*?)" + + rf"{self.think_end}(?P<content>.*)$" + ) + self.reasoning_regex = re.compile(reasoning_expr, re.DOTALL) + + self.buffer = Olmo3ReasoningBuffer( + think_start=self.think_start, think_end=self.think_end + ) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return self.think_end in text + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # for Olmo 3 streaming reason parsing, the stream parse + # will call first, and the same token will be called in + # is_reasoning_end and extract_content_ids + # this id is not part of content, so just return [] here. + return [] + + def extract_reasoning_content( + self, + model_output: str, + request: Union[ChatCompletionRequest, ResponsesRequest], + ) -> tuple[Optional[str], Optional[str]]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest | ResponsesRequest): Request being + processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + + re_match = self.reasoning_regex.match(model_output) + if re_match: + reasoning_content = re_match.group("reasoning") or None + content = re_match.group("content") or None + return reasoning_content, content + + # no reasoning content + return None, model_output + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """Extract content using token ID sequence state machine""" + + delta_message = self.buffer.add_text(delta_text) + if delta_message is None and self.buffer.think_end in self.buffer.buffer: + # this is a bit hacky, but, because of how the buffer is + # constructed, if the last delta_text contains characters that + # marks the end of thinking tokens, then messages in the buffer + # would never be processed because we get no other turn. To get + # around that, we check if the text buffer contains the end of + # thinking tokens, and, if so, we reprocess the buffer again. + delta_message = self.buffer.process_buffer() + + return delta_message diff --git a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py index 3e3c7f32796b..160e8633a43f 100644 --- a/vllm/reasoning/qwen3_reasoning_parser.py +++ b/vllm/reasoning/qwen3_reasoning_parser.py @@ -3,8 +3,7 @@ from typing import Optional, Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - ResponsesRequest) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ResponsesRequest from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser @@ -32,12 +31,11 @@ def end_token(self) -> str: return "</think>" def extract_reasoning_content( - self, model_output: str, request: Union[ChatCompletionRequest, - ResponsesRequest] + self, model_output: str, request: Union[ChatCompletionRequest, ResponsesRequest] ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from the model output. - + Qwen3 has stricter requirements - it needs both start and end tokens to be present, unlike other models that work with just the end token. @@ -50,15 +48,15 @@ def extract_reasoning_content( """ # Check if the model output contains both <think> and </think> tokens. - if (self.start_token not in model_output - or self.end_token not in model_output): + if self.start_token not in model_output or self.end_token not in model_output: return None, model_output # Check if the <think> is present in the model output, remove it # if it is present. model_output_parts = model_output.partition(self.start_token) - model_output = model_output_parts[2] if model_output_parts[ - 1] else model_output_parts[0] + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) # Check if the model output contains the </think> tokens. # If the end token is not found, return the model output as is. diff --git a/vllm/reasoning/seedoss_reasoning_parser.py b/vllm/reasoning/seedoss_reasoning_parser.py index 5f4bbbf1557e..72f8dc54f1b3 100644 --- a/vllm/reasoning/seedoss_reasoning_parser.py +++ b/vllm/reasoning/seedoss_reasoning_parser.py @@ -10,10 +10,10 @@ class SeedOSSReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for SeedOSS model. - The SeedOSS model uses <seed:think>...</seed:think> tokens to - denote reasoning content text. This parser extracts + The SeedOSS model uses <seed:think>...</seed:think> tokens to + denote reasoning content text. This parser extracts the reasoning content from the model output. - Similar to DeepSeek R1, it supports cases + Similar to DeepSeek R1, it supports cases where the model doesn't generate the start token. """ diff --git a/vllm/reasoning/step3_reasoning_parser.py b/vllm/reasoning/step3_reasoning_parser.py index 6e5deb52d345..c9f580077b33 100644 --- a/vllm/reasoning/step3_reasoning_parser.py +++ b/vllm/reasoning/step3_reasoning_parser.py @@ -7,8 +7,7 @@ import regex as re from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager @@ -20,7 +19,7 @@ class Step3ReasoningParser(ReasoningParser): """ Reasoning parser for Step3 model. - The Step3 model uses </think> token to denote the end of reasoning + The Step3 model uses </think> token to denote the end of reasoning text. This parser extracts all content before </think> as reasoning content. """ @@ -28,19 +27,20 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): super().__init__(tokenizer, *args, **kwargs) self.think_end_token = "</think>" - self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}", - re.DOTALL) + self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}", re.DOTALL) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " - "constructor during construction.") + "constructor during construction." + ) self.think_end_token_id = self.vocab.get(self.think_end_token) if self.think_end_token_id is None: raise RuntimeError( "Step3 reasoning parser could not locate think end " - "token in the tokenizer!") + "token in the tokenizer!" + ) def extract_reasoning_content_streaming( self, @@ -60,17 +60,18 @@ def extract_reasoning_content_streaming( - 'xyz' goes to content """ # Skip single special token - if len(delta_token_ids - ) == 1 and delta_token_ids[0] == self.think_end_token_id: + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: return None if self.think_end_token_id in delta_token_ids: # </think> in delta, extract reasoning content and remaining content end_index = delta_text.find(self.think_end_token) reasoning_content = delta_text[:end_index] - content = delta_text[end_index + len(self.think_end_token):] - return DeltaMessage(reasoning_content=reasoning_content, - content=content if content else None) + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage( + reasoning_content=reasoning_content, + content=content if content else None, + ) elif self.think_end_token_id in previous_token_ids: # </think> already seen in previous text, everything is content return DeltaMessage(content=delta_text) @@ -79,9 +80,8 @@ def extract_reasoning_content_streaming( return DeltaMessage(reasoning_content=delta_text) def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: - # Check if the model output contains the </think> token if self.think_end_token not in model_output: # If no </think> token, everything is reasoning content @@ -92,7 +92,7 @@ def extract_reasoning_content( reasoning_content = model_output[:end_index] # Content after </think> token - content = model_output[end_index + len(self.think_end_token):] + content = model_output[end_index + len(self.think_end_token) :] if len(content) == 0: content = None @@ -106,4 +106,4 @@ def extract_content_ids(self, input_ids: list[int]) -> list[int]: if self.think_end_token_id not in input_ids[:-1]: return [] else: - return input_ids[input_ids.index(self.think_end_token_id) + 1:] + return input_ids[input_ids.index(self.think_end_token_id) + 1 :] diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index f424682f9dfa..a1ff4e5ff63b 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampling parameters for text generation.""" + import copy import warnings from dataclasses import field @@ -50,26 +51,32 @@ class StructuredOutputsParams: def __post_init__(self): """Validate that some fields are mutually exclusive.""" - count = sum([ - self.json is not None, self.regex is not None, self.choice - is not None, self.grammar is not None, self.json_object is not None - ]) + count = sum( + [ + self.json is not None, + self.regex is not None, + self.choice is not None, + self.grammar is not None, + self.json_object is not None, + ] + ) if count > 1: raise ValueError( "You can only use one kind of structured outputs constraint " - f"but multiple are specified: {self.__dict__}") + f"but multiple are specified: {self.__dict__}" + ) @dataclass class GuidedDecodingParams(StructuredOutputsParams): - def __post_init__(self): warnings.warn( "GuidedDecodingParams is deprecated. This will be removed in " "v0.12.0 or v1.0.0, which ever is soonest. Please use " "StructuredOutputsParams instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) return super().__post_init__() @@ -83,10 +90,11 @@ class RequestOutputKind(Enum): class SamplingParams( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True): # type: ignore[call-arg] + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): # type: ignore[call-arg] """Sampling parameters for text generation. Overall, we follow the sampling parameters from the OpenAI text completion @@ -178,8 +186,7 @@ class SamplingParams( optionally prompt tokens as a first argument.""" include_stop_str_in_output: bool = False """Whether to include the stop strings in output text.""" - truncate_prompt_tokens: Optional[Annotated[int, - msgspec.Meta(ge=-1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None """If set to -1, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled.""" @@ -238,9 +245,7 @@ def from_optional( skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[list[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[Annotated[int, - msgspec.Meta( - ge=-1)]] = None, + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, structured_outputs: Optional[StructuredOutputsParams] = None, guided_decoding: Optional[GuidedDecodingParams] = None, @@ -261,19 +266,19 @@ def from_optional( "v0.12.0 or v1.0.0, which ever is soonest. Please use " "structured_outputs instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) structured_outputs = guided_decoding guided_decoding = None return SamplingParams( n=1 if n is None else n, best_of=best_of, - presence_penalty=0.0 - if presence_penalty is None else presence_penalty, - frequency_penalty=0.0 - if frequency_penalty is None else frequency_penalty, + presence_penalty=0.0 if presence_penalty is None else presence_penalty, + frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty, repetition_penalty=1.0 - if repetition_penalty is None else repetition_penalty, + if repetition_penalty is None + else repetition_penalty, temperature=1.0 if temperature is None else temperature, top_p=1.0 if top_p is None else top_p, top_k=top_k, @@ -311,7 +316,8 @@ def __post_init__(self) -> None: if self.best_of < self.n: raise ValueError( f"best_of must be greater than or equal to n, " - f"got n={self.n} and best_of={self.best_of}.") + f"got n={self.n} and best_of={self.best_of}." + ) if not self._real_n: self._real_n = self.n self.n = self.best_of @@ -320,7 +326,10 @@ def __post_init__(self) -> None: logger.warning( "temperature %s is less than %s, which may cause numerical " "errors nan or inf in tensors. We have maxed it out to %s.", - self.temperature, _MAX_TEMP, _MAX_TEMP) + self.temperature, + _MAX_TEMP, + _MAX_TEMP, + ) self.temperature = max(self.temperature, _MAX_TEMP) if self.seed == -1: @@ -366,101 +375,116 @@ def __post_init__(self) -> None: "v0.12.0 or v1.0.0, which ever is soonest. Please use " "structured_outputs instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) self.structured_outputs = self.guided_decoding self.guided_decoding = None def _verify_args(self) -> None: if not isinstance(self.n, int): - raise ValueError(f"n must be an int, but is of " - f"type {type(self.n)}") + raise ValueError(f"n must be an int, but is of type {type(self.n)}") if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") if self.best_of is not None: if not isinstance(self.best_of, int): raise ValueError( - f"best_of must be an integer, got {type(self.best_of)}") + f"best_of must be an integer, got {type(self.best_of)}" + ) if self.best_of < 1: - raise ValueError( - f"best_of must be at least 1, got {self.best_of}") + raise ValueError(f"best_of must be at least 1, got {self.best_of}") if self.best_of < self.n: raise ValueError( f"best_of must be greater than or equal to n, " - f"got n={self.n} and best_of={self.best_of}.") + f"got n={self.n} and best_of={self.best_of}." + ) if not -2.0 <= self.presence_penalty <= 2.0: - raise ValueError("presence_penalty must be in [-2, 2], got " - f"{self.presence_penalty}.") + raise ValueError( + f"presence_penalty must be in [-2, 2], got {self.presence_penalty}." + ) if not -2.0 <= self.frequency_penalty <= 2.0: - raise ValueError("frequency_penalty must be in [-2, 2], got " - f"{self.frequency_penalty}.") + raise ValueError( + f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}." + ) if self.repetition_penalty <= 0.0: raise ValueError( "repetition_penalty must be greater than zero, got " - f"{self.repetition_penalty}.") + f"{self.repetition_penalty}." + ) if self.temperature < 0.0: raise ValueError( - f"temperature must be non-negative, got {self.temperature}.") + f"temperature must be non-negative, got {self.temperature}." + ) if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") # quietly accept -1 as disabled, but prefer 0 if self.top_k < -1: - raise ValueError(f"top_k must be 0 (disable), or at least 1, " - f"got {self.top_k}.") + raise ValueError( + f"top_k must be 0 (disable), or at least 1, got {self.top_k}." + ) if not isinstance(self.top_k, int): raise TypeError( - f"top_k must be an integer, got {type(self.top_k).__name__}") + f"top_k must be an integer, got {type(self.top_k).__name__}" + ) if not 0.0 <= self.min_p <= 1.0: - raise ValueError("min_p must be in [0, 1], got " - f"{self.min_p}.") + raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") if self.max_tokens is not None and self.max_tokens < 1: - raise ValueError( - f"max_tokens must be at least 1, got {self.max_tokens}.") + raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.") if self.min_tokens < 0: - raise ValueError(f"min_tokens must be greater than or equal to 0, " - f"got {self.min_tokens}.") + raise ValueError( + f"min_tokens must be greater than or equal to 0, got {self.min_tokens}." + ) if self.max_tokens is not None and self.min_tokens > self.max_tokens: raise ValueError( f"min_tokens must be less than or equal to " - f"max_tokens={self.max_tokens}, got {self.min_tokens}.") - if (self.logprobs is not None and self.logprobs != -1 - and self.logprobs < 0): + f"max_tokens={self.max_tokens}, got {self.min_tokens}." + ) + if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0: raise ValueError( - f"logprobs must be non-negative or -1, got {self.logprobs}.") - if (self.prompt_logprobs is not None and self.prompt_logprobs != -1 - and self.prompt_logprobs < 0): + f"logprobs must be non-negative or -1, got {self.logprobs}." + ) + if ( + self.prompt_logprobs is not None + and self.prompt_logprobs != -1 + and self.prompt_logprobs < 0 + ): raise ValueError( f"prompt_logprobs must be non-negative or -1, got " - f"{self.prompt_logprobs}.") - if (self.truncate_prompt_tokens is not None - and (self.truncate_prompt_tokens == 0 - or self.truncate_prompt_tokens < -1)): + f"{self.prompt_logprobs}." + ) + if self.truncate_prompt_tokens is not None and ( + self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1 + ): raise ValueError( f"truncate_prompt_tokens must be an integer >= 1 or -1, " - f"got {self.truncate_prompt_tokens}") + f"got {self.truncate_prompt_tokens}" + ) assert isinstance(self.stop_token_ids, list) if not all(isinstance(st_id, int) for st_id in self.stop_token_ids): - raise ValueError(f"stop_token_ids must contain only integers, " - f"got {self.stop_token_ids}.") + raise ValueError( + f"stop_token_ids must contain only integers, got {self.stop_token_ids}." + ) assert isinstance(self.stop, list) if any(not stop_str for stop_str in self.stop): raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " - "Set detokenize=True to use stop.") + "Set detokenize=True to use stop." + ) if self.best_of != self._real_n and self.output_kind == ( - RequestOutputKind.DELTA): + RequestOutputKind.DELTA + ): raise ValueError("best_of must equal n to use output_kind=DELTA") def _verify_greedy_sampling(self) -> None: if self.n > 1: - raise ValueError("n must be 1 when using greedy sampling, " - f"got {self.n}.") + raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.") def update_from_generation_config( - self, - generation_config: dict[str, Any], - model_eos_token_id: Optional[int] = None) -> None: + self, + generation_config: dict[str, Any], + model_eos_token_id: Optional[int] = None, + ) -> None: """Update if there are non-default values from generation_config""" if model_eos_token_id is not None: @@ -494,30 +518,33 @@ def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None: for add_prefix_space in [False, True]: prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - prompt_token_ids = tokenizer.encode(text=prompt, - add_special_tokens=False) + prompt_token_ids = tokenizer.encode( + text=prompt, add_special_tokens=False + ) # If no space at the beginning # or if prefix space produces a new word token if (not add_prefix_space) or ( - add_prefix_space and prompt_token_ids[0] - != self._bad_words_token_ids[-1][0] - and len(prompt_token_ids) == len( - self._bad_words_token_ids[-1])): + add_prefix_space + and prompt_token_ids[0] != self._bad_words_token_ids[-1][0] + and len(prompt_token_ids) == len(self._bad_words_token_ids[-1]) + ): self._bad_words_token_ids.append(prompt_token_ids) invalid_token_ids = [ - token_id for bad_words_token_ids in self._bad_words_token_ids + token_id + for bad_words_token_ids in self._bad_words_token_ids for token_id in bad_words_token_ids if token_id < 0 or token_id > tokenizer.max_token_id ] if len(invalid_token_ids) > 0: raise ValueError( - f"The model vocabulary size is {tokenizer.max_token_id+1}," + f"The model vocabulary size is {tokenizer.max_token_id + 1}," f" but the following tokens" f" were specified as bad: {invalid_token_ids}." f" All token id values should be integers satisfying:" - f" 0 <= token_id <= {tokenizer.max_token_id}.") + f" 0 <= token_id <= {tokenizer.max_token_id}." + ) @cached_property def sampling_type(self) -> SamplingType: @@ -545,10 +572,14 @@ def clone(self) -> "SamplingParams": See https://github.com/vllm-project/vllm/issues/3087 """ - logit_processor_refs = None if self.logits_processors is None else { - id(lp): lp.clone() if hasattr(lp, 'clone') else lp - for lp in self.logits_processors - } + logit_processor_refs = ( + None + if self.logits_processors is None + else { + id(lp): lp.clone() if hasattr(lp, "clone") else lp + for lp in self.logits_processors + } + ) return copy.deepcopy(self, memo=logit_processor_refs) def __repr__(self) -> str: @@ -576,15 +607,18 @@ def __repr__(self) -> str: f"{self.spaces_between_special_tokens}, " f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " f"structured_outputs={self.structured_outputs}, " - f"extra_args={self.extra_args})") + f"extra_args={self.extra_args})" + ) class BeamSearchParams( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True): # type: ignore[call-arg] + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): # type: ignore[call-arg] """Beam search parameters for text generation.""" + beam_width: int max_tokens: int ignore_eos: bool = False diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index 055f28914ad5..fd0713dc0aa3 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -70,20 +70,19 @@ class ScalarType: """ def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" + assert self.mantissa <= 52 and self.exponent <= 11, ( + f"Cannot represent max/min as a double for type {self.__str__()}" + ) max_mantissa = (1 << self.mantissa) - 1 if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: max_mantissa = max_mantissa - 1 max_exponent = (1 << self.exponent) - 2 - if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN - or self.nan_repr == NanRepr.NONE): - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE: + assert self.exponent < 11, ( + f"Cannot represent max/min as a double for type {self.__str__()}" + ) max_exponent = max_exponent + 1 # adjust the exponent to match that of a double @@ -96,38 +95,39 @@ def _floating_point_max_int(self) -> int: exponent_bias = (1 << (self.exponent - 1)) - 1 exponent_bias_double = (1 << 10) - 1 # double e = 11 - max_exponent_double = (max_exponent - exponent_bias + - exponent_bias_double) + max_exponent_double = max_exponent - exponent_bias + exponent_bias_double # shift the mantissa and exponent into the proper positions for an # IEEE double and bitwise-or them together. - return (max_mantissa << - (52 - self.mantissa)) | (max_exponent_double << 52) + return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52) def _floating_point_max(self) -> float: double_raw = self._floating_point_max_int() - return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + return struct.unpack("!d", struct.pack("!Q", double_raw))[0] def _raw_max(self) -> Union[int, float]: if self.is_floating_point(): return self._floating_point_max() else: - assert (self.size_bits < 64 or self.size_bits == 64 - and self.is_signed()), "Cannot represent max as an int" + assert self.size_bits < 64 or self.size_bits == 64 and self.is_signed(), ( + "Cannot represent max as an int" + ) return (1 << self.mantissa) - 1 def _raw_min(self) -> Union[int, float]: if self.is_floating_point(): - assert self.is_signed( - ), "We currently assume all floating point types are signed" + assert self.is_signed(), ( + "We currently assume all floating point types are signed" + ) sign_bit_double = 1 << 63 max_raw = self._floating_point_max_int() min_raw = max_raw | sign_bit_double - return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + return struct.unpack("!d", struct.pack("!Q", min_raw))[0] else: - assert (not self.is_signed() or self.size_bits - <= 64), "Cannot represent min as a int64_t" + assert not self.is_signed() or self.size_bits <= 64, ( + "Cannot represent min as a int64_t" + ) if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -158,8 +158,7 @@ def or_and_advance(member, bit_width): or_and_advance(self._finite_values_only, 1) or_and_advance(self.nan_repr.value, 8) - assert offset <= 64, \ - f"ScalarType fields too big {offset} to fit into an int64" + assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64" _SCALAR_TYPES_ID_MAP[val] = self @@ -215,8 +214,7 @@ def is_ieee_754(self) -> bool: If the type is a floating point type that follows IEEE 754 conventions """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only + return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only def __str__(self) -> str: """ @@ -232,8 +230,14 @@ def __str__(self) -> str: - if bias is not present it means its zero """ if self.is_floating_point(): - ret = "float" + str(self.size_bits) + "_e" + str( - self.exponent) + "m" + str(self.mantissa) + ret = ( + "float" + + str(self.size_bits) + + "_e" + + str(self.exponent) + + "m" + + str(self.mantissa) + ) if not self.is_ieee_754(): if self._finite_values_only: @@ -261,41 +265,43 @@ def __len__(self) -> int: # @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": "Create a signed integer scalar type (size_bits includes sign-bit)." ret = cls(0, size_bits - 1, True, bias if bias else 0) ret.id # noqa B018: make sure the id is cached return ret @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": """Create an unsigned integer scalar type.""" ret = cls(0, size_bits, False, bias if bias else 0) ret.id # noqa B018: make sure the id is cached return ret @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType": """ Create a standard floating point type (i.e. follows IEEE 754 conventions). """ - assert (mantissa > 0 and exponent > 0) + assert mantissa > 0 and exponent > 0 ret = cls(exponent, mantissa, True, 0) ret.id # noqa B018: make sure the id is cached return ret @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: NanRepr) -> 'ScalarType': + def float_( + cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr + ) -> "ScalarType": """ Create a non-standard floating point type (i.e. does not follow IEEE 754 conventions). """ - assert (mantissa > 0 and exponent > 0) - assert (nan_repr != NanRepr.IEEE_754), ( + assert mantissa > 0 and exponent > 0 + assert nan_repr != NanRepr.IEEE_754, ( "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions") + "follow IEEE 754 conventions" + ) ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) ret.id # noqa B018: make sure the id is cached return ret @@ -303,8 +309,7 @@ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, @classmethod def from_id(cls, scalar_type_id: int): if scalar_type_id not in _SCALAR_TYPES_ID_MAP: - raise ValueError( - f"scalar_type_id {scalar_type_id} doesn't exists.") + raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.") return _SCALAR_TYPES_ID_MAP[scalar_type_id] @@ -327,8 +332,7 @@ class scalar_types: uint8 = ScalarType.uint(8, None) float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float8_e8m0fnu = ScalarType(8, 0, False, 0, True, - NanRepr.EXTD_RANGE_MAX_MIN) + float8_e8m0fnu = ScalarType(8, 0, False, 0, True, NanRepr.EXTD_RANGE_MAX_MIN) float16_e8m7 = ScalarType.float_IEEE754(8, 7) float16_e5m10 = ScalarType.float_IEEE754(5, 10) diff --git a/vllm/scripts.py b/vllm/scripts.py index 7a7fdccf0a32..f158860726be 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -10,6 +10,8 @@ # Backwards compatibility for the move from vllm.scripts to # vllm.entrypoints.cli.main def main(): - logger.warning("vllm.scripts.main() is deprecated. Please re-install " - "vllm or use vllm.entrypoints.cli.main.main() instead.") + logger.warning( + "vllm.scripts.main() is deprecated. Please re-install " + "vllm or use vllm.entrypoints.cli.main.main() instead." + ) vllm_main() diff --git a/vllm/sequence.py b/vllm/sequence.py index e5f23d47a660..7682b7f58305 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sequence and its related classes.""" + from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union @@ -8,8 +9,7 @@ import torch if TYPE_CHECKING: - from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorOutput) + from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput else: KVConnectorOutput = Any @@ -36,6 +36,7 @@ class RequestMetrics: will include model forward, block/sync across workers, cpu-gpu sync time and sampling time. """ + arrival_time: float last_token_time: float first_scheduled_time: Optional[float] @@ -53,7 +54,7 @@ class IntermediateTensors: """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. - + Each stage also needs to handle its own kv_connector_output. """ @@ -87,17 +88,16 @@ def __eq__(self, other: object): return False if self.tensors.keys() != other.tensors.keys(): return False - return all( - torch.equal(self.tensors[k], other.tensors[k]) - for k in self.tensors) + return all(torch.equal(self.tensors[k], other.tensors[k]) for k in self.tensors) def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" class ExecuteModelRequest( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, +): # type: ignore[call-arg] # Placeholder. Remove. pass diff --git a/vllm/tracing.py b/vllm/tracing.py index 7537e9901a04..c9b595999fc7 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -17,12 +17,15 @@ try: from opentelemetry.context.context import Context from opentelemetry.sdk.environment_variables import ( - OTEL_EXPORTER_OTLP_TRACES_PROTOCOL) + OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, + ) from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider from opentelemetry.trace.propagation.tracecontext import ( - TraceContextTextMapPropagator) + TraceContextTextMapPropagator, + ) + _is_otel_imported = True except ImportError: # Capture and format traceback to provide detailed context for the import @@ -30,6 +33,7 @@ # memory leaks. # See https://github.com/vllm-project/vllm/pull/7266#discussion_r1707395458 import traceback + otel_import_error_traceback = traceback.format_exc() class Context: # type: ignore @@ -49,13 +53,15 @@ def is_otel_available() -> bool: return _is_otel_imported -def init_tracer(instrumenting_module_name: str, - otlp_traces_endpoint: str) -> Optional[Tracer]: +def init_tracer( + instrumenting_module_name: str, otlp_traces_endpoint: str +) -> Optional[Tracer]: if not is_otel_available(): raise ValueError( "OpenTelemetry is not available. Unable to initialize " "a tracer. Ensure OpenTelemetry packages are installed. " - f"Original error:\n{otel_import_error_traceback}") + f"Original error:\n{otel_import_error_traceback}" + ) trace_provider = TracerProvider() span_exporter = get_span_exporter(otlp_traces_endpoint) @@ -70,19 +76,19 @@ def get_span_exporter(endpoint): protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc") if protocol == "grpc": from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( - OTLPSpanExporter) + OTLPSpanExporter, + ) elif protocol == "http/protobuf": from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter) # type: ignore + OTLPSpanExporter, # type: ignore + ) else: - raise ValueError( - f"Unsupported OTLP protocol '{protocol}' is configured") + raise ValueError(f"Unsupported OTLP protocol '{protocol}' is configured") return OTLPSpanExporter(endpoint=endpoint) -def extract_trace_context( - headers: Optional[Mapping[str, str]]) -> Optional[Context]: +def extract_trace_context(headers: Optional[Mapping[str, str]]) -> Optional[Context]: if is_otel_available(): headers = headers or {} return TraceContextTextMapPropagator().extract(headers) @@ -91,7 +97,6 @@ def extract_trace_context( def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]: - return {h: headers[h] for h in TRACE_HEADERS if h in headers} @@ -113,17 +118,13 @@ class SpanAttributes: GEN_AI_LATENCY_E2E = "gen_ai.latency.e2e" GEN_AI_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler" # Time taken in the forward pass for this across all workers - GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD = ( - "gen_ai.latency.time_in_model_forward") + GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward" # Time taken in the model execute function. This will include model # forward, block/sync across workers, cpu-gpu sync time and sampling time. - GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = ( - "gen_ai.latency.time_in_model_execute") - GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = \ - "gen_ai.latency.time_in_model_prefill" + GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute" + GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = "gen_ai.latency.time_in_model_prefill" GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode" - GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = \ - "gen_ai.latency.time_in_model_inference" + GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = "gen_ai.latency.time_in_model_inference" def contains_trace_headers(headers: Mapping[str, str]) -> bool: @@ -132,5 +133,4 @@ def contains_trace_headers(headers: Mapping[str, str]) -> bool: @run_once def log_tracing_disabled_warning() -> None: - logger.warning( - "Received a request with trace context but tracing is disabled") + logger.warning("Received a request with trace context but tracing is disabled") diff --git a/vllm/transformers_utils/__init__.py b/vllm/transformers_utils/__init__.py index 6d4231baca50..649df9a4f022 100644 --- a/vllm/transformers_utils/__init__.py +++ b/vllm/transformers_utils/__init__.py @@ -10,10 +10,11 @@ from packaging import version # patch_hub begins from modelscope>=1.18.1 - if version.parse(modelscope.__version__) <= version.parse('1.18.0'): + if version.parse(modelscope.__version__) <= version.parse("1.18.0"): raise ImportError( - 'Using vLLM with ModelScope needs modelscope>=1.18.1, please ' - 'install by `pip install modelscope -U`') + "Using vLLM with ModelScope needs modelscope>=1.18.1, please " + "install by `pip install modelscope -U`" + ) from modelscope.utils.hf_util import patch_hub # Patch hub to download models from modelscope to speed up. @@ -21,4 +22,5 @@ except ImportError as err: raise ImportError( "Please install modelscope>=1.18.1 via " - "`pip install modelscope>=1.18.1` to use ModelScope.") from err + "`pip install modelscope>=1.18.1` to use ModelScope." + ) from err diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index 3a97f2c05618..b8d0cd8d2f20 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -12,16 +12,14 @@ ChatTemplatePath = Union[Path, Callable[[str], Optional[Path]]] -def _get_qwen_chat_template_fallback( - tokenizer_name_or_path: str) -> Optional[Path]: +def _get_qwen_chat_template_fallback(tokenizer_name_or_path: str) -> Optional[Path]: if tokenizer_name_or_path.endswith("-Chat"): return CHAT_TEMPLATES_DIR / "template_chatml.jinja" return CHAT_TEMPLATES_DIR / "template_basic.jinja" -def _get_minicpmv_chat_template_fallback( - tokenizer_name_or_path: str) -> Optional[Path]: +def _get_minicpmv_chat_template_fallback(tokenizer_name_or_path: str) -> Optional[Path]: # MiniCPM-V-4.5 version uses a dedicated template if "4.5" in tokenizer_name_or_path or "4_5" in tokenizer_name_or_path: return CHAT_TEMPLATES_DIR / "template_minicpmv45.jinja" @@ -30,9 +28,9 @@ def _get_minicpmv_chat_template_fallback( return CHAT_TEMPLATES_DIR / "template_chatml.jinja" -# yapf: disable _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", + "clip": CHAT_TEMPLATES_DIR / "template_basic.jinja", "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", @@ -40,7 +38,6 @@ def _get_minicpmv_chat_template_fallback( "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", "qwen": _get_qwen_chat_template_fallback, } -# yapf: enable def register_chat_template_fallback_path( @@ -50,8 +47,10 @@ def register_chat_template_fallback_path( if model_type in _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: logger.warning( "Model type %s already has a chat template registered. " - "It will be overwritten by the new chat template %s.", model_type, - chat_template) + "It will be overwritten by the new chat template %s.", + model_type, + chat_template, + ) _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK[model_type] = chat_template diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8d340f88fa25..ab3eb6de4780 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -10,26 +10,32 @@ from typing import Any, Callable, Literal, Optional, TypeVar, Union import huggingface_hub -from huggingface_hub import get_safetensors_metadata, hf_hub_download +from huggingface_hub import ( + get_safetensors_metadata, + hf_hub_download, + try_to_load_from_cache, +) from huggingface_hub import list_repo_files as hf_list_repo_files -from huggingface_hub import try_to_load_from_cache -from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - LocalEntryNotFoundError, - RepositoryNotFoundError, - RevisionNotFoundError) +from huggingface_hub.utils import ( + EntryNotFoundError, + HfHubHTTPError, + LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) from transformers import GenerationConfig, PretrainedConfig -from transformers.models.auto.image_processing_auto import ( - get_image_processor_config) -from transformers.models.auto.modeling_auto import ( - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +from transformers.models.auto.image_processing_auto import get_image_processor_config +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.models.auto.tokenization_auto import get_tokenizer_config from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs from vllm.logger import init_logger from vllm.transformers_utils.config_parser_base import ConfigParserBase -from vllm.transformers_utils.utils import (check_gguf_file, - parse_safetensors_file_metadata) +from vllm.transformers_utils.utils import ( + check_gguf_file, + parse_safetensors_file_metadata, +) if envs.VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -45,21 +51,21 @@ def _get_hf_token() -> Optional[str]: """ Get the HuggingFace token from environment variable. - Returns None if the token is not set, is an empty string, + Returns None if the token is not set, is an empty string, or contains only whitespace. This follows the same pattern as huggingface_hub library which treats empty string tokens as None to avoid authentication errors. """ - token = os.getenv('HF_TOKEN') + token = os.getenv("HF_TOKEN") if token and token.strip(): return token return None class LazyConfigDict(dict): - def __getitem__(self, key): import vllm.transformers_utils.configs as configs + return getattr(configs, super().__getitem__(key)) @@ -84,30 +90,28 @@ def __getitem__(self, key): ultravox="UltravoxConfig", step3_vl="Step3VLConfig", step3_text="Step3TextConfig", - qwen3_next="Qwen3NextConfig") + qwen3_next="Qwen3NextConfig", +) _CONFIG_ATTRS_MAPPING: dict[str, str] = { "llm_config": "text_config", } _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = { - "internvl_chat": { - "has_no_defaults_at_init": True - }, - "NVLM_D": { - "has_no_defaults_at_init": True - }, + "internvl_chat": {"has_no_defaults_at_init": True}, + "NVLM_D": {"has_no_defaults_at_init": True}, } class HFConfigParser(ConfigParserBase): - - def parse(self, - model: Union[str, Path], - trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - **kwargs) -> tuple[dict, PretrainedConfig]: + def parse( + self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE config_dict, _ = PretrainedConfig.get_config_dict( model, @@ -119,8 +123,11 @@ def parse(self, # Use custom model class if it's in our registry model_type = config_dict.get("model_type") if model_type is None: - model_type = "speculators" if config_dict.get( - "speculators_config") is not None else model_type + model_type = ( + "speculators" + if config_dict.get("speculators_config") is not None + else model_type + ) if model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[model_type] @@ -133,8 +140,7 @@ def parse(self, ) else: try: - kwargs = _maybe_update_auto_config_kwargs( - kwargs, model_type=model_type) + kwargs = _maybe_update_auto_config_kwargs(kwargs, model_type=model_type) config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, @@ -144,15 +150,17 @@ def parse(self, **kwargs, ) except ValueError as e: - if (not trust_remote_code - and "requires you to execute the configuration file" - in str(e)): + if ( + not trust_remote_code + and "requires you to execute the configuration file" in str(e) + ): err_msg = ( "Failed to load the model config. If the model " "is a custom model not yet available in the " "HuggingFace transformers library, consider setting " "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") + "`--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e @@ -161,20 +169,23 @@ def parse(self, class MistralConfigParser(ConfigParserBase): - - def parse(self, - model: Union[str, Path], - trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - **kwargs) -> tuple[dict, PretrainedConfig]: + def parse( + self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: # This function loads a params.json config which # should be used when loading models in mistral format config_dict = _download_mistral_config_file(model, revision) - if (max_position_embeddings := - config_dict.get("max_position_embeddings")) is None: + if ( + max_position_embeddings := config_dict.get("max_position_embeddings") + ) is None: max_position_embeddings = _maybe_retrieve_max_pos_from_hf( - model, revision, **kwargs) + model, revision, **kwargs + ) config_dict["max_position_embeddings"] = max_position_embeddings from vllm.transformers_utils.configs.mistral import adapt_config_dict @@ -183,8 +194,9 @@ def parse(self, # Mistral configs may define sliding_window as list[int]. Convert it # to int and add the layer_types list[str] to make it HF compatible - if ((sliding_window := getattr(config, "sliding_window", None)) - and isinstance(sliding_window, list)): + if (sliding_window := getattr(config, "sliding_window", None)) and isinstance( + sliding_window, list + ): pattern_repeats = config.num_hidden_layers // len(sliding_window) layer_types = sliding_window * pattern_repeats config.layer_types = [ @@ -216,44 +228,51 @@ def get_config_parser(config_format: str) -> ConfigParserBase: def register_config_parser(config_format: str): - """Register a customized vllm config parser. - When a config format is not supported by vllm, you can register a customized - config parser to support it. - Args: - config_format (str): The config parser format name. - Examples: - - >>> from vllm.transformers_utils.config import (get_config_parser, - register_config_parser) - >>> from vllm.transformers_utils.config_parser_base import ConfigParserBase - >>> - >>> @register_config_parser("custom_config_parser") - ... class CustomConfigParser(ConfigParserBase): - ... def parse(self, - ... model: Union[str, Path], - ... trust_remote_code: bool, - ... revision: Optional[str] = None, - ... code_revision: Optional[str] = None, - ... **kwargs) -> tuple[dict, PretrainedConfig]: - ... raise NotImplementedError - >>> - >>> type(get_config_parser("custom_config_parser")) - <class 'CustomConfigParser'> + When a config format is not supported by vllm, you can register a customized + config parser to support it. + Args: + config_format (str): The config parser format name. + Examples: + + >>> from vllm.transformers_utils.config import (get_config_parser, + register_config_parser) + >>> from vllm.transformers_utils.config_parser_base import ConfigParserBase + >>> + >>> @register_config_parser("custom_config_parser") + ... class CustomConfigParser(ConfigParserBase): + ... def parse( + ... self, + ... model: Union[str, Path], + ... trust_remote_code: bool, + ... revision: Optional[str] = None, + ... code_revision: Optional[str] = None, + ... **kwargs, + ... ) -> tuple[dict, PretrainedConfig]: + ... raise NotImplementedError + >>> + >>> type(get_config_parser("custom_config_parser")) + <class 'CustomConfigParser'> """ # noqa: E501 def _wrapper(config_parser_cls): if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER: logger.warning( "Config format `%s` is already registered, and will be " - "overwritten by the new parser class `%s`.", config_format, - config_parser_cls) + "overwritten by the new parser class `%s`.", + config_format, + config_parser_cls, + ) if not issubclass(config_parser_cls, ConfigParserBase): - raise ValueError("The config parser must be a subclass of " - "`ConfigParserBase`.") + raise ValueError( + "The config parser must be a subclass of `ConfigParserBase`." + ) _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls - logger.info("Registered config parser `%s` with config format `%s`", - config_parser_cls, config_format) + logger.info( + "Registered config parser `%s` with config format `%s`", + config_parser_cls, + config_format, + ) return config_parser_cls return _wrapper @@ -275,8 +294,9 @@ def with_retry( if attempt == max_retries - 1: logger.error("%s: %s", log_msg, e) raise - logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1, - max_retries) + logger.error( + "%s: %s, retrying %d of %d", log_msg, e, attempt + 1, max_retries + ) time.sleep(retry_delay) retry_delay *= 2 @@ -292,28 +312,27 @@ def list_repo_files( repo_type: Optional[str] = None, token: Union[str, bool, None] = None, ) -> list[str]: - def lookup_files() -> list[str]: # directly list files if model is local if (local_path := Path(repo_id)).exists(): return [ str(file.relative_to(local_path)) - for file in local_path.rglob('*') if file.is_file() + for file in local_path.rglob("*") + if file.is_file() ] # if model is remote, use hf_hub api to list files try: if envs.VLLM_USE_MODELSCOPE: - from vllm.transformers_utils.utils import ( - modelscope_list_repo_files) - return modelscope_list_repo_files(repo_id, - revision=revision, - token=os.getenv( - "MODELSCOPE_API_TOKEN", - None)) - return hf_list_repo_files(repo_id, - revision=revision, - repo_type=repo_type, - token=token) + from vllm.transformers_utils.utils import modelscope_list_repo_files + + return modelscope_list_repo_files( + repo_id, + revision=revision, + token=os.getenv("MODELSCOPE_API_TOKEN", None), + ) + return hf_list_repo_files( + repo_id, revision=revision, repo_type=repo_type, token=token + ) except huggingface_hub.errors.OfflineModeIsEnabled: # Don't raise in offline mode, # all we know is that we don't have this @@ -331,23 +350,23 @@ def file_exists( revision: Optional[str] = None, token: Union[str, bool, None] = None, ) -> bool: - file_list = list_repo_files(repo_id, - repo_type=repo_type, - revision=revision, - token=token) + file_list = list_repo_files( + repo_id, repo_type=repo_type, revision=revision, token=token + ) return file_name in file_list # In offline mode the result can be a false negative -def file_or_path_exists(model: Union[str, Path], config_name: str, - revision: Optional[str]) -> bool: +def file_or_path_exists( + model: Union[str, Path], config_name: str, revision: Optional[str] +) -> bool: if (local_path := Path(model)).exists(): return (local_path / config_name).is_file() # Offline mode support: Check if config file is cached already - cached_filepath = try_to_load_from_cache(repo_id=model, - filename=config_name, - revision=revision) + cached_filepath = try_to_load_from_cache( + repo_id=model, filename=config_name, revision=revision + ) if isinstance(cached_filepath, str): # The config file exists in cache- we can continue trying to load return True @@ -356,10 +375,9 @@ def file_or_path_exists(model: Union[str, Path], config_name: str, # hf_hub. This will fail in offline mode. # Call HF to check if the file exists - return file_exists(str(model), - config_name, - revision=revision, - token=_get_hf_token()) + return file_exists( + str(model), config_name, revision=revision, token=_get_hf_token() + ) def patch_rope_scaling(config: PretrainedConfig) -> None: @@ -381,7 +399,8 @@ def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None: raise ValueError( f"Found conflicts between 'rope_type={rope_type}' (modern " f"field) and 'type={rope_type_legacy}' (legacy field). " - "You should only specify one of them.") + "You should only specify one of them." + ) if "rope_type" not in rope_scaling and "type" in rope_scaling: rope_scaling["rope_type"] = rope_scaling["type"] @@ -409,8 +428,11 @@ def _uses_mrope(config: PretrainedConfig) -> bool: def uses_mrope(config: PretrainedConfig) -> bool: """Detect if the model with this config uses M-ROPE.""" - return _uses_mrope(config) or _uses_mrope( - config.get_text_config()) or thinker_uses_mrope(config) + return ( + _uses_mrope(config) + or _uses_mrope(config.get_text_config()) + or thinker_uses_mrope(config) + ) def thinker_uses_mrope(config: PretrainedConfig) -> bool: @@ -432,8 +454,7 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool: def _is_encoder_decoder(config: PretrainedConfig) -> bool: return getattr(config, "is_encoder_decoder", False) - return (_is_encoder_decoder(config) - or _is_encoder_decoder(config.get_text_config())) + return _is_encoder_decoder(config) or _is_encoder_decoder(config.get_text_config()) def is_interleaved(config: PretrainedConfig) -> bool: @@ -462,8 +483,7 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig: if hasattr(config, old_attr): if not hasattr(config, new_attr): config.update({new_attr: getattr(config, old_attr)}) - logger.debug("Remapped config attribute '%s' to '%s'", old_attr, - new_attr) + logger.debug("Remapped config attribute '%s' to '%s'", old_attr, new_attr) return config @@ -512,11 +532,11 @@ def maybe_override_with_speculators( return model, tokenizer, vllm_speculative_config # Speculators format detected - process overrides - from vllm.transformers_utils.configs.speculators.base import ( - SpeculatorsConfig) + from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig speculative_config = SpeculatorsConfig.extract_vllm_speculative_config( - config_dict=config_dict) + config_dict=config_dict + ) # Set the draft model to the speculators model speculative_config["model"] = model @@ -535,8 +555,7 @@ def get_config( code_revision: Optional[str] = None, config_format: Union[str, ConfigFormat] = "auto", hf_overrides_kw: Optional[dict[str, Any]] = None, - hf_overrides_fn: Optional[Callable[[PretrainedConfig], - PretrainedConfig]] = None, + hf_overrides_fn: Optional[Callable[[PretrainedConfig], PretrainedConfig]] = None, **kwargs, ) -> PretrainedConfig: # Separate model folder from file path for GGUF models @@ -548,12 +567,9 @@ def get_config( if config_format == "auto": try: - if is_gguf or file_or_path_exists( - model, HF_CONFIG_NAME, revision=revision): + if is_gguf or file_or_path_exists(model, HF_CONFIG_NAME, revision=revision): config_format = "hf" - elif file_or_path_exists(model, - MISTRAL_CONFIG_NAME, - revision=revision): + elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): config_format = "mistral" else: raise ValueError( @@ -561,7 +577,8 @@ def get_config( "With config_format 'auto', ensure your model has either " "config.json (HF format) or params.json (Mistral format). " "Otherwise please specify your_custom_config_format " - "in engine args for customized config parser.") + "in engine args for customized config parser." + ) except Exception as e: error_message = ( @@ -576,7 +593,8 @@ def get_config( "'params.json'.\n" "3. For GGUF: pass the local path of the GGUF checkpoint.\n" " Loading GGUF from a remote repo directly is not yet " - "supported.\n").format(model=model) + "supported.\n" + ).format(model=model) raise ValueError(error_message) from e @@ -591,8 +609,7 @@ def get_config( # Special architecture mapping check for GGUF models if is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - raise RuntimeError( - f"Can't get gguf config for {config.model_type}.") + raise RuntimeError(f"Can't get gguf config for {config.model_type}.") model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) @@ -602,29 +619,35 @@ def get_config( # ModelOpt 0.29.0 and before saves the quantization config in a separate # "hf_quant_config.json" in the same directory as the model config file. - if quantization_config is None \ - and file_or_path_exists(model, "hf_quant_config.json", revision): - quantization_config = get_hf_file_to_dict("hf_quant_config.json", - model, revision) + if quantization_config is None and file_or_path_exists( + model, "hf_quant_config.json", revision + ): + quantization_config = get_hf_file_to_dict( + "hf_quant_config.json", model, revision + ) if quantization_config is not None: config.quantization_config = quantization_config # auto-enable DeepGEMM UE8M0 on Hopper if model config requests it scale_fmt = quantization_config.get("scale_fmt", None) - if scale_fmt in ("ue8m0", ): + if scale_fmt in ("ue8m0",): if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0_HOPPER"): os.environ["VLLM_USE_DEEP_GEMM_E8M0_HOPPER"] = "1" logger.info_once( - ("Detected quantization_config.scale_fmt=%s; " - "enabling Hopper UE8M0."), + ( + "Detected quantization_config.scale_fmt=%s; " + "enabling Hopper UE8M0." + ), scale_fmt, ) elif not envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: logger.warning_once( - ("Model config requests UE8M0 " - "(quantization_config.scale_fmt=%s), but " - "VLLM_USE_DEEP_GEMM_E8M0_HOPPER=0 is set; " - "Hopper UE8M0 disabled."), + ( + "Model config requests UE8M0 " + "(quantization_config.scale_fmt=%s), but " + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER=0 is set; " + "Hopper UE8M0 disabled." + ), scale_fmt, ) @@ -643,17 +666,17 @@ def get_config( return config -def try_get_local_file(model: Union[str, Path], - file_name: str, - revision: Optional[str] = 'main') -> Optional[Path]: +def try_get_local_file( + model: Union[str, Path], file_name: str, revision: Optional[str] = "main" +) -> Optional[Path]: file_path = Path(model) / file_name if file_path.is_file(): return file_path else: try: - cached_filepath = try_to_load_from_cache(repo_id=model, - filename=file_name, - revision=revision) + cached_filepath = try_to_load_from_cache( + repo_id=model, filename=file_name, revision=revision + ) if isinstance(cached_filepath, str): return Path(cached_filepath) except ValueError: @@ -661,9 +684,9 @@ def try_get_local_file(model: Union[str, Path], return None -def get_hf_file_to_dict(file_name: str, - model: Union[str, Path], - revision: Optional[str] = 'main'): +def get_hf_file_to_dict( + file_name: str, model: Union[str, Path], revision: Optional[str] = "main" +): """ Downloads a file from the Hugging Face Hub and returns its contents as a dictionary. @@ -678,25 +701,27 @@ def get_hf_file_to_dict(file_name: str, the contents of the downloaded file. """ - file_path = try_get_local_file(model=model, - file_name=file_name, - revision=revision) + file_path = try_get_local_file(model=model, file_name=file_name, revision=revision) if file_path is None: try: hf_hub_file = hf_hub_download(model, file_name, revision=revision) except huggingface_hub.errors.OfflineModeIsEnabled: return None - except (RepositoryNotFoundError, RevisionNotFoundError, - EntryNotFoundError, LocalEntryNotFoundError) as e: + except ( + RepositoryNotFoundError, + RevisionNotFoundError, + EntryNotFoundError, + LocalEntryNotFoundError, + ) as e: logger.debug("File or repository not found in hf_hub_download", e) return None except HfHubHTTPError as e: logger.warning( - "Cannot connect to Hugging Face Hub. Skipping file " - "download for '%s':", + "Cannot connect to Hugging Face Hub. Skipping file download for '%s':", file_name, - exc_info=e) + exc_info=e, + ) return None file_path = Path(hf_hub_file) @@ -708,8 +733,7 @@ def get_hf_file_to_dict(file_name: str, @cache -def get_pooling_config(model: str, - revision: Optional[str] = 'main') -> Optional[dict]: +def get_pooling_config(model: str, revision: Optional[str] = "main") -> Optional[dict]: """ This function gets the pooling and normalize config from the model - only applies to @@ -717,20 +741,20 @@ def get_pooling_config(model: str, Args: model: The name of the Hugging Face model. - revision: The specific version of the model to use. + revision: The specific version of the model to use. Defaults to 'main'. Returns: - A dictionary containing the pooling type and whether + A dictionary containing the pooling type and whether normalization is used, or None if no pooling configuration is found. """ modules_file_name = "modules.json" modules_dict = None - if file_or_path_exists(model=model, - config_name=modules_file_name, - revision=revision): + if file_or_path_exists( + model=model, config_name=modules_file_name, revision=revision + ): modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) if modules_dict is None: @@ -738,20 +762,31 @@ def get_pooling_config(model: str, logger.info("Found sentence-transformers modules configuration.") - pooling = next((item for item in modules_dict - if item["type"] == "sentence_transformers.models.Pooling"), - None) + pooling = next( + ( + item + for item in modules_dict + if item["type"] == "sentence_transformers.models.Pooling" + ), + None, + ) normalize = bool( - next((item for item in modules_dict - if item["type"] == "sentence_transformers.models.Normalize"), - False)) + next( + ( + item + for item in modules_dict + if item["type"] == "sentence_transformers.models.Normalize" + ), + False, + ) + ) if pooling: - pooling_file_name = "{}/config.json".format(pooling["path"]) pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) pooling_type_name = next( - (item for item, val in pooling_dict.items() if val is True), None) + (item for item, val in pooling_dict.items() if val is True), None + ) if pooling_type_name is not None: pooling_type_name = get_pooling_config_name(pooling_type_name) @@ -772,20 +807,19 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]: if "lasttoken" in pooling_name: pooling_name = "last" - supported_pooling_types = ['LAST', 'ALL', 'CLS', 'STEP', 'MEAN'] + supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"] pooling_type_name = pooling_name.upper() if pooling_type_name in supported_pooling_types: return pooling_type_name - raise NotImplementedError( - f"Pooling type {pooling_type_name} not supported") + raise NotImplementedError(f"Pooling type {pooling_type_name} not supported") @cache -def get_sentence_transformer_tokenizer_config(model: Union[str, Path], - revision: Optional[str] = 'main' - ): +def get_sentence_transformer_tokenizer_config( + model: Union[str, Path], revision: Optional[str] = "main" +): """ Returns the tokenization configuration dictionary for a given Sentence Transformer BERT model. @@ -812,9 +846,10 @@ def get_sentence_transformer_tokenizer_config(model: Union[str, Path], encoder_dict = None for config_file in sentence_transformer_config_files: - if try_get_local_file(model=model, - file_name=config_file, - revision=revision) is not None: + if ( + try_get_local_file(model=model, file_name=config_file, revision=revision) + is not None + ): encoder_dict = get_hf_file_to_dict(config_file, model, revision) if encoder_dict: break @@ -822,16 +857,15 @@ def get_sentence_transformer_tokenizer_config(model: Union[str, Path], if not encoder_dict and not Path(model).is_absolute(): try: # If model is on HuggingfaceHub, get the repo files - repo_files = list_repo_files(model, - revision=revision, - token=_get_hf_token()) + repo_files = list_repo_files( + model, revision=revision, token=_get_hf_token() + ) except Exception: repo_files = [] for config_name in sentence_transformer_config_files: if config_name in repo_files: - encoder_dict = get_hf_file_to_dict(config_name, model, - revision) + encoder_dict = get_hf_file_to_dict(config_name, model, revision) if encoder_dict: break @@ -848,34 +882,39 @@ def get_sentence_transformer_tokenizer_config(model: Union[str, Path], def maybe_register_config_serialize_by_value() -> None: """Try to register HF model configuration class to serialize by value - If trust_remote_code is set, and the model's config file specifies an - `AutoConfig` class, then the config class is typically an instance of - a custom class imported from the HF modules cache. - - Examples: - - >>> from transformers import AutoConfig - >>> klass = AutoConfig.from_pretrained('meta-llama/Meta-Llama-3-8B', trust_remote_code=True) - >>> klass.__class__ # transformers.models.llama.configuration_llama.LlamaConfig - >>> import transformers_modules # error, not initialized - >>> klass = AutoConfig.from_pretrained('deepseek-ai/DeepSeek-V2.5', trust_remote_code=True) - >>> import transformers_modules # success, initialized - >>> klass.__class__ # transformers_modules.deepseek-ai.DeepSeek-V2.5.98b11844770b2c3ffc18b175c758a803640f4e77.configuration_deepseek.DeepseekV2Config + If trust_remote_code is set, and the model's config file specifies an + `AutoConfig` class, then the config class is typically an instance of + a custom class imported from the HF modules cache. - In the DeepSeek example, the config class is an instance of a custom - class that is not serializable by default. This class will not be - importable in spawned workers, and won't exist at all on - other nodes, which breaks serialization of the config. - - In this function we tell the cloudpickle serialization library to pass - instances of these generated classes by value instead of by reference, - i.e. the class definition is serialized along with its data so that the - class module does not need to be importable on the receiving end. + Examples: - See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs - """ # noqa + >>> from transformers import AutoConfig + >>> klass = AutoConfig.from_pretrained( + ... "meta-llama/Meta-Llama-3-8B", trust_remote_code=True + ... ) + >>> klass.__class__ # transformers.models.llama.configuration_llama.LlamaConfig + >>> import transformers_modules # error, not initialized + >>> klass = AutoConfig.from_pretrained( + ... "deepseek-ai/DeepSeek-V2.5", trust_remote_code=True + ... ) + >>> import transformers_modules # success, initialized + >>> klass.__class__ # transformers_modules.deepseek-ai.DeepSeek-V2.5.98b11844770b2c3ffc18b175c758a803640f4e77.configuration_deepseek.DeepseekV2Config + + In the DeepSeek example, the config class is an instance of a custom + class that is not serializable by default. This class will not be + importable in spawned workers, and won't exist at all on + other nodes, which breaks serialization of the config. + + In this function we tell the cloudpickle serialization library to pass + instances of these generated classes by value instead of by reference, + i.e. the class definition is serialized along with its data so that the + class module does not need to be importable on the receiving end. + + See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs + """ # noqa try: import transformers_modules + transformers_modules_available = True except ImportError: transformers_modules_available = False @@ -892,7 +931,7 @@ class module does not need to be importable on the receiving end. # serialization of VllmConfig objects that may contain custom configs # from transformers_modules def _reduce_config(config: VllmConfig): - return (pickle.loads, (cloudpickle.dumps(config), )) + return (pickle.loads, (cloudpickle.dumps(config),)) multiprocessing.reducer.register(VllmConfig, _reduce_config) @@ -902,6 +941,7 @@ def _reduce_config(config: VllmConfig): # ray vendors its own version of cloudpickle from vllm.executor.ray_utils import ray + if ray: ray.cloudpickle.register_pickle_by_value(transformers_modules) @@ -911,7 +951,8 @@ def _reduce_config(config: VllmConfig): " trust_remote_code with by-value serialization. This may" " lead to a later error. If remote code is not needed" " remove `--trust-remote-code`", - exc_info=e) + exc_info=e, + ) def get_hf_image_processor_config( @@ -926,10 +967,9 @@ def get_hf_image_processor_config( # Separate model folder from file path for GGUF models if check_gguf_file(model): model = Path(model).parent - return get_image_processor_config(model, - token=hf_token, - revision=revision, - **kwargs) + return get_image_processor_config( + model, token=hf_token, revision=revision, **kwargs + ) def get_hf_text_config(config: PretrainedConfig): @@ -984,8 +1024,9 @@ def try_get_safetensors_metadata( ) try: - return with_retry(get_safetensors_metadata_partial, - "Error retrieving safetensors") + return with_retry( + get_safetensors_metadata_partial, "Error retrieving safetensors" + ) except Exception: return None @@ -1018,9 +1059,9 @@ def get_safetensors_params_metadata( safetensors_to_check = model_path.glob("*.safetensors") full_metadata = { param_name: info - for file_path in safetensors_to_check if file_path.is_file() - for param_name, info in parse_safetensors_file_metadata( - file_path).items() + for file_path in safetensors_to_check + if file_path.is_file() + for param_name, info in parse_safetensors_file_metadata(file_path).items() } else: repo_mt = try_get_safetensors_metadata(model, revision=revision) @@ -1040,7 +1081,8 @@ def _download_mistral_config_file(model, revision) -> dict: raise ValueError( f"Failed to load mistral '{config_file_name}' config for model " f"{model}. Please check if the model is a mistral-format model " - f"and if the config file exists.") + f"and if the config file exists." + ) assert isinstance(config_dict, dict) return config_dict @@ -1049,10 +1091,12 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: max_position_embeddings = 128_000 try: trust_remote_code_val = kwargs.get("trust_remote_code", False) - hf_config = get_config(model=model, - trust_remote_code=trust_remote_code_val, - revision=revision, - config_format="hf") + hf_config = get_config( + model=model, + trust_remote_code=trust_remote_code_val, + revision=revision, + config_format="hf", + ) if hf_value := hf_config.get_text_config().max_position_embeddings: max_position_embeddings = hf_value except Exception as e: @@ -1060,7 +1104,8 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: "The params.json file is missing 'max_position_embeddings'" " and could not get a value from the HF config." " Defaulting to 128000", - exc_info=e) + exc_info=e, + ) return max_position_embeddings @@ -1076,29 +1121,28 @@ def get_model_path(model: Union[str, Path], revision: Optional[str] = None): if envs.VLLM_USE_MODELSCOPE: from modelscope.hub.snapshot_download import snapshot_download + return snapshot_download(model_id=model, **common_kwargs) from huggingface_hub import snapshot_download + return snapshot_download(repo_id=model, **common_kwargs) -def get_hf_file_bytes(file_name: str, - model: Union[str, Path], - revision: Optional[str] = 'main') -> Optional[bytes]: +def get_hf_file_bytes( + file_name: str, model: Union[str, Path], revision: Optional[str] = "main" +) -> Optional[bytes]: """Get file contents from HuggingFace repository as bytes.""" - file_path = try_get_local_file(model=model, - file_name=file_name, - revision=revision) + file_path = try_get_local_file(model=model, file_name=file_name, revision=revision) if file_path is None: - hf_hub_file = hf_hub_download(model, - file_name, - revision=revision, - token=_get_hf_token()) + hf_hub_file = hf_hub_download( + model, file_name, revision=revision, token=_get_hf_token() + ) file_path = Path(hf_hub_file) if file_path is not None and file_path.is_file(): - with open(file_path, 'rb') as file: + with open(file_path, "rb") as file: return file.read() return None diff --git a/vllm/transformers_utils/config_parser_base.py b/vllm/transformers_utils/config_parser_base.py index c27177f74d4b..0e1c49b428b0 100644 --- a/vllm/transformers_utils/config_parser_base.py +++ b/vllm/transformers_utils/config_parser_base.py @@ -9,12 +9,13 @@ class ConfigParserBase(ABC): - @abstractmethod - def parse(self, - model: Union[str, Path], - trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - **kwargs) -> tuple[dict, PretrainedConfig]: + def parse( + self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: raise NotImplementedError diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 1b33b5e70e0b..72c90e073131 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -12,6 +12,7 @@ from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig + # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. @@ -30,9 +31,11 @@ from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig -from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, - Step3VisionEncoderConfig, - Step3VLConfig) +from vllm.transformers_utils.configs.step3_vl import ( + Step3TextConfig, + Step3VisionEncoderConfig, + Step3VLConfig, +) from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py index a789b93b5edf..1707e15285c8 100644 --- a/vllm/transformers_utils/configs/arctic.py +++ b/vllm/transformers_utils/configs/arctic.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable # ruff: noqa: E501 # coding=utf-8 # Copied from # https://huggingface.co/Snowflake/snowflake-arctic-instruct/blob/main/configuration_arctic.py -""" Arctic model configuration""" +"""Arctic model configuration""" from dataclasses import asdict, dataclass from typing import Any diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py index 176d2b8f63fe..1d795b55c8bc 100644 --- a/vllm/transformers_utils/configs/chatglm.py +++ b/vllm/transformers_utils/configs/chatglm.py @@ -13,33 +13,35 @@ class ChatGLMConfig(PretrainedConfig): "n_head_kv": "multi_query_group_num", } - def __init__(self, - num_layers=28, - padded_vocab_size=65024, - hidden_size=4096, - ffn_hidden_size=13696, - kv_channels=128, - num_attention_heads=32, - seq_length=2048, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - rmsnorm=True, - apply_residual_connection_post_layernorm=False, - post_layer_norm=True, - add_bias_linear=False, - add_qkv_bias=False, - interleaved_qkv=False, - bias_dropout_fusion=True, - multi_query_attention=False, - multi_query_group_num=1, - apply_query_key_layer_scaling=True, - attention_softmax_in_fp32=True, - fp32_residual_connection=False, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs): + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs, + ): self.num_layers = num_layers self.vocab_size = padded_vocab_size self.padded_vocab_size = padded_vocab_size @@ -55,7 +57,8 @@ def __init__(self, self.layernorm_epsilon = layernorm_epsilon self.rmsnorm = rmsnorm self.apply_residual_connection_post_layernorm = ( - apply_residual_connection_post_layernorm) + apply_residual_connection_post_layernorm + ) self.post_layer_norm = post_layer_norm self.add_bias_linear = add_bias_linear self.add_qkv_bias = add_qkv_bias diff --git a/vllm/transformers_utils/configs/deepseek_v3.py b/vllm/transformers_utils/configs/deepseek_v3.py index 4b26cdfd94b5..91fbed79dd02 100644 --- a/vllm/transformers_utils/configs/deepseek_v3.py +++ b/vllm/transformers_utils/configs/deepseek_v3.py @@ -7,7 +7,6 @@ class DeepseekV3Config(PretrainedConfig): - model_type = "deepseek_v3" keys_to_ignore_at_inference = ["past_key_values"] @@ -30,14 +29,14 @@ def __init__( qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, - topk_method='noaux_tc', + topk_method="noaux_tc", n_group=8, topk_group=4, num_experts_per_tok=8, moe_layer_freq=1, first_k_dense_replace=3, norm_topk_prob=True, - scoring_func='sigmoid', + scoring_func="sigmoid", hidden_act="silu", max_position_embeddings=4096, initializer_range=0.02, diff --git a/vllm/transformers_utils/configs/deepseek_vl2.py b/vllm/transformers_utils/configs/deepseek_vl2.py index 957d63831841..7abfe6229842 100644 --- a/vllm/transformers_utils/configs/deepseek_vl2.py +++ b/vllm/transformers_utils/configs/deepseek_vl2.py @@ -25,20 +25,22 @@ class VisionEncoderConfig(PretrainedConfig): deterministic: bool = False num_recomputing_layers: int = 0 - def __init__(self, - model_name: str = "vit_so400m_patch14_siglip_384.webli", - image_size: int = 384, - patch_size: int = 16, - width: int = 1024, - layers: int = 24, - heads: int = 16, - mlp_ratio: int = 4, - global_pool: str = "map", - ignore_head: bool = True, - class_token: bool = False, - num_classes: int = 0, - use_checkpoint: bool = False, - **kwargs): + def __init__( + self, + model_name: str = "vit_so400m_patch14_siglip_384.webli", + image_size: int = 384, + patch_size: int = 16, + width: int = 1024, + layers: int = 24, + heads: int = 16, + mlp_ratio: int = 4, + global_pool: str = "map", + ignore_head: bool = True, + class_token: bool = False, + num_classes: int = 0, + use_checkpoint: bool = False, + **kwargs, + ): self.model_name = model_name self.image_size = image_size self.patch_size = patch_size @@ -65,14 +67,16 @@ class MlpProjectorConfig(PretrainedConfig): downsample_ratio: int = 2 token_pooling: bool = False - def __init__(self, - projector_type: str = "downsample_mlp_gelu", - input_dim: int = 1152, - n_embed: int = 2048, - depth: int = 2, - mlp_ratio: int = 1, - downsample_ratio: int = 2, - **kwargs): + def __init__( + self, + projector_type: str = "downsample_mlp_gelu", + input_dim: int = 1152, + n_embed: int = 2048, + depth: int = 2, + mlp_ratio: int = 1, + downsample_ratio: int = 2, + **kwargs, + ): self.projector_type = projector_type self.input_dim = input_dim self.n_embed = n_embed @@ -84,7 +88,6 @@ def __init__(self, class DeepseekV2Config(PretrainedConfig): - model_type = "deepseek_v2" keys_to_ignore_at_inference = ["past_key_values"] @@ -106,14 +109,14 @@ def __init__( qk_rope_head_dim=64, v_head_dim=128, qk_nope_head_dim=128, - topk_method='gready', + topk_method="gready", n_group=None, topk_group=None, num_experts_per_tok=None, moe_layer_freq=1, first_k_dense_replace=0, norm_topk_prob=False, - scoring_func='softmax', + scoring_func="softmax", aux_loss_alpha=0.001, seq_aux=True, hidden_act="silu", @@ -191,14 +194,15 @@ class DeepseekVLV2Config(PretrainedConfig): tile_tag: str = "2D" global_view_pos: str = "head" - candidate_resolutions: tuple[tuple[int, int]] = ((384, 384), ) - - def __init__(self, - tile_tag: str = "tile_tag", - global_view_pos: str = "head", - candidate_resolutions: tuple[tuple[int, - int]] = ((384, 384), ), - **kwargs): + candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),) + + def __init__( + self, + tile_tag: str = "tile_tag", + global_view_pos: str = "head", + candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),), + **kwargs, + ): super().__init__(**kwargs) vision_config = kwargs.get("vision_config", {}) diff --git a/vllm/transformers_utils/configs/dotsocr.py b/vllm/transformers_utils/configs/dotsocr.py index 6bb3c12d9c7e..446693b9a32e 100644 --- a/vllm/transformers_utils/configs/dotsocr.py +++ b/vllm/transformers_utils/configs/dotsocr.py @@ -53,12 +53,14 @@ def __init__( class DotsOCRConfig(Qwen2Config): model_type = "dots_ocr" - def __init__(self, - image_token_id=151665, - video_token_id=151656, - vision_config: Optional[dict] = None, - *args, - **kwargs): + def __init__( + self, + image_token_id=151665, + video_token_id=151656, + vision_config: Optional[dict] = None, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.image_token_id = image_token_id self.video_token_id = video_token_id diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 444ed70de3d0..6e18513d1234 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -12,12 +12,13 @@ class EAGLEConfig(PretrainedConfig): model_type = "eagle" - def __init__(self, - model: Union[PretrainedConfig, dict, None] = None, - truncated_vocab_size: Optional[int] = None, - method: Optional[str] = 'eagle', - **kwargs): - + def __init__( + self, + model: Union[PretrainedConfig, dict, None] = None, + truncated_vocab_size: Optional[int] = None, + method: Optional[str] = "eagle", + **kwargs, + ): model_config: Union[PretrainedConfig, DeepseekV2Config, None] if isinstance(model, dict): archs = model.get("architectures", []) @@ -31,8 +32,7 @@ def __init__(self, model_config = model for k, v in kwargs.items(): - if k != "architectures" and k != "model_type" and hasattr( - model_config, k): + if k != "architectures" and k != "model_type" and hasattr(model_config, k): setattr(model_config, k, v) self.model = model_config @@ -40,31 +40,39 @@ def __init__(self, if self.model is None: self.truncated_vocab_size = None else: - self.truncated_vocab_size = self.model.vocab_size if \ - truncated_vocab_size is None else truncated_vocab_size + self.truncated_vocab_size = ( + self.model.vocab_size + if truncated_vocab_size is None + else truncated_vocab_size + ) # Eagle model name should follow naming convention of # LlamaForCausalLM -> EagleLlamaForCausalLM # LlamaForCausalLM -> Eagle3LlamaForCausalLM # LlamaForCausalLMEagle3 -> LlamaForCausalLMEagle3 if method == "eagle": - assert self.model is not None, \ + assert self.model is not None, ( "model should not be None when method is eagle" + ) kwargs["architectures"] = [ - f"Eagle{arch}" if not arch.startswith("Eagle") \ - else arch for arch in self.model.architectures + f"Eagle{arch}" if not arch.startswith("Eagle") else arch + for arch in self.model.architectures ] elif method == "eagle3": - assert self.model is not None, \ + assert self.model is not None, ( "model should not be None when method is eagle3" + ) kwargs["architectures"] = [ - arch if arch.startswith("Eagle3") or arch.endswith("Eagle3") - else f"Eagle3{arch}" for arch in self.model.architectures + arch + if arch.startswith("Eagle3") or arch.endswith("Eagle3") + else f"Eagle3{arch}" + for arch in self.model.architectures ] else: - raise ValueError(f"Invalid method {method}. " - "Supported methods are eagle and eagle3.") + raise ValueError( + f"Invalid method {method}. Supported methods are eagle and eagle3." + ) super().__init__(**kwargs) @@ -80,5 +88,6 @@ def from_pretrained( **kwargs, ) -> "EAGLEConfig": config_dict, kwargs = cls.get_config_dict( - pretrained_model_name_or_path, **kwargs) + pretrained_model_name_or_path, **kwargs + ) return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/transformers_utils/configs/falcon.py b/vllm/transformers_utils/configs/falcon.py index 2f5400463d91..c646d241d4eb 100644 --- a/vllm/transformers_utils/configs/falcon.py +++ b/vllm/transformers_utils/configs/falcon.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Falcon configuration""" + from transformers.configuration_utils import PretrainedConfig @@ -77,9 +78,7 @@ def __init__( # Hack for falcon-40b self.new_decoder_architecture = True - super().__init__(bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs) + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @property def head_dim(self): diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 3f50638f16b5..6b581bf18775 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -75,7 +75,7 @@ class JAISConfig(PretrainedConfig): Whether or not the model should return the last key/values attentions (not used by all models). scale_attn_by_inverse_layer_idx (`bool`, *optional*, default `True`): - Whether to additionally scale attention weights + Whether to additionally scale attention weights by `1 / layer_idx + 1`. reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): Whether to scale keys (K) prior to computing attention @@ -209,29 +209,35 @@ def _alibi_scaling_validation(self): if self.alibi_scaling is None: return - if (not isinstance(self.alibi_scaling, dict) - or len(self.alibi_scaling) != 2): + if not isinstance(self.alibi_scaling, dict) or len(self.alibi_scaling) != 2: raise ValueError( "`alibi_scaling` must be a dictionary with two fields, " "`type` and `factor` or `type` and `train_seq_len`, " - f"got {self.alibi_scaling}") + f"got {self.alibi_scaling}" + ) alibi_scaling_type = self.alibi_scaling.get("type", None) alibi_scaling_factor = self.alibi_scaling.get("factor", None) alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) if alibi_scaling_type is None or alibi_scaling_type != "linear": - raise ValueError(f"`alibi_scaling`'s type field must be 'linear', " - f"got {alibi_scaling_type}") - if (alibi_scaling_factor is not None - and not isinstance(alibi_scaling_factor, float) - or (alibi_scaling_factor is not None - and alibi_scaling_factor <= 1.0)): + raise ValueError( + f"`alibi_scaling`'s type field must be 'linear', " + f"got {alibi_scaling_type}" + ) + if ( + alibi_scaling_factor is not None + and not isinstance(alibi_scaling_factor, float) + or (alibi_scaling_factor is not None and alibi_scaling_factor <= 1.0) + ): raise ValueError( f"`alibi_scaling`'s factor field must be a float > 1.0, " - f"got {alibi_scaling_factor}") - if (alibi_dynamic_scaling is not None - and not isinstance(alibi_dynamic_scaling, int) - or (alibi_dynamic_scaling is not None - and alibi_dynamic_scaling <= 1)): + f"got {alibi_scaling_factor}" + ) + if ( + alibi_dynamic_scaling is not None + and not isinstance(alibi_dynamic_scaling, int) + or (alibi_dynamic_scaling is not None and alibi_dynamic_scaling <= 1) + ): raise ValueError( f"`alibi_scaling`'s `train_seq_len` field must be an " - f"integer > 1, got {alibi_dynamic_scaling}") + f"integer > 1, got {alibi_dynamic_scaling}" + ) diff --git a/vllm/transformers_utils/configs/kimi_vl.py b/vllm/transformers_utils/configs/kimi_vl.py index ae8dac0f381d..89a8878465b6 100644 --- a/vllm/transformers_utils/configs/kimi_vl.py +++ b/vllm/transformers_utils/configs/kimi_vl.py @@ -12,13 +12,15 @@ class KimiVLConfig(PretrainedConfig): model_type = "kimi_vl" - def __init__(self, - vision_config: Optional[Union[dict, MoonViTConfig]] = None, - text_config: Optional[Union[dict, DeepseekV2Config]] = None, - ignore_index: int = -100, - media_placeholder_token_id: int = 163605, - pad_token_id: int = 0, - **kwargs): + def __init__( + self, + vision_config: Optional[Union[dict, MoonViTConfig]] = None, + text_config: Optional[Union[dict, DeepseekV2Config]] = None, + ignore_index: int = -100, + media_placeholder_token_id: int = 163605, + pad_token_id: int = 0, + **kwargs, + ): if vision_config is None: vision_config = MoonViTConfig() elif isinstance(vision_config, dict): diff --git a/vllm/transformers_utils/configs/medusa.py b/vllm/transformers_utils/configs/medusa.py index 9ba52956a8e8..7dcfd0cf26ae 100644 --- a/vllm/transformers_utils/configs/medusa.py +++ b/vllm/transformers_utils/configs/medusa.py @@ -10,16 +10,17 @@ class MedusaConfig(PretrainedConfig): model_type = "medusa" - def __init__(self, - hidden_size: int = 4096, - vocab_size: int = 32001, - num_heads: int = 5, - num_hidden_layers: int = 1, - max_paths: int = 64, - topk: int = 10, - truncated_vocab_size: Optional[int] = None, - **kwargs): - + def __init__( + self, + hidden_size: int = 4096, + vocab_size: int = 32001, + num_heads: int = 5, + num_hidden_layers: int = 1, + max_paths: int = 64, + topk: int = 10, + truncated_vocab_size: Optional[int] = None, + **kwargs, + ): self.hidden_size = hidden_size self.vocab_size = vocab_size self.num_heads = num_heads @@ -27,8 +28,9 @@ def __init__(self, self.max_paths = max_paths self.topk = topk self.max_seq_len = int(2**20) - self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\ - else truncated_vocab_size + self.truncated_vocab_size = ( + vocab_size if truncated_vocab_size is None else truncated_vocab_size + ) if "architectures" not in kwargs: kwargs["architectures"] = ["MedusaModel"] @@ -41,12 +43,13 @@ def from_pretrained( **kwargs, ) -> "MedusaConfig": config_dict, kwargs = cls.get_config_dict( - pretrained_model_name_or_path, **kwargs) + pretrained_model_name_or_path, **kwargs + ) for k in list(config_dict.keys()): - if 'num' in k: - if 'heads' in k: + if "num" in k: + if "heads" in k: config_dict["num_heads"] = config_dict.pop(k) - elif 'layers' in k: + elif "layers" in k: config_dict["num_hidden_layers"] = config_dict.pop(k) return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/transformers_utils/configs/midashenglm.py b/vllm/transformers_utils/configs/midashenglm.py index 1c23202e23c8..5c9e72be8ebf 100644 --- a/vllm/transformers_utils/configs/midashenglm.py +++ b/vllm/transformers_utils/configs/midashenglm.py @@ -25,7 +25,8 @@ from transformers import PretrainedConfig from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( - Qwen2_5OmniTextConfig) + Qwen2_5OmniTextConfig, +) class DashengConfig(PretrainedConfig): @@ -91,11 +92,13 @@ def __init__( audio_token_id: Optional[int] = None, **kwargs, ): - self.audio_encoder_config = DashengConfig( - **(audio_encoder_config or {})) + self.audio_encoder_config = DashengConfig(**(audio_encoder_config or {})) self.subsample_factor = subsample_factor - self.text_config = (Qwen2_5OmniTextConfig( - **text_config) if text_config else Qwen2_5OmniTextConfig()) + self.text_config = ( + Qwen2_5OmniTextConfig(**text_config) + if text_config + else Qwen2_5OmniTextConfig() + ) self.text_config.rope_scaling = None # uses_mrope is false self.audio_token_id = audio_token_id super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index 5d9206e18832..d5bf79e01f95 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -9,8 +9,7 @@ logger = init_logger(__name__) -def adapt_config_dict(config_dict: dict[str, Any], - **kwargs) -> PretrainedConfig: +def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig: config_dict.update(kwargs) config_dict = _remap_general_mistral_args(config_dict) @@ -25,15 +24,16 @@ def adapt_config_dict(config_dict: dict[str, Any], if bool(config_dict.get("yarn")): config_dict = _remap_mistral_yarn_args(config_dict) - is_vision = ((config_dict.get("multimodal") - or {}).get("vision_encoder_args") - or config_dict.get("vision_encoder")) + is_vision = (config_dict.get("multimodal") or {}).get( + "vision_encoder_args" + ) or config_dict.get("vision_encoder") is_audio = bool( - ((config_dict.get("multimodal") or {}).get("whisper_model_args") - or {}).get("encoder_args")) + ((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get( + "encoder_args" + ) + ) - assert not (is_vision and is_audio), \ - "Vision and audio are mutually exclusive" + assert not (is_vision and is_audio), "Vision and audio are mutually exclusive" if is_vision: config_dict = _remap_mistral_vision_args(config_dict) @@ -77,7 +77,7 @@ def _remap_mistral_yarn_args(config: dict) -> dict: config["rope_scaling"] = { "rope_type": "yarn", "mscale_all_dim": 1, # We hardcoded this to 1 - **renamed_yarn_config + **renamed_yarn_config, } return config @@ -105,8 +105,7 @@ def _remap_general_mistral_args(config: dict) -> dict: if key in config: config[new_key] = config.pop(key) - for new_key, (key, - default_value) in top_level_mapping_with_default.items(): + for new_key, (key, default_value) in top_level_mapping_with_default.items(): config[new_key] = config.pop(key, default_value) return config @@ -116,16 +115,12 @@ def _remap_mistral_quantization_args(config: dict) -> dict: quantization = config.get("quantization", {}) if quantization.get("qformat_weight") == "fp8_e4m3": # This maps to the FP8 static per-tensor quantization scheme - quantization_config = { - "quant_method": "fp8", - "activation_scheme": "static" - } + quantization_config = {"quant_method": "fp8", "activation_scheme": "static"} elif quantization.get("quant_method") == "compressed-tensors": # Pass through the quantization config to compressed-tensors quantization_config = quantization else: - raise ValueError( - f"Found unknown quantization='{quantization}' in config") + raise ValueError(f"Found unknown quantization='{quantization}' in config") config["quantization_config"] = quantization_config @@ -139,13 +134,10 @@ def _remap_mistral_audio_args(config: dict) -> dict: quant_config = config.get("quantization_config") config = { - "model_type": - "whixtral", + "model_type": "whixtral", "architectures": ["VoxtralForConditionalGeneration"], - "text_config": - PretrainedConfig.from_dict(config), - "audio_config": - WhisperConfig( + "text_config": PretrainedConfig.from_dict(config), + "audio_config": WhisperConfig( num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"], window_size=encoder_args["audio_encoding_args"]["window_size"], sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"], @@ -158,7 +150,7 @@ def _remap_mistral_audio_args(config: dict) -> dict: vocab_size=encoder_args["vocab_size"], max_source_positions=encoder_args["max_source_positions"], is_encoder_decoder=False, # Override WhisperConfig default - ) + ), } if quant_config: config["quantization_config"] = quant_config diff --git a/vllm/transformers_utils/configs/mlp_speculator.py b/vllm/transformers_utils/configs/mlp_speculator.py index 2fa284e5c9e8..45d76a8fdf26 100644 --- a/vllm/transformers_utils/configs/mlp_speculator.py +++ b/vllm/transformers_utils/configs/mlp_speculator.py @@ -13,16 +13,18 @@ class MLPSpeculatorConfig(PretrainedConfig): "hidden_size": "emb_dim", } - def __init__(self, - vocab_size: int = 32000, - emb_dim: int = 4096, - inner_dim: int = 0, - n_predict: int = 3, - top_k_tokens_per_head: Optional[list[int]] = None, - n_candidates: int = 5, - tie_weights: bool = False, - scale_input: bool = False, - **kwargs): + def __init__( + self, + vocab_size: int = 32000, + emb_dim: int = 4096, + inner_dim: int = 0, + n_predict: int = 3, + top_k_tokens_per_head: Optional[list[int]] = None, + n_candidates: int = 5, + tie_weights: bool = False, + scale_input: bool = False, + **kwargs, + ): """ Initialize an MLPSpeculatorConfig diff --git a/vllm/transformers_utils/configs/moonvit.py b/vllm/transformers_utils/configs/moonvit.py index a6f712f3d600..6e9b2897f4cc 100644 --- a/vllm/transformers_utils/configs/moonvit.py +++ b/vllm/transformers_utils/configs/moonvit.py @@ -8,16 +8,16 @@ class MoonViTConfig(PretrainedConfig): model_type = "moonvit" def __init__( - self, - patch_size: int = 14, - init_pos_emb_height: int = 64, - init_pos_emb_width: int = 64, - num_attention_heads: int = 16, - num_hidden_layers: int = 27, - hidden_size: int = 1152, - intermediate_size: int = 4304, - merge_kernel_size: tuple[int, int] = (2, 2), - **kwargs, + self, + patch_size: int = 14, + init_pos_emb_height: int = 64, + init_pos_emb_width: int = 64, + num_attention_heads: int = 16, + num_hidden_layers: int = 27, + hidden_size: int = 1152, + intermediate_size: int = 4304, + merge_kernel_size: tuple[int, int] = (2, 2), + **kwargs, ): super().__init__(**kwargs) self.patch_size = patch_size diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py index 090fefa14203..60eed549561f 100644 --- a/vllm/transformers_utils/configs/nemotron.py +++ b/vllm/transformers_utils/configs/nemotron.py @@ -62,7 +62,7 @@ class NemotronConfig(PretrainedConfig): (MQA) otherwise GQA is used. When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original - heads within that group. For more details checkout + heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`): @@ -147,8 +147,9 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads head_dim = head_dim or kwargs.get("kv_channels") - self.head_dim = head_dim if head_dim is not None else ( - hidden_size // num_attention_heads) + self.head_dim = ( + head_dim if head_dim is not None else (hidden_size // num_attention_heads) + ) # for backward compatibility if num_key_value_heads is None: @@ -162,8 +163,11 @@ def __init__( self.rope_theta = rope_theta self.rope_scaling = rope_scaling # for backward compatibility - partial_rotary_factor = kwargs.get("rope_percent") or kwargs.get( - "rope_percentage") or partial_rotary_factor + partial_rotary_factor = ( + kwargs.get("rope_percent") + or kwargs.get("rope_percentage") + or partial_rotary_factor + ) self.partial_rotary_factor = partial_rotary_factor self._rope_scaling_validation() self.attention_bias = attention_bias @@ -185,21 +189,24 @@ def _rope_scaling_validation(self): if self.rope_scaling is None: return - if not isinstance(self.rope_scaling, dict) or len( - self.rope_scaling) != 2: + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( "`rope_scaling` must be a dictionary with two fields, " - f"`type` and `factor`, got {self.rope_scaling}") + f"`type` and `factor`, got {self.rope_scaling}" + ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in [ - "linear", "dynamic" - ]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( "`rope_scaling`'s type field must be one of ['linear', " - f"'dynamic'], got {rope_scaling_type}") - if rope_scaling_factor is None or not isinstance( - rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + f"'dynamic'], got {rope_scaling_type}" + ) + if ( + rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0 + ): raise ValueError( "`rope_scaling`'s factor field must be a float > 1, got " - f"{rope_scaling_factor}") \ No newline at end of file + f"{rope_scaling_factor}" + ) diff --git a/vllm/transformers_utils/configs/nemotron_h.py b/vllm/transformers_utils/configs/nemotron_h.py index 581bed5716c1..c8b6784d6a8e 100644 --- a/vllm/transformers_utils/configs/nemotron_h.py +++ b/vllm/transformers_utils/configs/nemotron_h.py @@ -203,11 +203,11 @@ def __init__( # Validate hybrid_override_pattern # M: Mamba2, *: Attention, -: MLP assert len(self.hybrid_override_pattern) == self.num_hidden_layers, ( - "hybrid_override_pattern must have same length as " - "num_hidden_layers") + "hybrid_override_pattern must have same length as num_hidden_layers" + ) assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), ( - "hybrid_override_pattern must only contain characters " - "'M', '*', or '-'") + "hybrid_override_pattern must only contain characters 'M', '*', or '-'" + ) # for backward compatibility if num_key_value_heads is None: @@ -253,7 +253,10 @@ def __init__( @property def layers_block_type(self): return [ - "mamba" if self.hybrid_override_pattern[i] == "M" else - "attention" if self.hybrid_override_pattern[i] == "*" else "mlp" + "mamba" + if self.hybrid_override_pattern[i] == "M" + else "attention" + if self.hybrid_override_pattern[i] == "*" + else "mlp" for i in range(self.num_hidden_layers) ] diff --git a/vllm/transformers_utils/configs/nemotron_vl.py b/vllm/transformers_utils/configs/nemotron_vl.py index 6a642f26b82a..6f98fbafbed5 100644 --- a/vllm/transformers_utils/configs/nemotron_vl.py +++ b/vllm/transformers_utils/configs/nemotron_vl.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable # ruff: noqa: E501 # Adapted from # https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/configuration.py @@ -16,7 +15,7 @@ class Nemotron_Nano_VL_Config(PretrainedConfig): - model_type = 'Llama_Nemotron_Nano_VL' + model_type = "Llama_Nemotron_Nano_VL" is_composition = True def __init__( @@ -26,17 +25,22 @@ def __init__( force_image_size=None, downsample_ratio=0.5, template=None, - ps_version='v1', + ps_version="v1", image_tag_type="internvl", projector_hidden_size=4096, vit_hidden_size=1280, - **kwargs + **kwargs, ): super().__init__(**kwargs) if vision_config is not None: - assert "auto_map" in vision_config and "AutoConfig" in vision_config["auto_map"] - vision_auto_config = get_class_from_dynamic_module(*vision_config["auto_map"]["AutoConfig"].split("--")[::-1]) + assert ( + "auto_map" in vision_config + and "AutoConfig" in vision_config["auto_map"] + ) + vision_auto_config = get_class_from_dynamic_module( + *vision_config["auto_map"]["AutoConfig"].split("--")[::-1] + ) self.vision_config = vision_auto_config(**vision_config) else: self.vision_config = PretrainedConfig() @@ -51,6 +55,6 @@ def __init__( self.downsample_ratio = downsample_ratio self.template = template # TODO move out of here and into the tokenizer self.ps_version = ps_version # Pixel shuffle version - self.image_tag_type = image_tag_type # TODO: into the tokenizer too? + self.image_tag_type = image_tag_type # TODO: into the tokenizer too? self.projector_hidden_size = projector_hidden_size self.vit_hidden_size = vit_hidden_size diff --git a/vllm/transformers_utils/configs/olmo3.py b/vllm/transformers_utils/configs/olmo3.py index 874507db43a7..f5a9a7cd36bd 100644 --- a/vllm/transformers_utils/configs/olmo3.py +++ b/vllm/transformers_utils/configs/olmo3.py @@ -5,7 +5,6 @@ class Olmo3Config(PretrainedConfig): - model_type = "olmo3" keys_to_ignore_at_inference = ["past_key_values"] diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py index 550f5e15dbcc..404fa700a26c 100644 --- a/vllm/transformers_utils/configs/ovis.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable # ruff: noqa: E501 # adapted from https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_aimv2.py # and https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_ovis.py @@ -70,34 +69,37 @@ def __init__( # Visual Tokenizer Configuration # ---------------------------------------------------------------------- class BaseVisualTokenizerConfig(PretrainedConfig): - - def __init__(self, - vocab_size=16384, - tokenize_function="softmax", - tau=1.0, - depths=None, - drop_cls_token=False, - backbone_config: Optional[Union[PretrainedConfig, - dict]] = None, - hidden_stride: int = 1, - **kwargs): + def __init__( + self, + vocab_size=16384, + tokenize_function="softmax", + tau=1.0, + depths=None, + drop_cls_token=False, + backbone_config: Optional[Union[PretrainedConfig, dict]] = None, + hidden_stride: int = 1, + **kwargs, + ): super().__init__(**kwargs) self.vocab_size = vocab_size self.tokenize_function = tokenize_function self.tau = tau if isinstance(depths, str): - depths = [int(x) for x in depths.split('|')] + depths = [int(x) for x in depths.split("|")] self.depths = depths self.backbone_kwargs = dict[str, Any]() self.drop_cls_token = drop_cls_token if backbone_config is not None: - assert isinstance(backbone_config, (PretrainedConfig, dict)), \ + assert isinstance(backbone_config, (PretrainedConfig, dict)), ( f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" + ) if not isinstance(backbone_config, PretrainedConfig): - model_type = backbone_config['model_type'] + model_type = backbone_config["model_type"] if model_type != "aimv2": - backbone_config.pop('model_type') - backbone_config = AutoConfig.for_model(model_type, **backbone_config) + backbone_config.pop("model_type") + backbone_config = AutoConfig.for_model( + model_type, **backbone_config + ) else: backbone_config = AIMv2Config(**backbone_config) self.backbone_config = backbone_config @@ -113,7 +115,7 @@ def __init__(self, **kwargs): self.drop_cls_token = False if self.depths: assert len(self.depths) == 1 - self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + self.backbone_kwargs["num_hidden_layers"] = self.depths[0] class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig): @@ -125,7 +127,7 @@ def __init__(self, **kwargs): self.drop_cls_token = False if self.depths: assert len(self.depths) == 1 - self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + self.backbone_kwargs["num_hidden_layers"] = self.depths[0] AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig) @@ -138,35 +140,39 @@ def __init__(self, **kwargs): class OvisConfig(PretrainedConfig): model_type = "ovis" - def __init__(self, - llm_config: Optional[Union[PretrainedConfig, dict]] = None, - visual_tokenizer_config: Optional[Union[PretrainedConfig, - dict]] = None, - multimodal_max_length=8192, - hidden_size=None, - conversation_formatter_class=None, - llm_attn_implementation=None, - disable_tie_weight=False, - **kwargs): + def __init__( + self, + llm_config: Optional[Union[PretrainedConfig, dict]] = None, + visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None, + multimodal_max_length=8192, + hidden_size=None, + conversation_formatter_class=None, + llm_attn_implementation=None, + disable_tie_weight=False, + **kwargs, + ): super().__init__(**kwargs) if llm_config is not None: - assert isinstance(llm_config, (PretrainedConfig, dict)), \ + assert isinstance(llm_config, (PretrainedConfig, dict)), ( f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type" + ) if not isinstance(llm_config, PretrainedConfig): - model_type = llm_config['model_type'] - llm_config.pop('model_type') + model_type = llm_config["model_type"] + llm_config.pop("model_type") llm_config = AutoConfig.for_model(model_type, **llm_config) # map llm_config to text_config self.text_config = llm_config if visual_tokenizer_config is not None: - assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ + assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), ( f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type" + ) if not isinstance(visual_tokenizer_config, PretrainedConfig): - model_type = visual_tokenizer_config['model_type'] - visual_tokenizer_config.pop('model_type') + model_type = visual_tokenizer_config["model_type"] + visual_tokenizer_config.pop("model_type") visual_tokenizer_config = AutoConfig.for_model( - model_type, **visual_tokenizer_config) + model_type, **visual_tokenizer_config + ) self.visual_tokenizer_config = visual_tokenizer_config self.multimodal_max_length = multimodal_max_length diff --git a/vllm/transformers_utils/configs/qwen3_next.py b/vllm/transformers_utils/configs/qwen3_next.py index c7af26acd1b9..21750bde2f87 100644 --- a/vllm/transformers_utils/configs/qwen3_next.py +++ b/vllm/transformers_utils/configs/qwen3_next.py @@ -16,8 +16,7 @@ # limitations under the License. """Qwen3-Next model configuration""" -from transformers.configuration_utils import (PretrainedConfig, - layer_type_validation) +from transformers.configuration_utils import PretrainedConfig, layer_type_validation from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging diff --git a/vllm/transformers_utils/configs/radio.py b/vllm/transformers_utils/configs/radio.py index e1d96294d6ad..f13598034bae 100644 --- a/vllm/transformers_utils/configs/radio.py +++ b/vllm/transformers_utils/configs/radio.py @@ -81,11 +81,11 @@ def __init__( self.initializer_factor = initializer_factor self.hidden_act = hidden_act self.max_img_size = max_img_size - self.norm_mean = list(norm_mean) if isinstance(norm_mean, - (tuple, - list)) else norm_mean - self.norm_std = list(norm_std) if isinstance(norm_std, - (tuple, - list)) else norm_std + self.norm_mean = ( + list(norm_mean) if isinstance(norm_mean, (tuple, list)) else norm_mean + ) + self.norm_std = ( + list(norm_std) if isinstance(norm_std, (tuple, list)) else norm_std + ) self.reg_tokens = reg_tokens super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/speculators/algos.py b/vllm/transformers_utils/configs/speculators/algos.py index efc87b6bcf26..88bce3d4f79e 100644 --- a/vllm/transformers_utils/configs/speculators/algos.py +++ b/vllm/transformers_utils/configs/speculators/algos.py @@ -5,7 +5,6 @@ def register_speculator(name): - def decorator(fn): SUPPORTED_SPECULATORS_TYPES[name] = fn return fn @@ -17,16 +16,23 @@ def decorator(fn): def update_eagle3(config_dict: dict, vllm_config: dict) -> None: """ Apply Eagle-3 specific configuration transformations. - + Eagle-3 specific fields: - draft_vocab_size: Size of the draft model's vocabulary - target_hidden_size: Hidden size of the target model - norm_before_residual: Whether to apply norm before residual connection + - eagle_aux_hidden_state_layer_ids: List of layer indices from the base + model to use as auxiliary inputs for the Eagle3 drafter. These layers + provide intermediate hidden states that help the drafter make better + predictions. This is the standard field used in Eagle3 checkpoints. """ vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size") if config_dict.get("target_hidden_size") is not None: vllm_config["target_hidden_size"] = config_dict["target_hidden_size"] - vllm_config["norm_before_residual"] = config_dict.get( - "norm_before_residual", True) + vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True) vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"] + if config_dict.get("eagle_aux_hidden_state_layer_ids"): + vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ + "eagle_aux_hidden_state_layer_ids" + ] diff --git a/vllm/transformers_utils/configs/speculators/base.py b/vllm/transformers_utils/configs/speculators/base.py index 53128b4eecb0..1c415a43360e 100644 --- a/vllm/transformers_utils/configs/speculators/base.py +++ b/vllm/transformers_utils/configs/speculators/base.py @@ -6,7 +6,8 @@ from transformers import PretrainedConfig from vllm.transformers_utils.configs.speculators.algos import ( - SUPPORTED_SPECULATORS_TYPES) + SUPPORTED_SPECULATORS_TYPES, +) __all__ = ["SpeculatorsConfig"] @@ -21,27 +22,27 @@ def from_pretrained( **kwargs, ) -> "SpeculatorsConfig": """Load speculators Eagle config and convert to vLLM format.""" - config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, - **kwargs) + config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) vllm_config = cls.extract_vllm_speculative_config(config_dict) return cls(**vllm_config) @classmethod def extract_vllm_speculative_config( - cls, config_dict: dict[str, Any]) -> dict[str, Any]: + cls, config_dict: dict[str, Any] + ) -> dict[str, Any]: speculators_model_type = config_dict.get("speculators_model_type") if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES: raise ValueError( f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. " - "Please ensure you're loading a speculators-format model.") + "Please ensure you're loading a speculators-format model." + ) # validate fields # TODO: @dsikka - use speculators pydantic model to validate cls.validate_speculators_config(config_dict=config_dict) # Convert from speculators config -> format that can be ingested by vLLM - vllm_config = cls.build_vllm_speculative_config( - config_dict=config_dict) + vllm_config = cls.build_vllm_speculative_config(config_dict=config_dict) # Apply anything specific to the supported algorithm algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type] algo_updater(config_dict=config_dict, vllm_config=vllm_config) @@ -64,11 +65,13 @@ def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: if not isinstance(config_dict["transformer_layer_config"], dict): raise TypeError( - "'transformer_layer_config' must be a dictionary if provided") + "'transformer_layer_config' must be a dictionary if provided" + ) @classmethod def build_vllm_speculative_config( - cls, config_dict: dict[str, Any]) -> dict[str, Any]: + cls, config_dict: dict[str, Any] + ) -> dict[str, Any]: """ Build vLLM-compatible speculative configuration from speculators format. @@ -94,14 +97,14 @@ def build_vllm_speculative_config( if num_speculative_tokens is None: raise ValueError( - "Missing 'speculative_tokens' in proposal method. " - f"Got: {first_method}") + f"Missing 'speculative_tokens' in proposal method. Got: {first_method}" + ) # Build base vLLM speculative configuration vllm_config = { "method": config_dict.get("speculators_model_type"), "num_speculative_tokens": num_speculative_tokens, - "target_model": spec_config.get("verifier")["name_or_path"] + "target_model": spec_config.get("verifier")["name_or_path"], } # Merge transformer layer configuration if present diff --git a/vllm/transformers_utils/configs/step3_vl.py b/vllm/transformers_utils/configs/step3_vl.py index fe3c72de69d2..36d39e828a93 100644 --- a/vllm/transformers_utils/configs/step3_vl.py +++ b/vllm/transformers_utils/configs/step3_vl.py @@ -59,13 +59,64 @@ def __init__( share_q_dim: int = 2048, head_dim: int = 256, norm_expert_weight: bool = False, - moe_layers_enum: tuple[int, - ...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, - 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, - 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, - 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59), + moe_layers_enum: tuple[int, ...] = ( + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + ), **kwargs, ) -> None: self.hidden_size = hidden_size diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py index aaf31d84d0c1..ac22304e9125 100644 --- a/vllm/transformers_utils/configs/ultravox.py +++ b/vllm/transformers_utils/configs/ultravox.py @@ -42,6 +42,7 @@ class UltravoxConfig(transformers.PretrainedConfig): projector or at the end. Versions v0.4.1 and below use `False`, but v0.5 and above use `True`. """ + wrapped_model_config: transformers.PretrainedConfig model_type = "ultravox" audio_token = "<|audio|>" @@ -76,15 +77,17 @@ def __init__( if text_model_id is None: text_config = text_config or {} self.wrapped_model_config = transformers.CONFIG_MAPPING[ - text_config.get("model_type", "llama")](**text_config) + text_config.get("model_type", "llama") + ](**text_config) # N.B. May set the audio_config below. self.audio_model_id = audio_model_id if audio_model_id is None: self.audio_model_id = None audio_config = audio_config or {} - self.audio_config = transformers.CONFIG_MAPPING[audio_config.get( - "model_type", "whisper")](**audio_config) + self.audio_config = transformers.CONFIG_MAPPING[ + audio_config.get("model_type", "whisper") + ](**audio_config) super().__init__(**kwargs) @@ -99,8 +102,7 @@ def __setattr__(self, key, value): if key == "text_model_id" and value is not None: from vllm.transformers_utils.config import get_config - self.wrapped_model_config = get_config(value, - trust_remote_code=False) + self.wrapped_model_config = get_config(value, trust_remote_code=False) elif key == "audio_model_id" and value is not None: from vllm.transformers_utils.config import get_config diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index 101f31d39cc1..60742ae97d5d 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -30,8 +30,9 @@ def _convert_tokens_to_string_with_added_encoders( current_sub_text: list[str] = [] convert_tokens_to_string = tokenizer.convert_tokens_to_string added_vocab_set = set(tokenizer.get_added_vocab()) - all_special_tokens = set( - tokenizer.all_special_tokens) if skip_special_tokens else () + all_special_tokens = ( + set(tokenizer.all_special_tokens) if skip_special_tokens else () + ) for token in output_tokens: # Use precomputed set for skip-special check @@ -70,11 +71,11 @@ def convert_prompt_ids_to_tokens( # We do not need to convert the whole prompt to tokens. # Offset a little more in case we have special tokens. new_tokens = tokenizer.convert_ids_to_tokens( - prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], - skip_special_tokens=skip_special_tokens) + prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2 :], + skip_special_tokens=skip_special_tokens, + ) read_offset = len(new_tokens) - prefix_offset = max( - read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + prefix_offset = max(read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) # This is required to guard against out-of-vocab prompt token ids _replace_none_with_empty(new_tokens) # type: ignore[arg-type] return new_tokens, prefix_offset, read_offset @@ -92,7 +93,7 @@ def convert_ids_list_to_tokens( Returns: Python list of token string representations - + """ token_str_lst = [] for token_id in token_ids: @@ -144,18 +145,17 @@ def detokenize_incrementally( # This is the first iteration for this sequence is_first_iter = prev_tokens is None if is_first_iter: - (prev_tokens, prefix_offset, - read_offset) = convert_prompt_ids_to_tokens( - tokenizer, - all_input_ids[:-1], - skip_special_tokens=skip_special_tokens) + (prev_tokens, prefix_offset, read_offset) = convert_prompt_ids_to_tokens( + tokenizer, all_input_ids[:-1], skip_special_tokens=skip_special_tokens + ) assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. if 0 <= new_token_id < len(tokenizer): # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( - [new_token_id], skip_special_tokens=skip_special_tokens) + [new_token_id], skip_special_tokens=skip_special_tokens + ) if isinstance(new_tokens, str): new_tokens = [new_tokens] else: @@ -171,9 +171,9 @@ def detokenize_incrementally( # surrounding ids. if tokenizer.is_fast or not tokenizer.get_added_vocab(): prefix_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:read_offset]) - new_text = tokenizer.convert_tokens_to_string( - output_tokens[prefix_offset:]) + output_tokens[prefix_offset:read_offset] + ) + new_text = tokenizer.convert_tokens_to_string(output_tokens[prefix_offset:]) else: prefix_text = _convert_tokens_to_string_with_added_encoders( tokenizer, @@ -195,5 +195,5 @@ def detokenize_incrementally( # by the model return new_tokens, "", prefix_offset, read_offset - new_text = new_text[len(prefix_text):] + new_text = new_text[len(prefix_text) :] return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 51bcce6c10e2..81f9b76b5ef7 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -4,8 +4,12 @@ from functools import lru_cache from typing import TYPE_CHECKING, Any, Optional, Union, cast -from transformers import (AutoFeatureExtractor, AutoImageProcessor, - AutoProcessor, AutoVideoProcessor) +from transformers import ( + AutoFeatureExtractor, + AutoImageProcessor, + AutoProcessor, + AutoVideoProcessor, +) from transformers.feature_extraction_utils import FeatureExtractionMixin from transformers.image_processing_utils import BaseImageProcessor from transformers.processing_utils import ProcessorMixin @@ -121,15 +125,18 @@ def get_processor( "a custom processor not yet available in the HuggingFace " "transformers library, consider setting " "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") + "`--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e if not isinstance(processor, processor_cls): - raise TypeError("Invalid type of HuggingFace processor. " - f"Expected type: {processor_cls}, but " - f"found type: {type(processor)}") + raise TypeError( + "Invalid type of HuggingFace processor. " + f"Expected type: {processor_cls}, but " + f"found type: {type(processor)}" + ) return processor @@ -158,7 +165,7 @@ def get_feature_extractor( trust_remote_code: bool = False, **kwargs: Any, ): - """Load an audio feature extractor for the given model name + """Load an audio feature extractor for the given model name via HuggingFace.""" try: feature_extractor = AutoFeatureExtractor.from_pretrained( @@ -166,7 +173,8 @@ def get_feature_extractor( *args, revision=revision, trust_remote_code=trust_remote_code, - **kwargs) + **kwargs, + ) except ValueError as e: # If the error pertains to the processor class not existing or not # currently being imported, suggest using the --trust-remote-code flag. @@ -177,7 +185,8 @@ def get_feature_extractor( "extractor is a custom extractor not yet available in the " "HuggingFace transformers library, consider setting " "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") + "`--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e @@ -213,7 +222,8 @@ def get_image_processor( *args, revision=revision, trust_remote_code=trust_remote_code, - **kwargs) + **kwargs, + ) except ValueError as e: # If the error pertains to the processor class not existing or not # currently being imported, suggest using the --trust-remote-code flag. @@ -224,7 +234,8 @@ def get_image_processor( "a custom processor not yet available in the HuggingFace " "transformers library, consider setting " "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") + "`--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e @@ -263,7 +274,8 @@ def get_video_processor( *args, revision=revision, trust_remote_code=trust_remote_code, - **kwargs) + **kwargs, + ) except ValueError as e: # If the error pertains to the processor class not existing or not # currently being imported, suggest using the --trust-remote-code flag. @@ -274,7 +286,8 @@ def get_video_processor( "a custom processor not yet available in the HuggingFace " "transformers library, consider setting " "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") + "`--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index 8a1ad226d99f..76b6d3dc9c99 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -8,8 +8,7 @@ - There is a need to override the existing processor to support vLLM. """ -from vllm.transformers_utils.processors.deepseek_vl2 import ( - DeepseekVLV2Processor) +from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor diff --git a/vllm/transformers_utils/processors/deepseek_vl2.py b/vllm/transformers_utils/processors/deepseek_vl2.py index d1d117b4e2cf..5ef258b9be29 100644 --- a/vllm/transformers_utils/processors/deepseek_vl2.py +++ b/vllm/transformers_utils/processors/deepseek_vl2.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable # ruff: noqa: E501 # coding=utf-8 # adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/ff23960c5cf9e6874b44be38af930cfb0ccbb620/deepseek_vl2/models/processing_deepseek_vl_v2.py @@ -35,11 +34,12 @@ class ImageTransform: - - def __init__(self, - mean: tuple[float, float, float] = (0.5, 0.5, 0.5), - std: tuple[float, float, float] = (0.5, 0.5, 0.5), - normalize: bool = True): + def __init__( + self, + mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + std: tuple[float, float, float] = (0.5, 0.5, 0.5), + normalize: bool = True, + ): self.mean = mean self.std = std self.normalize = normalize @@ -77,7 +77,6 @@ def __init__( ignore_id: int = -100, **kwargs, ): - self.candidate_resolutions = candidate_resolutions self.image_size = candidate_resolutions[0][0] self.patch_size = patch_size @@ -86,13 +85,15 @@ def __init__( self.normalize = normalize self.downsample_ratio = downsample_ratio - self.image_transform = ImageTransform(mean=image_mean, std=image_std, normalize=normalize) + self.image_transform = ImageTransform( + mean=image_mean, std=image_std, normalize=normalize + ) self.tokenizer = tokenizer - self.tokenizer.padding_side = 'left' # must set this,padding side with make a difference in batch inference + self.tokenizer.padding_side = "left" # must set this,padding side with make a difference in batch inference # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id' if tokenizer.pad_token is None: - self.tokenizer.add_special_tokens({'pad_token': pad_token}) + self.tokenizer.add_special_tokens({"pad_token": pad_token}) # add image token image_token_id = self.tokenizer.vocab.get(image_token) @@ -104,7 +105,7 @@ def __init__( # add five special tokens for grounding-related tasks # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|> - special_tokens = ['<|ref|>', '<|/ref|>', '<|det|>', '<|/det|>', '<|grounding|>'] + special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"] special_tokens_dict = {"additional_special_tokens": special_tokens} self.tokenizer.add_special_tokens(special_tokens_dict) @@ -134,15 +135,19 @@ def select_best_resolution(self, image_size): for width, height in self.candidate_resolutions: scale = min(width / original_width, height / original_height) - downscaled_width, downscaled_height = int( - original_width * scale), int(original_height * scale) - effective_resolution = min(downscaled_width * downscaled_height, - original_width * original_height) + downscaled_width, downscaled_height = ( + int(original_width * scale), + int(original_height * scale), + ) + effective_resolution = min( + downscaled_width * downscaled_height, original_width * original_height + ) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or ( - effective_resolution == max_effective_resolution - and wasted_resolution < min_wasted_resolution): + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution + ): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) @@ -198,12 +203,20 @@ def process_one( - num_image_tokens (list[int]): the number of image tokens """ - assert (prompt is not None and images is not None - ), "prompt and images must be used at the same time." + assert prompt is not None and images is not None, ( + "prompt and images must be used at the same time." + ) sft_format = prompt - tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens = self.tokenize_with_images( - sft_format, images, bos=True, eos=True, cropping=len(images) <= 2) + ( + tokenized_str, + images_list, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + ) = self.tokenize_with_images( + sft_format, images, bos=True, eos=True, cropping=len(images) <= 2 + ) masked_tokenized_str = [] for token_index in tokenized_str: if token_index != self.image_token_id: @@ -211,17 +224,21 @@ def process_one( else: masked_tokenized_str.append(self.ignore_id) - assert len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str), \ - (f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " - f"imags_seq_mask's length {len(images_seq_mask)}, are not equal") + assert ( + len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str) + ), ( + f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, " + f"imags_seq_mask's length {len(images_seq_mask)}, are not equal" + ) input_ids = torch.LongTensor(tokenized_str) target_ids = torch.LongTensor(masked_tokenized_str) images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) # set input_ids < 0 | input_ids == self.image_token_id as ignore_id - target_ids[(input_ids < 0) | - (input_ids == self.image_token_id)] = self.ignore_id + target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = ( + self.ignore_id + ) input_ids[input_ids < 0] = self.pad_id if inference_mode: @@ -311,30 +328,50 @@ def tokenize_with_images( best_width, best_height = self.image_size, self.image_size """process the global view""" - global_view = ImageOps.pad(image, (self.image_size, self.image_size), - color=tuple(int(x * 255) for x in self.image_transform.mean)) + global_view = ImageOps.pad( + image, + (self.image_size, self.image_size), + color=tuple(int(x * 255) for x in self.image_transform.mean), + ) images_list.append(self.image_transform(global_view)) """process the local views""" - local_view = ImageOps.pad(image, (best_width, best_height), - color=tuple(int(x * 255) for x in self.image_transform.mean)) + local_view = ImageOps.pad( + image, + (best_width, best_height), + color=tuple(int(x * 255) for x in self.image_transform.mean), + ) for i in range(0, best_height, self.image_size): for j in range(0, best_width, self.image_size): images_list.append( - self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size)))) + self.image_transform( + local_view.crop( + (j, i, j + self.image_size, i + self.image_size) + ) + ) + ) """record height / width crop num""" - num_width_tiles, num_height_tiles = best_width // self.image_size, best_height // self.image_size + num_width_tiles, num_height_tiles = ( + best_width // self.image_size, + best_height // self.image_size, + ) images_spatial_crop.append([num_width_tiles, num_height_tiles]) """add image tokens""" - h = w = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) + h = w = math.ceil( + (self.image_size // self.patch_size) / self.downsample_ratio + ) # global views tokens h * (w + 1), 1 is for line separator tokenized_image = [self.image_token_id] * h * (w + 1) # add a separator between global and local views tokenized_image += [self.image_token_id] # local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1) - tokenized_image += [self.image_token_id] * (num_height_tiles * h) * (num_width_tiles * w + 1) + tokenized_image += ( + [self.image_token_id] + * (num_height_tiles * h) + * (num_width_tiles * w + 1) + ) tokenized_str += tokenized_image images_seq_mask += [True] * len(tokenized_image) @@ -353,10 +390,17 @@ def tokenize_with_images( tokenized_str = tokenized_str + [self.eos_id] images_seq_mask = images_seq_mask + [False] - assert len(tokenized_str) == len( - images_seq_mask), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" + assert len(tokenized_str) == len(images_seq_mask), ( + f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" + ) - return tokenized_str, images_list, images_seq_mask, images_spatial_crop, num_image_tokens + return ( + tokenized_str, + images_list, + images_seq_mask, + images_spatial_crop, + num_image_tokens, + ) AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor) diff --git a/vllm/transformers_utils/processors/ovis.py b/vllm/transformers_utils/processors/ovis.py index 0077a7a8ce65..6d52ab48c970 100644 --- a/vllm/transformers_utils/processors/ovis.py +++ b/vllm/transformers_utils/processors/ovis.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# yapf: disable # ruff: noqa: E501 # coding=utf-8 # adapted from https://github.com/AIDC-AI/Ovis/blob/35ab51a1a1e3542fa6db260a1084cefbc8f164bb/ovis/vllm/processing_ovis.py @@ -30,29 +29,29 @@ import torch from transformers import AutoProcessor, BatchFeature from transformers.image_utils import ImageInput -from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, - Unpack) +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from vllm.multimodal.image import convert_image_mode -__all__ = ['OvisProcessor'] +__all__ = ["OvisProcessor"] IGNORE_ID = -100 -class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg] + +class OvisProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg] _defaults = { "text_kwargs": { "padding": False, }, "images_kwargs": { - 'max_partition':9, - 'covering_threshold':0.9, - 'convert_to_rgb':True, - 'return_tensors':'pt'}, + "max_partition": 9, + "covering_threshold": 0.9, + "convert_to_rgb": True, + "return_tensors": "pt", + }, } - class OvisProcessor(ProcessorMixin): r""" Constructs an Ovis processor which wraps an Ovis image processor and a Qwen2 tokenizer into a single processor. @@ -98,14 +97,16 @@ def extra_special_tokens(self): "image_col_sep": -303, "image_row_sep": -304, "image_end": -305, - 'image_pad': image_pad_token_id, + "image_pad": image_pad_token_id, } return extra_special_tokens def __call__( self, images: ImageInput = None, - text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + text: Union[ + TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] + ] = None, **kwargs: Unpack[OvisProcessorKwargs], ) -> BatchFeature: """ @@ -170,7 +171,6 @@ def __call__( # Process text input if text is not None: - if not isinstance(text, list): text = [text] @@ -179,7 +179,10 @@ def __call__( replaced_ids_list = [] idx = 0 for ids_tensor in tokenized_batched_text: - if image_token_id in ids_tensor and "image_placeholders" in image_features: + if ( + image_token_id in ids_tensor + and "image_placeholders" in image_features + ): if idx < len(image_features["image_placeholders"]): # Converts in list for ease of use ids_list = ids_tensor.tolist() @@ -189,7 +192,9 @@ def __call__( # replace placeholders for i, token_id in enumerate(ids_list): if token_id == image_token_id: - placeholder_ids = image_features["image_placeholders"][idx] + placeholder_ids = image_features["image_placeholders"][ + idx + ] new_ids.extend(placeholder_ids) idx += 1 else: @@ -199,7 +204,8 @@ def __call__( ids_tensor = torch.tensor(new_ids, dtype=torch.long) else: raise RuntimeError( - 'Mismatch between the images you provided and the number of placeholder present in the text') + "Mismatch between the images you provided and the number of placeholder present in the text" + ) replaced_ids_list.append(ids_tensor) @@ -218,7 +224,7 @@ def __call__( # Add image features if present if image_features: output["pixel_values"] = processed_images - output['grids'] = grids + output["grids"] = grids return output @@ -228,8 +234,10 @@ def __call__( def _tokenize_with_image_symbol(self, text_list: list[str]) -> torch.LongTensor: batch_token_ids = [] for text in text_list: - text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in - text.split(self.image_token)] + text_chunks = [ + self.tokenizer(chunk, add_special_tokens=False).input_ids + for chunk in text.split(self.image_token) + ] token_ids = [] num_chuck = len(text_chunks) for i, chunk in enumerate(text_chunks): @@ -241,50 +249,60 @@ def _tokenize_with_image_symbol(self, text_list: list[str]) -> torch.LongTensor: def get_image_size(self): size = self.image_processor.size - if 'shortest_edge' in size: - width = height = size['shortest_edge'] + if "shortest_edge" in size: + width = height = size["shortest_edge"] elif "height" in size and "width" in size: - width = size['width'] - height = size['height'] + width = size["width"] + height = size["height"] else: - raise ValueError( "Can't parse image size from image_processor config.") + raise ValueError("Can't parse image size from image_processor config.") return height, width def get_token_value(self, tok): return self.extra_special_tokens[tok] def construct_image_indicators(self, grid): - image_placeholders = [self.get_token_value('image_start'), - self.get_token_value('image_atom'), - self.get_token_value('image_prefix')] + image_placeholders = [ + self.get_token_value("image_start"), + self.get_token_value("image_atom"), + self.get_token_value("image_prefix"), + ] if grid[0] * grid[1] > 1: for r in range(grid[0]): for c in range(grid[1]): - image_placeholders.append(self.get_token_value('image_atom') ) + image_placeholders.append(self.get_token_value("image_atom")) if c < grid[1] - 1: - image_placeholders.append(self.get_token_value('image_col_sep')) + image_placeholders.append(self.get_token_value("image_col_sep")) if r < grid[0] - 1: - image_placeholders.append(self.get_token_value('image_row_sep')) - image_placeholders.append(self.get_token_value('image_end')) + image_placeholders.append(self.get_token_value("image_row_sep")) + image_placeholders.append(self.get_token_value("image_end")) return image_placeholders def construct_image_placeholders(self, grid): - image_placeholders = self.construct_image_indicators(grid) - image_atom_token_id = self.get_token_value('image_atom') + image_atom_token_id = self.get_token_value("image_atom") # Extract the padding token ID from tokenizer - image_padding_token_id = self.get_token_value('image_pad') + image_padding_token_id = self.get_token_value("image_pad") # Create a new list with padding tokens inserted padded_placeholder_tokens = [] for token in image_placeholders: padded_placeholder_tokens.append(image_padding_token_id) if token == image_atom_token_id: - padded_placeholder_tokens.extend([image_padding_token_id] * self.image_segment_len) + padded_placeholder_tokens.extend( + [image_padding_token_id] * self.image_segment_len + ) return padded_placeholder_tokens - def preprocess_image(self, image: PIL.Image.Image, max_partition, covering_threshold, convert_to_rgb, return_tensors): + def preprocess_image( + self, + image: PIL.Image.Image, + max_partition, + covering_threshold, + convert_to_rgb, + return_tensors, + ): def _preprocess(img: PIL.Image.Image, side): # first resize and preprocess w, h = img.size @@ -297,19 +315,27 @@ def _preprocess(img: PIL.Image.Image, side): new_height = side new_width = int(w / h * new_height) new_size = dict(height=new_height, width=new_width) - pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors=return_tensors)['pixel_values'] + pixel_values = self.image_processor.preprocess( + img, size=new_size, return_tensors=return_tensors + )["pixel_values"] # then pad to square - square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device) + square_values = torch.zeros( + [1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device + ) new_height, new_width = pixel_values.shape[2:] if new_height == new_width: square_values[:, :, :, :] = pixel_values elif new_height > new_width: from_index = (side - new_width) // 2 - square_values[:, :, :, from_index:from_index + new_width] = pixel_values + square_values[:, :, :, from_index : from_index + new_width] = ( + pixel_values + ) else: from_index = (side - new_height) // 2 - square_values[:, :, from_index:from_index + new_height, :] = pixel_values + square_values[:, :, from_index : from_index + new_height, :] = ( + pixel_values + ) return square_values @@ -351,7 +377,9 @@ def _get_best_grid(img, side): good_grids = [] for grid in candidate_grids: partition = _partition(img, grid) - covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area + covering_ratio = ( + sum([_covering_area(*p, side) for p in partition]) / img_area + ) assert covering_ratio <= 1.0 all_grids.append((grid, covering_ratio)) if covering_ratio > covering_threshold: @@ -359,18 +387,19 @@ def _get_best_grid(img, side): if len(good_grids) > 0: # pick the good partition with minimum #sub_images and break the tie using covering_ratio - return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0] + return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][ + 0 + ] else: # pick the partition with maximum covering_ratio and break the tie using #sub_images return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] if convert_to_rgb: - image = convert_image_mode(image, 'RGB') - + image = convert_image_mode(image, "RGB") sides = self.get_image_size() if sides[0] != sides[1]: - raise ValueError('get_image_size() returns non-square size') + raise ValueError("get_image_size() returns non-square size") side = sides[0] grid = _get_best_grid(image, side) partition = _partition(image, grid) @@ -406,14 +435,18 @@ def post_process_image_text_to_text(self, generated_outputs): `list[str]`: The decoded text. """ return self.tokenizer.batch_decode( - generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False + generated_outputs, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, ) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names - names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + names_from_processor = list( + dict.fromkeys(tokenizer_input_names + image_processor_input_names) + ) return names_from_processor + ["second_per_grid_ts"] diff --git a/vllm/transformers_utils/processors/ovis2_5.py b/vllm/transformers_utils/processors/ovis2_5.py index 282e9cb2116e..fba26d1d0304 100644 --- a/vllm/transformers_utils/processors/ovis2_5.py +++ b/vllm/transformers_utils/processors/ovis2_5.py @@ -9,33 +9,31 @@ import torch from transformers import AutoProcessor, BatchFeature from transformers.image_utils import ImageInput -from transformers.processing_utils import (ProcessingKwargs, ProcessorMixin, - Unpack) +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -__all__ = ['Ovis2_5Processor'] +__all__ = ["Ovis2_5Processor"] IMAGE_TOKEN = "<image>" VIDEO_TOKEN = "<video>" MIN_PIXELS = 448 * 448 MAX_PIXELS = 1792 * 1792 -class Ovis2_5ProcessorKwargs(ProcessingKwargs, - total=False): # type: ignore[call-arg] +class Ovis2_5ProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg] _defaults = { "text_kwargs": { "padding": False, }, "images_kwargs": { - 'convert_to_rgb': True, - 'min_pixels': MIN_PIXELS, - 'max_pixels': MAX_PIXELS, + "convert_to_rgb": True, + "min_pixels": MIN_PIXELS, + "max_pixels": MAX_PIXELS, }, "videos_kwargs": { - 'convert_to_rgb': True, - 'min_pixels': MIN_PIXELS, - 'max_pixels': MAX_PIXELS, - } + "convert_to_rgb": True, + "min_pixels": MIN_PIXELS, + "max_pixels": MAX_PIXELS, + }, } @@ -43,8 +41,8 @@ class Ovis2_5Processor(ProcessorMixin): r""" Constructs an Ovis processor which wraps an Ovis image processor and a Qwen2 tokenizer into a single processor. - [`OvisProcessor`] offers all the functionalities of - [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. + [`OvisProcessor`] offers all the functionalities of + [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the [`~OvisProcessor.__call__`] and [`~OvisProcessor.decode`] for more information. Args: @@ -81,9 +79,7 @@ def __init__( self.patch_size = patch_size self.hidden_stride = hidden_stride self.temporal_patch_size = temporal_patch_size - super().__init__(image_processor, - tokenizer, - chat_template=chat_template) + super().__init__(image_processor, tokenizer, chat_template=chat_template) @cached_property def extra_special_tokens(self): @@ -96,7 +92,7 @@ def extra_special_tokens(self): "image_end": -302, "video_start": -303, "video_end": -304, - 'image_pad': image_pad_token_id, + "image_pad": image_pad_token_id, } return extra_special_tokens @@ -104,8 +100,9 @@ def __call__( self, images: ImageInput = None, videos: Union[np.ndarray, list[ImageInput]] = None, - text: Union[TextInput, PreTokenizedInput, list[TextInput], - list[PreTokenizedInput]] = None, + text: Union[ + TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput] + ] = None, **kwargs: Unpack[Ovis2_5ProcessorKwargs], ) -> BatchFeature: """ @@ -148,9 +145,9 @@ def __call__( [`BatchFeature`]: A [`BatchFeature`] with the following fields: - **input_ids** -- list of token ids to be fed to a model. Returned when `text` is not `None`. - - **attention_mask** -- list of indices specifying which tokens + - **attention_mask** -- list of indices specifying which tokens should be attended to by the model (when - `return_attention_mask=True` or if *"attention_mask"* + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. @@ -177,9 +174,9 @@ def __call__( grids = [] # Process each image for image in images if isinstance(images, list) else [images]: - pixel_values, image_placeholders, grid = ( - self.preprocess_multidata( - images=image, **output_kwargs["images_kwargs"])) + pixel_values, image_placeholders, grid = self.preprocess_multidata( + images=image, **output_kwargs["images_kwargs"] + ) processed_images.append(pixel_values) image_placeholders_list.append(image_placeholders) grids.append(grid) @@ -196,16 +193,15 @@ def __call__( grids = [] # Process each video for video in videos if isinstance(videos, list) else [videos]: - pixel_values, video_placeholders, grid = ( - self.preprocess_multidata( - video=video, **output_kwargs["videos_kwargs"])) + pixel_values, video_placeholders, grid = self.preprocess_multidata( + video=video, **output_kwargs["videos_kwargs"] + ) processed_videos.append(pixel_values) videos_placeholders_list.append(video_placeholders) grids.append(grid) # assign all processed videos if processed_videos: - visual_features[ - "video_placeholders"] = videos_placeholders_list + visual_features["video_placeholders"] = videos_placeholders_list output["video_pixel_values"] = processed_videos output["video_grids"] = grids @@ -220,14 +216,16 @@ def __call__( image_idx = 0 video_idx = 0 for ids_tensor in tokenized_batched_text: - has_image_tokens = (image_token_id in ids_tensor - and "image_placeholders" in visual_features - and image_idx < len( - visual_features["image_placeholders"])) - has_video_tokens = (video_token_id in ids_tensor - and "video_placeholders" in visual_features - and video_idx < len( - visual_features["video_placeholders"])) + has_image_tokens = ( + image_token_id in ids_tensor + and "image_placeholders" in visual_features + and image_idx < len(visual_features["image_placeholders"]) + ) + has_video_tokens = ( + video_token_id in ids_tensor + and "video_placeholders" in visual_features + and video_idx < len(visual_features["video_placeholders"]) + ) if has_image_tokens or has_video_tokens: # Convert to list for easier manipulation ids_list = ids_tensor.tolist() @@ -237,13 +235,13 @@ def __call__( for token_id in ids_list: if token_id == image_token_id: new_ids.extend( - visual_features["image_placeholders"] - [image_idx]) + visual_features["image_placeholders"][image_idx] + ) image_idx += 1 elif token_id == video_token_id: new_ids.extend( - visual_features["video_placeholders"] - [video_idx]) + visual_features["video_placeholders"][video_idx] + ) video_idx += 1 else: new_ids.append(token_id) @@ -260,8 +258,7 @@ def __call__( # If only images were provided return BatchFeature(data=visual_features) - def _tokenize_with_visual_symbol(self, - text_list: list[str]) -> torch.LongTensor: + def _tokenize_with_visual_symbol(self, text_list: list[str]) -> torch.LongTensor: batch_token_ids = [] for text in text_list: token_ids = [] @@ -288,21 +285,24 @@ def _tokenize_with_visual_symbol(self, return torch.tensor(batch_token_ids, dtype=torch.long) # Copied from qwen2_vl - def smart_resize(self, - height: int, - width: int, - factor: int = 28, - min_pixels: int = MIN_PIXELS, - max_pixels: int = MAX_PIXELS): + def smart_resize( + self, + height: int, + width: int, + factor: int = 28, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, + ): """Rescales the image so that the following conditions are met: 1. Both dimensions (height and width) are divisible by 'factor'. - 2. The total number of pixels is within the range + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 3. The aspect ratio of the image is maintained as closely as possible. """ if height < factor or width < factor: - print(f"height:{height} or width:{width} must be " - f"larger than factor:{factor}") + print( + f"height:{height} or width:{width} must be larger than factor:{factor}" + ) if height < width: width = round(factor / height * width) height = factor @@ -311,8 +311,10 @@ def smart_resize(self, width = factor elif max(height, width) / min(height, width) > 200: - print(f"absolute aspect ratio must be smaller than 200, " - f"got {max(height, width) / min(height, width)}") + print( + f"absolute aspect ratio must be smaller than 200, " + f"got {max(height, width) / min(height, width)}" + ) if height > width: height = 200 * width else: @@ -335,29 +337,27 @@ def get_token_value(self, tok): def construct_visual_indicators(self, grid, is_video: bool = False): if is_video: - start_token = self.get_token_value('video_start') - end_token = self.get_token_value('video_end') + start_token = self.get_token_value("video_start") + end_token = self.get_token_value("video_end") else: - start_token = self.get_token_value('image_start') - end_token = self.get_token_value('image_end') + start_token = self.get_token_value("image_start") + end_token = self.get_token_value("image_end") - image_placeholders = [start_token, self.get_token_value('visual_atom')] + image_placeholders = [start_token, self.get_token_value("visual_atom")] if grid[0] * grid[1] > 1: for r in range(grid[0]): for c in range(grid[1]): - image_placeholders.append( - self.get_token_value('visual_atom')) + image_placeholders.append(self.get_token_value("visual_atom")) image_placeholders.append(end_token) return image_placeholders def construct_visual_placeholders(self, grid, is_video: bool = False): - visual_placeholders = self.construct_visual_indicators((1, 1), - is_video) + visual_placeholders = self.construct_visual_indicators((1, 1), is_video) - image_atom_token_id = self.get_token_value('visual_atom') + image_atom_token_id = self.get_token_value("visual_atom") # Extract the padding token ID from tokenizer - image_padding_token_id = self.get_token_value('image_pad') + image_padding_token_id = self.get_token_value("image_pad") num_image_atoms = grid[0] * grid[1] * grid[2] num_image_atoms //= self.hidden_stride**2 @@ -367,8 +367,9 @@ def construct_visual_placeholders(self, grid, is_video: bool = False): padded_placeholder_tokens = [] for token in visual_placeholders: if token == image_atom_token_id: - padded_placeholder_tokens.extend([image_padding_token_id] * - num_image_atoms) + padded_placeholder_tokens.extend( + [image_padding_token_id] * num_image_atoms + ) else: padded_placeholder_tokens.append(image_padding_token_id) return padded_placeholder_tokens @@ -380,7 +381,7 @@ def preprocess_multidata( convert_to_rgb: Optional[bool] = True, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS, - return_tensors: Optional[str] = 'pt', + return_tensors: Optional[str] = "pt", ): is_video = False if images is not None: @@ -396,11 +397,12 @@ def preprocess_multidata( images.append(image) elif isinstance(video, list): images = video - min_pixels = min(max_pixels if max_pixels is not None else MAX_PIXELS, - min_pixels if min_pixels is not None else MIN_PIXELS) + min_pixels = min( + max_pixels if max_pixels is not None else MAX_PIXELS, + min_pixels if min_pixels is not None else MIN_PIXELS, + ) images = [ - image.convert("RGB") - if convert_to_rgb and image.mode != 'RGB' else image + image.convert("RGB") if convert_to_rgb and image.mode != "RGB" else image for image in images ] @@ -417,14 +419,16 @@ def preprocess_multidata( ) new_size = dict(height=resized_height, width=resized_width) image_pt = self.image_processor.preprocess( - image, size=new_size, return_tensors="np")['pixel_values'][0] + image, size=new_size, return_tensors="np" + )["pixel_values"][0] processed_images.append(image_pt) patches = np.array(processed_images) if patches.shape[0] % self.temporal_patch_size != 0: - num_to_pad = self.temporal_patch_size - (patches.shape[0] % - self.temporal_patch_size) + num_to_pad = self.temporal_patch_size - ( + patches.shape[0] % self.temporal_patch_size + ) repeats = np.repeat(patches[-1][np.newaxis], num_to_pad, axis=0) patches = np.concatenate([patches, repeats], axis=0) channel = patches.shape[1] @@ -445,14 +449,18 @@ def preprocess_multidata( ) patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( - grid_t * grid_h * grid_w, channel * self.temporal_patch_size * - self.patch_size * self.patch_size) + grid_t * grid_h * grid_w, + channel * self.temporal_patch_size * self.patch_size * self.patch_size, + ) visual_placeholders = self.construct_visual_placeholders( - [grid_t, grid_h, grid_w], is_video) - return torch.tensor( - flatten_patches), visual_placeholders, torch.tensor( - [[grid_t, grid_h, grid_w]]) + [grid_t, grid_h, grid_w], is_video + ) + return ( + torch.tensor(flatten_patches), + visual_placeholders, + torch.tensor([[grid_t, grid_h, grid_w]]), + ) AutoProcessor.register("Ovis2_5Processor", Ovis2_5Processor) diff --git a/vllm/transformers_utils/runai_utils.py b/vllm/transformers_utils/runai_utils.py index 355fd60e8da1..ec60d66e5cff 100644 --- a/vllm/transformers_utils/runai_utils.py +++ b/vllm/transformers_utils/runai_utils.py @@ -14,7 +14,7 @@ logger = init_logger(__name__) -SUPPORTED_SCHEMES = ['s3://', 'gs://'] +SUPPORTED_SCHEMES = ["s3://", "gs://"] try: from runai_model_streamer import list_safetensors as runai_list_safetensors @@ -22,11 +22,9 @@ except (ImportError, OSError): # see https://github.com/run-ai/runai-model-streamer/issues/26 # OSError will be raised on arm64 platform - runai_model_streamer = PlaceholderModule( - "runai_model_streamer") # type: ignore[assignment] + runai_model_streamer = PlaceholderModule("runai_model_streamer") # type: ignore[assignment] runai_pull_files = runai_model_streamer.placeholder_attr("pull_files") - runai_list_safetensors = runai_model_streamer.placeholder_attr( - "list_safetensors") + runai_list_safetensors = runai_model_streamer.placeholder_attr("list_safetensors") def list_safetensors(path: str = "") -> list[str]: @@ -65,8 +63,10 @@ def __init__(self, url: str) -> None: signal.signal(sig, self._close_by_signal(existing_handler)) dir_name = os.path.join( - get_cache_dir(), "model_streamer", - hashlib.sha256(str(url).encode()).hexdigest()[:8]) + get_cache_dir(), + "model_streamer", + hashlib.sha256(str(url).encode()).hexdigest()[:8], + ) if os.path.exists(dir_name): shutil.rmtree(dir_name) os.makedirs(dir_name) @@ -78,7 +78,6 @@ def _close(self) -> None: shutil.rmtree(self.dir) def _close_by_signal(self, existing_handler=None): - def new_handler(signum, frame): self._close() if existing_handler: @@ -86,10 +85,12 @@ def new_handler(signum, frame): return new_handler - def pull_files(self, - model_path: str = "", - allow_pattern: Optional[list[str]] = None, - ignore_pattern: Optional[list[str]] = None) -> None: + def pull_files( + self, + model_path: str = "", + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None, + ) -> None: """ Pull files from object storage into the temporary directory. diff --git a/vllm/transformers_utils/s3_utils.py b/vllm/transformers_utils/s3_utils.py index b848898ff6da..ef30efd80b1f 100644 --- a/vllm/transformers_utils/s3_utils.py +++ b/vllm/transformers_utils/s3_utils.py @@ -17,21 +17,25 @@ def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]: return [ - path for path in paths if any( - fnmatch.fnmatch(path, pattern) for pattern in patterns) + path + for path in paths + if any(fnmatch.fnmatch(path, pattern) for pattern in patterns) ] def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]: return [ - path for path in paths + path + for path in paths if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns) ] -def glob(s3: Optional["BaseClient"] = None, - path: str = "", - allow_pattern: Optional[list[str]] = None) -> list[str]: +def glob( + s3: Optional["BaseClient"] = None, + path: str = "", + allow_pattern: Optional[list[str]] = None, +) -> list[str]: """ List full file names from S3 path and filter by allow pattern. @@ -47,17 +51,15 @@ def glob(s3: Optional["BaseClient"] = None, s3 = boto3.client("s3") if not path.endswith("/"): path = path + "/" - bucket_name, _, paths = list_files(s3, - path=path, - allow_pattern=allow_pattern) + bucket_name, _, paths = list_files(s3, path=path, allow_pattern=allow_pattern) return [f"s3://{bucket_name}/{path}" for path in paths] def list_files( - s3: "BaseClient", - path: str, - allow_pattern: Optional[list[str]] = None, - ignore_pattern: Optional[list[str]] = None + s3: "BaseClient", + path: str, + allow_pattern: Optional[list[str]] = None, + ignore_pattern: Optional[list[str]] = None, ) -> tuple[str, str, list[str]]: """ List files from S3 path and filter by pattern. @@ -71,17 +73,17 @@ def list_files( Returns: tuple[str, str, list[str]]: A tuple where: - The first element is the bucket name - - The second element is string represent the bucket + - The second element is string represent the bucket and the prefix as a dir like string - - The third element is a list of files allowed or + - The third element is a list of files allowed or disallowed by pattern """ - parts = path.removeprefix('s3://').split('/') - prefix = '/'.join(parts[1:]) + parts = path.removeprefix("s3://").split("/") + prefix = "/".join(parts[1:]) bucket_name = parts[0] objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) - paths = [obj['Key'] for obj in objects.get('Contents', [])] + paths = [obj["Key"] for obj in objects.get("Contents", [])] paths = _filter_ignore(paths, ["*/"]) if allow_pattern is not None: diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 9aaac6681739..9537295c6dcd 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -10,14 +10,12 @@ from typing import TYPE_CHECKING, Any, Optional, Union import huggingface_hub -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import assert_never from vllm import envs from vllm.logger import init_logger -from vllm.transformers_utils.config import ( - get_sentence_transformer_tokenizer_config) +from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.utils import check_gguf_file @@ -32,8 +30,7 @@ logger = init_logger(__name__) -AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, - TokenizerBase] +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, TokenizerBase] def decode_tokens( @@ -50,8 +47,7 @@ def decode_tokens( settings. """ if skip_special_tokens is not None: - return tokenizer.decode(token_ids, - skip_special_tokens=skip_special_tokens) + return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) return tokenizer.decode(token_ids) @@ -95,8 +91,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: tokenizer_all_special_ids = tokenizer.all_special_ids tokenizer_all_special_tokens = tokenizer.all_special_tokens - tokenizer_all_special_tokens_extended = ( - tokenizer.all_special_tokens_extended) + tokenizer_all_special_tokens_extended = tokenizer.all_special_tokens_extended tokenizer_vocab = tokenizer.get_vocab() tokenizer_len = len(tokenizer) @@ -110,7 +105,6 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: max_token_id = max(max_token_id, tokenizer.vocab_size) class CachedTokenizer(tokenizer.__class__): # type: ignore - @property def all_special_ids(self) -> list[int]: return tokenizer_all_special_ids @@ -134,7 +128,7 @@ def __len__(self) -> int: return tokenizer_len def __reduce__(self): - return get_cached_tokenizer, (tokenizer, ) + return get_cached_tokenizer, (tokenizer,) CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" @@ -151,8 +145,7 @@ def get_tokenizer( download_dir: Optional[str] = None, **kwargs, ) -> AnyTokenizer: - """Gets a tokenizer for the given model name via HuggingFace or ModelScope. - """ + """Gets a tokenizer for the given model name via HuggingFace or ModelScope.""" if envs.VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. @@ -173,13 +166,13 @@ def get_tokenizer( revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, # Ignore weights - we only need the tokenizer. - ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) tokenizer_name = tokenizer_path if tokenizer_mode == "slow": if kwargs.get("use_fast", False): - raise ValueError( - "Cannot use the fast tokenizer in slow tokenizer mode.") + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False if "truncation_side" not in kwargs: @@ -195,23 +188,28 @@ def get_tokenizer( is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai" if is_from_mistral_org and tokenizer_mode != "mistral": warnings.warn( - 'It is strongly recommended to run mistral models with ' + "It is strongly recommended to run mistral models with " '`--tokenizer-mode "mistral"` to ensure correct ' - 'encoding and decoding.', + "encoding and decoding.", FutureWarning, - stacklevel=2) + stacklevel=2, + ) tokenizer: AnyTokenizer if tokenizer_mode == "mistral": - tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), - revision=revision) + tokenizer = MistralTokenizer.from_pretrained( + str(tokenizer_name), revision=revision + ) elif tokenizer_mode == "custom": from vllm.transformers_utils.tokenizer_base import TokenizerRegistry - tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name), - *args, - revision=revision, - download_dir=download_dir, - **kwargs) + + tokenizer = TokenizerRegistry.get_tokenizer( + str(tokenizer_name), + *args, + revision=revision, + download_dir=download_dir, + **kwargs, + ) else: try: tokenizer = AutoTokenizer.from_pretrained( @@ -226,13 +224,16 @@ def get_tokenizer( # currently being imported, # suggest using the --trust-remote-code flag. if not trust_remote_code and ( - "does not exist or is not currently imported." in str(e) - or "requires you to execute the tokenizer file" in str(e)): - err_msg = ("Failed to load the tokenizer. If the tokenizer " - "is a custom tokenizer not yet available in the " - "HuggingFace transformers library, consider " - "setting `trust_remote_code=True` in LLM or using " - "the `--trust-remote-code` flag in the CLI.") + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e) + ): + err_msg = ( + "Failed to load the tokenizer. If the tokenizer " + "is a custom tokenizer not yet available in the " + "HuggingFace transformers library, consider " + "setting `trust_remote_code=True` in LLM or using " + "the `--trust-remote-code` flag in the CLI." + ) raise RuntimeError(err_msg) from e else: raise e @@ -240,19 +241,21 @@ def get_tokenizer( # The special_tokens in tokenizer should also be # controlled by do_lower_case in encoder_config encoder_config = get_sentence_transformer_tokenizer_config( - tokenizer_name, revision) + tokenizer_name, revision + ) if isinstance(encoder_config, dict) and encoder_config.get( - "do_lower_case", False): + "do_lower_case", False + ): special_tokens_map = { - k: v.lower() - for k, v in tokenizer.special_tokens_map.items() + k: v.lower() for k, v in tokenizer.special_tokens_map.items() } tokenizer.add_special_tokens(special_tokens_map) if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.warning( "Using a slow tokenizer. This might cause a significant " - "slowdown. Consider using a fast tokenizer instead.") + "slowdown. Consider using a fast tokenizer instead." + ) tokenizer = get_cached_tokenizer(tokenizer) return tokenizer diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py index b1f84a023fc3..2d64265abbf2 100644 --- a/vllm/transformers_utils/tokenizer_base.py +++ b/vllm/transformers_utils/tokenizer_base.py @@ -10,7 +10,6 @@ class TokenizerBase(ABC): - @property @abstractmethod def all_special_tokens_extended(self) -> list[str]: @@ -98,18 +97,22 @@ def encode_one( raise NotImplementedError() @abstractmethod - def encode(self, - text: str, - truncation: Optional[bool] = None, - max_length: Optional[int] = None, - add_special_tokens: Optional[bool] = None) -> list[int]: + def encode( + self, + text: str, + truncation: Optional[bool] = None, + max_length: Optional[int] = None, + add_special_tokens: Optional[bool] = None, + ) -> list[int]: raise NotImplementedError() @abstractmethod - def apply_chat_template(self, - messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, Any]]] = None, - **kwargs) -> list[int]: + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs, + ) -> list[int]: raise NotImplementedError() @abstractmethod @@ -117,9 +120,9 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError() @abstractmethod - def decode(self, - ids: Union[list[int], int], - skip_special_tokens: bool = True) -> str: + def decode( + self, ids: Union[list[int], int], skip_special_tokens: bool = True + ) -> str: raise NotImplementedError() @abstractmethod diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index 941156c4bf50..b63cb26af46d 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .mistral import (MistralTokenizer, maybe_serialize_tool_calls, - truncate_tool_call_ids, validate_request_params) +from .mistral import ( + MistralTokenizer, + maybe_serialize_tool_calls, + truncate_tool_call_ids, + validate_request_params, +) __all__ = [ - "MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids", - "validate_request_params" + "MistralTokenizer", + "maybe_serialize_tool_calls", + "truncate_tool_call_ids", + "validate_request_params", ] diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index ed9f28d54448..5633a31455e9 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -20,7 +20,8 @@ # will not be bothered by the dependency. from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import ( - MistralTokenizer as PublicMistralTokenizer) + MistralTokenizer as PublicMistralTokenizer, + ) from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @@ -51,7 +52,7 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): # - https://github.com/pydantic/pydantic/issues/9541 # TODO: remove when pydantic v2.11 is released for i, message in enumerate(request.messages): - if message.get("role") == 'assistant': + if message.get("role") == "assistant": tool_calls_validator = message.get("tool_calls", ().__iter__()) validated_tool_calls = [] while True: @@ -67,7 +68,7 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): def truncate_tool_call_ids(request: "ChatCompletionRequest"): """Truncates tool call IDs for Mistral's ID requirements.""" for i, message in enumerate(request.messages): - if message.get("role") == 'assistant': + if message.get("role") == "assistant": tool_calls = message.get("tool_calls", []) for tool_call in tool_calls: if len(tool_call["id"]) > 9: @@ -95,17 +96,19 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"): def validate_request_params(request: "ChatCompletionRequest"): - if (request.skip_special_tokens is not None - and not request.skip_special_tokens): - raise ValueError("skip_special_tokens=False is not supported " - "for Mistral tokenizers.") + if request.skip_special_tokens is not None and not request.skip_special_tokens: + raise ValueError( + "skip_special_tokens=False is not supported for Mistral tokenizers." + ) def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]: repo_cache = os.path.join( huggingface_hub.constants.HF_HUB_CACHE, huggingface_hub.constants.REPO_ID_SEPARATOR.join( - ["models", *repo_id.split("/")])) + ["models", *repo_id.split("/")] + ), + ) if revision is None: revision_file = os.path.join(repo_cache, "refs", "main") @@ -141,7 +144,8 @@ def find_tokenizer_file(files: list[str]): raise OSError( f"Found {len(matched_files)} files matching the " f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral " - f"tokenizer is present in {files}.") + f"tokenizer is present in {files}." + ) return matched_files[0] @@ -149,22 +153,23 @@ def find_tokenizer_file(files: list[str]): def _aggregate_content(content: list) -> list[dict[str, Any]]: aggregated_content: list[dict[str, Any]] = [] for chunk in content: - if chunk.get("type" - ) == "text" and aggregated_content and aggregated_content[ - -1].get("type") == "text": + if ( + chunk.get("type") == "text" + and aggregated_content + and aggregated_content[-1].get("type") == "text" + ): aggregated_content[-1]["text"] += "\n\n" + chunk.get("text") else: aggregated_content.append(chunk) - if len(aggregated_content) == 1 and aggregated_content[0].get( - "type") == "text": + if len(aggregated_content) == 1 and aggregated_content[0].get("type") == "text": content = aggregated_content[0]["text"] return content def make_mistral_chat_completion_request( - messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, - Any]]] = None) -> "ChatCompletionRequest": + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, +) -> "ChatCompletionRequest": last_message = cast(dict[str, Any], messages[-1]) if last_message["role"] == "assistant": last_message["prefix"] = True @@ -188,8 +193,7 @@ def make_mistral_chat_completion_request( # even if they are empty. if tools: for function in [ - tool["function"] for tool in tools - if tool["type"] == "function" + tool["function"] for tool in tools if tool["type"] == "function" ]: if function.get("parameters") is None: function["parameters"] = {} @@ -197,12 +201,11 @@ def make_mistral_chat_completion_request( function["description"] = "" from mistral_common.protocol.instruct.request import ChatCompletionRequest - return ChatCompletionRequest(messages=messages, - tools=tools) # type: ignore[type-var] + return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var] -class MistralTokenizer(TokenizerBase): +class MistralTokenizer(TokenizerBase): def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: self.mistral = tokenizer self.instruct = tokenizer.instruct_tokenizer @@ -215,10 +218,13 @@ def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: self.is_tekken = isinstance(tokenizer_, Tekkenizer) from mistral_common.tokens.tokenizers.sentencepiece import ( - SentencePieceTokenizer) + SentencePieceTokenizer, + ) + self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) - self._special_token_policy = (SpecialTokenPolicy.IGNORE - if self.is_tekken else None) + self._special_token_policy = ( + SpecialTokenPolicy.IGNORE if self.is_tekken else None + ) if not (self.is_tekken or self.is_spm): raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") @@ -226,57 +232,54 @@ def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: # Convert to a dict[str, int] to match protocol, but this is a lossy # conversion. There may be multiple token ids that decode to the same # string due to partial UTF-8 byte sequences being converted to � - self._vocab_dict = { - token: idx - for idx, token in enumerate(self._vocab) - } + self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)} self.tokenizer = tokenizer_ self._max_token_id = self.vocab_size - 1 @classmethod - def from_pretrained(cls, - path_or_repo_id: str, - *, - revision: Optional[str] = None) -> "MistralTokenizer": + def from_pretrained( + cls, path_or_repo_id: str, *, revision: Optional[str] = None + ) -> "MistralTokenizer": if not Path(path_or_repo_id).exists(): assert len(path_or_repo_id.split("/")) == 2, ( "You have either provided a non-existent path: " - "{path_or_repo_id} or an invalid HF Hub repo id.") + "{path_or_repo_id} or an invalid HF Hub repo id." + ) tokenizer_file = cls._download_mistral_tokenizer_from_hf( - path_or_repo_id, revision) + path_or_repo_id, revision + ) elif Path(path_or_repo_id).is_dir(): - tokenizer_file_name = find_tokenizer_file( - os.listdir(path_or_repo_id)) + tokenizer_file_name = find_tokenizer_file(os.listdir(path_or_repo_id)) tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name) else: - assert Path( - path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}" + assert Path(path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}" tokenizer_file = str(Path(path_or_repo_id)) from mistral_common.tokens.tokenizers.mistral import ( - MistralTokenizer as PublicMistralTokenizer) + MistralTokenizer as PublicMistralTokenizer, + ) + mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file) return cls(mistral_tokenizer) @staticmethod - def _download_mistral_tokenizer_from_hf(tokenizer_name: str, - revision: Optional[str]) -> str: + def _download_mistral_tokenizer_from_hf( + tokenizer_name: str, revision: Optional[str] + ) -> str: try: hf_api = HfApi() - files = hf_api.list_repo_files(repo_id=tokenizer_name, - revision=revision) + files = hf_api.list_repo_files(repo_id=tokenizer_name, revision=revision) except ConnectionError as exc: - files = list_local_repo_files(repo_id=tokenizer_name, - revision=revision) + files = list_local_repo_files(repo_id=tokenizer_name, revision=revision) if len(files) == 0: raise exc filename = find_tokenizer_file(files) - tokenizer_file = hf_hub_download(tokenizer_name, - filename=filename, - revision=revision) + tokenizer_file = hf_hub_download( + tokenizer_name, filename=filename, revision=revision + ) return tokenizer_file # the following attributes are set to fit vLLM's design and are used @@ -290,10 +293,7 @@ def all_special_tokens_extended(self) -> list[str]: special_tokens = self.tokenizer.SPECIAL_TOKENS else: special_tokens = list(SpecialTokens) - return [ - s.value if isinstance(s, SpecialTokens) else s - for s in special_tokens - ] + return [s.value if isinstance(s, SpecialTokens) else s for s in special_tokens] @property def all_special_tokens(self) -> list[str]: @@ -301,9 +301,7 @@ def all_special_tokens(self) -> list[str]: @property def all_special_ids(self) -> list[int]: - return [ - self.all_special_tokens.index(t) for t in self.all_special_tokens - ] + return [self.all_special_tokens.index(t) for t in self.all_special_tokens] @property def bos_token_id(self) -> int: @@ -386,26 +384,29 @@ def encode_one( input_ids = input_ids[:max_length] return input_ids - def encode(self, - text: str, - truncation: Optional[bool] = None, - max_length: Optional[int] = None, - add_special_tokens: Optional[bool] = None) -> list[int]: + def encode( + self, + text: str, + truncation: Optional[bool] = None, + max_length: Optional[int] = None, + add_special_tokens: Optional[bool] = None, + ) -> list[int]: # `encode` should only be used for prompt completion # it should never be used for chat_completion. # For chat completion use `apply_chat_template` if add_special_tokens is not None: - return self.tokenizer.encode(text, - bos=add_special_tokens, - eos=add_special_tokens) + return self.tokenizer.encode( + text, bos=add_special_tokens, eos=add_special_tokens + ) else: return self.tokenizer.encode(text, bos=True, eos=False) - def apply_chat_template(self, - messages: list["ChatCompletionMessageParam"], - tools: Optional[list[dict[str, Any]]] = None, - **kwargs) -> list[int]: - + def apply_chat_template( + self, + messages: list["ChatCompletionMessageParam"], + tools: Optional[list[dict[str, Any]]] = None, + **kwargs, + ) -> list[int]: request = make_mistral_chat_completion_request(messages, tools) encoded = self.mistral.encode_chat_completion(request) @@ -414,11 +415,15 @@ def apply_chat_template(self, def convert_tokens_to_string(self, tokens: list[str]) -> str: from mistral_common.tokens.tokenizers.base import SpecialTokens + if self.is_tekken: tokens = [ - t for t in tokens - if (t is SpecialTokens.tool_calls - or t not in self.tokenizer._all_special_tokens) + t + for t in tokens + if ( + t is SpecialTokens.tool_calls + or t not in self.tokenizer._all_special_tokens + ) ] if any(isinstance(t, bytes) for t in tokens): @@ -426,20 +431,20 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str: shift = self.tokenizer.num_special_tokens def _token_to_id(t: str): - t_bytes = t.encode("utf-8") \ - if not isinstance(t, bytes) else t + t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t try: - return shift + \ - self.tokenizer._tekken_token2id_nospecial[t_bytes] + return ( + shift + self.tokenizer._tekken_token2id_nospecial[t_bytes] + ) except KeyError: logger.warning( - "Failed to convert token %s to id," - " replacing with <unk>", t_bytes) + "Failed to convert token %s to id, replacing with <unk>", + t_bytes, + ) return self.tokenizer.unk_id ids = [_token_to_id(t) for t in tokens] - decoded = self.tokenizer.decode(ids, - self._special_token_policy) + decoded = self.tokenizer.decode(ids, self._special_token_policy) else: decoded = "".join(tokens) else: @@ -453,8 +458,10 @@ def _token_to_id(t: str): if token in special_tokens: if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens, - self._special_token_policy)) + self.tokenizer.decode( + regular_tokens, self._special_token_policy + ) + ) regular_tokens = [] decoded_list.append(token) else: @@ -462,19 +469,19 @@ def _token_to_id(t: str): if regular_tokens: decoded_list.append( - self.tokenizer.decode(regular_tokens, - self._special_token_policy)) + self.tokenizer.decode(regular_tokens, self._special_token_policy) + ) - decoded = ''.join(decoded_list) + decoded = "".join(decoded_list) return decoded - def decode(self, - ids: Union[list[int], int], - skip_special_tokens: bool = True) -> str: - assert ( - skip_special_tokens - ), "skip_special_tokens=False is not supported for Mistral tokenizers." + def decode( + self, ids: Union[list[int], int], skip_special_tokens: bool = True + ) -> str: + assert skip_special_tokens, ( + "skip_special_tokens=False is not supported for Mistral tokenizers." + ) if isinstance(ids, int): ids = [ids] @@ -486,13 +493,12 @@ def convert_ids_to_tokens( skip_special_tokens: bool = True, ) -> list[str]: from mistral_common.tokens.tokenizers.base import SpecialTokens - from mistral_common.tokens.tokenizers.instruct import ( - InstructTokenizerV13) + from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13 # TODO(Patrick) - potentially allow special tokens to not be skipped - assert ( - skip_special_tokens - ), "skip_special_tokens=False is not supported for Mistral tokenizers." + assert skip_special_tokens, ( + "skip_special_tokens=False is not supported for Mistral tokenizers." + ) assert self.is_tekken or self.is_spm, type(self.tokenizer) @@ -507,8 +513,9 @@ def convert_ids_to_tokens( if self.instruct.END_THINK: non_skip_special_tokens.add(self.instruct.END_THINK) ids = [ - i for i in ids if i > self.tokenizer.num_special_tokens - or i in non_skip_special_tokens + i + for i in ids + if i > self.tokenizer.num_special_tokens or i in non_skip_special_tokens ] tokens = [self.tokenizer.id_to_piece(id) for id in ids] diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 2aaad8f949d0..8952a0b197d6 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -15,7 +15,7 @@ def is_s3(model_or_path: str) -> bool: - return model_or_path.lower().startswith('s3://') + return model_or_path.lower().startswith("s3://") def check_gguf_file(model: Union[str, PathLike]) -> bool: @@ -43,13 +43,16 @@ def modelscope_list_repo_files( ) -> list[str]: """List files in a modelscope repo.""" from modelscope.hub.api import HubApi + api = HubApi() api.login(token) # same as huggingface_hub.list_repo_files files = [ - file['Path'] for file in api.get_model_files( - model_id=repo_id, revision=revision, recursive=True) - if file['Type'] == 'blob' + file["Path"] + for file in api.get_model_files( + model_id=repo_id, revision=revision, recursive=True + ) + if file["Type"] == "blob" ] return files @@ -91,18 +94,18 @@ def maybe_model_redirect(model: str) -> str: if not Path(model_redirect_path).exists(): return model - redirect_dict = (_maybe_json_dict(model_redirect_path) - or _maybe_space_split_dict(model_redirect_path)) - if (redirect_model := redirect_dict.get(model)): + redirect_dict = _maybe_json_dict(model_redirect_path) or _maybe_space_split_dict( + model_redirect_path + ) + if redirect_model := redirect_dict.get(model): logger.info("model redirect: [ %s ] -> [ %s ]", model, redirect_model) return redirect_model return model -def parse_safetensors_file_metadata( - path: Union[str, PathLike]) -> dict[str, Any]: +def parse_safetensors_file_metadata(path: Union[str, PathLike]) -> dict[str, Any]: with open(path, "rb") as f: - length_of_metadata = struct.unpack('<Q', f.read(8))[0] - metadata = json.loads(f.read(length_of_metadata).decode('utf-8')) + length_of_metadata = struct.unpack("<Q", f.read(8))[0] + metadata = json.loads(f.read(length_of_metadata).decode("utf-8")) return metadata diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 828536e6408b..a475d0fa406b 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder, - TritonPlaceholder) +from vllm.triton_utils.importing import ( + HAS_TRITON, + TritonLanguagePlaceholder, + TritonPlaceholder, +) if HAS_TRITON: import triton diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index 95076a9a7c8f..e1a509a303c5 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -21,15 +21,15 @@ # an is_active method. # The `x.driver and` check adds a small layer of safety. active_drivers = [ - x.driver for x in backends.values() - if x.driver and x.driver.is_active() + x.driver for x in backends.values() if x.driver and x.driver.is_active() ] # Check if we're in a distributed environment where CUDA_VISIBLE_DEVICES # might be temporarily empty (e.g., Ray sets it to "" during actor init) cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") - is_distributed_env = (cuda_visible_devices is not None - and len(cuda_visible_devices.strip()) == 0) + is_distributed_env = ( + cuda_visible_devices is not None and len(cuda_visible_devices.strip()) == 0 + ) # Apply lenient driver check for distributed environments if is_distributed_env and len(active_drivers) == 0: @@ -37,35 +37,41 @@ # active later when CUDA context is properly initialized logger.debug( "Triton found 0 active drivers in distributed environment. " - "This is expected during initialization.") + "This is expected during initialization." + ) elif not is_distributed_env and len(active_drivers) != 1: # Strict check for non-distributed environments logger.info( "Triton is installed but %d active driver(s) found " "(expected 1). Disabling Triton to prevent runtime errors.", - len(active_drivers)) + len(active_drivers), + ) HAS_TRITON = False except ImportError: # This can occur if Triton is partially installed or triton.backends # is missing. logger.warning( "Triton is installed, but `triton.backends` could not be imported. " - "Disabling Triton.") + "Disabling Triton." + ) HAS_TRITON = False except Exception as e: # Catch any other unexpected errors during the check. logger.warning( "An unexpected error occurred while checking Triton active drivers:" - " %s. Disabling Triton.", e) + " %s. Disabling Triton.", + e, + ) HAS_TRITON = False if not HAS_TRITON: - logger.info("Triton not installed or not compatible; certain GPU-related" - " functions will not be available.") + logger.info( + "Triton not installed or not compatible; certain GPU-related" + " functions will not be available." + ) class TritonPlaceholder(types.ModuleType): - def __init__(self): super().__init__("triton") self.__version__ = "3.4.0" @@ -76,7 +82,6 @@ def __init__(self): self.language = TritonLanguagePlaceholder() def _dummy_decorator(self, name): - def decorator(*args, **kwargs): if args and callable(args[0]): return args[0] @@ -86,7 +91,6 @@ def decorator(*args, **kwargs): class TritonLanguagePlaceholder(types.ModuleType): - def __init__(self): super().__init__("triton.language") self.constexpr = None diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 92245498de65..ed470ebe8892 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -68,8 +68,7 @@ def is_usage_stats_enabled(): no_usage_stats = envs.VLLM_NO_USAGE_STATS do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH) - _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats - or do_not_track_file) + _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats or do_not_track_file) return _USAGE_STATS_ENABLED @@ -80,9 +79,11 @@ def _get_current_timestamp_ns() -> int: def _detect_cloud_provider() -> str: # Try detecting through vendor file vendor_files = [ - "/sys/class/dmi/id/product_version", "/sys/class/dmi/id/bios_vendor", + "/sys/class/dmi/id/product_version", + "/sys/class/dmi/id/bios_vendor", "/sys/class/dmi/id/product_name", - "/sys/class/dmi/id/chassis_asset_tag", "/sys/class/dmi/id/sys_vendor" + "/sys/class/dmi/id/chassis_asset_tag", + "/sys/class/dmi/id/sys_vendor", ] # Mapping of identifiable strings to cloud providers cloud_identifiers = { @@ -152,39 +153,53 @@ def __init__(self) -> None: self.log_time: Optional[int] = None self.source: Optional[str] = None - def report_usage(self, - model_architecture: str, - usage_context: UsageContext, - extra_kvs: Optional[dict[str, Any]] = None) -> None: - t = Thread(target=self._report_usage_worker, - args=(model_architecture, usage_context, extra_kvs or {}), - daemon=True) + def report_usage( + self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: Optional[dict[str, Any]] = None, + ) -> None: + t = Thread( + target=self._report_usage_worker, + args=(model_architecture, usage_context, extra_kvs or {}), + daemon=True, + ) t.start() - def _report_usage_worker(self, model_architecture: str, - usage_context: UsageContext, - extra_kvs: dict[str, Any]) -> None: + def _report_usage_worker( + self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: dict[str, Any], + ) -> None: self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_continuous_usage() - def _report_usage_once(self, model_architecture: str, - usage_context: UsageContext, - extra_kvs: dict[str, Any]) -> None: + def _report_usage_once( + self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: dict[str, Any], + ) -> None: # Platform information from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): self.gpu_count = cuda_device_count_stateless() - self.gpu_type, self.gpu_memory_per_device = ( - cuda_get_device_properties(0, ("name", "total_memory"))) + self.gpu_type, self.gpu_memory_per_device = cuda_get_device_properties( + 0, ("name", "total_memory") + ) if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda if current_platform.is_tpu(): try: import torch_xla + self.gpu_count = torch_xla.runtime.world_size() self.gpu_type = torch_xla.tpu.get_tpu_type() - self.gpu_memory_per_device = ( - torch_xla.core.xla_model.get_memory_info()["bytes_limit"]) + self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[ + "bytes_limit" + ] except Exception: logger.exception("Failed to collect TPU information") self.provider = _detect_cloud_provider() @@ -195,11 +210,13 @@ def _report_usage_once(self, model_architecture: str, info = cpuinfo.get_cpu_info() self.num_cpu = info.get("count", None) self.cpu_type = info.get("brand_raw", "") - self.cpu_family_model_stepping = ",".join([ - str(info.get("family", "")), - str(info.get("model", "")), - str(info.get("stepping", "")) - ]) + self.cpu_family_model_stepping = ",".join( + [ + str(info.get("family", "")), + str(info.get("model", "")), + str(info.get("stepping", "")), + ] + ) # vLLM information self.context = usage_context.value @@ -207,10 +224,9 @@ def _report_usage_once(self, model_architecture: str, self.model_architecture = model_architecture # Environment variables - self.env_var_json = json.dumps({ - env_var: getattr(envs, env_var) - for env_var in _USAGE_ENV_VARS_TO_COLLECT - }) + self.env_var_json = json.dumps( + {env_var: getattr(envs, env_var) for env_var in _USAGE_ENV_VARS_TO_COLLECT} + ) # Metadata self.log_time = _get_current_timestamp_ns() diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 6b208bca6986..c06bbbbb23ab 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -33,22 +33,47 @@ import uuid import warnings import weakref -from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, - ArgumentTypeError, RawDescriptionHelpFormatter, - _ArgumentGroup) +from argparse import ( + Action, + ArgumentDefaultsHelpFormatter, + ArgumentParser, + ArgumentTypeError, + RawDescriptionHelpFormatter, + _ArgumentGroup, +) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict -from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, - Hashable, Iterable, Iterator, KeysView, Mapping, - Sequence) +from collections.abc import ( + AsyncGenerator, + Awaitable, + Collection, + Generator, + Hashable, + Iterable, + Iterator, + KeysView, + Mapping, + Sequence, +) from concurrent.futures import ThreadPoolExecutor from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from pathlib import Path from types import MappingProxyType -from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, TextIO, TypeVar, Union, cast, overload) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + NamedTuple, + TextIO, + TypeVar, + Union, + cast, + overload, +) from urllib.parse import urlparse from uuid import uuid4 @@ -117,8 +142,8 @@ """The number of bytes in one gibibyte (GiB).""" # ANSI color codes -CYAN = '\033[1;36m' -RESET = '\033[0;0m' +CYAN = "\033[1;36m" +RESET = "\033[0;0m" STR_DTYPE_TO_TORCH_DTYPE = { "float32": torch.float32, @@ -152,7 +177,7 @@ def set_default_torch_num_threads(num_threads: int): torch.set_num_threads(old_num_threads) -P = ParamSpec('P') +P = ParamSpec("P") T = TypeVar("T") U = TypeVar("U") @@ -161,8 +186,7 @@ def set_default_torch_num_threads(num_threads: int): _T = TypeVar("_T") -class _Sentinel: - ... +class _Sentinel: ... ALL_PINNED_SENTINEL = _Sentinel() @@ -179,7 +203,6 @@ class LayerBlockType(enum.Enum): class Counter: - def __init__(self, start: int = 0) -> None: self.counter = start @@ -193,7 +216,6 @@ def reset(self) -> None: class _MappingOrderCacheView(UserDict[_K, _V]): - def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]): super().__init__(data) self.ordered_keys = ordered_keys @@ -224,10 +246,7 @@ def __sub__(self, other: CacheInfo): class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): - - def __init__(self, - capacity: float, - getsizeof: Optional[Callable[[_V], float]] = None): + def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None): super().__init__(capacity, getsizeof) self.pinned_items = set[_K]() @@ -247,8 +266,7 @@ def __getitem__(self, key: _K, *, update_info: bool = True) -> _V: def __delitem__(self, key: _K) -> None: run_on_remove = key in self - value = self.__getitem__(key, - update_info=False) # type: ignore[call-arg] + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] super().__delitem__(key) if key in self.pinned_items: # Todo: add warning to inform that del pinned item @@ -261,7 +279,8 @@ def cache(self) -> Mapping[_K, _V]: """Return the internal cache dictionary in order (read-only).""" return _MappingOrderCacheView( self._Cache__data, # type: ignore - self.order) + self.order, + ) @property def order(self) -> Mapping[_K, None]: @@ -302,22 +321,17 @@ def touch(self, key: _K) -> None: self._LRUCache__order[key] = None # type: ignore @overload - def get(self, key: _K, /) -> Optional[_V]: - ... + def get(self, key: _K, /) -> _V | None: ... @overload - def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: - ... - - def get(self, - key: _K, - /, - default: Optional[Union[_V, - _T]] = None) -> Optional[Union[_V, _T]]: - value: Optional[Union[_V, _T]] + def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ... + + def get( + self, key: _K, /, default: Union[_V, _T] | None = None + ) -> Union[_V, _T] | None: + value: Union[_V, _T] | None if key in self: - value = self.__getitem__( - key, update_info=False) # type: ignore[call-arg] + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] self._hits += 1 else: @@ -327,23 +341,19 @@ def get(self, return value @overload - def pop(self, key: _K) -> _V: - ... + def pop(self, key: _K) -> _V: ... @overload - def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: - ... - - def pop(self, - key: _K, - default: Optional[Union[_V, - _T]] = None) -> Optional[Union[_V, _T]]: - value: Optional[Union[_V, _T]] + def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ... + + def pop( + self, key: _K, default: Union[_V, _T] | None = None + ) -> Union[_V, _T] | None: + value: Union[_V, _T] | None if key not in self: return default - value = self.__getitem__(key, - update_info=False) # type: ignore[call-arg] + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] self.__delitem__(key) return value @@ -366,7 +376,7 @@ def _unpin(self, key: _K) -> None: """ self.pinned_items.remove(key) - def _on_remove(self, key: _K, value: Optional[_V]) -> None: + def _on_remove(self, key: _K, value: _V | None) -> None: pass def remove_oldest(self, *, remove_pinned: bool = False) -> None: @@ -385,10 +395,12 @@ def popitem(self, remove_pinned: bool = False): # pop the oldest item in the cache that is not pinned lru_key = next( (key for key in self.order if key not in self.pinned_items), - ALL_PINNED_SENTINEL) + ALL_PINNED_SENTINEL, + ) if lru_key is ALL_PINNED_SENTINEL: - raise RuntimeError("All items are pinned, " - "cannot remove oldest from the cache.") + raise RuntimeError( + "All items are pinned, cannot remove oldest from the cache." + ) else: lru_key = next(iter(self.order)) value = self.pop(cast(_K, lru_key)) @@ -436,8 +448,7 @@ def get_object(self): return obj def reset(self): - """Makes all cached-objects available for the next scheduler iteration. - """ + """Makes all cached-objects available for the next scheduler iteration.""" self._index = 0 @@ -445,8 +456,8 @@ def reset(self): def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" from vllm import _custom_ops as ops - max_shared_mem = ( - ops.get_max_shared_memory_per_block_device_attribute(gpu)) + + max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu) # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # will fail assert max_shared_mem > 0, "max_shared_mem can not be zero" @@ -481,11 +492,14 @@ def __init__( self.batch_wait_timeout_s = batch_wait_timeout_s self._loop = asyncio.get_running_loop() - self._queues: dict[tuple, - asyncio.Queue[Union[tuple[str, dict, - asyncio.Future], - tuple[list[int], - asyncio.Future]]]] = {} + self._queues: dict[ + tuple, + asyncio.Queue[ + Union[ + tuple[str, dict, asyncio.Future], tuple[list[int], asyncio.Future] + ] + ], + ] = {} self._batcher_tasks: list[asyncio.Task] = [] # Single-thread executor for blocking tokenizer calls. @@ -509,8 +523,9 @@ async def decode(self, token_ids, **kwargs): # === Internal helpers === def _get_queue( self, loop: asyncio.AbstractEventLoop, key: tuple - ) -> asyncio.Queue[Union[tuple[str, dict, asyncio.Future], tuple[ - list[int], asyncio.Future]]]: + ) -> asyncio.Queue[ + Union[tuple[str, dict, asyncio.Future], tuple[list[int], asyncio.Future]] + ]: """Get the request queue for the given operation key, creating a new queue and batcher task if needed.""" queue = self._queues.get(key) @@ -520,8 +535,7 @@ def _get_queue( can_batch = key[1] != "other" coro = self._batch_encode_loop(queue, can_batch) else: - assert key[0] == "decode", \ - f"Unknown operation type: {key[0]}." + assert key[0] == "decode", f"Unknown operation type: {key[0]}." coro = self._batch_decode_loop(queue) self._batcher_tasks.append(loop.create_task(coro)) return queue @@ -541,7 +555,8 @@ async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool): break try: prompt, kwargs, result_future = await asyncio.wait_for( - queue.get(), timeout) + queue.get(), timeout + ) prompts.append(prompt) result_futures.append(result_future) if not can_batch: @@ -553,10 +568,10 @@ async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool): # If every request uses identical kwargs we can run a single # batched tokenizer call for a big speed-up. if can_batch and len(prompts) > 1: - batch_encode_fn = partial(self.tokenizer, prompts, - **kwargs) + batch_encode_fn = partial(self.tokenizer, prompts, **kwargs) results = await self._loop.run_in_executor( - self._executor, batch_encode_fn) + self._executor, batch_encode_fn + ) for i, fut in enumerate(result_futures): if not fut.done(): @@ -564,11 +579,11 @@ async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool): fut.set_result(BatchEncoding(data)) else: encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ - self.tokenizer(p, **kw) - for p, kw in zip(prompts, kwargs) + self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs) ] results = await self._loop.run_in_executor( - self._executor, encode_fn) + self._executor, encode_fn + ) for fut, res in zip(result_futures, results): if not fut.done(): @@ -592,7 +607,8 @@ async def _batch_decode_loop(self, queue: asyncio.Queue): break try: token_ids, result_future = await asyncio.wait_for( - queue.get(), timeout) + queue.get(), timeout + ) token_ids_list.append(token_ids) result_futures.append(result_future) except asyncio.TimeoutError: @@ -601,8 +617,8 @@ async def _batch_decode_loop(self, queue: asyncio.Queue): try: # Perform a single batched decode call for all requests results = await self._loop.run_in_executor( - self._executor, self.tokenizer.batch_decode, - token_ids_list) + self._executor, self.tokenizer.batch_decode, token_ids_list + ) for fut, res in zip(result_futures, results): if not fut.done(): fut.set_result(res) @@ -631,7 +647,7 @@ def _queue_key(self, op: str, kwargs: dict) -> tuple: """ if op == "decode": - return ("decode", ) + return ("decode",) add_special_tokens = kwargs.get("add_special_tokens", True) truncation = kwargs.get("truncation", False) @@ -641,16 +657,17 @@ def _queue_key(self, op: str, kwargs: dict) -> tuple: return "encode", add_special_tokens, False, None model_max = getattr(self.tokenizer, "model_max_length", None) - if max_length is None or (model_max is not None - and max_length == model_max): + if max_length is None or (model_max is not None and max_length == model_max): return "encode", add_special_tokens, True, "model_max" return "encode", "other" def __del__(self): - if ((tasks := getattr(self, "_batcher_tasks", None)) - and (loop := getattr(self, "_loop", None)) - and not loop.is_closed()): + if ( + (tasks := getattr(self, "_batcher_tasks", None)) + and (loop := getattr(self, "_loop", None)) + and not loop.is_closed() + ): def cancel_tasks(): for task in tasks: @@ -685,8 +702,7 @@ def in_loop(event_loop: AbstractEventLoop) -> bool: def make_async( - func: Callable[P, T], - executor: Optional[concurrent.futures.Executor] = None + func: Callable[P, T], executor: concurrent.futures.Executor | None = None ) -> Callable[P, Awaitable[T]]: """Take a blocking function, and run it on in an executor thread. @@ -703,15 +719,14 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: return _async_wrapper -def _next_task(iterator: AsyncGenerator[T, None], - loop: AbstractEventLoop) -> Task: +def _next_task(iterator: AsyncGenerator[T, None], loop: AbstractEventLoop) -> Task: # Can use anext() in python >= 3.10 return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] async def merge_async_iterators( - *iterators: AsyncGenerator[T, - None], ) -> AsyncGenerator[tuple[int, T], None]: + *iterators: AsyncGenerator[T, None], +) -> AsyncGenerator[tuple[int, T], None]: """Merge multiple asynchronous iterators into a single iterator. This method handle the case where some iterators finish before others. @@ -729,8 +744,7 @@ async def merge_async_iterators( awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} try: while awaits: - done, _ = await asyncio.wait(awaits.keys(), - return_when=FIRST_COMPLETED) + done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED) for d in done: pair = awaits.pop(d) try: @@ -748,8 +762,7 @@ async def merge_async_iterators( await it.aclose() -async def collect_from_async_generator( - iterator: AsyncGenerator[T, None]) -> list[T]: +async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]: """Collect all items from an async generator into a list.""" items = [] async for item in iterator: @@ -765,7 +778,8 @@ def get_ip() -> str: " it is often used by Docker and other software to" " interact with the container's network stack. Please " "use VLLM_HOST_IP instead to set the IP address for vLLM processes" - " to communicate with each other.") + " to communicate with each other." + ) if host_ip: return host_ip @@ -793,7 +807,8 @@ def get_ip() -> str: "Failed to get the IP address, using 0.0.0.0 by default." "The value can be set by the environment variable" " VLLM_HOST_IP or HOST_IP.", - stacklevel=2) + stacklevel=2, + ) return "0.0.0.0" @@ -821,7 +836,8 @@ def get_loopback_ip() -> str: else: raise RuntimeError( "Neither 127.0.0.1 nor ::1 are bound to a local interface. " - "Set the VLLM_LOOPBACK_IP environment variable explicitly.") + "Set the VLLM_LOOPBACK_IP environment variable explicitly." + ) def is_valid_ipv6_address(address: str) -> bool: @@ -834,13 +850,13 @@ def is_valid_ipv6_address(address: str) -> bool: def split_host_port(host_port: str) -> tuple[str, int]: # ipv6 - if host_port.startswith('['): - host, port = host_port.rsplit(']', 1) + if host_port.startswith("["): + host, port = host_port.rsplit("]", 1) host = host[1:] - port = port.split(':')[1] + port = port.split(":")[1] return host, int(port) else: - host, port = host_port.split(':') + host, port = host_port.split(":") return host, int(port) @@ -908,8 +924,7 @@ def _get_open_port() -> int: return port except OSError: port += 1 # Increment port number if already in use - logger.info("Port %d is already in use, trying port %d", - port - 1, port) + logger.info("Port %d is already in use, trying port %d", port - 1, port) # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -922,7 +937,7 @@ def _get_open_port() -> int: return s.getsockname()[1] -def find_process_using_port(port: int) -> Optional[psutil.Process]: +def find_process_using_port(port: int) -> psutil.Process | None: # TODO: We can not check for running processes with network # port on macOS. Therefore, we can not have a full graceful shutdown # of vLLM. For now, let's not look for processes in this case. @@ -932,8 +947,7 @@ def find_process_using_port(port: int) -> Optional[psutil.Process]: our_pid = os.getpid() for conn in psutil.net_connections(): - if conn.laddr.port == port and (conn.pid is not None - and conn.pid != our_pid): + if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid): try: return psutil.Process(conn.pid) except psutil.NoSuchProcess: @@ -945,15 +959,18 @@ def update_environment_variables(envs: dict[str, str]): for k, v in envs.items(): if k in os.environ and os.environ[k] != v: logger.warning( - "Overwriting environment variable %s " - "from '%s' to '%s'", k, os.environ[k], v) + "Overwriting environment variable %s from '%s' to '%s'", + k, + os.environ[k], + v, + ) os.environ[k] = v def chunk_list(lst: list[T], chunk_size: int): """Yield successive chunk_size chunks from lst.""" for i in range(0, len(lst), chunk_size): - yield lst[i:i + chunk_size] + yield lst[i : i + chunk_size] def cdiv(a: int, b: int) -> int: @@ -997,6 +1014,7 @@ def _generate_random_fp8( # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} from vllm import _custom_ops as ops + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) ops.convert_fp8(tensor, tensor_tmp) @@ -1004,12 +1022,12 @@ def _generate_random_fp8( def get_kv_cache_torch_dtype( - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: + cache_dtype: Union[str, torch.dtype] | None, + model_dtype: Union[str, torch.dtype] | None = None, +) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": - if isinstance(model_dtype, - str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: + if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] elif isinstance(model_dtype, torch.dtype): torch_dtype = model_dtype @@ -1032,39 +1050,37 @@ def create_kv_caches_with_random_flash( num_layers: int, num_heads: int, head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: Optional[int] = None, - device: Optional[str] = "cuda", - cache_layout: Optional[str] = "NHD", + cache_dtype: Union[str, torch.dtype] | None, + model_dtype: Union[str, torch.dtype] | None = None, + seed: int | None = None, + device: str | None = "cuda", + cache_layout: str | None = "NHD", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: from vllm.platforms import current_platform + current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) assert cache_layout in ("NHD", "HND") - stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, - 4) + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) - kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] - for i in stride_order) + kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) scale = head_size**-0.5 key_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = [] for _ in range(num_layers): - key_value_cache = torch.empty(size=kv_cache_allocation_shape, - dtype=torch_dtype, - device=device).permute(*stride_order) + key_value_cache = torch.empty( + size=kv_cache_allocation_shape, dtype=torch_dtype, device=device + ).permute(*stride_order) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_value_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8': + elif cache_dtype == "fp8": _generate_random_fp8(key_value_cache, -scale, scale) else: - raise ValueError( - f"Does not support key cache of type {cache_dtype}") + raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_value_cache[:, 0]) value_caches.append(key_value_cache[:, 1]) return key_caches, value_caches @@ -1076,16 +1092,17 @@ def create_kv_caches_with_random( num_layers: int, num_heads: int, head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: Optional[int] = None, - device: Optional[str] = "cuda", + cache_dtype: Union[str, torch.dtype] | None, + model_dtype: Union[str, torch.dtype] | None = None, + seed: int | None = None, + device: str | None = "cuda", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: if cache_dtype == "fp8" and head_size % 16: raise ValueError( f"Does not support key cache of type fp8 with head_size {head_size}" ) from vllm.platforms import current_platform + current_platform.seed_everything(seed) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) @@ -1095,31 +1112,27 @@ def create_kv_caches_with_random( key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches: list[torch.Tensor] = [] for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, - dtype=torch_dtype, - device=device) + key_cache = torch.empty(size=key_cache_shape, dtype=torch_dtype, device=device) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8': + elif cache_dtype == "fp8": _generate_random_fp8(key_cache, -scale, scale) else: - raise ValueError( - f"Does not support key cache of type {cache_dtype}") + raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_caches: list[torch.Tensor] = [] for _ in range(num_layers): - value_cache = torch.empty(size=value_cache_shape, - dtype=torch_dtype, - device=device) + value_cache = torch.empty( + size=value_cache_shape, dtype=torch_dtype, device=device + ) if cache_dtype in ["auto", "half", "bfloat16", "float"]: value_cache.uniform_(-scale, scale) - elif cache_dtype == 'fp8': + elif cache_dtype == "fp8": _generate_random_fp8(value_cache, -scale, scale) else: - raise ValueError( - f"Does not support value cache of type {cache_dtype}") + raise ValueError(f"Does not support value cache of type {cache_dtype}") value_caches.append(value_cache) return key_caches, value_caches @@ -1127,6 +1140,7 @@ def create_kv_caches_with_random( @cache def is_pin_memory_available() -> bool: from vllm.platforms import current_platform + return current_platform.is_pin_memory_available() @@ -1139,13 +1153,13 @@ def is_uva_available() -> bool: class DeviceMemoryProfiler: - - def __init__(self, device: Optional[torch.types.Device] = None): + def __init__(self, device: torch.types.Device | None = None): self.device = device def current_memory_usage(self) -> float: # Return the memory usage in bytes. from vllm.platforms import current_platform + gc.collect() return current_platform.get_current_memory_usage(self.device) @@ -1167,7 +1181,7 @@ def make_ndarray_with_pad( pad: T, dtype: npt.DTypeLike, *, - max_len: Optional[int] = None, + max_len: int | None = None, ) -> npt.NDArray: """ Make a padded array from 2D inputs. @@ -1182,7 +1196,7 @@ def make_ndarray_with_pad( padded_x = np.full((len(x), max_len), pad, dtype=dtype) for ind, blocktb in enumerate(x): assert len(blocktb) <= max_len - padded_x[ind, :len(blocktb)] = blocktb + padded_x[ind, : len(blocktb)] = blocktb return padded_x @@ -1192,8 +1206,8 @@ def make_tensor_with_pad( pad: T, dtype: torch.dtype, *, - max_len: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, + max_len: int | None = None, + device: Union[str, torch.device] | None = None, pin_memory: bool = False, ) -> torch.Tensor: """ @@ -1231,8 +1245,7 @@ def get_dtype_size(dtype: torch.dtype) -> int: # bool = 0, int = 1, float = 2, complex = 3 def _get_precision_level(dtype: torch.dtype) -> int: # NOTE: Complex dtypes return `is_floating_point=False` - return ((dtype != torch.bool) + dtype.is_floating_point + - dtype.is_complex * 2) + return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): @@ -1260,8 +1273,11 @@ def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): # Compare floating-point types src_info = torch.finfo(src_dtype) tgt_info = torch.finfo(tgt_dtype) - return (src_info.min >= tgt_info.min and src_info.max <= tgt_info.max - and src_info.resolution >= tgt_info.resolution) + return ( + src_info.min >= tgt_info.min + and src_info.max <= tgt_info.max + and src_info.resolution >= tgt_info.resolution + ) def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): @@ -1329,6 +1345,7 @@ def init_cached_hf_modules() -> None: Lazy initialization of the Hugging Face modules. """ from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() @@ -1372,8 +1389,8 @@ def find_nccl_library() -> str: # manually load the nccl library if so_file: logger.info( - "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", - so_file) + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file + ) else: if torch.version.cuda is not None: so_file = "libnccl.so.2" @@ -1385,11 +1402,11 @@ def find_nccl_library() -> str: return so_file -def find_nccl_include_paths() -> Optional[list[str]]: +def find_nccl_include_paths() -> list[str] | None: """ We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH` - environment variable, or we find the library file brought by - nvidia-nccl-cuXX. load_inline by default uses + environment variable, or we find the library file brought by + nvidia-nccl-cuXX. load_inline by default uses torch.utils.cpp_extension.include_paths """ paths: list[str] = [] @@ -1399,6 +1416,7 @@ def find_nccl_include_paths() -> Optional[list[str]]: try: import importlib.util + spec = importlib.util.find_spec("nvidia.nccl") if spec and getattr(spec, "submodule_search_locations", None): for loc in spec.submodule_search_locations: @@ -1431,7 +1449,6 @@ def _patched_set_stream(stream: torch.cuda.Stream) -> None: class _StreamPlaceholder: - def __init__(self): self.synchronize = lambda: None @@ -1448,8 +1465,8 @@ def current_stream() -> torch.cuda.Stream: from C/C++ code. """ from vllm.platforms import current_platform - if not hasattr(_current_stream_tls, - "value") or _current_stream_tls.value is None: + + if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: # when this function is called before any stream is set, # we return the default stream. # On ROCm using the default 0 stream in combination with RCCL @@ -1467,7 +1484,8 @@ def current_stream() -> torch.cuda.Stream: else: raise ValueError( "Fail to set current stream, current platform " - "may not support current_stream with torch API") + "may not support current_stream with torch API" + ) return _current_stream_tls.value @@ -1480,12 +1498,14 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: tmp_dir = tempfile.gettempdir() # add username to tmp_dir to avoid permission issues tmp_dir = os.path.join(tmp_dir, getpass.getuser()) - filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" - f"_thread_{threading.get_ident()}_" - f"at_{datetime.datetime.now()}.log").replace(" ", "_") - log_path = os.path.join(tmp_dir, "vllm", - f"vllm-instance-{vllm_config.instance_id}", - filename) + filename = ( + f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log" + ).replace(" ", "_") + log_path = os.path.join( + tmp_dir, "vllm", f"vllm-instance-{vllm_config.instance_id}", filename + ) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) @@ -1496,36 +1516,34 @@ def identity(value: T, **kwargs) -> T: return value -F = TypeVar('F', bound=Callable[..., Any]) +F = TypeVar("F", bound=Callable[..., Any]) def deprecate_args( start_index: int, is_deprecated: Union[bool, Callable[[], bool]] = True, - additional_message: Optional[str] = None, + additional_message: str | None = None, ) -> Callable[[F], F]: if not callable(is_deprecated): is_deprecated = partial(identity, is_deprecated) def wrapper(fn: F) -> F: - params = inspect.signature(fn).parameters pos_types = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ) - pos_kws = [ - kw for kw, param in params.items() if param.kind in pos_types - ] + pos_kws = [kw for kw, param in params.items() if param.kind in pos_types] @wraps(fn) def inner(*args, **kwargs): if is_deprecated(): - deprecated_args = pos_kws[start_index:len(args)] + deprecated_args = pos_kws[start_index : len(args)] if deprecated_args: msg = ( f"The positional arguments {deprecated_args} are " - "deprecated and will be removed in a future update.") + "deprecated and will be removed in a future update." + ) if additional_message is not None: msg += f" {additional_message}" @@ -1544,7 +1562,7 @@ def inner(*args, **kwargs): def deprecate_kwargs( *kws: str, is_deprecated: Union[bool, Callable[[], bool]] = True, - additional_message: Optional[str] = None, + additional_message: str | None = None, ) -> Callable[[F], F]: deprecated_kws = set(kws) @@ -1552,7 +1570,6 @@ def deprecate_kwargs( is_deprecated = partial(identity, is_deprecated) def wrapper(fn: F) -> F: - @wraps(fn) def inner(*args, **kwargs): if is_deprecated(): @@ -1560,7 +1577,8 @@ def inner(*args, **kwargs): if deprecated_kwargs: msg = ( f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update.") + "deprecated and will be removed in a future update." + ) if additional_message is not None: msg += f" {additional_message}" @@ -1577,8 +1595,7 @@ def inner(*args, **kwargs): @lru_cache(maxsize=8) -def _cuda_device_count_stateless( - cuda_visible_devices: Optional[str] = None) -> int: +def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for # LRU Cache purposes. @@ -1590,13 +1607,17 @@ def _cuda_device_count_stateless( import torch.version from vllm.platforms import current_platform + if not torch.cuda._is_compiled(): return 0 if current_platform.is_rocm(): # ROCm uses amdsmi instead of nvml for stateless device count # This requires a sufficiently modern version of Torch 2.4.0 - raw_count = torch.cuda._device_count_amdsmi() if (hasattr( - torch.cuda, "_device_count_amdsmi")) else -1 + raw_count = ( + torch.cuda._device_count_amdsmi() + if (hasattr(torch.cuda, "_device_count_amdsmi")) + else -1 + ) else: raw_count = torch.cuda._device_count_nvml() r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count @@ -1630,9 +1651,9 @@ def xpu_is_initialized() -> bool: return torch.xpu.is_initialized() -def cuda_get_device_properties(device, - names: Sequence[str], - init_cuda=False) -> tuple[Any, ...]: +def cuda_get_device_properties( + device, names: Sequence[str], init_cuda=False +) -> tuple[Any, ...]: """Get specified CUDA device property values without initializing CUDA in the current process.""" if init_cuda or cuda_is_initialized(): @@ -1642,11 +1663,12 @@ def cuda_get_device_properties(device, # Run in subprocess to avoid initializing CUDA as a side effect. mp_ctx = multiprocessing.get_context("fork") with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor: - return executor.submit(cuda_get_device_properties, device, names, - True).result() + return executor.submit(cuda_get_device_properties, device, names, True).result() -def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: +def weak_bind( + bound_method: Callable[..., Any], +) -> Callable[..., None]: """Make an instance method that weakly references its associated instance and no-ops once that instance is collected.""" @@ -1661,7 +1683,6 @@ def weak_bound(*args, **kwargs) -> None: def run_once(f: Callable[P, None]) -> Callable[P, None]: - def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: if wrapper.has_run: # type: ignore[attr-defined] return @@ -1677,19 +1698,18 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: class StoreBoolean(Action): - def __call__(self, parser, namespace, values, option_string=None): if values.lower() == "true": setattr(namespace, self.dest, True) elif values.lower() == "false": setattr(namespace, self.dest, False) else: - raise ValueError(f"Invalid boolean value: {values}. " - "Expected 'true' or 'false'.") + raise ValueError( + f"Invalid boolean value: {values}. Expected 'true' or 'false'." + ) -class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, - RawDescriptionHelpFormatter): +class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): """SortedHelpFormatter that sorts arguments by their option strings.""" def _split_lines(self, text, width): @@ -1701,7 +1721,7 @@ def _split_lines(self, text, width): # The patterns also include whitespace after the newline single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*") multiple_newlines = re.compile(r"\n{2,}\s*") - text = single_newline.sub(' ', text) + text = single_newline.sub(" ", text) lines = re.split(multiple_newlines, text) return sum([textwrap.wrap(line, width) for line in lines], []) @@ -1721,8 +1741,9 @@ class FlexibleArgumentParser(ArgumentParser): " --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n" "Additionally, list elements can be passed individually using +:\n" ' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n' - " --json-arg.key4+ value3 --json-arg.key4+=\'value4,value5\'\n\n") - _search_keyword: Optional[str] = None + " --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n" + ) + _search_keyword: str | None = None def __init__(self, *args, **kwargs): # Set the default "formatter_class" to SortedHelpFormatter @@ -1742,11 +1763,14 @@ def parse_known_args(self, args=None, namespace=None): logger.warning_once( "argument '--disable-log-requests' is deprecated and " "replaced with '--enable-log-requests'. This will be " - "removed in v0.12.0.") + "removed in v0.12.0." + ) namespace, args = super().parse_known_args(args, namespace) for action in FlexibleArgumentParser._deprecated: - if (hasattr(namespace, dest := action.dest) - and getattr(namespace, dest) != action.default): + if ( + hasattr(namespace, dest := action.dest) + and getattr(namespace, dest) != action.default + ): logger.warning_once("argument '%s' is deprecated", dest) return namespace, args @@ -1758,7 +1782,6 @@ def add_argument(self, *args, **kwargs): return action class _FlexibleArgumentGroup(_ArgumentGroup): - def add_argument(self, *args, **kwargs): deprecated = kwargs.pop("deprecated", False) action = super().add_argument(*args, **kwargs) @@ -1783,7 +1806,7 @@ def format_help(self): # Normalise the search keyword search_keyword = search_keyword.lower().replace("_", "-") # Return full help if searching for 'all' - if search_keyword == 'all': + if search_keyword == "all": self.epilog = self._json_tip return super().format_help() @@ -1802,12 +1825,12 @@ def format_help(self): for group in self._action_groups: for action in group._group_actions: # search option name - if any(search_keyword in opt.lower() - for opt in action.option_strings): + if any( + search_keyword in opt.lower() for opt in action.option_strings + ): matched_actions.append(action) if matched_actions: - formatter.start_section( - f"Arguments matching '{search_keyword}'") + formatter.start_section(f"Arguments matching '{search_keyword}'") formatter.add_arguments(matched_actions) formatter.end_section() formatter.add_text(self._json_tip) @@ -1817,12 +1840,12 @@ def format_help(self): formatter.add_text( f"No group or arguments matching '{search_keyword}'.\n" "Use '--help' to see available groups or " - "'--help=all' to see all available parameters.") + "'--help=all' to see all available parameters." + ) return formatter.format_help() # usage - formatter.add_usage(self.usage, self._actions, - self._mutually_exclusive_groups) + formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups) # description formatter.add_text(self.description) @@ -1857,13 +1880,16 @@ def parse_args( # type: ignore[override] if args and args[0] == "serve": try: model_idx = next( - i for i, arg in enumerate(args) - if arg == "--model" or arg.startswith("--model=")) + i + for i, arg in enumerate(args) + if arg == "--model" or arg.startswith("--model=") + ) logger.warning( "With `vllm serve`, you should provide the model as a " "positional argument or in a config file instead of via " "the `--model` option. " - "The `--model` option will be removed in v0.13.") + "The `--model` option will be removed in v0.13." + ) if args[model_idx] == "--model": model_tag = args[model_idx + 1] @@ -1887,7 +1913,7 @@ def parse_args( # type: ignore[override] except StopIteration: pass - if '--config' in args: + if "--config" in args: args = self._pull_args_from_config(args) def repl(match: re.Match) -> str: @@ -1901,28 +1927,29 @@ def repl(match: re.Match) -> str: processed_args = list[str]() for i, arg in enumerate(args): if arg.startswith("--help="): - FlexibleArgumentParser._search_keyword = arg.split( - '=', 1)[-1].lower() + FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower() processed_args.append("--help") - elif arg.startswith('--'): - if '=' in arg: - key, value = arg.split('=', 1) + elif arg.startswith("--"): + if "=" in arg: + key, value = arg.split("=", 1) key = pattern.sub(repl, key, count=1) - processed_args.append(f'{key}={value}') + processed_args.append(f"{key}={value}") else: key = pattern.sub(repl, arg, count=1) processed_args.append(key) - elif arg.startswith('-O') and arg != '-O' and arg[2] != '.': + elif arg.startswith("-O") and arg != "-O" and arg[2] != ".": # allow -O flag to be used without space, e.g. -O3 or -Odecode # -O.<...> handled later # also handle -O=<level> here - level = arg[3:] if arg[2] == '=' else arg[2:] - processed_args.append(f'-O.level={level}') - elif arg == '-O' and i + 1 < len(args) and args[i + 1] in { - "0", "1", "2", "3" - }: + level = arg[3:] if arg[2] == "=" else arg[2:] + processed_args.append(f"-O.level={level}") + elif ( + arg == "-O" + and i + 1 < len(args) + and args[i + 1] in {"0", "1", "2", "3"} + ): # Convert -O <n> to -O.level <n> - processed_args.append('-O.level') + processed_args.append("-O.level") else: processed_args.append(arg) @@ -1986,14 +2013,11 @@ def recursive_dict_update( # Merge all values with the same key into a single dict arg_dict = create_nested_dict(keys, value) - arg_duplicates = recursive_dict_update(dict_args[key], - arg_dict) - duplicates |= {f'{key}.{d}' for d in arg_duplicates} + arg_duplicates = recursive_dict_update(dict_args[key], arg_dict) + duplicates |= {f"{key}.{d}" for d in arg_duplicates} delete.add(i) # Filter out the dict args we set to None - processed_args = [ - a for i, a in enumerate(processed_args) if i not in delete - ] + processed_args = [a for i, a in enumerate(processed_args) if i not in delete] if duplicates: logger.warning("Found duplicate keys %s", ", ".join(duplicates)) @@ -2050,13 +2074,14 @@ def _pull_args_from_config(self, args: list[str]) -> list[str]: this way the order of priorities is maintained when these are args parsed by super(). """ - assert args.count( - '--config') <= 1, "More than one config file specified!" + assert args.count("--config") <= 1, "More than one config file specified!" - index = args.index('--config') + index = args.index("--config") if index == len(args) - 1: - raise ValueError("No config file specified! \ - Please check your command-line arguments.") + raise ValueError( + "No config file specified! \ + Please check your command-line arguments." + ) file_path = args[index + 1] @@ -2068,29 +2093,33 @@ def _pull_args_from_config(self, args: list[str]) -> list[str]: # followed by rest of cli args. # maintaining this order will enforce the precedence # of cli > config > defaults - if args[0].startswith('-'): + if args[0].startswith("-"): # No sub command (e.g., api_server entry point) - args = config_args + args[0:index] + args[index + 2:] + args = config_args + args[0:index] + args[index + 2 :] elif args[0] == "serve": - model_in_cli = len(args) > 1 and not args[1].startswith('-') - model_in_config = any(arg == '--model' for arg in config_args) + model_in_cli = len(args) > 1 and not args[1].startswith("-") + model_in_config = any(arg == "--model" for arg in config_args) if not model_in_cli and not model_in_config: raise ValueError( "No model specified! Please specify model either " - "as a positional argument or in a config file.") + "as a positional argument or in a config file." + ) if model_in_cli: # Model specified as positional arg, keep CLI version - args = [args[0]] + [ - args[1] - ] + config_args + args[2:index] + args[index + 2:] + args = ( + [args[0]] + + [args[1]] + + config_args + + args[2:index] + + args[index + 2 :] + ) else: # No model in CLI, use config if available - args = [args[0] - ] + config_args + args[1:index] + args[index + 2:] + args = [args[0]] + config_args + args[1:index] + args[index + 2 :] else: - args = [args[0]] + config_args + args[1:index] + args[index + 2:] + args = [args[0]] + config_args + args[1:index] + args[index + 2 :] return args @@ -2107,11 +2136,13 @@ def load_config_file(self, file_path: str) -> list[str]: '--tensor-parallel-size': '4' ] """ - extension: str = file_path.split('.')[-1] - if extension not in ('yaml', 'yml'): + extension: str = file_path.split(".")[-1] + if extension not in ("yaml", "yml"): raise ValueError( "Config file must be of a yaml/yml type.\ - %s supplied", extension) + %s supplied", + extension, + ) # only expecting a flat dictionary of atomic types processed_args: list[str] = [] @@ -2123,32 +2154,32 @@ def load_config_file(self, file_path: str) -> list[str]: except Exception as ex: logger.error( "Unable to read the config file at %s. \ - Make sure path is correct", file_path) + Make sure path is correct", + file_path, + ) raise ex store_boolean_arguments = [ - action.dest for action in self._actions - if isinstance(action, StoreBoolean) + action.dest for action in self._actions if isinstance(action, StoreBoolean) ] for key, value in config.items(): if isinstance(value, bool) and key not in store_boolean_arguments: if value: - processed_args.append('--' + key) + processed_args.append("--" + key) elif isinstance(value, list): if value: - processed_args.append('--' + key) + processed_args.append("--" + key) for item in value: processed_args.append(str(item)) else: - processed_args.append('--' + key) + processed_args.append("--" + key) processed_args.append(str(value)) return processed_args -async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, - **kwargs): +async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) @@ -2172,19 +2203,26 @@ def supports_kw( param_val = params.get(kw_name) # Types where the it may be valid, i.e., explicitly defined & nonvariadic - passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY)) + passable_kw_types = set( + ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + ) if param_val: is_sig_param = param_val.kind in passable_kw_types # We want kwargs only, but this is passable as a positional arg - if (requires_kw_only and is_sig_param - and param_val.kind != inspect.Parameter.KEYWORD_ONLY): + if ( + requires_kw_only + and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY + ): return False - if ((requires_kw_only - and param_val.kind == inspect.Parameter.KEYWORD_ONLY) - or (not requires_kw_only and is_sig_param)): + if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or ( + not requires_kw_only and is_sig_param + ): return True # If we're okay with var-kwargs, it's supported as long as @@ -2194,15 +2232,17 @@ def supports_kw( # mapping, but it wraps an ordered dict, and they appear in order. # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters last_param = params[next(reversed(params))] # type: ignore - return (last_param.kind == inspect.Parameter.VAR_KEYWORD - and last_param.name != kw_name) + return ( + last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name + ) return False def get_allowed_kwarg_only_overrides( callable: Callable[..., object], - overrides: Optional[Mapping[str, object]], + overrides: Mapping[str, object] | None, *, requires_kw_only: bool = True, allow_var_kwargs: bool = False, @@ -2234,10 +2274,12 @@ def get_allowed_kwarg_only_overrides( filtered_overrides = { kwarg_name: val for kwarg_name, val in overrides.items() - if supports_kw(callable, - kwarg_name, - requires_kw_only=requires_kw_only, - allow_var_kwargs=allow_var_kwargs) + if supports_kw( + callable, + kwarg_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) } # If anything is dropped, log a warning @@ -2246,11 +2288,15 @@ def get_allowed_kwarg_only_overrides( if requires_kw_only: logger.warning( "The following intended overrides are not keyword-only args " - "and will be dropped: %s", dropped_keys) + "and will be dropped: %s", + dropped_keys, + ) else: logger.warning( "The following intended overrides are not keyword args " - "and will be dropped: %s", dropped_keys) + "and will be dropped: %s", + dropped_keys, + ) return filtered_overrides @@ -2265,8 +2311,9 @@ def supports_dynamo() -> bool: # Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform def supports_xccl() -> bool: - return is_torch_equal_or_newer( - "2.8.0.dev") and torch.distributed.is_xccl_available() + return ( + is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() + ) # Some backends use pytorch version < 2.4.0 which doesn't @@ -2302,7 +2349,6 @@ def value(self): # Adapted from: https://stackoverflow.com/a/47212782/5082708 class LazyDict(Mapping[str, T], Generic[T]): - def __init__(self, factory: dict[str, Callable[[], T]]): self._factory = factory self._dict: dict[str, T] = {} @@ -2325,7 +2371,6 @@ def __len__(self): class ClassRegistry(UserDict[type[T], _V]): - def __getitem__(self, key: type[T]) -> _V: for cls in key.mro(): if cls in self.data: @@ -2359,8 +2404,9 @@ def weak_ref_tensor(tensor: Any) -> Any: def weak_ref_tensors( - tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor], - IntermediateTensors] + tensors: Union[ + torch.Tensor, list[torch.Tensor], tuple[torch.Tensor], IntermediateTensors + ], ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: """ Convenience function to create weak references to tensors, @@ -2375,11 +2421,11 @@ def weak_ref_tensors( # For IntermediateTensors used in pipeline parallelism from vllm.sequence import IntermediateTensors + if isinstance(tensors, IntermediateTensors): - ret = IntermediateTensors({ - key: weak_ref_tensor(val) - for key, val in tensors.tensors.items() - }) + ret = IntermediateTensors( + {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} + ) return ret raise ValueError("Invalid type for tensors") @@ -2419,7 +2465,8 @@ def get_vllm_optional_dependencies(): return { extra: [ - re.split(r";|>=|<=|==", req)[0] for req in requirements + re.split(r";|>=|<=|==", req)[0] + for req in requirements if req.endswith(f'extra == "{extra}"') ] for extra in extras @@ -2612,12 +2659,13 @@ def __getattr__(self, key: str): raise exc - raise AssertionError("PlaceholderModule should not be used " - "when the original module can be imported") + raise AssertionError( + "PlaceholderModule should not be used " + "when the original module can be imported" + ) class _PlaceholderModuleAttr(_PlaceholderBase): - def __init__(self, module: PlaceholderModule, attr_path: str) -> None: super().__init__() @@ -2626,14 +2674,15 @@ def __init__(self, module: PlaceholderModule, attr_path: str) -> None: self.__attr_path = attr_path def placeholder_attr(self, attr_path: str): - return _PlaceholderModuleAttr(self.__module, - f"{self.__attr_path}.{attr_path}") + return _PlaceholderModuleAttr(self.__module, f"{self.__attr_path}.{attr_path}") def __getattr__(self, key: str): getattr(self.__module, f"{self.__attr_path}.{key}") - raise AssertionError("PlaceholderModule should not be used " - "when the original module can be imported") + raise AssertionError( + "PlaceholderModule should not be used " + "when the original module can be imported" + ) # create a library to hold the custom op @@ -2641,13 +2690,13 @@ def __getattr__(self, key: str): def direct_register_custom_op( - op_name: str, - op_func: Callable, - mutates_args: Optional[list[str]] = None, - fake_impl: Optional[Callable] = None, - target_lib: Optional[Library] = None, - dispatch_key: Optional[str] = None, - tags: tuple[torch.Tag, ...] = (), + op_name: str, + op_func: Callable, + mutates_args: list[str] | None = None, + fake_impl: Callable | None = None, + target_lib: Library | None = None, + dispatch_key: str | None = None, + tags: tuple[torch.Tag, ...] = (), ): """ `torch.library.custom_op` can have significant overhead because it @@ -2666,12 +2715,14 @@ def direct_register_custom_op( """ if not supports_custom_op(): from vllm.platforms import current_platform + assert not current_platform.is_cuda_alike(), ( "cuda platform needs torch>=2.4 to support custom op, " "chances are you are using an old version of pytorch " "or a custom build of pytorch. It is recommended to " "use vLLM in a fresh new environment and let it install " - "the required dependencies.") + "the required dependencies." + ) return if mutates_args is None: @@ -2679,15 +2730,17 @@ def direct_register_custom_op( if dispatch_key is None: from vllm.platforms import current_platform + dispatch_key = current_platform.dispatch_key import torch.library + if hasattr(torch.library, "infer_schema"): - schema_str = torch.library.infer_schema(op_func, - mutates_args=mutates_args) + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: # for pytorch 2.4 import torch._custom_op.impl + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str, tags=tags) @@ -2733,6 +2786,7 @@ def kill_process_tree(pid: int): @dataclass class MemorySnapshot: """Memory snapshot.""" + torch_peak: int = 0 free_memory: int = 0 total_memory: int = 0 @@ -2754,15 +2808,14 @@ def measure(self): # After `torch.cuda.reset_peak_memory_stats()`, # `torch.cuda.memory_reserved()` will keep growing, and only shrink # when we call `torch.cuda.empty_cache()` or OOM happens. - self.torch_peak = torch.cuda.memory_stats().get( - "allocated_bytes.all.peak", 0) + self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) self.free_memory, self.total_memory = torch.cuda.mem_get_info() - shared_sysmem_device_mem_sms = ( - (8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark - if current_platform.is_cuda() and \ - current_platform.get_device_capability() in \ - shared_sysmem_device_mem_sms: + shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark + if ( + current_platform.is_cuda() + and current_platform.get_device_capability() in shared_sysmem_device_mem_sms + ): # On UMA (Orin, Thor and Spark) platform, # where both CPU and GPU rely on system memory, # the cudaMemGetInfo function shows the amount of free system memory @@ -2801,8 +2854,8 @@ def __sub__(self, other: MemorySnapshot) -> MemorySnapshot: @dataclass class MemoryProfilingResult: - """Memory profiling result. All numbers are in bytes. - """ + """Memory profiling result. All numbers are in bytes.""" + non_kv_cache_memory: int = 0 torch_peak_increase: int = 0 non_torch_increase: int = 0 @@ -2813,20 +2866,22 @@ class MemoryProfilingResult: profile_time: float = 0.0 def __repr__(self) -> str: - return (f"Memory profiling takes {self.profile_time:.2f} seconds. " - f"Total non KV cache memory: " - f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " - f"torch peak memory increase: " - f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " - f"non-torch forward increase memory: " - f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " - f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB.") + return ( + f"Memory profiling takes {self.profile_time:.2f} seconds. " + f"Total non KV cache memory: " + f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " + f"torch peak memory increase: " + f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " + f"non-torch forward increase memory: " + f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " + f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB." + ) @contextlib.contextmanager def memory_profiling( - baseline_snapshot: MemorySnapshot, - weights_memory: int) -> Generator[MemoryProfilingResult, None, None]: + baseline_snapshot: MemorySnapshot, weights_memory: int +) -> Generator[MemoryProfilingResult, None, None]: """Memory profiling context manager. baseline_snapshot: the memory snapshot before the current vLLM instance. weights_memory: memory used by PyTorch when loading the model weights. @@ -2900,29 +2955,34 @@ def memory_profiling( non_torch_memory = result.non_torch_increase peak_activation_memory = result.torch_peak_increase - result.non_kv_cache_memory = non_torch_memory + peak_activation_memory + result.weights_memory # noqa + result.non_kv_cache_memory = ( + non_torch_memory + peak_activation_memory + result.weights_memory + ) # noqa # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 def set_ulimit(target_soft_limit=65535): - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): logger.info("Windows detected, skipping ulimit adjustment.") return import resource + resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) if current_soft < target_soft_limit: try: - resource.setrlimit(resource_type, - (target_soft_limit, current_hard)) + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) except ValueError as e: logger.warning( "Found ulimit of %s and failed to automatically increase " "with error %s. This can cause fd limit errors like " "`OSError: [Errno 24] Too many open files`. Consider " - "increasing with ulimit -n", current_soft, e) + "increasing with ulimit -n", + current_soft, + e, + ) # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501 @@ -2953,7 +3013,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]: return scheme, host, port -def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str: +def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str: """Make a ZMQ path from its parts. Args: @@ -2976,9 +3036,9 @@ def make_zmq_socket( ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] path: str, socket_type: Any, - bind: Optional[bool] = None, - identity: Optional[bytes] = None, - linger: Optional[int] = None, + bind: bool | None = None, + identity: bytes | None = None, + linger: int | None = None, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -2992,10 +3052,7 @@ def make_zmq_socket( # - Set a large 0.5GB buffer to improve throughput # For systems with less memory: # - Use system default (-1) to avoid excessive memory consumption - if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024**3) # 0.5GB in bytes - else: - buf_size = -1 # Use system default buffer size + buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1 if bind is None: bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) @@ -3035,19 +3092,15 @@ def make_zmq_socket( def zmq_socket_ctx( path: str, socket_type: Any, - bind: Optional[bool] = None, + bind: bool | None = None, linger: int = 0, - identity: Optional[bytes] = None, + identity: bytes | None = None, ) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" ctx = zmq.Context() # type: ignore[attr-defined] try: - yield make_zmq_socket(ctx, - path, - socket_type, - bind=bind, - identity=identity) + yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity) except KeyboardInterrupt: logger.debug("Got Keyboard Interrupt.") @@ -3068,6 +3121,7 @@ def _maybe_force_spawn(): # to the subprocess so that it knows how to connect to the ray cluster. # env vars are inherited by subprocesses, even if we use spawn. import ray + os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address reasons.append("In a Ray actor and can only be spawned") @@ -3082,7 +3136,9 @@ def _maybe_force_spawn(): "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " "See https://docs.vllm.ai/en/latest/usage/" "troubleshooting.html#python-multiprocessing " - "for more information. Reasons: %s", "; ".join(reasons)) + "for more information. Reasons: %s", + "; ".join(reasons), + ) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -3101,7 +3157,7 @@ def get_mp_context(): def bind_kv_cache( ctx: dict[str, Any], kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] - shared_kv_cache_layers: Optional[dict[str, str]] = None + shared_kv_cache_layers: dict[str, str] | None = None, ) -> None: # Bind the kv_cache tensor to Attention modules, similar to # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] @@ -3119,33 +3175,40 @@ def bind_kv_cache( shared_kv_cache_layers = {} from vllm.attention import AttentionType from vllm.model_executor.models.utils import extract_layer_index + layer_need_kv_cache = [ - layer_name for layer_name in ctx - if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type - in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) \ - and ctx[layer_name].kv_sharing_target_layer_name is None + layer_name + for layer_name in ctx + if ( + hasattr(ctx[layer_name], "attn_type") + and ctx[layer_name].attn_type + in (AttentionType.DECODER, AttentionType.ENCODER_DECODER) + ) + and ctx[layer_name].kv_sharing_target_layer_name is None ] layer_index_sorted = sorted( - set( - extract_layer_index(layer_name) - for layer_name in layer_need_kv_cache)) + set(extract_layer_index(layer_name) for layer_name in layer_need_kv_cache) + ) for layer_name in layer_need_kv_cache: - kv_cache_idx = layer_index_sorted.index( - extract_layer_index(layer_name)) + kv_cache_idx = layer_index_sorted.index(extract_layer_index(layer_name)) forward_ctx = ctx[layer_name] assert len(forward_ctx.kv_cache) == len(kv_cache) for ve, ve_kv_cache in enumerate(kv_cache): forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] if shared_kv_cache_layers is not None: for layer_name, target_layer_name in shared_kv_cache_layers.items(): - assert extract_layer_index(target_layer_name) < \ - extract_layer_index(layer_name), \ - "v0 doesn't support interleaving kv sharing" + assert extract_layer_index(target_layer_name) < extract_layer_index( + layer_name + ), "v0 doesn't support interleaving kv sharing" ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache -def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], - kwargs: dict[str, Any]) -> Any: +def run_method( + obj: Any, + method: Union[str, bytes, Callable], + args: tuple[Any], + kwargs: dict[str, Any], +) -> Any: """ Run a method of an object with the given arguments and keyword arguments. If the method is string, it will be converted to a method using getattr. @@ -3159,8 +3222,9 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], try: func = getattr(obj, method) except AttributeError: - raise NotImplementedError(f"Method {method!r} is not" - " implemented.") from None + raise NotImplementedError( + f"Method {method!r} is not implemented." + ) from None else: func = partial(method, obj) # type: ignore return func(*args, **kwargs) @@ -3194,6 +3258,7 @@ def import_pynvml(): module to our codebase, and use it directly. """ import vllm.third_party.pynvml as pynvml + return pynvml @@ -3213,7 +3278,7 @@ def find_unimplemented_methods(self: object): unimplemented_methods = [] for attr_name in dir(self): # bypass inner method - if attr_name.startswith('_'): + if attr_name.startswith("_"): continue try: @@ -3227,8 +3292,8 @@ def find_unimplemented_methods(self: object): if "NotImplementedError" in src: unimplemented_methods.append(attr_name) if unimplemented_methods: - method_names = ','.join(unimplemented_methods) - msg = (f"Methods {method_names} not implemented in {self}") + method_names = ",".join(unimplemented_methods) + msg = f"Methods {method_names} not implemented in {self}" logger.debug(msg) @wraps(original_init) @@ -3236,7 +3301,7 @@ def wrapped_init(self, *args, **kwargs) -> None: original_init(self, *args, **kwargs) find_unimplemented_methods(self) - type.__setattr__(cls, '__init__', wrapped_init) + type.__setattr__(cls, "__init__", wrapped_init) return cls @@ -3308,7 +3373,7 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: @contextlib.contextmanager -def cprofile_context(save_file: Optional[str] = None): +def cprofile_context(save_file: str | None = None): """Run a cprofile Args: @@ -3330,7 +3395,7 @@ def cprofile_context(save_file: Optional[str] = None): prof.print_stats(sort="cumtime") -def cprofile(save_file: Optional[str] = None, enabled: bool = True): +def cprofile(save_file: str | None = None, enabled: bool = True): """Decorator to profile a Python method using cProfile. Args: @@ -3340,7 +3405,6 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True): """ def decorator(func: Callable): - @wraps(func) def wrapper(*args, **kwargs): if not enabled: @@ -3358,16 +3422,26 @@ def wrapper(*args, **kwargs): # Only relevant for models using ALiBi (e.g, MPT) def check_use_alibi(model_config: ModelConfig) -> bool: cfg = model_config.hf_text_config - return (getattr(cfg, "alibi", False) # Falcon - or ("BloomForCausalLM" in getattr(model_config.hf_config, - "architectures", [])) # Bloom - or getattr(cfg, "position_encoding_type", "") == - "alibi" # codellm_1b_alibi - or (hasattr(cfg, "attn_config") # MPT - and ((isinstance(cfg.attn_config, dict) - and cfg.attn_config.get("alibi", False)) or - (not isinstance(cfg.attn_config, dict) - and getattr(cfg.attn_config, "alibi", False))))) + return ( + getattr(cfg, "alibi", False) # Falcon + or ( + "BloomForCausalLM" in getattr(model_config.hf_config, "architectures", []) + ) # Bloom + or getattr(cfg, "position_encoding_type", "") == "alibi" # codellm_1b_alibi + or ( + hasattr(cfg, "attn_config") # MPT + and ( + ( + isinstance(cfg.attn_config, dict) + and cfg.attn_config.get("alibi", False) + ) + or ( + not isinstance(cfg.attn_config, dict) + and getattr(cfg.attn_config, "alibi", False) + ) + ) + ) + ) def sha256(input: Any) -> bytes: @@ -3435,7 +3509,7 @@ def is_torch_equal_or_newer(target: str) -> bool: return _is_torch_equal_or_newer(str(torch.__version__), target) except Exception: # Fallback to PKG-INFO to load the package info, needed by the doc gen. - return Version(importlib.metadata.version('torch')) >= Version(target) + return Version(importlib.metadata.version("torch")) >= Version(target) # Helper function used in testing. @@ -3484,9 +3558,9 @@ def has_tilelang() -> bool: return _has_module("tilelang") -def set_process_title(name: str, - suffix: str = "", - prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: +def set_process_title( + name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX +) -> None: """ Set the current process title to a specific name with an optional suffix. @@ -3513,7 +3587,7 @@ def write_with_prefix(s: str): if file.start_new_line: # type: ignore[attr-defined] file_write(prefix) idx = 0 - while (next_idx := s.find('\n', idx)) != -1: + while (next_idx := s.find("\n", idx)) != -1: next_idx += 1 file_write(s[idx:next_idx]) if next_idx == len(s): @@ -3528,7 +3602,7 @@ def write_with_prefix(s: str): file.write = write_with_prefix # type: ignore[method-assign] -def decorate_logs(process_name: Optional[str] = None) -> None: +def decorate_logs(process_name: str | None = None) -> None: """ Adds a process-specific prefix to each line of output written to stdout and stderr. @@ -3551,29 +3625,26 @@ def decorate_logs(process_name: Optional[str] = None) -> None: def length_from_prompt_token_ids_or_embeds( - prompt_token_ids: Optional[list[int]], - prompt_embeds: Optional[torch.Tensor], + prompt_token_ids: list[int] | None, + prompt_embeds: torch.Tensor | None, ) -> int: """Calculate the request length (in number of tokens) give either prompt_token_ids or prompt_embeds. """ - prompt_token_len = None if prompt_token_ids is None else len( - prompt_token_ids) - prompt_embeds_len = \ - None if prompt_embeds is None else len(prompt_embeds) + prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids) + prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds) if prompt_token_len is None: if prompt_embeds_len is None: - raise ValueError( - "Neither prompt_token_ids nor prompt_embeds were defined.") + raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.") return prompt_embeds_len else: - if (prompt_embeds_len is not None - and prompt_embeds_len != prompt_token_len): + if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len: raise ValueError( "Prompt token ids and prompt embeds had different lengths" f" prompt_token_ids={prompt_token_len}" - f" prompt_embeds={prompt_embeds_len}") + f" prompt_embeds={prompt_embeds_len}" + ) return prompt_token_len diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 125508bc4a9f..1d7f05cf67bb 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -4,12 +4,13 @@ Users of vLLM should always import **only** these wrappers. """ + from __future__ import annotations import functools import importlib import os -from typing import Any, Callable, NoReturn, Optional +from typing import Any, Callable, NoReturn import torch @@ -26,9 +27,14 @@ def is_deep_gemm_supported() -> bool: """ is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) - or current_platform.is_device_capability(100)) - return (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch - and not envs.VLLM_USE_FLASHINFER_MOE_FP8) + or current_platform.is_device_capability(100) + ) + return ( + envs.VLLM_USE_DEEP_GEMM + and has_deep_gemm() + and is_supported_arch + and not envs.VLLM_USE_FLASHINFER_MOE_FP8 + ) @functools.cache @@ -38,7 +44,8 @@ def is_deep_gemm_e8m0_used() -> bool: """ if not is_deep_gemm_supported(): logger.debug_once( - "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.") + "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system." + ) return False _lazy_init() @@ -51,13 +58,14 @@ def is_deep_gemm_e8m0_used() -> bool: logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.") return False - if current_platform.is_device_capability(100) and \ - envs.VLLM_USE_DEEP_GEMM_E8M0: + if current_platform.is_device_capability(100) and envs.VLLM_USE_DEEP_GEMM_E8M0: logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.") return True - if current_platform.is_device_capability(90) and \ - envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + if ( + current_platform.is_device_capability(90) + and envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER + ): logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.") return True @@ -69,7 +77,8 @@ def _missing(*_: Any, **__: Any) -> NoReturn: """Placeholder for unavailable DeepGEMM backend.""" raise RuntimeError( "DeepGEMM backend is not available or outdated. Please install or " - "update the `deep_gemm` to a newer version to enable FP8 kernels.") + "update the `deep_gemm` to a newer version to enable FP8 kernels." + ) _fp8_gemm_nt_impl: Callable[..., Any] | None = None @@ -89,21 +98,25 @@ def _lazy_init() -> None: global _get_mn_major_tma_aligned_tensor_impl # fast path - if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None - or _grouped_masked_impl is not None - or _fp8_mqa_logits_impl is not None - or _fp8_paged_mqa_logits_impl is not None - or _get_paged_mqa_logits_metadata_impl is not None): + if ( + _fp8_gemm_nt_impl is not None + or _grouped_impl is not None + or _grouped_masked_impl is not None + or _fp8_mqa_logits_impl is not None + or _fp8_paged_mqa_logits_impl is not None + or _get_paged_mqa_logits_metadata_impl is not None + ): return if not has_deep_gemm(): return # Set up deep_gemm cache path - DEEP_GEMM_JIT_CACHE_ENV_NAME = 'DG_JIT_CACHE_DIR' + DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR" if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None): os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join( - envs.VLLM_CACHE_ROOT, "deep_gemm") + envs.VLLM_CACHE_ROOT, "deep_gemm" + ) _dg = importlib.import_module("deep_gemm") @@ -113,9 +126,11 @@ def _lazy_init() -> None: _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) _get_paged_mqa_logits_metadata_impl = getattr( - _dg, "get_paged_mqa_logits_metadata", None) + _dg, "get_paged_mqa_logits_metadata", None + ) _get_mn_major_tma_aligned_tensor_impl = getattr( - _dg, "get_mn_major_tma_aligned_tensor", None) + _dg, "get_mn_major_tma_aligned_tensor", None + ) def get_num_sms() -> int: @@ -148,9 +163,9 @@ def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): _lazy_init() if _grouped_impl is None: return _missing(*args, **kwargs) - return _grouped_impl(*args, - disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), - **kwargs) + return _grouped_impl( + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): @@ -158,7 +173,8 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): if _grouped_masked_impl is None: return _missing(*args, **kwargs) return _grouped_masked_impl( - *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs) + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) def fp8_mqa_logits( @@ -191,8 +207,9 @@ def fp8_mqa_logits( return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) -def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int, - num_sms: int) -> torch.Tensor: +def get_paged_mqa_logits_metadata( + context_lens: torch.Tensor, block_size: int, num_sms: int +) -> torch.Tensor: """Build scheduling metadata for paged MQA logits. Args: @@ -208,8 +225,7 @@ def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int, _lazy_init() if _get_paged_mqa_logits_metadata_impl is None: return _missing() - return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, - num_sms) + return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) def fp8_paged_mqa_logits( @@ -245,14 +261,16 @@ def fp8_paged_mqa_logits( _lazy_init() if _fp8_paged_mqa_logits_impl is None: return _missing() - return _fp8_paged_mqa_logits_impl(q_fp8, - kv_cache_fp8, - weights, - context_lens, - block_tables, - schedule_metadata, - max_model_len, - clean_logits=True) + return _fp8_paged_mqa_logits_impl( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) def _ceil_to_ue8m0(x: torch.Tensor): @@ -269,15 +287,14 @@ def _align(x: int, y: int) -> int: # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38 @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def per_block_cast_to_fp8( - x: torch.Tensor, - block_size: list[int] = DEFAULT_BLOCK_SIZE, - use_ue8m0: bool = False) -> tuple[torch.Tensor, torch.Tensor]: + x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape block_m, block_n = block_size - x_padded = torch.zeros((_align(m, block_m), _align(n, block_n)), - dtype=x.dtype, - device=x.device) + x_padded = torch.zeros( + (_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device + ) x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) @@ -285,7 +302,8 @@ def per_block_cast_to_fp8( sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( - x_view.size(0), x_view.size(2)) + x_view.size(0), x_view.size(2) + ) def calc_diff(x: torch.Tensor, y: torch.Tensor): @@ -305,13 +323,18 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): def should_use_deepgemm_for_fp8_linear( - output_dtype: torch.dtype, - weight: torch.Tensor, - supports_deep_gemm: Optional[bool] = None): + output_dtype: torch.dtype, + weight: torch.Tensor, + supports_deep_gemm: bool | None = None, +): if supports_deep_gemm is None: supports_deep_gemm = is_deep_gemm_supported() - return (supports_deep_gemm and output_dtype == torch.bfloat16 - and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + return ( + supports_deep_gemm + and output_dtype == torch.bfloat16 + and weight.shape[0] % 128 == 0 + and weight.shape[1] % 128 == 0 + ) __all__ = [ diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 734cd938792a..1d707d56daba 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -4,6 +4,7 @@ Users of vLLM should always import **only** these wrappers. """ + from __future__ import annotations import contextlib @@ -11,7 +12,7 @@ import importlib import importlib.util import os -from typing import Any, Callable, NoReturn, Optional +from typing import Any, Callable, NoReturn import requests import torch @@ -44,7 +45,8 @@ def _missing(*_: Any, **__: Any) -> NoReturn: raise RuntimeError( "FlashInfer backend is not available. Please install the package " "to enable FlashInfer kernels: " - "https://github.com/flashinfer-ai/flashinfer") + "https://github.com/flashinfer-ai/flashinfer" + ) def _get_submodule(module_name: str) -> Any | None: @@ -56,9 +58,9 @@ def _get_submodule(module_name: str) -> Any | None: # General lazy import wrapper -def _lazy_import_wrapper(module_name: str, - attr_name: str, - fallback_fn: Callable[..., Any] = _missing): +def _lazy_import_wrapper( + module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing +): """Create a lazy import wrapper for a specific function.""" @functools.cache @@ -79,29 +81,34 @@ def wrapper(*args, **kwargs): # Create lazy wrappers for each function flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( - "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe") + "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe" +) flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper( - "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe") -flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", - "cutlass_fused_moe") + "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe" +) +flashinfer_cutlass_fused_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "cutlass_fused_moe" +) fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") nvfp4_block_scale_interleave = _lazy_import_wrapper( - "flashinfer", "nvfp4_block_scale_interleave") + "flashinfer", "nvfp4_block_scale_interleave" +) trtllm_fp4_block_scale_moe = _lazy_import_wrapper( - "flashinfer", "trtllm_fp4_block_scale_moe") + "flashinfer", "trtllm_fp4_block_scale_moe" +) # Special case for autotune since it returns a context manager autotune = _lazy_import_wrapper( "flashinfer.autotuner", "autotune", - fallback_fn=lambda *args, **kwargs: contextlib.nullcontext()) + fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(), +) @functools.cache def has_flashinfer_comm() -> bool: """Return ``True`` if FlashInfer comm module is available.""" - return has_flashinfer() and importlib.util.find_spec( - "flashinfer.comm") is not None + return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None @functools.cache @@ -128,8 +135,10 @@ def has_flashinfer_all2all() -> bool: @functools.cache def has_flashinfer_moe() -> bool: """Return ``True`` if FlashInfer MoE module is available.""" - return has_flashinfer() and importlib.util.find_spec( - "flashinfer.fused_moe") is not None + return ( + has_flashinfer() + and importlib.util.find_spec("flashinfer.fused_moe") is not None + ) @functools.cache @@ -174,7 +183,8 @@ def has_nvidia_artifactory() -> bool: else: logger.warning_once( "NVIDIA artifactory returned failed status code: %d", - response.status_code) + response.status_code, + ) return accessible except Exception as e: logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e) @@ -188,19 +198,18 @@ def supports_trtllm_attention() -> bool: NVIDIA artifactory is accessible """ # Requires SM100 and NVIDIA artifactory to be accessible to download cubins - return current_platform.is_device_capability( - 100) and has_nvidia_artifactory() + return current_platform.is_device_capability(100) and has_nvidia_artifactory() @functools.cache -def _force_use_trtllm_attention(env_value: Optional[bool]) -> Optional[bool]: +def _force_use_trtllm_attention(env_value: bool | None) -> bool | None: """Cache the env value for VLLM_USE_TRTLLM_ATTENTION""" if env_value is not None: logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) return env_value -def force_use_trtllm_attention() -> Optional[bool]: +def force_use_trtllm_attention() -> bool | None: """ Return ``None`` if VLLM_USE_TRTLLM_ATTENTION is not set, return ``True`` if TRTLLM attention is forced to be used, @@ -238,7 +247,8 @@ def use_trtllm_attention( if force_use_trtllm: logger.warning_once( "TRTLLM attention is not supported on this platform, " - "but VLLM_USE_TRTLLM_ATTENTION is set to 1") + "but VLLM_USE_TRTLLM_ATTENTION is set to 1" + ) return False # The combination of query and key heads is not supported @@ -252,8 +262,7 @@ def use_trtllm_attention( if has_spec and not is_prefill: # Speculative decoding requires TRTLLM attention for decodes - logger.info_once( - "Using TRTLLM attention (enabled for speculative decoding).") + logger.info_once("Using TRTLLM attention (enabled for speculative decoding).") return True # Must use TRTLLM attention if query is FP8 quantized @@ -261,28 +270,35 @@ def use_trtllm_attention( if has_sinks: raise RuntimeError( "TRTLLM FP8-qkv kernel is not supported for attention sinks. " - "Use kv_cache_dtype=auto for now.") + "Use kv_cache_dtype=auto for now." + ) logger.info_once("Using TRTLLM attention (query is quantized).") return True # If sinks are being used, we must use TRTLLM attention as it's # the only backend that supports them if has_sinks: - logger.info_once( - "Using TRTLLM attention (required for attention sinks).") + logger.info_once("Using TRTLLM attention (required for attention sinks).") return True if force_use_trtllm is None: # Environment variable not set - use auto-detection - use_trtllm = (num_tokens <= 256 and max_seq_len <= 131072 - and kv_cache_dtype == "auto") - if use_trtllm: - logger.warning_once("Using TRTLLM attention (auto-detected).") + if is_prefill: + # Prefill auto-detection + use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto" + if use_trtllm: + logger.warning_once("Using TRTLLM prefill attention (auto-detected).") + else: + # Decode auto-detection + use_trtllm = ( + num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto" + ) + if use_trtllm: + logger.warning_once("Using TRTLLM decode attention (auto-detected).") return use_trtllm # Environment variable is set to 1 - respect it - logger.info_once( - "Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") + logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") return True @@ -303,16 +319,14 @@ def flashinfer_mm_fp4( backend: str, ) -> torch.Tensor: from flashinfer import mm_fp4 as flashinfer_mm_fp4_ - return flashinfer_mm_fp4_(A, - B, - A_scale, - B_scale, - g_scale, - dtype, - block_size=16, - backend=backend) - - @torch.library.register_fake("vllm::flashinfer_mm_fp4", ) + + return flashinfer_mm_fp4_( + A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend + ) + + @torch.library.register_fake( + "vllm::flashinfer_mm_fp4", + ) def flashinfer_mm_fp4_fake( A: torch.Tensor, B: torch.Tensor, @@ -322,10 +336,7 @@ def flashinfer_mm_fp4_fake( dtype: torch.dtype, backend: str, ) -> torch.Tensor: - return torch.empty(A.shape[0], - B.shape[1], - dtype=dtype, - device=A.device) + return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) @torch.library.custom_op( "vllm::bmm_fp8", @@ -341,9 +352,12 @@ def bmm_fp8( backend: str, ) -> torch.Tensor: from flashinfer import bmm_fp8 as bmm_fp8_ + return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend) - @torch.library.register_fake("vllm::bmm_fp8", ) + @torch.library.register_fake( + "vllm::bmm_fp8", + ) def bmm_fp8_fake( A: torch.Tensor, B: torch.Tensor, @@ -352,18 +366,20 @@ def bmm_fp8_fake( dtype: torch.dtype, backend: str, ) -> torch.Tensor: - return torch.empty(A.shape[0], - A.shape[1], - B.shape[2], - dtype=dtype, - device=A.device) - - -def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, - block_scale_a: torch.Tensor, - block_scale_b: torch.Tensor, alpha: torch.Tensor, - out_dtype: torch.dtype, - backend: str) -> torch.Tensor: + return torch.empty( + A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device + ) + + +def flashinfer_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, + backend: str, +) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 assert a.stride(-1) == 1 and b.stride(-1) == 1 @@ -387,12 +403,13 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, def flashinfer_scaled_fp8_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None = None, +) -> torch.Tensor: assert a.ndim == 2 and b.ndim == 2 assert a.shape[1] == b.shape[0] assert scale_a.numel() == 1 and scale_b.numel() == 1 diff --git a/vllm/utils/gc_utils.py b/vllm/utils/gc_utils.py index 8ce2c200e299..e3b5b61dd364 100644 --- a/vllm/utils/gc_utils.py +++ b/vllm/utils/gc_utils.py @@ -36,8 +36,7 @@ def __init__(self, gc_debug_conf: Optional[str] = None) -> None: self.top_objects = json_conf.get("top_objects", -1) except Exception: self.enabled = False - logger.error("Failed to parse VLLM_GC_DEBUG(%s)", - VLLM_GC_DEBUG) + logger.error("Failed to parse VLLM_GC_DEBUG(%s)", VLLM_GC_DEBUG) logger.info("GC Debug Config. %s", str(self)) def __repr__(self) -> str: @@ -70,7 +69,8 @@ def handle(self, phase: str, info: dict[str, int]) -> None: # and top collected objects self.start_time_ns = time.monotonic_ns() self.gc_top_collected_objects = _compute_top_gc_collected_objects( - gc.get_objects(generation), self.config.top_objects) + gc.get_objects(generation), self.config.top_objects + ) elif phase == "stop": # After GC finished, Record GC elapsed time and # optionally top collected objects @@ -81,8 +81,11 @@ def handle(self, phase: str, info: dict[str, int]) -> None: elpased_ms, str(info.get("collected", "?")), generation, - (f" Top collected objects: \n{self.gc_top_collected_objects}" - if self.gc_top_collected_objects else ""), + ( + f" Top collected objects: \n{self.gc_top_collected_objects}" + if self.gc_top_collected_objects + else "" + ), ) @@ -125,4 +128,5 @@ def _compute_top_gc_collected_objects(objects: list[Any], top: int) -> str: object_types = [_compute_detailed_type(o) for o in objects] return "\n".join( f"{count:>5}:{object_type}" - for object_type, count in Counter(object_types).most_common(top)) + for object_type, count in Counter(object_types).most_common(top) + ) diff --git a/vllm/utils/jsontree.py b/vllm/utils/jsontree.py index 7eb58b5f5cf8..dcdc6ccb4c63 100644 --- a/vllm/utils/jsontree.py +++ b/vllm/utils/jsontree.py @@ -52,40 +52,35 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: def json_map_leaves( func: Callable[["torch.Tensor"], "torch.Tensor"], value: "BatchedTensorInputs", -) -> "BatchedTensorInputs": - ... +) -> "BatchedTensorInputs": ... @overload def json_map_leaves( func: Callable[[_T], _U], value: Union[_T, dict[str, _T]], -) -> Union[_U, dict[str, _U]]: - ... +) -> Union[_U, dict[str, _U]]: ... @overload def json_map_leaves( func: Callable[[_T], _U], value: Union[_T, list[_T]], -) -> Union[_U, list[_U]]: - ... +) -> Union[_U, list[_U]]: ... @overload def json_map_leaves( func: Callable[[_T], _U], value: Union[_T, tuple[_T, ...]], -) -> Union[_U, tuple[_U, ...]]: - ... +) -> Union[_U, tuple[_U, ...]]: ... @overload def json_map_leaves( func: Callable[[_T], _U], value: JSONTree[_T], -) -> JSONTree[_U]: - ... +) -> JSONTree[_U]: ... def json_map_leaves( @@ -111,8 +106,7 @@ def json_reduce_leaves( func: Callable[[_T, _T], _T], value: Union[_T, dict[str, _T]], /, -) -> _T: - ... +) -> _T: ... @overload @@ -120,8 +114,7 @@ def json_reduce_leaves( func: Callable[[_T, _T], _T], value: Union[_T, list[_T]], /, -) -> _T: - ... +) -> _T: ... @overload @@ -129,8 +122,7 @@ def json_reduce_leaves( func: Callable[[_T, _T], _T], value: Union[_T, tuple[_T, ...]], /, -) -> _T: - ... +) -> _T: ... @overload @@ -138,8 +130,7 @@ def json_reduce_leaves( func: Callable[[_T, _T], _T], value: JSONTree[_T], /, -) -> _T: - ... +) -> _T: ... @overload @@ -148,15 +139,14 @@ def json_reduce_leaves( value: JSONTree[_T], initial: _U, /, -) -> _U: - ... +) -> _U: ... def json_reduce_leaves( - func: Callable[..., Union[_T, _U]], - value: _JSONTree[_T], - initial: _U = cast(_U, ...), # noqa: B008 - /, + func: Callable[..., Union[_T, _U]], + value: _JSONTree[_T], + initial: _U = cast(_U, ...), # noqa: B008 + /, ) -> Union[_T, _U]: """ Apply a function of two arguments cumulatively to each leaf in a diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index d75dbcd5401b..e17676ccf7ef 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import (Annotated, Any, Optional, Union, get_args, get_origin, - get_type_hints) +from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints import torch @@ -11,7 +10,6 @@ class TensorShape: - def __init__( self, *dims: Union[int, str], @@ -37,8 +35,7 @@ def __str__(self) -> str: for dim in self.dims: if isinstance(dim, str): if dim in self.dynamic_dims: - dim_strs.append( - f"{dim}*") # Mark dynamic dimensions with * + dim_strs.append(f"{dim}*") # Mark dynamic dimensions with * else: dim_strs.append(dim) else: @@ -47,7 +44,6 @@ def __str__(self) -> str: class TensorSchema: - def __init__( self, *, @@ -94,34 +90,66 @@ def _match_shape_with_dynamic( return False return True - def _validate_nested_tensors( + def _fmt_indexer(self, idxs: tuple[int, ...]) -> str: + if not idxs: + return "" + + return str(list(idxs)) + + def _validate_field( self, - value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], + value: object, field_name: str, expected_shape: tuple[Union[int, str], ...], dynamic_dims: set[str], + leading_idxs: tuple[int, ...] = (), ) -> tuple[int, ...]: - """Validate a list/tuple of tensors and return the actual shape.""" + """Validate a field and return the actual shape.""" + if isinstance(value, (int, float)): + return () # Scalar + if isinstance(value, torch.Tensor): + return value.shape + + if not isinstance(value, (list, tuple)): + raise TypeError( + f"{field_name}{self._fmt_indexer(leading_idxs)} is not " + f"one of the expected types: int, float, Tensor, list, tuple. " + f"Got: {type(value)}" + ) + + if len(value) == 0: + raise ValueError( + f"{field_name}{self._fmt_indexer(leading_idxs)} is an empty sequence" + ) + # Ensure all tensors in the list have the same # shape, besides dynamic dimensions - first = value[0] for i, v in enumerate(value): - if not isinstance(v, torch.Tensor): - raise ValueError(f"{field_name}[{i}] is not a " - f"torch.Tensor") - if not self._match_shape_with_dynamic( - v.shape, - first.shape, - expected_shape, - dynamic_dims, + shape = self._validate_field( + v, + field_name, + expected_shape[1:], + dynamic_dims, + leading_idxs=leading_idxs + (i,), + ) + + if i == 0: + first_shape = shape + elif not self._match_shape_with_dynamic( + shape, + first_shape, + expected_shape, + dynamic_dims, ): - raise ValueError(f"{field_name} contains inconsistent " - f"shapes: {first.shape} vs {v.shape} " - f"at index {i}") + raise ValueError( + f"{field_name}{self._fmt_indexer(leading_idxs)} " + f"contains inconsistent shapes: {first_shape} " + f"(index 0) vs {shape} (index {i})" + ) # Treat the list as a stacked tensor: # shape = (len(list), *tensor.shape) - return (len(value), ) + first.shape + return (len(value),) + first_shape def _validate_tensor_shape_expected( self, @@ -134,27 +162,38 @@ def _validate_tensor_shape_expected( """Validate that the actual tensor shape matches the expected shape.""" if len(actual_shape) != len(expected_shape): - raise ValueError(f"{field_name} has rank {len(actual_shape)} " - f"but expected {len(expected_shape)}") + raise ValueError( + f"{field_name} has rank {len(actual_shape)} " + f"but expected {len(expected_shape)}. " + f"Expected shape: {expected_shape}, " + f"but got {actual_shape}" + ) for i, dim in enumerate(expected_shape): if dim in dynamic_dims: continue elif isinstance(dim, int): if actual_shape[i] != dim: - raise ValueError(f"{field_name} dim[{i}] expected " - f"{dim}, got {actual_shape[i]}") + raise ValueError( + f"{field_name} dim[{i}] expected " + f"{dim}, got {actual_shape[i]}. " + f"Expected shape: {expected_shape}, " + f"but got {actual_shape}" + ) elif isinstance(dim, str): if dim in shape_env: if actual_shape[i] != shape_env[dim]: - raise ValueError(f"{field_name} dim[{i}] expected " - f"'{dim}'={shape_env[dim]}, got " - f"{actual_shape[i]}") + raise ValueError( + f"{field_name} dim[{i}] expected " + f"'{dim}'={shape_env[dim]}, got " + f"{actual_shape[i]}" + ) else: shape_env[dim] = actual_shape[i] else: - raise TypeError(f"{field_name} dim[{i}] has unsupported " - f"type: {type(dim)}") + raise TypeError( + f"{field_name} dim[{i}] has unsupported type: {type(dim)}" + ) def validate(self) -> None: type_hints = get_type_hints(self.__class__, include_extras=True) @@ -162,8 +201,7 @@ def validate(self) -> None: for field_name, field_type in type_hints.items(): # Check if field is missing - if (not hasattr(self, field_name) - or getattr(self, field_name) is None): + if not hasattr(self, field_name) or getattr(self, field_name) is None: # Check if field is marked as optional actual_type = field_type if get_origin(field_type) is Annotated: @@ -187,40 +225,20 @@ def validate(self) -> None: for arg in args: if isinstance(arg, TensorShape): expected_shape = arg.resolve(**self._resolve_bindings) - if isinstance(value, (list, tuple)): - # list/tuple of Tensors → shape = (len(value), ...) - if value and isinstance(value[0], torch.Tensor): - actual_shape = self._validate_nested_tensors( - value, field_name, expected_shape, - arg.dynamic_dims) - elif value: - # list/tuple of scalars → shape = (len(value),) - actual_shape = (len(value), ) - else: - raise ValueError( - f"{field_name} is an empty list") - - # Tensor → shape = tensor.shape - elif isinstance(value, torch.Tensor): - actual_shape = value.shape - - # Otherwise, it's an unsupported type - else: - type_names = [] - for arg in args: - if hasattr(arg, "__name__"): - type_names.append(str(arg.__name__)) - else: - type_names.append(str(arg)) - - expected_types = ", ".join(type_names) - raise ValueError( - f"{field_name} is not one of the expected " - f"types: {expected_types}") + actual_shape = self._validate_field( + value, + field_name, + expected_shape, + arg.dynamic_dims, + ) self._validate_tensor_shape_expected( - actual_shape, expected_shape, field_name, - shape_env, arg.dynamic_dims) + actual_shape, + expected_shape, + field_name, + shape_env, + arg.dynamic_dims, + ) def print_shapes(self) -> None: """Print TensorShape annotations for debugging.""" diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 6ca0c63f6b59..6e27e93c9115 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -7,20 +7,26 @@ import torch from torch.nn.functional import scaled_dot_product_attention -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionType, + is_quantized_kv_cache, +) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) -from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.gpu_input_batch import InputBatch try: import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True # AttributeError is to handle a bug in ipex # https://github.com/intel/intel-extension-for-pytorch/pull/813 @@ -42,15 +48,15 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: @classmethod def validate_head_size(cls, head_size: int) -> None: attn_impl = _get_paged_attn_impl() - is_valid, supported_head_sizes = attn_impl.validate_head_size( - head_size) + is_valid, supported_head_sizes = attn_impl.validate_head_size(head_size) if not is_valid: attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -77,7 +83,8 @@ def get_kv_cache_shape( cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return _get_paged_attn_impl().get_kv_cache_shape( - num_blocks, block_size, num_kv_heads, head_size) + num_blocks, block_size, num_kv_heads, head_size + ) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: @@ -87,6 +94,7 @@ def use_cascade_attention(*args, **kwargs) -> bool: @dataclass class TorchSDPAMetadata(AttentionMetadata): """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. num_prefills: int # Number of prefill tokens. @@ -102,16 +110,16 @@ class TorchSDPAMetadata(AttentionMetadata): """Metadata for PagedAttention.""" # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. - seq_lens_tensor: Optional[torch.Tensor] + decode_seq_lens_tensor: Optional[torch.Tensor] # Maximum sequence length in the batch. 0 if it is prefill-only batch. - max_decode_seq_len: int + decode_max_seq_len: int # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks # in the kv cache. Each block can contain up to block_size tokens. # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. - block_tables: Optional[torch.Tensor] + decode_block_tables: Optional[torch.Tensor] """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts @@ -121,9 +129,9 @@ class TorchSDPAMetadata(AttentionMetadata): # For chunked prefill only max_query_len: Optional[int] = None - max_kv_len: Optional[int] = None + prefill_max_seq_len: Optional[int] = None prefill_query_start_loc: Optional[torch.Tensor] = None - kv_start_loc: Optional[torch.Tensor] = None + prefill_seq_start_loc: Optional[torch.Tensor] = None prefill_block_tables: Optional[torch.Tensor] = None # For V1 logits index only @@ -157,23 +165,27 @@ def __post_init__(self): @property def is_all_encoder_attn_metadata_set(self): - ''' + """ All attention metadata required for encoder attention is set. - ''' - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) + """ + return ( + (self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None) + ) @property def is_all_cross_attn_metadata_set(self): - ''' + """ All attention metadata required for enc/dec cross-attention is set. Superset of encoder attention required metadata. - ''' - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) + """ + return ( + self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None) + ) @property def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: @@ -191,7 +203,7 @@ def get_seq_lens( self, attn_type: str, ): - ''' + """ Extract appropriate sequence lengths from attention metadata according to attention type. @@ -204,10 +216,12 @@ def get_seq_lens( Returns: * Appropriate sequence lengths tensor for query * Appropriate sequence lengths tensor for key & value - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): seq_lens_q = self.seq_lens seq_lens_kv = self.seq_lens elif attn_type == AttentionType.ENCODER: @@ -224,7 +238,7 @@ def get_attn_bias( self, attn_type: str, ) -> Optional[list[torch.Tensor]]: - ''' + """ Extract appropriate attention bias from attention metadata according to attention type. @@ -236,10 +250,12 @@ def get_attn_bias( Returns: * Appropriate attention bias value given the attention type - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): return self.attn_bias elif attn_type == AttentionType.ENCODER: return self.encoder_attn_bias @@ -253,7 +269,7 @@ def set_attn_bias( attn_bias: list[torch.Tensor], attn_type: str, ) -> None: - ''' + """ Update appropriate attention bias field of attention metadata, according to attention type. @@ -263,10 +279,12 @@ def set_attn_bias( * attn_bias: The desired attention bias value * attn_type: encoder attention, decoder self-attention, encoder/decoder cross-attention - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): self.attn_bias = attn_bias elif attn_type == AttentionType.ENCODER: self.encoder_attn_bias = attn_bias @@ -279,7 +297,7 @@ def get_seq_len_block_table_args( self, attn_type: str, ) -> tuple: - ''' + """ The particular choice of sequence-length- and block-table-related attributes which should be extracted from attn_metadata is dependent on the type of attention operation. @@ -301,41 +319,48 @@ def get_seq_len_block_table_args( * Appropriate sequence-lengths tensor * Appropriate max sequence-length scalar * Appropriate block tables (or None) - ''' + """ - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): + if ( + attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY + ): # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run - return (self.seq_lens_tensor, self.max_decode_seq_len, - self.block_tables) + return ( + self.decode_seq_lens_tensor, + self.decode_max_seq_len, + self.decode_block_tables, + ) elif attn_type == AttentionType.ENCODER_DECODER: # Enc/dec cross-attention KVs match encoder sequence length; # cross-attention utilizes special "cross" block tables - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, - self.cross_block_tables) + return ( + self.encoder_seq_lens_tensor, + self.max_encoder_seq_len, + self.cross_block_tables, + ) elif attn_type == AttentionType.ENCODER: # No block tables associated with encoder attention - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, - None) + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): + reorder_batch_threshold: int = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device) -> None: + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ) -> None: super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.scheduler_config = vllm_config.scheduler_config - - # For reorder - self.reorder_prompt_req_index_list = np.empty( - vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) - self.reorder_decode_req_index_list = np.empty( - vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) - self.num_prompt_req: int = 0 + self._init_reorder_batch_threshold(1, False) self.seq_start_loc_cpu = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, @@ -344,103 +369,70 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], ) self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() - def reorder_batch(self, input_batch: InputBatch, - scheduler_output: SchedulerOutput) -> bool: - prompt_list_idx = 0 - decode_list_idx = 0 - for req_index in range(input_batch.num_reqs): - if input_batch.num_computed_tokens_cpu[ - req_index] < input_batch.num_prompt_tokens[req_index]: - # prompt stage - self.reorder_prompt_req_index_list[prompt_list_idx] = req_index - prompt_list_idx += 1 - else: - # decode stage - self.reorder_decode_req_index_list[decode_list_idx] = req_index - decode_list_idx += 1 - assert decode_list_idx + prompt_list_idx == input_batch.num_reqs - - # Update prompt requests number - self.num_prompt_req = prompt_list_idx - - reorder_req_num = 0 - for req_index in range(decode_list_idx): - if self.reorder_decode_req_index_list[req_index] < prompt_list_idx: - reorder_req_num += 1 - else: - break - - if reorder_req_num == 0: - return False - - reorder_prompt_list = ( - self.reorder_prompt_req_index_list[:prompt_list_idx] - [-reorder_req_num:]) - reorder_decode_list = ( - self.reorder_decode_req_index_list[:decode_list_idx] - [:reorder_req_num]) - assert reorder_decode_list.size == reorder_prompt_list.size - - for idx in range(reorder_req_num): - prompt_req_index = reorder_prompt_list[idx].item() - decode_req_index = reorder_decode_list[idx].item() - input_batch.swap_states(prompt_req_index, decode_req_index) - - return True - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> TorchSDPAMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> TorchSDPAMetadata: num_reqs = common_attn_metadata.num_reqs max_query_len = common_attn_metadata.max_query_len seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens_np = seq_lens_cpu.numpy() - num_prompt_req = self.num_prompt_req - max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item( - ) if num_prompt_req > 0 else 0 - max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item( - ) if num_prompt_req < num_reqs else 0 - self.seq_start_loc_np[0] = 0 - np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1]) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item()) - num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() - - num_prefill_tokens) + query_start_loc_np = query_start_loc_cpu.numpy() + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True, + ) + ) + + max_prefill_seq_len = ( + seq_lens_np[num_decodes:num_reqs].max().item() if num_prefills > 0 else 0 + ) + max_decode_seq_len = ( + seq_lens_np[:num_decodes].max().item() if num_prefills < num_reqs else 0 + ) + self.seq_start_loc_np[0] = 0 + np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1 : num_reqs + 1]) slot_mapping = common_attn_metadata.slot_mapping.long() block_table_tensor = common_attn_metadata.block_table_tensor + query_start_loc_np = query_start_loc_cpu.numpy() + query_start_loc_np[num_decodes : num_reqs + 1] -= num_decode_tokens attn_metadata = TorchSDPAMetadata( - num_prefills=num_prompt_req, + num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, # to ensure inference when chunked_prefill is disabled seq_lens=seq_lens_cpu.tolist(), - seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode - max_decode_seq_len=max_decode_seq_len, # decode - block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode + decode_seq_lens_tensor=seq_lens_cpu[:num_decodes], # decode + decode_max_seq_len=max_decode_seq_len, # decode + decode_block_tables=block_table_tensor[:num_decodes], # decode chunked_prefill=self.scheduler_config.chunked_prefill_enabled, max_query_len=max_query_len, - max_kv_len=max_prefill_seq_len, - prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req + - 1], # prefill - kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req + - 1], # prefill - prefill_block_tables=block_table_tensor[: - num_prompt_req], # prefill - query_start_loc=query_start_loc_cpu[:num_reqs + - 1], # for logits index + prefill_max_seq_len=max_prefill_seq_len, + prefill_query_start_loc=query_start_loc_cpu[ + num_decodes : num_reqs + 1 + ], # prefill + prefill_seq_start_loc=self.seq_start_loc_cpu[ + num_decodes : num_reqs + 1 + ], # prefill + prefill_block_tables=block_table_tensor[num_decodes:num_reqs], # prefill + query_start_loc=query_start_loc_cpu[: num_reqs + 1], # for logits index ) return attn_metadata class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): - def __init__( self, num_heads: int, @@ -457,8 +449,10 @@ def __init__( if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") if logits_soft_cap is not None: - logger.warning_once("Torch SPDA does not support logits soft cap. " - "Outputs may be slightly off.") + logger.warning_once( + "Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off." + ) self.paged_attn_impl = _get_paged_attn_impl() self.num_heads = num_heads self.head_size = head_size @@ -471,13 +465,15 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.need_mask = (self.alibi_slopes is not None - or self.sliding_window is not None) + self.need_mask = ( + self.alibi_slopes is not None or self.sliding_window is not None + ) if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: raise NotImplementedError( "Torch SDPA backend FP8 KV cache requires " - "intel_extension_for_pytorch support.") + "intel_extension_for_pytorch support." + ) self.attn_type = attn_type def forward( @@ -509,22 +505,28 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" - " for TorchSDPABackendImpl") + " for TorchSDPABackendImpl" + ) # For warming-up if attn_metadata is None: return query attn_type = self.attn_type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") + if attn_type == AttentionType.ENCODER and ( + not attn_metadata.is_all_encoder_attn_metadata_set + ): + raise AttributeError( + "Encoder attention requires setting encoder metadata attributes." + ) + elif attn_type == AttentionType.ENCODER_DECODER and ( + not attn_metadata.is_all_cross_attn_metadata_set + ): + raise AttributeError( + "Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes." + ) # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -535,7 +537,7 @@ def forward( else: assert value is None - if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: # KV-cache during decoder-self- or # encoder-decoder-cross-attention, but not # during encoder attention. @@ -544,7 +546,8 @@ def forward( # we still need to break out key_cache and value_cache # i.e. for later use by paged attention key_cache, value_cache = self.paged_attn_impl.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + kv_cache, self.num_kv_heads, self.head_size + ) if (key is not None) and (value is not None): if attn_type == AttentionType.ENCODER_DECODER: @@ -557,8 +560,15 @@ def forward( updated_slot_mapping = attn_metadata.slot_mapping self.paged_attn_impl.write_to_paged_cache( - key, value, key_cache, value_cache, updated_slot_mapping, - self.kv_cache_dtype, layer._k_scale, layer._v_scale) + key, + value, + key_cache, + value_cache, + updated_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if attn_type != AttentionType.ENCODER: # Decoder self-attention supports chunked prefill. @@ -584,26 +594,24 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore assert attn_metadata.seq_lens is not None - self._run_sdpa_forward(output, - query, - key, - value, - prefill_meta, - attn_type=attn_type) + self._run_sdpa_forward( + output, query, key, value, prefill_meta, attn_type=attn_type + ) else: # prefix-enabled attention assert not self.need_mask import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) ipex_modules.PagedAttention.flash_attn_varlen_func( - output[:prefill_meta.num_prefill_tokens, :, :], - query[:prefill_meta.num_prefill_tokens, :, :], + output[prefill_meta.num_decode_tokens :, :, :], + query[prefill_meta.num_decode_tokens :, :, :], key_cache, value_cache, prefill_meta.prefill_query_start_loc, - prefill_meta.kv_start_loc, + prefill_meta.prefill_seq_start_loc, prefill_meta.max_query_len, - prefill_meta.max_kv_len, + prefill_meta.prefill_max_seq_len, self.scale, True, prefill_meta.prefill_block_tables, @@ -612,7 +620,8 @@ def forward( if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have decode metadata.") + "Encoder-only models should not have decode metadata." + ) # Decoding run. ( seq_lens_arg, @@ -621,8 +630,8 @@ def forward( ) = decode_meta.get_seq_len_block_table_args(attn_type) self.paged_attn_impl.forward_decode( - output[attn_metadata.num_prefill_tokens:, :, :], - query[attn_metadata.num_prefill_tokens:, :, :], + output[: attn_metadata.num_decode_tokens, :, :], + query[: attn_metadata.num_decode_tokens, :, :], key_cache, value_cache, block_tables_arg, @@ -652,13 +661,15 @@ def _run_sdpa_forward( if attn_masks is None: if self.alibi_slopes is not None: attn_masks = _make_alibi_bias( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens) # type: ignore + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, # type: ignore + ) elif self.sliding_window is not None: assert attn_metadata.seq_lens is not None attn_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, - query.dtype) # type: ignore + attn_metadata.seq_lens, self.sliding_window, query.dtype + ) else: seq_lens, _ = attn_metadata.get_seq_lens(attn_type) attn_masks = [None] * len(seq_lens) @@ -672,22 +683,26 @@ def _run_sdpa_forward( key = key.repeat_interleave(self.num_queries_per_kv, dim=-3) value = value.repeat_interleave(self.num_queries_per_kv, dim=-3) - causal_attn = (attn_type == AttentionType.DECODER) + causal_attn = attn_type == AttentionType.DECODER seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) start_q, start_kv = 0, 0 - for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, - attn_masks): + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, attn_masks): end_q = start_q + seq_len_q end_kv = start_kv + seq_len_kv - sub_out = scaled_dot_product_attention( - query[None, :, start_q:end_q, :], - key[None, :, start_kv:end_kv, :], - value[None, :, start_kv:end_kv, :], - attn_mask=mask, - dropout_p=0.0, - is_causal=causal_attn and mask is None, - scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + sub_out = ( + scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and mask is None, + scale=self.scale, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) output[start_q:end_q, :, :] = sub_out start_q, start_kv = end_q, end_kv @@ -710,9 +725,11 @@ def _make_alibi_bias( num_heads = alibi_slopes.shape[0] bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + inf_mask = ( + torch.empty((1, seq_len, seq_len), dtype=bias.dtype) + .fill_(-torch.inf) + .triu_(diagonal=1) + ) attn_biases.append((bias + inf_mask).to(dtype)) return attn_biases @@ -741,7 +758,6 @@ def _make_sliding_window_bias( class _PagedAttention: - @staticmethod def validate_head_size(head_size: int) -> tuple[bool, list[int]]: SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] @@ -768,8 +784,7 @@ def split_kv_cache( num_blocks = kv_cache.shape[1] key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, - -1, x) + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) value_cache = kv_cache[1] value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) return key_cache, value_cache @@ -845,7 +860,6 @@ def forward_decode( class _IPEXPagedAttention(_PagedAttention): - @staticmethod def validate_head_size(head_size: int) -> tuple[bool, list[int]]: return True, [] @@ -878,8 +892,8 @@ def write_to_paged_cache( *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, - slot_mapping.flatten().int()) + key, value, key_cache, value_cache, slot_mapping.flatten().int() + ) @staticmethod def forward_decode( @@ -899,17 +913,30 @@ def forward_decode( *args, ) -> None: block_size = value_cache.shape[2] - head_mapping = torch.arange( - 0, - num_kv_heads, - device="cpu", - dtype=torch.int32, - ).view(num_kv_heads, - 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + head_mapping = ( + torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ) + .view(num_kv_heads, 1) + .repeat_interleave(query.size(1) // num_kv_heads) + .flatten() + ) ipex_modules.PagedAttention.single_query_cached_kv_attention( - output, query.contiguous(), key_cache, value_cache, head_mapping, - scale, block_tables, context_lens, block_size, max_context_len, - alibi_slopes) + output, + query.contiguous(), + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) def _get_paged_attn_impl(): diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f0770f744146..bb3dcddba3e9 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" + from dataclasses import dataclass from typing import Optional @@ -8,34 +9,43 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + is_quantized_kv_cache, +) from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version, - is_flash_attn_varlen_func_available) +from vllm.attention.utils.fa_utils import ( + flash_attn_supports_fp8, + get_flash_attn_version, + is_flash_attn_varlen_func_available, +) if is_flash_attn_varlen_func_available(): - from vllm.attention.utils.fa_utils import (flash_attn_varlen_func, - get_scheduler_metadata, - reshape_and_cache_flash) + from vllm.attention.utils.fa_utils import ( + flash_attn_varlen_func, + get_scheduler_metadata, + reshape_and_cache_flash, + ) from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - get_kv_cache_layout) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) class FlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True supports_quant_query_input: bool = True @@ -56,7 +66,8 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -141,7 +152,8 @@ class FlashAttentionMetadata: def _get_sliding_window_configs( - vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: + vllm_config: VllmConfig, +) -> set[Optional[tuple[int, int]]]: """Get the set of all sliding window configs used in the model.""" sliding_window_configs: set[Optional[tuple[int, int]]] = set() layers = get_layers_from_vllm_config(vllm_config, Attention) @@ -151,8 +163,7 @@ def _get_sliding_window_configs( return sliding_window_configs -class FlashAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata]): +class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetadata]): # FA3: # Supports full cudagraphs for all cases. # @@ -171,11 +182,19 @@ class FlashAttentionMetadataBuilder( # to FULL_AND_PIECEWISE. # TODO(luka, lucas): audit FA2 as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support = AttentionCGSupport.ALWAYS \ - if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH + cudagraph_support = ( + AttentionCGSupport.ALWAYS + if get_flash_attn_version() == 3 + else AttentionCGSupport.UNIFORM_BATCH + ) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config @@ -183,18 +202,19 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.compilation_config = vllm_config.compilation_config self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.kv_cache_dtype = kv_cache_spec.dtype self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.max_num_splits = 0 # No upper bound on the number of splits. - self.aot_schedule = (get_flash_attn_version() == 3) + self.aot_schedule = get_flash_attn_version() == 3 - self.use_full_cuda_graph = \ + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) self.max_cudagraph_size = self.compilation_config.max_capture_size if self.use_full_cuda_graph and self.aot_schedule: @@ -202,8 +222,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. raise ValueError( - "Capture size larger than 992 is not supported for " - "full cuda graph.") + "Capture size larger than 992 is not supported for full cuda graph." + ) self.scheduler_metadata = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, @@ -213,19 +233,20 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = ( - envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) + self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashAttentionMetadata: """ - fast_build disables AOT scheduling, used when there will be few + fast_build disables AOT scheduling, used when there will be few iterations i.e. spec-decode """ num_reqs = common_attn_metadata.num_reqs @@ -249,8 +270,7 @@ def build(self, # build() call so the layers are constructed (cannot populate) # in __init__. if aot_schedule: - sliding_window_configs = _get_sliding_window_configs( - self.vllm_config) + sliding_window_configs = _get_sliding_window_configs(self.vllm_config) if len(sliding_window_configs) == 1: sliding_window_config = sliding_window_configs.pop() if sliding_window_config is not None: @@ -260,20 +280,21 @@ def build(self, aot_schedule = False max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible - if self.use_full_cuda_graph and \ - num_actual_tokens <= self.max_cudagraph_size: + if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: # NOTE(woosuk): Setting num_splits > 1 may increase the memory # usage, because the intermediate buffers of size [num_splits, # num_heads, num_tokens, head_size] are allocated. Therefore, # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( - cache_dtype) + cache_dtype + ) else: qkv_dtype = self.kv_cache_dtype if aot_schedule: @@ -297,39 +318,44 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, max_query_len=num_actual_tokens, seqlens=prefix_kv_lens, max_seq_len=common_prefix_len, - causal=False) - scheduler_metadata = schedule(batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=suffix_kv_lens, - max_seq_len=max_seq_len - - common_prefix_len, - causal=True) + causal=False, + ) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=suffix_kv_lens, + max_seq_len=max_seq_len - common_prefix_len, + causal=True, + ) else: cu_prefix_query_lens = None prefix_kv_lens = None suffix_kv_lens = None prefix_scheduler_metadata = None - scheduler_metadata = schedule(batch_size=num_reqs, - cu_query_lens=query_start_loc, - max_query_len=max_query_len, - seqlens=seq_lens, - max_seq_len=max_seq_len, - causal=causal) + scheduler_metadata = schedule( + batch_size=num_reqs, + cu_query_lens=query_start_loc, + max_query_len=max_query_len, + seqlens=seq_lens, + max_seq_len=max_seq_len, + causal=causal, + ) # For FA3 + full cudagraph if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] @@ -357,7 +383,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, - causal=causal) + causal=causal, + ) return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: @@ -365,7 +392,6 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class FlashAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -406,18 +432,20 @@ def __init__( self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() - if is_quantized_kv_cache(self.kv_cache_dtype) \ - and not flash_attn_supports_fp8(): + if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): raise NotImplementedError( - "FlashAttention does not support fp8 kv-cache on this device.") + "FlashAttention does not support fp8 kv-cache on this device." + ) self.sinks = sinks if self.sinks is not None: assert self.vllm_flash_attn_version == 3, ( - "Sinks are only supported in FlashAttention 3") + "Sinks are only supported in FlashAttention 3" + ) assert self.sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " - "heads in the layer") + "heads in the layer" + ) def forward( self, @@ -450,8 +478,8 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") + "fused output quantization is not yet supported for FlashAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -474,11 +502,14 @@ def forward( if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, # we use direct Q, K, V tensors without caching - return self._forward_encoder_attention(query[:num_actual_tokens], - key[:num_actual_tokens], - value[:num_actual_tokens], - output[:num_actual_tokens], - attn_metadata, layer) + return self._forward_encoder_attention( + query[:num_actual_tokens], + key[:num_actual_tokens], + value[:num_actual_tokens], + output[:num_actual_tokens], + attn_metadata, + layer, + ) # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) @@ -486,8 +517,11 @@ def forward( # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. - if (self.kv_sharing_target_layer_name is None and key is not None - and value is not None): + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. # NOTE(woosuk): Here, key and value are padded while slot_mapping is @@ -509,7 +543,8 @@ def forward( if self.kv_cache_dtype.startswith("fp8"): # queries are quantized in the attention layer dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( - self.kv_cache_dtype) + self.kv_cache_dtype + ) key_cache = key_cache.view(dtype) value_cache = value_cache.view(dtype) @@ -597,7 +632,8 @@ def _forward_encoder_attention( # For encoder attention, process FP8 quantization if needed if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError( - "quantization is not supported for encoder attention") + "quantization is not supported for encoder attention" + ) # Use encoder-specific metadata for sequence information cu_seqlens_q = attn_metadata.query_start_loc @@ -607,7 +643,8 @@ def _forward_encoder_attention( descale_shape = ( cu_seqlens_q.shape[0] - 1, # type: ignore[union-attr] - self.num_kv_heads) + self.num_kv_heads, + ) # Call flash attention directly on Q, K, V tensors flash_attn_varlen_func( @@ -670,8 +707,12 @@ def use_cascade_attention( num_queries_per_kv = num_query_heads // num_kv_heads # The criteria for using FlashDecoding can be found in the following link: # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 - use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window - and not use_alibi and np.all(query_lens == 1)) + use_flash_decoding = ( + num_queries_per_kv > 1 + and not use_sliding_window + and not use_alibi + and np.all(query_lens == 1) + ) if not use_flash_decoding: # Use cascade attention. return True @@ -693,8 +734,9 @@ def use_cascade_attention( cascade_waves = cdiv(cascade_ctas, num_sms) cascade_time = cascade_waves * num_prefix_tiles - flash_decoding_ctas = (num_reqs * num_kv_heads * - cdiv(num_queries_per_kv, q_tile_size)) + flash_decoding_ctas = ( + num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size) + ) flash_decoding_ctas *= num_prefix_tiles flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) @@ -726,10 +768,11 @@ def cascade_attention( k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert alibi_slopes is None, ("Cascade attention does not support ALiBi.") + assert alibi_slopes is None, "Cascade attention does not support ALiBi." # TODO: Support sliding window. assert sliding_window == (-1, -1), ( - "Cascade attention does not support sliding window.") + "Cascade attention does not support sliding window." + ) num_tokens = query.shape[0] block_size = key_cache.shape[-3] @@ -755,12 +798,9 @@ def cascade_attention( return_softmax_lse=True, scheduler_metadata=prefix_scheduler_metadata, fa_version=fa_version, - q_descale=q_descale.expand(descale_shape) - if q_descale is not None else None, - k_descale=k_descale.expand(descale_shape) - if k_descale is not None else None, - v_descale=v_descale.expand(descale_shape) - if v_descale is not None else None, + q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -782,14 +822,10 @@ def cascade_attention( return_softmax_lse=True, scheduler_metadata=suffix_scheduler_metadata, fa_version=fa_version, - q_descale=q_descale.expand(descale_shape) - if q_descale is not None else None, - k_descale=k_descale.expand(descale_shape) - if k_descale is not None else None, - v_descale=v_descale.expand(descale_shape) - if v_descale is not None else None, + q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, + k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, + v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, ) # Merge prefix and suffix outputs, and store the result in output. - merge_attn_states(output, prefix_output, prefix_lse, suffix_output, - suffix_lse) + merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 13f18d103b53..c7a826a67d7d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1,50 +1,57 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashInfer.""" + from __future__ import annotations from dataclasses import dataclass -from typing import ClassVar, Optional, Union +from typing import ClassVar, Union import numpy as np import torch -from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - MultiLevelCascadeAttentionWrapper) +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + MultiLevelCascadeAttentionWrapper, +) from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache from flashinfer.utils import FP4Tensor from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionType, +) from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym, kNvfp4Quant) + QuantKey, + kFp8StaticTensorSym, + kNvfp4Quant, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available -from vllm.utils.flashinfer import (can_use_trtllm_attention, - flashinfer_disable_q_quantization, - supports_trtllm_attention, - use_trtllm_attention) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - get_kv_cache_layout, - get_per_layer_parameters, - infer_global_hyperparameters, - split_decodes_and_prefills) -# yapf: enable +from vllm.utils.flashinfer import ( + can_use_trtllm_attention, + flashinfer_disable_q_quantization, + supports_trtllm_attention, + use_trtllm_attention, +) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 -FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024 FP8_DTYPE = current_platform.fp8_dtype() FP4_DTYPE = torch.uint8 @@ -58,7 +65,8 @@ def _get_trtllm_gen_workspace_buffer(): global trtllm_gen_workspace_buffer if trtllm_gen_workspace_buffer is None: trtllm_gen_workspace_buffer = torch.zeros( - FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device='cuda') + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device="cuda" + ) return trtllm_gen_workspace_buffer @@ -75,9 +83,9 @@ def _trtllm_prefill_attn_kvfp8_dequant( ): batch_idx = tl.program_id(0).to(tl.int64) mock_block_table_idx = tl.program_id(1).to(tl.int64) - orig_page_num = tl.load(block_tables_prefill_ptr + - batch_idx * block_table_stride + - mock_block_table_idx).to(tl.int64) + orig_page_num = tl.load( + block_tables_prefill_ptr + batch_idx * block_table_stride + mock_block_table_idx + ).to(tl.int64) if orig_page_num <= 0: return dequant_dtype = mock_kv_cache_ptr.dtype.element_ty @@ -87,20 +95,24 @@ def _trtllm_prefill_attn_kvfp8_dequant( offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) fp8_vals = tl.load(kv_cache_ptr + offset) dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val - mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx - + 1) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + mock_cache_offset = ( + batch_idx * block_table_stride + mock_block_table_idx + 1 + ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) dequantized_vals = dequantized_vals.to(dequant_dtype) tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) # Dequantize V v_scale_val = tl.load(v_scale_ptr) - offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + - tl.arange(0, K_CACHE_STRIDE)) + offset = ( + orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) + ) fp8_vals = tl.load(kv_cache_ptr + offset) dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val mock_cache_offset = ( - (batch_idx * block_table_stride + mock_block_table_idx + 1) * - KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)) + (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE + + K_CACHE_STRIDE + + tl.arange(0, K_CACHE_STRIDE) + ) dequantized_vals = dequantized_vals.to(dequant_dtype) tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) @@ -120,9 +132,7 @@ def trtllm_prefill_attn_kvfp8_dequant( kv_cache_stride = k_cache_stride * s[1] new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) # mock kv cache contains just the pages needed by this prefill - mock_kv_cache = torch.empty(new_s, - dtype=dequant_dtype, - device=kv_cache.device) + mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device) # we simply sequentially index the pages needed by this prefill mock_block_table = torch.arange( start=1, @@ -165,7 +175,8 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -243,22 +254,28 @@ class FlashInferMetadata: # For cascade attention (CPU for planning). use_cascade: bool - prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None - decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None - cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None + decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None + cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None - qo_indptr_gpu: Optional[torch.Tensor] = None - paged_kv_indptr_gpu: Optional[torch.Tensor] = None + qo_indptr_gpu: torch.Tensor | None = None + paged_kv_indptr_gpu: torch.Tensor | None = None class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) reorder_batch_threshold: int = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config self.model_config = vllm_config.model_config @@ -266,32 +283,28 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) - if vllm_kernel_override_batch_invariant(): - self.decode_fixed_split_size = 2048 - self.prefill_fixed_split_size = 4096 - self.disable_split_kv = True - else: - self.decode_fixed_split_size = -1 - self.prefill_fixed_split_size = -1 - self.disable_split_kv = False - self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(self.model_config.max_model_len, - self.kv_cache_spec.block_size) + max_num_pages_per_req = cdiv( + self.model_config.max_model_len, self.kv_cache_spec.block_size + ) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req - self.enable_cuda_graph = (self.compilation_config.cudagraph_mode.\ - decode_mode() == CUDAGraphMode.FULL) + self.enable_cuda_graph = ( + self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + ) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. self._decode_wrappers_cudagraph: dict[ - int, BatchDecodeWithPagedKVCacheWrapper] = {} + int, BatchDecodeWithPagedKVCacheWrapper + ] = {} self._decode_cudagraph_max_bs = min( - max_num_reqs, self.compilation_config.max_capture_size) + max_num_reqs, self.compilation_config.max_capture_size + ) self.num_qo_heads = self.model_config.get_num_attention_heads( - self.vllm_config.parallel_config) + self.vllm_config.parallel_config + ) self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size FlashInferBackend.validate_head_size(self.head_dim) @@ -299,9 +312,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.cache_dtype = self.cache_config.cache_dtype if self.cache_dtype.startswith("fp8"): - self.kv_cache_dtype = ( - FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.cache_dtype)) + self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.cache_dtype + ) else: assert self.kv_cache_spec.dtype == self.model_config.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype @@ -310,14 +323,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION is set to 1. Otherwise, try to # use fp8 q if kv cache is fp8, and will fall back to model dtype # if TRTLLM attention kernel is not used when building attn metadata - if supports_trtllm_attention() and \ - not flashinfer_disable_q_quantization(): + if supports_trtllm_attention() and not flashinfer_disable_q_quantization(): self.q_data_type = self.kv_cache_dtype else: self.q_data_type = self.model_config.dtype - supports_spec_as_decode = \ - can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) + supports_spec_as_decode = can_use_trtllm_attention( + self.num_qo_heads, self.num_kv_heads + ) self._init_reorder_batch_threshold(1, supports_spec_as_decode) self._cascade_wrapper = None # Wrapper for cascade attention @@ -325,7 +338,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # Global hyperparameters shared by all attention layers # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) + get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl) + ) self.sm_scale = self.global_hyperparameters.sm_scale self.window_left = self.global_hyperparameters.window_left self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap @@ -334,69 +348,62 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], raise NotImplementedError( "FlashInfer backend currently does not support attention " "sinks, please use trtllm on blackwell or flash attention on " - "earlier GPUs.") + "earlier GPUs." + ) # Preparing persistent buffers (device-side) - self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device=self.device) + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=self.device + ) self.paged_kv_indices = torch.zeros( max_num_pages, # max num pages possible dtype=torch.int32, - device=self.device) - self.paged_kv_last_page_len = torch.zeros(max_num_reqs, - dtype=torch.int32, - device=self.device) + device=self.device, + ) + self.paged_kv_last_page_len = torch.zeros( + max_num_reqs, dtype=torch.int32, device=self.device + ) # host-side buffer pin_memory = is_pin_memory_available() - self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) + self.paged_kv_indptr_cpu = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() self.paged_kv_indptr_buffer = torch.zeros_like( - self.paged_kv_indptr_cpu, pin_memory=pin_memory) - self.paged_kv_indices_cpu = torch.zeros(max_num_pages, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.paged_kv_last_page_len_np = ( - self.paged_kv_last_page_len_cpu.numpy()) + self.paged_kv_indptr_cpu, pin_memory=pin_memory + ) + self.paged_kv_indices_cpu = torch.zeros( + max_num_pages, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_cpu = torch.zeros( + max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) + self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy() def _get_workspace_buffer(self): if self._workspace_buffer is None: - buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE - if vllm_kernel_override_batch_invariant(): - buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT - self._workspace_buffer = torch.zeros(buffer_size, - dtype=torch.uint8, - device=self.device) + self._workspace_buffer = torch.zeros( + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=self.device + ) return self._workspace_buffer def _get_prefill_wrapper(self): if self._prefill_wrapper is None: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), get_kv_cache_layout()) + self._get_workspace_buffer(), get_kv_cache_layout() + ) return self._prefill_wrapper - def _get_decode_wrapper(self, - batch_size: int, - use_cudagraph: bool = False): + def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): if use_cudagraph: - decode_wrapper = self._decode_wrappers_cudagraph.get( - batch_size, None) + decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None) else: decode_wrapper = self._decode_wrapper if decode_wrapper is None: if use_cudagraph: - paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] + paged_kv_indptr = self.paged_kv_indptr[: batch_size + 1] paged_kv_indices = self.paged_kv_indices - paged_kv_last_page_len = self.paged_kv_last_page_len[: - batch_size] + paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size] else: paged_kv_indptr = None paged_kv_indices = None @@ -425,19 +432,25 @@ def _get_decode_wrapper(self, def _get_cascade_wrapper(self): if self._cascade_wrapper is None: self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( - 2, self._get_workspace_buffer(), get_kv_cache_layout()) + 2, self._get_workspace_buffer(), get_kv_cache_layout() + ) return self._cascade_wrapper - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashInferMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashInferMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold, - require_uniform=True) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold, + require_uniform=True, + ) + ) page_size = self.page_size max_q_len = common_attn_metadata.max_query_len @@ -456,17 +469,16 @@ def build(self, num_common_kv_blocks = common_prefix_len // page_size # Create CPU versions directly for cascade (no GPU versions needed) - shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device='cpu') - shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks], - dtype=torch.int32, - device='cpu') - shared_kv_page_indices_cpu = block_table_tensor[ - 0, :num_common_kv_blocks] - shared_kv_last_page_len_cpu = torch.tensor([page_size], - dtype=torch.int32, - device='cpu') + shared_qo_indptr_cpu = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indptr_cpu = torch.tensor( + [0, num_common_kv_blocks], dtype=torch.int32, device="cpu" + ) + shared_kv_page_indices_cpu = block_table_tensor[0, :num_common_kv_blocks] + shared_kv_last_page_len_cpu = torch.tensor( + [page_size], dtype=torch.int32, device="cpu" + ) # Remove the blocks of the shared prefix from all requests. block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] @@ -481,22 +493,23 @@ def build(self, np.cumsum( num_blocks_np, dtype=np.int32, - out=self.paged_kv_indptr_np[1:num_reqs + 1], + out=self.paged_kv_indptr_np[1 : num_reqs + 1], ) # NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified # after this line (e.g., for cuda graphs), we need to copy the data to # self.paged_kv_indptr_buffer to avoid race condition. - self.paged_kv_indptr_buffer[:num_reqs + - 1] = (self.paged_kv_indptr_cpu[:num_reqs + - 1]) - paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1] - paged_kv_indptr.copy_(self.paged_kv_indptr_buffer[:num_reqs + 1], - non_blocking=True) + self.paged_kv_indptr_buffer[: num_reqs + 1] = self.paged_kv_indptr_cpu[ + : num_reqs + 1 + ] + paged_kv_indptr = self.paged_kv_indptr[: num_reqs + 1] + paged_kv_indptr.copy_( + self.paged_kv_indptr_buffer[: num_reqs + 1], non_blocking=True + ) # write self.paged_kv_indices inplace num_actual_pages = self.paged_kv_indptr_np[num_reqs] paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - _copy_page_indices_kernel[(num_reqs, )]( + _copy_page_indices_kernel[(num_reqs,)]( paged_kv_indices, block_table_tensor, block_table_tensor.stride(0), @@ -513,29 +526,34 @@ def build(self, ) uses_spec_reorder = self.reorder_batch_threshold > 1 - prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, - self.num_kv_heads, - num_prefill_tokens, - max_seq_len, - self.cache_dtype, - self.q_data_type, - is_prefill=True, - has_sinks=self.has_sinks, - has_spec=uses_spec_reorder) - decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, - self.num_kv_heads, - num_decode_tokens, - max_seq_len, - self.cache_dtype, - self.q_data_type, - is_prefill=False, - has_sinks=self.has_sinks, - has_spec=uses_spec_reorder) + prefill_use_trtllm = use_trtllm_attention( + self.num_qo_heads, + self.num_kv_heads, + num_prefill_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=True, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder, + ) + decode_use_trtllm = use_trtllm_attention( + self.num_qo_heads, + self.num_kv_heads, + num_decode_tokens, + max_seq_len, + self.cache_dtype, + self.q_data_type, + is_prefill=False, + has_sinks=self.has_sinks, + has_spec=uses_spec_reorder, + ) if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): raise NotImplementedError( "FlashInfer backend currently does not support attention " "sinks, please use trtllm on blackwell or flash attention on " - "earlier GPUs.") + "earlier GPUs." + ) # If TRTLLM attention is not used, the q quantization is not supported. # Fall back to use model dtype. @@ -561,7 +579,7 @@ def build(self, ) qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu - paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs] + paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs] paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs] if attn_metadata.use_cascade: @@ -592,17 +610,17 @@ def build(self, # Decodes are first so prefills start after the last decode prefill_start = num_decodes attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert qo_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert paged_kv_indptr_cpu[prefill_start:].shape[ - 0] == num_prefills + 1 - assert paged_kv_last_page_len_cpu[prefill_start:].shape[ - 0] == num_prefills + assert qo_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert paged_kv_indptr_cpu[prefill_start:].shape[0] == num_prefills + 1 + assert ( + paged_kv_last_page_len_cpu[prefill_start:].shape[0] == num_prefills + ) # Since prefill_wrapper.run() will be called with # query[num_decode_tokens:] we need to adjust the qo_indptr # to be relative to the start of the prefill queries. - qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ - prefill_start] + qo_indptr_cpu = ( + qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[prefill_start] + ) paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] # Recompute max_q_len for the slice of requests we are using @@ -610,8 +628,7 @@ def build(self, # we have a non-uniform batch with some short decodes offloaded # to the prefill pathway query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1] - attn_metadata.max_q_len_prefill = \ - int(query_lens_prefill.max().item()) + attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item()) if not attn_metadata.prefill_use_trtllm: attn_metadata.prefill_wrapper.plan( @@ -629,47 +646,53 @@ def build(self, logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.prefill_fixed_split_size, - disable_split_kv=self.disable_split_kv, ) else: attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) attn_metadata.paged_kv_indptr_gpu = paged_kv_indptr_cpu.to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) if num_decodes > 0: pure_decode = num_prefills == 0 # possible required padding for cudagraph replay - use_cudagraph = (self.enable_cuda_graph and pure_decode and - num_decodes <= self._decode_cudagraph_max_bs) + use_cudagraph = ( + self.enable_cuda_graph + and pure_decode + and num_decodes <= self._decode_cudagraph_max_bs + ) if use_cudagraph: - num_input_tokens = ( - self.vllm_config.pad_for_cudagraph(num_decode_tokens)) + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_decode_tokens + ) # Carefully fulfill the padding region with reasonable value # on cpu. # Make sure paged_kv_indptr_cpu is not decreasing - self.paged_kv_indptr_cpu[1 + num_decodes:1 + - num_input_tokens].fill_( - paged_kv_indptr_cpu[-1]) + self.paged_kv_indptr_cpu[ + 1 + num_decodes : 1 + num_input_tokens + ].fill_(paged_kv_indptr_cpu[-1]) # Fill the remaining paged_kv_last_page_len_cpu with 1. # This is because flashinfer treats 0 as a full page # instead of empty. - self.paged_kv_last_page_len_cpu[ - num_decodes:num_input_tokens].fill_(1) + self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_( + 1 + ) else: num_input_tokens = num_decode_tokens attn_metadata.decode_wrapper = self._get_decode_wrapper( - num_input_tokens, use_cudagraph) + num_input_tokens, use_cudagraph + ) if not attn_metadata.decode_use_trtllm: # Use the persistent buffer with padding length, # instead of the same address but chunked version # in atten_metadata when using cudagraph. fast_plan_decode( attn_metadata.decode_wrapper, - self.paged_kv_indptr_cpu[:num_input_tokens + 1], + self.paged_kv_indptr_cpu[: num_input_tokens + 1], paged_kv_indices, self.paged_kv_last_page_len_cpu[:num_input_tokens], seq_lens_cpu[:num_input_tokens], @@ -684,8 +707,6 @@ def build(self, logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, - fixed_split_size=self.decode_fixed_split_size, - disable_split_kv=self.disable_split_kv, ) return attn_metadata @@ -700,20 +721,19 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class FlashInferImpl(AttentionImpl): - def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], + alibi_slopes: list[float] | None, + sliding_window: int | None, kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, + logits_soft_cap: float | None = None, attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[int] = None, - sinks: Optional[torch.Tensor] = None, + kv_sharing_target_layer_name: int | None = None, + sinks: torch.Tensor | None = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -726,8 +746,9 @@ def __init__( self.sliding_window = (-1, -1) else: self.sliding_window = (sliding_window - 1, 0) - self.window_left = (self.sliding_window[0] - if self.sliding_window is not None else -1) + self.window_left = ( + self.sliding_window[0] if self.sliding_window is not None else -1 + ) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -735,30 +756,36 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl" + ) - self.sinks: Optional[torch.Tensor] = None + self.sinks: torch.Tensor | None = None if sinks is not None: if sinks.shape[0] != num_heads: raise ValueError( "Sinks must have the same number of heads as the number of " f"heads in the layer. Expected {num_heads}, but got " - f"{sinks.shape[0]}.") + f"{sinks.shape[0]}." + ) self.sinks = sinks - self.support_trtllm_attn = (supports_trtllm_attention() - and num_heads % num_kv_heads == 0) - self.bmm1_scale: Optional[float] = None - self.bmm2_scale: Optional[float] = None - self.o_sf_scale: Optional[float] = None + self.support_trtllm_attn = ( + supports_trtllm_attention() and num_heads % num_kv_heads == 0 + ) + self.bmm1_scale: float | None = None + self.bmm2_scale: float | None = None + self.o_sf_scale: float | None = None def fused_output_quant_supported(self, quant_key: QuantKey): - return (self.support_trtllm_attn - and self.kv_cache_dtype.startswith("fp8") - and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)) + return ( + self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) + ) def forward( self, @@ -768,9 +795,9 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -792,28 +819,32 @@ def forward( return output if self.bmm1_scale is None: - self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * - self.scale) + self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale if self.bmm2_scale is None: self.bmm2_scale = layer._v_scale_float # The attn+quant fusion happens when output_scale is provided. if output_scale is None: - assert output_block_scale is None, "output_block_scale "\ - "is not supported when fusion has not happened" + assert output_block_scale is None, ( + "output_block_scale is not supported when fusion has not happened" + ) else: - assert attn_metadata.q_data_type == FP8_DTYPE, \ + assert attn_metadata.q_data_type == FP8_DTYPE, ( "Query must be FP8 when attn+quant fusion happened." - assert (attn_metadata.prefill_use_trtllm and - attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" + ) + assert ( + attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm + ), "Must use TRT-LLM attn" if output.dtype == FP8_DTYPE: - assert output_block_scale is None, \ + assert output_block_scale is None, ( "output_block_scale should not be provided for fp8 output" + ) elif output.dtype == FP4_DTYPE: - assert output_block_scale is not None, \ + assert output_block_scale is not None, ( "output_block_scale is required for nvfp4 output" + ) else: raise ValueError(f"Unsupported output dtype: {output.dtype}") @@ -831,9 +862,9 @@ def forward( if attn_metadata.q_data_type == FP8_DTYPE: num_tokens, num_heads, head_size = query.shape query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale, + ) query = query.reshape((num_tokens, num_heads, head_size)) # IMPORTANT! @@ -870,7 +901,8 @@ def forward( # to process the cache when the kv_cache_dtype is fp8 if self.kv_cache_dtype.startswith("fp8"): torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - self.kv_cache_dtype) + self.kv_cache_dtype + ) kv_cache = kv_cache.view(torch_dtype) # Inputs and outputs may be padded for CUDA graphs @@ -904,8 +936,7 @@ def forward( if not attn_metadata.prefill_use_trtllm: assert prefill_wrapper._causal assert prefill_wrapper._window_left == self.window_left - assert prefill_wrapper._logits_soft_cap == ( - self.logits_soft_cap or 0.0) + assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale prefill_wrapper.run( prefill_query, @@ -918,8 +949,7 @@ def forward( # prefill_query may be non-contiguous prefill_query = prefill_query.contiguous() workspace_buffer = _get_trtllm_gen_workspace_buffer() - block_tables_prefill = attn_metadata.block_table_tensor[ - num_decodes:] + block_tables_prefill = attn_metadata.block_table_tensor[num_decodes:] seq_lens_prefill = attn_metadata.seq_lens[num_decodes:] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND @@ -932,28 +962,31 @@ def forward( if output.dtype == FP4_DTYPE: assert self.o_sf_scale is not None - out = FP4Tensor(data=output[num_decode_tokens:], - scale=output_block_scale, - scale_start_index=num_decode_tokens, - original_shape=prefill_query.shape) + out = FP4Tensor( + data=output[num_decode_tokens:], + scale=output_block_scale, + scale_start_index=num_decode_tokens, + original_shape=prefill_query.shape, + ) else: assert self.o_sf_scale is None out = output[num_decode_tokens:] - if attn_metadata.q_data_type != FP8_DTYPE \ - and self.kv_cache_dtype.startswith("fp8"): + if ( + attn_metadata.q_data_type != FP8_DTYPE + and self.kv_cache_dtype.startswith("fp8") + ): # TRTLLM prefill attention does not support BF16 Q # and fp8 kv cache. So to enable prefill attention # with fp8 kv cache, we can construct a mock block # and mock kv cache with BF16 KV involved in the prefill - mock_kv_cache, mock_block_table = ( - trtllm_prefill_attn_kvfp8_dequant( - kv_cache_permute, - block_tables_prefill, - layer._k_scale, - layer._v_scale, - attn_metadata.q_data_type, - )) + mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( + kv_cache_permute, + block_tables_prefill, + layer._k_scale, + layer._v_scale, + attn_metadata.q_data_type, + ) else: mock_kv_cache = kv_cache_permute mock_block_table = block_tables_prefill @@ -985,8 +1018,7 @@ def forward( if not attn_metadata.decode_use_trtllm: assert decode_wrapper._window_left == self.window_left - assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap - or 0.0) + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert decode_wrapper._sm_scale == self.scale decode_wrapper.run( decode_query, @@ -999,8 +1031,9 @@ def forward( # decode_query may be non-contiguous decode_query = decode_query.contiguous() workspace_buffer = _get_trtllm_gen_workspace_buffer() - block_tables_decode = attn_metadata.\ - block_table_tensor[:num_decode_tokens] + block_tables_decode = attn_metadata.block_table_tensor[ + :num_decode_tokens + ] seq_lens_decode = attn_metadata.seq_lens[:num_decode_tokens] # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND @@ -1013,10 +1046,12 @@ def forward( if output.dtype == FP4_DTYPE: assert self.o_sf_scale is not None - out = FP4Tensor(data=output[:num_decode_tokens], - scale=output_block_scale, - scale_start_index=0, - original_shape=decode_query.shape) + out = FP4Tensor( + data=output[:num_decode_tokens], + scale=output_block_scale, + scale_start_index=0, + original_shape=decode_query.shape, + ) else: assert self.o_sf_scale is None out = output[:num_decode_tokens] @@ -1026,8 +1061,7 @@ def forward( # attention to be initialized with q_len = 0 q_len_per_req = 1 else: - q_len_per_req = \ - num_decode_tokens // attn_metadata.num_decodes + q_len_per_req = num_decode_tokens // attn_metadata.num_decodes trtllm_batch_decode_with_kv_cache( query=decode_query, @@ -1042,7 +1076,8 @@ def forward( sinks=self.sinks, o_sf_scale=self.o_sf_scale, out=out, - q_len_per_req=q_len_per_req) + q_len_per_req=q_len_per_req, + ) return output_padded @@ -1058,16 +1093,14 @@ def fast_plan_decode( page_size: int, pos_encoding_mode: str = "NONE", window_left: int = -1, - logits_soft_cap: Optional[float] = None, - q_data_type: Optional[Union[str, torch.dtype]] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, - data_type: Optional[Union[str, torch.dtype]] = None, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, + logits_soft_cap: float | None = None, + q_data_type: Union[str, torch.dtype] | None = "float16", + kv_data_type: Union[str, torch.dtype] | None = None, + data_type: Union[str, torch.dtype] | None = None, + sm_scale: float | None = None, + rope_scale: float | None = None, + rope_theta: float | None = None, non_blocking: bool = True, - fixed_split_size: int = -1, - disable_split_kv: bool = False, ) -> None: """ A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for @@ -1085,8 +1118,7 @@ def fast_plan_decode( # Warm up with the original plan if it is first call, and always run the # original plan if we run for dynamic shape. For fixed shape (cudagraph), # this warm up is to generate the _cached_module for the decode wrapper. - if not self.is_cuda_graph_enabled or \ - getattr(self, "vllm_first_call", True): + if not self.is_cuda_graph_enabled or getattr(self, "vllm_first_call", True): self.plan( indptr_cpu, indices, @@ -1105,10 +1137,6 @@ def fast_plan_decode( rope_scale, rope_theta, non_blocking, - None, # block_tables - None, # seq_lens - fixed_split_size, - disable_split_kv, ) self.vllm_first_call = False return @@ -1130,31 +1158,33 @@ def fast_plan_decode( if kv_data_type is None: kv_data_type = q_data_type - q_data_type = getattr(torch, q_data_type) if isinstance( - q_data_type, str) else q_data_type - kv_data_type = getattr(torch, kv_data_type) if isinstance( - kv_data_type, str) else kv_data_type + q_data_type = ( + getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + ) + kv_data_type = ( + getattr(torch, kv_data_type) if isinstance(kv_data_type, str) else kv_data_type + ) if batch_size != self._fixed_batch_size: raise ValueError( "The batch size should be fixed in cudagraph mode, the runtime " "batch size {} mismatches the batch size set during " - "initialization {}".format(batch_size, self._fixed_batch_size)) + "initialization {}".format(batch_size, self._fixed_batch_size) + ) if len(indices) > len(self._paged_kv_indices_buf): raise ValueError( - "The size of indices should be less than or equal to the " - "allocated buffer") + "The size of indices should be less than or equal to the allocated buffer" + ) # host-to-device copy for the indptr buffer self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True) # host-to-device copy for the last_page_len buffer - self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, - non_blocking=True) + self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, non_blocking=True) qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") try: - # Make sure we pass exactly 18 arguments for tensor core version + # Make sure we pass exactly 15 arguments for tensor core version self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, @@ -1171,9 +1201,6 @@ def fast_plan_decode( head_dim, head_dim, False, # causal - window_left, - fixed_split_size, - disable_split_kv, ) except Exception as e: raise RuntimeError(f"Error in tensor core plan: {e}") from e @@ -1203,6 +1230,8 @@ def _copy_page_indices_kernel( offset = tl.arange(0, BLOCK_SIZE) for i in tl.range(0, num_blocks, BLOCK_SIZE): block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) - tl.store(page_indices + start_idx + i + offset, - block_ids, - mask=i + offset < num_blocks) + tl.store( + page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks, + ) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index e548b51060d8..4640e62abfe6 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -8,21 +8,32 @@ import torch import torch._dynamo.decorators import torch.nn.functional as F -from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, - _score_mod_signature, and_masks, - create_block_mask, - flex_attention) - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) +from torch.nn.attention.flex_attention import ( + BlockMask, + _mask_mod_signature, + _score_mod_signature, + and_masks, + create_block_mask, + flex_attention, +) + +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, + is_quantized_kv_cache, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( - vllm_kernel_override_batch_invariant) + vllm_kernel_override_batch_invariant, +) from vllm.utils import cdiv, is_torch_equal_or_newer -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -31,9 +42,9 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch -create_block_mask_compiled = torch.compile(create_block_mask, - fullgraph=True, - mode="reduce-overhead") +create_block_mask_compiled = torch.compile( + create_block_mask, fullgraph=True, mode="reduce-overhead" +) flex_attention_compiled = torch.compile(flex_attention, fullgraph=True) @@ -41,7 +52,8 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: device = offsets.device counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave( - torch.arange(len(counts), device=device, dtype=torch.int32), counts) + torch.arange(len(counts), device=device, dtype=torch.int32), counts + ) def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): @@ -103,10 +115,13 @@ def use_cascade_attention(*args, **kwargs) -> bool: return False -#@torch.compile(fullgraph=True, mode="reduce-overhead") -def physical_to_logical_mapping(block_table: torch.Tensor, - seq_lens: torch.Tensor, block_size: int, - total_blocks: int) -> torch.Tensor: +# @torch.compile(fullgraph=True, mode="reduce-overhead") +def physical_to_logical_mapping( + block_table: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + total_blocks: int, +) -> torch.Tensor: """ Creates an inverse mapping from physical block locations to logical indices. @@ -176,35 +191,37 @@ def physical_to_logical_mapping(block_table: torch.Tensor, max_reqs, max_num_blocks = block_table.shape device = block_table.device - physical_to_logical = torch.full((max_reqs, total_blocks), - -1, - dtype=torch.long, - device=device) + physical_to_logical = torch.full( + (max_reqs, total_blocks), -1, dtype=torch.long, device=device + ) # Only process valid blocks to avoid garbage values num_blocks_per_seq = cdiv(seq_lens, block_size) - mask = torch.arange(max_num_blocks, - device=device)[None, :] < num_blocks_per_seq[:, None] + mask = ( + torch.arange(max_num_blocks, device=device)[None, :] + < num_blocks_per_seq[:, None] + ) valid_block_table = torch.where(mask, block_table, 0) valid_logical_indices = torch.where( - mask, - torch.arange(max_num_blocks, device=device)[None, :], 0) + mask, torch.arange(max_num_blocks, device=device)[None, :], 0 + ) - physical_to_logical.scatter_(-1, valid_block_table.to(torch.int64), - valid_logical_indices) + physical_to_logical.scatter_( + -1, valid_block_table.to(torch.int64), valid_logical_indices + ) # NB - Seems like block 0 is always empty so we reset it manually physical_to_logical[:, 0] = -1 return physical_to_logical def unique_static_unsorted( - x: torch.Tensor, - *, - M: int, # maximum positive value (0 is “skip me”) - dim: int = -1, # axis along which to deduplicate - ignored_val: int = 0, # value to ignore - pad_val: int = -1, # sentinel for unused slots + x: torch.Tensor, + *, + M: int, # maximum positive value (0 is “skip me”) + dim: int = -1, # axis along which to deduplicate + ignored_val: int = 0, # value to ignore + pad_val: int = -1, # sentinel for unused slots ) -> torch.Tensor: """ - Keeps the first occurrence of each non-zero value while preserving order, @@ -236,8 +253,7 @@ def unique_static_unsorted( first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin") # ── keep mask: first occurrence *and* value ≠ 0 ───────────────────── - keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat) - ) # [B, N] + keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)) # [B, N] # ── left-pack uniques into a fresh tensor ─────────────────────────── dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go @@ -251,8 +267,9 @@ def unique_static_unsorted( return packed -def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, - kv_idx: torch.Tensor): +def causal_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor +): return q_idx >= kv_idx @@ -317,8 +334,7 @@ def _convert_physical_to_logical( physical_kv_block = physical_kv_idx // self.block_size physical_kv_offset = physical_kv_idx % self.block_size logical_block_idx = self.physical_to_logical[q_req, physical_kv_block] - logical_kv_idx = (logical_block_idx * self.block_size + - physical_kv_offset) + logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # Determine valid kv indices live_block = logical_block_idx >= 0 @@ -352,9 +368,9 @@ def final_mask_mod( q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - (is_valid, logical_q_idx, - logical_kv_idx) = self._convert_physical_to_logical( - self.doc_ids, q_idx, physical_kv_idx) + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) # Apply mask modification only for valid indices return torch.where( is_valid, @@ -392,11 +408,11 @@ def get_sliding_window_mask_mod(self) -> _mask_mod_signature: """ if self.sliding_window is None: - raise ValueError( - "sliding_window must be set for sliding window attention") + raise ValueError("sliding_window must be set for sliding window attention") - def sliding_window_mask_mod(b: torch.Tensor, h: torch.Tensor, - q_idx: torch.Tensor, kv_idx: torch.Tensor): + def sliding_window_mask_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ): return torch.abs(q_idx - kv_idx) < self.sliding_window def final_mask_mod( @@ -405,9 +421,9 @@ def final_mask_mod( q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - (is_valid, logical_q_idx, - logical_kv_idx) = self._convert_physical_to_logical( - self.doc_ids, q_idx, physical_kv_idx) + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx) + ) return torch.where( is_valid, sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx), @@ -451,18 +467,19 @@ def transformed_score_mod( q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - (is_valid, logical_q_idx, - logical_kv_idx) = self._convert_physical_to_logical( - request_lookup, q_idx, physical_kv_idx) + (is_valid, logical_q_idx, logical_kv_idx) = ( + self._convert_physical_to_logical( + request_lookup, q_idx, physical_kv_idx + ) + ) return torch.where( is_valid, - user_score_mod(score, - b, - h, - logical_q_idx, - logical_kv_idx, - physical_q=q_idx), -float('inf')) + user_score_mod( + score, b, h, logical_q_idx, logical_kv_idx, physical_q=q_idx + ), + -float("inf"), + ) return transformed_score_mod @@ -493,18 +510,22 @@ def _build_block_mask_direct(self) -> BlockMask: f"FlexAttention currently requires the cache block size " f"({self.block_size}) to be equal to the kv_block_size " f"({self.kv_block_size}). Please check your model's " - f"configuration.") + f"configuration." + ) used_pages = self.block_table[ - self.doc_ids, :cdiv(self.max_seq_len, self.block_size)] - used_pages_padded = pad_to_multiple(used_pages, - multiple=self.q_block_size, - dim=0) + self.doc_ids, : cdiv(self.max_seq_len, self.block_size) + ] + used_pages_padded = pad_to_multiple( + used_pages, multiple=self.q_block_size, dim=0 + ) used_pages_padded = used_pages_padded.reshape( - used_pages_padded.shape[0] // self.q_block_size, -1) + used_pages_padded.shape[0] // self.q_block_size, -1 + ) used_pages_padded = used_pages_padded // page_to_block_ratio - kv_indices = unique_static_unsorted((used_pages_padded.long()), - M=self.num_blocks).to(torch.int32) + kv_indices = unique_static_unsorted( + (used_pages_padded.long()), M=self.num_blocks + ).to(torch.int32) kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) block_mask_kwargs = { @@ -524,8 +545,7 @@ def _build_block_mask_direct(self) -> BlockMask: def build_block_mask(self) -> BlockMask: mask_mod = self.get_mask_mod() - kv_len = (self.total_cache_tokens - if self.causal else self.num_actual_tokens) + kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens return create_block_mask_compiled( mask_mod, None, @@ -555,11 +575,14 @@ def __post_init__(self): self.block_mask = self.build_block_mask() -class FlexAttentionMetadataBuilder( - AttentionMetadataBuilder[FlexAttentionMetadata]): - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): +class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config @@ -567,26 +590,27 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.cache_config = vllm_config.cache_config self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0") - self.q_block_size: int = 16 if is_torch_equal_or_newer( - "2.9.0.dev0") else 128 - self.kv_block_size: int = 16 if is_torch_equal_or_newer( - "2.9.0.dev0") else 128 + self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 + self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: return False - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlexAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlexAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -609,15 +633,18 @@ def build(self, max_possible_seq_len = self.model_config.max_model_len num_gpu_blocks = self.cache_config.num_gpu_blocks - assert num_gpu_blocks is not None, \ + assert num_gpu_blocks is not None, ( "FlexAttention requires num_gpu_blocks to be set" - total_cache_tokens = (num_gpu_blocks * block_size) + ) + total_cache_tokens = num_gpu_blocks * block_size inverse_block_table = physical_to_logical_mapping( - block_table_tensor, seq_lens, block_size, num_gpu_blocks) + block_table_tensor, seq_lens, block_size, num_gpu_blocks + ) offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) out = FlexAttentionMetadata( causal=common_attn_metadata.causal, @@ -675,14 +702,15 @@ def __init__( self.num_kv_heads = num_kv_heads self.attn_type = attn_type - if attn_type not in (AttentionType.ENCODER_ONLY, - AttentionType.DECODER): + if attn_type not in (AttentionType.ENCODER_ONLY, AttentionType.DECODER): raise NotImplementedError( - f"FlexAttention does not support {attn_type} attention") + f"FlexAttention does not support {attn_type} attention" + ) if alibi_slopes is not None: raise NotImplementedError( - "FlexAttention does not support alibi slopes yet.") + "FlexAttention does not support alibi slopes yet." + ) else: self.alibi_slopes = None @@ -692,19 +720,20 @@ def __init__( self.logits_soft_cap = logits_soft_cap if self.logits_soft_cap is not None: raise NotImplementedError( - "FlexAttention does not support logits soft cap yet.") + "FlexAttention does not support logits soft cap yet." + ) assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if kv_sharing_target_layer_name is not None: - raise NotImplementedError( - "FlexAttention does not support kv sharing yet.") + raise NotImplementedError("FlexAttention does not support kv sharing yet.") FlexAttentionBackend.validate_head_size(head_size) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "FlexAttention does not support quantized kv-cache. Yet") + "FlexAttention does not support quantized kv-cache. Yet" + ) @staticmethod def view_as_4d(tensor: torch.Tensor) -> torch.Tensor: @@ -741,8 +770,8 @@ def forward( assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlexAttentionImpl") + "fused output quantization is not yet supported for FlexAttentionImpl" + ) enable_gqa = self.num_kv_heads != self.num_heads @@ -761,11 +790,11 @@ def forward( # in direct block mask building code path. logger.warning_once( "Using direct block mask building with sliding window, " - "which is suboptimal now. Performance may be degraded.") + "which is suboptimal now. Performance may be degraded." + ) # update mask mod in attention metadata attn_metadata.mask_mod = attn_metadata.get_mask_mod() - attn_metadata.block_mask = ( - attn_metadata._build_block_mask_direct()) + attn_metadata.block_mask = attn_metadata._build_block_mask_direct() else: attn_metadata.block_mask = attn_metadata.build_block_mask() @@ -778,8 +807,9 @@ def forward( ) query = query[:, :, :num_actual_tokens, :] - if ((key_tensor.size(-2) > num_actual_tokens) - or (value_tensor.size(-2) > num_actual_tokens)): + if (key_tensor.size(-2) > num_actual_tokens) or ( + value_tensor.size(-2) > num_actual_tokens + ): # In the encoder-only model with torch.compile, # qkv might be padded, which might cause exception. # see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290 @@ -803,8 +833,7 @@ def forward( # View out the block_size dim key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) - value_cache = value_cache.view(-1, self.num_kv_heads, - self.head_size) + value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size) query, key_tensor, value_tensor = map( lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), (query, key_cache, value_cache), @@ -818,8 +847,9 @@ def forward( assert attn_metadata.block_mask is not None block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE - kernel_options = get_kernel_options(query, block_m, block_n, - attn_metadata.direct_build) + kernel_options = get_kernel_options( + query, block_m, block_n, attn_metadata.direct_build + ) out = flex_attention_compiled( query, key_tensor, @@ -837,8 +867,9 @@ def forward( return output -def get_kernel_options(query, block_m, block_n, - use_direct_build: bool) -> dict[str, Union[int, bool]]: +def get_kernel_options( + query, block_m, block_n, use_direct_build: bool +) -> dict[str, Union[int, bool]]: kernel_options: dict[str, Union[int, bool]] = { "FORCE_USE_FLEX_ATTENTION": True, } diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 11f165d6cfc6..21fc2ab72768 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Backend for GatedDeltaNet attention.""" + from dataclasses import dataclass from typing import Optional @@ -9,16 +10,17 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - compute_causal_conv1d_metadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class GDNAttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: return GDNAttentionMetadataBuilder @@ -36,19 +38,21 @@ class GDNAttentionMetadata: has_initial_state: Optional[torch.Tensor] = None - spec_query_start_loc: Optional[ - torch.Tensor] = None # shape: [num_spec_decodes + 1,] - non_spec_query_start_loc: Optional[ - torch.Tensor] = None # shape: [batch - num_spec_decodes + 1,] - - spec_state_indices_tensor: Optional[ - torch.Tensor] = None # shape: [batch, num_spec] - non_spec_state_indices_tensor: Optional[ - torch.Tensor] = None # shape: [batch - num_spec_decodes,] + spec_query_start_loc: Optional[torch.Tensor] = ( + None # shape: [num_spec_decodes + 1,] + ) + non_spec_query_start_loc: Optional[torch.Tensor] = ( + None # shape: [batch - num_spec_decodes + 1,] + ) + + spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec] + non_spec_state_indices_tensor: Optional[torch.Tensor] = ( + None # shape: [batch - num_spec_decodes,] + ) spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,] - spec_token_masks: Optional[ - torch. - Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,] + spec_token_masks: Optional[torch.Tensor] = ( + None # shape: [num_prefill_tokens + num_decode_tokens,] + ) num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,] # The following attributes are for triton implementation of causal_conv1d @@ -57,32 +61,37 @@ class GDNAttentionMetadata: token_chunk_offset_ptr: Optional[torch.Tensor] = None -class GDNAttentionMetadataBuilder( - AttentionMetadataBuilder[GDNAttentionMetadata]): - +class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]): cudagraph_support = AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: int = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): assert isinstance(kv_cache_spec, MambaSpec) self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.speculative_config = vllm_config.speculative_config self.kv_cache_spec = kv_cache_spec if self.speculative_config: - self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501 + self.num_spec = self.speculative_config.num_speculative_tokens else: self.num_spec = 0 self.use_spec_decode = self.num_spec > 0 self._init_reorder_batch_threshold(1, self.use_spec_decode) - self.use_full_cuda_graph = \ + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) self.decode_cudagraph_max_bs = min( - self.vllm_config.scheduler_config.max_num_seqs * - (self.num_spec + 1), self.compilation_config.max_capture_size) + self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1), + self.compilation_config.max_capture_size, + ) self.spec_state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs, self.num_spec + 1), @@ -90,32 +99,32 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], device=device, ) self.non_spec_state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) self.spec_sequence_masks = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.bool, device=device, ) self.spec_token_masks = torch.empty( - (self.decode_cudagraph_max_bs * (self.num_spec + 1), ), + (self.decode_cudagraph_max_bs * (self.num_spec + 1),), dtype=torch.bool, device=device, ) self.spec_query_start_loc = torch.empty( - (self.decode_cudagraph_max_bs + 1, ), + (self.decode_cudagraph_max_bs + 1,), dtype=torch.int32, device=device, ) self.non_spec_query_start_loc = torch.empty( - (self.decode_cudagraph_max_bs + 1, ), + (self.decode_cudagraph_max_bs + 1,), dtype=torch.int32, device=device, ) self.num_accepted_tokens = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) @@ -135,9 +144,14 @@ def build( # type: ignore[override] context_lens_tensor = context_lens.to(query_start_loc.device) nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None - or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= - 0].sum().item() == 0): + if ( + not self.use_spec_decode + or num_decode_draft_tokens_cpu is None + or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0] + .sum() + .item() + == 0 + ): spec_sequence_masks = None num_spec_decodes = 0 else: @@ -147,11 +161,13 @@ def build( # type: ignore[override] spec_sequence_masks = None else: spec_sequence_masks = spec_sequence_masks.to( - query_start_loc.device, non_blocking=True) + query_start_loc.device, non_blocking=True + ) if spec_sequence_masks is None: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(m, decode_threshold=1)) + split_decodes_and_prefills(m, decode_threshold=1) + ) num_spec_decode_tokens = 0 spec_token_masks = None spec_state_indices_tensor = None @@ -166,45 +182,56 @@ def build( # type: ignore[override] num_decodes = (non_spec_query_lens == 1).sum().item() num_prefills = non_spec_query_lens.size(0) - num_decodes num_decode_tokens = num_decodes - num_prefill_tokens = non_spec_query_lens.sum().item( - ) - num_decode_tokens + num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens if num_prefills == 0 and num_decodes == 0: spec_token_masks = torch.ones( - (min(num_spec_decodes * - (self.num_spec + 1), query_start_loc[-1].item())), + ( + min( + num_spec_decodes * (self.num_spec + 1), + query_start_loc[-1].item(), + ) + ), dtype=torch.bool, - device=query_start_loc.device) - spec_state_indices_tensor = m.block_table_tensor[:, :self. - num_spec + 1] + device=query_start_loc.device, + ) + spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1] non_spec_state_indices_tensor = None spec_query_start_loc = query_start_loc non_spec_query_start_loc = None else: spec_token_masks = torch.repeat_interleave( - spec_sequence_masks, query_lens) + spec_sequence_masks, query_lens + ) spec_state_indices_tensor = m.block_table_tensor[ - spec_sequence_masks, :self.num_spec + 1] - non_spec_state_indices_tensor = \ - m.block_table_tensor[~spec_sequence_masks, 0] + spec_sequence_masks, : self.num_spec + 1 + ] + non_spec_state_indices_tensor = m.block_table_tensor[ + ~spec_sequence_masks, 0 + ] spec_query_start_loc = torch.zeros( num_spec_decodes + 1, dtype=torch.int32, - device=query_start_loc.device) - torch.cumsum(query_lens[spec_sequence_masks], - dim=0, - out=spec_query_start_loc[1:]) + device=query_start_loc.device, + ) + torch.cumsum( + query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:] + ) non_spec_query_start_loc = torch.zeros( query_lens.size(0) - num_spec_decodes + 1, dtype=torch.int32, - device=query_start_loc.device) - torch.cumsum(query_lens[~spec_sequence_masks], - dim=0, - out=non_spec_query_start_loc[1:]) - - num_spec_decode_tokens = (query_lens.sum().item() - - num_prefill_tokens - num_decode_tokens) + device=query_start_loc.device, + ) + torch.cumsum( + query_lens[~spec_sequence_masks], + dim=0, + out=non_spec_query_start_loc[1:], + ) + + num_spec_decode_tokens = ( + query_lens.sum().item() - num_prefill_tokens - num_decode_tokens + ) assert num_accepted_tokens is not None num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] @@ -212,12 +239,14 @@ def build( # type: ignore[override] has_initial_state = context_lens_tensor > 0 if spec_sequence_masks is not None: has_initial_state = has_initial_state[~spec_sequence_masks] - nums_dict, batch_ptr, token_chunk_offset_ptr = \ + nums_dict, batch_ptr, token_chunk_offset_ptr = ( compute_causal_conv1d_metadata(non_spec_query_start_loc) + ) else: has_initial_state = None - num_actual_tokens = num_prefill_tokens + num_decode_tokens + \ - num_spec_decode_tokens + num_actual_tokens = ( + num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens + ) # prepare tensors for cudagraph # @@ -226,64 +255,71 @@ def build( # type: ignore[override] # # In above cases, the max possible batch size for n tokens, can be # min(n, cudagraph_max_bs). - if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0 - and num_spec_decodes <= self.decode_cudagraph_max_bs - and num_spec_decode_tokens <= self.decode_cudagraph_max_bs): - num_actual_tokens = self.vllm_config.pad_for_cudagraph( - m.num_actual_tokens) + if ( + self.use_full_cuda_graph + and num_prefills == 0 + and num_decodes == 0 + and num_spec_decodes <= self.decode_cudagraph_max_bs + and num_spec_decode_tokens <= self.decode_cudagraph_max_bs + ): + num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens) self.spec_state_indices_tensor[:num_spec_decodes].copy_( - spec_state_indices_tensor, non_blocking=True) - spec_state_indices_tensor = self.spec_state_indices_tensor[: - batch_size] + spec_state_indices_tensor, non_blocking=True + ) + spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size] spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID) self.spec_sequence_masks[:num_spec_decodes].copy_( - spec_sequence_masks, non_blocking=True) + spec_sequence_masks, non_blocking=True + ) spec_sequence_masks = self.spec_sequence_masks[:batch_size] spec_sequence_masks[num_spec_decodes:].fill_(False) assert spec_token_masks is not None - self.spec_token_masks[:spec_token_masks.size(0)].copy_( - spec_token_masks, non_blocking=True) + self.spec_token_masks[: spec_token_masks.size(0)].copy_( + spec_token_masks, non_blocking=True + ) spec_token_masks = self.spec_token_masks[:num_actual_tokens] - spec_token_masks[spec_token_masks.size(0):].fill_(False) + spec_token_masks[spec_token_masks.size(0) :].fill_(False) - self.spec_query_start_loc[:num_spec_decodes + 1].copy_( - spec_query_start_loc, non_blocking=True) - spec_num_query_tokens = spec_query_start_loc[ - -1] # type: ignore[index] - spec_query_start_loc = self.spec_query_start_loc[:batch_size + 1] - spec_query_start_loc[num_spec_decodes + - 1:].fill_(spec_num_query_tokens) + self.spec_query_start_loc[: num_spec_decodes + 1].copy_( + spec_query_start_loc, non_blocking=True + ) + spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index] + spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1] + spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens) self.num_accepted_tokens[:num_spec_decodes].copy_( - num_accepted_tokens, non_blocking=True) + num_accepted_tokens, non_blocking=True + ) num_accepted_tokens = self.num_accepted_tokens[:batch_size] num_accepted_tokens[num_spec_decodes:].fill_(1) - if (self.use_full_cuda_graph and num_prefills == 0 - and num_spec_decodes == 0 - and num_decodes <= self.decode_cudagraph_max_bs): - num_actual_tokens = self.vllm_config.pad_for_cudagraph( - m.num_actual_tokens) + if ( + self.use_full_cuda_graph + and num_prefills == 0 + and num_spec_decodes == 0 + and num_decodes <= self.decode_cudagraph_max_bs + ): + num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) batch_size = num_actual_tokens self.non_spec_state_indices_tensor[:num_decodes].copy_( - non_spec_state_indices_tensor, non_blocking=True) - non_spec_state_indices_tensor = \ - self.non_spec_state_indices_tensor[:batch_size] + non_spec_state_indices_tensor, non_blocking=True + ) + non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[ + :batch_size + ] non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID) - self.non_spec_query_start_loc[:num_decodes + 1].copy_( - non_spec_query_start_loc, non_blocking=True) - non_spec_num_query_tokens = non_spec_query_start_loc[ - -1] # type: ignore[index] - non_spec_query_start_loc = \ - self.non_spec_query_start_loc[:batch_size + 1] - non_spec_query_start_loc[num_decodes + - 1:].fill_(non_spec_num_query_tokens) + self.non_spec_query_start_loc[: num_decodes + 1].copy_( + non_spec_query_start_loc, non_blocking=True + ) + non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index] + non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1] + non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens) attn_metadata = GDNAttentionMetadata( num_prefills=num_prefills, @@ -308,7 +344,8 @@ def build( # type: ignore[override] return attn_metadata def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): + self, common_attn_metadata: CommonAttentionMetadata + ): """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba. @@ -317,16 +354,17 @@ def build_for_cudagraph_capture( assert ( m.num_reqs <= self.decode_cudagraph_max_bs - and m.num_actual_tokens <= self.decode_cudagraph_max_bs), ( - f"GDN only supports decode-only full CUDAGraph capture. " - f"Make sure batch size ({m.num_reqs}) <= " - f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), " - f"and number of tokens ({m.num_actual_tokens}) <= " - f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).") + and m.num_actual_tokens <= self.decode_cudagraph_max_bs + ), ( + f"GDN only supports decode-only full CUDAGraph capture. " + f"Make sure batch size ({m.num_reqs}) <= " + f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), " + f"and number of tokens ({m.num_actual_tokens}) <= " + f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})." + ) num_accepted_tokens = torch.diff(m.query_start_loc) num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu() m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu() - return self.build(0, m, num_accepted_tokens, - num_decode_draft_tokens_cpu) + return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu) diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index 0dc62d668020..1900c50849ec 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -6,14 +6,15 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class LinearAttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: return LinearAttentionMetadataBuilder @@ -31,20 +32,25 @@ class LinearAttentionMetadata: state_indices_tensor: torch.Tensor # shape: [batch,] -class LinearAttentionMetadataBuilder( - AttentionMetadataBuilder[LinearAttentionMetadata]): - +class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]): reorder_batch_threshold: int = 1 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> LinearAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> LinearAttentionMetadata: query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -52,8 +58,9 @@ def build(self, num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold)) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) attn_metadata = LinearAttentionMetadata( num_prefills=num_prefills, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 7cbfa2c2c9a5..e305cb2d8702 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -8,14 +8,14 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadataBuilder) -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + split_decodes_and_prefills, +) class Mamba1AttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: return Mamba1AttentionMetadataBuilder @@ -35,8 +35,8 @@ class Mamba1AttentionMetadata: class Mamba1AttentionMetadataBuilder( - BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]): - + BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata] +): def build( self, common_prefix_len: int, @@ -47,24 +47,30 @@ def build( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( - query_start_loc.device) + query_start_loc.device + ) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold)) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) has_initial_states = None padded_decodes = num_decodes if num_prefills > 0: has_initial_states = context_lens_tensor > 0 - elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.full_cuda_graph): + elif ( + num_decodes > 0 + and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph + ): state_indices_for_decode = state_indices_tensor[:num_decodes] padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( - state_indices_for_decode, non_blocking=True) + state_indices_for_decode, non_blocking=True + ) state_indices_tensor = self.state_indices_tensor[:padded_decodes] state_indices_tensor[num_decodes:] = PAD_SLOT_ID diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 68b6ff73ba3f..10f09442d82e 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -9,12 +9,13 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.utils import cdiv -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadataBuilder) -from vllm.v1.attention.backends.utils import (PAD_SLOT_ID, - CommonAttentionMetadata, - compute_causal_conv1d_metadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec @@ -68,27 +69,26 @@ def compute_varlen_chunk_metadata( # Exclusive prefix sum over logical-chunk lengths if chunk_lens: - cu_chunk_seqlens = torch.tensor([0] + - list(itertools.accumulate(chunk_lens)), - device=device, - dtype=torch.int32) + cu_chunk_seqlens = torch.tensor( + [0] + list(itertools.accumulate(chunk_lens)), + device=device, + dtype=torch.int32, + ) # Final boundary must equal total tokens assert int(cu_chunk_seqlens[-1].item()) == total else: cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32) - last_chunk_indices_t = (torch.tensor( - last_chunk_indices, device=device, dtype=torch.int32) - if len(starts) > 0 else torch.empty( - (0, ), device=device, dtype=torch.int32)) - seq_idx_chunks_t = torch.tensor(seq_idx_chunks, - device=device, - dtype=torch.int32) + last_chunk_indices_t = ( + torch.tensor(last_chunk_indices, device=device, dtype=torch.int32) + if len(starts) > 0 + else torch.empty((0,), device=device, dtype=torch.int32) + ) + seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32) return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t class Mamba2AttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]: return Mamba2AttentionMetadataBuilder @@ -122,6 +122,10 @@ class Mamba2AttentionMetadata: last_chunk_indices_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] + block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] + block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,] + block_idx_last_computed_token: torch.Tensor # shape: [batch,] + num_computed_tokens_p: torch.Tensor # shape: [batch,] # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None @@ -130,19 +134,48 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( - BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata] +): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() assert self.chunk_size is not None, ( - "chunk_size needs to be set in the model config for Mamba2 models") - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> Mamba2AttentionMetadata: + "chunk_size needs to be set in the model config for Mamba2 models" + ) + if self.vllm_config.cache_config.enable_prefix_caching: + self.state_indices_tensor = torch.empty( + ( + self.decode_cudagraph_max_bs, + cdiv( + vllm_config.model_config.max_model_len, kv_cache_spec.block_size + ), + ), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_scheduled_token = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + self.block_idx_last_computed_token = torch.empty( + (self.decode_cudagraph_max_bs,), + dtype=torch.int32, + device=device, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> Mamba2AttentionMetadata: num_reqs = common_attn_metadata.num_reqs seq_lens = common_attn_metadata.seq_lens @@ -158,31 +191,80 @@ def build(self, # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + num_computed_tokens, num_computed_tokens_p = None, None + block_idx_first_scheduled_token = None + block_idx_first_scheduled_token_p = None + + if self.vllm_config.cache_config.enable_prefix_caching: + # Return a tensor of shape (#requests, #max blocks) + state_indices_tensor = common_attn_metadata.block_table_tensor + # Additional cache-related varaiables: + mamba_block_size = self.kv_cache_spec.block_size + num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( + self.device + ) + # Block index of the last computed token + block_idx_last_computed_token = ( + cdiv(num_computed_tokens, mamba_block_size) - 1 + ) + # which is <= block index for the first scheduled token + block_idx_first_scheduled_token = ( + cdiv(num_computed_tokens + 1, mamba_block_size) - 1 + ) + # which is <= block index of the last scheduled token + block_idx_last_scheduled_token = ( + cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1 + ) + # -1 in case it's non-computed and causes later issues with indexing + block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) + else: + # Always return just a single block per each request: + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + # Additional cache-related varaiables: + block_idx_last_scheduled_token = None + block_idx_last_computed_token = None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold)) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) # Compute seq_idx for prefill only if num_prefills > 0: - #[batch,] + # [batch,] has_initial_states_cpu = ( - common_attn_metadata. - num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + > 0 + ) prep_initial_states = torch.any(has_initial_states_cpu).item() has_initial_states_p = has_initial_states_cpu.to( - common_attn_metadata.query_start_loc.device) - - query_start_loc_p = common_attn_metadata.query_start_loc[ - -num_prefills - 1:] - num_decode_tokens - - num_computed_tokens_p = \ - common_attn_metadata.num_computed_tokens_cpu[ - num_reqs - num_prefills:num_reqs] - query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ - -num_prefills - 1:] - num_decode_tokens + common_attn_metadata.query_start_loc.device + ) + + query_start_loc_p = ( + common_attn_metadata.query_start_loc[-num_prefills - 1 :] + - num_decode_tokens + ) + + if self.vllm_config.cache_config.enable_prefix_caching: + assert num_computed_tokens is not None + num_computed_tokens_p = num_computed_tokens[ + num_reqs - num_prefills : num_reqs + ] + assert block_idx_first_scheduled_token is not None + block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ + num_reqs - num_prefills : num_reqs + ] + num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + query_start_loc_p_cpu = ( + common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :] + - num_decode_tokens + ) # The code below carefully constructs the chunks such that: # 1. Chunks contain tokens from a *single* sequence only. @@ -199,9 +281,11 @@ def build(self, last_chunk_indices = [] seqlen_pos = 0 for req_idx in range(num_prefills): - this_num_computed = num_computed_tokens_p[req_idx].item() - this_new_tokens = query_start_loc_p_cpu[req_idx + 1].item( - ) - query_start_loc_p_cpu[req_idx].item() + this_num_computed = num_computed_tokens_p_cpu[req_idx].item() + this_new_tokens = ( + query_start_loc_p_cpu[req_idx + 1].item() + - query_start_loc_p_cpu[req_idx].item() + ) # if computed tokens are not chunk-aligned, use the first # chunk to finish it off @@ -209,8 +293,10 @@ def build(self, seq_idx.append(req_idx) cu_chunk_seqlen.append(seqlen_pos) # how many tokens to finish the chunk? - chunk_len = cdiv(this_num_computed, self.chunk_size - ) * self.chunk_size - this_num_computed + chunk_len = ( + cdiv(this_num_computed, self.chunk_size) * self.chunk_size + - this_num_computed + ) # we can only use at most this_new_tokens chunk_len = min(chunk_len, this_new_tokens) seqlen_pos += chunk_len @@ -229,29 +315,49 @@ def build(self, cu_chunk_seqlen.append(seqlen_pos) - seq_idx_p = torch.as_tensor(seq_idx, - device=query_start_loc_p.device, - dtype=torch.int32) + seq_idx_p = torch.as_tensor( + seq_idx, device=query_start_loc_p.device, dtype=torch.int32 + ) cu_chunk_seqlen_p = torch.as_tensor( - cu_chunk_seqlen, - device=query_start_loc_p.device, - dtype=torch.int32) + cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32 + ) last_chunk_indices_p = torch.as_tensor( - last_chunk_indices, - device=query_start_loc_p.device, - dtype=torch.int32) + last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32 + ) - nums_dict, batch_ptr, token_chunk_offset_ptr = \ + nums_dict, batch_ptr, token_chunk_offset_ptr = ( compute_causal_conv1d_metadata(query_start_loc_p) + ) - elif num_decodes <= self.decode_cudagraph_max_bs: + elif ( + num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph + ): # Pad state tensor for CUDA graph num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) - self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor, - non_blocking=True) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_tensor, non_blocking=True + ) state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID + if self.vllm_config.cache_config.enable_prefix_caching: + self.block_idx_last_scheduled_token[:num_decodes].copy_( + block_idx_last_scheduled_token, non_blocking=True + ) + block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ + :num_input_tokens + ] + block_idx_last_scheduled_token[num_decodes:] = 0 + + self.block_idx_last_computed_token[:num_decodes].copy_( + block_idx_last_computed_token, non_blocking=True + ) + block_idx_last_computed_token = self.block_idx_last_computed_token[ + :num_input_tokens + ] + block_idx_last_computed_token[num_decodes:] = 0 + attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, @@ -269,5 +375,9 @@ def build(self, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr, + block_idx_last_scheduled_token=block_idx_last_scheduled_token, + block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, + block_idx_last_computed_token=block_idx_last_computed_token, + num_computed_tokens_p=num_computed_tokens_p, ) return attn_metadata diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index ef342ce421ae..5aafb9813df0 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -7,9 +7,11 @@ import torch from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec M = TypeVar("M") @@ -17,35 +19,44 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): reorder_batch_threshold: int = 1 - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + ) + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) assert isinstance(kv_cache_spec, MambaSpec) self.compilation_config = vllm_config.compilation_config self.decode_cudagraph_max_bs = min( self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_capture_size) + self.compilation_config.max_capture_size, + ) self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, ), + (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba. """ m = common_attn_metadata - assert m.num_reqs == m.num_actual_tokens, \ - "Mamba only supports decode-only full CUDAGraph capture. " \ + assert m.num_reqs == m.num_actual_tokens, ( + "Mamba only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." + ) m.max_query_len = 1 # decode-only diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 963f1c5abf2a..f7ec18f5e9f6 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -197,9 +197,12 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata, - MLAAttentionImpl) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + AttentionMetadata, + MLAAttentionImpl, +) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states @@ -207,21 +210,26 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + UnquantizedLinearMethod, +) from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm.utils.flashinfer import has_nvidia_artifactory -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - get_per_layer_parameters, - infer_global_hyperparameters, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, + get_per_layer_parameters, + infer_global_hyperparameters, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec try: from vllm.vllm_flash_attn import flash_attn_varlen_func + is_vllm_fa = True except ImportError: # For rocm use upstream flash attention @@ -231,26 +239,29 @@ try: from flashinfer import BatchPrefillWithRaggedKVCacheWrapper - from flashinfer.prefill import ( # noqa: F401 - cudnn_batch_prefill_with_kv_cache) + from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401 + flashinfer_available = True except ImportError: flashinfer_available = False def is_rocm_aiter_fp8bmm_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_FP8BMM \ + return ( + current_platform.is_rocm() + and envs.VLLM_ROCM_USE_AITER_FP8BMM and envs.VLLM_ROCM_USE_AITER + ) if is_rocm_aiter_fp8bmm_enabled(): - from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant - as aiter_triton_fp8_bmm) + from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm, # noqa: E501 + ) def dynamic_per_batched_tensor_quant( - x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn): + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn + ): DTYPE_MAX = torch.finfo(dtype).max min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10) @@ -265,7 +276,6 @@ def dynamic_per_batched_tensor_quant( class MLACommonBackend(AttentionBackend): - accept_output_buffer: bool = True @staticmethod @@ -307,12 +317,13 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @dataclass class MLACommonPrefillMetadata: - """ Prefill Specific Metadata """ + """Prefill Specific Metadata""" @dataclass class ChunkedContextMetadata: @@ -340,16 +351,15 @@ class ChunkedContextMetadata: @dataclass class FlashInferPrefillMetadata(MLACommonPrefillMetadata): - prefill_main: Optional['BatchPrefillWithRaggedKVCacheWrapper'] = None - prefill_chunks: list['BatchPrefillWithRaggedKVCacheWrapper'] = field( - default_factory=list) + prefill_main: Optional["BatchPrefillWithRaggedKVCacheWrapper"] = None + prefill_chunks: list["BatchPrefillWithRaggedKVCacheWrapper"] = field( + default_factory=list + ) @dataclass class CudnnPrefillMetadata(MLACommonPrefillMetadata): - - class ChunkedContextMetadata( - MLACommonPrefillMetadata.ChunkedContextMetadata): + class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): seq_lens: torch.Tensor query_seq_lens: Optional[torch.Tensor] = None @@ -372,6 +382,7 @@ class MLACommonMetadata(Generic[D]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -398,9 +409,9 @@ class MLACommonMetadata(Generic[D]): head_dim: Optional[int] = None decode: Optional[D] = None - prefill: Optional[Union[MLACommonPrefillMetadata, - FlashInferPrefillMetadata, - CudnnPrefillMetadata]] = None + prefill: Optional[ + Union[MLACommonPrefillMetadata, FlashInferPrefillMetadata, CudnnPrefillMetadata] + ] = None def __post_init__(self): if self.head_dim is not None: @@ -414,15 +425,21 @@ def __post_init__(self): def use_flashinfer_prefill() -> bool: # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. - return (not envs.VLLM_DISABLE_FLASHINFER_PREFILL and flashinfer_available - and not envs.VLLM_USE_CUDNN_PREFILL - and current_platform.is_device_capability(100)) + return ( + not envs.VLLM_DISABLE_FLASHINFER_PREFILL + and flashinfer_available + and not envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100) + ) def use_cudnn_prefill() -> bool: - return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL - and current_platform.is_device_capability(100) - and has_nvidia_artifactory()) + return ( + flashinfer_available + and envs.VLLM_USE_CUDNN_PREFILL + and current_platform.is_device_capability(100) + and has_nvidia_artifactory() + ) # Currently 394MB, this can be tuned based on GEMM sizes used. @@ -436,19 +453,21 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): NOTE: Please read the comment at the top of the file before trying to understand this class """ + reorder_batch_threshold: int = 1 @staticmethod - def determine_chunked_prefill_workspace_size( - vllm_config: VllmConfig) -> int: + def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: scheduler_config = vllm_config.scheduler_config cache_config = vllm_config.cache_config model_config = vllm_config.model_config chunked_prefill_workspace_size = min( # Try for 8 full length request or at least 4 pages per-request - max(8 * model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * cache_config.block_size), + max( + 8 * model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * cache_config.block_size, + ), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, # which would result in the workspace being: @@ -457,23 +476,28 @@ def determine_chunked_prefill_workspace_size( # which would result in up-projected context being # 2*(192*128)*(64*1024) = 3gb # (assuming 192 QK head dim, 128 heads, and fp16) - 64 * 1024) + 64 * 1024, + ) # Enforce that we enough for at least 1 page per request chunked_prefill_workspace_size = max( chunked_prefill_workspace_size, - scheduler_config.max_num_seqs * cache_config.block_size) + scheduler_config.max_num_seqs * cache_config.block_size, + ) return chunked_prefill_workspace_size - def __init__(self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[type[M]] = None): - self.metadata_cls = metadata_cls \ - if metadata_cls is not None else MLACommonMetadata + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[type[M]] = None, + ): + self.metadata_cls = ( + metadata_cls if metadata_cls is not None else MLACommonMetadata + ) self.kv_cache_spec = kv_cache_spec scheduler_config = vllm_config.scheduler_config self.model_config = vllm_config.model_config @@ -481,8 +505,7 @@ def __init__(self, self.compilation_config = vllm_config.compilation_config self.device = device - self.num_heads = self.model_config.get_num_attention_heads( - parallel_config) + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() try: @@ -497,27 +520,31 @@ def __init__(self, if self.aot_schedule: self.page_size = self.kv_cache_spec.block_size - self.chunked_prefill_workspace_size = \ + self.chunked_prefill_workspace_size = ( self.determine_chunked_prefill_workspace_size(vllm_config) + ) if self.dcp_world_size > 1: # Note(hc): The local kvcache is incomplete when DCP is triggered, # an additional kvcache allgather across the DCP group is therefore # required, so the workspace has to be enlarged by 1/DCP relative # to the original TP allocation. - assert self.chunked_prefill_workspace_size % \ - self.dcp_world_size == 0 + assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0 self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size + - self.chunked_prefill_workspace_size // self.dcp_world_size, - self.model_config.get_head_size()), + ( + self.chunked_prefill_workspace_size + + self.chunked_prefill_workspace_size // self.dcp_world_size, + self.model_config.get_head_size(), + ), dtype=self.model_config.dtype, device=device, ) else: self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), + ( + self.chunked_prefill_workspace_size, + self.model_config.get_head_size(), + ), dtype=self.model_config.dtype, device=device, ) @@ -526,23 +553,23 @@ def __init__(self, self._use_fi_prefill = use_flashinfer_prefill() self.prefill_metadata_cls = ( FlashInferPrefillMetadata - if self._use_fi_prefill else CudnnPrefillMetadata - if self._use_cudnn_prefill else MLACommonPrefillMetadata) + if self._use_fi_prefill + else CudnnPrefillMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata + ) if self._use_fi_prefill: self._workspace_buffer = torch.empty( - FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, - device=device) + FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device + ) - self._fi_prefill_main: Optional[ - BatchPrefillWithRaggedKVCacheWrapper] = None - self._fi_prefill_chunks: list[ - BatchPrefillWithRaggedKVCacheWrapper] = [] + self._fi_prefill_main: Optional[BatchPrefillWithRaggedKVCacheWrapper] = None + self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, layer_names, - MLACommonImpl)) + get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) + ) if self._use_cudnn_prefill: self.cudnn_workspace = torch.empty( @@ -561,7 +588,8 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): if self._fi_prefill_main is None: self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper( - self._workspace_buffer, "NHD", backend="cutlass") + self._workspace_buffer, "NHD", backend="cutlass" + ) if has_context: num_chunks = chunked_context.cu_seq_lens.shape[0] @@ -570,7 +598,9 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): for _ in range(len(self._fi_prefill_chunks), num_chunks): self._fi_prefill_chunks.append( BatchPrefillWithRaggedKVCacheWrapper( - self._workspace_buffer, "NHD", backend="cutlass")) + self._workspace_buffer, "NHD", backend="cutlass" + ) + ) assert num_chunks <= len(self._fi_prefill_chunks) # In MLA, the non-latent num_qo_heads == num_kv_heads @@ -581,8 +611,7 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): assert self.kv_cache_spec.num_kv_heads == 1 # Get non-latent head_dim_qk and head_dim_vo - head_dim_qk = (self.mla_dims.qk_nope_head_dim + - self.mla_dims.qk_rope_head_dim) + head_dim_qk = self.mla_dims.qk_nope_head_dim + self.mla_dims.qk_rope_head_dim head_dim_vo = self.mla_dims.v_head_dim # For main run, qo_indptr == kv_indptr @@ -618,45 +647,50 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): causal=False, # This is context run sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, - logits_soft_cap=self._global_hyperparameters. - logits_soft_cap, + logits_soft_cap=self._global_hyperparameters.logits_soft_cap, q_data_type=self.model_config.dtype, ) prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> MLACommonDecodeMetadata: + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, ) def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with MLA. """ m = common_attn_metadata - assert m.num_reqs <= (m.num_actual_tokens * - self.reorder_batch_threshold), \ - "MLA only supports decode-only full CUDAGraph capture. " \ + assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), ( + "MLA only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." + ) assert m.max_query_len <= self.reorder_batch_threshold # decode only return self.build(0, m) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> M: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -676,18 +710,19 @@ def build(self, query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - - query_seq_lens_cpu) + num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) # Note(hc): update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: - seq_lens[:num_decodes] = seq_lens[:num_decodes] \ - // self.dcp_world_size + (self.dcp_rank <= \ - (seq_lens[:num_decodes] - 1) % self.dcp_world_size) + seq_lens[:num_decodes] = seq_lens[:num_decodes] // self.dcp_world_size + ( + self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size + ) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -698,13 +733,15 @@ def build(self, context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] # Note(hc): The context lengths in the perspective of dcp rank0. - cp_context_lens_cpu = torch.ceil(context_lens_cpu.float() / - self.dcp_world_size).int() + cp_context_lens_cpu = torch.ceil( + context_lens_cpu.float() / self.dcp_world_size + ).int() origin_context_lens = context_lens_cpu.tolist() max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] + prefill_query_start_loc = ( + query_start_loc[reqs_start:] - query_start_loc[reqs_start] + ) chunked_context_metadata = None if max_context_len_cpu > 0: @@ -716,16 +753,16 @@ def build(self, # prefill in the batch, we could probably use a more advanced # algorithm here and allocate more workspace to prefills with # longer context lengths - max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) + max_context_chunk = ( + self.chunked_prefill_workspace_size // num_prefills_with_context_cpu + ) if self.aot_schedule: # align max_context_chunk to page_size by rounding down, # currently the `gather_and_maybe_dequant_cache` kernel # cannot handle `context_chunk_starts` that are not aligned # to page_size - max_context_chunk = round_down(max_context_chunk, - self.page_size) + max_context_chunk = round_down(max_context_chunk, self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) @@ -736,22 +773,23 @@ def build(self, # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] # Note(simon): this is done in CPU because of downstream's # of `to_list`. - chunk_starts = \ - torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) \ + chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) * max_context_chunk - chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), - chunk_starts + max_context_chunk) + ) + chunk_ends = torch.min( + context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk + ) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(chunk_seq_lens, - dim=1, - out=cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) + cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32 + ) if self.dcp_world_size > 1: # Note(hc): The above max_context_chunk already enforces @@ -760,36 +798,37 @@ def build(self, # cp_gather_cache which not require `cp_chunk_starts` # aligned to page_size. assert max_context_chunk % self.dcp_world_size == 0 - cp_max_context_chunk = max_context_chunk // \ - self.dcp_world_size - cp_chunk_starts = \ - torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) \ + cp_max_context_chunk = max_context_chunk // self.dcp_world_size + cp_chunk_starts = ( + torch.arange(num_chunks, dtype=torch.int32) + .unsqueeze(1) + .expand(-1, num_prefills) * cp_max_context_chunk + ) cp_chunk_ends = torch.min( cp_context_lens_cpu.unsqueeze(0), - cp_chunk_starts + cp_max_context_chunk) - cp_chunk_seq_lens = (cp_chunk_ends - - cp_chunk_starts).clamp(min=0) - - cp_cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(cp_chunk_seq_lens, - dim=1, - out=cp_cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) - - chunked_context_metadata_cls = \ - CudnnPrefillMetadata.ChunkedContextMetadata \ - if self._use_cudnn_prefill else \ - MLACommonPrefillMetadata.ChunkedContextMetadata + cp_chunk_starts + cp_max_context_chunk, + ) + cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0) + + cp_cu_seq_lens_cpu = torch.zeros( + num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True + ) + torch.cumsum( + cp_chunk_seq_lens, + dim=1, + out=cp_cu_seq_lens_cpu[:, 1:], + dtype=torch.int32, + ) + + chunked_context_metadata_cls = ( + CudnnPrefillMetadata.ChunkedContextMetadata + if self._use_cudnn_prefill + else MLACommonPrefillMetadata.ChunkedContextMetadata + ) if self.dcp_world_size > 1: - chunked_context_metadata = \ - chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu \ - .to(device, non_blocking=True), + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=cp_chunk_starts.to(device, non_blocking=True), seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), @@ -797,16 +836,13 @@ def build(self, workspace=self.chunked_prefill_workspace, cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(), origin_context_lens=origin_context_lens, - cp_cu_seq_lens=cp_cu_seq_lens_cpu \ - .to(device, non_blocking=True), + cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True), chunk_size=max_context_chunk, cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), ) else: - chunked_context_metadata = \ - chunked_context_metadata_cls( - cu_seq_lens=cu_seq_lens_cpu \ - .to(device, non_blocking=True), + chunked_context_metadata = chunked_context_metadata_cls( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), @@ -817,8 +853,10 @@ def build(self, if self._use_cudnn_prefill: chunked_context_metadata.seq_lens = chunk_seq_lens - assert max(chunked_context_metadata.max_seq_lens) <= \ - self.chunked_prefill_workspace_size + assert ( + max(chunked_context_metadata.max_seq_lens) + <= self.chunked_prefill_workspace_size + ) prefill_metadata = self.prefill_metadata_cls( block_table=block_table_tensor[reqs_start:, ...], @@ -829,8 +867,9 @@ def build(self, if self._use_cudnn_prefill: assert isinstance(prefill_metadata, CudnnPrefillMetadata) - prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \ - - prefill_query_start_loc[:-1] + prefill_metadata.query_seq_lens = ( + prefill_query_start_loc[1:] - prefill_query_start_loc[:-1] + ) prefill_metadata.cudnn_workspace = self.cudnn_workspace decode_metadata = None @@ -839,8 +878,8 @@ def build(self, block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens_cpu=seq_lens_cpu[:num_decodes], seq_lens_device=seq_lens[:num_decodes], - query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], - query_start_loc_device=query_start_loc[:num_decodes + 1], + query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], + query_start_loc_device=query_start_loc[: num_decodes + 1], num_decode_tokens=num_decode_tokens, ) @@ -897,12 +936,14 @@ def reorg_kvcache( k_pe_segments = [] src_token_idx = 0 max_seq_len_check = 0 - for cp_chunk_seq_len, origin_context_len in zip(cp_chunk_seq_lens_lst, - origin_context_lens): + for cp_chunk_seq_len, origin_context_len in zip( + cp_chunk_seq_lens_lst, origin_context_lens + ): chunk_context_len = chunk_size if cp_chunk_seq_len != 0: chunk_context_len = min( - chunk_context_len, origin_context_len - chunk_size * chunk_idx) + chunk_context_len, origin_context_len - chunk_size * chunk_idx + ) cp_target_rank = (chunk_context_len - 1) % cp_world_size cur_seq_len = 0 for rank in range(cp_world_size): @@ -911,14 +952,16 @@ def reorg_kvcache( else: real_cp_chunk_seq_len = cp_chunk_seq_len if real_cp_chunk_seq_len: - kv_c_segment = allgatered_kv_c_normed[rank * toks + - src_token_idx:rank * - toks + src_token_idx + - real_cp_chunk_seq_len] - k_pe_segment = allgatered_k_pe[rank * toks + - src_token_idx:rank * toks + - src_token_idx + - real_cp_chunk_seq_len] + kv_c_segment = allgatered_kv_c_normed[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + real_cp_chunk_seq_len + ] + k_pe_segment = allgatered_k_pe[ + rank * toks + src_token_idx : rank * toks + + src_token_idx + + real_cp_chunk_seq_len + ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) cur_seq_len += real_cp_chunk_seq_len @@ -983,25 +1026,24 @@ def __init__( self.q_pad_num_heads = q_pad_num_heads def process_weights_after_loading(self, act_dtype: torch.dtype): - def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: if hasattr(layer, attr): return getattr(layer, attr) raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") + f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." + ) def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T @@ -1013,12 +1055,14 @@ def get_and_maybe_dequant_weights(layer: LinearBase): kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, @@ -1026,15 +1070,18 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype()) + W_K, dtype=current_platform.fp8_dtype() + ) self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype()) + W_V, dtype=current_platform.fp8_dtype() + ) # The kernel operates on non-padded inputs. Hence, pre-compiling # triton kernel to avoid runtime compilation for unseen batch sizes @@ -1050,23 +1097,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) for m in pre_compilation_list: - x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device) - aiter_triton_fp8_bmm(x, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) - - x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device) - aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + aiter_triton_fp8_bmm( + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) + + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) else: # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1) @@ -1078,11 +1125,9 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) if is_rocm_aiter_fp8bmm_enabled(): # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) + x = aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) # Convert from (B, N, V) to (B, N * V) x = x.reshape(-1, self.num_heads * self.v_head_dim) # Copy result @@ -1095,8 +1140,7 @@ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" # Convert from (N, B, V) to (B, N * V) - out_new = out.transpose(0, 1).reshape( - -1, self.num_heads * self.v_head_dim) + out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) # Adjust output buffer shape back to the original (B, N * V) N, B, V = out.shape @@ -1120,8 +1164,7 @@ def __init__(self, *args, **kwargs) -> None: self._pad_v = False elif use_cudnn_prefill(): logger.debug_once("Using CUDNN prefill for MLA") - self._run_prefill_context_chunk = \ - self._run_prefill_context_chunk_cudnn + self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn self._pad_v = False else: # Use FlashAttention @@ -1136,9 +1179,9 @@ def __init__(self, *args, **kwargs) -> None: self.flash_attn_varlen_func = flash_attn_varlen_func self.vllm_flash_attn_version = get_flash_attn_version() if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) + self.flash_attn_varlen_func = functools.partial( + flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version + ) # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim for attention backends that do @@ -1146,25 +1189,25 @@ def __init__(self, *args, **kwargs) -> None: # We don't need to pad V if we are on a hopper system with FA3 self._pad_v = self.vllm_flash_attn_version is None or not ( self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9) + and current_platform.get_device_capability()[0] == 9 + ) self.dcp_world_size: Optional[int] = None - self.chunked_prefill_workspace_size = \ + self.chunked_prefill_workspace_size = ( MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( - get_current_vllm_config()) - - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): + get_current_vllm_config() + ) + ) + + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): maybe_padded_v = v if self._pad_v: maybe_padded_v = torch.nn.functional.pad( - v, [0, q.shape[-1] - v.shape[-1]], value=0) + v, [0, q.shape[-1] - v.shape[-1]], value=0 + ) if is_vllm_fa: kwargs["return_softmax_lse"] = return_softmax_lse @@ -1192,8 +1235,9 @@ def _flash_attn_varlen_diff_headdims(self, return attn_out, lse return attn_out - def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, - k, v, return_softmax_lse): + def _run_prefill_new_tokens_fa( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): return self._flash_attn_varlen_diff_headdims( q=q, k=k, @@ -1207,8 +1251,9 @@ def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, return_softmax_lse=return_softmax_lse, ) - def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q, - k, v, return_softmax_lse): + def _run_prefill_new_tokens_fi( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): assert isinstance(prefill, FlashInferPrefillMetadata) assert prefill.prefill_main is not None ret = prefill.prefill_main.run( @@ -1223,8 +1268,9 @@ def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q, return ret[0], ret[1].transpose(0, 1).contiguous() return ret - def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, - q, k, v, return_softmax_lse): + def _run_prefill_new_tokens_cudnn( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.query_seq_lens is not None output, lse = cudnn_batch_prefill_with_kv_cache( @@ -1238,16 +1284,18 @@ def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1), causal=True, - return_lse=True, # do not support False for now - is_cuda_graph_compatible= - True, #Indicates actual_seq_lens are on GPU or CPU. + # Do not support False for now + return_lse=True, + # Indicates actual_seq_lens are on GPU or CPU. + is_cuda_graph_compatible=True, ) if return_softmax_lse: return output, lse return output - def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_fa( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert prefill.chunked_context is not None return self._flash_attn_varlen_diff_headdims( q=q, @@ -1262,8 +1310,9 @@ def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, return_softmax_lse=True, ) - def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_fi( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert isinstance(prefill, FlashInferPrefillMetadata) attn_out, lse = prefill.prefill_chunks[chunk_idx].run( q=q, @@ -1274,9 +1323,9 @@ def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, # Convert from (q_len, num_heads) to (num_heads, q_len) return attn_out, lse.transpose(0, 1).contiguous() - def _run_prefill_context_chunk_cudnn(self, - prefill: MLACommonPrefillMetadata, - chunk_idx: int, q, k, v): + def _run_prefill_context_chunk_cudnn( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): assert isinstance(prefill, CudnnPrefillMetadata) assert prefill.chunked_context is not None assert prefill.chunked_context.seq_lens[chunk_idx] is not None @@ -1290,34 +1339,34 @@ def _run_prefill_context_chunk_cudnn(self, max_token_per_sequence=prefill.max_query_len, max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx], actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), - actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx]. - view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view( + -1, 1, 1, 1 + ), causal=False, return_lse=True, - is_cuda_graph_compatible= - True, #Indicates actual_seq_lens are on GPU or CPU. + # Indicates actual_seq_lens are on GPU or CPU. + is_cuda_graph_compatible=True, ) def process_weights_after_loading(self, act_dtype: torch.dtype): - def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: if hasattr(layer, attr): return getattr(layer, attr) raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") + f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}." + ) def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) + eye = torch.eye( + layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device, + ) + dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T @@ -1329,12 +1378,14 @@ def get_and_maybe_dequant_weights(layer: LinearBase): kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}" + ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, @@ -1342,15 +1393,18 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) if is_rocm_aiter_fp8bmm_enabled(): W_K = W_UK.transpose(0, 1) # 16 512 128 W_V = W_UV.permute(1, 2, 0) # 16 128 512 self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype()) + W_K, dtype=current_platform.fp8_dtype() + ) self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype()) + W_V, dtype=current_platform.fp8_dtype() + ) # The kernel operates on non-padded inputs. Hence, pre-compiling # triton kernel to avoid runtime compilation for unseen batch sizes @@ -1366,23 +1420,23 @@ def get_and_maybe_dequant_weights(layer: LinearBase): ) for m in pre_compilation_list: - x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device) - aiter_triton_fp8_bmm(x, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) - - x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device) - aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) + x = torch.empty( + (self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device, + ) + aiter_triton_fp8_bmm( + x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) + + x = torch.empty( + (self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device, + ) + aiter_triton_fp8_bmm( + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + ) else: # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1) @@ -1418,18 +1472,15 @@ def _compute_prefill_context( seq_starts=prefill_metadata.chunked_context.starts[i], ) - kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank] - k_pe = workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) + kv_c_normed = workspace[:toks][..., : self.kv_lora_rank] + k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1494,44 +1545,45 @@ def _context_parallel_compute_prefill_context( # |------- N tokens --------|--------- N*dcp_size tokens ----------| # |<- use for loca_gather ->|<--------- use for allgather -------->| allgather_offset = workspace.shape[0] // (dcp_world_size + 1) - assert allgather_offset * (dcp_world_size + - 1) == workspace.shape[0] + assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0] assert toks <= allgather_offset local_gathered_kvcache = workspace[:toks] cur_allgather_workspace = workspace[ - allgather_offset:allgather_offset * (1 + dcp_world_size)] + allgather_offset : allgather_offset * (1 + dcp_world_size) + ] assert toks * dcp_world_size <= cur_allgather_workspace.shape[0] - cur_allgather_kvcache = cur_allgather_workspace[:toks * - dcp_world_size] - cur_allgather_kvcache.copy_(get_dcp_group().all_gather( - local_gathered_kvcache, dim=0)) - assert cur_allgather_kvcache.shape[ - -1] == self.kv_lora_rank + self.qk_rope_head_dim - allgatered_kv_c_normed, allgatered_k_pe = \ - cur_allgather_kvcache.unsqueeze( - 1).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size] + cur_allgather_kvcache.copy_( + get_dcp_group().all_gather(local_gathered_kvcache, dim=0) + ) + assert ( + cur_allgather_kvcache.shape[-1] + == self.kv_lora_rank + self.qk_rope_head_dim + ) + allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze( + 1 + ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed, k_pe = reorg_kvcache( allgatered_kv_c_normed, allgatered_k_pe, - cp_chunk_seq_lens_lst=prefill_metadata.chunked_context. - cp_chunk_seq_lens[i], - origin_context_lens=prefill_metadata.chunked_context. - origin_context_lens, + cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[ + i + ], + origin_context_lens=prefill_metadata.chunked_context.origin_context_lens, cp_world_size=dcp_world_size, - sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i] - [-1], + sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], chunk_size=prefill_metadata.chunked_context.chunk_size, chunk_idx=i, - toks=toks) + toks=toks, + ) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1574,10 +1626,10 @@ def _forward_prefill( assert self.dcp_world_size is not None has_context = attn_metadata.prefill.chunked_context is not None - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) @@ -1592,14 +1644,19 @@ def _forward_prefill( if has_context: suffix_output, suffix_lse = output if self.dcp_world_size > 1: - context_output, context_lse = \ + context_output, context_lse = ( self._context_parallel_compute_prefill_context( - q, kv_c_and_k_pe_cache, attn_metadata, - k_scale=None, dcp_world_size=self.dcp_world_size) + q, + kv_c_and_k_pe_cache, + attn_metadata, + k_scale=None, + dcp_world_size=self.dcp_world_size, + ) + ) else: - context_output, context_lse = \ - self._compute_prefill_context( - q, kv_c_and_k_pe_cache, attn_metadata, k_scale) + context_output, context_lse = self._compute_prefill_context( + q, kv_c_and_k_pe_cache, attn_metadata, k_scale + ) output = torch.empty_like(suffix_output) merge_attn_states( @@ -1612,7 +1669,7 @@ def _forward_prefill( # unpad if necessary if self._pad_v: - output = output[..., :v.shape[-1]] + output = output[..., : v.shape[-1]] return output.flatten(start_dim=-2) @@ -1642,16 +1699,19 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for MLACommonImpl") + "fused output quantization is not yet supported for MLACommonImpl" + ) if attn_metadata is None: # During the profile run try to simulate to worse case output size # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` # since this can be large _ = torch.empty( - (self.chunked_prefill_workspace_size, self.num_heads, - self.qk_nope_head_dim + self.v_head_dim), + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), device=k_c_normed.device, dtype=k_c_normed.dtype, ) @@ -1675,9 +1735,11 @@ def forward( k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 @@ -1705,39 +1767,47 @@ def forward( if has_prefill: output[num_decode_tokens:] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale) + prefill_q, + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + attn_metadata, + layer._k_scale, + ) if has_decode: assert attn_metadata.decode is not None decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) # Pads the head_dim if necessary (for the underlying kernel) if self.q_pad_num_heads is not None: B, N, L = decode_q_pe.shape - decode_pe_padded = decode_q_pe.new_empty( - (B, self.q_pad_num_heads, L)) + decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) decode_pe_padded.resize_((B, N, L)) decode_pe_padded.copy_(decode_q_pe) decode_q_pe = decode_pe_padded if is_rocm_aiter_fp8bmm_enabled(): # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) + decode_ql_nope = aiter_triton_fp8_bmm( + decode_q_nope, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True, + ) else: # Pads the head_dim if necessary (for the underlying kernel) N, B, P = decode_q_nope.shape _, _, L = self.W_UK_T.shape if self.q_pad_num_heads is not None: decode_ql_nope = decode_q_nope.new_empty( - (self.q_pad_num_heads, B, L)) + (self.q_pad_num_heads, B, L) + ) decode_ql_nope.resize_((N, B, L)) else: @@ -1751,15 +1821,17 @@ def forward( if fp8_attention: ql_nope_shape = decode_ql_nope.shape decode_ql_nope, _ = ops.scaled_fp8_quant( - decode_ql_nope.reshape([ - ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2] - ]), layer._q_scale) + decode_ql_nope.reshape( + [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]] + ), + layer._q_scale, + ) decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape) q_pe_shape = decode_q_pe.shape decode_q_pe, _ = ops.scaled_fp8_quant( - decode_q_pe.reshape( - [q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), - layer._q_scale) + decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), + layer._q_scale, + ) decode_q_pe = decode_q_pe.reshape(q_pe_shape) decode_q = (decode_ql_nope, decode_q_pe) @@ -1771,8 +1843,9 @@ def forward( decode_q = get_dcp_group().all_gather(decode_q, dim=1) # call decode attn - attn_out, lse = self._forward_decode(decode_q, kv_cache, - attn_metadata, layer) + attn_out, lse = self._forward_decode( + decode_q, kv_cache, attn_metadata, layer + ) # recorect dcp attn_out with lse. if self.dcp_world_size > 1: diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index d44e20f2cb6b..a3c677ca2108 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -7,13 +7,18 @@ import torch import vllm._custom_ops as ops -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + is_quantized_kv_cache, +) from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport logger = init_logger(__name__) @@ -21,12 +26,12 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): # enable full CUDA Graph support for decode-only capture - cudagraph_support: ClassVar[ - AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + cudagraph_support: ClassVar[AttentionCGSupport] = ( + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) class CutlassMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "CUTLASS_MLA" @@ -41,11 +46,10 @@ def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: class SM100Workspace: - def __init__(self, initial_workspace_size): - self._workspace_buf = torch.empty(initial_workspace_size, - device="cuda", - dtype=torch.uint8) + self._workspace_buf = torch.empty( + initial_workspace_size, device="cuda", dtype=torch.uint8 + ) self._block_size = 128 # Forced to 128 @@ -57,8 +61,7 @@ def __init__(self, initial_workspace_size): def get_buf(self): return self._workspace_buf - def ensure_size(self, attn_metadata: MLACommonMetadata, - num_kv_splits: int): + def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int): batch_size = attn_metadata.num_reqs max_seq_len = attn_metadata.max_query_len @@ -66,7 +69,8 @@ def ensure_size(self, attn_metadata: MLACommonMetadata, max_seq_len * self._block_size, batch_size, self._sm_count, - num_kv_splits=num_kv_splits) + num_kv_splits=num_kv_splits, + ) if self._workspace_buf.shape[0] < workspace_size: self._workspace_buf.resize_(workspace_size) @@ -81,51 +85,56 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, - head_size, - scale, - num_kv_heads, - alibi_slopes, - sliding_window, - kv_cache_dtype, - logits_soft_cap, - attn_type, - kv_sharing_target_layer_name, - q_pad_num_heads=MAX_HEADS, - **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + q_pad_num_heads=MAX_HEADS, + **mla_args, + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "CutlassMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "CutlassMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "CutlassMLAImpl" + ) # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging # issues. In case the code hangs, use: # FORCE_NUM_KV_SPLITS=1 force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None) if force_num_kv_splits: - logger.warning_once("Forcing num_kv_splits to %d", - int(force_num_kv_splits)) + logger.warning_once("Forcing num_kv_splits to %d", int(force_num_kv_splits)) self._num_kv_splits = int(force_num_kv_splits) else: self._num_kv_splits = -1 # => Auto-detect @@ -144,14 +153,13 @@ def _sm100_cutlass_mla_decode( sm_scale: float, num_kv_splits: int, ) -> tuple[torch.Tensor, torch.Tensor]: - assert (q_nope.ndim == 3 - ), f"q_nope must be a 3D tensor, but got {q_nope.ndim}" - assert ( - q_pe.ndim == 3), f"q_pe must be a 3D tensor, but got {q_pe.ndim}" - assert ( - kv_c_and_k_pe_cache.ndim == 3 - ), "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format( - kv_c_and_k_pe_cache.ndim) + assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}" + assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}" + assert kv_c_and_k_pe_cache.ndim == 3, ( + "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format( + kv_c_and_k_pe_cache.ndim + ) + ) B_q, H, D_q_nope = q_nope.shape B_q_2, H_2, D_q_pe = q_pe.shape @@ -171,28 +179,31 @@ def _sm100_cutlass_mla_decode( assert len(page_table.shape) == 2 B_block_table, block_num = page_table.shape assert B_block_table == B_q - assert (block_num - > 0), f"block num must be greater than 0, got {block_num}" + assert block_num > 0, f"block num must be greater than 0, got {block_num}" assert block_num % (128 / PAGE_SIZE) == 0 - assert q_nope.dtype in ( - torch.float16, torch.bfloat16, torch.float8_e4m3fn), ( - f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got " - f"{q_nope.dtype}.") + assert q_nope.dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn), ( + f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got {q_nope.dtype}." + ) assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype - assert ( - seq_lens.dtype == torch.int32 - ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." - assert ( - page_table.dtype == torch.int32 - ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." - - dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype) - else q_nope.dtype) + assert seq_lens.dtype == torch.int32, ( + f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." + ) + assert page_table.dtype == torch.int32, ( + f"page_table.dtype needs to be int32 but got {page_table.dtype}." + ) + + dtype = ( + torch.bfloat16 + if is_quantized_kv_cache(self.kv_cache_dtype) + else q_nope.dtype + ) out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) - lse = (torch.empty( - (B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) - if self.need_to_return_lse_for_decode else torch.Tensor()) + lse = ( + torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) + if self.need_to_return_lse_for_decode + else torch.Tensor() + ) ops.sm100_cutlass_mla_decode( out, @@ -228,7 +239,8 @@ def _forward_decode( q_nope, q_pe = q else: q_nope, q_pe = torch.split( - q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) # Adjust workspace size (if necessary) self._workspace.ensure_size(attn_metadata, self._num_kv_splits) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 652b1cdb6b76..c0c2dbe1f961 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -7,18 +7,25 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) -from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, - get_flash_attn_version) +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + is_quantized_kv_cache, +) +from vllm.attention.utils.fa_utils import ( + flash_attn_supports_mla, + get_flash_attn_version, +) from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonDecodeMetadata, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata @@ -27,7 +34,6 @@ class FlashAttnMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "FLASH_ATTN_MLA" @@ -59,22 +65,27 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): pass -class FlashAttnMLAMetadataBuilder( - MLACommonMetadataBuilder[FlashAttnMLAMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_BATCH +class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH reorder_batch_threshold: int = 512 - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - FlashAttnMLAMetadata) + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata + ) self.max_num_splits = 0 # No upper bound on the number of splits. - self.fa_aot_schedule = (get_flash_attn_version() == 3) + self.fa_aot_schedule = get_flash_attn_version() == 3 - self.use_full_cuda_graph = \ + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) if self.use_full_cuda_graph and self.fa_aot_schedule: self.max_cudagraph_size = self.compilation_config.max_capture_size @@ -83,8 +94,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. raise ValueError( - "Capture size larger than 992 is not supported for " - "full cuda graph.") + "Capture size larger than 992 is not supported for full cuda graph." + ) self.scheduler_metadata = torch.zeros( vllm_config.scheduler_config.max_num_seqs + 1, @@ -94,16 +105,17 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are # pre-allocated during capture. - self.max_num_splits = ( - envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) + self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH # TODO(lucas): Until we add support for the DCP custom masking we need # to restrict decodes to q_len == 1 when DCP is enabled. - self.reorder_batch_threshold = 1 \ - if get_dcp_group().world_size > 1 else self.reorder_batch_threshold + self.reorder_batch_threshold = ( + 1 if get_dcp_group().world_size > 1 else self.reorder_batch_threshold + ) - def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): + def _schedule_decode( + self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): if self.fa_aot_schedule: return get_scheduler_metadata( batch_size=num_reqs, @@ -122,13 +134,16 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, ) return None - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> FlashAttnMLADecodeMetadata: - query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + ) -> FlashAttnMLADecodeMetadata: + query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] max_query_len = query_lens_cpu.max().item() max_seq_len = seq_lens_cpu.max().item() @@ -146,9 +161,10 @@ def _build_decode(self, block_table_tensor: torch.Tensor, if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] # Ensure the persistent buffer is large enough - assert n <= self.scheduler_metadata.shape[0], \ - f"Scheduler metadata size {n} exceeds buffer size " + \ - f"{self.scheduler_metadata.shape[0]}" + assert n <= self.scheduler_metadata.shape[0], ( + f"Scheduler metadata size {n} exceeds buffer size " + + f"{self.scheduler_metadata.shape[0]}" + ) self.scheduler_metadata[:n] = scheduler_metadata # NOTE(woosuk): We should zero out the rest of the scheduler # metadata to guarantee the correctness. Otherwise, some thread @@ -179,42 +195,55 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - assert flash_attn_supports_mla(), \ - "FlashAttnMLA is not supported on this device" + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + + assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device" unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashAttnMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttnMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttnMLAImpl" + ) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "FlashAttnMLA V1 with FP8 KV cache not yet supported") + "FlashAttnMLA V1 with FP8 KV cache not yet supported" + ) def _forward_decode( self, @@ -230,14 +259,14 @@ def _forward_decode( q_nope, q_pe = q else: q_nope, q_pe = torch.split( - q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError( - "FP8 FlashAttention MLA not yet supported") + raise NotImplementedError("FP8 FlashAttention MLA not yet supported") - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] - k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] + kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] + k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :] # NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the # kernel uses this to calculate grid dimensions. Ensure it's at least 1 diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 701248670f72..f0ea1d653c3e 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -8,9 +8,11 @@ from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, +) logger = init_logger(__name__) @@ -18,7 +20,6 @@ class FlashInferMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "FLASHINFER_MLA" @@ -36,37 +37,49 @@ def get_impl_cls() -> type["FlashInferMLAImpl"]: class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashInferMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferMLAImpl" + ) self._workspace_buffer = g_fi_workspace self.bmm1_scale: Optional[float] = None @@ -90,8 +103,7 @@ def _forward_decode( q = q.unsqueeze(1) if self.bmm1_scale is None: - self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float * - self.scale) + self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale if self.bmm2_scale is None: self.bmm2_scale = layer._v_scale_float diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 67c21f83cf5d..56480832bcd1 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -7,16 +7,20 @@ import torch from vllm.attention.backends.abstract import AttentionLayer, AttentionType -from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) +from vllm.attention.ops.flashmla import ( + flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported, +) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonDecodeMetadata, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec @@ -24,7 +28,6 @@ class FlashMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "FLASHMLA" @@ -54,16 +57,22 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_BATCH + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - FlashMLAMetadata) + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata + ) self.num_q_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None @@ -82,19 +91,22 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.cg_buf_num_splits = torch.empty( (vllm_config.scheduler_config.max_num_seqs + 1), device=self.device, - dtype=torch.int32) - - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> FlashMLADecodeMetadata: - tile_scheduler_metadata, num_splits = \ - get_mla_metadata( + dtype=torch.int32, + ) + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + ) -> FlashMLADecodeMetadata: + tile_scheduler_metadata, num_splits = get_mla_metadata( seq_lens_device, self.num_q_heads, - 1, # MQA for the decode path + 1, # MQA for the decode path ) # TODO: we can disambiguate between decode and mixed-prefill decode here @@ -107,8 +119,9 @@ def _build_decode(self, block_table_tensor: torch.Tensor, sm_parts = tile_scheduler_metadata.size(0) # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) - tile_scheduler_metadata_view = \ - self.cg_buf_tile_scheduler_metadata[:sm_parts] + tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[ + :sm_parts + ] tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) tile_scheduler_metadata = tile_scheduler_metadata_view @@ -133,27 +146,36 @@ def _build_decode(self, block_table_tensor: torch.Tensor, class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): - can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) is_supported, reason = is_flashmla_supported() assert is_supported, reason @@ -162,13 +184,16 @@ def __init__( if any(unsupported_features): raise NotImplementedError( "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl" + ) def _forward_decode( self, @@ -191,8 +216,7 @@ def _forward_decode( block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=attn_metadata.decode. - tile_scheduler_metadata, + tile_scheduler_metadata=attn_metadata.decode.tile_scheduler_metadata, num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 36c3c188042c..21d67f832b7b 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -8,21 +8,28 @@ import torch from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + AttentionMetadata, +) from vllm.attention.backends.utils import get_mla_dims -from vllm.attention.ops.flashmla import (flash_mla_sparse_prefill, - flash_mla_with_kvcache, - get_mla_metadata) +from vllm.attention.ops.flashmla import ( + flash_mla_sparse_prefill, + flash_mla_with_kvcache, + get_mla_metadata, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: @@ -47,11 +54,10 @@ def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor: # Convert base-2 LSE to natural-log LSE # Keep FP32 for numerical stability during the merge. - return (lse_base2.to(torch.float32) * math.log(2.0)) + return lse_base2.to(torch.float32) * math.log(2.0) class FlashMLASparseBackend(AttentionBackend): - accept_output_buffer: bool = True @staticmethod @@ -113,13 +119,14 @@ class FlashMLASparseDecodeAndContextMetadata: dummy_block_table: torch.Tensor = None def filter_prefill_indices( - self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self, indices: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: assert self.prefill_context_lengths is not None prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1) - context_indices = torch.where(indices < prefill_context_lengths, - indices, -1) - new_token_indices = torch.where(indices >= prefill_context_lengths, - indices - prefill_context_lengths, -1) + context_indices = torch.where(indices < prefill_context_lengths, indices, -1) + new_token_indices = torch.where( + indices >= prefill_context_lengths, indices - prefill_context_lengths, -1 + ) return context_indices, new_token_indices @@ -194,8 +201,9 @@ def _convert_req_index_to_global_index_kernel( base = tl.load(bt_ptr, mask=valid_block, other=0) # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset - out_val = tl.where(is_invalid_tok | (~valid_block), -1, - base * BLOCK_SIZE + inblock_off) + out_val = tl.where( + is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off + ) # Store results out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 @@ -203,31 +211,30 @@ def _convert_req_index_to_global_index_kernel( def triton_convert_req_index_to_global_index( - req_id: torch.Tensor, # int32 [num_tokens] - block_table: torch. - Tensor, # int32 [num_requests, max_num_blocks_per_req] - token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] - BLOCK_SIZE: int = 64, - NUM_TOPK_TOKENS: int = 2048, - BLOCK_N: int = 128, # tile width along columns + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns ): """ out[token_id, indice_id] = - block_table[req_id[token_id], + block_table[req_id[token_id], token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + token_indices[token_id, indice_id] % BLOCK_SIZE Only when token_indices[token_id, indice_id] == -1 do we output -1. - For safety, we also output -1 if the derived block_id would be + For safety, we also output -1 if the derived block_id would be out-of-bounds. """ assert req_id.dtype == torch.int32 assert block_table.dtype == torch.int32 assert token_indices.dtype == torch.int32 assert token_indices.shape[1] == NUM_TOPK_TOKENS - assert NUM_TOPK_TOKENS % BLOCK_N == 0, \ - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \ - f"BLOCK_N ({BLOCK_N})" + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})" + ) num_tokens = req_id.shape[0] num_requests, max_num_blocks_per_req = block_table.shape @@ -268,14 +275,16 @@ def triton_convert_req_index_to_global_index( @dataclass -class FlashMLASparseMetadataBuilder( - AttentionMetadataBuilder[FlashMLASparseMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_BATCH - - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): +class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec self.model_config = vllm_config.model_config @@ -285,28 +294,27 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], props = torch.cuda.get_device_properties(device) sm_count = props.multi_processor_count - self.num_heads = self.model_config.get_num_attention_heads( - parallel_config) + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.topk_tokens = vllm_config.model_config.hf_config.index_topk self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" - self.topk_tokens_tensor = torch.tensor([self.topk_tokens], - device=device, - dtype=torch.int32) + self.topk_tokens_tensor = torch.tensor( + [self.topk_tokens], device=device, dtype=torch.int32 + ) self.max_model_len_tensor = torch.tensor( - [self.model_config.max_model_len], - device=device, - dtype=torch.int32) + [self.model_config.max_model_len], device=device, dtype=torch.int32 + ) # this is ignored by `flash_mla_with_kvcache` if indices not None - self.dummy_block_table = torch.empty((1, 1), - dtype=torch.int32, - device=self.device) + self.dummy_block_table = torch.empty( + (1, 1), dtype=torch.int32, device=self.device + ) # Equation taken from FlashMLA/csrc/pybind.cpp h_q, h_k = self.num_heads, 1 s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest max_num_sm_parts = int( - max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)) + max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1) + ) if current_platform.is_device_capability(100): max_num_sm_parts *= 2 self.tile_scheduler_metadata_buffer = torch.empty( @@ -314,34 +322,38 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # see: FlashMLA/csrc/params.h (max_num_sm_parts, 8), dtype=torch.int32, - device=device) + device=device, + ) self.num_splits_buffer = torch.empty( # We pack all the tokens into one batch for sparse attention. # Otherwise, we can exceed the sm of `get_mla_metadata`. - ( - 2, ), + (2,), dtype=torch.int32, - device=device) + device=device, + ) self.req_id_per_token_buffer = torch.empty( - (vllm_config.scheduler_config.max_num_batched_tokens, ), + (vllm_config.scheduler_config.max_num_batched_tokens,), dtype=torch.int32, - device=device) - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> FlashMLASparseMetadata: + device=device, + ) + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashMLASparseMetadata: num_tokens = common_attn_metadata.num_actual_tokens - starts = np.asarray(common_attn_metadata.query_start_loc_cpu, - dtype=np.int32) + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) seg_lengths = np.diff(starts) req_id_per_token = np.repeat( - np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths) + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths + ) # Zero-fill for cudagraphs self.req_id_per_token_buffer.fill_(0) - self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\ - .copy_(torch.from_numpy(req_id_per_token), non_blocking=True) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True + ) req_id_per_token = self.req_id_per_token_buffer[:num_tokens] fp8_extra_metadata = None @@ -357,8 +369,9 @@ def build(self, num_sm_parts = tile_scheduler_metadata.size(0) # Copy to persistent buffer for full-CG support - tile_scheduler_metadata_buffer = \ - self.tile_scheduler_metadata_buffer[:num_sm_parts] + tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[ + :num_sm_parts + ] tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) self.num_splits_buffer.copy_(num_splits) @@ -371,7 +384,8 @@ def build(self, # accidentally mark indices invalid, we will use -1 exclusively # to mark invalid indices cache_lens=self.max_model_len_tensor, - dummy_block_table=self.dummy_block_table) + dummy_block_table=self.dummy_block_table, + ) metadata = FlashMLASparseMetadata( num_reqs=common_attn_metadata.num_reqs, @@ -390,62 +404,79 @@ def build(self, class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - topk_indice_buffer: Optional[torch.Tensor] = None, - indexer: Optional["Indexer"] = None, - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + topk_indice_buffer: Optional[torch.Tensor] = None, + indexer: Optional["Indexer"] = None, + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer = indexer.topk_indices_buffer - self.padding = 128 if current_platform.is_device_capability( - 100) else 64 + self.padding = 128 if current_platform.is_device_capability(100) else 64 def _forward_bf16_kv( - self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: num_tokens = q.shape[0] kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( - -1, 1, kv_c_and_k_pe_cache.shape[-1]) + -1, 1, kv_c_and_k_pe_cache.shape[-1] + ) # NOTE(Chen): kernel requires num_local_head to be a multiple of # 64 on hopper and 128 on blackwell if self.num_heads % self.padding != 0: assert self.padding % self.num_heads == 0 - logger.warning_once(f"padding num_heads to {self.padding} \ - due to sparse attn kernel requirement") + logger.warning_once( + f"padding num_heads to {self.padding} \ + due to sparse attn kernel requirement" + ) q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2])) - q_padded[:, :self.num_heads, :] = q + q_padded[:, : self.num_heads, :] = q q = q_padded topk_indices = topk_indices.view(num_tokens, 1, -1) - output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, topk_indices, - self.softmax_scale)[0] - output = output[:, :self.num_heads, :] + output = flash_mla_sparse_prefill( + q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale + )[0] + output = output[:, : self.num_heads, :] return output - def _forward_fp8_kv(self, q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - topk_indices: torch.Tensor, - attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: - + def _forward_fp8_kv( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + ) -> torch.Tensor: assert attn_metadata.fp8_extra_metadata is not None extra_metadata = attn_metadata.fp8_extra_metadata @@ -483,8 +514,8 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for MLACommonImpl") + "fused output quantization is not yet supported for MLACommonImpl" + ) if attn_metadata is None: # The zero fill is required when used with DP + EP @@ -500,8 +531,7 @@ def forward( k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) @@ -534,11 +564,13 @@ def forward( ) if self.kv_cache_dtype != "fp8_ds_mla": - attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices_global, - attn_metadata) + attn_out = self._forward_bf16_kv( + q, kv_cache, topk_indices_global, attn_metadata + ) else: - attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global, - attn_metadata) + attn_out = self._forward_fp8_kv( + q, kv_cache, topk_indices_global, attn_metadata + ) self._v_up_proj(attn_out, out=output[:num_actual_toks]) return output diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 94b963f34e4a..1344840af6a5 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -5,21 +5,21 @@ import torch -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills, +) logger = init_logger(__name__) class DeepseekV32IndexerBackend(AttentionBackend): - @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: return DeepseekV32IndexerMetadata @@ -76,7 +76,6 @@ class DeepSeekV32IndexerDecodeMetadata: @dataclass class DeepseekV32IndexerMetadata: - # FIXME (zyongye) # hacky way to access the data now, need to be in chunked meta seq_lens: torch.Tensor @@ -104,27 +103,27 @@ class DeepseekV32IndexerMetadata: # TODO (zyongye) optimize this, this is now vibe coded def kv_spans_from_batches( - start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, - device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: """ Args: - start_seq_loc: 1D long tensor [B+1], cumulative counts of + start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. - Example: [0, 2, 4, 7] -> + Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. - seq_len_per_batch: 1D long tensor [B], + seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. Example: [5, 9, 4]. Returns: - start_tensor: 1D long tensor [N], start offset in the + start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. - end_location: 1D long tensor [N], + end_location: 1D long tensor [N], **exclusive** end = start + token's local position. (So the attended KV slice is kv[start:end].) - Assumes each batch contributes its full `seq_len_per_batch[i]` - keys to the KV cache, andthe selected tokens within a batch + Assumes each batch contributes its full `seq_len_per_batch[i]` + keys to the KV cache, andthe selected tokens within a batch are the **last** `counts[i]` positions of that sequence. """ q = start_seq_loc.to(dtype=torch.long) @@ -138,8 +137,10 @@ def kv_spans_from_batches( B = L.numel() if N == 0: - return (torch.empty(0, dtype=torch.long, device=device), - torch.empty(0, dtype=torch.long, device=device)) + return ( + torch.empty(0, dtype=torch.long, device=device), + torch.empty(0, dtype=torch.long, device=device), + ) # KV start offsets per batch in the concatenated KV cache kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] @@ -155,8 +156,9 @@ def kv_spans_from_batches( L_expand = torch.repeat_interleave(L, counts) # [N] m_expand = torch.repeat_interleave(counts, counts) # [N] # position within the selected block: 1..counts[b] - pos_within = (torch.arange(N, dtype=torch.long) - - torch.repeat_interleave(q[:-1], counts) + 1) + pos_within = ( + torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1 + ) local_pos = L_expand - m_expand + pos_within # [N], 1-based end_location = start_tensor + local_pos # exclusive end @@ -171,9 +173,9 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig): return max_model_len * 2 -def split_prefill_chunks(seq_lens_cpu: torch.Tensor, - max_prefill_buffer_size: int, - reqs_start: int) -> list[tuple[int, int]]: +def split_prefill_chunks( + seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int +) -> list[tuple[int, int]]: """ Split the prefill chunks into a list of tuples of (reqs_start, reqs_end) such that the total sequence length of each chunk is less than the @@ -183,7 +185,7 @@ def split_prefill_chunks(seq_lens_cpu: torch.Tensor, seq_lens_cpu: The sequence lengths of the prefill requests. max_prefill_buffer_size: The maximum prefill buffer size. reqs_start: The start index of the prefill requests. - + Returns: A list of tuples of (reqs_start, reqs_end). """ @@ -203,20 +205,22 @@ def split_prefill_chunks(seq_lens_cpu: torch.Tensor, class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) reorder_batch_threshold: int = 1 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) scheduler_config = self.vllm_config.scheduler_config - #NOTE(Chen):an estimated max size of flattened_kv. Need to double check. - self.max_prefill_buffer_size = get_max_prefill_buffer_size( - self.vllm_config) + # NOTE(Chen):an estimated max size of flattened_kv. Need to double check. + self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config) self.num_speculative_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0) + if self.vllm_config.speculative_config + else 0 + ) # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2 self.reorder_batch_threshold += min(self.num_speculative_tokens, 1) @@ -225,31 +229,38 @@ def __init__(self, *args, **kwargs): self.num_sms = sm_count self.decode_lens_buffer = torch.empty( - (scheduler_config.max_num_seqs, ), - dtype=torch.int32, - device=self.device) + (scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device + ) # See: DeepGMM/csrc/apis/attention.hpp - self.scheduler_metadata_buffer = torch.empty((self.num_sms + 1, 2), - dtype=torch.int32, - device=self.device) - - def build_one_prefill_chunk(self, reqs_start, reqs_end, - query_start_loc_cpu, seq_lens_cpu, - block_table): - prefill_query_start_loc = query_start_loc_cpu[ - reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start] + self.scheduler_metadata_buffer = torch.empty( + (self.num_sms + 1, 2), dtype=torch.int32, device=self.device + ) + + def build_one_prefill_chunk( + self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table + ): + prefill_query_start_loc = ( + query_start_loc_cpu[reqs_start : reqs_end + 1] + - query_start_loc_cpu[reqs_start] + ) cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( - prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], - self.device) + prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device + ) token_start = query_start_loc_cpu[reqs_start].item() token_end = query_start_loc_cpu[reqs_end].item() total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum() assert total_seq_lens <= self.max_prefill_buffer_size - cu_seq_lens = torch.cat([ - torch.zeros(1, dtype=torch.int32), - seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0) - ]).to(torch.int32).to(self.device) + cu_seq_lens = ( + torch.cat( + [ + torch.zeros(1, dtype=torch.int32), + seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0), + ] + ) + .to(torch.int32) + .to(self.device) + ) return DeepseekV32IndexerPrefillChunkMetadata( cu_seqlen_ks=cu_seqlen_ks, cu_seqlen_ke=cu_seqlen_ke, @@ -261,19 +272,21 @@ def build_one_prefill_chunk(self, reqs_start, reqs_end, num_reqs=reqs_end - reqs_start, ) - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> DeepseekV32IndexerMetadata: - + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> DeepseekV32IndexerMetadata: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -287,33 +300,39 @@ def build(self, ) chunks = [ self.build_one_prefill_chunk( - reqs_start, reqs_end, query_start_loc_cpu, + reqs_start, + reqs_end, + query_start_loc_cpu, common_attn_metadata.seq_lens_cpu, - common_attn_metadata.block_table_tensor) + common_attn_metadata.block_table_tensor, + ) for reqs_start, reqs_end in chunk_seq_ids ] prefill_metadata = DeepseekV32IndexerPrefillMetadata( - chunks=chunks, ) + chunks=chunks, + ) decode_metadata = None if num_decodes > 0: - torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1], - out=self.decode_lens_buffer[:num_decodes]) + torch.diff( + common_attn_metadata.query_start_loc[: num_decodes + 1], + out=self.decode_lens_buffer[:num_decodes], + ) decode_lens = self.decode_lens_buffer[:num_decodes] decode_lens_cpu = torch.diff( - common_attn_metadata.query_start_loc_cpu[:num_decodes + 1]) + common_attn_metadata.query_start_loc_cpu[: num_decodes + 1] + ) # Use CPU to avoid GPU sync; breaking async scheduling - requires_padding = (decode_lens_cpu.max() - > decode_lens_cpu.min()).item() + requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() seq_lens = common_attn_metadata.seq_lens[:num_decodes] self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( - seq_lens, self.kv_cache_spec.block_size, self.num_sms) + seq_lens, self.kv_cache_spec.block_size, self.num_sms + ) decode_metadata = DeepSeekV32IndexerDecodeMetadata( - block_table=common_attn_metadata. - block_table_tensor[:num_decodes, ...], + block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...], seq_lens=common_attn_metadata.seq_lens[:num_decodes], decode_lens=decode_lens, requires_padding=requires_padding, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 79247e569b1c..54ebf071d96f 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -11,26 +11,22 @@ from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd from vllm.config import VllmConfig from vllm.utils import cdiv -# yapf conflicts with isort for this docstring -# yapf: disable -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonDecodeMetadata, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, +) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec -# yapf: enable - def is_aiter_mla_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_ROCM_USE_AITER_MLA + return envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MLA class AiterMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "ROCM_AITER_MLA" @@ -68,19 +64,28 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 - cudagraph_support: ClassVar[AttentionCGSupport] = \ + cudagraph_support: ClassVar[AttentionCGSupport] = ( AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + ) - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): - super().__init__(kv_cache_spec, layer_names, vllm_config, device, - AiterMLAMetadata) - assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ - "only supports block size 1." + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata + ) + assert self.kv_cache_spec.block_size == 1, ( + "AITER MLAonly supports block size 1." + ) self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, - self.kv_cache_spec.block_size) + max_num_pages_per_req = cdiv( + vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size + ) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req @@ -89,74 +94,78 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # so we can only use the persistent buffer if a cudagraph is actually # being used. if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, - dtype=torch.int32, - device=device) - self.paged_kv_indices = torch.zeros(max_num_pages, - dtype=torch.int32, - device=device) - self.paged_kv_last_page_len = torch.zeros(max_num_reqs, - dtype=torch.int32, - device=device) - - self.qo_indptr = torch.arange(0, - max_num_reqs + 1, - dtype=torch.int32, - device=device) - - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> AiterMLADecodeMetadata: + self.paged_kv_indptr = torch.zeros( + max_num_reqs + 1, dtype=torch.int32, device=device + ) + self.paged_kv_indices = torch.zeros( + max_num_pages, dtype=torch.int32, device=device + ) + self.paged_kv_last_page_len = torch.zeros( + max_num_reqs, dtype=torch.int32, device=device + ) + + self.qo_indptr = torch.arange( + 0, max_num_reqs + 1, dtype=torch.int32, device=device + ) + + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + ) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device num_reqs = seq_lens_device.size(0) - mask = (torch.arange(block_table_tensor.size(1), - dtype=block_table_tensor.dtype, - device=device).unsqueeze(0) - < block_table_bounds.unsqueeze(1)) + mask = torch.arange( + block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device + ).unsqueeze(0) < block_table_bounds.unsqueeze(1) paged_kv_indices = block_table_tensor[mask] paged_kv_last_page_len = seq_lens_device % page_size - paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, - page_size, paged_kv_last_page_len) + paged_kv_last_page_len = torch.where( + paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len + ) - paged_kv_indptr = torch.cat([ - torch.zeros(1, dtype=block_table_bounds.dtype, device=device), - block_table_bounds.cumsum(dim=0, dtype=torch.int32) - ]) + paged_kv_indptr = torch.cat( + [ + torch.zeros(1, dtype=block_table_bounds.dtype, device=device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32), + ] + ) if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - num_actual_pages = paged_kv_indices.size(0) - self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, - non_blocking=True) + self.paged_kv_indices[:num_actual_pages].copy_( + paged_kv_indices, non_blocking=True + ) self.paged_kv_indices[num_actual_pages:].fill_(-1) paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, - non_blocking=True) - self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) - paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs] + self.paged_kv_indptr[: 1 + num_reqs].copy_( + paged_kv_indptr, non_blocking=True + ) + self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1]) + paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs] self.paged_kv_last_page_len[:num_reqs].copy_( - paged_kv_last_page_len, non_blocking=True) + paged_kv_last_page_len, non_blocking=True + ) self.paged_kv_last_page_len[num_reqs:].fill_(1) paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] - qo_indptr = self.qo_indptr[:1 + num_reqs] + qo_indptr = self.qo_indptr[: 1 + num_reqs] else: - qo_indptr = torch.arange(0, - num_reqs + 1, - step=1, - dtype=torch.int32, - device=device) + qo_indptr = torch.arange( + 0, num_reqs + 1, step=1, dtype=torch.int32, device=device + ) attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, @@ -164,51 +173,60 @@ def _build_decode(self, block_table_tensor: torch.Tensor, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, - qo_indptr=qo_indptr) + qo_indptr=qo_indptr, + ) return attn_metadata class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - assert (num_heads == 16 or num_heads == 128), ( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + assert num_heads == 16 or num_heads == 128, ( f"Aiter MLA only supports 16 or 128 number of heads.\n" f"Provided {num_heads} number of heads.\n" - "Try adjusting tensor_parallel_size value.") + "Try adjusting tensor_parallel_size value." + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "Aiter MLA does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): output = self.flash_attn_varlen_func( q=q, k=k, @@ -235,21 +253,25 @@ def _forward_decode( assert isinstance(q, torch.Tensor) B = q.shape[0] - o = torch.zeros(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) + o = torch.zeros( + B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + ) kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) # max_seqlen_qo must be 1 except for MTP # TODO: Find the best value for MTP max_seqlen_qo = 1 - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, - attn_metadata.decode.qo_indptr, max_seqlen_qo, - attn_metadata.decode.paged_kv_indptr, - attn_metadata.decode.paged_kv_indices, - attn_metadata.decode.paged_kv_last_page_len) + aiter_mla_decode_fwd( + q, + kv_buffer, + o, + self.scale, + attn_metadata.decode.qo_indptr, + max_seqlen_qo, + attn_metadata.decode.paged_kv_indptr, + attn_metadata.decode.paged_kv_indices, + attn_metadata.decode.paged_kv_last_page_len, + ) return o, None diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 076152061d50..3b6718c48d09 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -6,22 +6,26 @@ import torch from vllm import envs -from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, - is_quantized_kv_cache) +from vllm.attention.backends.abstract import ( + AttentionLayer, + AttentionType, + is_quantized_kv_cache, +) from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON -from vllm.v1.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata) +from vllm.v1.attention.backends.mla.common import ( + MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, +) logger = init_logger(__name__) class TritonMLABackend(MLACommonBackend): - @staticmethod def get_name() -> str: return "TRITON_MLA" @@ -35,54 +39,64 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap" + ) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonMLAImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonMLAImpl" + ) if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( - "TritonMLA V1 with FP8 KV cache not yet supported") + "TritonMLA V1 with FP8 KV cache not yet supported" + ) self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN self.triton_fa_func = triton_attention if HAS_TRITON else None - def _flash_attn_varlen_diff_headdims_rocm(self, - q, - k, - v, - softmax_scale=None, - **kwargs): + def _flash_attn_varlen_diff_headdims_rocm( + self, q, k, v, softmax_scale=None, **kwargs + ): assert self.triton_fa_func is not None # Triton Attention requires a padded V - padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) + padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) # The output of triton_attention is a tuple of # [output_tensor, encoded_softmax] where encoded_softmax is always None output_tensor, _ = self.triton_fa_func( @@ -101,18 +115,17 @@ def _flash_attn_varlen_diff_headdims_rocm(self, return output_tensor - def _flash_attn_varlen_diff_headdims(self, - q, - k, - v, - return_softmax_lse=False, - softmax_scale=None, - **kwargs): - if current_platform.is_rocm() \ - and self.use_triton_flash_attn \ - and not return_softmax_lse: + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): + if ( + current_platform.is_rocm() + and self.use_triton_flash_attn + and not return_softmax_lse + ): return self._flash_attn_varlen_diff_headdims_rocm( - q, k, v, softmax_scale=softmax_scale, **kwargs) + q, k, v, softmax_scale=softmax_scale, **kwargs + ) else: return super()._flash_attn_varlen_diff_headdims( q, @@ -120,7 +133,8 @@ def _flash_attn_varlen_diff_headdims(self, v, return_softmax_lse=return_softmax_lse, softmax_scale=softmax_scale, - **kwargs) + **kwargs, + ) def _forward_decode( self, @@ -141,11 +155,9 @@ def _forward_decode( assert isinstance(q, torch.Tensor) B = q.shape[0] q_num_heads = q.shape[1] - o = torch.zeros(B, - q_num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) + o = torch.zeros( + B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device + ) lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) num_kv_splits = 4 # TODO: heuristic @@ -165,13 +177,22 @@ def _forward_decode( # Add a head dim of 1 kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] PAGE_SIZE = kv_c_and_k_pe_cache.size(1) # Run MQA - decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, lse, - attn_metadata.decode.block_table, - attn_metadata.decode.seq_lens, attn_logits, - num_kv_splits, self.scale, PAGE_SIZE) + decode_attention_fwd( + q, + kv_c_and_k_pe_cache, + kv_c_cache, + o, + lse, + attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, + attn_logits, + num_kv_splits, + self.scale, + PAGE_SIZE, + ) return o, lse diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 7ac1a063f565..7e83e7a681f4 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -6,8 +6,12 @@ import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionLayer, + AttentionType, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv, next_power_of_2 @@ -41,49 +45,62 @@ from torch_xla.experimental.custom_kernel import XLA_LIB @requires_jax - def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, num_slices_per_block: int): + def kv_cache_update_op_impl( + kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int, + ): from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax( kv_cache_update, - (kv, slot_mapping, kv_cache, num_kv_update_slices), { - "page_size": page_size, - "num_slices_per_block": num_slices_per_block - }) + (kv, slot_mapping, kv_cache, num_kv_update_slices), + {"page_size": page_size, "num_slices_per_block": num_slices_per_block}, + ) return new_kv_cache - XLA_LIB.define( - "kv_cache_update_op(Tensor kv, Tensor slot_mapping," \ - "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \ - "int num_slices_per_block)" \ - "-> Tensor", ) + "kv_cache_update_op(Tensor kv, Tensor slot_mapping," + "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," + "int num_slices_per_block)" + "-> Tensor", + ) @impl(XLA_LIB, "kv_cache_update_op", "XLA") - def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int) -> torch.Tensor: - new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, - num_kv_update_slices, page_size, - num_slices_per_block) + def kv_cache_update_op_xla( + kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int, + ) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl( + kv, + slot_mapping, + kv_cache, + num_kv_update_slices, + page_size, + num_slices_per_block, + ) return new_kv_cache @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") - def kv_cache_update_op_non_xla(kv: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int) -> torch.Tensor: + def kv_cache_update_op_non_xla( + kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int, + ) -> torch.Tensor: return kv_cache class PallasAttentionBackend(AttentionBackend): - @staticmethod def get_name() -> str: return "PALLAS" @@ -104,8 +121,9 @@ def get_kv_cache_shape( head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: - padded_head_size = cdiv( - head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) return (num_blocks, block_size, num_kv_heads * 2, padded_head_size) @staticmethod @@ -122,10 +140,12 @@ def swap_blocks( # we simply make sure that the size is smaller than half of SMEM capacity. @staticmethod def get_min_page_size(vllm_config: VllmConfig) -> int: - max_num_page_per_req = (1024 * 1024 // 2 // - vllm_config.scheduler_config.max_num_seqs // 4) - min_page_size = cdiv(vllm_config.model_config.max_model_len, - max_num_page_per_req) + max_num_page_per_req = ( + 1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4 + ) + min_page_size = cdiv( + vllm_config.model_config.max_model_len, max_num_page_per_req + ) min_page_size = 1 << (min_page_size - 1).bit_length() return min_page_size @@ -146,8 +166,7 @@ def get_page_size(vllm_config: VllmConfig) -> int: # handle VREG spills. if vllm_config.model_config.max_model_len > 8192: return 16 - page_size = next_power_of_2( - vllm_config.model_config.max_model_len) // 16 + page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16 if page_size <= 16: return 16 if page_size >= 256: @@ -176,7 +195,6 @@ class PallasMetadata: class PallasAttentionBackendImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -203,15 +221,18 @@ def __init__( raise NotImplementedError("Alibi slopes is not supported.") if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl" + ) self.kv_cache_quantized_dtype = None if kv_cache_dtype != "auto": self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get( - kv_cache_dtype.lower().strip()) + kv_cache_dtype.lower().strip() + ) def forward( self, @@ -240,7 +261,8 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" - " for PallasAttentionBackendImpl") + " for PallasAttentionBackendImpl" + ) # For determine_available_memory case. if kv_cache.numel() == 0: @@ -253,15 +275,18 @@ def forward( key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - padded_head_size = cdiv( - self.head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) query = torch.nn.functional.pad( - query, (0, padded_head_size - self.head_size), value=0.0) + query, (0, padded_head_size - self.head_size), value=0.0 + ) key = torch.nn.functional.pad( - key, (0, padded_head_size - self.head_size), value=0.0) + key, (0, padded_head_size - self.head_size), value=0.0 + ) value = torch.nn.functional.pad( - value, (0, padded_head_size - self.head_size), value=0.0) + value, (0, padded_head_size - self.head_size), value=0.0 + ) if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0: # Write input keys and values to the KV cache. @@ -280,9 +305,9 @@ def forward( ) if self.kv_cache_quantized_dtype is not None and ( - layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0): - raise ValueError( - "k_scale_float and v_scale_float must be non-zero") + layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0 + ): + raise ValueError("k_scale_float and v_scale_float must be non-zero") output = torch.ops.xla.ragged_paged_attention( query, kv_cache, @@ -305,7 +330,7 @@ def forward( ) if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - output = output[:, :, :self.head_size] + output = output[:, :, : self.head_size] return output.reshape(num_tokens, hidden_size) @@ -321,7 +346,7 @@ def write_to_kv_cache( k_scale: float = 1.0, v_scale: float = 1.0, ) -> None: - """ Write the key and values to the KV cache. + """Write the key and values to the KV cache. Args: key: shape = [num_tokens, num_kv_heads, head_size] @@ -330,8 +355,7 @@ def write_to_kv_cache( num_slices_per_kv_cache_update_block: int """ _, page_size, num_combined_kv_heads, head_size = kv_cache.shape - head_size = cdiv(head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT if kv_cache_quantized_dtype is not None: dtype_info = torch.finfo(kv_cache_quantized_dtype) @@ -343,15 +367,19 @@ def write_to_kv_cache( value = torch.clamp(value, dtype_info.min, dtype_info.max) value = value.to(kv_cache_quantized_dtype) - kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, - head_size) + kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size) torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) kv_cache = kv_cache.flatten(0, 1) new_kv_cache = torch.ops.xla.kv_cache_update_op( - kv, slot_mapping, kv_cache, num_kv_update_slices, page_size, - num_slices_per_kv_cache_update_block) + kv, + slot_mapping, + kv_cache, + num_kv_update_slices, + page_size, + num_slices_per_kv_cache_update_block, + ) # NOTE: the in-place copy will be optimized away by XLA compiler. kv_cache.copy_(new_kv_cache) @@ -389,15 +417,18 @@ def get_dtype_packing(dtype): if 32 % bits != 0: raise ValueError( f"The bit width must be divisible by 32, but got bits={bits}, " - "dtype={dtype}") + "dtype={dtype}" + ) return 32 // bits -def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, - kv_cache_dtype: torch.dtype) -> int: +def get_page_size_bytes( + block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype +) -> int: """Returns the size in bytes of one page of the KV cache.""" - padded_head_size = cdiv(head_size, - TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) num_combined_kv_heads = num_kv_heads * 2 # NOTE: for the implicit padding in XLA @@ -405,5 +436,6 @@ def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing kv_cache_dtype_bits = dtype_bits(kv_cache_dtype) - return (block_size * num_combined_kv_heads * padded_head_size * - kv_cache_dtype_bits // 8) + return ( + block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8 + ) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index ed63c7b1bda6..348eca55eefb 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -1,19 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" + from dataclasses import dataclass from typing import Optional import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 @@ -43,55 +50,63 @@ def _vllm_layout_trans_kernel( batch_idx = tl.program_id(0) block_idx = tl.program_id(1) - batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + - tl.arange(0, 2)) + batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + tl.arange(0, 2)) batch_query_start, batch_query_end = tl.split(batch_query_indexes) query_len = batch_query_end - batch_query_start if query_len <= 1: return - batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + - tl.arange(0, 2)) + batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + tl.arange(0, 2)) batch_token_start, batch_token_end = tl.split(batch_token_indexes) seq_len = batch_token_end - batch_token_start if block_idx * BLOCK_SIZE < seq_len: - block_mask = (block_idx * BLOCK_SIZE + - tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len - - kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 + - block_idx).to(tl.int64) - - kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange( - 0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :] - k_vals = tl.load(k_buffer_ptr + kv_buffer_off, - mask=block_mask, - other=0.0) + block_mask = ( + block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None] + ) < seq_len + + kv_idx = tl.load( + block_table + batch_idx * block_table_stride_0 + block_idx + ).to(tl.int64) + + kv_buffer_off = ( + kv_idx * BLOCK_SIZE * E_DIM + + tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + + tl.arange(0, E_DIM)[None, :] + ) + k_vals = tl.load(k_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0) if k_vals.dtype.is_fp8(): - k_vals = (k_vals.to(tl.float32) * - tl.load(k_scale)).to(output_dtype) + k_vals = (k_vals.to(tl.float32) * tl.load(k_scale)).to(output_dtype) else: k_vals = k_vals.to(output_dtype) - v_vals = tl.load(v_buffer_ptr + kv_buffer_off, - mask=block_mask, - other=0.0) + v_vals = tl.load(v_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0) if v_vals.dtype.is_fp8(): - v_vals = (v_vals.to(tl.float32) * - tl.load(v_scale)).to(output_dtype) + v_vals = (v_vals.to(tl.float32) * tl.load(v_scale)).to(output_dtype) else: v_vals = v_vals.to(output_dtype) - kv_values_off = batch_token_start * E_DIM + \ - block_idx * BLOCK_SIZE * E_DIM + \ - tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \ - tl.arange(0, E_DIM)[None, :] + kv_values_off = ( + batch_token_start * E_DIM + + block_idx * BLOCK_SIZE * E_DIM + + tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + + tl.arange(0, E_DIM)[None, :] + ) tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask) tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask) - def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table, - k_cache, v_cache, max_seq_len, k_scale, v_scale, - output_dtype, total_tokens): + def vllm_layout_trans( + b_query_lens_loc, + b_seq_lens_loc, + block_table, + k_cache, + v_cache, + max_seq_len, + k_scale, + v_scale, + output_dtype, + total_tokens, + ): H_KV = v_cache.shape[2] D = v_cache.shape[3] BLOCK_SIZE = v_cache.shape[1] @@ -107,8 +122,7 @@ def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table, device=v_cache.device, ) - grid = (block_table.shape[0], - (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) + grid = (block_table.shape[0], (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) if output_dtype == torch.float16: output_dtype = tl.float16 @@ -117,19 +131,21 @@ def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table, else: raise ValueError(f"Unsupported output dtype: {output_dtype}") - _vllm_layout_trans_kernel[grid](k_cache, - v_cache, - k_values, - v_values, - b_query_lens_loc, - b_seq_lens_loc, - block_table, - block_table.stride(0), - k_scale, - v_scale, - output_dtype=output_dtype, - E_DIM=H_KV * D, - BLOCK_SIZE=BLOCK_SIZE) + _vllm_layout_trans_kernel[grid]( + k_cache, + v_cache, + k_values, + v_values, + b_query_lens_loc, + b_seq_lens_loc, + block_table, + block_table.stride(0), + k_scale, + v_scale, + output_dtype=output_dtype, + E_DIM=H_KV * D, + BLOCK_SIZE=BLOCK_SIZE, + ) return k_values, v_values @@ -152,9 +168,18 @@ def flash_attn_varlen_func_impl( ) -> torch.Tensor: if total_tokens == 0: total_tokens = int(cu_seqlens_k[-1].item()) - k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table, - k_cache, v_cache, max_seqlen_k, k_scale, - v_scale, q.dtype, total_tokens) + k, v = vllm_layout_trans( + cu_seqlens_q, + cu_seqlens_k, + block_table, + k_cache, + v_cache, + max_seqlen_k, + k_scale, + v_scale, + q.dtype, + total_tokens, + ) output = aiter.flash_attn_varlen_func( q=q, @@ -190,16 +215,17 @@ def flash_attn_varlen_func_fake( v_scale: torch.Tensor, total_tokens: int = 0, ) -> torch.Tensor: - return torch.empty(q.shape[0], - q.shape[1], - v_cache.shape[-2], - dtype=q.dtype, - device=q.device) + return torch.empty( + q.shape[0], q.shape[1], v_cache.shape[-2], dtype=q.dtype, device=q.device + ) - direct_register_custom_op("flash_attn_varlen_func", - flash_attn_varlen_func_impl, ["out"], - flash_attn_varlen_func_fake, - dispatch_key=current_platform.dispatch_key) + direct_register_custom_op( + "flash_attn_varlen_func", + flash_attn_varlen_func_impl, + ["out"], + flash_attn_varlen_func_fake, + dispatch_key=current_platform.dispatch_key, + ) logger = init_logger(__name__) @@ -231,11 +257,17 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( - AttentionMetadataBuilder[AiterFlashAttentionMetadata]): + AttentionMetadataBuilder[AiterFlashAttentionMetadata] +): cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config @@ -243,9 +275,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.cache_config = vllm_config.cache_config self.num_heads_q = self.model_config.get_num_attention_heads( - self.parallel_config) - self.num_heads_kv = self.model_config.get_num_kv_heads( - self.parallel_config) + self.parallel_config + ) + self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size # Sliding window size to be used with the AOT scheduler will be @@ -254,19 +286,22 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.total_tokens: int = 0 def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - self.total_tokens = self.model_config.max_model_len \ + self, common_attn_metadata: CommonAttentionMetadata + ): + self.total_tokens = ( + self.model_config.max_model_len * self.vllm_config.scheduler_config.max_num_partial_prefills - res = self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + ) + res = self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata) self.total_tokens = 0 return res - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> 'AiterFlashAttentionMetadata': - + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> "AiterFlashAttentionMetadata": num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len max_seq_len = common_attn_metadata.max_seq_len @@ -277,20 +312,18 @@ def build(self, if max_query_len > 1: # We pre-compute cumulative seq len needed for prefill attention # here to avoid recomputing it for every layer - cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, - dtype=torch.int32, - device=seq_lens.device) - torch.cumsum(seq_lens, - dim=0, - dtype=cu_seq_lens.dtype, - out=cu_seq_lens[1:]) + cu_seq_lens = torch.zeros( + seq_lens.shape[0] + 1, dtype=torch.int32, device=seq_lens.device + ) + torch.cumsum(seq_lens, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]) num_actual_kv_tokens = int(cu_seq_lens[-1].item()) else: cu_seq_lens = None num_actual_kv_tokens = 0 - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): + def schedule( + batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal + ): return None use_cascade = common_prefix_len > 0 @@ -316,7 +349,6 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class AiterFlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -336,7 +368,8 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -368,7 +401,6 @@ def get_kv_cache_shape( class AiterFlashAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -396,7 +428,7 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0. + logits_soft_cap = 0.0 self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -406,10 +438,12 @@ def __init__( AiterFlashAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl" + ) def forward( self, @@ -442,8 +476,8 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") + "fused output quantization is not yet supported for FlashAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -512,13 +546,14 @@ def forward( _, num_heads, head_size = query.shape nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 num_seqs = seqused_k.shape[0] - max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM - - 1) // _PARTITION_SIZE_ROCM + max_num_partitions = ( + max_seqlen_k + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM workspace_buffer = torch.empty( - (num_seqs * num_heads * max_num_partitions * head_size) * - nbytes_per_qo_elem + 2 * - (num_seqs * num_heads * max_num_partitions) * 4, + (num_seqs * num_heads * max_num_partitions * head_size) + * nbytes_per_qo_elem + + 2 * (num_seqs * num_heads * max_num_partitions) * 4, dtype=torch.uint8, device=output.device, ) @@ -546,4 +581,5 @@ def forward( return output else: raise NotImplementedError( - "Cascade attention is not implemented for ROCM AITER") + "Cascade attention is not implemented for ROCM AITER" + ) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 1748a48168d4..4c24770aa22c 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" + from dataclasses import dataclass from functools import cache from typing import ClassVar, Optional @@ -9,20 +10,27 @@ from vllm import _custom_ops as ops from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) -from vllm.attention.ops.chunked_prefill_paged_decode import ( - chunked_prefill_paged_decode) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) +from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym) + QuantKey, + kFp8StaticTensorSym, +) from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -58,21 +66,25 @@ class RocmAttentionMetadata: prefix_scheduler_metadata: Optional[torch.Tensor] = None -class RocmAttentionMetadataBuilder( - AttentionMetadataBuilder[RocmAttentionMetadata]): +class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( - vllm_config.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) + self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() def build_for_cudagraph_capture( @@ -93,10 +105,12 @@ def build_for_cudagraph_capture( return attn_metadata - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> RocmAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> RocmAttentionMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -109,14 +123,13 @@ def build(self, use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - - common_prefix_len) + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len suffix_kv_lens = suffix_kv_lens.to(self.device) else: cu_prefix_query_lens = None @@ -143,7 +156,6 @@ def build(self, class RocmAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -163,7 +175,8 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -203,12 +216,10 @@ def use_aiter_unified_attention() -> bool: """Check if aiter unified attention should be used.""" # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set # to 1 as default - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_USE_AITER_UNIFIED_ATTENTION + return envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION class RocmAttentionImpl(AttentionImpl): - def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym @@ -249,29 +260,30 @@ def __init__( RocmAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "RocmAttentionImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "RocmAttentionImpl" + ) self.fp8_dtype = current_platform.fp8_dtype() - self.force_prefill_decode_attn = \ - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + self.force_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION if not self.force_prefill_decode_attn: # If not using prefill decode attention, we use the Triton # unified attention implementation. if use_aiter_unified_attention(): - logger.info_once( - "Using aiter unified attention for RocmAttentionImpl") - from aiter.ops.triton.unified_attention import ( - unified_attention) + logger.info_once("Using aiter unified attention for RocmAttentionImpl") + from aiter.ops.triton.unified_attention import unified_attention + self.unified_attention = unified_attention else: - logger.info_once( - "Using vllm unified attention for RocmAttentionImpl") + logger.info_once("Using vllm unified attention for RocmAttentionImpl") from vllm.attention.ops.triton_unified_attention import ( - unified_attention) + unified_attention, + ) + self.unified_attention = unified_attention self.sinks = sinks @@ -279,7 +291,8 @@ def __init__( assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " - f"num_heads: {num_heads}.") + f"num_heads: {num_heads}." + ) def forward( self, @@ -310,7 +323,8 @@ def forward( if output_block_scale is not None: raise NotImplementedError( "fused block_scale output quantization is not yet supported" - " for RocmAttentionImpl") + " for RocmAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -332,7 +346,8 @@ def forward( if use_prefill_decode_attn: key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + kv_cache, self.num_kv_heads, self.head_size + ) else: key_cache, value_cache = kv_cache.unbind(0) @@ -366,16 +381,17 @@ def forward( key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape - assert layer._q_scale_float == 1.0, \ + assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." + ) if current_platform.is_cuda(): # Skip Q quantization on ROCm and XPU, enable this on cuda # only, since dequantizing back to f32 in the attention kernel # is not supported. query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale, + ) query = query.reshape((num_tokens, num_heads, head_size)) cu_seqlens_q = attn_metadata.query_start_loc @@ -430,6 +446,7 @@ def forward( k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), sinks=self.sinks, - output_scale=output_scale) + output_scale=output_scale, + ) return output diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index ba0fba4281e5..74cfecca764e 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -6,16 +6,16 @@ import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.mamba_attn import ( - BaseMambaAttentionMetadataBuilder) -from vllm.v1.attention.backends.utils import (PAD_SLOT_ID, - CommonAttentionMetadata, - compute_causal_conv1d_metadata, - split_decodes_and_prefills) +from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, + CommonAttentionMetadata, + compute_causal_conv1d_metadata, + split_decodes_and_prefills, +) class ShortConvAttentionBackend(AttentionBackend): - @staticmethod def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: return ShortConvAttentionMetadataBuilder @@ -39,12 +39,14 @@ class ShortConvAttentionMetadata: class ShortConvAttentionMetadataBuilder( - BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]): - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> ShortConvAttentionMetadata: + BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata] +): + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> ShortConvAttentionMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] @@ -54,28 +56,38 @@ def build(self, num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold)) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) has_initial_states_p = None if num_prefills > 0: has_initial_states_cpu = ( - common_attn_metadata. - num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) - has_initial_states_p = has_initial_states_cpu.to( - query_start_loc.device) - - query_start_loc_p = common_attn_metadata.query_start_loc[ - -num_prefills - 1:] - num_decode_tokens - - nums_dict, batch_ptr, token_chunk_offset_ptr = \ + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills : num_reqs + ] + > 0 + ) + has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device) + + query_start_loc_p = ( + common_attn_metadata.query_start_loc[-num_prefills - 1 :] + - num_decode_tokens + ) + + nums_dict, batch_ptr, token_chunk_offset_ptr = ( compute_causal_conv1d_metadata(query_start_loc_p) + ) - elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs - and self.compilation_config.full_cuda_graph): + elif ( + num_decodes > 0 + and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph + ): num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) - self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor, - non_blocking=True) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_tensor, non_blocking=True + ) state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 583756129a29..2a7770c87d24 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -8,14 +8,21 @@ import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) + AttentionMetadataBuilder, + CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec if TYPE_CHECKING: @@ -28,7 +35,6 @@ class TreeAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -48,7 +54,8 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -114,9 +121,9 @@ def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: # metadata structure return self._cached_prefill_metadata - q_start_loc = self.query_start_loc[self.num_decodes:] + q_start_loc = self.query_start_loc[self.num_decodes :] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[self.num_decodes:] + kv_seqlens = self.seq_lens[self.num_decodes :] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, @@ -124,8 +131,8 @@ def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]: query_start_loc=q_start_loc - q_start_loc[0], max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[self.num_decodes:], - slot_mapping=self.slot_mapping[self.num_decode_tokens:], + block_table=self.block_table[self.num_decodes :], + slot_mapping=self.slot_mapping[self.num_decode_tokens :], ) return self._cached_prefill_metadata @@ -139,9 +146,9 @@ def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: # metadata structure return self._cached_decode_metadata - q_start_loc = self.query_start_loc[:self.num_decodes + 1] + q_start_loc = self.query_start_loc[: self.num_decodes + 1] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[:self.num_decodes] + kv_seqlens = self.seq_lens[: self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = TreeAttentionMetadata( num_actual_tokens=self.num_decode_tokens, @@ -149,16 +156,14 @@ def decode_metadata(self) -> Optional["TreeAttentionMetadata"]: query_start_loc=q_start_loc, max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[:self.num_decodes], - slot_mapping=self.slot_mapping[:self.num_decode_tokens], + block_table=self.block_table[: self.num_decodes], + slot_mapping=self.slot_mapping[: self.num_decode_tokens], tree_attn_bias=self.tree_attn_bias, ) return self._cached_decode_metadata -class TreeAttentionMetadataBuilder( - AttentionMetadataBuilder[TreeAttentionMetadata]): - +class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]): def __init__( self, kv_cache_spec: AttentionSpec, @@ -172,10 +177,9 @@ def __init__( spec_config = vllm_config.speculative_config spec_token_tree = (spec := spec_config) and spec.speculative_token_tree - tree_choices: list[tuple[int, - ...]] = (ast.literal_eval(spec_token_tree) - if spec_token_tree is not None else - [(0, )]) + tree_choices: list[tuple[int, ...]] = ( + ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)] + ) # Construct the tree attention bias. depth_counts = _get_depth_counts(tree_choices) self.tree_attn_bias = _prepare_tree_attn_bias( @@ -185,12 +189,12 @@ def __init__( device=device, ) - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: return reorder_batch_to_split_decodes_and_prefills( - input_batch, - scheduler_output, - decode_threshold=self.tree_attn_bias.shape[0]) + input_batch, scheduler_output, decode_threshold=self.tree_attn_bias.shape[0] + ) def build( self, @@ -200,8 +204,10 @@ def build( ) -> TreeAttentionMetadata: decode_threshold = self.tree_attn_bias.shape[0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=decode_threshold)) + split_decodes_and_prefills( + common_attn_metadata, decode_threshold=decode_threshold + ) + ) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc @@ -241,8 +247,7 @@ def build_for_drafting( # Slice the tree attention bias for drafting. Exclude # the root level. start, end = 1, 1 + common_attn_metadata.max_query_len - self.tree_attn_bias = self.tree_attn_bias[start:end, - start:end].contiguous() + self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous() # Build attention bias. attn_metadata = self.build(0, common_attn_metadata, fast_build=True) @@ -273,10 +278,9 @@ def _prepare_tree_attn_bias( ) -> torch.Tensor: # +1 comes from the additional root node. tree_len = len(sorted_tree_choices) + 1 - tree_attn_mask = torch.full((tree_len, tree_len), - -torch.inf, - device=device, - dtype=dtype) + tree_attn_mask = torch.full( + (tree_len, tree_len), -torch.inf, device=device, dtype=dtype + ) # Set diagonal to all zeros. Each token should # attend to itself. @@ -298,14 +302,14 @@ def _prepare_tree_attn_bias( ancestor_idx = [] for c in range(len(cur_tree_choice) - 1): ancestor_idx.append( - sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1) + sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1 + ) tree_attn_mask[j + start + 1, ancestor_idx] = mask_val start += depth_counts[i] return tree_attn_mask class TreeAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -341,10 +345,12 @@ def __init__( TreeAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TreeAttentionImpl.") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TreeAttentionImpl." + ) def forward( self, @@ -374,8 +380,8 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( - "fused output quantization is not yet supported" - " for TreeAttentionImpl") + "fused output quantization is not yet supported for TreeAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -404,8 +410,7 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens num_decode_tokens = attn_metadata.num_decode_tokens - descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, - key.shape[1]) + descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1]) if prefill_meta := attn_metadata.prefill_metadata: unified_attention( q=query[num_decode_tokens:num_actual_tokens], diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 3983c5edc76f..9997ed16bed1 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -1,24 +1,34 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """High-Performance Triton-only Attention layer.""" + from dataclasses import dataclass from typing import ClassVar, Optional import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) from vllm.attention.ops.triton_reshape_and_cache_flash import ( - triton_reshape_and_cache_flash) + triton_reshape_and_cache_flash, +) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym) + QuantKey, + kFp8StaticTensorSym, +) from vllm.platforms import current_platform -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import AttentionSpec if current_platform.is_cuda_alike(): @@ -59,21 +69,25 @@ class TritonAttentionMetadata: prefix_scheduler_metadata: Optional[torch.Tensor] = None -class TritonAttentionMetadataBuilder( - AttentionMetadataBuilder[TritonAttentionMetadata]): +class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.block_size = kv_cache_spec.block_size model_config = vllm_config.model_config self.num_heads_q = model_config.get_num_attention_heads( - vllm_config.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) + self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config) self.headdim = model_config.get_head_size() def build_for_cudagraph_capture( @@ -86,10 +100,12 @@ def build_for_cudagraph_capture( attn_metadata.seq_lens.fill_(1) return attn_metadata - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> TritonAttentionMetadata: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> TritonAttentionMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -102,14 +118,13 @@ def build(self, use_cascade = common_prefix_len > 0 if use_cascade: - cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=self.device) - prefix_kv_lens = torch.tensor([common_prefix_len], - dtype=torch.int32, - device=self.device) - suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - - common_prefix_len) + cu_prefix_query_lens = torch.tensor( + [0, num_actual_tokens], dtype=torch.int32, device=self.device + ) + prefix_kv_lens = torch.tensor( + [common_prefix_len], dtype=torch.int32, device=self.device + ) + suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len suffix_kv_lens = suffix_kv_lens.to(self.device) else: cu_prefix_query_lens = None @@ -136,7 +151,6 @@ def build(self, class TritonAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -151,7 +165,8 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by TritonAttention." f"Head sizes need to be larger or equal 32 for this backend. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -187,7 +202,6 @@ def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]: class TritonAttentionImpl(AttentionImpl): - def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym @@ -228,10 +242,12 @@ def __init__( TritonAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonAttentionImpl") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "TritonAttentionImpl" + ) self.fp8_dtype = current_platform.fp8_dtype() @@ -240,7 +256,8 @@ def __init__( assert sinks.shape[0] == num_heads, ( "Sinks must have the same number of heads as the number of " f"heads in the layer. Sinks shape: {sinks.shape}, " - f"num_heads: {num_heads}.") + f"num_heads: {num_heads}." + ) def forward( self, @@ -271,7 +288,8 @@ def forward( if output_block_scale is not None: raise NotImplementedError( "fused block_scale output quantization is not yet supported" - " for TritonAttentionImpl") + " for TritonAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -316,16 +334,17 @@ def forward( key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) num_tokens, num_heads, head_size = query.shape - assert layer._q_scale_float == 1.0, \ + assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." + ) if current_platform.is_cuda(): # Skip Q quantization on ROCm and XPU, enable this on cuda # only, since dequantizing back to f32 in the attention kernel # is not supported. query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) + query.reshape((num_tokens, num_heads * head_size)).contiguous(), + layer._q_scale, + ) query = query.reshape((num_tokens, num_heads, head_size)) cu_seqlens_q = attn_metadata.query_start_loc diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index f37a829f401c..bddb2f22f0dc 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -5,8 +5,18 @@ import functools from abc import abstractmethod from dataclasses import dataclass, fields, make_dataclass -from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional, - Protocol, TypeVar, Union, get_args) +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Literal, + Optional, + Protocol, + TypeVar, + Union, + get_args, +) import numpy as np import torch @@ -21,11 +31,11 @@ from vllm.v1.worker.gpu_input_batch import InputBatch import vllm.envs as envs -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.layer import Attention from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_kv_connector_cache_layout) + get_kv_connector_cache_layout, +) from vllm.logger import init_logger from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.ubatch_utils import UBatchSlice @@ -46,7 +56,7 @@ class CommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. - + For many of the tensors we keep both GPU and CPU versions. """ @@ -89,26 +99,27 @@ def slice_query_start_locs( request_slice: slice, ) -> torch.Tensor: """ - Creates a new query_start_loc that corresponds to the requests in + Creates a new query_start_loc that corresponds to the requests in request_slice. Note: This function creates a new tensor to hold the new query_start_locs. This will break cudagraph compatibility. """ - return query_start_loc[request_slice.start: request_slice.stop + 1] -\ - query_start_loc[request_slice.start] + return ( + query_start_loc[request_slice.start : request_slice.stop + 1] + - query_start_loc[request_slice.start] + ) def _make_metadata_with_slice( - ubatch_slice: UBatchSlice, - attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: + ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata +) -> CommonAttentionMetadata: """ - This function creates a new CommonAttentionMetadata that corresponds to + This function creates a new CommonAttentionMetadata that corresponds to the requests included in ubatch_slice """ - assert not ubatch_slice.is_empty(), ( - f"Ubatch slice {ubatch_slice} is empty") + assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty" request_slice = ubatch_slice.request_slice token_slice = ubatch_slice.token_slice @@ -119,10 +130,12 @@ def _make_metadata_with_slice( last_req = request_slice.stop - 1 last_tok = token_slice.stop - 1 - assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \ + assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], ( "Token slice start outside of first request" - assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \ + ) + assert start_locs[last_req] <= last_tok < start_locs[last_req + 1], ( "Token slice end outside of last request" + ) # If the "middle" request has tokens in both ubatches, we have to split it. # If ubatch_slice is the first ubatch then we will be splitting the last @@ -132,12 +145,13 @@ def _make_metadata_with_slice( splits_last_request = last_tok < start_locs[last_req + 1] - 1 query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice) - query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, - request_slice) + query_start_loc = slice_query_start_locs( + attn_metadata.query_start_loc, request_slice + ) assert len(query_start_loc) >= 2, ( - f"query_start_loc must have at least 2 elements, " - f"got {len(query_start_loc)}") + f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}" + ) if splits_first_request: tokens_skipped = first_tok - start_locs[first_req] @@ -159,14 +173,13 @@ def _make_metadata_with_slice( seq_lens_cpu[-1] -= tokens_skipped max_seq_len = int(seq_lens_cpu.max()) - num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ - request_slice] + num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice] num_requests = request_slice.stop - request_slice.start num_actual_tokens = token_slice.stop - token_slice.start max_query_len = int( - torch.max(torch.abs(query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1])).item()) + torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item() + ) # This is to account for the case where we are in a dummy # run and query_start_loc_cpu is full of 0s @@ -196,15 +209,14 @@ def split_attn_metadata( common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ - Creates a new CommonAttentionMetadata instance that corresponds to the + Creates a new CommonAttentionMetadata instance that corresponds to the requests for each UBatchSlice in ubatch_slices. Note: This function does not modify common_attn_metadata """ results = [] for ubatch_slice in ubatch_slices: - results.append( - _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata)) return results @@ -213,7 +225,7 @@ def split_attn_metadata( class AttentionCGSupport(enum.Enum): - """ Constants for the cudagraph support of the attention backend + """Constants for the cudagraph support of the attention backend Here we do not consider the cascade attention, as currently it is never cudagraph supported.""" @@ -231,46 +243,53 @@ class AttentionCGSupport(enum.Enum): class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention (default: no). - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. reorder_batch_threshold: Optional[int] = None @abstractmethod - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): self.kv_cache_spec = kv_cache_spec self.layer_names = layer_names self.vllm_config = vllm_config self.device = device def _init_reorder_batch_threshold( - self, - reorder_batch_threshold: int = 1, - supports_spec_as_decode: bool = False) -> None: + self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False + ) -> None: self.reorder_batch_threshold = reorder_batch_threshold - if self.reorder_batch_threshold is not None \ - and supports_spec_as_decode: + if self.reorder_batch_threshold is not None and supports_spec_as_decode: # If the backend supports spec-as-decode kernels, then we can set # the reorder_batch_threshold based on the number of speculative # tokens from the config. speculative_config = self.vllm_config.speculative_config - if (speculative_config is not None - and speculative_config.num_speculative_tokens is not None): - self.reorder_batch_threshold = \ + if ( + speculative_config is not None + and speculative_config.num_speculative_tokens is not None + ): + self.reorder_batch_threshold = ( 1 + speculative_config.num_speculative_tokens + ) @abstractmethod - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> M: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. - + Args: common_prefix_len: The length of the common prefix of the batch. common_attn_metadata: The common attention metadata. @@ -280,8 +299,9 @@ def build(self, """ raise NotImplementedError - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: """ Update the order of requests in the batch based on the attention backend's needs. For example, some attention backends (namely MLA) may @@ -298,14 +318,16 @@ def reorder_batch(self, input_batch: "InputBatch", raise NotImplementedError def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ Build attention metadata for CUDA graph capture. Uses build by default. Subclasses that override this method should call self.build or super().build_for_cudagraph_capture. """ - return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + return self.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) def build_for_drafting( self, @@ -314,7 +336,7 @@ def build_for_drafting( ) -> M: """ Build attention metadata for draft model. Uses build by default. - + Args: common_attn_metadata: The common attention metadata. draft_index: The index of the current draft operation. @@ -323,9 +345,11 @@ def build_for_drafting( For tree-based attention, this index instead refers to the draft attempt for the i-th level in the tree of tokens. """ - return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - fast_build=True) + return self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + fast_build=True, + ) def use_cascade_attention( self, @@ -348,8 +372,11 @@ def get_kv_cache_layout(): if _KV_CACHE_LAYOUT_OVERRIDE is not None: cache_layout = _KV_CACHE_LAYOUT_OVERRIDE - logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \ - "Setting KV cache layout to %s.", cache_layout) + logger.info_once( + "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " + "Setting KV cache layout to %s.", + cache_layout, + ) return cache_layout # Format specified by the user. @@ -359,8 +386,11 @@ def get_kv_cache_layout(): cache_layout = get_kv_connector_cache_layout() else: assert is_valid_kv_cache_layout(cache_layout) - logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ - "detected. Setting KV cache layout to %s.", cache_layout) + logger.info_once( + "`VLLM_KV_CACHE_LAYOUT` environment variable " + "detected. Setting KV cache layout to %s.", + cache_layout, + ) return cache_layout @@ -385,8 +415,8 @@ class PerLayerParameters: def get_per_layer_parameters( - vllm_config: VllmConfig, layer_names: list[str], - cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]: + vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"] +) -> dict[str, PerLayerParameters]: """ Scan layers in `layer_names` and determine some hyperparameters to use during `plan`. @@ -406,17 +436,18 @@ def get_per_layer_parameters( sm_scale = impl.scale has_sinks = getattr(impl, "sinks", None) is not None - per_layer_params[key] = PerLayerParameters(window_left, - logits_soft_cap, sm_scale, - has_sinks) + per_layer_params[key] = PerLayerParameters( + window_left, logits_soft_cap, sm_scale, has_sinks + ) return per_layer_params def infer_global_hyperparameters( - per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: + per_layer_params: dict[str, PerLayerParameters], +) -> PerLayerParameters: """ - Currently, FlashInfer backend other than trtllm-gen + Currently, FlashInfer backend other than trtllm-gen only support models in which all layers share the same values for the following hyperparameters: - `window_left` @@ -437,13 +468,15 @@ def infer_global_hyperparameters( for params in param_sets: if params.window_left != global_params.window_left: raise ValueError( - "Window left is not the same for all layers. " \ - "One potential fix is to set disable_sliding_window=True") + "Window left is not the same for all layers. " + "One potential fix is to set disable_sliding_window=True" + ) assert params == global_params, ( "FlashInfer backend currently only supports models in which all" "layers share the same values " "for the following hyperparameters:" - "`window_left`, `logits_soft_cap`, `sm_scale`.") + "`window_left`, `logits_soft_cap`, `sm_scale`." + ) return global_params @@ -525,11 +558,10 @@ def make_local_attention_virtual_batches( # new_tokens_in_first_block = [2, 1, 4] # local_blocks = [2, 4, 2] q_tokens_in_first_block = np.minimum( - attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), - q_seqlens).astype(np.int32) + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens + ).astype(np.int32) tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) - local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, - attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size) # Once we know the number of local blocks we can compute the request spans # for each batch idx, we can figure out the number of "virtual" requests we @@ -550,14 +582,13 @@ def make_local_attention_virtual_batches( rarange = np.repeat(local_blocks, local_blocks) - arange - 1 # Then we can compute the seqlens_q_local, handling the fact that the # first and last blocks could be partial - seqlens_q_local = \ - np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) # set the first block since this may be a partial block seqlens_q_local[arange == 0] = q_tokens_in_first_block # set the remaining blocks seqlens_q_local[arange > 0] = np.minimum( - seqlens_q_local - attn_chunk_size * (arange - 1), - attn_chunk_size)[arange > 0] + seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size + )[arange > 0] # convert from q_seqlens to cu_seqlens_q cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32) @@ -569,22 +600,20 @@ def make_local_attention_virtual_batches( # batch # For our example this will be: # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] - seqlens_k_local = np.full(cu_num_blocks[-1], - attn_chunk_size, - dtype=np.int32) + seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block num_computed_tokens_local = seqlens_k_local - seqlens_q_local - k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ - (rarange * attn_chunk_size + \ - np.repeat(tokens_in_last_block, local_blocks)) + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( + rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) + ) # For the example the local attention blocks start at: # _b0_ _____b1_____ _b2_ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] block_starts = k_seqstarts_absolute // block_size - assert attn_chunk_size % block_size == 0, \ - f"attn_chunk_size {attn_chunk_size} is not " \ - f"divisible by block_size {block_size}" + assert attn_chunk_size % block_size == 0, ( + f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}" + ) pages_per_local_batch = attn_chunk_size // block_size # Create a block_table for the local attention blocks @@ -605,12 +634,14 @@ def make_local_attention_virtual_batches( # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) # ] - block_indices = (block_starts[:, None] + - np.arange(pages_per_local_batch, dtype=np.int32)) - block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - - 1) - batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), - local_blocks * pages_per_local_batch) + block_indices = block_starts[:, None] + np.arange( + pages_per_local_batch, dtype=np.int32 + ) + block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1) + batch_indices = np.repeat( + np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch, + ) # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance # regression when using numpy arrays (batch and block indices) to index into @@ -618,8 +649,9 @@ def make_local_attention_virtual_batches( # tensor first, which recovers perf. batch_indices_torch = torch.from_numpy(batch_indices) block_indices_torch = torch.from_numpy(block_indices) - block_table_local = block_table[batch_indices_torch, block_indices_torch]\ - .view(virtual_batches, -1) + block_table_local = block_table[batch_indices_torch, block_indices_torch].view( + virtual_batches, -1 + ) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) @@ -627,8 +659,7 @@ def make_local_attention_virtual_batches( return CommonAttentionMetadata( query_start_loc_cpu=query_start_loc_cpu, - query_start_loc=query_start_loc_cpu.to(device=device, - non_blocking=True), + query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True), seq_lens_cpu=seq_lens_cpu, seq_lens=seq_lens_cpu.to(device=device, non_blocking=True), num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), @@ -668,9 +699,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( # Find how many decode indices belong to each request # request_ids: [0, 1, 1, 2] - request_ids = torch.bucketize(logits_indices, - query_start_loc[1:], - right=True) + request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True) # Figure out how many tokens are in each request # num_decode_tokens: [1, 2, 1] @@ -678,9 +707,9 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( # Calculate new query_start_loc with tokens in generation_indices # decode_query_start_loc: [0, 1, 3, 4] - decode_query_start_loc = torch.empty(num_reqs + 1, - device=query_start_loc.device, - dtype=query_start_loc.dtype) + decode_query_start_loc = torch.empty( + num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype + ) decode_query_start_loc[0] = 0 decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) @@ -689,8 +718,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( common_attn_metadata = CommonAttentionMetadata( query_start_loc=decode_query_start_loc, - query_start_loc_cpu=decode_query_start_loc.to("cpu", - non_blocking=True), + query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True), seq_lens=seq_lens, seq_lens_cpu=seq_lens.to("cpu", non_blocking=True), num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, @@ -706,22 +734,25 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( def subclass_attention_backend( - name_prefix: str, attention_backend_cls: type[AttentionBackend], - builder_cls: type[AttentionMetadataBuilder[M]] + name_prefix: str, + attention_backend_cls: type[AttentionBackend], + builder_cls: type[AttentionMetadataBuilder[M]], ) -> type[AttentionBackend]: """ Return a new subclass where `get_builder_cls` returns `builder_cls`. """ name: str = name_prefix + attention_backend_cls.__name__ # type: ignore - return type(name, (attention_backend_cls, ), - {"get_builder_cls": lambda: builder_cls}) + return type( + name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls} + ) def split_decodes_and_prefills( - common_attn_metadata: CommonAttentionMetadata, - decode_threshold: int = 1, - require_uniform: bool = False) -> tuple[int, int, int, int]: + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, + require_uniform: bool = False, +) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode requests. @@ -745,8 +776,9 @@ def split_decodes_and_prefills( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - if max_query_len <= decode_threshold and \ - (not require_uniform or decode_threshold <= 1): + if max_query_len <= decode_threshold and ( + not require_uniform or decode_threshold <= 1 + ): return num_reqs, 0, num_tokens, 0 query_lens = query_start_loc[1:] - query_start_loc[:-1] @@ -779,7 +811,7 @@ def reorder_batch_to_split_decodes_and_prefills( """ Reorders the batch to split into prefill and decode requests; places all requests with <= decode_threshold tokens at the front of the batch. - + Returns: True if the batch was modified, False otherwise. """ @@ -834,8 +866,7 @@ def reorder_batch_to_split_decodes_and_prefills( return modified_batch -def reshape_query_for_spec_decode(query: torch.Tensor, - batch_size: int) -> torch.Tensor: +def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor: """ Reshapes the query tensor for the specified batch size, so that it has shape (batch_size, seq_len, num_heads, head_dim). @@ -845,13 +876,13 @@ def reshape_query_for_spec_decode(query: torch.Tensor, num_heads = query.shape[1] head_dim = query.shape[2] assert total_tokens % batch_size == 0, ( - f"{total_tokens=} is not divisible by {batch_size=}") + f"{total_tokens=} is not divisible by {batch_size=}" + ) seq_len = total_tokens // batch_size return query.view(batch_size, seq_len, num_heads, head_dim) -def reshape_attn_output_for_spec_decode( - attn_output: torch.Tensor) -> torch.Tensor: +def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor: """ Reshapes the attention output tensor, so that the batch_size and seq_len dimensions are combined. @@ -859,16 +890,14 @@ def reshape_attn_output_for_spec_decode( if attn_output.dim() == 3: # Already in the correct shape return attn_output - assert attn_output.dim() == 4, \ - f"attn_output must be 4D, got {attn_output.dim()}D" + assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D" total_tokens = attn_output.shape[0] * attn_output.shape[1] - return attn_output.view(total_tokens, attn_output.shape[2], - attn_output.shape[3]) + return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3]) KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ - ('logits_indices_padded', Optional[torch.Tensor], None), - ('num_logits_indices', int, 0), + ("logits_indices_padded", Optional[torch.Tensor], None), + ("num_logits_indices", int, 0), ] @@ -881,7 +910,7 @@ def subclass_attention_metadata( Return a new subclass of `metadata_cls` with additional fields """ name: str = name_prefix + metadata_cls.__name__ # type: ignore - Wrapped = make_dataclass(name, fields, bases=(metadata_cls, )) + Wrapped = make_dataclass(name, fields, bases=(metadata_cls,)) return Wrapped @@ -895,55 +924,55 @@ def create_fast_prefill_custom_backend( prefix: str, underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: - underlying_builder = underlying_attn_backend.get_builder_cls() class FastPrefillAttentionBuilder(underlying_builder): # type: ignore - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> AttentionMetadata: - new_common_attn_metadata =\ - make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) - metadata = super().build(common_prefix_len, - new_common_attn_metadata, fast_build) + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: + new_common_attn_metadata = ( + make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) + ) + metadata = super().build( + common_prefix_len, new_common_attn_metadata, fast_build + ) class KVSharingFastPrefillAttentionMetadata( - metadata.__class__, # type: ignore - KVSharingFastPrefillMetadata): - + metadata.__class__, # type: ignore + KVSharingFastPrefillMetadata, + ): def __init__(self, metadata, common_attn_metadata): # Shallow copy all fields in metadata cls for field in fields(metadata.__class__): - setattr(self, field.name, - getattr(metadata, field.name)) + setattr(self, field.name, getattr(metadata, field.name)) # Set additional fields that will be used in model code - assert (common_attn_metadata.logits_indices_padded - is not None - and common_attn_metadata.num_logits_indices - is not None) - self.logits_indices_padded = \ + assert ( + common_attn_metadata.logits_indices_padded is not None + and common_attn_metadata.num_logits_indices is not None + ) + self.logits_indices_padded = ( common_attn_metadata.logits_indices_padded - self.num_logits_indices = \ - common_attn_metadata.num_logits_indices + ) + self.num_logits_indices = common_attn_metadata.num_logits_indices - return KVSharingFastPrefillAttentionMetadata( - metadata, common_attn_metadata) + return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=FastPrefillAttentionBuilder) + builder_cls=FastPrefillAttentionBuilder, + ) return attn_backend def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): - # Needed for causal_conv1d - seqlens = query_start_loc_p.diff().to('cpu') + seqlens = query_start_loc_p.diff().to("cpu") nums_dict = {} # type: ignore batch_ptr = None token_chunk_offset_ptr = None @@ -951,40 +980,39 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): for BLOCK_M in [8]: # cover all BLOCK_M values nums = -(-seqlens // BLOCK_M) nums_dict[BLOCK_M] = {} - nums_dict[BLOCK_M]['nums'] = nums - nums_dict[BLOCK_M]['tot'] = nums.sum().item() + nums_dict[BLOCK_M]["nums"] = nums + nums_dict[BLOCK_M]["tot"] = nums.sum().item() mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) - nums_dict[BLOCK_M]['mlist'] = mlist - mlist_len = len(nums_dict[BLOCK_M]['mlist']) - nums_dict[BLOCK_M]['mlist_len'] = mlist_len + nums_dict[BLOCK_M]["mlist"] = mlist + mlist_len = len(nums_dict[BLOCK_M]["mlist"]) + nums_dict[BLOCK_M]["mlist_len"] = mlist_len MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 offsetlist = [] # type: ignore for idx, num in enumerate(nums): offsetlist.extend(range(num)) offsetlist = torch.tensor(offsetlist, dtype=torch.int32) - nums_dict[BLOCK_M]['offsetlist'] = offsetlist + nums_dict[BLOCK_M]["offsetlist"] = offsetlist if batch_ptr is None: # Update default value after class definition - batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=device) - token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=device) + batch_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device + ) else: if batch_ptr.nelement() < MAX_NUM_PROGRAMS: batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) token_chunk_offset_ptr.resize_( # type: ignore - MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) + MAX_NUM_PROGRAMS + ).fill_(PAD_SLOT_ID) batch_ptr[0:mlist_len].copy_(mlist) token_chunk_offset_ptr[ # type: ignore - 0:mlist_len].copy_(offsetlist) - nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr - nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr - ) # type: ignore + 0:mlist_len + ].copy_(offsetlist) + nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr + nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore return nums_dict, batch_ptr, token_chunk_offset_ptr diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 9d667ee04f75..17e752277c66 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -7,20 +7,29 @@ import torch -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionType, +) from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) + AttentionMetadataBuilder, + CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, + split_decodes_and_prefills, +) from vllm.v1.kv_cache_interface import AttentionSpec try: from xformers import ops as xops from xformers.ops.fmha.attn_bias import ( - AttentionBias, PagedBlockDiagonalCausalWithOffsetPaddedKeysMask) + AttentionBias, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + ) XFORMERS_AVAILABLE = True except ImportError: @@ -36,7 +45,6 @@ class XFormersAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True @classmethod @@ -86,7 +94,8 @@ def validate_head_size(cls, head_size: int) -> None: f"Head size {head_size} is not supported by {attn_type}. " f"Supported head sizes are: {supported_head_sizes}. " "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes.") + "FlexAttention backend which supports all head sizes." + ) @staticmethod def get_name() -> str: @@ -153,9 +162,9 @@ def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]: # metadata structure return self._cached_prefill_metadata - q_start_loc = self.query_start_loc[self.num_decodes:] + q_start_loc = self.query_start_loc[self.num_decodes :] q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[self.num_decodes:] + kv_seqlens = self.seq_lens[self.num_decodes :] # Construct & cache prefill-phase attention metadata structure self._cached_prefill_metadata = XFormersAttentionMetadata( num_actual_tokens=self.num_prefill_tokens, @@ -163,8 +172,8 @@ def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]: query_start_loc=q_start_loc - q_start_loc[0], max_seq_len=int(kv_seqlens.max().item()), seq_lens=kv_seqlens, - block_table=self.block_table[self.num_decodes:], - slot_mapping=self.slot_mapping[self.num_decode_tokens:], + block_table=self.block_table[self.num_decodes :], + slot_mapping=self.slot_mapping[self.num_decode_tokens :], ) return self._cached_prefill_metadata @@ -180,24 +189,24 @@ def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]: q_start_loc = self.query_start_loc q_seqlens = torch.diff(q_start_loc) - decode_kv_seqlens = self.seq_lens[:self.num_decodes] + decode_kv_seqlens = self.seq_lens[: self.num_decodes] # Construct & cache decode-phase attention metadata structure self._cached_decode_metadata = XFormersAttentionMetadata( num_actual_tokens=self.num_decode_tokens, - max_query_len=int(q_seqlens[:self.num_decodes].max().item()), - query_start_loc=q_start_loc[:self.num_decodes + 1], + max_query_len=int(q_seqlens[: self.num_decodes].max().item()), + query_start_loc=q_start_loc[: self.num_decodes + 1], max_seq_len=int(decode_kv_seqlens.max().item()), seq_lens=decode_kv_seqlens, - block_table=self.block_table[:self.num_decodes], - slot_mapping=self.slot_mapping[:self.num_decode_tokens], + block_table=self.block_table[: self.num_decodes], + slot_mapping=self.slot_mapping[: self.num_decode_tokens], attn_bias=self.attn_bias, ) return self._cached_decode_metadata class XFormersAttentionMetadataBuilder( - AttentionMetadataBuilder[XFormersAttentionMetadata]): - + AttentionMetadataBuilder[XFormersAttentionMetadata] +): reorder_batch_threshold: int = 1 def __init__( @@ -214,12 +223,12 @@ def __init__( self._num_decodes = 0 self._num_decode_tokens = 0 - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: return reorder_batch_to_split_decodes_and_prefills( - input_batch, - scheduler_output, - decode_threshold=self.reorder_batch_threshold) + input_batch, scheduler_output, decode_threshold=self.reorder_batch_threshold + ) def build( self, @@ -229,8 +238,9 @@ def build( ) -> XFormersAttentionMetadata: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( - common_attn_metadata, - decode_threshold=self.reorder_batch_threshold)) + common_attn_metadata, decode_threshold=self.reorder_batch_threshold + ) + ) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc @@ -246,14 +256,13 @@ def build( # Construct the decoder bias. decode_q_seqlens = q_seqlens[:num_decodes] decode_kv_seqlens = kv_seqlens[:num_decodes] - bias = ( - PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( - q_seqlen=decode_q_seqlens.tolist(), - kv_seqlen=decode_kv_seqlens.tolist(), - page_size=self.block_size, - block_tables=block_table[:num_decodes], - device=block_table.device, - )) + bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=decode_q_seqlens.tolist(), + kv_seqlen=decode_kv_seqlens.tolist(), + page_size=self.block_size, + block_tables=block_table[:num_decodes], + device=block_table.device, + ) return XFormersAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -272,7 +281,6 @@ def build( class XFormersAttentionImpl(AttentionImpl): - def __init__( self, num_heads: int, @@ -289,8 +297,7 @@ def __init__( if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported in V0.") if alibi_slopes is not None: - raise NotImplementedError( - "XFormers does not support alibi slopes yet.") + raise NotImplementedError("XFormers does not support alibi slopes yet.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -313,10 +320,12 @@ def __init__( XFormersAttentionBackend.validate_head_size(head_size) if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "XFormersAttentionImpl.") + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "XFormersAttentionImpl." + ) def forward( self, @@ -347,7 +356,8 @@ def forward( if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" - " for XFormersAttentionImpl") + " for XFormersAttentionImpl" + ) if attn_metadata is None: # Profiling run. @@ -377,8 +387,7 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens num_decode_tokens = attn_metadata.num_decode_tokens if prefill_meta := attn_metadata.prefill_metadata: - descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, - key.shape[1]) + descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1]) unified_attention( q=query[num_decode_tokens:num_actual_tokens], k=key_cache, @@ -403,36 +412,38 @@ def forward( # Query for decode. KV is not needed because it is already cached. decode_query = query[:num_decode_tokens] # Reshape query to [1, B_T, G, H, D]. - q = decode_query.view(1, -1, self.num_kv_heads, - self.num_queries_per_kv, self.head_size) + q = decode_query.view( + 1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size + ) # Reshape the k and v caches to [1, Bkv_T, G, H, D] - cache_k = key_cache.view(1, -1, self.num_kv_heads, 1, - self.head_size).expand( - 1, - -1, - self.num_kv_heads, - self.num_queries_per_kv, - self.head_size, - ) - cache_v = value_cache.view(1, -1, self.num_kv_heads, 1, - self.head_size).expand( - 1, - -1, - self.num_kv_heads, - self.num_queries_per_kv, - self.head_size, - ) + cache_k = key_cache.view( + 1, -1, self.num_kv_heads, 1, self.head_size + ).expand( + 1, + -1, + self.num_kv_heads, + self.num_queries_per_kv, + self.head_size, + ) + cache_v = value_cache.view( + 1, -1, self.num_kv_heads, 1, self.head_size + ).expand( + 1, + -1, + self.num_kv_heads, + self.num_queries_per_kv, + self.head_size, + ) attn_bias = decode_meta.attn_bias - output[: - num_decode_tokens] = xops.memory_efficient_attention_forward( - q, - cache_k, - cache_v, - attn_bias=attn_bias, - p=0.0, - scale=self.scale, - ).view(decode_query.shape) + output[:num_decode_tokens] = xops.memory_efficient_attention_forward( + q, + cache_k, + cache_v, + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + ).view(decode_query.shape) # Reshape the output tensor. return output diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 617a724a1ad2..ddfd94322737 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -3,16 +3,24 @@ from collections.abc import Iterable from typing import Any, Optional, Union -from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, - BlockRemoved, BlockStored, - KVCacheEvent) +from vllm.distributed.kv_events import ( + MEDIUM_GPU, + AllBlocksCleared, + BlockRemoved, + BlockStored, + KVCacheEvent, +) from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - ExternalBlockHash, - FreeKVCacheBlockQueue, KVCacheBlock, - get_block_hash, - make_block_hash_with_group_id, - maybe_convert_block_hash) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + BlockHashWithGroupId, + ExternalBlockHash, + FreeKVCacheBlockQueue, + KVCacheBlock, + get_block_hash, + make_block_hash_with_group_id, + maybe_convert_block_hash, +) from vllm.v1.request import Request logger = init_logger(__name__) @@ -20,7 +28,7 @@ class BlockHashToBlockMap: """ - Cache of blocks that are used for prefix caching. It caches blocks + Cache of blocks that are used for prefix caching. It caches blocks from hash directly to a block or multiple blocks (i.e. {block_hash: KVCacheBlocks}) - Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks @@ -42,11 +50,11 @@ class BlockHashToBlockMap: """ def __init__(self): - self._cache: dict[BlockHashWithGroupId, - Union[KVCacheBlock, dict[int, KVCacheBlock]]] = {} + self._cache: dict[ + BlockHashWithGroupId, Union[KVCacheBlock, dict[int, KVCacheBlock]] + ] = {} - def get_one_block(self, - key: BlockHashWithGroupId) -> Optional[KVCacheBlock]: + def get_one_block(self, key: BlockHashWithGroupId) -> Optional[KVCacheBlock]: """ Gets any block with the given block hash key. """ @@ -77,8 +85,7 @@ def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None: else: self._unexpected_blocks_type(blocks) - def pop(self, key: BlockHashWithGroupId, - block_id: int) -> Optional[KVCacheBlock]: + def pop(self, key: BlockHashWithGroupId, block_id: int) -> Optional[KVCacheBlock]: """ Checks if block_hash exists and pop block_id from the cache """ @@ -148,8 +155,7 @@ def __init__( self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) # Cache for block lookup - self.cached_block_hash_to_block: BlockHashToBlockMap = \ - BlockHashToBlockMap() + self.cached_block_hash_to_block: BlockHashToBlockMap = BlockHashToBlockMap() # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to @@ -161,9 +167,9 @@ def __init__( self.kv_event_queue: list[KVCacheEvent] = [] def get_cached_block( - self, block_hash: BlockHash, - kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]: - """Get the cached block by the block hash for each group in + self, block_hash: BlockHash, kv_cache_group_ids: list[int] + ) -> Optional[list[KVCacheBlock]]: + """Get the cached block by the block hash for each group in `kv_cache_group_ids`, or None if cache miss for any group. If there are duplicated blocks, we return the first block in the cache. @@ -177,9 +183,11 @@ def get_cached_block( cached_blocks = [] for group_id in kv_cache_group_ids: block_hash_with_group_id = make_block_hash_with_group_id( - block_hash, group_id) + block_hash, group_id + ) block = self.cached_block_hash_to_block.get_one_block( - block_hash_with_group_id) + block_hash_with_group_id + ) if not block: return None cached_blocks.append(block) @@ -218,17 +226,18 @@ def cache_full_blocks( new_block_hashes = request.block_hashes[num_cached_blocks:] new_hashes: Optional[list[ExternalBlockHash]] = ( - [] if self.enable_kv_cache_events else None) + [] if self.enable_kv_cache_events else None + ) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None block_hash = new_block_hashes[i] # Update and added the full block to the cache. block_hash_with_group_id = make_block_hash_with_group_id( - block_hash, kv_cache_group_id) + block_hash, kv_cache_group_id + ) blk.block_hash = block_hash_with_group_id - self.cached_block_hash_to_block.insert(block_hash_with_group_id, - blk) + self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk) if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) @@ -239,20 +248,21 @@ def cache_full_blocks( parent_block = blocks[num_cached_blocks - 1] assert parent_block.block_hash is not None parent_block_hash = maybe_convert_block_hash( - get_block_hash(parent_block.block_hash)) + get_block_hash(parent_block.block_hash) + ) self.kv_event_queue.append( BlockStored( block_hashes=new_hashes, parent_block_hash=parent_block_hash, - token_ids=request. - all_token_ids[num_cached_blocks * - block_size:num_full_blocks * block_size], + token_ids=request.all_token_ids[ + num_cached_blocks * block_size : num_full_blocks * block_size + ], block_size=block_size, - lora_id=request.lora_request.id - if request.lora_request else None, + lora_id=request.lora_request.id if request.lora_request else None, medium=MEDIUM_GPU, - )) + ) + ) def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. @@ -266,8 +276,7 @@ def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: A list of new block. """ if num_blocks > self.get_num_free_blocks(): - raise ValueError( - f"Cannot get {num_blocks} free blocks from the pool") + raise ValueError(f"Cannot get {num_blocks} free blocks from the pool") ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks) @@ -299,8 +308,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: # The block doesn't have hash, eviction is not needed return False - if self.cached_block_hash_to_block.pop(block_hash, - block.block_id) is None: + if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None: # block not found in cached_block_hash_to_block, # eviction is not needed return False @@ -313,10 +321,11 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[ - maybe_convert_block_hash(get_block_hash(block_hash)) - ], - medium=MEDIUM_GPU)) + BlockRemoved( + block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))], + medium=MEDIUM_GPU, + ) + ) return True def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: @@ -347,10 +356,9 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 - self.free_block_queue.append_n([ - block for block in blocks_list - if block.ref_cnt == 0 and not block.is_null - ]) + self.free_block_queue.append_n( + [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] + ) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -365,7 +373,9 @@ def reset_prefix_cache(self) -> bool: if num_used_blocks != 1: # The null block is always marked as used logger.warning( "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks - 1) + "blocks (%d) are not freed yet", + num_used_blocks - 1, + ) return False # Remove all hashes so that no new blocks will hit. @@ -405,7 +415,7 @@ def get_usage(self) -> float: def take_events(self) -> list[KVCacheEvent]: """Atomically takes all events and clears the queue. - + Returns: A list of KV cache events. """ diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index eadea15a2e5e..c70025992e70 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -33,12 +33,12 @@ class EncoderCacheManager: within requests, allowing for fine-grained memory management and enabling chunked processing of multimodal inputs. - Cache is enabled to share embeddings of same multimodal data - item (identified by their hash value) between different requests, - and eviction takes place at allocation time when there's no free + Cache is enabled to share embeddings of same multimodal data + item (identified by their hash value) between different requests, + and eviction takes place at allocation time when there's no free space for new embeddings. Oldest cached embeddings with no request referenced will be first evicted. - + Args: cache_size: Limit the size of the cache, measured by the number of tokens from the input sequence. @@ -99,27 +99,31 @@ def check_and_update_cache(self, request: Request, input_id: int) -> bool: self.cached[mm_hash].add(request.request_id) return True - def can_allocate(self, request: Request, input_id: int, - encoder_compute_budget: int, - num_tokens_to_schedule: int) -> bool: - """Check if there's sufficient cache space for a multimodal input. + def can_allocate( + self, + request: Request, + input_id: int, + encoder_compute_budget: int, + num_tokens_to_schedule: int, + ) -> bool: + """Check if there's sufficient cache space for a multimodal input. If there is, return True and update EncoderCacheManager state. If there is not enough free space in `num_free_slots` but there is enough reclaimable space in `num_freeable_slots`, entries will be evicted from `freeable` (their mm_hash appended to `freed`) until - enough space is available, and then this method returns True. + enough space is available, and then this method returns True. Older entries are evicted first. - - Returns False only if the requested number of tokens exceeds both + + Returns False only if the requested number of tokens exceeds both the free and reclaimable capacities combined. Args: request: The request containing the multimodal input. input_id: Index of the multimodal input within the request. - encoder_compute_budget: Number of encoder tokens allowed to be + encoder_compute_budget: Number of encoder tokens allowed to be computed when this method is invoked. - num_tokens_to_schedule: Number of tokens already scheduled to be + num_tokens_to_schedule: Number of tokens already scheduled to be allocated with cache space when this method is invoked. Returns: @@ -127,7 +131,7 @@ def can_allocate(self, request: Request, input_id: int, input (possibly after reclaiming `freeable` entries); otherwise False. - Note: This method does not allocate physical memory for the encoder + Note: This method does not allocate physical memory for the encoder output but only the state of EncoderCacheManager. """ num_tokens = request.get_num_encoder_tokens(input_id) @@ -202,7 +206,7 @@ def free_encoder_input(self, request: Request, input_id: int) -> None: When the reference set for the corresponding `mm_hash` becomes empty, the entry is appended to `freeable` and `num_freeable_slots` is - increased by the number of encoder tokens for that input. + increased by the number of encoder tokens for that input. The entry is NOT physically freed until capacity is needed (e.g., by `can_allocate`). @@ -221,8 +225,8 @@ def free_encoder_input(self, request: Request, input_id: int) -> None: def free(self, request: Request) -> None: """Free all encoder input cache reference held by *request*. - For each cached input ID, `free_encoder_input` is invoked. - The data stays in memory until eviction is triggered by a future + For each cached input ID, `free_encoder_input` is invoked. + The data stays in memory until eviction is triggered by a future attempt allocation called by 'can_allocate'. Typically called when a request is finished, cancelled, or aborted. @@ -236,9 +240,9 @@ def get_freed_mm_hashes(self) -> list[str]: Returns: List of mm_hash strings that were actually evicted since the last - call to be used by the scheduler to notify workers about which - encoder outputs can be removed from their caches. The internal - list is cleared after this call. + call to be used by the scheduler to notify workers about which + encoder outputs can be removed from their caches. The internal + list is cleared after this call. """ freed = self.freed self.freed = [] @@ -250,7 +254,7 @@ def compute_encoder_budget( scheduler_config: "SchedulerConfig", mm_registry: MultiModalRegistry, ) -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler + """Compute the encoder cache budget based on the model and scheduler configurations. Returns: @@ -260,8 +264,9 @@ def compute_encoder_budget( from the input sequence. """ if mm_registry.supports_multimodal_inputs(model_config): - max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) + max_tokens_by_modality = ( + mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config) + ) return compute_mm_encoder_budget( scheduler_config, @@ -271,18 +276,17 @@ def compute_encoder_budget( return compute_text_encoder_budget(scheduler_config) -def compute_text_encoder_budget( - scheduler_config: "SchedulerConfig") -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler +def compute_text_encoder_budget(scheduler_config: "SchedulerConfig") -> tuple[int, int]: + """Compute the encoder cache budget based on the model and scheduler configurations for a text-only model. Args: scheduler_config: Scheduler configuration. Returns: - - Compute budget for encoder execution, in unit of number of tokens + - Compute budget for encoder execution, in unit of number of tokens in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens + - Space budget for encoder cache size, in unit of number of tokens in the input sequence. """ # Currently text-only encoder-decoder models are not supported @@ -293,7 +297,7 @@ def compute_mm_encoder_budget( scheduler_config: "SchedulerConfig", max_tokens_by_modality: Mapping[str, int], ) -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler + """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model. Args: @@ -312,22 +316,28 @@ def compute_mm_encoder_budget( logger.warning( "All non-text modalities supported by the model have been " "explicitly disabled via limit_mm_per_prompt. Encoder cache will " - "not be initialized.") + "not be initialized." + ) return 0, 0 max_tokens_per_mm_item = max(max_tokens_by_modality.values()) - if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item - > scheduler_config.max_num_batched_tokens): + if ( + scheduler_config.disable_chunked_mm_input + and max_tokens_per_mm_item > scheduler_config.max_num_batched_tokens + ): raise ValueError( "Chunked MM input disabled but max_tokens_per_mm_item " f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens" f" ({scheduler_config.max_num_batched_tokens}). Please increase " - "max_num_batched_tokens.") + "max_num_batched_tokens." + ) - encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens, - max_tokens_per_mm_item) - encoder_cache_size = max(scheduler_config.encoder_cache_size, - max_tokens_per_mm_item) + encoder_compute_budget = max( + scheduler_config.max_num_encoder_input_tokens, max_tokens_per_mm_item + ) + encoder_cache_size = max( + scheduler_config.encoder_cache_size, max_tokens_per_mm_item + ) return encoder_compute_budget, encoder_cache_size diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 86771060c409..37e1b7ca3932 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -6,9 +6,11 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec) -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) + CrossAttentionManager, + FullAttentionManager, + get_manager_for_kv_cache_spec, +) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.request import Request @@ -30,8 +32,9 @@ def __init__( self.max_model_len = max_model_len self.enable_caching = enable_caching - self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, - enable_kv_cache_events) + self.block_pool = BlockPool( + kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events + ) # Needs special handling for find_longest_cache_hit if eagle is enabled self.use_eagle = use_eagle @@ -41,19 +44,23 @@ def __init__( block_pool=self.block_pool, kv_cache_group_id=i, dcp_world_size=dcp_world_size, - ) for i, kv_cache_group in enumerate( - self.kv_cache_config.kv_cache_groups)) + ) + for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) + ) - def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int, - new_computed_blocks: tuple[ - list[KVCacheBlock], ...], - num_encoder_tokens: int) -> int: + def get_num_blocks_to_allocate( + self, + request_id: str, + num_tokens: int, + new_computed_blocks: tuple[list[KVCacheBlock], ...], + num_encoder_tokens: int, + ) -> int: """ Get the number of blocks needed to be allocated for the request. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. @@ -69,15 +76,17 @@ def get_num_blocks_to_allocate(self, request_id: str, num_tokens: int, # For cross-attention, we issue a single static allocation # of blocks based on the number of encoder input tokens. num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_encoder_tokens, []) + request_id, num_encoder_tokens, [] + ) else: num_blocks_to_allocate += manager.get_num_blocks_to_allocate( - request_id, num_tokens, new_computed_blocks[i]) + request_id, num_tokens, new_computed_blocks[i] + ) return num_blocks_to_allocate def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None: + self, request_id: str, new_computed_blocks: tuple[list[KVCacheBlock], ...] + ) -> None: """ Add the new computed blocks to the request. @@ -87,21 +96,18 @@ def save_new_computed_blocks( prefix cache. """ for i, manager in enumerate(self.single_type_managers): - manager.save_new_computed_blocks(request_id, - new_computed_blocks[i]) + manager.save_new_computed_blocks(request_id, new_computed_blocks[i]) def allocate_new_blocks( - self, - request_id: str, - num_tokens: int, - num_encoder_tokens: int = 0) -> tuple[list[KVCacheBlock], ...]: + self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0 + ) -> tuple[list[KVCacheBlock], ...]: """ - Allocate new blocks for the request to give it at least `num_tokens` + Allocate new blocks for the request to give it at least `num_tokens` token slots. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). num_encoder_tokens: The number of encoder tokens for allocating blocks for cross-attention. @@ -111,9 +117,13 @@ def allocate_new_blocks( """ return tuple( manager.allocate_new_blocks( - request_id, num_encoder_tokens if isinstance( - manager, CrossAttentionManager) else num_tokens) - for manager in self.single_type_managers) + request_id, + num_encoder_tokens + if isinstance(manager, CrossAttentionManager) + else num_tokens, + ) + for manager in self.single_type_managers + ) def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """ @@ -138,8 +148,9 @@ def free(self, request_id: str) -> None: for manager in self.single_type_managers: manager.free(request_id) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> list[int]: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> list[int]: """ Get the number of common prefix blocks for all requests in the RUNNING state for each kv cache group. @@ -154,16 +165,14 @@ def get_num_common_prefix_blocks(self, request_id: str, the RUNNING state for each kv cache group. """ num_blocks_per_group = [ - manager.get_num_common_prefix_blocks(request_id, - num_running_requests) + manager.get_num_common_prefix_blocks(request_id, num_running_requests) for manager in self.single_type_managers ] return num_blocks_per_group - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks` and replace + Remove the blocks that are no longer needed from `blocks` and replace the removed blocks with null_block. Args: @@ -179,7 +188,8 @@ def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: """ return tuple( manager.req_to_blocks.get(request_id) or [] - for manager in self.single_type_managers) + for manager in self.single_type_managers + ) @abstractmethod def find_longest_cache_hit( @@ -198,19 +208,27 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): Does not implement any features related to prefix caching. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_kv_cache_events: bool, - dcp_world_size: int): - super().__init__(kv_cache_config, - max_model_len, - use_eagle, - False, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + False, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) self.num_single_type_manager = len(self.single_type_managers) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> list[int]: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> list[int]: return [0] * self.num_single_type_manager def find_longest_cache_hit( @@ -219,7 +237,8 @@ def find_longest_cache_hit( max_cache_hit_length: int, ) -> tuple[tuple[list[KVCacheBlock], ...], int]: blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(self.num_single_type_manager)) + [] for _ in range(self.num_single_type_manager) + ) return blocks, 0 @@ -230,23 +249,31 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): full attention or all attention layers use sliding window attention. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool, dcp_world_size: int): - super().__init__(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) - self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) + self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size self.dcp_world_size = dcp_world_size if dcp_world_size > 1: self.block_size *= dcp_world_size assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "UnitaryKVCacheCoordinator assumes only one kv cache group") + "UnitaryKVCacheCoordinator assumes only one kv cache group" + ) def find_longest_cache_hit( self, @@ -269,26 +296,34 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for hybrid models with multiple KV cache types, and thus multiple kv cache groups. - To simplify `find_longest_cache_hit`, it only supports the combination of + To simplify `find_longest_cache_hit`, it only supports the combination of two types of KV cache groups, and one of them must be full attention. May extend to more general cases in the future. """ - def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, enable_caching: bool, - enable_kv_cache_events: bool, dcp_world_size: int): - super().__init__(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) + def __init__( + self, + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, + ): + super().__init__( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) assert dcp_world_size == 1, "DCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: """ - Verifies that the model has exactly two types of KV cache groups, and + Verifies that the model has exactly two types of KV cache groups, and one of them is full attention. Then, split the kv cache groups into full attention groups and other groups. """ @@ -303,7 +338,8 @@ def verify_and_split_kv_cache_groups(self) -> None: else: assert full_attention_spec == g.kv_cache_spec, ( "HybridKVCacheCoordinator assumes exactly one type of " - "full attention groups now.") + "full attention groups now." + ) self.full_attention_group_ids.append(i) else: if other_spec is None: @@ -311,19 +347,22 @@ def verify_and_split_kv_cache_groups(self) -> None: else: assert other_spec == g.kv_cache_spec, ( "HybridKVCacheCoordinator assumes " - "exactly one other type of groups now.") + "exactly one other type of groups now." + ) self.other_group_ids.append(i) assert full_attention_spec is not None, ( "HybridKVCacheCoordinator assumes exactly one type of full " - "attention groups now.") + "attention groups now." + ) assert other_spec is not None, ( - "HybridKVCacheCoordinator assumes exactly one type of other " - "groups now.") + "HybridKVCacheCoordinator assumes exactly one type of other groups now." + ) self.full_attention_manager_cls = FullAttentionManager self.other_attention_cls = self.single_type_managers[ - self.other_group_ids[0]].__class__ + self.other_group_ids[0] + ].__class__ self.full_attention_spec = full_attention_spec self.other_spec = other_spec self.full_attention_block_size = self.full_attention_spec.block_size @@ -334,7 +373,8 @@ def verify_and_split_kv_cache_groups(self) -> None: divisible = self.other_block_size % self.full_attention_block_size assert divisible == 0, ( "KVCacheCoordinator assumes the block_size of full " - "attention layers is divisible by other layers now.") + "attention layers is divisible by other layers now." + ) if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True @@ -347,7 +387,8 @@ def verify_and_split_kv_cache_groups(self) -> None: "do not interleave, either full attention group ids " "are before other attention group ids or vice versa." "This is for simplifying merging hit_blocks_full_attn and " - "hit_blocks_other_attn to hit_blocks.") + "hit_blocks_other_attn to hit_blocks." + ) def find_longest_cache_hit( self, @@ -367,29 +408,26 @@ def find_longest_cache_hit( - The number of tokens of the longest cache hit. """ # First, find the longest cache hit for full attention. - hit_blocks_full_attn = ( - self.full_attention_manager_cls.find_longest_cache_hit( - block_hashes=block_hashes, - max_length=max_cache_hit_length, - kv_cache_group_ids=self.full_attention_group_ids, - block_pool=self.block_pool, - kv_cache_spec=self.full_attention_spec, - use_eagle=self.use_eagle, - )) - hit_length = len( - hit_blocks_full_attn[0]) * self.full_attention_block_size + hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=self.full_attention_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.full_attention_spec, + use_eagle=self.use_eagle, + ) + hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. - hit_blocks_other_attn = ( - self.other_attention_cls.find_longest_cache_hit( - block_hashes=block_hashes, - max_length=hit_length, - kv_cache_group_ids=self.other_group_ids, - block_pool=self.block_pool, - kv_cache_spec=self.other_spec, - use_eagle=self.use_eagle, - )) + hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=hit_length, + kv_cache_group_ids=self.other_group_ids, + block_pool=self.block_pool, + kv_cache_spec=self.other_spec, + use_eagle=self.use_eagle, + ) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size # NOTE: the prefix cache hit length must be a multiple of block_size as @@ -404,7 +442,7 @@ def find_longest_cache_hit( # Truncate the full attention cache hit to the length of the # cache hit of the other attention. for group_hit_blocks in hit_blocks_full_attn: - del group_hit_blocks[hit_length // self.full_attention_block_size:] + del group_hit_blocks[hit_length // self.full_attention_block_size :] # Merge the hit blocks of full attention and other attention. if self.full_attn_first: @@ -414,27 +452,36 @@ def find_longest_cache_hit( return hit_blocks, hit_length -def get_kv_cache_coordinator(kv_cache_config: KVCacheConfig, - max_model_len: int, use_eagle: bool, - enable_caching: bool, - enable_kv_cache_events: bool, - dcp_world_size: int) -> KVCacheCoordinator: +def get_kv_cache_coordinator( + kv_cache_config: KVCacheConfig, + max_model_len: int, + use_eagle: bool, + enable_caching: bool, + enable_kv_cache_events: bool, + dcp_world_size: int, +) -> KVCacheCoordinator: if not enable_caching: - return KVCacheCoordinatorNoPrefixCache(kv_cache_config, - max_model_len, - use_eagle, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) + return KVCacheCoordinatorNoPrefixCache( + kv_cache_config, + max_model_len, + use_eagle, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) if len(kv_cache_config.kv_cache_groups) == 1: - return UnitaryKVCacheCoordinator(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) - return HybridKVCacheCoordinator(kv_cache_config, - max_model_len, - use_eagle, - enable_caching, - enable_kv_cache_events, - dcp_world_size=dcp_world_size) + return UnitaryKVCacheCoordinator( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) + return HybridKVCacheCoordinator( + kv_cache_config, + max_model_len, + use_eagle, + enable_caching, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 0af98e7ba2d8..3e1a83a8a220 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -22,6 +22,7 @@ class KVCacheBlocks: Scheduler and KVCacheManager, to hide KVCacheManager's internal data structure from the Scheduler. """ + blocks: tuple[list[KVCacheBlock], ...] """ `blocks[i][j]` refers to the i-th kv_cache_group @@ -35,22 +36,20 @@ class KVCacheBlocks: def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" return KVCacheBlocks( - tuple(blk1 + blk2 - for blk1, blk2 in zip(self.blocks, other.blocks))) + tuple(blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)) + ) @overload def get_block_ids( self, allow_none: Literal[False] = False, - ) -> tuple[list[int], ...]: - ... + ) -> tuple[list[int], ...]: ... @overload def get_block_ids( self, allow_none: Literal[True] = True, - ) -> Optional[tuple[list[int], ...]]: - ... + ) -> Optional[tuple[list[int], ...]]: ... def get_block_ids( self, @@ -72,10 +71,7 @@ def get_block_ids( def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" assert len(self.blocks) == 1, "Only one group is supported" - return [ - block.block_id for block in self.blocks[0] - if block.block_hash is None - ] + return [block.block_id for block in self.blocks[0] if block.block_hash is None] def new_empty(self) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" @@ -83,7 +79,6 @@ def new_empty(self) -> "KVCacheBlocks": class KVCacheManager: - def __init__( self, kv_cache_config: KVCacheConfig, @@ -104,12 +99,18 @@ def __init__( self.block_size: Optional[int] = None if self.enable_caching: - assert len( - set(g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups) - ) == 1, "Only one block size is supported for now" + assert ( + len( + set( + g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups + ) + ) + == 1 + ), "Only one block size is supported for now" self.block_size = kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size + 0 + ].kv_cache_spec.block_size if dcp_world_size > 1: assert len(kv_cache_config.kv_cache_groups) == 1 @@ -151,8 +152,7 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks(self, - request: Request) -> tuple[KVCacheBlocks, int]: + def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -166,9 +166,10 @@ def get_computed_blocks(self, """ # Prefix caching is disabled or # When the request requires prompt logprobs, we skip prefix caching. - if (not self.enable_caching - or (request.sampling_params is not None - and request.sampling_params.prompt_logprobs is not None)): + if not self.enable_caching or ( + request.sampling_params is not None + and request.sampling_params.prompt_logprobs is not None + ): return self.create_empty_block_list(), 0 # NOTE: When all tokens hit the cache, we must recompute the last token @@ -179,8 +180,10 @@ def get_computed_blocks(self, # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(request.block_hashes, - max_cache_hit_length)) + self.coordinator.find_longest_cache_hit( + request.block_hashes, max_cache_hit_length + ) + ) if self.log_stats: assert self.prefix_cache_stats is not None @@ -188,8 +191,7 @@ def get_computed_blocks(self, # Previously preempted request self.prefix_cache_stats.preempted_requests += 1 self.prefix_cache_stats.preempted_queries += request.num_tokens - self.prefix_cache_stats.preempted_hits += ( - num_new_computed_tokens) + self.prefix_cache_stats.preempted_hits += num_new_computed_tokens else: # New request self.prefix_cache_stats.requests += 1 @@ -250,7 +252,8 @@ def allocate_slots( new_computed_block_list = new_computed_blocks.blocks else: new_computed_block_list = tuple( - [] for _ in range(len(self.kv_cache_config.kv_cache_groups))) + [] for _ in range(len(self.kv_cache_config.kv_cache_groups)) + ) # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -258,16 +261,17 @@ def allocate_slots( # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - self.coordinator.remove_skipped_blocks(request.request_id, - request.num_computed_tokens) + self.coordinator.remove_skipped_blocks( + request.request_id, request.num_computed_tokens + ) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits - num_computed_tokens = (request.num_computed_tokens + - num_new_computed_tokens) + num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens num_tokens_need_slot = min( num_computed_tokens + num_new_tokens + num_lookahead_tokens, - self.max_model_len) + self.max_model_len, + ) num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( request_id=request.request_id, @@ -285,16 +289,18 @@ def allocate_slots( self.block_pool.touch(new_computed_block_list) else: assert not any(new_computed_block_list), ( - "Computed blocks should be empty when " - "prefix caching is disabled") + "Computed blocks should be empty when prefix caching is disabled" + ) # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - self.coordinator.save_new_computed_blocks(request.request_id, - new_computed_block_list) + self.coordinator.save_new_computed_blocks( + request.request_id, new_computed_block_list + ) new_blocks = self.coordinator.allocate_new_blocks( - request.request_id, num_tokens_need_slot, num_encoder_tokens) + request.request_id, num_tokens_need_slot, num_encoder_tokens + ) # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. @@ -305,8 +311,9 @@ def allocate_slots( # num_new_tokens, but must exclude "non-committable" tokens (e.g., # draft tokens that could be rejected). Therefore, we cap the number # at `request.num_tokens`, ensuring only "finalized" tokens are cached. - num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, - request.num_tokens) + num_tokens_to_cache = min( + num_computed_tokens + num_new_tokens, request.num_tokens + ) self.coordinator.cache_blocks(request, num_tokens_to_cache) return KVCacheBlocks(new_blocks) @@ -378,7 +385,8 @@ def get_num_common_prefix_blocks( """ assert request.status == RequestStatus.RUNNING return self.coordinator.get_num_common_prefix_blocks( - request.request_id, num_running_requests) + request.request_id, num_running_requests + ) def take_events(self) -> list[KVCacheEvent]: """Take the KV cache events from the block pool. @@ -403,5 +411,4 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks(tuple([] - for _ in range(self.num_kv_cache_groups))) + return KVCacheBlocks(tuple([] for _ in range(self.num_kv_cache_groups))) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2ff1bb681d80..4683ad62981f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -13,11 +13,16 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import GiB_bytes, cdiv, sha256_cbor -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec, - UniformTypeKVCacheSpecs) +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + KVCacheTensor, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -37,16 +42,16 @@ ExternalBlockHash = Union[bytes, int] -def make_block_hash_with_group_id(block_hash: BlockHash, - group_id: int) -> BlockHashWithGroupId: +def make_block_hash_with_group_id( + block_hash: BlockHash, group_id: int +) -> BlockHashWithGroupId: """Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``. The group id is encoded using 4 bytes in big-endian order and appended to the block hash bytes. This representation avoids creating tuples while still allowing us to recover both components when needed. """ - return BlockHashWithGroupId(block_hash + - group_id.to_bytes(4, "big", signed=False)) + return BlockHashWithGroupId(block_hash + group_id.to_bytes(4, "big", signed=False)) def get_block_hash(key: BlockHashWithGroupId) -> BlockHash: @@ -87,7 +92,8 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]): "PYTHONHASHSEED is not set. This will lead to non-reproducible " "block-hashes when using sha256_cbor as the hash function." "Consider setting PYTHONHASHSEED to a fixed value for " - "reproducibility.") + "reproducibility." + ) if hash_seed is None: NONE_HASH = BlockHash(os.urandom(32)) @@ -143,9 +149,10 @@ def observe(self, stats: PrefixCacheStats): # Remove the oldest stats until number of requests does not exceed # the limit. # NOTE: We preserve the latest added stats regardless. - while len( - self.query_queue - ) > 1 and self.aggregated_requests > self.max_recent_requests: + while ( + len(self.query_queue) > 1 + and self.aggregated_requests > self.max_recent_requests + ): old_requests, old_queries, old_hits = self.query_queue.popleft() self.aggregated_requests -= old_requests self.aggregated_query_total -= old_queries @@ -169,6 +176,7 @@ def hit_rate(self) -> float: @dataclass class KVCacheBlock: """KV-cache block metadata.""" + # Block ID, ranging from 0 to num_gpu_blocks - 1. block_id: int # Reference count. @@ -192,7 +200,8 @@ def block_hash(self) -> Optional[BlockHashWithGroupId]: @block_hash.setter def block_hash(self, block_hash: BlockHashWithGroupId): assert self.block_hash is None, ( - "The block already has a hash. This should not happen.") + "The block already has a hash. This should not happen." + ) self._block_hash = block_hash def reset_hash(self): @@ -202,15 +211,15 @@ def reset_hash(self): def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ # on KVCacheBlock object recursively. - prev_block_id = (self.prev_free_block.block_id - if self.prev_free_block else None) - next_block_id = (self.next_free_block.block_id - if self.next_free_block else None) - return (f"KVCacheBlock(block_id={self.block_id}, " - f"ref_cnt={self.ref_cnt}, " - f"_block_hash={self._block_hash!r}, " - f"prev_free_block={prev_block_id}, " - f"next_free_block={next_block_id})") + prev_block_id = self.prev_free_block.block_id if self.prev_free_block else None + next_block_id = self.next_free_block.block_id if self.next_free_block else None + return ( + f"KVCacheBlock(block_id={self.block_id}, " + f"ref_cnt={self.ref_cnt}, " + f"_block_hash={self._block_hash!r}, " + f"prev_free_block={prev_block_id}, " + f"next_free_block={next_block_id})" + ) class FreeKVCacheBlockQueue: @@ -271,12 +280,14 @@ def popleft(self) -> KVCacheBlock: Returns: The first free block. """ - if (self.fake_free_list_head.next_free_block - is self.fake_free_list_tail - or self.fake_free_list_head.next_free_block is None): + if ( + self.fake_free_list_head.next_free_block is self.fake_free_list_tail + or self.fake_free_list_head.next_free_block is None + ): assert self.num_free_blocks == 0, ( f"num_free_blocks ({self.num_free_blocks}) is out of sync " - "with the free list.") + "with the free list." + ) raise ValueError("No free blocks available") first_block: KVCacheBlock = self.fake_free_list_head.next_free_block @@ -284,8 +295,10 @@ def popleft(self) -> KVCacheBlock: if first_block.next_free_block is None: # This should not happen if the block is from the free list. # It indicates a bug in the caller's logic. - raise RuntimeError("Invalid block found in popleft() " - "which doesn't have a valid next_free_block") + raise RuntimeError( + "Invalid block found in popleft() " + "which doesn't have a valid next_free_block" + ) # Connect fake_head and the next block of first_block (i.e. second block # or fake tail). @@ -360,7 +373,8 @@ def append(self, block: KVCacheBlock) -> None: """ if self.fake_free_list_tail.prev_free_block is None: raise RuntimeError( - "prev_free_block of fake_free_list_tail should always exist") + "prev_free_block of fake_free_list_tail should always exist" + ) last_block: KVCacheBlock = self.fake_free_list_tail.prev_free_block # Connect the new block after the last block. @@ -384,7 +398,8 @@ def append_n(self, blocks: list[KVCacheBlock]) -> None: last_block = self.fake_free_list_tail.prev_free_block assert last_block is not None, ( - "prev_free_block of fake_free_list_tail should always exist") + "prev_free_block of fake_free_list_tail should always exist" + ) # Add inter-connections between consecutive blocks for block in blocks: block.prev_free_block = last_block @@ -406,7 +421,8 @@ def get_all_free_blocks(self) -> list[KVCacheBlock]: ret = [] if self.fake_free_list_head.next_free_block is None: raise RuntimeError( - "next_free_block of fake_free_list_head should always exist") + "next_free_block of fake_free_list_head should always exist" + ) # Start from the first block curr_block: KVCacheBlock = self.fake_free_list_head.next_free_block # As long as next_free_block is available, we haven't reached to @@ -430,14 +446,16 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. # Request with provided cache salt need to include the salt. - return bool(request.mm_features) or (request.lora_request - is not None) or (request.cache_salt - is not None) + return ( + bool(request.mm_features) + or (request.lora_request is not None) + or (request.cache_salt is not None) + ) -def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, - end_token_idx: int, - start_mm_idx: int) -> tuple[list[Any], int]: +def _gen_mm_extra_hash_keys( + request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int +) -> tuple[list[Any], int]: """Generate extra keys related to MultiModal request for block hash computation. For multi-modal inputs, the extra keys are (mm_hash, start_offset) that indicate a mm input contained in the @@ -515,8 +533,8 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[int]: def generate_block_hash_extra_keys( - request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]: + request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int +) -> tuple[Optional[tuple[Any, ...]], int]: """Generate extra keys for the block hash. The extra keys can come from the multi-modal inputs and request specific metadata (e.g., LoRA ID). @@ -531,10 +549,12 @@ def generate_block_hash_extra_keys( """ mm_extra_keys: list[Any] mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( - request, start_token_idx, end_token_idx, start_mm_idx) + request, start_token_idx, end_token_idx, start_mm_idx + ) lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) - cache_salt_keys: list[str] = [request.cache_salt] if ( - start_token_idx == 0 and request.cache_salt) else [] + cache_salt_keys: list[str] = ( + [request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else [] + ) extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys @@ -545,10 +565,11 @@ def generate_block_hash_extra_keys( def hash_block_tokens( - hash_function: Callable[[Any], bytes], - parent_block_hash: Optional[BlockHash], - curr_block_token_ids: Sequence[int], - extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: + hash_function: Callable[[Any], bytes], + parent_block_hash: Optional[BlockHash], + curr_block_token_ids: Sequence[int], + extra_keys: Optional[tuple[Any, ...]] = None, +) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -569,8 +590,8 @@ def hash_block_tokens( curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( - hash_function( - (parent_block_hash, curr_block_token_ids_tuple, extra_keys))) + hash_function((parent_block_hash, curr_block_token_ids_tuple, extra_keys)) + ) def get_request_block_hasher( @@ -585,6 +606,10 @@ def request_block_hasher(request: Request) -> list[BlockHash]: start_token_idx = len(request.block_hashes) * block_size num_tokens = request.num_tokens + if start_token_idx + block_size > num_tokens: + # Early stop when there no new full blocks created. + return [] + curr_mm_idx = 0 if start_token_idx > 0: # Set curr_mm_idx = -1 to indicate the last mm input. @@ -593,8 +618,9 @@ def request_block_hasher(request: Request) -> list[BlockHash]: # last mm input. curr_mm_idx = -1 - prev_block_hash_value = (request.block_hashes[-1] - if request.block_hashes else None) + prev_block_hash_value = ( + request.block_hashes[-1] if request.block_hashes else None + ) new_block_hashes: list[BlockHash] = [] while True: end_token_idx = start_token_idx + block_size @@ -604,13 +630,14 @@ def request_block_hasher(request: Request) -> list[BlockHash]: # MM and LoRA requests need extra keys for block-hash computation. extra_keys, curr_mm_idx = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, curr_mm_idx) + request, start_token_idx, end_token_idx, curr_mm_idx + ) # Compute the hash of the current block block_tokens = request.all_token_ids[start_token_idx:end_token_idx] - block_hash = hash_block_tokens(caching_hash_fn, - prev_block_hash_value, block_tokens, - extra_keys) + block_hash = hash_block_tokens( + caching_hash_fn, prev_block_hash_value, block_tokens, extra_keys + ) new_block_hashes.append(block_hash) start_token_idx += block_size @@ -621,18 +648,20 @@ def request_block_hasher(request: Request) -> list[BlockHash]: return request_block_hasher -def max_memory_usage_bytes(vllm_config: VllmConfig, - kv_cache_specs: Iterable[KVCacheSpec]) -> int: +def max_memory_usage_bytes( + vllm_config: VllmConfig, kv_cache_specs: Iterable[KVCacheSpec] +) -> int: """ Get the maximum memory usage in bytes for the given KV cache specs. """ - return sum( - spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs) + return sum(spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs) -def estimate_max_model_len(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> int: +def estimate_max_model_len( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +) -> int: """ Estimates the maximum model length that can fit in the available memory using binary search. @@ -651,8 +680,7 @@ def fits_in_memory(model_len: int) -> bool: # Modify the max_model_len for this calculation vllm_config.model_config.max_model_len = model_len # Calculate memory needed for the given model length - memory_needed = max_memory_usage_bytes(vllm_config, - kv_cache_spec.values()) + memory_needed = max_memory_usage_bytes(vllm_config, kv_cache_spec.values()) return memory_needed <= available_memory # Binary search for the maximum model length @@ -675,9 +703,11 @@ def fits_in_memory(model_len: int) -> bool: return result -def check_enough_kv_cache_memory(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int): +def check_enough_kv_cache_memory( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int, +): """ Checks whether `available_memory` is enough for the KV cache to hold at least one request with the model's max_model_len. @@ -696,36 +726,41 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, return if available_memory <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine." + ) max_model_len = vllm_config.model_config.max_model_len needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values()) if needed_memory > available_memory: # Estimate the maximum model length that can fit in the available memory - estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, - available_memory) + estimated_max_len = estimate_max_model_len( + vllm_config, kv_cache_spec, available_memory + ) estimated_msg = "" if estimated_max_len > 0: estimated_msg = ( "Based on the available memory, " - f"the estimated maximum model length is {estimated_max_len}.") + f"the estimated maximum model length is {estimated_max_len}." + ) raise ValueError( f"To serve at least one request with the models's max seq len " - f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV " + f"({max_model_len}), ({needed_memory / GiB_bytes:.2f} GiB KV " f"cache is needed, which is larger than the available KV cache " - f"memory ({available_memory/GiB_bytes:.2f} GiB). " + f"memory ({available_memory / GiB_bytes:.2f} GiB). " f"{estimated_msg} " f"Try increasing `gpu_memory_utilization` or decreasing " - f"`max_model_len` when initializing the engine.") + f"`max_model_len` when initializing the engine." + ) def create_kv_cache_group_specs( - kv_cache_spec: dict[str, KVCacheSpec], - grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: + kv_cache_spec: dict[str, KVCacheSpec], grouped_layer_names: list[list[str]] +) -> list[KVCacheGroupSpec]: """ Create KVCacheGroupSpec object for each kv cache group layer. The layers in the same group should share the same @@ -748,7 +783,8 @@ def create_kv_cache_group_specs( ] merged_layer_spec = layer_specs[0].merge(layer_specs) kv_cache_groups.append( - KVCacheGroupSpec(layer_names_one_group, merged_layer_spec)) + KVCacheGroupSpec(layer_names_one_group, merged_layer_spec) + ) return kv_cache_groups @@ -778,19 +814,22 @@ def is_kv_cache_spec_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: def get_max_concurrency_for_kv_cache_config( - vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float: + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig +) -> float: """ Get the maximum concurrency for the given KV cache configuration. """ num_layer_per_group = max( - len(group.layer_names) for group in kv_cache_config.kv_cache_groups) + len(group.layer_names) for group in kv_cache_config.kv_cache_groups + ) max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes( - vllm_config, - (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups)) - memory_per_block = kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.page_size_bytes * num_layer_per_group - num_block_per_request = cdiv(max_memory_usage_per_request, - memory_per_block) + vllm_config, (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups) + ) + memory_per_block = ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.page_size_bytes + * num_layer_per_group + ) + num_block_per_request = cdiv(max_memory_usage_per_request, memory_per_block) max_concurrency = kv_cache_config.num_blocks / num_block_per_request return max_concurrency @@ -800,18 +839,20 @@ def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int: Override the number of kv cache blocks if `num_gpu_blocks_override` is set. """ if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = \ - vllm_config.cache_config.num_gpu_blocks_override + num_gpu_blocks_override = vllm_config.cache_config.num_gpu_blocks_override logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + "Overriding num_gpu_blocks=%d with num_gpu_blocks_override=%d", + num_blocks, + num_gpu_blocks_override, + ) num_blocks = num_gpu_blocks_override return num_blocks -def get_num_blocks(vllm_config: VllmConfig, num_layers: int, - available_memory: int, page_size: int) -> int: +def get_num_blocks( + vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int +) -> int: """ Get the number of kv cache blocks. @@ -837,9 +878,10 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: def _get_kv_cache_groups_uniform_spec( - kv_cache_specs: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: + kv_cache_specs: dict[str, KVCacheSpec], +) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for a model with the same KV cache + Generates the KV cache configuration for a model with the same KV cache spec for all layers. Args: @@ -849,12 +891,12 @@ def _get_kv_cache_groups_uniform_spec( The generated KVCacheGroupSpecs """ - return create_kv_cache_group_specs(kv_cache_specs, - [list(kv_cache_specs.keys())]) + return create_kv_cache_group_specs(kv_cache_specs, [list(kv_cache_specs.keys())]) def _get_kv_cache_groups_uniform_type( - spec: UniformTypeKVCacheSpecs) -> list[KVCacheGroupSpec]: + spec: UniformTypeKVCacheSpecs, +) -> list[KVCacheGroupSpec]: """ Generates the KV cache configuration for a model with one type of KV cache but different hidden sizes. All layers are merged into one group. @@ -869,8 +911,7 @@ def _get_kv_cache_groups_uniform_type( return [KVCacheGroupSpec(list(spec.kv_cache_specs.keys()), spec)] -def is_kv_cache_page_size_uniform( - kv_cache_spec: dict[str, KVCacheSpec]) -> bool: +def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same page size. Args: @@ -884,70 +925,69 @@ def is_kv_cache_page_size_uniform( return len(page_sizes) == 1 -def is_kv_cache_type_attention_free( - kv_cache_spec: dict[str, KVCacheSpec]) -> bool: - +def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: # kv_cache_spec is an empty dict for attention free models return not kv_cache_spec def _get_kv_cache_groups_uniform_page_size( - kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: + kv_cache_spec: dict[str, KVCacheSpec], +) -> list[KVCacheGroupSpec]: """ - Generates the KV cache groups for hybrid models with multiple - attention types but still with a uniform page size (physical memory per + Generates the KV cache groups for hybrid models with multiple + attention types but still with a uniform page size (physical memory per block per layer) for all layers. Detailed explanation about kv cache management of hybrid models: The layers in the models are repeated with some patterns, e.g., a model with 10 full attention layers and 20 sliding window attention layers can be - regarded as repeating the pattern (1 * full, 2 * sw) 10 times. + regarded as repeating the pattern (1 * full, 2 * sw) 10 times. The KVCacheManager allocates different block tables for each of the 3 layers - in the pattern, and repeats each of them 10 times to generate the + in the pattern, and repeats each of them 10 times to generate the block_table for the 30 layers in the model. Therefore, we can group the layers in the model into 3 kv_cache_groups, each of which contains 10 layers in the model. The KVCacheManager allocates the block_table for each group based on its - kv_cache spec, and the model runner applies the block table to each layer + kv_cache spec, and the model runner applies the block table to each layer in the group. For example: - 1. A model only uses full attention. The pattern is - (num_hidden_layers * full), so there is only one group and the block table - is shared by all layers. It is already handled by + 1. A model only uses full attention. The pattern is + (num_hidden_layers * full), so there is only one group and the block table + is shared by all layers. It is already handled by `_get_kv_cache_config_uniform_type`. - 2. A model with 10 full attention layers and 20 sliding window - attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so + 2. A model with 10 full attention layers and 20 sliding window + attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so there are 3 kv_cache_groups, each of which represents 10 layers. To simplify the implementation, we make the following assumptions: - 1. Physical memory per block: Must be the same across all KV cache groups. + 1. Physical memory per block: Must be the same across all KV cache groups. Breaking this assumption is non-trivial due to memory fragmentation concerns when allocating blocks of different sizes. - 2. Tokens per block (block_size): Currently, we directly use - `CacheConfig.block_size` for all layers. It can be extended to vary by KV - cache group, but within each KV cache group, all layers must share the same + 2. Tokens per block (block_size): Currently, we directly use + `CacheConfig.block_size` for all layers. It can be extended to vary by KV + cache group, but within each KV cache group, all layers must share the same block size. - 3. Physical memory per token per layer: This property is decided by model - config. Currently we only support models that have the same physical memory - per token per layer for all layers. Can be relaxed with a simple extension, + 3. Physical memory per token per layer: This property is decided by model + config. Currently we only support models that have the same physical memory + per token per layer for all layers. Can be relaxed with a simple extension, but still need to keep physical memory per block the same for all groups. - 4. Number of layers per group: Currently assumed the same for all layers. - Can be relaxed with a simple extension, but still need to keep physical + 4. Number of layers per group: Currently assumed the same for all layers. + Can be relaxed with a simple extension, but still need to keep physical memory per block the same for all groups. 5. Attention type within groups: All layers in a group must share the same - attention type. One exception is that, when - `--disable-hybrid-kv-cache-manager` is true, the single group for full - attention layers may also include attention layers using sliding window or + attention type. One exception is that, when + `--disable-hybrid-kv-cache-manager` is true, the single group for full + attention layers may also include attention layers using sliding window or LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details. - 6. Support for multiple attention types: The design for most components is - general to an arbitrary number of attention types. But - `find_longest_cache_hit` only supports one attention type or two + 6. Support for multiple attention types: The design for most components is + general to an arbitrary number of attention types. But + `find_longest_cache_hit` only supports one attention type or two types of full-attention plus exactly one another type. The general - implementation of this function is feasible but we don't know how to + implementation of this function is feasible but we don't know how to implement it cleanly yet. - As we assume tokens per block, physical memory per token per layer, and - number of layers per group are the same now, we can ensure that physical + As we assume tokens per block, physical memory per token per layer, and + number of layers per group are the same now, we can ensure that physical memory per block is the same for all groups. Args: @@ -1001,10 +1041,12 @@ def _get_kv_cache_groups_uniform_page_size( return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) -def get_kv_cache_config_from_groups(vllm_config: VllmConfig, - kv_cache_groups: list[KVCacheGroupSpec], - kv_cache_specs: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def get_kv_cache_config_from_groups( + vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], + kv_cache_specs: dict[str, KVCacheSpec], + available_memory: int, +) -> KVCacheConfig: """ Generate the KV cache configuration from the KV cache groups and spec of each layer. @@ -1027,19 +1069,22 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, ) # Determine how model runners should initialize the KV cache tensors. - if len(kv_cache_groups) == 1 and \ - isinstance(kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs): + if len(kv_cache_groups) == 1 and isinstance( + kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs + ): # Special case: all layers have the same type of KV cache but with # different hidden size. Allocate different amount of memory for each # layer based on its hidden size. - num_blocks = available_memory // kv_cache_groups[ - 0].kv_cache_spec.page_size_bytes + num_blocks = ( + available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes + ) num_blocks = may_override_num_blocks(vllm_config, num_blocks) per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs kv_cache_tensors = [ - KVCacheTensor(size=per_layer_specs[layer_name].page_size_bytes * - num_blocks, - shared_by=[layer_name]) + KVCacheTensor( + size=per_layer_specs[layer_name].page_size_bytes * num_blocks, + shared_by=[layer_name], + ) for layer_name in kv_cache_groups[0].layer_names ] else: @@ -1055,8 +1100,9 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, page_size = get_uniform_page_size(kv_cache_specs) assert group_size > 0, "group_size must be greater than 0" - num_blocks = get_num_blocks(vllm_config, group_size, available_memory, - page_size) + num_blocks = get_num_blocks( + vllm_config, group_size, available_memory, page_size + ) kv_cache_tensors = [] for i in range(group_size): shared_by = [] @@ -1064,8 +1110,8 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, if i < len(kv_cache_groups[j].layer_names): shared_by.append(kv_cache_groups[j].layer_names[i]) kv_cache_tensors.append( - KVCacheTensor(size=page_size * num_blocks, - shared_by=shared_by)) + KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by) + ) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, @@ -1073,8 +1119,7 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, kv_cache_groups=kv_cache_groups, ) - min_block_size = min( - [group.kv_cache_spec.block_size for group in kv_cache_groups]) + min_block_size = min([group.kv_cache_spec.block_size for group in kv_cache_groups]) # Print the KV cache size and maximum concurrency. num_tokens = num_blocks // len(kv_cache_groups) * min_block_size @@ -1082,14 +1127,19 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, num_tokens *= vllm_config.parallel_config.decode_context_parallel_size logger.info( "Multiplying the GPU KV cache size by the dcp_world_size %d.", - vllm_config.parallel_config.decode_context_parallel_size) + vllm_config.parallel_config.decode_context_parallel_size, + ) num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" max_concurrency = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, max_concurrency) + vllm_config, kv_cache_config + ) + logger.info( + "Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, + max_concurrency, + ) return kv_cache_config @@ -1104,25 +1154,27 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ if is_kv_cache_spec_uniform( - kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type( - kv_cache_spec): + kv_cache_spec + ) or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec): return logger.warning( "Hybrid KV cache manager is disabled for this hybrid model, " "This means we do not enable any optimizations for saving KV cache " "memory (e.g., dropping the KV cache outside the sliding window). " - "The compute of layers like sliding window is still saved.") + "The compute of layers like sliding window is still saved." + ) has_full_attention = any( - isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) + isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values() + ) has_sliding_window = any( - isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values()) + isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values() + ) has_chunked_local_attention = any( - isinstance(spec, ChunkedLocalAttentionSpec) - for spec in kv_cache_spec.values()) - if has_full_attention and (has_sliding_window - or has_chunked_local_attention): + isinstance(spec, ChunkedLocalAttentionSpec) for spec in kv_cache_spec.values() + ) + if has_full_attention and (has_sliding_window or has_chunked_local_attention): for layer_name, spec in kv_cache_spec.items(): if isinstance(spec, SlidingWindowSpec): kv_cache_spec[layer_name] = FullAttentionSpec( @@ -1141,15 +1193,19 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): attention_chunk_size=spec.attention_chunk_size, ) - if not (is_kv_cache_spec_uniform(kv_cache_spec) - or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)): - raise ValueError("Hybrid KV cache manager is disabled but failed to " - "convert the KV cache specs to one unified type.") + if not ( + is_kv_cache_spec_uniform(kv_cache_spec) + or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec) + ): + raise ValueError( + "Hybrid KV cache manager is disabled but failed to " + "convert the KV cache specs to one unified type." + ) def get_kv_cache_groups( - vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec] +) -> list[KVCacheGroupSpec]: """ Split the layers in the model into groups with the same KV cache spec. @@ -1188,14 +1244,14 @@ def get_kv_cache_groups( def generate_scheduler_kv_cache_config( - kv_cache_configs: list[KVCacheConfig]) -> KVCacheConfig: + kv_cache_configs: list[KVCacheConfig], +) -> KVCacheConfig: """ Generate the KV cache configuration for the scheduler. """ - assert all([ - cfg.num_blocks == kv_cache_configs[0].num_blocks - for cfg in kv_cache_configs - ]) + assert all( + [cfg.num_blocks == kv_cache_configs[0].num_blocks for cfg in kv_cache_configs] + ) # All workers have the same kv_cache_config except layer names, so use # an arbitrary one to initialize the scheduler. cfg = copy.deepcopy(kv_cache_configs[0]) @@ -1204,15 +1260,18 @@ def generate_scheduler_kv_cache_config( # All layers in the UniformTypeKVCacheSpecs have the same type, # so use an arbitrary one to initialize the scheduler. group.kv_cache_spec = next( - iter(group.kv_cache_spec.kv_cache_specs.values())) + iter(group.kv_cache_spec.kv_cache_specs.values()) + ) return cfg -def get_kv_cache_configs(vllm_config: VllmConfig, - kv_cache_specs: list[dict[str, KVCacheSpec]], - available_memory: list[int]) -> list[KVCacheConfig]: +def get_kv_cache_configs( + vllm_config: VllmConfig, + kv_cache_specs: list[dict[str, KVCacheSpec]], + available_memory: list[int], +) -> list[KVCacheConfig]: """ - Generates the KV cache configurations for a model. + Generates the KV cache configurations for a model. Since we use a shared centralized controller for all workers, we need the `kv_cache_config` to be consistent across all workers to make sure the KV cache allocation can be applied to all workers. However, different @@ -1231,7 +1290,7 @@ def get_kv_cache_configs(vllm_config: VllmConfig, vllm_config: The global VllmConfig kv_cache_specs: List of dict[layer_name, KVCacheSpec] for each worker. available_memory: Memory available for KV cache in bytes for each - worker. + worker. Returns: The generated KVCacheConfigs for each worker. @@ -1239,9 +1298,11 @@ def get_kv_cache_configs(vllm_config: VllmConfig, # Check if the available memory is enough for each worker. for kv_cache_spec_one_worker, available_memory_one_worker in zip( - kv_cache_specs, available_memory): - check_enough_kv_cache_memory(vllm_config, kv_cache_spec_one_worker, - available_memory_one_worker) + kv_cache_specs, available_memory + ): + check_enough_kv_cache_memory( + vllm_config, kv_cache_spec_one_worker, available_memory_one_worker + ) # Merge the KV cache specs of all workers. Different PP stages may have # different layer names, and different TP ranks of the same PP stage should @@ -1254,37 +1315,42 @@ def get_kv_cache_configs(vllm_config: VllmConfig, else: assert merged_kv_cache_specs[layer_name] == layer_spec, ( "The KV cache specs for the same layer are different " - "across workers. This is not supported yet.") - global_kv_cache_groups = get_kv_cache_groups(vllm_config, - merged_kv_cache_specs) + "across workers. This is not supported yet." + ) + global_kv_cache_groups = get_kv_cache_groups(vllm_config, merged_kv_cache_specs) kv_cache_configs: list[KVCacheConfig] = [] for kv_cache_spec_one_worker, available_memory_one_worker in zip( - kv_cache_specs, available_memory): + kv_cache_specs, available_memory + ): kv_cache_groups_one_worker: list[KVCacheGroupSpec] = [] for group in global_kv_cache_groups: group_layer_names_one_worker = [ - layer_name for layer_name in group.layer_names + layer_name + for layer_name in group.layer_names if layer_name in kv_cache_spec_one_worker ] kv_cache_groups_one_worker.append( - KVCacheGroupSpec(group_layer_names_one_worker, - group.kv_cache_spec)) + KVCacheGroupSpec(group_layer_names_one_worker, group.kv_cache_spec) + ) assert sum( - len(group.layer_names) for group in - kv_cache_groups_one_worker) == len(kv_cache_spec_one_worker), ( - "Some layers are not assigned to any group.") + len(group.layer_names) for group in kv_cache_groups_one_worker + ) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group." kv_cache_configs.append( - get_kv_cache_config_from_groups(vllm_config, - kv_cache_groups_one_worker, - kv_cache_spec_one_worker, - available_memory_one_worker)) + get_kv_cache_config_from_groups( + vllm_config, + kv_cache_groups_one_worker, + kv_cache_spec_one_worker, + available_memory_one_worker, + ) + ) # Change the num_blocks of each rank to the smallest among all ranks. We # do not need to shrink the tensor size because it is valid to only use the # first `num_blocks` blocks of the tensor. - min_num_blocks = min(kv_cache_config.num_blocks - for kv_cache_config in kv_cache_configs) + min_num_blocks = min( + kv_cache_config.num_blocks for kv_cache_config in kv_cache_configs + ) for kv_cache_config in kv_cache_configs: kv_cache_config.num_blocks = min_num_blocks diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 74ff6261732c..968b4db530bf 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -12,7 +12,6 @@ class AsyncScheduler(Scheduler): - def _update_after_schedule( self, scheduler_output: SchedulerOutput, @@ -20,8 +19,10 @@ def _update_after_schedule( super()._update_after_schedule(scheduler_output) for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] - if (request.num_computed_tokens == request.num_tokens + - request.num_output_placeholders): + if ( + request.num_computed_tokens + == request.num_tokens + request.num_output_placeholders + ): # The request will generate a new token in this scheduling step. # TODO(woosuk): Support speculative decoding. request.num_output_placeholders += 1 @@ -33,7 +34,8 @@ def _update_request_with_output( ) -> tuple[list[int], bool]: status_before_update = request.status new_token_ids, stopped = super()._update_request_with_output( - request, new_token_ids) + request, new_token_ids + ) # Update the number of output placeholders. request.num_output_placeholders -= len(new_token_ids) @@ -42,6 +44,6 @@ def _update_request_with_output( # Cache the new tokens. Preempted requests should be skipped. if status_before_update == RequestStatus.RUNNING: self.kv_cache_manager.cache_blocks( - request, - request.num_computed_tokens - request.num_output_placeholders) + request, request.num_computed_tokens - request.num_output_placeholders + ) return new_token_ids, stopped diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 5b1de3a66ceb..b92ef395e9b7 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -14,7 +14,6 @@ class SchedulerInterface(ABC): - @abstractmethod def schedule(self) -> "SchedulerOutput": """Schedule the requests to process in this scheduling step. @@ -72,7 +71,7 @@ def update_draft_token_ids( @abstractmethod def add_request(self, request: "Request") -> None: """Add a new request to the scheduler's internal queue. - + Args: request: The new request being added. """ @@ -91,7 +90,7 @@ def finish_requests( 1. When the request is aborted by the client. 2. When the frontend process detects a stop string of the request after de-tokenizing its generated tokens. - + Args: request_ids: A single or a list of request IDs. finished_status: The finished status of the given requests. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 6874e713aff3..cbce91b990a1 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from vllm._bc_linter import bc_linter_include @@ -13,8 +13,7 @@ import numpy.typing as npt import torch - from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorMetadata) + from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams @@ -25,16 +24,15 @@ @bc_linter_include @dataclass class NewRequestData: - req_id: str - prompt_token_ids: Optional[list[int]] + prompt_token_ids: list[int] | None mm_features: list[MultiModalFeatureSpec] - sampling_params: Optional[SamplingParams] - pooling_params: Optional[PoolingParams] + sampling_params: SamplingParams | None + pooling_params: PoolingParams | None block_ids: tuple[list[int], ...] num_computed_tokens: int - lora_request: Optional[LoRARequest] - prompt_embeds: Optional[torch.Tensor] = None + lora_request: LoRARequest | None + prompt_embeds: torch.Tensor | None = None @classmethod def from_request( @@ -55,42 +53,43 @@ def from_request( ) def __repr__(self) -> str: - prompt_embeds_shape = (self.prompt_embeds.shape - if self.prompt_embeds else None) - return (f"NewRequestData(" - f"req_id={self.req_id}," - f"prompt_token_ids={self.prompt_token_ids}," - f"mm_features={self.mm_features}," - f"sampling_params={self.sampling_params}," - f"block_ids={self.block_ids}," - f"num_computed_tokens={self.num_computed_tokens}," - f"lora_request={self.lora_request}," - f"prompt_embeds_shape={prompt_embeds_shape}" - ")") + prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None + return ( + f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids={self.prompt_token_ids}," + f"mm_features={self.mm_features}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" + ")" + ) # Version of __repr__ with the prompt data obfuscated def anon_repr(self) -> str: - prompt_token_ids_len = len( - self.prompt_token_ids - ) if self.prompt_token_ids is not None else None - prompt_embeds_shape = (self.prompt_embeds.shape - if self.prompt_embeds else None) - return (f"NewRequestData(" - f"req_id={self.req_id}," - f"prompt_token_ids_len={prompt_token_ids_len}," - f"mm_features={self.mm_features}," - f"sampling_params={self.sampling_params}," - f"block_ids={self.block_ids}," - f"num_computed_tokens={self.num_computed_tokens}," - f"lora_request={self.lora_request}," - f"prompt_embeds_shape={prompt_embeds_shape}" - ")") + prompt_token_ids_len = ( + len(self.prompt_token_ids) if self.prompt_token_ids is not None else None + ) + prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None + return ( + f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids_len={prompt_token_ids_len}," + f"mm_features={self.mm_features}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" + ")" + ) @bc_linter_include @dataclass class CachedRequestData: - req_ids: list[str] # If resumed_from_preemption is False, new_block_ids will be appended to # the request's block IDs. If True, new_block_ids will be used as the @@ -99,7 +98,7 @@ class CachedRequestData: # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. # When PP is not used, new_token_ids will be empty. new_token_ids: list[list[int]] - new_block_ids: list[Optional[tuple[list[int], ...]]] + new_block_ids: list[tuple[list[int], ...] | None] num_computed_tokens: list[int] num_output_tokens: list[int] @@ -122,7 +121,6 @@ def make_empty(cls) -> CachedRequestData: @bc_linter_include @dataclass class SchedulerOutput: - # list of the requests that are scheduled for the first time. # We cache the request's data in each worker process, so that we don't # need to re-send it every scheduling step. @@ -162,7 +160,7 @@ class SchedulerOutput: # for filling the next token bitmask structured_output_request_ids: dict[str, int] # the bitmask for the whole batch - grammar_bitmask: Optional[npt.NDArray[np.int32]] + grammar_bitmask: npt.NDArray[np.int32] | None # KV Cache Connector metadata. - kv_connector_metadata: Optional[KVConnectorMetadata] = None + kv_connector_metadata: KVConnectorMetadata | None = None diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index fc2bc30b9a5f..33e5ec72ebd7 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -14,6 +14,7 @@ class SchedulingPolicy(Enum): """Enum for scheduling policies.""" + FCFS = "fcfs" PRIORITY = "priority" @@ -111,9 +112,7 @@ def remove_request(self, request: Request) -> None: def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the queue.""" requests_to_remove = set(requests) - filtered_requests = [ - req for req in self if req not in requests_to_remove - ] + filtered_requests = [req for req in self if req not in requests_to_remove] # deque does not support in-place filtering, so we need to clear # and extend self.clear() @@ -150,8 +149,7 @@ def __init__(self) -> None: def add_request(self, request: Request) -> None: """Add a request to the queue according to priority policy.""" - heapq.heappush(self._heap, - (request.priority, request.arrival_time, request)) + heapq.heappush(self._heap, (request.priority, request.arrival_time, request)) def pop_request(self) -> Request: """Pop a request from the queue according to priority policy.""" @@ -169,15 +167,15 @@ def peek_request(self) -> Request: def prepend_request(self, request: Request) -> None: """Add a request to the queue according to priority policy. - - Note: In a priority queue, there is no concept of prepending to the + + Note: In a priority queue, there is no concept of prepending to the front. Requests are ordered by (priority, arrival_time).""" self.add_request(request) def prepend_requests(self, requests: RequestQueue) -> None: """Add all requests from another queue according to priority policy. - - Note: In a priority queue, there is no concept of prepending to the + + Note: In a priority queue, there is no concept of prepending to the front. Requests are ordered by (priority, arrival_time).""" for request in requests: self.add_request(request) @@ -190,8 +188,9 @@ def remove_request(self, request: Request) -> None: def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the queue.""" requests_to_remove = set(requests) - self._heap = [(p, t, r) for p, t, r in self._heap - if r not in requests_to_remove] + self._heap = [ + (p, t, r) for p, t, r in self._heap if r not in requests_to_remove + ] heapq.heapify(self._heap) def __bool__(self) -> bool: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d4be1b06b3b2..d9a0ff1aa5c9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,29 +7,28 @@ import time from collections import defaultdict from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, - compute_encoder_budget) +from vllm.v1.core.encoder_cache_manager import ( + EncoderCacheManager, + compute_encoder_budget, +) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) -from vllm.v1.core.sched.request_queue import (SchedulingPolicy, - create_request_queue) +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all -from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, - EngineCoreOutputs) +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -41,7 +40,6 @@ class Scheduler(SchedulerInterface): - def __init__( self, vllm_config: VllmConfig, @@ -66,17 +64,18 @@ def __init__( # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. - self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( - defaultdict(set) if include_finished_set else None) + self.finished_req_ids_dict: dict[int, set[str]] | None = ( + defaultdict(set) if include_finished_set else None + ) # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens + self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events) + and self.kv_events_config.enable_kv_cache_events + ) # Create KVConnector for the Scheduler. Note that each Worker # will have a corresponding KVConnector with Role=WORKER. @@ -85,12 +84,14 @@ def __init__( if self.vllm_config.kv_transfer_config is not None: assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " - "with KV connectors") + "with KV connectors" + ) assert not self.is_encoder_decoder, ( - "Encoder-decoder models are not currently supported " - "with KV connectors") + "Encoder-decoder models are not currently supported with KV connectors" + ) self.connector = KVConnectorFactory.create_connector( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + config=self.vllm_config, role=KVConnectorRole.SCHEDULER + ) self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, @@ -102,8 +103,7 @@ def __init__( self.block_size = self.cache_config.block_size - self.dcp_world_size = \ - vllm_config.parallel_config.decode_context_parallel_size + self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size # Note(hc): The scheduler’s block_size must be multiplied # by dcp_world_size, since block hashes are computed on the # original full token sequence at a granularity of @@ -120,7 +120,8 @@ def __init__( self.policy = SchedulingPolicy.FCFS else: raise ValueError( - f"Unknown scheduling policy: {self.scheduler_config.policy}") + f"Unknown scheduling policy: {self.scheduler_config.policy}" + ) # Priority queues for requests. self.waiting = create_request_queue(self.policy) self.running: list[Request] = [] @@ -153,8 +154,7 @@ def __init__( # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 # for these models. - self.encoder_cache_manager = EncoderCacheManager( - cache_size=encoder_cache_size) + self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config self.use_eagle = False @@ -211,30 +211,35 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. num_new_tokens = min( - num_new_tokens, - self.max_model_len - 1 - request.num_computed_tokens) + num_new_tokens, self.max_model_len - request.num_computed_tokens + ) # Schedule encoder inputs. encoder_inputs_to_schedule = None new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -257,7 +262,8 @@ def schedule(self) -> SchedulerOutput: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) + num_lookahead_tokens=self.num_lookahead_tokens, + ) if new_blocks is not None: # The request can be scheduled. @@ -282,8 +288,9 @@ def schedule(self) -> SchedulerOutput: preempted_req.num_computed_tokens = 0 preempted_req.num_preemptions += 1 if self.log_stats: - preempted_req.record_event(EngineCoreEventType.PREEMPTED, - scheduled_timestamp) + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) self.waiting.prepend_request(preempted_req) preempted_reqs.append(preempted_req) @@ -304,19 +311,21 @@ def schedule(self) -> SchedulerOutput: # Speculative decode related. if request.spec_token_ids: - num_scheduled_spec_tokens = (num_new_tokens + - request.num_computed_tokens - - request.num_tokens) + num_scheduled_spec_tokens = ( + num_new_tokens + request.num_computed_tokens - request.num_tokens + ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) + request.spec_token_ids + ) # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -326,8 +335,10 @@ def schedule(self) -> SchedulerOutput: scheduled_loras: set[int] = set() if self.lora_config: scheduled_loras = set( - req.lora_request.lora_int_id for req in scheduled_running_reqs - if req.lora_request and req.lora_request.lora_int_id > 0) + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) assert len(scheduled_loras) <= self.lora_config.max_loras # Use a temporary RequestQueue to collect requests that need to be @@ -350,7 +361,8 @@ def schedule(self) -> SchedulerOutput: else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) + request.request_id, + ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -368,9 +380,14 @@ def schedule(self) -> SchedulerOutput: # Check that adding the request still respects the max_loras # constraint. - if (self.lora_config and request.lora_request and - (len(scheduled_loras) == self.lora_config.max_loras and - request.lora_request.lora_int_id not in scheduled_loras)): + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): # Scheduling would exceed max_loras, skip. self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -382,15 +399,17 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request) + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: num_external_computed_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens)) + request, num_new_local_computed_tokens + ) + ) if num_external_computed_tokens is None: # The request cannot be scheduled because @@ -401,13 +420,15 @@ def schedule(self) -> SchedulerOutput: continue # Total computed tokens (local + external). - num_computed_tokens = (num_new_local_computed_tokens + - num_external_computed_tokens) + num_computed_tokens = ( + num_new_local_computed_tokens + num_external_computed_tokens + ) # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. else: new_computed_blocks = ( - self.kv_cache_manager.create_empty_block_list()) + self.kv_cache_manager.create_empty_block_list() + ) num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens @@ -424,15 +445,21 @@ def schedule(self) -> SchedulerOutput: # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): + if ( + 0 + < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens + ): num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + self.scheduler_config.long_prefill_token_threshold + ) # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked - if not self.scheduler_config.chunked_prefill_enabled and \ - num_new_tokens > token_budget: + if ( + not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget + ): self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -442,11 +469,16 @@ def schedule(self) -> SchedulerOutput: # Schedule encoder inputs. if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -456,9 +488,9 @@ def schedule(self) -> SchedulerOutput: # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. - effective_lookahead_tokens = (0 if request.num_computed_tokens - == 0 else - self.num_lookahead_tokens) + effective_lookahead_tokens = ( + 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens + ) # Determine if we need to allocate cross-attention blocks. if self.is_encoder_decoder and request.has_encoder_inputs: @@ -466,8 +498,9 @@ def schedule(self) -> SchedulerOutput: # always padded to the maximum length. If we support other # encoder-decoder models, this will need to be updated if we # want to only allocate what is needed. - num_encoder_tokens =\ + num_encoder_tokens = ( self.scheduler_config.max_num_encoder_input_tokens + ) else: num_encoder_tokens = 0 @@ -509,20 +542,21 @@ def schedule(self) -> SchedulerOutput: req_index += 1 self.running.append(request) if self.log_stats: - request.record_event(EngineCoreEventType.SCHEDULED, - scheduled_timestamp) + request.record_event( + EngineCoreEventType.SCHEDULED, scheduled_timestamp + ) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: - raise RuntimeError( - f"Invalid request status: {request.status}") + raise RuntimeError(f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_blocks[request.request_id] = ( - self.kv_cache_manager.get_blocks(request.request_id)) + self.kv_cache_manager.get_blocks(request.request_id) + ) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -533,7 +567,8 @@ def schedule(self) -> SchedulerOutput: # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -551,23 +586,26 @@ def schedule(self) -> SchedulerOutput: # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) <= len(self.running)) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs + ) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] num_common_prefix_blocks = ( self.kv_cache_manager.get_num_common_prefix_blocks( - any_request, len(self.running))) + any_request, len(self.running) + ) + ) # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) + req, req_to_new_blocks[req.request_id].get_block_ids() + ) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( @@ -577,11 +615,12 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_blocks, ) - scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + - scheduled_resumed_reqs) - structured_output_request_ids, grammar_bitmask = ( - self.get_grammar_bitmask(scheduled_requests, - scheduled_spec_decode_tokens)) + scheduled_requests = ( + scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs + ) + structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( + scheduled_requests, scheduled_spec_decode_tokens + ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -595,8 +634,7 @@ def schedule(self) -> SchedulerOutput: # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_mm_hashes=self.encoder_cache_manager. - get_freed_mm_hashes(), + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) @@ -670,7 +708,7 @@ def _make_cached_request_data( ) -> CachedRequestData: req_ids: list[str] = [] new_token_ids: list[list[int]] = [] - new_block_ids: list[Optional[tuple[list[int], ...]]] = [] + new_block_ids: list[tuple[list[int], ...] | None] = [] num_computed_tokens: list[int] = [] num_output_tokens: list[int] = [] @@ -678,16 +716,18 @@ def _make_cached_request_data( for req in itertools.chain(running_reqs, resumed_reqs): req_id = req.request_id req_ids.append(req_id) - num_tokens = (num_scheduled_tokens[req_id] - - len(spec_decode_tokens.get(req_id, ()))) + num_tokens = num_scheduled_tokens[req_id] - len( + spec_decode_tokens.get(req_id, ()) + ) if self.use_pp: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. Otherwise, we don't # need to send the sampled tokens back because the model runner # will cache them. - token_ids = req.all_token_ids[req.num_computed_tokens:req. - num_computed_tokens + num_tokens] + token_ids = req.all_token_ids[ + req.num_computed_tokens : req.num_computed_tokens + num_tokens + ] new_token_ids.append(token_ids) elif use_connector: # When using a KVConnector, we add a placeholder to avoid index @@ -695,7 +735,8 @@ def _make_cached_request_data( # is updated to handle token IDs properly. new_token_ids.append([]) new_block_ids.append( - req_to_new_blocks[req_id].get_block_ids(allow_none=True)) + req_to_new_blocks[req_id].get_block_ids(allow_none=True) + ) num_computed_tokens.append(req.num_computed_tokens) num_output_tokens.append(len(req.output_token_ids)) # Because resumed_reqs is usually empty, it is more efficient to do @@ -764,7 +805,8 @@ def _try_schedule_encoder_inputs( if self.is_encoder_decoder and num_computed_tokens > 0: assert start_pos == 0, ( "Encoder input should be processed at the beginning of " - "the sequence when encoder-decoder models are used.") + "the sequence when encoder-decoder models are used." + ) # Encoder input has already been computed # The calculation here is a bit different. We don't turn encoder # output into tokens that get processed by the decoder and @@ -788,8 +830,7 @@ def _try_schedule_encoder_inputs( # current step. continue - if self.encoder_cache_manager.check_and_update_cache( - request, i): + if self.encoder_cache_manager.check_and_update_cache(request, i): # The encoder input is already computed and cached from a # previous step. continue @@ -797,16 +838,18 @@ def _try_schedule_encoder_inputs( # If no encoder input chunking is allowed, we do not want to # partially schedule a multimodal item. If the scheduled range would # only cover part of the mm input, roll back to before the mm item. - if (self.scheduler_config.disable_chunked_mm_input - and num_computed_tokens < start_pos - and (num_computed_tokens + num_new_tokens) - < (start_pos + num_encoder_tokens)): + if ( + self.scheduler_config.disable_chunked_mm_input + and num_computed_tokens < start_pos + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens) + ): num_new_tokens = start_pos - num_computed_tokens break if not self.encoder_cache_manager.can_allocate( - request, i, encoder_compute_budget, - num_tokens_to_schedule): + request, i, encoder_compute_budget, num_tokens_to_schedule + ): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should # be processed altogether, as the encoder usually uses @@ -878,9 +921,10 @@ def update_from_output( kv_connector_output = model_runner_output.kv_connector_output outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) - spec_decoding_stats: Optional[SpecDecodingStats] = None - kv_connector_stats = (kv_connector_output.kv_connector_stats - if kv_connector_output else None) + spec_decoding_stats: SpecDecodingStats | None = None + kv_connector_stats = ( + kv_connector_output.kv_connector_stats if kv_connector_output else None + ) failed_kv_load_req_ids = None if kv_connector_output and kv_connector_output.invalid_block_ids: @@ -888,7 +932,8 @@ def update_from_output( # load. Identify affected requests and adjust their computed token # count to trigger recomputation of the invalid blocks. failed_kv_load_req_ids = self._handle_invalid_blocks( - kv_connector_output.invalid_block_ids) + kv_connector_output.invalid_block_ids + ) # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best @@ -908,11 +953,13 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[ - req_index] if sampled_token_ids else [] + generated_token_ids = ( + sampled_token_ids[req_index] if sampled_token_ids else [] + ) scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + scheduler_output.scheduled_spec_decode_tokens.get(req_id) + ) if scheduled_spec_token_ids: num_draft_tokens = len(scheduled_spec_token_ids) num_accepted = len(generated_token_ids) - 1 @@ -926,7 +973,8 @@ def update_from_output( spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted) + num_accepted_tokens=num_accepted, + ) stopped = False new_logprobs = None @@ -937,14 +985,14 @@ def update_from_output( # Check for stop and update request status. if new_token_ids: new_token_ids, stopped = self._update_request_with_output( - request, new_token_ids) + request, new_token_ids + ) # Stop checking for pooler models. pooler_output = None if pooler_outputs: pooler_output = pooler_outputs[req_index] - stopped = check_stop(request, self.max_model_len, - pooler_output) + stopped = check_stop(request, self.max_model_len, pooler_output) if stopped: kv_transfer_params = self._free_request(request) @@ -954,28 +1002,29 @@ def update_from_output( stopped_preempted_reqs.add(request) # Extract sample logprobs if needed. - if request.sampling_params is not None \ - and request.sampling_params.logprobs is not None and logprobs: + if ( + request.sampling_params is not None + and request.sampling_params.logprobs is not None + and logprobs + ): # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and self.structured_output_manager.should_advance( - request): + if new_token_ids and self.structured_output_manager.should_advance(request): # NOTE: structured_output_request # should not be None if use_structured_output, we have # checked above, so safe to ignore type warning request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) + req_id, new_token_ids + ) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None \ - or kv_transfer_params: - + if new_token_ids or pooler_output is not None or kv_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( @@ -990,7 +1039,8 @@ def update_from_output( kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, - )) + ) + ) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -1023,11 +1073,13 @@ def update_from_output( eco.finished_requests = finished_set else: engine_core_outputs[client_index] = EngineCoreOutputs( - finished_requests=finished_set) + finished_requests=finished_set + ) finished_req_ids.clear() - if (stats := self.make_stats(spec_decoding_stats, - kv_connector_stats)) is not None: + if ( + stats := self.make_stats(spec_decoding_stats, kv_connector_stats) + ) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: # We must return the stats even if there are no request @@ -1058,8 +1110,9 @@ def _update_request_with_output( return new_token_ids, stopped def _free_encoder_inputs(self, request: Request) -> None: - cached_encoder_input_ids = ( - self.encoder_cache_manager.get_cached_input_ids(request)) + cached_encoder_input_ids = self.encoder_cache_manager.get_cached_input_ids( + request + ) # OPTIMIZATION: Avoid list(set) if the set is empty. if not cached_encoder_input_ids: return @@ -1074,21 +1127,19 @@ def _free_encoder_inputs(self, request: Request) -> None: # With Whisper, as soon as we've generated a single token, # we know we're done with the encoder input. Cross Attention # KVs have been calculated and cached already. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) elif start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) def update_draft_token_ids( self, draft_token_ids: DraftTokenIds, ) -> None: for req_id, spec_token_ids in zip( - draft_token_ids.req_ids, - draft_token_ids.draft_token_ids, + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, ): request = self.requests.get(req_id) if request is None or request.is_finished(): @@ -1102,7 +1153,8 @@ def update_draft_token_ids( elif self.structured_output_manager.should_advance(request): metadata = request.structured_output_request request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids) + spec_token_ids + ) else: request.spec_token_ids = spec_token_ids @@ -1128,7 +1180,7 @@ def finish_requests( """ assert RequestStatus.is_finished(finished_status) if isinstance(request_ids, str): - request_ids = (request_ids, ) + request_ids = (request_ids,) else: request_ids = set(request_ids) @@ -1160,7 +1212,7 @@ def finish_requests( request.status = finished_status self._free_request(request) - def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + def _free_request(self, request: Request) -> dict[str, Any] | None: assert request.is_finished() delay_free_blocks, kv_xfer_params = self._connector_finished(request) @@ -1191,36 +1243,36 @@ def reset_prefix_cache(self) -> bool: def make_stats( self, - spec_decoding_stats: Optional[SpecDecodingStats] = None, - kv_connector_stats: Optional[KVConnectorStats] = None, - ) -> Optional[SchedulerStats]: + spec_decoding_stats: SpecDecodingStats | None = None, + kv_connector_stats: KVConnectorStats | None = None, + ) -> SchedulerStats | None: if not self.log_stats: return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() assert prefix_cache_stats is not None - return SchedulerStats(num_running_reqs=len(self.running), - num_waiting_reqs=len(self.waiting), - kv_cache_usage=self.kv_cache_manager.usage, - prefix_cache_stats=prefix_cache_stats, - spec_decoding_stats=spec_decoding_stats, - num_corrupted_reqs=sum(req.is_output_corrupted - for req in self.running), - kv_connector_stats=kv_connector_stats.data - if kv_connector_stats else None) + return SchedulerStats( + num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + kv_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=prefix_cache_stats, + spec_decoding_stats=spec_decoding_stats, + num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), + kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, + ) def make_spec_decoding_stats( self, - spec_decoding_stats: Optional[SpecDecodingStats], + spec_decoding_stats: SpecDecodingStats | None, num_draft_tokens: int, num_accepted_tokens: int, - ) -> Optional[SpecDecodingStats]: + ) -> SpecDecodingStats | None: if not self.log_stats: return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) spec_decoding_stats.observe_draft( - num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens + ) return spec_decoding_stats def shutdown(self) -> None: @@ -1233,11 +1285,12 @@ def shutdown(self) -> None: # KV Connector Related Methods ######################################################################## - def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: + def get_kv_connector(self) -> KVConnectorBase_V1 | None: return self.connector def _connector_finished( - self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + self, request: Request + ) -> tuple[bool, dict[str, Any] | None]: """ Invoke the KV connector request_finished() method if applicable. @@ -1247,7 +1300,7 @@ def _connector_finished( if self.connector is None: return False, None - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: @@ -1271,8 +1324,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: # updated in _update_requests_with_invalid_blocks if request.num_computed_tokens: # Cache any valid computed tokens. - self.kv_cache_manager.cache_blocks(request, - request.num_computed_tokens) + self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens) else: # No valid computed tokens, release allocated blocks. # There may be a local cache hit on retry. @@ -1281,8 +1333,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: self.failed_recving_kv_req_ids.remove(request.request_id) else: # Now that the blocks are ready, actually cache them. - (block_ids, ) = self.kv_cache_manager.get_block_ids( - request.request_id) + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) num_computed_tokens = len(block_ids) * self.block_size # Handle the case where num request tokens less than one block. num_computed_tokens = min(num_computed_tokens, request.num_tokens) @@ -1298,8 +1349,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: self.finished_recving_kv_req_ids.remove(request.request_id) return True - def _update_from_kv_xfer_finished(self, - kv_connector_output: KVConnectorOutput): + def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """ KV Connector: update the scheduler state based on the output. @@ -1314,21 +1364,23 @@ def _update_from_kv_xfer_finished(self, self.connector.update_connector_output(kv_connector_output) # KV Connector:: update recv and send status from last step. - for req_id in (kv_connector_output.finished_recving or ()): + for req_id in kv_connector_output.finished_recving or (): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (kv_connector_output.finished_sending or ()): + for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) if req_id not in self.requests: logger.warning( "Got finished sending KV transfer for request %s," - "but the request is already freed.", req_id) + "but the request is already freed.", + req_id, + ) else: self._free_blocks(self.requests[req_id]) def _update_requests_with_invalid_blocks( - self, requests: Iterable[Request], - invalid_block_ids: set[int]) -> tuple[set[str], int]: + self, requests: Iterable[Request], invalid_block_ids: set[int] + ) -> tuple[set[str], int]: """ Identify and update requests affected by invalid KV cache blocks. @@ -1359,25 +1411,25 @@ def _update_requests_with_invalid_blocks( marked_invalid_block = False req_id = request.request_id # TODO (davidb): add support for hybrid memory allocator - (req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id) + (req_block_ids,) = self.kv_cache_manager.get_block_ids(req_id) # We iterate only over blocks that may contain externally computed # tokens if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: # Async loading. If num_computed_tokens is set it implies we # already processed some block failures for it in a prior step req_num_computed_tokens = ( - request.num_computed_tokens if req_id - in self.failed_recving_kv_req_ids else len(req_block_ids) * - self.block_size) + request.num_computed_tokens + if req_id in self.failed_recving_kv_req_ids + else len(req_block_ids) * self.block_size + ) else: # Sync loading. num_computed_tokens includes new tokens req_num_computed_tokens = request.num_cached_tokens - req_num_computed_blocks = (req_num_computed_tokens + - self.block_size - 1) // self.block_size - for idx, block_id in zip(range(req_num_computed_blocks), - req_block_ids): - + req_num_computed_blocks = ( + req_num_computed_tokens + self.block_size - 1 + ) // self.block_size + for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids): if block_id not in invalid_block_ids: continue @@ -1402,8 +1454,9 @@ def _update_requests_with_invalid_blocks( marked_invalid_block = True # Truncate the computed tokens at the first failed block request.num_computed_tokens = idx * self.block_size - total_affected_tokens += (req_num_computed_tokens - - request.num_computed_tokens) + total_affected_tokens += ( + req_num_computed_tokens - request.num_computed_tokens + ) if is_affected: if not marked_invalid_block: @@ -1412,8 +1465,9 @@ def _update_requests_with_invalid_blocks( # Revert to considering only cached tokens as computed. # Currently this only applies to sync loading; Async # loading does not yet support block sharing - total_affected_tokens += (request.num_computed_tokens - - request.num_cached_tokens) + total_affected_tokens += ( + request.num_computed_tokens - request.num_cached_tokens + ) request.num_computed_tokens = request.num_cached_tokens affected_req_ids.add(request.request_id) @@ -1426,11 +1480,15 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- async_load_reqs = ( - req for req in self.waiting - if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + req + for req in self.waiting + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + ) async_affected_req_ids, num_tokens_to_reschedule = ( - self._update_requests_with_invalid_blocks(async_load_reqs, - invalid_block_ids)) + self._update_requests_with_invalid_blocks( + async_load_reqs, invalid_block_ids + ) + ) total_requests_to_reschedule += len(async_affected_req_ids) total_tokens_to_reschedule += num_tokens_to_reschedule @@ -1441,8 +1499,8 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: # --- Handle sync KV loads (running requests) --- sync_affected_req_ids, num_tokens_to_reschedule = ( - self._update_requests_with_invalid_blocks(self.running, - invalid_block_ids)) + self._update_requests_with_invalid_blocks(self.running, invalid_block_ids) + ) total_requests_to_reschedule += len(sync_affected_req_ids) total_tokens_to_reschedule += num_tokens_to_reschedule @@ -1451,7 +1509,9 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: logger.warning( "Recovered from KV load failure: " "%d request(s) rescheduled (%d tokens affected).", - total_requests_to_reschedule, total_tokens_to_reschedule) + total_requests_to_reschedule, + total_tokens_to_reschedule, + ) # Return the IDs of affected running requests to skip in # update_from_output. diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index c431843de6ba..0979100ed325 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -40,11 +40,13 @@ def remove_all(lst: list, items_to_remove: set) -> list: return [item for item in lst if item not in items_to_remove] -def check_stop(request: Request, - max_model_len: int, - pooler_output: Optional[torch.Tensor] = None) -> bool: - if (request.num_tokens >= max_model_len - or request.num_output_tokens >= request.max_tokens): +def check_stop( + request: Request, max_model_len: int, pooler_output: Optional[torch.Tensor] = None +) -> bool: + if ( + request.num_tokens > max_model_len + or request.num_output_tokens >= request.max_tokens + ): request.status = RequestStatus.FINISHED_LENGTH_CAPPED return True @@ -57,8 +59,7 @@ def check_stop(request: Request, sampling_params = request.sampling_params assert sampling_params is not None last_token_id = request.output_token_ids[-1] - if (not sampling_params.ignore_eos - and last_token_id == request.eos_token_id): + if not sampling_params.ignore_eos and last_token_id == request.eos_token_id: request.status = RequestStatus.FINISHED_STOPPED return True diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 27ea1c4db2a5..d624ff1b3dcc 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -7,16 +7,21 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock -from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, - CrossAttentionSpec, FullAttentionSpec, - KVCacheSpec, MambaSpec, - MLAAttentionSpec, SlidingWindowSpec) +from vllm.v1.kv_cache_interface import ( + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + FullAttentionSpec, + KVCacheSpec, + MambaSpec, + MLAAttentionSpec, + SlidingWindowSpec, +) from vllm.v1.request import Request class SingleTypeKVCacheManager(ABC): """ - An abstract base class for a manager that handle the kv cache management + An abstract base class for a manager that handle the kv cache management logic of one specific type of attention layer. """ @@ -44,8 +49,7 @@ def __init__( # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) + self.req_to_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list) # {req_id: The number of cached blocks for this given request} # This is used to track the number of cached blocks for each request. @@ -57,14 +61,14 @@ def __init__( self._null_block = block_pool.null_block def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: list[KVCacheBlock]) -> int: + self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] + ) -> int: """ Get the number of blocks needed to be allocated for the request. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). new_computed_blocks: The new computed blocks just hitting the prefix caching. @@ -74,20 +78,23 @@ def get_num_blocks_to_allocate( """ num_required_blocks = cdiv(num_tokens, self.block_size) - num_new_blocks = (num_required_blocks - len(new_computed_blocks) - - len(self.req_to_blocks[request_id])) + num_new_blocks = ( + num_required_blocks + - len(new_computed_blocks) + - len(self.req_to_blocks[request_id]) + ) # If a computed block of a request is an eviction candidate (in the # free queue and ref_cnt == 0), it will be changed from a free block # to a computed block when the request is allocated, so we also count # it as needed to be allocated. num_evictable_computed_blocks = sum( - blk.ref_cnt == 0 and not blk.is_null - for blk in new_computed_blocks) + blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks + ) return num_new_blocks + num_evictable_computed_blocks def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: list[KVCacheBlock]) -> None: + self, request_id: str, new_computed_blocks: list[KVCacheBlock] + ) -> None: """ Add the new computed blocks to the request. @@ -106,15 +113,16 @@ def save_new_computed_blocks( # A running request. Should not have new computed blocks. assert len(new_computed_blocks) == 0 - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: + def allocate_new_blocks( + self, request_id: str, num_tokens: int + ) -> list[KVCacheBlock]: """ - Allocate new blocks for the request to give it at least `num_tokens` + Allocate new blocks for the request to give it at least `num_tokens` token slots. Args: request_id: The request ID. - num_tokens: The total number of tokens that need a slot (including + num_tokens: The total number of tokens that need a slot (including tokens that are already allocated). Returns: @@ -136,7 +144,7 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None: Args: request: The request. - num_tokens: The total number of tokens that need to be cached + num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ num_cached_blocks = self.num_cached_block[request.request_id] @@ -174,8 +182,9 @@ def free(self, request_id: str) -> None: self.num_cached_block.pop(request_id, None) @abstractmethod - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ Get the number of common prefix blocks for all requests in the RUNNING state. @@ -205,12 +214,12 @@ def find_longest_cache_hit( dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: """ - Get the longest cache hit prefix of the blocks that is not longer than - `max_length`. The prefix should be a common prefix hit for all the - kv cache groups in `kv_cache_group_ids`. If no cache hit is found, - return an empty list. - If eagle is enabled, drop the last matched block to force recompute the - last block to get the required hidden states for eagle drafting head. + Get the longest cache hit prefix of the blocks that is not longer than + `max_length`. The prefix should be a common prefix hit for all the + kv cache groups in `kv_cache_group_ids`. If no cache hit is found, + return an empty list. + If eagle is enabled, drop the last matched block to force recompute the + last block to get the required hidden states for eagle drafting head. Need to be customized for each attention type. Args: @@ -235,10 +244,9 @@ def find_longest_cache_hit( raise NotImplementedError @abstractmethod - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: """ - Remove the blocks that are no longer needed from `blocks` and free the + Remove the blocks that are no longer needed from `blocks` and free the blocks. The removed blocks should be replaced by null_block. Need to be customized for each attention type. @@ -250,7 +258,6 @@ def remove_skipped_blocks(self, request_id: str, class FullAttentionManager(SingleTypeKVCacheManager): - @classmethod def find_longest_cache_hit( cls, @@ -264,10 +271,13 @@ def find_longest_cache_hit( ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) - ), "FullAttentionManager can only be used for full attention " \ + ), ( + "FullAttentionManager can only be used for full attention " "and chunked local attention groups" + ) computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(len(kv_cache_group_ids))) + [] for _ in range(len(kv_cache_group_ids)) + ) block_size = kv_cache_spec.block_size if dcp_world_size > 1: block_size *= dcp_world_size @@ -277,7 +287,8 @@ def find_longest_cache_hit( # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids): + block_hash, kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: @@ -287,13 +298,13 @@ def find_longest_cache_hit( computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # No need to remove blocks for full attention. pass - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: blocks = self.req_to_blocks[request_id] num_common_blocks = 0 for block in blocks: @@ -305,9 +316,9 @@ def get_num_common_prefix_blocks(self, request_id: str, class SlidingWindowManager(SingleTypeKVCacheManager): - - def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, - **kwargs) -> None: + def __init__( + self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs + ) -> None: super().__init__(kv_cache_spec, block_pool, **kwargs) self.sliding_window = kv_cache_spec.sliding_window self._null_block = block_pool.null_block @@ -324,13 +335,15 @@ def find_longest_cache_hit( dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( - "SlidingWindowManager can only be used for sliding window groups") + "SlidingWindowManager can only be used for sliding window groups" + ) assert dcp_world_size == 1, "DCP not support sliding window attn now." # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window sliding_window_contiguous_blocks = cdiv( - kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size) + kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size + ) if use_eagle: # Need to drop the last matched block if eagle is enabled. For # sliding window layer, we achieve this by increasing the number of @@ -344,14 +357,17 @@ def find_longest_cache_hit( # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. max_num_blocks = max_length // kv_cache_spec.block_size - computed_blocks = tuple([block_pool.null_block] * max_num_blocks - for _ in range(len(kv_cache_group_ids))) + computed_blocks = tuple( + [block_pool.null_block] * max_num_blocks + for _ in range(len(kv_cache_group_ids)) + ) num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. for i in range(max_num_blocks - 1, -1, -1): if cached_block := block_pool.get_cached_block( - block_hashes[i], kv_cache_group_ids): + block_hashes[i], kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed[i] = cached num_contiguous_blocks += 1 @@ -360,7 +376,7 @@ def find_longest_cache_hit( # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] # when sliding_window_contiguous_blocks=2. for computed in computed_blocks: - del computed[i + num_contiguous_blocks:] + del computed[i + num_contiguous_blocks :] match_found = True break else: @@ -375,8 +391,7 @@ def find_longest_cache_hit( computed.pop() return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the sliding window and # skipped during the attention computation. last_useful_token = num_computed_tokens - self.sliding_window + 1 @@ -393,21 +408,22 @@ def remove_skipped_blocks(self, request_id: str, blocks[i] = self._null_block self.block_pool.free_blocks(removed_blocks) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ NOTE(Chen): The prefix blocks are null blocks for sliding window layers. - So it's not correct to count ref_cnt like FullAttentionManager. Return - 0 here for correctness. Need to support cascade attention + sliding + So it's not correct to count ref_cnt like FullAttentionManager. Return + 0 here for correctness. Need to support cascade attention + sliding window in the future. """ return 0 class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): - - def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, - block_pool: BlockPool, **kwargs) -> None: + def __init__( + self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs + ) -> None: super().__init__(kv_cache_spec, block_pool, **kwargs) self.attention_chunk_size = kv_cache_spec.attention_chunk_size self._null_block = block_pool.null_block @@ -428,19 +444,19 @@ def find_longest_cache_hit( prefix of the blocks that is not longer than `max_length`. The prefix should be a common prefix hit for all the kv cache groups in `kv_cache_group_ids`. If no cache hit is found, return an empty list. - note we mark as computed if the whole block is outside of the local + note we mark as computed if the whole block is outside of the local window, and set the block as null. Examples: 1. Attention chunk size of 8, block size of 4, max length of 15 - for next token at 15th (zero-indexed), 8th - 14th tokens are in - the window(needs lookup), 0th - 7th are not in the window, - so they are already marked as computed. We check the complete - block3 (8th - 11th tokens), Assume block 3 is hit, we will return + for next token at 15th (zero-indexed), 8th - 14th tokens are in + the window(needs lookup), 0th - 7th are not in the window, + so they are already marked as computed. We check the complete + block3 (8th - 11th tokens), Assume block 3 is hit, we will return [null, null, block 3], otherwise, we return [null, null] 2. Attention chunk size of 8, block size of 4, max length of 16 - for next token at 16th (zero-indexed), 0th - 15th tokens are not - in the window, so they are already marked as computed. + for next token at 16th (zero-indexed), 0th - 15th tokens are not + in the window, so they are already marked as computed. we return 4 blocks[null, null, null, null] Args: @@ -455,39 +471,45 @@ def find_longest_cache_hit( A list of cached blocks """ assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), ( - "ChunkedLocalAttentionManager can only be used for " + - "chunked local attention groups") - assert use_eagle is False, ("Hybrid KV cache is not supported for " + - "eagle + chunked local attention.") + "ChunkedLocalAttentionManager can only be used for " + + "chunked local attention groups" + ) + assert use_eagle is False, ( + "Hybrid KV cache is not supported for " + "eagle + chunked local attention." + ) assert dcp_world_size == 1, "DCP not support chunked local attn now." max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: - local_attention_start_idx = (max_length // - kv_cache_spec.attention_chunk_size * - kv_cache_spec.attention_chunk_size) + local_attention_start_idx = ( + max_length + // kv_cache_spec.attention_chunk_size + * kv_cache_spec.attention_chunk_size + ) else: local_attention_start_idx = 0 # we marked blocks out of window as computed # with null blocks, and blocks inside window based on cache lookup # result [null] [null] ... [null] [hit block 1 (1st block contain # last window)] [hit block 2] ... [hit block x] - local_attention_start_block_idx = (local_attention_start_idx // - kv_cache_spec.block_size) + local_attention_start_block_idx = ( + local_attention_start_idx // kv_cache_spec.block_size + ) computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [block_pool.null_block] * local_attention_start_block_idx - for _ in range(len(kv_cache_group_ids))) + for _ in range(len(kv_cache_group_ids)) + ) for i in range(local_attention_start_block_idx, max_num_blocks): block_hash = block_hashes[i] if cached_block := block_pool.get_cached_block( - block_hash, kv_cache_group_ids): + block_hash, kv_cache_group_ids + ): for computed, cached in zip(computed_blocks, cached_block): computed.append(cached) else: break return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the chunked attention # window and skipped during the attention computation. @@ -499,13 +521,14 @@ def remove_skipped_blocks(self, request_id: str, # is 1024. for 1023, it will be 0. num_cached_block = self.num_cached_block.get(request_id, 0) local_attention_start_idx = ( - num_computed_tokens - ) // self.attention_chunk_size * self.attention_chunk_size + (num_computed_tokens) + // self.attention_chunk_size + * self.attention_chunk_size + ) first_useful_block_idx = local_attention_start_idx // self.block_size if num_cached_block > 0: # Make sure we don't delete the last cached block - first_useful_block_idx = min(first_useful_block_idx, - num_cached_block - 1) + first_useful_block_idx = min(first_useful_block_idx, num_cached_block - 1) # if block size = 128, 0 -> block 0, 1024 (= 128 * 8) -> # block 8, 372 (= 128 * 2 + 116) -> block 2 blocks = self.req_to_blocks[request_id] @@ -521,8 +544,9 @@ def remove_skipped_blocks(self, request_id: str, blocks[i] = self._null_block self.block_pool.free_blocks(removed_blocks) - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: """ cascade attention is not supported by chunked local attention. """ @@ -530,7 +554,6 @@ def get_num_common_prefix_blocks(self, request_id: str, class MambaManager(SingleTypeKVCacheManager): - @classmethod def find_longest_cache_hit( cls, @@ -542,46 +565,71 @@ def find_longest_cache_hit( use_eagle: bool, dcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: - assert isinstance( - kv_cache_spec, - MambaSpec), ("MambaManager can only be used for mamba groups") + assert isinstance(kv_cache_spec, MambaSpec), ( + "MambaManager can only be used for mamba groups" + ) assert dcp_world_size == 1, "DCP not support mamba now." - # Prefix caching is not supported for mamba now. Always return empty - # list. computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(len(kv_cache_group_ids))) + [] for _ in range(len(kv_cache_group_ids)) + ) + + max_num_blocks = max_length // kv_cache_spec.block_size + # Search from right to left and early stop when a match is found. + for i in range(max_num_blocks - 1, -1, -1): + if cached_block := block_pool.get_cached_block( + block_hashes[i], kv_cache_group_ids + ): + for computed, cached in zip(computed_blocks, cached_block): + # the hit length logic later assumes: + # hit_length = len(hit_blocks_other_attn[0]) + # * self.other_block_size + # so we insert dummy blocks at the beginning: + computed.extend([block_pool.null_block] * i) + computed.append(cached) + break # we just need the last match - early stopping + return computed_blocks - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: - # Each request will always have 1 block at this moment, so no need to - # remove blocks. + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: + # Here unused blocks may be freed up for running requests. + # TODO(@s3woz) Free up all blocks that aren't needed by Mamba2 + # (for which find_longest_cache_hit returns block_pool.null_block) pass - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: + """ + cascade attention is not supported by mamba + """ return 0 def get_num_blocks_to_allocate( - self, request_id: str, num_tokens: int, - new_computed_blocks: list[KVCacheBlock]) -> int: + self, request_id: str, num_tokens: int, new_computed_blocks: list[KVCacheBlock] + ) -> int: # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += (self.kv_cache_spec.block_size * - self.kv_cache_spec.num_speculative_blocks) - return super().get_num_blocks_to_allocate(request_id, num_tokens, - new_computed_blocks) + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) + return super().get_num_blocks_to_allocate( + request_id, num_tokens, new_computed_blocks + ) - def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[KVCacheBlock]: + def allocate_new_blocks( + self, request_id: str, num_tokens: int + ) -> list[KVCacheBlock]: # Allocate extra `num_speculative_blocks` blocks for # speculative decoding (MTP/EAGLE) with linear attention. assert isinstance(self.kv_cache_spec, MambaSpec) if self.kv_cache_spec.num_speculative_blocks > 0: - num_tokens += (self.kv_cache_spec.block_size * - self.kv_cache_spec.num_speculative_blocks) + num_tokens += ( + self.kv_cache_spec.block_size + * self.kv_cache_spec.num_speculative_blocks + ) return super().allocate_new_blocks(request_id, num_tokens) @@ -589,8 +637,8 @@ class CrossAttentionManager(SingleTypeKVCacheManager): """Manager for cross-attention KV cache in encoder-decoder models.""" def save_new_computed_blocks( - self, request_id: str, - new_computed_blocks: list[KVCacheBlock]) -> None: + self, request_id: str, new_computed_blocks: list[KVCacheBlock] + ) -> None: # We do not cache blocks for cross-attention to be shared between # requests, so `new_computed_blocks` should always be empty. assert len(new_computed_blocks) == 0 @@ -600,8 +648,9 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None: # requests, so this method is not relevant. raise ValueError("Should not be called as prefix caching is disabled.") - def get_num_common_prefix_blocks(self, request_id: str, - num_running_requests: int) -> int: + def get_num_common_prefix_blocks( + self, request_id: str, num_running_requests: int + ) -> int: # Cross-attention blocks contain request-specific encoder states # and are not shared between different requests return 0 @@ -626,11 +675,9 @@ def find_longest_cache_hit( # 2. Encoder states are computed once per request, not incrementally # 3. No reusable prefix exists between different multimodal inputs # Return empty blocks to indicate no cache hits - raise NotImplementedError( - "CrossAttentionManager does not support caching") + raise NotImplementedError("CrossAttentionManager does not support caching") - def remove_skipped_blocks(self, request_id: str, - num_computed_tokens: int) -> None: + def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None: # Cross-attention blocks represent encoder states which are needed # for the entire decoding process, so no blocks should be skipped pass @@ -646,8 +693,9 @@ def remove_skipped_blocks(self, request_id: str, } -def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec, - **kwargs) -> SingleTypeKVCacheManager: +def get_manager_for_kv_cache_spec( + kv_cache_spec: KVCacheSpec, **kwargs +) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] manager = manager_class(kv_cache_spec, **kwargs) return manager diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 29bb220760c0..ce4714702869 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -12,14 +12,14 @@ class CudagraphDispatcher: cudagraphs. The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one - for FULL cudagraph runtime mode. The keys are initialized depending on - attention support and what cudagraph mode is set in CompilationConfig. The + for FULL cudagraph runtime mode. The keys are initialized depending on + attention support and what cudagraph mode is set in CompilationConfig. The keys stored in dispatcher are the only source of truth for valid cudagraphs that can be dispatched at runtime. - At runtime, the dispatch method generates the runtime cudagraph mode (FULL, + At runtime, the dispatch method generates the runtime cudagraph mode (FULL, PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) - based on the input key. After dispatching (communicated via forward + based on the input key. After dispatching (communicated via forward context), the cudagraph wrappers will trust the dispatch key to either capture or replay (if the mode matches), or pass through to the underlying runnable without cudagraph (if the mode does not match or mode is NONE). @@ -37,28 +37,35 @@ def __init__(self, vllm_config: VllmConfig): } not_use_piecewise_compilation = ( - not self.cudagraph_mode.requires_piecewise_compilation()) - - assert not_use_piecewise_compilation or \ - self.compilation_config.is_attention_compiled_piecewise(), \ - "Compilation level should be CompilationLevel.PIECEWISE when "\ - "cudagraph_mode piecewise cudagraphs is used, "\ - "and attention should be in splitting_ops or "\ - "inductor splitting should be used. " \ - f"cudagraph_mode={self.cudagraph_mode}, "\ - f"compilation_level={self.compilation_config.level}, "\ + not self.cudagraph_mode.requires_piecewise_compilation() + ) + + assert ( + not_use_piecewise_compilation + or self.compilation_config.is_attention_compiled_piecewise() + ), ( + "Compilation level should be CompilationLevel.PIECEWISE when " + "cudagraph_mode piecewise cudagraphs is used, " + "and attention should be in splitting_ops or " + "inductor splitting should be used. " + f"cudagraph_mode={self.cudagraph_mode}, " + f"compilation_level={self.compilation_config.level}, " f"splitting_ops={self.compilation_config.splitting_ops}" + ) self.keys_initialized = False - def add_cudagraph_key(self, runtime_mode: CUDAGraphMode, - batch_descriptor: BatchDescriptor): - assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ + def add_cudagraph_key( + self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor + ): + assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( f"Invalid cudagraph runtime mode for keys: {runtime_mode}" + ) self.cudagraph_keys[runtime_mode].add(batch_descriptor) - def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, - uniform_decode_query_len: int): + def initialize_cudagraph_keys( + self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int + ): # This should be called only after attention backend is initialized. # Note: we create all valid keys for cudagraph here but do not @@ -68,33 +75,38 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, for bs in self.compilation_config.cudagraph_capture_sizes: self.add_cudagraph_key( cudagraph_mode.mixed_mode(), - BatchDescriptor(num_tokens=bs, uniform_decode=False)) + BatchDescriptor(num_tokens=bs, uniform_decode=False), + ) # if decode cudagraph mode is FULL, and we don't already have mixed # mode full cudagraphs then add them here. - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL \ - and cudagraph_mode.separate_routine(): - max_num_tokens = uniform_decode_query_len * \ - self.vllm_config.scheduler_config.max_num_seqs + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + uniform_decode_query_len + * self.vllm_config.scheduler_config.max_num_seqs + ) cudagraph_capture_sizes_for_decode = [ - x for x in self.compilation_config.cudagraph_capture_sizes + x + for x in self.compilation_config.cudagraph_capture_sizes if x <= max_num_tokens and x >= uniform_decode_query_len ] for bs in cudagraph_capture_sizes_for_decode: self.add_cudagraph_key( CUDAGraphMode.FULL, - BatchDescriptor(num_tokens=bs, uniform_decode=True)) + BatchDescriptor(num_tokens=bs, uniform_decode=True), + ) self.keys_initialized = True def dispatch( - self, - batch_descriptor: BatchDescriptor, - use_cascade_attn: bool = False + self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: """ Given conditions(e.g.,batch descriptor and if using cascade attention), dispatch to a cudagraph runtime mode and the valid batch descriptor. - A new batch descriptor is returned as we might dispatch a uniform batch + A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). """ # if not initialized, just skip dispatching. diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 345f5a464c2c..163c050e559e 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -32,6 +32,7 @@ class FinishReason(enum.IntEnum): abort - aborted for another reason """ + STOP = 0 LENGTH = 1 ABORT = 2 @@ -41,11 +42,11 @@ def __str__(self): class EngineCoreRequest( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] request_id: str prompt_token_ids: Optional[list[int]] mm_features: Optional[list[MultiModalFeatureSpec]] @@ -73,6 +74,7 @@ class EngineCoreRequest( class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" + QUEUED = 1 SCHEDULED = 2 PREEMPTED = 3 @@ -85,23 +87,24 @@ class EngineCoreEvent(msgspec.Struct): frontend to calculate intervals between engine core events. These timestamps should not be compared with timestamps from other processes. """ + type: EngineCoreEventType timestamp: float @classmethod - def new_event(cls, - event_type: EngineCoreEventType, - timestamp: Optional[float] = None) -> "EngineCoreEvent": + def new_event( + cls, event_type: EngineCoreEventType, timestamp: Optional[float] = None + ) -> "EngineCoreEvent": timestamp = time.monotonic() if timestamp is None else timestamp return cls(event_type, timestamp) class EngineCoreOutput( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] request_id: str new_token_ids: list[int] @@ -132,10 +135,10 @@ def __init__(self, r: Any = None): class UtilityOutput( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] call_id: int # Non-None implies the call failed, result should be None. @@ -144,11 +147,11 @@ class UtilityOutput( class EngineCoreOutputs( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - gc=False): # type: ignore[call-arg] - + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, +): # type: ignore[call-arg] # NOTE(Nick): We could consider ways to make this more compact, # e.g. columnwise layout @@ -179,12 +182,13 @@ class EngineCoreRequestType(enum.Enum): Request types defined as hex byte strings, so it can be sent over sockets without separate encoding step. """ - ADD = b'\x00' - ABORT = b'\x01' - START_DP_WAVE = b'\x02' - UTILITY = b'\x03' + + ADD = b"\x00" + ABORT = b"\x01" + START_DP_WAVE = b"\x02" + UTILITY = b"\x03" # Sentinel used within EngineCoreProc. - EXECUTOR_FAILED = b'\x04' + EXECUTOR_FAILED = b"\x04" class ReconfigureDistributedRequest(msgspec.Struct): @@ -199,5 +203,6 @@ class ReconfigureRankType(enum.IntEnum): """ Rank type for reconfiguring distributed request. """ + KEEP_CURRENT_RANK = -1 SHUTDOWN_CURRENT_RANK = -2 diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ab3a4e5e6fe5..5be1f833e3f6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -27,18 +27,14 @@ from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.tracing import init_tracer -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - init_tokenizer_from_configs) +from vllm.transformers_utils.config import maybe_register_config_serialize_by_value +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import (Device, as_list, cancel_task_threadsafe, cdiv, - deprecate_kwargs) +from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError -from vllm.v1.engine.output_processor import (OutputProcessor, - RequestOutputCollector) +from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -50,7 +46,6 @@ class AsyncLLM(EngineClient): - def __init__( self, vllm_config: VllmConfig, @@ -91,7 +86,8 @@ def __init__( "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " "This should not happen. As a workaround, try using " "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") + "VLLM_USE_V1=0 or 1 and report this issue on Github." + ) # Ensure we can serialize custom transformer configs maybe_register_config_serialize_by_value() @@ -105,29 +101,20 @@ def __init__( if not log_stats and stat_loggers is not None: logger.info( "AsyncLLM created with log_stats=False and non-empty custom " - "logger list; enabling logging without default stat loggers") - - if self.model_config.skip_tokenizer_init: - self.tokenizer = None - else: - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config) + "logger list; enabling logging without default stat loggers" + ) # Processor (converts Inputs --> EngineCoreRequests). - self.processor = Processor( - vllm_config=vllm_config, - tokenizer=self.tokenizer, - mm_registry=mm_registry, - ) + self.processor = Processor(vllm_config, mm_registry=mm_registry) # OutputProcessor (converts EngineCoreOutputs --> RequestOutput). - self.output_processor = OutputProcessor(self.tokenizer, - log_stats=self.log_stats) + self.output_processor = OutputProcessor( + self.tokenizer, log_stats=self.log_stats + ) if self.observability_config.otlp_traces_endpoint is not None: tracer = init_tracer( - "vllm.llm_engine", - self.observability_config.otlp_traces_endpoint) + "vllm.llm_engine", self.observability_config.otlp_traces_endpoint + ) self.output_processor.tracer = tracer # EngineCore (starts the engine in background process). @@ -163,7 +150,8 @@ def __init__( if envs.VLLM_TORCH_PROFILER_DIR: logger.info( "Torch profiler enabled. AsyncLLM CPU traces will be collected under %s", # noqa: E501 - envs.VLLM_TORCH_PROFILER_DIR) + envs.VLLM_TORCH_PROFILER_DIR, + ) worker_name = f"{socket.gethostname()}_{os.getpid()}.async_llm" self.profiler = torch.profiler.profile( activities=[ @@ -171,37 +159,39 @@ def __init__( ], with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, on_trace_ready=torch.profiler.tensorboard_trace_handler( - envs.VLLM_TORCH_PROFILER_DIR, - worker_name=worker_name, - use_gzip=True)) + envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True + ), + ) else: self.profiler = None @classmethod @deprecate_kwargs( "disable_log_requests", - additional_message=("This argument will have no effect. " - "Use `enable_log_requests` instead."), + additional_message=( + "This argument will have no effect. Use `enable_log_requests` instead." + ), ) def from_vllm_config( - cls, - vllm_config: VllmConfig, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[StatLoggerFactory]] = None, - enable_log_requests: bool = False, - disable_log_stats: bool = False, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0, - disable_log_requests: bool = True, # Deprecated, will be removed + cls, + vllm_config: VllmConfig, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[list[StatLoggerFactory]] = None, + enable_log_requests: bool = False, + disable_log_stats: bool = False, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + disable_log_requests: bool = True, # Deprecated, will be removed ) -> "AsyncLLM": if not envs.VLLM_USE_V1: raise ValueError( "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " "This should not happen. As a workaround, try using " "AsyncLLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") + "VLLM_USE_V1=0 or 1 and report this issue on Github." + ) # Create the LLMEngine. return cls( @@ -255,6 +245,10 @@ def shutdown(self): cancel_task_threadsafe(getattr(self, "output_handler", None)) + @property + def tokenizer(self) -> Optional[AnyTokenizer]: + return self.processor.tokenizer + async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return await self.engine_core.get_supported_tasks_async() @@ -288,14 +282,20 @@ async def add_request( assert prompt_text is None logger.warning_once( "Processor has been moved under OpenAIServing and will " - "be removed from AsyncLLM in v0.13.") - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - tokenization_kwargs, - trace_headers, priority, - data_parallel_rank) - prompt_text = (prompt if isinstance(prompt, str) else - prompt.get("prompt")) + "be removed from AsyncLLM in v0.13." + ) + request = self.processor.process_inputs( + request_id, + prompt, + params, + arrival_time, + lora_request, + tokenization_kwargs, + trace_headers, + priority, + data_parallel_rank, + ) + prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") if is_pooling or params.n == 1: await self._add_request(request, prompt_text, None, 0, queue) @@ -310,22 +310,24 @@ async def add_request( parent_request = ParentRequest(request_id, parent_params) for idx in range(parent_params.n): request_id, child_params = parent_request.get_child_info(idx) - child_request = request if idx == parent_params.n - 1 else copy( - request) + child_request = request if idx == parent_params.n - 1 else copy(request) child_request.request_id = request_id child_request.sampling_params = child_params - await self._add_request(child_request, prompt_text, parent_request, - idx, queue) + await self._add_request( + child_request, prompt_text, parent_request, idx, queue + ) return queue - async def _add_request(self, request: EngineCoreRequest, - prompt: Optional[str], - parent_req: Optional[ParentRequest], index: int, - queue: RequestOutputCollector): - + async def _add_request( + self, + request: EngineCoreRequest, + prompt: Optional[str], + parent_req: Optional[ParentRequest], + index: int, + queue: RequestOutputCollector, + ): # Add the request to OutputProcessor (this process). - self.output_processor.add_request(request, prompt, parent_req, index, - queue) + self.output_processor.add_request(request, prompt, parent_req, index, queue) # Add the EngineCoreRequest to EngineCore (separate process). await self.engine_core.add_request_async(request) @@ -366,12 +368,15 @@ async def generate( returning the RequestOutput back to the caller. """ - if (self.vllm_config.cache_config.kv_sharing_fast_prefill - and sampling_params.prompt_logprobs): + if ( + self.vllm_config.cache_config.kv_sharing_fast_prefill + and sampling_params.prompt_logprobs + ): raise ValueError( "--kv-sharing-fast-prefill produces incorrect logprobs for " "prompt tokens, please disable it when the requests need " - "prompt logprobs") + "prompt logprobs" + ) try: # We start the output_handler on the first call to generate() so @@ -389,15 +394,17 @@ async def generate( tokenization_kwargs, ) - q = await self.add_request(request_id, - prompt, - sampling_params, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - trace_headers=trace_headers, - priority=priority, - data_parallel_rank=data_parallel_rank, - prompt_text=prompt_text) + q = await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + trace_headers=trace_headers, + priority=priority, + data_parallel_rank=data_parallel_rank, + prompt_text=prompt_text, + ) # The output_handler task pushes items into the queue. # This task pulls from the queue and yields to caller. @@ -460,23 +467,26 @@ async def output_handler(): outputs = await engine_core.get_output_async() num_outputs = len(outputs.outputs) - iteration_stats = IterationStats() if ( - log_stats and num_outputs) else None + iteration_stats = ( + IterationStats() if (log_stats and num_outputs) else None + ) # Split outputs into chunks of at most # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the # event loop for too long. if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: - slices = (outputs.outputs, ) + slices = (outputs.outputs,) else: slices = np.array_split( outputs.outputs, - cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE)) + cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE), + ) for i, outputs_slice in enumerate(slices): # 2) Process EngineCoreOutputs. processed_outputs = output_processor.process_outputs( - outputs_slice, outputs.timestamp, iteration_stats) + outputs_slice, outputs.timestamp, iteration_stats + ) # NOTE: RequestOutputs are pushed to their queues. assert not processed_outputs.request_outputs @@ -486,7 +496,8 @@ async def output_handler(): # 3) Abort any reqs that finished due to stop strings. await engine_core.abort_requests_async( - processed_outputs.reqs_to_abort) + processed_outputs.reqs_to_abort + ) # 4) Logging. # TODO(rob): make into a coroutine and launch it in @@ -506,8 +517,9 @@ async def output_handler(): async def abort(self, request_id: Union[str, Iterable[str]]) -> None: """Abort RequestId in OutputProcessor and EngineCore.""" - request_ids = (request_id, ) if isinstance( - request_id, str) else as_list(request_id) + request_ids = ( + (request_id,) if isinstance(request_id, str) else as_list(request_id) + ) all_request_ids = self.output_processor.abort_requests(request_ids) await self.engine_core.abort_requests_async(all_request_ids) @@ -614,8 +626,9 @@ async def get_input_preprocessor(self) -> InputPreprocessor: async def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") + raise ValueError( + "Unable to get tokenizer because skip_tokenizer_init is True" + ) return self.tokenizer @@ -647,8 +660,7 @@ async def reset_mm_cache(self) -> None: self.processor.clear_cache() await self.engine_core.reset_mm_cache_async() - async def reset_prefix_cache(self, - device: Optional[Device] = None) -> None: + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: if device == Device.CPU: raise ValueError("Not supported on CPU.") await self.engine_core.reset_prefix_cache_async() @@ -679,16 +691,19 @@ async def pin_lora(self, lora_id: int) -> bool: """Prevent an adapter from being evicted.""" return await self.engine_core.pin_lora_async(lora_id) - async def collective_rpc(self, - method: str, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None): + async def collective_rpc( + self, + method: str, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + ): """ Perform a collective RPC call to the given path. """ return await self.engine_core.collective_rpc_async( - method, timeout, args, kwargs) + method, timeout, args, kwargs + ) async def wait_for_requests_to_drain(self, drain_timeout: int = 300): """Wait for all requests to be drained.""" @@ -698,16 +713,17 @@ async def wait_for_requests_to_drain(self, drain_timeout: int = 300): logger.info("Engines are idle, requests have been drained") return - logger.info( - "Engines are still running, waiting for requests to drain...") + logger.info("Engines are still running, waiting for requests to drain...") await asyncio.sleep(1) # Wait 1 second before checking again - raise TimeoutError(f"Timeout reached after {drain_timeout} seconds " - "waiting for requests to drain.") + raise TimeoutError( + f"Timeout reached after {drain_timeout} seconds " + "waiting for requests to drain." + ) - async def scale_elastic_ep(self, - new_data_parallel_size: int, - drain_timeout: int = 300): + async def scale_elastic_ep( + self, new_data_parallel_size: int, drain_timeout: int = 300 + ): """ Scale up or down the data parallel size by adding or removing engine cores. @@ -716,22 +732,24 @@ async def scale_elastic_ep(self, drain_timeout: Maximum time to wait for requests to drain (seconds) """ - old_data_parallel_size = \ - self.vllm_config.parallel_config.data_parallel_size + old_data_parallel_size = self.vllm_config.parallel_config.data_parallel_size if old_data_parallel_size == new_data_parallel_size: - logger.info("Data parallel size is already %s, skipping scale", - new_data_parallel_size) + logger.info( + "Data parallel size is already %s, skipping scale", + new_data_parallel_size, + ) return logger.info( - "Waiting for requests to drain before " - "scaling up to %s engines...", new_data_parallel_size) + "Waiting for requests to drain before scaling up to %s engines...", + new_data_parallel_size, + ) await self.wait_for_requests_to_drain(drain_timeout) logger.info( - "Requests have been drained, proceeding with scale " - "to %s engines", new_data_parallel_size) + "Requests have been drained, proceeding with scale to %s engines", + new_data_parallel_size, + ) await self.engine_core.scale_elastic_ep(new_data_parallel_size) - self.vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # recreate stat loggers if new_data_parallel_size > old_data_parallel_size and self.log_stats: diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 596edfdbe24f..9bb08e6db7be 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -56,7 +56,6 @@ class DPCoordinator: """ def __init__(self, parallel_config: ParallelConfig): - dp_size = parallel_config.data_parallel_size assert dp_size > 1, "Coordinator only used for data parallel" @@ -68,7 +67,8 @@ def __init__(self, parallel_config: ParallelConfig): # either external or hybrid DP LB mode. local_only = not (external_lb or hybrid_lb) front_publish_address = get_engine_client_zmq_addr( - local_only=local_only, host=host) + local_only=local_only, host=host + ) local_only_eng = dp_size == parallel_config.data_parallel_size_local back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) @@ -84,7 +84,8 @@ def __init__(self, parallel_config: ParallelConfig): "back_output_address": back_output_address, "back_publish_address": back_publish_address, }, - daemon=True) + daemon=True, + ) self.proc.start() self.stats_publish_address = front_publish_address @@ -104,16 +105,12 @@ def close(self): class EngineState: - def __init__(self): self.request_counts = [0, 0] # [waiting, running] class DPCoordinatorProc: - - def __init__(self, - engine_count: int, - min_stats_update_interval_ms: int = 100): + def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100): set_process_title("DPCoordinator") self.ctx = zmq.Context() @@ -131,7 +128,8 @@ def run_coordinator( ): coordinator = DPCoordinatorProc( engine_count=engine_count, - min_stats_update_interval_ms=min_stats_update_interval_ms) + min_stats_update_interval_ms=min_stats_update_interval_ms, + ) try: coordinator.process_input_socket( front_publish_address, @@ -141,10 +139,12 @@ def run_coordinator( except KeyboardInterrupt: logger.info("DP Coordinator process exiting") - def process_input_socket(self, front_publish_address: str, - back_output_address: str, - back_publish_address: str): - + def process_input_socket( + self, + front_publish_address: str, + back_output_address: str, + back_publish_address: str, + ): decoder = MsgpackDecoder(EngineCoreOutputs) # For tracking request wave progression. @@ -157,29 +157,33 @@ def process_input_socket(self, front_publish_address: str, last_stats_wave = -1 last_step_counts: Optional[list[list[int]]] = None - with make_zmq_socket( + with ( + make_zmq_socket( path=front_publish_address, # IPC ctx=self.ctx, socket_type=zmq.XPUB, bind=True, - ) as publish_front, make_zmq_socket( + ) as publish_front, + make_zmq_socket( path=back_output_address, # IPC or TCP ctx=self.ctx, socket_type=zmq.PULL, bind=True, - ) as output_back, make_zmq_socket( + ) as output_back, + make_zmq_socket( path=back_publish_address, # IPC or TCP ctx=self.ctx, socket_type=zmq.XPUB, bind=True, - ) as publish_back: - + ) as publish_back, + ): # Wait until all engines subscribe. for _ in self.engines: - if publish_back.recv() != b'\x01': + if publish_back.recv() != b"\x01": logger.error( "DP Coordinator received unexpected message while " - "waiting for engines to subscribe") + "waiting for engines to subscribe" + ) return # Send ready message to engines. publish_back.send(b"READY") @@ -194,15 +198,13 @@ def process_input_socket(self, front_publish_address: str, elapsed = int(time.time() * 1000) - last_publish_time # Send at stats_update_interval_ms interval if the stats have # changed, or otherwise every 5 seconds. - wait_for = (self.stats_update_interval_ms - if stats_changed else 5000) + wait_for = self.stats_update_interval_ms if stats_changed else 5000 # Wait at least 50ms to ensure we've received all stats for # the current step. min_timeout = 50 if last_step_counts is None else 0 - events = poller.poll(timeout=max(min_timeout, wait_for - - elapsed)) + events = poller.poll(timeout=max(min_timeout, wait_for - elapsed)) if not events: # Poller timeout - publish current stats to front-ends. if last_step_counts is not None: @@ -212,8 +214,7 @@ def process_input_socket(self, front_publish_address: str, engine_req_counts_list = self._get_engine_counts() stats_changed = False - to_publish = (engine_req_counts_list, current_wave, - engines_running) + to_publish = (engine_req_counts_list, current_wave, engines_running) publish_front.send(msgspec.msgpack.encode(to_publish)) last_publish_time = int(time.time() * 1000) continue @@ -223,13 +224,16 @@ def process_input_socket(self, front_publish_address: str, if publish_front in events: buffer = publish_front.recv() - if buffer in (b'\x01', b'\x00'): + if buffer in (b"\x01", b"\x00"): # Ignore subscription messages. continue decoded = msgspec.msgpack.decode(buffer) - if isinstance(decoded, (list, tuple)) and len( - decoded) == 2 and decoded[0] == "SCALE_ELASTIC_EP": + if ( + isinstance(decoded, (list, tuple)) + and len(decoded) == 2 + and decoded[0] == "SCALE_ELASTIC_EP" + ): # Handle scale up notification new_engine_count = decoded[1] current_count = len(self.engines) @@ -248,13 +252,17 @@ def process_input_socket(self, front_publish_address: str, # engine engines_running = False logger.info( - "DPCoordinator scaled up from %s to %s " - "engines", current_count, new_engine_count) + "DPCoordinator scaled up from %s to %s engines", + current_count, + new_engine_count, + ) else: self.engines = self.engines[:new_engine_count] logger.info( - "DPCoordinator scaled down from %s to %s " - "engines", current_count, new_engine_count) + "DPCoordinator scaled down from %s to %s engines", + current_count, + new_engine_count, + ) continue # Skip normal engine notification processing # We received a message on the front-end XPUB socket, @@ -270,8 +278,9 @@ def process_input_socket(self, front_publish_address: str, engines_running = True wave_state_changed = True - self._send_start_wave(publish_back, current_wave, - engine_to_exclude) + self._send_start_wave( + publish_back, current_wave, engine_to_exclude + ) if output_back in events: # We received a message from one of the engines. @@ -290,21 +299,28 @@ def process_input_socket(self, front_publish_address: str, stats = self.engines[eng_index].request_counts stats_step = scheduler_stats.step_counter stats_wave = scheduler_stats.current_wave - if (stats_wave > last_stats_wave - or stats_wave == last_stats_wave - and stats_step > last_stats_step): + if ( + stats_wave > last_stats_wave + or stats_wave == last_stats_wave + and stats_step > last_stats_step + ): if stats_changed: - last_step_counts = self._get_engine_counts( - do_copy=True) + last_step_counts = self._get_engine_counts(do_copy=True) last_stats_step = stats_step last_stats_wave = stats_wave elif stats_wave != last_stats_wave or ( - stats_step != last_stats_step): + stats_step != last_stats_step + ): logger.warning( "Received stats for out-of-order " "step (%d, %d) from engine %d (expected " - "> (%d, %d))", stats_wave, stats_step, - eng_index, last_stats_wave, last_stats_step) + "> (%d, %d))", + stats_wave, + stats_step, + eng_index, + last_stats_wave, + last_stats_step, + ) stats[0] = scheduler_stats.num_waiting_reqs stats[1] = scheduler_stats.num_running_reqs stats_changed = True @@ -315,20 +331,24 @@ def process_input_socket(self, front_publish_address: str, # (engines_running==False). if current_wave <= wave: new_wave = wave + 1 - logger.debug("Moving DP wave from %d to %d.", - current_wave, new_wave) + logger.debug( + "Moving DP wave from %d to %d.", current_wave, new_wave + ) current_wave = new_wave engines_running = False wave_state_changed = True elif (wave := outputs.start_wave) is not None and ( - wave > current_wave or - (wave == current_wave and not engines_running)): + wave > current_wave + or (wave == current_wave and not engines_running) + ): # 3. The engine received request for a non-current wave # so we must ensure that other engines progress to the # next wave (race condition handling). logger.debug( "Starting wave %d after notification of " - "stale wave request from engine.", wave) + "stale wave request from engine.", + wave, + ) current_wave = wave engines_running = True wave_state_changed = True @@ -339,16 +359,16 @@ def process_input_socket(self, front_publish_address: str, publish_front.send(msgspec.msgpack.encode(message)) @staticmethod - def _send_start_wave(socket: zmq.Socket, wave: int, - exclude_engine_index: Optional[int]): + def _send_start_wave( + socket: zmq.Socket, wave: int, exclude_engine_index: Optional[int] + ): """Broadcast the START_DP_WAVE message to all the engines. It includes the current wave number and index of engine which has already received a request with this wave number and so doesn't require additional notification. """ wave_encoded = msgspec.msgpack.encode((wave, exclude_engine_index)) - socket.send_multipart( - (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) + socket.send_multipart((EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) def _get_engine_counts(self, do_copy=False) -> list[list[int]]: """Return list of [waiting, running] count lists for each engine.""" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 3ee804f10c17..4826d7c589a7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -25,25 +25,39 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask -from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) -from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket, - resolve_obj_by_qualname, set_process_title) +from vllm.transformers_utils.config import maybe_register_config_serialize_by_value +from vllm.utils import ( + decorate_logs, + get_hash_fn_by_name, + make_zmq_socket, + resolve_obj_by_qualname, + set_process_title, +) from vllm.utils.gc_utils import maybe_attach_gc_debug_callback -from vllm.v1.core.kv_cache_utils import (BlockHash, - generate_scheduler_kv_cache_config, - get_kv_cache_configs, - get_request_block_hasher, - init_none_hash) +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + generate_scheduler_kv_cache_config, + get_kv_cache_configs, + get_request_block_hasher, + init_none_hash, +) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, - ReconfigureDistributedRequest, ReconfigureRankType, - UtilityOutput, UtilityResult) -from vllm.v1.engine.utils import (EngineHandshakeMetadata, EngineZmqAddresses, - get_device_indices) +from vllm.v1.engine import ( + EngineCoreOutputs, + EngineCoreRequest, + EngineCoreRequestType, + ReconfigureDistributedRequest, + ReconfigureRankType, + UtilityOutput, + UtilityResult, +) +from vllm.v1.engine.utils import ( + EngineHandshakeMetadata, + EngineZmqAddresses, + get_device_indices, +) from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats @@ -58,51 +72,56 @@ POLLING_TIMEOUT_S = 2.5 HANDSHAKE_TIMEOUT_MINS = 5 -_R = TypeVar('_R') # Return type for collective_rpc +_R = TypeVar("_R") # Return type for collective_rpc class EngineCore: """Inner loop of vLLM's Engine.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - executor_fail_callback: Optional[Callable] = None): - + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + executor_fail_callback: Optional[Callable] = None, + ): # plugins need to be loaded at the engine/scheduler level too from vllm.plugins import load_general_plugins + load_general_plugins() self.vllm_config = vllm_config - logger.info("Initializing a V1 LLM engine (v%s) with config: %s", - VLLM_VERSION, vllm_config) + logger.info( + "Initializing a V1 LLM engine (v%s) with config: %s", + VLLM_VERSION, + vllm_config, + ) self.log_stats = log_stats # Setup Model. self.model_executor = executor_class(vllm_config) if executor_fail_callback is not None: - self.model_executor.register_failure_callback( - executor_fail_callback) + self.model_executor.register_failure_callback(executor_fail_callback) self.available_gpu_memory_for_kv_cache = -1 # Setup KV Caches and update CacheConfig after profiling. - num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ - self._initialize_kv_caches(vllm_config) + num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( + vllm_config + ) vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks - self.collective_rpc("initialize_cache", - args=(num_gpu_blocks, num_cpu_blocks)) + self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. if isinstance(vllm_config.scheduler_config.scheduler_cls, str): Scheduler = resolve_obj_by_qualname( - vllm_config.scheduler_config.scheduler_cls) + vllm_config.scheduler_config.scheduler_cls + ) else: Scheduler = vllm_config.scheduler_config.scheduler_cls @@ -114,7 +133,8 @@ def __init__(self, "Using configured V1 scheduler class %s. " "This scheduler interface is not public and " "compatibility may not be maintained.", - vllm_config.scheduler_config.scheduler_cls) + vllm_config.scheduler_config.scheduler_cls, + ) if len(kv_cache_config.kv_cache_groups) == 0: # Encoder models without KV cache don't support @@ -126,49 +146,54 @@ def __init__(self, vllm_config=vllm_config, kv_cache_config=kv_cache_config, structured_output_manager=self.structured_output_manager, - include_finished_set=vllm_config.parallel_config.data_parallel_size - > 1, + include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, log_stats=self.log_stats, ) self.use_spec_decode = vllm_config.speculative_config is not None if self.scheduler.connector is not None: # type: ignore self.model_executor.init_kv_output_aggregator( - self.scheduler.connector.get_finished_count()) # type: ignore + self.scheduler.connector.get_finished_count() # type: ignore + ) self.mm_registry = mm_registry = MULTIMODAL_REGISTRY self.mm_receiver_cache = engine_receiver_cache_from_config( - vllm_config, mm_registry) + vllm_config, mm_registry + ) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously # schedule and execute batches, and is required by pipeline parallelism # to eliminate pipeline bubbles. self.batch_queue_size = self.model_executor.max_concurrent_batches - self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput], - SchedulerOutput]]] = None + self.batch_queue: Optional[ + deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] + ] = None if self.batch_queue_size > 1: - logger.info("Batch queue is enabled with size %d", - self.batch_queue_size) + logger.info("Batch queue is enabled with size %d", self.batch_queue_size) self.batch_queue = deque(maxlen=self.batch_queue_size) - self.request_block_hasher: Optional[Callable[[Request], - list[BlockHash]]] = None - if (self.vllm_config.cache_config.enable_prefix_caching - or self.scheduler.get_kv_connector() is not None): - + self.request_block_hasher: Optional[Callable[[Request], list[BlockHash]]] = None + if ( + self.vllm_config.cache_config.enable_prefix_caching + or self.scheduler.get_kv_connector() is not None + ): block_size = vllm_config.cache_config.block_size caching_hash_fn = get_hash_fn_by_name( - vllm_config.cache_config.prefix_caching_hash_algo) + vllm_config.cache_config.prefix_caching_hash_algo + ) init_none_hash(caching_hash_fn) self.request_block_hasher = get_request_block_hasher( - block_size, caching_hash_fn) + block_size, caching_hash_fn + ) - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) + self.step_fn = ( + self.step if self.batch_queue is None else self.step_with_batch_queue + ) def _initialize_kv_caches( - self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: + self, vllm_config: VllmConfig + ) -> tuple[int, int, KVCacheConfig]: start = time.time() # Get all kv cache needed by the model @@ -179,28 +204,27 @@ def _initialize_kv_caches( if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": dp_group = getattr(self, "dp_group", None) assert dp_group is not None - self.available_gpu_memory_for_kv_cache = \ + self.available_gpu_memory_for_kv_cache = ( ParallelConfig.sync_kv_cache_memory_size(dp_group, -1) - available_gpu_memory = [ - self.available_gpu_memory_for_kv_cache - ] * len(kv_cache_specs) + ) + available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len( + kv_cache_specs + ) else: # Profiles the peak memory usage of the model to determine how # much memory can be allocated for kv cache. - available_gpu_memory = ( - self.model_executor.determine_available_memory()) - self.available_gpu_memory_for_kv_cache = \ - available_gpu_memory[0] + available_gpu_memory = self.model_executor.determine_available_memory() + self.available_gpu_memory_for_kv_cache = available_gpu_memory[0] else: # Attention free models don't need memory for kv cache available_gpu_memory = [0] * len(kv_cache_specs) assert len(kv_cache_specs) == len(available_gpu_memory) - kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs, - available_gpu_memory) - scheduler_kv_cache_config = generate_scheduler_kv_cache_config( - kv_cache_configs) + kv_cache_configs = get_kv_cache_configs( + vllm_config, kv_cache_specs, available_gpu_memory + ) + scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) num_gpu_blocks = scheduler_kv_cache_config.num_blocks num_cpu_blocks = 0 @@ -208,8 +232,10 @@ def _initialize_kv_caches( self.model_executor.initialize_from_config(kv_cache_configs) elapsed = time.time() - start - logger.info(("init engine (profile, create kv cache, " - "warmup model) took %.2f seconds"), elapsed) + logger.info( + ("init engine (profile, create kv cache, warmup model) took %.2f seconds"), + elapsed, + ) return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config def get_supported_tasks(self) -> tuple[SupportedTask, ...]: @@ -224,22 +250,27 @@ def add_request(self, request: Request, request_wave: int = 0): # Validate the request_id type. if not isinstance(request.request_id, str): raise TypeError( - f"request_id must be a string, got {type(request.request_id)}") + f"request_id must be a string, got {type(request.request_id)}" + ) if pooling_params := request.pooling_params: supported_pooling_tasks = [ - task for task in self.get_supported_tasks() - if task in POOLING_TASKS + task for task in self.get_supported_tasks() if task in POOLING_TASKS ] if pooling_params.task not in supported_pooling_tasks: - raise ValueError(f"Unsupported task: {pooling_params.task!r} " - f"Supported tasks: {supported_pooling_tasks}") + raise ValueError( + f"Unsupported task: {pooling_params.task!r} " + f"Supported tasks: {supported_pooling_tasks}" + ) if request.kv_transfer_params is not None and ( - not self.scheduler.get_kv_connector()): - logger.warning("Got kv_transfer_params, but no KVConnector found. " - "Disabling KVTransfer for this request.") + not self.scheduler.get_kv_connector() + ): + logger.warning( + "Got kv_transfer_params, but no KVConnector found. " + "Disabling KVTransfer for this request." + ) self.scheduler.add_request(request) @@ -249,8 +280,7 @@ def abort_requests(self, request_ids: list[str]): # TODO: The scheduler doesn't really need to know the # specific finish reason, TBD whether we propagate that # (i.e. client-aborted vs stop criteria met). - self.scheduler.finish_requests(request_ids, - RequestStatus.FINISHED_ABORTED) + self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) def execute_model_with_error_logging( self, @@ -266,8 +296,9 @@ def execute_model_with_error_logging( # error from execute_model itself. # NOTE: This method is exception-free - dump_engine_exception(self.vllm_config, scheduler_output, - self.scheduler.make_stats()) + dump_engine_exception( + self.vllm_config, scheduler_output, self.scheduler.make_stats() + ) raise err def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: @@ -284,12 +315,13 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: scheduler_output = self.scheduler.schedule() model_output = self.execute_model_with_error_logging( self.model_executor.execute_model, # type: ignore - scheduler_output) + scheduler_output, + ) engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) # type: ignore + scheduler_output, model_output + ) # type: ignore - return (engine_core_outputs, - scheduler_output.total_num_scheduled_tokens > 0) + return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) def post_step(self, model_executed: bool) -> None: if self.use_spec_decode and model_executed: @@ -299,7 +331,8 @@ def post_step(self, model_executed: bool) -> None: self.scheduler.update_draft_token_ids(draft_token_ids) def step_with_batch_queue( - self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: + self, + ) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: """Schedule and execute batches with the batch queue. Note that if nothing to output in this step, None is returned. @@ -324,14 +357,15 @@ def step_with_batch_queue( model_executed = False if self.scheduler.has_requests(): scheduler_output = self.scheduler.schedule() - future = self.model_executor.execute_model(scheduler_output, - non_block=True) - batch_queue.appendleft( - (future, scheduler_output)) # type: ignore[arg-type] + future = self.model_executor.execute_model(scheduler_output, non_block=True) + batch_queue.appendleft((future, scheduler_output)) # type: ignore[arg-type] model_executed = scheduler_output.total_num_scheduled_tokens > 0 - if model_executed and len(batch_queue) < self.batch_queue_size \ - and not batch_queue[-1][0].done(): + if ( + model_executed + and len(batch_queue) < self.batch_queue_size + and not batch_queue[-1][0].done() + ): # Don't block on next worker response unless the queue is full # or there are no more requests to schedule. return None, True @@ -345,10 +379,12 @@ def step_with_batch_queue( # Block until the next result is available. future, scheduler_output = batch_queue.pop() model_output = self.execute_model_with_error_logging( - lambda _: future.result(), scheduler_output) + lambda _: future.result(), scheduler_output + ) engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output) + scheduler_output, model_output + ) return engine_core_outputs, model_executed @@ -366,8 +402,10 @@ def reset_mm_cache(self): # NOTE: Since this is mainly for debugging, we don't attempt to # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror) if self.scheduler.has_unfinished_requests(): - logger.warning("Resetting the multi-modal cache when requests are " - "in progress may lead to desynced internal caches.") + logger.warning( + "Resetting the multi-modal cache when requests are " + "in progress may lead to desynced internal caches." + ) if self.mm_receiver_cache is not None: self.mm_receiver_cache.clear_cache() @@ -405,27 +443,28 @@ def save_sharded_state( pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - self.model_executor.save_sharded_state(path=path, - pattern=pattern, - max_size=max_size) - - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return self.model_executor.collective_rpc(method, timeout, args, - kwargs) + self.model_executor.save_sharded_state( + path=path, pattern=pattern, max_size=max_size + ) + + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: + return self.model_executor.collective_rpc(method, timeout, args, kwargs) def save_tensorized_model( self, tensorizer_config, ) -> None: self.model_executor.save_tensorized_model( - tensorizer_config=tensorizer_config, ) + tensorizer_config=tensorizer_config, + ) - def preprocess_add_request( - self, request: EngineCoreRequest) -> tuple[Request, int]: + def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]: """Preprocess the request. This function could be directly used in input processing thread to allow @@ -435,12 +474,11 @@ def preprocess_add_request( # `mm_receiver_cache` is reset at the end of LLMEngine init, # and will only be accessed in the input processing thread afterwards. if self.mm_receiver_cache is not None and request.mm_features: - request.mm_features = ( - self.mm_receiver_cache.get_and_update_features( - request.mm_features)) + request.mm_features = self.mm_receiver_cache.get_and_update_features( + request.mm_features + ) - req = Request.from_engine_core_request(request, - self.request_block_hasher) + req = Request.from_engine_core_request(request, self.request_block_hasher) if req.use_structured_output: # Note on thread safety: no race condition. # `grammar_init` is only invoked in input processing thread. For @@ -454,7 +492,7 @@ def preprocess_add_request( class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" - ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD' + ENGINE_CORE_DEAD = b"ENGINE_CORE_DEAD" def __init__( self, @@ -467,37 +505,46 @@ def __init__( engine_index: int = 0, ): self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() - self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], - bytes]]() + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], bytes]]() executor_fail_callback = lambda: self.input_queue.put_nowait( - (EngineCoreRequestType.EXECUTOR_FAILED, b'')) + (EngineCoreRequestType.EXECUTOR_FAILED, b"") + ) self.engine_index = engine_index identity = self.engine_index.to_bytes(length=2, byteorder="little") self.engines_running = False - with self._perform_handshakes(handshake_address, identity, - local_client, vllm_config, - client_handshake_address) as addresses: + with self._perform_handshakes( + handshake_address, + identity, + local_client, + vllm_config, + client_handshake_address, + ) as addresses: self.client_count = len(addresses.outputs) # Set up data parallel environment. self.has_coordinator = addresses.coordinator_output is not None self.frontend_stats_publish_address = ( - addresses.frontend_stats_publish_address) - logger.debug("Has DP Coordinator: %s, stats publish address: %s", - self.has_coordinator, - self.frontend_stats_publish_address) + addresses.frontend_stats_publish_address + ) + logger.debug( + "Has DP Coordinator: %s, stats publish address: %s", + self.has_coordinator, + self.frontend_stats_publish_address, + ) # Only publish request queue stats to coordinator for "internal" # and "hybrid" LB modes . self.publish_dp_lb_stats = ( self.has_coordinator - and not vllm_config.parallel_config.data_parallel_external_lb) + and not vllm_config.parallel_config.data_parallel_external_lb + ) self._init_data_parallel(vllm_config) - super().__init__(vllm_config, executor_class, log_stats, - executor_fail_callback) + super().__init__( + vllm_config, executor_class, log_stats, executor_fail_callback + ) # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, @@ -505,26 +552,34 @@ def __init__( # model forward pass. # Threads handle Socket <-> Queues and core_busy_loop uses Queue. ready_event = threading.Event() - input_thread = threading.Thread(target=self.process_input_sockets, - args=(addresses.inputs, - addresses.coordinator_input, - identity, ready_event), - daemon=True) + input_thread = threading.Thread( + target=self.process_input_sockets, + args=( + addresses.inputs, + addresses.coordinator_input, + identity, + ready_event, + ), + daemon=True, + ) input_thread.start() self.output_thread = threading.Thread( target=self.process_output_sockets, - args=(addresses.outputs, addresses.coordinator_output, - self.engine_index), - daemon=True) + args=( + addresses.outputs, + addresses.coordinator_output, + self.engine_index, + ), + daemon=True, + ) self.output_thread.start() # Don't complete handshake until DP coordinator ready message is # received. while not ready_event.wait(timeout=10): if not input_thread.is_alive(): - raise RuntimeError( - "Input socket thread died during startup") + raise RuntimeError("Input socket thread died during startup") assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") @@ -570,18 +625,23 @@ def _perform_handshakes( input_ctx = zmq.Context() is_local = local_client and client_handshake_address is None headless = not local_client - handshake = self._perform_handshake(input_ctx, handshake_address, - identity, is_local, headless, - vllm_config, - vllm_config.parallel_config) + handshake = self._perform_handshake( + input_ctx, + handshake_address, + identity, + is_local, + headless, + vllm_config, + vllm_config.parallel_config, + ) if client_handshake_address is None: with handshake as addresses: yield addresses else: assert local_client local_handshake = self._perform_handshake( - input_ctx, client_handshake_address, identity, True, False, - vllm_config) + input_ctx, client_handshake_address, identity, True, False, vllm_config + ) with handshake as addresses, local_handshake as client_addresses: addresses.inputs = client_addresses.inputs addresses.outputs = client_addresses.outputs @@ -601,16 +661,18 @@ def _perform_handshake( vllm_config: VllmConfig, parallel_config_to_update: Optional[ParallelConfig] = None, ) -> Generator[EngineZmqAddresses, None, None]: - with make_zmq_socket(ctx, - handshake_address, - zmq.DEALER, - identity=identity, - linger=5000, - bind=False) as handshake_socket: + with make_zmq_socket( + ctx, + handshake_address, + zmq.DEALER, + identity=identity, + linger=5000, + bind=False, + ) as handshake_socket: # Register engine with front-end. - addresses = self.startup_handshake(handshake_socket, local_client, - headless, - parallel_config_to_update) + addresses = self.startup_handshake( + handshake_socket, local_client, headless, parallel_config_to_update + ) yield addresses # Send ready message. @@ -620,13 +682,16 @@ def _perform_handshake( # only runs with rank 0). dp_stats_address = self.frontend_stats_publish_address handshake_socket.send( - msgspec.msgpack.encode({ - "status": "READY", - "local": local_client, - "headless": headless, - "num_gpu_blocks": num_gpu_blocks, - "dp_stats_address": dp_stats_address, - })) + msgspec.msgpack.encode( + { + "status": "READY", + "local": local_client, + "headless": headless, + "num_gpu_blocks": num_gpu_blocks, + "dp_stats_address": dp_stats_address, + } + ) + ) @staticmethod def startup_handshake( @@ -635,24 +700,29 @@ def startup_handshake( headless: bool, parallel_config: Optional[ParallelConfig] = None, ) -> EngineZmqAddresses: - # Send registration message. handshake_socket.send( - msgspec.msgpack.encode({ - "status": "HELLO", - "local": local_client, - "headless": headless, - })) + msgspec.msgpack.encode( + { + "status": "HELLO", + "local": local_client, + "headless": headless, + } + ) + ) # Receive initialization message. logger.info("Waiting for init message from front-end.") if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000): - raise RuntimeError("Did not receive response from front-end " - f"process within {HANDSHAKE_TIMEOUT_MINS} " - f"minutes") + raise RuntimeError( + "Did not receive response from front-end " + f"process within {HANDSHAKE_TIMEOUT_MINS} " + f"minutes" + ) init_bytes = handshake_socket.recv() init_message: EngineHandshakeMetadata = msgspec.msgpack.decode( - init_bytes, type=EngineHandshakeMetadata) + init_bytes, type=EngineHandshakeMetadata + ) logger.debug("Received init message: %s", init_message) if parallel_config is not None: @@ -662,10 +732,7 @@ def startup_handshake( return init_message.addresses @staticmethod - def run_engine_core(*args, - dp_rank: int = 0, - local_dp_rank: int = 0, - **kwargs): + def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): """Launch EngineCore busy loop in background process.""" # Signal handler used for graceful termination. @@ -688,8 +755,7 @@ def signal_handler(signum, frame): engine_core: Optional[EngineCoreProc] = None try: - parallel_config: ParallelConfig = kwargs[ - "vllm_config"].parallel_config + parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config if parallel_config.data_parallel_size > 1 or dp_rank > 0: set_process_title("EngineCore", f"DP{dp_rank}") decorate_logs() @@ -735,8 +801,11 @@ def _process_input_queue(self): """Exits when an engine step needs to be performed.""" waited = False - while not self.engines_running and not self.scheduler.has_requests() \ - and not self.batch_queue: + while ( + not self.engines_running + and not self.scheduler.has_requests() + and not self.batch_queue + ): if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True @@ -757,15 +826,16 @@ def _process_engine_step(self) -> bool: # Step the engine core. outputs, model_executed = self.step_fn() # Put EngineCoreOutputs into the output queue. - for output in (outputs.items() if outputs else ()): + for output in outputs.items() if outputs else (): self.output_queue.put_nowait(output) # Post-step hook. self.post_step(model_executed) return model_executed - def _handle_client_request(self, request_type: EngineCoreRequestType, - request: Any) -> None: + def _handle_client_request( + self, request_type: EngineCoreRequestType, request: Any + ) -> None: """Dispatch request from client.""" if request_type == EngineCoreRequestType.ADD: @@ -782,29 +852,35 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, output.result = UtilityResult(result) except BaseException as e: logger.exception("Invocation of %s method failed", method_name) - output.failure_message = (f"Call to {method_name} method" - f" failed: {str(e)}") + output.failure_message = ( + f"Call to {method_name} method failed: {str(e)}" + ) self.output_queue.put_nowait( - (client_idx, EngineCoreOutputs(utility_output=output))) + (client_idx, EngineCoreOutputs(utility_output=output)) + ) elif request_type == EngineCoreRequestType.EXECUTOR_FAILED: raise RuntimeError("Executor failed.") else: - logger.error("Unrecognized input request type encountered: %s", - request_type) + logger.error( + "Unrecognized input request type encountered: %s", request_type + ) @staticmethod def _convert_msgspec_args(method, args): """If a provided arg type doesn't match corresponding target method - arg type, try converting to msgspec object.""" + arg type, try converting to msgspec object.""" if not args: return args arg_types = signature(method).parameters.values() assert len(args) <= len(arg_types) return tuple( - msgspec.convert(v, type=p.annotation) if isclass(p.annotation) + msgspec.convert(v, type=p.annotation) + if isclass(p.annotation) and issubclass(p.annotation, msgspec.Struct) - and not isinstance(v, p.annotation) else v - for v, p in zip(args, arg_types)) + and not isinstance(v, p.annotation) + else v + for v, p in zip(args, arg_types) + ) def _send_engine_dead(self): """Send EngineDead status to the EngineCoreClient.""" @@ -815,12 +891,18 @@ def _send_engine_dead(self): # Wait until msg sent by the daemon before shutdown. self.output_thread.join(timeout=5.0) if self.output_thread.is_alive(): - logger.fatal("vLLM shutdown signal from EngineCore failed " - "to send. Please report this issue.") + logger.fatal( + "vLLM shutdown signal from EngineCore failed " + "to send. Please report this issue." + ) - def process_input_sockets(self, input_addresses: list[str], - coord_input_address: Optional[str], - identity: bytes, ready_event: threading.Event): + def process_input_sockets( + self, + input_addresses: list[str], + coord_input_address: Optional[str], + identity: bytes, + ready_event: threading.Event, + ): """Input socket IO thread.""" # Msgpack serialization decoding. @@ -830,24 +912,26 @@ def process_input_sockets(self, input_addresses: list[str], with ExitStack() as stack, zmq.Context() as ctx: input_sockets = [ stack.enter_context( - make_zmq_socket(ctx, - input_address, - zmq.DEALER, - identity=identity, - bind=False)) + make_zmq_socket( + ctx, input_address, zmq.DEALER, identity=identity, bind=False + ) + ) for input_address in input_addresses ] if coord_input_address is None: coord_socket = None else: coord_socket = stack.enter_context( - make_zmq_socket(ctx, - coord_input_address, - zmq.XSUB, - identity=identity, - bind=False)) + make_zmq_socket( + ctx, + coord_input_address, + zmq.XSUB, + identity=identity, + bind=False, + ) + ) # Send subscription message to coordinator. - coord_socket.send(b'\x01') + coord_socket.send(b"\x01") # Register sockets with poller. poller = zmq.Poller() @@ -855,7 +939,7 @@ def process_input_sockets(self, input_addresses: list[str], # Send initial message to each input socket - this is required # before the front-end ROUTER socket can send input messages # back to us. - input_socket.send(b'') + input_socket.send(b"") poller.register(input_socket, zmq.POLLIN) if coord_socket is not None: @@ -868,10 +952,8 @@ def process_input_sockets(self, input_addresses: list[str], while True: for input_socket, _ in poller.poll(): # (RequestType, RequestData) - type_frame, *data_frames = input_socket.recv_multipart( - copy=False) - request_type = EngineCoreRequestType( - bytes(type_frame.buffer)) + type_frame, *data_frames = input_socket.recv_multipart(copy=False) + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. if request_type == EngineCoreRequestType.ADD: @@ -883,9 +965,12 @@ def process_input_sockets(self, input_addresses: list[str], # Push to input queue for core busy loop. self.input_queue.put_nowait((request_type, request)) - def process_output_sockets(self, output_paths: list[str], - coord_output_path: Optional[str], - engine_index: int): + def process_output_sockets( + self, + output_paths: list[str], + coord_output_path: Optional[str], + engine_index: int, + ): """Output socket IO thread.""" # Msgpack serialization encoding. @@ -902,13 +987,19 @@ def process_output_sockets(self, output_paths: list[str], with ExitStack() as stack, zmq.Context() as ctx: sockets = [ stack.enter_context( - make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000)) + make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000) + ) for output_path in output_paths ] - coord_socket = stack.enter_context( - make_zmq_socket( - ctx, coord_output_path, zmq.PUSH, bind=False, - linger=4000)) if coord_output_path is not None else None + coord_socket = ( + stack.enter_context( + make_zmq_socket( + ctx, coord_output_path, zmq.PUSH, bind=False, linger=4000 + ) + ) + if coord_output_path is not None + else None + ) max_reuse_bufs = len(sockets) + 1 while True: @@ -934,9 +1025,9 @@ def process_output_sockets(self, output_paths: list[str], buffer = reuse_buffers.pop() if reuse_buffers else bytearray() buffers = encoder.encode_into(outputs, buffer) - tracker = sockets[client_index].send_multipart(buffers, - copy=False, - track=True) + tracker = sockets[client_index].send_multipart( + buffers, copy=False, track=True + ) if not tracker.done: ref = outputs if len(buffers) > 1 else None pending.appendleft((tracker, ref, buffer)) @@ -966,12 +1057,17 @@ def __init__( # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank - super().__init__(vllm_config, local_client, handshake_address, - executor_class, log_stats, client_handshake_address, - dp_rank) + super().__init__( + vllm_config, + local_client, + handshake_address, + executor_class, + log_stats, + client_handshake_address, + dp_rank, + ) def _init_data_parallel(self, vllm_config: VllmConfig): - # Configure GPUs and stateless process group for data parallel. dp_rank = vllm_config.parallel_config.data_parallel_rank dp_size = vllm_config.parallel_config.data_parallel_size @@ -986,8 +1082,10 @@ def _init_data_parallel(self, vllm_config: VllmConfig): vllm_config.kv_transfer_config.engine_id = ( f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}" ) - logger.debug("Setting kv_transfer_config.engine_id to %s", - vllm_config.kv_transfer_config.engine_id) + logger.debug( + "Setting kv_transfer_config.engine_id to %s", + vllm_config.kv_transfer_config.engine_id, + ) self.dp_rank = dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() @@ -1005,20 +1103,22 @@ def add_request(self, request: Request, request_wave: int = 0): # Request received for an already-completed wave, notify # front-end that we need to start the next one. self.output_queue.put_nowait( - (-1, EngineCoreOutputs(start_wave=self.current_wave))) + (-1, EngineCoreOutputs(start_wave=self.current_wave)) + ) super().add_request(request, request_wave) - def _handle_client_request(self, request_type: EngineCoreRequestType, - request: Any) -> None: + def _handle_client_request( + self, request_type: EngineCoreRequestType, request: Any + ) -> None: if request_type == EngineCoreRequestType.START_DP_WAVE: new_wave, exclude_eng_index = request if exclude_eng_index != self.engine_index and ( - new_wave >= self.current_wave): + new_wave >= self.current_wave + ): self.current_wave = new_wave if not self.engines_running: - logger.debug("EngineCore starting idle loop for wave %d.", - new_wave) + logger.debug("EngineCore starting idle loop for wave %d.", new_wave) self.engines_running = True else: super()._handle_client_request(request_type, request) @@ -1031,11 +1131,10 @@ def _maybe_publish_request_counts(self): counts = self.scheduler.get_request_counts() if counts != self.last_counts: self.last_counts = counts - stats = SchedulerStats(*counts, - step_counter=self.step_counter, - current_wave=self.current_wave) - self.output_queue.put_nowait( - (-1, EngineCoreOutputs(scheduler_stats=stats))) + stats = SchedulerStats( + *counts, step_counter=self.step_counter, current_wave=self.current_wave + ) + self.output_queue.put_nowait((-1, EngineCoreOutputs(scheduler_stats=stats))) def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" @@ -1061,58 +1160,65 @@ def run_busy_loop(self): # 3) All-reduce operation to determine global unfinished reqs. self.engines_running = self._has_global_unfinished_reqs( - local_unfinished_reqs) + local_unfinished_reqs + ) if not self.engines_running: if self.dp_rank == 0 or not self.has_coordinator: # Notify client that we are pausing the loop. - logger.debug("Wave %d finished, pausing engine loop.", - self.current_wave) + logger.debug( + "Wave %d finished, pausing engine loop.", self.current_wave + ) # In the coordinator case, dp rank 0 sends updates to the # coordinator. Otherwise (offline spmd case), each rank # sends the update to its colocated front-end process. client_index = -1 if self.has_coordinator else 0 self.output_queue.put_nowait( - (client_index, - EngineCoreOutputs(wave_complete=self.current_wave))) + ( + client_index, + EngineCoreOutputs(wave_complete=self.current_wave), + ) + ) # Increment wave count and reset step counter. self.current_wave += 1 self.step_counter = 0 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: - # Optimization - only perform finish-sync all-reduce every 32 steps. self.step_counter += 1 if self.step_counter % 32 != 0: return True - return ParallelConfig.has_unfinished_dp(self.dp_group, - local_unfinished) + return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: stateless_destroy_torch_distributed_process_group(self.dp_group) self.shutdown() parallel_config = self.vllm_config.parallel_config old_dp_size = parallel_config.data_parallel_size - parallel_config.data_parallel_size = \ - reconfig_request.new_data_parallel_size + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size if reconfig_request.new_data_parallel_rank != -1: - parallel_config.data_parallel_rank = \ - reconfig_request.new_data_parallel_rank + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank # local rank specifies device visibility, it should not be changed - assert reconfig_request.new_data_parallel_rank_local == \ - ReconfigureRankType.KEEP_CURRENT_RANK - parallel_config.data_parallel_master_ip = \ + assert ( + reconfig_request.new_data_parallel_rank_local + == ReconfigureRankType.KEEP_CURRENT_RANK + ) + parallel_config.data_parallel_master_ip = ( reconfig_request.new_data_parallel_master_ip - parallel_config.data_parallel_master_port = \ + ) + parallel_config.data_parallel_master_port = ( reconfig_request.new_data_parallel_master_port + ) if reconfig_request.new_data_parallel_rank != -2: self.dp_rank = parallel_config.data_parallel_rank self.dp_group = parallel_config.stateless_init_dp_group() - reconfig_request.new_data_parallel_master_port = \ + reconfig_request.new_data_parallel_master_port = ( parallel_config.data_parallel_master_port + ) self.model_executor.reinitialize_distributed(reconfig_request) if reconfig_request.new_data_parallel_size > old_dp_size: @@ -1121,17 +1227,21 @@ def reinitialize_distributed( # engine-cores to new engine-cores so they can directly # use it in _initialize_kv_caches() rather than profiling. ParallelConfig.sync_kv_cache_memory_size( - self.dp_group, self.available_gpu_memory_for_kv_cache) + self.dp_group, self.available_gpu_memory_for_kv_cache + ) # NOTE(yongji): newly joined workers require dummy_run even # CUDA graph is not used self.model_executor.collective_rpc("compile_or_warm_up_model") - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): self.shutdown() logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) else: - logger.info("Distributed environment reinitialized for DP rank %s", - self.dp_rank) + logger.info( + "Distributed environment reinitialized for DP rank %s", self.dp_rank + ) class DPEngineCoreActor(DPEngineCoreProc): @@ -1151,8 +1261,7 @@ def __init__( ): self.addresses = addresses vllm_config.parallel_config.data_parallel_rank = dp_rank - vllm_config.parallel_config.data_parallel_rank_local = \ - local_dp_rank + vllm_config.parallel_config.data_parallel_rank_local = local_dp_rank # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time, @@ -1173,39 +1282,46 @@ def __init__( # of ray. self._set_visible_devices(vllm_config, local_dp_rank) - super().__init__(vllm_config, local_client, "", executor_class, - log_stats) + super().__init__(vllm_config, local_client, "", executor_class, log_stats) - def _set_visible_devices(self, vllm_config: VllmConfig, - local_dp_rank: int): + def _set_visible_devices(self, vllm_config: VllmConfig, local_dp_rank: int): from vllm.platforms import current_platform + if current_platform.is_xpu(): pass else: device_control_env_var = current_platform.device_control_env_var - self._set_cuda_visible_devices(vllm_config, local_dp_rank, - device_control_env_var) + self._set_cuda_visible_devices( + vllm_config, local_dp_rank, device_control_env_var + ) - def _set_cuda_visible_devices(self, vllm_config: VllmConfig, - local_dp_rank: int, - device_control_env_var: str): + def _set_cuda_visible_devices( + self, vllm_config: VllmConfig, local_dp_rank: int, device_control_env_var: str + ): world_size = vllm_config.parallel_config.world_size # Set CUDA_VISIBLE_DEVICES or equivalent. try: - value = get_device_indices(device_control_env_var, local_dp_rank, - world_size) + value = get_device_indices( + device_control_env_var, local_dp_rank, world_size + ) os.environ[device_control_env_var] = value except IndexError as e: raise Exception( f"Error setting {device_control_env_var}: " f"local range: [{local_dp_rank * world_size}, " f"{(local_dp_rank + 1) * world_size}) " - f"base value: \"{os.getenv(device_control_env_var)}\"") from e + f'base value: "{os.getenv(device_control_env_var)}"' + ) from e @contextmanager - def _perform_handshakes(self, handshake_address: str, identity: bytes, - local_client: bool, vllm_config: VllmConfig, - client_handshake_address: Optional[str]): + def _perform_handshakes( + self, + handshake_address: str, + identity: bytes, + local_client: bool, + vllm_config: VllmConfig, + client_handshake_address: Optional[str], + ): """ For Ray, we don't need to actually perform handshake. All addresses information is known before the actor creation. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index a84b0e55105b..27283411eada 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -23,17 +23,29 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask -from vllm.utils import (close_sockets, get_open_port, get_open_zmq_inproc_path, - in_loop, make_zmq_socket) -from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, - ReconfigureDistributedRequest, ReconfigureRankType, - UtilityOutput) +from vllm.utils import ( + close_sockets, + get_open_port, + get_open_zmq_inproc_path, + in_loop, + make_zmq_socket, +) +from vllm.v1.engine import ( + EngineCoreOutputs, + EngineCoreRequest, + EngineCoreRequestType, + ReconfigureDistributedRequest, + ReconfigureRankType, + UtilityOutput, +) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError -from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager, launch_core_engines) +from vllm.v1.engine.utils import ( + CoreEngineActorManager, + CoreEngineProcManager, + launch_core_engines, +) from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr @@ -41,14 +53,14 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]] -_R = TypeVar('_R') # Return type for collective_rpc +_R = TypeVar("_R") # Return type for collective_rpc EngineIdentity = bytes class EngineCoreClient(ABC): """ - EngineCoreClient: subclasses handle different methods for pushing + EngineCoreClient: subclasses handle different methods for pushing and pulling from the EngineCore for asyncio / multiprocessing. Subclasses: @@ -65,16 +77,17 @@ def make_client( executor_class: type[Executor], log_stats: bool, ) -> "EngineCoreClient": - # TODO: support this for debugging purposes. if asyncio_mode and not multiprocess_mode: raise NotImplementedError( "Running EngineCore in asyncio without multiprocessing " - "is not currently supported.") + "is not currently supported." + ) if multiprocess_mode and asyncio_mode: return EngineCoreClient.make_async_mp_client( - vllm_config, executor_class, log_stats) + vllm_config, executor_class, log_stats + ) if multiprocess_mode and not asyncio_mode: return SyncMPClient(vllm_config, executor_class, log_stats) @@ -91,8 +104,14 @@ def make_async_mp_client( client_index: int = 0, ) -> "MPClient": parallel_config = vllm_config.parallel_config - client_args = (vllm_config, executor_class, log_stats, - client_addresses, client_count, client_index) + client_args = ( + vllm_config, + executor_class, + log_stats, + client_addresses, + client_count, + client_index, + ) if parallel_config.data_parallel_size > 1: if parallel_config.data_parallel_external_lb: # External load balancer - client per DP rank. @@ -102,8 +121,7 @@ def make_async_mp_client( return AsyncMPClient(*client_args) @abstractmethod - def shutdown(self): - ... + def shutdown(self): ... def get_output(self) -> EngineCoreOutputs: raise NotImplementedError @@ -153,17 +171,18 @@ def list_loras(self) -> set[int]: def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError - def save_sharded_state(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + def save_sharded_state( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: raise NotImplementedError - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: raise NotImplementedError def dp_engines_running(self) -> bool: @@ -216,24 +235,24 @@ async def list_loras_async(self) -> set[int]: async def pin_lora_async(self, lora_id: int) -> bool: raise NotImplementedError - async def save_sharded_state_async(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + async def save_sharded_state_async( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: raise NotImplementedError async def collective_rpc_async( - self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: raise NotImplementedError class InprocClient(EngineCoreClient): """ - InprocClient: client for in-process EngineCore. Intended + InprocClient: client for in-process EngineCore. Intended for use in LLMEngine for V0-style add_request() and step() EngineCore setup in this process (no busy loop). @@ -295,17 +314,18 @@ def list_loras(self) -> set[int]: def pin_lora(self, lora_id: int) -> bool: return self.engine_core.pin_lora(lora_id) - def save_sharded_state(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + def save_sharded_state( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: self.engine_core.save_sharded_state(path, pattern, max_size) - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) def dp_engines_running(self) -> bool: @@ -320,8 +340,9 @@ class BackgroundResources: ctx: zmq.Context # If CoreEngineProcManager, it manages local engines; # if CoreEngineActorManager, it manages all engines. - engine_manager: Optional[Union[CoreEngineProcManager, - CoreEngineActorManager]] = None + engine_manager: Optional[Union[CoreEngineProcManager, CoreEngineActorManager]] = ( + None + ) coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None @@ -347,12 +368,15 @@ def __call__(self): if isinstance(self.output_socket, zmq.asyncio.Socket): # Async case. - loop = self.output_queue_task._loop \ - if self.output_queue_task else None + loop = self.output_queue_task._loop if self.output_queue_task else None - sockets = (self.output_socket, self.input_socket, - self.first_req_send_socket, self.first_req_rcv_socket, - self.stats_update_socket) + sockets = ( + self.output_socket, + self.input_socket, + self.first_req_send_socket, + self.first_req_rcv_socket, + self.stats_update_socket, + ) tasks = (self.output_queue_task, self.stats_update_task) @@ -387,11 +411,10 @@ def close_sockets_and_tasks(): with self.ctx.socket(zmq.PAIR) as shutdown_sender: shutdown_sender.connect(self.shutdown_path) # Send shutdown signal. - shutdown_sender.send(b'') + shutdown_sender.send(b"") def validate_alive(self, frames: Sequence[zmq.Frame]): - if len(frames) == 1 and (frames[0].buffer - == EngineCoreProc.ENGINE_CORE_DEAD): + if len(frames) == 1 and (frames[0].buffer == EngineCoreProc.ENGINE_CORE_DEAD): self.engine_dead = True raise EngineDeadError() @@ -404,7 +427,7 @@ class MPClient(EngineCoreClient): * pushes EngineCoreRequests via input_socket * pulls EngineCoreOutputs via output_socket - + * AsyncMPClient subclass for AsyncLLM usage * SyncMPClient subclass for LLM usage """ @@ -441,30 +464,32 @@ def __init__( # Engines are managed externally to this client. input_address = client_addresses["input_address"] output_address = client_addresses["output_address"] - self.stats_update_address = client_addresses.get( - "stats_update_address") + self.stats_update_address = client_addresses.get("stats_update_address") else: # Engines are managed by this client. - with launch_core_engines(vllm_config, executor_class, - log_stats) as (engine_manager, - coordinator, - addresses): + with launch_core_engines(vllm_config, executor_class, log_stats) as ( + engine_manager, + coordinator, + addresses, + ): self.resources.coordinator = coordinator self.resources.engine_manager = engine_manager - (input_address, ) = addresses.inputs - (output_address, ) = addresses.outputs - self.stats_update_address = ( - addresses.frontend_stats_publish_address) + (input_address,) = addresses.inputs + (output_address,) = addresses.outputs + self.stats_update_address = addresses.frontend_stats_publish_address if coordinator is not None: assert self.stats_update_address == ( - coordinator.get_stats_publish_address()) + coordinator.get_stats_publish_address() + ) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( - self.ctx, input_address, zmq.ROUTER, bind=True) + self.ctx, input_address, zmq.ROUTER, bind=True + ) self.resources.output_socket = make_zmq_socket( - self.ctx, output_address, zmq.PULL) + self.ctx, output_address, zmq.PULL + ) parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size @@ -473,19 +498,22 @@ def __init__( offline_mode = parallel_config.data_parallel_rank_local is not None # Client manages local+remote EngineCores in pure internal LB case. # Client manages local EngineCores in hybrid and external LB case. - local_engines_only = (parallel_config.data_parallel_hybrid_lb - or parallel_config.data_parallel_external_lb) + local_engines_only = ( + parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb + ) num_ranks = dp_local_size if local_engines_only else dp_size - self.engine_ranks_managed = [dp_rank] if offline_mode else list( - range(dp_rank, dp_rank + num_ranks)) + self.engine_ranks_managed = ( + [dp_rank] if offline_mode else list(range(dp_rank, dp_rank + num_ranks)) + ) assert parallel_config.data_parallel_size_local <= len( - self.engine_ranks_managed) + self.engine_ranks_managed + ) # ZMQ identity of each engine that this client will talk to. self.core_engines: list[EngineIdentity] = [ - rank.to_bytes(2, "little") - for rank in self.engine_ranks_managed + rank.to_bytes(2, "little") for rank in self.engine_ranks_managed ] # Wait for ready messages from each engine on the input socket. @@ -493,8 +521,10 @@ def __init__( sync_input_socket = zmq.Socket.shadow(self.input_socket) while identities: if not sync_input_socket.poll(timeout=600_000): - raise TimeoutError("Timed out waiting for engines to send" - "initial message on input socket.") + raise TimeoutError( + "Timed out waiting for engines to send" + "initial message on input socket." + ) identity, _ = sync_input_socket.recv_multipart() identities.remove(identity) @@ -520,8 +550,9 @@ def shutdown(self): def _format_exception(self, e: Exception) -> Exception: """If errored, use EngineDeadError so root cause is clear.""" - return EngineDeadError( - suppress_context=True) if self.resources.engine_dead else e + return ( + EngineDeadError(suppress_context=True) if self.resources.engine_dead else e + ) def ensure_alive(self): if self.resources.engine_dead: @@ -541,8 +572,11 @@ def dp_engines_running(self) -> bool: def start_engine_core_monitor(self): """Start a monitor thread for engine core processes.""" engine_manager = self.resources.engine_manager - if (engine_manager is None or not hasattr(engine_manager, 'processes') - or not engine_manager.processes): + if ( + engine_manager is None + or not hasattr(engine_manager, "processes") + or not engine_manager.processes + ): # No engine processes to monitor return @@ -559,23 +593,26 @@ def monitor_engine_cores(): if not _self or _self.resources.engine_dead: return _self.resources.engine_dead = True - proc_name = next(proc.name for proc in engine_processes - if proc.sentinel == died[0]) + proc_name = next( + proc.name for proc in engine_processes if proc.sentinel == died[0] + ) logger.error( - "Engine core proc %s died unexpectedly, " - "shutting down client.", proc_name) + "Engine core proc %s died unexpectedly, shutting down client.", + proc_name, + ) _self.shutdown() # Note: For MPClient, we don't have a failure callback mechanism # like MultiprocExecutor, but we set engine_dead flag which will # cause subsequent operations to raise EngineDeadError - Thread(target=monitor_engine_cores, - daemon=True, - name="MPClientEngineMonitor").start() + Thread( + target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor" + ).start() -def _process_utility_output(output: UtilityOutput, - utility_results: dict[int, AnyFuture]): +def _process_utility_output( + output: UtilityOutput, utility_results: dict[int, AnyFuture] +): """Set the result from a utility method in the waiting future.""" future = utility_results.pop(output.call_id) failure_message = output.failure_message @@ -590,15 +627,17 @@ def _process_utility_output(output: UtilityOutput, # original calling task being cancelled. if failure_message is not None: logger.error( - "Cancelled call to utility method failed " - "with error: %s", failure_message) + "Cancelled call to utility method failed with error: %s", + failure_message, + ) class SyncMPClient(MPClient): """Synchronous client for multi-proc EngineCore.""" - def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], - log_stats: bool): + def __init__( + self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool + ): super().__init__( asyncio_mode=False, vllm_config=vllm_config, @@ -641,8 +680,7 @@ def process_outputs_socket(): resources.validate_alive(frames) outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) + _process_utility_output(outputs.utility_output, utility_results) else: outputs_queue.put_nowait(outputs) except Exception as e: @@ -653,9 +691,11 @@ def process_outputs_socket(): out_socket.close(linger=0) # Process outputs from engine in separate thread. - self.output_queue_thread = Thread(target=process_outputs_socket, - name="EngineCoreOutputQueueThread", - daemon=True) + self.output_queue_thread = Thread( + target=process_outputs_socket, + name="EngineCoreOutputQueueThread", + daemon=True, + ) self.output_queue_thread.start() # The thread takes on responsibility for closing the socket. @@ -676,8 +716,7 @@ def _send_input(self, request_type: EngineCoreRequestType, request: Any): self.ensure_alive() self.free_pending_messages() # (Identity, RequestType, SerializedRequest) - msg = (self.core_engine, request_type.value, - *self.encoder.encode(request)) + msg = (self.core_engine, request_type.value, *self.encoder.encode(request)) if len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. @@ -691,8 +730,7 @@ def call_utility(self, method: str, *args) -> Any: call_id = uuid.uuid1().int >> 64 future: Future[Any] = Future() self.utility_results[call_id] = future - self._send_input(EngineCoreRequestType.UTILITY, - (0, call_id, method, args)) + self._send_input(EngineCoreRequestType.UTILITY, (0, call_id, method, args)) return future.result() @@ -741,31 +779,33 @@ def is_sleeping(self) -> bool: def execute_dummy_batch(self) -> None: self.call_utility("execute_dummy_batch") - def collective_rpc(self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return self.call_utility("collective_rpc", method, timeout, args, - kwargs) - - def save_sharded_state(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: + def collective_rpc( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: + return self.call_utility("collective_rpc", method, timeout, args, kwargs) + + def save_sharded_state( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: self.call_utility("save_sharded_state", path, pattern, max_size) class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0): + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + ): super().__init__( asyncio_mode=True, vllm_config=vllm_config, @@ -776,8 +816,7 @@ def __init__(self, self.client_count = client_count self.client_index = client_index - self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, - Exception]]() + self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs, Exception]]() try: # If we are running in an asyncio event loop, start the queue task. # Otherwise, it will be started lazily. If it is not started here, @@ -798,10 +837,9 @@ def _ensure_output_queue_task(self): decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue - output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs], - Awaitable[None]]] = getattr( - self.__class__, - "process_engine_outputs", None) + output_handler: Optional[ + Callable[[AsyncMPClient, EngineCoreOutputs], Awaitable[None]] + ] = getattr(self.__class__, "process_engine_outputs", None) _self_ref = weakref.ref(self) if output_handler else None output_socket = resources.output_socket assert output_socket is not None @@ -813,8 +851,7 @@ async def process_outputs_socket(): resources.validate_alive(frames) outputs: EngineCoreOutputs = decoder.decode(frames) if outputs.utility_output: - _process_utility_output(outputs.utility_output, - utility_results) + _process_utility_output(outputs.utility_output, utility_results) continue if output_handler is not None: @@ -833,7 +870,8 @@ async def process_outputs_socket(): outputs_queue.put_nowait(EngineDeadError()) resources.output_queue_task = asyncio.create_task( - process_outputs_socket(), name="EngineCoreOutputQueueTask") + process_outputs_socket(), name="EngineCoreOutputQueueTask" + ) async def get_output_async(self) -> EngineCoreOutputs: self._ensure_output_queue_task() @@ -846,19 +884,21 @@ async def get_output_async(self) -> EngineCoreOutputs: raise self._format_exception(outputs) from None return outputs - def _send_input(self, - request_type: EngineCoreRequestType, - request: Any, - engine: Optional[EngineIdentity] = None) -> Awaitable[Any]: + def _send_input( + self, + request_type: EngineCoreRequestType, + request: Any, + engine: Optional[EngineIdentity] = None, + ) -> Awaitable[Any]: if engine is None: engine = self.core_engine message = (request_type.value, *self.encoder.encode(request)) return self._send_input_message(message, engine, request) - def _send_input_message(self, message: tuple[bytestr, - ...], engine: EngineIdentity, - objects: Any) -> Awaitable[Any]: + def _send_input_message( + self, message: tuple[bytestr, ...], engine: EngineIdentity, objects: Any + ) -> Awaitable[Any]: """ objects is a reference to retain until zmq is finished with the buffers, in case they were extracted from tensors in the request. @@ -866,7 +906,7 @@ def _send_input_message(self, message: tuple[bytestr, self.ensure_alive() self.free_pending_messages() - msg = (engine, ) + message + msg = (engine,) + message if not objects or len(msg) <= 3: # No auxiliary buffers => no tensor backing buffers in request. return self.input_socket.send_multipart(msg, copy=False) @@ -882,17 +922,18 @@ def add_pending(f: asyncio.Future[zmq.MessageTracker]): return future async def call_utility_async(self, method: str, *args) -> Any: - return await self._call_utility_async(method, - *args, - engine=self.core_engine) + return await self._call_utility_async(method, *args, engine=self.core_engine) - async def _call_utility_async(self, method: str, *args, - engine: EngineIdentity) -> Any: + async def _call_utility_async( + self, method: str, *args, engine: EngineIdentity + ) -> Any: call_id = uuid.uuid1().int >> 64 future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future - message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (self.client_index, call_id, method, args))) + message = ( + EngineCoreRequestType.UTILITY.value, + *self.encoder.encode((self.client_index, call_id, method, args)), + ) await self._send_input_message(message, engine, args) self._ensure_output_queue_task() return await future @@ -942,38 +983,46 @@ async def list_loras_async(self) -> set[int]: async def pin_lora_async(self, lora_id: int) -> bool: return await self.call_utility_async("pin_lora", lora_id) - async def save_sharded_state_async(self, - path: str, - pattern: Optional[str] = None, - max_size: Optional[int] = None) -> None: - await self.call_utility_async("save_sharded_state", path, pattern, - max_size) + async def save_sharded_state_async( + self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None + ) -> None: + await self.call_utility_async("save_sharded_state", path, pattern, max_size) async def collective_rpc_async( - self, - method: Union[str, Callable[..., _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: - return await self.call_utility_async("collective_rpc", method, timeout, - args, kwargs) + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: + return await self.call_utility_async( + "collective_rpc", method, timeout, args, kwargs + ) class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore. Assumes external load-balancing by default.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0): + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + ): self.current_wave = 0 - super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_count, client_index) + super().__init__( + vllm_config, + executor_class, + log_stats, + client_addresses, + client_count, + client_index, + ) # List of [waiting, running] pair per engine. # Used only by DPLBAsyncMPClient subclass. @@ -981,10 +1030,8 @@ def __init__(self, self.first_req_sock_addr = get_open_zmq_inproc_path() self.first_req_send_socket = self.resources.first_req_send_socket = ( - make_zmq_socket(self.ctx, - self.first_req_sock_addr, - zmq.PAIR, - bind=True)) + make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True) + ) try: # If we are running in an asyncio event loop, start the stats task. # Otherwise, it will be started lazily. @@ -1003,25 +1050,25 @@ def _ensure_stats_update_task(self): # NOTE: running and waiting counts are all global from # the Coordinator include all global EngineCores. This # slice includes just the cores managed by this client. - count_slice = slice(self.engine_ranks_managed[0], - self.engine_ranks_managed[-1] + 1) + count_slice = slice( + self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1 + ) async def run_engine_stats_update_task(): - with (make_zmq_socket(self.ctx, - self.stats_update_address, - zmq.XSUB, - linger=0) as socket, - make_zmq_socket(self.ctx, - self.first_req_sock_addr, - zmq.PAIR, - bind=False, - linger=0) as first_req_rcv_socket): + with ( + make_zmq_socket( + self.ctx, self.stats_update_address, zmq.XSUB, linger=0 + ) as socket, + make_zmq_socket( + self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0 + ) as first_req_rcv_socket, + ): assert isinstance(socket, zmq.asyncio.Socket) assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket) self.resources.stats_update_socket = socket self.resources.first_req_rcv_socket = first_req_rcv_socket # Send subscription message. - await socket.send(b'\x01') + await socket.send(b"\x01") poller = zmq.asyncio.Poller() poller.register(socket, zmq.POLLIN) @@ -1029,23 +1076,27 @@ async def run_engine_stats_update_task(): while True: events = await poller.poll() - if not self.engines_running and len(events) == 2 or ( - events[0][0] == first_req_rcv_socket): + if ( + not self.engines_running + and len(events) == 2 + or (events[0][0] == first_req_rcv_socket) + ): # Check if this is a regular request notification or # scale up notification - buf = first_req_rcv_socket.recv( - flags=zmq.NOBLOCK).result() + buf = first_req_rcv_socket.recv(flags=zmq.NOBLOCK).result() decoded = msgspec.msgpack.decode(buf) - if isinstance( - decoded, - (list, tuple)) and len(decoded) == 2 and decoded[ - 0] == "SCALE_ELASTIC_EP": + if ( + isinstance(decoded, (list, tuple)) + and len(decoded) == 2 + and decoded[0] == "SCALE_ELASTIC_EP" + ): # Extract new engine count from the decoded message new_engine_count = decoded[1] # Send scale up notification to coordinator scale_msg = msgspec.msgpack.encode( - ("SCALE_ELASTIC_EP", new_engine_count)) + ("SCALE_ELASTIC_EP", new_engine_count) + ) await socket.send(scale_msg) continue @@ -1056,14 +1107,14 @@ async def run_engine_stats_update_task(): target_eng_index = decoded[1] self.engines_running = True msg = msgspec.msgpack.encode( - (target_eng_index, self.current_wave)) + (target_eng_index, self.current_wave) + ) await socket.send(msg) buf = None while True: # Drain all stats events (we only care about latest). - future: asyncio.Future[bytes] = socket.recv( - flags=zmq.NOBLOCK) + future: asyncio.Future[bytes] = socket.recv(flags=zmq.NOBLOCK) if isinstance(future.exception(), zmq.Again): break buf = future.result() @@ -1077,11 +1128,13 @@ async def run_engine_stats_update_task(): if counts is not None: sliced_counts = counts[count_slice] self.lb_engines = sliced_counts - logger.debug("Received counts: %s (%s)", sliced_counts, - count_slice) + logger.debug( + "Received counts: %s (%s)", sliced_counts, count_slice + ) resources.stats_update_task = asyncio.create_task( - run_engine_stats_update_task()) + run_engine_stats_update_task() + ) async def add_request_async(self, request: EngineCoreRequest) -> None: self._ensure_stats_update_task() @@ -1090,8 +1143,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None: request.client_index = self.client_index chosen_engine = self.get_core_engine_for_request(request) - to_await = self._send_input(EngineCoreRequestType.ADD, request, - chosen_engine) + to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine) if not self.engines_running: # Notify coordinator that we're sending a request req_msg = msgspec.msgpack.encode(("FIRST_REQ", chosen_engine)) @@ -1109,29 +1161,36 @@ class DPLBAsyncMPClient(DPAsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) EngineCore. Load-balances between multiple engine processes.""" - def __init__(self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_count: int = 1, - client_index: int = 0): - + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, + client_index: int = 0, + ): self.client_count = client_count # To route aborts to the correct engine. self.reqs_in_flight: dict[str, EngineIdentity] = {} - super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_count, client_index) + super().__init__( + vllm_config, + executor_class, + log_stats, + client_addresses, + client_count, + client_index, + ) assert len(self.core_engines) > 1 - self.eng_start_index = (len(self.core_engines) * - self.client_index) // client_count + self.eng_start_index = ( + len(self.core_engines) * self.client_index + ) // client_count - def get_core_engine_for_request( - self, request: EngineCoreRequest) -> EngineIdentity: + def get_core_engine_for_request(self, request: EngineCoreRequest) -> EngineIdentity: # Engines are in rank order. if (eng_index := request.data_parallel_rank) is None: current_counts = self.lb_engines @@ -1159,14 +1218,19 @@ def get_core_engine_for_request( async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. - return (await asyncio.gather(*[ - self._call_utility_async(method, *args, engine=engine) - for engine in self.core_engines - ]))[0] + return ( + await asyncio.gather( + *[ + self._call_utility_async(method, *args, engine=engine) + for engine in self.core_engines + ] + ) + )[0] @staticmethod - async def process_engine_outputs(self: "DPLBAsyncMPClient", - outputs: EngineCoreOutputs): + async def process_engine_outputs( + self: "DPLBAsyncMPClient", outputs: EngineCoreOutputs + ): if outputs.finished_requests and self.reqs_in_flight: for req_id in outputs.finished_requests: self.reqs_in_flight.pop(req_id, None) @@ -1188,10 +1252,10 @@ async def abort_requests_async(self, request_ids: list[str]) -> None: for engine, req_ids in by_engine.items(): await self._abort_requests(req_ids, engine) - async def _abort_requests(self, request_ids: list[str], - engine: EngineIdentity) -> None: - await self._send_input(EngineCoreRequestType.ABORT, request_ids, - engine) + async def _abort_requests( + self, request_ids: list[str], engine: EngineIdentity + ) -> None: + await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: """Scale elastic EP data parallel size""" @@ -1199,22 +1263,27 @@ async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: assert new_data_parallel_size != cur_data_parallel_size, ( f"new_data_parallel_size {new_data_parallel_size} must be " - f"different from cur_data_parallel_size {cur_data_parallel_size}") + f"different from cur_data_parallel_size {cur_data_parallel_size}" + ) - assert self.vllm_config.parallel_config.data_parallel_backend == \ - "ray", "Only ray DP backend supports scaling elastic EP" + assert self.vllm_config.parallel_config.data_parallel_backend == "ray", ( + "Only ray DP backend supports scaling elastic EP" + ) scale_up = new_data_parallel_size > cur_data_parallel_size if scale_up: - await self._scale_up_elastic_ep(cur_data_parallel_size, - new_data_parallel_size) + await self._scale_up_elastic_ep( + cur_data_parallel_size, new_data_parallel_size + ) else: - await self._scale_down_elastic_ep(cur_data_parallel_size, - new_data_parallel_size) + await self._scale_down_elastic_ep( + cur_data_parallel_size, new_data_parallel_size + ) - async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, - new_data_parallel_size: int) -> None: + async def _scale_up_elastic_ep( + self, cur_data_parallel_size: int, new_data_parallel_size: int + ) -> None: """Scale up the data parallel size by creating new engine cores and reconfiguring existing ones.""" cur_data_parallel_size = len(self.core_engines) @@ -1222,21 +1291,18 @@ async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, # Phase 1: Send reconfigure messages to all existing engines and wait # for them to be sent reconfig_futures = [] - self.vllm_config.parallel_config.data_parallel_master_port = \ - get_open_port() + self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() for engine in self.core_engines: reconfig_request = ReconfigureDistributedRequest( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_rank_local=\ - ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config. - data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config. - data_parallel_master_port) - coro = self._call_utility_async("reinitialize_distributed", - reconfig_request, - engine=engine) + new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, + new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, + ) + coro = self._call_utility_async( + "reinitialize_distributed", reconfig_request, engine=engine + ) reconfig_futures.append(asyncio.create_task(coro)) logger.info("All reconfigure messages sent, starting engine creation") @@ -1244,10 +1310,10 @@ async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, # Phase 2: Create new engines now that reconfig messages have been sent # self.resources.engine_manager is guaranteed to be # CoreEngineActorManager for RayDPClient - assert isinstance(self.resources.engine_manager, - CoreEngineActorManager) + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) self.resources.engine_manager.scale_up_elastic_ep( - self.vllm_config, new_data_parallel_size) + self.vllm_config, new_data_parallel_size + ) # Create new CoreEngine objects for the new engines new_engine_identities = set() @@ -1262,7 +1328,8 @@ async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, if not sync_input_socket.poll(timeout=600_000): raise TimeoutError( "Timed out waiting for new engines to send initial " - "message on input socket.") + "message on input socket." + ) identity, _ = sync_input_socket.recv_multipart() new_engine_identities.discard(identity) @@ -1274,42 +1341,42 @@ async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, # stats_update_task connection self._ensure_stats_update_task() scale_up_marker = msgspec.msgpack.encode( - ("SCALE_ELASTIC_EP", new_data_parallel_size)) + ("SCALE_ELASTIC_EP", new_data_parallel_size) + ) await self.first_req_send_socket.send(scale_up_marker) # Update the parallel config - self.vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size logger.info( "[Elastic EP] Scale up completed, new data parallel size: %s", - new_data_parallel_size) + new_data_parallel_size, + ) - async def _scale_down_elastic_ep(self, cur_data_parallel_size: int, - new_data_parallel_size: int) -> None: + async def _scale_down_elastic_ep( + self, cur_data_parallel_size: int, new_data_parallel_size: int + ) -> None: """Scale down the data parallel size by shutting down and reconfiguring existing engine cores.""" cur_data_parallel_size = len(self.core_engines) - self.vllm_config.parallel_config.data_parallel_master_port = \ - get_open_port() + self.vllm_config.parallel_config.data_parallel_master_port = get_open_port() reconfig_futures = [] for cur_dp_rank, engine in enumerate(self.core_engines): reconfig_request = ReconfigureDistributedRequest( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_rank_local=\ - ReconfigureRankType.KEEP_CURRENT_RANK, - new_data_parallel_master_ip=self.vllm_config.parallel_config. - data_parallel_master_ip, - new_data_parallel_master_port=self.vllm_config.parallel_config. - data_parallel_master_port) + new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK, + new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port, + ) if cur_dp_rank >= new_data_parallel_size: - reconfig_request.new_data_parallel_rank = \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK - coro = self._call_utility_async("reinitialize_distributed", - reconfig_request, - engine=engine) + reconfig_request.new_data_parallel_rank = ( + ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ) + coro = self._call_utility_async( + "reinitialize_distributed", reconfig_request, engine=engine + ) reconfig_futures.append(asyncio.create_task(coro)) for _ in range(new_data_parallel_size, cur_data_parallel_size): @@ -1317,18 +1384,19 @@ async def _scale_down_elastic_ep(self, cur_data_parallel_size: int, await asyncio.gather(*reconfig_futures) - assert isinstance(self.resources.engine_manager, - CoreEngineActorManager) + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) self.resources.engine_manager.scale_down_elastic_ep( - cur_data_parallel_size, new_data_parallel_size) + cur_data_parallel_size, new_data_parallel_size + ) self._ensure_stats_update_task() scale_down_marker = msgspec.msgpack.encode( - ("SCALE_ELASTIC_EP", new_data_parallel_size)) + ("SCALE_ELASTIC_EP", new_data_parallel_size) + ) await self.first_req_send_socket.send(scale_down_marker) - self.vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size logger.info( "[Elastic EP] Scale down completed, new data parallel size: %s", - new_data_parallel_size) + new_data_parallel_size, + ) diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 0f993a74c810..9d1d7558b1ed 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -11,7 +11,10 @@ from vllm.logger import init_logger from vllm.transformers_utils.detokenizer_utils import ( - AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) + AnyTokenizer, + convert_prompt_ids_to_tokens, + detokenize_incrementally, +) from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest @@ -19,15 +22,13 @@ # Only tokenizers >= 0.21.1 supports DecodeStream used for # FastIncrementalDetokenizer. -USE_FAST_DETOKENIZER = version.parse( - tokenizers.__version__) >= version.parse("0.21.1") +USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1") # Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042 INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered" class IncrementalDetokenizer: - def __init__(self): self.token_ids: list[int] = [] @@ -35,8 +36,7 @@ def __init__(self): def output_token_ids(self) -> list[int]: return self.token_ids - def update(self, new_token_ids: list[int], - stop_terminated: bool) -> Optional[str]: + def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: self.token_ids.extend(new_token_ids) return None @@ -49,15 +49,13 @@ def from_new_request( tokenizer: Optional[AnyTokenizer], request: EngineCoreRequest, ) -> "IncrementalDetokenizer": - assert request.sampling_params is not None if tokenizer is None: # No tokenizer => skipping detokenization. return IncrementalDetokenizer() - if USE_FAST_DETOKENIZER and isinstance(tokenizer, - PreTrainedTokenizerFast): + if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast): # Fast tokenizer => use tokenizers library DecodeStream. return FastIncrementalDetokenizer(tokenizer, request) @@ -66,7 +64,6 @@ def from_new_request( class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): - def __init__(self, request: EngineCoreRequest): super().__init__() @@ -88,8 +85,7 @@ def __init__(self, request: EngineCoreRequest): # Generation data self.output_text = "" - def update(self, new_token_ids: list[int], - stop_terminated: bool) -> Optional[str]: + def update(self, new_token_ids: list[int], stop_terminated: bool) -> Optional[str]: """ Update RequestState for the request_id by: 1) Detokenize the new token ids incrementally. @@ -117,8 +113,7 @@ def update(self, new_token_ids: list[int], self.token_ids.append(new_token_id) self.output_text += self.decode_next(new_token_id) # Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014 - if self.min_tokens and len( - self.output_token_ids) <= self.min_tokens: + if self.min_tokens and len(self.output_token_ids) <= self.min_tokens: stop_check_offset = len(self.output_text) if skipped_stop_token_id is not None: @@ -152,8 +147,11 @@ def get_next_output_text(self, finished: bool, delta: bool) -> str: # We return the full output text if the sequence is finished. buffer_length = 0 if finished else self.stop_buffer_length if not delta: - return self.output_text[:-buffer_length] if buffer_length else ( - self.output_text) + return ( + self.output_text[:-buffer_length] + if buffer_length + else (self.output_text) + ) length = len(self.output_text) - buffer_length last_offset = self._last_output_text_offset if last_offset < length: @@ -163,9 +161,7 @@ def get_next_output_text(self, finished: bool, delta: bool) -> str: class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): - - def __init__(self, tokenizer: PreTrainedTokenizerFast, - request: EngineCoreRequest): + def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest): super().__init__(request) sampling_params = request.sampling_params @@ -173,8 +169,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, self.request_id = request.request_id self.skip_special_tokens = sampling_params.skip_special_tokens - self.stream = DecodeStream( - skip_special_tokens=self.skip_special_tokens) + self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens) self.tokenizer: Tokenizer = tokenizer._tokenizer @@ -185,7 +180,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, if prompt_len > 4: for i in range(4, min(prompt_len + 1, 24)): suffix = prompt_token_ids[-i:] - if '�' not in self.tokenizer.decode(suffix): + if "�" not in self.tokenizer.decode(suffix): prompt_suffix = suffix break @@ -195,17 +190,18 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, self.spaces_between_special_tokens = ( sampling_params.skip_special_tokens - or sampling_params.spaces_between_special_tokens) + or sampling_params.spaces_between_special_tokens + ) if not self.spaces_between_special_tokens: # Store dict of added token ids so that we can suppress # the spaces between them. - if (added_token_ids := getattr(self.tokenizer, "added_token_ids", - None)) is None: + if ( + added_token_ids := getattr(self.tokenizer, "added_token_ids", None) + ) is None: self.tokenizer.added_token_ids = added_token_ids = { tid: tok.content - for tid, tok in - self.tokenizer.get_added_tokens_decoder().items() + for tid, tok in self.tokenizer.get_added_tokens_decoder().items() } if added_token_ids: @@ -245,15 +241,15 @@ def _protected_step(self, next_token_id: int) -> Optional[str]: # See https://github.com/vllm-project/vllm/issues/17448. logger.warning( "Encountered invalid prefix detokenization error" - " for request %s, resetting decode stream.", self.request_id) - self.stream = DecodeStream( - skip_special_tokens=self.skip_special_tokens) + " for request %s, resetting decode stream.", + self.request_id, + ) + self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens) token = self.stream.step(self.tokenizer, next_token_id) return token class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): - def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): super().__init__(request) @@ -262,7 +258,8 @@ def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): assert params is not None self.prompt_len = length_from_prompt_token_ids_or_embeds( - request.prompt_token_ids, request.prompt_embeds) + request.prompt_token_ids, request.prompt_embeds + ) # Metadata for incremental detokenization. if request.prompt_token_ids is not None: @@ -271,37 +268,37 @@ def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): tokenizer=tokenizer, prompt_ids=request.prompt_token_ids, skip_special_tokens=params.skip_special_tokens, - )) + ) + ) else: # Prompt embedding requests cannot be detokenized, in general. self.tokens = [""] * self.prompt_len self.prefix_offset = 0 self.read_offest = 0 - self.token_ids.extend(request.prompt_token_ids - or [0] * self.prompt_len) + self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len) self.skip_special_tokens = params.skip_special_tokens - self.spaces_between_special_tokens = ( - params.spaces_between_special_tokens) + self.spaces_between_special_tokens = params.spaces_between_special_tokens @property def output_token_ids(self) -> list[int]: - return self.token_ids if not self.prompt_len else ( - self.token_ids[self.prompt_len:]) + return ( + self.token_ids + if not self.prompt_len + else (self.token_ids[self.prompt_len :]) + ) def decode_next(self, next_token_id: int) -> str: - new_tokens, decoded_text, prefix_offset, read_offset = ( - detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=self.token_ids, - prev_tokens=self.tokens, - prefix_offset=self.prefix_offset, - read_offset=self.read_offset, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=self. - spaces_between_special_tokens, - )) + new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally( + tokenizer=self.tokenizer, + all_input_ids=self.token_ids, + prev_tokens=self.tokens, + prefix_offset=self.prefix_offset, + read_offset=self.read_offset, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + ) self.tokens.extend(new_tokens) self.prefix_offset = prefix_offset @@ -331,8 +328,7 @@ def check_stop_strings( for stop_str in stop: stop_string_len = len(stop_str) # Avoid searching already-searched text. - stop_index = output_text.find(stop_str, - 1 - new_char_count - stop_string_len) + stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len) if stop_index == -1: continue diff --git a/vllm/v1/engine/exceptions.py b/vllm/v1/engine/exceptions.py index 692ba9dc840f..d9f79a019e2d 100644 --- a/vllm/v1/engine/exceptions.py +++ b/vllm/v1/engine/exceptions.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project class EngineGenerateError(Exception): """Raised when a AsyncLLM.generate() fails. Recoverable.""" + pass diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 3734c208004a..701a62580562 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -23,8 +23,7 @@ from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.tracing import init_tracer -from vllm.transformers_utils.tokenizer import (AnyTokenizer, - init_tokenizer_from_configs) +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import Device from vllm.v1.engine import EngineCoreRequest @@ -62,12 +61,14 @@ def __init__( "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. " "This should not happen. As a workaround, try using " "LLMEngine.from_vllm_config(...) or explicitly set " - "VLLM_USE_V1=0 or 1 and report this issue on Github.") + "VLLM_USE_V1=0 or 1 and report this issue on Github." + ) if stat_loggers is not None: raise NotImplementedError( "Passing StatLoggers to LLMEngine in V1 is not yet supported. " - "Set VLLM_USE_V1=0 and file and issue on Github.") + "Set VLLM_USE_V1=0 and file and issue on Github." + ) self.vllm_config = vllm_config self.observability_config = vllm_config.observability_config @@ -76,39 +77,35 @@ def __init__( self.log_stats = log_stats - executor_backend = ( - self.vllm_config.parallel_config.distributed_executor_backend) + executor_backend = self.vllm_config.parallel_config.distributed_executor_backend parallel_config = vllm_config.parallel_config - self.external_launcher_dp = (parallel_config.data_parallel_size > 1 and - executor_backend == "external_launcher") + self.external_launcher_dp = ( + parallel_config.data_parallel_size > 1 + and executor_backend == "external_launcher" + ) # important: init dp group before init the engine_core # In the decoupled engine case this is handled in EngineCoreProc. - if not multiprocess_mode and parallel_config.data_parallel_size > 1 \ - and not self.external_launcher_dp: + if ( + not multiprocess_mode + and parallel_config.data_parallel_size > 1 + and not self.external_launcher_dp + ): self.dp_group = parallel_config.stateless_init_dp_group() else: self.dp_group = None self.should_execute_dummy_batch = False - if self.model_config.skip_tokenizer_init: - self.tokenizer = None - else: - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config) - # Processor (convert Inputs --> EngineCoreRequests) - self.processor = Processor(vllm_config=vllm_config, - tokenizer=self.tokenizer, - mm_registry=mm_registry) + self.processor = Processor(vllm_config, mm_registry=mm_registry) # OutputProcessor (convert EngineCoreOutputs --> RequestOutput). - self.output_processor = OutputProcessor(self.tokenizer, - log_stats=self.log_stats) + self.output_processor = OutputProcessor( + self.tokenizer, log_stats=self.log_stats + ) if self.observability_config.otlp_traces_endpoint is not None: tracer = init_tracer( - "vllm.llm_engine", - self.observability_config.otlp_traces_endpoint) + "vllm.llm_engine", self.observability_config.otlp_traces_endpoint + ) self.output_processor.tracer = tracer # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) @@ -149,12 +146,14 @@ def from_vllm_config( stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_stats: bool = False, ) -> "LLMEngine": - return cls(vllm_config=vllm_config, - executor_class=Executor.get_class(vllm_config), - log_stats=(not disable_log_stats), - usage_context=usage_context, - stat_loggers=stat_loggers, - multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING) + return cls( + vllm_config=vllm_config, + executor_class=Executor.get_class(vllm_config), + log_stats=(not disable_log_stats), + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING, + ) @classmethod def from_engine_args( @@ -175,12 +174,14 @@ def from_engine_args( enable_multiprocessing = True # Create the LLMEngine. - return cls(vllm_config=vllm_config, - executor_class=executor_class, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context, - stat_loggers=stat_loggers, - multiprocess_mode=enable_multiprocessing) + return cls( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + multiprocess_mode=enable_multiprocessing, + ) def get_num_unfinished_requests(self) -> int: return self.output_processor.get_num_unfinished_requests() @@ -193,7 +194,8 @@ def has_unfinished_requests(self) -> bool: def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: aggregated_has_unfinished = ParallelConfig.has_unfinished_dp( - self.dp_group, has_unfinished) + self.dp_group, has_unfinished + ) if not has_unfinished and aggregated_has_unfinished: self.should_execute_dummy_batch = True return aggregated_has_unfinished @@ -202,6 +204,14 @@ def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: def validate_outputs(cls, outputs, output_type): return outputs + @property + def tokenizer(self) -> Optional[AnyTokenizer]: + return self.processor.tokenizer + + @tokenizer.setter + def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None: + self.processor.tokenizer = tokenizer + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return self.engine_core.get_supported_tasks() @@ -225,22 +235,28 @@ def add_request( ) -> None: # Validate the request_id type. if not isinstance(request_id, str): - raise TypeError( - f"request_id must be a string, got {type(request_id)}") + raise TypeError(f"request_id must be a string, got {type(request_id)}") # Process raw inputs into the request. if isinstance(prompt, EngineCoreRequest): request = prompt else: assert prompt_text is None - logger.warning_once("Processor has been moved under LLM and will " - "be removed from LLMEngine in v0.13.") - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - tokenization_kwargs, - trace_headers, priority) - prompt_text = (prompt if isinstance(prompt, str) else - prompt.get("prompt")) + logger.warning_once( + "Processor has been moved under LLM and will " + "be removed from LLMEngine in v0.13." + ) + request = self.processor.process_inputs( + request_id, + prompt, + params, + arrival_time, + lora_request, + tokenization_kwargs, + trace_headers, + priority, + ) + prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") n = params.n if isinstance(params, SamplingParams) else 1 @@ -260,13 +276,13 @@ def add_request( child_request.sampling_params = params # Make a new RequestState and queue. - self.output_processor.add_request(child_request, prompt_text, - parent_req, idx) + self.output_processor.add_request( + child_request, prompt_text, parent_req, idx + ) # Add the request to EngineCore. self.engine_core.add_request(child_request) def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]: - if self.should_execute_dummy_batch: self.should_execute_dummy_batch = False self.engine_core.execute_dummy_batch() @@ -280,7 +296,8 @@ def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]: processed_outputs = self.output_processor.process_outputs( outputs.outputs, engine_core_timestamp=outputs.timestamp, - iteration_stats=iteration_stats) + iteration_stats=iteration_stats, + ) # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) @@ -330,8 +347,9 @@ def get_metrics(self) -> list[Metric]: def get_tokenizer(self) -> AnyTokenizer: if self.tokenizer is None: - raise ValueError("Unable to get tokenizer because " - "skip_tokenizer_init is True") + raise ValueError( + "Unable to get tokenizer because skip_tokenizer_init is True" + ) return self.tokenizer @@ -365,17 +383,21 @@ def pin_lora(self, lora_id: int) -> bool: """Prevent an adapter from being evicted.""" return self.engine_core.pin_lora(lora_id) - def collective_rpc(self, - method: Union[str, Callable[[WorkerBase], _R]], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + def collective_rpc( + self, + method: Union[str, Callable[[WorkerBase], _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> list[_R]: return self.engine_core.collective_rpc(method, timeout, args, kwargs) def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: - return self.collective_rpc("apply_model", args=(func, )) + return self.collective_rpc("apply_model", args=(func,)) def __del__(self): - if dp_group := getattr(self, "dp_group", - None) and not self.external_launcher_dp: + if ( + dp_group := getattr(self, "dp_group", None) + and not self.external_launcher_dp + ): stateless_destroy_torch_distributed_process_group(dp_group) diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 133122b6fcc0..ab0e44fce155 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -9,7 +9,9 @@ from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.transformers_utils.detokenizer_utils import ( - AnyTokenizer, convert_ids_list_to_tokens) + AnyTokenizer, + convert_ids_list_to_tokens, +) from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest from vllm.v1.outputs import LogprobsLists, LogprobsTensors @@ -20,7 +22,6 @@ @dataclass class LogprobsProcessor: - # Tokenizer for this request, # None if detokenization is disabled. tokenizer: Optional[AnyTokenizer] @@ -43,7 +44,7 @@ def from_new_request( num_prompt_logprobs = request.sampling_params.prompt_logprobs return cls( tokenizer=tokenizer, - cumulative_logprob=(None if num_logprobs is None else 0.), + cumulative_logprob=(None if num_logprobs is None else 0.0), logprobs=(None if num_logprobs is None else []), # NOTE: logprob of first prompt token is None. prompt_logprobs=(None if num_prompt_logprobs is None else [None]), @@ -68,12 +69,13 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists - for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, - token_ids_lst): - + for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): # Detokenize (non-incrementally). - decoded_tokens = NONES if self.tokenizer is None else ( - convert_ids_list_to_tokens(self.tokenizer, token_ids)) + decoded_tokens = ( + NONES + if self.tokenizer is None + else (convert_ids_list_to_tokens(self.tokenizer, token_ids)) + ) # Sampler puts the sampled logprob in first. sampled_token_logprob = logprobs[0] @@ -87,7 +89,8 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None: decoded_tokens, rank, self.num_logprobs, - )) + ) + ) def _update_prompt_logprobs( self, @@ -109,9 +112,13 @@ def _update_prompt_logprobs( # Detokenize non-incrementally. # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] - decoded_tokens = None if self.tokenizer is None else ( - convert_ids_list_to_tokens(self.tokenizer, - token_ids.flatten().tolist())) + decoded_tokens = ( + None + if self.tokenizer is None + else ( + convert_ids_list_to_tokens(self.tokenizer, token_ids.flatten().tolist()) + ) + ) # Recover shapes. num_prompt_tokens, num_logprobs = logprobs.shape @@ -126,15 +133,20 @@ def _update_prompt_logprobs( # Handle flattening. offset = pos * num_logprobs offset_end = offset + num_logprobs - decoded_tokens_for_pos = NONES \ - if decoded_tokens is None else decoded_tokens[offset:offset_end] + decoded_tokens_for_pos = ( + NONES if decoded_tokens is None else decoded_tokens[offset:offset_end] + ) # Update with the Logprob dictionary for this pos. self.prompt_logprobs.append( - self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos], - decoded_tokens_for_pos, - prompt_token_ranks[pos], - self.num_prompt_logprobs)) + self._make_logprob_dict( + prompt_logprobs[pos], + token_ids[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + self.num_prompt_logprobs, + ) + ) def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]: """Pop and return all request prompt logprobs @@ -182,7 +194,7 @@ def _make_logprob_dict( # being in the topk, since inserting duplicated data # into a dictionary twice is the same as doing it once. topk_ranks = range(1, num_logprobs + 1) - ranks = itertools.chain((rank, ), topk_ranks) + ranks = itertools.chain((rank,), topk_ranks) return { token_id: Logprob( @@ -191,7 +203,8 @@ def _make_logprob_dict( decoded_token=token, ) for token_id, logprob, rank, token in zip( - logprob_token_ids, logprobs, ranks, decoded_tokens) + logprob_token_ids, logprobs, ranks, decoded_tokens + ) } def update_from_output(self, output: EngineCoreOutput) -> None: diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 46cb97d4e7b5..eb65b68969e3 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -8,19 +8,21 @@ import torch -from vllm.outputs import (CompletionOutput, PoolingOutput, - PoolingRequestOutput, RequestOutput) +from vllm.outputs import ( + CompletionOutput, + PoolingOutput, + PoolingRequestOutput, + RequestOutput, +) from vllm.sampling_params import RequestOutputKind -from vllm.tracing import (SpanAttributes, SpanKind, Tracer, - extract_trace_context) +from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest -from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, - RequestStateStats) +from vllm.v1.metrics.stats import IterationStats, LoRARequestStates, RequestStateStats class RequestOutputCollector: @@ -34,12 +36,14 @@ class RequestOutputCollector: def __init__(self, output_kind: RequestOutputKind): self.aggregate = output_kind == RequestOutputKind.DELTA - self.output: Optional[Union[RequestOutput, PoolingRequestOutput, - Exception]] = None + self.output: Optional[Union[RequestOutput, PoolingRequestOutput, Exception]] = ( + None + ) self.ready = asyncio.Event() - def put(self, output: Union[RequestOutput, PoolingRequestOutput, - Exception]) -> None: + def put( + self, output: Union[RequestOutput, PoolingRequestOutput, Exception] + ) -> None: """Non-blocking put operation.""" if self.output is None or isinstance(output, Exception): self.output = output @@ -59,8 +63,7 @@ async def get(self) -> Union[RequestOutput, PoolingRequestOutput]: raise output return output - def get_nowait( - self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: + def get_nowait(self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: """Non-blocking get operation.""" output = self.output if output is not None: @@ -78,7 +81,6 @@ class OutputProcessorOutput: class RequestState: - def __init__( self, request_id: str, @@ -108,7 +110,8 @@ def __init__( self.prompt_token_ids = prompt_token_ids self.prompt_embeds = prompt_embeds self.prompt_len = length_from_prompt_token_ids_or_embeds( - self.prompt_token_ids, self.prompt_embeds) + self.prompt_token_ids, self.prompt_embeds + ) self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.max_tokens_param = max_tokens_param @@ -119,8 +122,7 @@ def __init__( self.queue = queue self.num_cached_tokens = 0 - self.stats = RequestStateStats( - arrival_time=arrival_time) if log_stats else None + self.stats = RequestStateStats(arrival_time=arrival_time) if log_stats else None @classmethod def from_new_request( @@ -133,7 +135,6 @@ def from_new_request( queue: Optional[RequestOutputCollector], log_stats: bool, ) -> "RequestState": - if sampling_params := request.sampling_params: if not sampling_params.detokenize: tokenizer = None @@ -164,8 +165,9 @@ def from_new_request( request_id=request.request_id, parent_req=parent_req, request_index=request_index, - lora_name=(request.lora_request.name - if request.lora_request is not None else None), + lora_name=( + request.lora_request.name if request.lora_request is not None else None + ), output_kind=output_kind, prompt=prompt, prompt_token_ids=request.prompt_token_ids, @@ -189,7 +191,6 @@ def make_request_output( stop_reason: Union[int, str, None], kv_transfer_params: Optional[dict[str, Any]] = None, ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: - finished = finish_reason is not None final_only = self.output_kind == RequestOutputKind.FINAL_ONLY @@ -200,22 +201,23 @@ def make_request_output( request_id = self.request_id if pooling_output is not None: return self._new_request_output( - request_id, [self._new_pooling_output(pooling_output)], - finished) + request_id, [self._new_pooling_output(pooling_output)], finished + ) - output = self._new_completion_output(new_token_ids, finish_reason, - stop_reason) + output = self._new_completion_output(new_token_ids, finish_reason, stop_reason) if self.parent_req is None: outputs = [output] else: request_id, outputs, finished = self.parent_req.get_outputs( - request_id, output) + request_id, output + ) if not outputs: return None - return self._new_request_output(request_id, outputs, finished, - kv_transfer_params) + return self._new_request_output( + request_id, outputs, finished, kv_transfer_params + ) def _new_request_output( self, @@ -224,7 +226,6 @@ def _new_request_output( finished: bool, kv_transfer_params: Optional[dict[str, Any]] = None, ) -> Union[RequestOutput, PoolingRequestOutput]: - first_output = outputs[0] if isinstance(first_output, PoolingOutput): assert len(outputs) == 1 @@ -248,15 +249,17 @@ def _new_request_output( if prompt_token_ids is None and self.prompt_embeds is not None: prompt_token_ids = [0] * len(self.prompt_embeds) - return RequestOutput(request_id=request_id, - prompt=self.prompt, - prompt_token_ids=prompt_token_ids, - prompt_logprobs=prompt_logprobs, - outputs=cast(list[CompletionOutput], outputs), - finished=finished, - kv_transfer_params=kv_transfer_params, - num_cached_tokens=self.num_cached_tokens, - metrics=self.stats) + return RequestOutput( + request_id=request_id, + prompt=self.prompt, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=prompt_logprobs, + outputs=cast(list[CompletionOutput], outputs), + finished=finished, + kv_transfer_params=kv_transfer_params, + num_cached_tokens=self.num_cached_tokens, + metrics=self.stats, + ) def _new_completion_output( self, @@ -264,7 +267,6 @@ def _new_completion_output( finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], ) -> CompletionOutput: - assert self.detokenizer is not None assert self.logprobs_processor is not None finished = finish_reason is not None @@ -278,7 +280,7 @@ def _new_completion_output( # Prepare logprobs, based on delta mode logprobs = self.logprobs_processor.logprobs if delta and logprobs: - logprobs = logprobs[-len(token_ids):] + logprobs = logprobs[-len(token_ids) :] return CompletionOutput( index=self.request_index, @@ -287,13 +289,13 @@ def _new_completion_output( logprobs=logprobs, cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, - stop_reason=stop_reason if finished else None) + stop_reason=stop_reason if finished else None, + ) def _new_pooling_output( self, pooling_output: torch.Tensor, ) -> PoolingOutput: - return PoolingOutput(data=pooling_output) @@ -333,15 +335,18 @@ def abort_requests( request_ids_to_abort.append(request_id) # Produce final abort output. if req_state.queue is not None and ( - request_output := req_state.make_request_output( - new_token_ids=[], - # Set pooling_output is not None to - # correctly enter the abort pooling branch - pooling_output=torch.randn(0, device="cpu") - if req_state.detokenizer is None else None, - finish_reason=FinishReason.ABORT, - stop_reason=None, - kv_transfer_params=None)): + request_output := req_state.make_request_output( + new_token_ids=[], + # Set pooling_output is not None to + # correctly enter the abort pooling branch + pooling_output=torch.randn(0, device="cpu") + if req_state.detokenizer is None + else None, + finish_reason=FinishReason.ABORT, + stop_reason=None, + kv_transfer_params=None, + ) + ): req_state.queue.put(request_output) elif parent := self.parent_requests.get(request_id): # Abort children prior to removing the parent. @@ -364,13 +369,15 @@ def add_request( if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - req_state = RequestState.from_new_request(tokenizer=self.tokenizer, - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + req_state = RequestState.from_new_request( + tokenizer=self.tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats, + ) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: @@ -404,8 +411,7 @@ def process_outputs( within the loop below. """ - request_outputs: Union[list[RequestOutput], - list[PoolingRequestOutput]] = [] + request_outputs: Union[list[RequestOutput], list[PoolingRequestOutput]] = [] reqs_to_abort: list[str] = [] for engine_core_output in engine_core_outputs: req_id = engine_core_output.request_id @@ -415,9 +421,9 @@ def process_outputs( continue # 1) Compute stats for this iteration. - self._update_stats_from_output(req_state, engine_core_output, - engine_core_timestamp, - iteration_stats) + self._update_stats_from_output( + req_state, engine_core_output, engine_core_timestamp, iteration_stats + ) new_token_ids = engine_core_output.new_token_ids pooling_output = engine_core_output.pooling_output @@ -432,20 +438,24 @@ def process_outputs( assert req_state.logprobs_processor is not None # 2) Detokenize the token ids into text and perform stop checks. stop_string = req_state.detokenizer.update( - new_token_ids, finish_reason == FinishReason.STOP) + new_token_ids, finish_reason == FinishReason.STOP + ) if stop_string: finish_reason = FinishReason.STOP stop_reason = stop_string # 3) Compute sample and prompt logprobs for request, # if required. - req_state.logprobs_processor.update_from_output( - engine_core_output) + req_state.logprobs_processor.update_from_output(engine_core_output) # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, pooling_output, finish_reason, stop_reason, - kv_transfer_params): + new_token_ids, + pooling_output, + finish_reason, + stop_reason, + kv_transfer_params, + ): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) @@ -466,11 +476,11 @@ def process_outputs( reqs_to_abort.append(req_id) # Track per-request stats - self._update_stats_from_finished(req_state, finish_reason, - iteration_stats) + self._update_stats_from_finished( + req_state, finish_reason, iteration_stats + ) if self.tracer: - self.do_tracing(engine_core_output, req_state, - iteration_stats) + self.do_tracing(engine_core_output, req_state, iteration_stats) self.lora_states.update_iteration_stats(iteration_stats) return OutputProcessorOutput( @@ -478,9 +488,12 @@ def process_outputs( reqs_to_abort=reqs_to_abort, ) - def do_tracing(self, engine_core_output: EngineCoreOutput, - req_state: RequestState, - iteration_stats: Optional[IterationStats]) -> None: + def do_tracing( + self, + engine_core_output: EngineCoreOutput, + req_state: RequestState, + iteration_stats: Optional[IterationStats], + ) -> None: assert req_state.stats is not None assert iteration_stats is not None assert self.tracer is not None @@ -488,59 +501,63 @@ def do_tracing(self, engine_core_output: EngineCoreOutput, arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9) trace_context = extract_trace_context(engine_core_output.trace_headers) prompt_length = length_from_prompt_token_ids_or_embeds( - req_state.prompt_token_ids, req_state.prompt_embeds) - with (self.tracer.start_as_current_span( - "llm_request", - kind=SpanKind.SERVER, - context=trace_context, - start_time=arrival_time_nano_seconds) as span): + req_state.prompt_token_ids, req_state.prompt_embeds + ) + with self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds, + ) as span: metrics = req_state.stats - e2e_time = iteration_stats.iteration_timestamp - \ - metrics.arrival_time + e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time queued_time = metrics.scheduled_ts - metrics.queued_ts prefill_time = metrics.first_token_ts - metrics.scheduled_ts decode_time = metrics.last_token_ts - metrics.first_token_ts inference_time = metrics.last_token_ts - metrics.scheduled_ts span.set_attribute( SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, - metrics.first_token_latency) + metrics.first_token_latency, + ) span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) - span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, - queued_time) - span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, - prompt_length) - span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, - metrics.num_generation_tokens) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, prompt_length) span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, - prefill_time) + SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, + metrics.num_generation_tokens, + ) span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, - decode_time) + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time + ) span.set_attribute( - SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, - inference_time) + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time + ) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time + ) # meta - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, - req_state.request_id) + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id) if req_state.top_p: - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, - req_state.top_p) + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.top_p) if req_state.max_tokens_param: - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, - req_state.max_tokens_param) + span.set_attribute( + SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, req_state.max_tokens_param + ) if req_state.temperature: - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, - req_state.temperature) + span.set_attribute( + SpanAttributes.GEN_AI_REQUEST_TEMPERATURE, req_state.temperature + ) if req_state.n: - span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, - req_state.n) + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.n) - def _update_stats_from_output(self, req_state: RequestState, - engine_core_output: EngineCoreOutput, - engine_core_timestamp: Optional[float], - iteration_stats: Optional[IterationStats]): + def _update_stats_from_output( + self, + req_state: RequestState, + engine_core_output: EngineCoreOutput, + engine_core_timestamp: Optional[float], + iteration_stats: Optional[IterationStats], + ): if iteration_stats is None: return @@ -548,15 +565,21 @@ def _update_stats_from_output(self, req_state: RequestState, assert engine_core_timestamp is not None assert req_state.stats is not None - iteration_stats.update_from_output(engine_core_output, - engine_core_timestamp, - req_state.is_prefilling, - req_state.prompt_len, - req_state.stats, lora_stats) - - def _update_stats_from_finished(self, req_state: RequestState, - finish_reason: Optional[FinishReason], - iteration_stats: Optional[IterationStats]): + iteration_stats.update_from_output( + engine_core_output, + engine_core_timestamp, + req_state.is_prefilling, + req_state.prompt_len, + req_state.stats, + lora_stats, + ) + + def _update_stats_from_finished( + self, + req_state: RequestState, + finish_reason: Optional[FinishReason], + iteration_stats: Optional[IterationStats], + ): if iteration_stats is None: return @@ -565,11 +588,13 @@ def _update_stats_from_finished(self, req_state: RequestState, iteration_stats.update_from_finished_request( finish_reason=finish_reason, num_prompt_tokens=length_from_prompt_token_ids_or_embeds( - req_state.prompt_token_ids, req_state.prompt_embeds), + req_state.prompt_token_ids, req_state.prompt_embeds + ), max_tokens_param=req_state.max_tokens_param, - req_stats=req_state.stats) + req_stats=req_state.stats, + ) self.lora_states.finish_request(req_state) ParentRequest.observe_finished_request( - req_state.parent_req, iteration_stats, - req_state.stats.num_generation_tokens) + req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens + ) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 1e9911152c6d..daf115c0325f 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -31,15 +31,16 @@ class ParentRequest: # To efficiently obtain child sampling params cached_child_sampling_params: Optional[SamplingParams] - def __init__(self, request_id: str, - sampling_params: SamplingParams) -> None: + def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: self.request_id = request_id self.sampling_params = sampling_params self.child_requests = set() - self.output_aggregator = [None] * sampling_params.n if ( - sampling_params.output_kind - == RequestOutputKind.FINAL_ONLY) else [] + self.output_aggregator = ( + [None] * sampling_params.n + if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY) + else [] + ) self.max_num_generation_tokens = 0 self.cached_child_sampling_params = None @@ -49,7 +50,7 @@ def _get_child_sampling_params( ) -> SamplingParams: """Efficiently obtain child `sampling_params` - If `sampling_params.seed` is not `None` then + If `sampling_params.seed` is not `None` then each child request requires a unique clone of parent `sampling_params` with a unique seed. @@ -76,10 +77,10 @@ def _get_child_sampling_params( def get_child_info(self, index: int) -> tuple[str, SamplingParams]: """Get child request ID and sampling params. - + Args: index: index within `n` child requests. - + Returns: (request ID, sampling_params) tuple """ @@ -111,23 +112,25 @@ def get_outputs( return self.request_id, outputs, finished def observe_num_generation_tokens(self, num_generation_tokens: int): - self.max_num_generation_tokens = max(num_generation_tokens, - self.max_num_generation_tokens) + self.max_num_generation_tokens = max( + num_generation_tokens, self.max_num_generation_tokens + ) return self.max_num_generation_tokens @staticmethod - def observe_finished_request(parent_req: Optional['ParentRequest'], - iteration_stats: IterationStats, - num_generation_tokens: int): - + def observe_finished_request( + parent_req: Optional["ParentRequest"], + iteration_stats: IterationStats, + num_generation_tokens: int, + ): n_param = parent_req.n if parent_req is not None else 1 if parent_req is not None: num_generation_tokens = parent_req.observe_num_generation_tokens( - num_generation_tokens) + num_generation_tokens + ) # Child requests finished, we can now record to iteration stats if parent_req is None or not parent_req.child_requests: - iteration_stats.max_num_generation_tokens_iter.append( - num_generation_tokens) + iteration_stats.max_num_generation_tokens_iter.append(num_generation_tokens) iteration_stats.n_params_iter.append(n_param) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index c30ceb96a5e0..f39e9c1eea7d 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -21,48 +21,49 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest -from vllm.v1.structured_output.backend_guidance import ( - validate_guidance_grammar) +from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar from vllm.v1.structured_output.backend_lm_format_enforcer import ( - validate_structured_output_request_lm_format_enforcer) + validate_structured_output_request_lm_format_enforcer, +) from vllm.v1.structured_output.backend_outlines import ( - validate_structured_output_request_outlines) -from vllm.v1.structured_output.backend_xgrammar import ( - validate_xgrammar_grammar) + validate_structured_output_request_outlines, +) +from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar logger = init_logger(__name__) class Processor: - def __init__( self, vllm_config: VllmConfig, - tokenizer: AnyTokenizer, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - + ) -> None: self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config self.structured_outputs_config = vllm_config.structured_outputs_config - self.tokenizer = tokenizer - self.generation_config_fields = ( - self.model_config.try_get_generation_config()) + self.generation_config_fields = self.model_config.try_get_generation_config() self.mm_registry = mm_registry - self.mm_processor_cache = processor_cache_from_config( - vllm_config, mm_registry) + self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry) self.input_preprocessor = InputPreprocessor( self.model_config, - self.tokenizer, mm_registry, mm_processor_cache=self.mm_processor_cache, ) + @property + def tokenizer(self) -> Optional[AnyTokenizer]: + return self.input_preprocessor.tokenizer + + @tokenizer.setter + def tokenizer(self, tokenizer: Optional[AnyTokenizer]) -> None: + self.input_preprocessor.tokenizer = tokenizer + def _validate_logprobs( self, params: SamplingParams, @@ -79,7 +80,8 @@ def _validate_logprobs( if num_logprobs > max_logprobs: raise ValueError( f"Requested sample logprobs of {num_logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + f"which is greater than max allowed: {max_logprobs}" + ) # Validate prompt logprobs. if params.prompt_logprobs: @@ -89,7 +91,8 @@ def _validate_logprobs( if num_prompt_logprobs > max_logprobs: raise ValueError( f"Requested prompt logprobs of {num_prompt_logprobs}, " - f"which is greater than max allowed: {max_logprobs}") + f"which is greater than max allowed: {max_logprobs}" + ) def _validate_sampling_params( self, @@ -108,8 +111,7 @@ def _validate_sampling_params( return vocab_size = len(self.tokenizer) if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids): - raise ValueError( - "allowed_token_ids contains out-of-vocab token id!") + raise ValueError("allowed_token_ids contains out-of-vocab token id!") def _validate_logit_bias( self, @@ -129,7 +131,8 @@ def _validate_logit_bias( if invalid_token_ids: raise ValueError( f"token_id(s) {invalid_token_ids} in logit_bias contain " - f"out-of-vocab token ids. Vocabulary size: {vocab_size}") + f"out-of-vocab token ids. Vocabulary size: {vocab_size}" + ) def _validate_supported_sampling_params( self, @@ -140,8 +143,9 @@ def _validate_supported_sampling_params( raise ValueError("vLLM V1 does not yet support best_of.") # Logits processors not supported. if params.logits_processors: - raise ValueError("vLLM V1 does not support per request " - "user provided logits processors.") + raise ValueError( + "vLLM V1 does not support per request user provided logits processors." + ) def _validate_params( self, @@ -178,18 +182,23 @@ def _validate_single_prompt(single_prompt: Union[dict, str]) -> None: for modality, items in mm_data.items(): if modality in mm_uuids: data_len = len(items) if isinstance(items, list) else 1 - uuid_len = len(mm_uuids[modality]) if isinstance( - mm_uuids[modality], list) else 1 + uuid_len = ( + len(mm_uuids[modality]) + if isinstance(mm_uuids[modality], list) + else 1 + ) if uuid_len != data_len: raise ValueError( f"multi_modal_uuids for modality '{modality}' " "must have same length as data: got " f"{uuid_len} uuids vs " - f"{data_len} items.") + f"{data_len} items." + ) else: raise ValueError( f"multi_modal_uuids for modality '{modality}' must " - "be provided if multi_modal_data is provided.") + "be provided if multi_modal_data is provided." + ) # Handle explicit encoder/decoder prompts or singleton prompt if isinstance(prompt, dict) and "encoder_prompt" in prompt: @@ -208,8 +217,9 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: # LoRA request passed in while LoRA is not enabled if not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") + raise ValueError( + f"Got lora_request {lora_request} but LoRA is not enabled!" + ) if self.tokenizer is not None: logger.warning_once( @@ -217,7 +227,8 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: "tokenizers for different LoRAs. By default, vLLM uses base " "model's tokenizer. If you are using a LoRA " "with its own tokenizer, consider specifying `--tokenizer " - "[lora_path]` to use the LoRA tokenizer.") + "[lora_path]` to use the LoRA tokenizer." + ) def _validate_structured_output(self, params: SamplingParams) -> None: if not params.structured_outputs or not self.structured_outputs_config: @@ -235,20 +246,23 @@ def _validate_structured_output(self, params: SamplingParams) -> None: # to a specific backend based on `auto` behavior in a previous # request. We remember that it was set as a result of `auto` # using the `_backend_was_auto` field set in the params. - if (backend != _backend - and not (backend == "auto" - and params.structured_outputs._backend_was_auto)): + if backend != _backend and not ( + backend == "auto" and params.structured_outputs._backend_was_auto + ): raise ValueError( "Request-level structured output backend selection is not " f"supported. The request specified '{_backend}', but vLLM " f"was initialised with '{backend}'. This error can be " - "resolved by removing '_backend' from the request.") + "resolved by removing '_backend' from the request." + ) else: params.structured_outputs._backend = backend # Request content validation - if (isinstance(params.structured_outputs.choice, list) - and not params.structured_outputs.choice): + if ( + isinstance(params.structured_outputs.choice, list) + and not params.structured_outputs.choice + ): # It is invalid for choice to be an empty list raise ValueError( f"Choice '{params.structured_outputs.choice}' cannot be an empty list" # noqa: E501 @@ -318,9 +332,7 @@ def _extract_mm_data(p: PromptType): mm_uuids: MultiModalUUIDDict = {} for modality, data in mm_data.items(): n = len(data) if isinstance(data, list) else 1 - mm_uuids[modality] = [ - f"{request_id}-{modality}-{i}" for i in range(n) - ] + mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] return mm_uuids def process_inputs( @@ -339,10 +351,13 @@ def process_inputs( self._validate_params(params) data_parallel_size = self.vllm_config.parallel_config.data_parallel_size - if data_parallel_rank is not None and not (0 <= data_parallel_rank < - data_parallel_size): - raise ValueError(f"data_parallel_rank {data_parallel_rank} " - f"is out of range [0, {data_parallel_size}).") + if data_parallel_rank is not None and not ( + 0 <= data_parallel_rank < data_parallel_size + ): + raise ValueError( + f"data_parallel_rank {data_parallel_rank} " + f"is out of range [0, {data_parallel_size})." + ) if arrival_time is None: arrival_time = time.time() @@ -355,9 +370,11 @@ def process_inputs( # reused across requests, therefore identifying multimodal data items # by their content is no longer necessary, and we create uuids with # request id-modality-index as multimodal hash overrides. - if (self.model_config.multimodal_config and - self.model_config.multimodal_config.mm_processor_cache_gb == 0 - and not self.cache_config.enable_prefix_caching): + if ( + self.model_config.multimodal_config + and self.model_config.multimodal_config.mm_processor_cache_gb == 0 + and not self.cache_config.enable_prefix_caching + ): mm_uuids = self._maybe_build_mm_uuids(request_id, prompt) else: # Otherwise, use user-provided uuids as multimodal hash overrides @@ -378,6 +395,7 @@ def process_inputs( mm_uuids=mm_uuids, ) from vllm.platforms import current_platform + current_platform.validate_request( prompt=prompt, params=params, @@ -393,10 +411,16 @@ def process_inputs( # discriminated unions of TypedDicts, because of how it handles # inheritance of TypedDict. If we explicitly extract the items we want # we can avoid type errors from using `dict.get` later in the method. - prompt_token_ids = decoder_inputs[ - "prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None - prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[ - "type"] == "embeds" else None + prompt_token_ids = ( + decoder_inputs["prompt_token_ids"] + if decoder_inputs["type"] != "embeds" + else None + ) + prompt_embeds = ( + decoder_inputs["prompt_embeds"] + if decoder_inputs["type"] == "embeds" + else None + ) sampling_params = None pooling_params = None @@ -406,11 +430,12 @@ def process_inputs( # If unset max tokens, then generate up to the max_model_len. if sampling_params.max_tokens is None: seq_len = length_from_prompt_token_ids_or_embeds( - prompt_token_ids, prompt_embeds) - sampling_params.max_tokens = \ - self.model_config.max_model_len - seq_len + prompt_token_ids, prompt_embeds + ) + sampling_params.max_tokens = self.model_config.max_model_len - seq_len sampling_params.update_from_generation_config( - self.generation_config_fields, eos_token_id) + self.generation_config_fields, eos_token_id + ) if self.tokenizer is not None: sampling_params.update_from_tokenizer(self.tokenizer) else: @@ -436,7 +461,9 @@ def process_inputs( data=decoder_mm_inputs[modality][idx], modality=modality, identifier=decoder_mm_hashes[modality][idx], - mm_position=decoder_mm_positions[modality][idx])) + mm_position=decoder_mm_positions[modality][idx], + ) + ) return EngineCoreRequest( request_id=request_id, @@ -454,8 +481,9 @@ def process_inputs( trace_headers=trace_headers, ) - def _validate_model_inputs(self, encoder_inputs: Optional[SingletonInputs], - decoder_inputs: SingletonInputs): + def _validate_model_inputs( + self, encoder_inputs: Optional[SingletonInputs], decoder_inputs: SingletonInputs + ): if encoder_inputs is not None: self._validate_model_input(encoder_inputs, prompt_type="encoder") @@ -469,12 +497,17 @@ def _validate_model_input( ): model_config = self.model_config - prompt_ids = None if prompt_inputs[ - "type"] == "embeds" else prompt_inputs["prompt_token_ids"] - prompt_embeds = prompt_inputs["prompt_embeds"] if prompt_inputs[ - "type"] == "embeds" else None - prompt_len = length_from_prompt_token_ids_or_embeds( - prompt_ids, prompt_embeds) + prompt_ids = ( + None + if prompt_inputs["type"] == "embeds" + else prompt_inputs["prompt_token_ids"] + ) + prompt_embeds = ( + prompt_inputs["prompt_embeds"] + if prompt_inputs["type"] == "embeds" + else None + ) + prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds) if not prompt_ids: if prompt_type == "encoder" and model_config.is_multimodal_model: pass # Mllama may have empty encoder inputs for text-only data @@ -483,10 +516,8 @@ def _validate_model_input( else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - if self.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = self.tokenizer + tokenizer = self.tokenizer + if tokenizer is not None: max_input_id = max(prompt_ids or [], default=0) # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while @@ -499,10 +530,10 @@ def _validate_model_input( # Here we take the max of the two to determine if a token id is # truly out-of-vocabulary. - if max_input_id > max(tokenizer.max_token_id, - self.model_config.get_vocab_size() - 1): - raise ValueError( - f"Token id {max_input_id} is out of vocabulary") + if max_input_id > max( + tokenizer.max_token_id, self.model_config.get_vocab_size() - 1 + ): + raise ValueError(f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len if prompt_len > max_prompt_len: @@ -522,16 +553,19 @@ def _validate_model_input( "Make sure that `max_model_len` is no smaller than the " "number of text tokens plus multimodal tokens. For image " "inputs, the number of image tokens depends on the number " - "of images, and possibly their aspect ratios as well.") + "of images, and possibly their aspect ratios as well." + ) else: suggestion = ( "Make sure that `max_model_len` is no smaller than the " - "number of text tokens.") + "number of text tokens." + ) raise ValueError( f"The {prompt_type} prompt (length {prompt_len}) is " f"longer than the maximum model length of {max_prompt_len}. " - f"{suggestion}") + f"{suggestion}" + ) # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 18ef25ceb6f5..5f23cf80d5df 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -70,6 +70,7 @@ class EngineHandshakeMetadata: including addresses of the front-end ZMQ queues that they should connect to. """ + addresses: EngineZmqAddresses parallel_config: dict[str, Union[int, str, list[int]]] @@ -103,8 +104,7 @@ def __init__( } if client_handshake_address: - common_kwargs[ - "client_handshake_address"] = client_handshake_address + common_kwargs["client_handshake_address"] = client_handshake_address self.processes: list[BaseProcess] = [] local_dp_ranks = [] @@ -115,21 +115,27 @@ def __init__( # Start EngineCore in background process. local_dp_ranks.append(local_index) self.processes.append( - context.Process(target=target_fn, - name=f"EngineCore_DP{global_index}", - kwargs=common_kwargs | { - "dp_rank": global_index, - "local_dp_rank": local_index, - })) + context.Process( + target=target_fn, + name=f"EngineCore_DP{global_index}", + kwargs=common_kwargs + | { + "dp_rank": global_index, + "local_dp_rank": local_index, + }, + ) + ) self._finalizer = weakref.finalize(self, shutdown, self.processes) data_parallel = vllm_config.parallel_config.data_parallel_size > 1 try: for proc, local_dp_rank in zip(self.processes, local_dp_ranks): - with set_device_control_env_var( - vllm_config, local_dp_rank) if ( - data_parallel) else contextlib.nullcontext(): + with ( + set_device_control_env_var(vllm_config, local_dp_rank) + if (data_parallel) + else contextlib.nullcontext() + ): proc.start() finally: # Kill other procs if not all are running. @@ -151,13 +157,15 @@ def finished_procs(self) -> dict[str, int]: """Returns dict of proc name -> exit code for any finished procs.""" return { proc.name: proc.exitcode - for proc in self.processes if proc.exitcode is not None + for proc in self.processes + if proc.exitcode is not None } @contextlib.contextmanager -def set_device_control_env_var(vllm_config: VllmConfig, - local_dp_rank: int) -> Iterator[None]: +def set_device_control_env_var( + vllm_config: VllmConfig, local_dp_rank: int +) -> Iterator[None]: """ Temporarily set CUDA_VISIBLE_DEVICES or equivalent for engine subprocess. @@ -166,12 +174,13 @@ def set_device_control_env_var(vllm_config: VllmConfig, evar = current_platform.device_control_env_var value = get_device_indices(evar, local_dp_rank, world_size) - with patch.dict(os.environ, values=((evar, value), )): + with patch.dict(os.environ, values=((evar, value),)): yield -def get_device_indices(device_control_env_var: str, local_dp_rank: int, - world_size: int): +def get_device_indices( + device_control_env_var: str, local_dp_rank: int, world_size: int +): """ Returns a comma-separated string of device indices for the specified data parallel rank. @@ -182,14 +191,16 @@ def get_device_indices(device_control_env_var: str, local_dp_rank: int, try: value = ",".join( str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * - world_size)) + for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size) + ) except IndexError as e: - raise Exception(f"Error setting {device_control_env_var}: " - f"local range: [{local_dp_rank * world_size}, " - f"{(local_dp_rank + 1) * world_size}) " - "base value: " - f"\"{os.getenv(device_control_env_var)}\"") from e + raise Exception( + f"Error setting {device_control_env_var}: " + f"local range: [{local_dp_rank * world_size}, " + f"{(local_dp_rank + 1) * world_size}) " + "base value: " + f'"{os.getenv(device_control_env_var)}"' + ) from e return value @@ -215,8 +226,7 @@ def __init__( import ray from ray.runtime_env import RuntimeEnv - from ray.util.scheduling_strategies import ( - PlacementGroupSchedulingStrategy) + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm.v1.engine.core import DPEngineCoreActor @@ -225,8 +235,7 @@ def __init__( env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor") self.env_vars_dict = { - name: os.environ[name] - for name in env_vars_list if name in os.environ + name: os.environ[name] for name in env_vars_list if name in os.environ } runtime_env = RuntimeEnv(env_vars=self.env_vars_dict) @@ -234,37 +243,38 @@ def __init__( self.executor_class = executor_class self.log_stats = log_stats dp_size = vllm_config.parallel_config.data_parallel_size - local_engine_count = \ - vllm_config.parallel_config.data_parallel_size_local + local_engine_count = vllm_config.parallel_config.data_parallel_size_local world_size = vllm_config.parallel_config.world_size if ray.is_initialized(): - logger.info( - "Ray is already initialized. Skipping Ray initialization.") + logger.info("Ray is already initialized. Skipping Ray initialization.") else: ray.init() if placement_groups is not None: assert local_dp_ranks is not None, ( - "local_dp_ranks must be provided if " - "placement_groups is provided") + "local_dp_ranks must be provided if placement_groups is provided" + ) assert len(placement_groups) == len(local_dp_ranks), ( - "placement_groups and local_dp_ranks must " - "have the same length") + "placement_groups and local_dp_ranks must have the same length" + ) logger.info("Using provided placement groups") # TODO(rui): validate passed-in placement groups self.created_placement_groups = [] else: - placement_groups, local_dp_ranks = \ + placement_groups, local_dp_ranks = ( CoreEngineActorManager.create_dp_placement_groups(vllm_config) + ) self.created_placement_groups = placement_groups assert len(placement_groups) == dp_size, ( - "Number of placement groups must match data parallel size") + "Number of placement groups must match data parallel size" + ) self.placement_group_is_local = [] refs = [] - for index, local_index, pg in zip(range(dp_size), local_dp_ranks, - placement_groups): + for index, local_index, pg in zip( + range(dp_size), local_dp_ranks, placement_groups + ): dp_vllm_config = copy.deepcopy(vllm_config) dp_vllm_config.parallel_config.placement_group = pg local_client = index < local_engine_count @@ -275,24 +285,32 @@ def __init__( # https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501 if current_platform.is_xpu(): device_evar = current_platform.device_control_env_var - device_indices = get_device_indices(device_evar, local_index, - world_size) + device_indices = get_device_indices( + device_evar, local_index, world_size + ) actor_env_vars = self.env_vars_dict.copy() actor_env_vars[device_evar] = device_indices runtime_env = RuntimeEnv(env_vars=actor_env_vars) - actor = ray.remote(DPEngineCoreActor).options( - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=world_size, - ), - runtime_env=runtime_env).remote(vllm_config=dp_vllm_config, - executor_class=executor_class, - log_stats=log_stats, - local_client=local_client, - addresses=addresses, - dp_rank=index, - local_dp_rank=local_index) + actor = ( + ray.remote(DPEngineCoreActor) + .options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + ), + runtime_env=runtime_env, + ) + .remote( + vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + local_client=local_client, + addresses=addresses, + dp_rank=index, + local_dp_rank=local_index, + ) + ) if local_client: self.local_engine_actors.append(actor) else: @@ -307,7 +325,7 @@ def __init__( @staticmethod def create_dp_placement_groups( - vllm_config: VllmConfig + vllm_config: VllmConfig, ) -> tuple[list["PlacementGroup"], list[int]]: """ Create placement groups for data parallel. @@ -317,23 +335,23 @@ def create_dp_placement_groups( from ray._private.state import available_resources_per_node logger.info("Creating placement groups for data parallel") - dp_master_ip = \ - vllm_config.parallel_config.data_parallel_master_ip + dp_master_ip = vllm_config.parallel_config.data_parallel_master_ip num_pg_to_create = vllm_config.parallel_config.data_parallel_size - local_engine_count = \ - vllm_config.parallel_config.data_parallel_size_local + local_engine_count = vllm_config.parallel_config.data_parallel_size_local available_resources = available_resources_per_node() world_size = vllm_config.parallel_config.world_size placement_groups: list[PlacementGroup] = [] local_dp_ranks: list[int] = [] - dp_master_ip_key = f'node:{dp_master_ip}' - nodes = sorted(available_resources.values(), - key=lambda x: dp_master_ip_key not in x) - assert len(nodes) > 0, ( - "No nodes with resources found in Ray cluster.") + dp_master_ip_key = f"node:{dp_master_ip}" + nodes = sorted( + available_resources.values(), key=lambda x: dp_master_ip_key not in x + ) + assert len(nodes) > 0, "No nodes with resources found in Ray cluster." assert dp_master_ip_key in nodes[0], ( - "The DP master node (ip: %s) is missing or dead", dp_master_ip) + "The DP master node (ip: %s) is missing or dead", + dp_master_ip, + ) device_str = current_platform.ray_device_key for node_resources in nodes: if device_str not in node_resources: @@ -341,19 +359,16 @@ def create_dp_placement_groups( # For now, each DP rank can only be assigned to one node # TODO(rui): support allocating a single DP rank # to multiple nodes - available_engine_count = int( - node_resources[device_str]) // world_size + available_engine_count = int(node_resources[device_str]) // world_size if dp_master_ip_key in node_resources: assert available_engine_count >= local_engine_count, ( "Not enough resources to allocate DP ranks " - f"on DP master node {dp_master_ip}") + f"on DP master node {dp_master_ip}" + ) for i in range(local_engine_count): - bundles = [{ - device_str: 1.0, - "node:" + dp_master_ip: 0.001 - }] * world_size + [{ - "CPU": 1.0 - }] + bundles = [ + {device_str: 1.0, "node:" + dp_master_ip: 0.001} + ] * world_size + [{"CPU": 1.0}] pg = ray.util.placement_group( name=f"dp_rank_{len(placement_groups)}", strategy="STRICT_PACK", @@ -379,7 +394,8 @@ def create_dp_placement_groups( "placement groups, only created " f"{len(placement_groups)} placement groups. " "Available resources: " - f"{available_resources}") + f"{available_resources}" + ) return placement_groups, local_dp_ranks @staticmethod @@ -390,8 +406,10 @@ def add_dp_placement_groups( Add placement groups for new data parallel size. """ import ray - from ray._private.state import (available_resources_per_node, - total_resources_per_node) + from ray._private.state import ( + available_resources_per_node, + total_resources_per_node, + ) from ray.util.state import list_nodes old_dp_size = old_vllm_config.parallel_config.data_parallel_size @@ -405,10 +423,10 @@ def add_dp_placement_groups( nodes = list_nodes() nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip) - assert nodes[0].node_ip == dp_master_ip, ( - "The first node must be the head node") + assert nodes[0].node_ip == dp_master_ip, "The first node must be the head node" assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( - "There can only be one head node") + "There can only be one head node" + ) available_resources = available_resources_per_node() total_resources = total_resources_per_node() @@ -446,12 +464,9 @@ def add_dp_placement_groups( # Create bundles with node constraint for master node if node_ip == dp_master_ip: - bundles = [{ - device_str: 1.0, - "node:" + dp_master_ip: 0.001 - }] * world_size + [{ - "CPU": 1.0 - }] + bundles = [ + {device_str: 1.0, "node:" + dp_master_ip: 0.001} + ] * world_size + [{"CPU": 1.0}] else: bundles = [{device_str: 1.0}] * world_size + [{"CPU": 1.0}] @@ -470,69 +485,76 @@ def add_dp_placement_groups( return placement_groups, local_dp_ranks - def scale_up_elastic_ep(self, cur_vllm_config: VllmConfig, - new_data_parallel_size: int) -> None: + def scale_up_elastic_ep( + self, cur_vllm_config: VllmConfig, new_data_parallel_size: int + ) -> None: import copy import ray from ray.runtime_env import RuntimeEnv - from ray.util.scheduling_strategies import ( - PlacementGroupSchedulingStrategy) + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm.v1.engine.core import DPEngineCoreActor - cur_data_parallel_size = len(self.local_engine_actors) + \ - len(self.remote_engine_actors) + cur_data_parallel_size = len(self.local_engine_actors) + len( + self.remote_engine_actors + ) assert new_data_parallel_size > cur_data_parallel_size, ( f"New data parallel size {new_data_parallel_size} must be greater " f"than current data parallel size {cur_data_parallel_size} " - "for scale up") + "for scale up" + ) - placement_groups, local_dp_ranks = \ - self.add_dp_placement_groups( - cur_vllm_config, new_data_parallel_size) + placement_groups, local_dp_ranks = self.add_dp_placement_groups( + cur_vllm_config, new_data_parallel_size + ) world_size = cur_vllm_config.parallel_config.world_size dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip new_local_engines = 0 - runtime_env = RuntimeEnv(env_vars=self.env_vars_dict - | {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"}) - for i, (pg, - local_rank) in enumerate(zip(placement_groups, - local_dp_ranks)): + runtime_env = RuntimeEnv( + env_vars=self.env_vars_dict | {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"} + ) + for i, (pg, local_rank) in enumerate(zip(placement_groups, local_dp_ranks)): rank = cur_data_parallel_size + i dp_vllm_config = copy.deepcopy(cur_vllm_config) - dp_vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + dp_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size dp_vllm_config.parallel_config.placement_group = pg # Check if this placement group is on the head node local_client = any( - bundle.get("node:" + dp_master_ip, 0) > 0 - for bundle in pg.bundle_specs) + bundle.get("node:" + dp_master_ip, 0) > 0 for bundle in pg.bundle_specs + ) if local_client: new_local_engines += 1 # Update data_parallel_size_local dp_vllm_config.parallel_config.data_parallel_size_local = ( - cur_vllm_config.parallel_config.data_parallel_size_local + - new_local_engines) - - actor = ray.remote(DPEngineCoreActor).options( - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=world_size, - ), - runtime_env=runtime_env).remote( + cur_vllm_config.parallel_config.data_parallel_size_local + + new_local_engines + ) + + actor = ( + ray.remote(DPEngineCoreActor) + .options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + ), + runtime_env=runtime_env, + ) + .remote( vllm_config=dp_vllm_config, executor_class=self.executor_class, log_stats=self.log_stats, local_client=local_client, addresses=self.addresses, dp_rank=rank, - local_dp_rank=local_rank) + local_dp_rank=local_rank, + ) + ) if local_client: self.local_engine_actors.append(actor) @@ -541,37 +563,47 @@ def scale_up_elastic_ep(self, cur_vllm_config: VllmConfig, self.created_placement_groups.append(pg) self.placement_group_is_local.append(local_client) - ray.get([ - actor.wait_for_init.remote() - for actor in (self.local_engine_actors[-new_local_engines:] - if new_local_engines > 0 else []) + - self.remote_engine_actors[-(len(placement_groups) - - new_local_engines):] - ]) + ray.get( + [ + actor.wait_for_init.remote() + for actor in ( + self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 + else [] + ) + + self.remote_engine_actors[ + -(len(placement_groups) - new_local_engines) : + ] + ] + ) - actors = (self.local_engine_actors[-new_local_engines:] - if new_local_engines > 0 else []) + \ - self.remote_engine_actors[-(len(placement_groups) - - new_local_engines):] + actors = ( + self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 + else [] + ) + self.remote_engine_actors[-(len(placement_groups) - new_local_engines) :] for actor in actors: self.run_refs.append(actor.run.remote()) - cur_vllm_config.parallel_config.data_parallel_size = \ - new_data_parallel_size + cur_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # Update old_vllm_config with new data_parallel_size_local if any new # local engines were added if new_local_engines > 0: - cur_vllm_config.parallel_config.data_parallel_size_local += \ + cur_vllm_config.parallel_config.data_parallel_size_local += ( new_local_engines + ) - def scale_down_elastic_ep(self, cur_data_parallel_size: int, - new_data_parallel_size: int) -> None: + def scale_down_elastic_ep( + self, cur_data_parallel_size: int, new_data_parallel_size: int + ) -> None: import ray + assert cur_data_parallel_size > new_data_parallel_size, ( f"cur_data_parallel_size {cur_data_parallel_size} must be greater " f"than new_data_parallel_size {new_data_parallel_size} " - "for scale down") + "for scale down" + ) for _ in range(cur_data_parallel_size - new_data_parallel_size): pg = self.created_placement_groups.pop() is_local = self.placement_group_is_local.pop() @@ -586,6 +618,7 @@ def get_run_refs(self): def close(self): import ray + for actor in self.local_engine_actors + self.remote_engine_actors: ray.kill(actor) for pg in self.created_placement_groups: @@ -598,11 +631,13 @@ def launch_core_engines( executor_class: type[Executor], log_stats: bool, num_api_servers: int = 1, -) -> Iterator[tuple[ +) -> Iterator[ + tuple[ Optional[Union[CoreEngineProcManager, CoreEngineActorManager]], Optional[DPCoordinator], EngineZmqAddresses, -]]: + ] +]: """Launch engine and DP coordinator processes as needed.""" parallel_config = vllm_config.parallel_config @@ -611,8 +646,10 @@ def launch_core_engines( local_start_index = parallel_config.data_parallel_rank_local dp_rank = parallel_config.data_parallel_rank host = parallel_config.data_parallel_master_ip - local_engines_only = (parallel_config.data_parallel_hybrid_lb - or parallel_config.data_parallel_external_lb) + local_engines_only = ( + parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb + ) # In offline mode there is an LLM instance per DP rank and # one core engine per LLM, see @@ -621,8 +658,9 @@ def launch_core_engines( # client_local_only = True for cases where this front-end # sends requests only to colocated engines. - client_local_only = (offline_mode or local_engines_only - or (local_engine_count == dp_size)) + client_local_only = ( + offline_mode or local_engines_only or (local_engine_count == dp_size) + ) # Set up input and output addresses. addresses = EngineZmqAddresses( @@ -644,12 +682,13 @@ def launch_core_engines( coordinator = DPCoordinator(parallel_config) addresses.coordinator_input, addresses.coordinator_output = ( - coordinator.get_engine_socket_addresses()) + coordinator.get_engine_socket_addresses() + ) addresses.frontend_stats_publish_address = ( - coordinator.get_stats_publish_address()) + coordinator.get_stats_publish_address() + ) - logger.info("Started DP Coordinator process (PID: %d)", - coordinator.proc.pid) + logger.info("Started DP Coordinator process (PID: %d)", coordinator.proc.pid) else: coordinator = None @@ -675,14 +714,14 @@ def launch_core_engines( # Note this also covers the case where we have zero local engines # and rank 0 is headless. engines_to_handshake = [ - CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(dp_size) + CoreEngine(index=i, local=(i < local_engine_count)) for i in range(dp_size) ] else: # Rank > 0 handshakes with just the local cores it is managing. assert local_engines_only, ( "Attempting to launch core_engines from dp_rank > 0, but " - "found internal DPLB, which is incompatible.") + "found internal DPLB, which is incompatible." + ) engines_to_handshake = [ CoreEngine(index=i, local=True) for i in range(dp_rank, dp_rank + local_engine_count) @@ -695,7 +734,8 @@ def launch_core_engines( handshake_local_only = offline_mode or local_engine_count == dp_size handshake_address = get_engine_client_zmq_addr( - handshake_local_only, host, parallel_config.data_parallel_rpc_port) + handshake_local_only, host, parallel_config.data_parallel_rpc_port + ) if local_engines_only and dp_rank > 0: assert not handshake_local_only @@ -705,9 +745,9 @@ def launch_core_engines( local_handshake_address = handshake_address client_handshake_address = None - with zmq_socket_ctx(local_handshake_address, zmq.ROUTER, - bind=True) as handshake_socket: - + with zmq_socket_ctx( + local_handshake_address, zmq.ROUTER, bind=True + ) as handshake_socket: from vllm.v1.engine.core import EngineCoreProc # Start local engines. @@ -722,7 +762,8 @@ def launch_core_engines( local_client=True, local_engine_count=local_engine_count, start_index=dp_rank, - local_start_index=local_start_index or 0) + local_start_index=local_start_index or 0, + ) else: local_engine_manager = None @@ -757,8 +798,10 @@ def wait_for_engine_startup( poller = zmq.Poller() poller.register(handshake_socket, zmq.POLLIN) - remote_should_be_headless = not parallel_config.data_parallel_hybrid_lb \ + remote_should_be_headless = ( + not parallel_config.data_parallel_hybrid_lb and not parallel_config.data_parallel_external_lb + ) if proc_manager is not None: for sentinel in proc_manager.sentinels(): @@ -770,67 +813,76 @@ def wait_for_engine_startup( if not events: if any(conn_pending): logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to connect.", *conn_pending) + "Waiting for %d local, %d remote core engine proc(s) to connect.", + *conn_pending, + ) if any(start_pending): logger.debug( - "Waiting for %d local, %d remote core engine proc(s) " - "to start.", *start_pending) + "Waiting for %d local, %d remote core engine proc(s) to start.", + *start_pending, + ) continue if len(events) > 1 or events[0][0] != handshake_socket: # One of the local core processes exited. finished = proc_manager.finished_procs() if proc_manager else {} if coord_process is not None and coord_process.exitcode is not None: finished[coord_process.name] = coord_process.exitcode - raise RuntimeError("Engine core initialization failed. " - "See root cause above. " - f"Failed core proc(s): {finished}") + raise RuntimeError( + "Engine core initialization failed. " + "See root cause above. " + f"Failed core proc(s): {finished}" + ) # Receive HELLO and READY messages from the input socket. eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() eng_index = int.from_bytes(eng_identity, "little") - engine = next((e for e in core_engines if e.identity == eng_identity), - None) + engine = next((e for e in core_engines if e.identity == eng_identity), None) if engine is None: - raise RuntimeError(f"Message from engine with unexpected data " - f"parallel rank: {eng_index}") + raise RuntimeError( + f"Message from engine with unexpected data parallel rank: {eng_index}" + ) msg = msgspec.msgpack.decode(ready_msg_bytes) status, local, headless = msg["status"], msg["local"], msg["headless"] if local != engine.local: - raise RuntimeError(f"{status} message from " - f"{'local' if local else 'remote'} " - f"engine {eng_index}, expected it to be " - f"{'local' if engine.local else 'remote'}") + raise RuntimeError( + f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}" + ) # Remote engines must be headless iff we aren't in hybrid dp lb mode. if not local and headless != remote_should_be_headless: if headless: - raise RuntimeError(f"Remote engine {eng_index} must not use " - f"--headless in external or hybrid dp lb " - f"mode") + raise RuntimeError( + f"Remote engine {eng_index} must not use " + f"--headless in external or hybrid dp lb " + f"mode" + ) else: - raise RuntimeError(f"Remote engine {eng_index} must use " - f"--headless unless in external or hybrid " - f"dp lb mode") + raise RuntimeError( + f"Remote engine {eng_index} must use " + f"--headless unless in external or hybrid " + f"dp lb mode" + ) if status == "HELLO" and engine.state == CoreEngineState.NEW: - # Send init message with DP config info. init_message = msgspec.msgpack.encode( EngineHandshakeMetadata( addresses=addresses, parallel_config={ - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "_data_parallel_master_port_list": - parallel_config._data_parallel_master_port_list, - "data_parallel_size": - parallel_config.data_parallel_size, - })) - handshake_socket.send_multipart((eng_identity, init_message), - copy=False) + k: getattr(parallel_config, k) + for k in ( + "data_parallel_master_ip", + "data_parallel_master_port", + "_data_parallel_master_port_list", + "data_parallel_size", + ) + }, + ) + ) + handshake_socket.send_multipart((eng_identity, init_message), copy=False) conn_pending[0 if local else 1] -= 1 start_pending[0 if local else 1] += 1 engine.state = CoreEngineState.CONNECTED @@ -846,15 +898,20 @@ def wait_for_engine_startup( # one of the engine handshakes, and passed to the local # front-end process in the response from the other. if addresses.frontend_stats_publish_address is None: - addresses.frontend_stats_publish_address = msg.get( - "dp_stats_address") + addresses.frontend_stats_publish_address = msg.get("dp_stats_address") start_pending[0 if local else 1] -= 1 engine.state = CoreEngineState.READY else: - raise RuntimeError(f"Unexpected {status} message for " - f"{'local' if local else 'remote'} engine " - f"{eng_index} in {engine.state} state.") - - logger.debug("%s from %s core engine process %s.", status, - "local" if local else "remote", eng_index) + raise RuntimeError( + f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state." + ) + + logger.debug( + "%s from %s core engine process %s.", + status, + "local" if local else "remote", + eng_index, + ) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 625017d52fff..064e4b2bbf18 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -10,9 +10,9 @@ from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase from vllm.executor.uniproc_executor import ( # noqa - ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0) -from vllm.executor.uniproc_executor import ( # noqa - UniProcExecutor as UniProcExecutorV0) + ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0, +) +from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0 # noqa from vllm.utils import resolve_obj_by_qualname from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec @@ -30,21 +30,24 @@ class Executor(ExecutorBase): def get_class(vllm_config: VllmConfig) -> type["Executor"]: executor_class: type[Executor] parallel_config = vllm_config.parallel_config - distributed_executor_backend = ( - parallel_config.distributed_executor_backend) + distributed_executor_backend = parallel_config.distributed_executor_backend # distributed_executor_backend must be set in VllmConfig.__post_init__ if isinstance(distributed_executor_backend, type): if not issubclass(distributed_executor_backend, ExecutorBase): raise TypeError( "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {distributed_executor_backend}.") + f"ExecutorBase. Got {distributed_executor_backend}." + ) executor_class = distributed_executor_backend elif distributed_executor_backend == "ray": from vllm.v1.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor) + RayDistributedExecutor, + ) + executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor + executor_class = MultiprocExecutor elif distributed_executor_backend == "uni": executor_class = UniProcExecutor @@ -53,25 +56,24 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]: # to support external launcher executor_class = ExecutorWithExternalLauncher elif isinstance(distributed_executor_backend, str): - executor_class = resolve_obj_by_qualname( - distributed_executor_backend) + executor_class = resolve_obj_by_qualname(distributed_executor_backend) if not issubclass(executor_class, ExecutorBase): raise TypeError( "distributed_executor_backend must be a subclass of " - f"ExecutorBase. Got {executor_class}.") + f"ExecutorBase. Got {executor_class}." + ) else: - raise ValueError("Unknown distributed executor backend: " - f"{distributed_executor_backend}") + raise ValueError( + f"Unknown distributed executor backend: {distributed_executor_backend}" + ) return executor_class - def initialize_from_config(self, - kv_cache_configs: list[KVCacheConfig]) -> None: + def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ - self.collective_rpc("initialize_from_config", - args=(kv_cache_configs, )) + self.collective_rpc("initialize_from_config", args=(kv_cache_configs,)) self.collective_rpc("compile_or_warm_up_model") def register_failure_callback(self, callback: FailureCallback): @@ -87,12 +89,14 @@ def determine_available_memory(self) -> list[int]: # in bytes def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: return self.collective_rpc("get_kv_cache_spec") - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None, - non_block: bool = False) -> list[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + ) -> list[Any]: raise NotImplementedError def execute_model( @@ -100,9 +104,9 @@ def execute_model( scheduler_output: SchedulerOutput, non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - output = self.collective_rpc("execute_model", - args=(scheduler_output, ), - non_block=non_block) + output = self.collective_rpc( + "execute_model", args=(scheduler_output,), non_block=non_block + ) return output[0] def execute_dummy_batch(self) -> None: @@ -117,7 +121,7 @@ def max_concurrent_batches(self) -> int: return 1 def profile(self, is_start: bool = True): - self.collective_rpc("profile", args=(is_start, )) + self.collective_rpc("profile", args=(is_start,)) class UniProcExecutor(UniProcExecutorV0, Executor): @@ -125,12 +129,12 @@ class UniProcExecutor(UniProcExecutorV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): - def determine_available_memory(self) -> list[int]: # in bytes # same as determine_num_available_blocks in v0, # we need to get the min across all ranks. memory = super().determine_available_memory() from vllm.distributed.parallel_state import get_world_group + cpu_group = get_world_group().cpu_group memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64) dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index eecdf8def6de..062b6042693b 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -24,30 +24,36 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel) -from vllm.distributed.device_communicators.shm_broadcast import (Handle, - MessageQueue) -from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, - get_pp_group, get_tp_group) +from vllm.distributed import destroy_distributed_environment, destroy_model_parallel +from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue +from vllm.distributed.parallel_state import ( + get_dp_group, + get_ep_group, + get_pp_group, + get_tp_group, +) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import worker_receiver_cache_from_config -from vllm.utils import (_maybe_force_spawn, decorate_logs, - get_distributed_init_method, get_loopback_ip, - get_mp_context, get_open_port, set_process_title) +from vllm.utils import ( + _maybe_force_spawn, + decorate_logs, + get_distributed_init_method, + get_loopback_ip, + get_mp_context, + get_open_port, + set_process_title, +) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.executor.utils import get_and_update_mm_cache -from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, - ModelRunnerOutput) +from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) class MultiprocExecutor(Executor): - supports_pp: bool = True def _init_executor(self) -> None: @@ -65,7 +71,8 @@ def _init_executor(self) -> None: assert self.world_size == tensor_parallel_size * pp_parallel_size, ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" - f"_parallel_size ({pp_parallel_size}). ") + f"_parallel_size ({pp_parallel_size}). " + ) # Set multiprocessing envs set_multiprocessing_worker_envs() @@ -74,14 +81,15 @@ def _init_executor(self) -> None: # Since it only works for single node, we can use the loopback address # get_loopback_ip() for communication. distributed_init_method = get_distributed_init_method( - get_loopback_ip(), get_open_port()) + get_loopback_ip(), get_open_port() + ) # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 - self.rpc_broadcast_mq = MessageQueue(self.world_size, - self.world_size, - max_chunk_bytes=max_chunk_bytes) + self.rpc_broadcast_mq = MessageQueue( + self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes + ) scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers @@ -99,7 +107,8 @@ def _init_executor(self) -> None: distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, shared_worker_lock=shared_worker_lock, - )) + ) + ) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. @@ -120,8 +129,7 @@ def _init_executor(self) -> None: for uw in unready_workers: if uw.death_writer is not None: uw.death_writer.close() - self._ensure_worker_termination( - [uw.proc for uw in unready_workers]) + self._ensure_worker_termination([uw.proc for uw in unready_workers]) # For pipeline parallel, we use a thread pool for asynchronous # execute_model. @@ -130,7 +138,8 @@ def _init_executor(self) -> None: # from the response queue # _async_aggregate_workers_output also assumes a single IO thread self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io") + max_workers=1, thread_name_prefix="mp_exec_io" + ) self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None @@ -146,23 +155,22 @@ def monitor_workers(): sentinels = [h.proc.sentinel for h in workers] died = multiprocessing.connection.wait(sentinels) _self = self_ref() - if not _self or getattr(_self, 'shutting_down', False): + if not _self or getattr(_self, "shutting_down", False): return _self.is_failed = True - proc_name = next(h.proc.name for h in workers - if h.proc.sentinel == died[0]) + proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0]) logger.error( - "Worker proc %s died unexpectedly, " - "shutting down executor.", proc_name) + "Worker proc %s died unexpectedly, shutting down executor.", proc_name + ) _self.shutdown() callback = _self.failure_callback if callback is not None: _self.failure_callback = None callback() - Thread(target=monitor_workers, - daemon=True, - name="MultiprocWorkerMonitor").start() + Thread( + target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor" + ).start() def register_failure_callback(self, callback: FailureCallback): if self.is_failed: @@ -175,47 +183,49 @@ def execute_model( scheduler_output: SchedulerOutput, non_block: bool = False, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - if not self.has_connector: # get output only from a single worker (output_rank) - (output, ) = self.collective_rpc( + (output,) = self.collective_rpc( "execute_model", - args=(scheduler_output, ), + args=(scheduler_output,), unique_reply_rank=self.output_rank, non_block=non_block, - timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, + ) return output # get output from all workers outputs = self.collective_rpc( "execute_model", - args=(scheduler_output, ), + args=(scheduler_output,), non_block=non_block, - timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS, + ) # aggregate all workers output to a single output if non_block: - return self.kv_output_aggregator.async_aggregate( - outputs, self.output_rank) + return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank) def execute_dummy_batch(self) -> None: - self.collective_rpc("execute_dummy_batch", - unique_reply_rank=self.output_rank) + self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank) def take_draft_token_ids(self) -> Optional[DraftTokenIds]: # OPTIMIZATION: Get output only from a single worker (output_rank) - outputs = self.collective_rpc("take_draft_token_ids", - unique_reply_rank=self.output_rank) + outputs = self.collective_rpc( + "take_draft_token_ids", unique_reply_rank=self.output_rank + ) return outputs[0] - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict] = None, - non_block: bool = False, - unique_reply_rank: Optional[int] = None) -> list[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + non_block: bool = False, + unique_reply_rank: Optional[int] = None, + ) -> list[Any]: if self.is_failed: raise RuntimeError("Executor failed.") @@ -230,42 +240,53 @@ def collective_rpc(self, send_method = method else: send_method = cloudpickle.dumps( - method, protocol=pickle.HIGHEST_PROTOCOL) + method, protocol=pickle.HIGHEST_PROTOCOL + ) self.rpc_broadcast_mq.enqueue( - (send_method, args, kwargs, unique_reply_rank)) - - workers = (self.workers[unique_reply_rank], - ) if unique_reply_rank is not None else self.workers + (send_method, args, kwargs, unique_reply_rank) + ) + + workers = ( + (self.workers[unique_reply_rank],) + if unique_reply_rank is not None + else self.workers + ) responses = [] - def get_response(w: WorkerProcHandle, - dequeue_timeout: Optional[float] = None, - cancel_event: Optional[threading.Event] = None): + def get_response( + w: WorkerProcHandle, + dequeue_timeout: Optional[float] = None, + cancel_event: Optional[threading.Event] = None, + ): status, result = w.worker_response_mq.dequeue( - timeout=dequeue_timeout, cancel=cancel_event) + timeout=dequeue_timeout, cancel=cancel_event + ) if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( f"Worker failed with error '{result}', please check the" - " stack trace above for the root cause") + " stack trace above for the root cause" + ) return result for w in workers: - dequeue_timeout = None if deadline is None else ( - deadline - time.monotonic()) + dequeue_timeout = ( + None if deadline is None else (deadline - time.monotonic()) + ) if self.io_thread_pool is not None: # We must consume worker_response_mq from a single thread. result = self.io_thread_pool.submit( # type: ignore - get_response, w, dequeue_timeout, self.shutdown_event) + get_response, w, dequeue_timeout, self.shutdown_event + ) if not non_block: result = result.result() elif not non_block: - result = get_response(w, dequeue_timeout, - self.shutdown_event) + result = get_response(w, dequeue_timeout, self.shutdown_event) else: - raise RuntimeError("non_block can only be used when" - " max_concurrent_batches > 1") + raise RuntimeError( + "non_block can only be used when max_concurrent_batches > 1" + ) responses.append(result) return responses @@ -302,11 +323,11 @@ def wait_for_termination(procs, timeout): def shutdown(self): """Properly shut down the executor and its workers""" - if not getattr(self, 'shutting_down', False): + if not getattr(self, "shutting_down", False): self.shutting_down = True # Make sure all the worker processes are terminated first. - if workers := getattr(self, 'workers', None): + if workers := getattr(self, "workers", None): for w in workers: # Close death_writer to signal child processes to exit if w.death_writer is not None: @@ -348,6 +369,7 @@ def _get_output_rank(self) -> int: @dataclass class UnreadyWorkerProcHandle: """WorkerProcess handle before READY.""" + proc: BaseProcess rank: int ready_pipe: Connection @@ -363,8 +385,8 @@ class WorkerProcHandle: @classmethod def from_unready_handle( - cls, unready_handle: UnreadyWorkerProcHandle, - worker_response_mq: MessageQueue) -> "WorkerProcHandle": + cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue + ) -> "WorkerProcHandle": return cls( proc=unready_handle.proc, rank=unready_handle.rank, @@ -393,8 +415,7 @@ def __init__( all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] - is_driver_worker = ( - rank % vllm_config.parallel_config.tensor_parallel_size == 0) + is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0 all_kwargs[rank] = { "vllm_config": vllm_config, "local_rank": local_rank, @@ -407,7 +428,8 @@ def __init__( # Initialize MessageQueue for receiving SchedulerOutput self.rpc_broadcast_mq = MessageQueue.create_from_handle( - input_shm_handle, self.worker.rank) + input_shm_handle, self.worker.rank + ) # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) @@ -419,19 +441,22 @@ def __init__( self.async_output_copy_thread = Thread( target=self.async_output_busy_loop, daemon=True, - name="WorkerAsyncOutputCopy") + name="WorkerAsyncOutputCopy", + ) self.async_output_copy_thread.start() # Initialize multimodal receiver cache if needed self.mm_receiver_cache = worker_receiver_cache_from_config( - vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock) + vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock + ) # Initialize device self.worker.init_device() # Set process title and log prefix self.setup_proc_title_and_log_prefix( - enable_ep=vllm_config.parallel_config.enable_expert_parallel) + enable_ep=vllm_config.parallel_config.enable_expert_parallel + ) # Load model self.worker.load_model() @@ -463,10 +488,12 @@ def make_worker_process( "shared_worker_lock": shared_worker_lock, } # Run EngineCore busy loop in background process. - proc = context.Process(target=WorkerProc.worker_main, - kwargs=process_kwargs, - name=f"VllmWorker-{rank}", - daemon=True) + proc = context.Process( + target=WorkerProc.worker_main, + kwargs=process_kwargs, + name=f"VllmWorker-{rank}", + daemon=True, + ) proc.start() writer.close() @@ -476,16 +503,18 @@ def make_worker_process( @staticmethod def wait_for_ready( - unready_proc_handles: list[UnreadyWorkerProcHandle] + unready_proc_handles: list[UnreadyWorkerProcHandle], ) -> list[WorkerProcHandle]: - - e = Exception("WorkerProc initialization failed due to " - "an exception in a background process. " - "See stack trace for root cause.") + e = Exception( + "WorkerProc initialization failed due to " + "an exception in a background process. " + "See stack trace for root cause." + ) pipes = {handle.ready_pipe: handle for handle in unready_proc_handles} - ready_proc_handles: list[Optional[WorkerProcHandle]] = ( - [None] * len(unready_proc_handles)) + ready_proc_handles: list[Optional[WorkerProcHandle]] = [None] * len( + unready_proc_handles + ) while pipes: ready = multiprocessing.connection.wait(pipes.keys()) for pipe in ready: @@ -499,10 +528,13 @@ def wait_for_ready( # Extract the message queue handle. worker_response_mq = MessageQueue.create_from_handle( - response["handle"], 0) + response["handle"], 0 + ) ready_proc_handles[unready_proc_handle.rank] = ( WorkerProcHandle.from_unready_handle( - unready_proc_handle, worker_response_mq)) + unready_proc_handle, worker_response_mq + ) + ) except EOFError: e.__suppress_context__ = True @@ -523,8 +555,8 @@ def shutdown(self): @staticmethod def worker_main(*args, **kwargs): - """ Worker initialization and execution loops. - This runs a background process """ + """Worker initialization and execution loops. + This runs a background process""" # Signal handler used for graceful termination. # SystemExit exception is only raised once to allow this and worker @@ -561,9 +593,9 @@ def monitor_parent_death(): except Exception as e: logger.warning("Death monitoring error: %s", e) - death_monitor = Thread(target=monitor_parent_death, - daemon=True, - name="WorkerDeathMonitor") + death_monitor = Thread( + target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor" + ) death_monitor.start() try: @@ -571,12 +603,12 @@ def monitor_parent_death(): worker = WorkerProc(*args, **kwargs) # Send READY once we know everything is loaded - ready_writer.send({ - "status": - WorkerProc.READY_STR, - "handle": - worker.worker_response_mq.export_handle(), - }) + ready_writer.send( + { + "status": WorkerProc.READY_STR, + "handle": worker.worker_response_mq.export_handle(), + } + ) # Ensure message queues are ready. Will deadlock if re-ordered. # Must be kept consistent with the Executor @@ -653,15 +685,18 @@ def worker_busy_loop(self, cancel: Optional[threading.Event] = None): """Main busy loop for Multiprocessing Workers""" while True: method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( - cancel=cancel, indefinite=True) + cancel=cancel, indefinite=True + ) try: if isinstance(method, str): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) # retrieve from shm cache if available - if self.mm_receiver_cache is not None \ - and func.__name__ == "execute_model": + if ( + self.mm_receiver_cache is not None + and func.__name__ == "execute_model" + ): get_and_update_mm_cache(self.mm_receiver_cache, args) output = func(*args, **kwargs) except Exception as e: @@ -701,7 +736,7 @@ def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: def set_multiprocessing_worker_envs(): - """ Set up environment variables that should be used when there are workers + """Set up environment variables that should be used when there are workers in a multiprocessing environment. This should be called by the parent process before worker processes are created""" @@ -714,13 +749,16 @@ def set_multiprocessing_worker_envs(): # impact on performance. The contention is amplified when running in a # container where CPU limits can cause throttling. default_omp_num_threads = 1 - if "OMP_NUM_THREADS" not in os.environ and ( - current_parallelism := - torch.get_num_threads()) > default_omp_num_threads: + if ( + "OMP_NUM_THREADS" not in os.environ + and (current_parallelism := torch.get_num_threads()) > default_omp_num_threads + ): logger.warning( "Reducing Torch parallelism from %d threads to %d to avoid " "unnecessary CPU contention. Set OMP_NUM_THREADS in the " "external environment to tune this value as needed.", - current_parallelism, default_omp_num_threads) + current_parallelism, + default_omp_num_threads, + ) os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) torch.set_num_threads(default_omp_num_threads) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index aadb5fd1dddd..e2c2bfd45d7b 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -6,7 +6,8 @@ from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.ray_distributed_executor import ( # noqa - RayDistributedExecutor as RayDistributedExecutorV0) + RayDistributedExecutor as RayDistributedExecutorV0, +) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType @@ -18,10 +19,10 @@ class FutureWrapper(Future): """A wrapper around Ray output reference to meet the interface - of .execute_model(): The top level (core busy loop) expects .result() api + of .execute_model(): The top level (core busy loop) expects .result() api to block and return a single output. - - If aggregator is provided, the outputs from all workers are aggregated upon + + If aggregator is provided, the outputs from all workers are aggregated upon the result() call. If not only the first worker's output is returned. """ @@ -101,8 +102,11 @@ def execute_model( return FutureWrapper(refs, self.kv_output_aggregator) def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: self._run_workers("reinitialize_distributed", reconfig_request) - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): self.shutdown() diff --git a/vllm/v1/executor/utils.py b/vllm/v1/executor/utils.py index 1855bc996381..884068a43882 100644 --- a/vllm/v1/executor/utils.py +++ b/vllm/v1/executor/utils.py @@ -20,4 +20,5 @@ def get_and_update_mm_cache( scheduler_output = args[0] for request_data in scheduler_output.scheduled_new_reqs: request_data.mm_features = receiver_cache.get_and_update_features( - request_data.mm_features) + request_data.mm_features + ) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 281816653540..9c28eb92c17a 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -50,7 +50,8 @@ def merge(cls, specs: list[Self]) -> Self: Merge a list of KVCacheSpec objects into a single KVCacheSpec object. """ assert all(spec == specs[0] for spec in specs[1:]), ( - "All layers in the same KV cache group must be the same.") + "All layers in the same KV cache group must be the same." + ) return copy.deepcopy(specs[0]) @@ -62,8 +63,13 @@ class AttentionSpec(KVCacheSpec): @property def page_size_bytes(self) -> int: - return 2 * self.block_size * self.num_kv_heads * self.head_size \ - * get_dtype_size(self.dtype) + return ( + 2 + * self.block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) @dataclass(frozen=True) @@ -82,8 +88,7 @@ class FullAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len - dcp_world_size = \ - vllm_config.parallel_config.decode_context_parallel_size + dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size # Note(hc): each dcp rank only need save # (max_model_len//dcp_world_size) tokens locally. if dcp_world_size > 1: @@ -99,24 +104,30 @@ def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]: else: raise ValueError( "All attention layers in the same KV cache group must have the " - "same window size.") + "same window size." + ) @classmethod def merge(cls, specs: list[Self]) -> Self: """ - Merge a list of FullAttentionSpec objects into a single + Merge a list of FullAttentionSpec objects into a single FullAttentionSpec object. """ assert all(isinstance(spec, FullAttentionSpec) for spec in specs), ( - "All attention layers in the same KV cache group must be " - "FullAttentionSpec.") + "All attention layers in the same KV cache group must be FullAttentionSpec." + ) - sliding_window = set(spec.sliding_window for spec in specs - if spec.sliding_window is not None) - attention_chunk_size = set(spec.attention_chunk_size for spec in specs - if spec.attention_chunk_size is not None) + sliding_window = set( + spec.sliding_window for spec in specs if spec.sliding_window is not None + ) + attention_chunk_size = set( + spec.attention_chunk_size + for spec in specs + if spec.attention_chunk_size is not None + ) assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( - "MLAAttentionSpec should be merged in MLAAttentionSpec.merge") + "MLAAttentionSpec should be merged in MLAAttentionSpec.merge" + ) merged_spec = cls( block_size=specs[0].block_size, num_kv_heads=specs[0].num_kv_heads, @@ -129,12 +140,14 @@ def merge(cls, specs: list[Self]) -> Self: for f in fields(AttentionSpec): assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( "All attention layers in the same KV cache group must have " - "the same attention spec.") - assert ( - (merged_spec.sliding_window is not None) + - (merged_spec.attention_chunk_size is not None) <= 1 - ), ("Model with both sliding window layers and chunked local attention " - "layers is not supported.") + "the same attention spec." + ) + assert (merged_spec.sliding_window is not None) + ( + merged_spec.attention_chunk_size is not None + ) <= 1, ( + "Model with both sliding window layers and chunked local attention " + "layers is not supported." + ) return merged_spec @@ -149,18 +162,23 @@ def page_size_bytes(self) -> int: # See `vllm/v1/attention/backends/mla/flashmla_sparse.py` # for details. return self.block_size * 656 - return self.block_size * self.num_kv_heads * self.head_size \ - * get_dtype_size(self.dtype) + return ( + self.block_size + * self.num_kv_heads + * self.head_size + * get_dtype_size(self.dtype) + ) @classmethod def merge(cls, specs: list[Self]) -> Self: assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), ( - "All attention layers in the same KV cache group must be " - "MLAAttentionSpec.") + "All attention layers in the same KV cache group must be MLAAttentionSpec." + ) cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) assert len(cache_dtype_str_set) == 1, ( "All attention layers in the same KV cache group must use the same " - "quantization method.") + "quantization method." + ) return cls( block_size=specs[0].block_size, num_kv_heads=specs[0].num_kv_heads, @@ -176,15 +194,15 @@ class ChunkedLocalAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len - max_num_batched_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens # During chunked prefill, we allocate KV cache for at most # `self.attention_chunk_size` computed tokens plus the newly scheduled # tokens. And we won't allocate KV cache for more than `max_model_len` # tokens. - num_tokens = min(self.attention_chunk_size + max_num_batched_tokens, - max_model_len) + num_tokens = min( + self.attention_chunk_size + max_num_batched_tokens, max_model_len + ) return cdiv(num_tokens, self.block_size) * self.page_size_bytes @@ -194,18 +212,19 @@ class SlidingWindowSpec(AttentionSpec): sliding_window: int def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - assert vllm_config.parallel_config.decode_context_parallel_size == 1, \ + assert vllm_config.parallel_config.decode_context_parallel_size == 1, ( "DCP not support sliding window." + ) max_model_len = vllm_config.model_config.max_model_len - max_num_batched_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens # During chunked prefill, we allocate KV cache for the last # `self.sliding_window-1` computed tokens plus the newly scheduled # tokens. And we won't allocate KV cache for more than `max_model_len` # tokens. - num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, - max_model_len) + num_tokens = min( + self.sliding_window - 1 + max_num_batched_tokens, max_model_len + ) # +1 here because the sliding window may not start from the beginning # of the block. For example, if the block size is 4 and num_token @@ -226,22 +245,20 @@ class MambaSpec(KVCacheSpec): def page_size_bytes(self) -> int: page_size = sum( prod(shape) * get_dtype_size(dtype) - for (shape, dtype) in zip(self.shapes, self.dtypes)) + for (shape, dtype) in zip(self.shapes, self.dtypes) + ) if self.page_size_padded is not None: assert self.page_size_padded >= page_size return self.page_size_padded return page_size def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: - # We allocate 1 block for each request now, so max_memory_usage_bytes is - # the same as page_size_bytes. - # Need to update this when supporting prefix caching. - return self.page_size_bytes + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes @dataclass(frozen=True) class EncoderOnlyAttentionSpec(AttentionSpec): - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # Encoder-only layers do not need KV cache return 0 @@ -256,8 +273,7 @@ class CrossAttentionSpec(AttentionSpec): def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: # For cross-attention, we need to cache encoder states # Get encoder length (e.g., 1500 for Whisper). - max_encoder_len = vllm_config.scheduler_config.\ - max_num_encoder_input_tokens + max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes @@ -269,18 +285,18 @@ class UniformTypeKVCacheSpecs(KVCacheSpec): sliding window attentions with different window sizes are not the same type and should not be merged into one UniformTypeKVCacheSpecs. """ + kv_cache_specs: dict[str, KVCacheSpec] @property def page_size_bytes(self) -> int: - return sum(spec.page_size_bytes - for spec in self.kv_cache_specs.values()) + return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values()) def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_num_pages = max( - cdiv(spec.max_memory_usage_bytes(vllm_config), - spec.page_size_bytes) - for spec in self.kv_cache_specs.values()) + cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes) + for spec in self.kv_cache_specs.values() + ) return max_num_pages * self.page_size_bytes @classmethod @@ -295,35 +311,38 @@ def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool: one_spec = next(iter(kv_cache_specs.values())) if isinstance(one_spec, FullAttentionSpec): return all( - isinstance(spec, FullAttentionSpec) - for spec in kv_cache_specs.values()) + isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values() + ) elif isinstance(one_spec, CrossAttentionSpec): return all( - isinstance(spec, CrossAttentionSpec) - for spec in kv_cache_specs.values()) + isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values() + ) elif isinstance(one_spec, SlidingWindowSpec): return all( isinstance(spec, SlidingWindowSpec) and spec.sliding_window == one_spec.sliding_window - for spec in kv_cache_specs.values()) + for spec in kv_cache_specs.values() + ) elif isinstance(one_spec, ChunkedLocalAttentionSpec): return all( isinstance(spec, ChunkedLocalAttentionSpec) and spec.attention_chunk_size == one_spec.attention_chunk_size - for spec in kv_cache_specs.values()) + for spec in kv_cache_specs.values() + ) elif isinstance(one_spec, MambaSpec): return all( - isinstance(spec, MambaSpec) and spec.num_speculative_blocks == - one_spec.num_speculative_blocks - for spec in kv_cache_specs.values()) + isinstance(spec, MambaSpec) + and spec.num_speculative_blocks == one_spec.num_speculative_blocks + for spec in kv_cache_specs.values() + ) else: # NOTE(Chen): Please add new branches for new KV cache spec types. raise NotImplementedError( - f"Unsupported KV cache spec type: {type(one_spec)}") + f"Unsupported KV cache spec type: {type(one_spec)}" + ) @classmethod - def from_specs(cls, kv_cache_specs: dict[str, - KVCacheSpec]) -> Optional[Self]: + def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Optional[Self]: """ Return a SameTypeKVCacheSpecs object if all layers have the same type of KV cache spec. Return None if not. @@ -340,6 +359,7 @@ class KVCacheTensor: """ A class for specifying how the workers should initialize the KV cache. """ + size: int # size of the KV cache tensor in bytes shared_by: list[str] # layer names that share the same KV cache tensor @@ -350,6 +370,7 @@ class KVCacheGroupSpec: Represents a group of model layers that share the same KV cache block table. These layers are regarded as one layer in the KV cache manager. """ + # The names of model layers in this group layer_names: list[str] # The KV cache spec of this manager layer @@ -361,6 +382,7 @@ class KVCacheConfig: """ The KV cache configuration of a model. """ + """The number of KV cache blocks""" num_blocks: int """How should model runner initialize the KV cache tensors for each layer""" diff --git a/vllm/v1/kv_offload/abstract.py b/vllm/v1/kv_offload/abstract.py index 9f9c044ea1c5..ce2d0dffc0ff 100644 --- a/vllm/v1/kv_offload/abstract.py +++ b/vllm/v1/kv_offload/abstract.py @@ -68,7 +68,6 @@ class OffloadingEvent: class OffloadingManager(ABC): - @abstractmethod def lookup(self, block_hashes: Iterable[BlockHash]) -> int: """ @@ -122,8 +121,8 @@ def complete_load(self, block_hashes: Iterable[BlockHash]): @abstractmethod def prepare_store( - self, - block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]: + self, block_hashes: Iterable[BlockHash] + ) -> Optional[PrepareStoreOutput]: """ Prepare the given blocks to be offloaded. The given blocks will be protected from eviction until @@ -140,9 +139,7 @@ def prepare_store( """ pass - def complete_store(self, - block_hashes: Iterable[BlockHash], - success: bool = True): + def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True): """ Marks blocks which were previously prepared to be stored, as stored. Following this call, the blocks become loadable. diff --git a/vllm/v1/kv_offload/backend.py b/vllm/v1/kv_offload/backend.py index 87a74200116b..538f7bf0584b 100644 --- a/vllm/v1/kv_offload/backend.py +++ b/vllm/v1/kv_offload/backend.py @@ -18,6 +18,7 @@ class BlockStatus(ctypes.Structure): load_store_spec - backend-specific information on how to actually read/write the block. """ + _fields_ = [("ref_cnt", ctypes.c_int32)] def __init__(self): @@ -51,8 +52,7 @@ def get_num_free_blocks(self): pass @abstractmethod - def allocate_blocks(self, - block_hashes: list[BlockHash]) -> list[BlockStatus]: + def allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]: """ Allocate space for writing blocks. This method assumes there is enough space for allocation. @@ -80,8 +80,9 @@ def free(self, block: BlockStatus): """ pass - def get_load_store_spec(self, block_hashes: Iterable[BlockHash], - blocks: Iterable[BlockStatus]) -> LoadStoreSpec: + def get_load_store_spec( + self, block_hashes: Iterable[BlockHash], blocks: Iterable[BlockStatus] + ) -> LoadStoreSpec: """ Get backend-specific information on how to read/write blocks. diff --git a/vllm/v1/kv_offload/backends/cpu.py b/vllm/v1/kv_offload/backends/cpu.py index eb1123d1d83a..736cf37853cd 100644 --- a/vllm/v1/kv_offload/backends/cpu.py +++ b/vllm/v1/kv_offload/backends/cpu.py @@ -10,8 +10,7 @@ class CPUBlockStatus(BlockStatus): - _fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64) - ] # type: ignore + _fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64)] # type: ignore def __init__(self, block_id: int): super().__init__() @@ -19,23 +18,24 @@ def __init__(self, block_id: int): class CPUBackend(Backend): - def __init__(self, block_size: int, num_blocks: int): - super().__init__(block_size=block_size, - medium=CPULoadStoreSpec.medium()) + super().__init__(block_size=block_size, medium=CPULoadStoreSpec.medium()) self.num_blocks: int = num_blocks self.num_allocated_blocks: int = 0 self.allocated_blocks_free_list: list[int] = [] def get_num_free_blocks(self): - return (len(self.allocated_blocks_free_list) + self.num_blocks - - self.num_allocated_blocks) - - def allocate_blocks(self, - block_hashes: list[BlockHash]) -> list[BlockStatus]: - num_fresh_blocks = min(len(block_hashes), - self.num_blocks - self.num_allocated_blocks) + return ( + len(self.allocated_blocks_free_list) + + self.num_blocks + - self.num_allocated_blocks + ) + + def allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]: + num_fresh_blocks = min( + len(block_hashes), self.num_blocks - self.num_allocated_blocks + ) num_reused_blocks = len(block_hashes) - num_fresh_blocks assert len(self.allocated_blocks_free_list) >= num_reused_blocks @@ -56,6 +56,7 @@ def free(self, block: BlockStatus): assert isinstance(block, CPUBlockStatus) self.allocated_blocks_free_list.append(block.block_id) - def get_load_store_spec(self, block_hashes: Iterable[BlockHash], - blocks: Iterable[BlockStatus]) -> LoadStoreSpec: + def get_load_store_spec( + self, block_hashes: Iterable[BlockHash], blocks: Iterable[BlockStatus] + ) -> LoadStoreSpec: return CPULoadStoreSpec([block.block_id for block in blocks]) diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py index b85d375fe63e..0c1cf64a237c 100644 --- a/vllm/v1/kv_offload/cpu.py +++ b/vllm/v1/kv_offload/cpu.py @@ -18,14 +18,14 @@ class CPUOffloadingSpec(OffloadingSpec): - def __init__(self, vllm_config: VllmConfig): super().__init__(vllm_config) num_cpu_blocks = self.extra_config.get("num_cpu_blocks") if not num_cpu_blocks: - raise Exception("num_cpu_blocks must be specified " - "in kv_connector_extra_config") + raise Exception( + "num_cpu_blocks must be specified in kv_connector_extra_config" + ) self.num_cpu_blocks: int = num_cpu_blocks # scheduler-side @@ -37,27 +37,30 @@ def __init__(self, vllm_config: VllmConfig): def get_manager(self) -> OffloadingManager: if not self._manager: kv_events_config = self.vllm_config.kv_events_config - enable_events = (kv_events_config is not None - and kv_events_config.enable_kv_cache_events) - self._manager = LRUOffloadingManager(CPUBackend( - block_size=self.offloaded_block_size, - num_blocks=self.num_cpu_blocks), - enable_events=enable_events) + enable_events = ( + kv_events_config is not None and kv_events_config.enable_kv_cache_events + ) + self._manager = LRUOffloadingManager( + CPUBackend( + block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks + ), + enable_events=enable_events, + ) return self._manager def get_handlers( self, kv_caches: dict[str, torch.Tensor] - ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], - OffloadingHandler]]: + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: if not self._handler: if not current_platform.is_cuda(): - raise Exception("CPU Offloading is currently only supported" - " on CUDA GPUs") + raise Exception( + "CPU Offloading is currently only supported on CUDA GPUs" + ) layer_names = list(kv_caches.keys()) - layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase, - layer_names) + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, layer_names + ) attn_backends = { layer_name: layers[layer_name].get_attn_backend() for layer_name in layer_names @@ -68,7 +71,8 @@ def get_handlers( gpu_block_size=self.gpu_block_size, cpu_block_size=self.offloaded_block_size, num_cpu_blocks=self.num_cpu_blocks, - gpu_caches=kv_caches) + gpu_caches=kv_caches, + ) assert self._handler is not None yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler diff --git a/vllm/v1/kv_offload/factory.py b/vllm/v1/kv_offload/factory.py index f9bef6cea903..e0a53460e840 100644 --- a/vllm/v1/kv_offload/factory.py +++ b/vllm/v1/kv_offload/factory.py @@ -16,8 +16,7 @@ class OffloadingSpecFactory: _registry: dict[str, Callable[[], type[OffloadingSpec]]] = {} @classmethod - def register_spec(cls, name: str, module_path: str, - class_name: str) -> None: + def register_spec(cls, name: str, module_path: str, class_name: str) -> None: """Register a spec with a lazy-loading module and class name.""" if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") @@ -51,6 +50,6 @@ def create_spec( # Register various specs here. -OffloadingSpecFactory.register_spec("CPUOffloadingSpec", - "vllm.v1.kv_offload.cpu", - "CPUOffloadingSpec") +OffloadingSpecFactory.register_spec( + "CPUOffloadingSpec", "vllm.v1.kv_offload.cpu", "CPUOffloadingSpec" +) diff --git a/vllm/v1/kv_offload/lru_manager.py b/vllm/v1/kv_offload/lru_manager.py index 18d3b1d637b3..36f5eb4a0abd 100644 --- a/vllm/v1/kv_offload/lru_manager.py +++ b/vllm/v1/kv_offload/lru_manager.py @@ -5,8 +5,12 @@ from typing import Optional from vllm.v1.core.kv_cache_utils import BlockHash -from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent, - OffloadingManager, PrepareStoreOutput) +from vllm.v1.kv_offload.abstract import ( + LoadStoreSpec, + OffloadingEvent, + OffloadingManager, + PrepareStoreOutput, +) from vllm.v1.kv_offload.backend import Backend, BlockStatus @@ -19,8 +23,7 @@ def __init__(self, backend: Backend, enable_events: bool = False): self.backend: Backend = backend # block_hash -> BlockStatus self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict() - self.events: Optional[list[OffloadingEvent]] = \ - [] if enable_events else None + self.events: Optional[list[OffloadingEvent]] = [] if enable_events else None def lookup(self, block_hashes: Iterable[BlockHash]) -> int: hit_count = 0 @@ -53,16 +56,16 @@ def complete_load(self, block_hashes: Iterable[BlockHash]): block.ref_cnt -= 1 def prepare_store( - self, - block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]: + self, block_hashes: Iterable[BlockHash] + ) -> Optional[PrepareStoreOutput]: # filter out blocks that are already stored block_hashes_to_store = [ - block_hash for block_hash in block_hashes - if block_hash not in self.blocks + block_hash for block_hash in block_hashes if block_hash not in self.blocks ] - num_blocks_to_evict = (len(block_hashes_to_store) - - self.backend.get_num_free_blocks()) + num_blocks_to_evict = ( + len(block_hashes_to_store) - self.backend.get_num_free_blocks() + ) # build list of blocks to evict to_evict = [] @@ -83,10 +86,13 @@ def prepare_store( if to_evict and self.events is not None: self.events.append( - OffloadingEvent(block_hashes=to_evict, - block_size=self.backend.block_size, - medium=self.backend.medium, - removed=True)) + OffloadingEvent( + block_hashes=to_evict, + block_size=self.backend.block_size, + medium=self.backend.medium, + removed=True, + ) + ) blocks = self.backend.allocate_blocks(block_hashes_to_store) assert len(blocks) == len(block_hashes_to_store) @@ -95,16 +101,15 @@ def prepare_store( self.blocks[block_hash] = block # build store specs for allocated blocks - store_spec = self.backend.get_load_store_spec(block_hashes_to_store, - blocks) + store_spec = self.backend.get_load_store_spec(block_hashes_to_store, blocks) - return PrepareStoreOutput(block_hashes_to_store=block_hashes_to_store, - store_spec=store_spec, - block_hashes_evicted=to_evict) + return PrepareStoreOutput( + block_hashes_to_store=block_hashes_to_store, + store_spec=store_spec, + block_hashes_evicted=to_evict, + ) - def complete_store(self, - block_hashes: Iterable[BlockHash], - success: bool = True): + def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True): stored_block_hashes: list[BlockHash] = [] if success: for block_hash in block_hashes: @@ -121,10 +126,13 @@ def complete_store(self, if stored_block_hashes and self.events is not None: self.events.append( - OffloadingEvent(block_hashes=stored_block_hashes, - block_size=self.backend.block_size, - medium=self.backend.medium, - removed=False)) + OffloadingEvent( + block_hashes=stored_block_hashes, + block_size=self.backend.block_size, + medium=self.backend.medium, + removed=False, + ) + ) def take_events(self) -> Iterable[OffloadingEvent]: if self.events is not None: diff --git a/vllm/v1/kv_offload/spec.py b/vllm/v1/kv_offload/spec.py index ed23d5e51934..a3c539a47d45 100644 --- a/vllm/v1/kv_offload/spec.py +++ b/vllm/v1/kv_offload/spec.py @@ -22,7 +22,8 @@ class OffloadingSpec(ABC): def __init__(self, vllm_config: "VllmConfig"): logger.warning( "Initializing OffloadingSpec. This API is experimental and " - "subject to change in the future as we iterate the design.") + "subject to change in the future as we iterate the design." + ) self.vllm_config = vllm_config kv_transfer_config = vllm_config.kv_transfer_config @@ -31,7 +32,8 @@ def __init__(self, vllm_config: "VllmConfig"): self.gpu_block_size = vllm_config.cache_config.block_size self.offloaded_block_size = int( - self.extra_config.get("block_size", self.gpu_block_size)) + self.extra_config.get("block_size", self.gpu_block_size) + ) assert self.offloaded_block_size % self.gpu_block_size == 0 @@ -47,8 +49,7 @@ def get_manager(self) -> OffloadingManager: @abstractmethod def get_handlers( self, kv_caches: dict[str, torch.Tensor] - ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], - OffloadingHandler]]: + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: """ Get offloading handlers along with their respective src and dst types. diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index 556c29247e5e..eb7117a400b9 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -9,16 +9,21 @@ from vllm.logger import init_logger from vllm.utils import is_pin_memory_available from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec -from vllm.v1.kv_offload.worker.worker import (OffloadingHandler, - TransferResult, TransferSpec) +from vllm.v1.kv_offload.worker.worker import ( + OffloadingHandler, + TransferResult, + TransferSpec, +) logger = init_logger(__name__) -def expand_block_ids(block_ids: np.ndarray, - block_size_factor: int, - output: np.ndarray, - skip_count: int = 0): +def expand_block_ids( + block_ids: np.ndarray, + block_size_factor: int, + output: np.ndarray, + skip_count: int = 0, +): """ Convert a list of block IDs to a list of matching block ids, assuming each block is composed of actual block_size_factor blocks. @@ -47,10 +52,14 @@ def expand_block_ids(block_ids: np.ndarray, class CpuGpuOffloadingHandler(OffloadingHandler): - - def __init__(self, gpu_block_size: int, cpu_block_size: int, - num_cpu_blocks: int, gpu_caches: dict[str, torch.Tensor], - attn_backends: dict[str, type[AttentionBackend]]): + def __init__( + self, + gpu_block_size: int, + cpu_block_size: int, + num_cpu_blocks: int, + gpu_caches: dict[str, torch.Tensor], + attn_backends: dict[str, type[AttentionBackend]], + ): assert cpu_block_size % gpu_block_size == 0 self.block_size_factor = cpu_block_size // gpu_block_size @@ -75,7 +84,8 @@ def __init__(self, gpu_block_size: int, cpu_block_size: int, gpu_shape = gpu_tensor.shape test_shape = attn_backends[layer_name].get_kv_cache_shape( - num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256) + num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 + ) if test_shape[0] == 1234: # shape is (num_blocks, ...) num_blocks_idx = 0 @@ -94,10 +104,13 @@ def __init__(self, gpu_block_size: int, cpu_block_size: int, logger.debug("Allocating CPU tensor of shape %r", cpu_shape) self.cpu_tensors.append( - torch.zeros(cpu_shape, - dtype=gpu_tensor.dtype, - device="cpu", - pin_memory=pin_memory)) + torch.zeros( + cpu_shape, + dtype=gpu_tensor.dtype, + device="cpu", + pin_memory=pin_memory, + ) + ) def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: src_spec, dst_spec = spec @@ -122,35 +135,36 @@ def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: assert src_blocks.ndim == 1 assert dst_blocks.ndim == 1 - dst_sub_blocks_to_skip = (-src_blocks.size % dst_block_size_factor) + dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor src_sub_block_count = src_blocks.size * src_block_size_factor assert ( - src_sub_block_count == dst_blocks.size * dst_block_size_factor - - dst_sub_blocks_to_skip) + src_sub_block_count + == dst_blocks.size * dst_block_size_factor - dst_sub_blocks_to_skip + ) src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64) expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0]) - expand_block_ids(dst_blocks, - dst_block_size_factor, - src_to_dst[:, 1], - skip_count=dst_sub_blocks_to_skip) + expand_block_ids( + dst_blocks, + dst_block_size_factor, + src_to_dst[:, 1], + skip_count=dst_sub_blocks_to_skip, + ) src_to_dst_tensor = torch.from_numpy(src_to_dst) - event = self.events_pool.pop() if self.events_pool \ - else torch.cuda.Event() + event = self.events_pool.pop() if self.events_pool else torch.cuda.Event() with torch.cuda.stream(stream): for src_tensor, dst_tensor, kv_dim in zip( - src_tensors, dst_tensors, self.kv_dim_before_num_blocks): + src_tensors, dst_tensors, self.kv_dim_before_num_blocks + ): if kv_dim: src_key_cache = src_tensor[0] dst_key_cache = dst_tensor[0] - ops.swap_blocks(src_key_cache, dst_key_cache, - src_to_dst_tensor) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor) src_value_cache = src_tensor[1] dst_value_cache = dst_tensor[1] - ops.swap_blocks(src_value_cache, dst_value_cache, - src_to_dst_tensor) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor) else: ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor) event.record(stream) diff --git a/vllm/v1/kv_offload/worker/worker.py b/vllm/v1/kv_offload/worker/worker.py index b7a52a088fb9..58ba082497fa 100644 --- a/vllm/v1/kv_offload/worker/worker.py +++ b/vllm/v1/kv_offload/worker/worker.py @@ -74,12 +74,14 @@ class OffloadingWorker: def __init__(self): self.handlers: set[OffloadingHandler] = set() - self.transfer_type_to_handler: dict[TransferType, - OffloadingHandler] = {} - - def register_handler(self, src_cls: type[LoadStoreSpec], - dst_cls: type[LoadStoreSpec], - handler: OffloadingHandler) -> None: + self.transfer_type_to_handler: dict[TransferType, OffloadingHandler] = {} + + def register_handler( + self, + src_cls: type[LoadStoreSpec], + dst_cls: type[LoadStoreSpec], + handler: OffloadingHandler, + ) -> None: """ Registers a new handler. @@ -113,19 +115,19 @@ def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: try: success = handler.transfer_async(job_id, spec) except Exception as e: - logger.warning("Exception in %r transfer %d: %r", - transfer_type, - job_id, - e, - exc_info=True) + logger.warning( + "Exception in %r transfer %d: %r", + transfer_type, + job_id, + e, + exc_info=True, + ) return False if not success: - logger.warning("Failed to submit %r transfer %d", transfer_type, - job_id) + logger.warning("Failed to submit %r transfer %d", transfer_type, job_id) else: - logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id, - spec) + logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id, spec) return success diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index ef95f03e8882..541af7af1725 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -9,8 +9,7 @@ import prometheus_client from vllm.config import SupportsMetricsInfo, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorLogging) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics from vllm.v1.engine import FinishReason @@ -32,26 +31,24 @@ class StatLoggerBase(ABC): """ @abstractmethod - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): - ... + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ... @abstractmethod - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): - ... + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ): ... @abstractmethod - def log_engine_initialized(self): - ... + def log_engine_initialized(self): ... def log(self): # noqa pass class LoggingStatLogger(StatLoggerBase): - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self.vllm_config = vllm_config @@ -85,21 +82,21 @@ def _get_throughput(self, tracked_stats: int, now: float) -> float: return 0.0 return float(tracked_stats / delta_time) - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ): """Log Stats to standard output.""" if iteration_stats: self._track_iteration_stats(iteration_stats) if scheduler_stats is not None: - self.prefix_caching_metrics.observe( - scheduler_stats.prefix_cache_stats) + self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_logging.observe( - scheduler_stats.spec_decoding_stats) + self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats) if kv_connector_stats := scheduler_stats.kv_connector_stats: self.kv_connector_logging.observe(kv_connector_stats) self.last_scheduler_stats = scheduler_stats @@ -107,8 +104,7 @@ def record(self, def log(self): now = time.monotonic() prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) - generation_throughput = self._get_throughput( - self.num_generation_tokens, now) + generation_throughput = self._get_throughput(self.num_generation_tokens, now) self._reset(now) @@ -116,8 +112,13 @@ def log(self): log_fn = logger.info if not any( - (prompt_throughput, generation_throughput, - self.last_prompt_throughput, self.last_generation_throughput)): + ( + prompt_throughput, + generation_throughput, + self.last_prompt_throughput, + self.last_generation_throughput, + ) + ): # Avoid log noise on an idle production system log_fn = logger.debug self.last_generation_throughput = generation_throughput @@ -146,8 +147,10 @@ def log_engine_initialized(self): if self.vllm_config.cache_config.num_gpu_blocks: logger.info( "Engine %03d: vllm cache_config_info with initialization " - "after num_gpu_blocks is: %d", self.engine_index, - self.vllm_config.cache_config.num_gpu_blocks) + "after num_gpu_blocks is: %d", + self.engine_index, + self.vllm_config.cache_config.num_gpu_blocks, + ) class PrometheusStatLogger(StatLoggerBase): @@ -156,9 +159,9 @@ class PrometheusStatLogger(StatLoggerBase): _histogram_cls = prometheus_client.Histogram _spec_decoding_cls = SpecDecodingProm - def __init__(self, - vllm_config: VllmConfig, - engine_indexes: Optional[list[int]] = None): + def __init__( + self, vllm_config: VllmConfig, engine_indexes: Optional[list[int]] = None + ): if engine_indexes is None: engine_indexes = [0] self.engine_indexes = engine_indexes @@ -167,21 +170,19 @@ def __init__(self, self.vllm_config = vllm_config # Use this flag to hide metrics that were deprecated in # a previous release and which will be removed future - self.show_hidden_metrics = \ - vllm_config.observability_config.show_hidden_metrics + self.show_hidden_metrics = vllm_config.observability_config.show_hidden_metrics labelnames = ["model_name", "engine"] model_name = vllm_config.model_config.served_model_name max_model_len = vllm_config.model_config.max_model_len spec_decode_labelvalues: dict[int, list[str]] = { - idx: [model_name, str(idx)] - for idx in engine_indexes + idx: [model_name, str(idx)] for idx in engine_indexes } self.spec_decoding_prom = self._spec_decoding_cls( - vllm_config.speculative_config, labelnames, - spec_decode_labelvalues) + vllm_config.speculative_config, labelnames, spec_decode_labelvalues + ) # # Scheduler state @@ -190,19 +191,21 @@ def __init__(self, name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", multiprocess_mode="mostrecent", - labelnames=labelnames) - self.gauge_scheduler_running = make_per_engine(gauge_scheduler_running, - engine_indexes, - model_name) + labelnames=labelnames, + ) + self.gauge_scheduler_running = make_per_engine( + gauge_scheduler_running, engine_indexes, model_name + ) gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", multiprocess_mode="mostrecent", - labelnames=labelnames) - self.gauge_scheduler_waiting = make_per_engine(gauge_scheduler_waiting, - engine_indexes, - model_name) + labelnames=labelnames, + ) + self.gauge_scheduler_waiting = make_per_engine( + gauge_scheduler_waiting, engine_indexes, model_name + ) # # GPU cache @@ -215,11 +218,14 @@ def __init__(self, name="vllm:gpu_cache_usage_perc", documentation=( "GPU KV-cache usage. 1 means 100 percent usage." - "DEPRECATED: Use vllm:kv_cache_usage_perc instead."), + "DEPRECATED: Use vllm:kv_cache_usage_perc instead." + ), multiprocess_mode="mostrecent", - labelnames=labelnames) + labelnames=labelnames, + ) self.gauge_gpu_cache_usage = make_per_engine( - gauge_gpu_cache_usage, engine_indexes, model_name) + gauge_gpu_cache_usage, engine_indexes, model_name + ) # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_queries # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10 @@ -231,9 +237,11 @@ def __init__(self, "GPU prefix cache queries, in terms of number of queried" "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead." ), - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_gpu_prefix_cache_queries = make_per_engine( - counter_gpu_prefix_cache_queries, engine_indexes, model_name) + counter_gpu_prefix_cache_queries, engine_indexes, model_name + ) # Deprecated in 0.9.2 - Renamed as vllm:prefix_cache_hits # With 0.11.x you can enable with --show-hidden-metrics-for-version=0.10 @@ -243,33 +251,42 @@ def __init__(self, name="vllm:gpu_prefix_cache_hits", documentation=( "GPU prefix cache hits, in terms of number of cached " - "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."), - labelnames=labelnames) + "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead." + ), + labelnames=labelnames, + ) self.counter_gpu_prefix_cache_hits = make_per_engine( - counter_gpu_prefix_cache_hits, engine_indexes, model_name) + counter_gpu_prefix_cache_hits, engine_indexes, model_name + ) gauge_kv_cache_usage = self._gauge_cls( name="vllm:kv_cache_usage_perc", documentation="KV-cache usage. 1 means 100 percent usage.", - labelnames=labelnames) - self.gauge_kv_cache_usage = make_per_engine(gauge_kv_cache_usage, - engine_indexes, model_name) + labelnames=labelnames, + ) + self.gauge_kv_cache_usage = make_per_engine( + gauge_kv_cache_usage, engine_indexes, model_name + ) counter_prefix_cache_queries = self._counter_cls( name="vllm:prefix_cache_queries", documentation=( - "Prefix cache queries, in terms of number of queried tokens."), - labelnames=labelnames) + "Prefix cache queries, in terms of number of queried tokens." + ), + labelnames=labelnames, + ) self.counter_prefix_cache_queries = make_per_engine( - counter_prefix_cache_queries, engine_indexes, model_name) + counter_prefix_cache_queries, engine_indexes, model_name + ) counter_prefix_cache_hits = self._counter_cls( name="vllm:prefix_cache_hits", - documentation=( - "Prefix cache hits, in terms of number of cached tokens."), - labelnames=labelnames) + documentation=("Prefix cache hits, in terms of number of cached tokens."), + labelnames=labelnames, + ) self.counter_prefix_cache_hits = make_per_engine( - counter_prefix_cache_hits, engine_indexes, model_name) + counter_prefix_cache_hits, engine_indexes, model_name + ) # # Counters @@ -277,36 +294,43 @@ def __init__(self, counter_num_preempted_reqs = self._counter_cls( name="vllm:num_preemptions", documentation="Cumulative number of preemption from the engine.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_num_preempted_reqs = make_per_engine( - counter_num_preempted_reqs, engine_indexes, model_name) + counter_num_preempted_reqs, engine_indexes, model_name + ) counter_prompt_tokens = self._counter_cls( name="vllm:prompt_tokens", documentation="Number of prefill tokens processed.", - labelnames=labelnames) - self.counter_prompt_tokens = make_per_engine(counter_prompt_tokens, - engine_indexes, - model_name) + labelnames=labelnames, + ) + self.counter_prompt_tokens = make_per_engine( + counter_prompt_tokens, engine_indexes, model_name + ) counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens", documentation="Number of generation tokens processed.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_generation_tokens = make_per_engine( - counter_generation_tokens, engine_indexes, model_name) + counter_generation_tokens, engine_indexes, model_name + ) - self.counter_request_success: dict[FinishReason, dict[ - int, prometheus_client.Counter]] = {} + self.counter_request_success: dict[ + FinishReason, dict[int, prometheus_client.Counter] + ] = {} counter_request_success_base = self._counter_cls( name="vllm:request_success", documentation="Count of successfully processed requests.", - labelnames=labelnames + ["finished_reason"]) + labelnames=labelnames + ["finished_reason"], + ) for reason in FinishReason: self.counter_request_success[reason] = { - idx: - counter_request_success_base.labels(model_name, str(idx), - str(reason)) + idx: counter_request_success_base.labels( + model_name, str(idx), str(reason) + ) for idx in engine_indexes } @@ -317,18 +341,21 @@ def __init__(self, name="vllm:request_prompt_tokens", documentation="Number of prefill tokens processed.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_num_prompt_tokens_request = make_per_engine( - histogram_num_prompt_tokens_request, engine_indexes, model_name) + histogram_num_prompt_tokens_request, engine_indexes, model_name + ) histogram_num_generation_tokens_request = self._histogram_cls( name="vllm:request_generation_tokens", documentation="Number of generation tokens processed.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_num_generation_tokens_request = make_per_engine( - histogram_num_generation_tokens_request, engine_indexes, - model_name) + histogram_num_generation_tokens_request, engine_indexes, model_name + ) # TODO: This metric might be incorrect in case of using multiple # api_server counts which uses prometheus mp. @@ -336,38 +363,42 @@ def __init__(self, histogram_iteration_tokens = self._histogram_cls( name="vllm:iteration_tokens_total", documentation="Histogram of number of tokens per engine_step.", - buckets=[ - 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 - ], - labelnames=labelnames) + buckets=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + labelnames=labelnames, + ) self.histogram_iteration_tokens = make_per_engine( - histogram_iteration_tokens, engine_indexes, model_name) + histogram_iteration_tokens, engine_indexes, model_name + ) histogram_max_num_generation_tokens_request = self._histogram_cls( name="vllm:request_max_num_generation_tokens", - documentation= - "Histogram of maximum number of requested generation tokens.", + documentation="Histogram of maximum number of requested generation tokens.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_max_num_generation_tokens_request = make_per_engine( - histogram_max_num_generation_tokens_request, engine_indexes, - model_name) + histogram_max_num_generation_tokens_request, engine_indexes, model_name + ) histogram_n_request = self._histogram_cls( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", buckets=[1, 2, 5, 10, 20], - labelnames=labelnames) - self.histogram_n_request = make_per_engine(histogram_n_request, - engine_indexes, model_name) + labelnames=labelnames, + ) + self.histogram_n_request = make_per_engine( + histogram_n_request, engine_indexes, model_name + ) histogram_max_tokens_request = self._histogram_cls( name="vllm:request_params_max_tokens", documentation="Histogram of the max_tokens request parameter.", buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_max_tokens_request = make_per_engine( - histogram_max_tokens_request, engine_indexes, model_name) + histogram_max_tokens_request, engine_indexes, model_name + ) # # Histogram of timing intervals @@ -376,13 +407,34 @@ def __init__(self, name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", buckets=[ - 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, - 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, - 2560.0 + 0.001, + 0.005, + 0.01, + 0.02, + 0.04, + 0.06, + 0.08, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, + 160.0, + 640.0, + 2560.0, ], - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_time_to_first_token = make_per_engine( - histogram_time_to_first_token, engine_indexes, model_name) + histogram_time_to_first_token, engine_indexes, model_name + ) # Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds # TODO: in 0.12, only enable if show_hidden_metrics=True @@ -390,86 +442,167 @@ def __init__(self, name="vllm:time_per_output_token_seconds", documentation=( "Histogram of time per output token in seconds." - "DEPRECATED: Use vllm:inter_token_latency_seconds instead."), + "DEPRECATED: Use vllm:inter_token_latency_seconds instead." + ), buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, - 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, ], - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_time_per_output_token = make_per_engine( - histogram_time_per_output_token, engine_indexes, model_name) + histogram_time_per_output_token, engine_indexes, model_name + ) histogram_inter_token_latency = self._histogram_cls( name="vllm:inter_token_latency_seconds", documentation="Histogram of inter-token latency in seconds.", buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, - 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, ], - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_inter_token_latency = make_per_engine( - histogram_inter_token_latency, engine_indexes, model_name) + histogram_inter_token_latency, engine_indexes, model_name + ) histogram_request_time_per_output_token = self._histogram_cls( name="vllm:request_time_per_output_token_seconds", - documentation= - "Histogram of time_per_output_token_seconds per request.", + documentation="Histogram of time_per_output_token_seconds per request.", buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, - 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.15, + 0.2, + 0.3, + 0.4, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 20.0, + 40.0, + 80.0, ], - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_request_time_per_output_token = make_per_engine( - histogram_request_time_per_output_token, engine_indexes, - model_name) + histogram_request_time_per_output_token, engine_indexes, model_name + ) request_latency_buckets = [ - 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, - 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 + 0.3, + 0.5, + 0.8, + 1.0, + 1.5, + 2.0, + 2.5, + 5.0, + 10.0, + 15.0, + 20.0, + 30.0, + 40.0, + 50.0, + 60.0, + 120.0, + 240.0, + 480.0, + 960.0, + 1920.0, + 7680.0, ] histogram_e2e_time_request = self._histogram_cls( name="vllm:e2e_request_latency_seconds", documentation="Histogram of e2e request latency in seconds.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_e2e_time_request = make_per_engine( - histogram_e2e_time_request, engine_indexes, model_name) + histogram_e2e_time_request, engine_indexes, model_name + ) histogram_queue_time_request = self._histogram_cls( name="vllm:request_queue_time_seconds", - documentation= - "Histogram of time spent in WAITING phase for request.", + documentation="Histogram of time spent in WAITING phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_queue_time_request = make_per_engine( - histogram_queue_time_request, engine_indexes, model_name) + histogram_queue_time_request, engine_indexes, model_name + ) histogram_inference_time_request = self._histogram_cls( name="vllm:request_inference_time_seconds", - documentation= - "Histogram of time spent in RUNNING phase for request.", + documentation="Histogram of time spent in RUNNING phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_inference_time_request = make_per_engine( - histogram_inference_time_request, engine_indexes, model_name) + histogram_inference_time_request, engine_indexes, model_name + ) histogram_prefill_time_request = self._histogram_cls( name="vllm:request_prefill_time_seconds", - documentation= - "Histogram of time spent in PREFILL phase for request.", + documentation="Histogram of time spent in PREFILL phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_prefill_time_request = make_per_engine( - histogram_prefill_time_request, engine_indexes, model_name) + histogram_prefill_time_request, engine_indexes, model_name + ) histogram_decode_time_request = self._histogram_cls( name="vllm:request_decode_time_seconds", - documentation= - "Histogram of time spent in DECODE phase for request.", + documentation="Histogram of time spent in DECODE phase for request.", buckets=request_latency_buckets, - labelnames=labelnames) + labelnames=labelnames, + ) self.histogram_decode_time_request = make_per_engine( - histogram_decode_time_request, engine_indexes, model_name) + histogram_decode_time_request, engine_indexes, model_name + ) # # LoRA metrics @@ -480,23 +613,21 @@ def __init__(self, self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: if len(self.engine_indexes) > 1: - raise NotImplementedError( - "LoRA in DP mode is not supported yet.") + raise NotImplementedError("LoRA in DP mode is not supported yet.") self.labelname_max_lora = "max_lora" self.labelname_waiting_lora_adapters = "waiting_lora_adapters" self.labelname_running_lora_adapters = "running_lora_adapters" self.max_lora = vllm_config.lora_config.max_loras - self.gauge_lora_info = \ - self._gauge_cls( - name="vllm:lora_requests_info", - documentation="Running stats on lora requests.", - multiprocess_mode="sum", - labelnames=[ - self.labelname_max_lora, - self.labelname_waiting_lora_adapters, - self.labelname_running_lora_adapters, - ], - ) + self.gauge_lora_info = self._gauge_cls( + name="vllm:lora_requests_info", + documentation="Running stats on lora requests.", + multiprocess_mode="sum", + labelnames=[ + self.labelname_max_lora, + self.labelname_waiting_lora_adapters, + self.labelname_running_lora_adapters, + ], + ) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): metrics_info = config_obj.metrics_info() @@ -522,54 +653,65 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): metrics_info["engine"] = str(engine_index) info_gauge.labels(**metrics_info).set(1) - def record(self, - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - engine_idx: int = 0): + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0, + ): """Log to prometheus.""" if scheduler_stats is not None: self.gauge_scheduler_running[engine_idx].set( - scheduler_stats.num_running_reqs) + scheduler_stats.num_running_reqs + ) self.gauge_scheduler_waiting[engine_idx].set( - scheduler_stats.num_waiting_reqs) + scheduler_stats.num_waiting_reqs + ) if self.show_hidden_metrics: self.gauge_gpu_cache_usage[engine_idx].set( - scheduler_stats.kv_cache_usage) - self.gauge_kv_cache_usage[engine_idx].set( - scheduler_stats.kv_cache_usage) + scheduler_stats.kv_cache_usage + ) + self.gauge_kv_cache_usage[engine_idx].set(scheduler_stats.kv_cache_usage) if self.show_hidden_metrics: self.counter_gpu_prefix_cache_queries[engine_idx].inc( - scheduler_stats.prefix_cache_stats.queries) + scheduler_stats.prefix_cache_stats.queries + ) self.counter_gpu_prefix_cache_hits[engine_idx].inc( - scheduler_stats.prefix_cache_stats.hits) + scheduler_stats.prefix_cache_stats.hits + ) self.counter_prefix_cache_queries[engine_idx].inc( - scheduler_stats.prefix_cache_stats.queries) + scheduler_stats.prefix_cache_stats.queries + ) self.counter_prefix_cache_hits[engine_idx].inc( - scheduler_stats.prefix_cache_stats.hits) + scheduler_stats.prefix_cache_stats.hits + ) if scheduler_stats.spec_decoding_stats is not None: self.spec_decoding_prom.observe( - scheduler_stats.spec_decoding_stats, engine_idx) + scheduler_stats.spec_decoding_stats, engine_idx + ) if iteration_stats is None: return self.counter_num_preempted_reqs[engine_idx].inc( - iteration_stats.num_preempted_reqs) - self.counter_prompt_tokens[engine_idx].inc( - iteration_stats.num_prompt_tokens) + iteration_stats.num_preempted_reqs + ) + self.counter_prompt_tokens[engine_idx].inc(iteration_stats.num_prompt_tokens) self.counter_generation_tokens[engine_idx].inc( - iteration_stats.num_generation_tokens) + iteration_stats.num_generation_tokens + ) self.histogram_iteration_tokens[engine_idx].observe( - iteration_stats.num_prompt_tokens + \ - iteration_stats.num_generation_tokens) + iteration_stats.num_prompt_tokens + iteration_stats.num_generation_tokens + ) for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter: - self.histogram_max_num_generation_tokens_request[ - engine_idx].observe(max_gen_tokens) + self.histogram_max_num_generation_tokens_request[engine_idx].observe( + max_gen_tokens + ) for n_param in iteration_stats.n_params_iter: self.histogram_n_request[engine_idx].observe(n_param) for ttft in iteration_stats.time_to_first_tokens_iter: @@ -579,40 +721,51 @@ def record(self, self.histogram_time_per_output_token[engine_idx].observe(itl) for finished_request in iteration_stats.finished_requests: - self.counter_request_success[ - finished_request.finish_reason][engine_idx].inc() + self.counter_request_success[finished_request.finish_reason][ + engine_idx + ].inc() self.histogram_e2e_time_request[engine_idx].observe( - finished_request.e2e_latency) + finished_request.e2e_latency + ) self.histogram_queue_time_request[engine_idx].observe( - finished_request.queued_time) + finished_request.queued_time + ) self.histogram_prefill_time_request[engine_idx].observe( - finished_request.prefill_time) + finished_request.prefill_time + ) self.histogram_inference_time_request[engine_idx].observe( - finished_request.inference_time) + finished_request.inference_time + ) self.histogram_decode_time_request[engine_idx].observe( - finished_request.decode_time) + finished_request.decode_time + ) self.histogram_num_prompt_tokens_request[engine_idx].observe( - finished_request.num_prompt_tokens) + finished_request.num_prompt_tokens + ) self.histogram_num_generation_tokens_request[engine_idx].observe( - finished_request.num_generation_tokens) + finished_request.num_generation_tokens + ) self.histogram_request_time_per_output_token[engine_idx].observe( - finished_request.mean_time_per_output_token) + finished_request.mean_time_per_output_token + ) if finished_request.max_tokens_param: self.histogram_max_tokens_request[engine_idx].observe( - finished_request.max_tokens_param) + finished_request.max_tokens_param + ) if self.gauge_lora_info is not None: - running_lora_adapters = \ - ",".join(iteration_stats.running_lora_adapters.keys()) - waiting_lora_adapters = \ - ",".join(iteration_stats.waiting_lora_adapters.keys()) + running_lora_adapters = ",".join( + iteration_stats.running_lora_adapters.keys() + ) + waiting_lora_adapters = ",".join( + iteration_stats.waiting_lora_adapters.keys() + ) lora_info_labels = { self.labelname_running_lora_adapters: running_lora_adapters, self.labelname_waiting_lora_adapters: waiting_lora_adapters, self.labelname_max_lora: self.max_lora, } - self.gauge_lora_info.labels(**lora_info_labels)\ - .set_to_current_time() + self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time() def log_engine_initialized(self): self.log_metrics_info("cache_config", self.vllm_config.cache_config) @@ -625,8 +778,9 @@ def log_engine_initialized(self): ] -def make_per_engine(metric: PromMetric, engine_idxs: list[int], - model_name: str) -> dict[int, PromMetric]: +def make_per_engine( + metric: PromMetric, engine_idxs: list[int], model_name: str +) -> dict[int, PromMetric]: return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs} @@ -688,7 +842,8 @@ def __init__( if client_count > 1: logger.warning( "AsyncLLM created with api_server_count more than 1; " - "disabling stats logging to avoid incomplete stats.") + "disabling stats logging to avoid incomplete stats." + ) else: factories.append(LoggingStatLogger) @@ -700,12 +855,12 @@ def __init__( for logger_factory in factories: # If we get a custom prometheus logger, use that # instead. This is typically used for the ray case. - if (isinstance(logger_factory, type) - and issubclass(logger_factory, PrometheusStatLogger)): + if isinstance(logger_factory, type) and issubclass( + logger_factory, PrometheusStatLogger + ): prometheus_factory = logger_factory continue - loggers.append(logger_factory(vllm_config, - engine_idx)) # type: ignore + loggers.append(logger_factory(vllm_config, engine_idx)) # type: ignore self.per_engine_logger_dict[engine_idx] = loggers # For Prometheus, need to share the metrics between EngineCores. @@ -725,8 +880,7 @@ def record( for logger in per_engine_loggers: logger.record(scheduler_stats, iteration_stats, engine_idx) - self.prometheus_logger.record(scheduler_stats, iteration_stats, - engine_idx) + self.prometheus_logger.record(scheduler_stats, iteration_stats, engine_idx) def log(self): for per_engine_loggers in self.per_engine_logger_dict.values(): diff --git a/vllm/v1/metrics/prometheus.py b/vllm/v1/metrics/prometheus.py index a43cf9ce255e..5823737968f9 100644 --- a/vllm/v1/metrics/prometheus.py +++ b/vllm/v1/metrics/prometheus.py @@ -16,9 +16,7 @@ def setup_multiprocess_prometheus(): - """Set up prometheus multiprocessing directory if not already configured. - - """ + """Set up prometheus multiprocessing directory if not already configured.""" global _prometheus_multiproc_dir if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: @@ -27,19 +25,22 @@ def setup_multiprocess_prometheus(): # cleaned up upon exit. _prometheus_multiproc_dir = tempfile.TemporaryDirectory() os.environ["PROMETHEUS_MULTIPROC_DIR"] = _prometheus_multiproc_dir.name - logger.debug("Created PROMETHEUS_MULTIPROC_DIR at %s", - _prometheus_multiproc_dir.name) + logger.debug( + "Created PROMETHEUS_MULTIPROC_DIR at %s", _prometheus_multiproc_dir.name + ) else: - logger.warning("Found PROMETHEUS_MULTIPROC_DIR was set by user. " - "This directory must be wiped between vLLM runs or " - "you will find inaccurate metrics. Unset the variable " - "and vLLM will properly handle cleanup.") + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup." + ) def get_prometheus_registry() -> CollectorRegistry: - """Get the appropriate prometheus registry based on multiprocessing + """Get the appropriate prometheus registry based on multiprocessing configuration. - + Returns: Registry: A prometheus registry """ @@ -54,11 +55,11 @@ def get_prometheus_registry() -> CollectorRegistry: def unregister_vllm_metrics(): """Unregister any existing vLLM collectors from the prometheus registry. - + This is useful for testing and CI/CD where metrics may be registered multiple times across test runs. - - Also, in case of multiprocess, we need to unregister the metrics from the + + Also, in case of multiprocess, we need to unregister the metrics from the global registry. """ registry = REGISTRY diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index 609185753860..a6fe2062f70c 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -15,11 +15,9 @@ class RayPrometheusMetric: - def __init__(self): if ray_metrics is None: - raise ImportError( - "RayPrometheusMetric requires Ray to be installed.") + raise ImportError("RayPrometheusMetric requires Ray to be installed.") self.metric: Metric = None @@ -38,15 +36,14 @@ def labels(self, *labels, **labelskwargs): f"Expected {len(self.metric._tag_keys)}, got {len(labels)}" ) - self.metric.set_default_tags( - dict(zip(self.metric._tag_keys, labels))) + self.metric.set_default_tags(dict(zip(self.metric._tag_keys, labels))) return self @staticmethod def _get_sanitized_opentelemetry_name(name: str) -> str: """ - For compatibility with Ray + OpenTelemetry, the metric name must be + For compatibility with Ray + OpenTelemetry, the metric name must be sanitized. In particular, this replaces disallowed character (e.g., ':') with '_' in the metric name. Allowed characters: a-z, A-Z, 0-9, _ @@ -63,21 +60,22 @@ class RayGaugeWrapper(RayPrometheusMetric): """Wraps around ray.util.metrics.Gauge to provide same API as prometheus_client.Gauge""" - def __init__(self, - name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None, - multiprocess_mode: Optional[str] = ""): - + def __init__( + self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + multiprocess_mode: Optional[str] = "", + ): # All Ray metrics are keyed by WorkerId, so multiprocess modes like # "mostrecent", "all", "sum" do not apply. This logic can be manually # implemented at the observability layer (Prometheus/Grafana). del multiprocess_mode labelnames_tuple = tuple(labelnames) if labelnames else None name = self._get_sanitized_opentelemetry_name(name) - self.metric = ray_metrics.Gauge(name=name, - description=documentation, - tag_keys=labelnames_tuple) + self.metric = ray_metrics.Gauge( + name=name, description=documentation, tag_keys=labelnames_tuple + ) def set(self, value: Union[int, float]): return self.metric.set(value) @@ -91,15 +89,17 @@ class RayCounterWrapper(RayPrometheusMetric): """Wraps around ray.util.metrics.Counter to provide same API as prometheus_client.Counter""" - def __init__(self, - name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None): + def __init__( + self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + ): labelnames_tuple = tuple(labelnames) if labelnames else None name = self._get_sanitized_opentelemetry_name(name) - self.metric = ray_metrics.Counter(name=name, - description=documentation, - tag_keys=labelnames_tuple) + self.metric = ray_metrics.Counter( + name=name, description=documentation, tag_keys=labelnames_tuple + ) def inc(self, value: Union[int, float] = 1.0): if value == 0: @@ -111,18 +111,22 @@ class RayHistogramWrapper(RayPrometheusMetric): """Wraps around ray.util.metrics.Histogram to provide same API as prometheus_client.Histogram""" - def __init__(self, - name: str, - documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None, - buckets: Optional[list[float]] = None): + def __init__( + self, + name: str, + documentation: Optional[str] = "", + labelnames: Optional[list[str]] = None, + buckets: Optional[list[float]] = None, + ): labelnames_tuple = tuple(labelnames) if labelnames else None name = self._get_sanitized_opentelemetry_name(name) boundaries = buckets if buckets else [] - self.metric = ray_metrics.Histogram(name=name, - description=documentation, - tag_keys=labelnames_tuple, - boundaries=boundaries) + self.metric = ray_metrics.Histogram( + name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=boundaries, + ) def observe(self, value: Union[int, float]): return self.metric.observe(value) diff --git a/vllm/v1/metrics/reader.py b/vllm/v1/metrics/reader.py index 4d6e59984154..5d50fa9461d0 100644 --- a/vllm/v1/metrics/reader.py +++ b/vllm/v1/metrics/reader.py @@ -17,6 +17,7 @@ class Metric: in some cases a single vLLM instance may have multiple metrics with the same name but different sets of labels. """ + name: str labels: dict[str, str] @@ -24,6 +25,7 @@ class Metric: @dataclass class Counter(Metric): """A monotonically increasing integer counter.""" + value: int @@ -34,12 +36,14 @@ class Vector(Metric): This type - which doesn't exist in Prometheus - models one very specific metric, vllm:spec_decode_num_accepted_tokens_per_pos. """ + values: list[int] @dataclass class Gauge(Metric): """A numerical value that can go up or down.""" + value: float @@ -58,6 +62,7 @@ class Histogram(Metric): The sum property is the total sum of all observed values. """ + count: int sum: float buckets: dict[str, int] @@ -87,7 +92,8 @@ def get_metrics_snapshot() -> list[Metric]: samples = _get_samples(metric) for s in samples: collected.append( - Gauge(name=metric.name, labels=s.labels, value=s.value)) + Gauge(name=metric.name, labels=s.labels, value=s.value) + ) elif metric.type == "counter": samples = _get_samples(metric, "_total") if metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": @@ -99,16 +105,15 @@ def get_metrics_snapshot() -> list[Metric]: # accepted tokens using a Counter labeled with 'position'. # We convert these into a vector of integer values. # - for labels, values in _digest_num_accepted_by_pos_samples( - samples): + for labels, values in _digest_num_accepted_by_pos_samples(samples): collected.append( - Vector(name=metric.name, labels=labels, values=values)) + Vector(name=metric.name, labels=labels, values=values) + ) else: for s in samples: collected.append( - Counter(name=metric.name, - labels=s.labels, - value=int(s.value))) + Counter(name=metric.name, labels=s.labels, value=int(s.value)) + ) elif metric.type == "histogram": # @@ -122,21 +127,24 @@ def get_metrics_snapshot() -> list[Metric]: count_samples = _get_samples(metric, "_count") sum_samples = _get_samples(metric, "_sum") for labels, buckets, count_value, sum_value in _digest_histogram( - bucket_samples, count_samples, sum_samples): + bucket_samples, count_samples, sum_samples + ): collected.append( - Histogram(name=metric.name, - labels=labels, - buckets=buckets, - count=count_value, - sum=sum_value)) + Histogram( + name=metric.name, + labels=labels, + buckets=buckets, + count=count_value, + sum=sum_value, + ) + ) else: raise AssertionError(f"Unknown metric type {metric.type}") return collected -def _get_samples(metric: PromMetric, - suffix: Optional[str] = None) -> list[Sample]: +def _get_samples(metric: PromMetric, suffix: Optional[str] = None) -> list[Sample]: name = (metric.name + suffix) if suffix is not None else metric.name return [s for s in metric.samples if s.name == name] @@ -148,8 +156,7 @@ def _strip_label(labels: dict[str, str], key_to_remove: str) -> dict[str, str]: def _digest_histogram( - bucket_samples: list[Sample], count_samples: list[Sample], - sum_samples: list[Sample] + bucket_samples: list[Sample], count_samples: list[Sample], sum_samples: list[Sample] ) -> list[tuple[dict[str, str], dict[str, int], int, float]]: # # In the case of DP, we have an indigestable @@ -192,20 +199,25 @@ def _digest_histogram( labels_key = frozenset(s.labels.items()) sums_by_labels[labels_key] = s.value - assert set(buckets_by_labels.keys()) == set( - counts_by_labels.keys()) == set(sums_by_labels.keys()) + assert ( + set(buckets_by_labels.keys()) + == set(counts_by_labels.keys()) + == set(sums_by_labels.keys()) + ) output = [] label_keys = list(buckets_by_labels.keys()) for k in label_keys: labels = dict(k) - output.append((labels, buckets_by_labels[k], counts_by_labels[k], - sums_by_labels[k])) + output.append( + (labels, buckets_by_labels[k], counts_by_labels[k], sums_by_labels[k]) + ) return output def _digest_num_accepted_by_pos_samples( - samples: list[Sample]) -> list[tuple[dict[str, str], list[int]]]: + samples: list[Sample], +) -> list[tuple[dict[str, str], list[int]]]: # # In the case of DP, we have an indigestable # per-position-per-engine count as a list of diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index a0d571318ba0..5564718d5165 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -15,6 +15,7 @@ @dataclass class PrefixCacheStats: """Stores prefix cache hit statistics.""" + # Whether reset_prefix_cache was invoked. reset: bool = False # The number of new requests in this update. @@ -45,8 +46,7 @@ class SchedulerStats: kv_cache_usage: float = 0.0 - prefix_cache_stats: PrefixCacheStats = field( - default_factory=PrefixCacheStats) + prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats) spec_decoding_stats: Optional[SpecDecodingStats] = None kv_connector_stats: Optional[dict[str, Any]] = None @@ -111,14 +111,23 @@ def __init__(self): self.waiting_lora_adapters: dict[str, int] = {} self.running_lora_adapters: dict[str, int] = {} + def __repr__(self) -> str: + field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items()) + return f"{self.__class__.__name__}({field_to_value_str})" + def _time_since(self, start: float) -> float: """Calculate an interval relative to this iteration's timestamp.""" return self.iteration_timestamp - start - def update_from_output(self, output: "EngineCoreOutput", - engine_core_timestamp: float, is_prefilling: bool, - prompt_len: int, req_stats: RequestStateStats, - lora_stats: Optional[LoRAStats]): + def update_from_output( + self, + output: "EngineCoreOutput", + engine_core_timestamp: float, + is_prefilling: bool, + prompt_len: int, + req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats], + ): num_new_generation_tokens = len(output.new_token_ids) self.num_generation_tokens += num_new_generation_tokens @@ -133,8 +142,9 @@ def update_from_output(self, output: "EngineCoreOutput", # Process request-level engine core events if output.events is not None: - self.update_from_events(output.request_id, output.events, - is_prefilling, req_stats, lora_stats) + self.update_from_events( + output.request_id, output.events, is_prefilling, req_stats, lora_stats + ) # Process the batch-level "new tokens" engine core event if is_prefilling: @@ -145,11 +155,17 @@ def update_from_output(self, output: "EngineCoreOutput", req_stats.last_token_ts = engine_core_timestamp - def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], - is_prefilling: bool, req_stats: RequestStateStats, - lora_stats: Optional[LoRAStats]): + def update_from_events( + self, + req_id: str, + events: list["EngineCoreEvent"], + is_prefilling: bool, + req_stats: RequestStateStats, + lora_stats: Optional[LoRAStats], + ): # Avoid circular dependency from vllm.v1.engine import EngineCoreEventType + for event in events: if event.type == EngineCoreEventType.QUEUED: req_stats.queued_ts = event.timestamp @@ -163,10 +179,13 @@ def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], self.num_preempted_reqs += 1 LoRARequestStates.preempted_request(lora_stats, req_id) - def update_from_finished_request(self, finish_reason: "FinishReason", - num_prompt_tokens: int, - max_tokens_param: Optional[int], - req_stats: RequestStateStats): + def update_from_finished_request( + self, + finish_reason: "FinishReason", + num_prompt_tokens: int, + max_tokens_param: Optional[int], + req_stats: RequestStateStats, + ): e2e_latency = self._time_since(req_stats.arrival_time) # Queued interval is from first QUEUED event to first SCHEDULED @@ -185,22 +204,24 @@ def update_from_finished_request(self, finish_reason: "FinishReason", inference_time = req_stats.last_token_ts - req_stats.scheduled_ts # Do not count the token generated by the prefill phase - mean_time_per_output_token = (decode_time / - (req_stats.num_generation_tokens - 1) - if req_stats.num_generation_tokens - - 1 > 0 else 0) - - finished_req = \ - FinishedRequestStats(finish_reason=finish_reason, - e2e_latency=e2e_latency, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=req_stats.num_generation_tokens, - max_tokens_param=max_tokens_param, - queued_time=queued_time, - prefill_time=prefill_time, - inference_time=inference_time, - decode_time=decode_time, - mean_time_per_output_token=mean_time_per_output_token) + mean_time_per_output_token = ( + decode_time / (req_stats.num_generation_tokens - 1) + if req_stats.num_generation_tokens - 1 > 0 + else 0 + ) + + finished_req = FinishedRequestStats( + finish_reason=finish_reason, + e2e_latency=e2e_latency, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=req_stats.num_generation_tokens, + max_tokens_param=max_tokens_param, + queued_time=queued_time, + prefill_time=prefill_time, + inference_time=inference_time, + decode_time=decode_time, + mean_time_per_output_token=mean_time_per_output_token, + ) self.finished_requests.append(finished_req) @@ -210,24 +231,24 @@ class LoRARequestStates: def __init__(self): self.lora_name_to_stats: dict[str, LoRAStats] = {} - def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]: + def get_stats(self, req_state: "RequestState") -> Optional[LoRAStats]: if req_state.lora_name is None: return None if req_state.lora_name not in self.lora_name_to_stats: self.lora_name_to_stats[req_state.lora_name] = LoRAStats() return self.lora_name_to_stats[req_state.lora_name] - def add_request(self, req_state: 'RequestState'): + def add_request(self, req_state: "RequestState"): if (lora_stats := self.get_stats(req_state)) is not None: lora_stats.waiting_requests.add(req_state.request_id) - def finish_request(self, req_state: 'RequestState'): + def finish_request(self, req_state: "RequestState"): if req_state.lora_name is None: return lora_stats = self.lora_name_to_stats[req_state.lora_name] lora_stats.running_requests.remove(req_state.request_id) - def abort_request(self, req_state: 'RequestState'): + def abort_request(self, req_state: "RequestState"): if req_state.lora_name is None: return lora_stats = self.lora_name_to_stats[req_state.lora_name] @@ -250,14 +271,15 @@ def preempted_request(lora_stats: Optional[LoRAStats], request_id: str): lora_stats.running_requests.remove(request_id) lora_stats.waiting_requests.add(request_id) - def update_iteration_stats(self, - iteration_stats: Optional[IterationStats]): + def update_iteration_stats(self, iteration_stats: Optional[IterationStats]): if iteration_stats is None: return for lora_name, stats in self.lora_name_to_stats.items(): if stats.waiting_requests: - iteration_stats.waiting_lora_adapters[lora_name] = \ - len(stats.waiting_requests) + iteration_stats.waiting_lora_adapters[lora_name] = len( + stats.waiting_requests + ) if stats.running_requests: - iteration_stats.running_lora_adapters[lora_name] = \ - len(stats.running_requests) + iteration_stats.running_lora_adapters[lora_name] = len( + stats.running_requests + ) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index d15cdf365962..d647b207575c 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -8,12 +8,10 @@ import torch if TYPE_CHECKING: - from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats class LogprobsLists(NamedTuple): - # [num_reqs, max_num_logprobs + 1] logprob_token_ids: list[list[int]] # [num_reqs, max_num_logprobs + 1] @@ -30,7 +28,6 @@ def slice(self, start: int, end: int): class LogprobsTensors(NamedTuple): - # [num_reqs, max_num_logprobs + 1] logprob_token_ids: torch.Tensor # [num_reqs, max_num_logprobs + 1] @@ -46,18 +43,18 @@ def tolists(self): ) @staticmethod - def empty_cpu(num_positions: int, - num_tokens_per_position: int) -> "LogprobsTensors": + def empty_cpu( + num_positions: int, num_tokens_per_position: int + ) -> "LogprobsTensors": """Create empty LogprobsTensors on CPU.""" logprob_token_ids = torch.empty( - (num_positions, num_tokens_per_position), - dtype=torch.int32, - device="cpu") + (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu" + ) logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32) - selected_token_ranks = torch.empty(num_positions, - dtype=torch.int32, - device="cpu") + selected_token_ranks = torch.empty( + num_positions, dtype=torch.int32, device="cpu" + ) return LogprobsTensors( logprob_token_ids=logprob_token_ids, logprobs=logprobs, @@ -72,7 +69,6 @@ def empty_cpu(num_positions: int, @dataclass class SamplerOutput: - # [num_reqs, max_num_generated_tokens] # Different requests can have different number of generated tokens. # All requests are padded to max_num_generated_tokens. @@ -92,15 +88,18 @@ class KVConnectorOutput: invalid_block_ids: set[int] = field(default_factory=set) def is_empty(self): - return (not self.finished_sending and not self.finished_recving - and not self.kv_connector_stats and not self.invalid_block_ids) + return ( + not self.finished_sending + and not self.finished_recving + and not self.kv_connector_stats + and not self.invalid_block_ids + ) # ModelRunnerOutput is serialized and sent to the scheduler process. # This is expensive for torch.Tensor so prefer to use list instead. @dataclass class ModelRunnerOutput: - # [num_reqs] req_ids: list[str] # req_id -> index @@ -134,11 +133,10 @@ class ModelRunnerOutput: # ModelRunnerOutput wrapper for async scheduling. class AsyncModelRunnerOutput(ABC): - @abstractmethod def get_output(self) -> ModelRunnerOutput: """Get the ModelRunnerOutput for this async output. - + This is a blocking call that waits until the results are ready, which might involve copying device tensors to the host. This method should only be called once per AsyncModelRunnerOutput. @@ -148,17 +146,18 @@ def get_output(self) -> ModelRunnerOutput: @dataclass class DraftTokenIds: - # [num_reqs] req_ids: list[str] # num_reqs x num_draft_tokens draft_token_ids: list[list[int]] -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - num_nans_in_logits=None) +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + num_nans_in_logits=None, +) diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 46506d272e90..36ae5b40a313 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -29,13 +29,13 @@ def __getitem__(self, indices: slice): ) def is_partial_prefill(self): - return not torch.all( - self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) + return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) @dataclass class PoolingMetadata: """Tensors for pooling.""" + prompt_lens: torch.Tensor # CPU Tensor prompt_token_ids: Optional[torch.Tensor] pooling_params: list[PoolingParams] @@ -44,34 +44,40 @@ class PoolingMetadata: def __getitem__(self, indices: slice): return PoolingMetadata( prompt_lens=self.prompt_lens[indices], - prompt_token_ids=None if self.prompt_token_ids is None else - self.prompt_token_ids[indices], + prompt_token_ids=None + if self.prompt_token_ids is None + else self.prompt_token_ids[indices], pooling_params=self.pooling_params[indices], pooling_cursor=None - if self.pooling_cursor is None else self.pooling_cursor[indices], + if self.pooling_cursor is None + else self.pooling_cursor[indices], ) - def build_pooling_cursor(self, num_scheduled_tokens: list[int], - device: torch.device): - self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, - self.prompt_lens, device) + def build_pooling_cursor( + self, num_scheduled_tokens: list[int], device: torch.device + ): + self.pooling_cursor = build_pooling_cursor( + num_scheduled_tokens, self.prompt_lens, device + ) -def build_pooling_cursor(num_scheduled_tokens: list[int], - prompt_lens: torch.Tensor, device: torch.device): +def build_pooling_cursor( + num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device +): assert len(prompt_lens) == len(num_scheduled_tokens) n_seq = len(num_scheduled_tokens) index = list(range(n_seq)) num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu") - cumsum = torch.zeros(n_seq + 1, - dtype=torch.int64, - pin_memory=pin_memory, - device="cpu") + cumsum = torch.zeros( + n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu" + ) torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:]) cumsum = cumsum.to(device, non_blocking=True) - return PoolingCursor(index=index, - first_token_indices_gpu=cumsum[:n_seq], - last_token_indices_gpu=cumsum[1:] - 1, - prompt_lens_cpu=prompt_lens, - num_scheduled_tokens_cpu=num_scheduled_tokens) + return PoolingCursor( + index=index, + first_token_indices_gpu=cumsum[:n_seq], + last_token_indices_gpu=cumsum[1:] - 1, + prompt_lens_cpu=prompt_lens, + num_scheduled_tokens_cpu=num_scheduled_tokens, + ) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index dd0aea645d74..ac6e583099bc 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -13,8 +13,12 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import length_from_prompt_token_ids_or_embeds -from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, - EngineCoreRequest, FinishReason) +from vllm.v1.engine import ( + EngineCoreEvent, + EngineCoreEventType, + EngineCoreRequest, + FinishReason, +) from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.utils import ConstantList @@ -24,7 +28,6 @@ class Request: - def __init__( self, request_id: str, @@ -41,8 +44,7 @@ def __init__( cache_salt: Optional[str] = None, priority: int = 0, trace_headers: Optional[Mapping[str, str]] = None, - block_hasher: Optional[Callable[["Request"], - list["BlockHash"]]] = None, + block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] = None, ) -> None: self.request_id = request_id self.client_index = client_index @@ -53,8 +55,7 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.structured_output_request = structured_output_request - self.arrival_time = arrival_time if arrival_time is not None else \ - time.time() + self.arrival_time = arrival_time if arrival_time is not None else time.time() self.status = RequestStatus.WAITING self.use_structured_output = False @@ -76,20 +77,23 @@ def __init__( self.use_structured_output = True if sampling_params.extra_args is not None: - self.kv_transfer_params = \ - sampling_params.extra_args.get("kv_transfer_params") + self.kv_transfer_params = sampling_params.extra_args.get( + "kv_transfer_params" + ) else: - raise ValueError( - "sampling_params and pooling_params can't both be unset") + raise ValueError("sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids self.prompt_embeds = prompt_embeds self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - prompt_token_ids, prompt_embeds) + prompt_token_ids, prompt_embeds + ) self._output_token_ids: list[int] = [] - self._all_token_ids: list[int] = self.prompt_token_ids.copy( - ) if self.prompt_token_ids is not None else [0 - ] * self.num_prompt_tokens + self._all_token_ids: list[int] = ( + self.prompt_token_ids.copy() + if self.prompt_token_ids is not None + else [0] * self.num_prompt_tokens + ) self.num_output_placeholders = 0 # Used in async scheduling. self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 @@ -119,16 +123,16 @@ def __init__( self.num_preemptions = 0 self.block_hashes: list[BlockHash] = [] - self.get_hash_new_full_blocks: Optional[Callable[ - [], list[BlockHash]]] = None + self.get_hash_new_full_blocks: Optional[Callable[[], list[BlockHash]]] = None if block_hasher is not None: self.get_hash_new_full_blocks = partial(block_hasher, self) self.block_hashes = self.get_hash_new_full_blocks() @classmethod def from_engine_core_request( - cls, request: EngineCoreRequest, - block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] + cls, + request: EngineCoreRequest, + block_hasher: Optional[Callable[["Request"], list["BlockHash"]]], ) -> "Request": return cls( request_id=request.request_id, @@ -142,8 +146,10 @@ def from_engine_core_request( arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( - sampling_params=request.sampling_params) \ - if request.sampling_params else None, + sampling_params=request.sampling_params + ) + if request.sampling_params + else None, cache_salt=request.cache_salt, priority=request.priority, trace_headers=request.trace_headers, @@ -207,6 +213,7 @@ def take_events(self) -> Optional[list[EngineCoreEvent]]: class RequestStatus(enum.IntEnum): """Status of a request.""" + WAITING = enum.auto() WAITING_FOR_FSM = enum.auto() WAITING_FOR_REMOTE_KVS = enum.auto() @@ -227,8 +234,7 @@ def is_finished(status: "RequestStatus") -> bool: return status > RequestStatus.PREEMPTED @staticmethod - def get_finished_reason( - status: "RequestStatus") -> Union[FinishReason, None]: + def get_finished_reason(status: "RequestStatus") -> Union[FinishReason, None]: return _FINISHED_REASON_MAP.get(status) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 10cad5b53071..98c4d8bad02d 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -13,15 +13,18 @@ from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor from vllm.sampling_params import SamplingParams -from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, - MinPLogitsProcessor, - MinTokensLogitsProcessor, - process_dict_updates) -from vllm.v1.sample.logits_processor.interface import (BatchUpdate, - LogitsProcessor, - MoveDirectionality) -from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder, - LogitsProcessors) +from vllm.v1.sample.logits_processor.builtin import ( + LogitBiasLogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + process_dict_updates, +) +from vllm.v1.sample.logits_processor.interface import ( + BatchUpdate, + LogitsProcessor, + MoveDirectionality, +) +from vllm.v1.sample.logits_processor.state import BatchUpdateBuilder, LogitsProcessors if TYPE_CHECKING: from vllm.config import VllmConfig @@ -30,10 +33,11 @@ # Error message when the user tries to initialize vLLM with a pooling model # and custom logitsproces -STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" - " logits processors.") +STR_POOLING_REJECTS_LOGITSPROCS = ( + "Pooling models do not support custom logits processors." +) -LOGITSPROCS_GROUP = 'vllm.logits_processors' +LOGITSPROCS_GROUP = "vllm.logits_processors" BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ MinTokensLogitsProcessor, @@ -54,27 +58,29 @@ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP) if len(installed_logitsprocs_plugins) == 0: - logger.debug("No logitsprocs plugins installed (group %s).", - LOGITSPROCS_GROUP) + logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP) return [] # Load logitsprocs plugins - logger.debug("Loading installed logitsprocs plugins (group %s):", - LOGITSPROCS_GROUP) + logger.debug("Loading installed logitsprocs plugins (group %s):", LOGITSPROCS_GROUP) classes: list[type[LogitsProcessor]] = [] for entrypoint in installed_logitsprocs_plugins: try: - logger.debug("- Loading logitproc plugin entrypoint=%s target=%s", - entrypoint.name, entrypoint.value) + logger.debug( + "- Loading logitproc plugin entrypoint=%s target=%s", + entrypoint.name, + entrypoint.value, + ) classes.append(entrypoint.load()) except Exception as e: raise RuntimeError( - f"Failed to load LogitsProcessor plugin {entrypoint}") from e + f"Failed to load LogitsProcessor plugin {entrypoint}" + ) from e return classes def _load_logitsprocs_by_fqcns( - logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]] + logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]], ) -> list[type[LogitsProcessor]]: """Load logit processor types, identifying them by fully-qualified class names (FQCNs). @@ -99,13 +105,14 @@ def _load_logitsprocs_by_fqcns( logger.debug( "%s additional custom logits processors specified, checking whether " - "they need to be loaded.", len(logits_processors)) + "they need to be loaded.", + len(logits_processors), + ) classes: list[type[LogitsProcessor]] = [] for ldx, logitproc in enumerate(logits_processors): if isinstance(logitproc, type): - logger.debug(" - Already-loaded logit processor: %s", - logitproc.__name__) + logger.debug(" - Already-loaded logit processor: %s", logitproc.__name__) if not issubclass(logitproc, LogitsProcessor): raise ValueError( f"{logitproc.__name__} is not a subclass of LogitsProcessor" @@ -131,8 +138,7 @@ def _load_logitsprocs_by_fqcns( if not isinstance(obj, type): raise ValueError("Loaded logit processor must be a type.") if not issubclass(obj, LogitsProcessor): - raise ValueError( - f"{obj.__name__} must be a subclass of LogitsProcessor") + raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor") classes.append(obj) return classes @@ -155,13 +161,13 @@ def _load_custom_logitsprocs( A list of all loaded logitproc types """ from vllm.platforms import current_platform + if current_platform.is_tpu(): # No logitsprocs specified by caller # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs return [] - return (_load_logitsprocs_plugins() + - _load_logitsprocs_by_fqcns(logits_processors)) + return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors) def build_logitsprocs( @@ -174,23 +180,28 @@ def build_logitsprocs( if is_pooling_model: if custom_logitsprocs: raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) - logger.debug("Skipping logits processor loading because pooling models" - " do not support logits processors.") + logger.debug( + "Skipping logits processor loading because pooling models" + " do not support logits processors." + ) return LogitsProcessors() custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) return LogitsProcessors( - ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( - BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) + ctor(vllm_config, device, is_pin_memory) + for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes + ) + ) class AdapterLogitsProcessor(LogitsProcessor): """Wrapper for per-request logits processors - + To wrap a specific per-request logits processor, * Subclass `AdapterLogitsProcessor` * Implement `self.is_argmax_invariant()` base-class method * Implement `self.new_req_logits_processor(params)` - + `self.__init__(vllm_config, device, is_pin_memory)` does not need to be overridden in general. However, to implement custom constructor behavior - especially any logic which operates on or stores `vllm_config`, `device`, @@ -199,8 +210,9 @@ class AdapterLogitsProcessor(LogitsProcessor): `super().__init__(vllm_config, device, is_pin_memory)` """ - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): """Subclass must invoke `super().__init__(vllm_config, device, is_pin_memory)`. @@ -236,7 +248,7 @@ def new_req_logits_processor( Returns: None if logits processor should not be applied to request; otherwise returns a `RequestLogitsProcessor` instance - + """ raise NotImplementedError @@ -257,11 +269,14 @@ def _new_state( Returns: logits processor partial[Tensor] or None - + """ if req_lp := self.new_req_logits_processor(params): - args = [prompt_ids, output_ids] if (len( - inspect.signature(req_lp).parameters) == 3) else [output_ids] + args = ( + [prompt_ids, output_ids] + if (len(inspect.signature(req_lp).parameters) == 3) + else [output_ids] + ) return partial(req_lp, *args) return None @@ -286,9 +301,16 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: __all__ = [ - "LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor", - "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", - "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", - "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP", - "AdapterLogitsProcessor" + "LogitsProcessor", + "LogitBiasLogitsProcessor", + "MinPLogitsProcessor", + "MinTokensLogitsProcessor", + "BatchUpdate", + "BatchUpdateBuilder", + "MoveDirectionality", + "LogitsProcessors", + "build_logitsprocs", + "STR_POOLING_REJECTS_LOGITSPROCS", + "LOGITSPROCS_GROUP", + "AdapterLogitsProcessor", ] diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index fc655d993cb4..3c3ddda7fb3e 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -6,9 +6,11 @@ import torch from vllm import SamplingParams -from vllm.v1.sample.logits_processor.interface import (BatchUpdate, - LogitsProcessor, - MoveDirectionality) +from vllm.v1.sample.logits_processor.interface import ( + BatchUpdate, + LogitsProcessor, + MoveDirectionality, +) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -17,25 +19,24 @@ class MinPLogitsProcessor(LogitsProcessor): - - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): max_num_reqs = vllm_config.scheduler_config.max_num_seqs self.min_p_count: int = 0 - self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=is_pin_memory) + self.min_p_cpu_tensor = torch.zeros( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory + ) self.min_p_cpu = self.min_p_cpu_tensor.numpy() self.use_double_tensor = torch.device(device).type != "cpu" if self.use_double_tensor: # Pre-allocated device tensor - self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) + self.min_p_device: torch.Tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device + ) else: self.min_p_device = self.min_p_cpu_tensor # Current slice of the device tensor @@ -93,8 +94,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if self.min_p_count and (needs_update or self.min_p.shape[0] != size): self.min_p = self.min_p_device[:size] if self.use_double_tensor: - self.min_p.copy_(self.min_p_cpu_tensor[:size], - non_blocking=True) + self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True) self.min_p.unsqueeze_(1) def apply(self, logits: torch.Tensor) -> torch.Tensor: @@ -104,28 +104,27 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: # Convert logits to probability distribution probability_values = torch.nn.functional.softmax(logits, dim=-1) # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) # Adjust min_p adjusted_min_p = max_probabilities.mul_(self.min_p) # Identify valid tokens using threshold comparison invalid_token_mask = probability_values < adjusted_min_p # Apply mask using boolean indexing - logits[invalid_token_mask] = -float('inf') + logits[invalid_token_mask] = -float("inf") return logits class LogitBiasLogitsProcessor(LogitsProcessor): - def __init__(self, _, device: torch.device, is_pin_memory: bool): self.device = device self.pin_memory = is_pin_memory self.biases: dict[int, dict[int, float]] = {} self.bias_tensor: torch.Tensor = torch.tensor(()) - self.logits_slice = (self._device_tensor([], torch.int32), - self._device_tensor([], torch.int32)) + self.logits_slice = ( + self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32), + ) def is_argmax_invariant(self) -> bool: """Logit bias can rebalance token probabilities and change the @@ -134,8 +133,8 @@ def is_argmax_invariant(self) -> bool: def update_state(self, batch_update: Optional[BatchUpdate]): needs_update = process_dict_updates( - self.biases, batch_update, - lambda params, _, __: params.logit_bias or None) + self.biases, batch_update, lambda params, _, __: params.logit_bias or None + ) # Update tensors if needed. if needs_update: @@ -148,15 +147,15 @@ def update_state(self, batch_update: Optional[BatchUpdate]): biases.extend(lb.values()) self.bias_tensor = self._device_tensor(biases, torch.float32) - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) + self.logits_slice = ( + self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32), + ) def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) + return torch.tensor( + data, device="cpu", dtype=dtype, pin_memory=self.pin_memory + ).to(device=self.device, non_blocking=True) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.biases: @@ -165,20 +164,19 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class MinTokensLogitsProcessor(LogitsProcessor): - - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): # index -> (min_toks, output_token_ids, stop_token_ids) self.device = device self.pin_memory = is_pin_memory self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} # (req_idx_tensor,eos_tok_id_tensor) - self.logits_slice: tuple[torch.Tensor, - torch.Tensor] = (self._device_tensor( - [], torch.int32), - self._device_tensor( - [], torch.int32)) + self.logits_slice: tuple[torch.Tensor, torch.Tensor] = ( + self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32), + ) def is_argmax_invariant(self) -> bool: """By censoring stop tokens, min-tokens can change the outcome @@ -187,8 +185,7 @@ def is_argmax_invariant(self) -> bool: @staticmethod def add_request( - params: SamplingParams, _: Optional[list[int]], - output_tok_ids: list[int] + params: SamplingParams, _: Optional[list[int]], output_tok_ids: list[int] ) -> Optional[tuple[int, Sequence[int], set[int]]]: min_tokens = params.min_tokens if not min_tokens or len(output_tok_ids) >= min_tokens: @@ -196,13 +193,16 @@ def add_request( return min_tokens, output_tok_ids, params.all_stop_token_ids def update_state(self, batch_update: Optional[BatchUpdate]): - needs_update = process_dict_updates(self.min_toks, batch_update, - self.add_request) + needs_update = process_dict_updates( + self.min_toks, batch_update, self.add_request + ) if self.min_toks: # Check for any requests that have attained their min tokens. - to_remove = tuple(index for index, (min_toks, out_tok_ids, - _) in self.min_toks.items() - if len(out_tok_ids) >= min_toks) + to_remove = tuple( + index + for index, (min_toks, out_tok_ids, _) in self.min_toks.items() + if len(out_tok_ids) >= min_toks + ) if to_remove: needs_update = True for index in to_remove: @@ -216,15 +216,15 @@ def update_state(self, batch_update: Optional[BatchUpdate]): reqs.extend([req] * len(stop_tok_ids)) tok_ids.extend(stop_tok_ids) - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) + self.logits_slice = ( + self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32), + ) def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) + return torch.tensor( + data, device="cpu", dtype=dtype, pin_memory=self.pin_memory + ).to(device=self.device, non_blocking=True) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.min_toks: @@ -234,9 +234,9 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: def process_dict_updates( - req_entries: dict[int, T], batch_update: Optional[BatchUpdate], - new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], - Optional[T]] + req_entries: dict[int, T], + batch_update: Optional[BatchUpdate], + new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], Optional[T]], ) -> bool: """Utility function to update dict state for sparse LogitsProcessors.""" @@ -246,8 +246,7 @@ def process_dict_updates( updated = False for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: - if (state := new_state(params, prompt_tok_ids, - output_tok_ids)) is not None: + if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None: req_entries[index] = state updated = True elif req_entries.pop(index, None) is not None: diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index a84afc2f347a..713bd21d3855 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -36,6 +36,7 @@ class MoveDirectionality(Enum): @dataclass(frozen=True) class BatchUpdate: """Persistent batch state change info for logitsprocs""" + batch_size: int # Current num reqs in batch # Metadata for requests added to, removed from, and moved @@ -57,10 +58,10 @@ class BatchUpdate: class LogitsProcessor(ABC): - @abstractmethod - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool) -> None: + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ) -> None: raise NotImplementedError @abstractmethod diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py index 0a1196559d3e..a601f6641581 100644 --- a/vllm/v1/sample/logits_processor/state.py +++ b/vllm/v1/sample/logits_processor/state.py @@ -4,10 +4,12 @@ from itertools import chain from typing import TYPE_CHECKING, Optional -from vllm.v1.sample.logits_processor.interface import (AddedRequest, - BatchUpdate, - MovedRequest, - RemovedRequest) +from vllm.v1.sample.logits_processor.interface import ( + AddedRequest, + BatchUpdate, + MovedRequest, + RemovedRequest, +) if TYPE_CHECKING: from vllm.v1.sample.logits_processor.interface import LogitsProcessor @@ -81,8 +83,9 @@ def removed_append(self, index: int) -> None: index: request index """ if self._is_removed_sorted: - raise RuntimeError("Cannot register new removed request after" - " self.removed has been read.") + raise RuntimeError( + "Cannot register new removed request after self.removed has been read." + ) self._removed.append(index) self.batch_changed = True @@ -116,7 +119,7 @@ def reset(self) -> bool: def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: """Generate a logitsprocs batch update data structure and reset internal batch update builder state. - + Args: batch_size: current persistent batch size @@ -146,14 +149,17 @@ class LogitsProcessors: """Encapsulates initialized logitsproc objects.""" def __init__( - self, - logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None: + self, logitsprocs: Optional[Iterator["LogitsProcessor"]] = None + ) -> None: self.argmax_invariant: list[LogitsProcessor] = [] self.non_argmax_invariant: list[LogitsProcessor] = [] if logitsprocs: for logitproc in logitsprocs: - (self.argmax_invariant if logitproc.is_argmax_invariant() else - self.non_argmax_invariant).append(logitproc) + ( + self.argmax_invariant + if logitproc.is_argmax_invariant() + else self.non_argmax_invariant + ).append(logitproc) @property def all(self) -> Iterator["LogitsProcessor"]: diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 9d6a87cea3d0..14895db1bd55 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -11,7 +11,6 @@ @dataclass class SamplingMetadata: - temperature: Optional[torch.Tensor] all_greedy: bool all_random: bool diff --git a/vllm/v1/sample/ops/bad_words.py b/vllm/v1/sample/ops/bad_words.py index 1b699565f26f..faa4c33cc793 100644 --- a/vllm/v1/sample/ops/bad_words.py +++ b/vllm/v1/sample/ops/bad_words.py @@ -17,10 +17,7 @@ def _apply_bad_words_single_batch( prefix_length = len(bad_word_ids) - 1 last_token_id = bad_word_ids[-1] - if prefix_length > 0: - actual_prefix = past_tokens_ids[-prefix_length:] - else: - actual_prefix = [] + actual_prefix = past_tokens_ids[-prefix_length:] if prefix_length > 0 else [] expected_prefix = bad_word_ids[:prefix_length] assert len(actual_prefix) == len(expected_prefix) @@ -35,5 +32,4 @@ def apply_bad_words( past_tokens_ids: list[list[int]], ) -> None: for i, bad_words_ids in bad_words_token_ids.items(): - _apply_bad_words_single_batch(logits[i], bad_words_ids, - past_tokens_ids[i]) + _apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i]) diff --git a/vllm/v1/sample/ops/logprobs.py b/vllm/v1/sample/ops/logprobs.py index 82875b7c8452..cf36d46e13fd 100644 --- a/vllm/v1/sample/ops/logprobs.py +++ b/vllm/v1/sample/ops/logprobs.py @@ -8,8 +8,7 @@ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def batched_count_greater_than(x: torch.Tensor, - values: torch.Tensor) -> torch.Tensor: +def batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.Tensor: """ Counts elements in each row of x that are greater than the corresponding value in values. Use torch.compile to generate an optimized kernel for diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index 5d54f6679a1a..e49b8db47800 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -19,15 +19,20 @@ def apply_all_penalties( Applies presence, frequency and repetition penalties to the logits. """ _, vocab_size = logits.shape - output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, - logits.device) - return apply_penalties(logits, prompt_token_ids, output_tokens_t, - presence_penalties, frequency_penalties, - repetition_penalties) + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device) + return apply_penalties( + logits, + prompt_token_ids, + output_tokens_t, + presence_penalties, + frequency_penalties, + repetition_penalties, + ) -def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int, - device: torch.device) -> torch.Tensor: +def _convert_to_tensors( + output_token_ids: list[list[int]], vocab_size: int, device: torch.device +) -> torch.Tensor: """ Convert the different list data structures to tensors. """ diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 5bcf1b585441..dbcdad07e4de 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -16,6 +16,7 @@ try: import flashinfer.sampling + is_flashinfer_available = True except ImportError: is_flashinfer_available = False @@ -34,14 +35,17 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: self.logprobs_mode = logprobs_mode # flashinfer optimization does not apply if intermediate # logprobs/logits after top_k/top_p need to be returned - if logprobs_mode not in ("processed_logits", "processed_logprobs" - ) and current_platform.is_cuda(): + if ( + logprobs_mode not in ("processed_logits", "processed_logprobs") + and current_platform.is_cuda() + ): if is_flashinfer_available: flashinfer_version = flashinfer.__version__ if version.parse(flashinfer_version) < version.parse("0.2.3"): logger.warning_once( "FlashInfer version >= 0.2.3 required. " - "Falling back to default sampling implementation.") + "Falling back to default sampling implementation." + ) self.forward = self.forward_native elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False: # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for @@ -52,21 +56,22 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: # None means False, while in V1, None means True. This is # why we use the condition # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. - logger.info_once( - "Using FlashInfer for top-p & top-k sampling.") + logger.info_once("Using FlashInfer for top-p & top-k sampling.") self.forward = self.forward_cuda else: logger.warning_once( "FlashInfer is available, but it is not enabled. " "Falling back to the PyTorch-native implementation of " "top-p & top-k sampling. For the best performance, " - "please set VLLM_USE_FLASHINFER_SAMPLER=1.") + "please set VLLM_USE_FLASHINFER_SAMPLER=1." + ) self.forward = self.forward_native else: logger.warning_once( "FlashInfer is not available. Falling back to the PyTorch-" "native implementation of top-p & top-k sampling. For the " - "best performance, please install FlashInfer.") + "best performance, please install FlashInfer." + ) self.forward = self.forward_native elif current_platform.is_cpu(): self.forward = self.forward_cpu @@ -109,13 +114,15 @@ def forward_cuda( # CPU-GPU synchronization while `flashinfer_sample` does. if (k is None and p is None) or generators: if generators: - logger.debug_once("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") + logger.debug_once( + "FlashInfer 0.2.3+ does not support " + "per-request generators. Falling back to " + "PyTorch-native implementation." + ) return self.forward_native(logits, generators, k, p) - assert self.logprobs_mode not in ( - "processed_logits", "processed_logprobs" - ), "FlashInfer does not support returning logits/logprobs" + assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), ( + "FlashInfer does not support returning logits/logprobs" + ) # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # because of slicing operation in logits_processor. @@ -278,15 +285,18 @@ def flashinfer_sample( # Top-p only. probs = logits.softmax(dim=-1, dtype=torch.float32) next_token_ids = flashinfer.sampling.top_p_sampling_from_probs( - probs, p, deterministic=True) + probs, p, deterministic=True + ) elif p is None: # Top-k only. probs = logits.softmax(dim=-1, dtype=torch.float32) next_token_ids = flashinfer.sampling.top_k_sampling_from_probs( - probs, k, deterministic=True) + probs, k, deterministic=True + ) else: # Both top-k and top-p. next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits( - logits, k, p, deterministic=True) + logits, k, p, deterministic=True + ) return next_token_ids.view(-1) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 37ce5bef8403..5f1dbf07d1f0 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -54,7 +54,7 @@ def forward( bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - ''' + """ Args: metadata: Metadata for spec decoding. @@ -81,7 +81,7 @@ def forward( Returns: output_token_ids (torch.Tensor): A tensor containing the final output token IDs. - ''' + """ assert metadata.max_spec_len <= MAX_SPEC_LEN # [num_tokens, vocab_size] # NOTE(woosuk): `target_logits` can be updated in place inside the @@ -123,11 +123,11 @@ def parse_output( """ output_token_ids_np = output_token_ids.cpu().numpy() # Create mask for valid tokens. - valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & - (output_token_ids_np < vocab_size)) + valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & ( + output_token_ids_np < vocab_size + ) outputs = [ - row[valid_mask[i]].tolist() - for i, row in enumerate(output_token_ids_np) + row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np) ] return outputs @@ -178,7 +178,7 @@ def rejection_sample( if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) - rejection_greedy_sample_kernel[(batch_size, )]( + rejection_greedy_sample_kernel[(batch_size,)]( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -213,7 +213,7 @@ def rejection_sample( ) # Rejection sampling for random sampling requests. - rejection_random_sample_kernel[(batch_size, )]( + rejection_random_sample_kernel[(batch_size,)]( output_token_ids, cu_num_draft_tokens, draft_token_ids, @@ -320,7 +320,7 @@ def expand_batch_to_tokens( batch_size = x.shape[0] assert cu_num_tokens.shape[0] == batch_size expanded_x = x.new_empty(num_tokens) - expand_kernel[(batch_size, )]( + expand_kernel[(batch_size,)]( expanded_x, x, cu_num_tokens, @@ -368,7 +368,7 @@ def generate_uniform_probs( # https://github.com/pytorch/pytorch/issues/16706. Using float64 # mitigates the issue. uniform_probs = torch.rand( - (num_tokens, ), + (num_tokens,), dtype=torch.float64, device=device, ) @@ -444,18 +444,12 @@ def rejection_greedy_sample_kernel( req_idx = tl.program_id(0) # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, # re-compilation may happen during runtime when is_greedy_ptr is None. - if is_greedy_ptr is None: - is_greedy = True - else: - is_greedy = tl.load(is_greedy_ptr + req_idx) + is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx) if not is_greedy: # Early exit for non-greedy sampling requests. return - if req_idx == 0: - start_idx = 0 - else: - start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx @@ -464,8 +458,10 @@ def rejection_greedy_sample_kernel( if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) - tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, - target_argmax_id) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + target_argmax_id, + ) if draft_token_id != target_argmax_id: # Reject. rejected = True @@ -474,8 +470,9 @@ def rejection_greedy_sample_kernel( # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - num_draft_tokens, bonus_token_id) + output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, + bonus_token_id, + ) # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @@ -500,10 +497,7 @@ def rejection_random_sample_kernel( # Early exit for greedy sampling requests. return - if req_idx == 0: - start_idx = 0 - else: - start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx @@ -514,12 +508,12 @@ def rejection_random_sample_kernel( if NO_DRAFT_PROBS: draft_prob = 1 else: - draft_prob = tl.load(draft_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id + ) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) # NOTE(woosuk): While the draft probability should never be 0, # we check it to avoid NaNs. If it happens to be 0, we reject. @@ -530,15 +524,17 @@ def rejection_random_sample_kernel( # Reject. Use recovered token. rejected = True token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) - tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, - token_id) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id + ) if not rejected: # If all tokens are accepted, append the bonus token. bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - num_draft_tokens, bonus_token_id) + output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, + bonus_token_id, + ) # NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @@ -562,9 +558,7 @@ def expand_kernel( src_val = tl.load(input_ptr + req_idx) src_val = tl.where(src_val == replace_from, replace_to, src_val) offset = tl.arange(0, MAX_NUM_TOKENS) - tl.store(output_ptr + start_idx + offset, - src_val, - mask=offset < num_tokens) + tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens) @triton.jit @@ -580,10 +574,7 @@ def sample_recovered_tokens_kernel( NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) - if req_idx == 0: - start_idx = 0 - else: - start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx @@ -595,26 +586,30 @@ def sample_recovered_tokens_kernel( vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - vocab_offset, - mask=((vocab_offset < vocab_size) & - (vocab_offset != draft_token_id)), - other=0) + prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)), + other=0, + ) else: - draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + - vocab_offset, - mask=vocab_offset < vocab_size, - other=0) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0) + draft_prob = tl.load( + draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) + target_prob = tl.load( + target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0, + ) prob = tl.maximum(target_prob - draft_prob, 0) # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because # `tl.argmax` will select the maximum value. - q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=float("-inf")) + q = tl.load( + q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf"), + ) recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 83ea766b1b4a..d4d3fb029599 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -24,39 +24,39 @@ class Sampler(nn.Module): A layer that samples the next tokens from the model's outputs with the following steps in order: - 1. If logprobs are requested: + 1. If logprobs are requested: a) If `logprobs_mode` is `raw_logprobs`, compute logprobs - as the final logprobs to return. + as the final logprobs to return. b) If `logprobs_mode` is `raw_logits`, clone the logits - as the final logprobs to return. - 2. Convert logits to float32. - 3. Apply allowed token ids whitelist. - 4. Apply bad words exclusion. + as the final logprobs to return. + 2. Convert logits to float32. + 3. Apply allowed token ids whitelist. + 4. Apply bad words exclusion. 5. Apply logit processors which are not argmax-invariant, - i.e. that can impact greedy sampling. - a) Min tokens processor - b) Logit bias processor - 6. Apply penalties - a) Repetition penalty - b) Frequency penalty - c) Presence penalty - 7. Sample the next tokens. `sample` method performs the following steps: + i.e. that can impact greedy sampling. + a) Min tokens processor + b) Logit bias processor + 6. Apply penalties + a) Repetition penalty + b) Frequency penalty + c) Presence penalty + 7. Sample the next tokens. `sample` method performs the following steps: a) If not `all_random`, perform greedy sampling. If `all_greedy`, - return the greedily sampled tokens and final logprobs if requested. - b) Apply temperature. + return the greedily sampled tokens and final logprobs if requested. + b) Apply temperature. c) Apply logit processors which are argmax-invariant, by default - the min_p processor. - d) Apply top_k and/or top_p. - e) Sample the next tokens with the probability distribution. + the min_p processor. + d) Apply top_k and/or top_p. + e) Sample the next tokens with the probability distribution. f) If `all_random` or temperature >= epsilon (1e-5), return the randomly sampled tokens and final logprobs if requested. Else, - return the greedily sampled tokens and logprobs if requested. + return the greedily sampled tokens and logprobs if requested. 8. Gather the logprobs of the top `max_num_logprobs` and sampled token (if requested). Note that if the sampled token is within the top `max_num_logprobs`, the logprob will be eventually merged in `LogprobsProcessor` during output processing. Therefore, the final output may contain either `max_num_logprobs + 1` or - `max_num_logprobs` logprobs. + `max_num_logprobs` logprobs. 9. Return the final `SamplerOutput`. """ @@ -108,8 +108,11 @@ def forward( # Gather the logprobs of the topk and sampled token (if requested). # Get logprobs and rank tensors (if requested) - logprobs_tensors = None if num_logprobs is None else \ - self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) + logprobs_tensors = ( + None + if num_logprobs is None + else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) + ) # Use int32 to reduce the tensor size. sampled = sampled.to(torch.int32) @@ -150,8 +153,7 @@ def sample( may update the logits tensor in-place. """ - assert not (sampling_metadata.all_greedy - and sampling_metadata.all_random) + assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) if sampling_metadata.all_random: greedy_sampled = None else: @@ -168,8 +170,9 @@ def sample( assert sampling_metadata.temperature is not None # Apply temperature. - logits = self.apply_temperature(logits, sampling_metadata.temperature, - sampling_metadata.all_random) + logits = self.apply_temperature( + logits, sampling_metadata.temperature, sampling_metadata.all_random + ) # Apply logits processors that only apply to random sampling # (argmax invariant) @@ -224,9 +227,7 @@ def gather_logprobs( """ assert token_ids.dtype == torch.int64 # Find the topK values. - topk_logprobs, topk_indices = torch.topk(logprobs, - num_logprobs, - dim=-1) + topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1) # Get with the logprob of the prompt or sampled token. token_ids = token_ids.unsqueeze(-1) @@ -267,8 +268,7 @@ def apply_allowed_token_ids( sampling_metadata: SamplingMetadata, ) -> torch.Tensor: if sampling_metadata.allowed_token_ids_mask is not None: - logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, - float("-inf")) + logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf")) return logits def apply_bad_words( diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 6491c84f6076..b58a94d0bf7d 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -48,15 +48,13 @@ class TPUSupportedSamplingMetadata: min_tokens = None # impl is not vectorized - logit_bias: list[Optional[dict[int, float]]] = field( - default_factory=lambda: list()) + logit_bias: list[Optional[dict[int, float]]] = field(default_factory=lambda: list()) allowed_token_ids_mask = None bad_words_token_ids = None # Generator not supported by xla - _generators: dict[int, - torch.Generator] = field(default_factory=lambda: dict()) + _generators: dict[int, torch.Generator] = field(default_factory=lambda: dict()) @property def generators(self) -> dict[int, torch.Generator]: @@ -69,13 +67,13 @@ def from_input_batch( input_batch: InputBatch, padded_num_reqs: int, xla_device: torch.device, - generate_params_if_all_greedy: bool = False + generate_params_if_all_greedy: bool = False, ) -> "TPUSupportedSamplingMetadata": """ Copy sampling tensors slices from `input_batch` to on device tensors. - `InputBatch._make_sampling_metadata` causes recompilation on XLA as it - slices dynamic shapes on device tensors. This impl moves the dynamic + `InputBatch._make_sampling_metadata` causes recompilation on XLA as it + slices dynamic shapes on device tensors. This impl moves the dynamic ops to CPU and produces tensors of fixed `padded_num_reqs` size. Args: @@ -87,11 +85,11 @@ def from_input_batch( we want to pre-compile a graph with sampling parameters, even if they are not strictly needed for greedy decoding. """ - needs_logprobs = input_batch.max_num_logprobs>0 if \ - input_batch.max_num_logprobs else False + needs_logprobs = ( + input_batch.max_num_logprobs > 0 if input_batch.max_num_logprobs else False + ) # Early return to avoid unnecessary cpu to tpu copy - if (input_batch.all_greedy is True - and generate_params_if_all_greedy is False): + if input_batch.all_greedy is True and generate_params_if_all_greedy is False: return cls(all_greedy=True, logprobs=needs_logprobs) num_reqs = input_batch.num_reqs @@ -100,25 +98,22 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: # Pad value is the default one. cpu_tensor[num_reqs:padded_num_reqs] = fill_val - fill_slice(input_batch.temperature_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["temperature"]) - fill_slice(input_batch.min_p_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["min_p"]) - fill_slice(input_batch.top_k_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["top_k"]) - fill_slice(input_batch.top_p_cpu_tensor, - DEFAULT_SAMPLING_PARAMS["top_p"]) + fill_slice( + input_batch.temperature_cpu_tensor, DEFAULT_SAMPLING_PARAMS["temperature"] + ) + fill_slice(input_batch.min_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["min_p"]) + fill_slice(input_batch.top_k_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_k"]) + fill_slice(input_batch.top_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_p"]) # Slice persistent device tensors to a fixed pre-compiled padded shape. return cls( - temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs]. - to(xla_device), + temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs].to( + xla_device + ), all_greedy=input_batch.all_greedy, # TODO enable more and avoid returning None values - top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to( - xla_device), - top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to( - xla_device), - min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( - xla_device), - logprobs=needs_logprobs) + top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to(xla_device), + top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to(xla_device), + min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(xla_device), + logprobs=needs_logprobs, + ) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 17b83a4ba074..ccef283a8182 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -14,7 +14,6 @@ class Sampler(nn.Module): - def __init__(self): # TODO(houseroad): Add support for logprobs_mode. super().__init__() @@ -35,7 +34,8 @@ def forward( # [num_requests, 1], where each row represents one generated # token per request. sampled_token_ids=sampled.unsqueeze(-1), - logprobs_tensors=None) + logprobs_tensors=None, + ) return sampler_output def apply_temperature( @@ -73,11 +73,13 @@ def sample( # Random sample. probs = logits.softmax(dim=-1, dtype=torch.float32) - random_sampled = self.random_sample(probs, - sampling_metadata.generators) + random_sampled = self.random_sample(probs, sampling_metadata.generators) - sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS, - greedy_sampled, random_sampled) + sampled = torch.where( + sampling_metadata.temperature < _SAMPLING_EPS, + greedy_sampled, + random_sampled, + ) return sampled def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: @@ -107,9 +109,7 @@ def gather_logprobs( Sampled token rank tensor, (num tokens) """ # Find the topK values. - topk_logprobs, topk_indices = torch.topk(logprobs, - num_logprobs, - dim=-1) + topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1) # Get with the logprob of the prompt or sampled token. token_ids = token_ids.unsqueeze(-1) @@ -138,9 +138,7 @@ def apply_min_p( # Convert logits to probability distribution probability_values = torch.nn.functional.softmax(logits, dim=-1) # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) # Reshape min_p for broadcasting adjusted_min_p = min_p.unsqueeze(1) * max_probabilities # Identify valid tokens using threshold comparison diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 876838084b9a..747d08dcd367 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -18,15 +18,18 @@ from vllm import envs from vllm.logger import init_logger -# yapf: disable -from vllm.multimodal.inputs import (BaseMultiModalField, - MultiModalBatchedField, - MultiModalFieldConfig, MultiModalFieldElem, - MultiModalFlatField, MultiModalKwargs, - MultiModalKwargsItem, - MultiModalKwargsItems, - MultiModalSharedField, NestedTensors) -# yapf: enable +from vllm.multimodal.inputs import ( + BaseMultiModalField, + MultiModalBatchedField, + MultiModalFieldConfig, + MultiModalFieldElem, + MultiModalFlatField, + MultiModalKwargs, + MultiModalKwargsItem, + MultiModalKwargsItems, + MultiModalSharedField, + NestedTensors, +) from vllm.v1.engine import UtilityResult logger = init_logger(__name__) @@ -48,8 +51,10 @@ def _log_insecure_serialization_warning(): - logger.warning_once("Allowing insecure serialization using pickle due to " - "VLLM_ALLOW_INSECURE_SERIALIZATION=1") + logger.warning_once( + "Allowing insecure serialization using pickle due to " + "VLLM_ALLOW_INSECURE_SERIALIZATION=1" + ) def _typestr(val: Any) -> Optional[tuple[str, str]]: @@ -72,8 +77,8 @@ def _encode_type_info_recursive(obj: Any) -> Any: def _decode_type_info_recursive( - type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], - Any]) -> Any: + type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any] +) -> Any: """Recursively decode type information for nested structures of lists/dicts.""" if type_info is None: @@ -85,8 +90,9 @@ def _decode_type_info_recursive( for k in type_info } if isinstance(type_info, list) and ( - # Exclude serialized tensors/numpy arrays. - len(type_info) != 2 or not isinstance(type_info[0], str)): + # Exclude serialized tensors/numpy arrays. + len(type_info) != 2 or not isinstance(type_info[0], str) + ): assert isinstance(data, list) return [ _decode_type_info_recursive(ti, d, convert_fn) @@ -101,7 +107,7 @@ class MsgpackEncoder: Note that unlike vanilla `msgspec` Encoders, this interface is generally not thread-safe when encoding tensors / numpy arrays. - By default, arrays below 256B are serialized inline Larger will get sent + By default, arrays below 256B are serialized inline Larger will get sent via dedicated messages. Note that this is a per-tensor limit. """ @@ -119,7 +125,7 @@ def __init__(self, size_threshold: Optional[int] = None): def encode(self, obj: Any) -> Sequence[bytestr]: try: - self.aux_buffers = bufs = [b''] + self.aux_buffers = bufs = [b""] bufs[0] = self.encoder.encode(obj) # This `bufs` list allows us to collect direct pointers to backing # buffers of tensors and np arrays, and return them along with the @@ -143,14 +149,15 @@ def enc_hook(self, obj: Any) -> Any: return self._encode_tensor(obj) # Fall back to pickle for object or void kind ndarrays. - if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): + if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"): return self._encode_ndarray(obj) if isinstance(obj, slice): # We are assuming only int-based values will be used here. return tuple( int(v) if v is not None else None - for v in (obj.start, obj.stop, obj.step)) + for v in (obj.start, obj.stop, obj.step) + ) if isinstance(obj, MultiModalKwargsItem): return self._encode_mm_item(obj) @@ -171,17 +178,20 @@ def enc_hook(self, obj: Any) -> Any: return _encode_type_info_recursive(result), result if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: - raise TypeError(f"Object of type {type(obj)} is not serializable" - "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow " - "fallback to pickle-based serialization.") + raise TypeError( + f"Object of type {type(obj)} is not serializable" + "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow " + "fallback to pickle-based serialization." + ) if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have # problems serializing methods. return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) - return msgpack.Ext(CUSTOM_TYPE_PICKLE, - pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) + return msgpack.Ext( + CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + ) def _encode_ndarray( self, obj: np.ndarray @@ -225,27 +235,22 @@ def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]: for modality, itemlist in items.items() } - def _encode_mm_item(self, - item: MultiModalKwargsItem) -> list[dict[str, Any]]: + def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]: return [self._encode_mm_field_elem(elem) for elem in item.values()] - def _encode_mm_field_elem(self, - elem: MultiModalFieldElem) -> dict[str, Any]: + def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]: return { - "modality": - elem.modality, - "key": - elem.key, - "data": (None if elem.data is None else - self._encode_nested_tensors(elem.data)), - "field": - self._encode_mm_field(elem.field), + "modality": elem.modality, + "key": elem.key, + "data": ( + None if elem.data is None else self._encode_nested_tensors(elem.data) + ), + "field": self._encode_mm_field(elem.field), } def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]: return { - modality: self._encode_nested_tensors(data) - for modality, data in kw.items() + modality: self._encode_nested_tensors(data) for modality, data in kw.items() } def _encode_nested_tensors(self, nt: NestedTensors) -> Any: @@ -264,8 +269,7 @@ def _encode_mm_field(self, field: BaseMultiModalField): raise TypeError(f"Unsupported field type: {field.__class__}") # We just need to copy all of the field values in order # which will be then used to reconstruct the field. - field_values = (getattr(field, f.name) - for f in dataclasses.fields(field)) + field_values = (getattr(field, f.name) for f in dataclasses.fields(field)) return name, *field_values @@ -277,10 +281,10 @@ class MsgpackDecoder: """ def __init__(self, t: Optional[Any] = None): - args = () if t is None else (t, ) - self.decoder = msgpack.Decoder(*args, - ext_hook=self.ext_hook, - dec_hook=self.dec_hook) + args = () if t is None else (t,) + self.decoder = msgpack.Decoder( + *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook + ) self.aux_buffers: Sequence[bytestr] = () if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: _log_insecure_serialization_warning() @@ -320,11 +324,14 @@ def _decode_utility_result(self, obj: Any) -> UtilityResult: result_type, result = obj if result_type is not None: if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: - raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must " - "be set to use custom utility result types") + raise TypeError( + "VLLM_ALLOW_INSECURE_SERIALIZATION must " + "be set to use custom utility result types" + ) # Use recursive decoding to handle nested structures - result = _decode_type_info_recursive(result_type, result, - self._convert_result) + result = _decode_type_info_recursive( + result_type, result, self._convert_result + ) return UtilityResult(result) def _convert_result(self, result_type: Sequence[str], result: Any) -> Any: @@ -347,8 +354,7 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: # Copy from inline representation, to decouple the memory storage # of the message from the original buffer. And also make Torch # not complain about a readonly memoryview. - buffer = self.aux_buffers[data] if isinstance(data, int) \ - else bytearray(data) + buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data) torch_dtype = getattr(torch, dtype) assert isinstance(torch_dtype, torch.dtype) if not buffer: # torch.frombuffer doesn't like empty buffers @@ -360,17 +366,19 @@ def _decode_tensor(self, arr: Any) -> torch.Tensor: return arr.view(torch_dtype).view(shape) def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems: - return MultiModalKwargsItems({ - modality: [self._decode_mm_item(item) for item in itemlist] - for modality, itemlist in obj.items() - }) + return MultiModalKwargsItems( + { + modality: [self._decode_mm_item(item) for item in itemlist] + for modality, itemlist in obj.items() + } + ) def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem: return MultiModalKwargsItem.from_elems( - [self._decode_mm_field_elem(v) for v in obj]) + [self._decode_mm_field_elem(v) for v in obj] + ) - def _decode_mm_field_elem(self, obj: dict[str, - Any]) -> MultiModalFieldElem: + def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem: if obj["data"] is not None: obj["data"] = self._decode_nested_tensors(obj["data"]) @@ -387,10 +395,12 @@ def _decode_mm_field_elem(self, obj: dict[str, return MultiModalFieldElem(**obj) def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs: - return MultiModalKwargs({ - modality: self._decode_nested_tensors(data) - for modality, data in obj.items() - }) + return MultiModalKwargs( + { + modality: self._decode_nested_tensors(data) + for modality, data in obj.items() + } + ) def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if isinstance(obj, (int, float)): @@ -419,5 +429,4 @@ def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_CLOUDPICKLE: return cloudpickle.loads(data) - raise NotImplementedError( - f"Extension type code {code} is not supported") + raise NotImplementedError(f"Extension type code {code} is not supported") diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dc6db0138806..5d4822a6279b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -10,8 +10,7 @@ import torch.nn as nn from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config) +from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger @@ -23,11 +22,15 @@ from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, - TreeAttentionMetadataBuilder) +from vllm.v1.attention.backends.tree_attn import ( + TreeAttentionMetadata, + TreeAttentionMetadataBuilder, +) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -41,7 +44,6 @@ class EagleProposer: - def __init__( self, vllm_config: VllmConfig, @@ -59,10 +61,8 @@ def __init__( self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size - self.num_speculative_tokens = ( - self.speculative_config.num_speculative_tokens) - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + self.num_speculative_tokens = self.speculative_config.num_speculative_tokens + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's @@ -72,62 +72,64 @@ def __init__( # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - vllm_config.model_config) + vllm_config.model_config + ) self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None - self.draft_indexer_metadata_builder: Optional[ - AttentionMetadataBuilder] = None + self.draft_indexer_metadata_builder: Optional[AttentionMetadataBuilder] = None self.attn_layer_names: list[str] = [] self.indexer_layer_names: list[str] = [] - self.use_cuda_graph = (not current_platform.is_xpu() - and self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager - and not self.speculative_config.enforce_eager) - self.cudagraph_batch_sizes = list( - reversed(self.vllm_config.compilation_config. - cudagraph_capture_sizes)) if self.use_cuda_graph else [] + self.use_cuda_graph = ( + not current_platform.is_xpu() + and self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE + and not self.vllm_config.model_config.enforce_eager + and not self.speculative_config.enforce_eager + ) + self.cudagraph_batch_sizes = ( + list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes)) + if self.use_cuda_graph + else [] + ) # persistent buffers for cuda graph - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=device) + self.input_ids = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device=device + ) self.uses_mrope = self.vllm_config.model_config.uses_mrope if self.uses_mrope: # M-RoPE need (3, max_num_tokens) - self.mrope_positions = torch.zeros((3, self.max_num_tokens), - dtype=torch.int64, - device=device) + self.mrope_positions = torch.zeros( + (3, self.max_num_tokens), dtype=torch.int64, device=device + ) else: # RoPE need (max_num_tokens,) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=device) + self.positions = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device=device + ) self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. max_batch_size = vllm_config.scheduler_config.max_num_seqs max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) - self.arange = torch.arange(max_num_slots_for_arange, - device=device, - dtype=torch.int32) + self.arange = torch.arange( + max_num_slots_for_arange, device=device, dtype=torch.int32 + ) self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) self.backup_next_token_ids = CpuGpuBuffer( max_batch_size, dtype=torch.int32, pin_memory=is_pin_memory_available(), device=device, - with_numpy=True) + with_numpy=True, + ) # Determine allowed attention backends once during initialization. self.allowed_attn_types: Optional[tuple] = None @@ -136,14 +138,15 @@ def __init__( # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) + AiterFlashAttentionMetadata, + ) + rocm_types.append(AiterFlashAttentionMetadata) self.allowed_attn_types = tuple(rocm_types) # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree - self.tree_choices: list[tuple[int, - ...]] = ast.literal_eval(spec_token_tree) + self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) tree_depth = len(self.tree_choices[-1]) # Precompute per-level properties of the tree. num_drafts_per_level = [0] * tree_depth @@ -152,10 +155,12 @@ def __init__( self.cu_drafts_per_level = [num_drafts_per_level[0]] self.child_drafts_per_level = [num_drafts_per_level[0]] for level in range(1, tree_depth): - self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + - num_drafts_per_level[level]) - self.child_drafts_per_level.append(num_drafts_per_level[level] // - num_drafts_per_level[level - 1]) + self.cu_drafts_per_level.append( + self.cu_drafts_per_level[-1] + num_drafts_per_level[level] + ) + self.child_drafts_per_level.append( + num_drafts_per_level[level] // num_drafts_per_level[level - 1] + ) # Precompute draft position offsets in flattened tree. self.tree_draft_pos_offsets = torch.arange( 1, @@ -188,8 +193,7 @@ def propose( last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embed_inputs: Optional[tuple[list[torch.Tensor], - torch.Tensor]] = None, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -200,11 +204,12 @@ def propose( if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( - target_hidden_states) + target_hidden_states + ) assert target_hidden_states.shape[-1] == self.hidden_size # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids @@ -213,17 +218,20 @@ def propose( # FIXME: need to consider multiple kv_cache_groups ubatch_id = dbo_current_ubatch_id() - attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ + ubatch_id + ] attn_metadata = attn_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, draft_index=0) + common_attn_metadata=common_attn_metadata, draft_index=0 + ) # FIXME: support hybrid kv for draft model (remove separate indexer) if self.draft_indexer_metadata_builder: draft_indexer_metadata = ( self.draft_indexer_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0, - )) + ) + ) else: draft_indexer_metadata = None # At this moment, we assume all eagle layers belong to the same KV @@ -235,8 +243,7 @@ def propose( assert draft_indexer_metadata is not None per_layer_attn_metadata[layer_name] = draft_indexer_metadata - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens @@ -259,9 +266,9 @@ def propose( input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_forward_context( + per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens + ): ret_hidden_states = self.model( input_ids=input_ids, positions=self._get_positions(num_input_tokens), @@ -304,28 +311,30 @@ def propose( draft_token_ids = logits.argmax(dim=-1) - if self.allowed_attn_types is not None and \ - not isinstance(attn_metadata, self.allowed_attn_types): + if self.allowed_attn_types is not None and not isinstance( + attn_metadata, self.allowed_attn_types + ): raise ValueError( f"Unsupported attention metadata type for speculative " "decoding with num_speculative_tokens > 1: " f"{type(attn_metadata)}. Supported types are: " - f"{self.allowed_attn_types}") + f"{self.allowed_attn_types}" + ) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 - common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] + common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] common_attn_metadata.query_start_loc_cpu = torch.from_numpy( - self.token_arange_np[:batch_size + 1]).clone() + self.token_arange_np[: batch_size + 1] + ).clone() for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. @@ -344,14 +353,15 @@ def propose( exceeds_max_model_len = positions[0] >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where\ - (exceeds_max_model_len.unsqueeze(0), \ - torch.zeros_like(positions), positions) + clamped_positions = torch.where( + exceeds_max_model_len.unsqueeze(0), + torch.zeros_like(positions), + positions, + ) else: positions += 1 exceeds_max_model_len = positions >= self.max_model_len - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + clamped_positions = torch.where(exceeds_max_model_len, 0, positions) # Increment the sequence lengths. common_attn_metadata.seq_lens += 1 @@ -359,11 +369,11 @@ def propose( # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, - 1) + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) - common_attn_metadata.num_computed_tokens_cpu = \ + common_attn_metadata.num_computed_tokens_cpu = ( common_attn_metadata.seq_lens_cpu - 1 + ) # Compute the slot mapping. if self.uses_mrope: @@ -372,26 +382,28 @@ def propose( else: block_numbers = clamped_positions // self.block_size block_ids = common_attn_metadata.block_table_tensor.gather( - dim=1, index=block_numbers.view(-1, 1)) + dim=1, index=block_numbers.view(-1, 1) + ) block_ids = block_ids.view(-1) if self.uses_mrope: common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions[0] % self.block_size) + block_ids * self.block_size + clamped_positions[0] % self.block_size + ) else: common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions % self.block_size) + block_ids * self.block_size + clamped_positions % self.block_size + ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. common_attn_metadata.slot_mapping.masked_fill_( - exceeds_max_model_len, PADDING_SLOT_ID) + exceeds_max_model_len, PADDING_SLOT_ID + ) # Rebuild attention metadata attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore - common_attn_metadata=common_attn_metadata, - draft_index=token_index + 1) + common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 + ) for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata @@ -400,8 +412,9 @@ def propose( self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.supports_mm_inputs: - self.inputs_embeds[:batch_size] = \ - self.model.get_input_embeddings(input_ids) + self.inputs_embeds[:batch_size] = self.model.get_input_embeddings( + input_ids + ) input_ids = None inputs_embeds = self.inputs_embeds[:input_batch_size] @@ -410,9 +423,9 @@ def propose( inputs_embeds = None # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): + with set_forward_context( + per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size + ): ret_hidden_states = self.model( input_ids=input_ids, positions=self._get_positions(input_batch_size), @@ -434,10 +447,12 @@ def propose( return draft_token_ids def prepare_next_token_ids_cpu( - self, sampled_token_ids: list[list[int]], - requests: dict[str, - CachedRequestState], gpu_input_batch: InputBatch, - num_scheduled_tokens: dict[str, int]) -> torch.Tensor: + self, + sampled_token_ids: list[list[int]], + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> torch.Tensor: """ This function is used to prepare the inputs for speculative decoding. It calculates the next token ids for each request based on the sampled @@ -456,23 +471,23 @@ def prepare_next_token_ids_cpu( # Get the next token id from the request state. req_id = req_ids[i] req_state = requests[req_id] - seq_len = (req_state.num_computed_tokens + - num_scheduled_tokens[req_id]) + seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.input_ids.device) + next_token_ids = torch.tensor( + next_token_ids, dtype=torch.int32, device=self.input_ids.device + ) return next_token_ids - def prepare_next_token_ids_padded(self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: torch.Tensor, - requests: dict[str, CachedRequestState], - gpu_input_batch: InputBatch, - discard_request_indices: torch.Tensor, - num_discarded_requests: int) -> \ - tuple[torch.Tensor, torch.Tensor]: + def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int, + ) -> tuple[torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding. It calculates the next token ids and the number of valid sampled tokens @@ -486,30 +501,34 @@ def prepare_next_token_ids_padded(self, # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs - self.backup_next_token_ids.np[:num_reqs] = np.array([ - requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item()) - for i in range(num_reqs) - ]) + self.backup_next_token_ids.np[:num_reqs] = np.array( + [ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item() + ) + for i in range(num_reqs) + ] + ) self.backup_next_token_ids.copy_to_gpu(num_reqs) # Mask out the sampled tokens indices that should not be sampled. - discard_sampled_tokens_req_indices = \ - discard_request_indices[:num_discarded_requests] + discard_sampled_tokens_req_indices = discard_request_indices[ + :num_discarded_requests + ] valid_sampled_token_ids_gpu = sampled_token_ids.clone() valid_sampled_token_ids_gpu.index_fill_( - 0, discard_sampled_tokens_req_indices, -1) + 0, discard_sampled_tokens_req_indices, -1 + ) # Generate a mask for all valid tokens within those requests max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: - valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, - dtype=torch.bool) + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool) else: - valid_mask = ( - (valid_sampled_token_ids_gpu != -1) & - (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)) + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) # Count the number of valid tokens in each request valid_sampled_tokens_count = valid_mask.sum(dim=1) @@ -521,22 +540,25 @@ def prepare_next_token_ids_padded(self, # Get last valid token from each row # (assume undefined state where there is no valid token) selected_tokens = torch.gather( - valid_sampled_token_ids_gpu, 1, - last_valid_indices_safe.unsqueeze(1)).squeeze(1) + valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) + ).squeeze(1) # Use last token if valid, pre-computed backup if not batch_size = valid_sampled_token_ids_gpu.shape[0] next_token_ids = torch.where( - last_valid_indices != -1, selected_tokens, - self.backup_next_token_ids.gpu[:batch_size]) + last_valid_indices != -1, + selected_tokens, + self.backup_next_token_ids.gpu[:batch_size], + ) return next_token_ids, valid_sampled_tokens_count - def prepare_inputs_padded(self, - common_attn_metadata: CommonAttentionMetadata, - spec_decode_metadata: SpecDecodeMetadata, - valid_sampled_tokens_count: torch.Tensor) -> \ - tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + def prepare_inputs_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor, + ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding It updates the common_attn_metadata for speculative decoding, @@ -545,21 +567,23 @@ def prepare_inputs_padded(self, used as padding and filtered out later by `token_indices_to_sample`. No blocking CPU operations should be introduced in this function. """ - num_draft_tokens_gpu = torch.cat([ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1] - ]) + num_draft_tokens_gpu = torch.cat( + [ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] + - spec_decode_metadata.cu_num_draft_tokens[:-1], + ] + ) num_rejected_tokens_gpu = torch.where( num_draft_tokens_gpu > 0, num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, - torch.zeros_like(num_draft_tokens_gpu)) + torch.zeros_like(num_draft_tokens_gpu), + ) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] total_num_tokens = query_start_loc_cpu[-1].item() token_indices = self.arange[:total_num_tokens] @@ -569,8 +593,7 @@ def prepare_inputs_padded(self, seq_lens=common_attn_metadata.seq_lens, query_start_loc_cpu=query_start_loc_cpu, seq_lens_cpu=common_attn_metadata.seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), @@ -580,8 +603,9 @@ def prepare_inputs_padded(self, causal=True, ) - token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ - - num_rejected_tokens_gpu + token_indices_to_sample = ( + common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu + ) return spec_common_attn_metadata, token_indices, token_indices_to_sample @@ -596,10 +620,10 @@ def propose_tree( hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ) -> list[torch.Tensor]: - tree_attn_metadata_builder = \ - self.runner.attn_groups[0][0].get_metadata_builder() - assert isinstance(tree_attn_metadata_builder, - TreeAttentionMetadataBuilder) + tree_attn_metadata_builder = self.runner.attn_groups[0][ + 0 + ].get_metadata_builder() + assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) total_num_drafts = self.cu_drafts_per_level[0] level_num_drafts = total_num_drafts @@ -608,31 +632,31 @@ def propose_tree( if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: - draft_token_ids = torch.topk(logits, num_children, - dim=-1).indices.view(batch_size, -1) + draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view( + batch_size, -1 + ) draft_token_ids_list = [draft_token_ids] draft_hidden_states = hidden_states.view(batch_size, 1, -1) # Initialize empty tensors for concatenation with the level outputs. - tree_input_ids = torch.empty(0, - device=self.input_ids.device, - dtype=self.input_ids.dtype) - tree_positions = torch.empty(0, - device=self.positions.device, - dtype=self.positions.dtype) - tree_hidden_states = torch.empty(0, - device=self.hidden_states.device, - dtype=self.hidden_states.dtype) + tree_input_ids = torch.empty( + 0, device=self.input_ids.device, dtype=self.input_ids.dtype + ) + tree_positions = torch.empty( + 0, device=self.positions.device, dtype=self.positions.dtype + ) + tree_hidden_states = torch.empty( + 0, device=self.hidden_states.device, dtype=self.hidden_states.dtype + ) # Precompute the draft token positions. flattened_draft_positions = ( - positions.view(batch_size, -1) + - self.tree_draft_pos_offsets[:batch_size, :]) + positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :] + ) tree_depth = len(self.cu_drafts_per_level) for level in range(tree_depth - 1): # Get draft positions for RoPE. draft_positions = positions + (level + 1) - exceeds_max_model_len = (positions + - total_num_drafts) >= self.max_model_len + exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. draft_positions = torch.where( @@ -644,27 +668,28 @@ def propose_tree( if level_num_drafts > 1: # Repeat the positions for each draft at this level. draft_positions = draft_positions.repeat_interleave( - level_num_drafts, dim=1) + level_num_drafts, dim=1 + ) if num_children > 1: # Repeat draft hidden states for each child. draft_hidden_states = draft_hidden_states.repeat_interleave( - num_children, dim=1) + num_children, dim=1 + ) # Concatenate the draft tokens, positions, and hidden states. - tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], - dim=1) - tree_positions = torch.cat([tree_positions, draft_positions], - dim=1) + tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1) + tree_positions = torch.cat([tree_positions, draft_positions], dim=1) tree_hidden_states = torch.cat( - [tree_hidden_states, draft_hidden_states], dim=1) + [tree_hidden_states, draft_hidden_states], dim=1 + ) # Build new attention metadata for the next level of drafts. # This is necessary to support tree attention. query_len = total_num_drafts common_attn_metadata = replace( common_attn_metadata, - query_start_loc=query_len * self.arange[:batch_size + 1], + query_start_loc=query_len * self.arange[: batch_size + 1], seq_lens=common_attn_metadata.seq_lens + level_num_drafts, num_actual_tokens=batch_size * query_len, max_query_len=query_len, @@ -680,20 +705,20 @@ def propose_tree( per_layer_attn_metadata[layer_name] = attn_metadata # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) + attn_metadata.max_seq_len = min( + attn_metadata.max_seq_len, self.max_model_len + ) # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. - query_positions = flattened_draft_positions[:, level:level + - query_len] + query_positions = flattened_draft_positions[:, level : level + query_len] block_numbers = query_positions // self.block_size - block_ids = attn_metadata.block_table.gather(dim=1, - index=block_numbers) - slot_mapping = (block_ids * self.block_size + - query_positions % self.block_size) + block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers) + slot_mapping = ( + block_ids * self.block_size + query_positions % self.block_size + ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. @@ -705,19 +730,16 @@ def propose_tree( input_ids = tree_input_ids.view(-1) self.input_ids[:num_tokens] = input_ids self.positions[:num_tokens] = tree_positions.view(-1) - self.hidden_states[:num_tokens] = tree_hidden_states.view( - num_tokens, -1) + self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1) - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_tokens) + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_forward_context( + per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens + ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], @@ -727,28 +749,29 @@ def propose_tree( # Get the output hidden states for the draft tokens. draft_hidden_states = hidden_states[:num_tokens].view( - batch_size, query_len, -1)[:, -level_num_drafts:] + batch_size, query_len, -1 + )[:, -level_num_drafts:] draft_last_hidden_states = last_hidden_states[:num_tokens].view( - batch_size, query_len, -1)[:, -level_num_drafts:] + batch_size, query_len, -1 + )[:, -level_num_drafts:] # Get the output logits for the draft tokens. logits = self.model.compute_logits( - draft_last_hidden_states.reshape(batch_size * level_num_drafts, - -1)) + draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1) + ) # Sample a draft token for each child at the next tree level. num_children = self.child_drafts_per_level[level + 1] if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: - draft_token_ids = torch.topk(logits, num_children, - dim=-1).indices.view( - batch_size, -1) + draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view( + batch_size, -1 + ) draft_token_ids_list.append(draft_token_ids) # Update the # drafts counters for the next tree level. - level_num_drafts = self.cu_drafts_per_level[level + - 1] - total_num_drafts + level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts total_num_drafts = self.cu_drafts_per_level[level + 1] return draft_token_ids_list @@ -784,17 +807,14 @@ def prepare_inputs( n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, - dtype=torch.int32) + num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - - num_rejected_tokens + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() @@ -804,7 +824,8 @@ def prepare_inputs( new_query_start_loc_cpu = torch.zeros( query_start_loc_cpu.shape, dtype=torch.int32, - pin_memory=is_pin_memory_available()) + pin_memory=is_pin_memory_available(), + ) new_query_start_loc_np = new_query_start_loc_cpu.numpy() np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) @@ -814,36 +835,36 @@ def prepare_inputs( # [0, 2, 6, 9] -> # [0, 0, 2, 2, 2, 2, 6, 6, 6] # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) + new_query_start_locs_expanded = np.repeat( + new_query_start_loc_np[:-1], new_num_tokens_per_req_np + ) # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> # [0, 1, 0, 1, 2, 3, 0, 1, 2] # _r1_ ____r2____ ___r3__ - token_offests = self.token_arange_np[:total_num_tokens] \ - - new_query_start_locs_expanded + token_offests = ( + self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded + ) # Expand starting positions to match token pattern # [0, q1, q1 + q2] -> # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] # _r1_ _____r2_______ ___________r3____________ old_query_start_locs_expanded = np.repeat( - query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np + ) # Final token indices are: # [0, 1, // req 1 # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to( - device, non_blocking=True) + token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True) spec_common_attn_metadata = CommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, - non_blocking=True), + query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), query_start_loc_cpu=new_query_start_loc_cpu, seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), @@ -856,45 +877,52 @@ def prepare_inputs( return spec_common_attn_metadata, token_indices def get_model_name(self, model: nn.Module) -> str: - if hasattr(model, 'module'): # multi-GPU + if hasattr(model, "module"): # multi-GPU model = model.module return model.__class__.__name__ def load_model(self, target_model: nn.Module) -> None: - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config + draft_model_config = self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + ) # FIXME: support hybrid kv for draft model target_indexer_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, - DeepseekV32IndexerCache).keys()) + get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache + ).keys() + ) from vllm.compilation.backends import set_model_tag + with set_model_tag("eagle_head"): - self.model = get_model(vllm_config=self.vllm_config, - model_config=draft_model_config) + self.model = get_model( + vllm_config=self.vllm_config, model_config=draft_model_config + ) draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names) - indexer_layers = get_layers_from_vllm_config(self.vllm_config, - DeepseekV32IndexerCache) - draft_indexer_layer_names = (indexer_layers.keys() - - target_indexer_layer_names) + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + - target_attn_layer_names + ) + indexer_layers = get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache + ) + draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names self.attn_layer_names = list(draft_attn_layer_names) self.indexer_layer_names = list(draft_indexer_layer_names) if self.indexer_layer_names: first_layer = self.indexer_layer_names[0] self.draft_indexer_metadata_builder = ( - indexer_layers[first_layer].get_attn_backend().get_builder_cls( - )( + indexer_layers[first_layer] + .get_attn_backend() + .get_builder_cls()( indexer_layers[first_layer].get_kv_cache_spec(), self.indexer_layer_names, self.vllm_config, self.device, - )) + ) + ) else: self.draft_indexer_metadata_builder = None @@ -902,38 +930,41 @@ def load_model(self, target_model: nn.Module) -> None: # Even if the target model is multimodal, we can also use # text-only draft models try: - dummy_input_ids = torch.tensor([[1]], - device=self.input_ids.device) - self.model.get_input_embeddings(dummy_input_ids, - multimodal_embeddings=None) + dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device) + self.model.get_input_embeddings( + dummy_input_ids, multimodal_embeddings=None + ) except (NotImplementedError, AttributeError, TypeError): logger.warning( "Draft model does not support multimodal inputs, " - "falling back to text-only mode") + "falling back to text-only mode" + ) self.supports_mm_inputs = False if supports_multimodal(target_model): # handle multimodality - if (self.get_model_name(target_model) == - "Qwen2_5_VLForConditionalGeneration"): - self.model.config.image_token_index = ( - target_model.config.image_token_id) + if ( + self.get_model_name(target_model) + == "Qwen2_5_VLForConditionalGeneration" + ): + self.model.config.image_token_index = target_model.config.image_token_id else: self.model.config.image_token_index = ( - target_model.config.image_token_index) + target_model.config.image_token_index + ) target_language_model = target_model.get_language_model() else: target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: - if hasattr(target_language_model.model, 'embed_tokens'): + if hasattr(target_language_model.model, "embed_tokens"): target_embed_tokens = target_language_model.model.embed_tokens - elif hasattr(target_language_model.model, 'embedding'): + elif hasattr(target_language_model.model, "embedding"): target_embed_tokens = target_language_model.model.embedding else: raise AttributeError( - "Target model does not have 'embed_tokens' or 'embedding' " - "attribute") + "Target model does not have 'embed_tokens' or 'embedding' attribute" + ) # Check if shapes match and we found the embedding eagle_shape = self.model.model.embed_tokens.weight.shape @@ -941,47 +972,53 @@ def load_model(self, target_model: nn.Module) -> None: if eagle_shape == target_shape: logger.info( "Assuming the EAGLE head shares the same vocab embedding" - " with the target model.") + " with the target model." + ) del self.model.model.embed_tokens self.model.model.embed_tokens = target_embed_tokens else: logger.info( "The EAGLE head's vocab embedding will be loaded separately" - " from the target model.") + " from the target model." + ) else: logger.info( "The EAGLE head's vocab embedding will be loaded separately" - " from the target model.") + " from the target model." + ) # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM if self.vllm_config.speculative_config.method != "eagle3": if hasattr(target_language_model, "lm_head"): - logger.info( - "Loading EAGLE LM head weights from the target model.") + logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_language_model.lm_head else: - if (hasattr(self.model, "lm_head") - and hasattr(target_language_model, "lm_head") - and self.model.lm_head.weight.shape - == target_language_model.lm_head.weight.shape): - logger.info("Assuming the EAGLE head shares the same lm_head" - " with the target model.") + if ( + hasattr(self.model, "lm_head") + and hasattr(target_language_model, "lm_head") + and self.model.lm_head.weight.shape + == target_language_model.lm_head.weight.shape + ): + logger.info( + "Assuming the EAGLE head shares the same lm_head" + " with the target model." + ) del self.model.lm_head self.model.lm_head = target_language_model.lm_head else: logger.info( "The EAGLE head's lm_head will be loaded separately" - " from the target model.") + " from the target model." + ) @torch.inference_mode() def dummy_run( self, num_tokens: int, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -996,8 +1033,7 @@ def dummy_run( inputs_embeds=inputs_embeds, ) - def _get_attention_metadata_builder( - self) -> list[AttentionMetadataBuilder]: + def _get_attention_metadata_builder(self) -> list[AttentionMetadataBuilder]: """Find and return the attention metadata builders for EAGLE layers. Returns: @@ -1018,11 +1054,11 @@ def _get_attention_metadata_builder( break assert builder is not None, ( - "Failed to find attention metadata builder for EAGLE layers.") + "Failed to find attention metadata builder for EAGLE layers." + ) return builder - def validate_same_kv_cache_group(self, - kv_cache_config: KVCacheConfig) -> None: + def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ Validate that all eagle layers belong to the same KVCacheGroup. Need this assumption to ensure all eagle layers can use the @@ -1033,12 +1069,17 @@ def validate_same_kv_cache_group(self, for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): for layer_name in kv_cache_group.layer_names: kv_cache_groups[layer_name] = id - assert len( - set([ - kv_cache_groups[layer_name] - for layer_name in self.attn_layer_names - ]) - ) == 1, "All eagle layers should belong to the same kv cache group" + assert ( + len( + set( + [ + kv_cache_groups[layer_name] + for layer_name in self.attn_layer_names + ] + ) + ) + == 1 + ), "All eagle layers should belong to the same kv cache group" # NOTE(woosuk): Currently, the below code is not used and we always use argmax diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 70b29c05c2a5..150dde177ce8 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -27,10 +27,9 @@ def __init__( # Save config parameters self.vllm_config = vllm_config self.device = device - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) - self.hidden_size = vllm_config.speculative_config.\ - draft_model_config.get_hidden_size( + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.hidden_size = ( + vllm_config.speculative_config.draft_model_config.get_hidden_size() ) self.dtype = vllm_config.model_config.dtype @@ -51,16 +50,19 @@ def propose( def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag + with set_model_tag("medusa_head"): - self.model = get_model(vllm_config=self.vllm_config, - model_config=self.vllm_config. - speculative_config.draft_model_config) + self.model = get_model( + vllm_config=self.vllm_config, + model_config=self.vllm_config.speculative_config.draft_model_config, + ) @torch.inference_mode() def dummy_run(self, num_tokens: int) -> None: - hidden_states = torch.zeros((self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device, + ) + with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): self.model(hidden_states) diff --git a/vllm/v1/spec_decode/metadata.py b/vllm/v1/spec_decode/metadata.py index b1efb40612d5..d0695244cb16 100644 --- a/vllm/v1/spec_decode/metadata.py +++ b/vllm/v1/spec_decode/metadata.py @@ -8,7 +8,6 @@ @dataclass class SpecDecodeMetadata: - # [num_tokens] draft_token_ids: torch.Tensor # [batch_size] @@ -36,22 +35,19 @@ def make_dummy( flattened_draft_token_ids = sum(draft_token_ids, []) num_tokens = len(flattened_draft_token_ids) - draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids, - dtype=torch.int32, - device=device) + draft_token_ids_tensor = torch.tensor( + flattened_draft_token_ids, dtype=torch.int32, device=device + ) cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) - cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to( - device) + cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device) - target_logits_indices = torch.zeros(num_tokens, - dtype=torch.int32, - device=device) - bonus_logits_indices = torch.zeros(batch_size, - dtype=torch.int32, - device=device) - logits_indices = torch.zeros(num_tokens + batch_size, - dtype=torch.int32, - device=device) + target_logits_indices = torch.zeros( + num_tokens, dtype=torch.int32, device=device + ) + bonus_logits_indices = torch.zeros(batch_size, dtype=torch.int32, device=device) + logits_indices = torch.zeros( + num_tokens + batch_size, dtype=torch.int32, device=device + ) return cls( draft_token_ids=draft_token_ids_tensor, num_draft_tokens=num_draft_tokens, diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 282e6f65e7ab..89a8a11a3d56 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -31,8 +31,10 @@ class SpecDecodingStats: @classmethod def new(cls, num_spec_tokens: int) -> "SpecDecodingStats": - return cls(num_spec_tokens=num_spec_tokens, - num_accepted_tokens_per_pos=[0] * num_spec_tokens) + return cls( + num_spec_tokens=num_spec_tokens, + num_accepted_tokens_per_pos=[0] * num_spec_tokens, + ) def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): self.num_drafts += 1 @@ -64,10 +66,10 @@ def reset(self): def observe(self, spec_decoding_stats: SpecDecodingStats): self.num_drafts.append(spec_decoding_stats.num_drafts) self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens) - self.num_accepted_tokens.append( - spec_decoding_stats.num_accepted_tokens) + self.num_accepted_tokens.append(spec_decoding_stats.num_accepted_tokens) self.accepted_tokens_per_pos_lists.append( - spec_decoding_stats.num_accepted_tokens_per_pos) + spec_decoding_stats.num_accepted_tokens_per_pos + ) def log(self, log_fn=logger.info): if not self.num_drafts: @@ -83,8 +85,11 @@ def log(self, log_fn=logger.info): draft_throughput = num_draft_tokens / elapsed_time accepted_throughput = num_accepted_tokens / elapsed_time - draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * - 100 if num_draft_tokens > 0 else float("nan")) + draft_acceptance_rate = ( + num_accepted_tokens / num_draft_tokens * 100 + if num_draft_tokens > 0 + else float("nan") + ) # Conventionally, mean acceptance length includes the bonus token mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts) @@ -149,27 +154,36 @@ def __init__( counter_drafts = self._counter_cls( name="vllm:spec_decode_num_drafts", documentation="Number of spec decoding drafts.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_spec_decode_num_drafts = make_per_engine( - counter_drafts, per_engine_labelvalues) + counter_drafts, per_engine_labelvalues + ) counter_draft_tokens = self._counter_cls( name="vllm:spec_decode_num_draft_tokens", documentation="Number of draft tokens.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_spec_decode_num_draft_tokens = make_per_engine( - counter_draft_tokens, per_engine_labelvalues) + counter_draft_tokens, per_engine_labelvalues + ) counter_accepted_tokens = self._counter_cls( name="vllm:spec_decode_num_accepted_tokens", documentation="Number of accepted tokens.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_spec_decode_num_accepted_tokens = make_per_engine( - counter_accepted_tokens, per_engine_labelvalues) + counter_accepted_tokens, per_engine_labelvalues + ) assert speculative_config is not None - num_spec_tokens = (speculative_config.num_speculative_tokens - if self.spec_decoding_enabled else 0) + num_spec_tokens = ( + speculative_config.num_speculative_tokens + if self.spec_decoding_enabled + else 0 + ) pos_labelnames = labelnames + ["position"] base_counter = self._counter_cls( name="vllm:spec_decode_num_accepted_tokens_per_pos", @@ -177,33 +191,33 @@ def __init__( labelnames=pos_labelnames, ) self.counter_spec_decode_num_accepted_tokens_per_pos: dict[ - int, list[prometheus_client.Counter]] = { - idx: [ - base_counter.labels(*lv, str(pos)) - for pos in range(num_spec_tokens) - ] - for idx, lv in per_engine_labelvalues.items() - } - - def observe(self, - spec_decoding_stats: SpecDecodingStats, - engine_idx: int = 0): + int, list[prometheus_client.Counter] + ] = { + idx: [base_counter.labels(*lv, str(pos)) for pos in range(num_spec_tokens)] + for idx, lv in per_engine_labelvalues.items() + } + + def observe(self, spec_decoding_stats: SpecDecodingStats, engine_idx: int = 0): if not self.spec_decoding_enabled: return self.counter_spec_decode_num_drafts[engine_idx].inc( - spec_decoding_stats.num_drafts) + spec_decoding_stats.num_drafts + ) self.counter_spec_decode_num_draft_tokens[engine_idx].inc( - spec_decoding_stats.num_draft_tokens) + spec_decoding_stats.num_draft_tokens + ) self.counter_spec_decode_num_accepted_tokens[engine_idx].inc( - spec_decoding_stats.num_accepted_tokens) + spec_decoding_stats.num_accepted_tokens + ) for pos, counter in enumerate( - self. - counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]): + self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx] + ): counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) -def make_per_engine(counter: prometheus_client.Counter, - per_engine_labelvalues: dict[int, list[str]]): +def make_per_engine( + counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[str]] +): """Create a counter for each label value.""" return { idx: counter.labels(*labelvalues) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index aed050a3540c..e2f83cb24aa9 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -9,7 +9,6 @@ class NgramProposer: - def __init__(self, vllm_config: VllmConfig): assert vllm_config.speculative_config is not None assert vllm_config.speculative_config.prompt_lookup_min is not None @@ -28,8 +27,7 @@ def __init__(self, vllm_config: VllmConfig): # Pre-allocate buffers for numba batch propose. max_num_seqs = vllm_config.scheduler_config.max_num_seqs - self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), - dtype=np.int32) + self.valid_ngram_draft = np.zeros((max_num_seqs, self.k), dtype=np.int32) self.valid_ngram_num_drafts = np.zeros((max_num_seqs), dtype=np.int32) # Threshold of total number of tokens in the batch to enable @@ -55,9 +53,13 @@ def __init__(self, vllm_config: VllmConfig): # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. - self.propose([[]] * 1024, [""] * 1024, np.zeros(1024, dtype=np.int32), - np.zeros((1024, self.max_model_len), dtype=np.int32), - set()) + self.propose( + [[]] * 1024, + [""] * 1024, + np.zeros(1024, dtype=np.int32), + np.zeros((1024, self.max_model_len), dtype=np.int32), + set(), + ) def batch_propose( self, @@ -67,20 +69,20 @@ def batch_propose( token_ids_cpu: np.ndarray, ) -> list[list[int]]: """Batch version of ngram proposer using numba for acceleration. - + Args: - valid_ngram_requests: + valid_ngram_requests: Set of indices of requests that need ngram proposals. - num_tokens_no_spec: - Numpy array of shape (batch_size,) representing the number + num_tokens_no_spec: + Numpy array of shape (batch_size,) representing the number of tokens without speculative tokens for each request. - token_ids_cpu: - Numpy array of shape (batch_size, max_model_len) + token_ids_cpu: + Numpy array of shape (batch_size, max_model_len) representing the token IDs for each request. Returns: - list[list[int]]: - A list where each element is a list of proposed + list[list[int]]: + A list where each element is a list of proposed token IDs for the corresponding request. """ draft_token_ids: list[list[int]] = [] @@ -96,26 +98,32 @@ def batch_propose( total_tokens = np.sum(num_tokens_no_spec) if total_tokens >= self.num_tokens_threshold: final_num_threads = max( - 1, min(self.num_numba_thread_available, - num_ngram_requests)) + 1, min(self.num_numba_thread_available, num_ngram_requests) + ) set_num_threads(final_num_threads) else: set_num_threads(1) - batch_propose_numba(valid_ngram_requests, num_tokens_no_spec, - token_ids_cpu, self.min_n, self.max_n, - self.max_model_len, self.k, - self.valid_ngram_draft, - self.valid_ngram_num_drafts) + batch_propose_numba( + valid_ngram_requests, + num_tokens_no_spec, + token_ids_cpu, + self.min_n, + self.max_n, + self.max_model_len, + self.k, + self.valid_ngram_draft, + self.valid_ngram_num_drafts, + ) # Restore original number of threads. set_num_threads(original_num_numba_threads) for i in range(num_requests): - if i in valid_ngram_requests and \ - self.valid_ngram_num_drafts[i] > 0: - draft_token_ids.append(self.valid_ngram_draft[ - i, :self.valid_ngram_num_drafts[i]].tolist()) + if i in valid_ngram_requests and self.valid_ngram_num_drafts[i] > 0: + draft_token_ids.append( + self.valid_ngram_draft[i, : self.valid_ngram_num_drafts[i]].tolist() + ) else: draft_token_ids.append([]) @@ -129,7 +137,6 @@ def propose( token_ids_cpu: np.ndarray, spec_decode_unsupported_reqs: set, ) -> list[list[int]]: - # find which requests need ngram proposals valid_ngram_requests = [] for i, sampled_ids in enumerate(sampled_token_ids): @@ -166,12 +173,17 @@ def load_model(self, *args, **kwargs): @njit(parallel=True) -def batch_propose_numba(valid_ngram_requests: list, - num_tokens_no_spec: np.ndarray, - token_ids_cpu: np.ndarray, min_n: int, max_n: int, - max_model_len: int, k: int, - valid_ngram_draft: np.ndarray, - valid_ngram_num_drafts: np.ndarray): +def batch_propose_numba( + valid_ngram_requests: list, + num_tokens_no_spec: np.ndarray, + token_ids_cpu: np.ndarray, + min_n: int, + max_n: int, + max_model_len: int, + k: int, + valid_ngram_draft: np.ndarray, + valid_ngram_num_drafts: np.ndarray, +): for i in prange(len(valid_ngram_requests)): idx = valid_ngram_requests[i] num_tokens = num_tokens_no_spec[idx] @@ -181,19 +193,22 @@ def batch_propose_numba(valid_ngram_requests: list, min_ngram=min_n, max_ngram=max_n, max_model_len=max_model_len, - k=k) + k=k, + ) valid_ngram_num_drafts[i] = drafter_output.shape[0] if len(drafter_output): - valid_ngram_draft[i, :drafter_output.shape[0]] = drafter_output + valid_ngram_draft[i, : drafter_output.shape[0]] = drafter_output @jit(nopython=True) -def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray, - min_ngram: int, - max_ngram: int, - max_model_len: int, - k: int) -> np.ndarray: +def _find_longest_matched_ngram_and_propose_tokens( + origin_tokens: np.ndarray, + min_ngram: int, + max_ngram: int, + max_model_len: int, + k: int, +) -> np.ndarray: """ Find the longest n-gram which matches the suffix of the given tokens whose length is within [min_ngram, max_ngram] (inclusive). @@ -203,12 +218,12 @@ def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray, # Do not generate draft tokens is context is shorter than minimum n-gram total_token = origin_tokens.shape[0] if total_token < min_ngram: - return np.empty((0, ), dtype=origin_tokens.dtype) + return np.empty((0,), dtype=origin_tokens.dtype) # Do not generate draft tokens beyond the max model length. k = min(k, max_model_len - total_token) if k <= 0: - return np.empty((0, ), dtype=origin_tokens.dtype) + return np.empty((0,), dtype=origin_tokens.dtype) # Flip tokens, and the goal become to find longest ngram # on the rightmost position which matches the prefix with @@ -265,7 +280,7 @@ def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray, if longest_ngram < min_ngram: # No valid ngram is found - return np.empty((0, ), dtype=origin_tokens.dtype) + return np.empty((0,), dtype=origin_tokens.dtype) # Flip the position back, so in origin_tokens, # origin_tokens[total_token-1-position:total_token-1-position+longest_ngram] @@ -273,4 +288,4 @@ def _find_longest_matched_ngram_and_propose_tokens(origin_tokens: np.ndarray, # total_token-1-position+longest_ngram start_position = total_token - 1 - position + longest_ngram k = min(k, total_token - start_position) - return origin_tokens[start_position:start_position + k] + return origin_tokens[start_position : start_position + k] diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 1116179dc5b6..1901c6fc9f14 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -7,8 +7,10 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: """True if request is incompatible with speculative decoding""" - return (sampling_params.frequency_penalty != 0.0 - or sampling_params.presence_penalty != 0.0 - or sampling_params.repetition_penalty != 1.0 - or sampling_params.min_p > _SAMPLING_EPS - or sampling_params.logprobs is not None) + return ( + sampling_params.frequency_penalty != 0.0 + or sampling_params.presence_penalty != 0.0 + or sampling_params.repetition_penalty != 1.0 + or sampling_params.min_p > _SAMPLING_EPS + or sampling_params.logprobs is not None + ) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 13c33d3edf14..1b5e75313d89 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -4,7 +4,7 @@ import multiprocessing from concurrent.futures import Future, ThreadPoolExecutor -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from vllm.config import VllmConfig from vllm.logger import init_logger @@ -12,8 +12,10 @@ from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, +) from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend if TYPE_CHECKING: @@ -33,11 +35,11 @@ class StructuredOutputManager: """Engine-level manager for structured output requests.""" def __init__(self, vllm_config: VllmConfig): - self.backend: Optional[StructuredOutputBackend] = None - self.reasoner: Optional[ReasoningParser] = None + self.backend: StructuredOutputBackend | None = None + self.reasoner: ReasoningParser | None = None self.vllm_config = vllm_config - self._grammar_bitmask: Optional[torch.Tensor] = None + self._grammar_bitmask: torch.Tensor | None = None self._full_mask = torch.tensor(-1, dtype=torch.int32) max_batch_size = self.vllm_config.scheduler_config.max_num_seqs @@ -48,8 +50,7 @@ def __init__(self, vllm_config: VllmConfig): # - at least 1 CPU # - at most half the number of CPUs or 8, whichever is less max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8)) - self.executor_for_fillmask = ThreadPoolExecutor( - max_workers=max_workers) + self.executor_for_fillmask = ThreadPoolExecutor(max_workers=max_workers) if not self.vllm_config.model_config.skip_tokenizer_init: # The default max_workers if not specified is the number of @@ -60,12 +61,15 @@ def __init__(self, vllm_config: VllmConfig): max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) self.tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config) - reasoning_parser = \ - self.vllm_config.structured_outputs_config.reasoning_parser + model_config=self.vllm_config.model_config + ) + reasoning_parser = ( + self.vllm_config.structured_outputs_config.reasoning_parser + ) if reasoning_parser: reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_parser) + reasoning_parser + ) self.reasoner = reasoner_cls(tokenizer=self.tokenizer) def grammar_init(self, request: Request) -> None: @@ -73,8 +77,10 @@ def grammar_init(self, request: Request) -> None: return if TYPE_CHECKING: - assert request.sampling_params is not None and \ - request.sampling_params.structured_outputs is not None + assert ( + request.sampling_params is not None + and request.sampling_params.structured_outputs is not None + ) # Initialize the backend the first time it is needed. # @@ -98,8 +104,7 @@ def grammar_init(self, request: Request) -> None: vocab_size=vocab_size, ) elif backend == "outlines": - from vllm.v1.structured_output.backend_outlines import ( - OutlinesBackend) + from vllm.v1.structured_output.backend_outlines import OutlinesBackend self.backend = OutlinesBackend( self.vllm_config, @@ -108,15 +113,16 @@ def grammar_init(self, request: Request) -> None: ) elif backend == "lm-format-enforcer": from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501 - LMFormatEnforcerBackend) + LMFormatEnforcerBackend, + ) + self.backend = LMFormatEnforcerBackend( self.vllm_config, tokenizer=self.tokenizer, vocab_size=vocab_size, ) else: - raise ValueError( - f"Unsupported structured output backend: {backend}") + raise ValueError(f"Unsupported structured output backend: {backend}") grammar = self.executor.submit(self._async_create_grammar, request) request.structured_output_request.grammar = grammar # type: ignore[assignment] @@ -162,15 +168,16 @@ def grammar_bitmask( requests: dict[str, Request], structured_output_request_ids: dict[str, int], scheduled_spec_decode_tokens: dict[str, list[int]], - ) -> Optional[npt.NDArray[np.int32]]: + ) -> npt.NDArray[np.int32] | None: # Prepare the structured output bitmask for this batch. if not structured_output_request_ids: return None max_num_spec_tokens = 0 if self.vllm_config.speculative_config is not None: - max_num_spec_tokens = \ + max_num_spec_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens + ) if self._grammar_bitmask is None: assert self.backend is not None @@ -179,22 +186,23 @@ def grammar_bitmask( # Allocate a bitmask for each token needing to be checked: # one for each speculative position, and one more for the # bonus token / non-speculative token. - self._grammar_bitmask = \ - self.backend.allocate_token_bitmask( - max_batch_size * (1 + max_num_spec_tokens)) + self._grammar_bitmask = self.backend.allocate_token_bitmask( + max_batch_size * (1 + max_num_spec_tokens) + ) # Generate a batched bitmask for all structured output requests. # When speculative decoding is enabled, we need to include multiple # masks for each request, one for each possible bonus token position. # These are stored inline in the tensor and unpacked by the gpu runner. cumulative_index = 0 - ordered_seq = sorted(structured_output_request_ids.items(), - key=lambda x: x[1]) + ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) # Optimized parallel filling of bitmasks for # non-spec, large-batch-size cases - if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \ - max_num_spec_tokens == 0: + if ( + len(ordered_seq) > self.fill_bitmask_parallel_threshold + and max_num_spec_tokens == 0 + ): promises = [] batch = [] for req_id, _ in ordered_seq: @@ -205,8 +213,9 @@ def grammar_bitmask( assert structured_output_request.grammar is not None apply_bitmask = self.should_fill_bitmask(request) - batch.append((structured_output_request.grammar, - cumulative_index, apply_bitmask)) + batch.append( + (structured_output_request.grammar, cumulative_index, apply_bitmask) + ) if len(batch) == self.fill_bitmask_parallel_batch_size: promises.append(self._async_submit_fill_bitmask(batch)) batch = [] @@ -232,18 +241,28 @@ def grammar_bitmask( state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) for i, token in enumerate(req_tokens + [None]): - self._fill_bitmasks([(structured_output_request.grammar, - cumulative_index, apply_bitmask)]) - - if apply_bitmask and token is not None and \ - not structured_output_request.grammar.is_terminated(): + self._fill_bitmasks( + [ + ( + structured_output_request.grammar, + cumulative_index, + apply_bitmask, + ) + ] + ) + + if ( + apply_bitmask + and token is not None + and not structured_output_request.grammar.is_terminated() + ): assert structured_output_request.grammar.accept_tokens( - req_id, [token]) + req_id, [token] + ) state_advancements += 1 cumulative_index += 1 if state_advancements > 0: - structured_output_request.grammar.rollback( - state_advancements) + structured_output_request.grammar.rollback(state_advancements) bitmask_tensor = self._grammar_bitmask if cumulative_index < bitmask_tensor.shape[0]: @@ -258,8 +277,9 @@ def should_fill_bitmask(self, request: Request) -> bool: if self.reasoner is not None: assert request.structured_output_request is not None if request.structured_output_request.reasoning_ended is None: - request.structured_output_request.reasoning_ended = \ + request.structured_output_request.reasoning_ended = ( self.reasoner.is_reasoning_end(request.prompt_token_ids) + ) return request.structured_output_request.reasoning_ended return True diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index e06ab6377de3..081cdfdc9932 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -7,16 +7,18 @@ import json import os from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import torch from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) from vllm.v1.structured_output.request import get_structured_output_key if TYPE_CHECKING: @@ -26,8 +28,7 @@ else: llguidance = LazyLoader("llguidance", globals(), "llguidance") llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf") - llguidance_torch = LazyLoader("llguidance.torch", globals(), - "llguidance.torch") + llguidance_torch = LazyLoader("llguidance.torch", globals(), "llguidance.torch") logger = init_logger(__name__) @@ -36,16 +37,18 @@ def _walk_json_for_additional_properties(data: object): if isinstance(data, dict): for value in data.values(): _walk_json_for_additional_properties(value) - if 'additionalProperties' not in data and \ - ('properties' in data or 'patternProperties' in data): - data['additionalProperties'] = False + if "additionalProperties" not in data and ( + "properties" in data or "patternProperties" in data + ): + data["additionalProperties"] = False elif isinstance(data, list): for item in data: _walk_json_for_additional_properties(item) def process_for_additional_properties( - guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]: + guide_json: Union[str, dict[str, Any]], +) -> dict[str, Any]: if isinstance(guide_json, str): guide_json_obj = json.loads(guide_json) else: @@ -57,21 +60,27 @@ def process_for_additional_properties( @dataclass class GuidanceBackend(StructuredOutputBackend): - def __post_init__(self): - self.disable_any_whitespace = \ + self.disable_any_whitespace = ( self.vllm_config.structured_outputs_config.disable_any_whitespace - self.disable_additional_properties = \ + ) + self.disable_additional_properties = ( self.vllm_config.structured_outputs_config.disable_additional_properties + ) self.ll_tokenizer = llguidance_hf.from_tokenizer( - self.tokenizer, self.vocab_size) + self.tokenizer, self.vocab_size + ) - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: self.serialized_grammar = serialize_guidance_grammar( - request_type, grammar_spec, self.disable_any_whitespace, - self.disable_additional_properties) + request_type, + grammar_spec, + self.disable_any_whitespace, + self.disable_additional_properties, + ) ll_matcher = llguidance.LLMatcher( self.ll_tokenizer, @@ -90,7 +99,8 @@ def compile_grammar(self, request_type: StructuredOutputOptions, def allocate_token_bitmask(self, max_num_seqs: int): return llguidance_torch.allocate_token_bitmask( - max_num_seqs, self.ll_tokenizer.vocab_size) + max_num_seqs, self.ll_tokenizer.vocab_size + ) def destroy(self): pass @@ -178,15 +188,17 @@ def serialize_guidance_grammar( disable_any_whitespace: bool = False, disable_additional_properties: bool = False, ) -> str: - - def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str: + def _process_schema( + grammar_spec: Union[str, dict[str, Any]], + ) -> str: if disable_additional_properties: grammar_spec = process_for_additional_properties(grammar_spec) return llguidance.LLMatcher.grammar_from_json_schema( grammar_spec, defaults={ "whitespace_flexible": not disable_any_whitespace, - }) + }, + ) if request_type == StructuredOutputOptions.JSON: return _process_schema(grammar_spec) @@ -195,7 +207,8 @@ def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str: '{"type": "object"}', defaults={ "whitespace_flexible": not disable_any_whitespace, - }) + }, + ) else: if request_type == StructuredOutputOptions.REGEX: tp = "regex" @@ -215,29 +228,32 @@ def _process_schema(grammar_spec: Union[str, dict[str, Any]], ) -> str: trig = next((t for t in triggers if begin.startswith(t)), None) if trig is None: raise ValueError( - f"Trigger {begin} not found in triggers {triggers}") + f"Trigger {begin} not found in triggers {triggers}" + ) tags.append( llguidance.StructTag( trigger=trig, begin=s["begin"], grammar=_process_schema(s["schema"]), end=s["end"], - )) + ) + ) if not tags: - raise ValueError( - "No structural tags found in the grammar spec.") + raise ValueError("No structural tags found in the grammar spec.") return llguidance.StructTag.to_grammar(tags) else: - logger.error("Validation should have already occurred. " - "Please file an issue.") - raise ValueError("grammar is not of valid supported types. " - f"({request_type!s})") + logger.error( + "Validation should have already occurred. Please file an issue." + ) + raise ValueError( + f"grammar is not of valid supported types. ({request_type!s})" + ) return llguidance.grammar_from(tp, grammar_spec) def validate_guidance_grammar( - sampling_params: SamplingParams, - tokenizer: Optional[llguidance.LLTokenizer] = None) -> None: + sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None +) -> None: tp, grm = get_structured_output_key(sampling_params) guidance_grm = serialize_guidance_grammar(tp, grm) err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer) diff --git a/vllm/v1/structured_output/backend_lm_format_enforcer.py b/vllm/v1/structured_output/backend_lm_format_enforcer.py index 465b2428f893..d9e484092d6a 100644 --- a/vllm/v1/structured_output/backend_lm_format_enforcer.py +++ b/vllm/v1/structured_output/backend_lm_format_enforcer.py @@ -13,26 +13,31 @@ from vllm.sampling_params import SamplingParams from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) if TYPE_CHECKING: import lmformatenforcer import lmformatenforcer.integrations.vllm as lmfe_vllm else: - lmformatenforcer = LazyLoader("lmformatenforcer", globals(), - "lmformatenforcer") - lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(), - "lmformatenforcer.integrations.vllm") + lmformatenforcer = LazyLoader("lmformatenforcer", globals(), "lmformatenforcer") + lmfe_vllm = LazyLoader( + "lmformatenforcer.integrations.vllm", + globals(), + "lmformatenforcer.integrations.vllm", + ) @lru_cache def _cached_build_vllm_token_enforcer_tokenizer_data( - tokenizer: PreTrainedTokenizerBase, - vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData: + tokenizer: PreTrainedTokenizerBase, vocab_size: int +) -> lmfe_vllm.TokenEnforcerTokenizerData: return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data( - tokenizer, use_bitmask=True, vocab_size=vocab_size) + tokenizer, use_bitmask=True, vocab_size=vocab_size + ) @dataclass @@ -44,7 +49,8 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: original_len = len(self.current_tokens_prefix) for token in tokens: if not self.token_enforcer.get_allowed_tokens( - self.current_tokens_prefix).is_token_allowed(token): + self.current_tokens_prefix + ).is_token_allowed(token): # Rollback partial updates to ensure atomicity. del self.current_tokens_prefix[original_len:] return False @@ -56,8 +62,8 @@ def validate_tokens(self, tokens: list[int]) -> list[int]: prefix = tokens[:prefix_length] next_token = tokens[prefix_length] if not self.token_enforcer.get_allowed_tokens( - self.current_tokens_prefix + - prefix).is_token_allowed(next_token): + self.current_tokens_prefix + prefix + ).is_token_allowed(next_token): break else: return tokens @@ -69,14 +75,16 @@ def rollback(self, num_tokens: int) -> None: def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: allowed_tokens = self.token_enforcer.get_allowed_tokens( - self.current_tokens_prefix) + self.current_tokens_prefix + ) bitmask[batch_index] = allowed_tokens.allowed_tokens def is_terminated(self) -> bool: # We are considered terminated if the prefix ends with eos_token_id - return_value = len( - self.current_tokens_prefix) > 0 and self.current_tokens_prefix[ - -1] == self.token_enforcer.eos_token_id + return_value = ( + len(self.current_tokens_prefix) > 0 + and self.current_tokens_prefix[-1] == self.token_enforcer.eos_token_id + ) return return_value def reset(self): @@ -85,18 +93,18 @@ def reset(self): @dataclass class LMFormatEnforcerBackend(StructuredOutputBackend): - def __post_init__(self): self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( - self.tokenizer, self.vocab_size) + self.tokenizer, self.vocab_size + ) - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: character_level_parser: lmformatenforcer.CharacterLevelParser if request_type == StructuredOutputOptions.JSON: spec_dict = json.loads(grammar_spec) - character_level_parser = lmformatenforcer.JsonSchemaParser( - spec_dict) + character_level_parser = lmformatenforcer.JsonSchemaParser(spec_dict) elif request_type == StructuredOutputOptions.JSON_OBJECT: character_level_parser = lmformatenforcer.JsonSchemaParser(None) elif request_type == StructuredOutputOptions.REGEX: @@ -104,14 +112,17 @@ def compile_grammar(self, request_type: StructuredOutputOptions, elif request_type == StructuredOutputOptions.CHOICE: choices = ast.literal_eval(grammar_spec) character_level_parser = lmformatenforcer.UnionParser( - [lmformatenforcer.StringParser(choice) for choice in choices]) + [lmformatenforcer.StringParser(choice) for choice in choices] + ) else: raise ValueError( - "Invalid request type for LM Format Enforcer backend" - f"({request_type!s})") + f"Invalid request type for LM Format Enforcer backend({request_type!s})" + ) max_rollback_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config is not None else 0) + if self.vllm_config.speculative_config is not None + else 0 + ) if max_rollback_tokens > 0: raise ValueError( @@ -136,8 +147,7 @@ def destroy(self): pass -def validate_structured_output_request_lm_format_enforcer( - params: SamplingParams): +def validate_structured_output_request_lm_format_enforcer(params: SamplingParams): if params.structured_outputs is None: return @@ -163,5 +173,7 @@ def validate_structured_output_request_lm_format_enforcer( elif so_params.choice: return elif so_params.grammar: - raise ValueError("LM Format Enforcer structured outputs backend " - "does not support grammar specifications") + raise ValueError( + "LM Format Enforcer structured outputs backend " + "does not support grammar specifications" + ) diff --git a/vllm/v1/structured_output/backend_outlines.py b/vllm/v1/structured_output/backend_outlines.py index e5e638a6ad76..c9875337179e 100644 --- a/vllm/v1/structured_output/backend_outlines.py +++ b/vllm/v1/structured_output/backend_outlines.py @@ -15,20 +15,23 @@ from vllm.sampling_params import SamplingParams from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) -from vllm.v1.structured_output.utils import (OutlinesVocabulary, - get_outlines_cache, - get_outlines_vocabulary) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) +from vllm.v1.structured_output.utils import ( + OutlinesVocabulary, + get_outlines_cache, + get_outlines_vocabulary, +) if TYPE_CHECKING: import outlines_core as oc import outlines_core.json_schema as json_schema else: oc = LazyLoader("oc", globals(), "outlines_core") - json_schema = LazyLoader("json_schema", globals(), - "outlines_core.json_schema") + json_schema = LazyLoader("json_schema", globals(), "outlines_core.json_schema") # Python 3.11+ sre_parse and sre_constants # are deprecated, so we must import them from re @@ -46,13 +49,13 @@ @dataclass class OutlinesBackend(StructuredOutputBackend): - def __post_init__(self): self.vocabulary = get_outlines_vocabulary(self.tokenizer) self.cache = get_outlines_cache() - def _compile_index(self, regex_string: str, - vocabulary: OutlinesVocabulary) -> oc.Index: + def _compile_index( + self, regex_string: str, vocabulary: OutlinesVocabulary + ) -> oc.Index: cache_key = f"{vocabulary._hash}_{regex_string}" if cache_key in self.cache: return self.cache[cache_key] @@ -62,8 +65,9 @@ def _compile_index(self, regex_string: str, return index - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: if request_type == StructuredOutputOptions.JSON: regex = json_schema.build_regex_from_schema(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: @@ -79,10 +83,13 @@ def compile_grammar(self, request_type: StructuredOutputOptions, index = self._compile_index(regex, self.vocabulary) max_rollback_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config is not None else 0) - return OutlinesGrammar(vocab_size=self.vocab_size, - guide=oc.Guide( - index, max_rollback=max_rollback_tokens)) + if self.vllm_config.speculative_config is not None + else 0 + ) + return OutlinesGrammar( + vocab_size=self.vocab_size, + guide=oc.Guide(index, max_rollback=max_rollback_tokens), + ) def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: return torch.full( @@ -98,20 +105,15 @@ def destroy(self): @dataclass class OutlinesGrammar(StructuredOutputGrammar): - vocab_size: int guide: oc.Guide = field(hash=False) - num_processed_tokens: int = field(default_factory=lambda: 0, - repr=False, - hash=False, - init=False) + num_processed_tokens: int = field( + default_factory=lambda: 0, repr=False, hash=False, init=False + ) # outlines_core signals done on DFA accept; vLLM expects done after EOS. # We delay the finished flag by one step so EOS can still be emitted. - _prev_finished: bool = field(default=False, - init=False, - repr=False, - hash=False) + _prev_finished: bool = field(default=False, init=False, repr=False, hash=False) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: """Accepts a list of tokens and advances the FSM. @@ -142,8 +144,7 @@ def validate_tokens(self, tokens: list[int]) -> list[int]: def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: mask = bitmask[idx] - self.guide.write_mask_into(mask.data_ptr(), mask.numel(), - mask.element_size()) + self.guide.write_mask_into(mask.data_ptr(), mask.numel(), mask.element_size()) def is_terminated(self) -> bool: curr = self.guide.is_finished() @@ -187,8 +188,10 @@ def validate_structured_output_request_outlines(params: SamplingParams): regex = "(" + "|".join(choices) + ")" validate_regex_is_buildable(regex) elif so_params.grammar: - raise ValueError("Outlines structured outputs backend " - "does not support grammar specifications") + raise ValueError( + "Outlines structured outputs backend " + "does not support grammar specifications" + ) def _prefix_needs_context(parsed) -> bool: @@ -196,7 +199,7 @@ def _prefix_needs_context(parsed) -> bool: def subpattern_consumes(parsed) -> bool: """Return True if subpattern can consume at least one character.""" - tokens = parsed.data if hasattr(parsed, 'data') else parsed + tokens = parsed.data if hasattr(parsed, "data") else parsed for ttype, tval in tokens: # literal, character class, or dot always consumes if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY): @@ -212,17 +215,18 @@ def subpattern_consumes(parsed) -> bool: if any(subpattern_consumes(br) for br in branches): return True # grouped subpattern: recurse into its contents - elif ttype == sre_parse.SUBPATTERN and subpattern_consumes( - tval[3]): + elif ttype == sre_parse.SUBPATTERN and subpattern_consumes(tval[3]): return True # No consumers, return False return False - tokens = parsed.data if hasattr(parsed, 'data') else parsed + tokens = parsed.data if hasattr(parsed, "data") else parsed for ttype, tval in tokens: # Direct anchors or look-around - if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT, - sre_constants.ASSERT_NOT): + if ttype == sre_parse.AT or ttype in ( + sre_constants.ASSERT, + sre_constants.ASSERT_NOT, + ): return True # Nested subpattern: check @@ -261,9 +265,8 @@ def subpattern_consumes(parsed) -> bool: def _check_unsupported(parsed) -> None: """Check for regex features unsupported by regex-automata""" - tokens = parsed.data if hasattr(parsed, 'data') else parsed + tokens = parsed.data if hasattr(parsed, "data") else parsed for ttype, tval in tokens: - # backreference if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS): raise ValueError("Backreferences are unsupported.") @@ -274,8 +277,7 @@ def _check_unsupported(parsed) -> None: # unicode word boundaries elif ttype == sre_parse.AT: - if tval in (sre_constants.AT_BOUNDARY, - sre_constants.AT_NON_BOUNDARY): + if tval in (sre_constants.AT_BOUNDARY, sre_constants.AT_NON_BOUNDARY): raise ValueError("Unicode word boundaries are unsupported.") elif ttype == sre_parse.BRANCH: @@ -308,7 +310,8 @@ def validate_regex_is_buildable(pattern: str) -> None: raise ValueError( f"Regex uses unsupported feature for structured outputs: {e}. " "Only basic matching constructs are supported—lookarounds, " - "backreferences, and unicode boundaries are not.") from e + "backreferences, and unicode boundaries are not." + ) from e if _prefix_needs_context(parsed): raise ValueError( @@ -317,4 +320,5 @@ def validate_regex_is_buildable(pattern: str) -> None: "in a way which requires context before any token is matched." "structured outputs needs regexes that can match without needing " "that context. Try rewriting the pattern without using these " - f"constructs. Pattern:\n{pattern}") + f"constructs. Pattern:\n{pattern}" + ) diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 9a53aa7a1ad1..2051b336e5bf 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -103,8 +103,9 @@ class StructuredOutputBackend(ABC): vocab_size: int @abstractmethod - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: """ Compiles a grammar specification into a structured output grammar. diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index a853e6540719..9f81d09633d7 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -14,12 +14,16 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader -from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, - StructuredOutputGrammar, - StructuredOutputOptions) -from vllm.v1.structured_output.utils import (choice_as_grammar, - convert_lark_to_ebnf, - grammar_is_likely_lark) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions, +) +from vllm.v1.structured_output.utils import ( + choice_as_grammar, + convert_lark_to_ebnf, + grammar_is_likely_lark, +) if TYPE_CHECKING: import xgrammar as xgr @@ -31,10 +35,10 @@ @dataclass class XgrammarBackend(StructuredOutputBackend): - def __post_init__(self): - self.disable_any_whitespace = \ + self.disable_any_whitespace = ( self.vllm_config.structured_outputs_config.disable_any_whitespace + ) if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. @@ -44,27 +48,33 @@ def __post_init__(self): encoded_vocab = self.tokenizer._vocab else: encoded_vocab = [ - token for token, _ in sorted( + token + for token, _ in sorted( self.tokenizer.get_vocab().items(), key=lambda x: x[1], ) ] stop_token_ids = None - if (hasattr( + if ( + hasattr( self.tokenizer, "eos_token_id", - ) and self.tokenizer.eos_token_id is not None): + ) + and self.tokenizer.eos_token_id is not None + ): stop_token_ids = [self.tokenizer.eos_token_id] except AttributeError as e: raise ValueError( f"Cannot get the vocabulary of the tokenizer " f"{type(self.tokenizer)}. The tokenizer should have a " - "get_vocab method.") from e + "get_vocab method." + ) from e tokenizer_info = xgr.TokenizerInfo( # type: ignore encoded_vocab=encoded_vocab, # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 vocab_type=xgr.VocabType.RAW - if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, + if self.tokenizer.is_tekken + else xgr.VocabType.BYTE_FALLBACK, vocab_size=self.vocab_size, stop_token_ids=stop_token_ids, add_prefix_space=True, @@ -83,18 +93,21 @@ def __post_init__(self): self.num_speculative_tokens = 0 if self.vllm_config.speculative_config is not None: - self.num_speculative_tokens = \ + self.num_speculative_tokens = ( self.vllm_config.speculative_config.num_speculative_tokens + ) - def compile_grammar(self, request_type: StructuredOutputOptions, - grammar_spec: str) -> StructuredOutputGrammar: + def compile_grammar( + self, request_type: StructuredOutputOptions, grammar_spec: str + ) -> StructuredOutputGrammar: if request_type == StructuredOutputOptions.JSON: ctx = self.compiler.compile_json_schema( - grammar_spec, any_whitespace=not self.disable_any_whitespace) + grammar_spec, any_whitespace=not self.disable_any_whitespace + ) elif request_type == StructuredOutputOptions.JSON_OBJECT: ctx = self.compiler.compile_json_schema( - '{"type": "object"}', - any_whitespace=not self.disable_any_whitespace) + '{"type": "object"}', any_whitespace=not self.disable_any_whitespace + ) elif request_type == StructuredOutputOptions.GRAMMAR: ctx = self.compiler.compile_grammar(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: @@ -106,17 +119,20 @@ def compile_grammar(self, request_type: StructuredOutputOptions, begin=s["begin"], schema=json.dumps(s["schema"]), end=s["end"], - ) for s in s_tag["structures"] + ) + for s in s_tag["structures"] ] structural_tag = xgr.StructuralTag.from_legacy_structural_tag( - tags, s_tag["triggers"]) + tags, s_tag["triggers"] + ) ctx = self.compiler.compile_structural_tag(structural_tag) else: logger.error( "Validation should have already occurred. Please file an issue." ) raise ValueError( - f"grammar is not of valid supported types. ({request_type!s})") + f"grammar is not of valid supported types. ({request_type!s})" + ) return XgrammarGrammar( matcher=xgr.GrammarMatcher( @@ -146,10 +162,9 @@ class XgrammarGrammar(StructuredOutputGrammar): vocab_size: int matcher: xgr.GrammarMatcher = field(hash=False) ctx: xgr.CompiledGrammar = field(hash=False) - num_processed_tokens: int = field(default_factory=lambda: 0, - repr=False, - hash=False, - init=False) + num_processed_tokens: int = field( + default_factory=lambda: 0, repr=False, hash=False, init=False + ) _is_terminated: bool = field(default=False, repr=False, hash=False) def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: @@ -164,7 +179,10 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: if not self.matcher.accept_token(token): logger.error( "Failed to advance FSM for request %s " - "for tokens %s. Please file an issue.", request_id, token) + "for tokens %s. Please file an issue.", + request_id, + token, + ) return False self.num_processed_tokens += 1 self._is_terminated = self.matcher.is_terminated() @@ -216,8 +234,9 @@ def check_object(obj: dict[str, Any]) -> bool: # Check for array unsupported keywords if obj.get("type") == "array" and any( - key in obj for key in ("uniqueItems", "contains", - "minContains", "maxContains")): + key in obj + for key in ("uniqueItems", "contains", "minContains", "maxContains") + ): return True # Unsupported keywords for strings @@ -226,8 +245,14 @@ def check_object(obj: dict[str, Any]) -> bool: # Unsupported keywords for objects if obj.get("type") == "object" and any( - key in obj for key in ("minProperties", "maxProperties", - "propertyNames", "patternProperties")): + key in obj + for key in ( + "minProperties", + "maxProperties", + "propertyNames", + "patternProperties", + ) + ): return True # Recursively check all nested objects and arrays @@ -259,16 +284,18 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: try: xgr.Grammar.from_regex(so_params.regex) except Exception as err: - raise ValueError("Failed to transform regex into a grammar: " - f"{err}") from err + raise ValueError( + f"Failed to transform regex into a grammar: {err}" + ) from err if so_params.choice: choice_grammar = choice_as_grammar(so_params.choice) try: xgr.Grammar.from_ebnf(choice_grammar) except Exception as err: - raise ValueError("Failed to transform choices into a grammar: " - "{err}") from err + raise ValueError( + "Failed to transform choices into a grammar: {err}" + ) from err so_params.choice = None so_params.grammar = choice_grammar return @@ -285,12 +312,14 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: try: xgr.Grammar.from_json_schema(schema) except Exception as err: - raise ValueError("Failed to transform json schema into a grammar: " - f"{err}") from err + raise ValueError( + f"Failed to transform json schema into a grammar: {err}" + ) from err if has_xgrammar_unsupported_json_features(schema): - raise ValueError("The provided JSON schema contains features not " - "supported by xgrammar.") + raise ValueError( + "The provided JSON schema contains features not supported by xgrammar." + ) return if so_params.grammar: @@ -300,7 +329,8 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: so_params.grammar = convert_lark_to_ebnf(so_params.grammar) except ValueError as e: raise ValueError( - "Failed to convert the grammar from Lark to EBNF. ") from e + "Failed to convert the grammar from Lark to EBNF. " + ) from e # Test parsing EBNF grammar, possibly already converted from Lark try: @@ -318,10 +348,12 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: begin=s["begin"], schema=json.dumps(s["schema"]), end=s["end"], - ) for s in s_tag["structures"] + ) + for s in s_tag["structures"] ] structural_tag = xgr.StructuralTag.from_legacy_structural_tag( - tags, s_tag["triggers"]) + tags, s_tag["triggers"] + ) xgr.Grammar.from_structural_tag(structural_tag) except Exception as e: raise ValueError("Invalid structural tag specification.") from e diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 99974ef46ecd..233c7c1e7805 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -10,18 +10,20 @@ from typing import Optional, Union, cast from vllm.sampling_params import SamplingParams -from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar, - StructuredOutputKey, - StructuredOutputOptions) +from vllm.v1.structured_output.backend_types import ( + StructuredOutputGrammar, + StructuredOutputKey, + StructuredOutputOptions, +) @dataclasses.dataclass class StructuredOutputRequest: - sampling_params: SamplingParams - _grammar: Optional[Union[Future[StructuredOutputGrammar], - StructuredOutputGrammar]] = None - reasoning_ended: Optional[bool] = None + _grammar: Union[Future[StructuredOutputGrammar], StructuredOutputGrammar] | None = ( + None + ) + reasoning_ended: bool | None = None def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports @@ -41,15 +43,17 @@ def is_grammar_ready(self) -> bool: return self._check_grammar_completion() @property - def grammar(self) -> Optional[StructuredOutputGrammar]: + def grammar(self) -> StructuredOutputGrammar | None: completed = self._check_grammar_completion() - return cast(Optional[StructuredOutputGrammar], - self._grammar) if completed else None + return ( + cast(Optional[StructuredOutputGrammar], self._grammar) + if completed + else None + ) @grammar.setter def grammar( - self, grammar: Union[StructuredOutputGrammar, - Future[StructuredOutputGrammar]] + self, grammar: Union[StructuredOutputGrammar, Future[StructuredOutputGrammar]] ) -> None: self._grammar = grammar @@ -58,8 +62,7 @@ def structured_output_key(self) -> StructuredOutputKey: return get_structured_output_key(self.sampling_params) -def get_structured_output_key( - sampling_params: SamplingParams) -> StructuredOutputKey: +def get_structured_output_key(sampling_params: SamplingParams) -> StructuredOutputKey: params = sampling_params.structured_outputs assert params is not None, "params can't be None." if params.json is not None: diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index b9b09bea1e80..b7326847d016 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -76,27 +76,31 @@ def apply_grammar_bitmask( for req_id, batch_index in seq: logit_index = batch_index + cumulative_offset cumulative_offset += len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + ) if req_id in scheduler_output.structured_output_request_ids: struct_out_req_batch_indices[req_id] = logit_index out_indices = [] # Reorder the bitmask to match the order of the requests in the batch. - sorted_bitmask = np.full(shape=(logits.shape[0], grammar_bitmask.shape[1]), - fill_value=-1, - dtype=grammar_bitmask.dtype) + sorted_bitmask = np.full( + shape=(logits.shape[0], grammar_bitmask.shape[1]), + fill_value=-1, + dtype=grammar_bitmask.dtype, + ) cumulative_index = 0 - seq = sorted(scheduler_output.structured_output_request_ids.items(), - key=lambda x: x[1]) + seq = sorted( + scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1] + ) for req_id, _ in seq: num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) + ) if req_id in struct_out_req_batch_indices: logit_index = struct_out_req_batch_indices[req_id] for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + i] = \ - grammar_bitmask[cumulative_index + i] + sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i] out_indices.append(logit_index + i) cumulative_index += 1 + num_spec_tokens grammar_bitmask = sorted_bitmask @@ -128,8 +132,7 @@ def __init__(self, vocabulary: oc.Vocabulary) -> None: self.inner = vocabulary # Have to do abs(hash()) because python hashes can # be negative, and we are using hash as a cache key. - hex_str = hashlib.sha256( - vocabulary.__repr__().encode('utf-8')).hexdigest() + hex_str = hashlib.sha256(vocabulary.__repr__().encode("utf-8")).hexdigest() hash_int = int(hex_str, 16) self._hash = hash_int @@ -165,16 +168,18 @@ def get_outlines_cache(): cache_dir = get_outlines_cache_path() if envs.VLLM_V1_USE_OUTLINES_CACHE: - logger.warning("Enabling outlines cache. This is an unbounded on-disk " - "cache. It may consume a lot of disk space and should " - "not be used with untrusted clients.") + logger.warning( + "Enabling outlines cache. This is an unbounded on-disk " + "cache. It may consume a lot of disk space and should " + "not be used with untrusted clients." + ) cache = Cache(cache_dir, eviction_policy="none", cull_limit=0) outlines_version = importlib.metadata.version("outlines_core") - cached_version = cache.get('__version__', None) + cached_version = cache.get("__version__", None) if cached_version != outlines_version: cache.clear() - cache.set('__version__', outlines_version) + cache.set("__version__", outlines_version) return cache else: return LRUCache(maxsize=128) @@ -194,19 +199,17 @@ def _reduced_vocabulary( A Dict of token string -> equivalent token ids """ - unicode_to_bytes = { - v: k - for k, v in tokenization_gpt2.bytes_to_unicode().items() - } + unicode_to_bytes = {v: k for k, v in tokenization_gpt2.bytes_to_unicode().items()} def convert_token_to_string(token: str) -> str: - string = tokenizer.convert_tokens_to_string([token]) # A hack to handle missing spaces to HF's Llama tokenizers - if (type(token) is str - and token.startswith(file_utils.SPIECE_UNDERLINE) - or token == "<0x20>"): + if ( + type(token) is str + and token.startswith(file_utils.SPIECE_UNDERLINE) + or token == "<0x20>" + ): return " " + string return string @@ -226,8 +229,7 @@ def convert_token_to_string(token: str) -> str: # by this point. token_bytes = bytes(token_str) # type: ignore[arg-type] - elif "\ufffd" in token_str and not re_replacement_seq.match( - token_str): + elif "\ufffd" in token_str and not re_replacement_seq.match(token_str): # Handle tokens with invalid UTF-8 sequences. if re_llama_byte_token.match(token): # Llama-like tokenizers use <0xXX> for incomplete sequences. @@ -238,12 +240,13 @@ def convert_token_to_string(token: str) -> str: if None in byte_vals: raise RuntimeError( f"Cannot convert token `{token}`" - f" ({token_idx}) to bytes: {token_str}") + f" ({token_idx}) to bytes: {token_str}" + ) # safe to ignore, since if None in byte_vals, # an error is thrown. token_bytes = bytes(byte_vals) # type: ignore[arg-type] else: - token_bytes = token_str.encode('utf-8') + token_bytes = token_str.encode("utf-8") if token_idx != eos_token_id: vocabulary.setdefault(token_bytes, []).append(token_idx) @@ -254,16 +257,18 @@ def convert_token_to_string(token: str) -> str: def get_outlines_vocabulary(tokenizer: AnyTokenizer) -> oc.Vocabulary: - """Get the `Vocabulary` object for a given tokenizer. - """ + """Get the `Vocabulary` object for a given tokenizer.""" if hasattr(tokenizer, "_outlines_vocabulary"): return tokenizer._outlines_vocabulary # type: ignore try: - if hasattr( + if ( + hasattr( tokenizer, "eos_token_id", - ) and tokenizer.eos_token_id is not None: + ) + and tokenizer.eos_token_id is not None + ): eos_token_id = tokenizer.eos_token_id else: raise ValueError( @@ -272,17 +277,18 @@ def get_outlines_vocabulary(tokenizer: AnyTokenizer) -> oc.Vocabulary: reduced_vocab = _reduced_vocabulary( tokenizer, - eos_token_id #type: ignore + eos_token_id, # type: ignore ) - vocabulary = OutlinesVocabulary( - oc.Vocabulary(eos_token_id, reduced_vocab)) + vocabulary = OutlinesVocabulary(oc.Vocabulary(eos_token_id, reduced_vocab)) tokenizer._outlines_vocabulary = vocabulary # type: ignore return vocabulary except AttributeError as e: - raise ValueError(f"Cannot get the vocabulary of the tokenizer " - f"({type(tokenizer)}). The tokenizer should have a " - "get_vocab method.") from e + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"({type(tokenizer)}). The tokenizer should have a " + "get_vocab method." + ) from e def grammar_is_likely_lark(grammar_str: str) -> bool: @@ -304,14 +310,14 @@ def grammar_is_likely_lark(grammar_str: str) -> bool: if not grammar_str or not isinstance(grammar_str, str): return False - for line in grammar_str.split('\n'): + for line in grammar_str.split("\n"): # Remove both comment styles - line = re.sub(r'(#|//).*$', '', line).strip() + line = re.sub(r"(#|//).*$", "", line).strip() if not line: continue # Look for EBNF rule definition - if '::=' in line: + if "::=" in line: return False return True @@ -348,40 +354,41 @@ def convert_lark_to_ebnf(grammar_str: str) -> str: def clean_line(line: str) -> str: """Remove comments and whitespace from line.""" - return re.sub(r'(#|//).*$', '', line).strip() + return re.sub(r"(#|//).*$", "", line).strip() def check_quotes(text: str, rule_name: str, line_num: int) -> None: """Validate quote matching in text.""" if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: - raise ValueError( - f"Mismatched quotes in {rule_name} on line {line_num}") + raise ValueError(f"Mismatched quotes in {rule_name} on line {line_num}") def extract_references(text: str) -> set[str]: """Extract rule references from text.""" # Remove quoted strings and special characters - text = re.sub(r'"[^"]*"', '', text) - text = re.sub(r'[+*?()|\[\]{}]', ' ', text) - return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + text = re.sub(r'"[^"]*"', "", text) + text = re.sub(r"[+*?()|\[\]{}]", " ", text) + return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", text)) # First pass: Find root rule and validate rule definitions - lines = [clean_line(line) for line in grammar_str.split('\n')] + lines = [clean_line(line) for line in grammar_str.split("\n")] first_rule = None for line_num, line in enumerate(lines, 1): - if not line or line.startswith('|'): + if not line or line.startswith("|"): continue - if ':' in line: + if ":" in line: try: - name = line.split(':', 1)[0].strip().strip('?') + name = line.split(":", 1)[0].strip().strip("?") defined_rules.add(name) if first_rule is None: first_rule = name - if name == 'start': - first_rule = 'start' + if name == "start": + first_rule = "start" except IndexError as e: - raise ValueError(f"Invalid rule format on line {line_num}. " - "Expected 'rule_name: definition'") from e + raise ValueError( + f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'" + ) from e if not defined_rules: raise ValueError("No valid rules found in grammar") @@ -398,29 +405,33 @@ def extract_references(text: str) -> set[str]: continue try: - if ':' in line and not line.startswith('|'): + if ":" in line and not line.startswith("|"): # Save previous rule if exists if current_rule: output_lines.append( - f"{current_rule} ::= {' | '.join(current_definition)}") + f"{current_rule} ::= {' | '.join(current_definition)}" + ) # Process new rule - name, definition = line.split(':', 1) - current_rule = name.strip().strip('?') + name, definition = line.split(":", 1) + current_rule = name.strip().strip("?") check_quotes(definition, f"rule '{current_rule}'", line_num) definition = re.sub(r"'([^']*)'", r'"\1"', definition) referenced_rules.update(extract_references(definition)) current_definition = [definition.strip()] - elif line.startswith('|'): + elif line.startswith("|"): if not current_rule: - raise ValueError(f"Alternative '|' on line {line_num} " - "without a preceding rule definition") + raise ValueError( + f"Alternative '|' on line {line_num} " + "without a preceding rule definition" + ) alt_def = line[1:].strip() - check_quotes(alt_def, f"alternative for rule '{current_rule}'", - line_num) + check_quotes( + alt_def, f"alternative for rule '{current_rule}'", line_num + ) alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) referenced_rules.update(extract_references(alt_def)) current_definition.append(alt_def) @@ -430,25 +441,24 @@ def extract_references(text: str) -> set[str]: # Add final rule if exists if current_rule: - output_lines.append( - f"{current_rule} ::= {' | '.join(current_definition)}") + output_lines.append(f"{current_rule} ::= {' | '.join(current_definition)}") # Validate all rules are defined - undefined_rules = referenced_rules - defined_rules - {'root'} + undefined_rules = referenced_rules - defined_rules - {"root"} if undefined_rules: - raise ValueError("Referenced rules are not defined: " - f"{', '.join(sorted(undefined_rules))}") + raise ValueError( + f"Referenced rules are not defined: {', '.join(sorted(undefined_rules))}" + ) - return '\n'.join(output_lines) + return "\n".join(output_lines) def choice_as_grammar(choice: list[str]) -> str: - def escape_ebnf_string(s: str) -> str: """Escape special characters in a EBNF string.""" # Escape double quotes and backslashes - return re.sub(r'(["\\])', r'\\\1', s) + return re.sub(r'(["\\])', r"\\\1", s) escaped_choices = (escape_ebnf_string(c) for c in choice) - grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices)) + grammar = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices) return grammar diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index ee0c1168f3cd..925943262894 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -9,25 +9,35 @@ from contextlib import AbstractContextManager from multiprocessing import connection from multiprocessing.process import BaseProcess -from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, overload) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Optional, + TypeVar, + Union, + overload, +) import torch from torch.autograd.profiler import record_function import vllm.envs as envs from vllm.logger import init_logger -from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) -from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, - kill_process_tree) +from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message +from vllm.utils import ( + get_open_port, + get_open_zmq_ipc_path, + get_tcp_uri, + kill_process_tree, +) if TYPE_CHECKING: import numpy as np from vllm.v1.engine.coordinator import DPCoordinator - from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager) + from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager logger = init_logger(__name__) @@ -35,7 +45,6 @@ class ConstantList(Generic[T], Sequence): - def __init__(self, x: list[T]) -> None: self._x = x @@ -57,31 +66,23 @@ def remove(self, item): def clear(self): raise TypeError("Cannot clear a constant list") - def index(self, - item: T, - start: int = 0, - stop: Optional[int] = None) -> int: - return self._x.index(item, start, - stop if stop is not None else len(self._x)) + def index(self, item: T, start: int = 0, stop: Optional[int] = None) -> int: + return self._x.index(item, start, stop if stop is not None else len(self._x)) @overload - def __getitem__(self, item: int) -> T: - ... + def __getitem__(self, item: int) -> T: ... @overload - def __getitem__(self, s: slice, /) -> list[T]: - ... + def __getitem__(self, s: slice, /) -> list[T]: ... def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]: return self._x[item] @overload - def __setitem__(self, item: int, value: T): - ... + def __setitem__(self, item: int, value: T): ... @overload - def __setitem__(self, s: slice, value: T, /): - ... + def __setitem__(self, s: slice, value: T, /): ... def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]): raise TypeError("Cannot set item in a constant list") @@ -113,10 +114,7 @@ def __init__( pin_memory: bool, with_numpy: bool = True, ) -> None: - self.cpu = torch.zeros(*size, - dtype=dtype, - device="cpu", - pin_memory=pin_memory) + self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory) self.gpu = torch.zeros_like(self.cpu, device=device) self.np: np.ndarray # To keep type hints simple (avoiding generics and subclasses), we @@ -126,7 +124,8 @@ def __init__( if dtype == torch.bfloat16: raise ValueError( "Bfloat16 torch tensors cannot be directly cast to a " - "numpy array, so call CpuGpuBuffer with with_numpy=False") + "numpy array, so call CpuGpuBuffer with with_numpy=False" + ) self.np = self.cpu.numpy() def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor: @@ -142,9 +141,7 @@ def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor: return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True) -def get_engine_client_zmq_addr(local_only: bool, - host: str, - port: int = 0) -> str: +def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str: """Assign a new ZMQ socket address. If local_only is True, participants are colocated and so a unique IPC @@ -153,8 +150,11 @@ def get_engine_client_zmq_addr(local_only: bool, Otherwise, the provided host and port will be used to construct a TCP address (port == 0 means assign an available port).""" - return get_open_zmq_ipc_path() if local_only else (get_tcp_uri( - host, port or get_open_port())) + return ( + get_open_zmq_ipc_path() + if local_only + else (get_tcp_uri(host, port or get_open_port())) + ) class APIServerProcessManager: @@ -195,21 +195,23 @@ def __init__( spawn_context = multiprocessing.get_context("spawn") self.processes: list[BaseProcess] = [] - for i, in_addr, out_addr in zip(range(num_servers), input_addresses, - output_addresses): + for i, in_addr, out_addr in zip( + range(num_servers), input_addresses, output_addresses + ): client_config = { "input_address": in_addr, "output_address": out_addr, "client_count": num_servers, - "client_index": i + "client_index": i, } if stats_update_address is not None: client_config["stats_update_address"] = stats_update_address - proc = spawn_context.Process(target=target_server_fn, - name=f"ApiServer_{i}", - args=(listen_address, sock, args, - client_config)) + proc = spawn_context.Process( + target=target_server_fn, + name=f"ApiServer_{i}", + args=(listen_address, sock, args, client_config), + ) self.processes.append(proc) proc.start() @@ -224,10 +226,12 @@ def close(self) -> None: def wait_for_completion_or_failure( - api_server_manager: APIServerProcessManager, - engine_manager: Optional[Union["CoreEngineProcManager", - "CoreEngineActorManager"]] = None, - coordinator: Optional["DPCoordinator"] = None) -> None: + api_server_manager: APIServerProcessManager, + engine_manager: Optional[ + Union["CoreEngineProcManager", "CoreEngineActorManager"] + ] = None, + coordinator: Optional["DPCoordinator"] = None, +) -> None: """Wait for all processes to complete or detect if any fail. Raises an exception if any process exits with a non-zero status. @@ -240,16 +244,14 @@ def wait_for_completion_or_failure( coordinator: The coordinator for data parallel. """ - from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager) + from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager try: logger.info("Waiting for API servers to complete ...") # Create a mapping of sentinels to their corresponding processes # for efficient lookup sentinel_to_proc: dict[Any, BaseProcess] = { - proc.sentinel: proc - for proc in api_server_manager.processes + proc.sentinel: proc for proc in api_server_manager.processes } if coordinator: @@ -265,8 +267,7 @@ def wait_for_completion_or_failure( # Check if any process terminates while sentinel_to_proc or actor_run_refs: # Wait for any process to terminate - ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, - timeout=5) + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5) # Process any terminated processes for sentinel in ready_sentinels: @@ -276,17 +277,18 @@ def wait_for_completion_or_failure( if proc.exitcode != 0: raise RuntimeError( f"Process {proc.name} (PID: {proc.pid}) " - f"died with exit code {proc.exitcode}") + f"died with exit code {proc.exitcode}" + ) if actor_run_refs: import ray + _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) except KeyboardInterrupt: logger.info("Received KeyboardInterrupt, shutting down API servers...") except Exception as e: - logger.exception("Exception occurred while running API servers: %s", - str(e)) + logger.exception("Exception occurred while running API servers: %s", str(e)) raise finally: logger.info("Terminating remaining processes ...") @@ -319,8 +321,9 @@ def shutdown(procs: list[BaseProcess]): kill_process_tree(pid) -def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, - length: int) -> torch.Tensor: +def copy_slice( + from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int +) -> torch.Tensor: """ Copy the first length elements of a tensor into another tensor in a non-blocking manner. @@ -333,8 +336,8 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, def report_usage_stats( - vllm_config, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None: + vllm_config, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT +) -> None: """Report usage statistics if enabled.""" if not is_usage_stats_enabled(): @@ -342,37 +345,28 @@ def report_usage_stats( from vllm.model_executor.model_loader import get_architecture_class_name + parallel_config = vllm_config.parallel_config + usage_message.report_usage( get_architecture_class_name(vllm_config.model_config), usage_context, extra_kvs={ # Common configuration - "dtype": - str(vllm_config.model_config.dtype), - "tensor_parallel_size": - vllm_config.parallel_config.tensor_parallel_size, - "block_size": - vllm_config.cache_config.block_size, - "gpu_memory_utilization": - vllm_config.cache_config.gpu_memory_utilization, - "kv_cache_memory_bytes": - vllm_config.cache_config.kv_cache_memory_bytes, + "dtype": str(vllm_config.model_config.dtype), + "tensor_parallel_size": parallel_config.tensor_parallel_size, + "block_size": vllm_config.cache_config.block_size, + "gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization, + "kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes, # Quantization - "quantization": - vllm_config.model_config.quantization, - "kv_cache_dtype": - str(vllm_config.cache_config.cache_dtype), - + "quantization": vllm_config.model_config.quantization, + "kv_cache_dtype": str(vllm_config.cache_config.cache_dtype), # Feature flags - "enable_lora": - bool(vllm_config.lora_config), - "enable_prefix_caching": - vllm_config.cache_config.enable_prefix_caching, - "enforce_eager": - vllm_config.model_config.enforce_eager, - "disable_custom_all_reduce": - vllm_config.parallel_config.disable_custom_all_reduce, - }) + "enable_lora": bool(vllm_config.lora_config), + "enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching, + "enforce_eager": vllm_config.model_config.enforce_eager, + "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, + }, + ) _PROFILER_FUNC = None @@ -390,6 +384,7 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager: func = record_function elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING: import nvtx + func = nvtx.annotate _PROFILER_FUNC = func diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 82b6d1b514d5..4d3688453cb9 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -14,7 +14,6 @@ class BlockTable: - def __init__( self, block_size: int, @@ -31,13 +30,14 @@ def __init__( self.pin_memory = pin_memory self.device = device - self.block_table = self._make_buffer(max_num_reqs, - max_num_blocks_per_req, - dtype=torch.int32) + self.block_table = self._make_buffer( + max_num_reqs, max_num_blocks_per_req, dtype=torch.int32 + ) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) - self.slot_mapping = self._make_buffer(self.max_num_batched_tokens, - dtype=torch.int64) + self.slot_mapping = self._make_buffer( + self.max_num_batched_tokens, dtype=torch.int64 + ) try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group @@ -56,7 +56,7 @@ def append_row( num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks - self.block_table.np[row_idx, start:start + num_blocks] = block_ids + self.block_table.np[row_idx, start : start + num_blocks] = block_ids def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 @@ -73,8 +73,9 @@ def swap_row(self, src: int, tgt: int) -> None: self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src] self.block_table.np[src_tgt] = self.block_table.np[tgt_src] - def compute_slot_mapping(self, req_indices: np.ndarray, - positions: np.ndarray) -> None: + def compute_slot_mapping( + self, req_indices: np.ndarray, positions: np.ndarray + ) -> None: # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. @@ -89,8 +90,10 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # Use a "virtual block" which equals to world_size * block_size # for block_table_indices calculation. virtual_block_size = self.block_size * self.dcp_world_size - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions // virtual_block_size) + block_table_indices = ( + req_indices * self.max_num_blocks_per_req + + positions // virtual_block_size + ) block_numbers = self.block_table.np.ravel()[block_table_indices] # Use virtual_block_size for mask calculation, which marks local # tokens. @@ -101,16 +104,20 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local - self.slot_mapping.np[:req_indices.shape[0]] = np.where( - mask, slot_mapping, -1) + self.slot_mapping.np[: req_indices.shape[0]] = np.where( + mask, slot_mapping, -1 + ) else: - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions // self.block_size) + block_table_indices = ( + req_indices * self.max_num_blocks_per_req + positions // self.block_size + ) block_numbers = self.block_table.np.ravel()[block_table_indices] block_offsets = positions % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping.np[:req_indices.shape[0]]) + np.add( + block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping.np[: req_indices.shape[0]], + ) def commit_block_table(self, num_reqs: int) -> None: self.block_table.copy_to_gpu(num_reqs) @@ -134,25 +141,27 @@ def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" return self.block_table.np - def _make_buffer(self, *size: Union[int, torch.SymInt], - dtype: torch.dtype) -> CpuGpuBuffer: - return CpuGpuBuffer(*size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory) + def _make_buffer( + self, *size: Union[int, torch.SymInt], dtype: torch.dtype + ) -> CpuGpuBuffer: + return CpuGpuBuffer( + *size, dtype=dtype, device=self.device, pin_memory=self.pin_memory + ) class MultiGroupBlockTable: """The BlockTables for each KV cache group.""" - def __init__(self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - pin_memory: bool, - device: torch.device, - block_sizes: list[int], - num_speculative_tokens: int = 0) -> None: + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + pin_memory: bool, + device: torch.device, + block_sizes: list[int], + num_speculative_tokens: int = 0, + ) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, # so the block_size which used for calc max_num_blocks_per_req @@ -165,14 +174,20 @@ def __init__(self, self.block_tables = [ BlockTable( - block_size, max_num_reqs, - max(cdiv(max_model_len, block_size * dcp_world_size), - 1 + num_speculative_tokens), max_num_batched_tokens, - pin_memory, device) for block_size in block_sizes + block_size, + max_num_reqs, + max( + cdiv(max_model_len, block_size * dcp_world_size), + 1 + num_speculative_tokens, + ), + max_num_batched_tokens, + pin_memory, + device, + ) + for block_size in block_sizes ] - def append_row(self, block_ids: tuple[list[int], ...], - row_idx: int) -> None: + def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.append_row(block_ids[i], row_idx) @@ -188,8 +203,9 @@ def swap_row(self, src: int, tgt: int) -> None: for block_table in self.block_tables: block_table.swap_row(src, tgt) - def compute_slot_mapping(self, req_indices: np.ndarray, - positions: np.ndarray) -> None: + def compute_slot_mapping( + self, req_indices: np.ndarray, positions: np.ndarray + ) -> None: for block_table in self.block_tables: block_table.compute_slot_mapping(req_indices, positions) diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 6a97f7ebc3fc..f48b354e8a7d 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -9,7 +9,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1 from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -20,7 +19,6 @@ class CPUModelRunner(GPUModelRunner): - def __init__(self, vllm_config: VllmConfig, device: torch.device): with _torch_cuda_wrapper(): super().__init__(vllm_config, device) @@ -33,55 +31,18 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self._postprocess_tensors() + # Note: Remove the override after new attention backend finished def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: - """ - Update the order of requests in the batch based on the attention - backend's needs. For example, some attention backends (namely MLA) may - want to separate requests based on if the attention computation will be - compute-bound or memory-bound. - - Args: - scheduler_output: The scheduler output. - """ - # Attention free models have zero kv_cache_groups, however models - # like Mamba are also attention free but use the kv_cache for - # keeping its internal state. This is why we check the number - # of kv_cache groups instead of solely checking - # for self.model_config.is_attention_free. - if len(self.kv_cache_config.kv_cache_groups) == 0: - return - if len(self.kv_cache_config.kv_cache_groups) > 1: - raise ValueError("Multiple KVCacheGroups is not" - "currently supported with CPU model runner.") - - # Guard against encoder-only / pooling models where `attn_groups` - # may be empty or lack the expected metadata_builder. - # Without this check, accessing `attn_groups[0][0]` would trigger - # an AssertionError on CPU backend. - if not hasattr(self, "attn_groups") or not self.attn_groups: - return - if not self.attn_groups[0]: - return - - mb = getattr(self.attn_groups[0][0], "metadata_builders", None) - if isinstance(mb, list): - if not isinstance(mb[0], TorchSDPAMetadataBuilderV1): - return - mb[0].reorder_batch(self.input_batch, scheduler_output) - return - elif not isinstance(mb, TorchSDPAMetadataBuilderV1): - # Encoder-only / rerank models do not benefit from reordering, - # so we safely skip here. - return - - # Safe path for decoder/attention-heavy models - mb.reorder_batch(self.input_batch, scheduler_output) + raise ValueError( + "Multiple KVCacheGroups is not" + "currently supported with CPU model runner." + ) + super()._may_reorder_batch(scheduler_output) def _postprocess_tensors(self) -> None: # Note: replace device tensors with cpu tensors - def replace_tensor(obj: Any, cpu_attr_name: str, - device_attr_name) -> None: + def replace_tensor(obj: Any, cpu_attr_name: str, device_attr_name) -> None: cpu_tensor = getattr(obj, cpu_attr_name, None) device_tensor = getattr(obj, device_attr_name, None) if cpu_tensor is not None and device_tensor is not None: @@ -107,8 +68,7 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model = get_model(vllm_config=self.vllm_config) if self.lora_config: - self.model = self.load_lora_model(self.model, self.vllm_config, - self.device) + self.model = self.load_lora_model(self.model, self.vllm_config, self.device) def get_model(self) -> nn.Module: return self.model @@ -129,23 +89,19 @@ def _sync_device(self) -> None: def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: return sampled_token_ids.tolist() - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: # Note: For CPU backend, dp padding is not required for now. return 0, None @contextmanager def _torch_cuda_wrapper(): - class _EventPlaceholder: - def __init__(self, *args, **kwargs) -> None: self.record = lambda: None self.synchronize = lambda: None class _StreamPlaceholder: - def __init__(self, *args, **kwargs) -> None: pass diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index c6a686d6b75e..ee865ec8e649 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -13,25 +13,27 @@ from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo from vllm.v1.worker.cpu_model_runner import CPUModelRunner -from vllm.v1.worker.gpu_worker import (Worker, - init_worker_distributed_environment) +from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment logger = init_logger(__name__) class CPUWorker(Worker): - - def __init__(self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False): - super().__init__(vllm_config, - local_rank, - rank, - distributed_init_method, - is_driver_worker=is_driver_worker) + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + super().__init__( + vllm_config, + local_rank, + rank, + distributed_init_method, + is_driver_worker=is_driver_worker, + ) self.parallel_config.disable_custom_all_reduce = True @@ -43,11 +45,13 @@ def init_device(self): if cpu_arch in (CpuArchEnum.POWERPC, CpuArchEnum.S390X): # For S390X/POWERPC SMT-8/4/2 self.local_omp_cpuid = self._get_autobind_cpu_ids( - lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]) + lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4] + ) elif current_platform.get_cpu_architecture() == CpuArchEnum.X86: # For x86 SMT-2, use 1 CPU per core self.local_omp_cpuid = self._get_autobind_cpu_ids( - lambda cpus: cpus[-1:]) + lambda cpus: cpus[-1:] + ) else: self.local_omp_cpuid = "all" else: @@ -55,9 +59,9 @@ def init_device(self): omp_cpuids = omp_cpuids.split("|") if local_dp_rank is not None: world_size = self.parallel_config.world_size - omp_cpuids = omp_cpuids[local_dp_rank * - world_size:(local_dp_rank + 1) * - world_size] + omp_cpuids = omp_cpuids[ + local_dp_rank * world_size : (local_dp_rank + 1) * world_size + ] self.local_omp_cpuid = omp_cpuids[self.rank] if self.local_omp_cpuid != "all": @@ -66,19 +70,22 @@ def init_device(self): logger.info(ret) # Note: unique identifier for creating allreduce shared memory - os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split( - ":")[-1] + os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1] # Initialize the distributed environment. - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) # Set random seed. set_random_seed(self.model_config.seed) # Construct the model runner self.model_runner: CPUModelRunner = CPUModelRunner( - self.vllm_config, torch.device("cpu")) + self.vllm_config, torch.device("cpu") + ) def sleep(self, level: int = 1) -> None: logger.warning("sleep mode is not supported on CPU, ignore it.") @@ -98,31 +105,31 @@ def compile_or_warm_up_model(self) -> None: self.model_runner.warming_up_model() def _get_autobind_cpu_ids( - self, cpu_selector: Callable[[list[LogicalCPUInfo]], - list[LogicalCPUInfo]] + self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]] ) -> str: """ - Return CPU ids to bind based on NUMA nodes. - Currently for rank N, only CPU ids on the N-th node in available NUMA + Return CPU ids to bind based on NUMA nodes. + Currently for rank N, only CPU ids on the N-th node in available NUMA node list will be selected. Args: - cpu_selector: a callable object to select CPUs from a CPU list + cpu_selector: a callable object to select CPUs from a CPU list of a physical core. The input is a LogicalCPUInfo list, sorted by - the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be + the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be returned. """ - allowed_numa_nodes, logical_cpu_list = \ + allowed_numa_nodes, logical_cpu_list = ( CpuPlatform.get_allowed_cpu_core_node_list() + ) assert len(allowed_numa_nodes) >= self.parallel_config.world_size, ( f"No enough allowed NUMA nodes to bind threads of " f"{self.parallel_config.world_size} CPUWorkers. " f"Allowed NUMA nodes are {allowed_numa_nodes}. " - "Please try to bind threads manually.") + "Please try to bind threads manually." + ) # Get CPUs on NUMA node `allowed_numa_nodes[local_rank]`` - selected_numa_node = allowed_numa_nodes[ - self.local_rank] # type: ignore + selected_numa_node = allowed_numa_nodes[self.local_rank] # type: ignore logical_cpu_list = [ x for x in logical_cpu_list if x.numa_node == selected_numa_node ] @@ -142,15 +149,20 @@ def _get_autobind_cpu_ids( # Reserve CPUs for other processes reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU if reserve_cpu_num is None: - need_reserve = (self.parallel_config.world_size > 1 or - self.parallel_config.data_parallel_size_local > 1) + need_reserve = ( + self.parallel_config.world_size > 1 + or self.parallel_config.data_parallel_size_local > 1 + ) reserve_cpu_num = 1 if need_reserve else 0 assert len(logical_cpu_list) > reserve_cpu_num, ( f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) " - f"should less than {len(logical_cpu_list)}.") + f"should less than {len(logical_cpu_list)}." + ) if reserve_cpu_num != 0: logical_cpu_list = logical_cpu_list[:-reserve_cpu_num] - logger.info("auto thread-binding list (id, physical core): %s", - [(x.id, x.physical_core) for x in logical_cpu_list]) + logger.info( + "auto thread-binding list (id, physical core): %s", + [(x.id, x.physical_core) for x in logical_cpu_list], + ) return ",".join([str(x.id) for x in logical_cpu_list]) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 67fb9864b19c..06f935423662 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -7,18 +7,19 @@ import numpy as np import torch -from typing_extensions import deprecated from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, - LogitsProcessors, - MoveDirectionality) +from vllm.v1.sample.logits_processor import ( + BatchUpdateBuilder, + LogitsProcessors, + MoveDirectionality, +) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice @@ -27,7 +28,6 @@ @dataclass class CachedRequestState: - req_id: str prompt_token_ids: Optional[list[int]] mm_features: list[MultiModalFeatureSpec] @@ -47,28 +47,20 @@ class CachedRequestState: def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - self.prompt_token_ids, self.prompt_embeds) + self.prompt_token_ids, self.prompt_embeds + ) @property def num_tokens(self) -> int: return self.num_prompt_tokens + len(self.output_token_ids) - # Temporary back-compatibility for plugins that define model runner - @property - @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " - "removed in v0.13. Please use `mm_kwargs` instead.") - def mm_inputs(self) -> list[MultiModalKwargsItems]: - return [ - MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features - if f.data is not None - ] - def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: if self.prompt_token_ids is None: raise ValueError( f"Tried to access token index {idx}, but that token was " - "provided via prompt_embeds, and its ID is unknown.") + "provided via prompt_embeds, and its ID is unknown." + ) return self.prompt_token_ids[idx] elif idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] @@ -77,7 +69,6 @@ def get_token_id(self, idx: int) -> int: class InputBatch: - def __init__( self, max_num_reqs: int, @@ -115,10 +106,9 @@ def __init__( pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() - self.is_token_ids = torch.zeros((max_num_reqs, max_model_len), - device="cpu", - dtype=bool, - pin_memory=False) + self.is_token_ids = torch.zeros( + (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False + ) # Store prompt embeddings per request to avoid OOM from large upfront # allocation if max_model_len is big. # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) @@ -127,13 +117,12 @@ def __init__( self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( - (max_num_reqs, ), + (max_num_reqs,), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) - self.num_computed_tokens_cpu = \ - self.num_computed_tokens_cpu_tensor.numpy() + self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy() # Block table. self.block_table = MultiGroupBlockTable( @@ -147,34 +136,27 @@ def __init__( ) # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.temperature = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device + ) + self.temperature_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.greedy_reqs: set[str] = set() self.random_reqs: set[str] = set() - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device) + self.top_p_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_reqs: set[str] = set() - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) + self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device) + self.top_k_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: set[str] = set() @@ -182,54 +164,43 @@ def __init__( self.spec_decode_unsupported_reqs: set[str] = set() # Frequency penalty related data structures - self.frequency_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.frequency_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: set[str] = set() # Presence penalty related data structures - self.presence_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( + self.presence_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) + self.presence_penalties_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory ) + self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy() self.presence_penalties_reqs: set[str] = set() # Repetition penalty related data structures - self.repetition_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.repetition_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() # Speculative decoding - self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ), - dtype=torch.int64, - device="cpu", - pin_memory=pin_memory) - self.num_accepted_tokens_cpu = \ - self.num_accepted_tokens_cpu_tensor.numpy() + self.num_accepted_tokens_cpu_tensor = torch.ones( + (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory + ) + self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() # lora related - self.request_lora_mapping = np.zeros((self.max_num_reqs, ), - dtype=np.int32) + self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32) self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {} @@ -261,8 +232,7 @@ def __init__( # req_index -> bad_words_token_ids self.bad_words_token_ids: dict[int, list[list[int]]] = {} - self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, - dtype=bool) + self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool) self.req_output_token_ids: list[Optional[list[int]]] = [] @@ -302,8 +272,13 @@ def _register_add_request(self, request: "CachedRequestState") -> int: # Detailed added request metadata is only required for non-pooling # models, to support logitsprocs. self.batch_update_builder.added.append( - (new_req_index, request.sampling_params, - request.prompt_token_ids, request.output_token_ids)) + ( + new_req_index, + request.sampling_params, + request.prompt_token_ids, + request.output_token_ids, + ) + ) return new_req_index @@ -325,20 +300,19 @@ def add_request( # Copy the prompt token ids and output token ids. num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - request.prompt_token_ids, request.prompt_embeds) + request.prompt_token_ids, request.prompt_embeds + ) self.num_prompt_tokens[req_index] = num_prompt_tokens start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) if request.prompt_token_ids is not None: - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids + self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids self.is_token_ids[req_index, :num_prompt_tokens] = True else: self.is_token_ids[req_index, :num_prompt_tokens] = False if request.prompt_embeds is not None: self.req_prompt_embeds[req_index] = request.prompt_embeds - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids + self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids self.is_token_ids[req_index, start_idx:end_idx] = True # Number of token ids in prompt (token_ids_cpu or prompt_embeds). # NOTE(woosuk): This may include spec decode tokens. @@ -350,8 +324,7 @@ def add_request( self.block_table.add_row(request.block_ids, req_index) if sampling_params := request.sampling_params: - if (self.is_spec_decode - and is_spec_decode_unsupported(sampling_params)): + if self.is_spec_decode and is_spec_decode_unsupported(sampling_params): self.spec_decode_unsupported_reqs.add(req_id) if sampling_params.sampling_type == SamplingType.GREEDY: # Should avoid division by zero later when apply_temperature. @@ -370,16 +343,15 @@ def add_request( else: top_k = self.vocab_size self.top_k_cpu[req_index] = top_k - self.frequency_penalties_cpu[ - req_index] = sampling_params.frequency_penalty + self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[ - req_index] = sampling_params.presence_penalty + self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[ - req_index] = sampling_params.repetition_penalty + self.repetition_penalties_cpu[req_index] = ( + sampling_params.repetition_penalty + ) if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) @@ -389,13 +361,17 @@ def add_request( self.generators[req_index] = request.generator if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = (self.vocab_size - if sampling_params.logprobs == -1 - else sampling_params.logprobs) + self.num_logprobs[req_id] = ( + self.vocab_size + if sampling_params.logprobs == -1 + else sampling_params.logprobs + ) if sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[req_id] = ( - self.vocab_size if sampling_params.prompt_logprobs == -1 - else sampling_params.prompt_logprobs) + self.vocab_size + if sampling_params.prompt_logprobs == -1 + else sampling_params.prompt_logprobs + ) if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) @@ -406,24 +382,29 @@ def add_request( self.max_num_reqs, self.vocab_size, dtype=torch.bool, - device=self.device) + device=self.device, + ) self.allowed_token_ids_mask_cpu_tensor = torch.zeros( self.max_num_reqs, self.vocab_size, dtype=torch.bool, - device="cpu") + device="cpu", + ) self.allowed_token_ids_mask_cpu_tensor[req_index] = True # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index][ - sampling_params.allowed_token_ids] = False + sampling_params.allowed_token_ids + ] = False if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids + self.bad_words_token_ids[req_index] = ( + sampling_params.bad_words_token_ids + ) elif pooling_params := request.pooling_params: self.pooling_params[req_id] = pooling_params self.logits_processing_needs_token_ids[req_index] = ( - pooling_params.requires_token_ids) + pooling_params.requires_token_ids + ) else: raise NotImplementedError("Unrecognized request type") @@ -500,21 +481,32 @@ def remove_request(self, req_id: str) -> Optional[int]: def swap_states(self, i1: int, i2: int) -> None: old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] - self._req_ids[i1], self._req_ids[i2] =\ - self._req_ids[i2], self._req_ids[i1] # noqa - self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ - self.req_output_token_ids[i2], self.req_output_token_ids[i1] + self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa + self.req_output_token_ids[i1], self.req_output_token_ids[i2] = ( + self.req_output_token_ids[i2], + self.req_output_token_ids[i1], + ) assert old_id_i1 is not None and old_id_i2 is not None - self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ - self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] - self.num_tokens[i1], self.num_tokens[i2] =\ - self.num_tokens[i2], self.num_tokens[i1] - self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ - self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] - self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ - self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] - self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ - self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] + self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = ( + self.req_id_to_index[old_id_i2], + self.req_id_to_index[old_id_i1], + ) + self.num_tokens[i1], self.num_tokens[i2] = ( + self.num_tokens[i2], + self.num_tokens[i1], + ) + self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = ( + self.num_tokens_no_spec[i2], + self.num_tokens_no_spec[i1], + ) + self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = ( + self.num_prompt_tokens[i2], + self.num_prompt_tokens[i1], + ) + self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = ( + self.num_computed_tokens_cpu[i2], + self.num_computed_tokens_cpu[i1], + ) # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -541,8 +533,10 @@ def swap_states(self, i1: int, i2: int) -> None: self.block_table.swap_row(i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = ( + self.request_lora_mapping[i2], + self.request_lora_mapping[i1], + ) if self.is_pooling_model: # Sampling and logits parameters don't apply to pooling models. @@ -550,32 +544,42 @@ def swap_states(self, i1: int, i2: int) -> None: # For autoregressive models, track detailed request reordering info # to support logitsprocs. - self.batch_update_builder.moved.append( - (i1, i2, MoveDirectionality.SWAP)) - - self.temperature_cpu[i1], self.temperature_cpu[i2] = \ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] = \ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] = \ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] - self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\ - self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1] + self.batch_update_builder.moved.append((i1, i2, MoveDirectionality.SWAP)) + + self.temperature_cpu[i1], self.temperature_cpu[i2] = ( + self.temperature_cpu[i2], + self.temperature_cpu[i1], + ) + self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = ( + self.frequency_penalties_cpu[i2], + self.frequency_penalties_cpu[i1], + ) + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = ( + self.presence_penalties_cpu[i2], + self.presence_penalties_cpu[i1], + ) + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = ( + self.repetition_penalties_cpu[i2], + self.repetition_penalties_cpu[i1], + ) + self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] = ( + self.num_accepted_tokens_cpu[i2], + self.num_accepted_tokens_cpu[i1], + ) swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[i1], \ - self.allowed_token_ids_mask_cpu_tensor[i2] =\ - self.allowed_token_ids_mask_cpu_tensor[i2], \ - self.allowed_token_ids_mask_cpu_tensor[i1] + ( + self.allowed_token_ids_mask_cpu_tensor[i1], + self.allowed_token_ids_mask_cpu_tensor[i2], + ) = ( + self.allowed_token_ids_mask_cpu_tensor[i2], + self.allowed_token_ids_mask_cpu_tensor[i1], + ) def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. @@ -627,23 +631,28 @@ def condense(self) -> None: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ - last_req_index, :num_tokens] + last_req_index, :num_tokens + ] self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ - last_req_index, :num_tokens] + last_req_index, :num_tokens + ] if last_req_index in self.req_prompt_embeds: - self.req_prompt_embeds[ - empty_index] = self.req_prompt_embeds.pop(last_req_index) + self.req_prompt_embeds[empty_index] = self.req_prompt_embeds.pop( + last_req_index + ) self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ - last_req_index] - self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] + last_req_index + ] + self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index] + self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[ + last_req_index + ] self.block_table.move_row(last_req_index, empty_index) self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] + last_req_index + ] if self.is_pooling_model: last_req_index -= 1 @@ -653,33 +662,35 @@ def condense(self) -> None: # Autoregressive models require detailed tracking of condense # operations to support logitsprocs self.batch_update_builder.moved.append( - (last_req_index, empty_index, - MoveDirectionality.UNIDIRECTIONAL)) + (last_req_index, empty_index, MoveDirectionality.UNIDIRECTIONAL) + ) - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[ - empty_index] = self.frequency_penalties_cpu[last_req_index] - self.presence_penalties_cpu[ - empty_index] = self.presence_penalties_cpu[last_req_index] - self.repetition_penalties_cpu[ - empty_index] = self.repetition_penalties_cpu[last_req_index] - self.num_accepted_tokens_cpu[ - empty_index] = self.num_accepted_tokens_cpu[last_req_index] + self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[ + last_req_index + ] + self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[ + last_req_index + ] + self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[ + last_req_index + ] + self.num_accepted_tokens_cpu[empty_index] = self.num_accepted_tokens_cpu[ + last_req_index + ] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator # TODO convert these to LogitsProcessors if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[ - empty_index] = self.allowed_token_ids_mask_cpu_tensor[ - last_req_index] + self.allowed_token_ids_mask_cpu_tensor[empty_index] = ( + self.allowed_token_ids_mask_cpu_tensor[last_req_index] + ) - bad_words_token_ids = self.bad_words_token_ids.pop( - last_req_index, None) + bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None) if bad_words_token_ids is not None: self.bad_words_token_ids[empty_index] = bad_words_token_ids @@ -711,8 +722,9 @@ def refresh_metadata(self): def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs if not self.all_greedy: - temperature = copy_slice(self.temperature_cpu_tensor, - self.temperature, num_reqs) + temperature = copy_slice( + self.temperature_cpu_tensor, self.temperature, num_reqs + ) else: temperature = None if not self.no_top_p: @@ -724,16 +736,22 @@ def _make_sampling_metadata(self) -> SamplingMetadata: # Since syncing these tensors is expensive only copy them # if necessary i.e. if there are requests which require # penalties to be applied during sampling. - copy_slice(self.frequency_penalties_cpu_tensor, - self.frequency_penalties, num_reqs) - copy_slice(self.presence_penalties_cpu_tensor, - self.presence_penalties, num_reqs) - copy_slice(self.repetition_penalties_cpu_tensor, - self.repetition_penalties, num_reqs) + copy_slice( + self.frequency_penalties_cpu_tensor, self.frequency_penalties, num_reqs + ) + copy_slice( + self.presence_penalties_cpu_tensor, self.presence_penalties, num_reqs + ) + copy_slice( + self.repetition_penalties_cpu_tensor, + self.repetition_penalties, + num_reqs, + ) needs_prompt_token_ids = ( not self.no_penalties - or self.logits_processing_needs_token_ids[:num_reqs].any()) + or self.logits_processing_needs_token_ids[:num_reqs].any() + ) if needs_prompt_token_ids: # The prompt tokens are used only for applying penalties or # step pooling during the sampling/pooling process. @@ -746,8 +764,11 @@ def _make_sampling_metadata(self) -> SamplingMetadata: allowed_token_ids_mask: Optional[torch.Tensor] = None if not self.no_allowed_token_ids: assert self.allowed_token_ids_mask is not None - copy_slice(self.allowed_token_ids_mask_cpu_tensor, - self.allowed_token_ids_mask, num_reqs) + copy_slice( + self.allowed_token_ids_mask_cpu_tensor, + self.allowed_token_ids_mask, + num_reqs, + ) allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] return SamplingMetadata( @@ -777,8 +798,7 @@ def get_pooling_metadata(self) -> PoolingMetadata: pooling_params = self.get_pooling_params() return PoolingMetadata( - prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]), + prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, ) @@ -797,9 +817,8 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: # Use the value of vocab_size as a pad since we don't have a # token_id of this value. for i in range(num_reqs): - prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size - return prompt_token_ids_cpu_tensor.to(device=self.device, - non_blocking=True) + prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) def make_lora_inputs( self, num_scheduled_tokens: np.ndarray @@ -815,12 +834,12 @@ def make_lora_inputs( 3. lora_requests: Set of relevant LoRA requests. """ - req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + req_lora_mapping = self.request_lora_mapping[: self.num_reqs] prompt_lora_mapping = tuple(req_lora_mapping) - token_lora_mapping = tuple( - req_lora_mapping.repeat(num_scheduled_tokens)) + token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens)) active_lora_requests: set[LoRARequest] = set( - self.lora_id_to_lora_request.values()) + self.lora_id_to_lora_request.values() + ) return prompt_lora_mapping, token_lora_mapping, active_lora_requests @@ -846,9 +865,11 @@ def no_top_k(self) -> bool: @property def no_penalties(self) -> bool: - return (len(self.presence_penalties_reqs) == 0 - and len(self.frequency_penalties_reqs) == 0 - and len(self.repetition_penalties_reqs) == 0) + return ( + len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0 + ) @property def max_num_logprobs(self) -> Optional[int]: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ff95acf0c016..5cbbe435a789 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -24,70 +24,102 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config, update_config) +from vllm.config import ( + CompilationLevel, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, + update_config, +) from vllm.distributed.eplb.eplb_state import EplbState -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, graph_capture, is_global_first_rank, - prepare_communication_buffer_for_model) -from vllm.forward_context import (BatchDescriptor, DPMetadata, - set_forward_context) + get_pp_group, + get_tp_group, + graph_capture, + is_global_first_rank, + prepare_communication_buffer_for_model, +) +from vllm.forward_context import BatchDescriptor, DPMetadata, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.models.interfaces import (SupportsMultiModal, - is_mixture_of_experts, - supports_eagle3, - supports_mrope, - supports_multimodal_pruning, - supports_transcription) -# yapf: enable +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + is_mixture_of_experts, + supports_eagle3, + supports_mrope, + supports_multimodal_pruning, + supports_transcription, +) from vllm.model_executor.models.interfaces_base import ( - VllmModelForPooling, is_pooling_model, is_text_generation_model) + VllmModelForPooling, + is_pooling_model, + is_text_generation_model, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, cdiv, check_use_alibi, get_dtype_size, - is_pin_memory_available, - length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo) +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + DeviceMemoryProfiler, + GiB_bytes, + cdiv, + check_use_alibi, + get_dtype_size, + is_pin_memory_available, + length_from_prompt_token_ids_or_embeds, + round_up, + supports_dynamo, +) from vllm.utils.jsontree import json_map_leaves from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, create_fast_prefill_custom_backend, - reorder_batch_to_split_decodes_and_prefills, split_attn_metadata) + reorder_batch_to_split_decodes_and_prefills, + split_attn_metadata, +) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher -# yapf conflicts with isort for this block -# yapf: disable -from vllm.v1.kv_cache_interface import (AttentionSpec, - ChunkedLocalAttentionSpec, - CrossAttentionSpec, - EncoderOnlyAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - MambaSpec, MLAAttentionSpec, - SlidingWindowSpec, - UniformTypeKVCacheSpecs) -# yapf: enable -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - DraftTokenIds, LogprobsLists, LogprobsTensors, - ModelRunnerOutput, PoolerOutput, SamplerOutput) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + MambaSpec, + MLAAttentionSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + DraftTokenIds, + LogprobsLists, + LogprobsTensors, + ModelRunnerOutput, + PoolerOutput, + SamplerOutput, +) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -101,18 +133,21 @@ from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper -from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin) +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds, - ubatch_split) +from vllm.v1.worker.ubatch_splitting import check_ubatch_thresholds, ubatch_split from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices from vllm.v1.worker.utils import is_residual_scattered_for_sp -from .utils import (AttentionGroup, MultiModalBudget, - add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, - gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders) +from .utils import ( + AttentionGroup, + MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, + gather_mm_placeholders, + sanity_check_mm_encoder_outputs, + scatter_mm_placeholders, +) if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -122,13 +157,11 @@ AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled -PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], - AttnMetadataDict] +PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict] # Wrapper for ModelRunnerOutput to support overlapped execution. class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): - def __init__( self, model_runner_output: ModelRunnerOutput, @@ -151,12 +184,13 @@ def __init__( with torch.cuda.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self._sampled_token_ids_cpu = self._sampled_token_ids.to( - 'cpu', non_blocking=True) + "cpu", non_blocking=True + ) self._async_copy_ready_event.record() def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. - + This function blocks until the copy is finished. """ self._async_copy_ready_event.synchronize() @@ -174,7 +208,6 @@ def get_output(self) -> ModelRunnerOutput: class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def __init__( self, vllm_config: VllmConfig, @@ -192,10 +225,10 @@ def __init__( self.observability_config = vllm_config.observability_config from vllm.model_executor.models.utils import set_cpu_offload_max_bytes - set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) - from vllm.model_executor.layers.batch_invariant import ( - init_batch_invariance) + + set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) + from vllm.model_executor.layers.batch_invariant import init_batch_invariance + init_batch_invariance() model_config = self.model_config @@ -208,13 +241,13 @@ def __init__( if cache_config.cache_dtype == "auto": self.kv_cache_dtype = self.dtype else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - self.is_pooling_model = (model_config.runner_type == 'pooling') + self.is_pooling_model = model_config.runner_type == "pooling" self.enable_prompt_embeds = model_config.enable_prompt_embeds self.is_multimodal_raw_input_only_model = ( - model_config.is_multimodal_raw_input_only_model) + model_config.is_multimodal_raw_input_only_model + ) # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len @@ -227,12 +260,12 @@ def __init__( # TODO: Support overlapping mirco-batches # https://github.com/vllm-project/vllm/issues/18019 self.broadcast_pp_output = ( - self.parallel_config.distributed_executor_backend - == "external_launcher" and len(get_pp_group().ranks) > 0) + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) # Model-related. - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) + self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size # Only relevant for models using ALiBi (e.g, MPT) @@ -244,13 +277,13 @@ def __init__( self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) + model_config + ) if self.model_config.is_encoder_decoder: # Maximum length of the encoder input, only for encoder-decoder # models. - self.max_encoder_len = scheduler_config.\ - max_num_encoder_input_tokens + self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens else: self.max_encoder_len = 0 @@ -284,17 +317,18 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, - self) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( - vllm_config=self.vllm_config, - device=self.device) # type: ignore + vllm_config=self.vllm_config, device=self.device + ) # type: ignore else: - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") + raise ValueError( + "Unknown speculative decoding method: " + f"{self.speculative_config.method}" + ) self.rejection_sampler = RejectionSampler() # Request states. @@ -322,58 +356,64 @@ def __init__( block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( - self.vllm_config, self.device, self.pin_memory, + self.vllm_config, + self.device, + self.pin_memory, self.is_pooling_model, - self.vllm_config.model_config.logits_processors), + self.vllm_config.model_config.logits_processors, + ), is_pooling_model=self.is_pooling_model, ) self.use_async_scheduling = self.scheduler_config.async_scheduling - self.async_output_copy_stream = torch.cuda.Stream() if \ - self.use_async_scheduling else None + self.async_output_copy_stream = ( + torch.cuda.Stream() if self.use_async_scheduling else None + ) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. - if self.compilation_config.cudagraph_capture_sizes and \ - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + if ( + self.compilation_config.cudagraph_capture_sizes + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): self.cudagraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes)) + reversed(self.compilation_config.cudagraph_capture_sizes) + ) # Cache the device properties. self._init_device_properties() # Persistent buffers for CUDA graphs. - self.input_ids = self._make_buffer(self.max_num_tokens, - dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, - dtype=torch.int64) - self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, - dtype=torch.int32) + self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.query_start_loc = self._make_buffer( + self.max_num_reqs + 1, dtype=torch.int32 + ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. - self.inputs_embeds = self._make_buffer(self.max_num_tokens, - self.hidden_size, - dtype=self.dtype, - numpy=False) - self.is_token_ids = self._make_buffer(self.max_num_tokens, - dtype=torch.bool) - self.discard_request_indices = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) + self.inputs_embeds = self._make_buffer( + self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False + ) + self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.discard_request_indices = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) self.num_discarded_requests = 0 - self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) - self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) + self.num_decode_draft_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.num_accepted_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) # Only relevant for multimodal models if self.supports_mm_inputs: - self.is_mm_embed = self._make_buffer(self.max_num_tokens, - dtype=torch.bool) + self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -388,7 +428,8 @@ def __init__( # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = self._make_buffer( - (3, self.max_num_tokens + 1), dtype=torch.int64) + (3, self.max_num_tokens + 1), dtype=torch.int64 + ) # CUDA event to synchronize use of reused CPU tensors between steps # when async scheduling is enabled. @@ -403,10 +444,10 @@ def __init__( # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(max(self.max_num_reqs + 1, - self.max_model_len, - self.max_num_tokens), - dtype=np.int64) + self.arange_np = np.arange( + max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + dtype=np.int64, + ) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -418,19 +459,27 @@ def __init__( self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: self.kv_sharing_fast_prefill_logits_indices = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=self.device) + self.max_num_tokens, dtype=torch.int32, device=self.device + ) - self.uniform_decode_query_len = 1 if not self.speculative_config else \ - 1 + self.speculative_config.num_speculative_tokens + self.uniform_decode_query_len = ( + 1 + if not self.speculative_config + else 1 + self.speculative_config.num_speculative_tokens + ) # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) - self.mm_budget = MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) if self.supports_mm_inputs else None + self.mm_budget = ( + MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) + if self.supports_mm_inputs + else None + ) self.reorder_batch_threshold: Optional[int] = None @@ -440,14 +489,14 @@ def __init__( self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: Optional[Union[list[list[int]], - torch.Tensor]] = None + self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_model_len, 1), dtype=torch.int64, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) def _get_positions(self, num_tokens: Any): if isinstance(num_tokens, int): @@ -459,15 +508,16 @@ def _get_positions(self, num_tokens: Any): return self.mrope_positions.gpu[:, num_tokens] return self.positions.gpu[num_tokens] - def _make_buffer(self, - *size: Union[int, torch.SymInt], - dtype: torch.dtype, - numpy: bool = True) -> CpuGpuBuffer: - return CpuGpuBuffer(*size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory, - with_numpy=numpy) + def _make_buffer( + self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True + ) -> CpuGpuBuffer: + return CpuGpuBuffer( + *size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy, + ) def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() @@ -480,9 +530,11 @@ def _init_model_kwargs(self, num_tokens: int): token_type_id_requests = dict[int, Any]() for i, param in enumerate(pooling_params): - if param.extra_kwargs is not None and \ - (token_types := param.extra_kwargs.get( - "compressed_token_type_ids")) is not None: + if ( + param.extra_kwargs is not None + and (token_types := param.extra_kwargs.get("compressed_token_type_ids")) + is not None + ): token_type_id_requests[i] = token_types if len(token_type_id_requests) == 0: @@ -497,7 +549,8 @@ def _init_model_kwargs(self, num_tokens: int): token_type_ids.append(ids) model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( - device=self.device) + device=self.device + ) return model_kwargs def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: @@ -523,17 +576,18 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # required for DCP with q_len > 1, so we assert here. Remove this # assert once the custom mask is support is added to FA3. if self.dcp_world_size > 1: - assert self.reorder_batch_threshold == 1, \ + assert self.reorder_batch_threshold == 1, ( "DCP not support reorder_batch_threshold > 1 now." + ) reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, - decode_threshold=self.reorder_batch_threshold) + decode_threshold=self.reorder_batch_threshold, + ) # Note: used for model runner override. def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties - """ + """Initialize attributes from torch.cuda.get_device_properties""" self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count @@ -589,8 +643,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params - if sampling_params and \ - sampling_params.sampling_type == SamplingType.RANDOM_SEED: + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: @@ -647,14 +703,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: new_token_ids = req_data.new_token_ids[i] # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec tokens. - num_new_tokens = (num_computed_tokens + len(new_token_ids) - - req_state.num_tokens) + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) if num_new_tokens == 1: # Avoid slicing list in most common case. req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) elif num_output_tokens < len(req_state.output_token_ids): # Some output tokens were discarded due to a sync-KV-load # failure. Align the cached state. @@ -662,21 +718,22 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is not None: - old_end_idx = self.input_batch.num_tokens_no_spec[ - req_index] - end_idx = self.input_batch.num_prompt_tokens[ - req_index] + num_output_tokens + old_end_idx = self.input_batch.num_tokens_no_spec[req_index] + end_idx = ( + self.input_batch.num_prompt_tokens[req_index] + + num_output_tokens + ) self.input_batch.num_tokens[req_index] = end_idx self.input_batch.num_tokens_no_spec[req_index] = end_idx - self.input_batch.is_token_ids[req_index, - end_idx:old_end_idx] = False + self.input_batch.is_token_ids[req_index, end_idx:old_end_idx] = ( + False + ) # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: assert new_block_ids is not None @@ -693,11 +750,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: - self.input_batch.block_table.append_row( - new_block_ids, req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. @@ -706,21 +761,22 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = new_token_ids - self.input_batch.num_tokens_no_spec[ - req_index] = end_token_index + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, () + ) if spec_token_ids: num_spec_tokens = len(spec_token_ids) start_index = self.input_batch.num_tokens_no_spec[req_index] end_token_index = start_index + num_spec_tokens self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index] = spec_token_ids + req_index, start_index:end_token_index + ] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec tokens. self.input_batch.num_tokens[req_index] += num_spec_tokens @@ -737,7 +793,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() def _update_states_after_model_execute( - self, output_token_ids: torch.Tensor) -> None: + self, output_token_ids: torch.Tensor + ) -> None: """Update the cached states after model execution. This is used for MTP/EAGLE for hybrid models, as in linear attention, @@ -750,14 +807,26 @@ def _update_states_after_model_execute( return # Find the number of accepted tokens for each sequence. - num_accepted_tokens = (torch.cat( - [ - output_token_ids, - torch.full((output_token_ids.size(0), 1), - -1, - device=output_token_ids.device), - ], - dim=1) == -1).int().argmax(-1).cpu().numpy() + num_accepted_tokens = ( + ( + torch.cat( + [ + output_token_ids, + torch.full( + (output_token_ids.size(0), 1), + -1, + device=output_token_ids.device, + ), + ], + dim=1, + ) + == -1 + ) + .int() + .argmax(-1) + .cpu() + .numpy() + ) for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens @@ -784,7 +853,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState): use_audio_in_video = True if supports_mrope(self.model): - req_state.mrope_positions, req_state.mrope_position_delta = \ + req_state.mrope_positions, req_state.mrope_position_delta = ( self.model.get_mrope_input_positions( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, @@ -794,8 +863,9 @@ def _init_mrope_positions(self, req_state: CachedRequestState): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + ) else: - req_state.mrope_positions, req_state.mrope_position_delta = \ + req_state.mrope_positions, req_state.mrope_position_delta = ( MRotaryEmbedding.get_input_positions_tensor( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, @@ -805,6 +875,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + ) def _extract_mm_kwargs( self, @@ -823,10 +894,10 @@ def _extract_mm_kwargs( model = cast(SupportsMultiModal, self.model) mm_kwargs_combined: BatchedTensorInputs = {} for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): mm_kwargs_combined.update(mm_kwargs_group) @@ -862,10 +933,11 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange - def _prepare_input_ids(self, total_num_scheduled_tokens: int, - cu_num_tokens: np.ndarray) -> None: + def _prepare_input_ids( + self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray + ) -> None: """Prepare the input IDs for the current batch. - + Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the GPU need to be copied into the corresponding slots into input_ids.""" @@ -894,7 +966,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # last token in each common request. flattened_index = cu_num_tokens[cur_index].item() - 1 flattened_indices.append(flattened_index) - indices_match &= (prev_index == flattened_index) + indices_match &= prev_index == flattened_index max_flattened_index = max(max_flattened_index, flattened_index) num_commmon_tokens = len(flattened_indices) if num_commmon_tokens < total_num_scheduled_tokens: @@ -914,28 +986,27 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # The indices are both the same permutation of 0..N-1 so # we can copy directly using a single slice. self.input_ids.gpu[:num_commmon_tokens].copy_( - self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, - 0], - non_blocking=True) + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], + non_blocking=True, + ) if self.enable_prompt_embeds: self.is_token_ids.gpu[:num_commmon_tokens] = True return # Upload the index tensors asynchronously # so the scatter can be non-blocking. - input_ids_index_tensor = torch.tensor(flattened_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to( - self.device, - non_blocking=True) + input_ids_index_tensor = torch.tensor( + flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( - prev_common_req_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to(self.device, non_blocking=True) + prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, index=input_ids_index_tensor, src=self.input_batch.prev_sampled_token_ids[ - prev_common_req_indices_tensor, 0]) + prev_common_req_indices_tensor, 0 + ], + ) def _get_encoder_seq_lens( self, @@ -957,10 +1028,17 @@ def _get_encoder_seq_lens( def _prepare_inputs( self, scheduler_output: "SchedulerOutput" - ) -> tuple[PerLayerAttnMetadata, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray, - Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], - Optional[torch.Tensor], bool]: + ) -> tuple[ + PerLayerAttnMetadata, + torch.Tensor, + Optional[SpecDecodeMetadata], + np.ndarray, + Optional[CommonAttentionMetadata], + int, + Optional[UBatchSlices], + Optional[torch.Tensor], + bool, + ]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -986,19 +1064,19 @@ def _prepare_inputs( # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -1009,24 +1087,28 @@ def _prepare_inputs( # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) token_indices_tensor = torch.from_numpy(token_indices) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - token_indices_tensor, - out=self.input_ids.cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + token_indices_tensor, + out=self.input_ids.cpu[:total_num_scheduled_tokens], + ) if self.enable_prompt_embeds: is_token_ids = self.input_batch.is_token_ids.flatten() torch.index_select( is_token_ids, 0, token_indices_tensor, - out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) + out=self.is_token_ids.cpu[:total_num_scheduled_tokens], + ) # Because we did not pre-allocate a massive prompt_embeds CPU tensor on # the InputBatch, we need to fill in the prompt embeds into the expected @@ -1060,52 +1142,49 @@ def _prepare_inputs( actual_num_sched = actual_end - start_pos if actual_num_sched > 0: - self.inputs_embeds.cpu[output_idx:output_idx + - actual_num_sched].copy_( - req_embeds[start_pos:actual_end] - ) + self.inputs_embeds.cpu[ + output_idx : output_idx + actual_num_sched + ].copy_(req_embeds[start_pos:actual_end]) output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) # Prepare the attention metadata. self.query_start_loc.np[0] = 0 - self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that - self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1]) + self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1]) self.query_start_loc.copy_to_gpu() - query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] + query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens num_tokens_padded = num_tokens_unpadded + self.get_local_padding( - num_tokens_unpadded) - uniform_decode = \ - (max_num_scheduled_tokens == self.uniform_decode_query_len) and \ - (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - ubatch_slices, num_tokens_after_padding = \ - ubatch_split(num_scheduled_tokens, - num_tokens_unpadded, - num_tokens_padded, - uniform_decode=uniform_decode, - vllm_config=self.vllm_config) + num_tokens_unpadded + ) + uniform_decode = ( + max_num_scheduled_tokens == self.uniform_decode_query_len + ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + ubatch_slices, num_tokens_after_padding = ubatch_split( + num_scheduled_tokens, + num_tokens_unpadded, + num_tokens_padded, + uniform_decode=uniform_decode, + vllm_config=self.vllm_config, + ) self.seq_lens.np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) + self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + ) # Fill unused with 0 for full cuda graph mode. self.seq_lens.np[num_reqs:].fill(0) self.seq_lens.copy_to_gpu() seq_lens = self.seq_lens.gpu[:num_reqs] max_seq_len = self.seq_lens.np[:num_reqs].max().item() - num_tokens = [ - self.requests[r].num_tokens for r in self.input_batch.req_ids - ] + num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) # Record the index of requests that should not be sampled, @@ -1113,8 +1192,9 @@ def _prepare_inputs( discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) - self.discard_request_indices.np[:self.num_discarded_requests] = ( - discard_request_indices) + self.discard_request_indices.np[: self.num_discarded_requests] = ( + discard_request_indices + ) self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) @@ -1125,13 +1205,13 @@ def _prepare_inputs( # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( self.mrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True) + non_blocking=True, + ) else: # Common case (1D positions) self.positions.copy_to_gpu(total_num_scheduled_tokens) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -1149,27 +1229,35 @@ def _prepare_inputs( # For chunked prefills, use -1 as mask rather than 0, as guided # decoding may rollback speculative tokens. num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if ( - self.input_batch.num_computed_tokens_cpu[req_idx] - >= self.input_batch.num_prompt_tokens[req_idx]) else -1) + num_decode_draft_tokens[req_idx] = ( + len(draft_token_ids) + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx] + ) + else -1 + ) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) + num_draft_tokens, cu_num_tokens + ) logits_indices = spec_decode_metadata.logits_indices # For DECODE only cuda graph of some attention backends (e.g., GDN). - self.num_decode_draft_tokens.np[: - num_reqs] = num_decode_draft_tokens + self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.copy_to_gpu() logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: logits_indices_padded = self._prepare_kv_sharing_fast_prefill( - logits_indices) + logits_indices + ) attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: @@ -1177,26 +1265,29 @@ def _prepare_inputs( use_cascade_attn = False # Used in the below loop. - query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1] seq_lens_cpu = self.seq_lens.cpu[:num_reqs] - num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ] spec_decode_common_attn_metadata = None if use_spec_decode: self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + self.kv_cache_config.kv_cache_groups + ): encoder_seq_lens = self._get_encoder_seq_lens( - scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs) + scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs + ) - if isinstance(kv_cache_group_spec.kv_cache_spec, - EncoderOnlyAttentionSpec): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to # create a dummy block table and slot mapping for them. blk_table_tensor = torch.zeros( @@ -1205,7 +1296,7 @@ def _prepare_inputs( device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens, ), + (total_num_scheduled_tokens,), dtype=torch.int64, device=self.device, ) @@ -1213,16 +1304,14 @@ def _prepare_inputs( else: blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor(num_reqs) - slot_mapping = blk_table.slot_mapping.gpu[: - total_num_scheduled_tokens] + slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_( - -1) - num_common_prefix_blocks = ( - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id]) + blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[ + kv_cache_group_id + ] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -1242,11 +1331,12 @@ def _prepare_inputs( encoder_seq_lens=encoder_seq_lens, ) - if (self.speculative_config - and spec_decode_common_attn_metadata is None): + if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): - if (self.drafter.attn_layer_names[0] - in kv_cache_group_spec.layer_names): + if ( + self.drafter.attn_layer_names[0] + in kv_cache_group_spec.layer_names + ): spec_decode_common_attn_metadata = common_attn_metadata else: spec_decode_common_attn_metadata = common_attn_metadata @@ -1264,24 +1354,27 @@ def _prepare_inputs( ) extra_attn_metadata_args = {} - if use_spec_decode and isinstance(builder, - GDNAttentionMetadataBuilder): + if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens. - gpu[:num_reqs], - num_decode_draft_tokens_cpu=self. - num_decode_draft_tokens.cpu[:num_reqs], + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ + :num_reqs + ], ) if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata) + ubatch_slices, common_attn_metadata + ) for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list): - attn_metadata_i = (attn_group.get_metadata_builder( - ubatch_id=ubid).build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata)) + common_attn_metadata_list + ): + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + ) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][layer_name] = attn_metadata_i @@ -1290,9 +1383,9 @@ def _prepare_inputs( attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", - False) + **extra_attn_metadata_args, + ) + use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1304,10 +1397,17 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens, ubatch_slices, - num_tokens_after_padding, use_cascade_attn) + return ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens, + spec_decode_common_attn_metadata, + max_num_scheduled_tokens, + ubatch_slices, + num_tokens_after_padding, + use_cascade_attn, + ) def _compute_cascade_attn_prefix_len( self, @@ -1379,18 +1479,20 @@ def _compute_cascade_attn_prefix_len( # this case. num_reqs = len(num_scheduled_tokens) common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min() + ) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * - kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or - (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None)) - use_local_attention = ( - isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) - or (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.attention_chunk_size is not None)) + common_prefix_len = ( + common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size + ) + use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None + ) + use_local_attention = isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None + ) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -1410,18 +1512,15 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): req = self.requests[req_id] assert req.mrope_positions is not None - num_computed_tokens = \ - self.input_batch.num_computed_tokens_cpu[index] - num_scheduled_tokens = \ - scheduler_output.num_scheduled_tokens[req_id] + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - req.prompt_token_ids, req.prompt_embeds) + req.prompt_token_ids, req.prompt_embeds + ) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: - prompt_part_len = max(0, - num_prompt_tokens - num_computed_tokens) - completion_part_len = max( - 0, num_scheduled_tokens - prompt_part_len) + prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) + completion_part_len = max(0, num_scheduled_tokens - prompt_part_len) else: prompt_part_len = num_scheduled_tokens completion_part_len = 0 @@ -1435,8 +1534,9 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions.cpu[:, dst_start:dst_end] = ( - req.mrope_positions[:, src_start:src_end]) + self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[ + :, src_start:src_end + ] mrope_pos_ptr += prompt_part_len if completion_part_len > 0: @@ -1476,10 +1576,12 @@ def _calc_spec_decode_metadata( # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( - num_sampled_tokens, cumsum_dtype=np.int32) + num_sampled_tokens, cumsum_dtype=np.int32 + ) # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( - cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens + ) # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] logits_indices += arange @@ -1490,22 +1592,28 @@ def _calc_spec_decode_metadata( # cu_num_draft_tokens: [3, 3, 5, 5, 6] # arange: [0, 1, 2, 0, 1, 0] cu_num_draft_tokens, arange = self._get_cumsum_and_arange( - num_draft_tokens, cumsum_dtype=np.int32) + num_draft_tokens, cumsum_dtype=np.int32 + ) # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens + ) # [0, 1, 2, 5, 6, 9] target_logits_indices += arange # TODO: Optimize the CPU -> GPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( - self.device, non_blocking=True) - logits_indices = torch.from_numpy(logits_indices).to(self.device, - non_blocking=True) + self.device, non_blocking=True + ) + logits_indices = torch.from_numpy(logits_indices).to( + self.device, non_blocking=True + ) target_logits_indices = torch.from_numpy(target_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] @@ -1529,23 +1637,26 @@ def _prepare_kv_sharing_fast_prefill( assert self.kv_sharing_fast_prefill_logits_indices is not None num_logits = logits_indices.shape[0] assert num_logits > 0 - self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( - logits_indices) + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(logits_indices) # There might have leftover indices in logits_indices[num_logits:] # from previous iterations, whose values may be greater than the # batch size in the current iteration. To ensure indices are always # valid, we fill the padded indices with the last index. self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item()) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1]): + logits_indices[-1].item() + ) + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1] + ): # Use piecewise CUDA graphs. # Add padding to the batch size. num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) else: num_logits_padded = num_logits - logits_indices_padded = ( - self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) + logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ + :num_logits_padded + ] return logits_indices_padded def _batch_mm_kwargs_from_scheduler( @@ -1584,7 +1695,8 @@ def _batch_mm_kwargs_from_scheduler( def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # Batch the multi-modal inputs using the helper method. mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( - scheduler_output) + scheduler_output + ) if not mm_kwargs: return @@ -1599,10 +1711,10 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): model = cast(SupportsMultiModal, self.model) encoder_outputs = [] for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # (ekhvedchenia): Temporary hack to limit peak memory usage when # processing multimodal data.This solves the issue with scheduler @@ -1616,11 +1728,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): micro_batch_size = 1 for i in range(0, num_items, micro_batch_size): micro_batch_mm_inputs = dict( - (k, v[i:i + micro_batch_size]) - for k, v in mm_kwargs_group.items()) + (k, v[i : i + micro_batch_size]) + for k, v in mm_kwargs_group.items() + ) micro_batch_outputs = model.get_multimodal_embeddings( - **micro_batch_mm_inputs) + **micro_batch_mm_inputs + ) curr_group_outputs.extend(micro_batch_outputs) else: @@ -1631,8 +1745,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.get_multimodal_embeddings( - **mm_kwargs_group) + curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -1664,11 +1777,9 @@ def _gather_mm_embeddings( for req_id in self.input_batch.req_ids: mm_embeds_req: list[torch.Tensor] = [] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] req_state = self.requests[req_id] - num_computed_tokens = \ - req_state.num_computed_tokens + shift_computed_tokens + num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens for mm_feature in req_state.mm_features: pos_info = mm_feature.mm_position @@ -1696,15 +1807,15 @@ def _gather_mm_embeddings( mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None,\ - f"Encoder cache miss for {mm_hash}." + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ - = True if is_embed is None else is_embed + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( + True if is_embed is None else is_embed + ) mm_embeds_item = gather_mm_placeholders( encoder_output[start_idx:end_idx], @@ -1721,7 +1832,8 @@ def _gather_mm_embeddings( multimodal_embeddings=mm_embeds_req, mrope_positions=req_state.mrope_positions, num_computed_tokens=req_state.num_computed_tokens, - )) + ) + ) req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_position_delta = new_delta @@ -1755,10 +1867,10 @@ def _extract_encoder_inputs( model = cast(SupportsMultiModal, self.model) encoder_features = {} for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # Add the grouped features to encoder_features dict # This allows the model to receive them as kwargs (e.g., @@ -1795,21 +1907,24 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: supported_tasks = list(model.pooler.get_supported_tasks()) - if (self.scheduler_config.chunked_prefill_enabled - and "encode" in supported_tasks): + if ( + self.scheduler_config.chunked_prefill_enabled + and "encode" in supported_tasks + ): supported_tasks.remove("encode") - logger.debug_once("Chunked prefill is not supported with " - "encode task which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it.") + logger.debug_once( + "Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by " + "`--no-enable-chunked-prefill` before using it." + ) if "score" in supported_tasks: num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: supported_tasks.remove("score") - logger.debug_once( - "Score API is only enabled for num_labels == 1.") + logger.debug_once("Score API is only enabled for num_labels == 1.") return supported_tasks @@ -1824,9 +1939,11 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return tuple(tasks) def sync_and_slice_intermediate_tensors( - self, num_tokens: int, intermediate_tensors: IntermediateTensors, - sync_self: bool) -> IntermediateTensors: - + self, + num_tokens: int, + intermediate_tensors: IntermediateTensors, + sync_self: bool, + ) -> IntermediateTensors: assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size @@ -1838,21 +1955,21 @@ def sync_and_slice_intermediate_tensors( assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): is_scattered = k == "residual" and is_rs - copy_len = num_tokens // tp if is_scattered else \ - num_tokens + copy_len = num_tokens // tp if is_scattered else num_tokens self.intermediate_tensors[k][:copy_len].copy_( - v[:copy_len], non_blocking=True) - - return IntermediateTensors({ - k: - v[:num_tokens // - tp] if k == "residual" and is_rs else v[:num_tokens] - for k, v in self.intermediate_tensors.items() - }) - - def eplb_step(self, - is_dummy: bool = False, - is_profile: bool = False) -> None: + v[:copy_len], non_blocking=True + ) + + return IntermediateTensors( + { + k: v[: num_tokens // tp] + if k == "residual" and is_rs + else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + } + ) + + def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ @@ -1869,8 +1986,7 @@ def eplb_step(self, log_stats=self.parallel_config.eplb_config.log_balancedness, ) - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: """ Determines the total number of tokens that each rank will run. All ranks will be padded out so that they run with the same number @@ -1897,31 +2013,33 @@ def get_dp_padding(self, return 0, None num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank) + num_tokens, dp_size, dp_rank + ) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) + num_tokens_after_padding = torch.tensor( + [max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32 + ) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding def get_local_padding(self, num_tokens_unpadded: int) -> int: - num_tokens_padded = num_tokens_unpadded - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1] + ): # Use piecewise CUDA graphs. # Add padding to the batch size. - num_tokens_padded = self.vllm_config.pad_for_cudagraph( - num_tokens_unpadded) + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) else: # Eager mode. # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.vllm_config.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: + if ( + self.vllm_config.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): num_tokens_padded = round_up(num_tokens_unpadded, tp_size) num_pad_tokens = num_tokens_padded - num_tokens_unpadded @@ -1931,12 +2049,13 @@ def get_local_padding(self, num_tokens_unpadded: int) -> int: # Should be called after attention metadata creation. This just pads # the second ubatch slice out to the total number of tokens # (num_tokens + padding) - def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, - num_total_tokens: int): - padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, - num_total_tokens) - ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice, - padded_second_ubatch_slice) + def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int): + padded_second_ubatch_slice = slice( + ubatch_slices[1].token_slice.start, num_total_tokens + ) + ubatch_slices[1] = UBatchSlice( + padded_second_ubatch_slice, padded_second_ubatch_slice + ) def _pool( self, @@ -1944,16 +2063,16 @@ def _pool( num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs ==\ - len(self.input_batch.pooling_params), \ - "Either all or none of the requests in" \ - " a batch must be pooling request" + assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), ( + "Either all or none of the requests in a batch must be pooling request" + ) hidden_states = hidden_states[:num_scheduled_tokens] pooling_metadata = self.input_batch.get_pooling_metadata() - pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), - device=hidden_states.device) - seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] + pooling_metadata.build_pooling_cursor( + num_scheduled_tokens_np.tolist(), device=hidden_states.device + ) + seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( @@ -1968,8 +2087,8 @@ def _pool( pooler_output: list[Optional[torch.Tensor]] = [] for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens + ): output = raw_output if seq_len == prompt_len else None pooler_output.append(output) @@ -1983,11 +2102,13 @@ def _pool( ) def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH - and hasattr(self, "cudagraph_batch_sizes") - and self.cudagraph_batch_sizes - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH + and hasattr(self, "cudagraph_batch_sizes") + and self.cudagraph_batch_sizes + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] + ): # Use CUDA graphs. # Add padding to the batch size. return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) @@ -1996,8 +2117,10 @@ def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if (self.compilation_config.pass_config.enable_sequence_parallelism - and tp_size > 1): + if ( + self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens @@ -2007,10 +2130,16 @@ def _preprocess( intermediate_tensors: Optional[IntermediateTensors] = None, ubatch_slices: Optional[UBatchSlices] = None, num_tokens_after_padding: Optional[torch.Tensor] = None, - ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, - Optional[IntermediateTensors], dict[str, Any]]: - + ) -> tuple[ + int, + int, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + Optional[IntermediateTensors], + dict[str, Any], + ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if ubatch_slices: assert num_tokens_after_padding is not None @@ -2018,18 +2147,19 @@ def _preprocess( self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif ubatch_slices is None: num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) - num_pad, num_tokens_after_padding = self.get_dp_padding( - num_input_tokens) + num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens) num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if (self.supports_mm_inputs and get_pp_group().is_first_rank - and not self.model_config.is_encoder_decoder): + if ( + self.supports_mm_inputs + and get_pp_group().is_first_rank + and not self.model_config.is_encoder_decoder + ): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds, is_mm_embed = self._gather_mm_embeddings( - scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -2041,8 +2171,7 @@ def _preprocess( ) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds.gpu[:num_scheduled_tokens].copy_( - inputs_embeds_scheduled) + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -2063,14 +2192,15 @@ def _preprocess( # If a batch only has token ids, then including the embedding layer # in the CUDA graph will be more performant (like in the else case # below). - token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens] \ - .nonzero(as_tuple=False) \ + token_ids_idx = ( + self.is_token_ids.gpu[:num_scheduled_tokens] + .nonzero(as_tuple=False) .squeeze(1) + ) # Some tokens ids may need to become embeds if token_ids_idx.numel() > 0: token_ids = self.input_ids.gpu[token_ids_idx] - tokens_to_embeds = self.model.get_input_embeddings( - input_ids=token_ids) + tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids) self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -2093,10 +2223,13 @@ def _preprocess( intermediate_tensors = None else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True) + num_input_tokens, intermediate_tensors, True + ) - if (self.model_config.is_encoder_decoder - and scheduler_output.scheduled_encoder_inputs): + if ( + self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs + ): encoder_inputs = self._extract_encoder_inputs(scheduler_output) model_kwargs.update(encoder_inputs) @@ -2112,8 +2245,9 @@ def _preprocess( ) def _sample( - self, logits: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata] + self, + logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata @@ -2152,24 +2286,28 @@ def _sample( return sampler_output def _bookkeeping_sync( - self, scheduler_output: "SchedulerOutput", - sampler_output: SamplerOutput, logits: Optional[torch.Tensor], - hidden_states: torch.Tensor, num_scheduled_tokens: int + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, + num_scheduled_tokens: int, ) -> tuple[ - dict[str, int], - Optional[LogprobsLists], - list[list[int]], - dict[str, Optional[LogprobsTensors]], - list[str], - dict[str, int], - list[int], + dict[str, int], + Optional[LogprobsLists], + list[list[int]], + dict[str, Optional[LogprobsTensors]], + list[str], + dict[str, int], + list[int], ]: num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] + discard_sampled_tokens_req_indices = self.discard_request_indices.np[ + : self.num_discarded_requests + ] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -2178,14 +2316,14 @@ def _bookkeeping_sync( # Copy some objects so they don't get modified after returning. # This is important when using async scheduling. req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = \ - self.input_batch.req_id_to_index.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None + logprobs_lists = ( + logprobs_tensors.tolists() if logprobs_tensors is not None else None + ) # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( @@ -2220,10 +2358,10 @@ def _bookkeeping_sync( # Cache the sampled tokens on the GPU and avoid CPU sync. # These will be copied into input_ids in the next step # when preparing inputs. - self.input_batch.prev_sampled_token_ids = \ - sampled_token_ids - self.input_batch.prev_sampled_token_ids_invalid_indices = \ + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = ( invalid_req_indices_set + ) self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) @@ -2238,8 +2376,7 @@ def _bookkeeping_sync( req_ids = self.input_batch.req_ids for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: - sampled_ids = [-1] if \ - req_idx not in invalid_req_indices_set else None + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None else: sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: @@ -2250,7 +2387,8 @@ def _bookkeeping_sync( assert end_idx <= self.max_model_len + 1, ( "Sampled token IDs exceed the max model length + 1. " f"Total number of tokens: {end_idx} > max_model_len + 1: " - f"{self.max_model_len + 1}") + f"{self.max_model_len + 1}" + ) n_tokens_cache = len(sampled_ids) @@ -2263,11 +2401,12 @@ def _bookkeeping_sync( if end_idx == self.max_model_len + 1: n_tokens_cache -= 1 - self.input_batch.token_ids_cpu[req_idx, start_idx:( - start_idx + n_tokens_cache)] = sampled_ids[:n_tokens_cache] - self.input_batch.is_token_ids[req_idx, - start_idx:(start_idx + - n_tokens_cache)] = True + self.input_batch.token_ids_cpu[ + req_idx, start_idx : (start_idx + n_tokens_cache) + ] = sampled_ids[:n_tokens_cache] + self.input_batch.is_token_ids[ + req_idx, start_idx : (start_idx + n_tokens_cache) + ] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx @@ -2312,7 +2451,7 @@ def _model_forward( """Helper method to call the model forward pass. This method can be overridden by subclasses for model execution. - Motivation: We can inspect only this method versus + Motivation: We can inspect only this method versus the whole execute_model, which has additional logic. Args: @@ -2349,18 +2488,27 @@ def execute_model( # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward( - scheduler_output, self.vllm_config) + scheduler_output, self.vllm_config + ) if self.cache_config.kv_sharing_fast_prefill: assert not self.input_batch.num_prompt_logprobs, ( "--kv-sharing-fast-prefill produces incorrect " "logprobs for prompt tokens, tokens, please disable " - "it when the requests need prompt logprobs") + "it when the requests need prompt logprobs" + ) # Prepare the decoder inputs. - (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len, ubatch_slices, num_tokens_after_padding, - use_cascade_attn) = self._prepare_inputs(scheduler_output) + ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + spec_decode_common_attn_metadata, + max_query_len, + ubatch_slices, + num_tokens_after_padding, + use_cascade_attn, + ) = self._prepare_inputs(scheduler_output) ( num_scheduled_tokens, @@ -2371,26 +2519,33 @@ def execute_model( positions, intermediate_tensors, model_kwargs, - ) = self._preprocess(scheduler_output, intermediate_tensors, - ubatch_slices, num_tokens_after_padding) - - uniform_decode = (max_query_len - == self.uniform_decode_query_len) and ( - num_scheduled_tokens - == self.input_batch.num_reqs * max_query_len) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor, - use_cascade_attn) + ) = self._preprocess( + scheduler_output, + intermediate_tensors, + ubatch_slices, + num_tokens_after_padding, + ) + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + num_scheduled_tokens == self.input_batch.num_reqs * max_query_len + ) + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, uniform_decode=uniform_decode + ) + cudagraph_runtime_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) + ) # Set cudagraph mode to none if calc_kv_scales is true. if attn_metadata is not None: - metadata_list = (attn_metadata.values() if isinstance( - attn_metadata, dict) else [attn_metadata]) + metadata_list = ( + attn_metadata.values() + if isinstance(attn_metadata, dict) + else [attn_metadata] + ) if any( - getattr(m, 'enable_kv_scales_calculation', False) - for m in metadata_list): + getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list + ): cudagraph_runtime_mode = CUDAGraphMode.NONE # This is currently to get around the assert in the DPMetadata @@ -2400,7 +2555,8 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - with (set_forward_context( + with ( + set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_input_tokens, @@ -2408,9 +2564,10 @@ def execute_model( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, - ), record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output) as - kv_connector_output): + ), + record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): model_output = self._model_forward( input_ids=input_ids, positions=positions, @@ -2438,8 +2595,9 @@ def execute_model( if self.is_pooling_model: # Return the pooling output. - output = self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) + output = self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + ) output.kv_connector_output = kv_connector_output return output @@ -2451,14 +2609,15 @@ def execute_model( if not get_pp_group().is_last_rank: all_gather_tensors = { - "residual": - not is_residual_scattered_for_sp( - self.vllm_config, num_input_tokens) + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) } get_pp_group().send_tensor_dict( hidden_states.tensors, all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors) + all_gather_tensors=all_gather_tensors, + ) logits = None else: sample_hidden_states = hidden_states[logits_indices] @@ -2468,16 +2627,17 @@ def execute_model( if logits is not None: model_output_broadcast_data["logits"] = logits.contiguous() - model_output_broadcast_data = get_pp_group( - ).broadcast_tensor_dict(model_output_broadcast_data, - src=len(get_pp_group().ranks) - 1) + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: - apply_grammar_bitmask(scheduler_output, self.input_batch, - logits, self.device) + apply_grammar_bitmask( + scheduler_output, self.input_batch, logits, self.device + ) with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) @@ -2496,22 +2656,27 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_common_attn_metadata, ) - use_padded_batch_for_eagle = self.speculative_config and \ - self.speculative_config.use_eagle() and \ - not self.speculative_config.disable_padded_drafter_batch + use_padded_batch_for_eagle = ( + self.speculative_config + and self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) effective_drafter_max_model_len = self.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len - if (self.speculative_config - and self.speculative_config.draft_model_config is not None - and self.speculative_config.draft_model_config.max_model_len - is not None): + if ( + self.speculative_config + and self.speculative_config.draft_model_config is not None + and self.speculative_config.draft_model_config.max_model_len is not None + ): effective_drafter_max_model_len = ( - self.speculative_config.draft_model_config.max_model_len) + self.speculative_config.draft_model_config.max_model_len + ) input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.max_seq_len + - self.speculative_config.num_speculative_tokens - <= effective_drafter_max_model_len) + spec_decode_common_attn_metadata.max_seq_len + + self.speculative_config.num_speculative_tokens + <= effective_drafter_max_model_len + ) if use_padded_batch_for_eagle and input_fits_in_drafter: # EAGLE speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. @@ -2526,12 +2691,19 @@ def propose_draft_token_ids(sampled_token_ids): req_ids_output_copy, req_id_to_index_output_copy, invalid_req_indices, - ) = self._bookkeeping_sync(scheduler_output, sampler_output, - logits, hidden_states, - num_scheduled_tokens) + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + num_scheduled_tokens, + ) - if (self.speculative_config and not use_padded_batch_for_eagle - and input_fits_in_drafter): + if ( + self.speculative_config + and not use_padded_batch_for_eagle + and input_fits_in_drafter + ): # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) @@ -2587,10 +2759,12 @@ def propose_draft_token_ids( assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( - sampled_token_ids, self.input_batch.req_ids, + sampled_token_ids, + self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs) + self.input_batch.spec_decode_unsupported_reqs, + ) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -2603,8 +2777,8 @@ def propose_draft_token_ids( offset = 0 assert spec_decode_metadata is not None for num_draft, tokens in zip( - spec_decode_metadata.num_draft_tokens, - sampled_token_ids): + spec_decode_metadata.num_draft_tokens, sampled_token_ids + ): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) @@ -2621,29 +2795,35 @@ def propose_draft_token_ids( # When padded-batch is disabled, the sampled_token_ids should be # the cpu-side list[list[int]] of valid sampled tokens for each # request, with invalid requests having empty lists. - assert isinstance(sampled_token_ids, list), \ - "sampled_token_ids should be a python list when" \ + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when" "padded-batch is disabled." + ) next_token_ids = self.drafter.prepare_next_token_ids_cpu( - sampled_token_ids, self.requests, self.input_batch, - scheduler_output.num_scheduled_tokens) + sampled_token_ids, + self.requests, + self.input_batch, + scheduler_output.num_scheduled_tokens, + ) else: # When using padded-batch, the sampled_token_ids should be # the gpu tensor of sampled tokens for each request, of shape # (num_reqs, num_spec_tokens + 1) with rejected tokens having # value -1. - assert isinstance(sampled_token_ids, torch.Tensor), \ - "sampled_token_ids should be a torch.Tensor when" \ + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor when" "padded-batch is enabled." - next_token_ids, valid_sampled_tokens_count = \ + ) + next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( common_attn_metadata, sampled_token_ids, self.requests, self.input_batch, self.discard_request_indices.gpu, - self.num_discarded_requests + self.num_discarded_requests, ) + ) if spec_decode_metadata is None: token_indices_to_sample = None @@ -2653,32 +2833,34 @@ def propose_draft_token_ids( if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) + [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.speculative_config.disable_padded_drafter_batch: token_indices_to_sample = None - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, - sampled_token_ids, - spec_decode_metadata.num_draft_tokens) + common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens, + ) else: - common_attn_metadata, token_indices, \ - token_indices_to_sample =\ + common_attn_metadata, token_indices, token_indices_to_sample = ( self.drafter.prepare_inputs_padded( common_attn_metadata, spec_decode_metadata, - valid_sampled_tokens_count) + valid_sampled_tokens_count, + ) + ) target_token_ids = self.input_ids.gpu[token_indices] target_positions = self._get_positions(token_indices) if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) + [h[token_indices] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[token_indices] @@ -2706,9 +2888,10 @@ def propose_draft_token_ids( def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " f"Allowed configs: {allowed_config_names}" + ) config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -2721,26 +2904,24 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) if eep_scale_up: from vllm.distributed.parallel_state import get_ep_group - num_local_physical_experts = torch.empty(1, - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) + + num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") + torch.distributed.broadcast( + num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 + ) num_local_physical_experts = int(num_local_physical_experts.item()) new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = ( - EplbState.recv_state()) + global_expert_load, old_global_expert_indices = EplbState.recv_state() num_logical_experts = global_expert_load.shape[1] self.parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts) - assert old_global_expert_indices.shape[ - 1] % num_local_physical_experts == 0 - old_ep_size = old_global_expert_indices.shape[ - 1] // num_local_physical_experts + num_local_physical_experts * new_ep_size - num_logical_experts + ) + assert old_global_expert_indices.shape[1] % num_local_physical_experts == 0 + old_ep_size = ( + old_global_expert_indices.shape[1] // num_local_physical_experts + ) rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) + old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size) } else: global_expert_load = None @@ -2752,36 +2933,50 @@ def load_model(self, eep_scale_up: bool = False) -> None: model_loader = get_model_loader(self.load_config) logger.info("Loading model from scratch...") self.model = model_loader.load_model( - vllm_config=self.vllm_config, model_config=self.model_config) + vllm_config=self.vllm_config, model_config=self.model_config + ) if self.lora_config: - self.model = self.load_lora_model(self.model, self.vllm_config, - self.device) + self.model = self.load_lora_model( + self.model, self.vllm_config, self.device + ) if hasattr(self, "drafter"): logger.info("Loading drafter model...") self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: - if supports_eagle3(self.model): - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) - else: + if not supports_eagle3(self.model): raise RuntimeError( "Model does not support EAGLE3 interface but " - "aux_hidden_state_outputs was requested") + "aux_hidden_state_outputs was requested" + ) + + # Try to get auxiliary layers from speculative config, + # otherwise use model's default layers + aux_layers = self._get_eagle3_aux_layers_from_config() + if aux_layers: + logger.info( + "Using auxiliary layers from speculative config: %s", + aux_layers, + ) + else: + aux_layers = self.model.get_eagle3_aux_hidden_state_layers() + + self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load) + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load, + ) prepare_communication_buffer_for_model(self.model) - self.is_multimodal_pruning_enabled = (supports_multimodal_pruning( - self.model) and self.model_config.multimodal_config. - is_multimodal_pruning_enabled()) + self.is_multimodal_pruning_enabled = ( + supports_multimodal_pruning(self.model) + and self.model_config.multimodal_config.is_multimodal_pruning_enabled() + ) - if is_mixture_of_experts( - self.model) and self.parallel_config.enable_eplb: - logger.info("EPLB is enabled for model %s.", - self.model_config.model) + if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + logger.info("EPLB is enabled for model %s.", self.model_config.model) self.eplb_state = EplbState.build( self.model, self.device, @@ -2792,11 +2987,10 @@ def load_model(self, eep_scale_up: bool = False) -> None: ) if ( - self.vllm_config.compilation_config.level == \ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo() + self.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS + and supports_dynamo() ): - backend = self.vllm_config.compilation_config.init_backend( - self.vllm_config) + backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) compilation_counter.dynamo_as_is_count += 1 self.model.compile(fullgraph=True, backend=backend) return @@ -2804,26 +2998,54 @@ def load_model(self, eep_scale_up: bool = False) -> None: # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. - if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \ - and not self.parallel_config.enable_dbo: - self.model = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.enable_dbo + ): + self.model = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) elif self.parallel_config.enable_dbo: if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.model = UBatchWrapper(self.model, self.vllm_config, - CUDAGraphMode.FULL, self.device) + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.FULL, self.device + ) else: - self.model = UBatchWrapper(self.model, self.vllm_config, - CUDAGraphMode.NONE, self.device) + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.NONE, self.device + ) + + def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]: + """Extract Eagle3 auxiliary layer indices from speculative config. + + These indices specify which hidden states from the base model should + be used as auxiliary inputs for the Eagle3 drafter model during + speculative decoding. + + Returns: + Tuple of layer indices if found in draft model config, + None otherwise. + """ + if not (self.speculative_config and self.speculative_config.draft_model_config): + return None + + hf_config = self.speculative_config.draft_model_config.hf_config + if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"): + return None + + layer_ids = hf_config.eagle_aux_hidden_state_layer_ids + if layer_ids and isinstance(layer_ids, (list, tuple)): + return tuple(layer_ids) + + return None def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ + assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." + ) model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") - model_loader.load_weights(self.get_model(), - model_config=self.model_config) + model_loader.load_weights(self.get_model(), model_config=self.model_config) def save_tensorized_model( self, @@ -2861,7 +3083,8 @@ def _get_prompt_logprobs_dict( num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Set up target LogprobsTensors object. logprobs_tensors = in_progress_dict.get(req_id) @@ -2869,7 +3092,8 @@ def _get_prompt_logprobs_dict( # Create empty logprobs CPU tensors for the entire prompt. # If chunked, we'll copy in slice by slice. logprobs_tensors = LogprobsTensors.empty_cpu( - num_prompt_tokens - 1, num_prompt_logprobs + 1) + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) in_progress_dict[req_id] = logprobs_tensors # Determine number of logits to retrieve. @@ -2899,27 +3123,29 @@ def _get_prompt_logprobs_dict( # then there is prompt logprob generated for each index. req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc.np[req_idx].item() - prompt_hidden_states = hidden_states[offset:offset + num_logits] + prompt_hidden_states = hidden_states[offset : offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want # to gather the logprob for. - tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] # Compute prompt logprobs. logprobs = self.sampler.compute_logprobs(logits) token_ids, logprobs, ranks = self.sampler.gather_logprobs( - logprobs, num_prompt_logprobs, tgt_token_ids) + logprobs, num_prompt_logprobs, tgt_token_ids + ) # Transfer GPU->CPU async. chunk_slice = slice(start_idx, start_idx + num_logits) logprobs_tensors.logprob_token_ids[chunk_slice].copy_( - token_ids, non_blocking=True) - logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, - non_blocking=True) + token_ids, non_blocking=True + ) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) logprobs_tensors.selected_token_ranks[chunk_slice].copy_( - ranks, non_blocking=True) + ranks, non_blocking=True + ) # Remove requests that have completed prefill from the batch # num_prompt_logprobs_dict. @@ -2947,8 +3173,9 @@ def _get_nans_in_logits( req_index = self.input_batch.req_id_to_index[req_id] num_nans_in_logits[req_id] = ( int(num_nans_for_index[req_index]) - if num_nans_for_index is not None - and req_index < logits.shape[0] else 0) + if num_nans_for_index is not None and req_index < logits.shape[0] + else 0 + ) return num_nans_in_logits except IndexError: return {} @@ -2974,11 +3201,11 @@ def rand_input_ids() -> torch.Tensor: self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype) + dtype=input_ids.dtype, + ) logger.debug_once("Randomizing dummy data for DP Rank") - input_ids.copy_(rand_input_ids()[:input_ids.size(0)], - non_blocking=True) + input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True) yield input_ids.fill_(0) @@ -3003,13 +3230,15 @@ def _get_mm_dummy_batch( dummy_mm_items = [dummy_mm_item] * max_items_per_batch model = cast(SupportsMultiModal, self.model) - return next(mm_kwargs_group - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - )) + return next( + mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ) + ) @torch.inference_mode() def _dummy_run( @@ -3046,8 +3275,10 @@ def _dummy_run( (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ - assert cudagraph_runtime_mode is None or \ - cudagraph_runtime_mode.valid_runtime_modes() + assert ( + cudagraph_runtime_mode is None + or cudagraph_runtime_mode.valid_runtime_modes() + ) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -3062,8 +3293,7 @@ def _dummy_run( # When setting max_query_len = 1, we switch to and capture the optimized # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. - max_query_len = self.uniform_decode_query_len if uniform_decode else \ - num_tokens + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -3079,9 +3309,7 @@ def _dummy_run( num_reqs = num_decode_tokens + 1 # Create decode requests (1 token each) followed by prefill request - num_scheduled_tokens_list = [1] * num_decode_tokens + [ - num_prefill_tokens - ] + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] # Note: Overriding max_query_len to be the prefill tokens max_query_len = num_prefill_tokens elif uniform_decode: @@ -3098,8 +3326,7 @@ def _dummy_run( assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) ubatch_slices = None @@ -3153,56 +3380,62 @@ def _dummy_run( self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() - cum_num_tokens, _ = self._get_cumsum_and_arange( - num_scheduled_tokens) - self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + self.kv_cache_config.kv_cache_groups + ): common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + - 1], + query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], seq_lens=self.seq_lens.gpu[:num_reqs], seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ], num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, max_seq_len=self.max_model_len, - block_table_tensor=self.input_batch. - block_table[kv_cache_group_id].get_device_tensor(num_reqs), + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id + ].get_device_tensor(num_reqs), slot_mapping=self.input_batch.block_table[ - kv_cache_group_id].slot_mapping.gpu[:num_tokens], - causal=True) + kv_cache_group_id + ].slot_mapping.gpu[:num_tokens], + causal=True, + ) for attn_group in self.attn_groups[kv_cache_group_id]: if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata) + ubatch_slices, common_attn_metadata + ) for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list): + common_attn_metadata_list + ): assert common_attn_metadata.max_query_len == 1 - attn_metadata_i = (attn_group\ - .get_metadata_builder(ubatch_id=ubid)\ - .build_for_cudagraph_capture(common_attn_metadata)) + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build_for_cudagraph_capture(common_attn_metadata) for layer_name in attn_group.layer_names: assert type(attn_metadata) is list - attn_metadata[ubid][ - layer_name] = attn_metadata_i + attn_metadata[ubid][layer_name] = attn_metadata_i else: assert type(attn_metadata) is dict - attn_metadata_i = attn_group.get_metadata_builder()\ - .build_for_cudagraph_capture(common_attn_metadata) + metadata_builder = attn_group.get_metadata_builder() + attn_metadata_i = metadata_builder.build_for_cudagraph_capture( + common_attn_metadata + ) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i - with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens, remove_lora): + with self.maybe_dummy_run_with_lora( + self.lora_config, num_scheduled_tokens, remove_lora + ): model_kwargs = self._init_model_kwargs(num_tokens) - if (self.supports_mm_inputs - and not self.model_config.is_encoder_decoder): + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens] model_kwargs = { @@ -3230,23 +3463,35 @@ def _dummy_run( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, - device=self.device)) + device=self.device, + ) + ) intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens, None, False) + num_tokens, None, False + ) # filter out the valid batch descriptor - _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens_after_padding, - uniform_decode=uniform_decode)) \ - if not is_profile else (CUDAGraphMode.NONE, None) + _cg_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch( + BatchDescriptor( + num_tokens=num_tokens_after_padding, + uniform_decode=uniform_decode, + ) + ) + if not is_profile + else (CUDAGraphMode.NONE, None) + ) if cudagraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for cudagraph capture - assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \ - cudagraph_runtime_mode == _cg_mode, ( + assert ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode == _cg_mode + ), ( f"Cudagraph runtime mode mismatch at dummy_run. " - f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." + ) else: cudagraph_runtime_mode = _cg_mode @@ -3258,14 +3503,18 @@ def _dummy_run( if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_after_padding - with self.maybe_randomize_inputs(input_ids), set_forward_context( + with ( + self.maybe_randomize_inputs(input_ids), + set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens_after_padding, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, - ubatch_slices=ubatch_slices): + ubatch_slices=ubatch_slices, + ), + ): outputs = self.model( input_ids=input_ids, positions=positions, @@ -3309,8 +3558,7 @@ def _dummy_sampler_run( logits = self.model.compute_logits(hidden_states) num_reqs = logits.size(0) - dummy_tensors = lambda v: torch.full( - (num_reqs, ), v, device=self.device) + dummy_tensors = lambda v: torch.full((num_reqs,), v, device=self.device) dummy_metadata = SamplingMetadata( temperature=dummy_tensors(0.5), @@ -3331,37 +3579,39 @@ def _dummy_sampler_run( logitsprocs=LogitsProcessors(), ) try: - sampler_output = self.sampler(logits=logits, - sampling_metadata=dummy_metadata) + sampler_output = self.sampler( + logits=logits, sampling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up sampler with " f"{num_reqs} dummy requests. Please try lowering " "`max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e if self.speculative_config: draft_token_ids = [[0] for _ in range(num_reqs)] dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids, self.device) + draft_token_ids, self.device + ) num_tokens = sum(len(ids) for ids in draft_token_ids) # draft_probs = torch.randn( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn(num_tokens, - logits.shape[-1], - device=self.device, - dtype=logits.dtype) + target_logits = torch.randn( + num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype + ) # NOTE(woosuk): Here, we should use int32 because the sampler uses # int32 for bonus_token_ids. If the dtype mismatches, re-compilation # will occur at runtime. - bonus_token_ids = torch.zeros(num_reqs, - device=self.device, - dtype=torch.int32) + bonus_token_ids = torch.zeros( + num_reqs, device=self.device, dtype=torch.int32 + ) self.rejection_sampler( dummy_spec_decode_metadata, draft_probs, @@ -3391,9 +3641,9 @@ def _dummy_pooler_run_task( num_scheduled_tokens_list, device="cpu", ) - dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), - dtype=torch.int32, - device=self.device) + dummy_token_ids = torch.zeros( + (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device + ) model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) @@ -3407,19 +3657,22 @@ def _dummy_pooler_run_task( pooling_params=[dummy_pooling_params] * num_reqs, ) - dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, - device=hidden_states.device) + dummy_metadata.build_pooling_cursor( + num_scheduled_tokens_list, device=hidden_states.device + ) try: - return model.pooler(hidden_states=hidden_states, - pooling_metadata=dummy_metadata) + return model.pooler( + hidden_states=hidden_states, pooling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " "lowering `max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e @@ -3445,7 +3698,8 @@ def profile_run(self) -> None: if self.model_config.multimodal_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " - "encoder cache.") + "encoder cache." + ) else: mm_budget = self.mm_budget assert mm_budget is not None @@ -3455,8 +3709,9 @@ def profile_run(self) -> None: # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget \ - .max_items_per_batch_by_modality[dummy_modality] + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + dummy_modality + ] logger.info( "Encoder cache will be initialized with a budget of " @@ -3474,9 +3729,9 @@ def profile_run(self) -> None: ) # Run multimodal encoder. - dummy_encoder_outputs = \ - self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs + ) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, @@ -3493,7 +3748,8 @@ def profile_run(self) -> None: expanded_outputs = [] for output in dummy_encoder_outputs: expanded = output.new_zeros( - (encoder_budget, encoder_output_shape[-1])) + (encoder_budget, encoder_output_shape[-1]) + ) num_tokens = output.shape[0] expanded[:num_tokens].copy_(output) expanded_outputs.append(expanded) @@ -3501,12 +3757,12 @@ def profile_run(self) -> None: dummy_encoder_outputs = expanded_outputs # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers - hidden_states, last_hidden_states \ - = self._dummy_run(self.max_num_tokens, is_profile=True) + hidden_states, last_hidden_states = self._dummy_run( + self.max_num_tokens, is_profile=True + ) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -3523,7 +3779,8 @@ def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "ensure `cudagraph_mode` was not manually set to `NONE`") + "ensure `cudagraph_mode` was not manually set to `NONE`" + ) return 0 else: self.initialize_cudagraph_capture() @@ -3563,24 +3820,29 @@ def freeze_gc(): self._capture_cudagraphs( compilation_cases, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False) + uniform_decode=False, + ) # Capture full cudagraph for uniform decode batches if we # don't already have full mixed prefill-decode cudagraphs. - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ - cudagraph_mode.separate_routine(): - max_num_tokens = self.scheduler_config.max_num_seqs * \ - self.uniform_decode_query_len + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + self.scheduler_config.max_num_seqs * self.uniform_decode_query_len + ) decode_cudagraph_batch_sizes = [ - x for x in self.cudagraph_batch_sizes if - x <= max_num_tokens and x >= self.uniform_decode_query_len + x + for x in self.cudagraph_batch_sizes + if x <= max_num_tokens and x >= self.uniform_decode_query_len ] - compilation_cases_decode = list( - reversed(decode_cudagraph_batch_sizes)) + compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes)) self._capture_cudagraphs( compilation_cases=compilation_cases_decode, cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True) + uniform_decode=True, + ) torch.cuda.synchronize() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -3596,16 +3858,23 @@ def freeze_gc(): elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / (1 << 30)) + logger.info( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + ) return cuda_graph_size - def _capture_cudagraphs(self, compilation_cases: list[int], - cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool): - assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ - cudagraph_runtime_mode.valid_runtime_modes(), \ - f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" + def _capture_cudagraphs( + self, + compilation_cases: list[int], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool, + ): + assert ( + cudagraph_runtime_mode != CUDAGraphMode.NONE + and cudagraph_runtime_mode.valid_runtime_modes() + ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" # Only rank 0 should print progress bar during capture if is_global_first_rank(): @@ -3614,7 +3883,9 @@ def _capture_cudagraphs(self, compilation_cases: list[int], disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", - cudagraph_runtime_mode.name)) + cudagraph_runtime_mode.name, + ), + ) # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: @@ -3622,14 +3893,16 @@ def _capture_cudagraphs(self, compilation_cases: list[int], # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph - allow_microbatching = self.parallel_config.enable_dbo \ - and cudagraph_runtime_mode == CUDAGraphMode.FULL \ - and uniform_decode \ + allow_microbatching = ( + self.parallel_config.enable_dbo + and cudagraph_runtime_mode == CUDAGraphMode.FULL + and uniform_decode and check_ubatch_thresholds( config=self.vllm_config.parallel_config, num_tokens=num_tokens, uniform_decode=uniform_decode, ) + ) for _ in range(self.compilation_config.cudagraph_num_of_warmups): # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. @@ -3637,29 +3910,31 @@ def _capture_cudagraphs(self, compilation_cases: list[int], # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. - force_attention = ( - cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + ) + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + ) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_groups) == 0, \ - "Attention backends are already initialized" + assert len(self.attn_groups) == 0, "Attention backends are already initialized" class AttentionGroupKey(NamedTuple): attn_backend: type[AttentionBackend] @@ -3669,8 +3944,8 @@ def get_attn_backends_for_group( kv_cache_group_spec: KVCacheGroupSpec, ) -> dict[AttentionGroupKey, list[str]]: layers = get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase, - kv_cache_group_spec.layer_names) + self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names + ) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than @@ -3690,23 +3965,19 @@ def get_attn_backends_for_group( full_cls_name = attn_backend.full_cls_name() layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): - layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ - layer_name] + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] key = (full_cls_name, layer_kv_cache_spec) - attn_backends[key] = AttentionGroupKey(attn_backend, - layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey( + attn_backend, layer_kv_cache_spec + ) attn_backend_layers[key].append(layer_name) - return { - attn_backends[k]: v - for k, v in attn_backend_layers.items() - } + return {attn_backends[k]: v for k, v in attn_backend_layers.items()} def create_attn_groups( attn_backends_map: dict[AttentionGroupKey, list[str]], ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] - for (attn_backend, - kv_cache_spec), layer_names in attn_backends_map.items(): + for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): attn_group = AttentionGroup.create_with_metadata_builders( attn_backend, layer_names, @@ -3714,7 +3985,8 @@ def create_attn_groups( self.vllm_config, self.device, num_metadata_builders=1 - if not self.parallel_config.enable_dbo else 2, + if not self.parallel_config.enable_dbo + else 2, ) attn_groups.append(attn_group) @@ -3729,7 +4001,7 @@ def create_attn_groups( def initialize_cudagraph_capture(self) -> None: """ - Resolve the cudagraph_mode when there are multiple attention + Resolve the cudagraph_mode when there are multiple attention backends with potential conflicting CUDA graph support. Then initialize the cudagraph_dispatcher based on the resolved cudagraph_mode. @@ -3745,81 +4017,110 @@ def initialize_cudagraph_capture(self) -> None: # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported - if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \ - and min_cg_support != AttentionCGSupport.ALWAYS: - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") + if ( + cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL + and min_cg_support != AttentionCGSupport.ALWAYS + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})" + ) if min_cg_support == AttentionCGSupport.NEVER: # if not supported any full cudagraphs, just raise it. - msg += "; please try cudagraph_mode=PIECEWISE, and "\ + msg += ( + "; please try cudagraph_mode=PIECEWISE, and " "make sure compilation level is piecewise" + ) raise ValueError(msg) # attempt to resolve the full cudagraph related mode if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE + ) else: msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_DECODE_ONLY + ) logger.warning(msg) # check that if we are doing decode full-cudagraphs it is supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and min_cg_support == AttentionCGSupport.NEVER): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") - if (self.compilation_config.level == CompilationLevel.PIECEWISE and - (self.compilation_config.splitting_ops_contain_attention() - or self.compilation_config.use_inductor_graph_partition)): - msg += "; setting cudagraph_mode=PIECEWISE because "\ + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and min_cg_support == AttentionCGSupport.NEVER + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})" + ) + if self.compilation_config.level == CompilationLevel.PIECEWISE and ( + self.compilation_config.splitting_ops_contain_attention() + or self.compilation_config.use_inductor_graph_partition + ): + msg += ( + "; setting cudagraph_mode=PIECEWISE because " "attention is compiled piecewise" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: - msg += "; setting cudagraph_mode=NONE because "\ + msg += ( + "; setting cudagraph_mode=NONE because " "attention is not compiled piecewise" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # check that if we are doing spec-decode + decode full-cudagraphs it is # supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and self.uniform_decode_query_len > 1 and min_cg_support.value - < AttentionCGSupport.UNIFORM_BATCH.value): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported" - f" with spec-decode for attention backend " - f"{min_cg_builder_name} (support: {min_cg_support})") + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 + and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_cg_builder_name} (support: {min_cg_support})" + ) if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: msg += "; setting cudagraph_mode=NONE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # double check that we can support full cudagraph if they are requested # even after automatic downgrades - if cudagraph_mode.has_full_cudagraphs() \ - and min_cg_support == AttentionCGSupport.NEVER: - raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not " - f"supported with {min_cg_builder_name} backend (" - f"support:{min_cg_support}) " - "; please try cudagraph_mode=PIECEWISE, " - "and make sure compilation level is piecewise") + if ( + cudagraph_mode.has_full_cudagraphs() + and min_cg_support == AttentionCGSupport.NEVER + ): + raise ValueError( + f"CUDAGraphMode.{cudagraph_mode.name} is not " + f"supported with {min_cg_builder_name} backend (" + f"support:{min_cg_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise" + ) # Trigger cudagraph dispatching keys initialization here (after # initializing attn backends). self.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, - self.uniform_decode_query_len) + self.compilation_config.cudagraph_mode, self.uniform_decode_query_len + ) def calculate_reorder_batch_threshold(self) -> None: """ @@ -3831,22 +4132,20 @@ def calculate_reorder_batch_threshold(self) -> None: # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) - reorder_batch_threshold_i = ( - attn_metadata_builder_i.reorder_batch_threshold) + reorder_batch_threshold_i = attn_metadata_builder_i.reorder_batch_threshold if reorder_batch_threshold_i is not None: if self.reorder_batch_threshold is not None: - if reorder_batch_threshold_i != \ - self.reorder_batch_threshold: + if reorder_batch_threshold_i != self.reorder_batch_threshold: raise ValueError( f"Attention backend reorders decodes with " f"threshold {reorder_batch_threshold_i} but other " f"backend uses threshold " - f"{self.reorder_batch_threshold}") + f"{self.reorder_batch_threshold}" + ) else: self.reorder_batch_threshold = reorder_batch_threshold_i - def may_reinitialize_input_batch(self, - kv_cache_config: KVCacheConfig) -> None: + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ Re-initialize the input batch if the block sizes are different from `[self.cache_config.block_size]`. This usually happens when there @@ -3863,7 +4162,8 @@ def may_reinitialize_input_batch(self, assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 - "for more details.") + "for more details." + ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=max(self.max_model_len, self.max_encoder_len), @@ -3877,11 +4177,14 @@ def may_reinitialize_input_batch(self, is_pooling_model=self.is_pooling_model, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0), + if self.vllm_config.speculative_config + else 0 + ), ) def _allocate_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initializes the KV cache buffer with the correct size. The buffer needs to be reshaped to the desired shape before being used by the models. @@ -3891,12 +4194,12 @@ def _allocate_kv_cache_tensors( Returns: dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros(kv_cache_tensor.size, - dtype=torch.int8, - device=self.device) + tensor = torch.zeros( + kv_cache_tensor.size, dtype=torch.int8, device=self.device + ) for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tensor @@ -3906,8 +4209,9 @@ def _allocate_kv_cache_tensors( if layer_name in self.runner_only_attn_layers: continue layer_names.add(layer_name) - assert layer_names == set(kv_cache_raw_tensors.keys( - )), "Some layers are not correctly initialized" + assert layer_names == set(kv_cache_raw_tensors.keys()), ( + "Some layers are not correctly initialized" + ) return kv_cache_raw_tensors def _attn_group_iterator(self) -> Iterator[AttentionGroup]: @@ -3945,8 +4249,7 @@ def _reshape_kv_cache_tensors( continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_cache_shape = attn_backend.get_kv_cache_shape( @@ -3954,41 +4257,43 @@ def _reshape_kv_cache_tensors( kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype) + cache_dtype_str=self.cache_config.cache_dtype, + ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = \ - attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len( - kv_cache_shape) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple( - range(len(kv_cache_shape))) + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) # The allocation respects the backend-defined stride order # to ensure the semantic remains consistent for each # backend. We first obtain the generic kv cache shape and # then permute it according to the stride order which could # result in a non-contiguous tensor. - kv_cache_shape = tuple(kv_cache_shape[i] - for i in kv_cache_stride_order) + kv_cache_shape = tuple( + kv_cache_shape[i] for i in kv_cache_stride_order + ) # Maintain original KV shape view. inv_order = [ kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = kv_cache_raw_tensors[ - layer_name].view(dtype).view(kv_cache_shape).permute( - *inv_order) + kv_caches[layer_name] = ( + kv_cache_raw_tensors[layer_name] + .view(dtype) + .view(kv_cache_shape) + .permute(*inv_order) + ) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] state_tensors = [] storage_offset_bytes = 0 - for (shape, dtype) in zip(kv_cache_spec.shapes, - kv_cache_spec.dtypes): + for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): dtype_size = get_dtype_size(dtype) num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size) + kv_cache_spec.page_size_bytes // dtype_size + ) target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) @@ -4012,7 +4317,8 @@ def _reshape_kv_cache_tensors( return kv_caches def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor]) -> None: + self, kv_caches: dict[str, torch.Tensor] + ) -> None: """ Update the layout of attention layers from (2, num_blocks, ...) to (num_blocks, 2, ...). @@ -4025,19 +4331,21 @@ def _update_hybrid_attention_mamba_layout( kv_cache_spec = group.kv_cache_spec for layer_name in group.layer_names: kv_cache = kv_caches[layer_name] - if (isinstance(kv_cache_spec, AttentionSpec) - and kv_cache.shape[0] == 2): - assert kv_cache.shape[1] != 2, \ - "Fail to determine whether the layout is " \ - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2: + assert kv_cache.shape[1] != 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " f"a tensor of shape {kv_cache.shape}" + ) hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_(size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, - *kv_cache.stride()[2:])) + kv_cache.as_strided_( + size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + ) def initialize_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initialize the memory buffer for KV cache. @@ -4050,25 +4358,29 @@ def initialize_kv_cache_tensors( # Initialize the memory buffer for KV cache kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) # Change the memory buffer to the desired shape - kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, - kv_cache_raw_tensors) + kv_caches = self._reshape_kv_cache_tensors( + kv_cache_config, kv_cache_raw_tensors + ) # Set up cross-layer KV cache sharing - for layer_name, target_layer_name in self.shared_kv_cache_layers.items( - ): - logger.debug("%s reuses KV cache of %s", layer_name, - target_layer_name) + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) kv_caches[layer_name] = kv_caches[target_layer_name] - num_attn_module = 2 \ - if self.model_config.hf_config.model_type == "longcat_flash" else 1 - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches, num_attn_module) + num_attn_module = ( + 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 + ) + bind_kv_cache( + kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches, + num_attn_module, + ) return kv_caches def maybe_add_kv_sharing_layers_to_kv_cache_groups( - self, kv_cache_config: KVCacheConfig) -> None: + self, kv_cache_config: KVCacheConfig + ) -> None: """ Add layers that re-use KV cache to KV cache group of its target layer. Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` @@ -4087,12 +4399,10 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups( # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other # similar KV sharing setups, only the layers that generate KV caches # are involved in the prefill phase, enabling prefill to early exit. - attn_layers = get_layers_from_vllm_config(self.vllm_config, - Attention) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: - self.kv_sharing_fast_prefill_eligible_layers.add( - layer_name) + self.kv_sharing_fast_prefill_eligible_layers.add(layer_name) else: break @@ -4124,23 +4434,23 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if self.dcp_world_size > 1: layer_names = self.attn_groups[0][0].layer_names - layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase, - layer_names) + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, layer_names + ) for layer in layers.values(): assert layer.impl.need_to_return_lse_for_decode, ( "DCP requires attention impls to return" " the softmax lse for decode, but the impl " f"{layer.impl.__class__.__name__} " - "does not return the softmax lse for decode.") + "does not return the softmax lse for decode." + ) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. """ block_size = self.vllm_config.cache_config.block_size - encoder_only_attn_specs: dict[AttentionSpec, - list[str]] = defaultdict(list) + encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: @@ -4148,16 +4458,18 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) + dtype=self.kv_cache_dtype, + ) encoder_only_attn_specs[attn_spec].append(layer_name) self.runner_only_attn_layers.add(layer_name) if len(encoder_only_attn_specs) > 0: - assert len( - encoder_only_attn_specs - ) == 1, "Only support one encoder-only attention spec now" + assert len(encoder_only_attn_specs) == 1, ( + "Only support one encoder-only attention spec now" + ) spec, layer_names = encoder_only_attn_specs.popitem() self.kv_cache_config.kv_cache_groups.append( - KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec) + ) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -4174,8 +4486,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: + if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so # that KV cache management logic will act as this layer does @@ -4190,79 +4501,84 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # the attention backends if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: - assert not use_mla, "MLA is not supported for sliding" \ - "window" + assert not use_mla, "MLA is not supported for slidingwindow" kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window) + sliding_window=attn_module.sliding_window, + ) elif use_mla: kv_cache_spec[layer_name] = MLAAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - cache_dtype_str=cache_dtype_str) - elif self.attention_chunk_size is not None \ - and isinstance(attn_module, ChunkedLocalAttention): + cache_dtype_str=cache_dtype_str, + ) + elif self.attention_chunk_size is not None and isinstance( + attn_module, ChunkedLocalAttention + ): kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size) + attention_chunk_size=self.attention_chunk_size, + ) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) + dtype=self.kv_cache_dtype, + ) elif attn_module.attn_type == AttentionType.ENCODER_DECODER: kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): + dtype=self.kv_cache_dtype, + ) + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): # encoder-only attention does not need KV cache. continue else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: - if (self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"]): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") - if self.vllm_config.cache_config.enable_prefix_caching: + if ( + self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"] + ): raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") - max_model_len = self.vllm_config.model_config.max_model_len - - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) + "Mamba with speculative decoding is not supported yet." + ) + mamba_block_size = self.vllm_config.cache_config.mamba_block_size + page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( shapes=mamba_module.get_state_shape(), dtypes=mamba_module.get_state_dtype(), - block_size=max_model_len, + block_size=mamba_block_size, page_size_padded=page_size_padded, mamba_type=mamba_module.mamba_type, num_speculative_blocks=( self.speculative_config.num_speculative_tokens - if self.speculative_config else 0), + if self.speculative_config + else 0 + ), ) ds_indexer_layers = get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache) + self.vllm_config, DeepseekV32IndexerCache + ) for layer_name, ds_indexer_module in ds_indexer_layers.items(): kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec() @@ -4277,7 +4593,7 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: # this is in the critical path of every single model # forward loop, this has caused perf issue for a disagg # setup. - pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned = self.sampled_token_ids_pinned_cpu[: sampled_token_ids.shape[0]] pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 39be8c74102e..3bd7c9d538de 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -11,10 +11,12 @@ from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import get_ep_group -from vllm.distributed.device_communicators.pynccl_allocator import ( - set_graph_pool_id) -from vllm.forward_context import (create_forward_context, get_forward_context, - override_forward_context) +from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id +from vllm.forward_context import ( + create_forward_context, + get_forward_context, + override_forward_context, +) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -42,27 +44,31 @@ class CUDAGraphMetaData: class SMControlContextManager: - - def __init__(self, comm_sms: int, set_comm_sms: Callable[[int], None], - set_compute_sms: Callable[[int], None]): + def __init__( + self, + comm_sms: int, + set_comm_sms: Callable[[int], None], + set_compute_sms: Callable[[int], None], + ): """ - Context manager for controlling SM (Streaming Multiprocessor) + Context manager for controlling SM (Streaming Multiprocessor) allocation. Upon entering the context, it sets the number of SMs allocated for communication and computation to comm_sms and total_sms - comm_sms respectively. Upon exiting, it restores the allocation to use all available SMs (i.e. total_sms). Args: - comm_sms (int): The number of SMs to allocate for communication. + comm_sms (int): The number of SMs to allocate for communication. (The remainder will be used for computation.) - set_comm_sms (Callable[[int], None]): + set_comm_sms (Callable[[int], None]): A function that sets the number of SMs for communication. - set_compute_sms (Callable[[int], None]): + set_compute_sms (Callable[[int], None]): A function that sets the number of SMs for computation. """ - assert current_platform.is_cuda(), \ + assert current_platform.is_cuda(), ( "SM control is currently only supported on CUDA" + ) props = torch.cuda.get_device_properties(torch.cuda.current_device()) total_sms = props.multi_processor_count @@ -84,9 +90,13 @@ def __exit__(self, exc_type, exc_value, traceback): class UBatchWrapper: - - def __init__(self, runnable: Callable, vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, device: torch.cuda.device): + def __init__( + self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + device: torch.cuda.device, + ): self.runnable = runnable self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config @@ -100,7 +110,8 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig, self.graph_pool = None if runtime_mode is not CUDAGraphMode.NONE: self.cudagraph_wrapper = CUDAGraphWrapper( - runnable, vllm_config, runtime_mode=runtime_mode) + runnable, vllm_config, runtime_mode=runtime_mode + ) self.graph_pool = current_platform.get_global_graph_pool() self.sm_control = self._create_sm_control_context(vllm_config) @@ -114,8 +125,7 @@ def _create_sm_control_context(vllm_config: VllmConfig): if vllm_config.parallel_config.enable_expert_parallel: # Currently only DeepEP highthroughput supports SM control so this # only affects that case. - all2all_manager = get_ep_group( - ).device_communicator.all2all_manager + all2all_manager = get_ep_group().device_communicator.all2all_manager if all2all_manager.max_sms_used() is not None: comm_sms = min(comm_sms, all2all_manager.max_sms_used()) @@ -127,18 +137,23 @@ def _create_sm_control_context(vllm_config: VllmConfig): set_compute_sms = lambda sms: None if has_deep_gemm() and comm_sms > 0: import deep_gemm as dg + set_compute_sms = lambda sms: dg.set_num_sms(sms) - return SMControlContextManager(comm_sms=comm_sms, - set_comm_sms=set_comm_sms, - set_compute_sms=set_compute_sms) + return SMControlContextManager( + comm_sms=comm_sms, + set_comm_sms=set_comm_sms, + set_compute_sms=set_compute_sms, + ) def __getattr__(self, key: str): # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): return getattr(self.runnable, key) - raise AttributeError(f"Attribute {key} not exists in the runnable of " - f"cudagraph wrapper: {self.runnable}") + raise AttributeError( + f"Attribute {key} not exists in the runnable of " + f"cudagraph wrapper: {self.runnable}" + ) def unwrap(self) -> Callable: # in case we need to access the original runnable. @@ -153,14 +168,14 @@ def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor: the graph capture. The flow is as follows: - 1. The main thread starts up each ubatch thread. Each thread will + 1. The main thread starts up each ubatch thread. Each thread will initialize its cuda context (torch.cuda.current_blas_handle()) before going to sleep upon entering the ubatch_context. - 2. The main thread starts the graph capture and wakes up the first + 2. The main thread starts the graph capture and wakes up the first ubatch thread. - 3. Each ubatch thread runs the model to completion and returns the + 3. Each ubatch thread runs the model to completion and returns the completed output tensors back to the main thread. 4. The main thread stores the captured cudagraph along with its metadata @@ -187,36 +202,38 @@ def _capture_ubatch_thread(results, ubatch_metadata): results: list[tuple[int, torch.Tensor]] = [] compute_stream = ubatch_metadata[0].context.compute_stream - num_tokens = ubatch_metadata[0].num_tokens + \ - ubatch_metadata[1].num_tokens + num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later with override_forward_context(None): ubatch_threads = [] for metadata in ubatch_metadata: - thread = threading.Thread(target=_capture_ubatch_thread, - args=( - results, - metadata, - )) + thread = threading.Thread( + target=_capture_ubatch_thread, + args=( + results, + metadata, + ), + ) ubatch_threads.append(thread) thread.start() self.ready_barrier.wait() # Wait for both threads to be ready # Capture the cudagraph - cudagraph_metadata = \ - CUDAGraphMetaData( - cudagraph=torch.cuda.CUDAGraph(), - ubatch_metadata=ubatch_metadata, - ) + cudagraph_metadata = CUDAGraphMetaData( + cudagraph=torch.cuda.CUDAGraph(), + ubatch_metadata=ubatch_metadata, + ) if self.graph_pool is not None: set_graph_pool_id(self.graph_pool) else: set_graph_pool_id(current_platform.graph_pool_handle()) - with torch.cuda.graph(cudagraph_metadata.cudagraph, - stream=compute_stream, - pool=self.graph_pool): + with torch.cuda.graph( + cudagraph_metadata.cudagraph, + stream=compute_stream, + pool=self.graph_pool, + ): ubatch_metadata[0].context.cpu_wait_event.set() for thread in ubatch_threads: thread.join() @@ -227,7 +244,6 @@ def _capture_ubatch_thread(results, ubatch_metadata): return cudagraph_metadata.outputs def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor: - @torch.inference_mode() def _ubatch_thread(results, model, ubatch_metadata): with ubatch_metadata.context: @@ -247,12 +263,14 @@ def _ubatch_thread(results, model, ubatch_metadata): with override_forward_context(None): ubatch_threads = [] for metadata in ubatch_metadata: - thread = threading.Thread(target=_ubatch_thread, - args=( - results, - model, - metadata, - )) + thread = threading.Thread( + target=_ubatch_thread, + args=( + results, + model, + metadata, + ), + ) ubatch_threads.append(thread) thread.start() self.ready_barrier.wait() # Wait for both threads to be ready @@ -263,11 +281,19 @@ def _ubatch_thread(results, model, ubatch_metadata): result = torch.cat(sorted_results, dim=0) return result - def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, - positions, inputs_embeds, intermediate_tensors, - compute_stream, dp_metadata, batch_descriptor, - cudagraph_runtime_mode) -> list[UbatchMetadata]: - + def _make_ubatch_metadata( + self, + ubatch_slices, + attn_metadata, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + compute_stream, + dp_metadata, + batch_descriptor, + cudagraph_runtime_mode, + ) -> list[UbatchMetadata]: # Create one forward context per ubatch forward_contexts = [] for i, ubatch_slice in enumerate(ubatch_slices): @@ -277,22 +303,32 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, self.vllm_config, dp_metadata=dp_metadata, batch_descriptor=batch_descriptor, - cudagraph_runtime_mode=cudagraph_runtime_mode)) + cudagraph_runtime_mode=cudagraph_runtime_mode, + ) + ) ubatch_ctxs = make_ubatch_contexts( num_micro_batches=len(ubatch_slices), comm_stream=self.comm_stream, compute_stream=compute_stream, forward_contexts=forward_contexts, - ready_barrier=self.ready_barrier) + ready_barrier=self.ready_barrier, + ) ubatch_metadata: list[UbatchMetadata] = [] for i, ubatch_slice in enumerate(ubatch_slices): - sliced_input_ids, sliced_positions, sliced_inputs_embeds, \ - sliced_intermediate_tensors = \ - self._slice_model_inputs( - ubatch_slice.token_slice, input_ids, positions, - inputs_embeds, intermediate_tensors) + ( + sliced_input_ids, + sliced_positions, + sliced_inputs_embeds, + sliced_intermediate_tensors, + ) = self._slice_model_inputs( + ubatch_slice.token_slice, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + ) ubatch_metadata.append( UbatchMetadata( context=ubatch_ctxs[i], @@ -300,13 +336,21 @@ def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, input_ids, positions=sliced_positions, inputs_embeds=sliced_inputs_embeds, intermediate_tensors=sliced_intermediate_tensors, - num_tokens=ubatch_slice.token_slice.stop - - ubatch_slice.token_slice.start)) + num_tokens=ubatch_slice.token_slice.stop + - ubatch_slice.token_slice.start, + ) + ) return ubatch_metadata - def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions, - inputs_embeds, intermediate_tensors): + def _slice_model_inputs( + self, + tokens_slice: slice, + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + ): sliced_input_ids = input_ids[tokens_slice] # if we are using mrope. Mrope adds an additional dimension to the # positions tensor @@ -314,13 +358,17 @@ def _slice_model_inputs(self, tokens_slice: slice, input_ids, positions, sliced_positions = positions[:, tokens_slice] else: sliced_positions = positions[tokens_slice] - sliced_inputs_embeds = inputs_embeds[ - tokens_slice] if inputs_embeds else None - sliced_intermediate_tensors = intermediate_tensors[ - tokens_slice] if intermediate_tensors else None - - return (sliced_input_ids, sliced_positions, sliced_inputs_embeds, - sliced_intermediate_tensors) + sliced_inputs_embeds = inputs_embeds[tokens_slice] if inputs_embeds else None + sliced_intermediate_tensors = ( + intermediate_tensors[tokens_slice] if intermediate_tensors else None + ) + + return ( + sliced_input_ids, + sliced_positions, + sliced_inputs_embeds, + sliced_intermediate_tensors, + ) def __call__(self, *args, **kwargs): forward_context = get_forward_context() @@ -330,7 +378,6 @@ def __call__(self, *args, **kwargs): # If there's no ubatching, just run the runnable object if ubatch_slices is None: - # This is to account for the case where ubatching was aborted. # When we capture full graphs we only capture one graph per shape, # meaning that if we have a ubatched cudagraph for the current @@ -342,20 +389,20 @@ def __call__(self, *args, **kwargs): if batch_descriptor.num_tokens in self.cudagraphs: cudagraph_runtime_mode = CUDAGraphMode.NONE - if cudagraph_runtime_mode in (CUDAGraphMode.NONE, - CUDAGraphMode.PIECEWISE): + if cudagraph_runtime_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE): return self.runnable(*args, **kwargs) else: assert self.cudagraph_wrapper is not None return self.cudagraph_wrapper(*args, **kwargs) attn_metadata = forward_context.attn_metadata - num_tokens = (ubatch_slices[0].token_slice.stop - - ubatch_slices[0].token_slice.start) * 2 - input_ids = kwargs['input_ids'] - positions = kwargs['positions'] - intermediate_tensors = kwargs['intermediate_tensors'] - inputs_embeds = kwargs['inputs_embeds'] + num_tokens = ( + ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start + ) * 2 + input_ids = kwargs["input_ids"] + positions = kwargs["positions"] + intermediate_tensors = kwargs["intermediate_tensors"] + inputs_embeds = kwargs["inputs_embeds"] compute_stream = torch.cuda.current_stream() dp_metadata = forward_context.dp_metadata @@ -363,8 +410,10 @@ def __call__(self, *args, **kwargs): # We shouldn't be here unless we are running with multiple DP ranks assert dp_metadata is not None - if num_tokens not in self.cudagraphs \ - and cudagraph_runtime_mode is CUDAGraphMode.FULL: + if ( + num_tokens not in self.cudagraphs + and cudagraph_runtime_mode is CUDAGraphMode.FULL + ): ubatch_metadata = self._make_ubatch_metadata( ubatch_slices=ubatch_slices, attn_metadata=attn_metadata, @@ -375,11 +424,14 @@ def __call__(self, *args, **kwargs): compute_stream=compute_stream, dp_metadata=dp_metadata, batch_descriptor=batch_descriptor, - cudagraph_runtime_mode=CUDAGraphMode.NONE) + cudagraph_runtime_mode=CUDAGraphMode.NONE, + ) with self.sm_control: return self._capture_ubatches(ubatch_metadata, self.model) - elif num_tokens in self.cudagraphs \ - and cudagraph_runtime_mode is CUDAGraphMode.FULL: + elif ( + num_tokens in self.cudagraphs + and cudagraph_runtime_mode is CUDAGraphMode.FULL + ): cudagraph_metadata = self.cudagraphs[num_tokens] cudagraph_metadata.cudagraph.replay() return cudagraph_metadata.outputs @@ -394,6 +446,7 @@ def __call__(self, *args, **kwargs): compute_stream=compute_stream, dp_metadata=dp_metadata, batch_descriptor=batch_descriptor, - cudagraph_runtime_mode=CUDAGraphMode.NONE) + cudagraph_runtime_mode=CUDAGraphMode.NONE, + ) with self.sm_control: return self._run_ubatches(ubatch_metadata, self.model) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a135a594ac6f..271aabb9e227 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A GPU worker class.""" + import copy import gc import os @@ -13,9 +14,11 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) +from vllm.distributed import ( + ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce, +) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger @@ -28,8 +31,12 @@ from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - DraftTokenIds, ModelRunnerOutput) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + DraftTokenIds, + ModelRunnerOutput, +) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import is_residual_scattered_for_sp @@ -43,7 +50,6 @@ class Worker(WorkerBase): - def __init__( self, vllm_config: VllmConfig, @@ -52,16 +58,18 @@ def __init__( distributed_init_method: str, is_driver_worker: bool = False, ): - - super().__init__(vllm_config=vllm_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker) + super().__init__( + vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() # Buffers saved before sleep @@ -71,8 +79,10 @@ def __init__( # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) + logger.info( + "Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) logger.debug( "Profiler config: record_shapes=%s," "profile_memory=%s,with_stack=%s,with_flops=%s", @@ -91,7 +101,9 @@ def __init__( with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) + torch_profiler_trace_dir, use_gzip=True + ), + ) else: self.profiler = None @@ -104,20 +116,20 @@ def sleep(self, level: int = 1) -> None: if level == 2: model = self.model_runner.model self._sleep_saved_buffers = { - name: buffer.cpu().clone() - for name, buffer in model.named_buffers() + name: buffer.cpu().clone() for name, buffer in model.named_buffers() } allocator = CuMemAllocator.get_instance() - allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) + allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) free_bytes_after_sleep, total = torch.cuda.mem_get_info() freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep used_bytes = total - free_bytes_after_sleep assert freed_bytes >= 0, "Memory usage increased after sleeping." logger.info( - "Sleep mode freed %.2f GiB memory, " - "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, - used_bytes / GiB_bytes) + "Sleep mode freed %.2f GiB memory, %.2f GiB memory is still in use.", + freed_bytes / GiB_bytes, + used_bytes / GiB_bytes, + ) def wake_up(self, tags: Optional[list[str]] = None) -> None: from vllm.device_allocator.cumem import CuMemAllocator @@ -133,23 +145,21 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None: buffer.data.copy_(self._sleep_saved_buffers[name].data) self._sleep_saved_buffers = {} - def _maybe_get_memory_pool_context(self, - tag: str) -> AbstractContextManager: + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: if self.vllm_config.model_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator allocator = CuMemAllocator.get_instance() if tag == "weights": assert allocator.get_current_usage() == 0, ( - "Sleep mode can only be " - "used for one instance per process.") + "Sleep mode can only be used for one instance per process." + ) context = allocator.use_memory_pool(tag=tag) else: context = nullcontext() return context - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -166,10 +176,13 @@ def init_device(self): # memory snapshot # This ensures NCCL buffers are allocated before we measure # available memory - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) # Set random seed. set_random_seed(self.model_config.seed) @@ -180,8 +193,10 @@ def init_device(self): # take current memory snapshot self.init_snapshot = MemorySnapshot() - self.requested_memory = (self.init_snapshot.total_memory * - self.cache_config.gpu_memory_utilization) + self.requested_memory = ( + self.init_snapshot.total_memory + * self.cache_config.gpu_memory_utilization + ) if self.init_snapshot.free_memory < self.requested_memory: GiB = lambda b: round(b / GiB_bytes, 2) raise ValueError( @@ -194,12 +209,12 @@ def init_device(self): f"utilization or reduce GPU memory used by other processes." ) else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") + raise RuntimeError(f"Not support device type: {self.device_config.device}") # Construct the model runner self.model_runner: GPUModelRunner = GPUModelRunner( - self.vllm_config, self.device) + self.vllm_config, self.device + ) if self.rank == 0: # If usage stat is enabled, collect relevant info. @@ -247,7 +262,8 @@ def determine_available_memory(self) -> int: "size. If OOM'ed, check the difference of initial free " "memory between the current run and the previous run " "where kv_cache_memory_bytes is suggested and update it " - "correspondingly.") + "correspondingly." + ) logger.info(msg) return kv_cache_memory_bytes @@ -257,8 +273,8 @@ def determine_available_memory(self) -> int: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. with memory_profiling( - self.init_snapshot, - weights_memory=int(self.model_runner.model_memory_usage), + self.init_snapshot, + weights_memory=int(self.model_runner.model_memory_usage), ) as profile_result: self.model_runner.profile_run() @@ -275,15 +291,15 @@ def determine_available_memory(self) -> int: "This happens when other processes sharing the same container " "release GPU memory while vLLM is profiling during initialization. " "To fix this, ensure consistent GPU memory allocation or " - "isolate vLLM in its own container.") - self.available_kv_cache_memory_bytes = self.requested_memory \ - - profile_result.non_kv_cache_memory + "isolate vLLM in its own container." + ) + self.available_kv_cache_memory_bytes = ( + self.requested_memory - profile_result.non_kv_cache_memory + ) - unrequested_memory = self.init_snapshot.free_memory \ - - self.requested_memory + unrequested_memory = self.init_snapshot.free_memory - self.requested_memory logger.debug( - "Initial free memory: %.2f GiB; " - "Requested memory: %.2f (util), %.2f GiB", + "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB", GiB(self.init_snapshot.free_memory), self.cache_config.gpu_memory_utilization, GiB(self.requested_memory), @@ -295,8 +311,10 @@ def determine_available_memory(self) -> int: GiB(free_gpu_memory - unrequested_memory), ) logger.debug(profile_result) - logger.info("Available KV cache memory: %.2f GiB", - GiB(self.available_kv_cache_memory_bytes)) + logger.info( + "Available KV cache memory: %.2f GiB", + GiB(self.available_kv_cache_memory_bytes), + ) gc.collect() return int(self.available_kv_cache_memory_bytes) @@ -324,15 +342,14 @@ def compile_or_warm_up_model(self) -> None: warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() if not self.model_config.enforce_eager: warmup_sizes = [ - x for x in warmup_sizes if x not in - self.vllm_config.compilation_config.cudagraph_capture_sizes + x + for x in warmup_sizes + if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes ] # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size, - skip_eplb=True, - remove_lora=False) + self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False) self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config) # Warmup and tune the kernels used during model execution before @@ -343,8 +360,9 @@ def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: cuda_graph_memory_bytes = self.model_runner.capture_model() - if (self.cache_config.kv_cache_memory_bytes is None - and hasattr(self, "peak_activation_memory")): + if self.cache_config.kv_cache_memory_bytes is None and hasattr( + self, "peak_activation_memory" + ): # Suggests optimal kv cache memory size if we rely on # memory_profiling to guess the kv cache memory size which # provides peak_activation_memory and a few other memory @@ -358,16 +376,22 @@ def compile_or_warm_up_model(self) -> None: # slightly underestimate the memory consumption. # So leave a small buffer (=150MiB) to avoid OOM. redundancy_buffer_memory = 150 * (1 << 20) - non_kv_cache_memory = (self.model_runner.model_memory_usage + - self.peak_activation_memory + - self.non_torch_memory + - cuda_graph_memory_bytes) + non_kv_cache_memory = ( + self.model_runner.model_memory_usage + + self.peak_activation_memory + + self.non_torch_memory + + cuda_graph_memory_bytes + ) kv_cache_memory_bytes_to_gpu_limit = ( - self.init_snapshot.free_memory - non_kv_cache_memory - - redundancy_buffer_memory) + self.init_snapshot.free_memory + - non_kv_cache_memory + - redundancy_buffer_memory + ) kv_cache_memory_bytes_to_requested_limit = ( - int(self.requested_memory) - non_kv_cache_memory - - redundancy_buffer_memory) + int(self.requested_memory) + - non_kv_cache_memory + - redundancy_buffer_memory + ) msg = ( f"Free memory on device " @@ -388,7 +412,8 @@ def compile_or_warm_up_model(self) -> None: f"{kv_cache_memory_bytes_to_gpu_limit}` " f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully " f"utilize gpu memory. Current kv cache memory in use is " - f"{GiB(self.available_kv_cache_memory_bytes)} GiB.") + f"{GiB(self.available_kv_cache_memory_bytes)} GiB." + ) logger.debug(msg) @@ -398,20 +423,20 @@ def compile_or_warm_up_model(self) -> None: # NOTE: This is called after `capture_model` on purpose to prevent # memory buffers from being cleared by `torch.cuda.empty_cache`. if get_pp_group().is_last_rank: - max_num_reqs = min(self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens) + max_num_reqs = min( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + ) # We skip EPLB here since we don't want to record dummy metrics - hidden_states, last_hidden_states = \ - self.model_runner._dummy_run( - num_tokens=max_num_reqs, - skip_eplb=True, - ) + hidden_states, last_hidden_states = self.model_runner._dummy_run( + num_tokens=max_num_reqs, + skip_eplb=True, + ) if self.model_runner.is_pooling_model: self.model_runner._dummy_pooler_run(hidden_states) else: - self.model_runner._dummy_sampler_run( - hidden_states=last_hidden_states) + self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. @@ -431,32 +456,36 @@ def execute_model( intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - num_input_tokens = self.model_runner._get_num_input_tokens( - num_scheduled_tokens) + num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens) all_gather_tensors = { - "residual": - not is_residual_scattered_for_sp(self.vllm_config, - num_input_tokens) + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) } if forward_pass and not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors)) + all_gather_tensors=all_gather_tensors, + ) + ) - output = self.model_runner.execute_model(scheduler_output, - intermediate_tensors) + output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)): return output assert isinstance(output, IntermediateTensors) parallel_config = self.vllm_config.parallel_config - assert parallel_config.distributed_executor_backend != ( - "external_launcher") and not get_pp_group().is_last_rank + assert ( + parallel_config.distributed_executor_backend != ("external_launcher") + and not get_pp_group().is_last_rank + ) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors) + get_pp_group().send_tensor_dict( + output.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) kv_connector_output = output.kv_connector_output if not kv_connector_output: @@ -483,8 +512,9 @@ def profile(self, is_start: bool = True): self.profiler.stop() # only print profiler results on rank 0 if self.local_rank == 0: - print(self.profiler.key_averages().table( - sort_by="self_cuda_time_total")) + print( + self.profiler.key_averages().table(sort_by="self_cuda_time_total") + ) def execute_dummy_batch(self) -> None: self.model_runner._dummy_run(1, uniform_decode=True) @@ -505,68 +535,79 @@ def check_health(self) -> None: # worker will always be healthy as long as it's running. return - def _eplb_before_scale_down(self, old_ep_size: int, - new_ep_size: int) -> None: + def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None: from vllm.distributed.parallel_state import get_ep_group + if get_ep_group().rank == 0: - logger.info("[Elastic EP] Starting expert resharding " - "before scaling down...") + logger.info( + "[Elastic EP] Starting expert resharding before scaling down..." + ) rank_mapping = { old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 for old_ep_rank in range(old_ep_size) } assert self.model_runner.eplb_state is not None - self.model_runner.eplb_state.rearrange(self.model_runner.model, - execute_shuffle=True, - global_expert_load=None, - rank_mapping=rank_mapping) + self.model_runner.eplb_state.rearrange( + self.model_runner.model, + execute_shuffle=True, + global_expert_load=None, + rank_mapping=rank_mapping, + ) torch.cuda.synchronize() if get_ep_group().rank == 0: logger.info("[Elastic EP] Expert resharding completed!") def _eplb_after_scale_up( - self, old_ep_size: int, new_ep_size: int, - global_expert_load: Optional[torch.Tensor]) -> None: + self, + old_ep_size: int, + new_ep_size: int, + global_expert_load: Optional[torch.Tensor], + ) -> None: from vllm.distributed.parallel_state import get_ep_group + if get_ep_group().rank == 0: - logger.info("[Elastic EP] Starting expert resharding " - "after scaling up...") - rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) - } + logger.info("[Elastic EP] Starting expert resharding after scaling up...") + rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} assert self.model_runner.eplb_state is not None self.model_runner.eplb_state.rearrange( self.model_runner.model, execute_shuffle=True, global_expert_load=global_expert_load, - rank_mapping=rank_mapping) + rank_mapping=rank_mapping, + ) if get_ep_group().rank == 0: logger.info("[Elastic EP] Expert resharding completed!") def _reconfigure_parallel_config( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: """ Update parallel config with provided reconfig_request """ parallel_config = self.vllm_config.parallel_config - parallel_config.data_parallel_size = \ - reconfig_request.new_data_parallel_size - if reconfig_request.new_data_parallel_rank != \ - ReconfigureRankType.KEEP_CURRENT_RANK: - parallel_config.data_parallel_rank = \ - reconfig_request.new_data_parallel_rank - if reconfig_request.new_data_parallel_rank_local != \ - ReconfigureRankType.KEEP_CURRENT_RANK: - parallel_config.data_parallel_rank_local = \ + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if ( + reconfig_request.new_data_parallel_rank_local + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank_local = ( reconfig_request.new_data_parallel_rank_local - parallel_config.data_parallel_master_ip = \ + ) + parallel_config.data_parallel_master_ip = ( reconfig_request.new_data_parallel_master_ip - parallel_config.data_parallel_master_port = \ + ) + parallel_config.data_parallel_master_port = ( reconfig_request.new_data_parallel_master_port + ) - def _reconfigure_moe(self, old_ep_size: int, - new_ep_size: int) -> Optional[torch.Tensor]: + def _reconfigure_moe( + self, old_ep_size: int, new_ep_size: int + ) -> Optional[torch.Tensor]: """ Reconfigure MoE modules with provided reconfig_request @@ -574,20 +615,26 @@ def _reconfigure_moe(self, old_ep_size: int, otherwise None """ from vllm.distributed.parallel_state import ( - get_dp_group, get_ep_group, prepare_communication_buffer_for_model) - from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoEParallelConfig) + get_dp_group, + get_ep_group, + prepare_communication_buffer_for_model, + ) + from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig parallel_config = self.vllm_config.parallel_config moe_modules = [ - module for module in self.model_runner.model.modules() - if (module.__class__.__name__ == "FusedMoE" - or module.__class__.__name__ == "SharedFusedMoE") + module + for module in self.model_runner.model.modules() + if ( + module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE" + ) ] num_local_experts = moe_modules[0].moe_config.num_local_experts - assert all(module.moe_config.num_local_experts == num_local_experts - for module in moe_modules), ( - "All MoE modules must have the same number of experts") + assert all( + module.moe_config.num_local_experts == num_local_experts + for module in moe_modules + ), "All MoE modules must have the same number of experts" for module in moe_modules: module.moe_config.num_experts = num_local_experts * new_ep_size module.global_num_experts = module.moe_config.num_experts @@ -600,49 +647,62 @@ def _reconfigure_moe(self, old_ep_size: int, if new_ep_size < old_ep_size: num_local_physical_experts = num_local_experts assert self.model_runner.eplb_state is not None - new_physical_experts = \ + new_physical_experts = ( self.model_runner.eplb_state.physical_to_logical_map.shape[1] + ) parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - - self.model_runner.eplb_state.logical_replica_count.shape[1]) + new_physical_experts + - self.model_runner.eplb_state.logical_replica_count.shape[1] + ) global_expert_load = None else: - num_local_physical_experts = torch.tensor([num_local_experts], - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) + num_local_physical_experts = torch.tensor( + [num_local_experts], dtype=torch.int32, device="cpu" + ) + torch.distributed.broadcast( + num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 + ) num_local_physical_experts = num_local_physical_experts.item() new_physical_experts = num_local_physical_experts * new_ep_size assert self.model_runner.eplb_state is not None global_expert_load = self.model_runner.eplb_state.rearrange( - self.model_runner.model, execute_shuffle=False) + self.model_runner.model, execute_shuffle=False + ) parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - global_expert_load.shape[1]) + new_physical_experts - global_expert_load.shape[1] + ) prepare_communication_buffer_for_model(self.model_runner.model) self.model_runner.model.update_physical_experts_metadata( num_physical_experts=new_physical_experts, - num_local_physical_experts=num_local_physical_experts) + num_local_physical_experts=num_local_physical_experts, + ) return global_expert_load def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: from vllm.config import set_current_vllm_config from vllm.distributed.parallel_state import ( - cleanup_dist_env_and_memory, get_ep_group) + cleanup_dist_env_and_memory, + get_ep_group, + ) old_ep_size = get_ep_group().world_size old_ep_rank = get_ep_group().rank - new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group( - ).world_size * get_pp_group().world_size + new_ep_size = ( + reconfig_request.new_data_parallel_size + * get_tp_group().world_size + * get_pp_group().world_size + ) if new_ep_size < old_ep_size: self._eplb_before_scale_down(old_ep_size, new_ep_size) cleanup_dist_env_and_memory() - if reconfig_request.new_data_parallel_rank == \ - ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + if ( + reconfig_request.new_data_parallel_rank + == ReconfigureRankType.SHUTDOWN_CURRENT_RANK + ): assert old_ep_rank >= new_ep_size # shutdown return @@ -650,16 +710,18 @@ def reinitialize_distributed( self._reconfigure_parallel_config(reconfig_request) with set_current_vllm_config(self.vllm_config): - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + ) global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size) if new_ep_size > old_ep_size: assert global_expert_load is not None - self._eplb_after_scale_up(old_ep_size, new_ep_size, - global_expert_load) + self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load) def save_sharded_state( self, @@ -668,6 +730,7 @@ def save_sharded_state( max_size: Optional[int] = None, ) -> None: from vllm.model_executor.model_loader import ShardedStateLoader + ShardedStateLoader.save_model( self.model_runner.model, path, @@ -680,7 +743,8 @@ def save_tensorized_model( tensorizer_config: "TensorizerConfig", ) -> None: self.model_runner.save_tensorized_model( - tensorizer_config=tensorizer_config, ) + tensorizer_config=tensorizer_config, + ) def shutdown(self) -> None: if runner := getattr(self, "model_runner", None): @@ -698,12 +762,14 @@ def init_worker_distributed_environment( parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank, backend) + init_distributed_environment( + parallel_config.world_size, rank, distributed_init_method, local_rank, backend + ) ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, - parallel_config.decode_context_parallel_size) + parallel_config.decode_context_parallel_size, + ) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index cdc0d317fffb..473982bebb12 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -3,22 +3,30 @@ """ Define KV connector functionality mixin for model runners. """ + import copy +from collections.abc import Generator from contextlib import AbstractContextManager, contextmanager, nullcontext -from typing import Generator # noqa: UP035 -from typing import TYPE_CHECKING, Optional +from typing import ( + TYPE_CHECKING, # noqa: UP035 + Optional, +) from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown, - get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_shutdown, + get_kv_transfer_group, + has_kv_transfer_group, +) from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, - ModelRunnerOutput) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + KVConnectorOutput, + ModelRunnerOutput, +) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -28,7 +36,6 @@ # Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU) class KVConnectorModelRunnerMixin: - @staticmethod def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): # Update KVConnector with the KVConnector metadata forward(). @@ -36,8 +43,7 @@ def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): kv_connector = get_kv_transfer_group() assert isinstance(kv_connector, KVConnectorBase) assert scheduler_output.kv_connector_metadata is not None - kv_connector.bind_connector_metadata( - scheduler_output.kv_connector_metadata) + kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata) # Background KV cache transfers happen here. # These transfers are designed to be async and the requests @@ -62,17 +68,21 @@ def get_finished_kv_transfers( ) -> tuple[Optional[set[str]], Optional[set[str]]]: if has_kv_transfer_group(): return get_kv_transfer_group().get_finished( - scheduler_output.finished_req_ids) + scheduler_output.finished_req_ids + ) return None, None @staticmethod - def kv_connector_no_forward(scheduler_output: "SchedulerOutput", - vllm_config: VllmConfig) -> ModelRunnerOutput: + def kv_connector_no_forward( + scheduler_output: "SchedulerOutput", vllm_config: VllmConfig + ) -> ModelRunnerOutput: # KV send/recv even if no work to do. - with set_forward_context( - None, vllm_config - ), KVConnectorModelRunnerMixin._get_kv_connector_output( - scheduler_output, wait_for_save=False) as kv_connector_output: + with ( + set_forward_context(None, vllm_config), + KVConnectorModelRunnerMixin._get_kv_connector_output( + scheduler_output, wait_for_save=False + ) as kv_connector_output, + ): pass if kv_connector_output.is_empty(): @@ -84,18 +94,20 @@ def kv_connector_no_forward(scheduler_output: "SchedulerOutput", @staticmethod def maybe_get_kv_connector_output( - scheduler_output: "SchedulerOutput" + scheduler_output: "SchedulerOutput", ) -> AbstractContextManager[Optional[KVConnectorOutput]]: - return KVConnectorModelRunnerMixin._get_kv_connector_output( - scheduler_output) if has_kv_transfer_group() else nullcontext() + return ( + KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output) + if has_kv_transfer_group() + else nullcontext() + ) # This context manager must be used within an active forward context. # It encapsulates the entire KV connector lifecycle within execute_model @staticmethod @contextmanager def _get_kv_connector_output( - scheduler_output: "SchedulerOutput", - wait_for_save: bool = True + scheduler_output: "SchedulerOutput", wait_for_save: bool = True ) -> Generator[KVConnectorOutput, None, None]: output = KVConnectorOutput() @@ -103,8 +115,7 @@ def _get_kv_connector_output( kv_connector = get_kv_transfer_group() assert isinstance(kv_connector, KVConnectorBase) assert scheduler_output.kv_connector_metadata is not None - kv_connector.bind_connector_metadata( - scheduler_output.kv_connector_metadata) + kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata) # Background KV cache transfers happen here. # These transfers are designed to be async and the requests @@ -118,12 +129,13 @@ def _get_kv_connector_output( kv_connector.wait_for_save() output.finished_sending, output.finished_recving = ( - kv_connector.get_finished(scheduler_output.finished_req_ids)) - output.invalid_block_ids = ( - kv_connector.get_block_ids_with_load_errors()) + kv_connector.get_finished(scheduler_output.finished_req_ids) + ) + output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors() - output.kv_connector_stats = KVConnectorModelRunnerMixin.\ - get_kv_connector_stats() + output.kv_connector_stats = ( + KVConnectorModelRunnerMixin.get_kv_connector_stats() + ) kv_connector.clear_connector_metadata() @staticmethod diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index e416f50322f4..e7358c4271ce 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -28,19 +28,19 @@ # Defined as a mixin for GPUModelRunner class LoRAModelRunnerMixin: - LORA_WARMUP_RANK = 8 - def load_lora_model(self, model: nn.Module, vllm_config: VllmConfig, - device: torch.device) -> nn.Module: - + def load_lora_model( + self, model: nn.Module, vllm_config: VllmConfig, device: torch.device + ) -> nn.Module: if not supports_lora(model): - raise ValueError( - f"{model.__class__.__name__} does not support LoRA yet.") + raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.") if supports_multimodal(model): - logger.warning("Regarding multimodal models, vLLM currently " - "only supports adding LoRA to language model.") + logger.warning( + "Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model." + ) # Add LoRA Manager to the Model Runner self.lora_manager = LRUCacheWorkerLoRAManager( @@ -51,41 +51,44 @@ def load_lora_model(self, model: nn.Module, vllm_config: VllmConfig, ) return self.lora_manager.create_lora_manager(model) - def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...], - token_lora_mapping: tuple[int, ...], - lora_requests: set[LoRARequest]) -> None: + def _set_active_loras( + self, + prompt_lora_mapping: tuple[int, ...], + token_lora_mapping: tuple[int, ...], + lora_requests: set[LoRARequest], + ) -> None: self._ensure_lora_enabled() # Set is_prefill to True, so we always use the SGMV kernels on # non-cuda platforms. # On cuda platforms we use the same kernels for prefill and # decode and this flag is generally ignored. - lora_mapping = LoRAMapping(token_lora_mapping, - prompt_lora_mapping, - is_prefill=True) + lora_mapping = LoRAMapping( + token_lora_mapping, prompt_lora_mapping, is_prefill=True + ) self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def _ensure_lora_enabled(self) -> None: if not hasattr(self, "lora_manager"): - raise RuntimeError( - "LoRA is not enabled. Use --enable-lora to enable LoRA.") - - def set_active_loras(self, input_batch: InputBatch, - num_scheduled_tokens: np.ndarray) -> None: + raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.") + def set_active_loras( + self, input_batch: InputBatch, num_scheduled_tokens: np.ndarray + ) -> None: prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs - token_lora_mapping: tuple[int, - ...] # of size np.sum(num_scheduled_tokens) + token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens) lora_requests: set[LoRARequest] - prompt_lora_mapping, token_lora_mapping, lora_requests = \ - input_batch.make_lora_inputs(num_scheduled_tokens) - return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, - lora_requests) + prompt_lora_mapping, token_lora_mapping, lora_requests = ( + input_batch.make_lora_inputs(num_scheduled_tokens) + ) + return self._set_active_loras( + prompt_lora_mapping, token_lora_mapping, lora_requests + ) @contextmanager - def maybe_setup_dummy_loras(self, - lora_config: Optional[LoRAConfig], - remove_lora: bool = True): + def maybe_setup_dummy_loras( + self, lora_config: Optional[LoRAConfig], remove_lora: bool = True + ): if lora_config is None: yield else: @@ -96,9 +99,11 @@ def maybe_setup_dummy_loras(self, # Make dummy lora requests lora_requests: set[LoRARequest] = { - LoRARequest(lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path") + LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) for lora_id in range(1, num_loras + 1) } @@ -106,8 +111,7 @@ def maybe_setup_dummy_loras(self, # Add the dummy LoRAs here so _set_active_loras doesn't try to # load from disk. for lr in lora_requests: - self.lora_manager.add_dummy_lora( - lr, rank=self.LORA_WARMUP_RANK) + self.lora_manager.add_dummy_lora(lr, rank=self.LORA_WARMUP_RANK) yield @@ -116,8 +120,9 @@ def maybe_setup_dummy_loras(self, self.lora_manager.remove_all_adapters() @contextmanager - def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig], - num_scheduled_tokens: np.ndarray): + def maybe_select_dummy_loras( + self, lora_config: Optional[LoRAConfig], num_scheduled_tokens: np.ndarray + ): if lora_config is None: yield else: @@ -129,35 +134,37 @@ def maybe_select_dummy_loras(self, lora_config: Optional[LoRAConfig], # Make prompt lora mapping # Assign LoRA IDs cyclically to simulate a worst-case scenario. - prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % - num_loras) + 1 + prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % num_loras) + 1 # Make token lora mapping - token_lora_mapping = np.repeat(prompt_lora_mapping, - num_scheduled_tokens) + token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) # Make dummy lora requests lora_requests: set[LoRARequest] = { - LoRARequest(lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path") + LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) for lora_id in range(1, num_loras + 1) } - self._set_active_loras(tuple(prompt_lora_mapping), - tuple(token_lora_mapping), lora_requests) + self._set_active_loras( + tuple(prompt_lora_mapping), tuple(token_lora_mapping), lora_requests + ) yield @contextmanager - def maybe_dummy_run_with_lora(self, - lora_config: Optional[LoRAConfig], - num_scheduled_tokens: np.ndarray, - remove_lora: bool = True): + def maybe_dummy_run_with_lora( + self, + lora_config: Optional[LoRAConfig], + num_scheduled_tokens: np.ndarray, + remove_lora: bool = True, + ): with ( - self.maybe_setup_dummy_loras(lora_config, remove_lora), - self.maybe_select_dummy_loras(lora_config, - num_scheduled_tokens), + self.maybe_setup_dummy_loras(lora_config, remove_lora), + self.maybe_select_dummy_loras(lora_config, num_scheduled_tokens), ): yield diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 4cd0ac352de0..34fed8f96467 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -18,16 +18,15 @@ class InputBatch: - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len @@ -54,13 +53,12 @@ def __init__( self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros( - (max_num_reqs, ), + (max_num_reqs,), device="cpu", dtype=torch.int32, pin_memory=pin_memory, ) - self.num_computed_tokens_cpu = \ - self.num_computed_tokens_cpu_tensor.numpy() + self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy() # Block table. self.block_table = MultiGroupBlockTable( @@ -73,91 +71,72 @@ def __init__( ) # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.temperature = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device + ) + self.temperature_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.temperature_cpu = self.temperature_cpu_tensor.numpy() self.greedy_reqs: set[str] = set() self.random_reqs: set[str] = set() - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device) + self.top_p_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.top_p_cpu = self.top_p_cpu_tensor.numpy() self.top_p_reqs: set[str] = set() - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) + self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device) + self.top_k_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory + ) self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: set[str] = set() - self.min_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.min_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) + self.min_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device) + self.min_p_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory + ) self.min_p_cpu = self.min_p_cpu_tensor.numpy() self.min_p_reqs: set[str] = set() # Frequency penalty related data structures - self.frequency_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.frequency_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_reqs: set[str] = set() # Presence penalty related data structures - self.presence_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( + self.presence_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device ) + self.presence_penalties_cpu_tensor = torch.empty( + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy() self.presence_penalties_reqs: set[str] = set() # Repetition penalty related data structures - self.repetition_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) + self.repetition_penalties = torch.empty( + (max_num_reqs,), dtype=torch.float, device=device + ) self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + (max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory + ) + self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() # req_index -> (min_tokens, stop_token_ids) self.min_tokens: dict[int, tuple[int, set[int]]] = {} # lora related - self.request_lora_mapping = np.zeros((self.max_num_reqs, ), - dtype=np.int32) + self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int32) self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {} @@ -174,8 +153,7 @@ def __init__( # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} - self.logit_bias: list[Optional[dict[int, - float]]] = [None] * max_num_reqs + self.logit_bias: list[Optional[dict[int, float]]] = [None] * max_num_reqs self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, # the value is False. Since we use masked_fill_ to set -inf. @@ -214,15 +192,14 @@ def add_request( # Copy the prompt token ids and output token ids. num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - request.prompt_token_ids, request.prompt_embeds) + request.prompt_token_ids, request.prompt_embeds + ) # TODO: copy prompt_embeds self.num_prompt_tokens[req_index] = num_prompt_tokens - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids + self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids + self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids # Number of token ids in token_ids_cpu. # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens @@ -252,23 +229,22 @@ def add_request( top_k = self.vocab_size self.top_k_cpu[req_index] = top_k self.min_p_cpu[req_index] = sampling_params.min_p - self.frequency_penalties_cpu[ - req_index] = sampling_params.frequency_penalty + self.frequency_penalties_cpu[req_index] = sampling_params.frequency_penalty if sampling_params.min_p > _SAMPLING_EPS: self.min_p_reqs.add(req_id) if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[ - req_index] = sampling_params.presence_penalty + self.presence_penalties_cpu[req_index] = sampling_params.presence_penalty if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[ - req_index] = sampling_params.repetition_penalty + self.repetition_penalties_cpu[req_index] = sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) if sampling_params.min_tokens: - self.min_tokens[req_index] = (sampling_params.min_tokens, - sampling_params.all_stop_token_ids) + self.min_tokens[req_index] = ( + sampling_params.min_tokens, + sampling_params.all_stop_token_ids, + ) # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. @@ -287,23 +263,23 @@ def add_request( if self.allowed_token_ids_mask_cpu_tensor is None: # Lazy allocation for this tensor, which can be large. # False means we don't fill with -inf. - self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device=self.device) - self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.allowed_token_ids_mask = torch.zeros( self.max_num_reqs, self.vocab_size, dtype=torch.bool, - device="cpu") + device=self.device, + ) + self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.max_num_reqs, self.vocab_size, dtype=torch.bool, device="cpu" + ) self.allowed_token_ids_mask_cpu_tensor[req_index] = True # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index][ - sampling_params.allowed_token_ids] = False + sampling_params.allowed_token_ids + ] = False if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids + self.bad_words_token_ids[req_index] = sampling_params.bad_words_token_ids # Add request lora ID if request.lora_request: @@ -361,35 +337,51 @@ def remove_request(self, req_id: str) -> Optional[int]: def swap_states(self, i1: int, i2: int) -> None: old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] - self._req_ids[i1], self._req_ids[i2] =\ - self._req_ids[i2], self._req_ids[i1] # noqa - self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ - self.req_output_token_ids[i2], self.req_output_token_ids[i1] + self._req_ids[i1], self._req_ids[i2] = self._req_ids[i2], self._req_ids[i1] # noqa + self.req_output_token_ids[i1], self.req_output_token_ids[i2] = ( + self.req_output_token_ids[i2], + self.req_output_token_ids[i1], + ) assert old_id_i1 is not None and old_id_i2 is not None - self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ - self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] - self.num_tokens[i1], self.num_tokens[i2] =\ - self.num_tokens[i2], self.num_tokens[i1] - self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ - self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] - self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ - self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] - self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ - self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] - self.temperature_cpu[i1], self.temperature_cpu[i2] =\ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] =\ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] =\ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] - self.min_p_cpu[i1], self.min_p_cpu[i2] =\ - self.min_p_cpu[i2], self.min_p_cpu[i1] + self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] = ( + self.req_id_to_index[old_id_i2], + self.req_id_to_index[old_id_i1], + ) + self.num_tokens[i1], self.num_tokens[i2] = ( + self.num_tokens[i2], + self.num_tokens[i1], + ) + self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] = ( + self.num_tokens_no_spec[i2], + self.num_tokens_no_spec[i1], + ) + self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] = ( + self.num_prompt_tokens[i2], + self.num_prompt_tokens[i1], + ) + self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] = ( + self.num_computed_tokens_cpu[i2], + self.num_computed_tokens_cpu[i1], + ) + self.temperature_cpu[i1], self.temperature_cpu[i2] = ( + self.temperature_cpu[i2], + self.temperature_cpu[i1], + ) + self.top_p_cpu[i1], self.top_p_cpu[i2] = self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = ( + self.frequency_penalties_cpu[i2], + self.frequency_penalties_cpu[i1], + ) + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = ( + self.presence_penalties_cpu[i2], + self.presence_penalties_cpu[i1], + ) + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = ( + self.repetition_penalties_cpu[i2], + self.repetition_penalties_cpu[i1], + ) + self.min_p_cpu[i1], self.min_p_cpu[i2] = self.min_p_cpu[i2], self.min_p_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -404,21 +396,28 @@ def swap_states(self, i1: int, i2: int) -> None: swap_dict_values(self.min_tokens, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] - self.logit_bias[i1], self.logit_bias[i2] =\ - self.logit_bias[i2], self.logit_bias[i1] + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = ( + self.request_lora_mapping[i2], + self.request_lora_mapping[i1], + ) + self.logit_bias[i1], self.logit_bias[i2] = ( + self.logit_bias[i2], + self.logit_bias[i1], + ) if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[i1], \ - self.allowed_token_ids_mask_cpu_tensor[i2] =\ - self.allowed_token_ids_mask_cpu_tensor[i2], \ - self.allowed_token_ids_mask_cpu_tensor[i1] + ( + self.allowed_token_ids_mask_cpu_tensor[i1], + self.allowed_token_ids_mask_cpu_tensor[i2], + ) = ( + self.allowed_token_ids_mask_cpu_tensor[i2], + self.allowed_token_ids_mask_cpu_tensor[i1], + ) self.block_table.swap_row(i1, i2) def condense(self, empty_req_indices: list[int]) -> None: """Move non-empty requests down into lower, empty indices. - + Args: empty_req_indices: empty batch indices, sorted descending. """ @@ -454,25 +453,29 @@ def condense(self, empty_req_indices: list[int]) -> None: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ - last_req_index, :num_tokens] + last_req_index, :num_tokens + ] self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ - last_req_index] - self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] + last_req_index + ] + self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[last_req_index] + self.num_computed_tokens_cpu[empty_index] = self.num_computed_tokens_cpu[ + last_req_index + ] self.block_table.move_row(last_req_index, empty_index) - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] + self.temperature_cpu[empty_index] = self.temperature_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[ - empty_index] = self.frequency_penalties_cpu[last_req_index] - self.presence_penalties_cpu[ - empty_index] = self.presence_penalties_cpu[last_req_index] - self.repetition_penalties_cpu[ - empty_index] = self.repetition_penalties_cpu[last_req_index] + self.frequency_penalties_cpu[empty_index] = self.frequency_penalties_cpu[ + last_req_index + ] + self.presence_penalties_cpu[empty_index] = self.presence_penalties_cpu[ + last_req_index + ] + self.repetition_penalties_cpu[empty_index] = self.repetition_penalties_cpu[ + last_req_index + ] self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: @@ -483,28 +486,28 @@ def condense(self, empty_req_indices: list[int]) -> None: self.min_tokens[empty_index] = min_token self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] + last_req_index + ] self.logit_bias[empty_index] = self.logit_bias[last_req_index] if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[ - empty_index] = self.allowed_token_ids_mask_cpu_tensor[ - last_req_index] + self.allowed_token_ids_mask_cpu_tensor[empty_index] = ( + self.allowed_token_ids_mask_cpu_tensor[last_req_index] + ) - bad_words_token_ids = self.bad_words_token_ids.pop( - last_req_index, None) + bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None) if bad_words_token_ids is not None: self.bad_words_token_ids[empty_index] = bad_words_token_ids # Decrement last_req_index since it is now empty. last_req_index -= 1 # Trim lists to the batch size. - del self._req_ids[self.num_reqs:] - del self.req_output_token_ids[self.num_reqs:] + del self._req_ids[self.num_reqs :] + del self.req_output_token_ids[self.num_reqs :] def _make_prompt_token_ids_tensor(self) -> torch.Tensor: - max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + max_prompt_len = self.num_prompt_tokens[: self.num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( (self.num_reqs, max_prompt_len), device="cpu", @@ -512,14 +515,12 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: pin_memory=self.pin_memory, ) prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = self.token_ids_cpu[:self. - num_reqs, :max_prompt_len] + prompt_token_ids[:] = self.token_ids_cpu[: self.num_reqs, :max_prompt_len] # Use the value of vocab_size as a pad since we don't have a # token_id of this value. for i in range(self.num_reqs): - prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size - return prompt_token_ids_cpu_tensor.to(device=self.device, - non_blocking=True) + prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size + return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) def make_lora_inputs( self, num_scheduled_tokens: np.ndarray @@ -535,12 +536,12 @@ def make_lora_inputs( 3. lora_requests: Set of relevant LoRA requests. """ - req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + req_lora_mapping = self.request_lora_mapping[: self.num_reqs] prompt_lora_mapping = tuple(req_lora_mapping) - token_lora_mapping = tuple( - req_lora_mapping.repeat(num_scheduled_tokens)) + token_lora_mapping = tuple(req_lora_mapping.repeat(num_scheduled_tokens)) active_lora_requests: set[LoRARequest] = set( - self.lora_id_to_lora_request.values()) + self.lora_id_to_lora_request.values() + ) return prompt_lora_mapping, token_lora_mapping, active_lora_requests @@ -570,9 +571,11 @@ def no_min_p(self) -> bool: @property def no_penalties(self) -> bool: - return (len(self.presence_penalties_reqs) == 0 - and len(self.frequency_penalties_reqs) == 0 - and len(self.repetition_penalties_reqs) == 0) + return ( + len(self.presence_penalties_reqs) == 0 + and len(self.frequency_penalties_reqs) == 0 + and len(self.repetition_penalties_reqs) == 0 + ) @property def max_num_logprobs(self) -> Optional[int]: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0b1c3d7c0e88..1d53fa954a7f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -9,6 +9,7 @@ import numpy as np import torch import torch.nn as nn + # TPU XLA related import torch_xla import torch_xla.core.xla_model as xm @@ -20,46 +21,71 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import (ParallelConfig, VllmConfig, - get_layers_from_vllm_config, update_config) -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.config import ( + ParallelConfig, + VllmConfig, + get_layers_from_vllm_config, + update_config, +) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.tpu import TPUModelLoader -from vllm.model_executor.models.interfaces import (SupportsMultiModal, - supports_transcription) +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + supports_transcription, +) from vllm.model_executor.models.interfaces_base import ( - is_pooling_model, is_text_generation_model) + is_pooling_model, + is_text_generation_model, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, - prev_power_of_2) -from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE, - PallasAttentionBackend, - PallasMetadata, - get_page_size_bytes) -from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, - SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists, - LogprobsTensors, ModelRunnerOutput) +from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available, prev_power_of_2 +from vllm.v1.attention.backends.pallas import ( + TPU_STR_DTYPE_TO_TORCH_DTYPE, + PallasAttentionBackend, + PallasMetadata, + get_page_size_bytes, +) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheSpec, + SlidingWindowSpec, +) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + LogprobsLists, + LogprobsTensors, + ModelRunnerOutput, +) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin, KVConnectorOutput) + KVConnectorModelRunnerMixin, + KVConnectorOutput, +) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch -from .utils import (MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, - bind_kv_cache, sanity_check_mm_encoder_outputs) +from .utils import ( + MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, + sanity_check_mm_encoder_outputs, +) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -107,7 +133,6 @@ # branch predictions are included as subgraph inputs to facilitate # pre-compilation. class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def __init__( self, vllm_config: VllmConfig, @@ -139,7 +164,7 @@ def __init__( num_devices = xr.global_runtime_device_count() mesh_shape = (num_devices, 1) device_ids = np.array(range(num_devices)) - self.mesh = xs.Mesh(device_ids, mesh_shape, ('x', 'y')) + self.mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y")) self.enforce_eager = model_config.enforce_eager @@ -155,8 +180,7 @@ def __init__( else: self.kv_cache_dtype = model_dtype else: - self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] + self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] self._hidden_states_dtype = self.dtype self.sliding_window = model_config.get_sliding_window() @@ -164,25 +188,28 @@ def __init__( self.max_model_len = model_config.max_model_len self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.num_blocks_per_most_len_req = cdiv( - self.most_model_len, - self.block_size) if self.most_model_len is not None else None + self.num_blocks_per_most_len_req = ( + cdiv(self.most_model_len, self.block_size) + if self.most_model_len is not None + else None + ) # InputBatch needs to work with sampling tensors greater than padding # to avoid dynamic shapes. Also, avoid suboptimal alignment. self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) self.num_tokens_paddings = _get_token_paddings( min_token_size=16, max_token_size=scheduler_config.max_num_batched_tokens, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP, + ) # In case `max_num_tokens < max(num_tokens_paddings)` use the actual # padded max value to pre-allocate data structures and pre-compile. self.max_num_tokens = self.num_tokens_paddings[-1] # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) + parallel_config, LayerBlockType.attention + ) + self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() @@ -195,17 +222,21 @@ def __init__( self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) + model_config + ) # TODO: Support M-RoPE (e.g, Qwen2-VL) assert not self.uses_mrope, "TPU does not support M-RoPE yet." - self._num_slices_per_kv_cache_update_block = \ - _get_num_slices_per_kv_cache_update_block(get_page_size_bytes( - block_size=self.block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - kv_cache_dtype=self.kv_cache_dtype, - )) + self._num_slices_per_kv_cache_update_block = ( + _get_num_slices_per_kv_cache_update_block( + get_page_size_bytes( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + kv_cache_dtype=self.kv_cache_dtype, + ) + ) + ) # Lazy initialization self.model: nn.Module # Set after load_model @@ -230,52 +261,68 @@ def __init__( # Cached torch/numpy tensor # The pytorch tensor and numpy array share the same buffer. # Sometimes the numpy op is faster so we create both. - self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu") + self.input_ids_cpu = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device="cpu" + ) - self.positions_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu") + self.positions_cpu = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device="cpu" + ) self.positions_np = self.positions_cpu.numpy() self.block_table_cpu = torch.zeros( (self.max_num_reqs, self.max_num_blocks_per_req), dtype=torch.int32, - device="cpu") + device="cpu", + ) # adjust num_reqs to avoid SMEM OOM. - self.num_reqs_most_model_len = min( - PallasAttentionBackend.get_max_num_seqs(self.most_model_len, - self.block_size), - self.max_num_reqs) if self.most_model_len is not None else None + self.num_reqs_most_model_len = ( + min( + PallasAttentionBackend.get_max_num_seqs( + self.most_model_len, self.block_size + ), + self.max_num_reqs, + ) + if self.most_model_len is not None + else None + ) self.num_reqs_max_model_len = min( - PallasAttentionBackend.get_max_num_seqs(self.max_model_len, - self.block_size), - self.max_num_reqs) - self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + PallasAttentionBackend.get_max_num_seqs( + self.max_model_len, self.block_size + ), + self.max_num_reqs, + ) + self.query_start_loc_cpu = torch.zeros( + self.max_num_tokens + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_lens_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + self.seq_lens_cpu = torch.zeros( + self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory, + ) self.seq_lens_np = self.seq_lens_cpu.numpy() # Only relevant for multimodal models if self.supports_mm_inputs: - self.is_mm_embed_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.bool, - device="cpu", - pin_memory=self.pin_memory) + self.is_mm_embed_cpu = torch.zeros( + self.max_num_tokens, + dtype=torch.bool, + device="cpu", + pin_memory=self.pin_memory, + ) # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens # Keep in int64 to avoid overflow with long context self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64) self.num_reqs_paddings = _get_req_paddings( - min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) + min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs + ) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -288,27 +335,35 @@ def __init__( (self.max_num_reqs, cdiv(self.vocab_size, 32)), dtype=torch.int32, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) self.require_structured_out_cpu = torch.zeros( (self.max_num_reqs, 1), dtype=torch.bool, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) self.structured_decode_arange = torch.arange( - 0, 32, device="cpu", pin_memory=self.pin_memory) + 0, 32, device="cpu", pin_memory=self.pin_memory + ) - self.mm_budget = (MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) if self.supports_mm_inputs else None) + self.mm_budget = ( + MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) + if self.supports_mm_inputs + else None + ) if not self.use_spmd: self.sample_from_logits_func = torch.compile( self.sample_from_logits, backend="openxla", fullgraph=True, - dynamic=False) + dynamic=False, + ) else: self.sample_from_logits_func = self.sample_from_logits @@ -322,8 +377,9 @@ def _update_num_xla_graphs(self, case_str): if new_compiled_graphs == 0: return - logger.info("Add new %d compiled XLA graphs due to %s", - new_compiled_graphs, case_str) + logger.info( + "Add new %d compiled XLA graphs due to %s", new_compiled_graphs, case_str + ) self.num_xla_graphs += new_compiled_graphs def _verify_num_xla_graphs(self, case_str): @@ -335,7 +391,9 @@ def _verify_num_xla_graphs(self, case_str): assert self.num_xla_graphs == curr_cached_graph, ( "Recompilation after warm up is detected during {}." " num_xla_graphs = {} curr_cached_graph = {}".format( - case_str, self.num_xla_graphs, curr_cached_graph)) + case_str, self.num_xla_graphs, curr_cached_graph + ) + ) def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: """Update the cached states and the persistent batch with the scheduler @@ -388,8 +446,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_ids_to_add: list[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: - assert new_req_data.sampling_params is not None,\ + assert new_req_data.sampling_params is not None, ( "Pooling is not supported in TPU yet" + ) req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params @@ -422,8 +481,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: assert new_block_ids is not None @@ -440,23 +498,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: - self.input_batch.block_table.append_row( - new_block_ids, req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. removed_req_indices = sorted(removed_req_indices, reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] - if removed_req_indices: - # Fill the empty index. - req_index = removed_req_indices.pop() - else: - # Append to the end. - req_index = None + # Fill the empty index or append to the end + req_index = removed_req_indices.pop() if removed_req_indices else None self.input_batch.add_request(req_state, req_index) # Condense the batched states if there are empty indices. @@ -513,8 +565,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: + if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so # that KV cache management logic will act as this layer does @@ -529,7 +580,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: if isinstance(attn_module, ChunkedLocalAttention): logger.warning_once( "Using irope in Pallas is not supported yet, it " - "will fall back to global attention for long context.") + "will fall back to global attention for long context." + ) if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -545,20 +597,22 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, ) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: raise NotImplementedError else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec - def _get_slot_mapping_metadata(self, num_reqs, - num_scheduled_tokens_per_req) -> np.ndarray: + def _get_slot_mapping_metadata( + self, num_reqs, num_scheduled_tokens_per_req + ) -> np.ndarray: """ Computes metadata for mapping slots to blocks in the key-value (KV) cache for a batch of requests. @@ -583,14 +637,16 @@ def _get_slot_mapping_metadata(self, num_reqs, - slice_len (int): The length of the slice. """ slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] - slices_end = self.input_batch.num_computed_tokens_cpu[:num_reqs] + \ - num_scheduled_tokens_per_req + slices_end = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req + ) local_block_start_idx = slices_start // self.block_size local_block_end_idx = (slices_end - 1) // self.block_size no_repeat_req_indices = self.arange_np[:num_reqs] global_block_start_idx = ( - no_repeat_req_indices * self.max_num_blocks_per_req + - local_block_start_idx) + no_repeat_req_indices * self.max_num_blocks_per_req + local_block_start_idx + ) block_lens = local_block_end_idx - local_block_start_idx + 1 global_block_start_idx = np.repeat(global_block_start_idx, block_lens) slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens]) @@ -598,30 +654,31 @@ def _get_slot_mapping_metadata(self, num_reqs, block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() block_numbers = block_table_cpu.flatten()[global_block_indices].numpy() total_block_len = np.sum(block_lens) - slot_mapping_slices = np.repeat(np.array([[0, self.block_size]], - dtype=np.int32), - total_block_len, - axis=0) + slot_mapping_slices = np.repeat( + np.array([[0, self.block_size]], dtype=np.int32), total_block_len, axis=0 + ) cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32) np.cumsum(block_lens, out=cu_block_lens[1:]) for req_idx in range(num_reqs): - slot_mapping_slices[cu_block_lens[req_idx]][ - 0] = slices_start[req_idx] % self.block_size - slot_mapping_slices[ - cu_block_lens[req_idx + 1] - - 1][1] = (slices_end[req_idx] - 1) % self.block_size + 1 + slot_mapping_slices[cu_block_lens[req_idx]][0] = ( + slices_start[req_idx] % self.block_size + ) + slot_mapping_slices[cu_block_lens[req_idx + 1] - 1][1] = ( + slices_end[req_idx] - 1 + ) % self.block_size + 1 slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0] cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32) np.cumsum(slice_lens, out=cu_slices_lens[1:]) - kv_cache_start_indices = slot_mapping_slices[:, 0] + \ - (block_numbers * self.block_size) + kv_cache_start_indices = slot_mapping_slices[:, 0] + ( + block_numbers * self.block_size + ) new_kv_start_indices = cu_slices_lens[:-1] slot_mapping_metadata = np.stack( - [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1) + [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1 + ) return slot_mapping_metadata - def _prepare_inputs(self, scheduler_output: "SchedulerOutput", - start_index: int): + def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int): assert scheduler_output.total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 @@ -643,22 +700,24 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", num_scheduled_tokens_per_req.append(num_tokens) if use_max_model_len: if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len: - num_scheduled_tokens_per_req = \ - num_scheduled_tokens_per_req[:self.num_reqs_max_model_len] + num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ + : self.num_reqs_max_model_len + ] end_index = start_index + self.num_reqs_max_model_len else: end_index = num_reqs else: - if len(num_scheduled_tokens_per_req - ) > self.num_reqs_most_model_len: - num_scheduled_tokens_per_req = \ - num_scheduled_tokens_per_req[:self.num_reqs_most_model_len] + if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len: + num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ + : self.num_reqs_most_model_len + ] end_index = start_index + self.num_reqs_most_model_len else: end_index = num_reqs max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req) - num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req, - dtype=np.int32) + num_scheduled_tokens_per_req = np.array( + num_scheduled_tokens_per_req, dtype=np.int32 + ) total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req) assert max_num_scheduled_tokens_all_reqs > 0 @@ -667,121 +726,130 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] # For each scheduled token, what are the corresponding req index. - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens_per_req) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req) # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # For each scheduled token, what is its position in corresponding req. arange = np.concatenate( - [self.arange_np[:n] for n in num_scheduled_tokens_per_req]) + [self.arange_np[:n] for n in num_scheduled_tokens_per_req] + ) # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens], + ) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 - np.cumsum(num_scheduled_tokens_per_req, - out=self.query_start_loc_np[1:num_reqs + 1]) - self.query_start_loc_np[num_reqs + 1:] = 1 + np.cumsum( + num_scheduled_tokens_per_req, out=self.query_start_loc_np[1 : num_reqs + 1] + ) + self.query_start_loc_np[num_reqs + 1 :] = 1 self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens_per_req) + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens_per_req + ) # Do the padding and copy the tensors to the TPU. padded_total_num_scheduled_tokens = _get_padded_token_len( - self.num_tokens_paddings, total_num_scheduled_tokens) + self.num_tokens_paddings, total_num_scheduled_tokens + ) # Zero out to avoid spurious values from prev iteration (last cp chunk) self.input_ids_cpu[ - total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0 - self.input_ids = self.input_ids_cpu[: - padded_total_num_scheduled_tokens].to( - self.device) - self.position_ids = self.positions_cpu[: - padded_total_num_scheduled_tokens].to( - self.device) + total_num_scheduled_tokens:padded_total_num_scheduled_tokens + ] = 0 + self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to( + self.device + ) + self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to( + self.device + ) if use_max_model_len: - block_tables = self.block_table_cpu[:self.num_reqs_max_model_len, : - self.max_num_blocks_per_req] - block_tables[:num_reqs, :self.max_num_blocks_per_req] = ( - self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]) - query_start_loc = self.query_start_loc_cpu[:self. - num_reqs_max_model_len + - 1].to(self.device) - seq_lens = self.seq_lens_cpu[:self.num_reqs_max_model_len].to( - self.device) + block_tables = self.block_table_cpu[ + : self.num_reqs_max_model_len, : self.max_num_blocks_per_req + ] + block_tables[:num_reqs, : self.max_num_blocks_per_req] = ( + self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs] + ) + query_start_loc = self.query_start_loc_cpu[ + : self.num_reqs_max_model_len + 1 + ].to(self.device) + seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device) else: - block_tables = self.block_table_cpu[:self. - num_reqs_most_model_len, :self. - num_blocks_per_most_len_req] - block_tables[:num_reqs, :self.num_blocks_per_most_len_req] = ( - self.input_batch.block_table[0].get_cpu_tensor() - [:num_reqs, :self.num_blocks_per_most_len_req]) - query_start_loc = self.query_start_loc_cpu[:self. - num_reqs_most_model_len + - 1].to(self.device) - seq_lens = self.seq_lens_cpu[:self.num_reqs_most_model_len].to( - self.device) + block_tables = self.block_table_cpu[ + : self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req + ] + block_tables[:num_reqs, : self.num_blocks_per_most_len_req] = ( + self.input_batch.block_table[0].get_cpu_tensor()[ + :num_reqs, : self.num_blocks_per_most_len_req + ] + ) + query_start_loc = self.query_start_loc_cpu[ + : self.num_reqs_most_model_len + 1 + ].to(self.device) + seq_lens = self.seq_lens_cpu[: self.num_reqs_most_model_len].to(self.device) block_tables = block_tables.to(self.device) # Calculate the slot mapping slot_mapping_metadata = self._get_slot_mapping_metadata( - num_reqs, num_scheduled_tokens_per_req) + num_reqs, num_scheduled_tokens_per_req + ) num_kv_update_slices = slot_mapping_metadata.shape[0] padded_num_slices = _get_padded_num_kv_cache_update_slices( - padded_total_num_scheduled_tokens, self.max_num_reqs, - self.block_size) + padded_total_num_scheduled_tokens, self.max_num_reqs, self.block_size + ) slot_mapping_metadata = np.pad( slot_mapping_metadata, [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], - constant_values=0) + constant_values=0, + ) slot_mapping_metadata = np.transpose(slot_mapping_metadata) - slot_mapping_metadata = torch.tensor(slot_mapping_metadata, - device=self.device) + slot_mapping_metadata = torch.tensor(slot_mapping_metadata, device=self.device) if self.lora_config is not None: # We need to respect padding when activating LoRA adapters padded_num_scheduled_tokens_per_req = np.copy( num_scheduled_tokens_per_req ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += \ + padded_num_scheduled_tokens_per_req[-1] += ( padded_total_num_scheduled_tokens - total_num_scheduled_tokens + ) - self.set_active_loras(self.input_batch, - padded_num_scheduled_tokens_per_req) + self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( slot_mapping=slot_mapping_metadata, block_tables=block_tables, context_lens=seq_lens, query_start_loc=query_start_loc, - num_seqs=torch.tensor([num_reqs], - dtype=torch.int32, - device=self.device), - num_kv_update_slices=torch.tensor([num_kv_update_slices], - dtype=torch.int32, - device=self.device), - num_slices_per_kv_cache_update_block=self. - _num_slices_per_kv_cache_update_block, + num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), + num_kv_update_slices=torch.tensor( + [num_kv_update_slices], dtype=torch.int32, device=self.device + ), + num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -789,10 +857,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", # token from the partial request. # TODO: Support prompt logprobs. padded_num_reqs = _get_padded_num_reqs_with_upper_limit( - num_reqs, self.max_num_reqs) + num_reqs, self.max_num_reqs + ) # Indices at which we sample (positions of last token in the sequence). # Padded to avoid recompiling when `num_reqs` varies. - logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 + logits_indices = self.query_start_loc_cpu[1 : padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) if self.lora_config is not None: @@ -800,20 +869,23 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", padded_num_scheduled_tokens_per_req = np.copy( num_scheduled_tokens_per_req ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += \ + padded_num_scheduled_tokens_per_req[-1] += ( padded_total_num_scheduled_tokens - total_num_scheduled_tokens + ) - self.set_active_loras(self.input_batch, - padded_num_scheduled_tokens_per_req) + self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) - layer_names = get_layers_from_vllm_config(self.vllm_config, - Attention).keys() + layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() per_layer_attn_metadata = { - layer_name: attn_metadata - for layer_name in layer_names + layer_name: attn_metadata for layer_name in layer_names } - return per_layer_attn_metadata, logits_indices, padded_num_reqs,\ - num_reqs, end_index + return ( + per_layer_attn_metadata, + logits_indices, + padded_num_reqs, + num_reqs, + end_index, + ) def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs @@ -843,10 +915,10 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): model = cast(SupportsMultiModal, self.model) encoder_outputs = [] for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # Run the encoder. # `curr_group_outputs` is either of the following: @@ -856,8 +928,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # (feature_size, hidden_size) in case the feature size is dynamic # depending on the input multimodal items. torch_xla.sync(wait=False) - curr_group_outputs = model.get_multimodal_embeddings( - **mm_kwargs_group) + curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group) torch_xla.sync(wait=False) sanity_check_mm_encoder_outputs( @@ -877,8 +948,9 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # assume to only have whole mm items to process. Hence we avoid the # intrinsic dynamism that `scatter_mm_placeholders` introduces. for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): - assert pos_info.is_embed is None, "Expected all positions to be"\ - " contiguous and embeddings." + assert pos_info.is_embed is None, ( + "Expected all positions to be contiguous and embeddings." + ) self.encoder_cache[mm_hash] = output def _gather_mm_embeddings( @@ -887,7 +959,8 @@ def _gather_mm_embeddings( ) -> tuple[list[torch.Tensor], torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens padded_total_num_scheduled_tokens = _get_padded_token_len( - self.num_tokens_paddings, total_num_scheduled_tokens) + self.num_tokens_paddings, total_num_scheduled_tokens + ) is_mm_embed = self.is_mm_embed_cpu is_mm_embed[:padded_total_num_scheduled_tokens] = False @@ -895,8 +968,7 @@ def _gather_mm_embeddings( req_start_idx = 0 for req_id in self.input_batch.req_ids: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens @@ -930,23 +1002,21 @@ def _gather_mm_embeddings( mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None,\ - f"Encoder cache miss for {mm_hash}." + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." - assert pos_info.is_embed is None, "Expected all positions to"\ - " be contiguous and embeddings." + assert pos_info.is_embed is None, ( + "Expected all positions to be contiguous and embeddings." + ) req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ - = True + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = True # Only whole mm items are processed mm_embeds.append(encoder_output) req_start_idx += num_scheduled_tokens - is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens] \ - .to(self.device) + is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens].to(self.device) return mm_embeds, is_mm_embed @@ -988,8 +1058,7 @@ def execute_model( # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) + return self.kv_connector_no_forward(scheduler_output, self.vllm_config) if self.supports_mm_inputs: # Run the multimodal encoder if any. @@ -1011,41 +1080,48 @@ def execute_model( self.maybe_setup_kv_connector(scheduler_output) while start_index < self.input_batch.num_reqs: - attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ - end_index = self._prepare_inputs(scheduler_output, start_index) + attn_metadata, logits_indices, padded_num_reqs, num_reqs, end_index = ( + self._prepare_inputs(scheduler_output, start_index) + ) input_ids, inputs_embeds = self._get_model_inputs( - self.input_ids, mm_embed_inputs) + self.input_ids, mm_embed_inputs + ) torch_xla.sync(wait=False) # Run the decoder with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=scheduler_output.total_num_scheduled_tokens): + attn_metadata, + self.vllm_config, + num_tokens=scheduler_output.total_num_scheduled_tokens, + ): hidden_states = self.model( input_ids=input_ids, positions=self.position_ids, inputs_embeds=inputs_embeds, ) - hidden_states = self.select_hidden_states(hidden_states, - logits_indices) + hidden_states = self.select_hidden_states(hidden_states, logits_indices) logits = self.compute_logits(hidden_states) - tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ - from_input_batch(self.input_batch, padded_num_reqs, self.device) + tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( + self.input_batch, padded_num_reqs, self.device + ) if scheduler_output.grammar_bitmask is not None: - require_struct_decoding, grammar_bitmask_padded, arange = \ - self.prepare_structured_decoding_input(logits, - scheduler_output) - logits = self.structured_decode(require_struct_decoding, - grammar_bitmask_padded, logits, - arange) + require_struct_decoding, grammar_bitmask_padded, arange = ( + self.prepare_structured_decoding_input(logits, scheduler_output) + ) + logits = self.structured_decode( + require_struct_decoding, grammar_bitmask_padded, logits, arange + ) selected_token_ids = self.sample_from_logits_func( - logits, tpu_sampling_metadata) + logits, tpu_sampling_metadata + ) # NOTE (NickLucche) Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. We can't enforce it # due to recompilations outside torch.compiled code, so just make # sure `sample_from_logits` does not modify the logits in-place. - logprobs = self.gather_logprobs(logits, selected_token_ids) \ - if tpu_sampling_metadata.logprobs else None + logprobs = ( + self.gather_logprobs(logits, selected_token_ids) + if tpu_sampling_metadata.logprobs + else None + ) # Remove padding on cpu and keep dynamic op outside of xla graph. selected_token_ids = selected_token_ids.cpu()[:num_reqs] @@ -1061,8 +1137,9 @@ def execute_model( # should be called right after each single forward pass, # instead of the forwards of the entire input batch. self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) + finished_sending, finished_recving = self.get_finished_kv_transfers( + scheduler_output + ) selected_token_ids = torch.cat(combined_selected_tokens, dim=0) if tpu_sampling_metadata.logprobs: @@ -1073,16 +1150,15 @@ def concat_lists(input_lists): result.extend(input_list) return result - logprobs_lists = LogprobsLists(logprob_token_ids=concat_lists( - [lp.logprob_token_ids for lp in combined_logprobs]), - logprobs=concat_lists([ - lp.logprobs - for lp in combined_logprobs - ]), - sampled_token_ranks=concat_lists([ - lp.sampled_token_ranks - for lp in combined_logprobs - ])) + logprobs_lists = LogprobsLists( + logprob_token_ids=concat_lists( + [lp.logprob_token_ids for lp in combined_logprobs] + ), + logprobs=concat_lists([lp.logprobs for lp in combined_logprobs]), + sampled_token_ranks=concat_lists( + [lp.sampled_token_ranks for lp in combined_logprobs] + ), + ) else: logprobs_lists = None @@ -1094,8 +1170,10 @@ def concat_lists(input_lists): for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): assert req_id is not None req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) + seq_len = ( + req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id] + ) if seq_len >= req_state.num_tokens: request_seq_lens.append((i, req_state, seq_len)) else: @@ -1111,8 +1189,8 @@ def concat_lists(input_lists): discard_sampled_tokens_req_indices.append(i) assert all( - req_id is not None for req_id in - self.input_batch.req_ids[:num_reqs]), "req_ids contains None" + req_id is not None for req_id in self.input_batch.req_ids[:num_reqs] + ), "req_ids contains None" req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} @@ -1140,22 +1218,24 @@ def concat_lists(input_lists): valid_mask = selected_token_ids != INVALID_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() valid_sampled_token_ids = [ - seq.tolist() - for seq in selected_token_ids[valid_mask].split(gen_lens) + seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens) ] self.input_batch.num_tokens[:num_reqs] += gen_lens for i, req_state, seq_len in request_seq_lens: target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) - self.input_batch.token_ids_cpu[ - i, target_slice] = valid_sampled_token_ids[i] + self.input_batch.token_ids_cpu[i, target_slice] = ( + valid_sampled_token_ids[i] + ) req_state.output_token_ids.extend(valid_sampled_token_ids[i]) - kv_connector_output = None if ( - finished_sending is None - and finished_recving is None) else KVConnectorOutput( + kv_connector_output = ( + None + if (finished_sending is None and finished_recving is None) + else KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, ) + ) model_runner_output = ModelRunnerOutput( req_ids=req_ids, @@ -1178,9 +1258,10 @@ def update_config(self, overrides: dict[str, Any]) -> None: # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754 allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " f"Allowed configs: {allowed_config_names}" + ) config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -1199,30 +1280,34 @@ def load_model(self) -> None: # the embedding weights. xm_tp_rank = xr.global_ordinal() with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", - return_value=xm_tp_rank): + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank, + ): try: if self.use_spmd: tpu_loader = TPUModelLoader( - load_config=self.vllm_config.load_config) + load_config=self.vllm_config.load_config + ) model = tpu_loader.load_model( vllm_config=self.vllm_config, model_config=self.vllm_config.model_config, - mesh=self.mesh) + mesh=self.mesh, + ) else: model_loader = get_model_loader(self.load_config) logger.info("Loading model from scratch...") model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.model_config) + vllm_config=self.vllm_config, model_config=self.model_config + ) except RuntimeError as e: raise RuntimeError( f"Unable to load model, a likely reason is the model is " "too large for the current device's HBM memory. " "Consider switching to a smaller model " "or sharding the weights on more chips. " - f"See the detailed error: {e}") from e + f"See the detailed error: {e}" + ) from e if self.lora_config is not None: model = self.load_lora_model(model, self.vllm_config, self.device) replace_set_lora(model) @@ -1236,44 +1321,43 @@ def load_model(self) -> None: self.sampler = TPUSampler() def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ + assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." + ) model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") model_loader.load_weights(self.model, model_config=self.model_config) @torch.no_grad() - def _dummy_run(self, num_tokens: int, num_reqs: int, - num_blocks: int) -> None: + def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: if self.supports_mm_inputs: input_ids = None - inputs_embeds = torch.zeros((num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + inputs_embeds = torch.zeros( + (num_tokens, self.hidden_size), dtype=self.dtype, device=self.device + ) else: - input_ids = torch.zeros((num_tokens), - dtype=torch.int32).to(self.device) + input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device) inputs_embeds = None actual_num_reqs = min(num_tokens, num_reqs) - position_ids = torch.zeros(num_tokens, - dtype=torch.int32).to(self.device) + position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) padded_num_slices = _get_padded_num_kv_cache_update_slices( - num_tokens, self.max_num_reqs, self.block_size) - num_kv_update_slices = torch.tensor([padded_num_slices], - dtype=torch.int32).to(self.device) - slot_mapping = torch.zeros((3, padded_num_slices), - dtype=torch.int32).to(self.device) - block_tables = torch.zeros((num_reqs, num_blocks), - dtype=torch.int32).to(self.device) + num_tokens, self.max_num_reqs, self.block_size + ) + num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to( + self.device + ) + slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to( + self.device + ) + block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to( + self.device + ) query_lens = [1] * num_reqs - query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, - dtype=torch.int32), - dim=0, - dtype=torch.int32).to(self.device) - context_lens = torch.ones((num_reqs, ), - dtype=torch.int32).to(self.device) - num_seqs = torch.tensor([actual_num_reqs], - dtype=torch.int32).to(self.device) + query_start_loc = torch.cumsum( + torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32 + ).to(self.device) + context_lens = torch.ones((num_reqs,), dtype=torch.int32).to(self.device) + num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32).to(self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, @@ -1281,8 +1365,7 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, query_start_loc=query_start_loc, num_seqs=num_seqs, num_kv_update_slices=num_kv_update_slices, - num_slices_per_kv_cache_update_block=self. - _num_slices_per_kv_cache_update_block, + num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block, ) if self.supports_mm_inputs: @@ -1295,27 +1378,29 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) - layer_names = get_layers_from_vllm_config(self.vllm_config, - Attention).keys() + layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() per_layer_attn_metadata = { - layer_name: attn_metadata - for layer_name in layer_names + layer_name: attn_metadata for layer_name in layer_names } - with self.maybe_select_dummy_loras( - self.lora_config, - np.array([num_tokens], dtype=np.int32)), set_forward_context( - per_layer_attn_metadata, self.vllm_config, 0): - out = self.model(input_ids=input_ids, - positions=position_ids, - inputs_embeds=inputs_embeds) + with ( + self.maybe_select_dummy_loras( + self.lora_config, np.array([num_tokens], dtype=np.int32) + ), + set_forward_context(per_layer_attn_metadata, self.vllm_config, 0), + ): + out = self.model( + input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds + ) self._hidden_states_dtype = out.dtype - def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, - lora_requests) -> None: + def _set_active_loras( + self, prompt_lora_mapping, token_lora_mapping, lora_requests + ) -> None: torch_xla.sync(wait=False) # Captures input updates - super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, - lora_requests) + super()._set_active_loras( + prompt_lora_mapping, token_lora_mapping, lora_requests + ) torch_xla.sync(wait=False) # Captures metadata updates def _precompile_mm_encoder(self) -> None: @@ -1332,8 +1417,8 @@ def _precompile_mm_encoder(self) -> None: for mode, max_items_per_seq in max_items_per_seq_by_modality.items(): logger.info( - "Compiling Multimodal %s Encoder with different input" - " shapes.", mode) + "Compiling Multimodal %s Encoder with different input shapes.", mode + ) start = time.perf_counter() # No padding for MM encoder just yet. for num_items in range(1, max_items_per_seq + 1): @@ -1345,7 +1430,8 @@ def _precompile_mm_encoder(self) -> None: # Run multimodal encoder. torch_xla.sync(wait=False) mm_embeds = self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + **batched_dummy_mm_inputs + ) torch_xla.sync(wait=False) num_patches = mm_embeds[0].shape[0] items_size = num_patches * num_items @@ -1359,12 +1445,11 @@ def _precompile_mm_encoder(self) -> None: # XLA Workaround: if torch.zeros(..device) is used, XLA # compiles a scalar+expansion op, which won't match # the graph generated at runtime. CPU->TPU must be used - placeholders_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device="cpu") + placeholders_ids = torch.zeros( + num_tokens, dtype=torch.int32, device="cpu" + ) # Align placeholders and actual num mm_embeddings. - placeholders_ids[:items_size] = \ - hf_config.image_token_index + placeholders_ids[:items_size] = hf_config.image_token_index placeholders_ids = placeholders_ids.to(self.device) @@ -1382,9 +1467,9 @@ def _precompile_mm_encoder(self) -> None: # Pre-compile `get_input_embeddings` when mm_embeddings are not # present. Chunk is only made of text, no mm_placeholders. for num_tokens in self.num_tokens_paddings: - placeholders_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device="cpu") + placeholders_ids = torch.zeros( + num_tokens, dtype=torch.int32, device="cpu" + ) placeholders_ids = placeholders_ids.to(self.device) a, b = self._get_model_inputs( placeholders_ids, @@ -1396,19 +1481,25 @@ def _precompile_mm_encoder(self) -> None: xm.wait_device_ops() end = time.perf_counter() logger.info( - "Multimodal %s Encoder compilation finished in in %.2f " - "[secs].", mode, end - start) + "Multimodal %s Encoder compilation finished in in %.2f [secs].", + mode, + end - start, + ) def _precompile_backbone(self) -> None: logger.info("Compiling the model with different input shapes.") start = time.perf_counter() for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) - self._dummy_run(num_tokens, self.num_reqs_max_model_len, - self.max_num_blocks_per_req) + self._dummy_run( + num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req + ) if self.most_model_len is not None: - self._dummy_run(num_tokens, self.num_reqs_most_model_len, - self.num_blocks_per_most_len_req) + self._dummy_run( + num_tokens, + self.num_reqs_most_model_len, + self.num_blocks_per_most_len_req, + ) xm.wait_device_ops() end = time.perf_counter() logger.info("Compilation finished in %.2f [secs].", end - start) @@ -1417,23 +1508,19 @@ def _precompile_backbone(self) -> None: def _precompile_select_hidden_states(self) -> None: # Compile hidden state selection function for bucketed # n_tokens x max_num_reqs. Graph is really small so this is fine. - logger.info( - "Compiling select_hidden_states with different input shapes.") + logger.info("Compiling select_hidden_states with different input shapes.") start = time.perf_counter() hsize = self.model_config.get_hidden_size() for num_tokens in self.num_tokens_paddings: - dummy_hidden = torch.zeros((num_tokens, hsize), - device=self.device, - dtype=self._hidden_states_dtype) + dummy_hidden = torch.zeros( + (num_tokens, hsize), device=self.device, dtype=self._hidden_states_dtype + ) torch._dynamo.mark_dynamic(dummy_hidden, 0) for num_reqs in self.num_reqs_paddings: - indices = torch.zeros(num_reqs, - dtype=torch.int32, - device=self.device) + indices = torch.zeros(num_reqs, dtype=torch.int32, device=self.device) torch._dynamo.mark_dynamic(indices, 0) self.select_hidden_states(dummy_hidden, indices) - logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, - num_reqs) + logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs) # Requests can't be more than tokens. But do compile for the # next bigger value in case num_tokens uses bucketed padding. if num_reqs >= min(num_tokens, self.max_num_reqs): @@ -1448,9 +1535,9 @@ def _precompile_compute_logits(self) -> None: start = time.perf_counter() hsize = self.model_config.get_hidden_size() for num_reqs in self.num_reqs_paddings: - dummy_hidden = torch.zeros((num_reqs, hsize), - device=self.device, - dtype=self._hidden_states_dtype) + dummy_hidden = torch.zeros( + (num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype + ) torch._dynamo.mark_dynamic(dummy_hidden, 0) self.compute_logits(dummy_hidden) logger.info(" -- num_seqs: %d", num_reqs) @@ -1460,23 +1547,28 @@ def _precompile_compute_logits(self) -> None: self._update_num_xla_graphs("compute_logits") def _precompile_structured_decoding(self) -> None: - logger.info( - "Compiling structured_decoding with different input shapes.") + logger.info("Compiling structured_decoding with different input shapes.") start = time.perf_counter() for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) - dummy_require_struct_decoding = \ - self.require_structured_out_cpu[:num_reqs].to(self.device) - dummy_grammar_bitmask = \ - self.grammar_bitmask_cpu[:num_reqs].to(self.device) + dummy_logits = torch.zeros( + (num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype, + ) + dummy_require_struct_decoding = self.require_structured_out_cpu[ + :num_reqs + ].to(self.device) + dummy_grammar_bitmask = self.grammar_bitmask_cpu[:num_reqs].to(self.device) # The first dimension of the above 3 dummy tensors cannot be # mark_dynamic because some operations in structured_decode require # them to be static. arange = self.structured_decode_arange.to(self.device) - self.structured_decode(dummy_require_struct_decoding, - dummy_grammar_bitmask, dummy_logits, arange) + self.structured_decode( + dummy_require_struct_decoding, + dummy_grammar_bitmask, + dummy_logits, + arange, + ) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -1484,30 +1576,29 @@ def _precompile_structured_decoding(self) -> None: self._update_num_xla_graphs("structured_decoding") def _precompile_sample_from_logits(self) -> None: - logger.info( - "Compiling sample_from_logits with different input shapes.") + logger.info("Compiling sample_from_logits with different input shapes.") start = time.perf_counter() for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) + dummy_logits = torch.zeros( + (num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype, + ) # The first dimension of dummy_logits cannot be mark_dynamic # because some operations in the sampler require it to be static. for all_greedy in [False, True]: generate_params_if_all_greedy = not all_greedy - sampling_metadata = ( - TPUSupportedSamplingMetadata.from_input_batch( - self.input_batch, - num_reqs, - self.device, - generate_params_if_all_greedy, - )) + sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( + self.input_batch, + num_reqs, + self.device, + generate_params_if_all_greedy, + ) sampling_metadata.all_greedy = all_greedy with self.maybe_select_dummy_loras( - self.lora_config, np.array([num_reqs], - dtype=np.int32)): - self.sample_from_logits_func(dummy_logits, - sampling_metadata) + self.lora_config, np.array([num_reqs], dtype=np.int32) + ): + self.sample_from_logits_func(dummy_logits, sampling_metadata) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -1518,13 +1609,15 @@ def _precompile_gather_logprobs(self) -> None: logger.info("Compiling gather_logprobs with different input shapes.") start = time.perf_counter() for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros((num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype) - dummy_tokens = torch.zeros((num_reqs, 1), - dtype=torch.int64).to(self.device) + dummy_logits = torch.zeros( + (num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype, + ) + dummy_tokens = torch.zeros((num_reqs, 1), dtype=torch.int64).to(self.device) with self.maybe_select_dummy_loras( - self.lora_config, np.array([num_reqs], dtype=np.int32)): + self.lora_config, np.array([num_reqs], dtype=np.int32) + ): self.gather_logprobs(dummy_logits, dummy_tokens) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() @@ -1554,7 +1647,8 @@ def profile_run( if self.model_config.multimodal_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " - "encoder cache.") + "encoder cache." + ) else: mm_budget = self.mm_budget assert mm_budget is not None @@ -1565,8 +1659,9 @@ def profile_run( # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget \ - .max_items_per_batch_by_modality[dummy_modality] + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + dummy_modality + ] logger.info( "Encoder cache will be initialized with a budget of " @@ -1588,15 +1683,16 @@ def profile_run( # impact of recompilation until it's fixed. start = time.perf_counter() torch_xla.sync(wait=False) - dummy_encoder_outputs = \ - self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs + ) torch_xla.sync(wait=False) xm.wait_device_ops() end = time.perf_counter() logger.info( "Multimodal Encoder profiling finished in %.2f [secs].", - end - start) + end - start, + ) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, @@ -1604,15 +1700,18 @@ def profile_run( ) # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. - self._dummy_run(num_tokens, self.num_reqs_max_model_len, - self.max_num_blocks_per_req) + self._dummy_run( + num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req + ) if self.most_model_len is not None: - self._dummy_run(num_tokens, self.num_reqs_most_model_len, - self.num_blocks_per_most_len_req) + self._dummy_run( + num_tokens, + self.num_reqs_most_model_len, + self.num_blocks_per_most_len_req, + ) torch_xla.sync(wait=False) xm.wait_device_ops() @@ -1637,10 +1736,8 @@ def maybe_setup_cross_layer_kv_sharing( kv_cache_config.kv_cache_groups, ) - for layer_name, target_layer_name in self.shared_kv_cache_layers.items( - ): - logger.debug("%s reuses KV cache of %s", layer_name, - target_layer_name) + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) kv_caches[layer_name] = kv_caches[target_layer_name] def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: @@ -1652,11 +1749,13 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ if len(kv_cache_config.kv_cache_groups) > 1: raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") + "Hybrid models with more than one KV cache type are not supported yet." + ) - if kv_cache_config.kv_cache_groups[ - 0].kv_cache_spec.block_size != self.block_size: + if ( + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + != self.block_size + ): self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -1669,14 +1768,16 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: ], ) # Verify dtype compatibility between block_table_cpu and input_batch - assert self.block_table_cpu.dtype == self.input_batch.block_table[ - 0].get_cpu_tensor().dtype + assert ( + self.block_table_cpu.dtype + == self.input_batch.block_table[0].get_cpu_tensor().dtype + ) kv_cache_sizes = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: assert len(kv_cache_tensor.shared_by) == 1, ( - "KV cache tensor shared by multiple layers is not supported in " - "TPU.") + "KV cache tensor shared by multiple layers is not supported in TPU." + ) kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size kv_caches: dict[str, torch.Tensor] = {} @@ -1690,19 +1791,23 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if self.use_spmd: num_kv_heads = kv_cache_spec.num_kv_heads assert self.original_parallel_config is not None - tp_size = \ - self.original_parallel_config.tensor_parallel_size + tp_size = self.original_parallel_config.tensor_parallel_size # TODO: Handle kv cache duplication under SPMD mode. assert num_kv_heads % tp_size == 0, ( f"num_kv_heads {num_kv_heads} must be divisible by " - f"tp_size {tp_size} under SPMD mode") + f"tp_size {tp_size} under SPMD mode" + ) kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) dtype = kv_cache_spec.dtype - tpu_kv_cache = torch.zeros(kv_cache_shape, - dtype=dtype).to(self.device) + tpu_kv_cache = torch.zeros(kv_cache_shape, dtype=dtype).to( + self.device + ) kv_caches[layer_name] = tpu_kv_cache else: @@ -1714,19 +1819,19 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + self.kv_caches, + ) if self.use_spmd: # Shard KV Cache for cache in self.kv_caches: - xs.mark_sharding(cache, self.mesh, (None, 'x', None, None)) + xs.mark_sharding(cache, self.mesh, (None, "x", None, None)) if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) def reset_dynamo_cache(self): - # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs` # since the compiled model object of the language backbone of a # multimodal model needs to be extracted via `get_language_model`. @@ -1737,7 +1842,8 @@ def reset_dynamo_cache(self): if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher): logger.info("Clear dynamo cache and cached dynamo bytecode.") torch._dynamo.eval_frame.remove_from_cache( - compiled_model.original_code_object) + compiled_model.original_code_object + ) compiled_model.compiled_codes.clear() @torch.compile(backend="openxla", fullgraph=True, dynamic=False) @@ -1745,30 +1851,29 @@ def select_hidden_states(self, hidden_states, indices_do_sample): return hidden_states[indices_do_sample] @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def compute_logits(self, - sample_hidden_states: torch.Tensor) -> torch.Tensor: + def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor: return self.model.compute_logits(sample_hidden_states) # TODO: Under SPMD mode, sample_from_logits has correctness issue. # Re-enable the torch.compile once the issue is fixed in torchxla. # @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def sample_from_logits( - self, logits: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: + self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata + ) -> torch.Tensor: """ - Sample with xla-friendly function. This function is to be traced + Sample with xla-friendly function. This function is to be traced separately from `forward` for lighter compilation overhead. """ if sampling_metadata.all_greedy: out_tokens = torch.argmax(logits, dim=-1, keepdim=True) else: - out_tokens = self.sampler(logits, - sampling_metadata).sampled_token_ids + out_tokens = self.sampler(logits, sampling_metadata).sampled_token_ids return out_tokens @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def gather_logprobs(self, logits: torch.Tensor, - sampled_tokens: torch.Tensor) -> LogprobsTensors: + def gather_logprobs( + self, logits: torch.Tensor, sampled_tokens: torch.Tensor + ) -> LogprobsTensors: """ Gather the top_logprobs with corresponding tokens. Use a fixed number of logprobs as an alternative to having multiple pre-compiled graphs. @@ -1778,28 +1883,37 @@ def gather_logprobs(self, logits: torch.Tensor, return self.sampler.gather_logprobs( logprobs, self.model_config.max_logprobs, - token_ids=sampled_tokens.squeeze(-1)) + token_ids=sampled_tokens.squeeze(-1), + ) @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def structured_decode(self, require_struct_decoding: torch.Tensor, - grammar_bitmask: torch.Tensor, logits: torch.Tensor, - arange: torch.Tensor) -> torch.Tensor: + def structured_decode( + self, + require_struct_decoding: torch.Tensor, + grammar_bitmask: torch.Tensor, + logits: torch.Tensor, + arange: torch.Tensor, + ) -> torch.Tensor: return torch.where( require_struct_decoding, self.apply_grammar_bitmask(logits, grammar_bitmask, arange), - logits) + logits, + ) - def apply_grammar_bitmask(self, logits: torch.Tensor, - grammar_bitmask: torch.Tensor, - arange: torch.Tensor): - assert (logits.shape[0] == grammar_bitmask.shape[0]) + def apply_grammar_bitmask( + self, logits: torch.Tensor, grammar_bitmask: torch.Tensor, arange: torch.Tensor + ): + assert logits.shape[0] == grammar_bitmask.shape[0] logits_cloned = logits.clone() for i in range(logits.shape[0]): - unpacked_bitmask = (torch.bitwise_right_shift( - grammar_bitmask[i][:, None], arange[None, :]) & 1) == 0 - unpacked_bitmask = unpacked_bitmask.reshape(-1)[:self.vocab_size] + unpacked_bitmask = ( + torch.bitwise_right_shift(grammar_bitmask[i][:, None], arange[None, :]) + & 1 + ) == 0 + unpacked_bitmask = unpacked_bitmask.reshape(-1)[: self.vocab_size] logits_cloned[i] = logits_cloned[i].masked_fill( - unpacked_bitmask, -float("inf")) + unpacked_bitmask, -float("inf") + ) return logits_cloned def get_multimodal_embeddings(self, *args, **kwargs): @@ -1821,23 +1935,27 @@ def prepare_structured_decoding_input( sorted_struct_requests = sorted( scheduler_output.structured_output_request_ids.items(), - key=lambda item: item[1]) + key=lambda item: item[1], + ) cumulative_mask_idx = 0 for req_id, _ in sorted_struct_requests: if req_id not in self.input_batch.req_id_to_index: continue batch_index = self.input_batch.req_id_to_index[req_id] self.grammar_bitmask_cpu[batch_index] = torch.from_numpy( - grammar_bitmask[cumulative_mask_idx]) + grammar_bitmask[cumulative_mask_idx] + ) # It's not guaranteed that all requests in this batch require # structured output, so create a bool tensor to represent # the requests that need structured output. self.require_structured_out_cpu[batch_index] = True cumulative_mask_idx += 1 - return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ - self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ - self.structured_decode_arange.to(logits.device) + return ( + self.require_structured_out_cpu[:num_reqs].to(logits.device), + self.grammar_bitmask_cpu[:num_reqs].to(logits.device), + self.structured_decode_arange.to(logits.device), + ) def _get_mm_dummy_batch( self, @@ -1860,13 +1978,15 @@ def _get_mm_dummy_batch( dummy_mm_items = [dummy_mm_item] * max_items_per_batch model = cast(SupportsMultiModal, self.model) - return next(grouped_mm_kwargs - for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - )) + return next( + grouped_mm_kwargs + for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ) + ) def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: @@ -1887,9 +2007,10 @@ def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int: return min(res, upper_limit) -def _get_token_paddings(min_token_size: int, max_token_size: int, - padding_gap: int) -> list[int]: - """Generate a list of padding size, starting from min_token_size, +def _get_token_paddings( + min_token_size: int, max_token_size: int, padding_gap: int +) -> list[int]: + """Generate a list of padding size, starting from min_token_size, ending with a number that can cover max_token_size If padding_gap == 0 then: @@ -1927,15 +2048,15 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, def _get_padded_token_len(paddings: list[int], x: int) -> int: - """Return the first element in paddings list greater or equal to x. - """ + """Return the first element in paddings list greater or equal to x.""" index = bisect.bisect_left(paddings, x) assert index < len(paddings) return paddings[index] -def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, - page_size: int) -> int: +def _get_padded_num_kv_cache_update_slices( + num_tokens: int, max_num_reqs: int, page_size: int +) -> int: """Calculates the padded number of KV cache update slices to avoid recompilation.""" # NOTE(chengjiyao): let's say R_i is the token num for i-th request, @@ -1971,7 +2092,6 @@ def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: def replace_set_lora(model): - def _tpu_set_lora( self, index: int, @@ -1995,5 +2115,4 @@ def _tpu_reset_lora(self, index: int): module._original_set_lora = module.set_lora module._original_reset_lora = module.reset_lora module.set_lora = _tpu_set_lora.__get__(module, module.__class__) - module.reset_lora = _tpu_reset_lora.__get__( - module, module.__class__) + module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index d4f0a65f2a16..66515c7e5786 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -11,10 +11,14 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - has_kv_transfer_group) +from vllm.distributed import ( + ensure_model_parallel_initialized, + init_distributed_environment, +) +from vllm.distributed.kv_transfer import ( + ensure_kv_transfer_initialized, + has_kv_transfer_group, +) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed @@ -23,8 +27,7 @@ from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, - KVCacheSpec) +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import report_usage_stats from vllm.v1.worker.utils import bind_kv_cache @@ -44,7 +47,6 @@ class TPUWorker: - def __init__( self, vllm_config: VllmConfig, @@ -82,12 +84,12 @@ def __init__( if self.cache_config.cache_dtype == "auto": self.cache_dtype = self.model_config.dtype else: - self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - self.cache_config.cache_dtype] + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype] if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() # Delay profiler initialization to the start of the profiling. @@ -100,14 +102,14 @@ def __init__( # For TPU, we can only have 1 active profiler session for 1 profiler # server. So we only profile on rank0. self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - self.profile_dir) + logger.info( + "Profiling enabled. Traces will be saved to: %s", self.profile_dir + ) if self.model_config.seed is None: self.model_config.seed = 0 - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -118,9 +120,10 @@ def init_device(self): # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to # fix this. It will be removed after the bug in XLA compiler is fixed. os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") + - " --xla_tpu_force_1d_allreduce_at_chunk_count=1" - " --xla_jf_conv_input_fusion=False") + os.environ.get("LIBTPU_INIT_ARGS", "") + + " --xla_tpu_force_1d_allreduce_at_chunk_count=1" + " --xla_jf_conv_input_fusion=False" + ) # --xla_jf_conv_input_fusion=False is used to improve the perf of # quantized matmul. torch.set_grad_enabled(False) @@ -128,8 +131,8 @@ def init_device(self): # Initialize the distributed environment. self._init_tpu_worker_distributed_environment( - self.vllm_config, self.rank, self.distributed_init_method, - self.local_rank) + self.vllm_config, self.rank, self.distributed_init_method, self.local_rank + ) # Device initialization should happen after initializing # the distributed runtime. @@ -158,14 +161,15 @@ def init_device(self): # cache during development is recommended.We can disable it by # `export VLLM_XLA_CACHE_PATH=` if envs.VLLM_XLA_CACHE_PATH: - per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{rank}") + per_rank_path = os.path.join( + envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}" + ) xr.initialize_cache(per_rank_path, readonly=False) # Init ModelRunner here, so that we have access to self.device. - self.model_runner = \ - TPUModelRunner(self.vllm_config, self.device, - self.original_parallel_config) + self.model_runner = TPUModelRunner( + self.vllm_config, self.device, self.original_parallel_config + ) if rank == 0: # If usage stat is enabled, collect relevant info. @@ -184,13 +188,15 @@ def determine_available_memory(self) -> int: kv_caches[layer_name] = tpu_kv_cache else: raise NotImplementedError( - f"Unsupported KV cache spec '{type(layer_spec)}'") + f"Unsupported KV cache spec '{type(layer_spec)}'" + ) runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, - runner_kv_caches) + runner_kv_caches, + ) # `max_num_tokens >= max_num_batched_tokens` due to padding. with self.model_runner.maybe_setup_dummy_loras(self.lora_config): @@ -215,6 +221,7 @@ def determine_available_memory(self) -> int: # TODO: use xm.get_memory_info for SPMD once it's supported in # PyTorch/XLA. import tpu_info + chip_type, _ = tpu_info.device.get_local_chips() device_usage = tpu_info.metrics.get_chip_usage(chip_type) total_memory_size = device_usage[0].total_memory @@ -231,20 +238,20 @@ def determine_available_memory(self) -> int: profiled = current_mem * 1.02 # Calculate the TPU KV cache size based on profiling. - usable_memory_size = int(total_memory_size * - self.cache_config.gpu_memory_utilization) + usable_memory_size = int( + total_memory_size * self.cache_config.gpu_memory_utilization + ) tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) head_size = self.model_config.get_head_size() if head_size > 0: - padded_head_size = cdiv( - head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + padded_head_size = ( + cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + ) if padded_head_size != head_size: - logger.warning_once("head size is padded to %d", - padded_head_size) + logger.warning_once("head size is padded to %d", padded_head_size) # We adjust the usable memory size for the KV cache to prevent OOM # errors, even after padding the head_size. - tpu_kv_cache_bytes = (tpu_kv_cache_bytes * head_size // - padded_head_size) + tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size return int(tpu_kv_cache_bytes) def execute_model( @@ -253,8 +260,7 @@ def execute_model( ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) # every worker's output is needed when kv_transfer_group is set up - return output if self.is_driver_worker or has_kv_transfer_group( - ) else None + return output if self.is_driver_worker or has_kv_transfer_group() else None def profile(self, is_start: bool = True): if self.rank < 1: @@ -327,8 +333,8 @@ def _init_tpu_worker_distributed_environment( backend=current_platform.dist_backend, ) ensure_model_parallel_initialized( - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size + ) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/v1/worker/ubatch_splitting.py b/vllm/v1/worker/ubatch_splitting.py index 7767750aa604..6723239e8495 100644 --- a/vllm/v1/worker/ubatch_splitting.py +++ b/vllm/v1/worker/ubatch_splitting.py @@ -10,8 +10,11 @@ from vllm.forward_context import DPMetadata from vllm.logger import init_logger from vllm.utils import round_up -from vllm.v1.worker.ubatch_utils import (UBatchSlice, UBatchSlices, - is_second_ubatch_empty) +from vllm.v1.worker.ubatch_utils import ( + UBatchSlice, + UBatchSlices, + is_second_ubatch_empty, +) logger = init_logger(__name__) @@ -24,14 +27,18 @@ def should_ubatch_with_num_tokens( ) -> tuple[bool, Optional[torch.Tensor]]: dp_size = vllm_config.parallel_config.data_parallel_size dp_rank = vllm_config.parallel_config.data_parallel_rank - return DPMetadata.should_ubatch_across_dp(should_ubatch, - orig_num_tokens_per_ubatch, - padded_num_tokens_per_ubatch, - dp_size, dp_rank) + return DPMetadata.should_ubatch_across_dp( + should_ubatch, + orig_num_tokens_per_ubatch, + padded_num_tokens_per_ubatch, + dp_size, + dp_rank, + ) -def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int, - uniform_decode: bool) -> bool: +def check_ubatch_thresholds( + config: ParallelConfig, num_tokens: int, uniform_decode: bool +) -> bool: if not config.enable_dbo: return False if uniform_decode: @@ -41,9 +48,11 @@ def check_ubatch_thresholds(config: ParallelConfig, num_tokens: int, def get_dp_padding_ubatch( - num_tokens_unpadded: int, num_tokens_padded: int, - should_attempt_ubatching: bool, - vllm_config: VllmConfig) -> tuple[bool, Optional[torch.Tensor]]: + num_tokens_unpadded: int, + num_tokens_padded: int, + should_attempt_ubatching: bool, + vllm_config: VllmConfig, +) -> tuple[bool, Optional[torch.Tensor]]: """ 1. Decides if each DP rank is going to microbatch. Either all ranks run with microbatching or none of them do. If this function decides @@ -71,7 +80,8 @@ def get_dp_padding_ubatch( # If this DP rank doesn't want to attempt microbatching if not should_attempt_ubatching: (should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens( - False, 0, 0, vllm_config) + False, 0, 0, vllm_config + ) assert should_ubatch is False assert num_tokens_across_dp is None return should_ubatch, num_tokens_across_dp @@ -85,14 +95,16 @@ def get_dp_padding_ubatch( # ubatch. Abort if so if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded): logger.debug( - "Empty second µbatch detected: unpadded tokens: %s, padded " - "tokens: %s", num_tokens_unpadded, num_tokens_padded) + "Empty second µbatch detected: unpadded tokens: %s, padded tokens: %s", + num_tokens_unpadded, + num_tokens_padded, + ) should_ubatch = False # Note that we compute the number of padded tokens per ubatch (should_ubatch, num_tokens_across_dp) = should_ubatch_with_num_tokens( - should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch, - vllm_config) + should_ubatch, num_tokens_unpadded // 2, num_tokens_per_ubatch, vllm_config + ) if not should_ubatch: assert num_tokens_across_dp is None return should_ubatch, num_tokens_across_dp @@ -100,14 +112,15 @@ def get_dp_padding_ubatch( assert num_tokens_across_dp is not None max_tokens_across_dp_cpu = int(torch.max(num_tokens_across_dp).item()) - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) + num_tokens_after_padding = torch.tensor( + [max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32 + ) return should_ubatch, num_tokens_after_padding -def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \ - -> UBatchSlices: + +def create_ubatch_slices( + num_scheduled_tokens: np.ndarray, split_point: int +) -> UBatchSlices: # TODO(lucas): Refactor the gpu_model_runner.py so we can pass # in cu_num_tokens directly (i.e. query_start_loc) cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32) @@ -119,19 +132,20 @@ def create_ubatch_slices(num_scheduled_tokens: np.ndarray, split_point: int) \ # Determine request slices using exclusive stop semantics # First ubatch includes requests whose tokens overlap [0, split_point) first_ubatch_req_stop = int( - np.searchsorted(cu_num_tokens, split_point, side="left")) + np.searchsorted(cu_num_tokens, split_point, side="left") + ) first_ubatch_req_slice = slice(0, first_ubatch_req_stop) # Second ubatch starts at the request that contains the split_point # or the request starting exactly at split_point (if on boundary) second_ubatch_req_start = int( - np.searchsorted(cu_num_tokens, split_point, side="right") - 1) - second_ubatch_req_slice = slice(second_ubatch_req_start, - len(cu_num_tokens) - 1) + np.searchsorted(cu_num_tokens, split_point, side="right") - 1 + ) + second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1) return [ UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice), - UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice) + UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice), ] @@ -147,7 +161,7 @@ def ubatch_split( should be split into microbatches. Returns: tuple[ - ubatch_slices: if this is set then all DP ranks have agreed to + ubatch_slices: if this is set then all DP ranks have agreed to microbatch num_tokens_after_padding: A tensor containing the total number of tokens per-microbatch for each DP rank including padding. Will be @@ -186,7 +200,8 @@ def ubatch_split( assert num_tokens_after_padding is not None token_split_point = int(num_tokens_after_padding[0].item()) - ubatch_slices = create_ubatch_slices(num_scheduled_tokens_per_request, - token_split_point) + ubatch_slices = create_ubatch_slices( + num_scheduled_tokens_per_request, token_split_point + ) return (ubatch_slices, num_tokens_after_padding) diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index 33d58aa94843..2deba16f8a49 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -11,8 +11,10 @@ class UBatchSlice: token_slice: slice def is_empty(self) -> bool: - return self.request_slice.start == self.request_slice.stop \ + return ( + self.request_slice.start == self.request_slice.stop or self.token_slice.start == self.token_slice.stop + ) @property def num_tokens(self) -> int: @@ -22,6 +24,7 @@ def num_tokens(self) -> int: UBatchSlices: TypeAlias = list[UBatchSlice] -def is_second_ubatch_empty(orig_num_tokens_per_ubatch: int, - padded_num_tokens_per_ubatch: int) -> bool: +def is_second_ubatch_empty( + orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int +) -> bool: return padded_num_tokens_per_ubatch >= 2 * orig_num_tokens_per_ubatch diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index c26cb07123a5..867ce2b93036 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -10,7 +10,7 @@ from vllm.utils import current_stream _THREAD_ID_TO_CONTEXT: dict = {} -_CURRENT_CONTEXTS: list[Optional['UBatchContext']] = [None, None] +_CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None] class UBatchContext: @@ -18,17 +18,19 @@ class UBatchContext: Context manager for micro-batching synchronization using threading events. """ - def __init__(self, - id: int, - comm_stream: torch.cuda.Stream, - compute_stream: torch.cuda.Stream, - forward_context: ForwardContext, - ready_barrier: threading.Barrier, - cpu_wait_event: threading.Event, - cpu_signal_event: threading.Event, - gpu_comm_done_event: torch.cuda.Event, - gpu_compute_done_event: torch.cuda.Event, - schedule: str = "default"): + def __init__( + self, + id: int, + comm_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, + forward_context: ForwardContext, + ready_barrier: threading.Barrier, + cpu_wait_event: threading.Event, + cpu_signal_event: threading.Event, + gpu_comm_done_event: torch.cuda.Event, + gpu_compute_done_event: torch.cuda.Event, + schedule: str = "default", + ): self.id = id self.comm_stream = comm_stream self.compute_stream = compute_stream @@ -151,7 +153,6 @@ def dbo_current_ubatch_id() -> int: def _register_ubatch_function(func): - def wrapper(*args, **kwargs): if len(_THREAD_ID_TO_CONTEXT) > 0: ctx_idx = _THREAD_ID_TO_CONTEXT[threading.get_ident()] @@ -161,20 +162,20 @@ def wrapper(*args, **kwargs): return wrapper -dbo_maybe_run_recv_hook = _register_ubatch_function( - UBatchContext.maybe_run_recv_hook) +dbo_maybe_run_recv_hook = _register_ubatch_function(UBatchContext.maybe_run_recv_hook) dbo_yield = _register_ubatch_function(UBatchContext.yield_) dbo_yield_and_switch_from_compute_to_comm = _register_ubatch_function( - UBatchContext.yield_and_switch_from_compute_to_comm) + UBatchContext.yield_and_switch_from_compute_to_comm +) dbo_yield_and_switch_from_comm_to_compute = _register_ubatch_function( - UBatchContext.yield_and_switch_from_comm_to_compute) + UBatchContext.yield_and_switch_from_comm_to_compute +) dbo_switch_to_comm = _register_ubatch_function(UBatchContext.switch_to_comm) -dbo_switch_to_compute = _register_ubatch_function( - UBatchContext.switch_to_compute) -dbo_switch_to_comm_sync = _register_ubatch_function( - UBatchContext.switch_to_comm_sync) +dbo_switch_to_compute = _register_ubatch_function(UBatchContext.switch_to_compute) +dbo_switch_to_comm_sync = _register_ubatch_function(UBatchContext.switch_to_comm_sync) dbo_switch_to_compute_sync = _register_ubatch_function( - UBatchContext.switch_to_compute_sync) + UBatchContext.switch_to_compute_sync +) def dbo_register_recv_hook(recv_hook): @@ -197,28 +198,25 @@ def make_ubatch_contexts( Create a context manager for micro-batching synchronization. """ cpu_events = [threading.Event() for _ in range(num_micro_batches)] - gpu_comm_done_events = [ - torch.cuda.Event() for _ in range(num_micro_batches) - ] - gpu_compute_done_events = [ - torch.cuda.Event() for _ in range(num_micro_batches) - ] + gpu_comm_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)] + gpu_compute_done_events = [torch.cuda.Event() for _ in range(num_micro_batches)] assert len(forward_contexts) == 2 ctxs = [] for i in range(num_micro_batches): - ctx = UBatchContext(id=i, - compute_stream=compute_stream, - comm_stream=comm_stream, - forward_context=forward_contexts[i], - ready_barrier=ready_barrier, - cpu_wait_event=cpu_events[i], - cpu_signal_event=cpu_events[(i + 1) % - num_micro_batches], - gpu_comm_done_event=gpu_comm_done_events[i], - gpu_compute_done_event=gpu_compute_done_events[i], - schedule=schedule) + ctx = UBatchContext( + id=i, + compute_stream=compute_stream, + comm_stream=comm_stream, + forward_context=forward_contexts[i], + ready_barrier=ready_barrier, + cpu_wait_event=cpu_events[i], + cpu_signal_event=cpu_events[(i + 1) % num_micro_batches], + gpu_comm_done_event=gpu_comm_done_events[i], + gpu_compute_done_event=gpu_compute_done_events[i], + schedule=schedule, + ) ctxs.append(ctx) return ctxs diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 3e0dbda59435..c3d16827f10e 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -35,18 +35,18 @@ def __init__( self.model_config = model_config self.scheduler_config = scheduler_config self.mm_registry = mm_registry - self.cache = cache = processor_only_cache_from_config( - model_config, mm_registry) + self.cache = cache = processor_only_cache_from_config(model_config, mm_registry) self.max_model_len = model_config.max_model_len self.max_num_reqs = scheduler_config.max_num_seqs - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, - cache=cache) + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) - max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config, - cache=cache) + max_tokens_by_modality = ( + mm_registry.get_max_tokens_per_item_by_nonzero_modality( + model_config, cache=cache + ) + ) encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( scheduler_config, @@ -145,17 +145,14 @@ def create_with_metadata_builders( vllm_config: VllmConfig, device: torch.device, num_metadata_builders: int = 1, - ) -> 'AttentionGroup': + ) -> "AttentionGroup": metadata_builders = [ - backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, - device) + backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) for _ in range(num_metadata_builders) ] - return AttentionGroup(backend, metadata_builders, layer_names, - kv_cache_spec) + return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec) - def get_metadata_builder(self, - ubatch_id: int = 0) -> AttentionMetadataBuilder: + def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: assert len(self.metadata_builders) > ubatch_id return self.metadata_builders[ubatch_id] @@ -172,19 +169,22 @@ def sanity_check_mm_encoder_outputs( "Expected multimodal embeddings to be a list/tuple of 2D tensors, " f"or a single 3D tensor, but got {type(mm_embeddings)} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) assert len(mm_embeddings) == expected_num_items, ( "Expected number of multimodal embeddings to match number of " f"input items: {expected_num_items}, but got {len(mm_embeddings)=} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) assert all(e.ndim == 2 for e in mm_embeddings), ( "Expected multimodal embeddings to be a sequence of 2D tensors, " f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) def scatter_mm_placeholders( @@ -290,8 +290,7 @@ def bind_kv_cache( # Convert kv_caches dict to a list of tensors in the order of layer_index. index2name = defaultdict(list) for layer_name in kv_caches: - index2name[extract_layer_index(layer_name, - num_attn_module)].append(layer_name) + index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name) for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] @@ -319,16 +318,16 @@ def bind_kv_cache( forward_context[layer_name].kv_cache = [kv_cache] -def is_residual_scattered_for_sp(vllm_config: VllmConfig, - num_input_tokens: int) -> bool: +def is_residual_scattered_for_sp( + vllm_config: VllmConfig, num_input_tokens: int +) -> bool: """Check if the residual tensor is scattered for sequence parallelism. The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled, and the number of input tokens is one of the compilation sizes. """ - if not vllm_config.compilation_config.pass_config.\ - enable_sequence_parallelism: + if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism: return False tp = vllm_config.parallel_config.tensor_parallel_size @@ -341,4 +340,4 @@ def is_residual_scattered_for_sp(vllm_config: VllmConfig, assert num_input_tokens % tp == 0 # Currently, SP is only enabled for static size fx graphs. - return (num_input_tokens in vllm_config.compilation_config.compile_sizes) + return num_input_tokens in vllm_config.compilation_config.compile_sizes diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 5b393ee6bf3e..dc9bb3910fbc 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -4,7 +4,7 @@ from __future__ import annotations import os -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, TypeVar, Union import torch import torch.nn as nn @@ -13,10 +13,13 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import (enable_trace_function_call_for_thread, - resolve_obj_by_qualname, run_method, - update_environment_variables, - warn_for_unimplemented_methods) +from vllm.utils import ( + enable_trace_function_call_for_thread, + resolve_obj_by_qualname, + run_method, + update_environment_variables, + warn_for_unimplemented_methods, +) from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.outputs import SamplerOutput @@ -65,6 +68,7 @@ def __init__( self.compilation_config = vllm_config.compilation_config from vllm.platforms import current_platform + self.current_platform = current_platform self.parallel_config.rank = rank @@ -74,8 +78,8 @@ def __init__( self.is_driver_worker = is_driver_worker # Device and model state - self.device: Optional[torch.device] = None - self.model_runner: Optional[nn.Module] = None + self.device: torch.device | None = None + self.model_runner: nn.Module | None = None def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """Get specifications for KV cache implementation.""" @@ -95,10 +99,8 @@ def init_device(self) -> None: """ raise NotImplementedError - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache with the given size in blocks. - """ + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks.""" raise NotImplementedError def get_model(self) -> nn.Module: @@ -113,9 +115,8 @@ def load_model(self) -> None: raise NotImplementedError def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[list[SamplerOutput]]: + self, execute_model_req: ExecuteModelRequest | None = None + ) -> list[SamplerOutput] | None: raise NotImplementedError def start_worker_execution_loop(self) -> None: @@ -197,8 +198,8 @@ def __init__( group. """ self.rpc_rank = rpc_rank - self.worker: Optional[WorkerBase] = None - self.vllm_config: Optional[VllmConfig] = None + self.worker: WorkerBase | None = None + self.vllm_config: VllmConfig | None = None # do not store this `vllm_config`, `init_worker` will set the final # one. TODO: investigate if we can remove this field in # `WorkerWrapperBase`, `init_cached_hf_modules` should be @@ -209,6 +210,7 @@ def __init__( if trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() def shutdown(self) -> None: @@ -229,7 +231,7 @@ def update_environment_variables( envs_list: list[dict[str, str]], ) -> None: envs = envs_list[self.rpc_rank] - key = 'CUDA_VISIBLE_DEVICES' + key = "CUDA_VISIBLE_DEVICES" if key in envs and key in os.environ: # overwriting CUDA_VISIBLE_DEVICES is desired behavior # suppress the warning in `update_environment_variables` @@ -244,22 +246,26 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: kwargs = all_kwargs[self.rpc_rank] self.vllm_config = kwargs.get("vllm_config") assert self.vllm_config is not None, ( - "vllm_config is required to initialize the worker") + "vllm_config is required to initialize the worker" + ) enable_trace_function_call_for_thread(self.vllm_config) from vllm.plugins import load_general_plugins + load_general_plugins() if isinstance(self.vllm_config.parallel_config.worker_cls, str): worker_class = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_cls) + self.vllm_config.parallel_config.worker_cls + ) else: raise ValueError( "passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501 ) if self.vllm_config.parallel_config.worker_extension_cls: worker_extension_cls = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_extension_cls) + self.vllm_config.parallel_config.worker_extension_cls + ) extended_calls = [] if worker_extension_cls not in worker_class.__bases__: # check any conflicts between worker and worker_extension_cls @@ -269,15 +275,20 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: assert not hasattr(worker_class, attr), ( f"Worker class {worker_class} already has an attribute" f" {attr}, which conflicts with the worker" - f" extension class {worker_extension_cls}.") + f" extension class {worker_extension_cls}." + ) if callable(getattr(worker_extension_cls, attr)): extended_calls.append(attr) # dynamically inherit the worker extension class worker_class.__bases__ = worker_class.__bases__ + ( - worker_extension_cls, ) + worker_extension_cls, + ) logger.info( "Injected %s into %s for extended collective_rpc calls %s", - worker_extension_cls, worker_class, extended_calls) + worker_extension_cls, + worker_class, + extended_calls, + ) with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization self.worker = worker_class(**kwargs) @@ -305,8 +316,10 @@ def execute_method(self, method: Union[str, bytes], *args, **kwargs): # exceptions in the rest worker may cause deadlock in rpc like ray # see https://github.com/vllm-project/vllm/issues/3455 # print the error and inform the user to solve the error - msg = (f"Error executing method {method!r}. " - "This might cause deadlock in distributed execution.") + msg = ( + f"Error executing method {method!r}. " + "This might cause deadlock in distributed execution." + ) logger.exception(msg) raise e diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 7becdd392498..4f82c18da73a 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -37,9 +37,7 @@ def _sync_device(self) -> None: @contextmanager def _torch_cuda_wrapper(): - class _EventPlaceholder: - def __init__(self, *args, **kwargs) -> None: self.record = lambda: None self.synchronize = lambda: None diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 7355206f30f5..a1e54628d9ed 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -11,8 +11,7 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.v1.worker.gpu_worker import (Worker, - init_worker_distributed_environment) +from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment from vllm.v1.worker.xpu_model_runner import XPUModelRunner logger = init_logger(__name__) @@ -29,8 +28,9 @@ def __init__( distributed_init_method: str, is_driver_worker: bool = False, ): - super().__init__(vllm_config, local_rank, rank, - distributed_init_method, is_driver_worker) + super().__init__( + vllm_config, local_rank, rank, distributed_init_method, is_driver_worker + ) device_config = self.device_config assert device_config.device_type == "xpu" assert current_platform.is_xpu() @@ -39,8 +39,10 @@ def __init__( # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) + logger.info( + "Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) logger.debug( "Profiler config: record_shapes=%s," "profile_memory=%s,with_stack=%s,with_flops=%s", @@ -59,7 +61,9 @@ def __init__( with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) + torch_profiler_trace_dir, use_gzip=True + ), + ) else: self.profiler = None @@ -75,8 +79,7 @@ def xpu_get_mem_info(self): # and we don't have any API to get it. so we mark it as 128MB. used_memory = torch.xpu.memory_allocated() non_torch_allocations = 128 * 1024 * 1024 - free_gpu_memory = total_gpu_memory - (used_memory + - non_torch_allocations) + free_gpu_memory = total_gpu_memory - (used_memory + non_torch_allocations) return free_gpu_memory, total_gpu_memory @torch.inference_mode() @@ -97,10 +100,12 @@ def determine_available_memory(self) -> int: free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info() current_allocated_bytes = torch.xpu.memory_allocated() - msg = ("Before memory profiling run, " - f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " - f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + msg = ( + "Before memory profiling run, " + f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " + f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB." + ) logger.info(msg) # Execute a forward pass with dummy inputs to profile the memory usage # of the model. @@ -113,67 +118,73 @@ def determine_available_memory(self) -> int: "Error in memory profiling. " f"Initial free memory {self.init_gpu_memory}, current free memory" f" {free_gpu_memory}. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") + "not properly cleaned up before initializing the vLLM instance." + ) # Get the peak memory allocation recorded by torch peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"] torch.xpu.empty_cache() - torch_allocated_bytes = torch.xpu.memory_stats( - )["allocated_bytes.all.current"] - total_allocated_bytes = self.xpu_get_mem_info( - )[1] - self.xpu_get_mem_info()[0] + torch_allocated_bytes = torch.xpu.memory_stats()["allocated_bytes.all.current"] + total_allocated_bytes = self.xpu_get_mem_info()[1] - self.xpu_get_mem_info()[0] non_torch_allocations = total_allocated_bytes - torch_allocated_bytes if non_torch_allocations > 0: peak_memory += non_torch_allocations available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) - - msg = ("After memory profiling run, " - f"peak memory usage is {peak_memory / 1024**2:.2f} MB," - f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " - f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory + ) + + msg = ( + "After memory profiling run, " + f"peak memory usage is {peak_memory / 1024**2:.2f} MB," + f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " + f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB." + ) logger.info(msg) return int(available_kv_cache_memory) def init_device(self): - if self.device_config.device.type == "xpu" and current_platform.is_xpu( - ): + if self.device_config.device.type == "xpu" and current_platform.is_xpu(): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) current_platform.check_if_supports_dtype(self.model_config.dtype) torch.xpu.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( - self.local_rank).total_memory + self.local_rank + ).total_memory else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") + raise RuntimeError(f"Not support device type: {self.device_config.device}") ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "pidfd") ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") - ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", - str(self.parallel_config.world_size)) + ENV_LOCAL_WORLD_SIZE = os.getenv( + "LOCAL_WORLD_SIZE", str(self.parallel_config.world_size) + ) os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE os.environ["LOCAL_RANK"] = str(self.local_rank) - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) + init_worker_distributed_environment( + self.vllm_config, + self.rank, + self.distributed_init_method, + self.local_rank, + current_platform.dist_backend, + ) # global all_reduce needed for overall oneccl warm up - torch.distributed.all_reduce(torch.zeros(1).xpu(), - group=get_world_group().device_group) + torch.distributed.all_reduce( + torch.zeros(1).xpu(), group=get_world_group().device_group + ) # Set random seed. set_random_seed(self.model_config.seed) # Construct the model runner self.model_runner = XPUModelRunner( # type: ignore - self.vllm_config, self.device) + self.vllm_config, self.device + ) diff --git a/vllm/version.py b/vllm/version.py index 6c88b1b5a3bf..63095f8bce1e 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -6,9 +6,7 @@ except Exception as e: import warnings - warnings.warn(f"Failed to read commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) + warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2) __version__ = "dev" __version_tuple__ = (0, 0, __version__)